1use crate::debug::DebugLogger;
12use crate::quantization::WebQuantizer;
13#[cfg(feature = "indexeddb")]
14use crate::storage::ModelStorage;
15use js_sys::{Array, Date, Object};
16use serde::{Deserialize, Serialize};
17use std::string::{String, ToString};
18use std::vec::Vec;
19use std::{format, vec};
20use wasm_bindgen::prelude::*;
21
22#[wasm_bindgen]
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
25pub enum ModelStatus {
26 NotLoaded,
28 Loading,
30 Ready,
32 Unloading,
34 Error,
36 WarmingUp,
38}
39
40#[wasm_bindgen]
42#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
43pub enum ModelPriority {
44 Low = 0,
46 Normal = 1,
48 High = 2,
50 Critical = 3,
52}
53
54#[wasm_bindgen]
56#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
57pub enum DeploymentEnvironment {
58 Development,
60 Staging,
62 Production,
64 Testing,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct ModelMetadata {
71 pub id: String,
72 pub name: String,
73 pub version: String,
74 pub description: String,
75 pub model_type: String,
76 pub architecture: String,
77 pub size_bytes: usize,
78 pub priority: ModelPriority,
79 pub tags: Vec<String>,
80 pub environment: DeploymentEnvironment,
81 pub created_at: f64,
82 pub last_used: f64,
83 pub usage_count: u32,
84 pub capabilities: Vec<String>,
85 pub requirements: ModelRequirements,
86 pub download_url: Option<String>,
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct ModelRequirements {
92 pub min_memory_mb: u32,
93 pub min_gpu_memory_mb: u32,
94 pub requires_gpu: bool,
95 pub requires_webgpu: bool,
96 pub min_cpu_cores: u32,
97 pub recommended_batch_size: u32,
98}
99
100pub struct LoadedModel {
102 pub metadata: ModelMetadata,
103 pub status: ModelStatus,
104 pub session: Option<crate::InferenceSession>,
105 pub load_time: f64,
106 pub memory_usage: usize,
107 pub gpu_memory_usage: usize,
108 pub warmup_completed: bool,
109 pub performance_stats: ModelPerformanceStats,
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct ModelPerformanceStats {
115 pub inference_count: u32,
116 pub total_inference_time_ms: f64,
117 pub average_inference_time_ms: f64,
118 pub last_inference_time_ms: f64,
119 pub errors: u32,
120 pub cache_hits: u32,
121 pub cache_misses: u32,
122}
123
124#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct ModelRouting {
127 pub route_id: String,
128 pub condition: RoutingCondition,
129 pub target_model_id: String,
130 pub weight: f32, pub enabled: bool,
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize)]
136pub enum RoutingCondition {
137 InputSize {
139 min_tokens: Option<u32>,
140 max_tokens: Option<u32>,
141 },
142 UserSegment { segment: String },
144 Capability { required_capability: String },
146 Performance { max_latency_ms: f64 },
148 Random { percentage: f32 },
150 Always,
152}
153
154#[wasm_bindgen]
156#[derive(Debug, Clone)]
157pub struct MultiModelConfig {
158 max_concurrent_models: usize,
159 max_memory_usage_mb: u32,
160 auto_unload_inactive: bool,
161 inactive_timeout_ms: u32,
162 enable_preloading: bool,
163 enable_model_warming: bool,
164 #[allow(dead_code)]
165 cache_enabled: bool,
166 #[allow(dead_code)]
167 performance_monitoring: bool,
168}
169
170impl Default for MultiModelConfig {
171 fn default() -> Self {
172 Self::new()
173 }
174}
175
176#[wasm_bindgen]
177impl MultiModelConfig {
178 #[wasm_bindgen(constructor)]
180 pub fn new() -> Self {
181 Self {
182 max_concurrent_models: 3,
183 max_memory_usage_mb: 500,
184 auto_unload_inactive: true,
185 inactive_timeout_ms: 300_000, enable_preloading: true,
187 enable_model_warming: true,
188 cache_enabled: true,
189 performance_monitoring: true,
190 }
191 }
192
193 pub fn development() -> Self {
195 Self {
196 max_concurrent_models: 2,
197 max_memory_usage_mb: 200,
198 auto_unload_inactive: false,
199 inactive_timeout_ms: 600_000, enable_preloading: false,
201 enable_model_warming: false,
202 cache_enabled: true,
203 performance_monitoring: true,
204 }
205 }
206
207 pub fn production() -> Self {
209 Self {
210 max_concurrent_models: 5,
211 max_memory_usage_mb: 1000,
212 auto_unload_inactive: true,
213 inactive_timeout_ms: 180_000, enable_preloading: true,
215 enable_model_warming: true,
216 cache_enabled: true,
217 performance_monitoring: false,
218 }
219 }
220
221 pub fn set_max_concurrent_models(mut self, max: usize) -> Self {
223 self.max_concurrent_models = max;
224 self
225 }
226
227 pub fn set_max_memory_usage_mb(mut self, mb: u32) -> Self {
229 self.max_memory_usage_mb = mb;
230 self
231 }
232
233 pub fn set_auto_unload_inactive(mut self, enabled: bool) -> Self {
235 self.auto_unload_inactive = enabled;
236 self
237 }
238
239 pub fn set_inactive_timeout_ms(mut self, timeout: u32) -> Self {
241 self.inactive_timeout_ms = timeout;
242 self
243 }
244
245 pub fn set_enable_preloading(mut self, enabled: bool) -> Self {
247 self.enable_preloading = enabled;
248 self
249 }
250}
251
252#[wasm_bindgen]
254pub struct MultiModelManager {
255 config: MultiModelConfig,
256 models: Vec<LoadedModel>,
257 model_registry: Vec<ModelMetadata>,
258 routing_rules: Vec<ModelRouting>,
259 default_model_id: Option<String>,
260 #[cfg(feature = "indexeddb")]
261 storage: Option<ModelStorage>,
262 quantizer: Option<WebQuantizer>,
263 debug_logger: Option<DebugLogger>,
264}
265
266#[wasm_bindgen]
267impl MultiModelManager {
268 #[wasm_bindgen(constructor)]
270 pub fn new(config: MultiModelConfig) -> Self {
271 Self {
272 config,
273 models: Vec::new(),
274 model_registry: Vec::new(),
275 routing_rules: Vec::new(),
276 default_model_id: None,
277 #[cfg(feature = "indexeddb")]
278 storage: None,
279 quantizer: None,
280 debug_logger: None,
281 }
282 }
283
284 #[cfg(feature = "indexeddb")]
286 pub async fn initialize_storage(&mut self, max_storage_mb: f64) -> Result<(), JsValue> {
287 let mut storage = ModelStorage::new("trustformers-models".to_string(), max_storage_mb);
288 storage.initialize().await?;
289 self.storage = Some(storage);
290
291 if let Some(ref mut logger) = self.debug_logger {
292 logger.info(
293 &format!("Initialized model storage ({}MB)", max_storage_mb),
294 "multi_model",
295 );
296 }
297
298 Ok(())
299 }
300
301 pub fn set_debug_logger(&mut self, logger: DebugLogger) {
303 self.debug_logger = Some(logger);
304 }
305
306 pub fn set_quantizer(&mut self, quantizer: WebQuantizer) {
308 self.quantizer = Some(quantizer);
309 }
310
311 pub fn register_model(&mut self, metadata: &str) -> Result<(), JsValue> {
313 let model_metadata: ModelMetadata = serde_json::from_str(metadata)
314 .map_err(|e| JsValue::from_str(&format!("Invalid metadata: {}", e)))?;
315
316 if self.model_registry.iter().any(|m| m.id == model_metadata.id) {
318 return Err(JsValue::from_str("Model already registered"));
319 }
320
321 self.model_registry.push(model_metadata.clone());
322
323 if let Some(ref mut logger) = self.debug_logger {
324 logger.info(
325 &format!(
326 "Registered model: {} ({})",
327 model_metadata.name, model_metadata.id
328 ),
329 "multi_model",
330 );
331 }
332
333 Ok(())
334 }
335
336 pub async fn load_model(&mut self, model_id: &str) -> Result<(), JsValue> {
338 if self.models.iter().any(|m| m.metadata.id == model_id) {
340 return Ok(());
341 }
342
343 let metadata = self
345 .model_registry
346 .iter()
347 .find(|m| m.id == model_id)
348 .ok_or_else(|| JsValue::from_str("Model not found in registry"))?
349 .clone();
350
351 self.ensure_resources_available(&metadata)?;
353
354 if let Some(ref mut logger) = self.debug_logger {
355 logger.start_timer(&format!("load_model_{}", model_id));
356 logger.info(&format!("Loading model: {}", metadata.name), "multi_model");
357 }
358
359 let mut session = crate::InferenceSession::new(metadata.model_type.clone())?;
361 session.initialize_with_auto_device().await?;
362
363 #[cfg(feature = "indexeddb")]
365 if let Some(ref storage) = self.storage {
366 if let Some(cached_data) = storage.get_model(model_id).await? {
367 session.load_model(&cached_data).await?;
368 } else {
369 if let Some(url) = &metadata.download_url {
371 let model_data = self.fetch_model_from_url(url).await.map_err(|e| {
372 JsValue::from_str(&format!("Failed to fetch model from {}: {:?}", url, e))
373 })?;
374 session.load_model(&model_data).await?;
375 storage
376 .store_model(
377 model_id,
378 &metadata.name,
379 &metadata.architecture,
380 &metadata.version,
381 &model_data,
382 )
383 .await?;
384 } else {
385 return Err(JsValue::from_str("No download URL provided for model"));
386 }
387 }
388 } else {
389 if let Some(url) = &metadata.download_url {
391 let model_data = self.fetch_model_from_url(url).await.map_err(|e| {
392 JsValue::from_str(&format!("Failed to fetch model from {}: {:?}", url, e))
393 })?;
394 session.load_model(&model_data).await?;
395 } else {
396 return Err(JsValue::from_str("No download URL provided for model"));
397 }
398 }
399
400 #[cfg(not(feature = "indexeddb"))]
401 {
402 if let Some(url) = &metadata.download_url {
404 let model_data = self.fetch_model_from_url(url).await.map_err(|e| {
405 JsValue::from_str(&format!("Failed to fetch model from {}: {:?}", url, e))
406 })?;
407 session.load_model(&model_data).await?;
408 } else {
409 return Err(JsValue::from_str("No download URL provided for model"));
410 }
411 }
412
413 let load_time = Date::now();
414
415 let loaded_model = LoadedModel {
416 metadata: metadata.clone(),
417 status: ModelStatus::Ready,
418 session: Some(session),
419 load_time,
420 memory_usage: metadata.size_bytes,
421 gpu_memory_usage: 0, warmup_completed: false,
423 performance_stats: ModelPerformanceStats {
424 inference_count: 0,
425 total_inference_time_ms: 0.0,
426 average_inference_time_ms: 0.0,
427 last_inference_time_ms: 0.0,
428 errors: 0,
429 cache_hits: 0,
430 cache_misses: 0,
431 },
432 };
433
434 self.models.push(loaded_model);
435
436 if let Some(ref mut logger) = self.debug_logger {
437 logger.end_timer(&format!("load_model_{}", model_id));
438 logger.info(
439 &format!("Model loaded successfully: {}", metadata.name),
440 "multi_model",
441 );
442 }
443
444 if self.config.enable_model_warming {
446 self.warmup_model(model_id).await?;
447 }
448
449 Ok(())
450 }
451
452 pub fn unload_model(&mut self, model_id: &str) -> Result<(), JsValue> {
454 if let Some(pos) = self.models.iter().position(|m| m.metadata.id == model_id) {
455 let model = self.models.remove(pos);
456
457 if let Some(ref mut logger) = self.debug_logger {
458 logger.info(
459 &format!("Unloaded model: {}", model.metadata.name),
460 "multi_model",
461 );
462 }
463
464 Ok(())
465 } else {
466 Err(JsValue::from_str("Model not found"))
467 }
468 }
469
470 pub fn get_model_index(&self, model_id: &str) -> Option<usize> {
472 self.models.iter().position(|m| m.metadata.id == model_id)
473 }
474
475 pub fn route_request(&self, input_context: &str) -> Result<String, JsValue> {
477 let context: serde_json::Value = serde_json::from_str(input_context)
479 .unwrap_or_else(|_| serde_json::Value::Object(serde_json::Map::new()));
480
481 for rule in &self.routing_rules {
483 if !rule.enabled {
484 continue;
485 }
486
487 if self.matches_routing_condition(&rule.condition, &context)? {
488 return Ok(rule.target_model_id.clone());
489 }
490 }
491
492 self.default_model_id
494 .clone()
495 .ok_or_else(|| JsValue::from_str("No default model configured"))
496 }
497
498 pub fn add_routing_rule(&mut self, rule_json: &str) -> Result<(), JsValue> {
500 let rule: ModelRouting = serde_json::from_str(rule_json)
501 .map_err(|e| JsValue::from_str(&format!("Invalid routing rule: {}", e)))?;
502
503 self.routing_rules.push(rule);
504 Ok(())
505 }
506
507 pub fn set_default_model(&mut self, model_id: &str) {
509 self.default_model_id = Some(model_id.to_string());
510 }
511
512 pub fn get_loaded_models(&self) -> Array {
514 let array = Array::new();
515
516 for model in &self.models {
517 let obj = Object::new();
518 let _ = js_sys::Reflect::set(&obj, &"id".into(), &model.metadata.id.clone().into());
519 let _ = js_sys::Reflect::set(&obj, &"name".into(), &model.metadata.name.clone().into());
520 let _ = js_sys::Reflect::set(
521 &obj,
522 &"status".into(),
523 &format!("{:?}", model.status).into(),
524 );
525 let _ = js_sys::Reflect::set(&obj, &"memory_usage".into(), &model.memory_usage.into());
526 let _ =
527 js_sys::Reflect::set(&obj, &"last_used".into(), &model.metadata.last_used.into());
528 let _ = js_sys::Reflect::set(
529 &obj,
530 &"usage_count".into(),
531 &model.metadata.usage_count.into(),
532 );
533 array.push(&obj);
534 }
535
536 array
537 }
538
539 pub fn get_performance_stats(&self) -> String {
541 let stats: Vec<_> = self
542 .models
543 .iter()
544 .map(|m| (m.metadata.id.clone(), m.performance_stats.clone()))
545 .collect();
546
547 serde_json::to_string_pretty(&stats).unwrap_or_else(|_| "{}".to_string())
548 }
549
550 pub fn optimize_memory(&mut self) -> Result<(), JsValue> {
552 if !self.config.auto_unload_inactive {
553 return Ok(());
554 }
555
556 let current_time = Date::now();
557 let mut models_to_unload = Vec::new();
558
559 for (i, model) in self.models.iter().enumerate() {
561 let inactive_time = current_time - model.metadata.last_used;
562
563 if model.metadata.priority == ModelPriority::Critical {
565 continue;
566 }
567
568 if inactive_time > self.config.inactive_timeout_ms as f64 {
570 models_to_unload.push((i, model.metadata.id.clone()));
571 }
572 }
573
574 models_to_unload.sort_by(|a, b| {
576 let model_a = &self.models[a.0];
577 let model_b = &self.models[b.0];
578 model_a.metadata.priority.cmp(&model_b.metadata.priority)
579 });
580
581 for (_, model_id) in models_to_unload {
583 self.unload_model(&model_id)?;
584
585 if let Some(ref mut logger) = self.debug_logger {
586 logger.info(
587 &format!("Auto-unloaded inactive model: {}", model_id),
588 "multi_model",
589 );
590 }
591 }
592
593 Ok(())
594 }
595
596 pub async fn preload_models(&mut self) -> Result<(), JsValue> {
598 if !self.config.enable_preloading {
599 return Ok(());
600 }
601
602 let mut candidates: Vec<_> = self
604 .model_registry
605 .iter()
606 .filter(|m| !self.models.iter().any(|loaded| loaded.metadata.id == m.id))
607 .cloned()
608 .collect();
609
610 candidates.sort_by(|a, b| {
611 let a_score = (a.usage_count as f32) * (a.priority as u8 as f32);
612 let b_score = (b.usage_count as f32) * (b.priority as u8 as f32);
613 b_score.partial_cmp(&a_score).unwrap_or(std::cmp::Ordering::Equal)
614 });
615
616 for candidate in
618 candidates.iter().take(self.config.max_concurrent_models - self.models.len())
619 {
620 if self.has_capacity_for_model(candidate) {
621 self.load_model(&candidate.id).await?;
622
623 if let Some(ref mut logger) = self.debug_logger {
624 logger.info(
625 &format!("Preloaded model: {}", candidate.name),
626 "multi_model",
627 );
628 }
629 }
630 }
631
632 Ok(())
633 }
634
635 pub fn get_total_memory_usage(&self) -> usize {
637 self.models.iter().map(|m| m.memory_usage).sum()
638 }
639
640 pub fn get_system_status(&self) -> String {
642 let total_memory = self.get_total_memory_usage();
643 let total_memory_mb = total_memory / (1024 * 1024);
644
645 format!(
647 r#"{{
648 "loaded_models": {},
649 "registered_models": {},
650 "routing_rules": {},
651 "total_memory_usage_mb": {},
652 "max_memory_mb": {},
653 "memory_utilization": {:.3},
654 "default_model": {}
655}}"#,
656 self.models.len(),
657 self.model_registry.len(),
658 self.routing_rules.len(),
659 total_memory_mb,
660 self.config.max_memory_usage_mb,
661 (total_memory_mb as f32) / (self.config.max_memory_usage_mb as f32),
662 if let Some(ref default) = self.default_model_id {
663 format!("\"{}\"", default)
664 } else {
665 "null".to_string()
666 }
667 )
668 }
669
670 async fn warmup_model(&mut self, model_id: &str) -> Result<(), JsValue> {
673 if let Some(model) = self.models.iter_mut().find(|m| m.metadata.id == model_id) {
674 model.status = ModelStatus::WarmingUp;
675
676 model.status = ModelStatus::Ready;
680 model.warmup_completed = true;
681
682 if let Some(ref mut logger) = self.debug_logger {
683 logger.info(
684 &format!("Model warmed up: {}", model.metadata.name),
685 "multi_model",
686 );
687 }
688 }
689
690 Ok(())
691 }
692
693 fn ensure_resources_available(&self, metadata: &ModelMetadata) -> Result<(), JsValue> {
694 if self.models.len() >= self.config.max_concurrent_models {
696 return Err(JsValue::from_str("Maximum concurrent models reached"));
697 }
698
699 let current_memory = self.get_total_memory_usage() / (1024 * 1024); let required_memory = current_memory + (metadata.size_bytes / (1024 * 1024));
702
703 if required_memory > self.config.max_memory_usage_mb as usize {
704 return Err(JsValue::from_str("Insufficient memory for model"));
705 }
706
707 Ok(())
708 }
709
710 fn has_capacity_for_model(&self, metadata: &ModelMetadata) -> bool {
711 if self.models.len() >= self.config.max_concurrent_models {
712 return false;
713 }
714
715 let current_memory = self.get_total_memory_usage() / (1024 * 1024);
716 let required_memory = current_memory + (metadata.size_bytes / (1024 * 1024));
717
718 required_memory <= self.config.max_memory_usage_mb as usize
719 }
720
721 fn matches_routing_condition(
722 &self,
723 condition: &RoutingCondition,
724 context: &serde_json::Value,
725 ) -> Result<bool, JsValue> {
726 match condition {
727 RoutingCondition::InputSize {
728 min_tokens,
729 max_tokens,
730 } => {
731 let token_count =
732 context.get("token_count").and_then(|v| v.as_u64()).unwrap_or(0) as u32;
733
734 if let Some(min) = min_tokens {
735 if token_count < *min {
736 return Ok(false);
737 }
738 }
739
740 if let Some(max) = max_tokens {
741 if token_count > *max {
742 return Ok(false);
743 }
744 }
745
746 Ok(true)
747 },
748 RoutingCondition::UserSegment { segment } => {
749 let user_segment =
750 context.get("user_segment").and_then(|v| v.as_str()).unwrap_or("");
751 Ok(user_segment == segment)
752 },
753 RoutingCondition::Capability {
754 required_capability,
755 } => {
756 let empty_vec = vec![];
757 let capabilities = context
758 .get("required_capabilities")
759 .and_then(|v| v.as_array())
760 .unwrap_or(&empty_vec);
761
762 Ok(capabilities
763 .iter()
764 .any(|cap| cap.as_str().is_some_and(|s| s == required_capability)))
765 },
766 RoutingCondition::Performance { max_latency_ms } => {
767 let required_latency =
768 context.get("max_latency_ms").and_then(|v| v.as_f64()).unwrap_or(f64::INFINITY);
769 Ok(required_latency <= *max_latency_ms)
770 },
771 RoutingCondition::Random { percentage } => {
772 let random_value = (Date::now() % 100.0) / 100.0;
773 Ok(random_value < (*percentage / 100.0) as f64)
774 },
775 RoutingCondition::Always => Ok(true),
776 }
777 }
778
779 async fn fetch_model_from_url(&self, url: &str) -> Result<Vec<u8>, JsValue> {
781 use wasm_bindgen::JsCast;
782 use wasm_bindgen_futures::JsFuture;
783
784 let request = web_sys::Request::new_with_str(url)?;
786 request.headers().set("Accept", "application/octet-stream")?;
787
788 let window = web_sys::window().ok_or("No global window object")?;
790
791 let resp_value = JsFuture::from(window.fetch_with_request(&request)).await?;
793 let resp: web_sys::Response = resp_value.dyn_into()?;
794
795 if !resp.ok() {
797 return Err(JsValue::from_str(&format!(
798 "HTTP error: {} {}",
799 resp.status(),
800 resp.status_text()
801 )));
802 }
803
804 let array_buffer = JsFuture::from(resp.array_buffer()?).await?;
806 let uint8_array = js_sys::Uint8Array::new(&array_buffer);
807
808 let mut data = vec![0u8; uint8_array.length() as usize];
810 uint8_array.copy_to(&mut data);
811
812 web_sys::console::log_1(
813 &format!(
814 "Successfully downloaded {} bytes for MultiModelManager",
815 data.len()
816 )
817 .into(),
818 );
819
820 Ok(data)
821 }
822}
823
824#[wasm_bindgen]
826pub fn create_development_multi_model_manager() -> MultiModelManager {
827 MultiModelManager::new(MultiModelConfig::development())
828}
829
830#[wasm_bindgen]
832pub fn create_production_multi_model_manager() -> MultiModelManager {
833 MultiModelManager::new(MultiModelConfig::production())
834}
835
836#[cfg(test)]
837mod tests {
838 use super::*;
839
840 #[test]
841 fn test_multi_model_config() {
842 let config = MultiModelConfig::development();
843 assert_eq!(config.max_concurrent_models, 2);
844 assert_eq!(config.max_memory_usage_mb, 200);
845
846 let prod_config = MultiModelConfig::production();
847 assert_eq!(prod_config.max_concurrent_models, 5);
848 }
849
850 #[test]
851 #[cfg(target_arch = "wasm32")]
852 fn test_model_registration() {
853 let config = MultiModelConfig::new();
854 let mut manager = MultiModelManager::new(config);
855
856 let metadata = ModelMetadata {
857 id: "test-model".to_string(),
858 name: "Test Model".to_string(),
859 version: "1.0".to_string(),
860 description: "A test model".to_string(),
861 model_type: "transformer".to_string(),
862 architecture: "bert".to_string(),
863 size_bytes: 1024 * 1024,
864 priority: ModelPriority::Normal,
865 tags: vec!["test".to_string()],
866 environment: DeploymentEnvironment::Development,
867 created_at: Date::now(),
868 last_used: Date::now(),
869 usage_count: 0,
870 capabilities: vec!["text-generation".to_string()],
871 requirements: ModelRequirements {
872 min_memory_mb: 100,
873 min_gpu_memory_mb: 0,
874 requires_gpu: false,
875 requires_webgpu: false,
876 min_cpu_cores: 1,
877 recommended_batch_size: 1,
878 },
879 download_url: Some("https://example.com/models/test-model.bin".to_string()),
880 };
881
882 let metadata_json =
883 serde_json::to_string(&metadata).expect("JSON serialization should succeed");
884 let result = manager.register_model(&metadata_json);
885 assert!(result.is_ok());
886 assert_eq!(manager.model_registry.len(), 1);
887 }
888
889 #[test]
890 #[cfg(not(target_arch = "wasm32"))]
891 fn test_model_manager_config() {
892 let config = MultiModelConfig::new();
894 let manager = MultiModelManager::new(config);
895 assert_eq!(manager.model_registry.len(), 0);
896 assert!(manager.routing_rules.is_empty());
897 }
898
899 #[test]
900 fn test_routing_conditions() {
901 let manager = MultiModelManager::new(MultiModelConfig::new());
902
903 let condition = RoutingCondition::InputSize {
904 min_tokens: Some(10),
905 max_tokens: Some(100),
906 };
907 let context: serde_json::Value = serde_json::from_str(r#"{"token_count": 50}"#)
908 .expect("JSON parsing should succeed for valid test input");
909
910 let matches = manager
911 .matches_routing_condition(&condition, &context)
912 .expect("matching should succeed in test");
913 assert!(matches);
914
915 let context_too_small: serde_json::Value = serde_json::from_str(r#"{"token_count": 5}"#)
916 .expect("JSON parsing should succeed for valid test input");
917 let matches_small = manager
918 .matches_routing_condition(&condition, &context_too_small)
919 .expect("matching should succeed in test");
920 assert!(!matches_small);
921 }
922}