diff --git a/app/main.py b/app/main.py index cbccede..4c87cc5 100644 --- a/app/main.py +++ b/app/main.py @@ -18,7 +18,18 @@ ConductorStatus, ) from app.scheduler import Scheduler, SchedulingStrategy -from app.providers import CerebrasProvider, NvidiaProvider +from app.providers import ( + CerebrasProvider, + NvidiaProvider, + GroqProvider, + GeminiProvider, + MistralProvider, + OpenRouterProvider, + DeepSeekProvider, + HuggingFaceProvider, + CohereProvider, + SambaNovaProvider, +) from app.providers.base import ProviderKey from app.rate_limiter import RateLimitBucket from app.models_mapping import ModelMapper, DEFAULT_MODEL @@ -52,91 +63,140 @@ scheduler: Optional[Scheduler] = None model_mapper: Optional[ModelMapper] = None +# Provider definition table — add new providers here only +PROVIDER_DEFINITIONS = [ + { + "name": "cerebras", + "class": CerebrasProvider, + "default_base_url": "https://api.cerebras.ai/v1", + "default_rpm": 1000, + "default_tpm": 1_000_000, + }, + { + "name": "nvidia", + "class": NvidiaProvider, + "default_base_url": "https://integrate.api.nvidia.com/v1", + "default_rpm": 60, + "default_tpm": 100_000, + }, + { + "name": "groq", + "class": GroqProvider, + "default_base_url": "https://api.groq.com/openai/v1", + "default_rpm": 30, + "default_tpm": 6_000, + }, + { + "name": "gemini", + "class": GeminiProvider, + "default_base_url": "https://generativelanguage.googleapis.com/v1beta/openai", + "default_rpm": 15, + "default_tpm": 1_000_000, + }, + { + "name": "mistral", + "class": MistralProvider, + "default_base_url": "https://api.mistral.ai/v1", + "default_rpm": 60, + "default_tpm": 100_000, + }, + { + "name": "openrouter", + "class": OpenRouterProvider, + "default_base_url": "https://openrouter.ai/api/v1", + "default_rpm": 60, + "default_tpm": 100_000, + }, + { + "name": "deepseek", + "class": DeepSeekProvider, + "default_base_url": "https://api.deepseek.com/v1", + "default_rpm": 60, + "default_tpm": 100_000, + }, + { + "name": "huggingface", + "class": HuggingFaceProvider, + "default_base_url": "https://api-inference.huggingface.co/v1", + "default_rpm": 30, + "default_tpm": 50_000, + }, + { + "name": "cohere", + "class": CohereProvider, + "default_base_url": "https://api.cohere.com/compatibility/v1", + "default_rpm": 20, + "default_tpm": 100_000, + }, + { + "name": "sambanova", + "class": SambaNovaProvider, + "default_base_url": "https://api.sambanova.ai/v1", + "default_rpm": 60, + "default_tpm": 100_000, + }, +] + async def initialize_scheduler(config: dict) -> Scheduler: """Initialize the scheduler with providers from config.""" global model_mapper - + conductor_config = config.get("conductor", {}) strategy = SchedulingStrategy( conductor_config.get("scheduling_strategy", "round_robin") ) - + # Initialize model mapper with custom mappings from config custom_models = config.get("models", {}) model_mapper = ModelMapper(custom_models) - + await logger.ainfo( "Model mapper initialized", available_models=model_mapper.get_available_models() ) - + sched = Scheduler(strategy=strategy) - providers_config = config.get("providers", {}) - - # Initialize Cerebras - cerebras_config = providers_config.get("cerebras", {}) - if cerebras_config.get("enabled", False) and cerebras_config.get("keys"): - provider = CerebrasProvider( - base_url=cerebras_config.get("base_url", "https://api.cerebras.ai/v1"), - model_mapper=model_mapper, - ) - - for i, key_config in enumerate(cerebras_config.get("keys", [])): - key_name = key_config.get("name", f"cerebras-key-{i+1}") - bucket = RateLimitBucket( - name=f"cerebras:{key_name}", - requests_per_minute=key_config.get("requests_per_minute", 1000), - tokens_per_minute=key_config.get("tokens_per_minute", 1_000_000), - ) - provider_key = ProviderKey( - provider_name="cerebras", - key_name=key_name, - api_key=key_config["api_key"], - bucket=bucket, - base_url=provider.base_url, - ) - provider.add_key(provider_key) - - if provider.keys: - await sched.add_provider(provider) - await logger.ainfo( - "Cerebras provider initialized", - keys_count=len(provider.keys) - ) - - # Initialize NVIDIA NIM - nvidia_config = providers_config.get("nvidia", {}) - if nvidia_config.get("enabled", False) and nvidia_config.get("keys"): - provider = NvidiaProvider( - base_url=nvidia_config.get("base_url", "https://integrate.api.nvidia.com/v1"), + + # Initialize all providers from the definition table + for definition in PROVIDER_DEFINITIONS: + name = definition["name"] + provider_config = providers_config.get(name, {}) + + if not provider_config.get("enabled", False): + continue + if not provider_config.get("keys"): + continue + + provider = definition["class"]( + base_url=provider_config.get("base_url", definition["default_base_url"]), model_mapper=model_mapper, ) - - for i, key_config in enumerate(nvidia_config.get("keys", [])): - key_name = key_config.get("name", f"nvidia-key-{i+1}") + + for i, key_config in enumerate(provider_config.get("keys", [])): + key_name = key_config.get("name", f"{name}-key-{i+1}") bucket = RateLimitBucket( - name=f"nvidia:{key_name}", - requests_per_minute=key_config.get("requests_per_minute", 60), - tokens_per_minute=key_config.get("tokens_per_minute", 100_000), + name=f"{name}:{key_name}", + requests_per_minute=key_config.get("requests_per_minute", definition["default_rpm"]), + tokens_per_minute=key_config.get("tokens_per_minute", definition["default_tpm"]), ) provider_key = ProviderKey( - provider_name="nvidia", + provider_name=name, key_name=key_name, api_key=key_config["api_key"], bucket=bucket, base_url=provider.base_url, ) provider.add_key(provider_key) - + if provider.keys: await sched.add_provider(provider) await logger.ainfo( - "NVIDIA NIM provider initialized", + f"{name.upper()} provider initialized", keys_count=len(provider.keys) ) - + await sched.start() return sched @@ -145,29 +205,26 @@ async def initialize_scheduler(config: dict) -> Scheduler: async def lifespan(app: FastAPI): """Application lifespan manager.""" global scheduler - + await logger.ainfo("Starting TrainForgeConductor...") - - # Load configuration + config = load_config(settings.config_path) await logger.ainfo("Configuration loaded", config_path=settings.config_path) - - # Initialize scheduler + scheduler = await initialize_scheduler(config) - + if not scheduler.providers: await logger.awarning( "No providers configured! Add API keys to config/config.yaml" ) - + await logger.ainfo( "TrainForgeConductor ready", providers=list(scheduler.providers.keys()), ) - + yield - - # Shutdown + await logger.ainfo("Shutting down TrainForgeConductor...") if scheduler: await scheduler.stop() @@ -202,10 +259,9 @@ async def get_status(): """Get conductor status including all provider rate limits.""" if not scheduler: raise HTTPException(status_code=503, detail="Scheduler not initialized") - + status = await scheduler.get_status() - - # Convert to response model + from app.models import ProviderStatus provider_statuses = [ ProviderStatus( @@ -222,7 +278,7 @@ async def get_status(): ) for p in status["providers"] ] - + return ConductorStatus( status=status["status"], total_providers=status["total_providers"], @@ -238,36 +294,33 @@ async def get_status(): async def chat_completion(request: ChatCompletionRequest): """ OpenAI-compatible chat completion endpoint. - + Use unified model names like "llama-70b" or "llama-8b" - the conductor will automatically translate to the correct provider-specific name. - + The conductor routes requests to available providers based on rate limits and the configured scheduling strategy. - + Optional fields: - model: Model to use (default: llama-70b). Use unified names. - provider: Force a specific provider (e.g., "cerebras" or "nvidia") - priority: Request priority (0-10, higher = more priority) - auto_retry: Automatically retry on failures with exponential backoff (default: true). - When enabled, the conductor absorbs transient errors and keeps retrying until it gets - a response or exhausts max_retries, so the caller always gets an answer. - max_retries: Maximum number of retry attempts when auto_retry is enabled (default: 10). """ if not scheduler: raise HTTPException(status_code=503, detail="Scheduler not initialized") - + if not scheduler.providers: raise HTTPException( - status_code=503, + status_code=503, detail="No providers configured. Add API keys to config/config.yaml" ) - + try: response = await scheduler.submit(request, wait=True) return response except MaxRetriesExhaustedError as e: - # Auto-retry exhausted all attempts — return full retry log await logger.aerror( "Auto-retry exhausted", total_attempts=e.total_attempts, @@ -285,7 +338,6 @@ async def chat_completion(request: ChatCompletionRequest): }, ) except AllProvidersExhaustedError as e: - # All providers failed (auto_retry=False path) - return detailed error await logger.aerror( "All providers exhausted", error_count=len(e.errors), @@ -338,56 +390,46 @@ async def chat_completion(request: ChatCompletionRequest): async def batch_chat_completion(batch: BatchRequest): """ Submit a batch of chat completion requests. - + All requests will be scheduled across available providers to maximize throughput while respecting rate limits. """ if not scheduler: raise HTTPException(status_code=503, detail="Scheduler not initialized") - + if not scheduler.providers: raise HTTPException( status_code=503, detail="No providers configured. Add API keys to config/config.yaml" ) - + start_time = time.time() - - # Create tasks for all requests + tasks = [ scheduler.submit(req, wait=True) for req in batch.requests ] - + responses: list[ChatCompletionResponse] = [] failed: list[dict] = [] - + if batch.wait_for_all: - # Wait for all to complete, collect results results = await asyncio.gather(*tasks, return_exceptions=True) - for i, result in enumerate(results): if isinstance(result, Exception): - failed.append({ - "index": i, - "error": str(result), - }) + failed.append({"index": i, "error": str(result)}) else: responses.append(result) else: - # Return as they complete for i, coro in enumerate(asyncio.as_completed(tasks)): try: result = await coro responses.append(result) except Exception as e: - failed.append({ - "index": i, - "error": str(e), - }) - + failed.append({"index": i, "error": str(e)}) + elapsed_ms = (time.time() - start_time) * 1000 - + return BatchResponse( responses=responses, failed=failed, @@ -399,21 +441,20 @@ async def batch_chat_completion(batch: BatchRequest): async def list_models(): """ List all available unified model names. - + These are the model names you can use in requests. The conductor automatically translates them to provider-specific names. """ if not model_mapper: raise HTTPException(status_code=503, detail="Model mapper not initialized") - - # Return unified model names + models = [] for unified_name in model_mapper.get_available_models(): models.append({ "id": unified_name, "object": "model", }) - + return { "data": models, "object": "list", @@ -433,4 +474,4 @@ def main(): if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/app/models_mapping.py b/app/models_mapping.py index a49dd7e..44dc3af 100644 --- a/app/models_mapping.py +++ b/app/models_mapping.py @@ -1,87 +1,168 @@ """ Unified model name mapping. - This allows you to use simple names like "llama-70b" and the conductor will automatically translate to the correct provider-specific name. """ -# Default model mappings: unified name -> provider-specific names DEFAULT_MODEL_MAPPING = { - # Llama 3.3 70B - the flagship model + # Llama 3.3 70B "llama-70b": { "cerebras": "llama-3.3-70b", "nvidia": "meta/llama-3.3-70b-instruct", + "groq": "llama-3.3-70b-versatile", + "openrouter": "meta-llama/llama-3.3-70b-instruct", + "sambanova": "Meta-Llama-3.3-70B-Instruct", + "huggingface": "meta-llama/Llama-3.3-70B-Instruct", }, "llama-3.3-70b": { "cerebras": "llama-3.3-70b", "nvidia": "meta/llama-3.3-70b-instruct", + "groq": "llama-3.3-70b-versatile", + "openrouter": "meta-llama/llama-3.3-70b-instruct", + "sambanova": "Meta-Llama-3.3-70B-Instruct", + "huggingface": "meta-llama/Llama-3.3-70B-Instruct", }, - - # Llama 3.1 8B - fast and cheap + # Llama 3.1 8B "llama-8b": { "cerebras": "llama3.1-8b", "nvidia": "meta/llama-3.1-8b-instruct", + "groq": "llama-3.1-8b-instant", + "openrouter": "meta-llama/llama-3.1-8b-instruct", + "sambanova": "Meta-Llama-3.1-8B-Instruct", + "huggingface": "meta-llama/Llama-3.1-8B-Instruct", }, "llama-3.1-8b": { "cerebras": "llama3.1-8b", "nvidia": "meta/llama-3.1-8b-instruct", + "groq": "llama-3.1-8b-instant", + "openrouter": "meta-llama/llama-3.1-8b-instruct", + "sambanova": "Meta-Llama-3.1-8B-Instruct", + "huggingface": "meta-llama/Llama-3.1-8B-Instruct", }, - # Llama 3.1 70B "llama-3.1-70b": { "cerebras": "llama-3.1-70b", "nvidia": "meta/llama-3.1-70b-instruct", + "groq": "llama-3.1-70b-versatile", + "openrouter": "meta-llama/llama-3.1-70b-instruct", + "sambanova": "Meta-Llama-3.1-70B-Instruct", + "huggingface": "meta-llama/Llama-3.1-70B-Instruct", + }, + # Llama 3.1 405B + "llama-405b": { + "nvidia": "meta/llama-3.1-405b-instruct", + "sambanova": "Meta-Llama-3.1-405B-Instruct", + "openrouter": "meta-llama/llama-3.1-405b-instruct", + }, + # Gemini models + "gemini-flash": { + "gemini": "gemini-2.0-flash", + "openrouter": "google/gemini-2.0-flash-exp:free", + }, + "gemini-2.0-flash": { + "gemini": "gemini-2.0-flash", + "openrouter": "google/gemini-2.0-flash-exp:free", + }, + "gemini-flash-lite": { + "gemini": "gemini-2.0-flash-lite", + }, + # Mistral models + "mistral-small": { + "mistral": "mistral-small-latest", + "openrouter": "mistralai/mistral-small", + }, + "mistral-7b": { + "mistral": "open-mistral-7b", + "openrouter": "mistralai/mistral-7b-instruct:free", + "huggingface": "mistralai/Mistral-7B-Instruct-v0.3", + }, + # DeepSeek models + "deepseek-chat": { + "deepseek": "deepseek-chat", + "openrouter": "deepseek/deepseek-chat:free", + }, + "deepseek-coder": { + "deepseek": "deepseek-coder", + "openrouter": "deepseek/deepseek-coder:free", + }, + "deepseek-r1": { + "deepseek": "deepseek-reasoner", + "openrouter": "deepseek/deepseek-r1:free", + "sambanova": "DeepSeek-R1", + }, + # Cohere models + "command-r": { + "cohere": "command-r-08-2024", + "openrouter": "cohere/command-r", + }, + "command-r-plus": { + "cohere": "command-r-plus-08-2024", + "openrouter": "cohere/command-r-plus", + }, + # Groq specific + "mixtral-8x7b": { + "groq": "mixtral-8x7b-32768", + "openrouter": "mistralai/mixtral-8x7b-instruct", + }, + "gemma-7b": { + "groq": "gemma-7b-it", + "openrouter": "google/gemma-7b-it:free", + "huggingface": "google/gemma-7b-it", + }, + "gemma2-9b": { + "groq": "gemma2-9b-it", + "openrouter": "google/gemma-2-9b-it:free", + "huggingface": "google/gemma-2-9b-it", + }, + # Qwen models + "qwen-72b": { + "sambanova": "Qwen2.5-72B-Instruct", + "openrouter": "qwen/qwen-2.5-72b-instruct:free", + "huggingface": "Qwen/Qwen2.5-72B-Instruct", }, } -# Default model to use when none specified DEFAULT_MODEL = "llama-70b" class ModelMapper: """Maps unified model names to provider-specific names.""" - + def __init__(self, custom_mappings: dict = None): """ Initialize with optional custom mappings from config. - + Args: custom_mappings: Dict of {unified_name: {provider: provider_name}} """ self.mappings = DEFAULT_MODEL_MAPPING.copy() if custom_mappings: self.mappings.update(custom_mappings) - + def get_provider_model(self, unified_name: str, provider: str) -> str: """ Get the provider-specific model name. - + Args: unified_name: The unified model name (e.g., "llama-70b") provider: The provider name (e.g., "cerebras", "nvidia") - + Returns: Provider-specific model name """ if not unified_name: unified_name = DEFAULT_MODEL - - # Normalize the name name_lower = unified_name.lower().strip() - - # Check if it's in our mappings if name_lower in self.mappings: provider_models = self.mappings[name_lower] if provider in provider_models: return provider_models[provider] - - # If not found, return as-is (maybe it's already provider-specific) return unified_name - + def get_available_models(self) -> list[str]: """Get list of available unified model names.""" return list(self.mappings.keys()) - + def add_mapping(self, unified_name: str, provider_models: dict): """Add a custom model mapping.""" - self.mappings[unified_name.lower()] = provider_models + self.mappings[unified_name.lower()] = provider_models \ No newline at end of file diff --git a/app/providers/__init__.py b/app/providers/__init__.py index 9deb484..982c3c6 100644 --- a/app/providers/__init__.py +++ b/app/providers/__init__.py @@ -1,13 +1,27 @@ """Provider implementations for TrainForgeConductor.""" - from .base import BaseProvider, ProviderKey from .cerebras import CerebrasProvider from .nvidia import NvidiaProvider +from .groq import GroqProvider +from .gemini import GeminiProvider +from .mistral import MistralProvider +from .openrouter import OpenRouterProvider +from .deepseek import DeepSeekProvider +from .huggingface import HuggingFaceProvider +from .cohere import CohereProvider +from .sambanova import SambaNovaProvider __all__ = [ "BaseProvider", "ProviderKey", "CerebrasProvider", "NvidiaProvider", -] - + "GroqProvider", + "GeminiProvider", + "MistralProvider", + "OpenRouterProvider", + "DeepSeekProvider", + "HuggingFaceProvider", + "CohereProvider", + "SambaNovaProvider", +] \ No newline at end of file diff --git a/app/providers/cohere.py b/app/providers/cohere.py new file mode 100644 index 0000000..89672ca --- /dev/null +++ b/app/providers/cohere.py @@ -0,0 +1,162 @@ +"""Cohere provider implementation.""" + +import time +import httpx +import structlog + +from app.models import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionChoice, + Message, + Usage, +) +from app.providers.base import BaseProvider, ProviderKey +from app.models_mapping import ModelMapper +from app.exceptions import ( + RateLimitError, + CapabilityError, + ProviderUnavailableError, +) + +logger = structlog.get_logger() + + +class CohereProvider(BaseProvider): + """Cohere provider — free tier available.""" + + name = "cohere" + + def __init__(self, base_url: str = "https://api.cohere.com/compatibility/v1", model_mapper: ModelMapper = None): + super().__init__(base_url, model_mapper) + + async def chat_completion( + self, + key: ProviderKey, + request: ChatCompletionRequest, + ) -> ChatCompletionResponse: + """Execute chat completion via Cohere API.""" + + model = self.get_model_name(request.model) + + payload = { + "model": model, + "messages": [{"role": m.role, "content": m.content} for m in request.messages], + "temperature": request.temperature, + "max_tokens": request.max_tokens, + "top_p": request.top_p, + } + + if request.stop: + payload["stop"] = request.stop + + headers = { + "Authorization": f"Bearer {key.api_key}", + "Content-Type": "application/json", + } + + client = await self.get_client() + + await logger.ainfo( + "Sending request to Cohere", + model=model, + key_name=key.key_name, + messages_count=len(request.messages), + ) + + try: + response = await client.post( + f"{self.base_url}/chat/completions", + json=payload, + headers=headers, + ) + response.raise_for_status() + data = response.json() + + choice = data["choices"][0] + usage = data.get("usage", {}) + + return ChatCompletionResponse( + id=data.get("id", f"cohere-{int(time.time())}"), + created=data.get("created", int(time.time())), + model=model, + choices=[ + ChatCompletionChoice( + index=0, + message=Message( + role="assistant", + content=choice["message"]["content"], + ), + finish_reason=choice.get("finish_reason", "stop"), + ) + ], + usage=Usage( + prompt_tokens=usage.get("prompt_tokens", 0), + completion_tokens=usage.get("completion_tokens", 0), + total_tokens=usage.get("total_tokens", 0), + ), + provider=self.name, + provider_key_name=key.key_name, + ) + + except httpx.HTTPStatusError as e: + status_code = e.response.status_code + response_text = e.response.text + + await logger.aerror( + "Cohere API error", + status_code=status_code, + response=response_text, + key_name=key.key_name, + ) + + if status_code == 429: + retry_after = e.response.headers.get("retry-after") + retry_seconds = float(retry_after) if retry_after else None + raise RateLimitError( + provider=self.name, + retry_after=retry_seconds, + message=f"Rate limit exceeded: {response_text[:200]}", + ) + + if 500 <= status_code < 600: + raise ProviderUnavailableError( + provider=self.name, + status_code=status_code, + message=f"Cohere server error: {response_text[:200]}", + ) + + if status_code == 400: + response_lower = response_text.lower() + if "image" in response_lower or "vision" in response_lower: + raise CapabilityError( + provider=self.name, + capability="vision", + message=f"Vision not supported: {response_text[:200]}", + ) + if "tool" in response_lower or "function" in response_lower: + raise CapabilityError( + provider=self.name, + capability="tool_calls", + message=f"Tool calling error: {response_text[:200]}", + ) + raise CapabilityError( + provider=self.name, + capability="unknown", + message=f"Request error: {response_text[:200]}", + ) + + raise + + except httpx.TimeoutException as e: + await logger.aerror("Cohere request timeout", error=str(e)) + raise ProviderUnavailableError( + provider=self.name, + status_code=504, + message=f"Request timeout: {str(e)}", + ) + except (RateLimitError, CapabilityError, ProviderUnavailableError): + raise + except Exception as e: + await logger.aerror("Cohere request failed", error=str(e)) + raise \ No newline at end of file diff --git a/app/providers/deepseek.py b/app/providers/deepseek.py new file mode 100644 index 0000000..d5980fc --- /dev/null +++ b/app/providers/deepseek.py @@ -0,0 +1,166 @@ +"""DeepSeek provider implementation.""" + +import time +import httpx +import structlog + +from app.models import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionChoice, + Message, + Usage, +) +from app.providers.base import BaseProvider, ProviderKey +from app.models_mapping import ModelMapper +from app.exceptions import ( + RateLimitError, + CapabilityError, + ProviderUnavailableError, +) + +logger = structlog.get_logger() + + +class DeepSeekProvider(BaseProvider): + """DeepSeek provider — free tier with strong coding models.""" + + name = "deepseek" + + def __init__(self, base_url: str = "https://api.deepseek.com/v1", model_mapper: ModelMapper = None): + super().__init__(base_url, model_mapper) + + async def chat_completion( + self, + key: ProviderKey, + request: ChatCompletionRequest, + ) -> ChatCompletionResponse: + """Execute chat completion via DeepSeek API.""" + + model = self.get_model_name(request.model) + + payload = { + "model": model, + "messages": [{"role": m.role, "content": m.content} for m in request.messages], + "temperature": request.temperature, + "max_tokens": request.max_tokens, + "top_p": request.top_p, + } + + if request.stop: + payload["stop"] = request.stop + + headers = { + "Authorization": f"Bearer {key.api_key}", + "Content-Type": "application/json", + } + + client = await self.get_client() + + await logger.ainfo( + "Sending request to DeepSeek", + model=model, + key_name=key.key_name, + messages_count=len(request.messages), + ) + + try: + response = await client.post( + f"{self.base_url}/chat/completions", + json=payload, + headers=headers, + ) + response.raise_for_status() + data = response.json() + + choice = data["choices"][0] + usage = data.get("usage", {}) + + total_tokens = usage.get("total_tokens", 0) + if total_tokens > 0: + await key.bucket.consume_tokens(total_tokens) + + return ChatCompletionResponse( + id=data.get("id", f"deepseek-{int(time.time())}"), + created=data.get("created", int(time.time())), + model=model, + choices=[ + ChatCompletionChoice( + index=0, + message=Message( + role="assistant", + content=choice["message"]["content"], + ), + finish_reason=choice.get("finish_reason", "stop"), + ) + ], + usage=Usage( + prompt_tokens=usage.get("prompt_tokens", 0), + completion_tokens=usage.get("completion_tokens", 0), + total_tokens=usage.get("total_tokens", 0), + ), + provider=self.name, + provider_key_name=key.key_name, + ) + + except httpx.HTTPStatusError as e: + status_code = e.response.status_code + response_text = e.response.text + + await logger.aerror( + "DeepSeek API error", + status_code=status_code, + response=response_text, + key_name=key.key_name, + ) + + if status_code == 429: + retry_after = e.response.headers.get("retry-after") + retry_seconds = float(retry_after) if retry_after else None + raise RateLimitError( + provider=self.name, + retry_after=retry_seconds, + message=f"Rate limit exceeded: {response_text[:200]}", + ) + + if 500 <= status_code < 600: + raise ProviderUnavailableError( + provider=self.name, + status_code=status_code, + message=f"DeepSeek server error: {response_text[:200]}", + ) + + if status_code == 400: + response_lower = response_text.lower() + if "image" in response_lower or "vision" in response_lower: + raise CapabilityError( + provider=self.name, + capability="vision", + message=f"Vision not supported: {response_text[:200]}", + ) + if "tool" in response_lower or "function" in response_lower: + raise CapabilityError( + provider=self.name, + capability="tool_calls", + message=f"Tool calling error: {response_text[:200]}", + ) + raise CapabilityError( + provider=self.name, + capability="unknown", + message=f"Request error: {response_text[:200]}", + ) + + raise + + except httpx.TimeoutException as e: + await logger.aerror("DeepSeek request timeout", error=str(e)) + raise ProviderUnavailableError( + provider=self.name, + status_code=504, + message=f"Request timeout: {str(e)}", + ) + except (RateLimitError, CapabilityError, ProviderUnavailableError): + raise + except Exception as e: + await logger.aerror("DeepSeek request failed", error=str(e)) + raise \ No newline at end of file diff --git a/app/providers/gemini.py b/app/providers/gemini.py new file mode 100644 index 0000000..48977f9 --- /dev/null +++ b/app/providers/gemini.py @@ -0,0 +1,162 @@ +"""Google Gemini provider implementation.""" + +import time +import httpx +import structlog + +from app.models import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionChoice, + Message, + Usage, +) +from app.providers.base import BaseProvider, ProviderKey +from app.models_mapping import ModelMapper +from app.exceptions import ( + RateLimitError, + CapabilityError, + ProviderUnavailableError, +) + +logger = structlog.get_logger() + + +class GeminiProvider(BaseProvider): + """Google Gemini provider — free tier available.""" + + name = "gemini" + + def __init__(self, base_url: str = "https://generativelanguage.googleapis.com/v1beta/openai", model_mapper: ModelMapper = None): + super().__init__(base_url, model_mapper) + + async def chat_completion( + self, + key: ProviderKey, + request: ChatCompletionRequest, + ) -> ChatCompletionResponse: + """Execute chat completion via Gemini API.""" + + model = self.get_model_name(request.model) + + payload = { + "model": model, + "messages": [{"role": m.role, "content": m.content} for m in request.messages], + "temperature": request.temperature, + "max_tokens": request.max_tokens, + "top_p": request.top_p, + } + + if request.stop: + payload["stop"] = request.stop + + headers = { + "Authorization": f"Bearer {key.api_key}", + "Content-Type": "application/json", + } + + client = await self.get_client() + + await logger.ainfo( + "Sending request to Gemini", + model=model, + key_name=key.key_name, + messages_count=len(request.messages), + ) + + try: + response = await client.post( + f"{self.base_url}/chat/completions", + json=payload, + headers=headers, + ) + response.raise_for_status() + data = response.json() + + choice = data["choices"][0] + usage = data.get("usage", {}) + + return ChatCompletionResponse( + id=data.get("id", f"gemini-{int(time.time())}"), + created=data.get("created", int(time.time())), + model=model, + choices=[ + ChatCompletionChoice( + index=0, + message=Message( + role="assistant", + content=choice["message"]["content"], + ), + finish_reason=choice.get("finish_reason", "stop"), + ) + ], + usage=Usage( + prompt_tokens=usage.get("prompt_tokens", 0), + completion_tokens=usage.get("completion_tokens", 0), + total_tokens=usage.get("total_tokens", 0), + ), + provider=self.name, + provider_key_name=key.key_name, + ) + + except httpx.HTTPStatusError as e: + status_code = e.response.status_code + response_text = e.response.text + + await logger.aerror( + "Gemini API error", + status_code=status_code, + response=response_text, + key_name=key.key_name, + ) + + if status_code == 429: + retry_after = e.response.headers.get("retry-after") + retry_seconds = float(retry_after) if retry_after else None + raise RateLimitError( + provider=self.name, + retry_after=retry_seconds, + message=f"Rate limit exceeded: {response_text[:200]}", + ) + + if 500 <= status_code < 600: + raise ProviderUnavailableError( + provider=self.name, + status_code=status_code, + message=f"Gemini server error: {response_text[:200]}", + ) + + if status_code == 400: + response_lower = response_text.lower() + if "image" in response_lower or "vision" in response_lower: + raise CapabilityError( + provider=self.name, + capability="vision", + message=f"Vision not supported: {response_text[:200]}", + ) + if "tool" in response_lower or "function" in response_lower: + raise CapabilityError( + provider=self.name, + capability="tool_calls", + message=f"Tool calling error: {response_text[:200]}", + ) + raise CapabilityError( + provider=self.name, + capability="unknown", + message=f"Request error: {response_text[:200]}", + ) + + raise + + except httpx.TimeoutException as e: + await logger.aerror("Gemini request timeout", error=str(e)) + raise ProviderUnavailableError( + provider=self.name, + status_code=504, + message=f"Request timeout: {str(e)}", + ) + except (RateLimitError, CapabilityError, ProviderUnavailableError): + raise + except Exception as e: + await logger.aerror("Gemini request failed", error=str(e)) + raise \ No newline at end of file diff --git a/app/providers/groq.py b/app/providers/groq.py new file mode 100644 index 0000000..645f280 --- /dev/null +++ b/app/providers/groq.py @@ -0,0 +1,162 @@ +"""Groq provider implementation.""" + +import time +import httpx +import structlog + +from app.models import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionChoice, + Message, + Usage, +) +from app.providers.base import BaseProvider, ProviderKey +from app.models_mapping import ModelMapper +from app.exceptions import ( + RateLimitError, + CapabilityError, + ProviderUnavailableError, +) + +logger = structlog.get_logger() + + +class GroqProvider(BaseProvider): + """Groq provider — free tier, very fast inference.""" + + name = "groq" + + def __init__(self, base_url: str = "https://api.groq.com/openai/v1", model_mapper: ModelMapper = None): + super().__init__(base_url, model_mapper) + + async def chat_completion( + self, + key: ProviderKey, + request: ChatCompletionRequest, + ) -> ChatCompletionResponse: + """Execute chat completion via Groq API.""" + + model = self.get_model_name(request.model) + + payload = { + "model": model, + "messages": [{"role": m.role, "content": m.content} for m in request.messages], + "temperature": request.temperature, + "max_tokens": request.max_tokens, + "top_p": request.top_p, + } + + if request.stop: + payload["stop"] = request.stop + + headers = { + "Authorization": f"Bearer {key.api_key}", + "Content-Type": "application/json", + } + + client = await self.get_client() + + await logger.ainfo( + "Sending request to Groq", + model=model, + key_name=key.key_name, + messages_count=len(request.messages), + ) + + try: + response = await client.post( + f"{self.base_url}/chat/completions", + json=payload, + headers=headers, + ) + response.raise_for_status() + data = response.json() + + choice = data["choices"][0] + usage = data.get("usage", {}) + + return ChatCompletionResponse( + id=data.get("id", f"groq-{int(time.time())}"), + created=data.get("created", int(time.time())), + model=model, + choices=[ + ChatCompletionChoice( + index=0, + message=Message( + role="assistant", + content=choice["message"]["content"], + ), + finish_reason=choice.get("finish_reason", "stop"), + ) + ], + usage=Usage( + prompt_tokens=usage.get("prompt_tokens", 0), + completion_tokens=usage.get("completion_tokens", 0), + total_tokens=usage.get("total_tokens", 0), + ), + provider=self.name, + provider_key_name=key.key_name, + ) + + except httpx.HTTPStatusError as e: + status_code = e.response.status_code + response_text = e.response.text + + await logger.aerror( + "Groq API error", + status_code=status_code, + response=response_text, + key_name=key.key_name, + ) + + if status_code == 429: + retry_after = e.response.headers.get("retry-after") + retry_seconds = float(retry_after) if retry_after else None + raise RateLimitError( + provider=self.name, + retry_after=retry_seconds, + message=f"Rate limit exceeded: {response_text[:200]}", + ) + + if 500 <= status_code < 600: + raise ProviderUnavailableError( + provider=self.name, + status_code=status_code, + message=f"Groq server error: {response_text[:200]}", + ) + + if status_code == 400: + response_lower = response_text.lower() + if "image" in response_lower or "vision" in response_lower: + raise CapabilityError( + provider=self.name, + capability="vision", + message=f"Vision not supported: {response_text[:200]}", + ) + if "tool" in response_lower or "function" in response_lower: + raise CapabilityError( + provider=self.name, + capability="tool_calls", + message=f"Tool calling error: {response_text[:200]}", + ) + raise CapabilityError( + provider=self.name, + capability="unknown", + message=f"Request error: {response_text[:200]}", + ) + + raise + + except httpx.TimeoutException as e: + await logger.aerror("Groq request timeout", error=str(e)) + raise ProviderUnavailableError( + provider=self.name, + status_code=504, + message=f"Request timeout: {str(e)}", + ) + except (RateLimitError, CapabilityError, ProviderUnavailableError): + raise + except Exception as e: + await logger.aerror("Groq request failed", error=str(e)) + raise \ No newline at end of file diff --git a/app/providers/huggingface.py b/app/providers/huggingface.py new file mode 100644 index 0000000..f10dc54 --- /dev/null +++ b/app/providers/huggingface.py @@ -0,0 +1,166 @@ +"""Hugging Face provider implementation.""" + +import time +import httpx +import structlog + +from app.models import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionChoice, + Message, + Usage, +) +from app.providers.base import BaseProvider, ProviderKey +from app.models_mapping import ModelMapper +from app.exceptions import ( + RateLimitError, + CapabilityError, + ProviderUnavailableError, +) + +logger = structlog.get_logger() + + +class HuggingFaceProvider(BaseProvider): + """Hugging Face Inference API provider — free tier available.""" + + name = "huggingface" + + def __init__(self, base_url: str = "https://api-inference.huggingface.co/v1", model_mapper: ModelMapper = None): + super().__init__(base_url, model_mapper) + + async def chat_completion( + self, + key: ProviderKey, + request: ChatCompletionRequest, + ) -> ChatCompletionResponse: + """Execute chat completion via Hugging Face Inference API.""" + + model = self.get_model_name(request.model) + + payload = { + "model": model, + "messages": [{"role": m.role, "content": m.content} for m in request.messages], + "temperature": request.temperature, + "max_tokens": request.max_tokens, + "top_p": request.top_p, + } + + if request.stop: + payload["stop"] = request.stop + + headers = { + "Authorization": f"Bearer {key.api_key}", + "Content-Type": "application/json", + } + + client = await self.get_client() + + await logger.ainfo( + "Sending request to Hugging Face", + model=model, + key_name=key.key_name, + messages_count=len(request.messages), + ) + + try: + response = await client.post( + f"{self.base_url}/chat/completions", + json=payload, + headers=headers, + ) + response.raise_for_status() + data = response.json() + + choice = data["choices"][0] + usage = data.get("usage", {}) + + total_tokens = usage.get("total_tokens", 0) + if total_tokens > 0: + await key.bucket.consume_tokens(total_tokens) + + return ChatCompletionResponse( + id=data.get("id", f"huggingface-{int(time.time())}"), + created=data.get("created", int(time.time())), + model=model, + choices=[ + ChatCompletionChoice( + index=0, + message=Message( + role="assistant", + content=choice["message"]["content"], + ), + finish_reason=choice.get("finish_reason", "stop"), + ) + ], + usage=Usage( + prompt_tokens=usage.get("prompt_tokens", 0), + completion_tokens=usage.get("completion_tokens", 0), + total_tokens=usage.get("total_tokens", 0), + ), + provider=self.name, + provider_key_name=key.key_name, + ) + + except httpx.HTTPStatusError as e: + status_code = e.response.status_code + response_text = e.response.text + + await logger.aerror( + "Hugging Face API error", + status_code=status_code, + response=response_text, + key_name=key.key_name, + ) + + if status_code == 429: + retry_after = e.response.headers.get("retry-after") + retry_seconds = float(retry_after) if retry_after else None + raise RateLimitError( + provider=self.name, + retry_after=retry_seconds, + message=f"Rate limit exceeded: {response_text[:200]}", + ) + + if 500 <= status_code < 600: + raise ProviderUnavailableError( + provider=self.name, + status_code=status_code, + message=f"Hugging Face server error: {response_text[:200]}", + ) + + if status_code == 400: + response_lower = response_text.lower() + if "image" in response_lower or "vision" in response_lower: + raise CapabilityError( + provider=self.name, + capability="vision", + message=f"Vision not supported: {response_text[:200]}", + ) + if "tool" in response_lower or "function" in response_lower: + raise CapabilityError( + provider=self.name, + capability="tool_calls", + message=f"Tool calling error: {response_text[:200]}", + ) + raise CapabilityError( + provider=self.name, + capability="unknown", + message=f"Request error: {response_text[:200]}", + ) + + raise + + except httpx.TimeoutException as e: + await logger.aerror("Hugging Face request timeout", error=str(e)) + raise ProviderUnavailableError( + provider=self.name, + status_code=504, + message=f"Request timeout: {str(e)}", + ) + except (RateLimitError, CapabilityError, ProviderUnavailableError): + raise + except Exception as e: + await logger.aerror("Hugging Face request failed", error=str(e)) + raise \ No newline at end of file diff --git a/app/providers/mistral.py b/app/providers/mistral.py new file mode 100644 index 0000000..722c7f1 --- /dev/null +++ b/app/providers/mistral.py @@ -0,0 +1,166 @@ +"""Mistral provider implementation.""" + +import time +import httpx +import structlog + +from app.models import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionChoice, + Message, + Usage, +) +from app.providers.base import BaseProvider, ProviderKey +from app.models_mapping import ModelMapper +from app.exceptions import ( + RateLimitError, + CapabilityError, + ProviderUnavailableError, +) + +logger = structlog.get_logger() + + +class MistralProvider(BaseProvider): + """Mistral AI provider — free tier available.""" + + name = "mistral" + + def __init__(self, base_url: str = "https://api.mistral.ai/v1", model_mapper: ModelMapper = None): + super().__init__(base_url, model_mapper) + + async def chat_completion( + self, + key: ProviderKey, + request: ChatCompletionRequest, + ) -> ChatCompletionResponse: + """Execute chat completion via Mistral API.""" + + model = self.get_model_name(request.model) + + payload = { + "model": model, + "messages": [{"role": m.role, "content": m.content} for m in request.messages], + "temperature": request.temperature, + "max_tokens": request.max_tokens, + "top_p": request.top_p, + } + + if request.stop: + payload["stop"] = request.stop + + headers = { + "Authorization": f"Bearer {key.api_key}", + "Content-Type": "application/json", + } + + client = await self.get_client() + + await logger.ainfo( + "Sending request to Mistral", + model=model, + key_name=key.key_name, + messages_count=len(request.messages), + ) + + try: + response = await client.post( + f"{self.base_url}/chat/completions", + json=payload, + headers=headers, + ) + response.raise_for_status() + data = response.json() + + choice = data["choices"][0] + usage = data.get("usage", {}) + + total_tokens = usage.get("total_tokens", 0) + if total_tokens > 0: + await key.bucket.consume_tokens(total_tokens) + + return ChatCompletionResponse( + id=data.get("id", f"mistral-{int(time.time())}"), + created=data.get("created", int(time.time())), + model=model, + choices=[ + ChatCompletionChoice( + index=0, + message=Message( + role="assistant", + content=choice["message"]["content"], + ), + finish_reason=choice.get("finish_reason", "stop"), + ) + ], + usage=Usage( + prompt_tokens=usage.get("prompt_tokens", 0), + completion_tokens=usage.get("completion_tokens", 0), + total_tokens=usage.get("total_tokens", 0), + ), + provider=self.name, + provider_key_name=key.key_name, + ) + + except httpx.HTTPStatusError as e: + status_code = e.response.status_code + response_text = e.response.text + + await logger.aerror( + "Mistral API error", + status_code=status_code, + response=response_text, + key_name=key.key_name, + ) + + if status_code == 429: + retry_after = e.response.headers.get("retry-after") + retry_seconds = float(retry_after) if retry_after else None + raise RateLimitError( + provider=self.name, + retry_after=retry_seconds, + message=f"Rate limit exceeded: {response_text[:200]}", + ) + + if 500 <= status_code < 600: + raise ProviderUnavailableError( + provider=self.name, + status_code=status_code, + message=f"Mistral server error: {response_text[:200]}", + ) + + if status_code == 400: + response_lower = response_text.lower() + if "image" in response_lower or "vision" in response_lower: + raise CapabilityError( + provider=self.name, + capability="vision", + message=f"Vision not supported: {response_text[:200]}", + ) + if "tool" in response_lower or "function" in response_lower: + raise CapabilityError( + provider=self.name, + capability="tool_calls", + message=f"Tool calling error: {response_text[:200]}", + ) + raise CapabilityError( + provider=self.name, + capability="unknown", + message=f"Request error: {response_text[:200]}", + ) + + raise + + except httpx.TimeoutException as e: + await logger.aerror("Mistral request timeout", error=str(e)) + raise ProviderUnavailableError( + provider=self.name, + status_code=504, + message=f"Request timeout: {str(e)}", + ) + except (RateLimitError, CapabilityError, ProviderUnavailableError): + raise + except Exception as e: + await logger.aerror("Mistral request failed", error=str(e)) + raise \ No newline at end of file diff --git a/app/providers/nvidia.py b/app/providers/nvidia.py index 53c3541..5a50135 100644 --- a/app/providers/nvidia.py +++ b/app/providers/nvidia.py @@ -24,23 +24,21 @@ class NvidiaProvider(BaseProvider): """NVIDIA NIM provider (via build.nvidia.com or self-hosted).""" - + name = "nvidia" - + def __init__(self, base_url: str = "https://integrate.api.nvidia.com/v1", model_mapper: ModelMapper = None): super().__init__(base_url, model_mapper) - + async def chat_completion( self, key: ProviderKey, request: ChatCompletionRequest, ) -> ChatCompletionResponse: """Execute chat completion via NVIDIA NIM API.""" - - # Translate unified model name to NVIDIA-specific name + model = self.get_model_name(request.model) - - # Prepare request payload (OpenAI-compatible) + payload = { "model": model, "messages": [{"role": m.role, "content": m.content} for m in request.messages], @@ -48,24 +46,24 @@ async def chat_completion( "max_tokens": request.max_tokens, "top_p": request.top_p, } - + if request.stop: payload["stop"] = request.stop - + headers = { "Authorization": f"Bearer {key.api_key}", "Content-Type": "application/json", } - + client = await self.get_client() - + await logger.ainfo( "Sending request to NVIDIA NIM", model=model, key_name=key.key_name, messages_count=len(request.messages), ) - + try: response = await client.post( f"{self.base_url}/chat/completions", @@ -74,16 +72,14 @@ async def chat_completion( ) response.raise_for_status() data = response.json() - - # Parse response + choice = data["choices"][0] usage = data.get("usage", {}) - - # Update token consumption + total_tokens = usage.get("total_tokens", 0) if total_tokens > 0: await key.bucket.consume_tokens(total_tokens) - + return ChatCompletionResponse( id=data.get("id", f"nvidia-{int(time.time())}"), created=data.get("created", int(time.time())), @@ -106,21 +102,19 @@ async def chat_completion( provider=self.name, provider_key_name=key.key_name, ) - + except httpx.HTTPStatusError as e: status_code = e.response.status_code response_text = e.response.text - + await logger.aerror( "NVIDIA NIM API error", status_code=status_code, response=response_text, key_name=key.key_name, ) - - # Handle rate limit (429) + if status_code == 429: - # Try to extract retry-after header retry_after = e.response.headers.get("retry-after") retry_seconds = float(retry_after) if retry_after else None raise RateLimitError( @@ -128,18 +122,15 @@ async def chat_completion( retry_after=retry_seconds, message=f"Rate limit exceeded: {response_text[:200]}", ) - - # Handle server errors (5xx) - provider temporarily unavailable + if 500 <= status_code < 600: raise ProviderUnavailableError( provider=self.name, status_code=status_code, message=f"NVIDIA server error: {response_text[:200]}", ) - - # Handle capability/validation errors (400) + if status_code == 400: - # Check for known capability issues response_lower = response_text.lower() if "tool" in response_lower or "function" in response_lower: raise CapabilityError( @@ -153,16 +144,14 @@ async def chat_completion( capability="vision", message=f"Vision/image error: {response_text[:200]}", ) - # Generic capability error raise CapabilityError( provider=self.name, capability="unknown", message=f"Request error: {response_text[:200]}", ) - - # Re-raise other errors + raise - + except httpx.TimeoutException as e: await logger.aerror("NVIDIA NIM request timeout", error=str(e)) raise ProviderUnavailableError( @@ -171,8 +160,7 @@ async def chat_completion( message=f"Request timeout: {str(e)}", ) except (RateLimitError, CapabilityError, ProviderUnavailableError): - # Re-raise our custom exceptions raise except Exception as e: await logger.aerror("NVIDIA NIM request failed", error=str(e)) - raise + raise \ No newline at end of file diff --git a/app/providers/openrouter.py b/app/providers/openrouter.py new file mode 100644 index 0000000..09af140 --- /dev/null +++ b/app/providers/openrouter.py @@ -0,0 +1,168 @@ +"""OpenRouter provider implementation.""" + +import time +import httpx +import structlog + +from app.models import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionChoice, + Message, + Usage, +) +from app.providers.base import BaseProvider, ProviderKey +from app.models_mapping import ModelMapper +from app.exceptions import ( + RateLimitError, + CapabilityError, + ProviderUnavailableError, +) + +logger = structlog.get_logger() + + +class OpenRouterProvider(BaseProvider): + """OpenRouter provider — access to many free models.""" + + name = "openrouter" + + def __init__(self, base_url: str = "https://openrouter.ai/api/v1", model_mapper: ModelMapper = None): + super().__init__(base_url, model_mapper) + + async def chat_completion( + self, + key: ProviderKey, + request: ChatCompletionRequest, + ) -> ChatCompletionResponse: + """Execute chat completion via OpenRouter API.""" + + model = self.get_model_name(request.model) + + payload = { + "model": model, + "messages": [{"role": m.role, "content": m.content} for m in request.messages], + "temperature": request.temperature, + "max_tokens": request.max_tokens, + "top_p": request.top_p, + } + + if request.stop: + payload["stop"] = request.stop + + headers = { + "Authorization": f"Bearer {key.api_key}", + "Content-Type": "application/json", + "HTTP-Referer": "https://github.com/Alcray/TrainForgeConductor", + "X-Title": "TrainForgeConductor", + } + + client = await self.get_client() + + await logger.ainfo( + "Sending request to OpenRouter", + model=model, + key_name=key.key_name, + messages_count=len(request.messages), + ) + + try: + response = await client.post( + f"{self.base_url}/chat/completions", + json=payload, + headers=headers, + ) + response.raise_for_status() + data = response.json() + + choice = data["choices"][0] + usage = data.get("usage", {}) + + total_tokens = usage.get("total_tokens", 0) + if total_tokens > 0: + await key.bucket.consume_tokens(total_tokens) + + return ChatCompletionResponse( + id=data.get("id", f"openrouter-{int(time.time())}"), + created=data.get("created", int(time.time())), + model=model, + choices=[ + ChatCompletionChoice( + index=0, + message=Message( + role="assistant", + content=choice["message"]["content"], + ), + finish_reason=choice.get("finish_reason", "stop"), + ) + ], + usage=Usage( + prompt_tokens=usage.get("prompt_tokens", 0), + completion_tokens=usage.get("completion_tokens", 0), + total_tokens=usage.get("total_tokens", 0), + ), + provider=self.name, + provider_key_name=key.key_name, + ) + + except httpx.HTTPStatusError as e: + status_code = e.response.status_code + response_text = e.response.text + + await logger.aerror( + "OpenRouter API error", + status_code=status_code, + response=response_text, + key_name=key.key_name, + ) + + if status_code == 429: + retry_after = e.response.headers.get("retry-after") + retry_seconds = float(retry_after) if retry_after else None + raise RateLimitError( + provider=self.name, + retry_after=retry_seconds, + message=f"Rate limit exceeded: {response_text[:200]}", + ) + + if 500 <= status_code < 600: + raise ProviderUnavailableError( + provider=self.name, + status_code=status_code, + message=f"OpenRouter server error: {response_text[:200]}", + ) + + if status_code == 400: + response_lower = response_text.lower() + if "image" in response_lower or "vision" in response_lower: + raise CapabilityError( + provider=self.name, + capability="vision", + message=f"Vision not supported: {response_text[:200]}", + ) + if "tool" in response_lower or "function" in response_lower: + raise CapabilityError( + provider=self.name, + capability="tool_calls", + message=f"Tool calling error: {response_text[:200]}", + ) + raise CapabilityError( + provider=self.name, + capability="unknown", + message=f"Request error: {response_text[:200]}", + ) + + raise + + except httpx.TimeoutException as e: + await logger.aerror("OpenRouter request timeout", error=str(e)) + raise ProviderUnavailableError( + provider=self.name, + status_code=504, + message=f"Request timeout: {str(e)}", + ) + except (RateLimitError, CapabilityError, ProviderUnavailableError): + raise + except Exception as e: + await logger.aerror("OpenRouter request failed", error=str(e)) + raise \ No newline at end of file diff --git a/app/providers/sambanova.py b/app/providers/sambanova.py new file mode 100644 index 0000000..62659a6 --- /dev/null +++ b/app/providers/sambanova.py @@ -0,0 +1,166 @@ +"""SambaNova provider implementation.""" + +import time +import httpx +import structlog + +from app.models import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionChoice, + Message, + Usage, +) +from app.providers.base import BaseProvider, ProviderKey +from app.models_mapping import ModelMapper +from app.exceptions import ( + RateLimitError, + CapabilityError, + ProviderUnavailableError, +) + +logger = structlog.get_logger() + + +class SambaNovaProvider(BaseProvider): + """SambaNova provider — free tier, fast Llama inference.""" + + name = "sambanova" + + def __init__(self, base_url: str = "https://api.sambanova.ai/v1", model_mapper: ModelMapper = None): + super().__init__(base_url, model_mapper) + + async def chat_completion( + self, + key: ProviderKey, + request: ChatCompletionRequest, + ) -> ChatCompletionResponse: + """Execute chat completion via SambaNova API.""" + + model = self.get_model_name(request.model) + + payload = { + "model": model, + "messages": [{"role": m.role, "content": m.content} for m in request.messages], + "temperature": request.temperature, + "max_tokens": request.max_tokens, + "top_p": request.top_p, + } + + if request.stop: + payload["stop"] = request.stop + + headers = { + "Authorization": f"Bearer {key.api_key}", + "Content-Type": "application/json", + } + + client = await self.get_client() + + await logger.ainfo( + "Sending request to SambaNova", + model=model, + key_name=key.key_name, + messages_count=len(request.messages), + ) + + try: + response = await client.post( + f"{self.base_url}/chat/completions", + json=payload, + headers=headers, + ) + response.raise_for_status() + data = response.json() + + choice = data["choices"][0] + usage = data.get("usage", {}) + + total_tokens = usage.get("total_tokens", 0) + if total_tokens > 0: + await key.bucket.consume_tokens(total_tokens) + + return ChatCompletionResponse( + id=data.get("id", f"sambanova-{int(time.time())}"), + created=data.get("created", int(time.time())), + model=model, + choices=[ + ChatCompletionChoice( + index=0, + message=Message( + role="assistant", + content=choice["message"]["content"], + ), + finish_reason=choice.get("finish_reason", "stop"), + ) + ], + usage=Usage( + prompt_tokens=usage.get("prompt_tokens", 0), + completion_tokens=usage.get("completion_tokens", 0), + total_tokens=usage.get("total_tokens", 0), + ), + provider=self.name, + provider_key_name=key.key_name, + ) + + except httpx.HTTPStatusError as e: + status_code = e.response.status_code + response_text = e.response.text + + await logger.aerror( + "SambaNova API error", + status_code=status_code, + response=response_text, + key_name=key.key_name, + ) + + if status_code == 429: + retry_after = e.response.headers.get("retry-after") + retry_seconds = float(retry_after) if retry_after else None + raise RateLimitError( + provider=self.name, + retry_after=retry_seconds, + message=f"Rate limit exceeded: {response_text[:200]}", + ) + + if 500 <= status_code < 600: + raise ProviderUnavailableError( + provider=self.name, + status_code=status_code, + message=f"SambaNova server error: {response_text[:200]}", + ) + + if status_code == 400: + response_lower = response_text.lower() + if "image" in response_lower or "vision" in response_lower: + raise CapabilityError( + provider=self.name, + capability="vision", + message=f"Vision not supported: {response_text[:200]}", + ) + if "tool" in response_lower or "function" in response_lower: + raise CapabilityError( + provider=self.name, + capability="tool_calls", + message=f"Tool calling error: {response_text[:200]}", + ) + raise CapabilityError( + provider=self.name, + capability="unknown", + message=f"Request error: {response_text[:200]}", + ) + + raise + + except httpx.TimeoutException as e: + await logger.aerror("SambaNova request timeout", error=str(e)) + raise ProviderUnavailableError( + provider=self.name, + status_code=504, + message=f"Request timeout: {str(e)}", + ) + except (RateLimitError, CapabilityError, ProviderUnavailableError): + raise + except Exception as e: + await logger.aerror("SambaNova request failed", error=str(e)) + raise \ No newline at end of file diff --git a/tests/test_models_mapping.py b/tests/test_models_mapping.py new file mode 100644 index 0000000..4bdab29 --- /dev/null +++ b/tests/test_models_mapping.py @@ -0,0 +1,78 @@ +"""Tests for ModelMapper with all providers.""" + +import pytest +from app.models_mapping import ModelMapper + + +@pytest.fixture +def mapper(): + return ModelMapper() + + +def test_llama_70b_groq(mapper): + assert mapper.get_provider_model("llama-70b", "groq") == "llama-3.3-70b-versatile" + + +def test_llama_70b_sambanova(mapper): + assert mapper.get_provider_model("llama-70b", "sambanova") == "Meta-Llama-3.3-70B-Instruct" + + +def test_llama_70b_huggingface(mapper): + assert mapper.get_provider_model("llama-70b", "huggingface") == "meta-llama/Llama-3.3-70B-Instruct" + + +def test_llama_8b_groq(mapper): + assert mapper.get_provider_model("llama-8b", "groq") == "llama-3.1-8b-instant" + + +def test_gemini_flash(mapper): + assert mapper.get_provider_model("gemini-flash", "gemini") == "gemini-2.0-flash" + + +def test_gemini_flash_openrouter(mapper): + assert mapper.get_provider_model("gemini-flash", "openrouter") == "google/gemini-2.0-flash-exp:free" + + +def test_mistral_7b(mapper): + assert mapper.get_provider_model("mistral-7b", "mistral") == "open-mistral-7b" + + +def test_deepseek_r1(mapper): + assert mapper.get_provider_model("deepseek-r1", "sambanova") == "DeepSeek-R1" + + +def test_deepseek_chat_openrouter(mapper): + assert mapper.get_provider_model("deepseek-chat", "openrouter") == "deepseek/deepseek-chat:free" + + +def test_command_r(mapper): + assert mapper.get_provider_model("command-r", "cohere") == "command-r-08-2024" + + +def test_command_r_plus(mapper): + assert mapper.get_provider_model("command-r-plus", "cohere") == "command-r-plus-08-2024" + + +def test_gemma2_9b_groq(mapper): + assert mapper.get_provider_model("gemma2-9b", "groq") == "gemma2-9b-it" + + +def test_qwen_72b_openrouter(mapper): + assert mapper.get_provider_model("qwen-72b", "openrouter") == "qwen/qwen-2.5-72b-instruct:free" + + +def test_unknown_model_returns_as_is(mapper): + assert mapper.get_provider_model("some-custom-model", "groq") == "some-custom-model" + + +def test_none_model_returns_default(mapper): + result = mapper.get_provider_model(None, "cerebras") + assert result == "llama-3.3-70b" + + +def test_get_available_models(mapper): + models = mapper.get_available_models() + assert "llama-70b" in models + assert "gemini-flash" in models + assert "deepseek-r1" in models + assert "command-r" in models \ No newline at end of file diff --git a/tests/test_new_providers.py b/tests/test_new_providers.py new file mode 100644 index 0000000..26e132d --- /dev/null +++ b/tests/test_new_providers.py @@ -0,0 +1,208 @@ +#!/usr/bin/env python3 +"""Integration tests for new providers added in PR #4.""" + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +import httpx +import pytest + +BASE_URL = "http://localhost:8000" + + +class TestGroqProvider: + """Tests specifically for Groq provider.""" + + @pytest.mark.asyncio + async def test_groq_simple_completion(self): + """Test a simple completion via Groq using unified model name.""" + async with httpx.AsyncClient(timeout=60.0) as client: + response = await client.post( + f"{BASE_URL}/v1/chat/completions", + json={ + "messages": [ + {"role": "user", "content": "Say only the word 'GROQ_OK' and nothing else."} + ], + "model": "llama-70b", + "max_tokens": 20, + "provider": "groq", + "temperature": 0, + } + ) + assert response.status_code == 200 + data = response.json() + assert data["provider"] == "groq" + assert len(data["choices"]) > 0 + content = data["choices"][0]["message"]["content"] + print(f"✔ Groq response: {content[:50]}") + + @pytest.mark.asyncio + async def test_groq_with_system_prompt(self): + """Test Groq with system prompt.""" + async with httpx.AsyncClient(timeout=60.0) as client: + response = await client.post( + f"{BASE_URL}/v1/chat/completions", + json={ + "messages": [ + {"role": "system", "content": "You are a math tutor. Be concise."}, + {"role": "user", "content": "What is 10 + 32?"} + ], + "model": "llama-70b", + "max_tokens": 50, + "provider": "groq", + "temperature": 0, + } + ) + assert response.status_code == 200 + data = response.json() + content = data["choices"][0]["message"]["content"] + assert "42" in content + print(f"✔ Groq math: {content[:50]}") + + @pytest.mark.asyncio + async def test_groq_fast_model(self): + """Test Groq with llama-8b instant model.""" + async with httpx.AsyncClient(timeout=60.0) as client: + response = await client.post( + f"{BASE_URL}/v1/chat/completions", + json={ + "messages": [ + {"role": "user", "content": "What is the capital of France?"} + ], + "model": "llama-8b", + "max_tokens": 50, + "provider": "groq", + "temperature": 0, + } + ) + assert response.status_code == 200 + data = response.json() + content = data["choices"][0]["message"]["content"] + assert "Paris" in content or "paris" in content.lower() + print(f"✔ Groq fast model: {content[:50]}") + + +class TestOpenRouterProvider: + """Tests specifically for OpenRouter provider.""" + + @pytest.mark.asyncio + async def test_openrouter_simple_completion(self): + """Test a simple completion via OpenRouter using unified model name.""" + async with httpx.AsyncClient(timeout=60.0) as client: + response = await client.post( + f"{BASE_URL}/v1/chat/completions", + json={ + "messages": [ + {"role": "user", "content": "Say only the word 'OPENROUTER_OK' and nothing else."} + ], + "model": "llama-70b", + "max_tokens": 20, + "provider": "openrouter", + "temperature": 0, + } + ) + assert response.status_code == 200 + data = response.json() + assert data["provider"] == "openrouter" + assert len(data["choices"]) > 0 + content = data["choices"][0]["message"]["content"] + print(f"✔ OpenRouter response: {content[:50]}") + + @pytest.mark.asyncio + async def test_openrouter_with_system_prompt(self): + """Test OpenRouter with system prompt.""" + async with httpx.AsyncClient(timeout=60.0) as client: + response = await client.post( + f"{BASE_URL}/v1/chat/completions", + json={ + "messages": [ + {"role": "system", "content": "You are a helpful assistant. Be very brief."}, + {"role": "user", "content": "What is the capital of Germany?"} + ], + "model": "llama-70b", + "max_tokens": 50, + "provider": "openrouter", + "temperature": 0, + } + ) + assert response.status_code == 200 + data = response.json() + content = data["choices"][0]["message"]["content"] + assert "Berlin" in content or "berlin" in content.lower() + print(f"✔ OpenRouter geography: {content[:50]}") + + @pytest.mark.asyncio + async def test_openrouter_free_model(self): + """Test OpenRouter with a free Gemini model.""" + async with httpx.AsyncClient(timeout=60.0) as client: + response = await client.post( + f"{BASE_URL}/v1/chat/completions", + json={ + "messages": [ + {"role": "user", "content": "What is 5 + 5?"} + ], + "model": "gemini-flash", + "max_tokens": 20, + "provider": "openrouter", + "temperature": 0, + } + ) + assert response.status_code == 200 + data = response.json() + content = data["choices"][0]["message"]["content"] + assert "10" in content + print(f"✔ OpenRouter Gemini: {content[:50]}") + + +class TestDeepSeekProvider: + """Tests specifically for DeepSeek provider.""" + + @pytest.mark.asyncio + async def test_deepseek_simple_completion(self): + """Test a simple completion via DeepSeek.""" + async with httpx.AsyncClient(timeout=60.0) as client: + response = await client.post( + f"{BASE_URL}/v1/chat/completions", + json={ + "messages": [ + {"role": "user", "content": "Say only the word 'DEEPSEEK_OK' and nothing else."} + ], + "model": "deepseek-chat", + "max_tokens": 20, + "provider": "deepseek", + "temperature": 0, + } + ) + assert response.status_code == 200 + data = response.json() + assert data["provider"] == "deepseek" + content = data["choices"][0]["message"]["content"] + print(f"✔ DeepSeek response: {content[:50]}") + + +class TestGeminiProvider: + """Tests specifically for Gemini provider.""" + + @pytest.mark.asyncio + async def test_gemini_simple_completion(self): + """Test a simple completion via Gemini.""" + async with httpx.AsyncClient(timeout=60.0) as client: + response = await client.post( + f"{BASE_URL}/v1/chat/completions", + json={ + "messages": [ + {"role": "user", "content": "Say only the word 'GEMINI_OK' and nothing else."} + ], + "model": "gemini-flash", + "max_tokens": 20, + "provider": "gemini", + "temperature": 0, + } + ) + assert response.status_code == 200 + data = response.json() + assert data["provider"] == "gemini" + content = data["choices"][0]["message"]["content"] + print(f"✔ Gemini response: {content[:50]}") \ No newline at end of file