Skip to main content

stakpak_shared/models/
llm.rs

1//! LLM Provider and Model Configuration
2//!
3//! This module provides the configuration types for LLM providers and models.
4//!
5//! # Provider Configuration
6//!
7//! Providers are configured in a `providers` HashMap where the key becomes the
8//! model prefix for routing requests to the correct provider.
9//!
10//! ## Built-in Providers
11//!
12//! - `openai` - OpenAI API
13//! - `anthropic` - Anthropic API (supports OAuth via `access_token`)
14//! - `gemini` - Google Gemini API
15//! - `bedrock` - AWS Bedrock (uses AWS credential chain, no API key)
16//!
17//! For built-in providers, you can use the model name directly without a prefix:
18//! - `claude-sonnet-4-5` → auto-detected as Anthropic
19//! - `gpt-4` → auto-detected as OpenAI
20//! - `gemini-2.5-pro` → auto-detected as Gemini
21//!
22//! ## Custom Providers
23//!
24//! Any OpenAI-compatible API can be configured using `type = "custom"`.
25//! The provider key becomes the model prefix.
26//!
27//! # Model Routing
28//!
29//! Models can be specified with or without a provider prefix:
30//!
31//! - `claude-sonnet-4-5` → auto-detected as `anthropic` provider
32//! - `anthropic/claude-sonnet-4-5` → explicit `anthropic` provider
33//! - `offline/llama3` → routes to `offline` custom provider, sends `llama3` to API
34//! - `custom/anthropic/claude-opus` → routes to `custom` provider,
35//!   sends `anthropic/claude-opus` to the API
36//!
37//! # Example Configuration
38//!
39//! ```toml
40//! [profiles.default]
41//! provider = "local"
42//! smart_model = "claude-sonnet-4-5"  # auto-detected as anthropic
43//! eco_model = "offline/llama3"       # custom provider
44//!
45//! [profiles.default.providers.anthropic]
46//! type = "anthropic"
47//! # api_key from auth.toml or ANTHROPIC_API_KEY env var
48//!
49//! [profiles.default.providers.offline]
50//! type = "custom"
51//! api_endpoint = "http://localhost:11434/v1"
52//! ```
53
54use serde::{Deserialize, Serialize};
55use stakai::Model;
56use std::collections::HashMap;
57
58// =============================================================================
59// Provider Configuration
60// =============================================================================
61
62/// Unified provider configuration enum
63///
64/// All provider configurations are stored in a `HashMap<String, ProviderConfig>`
65/// where the key is the provider name and becomes the model prefix for routing.
66///
67/// # Provider Key = Model Prefix
68///
69/// The key used in the HashMap becomes the prefix used in model names:
70/// - Config key: `providers.offline`
71/// - Model usage: `offline/llama3`
72/// - Routing: finds `offline` provider, sends `llama3` to API
73///
74/// # Example TOML
75/// ```toml
76/// [profiles.myprofile.providers.openai]
77/// type = "openai"
78/// api_key = "sk-..."
79///
80/// [profiles.myprofile.providers.anthropic]
81/// type = "anthropic"
82/// api_key = "sk-ant-..."
83/// access_token = "oauth-token"
84///
85/// [profiles.myprofile.providers.offline]
86/// type = "custom"
87/// api_endpoint = "http://localhost:11434/v1"
88/// ```
89#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
90#[serde(tag = "type", rename_all = "lowercase")]
91pub enum ProviderConfig {
92    /// OpenAI provider configuration
93    OpenAI {
94        #[serde(skip_serializing_if = "Option::is_none")]
95        api_key: Option<String>,
96        #[serde(skip_serializing_if = "Option::is_none")]
97        api_endpoint: Option<String>,
98    },
99    /// Anthropic provider configuration
100    Anthropic {
101        #[serde(skip_serializing_if = "Option::is_none")]
102        api_key: Option<String>,
103        #[serde(skip_serializing_if = "Option::is_none")]
104        api_endpoint: Option<String>,
105        /// OAuth access token (for Claude subscription)
106        #[serde(skip_serializing_if = "Option::is_none")]
107        access_token: Option<String>,
108    },
109    /// Google Gemini provider configuration
110    Gemini {
111        #[serde(skip_serializing_if = "Option::is_none")]
112        api_key: Option<String>,
113        #[serde(skip_serializing_if = "Option::is_none")]
114        api_endpoint: Option<String>,
115    },
116    /// Custom OpenAI-compatible provider (Ollama, vLLM, etc.)
117    ///
118    /// The provider key in the config becomes the model prefix.
119    /// For example, if configured as `providers.offline`, use models as:
120    /// - `offline/llama3` - passes `llama3` to the API
121    /// - `offline/anthropic/claude-opus` - passes `anthropic/claude-opus` to the API
122    ///
123    /// # Example TOML
124    /// ```toml
125    /// [profiles.myprofile.providers.offline]
126    /// type = "custom"
127    /// api_endpoint = "http://localhost:11434/v1"
128    ///
129    /// # Then use models as:
130    /// smart_model = "offline/llama3"
131    /// eco_model = "offline/phi3"
132    /// ```
133    Custom {
134        #[serde(skip_serializing_if = "Option::is_none")]
135        api_key: Option<String>,
136        /// API endpoint URL (required for custom providers)
137        /// Use the base URL as required by your provider (e.g., "http://localhost:11434/v1")
138        api_endpoint: String,
139    },
140    /// Stakpak provider configuration
141    ///
142    /// Routes inference through Stakpak's unified API, which provides:
143    /// - Access to multiple LLM providers via a single endpoint
144    /// - Usage tracking and billing
145    /// - Session management and checkpoints
146    ///
147    /// # Example TOML
148    /// ```toml
149    /// [profiles.myprofile.providers.stakpak]
150    /// type = "stakpak"
151    /// api_key = "your-stakpak-api-key"
152    /// api_endpoint = "https://apiv2.stakpak.dev"  # optional, this is the default
153    ///
154    /// # Then use models as:
155    /// smart_model = "stakpak/anthropic/claude-sonnet-4-5-20250929"
156    /// ```
157    Stakpak {
158        /// Stakpak API key (required)
159        api_key: String,
160        /// API endpoint URL (default: https://apiv2.stakpak.dev)
161        #[serde(skip_serializing_if = "Option::is_none")]
162        api_endpoint: Option<String>,
163    },
164    /// AWS Bedrock provider configuration
165    ///
166    /// Uses AWS credential chain for authentication (no API key needed).
167    /// Supports env vars, shared credentials, SSO, and instance roles.
168    ///
169    /// # Example TOML
170    /// ```toml
171    /// [profiles.myprofile.providers.amazon-bedrock]
172    /// type = "amazon-bedrock"
173    /// region = "us-east-1"
174    /// profile_name = "my-aws-profile"  # optional
175    ///
176    /// # Then use models as (friendly aliases work):
177    /// model = "amazon-bedrock/claude-sonnet-4-5"
178    /// ```
179    #[serde(rename = "amazon-bedrock")]
180    Bedrock {
181        /// AWS region (e.g., "us-east-1")
182        region: String,
183        /// Optional AWS named profile (from ~/.aws/config)
184        #[serde(skip_serializing_if = "Option::is_none")]
185        profile_name: Option<String>,
186    },
187}
188
189impl ProviderConfig {
190    /// Get the provider type name
191    pub fn provider_type(&self) -> &'static str {
192        match self {
193            ProviderConfig::OpenAI { .. } => "openai",
194            ProviderConfig::Anthropic { .. } => "anthropic",
195            ProviderConfig::Gemini { .. } => "gemini",
196            ProviderConfig::Custom { .. } => "custom",
197            ProviderConfig::Stakpak { .. } => "stakpak",
198            ProviderConfig::Bedrock { .. } => "amazon-bedrock",
199        }
200    }
201
202    /// Get the API key if set
203    pub fn api_key(&self) -> Option<&str> {
204        match self {
205            ProviderConfig::OpenAI { api_key, .. } => api_key.as_deref(),
206            ProviderConfig::Anthropic { api_key, .. } => api_key.as_deref(),
207            ProviderConfig::Gemini { api_key, .. } => api_key.as_deref(),
208            ProviderConfig::Custom { api_key, .. } => api_key.as_deref(),
209            ProviderConfig::Stakpak { api_key, .. } => Some(api_key.as_str()),
210            ProviderConfig::Bedrock { .. } => None, // AWS credential chain, no API key
211        }
212    }
213
214    /// Get the API endpoint if set
215    pub fn api_endpoint(&self) -> Option<&str> {
216        match self {
217            ProviderConfig::OpenAI { api_endpoint, .. } => api_endpoint.as_deref(),
218            ProviderConfig::Anthropic { api_endpoint, .. } => api_endpoint.as_deref(),
219            ProviderConfig::Gemini { api_endpoint, .. } => api_endpoint.as_deref(),
220            ProviderConfig::Custom { api_endpoint, .. } => Some(api_endpoint.as_str()),
221            ProviderConfig::Stakpak { api_endpoint, .. } => api_endpoint.as_deref(),
222            ProviderConfig::Bedrock { .. } => None, // No custom endpoint in config
223        }
224    }
225
226    /// Set the API endpoint for providers that support it.
227    ///
228    /// For `Custom`, `None` is ignored because custom providers require an endpoint.
229    /// For `Bedrock`, this is a no-op.
230    pub fn set_api_endpoint(&mut self, endpoint: Option<String>) {
231        match self {
232            ProviderConfig::OpenAI { api_endpoint, .. }
233            | ProviderConfig::Anthropic { api_endpoint, .. }
234            | ProviderConfig::Gemini { api_endpoint, .. }
235            | ProviderConfig::Stakpak { api_endpoint, .. } => {
236                *api_endpoint = endpoint;
237            }
238            ProviderConfig::Custom { api_endpoint, .. } => {
239                if let Some(custom_endpoint) = endpoint {
240                    *api_endpoint = custom_endpoint;
241                }
242            }
243            ProviderConfig::Bedrock { .. } => {}
244        }
245    }
246
247    /// Get the access token (Anthropic only)
248    pub fn access_token(&self) -> Option<&str> {
249        match self {
250            ProviderConfig::Anthropic { access_token, .. } => access_token.as_deref(),
251            _ => None,
252        }
253    }
254
255    /// Create an OpenAI provider config
256    pub fn openai(api_key: Option<String>) -> Self {
257        ProviderConfig::OpenAI {
258            api_key,
259            api_endpoint: None,
260        }
261    }
262
263    /// Create an Anthropic provider config
264    pub fn anthropic(api_key: Option<String>, access_token: Option<String>) -> Self {
265        ProviderConfig::Anthropic {
266            api_key,
267            api_endpoint: None,
268            access_token,
269        }
270    }
271
272    /// Create a Gemini provider config
273    pub fn gemini(api_key: Option<String>) -> Self {
274        ProviderConfig::Gemini {
275            api_key,
276            api_endpoint: None,
277        }
278    }
279
280    /// Create a custom provider config
281    pub fn custom(api_endpoint: String, api_key: Option<String>) -> Self {
282        ProviderConfig::Custom {
283            api_key,
284            api_endpoint,
285        }
286    }
287
288    /// Create a Stakpak provider config
289    pub fn stakpak(api_key: String, api_endpoint: Option<String>) -> Self {
290        ProviderConfig::Stakpak {
291            api_key,
292            api_endpoint,
293        }
294    }
295
296    /// Create a Bedrock provider config
297    pub fn bedrock(region: String, profile_name: Option<String>) -> Self {
298        ProviderConfig::Bedrock {
299            region,
300            profile_name,
301        }
302    }
303
304    /// Get the AWS region (Bedrock only)
305    pub fn region(&self) -> Option<&str> {
306        match self {
307            ProviderConfig::Bedrock { region, .. } => Some(region.as_str()),
308            _ => None,
309        }
310    }
311
312    /// Get the AWS profile name (Bedrock only)
313    pub fn profile_name(&self) -> Option<&str> {
314        match self {
315            ProviderConfig::Bedrock { profile_name, .. } => profile_name.as_deref(),
316            _ => None,
317        }
318    }
319}
320
321/// Aggregated provider configuration for LLM operations
322///
323/// This struct holds all configured providers, keyed by provider name.
324#[derive(Debug, Clone, Default)]
325pub struct LLMProviderConfig {
326    /// All provider configurations (key = provider name)
327    pub providers: HashMap<String, ProviderConfig>,
328}
329
330impl LLMProviderConfig {
331    /// Create a new empty provider config
332    pub fn new() -> Self {
333        Self {
334            providers: HashMap::new(),
335        }
336    }
337
338    /// Add a provider configuration
339    pub fn add_provider(&mut self, name: impl Into<String>, config: ProviderConfig) {
340        self.providers.insert(name.into(), config);
341    }
342
343    /// Get a provider configuration by name
344    pub fn get_provider(&self, name: &str) -> Option<&ProviderConfig> {
345        self.providers.get(name)
346    }
347
348    /// Check if any providers are configured
349    pub fn is_empty(&self) -> bool {
350        self.providers.is_empty()
351    }
352}
353
354/// Provider-specific options for LLM requests
355#[derive(Clone, Debug, Serialize, Deserialize, Default)]
356pub struct LLMProviderOptions {
357    /// Anthropic-specific options
358    #[serde(skip_serializing_if = "Option::is_none")]
359    pub anthropic: Option<LLMAnthropicOptions>,
360
361    /// OpenAI-specific options
362    #[serde(skip_serializing_if = "Option::is_none")]
363    pub openai: Option<LLMOpenAIOptions>,
364
365    /// Google/Gemini-specific options
366    #[serde(skip_serializing_if = "Option::is_none")]
367    pub google: Option<LLMGoogleOptions>,
368}
369
370/// Anthropic-specific options
371#[derive(Clone, Debug, Serialize, Deserialize, Default)]
372pub struct LLMAnthropicOptions {
373    /// Extended thinking configuration
374    #[serde(skip_serializing_if = "Option::is_none")]
375    pub thinking: Option<LLMThinkingOptions>,
376}
377
378/// Thinking/reasoning options
379#[derive(Clone, Debug, Serialize, Deserialize)]
380pub struct LLMThinkingOptions {
381    /// Budget tokens for thinking (must be >= 1024)
382    pub budget_tokens: u32,
383}
384
385impl LLMThinkingOptions {
386    pub fn new(budget_tokens: u32) -> Self {
387        Self {
388            budget_tokens: budget_tokens.max(1024),
389        }
390    }
391}
392
393/// OpenAI-specific options
394#[derive(Clone, Debug, Serialize, Deserialize, Default)]
395pub struct LLMOpenAIOptions {
396    /// Reasoning effort for o1/o3/o4 models ("low", "medium", "high")
397    #[serde(skip_serializing_if = "Option::is_none")]
398    pub reasoning_effort: Option<String>,
399}
400
401/// Google/Gemini-specific options
402#[derive(Clone, Debug, Serialize, Deserialize, Default)]
403pub struct LLMGoogleOptions {
404    /// Thinking budget in tokens
405    #[serde(skip_serializing_if = "Option::is_none")]
406    pub thinking_budget: Option<u32>,
407}
408
409#[derive(Clone, Debug, Serialize)]
410pub struct LLMInput {
411    pub model: Model,
412    pub messages: Vec<LLMMessage>,
413    pub max_tokens: u32,
414    pub tools: Option<Vec<LLMTool>>,
415    #[serde(skip_serializing_if = "Option::is_none")]
416    pub provider_options: Option<LLMProviderOptions>,
417    /// Custom headers to pass to the inference provider
418    #[serde(skip_serializing_if = "Option::is_none")]
419    pub headers: Option<std::collections::HashMap<String, String>>,
420}
421
422#[derive(Debug)]
423pub struct LLMStreamInput {
424    pub model: Model,
425    pub messages: Vec<LLMMessage>,
426    pub max_tokens: u32,
427    pub stream_channel_tx: tokio::sync::mpsc::Sender<GenerationDelta>,
428    pub tools: Option<Vec<LLMTool>>,
429    pub provider_options: Option<LLMProviderOptions>,
430    /// Custom headers to pass to the inference provider
431    pub headers: Option<std::collections::HashMap<String, String>>,
432}
433
434impl From<&LLMStreamInput> for LLMInput {
435    fn from(value: &LLMStreamInput) -> Self {
436        LLMInput {
437            model: value.model.clone(),
438            messages: value.messages.clone(),
439            max_tokens: value.max_tokens,
440            tools: value.tools.clone(),
441            provider_options: value.provider_options.clone(),
442            headers: value.headers.clone(),
443        }
444    }
445}
446
447#[derive(Serialize, Deserialize, Debug, Clone, Default)]
448pub struct LLMMessage {
449    pub role: String,
450    pub content: LLMMessageContent,
451}
452
453#[derive(Serialize, Deserialize, Debug, Clone)]
454pub struct SimpleLLMMessage {
455    #[serde(rename = "role")]
456    pub role: SimpleLLMRole,
457    pub content: String,
458}
459
460#[derive(Serialize, Deserialize, Debug, Clone)]
461#[serde(rename_all = "lowercase")]
462pub enum SimpleLLMRole {
463    User,
464    Assistant,
465}
466
467impl std::fmt::Display for SimpleLLMRole {
468    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
469        match self {
470            SimpleLLMRole::User => write!(f, "user"),
471            SimpleLLMRole::Assistant => write!(f, "assistant"),
472        }
473    }
474}
475
476#[derive(Serialize, Deserialize, Debug, Clone)]
477#[serde(untagged)]
478pub enum LLMMessageContent {
479    String(String),
480    List(Vec<LLMMessageTypedContent>),
481}
482
483#[allow(clippy::to_string_trait_impl)]
484impl ToString for LLMMessageContent {
485    fn to_string(&self) -> String {
486        match self {
487            LLMMessageContent::String(s) => s.clone(),
488            LLMMessageContent::List(l) => l
489                .iter()
490                .map(|c| match c {
491                    LLMMessageTypedContent::Text { text } => text.clone(),
492                    LLMMessageTypedContent::ToolCall { .. } => String::new(),
493                    LLMMessageTypedContent::ToolResult { content, .. } => content.clone(),
494                    LLMMessageTypedContent::Image { .. } => String::new(),
495                })
496                .collect::<Vec<_>>()
497                .join("\n"),
498        }
499    }
500}
501
502impl From<String> for LLMMessageContent {
503    fn from(value: String) -> Self {
504        LLMMessageContent::String(value)
505    }
506}
507
508impl Default for LLMMessageContent {
509    fn default() -> Self {
510        LLMMessageContent::String(String::new())
511    }
512}
513
514impl LLMMessageContent {
515    /// Convert into a Vec of typed content parts.
516    /// A `String` variant is returned as a single `Text` part (empty strings yield an empty vec).
517    pub fn into_parts(self) -> Vec<LLMMessageTypedContent> {
518        match self {
519            LLMMessageContent::List(parts) => parts,
520            LLMMessageContent::String(s) if s.is_empty() => vec![],
521            LLMMessageContent::String(s) => vec![LLMMessageTypedContent::Text { text: s }],
522        }
523    }
524}
525
526#[derive(Serialize, Deserialize, Debug, Clone)]
527#[serde(tag = "type")]
528pub enum LLMMessageTypedContent {
529    #[serde(rename = "text")]
530    Text { text: String },
531    #[serde(rename = "tool_use")]
532    ToolCall {
533        id: String,
534        name: String,
535        #[serde(alias = "input")]
536        args: serde_json::Value,
537        /// Opaque provider-specific metadata (e.g., Gemini thought_signature).
538        #[serde(skip_serializing_if = "Option::is_none")]
539        metadata: Option<serde_json::Value>,
540    },
541    #[serde(rename = "tool_result")]
542    ToolResult {
543        tool_use_id: String,
544        content: String,
545    },
546    #[serde(rename = "image")]
547    Image { source: LLMMessageImageSource },
548}
549
550#[derive(Serialize, Deserialize, Debug, Clone)]
551pub struct LLMMessageImageSource {
552    #[serde(rename = "type")]
553    pub r#type: String,
554    pub media_type: String,
555    pub data: String,
556}
557
558impl Default for LLMMessageTypedContent {
559    fn default() -> Self {
560        LLMMessageTypedContent::Text {
561            text: String::new(),
562        }
563    }
564}
565
566#[derive(Serialize, Deserialize, Debug, Clone)]
567pub struct LLMChoice {
568    pub finish_reason: Option<String>,
569    pub index: u32,
570    pub message: LLMMessage,
571}
572
573#[derive(Serialize, Deserialize, Debug, Clone)]
574pub struct LLMCompletionResponse {
575    pub model: String,
576    pub object: String,
577    pub choices: Vec<LLMChoice>,
578    pub created: u64,
579    pub usage: Option<LLMTokenUsage>,
580    pub id: String,
581}
582
583#[derive(Serialize, Deserialize, Debug, Clone)]
584pub struct LLMStreamDelta {
585    #[serde(skip_serializing_if = "Option::is_none")]
586    pub content: Option<String>,
587}
588
589#[derive(Serialize, Deserialize, Debug, Clone)]
590pub struct LLMStreamChoice {
591    pub finish_reason: Option<String>,
592    pub index: u32,
593    pub message: Option<LLMMessage>,
594    pub delta: LLMStreamDelta,
595}
596
597#[derive(Serialize, Deserialize, Debug, Clone)]
598pub struct LLMCompletionStreamResponse {
599    pub model: String,
600    pub object: String,
601    pub choices: Vec<LLMStreamChoice>,
602    pub created: u64,
603    #[serde(skip_serializing_if = "Option::is_none")]
604    pub usage: Option<LLMTokenUsage>,
605    pub id: String,
606    pub citations: Option<Vec<String>>,
607}
608
609#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
610pub struct LLMTool {
611    pub name: String,
612    pub description: String,
613    pub input_schema: serde_json::Value,
614}
615
616#[derive(Default, Debug, Serialize, Deserialize, Clone, PartialEq)]
617pub struct LLMTokenUsage {
618    pub prompt_tokens: u32,
619    pub completion_tokens: u32,
620    pub total_tokens: u32,
621
622    #[serde(skip_serializing_if = "Option::is_none")]
623    pub prompt_tokens_details: Option<PromptTokensDetails>,
624}
625
626#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
627#[serde(rename_all = "snake_case")]
628pub enum TokenType {
629    InputTokens,
630    OutputTokens,
631    CacheReadInputTokens,
632    CacheWriteInputTokens,
633}
634
635#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
636pub struct PromptTokensDetails {
637    #[serde(skip_serializing_if = "Option::is_none")]
638    pub input_tokens: Option<u32>,
639    #[serde(skip_serializing_if = "Option::is_none")]
640    pub output_tokens: Option<u32>,
641    #[serde(skip_serializing_if = "Option::is_none")]
642    pub cache_read_input_tokens: Option<u32>,
643    #[serde(skip_serializing_if = "Option::is_none")]
644    pub cache_write_input_tokens: Option<u32>,
645}
646
647impl PromptTokensDetails {
648    /// Returns an iterator over the token types and their values
649    pub fn iter(&self) -> impl Iterator<Item = (TokenType, u32)> {
650        [
651            (TokenType::InputTokens, self.input_tokens.unwrap_or(0)),
652            (TokenType::OutputTokens, self.output_tokens.unwrap_or(0)),
653            (
654                TokenType::CacheReadInputTokens,
655                self.cache_read_input_tokens.unwrap_or(0),
656            ),
657            (
658                TokenType::CacheWriteInputTokens,
659                self.cache_write_input_tokens.unwrap_or(0),
660            ),
661        ]
662        .into_iter()
663    }
664}
665
666impl std::ops::Add for PromptTokensDetails {
667    type Output = Self;
668
669    fn add(self, rhs: Self) -> Self::Output {
670        Self {
671            input_tokens: Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0)),
672            output_tokens: Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0)),
673            cache_read_input_tokens: Some(
674                self.cache_read_input_tokens.unwrap_or(0)
675                    + rhs.cache_read_input_tokens.unwrap_or(0),
676            ),
677            cache_write_input_tokens: Some(
678                self.cache_write_input_tokens.unwrap_or(0)
679                    + rhs.cache_write_input_tokens.unwrap_or(0),
680            ),
681        }
682    }
683}
684
685impl std::ops::AddAssign for PromptTokensDetails {
686    fn add_assign(&mut self, rhs: Self) {
687        self.input_tokens = Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0));
688        self.output_tokens = Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0));
689        self.cache_read_input_tokens = Some(
690            self.cache_read_input_tokens.unwrap_or(0) + rhs.cache_read_input_tokens.unwrap_or(0),
691        );
692        self.cache_write_input_tokens = Some(
693            self.cache_write_input_tokens.unwrap_or(0) + rhs.cache_write_input_tokens.unwrap_or(0),
694        );
695    }
696}
697
698#[derive(Serialize, Deserialize, Debug, Clone)]
699#[serde(tag = "type")]
700pub enum GenerationDelta {
701    Content { content: String },
702    Thinking { thinking: String },
703    ToolUse { tool_use: GenerationDeltaToolUse },
704    Usage { usage: LLMTokenUsage },
705    Metadata { metadata: serde_json::Value },
706}
707
708#[derive(Serialize, Deserialize, Debug, Clone)]
709pub struct GenerationDeltaToolUse {
710    pub id: Option<String>,
711    pub name: Option<String>,
712    pub input: Option<String>,
713    pub index: usize,
714    /// Opaque provider-specific metadata (e.g., Gemini thought_signature)
715    #[serde(skip_serializing_if = "Option::is_none")]
716    pub metadata: Option<serde_json::Value>,
717}
718
719#[cfg(test)]
720mod tests {
721    use super::*;
722
723    // =========================================================================
724    // ProviderConfig Tests
725    // =========================================================================
726
727    #[test]
728    fn test_provider_config_openai_serialization() {
729        let config = ProviderConfig::OpenAI {
730            api_key: Some("sk-test".to_string()),
731            api_endpoint: None,
732        };
733        let json = serde_json::to_string(&config).unwrap();
734        assert!(json.contains("\"type\":\"openai\""));
735        assert!(json.contains("\"api_key\":\"sk-test\""));
736        assert!(!json.contains("api_endpoint")); // Should be skipped when None
737    }
738
739    #[test]
740    fn test_provider_config_openai_with_endpoint() {
741        let config = ProviderConfig::OpenAI {
742            api_key: Some("sk-test".to_string()),
743            api_endpoint: Some("https://custom.openai.com/v1".to_string()),
744        };
745        let json = serde_json::to_string(&config).unwrap();
746        assert!(json.contains("\"api_endpoint\":\"https://custom.openai.com/v1\""));
747    }
748
749    #[test]
750    fn test_provider_config_anthropic_serialization() {
751        let config = ProviderConfig::Anthropic {
752            api_key: Some("sk-ant-test".to_string()),
753            api_endpoint: None,
754            access_token: Some("oauth-token".to_string()),
755        };
756        let json = serde_json::to_string(&config).unwrap();
757        assert!(json.contains("\"type\":\"anthropic\""));
758        assert!(json.contains("\"api_key\":\"sk-ant-test\""));
759        assert!(json.contains("\"access_token\":\"oauth-token\""));
760    }
761
762    #[test]
763    fn test_provider_config_gemini_serialization() {
764        let config = ProviderConfig::Gemini {
765            api_key: Some("gemini-key".to_string()),
766            api_endpoint: None,
767        };
768        let json = serde_json::to_string(&config).unwrap();
769        assert!(json.contains("\"type\":\"gemini\""));
770        assert!(json.contains("\"api_key\":\"gemini-key\""));
771    }
772
773    #[test]
774    fn test_provider_config_custom_serialization() {
775        let config = ProviderConfig::Custom {
776            api_key: Some("sk-custom".to_string()),
777            api_endpoint: "http://localhost:4000".to_string(),
778        };
779        let json = serde_json::to_string(&config).unwrap();
780        assert!(json.contains("\"type\":\"custom\""));
781        assert!(json.contains("\"api_endpoint\":\"http://localhost:4000\""));
782        assert!(json.contains("\"api_key\":\"sk-custom\""));
783    }
784
785    #[test]
786    fn test_provider_config_custom_without_key() {
787        let config = ProviderConfig::Custom {
788            api_key: None,
789            api_endpoint: "http://localhost:11434/v1".to_string(),
790        };
791        let json = serde_json::to_string(&config).unwrap();
792        assert!(json.contains("\"type\":\"custom\""));
793        assert!(json.contains("\"api_endpoint\""));
794        assert!(!json.contains("api_key")); // Should be skipped when None
795    }
796
797    #[test]
798    fn test_provider_config_deserialization_openai() {
799        let json = r#"{"type":"openai","api_key":"sk-test"}"#;
800        let config: ProviderConfig = serde_json::from_str(json).unwrap();
801        assert!(matches!(config, ProviderConfig::OpenAI { .. }));
802        assert_eq!(config.api_key(), Some("sk-test"));
803    }
804
805    #[test]
806    fn test_provider_config_deserialization_anthropic() {
807        let json = r#"{"type":"anthropic","api_key":"sk-ant","access_token":"oauth"}"#;
808        let config: ProviderConfig = serde_json::from_str(json).unwrap();
809        assert!(matches!(config, ProviderConfig::Anthropic { .. }));
810        assert_eq!(config.api_key(), Some("sk-ant"));
811        assert_eq!(config.access_token(), Some("oauth"));
812    }
813
814    #[test]
815    fn test_provider_config_deserialization_gemini() {
816        let json = r#"{"type":"gemini","api_key":"gemini-key"}"#;
817        let config: ProviderConfig = serde_json::from_str(json).unwrap();
818        assert!(matches!(config, ProviderConfig::Gemini { .. }));
819        assert_eq!(config.api_key(), Some("gemini-key"));
820    }
821
822    #[test]
823    fn test_provider_config_deserialization_custom() {
824        let json =
825            r#"{"type":"custom","api_endpoint":"http://localhost:4000","api_key":"sk-custom"}"#;
826        let config: ProviderConfig = serde_json::from_str(json).unwrap();
827        assert!(matches!(config, ProviderConfig::Custom { .. }));
828        assert_eq!(config.api_key(), Some("sk-custom"));
829        assert_eq!(config.api_endpoint(), Some("http://localhost:4000"));
830    }
831
832    #[test]
833    fn test_provider_config_helper_methods() {
834        let openai = ProviderConfig::openai(Some("sk-openai".to_string()));
835        assert_eq!(openai.provider_type(), "openai");
836        assert_eq!(openai.api_key(), Some("sk-openai"));
837
838        let anthropic =
839            ProviderConfig::anthropic(Some("sk-ant".to_string()), Some("oauth".to_string()));
840        assert_eq!(anthropic.provider_type(), "anthropic");
841        assert_eq!(anthropic.access_token(), Some("oauth"));
842
843        let gemini = ProviderConfig::gemini(Some("gemini-key".to_string()));
844        assert_eq!(gemini.provider_type(), "gemini");
845
846        let custom = ProviderConfig::custom(
847            "http://localhost:4000".to_string(),
848            Some("sk-custom".to_string()),
849        );
850        assert_eq!(custom.provider_type(), "custom");
851        assert_eq!(custom.api_endpoint(), Some("http://localhost:4000"));
852    }
853
854    #[test]
855    fn test_set_api_endpoint_updates_supported_providers() {
856        let mut openai = ProviderConfig::openai(Some("sk-openai".to_string()));
857        openai.set_api_endpoint(Some("https://proxy.example.com/v1".to_string()));
858        assert_eq!(openai.api_endpoint(), Some("https://proxy.example.com/v1"));
859
860        let mut bedrock = ProviderConfig::bedrock("us-east-1".to_string(), None);
861        bedrock.set_api_endpoint(Some("https://ignored.example.com".to_string()));
862        assert_eq!(bedrock.api_endpoint(), None);
863    }
864
865    #[test]
866    fn test_llm_provider_config_new() {
867        let config = LLMProviderConfig::new();
868        assert!(config.is_empty());
869    }
870
871    #[test]
872    fn test_llm_provider_config_add_and_get() {
873        let mut config = LLMProviderConfig::new();
874        config.add_provider(
875            "openai",
876            ProviderConfig::openai(Some("sk-test".to_string())),
877        );
878        config.add_provider(
879            "anthropic",
880            ProviderConfig::anthropic(Some("sk-ant".to_string()), None),
881        );
882
883        assert!(!config.is_empty());
884        assert!(config.get_provider("openai").is_some());
885        assert!(config.get_provider("anthropic").is_some());
886        assert!(config.get_provider("unknown").is_none());
887    }
888
889    #[test]
890    fn test_provider_config_toml_parsing() {
891        // Test parsing a HashMap of providers from TOML-like JSON
892        let json = r#"{
893            "openai": {"type": "openai", "api_key": "sk-openai"},
894            "anthropic": {"type": "anthropic", "api_key": "sk-ant", "access_token": "oauth"},
895            "litellm": {"type": "custom", "api_endpoint": "http://localhost:4000", "api_key": "sk-litellm"}
896        }"#;
897
898        let providers: HashMap<String, ProviderConfig> = serde_json::from_str(json).unwrap();
899        assert_eq!(providers.len(), 3);
900
901        assert!(matches!(
902            providers.get("openai"),
903            Some(ProviderConfig::OpenAI { .. })
904        ));
905        assert!(matches!(
906            providers.get("anthropic"),
907            Some(ProviderConfig::Anthropic { .. })
908        ));
909        assert!(matches!(
910            providers.get("litellm"),
911            Some(ProviderConfig::Custom { .. })
912        ));
913    }
914
915    // =========================================================================
916    // Bedrock ProviderConfig Tests
917    // =========================================================================
918
919    #[test]
920    fn test_provider_config_bedrock_serialization() {
921        let config = ProviderConfig::Bedrock {
922            region: "us-east-1".to_string(),
923            profile_name: Some("my-profile".to_string()),
924        };
925        let json = serde_json::to_string(&config).unwrap();
926        assert!(json.contains("\"type\":\"amazon-bedrock\""));
927        assert!(json.contains("\"region\":\"us-east-1\""));
928        assert!(json.contains("\"profile_name\":\"my-profile\""));
929    }
930
931    #[test]
932    fn test_provider_config_bedrock_serialization_without_profile() {
933        let config = ProviderConfig::Bedrock {
934            region: "us-west-2".to_string(),
935            profile_name: None,
936        };
937        let json = serde_json::to_string(&config).unwrap();
938        assert!(json.contains("\"type\":\"amazon-bedrock\""));
939        assert!(json.contains("\"region\":\"us-west-2\""));
940        assert!(!json.contains("profile_name")); // Should be skipped when None
941    }
942
943    #[test]
944    fn test_provider_config_bedrock_deserialization() {
945        let json = r#"{"type":"amazon-bedrock","region":"us-east-1","profile_name":"prod"}"#;
946        let config: ProviderConfig = serde_json::from_str(json).unwrap();
947        assert!(matches!(config, ProviderConfig::Bedrock { .. }));
948        assert_eq!(config.region(), Some("us-east-1"));
949        assert_eq!(config.profile_name(), Some("prod"));
950    }
951
952    #[test]
953    fn test_provider_config_bedrock_deserialization_minimal() {
954        let json = r#"{"type":"amazon-bedrock","region":"eu-west-1"}"#;
955        let config: ProviderConfig = serde_json::from_str(json).unwrap();
956        assert!(matches!(config, ProviderConfig::Bedrock { .. }));
957        assert_eq!(config.region(), Some("eu-west-1"));
958        assert_eq!(config.profile_name(), None);
959    }
960
961    #[test]
962    fn test_provider_config_bedrock_no_api_key() {
963        let config = ProviderConfig::bedrock("us-east-1".to_string(), None);
964        assert_eq!(config.api_key(), None); // Bedrock uses AWS credential chain
965        assert_eq!(config.api_endpoint(), None); // No custom endpoint
966    }
967
968    #[test]
969    fn test_provider_config_bedrock_helper_methods() {
970        let bedrock = ProviderConfig::bedrock("us-east-1".to_string(), Some("prod".to_string()));
971        assert_eq!(bedrock.provider_type(), "amazon-bedrock");
972        assert_eq!(bedrock.region(), Some("us-east-1"));
973        assert_eq!(bedrock.profile_name(), Some("prod"));
974        assert_eq!(bedrock.api_key(), None);
975        assert_eq!(bedrock.api_endpoint(), None);
976        assert_eq!(bedrock.access_token(), None);
977    }
978
979    #[test]
980    fn test_provider_config_bedrock_toml_roundtrip() {
981        let config = ProviderConfig::Bedrock {
982            region: "us-east-1".to_string(),
983            profile_name: Some("my-profile".to_string()),
984        };
985        let toml_str = toml::to_string(&config).unwrap();
986        let parsed: ProviderConfig = toml::from_str(&toml_str).unwrap();
987        assert_eq!(config, parsed);
988    }
989
990    #[test]
991    fn test_provider_config_bedrock_toml_parsing() {
992        let toml_str = r#"
993            type = "amazon-bedrock"
994            region = "us-east-1"
995            profile_name = "production"
996        "#;
997        let config: ProviderConfig = toml::from_str(toml_str).unwrap();
998        assert!(matches!(
999            config,
1000            ProviderConfig::Bedrock {
1001                ref region,
1002                ref profile_name,
1003            } if region == "us-east-1" && profile_name.as_deref() == Some("production")
1004        ));
1005    }
1006
1007    #[test]
1008    fn test_provider_config_bedrock_missing_region_fails() {
1009        let json = r#"{"type":"amazon-bedrock"}"#;
1010        let result: Result<ProviderConfig, _> = serde_json::from_str(json);
1011        assert!(result.is_err()); // region is required
1012    }
1013
1014    #[test]
1015    fn test_provider_config_bedrock_in_providers_map() {
1016        let json = r#"{
1017            "anthropic": {"type": "anthropic", "api_key": "sk-ant"},
1018            "amazon-bedrock": {"type": "amazon-bedrock", "region": "us-east-1"}
1019        }"#;
1020        let providers: HashMap<String, ProviderConfig> = serde_json::from_str(json).unwrap();
1021        assert_eq!(providers.len(), 2);
1022        assert!(matches!(
1023            providers.get("amazon-bedrock"),
1024            Some(ProviderConfig::Bedrock { .. })
1025        ));
1026    }
1027
1028    #[test]
1029    fn test_region_returns_none_for_non_bedrock() {
1030        let openai = ProviderConfig::openai(Some("key".to_string()));
1031        assert_eq!(openai.region(), None);
1032
1033        let anthropic = ProviderConfig::anthropic(Some("key".to_string()), None);
1034        assert_eq!(anthropic.region(), None);
1035    }
1036
1037    #[test]
1038    fn test_profile_name_returns_none_for_non_bedrock() {
1039        let openai = ProviderConfig::openai(Some("key".to_string()));
1040        assert_eq!(openai.profile_name(), None);
1041    }
1042}