Skip to main content

stakpak_shared/models/integrations/
openai.rs

1//! OpenAI provider configuration and chat message types
2//!
3//! This module contains:
4//! - Configuration types for OpenAI provider
5//! - OpenAI model enums with pricing info
6//! - Chat message types used throughout the TUI
7//! - Tool call types for agent interactions
8//!
9//! Note: Low-level API request/response types are in `libs/ai/src/providers/openai/`.
10
11use crate::models::llm::{
12    GenerationDelta, LLMMessage, LLMMessageContent, LLMMessageImageSource, LLMMessageTypedContent,
13    LLMTokenUsage, LLMTool,
14};
15use crate::models::model_pricing::{ContextAware, ContextPricingTier, ModelContextInfo};
16use serde::{Deserialize, Serialize};
17use serde_json::{Value, json};
18use uuid::Uuid;
19
20// =============================================================================
21// Provider Configuration
22// =============================================================================
23
24/// Configuration for OpenAI provider
25#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq)]
26pub struct OpenAIConfig {
27    pub api_endpoint: Option<String>,
28    pub api_key: Option<String>,
29}
30
31impl OpenAIConfig {
32    pub const OPENAI_CODEX_BASE_URL: &'static str = "https://chatgpt.com/backend-api/codex";
33    const OPENAI_AUTH_CLAIM: &'static str = "https://api.openai.com/auth";
34
35    /// Create config with API key
36    pub fn with_api_key(api_key: impl Into<String>) -> Self {
37        Self {
38            api_key: Some(api_key.into()),
39            api_endpoint: None,
40        }
41    }
42
43    /// Decode an OpenAI access token and extract the ChatGPT account ID.
44    ///
45    /// This intentionally reads the JWT payload without signature verification.
46    /// The claim is only used for request routing/header construction; OpenAI's
47    /// servers still validate the bearer token on use.
48    pub fn extract_chatgpt_account_id(access_token: &str) -> Option<String> {
49        let claims = crate::jwt::decode_jwt_payload_unverified(access_token)?;
50        let auth_claim = claims.get(Self::OPENAI_AUTH_CLAIM)?;
51
52        match auth_claim {
53            Value::Object(map) => map
54                .get("chatgpt_account_id")
55                .and_then(Value::as_str)
56                .map(ToString::to_string),
57            Value::String(raw_json) => {
58                serde_json::from_str::<Value>(raw_json)
59                    .ok()
60                    .and_then(|value| {
61                        value
62                            .get("chatgpt_account_id")
63                            .and_then(Value::as_str)
64                            .map(ToString::to_string)
65                    })
66            }
67            _ => None,
68        }
69    }
70}
71
72// =============================================================================
73// Model Definitions
74// =============================================================================
75
76/// OpenAI model identifiers
77#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
78pub enum OpenAIModel {
79    // Reasoning Models
80    #[serde(rename = "o3-2025-04-16")]
81    O3,
82    #[serde(rename = "o4-mini-2025-04-16")]
83    O4Mini,
84
85    #[default]
86    #[serde(rename = "gpt-5-2025-08-07")]
87    GPT5,
88    #[serde(rename = "gpt-5.1-2025-11-13")]
89    GPT51,
90    #[serde(rename = "gpt-5-mini-2025-08-07")]
91    GPT5Mini,
92    #[serde(rename = "gpt-5-nano-2025-08-07")]
93    GPT5Nano,
94
95    Custom(String),
96}
97
98impl OpenAIModel {
99    pub fn from_string(s: &str) -> Result<Self, String> {
100        serde_json::from_value(serde_json::Value::String(s.to_string()))
101            .map_err(|_| "Failed to deserialize OpenAI model".to_string())
102    }
103}
104
105impl ContextAware for OpenAIModel {
106    fn context_info(&self) -> ModelContextInfo {
107        let model_name = self.to_string();
108
109        if model_name.starts_with("o3") {
110            return ModelContextInfo {
111                max_tokens: 200_000,
112                pricing_tiers: vec![ContextPricingTier {
113                    label: "Standard".to_string(),
114                    input_cost_per_million: 2.0,
115                    output_cost_per_million: 8.0,
116                    upper_bound: None,
117                }],
118                approach_warning_threshold: 0.8,
119            };
120        }
121
122        if model_name.starts_with("o4-mini") {
123            return ModelContextInfo {
124                max_tokens: 200_000,
125                pricing_tiers: vec![ContextPricingTier {
126                    label: "Standard".to_string(),
127                    input_cost_per_million: 1.10,
128                    output_cost_per_million: 4.40,
129                    upper_bound: None,
130                }],
131                approach_warning_threshold: 0.8,
132            };
133        }
134
135        if model_name.starts_with("gpt-5-mini") {
136            return ModelContextInfo {
137                max_tokens: 400_000,
138                pricing_tiers: vec![ContextPricingTier {
139                    label: "Standard".to_string(),
140                    input_cost_per_million: 0.25,
141                    output_cost_per_million: 2.0,
142                    upper_bound: None,
143                }],
144                approach_warning_threshold: 0.8,
145            };
146        }
147
148        if model_name.starts_with("gpt-5-nano") {
149            return ModelContextInfo {
150                max_tokens: 400_000,
151                pricing_tiers: vec![ContextPricingTier {
152                    label: "Standard".to_string(),
153                    input_cost_per_million: 0.05,
154                    output_cost_per_million: 0.40,
155                    upper_bound: None,
156                }],
157                approach_warning_threshold: 0.8,
158            };
159        }
160
161        if model_name.starts_with("gpt-5") {
162            return ModelContextInfo {
163                max_tokens: 400_000,
164                pricing_tiers: vec![ContextPricingTier {
165                    label: "Standard".to_string(),
166                    input_cost_per_million: 1.25,
167                    output_cost_per_million: 10.0,
168                    upper_bound: None,
169                }],
170                approach_warning_threshold: 0.8,
171            };
172        }
173
174        ModelContextInfo::default()
175    }
176
177    fn model_name(&self) -> String {
178        match self {
179            OpenAIModel::O3 => "O3".to_string(),
180            OpenAIModel::O4Mini => "O4-mini".to_string(),
181            OpenAIModel::GPT5 => "GPT-5".to_string(),
182            OpenAIModel::GPT51 => "GPT-5.1".to_string(),
183            OpenAIModel::GPT5Mini => "GPT-5 Mini".to_string(),
184            OpenAIModel::GPT5Nano => "GPT-5 Nano".to_string(),
185            OpenAIModel::Custom(name) => format!("Custom ({})", name),
186        }
187    }
188}
189
190impl std::fmt::Display for OpenAIModel {
191    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
192        match self {
193            OpenAIModel::O3 => write!(f, "o3-2025-04-16"),
194            OpenAIModel::O4Mini => write!(f, "o4-mini-2025-04-16"),
195            OpenAIModel::GPT5Nano => write!(f, "gpt-5-nano-2025-08-07"),
196            OpenAIModel::GPT5Mini => write!(f, "gpt-5-mini-2025-08-07"),
197            OpenAIModel::GPT5 => write!(f, "gpt-5-2025-08-07"),
198            OpenAIModel::GPT51 => write!(f, "gpt-5.1-2025-11-13"),
199            OpenAIModel::Custom(model_name) => write!(f, "{}", model_name),
200        }
201    }
202}
203
204// =============================================================================
205// Message Types (used by TUI)
206// =============================================================================
207
208/// Message role
209#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
210#[serde(rename_all = "lowercase")]
211pub enum Role {
212    System,
213    Developer,
214    User,
215    #[default]
216    Assistant,
217    Tool,
218}
219
220impl std::fmt::Display for Role {
221    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
222        match self {
223            Role::System => write!(f, "system"),
224            Role::Developer => write!(f, "developer"),
225            Role::User => write!(f, "user"),
226            Role::Assistant => write!(f, "assistant"),
227            Role::Tool => write!(f, "tool"),
228        }
229    }
230}
231
232/// Model info for tracking which model generated a message
233#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
234pub struct ModelInfo {
235    /// Provider name (e.g., "anthropic", "openai")
236    pub provider: String,
237    /// Model identifier (e.g., "claude-sonnet-4-20250514", "gpt-4")
238    pub id: String,
239}
240
241/// Chat message
242#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
243pub struct ChatMessage {
244    pub role: Role,
245    pub content: Option<MessageContent>,
246    #[serde(skip_serializing_if = "Option::is_none")]
247    pub name: Option<String>,
248    #[serde(skip_serializing_if = "Option::is_none")]
249    pub tool_calls: Option<Vec<ToolCall>>,
250    #[serde(skip_serializing_if = "Option::is_none")]
251    pub tool_call_id: Option<String>,
252    #[serde(skip_serializing_if = "Option::is_none")]
253    pub usage: Option<LLMTokenUsage>,
254
255    // === Extended fields for session tracking ===
256    /// Unique message identifier
257    #[serde(skip_serializing_if = "Option::is_none")]
258    pub id: Option<String>,
259    /// Model that generated this message (for assistant messages)
260    #[serde(skip_serializing_if = "Option::is_none")]
261    pub model: Option<ModelInfo>,
262    /// Cost in dollars for this message
263    #[serde(skip_serializing_if = "Option::is_none")]
264    pub cost: Option<f64>,
265    /// Why the model stopped: "stop", "tool_calls", "length", "error"
266    #[serde(skip_serializing_if = "Option::is_none")]
267    pub finish_reason: Option<String>,
268    /// Unix timestamp (ms) when message was created/sent
269    #[serde(skip_serializing_if = "Option::is_none")]
270    pub created_at: Option<i64>,
271    /// Unix timestamp (ms) when assistant finished generating
272    #[serde(skip_serializing_if = "Option::is_none")]
273    pub completed_at: Option<i64>,
274    /// Plugin extensibility - unstructured metadata
275    #[serde(skip_serializing_if = "Option::is_none")]
276    pub metadata: Option<serde_json::Value>,
277}
278
279impl ChatMessage {
280    pub fn last_server_message(messages: &[ChatMessage]) -> Option<&ChatMessage> {
281        messages
282            .iter()
283            .rev()
284            .find(|message| message.role != Role::User && message.role != Role::Tool)
285    }
286
287    pub fn to_xml(&self) -> String {
288        match &self.content {
289            Some(MessageContent::String(s)) => {
290                format!("<message role=\"{}\">{}</message>", self.role, s)
291            }
292            Some(MessageContent::Array(parts)) => parts
293                .iter()
294                .map(|part| {
295                    format!(
296                        "<message role=\"{}\" type=\"{}\">{}</message>",
297                        self.role,
298                        part.r#type,
299                        part.text.clone().unwrap_or_default()
300                    )
301                })
302                .collect::<Vec<String>>()
303                .join("\n"),
304            None => String::new(),
305        }
306    }
307}
308
309/// Message content (string or array of parts)
310#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
311#[serde(untagged)]
312pub enum MessageContent {
313    String(String),
314    Array(Vec<ContentPart>),
315}
316
317impl MessageContent {
318    pub fn inject_checkpoint_id(&self, checkpoint_id: Uuid) -> Self {
319        match self {
320            MessageContent::String(s) => MessageContent::String(format!(
321                "<checkpoint_id>{checkpoint_id}</checkpoint_id>\n{s}"
322            )),
323            MessageContent::Array(parts) => MessageContent::Array(
324                std::iter::once(ContentPart {
325                    r#type: "text".to_string(),
326                    text: Some(format!("<checkpoint_id>{checkpoint_id}</checkpoint_id>")),
327                    image_url: None,
328                })
329                .chain(parts.iter().cloned())
330                .collect(),
331            ),
332        }
333    }
334
335    /// All indices from rfind()/find() of ASCII XML tags on same string
336    #[allow(clippy::string_slice)]
337    pub fn extract_checkpoint_id(&self) -> Option<Uuid> {
338        match self {
339            MessageContent::String(s) => s
340                .rfind("<checkpoint_id>")
341                .and_then(|start| {
342                    s[start..]
343                        .find("</checkpoint_id>")
344                        .map(|end| (start + "<checkpoint_id>".len(), start + end))
345                })
346                .and_then(|(start, end)| Uuid::parse_str(&s[start..end]).ok()),
347            MessageContent::Array(parts) => parts.iter().rev().find_map(|part| {
348                part.text.as_deref().and_then(|text| {
349                    text.rfind("<checkpoint_id>")
350                        .and_then(|start| {
351                            text[start..]
352                                .find("</checkpoint_id>")
353                                .map(|end| (start + "<checkpoint_id>".len(), start + end))
354                        })
355                        .and_then(|(start, end)| Uuid::parse_str(&text[start..end]).ok())
356                })
357            }),
358        }
359    }
360}
361
362impl std::fmt::Display for MessageContent {
363    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
364        match self {
365            MessageContent::String(s) => write!(f, "{s}"),
366            MessageContent::Array(parts) => {
367                let text_parts: Vec<String> =
368                    parts.iter().filter_map(|part| part.text.clone()).collect();
369                write!(f, "{}", text_parts.join("\n"))
370            }
371        }
372    }
373}
374
375impl Default for MessageContent {
376    fn default() -> Self {
377        MessageContent::String(String::new())
378    }
379}
380
381/// Content part (text or image)
382#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
383pub struct ContentPart {
384    pub r#type: String,
385    #[serde(skip_serializing_if = "Option::is_none")]
386    pub text: Option<String>,
387    #[serde(skip_serializing_if = "Option::is_none")]
388    pub image_url: Option<ImageUrl>,
389}
390
391/// Image URL with optional detail level
392#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
393pub struct ImageUrl {
394    pub url: String,
395    #[serde(skip_serializing_if = "Option::is_none")]
396    pub detail: Option<String>,
397}
398
399// =============================================================================
400// Tool Types (used by TUI)
401// =============================================================================
402
403/// Tool definition
404#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
405pub struct Tool {
406    pub r#type: String,
407    pub function: FunctionDefinition,
408}
409
410/// Function definition for tools
411#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
412pub struct FunctionDefinition {
413    pub name: String,
414    pub description: Option<String>,
415    pub parameters: serde_json::Value,
416}
417
418impl From<Tool> for LLMTool {
419    fn from(tool: Tool) -> Self {
420        LLMTool {
421            name: tool.function.name,
422            description: tool.function.description.unwrap_or_default(),
423            input_schema: tool.function.parameters,
424        }
425    }
426}
427
428/// Tool choice configuration
429#[derive(Debug, Clone, PartialEq)]
430pub enum ToolChoice {
431    Auto,
432    Required,
433    Object(ToolChoiceObject),
434}
435
436impl Serialize for ToolChoice {
437    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
438    where
439        S: serde::Serializer,
440    {
441        match self {
442            ToolChoice::Auto => serializer.serialize_str("auto"),
443            ToolChoice::Required => serializer.serialize_str("required"),
444            ToolChoice::Object(obj) => obj.serialize(serializer),
445        }
446    }
447}
448
449impl<'de> Deserialize<'de> for ToolChoice {
450    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
451    where
452        D: serde::Deserializer<'de>,
453    {
454        struct ToolChoiceVisitor;
455
456        impl<'de> serde::de::Visitor<'de> for ToolChoiceVisitor {
457            type Value = ToolChoice;
458
459            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
460                formatter.write_str("string or object")
461            }
462
463            fn visit_str<E>(self, value: &str) -> Result<ToolChoice, E>
464            where
465                E: serde::de::Error,
466            {
467                match value {
468                    "auto" => Ok(ToolChoice::Auto),
469                    "required" => Ok(ToolChoice::Required),
470                    _ => Err(serde::de::Error::unknown_variant(
471                        value,
472                        &["auto", "required"],
473                    )),
474                }
475            }
476
477            fn visit_map<M>(self, map: M) -> Result<ToolChoice, M::Error>
478            where
479                M: serde::de::MapAccess<'de>,
480            {
481                let obj = ToolChoiceObject::deserialize(
482                    serde::de::value::MapAccessDeserializer::new(map),
483                )?;
484                Ok(ToolChoice::Object(obj))
485            }
486        }
487
488        deserializer.deserialize_any(ToolChoiceVisitor)
489    }
490}
491
492#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
493pub struct ToolChoiceObject {
494    pub r#type: String,
495    pub function: FunctionChoice,
496}
497
498#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
499pub struct FunctionChoice {
500    pub name: String,
501}
502
503/// Tool call from assistant
504#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
505pub struct ToolCall {
506    pub id: String,
507    pub r#type: String,
508    pub function: FunctionCall,
509    /// Opaque provider-specific metadata (e.g., Gemini thought_signature)
510    #[serde(skip_serializing_if = "Option::is_none")]
511    pub metadata: Option<serde_json::Value>,
512}
513
514/// Function call details
515#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
516pub struct FunctionCall {
517    pub name: String,
518    pub arguments: String,
519}
520
521/// Tool call result status
522#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
523pub enum ToolCallResultStatus {
524    Success,
525    Error,
526    Cancelled,
527}
528
529/// Tool call result
530#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
531pub struct ToolCallResult {
532    pub call: ToolCall,
533    pub result: String,
534    pub status: ToolCallResultStatus,
535}
536
537/// Streaming progress info for a tool call being generated by the LLM
538#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
539pub struct ToolCallStreamInfo {
540    /// Tool name (may be empty if not yet streamed)
541    pub name: String,
542    /// Estimated token count of arguments streamed so far (~chars/4)
543    pub args_tokens: usize,
544    /// Optional description extracted from arguments (best-effort, may be None if JSON incomplete)
545    #[serde(skip_serializing_if = "Option::is_none")]
546    pub description: Option<String>,
547}
548
549/// Tool call result progress update
550#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
551pub struct ToolCallResultProgress {
552    pub id: Uuid,
553    pub message: String,
554    /// Type of progress update for specialized handling
555    #[serde(skip_serializing_if = "Option::is_none")]
556    pub progress_type: Option<ProgressType>,
557    /// Structured task updates for task wait progress
558    #[serde(skip_serializing_if = "Option::is_none")]
559    pub task_updates: Option<Vec<TaskUpdate>>,
560    /// Overall progress percentage (0.0 - 100.0)
561    #[serde(skip_serializing_if = "Option::is_none")]
562    pub progress: Option<f64>,
563}
564
565/// Type of progress update
566#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
567pub enum ProgressType {
568    /// Command output streaming
569    CommandOutput,
570    /// Task wait progress with structured updates
571    TaskWait,
572    /// Generic progress
573    Generic,
574}
575
576/// Structured task status update
577#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
578pub struct TaskUpdate {
579    pub task_id: String,
580    pub status: String,
581    #[serde(skip_serializing_if = "Option::is_none")]
582    pub description: Option<String>,
583    #[serde(skip_serializing_if = "Option::is_none")]
584    pub duration_secs: Option<f64>,
585    #[serde(skip_serializing_if = "Option::is_none")]
586    pub output_preview: Option<String>,
587    /// Whether this is a target task being waited on
588    #[serde(default)]
589    pub is_target: bool,
590    /// Pause information for paused subagent tasks
591    #[serde(skip_serializing_if = "Option::is_none")]
592    pub pause_info: Option<TaskPauseInfo>,
593}
594
595/// Pause information for subagent tasks awaiting approval
596#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
597pub struct TaskPauseInfo {
598    /// The agent's message before pausing
599    #[serde(skip_serializing_if = "Option::is_none")]
600    pub agent_message: Option<String>,
601    /// Pending tool calls awaiting approval
602    #[serde(skip_serializing_if = "Option::is_none")]
603    pub pending_tool_calls: Option<Vec<crate::models::async_manifest::PendingToolCall>>,
604}
605
606pub use crate::models::tools::ask_user::{
607    AskUserAnswer, AskUserOption, AskUserQuestion, AskUserRequest, AskUserResult,
608};
609
610/// Chat completion request
611#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
612pub struct ChatCompletionRequest {
613    pub model: String,
614    pub messages: Vec<ChatMessage>,
615    #[serde(skip_serializing_if = "Option::is_none")]
616    pub frequency_penalty: Option<f32>,
617    #[serde(skip_serializing_if = "Option::is_none")]
618    pub logit_bias: Option<serde_json::Value>,
619    #[serde(skip_serializing_if = "Option::is_none")]
620    pub logprobs: Option<bool>,
621    #[serde(skip_serializing_if = "Option::is_none")]
622    pub max_tokens: Option<u32>,
623    #[serde(skip_serializing_if = "Option::is_none")]
624    pub n: Option<u32>,
625    #[serde(skip_serializing_if = "Option::is_none")]
626    pub presence_penalty: Option<f32>,
627    #[serde(skip_serializing_if = "Option::is_none")]
628    pub response_format: Option<ResponseFormat>,
629    #[serde(skip_serializing_if = "Option::is_none")]
630    pub seed: Option<i64>,
631    #[serde(skip_serializing_if = "Option::is_none")]
632    pub stop: Option<StopSequence>,
633    #[serde(skip_serializing_if = "Option::is_none")]
634    pub stream: Option<bool>,
635    #[serde(skip_serializing_if = "Option::is_none")]
636    pub temperature: Option<f32>,
637    #[serde(skip_serializing_if = "Option::is_none")]
638    pub top_p: Option<f32>,
639    #[serde(skip_serializing_if = "Option::is_none")]
640    pub tools: Option<Vec<Tool>>,
641    #[serde(skip_serializing_if = "Option::is_none")]
642    pub tool_choice: Option<ToolChoice>,
643    #[serde(skip_serializing_if = "Option::is_none")]
644    pub user: Option<String>,
645    #[serde(skip_serializing_if = "Option::is_none")]
646    pub context: Option<ChatCompletionContext>,
647}
648
649impl ChatCompletionRequest {
650    pub fn new(
651        model: String,
652        messages: Vec<ChatMessage>,
653        tools: Option<Vec<Tool>>,
654        stream: Option<bool>,
655    ) -> Self {
656        Self {
657            model,
658            messages,
659            frequency_penalty: None,
660            logit_bias: None,
661            logprobs: None,
662            max_tokens: None,
663            n: None,
664            presence_penalty: None,
665            response_format: None,
666            seed: None,
667            stop: None,
668            stream,
669            temperature: None,
670            top_p: None,
671            tools,
672            tool_choice: None,
673            user: None,
674            context: None,
675        }
676    }
677}
678
679#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
680pub struct ChatCompletionContext {
681    pub scratchpad: Option<Value>,
682}
683
684#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
685pub struct ResponseFormat {
686    pub r#type: String,
687}
688
689#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
690#[serde(untagged)]
691pub enum StopSequence {
692    String(String),
693    Array(Vec<String>),
694}
695
696/// Chat completion response
697#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
698pub struct ChatCompletionResponse {
699    pub id: String,
700    pub object: String,
701    pub created: u64,
702    pub model: String,
703    pub choices: Vec<ChatCompletionChoice>,
704    pub usage: LLMTokenUsage,
705    #[serde(skip_serializing_if = "Option::is_none")]
706    pub system_fingerprint: Option<String>,
707    pub metadata: Option<serde_json::Value>,
708}
709
710#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
711pub struct ChatCompletionChoice {
712    pub index: usize,
713    pub message: ChatMessage,
714    pub logprobs: Option<LogProbs>,
715    pub finish_reason: FinishReason,
716}
717
718#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
719#[serde(rename_all = "snake_case")]
720pub enum FinishReason {
721    Stop,
722    Length,
723    ContentFilter,
724    ToolCalls,
725}
726
727#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
728pub struct LogProbs {
729    pub content: Option<Vec<LogProbContent>>,
730}
731
732#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
733pub struct LogProbContent {
734    pub token: String,
735    pub logprob: f32,
736    pub bytes: Option<Vec<u8>>,
737    pub top_logprobs: Option<Vec<TokenLogprob>>,
738}
739
740#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
741pub struct TokenLogprob {
742    pub token: String,
743    pub logprob: f32,
744    pub bytes: Option<Vec<u8>>,
745}
746
747// =============================================================================
748// Streaming Types
749// =============================================================================
750
751#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
752pub struct ChatCompletionStreamResponse {
753    pub id: String,
754    pub object: String,
755    pub created: u64,
756    pub model: String,
757    pub choices: Vec<ChatCompletionStreamChoice>,
758    pub usage: Option<LLMTokenUsage>,
759    pub metadata: Option<serde_json::Value>,
760}
761
762#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
763pub struct ChatCompletionStreamChoice {
764    pub index: usize,
765    pub delta: ChatMessageDelta,
766    pub finish_reason: Option<FinishReason>,
767}
768
769#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
770pub struct ChatMessageDelta {
771    #[serde(skip_serializing_if = "Option::is_none")]
772    pub role: Option<Role>,
773    #[serde(skip_serializing_if = "Option::is_none")]
774    pub content: Option<String>,
775    #[serde(skip_serializing_if = "Option::is_none")]
776    pub tool_calls: Option<Vec<ToolCallDelta>>,
777}
778
779#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
780pub struct ToolCallDelta {
781    pub index: usize,
782    pub id: Option<String>,
783    pub r#type: Option<String>,
784    pub function: Option<FunctionCallDelta>,
785    /// Opaque provider-specific metadata (e.g., Gemini thought_signature)
786    #[serde(skip_serializing_if = "Option::is_none")]
787    pub metadata: Option<serde_json::Value>,
788}
789
790#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
791pub struct FunctionCallDelta {
792    pub name: Option<String>,
793    pub arguments: Option<String>,
794}
795
796// =============================================================================
797// Conversions
798// =============================================================================
799
800impl From<LLMMessage> for ChatMessage {
801    fn from(llm_message: LLMMessage) -> Self {
802        let role = match llm_message.role.as_str() {
803            "system" => Role::System,
804            "user" => Role::User,
805            "assistant" => Role::Assistant,
806            "tool" => Role::Tool,
807            "developer" => Role::Developer,
808            _ => Role::User,
809        };
810
811        let (content, tool_calls, tool_call_id) = match llm_message.content {
812            LLMMessageContent::String(text) => (Some(MessageContent::String(text)), None, None),
813            LLMMessageContent::List(items) => {
814                let mut text_parts = Vec::new();
815                let mut tool_call_parts = Vec::new();
816                let mut tool_result_id: Option<String> = None;
817
818                for item in items {
819                    match item {
820                        LLMMessageTypedContent::Text { text } => {
821                            text_parts.push(ContentPart {
822                                r#type: "text".to_string(),
823                                text: Some(text),
824                                image_url: None,
825                            });
826                        }
827                        LLMMessageTypedContent::ToolCall {
828                            id,
829                            name,
830                            args,
831                            metadata,
832                        } => {
833                            tool_call_parts.push(ToolCall {
834                                id,
835                                r#type: "function".to_string(),
836                                function: FunctionCall {
837                                    name,
838                                    arguments: args.to_string(),
839                                },
840                                metadata,
841                            });
842                        }
843                        LLMMessageTypedContent::ToolResult {
844                            tool_use_id,
845                            content,
846                        } => {
847                            if tool_result_id.is_none() {
848                                tool_result_id = Some(tool_use_id);
849                            }
850                            text_parts.push(ContentPart {
851                                r#type: "text".to_string(),
852                                text: Some(content),
853                                image_url: None,
854                            });
855                        }
856                        LLMMessageTypedContent::Image { source } => {
857                            text_parts.push(ContentPart {
858                                r#type: "image_url".to_string(),
859                                text: None,
860                                image_url: Some(ImageUrl {
861                                    url: format!(
862                                        "data:{};base64,{}",
863                                        source.media_type, source.data
864                                    ),
865                                    detail: None,
866                                }),
867                            });
868                        }
869                    }
870                }
871
872                let content = if !text_parts.is_empty() {
873                    Some(MessageContent::Array(text_parts))
874                } else {
875                    None
876                };
877
878                let tool_calls = if !tool_call_parts.is_empty() {
879                    Some(tool_call_parts)
880                } else {
881                    None
882                };
883
884                (content, tool_calls, tool_result_id)
885            }
886        };
887
888        ChatMessage {
889            role,
890            content,
891            name: None,
892            tool_calls,
893            tool_call_id,
894            usage: None,
895            ..Default::default()
896        }
897    }
898}
899
900impl From<ChatMessage> for LLMMessage {
901    fn from(chat_message: ChatMessage) -> Self {
902        let mut content_parts = Vec::new();
903
904        match chat_message.content {
905            Some(MessageContent::String(s)) => {
906                if !s.is_empty() {
907                    content_parts.push(LLMMessageTypedContent::Text { text: s });
908                }
909            }
910            Some(MessageContent::Array(parts)) => {
911                for part in parts {
912                    if let Some(text) = part.text {
913                        content_parts.push(LLMMessageTypedContent::Text { text });
914                    } else if let Some(image_url) = part.image_url {
915                        let (media_type, data) = if image_url.url.starts_with("data:") {
916                            let parts: Vec<&str> = image_url.url.splitn(2, ',').collect();
917                            if parts.len() == 2 {
918                                let meta = parts[0];
919                                let data = parts[1];
920                                let media_type = meta
921                                    .trim_start_matches("data:")
922                                    .trim_end_matches(";base64")
923                                    .to_string();
924                                (media_type, data.to_string())
925                            } else {
926                                ("image/jpeg".to_string(), image_url.url)
927                            }
928                        } else {
929                            ("image/jpeg".to_string(), image_url.url)
930                        };
931
932                        content_parts.push(LLMMessageTypedContent::Image {
933                            source: LLMMessageImageSource {
934                                r#type: "base64".to_string(),
935                                media_type,
936                                data,
937                            },
938                        });
939                    }
940                }
941            }
942            None => {}
943        }
944
945        if let Some(tool_calls) = chat_message.tool_calls {
946            for tool_call in tool_calls {
947                let args = serde_json::from_str(&tool_call.function.arguments).unwrap_or(json!({}));
948                content_parts.push(LLMMessageTypedContent::ToolCall {
949                    id: tool_call.id,
950                    name: tool_call.function.name,
951                    args,
952                    metadata: tool_call.metadata,
953                });
954            }
955        }
956
957        // Handle tool result messages: when role is Tool and tool_call_id is present,
958        // convert the content to a ToolResult content part. This is the generic
959        // intermediate representation - each provider's conversion layer handles
960        // the specifics (e.g., Anthropic converts to user role with tool_result blocks)
961        if chat_message.role == Role::Tool
962            && let Some(tool_call_id) = chat_message.tool_call_id
963        {
964            // Extract content as string for the tool result
965            let content_str = content_parts
966                .iter()
967                .filter_map(|p| match p {
968                    LLMMessageTypedContent::Text { text } => Some(text.clone()),
969                    _ => None,
970                })
971                .collect::<Vec<_>>()
972                .join("\n");
973
974            // Replace content with a single ToolResult
975            content_parts = vec![LLMMessageTypedContent::ToolResult {
976                tool_use_id: tool_call_id,
977                content: content_str,
978            }];
979        }
980
981        LLMMessage {
982            role: chat_message.role.to_string(),
983            content: if content_parts.is_empty() {
984                LLMMessageContent::String(String::new())
985            } else if content_parts.len() == 1 {
986                match &content_parts[0] {
987                    LLMMessageTypedContent::Text { text } => {
988                        LLMMessageContent::String(text.clone())
989                    }
990                    _ => LLMMessageContent::List(content_parts),
991                }
992            } else {
993                LLMMessageContent::List(content_parts)
994            },
995        }
996    }
997}
998
999impl From<GenerationDelta> for ChatMessageDelta {
1000    fn from(delta: GenerationDelta) -> Self {
1001        match delta {
1002            GenerationDelta::Content { content } => ChatMessageDelta {
1003                role: Some(Role::Assistant),
1004                content: Some(content),
1005                tool_calls: None,
1006            },
1007            GenerationDelta::Thinking { thinking: _ } => ChatMessageDelta {
1008                role: Some(Role::Assistant),
1009                content: None,
1010                tool_calls: None,
1011            },
1012            GenerationDelta::ToolUse { tool_use } => ChatMessageDelta {
1013                role: Some(Role::Assistant),
1014                content: None,
1015                tool_calls: Some(vec![ToolCallDelta {
1016                    index: tool_use.index,
1017                    id: tool_use.id,
1018                    r#type: Some("function".to_string()),
1019                    function: Some(FunctionCallDelta {
1020                        name: tool_use.name,
1021                        arguments: tool_use.input,
1022                    }),
1023                    metadata: tool_use.metadata,
1024                }]),
1025            },
1026            _ => ChatMessageDelta {
1027                role: Some(Role::Assistant),
1028                content: None,
1029                tool_calls: None,
1030            },
1031        }
1032    }
1033}
1034
1035#[cfg(test)]
1036mod tests {
1037    use super::*;
1038
1039    #[test]
1040    fn test_serialize_basic_request() {
1041        let request = ChatCompletionRequest {
1042            model: "gpt-4".to_string(),
1043            messages: vec![
1044                ChatMessage {
1045                    role: Role::System,
1046                    content: Some(MessageContent::String(
1047                        "You are a helpful assistant.".to_string(),
1048                    )),
1049                    name: None,
1050                    tool_calls: None,
1051                    tool_call_id: None,
1052                    usage: None,
1053                    ..Default::default()
1054                },
1055                ChatMessage {
1056                    role: Role::User,
1057                    content: Some(MessageContent::String("Hello!".to_string())),
1058                    name: None,
1059                    tool_calls: None,
1060                    tool_call_id: None,
1061                    usage: None,
1062                    ..Default::default()
1063                },
1064            ],
1065            frequency_penalty: None,
1066            logit_bias: None,
1067            logprobs: None,
1068            max_tokens: Some(100),
1069            n: None,
1070            presence_penalty: None,
1071            response_format: None,
1072            seed: None,
1073            stop: None,
1074            stream: None,
1075            temperature: Some(0.7),
1076            top_p: None,
1077            tools: None,
1078            tool_choice: None,
1079            user: None,
1080            context: None,
1081        };
1082
1083        let json = serde_json::to_string(&request).unwrap();
1084        assert!(json.contains("\"model\":\"gpt-4\""));
1085        assert!(json.contains("\"messages\":["));
1086        assert!(json.contains("\"role\":\"system\""));
1087    }
1088
1089    #[test]
1090    fn test_llm_message_to_chat_message() {
1091        let llm_message = LLMMessage {
1092            role: "user".to_string(),
1093            content: LLMMessageContent::String("Hello, world!".to_string()),
1094        };
1095
1096        let chat_message = ChatMessage::from(llm_message);
1097        assert_eq!(chat_message.role, Role::User);
1098        match &chat_message.content {
1099            Some(MessageContent::String(text)) => assert_eq!(text, "Hello, world!"),
1100            _ => panic!("Expected string content"),
1101        }
1102    }
1103
1104    #[test]
1105    fn test_llm_message_to_chat_message_tool_result_preserves_tool_call_id() {
1106        let llm_message = LLMMessage {
1107            role: "tool".to_string(),
1108            content: LLMMessageContent::List(vec![LLMMessageTypedContent::ToolResult {
1109                tool_use_id: "toolu_01Abc123".to_string(),
1110                content: "Tool execution result".to_string(),
1111            }]),
1112        };
1113
1114        let chat_message = ChatMessage::from(llm_message);
1115        assert_eq!(chat_message.role, Role::Tool);
1116        assert_eq!(chat_message.tool_call_id.as_deref(), Some("toolu_01Abc123"));
1117        assert_eq!(
1118            chat_message.content,
1119            Some(MessageContent::Array(vec![ContentPart {
1120                r#type: "text".to_string(),
1121                text: Some("Tool execution result".to_string()),
1122                image_url: None,
1123            }]))
1124        );
1125    }
1126
1127    #[test]
1128    fn test_chat_message_to_llm_message_tool_result() {
1129        // Test that Tool role messages with tool_call_id are converted to ToolResult content
1130        // This is critical for Anthropic compatibility - the provider layer converts
1131        // role="tool" to role="user" with tool_result content blocks
1132        let chat_message = ChatMessage {
1133            role: Role::Tool,
1134            content: Some(MessageContent::String("Tool execution result".to_string())),
1135            name: None,
1136            tool_calls: None,
1137            tool_call_id: Some("toolu_01Abc123".to_string()),
1138            usage: None,
1139            ..Default::default()
1140        };
1141
1142        let llm_message: LLMMessage = chat_message.into();
1143
1144        // Role should be preserved as "tool" - provider layer handles conversion
1145        assert_eq!(llm_message.role, "tool");
1146
1147        // Content should be a ToolResult with the tool_call_id
1148        match &llm_message.content {
1149            LLMMessageContent::List(parts) => {
1150                assert_eq!(parts.len(), 1, "Should have exactly one content part");
1151                match &parts[0] {
1152                    LLMMessageTypedContent::ToolResult {
1153                        tool_use_id,
1154                        content,
1155                    } => {
1156                        assert_eq!(tool_use_id, "toolu_01Abc123");
1157                        assert_eq!(content, "Tool execution result");
1158                    }
1159                    _ => panic!("Expected ToolResult content part, got {:?}", parts[0]),
1160                }
1161            }
1162            _ => panic!(
1163                "Expected List content with ToolResult, got {:?}",
1164                llm_message.content
1165            ),
1166        }
1167    }
1168
1169    #[test]
1170    fn test_chat_message_to_llm_message_tool_result_empty_content() {
1171        // Test tool result with empty content
1172        let chat_message = ChatMessage {
1173            role: Role::Tool,
1174            content: None,
1175            name: None,
1176            tool_calls: None,
1177            tool_call_id: Some("toolu_02Xyz789".to_string()),
1178            usage: None,
1179            ..Default::default()
1180        };
1181
1182        let llm_message: LLMMessage = chat_message.into();
1183
1184        assert_eq!(llm_message.role, "tool");
1185        match &llm_message.content {
1186            LLMMessageContent::List(parts) => {
1187                assert_eq!(parts.len(), 1);
1188                match &parts[0] {
1189                    LLMMessageTypedContent::ToolResult {
1190                        tool_use_id,
1191                        content,
1192                    } => {
1193                        assert_eq!(tool_use_id, "toolu_02Xyz789");
1194                        assert_eq!(content, ""); // Empty content
1195                    }
1196                    _ => panic!("Expected ToolResult content part"),
1197                }
1198            }
1199            _ => panic!("Expected List content with ToolResult"),
1200        }
1201    }
1202
1203    #[test]
1204    fn test_chat_message_to_llm_message_assistant_with_tool_calls() {
1205        // Test that assistant messages with tool_calls are converted correctly
1206        let chat_message = ChatMessage {
1207            role: Role::Assistant,
1208            content: Some(MessageContent::String(
1209                "I'll help you with that.".to_string(),
1210            )),
1211            name: None,
1212            tool_calls: Some(vec![ToolCall {
1213                id: "call_abc123".to_string(),
1214                r#type: "function".to_string(),
1215                function: FunctionCall {
1216                    name: "get_weather".to_string(),
1217                    arguments: r#"{"location": "Paris"}"#.to_string(),
1218                },
1219                metadata: None,
1220            }]),
1221            tool_call_id: None,
1222            usage: None,
1223            ..Default::default()
1224        };
1225
1226        let llm_message: LLMMessage = chat_message.into();
1227
1228        assert_eq!(llm_message.role, "assistant");
1229        match &llm_message.content {
1230            LLMMessageContent::List(parts) => {
1231                assert_eq!(parts.len(), 2, "Should have text and tool call");
1232
1233                // First part should be text
1234                match &parts[0] {
1235                    LLMMessageTypedContent::Text { text } => {
1236                        assert_eq!(text, "I'll help you with that.");
1237                    }
1238                    _ => panic!("Expected Text content part first"),
1239                }
1240
1241                // Second part should be tool call
1242                match &parts[1] {
1243                    LLMMessageTypedContent::ToolCall { id, name, args, .. } => {
1244                        assert_eq!(id, "call_abc123");
1245                        assert_eq!(name, "get_weather");
1246                        assert_eq!(args["location"], "Paris");
1247                    }
1248                    _ => panic!("Expected ToolCall content part second"),
1249                }
1250            }
1251            _ => panic!("Expected List content"),
1252        }
1253    }
1254
1255    #[test]
1256    fn test_extract_chatgpt_account_id_from_access_token() {
1257        use base64::Engine;
1258
1259        let claim = json!({
1260            "chatgpt_account_id": "acct_test_123"
1261        });
1262        let payload = json!({
1263            "https://api.openai.com/auth": claim
1264        });
1265        let encoded_payload =
1266            base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(payload.to_string().as_bytes());
1267        let access_token = format!("header.{}.signature", encoded_payload);
1268
1269        assert_eq!(
1270            OpenAIConfig::extract_chatgpt_account_id(&access_token),
1271            Some("acct_test_123".to_string())
1272        );
1273    }
1274
1275    #[test]
1276    fn test_extract_chatgpt_account_id_returns_none_for_missing_claim() {
1277        use base64::Engine;
1278
1279        let payload = json!({
1280            "sub": "user_123"
1281        });
1282        let encoded_payload =
1283            base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(payload.to_string().as_bytes());
1284        let access_token = format!("header.{}.signature", encoded_payload);
1285
1286        assert_eq!(
1287            OpenAIConfig::extract_chatgpt_account_id(&access_token),
1288            None
1289        );
1290    }
1291
1292    #[test]
1293    fn test_extract_chatgpt_account_id_returns_none_for_invalid_token_shape() {
1294        assert_eq!(OpenAIConfig::extract_chatgpt_account_id("not-a-jwt"), None);
1295    }
1296
1297    #[test]
1298    fn test_extract_chatgpt_account_id_returns_none_for_invalid_claim_json() {
1299        use base64::Engine;
1300
1301        let payload = json!({
1302            "https://api.openai.com/auth": "{not-json}"
1303        });
1304        let encoded_payload =
1305            base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(payload.to_string().as_bytes());
1306        let access_token = format!("header.{}.signature", encoded_payload);
1307
1308        assert_eq!(
1309            OpenAIConfig::extract_chatgpt_account_id(&access_token),
1310            None
1311        );
1312    }
1313}