From aa09c1a4fb09c2ef5677bf44f3be133a98cc4584 Mon Sep 17 00:00:00 2001 From: kerthcet Date: Sat, 27 Dec 2025 10:18:28 +0800 Subject: [PATCH 1/3] Support completion Signed-off-by: kerthcet --- Cargo.toml | 2 +- src/client/client.rs | 17 ++++++++--- src/client/config.rs | 48 ------------------------------- src/lib.rs | 2 ++ src/main.rs | 62 +++++++++++++++++++++++++++++++++------- src/provider/common.rs | 22 ++++++++++++++ src/provider/faker.rs | 30 +++++++++++++++++-- src/provider/openai.rs | 31 ++++++++++++++++---- src/provider/provider.rs | 24 +++++++++------- src/router/random.rs | 5 ++-- src/router/router.rs | 2 +- src/router/wrr.rs | 4 +-- src/types/completions.rs | 1 + tests/client.rs | 26 +++++++++++++++++ 14 files changed, 187 insertions(+), 89 deletions(-) create mode 100644 src/provider/common.rs create mode 100644 src/types/completions.rs diff --git a/Cargo.toml b/Cargo.toml index d2cdedc..7958d55 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,7 @@ version = "0.1.0" edition = "2024" [dependencies] -async-openai = { version = "0.31.1", features = ["_api", "response-types", "responses",] } +async-openai = { version = "0.31.1", features = ["_api", "response-types", "responses", "completions"] } async-trait = "0.1.89" derive_builder = "0.20.2" dotenvy = "0.15.7" diff --git a/src/client/client.rs b/src/client/client.rs index 2f628f6..ad431a2 100644 --- a/src/client/client.rs +++ b/src/client/client.rs @@ -4,7 +4,7 @@ use crate::client::config::{Config, ModelName}; use crate::provider::provider; use crate::router::router; use crate::types::error::OpenAIError; -use crate::types::responses::{CreateResponse, Response}; +use crate::types::{completions, responses}; pub struct Client { providers: HashMap>, @@ -30,12 +30,21 @@ impl Client { pub async fn create_response( &mut self, - request: CreateResponse, - ) -> Result { - let candidate = self.router.sample(&request); + request: responses::CreateResponse, + ) -> Result { + let candidate = self.router.sample(); let provider = self.providers.get(&candidate).unwrap(); provider.create_response(request).await } + + pub async fn create_completion( + &mut self, + request: completions::CreateCompletionRequest, + ) -> Result { + let candidate = self.router.sample(); + let provider = self.providers.get(&candidate).unwrap(); + provider.create_completion(request).await + } } #[cfg(test)] diff --git a/src/client/config.rs b/src/client/config.rs index 933dd3f..f50de68 100644 --- a/src/client/config.rs +++ b/src/client/config.rs @@ -38,10 +38,6 @@ pub struct ModelConfig { pub(crate) base_url: Option, #[builder(default = "None", setter(custom))] pub(crate) provider: Option, - #[builder(default = "None")] - pub(crate) temperature: Option, - #[builder(default = "None")] - pub(crate) max_output_tokens: Option, #[builder(setter(custom))] pub(crate) name: ModelName, @@ -86,10 +82,6 @@ pub struct Config { pub(crate) base_url: Option, #[builder(default = "DEFAULT_PROVIDER.to_string()", setter(custom))] pub(crate) provider: String, - #[builder(default = "0.8")] - pub(crate) temperature: f32, - #[builder(default = "1024")] - pub(crate) max_output_tokens: usize, #[builder(default = "RoutingMode::Random")] pub(crate) routing_mode: RoutingMode, @@ -131,13 +123,6 @@ impl Config { if model.provider.is_none() { model.provider = Some(self.provider.clone()); } - - if model.temperature.is_none() { - model.temperature = Some(self.temperature); - } - if model.max_output_tokens.is_none() { - model.max_output_tokens = Some(self.max_output_tokens); - } } self } @@ -176,24 +161,6 @@ impl ConfigBuilder { )); } - if let Some(max_output_tokens) = model.max_output_tokens { - if max_output_tokens <= 0 { - return Err(format!( - "Model '{}' max_output_tokens must be positive.", - model.name - )); - } - } - - if let Some(temperature) = model.temperature { - if temperature < 0.0 || temperature > 1.0 { - return Err(format!( - "Model '{}' temperature must be between 0.0 and 1.0.", - model.name - )); - } - } - // check the existence of API key in environment variables if let Some(provider) = &model.provider { let env_var = format!("{}_API_KEY", provider); @@ -251,20 +218,10 @@ mod tests { assert!(valid_simplest_models_cfg.is_ok()); assert!(valid_simplest_models_cfg.as_ref().unwrap().provider == DEFAULT_PROVIDER); assert!(valid_simplest_models_cfg.as_ref().unwrap().base_url == None); - assert!(valid_simplest_models_cfg.as_ref().unwrap().temperature == 0.8); - assert!( - valid_simplest_models_cfg - .as_ref() - .unwrap() - .max_output_tokens - == 1024 - ); assert!(valid_simplest_models_cfg.as_ref().unwrap().routing_mode == RoutingMode::Random); assert!(valid_simplest_models_cfg.as_ref().unwrap().models.len() == 1); assert!(valid_simplest_models_cfg.as_ref().unwrap().models[0].base_url == None); assert!(valid_simplest_models_cfg.as_ref().unwrap().models[0].provider == None); - assert!(valid_simplest_models_cfg.as_ref().unwrap().models[0].temperature == None); - assert!(valid_simplest_models_cfg.as_ref().unwrap().models[0].max_output_tokens == None); assert!(valid_simplest_models_cfg.as_ref().unwrap().models[0].weight == -1); // case 2: @@ -299,7 +256,6 @@ mod tests { // AMRS_API_KEY is set in .env.test already. let valid_cfg_with_customized_provider = Config::builder() .base_url("http://example.ai") - .max_output_tokens(2048) .model( ModelConfig::builder() .name("custom-model") @@ -325,8 +281,6 @@ mod tests { from_filename(".env.test").ok(); let mut valid_cfg = Config::builder() - .temperature(0.5) - .max_output_tokens(1500) .model( ModelConfig::builder() .name("model-1".to_string()) @@ -338,8 +292,6 @@ mod tests { assert!(valid_cfg.is_ok()); assert!(valid_cfg.as_ref().unwrap().models.len() == 1); - assert!(valid_cfg.as_ref().unwrap().models[0].temperature == Some(0.5)); - assert!(valid_cfg.as_ref().unwrap().models[0].max_output_tokens == Some(1500)); assert!(valid_cfg.as_ref().unwrap().models[0].provider == Some("OPENAI".to_string())); assert!( valid_cfg.as_ref().unwrap().models[0].base_url diff --git a/src/lib.rs b/src/lib.rs index 6fc90da..9e65426 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,11 +8,13 @@ mod router { } mod provider { + mod common; mod faker; mod openai; pub mod provider; } pub mod types { + pub mod completions; pub mod error; pub mod responses; } diff --git a/src/main.rs b/src/main.rs index cf1447b..21dd104 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,9 +1,10 @@ use tokio::runtime::Runtime; use arms::client; -use arms::types::responses; +use arms::types::{completions, responses}; fn main() { + // case 1: completion with DeepInfra provider. let config = client::Config::builder() .provider("deepinfra") .routing_mode(client::RoutingMode::WRR) @@ -26,20 +27,14 @@ fn main() { let mut client = client::Client::new(config); - let request = responses::CreateResponseArgs::default() - .input(responses::InputParam::Items(vec![ - responses::InputItem::EasyMessage(responses::EasyInputMessage { - r#type: responses::MessageType::Message, - role: responses::Role::User, - content: responses::EasyInputContent::Text("What is AGI?".to_string()), - }), - ])) + let request = completions::CreateCompletionRequestArgs::default() + .prompt("How to achieve AGI?") .build() .unwrap(); let result = Runtime::new() .unwrap() - .block_on(client.create_response(request)); + .block_on(client.create_completion(request)); match result { Ok(response) => { @@ -49,4 +44,51 @@ fn main() { eprintln!("Error: {}", e); } } + + // case 2: response with DeepInfra provider. + // let config = client::Config::builder() + // .provider("deepinfra") + // .routing_mode(client::RoutingMode::WRR) + // .model( + // client::ModelConfig::builder() + // .name("nvidia/Nemotron-3-Nano-30B-A3B") + // .weight(1) + // .build() + // .unwrap(), + // ) + // .model( + // client::ModelConfig::builder() + // .name("deepseek-ai/DeepSeek-V3.2") + // .weight(2) + // .build() + // .unwrap(), + // ) + // .build() + // .unwrap(); + + // let mut client = client::Client::new(config); + + // let request = responses::CreateResponseArgs::default() + // .input(responses::InputParam::Items(vec![ + // responses::InputItem::EasyMessage(responses::EasyInputMessage { + // r#type: responses::MessageType::Message, + // role: responses::Role::User, + // content: responses::EasyInputContent::Text("What is AGI?".to_string()), + // }), + // ])) + // .build() + // .unwrap(); + + // let result = Runtime::new() + // .unwrap() + // .block_on(client.create_response(request)); + + // match result { + // Ok(response) => { + // println!("Response ID: {}", response.id); + // } + // Err(e) => { + // eprintln!("Error: {}", e); + // } + // } } diff --git a/src/provider/common.rs b/src/provider/common.rs new file mode 100644 index 0000000..d69aad9 --- /dev/null +++ b/src/provider/common.rs @@ -0,0 +1,22 @@ +use crate::types::error::OpenAIError; +use crate::types::{completions, responses}; + +pub fn validate_completion_request( + request: &completions::CreateCompletionRequest, +) -> Result<(), OpenAIError> { + if request.model != "" { + return Err(OpenAIError::InvalidArgument( + "Model must be specified in the client.Config".to_string(), + )); + } + Ok(()) +} + +pub fn validate_response_request(request: &responses::CreateResponse) -> Result<(), OpenAIError> { + if request.model.is_some() { + return Err(OpenAIError::InvalidArgument( + "Model must be specified in the client.Config".to_string(), + )); + } + Ok(()) +} diff --git a/src/provider/faker.rs b/src/provider/faker.rs index 614c7f0..9a7546c 100644 --- a/src/provider/faker.rs +++ b/src/provider/faker.rs @@ -1,7 +1,9 @@ +use async_openai::types::chat::Choice; use async_trait::async_trait; use crate::client::config::{ModelConfig, ModelName}; -use crate::provider::provider; +use crate::provider::{common, provider}; +use crate::types::completions::{CreateCompletionRequest, CreateCompletionResponse}; use crate::types::error::OpenAIError; use crate::types::responses::{ AssistantRole, CreateResponse, OutputItem, OutputMessage, OutputMessageContent, OutputStatus, @@ -26,8 +28,8 @@ impl provider::Provider for FakerProvider { "FakeProvider" } - async fn create_response(&self, request: CreateResponse) -> Result { - provider::validate_responses_request(&request)?; + async fn create_response(&self, _request: CreateResponse) -> Result { + common::validate_response_request(&_request)?; Ok(Response { id: "fake-response-id".to_string(), @@ -71,4 +73,26 @@ impl provider::Provider for FakerProvider { truncation: None, }) } + + async fn create_completion( + &self, + _request: CreateCompletionRequest, + ) -> Result { + common::validate_completion_request(&_request)?; + + Ok(CreateCompletionResponse { + id: "fake-completion-id".to_string(), + object: "text_completion".to_string(), + created: 1_600_000_000, + model: self.model.clone(), + choices: vec![Choice { + index: 0, + text: "This is a fake completion.".to_string(), + logprobs: None, + finish_reason: None, + }], + usage: None, + system_fingerprint: None, + }) + } } diff --git a/src/provider/openai.rs b/src/provider/openai.rs index 5a97069..3f917aa 100644 --- a/src/provider/openai.rs +++ b/src/provider/openai.rs @@ -3,9 +3,9 @@ use async_trait::async_trait; use derive_builder::Builder; use crate::client::config::{DEFAULT_PROVIDER, ModelConfig, ModelName}; -use crate::provider::provider; +use crate::provider::{common, provider}; use crate::types::error::OpenAIError; -use crate::types::responses::{CreateResponse, Response}; +use crate::types::{completions, responses}; #[derive(Debug, Clone, Builder)] #[builder(pattern = "mutable", build_fn(skip))] @@ -63,15 +63,34 @@ impl provider::Provider for OpenAIProvider { "OpenAIProvider" } - async fn create_response(&self, request: CreateResponse) -> Result { - if self.provider_name == "DEEPINFRA" { + async fn create_completion( + &self, + request: completions::CreateCompletionRequest, + ) -> Result { + common::validate_completion_request(&request)?; + + // Set the model after validation since model is bind to the provider. + let mut req = request.clone(); + req.model = self.model.clone(); + self.client.completions().create(req).await + } + + async fn create_response( + &self, + request: responses::CreateResponse, + ) -> Result { + if !provider::RESPONSE_ENDPOINT_PROVIDERS.contains(&self.provider_name.as_str()) { return Err(OpenAIError::InvalidArgument(format!( "Provider '{}' doesn't support Responses endpoint", self.provider_name ))); } - provider::validate_responses_request(&request)?; - self.client.responses().create(request).await + common::validate_response_request(&request)?; + + // Set the model after validation since model is bind to the provider. + let mut req = request.clone(); + req.model = Some(self.model.clone()); + self.client.responses().create(req).await } } diff --git a/src/provider/provider.rs b/src/provider/provider.rs index e760a1f..32242d5 100644 --- a/src/provider/provider.rs +++ b/src/provider/provider.rs @@ -4,7 +4,10 @@ use crate::client::config::ModelConfig; use crate::provider::faker::FakerProvider; use crate::provider::openai::OpenAIProvider; use crate::types::error::OpenAIError; -use crate::types::responses::{CreateResponse, Response}; +use crate::types::{completions, responses}; + +// Not all providers support response endpoint. +pub const RESPONSE_ENDPOINT_PROVIDERS: &[&str] = &["FAKER", "OPENAI"]; pub fn construct_provider(config: ModelConfig) -> Box { let provider = config.provider.clone().unwrap(); @@ -22,17 +25,16 @@ pub fn construct_provider(config: ModelConfig) -> Box { #[async_trait] pub trait Provider: Send + Sync { + // Used in tests only now. fn name(&self) -> &'static str; - async fn create_response(&self, request: CreateResponse) -> Result; -} - -pub fn validate_responses_request(request: &CreateResponse) -> Result<(), OpenAIError> { - if request.model.is_some() { - return Err(OpenAIError::InvalidArgument( - "Model must be specified in the client.Config".to_string(), - )); - } - Ok(()) + async fn create_response( + &self, + request: responses::CreateResponse, + ) -> Result; + async fn create_completion( + &self, + request: completions::CreateCompletionRequest, + ) -> Result; } #[cfg(test)] diff --git a/src/router/random.rs b/src/router/random.rs index c7f6283..1ca2221 100644 --- a/src/router/random.rs +++ b/src/router/random.rs @@ -2,7 +2,6 @@ use rand::Rng; use crate::client::config::ModelName; use crate::router::router::{ModelInfo, Router}; -use crate::types::responses::CreateResponse; pub struct RandomRouter { pub model_infos: Vec, @@ -19,7 +18,7 @@ impl Router for RandomRouter { "RandomRouter" } - fn sample(&mut self, _input: &CreateResponse) -> ModelName { + fn sample(&mut self) -> ModelName { let mut rng = rand::rng(); let idx = rng.random_range(0..self.model_infos.len()); self.model_infos[idx].name.clone() @@ -50,7 +49,7 @@ mod tests { let mut counts = std::collections::HashMap::new(); for _ in 0..1000 { - let candidate = router.sample(&CreateResponse::default()); + let candidate = router.sample(); *counts.entry(candidate.clone()).or_insert(0) += 1; } assert!(counts.len() == model_infos.len()); diff --git a/src/router/router.rs b/src/router/router.rs index 3182984..6561781 100644 --- a/src/router/router.rs +++ b/src/router/router.rs @@ -25,7 +25,7 @@ pub fn construct_router(mode: RoutingMode, models: Vec) -> Box &'static str; - fn sample(&mut self, input: &CreateResponse) -> ModelName; + fn sample(&mut self) -> ModelName; } #[cfg(test)] diff --git a/src/router/wrr.rs b/src/router/wrr.rs index 591def3..b9a375f 100644 --- a/src/router/wrr.rs +++ b/src/router/wrr.rs @@ -28,7 +28,7 @@ impl Router for WeightedRoundRobinRouter { } // Use Smooth Weighted Round Robin Algorithm. - fn sample(&mut self, _input: &CreateResponse) -> ModelName { + fn sample(&mut self) -> ModelName { // return early if only one model. if self.model_infos.len() == 1 { return self.model_infos[0].name.clone(); @@ -77,7 +77,7 @@ mod tests { let mut wrr = WeightedRoundRobinRouter::new(model_infos.clone()); let mut counts = HashMap::new(); for _ in 0..1000 { - let sampled_id = wrr.sample(&CreateResponse::default()); + let sampled_id = wrr.sample(); *counts.entry(sampled_id.clone()).or_insert(0) += 1; } assert!(counts.len() == model_infos.len()); diff --git a/src/types/completions.rs b/src/types/completions.rs new file mode 100644 index 0000000..4c54932 --- /dev/null +++ b/src/types/completions.rs @@ -0,0 +1 @@ +pub use async_openai::types::completions::*; diff --git a/tests/client.rs b/tests/client.rs index 732113e..2f56cc0 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -1,6 +1,7 @@ use dotenvy::from_filename; use arms::client; +use arms::types::completions; use arms::types::responses; #[cfg(test)] @@ -80,4 +81,29 @@ mod tests { .unwrap(); let _ = client.create_response(request).await.unwrap(); } + + #[tokio::test] + async fn test_completion() { + from_filename(".env.integration-test").ok(); + + let config = client::Config::builder() + .provider("faker") + .model( + client::ModelConfig::builder() + .name("fake-completion-model") + .build() + .unwrap(), + ) + .build() + .unwrap(); + + let mut client = client::Client::new(config); + let request = completions::CreateCompletionRequestArgs::default() + .build() + .unwrap(); + + let response = client.create_completion(request).await.unwrap(); + assert!(response.id.starts_with("fake-completion-id")); + assert!(response.model == "fake-completion-model"); + } } From f257c4a2837c55906bf726fd76d4fb6788aa0360 Mon Sep 17 00:00:00 2001 From: kerthcet Date: Sat, 27 Dec 2025 10:59:30 +0800 Subject: [PATCH 2/3] Replace completion with chat completion Signed-off-by: kerthcet --- Cargo.toml | 2 +- README.md | 12 ++++++++---- src/client/client.rs | 7 ++++--- src/lib.rs | 2 +- src/main.rs | 17 +++++++++++++---- src/provider/common.rs | 4 ++-- src/provider/faker.rs | 29 ++++++++++++++++++----------- src/provider/openai.rs | 8 ++++---- src/provider/provider.rs | 6 +++--- src/router/router.rs | 1 - src/router/wrr.rs | 1 - src/types/chat.rs | 1 + src/types/completions.rs | 1 - tests/client.rs | 4 ++-- 14 files changed, 57 insertions(+), 38 deletions(-) create mode 100644 src/types/chat.rs delete mode 100644 src/types/completions.rs diff --git a/Cargo.toml b/Cargo.toml index 7958d55..8d3c6e2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,7 @@ version = "0.1.0" edition = "2024" [dependencies] -async-openai = { version = "0.31.1", features = ["_api", "response-types", "responses", "completions"] } +async-openai = { version = "0.31.1", features = ["_api", "response-types", "responses", "chat-completion"] } async-trait = "0.1.89" derive_builder = "0.20.2" dotenvy = "0.15.7" diff --git a/README.md b/README.md index 30485e4..15931ec 100644 --- a/README.md +++ b/README.md @@ -27,8 +27,9 @@ Here's a simple example with the Weighted Round Robin (WRR) routing mode: // Before running the code, make sure to set your OpenAI API key in the environment variable: // export OPENAI_API_KEY="your_openai_api_key" +use tokio::runtime::Runtime; use arms::client; -use arms::types::responses; +use arms::types::chat; let config = client::Config::builder() .provider("openai") @@ -51,12 +52,15 @@ let config = client::Config::builder() .unwrap(); let mut client = client::Client::new(config); -let request = responses::CreateResponseArgs::default() - .input("give me a poem about nature") +let request = chat::CreateChatCompletionRequestArgs::default() + .messages([ + chat::ChatCompletionRequestSystemMessage::from("You are a helpful assistant.").into(), + chat::ChatCompletionRequestUserMessage::from("Who won the world series in 2020?").into(), + ]) .build() .unwrap(); -let response = client.create_response(request).await.unwrap(); +let result = Runtime::new().unwrap().block_on(client.create_completion(request)); ``` ## Contributing diff --git a/src/client/client.rs b/src/client/client.rs index ad431a2..d5762c7 100644 --- a/src/client/client.rs +++ b/src/client/client.rs @@ -4,7 +4,7 @@ use crate::client::config::{Config, ModelName}; use crate::provider::provider; use crate::router::router; use crate::types::error::OpenAIError; -use crate::types::{completions, responses}; +use crate::types::{chat, responses}; pub struct Client { providers: HashMap>, @@ -37,10 +37,11 @@ impl Client { provider.create_response(request).await } + // This is chat completion endpoint. pub async fn create_completion( &mut self, - request: completions::CreateCompletionRequest, - ) -> Result { + request: chat::CreateChatCompletionRequest, + ) -> Result { let candidate = self.router.sample(); let provider = self.providers.get(&candidate).unwrap(); provider.create_completion(request).await diff --git a/src/lib.rs b/src/lib.rs index 9e65426..d985e2e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,7 +14,7 @@ mod provider { pub mod provider; } pub mod types { - pub mod completions; + pub mod chat; pub mod error; pub mod responses; } diff --git a/src/main.rs b/src/main.rs index 21dd104..80c2e3a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,7 @@ use tokio::runtime::Runtime; use arms::client; -use arms::types::{completions, responses}; +use arms::types::chat; fn main() { // case 1: completion with DeepInfra provider. @@ -27,8 +27,17 @@ fn main() { let mut client = client::Client::new(config); - let request = completions::CreateCompletionRequestArgs::default() - .prompt("How to achieve AGI?") + let request = chat::CreateChatCompletionRequestArgs::default() + .messages([ + chat::ChatCompletionRequestSystemMessage::from("You are a helpful assistant.").into(), + chat::ChatCompletionRequestUserMessage::from("Who won the world series in 2020?") + .into(), + // chat::ChatCompletionRequestAssistantMessage::from( + // "The Los Angeles Dodgers won the World Series in 2020.", + // ) + // .into(), + // chat::ChatCompletionRequestUserMessage::from("Where was it played?").into(), + ]) .build() .unwrap(); @@ -38,7 +47,7 @@ fn main() { match result { Ok(response) => { - println!("Response ID: {}", response.id); + println!("Response: {:?}", response); } Err(e) => { eprintln!("Error: {}", e); diff --git a/src/provider/common.rs b/src/provider/common.rs index d69aad9..3b041ef 100644 --- a/src/provider/common.rs +++ b/src/provider/common.rs @@ -1,8 +1,8 @@ use crate::types::error::OpenAIError; -use crate::types::{completions, responses}; +use crate::types::{chat, responses}; pub fn validate_completion_request( - request: &completions::CreateCompletionRequest, + request: &chat::CreateChatCompletionRequest, ) -> Result<(), OpenAIError> { if request.model != "" { return Err(OpenAIError::InvalidArgument( diff --git a/src/provider/faker.rs b/src/provider/faker.rs index 9a7546c..e38edd0 100644 --- a/src/provider/faker.rs +++ b/src/provider/faker.rs @@ -1,9 +1,8 @@ -use async_openai::types::chat::Choice; use async_trait::async_trait; use crate::client::config::{ModelConfig, ModelName}; use crate::provider::{common, provider}; -use crate::types::completions::{CreateCompletionRequest, CreateCompletionResponse}; +use crate::types::chat; use crate::types::error::OpenAIError; use crate::types::responses::{ AssistantRole, CreateResponse, OutputItem, OutputMessage, OutputMessageContent, OutputStatus, @@ -76,22 +75,30 @@ impl provider::Provider for FakerProvider { async fn create_completion( &self, - _request: CreateCompletionRequest, - ) -> Result { - common::validate_completion_request(&_request)?; - - Ok(CreateCompletionResponse { + request: chat::CreateChatCompletionRequest, + ) -> Result { + common::validate_completion_request(&request)?; + Ok(chat::CreateChatCompletionResponse { id: "fake-completion-id".to_string(), object: "text_completion".to_string(), created: 1_600_000_000, model: self.model.clone(), - choices: vec![Choice { + usage: None, + service_tier: None, + choices: vec![chat::ChatChoice { index: 0, - text: "This is a fake completion.".to_string(), - logprobs: None, + message: chat::ChatCompletionResponseMessage { + role: chat::Role::Assistant, + content: Some("This is a fake chat completion.".to_string()), + refusal: None, + tool_calls: None, + annotations: None, + function_call: None, + audio: None, + }, finish_reason: None, + logprobs: None, }], - usage: None, system_fingerprint: None, }) } diff --git a/src/provider/openai.rs b/src/provider/openai.rs index 3f917aa..d89853e 100644 --- a/src/provider/openai.rs +++ b/src/provider/openai.rs @@ -5,7 +5,7 @@ use derive_builder::Builder; use crate::client::config::{DEFAULT_PROVIDER, ModelConfig, ModelName}; use crate::provider::{common, provider}; use crate::types::error::OpenAIError; -use crate::types::{completions, responses}; +use crate::types::{chat, responses}; #[derive(Debug, Clone, Builder)] #[builder(pattern = "mutable", build_fn(skip))] @@ -65,14 +65,14 @@ impl provider::Provider for OpenAIProvider { async fn create_completion( &self, - request: completions::CreateCompletionRequest, - ) -> Result { + request: chat::CreateChatCompletionRequest, + ) -> Result { common::validate_completion_request(&request)?; // Set the model after validation since model is bind to the provider. let mut req = request.clone(); req.model = self.model.clone(); - self.client.completions().create(req).await + self.client.chat().create(req).await } async fn create_response( diff --git a/src/provider/provider.rs b/src/provider/provider.rs index 32242d5..2bf83d2 100644 --- a/src/provider/provider.rs +++ b/src/provider/provider.rs @@ -4,7 +4,7 @@ use crate::client::config::ModelConfig; use crate::provider::faker::FakerProvider; use crate::provider::openai::OpenAIProvider; use crate::types::error::OpenAIError; -use crate::types::{completions, responses}; +use crate::types::{chat, responses}; // Not all providers support response endpoint. pub const RESPONSE_ENDPOINT_PROVIDERS: &[&str] = &["FAKER", "OPENAI"]; @@ -33,8 +33,8 @@ pub trait Provider: Send + Sync { ) -> Result; async fn create_completion( &self, - request: completions::CreateCompletionRequest, - ) -> Result; + request: chat::CreateChatCompletionRequest, + ) -> Result; } #[cfg(test)] diff --git a/src/router/router.rs b/src/router/router.rs index 6561781..6532a07 100644 --- a/src/router/router.rs +++ b/src/router/router.rs @@ -1,7 +1,6 @@ use crate::client::config::{ModelConfig, ModelName, RoutingMode}; use crate::router::random::RandomRouter; use crate::router::wrr::WeightedRoundRobinRouter; -use crate::types::responses::CreateResponse; #[derive(Debug, Clone)] pub struct ModelInfo { diff --git a/src/router/wrr.rs b/src/router/wrr.rs index b9a375f..cf9c416 100644 --- a/src/router/wrr.rs +++ b/src/router/wrr.rs @@ -1,6 +1,5 @@ use crate::client::config::ModelName; use crate::router::router::{ModelInfo, Router}; -use crate::types::responses::CreateResponse; pub struct WeightedRoundRobinRouter { total_weight: i32, diff --git a/src/types/chat.rs b/src/types/chat.rs new file mode 100644 index 0000000..e92a02e --- /dev/null +++ b/src/types/chat.rs @@ -0,0 +1 @@ +pub use async_openai::types::chat::*; diff --git a/src/types/completions.rs b/src/types/completions.rs deleted file mode 100644 index 4c54932..0000000 --- a/src/types/completions.rs +++ /dev/null @@ -1 +0,0 @@ -pub use async_openai::types::completions::*; diff --git a/tests/client.rs b/tests/client.rs index 2f56cc0..bf9a41d 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -1,7 +1,7 @@ use dotenvy::from_filename; use arms::client; -use arms::types::completions; +use arms::types::chat; use arms::types::responses; #[cfg(test)] @@ -98,7 +98,7 @@ mod tests { .unwrap(); let mut client = client::Client::new(config); - let request = completions::CreateCompletionRequestArgs::default() + let request = chat::CreateChatCompletionRequestArgs::default() .build() .unwrap(); From 616c2da30c87d6f39786c8dfb07e8c8599ad295a Mon Sep 17 00:00:00 2001 From: kerthcet Date: Sat, 27 Dec 2025 11:02:38 +0800 Subject: [PATCH 3/3] update Signed-off-by: kerthcet --- README.md | 2 +- src/main.rs | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 15931ec..f57774b 100644 --- a/README.md +++ b/README.md @@ -55,7 +55,7 @@ let mut client = client::Client::new(config); let request = chat::CreateChatCompletionRequestArgs::default() .messages([ chat::ChatCompletionRequestSystemMessage::from("You are a helpful assistant.").into(), - chat::ChatCompletionRequestUserMessage::from("Who won the world series in 2020?").into(), + chat::ChatCompletionRequestUserMessage::from("How is the weather today?").into(), ]) .build() .unwrap(); diff --git a/src/main.rs b/src/main.rs index 80c2e3a..8667a88 100644 --- a/src/main.rs +++ b/src/main.rs @@ -32,11 +32,11 @@ fn main() { chat::ChatCompletionRequestSystemMessage::from("You are a helpful assistant.").into(), chat::ChatCompletionRequestUserMessage::from("Who won the world series in 2020?") .into(), - // chat::ChatCompletionRequestAssistantMessage::from( - // "The Los Angeles Dodgers won the World Series in 2020.", - // ) - // .into(), - // chat::ChatCompletionRequestUserMessage::from("Where was it played?").into(), + chat::ChatCompletionRequestAssistantMessage::from( + "The Los Angeles Dodgers won the World Series in 2020.", + ) + .into(), + chat::ChatCompletionRequestUserMessage::from("Where was it played?").into(), ]) .build() .unwrap();