Skip to main content

zeph_config/
providers.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::fmt;
5
6use serde::{Deserialize, Serialize};
7use zeph_llm::{GeminiThinkingLevel, ThinkingConfig};
8
9/// Newtype wrapper for a provider name referencing an entry in `[[llm.providers]]`.
10///
11/// Using a dedicated type instead of bare `String` makes provider cross-references
12/// explicit in the type system and enables validation at config load time.
13#[derive(Debug, Clone, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
14#[serde(transparent)]
15pub struct ProviderName(String);
16
17impl ProviderName {
18    #[must_use]
19    pub fn new(name: impl Into<String>) -> Self {
20        Self(name.into())
21    }
22
23    #[must_use]
24    pub fn is_empty(&self) -> bool {
25        self.0.is_empty()
26    }
27
28    #[must_use]
29    pub fn as_str(&self) -> &str {
30        &self.0
31    }
32}
33
34impl fmt::Display for ProviderName {
35    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36        self.0.fmt(f)
37    }
38}
39
40impl AsRef<str> for ProviderName {
41    fn as_ref(&self) -> &str {
42        &self.0
43    }
44}
45
46impl std::ops::Deref for ProviderName {
47    type Target = str;
48
49    fn deref(&self) -> &str {
50        &self.0
51    }
52}
53
54impl PartialEq<str> for ProviderName {
55    fn eq(&self, other: &str) -> bool {
56        self.0 == other
57    }
58}
59
60impl PartialEq<&str> for ProviderName {
61    fn eq(&self, other: &&str) -> bool {
62        self.0 == *other
63    }
64}
65
66fn default_response_cache_ttl_secs() -> u64 {
67    3600
68}
69
70fn default_semantic_cache_threshold() -> f32 {
71    0.95
72}
73
74fn default_semantic_cache_max_candidates() -> u32 {
75    10
76}
77
78fn default_router_ema_alpha() -> f64 {
79    0.1
80}
81
82fn default_router_reorder_interval() -> u64 {
83    10
84}
85
86fn default_embedding_model() -> String {
87    "qwen3-embedding".into()
88}
89
90fn default_candle_source() -> String {
91    "huggingface".into()
92}
93
94fn default_chat_template() -> String {
95    "chatml".into()
96}
97
98fn default_candle_device() -> String {
99    "cpu".into()
100}
101
102fn default_temperature() -> f64 {
103    0.7
104}
105
106fn default_max_tokens() -> usize {
107    2048
108}
109
110fn default_seed() -> u64 {
111    42
112}
113
114fn default_repeat_penalty() -> f32 {
115    1.1
116}
117
118fn default_repeat_last_n() -> usize {
119    64
120}
121
122fn default_cascade_quality_threshold() -> f64 {
123    0.5
124}
125
126fn default_cascade_max_escalations() -> u8 {
127    2
128}
129
130fn default_cascade_window_size() -> usize {
131    50
132}
133
134fn default_reputation_decay_factor() -> f64 {
135    0.95
136}
137
138fn default_reputation_weight() -> f64 {
139    0.3
140}
141
142fn default_reputation_min_observations() -> u64 {
143    5
144}
145
146#[must_use]
147pub fn default_stt_provider() -> String {
148    String::new()
149}
150
151#[must_use]
152pub fn default_stt_language() -> String {
153    "auto".into()
154}
155
156#[must_use]
157pub fn get_default_embedding_model() -> String {
158    default_embedding_model()
159}
160
161#[must_use]
162pub fn get_default_response_cache_ttl_secs() -> u64 {
163    default_response_cache_ttl_secs()
164}
165
166#[must_use]
167pub fn get_default_router_ema_alpha() -> f64 {
168    default_router_ema_alpha()
169}
170
171#[must_use]
172pub fn get_default_router_reorder_interval() -> u64 {
173    default_router_reorder_interval()
174}
175
176/// LLM provider backend selector.
177#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
178#[serde(rename_all = "lowercase")]
179pub enum ProviderKind {
180    Ollama,
181    Claude,
182    OpenAi,
183    Gemini,
184    Candle,
185    Compatible,
186}
187
188impl ProviderKind {
189    #[must_use]
190    pub fn as_str(self) -> &'static str {
191        match self {
192            Self::Ollama => "ollama",
193            Self::Claude => "claude",
194            Self::OpenAi => "openai",
195            Self::Gemini => "gemini",
196            Self::Candle => "candle",
197            Self::Compatible => "compatible",
198        }
199    }
200}
201
202impl std::fmt::Display for ProviderKind {
203    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204        f.write_str(self.as_str())
205    }
206}
207
208#[derive(Debug, Deserialize, Serialize)]
209pub struct LlmConfig {
210    /// Provider pool. First entry is default unless one is marked `default = true`.
211    #[serde(default, skip_serializing_if = "Vec::is_empty")]
212    pub providers: Vec<ProviderEntry>,
213
214    /// Routing strategy for multi-provider configs.
215    #[serde(default, skip_serializing_if = "is_routing_none")]
216    pub routing: LlmRoutingStrategy,
217
218    /// Task-based routes (only used when `routing = "task"`).
219    #[serde(default, skip_serializing_if = "std::collections::HashMap::is_empty")]
220    pub routes: std::collections::HashMap<String, Vec<String>>,
221
222    #[serde(default = "default_embedding_model_opt")]
223    pub embedding_model: String,
224    #[serde(default, skip_serializing_if = "Option::is_none")]
225    pub candle: Option<CandleConfig>,
226    #[serde(default)]
227    pub stt: Option<SttConfig>,
228    #[serde(default)]
229    pub response_cache_enabled: bool,
230    #[serde(default = "default_response_cache_ttl_secs")]
231    pub response_cache_ttl_secs: u64,
232    /// Enable semantic similarity-based response caching. Requires embedding support.
233    #[serde(default)]
234    pub semantic_cache_enabled: bool,
235    /// Cosine similarity threshold for semantic cache hits (0.0–1.0).
236    ///
237    /// Only the highest-scoring candidate above this threshold is returned.
238    /// Lower values produce more cache hits but risk returning less relevant responses.
239    /// Recommended range: 0.92–0.98; default: 0.95.
240    #[serde(default = "default_semantic_cache_threshold")]
241    pub semantic_cache_threshold: f32,
242    /// Maximum cached entries to examine per semantic lookup (SQL `LIMIT` clause in
243    /// `ResponseCache::get_semantic()`). Controls the recall-vs-performance tradeoff:
244    ///
245    /// - **Higher values** (e.g. 50): scan more entries, better chance of finding a
246    ///   semantically similar cached response, but slower queries.
247    /// - **Lower values** (e.g. 5): faster queries, but may miss relevant cached entries
248    ///   when the cache is large.
249    /// - **Default (10)**: balanced middle ground for typical workloads.
250    ///
251    /// Tuning guidance: set to 50+ when recall matters more than latency (e.g. long-running
252    /// sessions with many cached responses); reduce to 5 for low-latency interactive use.
253    /// Env override: `ZEPH_LLM_SEMANTIC_CACHE_MAX_CANDIDATES`.
254    #[serde(default = "default_semantic_cache_max_candidates")]
255    pub semantic_cache_max_candidates: u32,
256    #[serde(default)]
257    pub router_ema_enabled: bool,
258    #[serde(default = "default_router_ema_alpha")]
259    pub router_ema_alpha: f64,
260    #[serde(default = "default_router_reorder_interval")]
261    pub router_reorder_interval: u64,
262    /// Routing configuration for Thompson/Cascade strategies.
263    #[serde(default, skip_serializing_if = "Option::is_none")]
264    pub router: Option<RouterConfig>,
265    /// Provider-specific instruction file to inject into the system prompt.
266    /// Merged with `agent.instruction_files` at startup.
267    #[serde(default, skip_serializing_if = "Option::is_none")]
268    pub instruction_file: Option<std::path::PathBuf>,
269    /// Shorthand model spec for tool-pair summarization and context compaction.
270    /// Format: `ollama/<model>`, `claude[/<model>]`, `openai[/<model>]`, `compatible/<name>`, `candle`.
271    /// Ignored when `[llm.summary_provider]` is set.
272    #[serde(default, skip_serializing_if = "Option::is_none")]
273    pub summary_model: Option<String>,
274    /// Structured provider config for summarization. Takes precedence over `summary_model`.
275    #[serde(default, skip_serializing_if = "Option::is_none")]
276    pub summary_provider: Option<ProviderEntry>,
277
278    /// Complexity triage routing configuration. Required when `routing = "triage"`.
279    #[serde(default, skip_serializing_if = "Option::is_none")]
280    pub complexity_routing: Option<ComplexityRoutingConfig>,
281}
282
283fn default_embedding_model_opt() -> String {
284    default_embedding_model()
285}
286
287#[allow(clippy::trivially_copy_pass_by_ref)]
288fn is_routing_none(s: &LlmRoutingStrategy) -> bool {
289    *s == LlmRoutingStrategy::None
290}
291
292impl LlmConfig {
293    /// Effective provider kind for the primary (first/default) provider in the pool.
294    #[must_use]
295    pub fn effective_provider(&self) -> ProviderKind {
296        self.providers
297            .first()
298            .map_or(ProviderKind::Ollama, |e| e.provider_type)
299    }
300
301    /// Effective base URL for the primary provider.
302    #[must_use]
303    pub fn effective_base_url(&self) -> &str {
304        self.providers
305            .first()
306            .and_then(|e| e.base_url.as_deref())
307            .unwrap_or("http://localhost:11434")
308    }
309
310    /// Effective model for the primary provider.
311    #[must_use]
312    pub fn effective_model(&self) -> &str {
313        self.providers
314            .first()
315            .and_then(|e| e.model.as_deref())
316            .unwrap_or("qwen3:8b")
317    }
318
319    /// Find the provider entry designated for STT.
320    ///
321    /// Resolution priority:
322    /// 1. `[llm.stt].provider` matches `[[llm.providers]].name` and the entry has `stt_model`
323    /// 2. `[llm.stt].provider` is empty — fall through to auto-detect
324    /// 3. First provider with `stt_model` set (auto-detect fallback)
325    /// 4. `None` — STT disabled
326    #[must_use]
327    pub fn stt_provider_entry(&self) -> Option<&ProviderEntry> {
328        let name_hint = self.stt.as_ref().map_or("", |s| s.provider.as_str());
329        if name_hint.is_empty() {
330            self.providers.iter().find(|p| p.stt_model.is_some())
331        } else {
332            self.providers
333                .iter()
334                .find(|p| p.effective_name() == name_hint && p.stt_model.is_some())
335        }
336    }
337
338    /// Validate that the config uses the new `[[llm.providers]]` format.
339    ///
340    /// # Errors
341    ///
342    /// Returns `ConfigError::Validation` when no providers are configured.
343    pub fn check_legacy_format(&self) -> Result<(), crate::error::ConfigError> {
344        Ok(())
345    }
346
347    /// Validate STT config cross-references.
348    ///
349    /// # Errors
350    ///
351    /// Returns `ConfigError::Validation` when the referenced STT provider does not exist.
352    pub fn validate_stt(&self) -> Result<(), crate::error::ConfigError> {
353        use crate::error::ConfigError;
354
355        let Some(stt) = &self.stt else {
356            return Ok(());
357        };
358        if stt.provider.is_empty() {
359            return Ok(());
360        }
361        let found = self
362            .providers
363            .iter()
364            .find(|p| p.effective_name() == stt.provider);
365        match found {
366            None => {
367                return Err(ConfigError::Validation(format!(
368                    "[llm.stt].provider = {:?} does not match any [[llm.providers]] entry",
369                    stt.provider
370                )));
371            }
372            Some(entry) if entry.stt_model.is_none() => {
373                tracing::warn!(
374                    provider = stt.provider,
375                    "[[llm.providers]] entry exists but has no `stt_model` — STT will not be activated"
376                );
377            }
378            _ => {}
379        }
380        Ok(())
381    }
382}
383
384#[derive(Debug, Clone, Deserialize, Serialize)]
385pub struct SttConfig {
386    /// Provider name from `[[llm.providers]]`. Empty string means auto-detect first provider
387    /// with `stt_model` set.
388    #[serde(default = "default_stt_provider")]
389    pub provider: String,
390    /// Language hint for transcription (e.g. `"en"`, `"auto"`).
391    #[serde(default = "default_stt_language")]
392    pub language: String,
393}
394
395/// Routing strategy selection for multi-provider routing.
396#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
397#[serde(rename_all = "lowercase")]
398pub enum RouterStrategyConfig {
399    /// Exponential moving average latency-aware ordering.
400    #[default]
401    Ema,
402    /// Thompson Sampling with Beta distributions (persistence-backed).
403    Thompson,
404    /// Cascade routing: try cheapest provider first, escalate on degenerate output.
405    Cascade,
406    /// PILOT: `LinUCB` contextual bandit with online learning and cost-aware reward.
407    Bandit,
408}
409
410/// Agent Stability Index (ASI) configuration.
411///
412/// Tracks per-provider response coherence via a sliding window of response embeddings.
413/// When coherence drops below `coherence_threshold`, the provider's routing prior is
414/// penalized by `penalty_weight`. Disabled by default; session-only (no persistence).
415///
416/// # Known Limitation
417///
418/// ASI embeddings are computed in a background `tokio::spawn` task after the response is
419/// returned to the caller. Under high request rates, the coherence score used for routing
420/// may lag 1–2 responses behind due to this fire-and-forget design. With the default
421/// `window = 5`, this lag is tolerable — coherence is a slow-moving signal.
422#[derive(Debug, Clone, Deserialize, Serialize)]
423pub struct AsiConfig {
424    /// Enable ASI coherence tracking. Default: false.
425    #[serde(default)]
426    pub enabled: bool,
427
428    /// Sliding window size for response embeddings per provider. Default: 5.
429    #[serde(default = "default_asi_window")]
430    pub window: usize,
431
432    /// Coherence score [0.0, 1.0] below which the provider is penalized. Default: 0.7.
433    #[serde(default = "default_asi_coherence_threshold")]
434    pub coherence_threshold: f32,
435
436    /// Penalty weight applied to Thompson beta / EMA score on low coherence. Default: 0.3.
437    ///
438    /// For Thompson, this shifts the beta prior: `beta += penalty_weight * (threshold - coherence)`.
439    /// For EMA, the score is multiplied by `max(0.5, coherence / threshold)`.
440    #[serde(default = "default_asi_penalty_weight")]
441    pub penalty_weight: f32,
442}
443
444fn default_asi_window() -> usize {
445    5
446}
447
448fn default_asi_coherence_threshold() -> f32 {
449    0.7
450}
451
452fn default_asi_penalty_weight() -> f32 {
453    0.3
454}
455
456impl Default for AsiConfig {
457    fn default() -> Self {
458        Self {
459            enabled: false,
460            window: default_asi_window(),
461            coherence_threshold: default_asi_coherence_threshold(),
462            penalty_weight: default_asi_penalty_weight(),
463        }
464    }
465}
466
467/// Routing configuration for multi-provider setups.
468#[derive(Debug, Clone, Deserialize, Serialize)]
469pub struct RouterConfig {
470    /// Routing strategy: `"ema"` (default), `"thompson"`, `"cascade"`, or `"bandit"`.
471    #[serde(default)]
472    pub strategy: RouterStrategyConfig,
473    /// Path for persisting Thompson Sampling state. Defaults to `~/.zeph/router_thompson_state.json`.
474    ///
475    /// # Security
476    ///
477    /// This path is user-controlled. The application writes and reads a JSON file at
478    /// this location. Ensure the path is within a directory that is not world-writable
479    /// (e.g., avoid `/tmp`). The file is created with mode `0o600` on Unix.
480    #[serde(default)]
481    pub thompson_state_path: Option<String>,
482    /// Cascade routing configuration. Only used when `strategy = "cascade"`.
483    #[serde(default)]
484    pub cascade: Option<CascadeConfig>,
485    /// Bayesian reputation scoring configuration (RAPS). Disabled by default.
486    #[serde(default)]
487    pub reputation: Option<ReputationConfig>,
488    /// PILOT bandit routing configuration. Only used when `strategy = "bandit"`.
489    #[serde(default)]
490    pub bandit: Option<BanditConfig>,
491    /// Embedding-based quality gate threshold for Thompson/EMA routing. Default: disabled.
492    ///
493    /// When set, after provider selection, the cosine similarity between the query embedding
494    /// and the response embedding is computed. If below this threshold, the next provider in
495    /// the ordered list is tried. On exhaustion, the best response seen is returned.
496    ///
497    /// Only applies to Thompson and EMA strategies. Cascade uses its own quality classifier.
498    /// Fail-open: embedding errors disable the gate for that request.
499    #[serde(default)]
500    pub quality_gate: Option<f32>,
501    /// Agent Stability Index configuration. Disabled by default.
502    #[serde(default)]
503    pub asi: Option<AsiConfig>,
504    /// Maximum number of concurrent `embed_batch` calls through the router.
505    ///
506    /// Limits simultaneous embedding HTTP requests to prevent provider rate-limiting
507    /// and memory pressure during indexing or high-frequency recall. Default: 4.
508    /// Set to 0 to disable the semaphore (unlimited concurrency).
509    #[serde(default = "default_embed_concurrency")]
510    pub embed_concurrency: usize,
511}
512
513fn default_embed_concurrency() -> usize {
514    4
515}
516
517/// Configuration for Bayesian reputation scoring (RAPS — Reputation-Adjusted Provider Selection).
518///
519/// When enabled, quality outcomes from tool execution shift the routing scores over time,
520/// giving an advantage to providers that consistently produce valid tool arguments.
521///
522/// Default: disabled. Set `enabled = true` to activate.
523#[derive(Debug, Clone, Deserialize, Serialize)]
524pub struct ReputationConfig {
525    /// Enable reputation scoring. Default: false.
526    #[serde(default)]
527    pub enabled: bool,
528    /// Session-level decay factor applied on each load. Range: (0.0, 1.0]. Default: 0.95.
529    /// Lower values make reputation forget faster; 1.0 = no decay.
530    #[serde(default = "default_reputation_decay_factor")]
531    pub decay_factor: f64,
532    /// Weight of reputation in routing score blend. Range: [0.0, 1.0]. Default: 0.3.
533    ///
534    /// **Warning**: values above 0.5 can aggressively suppress low-reputation providers.
535    /// At `weight = 1.0` with `rep_factor = 0.0` (all failures), the routing score
536    /// drops to zero — the provider becomes unreachable for that session. Stick to
537    /// the default (0.3) unless you intentionally want strong reputation gating.
538    #[serde(default = "default_reputation_weight")]
539    pub weight: f64,
540    /// Minimum quality observations before reputation influences routing. Default: 5.
541    #[serde(default = "default_reputation_min_observations")]
542    pub min_observations: u64,
543    /// Path for persisting reputation state. Defaults to `~/.config/zeph/router_reputation_state.json`.
544    #[serde(default)]
545    pub state_path: Option<String>,
546}
547
548/// Configuration for cascade routing (`strategy = "cascade"`).
549///
550/// Cascade routing tries providers in chain order (cheapest first), escalating to
551/// the next provider when the response is classified as degenerate (empty, repetitive,
552/// incoherent). Chain order determines cost order: first provider = cheapest.
553///
554/// # Limitations
555///
556/// The heuristic classifier detects degenerate outputs only, not semantic failures.
557/// Use `classifier_mode = "judge"` for semantic quality gating (adds LLM call cost).
558#[derive(Debug, Clone, Deserialize, Serialize)]
559pub struct CascadeConfig {
560    /// Minimum quality score [0.0, 1.0] to accept a response without escalating.
561    /// Responses scoring below this threshold trigger escalation.
562    #[serde(default = "default_cascade_quality_threshold")]
563    pub quality_threshold: f64,
564
565    /// Maximum number of quality-based escalations per request.
566    /// Network/API errors do not count against this budget.
567    /// Default: 2 (allows up to 3 providers: cheap → mid → expensive).
568    #[serde(default = "default_cascade_max_escalations")]
569    pub max_escalations: u8,
570
571    /// Quality classifier mode: `"heuristic"` (default) or `"judge"`.
572    /// Heuristic is zero-cost but detects only degenerate outputs.
573    /// Judge requires a configured `summary_model` and adds one LLM call per evaluation.
574    #[serde(default)]
575    pub classifier_mode: CascadeClassifierMode,
576
577    /// Rolling quality history window size per provider. Default: 50.
578    #[serde(default = "default_cascade_window_size")]
579    pub window_size: usize,
580
581    /// Maximum cumulative input+output tokens across all escalation levels.
582    /// When exceeded, returns the best-seen response instead of escalating further.
583    /// `None` disables the budget (unbounded escalation cost).
584    #[serde(default)]
585    pub max_cascade_tokens: Option<u32>,
586
587    /// Explicit cost ordering of provider names (cheapest first).
588    /// When set, cascade routing sorts providers by their position in this list before
589    /// trying them. Providers not in the list are appended after listed ones in their
590    /// original chain order. When unset, chain order is used (default behavior).
591    #[serde(default, skip_serializing_if = "Option::is_none")]
592    pub cost_tiers: Option<Vec<String>>,
593}
594
595impl Default for CascadeConfig {
596    fn default() -> Self {
597        Self {
598            quality_threshold: default_cascade_quality_threshold(),
599            max_escalations: default_cascade_max_escalations(),
600            classifier_mode: CascadeClassifierMode::default(),
601            window_size: default_cascade_window_size(),
602            max_cascade_tokens: None,
603            cost_tiers: None,
604        }
605    }
606}
607
608/// Quality classifier mode for cascade routing.
609#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
610#[serde(rename_all = "lowercase")]
611pub enum CascadeClassifierMode {
612    /// Zero-cost heuristic: detects degenerate outputs (empty, repetitive, incoherent).
613    /// Does not detect semantic failures (hallucinations, wrong answers).
614    #[default]
615    Heuristic,
616    /// LLM-based judge: more accurate but adds latency. Falls back to heuristic on failure.
617    /// Requires `summary_model` to be configured.
618    Judge,
619}
620
621fn default_bandit_alpha() -> f32 {
622    1.0
623}
624
625fn default_bandit_dim() -> usize {
626    32
627}
628
629fn default_bandit_cost_weight() -> f32 {
630    0.1
631}
632
633fn default_bandit_decay_factor() -> f32 {
634    1.0
635}
636
637fn default_bandit_embedding_timeout_ms() -> u64 {
638    50
639}
640
641fn default_bandit_cache_size() -> usize {
642    512
643}
644
645/// Configuration for PILOT bandit routing (`strategy = "bandit"`).
646///
647/// PILOT (Provider Intelligence via Learned Online Tuning) uses a `LinUCB` contextual
648/// bandit to learn which provider performs best for a given query context. The feature
649/// vector is derived from the query embedding (first `dim` components, L2-normalised).
650///
651/// **Cold start**: the bandit falls back to Thompson sampling for the first
652/// `10 * num_providers` queries (configurable). After warmup, `LinUCB` takes over.
653///
654/// **Embedding**: an `embedding_provider` must be set for feature vectors. If the embed
655/// call exceeds `embedding_timeout_ms` or fails, the bandit falls back to Thompson/uniform.
656/// Use a local provider (Ollama, Candle) to avoid network latency on the hot path.
657#[derive(Debug, Clone, Deserialize, Serialize)]
658pub struct BanditConfig {
659    /// `LinUCB` exploration parameter. Default: 1.0.
660    /// Higher values increase exploration; lower values favour exploitation.
661    #[serde(default = "default_bandit_alpha")]
662    pub alpha: f32,
663
664    /// Feature vector dimension (first `dim` components of the embedding).
665    ///
666    /// This is simple truncation, not PCA. The first raw embedding dimensions do not
667    /// necessarily capture the most variance. For `OpenAI` `text-embedding-3-*` models,
668    /// consider using the `dimensions` API parameter (Matryoshka embeddings) instead.
669    /// Default: 32.
670    #[serde(default = "default_bandit_dim")]
671    pub dim: usize,
672
673    /// Cost penalty weight in the reward signal: `reward = quality - cost_weight * cost_fraction`.
674    /// Default: 0.1. Increase to penalise expensive providers more aggressively.
675    #[serde(default = "default_bandit_cost_weight")]
676    pub cost_weight: f32,
677
678    /// Session-level decay applied to arm state on startup: `A = I + decay*(A-I)`, `b = decay*b`.
679    /// Values < 1.0 cause re-exploration after provider quality changes. Default: 1.0 (no decay).
680    #[serde(default = "default_bandit_decay_factor")]
681    pub decay_factor: f32,
682
683    /// Provider name from `[[llm.providers]]` used for query embeddings.
684    ///
685    /// SLM recommended: prefer a fast local model (e.g. Ollama `nomic-embed-text`,
686    /// Candle, or `text-embedding-3-small`) — this is called on every bandit request.
687    /// Empty string disables `LinUCB` (bandit always falls back to Thompson/uniform).
688    #[serde(default)]
689    pub embedding_provider: ProviderName,
690
691    /// Hard timeout for the embedding call in milliseconds. Default: 50.
692    /// If exceeded, the request falls back to Thompson/uniform selection.
693    #[serde(default = "default_bandit_embedding_timeout_ms")]
694    pub embedding_timeout_ms: u64,
695
696    /// Maximum cached embeddings (keyed by query text hash). Default: 512.
697    #[serde(default = "default_bandit_cache_size")]
698    pub cache_size: usize,
699
700    /// Path for persisting bandit state. Defaults to `~/.config/zeph/router_bandit_state.json`.
701    ///
702    /// # Security
703    ///
704    /// This path is user-controlled. The file is created with mode `0o600` on Unix.
705    /// Do not place it in world-writable directories.
706    #[serde(default)]
707    pub state_path: Option<String>,
708
709    /// MAR (Memory-Augmented Routing) confidence threshold.
710    ///
711    /// When the top-1 semantic recall score for the current query is >= this value,
712    /// the bandit biases toward cheaper providers (the answer is likely in memory).
713    /// Set to 1.0 to disable MAR. Default: 0.9.
714    #[serde(default = "default_bandit_memory_confidence_threshold")]
715    pub memory_confidence_threshold: f32,
716
717    /// Minimum number of queries before `LinUCB` takes over from Thompson warmup.
718    ///
719    /// When unset or `0`, defaults to `10 × number of providers` (computed at startup).
720    /// Set explicitly to control how long the bandit explores uniformly before
721    /// switching to context-aware routing. Setting `0` preserves the computed default.
722    #[serde(default)]
723    pub warmup_queries: Option<u64>,
724}
725
726fn default_bandit_memory_confidence_threshold() -> f32 {
727    0.9
728}
729
730impl Default for BanditConfig {
731    fn default() -> Self {
732        Self {
733            alpha: default_bandit_alpha(),
734            dim: default_bandit_dim(),
735            cost_weight: default_bandit_cost_weight(),
736            decay_factor: default_bandit_decay_factor(),
737            embedding_provider: ProviderName::default(),
738            embedding_timeout_ms: default_bandit_embedding_timeout_ms(),
739            cache_size: default_bandit_cache_size(),
740            state_path: None,
741            memory_confidence_threshold: default_bandit_memory_confidence_threshold(),
742            warmup_queries: None,
743        }
744    }
745}
746
747#[derive(Debug, Deserialize, Serialize)]
748pub struct CandleConfig {
749    #[serde(default = "default_candle_source")]
750    pub source: String,
751    #[serde(default)]
752    pub local_path: String,
753    #[serde(default)]
754    pub filename: Option<String>,
755    #[serde(default = "default_chat_template")]
756    pub chat_template: String,
757    #[serde(default = "default_candle_device")]
758    pub device: String,
759    #[serde(default)]
760    pub embedding_repo: Option<String>,
761    /// Resolved `HuggingFace` Hub API token for authenticated model downloads.
762    ///
763    /// Must be the **token value** — resolved by the caller before constructing this config.
764    #[serde(default)]
765    pub hf_token: Option<String>,
766    #[serde(default)]
767    pub generation: GenerationParams,
768}
769
770#[derive(Debug, Clone, Deserialize, Serialize)]
771pub struct GenerationParams {
772    #[serde(default = "default_temperature")]
773    pub temperature: f64,
774    #[serde(default)]
775    pub top_p: Option<f64>,
776    #[serde(default)]
777    pub top_k: Option<usize>,
778    #[serde(default = "default_max_tokens")]
779    pub max_tokens: usize,
780    #[serde(default = "default_seed")]
781    pub seed: u64,
782    #[serde(default = "default_repeat_penalty")]
783    pub repeat_penalty: f32,
784    #[serde(default = "default_repeat_last_n")]
785    pub repeat_last_n: usize,
786}
787
788pub const MAX_TOKENS_CAP: usize = 32768;
789
790impl GenerationParams {
791    #[must_use]
792    pub fn capped_max_tokens(&self) -> usize {
793        self.max_tokens.min(MAX_TOKENS_CAP)
794    }
795}
796
797impl Default for GenerationParams {
798    fn default() -> Self {
799        Self {
800            temperature: default_temperature(),
801            top_p: None,
802            top_k: None,
803            max_tokens: default_max_tokens(),
804            seed: default_seed(),
805            repeat_penalty: default_repeat_penalty(),
806            repeat_last_n: default_repeat_last_n(),
807        }
808    }
809}
810
811// ─── Unified config types ─────────────────────────────────────────────────────
812
813/// Routing strategy for the `[[llm.providers]]` pool.
814#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
815#[serde(rename_all = "lowercase")]
816pub enum LlmRoutingStrategy {
817    /// Single provider or first-in-pool (default).
818    #[default]
819    None,
820    /// Exponential moving average latency-aware ordering.
821    Ema,
822    /// Thompson Sampling with Beta distributions.
823    Thompson,
824    /// Cascade: try cheapest provider first, escalate on degenerate output.
825    Cascade,
826    /// Task-based routing using `[llm.routes]` map.
827    Task,
828    /// Complexity triage routing: pre-classify each request, delegate to appropriate tier.
829    Triage,
830    /// PILOT: `LinUCB` contextual bandit with online learning and budget-aware reward.
831    Bandit,
832}
833
834fn default_triage_timeout_secs() -> u64 {
835    5
836}
837
838fn default_max_triage_tokens() -> u32 {
839    50
840}
841
842fn default_true() -> bool {
843    true
844}
845
846/// Tier-to-provider name mapping for complexity routing.
847#[derive(Debug, Clone, Default, Deserialize, Serialize)]
848pub struct TierMapping {
849    pub simple: Option<String>,
850    pub medium: Option<String>,
851    pub complex: Option<String>,
852    pub expert: Option<String>,
853}
854
855/// Configuration for complexity-based triage routing (`routing = "triage"`).
856///
857/// When `[llm] routing = "triage"` is set, a cheap triage model classifies each request
858/// and routes it to the appropriate tier provider. Requires at least one tier mapping.
859///
860/// # Example
861///
862/// ```toml
863/// [llm]
864/// routing = "triage"
865///
866/// [llm.complexity_routing]
867/// triage_provider = "local-fast"
868///
869/// [llm.complexity_routing.tiers]
870/// simple = "local-fast"
871/// medium = "haiku"
872/// complex = "sonnet"
873/// expert = "opus"
874/// ```
875#[derive(Debug, Clone, Deserialize, Serialize)]
876pub struct ComplexityRoutingConfig {
877    /// Provider name from `[[llm.providers]]` used for triage classification.
878    #[serde(default)]
879    pub triage_provider: Option<ProviderName>,
880
881    /// Skip triage when all tiers map to the same provider.
882    #[serde(default = "default_true")]
883    pub bypass_single_provider: bool,
884
885    /// Tier-to-provider name mapping.
886    #[serde(default)]
887    pub tiers: TierMapping,
888
889    /// Max output tokens for the triage classification call. Default: 50.
890    #[serde(default = "default_max_triage_tokens")]
891    pub max_triage_tokens: u32,
892
893    /// Timeout in seconds for the triage classification call. Default: 5.
894    /// On timeout, falls back to the default (first) tier provider.
895    #[serde(default = "default_triage_timeout_secs")]
896    pub triage_timeout_secs: u64,
897
898    /// Optional fallback strategy when triage misclassifies.
899    /// Only `"cascade"` is currently supported (Phase 4).
900    #[serde(default)]
901    pub fallback_strategy: Option<String>,
902}
903
904impl Default for ComplexityRoutingConfig {
905    fn default() -> Self {
906        Self {
907            triage_provider: None,
908            bypass_single_provider: true,
909            tiers: TierMapping::default(),
910            max_triage_tokens: default_max_triage_tokens(),
911            triage_timeout_secs: default_triage_timeout_secs(),
912            fallback_strategy: None,
913        }
914    }
915}
916
917/// Inline candle config for use inside `ProviderEntry`.
918/// Re-uses the generation params from `CandleConfig`.
919#[derive(Debug, Clone, Deserialize, Serialize)]
920pub struct CandleInlineConfig {
921    #[serde(default = "default_candle_source")]
922    pub source: String,
923    #[serde(default)]
924    pub local_path: String,
925    #[serde(default)]
926    pub filename: Option<String>,
927    #[serde(default = "default_chat_template")]
928    pub chat_template: String,
929    #[serde(default = "default_candle_device")]
930    pub device: String,
931    #[serde(default)]
932    pub embedding_repo: Option<String>,
933    /// Resolved `HuggingFace` Hub API token for authenticated model downloads.
934    #[serde(default)]
935    pub hf_token: Option<String>,
936    #[serde(default)]
937    pub generation: GenerationParams,
938}
939
940impl Default for CandleInlineConfig {
941    fn default() -> Self {
942        Self {
943            source: default_candle_source(),
944            local_path: String::new(),
945            filename: None,
946            chat_template: default_chat_template(),
947            device: default_candle_device(),
948            embedding_repo: None,
949            hf_token: None,
950            generation: GenerationParams::default(),
951        }
952    }
953}
954
955/// Unified provider entry: one struct replaces `CloudLlmConfig`, `OpenAiConfig`,
956/// `GeminiConfig`, `OllamaConfig`, `CompatibleConfig`, and `OrchestratorProviderConfig`.
957///
958/// Provider-specific fields use `#[serde(default)]` and are ignored by backends
959/// that do not use them (flat-union pattern).
960#[derive(Debug, Clone, Deserialize, Serialize)]
961#[allow(clippy::struct_excessive_bools)]
962pub struct ProviderEntry {
963    /// Required: provider backend type.
964    #[serde(rename = "type")]
965    pub provider_type: ProviderKind,
966
967    /// Optional name for multi-provider configs. Auto-generated from type if absent.
968    #[serde(default)]
969    pub name: Option<String>,
970
971    /// Model identifier. Required for most types.
972    #[serde(default)]
973    pub model: Option<String>,
974
975    /// API base URL. Each type has its own default.
976    #[serde(default)]
977    pub base_url: Option<String>,
978
979    /// Max output tokens.
980    #[serde(default)]
981    pub max_tokens: Option<u32>,
982
983    /// Embedding model. When set, this provider supports `embed()` calls.
984    #[serde(default)]
985    pub embedding_model: Option<String>,
986
987    /// STT model. When set, this provider supports speech-to-text via the Whisper API or
988    /// Candle-local inference.
989    #[serde(default)]
990    pub stt_model: Option<String>,
991
992    /// Mark this entry as the embedding provider (handles `embed()` calls).
993    #[serde(default)]
994    pub embed: bool,
995
996    /// Mark this entry as the default chat provider (overrides position-based default).
997    #[serde(default)]
998    pub default: bool,
999
1000    // --- Claude-specific ---
1001    #[serde(default)]
1002    pub thinking: Option<ThinkingConfig>,
1003    #[serde(default)]
1004    pub server_compaction: bool,
1005    #[serde(default)]
1006    pub enable_extended_context: bool,
1007
1008    // --- OpenAI-specific ---
1009    #[serde(default)]
1010    pub reasoning_effort: Option<String>,
1011
1012    // --- Gemini-specific ---
1013    #[serde(default)]
1014    pub thinking_level: Option<GeminiThinkingLevel>,
1015    #[serde(default)]
1016    pub thinking_budget: Option<i32>,
1017    #[serde(default)]
1018    pub include_thoughts: Option<bool>,
1019
1020    // --- Compatible-specific: optional inline api_key ---
1021    #[serde(default)]
1022    pub api_key: Option<String>,
1023
1024    // --- Candle-specific ---
1025    #[serde(default)]
1026    pub candle: Option<CandleInlineConfig>,
1027
1028    // --- Vision ---
1029    #[serde(default)]
1030    pub vision_model: Option<String>,
1031
1032    /// Provider-specific instruction file.
1033    #[serde(default)]
1034    pub instruction_file: Option<std::path::PathBuf>,
1035}
1036
1037impl Default for ProviderEntry {
1038    fn default() -> Self {
1039        Self {
1040            provider_type: ProviderKind::Ollama,
1041            name: None,
1042            model: None,
1043            base_url: None,
1044            max_tokens: None,
1045            embedding_model: None,
1046            stt_model: None,
1047            embed: false,
1048            default: false,
1049            thinking: None,
1050            server_compaction: false,
1051            enable_extended_context: false,
1052            reasoning_effort: None,
1053            thinking_level: None,
1054            thinking_budget: None,
1055            include_thoughts: None,
1056            api_key: None,
1057            candle: None,
1058            vision_model: None,
1059            instruction_file: None,
1060        }
1061    }
1062}
1063
1064impl ProviderEntry {
1065    /// Resolve the effective name: explicit `name` field or type string.
1066    #[must_use]
1067    pub fn effective_name(&self) -> String {
1068        self.name
1069            .clone()
1070            .unwrap_or_else(|| self.provider_type.as_str().to_owned())
1071    }
1072
1073    /// Resolve the effective model: explicit `model` field or the provider-type default.
1074    ///
1075    /// Defaults mirror those used in `build_provider_from_entry` so that `runtime.model_name`
1076    /// always reflects the actual model being used rather than the provider type string.
1077    #[must_use]
1078    pub fn effective_model(&self) -> String {
1079        if let Some(ref m) = self.model {
1080            return m.clone();
1081        }
1082        match self.provider_type {
1083            ProviderKind::Ollama => "qwen3:8b".to_owned(),
1084            ProviderKind::Claude => "claude-haiku-4-5-20251001".to_owned(),
1085            ProviderKind::OpenAi => "gpt-4o-mini".to_owned(),
1086            ProviderKind::Gemini => "gemini-2.0-flash".to_owned(),
1087            ProviderKind::Compatible | ProviderKind::Candle => String::new(),
1088        }
1089    }
1090
1091    /// Validate this entry for cross-field consistency.
1092    ///
1093    /// # Errors
1094    ///
1095    /// Returns `ConfigError` when a fatal invariant is violated (e.g. compatible provider
1096    /// without a name).
1097    pub fn validate(&self) -> Result<(), crate::error::ConfigError> {
1098        use crate::error::ConfigError;
1099
1100        // B2: compatible provider MUST have name set.
1101        if self.provider_type == ProviderKind::Compatible && self.name.is_none() {
1102            return Err(ConfigError::Validation(
1103                "[[llm.providers]] entry with type=\"compatible\" must set `name`".into(),
1104            ));
1105        }
1106
1107        // B1: warn on irrelevant fields.
1108        match self.provider_type {
1109            ProviderKind::Ollama => {
1110                if self.thinking.is_some() {
1111                    tracing::warn!(
1112                        provider = self.effective_name(),
1113                        "field `thinking` is only used by Claude providers"
1114                    );
1115                }
1116                if self.reasoning_effort.is_some() {
1117                    tracing::warn!(
1118                        provider = self.effective_name(),
1119                        "field `reasoning_effort` is only used by OpenAI providers"
1120                    );
1121                }
1122                if self.thinking_level.is_some() || self.thinking_budget.is_some() {
1123                    tracing::warn!(
1124                        provider = self.effective_name(),
1125                        "fields `thinking_level`/`thinking_budget` are only used by Gemini providers"
1126                    );
1127                }
1128            }
1129            ProviderKind::Claude => {
1130                if self.reasoning_effort.is_some() {
1131                    tracing::warn!(
1132                        provider = self.effective_name(),
1133                        "field `reasoning_effort` is only used by OpenAI providers"
1134                    );
1135                }
1136                if self.thinking_level.is_some() || self.thinking_budget.is_some() {
1137                    tracing::warn!(
1138                        provider = self.effective_name(),
1139                        "fields `thinking_level`/`thinking_budget` are only used by Gemini providers"
1140                    );
1141                }
1142            }
1143            ProviderKind::OpenAi => {
1144                if self.thinking.is_some() {
1145                    tracing::warn!(
1146                        provider = self.effective_name(),
1147                        "field `thinking` is only used by Claude providers"
1148                    );
1149                }
1150                if self.thinking_level.is_some() || self.thinking_budget.is_some() {
1151                    tracing::warn!(
1152                        provider = self.effective_name(),
1153                        "fields `thinking_level`/`thinking_budget` are only used by Gemini providers"
1154                    );
1155                }
1156            }
1157            ProviderKind::Gemini => {
1158                if self.thinking.is_some() {
1159                    tracing::warn!(
1160                        provider = self.effective_name(),
1161                        "field `thinking` is only used by Claude providers"
1162                    );
1163                }
1164                if self.reasoning_effort.is_some() {
1165                    tracing::warn!(
1166                        provider = self.effective_name(),
1167                        "field `reasoning_effort` is only used by OpenAI providers"
1168                    );
1169                }
1170            }
1171            _ => {}
1172        }
1173
1174        // W6: Candle STT-only provider (stt_model set, no model) is valid — no warning needed.
1175        // Warn if Ollama has stt_model set (Ollama does not support Whisper API).
1176        if self.stt_model.is_some() && self.provider_type == ProviderKind::Ollama {
1177            tracing::warn!(
1178                provider = self.effective_name(),
1179                "field `stt_model` is set on an Ollama provider; Ollama does not support the \
1180                 Whisper STT API — use OpenAI, compatible, or candle instead"
1181            );
1182        }
1183
1184        Ok(())
1185    }
1186}
1187
1188/// Validate a pool of `ProviderEntry` items.
1189///
1190/// # Errors
1191///
1192/// Returns `ConfigError` for fatal validation failures:
1193/// - Empty pool
1194/// - Duplicate names
1195/// - Multiple entries marked `default = true`
1196/// - Individual entry validation errors
1197pub fn validate_pool(entries: &[ProviderEntry]) -> Result<(), crate::error::ConfigError> {
1198    use crate::error::ConfigError;
1199    use std::collections::HashSet;
1200
1201    if entries.is_empty() {
1202        return Err(ConfigError::Validation(
1203            "at least one LLM provider must be configured in [[llm.providers]]".into(),
1204        ));
1205    }
1206
1207    let default_count = entries.iter().filter(|e| e.default).count();
1208    if default_count > 1 {
1209        return Err(ConfigError::Validation(
1210            "only one [[llm.providers]] entry can be marked `default = true`".into(),
1211        ));
1212    }
1213
1214    let mut seen_names: HashSet<String> = HashSet::new();
1215    for entry in entries {
1216        let name = entry.effective_name();
1217        if !seen_names.insert(name.clone()) {
1218            return Err(ConfigError::Validation(format!(
1219                "duplicate provider name \"{name}\" in [[llm.providers]]"
1220            )));
1221        }
1222        entry.validate()?;
1223    }
1224
1225    Ok(())
1226}
1227
1228#[cfg(test)]
1229mod tests {
1230    use super::*;
1231
1232    fn ollama_entry() -> ProviderEntry {
1233        ProviderEntry {
1234            provider_type: ProviderKind::Ollama,
1235            name: Some("ollama".into()),
1236            model: Some("qwen3:8b".into()),
1237            ..Default::default()
1238        }
1239    }
1240
1241    fn claude_entry() -> ProviderEntry {
1242        ProviderEntry {
1243            provider_type: ProviderKind::Claude,
1244            name: Some("claude".into()),
1245            model: Some("claude-sonnet-4-6".into()),
1246            max_tokens: Some(8192),
1247            ..Default::default()
1248        }
1249    }
1250
1251    // ─── ProviderEntry::validate ─────────────────────────────────────────────
1252
1253    #[test]
1254    fn validate_ollama_valid() {
1255        assert!(ollama_entry().validate().is_ok());
1256    }
1257
1258    #[test]
1259    fn validate_claude_valid() {
1260        assert!(claude_entry().validate().is_ok());
1261    }
1262
1263    #[test]
1264    fn validate_compatible_without_name_errors() {
1265        let entry = ProviderEntry {
1266            provider_type: ProviderKind::Compatible,
1267            name: None,
1268            ..Default::default()
1269        };
1270        let err = entry.validate().unwrap_err();
1271        assert!(
1272            err.to_string().contains("compatible"),
1273            "error should mention compatible: {err}"
1274        );
1275    }
1276
1277    #[test]
1278    fn validate_compatible_with_name_ok() {
1279        let entry = ProviderEntry {
1280            provider_type: ProviderKind::Compatible,
1281            name: Some("my-proxy".into()),
1282            base_url: Some("http://localhost:8080".into()),
1283            model: Some("gpt-4o".into()),
1284            max_tokens: Some(4096),
1285            ..Default::default()
1286        };
1287        assert!(entry.validate().is_ok());
1288    }
1289
1290    #[test]
1291    fn validate_openai_valid() {
1292        let entry = ProviderEntry {
1293            provider_type: ProviderKind::OpenAi,
1294            name: Some("openai".into()),
1295            model: Some("gpt-4o".into()),
1296            max_tokens: Some(4096),
1297            ..Default::default()
1298        };
1299        assert!(entry.validate().is_ok());
1300    }
1301
1302    #[test]
1303    fn validate_gemini_valid() {
1304        let entry = ProviderEntry {
1305            provider_type: ProviderKind::Gemini,
1306            name: Some("gemini".into()),
1307            model: Some("gemini-2.0-flash".into()),
1308            ..Default::default()
1309        };
1310        assert!(entry.validate().is_ok());
1311    }
1312
1313    // ─── validate_pool ───────────────────────────────────────────────────────
1314
1315    #[test]
1316    fn validate_pool_empty_errors() {
1317        let err = validate_pool(&[]).unwrap_err();
1318        assert!(err.to_string().contains("at least one"), "{err}");
1319    }
1320
1321    #[test]
1322    fn validate_pool_single_entry_ok() {
1323        assert!(validate_pool(&[ollama_entry()]).is_ok());
1324    }
1325
1326    #[test]
1327    fn validate_pool_duplicate_names_errors() {
1328        let a = ollama_entry();
1329        let b = ollama_entry(); // same effective name "ollama"
1330        let err = validate_pool(&[a, b]).unwrap_err();
1331        assert!(err.to_string().contains("duplicate"), "{err}");
1332    }
1333
1334    #[test]
1335    fn validate_pool_multiple_defaults_errors() {
1336        let mut a = ollama_entry();
1337        let mut b = claude_entry();
1338        a.default = true;
1339        b.default = true;
1340        let err = validate_pool(&[a, b]).unwrap_err();
1341        assert!(err.to_string().contains("default"), "{err}");
1342    }
1343
1344    #[test]
1345    fn validate_pool_two_different_providers_ok() {
1346        assert!(validate_pool(&[ollama_entry(), claude_entry()]).is_ok());
1347    }
1348
1349    #[test]
1350    fn validate_pool_propagates_entry_error() {
1351        let bad = ProviderEntry {
1352            provider_type: ProviderKind::Compatible,
1353            name: None, // invalid: compatible without name
1354            ..Default::default()
1355        };
1356        assert!(validate_pool(&[bad]).is_err());
1357    }
1358
1359    // ─── ProviderEntry::effective_model ──────────────────────────────────────
1360
1361    #[test]
1362    fn effective_model_returns_explicit_when_set() {
1363        let entry = ProviderEntry {
1364            provider_type: ProviderKind::Claude,
1365            model: Some("claude-sonnet-4-6".into()),
1366            ..Default::default()
1367        };
1368        assert_eq!(entry.effective_model(), "claude-sonnet-4-6");
1369    }
1370
1371    #[test]
1372    fn effective_model_ollama_default_when_none() {
1373        let entry = ProviderEntry {
1374            provider_type: ProviderKind::Ollama,
1375            model: None,
1376            ..Default::default()
1377        };
1378        assert_eq!(entry.effective_model(), "qwen3:8b");
1379    }
1380
1381    #[test]
1382    fn effective_model_claude_default_when_none() {
1383        let entry = ProviderEntry {
1384            provider_type: ProviderKind::Claude,
1385            model: None,
1386            ..Default::default()
1387        };
1388        assert_eq!(entry.effective_model(), "claude-haiku-4-5-20251001");
1389    }
1390
1391    #[test]
1392    fn effective_model_openai_default_when_none() {
1393        let entry = ProviderEntry {
1394            provider_type: ProviderKind::OpenAi,
1395            model: None,
1396            ..Default::default()
1397        };
1398        assert_eq!(entry.effective_model(), "gpt-4o-mini");
1399    }
1400
1401    #[test]
1402    fn effective_model_gemini_default_when_none() {
1403        let entry = ProviderEntry {
1404            provider_type: ProviderKind::Gemini,
1405            model: None,
1406            ..Default::default()
1407        };
1408        assert_eq!(entry.effective_model(), "gemini-2.0-flash");
1409    }
1410
1411    // ─── LlmConfig::check_legacy_format ──────────────────────────────────────
1412
1413    // Parse a complete TOML snippet that includes the [llm] header.
1414    fn parse_llm(toml: &str) -> LlmConfig {
1415        #[derive(serde::Deserialize)]
1416        struct Wrapper {
1417            llm: LlmConfig,
1418        }
1419        toml::from_str::<Wrapper>(toml).unwrap().llm
1420    }
1421
1422    #[test]
1423    fn check_legacy_format_new_format_ok() {
1424        let cfg = parse_llm(
1425            r#"
1426[llm]
1427
1428[[llm.providers]]
1429type = "ollama"
1430model = "qwen3:8b"
1431"#,
1432        );
1433        assert!(cfg.check_legacy_format().is_ok());
1434    }
1435
1436    #[test]
1437    fn check_legacy_format_empty_providers_no_legacy_ok() {
1438        // No providers, no legacy fields — passes (empty [llm] is acceptable here)
1439        let cfg = parse_llm("[llm]\n");
1440        assert!(cfg.check_legacy_format().is_ok());
1441    }
1442
1443    // ─── LlmConfig::effective_* helpers ──────────────────────────────────────
1444
1445    #[test]
1446    fn effective_provider_falls_back_to_ollama_when_no_providers() {
1447        let cfg = parse_llm("[llm]\n");
1448        assert_eq!(cfg.effective_provider(), ProviderKind::Ollama);
1449    }
1450
1451    #[test]
1452    fn effective_provider_reads_from_providers_first() {
1453        let cfg = parse_llm(
1454            r#"
1455[llm]
1456
1457[[llm.providers]]
1458type = "claude"
1459model = "claude-sonnet-4-6"
1460"#,
1461        );
1462        assert_eq!(cfg.effective_provider(), ProviderKind::Claude);
1463    }
1464
1465    #[test]
1466    fn effective_model_reads_from_providers_first() {
1467        let cfg = parse_llm(
1468            r#"
1469[llm]
1470
1471[[llm.providers]]
1472type = "ollama"
1473model = "qwen3:8b"
1474"#,
1475        );
1476        assert_eq!(cfg.effective_model(), "qwen3:8b");
1477    }
1478
1479    #[test]
1480    fn effective_base_url_default_when_absent() {
1481        let cfg = parse_llm("[llm]\n");
1482        assert_eq!(cfg.effective_base_url(), "http://localhost:11434");
1483    }
1484
1485    #[test]
1486    fn effective_base_url_from_providers_entry() {
1487        let cfg = parse_llm(
1488            r#"
1489[llm]
1490
1491[[llm.providers]]
1492type = "ollama"
1493base_url = "http://myhost:11434"
1494"#,
1495        );
1496        assert_eq!(cfg.effective_base_url(), "http://myhost:11434");
1497    }
1498
1499    // ─── ComplexityRoutingConfig / LlmRoutingStrategy::Triage TOML parsing ──
1500
1501    #[test]
1502    fn complexity_routing_defaults() {
1503        let cr = ComplexityRoutingConfig::default();
1504        assert!(
1505            cr.bypass_single_provider,
1506            "bypass_single_provider must default to true"
1507        );
1508        assert_eq!(cr.triage_timeout_secs, 5);
1509        assert_eq!(cr.max_triage_tokens, 50);
1510        assert!(cr.triage_provider.is_none());
1511        assert!(cr.tiers.simple.is_none());
1512    }
1513
1514    #[test]
1515    fn complexity_routing_toml_round_trip() {
1516        let cfg = parse_llm(
1517            r#"
1518[llm]
1519routing = "triage"
1520
1521[llm.complexity_routing]
1522triage_provider = "fast"
1523bypass_single_provider = false
1524triage_timeout_secs = 10
1525max_triage_tokens = 100
1526
1527[llm.complexity_routing.tiers]
1528simple = "fast"
1529medium = "medium"
1530complex = "large"
1531expert = "opus"
1532"#,
1533        );
1534        assert!(matches!(cfg.routing, LlmRoutingStrategy::Triage));
1535        let cr = cfg
1536            .complexity_routing
1537            .expect("complexity_routing must be present");
1538        assert_eq!(cr.triage_provider.as_deref(), Some("fast"));
1539        assert!(!cr.bypass_single_provider);
1540        assert_eq!(cr.triage_timeout_secs, 10);
1541        assert_eq!(cr.max_triage_tokens, 100);
1542        assert_eq!(cr.tiers.simple.as_deref(), Some("fast"));
1543        assert_eq!(cr.tiers.medium.as_deref(), Some("medium"));
1544        assert_eq!(cr.tiers.complex.as_deref(), Some("large"));
1545        assert_eq!(cr.tiers.expert.as_deref(), Some("opus"));
1546    }
1547
1548    #[test]
1549    fn complexity_routing_partial_tiers_toml() {
1550        // Only simple + complex configured; medium and expert are None.
1551        let cfg = parse_llm(
1552            r#"
1553[llm]
1554routing = "triage"
1555
1556[llm.complexity_routing.tiers]
1557simple = "haiku"
1558complex = "sonnet"
1559"#,
1560        );
1561        let cr = cfg
1562            .complexity_routing
1563            .expect("complexity_routing must be present");
1564        assert_eq!(cr.tiers.simple.as_deref(), Some("haiku"));
1565        assert!(cr.tiers.medium.is_none());
1566        assert_eq!(cr.tiers.complex.as_deref(), Some("sonnet"));
1567        assert!(cr.tiers.expert.is_none());
1568        // Defaults still applied.
1569        assert!(cr.bypass_single_provider);
1570        assert_eq!(cr.triage_timeout_secs, 5);
1571    }
1572
1573    #[test]
1574    fn routing_strategy_triage_deserialized() {
1575        let cfg = parse_llm(
1576            r#"
1577[llm]
1578routing = "triage"
1579"#,
1580        );
1581        assert!(matches!(cfg.routing, LlmRoutingStrategy::Triage));
1582    }
1583
1584    // ─── stt_provider_entry ───────────────────────────────────────────────────
1585
1586    #[test]
1587    fn stt_provider_entry_by_name_match() {
1588        let cfg = parse_llm(
1589            r#"
1590[llm]
1591
1592[[llm.providers]]
1593type = "openai"
1594name = "quality"
1595model = "gpt-5.4"
1596stt_model = "gpt-4o-mini-transcribe"
1597
1598[llm.stt]
1599provider = "quality"
1600"#,
1601        );
1602        let entry = cfg.stt_provider_entry().expect("should find stt provider");
1603        assert_eq!(entry.effective_name(), "quality");
1604        assert_eq!(entry.stt_model.as_deref(), Some("gpt-4o-mini-transcribe"));
1605    }
1606
1607    #[test]
1608    fn stt_provider_entry_auto_detect_when_provider_empty() {
1609        let cfg = parse_llm(
1610            r#"
1611[llm]
1612
1613[[llm.providers]]
1614type = "openai"
1615name = "openai-stt"
1616stt_model = "whisper-1"
1617
1618[llm.stt]
1619provider = ""
1620"#,
1621        );
1622        let entry = cfg.stt_provider_entry().expect("should auto-detect");
1623        assert_eq!(entry.effective_name(), "openai-stt");
1624    }
1625
1626    #[test]
1627    fn stt_provider_entry_auto_detect_no_stt_section() {
1628        let cfg = parse_llm(
1629            r#"
1630[llm]
1631
1632[[llm.providers]]
1633type = "openai"
1634name = "openai-stt"
1635stt_model = "whisper-1"
1636"#,
1637        );
1638        // No [llm.stt] section — should still find first provider with stt_model.
1639        let entry = cfg.stt_provider_entry().expect("should auto-detect");
1640        assert_eq!(entry.effective_name(), "openai-stt");
1641    }
1642
1643    #[test]
1644    fn stt_provider_entry_none_when_no_stt_model() {
1645        let cfg = parse_llm(
1646            r#"
1647[llm]
1648
1649[[llm.providers]]
1650type = "openai"
1651name = "quality"
1652model = "gpt-5.4"
1653"#,
1654        );
1655        assert!(cfg.stt_provider_entry().is_none());
1656    }
1657
1658    #[test]
1659    fn stt_provider_entry_name_mismatch_falls_back_to_none() {
1660        // Named provider exists but has no stt_model; another unnamed has stt_model.
1661        let cfg = parse_llm(
1662            r#"
1663[llm]
1664
1665[[llm.providers]]
1666type = "openai"
1667name = "quality"
1668model = "gpt-5.4"
1669
1670[[llm.providers]]
1671type = "openai"
1672name = "openai-stt"
1673stt_model = "whisper-1"
1674
1675[llm.stt]
1676provider = "quality"
1677"#,
1678        );
1679        // "quality" has no stt_model — returns None for name-based lookup.
1680        assert!(cfg.stt_provider_entry().is_none());
1681    }
1682
1683    #[test]
1684    fn stt_config_deserializes_new_slim_format() {
1685        let cfg = parse_llm(
1686            r#"
1687[llm]
1688
1689[[llm.providers]]
1690type = "openai"
1691name = "quality"
1692stt_model = "whisper-1"
1693
1694[llm.stt]
1695provider = "quality"
1696language = "en"
1697"#,
1698        );
1699        let stt = cfg.stt.as_ref().expect("stt section present");
1700        assert_eq!(stt.provider, "quality");
1701        assert_eq!(stt.language, "en");
1702    }
1703
1704    #[test]
1705    fn stt_config_default_provider_is_empty() {
1706        // Verify that W4 fix: default_stt_provider() returns "" not "whisper".
1707        assert_eq!(default_stt_provider(), "");
1708    }
1709
1710    #[test]
1711    fn validate_stt_missing_provider_ok() {
1712        let cfg = parse_llm("[llm]\n");
1713        assert!(cfg.validate_stt().is_ok());
1714    }
1715
1716    #[test]
1717    fn validate_stt_valid_reference() {
1718        let cfg = parse_llm(
1719            r#"
1720[llm]
1721
1722[[llm.providers]]
1723type = "openai"
1724name = "quality"
1725stt_model = "whisper-1"
1726
1727[llm.stt]
1728provider = "quality"
1729"#,
1730        );
1731        assert!(cfg.validate_stt().is_ok());
1732    }
1733
1734    #[test]
1735    fn validate_stt_nonexistent_provider_errors() {
1736        let cfg = parse_llm(
1737            r#"
1738[llm]
1739
1740[[llm.providers]]
1741type = "openai"
1742name = "quality"
1743model = "gpt-5.4"
1744
1745[llm.stt]
1746provider = "nonexistent"
1747"#,
1748        );
1749        assert!(cfg.validate_stt().is_err());
1750    }
1751
1752    #[test]
1753    fn validate_stt_provider_exists_but_no_stt_model_returns_ok_with_warn() {
1754        // MEDIUM: provider is found but has no stt_model — should return Ok (warn path, not error).
1755        let cfg = parse_llm(
1756            r#"
1757[llm]
1758
1759[[llm.providers]]
1760type = "openai"
1761name = "quality"
1762model = "gpt-5.4"
1763
1764[llm.stt]
1765provider = "quality"
1766"#,
1767        );
1768        // validate_stt must succeed (only a tracing::warn is emitted — not an error).
1769        assert!(cfg.validate_stt().is_ok());
1770        // stt_provider_entry must return None because no stt_model is set.
1771        assert!(
1772            cfg.stt_provider_entry().is_none(),
1773            "stt_provider_entry must be None when provider has no stt_model"
1774        );
1775    }
1776
1777    // ─── BanditConfig::warmup_queries deserialization ─────────────────────────
1778
1779    #[test]
1780    fn bandit_warmup_queries_explicit_value_is_deserialized() {
1781        let cfg = parse_llm(
1782            r#"
1783[llm]
1784
1785[llm.router]
1786strategy = "bandit"
1787
1788[llm.router.bandit]
1789warmup_queries = 50
1790"#,
1791        );
1792        let bandit = cfg
1793            .router
1794            .expect("router section must be present")
1795            .bandit
1796            .expect("bandit section must be present");
1797        assert_eq!(
1798            bandit.warmup_queries,
1799            Some(50),
1800            "warmup_queries = 50 must deserialize to Some(50)"
1801        );
1802    }
1803
1804    #[test]
1805    fn bandit_warmup_queries_explicit_null_is_none() {
1806        // Explicitly writing the field as absent: field simply not present is
1807        // equivalent due to #[serde(default)]. Test that an explicit 0 is Some(0).
1808        let cfg = parse_llm(
1809            r#"
1810[llm]
1811
1812[llm.router]
1813strategy = "bandit"
1814
1815[llm.router.bandit]
1816warmup_queries = 0
1817"#,
1818        );
1819        let bandit = cfg
1820            .router
1821            .expect("router section must be present")
1822            .bandit
1823            .expect("bandit section must be present");
1824        // 0 is a valid explicit value — it means "preserve computed default".
1825        assert_eq!(
1826            bandit.warmup_queries,
1827            Some(0),
1828            "warmup_queries = 0 must deserialize to Some(0)"
1829        );
1830    }
1831
1832    #[test]
1833    fn bandit_warmup_queries_missing_field_defaults_to_none() {
1834        // When warmup_queries is omitted entirely, #[serde(default)] must produce None.
1835        let cfg = parse_llm(
1836            r#"
1837[llm]
1838
1839[llm.router]
1840strategy = "bandit"
1841
1842[llm.router.bandit]
1843alpha = 1.5
1844"#,
1845        );
1846        let bandit = cfg
1847            .router
1848            .expect("router section must be present")
1849            .bandit
1850            .expect("bandit section must be present");
1851        assert_eq!(
1852            bandit.warmup_queries, None,
1853            "omitted warmup_queries must default to None"
1854        );
1855    }
1856
1857    #[test]
1858    fn provider_name_new_and_as_str() {
1859        let n = ProviderName::new("fast");
1860        assert_eq!(n.as_str(), "fast");
1861        assert!(!n.is_empty());
1862    }
1863
1864    #[test]
1865    fn provider_name_default_is_empty() {
1866        let n = ProviderName::default();
1867        assert!(n.is_empty());
1868        assert_eq!(n.as_str(), "");
1869    }
1870
1871    #[test]
1872    fn provider_name_deref_to_str() {
1873        let n = ProviderName::new("quality");
1874        let s: &str = &n;
1875        assert_eq!(s, "quality");
1876    }
1877
1878    #[test]
1879    fn provider_name_partial_eq_str() {
1880        let n = ProviderName::new("fast");
1881        assert_eq!(n, "fast");
1882        assert_ne!(n, "slow");
1883    }
1884
1885    #[test]
1886    fn provider_name_serde_roundtrip() {
1887        let n = ProviderName::new("my-provider");
1888        let json = serde_json::to_string(&n).expect("serialize");
1889        assert_eq!(json, "\"my-provider\"");
1890        let back: ProviderName = serde_json::from_str(&json).expect("deserialize");
1891        assert_eq!(back, n);
1892    }
1893
1894    #[test]
1895    fn provider_name_serde_empty_roundtrip() {
1896        let n = ProviderName::default();
1897        let json = serde_json::to_string(&n).expect("serialize");
1898        assert_eq!(json, "\"\"");
1899        let back: ProviderName = serde_json::from_str(&json).expect("deserialize");
1900        assert_eq!(back, n);
1901        assert!(back.is_empty());
1902    }
1903}