1use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5use std::sync::{Arc, Mutex};
6
7use crate::{ModelError, ModelResult};
8use serde::{Deserialize, Serialize};
9use sha2::Digest;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct ModelInfo {
14 pub name: String,
16 pub version: String,
18 pub description: String,
20 pub architecture: String,
22 pub domain: String,
24 pub input_spec: String,
26 pub output_spec: String,
28 pub source: ModelSource,
30 pub size_bytes: u64,
32 pub parameters: u64,
34 pub metrics: HashMap<String, f32>,
36 pub tags: Vec<String>,
38 pub license: String,
40 pub citation: Option<String>,
42 pub checksum: String,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub enum ModelSource {
49 Local(PathBuf),
51 Url(String),
53 HuggingFace { repo: String, filename: String },
55 Registry { registry: String, path: String },
57}
58
59pub struct ModelHandle {
61 pub info: ModelInfo,
63 pub local_path: PathBuf,
65 pub loaded: bool,
67}
68
69impl ModelHandle {
70 pub fn new(info: ModelInfo, local_path: PathBuf) -> Self {
72 Self {
73 info,
74 local_path,
75 loaded: false,
76 }
77 }
78
79 pub fn exists(&self) -> bool {
81 self.local_path.exists()
82 }
83
84 pub fn file_size(&self) -> ModelResult<u64> {
86 let metadata = std::fs::metadata(&self.local_path)?;
87 Ok(metadata.len())
88 }
89
90 pub fn validate_checksum(&self) -> ModelResult<bool> {
92 if !self.exists() {
93 return Ok(false);
94 }
95
96 let data = std::fs::read(&self.local_path)?;
97 let hash = sha2::Sha256::digest(&data);
98 let hex_hash = hex::encode(hash);
99
100 Ok(hex_hash == self.info.checksum)
101 }
102}
103
104pub struct ModelRegistry {
106 models: Arc<Mutex<HashMap<String, ModelInfo>>>,
108 cache_dir: PathBuf,
110 handles: Arc<Mutex<HashMap<String, ModelHandle>>>,
112}
113
114impl ModelRegistry {
115 pub fn new<P: AsRef<Path>>(cache_dir: P) -> ModelResult<Self> {
117 let cache_dir = cache_dir.as_ref().to_path_buf();
118
119 if !cache_dir.exists() {
121 std::fs::create_dir_all(&cache_dir)?;
122 }
123
124 Ok(Self {
125 models: Arc::new(Mutex::new(HashMap::new())),
126 cache_dir,
127 handles: Arc::new(Mutex::new(HashMap::new())),
128 })
129 }
130
131 pub fn default() -> ModelResult<Self> {
133 let home_dir = dirs::home_dir().ok_or_else(|| ModelError::LoadingError {
134 reason: "Could not find home directory".to_string(),
135 })?;
136
137 let cache_dir = home_dir.join(".torsh").join("models");
138 Self::new(cache_dir)
139 }
140
141 pub fn register_model(&self, info: ModelInfo) -> ModelResult<()> {
143 let mut models = self.models.lock().expect("lock should not be poisoned");
144 let key = format!("{}:{}", info.name, info.version);
145 models.insert(key, info);
146 Ok(())
147 }
148
149 pub fn get_model_info(&self, name: &str, version: Option<&str>) -> ModelResult<ModelInfo> {
151 let models = self.models.lock().expect("lock should not be poisoned");
152
153 if let Some(version) = version {
154 let key = format!("{}:{}", name, version);
155 models
156 .get(&key)
157 .cloned()
158 .ok_or_else(|| ModelError::ModelNotFound { name: key })
159 } else {
160 let matching_models: Vec<_> =
162 models.values().filter(|info| info.name == name).collect();
163
164 if matching_models.is_empty() {
165 return Err(ModelError::ModelNotFound {
166 name: name.to_string(),
167 });
168 }
169
170 let mut sorted = matching_models;
172 sorted.sort_by(|a, b| a.version.cmp(&b.version));
173
174 Ok((*sorted
175 .last()
176 .expect("matching models list should not be empty"))
177 .clone())
178 }
179 }
180
181 pub fn list_models(&self) -> Vec<ModelInfo> {
183 let models = self.models.lock().expect("lock should not be poisoned");
184 models.values().cloned().collect()
185 }
186
187 pub fn search_by_domain(&self, domain: &str) -> Vec<ModelInfo> {
189 let models = self.models.lock().expect("lock should not be poisoned");
190 models
191 .values()
192 .filter(|info| info.domain == domain)
193 .cloned()
194 .collect()
195 }
196
197 pub fn search_by_tags(&self, tags: &[&str]) -> Vec<ModelInfo> {
199 let models = self.models.lock().expect("lock should not be poisoned");
200 models
201 .values()
202 .filter(|info| tags.iter().any(|tag| info.tags.contains(&tag.to_string())))
203 .cloned()
204 .collect()
205 }
206
207 pub fn get_model_handle(&self, name: &str, version: Option<&str>) -> ModelResult<ModelHandle> {
209 let info = self.get_model_info(name, version)?;
210 let key = format!("{}:{}", info.name, info.version);
211
212 {
214 let handles = self.handles.lock().expect("lock should not be poisoned");
215 if let Some(handle) = handles.get(&key) {
216 return Ok(ModelHandle {
217 info: handle.info.clone(),
218 local_path: handle.local_path.clone(),
219 loaded: handle.loaded,
220 });
221 }
222 }
223
224 let local_path = self.get_local_path(&info);
226 let handle = ModelHandle::new(info, local_path);
227
228 {
230 let mut handles = self.handles.lock().expect("lock should not be poisoned");
231 handles.insert(
232 key,
233 ModelHandle {
234 info: handle.info.clone(),
235 local_path: handle.local_path.clone(),
236 loaded: handle.loaded,
237 },
238 );
239 }
240
241 Ok(handle)
242 }
243
244 fn get_local_path(&self, info: &ModelInfo) -> PathBuf {
246 let filename = format!("{}-{}.safetensors", info.name, info.version);
247 self.cache_dir.join(filename)
248 }
249
250 pub fn load_from_file<P: AsRef<Path>>(&self, path: P) -> ModelResult<()> {
252 let content = std::fs::read_to_string(path)?;
253 let model_infos: Vec<ModelInfo> = serde_json::from_str(&content)?;
254
255 for info in model_infos {
256 self.register_model(info)?;
257 }
258
259 Ok(())
260 }
261
262 pub fn save_to_file<P: AsRef<Path>>(&self, path: P) -> ModelResult<()> {
264 let models = self.list_models();
265 let content = serde_json::to_string_pretty(&models)?;
266 std::fs::write(path, content)?;
267 Ok(())
268 }
269
270 pub fn register_builtin_models(&self) -> ModelResult<()> {
272 #[cfg(feature = "vision")]
274 {
275 self.register_vision_models()?;
276 }
277
278 #[cfg(feature = "nlp")]
280 {
281 self.register_nlp_models()?;
282 }
283
284 #[cfg(feature = "audio")]
286 {
287 self.register_audio_models()?;
288 }
289
290 #[cfg(feature = "multimodal")]
292 {
293 self.register_multimodal_models()?;
294 }
295
296 Ok(())
297 }
298
299 #[cfg(feature = "vision")]
300 fn register_vision_models(&self) -> ModelResult<()> {
301 let resnet18 = ModelInfo {
303 name: "resnet18".to_string(),
304 version: "1.0.0".to_string(),
305 description: "ResNet-18 model pre-trained on ImageNet".to_string(),
306 architecture: "ResNet".to_string(),
307 domain: "vision".to_string(),
308 input_spec: "RGB image [3, 224, 224]".to_string(),
309 output_spec: "1000 class probabilities".to_string(),
310 source: ModelSource::Url("https://github.com/pytorch/vision/releases/download/v0.1.9/resnet18-5c106cde.pth".to_string()),
311 size_bytes: 46827520,
312 parameters: 11689512,
313 metrics: {
314 let mut m = HashMap::new();
315 m.insert("top1_accuracy".to_string(), 69.758);
316 m.insert("top5_accuracy".to_string(), 89.078);
317 m
318 },
319 tags: vec!["classification".to_string(), "imagenet".to_string(), "cnn".to_string()],
320 license: "BSD".to_string(),
321 citation: Some("He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep residual learning for image recognition.".to_string()),
322 checksum: "5c106cde0abbf5e61f9b0e5d5c51b2a9e17896b7".to_string(),
323 };
324 self.register_model(resnet18)?;
325
326 let resnet50 = ModelInfo {
328 name: "resnet50".to_string(),
329 version: "1.0.0".to_string(),
330 description: "ResNet-50 model pre-trained on ImageNet".to_string(),
331 architecture: "ResNet".to_string(),
332 domain: "vision".to_string(),
333 input_spec: "RGB image [3, 224, 224]".to_string(),
334 output_spec: "1000 class probabilities".to_string(),
335 source: ModelSource::Url("https://github.com/pytorch/vision/releases/download/v0.1.9/resnet50-19c8e357.pth".to_string()),
336 size_bytes: 102502400,
337 parameters: 25557032,
338 metrics: {
339 let mut m = HashMap::new();
340 m.insert("top1_accuracy".to_string(), 76.130);
341 m.insert("top5_accuracy".to_string(), 92.862);
342 m
343 },
344 tags: vec!["classification".to_string(), "imagenet".to_string(), "cnn".to_string()],
345 license: "BSD".to_string(),
346 citation: Some("He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep residual learning for image recognition.".to_string()),
347 checksum: "19c8e357f2b6c76a2a39b97e94f5e71e8bbde6b7".to_string(),
348 };
349 self.register_model(resnet50)?;
350
351 let efficientnet_b0 = ModelInfo {
353 name: "efficientnet_b0".to_string(),
354 version: "1.0.0".to_string(),
355 description: "EfficientNet-B0 model pre-trained on ImageNet".to_string(),
356 architecture: "EfficientNet".to_string(),
357 domain: "vision".to_string(),
358 input_spec: "RGB image [3, 224, 224]".to_string(),
359 output_spec: "1000 class probabilities".to_string(),
360 source: ModelSource::Url("https://github.com/pytorch/vision/releases/download/v0.13.0/efficientnet_b0_rwightman-3dd342df.pth".to_string()),
361 size_bytes: 21389824,
362 parameters: 5288548,
363 metrics: {
364 let mut m = HashMap::new();
365 m.insert("top1_accuracy".to_string(), 77.692);
366 m.insert("top5_accuracy".to_string(), 93.532);
367 m
368 },
369 tags: vec!["classification".to_string(), "imagenet".to_string(), "efficient".to_string()],
370 license: "Apache-2.0".to_string(),
371 citation: Some("Tan, M., & Le, Q. (2019). Efficientnet: Rethinking model scaling for convolutional neural networks.".to_string()),
372 checksum: "3dd342df789abc123456".to_string(),
373 };
374 self.register_model(efficientnet_b0)?;
375
376 let vit_base = ModelInfo {
378 name: "vit_base_patch16_224".to_string(),
379 version: "1.0.0".to_string(),
380 description: "Vision Transformer (ViT-Base) with 16x16 patches, pre-trained on ImageNet".to_string(),
381 architecture: "ViT".to_string(),
382 domain: "vision".to_string(),
383 input_spec: "RGB image [3, 224, 224]".to_string(),
384 output_spec: "1000 class probabilities".to_string(),
385 source: ModelSource::Url("https://github.com/pytorch/vision/releases/download/v0.13.0/vit_b_16-c867db91.pth".to_string()),
386 size_bytes: 346659840,
387 parameters: 86567656,
388 metrics: {
389 let mut m = HashMap::new();
390 m.insert("top1_accuracy".to_string(), 81.072);
391 m.insert("top5_accuracy".to_string(), 95.318);
392 m
393 },
394 tags: vec!["classification".to_string(), "imagenet".to_string(), "transformer".to_string()],
395 license: "Apache-2.0".to_string(),
396 citation: Some("Dosovitskiy, A., et al. (2020). An image is worth 16x16 words: Transformers for image recognition at scale.".to_string()),
397 checksum: "c867db9123456789abc".to_string(),
398 };
399 self.register_model(vit_base)?;
400
401 Ok(())
402 }
403
404 #[cfg(feature = "nlp")]
405 fn register_nlp_models(&self) -> ModelResult<()> {
406 let bert_base = ModelInfo {
408 name: "bert-base-uncased".to_string(),
409 version: "1.0.0".to_string(),
410 description: "BERT base model (uncased) pre-trained on English corpus".to_string(),
411 architecture: "BERT".to_string(),
412 domain: "nlp".to_string(),
413 input_spec: "Tokenized text [seq_len]".to_string(),
414 output_spec: "Hidden states [seq_len, 768]".to_string(),
415 source: ModelSource::HuggingFace {
416 repo: "bert-base-uncased".to_string(),
417 filename: "pytorch_model.bin".to_string()
418 },
419 size_bytes: 440473133,
420 parameters: 110000000,
421 metrics: {
422 let mut m = HashMap::new();
423 m.insert("glue_avg".to_string(), 79.6);
424 m
425 },
426 tags: vec!["transformer".to_string(), "encoder".to_string(), "english".to_string()],
427 license: "Apache-2.0".to_string(),
428 citation: Some("Devlin, J., Chang, M. W., Lee, K., & Toutanova, K. (2018). Bert: Pre-training of deep bidirectional transformers for language understanding.".to_string()),
429 checksum: "abc123def456789".to_string(),
430 };
431 self.register_model(bert_base)?;
432
433 let gpt2_base = ModelInfo {
435 name: "gpt2".to_string(),
436 version: "1.0.0".to_string(),
437 description: "GPT-2 base model (117M parameters) pre-trained on English text"
438 .to_string(),
439 architecture: "GPT-2".to_string(),
440 domain: "nlp".to_string(),
441 input_spec: "Tokenized text [seq_len]".to_string(),
442 output_spec: "Token probabilities [seq_len, vocab_size]".to_string(),
443 source: ModelSource::HuggingFace {
444 repo: "gpt2".to_string(),
445 filename: "pytorch_model.bin".to_string(),
446 },
447 size_bytes: 510342400,
448 parameters: 117000000,
449 metrics: {
450 let mut m = HashMap::new();
451 m.insert("perplexity".to_string(), 18.3);
452 m
453 },
454 tags: vec![
455 "transformer".to_string(),
456 "decoder".to_string(),
457 "generative".to_string(),
458 ],
459 license: "MIT".to_string(),
460 citation: Some(
461 "Radford, A., et al. (2019). Language models are unsupervised multitask learners."
462 .to_string(),
463 ),
464 checksum: "def456789abc123".to_string(),
465 };
466 self.register_model(gpt2_base)?;
467
468 let roberta_base = ModelInfo {
470 name: "roberta-base".to_string(),
471 version: "1.0.0".to_string(),
472 description: "RoBERTa base model pre-trained on English corpus".to_string(),
473 architecture: "RoBERTa".to_string(),
474 domain: "nlp".to_string(),
475 input_spec: "Tokenized text [seq_len]".to_string(),
476 output_spec: "Hidden states [seq_len, 768]".to_string(),
477 source: ModelSource::HuggingFace {
478 repo: "roberta-base".to_string(),
479 filename: "pytorch_model.bin".to_string(),
480 },
481 size_bytes: 498677760,
482 parameters: 125000000,
483 metrics: {
484 let mut m = HashMap::new();
485 m.insert("glue_avg".to_string(), 83.2);
486 m
487 },
488 tags: vec![
489 "transformer".to_string(),
490 "encoder".to_string(),
491 "english".to_string(),
492 ],
493 license: "MIT".to_string(),
494 citation: Some(
495 "Liu, Y., et al. (2019). RoBERTa: A robustly optimized BERT pretraining approach."
496 .to_string(),
497 ),
498 checksum: "789abc123def456".to_string(),
499 };
500 self.register_model(roberta_base)?;
501
502 Ok(())
503 }
504
505 #[cfg(feature = "audio")]
506 fn register_audio_models(&self) -> ModelResult<()> {
507 let wav2vec2_base = ModelInfo {
509 name: "wav2vec2-base".to_string(),
510 version: "1.0.0".to_string(),
511 description: "Wav2Vec2 base model pre-trained for speech recognition".to_string(),
512 architecture: "Wav2Vec2".to_string(),
513 domain: "audio".to_string(),
514 input_spec: "Audio waveform [seq_len]".to_string(),
515 output_spec: "Hidden states [seq_len, 768]".to_string(),
516 source: ModelSource::HuggingFace {
517 repo: "facebook/wav2vec2-base".to_string(),
518 filename: "pytorch_model.bin".to_string()
519 },
520 size_bytes: 378000000,
521 parameters: 95000000,
522 metrics: {
523 let mut m = HashMap::new();
524 m.insert("librispeech_wer".to_string(), 6.1);
525 m
526 },
527 tags: vec!["speech".to_string(), "recognition".to_string(), "self-supervised".to_string()],
528 license: "MIT".to_string(),
529 citation: Some("Baevski, A., et al. (2020). wav2vec 2.0: A framework for self-supervised learning of speech representations.".to_string()),
530 checksum: "123abc456def789".to_string(),
531 };
532 self.register_model(wav2vec2_base)?;
533
534 let whisper_base = ModelInfo {
536 name: "whisper-base".to_string(),
537 version: "1.0.0".to_string(),
538 description: "Whisper base model for speech-to-text transcription".to_string(),
539 architecture: "Whisper".to_string(),
540 domain: "audio".to_string(),
541 input_spec: "Audio mel spectrogram [80, seq_len]".to_string(),
542 output_spec: "Text tokens [seq_len]".to_string(),
543 source: ModelSource::HuggingFace {
544 repo: "openai/whisper-base".to_string(),
545 filename: "pytorch_model.bin".to_string()
546 },
547 size_bytes: 290000000,
548 parameters: 74000000,
549 metrics: {
550 let mut m = HashMap::new();
551 m.insert("librispeech_wer".to_string(), 5.4);
552 m
553 },
554 tags: vec!["speech".to_string(), "transcription".to_string(), "multilingual".to_string()],
555 license: "MIT".to_string(),
556 citation: Some("Radford, A., et al. (2022). Robust speech recognition via large-scale weak supervision.".to_string()),
557 checksum: "456def789abc123".to_string(),
558 };
559 self.register_model(whisper_base)?;
560
561 Ok(())
562 }
563
564 #[cfg(feature = "multimodal")]
565 fn register_multimodal_models(&self) -> ModelResult<()> {
566 let clip_base = ModelInfo {
568 name: "clip-vit-base-patch32".to_string(),
569 version: "1.0.0".to_string(),
570 description: "CLIP model with ViT-Base vision encoder and text encoder".to_string(),
571 architecture: "CLIP".to_string(),
572 domain: "multimodal".to_string(),
573 input_spec: "RGB image [3, 224, 224] + text [seq_len]".to_string(),
574 output_spec: "Image/text embeddings [512]".to_string(),
575 source: ModelSource::HuggingFace {
576 repo: "openai/clip-vit-base-patch32".to_string(),
577 filename: "pytorch_model.bin".to_string()
578 },
579 size_bytes: 605000000,
580 parameters: 151000000,
581 metrics: {
582 let mut m = HashMap::new();
583 m.insert("zero_shot_imagenet".to_string(), 63.2);
584 m
585 },
586 tags: vec!["vision-language".to_string(), "contrastive".to_string(), "zero-shot".to_string()],
587 license: "MIT".to_string(),
588 citation: Some("Radford, A., et al. (2021). Learning transferable visual representations from natural language supervision.".to_string()),
589 checksum: "789abc123def456".to_string(),
590 };
591 self.register_model(clip_base)?;
592
593 Ok(())
594 }
595}
596
597lazy_static::lazy_static! {
598 static ref GLOBAL_REGISTRY: ModelRegistry = {
600 let registry = ModelRegistry::default().expect("Failed to create model registry");
601 registry.register_builtin_models().expect("Failed to register builtin models");
602 registry
603 };
604}
605
606pub fn get_global_registry() -> &'static ModelRegistry {
608 &GLOBAL_REGISTRY
609}
610
611#[cfg(test)]
612mod tests {
613 use super::*;
614 use tempfile::tempdir;
615
616 #[test]
617 fn test_model_registry_creation() {
618 let temp_dir = tempdir().unwrap();
619 let _registry = ModelRegistry::new(temp_dir.path()).unwrap();
620 assert!(temp_dir.path().exists());
621 }
622
623 #[test]
624 fn test_model_registration() {
625 let temp_dir = tempdir().unwrap();
626 let registry = ModelRegistry::new(temp_dir.path()).unwrap();
627
628 let info = ModelInfo {
629 name: "test_model".to_string(),
630 version: "1.0.0".to_string(),
631 description: "Test model".to_string(),
632 architecture: "TestNet".to_string(),
633 domain: "test".to_string(),
634 input_spec: "test input".to_string(),
635 output_spec: "test output".to_string(),
636 source: ModelSource::Local(PathBuf::from("test.safetensors")),
637 size_bytes: 1024,
638 parameters: 100,
639 metrics: HashMap::new(),
640 tags: vec!["test".to_string()],
641 license: "MIT".to_string(),
642 citation: None,
643 checksum: "test_checksum".to_string(),
644 };
645
646 registry.register_model(info.clone()).unwrap();
647
648 let retrieved = registry
649 .get_model_info("test_model", Some("1.0.0"))
650 .unwrap();
651 assert_eq!(retrieved.name, "test_model");
652 assert_eq!(retrieved.version, "1.0.0");
653 }
654
655 #[test]
656 fn test_model_search() {
657 let temp_dir = tempdir().unwrap();
658 let registry = ModelRegistry::new(temp_dir.path()).unwrap();
659
660 let info1 = ModelInfo {
661 name: "model1".to_string(),
662 version: "1.0.0".to_string(),
663 description: "Model 1".to_string(),
664 architecture: "Net1".to_string(),
665 domain: "vision".to_string(),
666 input_spec: "image".to_string(),
667 output_spec: "class".to_string(),
668 source: ModelSource::Local(PathBuf::from("model1.safetensors")),
669 size_bytes: 1024,
670 parameters: 100,
671 metrics: HashMap::new(),
672 tags: vec!["cnn".to_string(), "classification".to_string()],
673 license: "MIT".to_string(),
674 citation: None,
675 checksum: "checksum1".to_string(),
676 };
677
678 let info2 = ModelInfo {
679 name: "model2".to_string(),
680 version: "1.0.0".to_string(),
681 description: "Model 2".to_string(),
682 architecture: "Net2".to_string(),
683 domain: "nlp".to_string(),
684 input_spec: "text".to_string(),
685 output_spec: "embedding".to_string(),
686 source: ModelSource::Local(PathBuf::from("model2.safetensors")),
687 size_bytes: 2048,
688 parameters: 200,
689 metrics: HashMap::new(),
690 tags: vec!["transformer".to_string(), "embedding".to_string()],
691 license: "Apache-2.0".to_string(),
692 citation: None,
693 checksum: "checksum2".to_string(),
694 };
695
696 registry.register_model(info1).unwrap();
697 registry.register_model(info2).unwrap();
698
699 let vision_models = registry.search_by_domain("vision");
700 assert_eq!(vision_models.len(), 1);
701 assert_eq!(vision_models[0].name, "model1");
702
703 let cnn_models = registry.search_by_tags(&["cnn"]);
704 assert_eq!(cnn_models.len(), 1);
705 assert_eq!(cnn_models[0].name, "model1");
706 }
707}