1use async_trait::async_trait;
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6
7use super::{Model, ModelConfig, StreamEventStream};
8use crate::types::{
9 content::{Message, Role, SystemContentBlock},
10 errors::StrandsError,
11 streaming::{
12 ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockDeltaToolUse, ContentBlockStart,
13 ContentBlockStartEvent, ContentBlockStartToolUse, ContentBlockStopEvent, MessageStartEvent,
14 MessageStopEvent, MetadataEvent, Metrics, StopReason, StreamEvent, Usage,
15 },
16 tools::{ToolChoice, ToolSpec},
17};
18
19const DEFAULT_MODEL_ID: &str = "gpt-4o";
20const OPENAI_API_URL: &str = "https://api.openai.com/v1/chat/completions";
21
22#[derive(Clone)]
24pub struct OpenAIModel {
25 config: ModelConfig,
26 api_key: String,
27 base_url: Option<String>,
28 client: Client,
29}
30
31impl std::fmt::Debug for OpenAIModel {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 f.debug_struct("OpenAIModel")
34 .field("config", &self.config)
35 .field("base_url", &self.base_url)
36 .finish()
37 }
38}
39
40#[derive(Debug, Serialize)]
41struct OpenAIRequest {
42 model: String,
43 messages: Vec<OpenAIMessage>,
44 stream: bool,
45 #[serde(skip_serializing_if = "Option::is_none")]
46 max_tokens: Option<u32>,
47 #[serde(skip_serializing_if = "Option::is_none")]
48 temperature: Option<f32>,
49 #[serde(skip_serializing_if = "Option::is_none")]
50 top_p: Option<f32>,
51 #[serde(skip_serializing_if = "Vec::is_empty")]
52 tools: Vec<OpenAITool>,
53 #[serde(skip_serializing_if = "Option::is_none")]
54 tool_choice: Option<serde_json::Value>,
55 stream_options: StreamOptions,
56}
57
58#[derive(Debug, Serialize)]
59struct StreamOptions {
60 include_usage: bool,
61}
62
63#[derive(Debug, Serialize)]
64struct OpenAIMessage {
65 role: String,
66 content: serde_json::Value,
67 #[serde(skip_serializing_if = "Option::is_none")]
68 tool_calls: Option<Vec<OpenAIToolCall>>,
69 #[serde(skip_serializing_if = "Option::is_none")]
70 tool_call_id: Option<String>,
71}
72
73#[derive(Debug, Serialize, Deserialize, Clone)]
74struct OpenAIToolCall {
75 id: String,
76 #[serde(rename = "type")]
77 call_type: String,
78 function: OpenAIFunction,
79}
80
81#[derive(Debug, Serialize, Deserialize, Clone)]
82struct OpenAIFunction {
83 name: String,
84 arguments: String,
85}
86
87#[derive(Debug, Serialize)]
88struct OpenAITool {
89 #[serde(rename = "type")]
90 tool_type: String,
91 function: OpenAIFunctionDef,
92}
93
94#[derive(Debug, Serialize)]
95struct OpenAIFunctionDef {
96 name: String,
97 description: String,
98 parameters: serde_json::Value,
99}
100
101#[derive(Debug, Deserialize)]
102struct OpenAIStreamChunk {
103 choices: Vec<OpenAIChoice>,
104 #[serde(default)]
105 usage: Option<OpenAIUsage>,
106}
107
108#[derive(Debug, Deserialize)]
109struct OpenAIChoice {
110 delta: OpenAIDelta,
111 finish_reason: Option<String>,
112}
113
114#[derive(Debug, Deserialize)]
115struct OpenAIDelta {
116 content: Option<String>,
117 tool_calls: Option<Vec<OpenAIToolCallDelta>>,
118}
119
120#[derive(Debug, Deserialize, Clone)]
121struct OpenAIToolCallDelta {
122 index: usize,
123 id: Option<String>,
124 function: Option<OpenAIFunctionDelta>,
125}
126
127#[derive(Debug, Deserialize, Clone)]
128struct OpenAIFunctionDelta {
129 name: Option<String>,
130 arguments: Option<String>,
131}
132
133#[derive(Debug, Deserialize)]
134struct OpenAIUsage {
135 prompt_tokens: u32,
136 completion_tokens: u32,
137 total_tokens: u32,
138}
139
140impl OpenAIModel {
141 pub fn new(api_key: impl Into<String>) -> Self {
142 Self {
143 config: ModelConfig::new(DEFAULT_MODEL_ID),
144 api_key: api_key.into(),
145 base_url: None,
146 client: Client::new(),
147 }
148 }
149
150 pub fn with_model(mut self, model_id: impl Into<String>) -> Self {
151 self.config.model_id = model_id.into();
152 self
153 }
154
155 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
156 self.base_url = Some(base_url.into());
157 self
158 }
159
160 pub fn with_config(mut self, config: ModelConfig) -> Self {
161 self.config = config;
162 self
163 }
164
165 fn format_messages(&self, messages: &[Message], system_prompt: Option<&str>) -> Vec<OpenAIMessage> {
166 let mut formatted = Vec::new();
167
168 if let Some(prompt) = system_prompt {
169 formatted.push(OpenAIMessage {
170 role: "system".to_string(),
171 content: serde_json::Value::String(prompt.to_string()),
172 tool_calls: None,
173 tool_call_id: None,
174 });
175 }
176
177 for msg in messages {
178 let role = match msg.role {
179 Role::User => "user",
180 Role::Assistant => "assistant",
181 };
182
183 let mut text_content = Vec::new();
184 let mut tool_calls = Vec::new();
185 let mut tool_results = Vec::new();
186
187 for block in &msg.content {
188 if let Some(ref text) = block.text {
189 text_content.push(serde_json::json!({ "type": "text", "text": text }));
190 }
191
192 if let Some(ref tu) = block.tool_use {
193 tool_calls.push(OpenAIToolCall {
194 id: tu.tool_use_id.clone(),
195 call_type: "function".to_string(),
196 function: OpenAIFunction {
197 name: tu.name.clone(),
198 arguments: serde_json::to_string(&tu.input).unwrap_or_default(),
199 },
200 });
201 }
202
203 if let Some(ref tr) = block.tool_result {
204 let content = tr
205 .content
206 .iter()
207 .filter_map(|c| c.text.clone())
208 .collect::<Vec<_>>()
209 .join("\n");
210 tool_results.push((tr.tool_use_id.clone(), content));
211 }
212 }
213
214 if !tool_calls.is_empty() {
215 formatted.push(OpenAIMessage {
216 role: role.to_string(),
217 content: if text_content.is_empty() {
218 serde_json::Value::Null
219 } else {
220 serde_json::Value::Array(text_content.clone())
221 },
222 tool_calls: Some(tool_calls),
223 tool_call_id: None,
224 });
225 } else if !text_content.is_empty() {
226 formatted.push(OpenAIMessage {
227 role: role.to_string(),
228 content: serde_json::Value::Array(text_content),
229 tool_calls: None,
230 tool_call_id: None,
231 });
232 }
233
234 for (tool_id, content) in tool_results {
235 formatted.push(OpenAIMessage {
236 role: "tool".to_string(),
237 content: serde_json::Value::String(content),
238 tool_calls: None,
239 tool_call_id: Some(tool_id),
240 });
241 }
242 }
243
244 formatted
245 }
246
247 fn format_tools(&self, tool_specs: &[ToolSpec]) -> Vec<OpenAITool> {
248 tool_specs
249 .iter()
250 .map(|spec| OpenAITool {
251 tool_type: "function".to_string(),
252 function: OpenAIFunctionDef {
253 name: spec.name.clone(),
254 description: spec.description.clone(),
255 parameters: spec.input_schema.json.clone(),
256 },
257 })
258 .collect()
259 }
260
261 fn format_tool_choice(&self, tool_choice: Option<ToolChoice>) -> Option<serde_json::Value> {
262 tool_choice.map(|tc| match tc {
263 ToolChoice::Auto(_) => serde_json::json!("auto"),
264 ToolChoice::Any(_) => serde_json::json!("required"),
265 ToolChoice::Tool(t) => serde_json::json!({
266 "type": "function",
267 "function": { "name": t.name }
268 }),
269 })
270 }
271
272 fn map_stop_reason(reason: &str) -> StopReason {
273 match reason {
274 "tool_calls" => StopReason::ToolUse,
275 "length" => StopReason::MaxTokens,
276 "content_filter" => StopReason::ContentFiltered,
277 _ => StopReason::EndTurn,
278 }
279 }
280}
281
282#[async_trait]
283impl Model for OpenAIModel {
284 fn config(&self) -> &ModelConfig {
285 &self.config
286 }
287
288 fn update_config(&mut self, config: ModelConfig) {
289 self.config = config;
290 }
291
292 fn stream<'a>(
293 &'a self,
294 messages: &'a [Message],
295 tool_specs: Option<&'a [ToolSpec]>,
296 system_prompt: Option<&'a str>,
297 tool_choice: Option<ToolChoice>,
298 _system_prompt_content: Option<&'a [SystemContentBlock]>,
299 ) -> StreamEventStream<'a> {
300 let url = self.base_url.clone().unwrap_or_else(|| OPENAI_API_URL.to_string());
301 let api_key = self.api_key.clone();
302 let client = self.client.clone();
303
304 let request = OpenAIRequest {
305 model: self.config.model_id.clone(),
306 messages: self.format_messages(messages, system_prompt),
307 stream: true,
308 max_tokens: self.config.max_tokens,
309 temperature: self.config.temperature,
310 top_p: self.config.top_p,
311 tools: tool_specs.map(|s| self.format_tools(s)).unwrap_or_default(),
312 tool_choice: self.format_tool_choice(tool_choice),
313 stream_options: StreamOptions { include_usage: true },
314 };
315
316 Box::pin(async_stream::stream! {
317 let response = match client
318 .post(&url)
319 .header("Authorization", format!("Bearer {api_key}"))
320 .header("Content-Type", "application/json")
321 .json(&request)
322 .send()
323 .await
324 {
325 Ok(resp) => resp,
326 Err(e) => {
327 yield Err(StrandsError::NetworkError(e.to_string()));
328 return;
329 }
330 };
331
332 if !response.status().is_success() {
333 let status = response.status();
334 let body = response.text().await.unwrap_or_default();
335 if status.as_u16() == 429 {
336 yield Err(StrandsError::ModelThrottled { message: body });
337 } else if body.contains("context_length_exceeded") {
338 yield Err(StrandsError::ContextWindowOverflow { message: body });
339 } else {
340 yield Err(StrandsError::model_error(format!("HTTP {status}: {body}")));
341 }
342 return;
343 }
344
345 yield Ok(StreamEvent {
346 message_start: Some(MessageStartEvent { role: Role::Assistant }),
347 ..Default::default()
348 });
349
350 let mut content_started = false;
351 let mut tool_calls: std::collections::HashMap<usize, (String, String, String)> = std::collections::HashMap::new();
352 let mut finish_reason = None;
353 let mut final_usage = None;
354
355 use futures::StreamExt;
356 let mut byte_stream = response.bytes_stream();
357 let mut buffer = String::new();
358
359 loop {
360 for line in buffer.lines() {
361 let line = line.trim();
362 if line.is_empty() || line == "data: [DONE]" {
363 continue;
364 }
365
366 if let Some(json_str) = line.strip_prefix("data: ") {
367 if let Ok(chunk) = serde_json::from_str::<OpenAIStreamChunk>(json_str) {
368 if let Some(usage) = chunk.usage {
369 final_usage = Some(usage);
370 }
371
372 for choice in chunk.choices {
373 if let Some(ref content) = choice.delta.content {
374 if !content_started {
375 yield Ok(StreamEvent {
376 content_block_start: Some(ContentBlockStartEvent {
377 content_block_index: Some(0),
378 start: None,
379 }),
380 ..Default::default()
381 });
382 content_started = true;
383 }
384
385 yield Ok(StreamEvent {
386 content_block_delta: Some(ContentBlockDeltaEvent {
387 content_block_index: Some(0),
388 delta: Some(ContentBlockDelta {
389 text: Some(content.clone()),
390 ..Default::default()
391 }),
392 }),
393 ..Default::default()
394 });
395 }
396
397 if let Some(ref tcs) = choice.delta.tool_calls {
398 for tc in tcs {
399 let entry = tool_calls.entry(tc.index).or_insert_with(|| {
400 (String::new(), String::new(), String::new())
401 });
402 if let Some(ref id) = tc.id {
403 entry.0 = id.clone();
404 }
405 if let Some(ref f) = tc.function {
406 if let Some(ref name) = f.name {
407 entry.1 = name.clone();
408 }
409 if let Some(ref args) = f.arguments {
410 entry.2.push_str(args);
411 }
412 }
413 }
414 }
415
416 if let Some(ref reason) = choice.finish_reason {
417 finish_reason = Some(reason.clone());
418 }
419 }
420 }
421 }
422 }
423
424 match byte_stream.next().await {
425 Some(Ok(bytes)) => {
426 buffer = String::from_utf8_lossy(&bytes).to_string();
427 }
428 _ => break,
429 }
430 }
431
432 if content_started {
433 yield Ok(StreamEvent {
434 content_block_stop: Some(ContentBlockStopEvent {
435 content_block_index: Some(0),
436 }),
437 ..Default::default()
438 });
439 }
440
441 let mut tool_index = 1u32;
442 for (_idx, (id, name, args)) in tool_calls {
443 yield Ok(StreamEvent {
444 content_block_start: Some(ContentBlockStartEvent {
445 content_block_index: Some(tool_index),
446 start: Some(ContentBlockStart {
447 tool_use: Some(ContentBlockStartToolUse {
448 name: name.clone(),
449 tool_use_id: id.clone(),
450 }),
451 }),
452 }),
453 ..Default::default()
454 });
455
456 yield Ok(StreamEvent {
457 content_block_delta: Some(ContentBlockDeltaEvent {
458 content_block_index: Some(tool_index),
459 delta: Some(ContentBlockDelta {
460 tool_use: Some(ContentBlockDeltaToolUse { input: args }),
461 ..Default::default()
462 }),
463 }),
464 ..Default::default()
465 });
466
467 yield Ok(StreamEvent {
468 content_block_stop: Some(ContentBlockStopEvent {
469 content_block_index: Some(tool_index),
470 }),
471 ..Default::default()
472 });
473
474 tool_index += 1;
475 }
476
477 let stop = finish_reason.as_deref().map(Self::map_stop_reason).unwrap_or(StopReason::EndTurn);
478
479 yield Ok(StreamEvent {
480 message_stop: Some(MessageStopEvent {
481 stop_reason: Some(stop),
482 additional_model_response_fields: None,
483 }),
484 ..Default::default()
485 });
486
487 if let Some(usage) = final_usage {
488 yield Ok(StreamEvent {
489 metadata: Some(MetadataEvent {
490 usage: Some(Usage {
491 input_tokens: usage.prompt_tokens,
492 output_tokens: usage.completion_tokens,
493 total_tokens: usage.total_tokens,
494 cache_read_input_tokens: 0,
495 cache_write_input_tokens: 0,
496 }),
497 metrics: Some(Metrics {
498 latency_ms: 0,
499 time_to_first_byte_ms: 0,
500 }),
501 trace: None,
502 }),
503 ..Default::default()
504 });
505 }
506 })
507 }
508}
509
510#[cfg(test)]
511mod tests {
512 use super::*;
513
514 #[test]
515 fn test_openai_model_creation() {
516 let model = OpenAIModel::new("test-key").with_model("gpt-4o-mini");
517 assert_eq!(model.config().model_id, "gpt-4o-mini");
518 }
519
520 #[test]
521 fn test_openai_with_base_url() {
522 let model = OpenAIModel::new("test-key").with_base_url("https://custom.api.com");
523 assert_eq!(model.base_url, Some("https://custom.api.com".to_string()));
524 }
525}
526