From d5768845ac337d488dadf90dfff6676986a71226 Mon Sep 17 00:00:00 2001 From: Faisal Date: Sat, 16 May 2026 00:53:22 +0530 Subject: [PATCH 1/8] Add Groq provider with free tier support --- app/providers/groq.py | 166 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 166 insertions(+) create mode 100644 app/providers/groq.py diff --git a/app/providers/groq.py b/app/providers/groq.py new file mode 100644 index 0000000..e46d8a6 --- /dev/null +++ b/app/providers/groq.py @@ -0,0 +1,166 @@ +"""Groq provider implementation.""" + +import time +import httpx +import structlog + +from app.models import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionChoice, + Message, + Usage, +) +from app.providers.base import BaseProvider, ProviderKey +from app.models_mapping import ModelMapper +from app.exceptions import ( + RateLimitError, + CapabilityError, + ProviderUnavailableError, +) + +logger = structlog.get_logger() + + +class GroqProvider(BaseProvider): + """Groq provider — free tier, very fast inference.""" + + name = "groq" + + def __init__(self, base_url: str = "https://api.groq.com/openai/v1", model_mapper: ModelMapper = None): + super().__init__(base_url, model_mapper) + + async def chat_completion( + self, + key: ProviderKey, + request: ChatCompletionRequest, + ) -> ChatCompletionResponse: + """Execute chat completion via Groq API.""" + + model = self.get_model_name(request.model) + + payload = { + "model": model, + "messages": [{"role": m.role, "content": m.content} for m in request.messages], + "temperature": request.temperature, + "max_tokens": request.max_tokens, + "top_p": request.top_p, + } + + if request.stop: + payload["stop"] = request.stop + + headers = { + "Authorization": f"Bearer {key.api_key}", + "Content-Type": "application/json", + } + + client = await self.get_client() + + await logger.ainfo( + "Sending request to Groq", + model=model, + key_name=key.key_name, + messages_count=len(request.messages), + ) + + try: + response = await client.post( + f"{self.base_url}/chat/completions", + json=payload, + headers=headers, + ) + response.raise_for_status() + data = response.json() + + choice = data["choices"][0] + usage = data.get("usage", {}) + + total_tokens = usage.get("total_tokens", 0) + if total_tokens > 0: + await key.bucket.consume_tokens(total_tokens) + + return ChatCompletionResponse( + id=data.get("id", f"groq-{int(time.time())}"), + created=data.get("created", int(time.time())), + model=model, + choices=[ + ChatCompletionChoice( + index=0, + message=Message( + role="assistant", + content=choice["message"]["content"], + ), + finish_reason=choice.get("finish_reason", "stop"), + ) + ], + usage=Usage( + prompt_tokens=usage.get("prompt_tokens", 0), + completion_tokens=usage.get("completion_tokens", 0), + total_tokens=usage.get("total_tokens", 0), + ), + provider=self.name, + provider_key_name=key.key_name, + ) + + except httpx.HTTPStatusError as e: + status_code = e.response.status_code + response_text = e.response.text + + await logger.aerror( + "Groq API error", + status_code=status_code, + response=response_text, + key_name=key.key_name, + ) + + if status_code == 429: + retry_after = e.response.headers.get("retry-after") + retry_seconds = float(retry_after) if retry_after else None + raise RateLimitError( + provider=self.name, + retry_after=retry_seconds, + message=f"Rate limit exceeded: {response_text[:200]}", + ) + + if 500 <= status_code < 600: + raise ProviderUnavailableError( + provider=self.name, + status_code=status_code, + message=f"Groq server error: {response_text[:200]}", + ) + + if status_code == 400: + response_lower = response_text.lower() + if "image" in response_lower or "vision" in response_lower: + raise CapabilityError( + provider=self.name, + capability="vision", + message=f"Vision not supported: {response_text[:200]}", + ) + if "tool" in response_lower or "function" in response_lower: + raise CapabilityError( + provider=self.name, + capability="tool_calls", + message=f"Tool calling error: {response_text[:200]}", + ) + raise CapabilityError( + provider=self.name, + capability="unknown", + message=f"Request error: {response_text[:200]}", + ) + + raise + + except httpx.TimeoutException as e: + await logger.aerror("Groq request timeout", error=str(e)) + raise ProviderUnavailableError( + provider=self.name, + status_code=504, + message=f"Request timeout: {str(e)}", + ) + except (RateLimitError, CapabilityError, ProviderUnavailableError): + raise + except Exception as e: + await logger.aerror("Groq request failed", error=str(e)) + raise \ No newline at end of file From e358c4ba078ca0fce6bc0a8eae9a16ff65097c2f Mon Sep 17 00:00:00 2001 From: Faisal Date: Sat, 16 May 2026 00:58:18 +0530 Subject: [PATCH 2/8] Add Gemini, Mistral, OpenRouter providers and expand model mappings --- app/providers/gemini.py | 166 +++++++++++++++++++++++++++++++++++ app/providers/mistral.py | 166 +++++++++++++++++++++++++++++++++++ app/providers/openrouter.py | 168 ++++++++++++++++++++++++++++++++++++ 3 files changed, 500 insertions(+) create mode 100644 app/providers/gemini.py create mode 100644 app/providers/mistral.py create mode 100644 app/providers/openrouter.py diff --git a/app/providers/gemini.py b/app/providers/gemini.py new file mode 100644 index 0000000..3c7f641 --- /dev/null +++ b/app/providers/gemini.py @@ -0,0 +1,166 @@ +"""Google Gemini provider implementation.""" + +import time +import httpx +import structlog + +from app.models import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionChoice, + Message, + Usage, +) +from app.providers.base import BaseProvider, ProviderKey +from app.models_mapping import ModelMapper +from app.exceptions import ( + RateLimitError, + CapabilityError, + ProviderUnavailableError, +) + +logger = structlog.get_logger() + + +class GeminiProvider(BaseProvider): + """Google Gemini provider — free tier available.""" + + name = "gemini" + + def __init__(self, base_url: str = "https://generativelanguage.googleapis.com/v1beta/openai", model_mapper: ModelMapper = None): + super().__init__(base_url, model_mapper) + + async def chat_completion( + self, + key: ProviderKey, + request: ChatCompletionRequest, + ) -> ChatCompletionResponse: + """Execute chat completion via Gemini API.""" + + model = self.get_model_name(request.model) + + payload = { + "model": model, + "messages": [{"role": m.role, "content": m.content} for m in request.messages], + "temperature": request.temperature, + "max_tokens": request.max_tokens, + "top_p": request.top_p, + } + + if request.stop: + payload["stop"] = request.stop + + headers = { + "Authorization": f"Bearer {key.api_key}", + "Content-Type": "application/json", + } + + client = await self.get_client() + + await logger.ainfo( + "Sending request to Gemini", + model=model, + key_name=key.key_name, + messages_count=len(request.messages), + ) + + try: + response = await client.post( + f"{self.base_url}/chat/completions", + json=payload, + headers=headers, + ) + response.raise_for_status() + data = response.json() + + choice = data["choices"][0] + usage = data.get("usage", {}) + + total_tokens = usage.get("total_tokens", 0) + if total_tokens > 0: + await key.bucket.consume_tokens(total_tokens) + + return ChatCompletionResponse( + id=data.get("id", f"gemini-{int(time.time())}"), + created=data.get("created", int(time.time())), + model=model, + choices=[ + ChatCompletionChoice( + index=0, + message=Message( + role="assistant", + content=choice["message"]["content"], + ), + finish_reason=choice.get("finish_reason", "stop"), + ) + ], + usage=Usage( + prompt_tokens=usage.get("prompt_tokens", 0), + completion_tokens=usage.get("completion_tokens", 0), + total_tokens=usage.get("total_tokens", 0), + ), + provider=self.name, + provider_key_name=key.key_name, + ) + + except httpx.HTTPStatusError as e: + status_code = e.response.status_code + response_text = e.response.text + + await logger.aerror( + "Gemini API error", + status_code=status_code, + response=response_text, + key_name=key.key_name, + ) + + if status_code == 429: + retry_after = e.response.headers.get("retry-after") + retry_seconds = float(retry_after) if retry_after else None + raise RateLimitError( + provider=self.name, + retry_after=retry_seconds, + message=f"Rate limit exceeded: {response_text[:200]}", + ) + + if 500 <= status_code < 600: + raise ProviderUnavailableError( + provider=self.name, + status_code=status_code, + message=f"Gemini server error: {response_text[:200]}", + ) + + if status_code == 400: + response_lower = response_text.lower() + if "image" in response_lower or "vision" in response_lower: + raise CapabilityError( + provider=self.name, + capability="vision", + message=f"Vision not supported: {response_text[:200]}", + ) + if "tool" in response_lower or "function" in response_lower: + raise CapabilityError( + provider=self.name, + capability="tool_calls", + message=f"Tool calling error: {response_text[:200]}", + ) + raise CapabilityError( + provider=self.name, + capability="unknown", + message=f"Request error: {response_text[:200]}", + ) + + raise + + except httpx.TimeoutException as e: + await logger.aerror("Gemini request timeout", error=str(e)) + raise ProviderUnavailableError( + provider=self.name, + status_code=504, + message=f"Request timeout: {str(e)}", + ) + except (RateLimitError, CapabilityError, ProviderUnavailableError): + raise + except Exception as e: + await logger.aerror("Gemini request failed", error=str(e)) + raise \ No newline at end of file diff --git a/app/providers/mistral.py b/app/providers/mistral.py new file mode 100644 index 0000000..722c7f1 --- /dev/null +++ b/app/providers/mistral.py @@ -0,0 +1,166 @@ +"""Mistral provider implementation.""" + +import time +import httpx +import structlog + +from app.models import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionChoice, + Message, + Usage, +) +from app.providers.base import BaseProvider, ProviderKey +from app.models_mapping import ModelMapper +from app.exceptions import ( + RateLimitError, + CapabilityError, + ProviderUnavailableError, +) + +logger = structlog.get_logger() + + +class MistralProvider(BaseProvider): + """Mistral AI provider — free tier available.""" + + name = "mistral" + + def __init__(self, base_url: str = "https://api.mistral.ai/v1", model_mapper: ModelMapper = None): + super().__init__(base_url, model_mapper) + + async def chat_completion( + self, + key: ProviderKey, + request: ChatCompletionRequest, + ) -> ChatCompletionResponse: + """Execute chat completion via Mistral API.""" + + model = self.get_model_name(request.model) + + payload = { + "model": model, + "messages": [{"role": m.role, "content": m.content} for m in request.messages], + "temperature": request.temperature, + "max_tokens": request.max_tokens, + "top_p": request.top_p, + } + + if request.stop: + payload["stop"] = request.stop + + headers = { + "Authorization": f"Bearer {key.api_key}", + "Content-Type": "application/json", + } + + client = await self.get_client() + + await logger.ainfo( + "Sending request to Mistral", + model=model, + key_name=key.key_name, + messages_count=len(request.messages), + ) + + try: + response = await client.post( + f"{self.base_url}/chat/completions", + json=payload, + headers=headers, + ) + response.raise_for_status() + data = response.json() + + choice = data["choices"][0] + usage = data.get("usage", {}) + + total_tokens = usage.get("total_tokens", 0) + if total_tokens > 0: + await key.bucket.consume_tokens(total_tokens) + + return ChatCompletionResponse( + id=data.get("id", f"mistral-{int(time.time())}"), + created=data.get("created", int(time.time())), + model=model, + choices=[ + ChatCompletionChoice( + index=0, + message=Message( + role="assistant", + content=choice["message"]["content"], + ), + finish_reason=choice.get("finish_reason", "stop"), + ) + ], + usage=Usage( + prompt_tokens=usage.get("prompt_tokens", 0), + completion_tokens=usage.get("completion_tokens", 0), + total_tokens=usage.get("total_tokens", 0), + ), + provider=self.name, + provider_key_name=key.key_name, + ) + + except httpx.HTTPStatusError as e: + status_code = e.response.status_code + response_text = e.response.text + + await logger.aerror( + "Mistral API error", + status_code=status_code, + response=response_text, + key_name=key.key_name, + ) + + if status_code == 429: + retry_after = e.response.headers.get("retry-after") + retry_seconds = float(retry_after) if retry_after else None + raise RateLimitError( + provider=self.name, + retry_after=retry_seconds, + message=f"Rate limit exceeded: {response_text[:200]}", + ) + + if 500 <= status_code < 600: + raise ProviderUnavailableError( + provider=self.name, + status_code=status_code, + message=f"Mistral server error: {response_text[:200]}", + ) + + if status_code == 400: + response_lower = response_text.lower() + if "image" in response_lower or "vision" in response_lower: + raise CapabilityError( + provider=self.name, + capability="vision", + message=f"Vision not supported: {response_text[:200]}", + ) + if "tool" in response_lower or "function" in response_lower: + raise CapabilityError( + provider=self.name, + capability="tool_calls", + message=f"Tool calling error: {response_text[:200]}", + ) + raise CapabilityError( + provider=self.name, + capability="unknown", + message=f"Request error: {response_text[:200]}", + ) + + raise + + except httpx.TimeoutException as e: + await logger.aerror("Mistral request timeout", error=str(e)) + raise ProviderUnavailableError( + provider=self.name, + status_code=504, + message=f"Request timeout: {str(e)}", + ) + except (RateLimitError, CapabilityError, ProviderUnavailableError): + raise + except Exception as e: + await logger.aerror("Mistral request failed", error=str(e)) + raise \ No newline at end of file diff --git a/app/providers/openrouter.py b/app/providers/openrouter.py new file mode 100644 index 0000000..09af140 --- /dev/null +++ b/app/providers/openrouter.py @@ -0,0 +1,168 @@ +"""OpenRouter provider implementation.""" + +import time +import httpx +import structlog + +from app.models import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionChoice, + Message, + Usage, +) +from app.providers.base import BaseProvider, ProviderKey +from app.models_mapping import ModelMapper +from app.exceptions import ( + RateLimitError, + CapabilityError, + ProviderUnavailableError, +) + +logger = structlog.get_logger() + + +class OpenRouterProvider(BaseProvider): + """OpenRouter provider — access to many free models.""" + + name = "openrouter" + + def __init__(self, base_url: str = "https://openrouter.ai/api/v1", model_mapper: ModelMapper = None): + super().__init__(base_url, model_mapper) + + async def chat_completion( + self, + key: ProviderKey, + request: ChatCompletionRequest, + ) -> ChatCompletionResponse: + """Execute chat completion via OpenRouter API.""" + + model = self.get_model_name(request.model) + + payload = { + "model": model, + "messages": [{"role": m.role, "content": m.content} for m in request.messages], + "temperature": request.temperature, + "max_tokens": request.max_tokens, + "top_p": request.top_p, + } + + if request.stop: + payload["stop"] = request.stop + + headers = { + "Authorization": f"Bearer {key.api_key}", + "Content-Type": "application/json", + "HTTP-Referer": "https://github.com/Alcray/TrainForgeConductor", + "X-Title": "TrainForgeConductor", + } + + client = await self.get_client() + + await logger.ainfo( + "Sending request to OpenRouter", + model=model, + key_name=key.key_name, + messages_count=len(request.messages), + ) + + try: + response = await client.post( + f"{self.base_url}/chat/completions", + json=payload, + headers=headers, + ) + response.raise_for_status() + data = response.json() + + choice = data["choices"][0] + usage = data.get("usage", {}) + + total_tokens = usage.get("total_tokens", 0) + if total_tokens > 0: + await key.bucket.consume_tokens(total_tokens) + + return ChatCompletionResponse( + id=data.get("id", f"openrouter-{int(time.time())}"), + created=data.get("created", int(time.time())), + model=model, + choices=[ + ChatCompletionChoice( + index=0, + message=Message( + role="assistant", + content=choice["message"]["content"], + ), + finish_reason=choice.get("finish_reason", "stop"), + ) + ], + usage=Usage( + prompt_tokens=usage.get("prompt_tokens", 0), + completion_tokens=usage.get("completion_tokens", 0), + total_tokens=usage.get("total_tokens", 0), + ), + provider=self.name, + provider_key_name=key.key_name, + ) + + except httpx.HTTPStatusError as e: + status_code = e.response.status_code + response_text = e.response.text + + await logger.aerror( + "OpenRouter API error", + status_code=status_code, + response=response_text, + key_name=key.key_name, + ) + + if status_code == 429: + retry_after = e.response.headers.get("retry-after") + retry_seconds = float(retry_after) if retry_after else None + raise RateLimitError( + provider=self.name, + retry_after=retry_seconds, + message=f"Rate limit exceeded: {response_text[:200]}", + ) + + if 500 <= status_code < 600: + raise ProviderUnavailableError( + provider=self.name, + status_code=status_code, + message=f"OpenRouter server error: {response_text[:200]}", + ) + + if status_code == 400: + response_lower = response_text.lower() + if "image" in response_lower or "vision" in response_lower: + raise CapabilityError( + provider=self.name, + capability="vision", + message=f"Vision not supported: {response_text[:200]}", + ) + if "tool" in response_lower or "function" in response_lower: + raise CapabilityError( + provider=self.name, + capability="tool_calls", + message=f"Tool calling error: {response_text[:200]}", + ) + raise CapabilityError( + provider=self.name, + capability="unknown", + message=f"Request error: {response_text[:200]}", + ) + + raise + + except httpx.TimeoutException as e: + await logger.aerror("OpenRouter request timeout", error=str(e)) + raise ProviderUnavailableError( + provider=self.name, + status_code=504, + message=f"Request timeout: {str(e)}", + ) + except (RateLimitError, CapabilityError, ProviderUnavailableError): + raise + except Exception as e: + await logger.aerror("OpenRouter request failed", error=str(e)) + raise \ No newline at end of file From 3c2fbbc978198eb3638e961a839327bff84a09a7 Mon Sep 17 00:00:00 2001 From: Faisal Date: Sat, 16 May 2026 01:04:18 +0530 Subject: [PATCH 3/8] Add DeepSeek, HuggingFace, Cohere, SambaNova providers and expand model mappings --- app/models_mapping.py | 135 ++++++++++++++++++++-------- app/providers/__init__.py | 20 ++++- app/providers/cohere.py | 167 +++++++++++++++++++++++++++++++++++ app/providers/deepseek.py | 166 ++++++++++++++++++++++++++++++++++ app/providers/huggingface.py | 166 ++++++++++++++++++++++++++++++++++ app/providers/sambanova.py | 166 ++++++++++++++++++++++++++++++++++ 6 files changed, 781 insertions(+), 39 deletions(-) create mode 100644 app/providers/cohere.py create mode 100644 app/providers/deepseek.py create mode 100644 app/providers/huggingface.py create mode 100644 app/providers/sambanova.py diff --git a/app/models_mapping.py b/app/models_mapping.py index a49dd7e..24dc23c 100644 --- a/app/models_mapping.py +++ b/app/models_mapping.py @@ -1,87 +1,150 @@ """ Unified model name mapping. - This allows you to use simple names like "llama-70b" and the conductor will automatically translate to the correct provider-specific name. """ -# Default model mappings: unified name -> provider-specific names DEFAULT_MODEL_MAPPING = { - # Llama 3.3 70B - the flagship model + # Llama 3.3 70B "llama-70b": { "cerebras": "llama-3.3-70b", "nvidia": "meta/llama-3.3-70b-instruct", + "groq": "llama-3.3-70b-versatile", + "openrouter": "meta-llama/llama-3.3-70b-instruct", + "sambanova": "Meta-Llama-3.3-70B-Instruct", + "huggingface": "meta-llama/Llama-3.3-70B-Instruct", }, "llama-3.3-70b": { "cerebras": "llama-3.3-70b", "nvidia": "meta/llama-3.3-70b-instruct", + "groq": "llama-3.3-70b-versatile", + "openrouter": "meta-llama/llama-3.3-70b-instruct", + "sambanova": "Meta-Llama-3.3-70B-Instruct", + "huggingface": "meta-llama/Llama-3.3-70B-Instruct", }, - - # Llama 3.1 8B - fast and cheap + # Llama 3.1 8B "llama-8b": { "cerebras": "llama3.1-8b", "nvidia": "meta/llama-3.1-8b-instruct", + "groq": "llama-3.1-8b-instant", + "openrouter": "meta-llama/llama-3.1-8b-instruct", + "sambanova": "Meta-Llama-3.1-8B-Instruct", + "huggingface": "meta-llama/Llama-3.1-8B-Instruct", }, "llama-3.1-8b": { "cerebras": "llama3.1-8b", "nvidia": "meta/llama-3.1-8b-instruct", + "groq": "llama-3.1-8b-instant", + "openrouter": "meta-llama/llama-3.1-8b-instruct", + "sambanova": "Meta-Llama-3.1-8B-Instruct", + "huggingface": "meta-llama/Llama-3.1-8B-Instruct", }, - # Llama 3.1 70B "llama-3.1-70b": { "cerebras": "llama-3.1-70b", "nvidia": "meta/llama-3.1-70b-instruct", + "groq": "llama-3.1-70b-versatile", + "openrouter": "meta-llama/llama-3.1-70b-instruct", + "sambanova": "Meta-Llama-3.1-70B-Instruct", + "huggingface": "meta-llama/Llama-3.1-70B-Instruct", + }, + # Llama 3.1 405B + "llama-405b": { + "nvidia": "meta/llama-3.1-405b-instruct", + "sambanova": "Meta-Llama-3.1-405B-Instruct", + "openrouter": "meta-llama/llama-3.1-405b-instruct", + }, + # Gemini models + "gemini-flash": { + "gemini": "gemini-2.0-flash", + "openrouter": "google/gemini-2.0-flash-exp:free", + }, + "gemini-2.0-flash": { + "gemini": "gemini-2.0-flash", + "openrouter": "google/gemini-2.0-flash-exp:free", + }, + "gemini-flash-lite": { + "gemini": "gemini-2.0-flash-lite", + }, + # Mistral models + "mistral-small": { + "mistral": "mistral-small-latest", + "openrouter": "mistralai/mistral-small", + }, + "mistral-7b": { + "mistral": "open-mistral-7b", + "openrouter": "mistralai/mistral-7b-instruct:free", + "huggingface": "mistralai/Mistral-7B-Instruct-v0.3", + }, + # DeepSeek models + "deepseek-chat": { + "deepseek": "deepseek-chat", + "openrouter": "deepseek/deepseek-chat:free", + }, + "deepseek-coder": { + "deepseek": "deepseek-coder", + "openrouter": "deepseek/deepseek-coder:free", + }, + "deepseek-r1": { + "deepseek": "deepseek-reasoner", + "openrouter": "deepseek/deepseek-r1:free", + "sambanova": "DeepSeek-R1", + }, + # Cohere models + "command-r": { + "cohere": "command-r-08-2024", + "openrouter": "cohere/command-r", + }, + "command-r-plus": { + "cohere": "command-r-plus-08-2024", + "openrouter": "cohere/command-r-plus", + }, + # Groq specific + "mixtral-8x7b": { + "groq": "mixtral-8x7b-32768", + "openrouter": "mistralai/mixtral-8x7b-instruct", + }, + "gemma-7b": { + "groq": "gemma-7b-it", + "openrouter": "google/gemma-7b-it:free", + "huggingface": "google/gemma-7b-it", + }, + "gemma2-9b": { + "groq": "gemma2-9b-it", + "openrouter": "google/gemma-2-9b-it:free", + "huggingface": "google/gemma-2-9b-it", + }, + # Qwen models + "qwen-72b": { + "sambanova": "Qwen2.5-72B-Instruct", + "openrouter": "qwen/qwen-2.5-72b-instruct:free", + "huggingface": "Qwen/Qwen2.5-72B-Instruct", }, } -# Default model to use when none specified DEFAULT_MODEL = "llama-70b" class ModelMapper: """Maps unified model names to provider-specific names.""" - + def __init__(self, custom_mappings: dict = None): - """ - Initialize with optional custom mappings from config. - - Args: - custom_mappings: Dict of {unified_name: {provider: provider_name}} - """ self.mappings = DEFAULT_MODEL_MAPPING.copy() if custom_mappings: self.mappings.update(custom_mappings) - + def get_provider_model(self, unified_name: str, provider: str) -> str: - """ - Get the provider-specific model name. - - Args: - unified_name: The unified model name (e.g., "llama-70b") - provider: The provider name (e.g., "cerebras", "nvidia") - - Returns: - Provider-specific model name - """ if not unified_name: unified_name = DEFAULT_MODEL - - # Normalize the name name_lower = unified_name.lower().strip() - - # Check if it's in our mappings if name_lower in self.mappings: provider_models = self.mappings[name_lower] if provider in provider_models: return provider_models[provider] - - # If not found, return as-is (maybe it's already provider-specific) return unified_name - + def get_available_models(self) -> list[str]: - """Get list of available unified model names.""" return list(self.mappings.keys()) - + def add_mapping(self, unified_name: str, provider_models: dict): - """Add a custom model mapping.""" - self.mappings[unified_name.lower()] = provider_models + self.mappings[unified_name.lower()] = provider_models \ No newline at end of file diff --git a/app/providers/__init__.py b/app/providers/__init__.py index 9deb484..9979afb 100644 --- a/app/providers/__init__.py +++ b/app/providers/__init__.py @@ -1,13 +1,27 @@ """Provider implementations for TrainForgeConductor.""" - from .base import BaseProvider, ProviderKey from .cerebras import CerebrasProvider from .nvidia import NvidiaProvider +from .groq import GroqProvider +from .gemini import GeminiProvider +from .mistral import MistralProvider +from .openrouter import OpenRouterProvider +from .deepseek import DeepSeekProvider +from .huggingface import HuggingFaceProvider +from .cohere import CohereProvider +from .sambanova import SambaNoveProvider __all__ = [ "BaseProvider", "ProviderKey", "CerebrasProvider", "NvidiaProvider", -] - + "GroqProvider", + "GeminiProvider", + "MistralProvider", + "OpenRouterProvider", + "DeepSeekProvider", + "HuggingFaceProvider", + "CohereProvider", + "SambaNoveProvider", +] \ No newline at end of file diff --git a/app/providers/cohere.py b/app/providers/cohere.py new file mode 100644 index 0000000..083ec5f --- /dev/null +++ b/app/providers/cohere.py @@ -0,0 +1,167 @@ +"""Cohere provider implementation.""" + +import time +import httpx +import structlog + +from app.models import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionChoice, + Message, + Usage, +) +from app.providers.base import BaseProvider, ProviderKey +from app.models_mapping import ModelMapper +from app.exceptions import ( + RateLimitError, + CapabilityError, + ProviderUnavailableError, +) + +logger = structlog.get_logger() + + +class CohereProvider(BaseProvider): + """Cohere provider — free tier available.""" + + name = "cohere" + + def __init__(self, base_url: str = "https://api.cohere.com/compatibility/v1", model_mapper: ModelMapper = None): + super().__init__(base_url, model_mapper) + + async def chat_completion( + self, + key: ProviderKey, + request: ChatCompletionRequest, + ) -> ChatCompletionResponse: + """Execute chat completion via Cohere API.""" + + model = self.get_model_name(request.model) + + payload = { + "model": model, + "messages": [{"role": m.role, "content": m.content} for m in request.messages], + "temperature": request.temperature, + "max_tokens": request.max_tokens, + "top_p": request.top_p, + } + + if request.stop: + payload["stop"] = request.stop + + headers = { + "Authorization": f"Bearer {key.api_key}", + "Content-Type": "application/json", + } + + client = await self.get_client() + + await logger.ainfo( + "Sending request to Cohere", + model=model, + key_name=key.key_name, + messages_count=len(request.messages), + ) + + try: + response = await client.post( + f"{self.base_url}/chat/completions", + json=payload, + headers=headers, + ) + response.raise_for_status() + data = response.json() + + choice = data["choices"][0] + usage = data.get("usage", {}) + + total_tokens = usage.get("total_tokens", 0) + if total_tokens > 0: + await key.bucket.consume_tokens(total_tokens) + + return ChatCompletionResponse( + id=data.get("id", f"cohere-{int(time.time())}"), + created=data.get("created", int(time.time())), + model=model, + choices=[ + ChatCompletionChoice( + index=0, + message=Message( + role="assistant", + content=choice["message"]["content"], + ), + finish_reason=choice.get("finish_reason", "stop"), + ) + ], + usage=Usage( + prompt_tokens=usage.get("prompt_tokens", 0), + completion_tokens=usage.get("completion_tokens", 0), + total_tokens=usage.get("total_tokens", 0), + ), + provider=self.name, + provider_key_name=key.key_name, + ) + + except httpx.HTTPStatusError as e: + status_code = e.response.status_code + response_text = e.response.text + + await logger.aerror( + "Cohere API error", + status_code=status_code, + response=response_text, + key_name=key.key_name, + ) + + if status_code == 429: + retry_after = e.response.headers.get("retry-after") + retry_seconds = float(retry_after) if retry_after else None + raise RateLimitError( + provider=self.name, + retry_after=retry_seconds, + message=f"Rate limit exceeded: {response_text[:200]}", + ) + + if 500 <= status_code < 600: + raise ProviderUnavailableError( + provider=self.name, + status_code=status_code, + message=f"Cohere server error: {response_text[:200]}", + ) + + if status_code == 400: + response_lower = response_text.lower() + if "image" in response_lower or "vision" in response_lower: + raise CapabilityError( + provider=self.name, + capability="vision", + message=f"Vision not supported: {response_text[:200]}", + ) + if "tool" in response_lower or "function" in response_lower: + raise CapabilityError( + provider=self.name, + capability="tool_calls", + message=f"Tool calling error: {response_text[:200]}", + ) + raise CapabilityError( + provider=self.name, + capability="unknown", + message=f"Request error: {response_text[:200]}", + ) + + raise + + except httpx.TimeoutException as e: + await logger.aerror("Cohere request timeout", error=str(e)) + raise ProviderUnavailableError( + provider=self.name, + status_code=504, + message=f"Request timeout: {str(e)}", + ) + except (RateLimitError, CapabilityError, ProviderUnavailableError): + raise + except Exception as e: + await logger.aerror("Cohere request failed", error=str(e)) + raise + \ No newline at end of file diff --git a/app/providers/deepseek.py b/app/providers/deepseek.py new file mode 100644 index 0000000..d5980fc --- /dev/null +++ b/app/providers/deepseek.py @@ -0,0 +1,166 @@ +"""DeepSeek provider implementation.""" + +import time +import httpx +import structlog + +from app.models import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionChoice, + Message, + Usage, +) +from app.providers.base import BaseProvider, ProviderKey +from app.models_mapping import ModelMapper +from app.exceptions import ( + RateLimitError, + CapabilityError, + ProviderUnavailableError, +) + +logger = structlog.get_logger() + + +class DeepSeekProvider(BaseProvider): + """DeepSeek provider — free tier with strong coding models.""" + + name = "deepseek" + + def __init__(self, base_url: str = "https://api.deepseek.com/v1", model_mapper: ModelMapper = None): + super().__init__(base_url, model_mapper) + + async def chat_completion( + self, + key: ProviderKey, + request: ChatCompletionRequest, + ) -> ChatCompletionResponse: + """Execute chat completion via DeepSeek API.""" + + model = self.get_model_name(request.model) + + payload = { + "model": model, + "messages": [{"role": m.role, "content": m.content} for m in request.messages], + "temperature": request.temperature, + "max_tokens": request.max_tokens, + "top_p": request.top_p, + } + + if request.stop: + payload["stop"] = request.stop + + headers = { + "Authorization": f"Bearer {key.api_key}", + "Content-Type": "application/json", + } + + client = await self.get_client() + + await logger.ainfo( + "Sending request to DeepSeek", + model=model, + key_name=key.key_name, + messages_count=len(request.messages), + ) + + try: + response = await client.post( + f"{self.base_url}/chat/completions", + json=payload, + headers=headers, + ) + response.raise_for_status() + data = response.json() + + choice = data["choices"][0] + usage = data.get("usage", {}) + + total_tokens = usage.get("total_tokens", 0) + if total_tokens > 0: + await key.bucket.consume_tokens(total_tokens) + + return ChatCompletionResponse( + id=data.get("id", f"deepseek-{int(time.time())}"), + created=data.get("created", int(time.time())), + model=model, + choices=[ + ChatCompletionChoice( + index=0, + message=Message( + role="assistant", + content=choice["message"]["content"], + ), + finish_reason=choice.get("finish_reason", "stop"), + ) + ], + usage=Usage( + prompt_tokens=usage.get("prompt_tokens", 0), + completion_tokens=usage.get("completion_tokens", 0), + total_tokens=usage.get("total_tokens", 0), + ), + provider=self.name, + provider_key_name=key.key_name, + ) + + except httpx.HTTPStatusError as e: + status_code = e.response.status_code + response_text = e.response.text + + await logger.aerror( + "DeepSeek API error", + status_code=status_code, + response=response_text, + key_name=key.key_name, + ) + + if status_code == 429: + retry_after = e.response.headers.get("retry-after") + retry_seconds = float(retry_after) if retry_after else None + raise RateLimitError( + provider=self.name, + retry_after=retry_seconds, + message=f"Rate limit exceeded: {response_text[:200]}", + ) + + if 500 <= status_code < 600: + raise ProviderUnavailableError( + provider=self.name, + status_code=status_code, + message=f"DeepSeek server error: {response_text[:200]}", + ) + + if status_code == 400: + response_lower = response_text.lower() + if "image" in response_lower or "vision" in response_lower: + raise CapabilityError( + provider=self.name, + capability="vision", + message=f"Vision not supported: {response_text[:200]}", + ) + if "tool" in response_lower or "function" in response_lower: + raise CapabilityError( + provider=self.name, + capability="tool_calls", + message=f"Tool calling error: {response_text[:200]}", + ) + raise CapabilityError( + provider=self.name, + capability="unknown", + message=f"Request error: {response_text[:200]}", + ) + + raise + + except httpx.TimeoutException as e: + await logger.aerror("DeepSeek request timeout", error=str(e)) + raise ProviderUnavailableError( + provider=self.name, + status_code=504, + message=f"Request timeout: {str(e)}", + ) + except (RateLimitError, CapabilityError, ProviderUnavailableError): + raise + except Exception as e: + await logger.aerror("DeepSeek request failed", error=str(e)) + raise \ No newline at end of file diff --git a/app/providers/huggingface.py b/app/providers/huggingface.py new file mode 100644 index 0000000..f10dc54 --- /dev/null +++ b/app/providers/huggingface.py @@ -0,0 +1,166 @@ +"""Hugging Face provider implementation.""" + +import time +import httpx +import structlog + +from app.models import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionChoice, + Message, + Usage, +) +from app.providers.base import BaseProvider, ProviderKey +from app.models_mapping import ModelMapper +from app.exceptions import ( + RateLimitError, + CapabilityError, + ProviderUnavailableError, +) + +logger = structlog.get_logger() + + +class HuggingFaceProvider(BaseProvider): + """Hugging Face Inference API provider — free tier available.""" + + name = "huggingface" + + def __init__(self, base_url: str = "https://api-inference.huggingface.co/v1", model_mapper: ModelMapper = None): + super().__init__(base_url, model_mapper) + + async def chat_completion( + self, + key: ProviderKey, + request: ChatCompletionRequest, + ) -> ChatCompletionResponse: + """Execute chat completion via Hugging Face Inference API.""" + + model = self.get_model_name(request.model) + + payload = { + "model": model, + "messages": [{"role": m.role, "content": m.content} for m in request.messages], + "temperature": request.temperature, + "max_tokens": request.max_tokens, + "top_p": request.top_p, + } + + if request.stop: + payload["stop"] = request.stop + + headers = { + "Authorization": f"Bearer {key.api_key}", + "Content-Type": "application/json", + } + + client = await self.get_client() + + await logger.ainfo( + "Sending request to Hugging Face", + model=model, + key_name=key.key_name, + messages_count=len(request.messages), + ) + + try: + response = await client.post( + f"{self.base_url}/chat/completions", + json=payload, + headers=headers, + ) + response.raise_for_status() + data = response.json() + + choice = data["choices"][0] + usage = data.get("usage", {}) + + total_tokens = usage.get("total_tokens", 0) + if total_tokens > 0: + await key.bucket.consume_tokens(total_tokens) + + return ChatCompletionResponse( + id=data.get("id", f"huggingface-{int(time.time())}"), + created=data.get("created", int(time.time())), + model=model, + choices=[ + ChatCompletionChoice( + index=0, + message=Message( + role="assistant", + content=choice["message"]["content"], + ), + finish_reason=choice.get("finish_reason", "stop"), + ) + ], + usage=Usage( + prompt_tokens=usage.get("prompt_tokens", 0), + completion_tokens=usage.get("completion_tokens", 0), + total_tokens=usage.get("total_tokens", 0), + ), + provider=self.name, + provider_key_name=key.key_name, + ) + + except httpx.HTTPStatusError as e: + status_code = e.response.status_code + response_text = e.response.text + + await logger.aerror( + "Hugging Face API error", + status_code=status_code, + response=response_text, + key_name=key.key_name, + ) + + if status_code == 429: + retry_after = e.response.headers.get("retry-after") + retry_seconds = float(retry_after) if retry_after else None + raise RateLimitError( + provider=self.name, + retry_after=retry_seconds, + message=f"Rate limit exceeded: {response_text[:200]}", + ) + + if 500 <= status_code < 600: + raise ProviderUnavailableError( + provider=self.name, + status_code=status_code, + message=f"Hugging Face server error: {response_text[:200]}", + ) + + if status_code == 400: + response_lower = response_text.lower() + if "image" in response_lower or "vision" in response_lower: + raise CapabilityError( + provider=self.name, + capability="vision", + message=f"Vision not supported: {response_text[:200]}", + ) + if "tool" in response_lower or "function" in response_lower: + raise CapabilityError( + provider=self.name, + capability="tool_calls", + message=f"Tool calling error: {response_text[:200]}", + ) + raise CapabilityError( + provider=self.name, + capability="unknown", + message=f"Request error: {response_text[:200]}", + ) + + raise + + except httpx.TimeoutException as e: + await logger.aerror("Hugging Face request timeout", error=str(e)) + raise ProviderUnavailableError( + provider=self.name, + status_code=504, + message=f"Request timeout: {str(e)}", + ) + except (RateLimitError, CapabilityError, ProviderUnavailableError): + raise + except Exception as e: + await logger.aerror("Hugging Face request failed", error=str(e)) + raise \ No newline at end of file diff --git a/app/providers/sambanova.py b/app/providers/sambanova.py new file mode 100644 index 0000000..46ca032 --- /dev/null +++ b/app/providers/sambanova.py @@ -0,0 +1,166 @@ +"""SambaNova provider implementation.""" + +import time +import httpx +import structlog + +from app.models import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionChoice, + Message, + Usage, +) +from app.providers.base import BaseProvider, ProviderKey +from app.models_mapping import ModelMapper +from app.exceptions import ( + RateLimitError, + CapabilityError, + ProviderUnavailableError, +) + +logger = structlog.get_logger() + + +class SambaNoveProvider(BaseProvider): + """SambaNova provider — free tier, fast Llama inference.""" + + name = "sambanova" + + def __init__(self, base_url: str = "https://api.sambanova.ai/v1", model_mapper: ModelMapper = None): + super().__init__(base_url, model_mapper) + + async def chat_completion( + self, + key: ProviderKey, + request: ChatCompletionRequest, + ) -> ChatCompletionResponse: + """Execute chat completion via SambaNova API.""" + + model = self.get_model_name(request.model) + + payload = { + "model": model, + "messages": [{"role": m.role, "content": m.content} for m in request.messages], + "temperature": request.temperature, + "max_tokens": request.max_tokens, + "top_p": request.top_p, + } + + if request.stop: + payload["stop"] = request.stop + + headers = { + "Authorization": f"Bearer {key.api_key}", + "Content-Type": "application/json", + } + + client = await self.get_client() + + await logger.ainfo( + "Sending request to SambaNova", + model=model, + key_name=key.key_name, + messages_count=len(request.messages), + ) + + try: + response = await client.post( + f"{self.base_url}/chat/completions", + json=payload, + headers=headers, + ) + response.raise_for_status() + data = response.json() + + choice = data["choices"][0] + usage = data.get("usage", {}) + + total_tokens = usage.get("total_tokens", 0) + if total_tokens > 0: + await key.bucket.consume_tokens(total_tokens) + + return ChatCompletionResponse( + id=data.get("id", f"sambanova-{int(time.time())}"), + created=data.get("created", int(time.time())), + model=model, + choices=[ + ChatCompletionChoice( + index=0, + message=Message( + role="assistant", + content=choice["message"]["content"], + ), + finish_reason=choice.get("finish_reason", "stop"), + ) + ], + usage=Usage( + prompt_tokens=usage.get("prompt_tokens", 0), + completion_tokens=usage.get("completion_tokens", 0), + total_tokens=usage.get("total_tokens", 0), + ), + provider=self.name, + provider_key_name=key.key_name, + ) + + except httpx.HTTPStatusError as e: + status_code = e.response.status_code + response_text = e.response.text + + await logger.aerror( + "SambaNova API error", + status_code=status_code, + response=response_text, + key_name=key.key_name, + ) + + if status_code == 429: + retry_after = e.response.headers.get("retry-after") + retry_seconds = float(retry_after) if retry_after else None + raise RateLimitError( + provider=self.name, + retry_after=retry_seconds, + message=f"Rate limit exceeded: {response_text[:200]}", + ) + + if 500 <= status_code < 600: + raise ProviderUnavailableError( + provider=self.name, + status_code=status_code, + message=f"SambaNova server error: {response_text[:200]}", + ) + + if status_code == 400: + response_lower = response_text.lower() + if "image" in response_lower or "vision" in response_lower: + raise CapabilityError( + provider=self.name, + capability="vision", + message=f"Vision not supported: {response_text[:200]}", + ) + if "tool" in response_lower or "function" in response_lower: + raise CapabilityError( + provider=self.name, + capability="tool_calls", + message=f"Tool calling error: {response_text[:200]}", + ) + raise CapabilityError( + provider=self.name, + capability="unknown", + message=f"Request error: {response_text[:200]}", + ) + + raise + + except httpx.TimeoutException as e: + await logger.aerror("SambaNova request timeout", error=str(e)) + raise ProviderUnavailableError( + provider=self.name, + status_code=504, + message=f"Request timeout: {str(e)}", + ) + except (RateLimitError, CapabilityError, ProviderUnavailableError): + raise + except Exception as e: + await logger.aerror("SambaNova request failed", error=str(e)) + raise \ No newline at end of file From b6b32a0c78420cf8ccffd51445ba73e5936b1dd8 Mon Sep 17 00:00:00 2001 From: Faisal Date: Sat, 16 May 2026 22:12:00 +0530 Subject: [PATCH 4/8] Fix SambaNova typo, restore docstrings, add trailing newline fix, add model mapping tests --- app/models_mapping.py | 18 +++++++++ app/providers/__init__.py | 4 +- app/providers/cohere.py | 3 +- app/providers/sambanova.py | 2 +- tests/test_models_mapping.py | 78 ++++++++++++++++++++++++++++++++++++ 5 files changed, 100 insertions(+), 5 deletions(-) create mode 100644 tests/test_models_mapping.py diff --git a/app/models_mapping.py b/app/models_mapping.py index 24dc23c..44dc3af 100644 --- a/app/models_mapping.py +++ b/app/models_mapping.py @@ -129,11 +129,27 @@ class ModelMapper: """Maps unified model names to provider-specific names.""" def __init__(self, custom_mappings: dict = None): + """ + Initialize with optional custom mappings from config. + + Args: + custom_mappings: Dict of {unified_name: {provider: provider_name}} + """ self.mappings = DEFAULT_MODEL_MAPPING.copy() if custom_mappings: self.mappings.update(custom_mappings) def get_provider_model(self, unified_name: str, provider: str) -> str: + """ + Get the provider-specific model name. + + Args: + unified_name: The unified model name (e.g., "llama-70b") + provider: The provider name (e.g., "cerebras", "nvidia") + + Returns: + Provider-specific model name + """ if not unified_name: unified_name = DEFAULT_MODEL name_lower = unified_name.lower().strip() @@ -144,7 +160,9 @@ def get_provider_model(self, unified_name: str, provider: str) -> str: return unified_name def get_available_models(self) -> list[str]: + """Get list of available unified model names.""" return list(self.mappings.keys()) def add_mapping(self, unified_name: str, provider_models: dict): + """Add a custom model mapping.""" self.mappings[unified_name.lower()] = provider_models \ No newline at end of file diff --git a/app/providers/__init__.py b/app/providers/__init__.py index 9979afb..982c3c6 100644 --- a/app/providers/__init__.py +++ b/app/providers/__init__.py @@ -9,7 +9,7 @@ from .deepseek import DeepSeekProvider from .huggingface import HuggingFaceProvider from .cohere import CohereProvider -from .sambanova import SambaNoveProvider +from .sambanova import SambaNovaProvider __all__ = [ "BaseProvider", @@ -23,5 +23,5 @@ "DeepSeekProvider", "HuggingFaceProvider", "CohereProvider", - "SambaNoveProvider", + "SambaNovaProvider", ] \ No newline at end of file diff --git a/app/providers/cohere.py b/app/providers/cohere.py index 083ec5f..0f401d4 100644 --- a/app/providers/cohere.py +++ b/app/providers/cohere.py @@ -163,5 +163,4 @@ async def chat_completion( raise except Exception as e: await logger.aerror("Cohere request failed", error=str(e)) - raise - \ No newline at end of file + raise \ No newline at end of file diff --git a/app/providers/sambanova.py b/app/providers/sambanova.py index 46ca032..62659a6 100644 --- a/app/providers/sambanova.py +++ b/app/providers/sambanova.py @@ -22,7 +22,7 @@ logger = structlog.get_logger() -class SambaNoveProvider(BaseProvider): +class SambaNovaProvider(BaseProvider): """SambaNova provider — free tier, fast Llama inference.""" name = "sambanova" diff --git a/tests/test_models_mapping.py b/tests/test_models_mapping.py new file mode 100644 index 0000000..4bdab29 --- /dev/null +++ b/tests/test_models_mapping.py @@ -0,0 +1,78 @@ +"""Tests for ModelMapper with all providers.""" + +import pytest +from app.models_mapping import ModelMapper + + +@pytest.fixture +def mapper(): + return ModelMapper() + + +def test_llama_70b_groq(mapper): + assert mapper.get_provider_model("llama-70b", "groq") == "llama-3.3-70b-versatile" + + +def test_llama_70b_sambanova(mapper): + assert mapper.get_provider_model("llama-70b", "sambanova") == "Meta-Llama-3.3-70B-Instruct" + + +def test_llama_70b_huggingface(mapper): + assert mapper.get_provider_model("llama-70b", "huggingface") == "meta-llama/Llama-3.3-70B-Instruct" + + +def test_llama_8b_groq(mapper): + assert mapper.get_provider_model("llama-8b", "groq") == "llama-3.1-8b-instant" + + +def test_gemini_flash(mapper): + assert mapper.get_provider_model("gemini-flash", "gemini") == "gemini-2.0-flash" + + +def test_gemini_flash_openrouter(mapper): + assert mapper.get_provider_model("gemini-flash", "openrouter") == "google/gemini-2.0-flash-exp:free" + + +def test_mistral_7b(mapper): + assert mapper.get_provider_model("mistral-7b", "mistral") == "open-mistral-7b" + + +def test_deepseek_r1(mapper): + assert mapper.get_provider_model("deepseek-r1", "sambanova") == "DeepSeek-R1" + + +def test_deepseek_chat_openrouter(mapper): + assert mapper.get_provider_model("deepseek-chat", "openrouter") == "deepseek/deepseek-chat:free" + + +def test_command_r(mapper): + assert mapper.get_provider_model("command-r", "cohere") == "command-r-08-2024" + + +def test_command_r_plus(mapper): + assert mapper.get_provider_model("command-r-plus", "cohere") == "command-r-plus-08-2024" + + +def test_gemma2_9b_groq(mapper): + assert mapper.get_provider_model("gemma2-9b", "groq") == "gemma2-9b-it" + + +def test_qwen_72b_openrouter(mapper): + assert mapper.get_provider_model("qwen-72b", "openrouter") == "qwen/qwen-2.5-72b-instruct:free" + + +def test_unknown_model_returns_as_is(mapper): + assert mapper.get_provider_model("some-custom-model", "groq") == "some-custom-model" + + +def test_none_model_returns_default(mapper): + result = mapper.get_provider_model(None, "cerebras") + assert result == "llama-3.3-70b" + + +def test_get_available_models(mapper): + models = mapper.get_available_models() + assert "llama-70b" in models + assert "gemini-flash" in models + assert "deepseek-r1" in models + assert "command-r" in models \ No newline at end of file From 5a725be22334017676fef42357a8ef91d869aa81 Mon Sep 17 00:00:00 2001 From: Faisal Date: Sat, 16 May 2026 22:21:43 +0530 Subject: [PATCH 5/8] Wire all new providers into main.py runtime initialization --- app/main.py | 325 +++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 259 insertions(+), 66 deletions(-) diff --git a/app/main.py b/app/main.py index cbccede..1ef24b4 100644 --- a/app/main.py +++ b/app/main.py @@ -18,7 +18,18 @@ ConductorStatus, ) from app.scheduler import Scheduler, SchedulingStrategy -from app.providers import CerebrasProvider, NvidiaProvider +from app.providers import ( + CerebrasProvider, + NvidiaProvider, + GroqProvider, + GeminiProvider, + MistralProvider, + OpenRouterProvider, + DeepSeekProvider, + HuggingFaceProvider, + CohereProvider, + SambaNovaProvider, +) from app.providers.base import ProviderKey from app.rate_limiter import RateLimitBucket from app.models_mapping import ModelMapper, DEFAULT_MODEL @@ -56,25 +67,25 @@ async def initialize_scheduler(config: dict) -> Scheduler: """Initialize the scheduler with providers from config.""" global model_mapper - + conductor_config = config.get("conductor", {}) strategy = SchedulingStrategy( conductor_config.get("scheduling_strategy", "round_robin") ) - + # Initialize model mapper with custom mappings from config custom_models = config.get("models", {}) model_mapper = ModelMapper(custom_models) - + await logger.ainfo( "Model mapper initialized", available_models=model_mapper.get_available_models() ) - + sched = Scheduler(strategy=strategy) - + providers_config = config.get("providers", {}) - + # Initialize Cerebras cerebras_config = providers_config.get("cerebras", {}) if cerebras_config.get("enabled", False) and cerebras_config.get("keys"): @@ -82,7 +93,6 @@ async def initialize_scheduler(config: dict) -> Scheduler: base_url=cerebras_config.get("base_url", "https://api.cerebras.ai/v1"), model_mapper=model_mapper, ) - for i, key_config in enumerate(cerebras_config.get("keys", [])): key_name = key_config.get("name", f"cerebras-key-{i+1}") bucket = RateLimitBucket( @@ -98,14 +108,10 @@ async def initialize_scheduler(config: dict) -> Scheduler: base_url=provider.base_url, ) provider.add_key(provider_key) - if provider.keys: await sched.add_provider(provider) - await logger.ainfo( - "Cerebras provider initialized", - keys_count=len(provider.keys) - ) - + await logger.ainfo("Cerebras provider initialized", keys_count=len(provider.keys)) + # Initialize NVIDIA NIM nvidia_config = providers_config.get("nvidia", {}) if nvidia_config.get("enabled", False) and nvidia_config.get("keys"): @@ -113,7 +119,6 @@ async def initialize_scheduler(config: dict) -> Scheduler: base_url=nvidia_config.get("base_url", "https://integrate.api.nvidia.com/v1"), model_mapper=model_mapper, ) - for i, key_config in enumerate(nvidia_config.get("keys", [])): key_name = key_config.get("name", f"nvidia-key-{i+1}") bucket = RateLimitBucket( @@ -129,14 +134,218 @@ async def initialize_scheduler(config: dict) -> Scheduler: base_url=provider.base_url, ) provider.add_key(provider_key) - if provider.keys: await sched.add_provider(provider) - await logger.ainfo( - "NVIDIA NIM provider initialized", - keys_count=len(provider.keys) + await logger.ainfo("NVIDIA NIM provider initialized", keys_count=len(provider.keys)) + + # Initialize Groq + groq_config = providers_config.get("groq", {}) + if groq_config.get("enabled", False) and groq_config.get("keys"): + provider = GroqProvider( + base_url=groq_config.get("base_url", "https://api.groq.com/openai/v1"), + model_mapper=model_mapper, + ) + for i, key_config in enumerate(groq_config.get("keys", [])): + key_name = key_config.get("name", f"groq-key-{i+1}") + bucket = RateLimitBucket( + name=f"groq:{key_name}", + requests_per_minute=key_config.get("requests_per_minute", 30), + tokens_per_minute=key_config.get("tokens_per_minute", 6000), + ) + provider_key = ProviderKey( + provider_name="groq", + key_name=key_name, + api_key=key_config["api_key"], + bucket=bucket, + base_url=provider.base_url, + ) + provider.add_key(provider_key) + if provider.keys: + await sched.add_provider(provider) + await logger.ainfo("Groq provider initialized", keys_count=len(provider.keys)) + + # Initialize Gemini + gemini_config = providers_config.get("gemini", {}) + if gemini_config.get("enabled", False) and gemini_config.get("keys"): + provider = GeminiProvider( + base_url=gemini_config.get("base_url", "https://generativelanguage.googleapis.com/v1beta/openai"), + model_mapper=model_mapper, + ) + for i, key_config in enumerate(gemini_config.get("keys", [])): + key_name = key_config.get("name", f"gemini-key-{i+1}") + bucket = RateLimitBucket( + name=f"gemini:{key_name}", + requests_per_minute=key_config.get("requests_per_minute", 15), + tokens_per_minute=key_config.get("tokens_per_minute", 1_000_000), + ) + provider_key = ProviderKey( + provider_name="gemini", + key_name=key_name, + api_key=key_config["api_key"], + bucket=bucket, + base_url=provider.base_url, + ) + provider.add_key(provider_key) + if provider.keys: + await sched.add_provider(provider) + await logger.ainfo("Gemini provider initialized", keys_count=len(provider.keys)) + + # Initialize Mistral + mistral_config = providers_config.get("mistral", {}) + if mistral_config.get("enabled", False) and mistral_config.get("keys"): + provider = MistralProvider( + base_url=mistral_config.get("base_url", "https://api.mistral.ai/v1"), + model_mapper=model_mapper, + ) + for i, key_config in enumerate(mistral_config.get("keys", [])): + key_name = key_config.get("name", f"mistral-key-{i+1}") + bucket = RateLimitBucket( + name=f"mistral:{key_name}", + requests_per_minute=key_config.get("requests_per_minute", 60), + tokens_per_minute=key_config.get("tokens_per_minute", 100_000), + ) + provider_key = ProviderKey( + provider_name="mistral", + key_name=key_name, + api_key=key_config["api_key"], + bucket=bucket, + base_url=provider.base_url, + ) + provider.add_key(provider_key) + if provider.keys: + await sched.add_provider(provider) + await logger.ainfo("Mistral provider initialized", keys_count=len(provider.keys)) + + # Initialize OpenRouter + openrouter_config = providers_config.get("openrouter", {}) + if openrouter_config.get("enabled", False) and openrouter_config.get("keys"): + provider = OpenRouterProvider( + base_url=openrouter_config.get("base_url", "https://openrouter.ai/api/v1"), + model_mapper=model_mapper, + ) + for i, key_config in enumerate(openrouter_config.get("keys", [])): + key_name = key_config.get("name", f"openrouter-key-{i+1}") + bucket = RateLimitBucket( + name=f"openrouter:{key_name}", + requests_per_minute=key_config.get("requests_per_minute", 60), + tokens_per_minute=key_config.get("tokens_per_minute", 100_000), + ) + provider_key = ProviderKey( + provider_name="openrouter", + key_name=key_name, + api_key=key_config["api_key"], + bucket=bucket, + base_url=provider.base_url, + ) + provider.add_key(provider_key) + if provider.keys: + await sched.add_provider(provider) + await logger.ainfo("OpenRouter provider initialized", keys_count=len(provider.keys)) + + # Initialize DeepSeek + deepseek_config = providers_config.get("deepseek", {}) + if deepseek_config.get("enabled", False) and deepseek_config.get("keys"): + provider = DeepSeekProvider( + base_url=deepseek_config.get("base_url", "https://api.deepseek.com/v1"), + model_mapper=model_mapper, + ) + for i, key_config in enumerate(deepseek_config.get("keys", [])): + key_name = key_config.get("name", f"deepseek-key-{i+1}") + bucket = RateLimitBucket( + name=f"deepseek:{key_name}", + requests_per_minute=key_config.get("requests_per_minute", 60), + tokens_per_minute=key_config.get("tokens_per_minute", 100_000), + ) + provider_key = ProviderKey( + provider_name="deepseek", + key_name=key_name, + api_key=key_config["api_key"], + bucket=bucket, + base_url=provider.base_url, + ) + provider.add_key(provider_key) + if provider.keys: + await sched.add_provider(provider) + await logger.ainfo("DeepSeek provider initialized", keys_count=len(provider.keys)) + + # Initialize HuggingFace + huggingface_config = providers_config.get("huggingface", {}) + if huggingface_config.get("enabled", False) and huggingface_config.get("keys"): + provider = HuggingFaceProvider( + base_url=huggingface_config.get("base_url", "https://api-inference.huggingface.co/v1"), + model_mapper=model_mapper, + ) + for i, key_config in enumerate(huggingface_config.get("keys", [])): + key_name = key_config.get("name", f"huggingface-key-{i+1}") + bucket = RateLimitBucket( + name=f"huggingface:{key_name}", + requests_per_minute=key_config.get("requests_per_minute", 30), + tokens_per_minute=key_config.get("tokens_per_minute", 50_000), + ) + provider_key = ProviderKey( + provider_name="huggingface", + key_name=key_name, + api_key=key_config["api_key"], + bucket=bucket, + base_url=provider.base_url, + ) + provider.add_key(provider_key) + if provider.keys: + await sched.add_provider(provider) + await logger.ainfo("HuggingFace provider initialized", keys_count=len(provider.keys)) + + # Initialize Cohere + cohere_config = providers_config.get("cohere", {}) + if cohere_config.get("enabled", False) and cohere_config.get("keys"): + provider = CohereProvider( + base_url=cohere_config.get("base_url", "https://api.cohere.com/compatibility/v1"), + model_mapper=model_mapper, + ) + for i, key_config in enumerate(cohere_config.get("keys", [])): + key_name = key_config.get("name", f"cohere-key-{i+1}") + bucket = RateLimitBucket( + name=f"cohere:{key_name}", + requests_per_minute=key_config.get("requests_per_minute", 20), + tokens_per_minute=key_config.get("tokens_per_minute", 100_000), + ) + provider_key = ProviderKey( + provider_name="cohere", + key_name=key_name, + api_key=key_config["api_key"], + bucket=bucket, + base_url=provider.base_url, ) - + provider.add_key(provider_key) + if provider.keys: + await sched.add_provider(provider) + await logger.ainfo("Cohere provider initialized", keys_count=len(provider.keys)) + + # Initialize SambaNova + sambanova_config = providers_config.get("sambanova", {}) + if sambanova_config.get("enabled", False) and sambanova_config.get("keys"): + provider = SambaNovaProvider( + base_url=sambanova_config.get("base_url", "https://api.sambanova.ai/v1"), + model_mapper=model_mapper, + ) + for i, key_config in enumerate(sambanova_config.get("keys", [])): + key_name = key_config.get("name", f"sambanova-key-{i+1}") + bucket = RateLimitBucket( + name=f"sambanova:{key_name}", + requests_per_minute=key_config.get("requests_per_minute", 60), + tokens_per_minute=key_config.get("tokens_per_minute", 100_000), + ) + provider_key = ProviderKey( + provider_name="sambanova", + key_name=key_name, + api_key=key_config["api_key"], + bucket=bucket, + base_url=provider.base_url, + ) + provider.add_key(provider_key) + if provider.keys: + await sched.add_provider(provider) + await logger.ainfo("SambaNova provider initialized", keys_count=len(provider.keys)) + await sched.start() return sched @@ -145,28 +354,28 @@ async def initialize_scheduler(config: dict) -> Scheduler: async def lifespan(app: FastAPI): """Application lifespan manager.""" global scheduler - + await logger.ainfo("Starting TrainForgeConductor...") - + # Load configuration config = load_config(settings.config_path) await logger.ainfo("Configuration loaded", config_path=settings.config_path) - + # Initialize scheduler scheduler = await initialize_scheduler(config) - + if not scheduler.providers: await logger.awarning( "No providers configured! Add API keys to config/config.yaml" ) - + await logger.ainfo( "TrainForgeConductor ready", providers=list(scheduler.providers.keys()), ) - + yield - + # Shutdown await logger.ainfo("Shutting down TrainForgeConductor...") if scheduler: @@ -202,10 +411,9 @@ async def get_status(): """Get conductor status including all provider rate limits.""" if not scheduler: raise HTTPException(status_code=503, detail="Scheduler not initialized") - + status = await scheduler.get_status() - - # Convert to response model + from app.models import ProviderStatus provider_statuses = [ ProviderStatus( @@ -222,7 +430,7 @@ async def get_status(): ) for p in status["providers"] ] - + return ConductorStatus( status=status["status"], total_providers=status["total_providers"], @@ -238,36 +446,33 @@ async def get_status(): async def chat_completion(request: ChatCompletionRequest): """ OpenAI-compatible chat completion endpoint. - + Use unified model names like "llama-70b" or "llama-8b" - the conductor will automatically translate to the correct provider-specific name. - + The conductor routes requests to available providers based on rate limits and the configured scheduling strategy. - + Optional fields: - model: Model to use (default: llama-70b). Use unified names. - provider: Force a specific provider (e.g., "cerebras" or "nvidia") - priority: Request priority (0-10, higher = more priority) - auto_retry: Automatically retry on failures with exponential backoff (default: true). - When enabled, the conductor absorbs transient errors and keeps retrying until it gets - a response or exhausts max_retries, so the caller always gets an answer. - max_retries: Maximum number of retry attempts when auto_retry is enabled (default: 10). """ if not scheduler: raise HTTPException(status_code=503, detail="Scheduler not initialized") - + if not scheduler.providers: raise HTTPException( - status_code=503, + status_code=503, detail="No providers configured. Add API keys to config/config.yaml" ) - + try: response = await scheduler.submit(request, wait=True) return response except MaxRetriesExhaustedError as e: - # Auto-retry exhausted all attempts — return full retry log await logger.aerror( "Auto-retry exhausted", total_attempts=e.total_attempts, @@ -285,7 +490,6 @@ async def chat_completion(request: ChatCompletionRequest): }, ) except AllProvidersExhaustedError as e: - # All providers failed (auto_retry=False path) - return detailed error await logger.aerror( "All providers exhausted", error_count=len(e.errors), @@ -338,56 +542,46 @@ async def chat_completion(request: ChatCompletionRequest): async def batch_chat_completion(batch: BatchRequest): """ Submit a batch of chat completion requests. - + All requests will be scheduled across available providers to maximize throughput while respecting rate limits. """ if not scheduler: raise HTTPException(status_code=503, detail="Scheduler not initialized") - + if not scheduler.providers: raise HTTPException( status_code=503, detail="No providers configured. Add API keys to config/config.yaml" ) - + start_time = time.time() - - # Create tasks for all requests + tasks = [ scheduler.submit(req, wait=True) for req in batch.requests ] - + responses: list[ChatCompletionResponse] = [] failed: list[dict] = [] - + if batch.wait_for_all: - # Wait for all to complete, collect results results = await asyncio.gather(*tasks, return_exceptions=True) - for i, result in enumerate(results): if isinstance(result, Exception): - failed.append({ - "index": i, - "error": str(result), - }) + failed.append({"index": i, "error": str(result)}) else: responses.append(result) else: - # Return as they complete for i, coro in enumerate(asyncio.as_completed(tasks)): try: result = await coro responses.append(result) except Exception as e: - failed.append({ - "index": i, - "error": str(e), - }) - + failed.append({"index": i, "error": str(e)}) + elapsed_ms = (time.time() - start_time) * 1000 - + return BatchResponse( responses=responses, failed=failed, @@ -399,21 +593,20 @@ async def batch_chat_completion(batch: BatchRequest): async def list_models(): """ List all available unified model names. - + These are the model names you can use in requests. The conductor automatically translates them to provider-specific names. """ if not model_mapper: raise HTTPException(status_code=503, detail="Model mapper not initialized") - - # Return unified model names + models = [] for unified_name in model_mapper.get_available_models(): models.append({ "id": unified_name, "object": "model", }) - + return { "data": models, "object": "list", @@ -433,4 +626,4 @@ def main(): if __name__ == "__main__": - main() + main() \ No newline at end of file From 604708a64a3685891b33c49c3a7f1b2feffe2948 Mon Sep 17 00:00:00 2001 From: Faisal Date: Mon, 18 May 2026 00:49:25 +0530 Subject: [PATCH 6/8] Remove double token counting from all new providers --- app/providers/cohere.py | 6 +---- app/providers/gemini.py | 4 --- app/providers/groq.py | 4 --- app/providers/nvidia.py | 54 ++++++++++++++++------------------------- 4 files changed, 22 insertions(+), 46 deletions(-) diff --git a/app/providers/cohere.py b/app/providers/cohere.py index 0f401d4..89672ca 100644 --- a/app/providers/cohere.py +++ b/app/providers/cohere.py @@ -76,10 +76,6 @@ async def chat_completion( choice = data["choices"][0] usage = data.get("usage", {}) - total_tokens = usage.get("total_tokens", 0) - if total_tokens > 0: - await key.bucket.consume_tokens(total_tokens) - return ChatCompletionResponse( id=data.get("id", f"cohere-{int(time.time())}"), created=data.get("created", int(time.time())), @@ -163,4 +159,4 @@ async def chat_completion( raise except Exception as e: await logger.aerror("Cohere request failed", error=str(e)) - raise \ No newline at end of file + raise \ No newline at end of file diff --git a/app/providers/gemini.py b/app/providers/gemini.py index 3c7f641..48977f9 100644 --- a/app/providers/gemini.py +++ b/app/providers/gemini.py @@ -76,10 +76,6 @@ async def chat_completion( choice = data["choices"][0] usage = data.get("usage", {}) - total_tokens = usage.get("total_tokens", 0) - if total_tokens > 0: - await key.bucket.consume_tokens(total_tokens) - return ChatCompletionResponse( id=data.get("id", f"gemini-{int(time.time())}"), created=data.get("created", int(time.time())), diff --git a/app/providers/groq.py b/app/providers/groq.py index e46d8a6..645f280 100644 --- a/app/providers/groq.py +++ b/app/providers/groq.py @@ -76,10 +76,6 @@ async def chat_completion( choice = data["choices"][0] usage = data.get("usage", {}) - total_tokens = usage.get("total_tokens", 0) - if total_tokens > 0: - await key.bucket.consume_tokens(total_tokens) - return ChatCompletionResponse( id=data.get("id", f"groq-{int(time.time())}"), created=data.get("created", int(time.time())), diff --git a/app/providers/nvidia.py b/app/providers/nvidia.py index 53c3541..5a50135 100644 --- a/app/providers/nvidia.py +++ b/app/providers/nvidia.py @@ -24,23 +24,21 @@ class NvidiaProvider(BaseProvider): """NVIDIA NIM provider (via build.nvidia.com or self-hosted).""" - + name = "nvidia" - + def __init__(self, base_url: str = "https://integrate.api.nvidia.com/v1", model_mapper: ModelMapper = None): super().__init__(base_url, model_mapper) - + async def chat_completion( self, key: ProviderKey, request: ChatCompletionRequest, ) -> ChatCompletionResponse: """Execute chat completion via NVIDIA NIM API.""" - - # Translate unified model name to NVIDIA-specific name + model = self.get_model_name(request.model) - - # Prepare request payload (OpenAI-compatible) + payload = { "model": model, "messages": [{"role": m.role, "content": m.content} for m in request.messages], @@ -48,24 +46,24 @@ async def chat_completion( "max_tokens": request.max_tokens, "top_p": request.top_p, } - + if request.stop: payload["stop"] = request.stop - + headers = { "Authorization": f"Bearer {key.api_key}", "Content-Type": "application/json", } - + client = await self.get_client() - + await logger.ainfo( "Sending request to NVIDIA NIM", model=model, key_name=key.key_name, messages_count=len(request.messages), ) - + try: response = await client.post( f"{self.base_url}/chat/completions", @@ -74,16 +72,14 @@ async def chat_completion( ) response.raise_for_status() data = response.json() - - # Parse response + choice = data["choices"][0] usage = data.get("usage", {}) - - # Update token consumption + total_tokens = usage.get("total_tokens", 0) if total_tokens > 0: await key.bucket.consume_tokens(total_tokens) - + return ChatCompletionResponse( id=data.get("id", f"nvidia-{int(time.time())}"), created=data.get("created", int(time.time())), @@ -106,21 +102,19 @@ async def chat_completion( provider=self.name, provider_key_name=key.key_name, ) - + except httpx.HTTPStatusError as e: status_code = e.response.status_code response_text = e.response.text - + await logger.aerror( "NVIDIA NIM API error", status_code=status_code, response=response_text, key_name=key.key_name, ) - - # Handle rate limit (429) + if status_code == 429: - # Try to extract retry-after header retry_after = e.response.headers.get("retry-after") retry_seconds = float(retry_after) if retry_after else None raise RateLimitError( @@ -128,18 +122,15 @@ async def chat_completion( retry_after=retry_seconds, message=f"Rate limit exceeded: {response_text[:200]}", ) - - # Handle server errors (5xx) - provider temporarily unavailable + if 500 <= status_code < 600: raise ProviderUnavailableError( provider=self.name, status_code=status_code, message=f"NVIDIA server error: {response_text[:200]}", ) - - # Handle capability/validation errors (400) + if status_code == 400: - # Check for known capability issues response_lower = response_text.lower() if "tool" in response_lower or "function" in response_lower: raise CapabilityError( @@ -153,16 +144,14 @@ async def chat_completion( capability="vision", message=f"Vision/image error: {response_text[:200]}", ) - # Generic capability error raise CapabilityError( provider=self.name, capability="unknown", message=f"Request error: {response_text[:200]}", ) - - # Re-raise other errors + raise - + except httpx.TimeoutException as e: await logger.aerror("NVIDIA NIM request timeout", error=str(e)) raise ProviderUnavailableError( @@ -171,8 +160,7 @@ async def chat_completion( message=f"Request timeout: {str(e)}", ) except (RateLimitError, CapabilityError, ProviderUnavailableError): - # Re-raise our custom exceptions raise except Exception as e: await logger.aerror("NVIDIA NIM request failed", error=str(e)) - raise + raise \ No newline at end of file From 06be10ba1c4d1c24f1ecd91fd5d9b08208c7907e Mon Sep 17 00:00:00 2001 From: Faisal Date: Mon, 18 May 2026 13:41:16 +0530 Subject: [PATCH 7/8] Add integration tests for Groq, OpenRouter, DeepSeek and Gemini providers --- tests/test_new_providers.py | 208 ++++++++++++++++++++++++++++++++++++ 1 file changed, 208 insertions(+) create mode 100644 tests/test_new_providers.py diff --git a/tests/test_new_providers.py b/tests/test_new_providers.py new file mode 100644 index 0000000..26e132d --- /dev/null +++ b/tests/test_new_providers.py @@ -0,0 +1,208 @@ +#!/usr/bin/env python3 +"""Integration tests for new providers added in PR #4.""" + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +import httpx +import pytest + +BASE_URL = "http://localhost:8000" + + +class TestGroqProvider: + """Tests specifically for Groq provider.""" + + @pytest.mark.asyncio + async def test_groq_simple_completion(self): + """Test a simple completion via Groq using unified model name.""" + async with httpx.AsyncClient(timeout=60.0) as client: + response = await client.post( + f"{BASE_URL}/v1/chat/completions", + json={ + "messages": [ + {"role": "user", "content": "Say only the word 'GROQ_OK' and nothing else."} + ], + "model": "llama-70b", + "max_tokens": 20, + "provider": "groq", + "temperature": 0, + } + ) + assert response.status_code == 200 + data = response.json() + assert data["provider"] == "groq" + assert len(data["choices"]) > 0 + content = data["choices"][0]["message"]["content"] + print(f"✔ Groq response: {content[:50]}") + + @pytest.mark.asyncio + async def test_groq_with_system_prompt(self): + """Test Groq with system prompt.""" + async with httpx.AsyncClient(timeout=60.0) as client: + response = await client.post( + f"{BASE_URL}/v1/chat/completions", + json={ + "messages": [ + {"role": "system", "content": "You are a math tutor. Be concise."}, + {"role": "user", "content": "What is 10 + 32?"} + ], + "model": "llama-70b", + "max_tokens": 50, + "provider": "groq", + "temperature": 0, + } + ) + assert response.status_code == 200 + data = response.json() + content = data["choices"][0]["message"]["content"] + assert "42" in content + print(f"✔ Groq math: {content[:50]}") + + @pytest.mark.asyncio + async def test_groq_fast_model(self): + """Test Groq with llama-8b instant model.""" + async with httpx.AsyncClient(timeout=60.0) as client: + response = await client.post( + f"{BASE_URL}/v1/chat/completions", + json={ + "messages": [ + {"role": "user", "content": "What is the capital of France?"} + ], + "model": "llama-8b", + "max_tokens": 50, + "provider": "groq", + "temperature": 0, + } + ) + assert response.status_code == 200 + data = response.json() + content = data["choices"][0]["message"]["content"] + assert "Paris" in content or "paris" in content.lower() + print(f"✔ Groq fast model: {content[:50]}") + + +class TestOpenRouterProvider: + """Tests specifically for OpenRouter provider.""" + + @pytest.mark.asyncio + async def test_openrouter_simple_completion(self): + """Test a simple completion via OpenRouter using unified model name.""" + async with httpx.AsyncClient(timeout=60.0) as client: + response = await client.post( + f"{BASE_URL}/v1/chat/completions", + json={ + "messages": [ + {"role": "user", "content": "Say only the word 'OPENROUTER_OK' and nothing else."} + ], + "model": "llama-70b", + "max_tokens": 20, + "provider": "openrouter", + "temperature": 0, + } + ) + assert response.status_code == 200 + data = response.json() + assert data["provider"] == "openrouter" + assert len(data["choices"]) > 0 + content = data["choices"][0]["message"]["content"] + print(f"✔ OpenRouter response: {content[:50]}") + + @pytest.mark.asyncio + async def test_openrouter_with_system_prompt(self): + """Test OpenRouter with system prompt.""" + async with httpx.AsyncClient(timeout=60.0) as client: + response = await client.post( + f"{BASE_URL}/v1/chat/completions", + json={ + "messages": [ + {"role": "system", "content": "You are a helpful assistant. Be very brief."}, + {"role": "user", "content": "What is the capital of Germany?"} + ], + "model": "llama-70b", + "max_tokens": 50, + "provider": "openrouter", + "temperature": 0, + } + ) + assert response.status_code == 200 + data = response.json() + content = data["choices"][0]["message"]["content"] + assert "Berlin" in content or "berlin" in content.lower() + print(f"✔ OpenRouter geography: {content[:50]}") + + @pytest.mark.asyncio + async def test_openrouter_free_model(self): + """Test OpenRouter with a free Gemini model.""" + async with httpx.AsyncClient(timeout=60.0) as client: + response = await client.post( + f"{BASE_URL}/v1/chat/completions", + json={ + "messages": [ + {"role": "user", "content": "What is 5 + 5?"} + ], + "model": "gemini-flash", + "max_tokens": 20, + "provider": "openrouter", + "temperature": 0, + } + ) + assert response.status_code == 200 + data = response.json() + content = data["choices"][0]["message"]["content"] + assert "10" in content + print(f"✔ OpenRouter Gemini: {content[:50]}") + + +class TestDeepSeekProvider: + """Tests specifically for DeepSeek provider.""" + + @pytest.mark.asyncio + async def test_deepseek_simple_completion(self): + """Test a simple completion via DeepSeek.""" + async with httpx.AsyncClient(timeout=60.0) as client: + response = await client.post( + f"{BASE_URL}/v1/chat/completions", + json={ + "messages": [ + {"role": "user", "content": "Say only the word 'DEEPSEEK_OK' and nothing else."} + ], + "model": "deepseek-chat", + "max_tokens": 20, + "provider": "deepseek", + "temperature": 0, + } + ) + assert response.status_code == 200 + data = response.json() + assert data["provider"] == "deepseek" + content = data["choices"][0]["message"]["content"] + print(f"✔ DeepSeek response: {content[:50]}") + + +class TestGeminiProvider: + """Tests specifically for Gemini provider.""" + + @pytest.mark.asyncio + async def test_gemini_simple_completion(self): + """Test a simple completion via Gemini.""" + async with httpx.AsyncClient(timeout=60.0) as client: + response = await client.post( + f"{BASE_URL}/v1/chat/completions", + json={ + "messages": [ + {"role": "user", "content": "Say only the word 'GEMINI_OK' and nothing else."} + ], + "model": "gemini-flash", + "max_tokens": 20, + "provider": "gemini", + "temperature": 0, + } + ) + assert response.status_code == 200 + data = response.json() + assert data["provider"] == "gemini" + content = data["choices"][0]["message"]["content"] + print(f"✔ Gemini response: {content[:50]}") \ No newline at end of file From fe95d616a830a068ed8d2417b92d81358d8d0369 Mon Sep 17 00:00:00 2001 From: Faisal Date: Mon, 18 May 2026 13:44:36 +0530 Subject: [PATCH 8/8] Refactor provider initialization into shared helper table in main.py --- app/main.py | 338 +++++++++++++++------------------------------------- 1 file changed, 93 insertions(+), 245 deletions(-) diff --git a/app/main.py b/app/main.py index 1ef24b4..4c87cc5 100644 --- a/app/main.py +++ b/app/main.py @@ -63,6 +63,80 @@ scheduler: Optional[Scheduler] = None model_mapper: Optional[ModelMapper] = None +# Provider definition table — add new providers here only +PROVIDER_DEFINITIONS = [ + { + "name": "cerebras", + "class": CerebrasProvider, + "default_base_url": "https://api.cerebras.ai/v1", + "default_rpm": 1000, + "default_tpm": 1_000_000, + }, + { + "name": "nvidia", + "class": NvidiaProvider, + "default_base_url": "https://integrate.api.nvidia.com/v1", + "default_rpm": 60, + "default_tpm": 100_000, + }, + { + "name": "groq", + "class": GroqProvider, + "default_base_url": "https://api.groq.com/openai/v1", + "default_rpm": 30, + "default_tpm": 6_000, + }, + { + "name": "gemini", + "class": GeminiProvider, + "default_base_url": "https://generativelanguage.googleapis.com/v1beta/openai", + "default_rpm": 15, + "default_tpm": 1_000_000, + }, + { + "name": "mistral", + "class": MistralProvider, + "default_base_url": "https://api.mistral.ai/v1", + "default_rpm": 60, + "default_tpm": 100_000, + }, + { + "name": "openrouter", + "class": OpenRouterProvider, + "default_base_url": "https://openrouter.ai/api/v1", + "default_rpm": 60, + "default_tpm": 100_000, + }, + { + "name": "deepseek", + "class": DeepSeekProvider, + "default_base_url": "https://api.deepseek.com/v1", + "default_rpm": 60, + "default_tpm": 100_000, + }, + { + "name": "huggingface", + "class": HuggingFaceProvider, + "default_base_url": "https://api-inference.huggingface.co/v1", + "default_rpm": 30, + "default_tpm": 50_000, + }, + { + "name": "cohere", + "class": CohereProvider, + "default_base_url": "https://api.cohere.com/compatibility/v1", + "default_rpm": 20, + "default_tpm": 100_000, + }, + { + "name": "sambanova", + "class": SambaNovaProvider, + "default_base_url": "https://api.sambanova.ai/v1", + "default_rpm": 60, + "default_tpm": 100_000, + }, +] + async def initialize_scheduler(config: dict) -> Scheduler: """Initialize the scheduler with providers from config.""" @@ -83,268 +157,45 @@ async def initialize_scheduler(config: dict) -> Scheduler: ) sched = Scheduler(strategy=strategy) - providers_config = config.get("providers", {}) - # Initialize Cerebras - cerebras_config = providers_config.get("cerebras", {}) - if cerebras_config.get("enabled", False) and cerebras_config.get("keys"): - provider = CerebrasProvider( - base_url=cerebras_config.get("base_url", "https://api.cerebras.ai/v1"), - model_mapper=model_mapper, - ) - for i, key_config in enumerate(cerebras_config.get("keys", [])): - key_name = key_config.get("name", f"cerebras-key-{i+1}") - bucket = RateLimitBucket( - name=f"cerebras:{key_name}", - requests_per_minute=key_config.get("requests_per_minute", 1000), - tokens_per_minute=key_config.get("tokens_per_minute", 1_000_000), - ) - provider_key = ProviderKey( - provider_name="cerebras", - key_name=key_name, - api_key=key_config["api_key"], - bucket=bucket, - base_url=provider.base_url, - ) - provider.add_key(provider_key) - if provider.keys: - await sched.add_provider(provider) - await logger.ainfo("Cerebras provider initialized", keys_count=len(provider.keys)) + # Initialize all providers from the definition table + for definition in PROVIDER_DEFINITIONS: + name = definition["name"] + provider_config = providers_config.get(name, {}) - # Initialize NVIDIA NIM - nvidia_config = providers_config.get("nvidia", {}) - if nvidia_config.get("enabled", False) and nvidia_config.get("keys"): - provider = NvidiaProvider( - base_url=nvidia_config.get("base_url", "https://integrate.api.nvidia.com/v1"), - model_mapper=model_mapper, - ) - for i, key_config in enumerate(nvidia_config.get("keys", [])): - key_name = key_config.get("name", f"nvidia-key-{i+1}") - bucket = RateLimitBucket( - name=f"nvidia:{key_name}", - requests_per_minute=key_config.get("requests_per_minute", 60), - tokens_per_minute=key_config.get("tokens_per_minute", 100_000), - ) - provider_key = ProviderKey( - provider_name="nvidia", - key_name=key_name, - api_key=key_config["api_key"], - bucket=bucket, - base_url=provider.base_url, - ) - provider.add_key(provider_key) - if provider.keys: - await sched.add_provider(provider) - await logger.ainfo("NVIDIA NIM provider initialized", keys_count=len(provider.keys)) + if not provider_config.get("enabled", False): + continue + if not provider_config.get("keys"): + continue - # Initialize Groq - groq_config = providers_config.get("groq", {}) - if groq_config.get("enabled", False) and groq_config.get("keys"): - provider = GroqProvider( - base_url=groq_config.get("base_url", "https://api.groq.com/openai/v1"), + provider = definition["class"]( + base_url=provider_config.get("base_url", definition["default_base_url"]), model_mapper=model_mapper, ) - for i, key_config in enumerate(groq_config.get("keys", [])): - key_name = key_config.get("name", f"groq-key-{i+1}") - bucket = RateLimitBucket( - name=f"groq:{key_name}", - requests_per_minute=key_config.get("requests_per_minute", 30), - tokens_per_minute=key_config.get("tokens_per_minute", 6000), - ) - provider_key = ProviderKey( - provider_name="groq", - key_name=key_name, - api_key=key_config["api_key"], - bucket=bucket, - base_url=provider.base_url, - ) - provider.add_key(provider_key) - if provider.keys: - await sched.add_provider(provider) - await logger.ainfo("Groq provider initialized", keys_count=len(provider.keys)) - # Initialize Gemini - gemini_config = providers_config.get("gemini", {}) - if gemini_config.get("enabled", False) and gemini_config.get("keys"): - provider = GeminiProvider( - base_url=gemini_config.get("base_url", "https://generativelanguage.googleapis.com/v1beta/openai"), - model_mapper=model_mapper, - ) - for i, key_config in enumerate(gemini_config.get("keys", [])): - key_name = key_config.get("name", f"gemini-key-{i+1}") + for i, key_config in enumerate(provider_config.get("keys", [])): + key_name = key_config.get("name", f"{name}-key-{i+1}") bucket = RateLimitBucket( - name=f"gemini:{key_name}", - requests_per_minute=key_config.get("requests_per_minute", 15), - tokens_per_minute=key_config.get("tokens_per_minute", 1_000_000), + name=f"{name}:{key_name}", + requests_per_minute=key_config.get("requests_per_minute", definition["default_rpm"]), + tokens_per_minute=key_config.get("tokens_per_minute", definition["default_tpm"]), ) provider_key = ProviderKey( - provider_name="gemini", + provider_name=name, key_name=key_name, api_key=key_config["api_key"], bucket=bucket, base_url=provider.base_url, ) provider.add_key(provider_key) - if provider.keys: - await sched.add_provider(provider) - await logger.ainfo("Gemini provider initialized", keys_count=len(provider.keys)) - # Initialize Mistral - mistral_config = providers_config.get("mistral", {}) - if mistral_config.get("enabled", False) and mistral_config.get("keys"): - provider = MistralProvider( - base_url=mistral_config.get("base_url", "https://api.mistral.ai/v1"), - model_mapper=model_mapper, - ) - for i, key_config in enumerate(mistral_config.get("keys", [])): - key_name = key_config.get("name", f"mistral-key-{i+1}") - bucket = RateLimitBucket( - name=f"mistral:{key_name}", - requests_per_minute=key_config.get("requests_per_minute", 60), - tokens_per_minute=key_config.get("tokens_per_minute", 100_000), - ) - provider_key = ProviderKey( - provider_name="mistral", - key_name=key_name, - api_key=key_config["api_key"], - bucket=bucket, - base_url=provider.base_url, - ) - provider.add_key(provider_key) if provider.keys: await sched.add_provider(provider) - await logger.ainfo("Mistral provider initialized", keys_count=len(provider.keys)) - - # Initialize OpenRouter - openrouter_config = providers_config.get("openrouter", {}) - if openrouter_config.get("enabled", False) and openrouter_config.get("keys"): - provider = OpenRouterProvider( - base_url=openrouter_config.get("base_url", "https://openrouter.ai/api/v1"), - model_mapper=model_mapper, - ) - for i, key_config in enumerate(openrouter_config.get("keys", [])): - key_name = key_config.get("name", f"openrouter-key-{i+1}") - bucket = RateLimitBucket( - name=f"openrouter:{key_name}", - requests_per_minute=key_config.get("requests_per_minute", 60), - tokens_per_minute=key_config.get("tokens_per_minute", 100_000), + await logger.ainfo( + f"{name.upper()} provider initialized", + keys_count=len(provider.keys) ) - provider_key = ProviderKey( - provider_name="openrouter", - key_name=key_name, - api_key=key_config["api_key"], - bucket=bucket, - base_url=provider.base_url, - ) - provider.add_key(provider_key) - if provider.keys: - await sched.add_provider(provider) - await logger.ainfo("OpenRouter provider initialized", keys_count=len(provider.keys)) - - # Initialize DeepSeek - deepseek_config = providers_config.get("deepseek", {}) - if deepseek_config.get("enabled", False) and deepseek_config.get("keys"): - provider = DeepSeekProvider( - base_url=deepseek_config.get("base_url", "https://api.deepseek.com/v1"), - model_mapper=model_mapper, - ) - for i, key_config in enumerate(deepseek_config.get("keys", [])): - key_name = key_config.get("name", f"deepseek-key-{i+1}") - bucket = RateLimitBucket( - name=f"deepseek:{key_name}", - requests_per_minute=key_config.get("requests_per_minute", 60), - tokens_per_minute=key_config.get("tokens_per_minute", 100_000), - ) - provider_key = ProviderKey( - provider_name="deepseek", - key_name=key_name, - api_key=key_config["api_key"], - bucket=bucket, - base_url=provider.base_url, - ) - provider.add_key(provider_key) - if provider.keys: - await sched.add_provider(provider) - await logger.ainfo("DeepSeek provider initialized", keys_count=len(provider.keys)) - - # Initialize HuggingFace - huggingface_config = providers_config.get("huggingface", {}) - if huggingface_config.get("enabled", False) and huggingface_config.get("keys"): - provider = HuggingFaceProvider( - base_url=huggingface_config.get("base_url", "https://api-inference.huggingface.co/v1"), - model_mapper=model_mapper, - ) - for i, key_config in enumerate(huggingface_config.get("keys", [])): - key_name = key_config.get("name", f"huggingface-key-{i+1}") - bucket = RateLimitBucket( - name=f"huggingface:{key_name}", - requests_per_minute=key_config.get("requests_per_minute", 30), - tokens_per_minute=key_config.get("tokens_per_minute", 50_000), - ) - provider_key = ProviderKey( - provider_name="huggingface", - key_name=key_name, - api_key=key_config["api_key"], - bucket=bucket, - base_url=provider.base_url, - ) - provider.add_key(provider_key) - if provider.keys: - await sched.add_provider(provider) - await logger.ainfo("HuggingFace provider initialized", keys_count=len(provider.keys)) - - # Initialize Cohere - cohere_config = providers_config.get("cohere", {}) - if cohere_config.get("enabled", False) and cohere_config.get("keys"): - provider = CohereProvider( - base_url=cohere_config.get("base_url", "https://api.cohere.com/compatibility/v1"), - model_mapper=model_mapper, - ) - for i, key_config in enumerate(cohere_config.get("keys", [])): - key_name = key_config.get("name", f"cohere-key-{i+1}") - bucket = RateLimitBucket( - name=f"cohere:{key_name}", - requests_per_minute=key_config.get("requests_per_minute", 20), - tokens_per_minute=key_config.get("tokens_per_minute", 100_000), - ) - provider_key = ProviderKey( - provider_name="cohere", - key_name=key_name, - api_key=key_config["api_key"], - bucket=bucket, - base_url=provider.base_url, - ) - provider.add_key(provider_key) - if provider.keys: - await sched.add_provider(provider) - await logger.ainfo("Cohere provider initialized", keys_count=len(provider.keys)) - - # Initialize SambaNova - sambanova_config = providers_config.get("sambanova", {}) - if sambanova_config.get("enabled", False) and sambanova_config.get("keys"): - provider = SambaNovaProvider( - base_url=sambanova_config.get("base_url", "https://api.sambanova.ai/v1"), - model_mapper=model_mapper, - ) - for i, key_config in enumerate(sambanova_config.get("keys", [])): - key_name = key_config.get("name", f"sambanova-key-{i+1}") - bucket = RateLimitBucket( - name=f"sambanova:{key_name}", - requests_per_minute=key_config.get("requests_per_minute", 60), - tokens_per_minute=key_config.get("tokens_per_minute", 100_000), - ) - provider_key = ProviderKey( - provider_name="sambanova", - key_name=key_name, - api_key=key_config["api_key"], - bucket=bucket, - base_url=provider.base_url, - ) - provider.add_key(provider_key) - if provider.keys: - await sched.add_provider(provider) - await logger.ainfo("SambaNova provider initialized", keys_count=len(provider.keys)) await sched.start() return sched @@ -357,11 +208,9 @@ async def lifespan(app: FastAPI): await logger.ainfo("Starting TrainForgeConductor...") - # Load configuration config = load_config(settings.config_path) await logger.ainfo("Configuration loaded", config_path=settings.config_path) - # Initialize scheduler scheduler = await initialize_scheduler(config) if not scheduler.providers: @@ -376,7 +225,6 @@ async def lifespan(app: FastAPI): yield - # Shutdown await logger.ainfo("Shutting down TrainForgeConductor...") if scheduler: await scheduler.stop()