Skip to main content

torsh_models/
registry.rs

1//! Model registry for managing pre-trained models
2
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5use std::sync::{Arc, Mutex};
6
7use crate::{ModelError, ModelResult};
8use serde::{Deserialize, Serialize};
9use sha2::Digest;
10
11/// Information about a pre-trained model
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct ModelInfo {
14    /// Model name
15    pub name: String,
16    /// Model version
17    pub version: String,
18    /// Model description
19    pub description: String,
20    /// Model architecture
21    pub architecture: String,
22    /// Model domain (vision, nlp, etc.)
23    pub domain: String,
24    /// Input shape/specifications
25    pub input_spec: String,
26    /// Output shape/specifications  
27    pub output_spec: String,
28    /// Model file path or URL
29    pub source: ModelSource,
30    /// Model size in bytes
31    pub size_bytes: u64,
32    /// Model parameters count
33    pub parameters: u64,
34    /// Model accuracy metrics
35    pub metrics: HashMap<String, f32>,
36    /// Model tags for categorization
37    pub tags: Vec<String>,
38    /// License information
39    pub license: String,
40    /// Citation information
41    pub citation: Option<String>,
42    /// Checksum for validation
43    pub checksum: String,
44}
45
46/// Model source specification
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub enum ModelSource {
49    /// Local file path
50    Local(PathBuf),
51    /// HTTP/HTTPS URL
52    Url(String),
53    /// Hugging Face Hub model
54    HuggingFace { repo: String, filename: String },
55    /// Custom registry
56    Registry { registry: String, path: String },
57}
58
59/// Handle to a loaded model
60pub struct ModelHandle {
61    /// Model information
62    pub info: ModelInfo,
63    /// Model file path (local)
64    pub local_path: PathBuf,
65    /// Whether model is loaded in memory
66    pub loaded: bool,
67}
68
69impl ModelHandle {
70    /// Create new model handle
71    pub fn new(info: ModelInfo, local_path: PathBuf) -> Self {
72        Self {
73            info,
74            local_path,
75            loaded: false,
76        }
77    }
78
79    /// Check if model file exists locally
80    pub fn exists(&self) -> bool {
81        self.local_path.exists()
82    }
83
84    /// Get model file size
85    pub fn file_size(&self) -> ModelResult<u64> {
86        let metadata = std::fs::metadata(&self.local_path)?;
87        Ok(metadata.len())
88    }
89
90    /// Validate model checksum
91    pub fn validate_checksum(&self) -> ModelResult<bool> {
92        if !self.exists() {
93            return Ok(false);
94        }
95
96        let data = std::fs::read(&self.local_path)?;
97        let hash = sha2::Sha256::digest(&data);
98        let hex_hash = hex::encode(hash);
99
100        Ok(hex_hash == self.info.checksum)
101    }
102}
103
104/// Model registry for managing pre-trained models
105pub struct ModelRegistry {
106    /// Registered models
107    models: Arc<Mutex<HashMap<String, ModelInfo>>>,
108    /// Cache directory for downloaded models
109    cache_dir: PathBuf,
110    /// Model handles cache
111    handles: Arc<Mutex<HashMap<String, ModelHandle>>>,
112}
113
114impl ModelRegistry {
115    /// Create new model registry
116    pub fn new<P: AsRef<Path>>(cache_dir: P) -> ModelResult<Self> {
117        let cache_dir = cache_dir.as_ref().to_path_buf();
118
119        // Create cache directory if it doesn't exist
120        if !cache_dir.exists() {
121            std::fs::create_dir_all(&cache_dir)?;
122        }
123
124        Ok(Self {
125            models: Arc::new(Mutex::new(HashMap::new())),
126            cache_dir,
127            handles: Arc::new(Mutex::new(HashMap::new())),
128        })
129    }
130
131    /// Create default registry (uses ~/.torsh/models)
132    pub fn default() -> ModelResult<Self> {
133        let home_dir = dirs::home_dir().ok_or_else(|| ModelError::LoadingError {
134            reason: "Could not find home directory".to_string(),
135        })?;
136
137        let cache_dir = home_dir.join(".torsh").join("models");
138        Self::new(cache_dir)
139    }
140
141    /// Register a new model
142    pub fn register_model(&self, info: ModelInfo) -> ModelResult<()> {
143        let mut models = self.models.lock().expect("lock should not be poisoned");
144        let key = format!("{}:{}", info.name, info.version);
145        models.insert(key, info);
146        Ok(())
147    }
148
149    /// Get model information by name and version
150    pub fn get_model_info(&self, name: &str, version: Option<&str>) -> ModelResult<ModelInfo> {
151        let models = self.models.lock().expect("lock should not be poisoned");
152
153        if let Some(version) = version {
154            let key = format!("{}:{}", name, version);
155            models
156                .get(&key)
157                .cloned()
158                .ok_or_else(|| ModelError::ModelNotFound { name: key })
159        } else {
160            // Find latest version
161            let matching_models: Vec<_> =
162                models.values().filter(|info| info.name == name).collect();
163
164            if matching_models.is_empty() {
165                return Err(ModelError::ModelNotFound {
166                    name: name.to_string(),
167                });
168            }
169
170            // Sort by version and return latest
171            let mut sorted = matching_models;
172            sorted.sort_by(|a, b| a.version.cmp(&b.version));
173
174            Ok((*sorted
175                .last()
176                .expect("matching models list should not be empty"))
177            .clone())
178        }
179    }
180
181    /// List all registered models
182    pub fn list_models(&self) -> Vec<ModelInfo> {
183        let models = self.models.lock().expect("lock should not be poisoned");
184        models.values().cloned().collect()
185    }
186
187    /// Search models by domain
188    pub fn search_by_domain(&self, domain: &str) -> Vec<ModelInfo> {
189        let models = self.models.lock().expect("lock should not be poisoned");
190        models
191            .values()
192            .filter(|info| info.domain == domain)
193            .cloned()
194            .collect()
195    }
196
197    /// Search models by tags
198    pub fn search_by_tags(&self, tags: &[&str]) -> Vec<ModelInfo> {
199        let models = self.models.lock().expect("lock should not be poisoned");
200        models
201            .values()
202            .filter(|info| tags.iter().any(|tag| info.tags.contains(&tag.to_string())))
203            .cloned()
204            .collect()
205    }
206
207    /// Get model handle
208    pub fn get_model_handle(&self, name: &str, version: Option<&str>) -> ModelResult<ModelHandle> {
209        let info = self.get_model_info(name, version)?;
210        let key = format!("{}:{}", info.name, info.version);
211
212        // Check if handle already exists
213        {
214            let handles = self.handles.lock().expect("lock should not be poisoned");
215            if let Some(handle) = handles.get(&key) {
216                return Ok(ModelHandle {
217                    info: handle.info.clone(),
218                    local_path: handle.local_path.clone(),
219                    loaded: handle.loaded,
220                });
221            }
222        }
223
224        // Create new handle
225        let local_path = self.get_local_path(&info);
226        let handle = ModelHandle::new(info, local_path);
227
228        // Cache the handle
229        {
230            let mut handles = self.handles.lock().expect("lock should not be poisoned");
231            handles.insert(
232                key,
233                ModelHandle {
234                    info: handle.info.clone(),
235                    local_path: handle.local_path.clone(),
236                    loaded: handle.loaded,
237                },
238            );
239        }
240
241        Ok(handle)
242    }
243
244    /// Get local file path for a model
245    fn get_local_path(&self, info: &ModelInfo) -> PathBuf {
246        let filename = format!("{}-{}.safetensors", info.name, info.version);
247        self.cache_dir.join(filename)
248    }
249
250    /// Load models from registry file
251    pub fn load_from_file<P: AsRef<Path>>(&self, path: P) -> ModelResult<()> {
252        let content = std::fs::read_to_string(path)?;
253        let model_infos: Vec<ModelInfo> = serde_json::from_str(&content)?;
254
255        for info in model_infos {
256            self.register_model(info)?;
257        }
258
259        Ok(())
260    }
261
262    /// Save models to registry file
263    pub fn save_to_file<P: AsRef<Path>>(&self, path: P) -> ModelResult<()> {
264        let models = self.list_models();
265        let content = serde_json::to_string_pretty(&models)?;
266        std::fs::write(path, content)?;
267        Ok(())
268    }
269
270    /// Register built-in models
271    pub fn register_builtin_models(&self) -> ModelResult<()> {
272        // Vision models
273        #[cfg(feature = "vision")]
274        {
275            self.register_vision_models()?;
276        }
277
278        // NLP models
279        #[cfg(feature = "nlp")]
280        {
281            self.register_nlp_models()?;
282        }
283
284        // Audio models
285        #[cfg(feature = "audio")]
286        {
287            self.register_audio_models()?;
288        }
289
290        // Multimodal models
291        #[cfg(feature = "multimodal")]
292        {
293            self.register_multimodal_models()?;
294        }
295
296        Ok(())
297    }
298
299    #[cfg(feature = "vision")]
300    fn register_vision_models(&self) -> ModelResult<()> {
301        // ResNet-18
302        let resnet18 = ModelInfo {
303            name: "resnet18".to_string(),
304            version: "1.0.0".to_string(),
305            description: "ResNet-18 model pre-trained on ImageNet".to_string(),
306            architecture: "ResNet".to_string(),
307            domain: "vision".to_string(),
308            input_spec: "RGB image [3, 224, 224]".to_string(),
309            output_spec: "1000 class probabilities".to_string(),
310            source: ModelSource::Url("https://github.com/pytorch/vision/releases/download/v0.1.9/resnet18-5c106cde.pth".to_string()),
311            size_bytes: 46827520,
312            parameters: 11689512,
313            metrics: {
314                let mut m = HashMap::new();
315                m.insert("top1_accuracy".to_string(), 69.758);
316                m.insert("top5_accuracy".to_string(), 89.078);
317                m
318            },
319            tags: vec!["classification".to_string(), "imagenet".to_string(), "cnn".to_string()],
320            license: "BSD".to_string(),
321            citation: Some("He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep residual learning for image recognition.".to_string()),
322            checksum: "5c106cde0abbf5e61f9b0e5d5c51b2a9e17896b7".to_string(),
323        };
324        self.register_model(resnet18)?;
325
326        // ResNet-50
327        let resnet50 = ModelInfo {
328            name: "resnet50".to_string(),
329            version: "1.0.0".to_string(),
330            description: "ResNet-50 model pre-trained on ImageNet".to_string(),
331            architecture: "ResNet".to_string(),
332            domain: "vision".to_string(),
333            input_spec: "RGB image [3, 224, 224]".to_string(),
334            output_spec: "1000 class probabilities".to_string(),
335            source: ModelSource::Url("https://github.com/pytorch/vision/releases/download/v0.1.9/resnet50-19c8e357.pth".to_string()),
336            size_bytes: 102502400,
337            parameters: 25557032,
338            metrics: {
339                let mut m = HashMap::new();
340                m.insert("top1_accuracy".to_string(), 76.130);
341                m.insert("top5_accuracy".to_string(), 92.862);
342                m
343            },
344            tags: vec!["classification".to_string(), "imagenet".to_string(), "cnn".to_string()],
345            license: "BSD".to_string(),
346            citation: Some("He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep residual learning for image recognition.".to_string()),
347            checksum: "19c8e357f2b6c76a2a39b97e94f5e71e8bbde6b7".to_string(),
348        };
349        self.register_model(resnet50)?;
350
351        // EfficientNet-B0
352        let efficientnet_b0 = ModelInfo {
353            name: "efficientnet_b0".to_string(),
354            version: "1.0.0".to_string(),
355            description: "EfficientNet-B0 model pre-trained on ImageNet".to_string(),
356            architecture: "EfficientNet".to_string(),
357            domain: "vision".to_string(),
358            input_spec: "RGB image [3, 224, 224]".to_string(),
359            output_spec: "1000 class probabilities".to_string(),
360            source: ModelSource::Url("https://github.com/pytorch/vision/releases/download/v0.13.0/efficientnet_b0_rwightman-3dd342df.pth".to_string()),
361            size_bytes: 21389824,
362            parameters: 5288548,
363            metrics: {
364                let mut m = HashMap::new();
365                m.insert("top1_accuracy".to_string(), 77.692);
366                m.insert("top5_accuracy".to_string(), 93.532);
367                m
368            },
369            tags: vec!["classification".to_string(), "imagenet".to_string(), "efficient".to_string()],
370            license: "Apache-2.0".to_string(),
371            citation: Some("Tan, M., & Le, Q. (2019). Efficientnet: Rethinking model scaling for convolutional neural networks.".to_string()),
372            checksum: "3dd342df789abc123456".to_string(),
373        };
374        self.register_model(efficientnet_b0)?;
375
376        // Vision Transformer (ViT-Base)
377        let vit_base = ModelInfo {
378            name: "vit_base_patch16_224".to_string(),
379            version: "1.0.0".to_string(),
380            description: "Vision Transformer (ViT-Base) with 16x16 patches, pre-trained on ImageNet".to_string(),
381            architecture: "ViT".to_string(),
382            domain: "vision".to_string(),
383            input_spec: "RGB image [3, 224, 224]".to_string(),
384            output_spec: "1000 class probabilities".to_string(),
385            source: ModelSource::Url("https://github.com/pytorch/vision/releases/download/v0.13.0/vit_b_16-c867db91.pth".to_string()),
386            size_bytes: 346659840,
387            parameters: 86567656,
388            metrics: {
389                let mut m = HashMap::new();
390                m.insert("top1_accuracy".to_string(), 81.072);
391                m.insert("top5_accuracy".to_string(), 95.318);
392                m
393            },
394            tags: vec!["classification".to_string(), "imagenet".to_string(), "transformer".to_string()],
395            license: "Apache-2.0".to_string(),
396            citation: Some("Dosovitskiy, A., et al. (2020). An image is worth 16x16 words: Transformers for image recognition at scale.".to_string()),
397            checksum: "c867db9123456789abc".to_string(),
398        };
399        self.register_model(vit_base)?;
400
401        Ok(())
402    }
403
404    #[cfg(feature = "nlp")]
405    fn register_nlp_models(&self) -> ModelResult<()> {
406        // BERT-base
407        let bert_base = ModelInfo {
408            name: "bert-base-uncased".to_string(),
409            version: "1.0.0".to_string(),
410            description: "BERT base model (uncased) pre-trained on English corpus".to_string(),
411            architecture: "BERT".to_string(),
412            domain: "nlp".to_string(),
413            input_spec: "Tokenized text [seq_len]".to_string(),
414            output_spec: "Hidden states [seq_len, 768]".to_string(),
415            source: ModelSource::HuggingFace {
416                repo: "bert-base-uncased".to_string(),
417                filename: "pytorch_model.bin".to_string()
418            },
419            size_bytes: 440473133,
420            parameters: 110000000,
421            metrics: {
422                let mut m = HashMap::new();
423                m.insert("glue_avg".to_string(), 79.6);
424                m
425            },
426            tags: vec!["transformer".to_string(), "encoder".to_string(), "english".to_string()],
427            license: "Apache-2.0".to_string(),
428            citation: Some("Devlin, J., Chang, M. W., Lee, K., & Toutanova, K. (2018). Bert: Pre-training of deep bidirectional transformers for language understanding.".to_string()),
429            checksum: "abc123def456789".to_string(),
430        };
431        self.register_model(bert_base)?;
432
433        // GPT-2 Base
434        let gpt2_base = ModelInfo {
435            name: "gpt2".to_string(),
436            version: "1.0.0".to_string(),
437            description: "GPT-2 base model (117M parameters) pre-trained on English text"
438                .to_string(),
439            architecture: "GPT-2".to_string(),
440            domain: "nlp".to_string(),
441            input_spec: "Tokenized text [seq_len]".to_string(),
442            output_spec: "Token probabilities [seq_len, vocab_size]".to_string(),
443            source: ModelSource::HuggingFace {
444                repo: "gpt2".to_string(),
445                filename: "pytorch_model.bin".to_string(),
446            },
447            size_bytes: 510342400,
448            parameters: 117000000,
449            metrics: {
450                let mut m = HashMap::new();
451                m.insert("perplexity".to_string(), 18.3);
452                m
453            },
454            tags: vec![
455                "transformer".to_string(),
456                "decoder".to_string(),
457                "generative".to_string(),
458            ],
459            license: "MIT".to_string(),
460            citation: Some(
461                "Radford, A., et al. (2019). Language models are unsupervised multitask learners."
462                    .to_string(),
463            ),
464            checksum: "def456789abc123".to_string(),
465        };
466        self.register_model(gpt2_base)?;
467
468        // RoBERTa Base
469        let roberta_base = ModelInfo {
470            name: "roberta-base".to_string(),
471            version: "1.0.0".to_string(),
472            description: "RoBERTa base model pre-trained on English corpus".to_string(),
473            architecture: "RoBERTa".to_string(),
474            domain: "nlp".to_string(),
475            input_spec: "Tokenized text [seq_len]".to_string(),
476            output_spec: "Hidden states [seq_len, 768]".to_string(),
477            source: ModelSource::HuggingFace {
478                repo: "roberta-base".to_string(),
479                filename: "pytorch_model.bin".to_string(),
480            },
481            size_bytes: 498677760,
482            parameters: 125000000,
483            metrics: {
484                let mut m = HashMap::new();
485                m.insert("glue_avg".to_string(), 83.2);
486                m
487            },
488            tags: vec![
489                "transformer".to_string(),
490                "encoder".to_string(),
491                "english".to_string(),
492            ],
493            license: "MIT".to_string(),
494            citation: Some(
495                "Liu, Y., et al. (2019). RoBERTa: A robustly optimized BERT pretraining approach."
496                    .to_string(),
497            ),
498            checksum: "789abc123def456".to_string(),
499        };
500        self.register_model(roberta_base)?;
501
502        Ok(())
503    }
504
505    #[cfg(feature = "audio")]
506    fn register_audio_models(&self) -> ModelResult<()> {
507        // Wav2Vec2 Base
508        let wav2vec2_base = ModelInfo {
509            name: "wav2vec2-base".to_string(),
510            version: "1.0.0".to_string(),
511            description: "Wav2Vec2 base model pre-trained for speech recognition".to_string(),
512            architecture: "Wav2Vec2".to_string(),
513            domain: "audio".to_string(),
514            input_spec: "Audio waveform [seq_len]".to_string(),
515            output_spec: "Hidden states [seq_len, 768]".to_string(),
516            source: ModelSource::HuggingFace {
517                repo: "facebook/wav2vec2-base".to_string(),
518                filename: "pytorch_model.bin".to_string()
519            },
520            size_bytes: 378000000,
521            parameters: 95000000,
522            metrics: {
523                let mut m = HashMap::new();
524                m.insert("librispeech_wer".to_string(), 6.1);
525                m
526            },
527            tags: vec!["speech".to_string(), "recognition".to_string(), "self-supervised".to_string()],
528            license: "MIT".to_string(),
529            citation: Some("Baevski, A., et al. (2020). wav2vec 2.0: A framework for self-supervised learning of speech representations.".to_string()),
530            checksum: "123abc456def789".to_string(),
531        };
532        self.register_model(wav2vec2_base)?;
533
534        // Whisper Base
535        let whisper_base = ModelInfo {
536            name: "whisper-base".to_string(),
537            version: "1.0.0".to_string(),
538            description: "Whisper base model for speech-to-text transcription".to_string(),
539            architecture: "Whisper".to_string(),
540            domain: "audio".to_string(),
541            input_spec: "Audio mel spectrogram [80, seq_len]".to_string(),
542            output_spec: "Text tokens [seq_len]".to_string(),
543            source: ModelSource::HuggingFace {
544                repo: "openai/whisper-base".to_string(),
545                filename: "pytorch_model.bin".to_string()
546            },
547            size_bytes: 290000000,
548            parameters: 74000000,
549            metrics: {
550                let mut m = HashMap::new();
551                m.insert("librispeech_wer".to_string(), 5.4);
552                m
553            },
554            tags: vec!["speech".to_string(), "transcription".to_string(), "multilingual".to_string()],
555            license: "MIT".to_string(),
556            citation: Some("Radford, A., et al. (2022). Robust speech recognition via large-scale weak supervision.".to_string()),
557            checksum: "456def789abc123".to_string(),
558        };
559        self.register_model(whisper_base)?;
560
561        Ok(())
562    }
563
564    #[cfg(feature = "multimodal")]
565    fn register_multimodal_models(&self) -> ModelResult<()> {
566        // CLIP Base
567        let clip_base = ModelInfo {
568            name: "clip-vit-base-patch32".to_string(),
569            version: "1.0.0".to_string(),
570            description: "CLIP model with ViT-Base vision encoder and text encoder".to_string(),
571            architecture: "CLIP".to_string(),
572            domain: "multimodal".to_string(),
573            input_spec: "RGB image [3, 224, 224] + text [seq_len]".to_string(),
574            output_spec: "Image/text embeddings [512]".to_string(),
575            source: ModelSource::HuggingFace {
576                repo: "openai/clip-vit-base-patch32".to_string(),
577                filename: "pytorch_model.bin".to_string()
578            },
579            size_bytes: 605000000,
580            parameters: 151000000,
581            metrics: {
582                let mut m = HashMap::new();
583                m.insert("zero_shot_imagenet".to_string(), 63.2);
584                m
585            },
586            tags: vec!["vision-language".to_string(), "contrastive".to_string(), "zero-shot".to_string()],
587            license: "MIT".to_string(),
588            citation: Some("Radford, A., et al. (2021). Learning transferable visual representations from natural language supervision.".to_string()),
589            checksum: "789abc123def456".to_string(),
590        };
591        self.register_model(clip_base)?;
592
593        Ok(())
594    }
595}
596
597lazy_static::lazy_static! {
598    /// Create a global model registry instance
599    static ref GLOBAL_REGISTRY: ModelRegistry = {
600        let registry = ModelRegistry::default().expect("Failed to create model registry");
601        registry.register_builtin_models().expect("Failed to register builtin models");
602        registry
603    };
604}
605
606/// Get the global model registry
607pub fn get_global_registry() -> &'static ModelRegistry {
608    &GLOBAL_REGISTRY
609}
610
611#[cfg(test)]
612mod tests {
613    use super::*;
614    use tempfile::tempdir;
615
616    #[test]
617    fn test_model_registry_creation() {
618        let temp_dir = tempdir().unwrap();
619        let _registry = ModelRegistry::new(temp_dir.path()).unwrap();
620        assert!(temp_dir.path().exists());
621    }
622
623    #[test]
624    fn test_model_registration() {
625        let temp_dir = tempdir().unwrap();
626        let registry = ModelRegistry::new(temp_dir.path()).unwrap();
627
628        let info = ModelInfo {
629            name: "test_model".to_string(),
630            version: "1.0.0".to_string(),
631            description: "Test model".to_string(),
632            architecture: "TestNet".to_string(),
633            domain: "test".to_string(),
634            input_spec: "test input".to_string(),
635            output_spec: "test output".to_string(),
636            source: ModelSource::Local(PathBuf::from("test.safetensors")),
637            size_bytes: 1024,
638            parameters: 100,
639            metrics: HashMap::new(),
640            tags: vec!["test".to_string()],
641            license: "MIT".to_string(),
642            citation: None,
643            checksum: "test_checksum".to_string(),
644        };
645
646        registry.register_model(info.clone()).unwrap();
647
648        let retrieved = registry
649            .get_model_info("test_model", Some("1.0.0"))
650            .unwrap();
651        assert_eq!(retrieved.name, "test_model");
652        assert_eq!(retrieved.version, "1.0.0");
653    }
654
655    #[test]
656    fn test_model_search() {
657        let temp_dir = tempdir().unwrap();
658        let registry = ModelRegistry::new(temp_dir.path()).unwrap();
659
660        let info1 = ModelInfo {
661            name: "model1".to_string(),
662            version: "1.0.0".to_string(),
663            description: "Model 1".to_string(),
664            architecture: "Net1".to_string(),
665            domain: "vision".to_string(),
666            input_spec: "image".to_string(),
667            output_spec: "class".to_string(),
668            source: ModelSource::Local(PathBuf::from("model1.safetensors")),
669            size_bytes: 1024,
670            parameters: 100,
671            metrics: HashMap::new(),
672            tags: vec!["cnn".to_string(), "classification".to_string()],
673            license: "MIT".to_string(),
674            citation: None,
675            checksum: "checksum1".to_string(),
676        };
677
678        let info2 = ModelInfo {
679            name: "model2".to_string(),
680            version: "1.0.0".to_string(),
681            description: "Model 2".to_string(),
682            architecture: "Net2".to_string(),
683            domain: "nlp".to_string(),
684            input_spec: "text".to_string(),
685            output_spec: "embedding".to_string(),
686            source: ModelSource::Local(PathBuf::from("model2.safetensors")),
687            size_bytes: 2048,
688            parameters: 200,
689            metrics: HashMap::new(),
690            tags: vec!["transformer".to_string(), "embedding".to_string()],
691            license: "Apache-2.0".to_string(),
692            citation: None,
693            checksum: "checksum2".to_string(),
694        };
695
696        registry.register_model(info1).unwrap();
697        registry.register_model(info2).unwrap();
698
699        let vision_models = registry.search_by_domain("vision");
700        assert_eq!(vision_models.len(), 1);
701        assert_eq!(vision_models[0].name, "model1");
702
703        let cnn_models = registry.search_by_tags(&["cnn"]);
704        assert_eq!(cnn_models.len(), 1);
705        assert_eq!(cnn_models[0].name, "model1");
706    }
707}