Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ version = "0.1.0"
edition = "2024"

[dependencies]
async-openai = { version = "0.31.1", features = ["_api", "response-types", "responses",] }
async-openai = { version = "0.31.1", features = ["_api", "response-types", "responses", "chat-completion"] }
async-trait = "0.1.89"
derive_builder = "0.20.2"
dotenvy = "0.15.7"
Expand Down
12 changes: 8 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ Here's a simple example with the Weighted Round Robin (WRR) routing mode:
// Before running the code, make sure to set your OpenAI API key in the environment variable:
// export OPENAI_API_KEY="your_openai_api_key"

use tokio::runtime::Runtime;
use arms::client;
use arms::types::responses;
use arms::types::chat;

let config = client::Config::builder()
.provider("openai")
Expand All @@ -51,12 +52,15 @@ let config = client::Config::builder()
.unwrap();

let mut client = client::Client::new(config);
let request = responses::CreateResponseArgs::default()
.input("give me a poem about nature")
let request = chat::CreateChatCompletionRequestArgs::default()
.messages([
chat::ChatCompletionRequestSystemMessage::from("You are a helpful assistant.").into(),
chat::ChatCompletionRequestUserMessage::from("How is the weather today?").into(),
])
.build()
.unwrap();

let response = client.create_response(request).await.unwrap();
let result = Runtime::new().unwrap().block_on(client.create_completion(request));
```

## Contributing
Expand Down
18 changes: 14 additions & 4 deletions src/client/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::client::config::{Config, ModelName};
use crate::provider::provider;
use crate::router::router;
use crate::types::error::OpenAIError;
use crate::types::responses::{CreateResponse, Response};
use crate::types::{chat, responses};

pub struct Client {
providers: HashMap<ModelName, Box<dyn provider::Provider>>,
Expand All @@ -30,12 +30,22 @@ impl Client {

pub async fn create_response(
&mut self,
request: CreateResponse,
) -> Result<Response, OpenAIError> {
let candidate = self.router.sample(&request);
request: responses::CreateResponse,
) -> Result<responses::Response, OpenAIError> {
let candidate = self.router.sample();
let provider = self.providers.get(&candidate).unwrap();
provider.create_response(request).await
}

// This is chat completion endpoint.
pub async fn create_completion(
&mut self,
request: chat::CreateChatCompletionRequest,
) -> Result<chat::CreateChatCompletionResponse, OpenAIError> {
let candidate = self.router.sample();
let provider = self.providers.get(&candidate).unwrap();
provider.create_completion(request).await
}
}

#[cfg(test)]
Expand Down
48 changes: 0 additions & 48 deletions src/client/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,6 @@ pub struct ModelConfig {
pub(crate) base_url: Option<String>,
#[builder(default = "None", setter(custom))]
pub(crate) provider: Option<String>,
#[builder(default = "None")]
pub(crate) temperature: Option<f32>,
#[builder(default = "None")]
pub(crate) max_output_tokens: Option<usize>,

#[builder(setter(custom))]
pub(crate) name: ModelName,
Expand Down Expand Up @@ -86,10 +82,6 @@ pub struct Config {
pub(crate) base_url: Option<String>,
#[builder(default = "DEFAULT_PROVIDER.to_string()", setter(custom))]
pub(crate) provider: String,
#[builder(default = "0.8")]
pub(crate) temperature: f32,
#[builder(default = "1024")]
pub(crate) max_output_tokens: usize,

#[builder(default = "RoutingMode::Random")]
pub(crate) routing_mode: RoutingMode,
Expand Down Expand Up @@ -131,13 +123,6 @@ impl Config {
if model.provider.is_none() {
model.provider = Some(self.provider.clone());
}

if model.temperature.is_none() {
model.temperature = Some(self.temperature);
}
if model.max_output_tokens.is_none() {
model.max_output_tokens = Some(self.max_output_tokens);
}
}
self
}
Expand Down Expand Up @@ -176,24 +161,6 @@ impl ConfigBuilder {
));
}

if let Some(max_output_tokens) = model.max_output_tokens {
if max_output_tokens <= 0 {
return Err(format!(
"Model '{}' max_output_tokens must be positive.",
model.name
));
}
}

if let Some(temperature) = model.temperature {
if temperature < 0.0 || temperature > 1.0 {
return Err(format!(
"Model '{}' temperature must be between 0.0 and 1.0.",
model.name
));
}
}

// check the existence of API key in environment variables
if let Some(provider) = &model.provider {
let env_var = format!("{}_API_KEY", provider);
Expand Down Expand Up @@ -251,20 +218,10 @@ mod tests {
assert!(valid_simplest_models_cfg.is_ok());
assert!(valid_simplest_models_cfg.as_ref().unwrap().provider == DEFAULT_PROVIDER);
assert!(valid_simplest_models_cfg.as_ref().unwrap().base_url == None);
assert!(valid_simplest_models_cfg.as_ref().unwrap().temperature == 0.8);
assert!(
valid_simplest_models_cfg
.as_ref()
.unwrap()
.max_output_tokens
== 1024
);
assert!(valid_simplest_models_cfg.as_ref().unwrap().routing_mode == RoutingMode::Random);
assert!(valid_simplest_models_cfg.as_ref().unwrap().models.len() == 1);
assert!(valid_simplest_models_cfg.as_ref().unwrap().models[0].base_url == None);
assert!(valid_simplest_models_cfg.as_ref().unwrap().models[0].provider == None);
assert!(valid_simplest_models_cfg.as_ref().unwrap().models[0].temperature == None);
assert!(valid_simplest_models_cfg.as_ref().unwrap().models[0].max_output_tokens == None);
assert!(valid_simplest_models_cfg.as_ref().unwrap().models[0].weight == -1);

// case 2:
Expand Down Expand Up @@ -299,7 +256,6 @@ mod tests {
// AMRS_API_KEY is set in .env.test already.
let valid_cfg_with_customized_provider = Config::builder()
.base_url("http://example.ai")
.max_output_tokens(2048)
.model(
ModelConfig::builder()
.name("custom-model")
Expand All @@ -325,8 +281,6 @@ mod tests {
from_filename(".env.test").ok();

let mut valid_cfg = Config::builder()
.temperature(0.5)
.max_output_tokens(1500)
.model(
ModelConfig::builder()
.name("model-1".to_string())
Expand All @@ -338,8 +292,6 @@ mod tests {

assert!(valid_cfg.is_ok());
assert!(valid_cfg.as_ref().unwrap().models.len() == 1);
assert!(valid_cfg.as_ref().unwrap().models[0].temperature == Some(0.5));
assert!(valid_cfg.as_ref().unwrap().models[0].max_output_tokens == Some(1500));
assert!(valid_cfg.as_ref().unwrap().models[0].provider == Some("OPENAI".to_string()));
assert!(
valid_cfg.as_ref().unwrap().models[0].base_url
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@ mod router {
}

mod provider {
mod common;
mod faker;
mod openai;
pub mod provider;
}
pub mod types {
pub mod chat;
pub mod error;
pub mod responses;
}
73 changes: 62 additions & 11 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use tokio::runtime::Runtime;

use arms::client;
use arms::types::responses;
use arms::types::chat;

fn main() {
// case 1: completion with DeepInfra provider.
let config = client::Config::builder()
.provider("deepinfra")
.routing_mode(client::RoutingMode::WRR)
Expand All @@ -26,27 +27,77 @@ fn main() {

let mut client = client::Client::new(config);

let request = responses::CreateResponseArgs::default()
.input(responses::InputParam::Items(vec![
responses::InputItem::EasyMessage(responses::EasyInputMessage {
r#type: responses::MessageType::Message,
role: responses::Role::User,
content: responses::EasyInputContent::Text("What is AGI?".to_string()),
}),
]))
let request = chat::CreateChatCompletionRequestArgs::default()
.messages([
chat::ChatCompletionRequestSystemMessage::from("You are a helpful assistant.").into(),
chat::ChatCompletionRequestUserMessage::from("Who won the world series in 2020?")
.into(),
chat::ChatCompletionRequestAssistantMessage::from(
"The Los Angeles Dodgers won the World Series in 2020.",
)
.into(),
chat::ChatCompletionRequestUserMessage::from("Where was it played?").into(),
])
.build()
.unwrap();

let result = Runtime::new()
.unwrap()
.block_on(client.create_response(request));
.block_on(client.create_completion(request));

match result {
Ok(response) => {
println!("Response ID: {}", response.id);
println!("Response: {:?}", response);
}
Err(e) => {
eprintln!("Error: {}", e);
}
}

// case 2: response with DeepInfra provider.
// let config = client::Config::builder()
// .provider("deepinfra")
// .routing_mode(client::RoutingMode::WRR)
// .model(
// client::ModelConfig::builder()
// .name("nvidia/Nemotron-3-Nano-30B-A3B")
// .weight(1)
// .build()
// .unwrap(),
// )
// .model(
// client::ModelConfig::builder()
// .name("deepseek-ai/DeepSeek-V3.2")
// .weight(2)
// .build()
// .unwrap(),
// )
// .build()
// .unwrap();

// let mut client = client::Client::new(config);

// let request = responses::CreateResponseArgs::default()
// .input(responses::InputParam::Items(vec![
// responses::InputItem::EasyMessage(responses::EasyInputMessage {
// r#type: responses::MessageType::Message,
// role: responses::Role::User,
// content: responses::EasyInputContent::Text("What is AGI?".to_string()),
// }),
// ]))
// .build()
// .unwrap();

// let result = Runtime::new()
// .unwrap()
// .block_on(client.create_response(request));

// match result {
// Ok(response) => {
// println!("Response ID: {}", response.id);
// }
// Err(e) => {
// eprintln!("Error: {}", e);
// }
// }
}
22 changes: 22 additions & 0 deletions src/provider/common.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use crate::types::error::OpenAIError;
use crate::types::{chat, responses};

pub fn validate_completion_request(
request: &chat::CreateChatCompletionRequest,
) -> Result<(), OpenAIError> {
if request.model != "" {
return Err(OpenAIError::InvalidArgument(
"Model must be specified in the client.Config".to_string(),
));
}
Ok(())
}

pub fn validate_response_request(request: &responses::CreateResponse) -> Result<(), OpenAIError> {
if request.model.is_some() {
return Err(OpenAIError::InvalidArgument(
"Model must be specified in the client.Config".to_string(),
));
}
Ok(())
}
37 changes: 34 additions & 3 deletions src/provider/faker.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use async_trait::async_trait;

use crate::client::config::{ModelConfig, ModelName};
use crate::provider::provider;
use crate::provider::{common, provider};
use crate::types::chat;
use crate::types::error::OpenAIError;
use crate::types::responses::{
AssistantRole, CreateResponse, OutputItem, OutputMessage, OutputMessageContent, OutputStatus,
Expand All @@ -26,8 +27,8 @@ impl provider::Provider for FakerProvider {
"FakeProvider"
}

async fn create_response(&self, request: CreateResponse) -> Result<Response, OpenAIError> {
provider::validate_responses_request(&request)?;
async fn create_response(&self, _request: CreateResponse) -> Result<Response, OpenAIError> {
common::validate_response_request(&_request)?;

Ok(Response {
id: "fake-response-id".to_string(),
Expand Down Expand Up @@ -71,4 +72,34 @@ impl provider::Provider for FakerProvider {
truncation: None,
})
}

async fn create_completion(
&self,
request: chat::CreateChatCompletionRequest,
) -> Result<chat::CreateChatCompletionResponse, OpenAIError> {
common::validate_completion_request(&request)?;
Ok(chat::CreateChatCompletionResponse {
id: "fake-completion-id".to_string(),
object: "text_completion".to_string(),
created: 1_600_000_000,
model: self.model.clone(),
usage: None,
service_tier: None,
choices: vec![chat::ChatChoice {
index: 0,
message: chat::ChatCompletionResponseMessage {
role: chat::Role::Assistant,
content: Some("This is a fake chat completion.".to_string()),
refusal: None,
tool_calls: None,
annotations: None,
function_call: None,
audio: None,
},
finish_reason: None,
logprobs: None,
}],
system_fingerprint: None,
})
}
}
Loading
Loading