Skip to main content

stakpak_shared/models/
llm.rs

1//! LLM Provider and Model Configuration
2//!
3//! This module provides the configuration types for LLM providers and models.
4//!
5//! # Provider Configuration
6//!
7//! Providers are configured in a `providers` HashMap where the key becomes the
8//! model prefix for routing requests to the correct provider.
9//!
10//! ## Built-in Providers
11//!
12//! - `openai` - OpenAI API
13//! - `anthropic` - Anthropic API (supports OAuth via `access_token`)
14//! - `gemini` - Google Gemini API
15//! - `bedrock` - AWS Bedrock (uses AWS credential chain, no API key)
16//!
17//! For built-in providers, you can use the model name directly without a prefix:
18//! - `claude-sonnet-4-5` → auto-detected as Anthropic
19//! - `gpt-4` → auto-detected as OpenAI
20//! - `gemini-2.5-pro` → auto-detected as Gemini
21//!
22//! ## Custom Providers
23//!
24//! Any OpenAI-compatible API can be configured using `type = "custom"`.
25//! The provider key becomes the model prefix.
26//!
27//! # Model Routing
28//!
29//! Models can be specified with or without a provider prefix:
30//!
31//! - `claude-sonnet-4-5` → auto-detected as `anthropic` provider
32//! - `anthropic/claude-sonnet-4-5` → explicit `anthropic` provider
33//! - `offline/llama3` → routes to `offline` custom provider, sends `llama3` to API
34//! - `custom/anthropic/claude-opus` → routes to `custom` provider,
35//!   sends `anthropic/claude-opus` to the API
36//!
37//! # Example Configuration
38//!
39//! ```toml
40//! [profiles.default]
41//! provider = "local"
42//! smart_model = "claude-sonnet-4-5"  # auto-detected as anthropic
43//! eco_model = "offline/llama3"       # custom provider
44//!
45//! [profiles.default.providers.anthropic]
46//! type = "anthropic"
47//! # api_key from auth.toml or ANTHROPIC_API_KEY env var
48//!
49//! [profiles.default.providers.offline]
50//! type = "custom"
51//! api_endpoint = "http://localhost:11434/v1"
52//! ```
53
54use serde::{Deserialize, Serialize};
55use stakai::Model;
56use std::collections::HashMap;
57
58use super::auth::ProviderAuth;
59
60// =============================================================================
61// Provider Configuration
62// =============================================================================
63
64/// Unified provider configuration enum
65///
66/// All provider configurations are stored in a `HashMap<String, ProviderConfig>`
67/// where the key is the provider name and becomes the model prefix for routing.
68///
69/// # Provider Key = Model Prefix
70///
71/// The key used in the HashMap becomes the prefix used in model names:
72/// - Config key: `providers.offline`
73/// - Model usage: `offline/llama3`
74/// - Routing: finds `offline` provider, sends `llama3` to API
75///
76/// # Example TOML
77/// ```toml
78/// [profiles.myprofile.providers.openai]
79/// type = "openai"
80///
81/// [profiles.myprofile.providers.openai.auth]
82/// type = "api"
83/// key = "sk-..."
84///
85/// [profiles.myprofile.providers.anthropic]
86/// type = "anthropic"
87///
88/// [profiles.myprofile.providers.anthropic.auth]
89/// type = "oauth"
90/// access = "eyJ..."
91/// refresh = "eyJ..."
92/// expires = 1735600000000
93/// name = "Claude Max"
94///
95/// [profiles.myprofile.providers.offline]
96/// type = "custom"
97/// api_endpoint = "http://localhost:11434/v1"
98/// ```
99#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
100#[serde(tag = "type", rename_all = "lowercase")]
101pub enum ProviderConfig {
102    /// OpenAI provider configuration
103    OpenAI {
104        /// Legacy API key field (prefer `auth` field)
105        #[serde(skip_serializing_if = "Option::is_none")]
106        api_key: Option<String>,
107        #[serde(skip_serializing_if = "Option::is_none")]
108        api_endpoint: Option<String>,
109        /// Authentication credentials (preferred over api_key)
110        #[serde(skip_serializing_if = "Option::is_none")]
111        auth: Option<ProviderAuth>,
112    },
113    /// Anthropic provider configuration
114    Anthropic {
115        /// Legacy API key field (prefer `auth` field)
116        #[serde(skip_serializing_if = "Option::is_none")]
117        api_key: Option<String>,
118        #[serde(skip_serializing_if = "Option::is_none")]
119        api_endpoint: Option<String>,
120        /// Legacy OAuth access token (prefer `auth` field with OAuth type)
121        #[serde(skip_serializing_if = "Option::is_none")]
122        access_token: Option<String>,
123        /// Authentication credentials (preferred over api_key/access_token)
124        #[serde(skip_serializing_if = "Option::is_none")]
125        auth: Option<ProviderAuth>,
126    },
127    /// Google Gemini provider configuration
128    Gemini {
129        /// Legacy API key field (prefer `auth` field)
130        #[serde(skip_serializing_if = "Option::is_none")]
131        api_key: Option<String>,
132        #[serde(skip_serializing_if = "Option::is_none")]
133        api_endpoint: Option<String>,
134        /// Authentication credentials (preferred over api_key)
135        #[serde(skip_serializing_if = "Option::is_none")]
136        auth: Option<ProviderAuth>,
137    },
138    /// Custom OpenAI-compatible provider (Ollama, vLLM, etc.)
139    ///
140    /// The provider key in the config becomes the model prefix.
141    /// For example, if configured as `providers.offline`, use models as:
142    /// - `offline/llama3` - passes `llama3` to the API
143    /// - `offline/anthropic/claude-opus` - passes `anthropic/claude-opus` to the API
144    ///
145    /// # Example TOML
146    /// ```toml
147    /// [profiles.myprofile.providers.offline]
148    /// type = "custom"
149    /// api_endpoint = "http://localhost:11434/v1"
150    ///
151    /// # Then use models as:
152    /// model = "offline/llama3"
153    /// ```
154    Custom {
155        /// Legacy API key field (prefer `auth` field)
156        #[serde(skip_serializing_if = "Option::is_none")]
157        api_key: Option<String>,
158        /// API endpoint URL (required for custom providers)
159        /// Use the base URL as required by your provider (e.g., "http://localhost:11434/v1")
160        api_endpoint: String,
161        /// Authentication credentials (preferred over api_key)
162        #[serde(skip_serializing_if = "Option::is_none")]
163        auth: Option<ProviderAuth>,
164    },
165    /// Stakpak provider configuration
166    ///
167    /// Routes inference through Stakpak's unified API, which provides:
168    /// - Access to multiple LLM providers via a single endpoint
169    /// - Usage tracking and billing
170    /// - Session management and checkpoints
171    ///
172    /// # Example TOML
173    /// ```toml
174    /// [profiles.myprofile.providers.stakpak]
175    /// type = "stakpak"
176    /// api_endpoint = "https://apiv2.stakpak.dev"  # optional, this is the default
177    ///
178    /// [profiles.myprofile.providers.stakpak.auth]
179    /// type = "api"
180    /// key = "your-stakpak-api-key"
181    ///
182    /// # Then use models as:
183    /// model = "stakpak/anthropic/claude-sonnet-4-5-20250929"
184    /// ```
185    Stakpak {
186        /// Legacy API key field (prefer `auth` field)
187        /// Note: This field is optional when using `auth`
188        #[serde(skip_serializing_if = "Option::is_none")]
189        api_key: Option<String>,
190        /// API endpoint URL (default: https://apiv2.stakpak.dev)
191        #[serde(skip_serializing_if = "Option::is_none")]
192        api_endpoint: Option<String>,
193        /// Authentication credentials (preferred over api_key)
194        #[serde(skip_serializing_if = "Option::is_none")]
195        auth: Option<ProviderAuth>,
196    },
197    /// AWS Bedrock provider configuration
198    ///
199    /// Uses AWS credential chain for authentication (no API key needed).
200    /// Supports env vars, shared credentials, SSO, and instance roles.
201    ///
202    /// # Example TOML
203    /// ```toml
204    /// [profiles.myprofile.providers.amazon-bedrock]
205    /// type = "amazon-bedrock"
206    /// region = "us-east-1"
207    /// profile_name = "my-aws-profile"  # optional
208    ///
209    /// # Then use models as (friendly aliases work):
210    /// model = "amazon-bedrock/claude-sonnet-4-5"
211    /// ```
212    #[serde(rename = "amazon-bedrock")]
213    Bedrock {
214        /// AWS region (e.g., "us-east-1")
215        region: String,
216        /// Optional AWS named profile (from ~/.aws/config)
217        #[serde(skip_serializing_if = "Option::is_none")]
218        profile_name: Option<String>,
219    },
220    /// GitHub Copilot provider configuration
221    ///
222    /// Uses the GitHub Device Authorization Grant to obtain an OAuth token, then
223    /// calls the OpenAI-compatible Copilot API endpoint.
224    ///
225    /// # Example TOML
226    /// ```toml
227    /// [profiles.myprofile.providers.github-copilot]
228    /// type = "github-copilot"
229    ///
230    /// [profiles.myprofile.providers.github-copilot.auth]
231    /// type = "oauth"
232    /// access = "ghu_..."
233    /// refresh = ""
234    /// expires = 9223372036854775807
235    /// name = "GitHub Copilot"
236    ///
237    /// # Then use models as:
238    /// model = "github-copilot/gpt-4o"
239    /// ```
240    #[serde(rename = "github-copilot")]
241    GitHubCopilot {
242        /// Optional custom API endpoint (defaults to https://api.githubcopilot.com)
243        #[serde(skip_serializing_if = "Option::is_none")]
244        api_endpoint: Option<String>,
245        /// Authentication credentials (OAuth access token from device flow)
246        #[serde(skip_serializing_if = "Option::is_none")]
247        auth: Option<ProviderAuth>,
248    },
249}
250
251impl ProviderConfig {
252    /// Get the provider type name
253    pub fn provider_type(&self) -> &'static str {
254        match self {
255            ProviderConfig::OpenAI { .. } => "openai",
256            ProviderConfig::Anthropic { .. } => "anthropic",
257            ProviderConfig::Gemini { .. } => "gemini",
258            ProviderConfig::Custom { .. } => "custom",
259            ProviderConfig::Stakpak { .. } => "stakpak",
260            ProviderConfig::Bedrock { .. } => "amazon-bedrock",
261            ProviderConfig::GitHubCopilot { .. } => "github-copilot",
262        }
263    }
264
265    /// Get the API key if set (checks `auth` field first, then legacy `api_key`)
266    pub fn api_key(&self) -> Option<&str> {
267        // First check auth field
268        if let Some(auth) = self.get_auth_ref()
269            && let Some(key) = auth.api_key_value()
270        {
271            return Some(key);
272        }
273        // Fall back to legacy api_key field
274        match self {
275            ProviderConfig::OpenAI { api_key, .. } => api_key.as_deref(),
276            ProviderConfig::Anthropic { api_key, .. } => api_key.as_deref(),
277            ProviderConfig::Gemini { api_key, .. } => api_key.as_deref(),
278            ProviderConfig::Custom { api_key, .. } => api_key.as_deref(),
279            ProviderConfig::Stakpak { api_key, .. } => api_key.as_deref(),
280            ProviderConfig::Bedrock { .. } => None, // AWS credential chain, no API key
281            ProviderConfig::GitHubCopilot { .. } => None, // OAuth only, no API key
282        }
283    }
284
285    /// Get the auth credentials reference
286    fn get_auth_ref(&self) -> Option<&ProviderAuth> {
287        match self {
288            ProviderConfig::OpenAI { auth, .. } => auth.as_ref(),
289            ProviderConfig::Anthropic { auth, .. } => auth.as_ref(),
290            ProviderConfig::Gemini { auth, .. } => auth.as_ref(),
291            ProviderConfig::Custom { auth, .. } => auth.as_ref(),
292            ProviderConfig::Stakpak { auth, .. } => auth.as_ref(),
293            ProviderConfig::Bedrock { .. } => None,
294            ProviderConfig::GitHubCopilot { auth, .. } => auth.as_ref(),
295        }
296    }
297
298    /// Get resolved authentication credentials.
299    ///
300    /// Resolution order:
301    /// 1. `auth` field (preferred)
302    /// 2. Legacy `api_key` field (converted to ProviderAuth::Api)
303    /// 3. Legacy `access_token` field for Anthropic (converted to ProviderAuth with access token)
304    pub fn get_auth(&self) -> Option<ProviderAuth> {
305        // First check auth field
306        if let Some(auth) = self.get_auth_ref() {
307            return Some(auth.clone());
308        }
309
310        // Fall back to legacy fields
311        match self {
312            ProviderConfig::OpenAI { api_key, .. }
313            | ProviderConfig::Gemini { api_key, .. }
314            | ProviderConfig::Custom { api_key, .. }
315            | ProviderConfig::Stakpak { api_key, .. } => {
316                api_key.as_ref().map(ProviderAuth::api_key)
317            }
318            ProviderConfig::Anthropic {
319                api_key,
320                access_token,
321                ..
322            } => {
323                // Prefer api_key, then access_token (as OAuth bearer token, not API key)
324                if let Some(key) = api_key {
325                    Some(ProviderAuth::api_key(key))
326                } else {
327                    // Legacy access_token is an OAuth bearer token — wrap it as OAuth
328                    // with empty refresh token and zero expiry so it will be treated as
329                    // expired and trigger a re-auth rather than silently failing.
330                    access_token
331                        .as_ref()
332                        .map(|token| ProviderAuth::oauth(token, "", 0))
333                }
334            }
335            ProviderConfig::Bedrock { .. } => None,
336            // GitHubCopilot has no legacy fields; auth is always in the `auth` field
337            ProviderConfig::GitHubCopilot { .. } => None,
338        }
339    }
340
341    /// Set authentication credentials on this provider config.
342    ///
343    /// Also clears any legacy credential fields (`api_key`, `access_token`)
344    /// so they don't shadow the new `auth` field on future reads.
345    pub fn set_auth(&mut self, auth: ProviderAuth) {
346        match self {
347            ProviderConfig::OpenAI {
348                auth: auth_field,
349                api_key,
350                ..
351            }
352            | ProviderConfig::Gemini {
353                auth: auth_field,
354                api_key,
355                ..
356            }
357            | ProviderConfig::Custom {
358                auth: auth_field,
359                api_key,
360                ..
361            }
362            | ProviderConfig::Stakpak {
363                auth: auth_field,
364                api_key,
365                ..
366            } => {
367                *auth_field = Some(auth);
368                *api_key = None;
369            }
370            ProviderConfig::Anthropic {
371                auth: auth_field,
372                api_key,
373                access_token,
374                ..
375            } => {
376                *auth_field = Some(auth);
377                *api_key = None;
378                *access_token = None;
379            }
380            ProviderConfig::GitHubCopilot {
381                auth: auth_field, ..
382            } => {
383                *auth_field = Some(auth);
384            }
385            ProviderConfig::Bedrock { .. } => {
386                // Bedrock uses AWS credential chain, no auth field
387            }
388        }
389    }
390
391    /// Clear authentication credentials from this provider config.
392    ///
393    /// Clears both the `auth` field and any legacy credential fields
394    /// (`api_key`, `access_token`) to ensure credentials are fully removed.
395    pub fn clear_auth(&mut self) {
396        match self {
397            ProviderConfig::OpenAI {
398                auth: auth_field,
399                api_key,
400                ..
401            }
402            | ProviderConfig::Gemini {
403                auth: auth_field,
404                api_key,
405                ..
406            }
407            | ProviderConfig::Custom {
408                auth: auth_field,
409                api_key,
410                ..
411            }
412            | ProviderConfig::Stakpak {
413                auth: auth_field,
414                api_key,
415                ..
416            } => {
417                *auth_field = None;
418                *api_key = None;
419            }
420            ProviderConfig::Anthropic {
421                auth: auth_field,
422                api_key,
423                access_token,
424                ..
425            } => {
426                *auth_field = None;
427                *api_key = None;
428                *access_token = None;
429            }
430            ProviderConfig::GitHubCopilot {
431                auth: auth_field, ..
432            } => {
433                *auth_field = None;
434            }
435            ProviderConfig::Bedrock { .. } => {
436                // Bedrock uses AWS credential chain, no auth field
437            }
438        }
439    }
440
441    /// Get the API endpoint if set
442    pub fn api_endpoint(&self) -> Option<&str> {
443        match self {
444            ProviderConfig::OpenAI { api_endpoint, .. } => api_endpoint.as_deref(),
445            ProviderConfig::Anthropic { api_endpoint, .. } => api_endpoint.as_deref(),
446            ProviderConfig::Gemini { api_endpoint, .. } => api_endpoint.as_deref(),
447            ProviderConfig::Custom { api_endpoint, .. } => Some(api_endpoint.as_str()),
448            ProviderConfig::Stakpak { api_endpoint, .. } => api_endpoint.as_deref(),
449            ProviderConfig::Bedrock { .. } => None, // No custom endpoint in config
450            ProviderConfig::GitHubCopilot { api_endpoint, .. } => api_endpoint.as_deref(),
451        }
452    }
453
454    /// Set the API endpoint for providers that support it.
455    ///
456    /// For `Custom`, `None` is ignored because custom providers require an endpoint.
457    /// For `Bedrock`, this is a no-op.
458    pub fn set_api_endpoint(&mut self, endpoint: Option<String>) {
459        match self {
460            ProviderConfig::OpenAI { api_endpoint, .. }
461            | ProviderConfig::Anthropic { api_endpoint, .. }
462            | ProviderConfig::Gemini { api_endpoint, .. }
463            | ProviderConfig::Stakpak { api_endpoint, .. }
464            | ProviderConfig::GitHubCopilot { api_endpoint, .. } => {
465                *api_endpoint = endpoint;
466            }
467            ProviderConfig::Custom { api_endpoint, .. } => {
468                if let Some(custom_endpoint) = endpoint {
469                    *api_endpoint = custom_endpoint;
470                }
471            }
472            ProviderConfig::Bedrock { .. } => {}
473        }
474    }
475
476    /// Get the access token (for OAuth-based providers such as Anthropic and GitHub Copilot)
477    ///
478    /// Checks the `auth` field first for OAuth access token, then falls back
479    /// to the legacy `access_token` field (Anthropic only).
480    pub fn access_token(&self) -> Option<&str> {
481        // First check auth field for OAuth access token
482        if let Some(auth) = self.get_auth_ref()
483            && let Some(token) = auth.access_token()
484        {
485            return Some(token);
486        }
487        // Fall back to legacy access_token field (Anthropic only)
488        match self {
489            ProviderConfig::Anthropic { access_token, .. } => access_token.as_deref(),
490            _ => None,
491        }
492    }
493
494    /// Create an OpenAI provider config (legacy, uses api_key field)
495    pub fn openai(api_key: Option<String>) -> Self {
496        ProviderConfig::OpenAI {
497            api_key,
498            api_endpoint: None,
499            auth: None,
500        }
501    }
502
503    /// Create an OpenAI provider config with auth
504    pub fn openai_with_auth(auth: ProviderAuth) -> Self {
505        ProviderConfig::OpenAI {
506            api_key: None,
507            api_endpoint: None,
508            auth: Some(auth),
509        }
510    }
511
512    /// Create an Anthropic provider config (legacy, uses api_key/access_token fields)
513    pub fn anthropic(api_key: Option<String>, access_token: Option<String>) -> Self {
514        ProviderConfig::Anthropic {
515            api_key,
516            api_endpoint: None,
517            access_token,
518            auth: None,
519        }
520    }
521
522    /// Create an Anthropic provider config with auth
523    pub fn anthropic_with_auth(auth: ProviderAuth) -> Self {
524        ProviderConfig::Anthropic {
525            api_key: None,
526            api_endpoint: None,
527            access_token: None,
528            auth: Some(auth),
529        }
530    }
531
532    /// Create a Gemini provider config (legacy, uses api_key field)
533    pub fn gemini(api_key: Option<String>) -> Self {
534        ProviderConfig::Gemini {
535            api_key,
536            api_endpoint: None,
537            auth: None,
538        }
539    }
540
541    /// Create a Gemini provider config with auth
542    pub fn gemini_with_auth(auth: ProviderAuth) -> Self {
543        ProviderConfig::Gemini {
544            api_key: None,
545            api_endpoint: None,
546            auth: Some(auth),
547        }
548    }
549
550    /// Create a custom provider config (legacy, uses api_key field)
551    pub fn custom(api_endpoint: String, api_key: Option<String>) -> Self {
552        ProviderConfig::Custom {
553            api_key,
554            api_endpoint,
555            auth: None,
556        }
557    }
558
559    /// Create a custom provider config with auth
560    pub fn custom_with_auth(api_endpoint: String, auth: ProviderAuth) -> Self {
561        ProviderConfig::Custom {
562            api_key: None,
563            api_endpoint,
564            auth: Some(auth),
565        }
566    }
567
568    /// Create a Stakpak provider config (legacy, uses api_key field)
569    pub fn stakpak(api_key: String, api_endpoint: Option<String>) -> Self {
570        ProviderConfig::Stakpak {
571            api_key: Some(api_key),
572            api_endpoint,
573            auth: None,
574        }
575    }
576
577    /// Create a Stakpak provider config with auth
578    pub fn stakpak_with_auth(auth: ProviderAuth, api_endpoint: Option<String>) -> Self {
579        ProviderConfig::Stakpak {
580            api_key: None,
581            api_endpoint,
582            auth: Some(auth),
583        }
584    }
585
586    /// Create a GitHub Copilot provider config with auth (OAuth token from device flow)
587    pub fn github_copilot_with_auth(auth: ProviderAuth) -> Self {
588        ProviderConfig::GitHubCopilot {
589            api_endpoint: None,
590            auth: Some(auth),
591        }
592    }
593
594    /// Create a Bedrock provider config
595    pub fn bedrock(region: String, profile_name: Option<String>) -> Self {
596        ProviderConfig::Bedrock {
597            region,
598            profile_name,
599        }
600    }
601
602    /// Get the AWS region (Bedrock only)
603    pub fn region(&self) -> Option<&str> {
604        match self {
605            ProviderConfig::Bedrock { region, .. } => Some(region.as_str()),
606            _ => None,
607        }
608    }
609
610    /// Get the AWS profile name (Bedrock only)
611    pub fn profile_name(&self) -> Option<&str> {
612        match self {
613            ProviderConfig::Bedrock { profile_name, .. } => profile_name.as_deref(),
614            _ => None,
615        }
616    }
617
618    /// Create an empty provider config for a given provider name.
619    ///
620    /// Used during migration when we need to create a provider config
621    /// to attach auth credentials to.
622    pub fn empty_for_provider(provider_name: &str) -> Option<Self> {
623        match provider_name {
624            "openai" => Some(ProviderConfig::OpenAI {
625                api_key: None,
626                api_endpoint: None,
627                auth: None,
628            }),
629            "anthropic" => Some(ProviderConfig::Anthropic {
630                api_key: None,
631                api_endpoint: None,
632                access_token: None,
633                auth: None,
634            }),
635            "gemini" => Some(ProviderConfig::Gemini {
636                api_key: None,
637                api_endpoint: None,
638                auth: None,
639            }),
640            "stakpak" => Some(ProviderConfig::Stakpak {
641                api_key: None,
642                api_endpoint: None,
643                auth: None,
644            }),
645            "github-copilot" => Some(ProviderConfig::GitHubCopilot {
646                api_endpoint: None,
647                auth: None,
648            }),
649            // Custom providers need an endpoint, Bedrock uses AWS credential chain
650            _ => None,
651        }
652    }
653}
654
655/// Aggregated provider configuration for LLM operations
656///
657/// This struct holds all configured providers, keyed by provider name.
658#[derive(Debug, Clone, Default)]
659pub struct LLMProviderConfig {
660    /// All provider configurations (key = provider name)
661    pub providers: HashMap<String, ProviderConfig>,
662}
663
664impl LLMProviderConfig {
665    /// Create a new empty provider config
666    pub fn new() -> Self {
667        Self {
668            providers: HashMap::new(),
669        }
670    }
671
672    /// Add a provider configuration
673    pub fn add_provider(&mut self, name: impl Into<String>, config: ProviderConfig) {
674        self.providers.insert(name.into(), config);
675    }
676
677    /// Get a provider configuration by name
678    pub fn get_provider(&self, name: &str) -> Option<&ProviderConfig> {
679        self.providers.get(name)
680    }
681
682    /// Check if any providers are configured
683    pub fn is_empty(&self) -> bool {
684        self.providers.is_empty()
685    }
686}
687
688/// Provider-specific options for LLM requests
689#[derive(Clone, Debug, Serialize, Deserialize, Default)]
690pub struct LLMProviderOptions {
691    /// Anthropic-specific options
692    #[serde(skip_serializing_if = "Option::is_none")]
693    pub anthropic: Option<LLMAnthropicOptions>,
694
695    /// OpenAI-specific options
696    #[serde(skip_serializing_if = "Option::is_none")]
697    pub openai: Option<LLMOpenAIOptions>,
698
699    /// Google/Gemini-specific options
700    #[serde(skip_serializing_if = "Option::is_none")]
701    pub google: Option<LLMGoogleOptions>,
702}
703
704/// Anthropic-specific options
705#[derive(Clone, Debug, Serialize, Deserialize, Default)]
706pub struct LLMAnthropicOptions {
707    /// Extended thinking configuration
708    #[serde(skip_serializing_if = "Option::is_none")]
709    pub thinking: Option<LLMThinkingOptions>,
710}
711
712/// Thinking/reasoning options
713#[derive(Clone, Debug, Serialize, Deserialize)]
714pub struct LLMThinkingOptions {
715    /// Budget tokens for thinking (must be >= 1024)
716    pub budget_tokens: u32,
717}
718
719impl LLMThinkingOptions {
720    pub fn new(budget_tokens: u32) -> Self {
721        Self {
722            budget_tokens: budget_tokens.max(1024),
723        }
724    }
725}
726
727/// OpenAI-specific options
728#[derive(Clone, Debug, Serialize, Deserialize, Default)]
729pub struct LLMOpenAIOptions {
730    /// Reasoning effort for o1/o3/o4 models ("low", "medium", "high")
731    #[serde(skip_serializing_if = "Option::is_none")]
732    pub reasoning_effort: Option<String>,
733}
734
735/// Google/Gemini-specific options
736#[derive(Clone, Debug, Serialize, Deserialize, Default)]
737pub struct LLMGoogleOptions {
738    /// Thinking budget in tokens
739    #[serde(skip_serializing_if = "Option::is_none")]
740    pub thinking_budget: Option<u32>,
741}
742
743#[derive(Clone, Debug, Serialize)]
744pub struct LLMInput {
745    pub model: Model,
746    pub messages: Vec<LLMMessage>,
747    pub max_tokens: u32,
748    pub tools: Option<Vec<LLMTool>>,
749    #[serde(skip_serializing_if = "Option::is_none")]
750    pub provider_options: Option<LLMProviderOptions>,
751    /// Custom headers to pass to the inference provider
752    #[serde(skip_serializing_if = "Option::is_none")]
753    pub headers: Option<std::collections::HashMap<String, String>>,
754}
755
756#[derive(Debug)]
757pub struct LLMStreamInput {
758    pub model: Model,
759    pub messages: Vec<LLMMessage>,
760    pub max_tokens: u32,
761    pub stream_channel_tx: tokio::sync::mpsc::Sender<GenerationDelta>,
762    pub tools: Option<Vec<LLMTool>>,
763    pub provider_options: Option<LLMProviderOptions>,
764    /// Custom headers to pass to the inference provider
765    pub headers: Option<std::collections::HashMap<String, String>>,
766}
767
768impl From<&LLMStreamInput> for LLMInput {
769    fn from(value: &LLMStreamInput) -> Self {
770        LLMInput {
771            model: value.model.clone(),
772            messages: value.messages.clone(),
773            max_tokens: value.max_tokens,
774            tools: value.tools.clone(),
775            provider_options: value.provider_options.clone(),
776            headers: value.headers.clone(),
777        }
778    }
779}
780
781#[derive(Serialize, Deserialize, Debug, Clone, Default)]
782pub struct LLMMessage {
783    pub role: String,
784    pub content: LLMMessageContent,
785}
786
787#[derive(Serialize, Deserialize, Debug, Clone)]
788pub struct SimpleLLMMessage {
789    #[serde(rename = "role")]
790    pub role: SimpleLLMRole,
791    pub content: String,
792}
793
794#[derive(Serialize, Deserialize, Debug, Clone)]
795#[serde(rename_all = "lowercase")]
796pub enum SimpleLLMRole {
797    User,
798    Assistant,
799}
800
801impl std::fmt::Display for SimpleLLMRole {
802    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
803        match self {
804            SimpleLLMRole::User => write!(f, "user"),
805            SimpleLLMRole::Assistant => write!(f, "assistant"),
806        }
807    }
808}
809
810#[derive(Serialize, Deserialize, Debug, Clone)]
811#[serde(untagged)]
812pub enum LLMMessageContent {
813    String(String),
814    List(Vec<LLMMessageTypedContent>),
815}
816
817#[allow(clippy::to_string_trait_impl)]
818impl ToString for LLMMessageContent {
819    fn to_string(&self) -> String {
820        match self {
821            LLMMessageContent::String(s) => s.clone(),
822            LLMMessageContent::List(l) => l
823                .iter()
824                .map(|c| match c {
825                    LLMMessageTypedContent::Text { text } => text.clone(),
826                    LLMMessageTypedContent::ToolCall { .. } => String::new(),
827                    LLMMessageTypedContent::ToolResult { content, .. } => content.clone(),
828                    LLMMessageTypedContent::Image { .. } => String::new(),
829                })
830                .collect::<Vec<_>>()
831                .join("\n"),
832        }
833    }
834}
835
836impl From<String> for LLMMessageContent {
837    fn from(value: String) -> Self {
838        LLMMessageContent::String(value)
839    }
840}
841
842impl Default for LLMMessageContent {
843    fn default() -> Self {
844        LLMMessageContent::String(String::new())
845    }
846}
847
848impl LLMMessageContent {
849    /// Convert into a Vec of typed content parts.
850    /// A `String` variant is returned as a single `Text` part (empty strings yield an empty vec).
851    pub fn into_parts(self) -> Vec<LLMMessageTypedContent> {
852        match self {
853            LLMMessageContent::List(parts) => parts,
854            LLMMessageContent::String(s) if s.is_empty() => vec![],
855            LLMMessageContent::String(s) => vec![LLMMessageTypedContent::Text { text: s }],
856        }
857    }
858}
859
860#[derive(Serialize, Deserialize, Debug, Clone)]
861#[serde(tag = "type")]
862pub enum LLMMessageTypedContent {
863    #[serde(rename = "text")]
864    Text { text: String },
865    #[serde(rename = "tool_use")]
866    ToolCall {
867        id: String,
868        name: String,
869        #[serde(alias = "input")]
870        args: serde_json::Value,
871        /// Opaque provider-specific metadata (e.g., Gemini thought_signature).
872        #[serde(skip_serializing_if = "Option::is_none")]
873        metadata: Option<serde_json::Value>,
874    },
875    #[serde(rename = "tool_result")]
876    ToolResult {
877        tool_use_id: String,
878        content: String,
879    },
880    #[serde(rename = "image")]
881    Image { source: LLMMessageImageSource },
882}
883
884#[derive(Serialize, Deserialize, Debug, Clone)]
885pub struct LLMMessageImageSource {
886    #[serde(rename = "type")]
887    pub r#type: String,
888    pub media_type: String,
889    pub data: String,
890}
891
892impl Default for LLMMessageTypedContent {
893    fn default() -> Self {
894        LLMMessageTypedContent::Text {
895            text: String::new(),
896        }
897    }
898}
899
900#[derive(Serialize, Deserialize, Debug, Clone)]
901pub struct LLMChoice {
902    pub finish_reason: Option<String>,
903    pub index: u32,
904    pub message: LLMMessage,
905}
906
907#[derive(Serialize, Deserialize, Debug, Clone)]
908pub struct LLMCompletionResponse {
909    pub model: String,
910    pub object: String,
911    pub choices: Vec<LLMChoice>,
912    pub created: u64,
913    pub usage: Option<LLMTokenUsage>,
914    pub id: String,
915}
916
917#[derive(Serialize, Deserialize, Debug, Clone)]
918pub struct LLMStreamDelta {
919    #[serde(skip_serializing_if = "Option::is_none")]
920    pub content: Option<String>,
921}
922
923#[derive(Serialize, Deserialize, Debug, Clone)]
924pub struct LLMStreamChoice {
925    pub finish_reason: Option<String>,
926    pub index: u32,
927    pub message: Option<LLMMessage>,
928    pub delta: LLMStreamDelta,
929}
930
931#[derive(Serialize, Deserialize, Debug, Clone)]
932pub struct LLMCompletionStreamResponse {
933    pub model: String,
934    pub object: String,
935    pub choices: Vec<LLMStreamChoice>,
936    pub created: u64,
937    #[serde(skip_serializing_if = "Option::is_none")]
938    pub usage: Option<LLMTokenUsage>,
939    pub id: String,
940    pub citations: Option<Vec<String>>,
941}
942
943#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
944pub struct LLMTool {
945    pub name: String,
946    pub description: String,
947    pub input_schema: serde_json::Value,
948}
949
950#[derive(Default, Debug, Serialize, Deserialize, Clone, PartialEq)]
951pub struct LLMTokenUsage {
952    pub prompt_tokens: u32,
953    pub completion_tokens: u32,
954    pub total_tokens: u32,
955
956    #[serde(skip_serializing_if = "Option::is_none")]
957    pub prompt_tokens_details: Option<PromptTokensDetails>,
958}
959
960#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
961#[serde(rename_all = "snake_case")]
962pub enum TokenType {
963    InputTokens,
964    OutputTokens,
965    CacheReadInputTokens,
966    CacheWriteInputTokens,
967}
968
969#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)]
970pub struct PromptTokensDetails {
971    #[serde(skip_serializing_if = "Option::is_none")]
972    pub input_tokens: Option<u32>,
973    #[serde(skip_serializing_if = "Option::is_none")]
974    pub output_tokens: Option<u32>,
975    #[serde(skip_serializing_if = "Option::is_none")]
976    pub cache_read_input_tokens: Option<u32>,
977    #[serde(skip_serializing_if = "Option::is_none")]
978    pub cache_write_input_tokens: Option<u32>,
979}
980
981impl PromptTokensDetails {
982    /// Returns an iterator over the token types and their values
983    pub fn iter(&self) -> impl Iterator<Item = (TokenType, u32)> {
984        [
985            (TokenType::InputTokens, self.input_tokens.unwrap_or(0)),
986            (TokenType::OutputTokens, self.output_tokens.unwrap_or(0)),
987            (
988                TokenType::CacheReadInputTokens,
989                self.cache_read_input_tokens.unwrap_or(0),
990            ),
991            (
992                TokenType::CacheWriteInputTokens,
993                self.cache_write_input_tokens.unwrap_or(0),
994            ),
995        ]
996        .into_iter()
997    }
998}
999
1000impl std::ops::Add for PromptTokensDetails {
1001    type Output = Self;
1002
1003    fn add(self, rhs: Self) -> Self::Output {
1004        Self {
1005            input_tokens: Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0)),
1006            output_tokens: Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0)),
1007            cache_read_input_tokens: Some(
1008                self.cache_read_input_tokens.unwrap_or(0)
1009                    + rhs.cache_read_input_tokens.unwrap_or(0),
1010            ),
1011            cache_write_input_tokens: Some(
1012                self.cache_write_input_tokens.unwrap_or(0)
1013                    + rhs.cache_write_input_tokens.unwrap_or(0),
1014            ),
1015        }
1016    }
1017}
1018
1019impl std::ops::AddAssign for PromptTokensDetails {
1020    fn add_assign(&mut self, rhs: Self) {
1021        self.input_tokens = Some(self.input_tokens.unwrap_or(0) + rhs.input_tokens.unwrap_or(0));
1022        self.output_tokens = Some(self.output_tokens.unwrap_or(0) + rhs.output_tokens.unwrap_or(0));
1023        self.cache_read_input_tokens = Some(
1024            self.cache_read_input_tokens.unwrap_or(0) + rhs.cache_read_input_tokens.unwrap_or(0),
1025        );
1026        self.cache_write_input_tokens = Some(
1027            self.cache_write_input_tokens.unwrap_or(0) + rhs.cache_write_input_tokens.unwrap_or(0),
1028        );
1029    }
1030}
1031
1032#[derive(Serialize, Deserialize, Debug, Clone)]
1033#[serde(tag = "type")]
1034pub enum GenerationDelta {
1035    Content { content: String },
1036    Thinking { thinking: String },
1037    ToolUse { tool_use: GenerationDeltaToolUse },
1038    Usage { usage: LLMTokenUsage },
1039    Metadata { metadata: serde_json::Value },
1040}
1041
1042#[derive(Serialize, Deserialize, Debug, Clone)]
1043pub struct GenerationDeltaToolUse {
1044    pub id: Option<String>,
1045    pub name: Option<String>,
1046    pub input: Option<String>,
1047    pub index: usize,
1048    /// Opaque provider-specific metadata (e.g., Gemini thought_signature)
1049    #[serde(skip_serializing_if = "Option::is_none")]
1050    pub metadata: Option<serde_json::Value>,
1051}
1052
1053#[cfg(test)]
1054mod tests {
1055    use super::*;
1056
1057    // =========================================================================
1058    // ProviderConfig Tests
1059    // =========================================================================
1060
1061    #[test]
1062    fn test_provider_config_openai_serialization() {
1063        let config = ProviderConfig::OpenAI {
1064            api_key: Some("sk-test".to_string()),
1065            api_endpoint: None,
1066            auth: None,
1067        };
1068        let json = serde_json::to_string(&config).unwrap();
1069        assert!(json.contains("\"type\":\"openai\""));
1070        assert!(json.contains("\"api_key\":\"sk-test\""));
1071        assert!(!json.contains("api_endpoint")); // Should be skipped when None
1072    }
1073
1074    #[test]
1075    fn test_provider_config_openai_with_endpoint() {
1076        let config = ProviderConfig::OpenAI {
1077            api_key: Some("sk-test".to_string()),
1078            api_endpoint: Some("https://custom.openai.com/v1".to_string()),
1079            auth: None,
1080        };
1081        let json = serde_json::to_string(&config).unwrap();
1082        assert!(json.contains("\"api_endpoint\":\"https://custom.openai.com/v1\""));
1083    }
1084
1085    #[test]
1086    fn test_provider_config_anthropic_serialization() {
1087        let config = ProviderConfig::Anthropic {
1088            api_key: Some("sk-ant-test".to_string()),
1089            api_endpoint: None,
1090            access_token: Some("oauth-token".to_string()),
1091            auth: None,
1092        };
1093        let json = serde_json::to_string(&config).unwrap();
1094        assert!(json.contains("\"type\":\"anthropic\""));
1095        assert!(json.contains("\"api_key\":\"sk-ant-test\""));
1096        assert!(json.contains("\"access_token\":\"oauth-token\""));
1097    }
1098
1099    #[test]
1100    fn test_provider_config_gemini_serialization() {
1101        let config = ProviderConfig::Gemini {
1102            api_key: Some("gemini-key".to_string()),
1103            api_endpoint: None,
1104            auth: None,
1105        };
1106        let json = serde_json::to_string(&config).unwrap();
1107        assert!(json.contains("\"type\":\"gemini\""));
1108        assert!(json.contains("\"api_key\":\"gemini-key\""));
1109    }
1110
1111    #[test]
1112    fn test_provider_config_custom_serialization() {
1113        let config = ProviderConfig::Custom {
1114            api_key: Some("sk-custom".to_string()),
1115            api_endpoint: "http://localhost:4000".to_string(),
1116            auth: None,
1117        };
1118        let json = serde_json::to_string(&config).unwrap();
1119        assert!(json.contains("\"type\":\"custom\""));
1120        assert!(json.contains("\"api_endpoint\":\"http://localhost:4000\""));
1121        assert!(json.contains("\"api_key\":\"sk-custom\""));
1122    }
1123
1124    #[test]
1125    fn test_provider_config_custom_without_key() {
1126        let config = ProviderConfig::Custom {
1127            api_key: None,
1128            api_endpoint: "http://localhost:11434/v1".to_string(),
1129            auth: None,
1130        };
1131        let json = serde_json::to_string(&config).unwrap();
1132        assert!(json.contains("\"type\":\"custom\""));
1133        assert!(json.contains("\"api_endpoint\""));
1134        assert!(!json.contains("api_key")); // Should be skipped when None
1135    }
1136
1137    #[test]
1138    fn test_provider_config_deserialization_openai() {
1139        let json = r#"{"type":"openai","api_key":"sk-test"}"#;
1140        let config: ProviderConfig = serde_json::from_str(json).unwrap();
1141        assert!(matches!(config, ProviderConfig::OpenAI { .. }));
1142        assert_eq!(config.api_key(), Some("sk-test"));
1143    }
1144
1145    #[test]
1146    fn test_provider_config_deserialization_anthropic() {
1147        let json = r#"{"type":"anthropic","api_key":"sk-ant","access_token":"oauth"}"#;
1148        let config: ProviderConfig = serde_json::from_str(json).unwrap();
1149        assert!(matches!(config, ProviderConfig::Anthropic { .. }));
1150        assert_eq!(config.api_key(), Some("sk-ant"));
1151        assert_eq!(config.access_token(), Some("oauth"));
1152    }
1153
1154    #[test]
1155    fn test_provider_config_deserialization_gemini() {
1156        let json = r#"{"type":"gemini","api_key":"gemini-key"}"#;
1157        let config: ProviderConfig = serde_json::from_str(json).unwrap();
1158        assert!(matches!(config, ProviderConfig::Gemini { .. }));
1159        assert_eq!(config.api_key(), Some("gemini-key"));
1160    }
1161
1162    #[test]
1163    fn test_provider_config_deserialization_custom() {
1164        let json =
1165            r#"{"type":"custom","api_endpoint":"http://localhost:4000","api_key":"sk-custom"}"#;
1166        let config: ProviderConfig = serde_json::from_str(json).unwrap();
1167        assert!(matches!(config, ProviderConfig::Custom { .. }));
1168        assert_eq!(config.api_key(), Some("sk-custom"));
1169        assert_eq!(config.api_endpoint(), Some("http://localhost:4000"));
1170    }
1171
1172    #[test]
1173    fn test_provider_config_helper_methods() {
1174        let openai = ProviderConfig::openai(Some("sk-openai".to_string()));
1175        assert_eq!(openai.provider_type(), "openai");
1176        assert_eq!(openai.api_key(), Some("sk-openai"));
1177
1178        let anthropic =
1179            ProviderConfig::anthropic(Some("sk-ant".to_string()), Some("oauth".to_string()));
1180        assert_eq!(anthropic.provider_type(), "anthropic");
1181        assert_eq!(anthropic.access_token(), Some("oauth"));
1182
1183        let gemini = ProviderConfig::gemini(Some("gemini-key".to_string()));
1184        assert_eq!(gemini.provider_type(), "gemini");
1185
1186        let custom = ProviderConfig::custom(
1187            "http://localhost:4000".to_string(),
1188            Some("sk-custom".to_string()),
1189        );
1190        assert_eq!(custom.provider_type(), "custom");
1191        assert_eq!(custom.api_endpoint(), Some("http://localhost:4000"));
1192    }
1193
1194    #[test]
1195    fn test_set_api_endpoint_updates_supported_providers() {
1196        let mut openai = ProviderConfig::openai(Some("sk-openai".to_string()));
1197        openai.set_api_endpoint(Some("https://proxy.example.com/v1".to_string()));
1198        assert_eq!(openai.api_endpoint(), Some("https://proxy.example.com/v1"));
1199
1200        let mut bedrock = ProviderConfig::bedrock("us-east-1".to_string(), None);
1201        bedrock.set_api_endpoint(Some("https://ignored.example.com".to_string()));
1202        assert_eq!(bedrock.api_endpoint(), None);
1203    }
1204
1205    #[test]
1206    fn test_llm_provider_config_new() {
1207        let config = LLMProviderConfig::new();
1208        assert!(config.is_empty());
1209    }
1210
1211    #[test]
1212    fn test_llm_provider_config_add_and_get() {
1213        let mut config = LLMProviderConfig::new();
1214        config.add_provider(
1215            "openai",
1216            ProviderConfig::openai(Some("sk-test".to_string())),
1217        );
1218        config.add_provider(
1219            "anthropic",
1220            ProviderConfig::anthropic(Some("sk-ant".to_string()), None),
1221        );
1222
1223        assert!(!config.is_empty());
1224        assert!(config.get_provider("openai").is_some());
1225        assert!(config.get_provider("anthropic").is_some());
1226        assert!(config.get_provider("unknown").is_none());
1227    }
1228
1229    #[test]
1230    fn test_provider_config_toml_parsing() {
1231        // Test parsing a HashMap of providers from TOML-like JSON
1232        let json = r#"{
1233            "openai": {"type": "openai", "api_key": "sk-openai"},
1234            "anthropic": {"type": "anthropic", "api_key": "sk-ant", "access_token": "oauth"},
1235            "litellm": {"type": "custom", "api_endpoint": "http://localhost:4000", "api_key": "sk-litellm"}
1236        }"#;
1237
1238        let providers: HashMap<String, ProviderConfig> = serde_json::from_str(json).unwrap();
1239        assert_eq!(providers.len(), 3);
1240
1241        assert!(matches!(
1242            providers.get("openai"),
1243            Some(ProviderConfig::OpenAI { .. })
1244        ));
1245        assert!(matches!(
1246            providers.get("anthropic"),
1247            Some(ProviderConfig::Anthropic { .. })
1248        ));
1249        assert!(matches!(
1250            providers.get("litellm"),
1251            Some(ProviderConfig::Custom { .. })
1252        ));
1253    }
1254
1255    // =========================================================================
1256    // Bedrock ProviderConfig Tests
1257    // =========================================================================
1258
1259    #[test]
1260    fn test_provider_config_bedrock_serialization() {
1261        let config = ProviderConfig::Bedrock {
1262            region: "us-east-1".to_string(),
1263            profile_name: Some("my-profile".to_string()),
1264        };
1265        let json = serde_json::to_string(&config).unwrap();
1266        assert!(json.contains("\"type\":\"amazon-bedrock\""));
1267        assert!(json.contains("\"region\":\"us-east-1\""));
1268        assert!(json.contains("\"profile_name\":\"my-profile\""));
1269    }
1270
1271    #[test]
1272    fn test_provider_config_bedrock_serialization_without_profile() {
1273        let config = ProviderConfig::Bedrock {
1274            region: "us-west-2".to_string(),
1275            profile_name: None,
1276        };
1277        let json = serde_json::to_string(&config).unwrap();
1278        assert!(json.contains("\"type\":\"amazon-bedrock\""));
1279        assert!(json.contains("\"region\":\"us-west-2\""));
1280        assert!(!json.contains("profile_name")); // Should be skipped when None
1281    }
1282
1283    #[test]
1284    fn test_provider_config_bedrock_deserialization() {
1285        let json = r#"{"type":"amazon-bedrock","region":"us-east-1","profile_name":"prod"}"#;
1286        let config: ProviderConfig = serde_json::from_str(json).unwrap();
1287        assert!(matches!(config, ProviderConfig::Bedrock { .. }));
1288        assert_eq!(config.region(), Some("us-east-1"));
1289        assert_eq!(config.profile_name(), Some("prod"));
1290    }
1291
1292    #[test]
1293    fn test_provider_config_bedrock_deserialization_minimal() {
1294        let json = r#"{"type":"amazon-bedrock","region":"eu-west-1"}"#;
1295        let config: ProviderConfig = serde_json::from_str(json).unwrap();
1296        assert!(matches!(config, ProviderConfig::Bedrock { .. }));
1297        assert_eq!(config.region(), Some("eu-west-1"));
1298        assert_eq!(config.profile_name(), None);
1299    }
1300
1301    #[test]
1302    fn test_provider_config_bedrock_no_api_key() {
1303        let config = ProviderConfig::bedrock("us-east-1".to_string(), None);
1304        assert_eq!(config.api_key(), None); // Bedrock uses AWS credential chain
1305        assert_eq!(config.api_endpoint(), None); // No custom endpoint
1306    }
1307
1308    #[test]
1309    fn test_provider_config_bedrock_helper_methods() {
1310        let bedrock = ProviderConfig::bedrock("us-east-1".to_string(), Some("prod".to_string()));
1311        assert_eq!(bedrock.provider_type(), "amazon-bedrock");
1312        assert_eq!(bedrock.region(), Some("us-east-1"));
1313        assert_eq!(bedrock.profile_name(), Some("prod"));
1314        assert_eq!(bedrock.api_key(), None);
1315        assert_eq!(bedrock.api_endpoint(), None);
1316        assert_eq!(bedrock.access_token(), None);
1317    }
1318
1319    #[test]
1320    fn test_provider_config_bedrock_toml_roundtrip() {
1321        let config = ProviderConfig::Bedrock {
1322            region: "us-east-1".to_string(),
1323            profile_name: Some("my-profile".to_string()),
1324        };
1325        let toml_str = toml::to_string(&config).unwrap();
1326        let parsed: ProviderConfig = toml::from_str(&toml_str).unwrap();
1327        assert_eq!(config, parsed);
1328    }
1329
1330    #[test]
1331    fn test_provider_config_bedrock_toml_parsing() {
1332        let toml_str = r#"
1333            type = "amazon-bedrock"
1334            region = "us-east-1"
1335            profile_name = "production"
1336        "#;
1337        let config: ProviderConfig = toml::from_str(toml_str).unwrap();
1338        assert!(matches!(
1339            config,
1340            ProviderConfig::Bedrock {
1341                ref region,
1342                ref profile_name,
1343            } if region == "us-east-1" && profile_name.as_deref() == Some("production")
1344        ));
1345    }
1346
1347    #[test]
1348    fn test_provider_config_bedrock_missing_region_fails() {
1349        let json = r#"{"type":"amazon-bedrock"}"#;
1350        let result: Result<ProviderConfig, _> = serde_json::from_str(json);
1351        assert!(result.is_err()); // region is required
1352    }
1353
1354    #[test]
1355    fn test_provider_config_bedrock_in_providers_map() {
1356        let json = r#"{
1357            "anthropic": {"type": "anthropic", "api_key": "sk-ant"},
1358            "amazon-bedrock": {"type": "amazon-bedrock", "region": "us-east-1"}
1359        }"#;
1360        let providers: HashMap<String, ProviderConfig> = serde_json::from_str(json).unwrap();
1361        assert_eq!(providers.len(), 2);
1362        assert!(matches!(
1363            providers.get("amazon-bedrock"),
1364            Some(ProviderConfig::Bedrock { .. })
1365        ));
1366    }
1367
1368    #[test]
1369    fn test_region_returns_none_for_non_bedrock() {
1370        let openai = ProviderConfig::openai(Some("key".to_string()));
1371        assert_eq!(openai.region(), None);
1372
1373        let anthropic = ProviderConfig::anthropic(Some("key".to_string()), None);
1374        assert_eq!(anthropic.region(), None);
1375    }
1376
1377    #[test]
1378    fn test_profile_name_returns_none_for_non_bedrock() {
1379        let openai = ProviderConfig::openai(Some("key".to_string()));
1380        assert_eq!(openai.profile_name(), None);
1381    }
1382}