1use std::fmt;
5
6use serde::{Deserialize, Serialize};
7use zeph_llm::{GeminiThinkingLevel, ThinkingConfig};
8
9#[derive(Debug, Clone, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
14#[serde(transparent)]
15pub struct ProviderName(String);
16
17impl ProviderName {
18 #[must_use]
32 pub fn new(name: impl Into<String>) -> Self {
33 Self(name.into())
34 }
35
36 #[must_use]
47 pub fn is_empty(&self) -> bool {
48 self.0.is_empty()
49 }
50
51 #[must_use]
62 pub fn as_str(&self) -> &str {
63 &self.0
64 }
65}
66
67impl fmt::Display for ProviderName {
68 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
69 self.0.fmt(f)
70 }
71}
72
73impl AsRef<str> for ProviderName {
74 fn as_ref(&self) -> &str {
75 &self.0
76 }
77}
78
79impl std::ops::Deref for ProviderName {
80 type Target = str;
81
82 fn deref(&self) -> &str {
83 &self.0
84 }
85}
86
87impl PartialEq<str> for ProviderName {
88 fn eq(&self, other: &str) -> bool {
89 self.0 == other
90 }
91}
92
93impl PartialEq<&str> for ProviderName {
94 fn eq(&self, other: &&str) -> bool {
95 self.0 == *other
96 }
97}
98
99fn default_response_cache_ttl_secs() -> u64 {
100 3600
101}
102
103fn default_semantic_cache_threshold() -> f32 {
104 0.95
105}
106
107fn default_semantic_cache_max_candidates() -> u32 {
108 10
109}
110
111fn default_router_ema_alpha() -> f64 {
112 0.1
113}
114
115fn default_router_reorder_interval() -> u64 {
116 10
117}
118
119fn default_embedding_model() -> String {
120 "qwen3-embedding".into()
121}
122
123fn default_candle_source() -> String {
124 "huggingface".into()
125}
126
127fn default_chat_template() -> String {
128 "chatml".into()
129}
130
131fn default_candle_device() -> String {
132 "cpu".into()
133}
134
135fn default_temperature() -> f64 {
136 0.7
137}
138
139fn default_max_tokens() -> usize {
140 2048
141}
142
143fn default_seed() -> u64 {
144 42
145}
146
147fn default_repeat_penalty() -> f32 {
148 1.1
149}
150
151fn default_repeat_last_n() -> usize {
152 64
153}
154
155fn default_cascade_quality_threshold() -> f64 {
156 0.5
157}
158
159fn default_cascade_max_escalations() -> u8 {
160 2
161}
162
163fn default_cascade_window_size() -> usize {
164 50
165}
166
167fn default_reputation_decay_factor() -> f64 {
168 0.95
169}
170
171fn default_reputation_weight() -> f64 {
172 0.3
173}
174
175fn default_reputation_min_observations() -> u64 {
176 5
177}
178
179#[must_use]
181pub fn default_stt_provider() -> String {
182 String::new()
183}
184
185#[must_use]
187pub fn default_stt_language() -> String {
188 "auto".into()
189}
190
191#[must_use]
193pub fn get_default_embedding_model() -> String {
194 default_embedding_model()
195}
196
197#[must_use]
199pub fn get_default_response_cache_ttl_secs() -> u64 {
200 default_response_cache_ttl_secs()
201}
202
203#[must_use]
205pub fn get_default_router_ema_alpha() -> f64 {
206 default_router_ema_alpha()
207}
208
209#[must_use]
211pub fn get_default_router_reorder_interval() -> u64 {
212 default_router_reorder_interval()
213}
214
215#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
228#[serde(rename_all = "lowercase")]
229pub enum ProviderKind {
230 Ollama,
232 Claude,
234 OpenAi,
236 Gemini,
238 Candle,
240 Compatible,
242}
243
244impl ProviderKind {
245 #[must_use]
256 pub fn as_str(self) -> &'static str {
257 match self {
258 Self::Ollama => "ollama",
259 Self::Claude => "claude",
260 Self::OpenAi => "openai",
261 Self::Gemini => "gemini",
262 Self::Candle => "candle",
263 Self::Compatible => "compatible",
264 }
265 }
266}
267
268impl std::fmt::Display for ProviderKind {
269 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
270 f.write_str(self.as_str())
271 }
272}
273
274#[derive(Debug, Deserialize, Serialize)]
298pub struct LlmConfig {
299 #[serde(default, skip_serializing_if = "Vec::is_empty")]
301 pub providers: Vec<ProviderEntry>,
302
303 #[serde(default, skip_serializing_if = "is_routing_none")]
305 pub routing: LlmRoutingStrategy,
306
307 #[serde(default, skip_serializing_if = "std::collections::HashMap::is_empty")]
309 pub routes: std::collections::HashMap<String, Vec<String>>,
310
311 #[serde(default = "default_embedding_model_opt")]
312 pub embedding_model: String,
313 #[serde(default, skip_serializing_if = "Option::is_none")]
314 pub candle: Option<CandleConfig>,
315 #[serde(default)]
316 pub stt: Option<SttConfig>,
317 #[serde(default)]
318 pub response_cache_enabled: bool,
319 #[serde(default = "default_response_cache_ttl_secs")]
320 pub response_cache_ttl_secs: u64,
321 #[serde(default)]
323 pub semantic_cache_enabled: bool,
324 #[serde(default = "default_semantic_cache_threshold")]
330 pub semantic_cache_threshold: f32,
331 #[serde(default = "default_semantic_cache_max_candidates")]
344 pub semantic_cache_max_candidates: u32,
345 #[serde(default)]
346 pub router_ema_enabled: bool,
347 #[serde(default = "default_router_ema_alpha")]
348 pub router_ema_alpha: f64,
349 #[serde(default = "default_router_reorder_interval")]
350 pub router_reorder_interval: u64,
351 #[serde(default, skip_serializing_if = "Option::is_none")]
353 pub router: Option<RouterConfig>,
354 #[serde(default, skip_serializing_if = "Option::is_none")]
357 pub instruction_file: Option<std::path::PathBuf>,
358 #[serde(default, skip_serializing_if = "Option::is_none")]
362 pub summary_model: Option<String>,
363 #[serde(default, skip_serializing_if = "Option::is_none")]
365 pub summary_provider: Option<ProviderEntry>,
366
367 #[serde(default, skip_serializing_if = "Option::is_none")]
369 pub complexity_routing: Option<ComplexityRoutingConfig>,
370}
371
372fn default_embedding_model_opt() -> String {
373 default_embedding_model()
374}
375
376#[allow(clippy::trivially_copy_pass_by_ref)]
377fn is_routing_none(s: &LlmRoutingStrategy) -> bool {
378 *s == LlmRoutingStrategy::None
379}
380
381impl LlmConfig {
382 #[must_use]
384 pub fn effective_provider(&self) -> ProviderKind {
385 self.providers
386 .first()
387 .map_or(ProviderKind::Ollama, |e| e.provider_type)
388 }
389
390 #[must_use]
392 pub fn effective_base_url(&self) -> &str {
393 self.providers
394 .first()
395 .and_then(|e| e.base_url.as_deref())
396 .unwrap_or("http://localhost:11434")
397 }
398
399 #[must_use]
401 pub fn effective_model(&self) -> &str {
402 self.providers
403 .first()
404 .and_then(|e| e.model.as_deref())
405 .unwrap_or("qwen3:8b")
406 }
407
408 #[must_use]
416 pub fn stt_provider_entry(&self) -> Option<&ProviderEntry> {
417 let name_hint = self.stt.as_ref().map_or("", |s| s.provider.as_str());
418 if name_hint.is_empty() {
419 self.providers.iter().find(|p| p.stt_model.is_some())
420 } else {
421 self.providers
422 .iter()
423 .find(|p| p.effective_name() == name_hint && p.stt_model.is_some())
424 }
425 }
426
427 pub fn check_legacy_format(&self) -> Result<(), crate::error::ConfigError> {
433 Ok(())
434 }
435
436 pub fn validate_stt(&self) -> Result<(), crate::error::ConfigError> {
442 use crate::error::ConfigError;
443
444 let Some(stt) = &self.stt else {
445 return Ok(());
446 };
447 if stt.provider.is_empty() {
448 return Ok(());
449 }
450 let found = self
451 .providers
452 .iter()
453 .find(|p| p.effective_name() == stt.provider);
454 match found {
455 None => {
456 return Err(ConfigError::Validation(format!(
457 "[llm.stt].provider = {:?} does not match any [[llm.providers]] entry",
458 stt.provider
459 )));
460 }
461 Some(entry) if entry.stt_model.is_none() => {
462 tracing::warn!(
463 provider = stt.provider,
464 "[[llm.providers]] entry exists but has no `stt_model` — STT will not be activated"
465 );
466 }
467 _ => {}
468 }
469 Ok(())
470 }
471}
472
473#[derive(Debug, Clone, Deserialize, Serialize)]
486pub struct SttConfig {
487 #[serde(default = "default_stt_provider")]
490 pub provider: String,
491 #[serde(default = "default_stt_language")]
493 pub language: String,
494}
495
496#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
498#[serde(rename_all = "lowercase")]
499pub enum RouterStrategyConfig {
500 #[default]
502 Ema,
503 Thompson,
505 Cascade,
507 Bandit,
509}
510
511#[derive(Debug, Clone, Deserialize, Serialize)]
524pub struct AsiConfig {
525 #[serde(default)]
527 pub enabled: bool,
528
529 #[serde(default = "default_asi_window")]
531 pub window: usize,
532
533 #[serde(default = "default_asi_coherence_threshold")]
535 pub coherence_threshold: f32,
536
537 #[serde(default = "default_asi_penalty_weight")]
542 pub penalty_weight: f32,
543}
544
545fn default_asi_window() -> usize {
546 5
547}
548
549fn default_asi_coherence_threshold() -> f32 {
550 0.7
551}
552
553fn default_asi_penalty_weight() -> f32 {
554 0.3
555}
556
557impl Default for AsiConfig {
558 fn default() -> Self {
559 Self {
560 enabled: false,
561 window: default_asi_window(),
562 coherence_threshold: default_asi_coherence_threshold(),
563 penalty_weight: default_asi_penalty_weight(),
564 }
565 }
566}
567
568#[derive(Debug, Clone, Deserialize, Serialize)]
570pub struct RouterConfig {
571 #[serde(default)]
573 pub strategy: RouterStrategyConfig,
574 #[serde(default)]
582 pub thompson_state_path: Option<String>,
583 #[serde(default)]
585 pub cascade: Option<CascadeConfig>,
586 #[serde(default)]
588 pub reputation: Option<ReputationConfig>,
589 #[serde(default)]
591 pub bandit: Option<BanditConfig>,
592 #[serde(default)]
601 pub quality_gate: Option<f32>,
602 #[serde(default)]
604 pub asi: Option<AsiConfig>,
605 #[serde(default = "default_embed_concurrency")]
611 pub embed_concurrency: usize,
612}
613
614fn default_embed_concurrency() -> usize {
615 4
616}
617
618#[derive(Debug, Clone, Deserialize, Serialize)]
625pub struct ReputationConfig {
626 #[serde(default)]
628 pub enabled: bool,
629 #[serde(default = "default_reputation_decay_factor")]
632 pub decay_factor: f64,
633 #[serde(default = "default_reputation_weight")]
640 pub weight: f64,
641 #[serde(default = "default_reputation_min_observations")]
643 pub min_observations: u64,
644 #[serde(default)]
646 pub state_path: Option<String>,
647}
648
649#[derive(Debug, Clone, Deserialize, Serialize)]
660pub struct CascadeConfig {
661 #[serde(default = "default_cascade_quality_threshold")]
664 pub quality_threshold: f64,
665
666 #[serde(default = "default_cascade_max_escalations")]
670 pub max_escalations: u8,
671
672 #[serde(default)]
676 pub classifier_mode: CascadeClassifierMode,
677
678 #[serde(default = "default_cascade_window_size")]
680 pub window_size: usize,
681
682 #[serde(default)]
686 pub max_cascade_tokens: Option<u32>,
687
688 #[serde(default, skip_serializing_if = "Option::is_none")]
693 pub cost_tiers: Option<Vec<String>>,
694}
695
696impl Default for CascadeConfig {
697 fn default() -> Self {
698 Self {
699 quality_threshold: default_cascade_quality_threshold(),
700 max_escalations: default_cascade_max_escalations(),
701 classifier_mode: CascadeClassifierMode::default(),
702 window_size: default_cascade_window_size(),
703 max_cascade_tokens: None,
704 cost_tiers: None,
705 }
706 }
707}
708
709#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
711#[serde(rename_all = "lowercase")]
712pub enum CascadeClassifierMode {
713 #[default]
716 Heuristic,
717 Judge,
720}
721
722fn default_bandit_alpha() -> f32 {
723 1.0
724}
725
726fn default_bandit_dim() -> usize {
727 32
728}
729
730fn default_bandit_cost_weight() -> f32 {
731 0.1
732}
733
734fn default_bandit_decay_factor() -> f32 {
735 1.0
736}
737
738fn default_bandit_embedding_timeout_ms() -> u64 {
739 50
740}
741
742fn default_bandit_cache_size() -> usize {
743 512
744}
745
746#[derive(Debug, Clone, Deserialize, Serialize)]
759pub struct BanditConfig {
760 #[serde(default = "default_bandit_alpha")]
763 pub alpha: f32,
764
765 #[serde(default = "default_bandit_dim")]
772 pub dim: usize,
773
774 #[serde(default = "default_bandit_cost_weight")]
777 pub cost_weight: f32,
778
779 #[serde(default = "default_bandit_decay_factor")]
782 pub decay_factor: f32,
783
784 #[serde(default)]
790 pub embedding_provider: ProviderName,
791
792 #[serde(default = "default_bandit_embedding_timeout_ms")]
795 pub embedding_timeout_ms: u64,
796
797 #[serde(default = "default_bandit_cache_size")]
799 pub cache_size: usize,
800
801 #[serde(default)]
808 pub state_path: Option<String>,
809
810 #[serde(default = "default_bandit_memory_confidence_threshold")]
816 pub memory_confidence_threshold: f32,
817
818 #[serde(default)]
824 pub warmup_queries: Option<u64>,
825}
826
827fn default_bandit_memory_confidence_threshold() -> f32 {
828 0.9
829}
830
831impl Default for BanditConfig {
832 fn default() -> Self {
833 Self {
834 alpha: default_bandit_alpha(),
835 dim: default_bandit_dim(),
836 cost_weight: default_bandit_cost_weight(),
837 decay_factor: default_bandit_decay_factor(),
838 embedding_provider: ProviderName::default(),
839 embedding_timeout_ms: default_bandit_embedding_timeout_ms(),
840 cache_size: default_bandit_cache_size(),
841 state_path: None,
842 memory_confidence_threshold: default_bandit_memory_confidence_threshold(),
843 warmup_queries: None,
844 }
845 }
846}
847
848#[derive(Debug, Deserialize, Serialize)]
849pub struct CandleConfig {
850 #[serde(default = "default_candle_source")]
851 pub source: String,
852 #[serde(default)]
853 pub local_path: String,
854 #[serde(default)]
855 pub filename: Option<String>,
856 #[serde(default = "default_chat_template")]
857 pub chat_template: String,
858 #[serde(default = "default_candle_device")]
859 pub device: String,
860 #[serde(default)]
861 pub embedding_repo: Option<String>,
862 #[serde(default)]
866 pub hf_token: Option<String>,
867 #[serde(default)]
868 pub generation: GenerationParams,
869 #[serde(default = "default_inference_timeout_secs")]
878 pub inference_timeout_secs: u64,
879}
880
881fn default_inference_timeout_secs() -> u64 {
882 120
883}
884
885#[derive(Debug, Clone, Deserialize, Serialize)]
889pub struct GenerationParams {
890 #[serde(default = "default_temperature")]
892 pub temperature: f64,
893 #[serde(default)]
896 pub top_p: Option<f64>,
897 #[serde(default)]
900 pub top_k: Option<usize>,
901 #[serde(default = "default_max_tokens")]
904 pub max_tokens: usize,
905 #[serde(default = "default_seed")]
907 pub seed: u64,
908 #[serde(default = "default_repeat_penalty")]
910 pub repeat_penalty: f32,
911 #[serde(default = "default_repeat_last_n")]
913 pub repeat_last_n: usize,
914}
915
916pub const MAX_TOKENS_CAP: usize = 32768;
918
919impl GenerationParams {
920 #[must_use]
931 pub fn capped_max_tokens(&self) -> usize {
932 self.max_tokens.min(MAX_TOKENS_CAP)
933 }
934}
935
936impl Default for GenerationParams {
937 fn default() -> Self {
938 Self {
939 temperature: default_temperature(),
940 top_p: None,
941 top_k: None,
942 max_tokens: default_max_tokens(),
943 seed: default_seed(),
944 repeat_penalty: default_repeat_penalty(),
945 repeat_last_n: default_repeat_last_n(),
946 }
947 }
948}
949
950#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
954#[serde(rename_all = "lowercase")]
955pub enum LlmRoutingStrategy {
956 #[default]
958 None,
959 Ema,
961 Thompson,
963 Cascade,
965 Task,
967 Triage,
969 Bandit,
971}
972
973fn default_triage_timeout_secs() -> u64 {
974 5
975}
976
977fn default_max_triage_tokens() -> u32 {
978 50
979}
980
981fn default_true() -> bool {
982 true
983}
984
985#[derive(Debug, Clone, Default, Deserialize, Serialize)]
987pub struct TierMapping {
988 pub simple: Option<String>,
989 pub medium: Option<String>,
990 pub complex: Option<String>,
991 pub expert: Option<String>,
992}
993
994#[derive(Debug, Clone, Deserialize, Serialize)]
1015pub struct ComplexityRoutingConfig {
1016 #[serde(default)]
1018 pub triage_provider: Option<ProviderName>,
1019
1020 #[serde(default = "default_true")]
1022 pub bypass_single_provider: bool,
1023
1024 #[serde(default)]
1026 pub tiers: TierMapping,
1027
1028 #[serde(default = "default_max_triage_tokens")]
1030 pub max_triage_tokens: u32,
1031
1032 #[serde(default = "default_triage_timeout_secs")]
1035 pub triage_timeout_secs: u64,
1036
1037 #[serde(default)]
1040 pub fallback_strategy: Option<String>,
1041}
1042
1043impl Default for ComplexityRoutingConfig {
1044 fn default() -> Self {
1045 Self {
1046 triage_provider: None,
1047 bypass_single_provider: true,
1048 tiers: TierMapping::default(),
1049 max_triage_tokens: default_max_triage_tokens(),
1050 triage_timeout_secs: default_triage_timeout_secs(),
1051 fallback_strategy: None,
1052 }
1053 }
1054}
1055
1056#[derive(Debug, Clone, Deserialize, Serialize)]
1059pub struct CandleInlineConfig {
1060 #[serde(default = "default_candle_source")]
1061 pub source: String,
1062 #[serde(default)]
1063 pub local_path: String,
1064 #[serde(default)]
1065 pub filename: Option<String>,
1066 #[serde(default = "default_chat_template")]
1067 pub chat_template: String,
1068 #[serde(default = "default_candle_device")]
1069 pub device: String,
1070 #[serde(default)]
1071 pub embedding_repo: Option<String>,
1072 #[serde(default)]
1074 pub hf_token: Option<String>,
1075 #[serde(default)]
1076 pub generation: GenerationParams,
1077 #[serde(default = "default_inference_timeout_secs")]
1082 pub inference_timeout_secs: u64,
1083}
1084
1085impl Default for CandleInlineConfig {
1086 fn default() -> Self {
1087 Self {
1088 source: default_candle_source(),
1089 local_path: String::new(),
1090 filename: None,
1091 chat_template: default_chat_template(),
1092 device: default_candle_device(),
1093 embedding_repo: None,
1094 hf_token: None,
1095 generation: GenerationParams::default(),
1096 inference_timeout_secs: default_inference_timeout_secs(),
1097 }
1098 }
1099}
1100
1101#[derive(Debug, Clone, Deserialize, Serialize)]
1107#[allow(clippy::struct_excessive_bools)]
1108pub struct ProviderEntry {
1109 #[serde(rename = "type")]
1111 pub provider_type: ProviderKind,
1112
1113 #[serde(default)]
1115 pub name: Option<String>,
1116
1117 #[serde(default)]
1119 pub model: Option<String>,
1120
1121 #[serde(default)]
1123 pub base_url: Option<String>,
1124
1125 #[serde(default)]
1127 pub max_tokens: Option<u32>,
1128
1129 #[serde(default)]
1131 pub embedding_model: Option<String>,
1132
1133 #[serde(default)]
1136 pub stt_model: Option<String>,
1137
1138 #[serde(default)]
1140 pub embed: bool,
1141
1142 #[serde(default)]
1144 pub default: bool,
1145
1146 #[serde(default)]
1148 pub thinking: Option<ThinkingConfig>,
1149 #[serde(default)]
1150 pub server_compaction: bool,
1151 #[serde(default)]
1152 pub enable_extended_context: bool,
1153
1154 #[serde(default)]
1156 pub reasoning_effort: Option<String>,
1157
1158 #[serde(default)]
1160 pub thinking_level: Option<GeminiThinkingLevel>,
1161 #[serde(default)]
1162 pub thinking_budget: Option<i32>,
1163 #[serde(default)]
1164 pub include_thoughts: Option<bool>,
1165
1166 #[serde(default)]
1168 pub api_key: Option<String>,
1169
1170 #[serde(default)]
1172 pub candle: Option<CandleInlineConfig>,
1173
1174 #[serde(default)]
1176 pub vision_model: Option<String>,
1177
1178 #[serde(default)]
1180 pub instruction_file: Option<std::path::PathBuf>,
1181}
1182
1183impl Default for ProviderEntry {
1184 fn default() -> Self {
1185 Self {
1186 provider_type: ProviderKind::Ollama,
1187 name: None,
1188 model: None,
1189 base_url: None,
1190 max_tokens: None,
1191 embedding_model: None,
1192 stt_model: None,
1193 embed: false,
1194 default: false,
1195 thinking: None,
1196 server_compaction: false,
1197 enable_extended_context: false,
1198 reasoning_effort: None,
1199 thinking_level: None,
1200 thinking_budget: None,
1201 include_thoughts: None,
1202 api_key: None,
1203 candle: None,
1204 vision_model: None,
1205 instruction_file: None,
1206 }
1207 }
1208}
1209
1210impl ProviderEntry {
1211 #[must_use]
1213 pub fn effective_name(&self) -> String {
1214 self.name
1215 .clone()
1216 .unwrap_or_else(|| self.provider_type.as_str().to_owned())
1217 }
1218
1219 #[must_use]
1224 pub fn effective_model(&self) -> String {
1225 if let Some(ref m) = self.model {
1226 return m.clone();
1227 }
1228 match self.provider_type {
1229 ProviderKind::Ollama => "qwen3:8b".to_owned(),
1230 ProviderKind::Claude => "claude-haiku-4-5-20251001".to_owned(),
1231 ProviderKind::OpenAi => "gpt-4o-mini".to_owned(),
1232 ProviderKind::Gemini => "gemini-2.0-flash".to_owned(),
1233 ProviderKind::Compatible | ProviderKind::Candle => String::new(),
1234 }
1235 }
1236
1237 pub fn validate(&self) -> Result<(), crate::error::ConfigError> {
1244 use crate::error::ConfigError;
1245
1246 if self.provider_type == ProviderKind::Compatible && self.name.is_none() {
1248 return Err(ConfigError::Validation(
1249 "[[llm.providers]] entry with type=\"compatible\" must set `name`".into(),
1250 ));
1251 }
1252
1253 match self.provider_type {
1255 ProviderKind::Ollama => {
1256 if self.thinking.is_some() {
1257 tracing::warn!(
1258 provider = self.effective_name(),
1259 "field `thinking` is only used by Claude providers"
1260 );
1261 }
1262 if self.reasoning_effort.is_some() {
1263 tracing::warn!(
1264 provider = self.effective_name(),
1265 "field `reasoning_effort` is only used by OpenAI providers"
1266 );
1267 }
1268 if self.thinking_level.is_some() || self.thinking_budget.is_some() {
1269 tracing::warn!(
1270 provider = self.effective_name(),
1271 "fields `thinking_level`/`thinking_budget` are only used by Gemini providers"
1272 );
1273 }
1274 }
1275 ProviderKind::Claude => {
1276 if self.reasoning_effort.is_some() {
1277 tracing::warn!(
1278 provider = self.effective_name(),
1279 "field `reasoning_effort` is only used by OpenAI providers"
1280 );
1281 }
1282 if self.thinking_level.is_some() || self.thinking_budget.is_some() {
1283 tracing::warn!(
1284 provider = self.effective_name(),
1285 "fields `thinking_level`/`thinking_budget` are only used by Gemini providers"
1286 );
1287 }
1288 }
1289 ProviderKind::OpenAi => {
1290 if self.thinking.is_some() {
1291 tracing::warn!(
1292 provider = self.effective_name(),
1293 "field `thinking` is only used by Claude providers"
1294 );
1295 }
1296 if self.thinking_level.is_some() || self.thinking_budget.is_some() {
1297 tracing::warn!(
1298 provider = self.effective_name(),
1299 "fields `thinking_level`/`thinking_budget` are only used by Gemini providers"
1300 );
1301 }
1302 }
1303 ProviderKind::Gemini => {
1304 if self.thinking.is_some() {
1305 tracing::warn!(
1306 provider = self.effective_name(),
1307 "field `thinking` is only used by Claude providers"
1308 );
1309 }
1310 if self.reasoning_effort.is_some() {
1311 tracing::warn!(
1312 provider = self.effective_name(),
1313 "field `reasoning_effort` is only used by OpenAI providers"
1314 );
1315 }
1316 }
1317 _ => {}
1318 }
1319
1320 if self.stt_model.is_some() && self.provider_type == ProviderKind::Ollama {
1323 tracing::warn!(
1324 provider = self.effective_name(),
1325 "field `stt_model` is set on an Ollama provider; Ollama does not support the \
1326 Whisper STT API — use OpenAI, compatible, or candle instead"
1327 );
1328 }
1329
1330 Ok(())
1331 }
1332}
1333
1334pub fn validate_pool(entries: &[ProviderEntry]) -> Result<(), crate::error::ConfigError> {
1344 use crate::error::ConfigError;
1345 use std::collections::HashSet;
1346
1347 if entries.is_empty() {
1348 return Err(ConfigError::Validation(
1349 "at least one LLM provider must be configured in [[llm.providers]]".into(),
1350 ));
1351 }
1352
1353 let default_count = entries.iter().filter(|e| e.default).count();
1354 if default_count > 1 {
1355 return Err(ConfigError::Validation(
1356 "only one [[llm.providers]] entry can be marked `default = true`".into(),
1357 ));
1358 }
1359
1360 let mut seen_names: HashSet<String> = HashSet::new();
1361 for entry in entries {
1362 let name = entry.effective_name();
1363 if !seen_names.insert(name.clone()) {
1364 return Err(ConfigError::Validation(format!(
1365 "duplicate provider name \"{name}\" in [[llm.providers]]"
1366 )));
1367 }
1368 entry.validate()?;
1369 }
1370
1371 Ok(())
1372}
1373
1374#[cfg(test)]
1375mod tests {
1376 use super::*;
1377
1378 fn ollama_entry() -> ProviderEntry {
1379 ProviderEntry {
1380 provider_type: ProviderKind::Ollama,
1381 name: Some("ollama".into()),
1382 model: Some("qwen3:8b".into()),
1383 ..Default::default()
1384 }
1385 }
1386
1387 fn claude_entry() -> ProviderEntry {
1388 ProviderEntry {
1389 provider_type: ProviderKind::Claude,
1390 name: Some("claude".into()),
1391 model: Some("claude-sonnet-4-6".into()),
1392 max_tokens: Some(8192),
1393 ..Default::default()
1394 }
1395 }
1396
1397 #[test]
1400 fn validate_ollama_valid() {
1401 assert!(ollama_entry().validate().is_ok());
1402 }
1403
1404 #[test]
1405 fn validate_claude_valid() {
1406 assert!(claude_entry().validate().is_ok());
1407 }
1408
1409 #[test]
1410 fn validate_compatible_without_name_errors() {
1411 let entry = ProviderEntry {
1412 provider_type: ProviderKind::Compatible,
1413 name: None,
1414 ..Default::default()
1415 };
1416 let err = entry.validate().unwrap_err();
1417 assert!(
1418 err.to_string().contains("compatible"),
1419 "error should mention compatible: {err}"
1420 );
1421 }
1422
1423 #[test]
1424 fn validate_compatible_with_name_ok() {
1425 let entry = ProviderEntry {
1426 provider_type: ProviderKind::Compatible,
1427 name: Some("my-proxy".into()),
1428 base_url: Some("http://localhost:8080".into()),
1429 model: Some("gpt-4o".into()),
1430 max_tokens: Some(4096),
1431 ..Default::default()
1432 };
1433 assert!(entry.validate().is_ok());
1434 }
1435
1436 #[test]
1437 fn validate_openai_valid() {
1438 let entry = ProviderEntry {
1439 provider_type: ProviderKind::OpenAi,
1440 name: Some("openai".into()),
1441 model: Some("gpt-4o".into()),
1442 max_tokens: Some(4096),
1443 ..Default::default()
1444 };
1445 assert!(entry.validate().is_ok());
1446 }
1447
1448 #[test]
1449 fn validate_gemini_valid() {
1450 let entry = ProviderEntry {
1451 provider_type: ProviderKind::Gemini,
1452 name: Some("gemini".into()),
1453 model: Some("gemini-2.0-flash".into()),
1454 ..Default::default()
1455 };
1456 assert!(entry.validate().is_ok());
1457 }
1458
1459 #[test]
1462 fn validate_pool_empty_errors() {
1463 let err = validate_pool(&[]).unwrap_err();
1464 assert!(err.to_string().contains("at least one"), "{err}");
1465 }
1466
1467 #[test]
1468 fn validate_pool_single_entry_ok() {
1469 assert!(validate_pool(&[ollama_entry()]).is_ok());
1470 }
1471
1472 #[test]
1473 fn validate_pool_duplicate_names_errors() {
1474 let a = ollama_entry();
1475 let b = ollama_entry(); let err = validate_pool(&[a, b]).unwrap_err();
1477 assert!(err.to_string().contains("duplicate"), "{err}");
1478 }
1479
1480 #[test]
1481 fn validate_pool_multiple_defaults_errors() {
1482 let mut a = ollama_entry();
1483 let mut b = claude_entry();
1484 a.default = true;
1485 b.default = true;
1486 let err = validate_pool(&[a, b]).unwrap_err();
1487 assert!(err.to_string().contains("default"), "{err}");
1488 }
1489
1490 #[test]
1491 fn validate_pool_two_different_providers_ok() {
1492 assert!(validate_pool(&[ollama_entry(), claude_entry()]).is_ok());
1493 }
1494
1495 #[test]
1496 fn validate_pool_propagates_entry_error() {
1497 let bad = ProviderEntry {
1498 provider_type: ProviderKind::Compatible,
1499 name: None, ..Default::default()
1501 };
1502 assert!(validate_pool(&[bad]).is_err());
1503 }
1504
1505 #[test]
1508 fn effective_model_returns_explicit_when_set() {
1509 let entry = ProviderEntry {
1510 provider_type: ProviderKind::Claude,
1511 model: Some("claude-sonnet-4-6".into()),
1512 ..Default::default()
1513 };
1514 assert_eq!(entry.effective_model(), "claude-sonnet-4-6");
1515 }
1516
1517 #[test]
1518 fn effective_model_ollama_default_when_none() {
1519 let entry = ProviderEntry {
1520 provider_type: ProviderKind::Ollama,
1521 model: None,
1522 ..Default::default()
1523 };
1524 assert_eq!(entry.effective_model(), "qwen3:8b");
1525 }
1526
1527 #[test]
1528 fn effective_model_claude_default_when_none() {
1529 let entry = ProviderEntry {
1530 provider_type: ProviderKind::Claude,
1531 model: None,
1532 ..Default::default()
1533 };
1534 assert_eq!(entry.effective_model(), "claude-haiku-4-5-20251001");
1535 }
1536
1537 #[test]
1538 fn effective_model_openai_default_when_none() {
1539 let entry = ProviderEntry {
1540 provider_type: ProviderKind::OpenAi,
1541 model: None,
1542 ..Default::default()
1543 };
1544 assert_eq!(entry.effective_model(), "gpt-4o-mini");
1545 }
1546
1547 #[test]
1548 fn effective_model_gemini_default_when_none() {
1549 let entry = ProviderEntry {
1550 provider_type: ProviderKind::Gemini,
1551 model: None,
1552 ..Default::default()
1553 };
1554 assert_eq!(entry.effective_model(), "gemini-2.0-flash");
1555 }
1556
1557 fn parse_llm(toml: &str) -> LlmConfig {
1561 #[derive(serde::Deserialize)]
1562 struct Wrapper {
1563 llm: LlmConfig,
1564 }
1565 toml::from_str::<Wrapper>(toml).unwrap().llm
1566 }
1567
1568 #[test]
1569 fn check_legacy_format_new_format_ok() {
1570 let cfg = parse_llm(
1571 r#"
1572[llm]
1573
1574[[llm.providers]]
1575type = "ollama"
1576model = "qwen3:8b"
1577"#,
1578 );
1579 assert!(cfg.check_legacy_format().is_ok());
1580 }
1581
1582 #[test]
1583 fn check_legacy_format_empty_providers_no_legacy_ok() {
1584 let cfg = parse_llm("[llm]\n");
1586 assert!(cfg.check_legacy_format().is_ok());
1587 }
1588
1589 #[test]
1592 fn effective_provider_falls_back_to_ollama_when_no_providers() {
1593 let cfg = parse_llm("[llm]\n");
1594 assert_eq!(cfg.effective_provider(), ProviderKind::Ollama);
1595 }
1596
1597 #[test]
1598 fn effective_provider_reads_from_providers_first() {
1599 let cfg = parse_llm(
1600 r#"
1601[llm]
1602
1603[[llm.providers]]
1604type = "claude"
1605model = "claude-sonnet-4-6"
1606"#,
1607 );
1608 assert_eq!(cfg.effective_provider(), ProviderKind::Claude);
1609 }
1610
1611 #[test]
1612 fn effective_model_reads_from_providers_first() {
1613 let cfg = parse_llm(
1614 r#"
1615[llm]
1616
1617[[llm.providers]]
1618type = "ollama"
1619model = "qwen3:8b"
1620"#,
1621 );
1622 assert_eq!(cfg.effective_model(), "qwen3:8b");
1623 }
1624
1625 #[test]
1626 fn effective_base_url_default_when_absent() {
1627 let cfg = parse_llm("[llm]\n");
1628 assert_eq!(cfg.effective_base_url(), "http://localhost:11434");
1629 }
1630
1631 #[test]
1632 fn effective_base_url_from_providers_entry() {
1633 let cfg = parse_llm(
1634 r#"
1635[llm]
1636
1637[[llm.providers]]
1638type = "ollama"
1639base_url = "http://myhost:11434"
1640"#,
1641 );
1642 assert_eq!(cfg.effective_base_url(), "http://myhost:11434");
1643 }
1644
1645 #[test]
1648 fn complexity_routing_defaults() {
1649 let cr = ComplexityRoutingConfig::default();
1650 assert!(
1651 cr.bypass_single_provider,
1652 "bypass_single_provider must default to true"
1653 );
1654 assert_eq!(cr.triage_timeout_secs, 5);
1655 assert_eq!(cr.max_triage_tokens, 50);
1656 assert!(cr.triage_provider.is_none());
1657 assert!(cr.tiers.simple.is_none());
1658 }
1659
1660 #[test]
1661 fn complexity_routing_toml_round_trip() {
1662 let cfg = parse_llm(
1663 r#"
1664[llm]
1665routing = "triage"
1666
1667[llm.complexity_routing]
1668triage_provider = "fast"
1669bypass_single_provider = false
1670triage_timeout_secs = 10
1671max_triage_tokens = 100
1672
1673[llm.complexity_routing.tiers]
1674simple = "fast"
1675medium = "medium"
1676complex = "large"
1677expert = "opus"
1678"#,
1679 );
1680 assert!(matches!(cfg.routing, LlmRoutingStrategy::Triage));
1681 let cr = cfg
1682 .complexity_routing
1683 .expect("complexity_routing must be present");
1684 assert_eq!(cr.triage_provider.as_deref(), Some("fast"));
1685 assert!(!cr.bypass_single_provider);
1686 assert_eq!(cr.triage_timeout_secs, 10);
1687 assert_eq!(cr.max_triage_tokens, 100);
1688 assert_eq!(cr.tiers.simple.as_deref(), Some("fast"));
1689 assert_eq!(cr.tiers.medium.as_deref(), Some("medium"));
1690 assert_eq!(cr.tiers.complex.as_deref(), Some("large"));
1691 assert_eq!(cr.tiers.expert.as_deref(), Some("opus"));
1692 }
1693
1694 #[test]
1695 fn complexity_routing_partial_tiers_toml() {
1696 let cfg = parse_llm(
1698 r#"
1699[llm]
1700routing = "triage"
1701
1702[llm.complexity_routing.tiers]
1703simple = "haiku"
1704complex = "sonnet"
1705"#,
1706 );
1707 let cr = cfg
1708 .complexity_routing
1709 .expect("complexity_routing must be present");
1710 assert_eq!(cr.tiers.simple.as_deref(), Some("haiku"));
1711 assert!(cr.tiers.medium.is_none());
1712 assert_eq!(cr.tiers.complex.as_deref(), Some("sonnet"));
1713 assert!(cr.tiers.expert.is_none());
1714 assert!(cr.bypass_single_provider);
1716 assert_eq!(cr.triage_timeout_secs, 5);
1717 }
1718
1719 #[test]
1720 fn routing_strategy_triage_deserialized() {
1721 let cfg = parse_llm(
1722 r#"
1723[llm]
1724routing = "triage"
1725"#,
1726 );
1727 assert!(matches!(cfg.routing, LlmRoutingStrategy::Triage));
1728 }
1729
1730 #[test]
1733 fn stt_provider_entry_by_name_match() {
1734 let cfg = parse_llm(
1735 r#"
1736[llm]
1737
1738[[llm.providers]]
1739type = "openai"
1740name = "quality"
1741model = "gpt-5.4"
1742stt_model = "gpt-4o-mini-transcribe"
1743
1744[llm.stt]
1745provider = "quality"
1746"#,
1747 );
1748 let entry = cfg.stt_provider_entry().expect("should find stt provider");
1749 assert_eq!(entry.effective_name(), "quality");
1750 assert_eq!(entry.stt_model.as_deref(), Some("gpt-4o-mini-transcribe"));
1751 }
1752
1753 #[test]
1754 fn stt_provider_entry_auto_detect_when_provider_empty() {
1755 let cfg = parse_llm(
1756 r#"
1757[llm]
1758
1759[[llm.providers]]
1760type = "openai"
1761name = "openai-stt"
1762stt_model = "whisper-1"
1763
1764[llm.stt]
1765provider = ""
1766"#,
1767 );
1768 let entry = cfg.stt_provider_entry().expect("should auto-detect");
1769 assert_eq!(entry.effective_name(), "openai-stt");
1770 }
1771
1772 #[test]
1773 fn stt_provider_entry_auto_detect_no_stt_section() {
1774 let cfg = parse_llm(
1775 r#"
1776[llm]
1777
1778[[llm.providers]]
1779type = "openai"
1780name = "openai-stt"
1781stt_model = "whisper-1"
1782"#,
1783 );
1784 let entry = cfg.stt_provider_entry().expect("should auto-detect");
1786 assert_eq!(entry.effective_name(), "openai-stt");
1787 }
1788
1789 #[test]
1790 fn stt_provider_entry_none_when_no_stt_model() {
1791 let cfg = parse_llm(
1792 r#"
1793[llm]
1794
1795[[llm.providers]]
1796type = "openai"
1797name = "quality"
1798model = "gpt-5.4"
1799"#,
1800 );
1801 assert!(cfg.stt_provider_entry().is_none());
1802 }
1803
1804 #[test]
1805 fn stt_provider_entry_name_mismatch_falls_back_to_none() {
1806 let cfg = parse_llm(
1808 r#"
1809[llm]
1810
1811[[llm.providers]]
1812type = "openai"
1813name = "quality"
1814model = "gpt-5.4"
1815
1816[[llm.providers]]
1817type = "openai"
1818name = "openai-stt"
1819stt_model = "whisper-1"
1820
1821[llm.stt]
1822provider = "quality"
1823"#,
1824 );
1825 assert!(cfg.stt_provider_entry().is_none());
1827 }
1828
1829 #[test]
1830 fn stt_config_deserializes_new_slim_format() {
1831 let cfg = parse_llm(
1832 r#"
1833[llm]
1834
1835[[llm.providers]]
1836type = "openai"
1837name = "quality"
1838stt_model = "whisper-1"
1839
1840[llm.stt]
1841provider = "quality"
1842language = "en"
1843"#,
1844 );
1845 let stt = cfg.stt.as_ref().expect("stt section present");
1846 assert_eq!(stt.provider, "quality");
1847 assert_eq!(stt.language, "en");
1848 }
1849
1850 #[test]
1851 fn stt_config_default_provider_is_empty() {
1852 assert_eq!(default_stt_provider(), "");
1854 }
1855
1856 #[test]
1857 fn validate_stt_missing_provider_ok() {
1858 let cfg = parse_llm("[llm]\n");
1859 assert!(cfg.validate_stt().is_ok());
1860 }
1861
1862 #[test]
1863 fn validate_stt_valid_reference() {
1864 let cfg = parse_llm(
1865 r#"
1866[llm]
1867
1868[[llm.providers]]
1869type = "openai"
1870name = "quality"
1871stt_model = "whisper-1"
1872
1873[llm.stt]
1874provider = "quality"
1875"#,
1876 );
1877 assert!(cfg.validate_stt().is_ok());
1878 }
1879
1880 #[test]
1881 fn validate_stt_nonexistent_provider_errors() {
1882 let cfg = parse_llm(
1883 r#"
1884[llm]
1885
1886[[llm.providers]]
1887type = "openai"
1888name = "quality"
1889model = "gpt-5.4"
1890
1891[llm.stt]
1892provider = "nonexistent"
1893"#,
1894 );
1895 assert!(cfg.validate_stt().is_err());
1896 }
1897
1898 #[test]
1899 fn validate_stt_provider_exists_but_no_stt_model_returns_ok_with_warn() {
1900 let cfg = parse_llm(
1902 r#"
1903[llm]
1904
1905[[llm.providers]]
1906type = "openai"
1907name = "quality"
1908model = "gpt-5.4"
1909
1910[llm.stt]
1911provider = "quality"
1912"#,
1913 );
1914 assert!(cfg.validate_stt().is_ok());
1916 assert!(
1918 cfg.stt_provider_entry().is_none(),
1919 "stt_provider_entry must be None when provider has no stt_model"
1920 );
1921 }
1922
1923 #[test]
1926 fn bandit_warmup_queries_explicit_value_is_deserialized() {
1927 let cfg = parse_llm(
1928 r#"
1929[llm]
1930
1931[llm.router]
1932strategy = "bandit"
1933
1934[llm.router.bandit]
1935warmup_queries = 50
1936"#,
1937 );
1938 let bandit = cfg
1939 .router
1940 .expect("router section must be present")
1941 .bandit
1942 .expect("bandit section must be present");
1943 assert_eq!(
1944 bandit.warmup_queries,
1945 Some(50),
1946 "warmup_queries = 50 must deserialize to Some(50)"
1947 );
1948 }
1949
1950 #[test]
1951 fn bandit_warmup_queries_explicit_null_is_none() {
1952 let cfg = parse_llm(
1955 r#"
1956[llm]
1957
1958[llm.router]
1959strategy = "bandit"
1960
1961[llm.router.bandit]
1962warmup_queries = 0
1963"#,
1964 );
1965 let bandit = cfg
1966 .router
1967 .expect("router section must be present")
1968 .bandit
1969 .expect("bandit section must be present");
1970 assert_eq!(
1972 bandit.warmup_queries,
1973 Some(0),
1974 "warmup_queries = 0 must deserialize to Some(0)"
1975 );
1976 }
1977
1978 #[test]
1979 fn bandit_warmup_queries_missing_field_defaults_to_none() {
1980 let cfg = parse_llm(
1982 r#"
1983[llm]
1984
1985[llm.router]
1986strategy = "bandit"
1987
1988[llm.router.bandit]
1989alpha = 1.5
1990"#,
1991 );
1992 let bandit = cfg
1993 .router
1994 .expect("router section must be present")
1995 .bandit
1996 .expect("bandit section must be present");
1997 assert_eq!(
1998 bandit.warmup_queries, None,
1999 "omitted warmup_queries must default to None"
2000 );
2001 }
2002
2003 #[test]
2004 fn provider_name_new_and_as_str() {
2005 let n = ProviderName::new("fast");
2006 assert_eq!(n.as_str(), "fast");
2007 assert!(!n.is_empty());
2008 }
2009
2010 #[test]
2011 fn provider_name_default_is_empty() {
2012 let n = ProviderName::default();
2013 assert!(n.is_empty());
2014 assert_eq!(n.as_str(), "");
2015 }
2016
2017 #[test]
2018 fn provider_name_deref_to_str() {
2019 let n = ProviderName::new("quality");
2020 let s: &str = &n;
2021 assert_eq!(s, "quality");
2022 }
2023
2024 #[test]
2025 fn provider_name_partial_eq_str() {
2026 let n = ProviderName::new("fast");
2027 assert_eq!(n, "fast");
2028 assert_ne!(n, "slow");
2029 }
2030
2031 #[test]
2032 fn provider_name_serde_roundtrip() {
2033 let n = ProviderName::new("my-provider");
2034 let json = serde_json::to_string(&n).expect("serialize");
2035 assert_eq!(json, "\"my-provider\"");
2036 let back: ProviderName = serde_json::from_str(&json).expect("deserialize");
2037 assert_eq!(back, n);
2038 }
2039
2040 #[test]
2041 fn provider_name_serde_empty_roundtrip() {
2042 let n = ProviderName::default();
2043 let json = serde_json::to_string(&n).expect("serialize");
2044 assert_eq!(json, "\"\"");
2045 let back: ProviderName = serde_json::from_str(&json).expect("deserialize");
2046 assert_eq!(back, n);
2047 assert!(back.is_empty());
2048 }
2049}