stakpak_shared/models/integrations/
openai.rs

1//! OpenAI provider configuration and chat message types
2//!
3//! This module contains:
4//! - Configuration types for OpenAI provider
5//! - OpenAI model enums with pricing info
6//! - Chat message types used throughout the TUI
7//! - Tool call types for agent interactions
8//!
9//! Note: Low-level API request/response types are in `libs/ai/src/providers/openai/`.
10
11use crate::models::llm::{
12    GenerationDelta, LLMMessage, LLMMessageContent, LLMMessageImageSource, LLMMessageTypedContent,
13    LLMTokenUsage, LLMTool,
14};
15use crate::models::model_pricing::{ContextAware, ContextPricingTier, ModelContextInfo};
16use serde::{Deserialize, Serialize};
17use serde_json::{Value, json};
18use uuid::Uuid;
19
20// =============================================================================
21// Provider Configuration
22// =============================================================================
23
24/// Configuration for OpenAI provider
25#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq)]
26pub struct OpenAIConfig {
27    pub api_endpoint: Option<String>,
28    pub api_key: Option<String>,
29}
30
31impl OpenAIConfig {
32    /// Create config with API key
33    pub fn with_api_key(api_key: impl Into<String>) -> Self {
34        Self {
35            api_key: Some(api_key.into()),
36            api_endpoint: None,
37        }
38    }
39
40    /// Create config from ProviderAuth (only supports API key for OpenAI)
41    pub fn from_provider_auth(auth: &crate::models::auth::ProviderAuth) -> Option<Self> {
42        match auth {
43            crate::models::auth::ProviderAuth::Api { key } => Some(Self::with_api_key(key)),
44            crate::models::auth::ProviderAuth::OAuth { .. } => None, // OpenAI doesn't support OAuth
45        }
46    }
47
48    /// Merge with credentials from ProviderAuth, preserving existing endpoint
49    pub fn with_provider_auth(mut self, auth: &crate::models::auth::ProviderAuth) -> Option<Self> {
50        match auth {
51            crate::models::auth::ProviderAuth::Api { key } => {
52                self.api_key = Some(key.clone());
53                Some(self)
54            }
55            crate::models::auth::ProviderAuth::OAuth { .. } => None, // OpenAI doesn't support OAuth
56        }
57    }
58}
59
60// =============================================================================
61// Model Definitions
62// =============================================================================
63
64/// OpenAI model identifiers
65#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
66pub enum OpenAIModel {
67    // Reasoning Models
68    #[serde(rename = "o3-2025-04-16")]
69    O3,
70    #[serde(rename = "o4-mini-2025-04-16")]
71    O4Mini,
72
73    #[default]
74    #[serde(rename = "gpt-5-2025-08-07")]
75    GPT5,
76    #[serde(rename = "gpt-5.1-2025-11-13")]
77    GPT51,
78    #[serde(rename = "gpt-5-mini-2025-08-07")]
79    GPT5Mini,
80    #[serde(rename = "gpt-5-nano-2025-08-07")]
81    GPT5Nano,
82
83    Custom(String),
84}
85
86impl OpenAIModel {
87    pub fn from_string(s: &str) -> Result<Self, String> {
88        serde_json::from_value(serde_json::Value::String(s.to_string()))
89            .map_err(|_| "Failed to deserialize OpenAI model".to_string())
90    }
91
92    /// Default smart model for OpenAI
93    pub const DEFAULT_SMART_MODEL: OpenAIModel = OpenAIModel::GPT5;
94
95    /// Default eco model for OpenAI
96    pub const DEFAULT_ECO_MODEL: OpenAIModel = OpenAIModel::GPT5Mini;
97
98    /// Default recovery model for OpenAI
99    pub const DEFAULT_RECOVERY_MODEL: OpenAIModel = OpenAIModel::GPT5Mini;
100
101    /// Get default smart model as string
102    pub fn default_smart_model() -> String {
103        Self::DEFAULT_SMART_MODEL.to_string()
104    }
105
106    /// Get default eco model as string
107    pub fn default_eco_model() -> String {
108        Self::DEFAULT_ECO_MODEL.to_string()
109    }
110
111    /// Get default recovery model as string
112    pub fn default_recovery_model() -> String {
113        Self::DEFAULT_RECOVERY_MODEL.to_string()
114    }
115}
116
117impl ContextAware for OpenAIModel {
118    fn context_info(&self) -> ModelContextInfo {
119        let model_name = self.to_string();
120
121        if model_name.starts_with("o3") {
122            return ModelContextInfo {
123                max_tokens: 200_000,
124                pricing_tiers: vec![ContextPricingTier {
125                    label: "Standard".to_string(),
126                    input_cost_per_million: 2.0,
127                    output_cost_per_million: 8.0,
128                    upper_bound: None,
129                }],
130                approach_warning_threshold: 0.8,
131            };
132        }
133
134        if model_name.starts_with("o4-mini") {
135            return ModelContextInfo {
136                max_tokens: 200_000,
137                pricing_tiers: vec![ContextPricingTier {
138                    label: "Standard".to_string(),
139                    input_cost_per_million: 1.10,
140                    output_cost_per_million: 4.40,
141                    upper_bound: None,
142                }],
143                approach_warning_threshold: 0.8,
144            };
145        }
146
147        if model_name.starts_with("gpt-5-mini") {
148            return ModelContextInfo {
149                max_tokens: 400_000,
150                pricing_tiers: vec![ContextPricingTier {
151                    label: "Standard".to_string(),
152                    input_cost_per_million: 0.25,
153                    output_cost_per_million: 2.0,
154                    upper_bound: None,
155                }],
156                approach_warning_threshold: 0.8,
157            };
158        }
159
160        if model_name.starts_with("gpt-5-nano") {
161            return ModelContextInfo {
162                max_tokens: 400_000,
163                pricing_tiers: vec![ContextPricingTier {
164                    label: "Standard".to_string(),
165                    input_cost_per_million: 0.05,
166                    output_cost_per_million: 0.40,
167                    upper_bound: None,
168                }],
169                approach_warning_threshold: 0.8,
170            };
171        }
172
173        if model_name.starts_with("gpt-5") {
174            return ModelContextInfo {
175                max_tokens: 400_000,
176                pricing_tiers: vec![ContextPricingTier {
177                    label: "Standard".to_string(),
178                    input_cost_per_million: 1.25,
179                    output_cost_per_million: 10.0,
180                    upper_bound: None,
181                }],
182                approach_warning_threshold: 0.8,
183            };
184        }
185
186        ModelContextInfo::default()
187    }
188
189    fn model_name(&self) -> String {
190        match self {
191            OpenAIModel::O3 => "O3".to_string(),
192            OpenAIModel::O4Mini => "O4-mini".to_string(),
193            OpenAIModel::GPT5 => "GPT-5".to_string(),
194            OpenAIModel::GPT51 => "GPT-5.1".to_string(),
195            OpenAIModel::GPT5Mini => "GPT-5 Mini".to_string(),
196            OpenAIModel::GPT5Nano => "GPT-5 Nano".to_string(),
197            OpenAIModel::Custom(name) => format!("Custom ({})", name),
198        }
199    }
200}
201
202impl std::fmt::Display for OpenAIModel {
203    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204        match self {
205            OpenAIModel::O3 => write!(f, "o3-2025-04-16"),
206            OpenAIModel::O4Mini => write!(f, "o4-mini-2025-04-16"),
207            OpenAIModel::GPT5Nano => write!(f, "gpt-5-nano-2025-08-07"),
208            OpenAIModel::GPT5Mini => write!(f, "gpt-5-mini-2025-08-07"),
209            OpenAIModel::GPT5 => write!(f, "gpt-5-2025-08-07"),
210            OpenAIModel::GPT51 => write!(f, "gpt-5.1-2025-11-13"),
211            OpenAIModel::Custom(model_name) => write!(f, "{}", model_name),
212        }
213    }
214}
215
216/// Agent model type (smart/eco/recovery)
217#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq)]
218pub enum AgentModel {
219    #[serde(rename = "smart")]
220    #[default]
221    Smart,
222    #[serde(rename = "eco")]
223    Eco,
224    #[serde(rename = "recovery")]
225    Recovery,
226}
227
228impl std::fmt::Display for AgentModel {
229    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
230        match self {
231            AgentModel::Smart => write!(f, "smart"),
232            AgentModel::Eco => write!(f, "eco"),
233            AgentModel::Recovery => write!(f, "recovery"),
234        }
235    }
236}
237
238impl From<String> for AgentModel {
239    fn from(value: String) -> Self {
240        match value.as_str() {
241            "eco" => AgentModel::Eco,
242            "recovery" => AgentModel::Recovery,
243            _ => AgentModel::Smart,
244        }
245    }
246}
247
248// =============================================================================
249// Message Types (used by TUI)
250// =============================================================================
251
252/// Message role
253#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
254#[serde(rename_all = "lowercase")]
255pub enum Role {
256    System,
257    Developer,
258    User,
259    #[default]
260    Assistant,
261    Tool,
262}
263
264impl std::fmt::Display for Role {
265    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
266        match self {
267            Role::System => write!(f, "system"),
268            Role::Developer => write!(f, "developer"),
269            Role::User => write!(f, "user"),
270            Role::Assistant => write!(f, "assistant"),
271            Role::Tool => write!(f, "tool"),
272        }
273    }
274}
275
276/// Chat message
277#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
278pub struct ChatMessage {
279    pub role: Role,
280    pub content: Option<MessageContent>,
281    #[serde(skip_serializing_if = "Option::is_none")]
282    pub name: Option<String>,
283    #[serde(skip_serializing_if = "Option::is_none")]
284    pub tool_calls: Option<Vec<ToolCall>>,
285    #[serde(skip_serializing_if = "Option::is_none")]
286    pub tool_call_id: Option<String>,
287    #[serde(skip_serializing_if = "Option::is_none")]
288    pub usage: Option<LLMTokenUsage>,
289}
290
291impl ChatMessage {
292    pub fn last_server_message(messages: &[ChatMessage]) -> Option<&ChatMessage> {
293        messages
294            .iter()
295            .rev()
296            .find(|message| message.role != Role::User && message.role != Role::Tool)
297    }
298
299    pub fn to_xml(&self) -> String {
300        match &self.content {
301            Some(MessageContent::String(s)) => {
302                format!("<message role=\"{}\">{}</message>", self.role, s)
303            }
304            Some(MessageContent::Array(parts)) => parts
305                .iter()
306                .map(|part| {
307                    format!(
308                        "<message role=\"{}\" type=\"{}\">{}</message>",
309                        self.role,
310                        part.r#type,
311                        part.text.clone().unwrap_or_default()
312                    )
313                })
314                .collect::<Vec<String>>()
315                .join("\n"),
316            None => String::new(),
317        }
318    }
319}
320
321/// Message content (string or array of parts)
322#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
323#[serde(untagged)]
324pub enum MessageContent {
325    String(String),
326    Array(Vec<ContentPart>),
327}
328
329impl MessageContent {
330    pub fn inject_checkpoint_id(&self, checkpoint_id: Uuid) -> Self {
331        match self {
332            MessageContent::String(s) => MessageContent::String(format!(
333                "<checkpoint_id>{checkpoint_id}</checkpoint_id>\n{s}"
334            )),
335            MessageContent::Array(parts) => MessageContent::Array(
336                std::iter::once(ContentPart {
337                    r#type: "text".to_string(),
338                    text: Some(format!("<checkpoint_id>{checkpoint_id}</checkpoint_id>")),
339                    image_url: None,
340                })
341                .chain(parts.iter().cloned())
342                .collect(),
343            ),
344        }
345    }
346
347    pub fn extract_checkpoint_id(&self) -> Option<Uuid> {
348        match self {
349            MessageContent::String(s) => s
350                .rfind("<checkpoint_id>")
351                .and_then(|start| {
352                    s[start..]
353                        .find("</checkpoint_id>")
354                        .map(|end| (start + "<checkpoint_id>".len(), start + end))
355                })
356                .and_then(|(start, end)| Uuid::parse_str(&s[start..end]).ok()),
357            MessageContent::Array(parts) => parts.iter().rev().find_map(|part| {
358                part.text.as_deref().and_then(|text| {
359                    text.rfind("<checkpoint_id>")
360                        .and_then(|start| {
361                            text[start..]
362                                .find("</checkpoint_id>")
363                                .map(|end| (start + "<checkpoint_id>".len(), start + end))
364                        })
365                        .and_then(|(start, end)| Uuid::parse_str(&text[start..end]).ok())
366                })
367            }),
368        }
369    }
370}
371
372impl std::fmt::Display for MessageContent {
373    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
374        match self {
375            MessageContent::String(s) => write!(f, "{s}"),
376            MessageContent::Array(parts) => {
377                let text_parts: Vec<String> =
378                    parts.iter().filter_map(|part| part.text.clone()).collect();
379                write!(f, "{}", text_parts.join("\n"))
380            }
381        }
382    }
383}
384
385impl Default for MessageContent {
386    fn default() -> Self {
387        MessageContent::String(String::new())
388    }
389}
390
391/// Content part (text or image)
392#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
393pub struct ContentPart {
394    pub r#type: String,
395    #[serde(skip_serializing_if = "Option::is_none")]
396    pub text: Option<String>,
397    #[serde(skip_serializing_if = "Option::is_none")]
398    pub image_url: Option<ImageUrl>,
399}
400
401/// Image URL with optional detail level
402#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
403pub struct ImageUrl {
404    pub url: String,
405    #[serde(skip_serializing_if = "Option::is_none")]
406    pub detail: Option<String>,
407}
408
409// =============================================================================
410// Tool Types (used by TUI)
411// =============================================================================
412
413/// Tool definition
414#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
415pub struct Tool {
416    pub r#type: String,
417    pub function: FunctionDefinition,
418}
419
420/// Function definition for tools
421#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
422pub struct FunctionDefinition {
423    pub name: String,
424    pub description: Option<String>,
425    pub parameters: serde_json::Value,
426}
427
428impl From<Tool> for LLMTool {
429    fn from(tool: Tool) -> Self {
430        LLMTool {
431            name: tool.function.name,
432            description: tool.function.description.unwrap_or_default(),
433            input_schema: tool.function.parameters,
434        }
435    }
436}
437
438/// Tool choice configuration
439#[derive(Debug, Clone, PartialEq)]
440pub enum ToolChoice {
441    Auto,
442    Required,
443    Object(ToolChoiceObject),
444}
445
446impl Serialize for ToolChoice {
447    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
448    where
449        S: serde::Serializer,
450    {
451        match self {
452            ToolChoice::Auto => serializer.serialize_str("auto"),
453            ToolChoice::Required => serializer.serialize_str("required"),
454            ToolChoice::Object(obj) => obj.serialize(serializer),
455        }
456    }
457}
458
459impl<'de> Deserialize<'de> for ToolChoice {
460    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
461    where
462        D: serde::Deserializer<'de>,
463    {
464        struct ToolChoiceVisitor;
465
466        impl<'de> serde::de::Visitor<'de> for ToolChoiceVisitor {
467            type Value = ToolChoice;
468
469            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
470                formatter.write_str("string or object")
471            }
472
473            fn visit_str<E>(self, value: &str) -> Result<ToolChoice, E>
474            where
475                E: serde::de::Error,
476            {
477                match value {
478                    "auto" => Ok(ToolChoice::Auto),
479                    "required" => Ok(ToolChoice::Required),
480                    _ => Err(serde::de::Error::unknown_variant(
481                        value,
482                        &["auto", "required"],
483                    )),
484                }
485            }
486
487            fn visit_map<M>(self, map: M) -> Result<ToolChoice, M::Error>
488            where
489                M: serde::de::MapAccess<'de>,
490            {
491                let obj = ToolChoiceObject::deserialize(
492                    serde::de::value::MapAccessDeserializer::new(map),
493                )?;
494                Ok(ToolChoice::Object(obj))
495            }
496        }
497
498        deserializer.deserialize_any(ToolChoiceVisitor)
499    }
500}
501
502#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
503pub struct ToolChoiceObject {
504    pub r#type: String,
505    pub function: FunctionChoice,
506}
507
508#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
509pub struct FunctionChoice {
510    pub name: String,
511}
512
513/// Tool call from assistant
514#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
515pub struct ToolCall {
516    pub id: String,
517    pub r#type: String,
518    pub function: FunctionCall,
519}
520
521/// Function call details
522#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
523pub struct FunctionCall {
524    pub name: String,
525    pub arguments: String,
526}
527
528/// Tool call result status
529#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
530pub enum ToolCallResultStatus {
531    Success,
532    Error,
533    Cancelled,
534}
535
536/// Tool call result
537#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
538pub struct ToolCallResult {
539    pub call: ToolCall,
540    pub result: String,
541    pub status: ToolCallResultStatus,
542}
543
544/// Tool call result progress update
545#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
546pub struct ToolCallResultProgress {
547    pub id: Uuid,
548    pub message: String,
549}
550
551// =============================================================================
552// Chat Completion Types (used by TUI)
553// =============================================================================
554
555/// Chat completion request
556#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
557pub struct ChatCompletionRequest {
558    pub model: String,
559    pub messages: Vec<ChatMessage>,
560    #[serde(skip_serializing_if = "Option::is_none")]
561    pub frequency_penalty: Option<f32>,
562    #[serde(skip_serializing_if = "Option::is_none")]
563    pub logit_bias: Option<serde_json::Value>,
564    #[serde(skip_serializing_if = "Option::is_none")]
565    pub logprobs: Option<bool>,
566    #[serde(skip_serializing_if = "Option::is_none")]
567    pub max_tokens: Option<u32>,
568    #[serde(skip_serializing_if = "Option::is_none")]
569    pub n: Option<u32>,
570    #[serde(skip_serializing_if = "Option::is_none")]
571    pub presence_penalty: Option<f32>,
572    #[serde(skip_serializing_if = "Option::is_none")]
573    pub response_format: Option<ResponseFormat>,
574    #[serde(skip_serializing_if = "Option::is_none")]
575    pub seed: Option<i64>,
576    #[serde(skip_serializing_if = "Option::is_none")]
577    pub stop: Option<StopSequence>,
578    #[serde(skip_serializing_if = "Option::is_none")]
579    pub stream: Option<bool>,
580    #[serde(skip_serializing_if = "Option::is_none")]
581    pub temperature: Option<f32>,
582    #[serde(skip_serializing_if = "Option::is_none")]
583    pub top_p: Option<f32>,
584    #[serde(skip_serializing_if = "Option::is_none")]
585    pub tools: Option<Vec<Tool>>,
586    #[serde(skip_serializing_if = "Option::is_none")]
587    pub tool_choice: Option<ToolChoice>,
588    #[serde(skip_serializing_if = "Option::is_none")]
589    pub user: Option<String>,
590    #[serde(skip_serializing_if = "Option::is_none")]
591    pub context: Option<ChatCompletionContext>,
592}
593
594impl ChatCompletionRequest {
595    pub fn new(
596        model: String,
597        messages: Vec<ChatMessage>,
598        tools: Option<Vec<Tool>>,
599        stream: Option<bool>,
600    ) -> Self {
601        Self {
602            model,
603            messages,
604            frequency_penalty: None,
605            logit_bias: None,
606            logprobs: None,
607            max_tokens: None,
608            n: None,
609            presence_penalty: None,
610            response_format: None,
611            seed: None,
612            stop: None,
613            stream,
614            temperature: None,
615            top_p: None,
616            tools,
617            tool_choice: None,
618            user: None,
619            context: None,
620        }
621    }
622}
623
624#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
625pub struct ChatCompletionContext {
626    pub scratchpad: Option<Value>,
627}
628
629#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
630pub struct ResponseFormat {
631    pub r#type: String,
632}
633
634#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
635#[serde(untagged)]
636pub enum StopSequence {
637    String(String),
638    Array(Vec<String>),
639}
640
641/// Chat completion response
642#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
643pub struct ChatCompletionResponse {
644    pub id: String,
645    pub object: String,
646    pub created: u64,
647    pub model: String,
648    pub choices: Vec<ChatCompletionChoice>,
649    pub usage: LLMTokenUsage,
650    #[serde(skip_serializing_if = "Option::is_none")]
651    pub system_fingerprint: Option<String>,
652    pub metadata: Option<serde_json::Value>,
653}
654
655#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
656pub struct ChatCompletionChoice {
657    pub index: usize,
658    pub message: ChatMessage,
659    pub logprobs: Option<LogProbs>,
660    pub finish_reason: FinishReason,
661}
662
663#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
664#[serde(rename_all = "snake_case")]
665pub enum FinishReason {
666    Stop,
667    Length,
668    ContentFilter,
669    ToolCalls,
670}
671
672#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
673pub struct LogProbs {
674    pub content: Option<Vec<LogProbContent>>,
675}
676
677#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
678pub struct LogProbContent {
679    pub token: String,
680    pub logprob: f32,
681    pub bytes: Option<Vec<u8>>,
682    pub top_logprobs: Option<Vec<TokenLogprob>>,
683}
684
685#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
686pub struct TokenLogprob {
687    pub token: String,
688    pub logprob: f32,
689    pub bytes: Option<Vec<u8>>,
690}
691
692// =============================================================================
693// Streaming Types
694// =============================================================================
695
696#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
697pub struct ChatCompletionStreamResponse {
698    pub id: String,
699    pub object: String,
700    pub created: u64,
701    pub model: String,
702    pub choices: Vec<ChatCompletionStreamChoice>,
703    pub usage: Option<LLMTokenUsage>,
704    pub metadata: Option<serde_json::Value>,
705}
706
707#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
708pub struct ChatCompletionStreamChoice {
709    pub index: usize,
710    pub delta: ChatMessageDelta,
711    pub finish_reason: Option<FinishReason>,
712}
713
714#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
715pub struct ChatMessageDelta {
716    #[serde(skip_serializing_if = "Option::is_none")]
717    pub role: Option<Role>,
718    #[serde(skip_serializing_if = "Option::is_none")]
719    pub content: Option<String>,
720    #[serde(skip_serializing_if = "Option::is_none")]
721    pub tool_calls: Option<Vec<ToolCallDelta>>,
722}
723
724#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
725pub struct ToolCallDelta {
726    pub index: usize,
727    pub id: Option<String>,
728    pub r#type: Option<String>,
729    pub function: Option<FunctionCallDelta>,
730}
731
732#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
733pub struct FunctionCallDelta {
734    pub name: Option<String>,
735    pub arguments: Option<String>,
736}
737
738// =============================================================================
739// Conversions
740// =============================================================================
741
742impl From<LLMMessage> for ChatMessage {
743    fn from(llm_message: LLMMessage) -> Self {
744        let role = match llm_message.role.as_str() {
745            "system" => Role::System,
746            "user" => Role::User,
747            "assistant" => Role::Assistant,
748            "tool" => Role::Tool,
749            "developer" => Role::Developer,
750            _ => Role::User,
751        };
752
753        let (content, tool_calls) = match llm_message.content {
754            LLMMessageContent::String(text) => (Some(MessageContent::String(text)), None),
755            LLMMessageContent::List(items) => {
756                let mut text_parts = Vec::new();
757                let mut tool_call_parts = Vec::new();
758
759                for item in items {
760                    match item {
761                        LLMMessageTypedContent::Text { text } => {
762                            text_parts.push(ContentPart {
763                                r#type: "text".to_string(),
764                                text: Some(text),
765                                image_url: None,
766                            });
767                        }
768                        LLMMessageTypedContent::ToolCall { id, name, args } => {
769                            tool_call_parts.push(ToolCall {
770                                id,
771                                r#type: "function".to_string(),
772                                function: FunctionCall {
773                                    name,
774                                    arguments: args.to_string(),
775                                },
776                            });
777                        }
778                        LLMMessageTypedContent::ToolResult { content, .. } => {
779                            text_parts.push(ContentPart {
780                                r#type: "text".to_string(),
781                                text: Some(content),
782                                image_url: None,
783                            });
784                        }
785                        LLMMessageTypedContent::Image { source } => {
786                            text_parts.push(ContentPart {
787                                r#type: "image_url".to_string(),
788                                text: None,
789                                image_url: Some(ImageUrl {
790                                    url: format!(
791                                        "data:{};base64,{}",
792                                        source.media_type, source.data
793                                    ),
794                                    detail: None,
795                                }),
796                            });
797                        }
798                    }
799                }
800
801                let content = if !text_parts.is_empty() {
802                    Some(MessageContent::Array(text_parts))
803                } else {
804                    None
805                };
806
807                let tool_calls = if !tool_call_parts.is_empty() {
808                    Some(tool_call_parts)
809                } else {
810                    None
811                };
812
813                (content, tool_calls)
814            }
815        };
816
817        ChatMessage {
818            role,
819            content,
820            name: None,
821            tool_calls,
822            tool_call_id: None,
823            usage: None,
824        }
825    }
826}
827
828impl From<ChatMessage> for LLMMessage {
829    fn from(chat_message: ChatMessage) -> Self {
830        let mut content_parts = Vec::new();
831
832        match chat_message.content {
833            Some(MessageContent::String(s)) => {
834                if !s.is_empty() {
835                    content_parts.push(LLMMessageTypedContent::Text { text: s });
836                }
837            }
838            Some(MessageContent::Array(parts)) => {
839                for part in parts {
840                    if let Some(text) = part.text {
841                        content_parts.push(LLMMessageTypedContent::Text { text });
842                    } else if let Some(image_url) = part.image_url {
843                        let (media_type, data) = if image_url.url.starts_with("data:") {
844                            let parts: Vec<&str> = image_url.url.splitn(2, ',').collect();
845                            if parts.len() == 2 {
846                                let meta = parts[0];
847                                let data = parts[1];
848                                let media_type = meta
849                                    .trim_start_matches("data:")
850                                    .trim_end_matches(";base64")
851                                    .to_string();
852                                (media_type, data.to_string())
853                            } else {
854                                ("image/jpeg".to_string(), image_url.url)
855                            }
856                        } else {
857                            ("image/jpeg".to_string(), image_url.url)
858                        };
859
860                        content_parts.push(LLMMessageTypedContent::Image {
861                            source: LLMMessageImageSource {
862                                r#type: "base64".to_string(),
863                                media_type,
864                                data,
865                            },
866                        });
867                    }
868                }
869            }
870            None => {}
871        }
872
873        if let Some(tool_calls) = chat_message.tool_calls {
874            for tool_call in tool_calls {
875                let args = serde_json::from_str(&tool_call.function.arguments).unwrap_or(json!({}));
876                content_parts.push(LLMMessageTypedContent::ToolCall {
877                    id: tool_call.id,
878                    name: tool_call.function.name,
879                    args,
880                });
881            }
882        }
883
884        LLMMessage {
885            role: chat_message.role.to_string(),
886            content: if content_parts.is_empty() {
887                LLMMessageContent::String(String::new())
888            } else if content_parts.len() == 1 {
889                match &content_parts[0] {
890                    LLMMessageTypedContent::Text { text } => {
891                        LLMMessageContent::String(text.clone())
892                    }
893                    _ => LLMMessageContent::List(content_parts),
894                }
895            } else {
896                LLMMessageContent::List(content_parts)
897            },
898        }
899    }
900}
901
902impl From<GenerationDelta> for ChatMessageDelta {
903    fn from(delta: GenerationDelta) -> Self {
904        match delta {
905            GenerationDelta::Content { content } => ChatMessageDelta {
906                role: Some(Role::Assistant),
907                content: Some(content),
908                tool_calls: None,
909            },
910            GenerationDelta::Thinking { thinking: _ } => ChatMessageDelta {
911                role: Some(Role::Assistant),
912                content: None,
913                tool_calls: None,
914            },
915            GenerationDelta::ToolUse { tool_use } => ChatMessageDelta {
916                role: Some(Role::Assistant),
917                content: None,
918                tool_calls: Some(vec![ToolCallDelta {
919                    index: tool_use.index,
920                    id: tool_use.id,
921                    r#type: Some("function".to_string()),
922                    function: Some(FunctionCallDelta {
923                        name: tool_use.name,
924                        arguments: tool_use.input,
925                    }),
926                }]),
927            },
928            _ => ChatMessageDelta {
929                role: Some(Role::Assistant),
930                content: None,
931                tool_calls: None,
932            },
933        }
934    }
935}
936
937#[cfg(test)]
938mod tests {
939    use super::*;
940
941    #[test]
942    fn test_serialize_basic_request() {
943        let request = ChatCompletionRequest {
944            model: AgentModel::Smart.to_string(),
945            messages: vec![
946                ChatMessage {
947                    role: Role::System,
948                    content: Some(MessageContent::String(
949                        "You are a helpful assistant.".to_string(),
950                    )),
951                    name: None,
952                    tool_calls: None,
953                    tool_call_id: None,
954                    usage: None,
955                },
956                ChatMessage {
957                    role: Role::User,
958                    content: Some(MessageContent::String("Hello!".to_string())),
959                    name: None,
960                    tool_calls: None,
961                    tool_call_id: None,
962                    usage: None,
963                },
964            ],
965            frequency_penalty: None,
966            logit_bias: None,
967            logprobs: None,
968            max_tokens: Some(100),
969            n: None,
970            presence_penalty: None,
971            response_format: None,
972            seed: None,
973            stop: None,
974            stream: None,
975            temperature: Some(0.7),
976            top_p: None,
977            tools: None,
978            tool_choice: None,
979            user: None,
980            context: None,
981        };
982
983        let json = serde_json::to_string(&request).unwrap();
984        assert!(json.contains("\"model\":\"smart\""));
985        assert!(json.contains("\"messages\":["));
986        assert!(json.contains("\"role\":\"system\""));
987    }
988
989    #[test]
990    fn test_llm_message_to_chat_message() {
991        let llm_message = LLMMessage {
992            role: "user".to_string(),
993            content: LLMMessageContent::String("Hello, world!".to_string()),
994        };
995
996        let chat_message = ChatMessage::from(llm_message);
997        assert_eq!(chat_message.role, Role::User);
998        match &chat_message.content {
999            Some(MessageContent::String(text)) => assert_eq!(text, "Hello, world!"),
1000            _ => panic!("Expected string content"),
1001        }
1002    }
1003}