1use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use tensorlogic_ir::EinsumGraph;
14use thiserror::Error;
15
16#[derive(Error, Debug, Clone, PartialEq)]
18pub enum MultiModelError {
19 #[error("Model not found: {0}")]
20 ModelNotFound(String),
21
22 #[error("Incompatible model outputs")]
23 IncompatibleOutputs,
24
25 #[error("Invalid ensemble configuration: {0}")]
26 InvalidEnsemble(String),
27
28 #[error("Model routing failed: {0}")]
29 RoutingFailed(String),
30
31 #[error("Resource limit exceeded: {0}")]
32 ResourceLimitExceeded(String),
33}
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
37pub enum EnsembleStrategy {
38 Average,
40 WeightedAverage,
42 MajorityVote,
44 SoftVote,
46 Stacking,
48 Boosting,
50 MaxConfidence,
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct ModelMetadata {
57 pub id: String,
59 pub name: String,
61 pub version: String,
63 pub input_shapes: Vec<Vec<usize>>,
65 pub output_shapes: Vec<Vec<usize>>,
67 pub weight: f64,
69 pub priority: u32,
71 pub resource_requirements: ResourceRequirements,
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct ResourceRequirements {
78 pub memory_bytes: usize,
80 pub gpu_memory_bytes: Option<usize>,
82 pub estimated_flops: f64,
84 pub estimated_latency_ms: f64,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct EnsembleConfig {
91 pub strategy: EnsembleStrategy,
93 pub model_weights: HashMap<String, f64>,
95 pub min_models: usize,
97 pub parallel_execution: bool,
99 pub model_timeout_ms: Option<u64>,
101}
102
103impl Default for EnsembleConfig {
104 fn default() -> Self {
105 Self {
106 strategy: EnsembleStrategy::Average,
107 model_weights: HashMap::new(),
108 min_models: 1,
109 parallel_execution: true,
110 model_timeout_ms: None,
111 }
112 }
113}
114
115impl EnsembleConfig {
116 pub fn voting() -> Self {
118 Self {
119 strategy: EnsembleStrategy::MajorityVote,
120 min_models: 3,
121 ..Default::default()
122 }
123 }
124
125 pub fn weighted_average(weights: HashMap<String, f64>) -> Self {
127 Self {
128 strategy: EnsembleStrategy::WeightedAverage,
129 model_weights: weights,
130 ..Default::default()
131 }
132 }
133
134 pub fn stacking() -> Self {
136 Self {
137 strategy: EnsembleStrategy::Stacking,
138 parallel_execution: true,
139 ..Default::default()
140 }
141 }
142}
143
144#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
146pub enum RoutingStrategy {
147 Priority,
149 LowestLatency,
151 BestAccuracy,
153 RoundRobin,
155 Random,
157 Cascade,
159 ContentBased,
161}
162
163#[derive(Debug, Clone, Serialize, Deserialize)]
165pub struct CascadeConfig {
166 pub model_sequence: Vec<String>,
168 pub confidence_thresholds: Vec<f64>,
170 pub enable_early_exit: bool,
172 pub max_models: usize,
174}
175
176impl CascadeConfig {
177 pub fn two_tier(fast_model: String, accurate_model: String, threshold: f64) -> Self {
179 Self {
180 model_sequence: vec![fast_model, accurate_model],
181 confidence_thresholds: vec![threshold],
182 enable_early_exit: true,
183 max_models: 2,
184 }
185 }
186
187 pub fn three_tier(
189 fast: String,
190 medium: String,
191 accurate: String,
192 thresholds: (f64, f64),
193 ) -> Self {
194 Self {
195 model_sequence: vec![fast, medium, accurate],
196 confidence_thresholds: vec![thresholds.0, thresholds.1],
197 enable_early_exit: true,
198 max_models: 3,
199 }
200 }
201}
202
203pub struct MultiModelCoordinator {
205 models: HashMap<String, EinsumGraph>,
206 metadata: HashMap<String, ModelMetadata>,
207 ensemble_config: Option<EnsembleConfig>,
208 routing_strategy: RoutingStrategy,
209 stats: CoordinationStats,
210}
211
212impl MultiModelCoordinator {
213 pub fn new() -> Self {
215 Self {
216 models: HashMap::new(),
217 metadata: HashMap::new(),
218 ensemble_config: None,
219 routing_strategy: RoutingStrategy::Priority,
220 stats: CoordinationStats::default(),
221 }
222 }
223
224 pub fn register_model(
226 &mut self,
227 graph: EinsumGraph,
228 metadata: ModelMetadata,
229 ) -> Result<(), MultiModelError> {
230 let id = metadata.id.clone();
231 self.models.insert(id.clone(), graph);
232 self.metadata.insert(id, metadata);
233 self.stats.total_models += 1;
234 Ok(())
235 }
236
237 pub fn unregister_model(&mut self, model_id: &str) -> Result<(), MultiModelError> {
239 self.models
240 .remove(model_id)
241 .ok_or_else(|| MultiModelError::ModelNotFound(model_id.to_string()))?;
242 self.metadata.remove(model_id);
243 self.stats.total_models = self.models.len();
244 Ok(())
245 }
246
247 pub fn set_ensemble_config(&mut self, config: EnsembleConfig) {
249 self.ensemble_config = Some(config);
250 }
251
252 pub fn set_routing_strategy(&mut self, strategy: RoutingStrategy) {
254 self.routing_strategy = strategy;
255 }
256
257 pub fn select_model(
259 &mut self,
260 _input_features: Option<&[f64]>,
261 ) -> Result<String, MultiModelError> {
262 if self.models.is_empty() {
263 return Err(MultiModelError::RoutingFailed(
264 "No models registered".to_string(),
265 ));
266 }
267
268 let selected = match self.routing_strategy {
269 RoutingStrategy::Priority => self.select_by_priority(),
270 RoutingStrategy::LowestLatency => self.select_by_latency(),
271 RoutingStrategy::BestAccuracy => self.select_by_accuracy(),
272 RoutingStrategy::RoundRobin => self.select_round_robin(),
273 RoutingStrategy::Random => self.select_random(),
274 RoutingStrategy::Cascade => self.select_cascade(),
275 RoutingStrategy::ContentBased => self.select_content_based(_input_features),
276 };
277
278 if let Ok(ref id) = selected {
279 self.stats.total_routings += 1;
280 *self.stats.model_usage.entry(id.clone()).or_insert(0) += 1;
281 }
282
283 selected
284 }
285
286 fn select_by_priority(&self) -> Result<String, MultiModelError> {
287 self.metadata
288 .iter()
289 .max_by_key(|(_, meta)| meta.priority)
290 .map(|(id, _)| id.clone())
291 .ok_or_else(|| MultiModelError::RoutingFailed("No models available".to_string()))
292 }
293
294 fn select_by_latency(&self) -> Result<String, MultiModelError> {
295 self.metadata
296 .iter()
297 .min_by(|(_, a), (_, b)| {
298 a.resource_requirements
299 .estimated_latency_ms
300 .partial_cmp(&b.resource_requirements.estimated_latency_ms)
301 .unwrap()
302 })
303 .map(|(id, _)| id.clone())
304 .ok_or_else(|| MultiModelError::RoutingFailed("No models available".to_string()))
305 }
306
307 fn select_by_accuracy(&self) -> Result<String, MultiModelError> {
308 self.select_by_priority()
311 }
312
313 fn select_round_robin(&mut self) -> Result<String, MultiModelError> {
314 let model_ids: Vec<_> = self.models.keys().cloned().collect();
315 if model_ids.is_empty() {
316 return Err(MultiModelError::RoutingFailed(
317 "No models available".to_string(),
318 ));
319 }
320
321 let idx = self.stats.total_routings % model_ids.len();
322 Ok(model_ids[idx].clone())
323 }
324
325 fn select_random(&self) -> Result<String, MultiModelError> {
326 let model_ids: Vec<_> = self.models.keys().cloned().collect();
328 if model_ids.is_empty() {
329 return Err(MultiModelError::RoutingFailed(
330 "No models available".to_string(),
331 ));
332 }
333
334 Ok(model_ids[0].clone())
335 }
336
337 fn select_cascade(&self) -> Result<String, MultiModelError> {
338 self.select_by_latency()
340 }
341
342 fn select_content_based(&self, _features: Option<&[f64]>) -> Result<String, MultiModelError> {
343 self.select_by_priority()
346 }
347
348 pub fn get_model(&self, model_id: &str) -> Option<&EinsumGraph> {
350 self.models.get(model_id)
351 }
352
353 pub fn get_metadata(&self, model_id: &str) -> Option<&ModelMetadata> {
355 self.metadata.get(model_id)
356 }
357
358 pub fn model_ids(&self) -> Vec<String> {
360 self.models.keys().cloned().collect()
361 }
362
363 pub fn stats(&self) -> &CoordinationStats {
365 &self.stats
366 }
367
368 pub fn has_ensemble(&self) -> bool {
370 self.ensemble_config.is_some()
371 }
372
373 pub fn ensemble_config(&self) -> Option<&EnsembleConfig> {
375 self.ensemble_config.as_ref()
376 }
377
378 pub fn total_resource_requirements(&self) -> ResourceRequirements {
380 let mut total = ResourceRequirements {
381 memory_bytes: 0,
382 gpu_memory_bytes: Some(0),
383 estimated_flops: 0.0,
384 estimated_latency_ms: 0.0,
385 };
386
387 for metadata in self.metadata.values() {
388 let req = &metadata.resource_requirements;
389 total.memory_bytes += req.memory_bytes;
390 if let (Some(total_gpu), Some(req_gpu)) = (total.gpu_memory_bytes, req.gpu_memory_bytes)
391 {
392 total.gpu_memory_bytes = Some(total_gpu + req_gpu);
393 }
394 total.estimated_flops += req.estimated_flops;
395 total.estimated_latency_ms = total.estimated_latency_ms.max(req.estimated_latency_ms);
397 }
398
399 total
400 }
401}
402
403impl Default for MultiModelCoordinator {
404 fn default() -> Self {
405 Self::new()
406 }
407}
408
409#[derive(Debug, Clone, Default, Serialize, Deserialize)]
411pub struct CoordinationStats {
412 pub total_models: usize,
414 pub total_routings: usize,
416 pub total_ensemble_executions: usize,
418 pub model_usage: HashMap<String, usize>,
420}
421
422impl CoordinationStats {
423 pub fn most_used_model(&self) -> Option<(String, usize)> {
425 self.model_usage
426 .iter()
427 .max_by_key(|(_, &count)| count)
428 .map(|(id, &count)| (id.clone(), count))
429 }
430
431 pub fn usage_distribution(&self) -> HashMap<String, f64> {
433 let total = self.model_usage.values().sum::<usize>() as f64;
434 if total == 0.0 {
435 return HashMap::new();
436 }
437
438 self.model_usage
439 .iter()
440 .map(|(id, &count)| (id.clone(), count as f64 / total))
441 .collect()
442 }
443}
444
445pub trait TlEnsembleExecutor {
447 type Output;
449 type Error;
451
452 fn execute_ensemble(
454 &self,
455 models: &[&EinsumGraph],
456 inputs: &[Self::Output],
457 strategy: EnsembleStrategy,
458 ) -> Result<Self::Output, Self::Error>;
459
460 fn aggregate_predictions(
462 &self,
463 predictions: &[Self::Output],
464 strategy: EnsembleStrategy,
465 ) -> Result<Self::Output, Self::Error>;
466}
467
468pub trait TlModelRouter {
470 fn route_to_model(&self, input: &[f64]) -> Result<String, MultiModelError>;
472
473 fn routing_confidence(&self, input: &[f64], model_id: &str) -> f64;
475}
476
477#[cfg(test)]
478mod tests {
479 use super::*;
480 use tensorlogic_ir::{EinsumNode, OpType};
481
482 fn create_test_graph(_id: &str) -> EinsumGraph {
483 let mut graph = EinsumGraph::new();
484 graph.nodes.push(EinsumNode {
485 op: OpType::Einsum {
486 spec: "ij->ij".to_string(),
487 },
488 inputs: vec![],
489 outputs: vec![0],
490 metadata: Default::default(),
491 });
492 graph
493 }
494
495 fn create_test_metadata(id: &str, priority: u32) -> ModelMetadata {
496 ModelMetadata {
497 id: id.to_string(),
498 name: format!("Model {}", id),
499 version: "1.0".to_string(),
500 input_shapes: vec![vec![10, 10]],
501 output_shapes: vec![vec![10, 10]],
502 weight: 1.0,
503 priority,
504 resource_requirements: ResourceRequirements {
505 memory_bytes: 1024 * 1024,
506 gpu_memory_bytes: Some(512 * 1024),
507 estimated_flops: 1e9,
508 estimated_latency_ms: 10.0,
509 },
510 }
511 }
512
513 #[test]
514 fn test_ensemble_strategy() {
515 let config = EnsembleConfig::voting();
516 assert_eq!(config.strategy, EnsembleStrategy::MajorityVote);
517
518 let mut weights = HashMap::new();
519 weights.insert("model1".to_string(), 0.6);
520 weights.insert("model2".to_string(), 0.4);
521 let config = EnsembleConfig::weighted_average(weights);
522 assert_eq!(config.strategy, EnsembleStrategy::WeightedAverage);
523 }
524
525 #[test]
526 fn test_cascade_config() {
527 let config = CascadeConfig::two_tier("fast".to_string(), "accurate".to_string(), 0.9);
528 assert_eq!(config.model_sequence.len(), 2);
529 assert_eq!(config.confidence_thresholds[0], 0.9);
530
531 let config = CascadeConfig::three_tier(
532 "fast".to_string(),
533 "medium".to_string(),
534 "accurate".to_string(),
535 (0.8, 0.95),
536 );
537 assert_eq!(config.model_sequence.len(), 3);
538 }
539
540 #[test]
541 fn test_coordinator_creation() {
542 let coordinator = MultiModelCoordinator::new();
543 assert_eq!(coordinator.models.len(), 0);
544 assert_eq!(coordinator.stats.total_models, 0);
545 }
546
547 #[test]
548 fn test_model_registration() {
549 let mut coordinator = MultiModelCoordinator::new();
550
551 let graph = create_test_graph("model1");
552 let metadata = create_test_metadata("model1", 1);
553
554 assert!(coordinator.register_model(graph, metadata).is_ok());
555 assert_eq!(coordinator.stats.total_models, 1);
556 assert!(coordinator.get_model("model1").is_some());
557 }
558
559 #[test]
560 fn test_model_unregistration() {
561 let mut coordinator = MultiModelCoordinator::new();
562
563 let graph = create_test_graph("model1");
564 let metadata = create_test_metadata("model1", 1);
565 coordinator.register_model(graph, metadata).unwrap();
566
567 assert!(coordinator.unregister_model("model1").is_ok());
568 assert_eq!(coordinator.stats.total_models, 0);
569 assert!(coordinator.get_model("model1").is_none());
570 }
571
572 #[test]
573 fn test_routing_by_priority() {
574 let mut coordinator = MultiModelCoordinator::new();
575
576 coordinator
577 .register_model(
578 create_test_graph("model1"),
579 create_test_metadata("model1", 1),
580 )
581 .unwrap();
582 coordinator
583 .register_model(
584 create_test_graph("model2"),
585 create_test_metadata("model2", 5),
586 )
587 .unwrap();
588
589 coordinator.set_routing_strategy(RoutingStrategy::Priority);
590 let selected = coordinator.select_model(None).unwrap();
591 assert_eq!(selected, "model2"); }
593
594 #[test]
595 fn test_routing_by_latency() {
596 let mut coordinator = MultiModelCoordinator::new();
597
598 let mut meta1 = create_test_metadata("model1", 1);
599 meta1.resource_requirements.estimated_latency_ms = 20.0;
600 let mut meta2 = create_test_metadata("model2", 1);
601 meta2.resource_requirements.estimated_latency_ms = 5.0;
602
603 coordinator
604 .register_model(create_test_graph("model1"), meta1)
605 .unwrap();
606 coordinator
607 .register_model(create_test_graph("model2"), meta2)
608 .unwrap();
609
610 coordinator.set_routing_strategy(RoutingStrategy::LowestLatency);
611 let selected = coordinator.select_model(None).unwrap();
612 assert_eq!(selected, "model2"); }
614
615 #[test]
616 fn test_ensemble_configuration() {
617 let mut coordinator = MultiModelCoordinator::new();
618 assert!(!coordinator.has_ensemble());
619
620 coordinator.set_ensemble_config(EnsembleConfig::voting());
621 assert!(coordinator.has_ensemble());
622 assert_eq!(
623 coordinator.ensemble_config().unwrap().strategy,
624 EnsembleStrategy::MajorityVote
625 );
626 }
627
628 #[test]
629 fn test_total_resource_requirements() {
630 let mut coordinator = MultiModelCoordinator::new();
631
632 coordinator
633 .register_model(
634 create_test_graph("model1"),
635 create_test_metadata("model1", 1),
636 )
637 .unwrap();
638 coordinator
639 .register_model(
640 create_test_graph("model2"),
641 create_test_metadata("model2", 1),
642 )
643 .unwrap();
644
645 let total = coordinator.total_resource_requirements();
646 assert_eq!(total.memory_bytes, 2 * 1024 * 1024);
647 assert_eq!(total.gpu_memory_bytes, Some(2 * 512 * 1024));
648 }
649
650 #[test]
651 fn test_coordination_stats() {
652 let mut stats = CoordinationStats::default();
653 stats.model_usage.insert("model1".to_string(), 10);
654 stats.model_usage.insert("model2".to_string(), 5);
655
656 let (id, count) = stats.most_used_model().unwrap();
657 assert_eq!(id, "model1");
658 assert_eq!(count, 10);
659
660 let dist = stats.usage_distribution();
661 assert_eq!(dist.get("model1").unwrap(), &(10.0 / 15.0));
662 }
663
664 #[test]
665 fn test_round_robin_routing() {
666 let mut coordinator = MultiModelCoordinator::new();
667
668 coordinator
669 .register_model(
670 create_test_graph("model1"),
671 create_test_metadata("model1", 1),
672 )
673 .unwrap();
674 coordinator
675 .register_model(
676 create_test_graph("model2"),
677 create_test_metadata("model2", 1),
678 )
679 .unwrap();
680
681 coordinator.set_routing_strategy(RoutingStrategy::RoundRobin);
682
683 let id1 = coordinator.select_model(None).unwrap();
684 let id2 = coordinator.select_model(None).unwrap();
685
686 assert_ne!(id1, id2);
688 }
689
690 #[test]
691 fn test_model_ids() {
692 let mut coordinator = MultiModelCoordinator::new();
693
694 coordinator
695 .register_model(
696 create_test_graph("model1"),
697 create_test_metadata("model1", 1),
698 )
699 .unwrap();
700 coordinator
701 .register_model(
702 create_test_graph("model2"),
703 create_test_metadata("model2", 1),
704 )
705 .unwrap();
706
707 let ids = coordinator.model_ids();
708 assert_eq!(ids.len(), 2);
709 assert!(ids.contains(&"model1".to_string()));
710 assert!(ids.contains(&"model2".to_string()));
711 }
712}