diff --git a/.env.integration-test b/.env.integration-test index 099acf2..63eb649 100644 --- a/.env.integration-test +++ b/.env.integration-test @@ -1,3 +1,3 @@ AMRS_API_KEY=your_amrs_api_key_here OPENAI_API_KEY=your_openai_api_key_here -FAKE_API_KEY=your_fake_api_key_here +FAKER_API_KEY=your_faker_api_key_here diff --git a/Cargo.lock b/Cargo.lock index 6bb4ba0..35370be 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -770,6 +770,15 @@ version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77" +[[package]] +name = "lock_api" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" +dependencies = [ + "scopeguard", +] + [[package]] name = "log" version = "0.4.29" @@ -898,6 +907,29 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "parking_lot" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93857453250e3077bd71ff98b6a65ea6621a19bb0f559a85248955ac12c45a1a" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-link", +] + [[package]] name = "percent-encoding" version = "2.3.2" @@ -1078,6 +1110,15 @@ dependencies = [ "getrandom 0.3.4", ] +[[package]] +name = "redox_syscall" +version = "0.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" +dependencies = [ + "bitflags", +] + [[package]] name = "reqwest" version = "0.12.26" @@ -1243,6 +1284,12 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + [[package]] name = "secrecy" version = "0.10.3" @@ -1350,6 +1397,16 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +[[package]] +name = "signal-hook-registry" +version = "1.4.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4db69cba1110affc0e9f7bcd48bbf87b3f4fc7c61fc9155afd4c469eb3d6c1b" +dependencies = [ + "errno", + "libc", +] + [[package]] name = "slab" version = "0.4.11" @@ -1529,7 +1586,9 @@ dependencies = [ "bytes", "libc", "mio", + "parking_lot", "pin-project-lite", + "signal-hook-registry", "socket2", "tokio-macros", "windows-sys 0.61.2", diff --git a/Cargo.toml b/Cargo.toml index 563dd9f..d2cdedc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,4 +12,4 @@ lazy_static = "1.5.0" rand = "0.9.2" reqwest = "0.12.26" serde = "1.0.228" -tokio = "1.48.0" +tokio = { version = "1.48.0", features = ["full"] } diff --git a/README.md b/README.md index a52ac5d..30485e4 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ Thanks to [async-openai](https://github.com/64bit/async-openai), AMRS builds on - Flexible routing strategies, including: - **Random**: Randomly selects a model from the available models. - **WRR**: Weighted Round Robin selects models based on predefined weights. - - **UCB**: Upper Confidence Bound based model selection (coming soon). + - **UCB1**: Upper Confidence Bound based model selection (coming soon). - **Adaptive**: Dynamically selects models based on performance metrics (coming soon). - Broad provider support: @@ -27,21 +27,22 @@ Here's a simple example with the Weighted Round Robin (WRR) routing mode: // Before running the code, make sure to set your OpenAI API key in the environment variable: // export OPENAI_API_KEY="your_openai_api_key" -use arms::{Client, Config, ModelConfig, CreateResponseArgs, RoutingMode}; +use arms::client; +use arms::types::responses; -let config = Config::builder() +let config = client::Config::builder() .provider("openai") - .routing_mode(RoutingMode::WRR) + .routing_mode(client::RoutingMode::WRR) .model( - ModelConfig::builder() - .id("gpt-3.5-turbo") + client::ModelConfig::builder() + .name("gpt-3.5-turbo") .weight(2) .build() .unwrap(), ) .model( - ModelConfig::builder() - .id("gpt-4") + client::ModelConfig::builder() + .name("gpt-4") .weight(1) .build() .unwrap(), @@ -49,8 +50,8 @@ let config = Config::builder() .build() .unwrap(); -let mut client = Client::new(config); -let request = CreateResponseArgs::default() +let mut client = client::Client::new(config); +let request = responses::CreateResponseArgs::default() .input("give me a poem about nature") .build() .unwrap(); diff --git a/bindings/python/amrs/config.py b/bindings/python/amrs/config.py index a011e6d..955e58e 100644 --- a/bindings/python/amrs/config.py +++ b/bindings/python/amrs/config.py @@ -47,10 +47,10 @@ class BasicModelConfig(BaseModel): ) -type ModelID = str +type ModelName = str class ModelConfig(BasicModelConfig): - id: ModelID = Field( + id: ModelName = Field( description="ID of the model to be used." ) weight: Optional[int] = Field( diff --git a/bindings/python/amrs/router/random.py b/bindings/python/amrs/router/random.py index 73c5f6f..337f0e1 100644 --- a/bindings/python/amrs/router/random.py +++ b/bindings/python/amrs/router/random.py @@ -1,11 +1,11 @@ import random -from amrs.config import ModelID +from amrs.config import ModelName from amrs.router.router import Router class RandomRouter(Router): - def __init__(self, model_list: list[ModelID]): + def __init__(self, model_list: list[ModelName]): super().__init__(model_list) - def sample(self, _: str) -> ModelID: + def sample(self, _: str) -> ModelName: return random.choice(self._model_list) diff --git a/bindings/python/amrs/router/router.py b/bindings/python/amrs/router/router.py index 9266a3f..c8bceba 100644 --- a/bindings/python/amrs/router/router.py +++ b/bindings/python/amrs/router/router.py @@ -7,11 +7,11 @@ class ModelInfo: average_latency: float = 0.0 class Router(abc.ABC): - def __init__(self, model_list: list[config.ModelID]): + def __init__(self, model_list: list[config.ModelName]): self._model_list = model_list @abc.abstractmethod - def sample(self, content: str) -> config.ModelID: + def sample(self, content: str) -> config.ModelName: pass def new_router(model_cfgs: list[config.ModelConfig], mode: config.RoutingMode) -> Router: diff --git a/src/client/client.rs b/src/client/client.rs index dca327c..2f628f6 100644 --- a/src/client/client.rs +++ b/src/client/client.rs @@ -1,11 +1,13 @@ use std::collections::HashMap; -use crate::config::{Config, ModelId}; +use crate::client::config::{Config, ModelName}; use crate::provider::provider; use crate::router::router; +use crate::types::error::OpenAIError; +use crate::types::responses::{CreateResponse, Response}; pub struct Client { - providers: HashMap>, + providers: HashMap>, router: Box, } @@ -17,7 +19,7 @@ impl Client { let providers = cfg .models .iter() - .map(|m| (m.id.clone(), provider::construct_provider(m.clone()))) + .map(|m| (m.name.clone(), provider::construct_provider(m.clone()))) .collect(); Self { @@ -28,10 +30,10 @@ impl Client { pub async fn create_response( &mut self, - request: provider::CreateResponseReq, - ) -> Result { - let model_id = self.router.sample(&request); - let provider = self.providers.get(&model_id).unwrap(); + request: CreateResponse, + ) -> Result { + let candidate = self.router.sample(&request); + let provider = self.providers.get(&candidate).unwrap(); provider.create_response(request).await } } @@ -39,7 +41,7 @@ impl Client { #[cfg(test)] mod tests { use super::*; - use crate::config::{Config, ModelConfig, RoutingMode}; + use crate::client::config::{Config, ModelConfig, RoutingMode}; use dotenvy::from_filename; #[test] @@ -58,7 +60,7 @@ mod tests { config: Config::builder() .models(vec![ ModelConfig::builder() - .id("model_c".to_string()) + .name("model_c".to_string()) .build() .unwrap(), ]) @@ -71,15 +73,15 @@ mod tests { config: Config::builder() .routing_mode(RoutingMode::WRR) .models(vec![ - crate::config::ModelConfig::builder() - .id("model_a".to_string()) + crate::client::config::ModelConfig::builder() + .name("model_a".to_string()) .provider(Some("openai".to_string())) .base_url(Some("https://api.openai.com/v1".to_string())) .weight(1) .build() .unwrap(), - crate::config::ModelConfig::builder() - .id("model_b".to_string()) + crate::client::config::ModelConfig::builder() + .name("model_b".to_string()) .provider(Some("openai".to_string())) .base_url(Some("https://api.openai.com/v1".to_string())) .weight(3) @@ -95,13 +97,13 @@ mod tests { config: Config::builder() .models(vec![ ModelConfig::builder() - .id("model_a".to_string()) + .name("model_a".to_string()) .provider(Some("openai".to_string())) .base_url(Some("https://api.openai.com/v1".to_string())) .build() .unwrap(), ModelConfig::builder() - .id("model_b".to_string()) + .name("model_b".to_string()) .provider(Some("openai".to_string())) .base_url(Some("https://api.openai.com/v1".to_string())) .build() diff --git a/src/config.rs b/src/client/config.rs similarity index 80% rename from src/config.rs rename to src/client/config.rs index a056407..933dd3f 100644 --- a/src/config.rs +++ b/src/client/config.rs @@ -5,8 +5,7 @@ use derive_builder::Builder; use lazy_static::lazy_static; // ------------------ Provider ------------------ -pub type ProviderName = String; -const OPENAI_PROVIDER: &str = "OPENAI"; +pub const DEFAULT_PROVIDER: &str = "OPENAI"; lazy_static! { pub static ref PROVIDER_BASE_URLS: HashMap<&'static str, &'static str> = { @@ -15,7 +14,7 @@ lazy_static! { m.insert("DEEPINFRA", "https://api.deepinfra.com/v1/openai"); m.insert("OPENROUTER", "https://openrouter.ai/api/v1"); - m.insert("FAKE", "http://localhost:8080"); // test only + m.insert("FAKER", "http://localhost:8080"); // test only // TODO: support more providers here... m }; @@ -29,7 +28,7 @@ pub enum RoutingMode { } // ------------------ Model Config ------------------ -pub type ModelId = String; +pub type ModelName = String; #[derive(Debug, Clone, Builder)] #[builder(build_fn(validate = "Self::validate"), pattern = "mutable")] @@ -38,21 +37,21 @@ pub struct ModelConfig { #[builder(default = "None")] pub(crate) base_url: Option, #[builder(default = "None", setter(custom))] - pub(crate) provider: Option, + pub(crate) provider: Option, #[builder(default = "None")] pub(crate) temperature: Option, #[builder(default = "None")] pub(crate) max_output_tokens: Option, #[builder(setter(custom))] - pub(crate) id: ModelId, + pub(crate) name: ModelName, #[builder(default=-1)] pub(crate) weight: i32, } impl ModelConfigBuilder { - pub fn id>(&mut self, name: S) -> &mut Self { - self.id = Some(name.as_ref().to_string()); + pub fn name>(&mut self, name: S) -> &mut Self { + self.name = Some(name.as_ref().to_string()); self } @@ -65,8 +64,8 @@ impl ModelConfigBuilder { } fn validate(&self) -> Result<(), String> { - if self.id.is_none() { - return Err("Model id must be provided.".to_string()); + if self.name.is_none() { + return Err("Model name must be provided.".to_string()); } Ok(()) } @@ -83,10 +82,10 @@ impl ModelConfig { #[builder(build_fn(validate = "Self::validate"), pattern = "mutable")] pub struct Config { // global configs for models, will be overridden by model-specific configs - #[builder(default = "https://api.openai.com/v1".to_string())] - pub(crate) base_url: String, - #[builder(default = "ProviderName::from(OPENAI_PROVIDER)", setter(custom))] - pub(crate) provider: ProviderName, + #[builder(default=None, setter(custom))] + pub(crate) base_url: Option, + #[builder(default = "DEFAULT_PROVIDER.to_string()", setter(custom))] + pub(crate) provider: String, #[builder(default = "0.8")] pub(crate) temperature: f32, #[builder(default = "1024")] @@ -105,22 +104,34 @@ impl Config { // populate will fill in the missing model-specific configs with global configs. pub fn populate(&mut self) -> &mut Self { + let global_base_url = match self.base_url.is_some() { + true => self.base_url.clone(), + false => Some( + PROVIDER_BASE_URLS + .get(self.provider.as_str()) + .unwrap() + .to_string(), + ), + }; + for model in &mut self.models { - let model_url_exist = model.base_url.is_some(); + if model.base_url.is_none() { + if model.provider.is_some() { + model.base_url = Some( + PROVIDER_BASE_URLS + .get(model.provider.as_ref().unwrap().as_str()) + .unwrap() + .to_string(), + ); + } else { + model.base_url = global_base_url.clone(); + } + } if model.provider.is_none() { model.provider = Some(self.provider.clone()); } - if !model_url_exist - && PROVIDER_BASE_URLS.contains_key(model.provider.as_ref().unwrap().as_str()) - { - model.base_url = - Some(PROVIDER_BASE_URLS[model.provider.as_ref().unwrap().as_str()].to_string()); - } - if !model_url_exist { - model.base_url = Some(self.base_url.clone()); - } if model.temperature.is_none() { model.temperature = Some(self.temperature); } @@ -133,6 +144,10 @@ impl Config { } impl ConfigBuilder { + pub fn base_url>(&mut self, url: S) -> &mut Self { + self.base_url = Some(Some(url.as_ref().to_string())); + self + } pub fn model(&mut self, model: ModelConfig) -> &mut Self { let mut models = self.models.clone().unwrap_or_default(); models.push(model); @@ -157,7 +172,7 @@ impl ConfigBuilder { { return Err(format!( "Model '{}' weight must be non-negative in Weighted routing mode.", - model.id + model.name )); } @@ -165,7 +180,7 @@ impl ConfigBuilder { if max_output_tokens <= 0 { return Err(format!( "Model '{}' max_output_tokens must be positive.", - model.id + model.name )); } } @@ -174,14 +189,14 @@ impl ConfigBuilder { if temperature < 0.0 || temperature > 1.0 { return Err(format!( "Model '{}' temperature must be between 0.0 and 1.0.", - model.id + model.name )); } } // check the existence of API key in environment variables if let Some(provider) = &model.provider { - let env_var = format!("{}_API_KEY", provider.to_uppercase()); + let env_var = format!("{}_API_KEY", provider); if env::var(&env_var).is_err() { return Err(format!( "API key for provider '{}' not found in environment variable '{}'", @@ -195,7 +210,7 @@ impl ConfigBuilder { "{}_API_KEY", self.provider .as_ref() - .unwrap_or(&ProviderName::from(OPENAI_PROVIDER)) + .unwrap_or(&DEFAULT_PROVIDER.to_string()) .to_uppercase() ); if env::var(&env_var).is_err() { @@ -203,8 +218,7 @@ impl ConfigBuilder { "API key for provider '{}' not found in environment variable '{}'", self.provider .as_ref() - .unwrap_or(&ProviderName::from(OPENAI_PROVIDER)) - .to_uppercase(), + .unwrap_or(&DEFAULT_PROVIDER.to_string()), env_var )); } @@ -229,16 +243,14 @@ mod tests { let valid_simplest_models_cfg = Config::builder() .model( ModelConfig::builder() - .id("gpt-4".to_string()) + .name("gpt-4".to_string()) .build() .unwrap(), ) .build(); assert!(valid_simplest_models_cfg.is_ok()); - assert!(valid_simplest_models_cfg.as_ref().unwrap().provider == OPENAI_PROVIDER); - assert!( - valid_simplest_models_cfg.as_ref().unwrap().base_url == "https://api.openai.com/v1" - ); + assert!(valid_simplest_models_cfg.as_ref().unwrap().provider == DEFAULT_PROVIDER); + assert!(valid_simplest_models_cfg.as_ref().unwrap().base_url == None); assert!(valid_simplest_models_cfg.as_ref().unwrap().temperature == 0.8); assert!( valid_simplest_models_cfg @@ -259,11 +271,11 @@ mod tests { let valid_cfg = Config::builder() .models(vec![ ModelConfig::builder() - .id("gpt-3.5-turbo".to_string()) + .name("gpt-3.5-turbo".to_string()) .build() .unwrap(), ModelConfig::builder() - .id("gpt-4".to_string()) + .name("gpt-4".to_string()) .build() .unwrap(), ]) @@ -275,7 +287,7 @@ mod tests { let invalid_cfg_with_no_api_key = Config::builder() .model( ModelConfig::builder() - .id("some-model".to_string()) + .name("some-model".to_string()) .build() .unwrap(), ) @@ -286,11 +298,11 @@ mod tests { // case 4: // AMRS_API_KEY is set in .env.test already. let valid_cfg_with_customized_provider = Config::builder() - .base_url("http://example.ai".to_string()) + .base_url("http://example.ai") .max_output_tokens(2048) .model( ModelConfig::builder() - .id("custom-model") + .name("custom-model") .provider(Some("AMRS")) .build() .unwrap(), @@ -303,7 +315,7 @@ mod tests { assert!(invalid_empty_models_cfg.is_err()); // case 6: - print!("validating invalid empty model id config"); + print!("validating invalid empty model name config"); let invalid_empty_model_id_cfg = ModelConfig::builder().build(); assert!(invalid_empty_model_id_cfg.is_err()); } @@ -317,7 +329,7 @@ mod tests { .max_output_tokens(1500) .model( ModelConfig::builder() - .id("model-1".to_string()) + .name("model-1".to_string()) .build() .unwrap(), ) @@ -338,7 +350,7 @@ mod tests { let mut valid_specified_cfg = Config::builder() .provider("AMRS".to_string()) .base_url("http://custom-api.ai".to_string()) - .model(ModelConfig::builder().id("model-2").build().unwrap()) + .model(ModelConfig::builder().name("model-2").build().unwrap()) .build(); valid_specified_cfg.as_mut().unwrap().populate(); diff --git a/src/client/mod.rs b/src/client/mod.rs new file mode 100644 index 0000000..96f5bea --- /dev/null +++ b/src/client/mod.rs @@ -0,0 +1,5 @@ +pub mod client; +pub mod config; + +pub use client::Client; +pub use config::{Config, ModelConfig, ModelName, RoutingMode}; diff --git a/src/lib.rs b/src/lib.rs index 7027782..6fc90da 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,21 +1,18 @@ +pub mod client; + mod router { mod random; pub mod router; pub mod stats; mod wrr; } -mod config; -mod client { - pub mod client; -} + mod provider { - mod fake; + mod faker; mod openai; pub mod provider; } - -pub use crate::client::client::Client; -pub use crate::config::{Config, ModelConfig, RoutingMode}; -pub use crate::provider::provider::{ - APIError, CreateResponseArgs, CreateResponseReq, CreateResponseRes, -}; +pub mod types { + pub mod error; + pub mod responses; +} diff --git a/src/main.rs b/src/main.rs index e7a11a9..cf1447b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,52 @@ +use tokio::runtime::Runtime; + +use arms::client; +use arms::types::responses; + fn main() { - println!("Hello, world!"); + let config = client::Config::builder() + .provider("deepinfra") + .routing_mode(client::RoutingMode::WRR) + .model( + client::ModelConfig::builder() + .name("nvidia/Nemotron-3-Nano-30B-A3B") + .weight(1) + .build() + .unwrap(), + ) + .model( + client::ModelConfig::builder() + .name("deepseek-ai/DeepSeek-V3.2") + .weight(2) + .build() + .unwrap(), + ) + .build() + .unwrap(); + + let mut client = client::Client::new(config); + + let request = responses::CreateResponseArgs::default() + .input(responses::InputParam::Items(vec![ + responses::InputItem::EasyMessage(responses::EasyInputMessage { + r#type: responses::MessageType::Message, + role: responses::Role::User, + content: responses::EasyInputContent::Text("What is AGI?".to_string()), + }), + ])) + .build() + .unwrap(); + + let result = Runtime::new() + .unwrap() + .block_on(client.create_response(request)); + + match result { + Ok(response) => { + println!("Response ID: {}", response.id); + } + Err(e) => { + eprintln!("Error: {}", e); + } + } } diff --git a/src/provider/fake.rs b/src/provider/faker.rs similarity index 68% rename from src/provider/fake.rs rename to src/provider/faker.rs index a38b7b7..614c7f0 100644 --- a/src/provider/fake.rs +++ b/src/provider/faker.rs @@ -1,43 +1,35 @@ -use std::str::FromStr; - -use async_openai::types::responses::{ - AssistantRole, OutputItem, OutputMessage, OutputMessageContent, OutputStatus, - OutputTextContent, Status, -}; -use async_openai::{Client, config::OpenAIConfig}; use async_trait::async_trait; -use reqwest::header::HeaderName; -use crate::config::{ModelConfig, ModelId}; -use crate::provider::provider::{ - APIError, CreateResponseReq, CreateResponseRes, Provider, validate_request, +use crate::client::config::{ModelConfig, ModelName}; +use crate::provider::provider; +use crate::types::error::OpenAIError; +use crate::types::responses::{ + AssistantRole, CreateResponse, OutputItem, OutputMessage, OutputMessageContent, OutputStatus, + OutputTextContent, Response, Status, }; -pub struct FakeProvider { - model: ModelId, +pub struct FakerProvider { + model: ModelName, } -impl FakeProvider { +impl FakerProvider { pub fn new(config: ModelConfig) -> Self { Self { - model: config.id.clone(), + model: config.name.clone(), } } } #[async_trait] -impl Provider for FakeProvider { +impl provider::Provider for FakerProvider { fn name(&self) -> &'static str { "FakeProvider" } - async fn create_response( - &self, - request: CreateResponseReq, - ) -> Result { - validate_request(&request)?; + async fn create_response(&self, request: CreateResponse) -> Result { + provider::validate_responses_request(&request)?; - Ok(CreateResponseRes { + Ok(Response { id: "fake-response-id".to_string(), object: "text_completion".to_string(), model: self.model.clone(), diff --git a/src/provider/openai.rs b/src/provider/openai.rs index 753e591..5a97069 100644 --- a/src/provider/openai.rs +++ b/src/provider/openai.rs @@ -2,17 +2,19 @@ use async_openai::{Client, config::OpenAIConfig}; use async_trait::async_trait; use derive_builder::Builder; -use crate::config::{ModelConfig, ModelId}; -use crate::provider::provider::{ - APIError, CreateResponseReq, CreateResponseRes, Provider, validate_request, -}; +use crate::client::config::{DEFAULT_PROVIDER, ModelConfig, ModelName}; +use crate::provider::provider; +use crate::types::error::OpenAIError; +use crate::types::responses::{CreateResponse, Response}; #[derive(Debug, Clone, Builder)] #[builder(pattern = "mutable", build_fn(skip))] pub struct OpenAIProvider { - model: ModelId, + model: ModelName, config: OpenAIConfig, client: Client, + #[builder(default = "OPENAI_PROVIDER.to_string()", setter(custom))] + provider_name: String, } impl OpenAIProvider { @@ -28,34 +30,48 @@ impl OpenAIProvider { .with_api_key(api_key); OpenAIProviderBuilder { - model: Some(config.id.clone()), + model: Some(config.name.clone()), config: Some(openai_config), client: None, + provider_name: None, } } } impl OpenAIProviderBuilder { + pub fn provider_name>(&mut self, name: S) -> &mut Self { + self.provider_name = Some(name.as_ref().to_string()); + self + } + pub fn build(&mut self) -> OpenAIProvider { OpenAIProvider { model: self.model.clone().unwrap(), config: self.config.clone().unwrap(), client: Client::with_config(self.config.as_ref().unwrap().clone()), + provider_name: self + .provider_name + .clone() + .unwrap_or(DEFAULT_PROVIDER.to_string()), } } } #[async_trait] -impl Provider for OpenAIProvider { +impl provider::Provider for OpenAIProvider { fn name(&self) -> &'static str { "OpenAIProvider" } - async fn create_response( - &self, - request: CreateResponseReq, - ) -> Result { - validate_request(&request)?; + async fn create_response(&self, request: CreateResponse) -> Result { + if self.provider_name == "DEEPINFRA" { + return Err(OpenAIError::InvalidArgument(format!( + "Provider '{}' doesn't support Responses endpoint", + self.provider_name + ))); + } + + provider::validate_responses_request(&request)?; self.client.responses().create(request).await } } diff --git a/src/provider/provider.rs b/src/provider/provider.rs index 6ffe305..e760a1f 100644 --- a/src/provider/provider.rs +++ b/src/provider/provider.rs @@ -1,23 +1,21 @@ -use async_openai::error::OpenAIError as OpenAI_Error; -use async_openai::types::responses::{ - CreateResponse, CreateResponseArgs as OpenAICreateResponseArgs, Response, -}; use async_trait::async_trait; -use crate::config::ModelConfig; -use crate::provider::fake::FakeProvider; +use crate::client::config::ModelConfig; +use crate::provider::faker::FakerProvider; use crate::provider::openai::OpenAIProvider; - -pub type CreateResponseReq = CreateResponse; -pub type CreateResponseArgs = OpenAICreateResponseArgs; -pub type CreateResponseRes = Response; -pub type APIError = OpenAI_Error; +use crate::types::error::OpenAIError; +use crate::types::responses::{CreateResponse, Response}; pub fn construct_provider(config: ModelConfig) -> Box { - let provider = config.provider.as_ref().unwrap(); + let provider = config.provider.clone().unwrap(); + match provider.to_uppercase().as_ref() { - "FAKE" => Box::new(FakeProvider::new(config)), - "OPENAI" => Box::new(OpenAIProvider::builder(config).build()), + "FAKER" => Box::new(FakerProvider::new(config)), + "OPENAI" | "DEEPINFRA" => Box::new( + OpenAIProvider::builder(config) + .provider_name(provider) + .build(), + ), _ => panic!("Unsupported provider: {}", provider), } } @@ -25,16 +23,13 @@ pub fn construct_provider(config: ModelConfig) -> Box { #[async_trait] pub trait Provider: Send + Sync { fn name(&self) -> &'static str; - async fn create_response( - &self, - request: CreateResponseReq, - ) -> Result; + async fn create_response(&self, request: CreateResponse) -> Result; } -pub fn validate_request(request: &CreateResponseReq) -> Result<(), APIError> { +pub fn validate_responses_request(request: &CreateResponse) -> Result<(), OpenAIError> { if request.model.is_some() { - return Err(APIError::InvalidArgument( - "Model ID must be specified in the config".to_string(), + return Err(OpenAIError::InvalidArgument( + "Model must be specified in the client.Config".to_string(), )); } Ok(()) @@ -56,7 +51,7 @@ mod tests { TestCase { name: "OpenAI Provider", config: ModelConfig::builder() - .id("test-model".to_string()) + .name("test-model".to_string()) .provider(Some("openai".to_string())) .base_url(Some("https://api.openai.com/v1".to_string())) .build() @@ -66,7 +61,7 @@ mod tests { TestCase { name: "Unsupported Provider", config: ModelConfig::builder() - .id("test-model".to_string()) + .name("test-model".to_string()) .provider(Some("unsupported".to_string())) .base_url(Some("https://api.openai.com/v1".to_string())) .build() diff --git a/src/router/random.rs b/src/router/random.rs index 1a5591b..c7f6283 100644 --- a/src/router/random.rs +++ b/src/router/random.rs @@ -1,8 +1,8 @@ use rand::Rng; -use crate::config::ModelId; -use crate::provider::provider::CreateResponseReq; +use crate::client::config::ModelName; use crate::router::router::{ModelInfo, Router}; +use crate::types::responses::CreateResponse; pub struct RandomRouter { pub model_infos: Vec, @@ -19,10 +19,10 @@ impl Router for RandomRouter { "RandomRouter" } - fn sample(&mut self, _input: &CreateResponseReq) -> ModelId { + fn sample(&mut self, _input: &CreateResponse) -> ModelName { let mut rng = rand::rng(); let idx = rng.random_range(0..self.model_infos.len()); - self.model_infos[idx].id.clone() + self.model_infos[idx].name.clone() } } @@ -34,15 +34,15 @@ mod tests { fn test_random_router_sampling() { let model_infos = vec![ ModelInfo { - id: "model_x".to_string(), + name: "model_x".to_string(), weight: 1, }, ModelInfo { - id: "model_y".to_string(), + name: "model_y".to_string(), weight: 2, }, ModelInfo { - id: "model_z".to_string(), + name: "model_z".to_string(), weight: 3, }, ]; @@ -50,8 +50,8 @@ mod tests { let mut counts = std::collections::HashMap::new(); for _ in 0..1000 { - let sampled_id = router.sample(&CreateResponseReq::default()); - *counts.entry(sampled_id.clone()).or_insert(0) += 1; + let candidate = router.sample(&CreateResponse::default()); + *counts.entry(candidate.clone()).or_insert(0) += 1; } assert!(counts.len() == model_infos.len()); for count in counts.values() { diff --git a/src/router/router.rs b/src/router/router.rs index 0647cec..3182984 100644 --- a/src/router/router.rs +++ b/src/router/router.rs @@ -1,14 +1,11 @@ -use std::collections::HashMap; -use std::sync::atomic::AtomicUsize; - -use crate::config::{ModelConfig, ModelId, RoutingMode}; -use crate::provider::provider::CreateResponseReq; +use crate::client::config::{ModelConfig, ModelName, RoutingMode}; use crate::router::random::RandomRouter; use crate::router::wrr::WeightedRoundRobinRouter; +use crate::types::responses::CreateResponse; #[derive(Debug, Clone)] pub struct ModelInfo { - pub id: ModelId, + pub name: ModelName, pub weight: i32, } @@ -16,7 +13,7 @@ pub fn construct_router(mode: RoutingMode, models: Vec) -> Box = models .iter() .map(|m| ModelInfo { - id: m.id.clone(), + name: m.name.clone(), weight: m.weight.clone(), }) .collect(); @@ -28,7 +25,7 @@ pub fn construct_router(mode: RoutingMode, models: Vec) -> Box &'static str; - fn sample(&mut self, input: &CreateResponseReq) -> ModelId; + fn sample(&mut self, input: &CreateResponse) -> ModelName; } #[cfg(test)] @@ -38,13 +35,13 @@ mod tests { fn test_router_construction() { let model_configs = vec![ ModelConfig::builder() - .id("model_a".to_string()) + .name("model_a".to_string()) .provider(Some("openai".to_string())) .base_url(Some("https://api.openai.com/v1".to_string())) .build() .unwrap(), ModelConfig::builder() - .id("model_b".to_string()) + .name("model_b".to_string()) .provider(Some("openai".to_string())) .base_url(Some("https://api.openai.com/v1".to_string())) .build() diff --git a/src/router/stats.rs b/src/router/stats.rs index 7b4251a..f586d0f 100644 --- a/src/router/stats.rs +++ b/src/router/stats.rs @@ -1,10 +1,10 @@ use std::collections::HashMap; use std::sync::atomic::{AtomicUsize, Ordering}; -use crate::config::ModelId; +use crate::client::config::ModelName; pub struct RouterStats { - requests_per_model: HashMap, + requests_per_model: HashMap, } impl RouterStats { @@ -14,7 +14,7 @@ impl RouterStats { } } - pub fn increment_request(&mut self, model_id: &ModelId) -> usize { + pub fn increment_request(&mut self, model_id: &ModelName) -> usize { let counter = self .requests_per_model .entry(model_id.clone()) diff --git a/src/router/wrr.rs b/src/router/wrr.rs index fa5e481..591def3 100644 --- a/src/router/wrr.rs +++ b/src/router/wrr.rs @@ -1,5 +1,6 @@ +use crate::client::config::ModelName; use crate::router::router::{ModelInfo, Router}; -use crate::{config::ModelId, provider::provider::CreateResponseReq}; +use crate::types::responses::CreateResponse; pub struct WeightedRoundRobinRouter { total_weight: i32, @@ -27,10 +28,10 @@ impl Router for WeightedRoundRobinRouter { } // Use Smooth Weighted Round Robin Algorithm. - fn sample(&mut self, _input: &CreateResponseReq) -> ModelId { + fn sample(&mut self, _input: &CreateResponse) -> ModelName { // return early if only one model. if self.model_infos.len() == 1 { - return self.model_infos[0].id.clone(); + return self.model_infos[0].name.clone(); } self.current_weights @@ -48,7 +49,7 @@ impl Router for WeightedRoundRobinRouter { } self.current_weights[max_index] -= self.total_weight; - self.model_infos[max_index].id.clone() + self.model_infos[max_index].name.clone() } } @@ -61,22 +62,22 @@ mod tests { fn test_weighted_round_robin_sampling() { let model_infos = vec![ ModelInfo { - id: "model_x".to_string(), + name: "model_x".to_string(), weight: 1, }, ModelInfo { - id: "model_y".to_string(), + name: "model_y".to_string(), weight: 3, }, ModelInfo { - id: "model_z".to_string(), + name: "model_z".to_string(), weight: 6, }, ]; let mut wrr = WeightedRoundRobinRouter::new(model_infos.clone()); let mut counts = HashMap::new(); for _ in 0..1000 { - let sampled_id = wrr.sample(&CreateResponseReq::default()); + let sampled_id = wrr.sample(&CreateResponse::default()); *counts.entry(sampled_id.clone()).or_insert(0) += 1; } assert!(counts.len() == model_infos.len()); diff --git a/src/types/error.rs b/src/types/error.rs new file mode 100644 index 0000000..b3eeac8 --- /dev/null +++ b/src/types/error.rs @@ -0,0 +1 @@ +pub use async_openai::error::*; diff --git a/src/types/responses.rs b/src/types/responses.rs new file mode 100644 index 0000000..3756fa2 --- /dev/null +++ b/src/types/responses.rs @@ -0,0 +1 @@ +pub use async_openai::types::responses::*; diff --git a/tests/client.rs b/tests/client.rs index ecdea8c..732113e 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -1,6 +1,7 @@ use dotenvy::from_filename; -use arms::{Client, Config, CreateResponseArgs, ModelConfig, RoutingMode}; +use arms::client; +use arms::types::responses; #[cfg(test)] mod tests { @@ -11,14 +12,19 @@ mod tests { from_filename(".env.integration-test").ok(); // case 1: one model. - let config = Config::builder() - .provider("fake") - .model(ModelConfig::builder().id("fake-model").build().unwrap()) + let config = client::Config::builder() + .provider("faker") + .model( + client::ModelConfig::builder() + .name("fake-model") + .build() + .unwrap(), + ) .build() .unwrap(); - let mut client = Client::new(config); - let request = CreateResponseArgs::default() + let mut client = client::Client::new(config); + let request = responses::CreateResponseArgs::default() .input("tell me the weather today") .build() .unwrap(); @@ -28,13 +34,18 @@ mod tests { assert!(response.model == "fake-model"); // case 2: specify model in request. - let config = Config::builder() + let config = client::Config::builder() .provider("openai") - .model(ModelConfig::builder().id("gpt-3.5-turbo").build().unwrap()) + .model( + client::ModelConfig::builder() + .name("gpt-3.5-turbo") + .build() + .unwrap(), + ) .build() .unwrap(); - let mut client = Client::new(config); - let request = CreateResponseArgs::default() + let mut client = client::Client::new(config); + let request = responses::CreateResponseArgs::default() .model("gpt-3.5-turbo") .input("tell me a joke") .build() @@ -43,27 +54,27 @@ mod tests { assert!(response.is_err()); // case 3: multiple models with router. - let config = Config::builder() - .provider("fake") - .routing_mode(RoutingMode::WRR) + let config = client::Config::builder() + .provider("faker") + .routing_mode(client::RoutingMode::WRR) .model( - ModelConfig::builder() - .id("gpt-3.5-turbo") + client::ModelConfig::builder() + .name("gpt-3.5-turbo") .weight(1) .build() .unwrap(), ) .model( - ModelConfig::builder() - .id("gpt-4") + client::ModelConfig::builder() + .name("gpt-4") .weight(1) .build() .unwrap(), ) .build() .unwrap(); - let mut client = Client::new(config); - let request = CreateResponseArgs::default() + let mut client = client::Client::new(config); + let request = responses::CreateResponseArgs::default() .input("give me a poem about nature") .build() .unwrap();