Skip to main content

tensorlogic_infer/
multimodel.rs

1//! Multi-model coordination for ensemble and multi-task inference.
2//!
3//! This module provides coordination capabilities for running multiple models:
4//! - Ensemble inference (voting, averaging, stacking)
5//! - Multi-task model coordination
6//! - Model cascades (early-exit strategies)
7//! - Model routing (dynamic model selection)
8//! - Resource sharing across models
9//! - Load balancing for model serving
10
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use tensorlogic_ir::EinsumGraph;
14use thiserror::Error;
15
16/// Multi-model coordination errors.
17#[derive(Error, Debug, Clone, PartialEq)]
18pub enum MultiModelError {
19    #[error("Model not found: {0}")]
20    ModelNotFound(String),
21
22    #[error("Incompatible model outputs")]
23    IncompatibleOutputs,
24
25    #[error("Invalid ensemble configuration: {0}")]
26    InvalidEnsemble(String),
27
28    #[error("Model routing failed: {0}")]
29    RoutingFailed(String),
30
31    #[error("Resource limit exceeded: {0}")]
32    ResourceLimitExceeded(String),
33}
34
35/// Ensemble aggregation strategy.
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
37pub enum EnsembleStrategy {
38    /// Simple average of predictions
39    Average,
40    /// Weighted average with learned weights
41    WeightedAverage,
42    /// Majority voting (for classification)
43    MajorityVote,
44    /// Soft voting with probabilities
45    SoftVote,
46    /// Stacking with meta-learner
47    Stacking,
48    /// Boosting-style weighted combination
49    Boosting,
50    /// Maximum confidence prediction
51    MaxConfidence,
52}
53
54/// Model metadata for coordination.
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct ModelMetadata {
57    /// Model identifier
58    pub id: String,
59    /// Model name
60    pub name: String,
61    /// Model version
62    pub version: String,
63    /// Expected input shapes
64    pub input_shapes: Vec<Vec<usize>>,
65    /// Expected output shapes
66    pub output_shapes: Vec<Vec<usize>>,
67    /// Model weight (for ensemble)
68    pub weight: f64,
69    /// Priority (for routing)
70    pub priority: u32,
71    /// Resource requirements
72    pub resource_requirements: ResourceRequirements,
73}
74
75/// Resource requirements for a model.
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct ResourceRequirements {
78    /// Memory required (bytes)
79    pub memory_bytes: usize,
80    /// GPU memory required (bytes)
81    pub gpu_memory_bytes: Option<usize>,
82    /// Estimated FLOPS
83    pub estimated_flops: f64,
84    /// Estimated latency (milliseconds)
85    pub estimated_latency_ms: f64,
86}
87
88/// Ensemble configuration.
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct EnsembleConfig {
91    /// Ensemble strategy
92    pub strategy: EnsembleStrategy,
93    /// Model weights (for weighted averaging)
94    pub model_weights: HashMap<String, f64>,
95    /// Minimum models for consensus
96    pub min_models: usize,
97    /// Enable parallel execution
98    pub parallel_execution: bool,
99    /// Timeout for individual models
100    pub model_timeout_ms: Option<u64>,
101}
102
103impl Default for EnsembleConfig {
104    fn default() -> Self {
105        Self {
106            strategy: EnsembleStrategy::Average,
107            model_weights: HashMap::new(),
108            min_models: 1,
109            parallel_execution: true,
110            model_timeout_ms: None,
111        }
112    }
113}
114
115impl EnsembleConfig {
116    /// Create configuration for voting ensemble.
117    pub fn voting() -> Self {
118        Self {
119            strategy: EnsembleStrategy::MajorityVote,
120            min_models: 3,
121            ..Default::default()
122        }
123    }
124
125    /// Create configuration for weighted averaging.
126    pub fn weighted_average(weights: HashMap<String, f64>) -> Self {
127        Self {
128            strategy: EnsembleStrategy::WeightedAverage,
129            model_weights: weights,
130            ..Default::default()
131        }
132    }
133
134    /// Create configuration for stacking ensemble.
135    pub fn stacking() -> Self {
136        Self {
137            strategy: EnsembleStrategy::Stacking,
138            parallel_execution: true,
139            ..Default::default()
140        }
141    }
142}
143
144/// Model routing strategy.
145#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
146pub enum RoutingStrategy {
147    /// Route to model with highest priority
148    Priority,
149    /// Route to model with lowest latency
150    LowestLatency,
151    /// Route to model with best accuracy (requires profiling)
152    BestAccuracy,
153    /// Round-robin across models
154    RoundRobin,
155    /// Random selection
156    Random,
157    /// Cascade (try fast model first, fallback to accurate)
158    Cascade,
159    /// Route based on input features
160    ContentBased,
161}
162
163/// Model cascade configuration.
164#[derive(Debug, Clone, Serialize, Deserialize)]
165pub struct CascadeConfig {
166    /// Ordered list of model IDs (fast to accurate)
167    pub model_sequence: Vec<String>,
168    /// Confidence thresholds for each model
169    pub confidence_thresholds: Vec<f64>,
170    /// Enable early exit
171    pub enable_early_exit: bool,
172    /// Maximum models to try
173    pub max_models: usize,
174}
175
176impl CascadeConfig {
177    /// Create a two-tier cascade (fast + accurate).
178    pub fn two_tier(fast_model: String, accurate_model: String, threshold: f64) -> Self {
179        Self {
180            model_sequence: vec![fast_model, accurate_model],
181            confidence_thresholds: vec![threshold],
182            enable_early_exit: true,
183            max_models: 2,
184        }
185    }
186
187    /// Create a three-tier cascade.
188    pub fn three_tier(
189        fast: String,
190        medium: String,
191        accurate: String,
192        thresholds: (f64, f64),
193    ) -> Self {
194        Self {
195            model_sequence: vec![fast, medium, accurate],
196            confidence_thresholds: vec![thresholds.0, thresholds.1],
197            enable_early_exit: true,
198            max_models: 3,
199        }
200    }
201}
202
203/// Multi-model coordinator.
204pub struct MultiModelCoordinator {
205    models: HashMap<String, EinsumGraph>,
206    metadata: HashMap<String, ModelMetadata>,
207    ensemble_config: Option<EnsembleConfig>,
208    routing_strategy: RoutingStrategy,
209    stats: CoordinationStats,
210}
211
212impl MultiModelCoordinator {
213    /// Create a new multi-model coordinator.
214    pub fn new() -> Self {
215        Self {
216            models: HashMap::new(),
217            metadata: HashMap::new(),
218            ensemble_config: None,
219            routing_strategy: RoutingStrategy::Priority,
220            stats: CoordinationStats::default(),
221        }
222    }
223
224    /// Register a model.
225    pub fn register_model(
226        &mut self,
227        graph: EinsumGraph,
228        metadata: ModelMetadata,
229    ) -> Result<(), MultiModelError> {
230        let id = metadata.id.clone();
231        self.models.insert(id.clone(), graph);
232        self.metadata.insert(id, metadata);
233        self.stats.total_models += 1;
234        Ok(())
235    }
236
237    /// Unregister a model.
238    pub fn unregister_model(&mut self, model_id: &str) -> Result<(), MultiModelError> {
239        self.models
240            .remove(model_id)
241            .ok_or_else(|| MultiModelError::ModelNotFound(model_id.to_string()))?;
242        self.metadata.remove(model_id);
243        self.stats.total_models = self.models.len();
244        Ok(())
245    }
246
247    /// Set ensemble configuration.
248    pub fn set_ensemble_config(&mut self, config: EnsembleConfig) {
249        self.ensemble_config = Some(config);
250    }
251
252    /// Set routing strategy.
253    pub fn set_routing_strategy(&mut self, strategy: RoutingStrategy) {
254        self.routing_strategy = strategy;
255    }
256
257    /// Select model based on routing strategy.
258    pub fn select_model(
259        &mut self,
260        _input_features: Option<&[f64]>,
261    ) -> Result<String, MultiModelError> {
262        if self.models.is_empty() {
263            return Err(MultiModelError::RoutingFailed(
264                "No models registered".to_string(),
265            ));
266        }
267
268        let selected = match self.routing_strategy {
269            RoutingStrategy::Priority => self.select_by_priority(),
270            RoutingStrategy::LowestLatency => self.select_by_latency(),
271            RoutingStrategy::BestAccuracy => self.select_by_accuracy(),
272            RoutingStrategy::RoundRobin => self.select_round_robin(),
273            RoutingStrategy::Random => self.select_random(),
274            RoutingStrategy::Cascade => self.select_cascade(),
275            RoutingStrategy::ContentBased => self.select_content_based(_input_features),
276        };
277
278        if let Ok(ref id) = selected {
279            self.stats.total_routings += 1;
280            *self.stats.model_usage.entry(id.clone()).or_insert(0) += 1;
281        }
282
283        selected
284    }
285
286    fn select_by_priority(&self) -> Result<String, MultiModelError> {
287        self.metadata
288            .iter()
289            .max_by_key(|(_, meta)| meta.priority)
290            .map(|(id, _)| id.clone())
291            .ok_or_else(|| MultiModelError::RoutingFailed("No models available".to_string()))
292    }
293
294    fn select_by_latency(&self) -> Result<String, MultiModelError> {
295        self.metadata
296            .iter()
297            .min_by(|(_, a), (_, b)| {
298                a.resource_requirements
299                    .estimated_latency_ms
300                    .partial_cmp(&b.resource_requirements.estimated_latency_ms)
301                    .unwrap()
302            })
303            .map(|(id, _)| id.clone())
304            .ok_or_else(|| MultiModelError::RoutingFailed("No models available".to_string()))
305    }
306
307    fn select_by_accuracy(&self) -> Result<String, MultiModelError> {
308        // Would use historical accuracy data
309        // For now, just select highest priority
310        self.select_by_priority()
311    }
312
313    fn select_round_robin(&mut self) -> Result<String, MultiModelError> {
314        let model_ids: Vec<_> = self.models.keys().cloned().collect();
315        if model_ids.is_empty() {
316            return Err(MultiModelError::RoutingFailed(
317                "No models available".to_string(),
318            ));
319        }
320
321        let idx = self.stats.total_routings % model_ids.len();
322        Ok(model_ids[idx].clone())
323    }
324
325    fn select_random(&self) -> Result<String, MultiModelError> {
326        // In real implementation, use proper RNG
327        let model_ids: Vec<_> = self.models.keys().cloned().collect();
328        if model_ids.is_empty() {
329            return Err(MultiModelError::RoutingFailed(
330                "No models available".to_string(),
331            ));
332        }
333
334        Ok(model_ids[0].clone())
335    }
336
337    fn select_cascade(&self) -> Result<String, MultiModelError> {
338        // Start with fastest model
339        self.select_by_latency()
340    }
341
342    fn select_content_based(&self, _features: Option<&[f64]>) -> Result<String, MultiModelError> {
343        // Would analyze input features to select best model
344        // For now, fallback to priority
345        self.select_by_priority()
346    }
347
348    /// Get model by ID.
349    pub fn get_model(&self, model_id: &str) -> Option<&EinsumGraph> {
350        self.models.get(model_id)
351    }
352
353    /// Get model metadata.
354    pub fn get_metadata(&self, model_id: &str) -> Option<&ModelMetadata> {
355        self.metadata.get(model_id)
356    }
357
358    /// Get all registered model IDs.
359    pub fn model_ids(&self) -> Vec<String> {
360        self.models.keys().cloned().collect()
361    }
362
363    /// Get statistics.
364    pub fn stats(&self) -> &CoordinationStats {
365        &self.stats
366    }
367
368    /// Check if ensemble is configured.
369    pub fn has_ensemble(&self) -> bool {
370        self.ensemble_config.is_some()
371    }
372
373    /// Get ensemble configuration.
374    pub fn ensemble_config(&self) -> Option<&EnsembleConfig> {
375        self.ensemble_config.as_ref()
376    }
377
378    /// Estimate total resource requirements.
379    pub fn total_resource_requirements(&self) -> ResourceRequirements {
380        let mut total = ResourceRequirements {
381            memory_bytes: 0,
382            gpu_memory_bytes: Some(0),
383            estimated_flops: 0.0,
384            estimated_latency_ms: 0.0,
385        };
386
387        for metadata in self.metadata.values() {
388            let req = &metadata.resource_requirements;
389            total.memory_bytes += req.memory_bytes;
390            if let (Some(total_gpu), Some(req_gpu)) = (total.gpu_memory_bytes, req.gpu_memory_bytes)
391            {
392                total.gpu_memory_bytes = Some(total_gpu + req_gpu);
393            }
394            total.estimated_flops += req.estimated_flops;
395            // For latency, use max if parallel, sum if sequential
396            total.estimated_latency_ms = total.estimated_latency_ms.max(req.estimated_latency_ms);
397        }
398
399        total
400    }
401}
402
403impl Default for MultiModelCoordinator {
404    fn default() -> Self {
405        Self::new()
406    }
407}
408
409/// Coordination statistics.
410#[derive(Debug, Clone, Default, Serialize, Deserialize)]
411pub struct CoordinationStats {
412    /// Total models registered
413    pub total_models: usize,
414    /// Total routing decisions
415    pub total_routings: usize,
416    /// Total ensemble executions
417    pub total_ensemble_executions: usize,
418    /// Model usage counts
419    pub model_usage: HashMap<String, usize>,
420}
421
422impl CoordinationStats {
423    /// Get most used model.
424    pub fn most_used_model(&self) -> Option<(String, usize)> {
425        self.model_usage
426            .iter()
427            .max_by_key(|(_, &count)| count)
428            .map(|(id, &count)| (id.clone(), count))
429    }
430
431    /// Get model usage distribution.
432    pub fn usage_distribution(&self) -> HashMap<String, f64> {
433        let total = self.model_usage.values().sum::<usize>() as f64;
434        if total == 0.0 {
435            return HashMap::new();
436        }
437
438        self.model_usage
439            .iter()
440            .map(|(id, &count)| (id.clone(), count as f64 / total))
441            .collect()
442    }
443}
444
445/// Trait for multi-model ensemble execution.
446pub trait TlEnsembleExecutor {
447    /// Output type
448    type Output;
449    /// Error type
450    type Error;
451
452    /// Execute ensemble with given strategy.
453    fn execute_ensemble(
454        &self,
455        models: &[&EinsumGraph],
456        inputs: &[Self::Output],
457        strategy: EnsembleStrategy,
458    ) -> Result<Self::Output, Self::Error>;
459
460    /// Aggregate predictions from multiple models.
461    fn aggregate_predictions(
462        &self,
463        predictions: &[Self::Output],
464        strategy: EnsembleStrategy,
465    ) -> Result<Self::Output, Self::Error>;
466}
467
468/// Trait for model routing.
469pub trait TlModelRouter {
470    /// Select appropriate model based on input.
471    fn route_to_model(&self, input: &[f64]) -> Result<String, MultiModelError>;
472
473    /// Get routing confidence score.
474    fn routing_confidence(&self, input: &[f64], model_id: &str) -> f64;
475}
476
477#[cfg(test)]
478mod tests {
479    use super::*;
480    use tensorlogic_ir::{EinsumNode, OpType};
481
482    fn create_test_graph(_id: &str) -> EinsumGraph {
483        let mut graph = EinsumGraph::new();
484        graph.nodes.push(EinsumNode {
485            op: OpType::Einsum {
486                spec: "ij->ij".to_string(),
487            },
488            inputs: vec![],
489            outputs: vec![0],
490            metadata: Default::default(),
491        });
492        graph
493    }
494
495    fn create_test_metadata(id: &str, priority: u32) -> ModelMetadata {
496        ModelMetadata {
497            id: id.to_string(),
498            name: format!("Model {}", id),
499            version: "1.0".to_string(),
500            input_shapes: vec![vec![10, 10]],
501            output_shapes: vec![vec![10, 10]],
502            weight: 1.0,
503            priority,
504            resource_requirements: ResourceRequirements {
505                memory_bytes: 1024 * 1024,
506                gpu_memory_bytes: Some(512 * 1024),
507                estimated_flops: 1e9,
508                estimated_latency_ms: 10.0,
509            },
510        }
511    }
512
513    #[test]
514    fn test_ensemble_strategy() {
515        let config = EnsembleConfig::voting();
516        assert_eq!(config.strategy, EnsembleStrategy::MajorityVote);
517
518        let mut weights = HashMap::new();
519        weights.insert("model1".to_string(), 0.6);
520        weights.insert("model2".to_string(), 0.4);
521        let config = EnsembleConfig::weighted_average(weights);
522        assert_eq!(config.strategy, EnsembleStrategy::WeightedAverage);
523    }
524
525    #[test]
526    fn test_cascade_config() {
527        let config = CascadeConfig::two_tier("fast".to_string(), "accurate".to_string(), 0.9);
528        assert_eq!(config.model_sequence.len(), 2);
529        assert_eq!(config.confidence_thresholds[0], 0.9);
530
531        let config = CascadeConfig::three_tier(
532            "fast".to_string(),
533            "medium".to_string(),
534            "accurate".to_string(),
535            (0.8, 0.95),
536        );
537        assert_eq!(config.model_sequence.len(), 3);
538    }
539
540    #[test]
541    fn test_coordinator_creation() {
542        let coordinator = MultiModelCoordinator::new();
543        assert_eq!(coordinator.models.len(), 0);
544        assert_eq!(coordinator.stats.total_models, 0);
545    }
546
547    #[test]
548    fn test_model_registration() {
549        let mut coordinator = MultiModelCoordinator::new();
550
551        let graph = create_test_graph("model1");
552        let metadata = create_test_metadata("model1", 1);
553
554        assert!(coordinator.register_model(graph, metadata).is_ok());
555        assert_eq!(coordinator.stats.total_models, 1);
556        assert!(coordinator.get_model("model1").is_some());
557    }
558
559    #[test]
560    fn test_model_unregistration() {
561        let mut coordinator = MultiModelCoordinator::new();
562
563        let graph = create_test_graph("model1");
564        let metadata = create_test_metadata("model1", 1);
565        coordinator.register_model(graph, metadata).unwrap();
566
567        assert!(coordinator.unregister_model("model1").is_ok());
568        assert_eq!(coordinator.stats.total_models, 0);
569        assert!(coordinator.get_model("model1").is_none());
570    }
571
572    #[test]
573    fn test_routing_by_priority() {
574        let mut coordinator = MultiModelCoordinator::new();
575
576        coordinator
577            .register_model(
578                create_test_graph("model1"),
579                create_test_metadata("model1", 1),
580            )
581            .unwrap();
582        coordinator
583            .register_model(
584                create_test_graph("model2"),
585                create_test_metadata("model2", 5),
586            )
587            .unwrap();
588
589        coordinator.set_routing_strategy(RoutingStrategy::Priority);
590        let selected = coordinator.select_model(None).unwrap();
591        assert_eq!(selected, "model2"); // Higher priority
592    }
593
594    #[test]
595    fn test_routing_by_latency() {
596        let mut coordinator = MultiModelCoordinator::new();
597
598        let mut meta1 = create_test_metadata("model1", 1);
599        meta1.resource_requirements.estimated_latency_ms = 20.0;
600        let mut meta2 = create_test_metadata("model2", 1);
601        meta2.resource_requirements.estimated_latency_ms = 5.0;
602
603        coordinator
604            .register_model(create_test_graph("model1"), meta1)
605            .unwrap();
606        coordinator
607            .register_model(create_test_graph("model2"), meta2)
608            .unwrap();
609
610        coordinator.set_routing_strategy(RoutingStrategy::LowestLatency);
611        let selected = coordinator.select_model(None).unwrap();
612        assert_eq!(selected, "model2"); // Lower latency
613    }
614
615    #[test]
616    fn test_ensemble_configuration() {
617        let mut coordinator = MultiModelCoordinator::new();
618        assert!(!coordinator.has_ensemble());
619
620        coordinator.set_ensemble_config(EnsembleConfig::voting());
621        assert!(coordinator.has_ensemble());
622        assert_eq!(
623            coordinator.ensemble_config().unwrap().strategy,
624            EnsembleStrategy::MajorityVote
625        );
626    }
627
628    #[test]
629    fn test_total_resource_requirements() {
630        let mut coordinator = MultiModelCoordinator::new();
631
632        coordinator
633            .register_model(
634                create_test_graph("model1"),
635                create_test_metadata("model1", 1),
636            )
637            .unwrap();
638        coordinator
639            .register_model(
640                create_test_graph("model2"),
641                create_test_metadata("model2", 1),
642            )
643            .unwrap();
644
645        let total = coordinator.total_resource_requirements();
646        assert_eq!(total.memory_bytes, 2 * 1024 * 1024);
647        assert_eq!(total.gpu_memory_bytes, Some(2 * 512 * 1024));
648    }
649
650    #[test]
651    fn test_coordination_stats() {
652        let mut stats = CoordinationStats::default();
653        stats.model_usage.insert("model1".to_string(), 10);
654        stats.model_usage.insert("model2".to_string(), 5);
655
656        let (id, count) = stats.most_used_model().unwrap();
657        assert_eq!(id, "model1");
658        assert_eq!(count, 10);
659
660        let dist = stats.usage_distribution();
661        assert_eq!(dist.get("model1").unwrap(), &(10.0 / 15.0));
662    }
663
664    #[test]
665    fn test_round_robin_routing() {
666        let mut coordinator = MultiModelCoordinator::new();
667
668        coordinator
669            .register_model(
670                create_test_graph("model1"),
671                create_test_metadata("model1", 1),
672            )
673            .unwrap();
674        coordinator
675            .register_model(
676                create_test_graph("model2"),
677                create_test_metadata("model2", 1),
678            )
679            .unwrap();
680
681        coordinator.set_routing_strategy(RoutingStrategy::RoundRobin);
682
683        let id1 = coordinator.select_model(None).unwrap();
684        let id2 = coordinator.select_model(None).unwrap();
685
686        // Should alternate (though order may vary)
687        assert_ne!(id1, id2);
688    }
689
690    #[test]
691    fn test_model_ids() {
692        let mut coordinator = MultiModelCoordinator::new();
693
694        coordinator
695            .register_model(
696                create_test_graph("model1"),
697                create_test_metadata("model1", 1),
698            )
699            .unwrap();
700        coordinator
701            .register_model(
702                create_test_graph("model2"),
703                create_test_metadata("model2", 1),
704            )
705            .unwrap();
706
707        let ids = coordinator.model_ids();
708        assert_eq!(ids.len(), 2);
709        assert!(ids.contains(&"model1".to_string()));
710        assert!(ids.contains(&"model2".to_string()));
711    }
712}