1use std::fmt;
5
6use serde::{Deserialize, Serialize};
7use zeph_llm::{GeminiThinkingLevel, ThinkingConfig};
8
9#[derive(Debug, Clone, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
14#[serde(transparent)]
15pub struct ProviderName(String);
16
17impl ProviderName {
18 #[must_use]
19 pub fn new(name: impl Into<String>) -> Self {
20 Self(name.into())
21 }
22
23 #[must_use]
24 pub fn is_empty(&self) -> bool {
25 self.0.is_empty()
26 }
27
28 #[must_use]
29 pub fn as_str(&self) -> &str {
30 &self.0
31 }
32}
33
34impl fmt::Display for ProviderName {
35 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36 self.0.fmt(f)
37 }
38}
39
40impl AsRef<str> for ProviderName {
41 fn as_ref(&self) -> &str {
42 &self.0
43 }
44}
45
46impl std::ops::Deref for ProviderName {
47 type Target = str;
48
49 fn deref(&self) -> &str {
50 &self.0
51 }
52}
53
54impl PartialEq<str> for ProviderName {
55 fn eq(&self, other: &str) -> bool {
56 self.0 == other
57 }
58}
59
60impl PartialEq<&str> for ProviderName {
61 fn eq(&self, other: &&str) -> bool {
62 self.0 == *other
63 }
64}
65
66fn default_response_cache_ttl_secs() -> u64 {
67 3600
68}
69
70fn default_semantic_cache_threshold() -> f32 {
71 0.95
72}
73
74fn default_semantic_cache_max_candidates() -> u32 {
75 10
76}
77
78fn default_router_ema_alpha() -> f64 {
79 0.1
80}
81
82fn default_router_reorder_interval() -> u64 {
83 10
84}
85
86fn default_embedding_model() -> String {
87 "qwen3-embedding".into()
88}
89
90fn default_candle_source() -> String {
91 "huggingface".into()
92}
93
94fn default_chat_template() -> String {
95 "chatml".into()
96}
97
98fn default_candle_device() -> String {
99 "cpu".into()
100}
101
102fn default_temperature() -> f64 {
103 0.7
104}
105
106fn default_max_tokens() -> usize {
107 2048
108}
109
110fn default_seed() -> u64 {
111 42
112}
113
114fn default_repeat_penalty() -> f32 {
115 1.1
116}
117
118fn default_repeat_last_n() -> usize {
119 64
120}
121
122fn default_cascade_quality_threshold() -> f64 {
123 0.5
124}
125
126fn default_cascade_max_escalations() -> u8 {
127 2
128}
129
130fn default_cascade_window_size() -> usize {
131 50
132}
133
134fn default_reputation_decay_factor() -> f64 {
135 0.95
136}
137
138fn default_reputation_weight() -> f64 {
139 0.3
140}
141
142fn default_reputation_min_observations() -> u64 {
143 5
144}
145
146#[must_use]
147pub fn default_stt_provider() -> String {
148 String::new()
149}
150
151#[must_use]
152pub fn default_stt_language() -> String {
153 "auto".into()
154}
155
156#[must_use]
157pub fn get_default_embedding_model() -> String {
158 default_embedding_model()
159}
160
161#[must_use]
162pub fn get_default_response_cache_ttl_secs() -> u64 {
163 default_response_cache_ttl_secs()
164}
165
166#[must_use]
167pub fn get_default_router_ema_alpha() -> f64 {
168 default_router_ema_alpha()
169}
170
171#[must_use]
172pub fn get_default_router_reorder_interval() -> u64 {
173 default_router_reorder_interval()
174}
175
176#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
178#[serde(rename_all = "lowercase")]
179pub enum ProviderKind {
180 Ollama,
181 Claude,
182 OpenAi,
183 Gemini,
184 Candle,
185 Compatible,
186}
187
188impl ProviderKind {
189 #[must_use]
190 pub fn as_str(self) -> &'static str {
191 match self {
192 Self::Ollama => "ollama",
193 Self::Claude => "claude",
194 Self::OpenAi => "openai",
195 Self::Gemini => "gemini",
196 Self::Candle => "candle",
197 Self::Compatible => "compatible",
198 }
199 }
200}
201
202impl std::fmt::Display for ProviderKind {
203 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204 f.write_str(self.as_str())
205 }
206}
207
208#[derive(Debug, Deserialize, Serialize)]
209pub struct LlmConfig {
210 #[serde(default, skip_serializing_if = "Vec::is_empty")]
212 pub providers: Vec<ProviderEntry>,
213
214 #[serde(default, skip_serializing_if = "is_routing_none")]
216 pub routing: LlmRoutingStrategy,
217
218 #[serde(default, skip_serializing_if = "std::collections::HashMap::is_empty")]
220 pub routes: std::collections::HashMap<String, Vec<String>>,
221
222 #[serde(default = "default_embedding_model_opt")]
223 pub embedding_model: String,
224 #[serde(default, skip_serializing_if = "Option::is_none")]
225 pub candle: Option<CandleConfig>,
226 #[serde(default)]
227 pub stt: Option<SttConfig>,
228 #[serde(default)]
229 pub response_cache_enabled: bool,
230 #[serde(default = "default_response_cache_ttl_secs")]
231 pub response_cache_ttl_secs: u64,
232 #[serde(default)]
234 pub semantic_cache_enabled: bool,
235 #[serde(default = "default_semantic_cache_threshold")]
241 pub semantic_cache_threshold: f32,
242 #[serde(default = "default_semantic_cache_max_candidates")]
255 pub semantic_cache_max_candidates: u32,
256 #[serde(default)]
257 pub router_ema_enabled: bool,
258 #[serde(default = "default_router_ema_alpha")]
259 pub router_ema_alpha: f64,
260 #[serde(default = "default_router_reorder_interval")]
261 pub router_reorder_interval: u64,
262 #[serde(default, skip_serializing_if = "Option::is_none")]
264 pub router: Option<RouterConfig>,
265 #[serde(default, skip_serializing_if = "Option::is_none")]
268 pub instruction_file: Option<std::path::PathBuf>,
269 #[serde(default, skip_serializing_if = "Option::is_none")]
273 pub summary_model: Option<String>,
274 #[serde(default, skip_serializing_if = "Option::is_none")]
276 pub summary_provider: Option<ProviderEntry>,
277
278 #[serde(default, skip_serializing_if = "Option::is_none")]
280 pub complexity_routing: Option<ComplexityRoutingConfig>,
281}
282
283fn default_embedding_model_opt() -> String {
284 default_embedding_model()
285}
286
287#[allow(clippy::trivially_copy_pass_by_ref)]
288fn is_routing_none(s: &LlmRoutingStrategy) -> bool {
289 *s == LlmRoutingStrategy::None
290}
291
292impl LlmConfig {
293 #[must_use]
295 pub fn effective_provider(&self) -> ProviderKind {
296 self.providers
297 .first()
298 .map_or(ProviderKind::Ollama, |e| e.provider_type)
299 }
300
301 #[must_use]
303 pub fn effective_base_url(&self) -> &str {
304 self.providers
305 .first()
306 .and_then(|e| e.base_url.as_deref())
307 .unwrap_or("http://localhost:11434")
308 }
309
310 #[must_use]
312 pub fn effective_model(&self) -> &str {
313 self.providers
314 .first()
315 .and_then(|e| e.model.as_deref())
316 .unwrap_or("qwen3:8b")
317 }
318
319 #[must_use]
327 pub fn stt_provider_entry(&self) -> Option<&ProviderEntry> {
328 let name_hint = self.stt.as_ref().map_or("", |s| s.provider.as_str());
329 if name_hint.is_empty() {
330 self.providers.iter().find(|p| p.stt_model.is_some())
331 } else {
332 self.providers
333 .iter()
334 .find(|p| p.effective_name() == name_hint && p.stt_model.is_some())
335 }
336 }
337
338 pub fn check_legacy_format(&self) -> Result<(), crate::error::ConfigError> {
344 Ok(())
345 }
346
347 pub fn validate_stt(&self) -> Result<(), crate::error::ConfigError> {
353 use crate::error::ConfigError;
354
355 let Some(stt) = &self.stt else {
356 return Ok(());
357 };
358 if stt.provider.is_empty() {
359 return Ok(());
360 }
361 let found = self
362 .providers
363 .iter()
364 .find(|p| p.effective_name() == stt.provider);
365 match found {
366 None => {
367 return Err(ConfigError::Validation(format!(
368 "[llm.stt].provider = {:?} does not match any [[llm.providers]] entry",
369 stt.provider
370 )));
371 }
372 Some(entry) if entry.stt_model.is_none() => {
373 tracing::warn!(
374 provider = stt.provider,
375 "[[llm.providers]] entry exists but has no `stt_model` — STT will not be activated"
376 );
377 }
378 _ => {}
379 }
380 Ok(())
381 }
382}
383
384#[derive(Debug, Clone, Deserialize, Serialize)]
385pub struct SttConfig {
386 #[serde(default = "default_stt_provider")]
389 pub provider: String,
390 #[serde(default = "default_stt_language")]
392 pub language: String,
393}
394
395#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
397#[serde(rename_all = "lowercase")]
398pub enum RouterStrategyConfig {
399 #[default]
401 Ema,
402 Thompson,
404 Cascade,
406 Bandit,
408}
409
410#[derive(Debug, Clone, Deserialize, Serialize)]
412pub struct RouterConfig {
413 #[serde(default)]
415 pub strategy: RouterStrategyConfig,
416 #[serde(default)]
424 pub thompson_state_path: Option<String>,
425 #[serde(default)]
427 pub cascade: Option<CascadeConfig>,
428 #[serde(default)]
430 pub reputation: Option<ReputationConfig>,
431 #[serde(default)]
433 pub bandit: Option<BanditConfig>,
434}
435
436#[derive(Debug, Clone, Deserialize, Serialize)]
443pub struct ReputationConfig {
444 #[serde(default)]
446 pub enabled: bool,
447 #[serde(default = "default_reputation_decay_factor")]
450 pub decay_factor: f64,
451 #[serde(default = "default_reputation_weight")]
458 pub weight: f64,
459 #[serde(default = "default_reputation_min_observations")]
461 pub min_observations: u64,
462 #[serde(default)]
464 pub state_path: Option<String>,
465}
466
467#[derive(Debug, Clone, Deserialize, Serialize)]
478pub struct CascadeConfig {
479 #[serde(default = "default_cascade_quality_threshold")]
482 pub quality_threshold: f64,
483
484 #[serde(default = "default_cascade_max_escalations")]
488 pub max_escalations: u8,
489
490 #[serde(default)]
494 pub classifier_mode: CascadeClassifierMode,
495
496 #[serde(default = "default_cascade_window_size")]
498 pub window_size: usize,
499
500 #[serde(default)]
504 pub max_cascade_tokens: Option<u32>,
505
506 #[serde(default, skip_serializing_if = "Option::is_none")]
511 pub cost_tiers: Option<Vec<String>>,
512}
513
514impl Default for CascadeConfig {
515 fn default() -> Self {
516 Self {
517 quality_threshold: default_cascade_quality_threshold(),
518 max_escalations: default_cascade_max_escalations(),
519 classifier_mode: CascadeClassifierMode::default(),
520 window_size: default_cascade_window_size(),
521 max_cascade_tokens: None,
522 cost_tiers: None,
523 }
524 }
525}
526
527#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
529#[serde(rename_all = "lowercase")]
530pub enum CascadeClassifierMode {
531 #[default]
534 Heuristic,
535 Judge,
538}
539
540fn default_bandit_alpha() -> f32 {
541 1.0
542}
543
544fn default_bandit_dim() -> usize {
545 32
546}
547
548fn default_bandit_cost_weight() -> f32 {
549 0.1
550}
551
552fn default_bandit_decay_factor() -> f32 {
553 1.0
554}
555
556fn default_bandit_embedding_timeout_ms() -> u64 {
557 50
558}
559
560fn default_bandit_cache_size() -> usize {
561 512
562}
563
564#[derive(Debug, Clone, Deserialize, Serialize)]
577pub struct BanditConfig {
578 #[serde(default = "default_bandit_alpha")]
581 pub alpha: f32,
582
583 #[serde(default = "default_bandit_dim")]
590 pub dim: usize,
591
592 #[serde(default = "default_bandit_cost_weight")]
595 pub cost_weight: f32,
596
597 #[serde(default = "default_bandit_decay_factor")]
600 pub decay_factor: f32,
601
602 #[serde(default)]
608 pub embedding_provider: ProviderName,
609
610 #[serde(default = "default_bandit_embedding_timeout_ms")]
613 pub embedding_timeout_ms: u64,
614
615 #[serde(default = "default_bandit_cache_size")]
617 pub cache_size: usize,
618
619 #[serde(default)]
626 pub state_path: Option<String>,
627
628 #[serde(default = "default_bandit_memory_confidence_threshold")]
634 pub memory_confidence_threshold: f32,
635
636 #[serde(default)]
642 pub warmup_queries: Option<u64>,
643}
644
645fn default_bandit_memory_confidence_threshold() -> f32 {
646 0.9
647}
648
649impl Default for BanditConfig {
650 fn default() -> Self {
651 Self {
652 alpha: default_bandit_alpha(),
653 dim: default_bandit_dim(),
654 cost_weight: default_bandit_cost_weight(),
655 decay_factor: default_bandit_decay_factor(),
656 embedding_provider: ProviderName::default(),
657 embedding_timeout_ms: default_bandit_embedding_timeout_ms(),
658 cache_size: default_bandit_cache_size(),
659 state_path: None,
660 memory_confidence_threshold: default_bandit_memory_confidence_threshold(),
661 warmup_queries: None,
662 }
663 }
664}
665
666#[derive(Debug, Deserialize, Serialize)]
667pub struct CandleConfig {
668 #[serde(default = "default_candle_source")]
669 pub source: String,
670 #[serde(default)]
671 pub local_path: String,
672 #[serde(default)]
673 pub filename: Option<String>,
674 #[serde(default = "default_chat_template")]
675 pub chat_template: String,
676 #[serde(default = "default_candle_device")]
677 pub device: String,
678 #[serde(default)]
679 pub embedding_repo: Option<String>,
680 #[serde(default)]
684 pub hf_token: Option<String>,
685 #[serde(default)]
686 pub generation: GenerationParams,
687}
688
689#[derive(Debug, Clone, Deserialize, Serialize)]
690pub struct GenerationParams {
691 #[serde(default = "default_temperature")]
692 pub temperature: f64,
693 #[serde(default)]
694 pub top_p: Option<f64>,
695 #[serde(default)]
696 pub top_k: Option<usize>,
697 #[serde(default = "default_max_tokens")]
698 pub max_tokens: usize,
699 #[serde(default = "default_seed")]
700 pub seed: u64,
701 #[serde(default = "default_repeat_penalty")]
702 pub repeat_penalty: f32,
703 #[serde(default = "default_repeat_last_n")]
704 pub repeat_last_n: usize,
705}
706
707pub const MAX_TOKENS_CAP: usize = 32768;
708
709impl GenerationParams {
710 #[must_use]
711 pub fn capped_max_tokens(&self) -> usize {
712 self.max_tokens.min(MAX_TOKENS_CAP)
713 }
714}
715
716impl Default for GenerationParams {
717 fn default() -> Self {
718 Self {
719 temperature: default_temperature(),
720 top_p: None,
721 top_k: None,
722 max_tokens: default_max_tokens(),
723 seed: default_seed(),
724 repeat_penalty: default_repeat_penalty(),
725 repeat_last_n: default_repeat_last_n(),
726 }
727 }
728}
729
730#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
734#[serde(rename_all = "lowercase")]
735pub enum LlmRoutingStrategy {
736 #[default]
738 None,
739 Ema,
741 Thompson,
743 Cascade,
745 Task,
747 Triage,
749 Bandit,
751}
752
753fn default_triage_timeout_secs() -> u64 {
754 5
755}
756
757fn default_max_triage_tokens() -> u32 {
758 50
759}
760
761fn default_true() -> bool {
762 true
763}
764
765#[derive(Debug, Clone, Default, Deserialize, Serialize)]
767pub struct TierMapping {
768 pub simple: Option<String>,
769 pub medium: Option<String>,
770 pub complex: Option<String>,
771 pub expert: Option<String>,
772}
773
774#[derive(Debug, Clone, Deserialize, Serialize)]
795pub struct ComplexityRoutingConfig {
796 #[serde(default)]
798 pub triage_provider: Option<ProviderName>,
799
800 #[serde(default = "default_true")]
802 pub bypass_single_provider: bool,
803
804 #[serde(default)]
806 pub tiers: TierMapping,
807
808 #[serde(default = "default_max_triage_tokens")]
810 pub max_triage_tokens: u32,
811
812 #[serde(default = "default_triage_timeout_secs")]
815 pub triage_timeout_secs: u64,
816
817 #[serde(default)]
820 pub fallback_strategy: Option<String>,
821}
822
823impl Default for ComplexityRoutingConfig {
824 fn default() -> Self {
825 Self {
826 triage_provider: None,
827 bypass_single_provider: true,
828 tiers: TierMapping::default(),
829 max_triage_tokens: default_max_triage_tokens(),
830 triage_timeout_secs: default_triage_timeout_secs(),
831 fallback_strategy: None,
832 }
833 }
834}
835
836#[derive(Debug, Clone, Deserialize, Serialize)]
839pub struct CandleInlineConfig {
840 #[serde(default = "default_candle_source")]
841 pub source: String,
842 #[serde(default)]
843 pub local_path: String,
844 #[serde(default)]
845 pub filename: Option<String>,
846 #[serde(default = "default_chat_template")]
847 pub chat_template: String,
848 #[serde(default = "default_candle_device")]
849 pub device: String,
850 #[serde(default)]
851 pub embedding_repo: Option<String>,
852 #[serde(default)]
854 pub hf_token: Option<String>,
855 #[serde(default)]
856 pub generation: GenerationParams,
857}
858
859impl Default for CandleInlineConfig {
860 fn default() -> Self {
861 Self {
862 source: default_candle_source(),
863 local_path: String::new(),
864 filename: None,
865 chat_template: default_chat_template(),
866 device: default_candle_device(),
867 embedding_repo: None,
868 hf_token: None,
869 generation: GenerationParams::default(),
870 }
871 }
872}
873
874#[derive(Debug, Clone, Deserialize, Serialize)]
880#[allow(clippy::struct_excessive_bools)]
881pub struct ProviderEntry {
882 #[serde(rename = "type")]
884 pub provider_type: ProviderKind,
885
886 #[serde(default)]
888 pub name: Option<String>,
889
890 #[serde(default)]
892 pub model: Option<String>,
893
894 #[serde(default)]
896 pub base_url: Option<String>,
897
898 #[serde(default)]
900 pub max_tokens: Option<u32>,
901
902 #[serde(default)]
904 pub embedding_model: Option<String>,
905
906 #[serde(default)]
909 pub stt_model: Option<String>,
910
911 #[serde(default)]
913 pub embed: bool,
914
915 #[serde(default)]
917 pub default: bool,
918
919 #[serde(default)]
921 pub thinking: Option<ThinkingConfig>,
922 #[serde(default)]
923 pub server_compaction: bool,
924 #[serde(default)]
925 pub enable_extended_context: bool,
926
927 #[serde(default)]
929 pub reasoning_effort: Option<String>,
930
931 #[serde(default)]
933 pub thinking_level: Option<GeminiThinkingLevel>,
934 #[serde(default)]
935 pub thinking_budget: Option<i32>,
936 #[serde(default)]
937 pub include_thoughts: Option<bool>,
938
939 #[serde(default)]
941 pub tool_use: bool,
942
943 #[serde(default)]
945 pub api_key: Option<String>,
946
947 #[serde(default)]
949 pub candle: Option<CandleInlineConfig>,
950
951 #[serde(default)]
953 pub vision_model: Option<String>,
954
955 #[serde(default)]
957 pub instruction_file: Option<std::path::PathBuf>,
958}
959
960impl Default for ProviderEntry {
961 fn default() -> Self {
962 Self {
963 provider_type: ProviderKind::Ollama,
964 name: None,
965 model: None,
966 base_url: None,
967 max_tokens: None,
968 embedding_model: None,
969 stt_model: None,
970 embed: false,
971 default: false,
972 thinking: None,
973 server_compaction: false,
974 enable_extended_context: false,
975 reasoning_effort: None,
976 thinking_level: None,
977 thinking_budget: None,
978 include_thoughts: None,
979 tool_use: false,
980 api_key: None,
981 candle: None,
982 vision_model: None,
983 instruction_file: None,
984 }
985 }
986}
987
988impl ProviderEntry {
989 #[must_use]
991 pub fn effective_name(&self) -> String {
992 self.name
993 .clone()
994 .unwrap_or_else(|| self.provider_type.as_str().to_owned())
995 }
996
997 #[must_use]
1002 pub fn effective_model(&self) -> String {
1003 if let Some(ref m) = self.model {
1004 return m.clone();
1005 }
1006 match self.provider_type {
1007 ProviderKind::Ollama => "qwen3:8b".to_owned(),
1008 ProviderKind::Claude => "claude-haiku-4-5-20251001".to_owned(),
1009 ProviderKind::OpenAi => "gpt-4o-mini".to_owned(),
1010 ProviderKind::Gemini => "gemini-2.0-flash".to_owned(),
1011 ProviderKind::Compatible | ProviderKind::Candle => String::new(),
1012 }
1013 }
1014
1015 pub fn validate(&self) -> Result<(), crate::error::ConfigError> {
1022 use crate::error::ConfigError;
1023
1024 if self.provider_type == ProviderKind::Compatible && self.name.is_none() {
1026 return Err(ConfigError::Validation(
1027 "[[llm.providers]] entry with type=\"compatible\" must set `name`".into(),
1028 ));
1029 }
1030
1031 match self.provider_type {
1033 ProviderKind::Ollama => {
1034 if self.thinking.is_some() {
1035 tracing::warn!(
1036 provider = self.effective_name(),
1037 "field `thinking` is only used by Claude providers"
1038 );
1039 }
1040 if self.reasoning_effort.is_some() {
1041 tracing::warn!(
1042 provider = self.effective_name(),
1043 "field `reasoning_effort` is only used by OpenAI providers"
1044 );
1045 }
1046 if self.thinking_level.is_some() || self.thinking_budget.is_some() {
1047 tracing::warn!(
1048 provider = self.effective_name(),
1049 "fields `thinking_level`/`thinking_budget` are only used by Gemini providers"
1050 );
1051 }
1052 }
1053 ProviderKind::Claude => {
1054 if self.reasoning_effort.is_some() {
1055 tracing::warn!(
1056 provider = self.effective_name(),
1057 "field `reasoning_effort` is only used by OpenAI providers"
1058 );
1059 }
1060 if self.thinking_level.is_some() || self.thinking_budget.is_some() {
1061 tracing::warn!(
1062 provider = self.effective_name(),
1063 "fields `thinking_level`/`thinking_budget` are only used by Gemini providers"
1064 );
1065 }
1066 if self.tool_use {
1067 tracing::warn!(
1068 provider = self.effective_name(),
1069 "field `tool_use` is only used by Ollama providers"
1070 );
1071 }
1072 }
1073 ProviderKind::OpenAi => {
1074 if self.thinking.is_some() {
1075 tracing::warn!(
1076 provider = self.effective_name(),
1077 "field `thinking` is only used by Claude providers"
1078 );
1079 }
1080 if self.thinking_level.is_some() || self.thinking_budget.is_some() {
1081 tracing::warn!(
1082 provider = self.effective_name(),
1083 "fields `thinking_level`/`thinking_budget` are only used by Gemini providers"
1084 );
1085 }
1086 if self.tool_use {
1087 tracing::warn!(
1088 provider = self.effective_name(),
1089 "field `tool_use` is only used by Ollama providers"
1090 );
1091 }
1092 }
1093 ProviderKind::Gemini => {
1094 if self.thinking.is_some() {
1095 tracing::warn!(
1096 provider = self.effective_name(),
1097 "field `thinking` is only used by Claude providers"
1098 );
1099 }
1100 if self.reasoning_effort.is_some() {
1101 tracing::warn!(
1102 provider = self.effective_name(),
1103 "field `reasoning_effort` is only used by OpenAI providers"
1104 );
1105 }
1106 if self.tool_use {
1107 tracing::warn!(
1108 provider = self.effective_name(),
1109 "field `tool_use` is only used by Ollama providers"
1110 );
1111 }
1112 }
1113 _ => {}
1114 }
1115
1116 if self.stt_model.is_some() && self.provider_type == ProviderKind::Ollama {
1119 tracing::warn!(
1120 provider = self.effective_name(),
1121 "field `stt_model` is set on an Ollama provider; Ollama does not support the \
1122 Whisper STT API — use OpenAI, compatible, or candle instead"
1123 );
1124 }
1125
1126 Ok(())
1127 }
1128}
1129
1130pub fn validate_pool(entries: &[ProviderEntry]) -> Result<(), crate::error::ConfigError> {
1140 use crate::error::ConfigError;
1141 use std::collections::HashSet;
1142
1143 if entries.is_empty() {
1144 return Err(ConfigError::Validation(
1145 "at least one LLM provider must be configured in [[llm.providers]]".into(),
1146 ));
1147 }
1148
1149 let default_count = entries.iter().filter(|e| e.default).count();
1150 if default_count > 1 {
1151 return Err(ConfigError::Validation(
1152 "only one [[llm.providers]] entry can be marked `default = true`".into(),
1153 ));
1154 }
1155
1156 let mut seen_names: HashSet<String> = HashSet::new();
1157 for entry in entries {
1158 let name = entry.effective_name();
1159 if !seen_names.insert(name.clone()) {
1160 return Err(ConfigError::Validation(format!(
1161 "duplicate provider name \"{name}\" in [[llm.providers]]"
1162 )));
1163 }
1164 entry.validate()?;
1165 }
1166
1167 Ok(())
1168}
1169
1170#[cfg(test)]
1171mod tests {
1172 use super::*;
1173
1174 fn ollama_entry() -> ProviderEntry {
1175 ProviderEntry {
1176 provider_type: ProviderKind::Ollama,
1177 name: Some("ollama".into()),
1178 model: Some("qwen3:8b".into()),
1179 ..Default::default()
1180 }
1181 }
1182
1183 fn claude_entry() -> ProviderEntry {
1184 ProviderEntry {
1185 provider_type: ProviderKind::Claude,
1186 name: Some("claude".into()),
1187 model: Some("claude-sonnet-4-6".into()),
1188 max_tokens: Some(8192),
1189 ..Default::default()
1190 }
1191 }
1192
1193 #[test]
1196 fn validate_ollama_valid() {
1197 assert!(ollama_entry().validate().is_ok());
1198 }
1199
1200 #[test]
1201 fn validate_claude_valid() {
1202 assert!(claude_entry().validate().is_ok());
1203 }
1204
1205 #[test]
1206 fn validate_compatible_without_name_errors() {
1207 let entry = ProviderEntry {
1208 provider_type: ProviderKind::Compatible,
1209 name: None,
1210 ..Default::default()
1211 };
1212 let err = entry.validate().unwrap_err();
1213 assert!(
1214 err.to_string().contains("compatible"),
1215 "error should mention compatible: {err}"
1216 );
1217 }
1218
1219 #[test]
1220 fn validate_compatible_with_name_ok() {
1221 let entry = ProviderEntry {
1222 provider_type: ProviderKind::Compatible,
1223 name: Some("my-proxy".into()),
1224 base_url: Some("http://localhost:8080".into()),
1225 model: Some("gpt-4o".into()),
1226 max_tokens: Some(4096),
1227 ..Default::default()
1228 };
1229 assert!(entry.validate().is_ok());
1230 }
1231
1232 #[test]
1233 fn validate_openai_valid() {
1234 let entry = ProviderEntry {
1235 provider_type: ProviderKind::OpenAi,
1236 name: Some("openai".into()),
1237 model: Some("gpt-4o".into()),
1238 max_tokens: Some(4096),
1239 ..Default::default()
1240 };
1241 assert!(entry.validate().is_ok());
1242 }
1243
1244 #[test]
1245 fn validate_gemini_valid() {
1246 let entry = ProviderEntry {
1247 provider_type: ProviderKind::Gemini,
1248 name: Some("gemini".into()),
1249 model: Some("gemini-2.0-flash".into()),
1250 ..Default::default()
1251 };
1252 assert!(entry.validate().is_ok());
1253 }
1254
1255 #[test]
1258 fn validate_pool_empty_errors() {
1259 let err = validate_pool(&[]).unwrap_err();
1260 assert!(err.to_string().contains("at least one"), "{err}");
1261 }
1262
1263 #[test]
1264 fn validate_pool_single_entry_ok() {
1265 assert!(validate_pool(&[ollama_entry()]).is_ok());
1266 }
1267
1268 #[test]
1269 fn validate_pool_duplicate_names_errors() {
1270 let a = ollama_entry();
1271 let b = ollama_entry(); let err = validate_pool(&[a, b]).unwrap_err();
1273 assert!(err.to_string().contains("duplicate"), "{err}");
1274 }
1275
1276 #[test]
1277 fn validate_pool_multiple_defaults_errors() {
1278 let mut a = ollama_entry();
1279 let mut b = claude_entry();
1280 a.default = true;
1281 b.default = true;
1282 let err = validate_pool(&[a, b]).unwrap_err();
1283 assert!(err.to_string().contains("default"), "{err}");
1284 }
1285
1286 #[test]
1287 fn validate_pool_two_different_providers_ok() {
1288 assert!(validate_pool(&[ollama_entry(), claude_entry()]).is_ok());
1289 }
1290
1291 #[test]
1292 fn validate_pool_propagates_entry_error() {
1293 let bad = ProviderEntry {
1294 provider_type: ProviderKind::Compatible,
1295 name: None, ..Default::default()
1297 };
1298 assert!(validate_pool(&[bad]).is_err());
1299 }
1300
1301 #[test]
1304 fn effective_model_returns_explicit_when_set() {
1305 let entry = ProviderEntry {
1306 provider_type: ProviderKind::Claude,
1307 model: Some("claude-sonnet-4-6".into()),
1308 ..Default::default()
1309 };
1310 assert_eq!(entry.effective_model(), "claude-sonnet-4-6");
1311 }
1312
1313 #[test]
1314 fn effective_model_ollama_default_when_none() {
1315 let entry = ProviderEntry {
1316 provider_type: ProviderKind::Ollama,
1317 model: None,
1318 ..Default::default()
1319 };
1320 assert_eq!(entry.effective_model(), "qwen3:8b");
1321 }
1322
1323 #[test]
1324 fn effective_model_claude_default_when_none() {
1325 let entry = ProviderEntry {
1326 provider_type: ProviderKind::Claude,
1327 model: None,
1328 ..Default::default()
1329 };
1330 assert_eq!(entry.effective_model(), "claude-haiku-4-5-20251001");
1331 }
1332
1333 #[test]
1334 fn effective_model_openai_default_when_none() {
1335 let entry = ProviderEntry {
1336 provider_type: ProviderKind::OpenAi,
1337 model: None,
1338 ..Default::default()
1339 };
1340 assert_eq!(entry.effective_model(), "gpt-4o-mini");
1341 }
1342
1343 #[test]
1344 fn effective_model_gemini_default_when_none() {
1345 let entry = ProviderEntry {
1346 provider_type: ProviderKind::Gemini,
1347 model: None,
1348 ..Default::default()
1349 };
1350 assert_eq!(entry.effective_model(), "gemini-2.0-flash");
1351 }
1352
1353 fn parse_llm(toml: &str) -> LlmConfig {
1357 #[derive(serde::Deserialize)]
1358 struct Wrapper {
1359 llm: LlmConfig,
1360 }
1361 toml::from_str::<Wrapper>(toml).unwrap().llm
1362 }
1363
1364 #[test]
1365 fn check_legacy_format_new_format_ok() {
1366 let cfg = parse_llm(
1367 r#"
1368[llm]
1369
1370[[llm.providers]]
1371type = "ollama"
1372model = "qwen3:8b"
1373"#,
1374 );
1375 assert!(cfg.check_legacy_format().is_ok());
1376 }
1377
1378 #[test]
1379 fn check_legacy_format_empty_providers_no_legacy_ok() {
1380 let cfg = parse_llm("[llm]\n");
1382 assert!(cfg.check_legacy_format().is_ok());
1383 }
1384
1385 #[test]
1388 fn effective_provider_falls_back_to_ollama_when_no_providers() {
1389 let cfg = parse_llm("[llm]\n");
1390 assert_eq!(cfg.effective_provider(), ProviderKind::Ollama);
1391 }
1392
1393 #[test]
1394 fn effective_provider_reads_from_providers_first() {
1395 let cfg = parse_llm(
1396 r#"
1397[llm]
1398
1399[[llm.providers]]
1400type = "claude"
1401model = "claude-sonnet-4-6"
1402"#,
1403 );
1404 assert_eq!(cfg.effective_provider(), ProviderKind::Claude);
1405 }
1406
1407 #[test]
1408 fn effective_model_reads_from_providers_first() {
1409 let cfg = parse_llm(
1410 r#"
1411[llm]
1412
1413[[llm.providers]]
1414type = "ollama"
1415model = "qwen3:8b"
1416"#,
1417 );
1418 assert_eq!(cfg.effective_model(), "qwen3:8b");
1419 }
1420
1421 #[test]
1422 fn effective_base_url_default_when_absent() {
1423 let cfg = parse_llm("[llm]\n");
1424 assert_eq!(cfg.effective_base_url(), "http://localhost:11434");
1425 }
1426
1427 #[test]
1428 fn effective_base_url_from_providers_entry() {
1429 let cfg = parse_llm(
1430 r#"
1431[llm]
1432
1433[[llm.providers]]
1434type = "ollama"
1435base_url = "http://myhost:11434"
1436"#,
1437 );
1438 assert_eq!(cfg.effective_base_url(), "http://myhost:11434");
1439 }
1440
1441 #[test]
1444 fn complexity_routing_defaults() {
1445 let cr = ComplexityRoutingConfig::default();
1446 assert!(
1447 cr.bypass_single_provider,
1448 "bypass_single_provider must default to true"
1449 );
1450 assert_eq!(cr.triage_timeout_secs, 5);
1451 assert_eq!(cr.max_triage_tokens, 50);
1452 assert!(cr.triage_provider.is_none());
1453 assert!(cr.tiers.simple.is_none());
1454 }
1455
1456 #[test]
1457 fn complexity_routing_toml_round_trip() {
1458 let cfg = parse_llm(
1459 r#"
1460[llm]
1461routing = "triage"
1462
1463[llm.complexity_routing]
1464triage_provider = "fast"
1465bypass_single_provider = false
1466triage_timeout_secs = 10
1467max_triage_tokens = 100
1468
1469[llm.complexity_routing.tiers]
1470simple = "fast"
1471medium = "medium"
1472complex = "large"
1473expert = "opus"
1474"#,
1475 );
1476 assert!(matches!(cfg.routing, LlmRoutingStrategy::Triage));
1477 let cr = cfg
1478 .complexity_routing
1479 .expect("complexity_routing must be present");
1480 assert_eq!(cr.triage_provider.as_deref(), Some("fast"));
1481 assert!(!cr.bypass_single_provider);
1482 assert_eq!(cr.triage_timeout_secs, 10);
1483 assert_eq!(cr.max_triage_tokens, 100);
1484 assert_eq!(cr.tiers.simple.as_deref(), Some("fast"));
1485 assert_eq!(cr.tiers.medium.as_deref(), Some("medium"));
1486 assert_eq!(cr.tiers.complex.as_deref(), Some("large"));
1487 assert_eq!(cr.tiers.expert.as_deref(), Some("opus"));
1488 }
1489
1490 #[test]
1491 fn complexity_routing_partial_tiers_toml() {
1492 let cfg = parse_llm(
1494 r#"
1495[llm]
1496routing = "triage"
1497
1498[llm.complexity_routing.tiers]
1499simple = "haiku"
1500complex = "sonnet"
1501"#,
1502 );
1503 let cr = cfg
1504 .complexity_routing
1505 .expect("complexity_routing must be present");
1506 assert_eq!(cr.tiers.simple.as_deref(), Some("haiku"));
1507 assert!(cr.tiers.medium.is_none());
1508 assert_eq!(cr.tiers.complex.as_deref(), Some("sonnet"));
1509 assert!(cr.tiers.expert.is_none());
1510 assert!(cr.bypass_single_provider);
1512 assert_eq!(cr.triage_timeout_secs, 5);
1513 }
1514
1515 #[test]
1516 fn routing_strategy_triage_deserialized() {
1517 let cfg = parse_llm(
1518 r#"
1519[llm]
1520routing = "triage"
1521"#,
1522 );
1523 assert!(matches!(cfg.routing, LlmRoutingStrategy::Triage));
1524 }
1525
1526 #[test]
1529 fn stt_provider_entry_by_name_match() {
1530 let cfg = parse_llm(
1531 r#"
1532[llm]
1533
1534[[llm.providers]]
1535type = "openai"
1536name = "quality"
1537model = "gpt-5.4"
1538stt_model = "gpt-4o-mini-transcribe"
1539
1540[llm.stt]
1541provider = "quality"
1542"#,
1543 );
1544 let entry = cfg.stt_provider_entry().expect("should find stt provider");
1545 assert_eq!(entry.effective_name(), "quality");
1546 assert_eq!(entry.stt_model.as_deref(), Some("gpt-4o-mini-transcribe"));
1547 }
1548
1549 #[test]
1550 fn stt_provider_entry_auto_detect_when_provider_empty() {
1551 let cfg = parse_llm(
1552 r#"
1553[llm]
1554
1555[[llm.providers]]
1556type = "openai"
1557name = "openai-stt"
1558stt_model = "whisper-1"
1559
1560[llm.stt]
1561provider = ""
1562"#,
1563 );
1564 let entry = cfg.stt_provider_entry().expect("should auto-detect");
1565 assert_eq!(entry.effective_name(), "openai-stt");
1566 }
1567
1568 #[test]
1569 fn stt_provider_entry_auto_detect_no_stt_section() {
1570 let cfg = parse_llm(
1571 r#"
1572[llm]
1573
1574[[llm.providers]]
1575type = "openai"
1576name = "openai-stt"
1577stt_model = "whisper-1"
1578"#,
1579 );
1580 let entry = cfg.stt_provider_entry().expect("should auto-detect");
1582 assert_eq!(entry.effective_name(), "openai-stt");
1583 }
1584
1585 #[test]
1586 fn stt_provider_entry_none_when_no_stt_model() {
1587 let cfg = parse_llm(
1588 r#"
1589[llm]
1590
1591[[llm.providers]]
1592type = "openai"
1593name = "quality"
1594model = "gpt-5.4"
1595"#,
1596 );
1597 assert!(cfg.stt_provider_entry().is_none());
1598 }
1599
1600 #[test]
1601 fn stt_provider_entry_name_mismatch_falls_back_to_none() {
1602 let cfg = parse_llm(
1604 r#"
1605[llm]
1606
1607[[llm.providers]]
1608type = "openai"
1609name = "quality"
1610model = "gpt-5.4"
1611
1612[[llm.providers]]
1613type = "openai"
1614name = "openai-stt"
1615stt_model = "whisper-1"
1616
1617[llm.stt]
1618provider = "quality"
1619"#,
1620 );
1621 assert!(cfg.stt_provider_entry().is_none());
1623 }
1624
1625 #[test]
1626 fn stt_config_deserializes_new_slim_format() {
1627 let cfg = parse_llm(
1628 r#"
1629[llm]
1630
1631[[llm.providers]]
1632type = "openai"
1633name = "quality"
1634stt_model = "whisper-1"
1635
1636[llm.stt]
1637provider = "quality"
1638language = "en"
1639"#,
1640 );
1641 let stt = cfg.stt.as_ref().expect("stt section present");
1642 assert_eq!(stt.provider, "quality");
1643 assert_eq!(stt.language, "en");
1644 }
1645
1646 #[test]
1647 fn stt_config_default_provider_is_empty() {
1648 assert_eq!(default_stt_provider(), "");
1650 }
1651
1652 #[test]
1653 fn validate_stt_missing_provider_ok() {
1654 let cfg = parse_llm("[llm]\n");
1655 assert!(cfg.validate_stt().is_ok());
1656 }
1657
1658 #[test]
1659 fn validate_stt_valid_reference() {
1660 let cfg = parse_llm(
1661 r#"
1662[llm]
1663
1664[[llm.providers]]
1665type = "openai"
1666name = "quality"
1667stt_model = "whisper-1"
1668
1669[llm.stt]
1670provider = "quality"
1671"#,
1672 );
1673 assert!(cfg.validate_stt().is_ok());
1674 }
1675
1676 #[test]
1677 fn validate_stt_nonexistent_provider_errors() {
1678 let cfg = parse_llm(
1679 r#"
1680[llm]
1681
1682[[llm.providers]]
1683type = "openai"
1684name = "quality"
1685model = "gpt-5.4"
1686
1687[llm.stt]
1688provider = "nonexistent"
1689"#,
1690 );
1691 assert!(cfg.validate_stt().is_err());
1692 }
1693
1694 #[test]
1695 fn validate_stt_provider_exists_but_no_stt_model_returns_ok_with_warn() {
1696 let cfg = parse_llm(
1698 r#"
1699[llm]
1700
1701[[llm.providers]]
1702type = "openai"
1703name = "quality"
1704model = "gpt-5.4"
1705
1706[llm.stt]
1707provider = "quality"
1708"#,
1709 );
1710 assert!(cfg.validate_stt().is_ok());
1712 assert!(
1714 cfg.stt_provider_entry().is_none(),
1715 "stt_provider_entry must be None when provider has no stt_model"
1716 );
1717 }
1718
1719 #[test]
1722 fn bandit_warmup_queries_explicit_value_is_deserialized() {
1723 let cfg = parse_llm(
1724 r#"
1725[llm]
1726
1727[llm.router]
1728strategy = "bandit"
1729
1730[llm.router.bandit]
1731warmup_queries = 50
1732"#,
1733 );
1734 let bandit = cfg
1735 .router
1736 .expect("router section must be present")
1737 .bandit
1738 .expect("bandit section must be present");
1739 assert_eq!(
1740 bandit.warmup_queries,
1741 Some(50),
1742 "warmup_queries = 50 must deserialize to Some(50)"
1743 );
1744 }
1745
1746 #[test]
1747 fn bandit_warmup_queries_explicit_null_is_none() {
1748 let cfg = parse_llm(
1751 r#"
1752[llm]
1753
1754[llm.router]
1755strategy = "bandit"
1756
1757[llm.router.bandit]
1758warmup_queries = 0
1759"#,
1760 );
1761 let bandit = cfg
1762 .router
1763 .expect("router section must be present")
1764 .bandit
1765 .expect("bandit section must be present");
1766 assert_eq!(
1768 bandit.warmup_queries,
1769 Some(0),
1770 "warmup_queries = 0 must deserialize to Some(0)"
1771 );
1772 }
1773
1774 #[test]
1775 fn bandit_warmup_queries_missing_field_defaults_to_none() {
1776 let cfg = parse_llm(
1778 r#"
1779[llm]
1780
1781[llm.router]
1782strategy = "bandit"
1783
1784[llm.router.bandit]
1785alpha = 1.5
1786"#,
1787 );
1788 let bandit = cfg
1789 .router
1790 .expect("router section must be present")
1791 .bandit
1792 .expect("bandit section must be present");
1793 assert_eq!(
1794 bandit.warmup_queries, None,
1795 "omitted warmup_queries must default to None"
1796 );
1797 }
1798
1799 #[test]
1800 fn provider_name_new_and_as_str() {
1801 let n = ProviderName::new("fast");
1802 assert_eq!(n.as_str(), "fast");
1803 assert!(!n.is_empty());
1804 }
1805
1806 #[test]
1807 fn provider_name_default_is_empty() {
1808 let n = ProviderName::default();
1809 assert!(n.is_empty());
1810 assert_eq!(n.as_str(), "");
1811 }
1812
1813 #[test]
1814 fn provider_name_deref_to_str() {
1815 let n = ProviderName::new("quality");
1816 let s: &str = &n;
1817 assert_eq!(s, "quality");
1818 }
1819
1820 #[test]
1821 fn provider_name_partial_eq_str() {
1822 let n = ProviderName::new("fast");
1823 assert_eq!(n, "fast");
1824 assert_ne!(n, "slow");
1825 }
1826
1827 #[test]
1828 fn provider_name_serde_roundtrip() {
1829 let n = ProviderName::new("my-provider");
1830 let json = serde_json::to_string(&n).expect("serialize");
1831 assert_eq!(json, "\"my-provider\"");
1832 let back: ProviderName = serde_json::from_str(&json).expect("deserialize");
1833 assert_eq!(back, n);
1834 }
1835
1836 #[test]
1837 fn provider_name_serde_empty_roundtrip() {
1838 let n = ProviderName::default();
1839 let json = serde_json::to_string(&n).expect("serialize");
1840 assert_eq!(json, "\"\"");
1841 let back: ProviderName = serde_json::from_str(&json).expect("deserialize");
1842 assert_eq!(back, n);
1843 assert!(back.is_empty());
1844 }
1845}