Skip to main content

stakpak_shared/models/integrations/
openai.rs

1//! OpenAI provider configuration and chat message types
2//!
3//! This module contains:
4//! - Configuration types for OpenAI provider
5//! - OpenAI model enums with pricing info
6//! - Chat message types used throughout the TUI
7//! - Tool call types for agent interactions
8//!
9//! Note: Low-level API request/response types are in `libs/ai/src/providers/openai/`.
10
11use crate::models::llm::{
12    GenerationDelta, LLMMessage, LLMMessageContent, LLMMessageImageSource, LLMMessageTypedContent,
13    LLMTokenUsage, LLMTool,
14};
15use crate::models::model_pricing::{ContextAware, ContextPricingTier, ModelContextInfo};
16use serde::{Deserialize, Serialize};
17use serde_json::{Value, json};
18use uuid::Uuid;
19
20// =============================================================================
21// Provider Configuration
22// =============================================================================
23
24/// Configuration for OpenAI provider
25#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq)]
26pub struct OpenAIConfig {
27    pub api_endpoint: Option<String>,
28    pub api_key: Option<String>,
29}
30
31impl OpenAIConfig {
32    /// Create config with API key
33    pub fn with_api_key(api_key: impl Into<String>) -> Self {
34        Self {
35            api_key: Some(api_key.into()),
36            api_endpoint: None,
37        }
38    }
39
40    /// Create config from ProviderAuth (only supports API key for OpenAI)
41    pub fn from_provider_auth(auth: &crate::models::auth::ProviderAuth) -> Option<Self> {
42        match auth {
43            crate::models::auth::ProviderAuth::Api { key } => Some(Self::with_api_key(key)),
44            crate::models::auth::ProviderAuth::OAuth { .. } => None, // OpenAI doesn't support OAuth
45        }
46    }
47
48    /// Merge with credentials from ProviderAuth, preserving existing endpoint
49    pub fn with_provider_auth(mut self, auth: &crate::models::auth::ProviderAuth) -> Option<Self> {
50        match auth {
51            crate::models::auth::ProviderAuth::Api { key } => {
52                self.api_key = Some(key.clone());
53                Some(self)
54            }
55            crate::models::auth::ProviderAuth::OAuth { .. } => None, // OpenAI doesn't support OAuth
56        }
57    }
58}
59
60// =============================================================================
61// Model Definitions
62// =============================================================================
63
64/// OpenAI model identifiers
65#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
66pub enum OpenAIModel {
67    // Reasoning Models
68    #[serde(rename = "o3-2025-04-16")]
69    O3,
70    #[serde(rename = "o4-mini-2025-04-16")]
71    O4Mini,
72
73    #[default]
74    #[serde(rename = "gpt-5-2025-08-07")]
75    GPT5,
76    #[serde(rename = "gpt-5.1-2025-11-13")]
77    GPT51,
78    #[serde(rename = "gpt-5-mini-2025-08-07")]
79    GPT5Mini,
80    #[serde(rename = "gpt-5-nano-2025-08-07")]
81    GPT5Nano,
82
83    Custom(String),
84}
85
86impl OpenAIModel {
87    pub fn from_string(s: &str) -> Result<Self, String> {
88        serde_json::from_value(serde_json::Value::String(s.to_string()))
89            .map_err(|_| "Failed to deserialize OpenAI model".to_string())
90    }
91
92    /// Default smart model for OpenAI
93    pub const DEFAULT_SMART_MODEL: OpenAIModel = OpenAIModel::GPT5;
94
95    /// Default eco model for OpenAI
96    pub const DEFAULT_ECO_MODEL: OpenAIModel = OpenAIModel::GPT5Mini;
97
98    /// Default recovery model for OpenAI
99    pub const DEFAULT_RECOVERY_MODEL: OpenAIModel = OpenAIModel::GPT5Mini;
100
101    /// Get default smart model as string
102    pub fn default_smart_model() -> String {
103        Self::DEFAULT_SMART_MODEL.to_string()
104    }
105
106    /// Get default eco model as string
107    pub fn default_eco_model() -> String {
108        Self::DEFAULT_ECO_MODEL.to_string()
109    }
110
111    /// Get default recovery model as string
112    pub fn default_recovery_model() -> String {
113        Self::DEFAULT_RECOVERY_MODEL.to_string()
114    }
115}
116
117impl ContextAware for OpenAIModel {
118    fn context_info(&self) -> ModelContextInfo {
119        let model_name = self.to_string();
120
121        if model_name.starts_with("o3") {
122            return ModelContextInfo {
123                max_tokens: 200_000,
124                pricing_tiers: vec![ContextPricingTier {
125                    label: "Standard".to_string(),
126                    input_cost_per_million: 2.0,
127                    output_cost_per_million: 8.0,
128                    upper_bound: None,
129                }],
130                approach_warning_threshold: 0.8,
131            };
132        }
133
134        if model_name.starts_with("o4-mini") {
135            return ModelContextInfo {
136                max_tokens: 200_000,
137                pricing_tiers: vec![ContextPricingTier {
138                    label: "Standard".to_string(),
139                    input_cost_per_million: 1.10,
140                    output_cost_per_million: 4.40,
141                    upper_bound: None,
142                }],
143                approach_warning_threshold: 0.8,
144            };
145        }
146
147        if model_name.starts_with("gpt-5-mini") {
148            return ModelContextInfo {
149                max_tokens: 400_000,
150                pricing_tiers: vec![ContextPricingTier {
151                    label: "Standard".to_string(),
152                    input_cost_per_million: 0.25,
153                    output_cost_per_million: 2.0,
154                    upper_bound: None,
155                }],
156                approach_warning_threshold: 0.8,
157            };
158        }
159
160        if model_name.starts_with("gpt-5-nano") {
161            return ModelContextInfo {
162                max_tokens: 400_000,
163                pricing_tiers: vec![ContextPricingTier {
164                    label: "Standard".to_string(),
165                    input_cost_per_million: 0.05,
166                    output_cost_per_million: 0.40,
167                    upper_bound: None,
168                }],
169                approach_warning_threshold: 0.8,
170            };
171        }
172
173        if model_name.starts_with("gpt-5") {
174            return ModelContextInfo {
175                max_tokens: 400_000,
176                pricing_tiers: vec![ContextPricingTier {
177                    label: "Standard".to_string(),
178                    input_cost_per_million: 1.25,
179                    output_cost_per_million: 10.0,
180                    upper_bound: None,
181                }],
182                approach_warning_threshold: 0.8,
183            };
184        }
185
186        ModelContextInfo::default()
187    }
188
189    fn model_name(&self) -> String {
190        match self {
191            OpenAIModel::O3 => "O3".to_string(),
192            OpenAIModel::O4Mini => "O4-mini".to_string(),
193            OpenAIModel::GPT5 => "GPT-5".to_string(),
194            OpenAIModel::GPT51 => "GPT-5.1".to_string(),
195            OpenAIModel::GPT5Mini => "GPT-5 Mini".to_string(),
196            OpenAIModel::GPT5Nano => "GPT-5 Nano".to_string(),
197            OpenAIModel::Custom(name) => format!("Custom ({})", name),
198        }
199    }
200}
201
202impl std::fmt::Display for OpenAIModel {
203    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204        match self {
205            OpenAIModel::O3 => write!(f, "o3-2025-04-16"),
206            OpenAIModel::O4Mini => write!(f, "o4-mini-2025-04-16"),
207            OpenAIModel::GPT5Nano => write!(f, "gpt-5-nano-2025-08-07"),
208            OpenAIModel::GPT5Mini => write!(f, "gpt-5-mini-2025-08-07"),
209            OpenAIModel::GPT5 => write!(f, "gpt-5-2025-08-07"),
210            OpenAIModel::GPT51 => write!(f, "gpt-5.1-2025-11-13"),
211            OpenAIModel::Custom(model_name) => write!(f, "{}", model_name),
212        }
213    }
214}
215
216/// Agent model type (smart/eco/recovery)
217#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq)]
218pub enum AgentModel {
219    #[serde(rename = "smart")]
220    #[default]
221    Smart,
222    #[serde(rename = "eco")]
223    Eco,
224    #[serde(rename = "recovery")]
225    Recovery,
226}
227
228impl std::fmt::Display for AgentModel {
229    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
230        match self {
231            AgentModel::Smart => write!(f, "smart"),
232            AgentModel::Eco => write!(f, "eco"),
233            AgentModel::Recovery => write!(f, "recovery"),
234        }
235    }
236}
237
238impl From<String> for AgentModel {
239    fn from(value: String) -> Self {
240        match value.as_str() {
241            "eco" => AgentModel::Eco,
242            "recovery" => AgentModel::Recovery,
243            _ => AgentModel::Smart,
244        }
245    }
246}
247
248// =============================================================================
249// Message Types (used by TUI)
250// =============================================================================
251
252/// Message role
253#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
254#[serde(rename_all = "lowercase")]
255pub enum Role {
256    System,
257    Developer,
258    User,
259    #[default]
260    Assistant,
261    Tool,
262}
263
264impl std::fmt::Display for Role {
265    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
266        match self {
267            Role::System => write!(f, "system"),
268            Role::Developer => write!(f, "developer"),
269            Role::User => write!(f, "user"),
270            Role::Assistant => write!(f, "assistant"),
271            Role::Tool => write!(f, "tool"),
272        }
273    }
274}
275
276/// Model info for tracking which model generated a message
277#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
278pub struct ModelInfo {
279    /// Provider name (e.g., "anthropic", "openai")
280    pub provider: String,
281    /// Model identifier (e.g., "claude-sonnet-4-20250514", "gpt-4")
282    pub id: String,
283}
284
285/// Chat message
286#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
287pub struct ChatMessage {
288    pub role: Role,
289    pub content: Option<MessageContent>,
290    #[serde(skip_serializing_if = "Option::is_none")]
291    pub name: Option<String>,
292    #[serde(skip_serializing_if = "Option::is_none")]
293    pub tool_calls: Option<Vec<ToolCall>>,
294    #[serde(skip_serializing_if = "Option::is_none")]
295    pub tool_call_id: Option<String>,
296    #[serde(skip_serializing_if = "Option::is_none")]
297    pub usage: Option<LLMTokenUsage>,
298
299    // === Extended fields for session tracking ===
300    /// Unique message identifier
301    #[serde(skip_serializing_if = "Option::is_none")]
302    pub id: Option<String>,
303    /// Model that generated this message (for assistant messages)
304    #[serde(skip_serializing_if = "Option::is_none")]
305    pub model: Option<ModelInfo>,
306    /// Cost in dollars for this message
307    #[serde(skip_serializing_if = "Option::is_none")]
308    pub cost: Option<f64>,
309    /// Why the model stopped: "stop", "tool_calls", "length", "error"
310    #[serde(skip_serializing_if = "Option::is_none")]
311    pub finish_reason: Option<String>,
312    /// Unix timestamp (ms) when message was created/sent
313    #[serde(skip_serializing_if = "Option::is_none")]
314    pub created_at: Option<i64>,
315    /// Unix timestamp (ms) when assistant finished generating
316    #[serde(skip_serializing_if = "Option::is_none")]
317    pub completed_at: Option<i64>,
318    /// Plugin extensibility - unstructured metadata
319    #[serde(skip_serializing_if = "Option::is_none")]
320    pub metadata: Option<serde_json::Value>,
321}
322
323impl ChatMessage {
324    pub fn last_server_message(messages: &[ChatMessage]) -> Option<&ChatMessage> {
325        messages
326            .iter()
327            .rev()
328            .find(|message| message.role != Role::User && message.role != Role::Tool)
329    }
330
331    pub fn to_xml(&self) -> String {
332        match &self.content {
333            Some(MessageContent::String(s)) => {
334                format!("<message role=\"{}\">{}</message>", self.role, s)
335            }
336            Some(MessageContent::Array(parts)) => parts
337                .iter()
338                .map(|part| {
339                    format!(
340                        "<message role=\"{}\" type=\"{}\">{}</message>",
341                        self.role,
342                        part.r#type,
343                        part.text.clone().unwrap_or_default()
344                    )
345                })
346                .collect::<Vec<String>>()
347                .join("\n"),
348            None => String::new(),
349        }
350    }
351}
352
353/// Message content (string or array of parts)
354#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
355#[serde(untagged)]
356pub enum MessageContent {
357    String(String),
358    Array(Vec<ContentPart>),
359}
360
361impl MessageContent {
362    pub fn inject_checkpoint_id(&self, checkpoint_id: Uuid) -> Self {
363        match self {
364            MessageContent::String(s) => MessageContent::String(format!(
365                "<checkpoint_id>{checkpoint_id}</checkpoint_id>\n{s}"
366            )),
367            MessageContent::Array(parts) => MessageContent::Array(
368                std::iter::once(ContentPart {
369                    r#type: "text".to_string(),
370                    text: Some(format!("<checkpoint_id>{checkpoint_id}</checkpoint_id>")),
371                    image_url: None,
372                })
373                .chain(parts.iter().cloned())
374                .collect(),
375            ),
376        }
377    }
378
379    pub fn extract_checkpoint_id(&self) -> Option<Uuid> {
380        match self {
381            MessageContent::String(s) => s
382                .rfind("<checkpoint_id>")
383                .and_then(|start| {
384                    s[start..]
385                        .find("</checkpoint_id>")
386                        .map(|end| (start + "<checkpoint_id>".len(), start + end))
387                })
388                .and_then(|(start, end)| Uuid::parse_str(&s[start..end]).ok()),
389            MessageContent::Array(parts) => parts.iter().rev().find_map(|part| {
390                part.text.as_deref().and_then(|text| {
391                    text.rfind("<checkpoint_id>")
392                        .and_then(|start| {
393                            text[start..]
394                                .find("</checkpoint_id>")
395                                .map(|end| (start + "<checkpoint_id>".len(), start + end))
396                        })
397                        .and_then(|(start, end)| Uuid::parse_str(&text[start..end]).ok())
398                })
399            }),
400        }
401    }
402}
403
404impl std::fmt::Display for MessageContent {
405    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
406        match self {
407            MessageContent::String(s) => write!(f, "{s}"),
408            MessageContent::Array(parts) => {
409                let text_parts: Vec<String> =
410                    parts.iter().filter_map(|part| part.text.clone()).collect();
411                write!(f, "{}", text_parts.join("\n"))
412            }
413        }
414    }
415}
416
417impl Default for MessageContent {
418    fn default() -> Self {
419        MessageContent::String(String::new())
420    }
421}
422
423/// Content part (text or image)
424#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
425pub struct ContentPart {
426    pub r#type: String,
427    #[serde(skip_serializing_if = "Option::is_none")]
428    pub text: Option<String>,
429    #[serde(skip_serializing_if = "Option::is_none")]
430    pub image_url: Option<ImageUrl>,
431}
432
433/// Image URL with optional detail level
434#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
435pub struct ImageUrl {
436    pub url: String,
437    #[serde(skip_serializing_if = "Option::is_none")]
438    pub detail: Option<String>,
439}
440
441// =============================================================================
442// Tool Types (used by TUI)
443// =============================================================================
444
445/// Tool definition
446#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
447pub struct Tool {
448    pub r#type: String,
449    pub function: FunctionDefinition,
450}
451
452/// Function definition for tools
453#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
454pub struct FunctionDefinition {
455    pub name: String,
456    pub description: Option<String>,
457    pub parameters: serde_json::Value,
458}
459
460impl From<Tool> for LLMTool {
461    fn from(tool: Tool) -> Self {
462        LLMTool {
463            name: tool.function.name,
464            description: tool.function.description.unwrap_or_default(),
465            input_schema: tool.function.parameters,
466        }
467    }
468}
469
470/// Tool choice configuration
471#[derive(Debug, Clone, PartialEq)]
472pub enum ToolChoice {
473    Auto,
474    Required,
475    Object(ToolChoiceObject),
476}
477
478impl Serialize for ToolChoice {
479    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
480    where
481        S: serde::Serializer,
482    {
483        match self {
484            ToolChoice::Auto => serializer.serialize_str("auto"),
485            ToolChoice::Required => serializer.serialize_str("required"),
486            ToolChoice::Object(obj) => obj.serialize(serializer),
487        }
488    }
489}
490
491impl<'de> Deserialize<'de> for ToolChoice {
492    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
493    where
494        D: serde::Deserializer<'de>,
495    {
496        struct ToolChoiceVisitor;
497
498        impl<'de> serde::de::Visitor<'de> for ToolChoiceVisitor {
499            type Value = ToolChoice;
500
501            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
502                formatter.write_str("string or object")
503            }
504
505            fn visit_str<E>(self, value: &str) -> Result<ToolChoice, E>
506            where
507                E: serde::de::Error,
508            {
509                match value {
510                    "auto" => Ok(ToolChoice::Auto),
511                    "required" => Ok(ToolChoice::Required),
512                    _ => Err(serde::de::Error::unknown_variant(
513                        value,
514                        &["auto", "required"],
515                    )),
516                }
517            }
518
519            fn visit_map<M>(self, map: M) -> Result<ToolChoice, M::Error>
520            where
521                M: serde::de::MapAccess<'de>,
522            {
523                let obj = ToolChoiceObject::deserialize(
524                    serde::de::value::MapAccessDeserializer::new(map),
525                )?;
526                Ok(ToolChoice::Object(obj))
527            }
528        }
529
530        deserializer.deserialize_any(ToolChoiceVisitor)
531    }
532}
533
534#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
535pub struct ToolChoiceObject {
536    pub r#type: String,
537    pub function: FunctionChoice,
538}
539
540#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
541pub struct FunctionChoice {
542    pub name: String,
543}
544
545/// Tool call from assistant
546#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
547pub struct ToolCall {
548    pub id: String,
549    pub r#type: String,
550    pub function: FunctionCall,
551}
552
553/// Function call details
554#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
555pub struct FunctionCall {
556    pub name: String,
557    pub arguments: String,
558}
559
560/// Tool call result status
561#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
562pub enum ToolCallResultStatus {
563    Success,
564    Error,
565    Cancelled,
566}
567
568/// Tool call result
569#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
570pub struct ToolCallResult {
571    pub call: ToolCall,
572    pub result: String,
573    pub status: ToolCallResultStatus,
574}
575
576/// Tool call result progress update
577#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
578pub struct ToolCallResultProgress {
579    pub id: Uuid,
580    pub message: String,
581}
582
583// =============================================================================
584// Chat Completion Types (used by TUI)
585// =============================================================================
586
587/// Chat completion request
588#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
589pub struct ChatCompletionRequest {
590    pub model: String,
591    pub messages: Vec<ChatMessage>,
592    #[serde(skip_serializing_if = "Option::is_none")]
593    pub frequency_penalty: Option<f32>,
594    #[serde(skip_serializing_if = "Option::is_none")]
595    pub logit_bias: Option<serde_json::Value>,
596    #[serde(skip_serializing_if = "Option::is_none")]
597    pub logprobs: Option<bool>,
598    #[serde(skip_serializing_if = "Option::is_none")]
599    pub max_tokens: Option<u32>,
600    #[serde(skip_serializing_if = "Option::is_none")]
601    pub n: Option<u32>,
602    #[serde(skip_serializing_if = "Option::is_none")]
603    pub presence_penalty: Option<f32>,
604    #[serde(skip_serializing_if = "Option::is_none")]
605    pub response_format: Option<ResponseFormat>,
606    #[serde(skip_serializing_if = "Option::is_none")]
607    pub seed: Option<i64>,
608    #[serde(skip_serializing_if = "Option::is_none")]
609    pub stop: Option<StopSequence>,
610    #[serde(skip_serializing_if = "Option::is_none")]
611    pub stream: Option<bool>,
612    #[serde(skip_serializing_if = "Option::is_none")]
613    pub temperature: Option<f32>,
614    #[serde(skip_serializing_if = "Option::is_none")]
615    pub top_p: Option<f32>,
616    #[serde(skip_serializing_if = "Option::is_none")]
617    pub tools: Option<Vec<Tool>>,
618    #[serde(skip_serializing_if = "Option::is_none")]
619    pub tool_choice: Option<ToolChoice>,
620    #[serde(skip_serializing_if = "Option::is_none")]
621    pub user: Option<String>,
622    #[serde(skip_serializing_if = "Option::is_none")]
623    pub context: Option<ChatCompletionContext>,
624}
625
626impl ChatCompletionRequest {
627    pub fn new(
628        model: String,
629        messages: Vec<ChatMessage>,
630        tools: Option<Vec<Tool>>,
631        stream: Option<bool>,
632    ) -> Self {
633        Self {
634            model,
635            messages,
636            frequency_penalty: None,
637            logit_bias: None,
638            logprobs: None,
639            max_tokens: None,
640            n: None,
641            presence_penalty: None,
642            response_format: None,
643            seed: None,
644            stop: None,
645            stream,
646            temperature: None,
647            top_p: None,
648            tools,
649            tool_choice: None,
650            user: None,
651            context: None,
652        }
653    }
654}
655
656#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
657pub struct ChatCompletionContext {
658    pub scratchpad: Option<Value>,
659}
660
661#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
662pub struct ResponseFormat {
663    pub r#type: String,
664}
665
666#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
667#[serde(untagged)]
668pub enum StopSequence {
669    String(String),
670    Array(Vec<String>),
671}
672
673/// Chat completion response
674#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
675pub struct ChatCompletionResponse {
676    pub id: String,
677    pub object: String,
678    pub created: u64,
679    pub model: String,
680    pub choices: Vec<ChatCompletionChoice>,
681    pub usage: LLMTokenUsage,
682    #[serde(skip_serializing_if = "Option::is_none")]
683    pub system_fingerprint: Option<String>,
684    pub metadata: Option<serde_json::Value>,
685}
686
687#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
688pub struct ChatCompletionChoice {
689    pub index: usize,
690    pub message: ChatMessage,
691    pub logprobs: Option<LogProbs>,
692    pub finish_reason: FinishReason,
693}
694
695#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
696#[serde(rename_all = "snake_case")]
697pub enum FinishReason {
698    Stop,
699    Length,
700    ContentFilter,
701    ToolCalls,
702}
703
704#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
705pub struct LogProbs {
706    pub content: Option<Vec<LogProbContent>>,
707}
708
709#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
710pub struct LogProbContent {
711    pub token: String,
712    pub logprob: f32,
713    pub bytes: Option<Vec<u8>>,
714    pub top_logprobs: Option<Vec<TokenLogprob>>,
715}
716
717#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
718pub struct TokenLogprob {
719    pub token: String,
720    pub logprob: f32,
721    pub bytes: Option<Vec<u8>>,
722}
723
724// =============================================================================
725// Streaming Types
726// =============================================================================
727
728#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
729pub struct ChatCompletionStreamResponse {
730    pub id: String,
731    pub object: String,
732    pub created: u64,
733    pub model: String,
734    pub choices: Vec<ChatCompletionStreamChoice>,
735    pub usage: Option<LLMTokenUsage>,
736    pub metadata: Option<serde_json::Value>,
737}
738
739#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
740pub struct ChatCompletionStreamChoice {
741    pub index: usize,
742    pub delta: ChatMessageDelta,
743    pub finish_reason: Option<FinishReason>,
744}
745
746#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
747pub struct ChatMessageDelta {
748    #[serde(skip_serializing_if = "Option::is_none")]
749    pub role: Option<Role>,
750    #[serde(skip_serializing_if = "Option::is_none")]
751    pub content: Option<String>,
752    #[serde(skip_serializing_if = "Option::is_none")]
753    pub tool_calls: Option<Vec<ToolCallDelta>>,
754}
755
756#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
757pub struct ToolCallDelta {
758    pub index: usize,
759    pub id: Option<String>,
760    pub r#type: Option<String>,
761    pub function: Option<FunctionCallDelta>,
762}
763
764#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
765pub struct FunctionCallDelta {
766    pub name: Option<String>,
767    pub arguments: Option<String>,
768}
769
770// =============================================================================
771// Conversions
772// =============================================================================
773
774impl From<LLMMessage> for ChatMessage {
775    fn from(llm_message: LLMMessage) -> Self {
776        let role = match llm_message.role.as_str() {
777            "system" => Role::System,
778            "user" => Role::User,
779            "assistant" => Role::Assistant,
780            "tool" => Role::Tool,
781            "developer" => Role::Developer,
782            _ => Role::User,
783        };
784
785        let (content, tool_calls) = match llm_message.content {
786            LLMMessageContent::String(text) => (Some(MessageContent::String(text)), None),
787            LLMMessageContent::List(items) => {
788                let mut text_parts = Vec::new();
789                let mut tool_call_parts = Vec::new();
790
791                for item in items {
792                    match item {
793                        LLMMessageTypedContent::Text { text } => {
794                            text_parts.push(ContentPart {
795                                r#type: "text".to_string(),
796                                text: Some(text),
797                                image_url: None,
798                            });
799                        }
800                        LLMMessageTypedContent::ToolCall { id, name, args } => {
801                            tool_call_parts.push(ToolCall {
802                                id,
803                                r#type: "function".to_string(),
804                                function: FunctionCall {
805                                    name,
806                                    arguments: args.to_string(),
807                                },
808                            });
809                        }
810                        LLMMessageTypedContent::ToolResult { content, .. } => {
811                            text_parts.push(ContentPart {
812                                r#type: "text".to_string(),
813                                text: Some(content),
814                                image_url: None,
815                            });
816                        }
817                        LLMMessageTypedContent::Image { source } => {
818                            text_parts.push(ContentPart {
819                                r#type: "image_url".to_string(),
820                                text: None,
821                                image_url: Some(ImageUrl {
822                                    url: format!(
823                                        "data:{};base64,{}",
824                                        source.media_type, source.data
825                                    ),
826                                    detail: None,
827                                }),
828                            });
829                        }
830                    }
831                }
832
833                let content = if !text_parts.is_empty() {
834                    Some(MessageContent::Array(text_parts))
835                } else {
836                    None
837                };
838
839                let tool_calls = if !tool_call_parts.is_empty() {
840                    Some(tool_call_parts)
841                } else {
842                    None
843                };
844
845                (content, tool_calls)
846            }
847        };
848
849        ChatMessage {
850            role,
851            content,
852            name: None,
853            tool_calls,
854            tool_call_id: None,
855            usage: None,
856            ..Default::default()
857        }
858    }
859}
860
861impl From<ChatMessage> for LLMMessage {
862    fn from(chat_message: ChatMessage) -> Self {
863        let mut content_parts = Vec::new();
864
865        match chat_message.content {
866            Some(MessageContent::String(s)) => {
867                if !s.is_empty() {
868                    content_parts.push(LLMMessageTypedContent::Text { text: s });
869                }
870            }
871            Some(MessageContent::Array(parts)) => {
872                for part in parts {
873                    if let Some(text) = part.text {
874                        content_parts.push(LLMMessageTypedContent::Text { text });
875                    } else if let Some(image_url) = part.image_url {
876                        let (media_type, data) = if image_url.url.starts_with("data:") {
877                            let parts: Vec<&str> = image_url.url.splitn(2, ',').collect();
878                            if parts.len() == 2 {
879                                let meta = parts[0];
880                                let data = parts[1];
881                                let media_type = meta
882                                    .trim_start_matches("data:")
883                                    .trim_end_matches(";base64")
884                                    .to_string();
885                                (media_type, data.to_string())
886                            } else {
887                                ("image/jpeg".to_string(), image_url.url)
888                            }
889                        } else {
890                            ("image/jpeg".to_string(), image_url.url)
891                        };
892
893                        content_parts.push(LLMMessageTypedContent::Image {
894                            source: LLMMessageImageSource {
895                                r#type: "base64".to_string(),
896                                media_type,
897                                data,
898                            },
899                        });
900                    }
901                }
902            }
903            None => {}
904        }
905
906        if let Some(tool_calls) = chat_message.tool_calls {
907            for tool_call in tool_calls {
908                let args = serde_json::from_str(&tool_call.function.arguments).unwrap_or(json!({}));
909                content_parts.push(LLMMessageTypedContent::ToolCall {
910                    id: tool_call.id,
911                    name: tool_call.function.name,
912                    args,
913                });
914            }
915        }
916
917        // Handle tool result messages: when role is Tool and tool_call_id is present,
918        // convert the content to a ToolResult content part. This is the generic
919        // intermediate representation - each provider's conversion layer handles
920        // the specifics (e.g., Anthropic converts to user role with tool_result blocks)
921        if chat_message.role == Role::Tool
922            && let Some(tool_call_id) = chat_message.tool_call_id
923        {
924            // Extract content as string for the tool result
925            let content_str = content_parts
926                .iter()
927                .filter_map(|p| match p {
928                    LLMMessageTypedContent::Text { text } => Some(text.clone()),
929                    _ => None,
930                })
931                .collect::<Vec<_>>()
932                .join("\n");
933
934            // Replace content with a single ToolResult
935            content_parts = vec![LLMMessageTypedContent::ToolResult {
936                tool_use_id: tool_call_id,
937                content: content_str,
938            }];
939        }
940
941        LLMMessage {
942            role: chat_message.role.to_string(),
943            content: if content_parts.is_empty() {
944                LLMMessageContent::String(String::new())
945            } else if content_parts.len() == 1 {
946                match &content_parts[0] {
947                    LLMMessageTypedContent::Text { text } => {
948                        LLMMessageContent::String(text.clone())
949                    }
950                    _ => LLMMessageContent::List(content_parts),
951                }
952            } else {
953                LLMMessageContent::List(content_parts)
954            },
955        }
956    }
957}
958
959impl From<GenerationDelta> for ChatMessageDelta {
960    fn from(delta: GenerationDelta) -> Self {
961        match delta {
962            GenerationDelta::Content { content } => ChatMessageDelta {
963                role: Some(Role::Assistant),
964                content: Some(content),
965                tool_calls: None,
966            },
967            GenerationDelta::Thinking { thinking: _ } => ChatMessageDelta {
968                role: Some(Role::Assistant),
969                content: None,
970                tool_calls: None,
971            },
972            GenerationDelta::ToolUse { tool_use } => ChatMessageDelta {
973                role: Some(Role::Assistant),
974                content: None,
975                tool_calls: Some(vec![ToolCallDelta {
976                    index: tool_use.index,
977                    id: tool_use.id,
978                    r#type: Some("function".to_string()),
979                    function: Some(FunctionCallDelta {
980                        name: tool_use.name,
981                        arguments: tool_use.input,
982                    }),
983                }]),
984            },
985            _ => ChatMessageDelta {
986                role: Some(Role::Assistant),
987                content: None,
988                tool_calls: None,
989            },
990        }
991    }
992}
993
994#[cfg(test)]
995mod tests {
996    use super::*;
997
998    #[test]
999    fn test_serialize_basic_request() {
1000        let request = ChatCompletionRequest {
1001            model: AgentModel::Smart.to_string(),
1002            messages: vec![
1003                ChatMessage {
1004                    role: Role::System,
1005                    content: Some(MessageContent::String(
1006                        "You are a helpful assistant.".to_string(),
1007                    )),
1008                    name: None,
1009                    tool_calls: None,
1010                    tool_call_id: None,
1011                    usage: None,
1012                    ..Default::default()
1013                },
1014                ChatMessage {
1015                    role: Role::User,
1016                    content: Some(MessageContent::String("Hello!".to_string())),
1017                    name: None,
1018                    tool_calls: None,
1019                    tool_call_id: None,
1020                    usage: None,
1021                    ..Default::default()
1022                },
1023            ],
1024            frequency_penalty: None,
1025            logit_bias: None,
1026            logprobs: None,
1027            max_tokens: Some(100),
1028            n: None,
1029            presence_penalty: None,
1030            response_format: None,
1031            seed: None,
1032            stop: None,
1033            stream: None,
1034            temperature: Some(0.7),
1035            top_p: None,
1036            tools: None,
1037            tool_choice: None,
1038            user: None,
1039            context: None,
1040        };
1041
1042        let json = serde_json::to_string(&request).unwrap();
1043        assert!(json.contains("\"model\":\"smart\""));
1044        assert!(json.contains("\"messages\":["));
1045        assert!(json.contains("\"role\":\"system\""));
1046    }
1047
1048    #[test]
1049    fn test_llm_message_to_chat_message() {
1050        let llm_message = LLMMessage {
1051            role: "user".to_string(),
1052            content: LLMMessageContent::String("Hello, world!".to_string()),
1053        };
1054
1055        let chat_message = ChatMessage::from(llm_message);
1056        assert_eq!(chat_message.role, Role::User);
1057        match &chat_message.content {
1058            Some(MessageContent::String(text)) => assert_eq!(text, "Hello, world!"),
1059            _ => panic!("Expected string content"),
1060        }
1061    }
1062
1063    #[test]
1064    fn test_chat_message_to_llm_message_tool_result() {
1065        // Test that Tool role messages with tool_call_id are converted to ToolResult content
1066        // This is critical for Anthropic compatibility - the provider layer converts
1067        // role="tool" to role="user" with tool_result content blocks
1068        let chat_message = ChatMessage {
1069            role: Role::Tool,
1070            content: Some(MessageContent::String("Tool execution result".to_string())),
1071            name: None,
1072            tool_calls: None,
1073            tool_call_id: Some("toolu_01Abc123".to_string()),
1074            usage: None,
1075            ..Default::default()
1076        };
1077
1078        let llm_message: LLMMessage = chat_message.into();
1079
1080        // Role should be preserved as "tool" - provider layer handles conversion
1081        assert_eq!(llm_message.role, "tool");
1082
1083        // Content should be a ToolResult with the tool_call_id
1084        match &llm_message.content {
1085            LLMMessageContent::List(parts) => {
1086                assert_eq!(parts.len(), 1, "Should have exactly one content part");
1087                match &parts[0] {
1088                    LLMMessageTypedContent::ToolResult {
1089                        tool_use_id,
1090                        content,
1091                    } => {
1092                        assert_eq!(tool_use_id, "toolu_01Abc123");
1093                        assert_eq!(content, "Tool execution result");
1094                    }
1095                    _ => panic!("Expected ToolResult content part, got {:?}", parts[0]),
1096                }
1097            }
1098            _ => panic!(
1099                "Expected List content with ToolResult, got {:?}",
1100                llm_message.content
1101            ),
1102        }
1103    }
1104
1105    #[test]
1106    fn test_chat_message_to_llm_message_tool_result_empty_content() {
1107        // Test tool result with empty content
1108        let chat_message = ChatMessage {
1109            role: Role::Tool,
1110            content: None,
1111            name: None,
1112            tool_calls: None,
1113            tool_call_id: Some("toolu_02Xyz789".to_string()),
1114            usage: None,
1115            ..Default::default()
1116        };
1117
1118        let llm_message: LLMMessage = chat_message.into();
1119
1120        assert_eq!(llm_message.role, "tool");
1121        match &llm_message.content {
1122            LLMMessageContent::List(parts) => {
1123                assert_eq!(parts.len(), 1);
1124                match &parts[0] {
1125                    LLMMessageTypedContent::ToolResult {
1126                        tool_use_id,
1127                        content,
1128                    } => {
1129                        assert_eq!(tool_use_id, "toolu_02Xyz789");
1130                        assert_eq!(content, ""); // Empty content
1131                    }
1132                    _ => panic!("Expected ToolResult content part"),
1133                }
1134            }
1135            _ => panic!("Expected List content with ToolResult"),
1136        }
1137    }
1138
1139    #[test]
1140    fn test_chat_message_to_llm_message_assistant_with_tool_calls() {
1141        // Test that assistant messages with tool_calls are converted correctly
1142        let chat_message = ChatMessage {
1143            role: Role::Assistant,
1144            content: Some(MessageContent::String(
1145                "I'll help you with that.".to_string(),
1146            )),
1147            name: None,
1148            tool_calls: Some(vec![ToolCall {
1149                id: "call_abc123".to_string(),
1150                r#type: "function".to_string(),
1151                function: FunctionCall {
1152                    name: "get_weather".to_string(),
1153                    arguments: r#"{"location": "Paris"}"#.to_string(),
1154                },
1155            }]),
1156            tool_call_id: None,
1157            usage: None,
1158            ..Default::default()
1159        };
1160
1161        let llm_message: LLMMessage = chat_message.into();
1162
1163        assert_eq!(llm_message.role, "assistant");
1164        match &llm_message.content {
1165            LLMMessageContent::List(parts) => {
1166                assert_eq!(parts.len(), 2, "Should have text and tool call");
1167
1168                // First part should be text
1169                match &parts[0] {
1170                    LLMMessageTypedContent::Text { text } => {
1171                        assert_eq!(text, "I'll help you with that.");
1172                    }
1173                    _ => panic!("Expected Text content part first"),
1174                }
1175
1176                // Second part should be tool call
1177                match &parts[1] {
1178                    LLMMessageTypedContent::ToolCall { id, name, args } => {
1179                        assert_eq!(id, "call_abc123");
1180                        assert_eq!(name, "get_weather");
1181                        assert_eq!(args["location"], "Paris");
1182                    }
1183                    _ => panic!("Expected ToolCall content part second"),
1184                }
1185            }
1186            _ => panic!("Expected List content"),
1187        }
1188    }
1189}