smolagents_rs/models/
openai.rs1use std::collections::HashMap;
2
3use crate::errors::AgentError;
4use crate::models::model_traits::{Model, ModelResponse};
5use crate::models::types::{Message, MessageRole};
6use crate::tools::ToolInfo;
7use anyhow::Result;
8use reqwest::blocking::Client;
9use serde::{Deserialize, Serialize};
10use serde_json::{json, Value};
11
12#[derive(Debug, Deserialize)]
13pub struct OpenAIResponse {
14 pub choices: Vec<Choice>,
15}
16
17#[derive(Debug, Deserialize)]
18pub struct Choice {
19 pub message: AssistantMessage,
20}
21
22#[derive(Debug, Deserialize)]
23pub struct AssistantMessage {
24 pub role: MessageRole,
25 pub content: Option<String>,
26 pub tool_calls: Option<Vec<ToolCall>>,
27 pub refusal: Option<String>,
28}
29
30#[derive(Debug, Serialize, Deserialize, Clone)]
31pub struct ToolCall {
32 pub id: Option<String>,
33 #[serde(rename = "type")]
34 pub call_type: Option<String>,
35 pub function: FunctionCall,
36}
37
38#[derive(Debug, Serialize, Deserialize, Clone)]
39pub struct FunctionCall {
40 pub name: String,
41 #[serde(deserialize_with = "deserialize_arguments")]
42 pub arguments: Value,
43}
44
45fn deserialize_arguments<'de, D>(deserializer: D) -> Result<Value, D::Error>
47where
48 D: serde::Deserializer<'de>,
49{
50 let value = Value::deserialize(deserializer)?;
51
52 if let Value::String(s) = &value {
54 if let Ok(parsed) = serde_json::from_str(s) {
55 return Ok(parsed);
56 }
57 }
58
59 Ok(value)
60}
61
62impl FunctionCall {
63 pub fn get_arguments(&self) -> Result<HashMap<String, String>> {
64 if let Ok(map) = serde_json::from_value(self.arguments.clone()) {
66 return Ok(map);
67 }
68
69 if let Value::String(arg_str) = &self.arguments {
71 if let Ok(parsed) = serde_json::from_str(arg_str) {
72 return Ok(parsed);
73 }
74 }
75
76 Err(anyhow::anyhow!(
78 "Failed to parse arguments as HashMap or JSON string"
79 ))
80 }
81}
82
83impl ModelResponse for OpenAIResponse {
84 fn get_response(&self) -> Result<String, AgentError> {
85 Ok(self
86 .choices
87 .first()
88 .ok_or(AgentError::Generation(
89 "No message returned from OpenAI".to_string(),
90 ))?
91 .message
92 .content
93 .clone()
94 .unwrap_or_default())
95 }
96
97 fn get_tools_used(&self) -> Result<Vec<ToolCall>, AgentError> {
98 if let Some(tool_calls) = &self
99 .choices
100 .first()
101 .ok_or(AgentError::Generation(
102 "No message returned from OpenAI".to_string(),
103 ))?
104 .message
105 .tool_calls
106 {
107 Ok(tool_calls.clone())
108 } else {
109 Err(AgentError::Generation(
110 "No tool calls returned from OpenAI".to_string(),
111 ))
112 }
113 }
114}
115
116#[derive(Debug)]
117pub struct OpenAIServerModel {
118 pub base_url: String,
119 pub model_id: String,
120 pub client: Client,
121 pub temperature: f32,
122 pub api_key: String,
123}
124
125impl OpenAIServerModel {
126 pub fn new(
127 base_url: Option<&str>,
128 model_id: Option<&str>,
129 temperature: Option<f32>,
130 api_key: Option<String>,
131 ) -> Self {
132 let api_key = api_key.unwrap_or_else(|| {
133 std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY must be set")
134 });
135 let model_id = model_id.unwrap_or("gpt-4o-mini").to_string();
136 let base_url = base_url.unwrap_or("https://api.openai.com/v1/chat/completions");
137 let client = Client::new();
138
139 OpenAIServerModel {
140 base_url: base_url.to_string(),
141 model_id,
142 client,
143 temperature: temperature.unwrap_or(0.5),
144 api_key,
145 }
146 }
147}
148
149impl Model for OpenAIServerModel {
150 fn run(
151 &self,
152 messages: Vec<Message>,
153 tools_to_call_from: Vec<ToolInfo>,
154 max_tokens: Option<usize>,
155 args: Option<HashMap<String, Vec<String>>>,
156 ) -> Result<Box<dyn ModelResponse>, AgentError> {
157 let max_tokens = max_tokens.unwrap_or(1500);
158 let messages = messages
159 .iter()
160 .map(|message| {
161 json!({
162 "role": message.role,
163 "content": message.content
164 })
165 })
166 .collect::<Vec<_>>();
167
168 let mut body = json!({
169 "model": self.model_id,
170 "messages": messages,
171 "temperature": self.temperature,
172 "max_tokens": max_tokens,
173 });
174
175 if !tools_to_call_from.is_empty() {
176 body["tools"] = json!(tools_to_call_from);
177 body["tool_choice"] = json!("required");
178 }
179
180 if let Some(args) = args {
181 let body_map = body.as_object_mut().unwrap();
182 for (key, value) in args {
183 body_map.insert(key, json!(value));
184 }
185 }
186
187 let response = self
188 .client
189 .post(&self.base_url)
190 .header("Authorization", format!("Bearer {}", self.api_key))
191 .json(&body)
192 .send()
193 .map_err(|e| {
194 AgentError::Generation(format!("Failed to get response from OpenAI: {}", e))
195 })?;
196
197 match response.status() {
198 reqwest::StatusCode::OK => {
199 let response = response.json::<OpenAIResponse>().unwrap();
200 Ok(Box::new(response))
201 }
202 _ => Err(AgentError::Generation(format!(
203 "Failed to get response from OpenAI: {}",
204 response.text().unwrap()
205 ))),
206 }
207 }
208}