1use std::collections::HashMap;
4use std::sync::atomic::{AtomicU64, Ordering};
5
6use terraphim_types::capability::{CostLevel, Latency, Provider};
7
8pub trait RoutingStrategy: Send + Sync {
10 fn select_provider<'a>(&self, candidates: Vec<&'a Provider>) -> Option<&'a Provider>;
12
13 fn name(&self) -> &'static str;
15}
16
17#[derive(Debug, Clone, Default)]
19pub struct CostOptimized;
20
21impl RoutingStrategy for CostOptimized {
22 fn select_provider<'a>(&self, candidates: Vec<&'a Provider>) -> Option<&'a Provider> {
23 let result = candidates.into_iter().min_by_key(|p| p.cost_level);
24 tracing::debug!(
25 strategy = "cost_optimized",
26 selected_provider = result.map(|p| p.id.as_str()),
27 "Strategy selection complete"
28 );
29 result
30 }
31
32 fn name(&self) -> &'static str {
33 "cost_optimized"
34 }
35}
36
37#[derive(Debug, Clone, Default)]
39pub struct LatencyOptimized;
40
41impl RoutingStrategy for LatencyOptimized {
42 fn select_provider<'a>(&self, candidates: Vec<&'a Provider>) -> Option<&'a Provider> {
43 let result = candidates.into_iter().min_by_key(|p| p.latency);
44 tracing::debug!(
45 strategy = "latency_optimized",
46 selected_provider = result.map(|p| p.id.as_str()),
47 "Strategy selection complete"
48 );
49 result
50 }
51
52 fn name(&self) -> &'static str {
53 "latency_optimized"
54 }
55}
56
57#[derive(Debug, Clone, Default)]
59pub struct CapabilityFirst;
60
61impl RoutingStrategy for CapabilityFirst {
62 fn select_provider<'a>(&self, candidates: Vec<&'a Provider>) -> Option<&'a Provider> {
63 let result = candidates.into_iter().max_by_key(|p| p.capabilities.len());
64 tracing::debug!(
65 strategy = "capability_first",
66 selected_provider = result.map(|p| p.id.as_str()),
67 "Strategy selection complete"
68 );
69 result
70 }
71
72 fn name(&self) -> &'static str {
73 "capability_first"
74 }
75}
76
77#[derive(Debug)]
79pub struct RoundRobin {
80 index: std::sync::atomic::AtomicUsize,
81}
82
83impl RoundRobin {
84 pub fn new() -> Self {
85 Self {
86 index: std::sync::atomic::AtomicUsize::new(0),
87 }
88 }
89}
90
91impl Default for RoundRobin {
92 fn default() -> Self {
93 Self::new()
94 }
95}
96
97impl RoutingStrategy for RoundRobin {
98 fn select_provider<'a>(&self, candidates: Vec<&'a Provider>) -> Option<&'a Provider> {
99 if candidates.is_empty() {
100 return None;
101 }
102
103 let index = self
104 .index
105 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
106 let selected = index % candidates.len();
107
108 let result = candidates.into_iter().nth(selected);
109 tracing::debug!(
110 strategy = "round_robin",
111 selected_provider = result.map(|p| p.id.as_str()),
112 index = index,
113 "Strategy selection complete"
114 );
115 result
116 }
117
118 fn name(&self) -> &'static str {
119 "round_robin"
120 }
121}
122
123pub struct WeightedStrategy {
128 strategy_a: Box<dyn RoutingStrategy>,
129 strategy_b: Box<dyn RoutingStrategy>,
130 weight_a: f64,
132 counter: AtomicU64,
134}
135
136impl std::fmt::Debug for WeightedStrategy {
137 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138 f.debug_struct("WeightedStrategy")
139 .field("strategy_a", &self.strategy_a.name())
140 .field("strategy_b", &self.strategy_b.name())
141 .field("weight_a", &self.weight_a)
142 .finish()
143 }
144}
145
146impl WeightedStrategy {
147 pub fn new(
151 strategy_a: Box<dyn RoutingStrategy>,
152 strategy_b: Box<dyn RoutingStrategy>,
153 weight_a: f64,
154 ) -> Self {
155 Self {
156 strategy_a,
157 strategy_b,
158 weight_a: weight_a.clamp(0.0, 1.0),
159 counter: AtomicU64::new(0),
160 }
161 }
162}
163
164impl RoutingStrategy for WeightedStrategy {
165 fn select_provider<'a>(&self, candidates: Vec<&'a Provider>) -> Option<&'a Provider> {
166 let count = self.counter.fetch_add(1, Ordering::Relaxed);
167 let threshold = (self.weight_a * 100.0) as u64;
169 let use_a = (count % 100) < threshold;
170
171 let (chosen, name) = if use_a {
172 (&self.strategy_a, self.strategy_a.name())
173 } else {
174 (&self.strategy_b, self.strategy_b.name())
175 };
176
177 tracing::debug!(
178 strategy = "weighted",
179 chosen_branch = name,
180 branch = if use_a { "A" } else { "B" },
181 counter = count,
182 weight_a = self.weight_a,
183 "A/B strategy selection"
184 );
185
186 chosen.select_provider(candidates)
187 }
188
189 fn name(&self) -> &'static str {
190 "weighted"
191 }
192}
193
194pub struct PreferenceFilter {
200 base: Box<dyn RoutingStrategy>,
201 max_cost: Option<CostLevel>,
202 max_latency: Option<Latency>,
203}
204
205impl PreferenceFilter {
206 pub fn new(base: Box<dyn RoutingStrategy>) -> Self {
208 Self {
209 base,
210 max_cost: None,
211 max_latency: None,
212 }
213 }
214
215 pub fn with_max_cost(mut self, cost: CostLevel) -> Self {
217 self.max_cost = Some(cost);
218 self
219 }
220
221 pub fn with_max_latency(mut self, latency: Latency) -> Self {
223 self.max_latency = Some(latency);
224 self
225 }
226}
227
228impl std::fmt::Debug for PreferenceFilter {
229 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
230 f.debug_struct("PreferenceFilter")
231 .field("base", &self.base.name())
232 .field("max_cost", &self.max_cost)
233 .field("max_latency", &self.max_latency)
234 .finish()
235 }
236}
237
238impl RoutingStrategy for PreferenceFilter {
239 fn select_provider<'a>(&self, candidates: Vec<&'a Provider>) -> Option<&'a Provider> {
240 let filtered: Vec<&'a Provider> = candidates
241 .iter()
242 .copied()
243 .filter(|p| {
244 if let Some(max_cost) = self.max_cost {
245 if p.cost_level > max_cost {
246 return false;
247 }
248 }
249 if let Some(max_latency) = self.max_latency {
250 if p.latency > max_latency {
251 return false;
252 }
253 }
254 true
255 })
256 .collect();
257
258 let filtered_count = filtered.len();
259 tracing::debug!(
260 strategy = "preference_filter",
261 base = self.base.name(),
262 max_cost = ?self.max_cost,
263 max_latency = ?self.max_latency,
264 original_count = candidates.len(),
265 filtered_count = filtered_count,
266 "Applied preference filters"
267 );
268
269 if filtered.is_empty() {
270 tracing::debug!("Preference filter eliminated all candidates, using unfiltered");
272 self.base.select_provider(candidates)
273 } else {
274 self.base.select_provider(filtered)
275 }
276 }
277
278 fn name(&self) -> &'static str {
279 "preference_filter"
280 }
281}
282
283pub struct StrategyRegistry {
288 factories: HashMap<String, Box<dyn Fn() -> Box<dyn RoutingStrategy> + Send + Sync>>,
289}
290
291impl StrategyRegistry {
292 pub fn new() -> Self {
294 let mut reg = Self {
295 factories: HashMap::new(),
296 };
297 reg.register("cost_optimized", || Box::new(CostOptimized));
298 reg.register("latency_optimized", || Box::new(LatencyOptimized));
299 reg.register("capability_first", || Box::new(CapabilityFirst));
300 reg.register("round_robin", || Box::new(RoundRobin::new()));
301 reg
302 }
303
304 pub fn register<F>(&mut self, name: &str, factory: F)
306 where
307 F: Fn() -> Box<dyn RoutingStrategy> + Send + Sync + 'static,
308 {
309 self.factories.insert(name.to_string(), Box::new(factory));
310 }
311
312 pub fn get(&self, name: &str) -> Option<Box<dyn RoutingStrategy>> {
314 self.factories.get(name).map(|f| f())
315 }
316
317 pub fn names(&self) -> Vec<&str> {
319 self.factories.keys().map(|k| k.as_str()).collect()
320 }
321}
322
323impl Default for StrategyRegistry {
324 fn default() -> Self {
325 Self::new()
326 }
327}
328
329impl std::fmt::Debug for StrategyRegistry {
330 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
331 f.debug_struct("StrategyRegistry")
332 .field("strategies", &self.factories.keys().collect::<Vec<_>>())
333 .finish()
334 }
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340 use terraphim_types::capability::{Capability, ProviderType};
341
342 fn create_test_provider(id: &str, cost: CostLevel, latency: Latency) -> Provider {
343 Provider {
344 id: id.to_string(),
345 name: id.to_string(),
346 provider_type: ProviderType::Llm {
347 model_id: id.to_string(),
348 api_endpoint: "https://example.com".to_string(),
349 },
350 capabilities: vec![Capability::CodeGeneration],
351 cost_level: cost,
352 latency,
353 keywords: vec![],
354 }
355 }
356
357 #[test]
358 fn test_cost_optimized() {
359 let strategy = CostOptimized;
360
361 let providers = vec![
362 create_test_provider("expensive", CostLevel::Expensive, Latency::Medium),
363 create_test_provider("cheap", CostLevel::Cheap, Latency::Medium),
364 create_test_provider("moderate", CostLevel::Moderate, Latency::Medium),
365 ];
366
367 let candidates: Vec<&Provider> = providers.iter().collect();
368 let selected = strategy.select_provider(candidates);
369
370 assert_eq!(selected.unwrap().id, "cheap");
371 }
372
373 #[test]
374 fn test_latency_optimized() {
375 let strategy = LatencyOptimized;
376
377 let providers = vec![
378 create_test_provider("slow", CostLevel::Moderate, Latency::Slow),
379 create_test_provider("fast", CostLevel::Moderate, Latency::Fast),
380 create_test_provider("medium", CostLevel::Moderate, Latency::Medium),
381 ];
382
383 let candidates: Vec<&Provider> = providers.iter().collect();
384 let selected = strategy.select_provider(candidates);
385
386 assert_eq!(selected.unwrap().id, "fast");
387 }
388
389 #[test]
390 fn test_round_robin() {
391 let strategy = RoundRobin::new();
392
393 let providers = vec![
394 create_test_provider("a", CostLevel::Cheap, Latency::Fast),
395 create_test_provider("b", CostLevel::Cheap, Latency::Fast),
396 create_test_provider("c", CostLevel::Cheap, Latency::Fast),
397 ];
398
399 let candidates: Vec<&Provider> = providers.iter().collect();
401 let selected = strategy.select_provider(candidates.clone());
402 assert_eq!(selected.unwrap().id, "a");
403
404 let selected = strategy.select_provider(candidates.clone());
406 assert_eq!(selected.unwrap().id, "b");
407
408 let selected = strategy.select_provider(candidates.clone());
410 assert_eq!(selected.unwrap().id, "c");
411
412 let selected = strategy.select_provider(candidates);
414 assert_eq!(selected.unwrap().id, "a");
415 }
416
417 #[test]
418 fn test_weighted_strategy_all_a() {
419 let strategy = WeightedStrategy::new(
420 Box::new(CostOptimized),
421 Box::new(LatencyOptimized),
422 1.0, );
424
425 let providers = vec![
426 create_test_provider("cheap-slow", CostLevel::Cheap, Latency::Slow),
427 create_test_provider("expensive-fast", CostLevel::Expensive, Latency::Fast),
428 ];
429
430 for _ in 0..10 {
432 let candidates: Vec<&Provider> = providers.iter().collect();
433 let selected = strategy.select_provider(candidates);
434 assert_eq!(selected.unwrap().id, "cheap-slow");
435 }
436 }
437
438 #[test]
439 fn test_weighted_strategy_all_b() {
440 let strategy = WeightedStrategy::new(
441 Box::new(CostOptimized),
442 Box::new(LatencyOptimized),
443 0.0, );
445
446 let providers = vec![
447 create_test_provider("cheap-slow", CostLevel::Cheap, Latency::Slow),
448 create_test_provider("expensive-fast", CostLevel::Expensive, Latency::Fast),
449 ];
450
451 for _ in 0..10 {
453 let candidates: Vec<&Provider> = providers.iter().collect();
454 let selected = strategy.select_provider(candidates);
455 assert_eq!(selected.unwrap().id, "expensive-fast");
456 }
457 }
458
459 #[test]
460 fn test_weighted_strategy_split() {
461 let strategy = WeightedStrategy::new(
462 Box::new(CostOptimized),
463 Box::new(LatencyOptimized),
464 0.5, );
466
467 let providers = vec![
468 create_test_provider("cheap-slow", CostLevel::Cheap, Latency::Slow),
469 create_test_provider("expensive-fast", CostLevel::Expensive, Latency::Fast),
470 ];
471
472 let mut a_count = 0;
473 let mut b_count = 0;
474 for _ in 0..100 {
475 let candidates: Vec<&Provider> = providers.iter().collect();
476 let selected = strategy.select_provider(candidates).unwrap();
477 if selected.id == "cheap-slow" {
478 a_count += 1;
479 } else {
480 b_count += 1;
481 }
482 }
483
484 assert_eq!(a_count, 50);
486 assert_eq!(b_count, 50);
487 }
488
489 #[test]
490 fn test_preference_filter_cost() {
491 let strategy =
492 PreferenceFilter::new(Box::new(LatencyOptimized)).with_max_cost(CostLevel::Moderate);
493
494 let providers = vec![
495 create_test_provider("cheap-slow", CostLevel::Cheap, Latency::Slow),
496 create_test_provider("expensive-fast", CostLevel::Expensive, Latency::Fast),
497 create_test_provider("moderate-medium", CostLevel::Moderate, Latency::Medium),
498 ];
499
500 let candidates: Vec<&Provider> = providers.iter().collect();
501 let selected = strategy.select_provider(candidates);
502
503 assert_eq!(selected.unwrap().id, "moderate-medium");
505 }
506
507 #[test]
508 fn test_preference_filter_latency() {
509 let strategy =
510 PreferenceFilter::new(Box::new(CostOptimized)).with_max_latency(Latency::Medium);
511
512 let providers = vec![
513 create_test_provider("cheap-slow", CostLevel::Cheap, Latency::Slow),
514 create_test_provider("expensive-fast", CostLevel::Expensive, Latency::Fast),
515 create_test_provider("moderate-medium", CostLevel::Moderate, Latency::Medium),
516 ];
517
518 let candidates: Vec<&Provider> = providers.iter().collect();
519 let selected = strategy.select_provider(candidates);
520
521 assert_eq!(selected.unwrap().id, "moderate-medium");
523 }
524
525 #[test]
526 fn test_preference_filter_fallthrough() {
527 let strategy =
529 PreferenceFilter::new(Box::new(CostOptimized)).with_max_latency(Latency::Fast);
530
531 let providers = vec![
532 create_test_provider("cheap-slow", CostLevel::Cheap, Latency::Slow),
533 create_test_provider("expensive-slow", CostLevel::Expensive, Latency::Slow),
534 ];
535
536 let candidates: Vec<&Provider> = providers.iter().collect();
537 let selected = strategy.select_provider(candidates);
538
539 assert_eq!(selected.unwrap().id, "cheap-slow");
541 }
542
543 #[test]
544 fn test_strategy_registry_builtins() {
545 let registry = StrategyRegistry::new();
546
547 assert!(registry.get("cost_optimized").is_some());
548 assert!(registry.get("latency_optimized").is_some());
549 assert!(registry.get("capability_first").is_some());
550 assert!(registry.get("round_robin").is_some());
551 assert!(registry.get("nonexistent").is_none());
552
553 let names = registry.names();
554 assert_eq!(names.len(), 4);
555 }
556
557 #[test]
558 fn test_strategy_registry_custom() {
559 let mut registry = StrategyRegistry::new();
560 registry.register("my_strategy", || Box::new(CostOptimized));
561
562 let strategy = registry.get("my_strategy").unwrap();
563 assert_eq!(strategy.name(), "cost_optimized");
564
565 let names = registry.names();
566 assert_eq!(names.len(), 5);
567 }
568
569 #[test]
570 fn test_strategy_registry_returns_fresh_instances() {
571 let registry = StrategyRegistry::new();
572
573 let rr1 = registry.get("round_robin").unwrap();
575 let rr2 = registry.get("round_robin").unwrap();
576
577 let providers = vec![
578 create_test_provider("a", CostLevel::Cheap, Latency::Fast),
579 create_test_provider("b", CostLevel::Cheap, Latency::Fast),
580 ];
581
582 let sel1 = rr1.select_provider(providers.iter().collect());
584 let sel2 = rr2.select_provider(providers.iter().collect());
585 assert_eq!(sel1.unwrap().id, "a");
586 assert_eq!(sel2.unwrap().id, "a");
587 }
588}
589
590#[cfg(test)]
591mod proptest_tests {
592 use super::*;
593 use proptest::prelude::*;
594 use terraphim_types::capability::{Capability, ProviderType};
595
596 fn arb_cost_level() -> impl Strategy<Value = CostLevel> {
597 prop_oneof![
598 Just(CostLevel::Cheap),
599 Just(CostLevel::Moderate),
600 Just(CostLevel::Expensive),
601 ]
602 }
603
604 fn arb_latency() -> impl Strategy<Value = Latency> {
605 prop_oneof![
606 Just(Latency::Fast),
607 Just(Latency::Medium),
608 Just(Latency::Slow),
609 ]
610 }
611
612 fn arb_provider(id: &str) -> impl Strategy<Value = Provider> {
613 let id = id.to_string();
614 (arb_cost_level(), arb_latency()).prop_map(move |(cost, latency)| {
615 Provider::new(
616 id.clone(),
617 format!("Provider {}", id),
618 ProviderType::Llm {
619 model_id: "test".to_string(),
620 api_endpoint: "https://api.test.com".to_string(),
621 },
622 Capability::all(),
623 )
624 .with_cost(cost)
625 .with_latency(latency)
626 })
627 }
628
629 proptest! {
630 #[test]
631 fn cost_optimized_always_picks_cheapest(
632 p1 in arb_provider("p1"),
633 p2 in arb_provider("p2"),
634 p3 in arb_provider("p3"),
635 ) {
636 let strategy = CostOptimized;
637 let providers = vec![p1.clone(), p2.clone(), p3.clone()];
638 let candidates: Vec<&Provider> = providers.iter().collect();
639
640 if let Some(selected) = strategy.select_provider(candidates) {
641 for p in &providers {
643 prop_assert!(selected.cost_level <= p.cost_level);
644 }
645 }
646 }
647
648 #[test]
649 fn latency_optimized_always_picks_fastest(
650 p1 in arb_provider("p1"),
651 p2 in arb_provider("p2"),
652 p3 in arb_provider("p3"),
653 ) {
654 let strategy = LatencyOptimized;
655 let providers = vec![p1.clone(), p2.clone(), p3.clone()];
656 let candidates: Vec<&Provider> = providers.iter().collect();
657
658 if let Some(selected) = strategy.select_provider(candidates) {
659 for p in &providers {
660 prop_assert!(selected.latency <= p.latency);
661 }
662 }
663 }
664
665 #[test]
666 fn round_robin_cycles_through_all(
667 p1 in arb_provider("rr1"),
668 p2 in arb_provider("rr2"),
669 ) {
670 let strategy = RoundRobin::new();
671 let providers = vec![p1.clone(), p2.clone()];
672
673 let mut seen = std::collections::HashSet::new();
674 for _ in 0..providers.len() {
675 let candidates: Vec<&Provider> = providers.iter().collect();
676 if let Some(selected) = strategy.select_provider(candidates) {
677 seen.insert(selected.id.clone());
678 }
679 }
680 prop_assert_eq!(seen.len(), providers.len());
682 }
683 }
684}