diff --git a/Cargo.toml b/Cargo.toml index d2cdedc..8d3c6e2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,7 @@ version = "0.1.0" edition = "2024" [dependencies] -async-openai = { version = "0.31.1", features = ["_api", "response-types", "responses",] } +async-openai = { version = "0.31.1", features = ["_api", "response-types", "responses", "chat-completion"] } async-trait = "0.1.89" derive_builder = "0.20.2" dotenvy = "0.15.7" diff --git a/README.md b/README.md index 30485e4..f57774b 100644 --- a/README.md +++ b/README.md @@ -27,8 +27,9 @@ Here's a simple example with the Weighted Round Robin (WRR) routing mode: // Before running the code, make sure to set your OpenAI API key in the environment variable: // export OPENAI_API_KEY="your_openai_api_key" +use tokio::runtime::Runtime; use arms::client; -use arms::types::responses; +use arms::types::chat; let config = client::Config::builder() .provider("openai") @@ -51,12 +52,15 @@ let config = client::Config::builder() .unwrap(); let mut client = client::Client::new(config); -let request = responses::CreateResponseArgs::default() - .input("give me a poem about nature") +let request = chat::CreateChatCompletionRequestArgs::default() + .messages([ + chat::ChatCompletionRequestSystemMessage::from("You are a helpful assistant.").into(), + chat::ChatCompletionRequestUserMessage::from("How is the weather today?").into(), + ]) .build() .unwrap(); -let response = client.create_response(request).await.unwrap(); +let result = Runtime::new().unwrap().block_on(client.create_completion(request)); ``` ## Contributing diff --git a/src/client/client.rs b/src/client/client.rs index 2f628f6..d5762c7 100644 --- a/src/client/client.rs +++ b/src/client/client.rs @@ -4,7 +4,7 @@ use crate::client::config::{Config, ModelName}; use crate::provider::provider; use crate::router::router; use crate::types::error::OpenAIError; -use crate::types::responses::{CreateResponse, Response}; +use crate::types::{chat, responses}; pub struct Client { providers: HashMap>, @@ -30,12 +30,22 @@ impl Client { pub async fn create_response( &mut self, - request: CreateResponse, - ) -> Result { - let candidate = self.router.sample(&request); + request: responses::CreateResponse, + ) -> Result { + let candidate = self.router.sample(); let provider = self.providers.get(&candidate).unwrap(); provider.create_response(request).await } + + // This is chat completion endpoint. + pub async fn create_completion( + &mut self, + request: chat::CreateChatCompletionRequest, + ) -> Result { + let candidate = self.router.sample(); + let provider = self.providers.get(&candidate).unwrap(); + provider.create_completion(request).await + } } #[cfg(test)] diff --git a/src/client/config.rs b/src/client/config.rs index 933dd3f..f50de68 100644 --- a/src/client/config.rs +++ b/src/client/config.rs @@ -38,10 +38,6 @@ pub struct ModelConfig { pub(crate) base_url: Option, #[builder(default = "None", setter(custom))] pub(crate) provider: Option, - #[builder(default = "None")] - pub(crate) temperature: Option, - #[builder(default = "None")] - pub(crate) max_output_tokens: Option, #[builder(setter(custom))] pub(crate) name: ModelName, @@ -86,10 +82,6 @@ pub struct Config { pub(crate) base_url: Option, #[builder(default = "DEFAULT_PROVIDER.to_string()", setter(custom))] pub(crate) provider: String, - #[builder(default = "0.8")] - pub(crate) temperature: f32, - #[builder(default = "1024")] - pub(crate) max_output_tokens: usize, #[builder(default = "RoutingMode::Random")] pub(crate) routing_mode: RoutingMode, @@ -131,13 +123,6 @@ impl Config { if model.provider.is_none() { model.provider = Some(self.provider.clone()); } - - if model.temperature.is_none() { - model.temperature = Some(self.temperature); - } - if model.max_output_tokens.is_none() { - model.max_output_tokens = Some(self.max_output_tokens); - } } self } @@ -176,24 +161,6 @@ impl ConfigBuilder { )); } - if let Some(max_output_tokens) = model.max_output_tokens { - if max_output_tokens <= 0 { - return Err(format!( - "Model '{}' max_output_tokens must be positive.", - model.name - )); - } - } - - if let Some(temperature) = model.temperature { - if temperature < 0.0 || temperature > 1.0 { - return Err(format!( - "Model '{}' temperature must be between 0.0 and 1.0.", - model.name - )); - } - } - // check the existence of API key in environment variables if let Some(provider) = &model.provider { let env_var = format!("{}_API_KEY", provider); @@ -251,20 +218,10 @@ mod tests { assert!(valid_simplest_models_cfg.is_ok()); assert!(valid_simplest_models_cfg.as_ref().unwrap().provider == DEFAULT_PROVIDER); assert!(valid_simplest_models_cfg.as_ref().unwrap().base_url == None); - assert!(valid_simplest_models_cfg.as_ref().unwrap().temperature == 0.8); - assert!( - valid_simplest_models_cfg - .as_ref() - .unwrap() - .max_output_tokens - == 1024 - ); assert!(valid_simplest_models_cfg.as_ref().unwrap().routing_mode == RoutingMode::Random); assert!(valid_simplest_models_cfg.as_ref().unwrap().models.len() == 1); assert!(valid_simplest_models_cfg.as_ref().unwrap().models[0].base_url == None); assert!(valid_simplest_models_cfg.as_ref().unwrap().models[0].provider == None); - assert!(valid_simplest_models_cfg.as_ref().unwrap().models[0].temperature == None); - assert!(valid_simplest_models_cfg.as_ref().unwrap().models[0].max_output_tokens == None); assert!(valid_simplest_models_cfg.as_ref().unwrap().models[0].weight == -1); // case 2: @@ -299,7 +256,6 @@ mod tests { // AMRS_API_KEY is set in .env.test already. let valid_cfg_with_customized_provider = Config::builder() .base_url("http://example.ai") - .max_output_tokens(2048) .model( ModelConfig::builder() .name("custom-model") @@ -325,8 +281,6 @@ mod tests { from_filename(".env.test").ok(); let mut valid_cfg = Config::builder() - .temperature(0.5) - .max_output_tokens(1500) .model( ModelConfig::builder() .name("model-1".to_string()) @@ -338,8 +292,6 @@ mod tests { assert!(valid_cfg.is_ok()); assert!(valid_cfg.as_ref().unwrap().models.len() == 1); - assert!(valid_cfg.as_ref().unwrap().models[0].temperature == Some(0.5)); - assert!(valid_cfg.as_ref().unwrap().models[0].max_output_tokens == Some(1500)); assert!(valid_cfg.as_ref().unwrap().models[0].provider == Some("OPENAI".to_string())); assert!( valid_cfg.as_ref().unwrap().models[0].base_url diff --git a/src/lib.rs b/src/lib.rs index 6fc90da..d985e2e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,11 +8,13 @@ mod router { } mod provider { + mod common; mod faker; mod openai; pub mod provider; } pub mod types { + pub mod chat; pub mod error; pub mod responses; } diff --git a/src/main.rs b/src/main.rs index cf1447b..8667a88 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,9 +1,10 @@ use tokio::runtime::Runtime; use arms::client; -use arms::types::responses; +use arms::types::chat; fn main() { + // case 1: completion with DeepInfra provider. let config = client::Config::builder() .provider("deepinfra") .routing_mode(client::RoutingMode::WRR) @@ -26,27 +27,77 @@ fn main() { let mut client = client::Client::new(config); - let request = responses::CreateResponseArgs::default() - .input(responses::InputParam::Items(vec![ - responses::InputItem::EasyMessage(responses::EasyInputMessage { - r#type: responses::MessageType::Message, - role: responses::Role::User, - content: responses::EasyInputContent::Text("What is AGI?".to_string()), - }), - ])) + let request = chat::CreateChatCompletionRequestArgs::default() + .messages([ + chat::ChatCompletionRequestSystemMessage::from("You are a helpful assistant.").into(), + chat::ChatCompletionRequestUserMessage::from("Who won the world series in 2020?") + .into(), + chat::ChatCompletionRequestAssistantMessage::from( + "The Los Angeles Dodgers won the World Series in 2020.", + ) + .into(), + chat::ChatCompletionRequestUserMessage::from("Where was it played?").into(), + ]) .build() .unwrap(); let result = Runtime::new() .unwrap() - .block_on(client.create_response(request)); + .block_on(client.create_completion(request)); match result { Ok(response) => { - println!("Response ID: {}", response.id); + println!("Response: {:?}", response); } Err(e) => { eprintln!("Error: {}", e); } } + + // case 2: response with DeepInfra provider. + // let config = client::Config::builder() + // .provider("deepinfra") + // .routing_mode(client::RoutingMode::WRR) + // .model( + // client::ModelConfig::builder() + // .name("nvidia/Nemotron-3-Nano-30B-A3B") + // .weight(1) + // .build() + // .unwrap(), + // ) + // .model( + // client::ModelConfig::builder() + // .name("deepseek-ai/DeepSeek-V3.2") + // .weight(2) + // .build() + // .unwrap(), + // ) + // .build() + // .unwrap(); + + // let mut client = client::Client::new(config); + + // let request = responses::CreateResponseArgs::default() + // .input(responses::InputParam::Items(vec![ + // responses::InputItem::EasyMessage(responses::EasyInputMessage { + // r#type: responses::MessageType::Message, + // role: responses::Role::User, + // content: responses::EasyInputContent::Text("What is AGI?".to_string()), + // }), + // ])) + // .build() + // .unwrap(); + + // let result = Runtime::new() + // .unwrap() + // .block_on(client.create_response(request)); + + // match result { + // Ok(response) => { + // println!("Response ID: {}", response.id); + // } + // Err(e) => { + // eprintln!("Error: {}", e); + // } + // } } diff --git a/src/provider/common.rs b/src/provider/common.rs new file mode 100644 index 0000000..3b041ef --- /dev/null +++ b/src/provider/common.rs @@ -0,0 +1,22 @@ +use crate::types::error::OpenAIError; +use crate::types::{chat, responses}; + +pub fn validate_completion_request( + request: &chat::CreateChatCompletionRequest, +) -> Result<(), OpenAIError> { + if request.model != "" { + return Err(OpenAIError::InvalidArgument( + "Model must be specified in the client.Config".to_string(), + )); + } + Ok(()) +} + +pub fn validate_response_request(request: &responses::CreateResponse) -> Result<(), OpenAIError> { + if request.model.is_some() { + return Err(OpenAIError::InvalidArgument( + "Model must be specified in the client.Config".to_string(), + )); + } + Ok(()) +} diff --git a/src/provider/faker.rs b/src/provider/faker.rs index 614c7f0..e38edd0 100644 --- a/src/provider/faker.rs +++ b/src/provider/faker.rs @@ -1,7 +1,8 @@ use async_trait::async_trait; use crate::client::config::{ModelConfig, ModelName}; -use crate::provider::provider; +use crate::provider::{common, provider}; +use crate::types::chat; use crate::types::error::OpenAIError; use crate::types::responses::{ AssistantRole, CreateResponse, OutputItem, OutputMessage, OutputMessageContent, OutputStatus, @@ -26,8 +27,8 @@ impl provider::Provider for FakerProvider { "FakeProvider" } - async fn create_response(&self, request: CreateResponse) -> Result { - provider::validate_responses_request(&request)?; + async fn create_response(&self, _request: CreateResponse) -> Result { + common::validate_response_request(&_request)?; Ok(Response { id: "fake-response-id".to_string(), @@ -71,4 +72,34 @@ impl provider::Provider for FakerProvider { truncation: None, }) } + + async fn create_completion( + &self, + request: chat::CreateChatCompletionRequest, + ) -> Result { + common::validate_completion_request(&request)?; + Ok(chat::CreateChatCompletionResponse { + id: "fake-completion-id".to_string(), + object: "text_completion".to_string(), + created: 1_600_000_000, + model: self.model.clone(), + usage: None, + service_tier: None, + choices: vec![chat::ChatChoice { + index: 0, + message: chat::ChatCompletionResponseMessage { + role: chat::Role::Assistant, + content: Some("This is a fake chat completion.".to_string()), + refusal: None, + tool_calls: None, + annotations: None, + function_call: None, + audio: None, + }, + finish_reason: None, + logprobs: None, + }], + system_fingerprint: None, + }) + } } diff --git a/src/provider/openai.rs b/src/provider/openai.rs index 5a97069..d89853e 100644 --- a/src/provider/openai.rs +++ b/src/provider/openai.rs @@ -3,9 +3,9 @@ use async_trait::async_trait; use derive_builder::Builder; use crate::client::config::{DEFAULT_PROVIDER, ModelConfig, ModelName}; -use crate::provider::provider; +use crate::provider::{common, provider}; use crate::types::error::OpenAIError; -use crate::types::responses::{CreateResponse, Response}; +use crate::types::{chat, responses}; #[derive(Debug, Clone, Builder)] #[builder(pattern = "mutable", build_fn(skip))] @@ -63,15 +63,34 @@ impl provider::Provider for OpenAIProvider { "OpenAIProvider" } - async fn create_response(&self, request: CreateResponse) -> Result { - if self.provider_name == "DEEPINFRA" { + async fn create_completion( + &self, + request: chat::CreateChatCompletionRequest, + ) -> Result { + common::validate_completion_request(&request)?; + + // Set the model after validation since model is bind to the provider. + let mut req = request.clone(); + req.model = self.model.clone(); + self.client.chat().create(req).await + } + + async fn create_response( + &self, + request: responses::CreateResponse, + ) -> Result { + if !provider::RESPONSE_ENDPOINT_PROVIDERS.contains(&self.provider_name.as_str()) { return Err(OpenAIError::InvalidArgument(format!( "Provider '{}' doesn't support Responses endpoint", self.provider_name ))); } - provider::validate_responses_request(&request)?; - self.client.responses().create(request).await + common::validate_response_request(&request)?; + + // Set the model after validation since model is bind to the provider. + let mut req = request.clone(); + req.model = Some(self.model.clone()); + self.client.responses().create(req).await } } diff --git a/src/provider/provider.rs b/src/provider/provider.rs index e760a1f..2bf83d2 100644 --- a/src/provider/provider.rs +++ b/src/provider/provider.rs @@ -4,7 +4,10 @@ use crate::client::config::ModelConfig; use crate::provider::faker::FakerProvider; use crate::provider::openai::OpenAIProvider; use crate::types::error::OpenAIError; -use crate::types::responses::{CreateResponse, Response}; +use crate::types::{chat, responses}; + +// Not all providers support response endpoint. +pub const RESPONSE_ENDPOINT_PROVIDERS: &[&str] = &["FAKER", "OPENAI"]; pub fn construct_provider(config: ModelConfig) -> Box { let provider = config.provider.clone().unwrap(); @@ -22,17 +25,16 @@ pub fn construct_provider(config: ModelConfig) -> Box { #[async_trait] pub trait Provider: Send + Sync { + // Used in tests only now. fn name(&self) -> &'static str; - async fn create_response(&self, request: CreateResponse) -> Result; -} - -pub fn validate_responses_request(request: &CreateResponse) -> Result<(), OpenAIError> { - if request.model.is_some() { - return Err(OpenAIError::InvalidArgument( - "Model must be specified in the client.Config".to_string(), - )); - } - Ok(()) + async fn create_response( + &self, + request: responses::CreateResponse, + ) -> Result; + async fn create_completion( + &self, + request: chat::CreateChatCompletionRequest, + ) -> Result; } #[cfg(test)] diff --git a/src/router/random.rs b/src/router/random.rs index c7f6283..1ca2221 100644 --- a/src/router/random.rs +++ b/src/router/random.rs @@ -2,7 +2,6 @@ use rand::Rng; use crate::client::config::ModelName; use crate::router::router::{ModelInfo, Router}; -use crate::types::responses::CreateResponse; pub struct RandomRouter { pub model_infos: Vec, @@ -19,7 +18,7 @@ impl Router for RandomRouter { "RandomRouter" } - fn sample(&mut self, _input: &CreateResponse) -> ModelName { + fn sample(&mut self) -> ModelName { let mut rng = rand::rng(); let idx = rng.random_range(0..self.model_infos.len()); self.model_infos[idx].name.clone() @@ -50,7 +49,7 @@ mod tests { let mut counts = std::collections::HashMap::new(); for _ in 0..1000 { - let candidate = router.sample(&CreateResponse::default()); + let candidate = router.sample(); *counts.entry(candidate.clone()).or_insert(0) += 1; } assert!(counts.len() == model_infos.len()); diff --git a/src/router/router.rs b/src/router/router.rs index 3182984..6532a07 100644 --- a/src/router/router.rs +++ b/src/router/router.rs @@ -1,7 +1,6 @@ use crate::client::config::{ModelConfig, ModelName, RoutingMode}; use crate::router::random::RandomRouter; use crate::router::wrr::WeightedRoundRobinRouter; -use crate::types::responses::CreateResponse; #[derive(Debug, Clone)] pub struct ModelInfo { @@ -25,7 +24,7 @@ pub fn construct_router(mode: RoutingMode, models: Vec) -> Box &'static str; - fn sample(&mut self, input: &CreateResponse) -> ModelName; + fn sample(&mut self) -> ModelName; } #[cfg(test)] diff --git a/src/router/wrr.rs b/src/router/wrr.rs index 591def3..cf9c416 100644 --- a/src/router/wrr.rs +++ b/src/router/wrr.rs @@ -1,6 +1,5 @@ use crate::client::config::ModelName; use crate::router::router::{ModelInfo, Router}; -use crate::types::responses::CreateResponse; pub struct WeightedRoundRobinRouter { total_weight: i32, @@ -28,7 +27,7 @@ impl Router for WeightedRoundRobinRouter { } // Use Smooth Weighted Round Robin Algorithm. - fn sample(&mut self, _input: &CreateResponse) -> ModelName { + fn sample(&mut self) -> ModelName { // return early if only one model. if self.model_infos.len() == 1 { return self.model_infos[0].name.clone(); @@ -77,7 +76,7 @@ mod tests { let mut wrr = WeightedRoundRobinRouter::new(model_infos.clone()); let mut counts = HashMap::new(); for _ in 0..1000 { - let sampled_id = wrr.sample(&CreateResponse::default()); + let sampled_id = wrr.sample(); *counts.entry(sampled_id.clone()).or_insert(0) += 1; } assert!(counts.len() == model_infos.len()); diff --git a/src/types/chat.rs b/src/types/chat.rs new file mode 100644 index 0000000..e92a02e --- /dev/null +++ b/src/types/chat.rs @@ -0,0 +1 @@ +pub use async_openai::types::chat::*; diff --git a/tests/client.rs b/tests/client.rs index 732113e..bf9a41d 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -1,6 +1,7 @@ use dotenvy::from_filename; use arms::client; +use arms::types::chat; use arms::types::responses; #[cfg(test)] @@ -80,4 +81,29 @@ mod tests { .unwrap(); let _ = client.create_response(request).await.unwrap(); } + + #[tokio::test] + async fn test_completion() { + from_filename(".env.integration-test").ok(); + + let config = client::Config::builder() + .provider("faker") + .model( + client::ModelConfig::builder() + .name("fake-completion-model") + .build() + .unwrap(), + ) + .build() + .unwrap(); + + let mut client = client::Client::new(config); + let request = chat::CreateChatCompletionRequestArgs::default() + .build() + .unwrap(); + + let response = client.create_completion(request).await.unwrap(); + assert!(response.id.starts_with("fake-completion-id")); + assert!(response.model == "fake-completion-model"); + } }