Skip to main content

stakpak_shared/models/
llm.rs

1//! LLM Provider and Model Configuration
2//!
3//! This module provides the configuration types for LLM providers and models.
4//!
5//! # Provider Configuration
6//!
7//! Providers are configured in a `providers` HashMap where the key becomes the
8//! model prefix for routing requests to the correct provider.
9//!
10//! ## Built-in Providers
11//!
12//! - `openai` - OpenAI API
13//! - `anthropic` - Anthropic API (supports OAuth via `access_token`)
14//! - `gemini` - Google Gemini API
15//!
16//! For built-in providers, you can use the model name directly without a prefix:
17//! - `claude-sonnet-4-5` → auto-detected as Anthropic
18//! - `gpt-4` → auto-detected as OpenAI
19//! - `gemini-2.5-pro` → auto-detected as Gemini
20//!
21//! ## Custom Providers
22//!
23//! Any OpenAI-compatible API can be configured using `type = "custom"`.
24//! The provider key becomes the model prefix.
25//!
26//! # Model Routing
27//!
28//! Models can be specified with or without a provider prefix:
29//!
30//! - `claude-sonnet-4-5` → auto-detected as `anthropic` provider
31//! - `anthropic/claude-sonnet-4-5` → explicit `anthropic` provider
32//! - `offline/llama3` → routes to `offline` custom provider, sends `llama3` to API
33//! - `custom/anthropic/claude-opus` → routes to `custom` provider,
34//!   sends `anthropic/claude-opus` to the API
35//!
36//! # Example Configuration
37//!
38//! ```toml
39//! [profiles.default]
40//! provider = "local"
41//! smart_model = "claude-sonnet-4-5"  # auto-detected as anthropic
42//! eco_model = "offline/llama3"       # custom provider
43//!
44//! [profiles.default.providers.anthropic]
45//! type = "anthropic"
46//! # api_key from auth.toml or ANTHROPIC_API_KEY env var
47//!
48//! [profiles.default.providers.offline]
49//! type = "custom"
50//! api_endpoint = "http://localhost:11434/v1"
51//! ```
52
53use serde::{Deserialize, Serialize};
54use stakai::Model;
55use std::collections::HashMap;
56
57// =============================================================================
58// Provider Configuration
59// =============================================================================
60
61/// Unified provider configuration enum
62///
63/// All provider configurations are stored in a `HashMap<String, ProviderConfig>`
64/// where the key is the provider name and becomes the model prefix for routing.
65///
66/// # Provider Key = Model Prefix
67///
68/// The key used in the HashMap becomes the prefix used in model names:
69/// - Config key: `providers.offline`
70/// - Model usage: `offline/llama3`
71/// - Routing: finds `offline` provider, sends `llama3` to API
72///
73/// # Example TOML
74/// ```toml
75/// [profiles.myprofile.providers.openai]
76/// type = "openai"
77/// api_key = "sk-..."
78///
79/// [profiles.myprofile.providers.anthropic]
80/// type = "anthropic"
81/// api_key = "sk-ant-..."
82/// access_token = "oauth-token"
83///
84/// [profiles.myprofile.providers.offline]
85/// type = "custom"
86/// api_endpoint = "http://localhost:11434/v1"
87/// ```
88#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
89#[serde(tag = "type", rename_all = "lowercase")]
90pub enum ProviderConfig {
91    /// OpenAI provider configuration
92    OpenAI {
93        #[serde(skip_serializing_if = "Option::is_none")]
94        api_key: Option<String>,
95        #[serde(skip_serializing_if = "Option::is_none")]
96        api_endpoint: Option<String>,
97    },
98    /// Anthropic provider configuration
99    Anthropic {
100        #[serde(skip_serializing_if = "Option::is_none")]
101        api_key: Option<String>,
102        #[serde(skip_serializing_if = "Option::is_none")]
103        api_endpoint: Option<String>,
104        /// OAuth access token (for Claude subscription)
105        #[serde(skip_serializing_if = "Option::is_none")]
106        access_token: Option<String>,
107    },
108    /// Google Gemini provider configuration
109    Gemini {
110        #[serde(skip_serializing_if = "Option::is_none")]
111        api_key: Option<String>,
112        #[serde(skip_serializing_if = "Option::is_none")]
113        api_endpoint: Option<String>,
114    },
115    /// Custom OpenAI-compatible provider (Ollama, vLLM, etc.)
116    ///
117    /// The provider key in the config becomes the model prefix.
118    /// For example, if configured as `providers.offline`, use models as:
119    /// - `offline/llama3` - passes `llama3` to the API
120    /// - `offline/anthropic/claude-opus` - passes `anthropic/claude-opus` to the API
121    ///
122    /// # Example TOML
123    /// ```toml
124    /// [profiles.myprofile.providers.offline]
125    /// type = "custom"
126    /// api_endpoint = "http://localhost:11434/v1"
127    ///
128    /// # Then use models as:
129    /// smart_model = "offline/llama3"
130    /// eco_model = "offline/phi3"
131    /// ```
132    Custom {
133        #[serde(skip_serializing_if = "Option::is_none")]
134        api_key: Option<String>,
135        /// API endpoint URL (required for custom providers)
136        /// Use the base URL as required by your provider (e.g., "http://localhost:11434/v1")
137        api_endpoint: String,
138    },
139    /// Stakpak provider configuration
140    ///
141    /// Routes inference through Stakpak's unified API, which provides:
142    /// - Access to multiple LLM providers via a single endpoint
143    /// - Usage tracking and billing
144    /// - Session management and checkpoints
145    ///
146    /// # Example TOML
147    /// ```toml
148    /// [profiles.myprofile.providers.stakpak]
149    /// type = "stakpak"
150    /// api_key = "your-stakpak-api-key"
151    /// api_endpoint = "https://apiv2.stakpak.dev"  # optional, this is the default
152    ///
153    /// # Then use models as:
154    /// smart_model = "stakpak/anthropic/claude-sonnet-4-5-20250929"
155    /// ```
156    Stakpak {
157        /// Stakpak API key (required)
158        api_key: String,
159        /// API endpoint URL (default: https://apiv2.stakpak.dev)
160        #[serde(skip_serializing_if = "Option::is_none")]
161        api_endpoint: Option<String>,
162    },
163}
164
165impl ProviderConfig {
166    /// Get the provider type name
167    pub fn provider_type(&self) -> &'static str {
168        match self {
169            ProviderConfig::OpenAI { .. } => "openai",
170            ProviderConfig::Anthropic { .. } => "anthropic",
171            ProviderConfig::Gemini { .. } => "gemini",
172            ProviderConfig::Custom { .. } => "custom",
173            ProviderConfig::Stakpak { .. } => "stakpak",
174        }
175    }
176
177    /// Get the API key if set
178    pub fn api_key(&self) -> Option<&str> {
179        match self {
180            ProviderConfig::OpenAI { api_key, .. } => api_key.as_deref(),
181            ProviderConfig::Anthropic { api_key, .. } => api_key.as_deref(),
182            ProviderConfig::Gemini { api_key, .. } => api_key.as_deref(),
183            ProviderConfig::Custom { api_key, .. } => api_key.as_deref(),
184            ProviderConfig::Stakpak { api_key, .. } => Some(api_key.as_str()),
185        }
186    }
187
188    /// Get the API endpoint if set
189    pub fn api_endpoint(&self) -> Option<&str> {
190        match self {
191            ProviderConfig::OpenAI { api_endpoint, .. } => api_endpoint.as_deref(),
192            ProviderConfig::Anthropic { api_endpoint, .. } => api_endpoint.as_deref(),
193            ProviderConfig::Gemini { api_endpoint, .. } => api_endpoint.as_deref(),
194            ProviderConfig::Custom { api_endpoint, .. } => Some(api_endpoint.as_str()),
195            ProviderConfig::Stakpak { api_endpoint, .. } => api_endpoint.as_deref(),
196        }
197    }
198
199    /// Get the access token (Anthropic only)
200    pub fn access_token(&self) -> Option<&str> {
201        match self {
202            ProviderConfig::Anthropic { access_token, .. } => access_token.as_deref(),
203            _ => None,
204        }
205    }
206
207    /// Create an OpenAI provider config
208    pub fn openai(api_key: Option<String>) -> Self {
209        ProviderConfig::OpenAI {
210            api_key,
211            api_endpoint: None,
212        }
213    }
214
215    /// Create an Anthropic provider config
216    pub fn anthropic(api_key: Option<String>, access_token: Option<String>) -> Self {
217        ProviderConfig::Anthropic {
218            api_key,
219            api_endpoint: None,
220            access_token,
221        }
222    }
223
224    /// Create a Gemini provider config
225    pub fn gemini(api_key: Option<String>) -> Self {
226        ProviderConfig::Gemini {
227            api_key,
228            api_endpoint: None,
229        }
230    }
231
232    /// Create a custom provider config
233    pub fn custom(api_endpoint: String, api_key: Option<String>) -> Self {
234        ProviderConfig::Custom {
235            api_key,
236            api_endpoint,
237        }
238    }
239
240    /// Create a Stakpak provider config
241    pub fn stakpak(api_key: String, api_endpoint: Option<String>) -> Self {
242        ProviderConfig::Stakpak {
243            api_key,
244            api_endpoint,
245        }
246    }
247}
248
249/// Aggregated provider configuration for LLM operations
250///
251/// This struct holds all configured providers, keyed by provider name.
252#[derive(Debug, Clone, Default)]
253pub struct LLMProviderConfig {
254    /// All provider configurations (key = provider name)
255    pub providers: HashMap<String, ProviderConfig>,
256}
257
258impl LLMProviderConfig {
259    /// Create a new empty provider config
260    pub fn new() -> Self {
261        Self {
262            providers: HashMap::new(),
263        }
264    }
265
266    /// Add a provider configuration
267    pub fn add_provider(&mut self, name: impl Into<String>, config: ProviderConfig) {
268        self.providers.insert(name.into(), config);
269    }
270
271    /// Get a provider configuration by name
272    pub fn get_provider(&self, name: &str) -> Option<&ProviderConfig> {
273        self.providers.get(name)
274    }
275
276    /// Check if any providers are configured
277    pub fn is_empty(&self) -> bool {
278        self.providers.is_empty()
279    }
280}
281
282/// Provider-specific options for LLM requests
283#[derive(Clone, Debug, Serialize, Deserialize, Default)]
284pub struct LLMProviderOptions {
285    /// Anthropic-specific options
286    #[serde(skip_serializing_if = "Option::is_none")]
287    pub anthropic: Option<LLMAnthropicOptions>,
288
289    /// OpenAI-specific options
290    #[serde(skip_serializing_if = "Option::is_none")]
291    pub openai: Option<LLMOpenAIOptions>,
292
293    /// Google/Gemini-specific options
294    #[serde(skip_serializing_if = "Option::is_none")]
295    pub google: Option<LLMGoogleOptions>,
296}
297
298/// Anthropic-specific options
299#[derive(Clone, Debug, Serialize, Deserialize, Default)]
300pub struct LLMAnthropicOptions {
301    /// Extended thinking configuration
302    #[serde(skip_serializing_if = "Option::is_none")]
303    pub thinking: Option<LLMThinkingOptions>,
304}
305
306/// Thinking/reasoning options
307#[derive(Clone, Debug, Serialize, Deserialize)]
308pub struct LLMThinkingOptions {
309    /// Budget tokens for thinking (must be >= 1024)
310    pub budget_tokens: u32,
311}
312
313impl LLMThinkingOptions {
314    pub fn new(budget_tokens: u32) -> Self {
315        Self {
316            budget_tokens: budget_tokens.max(1024),
317        }
318    }
319}
320
321/// OpenAI-specific options
322#[derive(Clone, Debug, Serialize, Deserialize, Default)]
323pub struct LLMOpenAIOptions {
324    /// Reasoning effort for o1/o3/o4 models ("low", "medium", "high")
325    #[serde(skip_serializing_if = "Option::is_none")]
326    pub reasoning_effort: Option<String>,
327}
328
329/// Google/Gemini-specific options
330#[derive(Clone, Debug, Serialize, Deserialize, Default)]
331pub struct LLMGoogleOptions {
332    /// Thinking budget in tokens
333    #[serde(skip_serializing_if = "Option::is_none")]
334    pub thinking_budget: Option<u32>,
335}
336
337#[derive(Clone, Debug, Serialize)]
338pub struct LLMInput {
339    pub model: Model,
340    pub messages: Vec<LLMMessage>,
341    pub max_tokens: u32,
342    pub tools: Option<Vec<LLMTool>>,
343    #[serde(skip_serializing_if = "Option::is_none")]
344    pub provider_options: Option<LLMProviderOptions>,
345    /// Custom headers to pass to the inference provider
346    #[serde(skip_serializing_if = "Option::is_none")]
347    pub headers: Option<std::collections::HashMap<String, String>>,
348}
349
350#[derive(Debug)]
351pub struct LLMStreamInput {
352    pub model: Model,
353    pub messages: Vec<LLMMessage>,
354    pub max_tokens: u32,
355    pub stream_channel_tx: tokio::sync::mpsc::Sender<GenerationDelta>,
356    pub tools: Option<Vec<LLMTool>>,
357    pub provider_options: Option<LLMProviderOptions>,
358    /// Custom headers to pass to the inference provider
359    pub headers: Option<std::collections::HashMap<String, String>>,
360}
361
362impl From<&LLMStreamInput> for LLMInput {
363    fn from(value: &LLMStreamInput) -> Self {
364        LLMInput {
365            model: value.model.clone(),
366            messages: value.messages.clone(),
367            max_tokens: value.max_tokens,
368            tools: value.tools.clone(),
369            provider_options: value.provider_options.clone(),
370            headers: value.headers.clone(),
371        }
372    }
373}
374
375#[derive(Serialize, Deserialize, Debug, Clone, Default)]
376pub struct LLMMessage {
377    pub role: String,
378    pub content: LLMMessageContent,
379}
380
381#[derive(Serialize, Deserialize, Debug, Clone)]
382pub struct SimpleLLMMessage {
383    #[serde(rename = "role")]
384    pub role: SimpleLLMRole,
385    pub content: String,
386}
387
388#[derive(Serialize, Deserialize, Debug, Clone)]
389#[serde(rename_all = "lowercase")]
390pub enum SimpleLLMRole {
391    User,
392    Assistant,
393}
394
395impl std::fmt::Display for SimpleLLMRole {
396    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
397        match self {
398            SimpleLLMRole::User => write!(f, "user"),
399            SimpleLLMRole::Assistant => write!(f, "assistant"),
400        }
401    }
402}
403
404#[derive(Serialize, Deserialize, Debug, Clone)]
405#[serde(untagged)]
406pub enum LLMMessageContent {
407    String(String),
408    List(Vec<LLMMessageTypedContent>),
409}
410
411#[allow(clippy::to_string_trait_impl)]
412impl ToString for LLMMessageContent {
413    fn to_string(&self) -> String {
414        match self {
415            LLMMessageContent::String(s) => s.clone(),
416            LLMMessageContent::List(l) => l
417                .iter()
418                .map(|c| match c {
419                    LLMMessageTypedContent::Text { text } => text.clone(),
420                    LLMMessageTypedContent::ToolCall { .. } => String::new(),
421                    LLMMessageTypedContent::ToolResult { content, .. } => content.clone(),
422                    LLMMessageTypedContent::Image { .. } => String::new(),
423                })
424                .collect::<Vec<_>>()
425                .join("\n"),
426        }
427    }
428}
429
430impl From<String> for LLMMessageContent {
431    fn from(value: String) -> Self {
432        LLMMessageContent::String(value)
433    }
434}
435
436impl Default for LLMMessageContent {
437    fn default() -> Self {
438        LLMMessageContent::String(String::new())
439    }
440}
441
442#[derive(Serialize, Deserialize, Debug, Clone)]
443#[serde(tag = "type")]
444pub enum LLMMessageTypedContent {
445    #[serde(rename = "text")]
446    Text { text: String },
447    #[serde(rename = "tool_use")]
448    ToolCall {
449        id: String,
450        name: String,
451        #[serde(alias = "input")]
452        args: serde_json::Value,
453        /// Opaque provider-specific metadata (e.g., Gemini thought_signature).
454        #[serde(skip_serializing_if = "Option::is_none")]
455        metadata: Option<serde_json::Value>,
456    },
457    #[serde(rename = "tool_result")]
458    ToolResult {
459        tool_use_id: String,
460        content: String,
461    },
462    #[serde(rename = "image")]
463    Image { source: LLMMessageImageSource },
464}
465
466#[derive(Serialize, Deserialize, Debug, Clone)]
467pub struct LLMMessageImageSource {
468    #[serde(rename = "type")]
469    pub r#type: String,
470    pub media_type: String,
471    pub data: String,
472}
473
474impl Default for LLMMessageTypedContent {
475    fn default() -> Self {
476        LLMMessageTypedContent::Text {
477            text: String::new(),
478        }
479    }
480}
481
482#[derive(Serialize, Deserialize, Debug, Clone)]
483pub struct LLMChoice {
484    pub finish_reason: Option<String>,
485    pub index: u32,
486    pub message: LLMMessage,
487}
488
489#[derive(Serialize, Deserialize, Debug, Clone)]
490pub struct LLMCompletionResponse {
491    pub model: String,
492    pub object: String,
493    pub choices: Vec<LLMChoice>,
494    pub created: u64,
495    pub usage: Option<LLMTokenUsage>,
496    pub id: String,
497}
498
499#[derive(Serialize, Deserialize, Debug, Clone)]
500pub struct LLMStreamDelta {
501    #[serde(skip_serializing_if = "Option::is_none")]
502    pub content: Option<String>,
503}
504
505#[derive(Serialize, Deserialize, Debug, Clone)]
506pub struct LLMStreamChoice {
507    pub finish_reason: Option<String>,
508    pub index: u32,
509    pub message: Option<LLMMessage>,
510    pub delta: LLMStreamDelta,
511}
512
513#[derive(Serialize, Deserialize, Debug, Clone)]
514pub struct LLMCompletionStreamResponse {
515    pub model: String,
516    pub object: String,
517    pub choices: Vec<LLMStreamChoice>,
518    pub created: u64,
519    #[serde(skip_serializing_if = "Option::is_none")]
520    pub usage: Option<LLMTokenUsage>,
521    pub id: String,
522    pub citations: Option<Vec<String>>,
523}
524
525#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
526pub struct LLMTool {
527    pub name: String,
528    pub description: String,
529    pub input_schema: serde_json::Value,
530}
531
532#[derive(Default, Debug, Serialize, Deserialize, Clone, PartialEq)]
533pub struct LLMTokenUsage {
534    pub prompt_tokens: u32,
535    pub completion_tokens: u32,
536    pub total_tokens: u32,
537
538    #[serde(skip_serializing_if = "Option::is_none")]
539    pub prompt_tokens_details: Option<PromptTokensDetails>,
540}
541
542#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
543#[serde(rename_all = "snake_case")]
544pub enum TokenType {
545    InputTokens,
546    OutputTokens,
547    CacheReadInputTokens,
548    CacheWriteInputTokens,
549}
550
551#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
552pub struct PromptTokensDetails {
553    #[serde(skip_serializing_if = "Option::is_none")]
554    pub input_tokens: Option<u32>,
555    #[serde(skip_serializing_if = "Option::is_none")]
556    pub output_tokens: Option<u32>,
557    #[serde(skip_serializing_if = "Option::is_none")]
558    pub cache_read_input_tokens: Option<u32>,
559    #[serde(skip_serializing_if = "Option::is_none")]
560    pub cache_write_input_tokens: Option<u32>,
561}
562
563impl PromptTokensDetails {
564    /// Returns an iterator over the token types and their values
565    pub fn iter(&self) -> impl Iterator<Item = (TokenType, u32)> {
566        [
567            (TokenType::InputTokens, self.input_tokens.unwrap_or(0)),
568            (TokenType::OutputTokens, self.output_tokens.unwrap_or(0)),
569            (
570                TokenType::CacheReadInputTokens,
571                self.cache_read_input_tokens.unwrap_or(0),
572            ),
573            (
574                TokenType::CacheWriteInputTokens,
575                self.cache_write_input_tokens.unwrap_or(0),
576            ),
577        ]
578        .into_iter()
579    }
580}
581
582impl std::ops::Add for PromptTokensDetails {
583    type Output = Self;
584
585    fn add(self, rhs: Self) -> Self::Output {
586        Self {
587            input_tokens: Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0)),
588            output_tokens: Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0)),
589            cache_read_input_tokens: Some(
590                self.cache_read_input_tokens.unwrap_or(0)
591                    + rhs.cache_read_input_tokens.unwrap_or(0),
592            ),
593            cache_write_input_tokens: Some(
594                self.cache_write_input_tokens.unwrap_or(0)
595                    + rhs.cache_write_input_tokens.unwrap_or(0),
596            ),
597        }
598    }
599}
600
601impl std::ops::AddAssign for PromptTokensDetails {
602    fn add_assign(&mut self, rhs: Self) {
603        self.input_tokens = Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0));
604        self.output_tokens = Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0));
605        self.cache_read_input_tokens = Some(
606            self.cache_read_input_tokens.unwrap_or(0) + rhs.cache_read_input_tokens.unwrap_or(0),
607        );
608        self.cache_write_input_tokens = Some(
609            self.cache_write_input_tokens.unwrap_or(0) + rhs.cache_write_input_tokens.unwrap_or(0),
610        );
611    }
612}
613
614#[derive(Serialize, Deserialize, Debug, Clone)]
615#[serde(tag = "type")]
616pub enum GenerationDelta {
617    Content { content: String },
618    Thinking { thinking: String },
619    ToolUse { tool_use: GenerationDeltaToolUse },
620    Usage { usage: LLMTokenUsage },
621    Metadata { metadata: serde_json::Value },
622}
623
624#[derive(Serialize, Deserialize, Debug, Clone)]
625pub struct GenerationDeltaToolUse {
626    pub id: Option<String>,
627    pub name: Option<String>,
628    pub input: Option<String>,
629    pub index: usize,
630    /// Opaque provider-specific metadata (e.g., Gemini thought_signature)
631    #[serde(skip_serializing_if = "Option::is_none")]
632    pub metadata: Option<serde_json::Value>,
633}
634
635#[cfg(test)]
636mod tests {
637    use super::*;
638
639    // =========================================================================
640    // ProviderConfig Tests
641    // =========================================================================
642
643    #[test]
644    fn test_provider_config_openai_serialization() {
645        let config = ProviderConfig::OpenAI {
646            api_key: Some("sk-test".to_string()),
647            api_endpoint: None,
648        };
649        let json = serde_json::to_string(&config).unwrap();
650        assert!(json.contains("\"type\":\"openai\""));
651        assert!(json.contains("\"api_key\":\"sk-test\""));
652        assert!(!json.contains("api_endpoint")); // Should be skipped when None
653    }
654
655    #[test]
656    fn test_provider_config_openai_with_endpoint() {
657        let config = ProviderConfig::OpenAI {
658            api_key: Some("sk-test".to_string()),
659            api_endpoint: Some("https://custom.openai.com/v1".to_string()),
660        };
661        let json = serde_json::to_string(&config).unwrap();
662        assert!(json.contains("\"api_endpoint\":\"https://custom.openai.com/v1\""));
663    }
664
665    #[test]
666    fn test_provider_config_anthropic_serialization() {
667        let config = ProviderConfig::Anthropic {
668            api_key: Some("sk-ant-test".to_string()),
669            api_endpoint: None,
670            access_token: Some("oauth-token".to_string()),
671        };
672        let json = serde_json::to_string(&config).unwrap();
673        assert!(json.contains("\"type\":\"anthropic\""));
674        assert!(json.contains("\"api_key\":\"sk-ant-test\""));
675        assert!(json.contains("\"access_token\":\"oauth-token\""));
676    }
677
678    #[test]
679    fn test_provider_config_gemini_serialization() {
680        let config = ProviderConfig::Gemini {
681            api_key: Some("gemini-key".to_string()),
682            api_endpoint: None,
683        };
684        let json = serde_json::to_string(&config).unwrap();
685        assert!(json.contains("\"type\":\"gemini\""));
686        assert!(json.contains("\"api_key\":\"gemini-key\""));
687    }
688
689    #[test]
690    fn test_provider_config_custom_serialization() {
691        let config = ProviderConfig::Custom {
692            api_key: Some("sk-custom".to_string()),
693            api_endpoint: "http://localhost:4000".to_string(),
694        };
695        let json = serde_json::to_string(&config).unwrap();
696        assert!(json.contains("\"type\":\"custom\""));
697        assert!(json.contains("\"api_endpoint\":\"http://localhost:4000\""));
698        assert!(json.contains("\"api_key\":\"sk-custom\""));
699    }
700
701    #[test]
702    fn test_provider_config_custom_without_key() {
703        let config = ProviderConfig::Custom {
704            api_key: None,
705            api_endpoint: "http://localhost:11434/v1".to_string(),
706        };
707        let json = serde_json::to_string(&config).unwrap();
708        assert!(json.contains("\"type\":\"custom\""));
709        assert!(json.contains("\"api_endpoint\""));
710        assert!(!json.contains("api_key")); // Should be skipped when None
711    }
712
713    #[test]
714    fn test_provider_config_deserialization_openai() {
715        let json = r#"{"type":"openai","api_key":"sk-test"}"#;
716        let config: ProviderConfig = serde_json::from_str(json).unwrap();
717        assert!(matches!(config, ProviderConfig::OpenAI { .. }));
718        assert_eq!(config.api_key(), Some("sk-test"));
719    }
720
721    #[test]
722    fn test_provider_config_deserialization_anthropic() {
723        let json = r#"{"type":"anthropic","api_key":"sk-ant","access_token":"oauth"}"#;
724        let config: ProviderConfig = serde_json::from_str(json).unwrap();
725        assert!(matches!(config, ProviderConfig::Anthropic { .. }));
726        assert_eq!(config.api_key(), Some("sk-ant"));
727        assert_eq!(config.access_token(), Some("oauth"));
728    }
729
730    #[test]
731    fn test_provider_config_deserialization_gemini() {
732        let json = r#"{"type":"gemini","api_key":"gemini-key"}"#;
733        let config: ProviderConfig = serde_json::from_str(json).unwrap();
734        assert!(matches!(config, ProviderConfig::Gemini { .. }));
735        assert_eq!(config.api_key(), Some("gemini-key"));
736    }
737
738    #[test]
739    fn test_provider_config_deserialization_custom() {
740        let json =
741            r#"{"type":"custom","api_endpoint":"http://localhost:4000","api_key":"sk-custom"}"#;
742        let config: ProviderConfig = serde_json::from_str(json).unwrap();
743        assert!(matches!(config, ProviderConfig::Custom { .. }));
744        assert_eq!(config.api_key(), Some("sk-custom"));
745        assert_eq!(config.api_endpoint(), Some("http://localhost:4000"));
746    }
747
748    #[test]
749    fn test_provider_config_helper_methods() {
750        let openai = ProviderConfig::openai(Some("sk-openai".to_string()));
751        assert_eq!(openai.provider_type(), "openai");
752        assert_eq!(openai.api_key(), Some("sk-openai"));
753
754        let anthropic =
755            ProviderConfig::anthropic(Some("sk-ant".to_string()), Some("oauth".to_string()));
756        assert_eq!(anthropic.provider_type(), "anthropic");
757        assert_eq!(anthropic.access_token(), Some("oauth"));
758
759        let gemini = ProviderConfig::gemini(Some("gemini-key".to_string()));
760        assert_eq!(gemini.provider_type(), "gemini");
761
762        let custom = ProviderConfig::custom(
763            "http://localhost:4000".to_string(),
764            Some("sk-custom".to_string()),
765        );
766        assert_eq!(custom.provider_type(), "custom");
767        assert_eq!(custom.api_endpoint(), Some("http://localhost:4000"));
768    }
769
770    #[test]
771    fn test_llm_provider_config_new() {
772        let config = LLMProviderConfig::new();
773        assert!(config.is_empty());
774    }
775
776    #[test]
777    fn test_llm_provider_config_add_and_get() {
778        let mut config = LLMProviderConfig::new();
779        config.add_provider(
780            "openai",
781            ProviderConfig::openai(Some("sk-test".to_string())),
782        );
783        config.add_provider(
784            "anthropic",
785            ProviderConfig::anthropic(Some("sk-ant".to_string()), None),
786        );
787
788        assert!(!config.is_empty());
789        assert!(config.get_provider("openai").is_some());
790        assert!(config.get_provider("anthropic").is_some());
791        assert!(config.get_provider("unknown").is_none());
792    }
793
794    #[test]
795    fn test_provider_config_toml_parsing() {
796        // Test parsing a HashMap of providers from TOML-like JSON
797        let json = r#"{
798            "openai": {"type": "openai", "api_key": "sk-openai"},
799            "anthropic": {"type": "anthropic", "api_key": "sk-ant", "access_token": "oauth"},
800            "litellm": {"type": "custom", "api_endpoint": "http://localhost:4000", "api_key": "sk-litellm"}
801        }"#;
802
803        let providers: HashMap<String, ProviderConfig> = serde_json::from_str(json).unwrap();
804        assert_eq!(providers.len(), 3);
805
806        assert!(matches!(
807            providers.get("openai"),
808            Some(ProviderConfig::OpenAI { .. })
809        ));
810        assert!(matches!(
811            providers.get("anthropic"),
812            Some(ProviderConfig::Anthropic { .. })
813        ));
814        assert!(matches!(
815            providers.get("litellm"),
816            Some(ProviderConfig::Custom { .. })
817        ));
818    }
819}