stakpak_shared/models/integrations/
openai.rs

1use crate::models::error::{AgentError, BadRequestErrorMessage};
2use crate::models::llm::{
3    GenerationDelta, GenerationDeltaToolUse, LLMChoice, LLMCompletionResponse, LLMMessage,
4    LLMMessageContent, LLMMessageImageSource, LLMMessageTypedContent, LLMTokenUsage, LLMTool,
5};
6use futures_util::StreamExt;
7use reqwest_middleware::ClientBuilder;
8use reqwest_retry::{RetryTransientMiddleware, policies::ExponentialBackoff};
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use serde_json::json;
12use uuid::Uuid;
13
14const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1/chat/completions";
15
16#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq)]
17pub struct OpenAIConfig {
18    pub api_endpoint: Option<String>,
19    pub api_key: Option<String>,
20}
21
22#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
23pub enum OpenAIModel {
24    // Reasoning Models
25    #[serde(rename = "o3-2025-04-16")]
26    O3,
27    #[serde(rename = "o4-mini-2025-04-16")]
28    O4Mini,
29
30    #[default]
31    #[serde(rename = "gpt-5-2025-08-07")]
32    GPT5,
33    #[serde(rename = "gpt-5-mini-2025-08-07")]
34    GPT5Mini,
35    #[serde(rename = "gpt-5-nano-2025-08-07")]
36    GPT5Nano,
37
38    Custom(String),
39}
40
41impl OpenAIModel {
42    pub fn from_string(s: &str) -> Result<Self, String> {
43        serde_json::from_value(serde_json::Value::String(s.to_string()))
44            .map_err(|_| "Failed to deserialize OpenAI model".to_string())
45    }
46
47    /// Default smart model for OpenAI
48    pub const DEFAULT_SMART_MODEL: OpenAIModel = OpenAIModel::GPT5;
49
50    /// Default eco model for OpenAI
51    pub const DEFAULT_ECO_MODEL: OpenAIModel = OpenAIModel::GPT5Mini;
52
53    /// Default recovery model for OpenAI
54    pub const DEFAULT_RECOVERY_MODEL: OpenAIModel = OpenAIModel::GPT5Mini;
55
56    /// Get default smart model as string
57    pub fn default_smart_model() -> String {
58        Self::DEFAULT_SMART_MODEL.to_string()
59    }
60
61    /// Get default eco model as string
62    pub fn default_eco_model() -> String {
63        Self::DEFAULT_ECO_MODEL.to_string()
64    }
65
66    /// Get default recovery model as string
67    pub fn default_recovery_model() -> String {
68        Self::DEFAULT_RECOVERY_MODEL.to_string()
69    }
70}
71
72impl std::fmt::Display for OpenAIModel {
73    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74        match self {
75            OpenAIModel::O3 => write!(f, "o3-2025-04-16"),
76
77            OpenAIModel::O4Mini => write!(f, "o4-mini-2025-04-16"),
78
79            OpenAIModel::GPT5Nano => write!(f, "gpt-5-nano-2025-08-07"),
80            OpenAIModel::GPT5Mini => write!(f, "gpt-5-mini-2025-08-07"),
81            OpenAIModel::GPT5 => write!(f, "gpt-5-2025-08-07"),
82
83            OpenAIModel::Custom(model_name) => write!(f, "{}", model_name),
84        }
85    }
86}
87
88#[derive(Serialize, Deserialize, Debug)]
89pub struct OpenAIInput {
90    pub model: OpenAIModel,
91    pub messages: Vec<LLMMessage>,
92    pub max_tokens: u32,
93
94    #[serde(skip_serializing_if = "Option::is_none")]
95    pub json: Option<serde_json::Value>,
96
97    #[serde(skip_serializing_if = "Option::is_none")]
98    pub tools: Option<Vec<LLMTool>>,
99
100    #[serde(skip_serializing_if = "Option::is_none")]
101    pub reasoning_effort: Option<OpenAIReasoningEffort>,
102}
103
104impl OpenAIInput {
105    pub fn is_reasoning_model(&self) -> bool {
106        matches!(self.model, |OpenAIModel::O3| OpenAIModel::O4Mini
107            | OpenAIModel::GPT5
108            | OpenAIModel::GPT5Mini
109            | OpenAIModel::GPT5Nano)
110    }
111
112    pub fn is_standard_model(&self) -> bool {
113        !self.is_reasoning_model()
114    }
115}
116
117#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
118pub enum OpenAIReasoningEffort {
119    #[serde(rename = "minimal")]
120    Minimal,
121    #[serde(rename = "low")]
122    Low,
123    #[default]
124    #[serde(rename = "medium")]
125    Medium,
126    #[serde(rename = "high")]
127    High,
128}
129
130#[derive(Serialize, Deserialize, Debug)]
131pub struct OpenAITool {
132    pub r#type: String,
133    pub function: OpenAIToolFunction,
134}
135
136#[derive(Serialize, Deserialize, Debug)]
137pub struct OpenAIToolFunction {
138    pub name: String,
139    pub description: String,
140    pub parameters: serde_json::Value,
141}
142
143impl From<LLMTool> for OpenAITool {
144    fn from(tool: LLMTool) -> Self {
145        OpenAITool {
146            r#type: "function".to_string(),
147            function: OpenAIToolFunction {
148                name: tool.name,
149                description: tool.description,
150                parameters: tool.input_schema,
151            },
152        }
153    }
154}
155
156#[derive(Serialize, Deserialize, Debug, Clone)]
157pub struct OpenAIOutput {
158    pub model: String,
159    pub object: String,
160    pub choices: Vec<OpenAILLMChoice>,
161    pub created: u64,
162    pub usage: Option<LLMTokenUsage>,
163    pub id: String,
164}
165
166impl From<OpenAIOutput> for LLMCompletionResponse {
167    fn from(val: OpenAIOutput) -> Self {
168        LLMCompletionResponse {
169            model: val.model,
170            object: val.object,
171            choices: val.choices.into_iter().map(OpenAILLMChoice::into).collect(),
172            created: val.created,
173            usage: val.usage,
174            id: val.id,
175        }
176    }
177}
178
179#[derive(Serialize, Deserialize, Debug, Clone)]
180pub struct OpenAILLMChoice {
181    pub finish_reason: Option<String>,
182    pub index: u32,
183    pub message: OpenAILLMMessage,
184}
185
186impl From<OpenAILLMChoice> for LLMChoice {
187    fn from(val: OpenAILLMChoice) -> Self {
188        LLMChoice {
189            finish_reason: val.finish_reason,
190            index: val.index,
191            message: val.message.into(),
192        }
193    }
194}
195
196#[derive(Serialize, Deserialize, Debug, Clone)]
197pub struct OpenAILLMMessage {
198    pub role: String,
199    pub content: Option<String>,
200    pub tool_calls: Option<Vec<OpenAILLMMessageToolCall>>,
201}
202impl From<OpenAILLMMessage> for LLMMessage {
203    fn from(val: OpenAILLMMessage) -> Self {
204        LLMMessage {
205            role: val.role,
206            content: match val.tool_calls {
207                None => LLMMessageContent::String(val.content.unwrap_or_default()),
208                Some(tool_calls) => LLMMessageContent::List(
209                    std::iter::once(LLMMessageTypedContent::Text {
210                        text: val.content.unwrap_or_default(),
211                    })
212                    .chain(tool_calls.into_iter().map(|tool_call| {
213                        LLMMessageTypedContent::ToolCall {
214                            id: tool_call.id,
215                            name: tool_call.function.name,
216                            args: match serde_json::from_str(&tool_call.function.arguments) {
217                                Ok(args) => args,
218                                Err(_) => {
219                                    return LLMMessageTypedContent::Text {
220                                        text: String::from("Error parsing tool call arguments"),
221                                    };
222                                }
223                            },
224                        }
225                    }))
226                    .collect(),
227                ),
228            },
229        }
230    }
231}
232
233#[derive(Serialize, Deserialize, Debug, Clone)]
234pub struct OpenAILLMMessageToolCall {
235    pub id: String,
236    pub r#type: String,
237    pub function: OpenAILLMMessageToolCallFunction,
238}
239#[derive(Serialize, Deserialize, Debug, Clone)]
240pub struct OpenAILLMMessageToolCallFunction {
241    pub arguments: String,
242    pub name: String,
243}
244
245pub struct OpenAI {}
246
247impl OpenAI {
248    pub async fn chat(
249        config: &OpenAIConfig,
250        input: OpenAIInput,
251    ) -> Result<LLMCompletionResponse, AgentError> {
252        let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
253        let client = ClientBuilder::new(reqwest::Client::new())
254            .with(RetryTransientMiddleware::new_with_policy(retry_policy))
255            .build();
256
257        let is_reasoning_model = input.is_reasoning_model();
258
259        let mut payload = json!({
260            "model": input.model.to_string(),
261            "messages": input.messages.into_iter().map(ChatMessage::from).collect::<Vec<ChatMessage>>(),
262            "max_completion_tokens": input.max_tokens,
263            "stream": false,
264        });
265
266        if is_reasoning_model {
267            if let Some(reasoning_effort) = input.reasoning_effort {
268                payload["reasoning_effort"] = json!(reasoning_effort);
269            } else {
270                payload["reasoning_effort"] = json!(OpenAIReasoningEffort::Medium);
271            }
272        } else {
273            payload["temperature"] = json!(0);
274        }
275
276        if let Some(tools) = input.tools {
277            let openai_tools: Vec<OpenAITool> = tools.into_iter().map(|t| t.into()).collect();
278            payload["tools"] = json!(openai_tools);
279        }
280
281        if let Some(schema) = input.json {
282            payload["response_format"] = json!({
283                "type": "json_schema",
284                "json_schema": {
285                    "strict": true,
286                    "schema": schema,
287                    "name": "my-schema"
288                }
289            });
290        }
291
292        let api_endpoint = config.api_endpoint.as_ref().map_or(DEFAULT_BASE_URL, |v| v);
293        let api_key = config.api_key.as_ref().map_or("", |v| v);
294
295        let response = client
296            .post(api_endpoint)
297            .header("Authorization", format!("Bearer {}", api_key))
298            .header("Accept", "application/json")
299            .header("Content-Type", "application/json")
300            .body(serde_json::to_string(&payload).map_err(|e| {
301                AgentError::BadRequest(BadRequestErrorMessage::ApiError(e.to_string()))
302            })?)
303            .send()
304            .await;
305
306        let response = match response {
307            Ok(resp) => resp,
308            Err(e) => {
309                return Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
310                    e.to_string(),
311                )));
312            }
313        };
314
315        match response.json::<Value>().await {
316            Ok(json) => match serde_json::from_value::<OpenAIOutput>(json.clone()) {
317                Ok(json_response) => Ok(json_response.into()),
318                Err(e) => Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
319                    e.to_string(),
320                ))),
321            },
322            Err(e) => Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
323                e.to_string(),
324            ))),
325        }
326    }
327
328    pub async fn chat_stream(
329        config: &OpenAIConfig,
330        stream_channel_tx: tokio::sync::mpsc::Sender<GenerationDelta>,
331        input: OpenAIInput,
332    ) -> Result<LLMCompletionResponse, AgentError> {
333        let is_reasoning_model = input.is_reasoning_model();
334
335        let mut payload = json!({
336            "model": input.model.to_string(),
337            "messages": input.messages.into_iter().map(ChatMessage::from).collect::<Vec<ChatMessage>>(),
338            "max_completion_tokens": input.max_tokens,
339            "stream": true,
340            "stream_options":{
341                "include_usage": true
342            }
343        });
344
345        if is_reasoning_model {
346            if let Some(reasoning_effort) = input.reasoning_effort {
347                payload["reasoning_effort"] = json!(reasoning_effort);
348            } else {
349                payload["reasoning_effort"] = json!(OpenAIReasoningEffort::Medium);
350            }
351        } else {
352            payload["temperature"] = json!(0);
353        }
354
355        if let Some(tools) = input.tools {
356            let openai_tools: Vec<OpenAITool> = tools.into_iter().map(|t| t.into()).collect();
357            payload["tools"] = json!(openai_tools);
358        }
359
360        if let Some(schema) = input.json {
361            payload["response_format"] = json!({
362                "type": "json_schema",
363                "json_schema": {
364                    "strict": true,
365                    "schema": schema,
366                    "name": "my-schema"
367                }
368            });
369        }
370
371        let retry_policy = ExponentialBackoff::builder().build_with_max_retries(3);
372        let client = ClientBuilder::new(reqwest::Client::new())
373            .with(RetryTransientMiddleware::new_with_policy(retry_policy))
374            .build();
375
376        let api_endpoint = config.api_endpoint.as_ref().map_or(DEFAULT_BASE_URL, |v| v);
377        let api_key = config.api_key.as_ref().map_or("", |v| v);
378
379        // Send the POST request
380        let response = client
381            .post(api_endpoint)
382            .header("Authorization", format!("Bearer {}", api_key))
383            .header("Accept", "application/json")
384            .header("Content-Type", "application/json")
385            .json(&payload)
386            .send()
387            .await;
388
389        let response = match response {
390            Ok(resp) => resp,
391            Err(e) => {
392                return Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
393                    e.to_string(),
394                )));
395            }
396        };
397
398        if !response.status().is_success() {
399            return Err(AgentError::BadRequest(BadRequestErrorMessage::ApiError(
400                format!(
401                    "{}: {}",
402                    response.status(),
403                    response.text().await.unwrap_or_default(),
404                ),
405            )));
406        }
407
408        process_openai_stream(response, input.model.to_string(), stream_channel_tx)
409            .await
410            .map_err(|e| AgentError::BadRequest(BadRequestErrorMessage::ApiError(e.to_string())))
411    }
412}
413
414/// Process OpenAI stream and convert to GenerationDelta format
415async fn process_openai_stream(
416    response: reqwest::Response,
417    model: String,
418    stream_channel_tx: tokio::sync::mpsc::Sender<GenerationDelta>,
419) -> Result<LLMCompletionResponse, AgentError> {
420    let mut completion_response = LLMCompletionResponse {
421        id: "".to_string(),
422        model: model.clone(),
423        object: "chat.completion".to_string(),
424        choices: vec![],
425        created: chrono::Utc::now().timestamp_millis() as u64,
426        usage: None,
427    };
428
429    let mut stream = response.bytes_stream();
430    let mut unparsed_data = String::new();
431    let mut current_tool_calls: std::collections::HashMap<usize, (String, String, String)> =
432        std::collections::HashMap::new();
433    let mut accumulated_content = String::new();
434    let mut finish_reason: Option<String> = None;
435
436    while let Some(chunk) = stream.next().await {
437        let chunk = chunk.map_err(|e| {
438            let error_message = format!("Failed to read stream chunk from OpenAI API: {e}");
439            AgentError::BadRequest(BadRequestErrorMessage::ApiError(error_message))
440        })?;
441
442        let text = std::str::from_utf8(&chunk).map_err(|e| {
443            let error_message = format!("Failed to parse UTF-8 from OpenAI response: {e}");
444            AgentError::BadRequest(BadRequestErrorMessage::ApiError(error_message))
445        })?;
446
447        unparsed_data.push_str(text);
448
449        while let Some(line_end) = unparsed_data.find('\n') {
450            let line = unparsed_data[..line_end].to_string();
451            unparsed_data = unparsed_data[line_end + 1..].to_string();
452
453            if line.trim().is_empty() {
454                continue;
455            }
456
457            if !line.starts_with("data: ") {
458                continue;
459            }
460
461            let json_str = &line[6..];
462            if json_str == "[DONE]" {
463                continue;
464            }
465
466            match serde_json::from_str::<ChatCompletionStreamResponse>(json_str) {
467                Ok(stream_response) => {
468                    // Update completion response metadata
469                    if completion_response.id.is_empty() {
470                        completion_response.id = stream_response.id.clone();
471                        completion_response.model = stream_response.model.clone();
472                        completion_response.object = stream_response.object.clone();
473                        completion_response.created = stream_response.created;
474                    }
475
476                    // Process choices
477                    for choice in &stream_response.choices {
478                        if let Some(content) = &choice.delta.content {
479                            // Send content delta
480                            stream_channel_tx
481                                .send(GenerationDelta::Content {
482                                    content: content.clone(),
483                                })
484                                .await
485                                .map_err(|e| {
486                                    AgentError::BadRequest(BadRequestErrorMessage::ApiError(
487                                        e.to_string(),
488                                    ))
489                                })?;
490                            accumulated_content.push_str(content);
491                        }
492
493                        // Handle tool calls
494                        if let Some(tool_calls) = &choice.delta.tool_calls {
495                            for tool_call in tool_calls {
496                                let index = tool_call.index;
497
498                                // Initialize or update tool call
499                                let entry = current_tool_calls.entry(index).or_insert((
500                                    String::new(),
501                                    String::new(),
502                                    String::new(),
503                                ));
504
505                                if let Some(id) = &tool_call.id {
506                                    entry.0 = id.clone();
507                                }
508                                if let Some(function) = &tool_call.function {
509                                    if let Some(name) = &function.name {
510                                        entry.1 = name.clone();
511                                    }
512                                    if let Some(args) = &function.arguments {
513                                        entry.2.push_str(args);
514                                    }
515                                }
516
517                                // Send tool use delta
518                                stream_channel_tx
519                                    .send(GenerationDelta::ToolUse {
520                                        tool_use: GenerationDeltaToolUse {
521                                            id: tool_call.id.clone(),
522                                            name: tool_call
523                                                .function
524                                                .as_ref()
525                                                .and_then(|f| f.name.clone())
526                                                .and_then(|n| {
527                                                    if n.is_empty() { None } else { Some(n) }
528                                                }),
529                                            input: tool_call
530                                                .function
531                                                .as_ref()
532                                                .and_then(|f| f.arguments.clone()),
533                                            index,
534                                        },
535                                    })
536                                    .await
537                                    .map_err(|e| {
538                                        AgentError::BadRequest(BadRequestErrorMessage::ApiError(
539                                            e.to_string(),
540                                        ))
541                                    })?;
542                            }
543                        }
544
545                        if let Some(reason) = &choice.finish_reason {
546                            finish_reason = Some(match reason {
547                                FinishReason::Stop => "stop".to_string(),
548                                FinishReason::Length => "length".to_string(),
549                                FinishReason::ContentFilter => "content_filter".to_string(),
550                                FinishReason::ToolCalls => "tool_calls".to_string(),
551                            });
552                        }
553                    }
554
555                    // Update usage if available
556                    if let Some(usage) = &stream_response.usage {
557                        stream_channel_tx
558                            .send(GenerationDelta::Usage {
559                                usage: usage.clone(),
560                            })
561                            .await
562                            .map_err(|e| {
563                                AgentError::BadRequest(BadRequestErrorMessage::ApiError(
564                                    e.to_string(),
565                                ))
566                            })?;
567                        completion_response.usage = Some(usage.clone());
568                    }
569                }
570                Err(e) => {
571                    eprintln!("Error parsing response: {}", e);
572                }
573            }
574        }
575    }
576
577    // Build final response
578    let mut message_content = vec![];
579
580    if !accumulated_content.is_empty() {
581        message_content.push(LLMMessageTypedContent::Text {
582            text: accumulated_content,
583        });
584    }
585
586    for (_, (id, name, args)) in current_tool_calls {
587        if let Ok(parsed_args) = serde_json::from_str(&args) {
588            message_content.push(LLMMessageTypedContent::ToolCall {
589                id,
590                name,
591                args: parsed_args,
592            });
593        }
594    }
595
596    completion_response.choices = vec![LLMChoice {
597        finish_reason,
598        index: 0,
599        message: LLMMessage {
600            role: "assistant".to_string(),
601            content: if message_content.is_empty() {
602                LLMMessageContent::String(String::new())
603            } else if message_content.len() == 1
604                && matches!(&message_content[0], LLMMessageTypedContent::Text { .. })
605            {
606                if let LLMMessageTypedContent::Text { text } = &message_content[0] {
607                    LLMMessageContent::String(text.clone())
608                } else {
609                    LLMMessageContent::List(message_content)
610                }
611            } else {
612                LLMMessageContent::List(message_content)
613            },
614        },
615    }];
616
617    Ok(completion_response)
618}
619
620#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
621#[serde(rename_all = "lowercase")]
622pub enum Role {
623    System,
624    Developer,
625    User,
626    #[default]
627    Assistant,
628    Tool,
629    // Function,
630}
631
632impl std::fmt::Display for Role {
633    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
634        match self {
635            Role::System => write!(f, "system"),
636            Role::Developer => write!(f, "developer"),
637            Role::User => write!(f, "user"),
638            Role::Assistant => write!(f, "assistant"),
639            Role::Tool => write!(f, "tool"),
640        }
641    }
642}
643
644#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
645pub struct ChatCompletionRequest {
646    pub model: String,
647    pub messages: Vec<ChatMessage>,
648    #[serde(skip_serializing_if = "Option::is_none")]
649    pub frequency_penalty: Option<f32>,
650    #[serde(skip_serializing_if = "Option::is_none")]
651    pub logit_bias: Option<serde_json::Value>,
652    #[serde(skip_serializing_if = "Option::is_none")]
653    pub logprobs: Option<bool>,
654    #[serde(skip_serializing_if = "Option::is_none")]
655    pub max_tokens: Option<u32>,
656    #[serde(skip_serializing_if = "Option::is_none")]
657    pub n: Option<u32>,
658    #[serde(skip_serializing_if = "Option::is_none")]
659    pub presence_penalty: Option<f32>,
660    #[serde(skip_serializing_if = "Option::is_none")]
661    pub response_format: Option<ResponseFormat>,
662    #[serde(skip_serializing_if = "Option::is_none")]
663    pub seed: Option<i64>,
664    #[serde(skip_serializing_if = "Option::is_none")]
665    pub stop: Option<StopSequence>,
666    #[serde(skip_serializing_if = "Option::is_none")]
667    pub stream: Option<bool>,
668    #[serde(skip_serializing_if = "Option::is_none")]
669    pub temperature: Option<f32>,
670    #[serde(skip_serializing_if = "Option::is_none")]
671    pub top_p: Option<f32>,
672    #[serde(skip_serializing_if = "Option::is_none")]
673    pub tools: Option<Vec<Tool>>,
674    #[serde(skip_serializing_if = "Option::is_none")]
675    pub tool_choice: Option<ToolChoice>,
676    #[serde(skip_serializing_if = "Option::is_none")]
677    pub user: Option<String>,
678
679    #[serde(skip_serializing_if = "Option::is_none")]
680    pub context: Option<ChatCompletionContext>,
681}
682
683#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
684pub struct ChatCompletionContext {
685    pub scratchpad: Option<Value>,
686}
687
688#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
689pub struct ChatMessage {
690    pub role: Role,
691    pub content: Option<MessageContent>,
692    #[serde(skip_serializing_if = "Option::is_none")]
693    pub name: Option<String>,
694    #[serde(skip_serializing_if = "Option::is_none")]
695    pub tool_calls: Option<Vec<ToolCall>>,
696    #[serde(skip_serializing_if = "Option::is_none")]
697    pub tool_call_id: Option<String>,
698
699    // TODO: check if this is needed
700    #[serde(skip_serializing_if = "Option::is_none")]
701    pub usage: Option<LLMTokenUsage>,
702}
703
704impl ChatMessage {
705    pub fn last_server_message(messages: &[ChatMessage]) -> Option<&ChatMessage> {
706        messages
707            .iter()
708            .rev()
709            .find(|message| message.role != Role::User && message.role != Role::Tool)
710    }
711
712    pub fn to_xml(&self) -> String {
713        match &self.content {
714            Some(MessageContent::String(s)) => {
715                format!("<message role=\"{}\">{}</message>", self.role, s)
716            }
717            Some(MessageContent::Array(parts)) => parts
718                .iter()
719                .map(|part| {
720                    format!(
721                        "<message role=\"{}\" type=\"{}\">{}</message>",
722                        self.role,
723                        part.r#type,
724                        part.text.clone().unwrap_or_default()
725                    )
726                })
727                .collect::<Vec<String>>()
728                .join("\n"),
729            None => String::new(),
730        }
731    }
732}
733
734#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
735#[serde(untagged)]
736pub enum MessageContent {
737    String(String),
738    Array(Vec<ContentPart>),
739}
740
741impl MessageContent {
742    pub fn inject_checkpoint_id(&self, checkpoint_id: Uuid) -> Self {
743        match self {
744            MessageContent::String(s) => MessageContent::String(format!(
745                "<checkpoint_id>{checkpoint_id}</checkpoint_id>\n{s}"
746            )),
747            MessageContent::Array(parts) => MessageContent::Array(
748                std::iter::once(ContentPart {
749                    r#type: "text".to_string(),
750                    text: Some(format!("<checkpoint_id>{checkpoint_id}</checkpoint_id>")),
751                    image_url: None,
752                })
753                .chain(parts.iter().cloned())
754                .collect(),
755            ),
756        }
757    }
758
759    pub fn extract_checkpoint_id(&self) -> Option<Uuid> {
760        match self {
761            MessageContent::String(s) => s
762                .rfind("<checkpoint_id>")
763                .and_then(|start| {
764                    s[start..]
765                        .find("</checkpoint_id>")
766                        .map(|end| (start + "<checkpoint_id>".len(), start + end))
767                })
768                .and_then(|(start, end)| Uuid::parse_str(&s[start..end]).ok()),
769            MessageContent::Array(parts) => parts.iter().rev().find_map(|part| {
770                part.text.as_deref().and_then(|text| {
771                    text.rfind("<checkpoint_id>")
772                        .and_then(|start| {
773                            text[start..]
774                                .find("</checkpoint_id>")
775                                .map(|end| (start + "<checkpoint_id>".len(), start + end))
776                        })
777                        .and_then(|(start, end)| Uuid::parse_str(&text[start..end]).ok())
778                })
779            }),
780        }
781    }
782}
783
784impl std::fmt::Display for MessageContent {
785    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
786        match self {
787            MessageContent::String(s) => write!(f, "{s}"),
788            MessageContent::Array(parts) => {
789                let text_parts: Vec<String> =
790                    parts.iter().filter_map(|part| part.text.clone()).collect();
791                write!(f, "{}", text_parts.join("\n"))
792            }
793        }
794    }
795}
796impl Default for MessageContent {
797    fn default() -> Self {
798        MessageContent::String(String::new())
799    }
800}
801
802#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
803pub struct ContentPart {
804    pub r#type: String,
805    #[serde(skip_serializing_if = "Option::is_none")]
806    pub text: Option<String>,
807    #[serde(skip_serializing_if = "Option::is_none")]
808    pub image_url: Option<ImageUrl>,
809}
810
811#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
812pub struct ImageUrl {
813    pub url: String,
814    #[serde(skip_serializing_if = "Option::is_none")]
815    pub detail: Option<String>,
816}
817
818#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
819pub struct ResponseFormat {
820    pub r#type: String,
821}
822
823#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
824#[serde(untagged)]
825pub enum StopSequence {
826    String(String),
827    Array(Vec<String>),
828}
829
830#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
831pub struct Tool {
832    pub r#type: String,
833    pub function: FunctionDefinition,
834}
835
836#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
837pub struct FunctionDefinition {
838    pub name: String,
839    pub description: Option<String>,
840    pub parameters: serde_json::Value,
841}
842
843#[derive(Debug, Clone, PartialEq)]
844pub enum ToolChoice {
845    Auto,
846    Required,
847    Object(ToolChoiceObject),
848}
849
850impl Serialize for ToolChoice {
851    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
852    where
853        S: serde::Serializer,
854    {
855        match self {
856            ToolChoice::Auto => serializer.serialize_str("auto"),
857            ToolChoice::Required => serializer.serialize_str("required"),
858            ToolChoice::Object(obj) => obj.serialize(serializer),
859        }
860    }
861}
862
863impl<'de> Deserialize<'de> for ToolChoice {
864    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
865    where
866        D: serde::Deserializer<'de>,
867    {
868        struct ToolChoiceVisitor;
869
870        impl<'de> serde::de::Visitor<'de> for ToolChoiceVisitor {
871            type Value = ToolChoice;
872
873            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
874                formatter.write_str("string or object")
875            }
876
877            fn visit_str<E>(self, value: &str) -> Result<ToolChoice, E>
878            where
879                E: serde::de::Error,
880            {
881                match value {
882                    "auto" => Ok(ToolChoice::Auto),
883                    "required" => Ok(ToolChoice::Required),
884                    _ => Err(serde::de::Error::unknown_variant(
885                        value,
886                        &["auto", "required"],
887                    )),
888                }
889            }
890
891            fn visit_map<M>(self, map: M) -> Result<ToolChoice, M::Error>
892            where
893                M: serde::de::MapAccess<'de>,
894            {
895                let obj = ToolChoiceObject::deserialize(
896                    serde::de::value::MapAccessDeserializer::new(map),
897                )?;
898                Ok(ToolChoice::Object(obj))
899            }
900        }
901
902        deserializer.deserialize_any(ToolChoiceVisitor)
903    }
904}
905
906#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
907pub struct ToolChoiceObject {
908    pub r#type: String,
909    pub function: FunctionChoice,
910}
911
912#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
913pub struct FunctionChoice {
914    pub name: String,
915}
916
917#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
918pub struct ToolCall {
919    pub id: String,
920    pub r#type: String,
921    pub function: FunctionCall,
922}
923
924#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
925pub struct FunctionCall {
926    pub name: String,
927    pub arguments: String,
928}
929
930#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
931pub struct ChatCompletionResponse {
932    pub id: String,
933    pub object: String,
934    pub created: u64,
935    pub model: String,
936    pub choices: Vec<ChatCompletionChoice>,
937    pub usage: LLMTokenUsage,
938    #[serde(skip_serializing_if = "Option::is_none")]
939    pub system_fingerprint: Option<String>,
940    pub metadata: Option<serde_json::Value>,
941}
942
943#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
944pub struct ChatCompletionChoice {
945    pub index: usize,
946    pub message: ChatMessage,
947    pub logprobs: Option<LogProbs>,
948    pub finish_reason: FinishReason,
949}
950
951#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
952#[serde(rename_all = "snake_case")]
953pub enum FinishReason {
954    Stop,
955    Length,
956    ContentFilter,
957    ToolCalls,
958}
959
960#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
961pub struct LogProbs {
962    pub content: Option<Vec<LogProbContent>>,
963}
964
965#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
966pub struct LogProbContent {
967    pub token: String,
968    pub logprob: f32,
969    pub bytes: Option<Vec<u8>>,
970    pub top_logprobs: Option<Vec<TokenLogprob>>,
971}
972
973#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
974pub struct TokenLogprob {
975    pub token: String,
976    pub logprob: f32,
977    pub bytes: Option<Vec<u8>>,
978}
979
980#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
981pub struct ChatCompletionStreamResponse {
982    pub id: String,
983    pub object: String,
984    pub created: u64,
985    pub model: String,
986    pub choices: Vec<ChatCompletionStreamChoice>,
987    pub usage: Option<LLMTokenUsage>,
988    pub metadata: Option<serde_json::Value>,
989}
990#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
991pub struct ChatCompletionStreamChoice {
992    pub index: usize,
993    pub delta: ChatMessageDelta,
994    pub finish_reason: Option<FinishReason>,
995}
996#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
997pub struct ChatMessageDelta {
998    #[serde(skip_serializing_if = "Option::is_none")]
999    pub role: Option<Role>,
1000    #[serde(skip_serializing_if = "Option::is_none")]
1001    pub content: Option<String>,
1002    #[serde(skip_serializing_if = "Option::is_none")]
1003    pub tool_calls: Option<Vec<ToolCallDelta>>,
1004}
1005
1006#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1007pub struct ToolCallDelta {
1008    pub index: usize,
1009    pub id: Option<String>,
1010    pub r#type: Option<String>,
1011    pub function: Option<FunctionCallDelta>,
1012}
1013
1014#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1015pub struct FunctionCallDelta {
1016    pub name: Option<String>,
1017    pub arguments: Option<String>,
1018}
1019
1020impl From<LLMMessage> for ChatMessage {
1021    fn from(llm_message: LLMMessage) -> Self {
1022        let role = match llm_message.role.as_str() {
1023            "system" => Role::System,
1024            "user" => Role::User,
1025            "assistant" => Role::Assistant,
1026            "tool" => Role::Tool,
1027            // "function" => Role::Function,
1028            "developer" => Role::Developer,
1029            _ => Role::User, // Default to user for unknown roles
1030        };
1031
1032        let (content, tool_calls) = match llm_message.content {
1033            LLMMessageContent::String(text) => (Some(MessageContent::String(text)), None),
1034            LLMMessageContent::List(items) => {
1035                let mut text_parts = Vec::new();
1036                let mut tool_call_parts = Vec::new();
1037
1038                for item in items {
1039                    match item {
1040                        LLMMessageTypedContent::Text { text } => {
1041                            text_parts.push(ContentPart {
1042                                r#type: "text".to_string(),
1043                                text: Some(text),
1044                                image_url: None,
1045                            });
1046                        }
1047                        LLMMessageTypedContent::ToolCall { id, name, args } => {
1048                            tool_call_parts.push(ToolCall {
1049                                id,
1050                                r#type: "function".to_string(),
1051                                function: FunctionCall {
1052                                    name,
1053                                    arguments: args.to_string(),
1054                                },
1055                            });
1056                        }
1057                        LLMMessageTypedContent::ToolResult { content, .. } => {
1058                            text_parts.push(ContentPart {
1059                                r#type: "text".to_string(),
1060                                text: Some(content),
1061                                image_url: None,
1062                            });
1063                        }
1064                        LLMMessageTypedContent::Image { source } => {
1065                            text_parts.push(ContentPart {
1066                                r#type: "image_url".to_string(),
1067                                text: None,
1068                                image_url: Some(ImageUrl {
1069                                    url: format!(
1070                                        "data:{};base64,{}",
1071                                        source.media_type, source.data
1072                                    ),
1073                                    detail: None,
1074                                }),
1075                            });
1076                        }
1077                    }
1078                }
1079
1080                let content = if !text_parts.is_empty() {
1081                    Some(MessageContent::Array(text_parts))
1082                } else {
1083                    None
1084                };
1085
1086                let tool_calls = if !tool_call_parts.is_empty() {
1087                    Some(tool_call_parts)
1088                } else {
1089                    None
1090                };
1091
1092                (content, tool_calls)
1093            }
1094        };
1095
1096        ChatMessage {
1097            role,
1098            content,
1099            name: None, // LLMMessage doesn't have a name field
1100            tool_calls,
1101            tool_call_id: None, // LLMMessage doesn't have a tool_call_id field
1102            usage: None,
1103        }
1104    }
1105}
1106
1107impl From<ChatMessage> for LLMMessage {
1108    fn from(chat_message: ChatMessage) -> Self {
1109        let mut content_parts = Vec::new();
1110
1111        // Handle text content
1112        match chat_message.content {
1113            Some(MessageContent::String(s)) => {
1114                if !s.is_empty() {
1115                    content_parts.push(LLMMessageTypedContent::Text { text: s });
1116                }
1117            }
1118            Some(MessageContent::Array(parts)) => {
1119                for part in parts {
1120                    if let Some(text) = part.text {
1121                        content_parts.push(LLMMessageTypedContent::Text { text });
1122                    } else if let Some(image_url) = part.image_url {
1123                        let (media_type, data) = if image_url.url.starts_with("data:") {
1124                            let parts: Vec<&str> = image_url.url.splitn(2, ',').collect();
1125                            if parts.len() == 2 {
1126                                let meta = parts[0];
1127                                let data = parts[1];
1128                                let media_type = meta
1129                                    .trim_start_matches("data:")
1130                                    .trim_end_matches(";base64")
1131                                    .to_string();
1132                                (media_type, data.to_string())
1133                            } else {
1134                                ("image/jpeg".to_string(), image_url.url)
1135                            }
1136                        } else {
1137                            ("image/jpeg".to_string(), image_url.url)
1138                        };
1139
1140                        content_parts.push(LLMMessageTypedContent::Image {
1141                            source: LLMMessageImageSource {
1142                                r#type: "base64".to_string(),
1143                                media_type,
1144                                data,
1145                            },
1146                        });
1147                    }
1148                }
1149            }
1150            None => {}
1151        }
1152
1153        // Handle tool calls
1154        if let Some(tool_calls) = chat_message.tool_calls {
1155            for tool_call in tool_calls {
1156                let args = serde_json::from_str(&tool_call.function.arguments).unwrap_or(json!({}));
1157                content_parts.push(LLMMessageTypedContent::ToolCall {
1158                    id: tool_call.id,
1159                    name: tool_call.function.name,
1160                    args,
1161                });
1162            }
1163        }
1164
1165        LLMMessage {
1166            role: chat_message.role.to_string(),
1167            content: if content_parts.is_empty() {
1168                LLMMessageContent::String(String::new())
1169            } else if content_parts.len() == 1 {
1170                match &content_parts[0] {
1171                    LLMMessageTypedContent::Text { text } => {
1172                        LLMMessageContent::String(text.clone())
1173                    }
1174                    _ => LLMMessageContent::List(content_parts),
1175                }
1176            } else {
1177                LLMMessageContent::List(content_parts)
1178            },
1179        }
1180    }
1181}
1182
1183#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq)]
1184pub enum AgentModel {
1185    #[serde(rename = "smart")]
1186    #[default]
1187    Smart,
1188    #[serde(rename = "eco")]
1189    Eco,
1190    #[serde(rename = "recovery")]
1191    Recovery,
1192}
1193
1194impl std::fmt::Display for AgentModel {
1195    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1196        match self {
1197            AgentModel::Smart => write!(f, "smart"),
1198            AgentModel::Eco => write!(f, "eco"),
1199            AgentModel::Recovery => write!(f, "recovery"),
1200        }
1201    }
1202}
1203
1204impl From<String> for AgentModel {
1205    fn from(value: String) -> Self {
1206        match value.as_str() {
1207            "eco" => AgentModel::Eco,
1208            "recovery" => AgentModel::Recovery,
1209            _ => AgentModel::Smart,
1210        }
1211    }
1212}
1213
1214impl ChatCompletionRequest {
1215    pub fn new(
1216        model: String,
1217        messages: Vec<ChatMessage>,
1218        tools: Option<Vec<Tool>>,
1219        stream: Option<bool>,
1220    ) -> Self {
1221        Self {
1222            model,
1223            messages,
1224            frequency_penalty: None,
1225            logit_bias: None,
1226            logprobs: None,
1227            max_tokens: None,
1228            n: None,
1229            presence_penalty: None,
1230            response_format: None,
1231            seed: None,
1232            stop: None,
1233            stream,
1234            temperature: None,
1235            top_p: None,
1236            tools,
1237            tool_choice: None,
1238            user: None,
1239            context: None,
1240        }
1241    }
1242}
1243
1244impl From<Tool> for LLMTool {
1245    fn from(tool: Tool) -> Self {
1246        LLMTool {
1247            name: tool.function.name,
1248            description: tool.function.description.unwrap_or_default(),
1249            input_schema: tool.function.parameters,
1250        }
1251    }
1252}
1253
1254#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1255pub enum ToolCallResultStatus {
1256    Success,
1257    Error,
1258    Cancelled,
1259}
1260
1261#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
1262pub struct ToolCallResult {
1263    pub call: ToolCall,
1264    pub result: String,
1265    pub status: ToolCallResultStatus,
1266}
1267
1268#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1269pub struct ToolCallResultProgress {
1270    pub id: Uuid,
1271    pub message: String,
1272}
1273
1274impl From<GenerationDelta> for ChatMessageDelta {
1275    fn from(delta: GenerationDelta) -> Self {
1276        match delta {
1277            GenerationDelta::Content { content } => ChatMessageDelta {
1278                role: Some(Role::Assistant),
1279                content: Some(content),
1280                tool_calls: None,
1281            },
1282            GenerationDelta::Thinking { thinking: _ } => ChatMessageDelta {
1283                role: Some(Role::Assistant),
1284                content: None,
1285                tool_calls: None,
1286            },
1287            GenerationDelta::ToolUse { tool_use } => ChatMessageDelta {
1288                role: Some(Role::Assistant),
1289                content: None,
1290                tool_calls: Some(vec![ToolCallDelta {
1291                    index: tool_use.index,
1292                    id: tool_use.id,
1293                    r#type: Some("function".to_string()),
1294                    function: Some(FunctionCallDelta {
1295                        name: tool_use.name,
1296                        arguments: tool_use.input,
1297                    }),
1298                }]),
1299            },
1300            _ => ChatMessageDelta {
1301                role: Some(Role::Assistant),
1302                content: None,
1303                tool_calls: None,
1304            },
1305        }
1306    }
1307}
1308
1309#[cfg(test)]
1310mod tests {
1311    use super::*;
1312
1313    #[test]
1314    fn test_serialize_basic_request() {
1315        let request = ChatCompletionRequest {
1316            model: AgentModel::Smart.to_string(),
1317            messages: vec![
1318                ChatMessage {
1319                    role: Role::System,
1320                    content: Some(MessageContent::String(
1321                        "You are a helpful assistant.".to_string(),
1322                    )),
1323                    name: None,
1324                    tool_calls: None,
1325                    tool_call_id: None,
1326                    usage: None,
1327                },
1328                ChatMessage {
1329                    role: Role::User,
1330                    content: Some(MessageContent::String("Hello!".to_string())),
1331                    name: None,
1332                    tool_calls: None,
1333                    tool_call_id: None,
1334                    usage: None,
1335                },
1336            ],
1337            frequency_penalty: None,
1338            logit_bias: None,
1339            logprobs: None,
1340            max_tokens: Some(100),
1341            n: None,
1342            presence_penalty: None,
1343            response_format: None,
1344            seed: None,
1345            stop: None,
1346            stream: None,
1347            temperature: Some(0.7),
1348            top_p: None,
1349            tools: None,
1350            tool_choice: None,
1351            user: None,
1352            context: None,
1353        };
1354
1355        let json = serde_json::to_string(&request).unwrap();
1356        assert!(json.contains("\"model\":\"smart\""));
1357        assert!(json.contains("\"messages\":["));
1358        assert!(json.contains("\"role\":\"system\""));
1359        assert!(json.contains("\"content\":\"You are a helpful assistant.\""));
1360        assert!(json.contains("\"role\":\"user\""));
1361        assert!(json.contains("\"content\":\"Hello!\""));
1362        assert!(json.contains("\"max_tokens\":100"));
1363        assert!(json.contains("\"temperature\":0.7"));
1364    }
1365
1366    #[test]
1367    fn test_deserialize_response() {
1368        let json = r#"{
1369            "id": "chatcmpl-123",
1370            "object": "chat.completion",
1371            "created": 1677652288,
1372            "model": "smart",
1373            "system_fingerprint": "fp_123abc",
1374            "choices": [{
1375                "index": 0,
1376                "message": {
1377                    "role": "assistant",
1378                    "content": "Hello! How can I help you today?"
1379                },
1380                "finish_reason": "stop"
1381            }],
1382            "usage": {
1383                "prompt_tokens": 9,
1384                "completion_tokens": 12,
1385                "total_tokens": 21
1386            }
1387        }"#;
1388
1389        let response: ChatCompletionResponse = serde_json::from_str(json).unwrap();
1390        assert_eq!(response.id, "chatcmpl-123");
1391        assert_eq!(response.object, "chat.completion");
1392        assert_eq!(response.created, 1677652288);
1393        assert_eq!(response.model, AgentModel::Smart.to_string());
1394        assert_eq!(response.system_fingerprint, Some("fp_123abc".to_string()));
1395
1396        assert_eq!(response.choices.len(), 1);
1397        assert_eq!(response.choices[0].index, 0);
1398        assert_eq!(response.choices[0].message.role, Role::Assistant);
1399
1400        match &response.choices[0].message.content {
1401            Some(MessageContent::String(content)) => {
1402                assert_eq!(content, "Hello! How can I help you today?");
1403            }
1404            _ => panic!("Expected string content"),
1405        }
1406
1407        assert_eq!(response.choices[0].finish_reason, FinishReason::Stop);
1408        assert_eq!(response.usage.prompt_tokens, 9);
1409        assert_eq!(response.usage.completion_tokens, 12);
1410        assert_eq!(response.usage.total_tokens, 21);
1411    }
1412
1413    #[test]
1414    fn test_tool_calls_request_response() {
1415        // Test a request with tools
1416        let tools_request = ChatCompletionRequest {
1417            model: AgentModel::Smart.to_string(),
1418            messages: vec![ChatMessage {
1419                role: Role::User,
1420                content: Some(MessageContent::String(
1421                    "What's the weather in San Francisco?".to_string(),
1422                )),
1423                name: None,
1424                tool_calls: None,
1425                tool_call_id: None,
1426                usage: None,
1427            }],
1428            tools: Some(vec![Tool {
1429                r#type: "function".to_string(),
1430                function: FunctionDefinition {
1431                    name: "get_weather".to_string(),
1432                    description: Some("Get the current weather in a given location".to_string()),
1433                    parameters: serde_json::json!({
1434                        "type": "object",
1435                        "properties": {
1436                            "location": {
1437                                "type": "string",
1438                                "description": "The city and state, e.g. San Francisco, CA"
1439                            }
1440                        },
1441                        "required": ["location"]
1442                    }),
1443                },
1444            }]),
1445            tool_choice: Some(ToolChoice::Auto),
1446            max_tokens: Some(100),
1447            temperature: Some(0.7),
1448            frequency_penalty: None,
1449            logit_bias: None,
1450            logprobs: None,
1451            n: None,
1452            presence_penalty: None,
1453            response_format: None,
1454            seed: None,
1455            stop: None,
1456            stream: None,
1457            top_p: None,
1458            user: None,
1459            context: None,
1460        };
1461
1462        let json = serde_json::to_string(&tools_request).unwrap();
1463        println!("Tool request JSON: {}", json);
1464
1465        assert!(json.contains("\"tools\":["));
1466        assert!(json.contains("\"type\":\"function\""));
1467        assert!(json.contains("\"name\":\"get_weather\""));
1468        // Auto should be serialized as "auto" (string)
1469        assert!(json.contains("\"tool_choice\":\"auto\""));
1470
1471        // Test response with tool calls
1472        let tool_response_json = r#"{
1473            "id": "chatcmpl-123",
1474            "object": "chat.completion",
1475            "created": 1677652288,
1476            "model": "eco",
1477            "choices": [{
1478                "index": 0,
1479                "message": {
1480                    "role": "assistant",
1481                    "content": null,
1482                    "tool_calls": [
1483                        {
1484                            "id": "call_abc123",
1485                            "type": "function",
1486                            "function": {
1487                                "name": "get_weather",
1488                                "arguments": "{\"location\":\"San Francisco, CA\"}"
1489                            }
1490                        }
1491                    ]
1492                },
1493                "finish_reason": "tool_calls"
1494            }],
1495            "usage": {
1496                "prompt_tokens": 82,
1497                "completion_tokens": 17,
1498                "total_tokens": 99
1499            }
1500        }"#;
1501
1502        let tool_response: ChatCompletionResponse =
1503            serde_json::from_str(tool_response_json).unwrap();
1504        assert_eq!(tool_response.choices[0].message.role, Role::Assistant);
1505        assert_eq!(tool_response.choices[0].message.content, None);
1506        assert!(tool_response.choices[0].message.tool_calls.is_some());
1507
1508        let tool_calls = tool_response.choices[0]
1509            .message
1510            .tool_calls
1511            .as_ref()
1512            .unwrap();
1513        assert_eq!(tool_calls.len(), 1);
1514        assert_eq!(tool_calls[0].id, "call_abc123");
1515        assert_eq!(tool_calls[0].r#type, "function");
1516        assert_eq!(tool_calls[0].function.name, "get_weather");
1517        assert_eq!(
1518            tool_calls[0].function.arguments,
1519            "{\"location\":\"San Francisco, CA\"}"
1520        );
1521
1522        assert_eq!(
1523            tool_response.choices[0].finish_reason,
1524            FinishReason::ToolCalls,
1525        );
1526    }
1527
1528    #[test]
1529    fn test_content_with_image() {
1530        let message_with_image = ChatMessage {
1531            role: Role::User,
1532            content: Some(MessageContent::Array(vec![
1533                ContentPart {
1534                    r#type: "text".to_string(),
1535                    text: Some("What's in this image?".to_string()),
1536                    image_url: None,
1537                },
1538                ContentPart {
1539                    r#type: "input_image".to_string(),
1540                    text: None,
1541                    image_url: Some(ImageUrl {
1542                        url: "...".to_string(),
1543                        detail: Some("low".to_string()),
1544                    }),
1545                },
1546            ])),
1547            name: None,
1548            tool_calls: None,
1549            tool_call_id: None,
1550            usage: None,
1551        };
1552
1553        let json = serde_json::to_string(&message_with_image).unwrap();
1554        println!("Serialized JSON: {}", json);
1555
1556        assert!(json.contains("\"role\":\"user\""));
1557        assert!(json.contains("\"type\":\"text\""));
1558        assert!(json.contains("\"text\":\"What's in this image?\""));
1559        assert!(json.contains("\"type\":\"input_image\""));
1560        assert!(json.contains("\"url\":\"...\""));
1561        assert!(json.contains("\"detail\":\"low\""));
1562    }
1563
1564    #[test]
1565    fn test_response_format() {
1566        let json_format_request = ChatCompletionRequest {
1567            model: AgentModel::Smart.to_string(),
1568            messages: vec![ChatMessage {
1569                role: Role::User,
1570                content: Some(MessageContent::String(
1571                    "Generate a JSON object with name and age fields".to_string(),
1572                )),
1573                name: None,
1574                tool_calls: None,
1575                tool_call_id: None,
1576                usage: None,
1577            }],
1578            response_format: Some(ResponseFormat {
1579                r#type: "json_object".to_string(),
1580            }),
1581            max_tokens: Some(100),
1582            temperature: None,
1583            frequency_penalty: None,
1584            logit_bias: None,
1585            logprobs: None,
1586            n: None,
1587            presence_penalty: None,
1588            seed: None,
1589            stop: None,
1590            stream: None,
1591            top_p: None,
1592            tools: None,
1593            tool_choice: None,
1594            user: None,
1595            context: None,
1596        };
1597
1598        let json = serde_json::to_string(&json_format_request).unwrap();
1599        assert!(json.contains("\"response_format\":{\"type\":\"json_object\"}"));
1600    }
1601
1602    #[test]
1603    fn test_llm_message_to_chat_message() {
1604        // Test simple string content
1605        let llm_message = LLMMessage {
1606            role: "user".to_string(),
1607            content: LLMMessageContent::String("Hello, world!".to_string()),
1608        };
1609
1610        let chat_message = ChatMessage::from(llm_message);
1611        assert_eq!(chat_message.role, Role::User);
1612        match &chat_message.content {
1613            Some(MessageContent::String(text)) => assert_eq!(text, "Hello, world!"),
1614            _ => panic!("Expected string content"),
1615        }
1616        assert_eq!(chat_message.tool_calls, None);
1617
1618        // Test tool call conversion
1619        let llm_message_with_tool = LLMMessage {
1620            role: "assistant".to_string(),
1621            content: LLMMessageContent::List(vec![LLMMessageTypedContent::ToolCall {
1622                id: "call_123".to_string(),
1623                name: "get_weather".to_string(),
1624                args: serde_json::json!({"location": "San Francisco"}),
1625            }]),
1626        };
1627
1628        let chat_message = ChatMessage::from(llm_message_with_tool);
1629        assert_eq!(chat_message.role, Role::Assistant);
1630        assert_eq!(chat_message.content, None); // No text content
1631        assert!(chat_message.tool_calls.is_some());
1632
1633        let tool_calls = chat_message.tool_calls.unwrap();
1634        assert_eq!(tool_calls.len(), 1);
1635        assert_eq!(tool_calls[0].id, "call_123");
1636        assert_eq!(tool_calls[0].function.name, "get_weather");
1637        assert!(tool_calls[0].function.arguments.contains("San Francisco"));
1638
1639        // Test mixed content
1640        let llm_message_mixed = LLMMessage {
1641            role: "assistant".to_string(),
1642            content: LLMMessageContent::List(vec![
1643                LLMMessageTypedContent::Text {
1644                    text: "The weather is:".to_string(),
1645                },
1646                LLMMessageTypedContent::ToolCall {
1647                    id: "call_456".to_string(),
1648                    name: "get_weather".to_string(),
1649                    args: serde_json::json!({"location": "New York"}),
1650                },
1651            ]),
1652        };
1653
1654        let chat_message = ChatMessage::from(llm_message_mixed);
1655        assert_eq!(chat_message.role, Role::Assistant);
1656
1657        match &chat_message.content {
1658            Some(MessageContent::Array(parts)) => {
1659                assert_eq!(parts.len(), 1);
1660                assert_eq!(parts[0].r#type, "text");
1661                assert_eq!(parts[0].text, Some("The weather is:".to_string()));
1662            }
1663            _ => panic!("Expected array content"),
1664        }
1665
1666        let tool_calls = chat_message.tool_calls.unwrap();
1667        assert_eq!(tool_calls.len(), 1);
1668        assert_eq!(tool_calls[0].id, "call_456");
1669        assert_eq!(tool_calls[0].function.name, "get_weather");
1670        assert!(tool_calls[0].function.arguments.contains("New York"));
1671    }
1672}