1use serde::{Deserialize, Serialize};
5use zeph_llm::{GeminiThinkingLevel, ThinkingConfig};
6
7fn default_response_cache_ttl_secs() -> u64 {
8 3600
9}
10
11fn default_semantic_cache_threshold() -> f32 {
12 0.95
13}
14
15fn default_semantic_cache_max_candidates() -> u32 {
16 10
17}
18
19fn default_router_ema_alpha() -> f64 {
20 0.1
21}
22
23fn default_router_reorder_interval() -> u64 {
24 10
25}
26
27fn default_embedding_model() -> String {
28 "qwen3-embedding".into()
29}
30
31fn default_candle_source() -> String {
32 "huggingface".into()
33}
34
35fn default_chat_template() -> String {
36 "chatml".into()
37}
38
39fn default_candle_device() -> String {
40 "cpu".into()
41}
42
43fn default_temperature() -> f64 {
44 0.7
45}
46
47fn default_max_tokens() -> usize {
48 2048
49}
50
51fn default_seed() -> u64 {
52 42
53}
54
55fn default_repeat_penalty() -> f32 {
56 1.1
57}
58
59fn default_repeat_last_n() -> usize {
60 64
61}
62
63fn default_cascade_quality_threshold() -> f64 {
64 0.5
65}
66
67fn default_cascade_max_escalations() -> u8 {
68 2
69}
70
71fn default_cascade_window_size() -> usize {
72 50
73}
74
75fn default_reputation_decay_factor() -> f64 {
76 0.95
77}
78
79fn default_reputation_weight() -> f64 {
80 0.3
81}
82
83fn default_reputation_min_observations() -> u64 {
84 5
85}
86
87#[must_use]
88pub fn default_stt_provider() -> String {
89 String::new()
90}
91
92#[must_use]
93pub fn default_stt_language() -> String {
94 "auto".into()
95}
96
97#[must_use]
98pub fn get_default_embedding_model() -> String {
99 default_embedding_model()
100}
101
102#[must_use]
103pub fn get_default_response_cache_ttl_secs() -> u64 {
104 default_response_cache_ttl_secs()
105}
106
107#[must_use]
108pub fn get_default_router_ema_alpha() -> f64 {
109 default_router_ema_alpha()
110}
111
112#[must_use]
113pub fn get_default_router_reorder_interval() -> u64 {
114 default_router_reorder_interval()
115}
116
117#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
119#[serde(rename_all = "lowercase")]
120pub enum ProviderKind {
121 Ollama,
122 Claude,
123 OpenAi,
124 Gemini,
125 Candle,
126 Compatible,
127}
128
129impl ProviderKind {
130 #[must_use]
131 pub fn as_str(self) -> &'static str {
132 match self {
133 Self::Ollama => "ollama",
134 Self::Claude => "claude",
135 Self::OpenAi => "openai",
136 Self::Gemini => "gemini",
137 Self::Candle => "candle",
138 Self::Compatible => "compatible",
139 }
140 }
141}
142
143impl std::fmt::Display for ProviderKind {
144 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145 f.write_str(self.as_str())
146 }
147}
148
149#[derive(Debug, Deserialize, Serialize)]
150pub struct LlmConfig {
151 #[serde(default, skip_serializing_if = "Vec::is_empty")]
153 pub providers: Vec<ProviderEntry>,
154
155 #[serde(default, skip_serializing_if = "is_routing_none")]
157 pub routing: LlmRoutingStrategy,
158
159 #[serde(default, skip_serializing_if = "std::collections::HashMap::is_empty")]
161 pub routes: std::collections::HashMap<String, Vec<String>>,
162
163 #[serde(default = "default_embedding_model_opt")]
164 pub embedding_model: String,
165 #[serde(default, skip_serializing_if = "Option::is_none")]
166 pub candle: Option<CandleConfig>,
167 #[serde(default)]
168 pub stt: Option<SttConfig>,
169 #[serde(default)]
170 pub response_cache_enabled: bool,
171 #[serde(default = "default_response_cache_ttl_secs")]
172 pub response_cache_ttl_secs: u64,
173 #[serde(default)]
175 pub semantic_cache_enabled: bool,
176 #[serde(default = "default_semantic_cache_threshold")]
182 pub semantic_cache_threshold: f32,
183 #[serde(default = "default_semantic_cache_max_candidates")]
196 pub semantic_cache_max_candidates: u32,
197 #[serde(default)]
198 pub router_ema_enabled: bool,
199 #[serde(default = "default_router_ema_alpha")]
200 pub router_ema_alpha: f64,
201 #[serde(default = "default_router_reorder_interval")]
202 pub router_reorder_interval: u64,
203 #[serde(default, skip_serializing_if = "Option::is_none")]
205 pub router: Option<RouterConfig>,
206 #[serde(default, skip_serializing_if = "Option::is_none")]
209 pub instruction_file: Option<std::path::PathBuf>,
210 #[serde(default, skip_serializing_if = "Option::is_none")]
214 pub summary_model: Option<String>,
215 #[serde(default, skip_serializing_if = "Option::is_none")]
217 pub summary_provider: Option<ProviderEntry>,
218
219 #[serde(default, skip_serializing_if = "Option::is_none")]
221 pub complexity_routing: Option<ComplexityRoutingConfig>,
222}
223
224fn default_embedding_model_opt() -> String {
225 default_embedding_model()
226}
227
228#[allow(clippy::trivially_copy_pass_by_ref)]
229fn is_routing_none(s: &LlmRoutingStrategy) -> bool {
230 *s == LlmRoutingStrategy::None
231}
232
233impl LlmConfig {
234 #[must_use]
236 pub fn effective_provider(&self) -> ProviderKind {
237 self.providers
238 .first()
239 .map_or(ProviderKind::Ollama, |e| e.provider_type)
240 }
241
242 #[must_use]
244 pub fn effective_base_url(&self) -> &str {
245 self.providers
246 .first()
247 .and_then(|e| e.base_url.as_deref())
248 .unwrap_or("http://localhost:11434")
249 }
250
251 #[must_use]
253 pub fn effective_model(&self) -> &str {
254 self.providers
255 .first()
256 .and_then(|e| e.model.as_deref())
257 .unwrap_or("qwen3:8b")
258 }
259
260 #[must_use]
268 pub fn stt_provider_entry(&self) -> Option<&ProviderEntry> {
269 let name_hint = self.stt.as_ref().map_or("", |s| s.provider.as_str());
270 if name_hint.is_empty() {
271 self.providers.iter().find(|p| p.stt_model.is_some())
272 } else {
273 self.providers
274 .iter()
275 .find(|p| p.effective_name() == name_hint && p.stt_model.is_some())
276 }
277 }
278
279 pub fn check_legacy_format(&self) -> Result<(), crate::error::ConfigError> {
285 Ok(())
286 }
287
288 pub fn validate_stt(&self) -> Result<(), crate::error::ConfigError> {
294 use crate::error::ConfigError;
295
296 let Some(stt) = &self.stt else {
297 return Ok(());
298 };
299 if stt.provider.is_empty() {
300 return Ok(());
301 }
302 let found = self
303 .providers
304 .iter()
305 .find(|p| p.effective_name() == stt.provider);
306 match found {
307 None => {
308 return Err(ConfigError::Validation(format!(
309 "[llm.stt].provider = {:?} does not match any [[llm.providers]] entry",
310 stt.provider
311 )));
312 }
313 Some(entry) if entry.stt_model.is_none() => {
314 tracing::warn!(
315 provider = stt.provider,
316 "[[llm.providers]] entry exists but has no `stt_model` — STT will not be activated"
317 );
318 }
319 _ => {}
320 }
321 Ok(())
322 }
323}
324
325#[derive(Debug, Clone, Deserialize, Serialize)]
326pub struct SttConfig {
327 #[serde(default = "default_stt_provider")]
330 pub provider: String,
331 #[serde(default = "default_stt_language")]
333 pub language: String,
334}
335
336#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
338#[serde(rename_all = "lowercase")]
339pub enum RouterStrategyConfig {
340 #[default]
342 Ema,
343 Thompson,
345 Cascade,
347 Bandit,
349}
350
351#[derive(Debug, Clone, Deserialize, Serialize)]
353pub struct RouterConfig {
354 #[serde(default)]
356 pub strategy: RouterStrategyConfig,
357 #[serde(default)]
365 pub thompson_state_path: Option<String>,
366 #[serde(default)]
368 pub cascade: Option<CascadeConfig>,
369 #[serde(default)]
371 pub reputation: Option<ReputationConfig>,
372 #[serde(default)]
374 pub bandit: Option<BanditConfig>,
375}
376
377#[derive(Debug, Clone, Deserialize, Serialize)]
384pub struct ReputationConfig {
385 #[serde(default)]
387 pub enabled: bool,
388 #[serde(default = "default_reputation_decay_factor")]
391 pub decay_factor: f64,
392 #[serde(default = "default_reputation_weight")]
399 pub weight: f64,
400 #[serde(default = "default_reputation_min_observations")]
402 pub min_observations: u64,
403 #[serde(default)]
405 pub state_path: Option<String>,
406}
407
408#[derive(Debug, Clone, Deserialize, Serialize)]
419pub struct CascadeConfig {
420 #[serde(default = "default_cascade_quality_threshold")]
423 pub quality_threshold: f64,
424
425 #[serde(default = "default_cascade_max_escalations")]
429 pub max_escalations: u8,
430
431 #[serde(default)]
435 pub classifier_mode: CascadeClassifierMode,
436
437 #[serde(default = "default_cascade_window_size")]
439 pub window_size: usize,
440
441 #[serde(default)]
445 pub max_cascade_tokens: Option<u32>,
446
447 #[serde(default, skip_serializing_if = "Option::is_none")]
452 pub cost_tiers: Option<Vec<String>>,
453}
454
455impl Default for CascadeConfig {
456 fn default() -> Self {
457 Self {
458 quality_threshold: default_cascade_quality_threshold(),
459 max_escalations: default_cascade_max_escalations(),
460 classifier_mode: CascadeClassifierMode::default(),
461 window_size: default_cascade_window_size(),
462 max_cascade_tokens: None,
463 cost_tiers: None,
464 }
465 }
466}
467
468#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
470#[serde(rename_all = "lowercase")]
471pub enum CascadeClassifierMode {
472 #[default]
475 Heuristic,
476 Judge,
479}
480
481fn default_bandit_alpha() -> f32 {
482 1.0
483}
484
485fn default_bandit_dim() -> usize {
486 32
487}
488
489fn default_bandit_cost_weight() -> f32 {
490 0.1
491}
492
493fn default_bandit_decay_factor() -> f32 {
494 1.0
495}
496
497fn default_bandit_embedding_timeout_ms() -> u64 {
498 50
499}
500
501fn default_bandit_cache_size() -> usize {
502 512
503}
504
505#[derive(Debug, Clone, Deserialize, Serialize)]
518pub struct BanditConfig {
519 #[serde(default = "default_bandit_alpha")]
522 pub alpha: f32,
523
524 #[serde(default = "default_bandit_dim")]
531 pub dim: usize,
532
533 #[serde(default = "default_bandit_cost_weight")]
536 pub cost_weight: f32,
537
538 #[serde(default = "default_bandit_decay_factor")]
541 pub decay_factor: f32,
542
543 #[serde(default)]
549 pub embedding_provider: String,
550
551 #[serde(default = "default_bandit_embedding_timeout_ms")]
554 pub embedding_timeout_ms: u64,
555
556 #[serde(default = "default_bandit_cache_size")]
558 pub cache_size: usize,
559
560 #[serde(default)]
567 pub state_path: Option<String>,
568
569 #[serde(default = "default_bandit_memory_confidence_threshold")]
575 pub memory_confidence_threshold: f32,
576}
577
578fn default_bandit_memory_confidence_threshold() -> f32 {
579 0.9
580}
581
582impl Default for BanditConfig {
583 fn default() -> Self {
584 Self {
585 alpha: default_bandit_alpha(),
586 dim: default_bandit_dim(),
587 cost_weight: default_bandit_cost_weight(),
588 decay_factor: default_bandit_decay_factor(),
589 embedding_provider: String::new(),
590 embedding_timeout_ms: default_bandit_embedding_timeout_ms(),
591 cache_size: default_bandit_cache_size(),
592 state_path: None,
593 memory_confidence_threshold: default_bandit_memory_confidence_threshold(),
594 }
595 }
596}
597
598#[derive(Debug, Deserialize, Serialize)]
599pub struct CandleConfig {
600 #[serde(default = "default_candle_source")]
601 pub source: String,
602 #[serde(default)]
603 pub local_path: String,
604 #[serde(default)]
605 pub filename: Option<String>,
606 #[serde(default = "default_chat_template")]
607 pub chat_template: String,
608 #[serde(default = "default_candle_device")]
609 pub device: String,
610 #[serde(default)]
611 pub embedding_repo: Option<String>,
612 #[serde(default)]
616 pub hf_token: Option<String>,
617 #[serde(default)]
618 pub generation: GenerationParams,
619}
620
621#[derive(Debug, Clone, Deserialize, Serialize)]
622pub struct GenerationParams {
623 #[serde(default = "default_temperature")]
624 pub temperature: f64,
625 #[serde(default)]
626 pub top_p: Option<f64>,
627 #[serde(default)]
628 pub top_k: Option<usize>,
629 #[serde(default = "default_max_tokens")]
630 pub max_tokens: usize,
631 #[serde(default = "default_seed")]
632 pub seed: u64,
633 #[serde(default = "default_repeat_penalty")]
634 pub repeat_penalty: f32,
635 #[serde(default = "default_repeat_last_n")]
636 pub repeat_last_n: usize,
637}
638
639pub const MAX_TOKENS_CAP: usize = 32768;
640
641impl GenerationParams {
642 #[must_use]
643 pub fn capped_max_tokens(&self) -> usize {
644 self.max_tokens.min(MAX_TOKENS_CAP)
645 }
646}
647
648impl Default for GenerationParams {
649 fn default() -> Self {
650 Self {
651 temperature: default_temperature(),
652 top_p: None,
653 top_k: None,
654 max_tokens: default_max_tokens(),
655 seed: default_seed(),
656 repeat_penalty: default_repeat_penalty(),
657 repeat_last_n: default_repeat_last_n(),
658 }
659 }
660}
661
662#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Deserialize, Serialize)]
666#[serde(rename_all = "lowercase")]
667pub enum LlmRoutingStrategy {
668 #[default]
670 None,
671 Ema,
673 Thompson,
675 Cascade,
677 Task,
679 Triage,
681 Bandit,
683}
684
685fn default_triage_timeout_secs() -> u64 {
686 5
687}
688
689fn default_max_triage_tokens() -> u32 {
690 50
691}
692
693fn default_true() -> bool {
694 true
695}
696
697#[derive(Debug, Clone, Default, Deserialize, Serialize)]
699pub struct TierMapping {
700 pub simple: Option<String>,
701 pub medium: Option<String>,
702 pub complex: Option<String>,
703 pub expert: Option<String>,
704}
705
706#[derive(Debug, Clone, Deserialize, Serialize)]
727pub struct ComplexityRoutingConfig {
728 #[serde(default)]
730 pub triage_provider: Option<String>,
731
732 #[serde(default = "default_true")]
734 pub bypass_single_provider: bool,
735
736 #[serde(default)]
738 pub tiers: TierMapping,
739
740 #[serde(default = "default_max_triage_tokens")]
742 pub max_triage_tokens: u32,
743
744 #[serde(default = "default_triage_timeout_secs")]
747 pub triage_timeout_secs: u64,
748
749 #[serde(default)]
752 pub fallback_strategy: Option<String>,
753}
754
755impl Default for ComplexityRoutingConfig {
756 fn default() -> Self {
757 Self {
758 triage_provider: None,
759 bypass_single_provider: true,
760 tiers: TierMapping::default(),
761 max_triage_tokens: default_max_triage_tokens(),
762 triage_timeout_secs: default_triage_timeout_secs(),
763 fallback_strategy: None,
764 }
765 }
766}
767
768#[derive(Debug, Clone, Deserialize, Serialize)]
771pub struct CandleInlineConfig {
772 #[serde(default = "default_candle_source")]
773 pub source: String,
774 #[serde(default)]
775 pub local_path: String,
776 #[serde(default)]
777 pub filename: Option<String>,
778 #[serde(default = "default_chat_template")]
779 pub chat_template: String,
780 #[serde(default = "default_candle_device")]
781 pub device: String,
782 #[serde(default)]
783 pub embedding_repo: Option<String>,
784 #[serde(default)]
786 pub hf_token: Option<String>,
787 #[serde(default)]
788 pub generation: GenerationParams,
789}
790
791impl Default for CandleInlineConfig {
792 fn default() -> Self {
793 Self {
794 source: default_candle_source(),
795 local_path: String::new(),
796 filename: None,
797 chat_template: default_chat_template(),
798 device: default_candle_device(),
799 embedding_repo: None,
800 hf_token: None,
801 generation: GenerationParams::default(),
802 }
803 }
804}
805
806#[derive(Debug, Clone, Deserialize, Serialize)]
812#[allow(clippy::struct_excessive_bools)]
813pub struct ProviderEntry {
814 #[serde(rename = "type")]
816 pub provider_type: ProviderKind,
817
818 #[serde(default)]
820 pub name: Option<String>,
821
822 #[serde(default)]
824 pub model: Option<String>,
825
826 #[serde(default)]
828 pub base_url: Option<String>,
829
830 #[serde(default)]
832 pub max_tokens: Option<u32>,
833
834 #[serde(default)]
836 pub embedding_model: Option<String>,
837
838 #[serde(default)]
841 pub stt_model: Option<String>,
842
843 #[serde(default)]
845 pub embed: bool,
846
847 #[serde(default)]
849 pub default: bool,
850
851 #[serde(default)]
853 pub thinking: Option<ThinkingConfig>,
854 #[serde(default)]
855 pub server_compaction: bool,
856 #[serde(default)]
857 pub enable_extended_context: bool,
858
859 #[serde(default)]
861 pub reasoning_effort: Option<String>,
862
863 #[serde(default)]
865 pub thinking_level: Option<GeminiThinkingLevel>,
866 #[serde(default)]
867 pub thinking_budget: Option<i32>,
868 #[serde(default)]
869 pub include_thoughts: Option<bool>,
870
871 #[serde(default)]
873 pub tool_use: bool,
874
875 #[serde(default)]
877 pub api_key: Option<String>,
878
879 #[serde(default)]
881 pub candle: Option<CandleInlineConfig>,
882
883 #[serde(default)]
885 pub vision_model: Option<String>,
886
887 #[serde(default)]
889 pub instruction_file: Option<std::path::PathBuf>,
890}
891
892impl Default for ProviderEntry {
893 fn default() -> Self {
894 Self {
895 provider_type: ProviderKind::Ollama,
896 name: None,
897 model: None,
898 base_url: None,
899 max_tokens: None,
900 embedding_model: None,
901 stt_model: None,
902 embed: false,
903 default: false,
904 thinking: None,
905 server_compaction: false,
906 enable_extended_context: false,
907 reasoning_effort: None,
908 thinking_level: None,
909 thinking_budget: None,
910 include_thoughts: None,
911 tool_use: false,
912 api_key: None,
913 candle: None,
914 vision_model: None,
915 instruction_file: None,
916 }
917 }
918}
919
920impl ProviderEntry {
921 #[must_use]
923 pub fn effective_name(&self) -> String {
924 self.name
925 .clone()
926 .unwrap_or_else(|| self.provider_type.as_str().to_owned())
927 }
928
929 #[must_use]
934 pub fn effective_model(&self) -> String {
935 if let Some(ref m) = self.model {
936 return m.clone();
937 }
938 match self.provider_type {
939 ProviderKind::Ollama => "qwen3:8b".to_owned(),
940 ProviderKind::Claude => "claude-haiku-4-5-20251001".to_owned(),
941 ProviderKind::OpenAi => "gpt-4o-mini".to_owned(),
942 ProviderKind::Gemini => "gemini-2.0-flash".to_owned(),
943 ProviderKind::Compatible | ProviderKind::Candle => String::new(),
944 }
945 }
946
947 pub fn validate(&self) -> Result<(), crate::error::ConfigError> {
954 use crate::error::ConfigError;
955
956 if self.provider_type == ProviderKind::Compatible && self.name.is_none() {
958 return Err(ConfigError::Validation(
959 "[[llm.providers]] entry with type=\"compatible\" must set `name`".into(),
960 ));
961 }
962
963 match self.provider_type {
965 ProviderKind::Ollama => {
966 if self.thinking.is_some() {
967 tracing::warn!(
968 provider = self.effective_name(),
969 "field `thinking` is only used by Claude providers"
970 );
971 }
972 if self.reasoning_effort.is_some() {
973 tracing::warn!(
974 provider = self.effective_name(),
975 "field `reasoning_effort` is only used by OpenAI providers"
976 );
977 }
978 if self.thinking_level.is_some() || self.thinking_budget.is_some() {
979 tracing::warn!(
980 provider = self.effective_name(),
981 "fields `thinking_level`/`thinking_budget` are only used by Gemini providers"
982 );
983 }
984 }
985 ProviderKind::Claude => {
986 if self.reasoning_effort.is_some() {
987 tracing::warn!(
988 provider = self.effective_name(),
989 "field `reasoning_effort` is only used by OpenAI providers"
990 );
991 }
992 if self.thinking_level.is_some() || self.thinking_budget.is_some() {
993 tracing::warn!(
994 provider = self.effective_name(),
995 "fields `thinking_level`/`thinking_budget` are only used by Gemini providers"
996 );
997 }
998 if self.tool_use {
999 tracing::warn!(
1000 provider = self.effective_name(),
1001 "field `tool_use` is only used by Ollama providers"
1002 );
1003 }
1004 }
1005 ProviderKind::OpenAi => {
1006 if self.thinking.is_some() {
1007 tracing::warn!(
1008 provider = self.effective_name(),
1009 "field `thinking` is only used by Claude providers"
1010 );
1011 }
1012 if self.thinking_level.is_some() || self.thinking_budget.is_some() {
1013 tracing::warn!(
1014 provider = self.effective_name(),
1015 "fields `thinking_level`/`thinking_budget` are only used by Gemini providers"
1016 );
1017 }
1018 if self.tool_use {
1019 tracing::warn!(
1020 provider = self.effective_name(),
1021 "field `tool_use` is only used by Ollama providers"
1022 );
1023 }
1024 }
1025 ProviderKind::Gemini => {
1026 if self.thinking.is_some() {
1027 tracing::warn!(
1028 provider = self.effective_name(),
1029 "field `thinking` is only used by Claude providers"
1030 );
1031 }
1032 if self.reasoning_effort.is_some() {
1033 tracing::warn!(
1034 provider = self.effective_name(),
1035 "field `reasoning_effort` is only used by OpenAI providers"
1036 );
1037 }
1038 if self.tool_use {
1039 tracing::warn!(
1040 provider = self.effective_name(),
1041 "field `tool_use` is only used by Ollama providers"
1042 );
1043 }
1044 }
1045 _ => {}
1046 }
1047
1048 if self.stt_model.is_some() && self.provider_type == ProviderKind::Ollama {
1051 tracing::warn!(
1052 provider = self.effective_name(),
1053 "field `stt_model` is set on an Ollama provider; Ollama does not support the \
1054 Whisper STT API — use OpenAI, compatible, or candle instead"
1055 );
1056 }
1057
1058 Ok(())
1059 }
1060}
1061
1062pub fn validate_pool(entries: &[ProviderEntry]) -> Result<(), crate::error::ConfigError> {
1072 use crate::error::ConfigError;
1073 use std::collections::HashSet;
1074
1075 if entries.is_empty() {
1076 return Err(ConfigError::Validation(
1077 "at least one LLM provider must be configured in [[llm.providers]]".into(),
1078 ));
1079 }
1080
1081 let default_count = entries.iter().filter(|e| e.default).count();
1082 if default_count > 1 {
1083 return Err(ConfigError::Validation(
1084 "only one [[llm.providers]] entry can be marked `default = true`".into(),
1085 ));
1086 }
1087
1088 let mut seen_names: HashSet<String> = HashSet::new();
1089 for entry in entries {
1090 let name = entry.effective_name();
1091 if !seen_names.insert(name.clone()) {
1092 return Err(ConfigError::Validation(format!(
1093 "duplicate provider name \"{name}\" in [[llm.providers]]"
1094 )));
1095 }
1096 entry.validate()?;
1097 }
1098
1099 Ok(())
1100}
1101
1102#[cfg(test)]
1103mod tests {
1104 use super::*;
1105
1106 fn ollama_entry() -> ProviderEntry {
1107 ProviderEntry {
1108 provider_type: ProviderKind::Ollama,
1109 name: Some("ollama".into()),
1110 model: Some("qwen3:8b".into()),
1111 ..Default::default()
1112 }
1113 }
1114
1115 fn claude_entry() -> ProviderEntry {
1116 ProviderEntry {
1117 provider_type: ProviderKind::Claude,
1118 name: Some("claude".into()),
1119 model: Some("claude-sonnet-4-6".into()),
1120 max_tokens: Some(8192),
1121 ..Default::default()
1122 }
1123 }
1124
1125 #[test]
1128 fn validate_ollama_valid() {
1129 assert!(ollama_entry().validate().is_ok());
1130 }
1131
1132 #[test]
1133 fn validate_claude_valid() {
1134 assert!(claude_entry().validate().is_ok());
1135 }
1136
1137 #[test]
1138 fn validate_compatible_without_name_errors() {
1139 let entry = ProviderEntry {
1140 provider_type: ProviderKind::Compatible,
1141 name: None,
1142 ..Default::default()
1143 };
1144 let err = entry.validate().unwrap_err();
1145 assert!(
1146 err.to_string().contains("compatible"),
1147 "error should mention compatible: {err}"
1148 );
1149 }
1150
1151 #[test]
1152 fn validate_compatible_with_name_ok() {
1153 let entry = ProviderEntry {
1154 provider_type: ProviderKind::Compatible,
1155 name: Some("my-proxy".into()),
1156 base_url: Some("http://localhost:8080".into()),
1157 model: Some("gpt-4o".into()),
1158 max_tokens: Some(4096),
1159 ..Default::default()
1160 };
1161 assert!(entry.validate().is_ok());
1162 }
1163
1164 #[test]
1165 fn validate_openai_valid() {
1166 let entry = ProviderEntry {
1167 provider_type: ProviderKind::OpenAi,
1168 name: Some("openai".into()),
1169 model: Some("gpt-4o".into()),
1170 max_tokens: Some(4096),
1171 ..Default::default()
1172 };
1173 assert!(entry.validate().is_ok());
1174 }
1175
1176 #[test]
1177 fn validate_gemini_valid() {
1178 let entry = ProviderEntry {
1179 provider_type: ProviderKind::Gemini,
1180 name: Some("gemini".into()),
1181 model: Some("gemini-2.0-flash".into()),
1182 ..Default::default()
1183 };
1184 assert!(entry.validate().is_ok());
1185 }
1186
1187 #[test]
1190 fn validate_pool_empty_errors() {
1191 let err = validate_pool(&[]).unwrap_err();
1192 assert!(err.to_string().contains("at least one"), "{err}");
1193 }
1194
1195 #[test]
1196 fn validate_pool_single_entry_ok() {
1197 assert!(validate_pool(&[ollama_entry()]).is_ok());
1198 }
1199
1200 #[test]
1201 fn validate_pool_duplicate_names_errors() {
1202 let a = ollama_entry();
1203 let b = ollama_entry(); let err = validate_pool(&[a, b]).unwrap_err();
1205 assert!(err.to_string().contains("duplicate"), "{err}");
1206 }
1207
1208 #[test]
1209 fn validate_pool_multiple_defaults_errors() {
1210 let mut a = ollama_entry();
1211 let mut b = claude_entry();
1212 a.default = true;
1213 b.default = true;
1214 let err = validate_pool(&[a, b]).unwrap_err();
1215 assert!(err.to_string().contains("default"), "{err}");
1216 }
1217
1218 #[test]
1219 fn validate_pool_two_different_providers_ok() {
1220 assert!(validate_pool(&[ollama_entry(), claude_entry()]).is_ok());
1221 }
1222
1223 #[test]
1224 fn validate_pool_propagates_entry_error() {
1225 let bad = ProviderEntry {
1226 provider_type: ProviderKind::Compatible,
1227 name: None, ..Default::default()
1229 };
1230 assert!(validate_pool(&[bad]).is_err());
1231 }
1232
1233 #[test]
1236 fn effective_model_returns_explicit_when_set() {
1237 let entry = ProviderEntry {
1238 provider_type: ProviderKind::Claude,
1239 model: Some("claude-sonnet-4-6".into()),
1240 ..Default::default()
1241 };
1242 assert_eq!(entry.effective_model(), "claude-sonnet-4-6");
1243 }
1244
1245 #[test]
1246 fn effective_model_ollama_default_when_none() {
1247 let entry = ProviderEntry {
1248 provider_type: ProviderKind::Ollama,
1249 model: None,
1250 ..Default::default()
1251 };
1252 assert_eq!(entry.effective_model(), "qwen3:8b");
1253 }
1254
1255 #[test]
1256 fn effective_model_claude_default_when_none() {
1257 let entry = ProviderEntry {
1258 provider_type: ProviderKind::Claude,
1259 model: None,
1260 ..Default::default()
1261 };
1262 assert_eq!(entry.effective_model(), "claude-haiku-4-5-20251001");
1263 }
1264
1265 #[test]
1266 fn effective_model_openai_default_when_none() {
1267 let entry = ProviderEntry {
1268 provider_type: ProviderKind::OpenAi,
1269 model: None,
1270 ..Default::default()
1271 };
1272 assert_eq!(entry.effective_model(), "gpt-4o-mini");
1273 }
1274
1275 #[test]
1276 fn effective_model_gemini_default_when_none() {
1277 let entry = ProviderEntry {
1278 provider_type: ProviderKind::Gemini,
1279 model: None,
1280 ..Default::default()
1281 };
1282 assert_eq!(entry.effective_model(), "gemini-2.0-flash");
1283 }
1284
1285 fn parse_llm(toml: &str) -> LlmConfig {
1289 #[derive(serde::Deserialize)]
1290 struct Wrapper {
1291 llm: LlmConfig,
1292 }
1293 toml::from_str::<Wrapper>(toml).unwrap().llm
1294 }
1295
1296 #[test]
1297 fn check_legacy_format_new_format_ok() {
1298 let cfg = parse_llm(
1299 r#"
1300[llm]
1301
1302[[llm.providers]]
1303type = "ollama"
1304model = "qwen3:8b"
1305"#,
1306 );
1307 assert!(cfg.check_legacy_format().is_ok());
1308 }
1309
1310 #[test]
1311 fn check_legacy_format_empty_providers_no_legacy_ok() {
1312 let cfg = parse_llm("[llm]\n");
1314 assert!(cfg.check_legacy_format().is_ok());
1315 }
1316
1317 #[test]
1320 fn effective_provider_falls_back_to_ollama_when_no_providers() {
1321 let cfg = parse_llm("[llm]\n");
1322 assert_eq!(cfg.effective_provider(), ProviderKind::Ollama);
1323 }
1324
1325 #[test]
1326 fn effective_provider_reads_from_providers_first() {
1327 let cfg = parse_llm(
1328 r#"
1329[llm]
1330
1331[[llm.providers]]
1332type = "claude"
1333model = "claude-sonnet-4-6"
1334"#,
1335 );
1336 assert_eq!(cfg.effective_provider(), ProviderKind::Claude);
1337 }
1338
1339 #[test]
1340 fn effective_model_reads_from_providers_first() {
1341 let cfg = parse_llm(
1342 r#"
1343[llm]
1344
1345[[llm.providers]]
1346type = "ollama"
1347model = "qwen3:8b"
1348"#,
1349 );
1350 assert_eq!(cfg.effective_model(), "qwen3:8b");
1351 }
1352
1353 #[test]
1354 fn effective_base_url_default_when_absent() {
1355 let cfg = parse_llm("[llm]\n");
1356 assert_eq!(cfg.effective_base_url(), "http://localhost:11434");
1357 }
1358
1359 #[test]
1360 fn effective_base_url_from_providers_entry() {
1361 let cfg = parse_llm(
1362 r#"
1363[llm]
1364
1365[[llm.providers]]
1366type = "ollama"
1367base_url = "http://myhost:11434"
1368"#,
1369 );
1370 assert_eq!(cfg.effective_base_url(), "http://myhost:11434");
1371 }
1372
1373 #[test]
1376 fn complexity_routing_defaults() {
1377 let cr = ComplexityRoutingConfig::default();
1378 assert!(
1379 cr.bypass_single_provider,
1380 "bypass_single_provider must default to true"
1381 );
1382 assert_eq!(cr.triage_timeout_secs, 5);
1383 assert_eq!(cr.max_triage_tokens, 50);
1384 assert!(cr.triage_provider.is_none());
1385 assert!(cr.tiers.simple.is_none());
1386 }
1387
1388 #[test]
1389 fn complexity_routing_toml_round_trip() {
1390 let cfg = parse_llm(
1391 r#"
1392[llm]
1393routing = "triage"
1394
1395[llm.complexity_routing]
1396triage_provider = "fast"
1397bypass_single_provider = false
1398triage_timeout_secs = 10
1399max_triage_tokens = 100
1400
1401[llm.complexity_routing.tiers]
1402simple = "fast"
1403medium = "medium"
1404complex = "large"
1405expert = "opus"
1406"#,
1407 );
1408 assert!(matches!(cfg.routing, LlmRoutingStrategy::Triage));
1409 let cr = cfg
1410 .complexity_routing
1411 .expect("complexity_routing must be present");
1412 assert_eq!(cr.triage_provider.as_deref(), Some("fast"));
1413 assert!(!cr.bypass_single_provider);
1414 assert_eq!(cr.triage_timeout_secs, 10);
1415 assert_eq!(cr.max_triage_tokens, 100);
1416 assert_eq!(cr.tiers.simple.as_deref(), Some("fast"));
1417 assert_eq!(cr.tiers.medium.as_deref(), Some("medium"));
1418 assert_eq!(cr.tiers.complex.as_deref(), Some("large"));
1419 assert_eq!(cr.tiers.expert.as_deref(), Some("opus"));
1420 }
1421
1422 #[test]
1423 fn complexity_routing_partial_tiers_toml() {
1424 let cfg = parse_llm(
1426 r#"
1427[llm]
1428routing = "triage"
1429
1430[llm.complexity_routing.tiers]
1431simple = "haiku"
1432complex = "sonnet"
1433"#,
1434 );
1435 let cr = cfg
1436 .complexity_routing
1437 .expect("complexity_routing must be present");
1438 assert_eq!(cr.tiers.simple.as_deref(), Some("haiku"));
1439 assert!(cr.tiers.medium.is_none());
1440 assert_eq!(cr.tiers.complex.as_deref(), Some("sonnet"));
1441 assert!(cr.tiers.expert.is_none());
1442 assert!(cr.bypass_single_provider);
1444 assert_eq!(cr.triage_timeout_secs, 5);
1445 }
1446
1447 #[test]
1448 fn routing_strategy_triage_deserialized() {
1449 let cfg = parse_llm(
1450 r#"
1451[llm]
1452routing = "triage"
1453"#,
1454 );
1455 assert!(matches!(cfg.routing, LlmRoutingStrategy::Triage));
1456 }
1457
1458 #[test]
1461 fn stt_provider_entry_by_name_match() {
1462 let cfg = parse_llm(
1463 r#"
1464[llm]
1465
1466[[llm.providers]]
1467type = "openai"
1468name = "quality"
1469model = "gpt-5.4"
1470stt_model = "gpt-4o-mini-transcribe"
1471
1472[llm.stt]
1473provider = "quality"
1474"#,
1475 );
1476 let entry = cfg.stt_provider_entry().expect("should find stt provider");
1477 assert_eq!(entry.effective_name(), "quality");
1478 assert_eq!(entry.stt_model.as_deref(), Some("gpt-4o-mini-transcribe"));
1479 }
1480
1481 #[test]
1482 fn stt_provider_entry_auto_detect_when_provider_empty() {
1483 let cfg = parse_llm(
1484 r#"
1485[llm]
1486
1487[[llm.providers]]
1488type = "openai"
1489name = "openai-stt"
1490stt_model = "whisper-1"
1491
1492[llm.stt]
1493provider = ""
1494"#,
1495 );
1496 let entry = cfg.stt_provider_entry().expect("should auto-detect");
1497 assert_eq!(entry.effective_name(), "openai-stt");
1498 }
1499
1500 #[test]
1501 fn stt_provider_entry_auto_detect_no_stt_section() {
1502 let cfg = parse_llm(
1503 r#"
1504[llm]
1505
1506[[llm.providers]]
1507type = "openai"
1508name = "openai-stt"
1509stt_model = "whisper-1"
1510"#,
1511 );
1512 let entry = cfg.stt_provider_entry().expect("should auto-detect");
1514 assert_eq!(entry.effective_name(), "openai-stt");
1515 }
1516
1517 #[test]
1518 fn stt_provider_entry_none_when_no_stt_model() {
1519 let cfg = parse_llm(
1520 r#"
1521[llm]
1522
1523[[llm.providers]]
1524type = "openai"
1525name = "quality"
1526model = "gpt-5.4"
1527"#,
1528 );
1529 assert!(cfg.stt_provider_entry().is_none());
1530 }
1531
1532 #[test]
1533 fn stt_provider_entry_name_mismatch_falls_back_to_none() {
1534 let cfg = parse_llm(
1536 r#"
1537[llm]
1538
1539[[llm.providers]]
1540type = "openai"
1541name = "quality"
1542model = "gpt-5.4"
1543
1544[[llm.providers]]
1545type = "openai"
1546name = "openai-stt"
1547stt_model = "whisper-1"
1548
1549[llm.stt]
1550provider = "quality"
1551"#,
1552 );
1553 assert!(cfg.stt_provider_entry().is_none());
1555 }
1556
1557 #[test]
1558 fn stt_config_deserializes_new_slim_format() {
1559 let cfg = parse_llm(
1560 r#"
1561[llm]
1562
1563[[llm.providers]]
1564type = "openai"
1565name = "quality"
1566stt_model = "whisper-1"
1567
1568[llm.stt]
1569provider = "quality"
1570language = "en"
1571"#,
1572 );
1573 let stt = cfg.stt.as_ref().expect("stt section present");
1574 assert_eq!(stt.provider, "quality");
1575 assert_eq!(stt.language, "en");
1576 }
1577
1578 #[test]
1579 fn stt_config_default_provider_is_empty() {
1580 assert_eq!(default_stt_provider(), "");
1582 }
1583
1584 #[test]
1585 fn validate_stt_missing_provider_ok() {
1586 let cfg = parse_llm("[llm]\n");
1587 assert!(cfg.validate_stt().is_ok());
1588 }
1589
1590 #[test]
1591 fn validate_stt_valid_reference() {
1592 let cfg = parse_llm(
1593 r#"
1594[llm]
1595
1596[[llm.providers]]
1597type = "openai"
1598name = "quality"
1599stt_model = "whisper-1"
1600
1601[llm.stt]
1602provider = "quality"
1603"#,
1604 );
1605 assert!(cfg.validate_stt().is_ok());
1606 }
1607
1608 #[test]
1609 fn validate_stt_nonexistent_provider_errors() {
1610 let cfg = parse_llm(
1611 r#"
1612[llm]
1613
1614[[llm.providers]]
1615type = "openai"
1616name = "quality"
1617model = "gpt-5.4"
1618
1619[llm.stt]
1620provider = "nonexistent"
1621"#,
1622 );
1623 assert!(cfg.validate_stt().is_err());
1624 }
1625
1626 #[test]
1627 fn validate_stt_provider_exists_but_no_stt_model_returns_ok_with_warn() {
1628 let cfg = parse_llm(
1630 r#"
1631[llm]
1632
1633[[llm.providers]]
1634type = "openai"
1635name = "quality"
1636model = "gpt-5.4"
1637
1638[llm.stt]
1639provider = "quality"
1640"#,
1641 );
1642 assert!(cfg.validate_stt().is_ok());
1644 assert!(
1646 cfg.stt_provider_entry().is_none(),
1647 "stt_provider_entry must be None when provider has no stt_model"
1648 );
1649 }
1650}