Skip to main content

trustformers_wasm/
multi_model_manager.rs

1//! Multi-model management system for efficient model loading and switching
2//!
3//! This module provides comprehensive multi-model management capabilities including:
4//! - Model registry and metadata management
5//! - Dynamic model loading/unloading with memory optimization
6//! - Model switching and intelligent routing
7//! - Performance optimization for multi-model scenarios
8//! - Model versioning and A/B testing support
9//! - Resource allocation and cleanup
10
11use crate::debug::DebugLogger;
12use crate::quantization::WebQuantizer;
13#[cfg(feature = "indexeddb")]
14use crate::storage::ModelStorage;
15use js_sys::{Array, Date, Object};
16use serde::{Deserialize, Serialize};
17use std::string::{String, ToString};
18use std::vec::Vec;
19use std::{format, vec};
20use wasm_bindgen::prelude::*;
21
22/// Model status in the management system
23#[wasm_bindgen]
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
25pub enum ModelStatus {
26    /// Model is not loaded
27    NotLoaded,
28    /// Model is currently loading
29    Loading,
30    /// Model is loaded and ready for inference
31    Ready,
32    /// Model is currently unloading
33    Unloading,
34    /// Model encountered an error
35    Error,
36    /// Model is warming up (first inference)
37    WarmingUp,
38}
39
40/// Model priority for memory management
41#[wasm_bindgen]
42#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
43pub enum ModelPriority {
44    /// Low priority - can be unloaded first
45    Low = 0,
46    /// Normal priority
47    Normal = 1,
48    /// High priority - kept in memory longer
49    High = 2,
50    /// Critical priority - never auto-unloaded
51    Critical = 3,
52}
53
54/// Model deployment environment
55#[wasm_bindgen]
56#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
57pub enum DeploymentEnvironment {
58    /// Development environment
59    Development,
60    /// Staging environment
61    Staging,
62    /// Production environment
63    Production,
64    /// A/B testing environment
65    Testing,
66}
67
68/// Model metadata and configuration
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct ModelMetadata {
71    pub id: String,
72    pub name: String,
73    pub version: String,
74    pub description: String,
75    pub model_type: String,
76    pub architecture: String,
77    pub size_bytes: usize,
78    pub priority: ModelPriority,
79    pub tags: Vec<String>,
80    pub environment: DeploymentEnvironment,
81    pub created_at: f64,
82    pub last_used: f64,
83    pub usage_count: u32,
84    pub capabilities: Vec<String>,
85    pub requirements: ModelRequirements,
86    pub download_url: Option<String>,
87}
88
89/// Model resource requirements
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct ModelRequirements {
92    pub min_memory_mb: u32,
93    pub min_gpu_memory_mb: u32,
94    pub requires_gpu: bool,
95    pub requires_webgpu: bool,
96    pub min_cpu_cores: u32,
97    pub recommended_batch_size: u32,
98}
99
100/// Loaded model instance
101pub struct LoadedModel {
102    pub metadata: ModelMetadata,
103    pub status: ModelStatus,
104    pub session: Option<crate::InferenceSession>,
105    pub load_time: f64,
106    pub memory_usage: usize,
107    pub gpu_memory_usage: usize,
108    pub warmup_completed: bool,
109    pub performance_stats: ModelPerformanceStats,
110}
111
112/// Performance statistics for a model
113#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct ModelPerformanceStats {
115    pub inference_count: u32,
116    pub total_inference_time_ms: f64,
117    pub average_inference_time_ms: f64,
118    pub last_inference_time_ms: f64,
119    pub errors: u32,
120    pub cache_hits: u32,
121    pub cache_misses: u32,
122}
123
124/// Model routing configuration
125#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct ModelRouting {
127    pub route_id: String,
128    pub condition: RoutingCondition,
129    pub target_model_id: String,
130    pub weight: f32, // For A/B testing
131    pub enabled: bool,
132}
133
134/// Routing conditions
135#[derive(Debug, Clone, Serialize, Deserialize)]
136pub enum RoutingCondition {
137    /// Route based on input size
138    InputSize {
139        min_tokens: Option<u32>,
140        max_tokens: Option<u32>,
141    },
142    /// Route based on user segment
143    UserSegment { segment: String },
144    /// Route based on capability requirement
145    Capability { required_capability: String },
146    /// Route based on performance requirement
147    Performance { max_latency_ms: f64 },
148    /// Random routing for A/B testing
149    Random { percentage: f32 },
150    /// Always route (default)
151    Always,
152}
153
154/// Multi-model manager configuration
155#[wasm_bindgen]
156#[derive(Debug, Clone)]
157pub struct MultiModelConfig {
158    max_concurrent_models: usize,
159    max_memory_usage_mb: u32,
160    auto_unload_inactive: bool,
161    inactive_timeout_ms: u32,
162    enable_preloading: bool,
163    enable_model_warming: bool,
164    #[allow(dead_code)]
165    cache_enabled: bool,
166    #[allow(dead_code)]
167    performance_monitoring: bool,
168}
169
170impl Default for MultiModelConfig {
171    fn default() -> Self {
172        Self::new()
173    }
174}
175
176#[wasm_bindgen]
177impl MultiModelConfig {
178    /// Create a new multi-model configuration
179    #[wasm_bindgen(constructor)]
180    pub fn new() -> Self {
181        Self {
182            max_concurrent_models: 3,
183            max_memory_usage_mb: 500,
184            auto_unload_inactive: true,
185            inactive_timeout_ms: 300_000, // 5 minutes
186            enable_preloading: true,
187            enable_model_warming: true,
188            cache_enabled: true,
189            performance_monitoring: true,
190        }
191    }
192
193    /// Create a configuration optimized for development
194    pub fn development() -> Self {
195        Self {
196            max_concurrent_models: 2,
197            max_memory_usage_mb: 200,
198            auto_unload_inactive: false,
199            inactive_timeout_ms: 600_000, // 10 minutes
200            enable_preloading: false,
201            enable_model_warming: false,
202            cache_enabled: true,
203            performance_monitoring: true,
204        }
205    }
206
207    /// Create a configuration optimized for production
208    pub fn production() -> Self {
209        Self {
210            max_concurrent_models: 5,
211            max_memory_usage_mb: 1000,
212            auto_unload_inactive: true,
213            inactive_timeout_ms: 180_000, // 3 minutes
214            enable_preloading: true,
215            enable_model_warming: true,
216            cache_enabled: true,
217            performance_monitoring: false,
218        }
219    }
220
221    /// Set maximum concurrent models
222    pub fn set_max_concurrent_models(mut self, max: usize) -> Self {
223        self.max_concurrent_models = max;
224        self
225    }
226
227    /// Set maximum memory usage in MB
228    pub fn set_max_memory_usage_mb(mut self, mb: u32) -> Self {
229        self.max_memory_usage_mb = mb;
230        self
231    }
232
233    /// Enable/disable auto-unloading of inactive models
234    pub fn set_auto_unload_inactive(mut self, enabled: bool) -> Self {
235        self.auto_unload_inactive = enabled;
236        self
237    }
238
239    /// Set inactive timeout in milliseconds
240    pub fn set_inactive_timeout_ms(mut self, timeout: u32) -> Self {
241        self.inactive_timeout_ms = timeout;
242        self
243    }
244
245    /// Enable/disable model preloading
246    pub fn set_enable_preloading(mut self, enabled: bool) -> Self {
247        self.enable_preloading = enabled;
248        self
249    }
250}
251
252/// Multi-model manager
253#[wasm_bindgen]
254pub struct MultiModelManager {
255    config: MultiModelConfig,
256    models: Vec<LoadedModel>,
257    model_registry: Vec<ModelMetadata>,
258    routing_rules: Vec<ModelRouting>,
259    default_model_id: Option<String>,
260    #[cfg(feature = "indexeddb")]
261    storage: Option<ModelStorage>,
262    quantizer: Option<WebQuantizer>,
263    debug_logger: Option<DebugLogger>,
264}
265
266#[wasm_bindgen]
267impl MultiModelManager {
268    /// Create a new multi-model manager
269    #[wasm_bindgen(constructor)]
270    pub fn new(config: MultiModelConfig) -> Self {
271        Self {
272            config,
273            models: Vec::new(),
274            model_registry: Vec::new(),
275            routing_rules: Vec::new(),
276            default_model_id: None,
277            #[cfg(feature = "indexeddb")]
278            storage: None,
279            quantizer: None,
280            debug_logger: None,
281        }
282    }
283
284    /// Initialize storage for model caching
285    #[cfg(feature = "indexeddb")]
286    pub async fn initialize_storage(&mut self, max_storage_mb: f64) -> Result<(), JsValue> {
287        let mut storage = ModelStorage::new("trustformers-models".to_string(), max_storage_mb);
288        storage.initialize().await?;
289        self.storage = Some(storage);
290
291        if let Some(ref mut logger) = self.debug_logger {
292            logger.info(
293                &format!("Initialized model storage ({}MB)", max_storage_mb),
294                "multi_model",
295            );
296        }
297
298        Ok(())
299    }
300
301    /// Set debug logger
302    pub fn set_debug_logger(&mut self, logger: DebugLogger) {
303        self.debug_logger = Some(logger);
304    }
305
306    /// Set quantizer for model optimization
307    pub fn set_quantizer(&mut self, quantizer: WebQuantizer) {
308        self.quantizer = Some(quantizer);
309    }
310
311    /// Register a model in the system
312    pub fn register_model(&mut self, metadata: &str) -> Result<(), JsValue> {
313        let model_metadata: ModelMetadata = serde_json::from_str(metadata)
314            .map_err(|e| JsValue::from_str(&format!("Invalid metadata: {}", e)))?;
315
316        // Check if model already exists
317        if self.model_registry.iter().any(|m| m.id == model_metadata.id) {
318            return Err(JsValue::from_str("Model already registered"));
319        }
320
321        self.model_registry.push(model_metadata.clone());
322
323        if let Some(ref mut logger) = self.debug_logger {
324            logger.info(
325                &format!(
326                    "Registered model: {} ({})",
327                    model_metadata.name, model_metadata.id
328                ),
329                "multi_model",
330            );
331        }
332
333        Ok(())
334    }
335
336    /// Load a model by ID
337    pub async fn load_model(&mut self, model_id: &str) -> Result<(), JsValue> {
338        // Check if model is already loaded
339        if self.models.iter().any(|m| m.metadata.id == model_id) {
340            return Ok(());
341        }
342
343        // Find model metadata
344        let metadata = self
345            .model_registry
346            .iter()
347            .find(|m| m.id == model_id)
348            .ok_or_else(|| JsValue::from_str("Model not found in registry"))?
349            .clone();
350
351        // Check resource availability
352        self.ensure_resources_available(&metadata)?;
353
354        if let Some(ref mut logger) = self.debug_logger {
355            logger.start_timer(&format!("load_model_{}", model_id));
356            logger.info(&format!("Loading model: {}", metadata.name), "multi_model");
357        }
358
359        // Create inference session
360        let mut session = crate::InferenceSession::new(metadata.model_type.clone())?;
361        session.initialize_with_auto_device().await?;
362
363        // Load model data (from cache or URL)
364        #[cfg(feature = "indexeddb")]
365        if let Some(ref storage) = self.storage {
366            if let Some(cached_data) = storage.get_model(model_id).await? {
367                session.load_model(&cached_data).await?;
368            } else {
369                // Load from URL and cache
370                if let Some(url) = &metadata.download_url {
371                    let model_data = self.fetch_model_from_url(url).await.map_err(|e| {
372                        JsValue::from_str(&format!("Failed to fetch model from {}: {:?}", url, e))
373                    })?;
374                    session.load_model(&model_data).await?;
375                    storage
376                        .store_model(
377                            model_id,
378                            &metadata.name,
379                            &metadata.architecture,
380                            &metadata.version,
381                            &model_data,
382                        )
383                        .await?;
384                } else {
385                    return Err(JsValue::from_str("No download URL provided for model"));
386                }
387            }
388        } else {
389            // Load from URL without caching
390            if let Some(url) = &metadata.download_url {
391                let model_data = self.fetch_model_from_url(url).await.map_err(|e| {
392                    JsValue::from_str(&format!("Failed to fetch model from {}: {:?}", url, e))
393                })?;
394                session.load_model(&model_data).await?;
395            } else {
396                return Err(JsValue::from_str("No download URL provided for model"));
397            }
398        }
399
400        #[cfg(not(feature = "indexeddb"))]
401        {
402            // Load from URL without caching
403            if let Some(url) = &metadata.download_url {
404                let model_data = self.fetch_model_from_url(url).await.map_err(|e| {
405                    JsValue::from_str(&format!("Failed to fetch model from {}: {:?}", url, e))
406                })?;
407                session.load_model(&model_data).await?;
408            } else {
409                return Err(JsValue::from_str("No download URL provided for model"));
410            }
411        }
412
413        let load_time = Date::now();
414
415        let loaded_model = LoadedModel {
416            metadata: metadata.clone(),
417            status: ModelStatus::Ready,
418            session: Some(session),
419            load_time,
420            memory_usage: metadata.size_bytes,
421            gpu_memory_usage: 0, // Would be calculated
422            warmup_completed: false,
423            performance_stats: ModelPerformanceStats {
424                inference_count: 0,
425                total_inference_time_ms: 0.0,
426                average_inference_time_ms: 0.0,
427                last_inference_time_ms: 0.0,
428                errors: 0,
429                cache_hits: 0,
430                cache_misses: 0,
431            },
432        };
433
434        self.models.push(loaded_model);
435
436        if let Some(ref mut logger) = self.debug_logger {
437            logger.end_timer(&format!("load_model_{}", model_id));
438            logger.info(
439                &format!("Model loaded successfully: {}", metadata.name),
440                "multi_model",
441            );
442        }
443
444        // Perform warmup if enabled
445        if self.config.enable_model_warming {
446            self.warmup_model(model_id).await?;
447        }
448
449        Ok(())
450    }
451
452    /// Unload a model by ID
453    pub fn unload_model(&mut self, model_id: &str) -> Result<(), JsValue> {
454        if let Some(pos) = self.models.iter().position(|m| m.metadata.id == model_id) {
455            let model = self.models.remove(pos);
456
457            if let Some(ref mut logger) = self.debug_logger {
458                logger.info(
459                    &format!("Unloaded model: {}", model.metadata.name),
460                    "multi_model",
461                );
462            }
463
464            Ok(())
465        } else {
466            Err(JsValue::from_str("Model not found"))
467        }
468    }
469
470    /// Get a model for inference by ID
471    pub fn get_model_index(&self, model_id: &str) -> Option<usize> {
472        self.models.iter().position(|m| m.metadata.id == model_id)
473    }
474
475    /// Route a request to the appropriate model
476    pub fn route_request(&self, input_context: &str) -> Result<String, JsValue> {
477        // Parse input context (simplified)
478        let context: serde_json::Value = serde_json::from_str(input_context)
479            .unwrap_or_else(|_| serde_json::Value::Object(serde_json::Map::new()));
480
481        // Apply routing rules
482        for rule in &self.routing_rules {
483            if !rule.enabled {
484                continue;
485            }
486
487            if self.matches_routing_condition(&rule.condition, &context)? {
488                return Ok(rule.target_model_id.clone());
489            }
490        }
491
492        // Use default model if no rules match
493        self.default_model_id
494            .clone()
495            .ok_or_else(|| JsValue::from_str("No default model configured"))
496    }
497
498    /// Add a routing rule
499    pub fn add_routing_rule(&mut self, rule_json: &str) -> Result<(), JsValue> {
500        let rule: ModelRouting = serde_json::from_str(rule_json)
501            .map_err(|e| JsValue::from_str(&format!("Invalid routing rule: {}", e)))?;
502
503        self.routing_rules.push(rule);
504        Ok(())
505    }
506
507    /// Set the default model
508    pub fn set_default_model(&mut self, model_id: &str) {
509        self.default_model_id = Some(model_id.to_string());
510    }
511
512    /// Get list of loaded models
513    pub fn get_loaded_models(&self) -> Array {
514        let array = Array::new();
515
516        for model in &self.models {
517            let obj = Object::new();
518            let _ = js_sys::Reflect::set(&obj, &"id".into(), &model.metadata.id.clone().into());
519            let _ = js_sys::Reflect::set(&obj, &"name".into(), &model.metadata.name.clone().into());
520            let _ = js_sys::Reflect::set(
521                &obj,
522                &"status".into(),
523                &format!("{:?}", model.status).into(),
524            );
525            let _ = js_sys::Reflect::set(&obj, &"memory_usage".into(), &model.memory_usage.into());
526            let _ =
527                js_sys::Reflect::set(&obj, &"last_used".into(), &model.metadata.last_used.into());
528            let _ = js_sys::Reflect::set(
529                &obj,
530                &"usage_count".into(),
531                &model.metadata.usage_count.into(),
532            );
533            array.push(&obj);
534        }
535
536        array
537    }
538
539    /// Get performance statistics for all models
540    pub fn get_performance_stats(&self) -> String {
541        let stats: Vec<_> = self
542            .models
543            .iter()
544            .map(|m| (m.metadata.id.clone(), m.performance_stats.clone()))
545            .collect();
546
547        serde_json::to_string_pretty(&stats).unwrap_or_else(|_| "{}".to_string())
548    }
549
550    /// Optimize memory usage by unloading inactive models
551    pub fn optimize_memory(&mut self) -> Result<(), JsValue> {
552        if !self.config.auto_unload_inactive {
553            return Ok(());
554        }
555
556        let current_time = Date::now();
557        let mut models_to_unload = Vec::new();
558
559        // Find models to unload based on priority and last usage
560        for (i, model) in self.models.iter().enumerate() {
561            let inactive_time = current_time - model.metadata.last_used;
562
563            // Don't unload critical priority models
564            if model.metadata.priority == ModelPriority::Critical {
565                continue;
566            }
567
568            // Unload if inactive for too long
569            if inactive_time > self.config.inactive_timeout_ms as f64 {
570                models_to_unload.push((i, model.metadata.id.clone()));
571            }
572        }
573
574        // Sort by priority (unload low priority first)
575        models_to_unload.sort_by(|a, b| {
576            let model_a = &self.models[a.0];
577            let model_b = &self.models[b.0];
578            model_a.metadata.priority.cmp(&model_b.metadata.priority)
579        });
580
581        // Unload models
582        for (_, model_id) in models_to_unload {
583            self.unload_model(&model_id)?;
584
585            if let Some(ref mut logger) = self.debug_logger {
586                logger.info(
587                    &format!("Auto-unloaded inactive model: {}", model_id),
588                    "multi_model",
589                );
590            }
591        }
592
593        Ok(())
594    }
595
596    /// Preload models based on usage patterns
597    pub async fn preload_models(&mut self) -> Result<(), JsValue> {
598        if !self.config.enable_preloading {
599            return Ok(());
600        }
601
602        // Sort models by usage frequency and priority
603        let mut candidates: Vec<_> = self
604            .model_registry
605            .iter()
606            .filter(|m| !self.models.iter().any(|loaded| loaded.metadata.id == m.id))
607            .cloned()
608            .collect();
609
610        candidates.sort_by(|a, b| {
611            let a_score = (a.usage_count as f32) * (a.priority as u8 as f32);
612            let b_score = (b.usage_count as f32) * (b.priority as u8 as f32);
613            b_score.partial_cmp(&a_score).unwrap_or(std::cmp::Ordering::Equal)
614        });
615
616        // Preload top candidates if we have capacity
617        for candidate in
618            candidates.iter().take(self.config.max_concurrent_models - self.models.len())
619        {
620            if self.has_capacity_for_model(candidate) {
621                self.load_model(&candidate.id).await?;
622
623                if let Some(ref mut logger) = self.debug_logger {
624                    logger.info(
625                        &format!("Preloaded model: {}", candidate.name),
626                        "multi_model",
627                    );
628                }
629            }
630        }
631
632        Ok(())
633    }
634
635    /// Get current memory usage across all models
636    pub fn get_total_memory_usage(&self) -> usize {
637        self.models.iter().map(|m| m.memory_usage).sum()
638    }
639
640    /// Get system status summary
641    pub fn get_system_status(&self) -> String {
642        let total_memory = self.get_total_memory_usage();
643        let total_memory_mb = total_memory / (1024 * 1024);
644
645        // Manual JSON construction
646        format!(
647            r#"{{
648  "loaded_models": {},
649  "registered_models": {},
650  "routing_rules": {},
651  "total_memory_usage_mb": {},
652  "max_memory_mb": {},
653  "memory_utilization": {:.3},
654  "default_model": {}
655}}"#,
656            self.models.len(),
657            self.model_registry.len(),
658            self.routing_rules.len(),
659            total_memory_mb,
660            self.config.max_memory_usage_mb,
661            (total_memory_mb as f32) / (self.config.max_memory_usage_mb as f32),
662            if let Some(ref default) = self.default_model_id {
663                format!("\"{}\"", default)
664            } else {
665                "null".to_string()
666            }
667        )
668    }
669
670    // Private helper methods
671
672    async fn warmup_model(&mut self, model_id: &str) -> Result<(), JsValue> {
673        if let Some(model) = self.models.iter_mut().find(|m| m.metadata.id == model_id) {
674            model.status = ModelStatus::WarmingUp;
675
676            // Perform a dummy inference to warm up the model
677            // This would use actual inference in a real implementation
678
679            model.status = ModelStatus::Ready;
680            model.warmup_completed = true;
681
682            if let Some(ref mut logger) = self.debug_logger {
683                logger.info(
684                    &format!("Model warmed up: {}", model.metadata.name),
685                    "multi_model",
686                );
687            }
688        }
689
690        Ok(())
691    }
692
693    fn ensure_resources_available(&self, metadata: &ModelMetadata) -> Result<(), JsValue> {
694        // Check if we're at capacity
695        if self.models.len() >= self.config.max_concurrent_models {
696            return Err(JsValue::from_str("Maximum concurrent models reached"));
697        }
698
699        // Check memory requirements
700        let current_memory = self.get_total_memory_usage() / (1024 * 1024); // Convert to MB
701        let required_memory = current_memory + (metadata.size_bytes / (1024 * 1024));
702
703        if required_memory > self.config.max_memory_usage_mb as usize {
704            return Err(JsValue::from_str("Insufficient memory for model"));
705        }
706
707        Ok(())
708    }
709
710    fn has_capacity_for_model(&self, metadata: &ModelMetadata) -> bool {
711        if self.models.len() >= self.config.max_concurrent_models {
712            return false;
713        }
714
715        let current_memory = self.get_total_memory_usage() / (1024 * 1024);
716        let required_memory = current_memory + (metadata.size_bytes / (1024 * 1024));
717
718        required_memory <= self.config.max_memory_usage_mb as usize
719    }
720
721    fn matches_routing_condition(
722        &self,
723        condition: &RoutingCondition,
724        context: &serde_json::Value,
725    ) -> Result<bool, JsValue> {
726        match condition {
727            RoutingCondition::InputSize {
728                min_tokens,
729                max_tokens,
730            } => {
731                let token_count =
732                    context.get("token_count").and_then(|v| v.as_u64()).unwrap_or(0) as u32;
733
734                if let Some(min) = min_tokens {
735                    if token_count < *min {
736                        return Ok(false);
737                    }
738                }
739
740                if let Some(max) = max_tokens {
741                    if token_count > *max {
742                        return Ok(false);
743                    }
744                }
745
746                Ok(true)
747            },
748            RoutingCondition::UserSegment { segment } => {
749                let user_segment =
750                    context.get("user_segment").and_then(|v| v.as_str()).unwrap_or("");
751                Ok(user_segment == segment)
752            },
753            RoutingCondition::Capability {
754                required_capability,
755            } => {
756                let empty_vec = vec![];
757                let capabilities = context
758                    .get("required_capabilities")
759                    .and_then(|v| v.as_array())
760                    .unwrap_or(&empty_vec);
761
762                Ok(capabilities
763                    .iter()
764                    .any(|cap| cap.as_str().is_some_and(|s| s == required_capability)))
765            },
766            RoutingCondition::Performance { max_latency_ms } => {
767                let required_latency =
768                    context.get("max_latency_ms").and_then(|v| v.as_f64()).unwrap_or(f64::INFINITY);
769                Ok(required_latency <= *max_latency_ms)
770            },
771            RoutingCondition::Random { percentage } => {
772                let random_value = (Date::now() % 100.0) / 100.0;
773                Ok(random_value < (*percentage / 100.0) as f64)
774            },
775            RoutingCondition::Always => Ok(true),
776        }
777    }
778
779    /// Fetch model data from a remote URL using the Fetch API
780    async fn fetch_model_from_url(&self, url: &str) -> Result<Vec<u8>, JsValue> {
781        use wasm_bindgen::JsCast;
782        use wasm_bindgen_futures::JsFuture;
783
784        // Create fetch request
785        let request = web_sys::Request::new_with_str(url)?;
786        request.headers().set("Accept", "application/octet-stream")?;
787
788        // Get the global window object
789        let window = web_sys::window().ok_or("No global window object")?;
790
791        // Perform the fetch
792        let resp_value = JsFuture::from(window.fetch_with_request(&request)).await?;
793        let resp: web_sys::Response = resp_value.dyn_into()?;
794
795        // Check if the response is ok
796        if !resp.ok() {
797            return Err(JsValue::from_str(&format!(
798                "HTTP error: {} {}",
799                resp.status(),
800                resp.status_text()
801            )));
802        }
803
804        // Get the response as ArrayBuffer
805        let array_buffer = JsFuture::from(resp.array_buffer()?).await?;
806        let uint8_array = js_sys::Uint8Array::new(&array_buffer);
807
808        // Convert to Vec<u8>
809        let mut data = vec![0u8; uint8_array.length() as usize];
810        uint8_array.copy_to(&mut data);
811
812        web_sys::console::log_1(
813            &format!(
814                "Successfully downloaded {} bytes for MultiModelManager",
815                data.len()
816            )
817            .into(),
818        );
819
820        Ok(data)
821    }
822}
823
824/// Create a multi-model manager with development settings
825#[wasm_bindgen]
826pub fn create_development_multi_model_manager() -> MultiModelManager {
827    MultiModelManager::new(MultiModelConfig::development())
828}
829
830/// Create a multi-model manager with production settings
831#[wasm_bindgen]
832pub fn create_production_multi_model_manager() -> MultiModelManager {
833    MultiModelManager::new(MultiModelConfig::production())
834}
835
836#[cfg(test)]
837mod tests {
838    use super::*;
839
840    #[test]
841    fn test_multi_model_config() {
842        let config = MultiModelConfig::development();
843        assert_eq!(config.max_concurrent_models, 2);
844        assert_eq!(config.max_memory_usage_mb, 200);
845
846        let prod_config = MultiModelConfig::production();
847        assert_eq!(prod_config.max_concurrent_models, 5);
848    }
849
850    #[test]
851    #[cfg(target_arch = "wasm32")]
852    fn test_model_registration() {
853        let config = MultiModelConfig::new();
854        let mut manager = MultiModelManager::new(config);
855
856        let metadata = ModelMetadata {
857            id: "test-model".to_string(),
858            name: "Test Model".to_string(),
859            version: "1.0".to_string(),
860            description: "A test model".to_string(),
861            model_type: "transformer".to_string(),
862            architecture: "bert".to_string(),
863            size_bytes: 1024 * 1024,
864            priority: ModelPriority::Normal,
865            tags: vec!["test".to_string()],
866            environment: DeploymentEnvironment::Development,
867            created_at: Date::now(),
868            last_used: Date::now(),
869            usage_count: 0,
870            capabilities: vec!["text-generation".to_string()],
871            requirements: ModelRequirements {
872                min_memory_mb: 100,
873                min_gpu_memory_mb: 0,
874                requires_gpu: false,
875                requires_webgpu: false,
876                min_cpu_cores: 1,
877                recommended_batch_size: 1,
878            },
879            download_url: Some("https://example.com/models/test-model.bin".to_string()),
880        };
881
882        let metadata_json =
883            serde_json::to_string(&metadata).expect("JSON serialization should succeed");
884        let result = manager.register_model(&metadata_json);
885        assert!(result.is_ok());
886        assert_eq!(manager.model_registry.len(), 1);
887    }
888
889    #[test]
890    #[cfg(not(target_arch = "wasm32"))]
891    fn test_model_manager_config() {
892        // Test only configuration creation for non-WASM targets
893        let config = MultiModelConfig::new();
894        let manager = MultiModelManager::new(config);
895        assert_eq!(manager.model_registry.len(), 0);
896        assert!(manager.routing_rules.is_empty());
897    }
898
899    #[test]
900    fn test_routing_conditions() {
901        let manager = MultiModelManager::new(MultiModelConfig::new());
902
903        let condition = RoutingCondition::InputSize {
904            min_tokens: Some(10),
905            max_tokens: Some(100),
906        };
907        let context: serde_json::Value = serde_json::from_str(r#"{"token_count": 50}"#)
908            .expect("JSON parsing should succeed for valid test input");
909
910        let matches = manager
911            .matches_routing_condition(&condition, &context)
912            .expect("matching should succeed in test");
913        assert!(matches);
914
915        let context_too_small: serde_json::Value = serde_json::from_str(r#"{"token_count": 5}"#)
916            .expect("JSON parsing should succeed for valid test input");
917        let matches_small = manager
918            .matches_routing_condition(&condition, &context_too_small)
919            .expect("matching should succeed in test");
920        assert!(!matches_small);
921    }
922}