Skip to main content

spider_agent/automation/
router.rs

1//! Smart model routing for optimal performance and cost.
2//!
3//! Routes requests to appropriate models based on:
4//! - Task complexity and category
5//! - Token count estimates
6//! - Latency requirements
7//! - Cost constraints
8//! - Arena rankings and model capabilities (from `llm_models_spider`)
9//!
10//! ## Multi-Agent Model Selection
11//!
12//! The [`ModelSelector`] allows users to pass in their available models and
13//! get optimal routing based on task type, arena rankings, and pricing data.
14//! Users can also define custom rank/priority overrides.
15
16use super::{CostTier, ModelPolicy};
17use std::time::Duration;
18
19/// Smart router for selecting optimal models.
20///
21/// Analyzes tasks and routes them to the most appropriate model
22/// based on complexity, cost, and latency requirements.
23#[derive(Debug, Clone)]
24pub struct ModelRouter {
25    /// Model policy configuration.
26    policy: ModelPolicy,
27    /// Token threshold for using larger models.
28    large_model_threshold: usize,
29    /// Token threshold for using medium models.
30    medium_model_threshold: usize,
31    /// Whether to enable smart routing.
32    enabled: bool,
33}
34
35impl Default for ModelRouter {
36    fn default() -> Self {
37        Self::new()
38    }
39}
40
41impl ModelRouter {
42    /// Create a new router with default settings.
43    pub fn new() -> Self {
44        Self {
45            policy: ModelPolicy::default(),
46            large_model_threshold: 4000,
47            medium_model_threshold: 1000,
48            enabled: true,
49        }
50    }
51
52    /// Create with custom policy.
53    pub fn with_policy(policy: ModelPolicy) -> Self {
54        Self {
55            policy,
56            ..Default::default()
57        }
58    }
59
60    /// Set token thresholds.
61    pub fn with_thresholds(mut self, medium: usize, large: usize) -> Self {
62        self.medium_model_threshold = medium;
63        self.large_model_threshold = large;
64        self
65    }
66
67    /// Get a reference to the underlying policy.
68    pub fn policy(&self) -> &ModelPolicy {
69        &self.policy
70    }
71
72    /// Enable or disable smart routing.
73    pub fn enabled(mut self, enabled: bool) -> Self {
74        self.enabled = enabled;
75        self
76    }
77
78    /// Route a task to the optimal model.
79    ///
80    /// Returns the recommended model name.
81    pub fn route(&self, task: &TaskAnalysis) -> RoutingDecision {
82        if !self.enabled {
83            return RoutingDecision {
84                model: self.policy.medium.clone(),
85                tier: CostTier::Medium,
86                reason: "Smart routing disabled".to_string(),
87            };
88        }
89
90        // Determine complexity tier
91        let tier = self.analyze_complexity(task);
92
93        // Check policy constraints
94        let tier = self.apply_constraints(tier, task);
95
96        let model = self.policy.model_for_tier(tier).to_string();
97        let reason = self.explain_routing(task, tier);
98
99        RoutingDecision {
100            model,
101            tier,
102            reason,
103        }
104    }
105
106    /// Analyze task complexity.
107    fn analyze_complexity(&self, task: &TaskAnalysis) -> CostTier {
108        let mut score = 0;
109
110        // Token count factor
111        if task.estimated_tokens > self.large_model_threshold {
112            score += 3;
113        } else if task.estimated_tokens > self.medium_model_threshold {
114            score += 2;
115        } else {
116            score += 1;
117        }
118
119        // Complexity indicators
120        if task.requires_reasoning {
121            score += 2;
122        }
123        if task.requires_code_generation {
124            score += 2;
125        }
126        if task.requires_structured_output {
127            score += 1;
128        }
129        if task.multi_step {
130            score += 1;
131        }
132
133        // Map score to tier
134        match score {
135            0..=2 => CostTier::Low,
136            3..=5 => CostTier::Medium,
137            _ => CostTier::High,
138        }
139    }
140
141    /// Apply policy constraints to the selected tier.
142    fn apply_constraints(&self, tier: CostTier, task: &TaskAnalysis) -> CostTier {
143        // Check max tier constraint
144        let tier = match (tier, self.policy.max_cost_tier) {
145            (CostTier::High, CostTier::Low) => CostTier::Low,
146            (CostTier::High, CostTier::Medium) => CostTier::Medium,
147            (CostTier::Medium, CostTier::Low) => CostTier::Low,
148            _ => tier,
149        };
150
151        // Check latency constraint
152        if let Some(max_latency) = self.policy.max_latency_ms {
153            let estimated_latency = self.estimate_latency(tier, task);
154            if estimated_latency > max_latency {
155                // Downgrade to faster model
156                return match tier {
157                    CostTier::High => CostTier::Medium,
158                    CostTier::Medium => CostTier::Low,
159                    CostTier::Low => CostTier::Low,
160                };
161            }
162        }
163
164        // Check if large model is allowed
165        if tier == CostTier::High && !self.policy.allow_large {
166            return CostTier::Medium;
167        }
168
169        tier
170    }
171
172    /// Estimate latency for a tier.
173    fn estimate_latency(&self, tier: CostTier, task: &TaskAnalysis) -> u64 {
174        // Rough estimates in milliseconds
175        let base_latency = match tier {
176            CostTier::Low => 500,
177            CostTier::Medium => 1500,
178            CostTier::High => 3000,
179        };
180
181        // Add token-based estimate (rough: 50ms per 100 tokens)
182        let token_latency = (task.estimated_tokens as u64 / 100) * 50;
183
184        base_latency + token_latency
185    }
186
187    /// Explain the routing decision.
188    fn explain_routing(&self, task: &TaskAnalysis, tier: CostTier) -> String {
189        let mut reasons = Vec::new();
190
191        if task.estimated_tokens > self.large_model_threshold {
192            reasons.push("high token count");
193        }
194        if task.requires_reasoning {
195            reasons.push("requires reasoning");
196        }
197        if task.requires_code_generation {
198            reasons.push("requires code generation");
199        }
200
201        if reasons.is_empty() {
202            reasons.push("standard task");
203        }
204
205        format!("{:?} tier selected: {}", tier, reasons.join(", "))
206    }
207
208    /// Quickly route a simple prompt.
209    pub fn route_simple(&self, prompt: &str) -> RoutingDecision {
210        let task = TaskAnalysis::from_prompt(prompt);
211        self.route(&task)
212    }
213}
214
215// ── ModelSelector ─────────────────────────────────────────────────────────────
216
217/// Capability requirements for model selection.
218#[derive(Debug, Clone, Copy, Default)]
219pub struct ModelRequirements {
220    /// Requires vision/image input support.
221    pub vision: bool,
222    /// Requires audio input support.
223    pub audio: bool,
224    /// Requires video input support.
225    pub video: bool,
226    /// Requires PDF/file input support.
227    pub pdf: bool,
228    /// Minimum context window (input tokens). 0 = no requirement.
229    pub min_context_tokens: u32,
230    /// Maximum input cost per million tokens in USD. 0.0 = no limit.
231    pub max_input_cost_per_m: f32,
232    /// Minimum arena score (0.0-100.0). 0.0 = no minimum.
233    pub min_arena_score: f32,
234}
235
236impl ModelRequirements {
237    /// Require vision support.
238    pub fn with_vision(mut self) -> Self {
239        self.vision = true;
240        self
241    }
242
243    /// Require a minimum context window.
244    pub fn with_min_context(mut self, tokens: u32) -> Self {
245        self.min_context_tokens = tokens;
246        self
247    }
248
249    /// Set a maximum input cost per million tokens.
250    pub fn with_max_cost(mut self, cost: f32) -> Self {
251        self.max_input_cost_per_m = cost;
252        self
253    }
254
255    /// Require a minimum arena score.
256    pub fn with_min_arena(mut self, score: f32) -> Self {
257        self.min_arena_score = score;
258        self
259    }
260}
261
262/// A scored model candidate for selection.
263#[derive(Debug, Clone)]
264pub struct ScoredModel {
265    /// Model name (borrowed from the user's pool or MODEL_INFO).
266    pub name: String,
267    /// Effective score used for ranking (higher is better).
268    pub score: f32,
269    /// Arena rank if available (0.0-100.0).
270    pub arena_rank: Option<f32>,
271    /// Input cost per million tokens if available.
272    pub input_cost_per_m: Option<f32>,
273    /// Whether the model supports vision.
274    pub supports_vision: bool,
275    /// Max input tokens.
276    pub max_input_tokens: u32,
277}
278
279/// Priority strategy for scoring models.
280#[derive(Debug, Clone, Copy, PartialEq, Default)]
281pub enum SelectionStrategy {
282    /// Prefer highest arena score (quality).
283    #[default]
284    BestQuality,
285    /// Prefer lowest cost.
286    CheapestFirst,
287    /// Prefer largest context window.
288    LargestContext,
289    /// Balance quality and cost (arena_score / cost).
290    ValueOptimal,
291}
292
293/// Flexible model selector that picks optimal models from a user-provided pool.
294///
295/// Users pass in the models they have available (API keys, endpoints), and the
296/// selector uses arena rankings, pricing, and capability data to rank them.
297/// Custom priority overrides let users boost or penalize specific models.
298///
299/// # Example
300///
301/// ```rust,ignore
302/// use spider_agent::automation::router::{ModelSelector, ModelRequirements, SelectionStrategy};
303///
304/// let mut selector = ModelSelector::new(&["gpt-4o", "claude-sonnet-4.5", "gemini-2.5-pro"]);
305/// selector.set_strategy(SelectionStrategy::BestQuality);
306///
307/// // Pick best model for a vision task
308/// let reqs = ModelRequirements::default().with_vision();
309/// if let Some(best) = selector.select(&reqs) {
310///     println!("Use: {} (score: {:.1})", best.name, best.score);
311/// }
312/// ```
313#[derive(Debug, Clone)]
314pub struct ModelSelector {
315    /// User's available models with optional custom priority overrides.
316    /// Each entry: (model_name, custom_priority_override).
317    /// Priority override: None = use auto scoring; Some(f32) = fixed score.
318    models: Vec<(String, Option<f32>)>,
319    /// Selection strategy.
320    strategy: SelectionStrategy,
321}
322
323impl ModelSelector {
324    /// Create a selector from a list of available model names.
325    pub fn new(models: &[&str]) -> Self {
326        Self {
327            models: models.iter().map(|m| (m.to_lowercase(), None)).collect(),
328            strategy: SelectionStrategy::default(),
329        }
330    }
331
332    /// Create from owned strings.
333    pub fn from_owned(models: Vec<String>) -> Self {
334        Self {
335            models: models
336                .into_iter()
337                .map(|m| (m.to_lowercase(), None))
338                .collect(),
339            strategy: SelectionStrategy::default(),
340        }
341    }
342
343    /// Set the selection strategy.
344    pub fn set_strategy(&mut self, strategy: SelectionStrategy) {
345        self.strategy = strategy;
346    }
347
348    /// Set a custom priority override for a specific model.
349    ///
350    /// The priority is a fixed score (higher = more preferred).
351    /// This overrides the auto-calculated score from arena/pricing data.
352    pub fn set_priority(&mut self, model: &str, priority: f32) {
353        let lower = model.to_lowercase();
354        for (name, prio) in &mut self.models {
355            if *name == lower {
356                *prio = Some(priority);
357                return;
358            }
359        }
360        // Model not in pool — add it with the override
361        self.models.push((lower, Some(priority)));
362    }
363
364    /// Add a model to the pool.
365    pub fn add_model(&mut self, model: &str) {
366        let lower = model.to_lowercase();
367        if !self.models.iter().any(|(n, _)| *n == lower) {
368            self.models.push((lower, None));
369        }
370    }
371
372    /// Select the best model matching the given requirements.
373    ///
374    /// Returns `None` if no model in the pool satisfies the requirements.
375    ///
376    /// For pools with ≤ 2 models, skips scoring/sorting and returns the first
377    /// model that satisfies the requirements. Use the full ranking pipeline
378    /// only when there are 3+ models to meaningfully choose between.
379    pub fn select(&self, reqs: &ModelRequirements) -> Option<ScoredModel> {
380        if self.models.len() <= 2 {
381            // Fast path: no meaningful selection with 0-2 models.
382            // Just check requirements and return the first match.
383            return self
384                .models
385                .iter()
386                .filter_map(|(name, custom_prio)| self.score_model(name, *custom_prio, reqs))
387                .next();
388        }
389        self.ranked(reqs).into_iter().next()
390    }
391
392    /// Return all models that satisfy the requirements, ranked best-to-worst.
393    ///
394    /// For pools with ≤ 2 models, skips the sorting step since the ordering
395    /// is trivial. Scoring/sorting is only worthwhile with 3+ candidates.
396    pub fn ranked(&self, reqs: &ModelRequirements) -> Vec<ScoredModel> {
397        let mut candidates: Vec<ScoredModel> = self
398            .models
399            .iter()
400            .filter_map(|(name, custom_prio)| self.score_model(name, *custom_prio, reqs))
401            .collect();
402
403        // Only worth sorting when there are 3+ candidates
404        if candidates.len() > 2 {
405            candidates.sort_by(|a, b| {
406                b.score
407                    .partial_cmp(&a.score)
408                    .unwrap_or(std::cmp::Ordering::Equal)
409            });
410        }
411        candidates
412    }
413
414    /// Select the best model for each distinct requirement set from a list.
415    ///
416    /// Useful for multi-agent dispatch: given N different sub-tasks with different
417    /// requirements, get the best model for each without reusing the same model
418    /// (unless it's the only option).
419    pub fn select_multi(&self, requirements: &[ModelRequirements]) -> Vec<Option<ScoredModel>> {
420        let mut used: Vec<bool> = vec![false; self.models.len()];
421        let mut results = Vec::with_capacity(requirements.len());
422
423        for reqs in requirements {
424            let mut best: Option<(ScoredModel, usize)> = None;
425
426            for (idx, (name, custom_prio)) in self.models.iter().enumerate() {
427                if used[idx] {
428                    continue;
429                }
430                if let Some(scored) = self.score_model(name, *custom_prio, reqs) {
431                    let dominated = match &best {
432                        Some((current, _)) => scored.score > current.score,
433                        None => true,
434                    };
435                    if dominated {
436                        best = Some((scored, idx));
437                    }
438                }
439            }
440
441            if let Some((model, idx)) = best {
442                used[idx] = true;
443                results.push(Some(model));
444            } else {
445                // Fallback: allow reuse if no unused model fits
446                let fallback = self.select(reqs);
447                results.push(fallback);
448            }
449        }
450
451        results
452    }
453
454    /// Score a model against requirements. Returns None if it doesn't meet them.
455    fn score_model(
456        &self,
457        name: &str,
458        custom_prio: Option<f32>,
459        reqs: &ModelRequirements,
460    ) -> Option<ScoredModel> {
461        let profile = llm_models_spider::model_profile(name);
462
463        // Extract capabilities (from profile or from individual lookups)
464        let has_vision = profile
465            .as_ref()
466            .map(|p| p.capabilities.vision)
467            .unwrap_or_else(|| llm_models_spider::supports_vision(name));
468        let has_audio = profile
469            .as_ref()
470            .map(|p| p.capabilities.audio)
471            .unwrap_or(false);
472        let has_video = profile
473            .as_ref()
474            .map(|p| p.capabilities.video)
475            .unwrap_or_else(|| llm_models_spider::supports_video(name));
476        let has_pdf = profile
477            .as_ref()
478            .map(|p| p.capabilities.file)
479            .unwrap_or_else(|| llm_models_spider::supports_pdf(name));
480
481        // Check hard requirements
482        if reqs.vision && !has_vision {
483            return None;
484        }
485        if reqs.audio && !has_audio {
486            return None;
487        }
488        if reqs.video && !has_video {
489            return None;
490        }
491        if reqs.pdf && !has_pdf {
492            return None;
493        }
494
495        let max_input = profile.as_ref().map(|p| p.max_input_tokens).unwrap_or(0);
496        if reqs.min_context_tokens > 0 && max_input < reqs.min_context_tokens {
497            return None;
498        }
499
500        let arena = profile.as_ref().and_then(|p| p.ranks.overall);
501        let input_cost = profile
502            .as_ref()
503            .and_then(|p| p.pricing.input_cost_per_m_tokens);
504
505        if reqs.max_input_cost_per_m > 0.0 {
506            if let Some(cost) = input_cost {
507                if cost > reqs.max_input_cost_per_m {
508                    return None;
509                }
510            }
511        }
512
513        if reqs.min_arena_score > 0.0 {
514            match arena {
515                Some(score) if score >= reqs.min_arena_score => {}
516                Some(_) => return None,
517                None => {} // Unknown arena — don't filter out
518            }
519        }
520
521        // Compute score
522        let score = if let Some(prio) = custom_prio {
523            prio
524        } else {
525            self.auto_score(arena, input_cost, max_input)
526        };
527
528        Some(ScoredModel {
529            name: name.to_string(),
530            score,
531            arena_rank: arena,
532            input_cost_per_m: input_cost,
533            supports_vision: has_vision,
534            max_input_tokens: max_input,
535        })
536    }
537
538    /// Compute an automatic score based on strategy.
539    fn auto_score(&self, arena: Option<f32>, cost: Option<f32>, context: u32) -> f32 {
540        match self.strategy {
541            SelectionStrategy::BestQuality => arena.unwrap_or(50.0),
542            SelectionStrategy::CheapestFirst => {
543                // Invert cost: lower cost = higher score
544                match cost {
545                    Some(c) if c > 0.0 => 1000.0 / c,
546                    _ => 50.0, // Unknown cost = neutral
547                }
548            }
549            SelectionStrategy::LargestContext => context as f32 / 1000.0,
550            SelectionStrategy::ValueOptimal => {
551                let quality = arena.unwrap_or(50.0);
552                let cost_factor = match cost {
553                    Some(c) if c > 0.0 => 100.0 / c,
554                    _ => 1.0,
555                };
556                quality * cost_factor.sqrt()
557            }
558        }
559    }
560}
561
562/// Build a [`ModelPolicy`] automatically from a pool of available models.
563///
564/// Inspects arena rankings and pricing to assign models to tiers.
565/// The best-ranked model becomes `large`, cheapest becomes `small`,
566/// and something in-between becomes `medium`.
567///
568/// For pools with ≤ 2 models, skips the full scoring pipeline:
569/// - 0 models → default policy
570/// - 1 model → all tiers use that model
571/// - 2 models → first=large/medium, second=small (no scoring needed,
572///   dual-model routing via [`VisionRouteMode`] handles the rest)
573pub fn auto_policy(available_models: &[&str]) -> ModelPolicy {
574    if available_models.is_empty() {
575        return ModelPolicy::default();
576    }
577    if available_models.len() == 1 {
578        let m = available_models[0].to_string();
579        return ModelPolicy {
580            small: m.clone(),
581            medium: m.clone(),
582            large: m,
583            allow_large: true,
584            max_latency_ms: None,
585            max_cost_tier: CostTier::High,
586        };
587    }
588    if available_models.len() == 2 {
589        // With only 2 models, skip arena/pricing lookups.
590        // Assign first as large/medium, second as small — the caller
591        // already knows which is vision vs text via VisionRouteMode.
592        let a = available_models[0].to_string();
593        let b = available_models[1].to_string();
594        return ModelPolicy {
595            large: a.clone(),
596            medium: a,
597            small: b,
598            allow_large: true,
599            max_latency_ms: None,
600            max_cost_tier: CostTier::High,
601        };
602    }
603
604    // 3+ models: full scoring pipeline
605    // Collect (name, arena_score, input_cost)
606    let mut models: Vec<(&str, f32, f32)> = available_models
607        .iter()
608        .map(|&name| {
609            let profile = llm_models_spider::model_profile(name);
610            let arena = profile
611                .as_ref()
612                .and_then(|p| p.ranks.overall)
613                .unwrap_or(50.0);
614            let cost = profile
615                .as_ref()
616                .and_then(|p| p.pricing.input_cost_per_m_tokens)
617                .unwrap_or(5.0);
618            (name, arena, cost)
619        })
620        .collect();
621
622    // Sort by arena score descending
623    models.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
624
625    // Safety: models is guaranteed non-empty here (early returns above handle 0/1/2),
626    // but use defensive indexing to avoid panics if logic changes.
627    if models.is_empty() {
628        return ModelPolicy::default();
629    }
630    let large = models[0].0.to_string();
631    let small = models
632        .last()
633        .map(|m| m.0.to_string())
634        .unwrap_or_else(|| large.clone());
635    let medium = if models.len() >= 3 {
636        models[models.len() / 2].0.to_string()
637    } else {
638        large.clone()
639    };
640
641    ModelPolicy {
642        small,
643        medium,
644        large,
645        allow_large: true,
646        max_latency_ms: None,
647        max_cost_tier: CostTier::High,
648    }
649}
650
651// ── TaskAnalysis ──────────────────────────────────────────────────────────────
652
653/// Analysis of a task for routing.
654#[derive(Debug, Clone, Default)]
655pub struct TaskAnalysis {
656    /// Estimated input tokens.
657    pub estimated_tokens: usize,
658    /// Whether the task requires complex reasoning.
659    pub requires_reasoning: bool,
660    /// Whether the task requires code generation.
661    pub requires_code_generation: bool,
662    /// Whether structured JSON output is required.
663    pub requires_structured_output: bool,
664    /// Whether this is a multi-step task.
665    pub multi_step: bool,
666    /// Maximum acceptable latency.
667    pub max_latency: Option<Duration>,
668    /// Task category.
669    pub category: TaskCategory,
670    /// Whether the task requires vision capabilities.
671    pub requires_vision: bool,
672    /// Whether the task requires audio capabilities.
673    pub requires_audio: bool,
674}
675
676impl TaskAnalysis {
677    /// Create analysis from a prompt.
678    pub fn from_prompt(prompt: &str) -> Self {
679        let estimated_tokens = estimate_tokens(prompt);
680        let lower = prompt.to_lowercase();
681
682        Self {
683            estimated_tokens,
684            requires_reasoning: lower.contains("analyze")
685                || lower.contains("compare")
686                || lower.contains("explain")
687                || lower.contains("why"),
688            requires_code_generation: lower.contains("code")
689                || lower.contains("implement")
690                || lower.contains("function")
691                || lower.contains("script"),
692            requires_structured_output: lower.contains("json")
693                || lower.contains("extract")
694                || lower.contains("list"),
695            multi_step: lower.contains("then")
696                || lower.contains("step")
697                || lower.contains("first")
698                || lower.contains("next"),
699            max_latency: None,
700            category: TaskCategory::General,
701            requires_vision: lower.contains("screenshot")
702                || lower.contains("image")
703                || lower.contains("picture")
704                || lower.contains("visual"),
705            requires_audio: lower.contains("audio")
706                || lower.contains("voice")
707                || lower.contains("speech"),
708        }
709    }
710
711    /// Create analysis for extraction task.
712    pub fn extraction(html_length: usize) -> Self {
713        Self {
714            estimated_tokens: html_length / 4 + 200, // Rough estimate
715            requires_reasoning: false,
716            requires_code_generation: false,
717            requires_structured_output: true,
718            multi_step: false,
719            max_latency: None,
720            category: TaskCategory::Extraction,
721            requires_vision: false,
722            requires_audio: false,
723        }
724    }
725
726    /// Create analysis for action task.
727    pub fn action(instruction: &str) -> Self {
728        let mut analysis = Self::from_prompt(instruction);
729        analysis.category = TaskCategory::Action;
730        analysis.requires_structured_output = true;
731        analysis
732    }
733
734    /// Set max latency requirement.
735    pub fn with_max_latency(mut self, latency: Duration) -> Self {
736        self.max_latency = Some(latency);
737        self
738    }
739
740    /// Convert to model requirements for the selector.
741    pub fn to_requirements(&self) -> ModelRequirements {
742        ModelRequirements {
743            vision: self.requires_vision,
744            audio: self.requires_audio,
745            ..Default::default()
746        }
747    }
748}
749
750/// Category of task.
751#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
752pub enum TaskCategory {
753    /// General purpose task.
754    #[default]
755    General,
756    /// Data extraction.
757    Extraction,
758    /// Browser action.
759    Action,
760    /// Code generation.
761    Code,
762    /// Analysis/reasoning.
763    Analysis,
764    /// Simple classification.
765    Classification,
766}
767
768/// Result of routing decision.
769#[derive(Debug, Clone)]
770pub struct RoutingDecision {
771    /// Selected model name.
772    pub model: String,
773    /// Selected cost tier.
774    pub tier: CostTier,
775    /// Explanation for the decision.
776    pub reason: String,
777}
778
779impl RoutingDecision {
780    /// Check if this routes to a fast model.
781    pub fn is_fast(&self) -> bool {
782        self.tier == CostTier::Low
783    }
784
785    /// Check if this routes to a powerful model.
786    pub fn is_powerful(&self) -> bool {
787        self.tier == CostTier::High
788    }
789}
790
791/// Estimate token count for text.
792///
793/// Uses a rough approximation of 4 characters per token.
794pub fn estimate_tokens(text: &str) -> usize {
795    // Rough estimate: ~4 characters per token for English
796    // This is a simplification; real tokenization is more complex
797    text.len() / 4 + 1
798}
799
800/// Estimate tokens for messages.
801pub fn estimate_message_tokens(messages: &[crate::Message]) -> usize {
802    messages
803        .iter()
804        .map(|m| estimate_tokens(m.content.as_text()) + 4) // +4 for message overhead
805        .sum()
806}
807
808/// Classify the complexity of an automation round using signals already
809/// available in the engine loop — no additional LLM call required.
810///
811/// Returns a [`TaskAnalysis`] suitable for passing to [`ModelRouter::route`].
812pub fn classify_round_complexity(
813    user_prompt: &str,
814    html_len: usize,
815    round_idx: usize,
816    stagnated: bool,
817) -> TaskAnalysis {
818    let mut analysis = TaskAnalysis::from_prompt(user_prompt);
819    analysis.estimated_tokens = user_prompt.len() / 4 + 1;
820    // Initial round requires more reasoning (first page analysis)
821    if round_idx == 0 {
822        analysis.requires_reasoning = true;
823    }
824    // Stagnated rounds need a stronger model to break out
825    if stagnated {
826        analysis.requires_reasoning = true;
827        analysis.multi_step = true;
828    }
829    // Large pages need more capability
830    if html_len > 50_000 {
831        analysis.requires_reasoning = true;
832    }
833    // Automation always requires structured JSON output
834    analysis.requires_structured_output = true;
835    analysis
836}
837
838#[cfg(test)]
839mod tests {
840    use super::*;
841
842    #[test]
843    fn test_model_router_simple() {
844        let router = ModelRouter::new();
845
846        let decision = router.route_simple("Extract the title from this page");
847        assert!(!decision.model.is_empty());
848    }
849
850    #[test]
851    fn test_model_router_complex() {
852        let router = ModelRouter::new();
853
854        let task = TaskAnalysis {
855            estimated_tokens: 5000,
856            requires_reasoning: true,
857            requires_code_generation: true,
858            ..Default::default()
859        };
860
861        let decision = router.route(&task);
862        assert_eq!(decision.tier, CostTier::High);
863    }
864
865    #[test]
866    fn test_model_router_constrained() {
867        let policy = ModelPolicy {
868            max_cost_tier: CostTier::Medium,
869            ..Default::default()
870        };
871
872        let router = ModelRouter::with_policy(policy);
873
874        let task = TaskAnalysis {
875            estimated_tokens: 5000,
876            requires_reasoning: true,
877            ..Default::default()
878        };
879
880        let decision = router.route(&task);
881        // Should be capped at Medium due to policy
882        assert!(decision.tier != CostTier::High);
883    }
884
885    #[test]
886    fn test_task_analysis_from_prompt() {
887        let analysis = TaskAnalysis::from_prompt(
888            "Analyze the code and explain why it's slow, then implement a fix",
889        );
890
891        assert!(analysis.requires_reasoning);
892        assert!(analysis.requires_code_generation);
893        assert!(analysis.multi_step);
894    }
895
896    #[test]
897    fn test_task_analysis_vision_detection() {
898        let analysis = TaskAnalysis::from_prompt("Look at this screenshot and describe it");
899        assert!(analysis.requires_vision);
900
901        let analysis = TaskAnalysis::from_prompt("Summarize this text");
902        assert!(!analysis.requires_vision);
903    }
904
905    #[test]
906    fn test_estimate_tokens() {
907        assert_eq!(estimate_tokens("hello world"), 3); // 11 chars / 4 + 1
908        assert_eq!(estimate_tokens(""), 1);
909    }
910
911    // ── ModelSelector tests ───────────────────────────────────────────────
912
913    #[test]
914    fn test_model_selector_basic() {
915        let selector = ModelSelector::new(&["gpt-4o", "gpt-4o-mini", "gpt-3.5-turbo"]);
916        let reqs = ModelRequirements::default();
917        let result = selector.select(&reqs);
918        assert!(result.is_some());
919    }
920
921    #[test]
922    fn test_model_selector_vision_filter() {
923        let selector = ModelSelector::new(&["gpt-4o", "gpt-3.5-turbo"]);
924        let reqs = ModelRequirements::default().with_vision();
925        let ranked = selector.ranked(&reqs);
926
927        // gpt-4o supports vision, gpt-3.5-turbo does not
928        assert!(!ranked.is_empty());
929        for m in &ranked {
930            assert!(
931                m.supports_vision,
932                "non-vision model {} passed filter",
933                m.name
934            );
935        }
936    }
937
938    #[test]
939    fn test_model_selector_custom_priority() {
940        let mut selector = ModelSelector::new(&["gpt-4o", "gpt-4o-mini", "gpt-3.5-turbo"]);
941        // Override gpt-4o-mini to be top priority
942        selector.set_priority("gpt-4o-mini", 999.0);
943
944        let reqs = ModelRequirements::default();
945        let best = selector.select(&reqs).unwrap();
946        assert_eq!(best.name, "gpt-4o-mini");
947        assert_eq!(best.score, 999.0);
948    }
949
950    #[test]
951    fn test_model_selector_cheapest_strategy() {
952        let mut selector = ModelSelector::new(&["gpt-4o", "gpt-4o-mini"]);
953        selector.set_strategy(SelectionStrategy::CheapestFirst);
954
955        let reqs = ModelRequirements::default();
956        let ranked = selector.ranked(&reqs);
957
958        // With CheapestFirst, cheaper model should rank higher
959        assert!(ranked.len() >= 1);
960    }
961
962    #[test]
963    fn test_model_selector_multi_dispatch() {
964        let selector = ModelSelector::new(&["gpt-4o", "gpt-4o-mini", "gpt-3.5-turbo"]);
965
966        let requirements = vec![
967            ModelRequirements::default().with_vision(),
968            ModelRequirements::default(),
969        ];
970
971        let results = selector.select_multi(&requirements);
972        assert_eq!(results.len(), 2);
973
974        // First task needs vision — should pick a vision model
975        assert!(results[0].is_some());
976        assert!(results[0].as_ref().unwrap().supports_vision);
977
978        // Second task — should pick a different model if possible
979        assert!(results[1].is_some());
980    }
981
982    #[test]
983    fn test_model_selector_add_model() {
984        let mut selector = ModelSelector::new(&["gpt-4o"]);
985        selector.add_model("gpt-4o-mini");
986        assert_eq!(selector.models.len(), 2);
987        // Adding duplicate should not increase count
988        selector.add_model("gpt-4o");
989        assert_eq!(selector.models.len(), 2);
990    }
991
992    #[test]
993    fn test_auto_policy_single_model() {
994        let policy = auto_policy(&["gpt-4o"]);
995        assert_eq!(policy.small, "gpt-4o");
996        assert_eq!(policy.medium, "gpt-4o");
997        assert_eq!(policy.large, "gpt-4o");
998    }
999
1000    #[test]
1001    fn test_auto_policy_multiple_models() {
1002        let policy = auto_policy(&["gpt-4o", "gpt-4o-mini", "gpt-3.5-turbo"]);
1003        // Should assign different models to different tiers
1004        assert!(!policy.small.is_empty());
1005        assert!(!policy.medium.is_empty());
1006        assert!(!policy.large.is_empty());
1007    }
1008
1009    #[test]
1010    fn test_auto_policy_empty() {
1011        let policy = auto_policy(&[]);
1012        // Should return default policy
1013        assert_eq!(policy.small, "gpt-4o-mini");
1014        assert_eq!(policy.medium, "gpt-4o");
1015    }
1016
1017    #[test]
1018    fn test_model_requirements_builder() {
1019        let reqs = ModelRequirements::default()
1020            .with_vision()
1021            .with_min_context(100_000)
1022            .with_max_cost(10.0)
1023            .with_min_arena(60.0);
1024
1025        assert!(reqs.vision);
1026        assert_eq!(reqs.min_context_tokens, 100_000);
1027        assert_eq!(reqs.max_input_cost_per_m, 10.0);
1028        assert_eq!(reqs.min_arena_score, 60.0);
1029    }
1030
1031    #[test]
1032    fn test_task_to_requirements() {
1033        let task = TaskAnalysis::from_prompt("Look at this screenshot and extract data");
1034        let reqs = task.to_requirements();
1035        assert!(reqs.vision);
1036    }
1037
1038    // ── Phase 1: llm_models_spider data accuracy ────────────────────────
1039
1040    #[test]
1041    fn test_llm_data_vision_models_detected() {
1042        // Models that MUST report vision support
1043        for model in &[
1044            "gpt-4o",
1045            "gpt-4o-mini",
1046            "claude-sonnet-4-5-20250514",
1047            "gemini-2.0-flash",
1048            "qwen2-vl-72b-instruct",
1049            "llama-3.2-11b-vision-instruct",
1050        ] {
1051            assert!(
1052                llm_models_spider::supports_vision(model),
1053                "{model} should support vision"
1054            );
1055        }
1056        // Models that MUST NOT report vision support
1057        for model in &["gpt-3.5-turbo", "deepseek-chat"] {
1058            assert!(
1059                !llm_models_spider::supports_vision(model),
1060                "{model} should NOT support vision"
1061            );
1062        }
1063    }
1064
1065    #[test]
1066    fn test_llm_data_model_profiles_exist() {
1067        let must_have = [
1068            "gpt-4o",
1069            "gpt-4o-mini",
1070            "gpt-3.5-turbo",
1071            "claude-3-5-sonnet-20241022",
1072            "gemini-2.0-flash",
1073            "deepseek-chat",
1074        ];
1075        for name in &must_have {
1076            let profile = llm_models_spider::model_profile(name);
1077            assert!(
1078                profile.is_some(),
1079                "model_profile({name}) should return Some"
1080            );
1081            let p = profile.unwrap();
1082            assert!(
1083                p.max_input_tokens > 0,
1084                "{name} should have max_input_tokens > 0, got {}",
1085                p.max_input_tokens
1086            );
1087        }
1088    }
1089
1090    #[test]
1091    fn test_llm_data_arena_scores_present() {
1092        // These well-known models should have arena scores
1093        // (use short canonical names that match the arena data)
1094        for name in &["claude-3.5-sonnet", "chatgpt-4o-latest", "claude-opus-4"] {
1095            let profile = llm_models_spider::model_profile(name);
1096            assert!(profile.is_some(), "{name} should have a profile");
1097            let p = profile.unwrap();
1098            assert!(
1099                p.ranks.overall.is_some(),
1100                "{name} should have an arena score"
1101            );
1102            assert!(
1103                p.ranks.overall.unwrap() > 0.0,
1104                "{name} arena score should be > 0"
1105            );
1106        }
1107    }
1108
1109    #[test]
1110    fn test_llm_data_pricing_ordering() {
1111        let cheap = llm_models_spider::model_profile("gpt-4o-mini");
1112        let expensive = llm_models_spider::model_profile("claude-opus-4-20250514");
1113        assert!(cheap.is_some() && expensive.is_some());
1114        let cheap_cost = cheap.unwrap().pricing.input_cost_per_m_tokens.unwrap();
1115        let expensive_cost = expensive.unwrap().pricing.input_cost_per_m_tokens.unwrap();
1116        assert!(
1117            cheap_cost < expensive_cost,
1118            "gpt-4o-mini (${cheap_cost}) should be cheaper than claude-opus-4 (${expensive_cost})"
1119        );
1120    }
1121
1122    #[test]
1123    fn test_llm_data_context_window_ordering() {
1124        let large_ctx = llm_models_spider::model_profile("gemini-2.5-pro-preview-05-06");
1125        let small_ctx = llm_models_spider::model_profile("gpt-3.5-turbo");
1126        assert!(large_ctx.is_some() && small_ctx.is_some());
1127        let large_tokens = large_ctx.unwrap().max_input_tokens;
1128        let small_tokens = small_ctx.unwrap().max_input_tokens;
1129        assert!(
1130            large_tokens > small_tokens,
1131            "gemini-2.5-pro ({large_tokens}) should have more context than gpt-3.5-turbo ({small_tokens})"
1132        );
1133    }
1134
1135    // ── Phase 2: ModelSelector reliability ──────────────────────────────
1136
1137    #[test]
1138    fn test_selector_realistic_pool_best_quality() {
1139        let selector = ModelSelector::new(&[
1140            "gpt-4o",
1141            "gpt-4o-mini",
1142            "gpt-3.5-turbo",
1143            "claude-3-5-sonnet-20241022",
1144            "gemini-2.0-flash",
1145            "deepseek-chat",
1146        ]);
1147        let reqs = ModelRequirements::default();
1148        let ranked = selector.ranked(&reqs);
1149        assert!(!ranked.is_empty());
1150        // BestQuality (default) → first should have the highest arena score
1151        let top = &ranked[0];
1152        for other in &ranked[1..] {
1153            assert!(
1154                top.score >= other.score,
1155                "top model {} (score {}) should beat {} (score {})",
1156                top.name,
1157                top.score,
1158                other.name,
1159                other.score
1160            );
1161        }
1162    }
1163
1164    #[test]
1165    fn test_selector_realistic_pool_cheapest() {
1166        let mut selector = ModelSelector::new(&[
1167            "gpt-4o",
1168            "gpt-4o-mini",
1169            "gpt-3.5-turbo",
1170            "claude-3-5-sonnet-20241022",
1171        ]);
1172        selector.set_strategy(SelectionStrategy::CheapestFirst);
1173        let reqs = ModelRequirements::default();
1174        let ranked = selector.ranked(&reqs);
1175        assert!(ranked.len() >= 2);
1176        // CheapestFirst → lower cost gets higher score
1177        let top = &ranked[0];
1178        let bottom = ranked.last().unwrap();
1179        if let (Some(top_cost), Some(bottom_cost)) = (top.input_cost_per_m, bottom.input_cost_per_m)
1180        {
1181            assert!(
1182                top_cost <= bottom_cost,
1183                "cheapest ({}, ${top_cost}) should rank above expensive ({}, ${bottom_cost})",
1184                top.name,
1185                bottom.name
1186            );
1187        }
1188    }
1189
1190    #[test]
1191    fn test_selector_vision_filter_rejects_text_only() {
1192        let selector = ModelSelector::new(&["gpt-3.5-turbo", "deepseek-chat"]);
1193        let reqs = ModelRequirements::default().with_vision();
1194        let result = selector.select(&reqs);
1195        assert!(
1196            result.is_none(),
1197            "text-only pool should return None for vision requirement"
1198        );
1199    }
1200
1201    #[test]
1202    fn test_selector_unknown_models_graceful() {
1203        let selector = ModelSelector::new(&["my-custom-model", "local-llama"]);
1204        let reqs = ModelRequirements::default();
1205        let result = selector.select(&reqs);
1206        assert!(
1207            result.is_some(),
1208            "unknown models should still return Some with default score"
1209        );
1210        let scored = result.unwrap();
1211        assert_eq!(scored.score, 50.0, "unknown model gets default score 50.0");
1212    }
1213
1214    #[test]
1215    fn test_selector_single_model_all_strategies() {
1216        for strategy in &[
1217            SelectionStrategy::BestQuality,
1218            SelectionStrategy::CheapestFirst,
1219            SelectionStrategy::LargestContext,
1220            SelectionStrategy::ValueOptimal,
1221        ] {
1222            let mut selector = ModelSelector::new(&["gpt-4o"]);
1223            selector.set_strategy(*strategy);
1224            let reqs = ModelRequirements::default();
1225            let result = selector.select(&reqs);
1226            assert!(
1227                result.is_some(),
1228                "single model should be returned for {strategy:?}"
1229            );
1230            assert_eq!(result.unwrap().name, "gpt-4o");
1231        }
1232    }
1233
1234    #[test]
1235    fn test_selector_deterministic_ordering() {
1236        let selector = ModelSelector::new(&[
1237            "gpt-4o",
1238            "gpt-4o-mini",
1239            "claude-3-5-sonnet-20241022",
1240            "gemini-2.0-flash",
1241        ]);
1242        let reqs = ModelRequirements::default();
1243
1244        let first_run: Vec<String> = selector
1245            .ranked(&reqs)
1246            .iter()
1247            .map(|m| m.name.clone())
1248            .collect();
1249        let second_run: Vec<String> = selector
1250            .ranked(&reqs)
1251            .iter()
1252            .map(|m| m.name.clone())
1253            .collect();
1254        assert_eq!(
1255            first_run, second_run,
1256            "repeated calls must produce identical ordering"
1257        );
1258    }
1259
1260    #[test]
1261    fn test_selector_cost_filter_strict() {
1262        let selector = ModelSelector::new(&["gpt-4o", "gpt-4o-mini", "gpt-3.5-turbo"]);
1263        let reqs = ModelRequirements::default().with_max_cost(1.0);
1264        let ranked = selector.ranked(&reqs);
1265        for m in &ranked {
1266            if let Some(cost) = m.input_cost_per_m {
1267                assert!(
1268                    cost <= 1.0,
1269                    "{} has cost ${cost} which exceeds max 1.0",
1270                    m.name
1271                );
1272            }
1273        }
1274    }
1275
1276    #[test]
1277    fn test_selector_min_context_filter() {
1278        let selector = ModelSelector::new(&["gpt-4o", "gpt-3.5-turbo", "gemini-2.0-flash"]);
1279        let reqs = ModelRequirements::default().with_min_context(500_000);
1280        let ranked = selector.ranked(&reqs);
1281        for m in &ranked {
1282            assert!(
1283                m.max_input_tokens >= 500_000,
1284                "{} has {} tokens, below 500k minimum",
1285                m.name,
1286                m.max_input_tokens
1287            );
1288        }
1289    }
1290
1291    #[test]
1292    fn test_selector_value_optimal_balances() {
1293        let mut selector = ModelSelector::new(&[
1294            "gpt-4o",        // high quality, moderate cost
1295            "gpt-4o-mini",   // lower quality, cheap
1296            "gpt-3.5-turbo", // lowest quality, cheapest
1297        ]);
1298        selector.set_strategy(SelectionStrategy::ValueOptimal);
1299        let reqs = ModelRequirements::default();
1300        let ranked = selector.ranked(&reqs);
1301        assert!(!ranked.is_empty());
1302        let top = &ranked[0];
1303        // ValueOptimal should NOT just pick cheapest or best — verify it's not the raw cheapest
1304        // (it uses quality * sqrt(100/cost), so moderate-cost + high-quality can win)
1305        // Just verify it returns a valid result with a positive score
1306        assert!(top.score > 0.0, "ValueOptimal score should be positive");
1307    }
1308
1309    #[test]
1310    fn test_select_multi_no_reuse() {
1311        let selector = ModelSelector::new(&["gpt-4o", "gpt-4o-mini", "gpt-3.5-turbo"]);
1312        let requirements = vec![
1313            ModelRequirements::default(),
1314            ModelRequirements::default(),
1315            ModelRequirements::default(),
1316        ];
1317        let results = selector.select_multi(&requirements);
1318        assert_eq!(results.len(), 3);
1319        // All should be Some
1320        let names: Vec<&str> = results
1321            .iter()
1322            .filter_map(|r| r.as_ref().map(|m| m.name.as_str()))
1323            .collect();
1324        assert_eq!(names.len(), 3, "all 3 requests should get a model");
1325        // No duplicates (3 models, 3 requests → each used once)
1326        let mut deduped = names.clone();
1327        deduped.sort();
1328        deduped.dedup();
1329        assert_eq!(
1330            deduped.len(),
1331            3,
1332            "no model should be reused when pool is large enough"
1333        );
1334    }
1335
1336    #[test]
1337    fn test_select_multi_exhaustion_fallback() {
1338        let selector = ModelSelector::new(&["gpt-4o"]);
1339        let requirements = vec![
1340            ModelRequirements::default(),
1341            ModelRequirements::default(),
1342            ModelRequirements::default(),
1343        ];
1344        let results = selector.select_multi(&requirements);
1345        assert_eq!(results.len(), 3);
1346        // First gets the model, rest fall back to reuse
1347        assert!(results[0].is_some());
1348        assert!(
1349            results[1].is_some(),
1350            "fallback should reuse the single model"
1351        );
1352        assert!(
1353            results[2].is_some(),
1354            "fallback should reuse the single model"
1355        );
1356        // All should be the same model
1357        assert_eq!(results[0].as_ref().unwrap().name, "gpt-4o");
1358        assert_eq!(results[1].as_ref().unwrap().name, "gpt-4o");
1359        assert_eq!(results[2].as_ref().unwrap().name, "gpt-4o");
1360    }
1361
1362    #[test]
1363    fn test_selector_priority_override_beats_arena() {
1364        let mut selector = ModelSelector::new(&["gpt-4o", "gpt-4o-mini", "gpt-3.5-turbo"]);
1365        // gpt-3.5-turbo has low arena score; override it to beat gpt-4o
1366        selector.set_priority("gpt-3.5-turbo", 999.0);
1367        let reqs = ModelRequirements::default();
1368        let best = selector.select(&reqs).unwrap();
1369        assert_eq!(
1370            best.name, "gpt-3.5-turbo",
1371            "priority override should beat natural arena score"
1372        );
1373        assert_eq!(best.score, 999.0);
1374    }
1375
1376    // ── Phase 3: auto_policy + ModelRouter pipeline ─────────────────────
1377
1378    #[test]
1379    fn test_auto_policy_realistic_tiering() {
1380        let policy = auto_policy(&["gpt-4o", "gpt-4o-mini", "gpt-3.5-turbo"]);
1381        // Sorted by arena descending: large=highest, small=lowest
1382        // large should be the model with highest arena score
1383        // small should be the model with lowest arena score
1384        assert_ne!(
1385            policy.large, policy.small,
1386            "large and small should be different models"
1387        );
1388        // Verify tiers resolve correctly
1389        assert_eq!(policy.model_for_tier(CostTier::High), policy.large);
1390        assert_eq!(policy.model_for_tier(CostTier::Low), policy.small);
1391        assert_eq!(policy.model_for_tier(CostTier::Medium), policy.medium);
1392    }
1393
1394    #[test]
1395    fn test_auto_policy_2_models() {
1396        let policy = auto_policy(&["gpt-4o", "gpt-4o-mini"]);
1397        // With 2 models, skip scoring: first=large/medium, second=small
1398        assert_eq!(policy.large, "gpt-4o");
1399        assert_eq!(policy.medium, "gpt-4o");
1400        assert_eq!(policy.small, "gpt-4o-mini");
1401        assert_eq!(
1402            policy.medium, policy.large,
1403            "2-model policy should have medium == large"
1404        );
1405        assert_ne!(policy.large, policy.small, "large and small should differ");
1406    }
1407
1408    #[test]
1409    fn test_auto_policy_unknown_models() {
1410        let policy = auto_policy(&["my-custom-llm", "local-model-7b", "test-endpoint"]);
1411        // Should not panic, all get default arena=50.0, cost=5.0
1412        assert!(!policy.small.is_empty());
1413        assert!(!policy.medium.is_empty());
1414        assert!(!policy.large.is_empty());
1415        assert!(policy.allow_large);
1416    }
1417
1418    #[test]
1419    fn test_auto_policy_to_router_e2e() {
1420        let policy = auto_policy(&["gpt-4o", "gpt-4o-mini", "gpt-3.5-turbo"]);
1421        let router = ModelRouter::with_policy(policy.clone());
1422
1423        // Low complexity → small model
1424        let simple_task = TaskAnalysis {
1425            estimated_tokens: 100,
1426            ..Default::default()
1427        };
1428        let decision = router.route(&simple_task);
1429        assert_eq!(decision.tier, CostTier::Low);
1430        assert_eq!(decision.model, policy.small);
1431
1432        // High complexity → large model
1433        let hard_task = TaskAnalysis {
1434            estimated_tokens: 5000,
1435            requires_reasoning: true,
1436            requires_code_generation: true,
1437            ..Default::default()
1438        };
1439        let decision = router.route(&hard_task);
1440        assert_eq!(decision.tier, CostTier::High);
1441        assert_eq!(decision.model, policy.large);
1442
1443        // Medium complexity → medium model
1444        let medium_task = TaskAnalysis {
1445            estimated_tokens: 2000,
1446            requires_structured_output: true,
1447            multi_step: true,
1448            ..Default::default()
1449        };
1450        let decision = router.route(&medium_task);
1451        assert_eq!(decision.tier, CostTier::Medium);
1452        assert_eq!(decision.model, policy.medium);
1453    }
1454
1455    #[test]
1456    fn test_auto_policy_to_router_e2e_single_model() {
1457        // Common case: user only has one API key / one model
1458        let policy = auto_policy(&["gpt-4o"]);
1459        assert_eq!(policy.small, "gpt-4o");
1460        assert_eq!(policy.medium, "gpt-4o");
1461        assert_eq!(policy.large, "gpt-4o");
1462
1463        let router = ModelRouter::with_policy(policy);
1464
1465        // ALL complexity levels must resolve to the single model
1466        let simple = TaskAnalysis {
1467            estimated_tokens: 50,
1468            ..Default::default()
1469        };
1470        let medium = TaskAnalysis {
1471            estimated_tokens: 2000,
1472            requires_structured_output: true,
1473            ..Default::default()
1474        };
1475        let hard = TaskAnalysis {
1476            estimated_tokens: 5000,
1477            requires_reasoning: true,
1478            requires_code_generation: true,
1479            ..Default::default()
1480        };
1481
1482        for (label, task) in [("simple", &simple), ("medium", &medium), ("hard", &hard)] {
1483            let decision = router.route(task);
1484            assert_eq!(
1485                decision.model, "gpt-4o",
1486                "{label} task should still route to the only model"
1487            );
1488        }
1489    }
1490
1491    #[test]
1492    fn test_selector_single_model_vision_mismatch() {
1493        // User has one text-only model but needs vision → None
1494        let selector = ModelSelector::new(&["gpt-3.5-turbo"]);
1495        let reqs = ModelRequirements::default().with_vision();
1496        assert!(
1497            selector.select(&reqs).is_none(),
1498            "single text-only model should not satisfy vision requirement"
1499        );
1500
1501        // User has one vision model and needs vision → works
1502        let selector = ModelSelector::new(&["gpt-4o"]);
1503        let result = selector.select(&reqs);
1504        assert!(
1505            result.is_some(),
1506            "single vision model should satisfy vision"
1507        );
1508        assert_eq!(result.unwrap().name, "gpt-4o");
1509    }
1510
1511    #[test]
1512    fn test_selector_single_model_with_cost_filter() {
1513        // User has one expensive model but cost filter is strict → None
1514        let selector = ModelSelector::new(&["gpt-4o"]);
1515        let reqs = ModelRequirements::default().with_max_cost(0.01);
1516        assert!(
1517            selector.select(&reqs).is_none(),
1518            "single expensive model should be filtered by strict cost limit"
1519        );
1520
1521        // Relax cost filter → works
1522        let reqs = ModelRequirements::default().with_max_cost(100.0);
1523        let result = selector.select(&reqs);
1524        assert!(result.is_some());
1525        assert_eq!(result.unwrap().name, "gpt-4o");
1526    }
1527
1528    #[test]
1529    fn test_selector_single_unknown_model_e2e() {
1530        // User has one custom/self-hosted model with no data in llm_models_spider
1531        let policy = auto_policy(&["my-local-llama"]);
1532        assert_eq!(policy.small, "my-local-llama");
1533        assert_eq!(policy.medium, "my-local-llama");
1534        assert_eq!(policy.large, "my-local-llama");
1535
1536        let router = ModelRouter::with_policy(policy);
1537        let decision = router.route_simple("do something complex and analyze the code");
1538        assert_eq!(
1539            decision.model, "my-local-llama",
1540            "unknown single model should still be routed to"
1541        );
1542
1543        // Also verify selector works
1544        let selector = ModelSelector::new(&["my-local-llama"]);
1545        let result = selector.select(&ModelRequirements::default());
1546        assert!(result.is_some());
1547        let scored = result.unwrap();
1548        assert_eq!(scored.name, "my-local-llama");
1549        assert_eq!(scored.score, 50.0, "unknown model gets default score");
1550        assert_eq!(
1551            scored.max_input_tokens, 0,
1552            "unknown model has no context data"
1553        );
1554        assert!(
1555            scored.arena_rank.is_none(),
1556            "unknown model has no arena data"
1557        );
1558    }
1559
1560    #[test]
1561    fn test_router_latency_constraint_downgrade() {
1562        let policy = ModelPolicy {
1563            max_latency_ms: Some(1000),
1564            ..Default::default()
1565        };
1566        let router = ModelRouter::with_policy(policy);
1567
1568        // A task that would normally be High tier
1569        let task = TaskAnalysis {
1570            estimated_tokens: 5000,
1571            requires_reasoning: true,
1572            requires_code_generation: true,
1573            ..Default::default()
1574        };
1575        let decision = router.route(&task);
1576        // Latency constraint should downgrade from High
1577        assert_ne!(
1578            decision.tier,
1579            CostTier::High,
1580            "latency constraint should prevent High tier"
1581        );
1582    }
1583
1584    #[test]
1585    fn test_router_allow_large_false() {
1586        let policy = ModelPolicy {
1587            allow_large: false,
1588            ..Default::default()
1589        };
1590        let router = ModelRouter::with_policy(policy);
1591
1592        let task = TaskAnalysis {
1593            estimated_tokens: 5000,
1594            requires_reasoning: true,
1595            requires_code_generation: true,
1596            ..Default::default()
1597        };
1598        let decision = router.route(&task);
1599        assert_ne!(
1600            decision.tier,
1601            CostTier::High,
1602            "allow_large=false should cap at Medium"
1603        );
1604    }
1605
1606    #[test]
1607    fn test_router_threshold_customization() {
1608        let router = ModelRouter::new().with_thresholds(100, 200);
1609
1610        // With lowered thresholds: tokens=300 → +3 (> large threshold 200),
1611        // reasoning → +2, code_gen → +2 = 7 → High tier
1612        let task = TaskAnalysis {
1613            estimated_tokens: 300,
1614            requires_reasoning: true,
1615            requires_code_generation: true,
1616            ..Default::default()
1617        };
1618        let decision = router.route(&task);
1619        assert_eq!(
1620            decision.tier,
1621            CostTier::High,
1622            "lowered thresholds should promote to High tier sooner"
1623        );
1624
1625        // Same task with default thresholds (1000/4000) would be Medium:
1626        // tokens=300 → +1 (< 1000), reasoning → +2, code_gen → +2 = 5 → Medium
1627        let default_router = ModelRouter::new();
1628        let decision = default_router.route(&task);
1629        assert_eq!(
1630            decision.tier,
1631            CostTier::Medium,
1632            "default thresholds should keep this at Medium"
1633        );
1634    }
1635
1636    // ── Phase 5: Edge cases ─────────────────────────────────────────────
1637
1638    #[test]
1639    fn test_selector_empty_pool() {
1640        let selector = ModelSelector::new(&[]);
1641        let reqs = ModelRequirements::default();
1642        let result = selector.select(&reqs);
1643        assert!(result.is_none(), "empty pool should return None");
1644    }
1645
1646    #[test]
1647    fn test_selector_duplicate_models() {
1648        let selector = ModelSelector::new(&["gpt-4o", "gpt-4o", "gpt-4o"]);
1649        let requirements = vec![
1650            ModelRequirements::default(),
1651            ModelRequirements::default(),
1652            ModelRequirements::default(),
1653        ];
1654        let results = selector.select_multi(&requirements);
1655        assert_eq!(results.len(), 3, "should not hang on duplicates");
1656        // All should resolve (first gets it, rest fallback)
1657        for (i, r) in results.iter().enumerate() {
1658            assert!(r.is_some(), "request {i} should get a model");
1659        }
1660    }
1661
1662    #[test]
1663    fn test_task_analysis_edge_cases() {
1664        // Empty string
1665        let analysis = TaskAnalysis::from_prompt("");
1666        assert_eq!(analysis.estimated_tokens, 1);
1667        assert!(!analysis.requires_reasoning);
1668
1669        // All keywords
1670        let analysis = TaskAnalysis::from_prompt(
1671            "analyze compare explain why code implement function script json extract list then step first next screenshot image",
1672        );
1673        assert!(analysis.requires_reasoning);
1674        assert!(analysis.requires_code_generation);
1675        assert!(analysis.requires_structured_output);
1676        assert!(analysis.multi_step);
1677        assert!(analysis.requires_vision);
1678
1679        // Unicode-only
1680        let analysis = TaskAnalysis::from_prompt("你好世界 🌍 日本語テスト");
1681        assert!(!analysis.requires_reasoning);
1682        assert!(!analysis.requires_code_generation);
1683        assert!(analysis.estimated_tokens > 0);
1684    }
1685
1686    #[test]
1687    fn test_auto_policy_large_pool() {
1688        let models: Vec<&str> = vec![
1689            "gpt-4o",
1690            "gpt-4o-mini",
1691            "gpt-3.5-turbo",
1692            "claude-3-5-sonnet-20241022",
1693            "claude-3-5-haiku-20241022",
1694            "gemini-2.0-flash",
1695            "deepseek-chat",
1696            "unknown-model-1",
1697            "unknown-model-2",
1698            "unknown-model-3",
1699            "unknown-model-4",
1700            "unknown-model-5",
1701            "unknown-model-6",
1702            "unknown-model-7",
1703            "unknown-model-8",
1704            "unknown-model-9",
1705            "unknown-model-10",
1706            "unknown-model-11",
1707            "unknown-model-12",
1708            "unknown-model-13",
1709        ];
1710        let policy = auto_policy(&models);
1711        assert!(!policy.small.is_empty());
1712        assert!(!policy.medium.is_empty());
1713        assert!(!policy.large.is_empty());
1714        assert!(policy.allow_large);
1715        assert_eq!(policy.max_cost_tier, CostTier::High);
1716    }
1717
1718    // ── classify_round_complexity tests ────────────────────────────────
1719
1720    #[test]
1721    fn test_classify_round_complexity_round_0() {
1722        let analysis = classify_round_complexity("click button", 1000, 0, false);
1723        assert!(
1724            analysis.requires_reasoning,
1725            "round 0 always requires reasoning"
1726        );
1727        assert!(analysis.requires_structured_output);
1728    }
1729
1730    #[test]
1731    fn test_classify_round_complexity_stagnated() {
1732        let analysis = classify_round_complexity("click button", 1000, 5, true);
1733        assert!(
1734            analysis.requires_reasoning,
1735            "stagnated rounds need reasoning"
1736        );
1737        assert!(analysis.multi_step, "stagnated rounds are multi-step");
1738    }
1739
1740    #[test]
1741    fn test_classify_round_complexity_large_html() {
1742        let analysis = classify_round_complexity("click button", 60_000, 3, false);
1743        assert!(analysis.requires_reasoning, "large HTML needs reasoning");
1744    }
1745
1746    #[test]
1747    fn test_classify_round_complexity_simple() {
1748        let analysis = classify_round_complexity("click button", 1000, 3, false);
1749        // Simple round: not round 0, not stagnated, small HTML, no reasoning keywords
1750        assert!(!analysis.requires_reasoning);
1751        assert!(!analysis.multi_step);
1752        assert!(analysis.requires_structured_output);
1753    }
1754
1755    #[test]
1756    fn test_policy_getter() {
1757        let policy = ModelPolicy {
1758            small: "small-model".to_string(),
1759            medium: "medium-model".to_string(),
1760            large: "large-model".to_string(),
1761            allow_large: true,
1762            max_latency_ms: None,
1763            max_cost_tier: CostTier::High,
1764        };
1765        let router = ModelRouter::with_policy(policy.clone());
1766        assert_eq!(router.policy().small, "small-model");
1767        assert_eq!(router.policy().large, "large-model");
1768    }
1769}