1use std::fmt;
5
6use serde::{Deserialize, Serialize};
7use zeph_llm::{CacheTtl, GeminiThinkingLevel, ThinkingConfig};
8
9#[derive(Debug, Clone, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
14#[serde(transparent)]
15pub struct ProviderName(String);
16
17impl ProviderName {
18 #[must_use]
32 pub fn new(name: impl Into<String>) -> Self {
33 Self(name.into())
34 }
35
36 #[must_use]
47 pub fn is_empty(&self) -> bool {
48 self.0.is_empty()
49 }
50
51 #[must_use]
62 pub fn as_str(&self) -> &str {
63 &self.0
64 }
65}
66
67impl fmt::Display for ProviderName {
68 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
69 self.0.fmt(f)
70 }
71}
72
73impl AsRef<str> for ProviderName {
74 fn as_ref(&self) -> &str {
75 &self.0
76 }
77}
78
79impl std::ops::Deref for ProviderName {
80 type Target = str;
81
82 fn deref(&self) -> &str {
83 &self.0
84 }
85}
86
87impl PartialEq<str> for ProviderName {
88 fn eq(&self, other: &str) -> bool {
89 self.0 == other
90 }
91}
92
93impl PartialEq<&str> for ProviderName {
94 fn eq(&self, other: &&str) -> bool {
95 self.0 == *other
96 }
97}
98
99fn default_response_cache_ttl_secs() -> u64 {
100 3600
101}
102
103fn default_semantic_cache_threshold() -> f32 {
104 0.95
105}
106
107fn default_semantic_cache_max_candidates() -> u32 {
108 10
109}
110
111fn default_router_ema_alpha() -> f64 {
112 0.1
113}
114
115fn default_router_reorder_interval() -> u64 {
116 10
117}
118
119fn default_embedding_model() -> String {
120 "qwen3-embedding".into()
121}
122
123fn default_candle_source() -> String {
124 "huggingface".into()
125}
126
127fn default_chat_template() -> String {
128 "chatml".into()
129}
130
131fn default_candle_device() -> String {
132 "cpu".into()
133}
134
135fn default_temperature() -> f64 {
136 0.7
137}
138
139fn default_max_tokens() -> usize {
140 2048
141}
142
143fn default_seed() -> u64 {
144 42
145}
146
147fn default_repeat_penalty() -> f32 {
148 1.1
149}
150
151fn default_repeat_last_n() -> usize {
152 64
153}
154
155fn default_cascade_quality_threshold() -> f64 {
156 0.5
157}
158
159fn default_cascade_max_escalations() -> u8 {
160 2
161}
162
163fn default_cascade_window_size() -> usize {
164 50
165}
166
167fn default_reputation_decay_factor() -> f64 {
168 0.95
169}
170
171fn default_reputation_weight() -> f64 {
172 0.3
173}
174
175fn default_reputation_min_observations() -> u64 {
176 5
177}
178
179#[must_use]
181pub fn default_stt_provider() -> String {
182 String::new()
183}
184
185#[must_use]
187pub fn default_stt_language() -> String {
188 "auto".into()
189}
190
191#[must_use]
193pub fn get_default_embedding_model() -> String {
194 default_embedding_model()
195}
196
197#[must_use]
199pub fn get_default_response_cache_ttl_secs() -> u64 {
200 default_response_cache_ttl_secs()
201}
202
203#[must_use]
205pub fn get_default_router_ema_alpha() -> f64 {
206 default_router_ema_alpha()
207}
208
209#[must_use]
211pub fn get_default_router_reorder_interval() -> u64 {
212 default_router_reorder_interval()
213}
214
215#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
228#[serde(rename_all = "lowercase")]
229pub enum ProviderKind {
230 Ollama,
232 Claude,
234 OpenAi,
236 Gemini,
238 Candle,
240 Compatible,
242}
243
244impl ProviderKind {
245 #[must_use]
256 pub fn as_str(self) -> &'static str {
257 match self {
258 Self::Ollama => "ollama",
259 Self::Claude => "claude",
260 Self::OpenAi => "openai",
261 Self::Gemini => "gemini",
262 Self::Candle => "candle",
263 Self::Compatible => "compatible",
264 }
265 }
266}
267
268impl std::fmt::Display for ProviderKind {
269 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
270 f.write_str(self.as_str())
271 }
272}
273
274#[derive(Debug, Deserialize, Serialize)]
298pub struct LlmConfig {
299 #[serde(default, skip_serializing_if = "Vec::is_empty")]
301 pub providers: Vec<ProviderEntry>,
302
303 #[serde(default, skip_serializing_if = "is_routing_none")]
305 pub routing: LlmRoutingStrategy,
306
307 #[serde(default = "default_embedding_model_opt")]
308 pub embedding_model: String,
309 #[serde(default, skip_serializing_if = "Option::is_none")]
310 pub candle: Option<CandleConfig>,
311 #[serde(default)]
312 pub stt: Option<SttConfig>,
313 #[serde(default)]
314 pub response_cache_enabled: bool,
315 #[serde(default = "default_response_cache_ttl_secs")]
316 pub response_cache_ttl_secs: u64,
317 #[serde(default)]
319 pub semantic_cache_enabled: bool,
320 #[serde(default = "default_semantic_cache_threshold")]
326 pub semantic_cache_threshold: f32,
327 #[serde(default = "default_semantic_cache_max_candidates")]
340 pub semantic_cache_max_candidates: u32,
341 #[serde(default)]
342 pub router_ema_enabled: bool,
343 #[serde(default = "default_router_ema_alpha")]
344 pub router_ema_alpha: f64,
345 #[serde(default = "default_router_reorder_interval")]
346 pub router_reorder_interval: u64,
347 #[serde(default, skip_serializing_if = "Option::is_none")]
349 pub router: Option<RouterConfig>,
350 #[serde(default, skip_serializing_if = "Option::is_none")]
353 pub instruction_file: Option<std::path::PathBuf>,
354 #[serde(default, skip_serializing_if = "Option::is_none")]
358 pub summary_model: Option<String>,
359 #[serde(default, skip_serializing_if = "Option::is_none")]
361 pub summary_provider: Option<ProviderEntry>,
362
363 #[serde(default, skip_serializing_if = "Option::is_none")]
365 pub complexity_routing: Option<ComplexityRoutingConfig>,
366
367 #[serde(default, skip_serializing_if = "Option::is_none")]
369 pub coe: Option<CoeConfig>,
370}
371
372fn default_embedding_model_opt() -> String {
373 default_embedding_model()
374}
375
376#[allow(clippy::trivially_copy_pass_by_ref)]
377fn is_routing_none(s: &LlmRoutingStrategy) -> bool {
378 *s == LlmRoutingStrategy::None
379}
380
381impl LlmConfig {
382 #[must_use]
384 pub fn effective_provider(&self) -> ProviderKind {
385 self.providers
386 .first()
387 .map_or(ProviderKind::Ollama, |e| e.provider_type)
388 }
389
390 #[must_use]
392 pub fn effective_base_url(&self) -> &str {
393 self.providers
394 .first()
395 .and_then(|e| e.base_url.as_deref())
396 .unwrap_or("http://localhost:11434")
397 }
398
399 #[must_use]
401 pub fn effective_model(&self) -> &str {
402 self.providers
403 .first()
404 .and_then(|e| e.model.as_deref())
405 .unwrap_or("qwen3:8b")
406 }
407
408 #[must_use]
416 pub fn stt_provider_entry(&self) -> Option<&ProviderEntry> {
417 let name_hint = self.stt.as_ref().map_or("", |s| s.provider.as_str());
418 if name_hint.is_empty() {
419 self.providers.iter().find(|p| p.stt_model.is_some())
420 } else {
421 self.providers
422 .iter()
423 .find(|p| p.effective_name() == name_hint && p.stt_model.is_some())
424 }
425 }
426
427 pub fn check_legacy_format(&self) -> Result<(), crate::error::ConfigError> {
433 Ok(())
434 }
435
436 pub fn validate_stt(&self) -> Result<(), crate::error::ConfigError> {
442 use crate::error::ConfigError;
443
444 let Some(stt) = &self.stt else {
445 return Ok(());
446 };
447 if stt.provider.is_empty() {
448 return Ok(());
449 }
450 let found = self
451 .providers
452 .iter()
453 .find(|p| p.effective_name() == stt.provider);
454 match found {
455 None => {
456 return Err(ConfigError::Validation(format!(
457 "[llm.stt].provider = {:?} does not match any [[llm.providers]] entry",
458 stt.provider
459 )));
460 }
461 Some(entry) if entry.stt_model.is_none() => {
462 tracing::warn!(
463 provider = stt.provider,
464 "[[llm.providers]] entry exists but has no `stt_model` — STT will not be activated"
465 );
466 }
467 _ => {}
468 }
469 Ok(())
470 }
471}
472
473#[derive(Debug, Clone, Deserialize, Serialize)]
486pub struct SttConfig {
487 #[serde(default = "default_stt_provider")]
490 pub provider: String,
491 #[serde(default = "default_stt_language")]
493 pub language: String,
494}
495
496#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
498#[serde(rename_all = "lowercase")]
499pub enum RouterStrategyConfig {
500 #[default]
502 Ema,
503 Thompson,
505 Cascade,
507 Bandit,
509}
510
511#[derive(Debug, Clone, Deserialize, Serialize)]
524pub struct AsiConfig {
525 #[serde(default)]
527 pub enabled: bool,
528
529 #[serde(default = "default_asi_window")]
531 pub window: usize,
532
533 #[serde(default = "default_asi_coherence_threshold")]
535 pub coherence_threshold: f32,
536
537 #[serde(default = "default_asi_penalty_weight")]
542 pub penalty_weight: f32,
543}
544
545fn default_asi_window() -> usize {
546 5
547}
548
549fn default_asi_coherence_threshold() -> f32 {
550 0.7
551}
552
553fn default_asi_penalty_weight() -> f32 {
554 0.3
555}
556
557impl Default for AsiConfig {
558 fn default() -> Self {
559 Self {
560 enabled: false,
561 window: default_asi_window(),
562 coherence_threshold: default_asi_coherence_threshold(),
563 penalty_weight: default_asi_penalty_weight(),
564 }
565 }
566}
567
568#[derive(Debug, Clone, Deserialize, Serialize)]
570pub struct RouterConfig {
571 #[serde(default)]
573 pub strategy: RouterStrategyConfig,
574 #[serde(default)]
582 pub thompson_state_path: Option<String>,
583 #[serde(default)]
585 pub cascade: Option<CascadeConfig>,
586 #[serde(default)]
588 pub reputation: Option<ReputationConfig>,
589 #[serde(default)]
591 pub bandit: Option<BanditConfig>,
592 #[serde(default)]
601 pub quality_gate: Option<f32>,
602 #[serde(default)]
604 pub asi: Option<AsiConfig>,
605 #[serde(default = "default_embed_concurrency")]
611 pub embed_concurrency: usize,
612}
613
614fn default_embed_concurrency() -> usize {
615 4
616}
617
618#[derive(Debug, Clone, Deserialize, Serialize)]
625pub struct ReputationConfig {
626 #[serde(default)]
628 pub enabled: bool,
629 #[serde(default = "default_reputation_decay_factor")]
632 pub decay_factor: f64,
633 #[serde(default = "default_reputation_weight")]
640 pub weight: f64,
641 #[serde(default = "default_reputation_min_observations")]
643 pub min_observations: u64,
644 #[serde(default)]
646 pub state_path: Option<String>,
647}
648
649#[derive(Debug, Clone, Deserialize, Serialize)]
660pub struct CascadeConfig {
661 #[serde(default = "default_cascade_quality_threshold")]
664 pub quality_threshold: f64,
665
666 #[serde(default = "default_cascade_max_escalations")]
670 pub max_escalations: u8,
671
672 #[serde(default)]
676 pub classifier_mode: CascadeClassifierMode,
677
678 #[serde(default = "default_cascade_window_size")]
680 pub window_size: usize,
681
682 #[serde(default)]
686 pub max_cascade_tokens: Option<u32>,
687
688 #[serde(default, skip_serializing_if = "Option::is_none")]
693 pub cost_tiers: Option<Vec<String>>,
694}
695
696impl Default for CascadeConfig {
697 fn default() -> Self {
698 Self {
699 quality_threshold: default_cascade_quality_threshold(),
700 max_escalations: default_cascade_max_escalations(),
701 classifier_mode: CascadeClassifierMode::default(),
702 window_size: default_cascade_window_size(),
703 max_cascade_tokens: None,
704 cost_tiers: None,
705 }
706 }
707}
708
709#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
711#[serde(rename_all = "lowercase")]
712pub enum CascadeClassifierMode {
713 #[default]
716 Heuristic,
717 Judge,
720}
721
722fn default_bandit_alpha() -> f32 {
723 1.0
724}
725
726fn default_bandit_dim() -> usize {
727 32
728}
729
730fn default_bandit_cost_weight() -> f32 {
731 0.1
732}
733
734fn default_bandit_decay_factor() -> f32 {
735 1.0
736}
737
738fn default_bandit_embedding_timeout_ms() -> u64 {
739 50
740}
741
742fn default_bandit_cache_size() -> usize {
743 512
744}
745
746#[derive(Debug, Clone, Deserialize, Serialize)]
759pub struct BanditConfig {
760 #[serde(default = "default_bandit_alpha")]
763 pub alpha: f32,
764
765 #[serde(default = "default_bandit_dim")]
772 pub dim: usize,
773
774 #[serde(default = "default_bandit_cost_weight")]
777 pub cost_weight: f32,
778
779 #[serde(default = "default_bandit_decay_factor")]
782 pub decay_factor: f32,
783
784 #[serde(default)]
790 pub embedding_provider: ProviderName,
791
792 #[serde(default = "default_bandit_embedding_timeout_ms")]
795 pub embedding_timeout_ms: u64,
796
797 #[serde(default = "default_bandit_cache_size")]
799 pub cache_size: usize,
800
801 #[serde(default)]
808 pub state_path: Option<String>,
809
810 #[serde(default = "default_bandit_memory_confidence_threshold")]
816 pub memory_confidence_threshold: f32,
817
818 #[serde(default)]
824 pub warmup_queries: Option<u64>,
825}
826
827fn default_bandit_memory_confidence_threshold() -> f32 {
828 0.9
829}
830
831impl Default for BanditConfig {
832 fn default() -> Self {
833 Self {
834 alpha: default_bandit_alpha(),
835 dim: default_bandit_dim(),
836 cost_weight: default_bandit_cost_weight(),
837 decay_factor: default_bandit_decay_factor(),
838 embedding_provider: ProviderName::default(),
839 embedding_timeout_ms: default_bandit_embedding_timeout_ms(),
840 cache_size: default_bandit_cache_size(),
841 state_path: None,
842 memory_confidence_threshold: default_bandit_memory_confidence_threshold(),
843 warmup_queries: None,
844 }
845 }
846}
847
848#[derive(Debug, Deserialize, Serialize)]
849pub struct CandleConfig {
850 #[serde(default = "default_candle_source")]
851 pub source: String,
852 #[serde(default)]
853 pub local_path: String,
854 #[serde(default)]
855 pub filename: Option<String>,
856 #[serde(default = "default_chat_template")]
857 pub chat_template: String,
858 #[serde(default = "default_candle_device")]
859 pub device: String,
860 #[serde(default)]
861 pub embedding_repo: Option<String>,
862 #[serde(default)]
866 pub hf_token: Option<String>,
867 #[serde(default)]
868 pub generation: GenerationParams,
869 #[serde(default = "default_inference_timeout_secs")]
878 pub inference_timeout_secs: u64,
879}
880
881fn default_inference_timeout_secs() -> u64 {
882 120
883}
884
885#[derive(Debug, Clone, Deserialize, Serialize)]
889pub struct GenerationParams {
890 #[serde(default = "default_temperature")]
892 pub temperature: f64,
893 #[serde(default)]
896 pub top_p: Option<f64>,
897 #[serde(default)]
900 pub top_k: Option<usize>,
901 #[serde(default = "default_max_tokens")]
904 pub max_tokens: usize,
905 #[serde(default = "default_seed")]
907 pub seed: u64,
908 #[serde(default = "default_repeat_penalty")]
910 pub repeat_penalty: f32,
911 #[serde(default = "default_repeat_last_n")]
913 pub repeat_last_n: usize,
914}
915
916pub const MAX_TOKENS_CAP: usize = 32768;
918
919impl GenerationParams {
920 #[must_use]
931 pub fn capped_max_tokens(&self) -> usize {
932 self.max_tokens.min(MAX_TOKENS_CAP)
933 }
934}
935
936impl Default for GenerationParams {
937 fn default() -> Self {
938 Self {
939 temperature: default_temperature(),
940 top_p: None,
941 top_k: None,
942 max_tokens: default_max_tokens(),
943 seed: default_seed(),
944 repeat_penalty: default_repeat_penalty(),
945 repeat_last_n: default_repeat_last_n(),
946 }
947 }
948}
949
950#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
954#[serde(rename_all = "lowercase")]
955pub enum LlmRoutingStrategy {
956 #[default]
958 None,
959 Ema,
961 Thompson,
963 Cascade,
965 Triage,
967 Bandit,
969}
970
971fn default_triage_timeout_secs() -> u64 {
972 5
973}
974
975fn default_max_triage_tokens() -> u32 {
976 50
977}
978
979fn default_true() -> bool {
980 true
981}
982
983#[derive(Debug, Clone, Default, Deserialize, Serialize)]
985pub struct TierMapping {
986 pub simple: Option<String>,
987 pub medium: Option<String>,
988 pub complex: Option<String>,
989 pub expert: Option<String>,
990}
991
992#[derive(Debug, Clone, Deserialize, Serialize)]
1013pub struct ComplexityRoutingConfig {
1014 #[serde(default)]
1016 pub triage_provider: Option<ProviderName>,
1017
1018 #[serde(default = "default_true")]
1020 pub bypass_single_provider: bool,
1021
1022 #[serde(default)]
1024 pub tiers: TierMapping,
1025
1026 #[serde(default = "default_max_triage_tokens")]
1028 pub max_triage_tokens: u32,
1029
1030 #[serde(default = "default_triage_timeout_secs")]
1033 pub triage_timeout_secs: u64,
1034
1035 #[serde(default)]
1038 pub fallback_strategy: Option<String>,
1039}
1040
1041impl Default for ComplexityRoutingConfig {
1042 fn default() -> Self {
1043 Self {
1044 triage_provider: None,
1045 bypass_single_provider: true,
1046 tiers: TierMapping::default(),
1047 max_triage_tokens: default_max_triage_tokens(),
1048 triage_timeout_secs: default_triage_timeout_secs(),
1049 fallback_strategy: None,
1050 }
1051 }
1052}
1053
1054#[derive(Debug, Clone, Deserialize, Serialize)]
1072#[serde(default)]
1073pub struct CoeConfig {
1074 pub enabled: bool,
1076 pub intra_threshold: f64,
1078 pub inter_threshold: f64,
1080 pub shadow_sample_rate: f64,
1082 pub secondary_provider: ProviderName,
1084 pub embed_provider: ProviderName,
1086}
1087
1088impl Default for CoeConfig {
1089 fn default() -> Self {
1090 Self {
1091 enabled: false,
1092 intra_threshold: 0.8,
1093 inter_threshold: 0.20,
1094 shadow_sample_rate: 0.1,
1095 secondary_provider: ProviderName::default(),
1096 embed_provider: ProviderName::default(),
1097 }
1098 }
1099}
1100
1101#[derive(Debug, Clone, Deserialize, Serialize)]
1104pub struct CandleInlineConfig {
1105 #[serde(default = "default_candle_source")]
1106 pub source: String,
1107 #[serde(default)]
1108 pub local_path: String,
1109 #[serde(default)]
1110 pub filename: Option<String>,
1111 #[serde(default = "default_chat_template")]
1112 pub chat_template: String,
1113 #[serde(default = "default_candle_device")]
1114 pub device: String,
1115 #[serde(default)]
1116 pub embedding_repo: Option<String>,
1117 #[serde(default)]
1119 pub hf_token: Option<String>,
1120 #[serde(default)]
1121 pub generation: GenerationParams,
1122 #[serde(default = "default_inference_timeout_secs")]
1127 pub inference_timeout_secs: u64,
1128}
1129
1130impl Default for CandleInlineConfig {
1131 fn default() -> Self {
1132 Self {
1133 source: default_candle_source(),
1134 local_path: String::new(),
1135 filename: None,
1136 chat_template: default_chat_template(),
1137 device: default_candle_device(),
1138 embedding_repo: None,
1139 hf_token: None,
1140 generation: GenerationParams::default(),
1141 inference_timeout_secs: default_inference_timeout_secs(),
1142 }
1143 }
1144}
1145
1146#[derive(Debug, Clone, Deserialize, Serialize)]
1152#[allow(clippy::struct_excessive_bools)]
1153pub struct ProviderEntry {
1154 #[serde(rename = "type")]
1156 pub provider_type: ProviderKind,
1157
1158 #[serde(default)]
1160 pub name: Option<String>,
1161
1162 #[serde(default)]
1164 pub model: Option<String>,
1165
1166 #[serde(default)]
1168 pub base_url: Option<String>,
1169
1170 #[serde(default)]
1172 pub max_tokens: Option<u32>,
1173
1174 #[serde(default)]
1176 pub embedding_model: Option<String>,
1177
1178 #[serde(default)]
1181 pub stt_model: Option<String>,
1182
1183 #[serde(default)]
1185 pub embed: bool,
1186
1187 #[serde(default)]
1189 pub default: bool,
1190
1191 #[serde(default)]
1193 pub thinking: Option<ThinkingConfig>,
1194 #[serde(default)]
1195 pub server_compaction: bool,
1196 #[serde(default)]
1197 pub enable_extended_context: bool,
1198 #[serde(default)]
1201 pub prompt_cache_ttl: Option<CacheTtl>,
1202
1203 #[serde(default)]
1205 pub reasoning_effort: Option<String>,
1206
1207 #[serde(default)]
1209 pub thinking_level: Option<GeminiThinkingLevel>,
1210 #[serde(default)]
1211 pub thinking_budget: Option<i32>,
1212 #[serde(default)]
1213 pub include_thoughts: Option<bool>,
1214
1215 #[serde(default)]
1217 pub api_key: Option<String>,
1218
1219 #[serde(default)]
1221 pub candle: Option<CandleInlineConfig>,
1222
1223 #[serde(default)]
1225 pub vision_model: Option<String>,
1226
1227 #[serde(default)]
1229 pub instruction_file: Option<std::path::PathBuf>,
1230}
1231
1232impl Default for ProviderEntry {
1233 fn default() -> Self {
1234 Self {
1235 provider_type: ProviderKind::Ollama,
1236 name: None,
1237 model: None,
1238 base_url: None,
1239 max_tokens: None,
1240 embedding_model: None,
1241 stt_model: None,
1242 embed: false,
1243 default: false,
1244 thinking: None,
1245 server_compaction: false,
1246 enable_extended_context: false,
1247 prompt_cache_ttl: None,
1248 reasoning_effort: None,
1249 thinking_level: None,
1250 thinking_budget: None,
1251 include_thoughts: None,
1252 api_key: None,
1253 candle: None,
1254 vision_model: None,
1255 instruction_file: None,
1256 }
1257 }
1258}
1259
1260impl ProviderEntry {
1261 #[must_use]
1263 pub fn effective_name(&self) -> String {
1264 self.name
1265 .clone()
1266 .unwrap_or_else(|| self.provider_type.as_str().to_owned())
1267 }
1268
1269 #[must_use]
1274 pub fn effective_model(&self) -> String {
1275 if let Some(ref m) = self.model {
1276 return m.clone();
1277 }
1278 match self.provider_type {
1279 ProviderKind::Ollama => "qwen3:8b".to_owned(),
1280 ProviderKind::Claude => "claude-haiku-4-5-20251001".to_owned(),
1281 ProviderKind::OpenAi => "gpt-4o-mini".to_owned(),
1282 ProviderKind::Gemini => "gemini-2.0-flash".to_owned(),
1283 ProviderKind::Compatible | ProviderKind::Candle => String::new(),
1284 }
1285 }
1286
1287 pub fn validate(&self) -> Result<(), crate::error::ConfigError> {
1294 use crate::error::ConfigError;
1295
1296 if self.provider_type == ProviderKind::Compatible && self.name.is_none() {
1298 return Err(ConfigError::Validation(
1299 "[[llm.providers]] entry with type=\"compatible\" must set `name`".into(),
1300 ));
1301 }
1302
1303 match self.provider_type {
1305 ProviderKind::Ollama => {
1306 if self.thinking.is_some() {
1307 tracing::warn!(
1308 provider = self.effective_name(),
1309 "field `thinking` is only used by Claude providers"
1310 );
1311 }
1312 if self.reasoning_effort.is_some() {
1313 tracing::warn!(
1314 provider = self.effective_name(),
1315 "field `reasoning_effort` is only used by OpenAI providers"
1316 );
1317 }
1318 if self.thinking_level.is_some() || self.thinking_budget.is_some() {
1319 tracing::warn!(
1320 provider = self.effective_name(),
1321 "fields `thinking_level`/`thinking_budget` are only used by Gemini providers"
1322 );
1323 }
1324 }
1325 ProviderKind::Claude => {
1326 if self.reasoning_effort.is_some() {
1327 tracing::warn!(
1328 provider = self.effective_name(),
1329 "field `reasoning_effort` is only used by OpenAI providers"
1330 );
1331 }
1332 if self.thinking_level.is_some() || self.thinking_budget.is_some() {
1333 tracing::warn!(
1334 provider = self.effective_name(),
1335 "fields `thinking_level`/`thinking_budget` are only used by Gemini providers"
1336 );
1337 }
1338 }
1339 ProviderKind::OpenAi => {
1340 if self.thinking.is_some() {
1341 tracing::warn!(
1342 provider = self.effective_name(),
1343 "field `thinking` is only used by Claude providers"
1344 );
1345 }
1346 if self.thinking_level.is_some() || self.thinking_budget.is_some() {
1347 tracing::warn!(
1348 provider = self.effective_name(),
1349 "fields `thinking_level`/`thinking_budget` are only used by Gemini providers"
1350 );
1351 }
1352 }
1353 ProviderKind::Gemini => {
1354 if self.thinking.is_some() {
1355 tracing::warn!(
1356 provider = self.effective_name(),
1357 "field `thinking` is only used by Claude providers"
1358 );
1359 }
1360 if self.reasoning_effort.is_some() {
1361 tracing::warn!(
1362 provider = self.effective_name(),
1363 "field `reasoning_effort` is only used by OpenAI providers"
1364 );
1365 }
1366 }
1367 _ => {}
1368 }
1369
1370 if self.stt_model.is_some() && self.provider_type == ProviderKind::Ollama {
1373 tracing::warn!(
1374 provider = self.effective_name(),
1375 "field `stt_model` is set on an Ollama provider; Ollama does not support the \
1376 Whisper STT API — use OpenAI, compatible, or candle instead"
1377 );
1378 }
1379
1380 Ok(())
1381 }
1382}
1383
1384pub fn validate_pool(entries: &[ProviderEntry]) -> Result<(), crate::error::ConfigError> {
1394 use crate::error::ConfigError;
1395 use std::collections::HashSet;
1396
1397 if entries.is_empty() {
1398 return Err(ConfigError::Validation(
1399 "at least one LLM provider must be configured in [[llm.providers]]".into(),
1400 ));
1401 }
1402
1403 let default_count = entries.iter().filter(|e| e.default).count();
1404 if default_count > 1 {
1405 return Err(ConfigError::Validation(
1406 "only one [[llm.providers]] entry can be marked `default = true`".into(),
1407 ));
1408 }
1409
1410 let mut seen_names: HashSet<String> = HashSet::new();
1411 for entry in entries {
1412 let name = entry.effective_name();
1413 if !seen_names.insert(name.clone()) {
1414 return Err(ConfigError::Validation(format!(
1415 "duplicate provider name \"{name}\" in [[llm.providers]]"
1416 )));
1417 }
1418 entry.validate()?;
1419 }
1420
1421 Ok(())
1422}
1423
1424#[cfg(test)]
1425mod tests {
1426 use super::*;
1427
1428 fn ollama_entry() -> ProviderEntry {
1429 ProviderEntry {
1430 provider_type: ProviderKind::Ollama,
1431 name: Some("ollama".into()),
1432 model: Some("qwen3:8b".into()),
1433 ..Default::default()
1434 }
1435 }
1436
1437 fn claude_entry() -> ProviderEntry {
1438 ProviderEntry {
1439 provider_type: ProviderKind::Claude,
1440 name: Some("claude".into()),
1441 model: Some("claude-sonnet-4-6".into()),
1442 max_tokens: Some(8192),
1443 ..Default::default()
1444 }
1445 }
1446
1447 #[test]
1450 fn validate_ollama_valid() {
1451 assert!(ollama_entry().validate().is_ok());
1452 }
1453
1454 #[test]
1455 fn validate_claude_valid() {
1456 assert!(claude_entry().validate().is_ok());
1457 }
1458
1459 #[test]
1460 fn validate_compatible_without_name_errors() {
1461 let entry = ProviderEntry {
1462 provider_type: ProviderKind::Compatible,
1463 name: None,
1464 ..Default::default()
1465 };
1466 let err = entry.validate().unwrap_err();
1467 assert!(
1468 err.to_string().contains("compatible"),
1469 "error should mention compatible: {err}"
1470 );
1471 }
1472
1473 #[test]
1474 fn validate_compatible_with_name_ok() {
1475 let entry = ProviderEntry {
1476 provider_type: ProviderKind::Compatible,
1477 name: Some("my-proxy".into()),
1478 base_url: Some("http://localhost:8080".into()),
1479 model: Some("gpt-4o".into()),
1480 max_tokens: Some(4096),
1481 ..Default::default()
1482 };
1483 assert!(entry.validate().is_ok());
1484 }
1485
1486 #[test]
1487 fn validate_openai_valid() {
1488 let entry = ProviderEntry {
1489 provider_type: ProviderKind::OpenAi,
1490 name: Some("openai".into()),
1491 model: Some("gpt-4o".into()),
1492 max_tokens: Some(4096),
1493 ..Default::default()
1494 };
1495 assert!(entry.validate().is_ok());
1496 }
1497
1498 #[test]
1499 fn validate_gemini_valid() {
1500 let entry = ProviderEntry {
1501 provider_type: ProviderKind::Gemini,
1502 name: Some("gemini".into()),
1503 model: Some("gemini-2.0-flash".into()),
1504 ..Default::default()
1505 };
1506 assert!(entry.validate().is_ok());
1507 }
1508
1509 #[test]
1512 fn validate_pool_empty_errors() {
1513 let err = validate_pool(&[]).unwrap_err();
1514 assert!(err.to_string().contains("at least one"), "{err}");
1515 }
1516
1517 #[test]
1518 fn validate_pool_single_entry_ok() {
1519 assert!(validate_pool(&[ollama_entry()]).is_ok());
1520 }
1521
1522 #[test]
1523 fn validate_pool_duplicate_names_errors() {
1524 let a = ollama_entry();
1525 let b = ollama_entry(); let err = validate_pool(&[a, b]).unwrap_err();
1527 assert!(err.to_string().contains("duplicate"), "{err}");
1528 }
1529
1530 #[test]
1531 fn validate_pool_multiple_defaults_errors() {
1532 let mut a = ollama_entry();
1533 let mut b = claude_entry();
1534 a.default = true;
1535 b.default = true;
1536 let err = validate_pool(&[a, b]).unwrap_err();
1537 assert!(err.to_string().contains("default"), "{err}");
1538 }
1539
1540 #[test]
1541 fn validate_pool_two_different_providers_ok() {
1542 assert!(validate_pool(&[ollama_entry(), claude_entry()]).is_ok());
1543 }
1544
1545 #[test]
1546 fn validate_pool_propagates_entry_error() {
1547 let bad = ProviderEntry {
1548 provider_type: ProviderKind::Compatible,
1549 name: None, ..Default::default()
1551 };
1552 assert!(validate_pool(&[bad]).is_err());
1553 }
1554
1555 #[test]
1558 fn effective_model_returns_explicit_when_set() {
1559 let entry = ProviderEntry {
1560 provider_type: ProviderKind::Claude,
1561 model: Some("claude-sonnet-4-6".into()),
1562 ..Default::default()
1563 };
1564 assert_eq!(entry.effective_model(), "claude-sonnet-4-6");
1565 }
1566
1567 #[test]
1568 fn effective_model_ollama_default_when_none() {
1569 let entry = ProviderEntry {
1570 provider_type: ProviderKind::Ollama,
1571 model: None,
1572 ..Default::default()
1573 };
1574 assert_eq!(entry.effective_model(), "qwen3:8b");
1575 }
1576
1577 #[test]
1578 fn effective_model_claude_default_when_none() {
1579 let entry = ProviderEntry {
1580 provider_type: ProviderKind::Claude,
1581 model: None,
1582 ..Default::default()
1583 };
1584 assert_eq!(entry.effective_model(), "claude-haiku-4-5-20251001");
1585 }
1586
1587 #[test]
1588 fn effective_model_openai_default_when_none() {
1589 let entry = ProviderEntry {
1590 provider_type: ProviderKind::OpenAi,
1591 model: None,
1592 ..Default::default()
1593 };
1594 assert_eq!(entry.effective_model(), "gpt-4o-mini");
1595 }
1596
1597 #[test]
1598 fn effective_model_gemini_default_when_none() {
1599 let entry = ProviderEntry {
1600 provider_type: ProviderKind::Gemini,
1601 model: None,
1602 ..Default::default()
1603 };
1604 assert_eq!(entry.effective_model(), "gemini-2.0-flash");
1605 }
1606
1607 fn parse_llm(toml: &str) -> LlmConfig {
1611 #[derive(serde::Deserialize)]
1612 struct Wrapper {
1613 llm: LlmConfig,
1614 }
1615 toml::from_str::<Wrapper>(toml).unwrap().llm
1616 }
1617
1618 #[test]
1619 fn check_legacy_format_new_format_ok() {
1620 let cfg = parse_llm(
1621 r#"
1622[llm]
1623
1624[[llm.providers]]
1625type = "ollama"
1626model = "qwen3:8b"
1627"#,
1628 );
1629 assert!(cfg.check_legacy_format().is_ok());
1630 }
1631
1632 #[test]
1633 fn check_legacy_format_empty_providers_no_legacy_ok() {
1634 let cfg = parse_llm("[llm]\n");
1636 assert!(cfg.check_legacy_format().is_ok());
1637 }
1638
1639 #[test]
1642 fn effective_provider_falls_back_to_ollama_when_no_providers() {
1643 let cfg = parse_llm("[llm]\n");
1644 assert_eq!(cfg.effective_provider(), ProviderKind::Ollama);
1645 }
1646
1647 #[test]
1648 fn effective_provider_reads_from_providers_first() {
1649 let cfg = parse_llm(
1650 r#"
1651[llm]
1652
1653[[llm.providers]]
1654type = "claude"
1655model = "claude-sonnet-4-6"
1656"#,
1657 );
1658 assert_eq!(cfg.effective_provider(), ProviderKind::Claude);
1659 }
1660
1661 #[test]
1662 fn effective_model_reads_from_providers_first() {
1663 let cfg = parse_llm(
1664 r#"
1665[llm]
1666
1667[[llm.providers]]
1668type = "ollama"
1669model = "qwen3:8b"
1670"#,
1671 );
1672 assert_eq!(cfg.effective_model(), "qwen3:8b");
1673 }
1674
1675 #[test]
1676 fn effective_base_url_default_when_absent() {
1677 let cfg = parse_llm("[llm]\n");
1678 assert_eq!(cfg.effective_base_url(), "http://localhost:11434");
1679 }
1680
1681 #[test]
1682 fn effective_base_url_from_providers_entry() {
1683 let cfg = parse_llm(
1684 r#"
1685[llm]
1686
1687[[llm.providers]]
1688type = "ollama"
1689base_url = "http://myhost:11434"
1690"#,
1691 );
1692 assert_eq!(cfg.effective_base_url(), "http://myhost:11434");
1693 }
1694
1695 #[test]
1698 fn complexity_routing_defaults() {
1699 let cr = ComplexityRoutingConfig::default();
1700 assert!(
1701 cr.bypass_single_provider,
1702 "bypass_single_provider must default to true"
1703 );
1704 assert_eq!(cr.triage_timeout_secs, 5);
1705 assert_eq!(cr.max_triage_tokens, 50);
1706 assert!(cr.triage_provider.is_none());
1707 assert!(cr.tiers.simple.is_none());
1708 }
1709
1710 #[test]
1711 fn complexity_routing_toml_round_trip() {
1712 let cfg = parse_llm(
1713 r#"
1714[llm]
1715routing = "triage"
1716
1717[llm.complexity_routing]
1718triage_provider = "fast"
1719bypass_single_provider = false
1720triage_timeout_secs = 10
1721max_triage_tokens = 100
1722
1723[llm.complexity_routing.tiers]
1724simple = "fast"
1725medium = "medium"
1726complex = "large"
1727expert = "opus"
1728"#,
1729 );
1730 assert!(matches!(cfg.routing, LlmRoutingStrategy::Triage));
1731 let cr = cfg
1732 .complexity_routing
1733 .expect("complexity_routing must be present");
1734 assert_eq!(cr.triage_provider.as_deref(), Some("fast"));
1735 assert!(!cr.bypass_single_provider);
1736 assert_eq!(cr.triage_timeout_secs, 10);
1737 assert_eq!(cr.max_triage_tokens, 100);
1738 assert_eq!(cr.tiers.simple.as_deref(), Some("fast"));
1739 assert_eq!(cr.tiers.medium.as_deref(), Some("medium"));
1740 assert_eq!(cr.tiers.complex.as_deref(), Some("large"));
1741 assert_eq!(cr.tiers.expert.as_deref(), Some("opus"));
1742 }
1743
1744 #[test]
1745 fn complexity_routing_partial_tiers_toml() {
1746 let cfg = parse_llm(
1748 r#"
1749[llm]
1750routing = "triage"
1751
1752[llm.complexity_routing.tiers]
1753simple = "haiku"
1754complex = "sonnet"
1755"#,
1756 );
1757 let cr = cfg
1758 .complexity_routing
1759 .expect("complexity_routing must be present");
1760 assert_eq!(cr.tiers.simple.as_deref(), Some("haiku"));
1761 assert!(cr.tiers.medium.is_none());
1762 assert_eq!(cr.tiers.complex.as_deref(), Some("sonnet"));
1763 assert!(cr.tiers.expert.is_none());
1764 assert!(cr.bypass_single_provider);
1766 assert_eq!(cr.triage_timeout_secs, 5);
1767 }
1768
1769 #[test]
1770 fn routing_strategy_triage_deserialized() {
1771 let cfg = parse_llm(
1772 r#"
1773[llm]
1774routing = "triage"
1775"#,
1776 );
1777 assert!(matches!(cfg.routing, LlmRoutingStrategy::Triage));
1778 }
1779
1780 #[test]
1783 fn stt_provider_entry_by_name_match() {
1784 let cfg = parse_llm(
1785 r#"
1786[llm]
1787
1788[[llm.providers]]
1789type = "openai"
1790name = "quality"
1791model = "gpt-5.4"
1792stt_model = "gpt-4o-mini-transcribe"
1793
1794[llm.stt]
1795provider = "quality"
1796"#,
1797 );
1798 let entry = cfg.stt_provider_entry().expect("should find stt provider");
1799 assert_eq!(entry.effective_name(), "quality");
1800 assert_eq!(entry.stt_model.as_deref(), Some("gpt-4o-mini-transcribe"));
1801 }
1802
1803 #[test]
1804 fn stt_provider_entry_auto_detect_when_provider_empty() {
1805 let cfg = parse_llm(
1806 r#"
1807[llm]
1808
1809[[llm.providers]]
1810type = "openai"
1811name = "openai-stt"
1812stt_model = "whisper-1"
1813
1814[llm.stt]
1815provider = ""
1816"#,
1817 );
1818 let entry = cfg.stt_provider_entry().expect("should auto-detect");
1819 assert_eq!(entry.effective_name(), "openai-stt");
1820 }
1821
1822 #[test]
1823 fn stt_provider_entry_auto_detect_no_stt_section() {
1824 let cfg = parse_llm(
1825 r#"
1826[llm]
1827
1828[[llm.providers]]
1829type = "openai"
1830name = "openai-stt"
1831stt_model = "whisper-1"
1832"#,
1833 );
1834 let entry = cfg.stt_provider_entry().expect("should auto-detect");
1836 assert_eq!(entry.effective_name(), "openai-stt");
1837 }
1838
1839 #[test]
1840 fn stt_provider_entry_none_when_no_stt_model() {
1841 let cfg = parse_llm(
1842 r#"
1843[llm]
1844
1845[[llm.providers]]
1846type = "openai"
1847name = "quality"
1848model = "gpt-5.4"
1849"#,
1850 );
1851 assert!(cfg.stt_provider_entry().is_none());
1852 }
1853
1854 #[test]
1855 fn stt_provider_entry_name_mismatch_falls_back_to_none() {
1856 let cfg = parse_llm(
1858 r#"
1859[llm]
1860
1861[[llm.providers]]
1862type = "openai"
1863name = "quality"
1864model = "gpt-5.4"
1865
1866[[llm.providers]]
1867type = "openai"
1868name = "openai-stt"
1869stt_model = "whisper-1"
1870
1871[llm.stt]
1872provider = "quality"
1873"#,
1874 );
1875 assert!(cfg.stt_provider_entry().is_none());
1877 }
1878
1879 #[test]
1880 fn stt_config_deserializes_new_slim_format() {
1881 let cfg = parse_llm(
1882 r#"
1883[llm]
1884
1885[[llm.providers]]
1886type = "openai"
1887name = "quality"
1888stt_model = "whisper-1"
1889
1890[llm.stt]
1891provider = "quality"
1892language = "en"
1893"#,
1894 );
1895 let stt = cfg.stt.as_ref().expect("stt section present");
1896 assert_eq!(stt.provider, "quality");
1897 assert_eq!(stt.language, "en");
1898 }
1899
1900 #[test]
1901 fn stt_config_default_provider_is_empty() {
1902 assert_eq!(default_stt_provider(), "");
1904 }
1905
1906 #[test]
1907 fn validate_stt_missing_provider_ok() {
1908 let cfg = parse_llm("[llm]\n");
1909 assert!(cfg.validate_stt().is_ok());
1910 }
1911
1912 #[test]
1913 fn validate_stt_valid_reference() {
1914 let cfg = parse_llm(
1915 r#"
1916[llm]
1917
1918[[llm.providers]]
1919type = "openai"
1920name = "quality"
1921stt_model = "whisper-1"
1922
1923[llm.stt]
1924provider = "quality"
1925"#,
1926 );
1927 assert!(cfg.validate_stt().is_ok());
1928 }
1929
1930 #[test]
1931 fn validate_stt_nonexistent_provider_errors() {
1932 let cfg = parse_llm(
1933 r#"
1934[llm]
1935
1936[[llm.providers]]
1937type = "openai"
1938name = "quality"
1939model = "gpt-5.4"
1940
1941[llm.stt]
1942provider = "nonexistent"
1943"#,
1944 );
1945 assert!(cfg.validate_stt().is_err());
1946 }
1947
1948 #[test]
1949 fn validate_stt_provider_exists_but_no_stt_model_returns_ok_with_warn() {
1950 let cfg = parse_llm(
1952 r#"
1953[llm]
1954
1955[[llm.providers]]
1956type = "openai"
1957name = "quality"
1958model = "gpt-5.4"
1959
1960[llm.stt]
1961provider = "quality"
1962"#,
1963 );
1964 assert!(cfg.validate_stt().is_ok());
1966 assert!(
1968 cfg.stt_provider_entry().is_none(),
1969 "stt_provider_entry must be None when provider has no stt_model"
1970 );
1971 }
1972
1973 #[test]
1976 fn bandit_warmup_queries_explicit_value_is_deserialized() {
1977 let cfg = parse_llm(
1978 r#"
1979[llm]
1980
1981[llm.router]
1982strategy = "bandit"
1983
1984[llm.router.bandit]
1985warmup_queries = 50
1986"#,
1987 );
1988 let bandit = cfg
1989 .router
1990 .expect("router section must be present")
1991 .bandit
1992 .expect("bandit section must be present");
1993 assert_eq!(
1994 bandit.warmup_queries,
1995 Some(50),
1996 "warmup_queries = 50 must deserialize to Some(50)"
1997 );
1998 }
1999
2000 #[test]
2001 fn bandit_warmup_queries_explicit_null_is_none() {
2002 let cfg = parse_llm(
2005 r#"
2006[llm]
2007
2008[llm.router]
2009strategy = "bandit"
2010
2011[llm.router.bandit]
2012warmup_queries = 0
2013"#,
2014 );
2015 let bandit = cfg
2016 .router
2017 .expect("router section must be present")
2018 .bandit
2019 .expect("bandit section must be present");
2020 assert_eq!(
2022 bandit.warmup_queries,
2023 Some(0),
2024 "warmup_queries = 0 must deserialize to Some(0)"
2025 );
2026 }
2027
2028 #[test]
2029 fn bandit_warmup_queries_missing_field_defaults_to_none() {
2030 let cfg = parse_llm(
2032 r#"
2033[llm]
2034
2035[llm.router]
2036strategy = "bandit"
2037
2038[llm.router.bandit]
2039alpha = 1.5
2040"#,
2041 );
2042 let bandit = cfg
2043 .router
2044 .expect("router section must be present")
2045 .bandit
2046 .expect("bandit section must be present");
2047 assert_eq!(
2048 bandit.warmup_queries, None,
2049 "omitted warmup_queries must default to None"
2050 );
2051 }
2052
2053 #[test]
2054 fn provider_name_new_and_as_str() {
2055 let n = ProviderName::new("fast");
2056 assert_eq!(n.as_str(), "fast");
2057 assert!(!n.is_empty());
2058 }
2059
2060 #[test]
2061 fn provider_name_default_is_empty() {
2062 let n = ProviderName::default();
2063 assert!(n.is_empty());
2064 assert_eq!(n.as_str(), "");
2065 }
2066
2067 #[test]
2068 fn provider_name_deref_to_str() {
2069 let n = ProviderName::new("quality");
2070 let s: &str = &n;
2071 assert_eq!(s, "quality");
2072 }
2073
2074 #[test]
2075 fn provider_name_partial_eq_str() {
2076 let n = ProviderName::new("fast");
2077 assert_eq!(n, "fast");
2078 assert_ne!(n, "slow");
2079 }
2080
2081 #[test]
2082 fn provider_name_serde_roundtrip() {
2083 let n = ProviderName::new("my-provider");
2084 let json = serde_json::to_string(&n).expect("serialize");
2085 assert_eq!(json, "\"my-provider\"");
2086 let back: ProviderName = serde_json::from_str(&json).expect("deserialize");
2087 assert_eq!(back, n);
2088 }
2089
2090 #[test]
2091 fn provider_name_serde_empty_roundtrip() {
2092 let n = ProviderName::default();
2093 let json = serde_json::to_string(&n).expect("serialize");
2094 assert_eq!(json, "\"\"");
2095 let back: ProviderName = serde_json::from_str(&json).expect("deserialize");
2096 assert_eq!(back, n);
2097 assert!(back.is_empty());
2098 }
2099}