strands_agents/models/
openai.rs

1//! OpenAI model provider.
2
3use async_trait::async_trait;
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6
7use super::{Model, ModelConfig, StreamEventStream};
8use crate::types::{
9    content::{Message, Role, SystemContentBlock},
10    errors::StrandsError,
11    streaming::{
12        ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockDeltaToolUse, ContentBlockStart,
13        ContentBlockStartEvent, ContentBlockStartToolUse, ContentBlockStopEvent, MessageStartEvent,
14        MessageStopEvent, MetadataEvent, Metrics, StopReason, StreamEvent, Usage,
15    },
16    tools::{ToolChoice, ToolSpec},
17};
18
19const DEFAULT_MODEL_ID: &str = "gpt-4o";
20const OPENAI_API_URL: &str = "https://api.openai.com/v1/chat/completions";
21
22/// OpenAI model provider.
23#[derive(Clone)]
24pub struct OpenAIModel {
25    config: ModelConfig,
26    api_key: String,
27    base_url: Option<String>,
28    client: Client,
29}
30
31impl std::fmt::Debug for OpenAIModel {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        f.debug_struct("OpenAIModel")
34            .field("config", &self.config)
35            .field("base_url", &self.base_url)
36            .finish()
37    }
38}
39
40#[derive(Debug, Serialize)]
41struct OpenAIRequest {
42    model: String,
43    messages: Vec<OpenAIMessage>,
44    stream: bool,
45    #[serde(skip_serializing_if = "Option::is_none")]
46    max_tokens: Option<u32>,
47    #[serde(skip_serializing_if = "Option::is_none")]
48    temperature: Option<f32>,
49    #[serde(skip_serializing_if = "Option::is_none")]
50    top_p: Option<f32>,
51    #[serde(skip_serializing_if = "Vec::is_empty")]
52    tools: Vec<OpenAITool>,
53    #[serde(skip_serializing_if = "Option::is_none")]
54    tool_choice: Option<serde_json::Value>,
55    stream_options: StreamOptions,
56}
57
58#[derive(Debug, Serialize)]
59struct StreamOptions {
60    include_usage: bool,
61}
62
63#[derive(Debug, Serialize)]
64struct OpenAIMessage {
65    role: String,
66    content: serde_json::Value,
67    #[serde(skip_serializing_if = "Option::is_none")]
68    tool_calls: Option<Vec<OpenAIToolCall>>,
69    #[serde(skip_serializing_if = "Option::is_none")]
70    tool_call_id: Option<String>,
71}
72
73#[derive(Debug, Serialize, Deserialize, Clone)]
74struct OpenAIToolCall {
75    id: String,
76    #[serde(rename = "type")]
77    call_type: String,
78    function: OpenAIFunction,
79}
80
81#[derive(Debug, Serialize, Deserialize, Clone)]
82struct OpenAIFunction {
83    name: String,
84    arguments: String,
85}
86
87#[derive(Debug, Serialize)]
88struct OpenAITool {
89    #[serde(rename = "type")]
90    tool_type: String,
91    function: OpenAIFunctionDef,
92}
93
94#[derive(Debug, Serialize)]
95struct OpenAIFunctionDef {
96    name: String,
97    description: String,
98    parameters: serde_json::Value,
99}
100
101#[derive(Debug, Deserialize)]
102struct OpenAIStreamChunk {
103    choices: Vec<OpenAIChoice>,
104    #[serde(default)]
105    usage: Option<OpenAIUsage>,
106}
107
108#[derive(Debug, Deserialize)]
109struct OpenAIChoice {
110    delta: OpenAIDelta,
111    finish_reason: Option<String>,
112}
113
114#[derive(Debug, Deserialize)]
115struct OpenAIDelta {
116    content: Option<String>,
117    tool_calls: Option<Vec<OpenAIToolCallDelta>>,
118}
119
120#[derive(Debug, Deserialize, Clone)]
121struct OpenAIToolCallDelta {
122    index: usize,
123    id: Option<String>,
124    function: Option<OpenAIFunctionDelta>,
125}
126
127#[derive(Debug, Deserialize, Clone)]
128struct OpenAIFunctionDelta {
129    name: Option<String>,
130    arguments: Option<String>,
131}
132
133#[derive(Debug, Deserialize)]
134struct OpenAIUsage {
135    prompt_tokens: u32,
136    completion_tokens: u32,
137    total_tokens: u32,
138}
139
140impl OpenAIModel {
141    pub fn new(api_key: impl Into<String>) -> Self {
142        Self {
143            config: ModelConfig::new(DEFAULT_MODEL_ID),
144            api_key: api_key.into(),
145            base_url: None,
146            client: Client::new(),
147        }
148    }
149
150    pub fn with_model(mut self, model_id: impl Into<String>) -> Self {
151        self.config.model_id = model_id.into();
152        self
153    }
154
155    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
156        self.base_url = Some(base_url.into());
157        self
158    }
159
160    pub fn with_config(mut self, config: ModelConfig) -> Self {
161        self.config = config;
162        self
163    }
164
165    fn format_messages(&self, messages: &[Message], system_prompt: Option<&str>) -> Vec<OpenAIMessage> {
166        let mut formatted = Vec::new();
167
168        if let Some(prompt) = system_prompt {
169            formatted.push(OpenAIMessage {
170                role: "system".to_string(),
171                content: serde_json::Value::String(prompt.to_string()),
172                tool_calls: None,
173                tool_call_id: None,
174            });
175        }
176
177        for msg in messages {
178            let role = match msg.role {
179                Role::User => "user",
180                Role::Assistant => "assistant",
181            };
182
183            let mut text_content = Vec::new();
184            let mut tool_calls = Vec::new();
185            let mut tool_results = Vec::new();
186
187            for block in &msg.content {
188                if let Some(ref text) = block.text {
189                    text_content.push(serde_json::json!({ "type": "text", "text": text }));
190                }
191
192                if let Some(ref tu) = block.tool_use {
193                    tool_calls.push(OpenAIToolCall {
194                        id: tu.tool_use_id.clone(),
195                        call_type: "function".to_string(),
196                        function: OpenAIFunction {
197                            name: tu.name.clone(),
198                            arguments: serde_json::to_string(&tu.input).unwrap_or_default(),
199                        },
200                    });
201                }
202
203                if let Some(ref tr) = block.tool_result {
204                    let content = tr
205                        .content
206                        .iter()
207                        .filter_map(|c| c.text.clone())
208                        .collect::<Vec<_>>()
209                        .join("\n");
210                    tool_results.push((tr.tool_use_id.clone(), content));
211                }
212            }
213
214            if !tool_calls.is_empty() {
215                formatted.push(OpenAIMessage {
216                    role: role.to_string(),
217                    content: if text_content.is_empty() {
218                        serde_json::Value::Null
219                    } else {
220                        serde_json::Value::Array(text_content.clone())
221                    },
222                    tool_calls: Some(tool_calls),
223                    tool_call_id: None,
224                });
225            } else if !text_content.is_empty() {
226                formatted.push(OpenAIMessage {
227                    role: role.to_string(),
228                    content: serde_json::Value::Array(text_content),
229                    tool_calls: None,
230                    tool_call_id: None,
231                });
232            }
233
234            for (tool_id, content) in tool_results {
235                formatted.push(OpenAIMessage {
236                    role: "tool".to_string(),
237                    content: serde_json::Value::String(content),
238                    tool_calls: None,
239                    tool_call_id: Some(tool_id),
240                });
241            }
242        }
243
244        formatted
245    }
246
247    fn format_tools(&self, tool_specs: &[ToolSpec]) -> Vec<OpenAITool> {
248        tool_specs
249            .iter()
250            .map(|spec| OpenAITool {
251                tool_type: "function".to_string(),
252                function: OpenAIFunctionDef {
253                    name: spec.name.clone(),
254                    description: spec.description.clone(),
255                    parameters: spec.input_schema.json.clone(),
256                },
257            })
258            .collect()
259    }
260
261    fn format_tool_choice(&self, tool_choice: Option<ToolChoice>) -> Option<serde_json::Value> {
262        tool_choice.map(|tc| match tc {
263            ToolChoice::Auto(_) => serde_json::json!("auto"),
264            ToolChoice::Any(_) => serde_json::json!("required"),
265            ToolChoice::Tool(t) => serde_json::json!({
266                "type": "function",
267                "function": { "name": t.name }
268            }),
269        })
270    }
271
272    fn map_stop_reason(reason: &str) -> StopReason {
273        match reason {
274            "tool_calls" => StopReason::ToolUse,
275            "length" => StopReason::MaxTokens,
276            "content_filter" => StopReason::ContentFiltered,
277            _ => StopReason::EndTurn,
278        }
279    }
280}
281
282#[async_trait]
283impl Model for OpenAIModel {
284    fn config(&self) -> &ModelConfig {
285        &self.config
286    }
287
288    fn update_config(&mut self, config: ModelConfig) {
289        self.config = config;
290    }
291
292    fn stream<'a>(
293        &'a self,
294        messages: &'a [Message],
295        tool_specs: Option<&'a [ToolSpec]>,
296        system_prompt: Option<&'a str>,
297        tool_choice: Option<ToolChoice>,
298        _system_prompt_content: Option<&'a [SystemContentBlock]>,
299    ) -> StreamEventStream<'a> {
300        let url = self.base_url.clone().unwrap_or_else(|| OPENAI_API_URL.to_string());
301        let api_key = self.api_key.clone();
302        let client = self.client.clone();
303
304        let request = OpenAIRequest {
305            model: self.config.model_id.clone(),
306            messages: self.format_messages(messages, system_prompt),
307            stream: true,
308            max_tokens: self.config.max_tokens,
309            temperature: self.config.temperature,
310            top_p: self.config.top_p,
311            tools: tool_specs.map(|s| self.format_tools(s)).unwrap_or_default(),
312            tool_choice: self.format_tool_choice(tool_choice),
313            stream_options: StreamOptions { include_usage: true },
314        };
315
316        Box::pin(async_stream::stream! {
317            let response = match client
318                .post(&url)
319                .header("Authorization", format!("Bearer {api_key}"))
320                .header("Content-Type", "application/json")
321                .json(&request)
322                .send()
323                .await
324            {
325                Ok(resp) => resp,
326                Err(e) => {
327                    yield Err(StrandsError::NetworkError(e.to_string()));
328                    return;
329                }
330            };
331
332            if !response.status().is_success() {
333                let status = response.status();
334                let body = response.text().await.unwrap_or_default();
335                if status.as_u16() == 429 {
336                    yield Err(StrandsError::ModelThrottled { message: body });
337                } else if body.contains("context_length_exceeded") {
338                    yield Err(StrandsError::ContextWindowOverflow { message: body });
339                } else {
340                    yield Err(StrandsError::model_error(format!("HTTP {status}: {body}")));
341                }
342                return;
343            }
344
345            yield Ok(StreamEvent {
346                message_start: Some(MessageStartEvent { role: Role::Assistant }),
347                ..Default::default()
348            });
349
350            let mut content_started = false;
351            let mut tool_calls: std::collections::HashMap<usize, (String, String, String)> = std::collections::HashMap::new();
352            let mut finish_reason = None;
353            let mut final_usage = None;
354
355            use futures::StreamExt;
356            let mut byte_stream = response.bytes_stream();
357            let mut buffer = String::new();
358
359            loop {
360                for line in buffer.lines() {
361                    let line = line.trim();
362                    if line.is_empty() || line == "data: [DONE]" {
363                        continue;
364                    }
365
366                    if let Some(json_str) = line.strip_prefix("data: ") {
367                        if let Ok(chunk) = serde_json::from_str::<OpenAIStreamChunk>(json_str) {
368                            if let Some(usage) = chunk.usage {
369                                final_usage = Some(usage);
370                            }
371
372                            for choice in chunk.choices {
373                                if let Some(ref content) = choice.delta.content {
374                                    if !content_started {
375                                        yield Ok(StreamEvent {
376                                            content_block_start: Some(ContentBlockStartEvent {
377                                                content_block_index: Some(0),
378                                                start: None,
379                                            }),
380                                            ..Default::default()
381                                        });
382                                        content_started = true;
383                                    }
384
385                                    yield Ok(StreamEvent {
386                                        content_block_delta: Some(ContentBlockDeltaEvent {
387                                            content_block_index: Some(0),
388                                            delta: Some(ContentBlockDelta {
389                                                text: Some(content.clone()),
390                                                ..Default::default()
391                                            }),
392                                        }),
393                                        ..Default::default()
394                                    });
395                                }
396
397                                if let Some(ref tcs) = choice.delta.tool_calls {
398                                    for tc in tcs {
399                                        let entry = tool_calls.entry(tc.index).or_insert_with(|| {
400                                            (String::new(), String::new(), String::new())
401                                        });
402                                        if let Some(ref id) = tc.id {
403                                            entry.0 = id.clone();
404                                        }
405                                        if let Some(ref f) = tc.function {
406                                            if let Some(ref name) = f.name {
407                                                entry.1 = name.clone();
408                                            }
409                                            if let Some(ref args) = f.arguments {
410                                                entry.2.push_str(args);
411                                            }
412                                        }
413                                    }
414                                }
415
416                                if let Some(ref reason) = choice.finish_reason {
417                                    finish_reason = Some(reason.clone());
418                                }
419                            }
420                        }
421                    }
422                }
423
424                match byte_stream.next().await {
425                    Some(Ok(bytes)) => {
426                        buffer = String::from_utf8_lossy(&bytes).to_string();
427                    }
428                    _ => break,
429                }
430            }
431
432            if content_started {
433                yield Ok(StreamEvent {
434                    content_block_stop: Some(ContentBlockStopEvent {
435                        content_block_index: Some(0),
436                    }),
437                    ..Default::default()
438                });
439            }
440
441            let mut tool_index = 1u32;
442            for (_idx, (id, name, args)) in tool_calls {
443                yield Ok(StreamEvent {
444                    content_block_start: Some(ContentBlockStartEvent {
445                        content_block_index: Some(tool_index),
446                        start: Some(ContentBlockStart {
447                            tool_use: Some(ContentBlockStartToolUse {
448                                name: name.clone(),
449                                tool_use_id: id.clone(),
450                            }),
451                        }),
452                    }),
453                    ..Default::default()
454                });
455
456                yield Ok(StreamEvent {
457                    content_block_delta: Some(ContentBlockDeltaEvent {
458                        content_block_index: Some(tool_index),
459                        delta: Some(ContentBlockDelta {
460                            tool_use: Some(ContentBlockDeltaToolUse { input: args }),
461                            ..Default::default()
462                        }),
463                    }),
464                    ..Default::default()
465                });
466
467                yield Ok(StreamEvent {
468                    content_block_stop: Some(ContentBlockStopEvent {
469                        content_block_index: Some(tool_index),
470                    }),
471                    ..Default::default()
472                });
473
474                tool_index += 1;
475            }
476
477            let stop = finish_reason.as_deref().map(Self::map_stop_reason).unwrap_or(StopReason::EndTurn);
478
479            yield Ok(StreamEvent {
480                message_stop: Some(MessageStopEvent {
481                    stop_reason: Some(stop),
482                    additional_model_response_fields: None,
483                }),
484                ..Default::default()
485            });
486
487            if let Some(usage) = final_usage {
488                yield Ok(StreamEvent {
489                    metadata: Some(MetadataEvent {
490                        usage: Some(Usage {
491                            input_tokens: usage.prompt_tokens,
492                            output_tokens: usage.completion_tokens,
493                            total_tokens: usage.total_tokens,
494                            cache_read_input_tokens: 0,
495                            cache_write_input_tokens: 0,
496                        }),
497                        metrics: Some(Metrics {
498                            latency_ms: 0,
499                            time_to_first_byte_ms: 0,
500                        }),
501                        trace: None,
502                    }),
503                    ..Default::default()
504                });
505            }
506        })
507    }
508}
509
510#[cfg(test)]
511mod tests {
512    use super::*;
513
514    #[test]
515    fn test_openai_model_creation() {
516        let model = OpenAIModel::new("test-key").with_model("gpt-4o-mini");
517        assert_eq!(model.config().model_id, "gpt-4o-mini");
518    }
519
520    #[test]
521    fn test_openai_with_base_url() {
522        let model = OpenAIModel::new("test-key").with_base_url("https://custom.api.com");
523        assert_eq!(model.base_url, Some("https://custom.api.com".to_string()));
524    }
525}
526