stakpak_shared/models/
llm.rs

1//! LLM Provider and Model Configuration
2//!
3//! This module provides the configuration types for LLM providers and models.
4//!
5//! # Provider Configuration
6//!
7//! Providers are configured in a `providers` HashMap where the key becomes the
8//! model prefix for routing requests to the correct provider.
9//!
10//! ## Built-in Providers
11//!
12//! - `openai` - OpenAI API
13//! - `anthropic` - Anthropic API (supports OAuth via `access_token`)
14//! - `gemini` - Google Gemini API
15//!
16//! For built-in providers, you can use the model name directly without a prefix:
17//! - `claude-sonnet-4-5` → auto-detected as Anthropic
18//! - `gpt-4` → auto-detected as OpenAI
19//! - `gemini-2.5-pro` → auto-detected as Gemini
20//!
21//! ## Custom Providers
22//!
23//! Any OpenAI-compatible API can be configured using `type = "custom"`.
24//! The provider key becomes the model prefix.
25//!
26//! # Model Routing
27//!
28//! Models can be specified with or without a provider prefix:
29//!
30//! - `claude-sonnet-4-5` → auto-detected as `anthropic` provider
31//! - `anthropic/claude-sonnet-4-5` → explicit `anthropic` provider
32//! - `offline/llama3` → routes to `offline` custom provider, sends `llama3` to API
33//! - `custom/anthropic/claude-opus` → routes to `custom` provider,
34//!   sends `anthropic/claude-opus` to the API
35//!
36//! # Example Configuration
37//!
38//! ```toml
39//! [profiles.default]
40//! provider = "local"
41//! smart_model = "claude-sonnet-4-5"  # auto-detected as anthropic
42//! eco_model = "offline/llama3"       # custom provider
43//!
44//! [profiles.default.providers.anthropic]
45//! type = "anthropic"
46//! # api_key from auth.toml or ANTHROPIC_API_KEY env var
47//!
48//! [profiles.default.providers.offline]
49//! type = "custom"
50//! api_endpoint = "http://localhost:11434/v1"
51//! ```
52
53use crate::models::{
54    integrations::{anthropic::AnthropicModel, gemini::GeminiModel, openai::OpenAIModel},
55    model_pricing::{ContextAware, ModelContextInfo},
56};
57use serde::{Deserialize, Serialize};
58use std::collections::HashMap;
59use std::fmt::Display;
60
61// =============================================================================
62// Provider Configuration
63// =============================================================================
64
65/// Unified provider configuration enum
66///
67/// All provider configurations are stored in a `HashMap<String, ProviderConfig>`
68/// where the key is the provider name and becomes the model prefix for routing.
69///
70/// # Provider Key = Model Prefix
71///
72/// The key used in the HashMap becomes the prefix used in model names:
73/// - Config key: `providers.offline`
74/// - Model usage: `offline/llama3`
75/// - Routing: finds `offline` provider, sends `llama3` to API
76///
77/// # Example TOML
78/// ```toml
79/// [profiles.myprofile.providers.openai]
80/// type = "openai"
81/// api_key = "sk-..."
82///
83/// [profiles.myprofile.providers.anthropic]
84/// type = "anthropic"
85/// api_key = "sk-ant-..."
86/// access_token = "oauth-token"
87///
88/// [profiles.myprofile.providers.offline]
89/// type = "custom"
90/// api_endpoint = "http://localhost:11434/v1"
91/// ```
92#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
93#[serde(tag = "type", rename_all = "lowercase")]
94pub enum ProviderConfig {
95    /// OpenAI provider configuration
96    OpenAI {
97        #[serde(skip_serializing_if = "Option::is_none")]
98        api_key: Option<String>,
99        #[serde(skip_serializing_if = "Option::is_none")]
100        api_endpoint: Option<String>,
101    },
102    /// Anthropic provider configuration
103    Anthropic {
104        #[serde(skip_serializing_if = "Option::is_none")]
105        api_key: Option<String>,
106        #[serde(skip_serializing_if = "Option::is_none")]
107        api_endpoint: Option<String>,
108        /// OAuth access token (for Claude subscription)
109        #[serde(skip_serializing_if = "Option::is_none")]
110        access_token: Option<String>,
111    },
112    /// Google Gemini provider configuration
113    Gemini {
114        #[serde(skip_serializing_if = "Option::is_none")]
115        api_key: Option<String>,
116        #[serde(skip_serializing_if = "Option::is_none")]
117        api_endpoint: Option<String>,
118    },
119    /// Custom OpenAI-compatible provider (Ollama, vLLM, etc.)
120    ///
121    /// The provider key in the config becomes the model prefix.
122    /// For example, if configured as `providers.offline`, use models as:
123    /// - `offline/llama3` - passes `llama3` to the API
124    /// - `offline/anthropic/claude-opus` - passes `anthropic/claude-opus` to the API
125    ///
126    /// # Example TOML
127    /// ```toml
128    /// [profiles.myprofile.providers.offline]
129    /// type = "custom"
130    /// api_endpoint = "http://localhost:11434/v1"
131    ///
132    /// # Then use models as:
133    /// smart_model = "offline/llama3"
134    /// eco_model = "offline/phi3"
135    /// ```
136    Custom {
137        #[serde(skip_serializing_if = "Option::is_none")]
138        api_key: Option<String>,
139        /// API endpoint URL (required for custom providers)
140        /// Use the base URL as required by your provider (e.g., "http://localhost:11434/v1")
141        api_endpoint: String,
142    },
143}
144
145impl ProviderConfig {
146    /// Get the provider type name
147    pub fn provider_type(&self) -> &'static str {
148        match self {
149            ProviderConfig::OpenAI { .. } => "openai",
150            ProviderConfig::Anthropic { .. } => "anthropic",
151            ProviderConfig::Gemini { .. } => "gemini",
152            ProviderConfig::Custom { .. } => "custom",
153        }
154    }
155
156    /// Get the API key if set
157    pub fn api_key(&self) -> Option<&str> {
158        match self {
159            ProviderConfig::OpenAI { api_key, .. } => api_key.as_deref(),
160            ProviderConfig::Anthropic { api_key, .. } => api_key.as_deref(),
161            ProviderConfig::Gemini { api_key, .. } => api_key.as_deref(),
162            ProviderConfig::Custom { api_key, .. } => api_key.as_deref(),
163        }
164    }
165
166    /// Get the API endpoint if set
167    pub fn api_endpoint(&self) -> Option<&str> {
168        match self {
169            ProviderConfig::OpenAI { api_endpoint, .. } => api_endpoint.as_deref(),
170            ProviderConfig::Anthropic { api_endpoint, .. } => api_endpoint.as_deref(),
171            ProviderConfig::Gemini { api_endpoint, .. } => api_endpoint.as_deref(),
172            ProviderConfig::Custom { api_endpoint, .. } => Some(api_endpoint.as_str()),
173        }
174    }
175
176    /// Get the access token (Anthropic only)
177    pub fn access_token(&self) -> Option<&str> {
178        match self {
179            ProviderConfig::Anthropic { access_token, .. } => access_token.as_deref(),
180            _ => None,
181        }
182    }
183
184    /// Create an OpenAI provider config
185    pub fn openai(api_key: Option<String>) -> Self {
186        ProviderConfig::OpenAI {
187            api_key,
188            api_endpoint: None,
189        }
190    }
191
192    /// Create an Anthropic provider config
193    pub fn anthropic(api_key: Option<String>, access_token: Option<String>) -> Self {
194        ProviderConfig::Anthropic {
195            api_key,
196            api_endpoint: None,
197            access_token,
198        }
199    }
200
201    /// Create a Gemini provider config
202    pub fn gemini(api_key: Option<String>) -> Self {
203        ProviderConfig::Gemini {
204            api_key,
205            api_endpoint: None,
206        }
207    }
208
209    /// Create a custom provider config
210    pub fn custom(api_endpoint: String, api_key: Option<String>) -> Self {
211        ProviderConfig::Custom {
212            api_key,
213            api_endpoint,
214        }
215    }
216}
217
218#[derive(Clone, Debug, PartialEq, Serialize)]
219pub enum LLMModel {
220    Anthropic(AnthropicModel),
221    Gemini(GeminiModel),
222    OpenAI(OpenAIModel),
223    /// Custom provider with explicit provider name and model.
224    ///
225    /// Used for custom OpenAI-compatible providers like LiteLLM, Ollama, etc.
226    /// The provider name matches the key in the `providers` HashMap config.
227    ///
228    /// # Examples
229    /// - `litellm/claude-opus` → `provider: "litellm"`, `model: "claude-opus"`
230    /// - `litellm/anthropic/claude-opus` → `provider: "litellm"`, `model: "anthropic/claude-opus"`
231    /// - `ollama/llama3` → `provider: "ollama"`, `model: "llama3"`
232    Custom {
233        /// Provider name matching the key in providers config (e.g., "litellm", "ollama")
234        provider: String,
235        /// Model name/path to pass to the provider API (can include nested prefixes)
236        model: String,
237    },
238}
239
240impl ContextAware for LLMModel {
241    fn context_info(&self) -> ModelContextInfo {
242        match self {
243            LLMModel::Anthropic(model) => model.context_info(),
244            LLMModel::Gemini(model) => model.context_info(),
245            LLMModel::OpenAI(model) => model.context_info(),
246            LLMModel::Custom { .. } => ModelContextInfo::default(),
247        }
248    }
249
250    fn model_name(&self) -> String {
251        match self {
252            LLMModel::Anthropic(model) => model.model_name(),
253            LLMModel::Gemini(model) => model.model_name(),
254            LLMModel::OpenAI(model) => model.model_name(),
255            LLMModel::Custom { provider, model } => format!("{}/{}", provider, model),
256        }
257    }
258}
259
260/// Aggregated provider configuration for LLM operations
261///
262/// This struct holds all configured providers, keyed by provider name.
263#[derive(Debug, Clone, Default)]
264pub struct LLMProviderConfig {
265    /// All provider configurations (key = provider name)
266    pub providers: HashMap<String, ProviderConfig>,
267}
268
269impl LLMProviderConfig {
270    /// Create a new empty provider config
271    pub fn new() -> Self {
272        Self {
273            providers: HashMap::new(),
274        }
275    }
276
277    /// Add a provider configuration
278    pub fn add_provider(&mut self, name: impl Into<String>, config: ProviderConfig) {
279        self.providers.insert(name.into(), config);
280    }
281
282    /// Get a provider configuration by name
283    pub fn get_provider(&self, name: &str) -> Option<&ProviderConfig> {
284        self.providers.get(name)
285    }
286
287    /// Check if any providers are configured
288    pub fn is_empty(&self) -> bool {
289        self.providers.is_empty()
290    }
291}
292
293impl From<String> for LLMModel {
294    /// Parse a model string into an LLMModel.
295    ///
296    /// # Format
297    /// - `provider/model` - Explicit provider prefix
298    /// - `provider/nested/model` - Provider with nested model path (e.g., for LiteLLM)
299    /// - `model-name` - Auto-detect provider from model name
300    ///
301    /// # Examples
302    /// - `"litellm/anthropic/claude-opus"` → Custom { provider: "litellm", model: "anthropic/claude-opus" }
303    /// - `"anthropic/claude-opus-4-5"` → Anthropic(Claude45Opus) (built-in provider)
304    /// - `"claude-opus-4-5"` → Anthropic(Claude45Opus) (auto-detected)
305    /// - `"ollama/llama3"` → Custom { provider: "ollama", model: "llama3" }
306    fn from(value: String) -> Self {
307        // Check for explicit provider/model format (e.g., "litellm/anthropic/claude-opus")
308        // split_once takes only the first segment as provider, rest is the model path
309        if let Some((provider, model)) = value.split_once('/') {
310            // Check if it's a known built-in provider with explicit prefix
311            match provider {
312                "anthropic" => return Self::from_model_name(model),
313                "openai" => return Self::from_model_name(model),
314                "google" | "gemini" => return Self::from_model_name(model),
315                // Unknown provider = custom provider (model can contain additional slashes)
316                _ => {
317                    return LLMModel::Custom {
318                        provider: provider.to_string(),
319                        model: model.to_string(), // Preserves nested paths like "anthropic/claude-opus"
320                    };
321                }
322            }
323        }
324
325        // Fall back to auto-detection by model name prefix
326        Self::from_model_name(&value)
327    }
328}
329
330impl LLMModel {
331    /// Parse model name without provider prefix
332    fn from_model_name(model: &str) -> Self {
333        if model.starts_with("claude-haiku-4-5") {
334            LLMModel::Anthropic(AnthropicModel::Claude45Haiku)
335        } else if model.starts_with("claude-sonnet-4-5") {
336            LLMModel::Anthropic(AnthropicModel::Claude45Sonnet)
337        } else if model.starts_with("claude-opus-4-5") {
338            LLMModel::Anthropic(AnthropicModel::Claude45Opus)
339        } else if model == "gemini-2.5-flash-lite" {
340            LLMModel::Gemini(GeminiModel::Gemini25FlashLite)
341        } else if model.starts_with("gemini-2.5-flash") {
342            LLMModel::Gemini(GeminiModel::Gemini25Flash)
343        } else if model.starts_with("gemini-2.5-pro") {
344            LLMModel::Gemini(GeminiModel::Gemini25Pro)
345        } else if model.starts_with("gemini-3-pro-preview") {
346            LLMModel::Gemini(GeminiModel::Gemini3Pro)
347        } else if model.starts_with("gemini-3-flash-preview") {
348            LLMModel::Gemini(GeminiModel::Gemini3Flash)
349        } else if model.starts_with("gpt-5-mini") {
350            LLMModel::OpenAI(OpenAIModel::GPT5Mini)
351        } else if model.starts_with("gpt-5") {
352            LLMModel::OpenAI(OpenAIModel::GPT5)
353        } else {
354            // Unknown model without provider prefix - treat as custom with "custom" provider
355            LLMModel::Custom {
356                provider: "custom".to_string(),
357                model: model.to_string(),
358            }
359        }
360    }
361
362    /// Get the provider name for this model
363    pub fn provider_name(&self) -> &str {
364        match self {
365            LLMModel::Anthropic(_) => "anthropic",
366            LLMModel::Gemini(_) => "google",
367            LLMModel::OpenAI(_) => "openai",
368            LLMModel::Custom { provider, .. } => provider,
369        }
370    }
371
372    /// Get just the model name without provider prefix
373    pub fn model_id(&self) -> String {
374        match self {
375            LLMModel::Anthropic(m) => m.to_string(),
376            LLMModel::Gemini(m) => m.to_string(),
377            LLMModel::OpenAI(m) => m.to_string(),
378            LLMModel::Custom { model, .. } => model.clone(),
379        }
380    }
381}
382
383impl Display for LLMModel {
384    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
385        match self {
386            LLMModel::Anthropic(model) => write!(f, "{}", model),
387            LLMModel::Gemini(model) => write!(f, "{}", model),
388            LLMModel::OpenAI(model) => write!(f, "{}", model),
389            LLMModel::Custom { provider, model } => write!(f, "{}/{}", provider, model),
390        }
391    }
392}
393
394/// Provider-specific options for LLM requests
395#[derive(Clone, Debug, Serialize, Deserialize, Default)]
396pub struct LLMProviderOptions {
397    /// Anthropic-specific options
398    #[serde(skip_serializing_if = "Option::is_none")]
399    pub anthropic: Option<LLMAnthropicOptions>,
400
401    /// OpenAI-specific options
402    #[serde(skip_serializing_if = "Option::is_none")]
403    pub openai: Option<LLMOpenAIOptions>,
404
405    /// Google/Gemini-specific options
406    #[serde(skip_serializing_if = "Option::is_none")]
407    pub google: Option<LLMGoogleOptions>,
408}
409
410/// Anthropic-specific options
411#[derive(Clone, Debug, Serialize, Deserialize, Default)]
412pub struct LLMAnthropicOptions {
413    /// Extended thinking configuration
414    #[serde(skip_serializing_if = "Option::is_none")]
415    pub thinking: Option<LLMThinkingOptions>,
416}
417
418/// Thinking/reasoning options
419#[derive(Clone, Debug, Serialize, Deserialize)]
420pub struct LLMThinkingOptions {
421    /// Budget tokens for thinking (must be >= 1024)
422    pub budget_tokens: u32,
423}
424
425impl LLMThinkingOptions {
426    pub fn new(budget_tokens: u32) -> Self {
427        Self {
428            budget_tokens: budget_tokens.max(1024),
429        }
430    }
431}
432
433/// OpenAI-specific options
434#[derive(Clone, Debug, Serialize, Deserialize, Default)]
435pub struct LLMOpenAIOptions {
436    /// Reasoning effort for o1/o3/o4 models ("low", "medium", "high")
437    #[serde(skip_serializing_if = "Option::is_none")]
438    pub reasoning_effort: Option<String>,
439}
440
441/// Google/Gemini-specific options
442#[derive(Clone, Debug, Serialize, Deserialize, Default)]
443pub struct LLMGoogleOptions {
444    /// Thinking budget in tokens
445    #[serde(skip_serializing_if = "Option::is_none")]
446    pub thinking_budget: Option<u32>,
447}
448
449#[derive(Clone, Debug, Serialize)]
450pub struct LLMInput {
451    pub model: LLMModel,
452    pub messages: Vec<LLMMessage>,
453    pub max_tokens: u32,
454    pub tools: Option<Vec<LLMTool>>,
455    #[serde(skip_serializing_if = "Option::is_none")]
456    pub provider_options: Option<LLMProviderOptions>,
457}
458
459#[derive(Debug)]
460pub struct LLMStreamInput {
461    pub model: LLMModel,
462    pub messages: Vec<LLMMessage>,
463    pub max_tokens: u32,
464    pub stream_channel_tx: tokio::sync::mpsc::Sender<GenerationDelta>,
465    pub tools: Option<Vec<LLMTool>>,
466    pub provider_options: Option<LLMProviderOptions>,
467}
468
469impl From<&LLMStreamInput> for LLMInput {
470    fn from(value: &LLMStreamInput) -> Self {
471        LLMInput {
472            model: value.model.clone(),
473            messages: value.messages.clone(),
474            max_tokens: value.max_tokens,
475            tools: value.tools.clone(),
476            provider_options: value.provider_options.clone(),
477        }
478    }
479}
480
481#[derive(Serialize, Deserialize, Debug, Clone, Default)]
482pub struct LLMMessage {
483    pub role: String,
484    pub content: LLMMessageContent,
485}
486
487#[derive(Serialize, Deserialize, Debug, Clone)]
488pub struct SimpleLLMMessage {
489    #[serde(rename = "role")]
490    pub role: SimpleLLMRole,
491    pub content: String,
492}
493
494#[derive(Serialize, Deserialize, Debug, Clone)]
495#[serde(rename_all = "lowercase")]
496pub enum SimpleLLMRole {
497    User,
498    Assistant,
499}
500
501impl std::fmt::Display for SimpleLLMRole {
502    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
503        match self {
504            SimpleLLMRole::User => write!(f, "user"),
505            SimpleLLMRole::Assistant => write!(f, "assistant"),
506        }
507    }
508}
509
510#[derive(Serialize, Deserialize, Debug, Clone)]
511#[serde(untagged)]
512pub enum LLMMessageContent {
513    String(String),
514    List(Vec<LLMMessageTypedContent>),
515}
516
517#[allow(clippy::to_string_trait_impl)]
518impl ToString for LLMMessageContent {
519    fn to_string(&self) -> String {
520        match self {
521            LLMMessageContent::String(s) => s.clone(),
522            LLMMessageContent::List(l) => l
523                .iter()
524                .map(|c| match c {
525                    LLMMessageTypedContent::Text { text } => text.clone(),
526                    LLMMessageTypedContent::ToolCall { .. } => String::new(),
527                    LLMMessageTypedContent::ToolResult { content, .. } => content.clone(),
528                    LLMMessageTypedContent::Image { .. } => String::new(),
529                })
530                .collect::<Vec<_>>()
531                .join("\n"),
532        }
533    }
534}
535
536impl From<String> for LLMMessageContent {
537    fn from(value: String) -> Self {
538        LLMMessageContent::String(value)
539    }
540}
541
542impl Default for LLMMessageContent {
543    fn default() -> Self {
544        LLMMessageContent::String(String::new())
545    }
546}
547
548#[derive(Serialize, Deserialize, Debug, Clone)]
549#[serde(tag = "type")]
550pub enum LLMMessageTypedContent {
551    #[serde(rename = "text")]
552    Text { text: String },
553    #[serde(rename = "tool_use")]
554    ToolCall {
555        id: String,
556        name: String,
557        #[serde(alias = "input")]
558        args: serde_json::Value,
559    },
560    #[serde(rename = "tool_result")]
561    ToolResult {
562        tool_use_id: String,
563        content: String,
564    },
565    #[serde(rename = "image")]
566    Image { source: LLMMessageImageSource },
567}
568
569#[derive(Serialize, Deserialize, Debug, Clone)]
570pub struct LLMMessageImageSource {
571    #[serde(rename = "type")]
572    pub r#type: String,
573    pub media_type: String,
574    pub data: String,
575}
576
577impl Default for LLMMessageTypedContent {
578    fn default() -> Self {
579        LLMMessageTypedContent::Text {
580            text: String::new(),
581        }
582    }
583}
584
585#[derive(Serialize, Deserialize, Debug, Clone)]
586pub struct LLMChoice {
587    pub finish_reason: Option<String>,
588    pub index: u32,
589    pub message: LLMMessage,
590}
591
592#[derive(Serialize, Deserialize, Debug, Clone)]
593pub struct LLMCompletionResponse {
594    pub model: String,
595    pub object: String,
596    pub choices: Vec<LLMChoice>,
597    pub created: u64,
598    pub usage: Option<LLMTokenUsage>,
599    pub id: String,
600}
601
602#[derive(Serialize, Deserialize, Debug, Clone)]
603pub struct LLMStreamDelta {
604    #[serde(skip_serializing_if = "Option::is_none")]
605    pub content: Option<String>,
606}
607
608#[derive(Serialize, Deserialize, Debug, Clone)]
609pub struct LLMStreamChoice {
610    pub finish_reason: Option<String>,
611    pub index: u32,
612    pub message: Option<LLMMessage>,
613    pub delta: LLMStreamDelta,
614}
615
616#[derive(Serialize, Deserialize, Debug, Clone)]
617pub struct LLMCompletionStreamResponse {
618    pub model: String,
619    pub object: String,
620    pub choices: Vec<LLMStreamChoice>,
621    pub created: u64,
622    #[serde(skip_serializing_if = "Option::is_none")]
623    pub usage: Option<LLMTokenUsage>,
624    pub id: String,
625    pub citations: Option<Vec<String>>,
626}
627
628#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
629pub struct LLMTool {
630    pub name: String,
631    pub description: String,
632    pub input_schema: serde_json::Value,
633}
634
635#[derive(Default, Debug, Serialize, Deserialize, Clone, PartialEq)]
636pub struct LLMTokenUsage {
637    pub prompt_tokens: u32,
638    pub completion_tokens: u32,
639    pub total_tokens: u32,
640
641    #[serde(skip_serializing_if = "Option::is_none")]
642    pub prompt_tokens_details: Option<PromptTokensDetails>,
643}
644
645#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
646#[serde(rename_all = "snake_case")]
647pub enum TokenType {
648    InputTokens,
649    OutputTokens,
650    CacheReadInputTokens,
651    CacheWriteInputTokens,
652}
653
654#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
655pub struct PromptTokensDetails {
656    #[serde(skip_serializing_if = "Option::is_none")]
657    pub input_tokens: Option<u32>,
658    #[serde(skip_serializing_if = "Option::is_none")]
659    pub output_tokens: Option<u32>,
660    #[serde(skip_serializing_if = "Option::is_none")]
661    pub cache_read_input_tokens: Option<u32>,
662    #[serde(skip_serializing_if = "Option::is_none")]
663    pub cache_write_input_tokens: Option<u32>,
664}
665
666impl PromptTokensDetails {
667    /// Returns an iterator over the token types and their values
668    pub fn iter(&self) -> impl Iterator<Item = (TokenType, u32)> {
669        [
670            (TokenType::InputTokens, self.input_tokens.unwrap_or(0)),
671            (TokenType::OutputTokens, self.output_tokens.unwrap_or(0)),
672            (
673                TokenType::CacheReadInputTokens,
674                self.cache_read_input_tokens.unwrap_or(0),
675            ),
676            (
677                TokenType::CacheWriteInputTokens,
678                self.cache_write_input_tokens.unwrap_or(0),
679            ),
680        ]
681        .into_iter()
682    }
683}
684
685impl std::ops::Add for PromptTokensDetails {
686    type Output = Self;
687
688    fn add(self, rhs: Self) -> Self::Output {
689        Self {
690            input_tokens: Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0)),
691            output_tokens: Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0)),
692            cache_read_input_tokens: Some(
693                self.cache_read_input_tokens.unwrap_or(0)
694                    + rhs.cache_read_input_tokens.unwrap_or(0),
695            ),
696            cache_write_input_tokens: Some(
697                self.cache_write_input_tokens.unwrap_or(0)
698                    + rhs.cache_write_input_tokens.unwrap_or(0),
699            ),
700        }
701    }
702}
703
704impl std::ops::AddAssign for PromptTokensDetails {
705    fn add_assign(&mut self, rhs: Self) {
706        self.input_tokens = Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0));
707        self.output_tokens = Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0));
708        self.cache_read_input_tokens = Some(
709            self.cache_read_input_tokens.unwrap_or(0) + rhs.cache_read_input_tokens.unwrap_or(0),
710        );
711        self.cache_write_input_tokens = Some(
712            self.cache_write_input_tokens.unwrap_or(0) + rhs.cache_write_input_tokens.unwrap_or(0),
713        );
714    }
715}
716
717#[derive(Serialize, Deserialize, Debug, Clone)]
718#[serde(tag = "type")]
719pub enum GenerationDelta {
720    Content { content: String },
721    Thinking { thinking: String },
722    ToolUse { tool_use: GenerationDeltaToolUse },
723    Usage { usage: LLMTokenUsage },
724    Metadata { metadata: serde_json::Value },
725}
726
727#[derive(Serialize, Deserialize, Debug, Clone)]
728pub struct GenerationDeltaToolUse {
729    pub id: Option<String>,
730    pub name: Option<String>,
731    pub input: Option<String>,
732    pub index: usize,
733}
734
735#[cfg(test)]
736mod tests {
737    use super::*;
738
739    #[test]
740    fn test_llm_model_from_known_anthropic_model() {
741        let model = LLMModel::from("claude-opus-4-5-20251101".to_string());
742        assert!(matches!(
743            model,
744            LLMModel::Anthropic(AnthropicModel::Claude45Opus)
745        ));
746    }
747
748    #[test]
749    fn test_llm_model_from_known_openai_model() {
750        let model = LLMModel::from("gpt-5".to_string());
751        assert!(matches!(model, LLMModel::OpenAI(OpenAIModel::GPT5)));
752    }
753
754    #[test]
755    fn test_llm_model_from_known_gemini_model() {
756        let model = LLMModel::from("gemini-2.5-flash".to_string());
757        assert!(matches!(
758            model,
759            LLMModel::Gemini(GeminiModel::Gemini25Flash)
760        ));
761    }
762
763    #[test]
764    fn test_llm_model_from_custom_provider_with_slash() {
765        let model = LLMModel::from("litellm/claude-opus-4-5".to_string());
766        match model {
767            LLMModel::Custom { provider, model } => {
768                assert_eq!(provider, "litellm");
769                assert_eq!(model, "claude-opus-4-5");
770            }
771            _ => panic!("Expected Custom model"),
772        }
773    }
774
775    #[test]
776    fn test_llm_model_from_ollama_provider() {
777        let model = LLMModel::from("ollama/llama3".to_string());
778        match model {
779            LLMModel::Custom { provider, model } => {
780                assert_eq!(provider, "ollama");
781                assert_eq!(model, "llama3");
782            }
783            _ => panic!("Expected Custom model"),
784        }
785    }
786
787    #[test]
788    fn test_llm_model_explicit_anthropic_prefix() {
789        // Explicit anthropic/ prefix should still parse to Anthropic variant
790        let model = LLMModel::from("anthropic/claude-opus-4-5".to_string());
791        assert!(matches!(
792            model,
793            LLMModel::Anthropic(AnthropicModel::Claude45Opus)
794        ));
795    }
796
797    #[test]
798    fn test_llm_model_explicit_openai_prefix() {
799        let model = LLMModel::from("openai/gpt-5".to_string());
800        assert!(matches!(model, LLMModel::OpenAI(OpenAIModel::GPT5)));
801    }
802
803    #[test]
804    fn test_llm_model_explicit_google_prefix() {
805        let model = LLMModel::from("google/gemini-2.5-flash".to_string());
806        assert!(matches!(
807            model,
808            LLMModel::Gemini(GeminiModel::Gemini25Flash)
809        ));
810    }
811
812    #[test]
813    fn test_llm_model_explicit_gemini_prefix() {
814        // gemini/ alias should also work
815        let model = LLMModel::from("gemini/gemini-2.5-flash".to_string());
816        assert!(matches!(
817            model,
818            LLMModel::Gemini(GeminiModel::Gemini25Flash)
819        ));
820    }
821
822    #[test]
823    fn test_llm_model_unknown_model_becomes_custom() {
824        let model = LLMModel::from("some-random-model".to_string());
825        match model {
826            LLMModel::Custom { provider, model } => {
827                assert_eq!(provider, "custom");
828                assert_eq!(model, "some-random-model");
829            }
830            _ => panic!("Expected Custom model"),
831        }
832    }
833
834    #[test]
835    fn test_llm_model_display_anthropic() {
836        let model = LLMModel::Anthropic(AnthropicModel::Claude45Sonnet);
837        let s = model.to_string();
838        assert!(s.contains("claude"));
839    }
840
841    #[test]
842    fn test_llm_model_display_custom() {
843        let model = LLMModel::Custom {
844            provider: "litellm".to_string(),
845            model: "claude-opus".to_string(),
846        };
847        assert_eq!(model.to_string(), "litellm/claude-opus");
848    }
849
850    #[test]
851    fn test_llm_model_provider_name() {
852        assert_eq!(
853            LLMModel::Anthropic(AnthropicModel::Claude45Sonnet).provider_name(),
854            "anthropic"
855        );
856        assert_eq!(
857            LLMModel::OpenAI(OpenAIModel::GPT5).provider_name(),
858            "openai"
859        );
860        assert_eq!(
861            LLMModel::Gemini(GeminiModel::Gemini25Flash).provider_name(),
862            "google"
863        );
864        assert_eq!(
865            LLMModel::Custom {
866                provider: "litellm".to_string(),
867                model: "test".to_string()
868            }
869            .provider_name(),
870            "litellm"
871        );
872    }
873
874    #[test]
875    fn test_llm_model_model_id() {
876        let model = LLMModel::Custom {
877            provider: "litellm".to_string(),
878            model: "claude-opus-4-5".to_string(),
879        };
880        assert_eq!(model.model_id(), "claude-opus-4-5");
881    }
882
883    // =========================================================================
884    // ProviderConfig Tests
885    // =========================================================================
886
887    #[test]
888    fn test_provider_config_openai_serialization() {
889        let config = ProviderConfig::OpenAI {
890            api_key: Some("sk-test".to_string()),
891            api_endpoint: None,
892        };
893        let json = serde_json::to_string(&config).unwrap();
894        assert!(json.contains("\"type\":\"openai\""));
895        assert!(json.contains("\"api_key\":\"sk-test\""));
896        assert!(!json.contains("api_endpoint")); // Should be skipped when None
897    }
898
899    #[test]
900    fn test_provider_config_openai_with_endpoint() {
901        let config = ProviderConfig::OpenAI {
902            api_key: Some("sk-test".to_string()),
903            api_endpoint: Some("https://custom.openai.com/v1".to_string()),
904        };
905        let json = serde_json::to_string(&config).unwrap();
906        assert!(json.contains("\"api_endpoint\":\"https://custom.openai.com/v1\""));
907    }
908
909    #[test]
910    fn test_provider_config_anthropic_serialization() {
911        let config = ProviderConfig::Anthropic {
912            api_key: Some("sk-ant-test".to_string()),
913            api_endpoint: None,
914            access_token: Some("oauth-token".to_string()),
915        };
916        let json = serde_json::to_string(&config).unwrap();
917        assert!(json.contains("\"type\":\"anthropic\""));
918        assert!(json.contains("\"api_key\":\"sk-ant-test\""));
919        assert!(json.contains("\"access_token\":\"oauth-token\""));
920    }
921
922    #[test]
923    fn test_provider_config_gemini_serialization() {
924        let config = ProviderConfig::Gemini {
925            api_key: Some("gemini-key".to_string()),
926            api_endpoint: None,
927        };
928        let json = serde_json::to_string(&config).unwrap();
929        assert!(json.contains("\"type\":\"gemini\""));
930        assert!(json.contains("\"api_key\":\"gemini-key\""));
931    }
932
933    #[test]
934    fn test_provider_config_custom_serialization() {
935        let config = ProviderConfig::Custom {
936            api_key: Some("sk-custom".to_string()),
937            api_endpoint: "http://localhost:4000".to_string(),
938        };
939        let json = serde_json::to_string(&config).unwrap();
940        assert!(json.contains("\"type\":\"custom\""));
941        assert!(json.contains("\"api_endpoint\":\"http://localhost:4000\""));
942        assert!(json.contains("\"api_key\":\"sk-custom\""));
943    }
944
945    #[test]
946    fn test_provider_config_custom_without_key() {
947        let config = ProviderConfig::Custom {
948            api_key: None,
949            api_endpoint: "http://localhost:11434/v1".to_string(),
950        };
951        let json = serde_json::to_string(&config).unwrap();
952        assert!(json.contains("\"type\":\"custom\""));
953        assert!(json.contains("\"api_endpoint\""));
954        assert!(!json.contains("api_key")); // Should be skipped when None
955    }
956
957    #[test]
958    fn test_provider_config_deserialization_openai() {
959        let json = r#"{"type":"openai","api_key":"sk-test"}"#;
960        let config: ProviderConfig = serde_json::from_str(json).unwrap();
961        assert!(matches!(config, ProviderConfig::OpenAI { .. }));
962        assert_eq!(config.api_key(), Some("sk-test"));
963    }
964
965    #[test]
966    fn test_provider_config_deserialization_anthropic() {
967        let json = r#"{"type":"anthropic","api_key":"sk-ant","access_token":"oauth"}"#;
968        let config: ProviderConfig = serde_json::from_str(json).unwrap();
969        assert!(matches!(config, ProviderConfig::Anthropic { .. }));
970        assert_eq!(config.api_key(), Some("sk-ant"));
971        assert_eq!(config.access_token(), Some("oauth"));
972    }
973
974    #[test]
975    fn test_provider_config_deserialization_gemini() {
976        let json = r#"{"type":"gemini","api_key":"gemini-key"}"#;
977        let config: ProviderConfig = serde_json::from_str(json).unwrap();
978        assert!(matches!(config, ProviderConfig::Gemini { .. }));
979        assert_eq!(config.api_key(), Some("gemini-key"));
980    }
981
982    #[test]
983    fn test_provider_config_deserialization_custom() {
984        let json =
985            r#"{"type":"custom","api_endpoint":"http://localhost:4000","api_key":"sk-custom"}"#;
986        let config: ProviderConfig = serde_json::from_str(json).unwrap();
987        assert!(matches!(config, ProviderConfig::Custom { .. }));
988        assert_eq!(config.api_key(), Some("sk-custom"));
989        assert_eq!(config.api_endpoint(), Some("http://localhost:4000"));
990    }
991
992    #[test]
993    fn test_provider_config_helper_methods() {
994        let openai = ProviderConfig::openai(Some("sk-openai".to_string()));
995        assert_eq!(openai.provider_type(), "openai");
996        assert_eq!(openai.api_key(), Some("sk-openai"));
997
998        let anthropic =
999            ProviderConfig::anthropic(Some("sk-ant".to_string()), Some("oauth".to_string()));
1000        assert_eq!(anthropic.provider_type(), "anthropic");
1001        assert_eq!(anthropic.access_token(), Some("oauth"));
1002
1003        let gemini = ProviderConfig::gemini(Some("gemini-key".to_string()));
1004        assert_eq!(gemini.provider_type(), "gemini");
1005
1006        let custom = ProviderConfig::custom(
1007            "http://localhost:4000".to_string(),
1008            Some("sk-custom".to_string()),
1009        );
1010        assert_eq!(custom.provider_type(), "custom");
1011        assert_eq!(custom.api_endpoint(), Some("http://localhost:4000"));
1012    }
1013
1014    #[test]
1015    fn test_llm_provider_config_new() {
1016        let config = LLMProviderConfig::new();
1017        assert!(config.is_empty());
1018    }
1019
1020    #[test]
1021    fn test_llm_provider_config_add_and_get() {
1022        let mut config = LLMProviderConfig::new();
1023        config.add_provider(
1024            "openai",
1025            ProviderConfig::openai(Some("sk-test".to_string())),
1026        );
1027        config.add_provider(
1028            "anthropic",
1029            ProviderConfig::anthropic(Some("sk-ant".to_string()), None),
1030        );
1031
1032        assert!(!config.is_empty());
1033        assert!(config.get_provider("openai").is_some());
1034        assert!(config.get_provider("anthropic").is_some());
1035        assert!(config.get_provider("unknown").is_none());
1036    }
1037
1038    #[test]
1039    fn test_provider_config_toml_parsing() {
1040        // Test parsing a HashMap of providers from TOML-like JSON
1041        let json = r#"{
1042            "openai": {"type": "openai", "api_key": "sk-openai"},
1043            "anthropic": {"type": "anthropic", "api_key": "sk-ant", "access_token": "oauth"},
1044            "litellm": {"type": "custom", "api_endpoint": "http://localhost:4000", "api_key": "sk-litellm"}
1045        }"#;
1046
1047        let providers: HashMap<String, ProviderConfig> = serde_json::from_str(json).unwrap();
1048        assert_eq!(providers.len(), 3);
1049
1050        assert!(matches!(
1051            providers.get("openai"),
1052            Some(ProviderConfig::OpenAI { .. })
1053        ));
1054        assert!(matches!(
1055            providers.get("anthropic"),
1056            Some(ProviderConfig::Anthropic { .. })
1057        ));
1058        assert!(matches!(
1059            providers.get("litellm"),
1060            Some(ProviderConfig::Custom { .. })
1061        ));
1062    }
1063}