Skip to main content

stakpak_shared/models/integrations/
openai.rs

1//! OpenAI provider configuration and chat message types
2//!
3//! This module contains:
4//! - Configuration types for OpenAI provider
5//! - OpenAI model enums with pricing info
6//! - Chat message types used throughout the TUI
7//! - Tool call types for agent interactions
8//!
9//! Note: Low-level API request/response types are in `libs/ai/src/providers/openai/`.
10
11use crate::models::llm::{
12    GenerationDelta, LLMMessage, LLMMessageContent, LLMMessageImageSource, LLMMessageTypedContent,
13    LLMTokenUsage, LLMTool,
14};
15use crate::models::model_pricing::{ContextAware, ContextPricingTier, ModelContextInfo};
16use serde::{Deserialize, Serialize};
17use serde_json::{Value, json};
18use uuid::Uuid;
19
20// =============================================================================
21// Provider Configuration
22// =============================================================================
23
24/// Configuration for OpenAI provider
25#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq)]
26pub struct OpenAIConfig {
27    pub api_endpoint: Option<String>,
28    pub api_key: Option<String>,
29}
30
31impl OpenAIConfig {
32    pub const OPENAI_CODEX_BASE_URL: &'static str = "https://chatgpt.com/backend-api/codex";
33    const OPENAI_AUTH_CLAIM: &'static str = "https://api.openai.com/auth";
34
35    /// Create config with API key
36    pub fn with_api_key(api_key: impl Into<String>) -> Self {
37        Self {
38            api_key: Some(api_key.into()),
39            api_endpoint: None,
40        }
41    }
42
43    /// Decode an OpenAI access token and extract the ChatGPT account ID.
44    ///
45    /// This intentionally reads the JWT payload without signature verification.
46    /// The claim is only used for request routing/header construction; OpenAI's
47    /// servers still validate the bearer token on use.
48    pub fn extract_chatgpt_account_id(access_token: &str) -> Option<String> {
49        let claims = crate::jwt::decode_jwt_payload_unverified(access_token)?;
50        let auth_claim = claims.get(Self::OPENAI_AUTH_CLAIM)?;
51
52        match auth_claim {
53            Value::Object(map) => map
54                .get("chatgpt_account_id")
55                .and_then(Value::as_str)
56                .map(ToString::to_string),
57            Value::String(raw_json) => {
58                serde_json::from_str::<Value>(raw_json)
59                    .ok()
60                    .and_then(|value| {
61                        value
62                            .get("chatgpt_account_id")
63                            .and_then(Value::as_str)
64                            .map(ToString::to_string)
65                    })
66            }
67            _ => None,
68        }
69    }
70}
71
72// =============================================================================
73// Model Definitions
74// =============================================================================
75
76/// OpenAI model identifiers
77#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
78pub enum OpenAIModel {
79    // Reasoning Models
80    #[serde(rename = "o3-2025-04-16")]
81    O3,
82    #[serde(rename = "o4-mini-2025-04-16")]
83    O4Mini,
84
85    #[default]
86    #[serde(rename = "gpt-5-2025-08-07")]
87    GPT5,
88    #[serde(rename = "gpt-5.1-2025-11-13")]
89    GPT51,
90    #[serde(rename = "gpt-5-mini-2025-08-07")]
91    GPT5Mini,
92    #[serde(rename = "gpt-5-nano-2025-08-07")]
93    GPT5Nano,
94
95    Custom(String),
96}
97
98impl OpenAIModel {
99    pub fn from_string(s: &str) -> Result<Self, String> {
100        serde_json::from_value(serde_json::Value::String(s.to_string()))
101            .map_err(|_| "Failed to deserialize OpenAI model".to_string())
102    }
103
104    /// Default smart model for OpenAI
105    pub const DEFAULT_SMART_MODEL: OpenAIModel = OpenAIModel::GPT5;
106
107    /// Default eco model for OpenAI
108    pub const DEFAULT_ECO_MODEL: OpenAIModel = OpenAIModel::GPT5Mini;
109
110    /// Default recovery model for OpenAI
111    pub const DEFAULT_RECOVERY_MODEL: OpenAIModel = OpenAIModel::GPT5Mini;
112
113    /// Get default smart model as string
114    pub fn default_smart_model() -> String {
115        Self::DEFAULT_SMART_MODEL.to_string()
116    }
117
118    /// Get default eco model as string
119    pub fn default_eco_model() -> String {
120        Self::DEFAULT_ECO_MODEL.to_string()
121    }
122
123    /// Get default recovery model as string
124    pub fn default_recovery_model() -> String {
125        Self::DEFAULT_RECOVERY_MODEL.to_string()
126    }
127}
128
129impl ContextAware for OpenAIModel {
130    fn context_info(&self) -> ModelContextInfo {
131        let model_name = self.to_string();
132
133        if model_name.starts_with("o3") {
134            return ModelContextInfo {
135                max_tokens: 200_000,
136                pricing_tiers: vec![ContextPricingTier {
137                    label: "Standard".to_string(),
138                    input_cost_per_million: 2.0,
139                    output_cost_per_million: 8.0,
140                    upper_bound: None,
141                }],
142                approach_warning_threshold: 0.8,
143            };
144        }
145
146        if model_name.starts_with("o4-mini") {
147            return ModelContextInfo {
148                max_tokens: 200_000,
149                pricing_tiers: vec![ContextPricingTier {
150                    label: "Standard".to_string(),
151                    input_cost_per_million: 1.10,
152                    output_cost_per_million: 4.40,
153                    upper_bound: None,
154                }],
155                approach_warning_threshold: 0.8,
156            };
157        }
158
159        if model_name.starts_with("gpt-5-mini") {
160            return ModelContextInfo {
161                max_tokens: 400_000,
162                pricing_tiers: vec![ContextPricingTier {
163                    label: "Standard".to_string(),
164                    input_cost_per_million: 0.25,
165                    output_cost_per_million: 2.0,
166                    upper_bound: None,
167                }],
168                approach_warning_threshold: 0.8,
169            };
170        }
171
172        if model_name.starts_with("gpt-5-nano") {
173            return ModelContextInfo {
174                max_tokens: 400_000,
175                pricing_tiers: vec![ContextPricingTier {
176                    label: "Standard".to_string(),
177                    input_cost_per_million: 0.05,
178                    output_cost_per_million: 0.40,
179                    upper_bound: None,
180                }],
181                approach_warning_threshold: 0.8,
182            };
183        }
184
185        if model_name.starts_with("gpt-5") {
186            return ModelContextInfo {
187                max_tokens: 400_000,
188                pricing_tiers: vec![ContextPricingTier {
189                    label: "Standard".to_string(),
190                    input_cost_per_million: 1.25,
191                    output_cost_per_million: 10.0,
192                    upper_bound: None,
193                }],
194                approach_warning_threshold: 0.8,
195            };
196        }
197
198        ModelContextInfo::default()
199    }
200
201    fn model_name(&self) -> String {
202        match self {
203            OpenAIModel::O3 => "O3".to_string(),
204            OpenAIModel::O4Mini => "O4-mini".to_string(),
205            OpenAIModel::GPT5 => "GPT-5".to_string(),
206            OpenAIModel::GPT51 => "GPT-5.1".to_string(),
207            OpenAIModel::GPT5Mini => "GPT-5 Mini".to_string(),
208            OpenAIModel::GPT5Nano => "GPT-5 Nano".to_string(),
209            OpenAIModel::Custom(name) => format!("Custom ({})", name),
210        }
211    }
212}
213
214impl std::fmt::Display for OpenAIModel {
215    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
216        match self {
217            OpenAIModel::O3 => write!(f, "o3-2025-04-16"),
218            OpenAIModel::O4Mini => write!(f, "o4-mini-2025-04-16"),
219            OpenAIModel::GPT5Nano => write!(f, "gpt-5-nano-2025-08-07"),
220            OpenAIModel::GPT5Mini => write!(f, "gpt-5-mini-2025-08-07"),
221            OpenAIModel::GPT5 => write!(f, "gpt-5-2025-08-07"),
222            OpenAIModel::GPT51 => write!(f, "gpt-5.1-2025-11-13"),
223            OpenAIModel::Custom(model_name) => write!(f, "{}", model_name),
224        }
225    }
226}
227
228// =============================================================================
229// Message Types (used by TUI)
230// =============================================================================
231
232/// Message role
233#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
234#[serde(rename_all = "lowercase")]
235pub enum Role {
236    System,
237    Developer,
238    User,
239    #[default]
240    Assistant,
241    Tool,
242}
243
244impl std::fmt::Display for Role {
245    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
246        match self {
247            Role::System => write!(f, "system"),
248            Role::Developer => write!(f, "developer"),
249            Role::User => write!(f, "user"),
250            Role::Assistant => write!(f, "assistant"),
251            Role::Tool => write!(f, "tool"),
252        }
253    }
254}
255
256/// Model info for tracking which model generated a message
257#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
258pub struct ModelInfo {
259    /// Provider name (e.g., "anthropic", "openai")
260    pub provider: String,
261    /// Model identifier (e.g., "claude-sonnet-4-20250514", "gpt-4")
262    pub id: String,
263}
264
265/// Chat message
266#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
267pub struct ChatMessage {
268    pub role: Role,
269    pub content: Option<MessageContent>,
270    #[serde(skip_serializing_if = "Option::is_none")]
271    pub name: Option<String>,
272    #[serde(skip_serializing_if = "Option::is_none")]
273    pub tool_calls: Option<Vec<ToolCall>>,
274    #[serde(skip_serializing_if = "Option::is_none")]
275    pub tool_call_id: Option<String>,
276    #[serde(skip_serializing_if = "Option::is_none")]
277    pub usage: Option<LLMTokenUsage>,
278
279    // === Extended fields for session tracking ===
280    /// Unique message identifier
281    #[serde(skip_serializing_if = "Option::is_none")]
282    pub id: Option<String>,
283    /// Model that generated this message (for assistant messages)
284    #[serde(skip_serializing_if = "Option::is_none")]
285    pub model: Option<ModelInfo>,
286    /// Cost in dollars for this message
287    #[serde(skip_serializing_if = "Option::is_none")]
288    pub cost: Option<f64>,
289    /// Why the model stopped: "stop", "tool_calls", "length", "error"
290    #[serde(skip_serializing_if = "Option::is_none")]
291    pub finish_reason: Option<String>,
292    /// Unix timestamp (ms) when message was created/sent
293    #[serde(skip_serializing_if = "Option::is_none")]
294    pub created_at: Option<i64>,
295    /// Unix timestamp (ms) when assistant finished generating
296    #[serde(skip_serializing_if = "Option::is_none")]
297    pub completed_at: Option<i64>,
298    /// Plugin extensibility - unstructured metadata
299    #[serde(skip_serializing_if = "Option::is_none")]
300    pub metadata: Option<serde_json::Value>,
301}
302
303impl ChatMessage {
304    pub fn last_server_message(messages: &[ChatMessage]) -> Option<&ChatMessage> {
305        messages
306            .iter()
307            .rev()
308            .find(|message| message.role != Role::User && message.role != Role::Tool)
309    }
310
311    pub fn to_xml(&self) -> String {
312        match &self.content {
313            Some(MessageContent::String(s)) => {
314                format!("<message role=\"{}\">{}</message>", self.role, s)
315            }
316            Some(MessageContent::Array(parts)) => parts
317                .iter()
318                .map(|part| {
319                    format!(
320                        "<message role=\"{}\" type=\"{}\">{}</message>",
321                        self.role,
322                        part.r#type,
323                        part.text.clone().unwrap_or_default()
324                    )
325                })
326                .collect::<Vec<String>>()
327                .join("\n"),
328            None => String::new(),
329        }
330    }
331}
332
333/// Message content (string or array of parts)
334#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
335#[serde(untagged)]
336pub enum MessageContent {
337    String(String),
338    Array(Vec<ContentPart>),
339}
340
341impl MessageContent {
342    pub fn inject_checkpoint_id(&self, checkpoint_id: Uuid) -> Self {
343        match self {
344            MessageContent::String(s) => MessageContent::String(format!(
345                "<checkpoint_id>{checkpoint_id}</checkpoint_id>\n{s}"
346            )),
347            MessageContent::Array(parts) => MessageContent::Array(
348                std::iter::once(ContentPart {
349                    r#type: "text".to_string(),
350                    text: Some(format!("<checkpoint_id>{checkpoint_id}</checkpoint_id>")),
351                    image_url: None,
352                })
353                .chain(parts.iter().cloned())
354                .collect(),
355            ),
356        }
357    }
358
359    /// All indices from rfind()/find() of ASCII XML tags on same string
360    #[allow(clippy::string_slice)]
361    pub fn extract_checkpoint_id(&self) -> Option<Uuid> {
362        match self {
363            MessageContent::String(s) => s
364                .rfind("<checkpoint_id>")
365                .and_then(|start| {
366                    s[start..]
367                        .find("</checkpoint_id>")
368                        .map(|end| (start + "<checkpoint_id>".len(), start + end))
369                })
370                .and_then(|(start, end)| Uuid::parse_str(&s[start..end]).ok()),
371            MessageContent::Array(parts) => parts.iter().rev().find_map(|part| {
372                part.text.as_deref().and_then(|text| {
373                    text.rfind("<checkpoint_id>")
374                        .and_then(|start| {
375                            text[start..]
376                                .find("</checkpoint_id>")
377                                .map(|end| (start + "<checkpoint_id>".len(), start + end))
378                        })
379                        .and_then(|(start, end)| Uuid::parse_str(&text[start..end]).ok())
380                })
381            }),
382        }
383    }
384}
385
386impl std::fmt::Display for MessageContent {
387    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
388        match self {
389            MessageContent::String(s) => write!(f, "{s}"),
390            MessageContent::Array(parts) => {
391                let text_parts: Vec<String> =
392                    parts.iter().filter_map(|part| part.text.clone()).collect();
393                write!(f, "{}", text_parts.join("\n"))
394            }
395        }
396    }
397}
398
399impl Default for MessageContent {
400    fn default() -> Self {
401        MessageContent::String(String::new())
402    }
403}
404
405/// Content part (text or image)
406#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
407pub struct ContentPart {
408    pub r#type: String,
409    #[serde(skip_serializing_if = "Option::is_none")]
410    pub text: Option<String>,
411    #[serde(skip_serializing_if = "Option::is_none")]
412    pub image_url: Option<ImageUrl>,
413}
414
415/// Image URL with optional detail level
416#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
417pub struct ImageUrl {
418    pub url: String,
419    #[serde(skip_serializing_if = "Option::is_none")]
420    pub detail: Option<String>,
421}
422
423// =============================================================================
424// Tool Types (used by TUI)
425// =============================================================================
426
427/// Tool definition
428#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
429pub struct Tool {
430    pub r#type: String,
431    pub function: FunctionDefinition,
432}
433
434/// Function definition for tools
435#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
436pub struct FunctionDefinition {
437    pub name: String,
438    pub description: Option<String>,
439    pub parameters: serde_json::Value,
440}
441
442impl From<Tool> for LLMTool {
443    fn from(tool: Tool) -> Self {
444        LLMTool {
445            name: tool.function.name,
446            description: tool.function.description.unwrap_or_default(),
447            input_schema: tool.function.parameters,
448        }
449    }
450}
451
452/// Tool choice configuration
453#[derive(Debug, Clone, PartialEq)]
454pub enum ToolChoice {
455    Auto,
456    Required,
457    Object(ToolChoiceObject),
458}
459
460impl Serialize for ToolChoice {
461    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
462    where
463        S: serde::Serializer,
464    {
465        match self {
466            ToolChoice::Auto => serializer.serialize_str("auto"),
467            ToolChoice::Required => serializer.serialize_str("required"),
468            ToolChoice::Object(obj) => obj.serialize(serializer),
469        }
470    }
471}
472
473impl<'de> Deserialize<'de> for ToolChoice {
474    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
475    where
476        D: serde::Deserializer<'de>,
477    {
478        struct ToolChoiceVisitor;
479
480        impl<'de> serde::de::Visitor<'de> for ToolChoiceVisitor {
481            type Value = ToolChoice;
482
483            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
484                formatter.write_str("string or object")
485            }
486
487            fn visit_str<E>(self, value: &str) -> Result<ToolChoice, E>
488            where
489                E: serde::de::Error,
490            {
491                match value {
492                    "auto" => Ok(ToolChoice::Auto),
493                    "required" => Ok(ToolChoice::Required),
494                    _ => Err(serde::de::Error::unknown_variant(
495                        value,
496                        &["auto", "required"],
497                    )),
498                }
499            }
500
501            fn visit_map<M>(self, map: M) -> Result<ToolChoice, M::Error>
502            where
503                M: serde::de::MapAccess<'de>,
504            {
505                let obj = ToolChoiceObject::deserialize(
506                    serde::de::value::MapAccessDeserializer::new(map),
507                )?;
508                Ok(ToolChoice::Object(obj))
509            }
510        }
511
512        deserializer.deserialize_any(ToolChoiceVisitor)
513    }
514}
515
516#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
517pub struct ToolChoiceObject {
518    pub r#type: String,
519    pub function: FunctionChoice,
520}
521
522#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
523pub struct FunctionChoice {
524    pub name: String,
525}
526
527/// Tool call from assistant
528#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
529pub struct ToolCall {
530    pub id: String,
531    pub r#type: String,
532    pub function: FunctionCall,
533    /// Opaque provider-specific metadata (e.g., Gemini thought_signature)
534    #[serde(skip_serializing_if = "Option::is_none")]
535    pub metadata: Option<serde_json::Value>,
536}
537
538/// Function call details
539#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
540pub struct FunctionCall {
541    pub name: String,
542    pub arguments: String,
543}
544
545/// Tool call result status
546#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
547pub enum ToolCallResultStatus {
548    Success,
549    Error,
550    Cancelled,
551}
552
553/// Tool call result
554#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
555pub struct ToolCallResult {
556    pub call: ToolCall,
557    pub result: String,
558    pub status: ToolCallResultStatus,
559}
560
561/// Streaming progress info for a tool call being generated by the LLM
562#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
563pub struct ToolCallStreamInfo {
564    /// Tool name (may be empty if not yet streamed)
565    pub name: String,
566    /// Estimated token count of arguments streamed so far (~chars/4)
567    pub args_tokens: usize,
568    /// Optional description extracted from arguments (best-effort, may be None if JSON incomplete)
569    #[serde(skip_serializing_if = "Option::is_none")]
570    pub description: Option<String>,
571}
572
573/// Tool call result progress update
574#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
575pub struct ToolCallResultProgress {
576    pub id: Uuid,
577    pub message: String,
578    /// Type of progress update for specialized handling
579    #[serde(skip_serializing_if = "Option::is_none")]
580    pub progress_type: Option<ProgressType>,
581    /// Structured task updates for task wait progress
582    #[serde(skip_serializing_if = "Option::is_none")]
583    pub task_updates: Option<Vec<TaskUpdate>>,
584    /// Overall progress percentage (0.0 - 100.0)
585    #[serde(skip_serializing_if = "Option::is_none")]
586    pub progress: Option<f64>,
587}
588
589/// Type of progress update
590#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
591pub enum ProgressType {
592    /// Command output streaming
593    CommandOutput,
594    /// Task wait progress with structured updates
595    TaskWait,
596    /// Generic progress
597    Generic,
598}
599
600/// Structured task status update
601#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
602pub struct TaskUpdate {
603    pub task_id: String,
604    pub status: String,
605    #[serde(skip_serializing_if = "Option::is_none")]
606    pub description: Option<String>,
607    #[serde(skip_serializing_if = "Option::is_none")]
608    pub duration_secs: Option<f64>,
609    #[serde(skip_serializing_if = "Option::is_none")]
610    pub output_preview: Option<String>,
611    /// Whether this is a target task being waited on
612    #[serde(default)]
613    pub is_target: bool,
614    /// Pause information for paused subagent tasks
615    #[serde(skip_serializing_if = "Option::is_none")]
616    pub pause_info: Option<TaskPauseInfo>,
617}
618
619/// Pause information for subagent tasks awaiting approval
620#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
621pub struct TaskPauseInfo {
622    /// The agent's message before pausing
623    #[serde(skip_serializing_if = "Option::is_none")]
624    pub agent_message: Option<String>,
625    /// Pending tool calls awaiting approval
626    #[serde(skip_serializing_if = "Option::is_none")]
627    pub pending_tool_calls: Option<Vec<crate::models::async_manifest::PendingToolCall>>,
628}
629
630pub use crate::models::tools::ask_user::{
631    AskUserAnswer, AskUserOption, AskUserQuestion, AskUserRequest, AskUserResult,
632};
633
634/// Chat completion request
635#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
636pub struct ChatCompletionRequest {
637    pub model: String,
638    pub messages: Vec<ChatMessage>,
639    #[serde(skip_serializing_if = "Option::is_none")]
640    pub frequency_penalty: Option<f32>,
641    #[serde(skip_serializing_if = "Option::is_none")]
642    pub logit_bias: Option<serde_json::Value>,
643    #[serde(skip_serializing_if = "Option::is_none")]
644    pub logprobs: Option<bool>,
645    #[serde(skip_serializing_if = "Option::is_none")]
646    pub max_tokens: Option<u32>,
647    #[serde(skip_serializing_if = "Option::is_none")]
648    pub n: Option<u32>,
649    #[serde(skip_serializing_if = "Option::is_none")]
650    pub presence_penalty: Option<f32>,
651    #[serde(skip_serializing_if = "Option::is_none")]
652    pub response_format: Option<ResponseFormat>,
653    #[serde(skip_serializing_if = "Option::is_none")]
654    pub seed: Option<i64>,
655    #[serde(skip_serializing_if = "Option::is_none")]
656    pub stop: Option<StopSequence>,
657    #[serde(skip_serializing_if = "Option::is_none")]
658    pub stream: Option<bool>,
659    #[serde(skip_serializing_if = "Option::is_none")]
660    pub temperature: Option<f32>,
661    #[serde(skip_serializing_if = "Option::is_none")]
662    pub top_p: Option<f32>,
663    #[serde(skip_serializing_if = "Option::is_none")]
664    pub tools: Option<Vec<Tool>>,
665    #[serde(skip_serializing_if = "Option::is_none")]
666    pub tool_choice: Option<ToolChoice>,
667    #[serde(skip_serializing_if = "Option::is_none")]
668    pub user: Option<String>,
669    #[serde(skip_serializing_if = "Option::is_none")]
670    pub context: Option<ChatCompletionContext>,
671}
672
673impl ChatCompletionRequest {
674    pub fn new(
675        model: String,
676        messages: Vec<ChatMessage>,
677        tools: Option<Vec<Tool>>,
678        stream: Option<bool>,
679    ) -> Self {
680        Self {
681            model,
682            messages,
683            frequency_penalty: None,
684            logit_bias: None,
685            logprobs: None,
686            max_tokens: None,
687            n: None,
688            presence_penalty: None,
689            response_format: None,
690            seed: None,
691            stop: None,
692            stream,
693            temperature: None,
694            top_p: None,
695            tools,
696            tool_choice: None,
697            user: None,
698            context: None,
699        }
700    }
701}
702
703#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
704pub struct ChatCompletionContext {
705    pub scratchpad: Option<Value>,
706}
707
708#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
709pub struct ResponseFormat {
710    pub r#type: String,
711}
712
713#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
714#[serde(untagged)]
715pub enum StopSequence {
716    String(String),
717    Array(Vec<String>),
718}
719
720/// Chat completion response
721#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
722pub struct ChatCompletionResponse {
723    pub id: String,
724    pub object: String,
725    pub created: u64,
726    pub model: String,
727    pub choices: Vec<ChatCompletionChoice>,
728    pub usage: LLMTokenUsage,
729    #[serde(skip_serializing_if = "Option::is_none")]
730    pub system_fingerprint: Option<String>,
731    pub metadata: Option<serde_json::Value>,
732}
733
734#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
735pub struct ChatCompletionChoice {
736    pub index: usize,
737    pub message: ChatMessage,
738    pub logprobs: Option<LogProbs>,
739    pub finish_reason: FinishReason,
740}
741
742#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
743#[serde(rename_all = "snake_case")]
744pub enum FinishReason {
745    Stop,
746    Length,
747    ContentFilter,
748    ToolCalls,
749}
750
751#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
752pub struct LogProbs {
753    pub content: Option<Vec<LogProbContent>>,
754}
755
756#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
757pub struct LogProbContent {
758    pub token: String,
759    pub logprob: f32,
760    pub bytes: Option<Vec<u8>>,
761    pub top_logprobs: Option<Vec<TokenLogprob>>,
762}
763
764#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
765pub struct TokenLogprob {
766    pub token: String,
767    pub logprob: f32,
768    pub bytes: Option<Vec<u8>>,
769}
770
771// =============================================================================
772// Streaming Types
773// =============================================================================
774
775#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
776pub struct ChatCompletionStreamResponse {
777    pub id: String,
778    pub object: String,
779    pub created: u64,
780    pub model: String,
781    pub choices: Vec<ChatCompletionStreamChoice>,
782    pub usage: Option<LLMTokenUsage>,
783    pub metadata: Option<serde_json::Value>,
784}
785
786#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
787pub struct ChatCompletionStreamChoice {
788    pub index: usize,
789    pub delta: ChatMessageDelta,
790    pub finish_reason: Option<FinishReason>,
791}
792
793#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
794pub struct ChatMessageDelta {
795    #[serde(skip_serializing_if = "Option::is_none")]
796    pub role: Option<Role>,
797    #[serde(skip_serializing_if = "Option::is_none")]
798    pub content: Option<String>,
799    #[serde(skip_serializing_if = "Option::is_none")]
800    pub tool_calls: Option<Vec<ToolCallDelta>>,
801}
802
803#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
804pub struct ToolCallDelta {
805    pub index: usize,
806    pub id: Option<String>,
807    pub r#type: Option<String>,
808    pub function: Option<FunctionCallDelta>,
809    /// Opaque provider-specific metadata (e.g., Gemini thought_signature)
810    #[serde(skip_serializing_if = "Option::is_none")]
811    pub metadata: Option<serde_json::Value>,
812}
813
814#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
815pub struct FunctionCallDelta {
816    pub name: Option<String>,
817    pub arguments: Option<String>,
818}
819
820// =============================================================================
821// Conversions
822// =============================================================================
823
824impl From<LLMMessage> for ChatMessage {
825    fn from(llm_message: LLMMessage) -> Self {
826        let role = match llm_message.role.as_str() {
827            "system" => Role::System,
828            "user" => Role::User,
829            "assistant" => Role::Assistant,
830            "tool" => Role::Tool,
831            "developer" => Role::Developer,
832            _ => Role::User,
833        };
834
835        let (content, tool_calls, tool_call_id) = match llm_message.content {
836            LLMMessageContent::String(text) => (Some(MessageContent::String(text)), None, None),
837            LLMMessageContent::List(items) => {
838                let mut text_parts = Vec::new();
839                let mut tool_call_parts = Vec::new();
840                let mut tool_result_id: Option<String> = None;
841
842                for item in items {
843                    match item {
844                        LLMMessageTypedContent::Text { text } => {
845                            text_parts.push(ContentPart {
846                                r#type: "text".to_string(),
847                                text: Some(text),
848                                image_url: None,
849                            });
850                        }
851                        LLMMessageTypedContent::ToolCall {
852                            id,
853                            name,
854                            args,
855                            metadata,
856                        } => {
857                            tool_call_parts.push(ToolCall {
858                                id,
859                                r#type: "function".to_string(),
860                                function: FunctionCall {
861                                    name,
862                                    arguments: args.to_string(),
863                                },
864                                metadata,
865                            });
866                        }
867                        LLMMessageTypedContent::ToolResult {
868                            tool_use_id,
869                            content,
870                        } => {
871                            if tool_result_id.is_none() {
872                                tool_result_id = Some(tool_use_id);
873                            }
874                            text_parts.push(ContentPart {
875                                r#type: "text".to_string(),
876                                text: Some(content),
877                                image_url: None,
878                            });
879                        }
880                        LLMMessageTypedContent::Image { source } => {
881                            text_parts.push(ContentPart {
882                                r#type: "image_url".to_string(),
883                                text: None,
884                                image_url: Some(ImageUrl {
885                                    url: format!(
886                                        "data:{};base64,{}",
887                                        source.media_type, source.data
888                                    ),
889                                    detail: None,
890                                }),
891                            });
892                        }
893                    }
894                }
895
896                let content = if !text_parts.is_empty() {
897                    Some(MessageContent::Array(text_parts))
898                } else {
899                    None
900                };
901
902                let tool_calls = if !tool_call_parts.is_empty() {
903                    Some(tool_call_parts)
904                } else {
905                    None
906                };
907
908                (content, tool_calls, tool_result_id)
909            }
910        };
911
912        ChatMessage {
913            role,
914            content,
915            name: None,
916            tool_calls,
917            tool_call_id,
918            usage: None,
919            ..Default::default()
920        }
921    }
922}
923
924impl From<ChatMessage> for LLMMessage {
925    fn from(chat_message: ChatMessage) -> Self {
926        let mut content_parts = Vec::new();
927
928        match chat_message.content {
929            Some(MessageContent::String(s)) => {
930                if !s.is_empty() {
931                    content_parts.push(LLMMessageTypedContent::Text { text: s });
932                }
933            }
934            Some(MessageContent::Array(parts)) => {
935                for part in parts {
936                    if let Some(text) = part.text {
937                        content_parts.push(LLMMessageTypedContent::Text { text });
938                    } else if let Some(image_url) = part.image_url {
939                        let (media_type, data) = if image_url.url.starts_with("data:") {
940                            let parts: Vec<&str> = image_url.url.splitn(2, ',').collect();
941                            if parts.len() == 2 {
942                                let meta = parts[0];
943                                let data = parts[1];
944                                let media_type = meta
945                                    .trim_start_matches("data:")
946                                    .trim_end_matches(";base64")
947                                    .to_string();
948                                (media_type, data.to_string())
949                            } else {
950                                ("image/jpeg".to_string(), image_url.url)
951                            }
952                        } else {
953                            ("image/jpeg".to_string(), image_url.url)
954                        };
955
956                        content_parts.push(LLMMessageTypedContent::Image {
957                            source: LLMMessageImageSource {
958                                r#type: "base64".to_string(),
959                                media_type,
960                                data,
961                            },
962                        });
963                    }
964                }
965            }
966            None => {}
967        }
968
969        if let Some(tool_calls) = chat_message.tool_calls {
970            for tool_call in tool_calls {
971                let args = serde_json::from_str(&tool_call.function.arguments).unwrap_or(json!({}));
972                content_parts.push(LLMMessageTypedContent::ToolCall {
973                    id: tool_call.id,
974                    name: tool_call.function.name,
975                    args,
976                    metadata: tool_call.metadata,
977                });
978            }
979        }
980
981        // Handle tool result messages: when role is Tool and tool_call_id is present,
982        // convert the content to a ToolResult content part. This is the generic
983        // intermediate representation - each provider's conversion layer handles
984        // the specifics (e.g., Anthropic converts to user role with tool_result blocks)
985        if chat_message.role == Role::Tool
986            && let Some(tool_call_id) = chat_message.tool_call_id
987        {
988            // Extract content as string for the tool result
989            let content_str = content_parts
990                .iter()
991                .filter_map(|p| match p {
992                    LLMMessageTypedContent::Text { text } => Some(text.clone()),
993                    _ => None,
994                })
995                .collect::<Vec<_>>()
996                .join("\n");
997
998            // Replace content with a single ToolResult
999            content_parts = vec![LLMMessageTypedContent::ToolResult {
1000                tool_use_id: tool_call_id,
1001                content: content_str,
1002            }];
1003        }
1004
1005        LLMMessage {
1006            role: chat_message.role.to_string(),
1007            content: if content_parts.is_empty() {
1008                LLMMessageContent::String(String::new())
1009            } else if content_parts.len() == 1 {
1010                match &content_parts[0] {
1011                    LLMMessageTypedContent::Text { text } => {
1012                        LLMMessageContent::String(text.clone())
1013                    }
1014                    _ => LLMMessageContent::List(content_parts),
1015                }
1016            } else {
1017                LLMMessageContent::List(content_parts)
1018            },
1019        }
1020    }
1021}
1022
1023impl From<GenerationDelta> for ChatMessageDelta {
1024    fn from(delta: GenerationDelta) -> Self {
1025        match delta {
1026            GenerationDelta::Content { content } => ChatMessageDelta {
1027                role: Some(Role::Assistant),
1028                content: Some(content),
1029                tool_calls: None,
1030            },
1031            GenerationDelta::Thinking { thinking: _ } => ChatMessageDelta {
1032                role: Some(Role::Assistant),
1033                content: None,
1034                tool_calls: None,
1035            },
1036            GenerationDelta::ToolUse { tool_use } => ChatMessageDelta {
1037                role: Some(Role::Assistant),
1038                content: None,
1039                tool_calls: Some(vec![ToolCallDelta {
1040                    index: tool_use.index,
1041                    id: tool_use.id,
1042                    r#type: Some("function".to_string()),
1043                    function: Some(FunctionCallDelta {
1044                        name: tool_use.name,
1045                        arguments: tool_use.input,
1046                    }),
1047                    metadata: tool_use.metadata,
1048                }]),
1049            },
1050            _ => ChatMessageDelta {
1051                role: Some(Role::Assistant),
1052                content: None,
1053                tool_calls: None,
1054            },
1055        }
1056    }
1057}
1058
1059#[cfg(test)]
1060mod tests {
1061    use super::*;
1062
1063    #[test]
1064    fn test_serialize_basic_request() {
1065        let request = ChatCompletionRequest {
1066            model: "gpt-4".to_string(),
1067            messages: vec![
1068                ChatMessage {
1069                    role: Role::System,
1070                    content: Some(MessageContent::String(
1071                        "You are a helpful assistant.".to_string(),
1072                    )),
1073                    name: None,
1074                    tool_calls: None,
1075                    tool_call_id: None,
1076                    usage: None,
1077                    ..Default::default()
1078                },
1079                ChatMessage {
1080                    role: Role::User,
1081                    content: Some(MessageContent::String("Hello!".to_string())),
1082                    name: None,
1083                    tool_calls: None,
1084                    tool_call_id: None,
1085                    usage: None,
1086                    ..Default::default()
1087                },
1088            ],
1089            frequency_penalty: None,
1090            logit_bias: None,
1091            logprobs: None,
1092            max_tokens: Some(100),
1093            n: None,
1094            presence_penalty: None,
1095            response_format: None,
1096            seed: None,
1097            stop: None,
1098            stream: None,
1099            temperature: Some(0.7),
1100            top_p: None,
1101            tools: None,
1102            tool_choice: None,
1103            user: None,
1104            context: None,
1105        };
1106
1107        let json = serde_json::to_string(&request).unwrap();
1108        assert!(json.contains("\"model\":\"gpt-4\""));
1109        assert!(json.contains("\"messages\":["));
1110        assert!(json.contains("\"role\":\"system\""));
1111    }
1112
1113    #[test]
1114    fn test_llm_message_to_chat_message() {
1115        let llm_message = LLMMessage {
1116            role: "user".to_string(),
1117            content: LLMMessageContent::String("Hello, world!".to_string()),
1118        };
1119
1120        let chat_message = ChatMessage::from(llm_message);
1121        assert_eq!(chat_message.role, Role::User);
1122        match &chat_message.content {
1123            Some(MessageContent::String(text)) => assert_eq!(text, "Hello, world!"),
1124            _ => panic!("Expected string content"),
1125        }
1126    }
1127
1128    #[test]
1129    fn test_llm_message_to_chat_message_tool_result_preserves_tool_call_id() {
1130        let llm_message = LLMMessage {
1131            role: "tool".to_string(),
1132            content: LLMMessageContent::List(vec![LLMMessageTypedContent::ToolResult {
1133                tool_use_id: "toolu_01Abc123".to_string(),
1134                content: "Tool execution result".to_string(),
1135            }]),
1136        };
1137
1138        let chat_message = ChatMessage::from(llm_message);
1139        assert_eq!(chat_message.role, Role::Tool);
1140        assert_eq!(chat_message.tool_call_id.as_deref(), Some("toolu_01Abc123"));
1141        assert_eq!(
1142            chat_message.content,
1143            Some(MessageContent::Array(vec![ContentPart {
1144                r#type: "text".to_string(),
1145                text: Some("Tool execution result".to_string()),
1146                image_url: None,
1147            }]))
1148        );
1149    }
1150
1151    #[test]
1152    fn test_chat_message_to_llm_message_tool_result() {
1153        // Test that Tool role messages with tool_call_id are converted to ToolResult content
1154        // This is critical for Anthropic compatibility - the provider layer converts
1155        // role="tool" to role="user" with tool_result content blocks
1156        let chat_message = ChatMessage {
1157            role: Role::Tool,
1158            content: Some(MessageContent::String("Tool execution result".to_string())),
1159            name: None,
1160            tool_calls: None,
1161            tool_call_id: Some("toolu_01Abc123".to_string()),
1162            usage: None,
1163            ..Default::default()
1164        };
1165
1166        let llm_message: LLMMessage = chat_message.into();
1167
1168        // Role should be preserved as "tool" - provider layer handles conversion
1169        assert_eq!(llm_message.role, "tool");
1170
1171        // Content should be a ToolResult with the tool_call_id
1172        match &llm_message.content {
1173            LLMMessageContent::List(parts) => {
1174                assert_eq!(parts.len(), 1, "Should have exactly one content part");
1175                match &parts[0] {
1176                    LLMMessageTypedContent::ToolResult {
1177                        tool_use_id,
1178                        content,
1179                    } => {
1180                        assert_eq!(tool_use_id, "toolu_01Abc123");
1181                        assert_eq!(content, "Tool execution result");
1182                    }
1183                    _ => panic!("Expected ToolResult content part, got {:?}", parts[0]),
1184                }
1185            }
1186            _ => panic!(
1187                "Expected List content with ToolResult, got {:?}",
1188                llm_message.content
1189            ),
1190        }
1191    }
1192
1193    #[test]
1194    fn test_chat_message_to_llm_message_tool_result_empty_content() {
1195        // Test tool result with empty content
1196        let chat_message = ChatMessage {
1197            role: Role::Tool,
1198            content: None,
1199            name: None,
1200            tool_calls: None,
1201            tool_call_id: Some("toolu_02Xyz789".to_string()),
1202            usage: None,
1203            ..Default::default()
1204        };
1205
1206        let llm_message: LLMMessage = chat_message.into();
1207
1208        assert_eq!(llm_message.role, "tool");
1209        match &llm_message.content {
1210            LLMMessageContent::List(parts) => {
1211                assert_eq!(parts.len(), 1);
1212                match &parts[0] {
1213                    LLMMessageTypedContent::ToolResult {
1214                        tool_use_id,
1215                        content,
1216                    } => {
1217                        assert_eq!(tool_use_id, "toolu_02Xyz789");
1218                        assert_eq!(content, ""); // Empty content
1219                    }
1220                    _ => panic!("Expected ToolResult content part"),
1221                }
1222            }
1223            _ => panic!("Expected List content with ToolResult"),
1224        }
1225    }
1226
1227    #[test]
1228    fn test_chat_message_to_llm_message_assistant_with_tool_calls() {
1229        // Test that assistant messages with tool_calls are converted correctly
1230        let chat_message = ChatMessage {
1231            role: Role::Assistant,
1232            content: Some(MessageContent::String(
1233                "I'll help you with that.".to_string(),
1234            )),
1235            name: None,
1236            tool_calls: Some(vec![ToolCall {
1237                id: "call_abc123".to_string(),
1238                r#type: "function".to_string(),
1239                function: FunctionCall {
1240                    name: "get_weather".to_string(),
1241                    arguments: r#"{"location": "Paris"}"#.to_string(),
1242                },
1243                metadata: None,
1244            }]),
1245            tool_call_id: None,
1246            usage: None,
1247            ..Default::default()
1248        };
1249
1250        let llm_message: LLMMessage = chat_message.into();
1251
1252        assert_eq!(llm_message.role, "assistant");
1253        match &llm_message.content {
1254            LLMMessageContent::List(parts) => {
1255                assert_eq!(parts.len(), 2, "Should have text and tool call");
1256
1257                // First part should be text
1258                match &parts[0] {
1259                    LLMMessageTypedContent::Text { text } => {
1260                        assert_eq!(text, "I'll help you with that.");
1261                    }
1262                    _ => panic!("Expected Text content part first"),
1263                }
1264
1265                // Second part should be tool call
1266                match &parts[1] {
1267                    LLMMessageTypedContent::ToolCall { id, name, args, .. } => {
1268                        assert_eq!(id, "call_abc123");
1269                        assert_eq!(name, "get_weather");
1270                        assert_eq!(args["location"], "Paris");
1271                    }
1272                    _ => panic!("Expected ToolCall content part second"),
1273                }
1274            }
1275            _ => panic!("Expected List content"),
1276        }
1277    }
1278
1279    #[test]
1280    fn test_extract_chatgpt_account_id_from_access_token() {
1281        use base64::Engine;
1282
1283        let claim = json!({
1284            "chatgpt_account_id": "acct_test_123"
1285        });
1286        let payload = json!({
1287            "https://api.openai.com/auth": claim
1288        });
1289        let encoded_payload =
1290            base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(payload.to_string().as_bytes());
1291        let access_token = format!("header.{}.signature", encoded_payload);
1292
1293        assert_eq!(
1294            OpenAIConfig::extract_chatgpt_account_id(&access_token),
1295            Some("acct_test_123".to_string())
1296        );
1297    }
1298
1299    #[test]
1300    fn test_extract_chatgpt_account_id_returns_none_for_missing_claim() {
1301        use base64::Engine;
1302
1303        let payload = json!({
1304            "sub": "user_123"
1305        });
1306        let encoded_payload =
1307            base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(payload.to_string().as_bytes());
1308        let access_token = format!("header.{}.signature", encoded_payload);
1309
1310        assert_eq!(
1311            OpenAIConfig::extract_chatgpt_account_id(&access_token),
1312            None
1313        );
1314    }
1315
1316    #[test]
1317    fn test_extract_chatgpt_account_id_returns_none_for_invalid_token_shape() {
1318        assert_eq!(OpenAIConfig::extract_chatgpt_account_id("not-a-jwt"), None);
1319    }
1320
1321    #[test]
1322    fn test_extract_chatgpt_account_id_returns_none_for_invalid_claim_json() {
1323        use base64::Engine;
1324
1325        let payload = json!({
1326            "https://api.openai.com/auth": "{not-json}"
1327        });
1328        let encoded_payload =
1329            base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(payload.to_string().as_bytes());
1330        let access_token = format!("header.{}.signature", encoded_payload);
1331
1332        assert_eq!(
1333            OpenAIConfig::extract_chatgpt_account_id(&access_token),
1334            None
1335        );
1336    }
1337}