Skip to main content

stakpak_shared/models/integrations/
openai.rs

1//! OpenAI provider configuration and chat message types
2//!
3//! This module contains:
4//! - Configuration types for OpenAI provider
5//! - OpenAI model enums with pricing info
6//! - Chat message types used throughout the TUI
7//! - Tool call types for agent interactions
8//!
9//! Note: Low-level API request/response types are in `libs/ai/src/providers/openai/`.
10
11use crate::models::llm::{
12    GenerationDelta, LLMMessage, LLMMessageContent, LLMMessageImageSource, LLMMessageTypedContent,
13    LLMTokenUsage, LLMTool,
14};
15use crate::models::model_pricing::{ContextAware, ContextPricingTier, ModelContextInfo};
16use serde::{Deserialize, Serialize};
17use serde_json::{Value, json};
18use uuid::Uuid;
19
20// =============================================================================
21// Provider Configuration
22// =============================================================================
23
24/// Configuration for OpenAI provider
25#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq)]
26pub struct OpenAIConfig {
27    pub api_endpoint: Option<String>,
28    pub api_key: Option<String>,
29}
30
31impl OpenAIConfig {
32    /// Create config with API key
33    pub fn with_api_key(api_key: impl Into<String>) -> Self {
34        Self {
35            api_key: Some(api_key.into()),
36            api_endpoint: None,
37        }
38    }
39
40    /// Create config from ProviderAuth (only supports API key for OpenAI)
41    pub fn from_provider_auth(auth: &crate::models::auth::ProviderAuth) -> Option<Self> {
42        match auth {
43            crate::models::auth::ProviderAuth::Api { key } => Some(Self::with_api_key(key)),
44            crate::models::auth::ProviderAuth::OAuth { .. } => None, // OpenAI doesn't support OAuth
45        }
46    }
47
48    /// Merge with credentials from ProviderAuth, preserving existing endpoint
49    pub fn with_provider_auth(mut self, auth: &crate::models::auth::ProviderAuth) -> Option<Self> {
50        match auth {
51            crate::models::auth::ProviderAuth::Api { key } => {
52                self.api_key = Some(key.clone());
53                Some(self)
54            }
55            crate::models::auth::ProviderAuth::OAuth { .. } => None, // OpenAI doesn't support OAuth
56        }
57    }
58}
59
60// =============================================================================
61// Model Definitions
62// =============================================================================
63
64/// OpenAI model identifiers
65#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
66pub enum OpenAIModel {
67    // Reasoning Models
68    #[serde(rename = "o3-2025-04-16")]
69    O3,
70    #[serde(rename = "o4-mini-2025-04-16")]
71    O4Mini,
72
73    #[default]
74    #[serde(rename = "gpt-5-2025-08-07")]
75    GPT5,
76    #[serde(rename = "gpt-5.1-2025-11-13")]
77    GPT51,
78    #[serde(rename = "gpt-5-mini-2025-08-07")]
79    GPT5Mini,
80    #[serde(rename = "gpt-5-nano-2025-08-07")]
81    GPT5Nano,
82
83    Custom(String),
84}
85
86impl OpenAIModel {
87    pub fn from_string(s: &str) -> Result<Self, String> {
88        serde_json::from_value(serde_json::Value::String(s.to_string()))
89            .map_err(|_| "Failed to deserialize OpenAI model".to_string())
90    }
91
92    /// Default smart model for OpenAI
93    pub const DEFAULT_SMART_MODEL: OpenAIModel = OpenAIModel::GPT5;
94
95    /// Default eco model for OpenAI
96    pub const DEFAULT_ECO_MODEL: OpenAIModel = OpenAIModel::GPT5Mini;
97
98    /// Default recovery model for OpenAI
99    pub const DEFAULT_RECOVERY_MODEL: OpenAIModel = OpenAIModel::GPT5Mini;
100
101    /// Get default smart model as string
102    pub fn default_smart_model() -> String {
103        Self::DEFAULT_SMART_MODEL.to_string()
104    }
105
106    /// Get default eco model as string
107    pub fn default_eco_model() -> String {
108        Self::DEFAULT_ECO_MODEL.to_string()
109    }
110
111    /// Get default recovery model as string
112    pub fn default_recovery_model() -> String {
113        Self::DEFAULT_RECOVERY_MODEL.to_string()
114    }
115}
116
117impl ContextAware for OpenAIModel {
118    fn context_info(&self) -> ModelContextInfo {
119        let model_name = self.to_string();
120
121        if model_name.starts_with("o3") {
122            return ModelContextInfo {
123                max_tokens: 200_000,
124                pricing_tiers: vec![ContextPricingTier {
125                    label: "Standard".to_string(),
126                    input_cost_per_million: 2.0,
127                    output_cost_per_million: 8.0,
128                    upper_bound: None,
129                }],
130                approach_warning_threshold: 0.8,
131            };
132        }
133
134        if model_name.starts_with("o4-mini") {
135            return ModelContextInfo {
136                max_tokens: 200_000,
137                pricing_tiers: vec![ContextPricingTier {
138                    label: "Standard".to_string(),
139                    input_cost_per_million: 1.10,
140                    output_cost_per_million: 4.40,
141                    upper_bound: None,
142                }],
143                approach_warning_threshold: 0.8,
144            };
145        }
146
147        if model_name.starts_with("gpt-5-mini") {
148            return ModelContextInfo {
149                max_tokens: 400_000,
150                pricing_tiers: vec![ContextPricingTier {
151                    label: "Standard".to_string(),
152                    input_cost_per_million: 0.25,
153                    output_cost_per_million: 2.0,
154                    upper_bound: None,
155                }],
156                approach_warning_threshold: 0.8,
157            };
158        }
159
160        if model_name.starts_with("gpt-5-nano") {
161            return ModelContextInfo {
162                max_tokens: 400_000,
163                pricing_tiers: vec![ContextPricingTier {
164                    label: "Standard".to_string(),
165                    input_cost_per_million: 0.05,
166                    output_cost_per_million: 0.40,
167                    upper_bound: None,
168                }],
169                approach_warning_threshold: 0.8,
170            };
171        }
172
173        if model_name.starts_with("gpt-5") {
174            return ModelContextInfo {
175                max_tokens: 400_000,
176                pricing_tiers: vec![ContextPricingTier {
177                    label: "Standard".to_string(),
178                    input_cost_per_million: 1.25,
179                    output_cost_per_million: 10.0,
180                    upper_bound: None,
181                }],
182                approach_warning_threshold: 0.8,
183            };
184        }
185
186        ModelContextInfo::default()
187    }
188
189    fn model_name(&self) -> String {
190        match self {
191            OpenAIModel::O3 => "O3".to_string(),
192            OpenAIModel::O4Mini => "O4-mini".to_string(),
193            OpenAIModel::GPT5 => "GPT-5".to_string(),
194            OpenAIModel::GPT51 => "GPT-5.1".to_string(),
195            OpenAIModel::GPT5Mini => "GPT-5 Mini".to_string(),
196            OpenAIModel::GPT5Nano => "GPT-5 Nano".to_string(),
197            OpenAIModel::Custom(name) => format!("Custom ({})", name),
198        }
199    }
200}
201
202impl std::fmt::Display for OpenAIModel {
203    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204        match self {
205            OpenAIModel::O3 => write!(f, "o3-2025-04-16"),
206            OpenAIModel::O4Mini => write!(f, "o4-mini-2025-04-16"),
207            OpenAIModel::GPT5Nano => write!(f, "gpt-5-nano-2025-08-07"),
208            OpenAIModel::GPT5Mini => write!(f, "gpt-5-mini-2025-08-07"),
209            OpenAIModel::GPT5 => write!(f, "gpt-5-2025-08-07"),
210            OpenAIModel::GPT51 => write!(f, "gpt-5.1-2025-11-13"),
211            OpenAIModel::Custom(model_name) => write!(f, "{}", model_name),
212        }
213    }
214}
215
216// =============================================================================
217// Message Types (used by TUI)
218// =============================================================================
219
220/// Message role
221#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
222#[serde(rename_all = "lowercase")]
223pub enum Role {
224    System,
225    Developer,
226    User,
227    #[default]
228    Assistant,
229    Tool,
230}
231
232impl std::fmt::Display for Role {
233    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
234        match self {
235            Role::System => write!(f, "system"),
236            Role::Developer => write!(f, "developer"),
237            Role::User => write!(f, "user"),
238            Role::Assistant => write!(f, "assistant"),
239            Role::Tool => write!(f, "tool"),
240        }
241    }
242}
243
244/// Model info for tracking which model generated a message
245#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
246pub struct ModelInfo {
247    /// Provider name (e.g., "anthropic", "openai")
248    pub provider: String,
249    /// Model identifier (e.g., "claude-sonnet-4-20250514", "gpt-4")
250    pub id: String,
251}
252
253/// Chat message
254#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
255pub struct ChatMessage {
256    pub role: Role,
257    pub content: Option<MessageContent>,
258    #[serde(skip_serializing_if = "Option::is_none")]
259    pub name: Option<String>,
260    #[serde(skip_serializing_if = "Option::is_none")]
261    pub tool_calls: Option<Vec<ToolCall>>,
262    #[serde(skip_serializing_if = "Option::is_none")]
263    pub tool_call_id: Option<String>,
264    #[serde(skip_serializing_if = "Option::is_none")]
265    pub usage: Option<LLMTokenUsage>,
266
267    // === Extended fields for session tracking ===
268    /// Unique message identifier
269    #[serde(skip_serializing_if = "Option::is_none")]
270    pub id: Option<String>,
271    /// Model that generated this message (for assistant messages)
272    #[serde(skip_serializing_if = "Option::is_none")]
273    pub model: Option<ModelInfo>,
274    /// Cost in dollars for this message
275    #[serde(skip_serializing_if = "Option::is_none")]
276    pub cost: Option<f64>,
277    /// Why the model stopped: "stop", "tool_calls", "length", "error"
278    #[serde(skip_serializing_if = "Option::is_none")]
279    pub finish_reason: Option<String>,
280    /// Unix timestamp (ms) when message was created/sent
281    #[serde(skip_serializing_if = "Option::is_none")]
282    pub created_at: Option<i64>,
283    /// Unix timestamp (ms) when assistant finished generating
284    #[serde(skip_serializing_if = "Option::is_none")]
285    pub completed_at: Option<i64>,
286    /// Plugin extensibility - unstructured metadata
287    #[serde(skip_serializing_if = "Option::is_none")]
288    pub metadata: Option<serde_json::Value>,
289}
290
291impl ChatMessage {
292    pub fn last_server_message(messages: &[ChatMessage]) -> Option<&ChatMessage> {
293        messages
294            .iter()
295            .rev()
296            .find(|message| message.role != Role::User && message.role != Role::Tool)
297    }
298
299    pub fn to_xml(&self) -> String {
300        match &self.content {
301            Some(MessageContent::String(s)) => {
302                format!("<message role=\"{}\">{}</message>", self.role, s)
303            }
304            Some(MessageContent::Array(parts)) => parts
305                .iter()
306                .map(|part| {
307                    format!(
308                        "<message role=\"{}\" type=\"{}\">{}</message>",
309                        self.role,
310                        part.r#type,
311                        part.text.clone().unwrap_or_default()
312                    )
313                })
314                .collect::<Vec<String>>()
315                .join("\n"),
316            None => String::new(),
317        }
318    }
319}
320
321/// Message content (string or array of parts)
322#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
323#[serde(untagged)]
324pub enum MessageContent {
325    String(String),
326    Array(Vec<ContentPart>),
327}
328
329impl MessageContent {
330    pub fn inject_checkpoint_id(&self, checkpoint_id: Uuid) -> Self {
331        match self {
332            MessageContent::String(s) => MessageContent::String(format!(
333                "<checkpoint_id>{checkpoint_id}</checkpoint_id>\n{s}"
334            )),
335            MessageContent::Array(parts) => MessageContent::Array(
336                std::iter::once(ContentPart {
337                    r#type: "text".to_string(),
338                    text: Some(format!("<checkpoint_id>{checkpoint_id}</checkpoint_id>")),
339                    image_url: None,
340                })
341                .chain(parts.iter().cloned())
342                .collect(),
343            ),
344        }
345    }
346
347    pub fn extract_checkpoint_id(&self) -> Option<Uuid> {
348        match self {
349            MessageContent::String(s) => s
350                .rfind("<checkpoint_id>")
351                .and_then(|start| {
352                    s[start..]
353                        .find("</checkpoint_id>")
354                        .map(|end| (start + "<checkpoint_id>".len(), start + end))
355                })
356                .and_then(|(start, end)| Uuid::parse_str(&s[start..end]).ok()),
357            MessageContent::Array(parts) => parts.iter().rev().find_map(|part| {
358                part.text.as_deref().and_then(|text| {
359                    text.rfind("<checkpoint_id>")
360                        .and_then(|start| {
361                            text[start..]
362                                .find("</checkpoint_id>")
363                                .map(|end| (start + "<checkpoint_id>".len(), start + end))
364                        })
365                        .and_then(|(start, end)| Uuid::parse_str(&text[start..end]).ok())
366                })
367            }),
368        }
369    }
370}
371
372impl std::fmt::Display for MessageContent {
373    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
374        match self {
375            MessageContent::String(s) => write!(f, "{s}"),
376            MessageContent::Array(parts) => {
377                let text_parts: Vec<String> =
378                    parts.iter().filter_map(|part| part.text.clone()).collect();
379                write!(f, "{}", text_parts.join("\n"))
380            }
381        }
382    }
383}
384
385impl Default for MessageContent {
386    fn default() -> Self {
387        MessageContent::String(String::new())
388    }
389}
390
391/// Content part (text or image)
392#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
393pub struct ContentPart {
394    pub r#type: String,
395    #[serde(skip_serializing_if = "Option::is_none")]
396    pub text: Option<String>,
397    #[serde(skip_serializing_if = "Option::is_none")]
398    pub image_url: Option<ImageUrl>,
399}
400
401/// Image URL with optional detail level
402#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
403pub struct ImageUrl {
404    pub url: String,
405    #[serde(skip_serializing_if = "Option::is_none")]
406    pub detail: Option<String>,
407}
408
409// =============================================================================
410// Tool Types (used by TUI)
411// =============================================================================
412
413/// Tool definition
414#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
415pub struct Tool {
416    pub r#type: String,
417    pub function: FunctionDefinition,
418}
419
420/// Function definition for tools
421#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
422pub struct FunctionDefinition {
423    pub name: String,
424    pub description: Option<String>,
425    pub parameters: serde_json::Value,
426}
427
428impl From<Tool> for LLMTool {
429    fn from(tool: Tool) -> Self {
430        LLMTool {
431            name: tool.function.name,
432            description: tool.function.description.unwrap_or_default(),
433            input_schema: tool.function.parameters,
434        }
435    }
436}
437
438/// Tool choice configuration
439#[derive(Debug, Clone, PartialEq)]
440pub enum ToolChoice {
441    Auto,
442    Required,
443    Object(ToolChoiceObject),
444}
445
446impl Serialize for ToolChoice {
447    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
448    where
449        S: serde::Serializer,
450    {
451        match self {
452            ToolChoice::Auto => serializer.serialize_str("auto"),
453            ToolChoice::Required => serializer.serialize_str("required"),
454            ToolChoice::Object(obj) => obj.serialize(serializer),
455        }
456    }
457}
458
459impl<'de> Deserialize<'de> for ToolChoice {
460    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
461    where
462        D: serde::Deserializer<'de>,
463    {
464        struct ToolChoiceVisitor;
465
466        impl<'de> serde::de::Visitor<'de> for ToolChoiceVisitor {
467            type Value = ToolChoice;
468
469            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
470                formatter.write_str("string or object")
471            }
472
473            fn visit_str<E>(self, value: &str) -> Result<ToolChoice, E>
474            where
475                E: serde::de::Error,
476            {
477                match value {
478                    "auto" => Ok(ToolChoice::Auto),
479                    "required" => Ok(ToolChoice::Required),
480                    _ => Err(serde::de::Error::unknown_variant(
481                        value,
482                        &["auto", "required"],
483                    )),
484                }
485            }
486
487            fn visit_map<M>(self, map: M) -> Result<ToolChoice, M::Error>
488            where
489                M: serde::de::MapAccess<'de>,
490            {
491                let obj = ToolChoiceObject::deserialize(
492                    serde::de::value::MapAccessDeserializer::new(map),
493                )?;
494                Ok(ToolChoice::Object(obj))
495            }
496        }
497
498        deserializer.deserialize_any(ToolChoiceVisitor)
499    }
500}
501
502#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
503pub struct ToolChoiceObject {
504    pub r#type: String,
505    pub function: FunctionChoice,
506}
507
508#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
509pub struct FunctionChoice {
510    pub name: String,
511}
512
513/// Tool call from assistant
514#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
515pub struct ToolCall {
516    pub id: String,
517    pub r#type: String,
518    pub function: FunctionCall,
519    /// Opaque provider-specific metadata (e.g., Gemini thought_signature)
520    #[serde(skip_serializing_if = "Option::is_none")]
521    pub metadata: Option<serde_json::Value>,
522}
523
524/// Function call details
525#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
526pub struct FunctionCall {
527    pub name: String,
528    pub arguments: String,
529}
530
531/// Tool call result status
532#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
533pub enum ToolCallResultStatus {
534    Success,
535    Error,
536    Cancelled,
537}
538
539/// Tool call result
540#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
541pub struct ToolCallResult {
542    pub call: ToolCall,
543    pub result: String,
544    pub status: ToolCallResultStatus,
545}
546
547/// Tool call result progress update
548#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
549pub struct ToolCallResultProgress {
550    pub id: Uuid,
551    pub message: String,
552    /// Type of progress update for specialized handling
553    #[serde(skip_serializing_if = "Option::is_none")]
554    pub progress_type: Option<ProgressType>,
555    /// Structured task updates for task wait progress
556    #[serde(skip_serializing_if = "Option::is_none")]
557    pub task_updates: Option<Vec<TaskUpdate>>,
558    /// Overall progress percentage (0.0 - 100.0)
559    #[serde(skip_serializing_if = "Option::is_none")]
560    pub progress: Option<f64>,
561}
562
563/// Type of progress update
564#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
565pub enum ProgressType {
566    /// Command output streaming
567    CommandOutput,
568    /// Task wait progress with structured updates
569    TaskWait,
570    /// Generic progress
571    Generic,
572}
573
574/// Structured task status update
575#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
576pub struct TaskUpdate {
577    pub task_id: String,
578    pub status: String,
579    #[serde(skip_serializing_if = "Option::is_none")]
580    pub description: Option<String>,
581    #[serde(skip_serializing_if = "Option::is_none")]
582    pub duration_secs: Option<f64>,
583    #[serde(skip_serializing_if = "Option::is_none")]
584    pub output_preview: Option<String>,
585    /// Whether this is a target task being waited on
586    #[serde(default)]
587    pub is_target: bool,
588    /// Pause information for paused subagent tasks
589    #[serde(skip_serializing_if = "Option::is_none")]
590    pub pause_info: Option<TaskPauseInfo>,
591}
592
593/// Pause information for subagent tasks awaiting approval
594#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
595pub struct TaskPauseInfo {
596    /// The agent's message before pausing
597    #[serde(skip_serializing_if = "Option::is_none")]
598    pub agent_message: Option<String>,
599    /// Pending tool calls awaiting approval
600    #[serde(skip_serializing_if = "Option::is_none")]
601    pub pending_tool_calls: Option<Vec<PendingToolCall>>,
602}
603
604/// A pending tool call awaiting approval
605#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
606pub struct PendingToolCall {
607    pub id: String,
608    pub name: String,
609    #[serde(skip_serializing_if = "Option::is_none")]
610    pub arguments: Option<serde_json::Value>,
611}
612
613// =============================================================================
614// Chat Completion Types (used by TUI)
615// =============================================================================
616
617/// Chat completion request
618#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
619pub struct ChatCompletionRequest {
620    pub model: String,
621    pub messages: Vec<ChatMessage>,
622    #[serde(skip_serializing_if = "Option::is_none")]
623    pub frequency_penalty: Option<f32>,
624    #[serde(skip_serializing_if = "Option::is_none")]
625    pub logit_bias: Option<serde_json::Value>,
626    #[serde(skip_serializing_if = "Option::is_none")]
627    pub logprobs: Option<bool>,
628    #[serde(skip_serializing_if = "Option::is_none")]
629    pub max_tokens: Option<u32>,
630    #[serde(skip_serializing_if = "Option::is_none")]
631    pub n: Option<u32>,
632    #[serde(skip_serializing_if = "Option::is_none")]
633    pub presence_penalty: Option<f32>,
634    #[serde(skip_serializing_if = "Option::is_none")]
635    pub response_format: Option<ResponseFormat>,
636    #[serde(skip_serializing_if = "Option::is_none")]
637    pub seed: Option<i64>,
638    #[serde(skip_serializing_if = "Option::is_none")]
639    pub stop: Option<StopSequence>,
640    #[serde(skip_serializing_if = "Option::is_none")]
641    pub stream: Option<bool>,
642    #[serde(skip_serializing_if = "Option::is_none")]
643    pub temperature: Option<f32>,
644    #[serde(skip_serializing_if = "Option::is_none")]
645    pub top_p: Option<f32>,
646    #[serde(skip_serializing_if = "Option::is_none")]
647    pub tools: Option<Vec<Tool>>,
648    #[serde(skip_serializing_if = "Option::is_none")]
649    pub tool_choice: Option<ToolChoice>,
650    #[serde(skip_serializing_if = "Option::is_none")]
651    pub user: Option<String>,
652    #[serde(skip_serializing_if = "Option::is_none")]
653    pub context: Option<ChatCompletionContext>,
654}
655
656impl ChatCompletionRequest {
657    pub fn new(
658        model: String,
659        messages: Vec<ChatMessage>,
660        tools: Option<Vec<Tool>>,
661        stream: Option<bool>,
662    ) -> Self {
663        Self {
664            model,
665            messages,
666            frequency_penalty: None,
667            logit_bias: None,
668            logprobs: None,
669            max_tokens: None,
670            n: None,
671            presence_penalty: None,
672            response_format: None,
673            seed: None,
674            stop: None,
675            stream,
676            temperature: None,
677            top_p: None,
678            tools,
679            tool_choice: None,
680            user: None,
681            context: None,
682        }
683    }
684}
685
686#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
687pub struct ChatCompletionContext {
688    pub scratchpad: Option<Value>,
689}
690
691#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
692pub struct ResponseFormat {
693    pub r#type: String,
694}
695
696#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
697#[serde(untagged)]
698pub enum StopSequence {
699    String(String),
700    Array(Vec<String>),
701}
702
703/// Chat completion response
704#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
705pub struct ChatCompletionResponse {
706    pub id: String,
707    pub object: String,
708    pub created: u64,
709    pub model: String,
710    pub choices: Vec<ChatCompletionChoice>,
711    pub usage: LLMTokenUsage,
712    #[serde(skip_serializing_if = "Option::is_none")]
713    pub system_fingerprint: Option<String>,
714    pub metadata: Option<serde_json::Value>,
715}
716
717#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
718pub struct ChatCompletionChoice {
719    pub index: usize,
720    pub message: ChatMessage,
721    pub logprobs: Option<LogProbs>,
722    pub finish_reason: FinishReason,
723}
724
725#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
726#[serde(rename_all = "snake_case")]
727pub enum FinishReason {
728    Stop,
729    Length,
730    ContentFilter,
731    ToolCalls,
732}
733
734#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
735pub struct LogProbs {
736    pub content: Option<Vec<LogProbContent>>,
737}
738
739#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
740pub struct LogProbContent {
741    pub token: String,
742    pub logprob: f32,
743    pub bytes: Option<Vec<u8>>,
744    pub top_logprobs: Option<Vec<TokenLogprob>>,
745}
746
747#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
748pub struct TokenLogprob {
749    pub token: String,
750    pub logprob: f32,
751    pub bytes: Option<Vec<u8>>,
752}
753
754// =============================================================================
755// Streaming Types
756// =============================================================================
757
758#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
759pub struct ChatCompletionStreamResponse {
760    pub id: String,
761    pub object: String,
762    pub created: u64,
763    pub model: String,
764    pub choices: Vec<ChatCompletionStreamChoice>,
765    pub usage: Option<LLMTokenUsage>,
766    pub metadata: Option<serde_json::Value>,
767}
768
769#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
770pub struct ChatCompletionStreamChoice {
771    pub index: usize,
772    pub delta: ChatMessageDelta,
773    pub finish_reason: Option<FinishReason>,
774}
775
776#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
777pub struct ChatMessageDelta {
778    #[serde(skip_serializing_if = "Option::is_none")]
779    pub role: Option<Role>,
780    #[serde(skip_serializing_if = "Option::is_none")]
781    pub content: Option<String>,
782    #[serde(skip_serializing_if = "Option::is_none")]
783    pub tool_calls: Option<Vec<ToolCallDelta>>,
784}
785
786#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
787pub struct ToolCallDelta {
788    pub index: usize,
789    pub id: Option<String>,
790    pub r#type: Option<String>,
791    pub function: Option<FunctionCallDelta>,
792    /// Opaque provider-specific metadata (e.g., Gemini thought_signature)
793    #[serde(skip_serializing_if = "Option::is_none")]
794    pub metadata: Option<serde_json::Value>,
795}
796
797#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
798pub struct FunctionCallDelta {
799    pub name: Option<String>,
800    pub arguments: Option<String>,
801}
802
803// =============================================================================
804// Conversions
805// =============================================================================
806
807impl From<LLMMessage> for ChatMessage {
808    fn from(llm_message: LLMMessage) -> Self {
809        let role = match llm_message.role.as_str() {
810            "system" => Role::System,
811            "user" => Role::User,
812            "assistant" => Role::Assistant,
813            "tool" => Role::Tool,
814            "developer" => Role::Developer,
815            _ => Role::User,
816        };
817
818        let (content, tool_calls) = match llm_message.content {
819            LLMMessageContent::String(text) => (Some(MessageContent::String(text)), None),
820            LLMMessageContent::List(items) => {
821                let mut text_parts = Vec::new();
822                let mut tool_call_parts = Vec::new();
823
824                for item in items {
825                    match item {
826                        LLMMessageTypedContent::Text { text } => {
827                            text_parts.push(ContentPart {
828                                r#type: "text".to_string(),
829                                text: Some(text),
830                                image_url: None,
831                            });
832                        }
833                        LLMMessageTypedContent::ToolCall {
834                            id,
835                            name,
836                            args,
837                            metadata,
838                        } => {
839                            tool_call_parts.push(ToolCall {
840                                id,
841                                r#type: "function".to_string(),
842                                function: FunctionCall {
843                                    name,
844                                    arguments: args.to_string(),
845                                },
846                                metadata,
847                            });
848                        }
849                        LLMMessageTypedContent::ToolResult { content, .. } => {
850                            text_parts.push(ContentPart {
851                                r#type: "text".to_string(),
852                                text: Some(content),
853                                image_url: None,
854                            });
855                        }
856                        LLMMessageTypedContent::Image { source } => {
857                            text_parts.push(ContentPart {
858                                r#type: "image_url".to_string(),
859                                text: None,
860                                image_url: Some(ImageUrl {
861                                    url: format!(
862                                        "data:{};base64,{}",
863                                        source.media_type, source.data
864                                    ),
865                                    detail: None,
866                                }),
867                            });
868                        }
869                    }
870                }
871
872                let content = if !text_parts.is_empty() {
873                    Some(MessageContent::Array(text_parts))
874                } else {
875                    None
876                };
877
878                let tool_calls = if !tool_call_parts.is_empty() {
879                    Some(tool_call_parts)
880                } else {
881                    None
882                };
883
884                (content, tool_calls)
885            }
886        };
887
888        ChatMessage {
889            role,
890            content,
891            name: None,
892            tool_calls,
893            tool_call_id: None,
894            usage: None,
895            ..Default::default()
896        }
897    }
898}
899
900impl From<ChatMessage> for LLMMessage {
901    fn from(chat_message: ChatMessage) -> Self {
902        let mut content_parts = Vec::new();
903
904        match chat_message.content {
905            Some(MessageContent::String(s)) => {
906                if !s.is_empty() {
907                    content_parts.push(LLMMessageTypedContent::Text { text: s });
908                }
909            }
910            Some(MessageContent::Array(parts)) => {
911                for part in parts {
912                    if let Some(text) = part.text {
913                        content_parts.push(LLMMessageTypedContent::Text { text });
914                    } else if let Some(image_url) = part.image_url {
915                        let (media_type, data) = if image_url.url.starts_with("data:") {
916                            let parts: Vec<&str> = image_url.url.splitn(2, ',').collect();
917                            if parts.len() == 2 {
918                                let meta = parts[0];
919                                let data = parts[1];
920                                let media_type = meta
921                                    .trim_start_matches("data:")
922                                    .trim_end_matches(";base64")
923                                    .to_string();
924                                (media_type, data.to_string())
925                            } else {
926                                ("image/jpeg".to_string(), image_url.url)
927                            }
928                        } else {
929                            ("image/jpeg".to_string(), image_url.url)
930                        };
931
932                        content_parts.push(LLMMessageTypedContent::Image {
933                            source: LLMMessageImageSource {
934                                r#type: "base64".to_string(),
935                                media_type,
936                                data,
937                            },
938                        });
939                    }
940                }
941            }
942            None => {}
943        }
944
945        if let Some(tool_calls) = chat_message.tool_calls {
946            for tool_call in tool_calls {
947                let args = serde_json::from_str(&tool_call.function.arguments).unwrap_or(json!({}));
948                content_parts.push(LLMMessageTypedContent::ToolCall {
949                    id: tool_call.id,
950                    name: tool_call.function.name,
951                    args,
952                    metadata: tool_call.metadata,
953                });
954            }
955        }
956
957        // Handle tool result messages: when role is Tool and tool_call_id is present,
958        // convert the content to a ToolResult content part. This is the generic
959        // intermediate representation - each provider's conversion layer handles
960        // the specifics (e.g., Anthropic converts to user role with tool_result blocks)
961        if chat_message.role == Role::Tool
962            && let Some(tool_call_id) = chat_message.tool_call_id
963        {
964            // Extract content as string for the tool result
965            let content_str = content_parts
966                .iter()
967                .filter_map(|p| match p {
968                    LLMMessageTypedContent::Text { text } => Some(text.clone()),
969                    _ => None,
970                })
971                .collect::<Vec<_>>()
972                .join("\n");
973
974            // Replace content with a single ToolResult
975            content_parts = vec![LLMMessageTypedContent::ToolResult {
976                tool_use_id: tool_call_id,
977                content: content_str,
978            }];
979        }
980
981        LLMMessage {
982            role: chat_message.role.to_string(),
983            content: if content_parts.is_empty() {
984                LLMMessageContent::String(String::new())
985            } else if content_parts.len() == 1 {
986                match &content_parts[0] {
987                    LLMMessageTypedContent::Text { text } => {
988                        LLMMessageContent::String(text.clone())
989                    }
990                    _ => LLMMessageContent::List(content_parts),
991                }
992            } else {
993                LLMMessageContent::List(content_parts)
994            },
995        }
996    }
997}
998
999impl From<GenerationDelta> for ChatMessageDelta {
1000    fn from(delta: GenerationDelta) -> Self {
1001        match delta {
1002            GenerationDelta::Content { content } => ChatMessageDelta {
1003                role: Some(Role::Assistant),
1004                content: Some(content),
1005                tool_calls: None,
1006            },
1007            GenerationDelta::Thinking { thinking: _ } => ChatMessageDelta {
1008                role: Some(Role::Assistant),
1009                content: None,
1010                tool_calls: None,
1011            },
1012            GenerationDelta::ToolUse { tool_use } => ChatMessageDelta {
1013                role: Some(Role::Assistant),
1014                content: None,
1015                tool_calls: Some(vec![ToolCallDelta {
1016                    index: tool_use.index,
1017                    id: tool_use.id,
1018                    r#type: Some("function".to_string()),
1019                    function: Some(FunctionCallDelta {
1020                        name: tool_use.name,
1021                        arguments: tool_use.input,
1022                    }),
1023                    metadata: tool_use.metadata,
1024                }]),
1025            },
1026            _ => ChatMessageDelta {
1027                role: Some(Role::Assistant),
1028                content: None,
1029                tool_calls: None,
1030            },
1031        }
1032    }
1033}
1034
1035#[cfg(test)]
1036mod tests {
1037    use super::*;
1038
1039    #[test]
1040    fn test_serialize_basic_request() {
1041        let request = ChatCompletionRequest {
1042            model: "gpt-4".to_string(),
1043            messages: vec![
1044                ChatMessage {
1045                    role: Role::System,
1046                    content: Some(MessageContent::String(
1047                        "You are a helpful assistant.".to_string(),
1048                    )),
1049                    name: None,
1050                    tool_calls: None,
1051                    tool_call_id: None,
1052                    usage: None,
1053                    ..Default::default()
1054                },
1055                ChatMessage {
1056                    role: Role::User,
1057                    content: Some(MessageContent::String("Hello!".to_string())),
1058                    name: None,
1059                    tool_calls: None,
1060                    tool_call_id: None,
1061                    usage: None,
1062                    ..Default::default()
1063                },
1064            ],
1065            frequency_penalty: None,
1066            logit_bias: None,
1067            logprobs: None,
1068            max_tokens: Some(100),
1069            n: None,
1070            presence_penalty: None,
1071            response_format: None,
1072            seed: None,
1073            stop: None,
1074            stream: None,
1075            temperature: Some(0.7),
1076            top_p: None,
1077            tools: None,
1078            tool_choice: None,
1079            user: None,
1080            context: None,
1081        };
1082
1083        let json = serde_json::to_string(&request).unwrap();
1084        assert!(json.contains("\"model\":\"gpt-4\""));
1085        assert!(json.contains("\"messages\":["));
1086        assert!(json.contains("\"role\":\"system\""));
1087    }
1088
1089    #[test]
1090    fn test_llm_message_to_chat_message() {
1091        let llm_message = LLMMessage {
1092            role: "user".to_string(),
1093            content: LLMMessageContent::String("Hello, world!".to_string()),
1094        };
1095
1096        let chat_message = ChatMessage::from(llm_message);
1097        assert_eq!(chat_message.role, Role::User);
1098        match &chat_message.content {
1099            Some(MessageContent::String(text)) => assert_eq!(text, "Hello, world!"),
1100            _ => panic!("Expected string content"),
1101        }
1102    }
1103
1104    #[test]
1105    fn test_chat_message_to_llm_message_tool_result() {
1106        // Test that Tool role messages with tool_call_id are converted to ToolResult content
1107        // This is critical for Anthropic compatibility - the provider layer converts
1108        // role="tool" to role="user" with tool_result content blocks
1109        let chat_message = ChatMessage {
1110            role: Role::Tool,
1111            content: Some(MessageContent::String("Tool execution result".to_string())),
1112            name: None,
1113            tool_calls: None,
1114            tool_call_id: Some("toolu_01Abc123".to_string()),
1115            usage: None,
1116            ..Default::default()
1117        };
1118
1119        let llm_message: LLMMessage = chat_message.into();
1120
1121        // Role should be preserved as "tool" - provider layer handles conversion
1122        assert_eq!(llm_message.role, "tool");
1123
1124        // Content should be a ToolResult with the tool_call_id
1125        match &llm_message.content {
1126            LLMMessageContent::List(parts) => {
1127                assert_eq!(parts.len(), 1, "Should have exactly one content part");
1128                match &parts[0] {
1129                    LLMMessageTypedContent::ToolResult {
1130                        tool_use_id,
1131                        content,
1132                    } => {
1133                        assert_eq!(tool_use_id, "toolu_01Abc123");
1134                        assert_eq!(content, "Tool execution result");
1135                    }
1136                    _ => panic!("Expected ToolResult content part, got {:?}", parts[0]),
1137                }
1138            }
1139            _ => panic!(
1140                "Expected List content with ToolResult, got {:?}",
1141                llm_message.content
1142            ),
1143        }
1144    }
1145
1146    #[test]
1147    fn test_chat_message_to_llm_message_tool_result_empty_content() {
1148        // Test tool result with empty content
1149        let chat_message = ChatMessage {
1150            role: Role::Tool,
1151            content: None,
1152            name: None,
1153            tool_calls: None,
1154            tool_call_id: Some("toolu_02Xyz789".to_string()),
1155            usage: None,
1156            ..Default::default()
1157        };
1158
1159        let llm_message: LLMMessage = chat_message.into();
1160
1161        assert_eq!(llm_message.role, "tool");
1162        match &llm_message.content {
1163            LLMMessageContent::List(parts) => {
1164                assert_eq!(parts.len(), 1);
1165                match &parts[0] {
1166                    LLMMessageTypedContent::ToolResult {
1167                        tool_use_id,
1168                        content,
1169                    } => {
1170                        assert_eq!(tool_use_id, "toolu_02Xyz789");
1171                        assert_eq!(content, ""); // Empty content
1172                    }
1173                    _ => panic!("Expected ToolResult content part"),
1174                }
1175            }
1176            _ => panic!("Expected List content with ToolResult"),
1177        }
1178    }
1179
1180    #[test]
1181    fn test_chat_message_to_llm_message_assistant_with_tool_calls() {
1182        // Test that assistant messages with tool_calls are converted correctly
1183        let chat_message = ChatMessage {
1184            role: Role::Assistant,
1185            content: Some(MessageContent::String(
1186                "I'll help you with that.".to_string(),
1187            )),
1188            name: None,
1189            tool_calls: Some(vec![ToolCall {
1190                id: "call_abc123".to_string(),
1191                r#type: "function".to_string(),
1192                function: FunctionCall {
1193                    name: "get_weather".to_string(),
1194                    arguments: r#"{"location": "Paris"}"#.to_string(),
1195                },
1196                metadata: None,
1197            }]),
1198            tool_call_id: None,
1199            usage: None,
1200            ..Default::default()
1201        };
1202
1203        let llm_message: LLMMessage = chat_message.into();
1204
1205        assert_eq!(llm_message.role, "assistant");
1206        match &llm_message.content {
1207            LLMMessageContent::List(parts) => {
1208                assert_eq!(parts.len(), 2, "Should have text and tool call");
1209
1210                // First part should be text
1211                match &parts[0] {
1212                    LLMMessageTypedContent::Text { text } => {
1213                        assert_eq!(text, "I'll help you with that.");
1214                    }
1215                    _ => panic!("Expected Text content part first"),
1216                }
1217
1218                // Second part should be tool call
1219                match &parts[1] {
1220                    LLMMessageTypedContent::ToolCall { id, name, args, .. } => {
1221                        assert_eq!(id, "call_abc123");
1222                        assert_eq!(name, "get_weather");
1223                        assert_eq!(args["location"], "Paris");
1224                    }
1225                    _ => panic!("Expected ToolCall content part second"),
1226                }
1227            }
1228            _ => panic!("Expected List content"),
1229        }
1230    }
1231}