Skip to main content

terraphim_router/
strategy.rs

1//! Routing strategies for selecting the best provider.
2
3use std::collections::HashMap;
4use std::sync::atomic::{AtomicU64, Ordering};
5
6use terraphim_types::capability::{CostLevel, Latency, Provider};
7
8/// Trait for routing strategies
9pub trait RoutingStrategy: Send + Sync {
10    /// Select the best provider from candidates
11    fn select_provider<'a>(&self, candidates: Vec<&'a Provider>) -> Option<&'a Provider>;
12
13    /// Get strategy name
14    fn name(&self) -> &'static str;
15}
16
17/// Strategy: Optimize for lowest cost
18#[derive(Debug, Clone, Default)]
19pub struct CostOptimized;
20
21impl RoutingStrategy for CostOptimized {
22    fn select_provider<'a>(&self, candidates: Vec<&'a Provider>) -> Option<&'a Provider> {
23        let result = candidates.into_iter().min_by_key(|p| p.cost_level);
24        tracing::debug!(
25            strategy = "cost_optimized",
26            selected_provider = result.map(|p| p.id.as_str()),
27            "Strategy selection complete"
28        );
29        result
30    }
31
32    fn name(&self) -> &'static str {
33        "cost_optimized"
34    }
35}
36
37/// Strategy: Optimize for lowest latency
38#[derive(Debug, Clone, Default)]
39pub struct LatencyOptimized;
40
41impl RoutingStrategy for LatencyOptimized {
42    fn select_provider<'a>(&self, candidates: Vec<&'a Provider>) -> Option<&'a Provider> {
43        let result = candidates.into_iter().min_by_key(|p| p.latency);
44        tracing::debug!(
45            strategy = "latency_optimized",
46            selected_provider = result.map(|p| p.id.as_str()),
47            "Strategy selection complete"
48        );
49        result
50    }
51
52    fn name(&self) -> &'static str {
53        "latency_optimized"
54    }
55}
56
57/// Strategy: Optimize for best capability match
58#[derive(Debug, Clone, Default)]
59pub struct CapabilityFirst;
60
61impl RoutingStrategy for CapabilityFirst {
62    fn select_provider<'a>(&self, candidates: Vec<&'a Provider>) -> Option<&'a Provider> {
63        let result = candidates.into_iter().max_by_key(|p| p.capabilities.len());
64        tracing::debug!(
65            strategy = "capability_first",
66            selected_provider = result.map(|p| p.id.as_str()),
67            "Strategy selection complete"
68        );
69        result
70    }
71
72    fn name(&self) -> &'static str {
73        "capability_first"
74    }
75}
76
77/// Strategy: Round-robin for load balancing
78#[derive(Debug)]
79pub struct RoundRobin {
80    index: std::sync::atomic::AtomicUsize,
81}
82
83impl RoundRobin {
84    pub fn new() -> Self {
85        Self {
86            index: std::sync::atomic::AtomicUsize::new(0),
87        }
88    }
89}
90
91impl Default for RoundRobin {
92    fn default() -> Self {
93        Self::new()
94    }
95}
96
97impl RoutingStrategy for RoundRobin {
98    fn select_provider<'a>(&self, candidates: Vec<&'a Provider>) -> Option<&'a Provider> {
99        if candidates.is_empty() {
100            return None;
101        }
102
103        let index = self
104            .index
105            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
106        let selected = index % candidates.len();
107
108        let result = candidates.into_iter().nth(selected);
109        tracing::debug!(
110            strategy = "round_robin",
111            selected_provider = result.map(|p| p.id.as_str()),
112            index = index,
113            "Strategy selection complete"
114        );
115        result
116    }
117
118    fn name(&self) -> &'static str {
119        "round_robin"
120    }
121}
122
123/// Strategy: A/B testing -- probabilistically route between two strategies.
124///
125/// Uses a weight (0.0-1.0) to determine how often strategy A vs B is chosen.
126/// A weight of 0.7 means 70% of requests go through strategy A.
127pub struct WeightedStrategy {
128    strategy_a: Box<dyn RoutingStrategy>,
129    strategy_b: Box<dyn RoutingStrategy>,
130    /// Weight for strategy A (0.0 = always B, 1.0 = always A)
131    weight_a: f64,
132    /// Counter for deterministic round-based selection
133    counter: AtomicU64,
134}
135
136impl std::fmt::Debug for WeightedStrategy {
137    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138        f.debug_struct("WeightedStrategy")
139            .field("strategy_a", &self.strategy_a.name())
140            .field("strategy_b", &self.strategy_b.name())
141            .field("weight_a", &self.weight_a)
142            .finish()
143    }
144}
145
146impl WeightedStrategy {
147    /// Create a new weighted strategy for A/B testing.
148    ///
149    /// `weight_a` is the fraction of requests routed to strategy A (0.0 to 1.0).
150    pub fn new(
151        strategy_a: Box<dyn RoutingStrategy>,
152        strategy_b: Box<dyn RoutingStrategy>,
153        weight_a: f64,
154    ) -> Self {
155        Self {
156            strategy_a,
157            strategy_b,
158            weight_a: weight_a.clamp(0.0, 1.0),
159            counter: AtomicU64::new(0),
160        }
161    }
162}
163
164impl RoutingStrategy for WeightedStrategy {
165    fn select_provider<'a>(&self, candidates: Vec<&'a Provider>) -> Option<&'a Provider> {
166        let count = self.counter.fetch_add(1, Ordering::Relaxed);
167        // Deterministic: use counter mod 100 vs weight percentage
168        let threshold = (self.weight_a * 100.0) as u64;
169        let use_a = (count % 100) < threshold;
170
171        let (chosen, name) = if use_a {
172            (&self.strategy_a, self.strategy_a.name())
173        } else {
174            (&self.strategy_b, self.strategy_b.name())
175        };
176
177        tracing::debug!(
178            strategy = "weighted",
179            chosen_branch = name,
180            branch = if use_a { "A" } else { "B" },
181            counter = count,
182            weight_a = self.weight_a,
183            "A/B strategy selection"
184        );
185
186        chosen.select_provider(candidates)
187    }
188
189    fn name(&self) -> &'static str {
190        "weighted"
191    }
192}
193
194/// Strategy: Filter candidates by user preferences, then delegate to a base strategy.
195///
196/// Removes providers exceeding `max_cost` or `max_latency` before passing
197/// to the inner strategy. If filtering eliminates all candidates, falls through
198/// to the base strategy with unfiltered candidates.
199pub struct PreferenceFilter {
200    base: Box<dyn RoutingStrategy>,
201    max_cost: Option<CostLevel>,
202    max_latency: Option<Latency>,
203}
204
205impl PreferenceFilter {
206    /// Create a preference filter wrapping a base strategy.
207    pub fn new(base: Box<dyn RoutingStrategy>) -> Self {
208        Self {
209            base,
210            max_cost: None,
211            max_latency: None,
212        }
213    }
214
215    /// Set maximum acceptable cost level.
216    pub fn with_max_cost(mut self, cost: CostLevel) -> Self {
217        self.max_cost = Some(cost);
218        self
219    }
220
221    /// Set maximum acceptable latency.
222    pub fn with_max_latency(mut self, latency: Latency) -> Self {
223        self.max_latency = Some(latency);
224        self
225    }
226}
227
228impl std::fmt::Debug for PreferenceFilter {
229    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
230        f.debug_struct("PreferenceFilter")
231            .field("base", &self.base.name())
232            .field("max_cost", &self.max_cost)
233            .field("max_latency", &self.max_latency)
234            .finish()
235    }
236}
237
238impl RoutingStrategy for PreferenceFilter {
239    fn select_provider<'a>(&self, candidates: Vec<&'a Provider>) -> Option<&'a Provider> {
240        let filtered: Vec<&'a Provider> = candidates
241            .iter()
242            .copied()
243            .filter(|p| {
244                if let Some(max_cost) = self.max_cost {
245                    if p.cost_level > max_cost {
246                        return false;
247                    }
248                }
249                if let Some(max_latency) = self.max_latency {
250                    if p.latency > max_latency {
251                        return false;
252                    }
253                }
254                true
255            })
256            .collect();
257
258        let filtered_count = filtered.len();
259        tracing::debug!(
260            strategy = "preference_filter",
261            base = self.base.name(),
262            max_cost = ?self.max_cost,
263            max_latency = ?self.max_latency,
264            original_count = candidates.len(),
265            filtered_count = filtered_count,
266            "Applied preference filters"
267        );
268
269        if filtered.is_empty() {
270            // Fall through to base with unfiltered candidates
271            tracing::debug!("Preference filter eliminated all candidates, using unfiltered");
272            self.base.select_provider(candidates)
273        } else {
274            self.base.select_provider(filtered)
275        }
276    }
277
278    fn name(&self) -> &'static str {
279        "preference_filter"
280    }
281}
282
283/// Registry of named strategies for runtime lookup.
284///
285/// Strategies are stored as factory functions so each lookup produces a fresh
286/// instance (important for stateful strategies like `RoundRobin`).
287pub struct StrategyRegistry {
288    factories: HashMap<String, Box<dyn Fn() -> Box<dyn RoutingStrategy> + Send + Sync>>,
289}
290
291impl StrategyRegistry {
292    /// Create a registry pre-populated with the four built-in strategies.
293    pub fn new() -> Self {
294        let mut reg = Self {
295            factories: HashMap::new(),
296        };
297        reg.register("cost_optimized", || Box::new(CostOptimized));
298        reg.register("latency_optimized", || Box::new(LatencyOptimized));
299        reg.register("capability_first", || Box::new(CapabilityFirst));
300        reg.register("round_robin", || Box::new(RoundRobin::new()));
301        reg
302    }
303
304    /// Register a named strategy factory.
305    pub fn register<F>(&mut self, name: &str, factory: F)
306    where
307        F: Fn() -> Box<dyn RoutingStrategy> + Send + Sync + 'static,
308    {
309        self.factories.insert(name.to_string(), Box::new(factory));
310    }
311
312    /// Look up and instantiate a strategy by name.
313    pub fn get(&self, name: &str) -> Option<Box<dyn RoutingStrategy>> {
314        self.factories.get(name).map(|f| f())
315    }
316
317    /// List registered strategy names.
318    pub fn names(&self) -> Vec<&str> {
319        self.factories.keys().map(|k| k.as_str()).collect()
320    }
321}
322
323impl Default for StrategyRegistry {
324    fn default() -> Self {
325        Self::new()
326    }
327}
328
329impl std::fmt::Debug for StrategyRegistry {
330    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
331        f.debug_struct("StrategyRegistry")
332            .field("strategies", &self.factories.keys().collect::<Vec<_>>())
333            .finish()
334    }
335}
336
337#[cfg(test)]
338mod tests {
339    use super::*;
340    use terraphim_types::capability::{Capability, ProviderType};
341
342    fn create_test_provider(id: &str, cost: CostLevel, latency: Latency) -> Provider {
343        Provider {
344            id: id.to_string(),
345            name: id.to_string(),
346            provider_type: ProviderType::Llm {
347                model_id: id.to_string(),
348                api_endpoint: "https://example.com".to_string(),
349            },
350            capabilities: vec![Capability::CodeGeneration],
351            cost_level: cost,
352            latency,
353            keywords: vec![],
354        }
355    }
356
357    #[test]
358    fn test_cost_optimized() {
359        let strategy = CostOptimized;
360
361        let providers = vec![
362            create_test_provider("expensive", CostLevel::Expensive, Latency::Medium),
363            create_test_provider("cheap", CostLevel::Cheap, Latency::Medium),
364            create_test_provider("moderate", CostLevel::Moderate, Latency::Medium),
365        ];
366
367        let candidates: Vec<&Provider> = providers.iter().collect();
368        let selected = strategy.select_provider(candidates);
369
370        assert_eq!(selected.unwrap().id, "cheap");
371    }
372
373    #[test]
374    fn test_latency_optimized() {
375        let strategy = LatencyOptimized;
376
377        let providers = vec![
378            create_test_provider("slow", CostLevel::Moderate, Latency::Slow),
379            create_test_provider("fast", CostLevel::Moderate, Latency::Fast),
380            create_test_provider("medium", CostLevel::Moderate, Latency::Medium),
381        ];
382
383        let candidates: Vec<&Provider> = providers.iter().collect();
384        let selected = strategy.select_provider(candidates);
385
386        assert_eq!(selected.unwrap().id, "fast");
387    }
388
389    #[test]
390    fn test_round_robin() {
391        let strategy = RoundRobin::new();
392
393        let providers = vec![
394            create_test_provider("a", CostLevel::Cheap, Latency::Fast),
395            create_test_provider("b", CostLevel::Cheap, Latency::Fast),
396            create_test_provider("c", CostLevel::Cheap, Latency::Fast),
397        ];
398
399        // First call should return "a"
400        let candidates: Vec<&Provider> = providers.iter().collect();
401        let selected = strategy.select_provider(candidates.clone());
402        assert_eq!(selected.unwrap().id, "a");
403
404        // Second call should return "b"
405        let selected = strategy.select_provider(candidates.clone());
406        assert_eq!(selected.unwrap().id, "b");
407
408        // Third call should return "c"
409        let selected = strategy.select_provider(candidates.clone());
410        assert_eq!(selected.unwrap().id, "c");
411
412        // Fourth call should wrap around to "a"
413        let selected = strategy.select_provider(candidates);
414        assert_eq!(selected.unwrap().id, "a");
415    }
416
417    #[test]
418    fn test_weighted_strategy_all_a() {
419        let strategy = WeightedStrategy::new(
420            Box::new(CostOptimized),
421            Box::new(LatencyOptimized),
422            1.0, // Always use A
423        );
424
425        let providers = vec![
426            create_test_provider("cheap-slow", CostLevel::Cheap, Latency::Slow),
427            create_test_provider("expensive-fast", CostLevel::Expensive, Latency::Fast),
428        ];
429
430        // With weight 1.0, always uses CostOptimized -> "cheap-slow"
431        for _ in 0..10 {
432            let candidates: Vec<&Provider> = providers.iter().collect();
433            let selected = strategy.select_provider(candidates);
434            assert_eq!(selected.unwrap().id, "cheap-slow");
435        }
436    }
437
438    #[test]
439    fn test_weighted_strategy_all_b() {
440        let strategy = WeightedStrategy::new(
441            Box::new(CostOptimized),
442            Box::new(LatencyOptimized),
443            0.0, // Always use B
444        );
445
446        let providers = vec![
447            create_test_provider("cheap-slow", CostLevel::Cheap, Latency::Slow),
448            create_test_provider("expensive-fast", CostLevel::Expensive, Latency::Fast),
449        ];
450
451        // With weight 0.0, always uses LatencyOptimized -> "expensive-fast"
452        for _ in 0..10 {
453            let candidates: Vec<&Provider> = providers.iter().collect();
454            let selected = strategy.select_provider(candidates);
455            assert_eq!(selected.unwrap().id, "expensive-fast");
456        }
457    }
458
459    #[test]
460    fn test_weighted_strategy_split() {
461        let strategy = WeightedStrategy::new(
462            Box::new(CostOptimized),
463            Box::new(LatencyOptimized),
464            0.5, // 50/50 split
465        );
466
467        let providers = vec![
468            create_test_provider("cheap-slow", CostLevel::Cheap, Latency::Slow),
469            create_test_provider("expensive-fast", CostLevel::Expensive, Latency::Fast),
470        ];
471
472        let mut a_count = 0;
473        let mut b_count = 0;
474        for _ in 0..100 {
475            let candidates: Vec<&Provider> = providers.iter().collect();
476            let selected = strategy.select_provider(candidates).unwrap();
477            if selected.id == "cheap-slow" {
478                a_count += 1;
479            } else {
480                b_count += 1;
481            }
482        }
483
484        // 50/50 split should give exactly 50 each (deterministic counter-based)
485        assert_eq!(a_count, 50);
486        assert_eq!(b_count, 50);
487    }
488
489    #[test]
490    fn test_preference_filter_cost() {
491        let strategy =
492            PreferenceFilter::new(Box::new(LatencyOptimized)).with_max_cost(CostLevel::Moderate);
493
494        let providers = vec![
495            create_test_provider("cheap-slow", CostLevel::Cheap, Latency::Slow),
496            create_test_provider("expensive-fast", CostLevel::Expensive, Latency::Fast),
497            create_test_provider("moderate-medium", CostLevel::Moderate, Latency::Medium),
498        ];
499
500        let candidates: Vec<&Provider> = providers.iter().collect();
501        let selected = strategy.select_provider(candidates);
502
503        // Expensive is filtered out, then LatencyOptimized picks "moderate-medium" (Medium < Slow)
504        assert_eq!(selected.unwrap().id, "moderate-medium");
505    }
506
507    #[test]
508    fn test_preference_filter_latency() {
509        let strategy =
510            PreferenceFilter::new(Box::new(CostOptimized)).with_max_latency(Latency::Medium);
511
512        let providers = vec![
513            create_test_provider("cheap-slow", CostLevel::Cheap, Latency::Slow),
514            create_test_provider("expensive-fast", CostLevel::Expensive, Latency::Fast),
515            create_test_provider("moderate-medium", CostLevel::Moderate, Latency::Medium),
516        ];
517
518        let candidates: Vec<&Provider> = providers.iter().collect();
519        let selected = strategy.select_provider(candidates);
520
521        // Slow is filtered out, then CostOptimized picks "moderate-medium" (Moderate < Expensive)
522        assert_eq!(selected.unwrap().id, "moderate-medium");
523    }
524
525    #[test]
526    fn test_preference_filter_fallthrough() {
527        // Filter requires max_latency=Fast, but both providers are Slow
528        let strategy =
529            PreferenceFilter::new(Box::new(CostOptimized)).with_max_latency(Latency::Fast);
530
531        let providers = vec![
532            create_test_provider("cheap-slow", CostLevel::Cheap, Latency::Slow),
533            create_test_provider("expensive-slow", CostLevel::Expensive, Latency::Slow),
534        ];
535
536        let candidates: Vec<&Provider> = providers.iter().collect();
537        let selected = strategy.select_provider(candidates);
538
539        // All filtered out -> falls through to unfiltered CostOptimized -> "cheap-slow"
540        assert_eq!(selected.unwrap().id, "cheap-slow");
541    }
542
543    #[test]
544    fn test_strategy_registry_builtins() {
545        let registry = StrategyRegistry::new();
546
547        assert!(registry.get("cost_optimized").is_some());
548        assert!(registry.get("latency_optimized").is_some());
549        assert!(registry.get("capability_first").is_some());
550        assert!(registry.get("round_robin").is_some());
551        assert!(registry.get("nonexistent").is_none());
552
553        let names = registry.names();
554        assert_eq!(names.len(), 4);
555    }
556
557    #[test]
558    fn test_strategy_registry_custom() {
559        let mut registry = StrategyRegistry::new();
560        registry.register("my_strategy", || Box::new(CostOptimized));
561
562        let strategy = registry.get("my_strategy").unwrap();
563        assert_eq!(strategy.name(), "cost_optimized");
564
565        let names = registry.names();
566        assert_eq!(names.len(), 5);
567    }
568
569    #[test]
570    fn test_strategy_registry_returns_fresh_instances() {
571        let registry = StrategyRegistry::new();
572
573        // RoundRobin is stateful -- verify we get independent instances
574        let rr1 = registry.get("round_robin").unwrap();
575        let rr2 = registry.get("round_robin").unwrap();
576
577        let providers = vec![
578            create_test_provider("a", CostLevel::Cheap, Latency::Fast),
579            create_test_provider("b", CostLevel::Cheap, Latency::Fast),
580        ];
581
582        // Both should start at index 0 (independent state)
583        let sel1 = rr1.select_provider(providers.iter().collect());
584        let sel2 = rr2.select_provider(providers.iter().collect());
585        assert_eq!(sel1.unwrap().id, "a");
586        assert_eq!(sel2.unwrap().id, "a");
587    }
588}
589
590#[cfg(test)]
591mod proptest_tests {
592    use super::*;
593    use proptest::prelude::*;
594    use terraphim_types::capability::{Capability, ProviderType};
595
596    fn arb_cost_level() -> impl Strategy<Value = CostLevel> {
597        prop_oneof![
598            Just(CostLevel::Cheap),
599            Just(CostLevel::Moderate),
600            Just(CostLevel::Expensive),
601        ]
602    }
603
604    fn arb_latency() -> impl Strategy<Value = Latency> {
605        prop_oneof![
606            Just(Latency::Fast),
607            Just(Latency::Medium),
608            Just(Latency::Slow),
609        ]
610    }
611
612    fn arb_provider(id: &str) -> impl Strategy<Value = Provider> {
613        let id = id.to_string();
614        (arb_cost_level(), arb_latency()).prop_map(move |(cost, latency)| {
615            Provider::new(
616                id.clone(),
617                format!("Provider {}", id),
618                ProviderType::Llm {
619                    model_id: "test".to_string(),
620                    api_endpoint: "https://api.test.com".to_string(),
621                },
622                Capability::all(),
623            )
624            .with_cost(cost)
625            .with_latency(latency)
626        })
627    }
628
629    proptest! {
630        #[test]
631        fn cost_optimized_always_picks_cheapest(
632            p1 in arb_provider("p1"),
633            p2 in arb_provider("p2"),
634            p3 in arb_provider("p3"),
635        ) {
636            let strategy = CostOptimized;
637            let providers = vec![p1.clone(), p2.clone(), p3.clone()];
638            let candidates: Vec<&Provider> = providers.iter().collect();
639
640            if let Some(selected) = strategy.select_provider(candidates) {
641                // Selected provider's cost should be <= all others
642                for p in &providers {
643                    prop_assert!(selected.cost_level <= p.cost_level);
644                }
645            }
646        }
647
648        #[test]
649        fn latency_optimized_always_picks_fastest(
650            p1 in arb_provider("p1"),
651            p2 in arb_provider("p2"),
652            p3 in arb_provider("p3"),
653        ) {
654            let strategy = LatencyOptimized;
655            let providers = vec![p1.clone(), p2.clone(), p3.clone()];
656            let candidates: Vec<&Provider> = providers.iter().collect();
657
658            if let Some(selected) = strategy.select_provider(candidates) {
659                for p in &providers {
660                    prop_assert!(selected.latency <= p.latency);
661                }
662            }
663        }
664
665        #[test]
666        fn round_robin_cycles_through_all(
667            p1 in arb_provider("rr1"),
668            p2 in arb_provider("rr2"),
669        ) {
670            let strategy = RoundRobin::new();
671            let providers = vec![p1.clone(), p2.clone()];
672
673            let mut seen = std::collections::HashSet::new();
674            for _ in 0..providers.len() {
675                let candidates: Vec<&Provider> = providers.iter().collect();
676                if let Some(selected) = strategy.select_provider(candidates) {
677                    seen.insert(selected.id.clone());
678                }
679            }
680            // After N iterations with N providers, all should have been selected
681            prop_assert_eq!(seen.len(), providers.len());
682        }
683    }
684}