stakpak_shared/models/integrations/
openai.rs

1use crate::models::error::{AgentError, BadRequestErrorMessage};
2use crate::models::llm::{
3    GenerationDelta, GenerationDeltaToolUse, LLMChoice, LLMCompletionResponse, LLMMessage,
4    LLMMessageContent, LLMMessageImageSource, LLMMessageTypedContent, LLMTokenUsage, LLMTool,
5};
6use crate::models::model_pricing::{ContextAware, ContextPricingTier, ModelContextInfo};
7use futures_util::StreamExt;
8use reqwest_middleware::ClientBuilder;
9use reqwest_retry::{RetryTransientMiddleware, policies::ExponentialBackoff};
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use serde_json::json;
13use uuid::Uuid;
14
15const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1/chat/completions";
16
17#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq)]
18pub struct OpenAIConfig {
19    pub api_endpoint: Option<String>,
20    pub api_key: Option<String>,
21}
22
23#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
24pub enum OpenAIModel {
25    // Reasoning Models
26    #[serde(rename = "o3-2025-04-16")]
27    O3,
28    #[serde(rename = "o4-mini-2025-04-16")]
29    O4Mini,
30
31    #[default]
32    #[serde(rename = "gpt-5-2025-08-07")]
33    GPT5,
34    #[serde(rename = "gpt-5.1-2025-11-13")]
35    GPT51,
36    #[serde(rename = "gpt-5-mini-2025-08-07")]
37    GPT5Mini,
38    #[serde(rename = "gpt-5-nano-2025-08-07")]
39    GPT5Nano,
40
41    Custom(String),
42}
43
44impl OpenAIModel {
45    pub fn from_string(s: &str) -> Result<Self, String> {
46        serde_json::from_value(serde_json::Value::String(s.to_string()))
47            .map_err(|_| "Failed to deserialize OpenAI model".to_string())
48    }
49
50    /// Default smart model for OpenAI
51    pub const DEFAULT_SMART_MODEL: OpenAIModel = OpenAIModel::GPT5;
52
53    /// Default eco model for OpenAI
54    pub const DEFAULT_ECO_MODEL: OpenAIModel = OpenAIModel::GPT5Mini;
55
56    /// Default recovery model for OpenAI
57    pub const DEFAULT_RECOVERY_MODEL: OpenAIModel = OpenAIModel::GPT5Mini;
58
59    /// Get default smart model as string
60    pub fn default_smart_model() -> String {
61        Self::DEFAULT_SMART_MODEL.to_string()
62    }
63
64    /// Get default eco model as string
65    pub fn default_eco_model() -> String {
66        Self::DEFAULT_ECO_MODEL.to_string()
67    }
68
69    /// Get default recovery model as string
70    pub fn default_recovery_model() -> String {
71        Self::DEFAULT_RECOVERY_MODEL.to_string()
72    }
73}
74
75impl ContextAware for OpenAIModel {
76    fn context_info(&self) -> ModelContextInfo {
77        let model_name = self.to_string();
78
79        if model_name.starts_with("o3") {
80            return ModelContextInfo {
81                max_tokens: 200_000,
82                pricing_tiers: vec![ContextPricingTier {
83                    label: "Standard".to_string(),
84                    input_cost_per_million: 2.0,
85                    output_cost_per_million: 8.0,
86                    upper_bound: None,
87                }],
88                approach_warning_threshold: 0.8,
89            };
90        }
91
92        if model_name.starts_with("o4-mini") {
93            return ModelContextInfo {
94                max_tokens: 200_000,
95                pricing_tiers: vec![ContextPricingTier {
96                    label: "Standard".to_string(),
97                    input_cost_per_million: 1.10,
98                    output_cost_per_million: 4.40,
99                    upper_bound: None,
100                }],
101                approach_warning_threshold: 0.8,
102            };
103        }
104
105        if model_name.starts_with("gpt-5-mini") {
106            return ModelContextInfo {
107                max_tokens: 400_000,
108                pricing_tiers: vec![ContextPricingTier {
109                    label: "Standard".to_string(),
110                    input_cost_per_million: 0.25,
111                    output_cost_per_million: 2.0,
112                    upper_bound: None,
113                }],
114                approach_warning_threshold: 0.8,
115            };
116        }
117
118        if model_name.starts_with("gpt-5-nano") {
119            return ModelContextInfo {
120                max_tokens: 400_000,
121                pricing_tiers: vec![ContextPricingTier {
122                    label: "Standard".to_string(),
123                    input_cost_per_million: 0.05,
124                    output_cost_per_million: 0.40,
125                    upper_bound: None,
126                }],
127                approach_warning_threshold: 0.8,
128            };
129        }
130
131        if model_name.starts_with("gpt-5") {
132            return ModelContextInfo {
133                max_tokens: 400_000,
134                pricing_tiers: vec![ContextPricingTier {
135                    label: "Standard".to_string(),
136                    input_cost_per_million: 1.25,
137                    output_cost_per_million: 10.0,
138                    upper_bound: None,
139                }],
140                approach_warning_threshold: 0.8,
141            };
142        }
143
144        ModelContextInfo::default()
145    }
146
147    fn model_name(&self) -> String {
148        match self {
149            OpenAIModel::O3 => "O3".to_string(),
150            OpenAIModel::O4Mini => "O4-mini".to_string(),
151            OpenAIModel::GPT5 => "GPT-5".to_string(),
152            OpenAIModel::GPT51 => "GPT-5.1".to_string(),
153            OpenAIModel::GPT5Mini => "GPT-5 Mini".to_string(),
154            OpenAIModel::GPT5Nano => "GPT-5 Nano".to_string(),
155            OpenAIModel::Custom(name) => format!("Custom ({})", name),
156        }
157    }
158}
159
160impl std::fmt::Display for OpenAIModel {
161    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162        match self {
163            OpenAIModel::O3 => write!(f, "o3-2025-04-16"),
164
165            OpenAIModel::O4Mini => write!(f, "o4-mini-2025-04-16"),
166
167            OpenAIModel::GPT5Nano => write!(f, "gpt-5-nano-2025-08-07"),
168            OpenAIModel::GPT5Mini => write!(f, "gpt-5-mini-2025-08-07"),
169            OpenAIModel::GPT5 => write!(f, "gpt-5-2025-08-07"),
170            OpenAIModel::GPT51 => write!(f, "gpt-5.1-2025-11-13"),
171
172            OpenAIModel::Custom(model_name) => write!(f, "{}", model_name),
173        }
174    }
175}
176
177#[derive(Serialize, Deserialize, Debug)]
178pub struct OpenAIInput {
179    pub model: OpenAIModel,
180    pub messages: Vec<LLMMessage>,
181    pub max_tokens: u32,
182
183    #[serde(skip_serializing_if = "Option::is_none")]
184    pub json: Option<serde_json::Value>,
185
186    #[serde(skip_serializing_if = "Option::is_none")]
187    pub tools: Option<Vec<LLMTool>>,
188
189    #[serde(skip_serializing_if = "Option::is_none")]
190    pub reasoning_effort: Option<OpenAIReasoningEffort>,
191}
192
193impl OpenAIInput {
194    pub fn is_reasoning_model(&self) -> bool {
195        matches!(self.model, |OpenAIModel::O3| OpenAIModel::O4Mini
196            | OpenAIModel::GPT5
197            | OpenAIModel::GPT5Mini
198            | OpenAIModel::GPT5Nano)
199    }
200
201    pub fn is_standard_model(&self) -> bool {
202        !self.is_reasoning_model()
203    }
204}
205
206#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
207pub enum OpenAIReasoningEffort {
208    #[serde(rename = "minimal")]
209    Minimal,
210    #[serde(rename = "low")]
211    Low,
212    #[default]
213    #[serde(rename = "medium")]
214    Medium,
215    #[serde(rename = "high")]
216    High,
217}
218
219#[derive(Serialize, Deserialize, Debug)]
220pub struct OpenAITool {
221    pub r#type: String,
222    pub function: OpenAIToolFunction,
223}
224
225#[derive(Serialize, Deserialize, Debug)]
226pub struct OpenAIToolFunction {
227    pub name: String,
228    pub description: String,
229    pub parameters: serde_json::Value,
230}
231
232impl From<LLMTool> for OpenAITool {
233    fn from(tool: LLMTool) -> Self {
234        OpenAITool {
235            r#type: "function".to_string(),
236            function: OpenAIToolFunction {
237                name: tool.name,
238                description: tool.description,
239                parameters: tool.input_schema,
240            },
241        }
242    }
243}
244
245#[derive(Serialize, Deserialize, Debug, Clone)]
246pub struct OpenAIOutput {
247    pub model: String,
248    pub object: String,
249    pub choices: Vec<OpenAILLMChoice>,
250    pub created: u64,
251    pub usage: Option<LLMTokenUsage>,
252    pub id: String,
253}
254
255impl From<OpenAIOutput> for LLMCompletionResponse {
256    fn from(val: OpenAIOutput) -> Self {
257        LLMCompletionResponse {
258            model: val.model,
259            object: val.object,
260            choices: val.choices.into_iter().map(OpenAILLMChoice::into).collect(),
261            created: val.created,
262            usage: val.usage,
263            id: val.id,
264        }
265    }
266}
267
268#[derive(Serialize, Deserialize, Debug, Clone)]
269pub struct OpenAILLMChoice {
270    pub finish_reason: Option<String>,
271    pub index: u32,
272    pub message: OpenAILLMMessage,
273}
274
275impl From<OpenAILLMChoice> for LLMChoice {
276    fn from(val: OpenAILLMChoice) -> Self {
277        LLMChoice {
278            finish_reason: val.finish_reason,
279            index: val.index,
280            message: val.message.into(),
281        }
282    }
283}
284
285#[derive(Serialize, Deserialize, Debug, Clone)]
286pub struct OpenAILLMMessage {
287    pub role: String,
288    pub content: Option<String>,
289    pub tool_calls: Option<Vec<OpenAILLMMessageToolCall>>,
290}
291impl From<OpenAILLMMessage> for LLMMessage {
292    fn from(val: OpenAILLMMessage) -> Self {
293        LLMMessage {
294            role: val.role,
295            content: match val.tool_calls {
296                None => LLMMessageContent::String(val.content.unwrap_or_default()),
297                Some(tool_calls) => LLMMessageContent::List(
298                    std::iter::once(LLMMessageTypedContent::Text {
299                        text: val.content.unwrap_or_default(),
300                    })
301                    .chain(tool_calls.into_iter().map(|tool_call| {
302                        LLMMessageTypedContent::ToolCall {
303                            id: tool_call.id,
304                            name: tool_call.function.name,
305                            args: match serde_json::from_str(&tool_call.function.arguments) {
306                                Ok(args) => args,
307                                Err(_) => {
308                                    return LLMMessageTypedContent::Text {
309                                        text: String::from("Error parsing tool call arguments"),
310                                    };
311                                }
312                            },
313                        }
314                    }))
315                    .collect(),
316                ),
317            },
318        }
319    }
320}
321
322#[derive(Serialize, Deserialize, Debug, Clone)]
323pub struct OpenAILLMMessageToolCall {
324    pub id: String,
325    pub r#type: String,
326    pub function: OpenAILLMMessageToolCallFunction,
327}
328#[derive(Serialize, Deserialize, Debug, Clone)]
329pub struct OpenAILLMMessageToolCallFunction {
330    pub arguments: String,
331    pub name: String,
332}
333
334pub struct OpenAI {}
335
336impl OpenAI {
337    pub async fn chat(
338        config: &OpenAIConfig,
339        input: OpenAIInput,
340    ) -> Result<LLMCompletionResponse, AgentError> {
341        let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
342        let client = ClientBuilder::new(reqwest::Client::new())
343            .with(RetryTransientMiddleware::new_with_policy(retry_policy))
344            .build();
345
346        let is_reasoning_model = input.is_reasoning_model();
347
348        let mut payload = json!({
349            "model": input.model.to_string(),
350            "messages": input.messages.into_iter().map(ChatMessage::from).collect::<Vec<ChatMessage>>(),
351            "max_completion_tokens": input.max_tokens,
352            "stream": false,
353        });
354
355        if is_reasoning_model {
356            if let Some(reasoning_effort) = input.reasoning_effort {
357                payload["reasoning_effort"] = json!(reasoning_effort);
358            } else {
359                payload["reasoning_effort"] = json!(OpenAIReasoningEffort::Medium);
360            }
361        } else {
362            payload["temperature"] = json!(0);
363        }
364
365        if let Some(tools) = input.tools {
366            let openai_tools: Vec<OpenAITool> = tools.into_iter().map(|t| t.into()).collect();
367            payload["tools"] = json!(openai_tools);
368        }
369
370        if let Some(schema) = input.json {
371            payload["response_format"] = json!({
372                "type": "json_schema",
373                "json_schema": {
374                    "strict": true,
375                    "schema": schema,
376                    "name": "my-schema"
377                }
378            });
379        }
380
381        let api_endpoint = config.api_endpoint.as_ref().map_or(DEFAULT_BASE_URL, |v| v);
382        let api_key = config.api_key.as_ref().map_or("", |v| v);
383
384        let final_payload_str = serde_json::to_string(&payload)
385            .map_err(|e| AgentError::BadRequest(BadRequestErrorMessage::ApiError(e.to_string())))?;
386
387        let response = client
388            .post(api_endpoint)
389            .header("Authorization", format!("Bearer {}", api_key))
390            .header("Accept", "application/json")
391            .header("Content-Type", "application/json")
392            .body(final_payload_str)
393            .send()
394            .await;
395
396        let response = match response {
397            Ok(resp) => resp,
398            Err(e) => {
399                return Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
400                    e.to_string(),
401                )));
402            }
403        };
404
405        match response.json::<Value>().await {
406            Ok(json) => {
407                // Check if the response contains an error field
408                if let Some(error_obj) = json.get("error") {
409                    let error_message =
410                        if let Some(message) = error_obj.get("message").and_then(|m| m.as_str()) {
411                            message.to_string()
412                        } else if let Some(code) = error_obj.get("code").and_then(|c| c.as_str()) {
413                            format!("API error: {}", code)
414                        } else {
415                            serde_json::to_string(error_obj)
416                                .unwrap_or_else(|_| "Unknown API error".to_string())
417                        };
418                    return Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
419                        error_message,
420                    )));
421                }
422
423                // Try to deserialize as successful response
424                match serde_json::from_value::<OpenAIOutput>(json.clone()) {
425                    Ok(json_response) => Ok(json_response.into()),
426                    Err(e) => Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
427                        e.to_string(),
428                    ))),
429                }
430            }
431            Err(e) => Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
432                e.to_string(),
433            ))),
434        }
435    }
436
437    pub async fn chat_stream(
438        config: &OpenAIConfig,
439        stream_channel_tx: tokio::sync::mpsc::Sender<GenerationDelta>,
440        input: OpenAIInput,
441    ) -> Result<LLMCompletionResponse, AgentError> {
442        let is_reasoning_model = input.is_reasoning_model();
443
444        let mut payload = json!({
445            "model": input.model.to_string(),
446            "messages": input.messages.into_iter().map(ChatMessage::from).collect::<Vec<ChatMessage>>(),
447            "max_completion_tokens": input.max_tokens,
448            "stream": true,
449            "stream_options":{
450                "include_usage": true
451            }
452        });
453
454        if is_reasoning_model {
455            if let Some(reasoning_effort) = input.reasoning_effort {
456                payload["reasoning_effort"] = json!(reasoning_effort);
457            } else {
458                payload["reasoning_effort"] = json!(OpenAIReasoningEffort::Medium);
459            }
460        } else {
461            payload["temperature"] = json!(0);
462        }
463
464        if let Some(tools) = input.tools {
465            let openai_tools: Vec<OpenAITool> = tools.into_iter().map(|t| t.into()).collect();
466            payload["tools"] = json!(openai_tools);
467        }
468
469        if let Some(schema) = input.json {
470            payload["response_format"] = json!({
471                "type": "json_schema",
472                "json_schema": {
473                    "strict": true,
474                    "schema": schema,
475                    "name": "my-schema"
476                }
477            });
478        }
479
480        let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
481        let client = ClientBuilder::new(reqwest::Client::new())
482            .with(RetryTransientMiddleware::new_with_policy(retry_policy))
483            .build();
484
485        let api_endpoint = config.api_endpoint.as_ref().map_or(DEFAULT_BASE_URL, |v| v);
486        let api_key = config.api_key.as_ref().map_or("", |v| v);
487
488        let final_payload_str = serde_json::to_string(&payload)
489            .map_err(|e| AgentError::BadRequest(BadRequestErrorMessage::ApiError(e.to_string())))?;
490
491        let response = client
492            .post(api_endpoint)
493            .header("Authorization", format!("Bearer {}", api_key))
494            .header("Accept", "application/json")
495            .header("Content-Type", "application/json")
496            .body(final_payload_str)
497            .send()
498            .await;
499
500        let response = match response {
501            Ok(resp) => resp,
502            Err(e) => {
503                return Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
504                    e.to_string(),
505                )));
506            }
507        };
508
509        if !response.status().is_success() {
510            let error_text = response.text().await.unwrap_or_default();
511
512            // Try to parse as JSON and extract error message
513            let error_message = if let Ok(json) = serde_json::from_str::<Value>(&error_text) {
514                if let Some(error_obj) = json.get("error") {
515                    if let Some(message) = error_obj.get("message").and_then(|m| m.as_str()) {
516                        message.to_string()
517                    } else if let Some(code) = error_obj.get("code").and_then(|c| c.as_str()) {
518                        format!("API error: {}", code)
519                    } else {
520                        error_text
521                    }
522                } else {
523                    error_text
524                }
525            } else {
526                error_text
527            };
528
529            return Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
530                error_message,
531            )));
532        }
533
534        process_openai_stream(response, input.model.to_string(), stream_channel_tx)
535            .await
536            .map_err(|e| AgentError::BadRequest(BadRequestErrorMessage::ApiError(e.to_string())))
537    }
538}
539
540/// Process OpenAI stream and convert to GenerationDelta format
541async fn process_openai_stream(
542    response: reqwest::Response,
543    model: String,
544    stream_channel_tx: tokio::sync::mpsc::Sender<GenerationDelta>,
545) -> Result<LLMCompletionResponse, AgentError> {
546    let mut completion_response = LLMCompletionResponse {
547        id: "".to_string(),
548        model: model.clone(),
549        object: "chat.completion".to_string(),
550        choices: vec![],
551        created: chrono::Utc::now().timestamp_millis() as u64,
552        usage: None,
553    };
554
555    let mut stream = response.bytes_stream();
556    let mut unparsed_data = String::new();
557    let mut current_tool_calls: std::collections::HashMap<usize, (String, String, String)> =
558        std::collections::HashMap::new();
559    let mut accumulated_content = String::new();
560    let mut finish_reason: Option<String> = None;
561
562    while let Some(chunk) = stream.next().await {
563        let chunk = chunk.map_err(|e| {
564            let error_message = format!("Failed to read stream chunk from OpenAI API: {e}");
565            AgentError::BadRequest(BadRequestErrorMessage::ApiError(error_message))
566        })?;
567
568        let text = std::str::from_utf8(&chunk).map_err(|e| {
569            let error_message = format!("Failed to parse UTF-8 from OpenAI response: {e}");
570            AgentError::BadRequest(BadRequestErrorMessage::ApiError(error_message))
571        })?;
572
573        unparsed_data.push_str(text);
574
575        while let Some(line_end) = unparsed_data.find('\n') {
576            let line = unparsed_data[..line_end].to_string();
577            unparsed_data = unparsed_data[line_end + 1..].to_string();
578
579            if line.trim().is_empty() {
580                continue;
581            }
582
583            if !line.starts_with("data: ") {
584                continue;
585            }
586
587            let json_str = &line[6..];
588            if json_str == "[DONE]" {
589                continue;
590            }
591
592            match serde_json::from_str::<ChatCompletionStreamResponse>(json_str) {
593                Ok(stream_response) => {
594                    // Update completion response metadata
595                    if completion_response.id.is_empty() {
596                        completion_response.id = stream_response.id.clone();
597                        completion_response.model = stream_response.model.clone();
598                        completion_response.object = stream_response.object.clone();
599                        completion_response.created = stream_response.created;
600                    }
601
602                    // Process choices
603                    for choice in &stream_response.choices {
604                        if let Some(content) = &choice.delta.content {
605                            // Send content delta
606                            stream_channel_tx
607                                .send(GenerationDelta::Content {
608                                    content: content.clone(),
609                                })
610                                .await
611                                .map_err(|e| {
612                                    AgentError::BadRequest(BadRequestErrorMessage::ApiError(
613                                        e.to_string(),
614                                    ))
615                                })?;
616                            accumulated_content.push_str(content);
617                        }
618
619                        // Handle tool calls
620                        if let Some(tool_calls) = &choice.delta.tool_calls {
621                            for tool_call in tool_calls {
622                                let index = tool_call.index;
623
624                                // Initialize or update tool call
625                                let entry = current_tool_calls.entry(index).or_insert((
626                                    String::new(),
627                                    String::new(),
628                                    String::new(),
629                                ));
630
631                                if let Some(id) = &tool_call.id {
632                                    entry.0 = id.clone();
633                                }
634                                if let Some(function) = &tool_call.function {
635                                    if let Some(name) = &function.name {
636                                        entry.1 = name.clone();
637                                    }
638                                    if let Some(args) = &function.arguments {
639                                        entry.2.push_str(args);
640                                    }
641                                }
642
643                                // Send tool use delta
644                                stream_channel_tx
645                                    .send(GenerationDelta::ToolUse {
646                                        tool_use: GenerationDeltaToolUse {
647                                            id: tool_call.id.clone(),
648                                            name: tool_call
649                                                .function
650                                                .as_ref()
651                                                .and_then(|f| f.name.clone())
652                                                .and_then(|n| {
653                                                    if n.is_empty() { None } else { Some(n) }
654                                                }),
655                                            input: tool_call
656                                                .function
657                                                .as_ref()
658                                                .and_then(|f| f.arguments.clone()),
659                                            index,
660                                        },
661                                    })
662                                    .await
663                                    .map_err(|e| {
664                                        AgentError::BadRequest(BadRequestErrorMessage::ApiError(
665                                            e.to_string(),
666                                        ))
667                                    })?;
668                            }
669                        }
670
671                        if let Some(reason) = &choice.finish_reason {
672                            finish_reason = Some(match reason {
673                                FinishReason::Stop => "stop".to_string(),
674                                FinishReason::Length => "length".to_string(),
675                                FinishReason::ContentFilter => "content_filter".to_string(),
676                                FinishReason::ToolCalls => "tool_calls".to_string(),
677                            });
678                        }
679                    }
680
681                    // Update usage if available
682                    if let Some(usage) = &stream_response.usage {
683                        stream_channel_tx
684                            .send(GenerationDelta::Usage {
685                                usage: usage.clone(),
686                            })
687                            .await
688                            .map_err(|e| {
689                                AgentError::BadRequest(BadRequestErrorMessage::ApiError(
690                                    e.to_string(),
691                                ))
692                            })?;
693                        completion_response.usage = Some(usage.clone());
694                    }
695                }
696                Err(e) => {
697                    eprintln!("Error parsing response: {}", e);
698                }
699            }
700        }
701    }
702
703    // Build final response
704    let mut message_content = vec![];
705
706    if !accumulated_content.is_empty() {
707        message_content.push(LLMMessageTypedContent::Text {
708            text: accumulated_content,
709        });
710    }
711
712    for (_, (id, name, args)) in current_tool_calls {
713        if let Ok(parsed_args) = serde_json::from_str(&args) {
714            message_content.push(LLMMessageTypedContent::ToolCall {
715                id,
716                name,
717                args: parsed_args,
718            });
719        }
720    }
721
722    completion_response.choices = vec![LLMChoice {
723        finish_reason,
724        index: 0,
725        message: LLMMessage {
726            role: "assistant".to_string(),
727            content: if message_content.is_empty() {
728                LLMMessageContent::String(String::new())
729            } else if message_content.len() == 1
730                && matches!(&message_content[0], LLMMessageTypedContent::Text { .. })
731            {
732                if let LLMMessageTypedContent::Text { text } = &message_content[0] {
733                    LLMMessageContent::String(text.clone())
734                } else {
735                    LLMMessageContent::List(message_content)
736                }
737            } else {
738                LLMMessageContent::List(message_content)
739            },
740        },
741    }];
742
743    Ok(completion_response)
744}
745
746#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
747#[serde(rename_all = "lowercase")]
748pub enum Role {
749    System,
750    Developer,
751    User,
752    #[default]
753    Assistant,
754    Tool,
755    // Function,
756}
757
758impl std::fmt::Display for Role {
759    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
760        match self {
761            Role::System => write!(f, "system"),
762            Role::Developer => write!(f, "developer"),
763            Role::User => write!(f, "user"),
764            Role::Assistant => write!(f, "assistant"),
765            Role::Tool => write!(f, "tool"),
766        }
767    }
768}
769
770#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
771pub struct ChatCompletionRequest {
772    pub model: String,
773    pub messages: Vec<ChatMessage>,
774    #[serde(skip_serializing_if = "Option::is_none")]
775    pub frequency_penalty: Option<f32>,
776    #[serde(skip_serializing_if = "Option::is_none")]
777    pub logit_bias: Option<serde_json::Value>,
778    #[serde(skip_serializing_if = "Option::is_none")]
779    pub logprobs: Option<bool>,
780    #[serde(skip_serializing_if = "Option::is_none")]
781    pub max_tokens: Option<u32>,
782    #[serde(skip_serializing_if = "Option::is_none")]
783    pub n: Option<u32>,
784    #[serde(skip_serializing_if = "Option::is_none")]
785    pub presence_penalty: Option<f32>,
786    #[serde(skip_serializing_if = "Option::is_none")]
787    pub response_format: Option<ResponseFormat>,
788    #[serde(skip_serializing_if = "Option::is_none")]
789    pub seed: Option<i64>,
790    #[serde(skip_serializing_if = "Option::is_none")]
791    pub stop: Option<StopSequence>,
792    #[serde(skip_serializing_if = "Option::is_none")]
793    pub stream: Option<bool>,
794    #[serde(skip_serializing_if = "Option::is_none")]
795    pub temperature: Option<f32>,
796    #[serde(skip_serializing_if = "Option::is_none")]
797    pub top_p: Option<f32>,
798    #[serde(skip_serializing_if = "Option::is_none")]
799    pub tools: Option<Vec<Tool>>,
800    #[serde(skip_serializing_if = "Option::is_none")]
801    pub tool_choice: Option<ToolChoice>,
802    #[serde(skip_serializing_if = "Option::is_none")]
803    pub user: Option<String>,
804
805    #[serde(skip_serializing_if = "Option::is_none")]
806    pub context: Option<ChatCompletionContext>,
807}
808
809#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
810pub struct ChatCompletionContext {
811    pub scratchpad: Option<Value>,
812}
813
814#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
815pub struct ChatMessage {
816    pub role: Role,
817    pub content: Option<MessageContent>,
818    #[serde(skip_serializing_if = "Option::is_none")]
819    pub name: Option<String>,
820    #[serde(skip_serializing_if = "Option::is_none")]
821    pub tool_calls: Option<Vec<ToolCall>>,
822    #[serde(skip_serializing_if = "Option::is_none")]
823    pub tool_call_id: Option<String>,
824
825    // TODO: check if this is needed
826    #[serde(skip_serializing_if = "Option::is_none")]
827    pub usage: Option<LLMTokenUsage>,
828}
829
830impl ChatMessage {
831    pub fn last_server_message(messages: &[ChatMessage]) -> Option<&ChatMessage> {
832        messages
833            .iter()
834            .rev()
835            .find(|message| message.role != Role::User && message.role != Role::Tool)
836    }
837
838    pub fn to_xml(&self) -> String {
839        match &self.content {
840            Some(MessageContent::String(s)) => {
841                format!("<message role=\"{}\">{}</message>", self.role, s)
842            }
843            Some(MessageContent::Array(parts)) => parts
844                .iter()
845                .map(|part| {
846                    format!(
847                        "<message role=\"{}\" type=\"{}\">{}</message>",
848                        self.role,
849                        part.r#type,
850                        part.text.clone().unwrap_or_default()
851                    )
852                })
853                .collect::<Vec<String>>()
854                .join("\n"),
855            None => String::new(),
856        }
857    }
858}
859
860#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
861#[serde(untagged)]
862pub enum MessageContent {
863    String(String),
864    Array(Vec<ContentPart>),
865}
866
867impl MessageContent {
868    pub fn inject_checkpoint_id(&self, checkpoint_id: Uuid) -> Self {
869        match self {
870            MessageContent::String(s) => MessageContent::String(format!(
871                "<checkpoint_id>{checkpoint_id}</checkpoint_id>\n{s}"
872            )),
873            MessageContent::Array(parts) => MessageContent::Array(
874                std::iter::once(ContentPart {
875                    r#type: "text".to_string(),
876                    text: Some(format!("<checkpoint_id>{checkpoint_id}</checkpoint_id>")),
877                    image_url: None,
878                })
879                .chain(parts.iter().cloned())
880                .collect(),
881            ),
882        }
883    }
884
885    pub fn extract_checkpoint_id(&self) -> Option<Uuid> {
886        match self {
887            MessageContent::String(s) => s
888                .rfind("<checkpoint_id>")
889                .and_then(|start| {
890                    s[start..]
891                        .find("</checkpoint_id>")
892                        .map(|end| (start + "<checkpoint_id>".len(), start + end))
893                })
894                .and_then(|(start, end)| Uuid::parse_str(&s[start..end]).ok()),
895            MessageContent::Array(parts) => parts.iter().rev().find_map(|part| {
896                part.text.as_deref().and_then(|text| {
897                    text.rfind("<checkpoint_id>")
898                        .and_then(|start| {
899                            text[start..]
900                                .find("</checkpoint_id>")
901                                .map(|end| (start + "<checkpoint_id>".len(), start + end))
902                        })
903                        .and_then(|(start, end)| Uuid::parse_str(&text[start..end]).ok())
904                })
905            }),
906        }
907    }
908}
909
910impl std::fmt::Display for MessageContent {
911    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
912        match self {
913            MessageContent::String(s) => write!(f, "{s}"),
914            MessageContent::Array(parts) => {
915                let text_parts: Vec<String> =
916                    parts.iter().filter_map(|part| part.text.clone()).collect();
917                write!(f, "{}", text_parts.join("\n"))
918            }
919        }
920    }
921}
922impl Default for MessageContent {
923    fn default() -> Self {
924        MessageContent::String(String::new())
925    }
926}
927
928#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
929pub struct ContentPart {
930    pub r#type: String,
931    #[serde(skip_serializing_if = "Option::is_none")]
932    pub text: Option<String>,
933    #[serde(skip_serializing_if = "Option::is_none")]
934    pub image_url: Option<ImageUrl>,
935}
936
937#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
938pub struct ImageUrl {
939    pub url: String,
940    #[serde(skip_serializing_if = "Option::is_none")]
941    pub detail: Option<String>,
942}
943
944#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
945pub struct ResponseFormat {
946    pub r#type: String,
947}
948
949#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
950#[serde(untagged)]
951pub enum StopSequence {
952    String(String),
953    Array(Vec<String>),
954}
955
956#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
957pub struct Tool {
958    pub r#type: String,
959    pub function: FunctionDefinition,
960}
961
962#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
963pub struct FunctionDefinition {
964    pub name: String,
965    pub description: Option<String>,
966    pub parameters: serde_json::Value,
967}
968
969#[derive(Debug, Clone, PartialEq)]
970pub enum ToolChoice {
971    Auto,
972    Required,
973    Object(ToolChoiceObject),
974}
975
976impl Serialize for ToolChoice {
977    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
978    where
979        S: serde::Serializer,
980    {
981        match self {
982            ToolChoice::Auto => serializer.serialize_str("auto"),
983            ToolChoice::Required => serializer.serialize_str("required"),
984            ToolChoice::Object(obj) => obj.serialize(serializer),
985        }
986    }
987}
988
989impl<'de> Deserialize<'de> for ToolChoice {
990    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
991    where
992        D: serde::Deserializer<'de>,
993    {
994        struct ToolChoiceVisitor;
995
996        impl<'de> serde::de::Visitor<'de> for ToolChoiceVisitor {
997            type Value = ToolChoice;
998
999            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
1000                formatter.write_str("string or object")
1001            }
1002
1003            fn visit_str<E>(self, value: &str) -> Result<ToolChoice, E>
1004            where
1005                E: serde::de::Error,
1006            {
1007                match value {
1008                    "auto" => Ok(ToolChoice::Auto),
1009                    "required" => Ok(ToolChoice::Required),
1010                    _ => Err(serde::de::Error::unknown_variant(
1011                        value,
1012                        &["auto", "required"],
1013                    )),
1014                }
1015            }
1016
1017            fn visit_map<M>(self, map: M) -> Result<ToolChoice, M::Error>
1018            where
1019                M: serde::de::MapAccess<'de>,
1020            {
1021                let obj = ToolChoiceObject::deserialize(
1022                    serde::de::value::MapAccessDeserializer::new(map),
1023                )?;
1024                Ok(ToolChoice::Object(obj))
1025            }
1026        }
1027
1028        deserializer.deserialize_any(ToolChoiceVisitor)
1029    }
1030}
1031
1032#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1033pub struct ToolChoiceObject {
1034    pub r#type: String,
1035    pub function: FunctionChoice,
1036}
1037
1038#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1039pub struct FunctionChoice {
1040    pub name: String,
1041}
1042
1043#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1044pub struct ToolCall {
1045    pub id: String,
1046    pub r#type: String,
1047    pub function: FunctionCall,
1048}
1049
1050#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1051pub struct FunctionCall {
1052    pub name: String,
1053    pub arguments: String,
1054}
1055
1056#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1057pub struct ChatCompletionResponse {
1058    pub id: String,
1059    pub object: String,
1060    pub created: u64,
1061    pub model: String,
1062    pub choices: Vec<ChatCompletionChoice>,
1063    pub usage: LLMTokenUsage,
1064    #[serde(skip_serializing_if = "Option::is_none")]
1065    pub system_fingerprint: Option<String>,
1066    pub metadata: Option<serde_json::Value>,
1067}
1068
1069#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1070pub struct ChatCompletionChoice {
1071    pub index: usize,
1072    pub message: ChatMessage,
1073    pub logprobs: Option<LogProbs>,
1074    pub finish_reason: FinishReason,
1075}
1076
1077#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1078#[serde(rename_all = "snake_case")]
1079pub enum FinishReason {
1080    Stop,
1081    Length,
1082    ContentFilter,
1083    ToolCalls,
1084}
1085
1086#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1087pub struct LogProbs {
1088    pub content: Option<Vec<LogProbContent>>,
1089}
1090
1091#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1092pub struct LogProbContent {
1093    pub token: String,
1094    pub logprob: f32,
1095    pub bytes: Option<Vec<u8>>,
1096    pub top_logprobs: Option<Vec<TokenLogprob>>,
1097}
1098
1099#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1100pub struct TokenLogprob {
1101    pub token: String,
1102    pub logprob: f32,
1103    pub bytes: Option<Vec<u8>>,
1104}
1105
1106#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1107pub struct ChatCompletionStreamResponse {
1108    pub id: String,
1109    pub object: String,
1110    pub created: u64,
1111    pub model: String,
1112    pub choices: Vec<ChatCompletionStreamChoice>,
1113    pub usage: Option<LLMTokenUsage>,
1114    pub metadata: Option<serde_json::Value>,
1115}
1116#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1117pub struct ChatCompletionStreamChoice {
1118    pub index: usize,
1119    pub delta: ChatMessageDelta,
1120    pub finish_reason: Option<FinishReason>,
1121}
1122#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1123pub struct ChatMessageDelta {
1124    #[serde(skip_serializing_if = "Option::is_none")]
1125    pub role: Option<Role>,
1126    #[serde(skip_serializing_if = "Option::is_none")]
1127    pub content: Option<String>,
1128    #[serde(skip_serializing_if = "Option::is_none")]
1129    pub tool_calls: Option<Vec<ToolCallDelta>>,
1130}
1131
1132#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1133pub struct ToolCallDelta {
1134    pub index: usize,
1135    pub id: Option<String>,
1136    pub r#type: Option<String>,
1137    pub function: Option<FunctionCallDelta>,
1138}
1139
1140#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1141pub struct FunctionCallDelta {
1142    pub name: Option<String>,
1143    pub arguments: Option<String>,
1144}
1145
1146impl From<LLMMessage> for ChatMessage {
1147    fn from(llm_message: LLMMessage) -> Self {
1148        let role = match llm_message.role.as_str() {
1149            "system" => Role::System,
1150            "user" => Role::User,
1151            "assistant" => Role::Assistant,
1152            "tool" => Role::Tool,
1153            // "function" => Role::Function,
1154            "developer" => Role::Developer,
1155            _ => Role::User, // Default to user for unknown roles
1156        };
1157
1158        let (content, tool_calls) = match llm_message.content {
1159            LLMMessageContent::String(text) => (Some(MessageContent::String(text)), None),
1160            LLMMessageContent::List(items) => {
1161                let mut text_parts = Vec::new();
1162                let mut tool_call_parts = Vec::new();
1163
1164                for item in items {
1165                    match item {
1166                        LLMMessageTypedContent::Text { text } => {
1167                            text_parts.push(ContentPart {
1168                                r#type: "text".to_string(),
1169                                text: Some(text),
1170                                image_url: None,
1171                            });
1172                        }
1173                        LLMMessageTypedContent::ToolCall { id, name, args } => {
1174                            tool_call_parts.push(ToolCall {
1175                                id,
1176                                r#type: "function".to_string(),
1177                                function: FunctionCall {
1178                                    name,
1179                                    arguments: args.to_string(),
1180                                },
1181                            });
1182                        }
1183                        LLMMessageTypedContent::ToolResult { content, .. } => {
1184                            text_parts.push(ContentPart {
1185                                r#type: "text".to_string(),
1186                                text: Some(content),
1187                                image_url: None,
1188                            });
1189                        }
1190                        LLMMessageTypedContent::Image { source } => {
1191                            text_parts.push(ContentPart {
1192                                r#type: "image_url".to_string(),
1193                                text: None,
1194                                image_url: Some(ImageUrl {
1195                                    url: format!(
1196                                        "data:{};base64,{}",
1197                                        source.media_type, source.data
1198                                    ),
1199                                    detail: None,
1200                                }),
1201                            });
1202                        }
1203                    }
1204                }
1205
1206                let content = if !text_parts.is_empty() {
1207                    Some(MessageContent::Array(text_parts))
1208                } else {
1209                    None
1210                };
1211
1212                let tool_calls = if !tool_call_parts.is_empty() {
1213                    Some(tool_call_parts)
1214                } else {
1215                    None
1216                };
1217
1218                (content, tool_calls)
1219            }
1220        };
1221
1222        ChatMessage {
1223            role,
1224            content,
1225            name: None, // LLMMessage doesn't have a name field
1226            tool_calls,
1227            tool_call_id: None, // LLMMessage doesn't have a tool_call_id field
1228            usage: None,
1229        }
1230    }
1231}
1232
1233impl From<ChatMessage> for LLMMessage {
1234    fn from(chat_message: ChatMessage) -> Self {
1235        let mut content_parts = Vec::new();
1236
1237        // Handle text content
1238        match chat_message.content {
1239            Some(MessageContent::String(s)) => {
1240                if !s.is_empty() {
1241                    content_parts.push(LLMMessageTypedContent::Text { text: s });
1242                }
1243            }
1244            Some(MessageContent::Array(parts)) => {
1245                for part in parts {
1246                    if let Some(text) = part.text {
1247                        content_parts.push(LLMMessageTypedContent::Text { text });
1248                    } else if let Some(image_url) = part.image_url {
1249                        let (media_type, data) = if image_url.url.starts_with("data:") {
1250                            let parts: Vec<&str> = image_url.url.splitn(2, ',').collect();
1251                            if parts.len() == 2 {
1252                                let meta = parts[0];
1253                                let data = parts[1];
1254                                let media_type = meta
1255                                    .trim_start_matches("data:")
1256                                    .trim_end_matches(";base64")
1257                                    .to_string();
1258                                (media_type, data.to_string())
1259                            } else {
1260                                ("image/jpeg".to_string(), image_url.url)
1261                            }
1262                        } else {
1263                            ("image/jpeg".to_string(), image_url.url)
1264                        };
1265
1266                        content_parts.push(LLMMessageTypedContent::Image {
1267                            source: LLMMessageImageSource {
1268                                r#type: "base64".to_string(),
1269                                media_type,
1270                                data,
1271                            },
1272                        });
1273                    }
1274                }
1275            }
1276            None => {}
1277        }
1278
1279        // Handle tool calls
1280        if let Some(tool_calls) = chat_message.tool_calls {
1281            for tool_call in tool_calls {
1282                let args = serde_json::from_str(&tool_call.function.arguments).unwrap_or(json!({}));
1283                content_parts.push(LLMMessageTypedContent::ToolCall {
1284                    id: tool_call.id,
1285                    name: tool_call.function.name,
1286                    args,
1287                });
1288            }
1289        }
1290
1291        LLMMessage {
1292            role: chat_message.role.to_string(),
1293            content: if content_parts.is_empty() {
1294                LLMMessageContent::String(String::new())
1295            } else if content_parts.len() == 1 {
1296                match &content_parts[0] {
1297                    LLMMessageTypedContent::Text { text } => {
1298                        LLMMessageContent::String(text.clone())
1299                    }
1300                    _ => LLMMessageContent::List(content_parts),
1301                }
1302            } else {
1303                LLMMessageContent::List(content_parts)
1304            },
1305        }
1306    }
1307}
1308
1309#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq)]
1310pub enum AgentModel {
1311    #[serde(rename = "smart")]
1312    #[default]
1313    Smart,
1314    #[serde(rename = "eco")]
1315    Eco,
1316    #[serde(rename = "recovery")]
1317    Recovery,
1318}
1319
1320impl std::fmt::Display for AgentModel {
1321    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1322        match self {
1323            AgentModel::Smart => write!(f, "smart"),
1324            AgentModel::Eco => write!(f, "eco"),
1325            AgentModel::Recovery => write!(f, "recovery"),
1326        }
1327    }
1328}
1329
1330impl From<String> for AgentModel {
1331    fn from(value: String) -> Self {
1332        match value.as_str() {
1333            "eco" => AgentModel::Eco,
1334            "recovery" => AgentModel::Recovery,
1335            _ => AgentModel::Smart,
1336        }
1337    }
1338}
1339
1340impl ChatCompletionRequest {
1341    pub fn new(
1342        model: String,
1343        messages: Vec<ChatMessage>,
1344        tools: Option<Vec<Tool>>,
1345        stream: Option<bool>,
1346    ) -> Self {
1347        Self {
1348            model,
1349            messages,
1350            frequency_penalty: None,
1351            logit_bias: None,
1352            logprobs: None,
1353            max_tokens: None,
1354            n: None,
1355            presence_penalty: None,
1356            response_format: None,
1357            seed: None,
1358            stop: None,
1359            stream,
1360            temperature: None,
1361            top_p: None,
1362            tools,
1363            tool_choice: None,
1364            user: None,
1365            context: None,
1366        }
1367    }
1368}
1369
1370impl From<Tool> for LLMTool {
1371    fn from(tool: Tool) -> Self {
1372        LLMTool {
1373            name: tool.function.name,
1374            description: tool.function.description.unwrap_or_default(),
1375            input_schema: tool.function.parameters,
1376        }
1377    }
1378}
1379
1380#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1381pub enum ToolCallResultStatus {
1382    Success,
1383    Error,
1384    Cancelled,
1385}
1386
1387#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1388pub struct ToolCallResult {
1389    pub call: ToolCall,
1390    pub result: String,
1391    pub status: ToolCallResultStatus,
1392}
1393
1394#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1395pub struct ToolCallResultProgress {
1396    pub id: Uuid,
1397    pub message: String,
1398}
1399
1400impl From<GenerationDelta> for ChatMessageDelta {
1401    fn from(delta: GenerationDelta) -> Self {
1402        match delta {
1403            GenerationDelta::Content { content } => ChatMessageDelta {
1404                role: Some(Role::Assistant),
1405                content: Some(content),
1406                tool_calls: None,
1407            },
1408            GenerationDelta::Thinking { thinking: _ } => ChatMessageDelta {
1409                role: Some(Role::Assistant),
1410                content: None,
1411                tool_calls: None,
1412            },
1413            GenerationDelta::ToolUse { tool_use } => ChatMessageDelta {
1414                role: Some(Role::Assistant),
1415                content: None,
1416                tool_calls: Some(vec![ToolCallDelta {
1417                    index: tool_use.index,
1418                    id: tool_use.id,
1419                    r#type: Some("function".to_string()),
1420                    function: Some(FunctionCallDelta {
1421                        name: tool_use.name,
1422                        arguments: tool_use.input,
1423                    }),
1424                }]),
1425            },
1426            _ => ChatMessageDelta {
1427                role: Some(Role::Assistant),
1428                content: None,
1429                tool_calls: None,
1430            },
1431        }
1432    }
1433}
1434
1435#[cfg(test)]
1436mod tests {
1437    use super::*;
1438
1439    #[test]
1440    fn test_serialize_basic_request() {
1441        let request = ChatCompletionRequest {
1442            model: AgentModel::Smart.to_string(),
1443            messages: vec![
1444                ChatMessage {
1445                    role: Role::System,
1446                    content: Some(MessageContent::String(
1447                        "You are a helpful assistant.".to_string(),
1448                    )),
1449                    name: None,
1450                    tool_calls: None,
1451                    tool_call_id: None,
1452                    usage: None,
1453                },
1454                ChatMessage {
1455                    role: Role::User,
1456                    content: Some(MessageContent::String("Hello!".to_string())),
1457                    name: None,
1458                    tool_calls: None,
1459                    tool_call_id: None,
1460                    usage: None,
1461                },
1462            ],
1463            frequency_penalty: None,
1464            logit_bias: None,
1465            logprobs: None,
1466            max_tokens: Some(100),
1467            n: None,
1468            presence_penalty: None,
1469            response_format: None,
1470            seed: None,
1471            stop: None,
1472            stream: None,
1473            temperature: Some(0.7),
1474            top_p: None,
1475            tools: None,
1476            tool_choice: None,
1477            user: None,
1478            context: None,
1479        };
1480
1481        let json = serde_json::to_string(&request).unwrap();
1482        assert!(json.contains("\"model\":\"smart\""));
1483        assert!(json.contains("\"messages\":["));
1484        assert!(json.contains("\"role\":\"system\""));
1485        assert!(json.contains("\"content\":\"You are a helpful assistant.\""));
1486        assert!(json.contains("\"role\":\"user\""));
1487        assert!(json.contains("\"content\":\"Hello!\""));
1488        assert!(json.contains("\"max_tokens\":100"));
1489        assert!(json.contains("\"temperature\":0.7"));
1490    }
1491
1492    #[test]
1493    fn test_deserialize_response() {
1494        let json = r#"{
1495            "id": "chatcmpl-123",
1496            "object": "chat.completion",
1497            "created": 1677652288,
1498            "model": "smart",
1499            "system_fingerprint": "fp_123abc",
1500            "choices": [{
1501                "index": 0,
1502                "message": {
1503                    "role": "assistant",
1504                    "content": "Hello! How can I help you today?"
1505                },
1506                "finish_reason": "stop"
1507            }],
1508            "usage": {
1509                "prompt_tokens": 9,
1510                "completion_tokens": 12,
1511                "total_tokens": 21
1512            }
1513        }"#;
1514
1515        let response: ChatCompletionResponse = serde_json::from_str(json).unwrap();
1516        assert_eq!(response.id, "chatcmpl-123");
1517        assert_eq!(response.object, "chat.completion");
1518        assert_eq!(response.created, 1677652288);
1519        assert_eq!(response.model, AgentModel::Smart.to_string());
1520        assert_eq!(response.system_fingerprint, Some("fp_123abc".to_string()));
1521
1522        assert_eq!(response.choices.len(), 1);
1523        assert_eq!(response.choices[0].index, 0);
1524        assert_eq!(response.choices[0].message.role, Role::Assistant);
1525
1526        match &response.choices[0].message.content {
1527            Some(MessageContent::String(content)) => {
1528                assert_eq!(content, "Hello! How can I help you today?");
1529            }
1530            _ => panic!("Expected string content"),
1531        }
1532
1533        assert_eq!(response.choices[0].finish_reason, FinishReason::Stop);
1534        assert_eq!(response.usage.prompt_tokens, 9);
1535        assert_eq!(response.usage.completion_tokens, 12);
1536        assert_eq!(response.usage.total_tokens, 21);
1537    }
1538
1539    #[test]
1540    fn test_tool_calls_request_response() {
1541        // Test a request with tools
1542        let tools_request = ChatCompletionRequest {
1543            model: AgentModel::Smart.to_string(),
1544            messages: vec![ChatMessage {
1545                role: Role::User,
1546                content: Some(MessageContent::String(
1547                    "What's the weather in San Francisco?".to_string(),
1548                )),
1549                name: None,
1550                tool_calls: None,
1551                tool_call_id: None,
1552                usage: None,
1553            }],
1554            tools: Some(vec![Tool {
1555                r#type: "function".to_string(),
1556                function: FunctionDefinition {
1557                    name: "get_weather".to_string(),
1558                    description: Some("Get the current weather in a given location".to_string()),
1559                    parameters: serde_json::json!({
1560                        "type": "object",
1561                        "properties": {
1562                            "location": {
1563                                "type": "string",
1564                                "description": "The city and state, e.g. San Francisco, CA"
1565                            }
1566                        },
1567                        "required": ["location"]
1568                    }),
1569                },
1570            }]),
1571            tool_choice: Some(ToolChoice::Auto),
1572            max_tokens: Some(100),
1573            temperature: Some(0.7),
1574            frequency_penalty: None,
1575            logit_bias: None,
1576            logprobs: None,
1577            n: None,
1578            presence_penalty: None,
1579            response_format: None,
1580            seed: None,
1581            stop: None,
1582            stream: None,
1583            top_p: None,
1584            user: None,
1585            context: None,
1586        };
1587
1588        let json = serde_json::to_string(&tools_request).unwrap();
1589        println!("Tool request JSON: {}", json);
1590
1591        assert!(json.contains("\"tools\":["));
1592        assert!(json.contains("\"type\":\"function\""));
1593        assert!(json.contains("\"name\":\"get_weather\""));
1594        // Auto should be serialized as "auto" (string)
1595        assert!(json.contains("\"tool_choice\":\"auto\""));
1596
1597        // Test response with tool calls
1598        let tool_response_json = r#"{
1599            "id": "chatcmpl-123",
1600            "object": "chat.completion",
1601            "created": 1677652288,
1602            "model": "eco",
1603            "choices": [{
1604                "index": 0,
1605                "message": {
1606                    "role": "assistant",
1607                    "content": null,
1608                    "tool_calls": [
1609                        {
1610                            "id": "call_abc123",
1611                            "type": "function",
1612                            "function": {
1613                                "name": "get_weather",
1614                                "arguments": "{\"location\":\"San Francisco, CA\"}"
1615                            }
1616                        }
1617                    ]
1618                },
1619                "finish_reason": "tool_calls"
1620            }],
1621            "usage": {
1622                "prompt_tokens": 82,
1623                "completion_tokens": 17,
1624                "total_tokens": 99
1625            }
1626        }"#;
1627
1628        let tool_response: ChatCompletionResponse =
1629            serde_json::from_str(tool_response_json).unwrap();
1630        assert_eq!(tool_response.choices[0].message.role, Role::Assistant);
1631        assert_eq!(tool_response.choices[0].message.content, None);
1632        assert!(tool_response.choices[0].message.tool_calls.is_some());
1633
1634        let tool_calls = tool_response.choices[0]
1635            .message
1636            .tool_calls
1637            .as_ref()
1638            .unwrap();
1639        assert_eq!(tool_calls.len(), 1);
1640        assert_eq!(tool_calls[0].id, "call_abc123");
1641        assert_eq!(tool_calls[0].r#type, "function");
1642        assert_eq!(tool_calls[0].function.name, "get_weather");
1643        assert_eq!(
1644            tool_calls[0].function.arguments,
1645            "{\"location\":\"San Francisco, CA\"}"
1646        );
1647
1648        assert_eq!(
1649            tool_response.choices[0].finish_reason,
1650            FinishReason::ToolCalls,
1651        );
1652    }
1653
1654    #[test]
1655    fn test_content_with_image() {
1656        let message_with_image = ChatMessage {
1657            role: Role::User,
1658            content: Some(MessageContent::Array(vec![
1659                ContentPart {
1660                    r#type: "text".to_string(),
1661                    text: Some("What's in this image?".to_string()),
1662                    image_url: None,
1663                },
1664                ContentPart {
1665                    r#type: "input_image".to_string(),
1666                    text: None,
1667                    image_url: Some(ImageUrl {
1668                        url: "data:image/jpeg;base64,/9j/4AAQSkZ...".to_string(),
1669                        detail: Some("low".to_string()),
1670                    }),
1671                },
1672            ])),
1673            name: None,
1674            tool_calls: None,
1675            tool_call_id: None,
1676            usage: None,
1677        };
1678
1679        let json = serde_json::to_string(&message_with_image).unwrap();
1680        println!("Serialized JSON: {}", json);
1681
1682        assert!(json.contains("\"role\":\"user\""));
1683        assert!(json.contains("\"type\":\"text\""));
1684        assert!(json.contains("\"text\":\"What's in this image?\""));
1685        assert!(json.contains("\"type\":\"input_image\""));
1686        assert!(json.contains("\"url\":\"data:image/jpeg;base64,/9j/4AAQSkZ...\""));
1687        assert!(json.contains("\"detail\":\"low\""));
1688    }
1689
1690    #[test]
1691    fn test_response_format() {
1692        let json_format_request = ChatCompletionRequest {
1693            model: AgentModel::Smart.to_string(),
1694            messages: vec![ChatMessage {
1695                role: Role::User,
1696                content: Some(MessageContent::String(
1697                    "Generate a JSON object with name and age fields".to_string(),
1698                )),
1699                name: None,
1700                tool_calls: None,
1701                tool_call_id: None,
1702                usage: None,
1703            }],
1704            response_format: Some(ResponseFormat {
1705                r#type: "json_object".to_string(),
1706            }),
1707            max_tokens: Some(100),
1708            temperature: None,
1709            frequency_penalty: None,
1710            logit_bias: None,
1711            logprobs: None,
1712            n: None,
1713            presence_penalty: None,
1714            seed: None,
1715            stop: None,
1716            stream: None,
1717            top_p: None,
1718            tools: None,
1719            tool_choice: None,
1720            user: None,
1721            context: None,
1722        };
1723
1724        let json = serde_json::to_string(&json_format_request).unwrap();
1725        assert!(json.contains("\"response_format\":{\"type\":\"json_object\"}"));
1726    }
1727
1728    #[test]
1729    fn test_llm_message_to_chat_message() {
1730        // Test simple string content
1731        let llm_message = LLMMessage {
1732            role: "user".to_string(),
1733            content: LLMMessageContent::String("Hello, world!".to_string()),
1734        };
1735
1736        let chat_message = ChatMessage::from(llm_message);
1737        assert_eq!(chat_message.role, Role::User);
1738        match &chat_message.content {
1739            Some(MessageContent::String(text)) => assert_eq!(text, "Hello, world!"),
1740            _ => panic!("Expected string content"),
1741        }
1742        assert_eq!(chat_message.tool_calls, None);
1743
1744        // Test tool call conversion
1745        let llm_message_with_tool = LLMMessage {
1746            role: "assistant".to_string(),
1747            content: LLMMessageContent::List(vec![LLMMessageTypedContent::ToolCall {
1748                id: "call_123".to_string(),
1749                name: "get_weather".to_string(),
1750                args: serde_json::json!({"location": "San Francisco"}),
1751            }]),
1752        };
1753
1754        let chat_message = ChatMessage::from(llm_message_with_tool);
1755        assert_eq!(chat_message.role, Role::Assistant);
1756        assert_eq!(chat_message.content, None); // No text content
1757        assert!(chat_message.tool_calls.is_some());
1758
1759        let tool_calls = chat_message.tool_calls.unwrap();
1760        assert_eq!(tool_calls.len(), 1);
1761        assert_eq!(tool_calls[0].id, "call_123");
1762        assert_eq!(tool_calls[0].function.name, "get_weather");
1763        assert!(tool_calls[0].function.arguments.contains("San Francisco"));
1764
1765        // Test mixed content
1766        let llm_message_mixed = LLMMessage {
1767            role: "assistant".to_string(),
1768            content: LLMMessageContent::List(vec![
1769                LLMMessageTypedContent::Text {
1770                    text: "The weather is:".to_string(),
1771                },
1772                LLMMessageTypedContent::ToolCall {
1773                    id: "call_456".to_string(),
1774                    name: "get_weather".to_string(),
1775                    args: serde_json::json!({"location": "New York"}),
1776                },
1777            ]),
1778        };
1779
1780        let chat_message = ChatMessage::from(llm_message_mixed);
1781        assert_eq!(chat_message.role, Role::Assistant);
1782
1783        match &chat_message.content {
1784            Some(MessageContent::Array(parts)) => {
1785                assert_eq!(parts.len(), 1);
1786                assert_eq!(parts[0].r#type, "text");
1787                assert_eq!(parts[0].text, Some("The weather is:".to_string()));
1788            }
1789            _ => panic!("Expected array content"),
1790        }
1791
1792        let tool_calls = chat_message.tool_calls.unwrap();
1793        assert_eq!(tool_calls.len(), 1);
1794        assert_eq!(tool_calls[0].id, "call_456");
1795        assert_eq!(tool_calls[0].function.name, "get_weather");
1796        assert!(tool_calls[0].function.arguments.contains("New York"));
1797    }
1798}