1use super::{CostTier, ModelPolicy};
17use std::time::Duration;
18
19#[derive(Debug, Clone)]
24pub struct ModelRouter {
25 policy: ModelPolicy,
27 large_model_threshold: usize,
29 medium_model_threshold: usize,
31 enabled: bool,
33}
34
35impl Default for ModelRouter {
36 fn default() -> Self {
37 Self::new()
38 }
39}
40
41impl ModelRouter {
42 pub fn new() -> Self {
44 Self {
45 policy: ModelPolicy::default(),
46 large_model_threshold: 4000,
47 medium_model_threshold: 1000,
48 enabled: true,
49 }
50 }
51
52 pub fn with_policy(policy: ModelPolicy) -> Self {
54 Self {
55 policy,
56 ..Default::default()
57 }
58 }
59
60 pub fn with_thresholds(mut self, medium: usize, large: usize) -> Self {
62 self.medium_model_threshold = medium;
63 self.large_model_threshold = large;
64 self
65 }
66
67 pub fn policy(&self) -> &ModelPolicy {
69 &self.policy
70 }
71
72 pub fn enabled(mut self, enabled: bool) -> Self {
74 self.enabled = enabled;
75 self
76 }
77
78 pub fn route(&self, task: &TaskAnalysis) -> RoutingDecision {
82 if !self.enabled {
83 return RoutingDecision {
84 model: self.policy.medium.clone(),
85 tier: CostTier::Medium,
86 reason: "Smart routing disabled".to_string(),
87 };
88 }
89
90 let tier = self.analyze_complexity(task);
92
93 let tier = self.apply_constraints(tier, task);
95
96 let model = self.policy.model_for_tier(tier).to_string();
97 let reason = self.explain_routing(task, tier);
98
99 RoutingDecision {
100 model,
101 tier,
102 reason,
103 }
104 }
105
106 fn analyze_complexity(&self, task: &TaskAnalysis) -> CostTier {
108 let mut score = 0;
109
110 if task.estimated_tokens > self.large_model_threshold {
112 score += 3;
113 } else if task.estimated_tokens > self.medium_model_threshold {
114 score += 2;
115 } else {
116 score += 1;
117 }
118
119 if task.requires_reasoning {
121 score += 2;
122 }
123 if task.requires_code_generation {
124 score += 2;
125 }
126 if task.requires_structured_output {
127 score += 1;
128 }
129 if task.multi_step {
130 score += 1;
131 }
132
133 match score {
135 0..=2 => CostTier::Low,
136 3..=5 => CostTier::Medium,
137 _ => CostTier::High,
138 }
139 }
140
141 fn apply_constraints(&self, tier: CostTier, task: &TaskAnalysis) -> CostTier {
143 let tier = match (tier, self.policy.max_cost_tier) {
145 (CostTier::High, CostTier::Low) => CostTier::Low,
146 (CostTier::High, CostTier::Medium) => CostTier::Medium,
147 (CostTier::Medium, CostTier::Low) => CostTier::Low,
148 _ => tier,
149 };
150
151 if let Some(max_latency) = self.policy.max_latency_ms {
153 let estimated_latency = self.estimate_latency(tier, task);
154 if estimated_latency > max_latency {
155 return match tier {
157 CostTier::High => CostTier::Medium,
158 CostTier::Medium => CostTier::Low,
159 CostTier::Low => CostTier::Low,
160 };
161 }
162 }
163
164 if tier == CostTier::High && !self.policy.allow_large {
166 return CostTier::Medium;
167 }
168
169 tier
170 }
171
172 fn estimate_latency(&self, tier: CostTier, task: &TaskAnalysis) -> u64 {
174 let base_latency = match tier {
176 CostTier::Low => 500,
177 CostTier::Medium => 1500,
178 CostTier::High => 3000,
179 };
180
181 let token_latency = (task.estimated_tokens as u64 / 100) * 50;
183
184 base_latency + token_latency
185 }
186
187 fn explain_routing(&self, task: &TaskAnalysis, tier: CostTier) -> String {
189 let mut reasons = Vec::new();
190
191 if task.estimated_tokens > self.large_model_threshold {
192 reasons.push("high token count");
193 }
194 if task.requires_reasoning {
195 reasons.push("requires reasoning");
196 }
197 if task.requires_code_generation {
198 reasons.push("requires code generation");
199 }
200
201 if reasons.is_empty() {
202 reasons.push("standard task");
203 }
204
205 format!("{:?} tier selected: {}", tier, reasons.join(", "))
206 }
207
208 pub fn route_simple(&self, prompt: &str) -> RoutingDecision {
210 let task = TaskAnalysis::from_prompt(prompt);
211 self.route(&task)
212 }
213}
214
215#[derive(Debug, Clone, Copy, Default)]
219pub struct ModelRequirements {
220 pub vision: bool,
222 pub audio: bool,
224 pub video: bool,
226 pub pdf: bool,
228 pub min_context_tokens: u32,
230 pub max_input_cost_per_m: f32,
232 pub min_arena_score: f32,
234}
235
236impl ModelRequirements {
237 pub fn with_vision(mut self) -> Self {
239 self.vision = true;
240 self
241 }
242
243 pub fn with_min_context(mut self, tokens: u32) -> Self {
245 self.min_context_tokens = tokens;
246 self
247 }
248
249 pub fn with_max_cost(mut self, cost: f32) -> Self {
251 self.max_input_cost_per_m = cost;
252 self
253 }
254
255 pub fn with_min_arena(mut self, score: f32) -> Self {
257 self.min_arena_score = score;
258 self
259 }
260}
261
262#[derive(Debug, Clone)]
264pub struct ScoredModel {
265 pub name: String,
267 pub score: f32,
269 pub arena_rank: Option<f32>,
271 pub input_cost_per_m: Option<f32>,
273 pub supports_vision: bool,
275 pub max_input_tokens: u32,
277}
278
279#[derive(Debug, Clone, Copy, PartialEq, Default)]
281pub enum SelectionStrategy {
282 #[default]
284 BestQuality,
285 CheapestFirst,
287 LargestContext,
289 ValueOptimal,
291}
292
293#[derive(Debug, Clone)]
314pub struct ModelSelector {
315 models: Vec<(String, Option<f32>)>,
319 strategy: SelectionStrategy,
321}
322
323impl ModelSelector {
324 pub fn new(models: &[&str]) -> Self {
326 Self {
327 models: models.iter().map(|m| (m.to_lowercase(), None)).collect(),
328 strategy: SelectionStrategy::default(),
329 }
330 }
331
332 pub fn from_owned(models: Vec<String>) -> Self {
334 Self {
335 models: models
336 .into_iter()
337 .map(|m| (m.to_lowercase(), None))
338 .collect(),
339 strategy: SelectionStrategy::default(),
340 }
341 }
342
343 pub fn set_strategy(&mut self, strategy: SelectionStrategy) {
345 self.strategy = strategy;
346 }
347
348 pub fn set_priority(&mut self, model: &str, priority: f32) {
353 let lower = model.to_lowercase();
354 for (name, prio) in &mut self.models {
355 if *name == lower {
356 *prio = Some(priority);
357 return;
358 }
359 }
360 self.models.push((lower, Some(priority)));
362 }
363
364 pub fn add_model(&mut self, model: &str) {
366 let lower = model.to_lowercase();
367 if !self.models.iter().any(|(n, _)| *n == lower) {
368 self.models.push((lower, None));
369 }
370 }
371
372 pub fn select(&self, reqs: &ModelRequirements) -> Option<ScoredModel> {
380 if self.models.len() <= 2 {
381 return self
384 .models
385 .iter()
386 .filter_map(|(name, custom_prio)| self.score_model(name, *custom_prio, reqs))
387 .next();
388 }
389 self.ranked(reqs).into_iter().next()
390 }
391
392 pub fn ranked(&self, reqs: &ModelRequirements) -> Vec<ScoredModel> {
397 let mut candidates: Vec<ScoredModel> = self
398 .models
399 .iter()
400 .filter_map(|(name, custom_prio)| self.score_model(name, *custom_prio, reqs))
401 .collect();
402
403 if candidates.len() > 2 {
405 candidates.sort_by(|a, b| {
406 b.score
407 .partial_cmp(&a.score)
408 .unwrap_or(std::cmp::Ordering::Equal)
409 });
410 }
411 candidates
412 }
413
414 pub fn select_multi(&self, requirements: &[ModelRequirements]) -> Vec<Option<ScoredModel>> {
420 let mut used: Vec<bool> = vec![false; self.models.len()];
421 let mut results = Vec::with_capacity(requirements.len());
422
423 for reqs in requirements {
424 let mut best: Option<(ScoredModel, usize)> = None;
425
426 for (idx, (name, custom_prio)) in self.models.iter().enumerate() {
427 if used[idx] {
428 continue;
429 }
430 if let Some(scored) = self.score_model(name, *custom_prio, reqs) {
431 let dominated = match &best {
432 Some((current, _)) => scored.score > current.score,
433 None => true,
434 };
435 if dominated {
436 best = Some((scored, idx));
437 }
438 }
439 }
440
441 if let Some((model, idx)) = best {
442 used[idx] = true;
443 results.push(Some(model));
444 } else {
445 let fallback = self.select(reqs);
447 results.push(fallback);
448 }
449 }
450
451 results
452 }
453
454 fn score_model(
456 &self,
457 name: &str,
458 custom_prio: Option<f32>,
459 reqs: &ModelRequirements,
460 ) -> Option<ScoredModel> {
461 let profile = llm_models_spider::model_profile(name);
462
463 let has_vision = profile
465 .as_ref()
466 .map(|p| p.capabilities.vision)
467 .unwrap_or_else(|| llm_models_spider::supports_vision(name));
468 let has_audio = profile
469 .as_ref()
470 .map(|p| p.capabilities.audio)
471 .unwrap_or(false);
472 let has_video = profile
473 .as_ref()
474 .map(|p| p.capabilities.video)
475 .unwrap_or_else(|| llm_models_spider::supports_video(name));
476 let has_pdf = profile
477 .as_ref()
478 .map(|p| p.capabilities.file)
479 .unwrap_or_else(|| llm_models_spider::supports_pdf(name));
480
481 if reqs.vision && !has_vision {
483 return None;
484 }
485 if reqs.audio && !has_audio {
486 return None;
487 }
488 if reqs.video && !has_video {
489 return None;
490 }
491 if reqs.pdf && !has_pdf {
492 return None;
493 }
494
495 let max_input = profile.as_ref().map(|p| p.max_input_tokens).unwrap_or(0);
496 if reqs.min_context_tokens > 0 && max_input < reqs.min_context_tokens {
497 return None;
498 }
499
500 let arena = profile.as_ref().and_then(|p| p.ranks.overall);
501 let input_cost = profile
502 .as_ref()
503 .and_then(|p| p.pricing.input_cost_per_m_tokens);
504
505 if reqs.max_input_cost_per_m > 0.0 {
506 if let Some(cost) = input_cost {
507 if cost > reqs.max_input_cost_per_m {
508 return None;
509 }
510 }
511 }
512
513 if reqs.min_arena_score > 0.0 {
514 match arena {
515 Some(score) if score >= reqs.min_arena_score => {}
516 Some(_) => return None,
517 None => {} }
519 }
520
521 let score = if let Some(prio) = custom_prio {
523 prio
524 } else {
525 self.auto_score(arena, input_cost, max_input)
526 };
527
528 Some(ScoredModel {
529 name: name.to_string(),
530 score,
531 arena_rank: arena,
532 input_cost_per_m: input_cost,
533 supports_vision: has_vision,
534 max_input_tokens: max_input,
535 })
536 }
537
538 fn auto_score(&self, arena: Option<f32>, cost: Option<f32>, context: u32) -> f32 {
540 match self.strategy {
541 SelectionStrategy::BestQuality => arena.unwrap_or(50.0),
542 SelectionStrategy::CheapestFirst => {
543 match cost {
545 Some(c) if c > 0.0 => 1000.0 / c,
546 _ => 50.0, }
548 }
549 SelectionStrategy::LargestContext => context as f32 / 1000.0,
550 SelectionStrategy::ValueOptimal => {
551 let quality = arena.unwrap_or(50.0);
552 let cost_factor = match cost {
553 Some(c) if c > 0.0 => 100.0 / c,
554 _ => 1.0,
555 };
556 quality * cost_factor.sqrt()
557 }
558 }
559 }
560}
561
562pub fn auto_policy(available_models: &[&str]) -> ModelPolicy {
574 if available_models.is_empty() {
575 return ModelPolicy::default();
576 }
577 if available_models.len() == 1 {
578 let m = available_models[0].to_string();
579 return ModelPolicy {
580 small: m.clone(),
581 medium: m.clone(),
582 large: m,
583 allow_large: true,
584 max_latency_ms: None,
585 max_cost_tier: CostTier::High,
586 };
587 }
588 if available_models.len() == 2 {
589 let a = available_models[0].to_string();
593 let b = available_models[1].to_string();
594 return ModelPolicy {
595 large: a.clone(),
596 medium: a,
597 small: b,
598 allow_large: true,
599 max_latency_ms: None,
600 max_cost_tier: CostTier::High,
601 };
602 }
603
604 let mut models: Vec<(&str, f32, f32)> = available_models
607 .iter()
608 .map(|&name| {
609 let profile = llm_models_spider::model_profile(name);
610 let arena = profile
611 .as_ref()
612 .and_then(|p| p.ranks.overall)
613 .unwrap_or(50.0);
614 let cost = profile
615 .as_ref()
616 .and_then(|p| p.pricing.input_cost_per_m_tokens)
617 .unwrap_or(5.0);
618 (name, arena, cost)
619 })
620 .collect();
621
622 models.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
624
625 if models.is_empty() {
628 return ModelPolicy::default();
629 }
630 let large = models[0].0.to_string();
631 let small = models
632 .last()
633 .map(|m| m.0.to_string())
634 .unwrap_or_else(|| large.clone());
635 let medium = if models.len() >= 3 {
636 models[models.len() / 2].0.to_string()
637 } else {
638 large.clone()
639 };
640
641 ModelPolicy {
642 small,
643 medium,
644 large,
645 allow_large: true,
646 max_latency_ms: None,
647 max_cost_tier: CostTier::High,
648 }
649}
650
651#[derive(Debug, Clone, Default)]
655pub struct TaskAnalysis {
656 pub estimated_tokens: usize,
658 pub requires_reasoning: bool,
660 pub requires_code_generation: bool,
662 pub requires_structured_output: bool,
664 pub multi_step: bool,
666 pub max_latency: Option<Duration>,
668 pub category: TaskCategory,
670 pub requires_vision: bool,
672 pub requires_audio: bool,
674}
675
676impl TaskAnalysis {
677 pub fn from_prompt(prompt: &str) -> Self {
679 let estimated_tokens = estimate_tokens(prompt);
680 let lower = prompt.to_lowercase();
681
682 Self {
683 estimated_tokens,
684 requires_reasoning: lower.contains("analyze")
685 || lower.contains("compare")
686 || lower.contains("explain")
687 || lower.contains("why"),
688 requires_code_generation: lower.contains("code")
689 || lower.contains("implement")
690 || lower.contains("function")
691 || lower.contains("script"),
692 requires_structured_output: lower.contains("json")
693 || lower.contains("extract")
694 || lower.contains("list"),
695 multi_step: lower.contains("then")
696 || lower.contains("step")
697 || lower.contains("first")
698 || lower.contains("next"),
699 max_latency: None,
700 category: TaskCategory::General,
701 requires_vision: lower.contains("screenshot")
702 || lower.contains("image")
703 || lower.contains("picture")
704 || lower.contains("visual"),
705 requires_audio: lower.contains("audio")
706 || lower.contains("voice")
707 || lower.contains("speech"),
708 }
709 }
710
711 pub fn extraction(html_length: usize) -> Self {
713 Self {
714 estimated_tokens: html_length / 4 + 200, requires_reasoning: false,
716 requires_code_generation: false,
717 requires_structured_output: true,
718 multi_step: false,
719 max_latency: None,
720 category: TaskCategory::Extraction,
721 requires_vision: false,
722 requires_audio: false,
723 }
724 }
725
726 pub fn action(instruction: &str) -> Self {
728 let mut analysis = Self::from_prompt(instruction);
729 analysis.category = TaskCategory::Action;
730 analysis.requires_structured_output = true;
731 analysis
732 }
733
734 pub fn with_max_latency(mut self, latency: Duration) -> Self {
736 self.max_latency = Some(latency);
737 self
738 }
739
740 pub fn to_requirements(&self) -> ModelRequirements {
742 ModelRequirements {
743 vision: self.requires_vision,
744 audio: self.requires_audio,
745 ..Default::default()
746 }
747 }
748}
749
750#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
752pub enum TaskCategory {
753 #[default]
755 General,
756 Extraction,
758 Action,
760 Code,
762 Analysis,
764 Classification,
766}
767
768#[derive(Debug, Clone)]
770pub struct RoutingDecision {
771 pub model: String,
773 pub tier: CostTier,
775 pub reason: String,
777}
778
779impl RoutingDecision {
780 pub fn is_fast(&self) -> bool {
782 self.tier == CostTier::Low
783 }
784
785 pub fn is_powerful(&self) -> bool {
787 self.tier == CostTier::High
788 }
789}
790
791pub fn estimate_tokens(text: &str) -> usize {
795 text.len() / 4 + 1
798}
799
800pub fn estimate_message_tokens(messages: &[crate::Message]) -> usize {
802 messages
803 .iter()
804 .map(|m| estimate_tokens(m.content.as_text()) + 4) .sum()
806}
807
808pub fn classify_round_complexity(
813 user_prompt: &str,
814 html_len: usize,
815 round_idx: usize,
816 stagnated: bool,
817) -> TaskAnalysis {
818 let mut analysis = TaskAnalysis::from_prompt(user_prompt);
819 analysis.estimated_tokens = user_prompt.len() / 4 + 1;
820 if round_idx == 0 {
822 analysis.requires_reasoning = true;
823 }
824 if stagnated {
826 analysis.requires_reasoning = true;
827 analysis.multi_step = true;
828 }
829 if html_len > 50_000 {
831 analysis.requires_reasoning = true;
832 }
833 analysis.requires_structured_output = true;
835 analysis
836}
837
838#[cfg(test)]
839mod tests {
840 use super::*;
841
842 #[test]
843 fn test_model_router_simple() {
844 let router = ModelRouter::new();
845
846 let decision = router.route_simple("Extract the title from this page");
847 assert!(!decision.model.is_empty());
848 }
849
850 #[test]
851 fn test_model_router_complex() {
852 let router = ModelRouter::new();
853
854 let task = TaskAnalysis {
855 estimated_tokens: 5000,
856 requires_reasoning: true,
857 requires_code_generation: true,
858 ..Default::default()
859 };
860
861 let decision = router.route(&task);
862 assert_eq!(decision.tier, CostTier::High);
863 }
864
865 #[test]
866 fn test_model_router_constrained() {
867 let policy = ModelPolicy {
868 max_cost_tier: CostTier::Medium,
869 ..Default::default()
870 };
871
872 let router = ModelRouter::with_policy(policy);
873
874 let task = TaskAnalysis {
875 estimated_tokens: 5000,
876 requires_reasoning: true,
877 ..Default::default()
878 };
879
880 let decision = router.route(&task);
881 assert!(decision.tier != CostTier::High);
883 }
884
885 #[test]
886 fn test_task_analysis_from_prompt() {
887 let analysis = TaskAnalysis::from_prompt(
888 "Analyze the code and explain why it's slow, then implement a fix",
889 );
890
891 assert!(analysis.requires_reasoning);
892 assert!(analysis.requires_code_generation);
893 assert!(analysis.multi_step);
894 }
895
896 #[test]
897 fn test_task_analysis_vision_detection() {
898 let analysis = TaskAnalysis::from_prompt("Look at this screenshot and describe it");
899 assert!(analysis.requires_vision);
900
901 let analysis = TaskAnalysis::from_prompt("Summarize this text");
902 assert!(!analysis.requires_vision);
903 }
904
905 #[test]
906 fn test_estimate_tokens() {
907 assert_eq!(estimate_tokens("hello world"), 3); assert_eq!(estimate_tokens(""), 1);
909 }
910
911 #[test]
914 fn test_model_selector_basic() {
915 let selector = ModelSelector::new(&["gpt-4o", "gpt-4o-mini", "gpt-3.5-turbo"]);
916 let reqs = ModelRequirements::default();
917 let result = selector.select(&reqs);
918 assert!(result.is_some());
919 }
920
921 #[test]
922 fn test_model_selector_vision_filter() {
923 let selector = ModelSelector::new(&["gpt-4o", "gpt-3.5-turbo"]);
924 let reqs = ModelRequirements::default().with_vision();
925 let ranked = selector.ranked(&reqs);
926
927 assert!(!ranked.is_empty());
929 for m in &ranked {
930 assert!(
931 m.supports_vision,
932 "non-vision model {} passed filter",
933 m.name
934 );
935 }
936 }
937
938 #[test]
939 fn test_model_selector_custom_priority() {
940 let mut selector = ModelSelector::new(&["gpt-4o", "gpt-4o-mini", "gpt-3.5-turbo"]);
941 selector.set_priority("gpt-4o-mini", 999.0);
943
944 let reqs = ModelRequirements::default();
945 let best = selector.select(&reqs).unwrap();
946 assert_eq!(best.name, "gpt-4o-mini");
947 assert_eq!(best.score, 999.0);
948 }
949
950 #[test]
951 fn test_model_selector_cheapest_strategy() {
952 let mut selector = ModelSelector::new(&["gpt-4o", "gpt-4o-mini"]);
953 selector.set_strategy(SelectionStrategy::CheapestFirst);
954
955 let reqs = ModelRequirements::default();
956 let ranked = selector.ranked(&reqs);
957
958 assert!(ranked.len() >= 1);
960 }
961
962 #[test]
963 fn test_model_selector_multi_dispatch() {
964 let selector = ModelSelector::new(&["gpt-4o", "gpt-4o-mini", "gpt-3.5-turbo"]);
965
966 let requirements = vec![
967 ModelRequirements::default().with_vision(),
968 ModelRequirements::default(),
969 ];
970
971 let results = selector.select_multi(&requirements);
972 assert_eq!(results.len(), 2);
973
974 assert!(results[0].is_some());
976 assert!(results[0].as_ref().unwrap().supports_vision);
977
978 assert!(results[1].is_some());
980 }
981
982 #[test]
983 fn test_model_selector_add_model() {
984 let mut selector = ModelSelector::new(&["gpt-4o"]);
985 selector.add_model("gpt-4o-mini");
986 assert_eq!(selector.models.len(), 2);
987 selector.add_model("gpt-4o");
989 assert_eq!(selector.models.len(), 2);
990 }
991
992 #[test]
993 fn test_auto_policy_single_model() {
994 let policy = auto_policy(&["gpt-4o"]);
995 assert_eq!(policy.small, "gpt-4o");
996 assert_eq!(policy.medium, "gpt-4o");
997 assert_eq!(policy.large, "gpt-4o");
998 }
999
1000 #[test]
1001 fn test_auto_policy_multiple_models() {
1002 let policy = auto_policy(&["gpt-4o", "gpt-4o-mini", "gpt-3.5-turbo"]);
1003 assert!(!policy.small.is_empty());
1005 assert!(!policy.medium.is_empty());
1006 assert!(!policy.large.is_empty());
1007 }
1008
1009 #[test]
1010 fn test_auto_policy_empty() {
1011 let policy = auto_policy(&[]);
1012 assert_eq!(policy.small, "gpt-4o-mini");
1014 assert_eq!(policy.medium, "gpt-4o");
1015 }
1016
1017 #[test]
1018 fn test_model_requirements_builder() {
1019 let reqs = ModelRequirements::default()
1020 .with_vision()
1021 .with_min_context(100_000)
1022 .with_max_cost(10.0)
1023 .with_min_arena(60.0);
1024
1025 assert!(reqs.vision);
1026 assert_eq!(reqs.min_context_tokens, 100_000);
1027 assert_eq!(reqs.max_input_cost_per_m, 10.0);
1028 assert_eq!(reqs.min_arena_score, 60.0);
1029 }
1030
1031 #[test]
1032 fn test_task_to_requirements() {
1033 let task = TaskAnalysis::from_prompt("Look at this screenshot and extract data");
1034 let reqs = task.to_requirements();
1035 assert!(reqs.vision);
1036 }
1037
1038 #[test]
1041 fn test_llm_data_vision_models_detected() {
1042 for model in &[
1044 "gpt-4o",
1045 "gpt-4o-mini",
1046 "claude-sonnet-4-5-20250514",
1047 "gemini-2.0-flash",
1048 "qwen2-vl-72b-instruct",
1049 "llama-3.2-11b-vision-instruct",
1050 ] {
1051 assert!(
1052 llm_models_spider::supports_vision(model),
1053 "{model} should support vision"
1054 );
1055 }
1056 for model in &["gpt-3.5-turbo", "deepseek-chat"] {
1058 assert!(
1059 !llm_models_spider::supports_vision(model),
1060 "{model} should NOT support vision"
1061 );
1062 }
1063 }
1064
1065 #[test]
1066 fn test_llm_data_model_profiles_exist() {
1067 let must_have = [
1068 "gpt-4o",
1069 "gpt-4o-mini",
1070 "gpt-3.5-turbo",
1071 "claude-3-5-sonnet-20241022",
1072 "gemini-2.0-flash",
1073 "deepseek-chat",
1074 ];
1075 for name in &must_have {
1076 let profile = llm_models_spider::model_profile(name);
1077 assert!(
1078 profile.is_some(),
1079 "model_profile({name}) should return Some"
1080 );
1081 let p = profile.unwrap();
1082 assert!(
1083 p.max_input_tokens > 0,
1084 "{name} should have max_input_tokens > 0, got {}",
1085 p.max_input_tokens
1086 );
1087 }
1088 }
1089
1090 #[test]
1091 fn test_llm_data_arena_scores_present() {
1092 for name in &["claude-3.5-sonnet", "chatgpt-4o-latest", "claude-opus-4"] {
1095 let profile = llm_models_spider::model_profile(name);
1096 assert!(profile.is_some(), "{name} should have a profile");
1097 let p = profile.unwrap();
1098 assert!(
1099 p.ranks.overall.is_some(),
1100 "{name} should have an arena score"
1101 );
1102 assert!(
1103 p.ranks.overall.unwrap() > 0.0,
1104 "{name} arena score should be > 0"
1105 );
1106 }
1107 }
1108
1109 #[test]
1110 fn test_llm_data_pricing_ordering() {
1111 let cheap = llm_models_spider::model_profile("gpt-4o-mini");
1112 let expensive = llm_models_spider::model_profile("claude-opus-4-20250514");
1113 assert!(cheap.is_some() && expensive.is_some());
1114 let cheap_cost = cheap.unwrap().pricing.input_cost_per_m_tokens.unwrap();
1115 let expensive_cost = expensive.unwrap().pricing.input_cost_per_m_tokens.unwrap();
1116 assert!(
1117 cheap_cost < expensive_cost,
1118 "gpt-4o-mini (${cheap_cost}) should be cheaper than claude-opus-4 (${expensive_cost})"
1119 );
1120 }
1121
1122 #[test]
1123 fn test_llm_data_context_window_ordering() {
1124 let large_ctx = llm_models_spider::model_profile("gemini-2.5-pro-preview-05-06");
1125 let small_ctx = llm_models_spider::model_profile("gpt-3.5-turbo");
1126 assert!(large_ctx.is_some() && small_ctx.is_some());
1127 let large_tokens = large_ctx.unwrap().max_input_tokens;
1128 let small_tokens = small_ctx.unwrap().max_input_tokens;
1129 assert!(
1130 large_tokens > small_tokens,
1131 "gemini-2.5-pro ({large_tokens}) should have more context than gpt-3.5-turbo ({small_tokens})"
1132 );
1133 }
1134
1135 #[test]
1138 fn test_selector_realistic_pool_best_quality() {
1139 let selector = ModelSelector::new(&[
1140 "gpt-4o",
1141 "gpt-4o-mini",
1142 "gpt-3.5-turbo",
1143 "claude-3-5-sonnet-20241022",
1144 "gemini-2.0-flash",
1145 "deepseek-chat",
1146 ]);
1147 let reqs = ModelRequirements::default();
1148 let ranked = selector.ranked(&reqs);
1149 assert!(!ranked.is_empty());
1150 let top = &ranked[0];
1152 for other in &ranked[1..] {
1153 assert!(
1154 top.score >= other.score,
1155 "top model {} (score {}) should beat {} (score {})",
1156 top.name,
1157 top.score,
1158 other.name,
1159 other.score
1160 );
1161 }
1162 }
1163
1164 #[test]
1165 fn test_selector_realistic_pool_cheapest() {
1166 let mut selector = ModelSelector::new(&[
1167 "gpt-4o",
1168 "gpt-4o-mini",
1169 "gpt-3.5-turbo",
1170 "claude-3-5-sonnet-20241022",
1171 ]);
1172 selector.set_strategy(SelectionStrategy::CheapestFirst);
1173 let reqs = ModelRequirements::default();
1174 let ranked = selector.ranked(&reqs);
1175 assert!(ranked.len() >= 2);
1176 let top = &ranked[0];
1178 let bottom = ranked.last().unwrap();
1179 if let (Some(top_cost), Some(bottom_cost)) = (top.input_cost_per_m, bottom.input_cost_per_m)
1180 {
1181 assert!(
1182 top_cost <= bottom_cost,
1183 "cheapest ({}, ${top_cost}) should rank above expensive ({}, ${bottom_cost})",
1184 top.name,
1185 bottom.name
1186 );
1187 }
1188 }
1189
1190 #[test]
1191 fn test_selector_vision_filter_rejects_text_only() {
1192 let selector = ModelSelector::new(&["gpt-3.5-turbo", "deepseek-chat"]);
1193 let reqs = ModelRequirements::default().with_vision();
1194 let result = selector.select(&reqs);
1195 assert!(
1196 result.is_none(),
1197 "text-only pool should return None for vision requirement"
1198 );
1199 }
1200
1201 #[test]
1202 fn test_selector_unknown_models_graceful() {
1203 let selector = ModelSelector::new(&["my-custom-model", "local-llama"]);
1204 let reqs = ModelRequirements::default();
1205 let result = selector.select(&reqs);
1206 assert!(
1207 result.is_some(),
1208 "unknown models should still return Some with default score"
1209 );
1210 let scored = result.unwrap();
1211 assert_eq!(scored.score, 50.0, "unknown model gets default score 50.0");
1212 }
1213
1214 #[test]
1215 fn test_selector_single_model_all_strategies() {
1216 for strategy in &[
1217 SelectionStrategy::BestQuality,
1218 SelectionStrategy::CheapestFirst,
1219 SelectionStrategy::LargestContext,
1220 SelectionStrategy::ValueOptimal,
1221 ] {
1222 let mut selector = ModelSelector::new(&["gpt-4o"]);
1223 selector.set_strategy(*strategy);
1224 let reqs = ModelRequirements::default();
1225 let result = selector.select(&reqs);
1226 assert!(
1227 result.is_some(),
1228 "single model should be returned for {strategy:?}"
1229 );
1230 assert_eq!(result.unwrap().name, "gpt-4o");
1231 }
1232 }
1233
1234 #[test]
1235 fn test_selector_deterministic_ordering() {
1236 let selector = ModelSelector::new(&[
1237 "gpt-4o",
1238 "gpt-4o-mini",
1239 "claude-3-5-sonnet-20241022",
1240 "gemini-2.0-flash",
1241 ]);
1242 let reqs = ModelRequirements::default();
1243
1244 let first_run: Vec<String> = selector
1245 .ranked(&reqs)
1246 .iter()
1247 .map(|m| m.name.clone())
1248 .collect();
1249 let second_run: Vec<String> = selector
1250 .ranked(&reqs)
1251 .iter()
1252 .map(|m| m.name.clone())
1253 .collect();
1254 assert_eq!(
1255 first_run, second_run,
1256 "repeated calls must produce identical ordering"
1257 );
1258 }
1259
1260 #[test]
1261 fn test_selector_cost_filter_strict() {
1262 let selector = ModelSelector::new(&["gpt-4o", "gpt-4o-mini", "gpt-3.5-turbo"]);
1263 let reqs = ModelRequirements::default().with_max_cost(1.0);
1264 let ranked = selector.ranked(&reqs);
1265 for m in &ranked {
1266 if let Some(cost) = m.input_cost_per_m {
1267 assert!(
1268 cost <= 1.0,
1269 "{} has cost ${cost} which exceeds max 1.0",
1270 m.name
1271 );
1272 }
1273 }
1274 }
1275
1276 #[test]
1277 fn test_selector_min_context_filter() {
1278 let selector = ModelSelector::new(&["gpt-4o", "gpt-3.5-turbo", "gemini-2.0-flash"]);
1279 let reqs = ModelRequirements::default().with_min_context(500_000);
1280 let ranked = selector.ranked(&reqs);
1281 for m in &ranked {
1282 assert!(
1283 m.max_input_tokens >= 500_000,
1284 "{} has {} tokens, below 500k minimum",
1285 m.name,
1286 m.max_input_tokens
1287 );
1288 }
1289 }
1290
1291 #[test]
1292 fn test_selector_value_optimal_balances() {
1293 let mut selector = ModelSelector::new(&[
1294 "gpt-4o", "gpt-4o-mini", "gpt-3.5-turbo", ]);
1298 selector.set_strategy(SelectionStrategy::ValueOptimal);
1299 let reqs = ModelRequirements::default();
1300 let ranked = selector.ranked(&reqs);
1301 assert!(!ranked.is_empty());
1302 let top = &ranked[0];
1303 assert!(top.score > 0.0, "ValueOptimal score should be positive");
1307 }
1308
1309 #[test]
1310 fn test_select_multi_no_reuse() {
1311 let selector = ModelSelector::new(&["gpt-4o", "gpt-4o-mini", "gpt-3.5-turbo"]);
1312 let requirements = vec![
1313 ModelRequirements::default(),
1314 ModelRequirements::default(),
1315 ModelRequirements::default(),
1316 ];
1317 let results = selector.select_multi(&requirements);
1318 assert_eq!(results.len(), 3);
1319 let names: Vec<&str> = results
1321 .iter()
1322 .filter_map(|r| r.as_ref().map(|m| m.name.as_str()))
1323 .collect();
1324 assert_eq!(names.len(), 3, "all 3 requests should get a model");
1325 let mut deduped = names.clone();
1327 deduped.sort();
1328 deduped.dedup();
1329 assert_eq!(
1330 deduped.len(),
1331 3,
1332 "no model should be reused when pool is large enough"
1333 );
1334 }
1335
1336 #[test]
1337 fn test_select_multi_exhaustion_fallback() {
1338 let selector = ModelSelector::new(&["gpt-4o"]);
1339 let requirements = vec![
1340 ModelRequirements::default(),
1341 ModelRequirements::default(),
1342 ModelRequirements::default(),
1343 ];
1344 let results = selector.select_multi(&requirements);
1345 assert_eq!(results.len(), 3);
1346 assert!(results[0].is_some());
1348 assert!(
1349 results[1].is_some(),
1350 "fallback should reuse the single model"
1351 );
1352 assert!(
1353 results[2].is_some(),
1354 "fallback should reuse the single model"
1355 );
1356 assert_eq!(results[0].as_ref().unwrap().name, "gpt-4o");
1358 assert_eq!(results[1].as_ref().unwrap().name, "gpt-4o");
1359 assert_eq!(results[2].as_ref().unwrap().name, "gpt-4o");
1360 }
1361
1362 #[test]
1363 fn test_selector_priority_override_beats_arena() {
1364 let mut selector = ModelSelector::new(&["gpt-4o", "gpt-4o-mini", "gpt-3.5-turbo"]);
1365 selector.set_priority("gpt-3.5-turbo", 999.0);
1367 let reqs = ModelRequirements::default();
1368 let best = selector.select(&reqs).unwrap();
1369 assert_eq!(
1370 best.name, "gpt-3.5-turbo",
1371 "priority override should beat natural arena score"
1372 );
1373 assert_eq!(best.score, 999.0);
1374 }
1375
1376 #[test]
1379 fn test_auto_policy_realistic_tiering() {
1380 let policy = auto_policy(&["gpt-4o", "gpt-4o-mini", "gpt-3.5-turbo"]);
1381 assert_ne!(
1385 policy.large, policy.small,
1386 "large and small should be different models"
1387 );
1388 assert_eq!(policy.model_for_tier(CostTier::High), policy.large);
1390 assert_eq!(policy.model_for_tier(CostTier::Low), policy.small);
1391 assert_eq!(policy.model_for_tier(CostTier::Medium), policy.medium);
1392 }
1393
1394 #[test]
1395 fn test_auto_policy_2_models() {
1396 let policy = auto_policy(&["gpt-4o", "gpt-4o-mini"]);
1397 assert_eq!(policy.large, "gpt-4o");
1399 assert_eq!(policy.medium, "gpt-4o");
1400 assert_eq!(policy.small, "gpt-4o-mini");
1401 assert_eq!(
1402 policy.medium, policy.large,
1403 "2-model policy should have medium == large"
1404 );
1405 assert_ne!(policy.large, policy.small, "large and small should differ");
1406 }
1407
1408 #[test]
1409 fn test_auto_policy_unknown_models() {
1410 let policy = auto_policy(&["my-custom-llm", "local-model-7b", "test-endpoint"]);
1411 assert!(!policy.small.is_empty());
1413 assert!(!policy.medium.is_empty());
1414 assert!(!policy.large.is_empty());
1415 assert!(policy.allow_large);
1416 }
1417
1418 #[test]
1419 fn test_auto_policy_to_router_e2e() {
1420 let policy = auto_policy(&["gpt-4o", "gpt-4o-mini", "gpt-3.5-turbo"]);
1421 let router = ModelRouter::with_policy(policy.clone());
1422
1423 let simple_task = TaskAnalysis {
1425 estimated_tokens: 100,
1426 ..Default::default()
1427 };
1428 let decision = router.route(&simple_task);
1429 assert_eq!(decision.tier, CostTier::Low);
1430 assert_eq!(decision.model, policy.small);
1431
1432 let hard_task = TaskAnalysis {
1434 estimated_tokens: 5000,
1435 requires_reasoning: true,
1436 requires_code_generation: true,
1437 ..Default::default()
1438 };
1439 let decision = router.route(&hard_task);
1440 assert_eq!(decision.tier, CostTier::High);
1441 assert_eq!(decision.model, policy.large);
1442
1443 let medium_task = TaskAnalysis {
1445 estimated_tokens: 2000,
1446 requires_structured_output: true,
1447 multi_step: true,
1448 ..Default::default()
1449 };
1450 let decision = router.route(&medium_task);
1451 assert_eq!(decision.tier, CostTier::Medium);
1452 assert_eq!(decision.model, policy.medium);
1453 }
1454
1455 #[test]
1456 fn test_auto_policy_to_router_e2e_single_model() {
1457 let policy = auto_policy(&["gpt-4o"]);
1459 assert_eq!(policy.small, "gpt-4o");
1460 assert_eq!(policy.medium, "gpt-4o");
1461 assert_eq!(policy.large, "gpt-4o");
1462
1463 let router = ModelRouter::with_policy(policy);
1464
1465 let simple = TaskAnalysis {
1467 estimated_tokens: 50,
1468 ..Default::default()
1469 };
1470 let medium = TaskAnalysis {
1471 estimated_tokens: 2000,
1472 requires_structured_output: true,
1473 ..Default::default()
1474 };
1475 let hard = TaskAnalysis {
1476 estimated_tokens: 5000,
1477 requires_reasoning: true,
1478 requires_code_generation: true,
1479 ..Default::default()
1480 };
1481
1482 for (label, task) in [("simple", &simple), ("medium", &medium), ("hard", &hard)] {
1483 let decision = router.route(task);
1484 assert_eq!(
1485 decision.model, "gpt-4o",
1486 "{label} task should still route to the only model"
1487 );
1488 }
1489 }
1490
1491 #[test]
1492 fn test_selector_single_model_vision_mismatch() {
1493 let selector = ModelSelector::new(&["gpt-3.5-turbo"]);
1495 let reqs = ModelRequirements::default().with_vision();
1496 assert!(
1497 selector.select(&reqs).is_none(),
1498 "single text-only model should not satisfy vision requirement"
1499 );
1500
1501 let selector = ModelSelector::new(&["gpt-4o"]);
1503 let result = selector.select(&reqs);
1504 assert!(
1505 result.is_some(),
1506 "single vision model should satisfy vision"
1507 );
1508 assert_eq!(result.unwrap().name, "gpt-4o");
1509 }
1510
1511 #[test]
1512 fn test_selector_single_model_with_cost_filter() {
1513 let selector = ModelSelector::new(&["gpt-4o"]);
1515 let reqs = ModelRequirements::default().with_max_cost(0.01);
1516 assert!(
1517 selector.select(&reqs).is_none(),
1518 "single expensive model should be filtered by strict cost limit"
1519 );
1520
1521 let reqs = ModelRequirements::default().with_max_cost(100.0);
1523 let result = selector.select(&reqs);
1524 assert!(result.is_some());
1525 assert_eq!(result.unwrap().name, "gpt-4o");
1526 }
1527
1528 #[test]
1529 fn test_selector_single_unknown_model_e2e() {
1530 let policy = auto_policy(&["my-local-llama"]);
1532 assert_eq!(policy.small, "my-local-llama");
1533 assert_eq!(policy.medium, "my-local-llama");
1534 assert_eq!(policy.large, "my-local-llama");
1535
1536 let router = ModelRouter::with_policy(policy);
1537 let decision = router.route_simple("do something complex and analyze the code");
1538 assert_eq!(
1539 decision.model, "my-local-llama",
1540 "unknown single model should still be routed to"
1541 );
1542
1543 let selector = ModelSelector::new(&["my-local-llama"]);
1545 let result = selector.select(&ModelRequirements::default());
1546 assert!(result.is_some());
1547 let scored = result.unwrap();
1548 assert_eq!(scored.name, "my-local-llama");
1549 assert_eq!(scored.score, 50.0, "unknown model gets default score");
1550 assert_eq!(
1551 scored.max_input_tokens, 0,
1552 "unknown model has no context data"
1553 );
1554 assert!(
1555 scored.arena_rank.is_none(),
1556 "unknown model has no arena data"
1557 );
1558 }
1559
1560 #[test]
1561 fn test_router_latency_constraint_downgrade() {
1562 let policy = ModelPolicy {
1563 max_latency_ms: Some(1000),
1564 ..Default::default()
1565 };
1566 let router = ModelRouter::with_policy(policy);
1567
1568 let task = TaskAnalysis {
1570 estimated_tokens: 5000,
1571 requires_reasoning: true,
1572 requires_code_generation: true,
1573 ..Default::default()
1574 };
1575 let decision = router.route(&task);
1576 assert_ne!(
1578 decision.tier,
1579 CostTier::High,
1580 "latency constraint should prevent High tier"
1581 );
1582 }
1583
1584 #[test]
1585 fn test_router_allow_large_false() {
1586 let policy = ModelPolicy {
1587 allow_large: false,
1588 ..Default::default()
1589 };
1590 let router = ModelRouter::with_policy(policy);
1591
1592 let task = TaskAnalysis {
1593 estimated_tokens: 5000,
1594 requires_reasoning: true,
1595 requires_code_generation: true,
1596 ..Default::default()
1597 };
1598 let decision = router.route(&task);
1599 assert_ne!(
1600 decision.tier,
1601 CostTier::High,
1602 "allow_large=false should cap at Medium"
1603 );
1604 }
1605
1606 #[test]
1607 fn test_router_threshold_customization() {
1608 let router = ModelRouter::new().with_thresholds(100, 200);
1609
1610 let task = TaskAnalysis {
1613 estimated_tokens: 300,
1614 requires_reasoning: true,
1615 requires_code_generation: true,
1616 ..Default::default()
1617 };
1618 let decision = router.route(&task);
1619 assert_eq!(
1620 decision.tier,
1621 CostTier::High,
1622 "lowered thresholds should promote to High tier sooner"
1623 );
1624
1625 let default_router = ModelRouter::new();
1628 let decision = default_router.route(&task);
1629 assert_eq!(
1630 decision.tier,
1631 CostTier::Medium,
1632 "default thresholds should keep this at Medium"
1633 );
1634 }
1635
1636 #[test]
1639 fn test_selector_empty_pool() {
1640 let selector = ModelSelector::new(&[]);
1641 let reqs = ModelRequirements::default();
1642 let result = selector.select(&reqs);
1643 assert!(result.is_none(), "empty pool should return None");
1644 }
1645
1646 #[test]
1647 fn test_selector_duplicate_models() {
1648 let selector = ModelSelector::new(&["gpt-4o", "gpt-4o", "gpt-4o"]);
1649 let requirements = vec![
1650 ModelRequirements::default(),
1651 ModelRequirements::default(),
1652 ModelRequirements::default(),
1653 ];
1654 let results = selector.select_multi(&requirements);
1655 assert_eq!(results.len(), 3, "should not hang on duplicates");
1656 for (i, r) in results.iter().enumerate() {
1658 assert!(r.is_some(), "request {i} should get a model");
1659 }
1660 }
1661
1662 #[test]
1663 fn test_task_analysis_edge_cases() {
1664 let analysis = TaskAnalysis::from_prompt("");
1666 assert_eq!(analysis.estimated_tokens, 1);
1667 assert!(!analysis.requires_reasoning);
1668
1669 let analysis = TaskAnalysis::from_prompt(
1671 "analyze compare explain why code implement function script json extract list then step first next screenshot image",
1672 );
1673 assert!(analysis.requires_reasoning);
1674 assert!(analysis.requires_code_generation);
1675 assert!(analysis.requires_structured_output);
1676 assert!(analysis.multi_step);
1677 assert!(analysis.requires_vision);
1678
1679 let analysis = TaskAnalysis::from_prompt("你好世界 🌍 日本語テスト");
1681 assert!(!analysis.requires_reasoning);
1682 assert!(!analysis.requires_code_generation);
1683 assert!(analysis.estimated_tokens > 0);
1684 }
1685
1686 #[test]
1687 fn test_auto_policy_large_pool() {
1688 let models: Vec<&str> = vec![
1689 "gpt-4o",
1690 "gpt-4o-mini",
1691 "gpt-3.5-turbo",
1692 "claude-3-5-sonnet-20241022",
1693 "claude-3-5-haiku-20241022",
1694 "gemini-2.0-flash",
1695 "deepseek-chat",
1696 "unknown-model-1",
1697 "unknown-model-2",
1698 "unknown-model-3",
1699 "unknown-model-4",
1700 "unknown-model-5",
1701 "unknown-model-6",
1702 "unknown-model-7",
1703 "unknown-model-8",
1704 "unknown-model-9",
1705 "unknown-model-10",
1706 "unknown-model-11",
1707 "unknown-model-12",
1708 "unknown-model-13",
1709 ];
1710 let policy = auto_policy(&models);
1711 assert!(!policy.small.is_empty());
1712 assert!(!policy.medium.is_empty());
1713 assert!(!policy.large.is_empty());
1714 assert!(policy.allow_large);
1715 assert_eq!(policy.max_cost_tier, CostTier::High);
1716 }
1717
1718 #[test]
1721 fn test_classify_round_complexity_round_0() {
1722 let analysis = classify_round_complexity("click button", 1000, 0, false);
1723 assert!(
1724 analysis.requires_reasoning,
1725 "round 0 always requires reasoning"
1726 );
1727 assert!(analysis.requires_structured_output);
1728 }
1729
1730 #[test]
1731 fn test_classify_round_complexity_stagnated() {
1732 let analysis = classify_round_complexity("click button", 1000, 5, true);
1733 assert!(
1734 analysis.requires_reasoning,
1735 "stagnated rounds need reasoning"
1736 );
1737 assert!(analysis.multi_step, "stagnated rounds are multi-step");
1738 }
1739
1740 #[test]
1741 fn test_classify_round_complexity_large_html() {
1742 let analysis = classify_round_complexity("click button", 60_000, 3, false);
1743 assert!(analysis.requires_reasoning, "large HTML needs reasoning");
1744 }
1745
1746 #[test]
1747 fn test_classify_round_complexity_simple() {
1748 let analysis = classify_round_complexity("click button", 1000, 3, false);
1749 assert!(!analysis.requires_reasoning);
1751 assert!(!analysis.multi_step);
1752 assert!(analysis.requires_structured_output);
1753 }
1754
1755 #[test]
1756 fn test_policy_getter() {
1757 let policy = ModelPolicy {
1758 small: "small-model".to_string(),
1759 medium: "medium-model".to_string(),
1760 large: "large-model".to_string(),
1761 allow_large: true,
1762 max_latency_ms: None,
1763 max_cost_tier: CostTier::High,
1764 };
1765 let router = ModelRouter::with_policy(policy.clone());
1766 assert_eq!(router.policy().small, "small-model");
1767 assert_eq!(router.policy().large, "large-model");
1768 }
1769}