smolagents_rs/models/
openai.rs

1use std::collections::HashMap;
2
3use crate::errors::AgentError;
4use crate::models::model_traits::{Model, ModelResponse};
5use crate::models::types::{Message, MessageRole};
6use crate::tools::ToolInfo;
7use anyhow::Result;
8use reqwest::blocking::Client;
9use serde::{Deserialize, Serialize};
10use serde_json::{json, Value};
11
12#[derive(Debug, Deserialize)]
13pub struct OpenAIResponse {
14    pub choices: Vec<Choice>,
15}
16
17#[derive(Debug, Deserialize)]
18pub struct Choice {
19    pub message: AssistantMessage,
20}
21
22#[derive(Debug, Deserialize)]
23pub struct AssistantMessage {
24    pub role: MessageRole,
25    pub content: Option<String>,
26    pub tool_calls: Option<Vec<ToolCall>>,
27    pub refusal: Option<String>,
28}
29
30#[derive(Debug, Serialize, Deserialize, Clone)]
31pub struct ToolCall {
32    pub id: Option<String>,
33    #[serde(rename = "type")]
34    pub call_type: Option<String>,
35    pub function: FunctionCall,
36}
37
38#[derive(Debug, Serialize, Deserialize, Clone)]
39pub struct FunctionCall {
40    pub name: String,
41    #[serde(deserialize_with = "deserialize_arguments")]
42    pub arguments: Value,
43}
44
45// Add this function to handle argument deserialization
46fn deserialize_arguments<'de, D>(deserializer: D) -> Result<Value, D::Error>
47where
48    D: serde::Deserializer<'de>,
49{
50    let value = Value::deserialize(deserializer)?;
51
52    // If it's a string, try to parse it as JSON
53    if let Value::String(s) = &value {
54        if let Ok(parsed) = serde_json::from_str(s) {
55            return Ok(parsed);
56        }
57    }
58
59    Ok(value)
60}
61
62impl FunctionCall {
63    pub fn get_arguments(&self) -> Result<HashMap<String, String>> {
64        // First try to parse as a HashMap directly
65        if let Ok(map) = serde_json::from_value(self.arguments.clone()) {
66            return Ok(map);
67        }
68
69        // If that fails, try to parse as a string and then parse that string as JSON
70        if let Value::String(arg_str) = &self.arguments {
71            if let Ok(parsed) = serde_json::from_str(arg_str) {
72                return Ok(parsed);
73            }
74        }
75
76        // If all parsing attempts fail, return the original error
77        Err(anyhow::anyhow!(
78            "Failed to parse arguments as HashMap or JSON string"
79        ))
80    }
81}
82
83impl ModelResponse for OpenAIResponse {
84    fn get_response(&self) -> Result<String, AgentError> {
85        Ok(self
86            .choices
87            .first()
88            .ok_or(AgentError::Generation(
89                "No message returned from OpenAI".to_string(),
90            ))?
91            .message
92            .content
93            .clone()
94            .unwrap_or_default())
95    }
96
97    fn get_tools_used(&self) -> Result<Vec<ToolCall>, AgentError> {
98        if let Some(tool_calls) = &self
99            .choices
100            .first()
101            .ok_or(AgentError::Generation(
102                "No message returned from OpenAI".to_string(),
103            ))?
104            .message
105            .tool_calls
106        {
107            Ok(tool_calls.clone())
108        } else {
109            Err(AgentError::Generation(
110                "No tool calls returned from OpenAI".to_string(),
111            ))
112        }
113    }
114}
115
116#[derive(Debug)]
117pub struct OpenAIServerModel {
118    pub base_url: String,
119    pub model_id: String,
120    pub client: Client,
121    pub temperature: f32,
122    pub api_key: String,
123}
124
125impl OpenAIServerModel {
126    pub fn new(
127        base_url: Option<&str>,
128        model_id: Option<&str>,
129        temperature: Option<f32>,
130        api_key: Option<String>,
131    ) -> Self {
132        let api_key = api_key.unwrap_or_else(|| {
133            std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY must be set")
134        });
135        let model_id = model_id.unwrap_or("gpt-4o-mini").to_string();
136        let base_url = base_url.unwrap_or("https://api.openai.com/v1/chat/completions");
137        let client = Client::new();
138
139        OpenAIServerModel {
140            base_url: base_url.to_string(),
141            model_id,
142            client,
143            temperature: temperature.unwrap_or(0.5),
144            api_key,
145        }
146    }
147}
148
149impl Model for OpenAIServerModel {
150    fn run(
151        &self,
152        messages: Vec<Message>,
153        tools_to_call_from: Vec<ToolInfo>,
154        max_tokens: Option<usize>,
155        args: Option<HashMap<String, Vec<String>>>,
156    ) -> Result<Box<dyn ModelResponse>, AgentError> {
157        let max_tokens = max_tokens.unwrap_or(1500);
158        let messages = messages
159            .iter()
160            .map(|message| {
161                json!({
162                    "role": message.role,
163                    "content": message.content
164                })
165            })
166            .collect::<Vec<_>>();
167
168        let mut body = json!({
169            "model": self.model_id,
170            "messages": messages,
171            "temperature": self.temperature,
172            "max_tokens": max_tokens,
173        });
174
175        if !tools_to_call_from.is_empty() {
176            body["tools"] = json!(tools_to_call_from);
177            body["tool_choice"] = json!("required");
178        }
179
180        if let Some(args) = args {
181            let body_map = body.as_object_mut().unwrap();
182            for (key, value) in args {
183                body_map.insert(key, json!(value));
184            }
185        }
186
187        let response = self
188            .client
189            .post(&self.base_url)
190            .header("Authorization", format!("Bearer {}", self.api_key))
191            .json(&body)
192            .send()
193            .map_err(|e| {
194                AgentError::Generation(format!("Failed to get response from OpenAI: {}", e))
195            })?;
196
197        match response.status() {
198            reqwest::StatusCode::OK => {
199                let response = response.json::<OpenAIResponse>().unwrap();
200                Ok(Box::new(response))
201            }
202            _ => Err(AgentError::Generation(format!(
203                "Failed to get response from OpenAI: {}",
204                response.text().unwrap()
205            ))),
206        }
207    }
208}