Skip to main content

stakpak_shared/models/integrations/
openai.rs

1//! OpenAI provider configuration and chat message types
2//!
3//! This module contains:
4//! - Configuration types for OpenAI provider
5//! - OpenAI model enums with pricing info
6//! - Chat message types used throughout the TUI
7//! - Tool call types for agent interactions
8//!
9//! Note: Low-level API request/response types are in `libs/ai/src/providers/openai/`.
10
11use crate::models::llm::{
12    GenerationDelta, LLMMessage, LLMMessageContent, LLMMessageImageSource, LLMMessageTypedContent,
13    LLMTokenUsage, LLMTool,
14};
15use crate::models::model_pricing::{ContextAware, ContextPricingTier, ModelContextInfo};
16use serde::{Deserialize, Serialize};
17use serde_json::{Value, json};
18use uuid::Uuid;
19
20// =============================================================================
21// Provider Configuration
22// =============================================================================
23
24/// Configuration for OpenAI provider
25#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq)]
26pub struct OpenAIConfig {
27    pub api_endpoint: Option<String>,
28    pub api_key: Option<String>,
29}
30
31impl OpenAIConfig {
32    /// Create config with API key
33    pub fn with_api_key(api_key: impl Into<String>) -> Self {
34        Self {
35            api_key: Some(api_key.into()),
36            api_endpoint: None,
37        }
38    }
39
40    /// Create config from ProviderAuth (only supports API key for OpenAI)
41    pub fn from_provider_auth(auth: &crate::models::auth::ProviderAuth) -> Option<Self> {
42        match auth {
43            crate::models::auth::ProviderAuth::Api { key } => Some(Self::with_api_key(key)),
44            crate::models::auth::ProviderAuth::OAuth { .. } => None, // OpenAI doesn't support OAuth
45        }
46    }
47
48    /// Merge with credentials from ProviderAuth, preserving existing endpoint
49    pub fn with_provider_auth(mut self, auth: &crate::models::auth::ProviderAuth) -> Option<Self> {
50        match auth {
51            crate::models::auth::ProviderAuth::Api { key } => {
52                self.api_key = Some(key.clone());
53                Some(self)
54            }
55            crate::models::auth::ProviderAuth::OAuth { .. } => None, // OpenAI doesn't support OAuth
56        }
57    }
58}
59
60// =============================================================================
61// Model Definitions
62// =============================================================================
63
64/// OpenAI model identifiers
65#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
66pub enum OpenAIModel {
67    // Reasoning Models
68    #[serde(rename = "o3-2025-04-16")]
69    O3,
70    #[serde(rename = "o4-mini-2025-04-16")]
71    O4Mini,
72
73    #[default]
74    #[serde(rename = "gpt-5-2025-08-07")]
75    GPT5,
76    #[serde(rename = "gpt-5.1-2025-11-13")]
77    GPT51,
78    #[serde(rename = "gpt-5-mini-2025-08-07")]
79    GPT5Mini,
80    #[serde(rename = "gpt-5-nano-2025-08-07")]
81    GPT5Nano,
82
83    Custom(String),
84}
85
86impl OpenAIModel {
87    pub fn from_string(s: &str) -> Result<Self, String> {
88        serde_json::from_value(serde_json::Value::String(s.to_string()))
89            .map_err(|_| "Failed to deserialize OpenAI model".to_string())
90    }
91
92    /// Default smart model for OpenAI
93    pub const DEFAULT_SMART_MODEL: OpenAIModel = OpenAIModel::GPT5;
94
95    /// Default eco model for OpenAI
96    pub const DEFAULT_ECO_MODEL: OpenAIModel = OpenAIModel::GPT5Mini;
97
98    /// Default recovery model for OpenAI
99    pub const DEFAULT_RECOVERY_MODEL: OpenAIModel = OpenAIModel::GPT5Mini;
100
101    /// Get default smart model as string
102    pub fn default_smart_model() -> String {
103        Self::DEFAULT_SMART_MODEL.to_string()
104    }
105
106    /// Get default eco model as string
107    pub fn default_eco_model() -> String {
108        Self::DEFAULT_ECO_MODEL.to_string()
109    }
110
111    /// Get default recovery model as string
112    pub fn default_recovery_model() -> String {
113        Self::DEFAULT_RECOVERY_MODEL.to_string()
114    }
115}
116
117impl ContextAware for OpenAIModel {
118    fn context_info(&self) -> ModelContextInfo {
119        let model_name = self.to_string();
120
121        if model_name.starts_with("o3") {
122            return ModelContextInfo {
123                max_tokens: 200_000,
124                pricing_tiers: vec![ContextPricingTier {
125                    label: "Standard".to_string(),
126                    input_cost_per_million: 2.0,
127                    output_cost_per_million: 8.0,
128                    upper_bound: None,
129                }],
130                approach_warning_threshold: 0.8,
131            };
132        }
133
134        if model_name.starts_with("o4-mini") {
135            return ModelContextInfo {
136                max_tokens: 200_000,
137                pricing_tiers: vec![ContextPricingTier {
138                    label: "Standard".to_string(),
139                    input_cost_per_million: 1.10,
140                    output_cost_per_million: 4.40,
141                    upper_bound: None,
142                }],
143                approach_warning_threshold: 0.8,
144            };
145        }
146
147        if model_name.starts_with("gpt-5-mini") {
148            return ModelContextInfo {
149                max_tokens: 400_000,
150                pricing_tiers: vec![ContextPricingTier {
151                    label: "Standard".to_string(),
152                    input_cost_per_million: 0.25,
153                    output_cost_per_million: 2.0,
154                    upper_bound: None,
155                }],
156                approach_warning_threshold: 0.8,
157            };
158        }
159
160        if model_name.starts_with("gpt-5-nano") {
161            return ModelContextInfo {
162                max_tokens: 400_000,
163                pricing_tiers: vec![ContextPricingTier {
164                    label: "Standard".to_string(),
165                    input_cost_per_million: 0.05,
166                    output_cost_per_million: 0.40,
167                    upper_bound: None,
168                }],
169                approach_warning_threshold: 0.8,
170            };
171        }
172
173        if model_name.starts_with("gpt-5") {
174            return ModelContextInfo {
175                max_tokens: 400_000,
176                pricing_tiers: vec![ContextPricingTier {
177                    label: "Standard".to_string(),
178                    input_cost_per_million: 1.25,
179                    output_cost_per_million: 10.0,
180                    upper_bound: None,
181                }],
182                approach_warning_threshold: 0.8,
183            };
184        }
185
186        ModelContextInfo::default()
187    }
188
189    fn model_name(&self) -> String {
190        match self {
191            OpenAIModel::O3 => "O3".to_string(),
192            OpenAIModel::O4Mini => "O4-mini".to_string(),
193            OpenAIModel::GPT5 => "GPT-5".to_string(),
194            OpenAIModel::GPT51 => "GPT-5.1".to_string(),
195            OpenAIModel::GPT5Mini => "GPT-5 Mini".to_string(),
196            OpenAIModel::GPT5Nano => "GPT-5 Nano".to_string(),
197            OpenAIModel::Custom(name) => format!("Custom ({})", name),
198        }
199    }
200}
201
202impl std::fmt::Display for OpenAIModel {
203    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204        match self {
205            OpenAIModel::O3 => write!(f, "o3-2025-04-16"),
206            OpenAIModel::O4Mini => write!(f, "o4-mini-2025-04-16"),
207            OpenAIModel::GPT5Nano => write!(f, "gpt-5-nano-2025-08-07"),
208            OpenAIModel::GPT5Mini => write!(f, "gpt-5-mini-2025-08-07"),
209            OpenAIModel::GPT5 => write!(f, "gpt-5-2025-08-07"),
210            OpenAIModel::GPT51 => write!(f, "gpt-5.1-2025-11-13"),
211            OpenAIModel::Custom(model_name) => write!(f, "{}", model_name),
212        }
213    }
214}
215
216// =============================================================================
217// Message Types (used by TUI)
218// =============================================================================
219
220/// Message role
221#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
222#[serde(rename_all = "lowercase")]
223pub enum Role {
224    System,
225    Developer,
226    User,
227    #[default]
228    Assistant,
229    Tool,
230}
231
232impl std::fmt::Display for Role {
233    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
234        match self {
235            Role::System => write!(f, "system"),
236            Role::Developer => write!(f, "developer"),
237            Role::User => write!(f, "user"),
238            Role::Assistant => write!(f, "assistant"),
239            Role::Tool => write!(f, "tool"),
240        }
241    }
242}
243
244/// Model info for tracking which model generated a message
245#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
246pub struct ModelInfo {
247    /// Provider name (e.g., "anthropic", "openai")
248    pub provider: String,
249    /// Model identifier (e.g., "claude-sonnet-4-20250514", "gpt-4")
250    pub id: String,
251}
252
253/// Chat message
254#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
255pub struct ChatMessage {
256    pub role: Role,
257    pub content: Option<MessageContent>,
258    #[serde(skip_serializing_if = "Option::is_none")]
259    pub name: Option<String>,
260    #[serde(skip_serializing_if = "Option::is_none")]
261    pub tool_calls: Option<Vec<ToolCall>>,
262    #[serde(skip_serializing_if = "Option::is_none")]
263    pub tool_call_id: Option<String>,
264    #[serde(skip_serializing_if = "Option::is_none")]
265    pub usage: Option<LLMTokenUsage>,
266
267    // === Extended fields for session tracking ===
268    /// Unique message identifier
269    #[serde(skip_serializing_if = "Option::is_none")]
270    pub id: Option<String>,
271    /// Model that generated this message (for assistant messages)
272    #[serde(skip_serializing_if = "Option::is_none")]
273    pub model: Option<ModelInfo>,
274    /// Cost in dollars for this message
275    #[serde(skip_serializing_if = "Option::is_none")]
276    pub cost: Option<f64>,
277    /// Why the model stopped: "stop", "tool_calls", "length", "error"
278    #[serde(skip_serializing_if = "Option::is_none")]
279    pub finish_reason: Option<String>,
280    /// Unix timestamp (ms) when message was created/sent
281    #[serde(skip_serializing_if = "Option::is_none")]
282    pub created_at: Option<i64>,
283    /// Unix timestamp (ms) when assistant finished generating
284    #[serde(skip_serializing_if = "Option::is_none")]
285    pub completed_at: Option<i64>,
286    /// Plugin extensibility - unstructured metadata
287    #[serde(skip_serializing_if = "Option::is_none")]
288    pub metadata: Option<serde_json::Value>,
289}
290
291impl ChatMessage {
292    pub fn last_server_message(messages: &[ChatMessage]) -> Option<&ChatMessage> {
293        messages
294            .iter()
295            .rev()
296            .find(|message| message.role != Role::User && message.role != Role::Tool)
297    }
298
299    pub fn to_xml(&self) -> String {
300        match &self.content {
301            Some(MessageContent::String(s)) => {
302                format!("<message role=\"{}\">{}</message>", self.role, s)
303            }
304            Some(MessageContent::Array(parts)) => parts
305                .iter()
306                .map(|part| {
307                    format!(
308                        "<message role=\"{}\" type=\"{}\">{}</message>",
309                        self.role,
310                        part.r#type,
311                        part.text.clone().unwrap_or_default()
312                    )
313                })
314                .collect::<Vec<String>>()
315                .join("\n"),
316            None => String::new(),
317        }
318    }
319}
320
321/// Message content (string or array of parts)
322#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
323#[serde(untagged)]
324pub enum MessageContent {
325    String(String),
326    Array(Vec<ContentPart>),
327}
328
329impl MessageContent {
330    pub fn inject_checkpoint_id(&self, checkpoint_id: Uuid) -> Self {
331        match self {
332            MessageContent::String(s) => MessageContent::String(format!(
333                "<checkpoint_id>{checkpoint_id}</checkpoint_id>\n{s}"
334            )),
335            MessageContent::Array(parts) => MessageContent::Array(
336                std::iter::once(ContentPart {
337                    r#type: "text".to_string(),
338                    text: Some(format!("<checkpoint_id>{checkpoint_id}</checkpoint_id>")),
339                    image_url: None,
340                })
341                .chain(parts.iter().cloned())
342                .collect(),
343            ),
344        }
345    }
346
347    /// All indices from rfind()/find() of ASCII XML tags on same string
348    #[allow(clippy::string_slice)]
349    pub fn extract_checkpoint_id(&self) -> Option<Uuid> {
350        match self {
351            MessageContent::String(s) => s
352                .rfind("<checkpoint_id>")
353                .and_then(|start| {
354                    s[start..]
355                        .find("</checkpoint_id>")
356                        .map(|end| (start + "<checkpoint_id>".len(), start + end))
357                })
358                .and_then(|(start, end)| Uuid::parse_str(&s[start..end]).ok()),
359            MessageContent::Array(parts) => parts.iter().rev().find_map(|part| {
360                part.text.as_deref().and_then(|text| {
361                    text.rfind("<checkpoint_id>")
362                        .and_then(|start| {
363                            text[start..]
364                                .find("</checkpoint_id>")
365                                .map(|end| (start + "<checkpoint_id>".len(), start + end))
366                        })
367                        .and_then(|(start, end)| Uuid::parse_str(&text[start..end]).ok())
368                })
369            }),
370        }
371    }
372}
373
374impl std::fmt::Display for MessageContent {
375    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
376        match self {
377            MessageContent::String(s) => write!(f, "{s}"),
378            MessageContent::Array(parts) => {
379                let text_parts: Vec<String> =
380                    parts.iter().filter_map(|part| part.text.clone()).collect();
381                write!(f, "{}", text_parts.join("\n"))
382            }
383        }
384    }
385}
386
387impl Default for MessageContent {
388    fn default() -> Self {
389        MessageContent::String(String::new())
390    }
391}
392
393/// Content part (text or image)
394#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
395pub struct ContentPart {
396    pub r#type: String,
397    #[serde(skip_serializing_if = "Option::is_none")]
398    pub text: Option<String>,
399    #[serde(skip_serializing_if = "Option::is_none")]
400    pub image_url: Option<ImageUrl>,
401}
402
403/// Image URL with optional detail level
404#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
405pub struct ImageUrl {
406    pub url: String,
407    #[serde(skip_serializing_if = "Option::is_none")]
408    pub detail: Option<String>,
409}
410
411// =============================================================================
412// Tool Types (used by TUI)
413// =============================================================================
414
415/// Tool definition
416#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
417pub struct Tool {
418    pub r#type: String,
419    pub function: FunctionDefinition,
420}
421
422/// Function definition for tools
423#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
424pub struct FunctionDefinition {
425    pub name: String,
426    pub description: Option<String>,
427    pub parameters: serde_json::Value,
428}
429
430impl From<Tool> for LLMTool {
431    fn from(tool: Tool) -> Self {
432        LLMTool {
433            name: tool.function.name,
434            description: tool.function.description.unwrap_or_default(),
435            input_schema: tool.function.parameters,
436        }
437    }
438}
439
440/// Tool choice configuration
441#[derive(Debug, Clone, PartialEq)]
442pub enum ToolChoice {
443    Auto,
444    Required,
445    Object(ToolChoiceObject),
446}
447
448impl Serialize for ToolChoice {
449    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
450    where
451        S: serde::Serializer,
452    {
453        match self {
454            ToolChoice::Auto => serializer.serialize_str("auto"),
455            ToolChoice::Required => serializer.serialize_str("required"),
456            ToolChoice::Object(obj) => obj.serialize(serializer),
457        }
458    }
459}
460
461impl<'de> Deserialize<'de> for ToolChoice {
462    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
463    where
464        D: serde::Deserializer<'de>,
465    {
466        struct ToolChoiceVisitor;
467
468        impl<'de> serde::de::Visitor<'de> for ToolChoiceVisitor {
469            type Value = ToolChoice;
470
471            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
472                formatter.write_str("string or object")
473            }
474
475            fn visit_str<E>(self, value: &str) -> Result<ToolChoice, E>
476            where
477                E: serde::de::Error,
478            {
479                match value {
480                    "auto" => Ok(ToolChoice::Auto),
481                    "required" => Ok(ToolChoice::Required),
482                    _ => Err(serde::de::Error::unknown_variant(
483                        value,
484                        &["auto", "required"],
485                    )),
486                }
487            }
488
489            fn visit_map<M>(self, map: M) -> Result<ToolChoice, M::Error>
490            where
491                M: serde::de::MapAccess<'de>,
492            {
493                let obj = ToolChoiceObject::deserialize(
494                    serde::de::value::MapAccessDeserializer::new(map),
495                )?;
496                Ok(ToolChoice::Object(obj))
497            }
498        }
499
500        deserializer.deserialize_any(ToolChoiceVisitor)
501    }
502}
503
504#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
505pub struct ToolChoiceObject {
506    pub r#type: String,
507    pub function: FunctionChoice,
508}
509
510#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
511pub struct FunctionChoice {
512    pub name: String,
513}
514
515/// Tool call from assistant
516#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
517pub struct ToolCall {
518    pub id: String,
519    pub r#type: String,
520    pub function: FunctionCall,
521    /// Opaque provider-specific metadata (e.g., Gemini thought_signature)
522    #[serde(skip_serializing_if = "Option::is_none")]
523    pub metadata: Option<serde_json::Value>,
524}
525
526/// Function call details
527#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
528pub struct FunctionCall {
529    pub name: String,
530    pub arguments: String,
531}
532
533/// Tool call result status
534#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
535pub enum ToolCallResultStatus {
536    Success,
537    Error,
538    Cancelled,
539}
540
541/// Tool call result
542#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
543pub struct ToolCallResult {
544    pub call: ToolCall,
545    pub result: String,
546    pub status: ToolCallResultStatus,
547}
548
549/// Streaming progress info for a tool call being generated by the LLM
550#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
551pub struct ToolCallStreamInfo {
552    /// Tool name (may be empty if not yet streamed)
553    pub name: String,
554    /// Estimated token count of arguments streamed so far (~chars/4)
555    pub args_tokens: usize,
556    /// Optional description extracted from arguments (best-effort, may be None if JSON incomplete)
557    #[serde(skip_serializing_if = "Option::is_none")]
558    pub description: Option<String>,
559}
560
561/// Tool call result progress update
562#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
563pub struct ToolCallResultProgress {
564    pub id: Uuid,
565    pub message: String,
566    /// Type of progress update for specialized handling
567    #[serde(skip_serializing_if = "Option::is_none")]
568    pub progress_type: Option<ProgressType>,
569    /// Structured task updates for task wait progress
570    #[serde(skip_serializing_if = "Option::is_none")]
571    pub task_updates: Option<Vec<TaskUpdate>>,
572    /// Overall progress percentage (0.0 - 100.0)
573    #[serde(skip_serializing_if = "Option::is_none")]
574    pub progress: Option<f64>,
575}
576
577/// Type of progress update
578#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
579pub enum ProgressType {
580    /// Command output streaming
581    CommandOutput,
582    /// Task wait progress with structured updates
583    TaskWait,
584    /// Generic progress
585    Generic,
586}
587
588/// Structured task status update
589#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
590pub struct TaskUpdate {
591    pub task_id: String,
592    pub status: String,
593    #[serde(skip_serializing_if = "Option::is_none")]
594    pub description: Option<String>,
595    #[serde(skip_serializing_if = "Option::is_none")]
596    pub duration_secs: Option<f64>,
597    #[serde(skip_serializing_if = "Option::is_none")]
598    pub output_preview: Option<String>,
599    /// Whether this is a target task being waited on
600    #[serde(default)]
601    pub is_target: bool,
602    /// Pause information for paused subagent tasks
603    #[serde(skip_serializing_if = "Option::is_none")]
604    pub pause_info: Option<TaskPauseInfo>,
605}
606
607/// Pause information for subagent tasks awaiting approval
608#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
609pub struct TaskPauseInfo {
610    /// The agent's message before pausing
611    #[serde(skip_serializing_if = "Option::is_none")]
612    pub agent_message: Option<String>,
613    /// Pending tool calls awaiting approval
614    #[serde(skip_serializing_if = "Option::is_none")]
615    pub pending_tool_calls: Option<Vec<crate::models::async_manifest::PendingToolCall>>,
616}
617
618// =============================================================================
619// Chat Completion Types (used by TUI)
620// =============================================================================
621
622/// Chat completion request
623#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
624pub struct ChatCompletionRequest {
625    pub model: String,
626    pub messages: Vec<ChatMessage>,
627    #[serde(skip_serializing_if = "Option::is_none")]
628    pub frequency_penalty: Option<f32>,
629    #[serde(skip_serializing_if = "Option::is_none")]
630    pub logit_bias: Option<serde_json::Value>,
631    #[serde(skip_serializing_if = "Option::is_none")]
632    pub logprobs: Option<bool>,
633    #[serde(skip_serializing_if = "Option::is_none")]
634    pub max_tokens: Option<u32>,
635    #[serde(skip_serializing_if = "Option::is_none")]
636    pub n: Option<u32>,
637    #[serde(skip_serializing_if = "Option::is_none")]
638    pub presence_penalty: Option<f32>,
639    #[serde(skip_serializing_if = "Option::is_none")]
640    pub response_format: Option<ResponseFormat>,
641    #[serde(skip_serializing_if = "Option::is_none")]
642    pub seed: Option<i64>,
643    #[serde(skip_serializing_if = "Option::is_none")]
644    pub stop: Option<StopSequence>,
645    #[serde(skip_serializing_if = "Option::is_none")]
646    pub stream: Option<bool>,
647    #[serde(skip_serializing_if = "Option::is_none")]
648    pub temperature: Option<f32>,
649    #[serde(skip_serializing_if = "Option::is_none")]
650    pub top_p: Option<f32>,
651    #[serde(skip_serializing_if = "Option::is_none")]
652    pub tools: Option<Vec<Tool>>,
653    #[serde(skip_serializing_if = "Option::is_none")]
654    pub tool_choice: Option<ToolChoice>,
655    #[serde(skip_serializing_if = "Option::is_none")]
656    pub user: Option<String>,
657    #[serde(skip_serializing_if = "Option::is_none")]
658    pub context: Option<ChatCompletionContext>,
659}
660
661impl ChatCompletionRequest {
662    pub fn new(
663        model: String,
664        messages: Vec<ChatMessage>,
665        tools: Option<Vec<Tool>>,
666        stream: Option<bool>,
667    ) -> Self {
668        Self {
669            model,
670            messages,
671            frequency_penalty: None,
672            logit_bias: None,
673            logprobs: None,
674            max_tokens: None,
675            n: None,
676            presence_penalty: None,
677            response_format: None,
678            seed: None,
679            stop: None,
680            stream,
681            temperature: None,
682            top_p: None,
683            tools,
684            tool_choice: None,
685            user: None,
686            context: None,
687        }
688    }
689}
690
691#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
692pub struct ChatCompletionContext {
693    pub scratchpad: Option<Value>,
694}
695
696#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
697pub struct ResponseFormat {
698    pub r#type: String,
699}
700
701#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
702#[serde(untagged)]
703pub enum StopSequence {
704    String(String),
705    Array(Vec<String>),
706}
707
708/// Chat completion response
709#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
710pub struct ChatCompletionResponse {
711    pub id: String,
712    pub object: String,
713    pub created: u64,
714    pub model: String,
715    pub choices: Vec<ChatCompletionChoice>,
716    pub usage: LLMTokenUsage,
717    #[serde(skip_serializing_if = "Option::is_none")]
718    pub system_fingerprint: Option<String>,
719    pub metadata: Option<serde_json::Value>,
720}
721
722#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
723pub struct ChatCompletionChoice {
724    pub index: usize,
725    pub message: ChatMessage,
726    pub logprobs: Option<LogProbs>,
727    pub finish_reason: FinishReason,
728}
729
730#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
731#[serde(rename_all = "snake_case")]
732pub enum FinishReason {
733    Stop,
734    Length,
735    ContentFilter,
736    ToolCalls,
737}
738
739#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
740pub struct LogProbs {
741    pub content: Option<Vec<LogProbContent>>,
742}
743
744#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
745pub struct LogProbContent {
746    pub token: String,
747    pub logprob: f32,
748    pub bytes: Option<Vec<u8>>,
749    pub top_logprobs: Option<Vec<TokenLogprob>>,
750}
751
752#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
753pub struct TokenLogprob {
754    pub token: String,
755    pub logprob: f32,
756    pub bytes: Option<Vec<u8>>,
757}
758
759// =============================================================================
760// Streaming Types
761// =============================================================================
762
763#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
764pub struct ChatCompletionStreamResponse {
765    pub id: String,
766    pub object: String,
767    pub created: u64,
768    pub model: String,
769    pub choices: Vec<ChatCompletionStreamChoice>,
770    pub usage: Option<LLMTokenUsage>,
771    pub metadata: Option<serde_json::Value>,
772}
773
774#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
775pub struct ChatCompletionStreamChoice {
776    pub index: usize,
777    pub delta: ChatMessageDelta,
778    pub finish_reason: Option<FinishReason>,
779}
780
781#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
782pub struct ChatMessageDelta {
783    #[serde(skip_serializing_if = "Option::is_none")]
784    pub role: Option<Role>,
785    #[serde(skip_serializing_if = "Option::is_none")]
786    pub content: Option<String>,
787    #[serde(skip_serializing_if = "Option::is_none")]
788    pub tool_calls: Option<Vec<ToolCallDelta>>,
789}
790
791#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
792pub struct ToolCallDelta {
793    pub index: usize,
794    pub id: Option<String>,
795    pub r#type: Option<String>,
796    pub function: Option<FunctionCallDelta>,
797    /// Opaque provider-specific metadata (e.g., Gemini thought_signature)
798    #[serde(skip_serializing_if = "Option::is_none")]
799    pub metadata: Option<serde_json::Value>,
800}
801
802#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
803pub struct FunctionCallDelta {
804    pub name: Option<String>,
805    pub arguments: Option<String>,
806}
807
808// =============================================================================
809// Conversions
810// =============================================================================
811
812impl From<LLMMessage> for ChatMessage {
813    fn from(llm_message: LLMMessage) -> Self {
814        let role = match llm_message.role.as_str() {
815            "system" => Role::System,
816            "user" => Role::User,
817            "assistant" => Role::Assistant,
818            "tool" => Role::Tool,
819            "developer" => Role::Developer,
820            _ => Role::User,
821        };
822
823        let (content, tool_calls, tool_call_id) = match llm_message.content {
824            LLMMessageContent::String(text) => (Some(MessageContent::String(text)), None, None),
825            LLMMessageContent::List(items) => {
826                let mut text_parts = Vec::new();
827                let mut tool_call_parts = Vec::new();
828                let mut tool_result_id: Option<String> = None;
829
830                for item in items {
831                    match item {
832                        LLMMessageTypedContent::Text { text } => {
833                            text_parts.push(ContentPart {
834                                r#type: "text".to_string(),
835                                text: Some(text),
836                                image_url: None,
837                            });
838                        }
839                        LLMMessageTypedContent::ToolCall {
840                            id,
841                            name,
842                            args,
843                            metadata,
844                        } => {
845                            tool_call_parts.push(ToolCall {
846                                id,
847                                r#type: "function".to_string(),
848                                function: FunctionCall {
849                                    name,
850                                    arguments: args.to_string(),
851                                },
852                                metadata,
853                            });
854                        }
855                        LLMMessageTypedContent::ToolResult {
856                            tool_use_id,
857                            content,
858                        } => {
859                            if tool_result_id.is_none() {
860                                tool_result_id = Some(tool_use_id);
861                            }
862                            text_parts.push(ContentPart {
863                                r#type: "text".to_string(),
864                                text: Some(content),
865                                image_url: None,
866                            });
867                        }
868                        LLMMessageTypedContent::Image { source } => {
869                            text_parts.push(ContentPart {
870                                r#type: "image_url".to_string(),
871                                text: None,
872                                image_url: Some(ImageUrl {
873                                    url: format!(
874                                        "data:{};base64,{}",
875                                        source.media_type, source.data
876                                    ),
877                                    detail: None,
878                                }),
879                            });
880                        }
881                    }
882                }
883
884                let content = if !text_parts.is_empty() {
885                    Some(MessageContent::Array(text_parts))
886                } else {
887                    None
888                };
889
890                let tool_calls = if !tool_call_parts.is_empty() {
891                    Some(tool_call_parts)
892                } else {
893                    None
894                };
895
896                (content, tool_calls, tool_result_id)
897            }
898        };
899
900        ChatMessage {
901            role,
902            content,
903            name: None,
904            tool_calls,
905            tool_call_id,
906            usage: None,
907            ..Default::default()
908        }
909    }
910}
911
912impl From<ChatMessage> for LLMMessage {
913    fn from(chat_message: ChatMessage) -> Self {
914        let mut content_parts = Vec::new();
915
916        match chat_message.content {
917            Some(MessageContent::String(s)) => {
918                if !s.is_empty() {
919                    content_parts.push(LLMMessageTypedContent::Text { text: s });
920                }
921            }
922            Some(MessageContent::Array(parts)) => {
923                for part in parts {
924                    if let Some(text) = part.text {
925                        content_parts.push(LLMMessageTypedContent::Text { text });
926                    } else if let Some(image_url) = part.image_url {
927                        let (media_type, data) = if image_url.url.starts_with("data:") {
928                            let parts: Vec<&str> = image_url.url.splitn(2, ',').collect();
929                            if parts.len() == 2 {
930                                let meta = parts[0];
931                                let data = parts[1];
932                                let media_type = meta
933                                    .trim_start_matches("data:")
934                                    .trim_end_matches(";base64")
935                                    .to_string();
936                                (media_type, data.to_string())
937                            } else {
938                                ("image/jpeg".to_string(), image_url.url)
939                            }
940                        } else {
941                            ("image/jpeg".to_string(), image_url.url)
942                        };
943
944                        content_parts.push(LLMMessageTypedContent::Image {
945                            source: LLMMessageImageSource {
946                                r#type: "base64".to_string(),
947                                media_type,
948                                data,
949                            },
950                        });
951                    }
952                }
953            }
954            None => {}
955        }
956
957        if let Some(tool_calls) = chat_message.tool_calls {
958            for tool_call in tool_calls {
959                let args = serde_json::from_str(&tool_call.function.arguments).unwrap_or(json!({}));
960                content_parts.push(LLMMessageTypedContent::ToolCall {
961                    id: tool_call.id,
962                    name: tool_call.function.name,
963                    args,
964                    metadata: tool_call.metadata,
965                });
966            }
967        }
968
969        // Handle tool result messages: when role is Tool and tool_call_id is present,
970        // convert the content to a ToolResult content part. This is the generic
971        // intermediate representation - each provider's conversion layer handles
972        // the specifics (e.g., Anthropic converts to user role with tool_result blocks)
973        if chat_message.role == Role::Tool
974            && let Some(tool_call_id) = chat_message.tool_call_id
975        {
976            // Extract content as string for the tool result
977            let content_str = content_parts
978                .iter()
979                .filter_map(|p| match p {
980                    LLMMessageTypedContent::Text { text } => Some(text.clone()),
981                    _ => None,
982                })
983                .collect::<Vec<_>>()
984                .join("\n");
985
986            // Replace content with a single ToolResult
987            content_parts = vec![LLMMessageTypedContent::ToolResult {
988                tool_use_id: tool_call_id,
989                content: content_str,
990            }];
991        }
992
993        LLMMessage {
994            role: chat_message.role.to_string(),
995            content: if content_parts.is_empty() {
996                LLMMessageContent::String(String::new())
997            } else if content_parts.len() == 1 {
998                match &content_parts[0] {
999                    LLMMessageTypedContent::Text { text } => {
1000                        LLMMessageContent::String(text.clone())
1001                    }
1002                    _ => LLMMessageContent::List(content_parts),
1003                }
1004            } else {
1005                LLMMessageContent::List(content_parts)
1006            },
1007        }
1008    }
1009}
1010
1011impl From<GenerationDelta> for ChatMessageDelta {
1012    fn from(delta: GenerationDelta) -> Self {
1013        match delta {
1014            GenerationDelta::Content { content } => ChatMessageDelta {
1015                role: Some(Role::Assistant),
1016                content: Some(content),
1017                tool_calls: None,
1018            },
1019            GenerationDelta::Thinking { thinking: _ } => ChatMessageDelta {
1020                role: Some(Role::Assistant),
1021                content: None,
1022                tool_calls: None,
1023            },
1024            GenerationDelta::ToolUse { tool_use } => ChatMessageDelta {
1025                role: Some(Role::Assistant),
1026                content: None,
1027                tool_calls: Some(vec![ToolCallDelta {
1028                    index: tool_use.index,
1029                    id: tool_use.id,
1030                    r#type: Some("function".to_string()),
1031                    function: Some(FunctionCallDelta {
1032                        name: tool_use.name,
1033                        arguments: tool_use.input,
1034                    }),
1035                    metadata: tool_use.metadata,
1036                }]),
1037            },
1038            _ => ChatMessageDelta {
1039                role: Some(Role::Assistant),
1040                content: None,
1041                tool_calls: None,
1042            },
1043        }
1044    }
1045}
1046
1047#[cfg(test)]
1048mod tests {
1049    use super::*;
1050
1051    #[test]
1052    fn test_serialize_basic_request() {
1053        let request = ChatCompletionRequest {
1054            model: "gpt-4".to_string(),
1055            messages: vec![
1056                ChatMessage {
1057                    role: Role::System,
1058                    content: Some(MessageContent::String(
1059                        "You are a helpful assistant.".to_string(),
1060                    )),
1061                    name: None,
1062                    tool_calls: None,
1063                    tool_call_id: None,
1064                    usage: None,
1065                    ..Default::default()
1066                },
1067                ChatMessage {
1068                    role: Role::User,
1069                    content: Some(MessageContent::String("Hello!".to_string())),
1070                    name: None,
1071                    tool_calls: None,
1072                    tool_call_id: None,
1073                    usage: None,
1074                    ..Default::default()
1075                },
1076            ],
1077            frequency_penalty: None,
1078            logit_bias: None,
1079            logprobs: None,
1080            max_tokens: Some(100),
1081            n: None,
1082            presence_penalty: None,
1083            response_format: None,
1084            seed: None,
1085            stop: None,
1086            stream: None,
1087            temperature: Some(0.7),
1088            top_p: None,
1089            tools: None,
1090            tool_choice: None,
1091            user: None,
1092            context: None,
1093        };
1094
1095        let json = serde_json::to_string(&request).unwrap();
1096        assert!(json.contains("\"model\":\"gpt-4\""));
1097        assert!(json.contains("\"messages\":["));
1098        assert!(json.contains("\"role\":\"system\""));
1099    }
1100
1101    #[test]
1102    fn test_llm_message_to_chat_message() {
1103        let llm_message = LLMMessage {
1104            role: "user".to_string(),
1105            content: LLMMessageContent::String("Hello, world!".to_string()),
1106        };
1107
1108        let chat_message = ChatMessage::from(llm_message);
1109        assert_eq!(chat_message.role, Role::User);
1110        match &chat_message.content {
1111            Some(MessageContent::String(text)) => assert_eq!(text, "Hello, world!"),
1112            _ => panic!("Expected string content"),
1113        }
1114    }
1115
1116    #[test]
1117    fn test_llm_message_to_chat_message_tool_result_preserves_tool_call_id() {
1118        let llm_message = LLMMessage {
1119            role: "tool".to_string(),
1120            content: LLMMessageContent::List(vec![LLMMessageTypedContent::ToolResult {
1121                tool_use_id: "toolu_01Abc123".to_string(),
1122                content: "Tool execution result".to_string(),
1123            }]),
1124        };
1125
1126        let chat_message = ChatMessage::from(llm_message);
1127        assert_eq!(chat_message.role, Role::Tool);
1128        assert_eq!(chat_message.tool_call_id.as_deref(), Some("toolu_01Abc123"));
1129        assert_eq!(
1130            chat_message.content,
1131            Some(MessageContent::Array(vec![ContentPart {
1132                r#type: "text".to_string(),
1133                text: Some("Tool execution result".to_string()),
1134                image_url: None,
1135            }]))
1136        );
1137    }
1138
1139    #[test]
1140    fn test_chat_message_to_llm_message_tool_result() {
1141        // Test that Tool role messages with tool_call_id are converted to ToolResult content
1142        // This is critical for Anthropic compatibility - the provider layer converts
1143        // role="tool" to role="user" with tool_result content blocks
1144        let chat_message = ChatMessage {
1145            role: Role::Tool,
1146            content: Some(MessageContent::String("Tool execution result".to_string())),
1147            name: None,
1148            tool_calls: None,
1149            tool_call_id: Some("toolu_01Abc123".to_string()),
1150            usage: None,
1151            ..Default::default()
1152        };
1153
1154        let llm_message: LLMMessage = chat_message.into();
1155
1156        // Role should be preserved as "tool" - provider layer handles conversion
1157        assert_eq!(llm_message.role, "tool");
1158
1159        // Content should be a ToolResult with the tool_call_id
1160        match &llm_message.content {
1161            LLMMessageContent::List(parts) => {
1162                assert_eq!(parts.len(), 1, "Should have exactly one content part");
1163                match &parts[0] {
1164                    LLMMessageTypedContent::ToolResult {
1165                        tool_use_id,
1166                        content,
1167                    } => {
1168                        assert_eq!(tool_use_id, "toolu_01Abc123");
1169                        assert_eq!(content, "Tool execution result");
1170                    }
1171                    _ => panic!("Expected ToolResult content part, got {:?}", parts[0]),
1172                }
1173            }
1174            _ => panic!(
1175                "Expected List content with ToolResult, got {:?}",
1176                llm_message.content
1177            ),
1178        }
1179    }
1180
1181    #[test]
1182    fn test_chat_message_to_llm_message_tool_result_empty_content() {
1183        // Test tool result with empty content
1184        let chat_message = ChatMessage {
1185            role: Role::Tool,
1186            content: None,
1187            name: None,
1188            tool_calls: None,
1189            tool_call_id: Some("toolu_02Xyz789".to_string()),
1190            usage: None,
1191            ..Default::default()
1192        };
1193
1194        let llm_message: LLMMessage = chat_message.into();
1195
1196        assert_eq!(llm_message.role, "tool");
1197        match &llm_message.content {
1198            LLMMessageContent::List(parts) => {
1199                assert_eq!(parts.len(), 1);
1200                match &parts[0] {
1201                    LLMMessageTypedContent::ToolResult {
1202                        tool_use_id,
1203                        content,
1204                    } => {
1205                        assert_eq!(tool_use_id, "toolu_02Xyz789");
1206                        assert_eq!(content, ""); // Empty content
1207                    }
1208                    _ => panic!("Expected ToolResult content part"),
1209                }
1210            }
1211            _ => panic!("Expected List content with ToolResult"),
1212        }
1213    }
1214
1215    #[test]
1216    fn test_chat_message_to_llm_message_assistant_with_tool_calls() {
1217        // Test that assistant messages with tool_calls are converted correctly
1218        let chat_message = ChatMessage {
1219            role: Role::Assistant,
1220            content: Some(MessageContent::String(
1221                "I'll help you with that.".to_string(),
1222            )),
1223            name: None,
1224            tool_calls: Some(vec![ToolCall {
1225                id: "call_abc123".to_string(),
1226                r#type: "function".to_string(),
1227                function: FunctionCall {
1228                    name: "get_weather".to_string(),
1229                    arguments: r#"{"location": "Paris"}"#.to_string(),
1230                },
1231                metadata: None,
1232            }]),
1233            tool_call_id: None,
1234            usage: None,
1235            ..Default::default()
1236        };
1237
1238        let llm_message: LLMMessage = chat_message.into();
1239
1240        assert_eq!(llm_message.role, "assistant");
1241        match &llm_message.content {
1242            LLMMessageContent::List(parts) => {
1243                assert_eq!(parts.len(), 2, "Should have text and tool call");
1244
1245                // First part should be text
1246                match &parts[0] {
1247                    LLMMessageTypedContent::Text { text } => {
1248                        assert_eq!(text, "I'll help you with that.");
1249                    }
1250                    _ => panic!("Expected Text content part first"),
1251                }
1252
1253                // Second part should be tool call
1254                match &parts[1] {
1255                    LLMMessageTypedContent::ToolCall { id, name, args, .. } => {
1256                        assert_eq!(id, "call_abc123");
1257                        assert_eq!(name, "get_weather");
1258                        assert_eq!(args["location"], "Paris");
1259                    }
1260                    _ => panic!("Expected ToolCall content part second"),
1261                }
1262            }
1263            _ => panic!("Expected List content"),
1264        }
1265    }
1266}