Skip to main content

stakpak_shared/models/
llm.rs

1//! LLM Provider and Model Configuration
2//!
3//! This module provides the configuration types for LLM providers and models.
4//!
5//! # Provider Configuration
6//!
7//! Providers are configured in a `providers` HashMap where the key becomes the
8//! model prefix for routing requests to the correct provider.
9//!
10//! ## Built-in Providers
11//!
12//! - `openai` - OpenAI API
13//! - `anthropic` - Anthropic API (supports OAuth via `access_token`)
14//! - `gemini` - Google Gemini API
15//! - `bedrock` - AWS Bedrock (uses AWS credential chain, no API key)
16//!
17//! For built-in providers, you can use the model name directly without a prefix:
18//! - `claude-sonnet-4-5` → auto-detected as Anthropic
19//! - `gpt-4` → auto-detected as OpenAI
20//! - `gemini-2.5-pro` → auto-detected as Gemini
21//!
22//! ## Custom Providers
23//!
24//! Any OpenAI-compatible API can be configured using `type = "custom"`.
25//! The provider key becomes the model prefix.
26//!
27//! # Model Routing
28//!
29//! Models can be specified with or without a provider prefix:
30//!
31//! - `claude-sonnet-4-5` → auto-detected as `anthropic` provider
32//! - `anthropic/claude-sonnet-4-5` → explicit `anthropic` provider
33//! - `offline/llama3` → routes to `offline` custom provider, sends `llama3` to API
34//! - `custom/anthropic/claude-opus` → routes to `custom` provider,
35//!   sends `anthropic/claude-opus` to the API
36//!
37//! # Example Configuration
38//!
39//! ```toml
40//! [profiles.default]
41//! provider = "local"
42//! smart_model = "claude-sonnet-4-5"  # auto-detected as anthropic
43//! eco_model = "offline/llama3"       # custom provider
44//!
45//! [profiles.default.providers.anthropic]
46//! type = "anthropic"
47//! # api_key from auth.toml or ANTHROPIC_API_KEY env var
48//!
49//! [profiles.default.providers.offline]
50//! type = "custom"
51//! api_endpoint = "http://localhost:11434/v1"
52//! ```
53
54use serde::{Deserialize, Serialize};
55use stakai::Model;
56use std::collections::HashMap;
57
58// =============================================================================
59// Provider Configuration
60// =============================================================================
61
62/// Unified provider configuration enum
63///
64/// All provider configurations are stored in a `HashMap<String, ProviderConfig>`
65/// where the key is the provider name and becomes the model prefix for routing.
66///
67/// # Provider Key = Model Prefix
68///
69/// The key used in the HashMap becomes the prefix used in model names:
70/// - Config key: `providers.offline`
71/// - Model usage: `offline/llama3`
72/// - Routing: finds `offline` provider, sends `llama3` to API
73///
74/// # Example TOML
75/// ```toml
76/// [profiles.myprofile.providers.openai]
77/// type = "openai"
78/// api_key = "sk-..."
79///
80/// [profiles.myprofile.providers.anthropic]
81/// type = "anthropic"
82/// api_key = "sk-ant-..."
83/// access_token = "oauth-token"
84///
85/// [profiles.myprofile.providers.offline]
86/// type = "custom"
87/// api_endpoint = "http://localhost:11434/v1"
88/// ```
89#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
90#[serde(tag = "type", rename_all = "lowercase")]
91pub enum ProviderConfig {
92    /// OpenAI provider configuration
93    OpenAI {
94        #[serde(skip_serializing_if = "Option::is_none")]
95        api_key: Option<String>,
96        #[serde(skip_serializing_if = "Option::is_none")]
97        api_endpoint: Option<String>,
98    },
99    /// Anthropic provider configuration
100    Anthropic {
101        #[serde(skip_serializing_if = "Option::is_none")]
102        api_key: Option<String>,
103        #[serde(skip_serializing_if = "Option::is_none")]
104        api_endpoint: Option<String>,
105        /// OAuth access token (for Claude subscription)
106        #[serde(skip_serializing_if = "Option::is_none")]
107        access_token: Option<String>,
108    },
109    /// Google Gemini provider configuration
110    Gemini {
111        #[serde(skip_serializing_if = "Option::is_none")]
112        api_key: Option<String>,
113        #[serde(skip_serializing_if = "Option::is_none")]
114        api_endpoint: Option<String>,
115    },
116    /// Custom OpenAI-compatible provider (Ollama, vLLM, etc.)
117    ///
118    /// The provider key in the config becomes the model prefix.
119    /// For example, if configured as `providers.offline`, use models as:
120    /// - `offline/llama3` - passes `llama3` to the API
121    /// - `offline/anthropic/claude-opus` - passes `anthropic/claude-opus` to the API
122    ///
123    /// # Example TOML
124    /// ```toml
125    /// [profiles.myprofile.providers.offline]
126    /// type = "custom"
127    /// api_endpoint = "http://localhost:11434/v1"
128    ///
129    /// # Then use models as:
130    /// smart_model = "offline/llama3"
131    /// eco_model = "offline/phi3"
132    /// ```
133    Custom {
134        #[serde(skip_serializing_if = "Option::is_none")]
135        api_key: Option<String>,
136        /// API endpoint URL (required for custom providers)
137        /// Use the base URL as required by your provider (e.g., "http://localhost:11434/v1")
138        api_endpoint: String,
139    },
140    /// Stakpak provider configuration
141    ///
142    /// Routes inference through Stakpak's unified API, which provides:
143    /// - Access to multiple LLM providers via a single endpoint
144    /// - Usage tracking and billing
145    /// - Session management and checkpoints
146    ///
147    /// # Example TOML
148    /// ```toml
149    /// [profiles.myprofile.providers.stakpak]
150    /// type = "stakpak"
151    /// api_key = "your-stakpak-api-key"
152    /// api_endpoint = "https://apiv2.stakpak.dev"  # optional, this is the default
153    ///
154    /// # Then use models as:
155    /// smart_model = "stakpak/anthropic/claude-sonnet-4-5-20250929"
156    /// ```
157    Stakpak {
158        /// Stakpak API key (required)
159        api_key: String,
160        /// API endpoint URL (default: https://apiv2.stakpak.dev)
161        #[serde(skip_serializing_if = "Option::is_none")]
162        api_endpoint: Option<String>,
163    },
164    /// AWS Bedrock provider configuration
165    ///
166    /// Uses AWS credential chain for authentication (no API key needed).
167    /// Supports env vars, shared credentials, SSO, and instance roles.
168    ///
169    /// # Example TOML
170    /// ```toml
171    /// [profiles.myprofile.providers.amazon-bedrock]
172    /// type = "amazon-bedrock"
173    /// region = "us-east-1"
174    /// profile_name = "my-aws-profile"  # optional
175    ///
176    /// # Then use models as (friendly aliases work):
177    /// model = "amazon-bedrock/claude-sonnet-4-5"
178    /// ```
179    #[serde(rename = "amazon-bedrock")]
180    Bedrock {
181        /// AWS region (e.g., "us-east-1")
182        region: String,
183        /// Optional AWS named profile (from ~/.aws/config)
184        #[serde(skip_serializing_if = "Option::is_none")]
185        profile_name: Option<String>,
186    },
187}
188
189impl ProviderConfig {
190    /// Get the provider type name
191    pub fn provider_type(&self) -> &'static str {
192        match self {
193            ProviderConfig::OpenAI { .. } => "openai",
194            ProviderConfig::Anthropic { .. } => "anthropic",
195            ProviderConfig::Gemini { .. } => "gemini",
196            ProviderConfig::Custom { .. } => "custom",
197            ProviderConfig::Stakpak { .. } => "stakpak",
198            ProviderConfig::Bedrock { .. } => "amazon-bedrock",
199        }
200    }
201
202    /// Get the API key if set
203    pub fn api_key(&self) -> Option<&str> {
204        match self {
205            ProviderConfig::OpenAI { api_key, .. } => api_key.as_deref(),
206            ProviderConfig::Anthropic { api_key, .. } => api_key.as_deref(),
207            ProviderConfig::Gemini { api_key, .. } => api_key.as_deref(),
208            ProviderConfig::Custom { api_key, .. } => api_key.as_deref(),
209            ProviderConfig::Stakpak { api_key, .. } => Some(api_key.as_str()),
210            ProviderConfig::Bedrock { .. } => None, // AWS credential chain, no API key
211        }
212    }
213
214    /// Get the API endpoint if set
215    pub fn api_endpoint(&self) -> Option<&str> {
216        match self {
217            ProviderConfig::OpenAI { api_endpoint, .. } => api_endpoint.as_deref(),
218            ProviderConfig::Anthropic { api_endpoint, .. } => api_endpoint.as_deref(),
219            ProviderConfig::Gemini { api_endpoint, .. } => api_endpoint.as_deref(),
220            ProviderConfig::Custom { api_endpoint, .. } => Some(api_endpoint.as_str()),
221            ProviderConfig::Stakpak { api_endpoint, .. } => api_endpoint.as_deref(),
222            ProviderConfig::Bedrock { .. } => None, // No custom endpoint in config
223        }
224    }
225
226    /// Get the access token (Anthropic only)
227    pub fn access_token(&self) -> Option<&str> {
228        match self {
229            ProviderConfig::Anthropic { access_token, .. } => access_token.as_deref(),
230            _ => None,
231        }
232    }
233
234    /// Create an OpenAI provider config
235    pub fn openai(api_key: Option<String>) -> Self {
236        ProviderConfig::OpenAI {
237            api_key,
238            api_endpoint: None,
239        }
240    }
241
242    /// Create an Anthropic provider config
243    pub fn anthropic(api_key: Option<String>, access_token: Option<String>) -> Self {
244        ProviderConfig::Anthropic {
245            api_key,
246            api_endpoint: None,
247            access_token,
248        }
249    }
250
251    /// Create a Gemini provider config
252    pub fn gemini(api_key: Option<String>) -> Self {
253        ProviderConfig::Gemini {
254            api_key,
255            api_endpoint: None,
256        }
257    }
258
259    /// Create a custom provider config
260    pub fn custom(api_endpoint: String, api_key: Option<String>) -> Self {
261        ProviderConfig::Custom {
262            api_key,
263            api_endpoint,
264        }
265    }
266
267    /// Create a Stakpak provider config
268    pub fn stakpak(api_key: String, api_endpoint: Option<String>) -> Self {
269        ProviderConfig::Stakpak {
270            api_key,
271            api_endpoint,
272        }
273    }
274
275    /// Create a Bedrock provider config
276    pub fn bedrock(region: String, profile_name: Option<String>) -> Self {
277        ProviderConfig::Bedrock {
278            region,
279            profile_name,
280        }
281    }
282
283    /// Get the AWS region (Bedrock only)
284    pub fn region(&self) -> Option<&str> {
285        match self {
286            ProviderConfig::Bedrock { region, .. } => Some(region.as_str()),
287            _ => None,
288        }
289    }
290
291    /// Get the AWS profile name (Bedrock only)
292    pub fn profile_name(&self) -> Option<&str> {
293        match self {
294            ProviderConfig::Bedrock { profile_name, .. } => profile_name.as_deref(),
295            _ => None,
296        }
297    }
298}
299
300/// Aggregated provider configuration for LLM operations
301///
302/// This struct holds all configured providers, keyed by provider name.
303#[derive(Debug, Clone, Default)]
304pub struct LLMProviderConfig {
305    /// All provider configurations (key = provider name)
306    pub providers: HashMap<String, ProviderConfig>,
307}
308
309impl LLMProviderConfig {
310    /// Create a new empty provider config
311    pub fn new() -> Self {
312        Self {
313            providers: HashMap::new(),
314        }
315    }
316
317    /// Add a provider configuration
318    pub fn add_provider(&mut self, name: impl Into<String>, config: ProviderConfig) {
319        self.providers.insert(name.into(), config);
320    }
321
322    /// Get a provider configuration by name
323    pub fn get_provider(&self, name: &str) -> Option<&ProviderConfig> {
324        self.providers.get(name)
325    }
326
327    /// Check if any providers are configured
328    pub fn is_empty(&self) -> bool {
329        self.providers.is_empty()
330    }
331}
332
333/// Provider-specific options for LLM requests
334#[derive(Clone, Debug, Serialize, Deserialize, Default)]
335pub struct LLMProviderOptions {
336    /// Anthropic-specific options
337    #[serde(skip_serializing_if = "Option::is_none")]
338    pub anthropic: Option<LLMAnthropicOptions>,
339
340    /// OpenAI-specific options
341    #[serde(skip_serializing_if = "Option::is_none")]
342    pub openai: Option<LLMOpenAIOptions>,
343
344    /// Google/Gemini-specific options
345    #[serde(skip_serializing_if = "Option::is_none")]
346    pub google: Option<LLMGoogleOptions>,
347}
348
349/// Anthropic-specific options
350#[derive(Clone, Debug, Serialize, Deserialize, Default)]
351pub struct LLMAnthropicOptions {
352    /// Extended thinking configuration
353    #[serde(skip_serializing_if = "Option::is_none")]
354    pub thinking: Option<LLMThinkingOptions>,
355}
356
357/// Thinking/reasoning options
358#[derive(Clone, Debug, Serialize, Deserialize)]
359pub struct LLMThinkingOptions {
360    /// Budget tokens for thinking (must be >= 1024)
361    pub budget_tokens: u32,
362}
363
364impl LLMThinkingOptions {
365    pub fn new(budget_tokens: u32) -> Self {
366        Self {
367            budget_tokens: budget_tokens.max(1024),
368        }
369    }
370}
371
372/// OpenAI-specific options
373#[derive(Clone, Debug, Serialize, Deserialize, Default)]
374pub struct LLMOpenAIOptions {
375    /// Reasoning effort for o1/o3/o4 models ("low", "medium", "high")
376    #[serde(skip_serializing_if = "Option::is_none")]
377    pub reasoning_effort: Option<String>,
378}
379
380/// Google/Gemini-specific options
381#[derive(Clone, Debug, Serialize, Deserialize, Default)]
382pub struct LLMGoogleOptions {
383    /// Thinking budget in tokens
384    #[serde(skip_serializing_if = "Option::is_none")]
385    pub thinking_budget: Option<u32>,
386}
387
388#[derive(Clone, Debug, Serialize)]
389pub struct LLMInput {
390    pub model: Model,
391    pub messages: Vec<LLMMessage>,
392    pub max_tokens: u32,
393    pub tools: Option<Vec<LLMTool>>,
394    #[serde(skip_serializing_if = "Option::is_none")]
395    pub provider_options: Option<LLMProviderOptions>,
396    /// Custom headers to pass to the inference provider
397    #[serde(skip_serializing_if = "Option::is_none")]
398    pub headers: Option<std::collections::HashMap<String, String>>,
399}
400
401#[derive(Debug)]
402pub struct LLMStreamInput {
403    pub model: Model,
404    pub messages: Vec<LLMMessage>,
405    pub max_tokens: u32,
406    pub stream_channel_tx: tokio::sync::mpsc::Sender<GenerationDelta>,
407    pub tools: Option<Vec<LLMTool>>,
408    pub provider_options: Option<LLMProviderOptions>,
409    /// Custom headers to pass to the inference provider
410    pub headers: Option<std::collections::HashMap<String, String>>,
411}
412
413impl From<&LLMStreamInput> for LLMInput {
414    fn from(value: &LLMStreamInput) -> Self {
415        LLMInput {
416            model: value.model.clone(),
417            messages: value.messages.clone(),
418            max_tokens: value.max_tokens,
419            tools: value.tools.clone(),
420            provider_options: value.provider_options.clone(),
421            headers: value.headers.clone(),
422        }
423    }
424}
425
426#[derive(Serialize, Deserialize, Debug, Clone, Default)]
427pub struct LLMMessage {
428    pub role: String,
429    pub content: LLMMessageContent,
430}
431
432#[derive(Serialize, Deserialize, Debug, Clone)]
433pub struct SimpleLLMMessage {
434    #[serde(rename = "role")]
435    pub role: SimpleLLMRole,
436    pub content: String,
437}
438
439#[derive(Serialize, Deserialize, Debug, Clone)]
440#[serde(rename_all = "lowercase")]
441pub enum SimpleLLMRole {
442    User,
443    Assistant,
444}
445
446impl std::fmt::Display for SimpleLLMRole {
447    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
448        match self {
449            SimpleLLMRole::User => write!(f, "user"),
450            SimpleLLMRole::Assistant => write!(f, "assistant"),
451        }
452    }
453}
454
455#[derive(Serialize, Deserialize, Debug, Clone)]
456#[serde(untagged)]
457pub enum LLMMessageContent {
458    String(String),
459    List(Vec<LLMMessageTypedContent>),
460}
461
462#[allow(clippy::to_string_trait_impl)]
463impl ToString for LLMMessageContent {
464    fn to_string(&self) -> String {
465        match self {
466            LLMMessageContent::String(s) => s.clone(),
467            LLMMessageContent::List(l) => l
468                .iter()
469                .map(|c| match c {
470                    LLMMessageTypedContent::Text { text } => text.clone(),
471                    LLMMessageTypedContent::ToolCall { .. } => String::new(),
472                    LLMMessageTypedContent::ToolResult { content, .. } => content.clone(),
473                    LLMMessageTypedContent::Image { .. } => String::new(),
474                })
475                .collect::<Vec<_>>()
476                .join("\n"),
477        }
478    }
479}
480
481impl From<String> for LLMMessageContent {
482    fn from(value: String) -> Self {
483        LLMMessageContent::String(value)
484    }
485}
486
487impl Default for LLMMessageContent {
488    fn default() -> Self {
489        LLMMessageContent::String(String::new())
490    }
491}
492
493impl LLMMessageContent {
494    /// Convert into a Vec of typed content parts.
495    /// A `String` variant is returned as a single `Text` part (empty strings yield an empty vec).
496    pub fn into_parts(self) -> Vec<LLMMessageTypedContent> {
497        match self {
498            LLMMessageContent::List(parts) => parts,
499            LLMMessageContent::String(s) if s.is_empty() => vec![],
500            LLMMessageContent::String(s) => vec![LLMMessageTypedContent::Text { text: s }],
501        }
502    }
503}
504
505#[derive(Serialize, Deserialize, Debug, Clone)]
506#[serde(tag = "type")]
507pub enum LLMMessageTypedContent {
508    #[serde(rename = "text")]
509    Text { text: String },
510    #[serde(rename = "tool_use")]
511    ToolCall {
512        id: String,
513        name: String,
514        #[serde(alias = "input")]
515        args: serde_json::Value,
516        /// Opaque provider-specific metadata (e.g., Gemini thought_signature).
517        #[serde(skip_serializing_if = "Option::is_none")]
518        metadata: Option<serde_json::Value>,
519    },
520    #[serde(rename = "tool_result")]
521    ToolResult {
522        tool_use_id: String,
523        content: String,
524    },
525    #[serde(rename = "image")]
526    Image { source: LLMMessageImageSource },
527}
528
529#[derive(Serialize, Deserialize, Debug, Clone)]
530pub struct LLMMessageImageSource {
531    #[serde(rename = "type")]
532    pub r#type: String,
533    pub media_type: String,
534    pub data: String,
535}
536
537impl Default for LLMMessageTypedContent {
538    fn default() -> Self {
539        LLMMessageTypedContent::Text {
540            text: String::new(),
541        }
542    }
543}
544
545#[derive(Serialize, Deserialize, Debug, Clone)]
546pub struct LLMChoice {
547    pub finish_reason: Option<String>,
548    pub index: u32,
549    pub message: LLMMessage,
550}
551
552#[derive(Serialize, Deserialize, Debug, Clone)]
553pub struct LLMCompletionResponse {
554    pub model: String,
555    pub object: String,
556    pub choices: Vec<LLMChoice>,
557    pub created: u64,
558    pub usage: Option<LLMTokenUsage>,
559    pub id: String,
560}
561
562#[derive(Serialize, Deserialize, Debug, Clone)]
563pub struct LLMStreamDelta {
564    #[serde(skip_serializing_if = "Option::is_none")]
565    pub content: Option<String>,
566}
567
568#[derive(Serialize, Deserialize, Debug, Clone)]
569pub struct LLMStreamChoice {
570    pub finish_reason: Option<String>,
571    pub index: u32,
572    pub message: Option<LLMMessage>,
573    pub delta: LLMStreamDelta,
574}
575
576#[derive(Serialize, Deserialize, Debug, Clone)]
577pub struct LLMCompletionStreamResponse {
578    pub model: String,
579    pub object: String,
580    pub choices: Vec<LLMStreamChoice>,
581    pub created: u64,
582    #[serde(skip_serializing_if = "Option::is_none")]
583    pub usage: Option<LLMTokenUsage>,
584    pub id: String,
585    pub citations: Option<Vec<String>>,
586}
587
588#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
589pub struct LLMTool {
590    pub name: String,
591    pub description: String,
592    pub input_schema: serde_json::Value,
593}
594
595#[derive(Default, Debug, Serialize, Deserialize, Clone, PartialEq)]
596pub struct LLMTokenUsage {
597    pub prompt_tokens: u32,
598    pub completion_tokens: u32,
599    pub total_tokens: u32,
600
601    #[serde(skip_serializing_if = "Option::is_none")]
602    pub prompt_tokens_details: Option<PromptTokensDetails>,
603}
604
605#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
606#[serde(rename_all = "snake_case")]
607pub enum TokenType {
608    InputTokens,
609    OutputTokens,
610    CacheReadInputTokens,
611    CacheWriteInputTokens,
612}
613
614#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
615pub struct PromptTokensDetails {
616    #[serde(skip_serializing_if = "Option::is_none")]
617    pub input_tokens: Option<u32>,
618    #[serde(skip_serializing_if = "Option::is_none")]
619    pub output_tokens: Option<u32>,
620    #[serde(skip_serializing_if = "Option::is_none")]
621    pub cache_read_input_tokens: Option<u32>,
622    #[serde(skip_serializing_if = "Option::is_none")]
623    pub cache_write_input_tokens: Option<u32>,
624}
625
626impl PromptTokensDetails {
627    /// Returns an iterator over the token types and their values
628    pub fn iter(&self) -> impl Iterator<Item = (TokenType, u32)> {
629        [
630            (TokenType::InputTokens, self.input_tokens.unwrap_or(0)),
631            (TokenType::OutputTokens, self.output_tokens.unwrap_or(0)),
632            (
633                TokenType::CacheReadInputTokens,
634                self.cache_read_input_tokens.unwrap_or(0),
635            ),
636            (
637                TokenType::CacheWriteInputTokens,
638                self.cache_write_input_tokens.unwrap_or(0),
639            ),
640        ]
641        .into_iter()
642    }
643}
644
645impl std::ops::Add for PromptTokensDetails {
646    type Output = Self;
647
648    fn add(self, rhs: Self) -> Self::Output {
649        Self {
650            input_tokens: Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0)),
651            output_tokens: Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0)),
652            cache_read_input_tokens: Some(
653                self.cache_read_input_tokens.unwrap_or(0)
654                    + rhs.cache_read_input_tokens.unwrap_or(0),
655            ),
656            cache_write_input_tokens: Some(
657                self.cache_write_input_tokens.unwrap_or(0)
658                    + rhs.cache_write_input_tokens.unwrap_or(0),
659            ),
660        }
661    }
662}
663
664impl std::ops::AddAssign for PromptTokensDetails {
665    fn add_assign(&mut self, rhs: Self) {
666        self.input_tokens = Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0));
667        self.output_tokens = Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0));
668        self.cache_read_input_tokens = Some(
669            self.cache_read_input_tokens.unwrap_or(0) + rhs.cache_read_input_tokens.unwrap_or(0),
670        );
671        self.cache_write_input_tokens = Some(
672            self.cache_write_input_tokens.unwrap_or(0) + rhs.cache_write_input_tokens.unwrap_or(0),
673        );
674    }
675}
676
677#[derive(Serialize, Deserialize, Debug, Clone)]
678#[serde(tag = "type")]
679pub enum GenerationDelta {
680    Content { content: String },
681    Thinking { thinking: String },
682    ToolUse { tool_use: GenerationDeltaToolUse },
683    Usage { usage: LLMTokenUsage },
684    Metadata { metadata: serde_json::Value },
685}
686
687#[derive(Serialize, Deserialize, Debug, Clone)]
688pub struct GenerationDeltaToolUse {
689    pub id: Option<String>,
690    pub name: Option<String>,
691    pub input: Option<String>,
692    pub index: usize,
693    /// Opaque provider-specific metadata (e.g., Gemini thought_signature)
694    #[serde(skip_serializing_if = "Option::is_none")]
695    pub metadata: Option<serde_json::Value>,
696}
697
698#[cfg(test)]
699mod tests {
700    use super::*;
701
702    // =========================================================================
703    // ProviderConfig Tests
704    // =========================================================================
705
706    #[test]
707    fn test_provider_config_openai_serialization() {
708        let config = ProviderConfig::OpenAI {
709            api_key: Some("sk-test".to_string()),
710            api_endpoint: None,
711        };
712        let json = serde_json::to_string(&config).unwrap();
713        assert!(json.contains("\"type\":\"openai\""));
714        assert!(json.contains("\"api_key\":\"sk-test\""));
715        assert!(!json.contains("api_endpoint")); // Should be skipped when None
716    }
717
718    #[test]
719    fn test_provider_config_openai_with_endpoint() {
720        let config = ProviderConfig::OpenAI {
721            api_key: Some("sk-test".to_string()),
722            api_endpoint: Some("https://custom.openai.com/v1".to_string()),
723        };
724        let json = serde_json::to_string(&config).unwrap();
725        assert!(json.contains("\"api_endpoint\":\"https://custom.openai.com/v1\""));
726    }
727
728    #[test]
729    fn test_provider_config_anthropic_serialization() {
730        let config = ProviderConfig::Anthropic {
731            api_key: Some("sk-ant-test".to_string()),
732            api_endpoint: None,
733            access_token: Some("oauth-token".to_string()),
734        };
735        let json = serde_json::to_string(&config).unwrap();
736        assert!(json.contains("\"type\":\"anthropic\""));
737        assert!(json.contains("\"api_key\":\"sk-ant-test\""));
738        assert!(json.contains("\"access_token\":\"oauth-token\""));
739    }
740
741    #[test]
742    fn test_provider_config_gemini_serialization() {
743        let config = ProviderConfig::Gemini {
744            api_key: Some("gemini-key".to_string()),
745            api_endpoint: None,
746        };
747        let json = serde_json::to_string(&config).unwrap();
748        assert!(json.contains("\"type\":\"gemini\""));
749        assert!(json.contains("\"api_key\":\"gemini-key\""));
750    }
751
752    #[test]
753    fn test_provider_config_custom_serialization() {
754        let config = ProviderConfig::Custom {
755            api_key: Some("sk-custom".to_string()),
756            api_endpoint: "http://localhost:4000".to_string(),
757        };
758        let json = serde_json::to_string(&config).unwrap();
759        assert!(json.contains("\"type\":\"custom\""));
760        assert!(json.contains("\"api_endpoint\":\"http://localhost:4000\""));
761        assert!(json.contains("\"api_key\":\"sk-custom\""));
762    }
763
764    #[test]
765    fn test_provider_config_custom_without_key() {
766        let config = ProviderConfig::Custom {
767            api_key: None,
768            api_endpoint: "http://localhost:11434/v1".to_string(),
769        };
770        let json = serde_json::to_string(&config).unwrap();
771        assert!(json.contains("\"type\":\"custom\""));
772        assert!(json.contains("\"api_endpoint\""));
773        assert!(!json.contains("api_key")); // Should be skipped when None
774    }
775
776    #[test]
777    fn test_provider_config_deserialization_openai() {
778        let json = r#"{"type":"openai","api_key":"sk-test"}"#;
779        let config: ProviderConfig = serde_json::from_str(json).unwrap();
780        assert!(matches!(config, ProviderConfig::OpenAI { .. }));
781        assert_eq!(config.api_key(), Some("sk-test"));
782    }
783
784    #[test]
785    fn test_provider_config_deserialization_anthropic() {
786        let json = r#"{"type":"anthropic","api_key":"sk-ant","access_token":"oauth"}"#;
787        let config: ProviderConfig = serde_json::from_str(json).unwrap();
788        assert!(matches!(config, ProviderConfig::Anthropic { .. }));
789        assert_eq!(config.api_key(), Some("sk-ant"));
790        assert_eq!(config.access_token(), Some("oauth"));
791    }
792
793    #[test]
794    fn test_provider_config_deserialization_gemini() {
795        let json = r#"{"type":"gemini","api_key":"gemini-key"}"#;
796        let config: ProviderConfig = serde_json::from_str(json).unwrap();
797        assert!(matches!(config, ProviderConfig::Gemini { .. }));
798        assert_eq!(config.api_key(), Some("gemini-key"));
799    }
800
801    #[test]
802    fn test_provider_config_deserialization_custom() {
803        let json =
804            r#"{"type":"custom","api_endpoint":"http://localhost:4000","api_key":"sk-custom"}"#;
805        let config: ProviderConfig = serde_json::from_str(json).unwrap();
806        assert!(matches!(config, ProviderConfig::Custom { .. }));
807        assert_eq!(config.api_key(), Some("sk-custom"));
808        assert_eq!(config.api_endpoint(), Some("http://localhost:4000"));
809    }
810
811    #[test]
812    fn test_provider_config_helper_methods() {
813        let openai = ProviderConfig::openai(Some("sk-openai".to_string()));
814        assert_eq!(openai.provider_type(), "openai");
815        assert_eq!(openai.api_key(), Some("sk-openai"));
816
817        let anthropic =
818            ProviderConfig::anthropic(Some("sk-ant".to_string()), Some("oauth".to_string()));
819        assert_eq!(anthropic.provider_type(), "anthropic");
820        assert_eq!(anthropic.access_token(), Some("oauth"));
821
822        let gemini = ProviderConfig::gemini(Some("gemini-key".to_string()));
823        assert_eq!(gemini.provider_type(), "gemini");
824
825        let custom = ProviderConfig::custom(
826            "http://localhost:4000".to_string(),
827            Some("sk-custom".to_string()),
828        );
829        assert_eq!(custom.provider_type(), "custom");
830        assert_eq!(custom.api_endpoint(), Some("http://localhost:4000"));
831    }
832
833    #[test]
834    fn test_llm_provider_config_new() {
835        let config = LLMProviderConfig::new();
836        assert!(config.is_empty());
837    }
838
839    #[test]
840    fn test_llm_provider_config_add_and_get() {
841        let mut config = LLMProviderConfig::new();
842        config.add_provider(
843            "openai",
844            ProviderConfig::openai(Some("sk-test".to_string())),
845        );
846        config.add_provider(
847            "anthropic",
848            ProviderConfig::anthropic(Some("sk-ant".to_string()), None),
849        );
850
851        assert!(!config.is_empty());
852        assert!(config.get_provider("openai").is_some());
853        assert!(config.get_provider("anthropic").is_some());
854        assert!(config.get_provider("unknown").is_none());
855    }
856
857    #[test]
858    fn test_provider_config_toml_parsing() {
859        // Test parsing a HashMap of providers from TOML-like JSON
860        let json = r#"{
861            "openai": {"type": "openai", "api_key": "sk-openai"},
862            "anthropic": {"type": "anthropic", "api_key": "sk-ant", "access_token": "oauth"},
863            "litellm": {"type": "custom", "api_endpoint": "http://localhost:4000", "api_key": "sk-litellm"}
864        }"#;
865
866        let providers: HashMap<String, ProviderConfig> = serde_json::from_str(json).unwrap();
867        assert_eq!(providers.len(), 3);
868
869        assert!(matches!(
870            providers.get("openai"),
871            Some(ProviderConfig::OpenAI { .. })
872        ));
873        assert!(matches!(
874            providers.get("anthropic"),
875            Some(ProviderConfig::Anthropic { .. })
876        ));
877        assert!(matches!(
878            providers.get("litellm"),
879            Some(ProviderConfig::Custom { .. })
880        ));
881    }
882
883    // =========================================================================
884    // Bedrock ProviderConfig Tests
885    // =========================================================================
886
887    #[test]
888    fn test_provider_config_bedrock_serialization() {
889        let config = ProviderConfig::Bedrock {
890            region: "us-east-1".to_string(),
891            profile_name: Some("my-profile".to_string()),
892        };
893        let json = serde_json::to_string(&config).unwrap();
894        assert!(json.contains("\"type\":\"amazon-bedrock\""));
895        assert!(json.contains("\"region\":\"us-east-1\""));
896        assert!(json.contains("\"profile_name\":\"my-profile\""));
897    }
898
899    #[test]
900    fn test_provider_config_bedrock_serialization_without_profile() {
901        let config = ProviderConfig::Bedrock {
902            region: "us-west-2".to_string(),
903            profile_name: None,
904        };
905        let json = serde_json::to_string(&config).unwrap();
906        assert!(json.contains("\"type\":\"amazon-bedrock\""));
907        assert!(json.contains("\"region\":\"us-west-2\""));
908        assert!(!json.contains("profile_name")); // Should be skipped when None
909    }
910
911    #[test]
912    fn test_provider_config_bedrock_deserialization() {
913        let json = r#"{"type":"amazon-bedrock","region":"us-east-1","profile_name":"prod"}"#;
914        let config: ProviderConfig = serde_json::from_str(json).unwrap();
915        assert!(matches!(config, ProviderConfig::Bedrock { .. }));
916        assert_eq!(config.region(), Some("us-east-1"));
917        assert_eq!(config.profile_name(), Some("prod"));
918    }
919
920    #[test]
921    fn test_provider_config_bedrock_deserialization_minimal() {
922        let json = r#"{"type":"amazon-bedrock","region":"eu-west-1"}"#;
923        let config: ProviderConfig = serde_json::from_str(json).unwrap();
924        assert!(matches!(config, ProviderConfig::Bedrock { .. }));
925        assert_eq!(config.region(), Some("eu-west-1"));
926        assert_eq!(config.profile_name(), None);
927    }
928
929    #[test]
930    fn test_provider_config_bedrock_no_api_key() {
931        let config = ProviderConfig::bedrock("us-east-1".to_string(), None);
932        assert_eq!(config.api_key(), None); // Bedrock uses AWS credential chain
933        assert_eq!(config.api_endpoint(), None); // No custom endpoint
934    }
935
936    #[test]
937    fn test_provider_config_bedrock_helper_methods() {
938        let bedrock = ProviderConfig::bedrock("us-east-1".to_string(), Some("prod".to_string()));
939        assert_eq!(bedrock.provider_type(), "amazon-bedrock");
940        assert_eq!(bedrock.region(), Some("us-east-1"));
941        assert_eq!(bedrock.profile_name(), Some("prod"));
942        assert_eq!(bedrock.api_key(), None);
943        assert_eq!(bedrock.api_endpoint(), None);
944        assert_eq!(bedrock.access_token(), None);
945    }
946
947    #[test]
948    fn test_provider_config_bedrock_toml_roundtrip() {
949        let config = ProviderConfig::Bedrock {
950            region: "us-east-1".to_string(),
951            profile_name: Some("my-profile".to_string()),
952        };
953        let toml_str = toml::to_string(&config).unwrap();
954        let parsed: ProviderConfig = toml::from_str(&toml_str).unwrap();
955        assert_eq!(config, parsed);
956    }
957
958    #[test]
959    fn test_provider_config_bedrock_toml_parsing() {
960        let toml_str = r#"
961            type = "amazon-bedrock"
962            region = "us-east-1"
963            profile_name = "production"
964        "#;
965        let config: ProviderConfig = toml::from_str(toml_str).unwrap();
966        assert!(matches!(
967            config,
968            ProviderConfig::Bedrock {
969                ref region,
970                ref profile_name,
971            } if region == "us-east-1" && profile_name.as_deref() == Some("production")
972        ));
973    }
974
975    #[test]
976    fn test_provider_config_bedrock_missing_region_fails() {
977        let json = r#"{"type":"amazon-bedrock"}"#;
978        let result: Result<ProviderConfig, _> = serde_json::from_str(json);
979        assert!(result.is_err()); // region is required
980    }
981
982    #[test]
983    fn test_provider_config_bedrock_in_providers_map() {
984        let json = r#"{
985            "anthropic": {"type": "anthropic", "api_key": "sk-ant"},
986            "amazon-bedrock": {"type": "amazon-bedrock", "region": "us-east-1"}
987        }"#;
988        let providers: HashMap<String, ProviderConfig> = serde_json::from_str(json).unwrap();
989        assert_eq!(providers.len(), 2);
990        assert!(matches!(
991            providers.get("amazon-bedrock"),
992            Some(ProviderConfig::Bedrock { .. })
993        ));
994    }
995
996    #[test]
997    fn test_region_returns_none_for_non_bedrock() {
998        let openai = ProviderConfig::openai(Some("key".to_string()));
999        assert_eq!(openai.region(), None);
1000
1001        let anthropic = ProviderConfig::anthropic(Some("key".to_string()), None);
1002        assert_eq!(anthropic.region(), None);
1003    }
1004
1005    #[test]
1006    fn test_profile_name_returns_none_for_non_bedrock() {
1007        let openai = ProviderConfig::openai(Some("key".to_string()));
1008        assert_eq!(openai.profile_name(), None);
1009    }
1010}