1use anyhow::{Context, Result};
2use serde::{Deserialize, Serialize};
3use sha2::{Digest, Sha256};
4use std::collections::HashMap;
5use std::path::PathBuf;
6use std::time::{SystemTime, UNIX_EPOCH};
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct ModelVersion {
10 pub version_id: String,
11 pub model_name: String,
12 pub version_number: u32,
13 pub created_at: u64,
14 pub created_by: String,
15 pub description: String,
16 pub tags: Vec<String>,
17 pub metadata: HashMap<String, String>,
18 pub model_hash: String,
19 pub file_path: PathBuf,
20 pub parent_version: Option<String>,
21 pub training_config: TrainingConfig,
22 pub performance_metrics: PerformanceMetrics,
23 pub model_size: u64,
24 pub status: ModelStatus,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct TrainingConfig {
29 pub learning_rate: f32,
30 pub batch_size: usize,
31 pub epochs: u32,
32 pub optimizer: String,
33 pub loss_function: String,
34 pub regularization: HashMap<String, f32>,
35 pub hyperparameters: HashMap<String, String>,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct PerformanceMetrics {
40 pub accuracy: f32,
41 pub loss: f32,
42 pub validation_accuracy: f32,
43 pub validation_loss: f32,
44 pub f1_score: Option<f32>,
45 pub precision: Option<f32>,
46 pub recall: Option<f32>,
47 pub custom_metrics: HashMap<String, f32>,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub enum ModelStatus {
52 Training,
53 Trained,
54 Validated,
55 Deployed,
56 Archived,
57 Failed,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct ModelRegistry {
62 pub models: HashMap<String, Vec<ModelVersion>>,
63 pub latest_versions: HashMap<String, String>,
64 pub registry_path: PathBuf,
65}
66
67pub struct ModelVersioningManager {
68 registry: ModelRegistry,
69 storage_root: PathBuf,
70}
71
72impl ModelVersioningManager {
73 pub fn new(storage_root: PathBuf) -> Result<Self> {
74 let registry_path = storage_root.join("model_registry.json");
75
76 let registry = if registry_path.exists() {
77 let registry_data =
78 std::fs::read_to_string(®istry_path).context("Failed to read model registry")?;
79 serde_json::from_str(®istry_data).context("Failed to parse model registry")?
80 } else {
81 ModelRegistry {
82 models: HashMap::new(),
83 latest_versions: HashMap::new(),
84 registry_path: registry_path.clone(),
85 }
86 };
87
88 std::fs::create_dir_all(&storage_root)
90 .context("Failed to create storage root directory")?;
91
92 Ok(Self {
93 registry,
94 storage_root,
95 })
96 }
97
98 pub fn create_version(
99 &mut self,
100 model_name: String,
101 model_data: &[u8],
102 description: String,
103 created_by: String,
104 tags: Vec<String>,
105 training_config: TrainingConfig,
106 performance_metrics: PerformanceMetrics,
107 metadata: HashMap<String, String>,
108 ) -> Result<ModelVersion> {
109 let version_number = self.get_next_version_number(&model_name);
111
112 let version_id = format!("{}_{:04}", model_name, version_number);
114
115 let mut hasher = Sha256::new();
117 hasher.update(model_data);
118 let model_hash = format!("{:x}", hasher.finalize());
119
120 let file_name = format!("{}.model", version_id);
122 let file_path = self.storage_root.join(&model_name).join(&file_name);
123
124 if let Some(parent) = file_path.parent() {
126 std::fs::create_dir_all(parent).context("Failed to create model directory")?;
127 }
128
129 std::fs::write(&file_path, model_data).context("Failed to save model data")?;
131
132 let created_at = SystemTime::now()
134 .duration_since(UNIX_EPOCH)
135 .expect("SystemTime should be after UNIX_EPOCH")
136 .as_secs();
137
138 let parent_version = self.registry.latest_versions.get(&model_name).cloned();
140
141 let version = ModelVersion {
143 version_id: version_id.clone(),
144 model_name: model_name.clone(),
145 version_number,
146 created_at,
147 created_by,
148 description,
149 tags,
150 metadata,
151 model_hash,
152 file_path,
153 parent_version,
154 training_config,
155 performance_metrics,
156 model_size: model_data.len() as u64,
157 status: ModelStatus::Trained,
158 };
159
160 self.registry
162 .models
163 .entry(model_name.clone())
164 .or_default()
165 .push(version.clone());
166
167 self.registry.latest_versions.insert(model_name, version_id);
169
170 self.save_registry()?;
172
173 Ok(version)
174 }
175
176 pub fn get_version(&self, model_name: &str, version_id: &str) -> Option<&ModelVersion> {
177 self.registry
178 .models
179 .get(model_name)?
180 .iter()
181 .find(|v| v.version_id == version_id)
182 }
183
184 pub fn get_latest_version(&self, model_name: &str) -> Option<&ModelVersion> {
185 let latest_version_id = self.registry.latest_versions.get(model_name)?;
186 self.get_version(model_name, latest_version_id)
187 }
188
189 pub fn list_versions(&self, model_name: &str) -> Vec<&ModelVersion> {
190 self.registry
191 .models
192 .get(model_name)
193 .map(|versions| {
194 let mut sorted_versions: Vec<_> = versions.iter().collect();
195 sorted_versions.sort_by_key(|v| std::cmp::Reverse(v.created_at));
196 sorted_versions
197 })
198 .unwrap_or_default()
199 }
200
201 pub fn list_models(&self) -> Vec<String> {
202 self.registry.models.keys().cloned().collect()
203 }
204
205 pub fn update_status(
206 &mut self,
207 model_name: &str,
208 version_id: &str,
209 status: ModelStatus,
210 ) -> Result<()> {
211 let versions = self.registry.models.get_mut(model_name).context("Model not found")?;
212
213 let version = versions
214 .iter_mut()
215 .find(|v| v.version_id == version_id)
216 .context("Version not found")?;
217
218 version.status = status;
219 self.save_registry()?;
220
221 Ok(())
222 }
223
224 pub fn add_tag(&mut self, model_name: &str, version_id: &str, tag: String) -> Result<()> {
225 let versions = self.registry.models.get_mut(model_name).context("Model not found")?;
226
227 let version = versions
228 .iter_mut()
229 .find(|v| v.version_id == version_id)
230 .context("Version not found")?;
231
232 if !version.tags.contains(&tag) {
233 version.tags.push(tag);
234 self.save_registry()?;
235 }
236
237 Ok(())
238 }
239
240 pub fn remove_tag(&mut self, model_name: &str, version_id: &str, tag: &str) -> Result<()> {
241 let versions = self.registry.models.get_mut(model_name).context("Model not found")?;
242
243 let version = versions
244 .iter_mut()
245 .find(|v| v.version_id == version_id)
246 .context("Version not found")?;
247
248 version.tags.retain(|t| t != tag);
249 self.save_registry()?;
250
251 Ok(())
252 }
253
254 pub fn find_versions_by_tag(&self, tag: &str) -> Vec<&ModelVersion> {
255 self.registry
256 .models
257 .values()
258 .flatten()
259 .filter(|version| version.tags.contains(&tag.to_string()))
260 .collect()
261 }
262
263 pub fn find_versions_by_performance(
264 &self,
265 metric_name: &str,
266 min_value: f32,
267 max_value: Option<f32>,
268 ) -> Vec<&ModelVersion> {
269 self.registry
270 .models
271 .values()
272 .flatten()
273 .filter(|version| {
274 if let Some(value) = version.performance_metrics.custom_metrics.get(metric_name) {
275 *value >= min_value && max_value.map_or(true, |max| *value <= max)
276 } else {
277 match metric_name {
279 "accuracy" => {
280 let value = version.performance_metrics.accuracy;
281 value >= min_value && max_value.map_or(true, |max| value <= max)
282 },
283 "loss" => {
284 let value = version.performance_metrics.loss;
285 value >= min_value && max_value.map_or(true, |max| value <= max)
286 },
287 "validation_accuracy" => {
288 let value = version.performance_metrics.validation_accuracy;
289 value >= min_value && max_value.map_or(true, |max| value <= max)
290 },
291 "validation_loss" => {
292 let value = version.performance_metrics.validation_loss;
293 value >= min_value && max_value.map_or(true, |max| value <= max)
294 },
295 _ => false,
296 }
297 }
298 })
299 .collect()
300 }
301
302 pub fn delete_version(&mut self, model_name: &str, version_id: &str) -> Result<()> {
303 let version =
305 self.get_version(model_name, version_id).context("Version not found")?.clone();
306
307 if let Some(versions) = self.registry.models.get_mut(model_name) {
309 versions.retain(|v| v.version_id != version_id);
310
311 if self.registry.latest_versions.get(model_name) == Some(&version_id.to_string()) {
313 if let Some(latest) = versions.iter().max_by_key(|v| v.created_at) {
314 self.registry
315 .latest_versions
316 .insert(model_name.to_string(), latest.version_id.clone());
317 } else {
318 self.registry.latest_versions.remove(model_name);
319 }
320 }
321 }
322
323 if version.file_path.exists() {
325 std::fs::remove_file(&version.file_path).context("Failed to delete model file")?;
326 }
327
328 self.save_registry()?;
329
330 Ok(())
331 }
332
333 pub fn load_model_data(&self, model_name: &str, version_id: &str) -> Result<Vec<u8>> {
334 let version = self.get_version(model_name, version_id).context("Version not found")?;
335
336 std::fs::read(&version.file_path).context("Failed to read model data")
337 }
338
339 pub fn get_version_lineage(&self, model_name: &str, version_id: &str) -> Vec<&ModelVersion> {
340 let mut lineage = Vec::new();
341 let mut current_version_id = Some(version_id.to_string());
342
343 while let Some(vid) = current_version_id {
344 if let Some(version) = self.get_version(model_name, &vid) {
345 lineage.push(version);
346 current_version_id = version.parent_version.clone();
347 } else {
348 break;
349 }
350 }
351
352 lineage
353 }
354
355 pub fn compare_versions(
356 &self,
357 model_name: &str,
358 version_id1: &str,
359 version_id2: &str,
360 ) -> Result<VersionComparison> {
361 let version1 =
362 self.get_version(model_name, version_id1).context("First version not found")?;
363 let version2 =
364 self.get_version(model_name, version_id2).context("Second version not found")?;
365
366 Ok(VersionComparison {
367 version1: version1.clone(),
368 version2: version2.clone(),
369 accuracy_diff: version2.performance_metrics.accuracy
370 - version1.performance_metrics.accuracy,
371 loss_diff: version2.performance_metrics.loss - version1.performance_metrics.loss,
372 size_diff: version2.model_size as i64 - version1.model_size as i64,
373 config_changes: self
374 .compare_training_configs(&version1.training_config, &version2.training_config),
375 })
376 }
377
378 fn compare_training_configs(
379 &self,
380 config1: &TrainingConfig,
381 config2: &TrainingConfig,
382 ) -> Vec<String> {
383 let mut changes = Vec::new();
384
385 if config1.learning_rate != config2.learning_rate {
386 changes.push(format!(
387 "Learning rate: {} -> {}",
388 config1.learning_rate, config2.learning_rate
389 ));
390 }
391
392 if config1.batch_size != config2.batch_size {
393 changes.push(format!(
394 "Batch size: {} -> {}",
395 config1.batch_size, config2.batch_size
396 ));
397 }
398
399 if config1.epochs != config2.epochs {
400 changes.push(format!("Epochs: {} -> {}", config1.epochs, config2.epochs));
401 }
402
403 if config1.optimizer != config2.optimizer {
404 changes.push(format!(
405 "Optimizer: {} -> {}",
406 config1.optimizer, config2.optimizer
407 ));
408 }
409
410 if config1.loss_function != config2.loss_function {
411 changes.push(format!(
412 "Loss function: {} -> {}",
413 config1.loss_function, config2.loss_function
414 ));
415 }
416
417 changes
418 }
419
420 fn get_next_version_number(&self, model_name: &str) -> u32 {
421 self.registry
422 .models
423 .get(model_name)
424 .map(|versions| versions.iter().map(|v| v.version_number).max().unwrap_or(0) + 1)
425 .unwrap_or(1)
426 }
427
428 fn save_registry(&self) -> Result<()> {
429 let registry_data =
430 serde_json::to_string_pretty(&self.registry).context("Failed to serialize registry")?;
431
432 std::fs::write(&self.registry.registry_path, registry_data)
433 .context("Failed to save registry")?;
434
435 Ok(())
436 }
437
438 pub fn get_statistics(&self) -> ModelRegistryStatistics {
439 let total_models = self.registry.models.len();
440 let total_versions = self.registry.models.values().map(|v| v.len()).sum();
441 let total_size: u64 = self.registry.models.values().flatten().map(|v| v.model_size).sum();
442
443 let status_counts =
444 self.registry
445 .models
446 .values()
447 .flatten()
448 .fold(HashMap::new(), |mut acc, version| {
449 *acc.entry(format!("{:?}", version.status)).or_insert(0) += 1;
450 acc
451 });
452
453 ModelRegistryStatistics {
454 total_models,
455 total_versions,
456 total_size,
457 status_counts,
458 }
459 }
460}
461
462#[derive(Debug, Clone)]
463pub struct VersionComparison {
464 pub version1: ModelVersion,
465 pub version2: ModelVersion,
466 pub accuracy_diff: f32,
467 pub loss_diff: f32,
468 pub size_diff: i64,
469 pub config_changes: Vec<String>,
470}
471
472#[derive(Debug)]
473pub struct ModelRegistryStatistics {
474 pub total_models: usize,
475 pub total_versions: usize,
476 pub total_size: u64,
477 pub status_counts: HashMap<String, usize>,
478}
479
480#[cfg(test)]
481mod tests {
482 use super::*;
483 use tempfile::TempDir;
484
485 #[test]
486 fn test_model_versioning_manager_creation() {
487 let temp_dir = TempDir::new().expect("temp file creation failed");
488 let manager = ModelVersioningManager::new(temp_dir.path().to_path_buf())
489 .expect("temp file creation failed");
490 assert_eq!(manager.list_models().len(), 0);
491 }
492
493 #[test]
494 fn test_create_version() {
495 let temp_dir = TempDir::new().expect("temp file creation failed");
496 let mut manager = ModelVersioningManager::new(temp_dir.path().to_path_buf())
497 .expect("temp file creation failed");
498
499 let training_config = TrainingConfig {
500 learning_rate: 0.001,
501 batch_size: 32,
502 epochs: 10,
503 optimizer: "Adam".to_string(),
504 loss_function: "CrossEntropy".to_string(),
505 regularization: HashMap::new(),
506 hyperparameters: HashMap::new(),
507 };
508
509 let performance_metrics = PerformanceMetrics {
510 accuracy: 0.95,
511 loss: 0.05,
512 validation_accuracy: 0.93,
513 validation_loss: 0.07,
514 f1_score: Some(0.94),
515 precision: Some(0.96),
516 recall: Some(0.92),
517 custom_metrics: HashMap::new(),
518 };
519
520 let model_data = b"fake model data";
521 let version = manager
522 .create_version(
523 "test_model".to_string(),
524 model_data,
525 "Test model version".to_string(),
526 "test_user".to_string(),
527 vec!["test".to_string()],
528 training_config,
529 performance_metrics,
530 HashMap::new(),
531 )
532 .expect("operation failed in test");
533
534 assert_eq!(version.model_name, "test_model");
535 assert_eq!(version.version_number, 1);
536 assert_eq!(version.description, "Test model version");
537 assert_eq!(version.tags, vec!["test"]);
538 }
539
540 #[test]
541 fn test_get_latest_version() {
542 let temp_dir = TempDir::new().expect("temp file creation failed");
543 let mut manager = ModelVersioningManager::new(temp_dir.path().to_path_buf())
544 .expect("temp file creation failed");
545
546 let training_config = TrainingConfig {
547 learning_rate: 0.001,
548 batch_size: 32,
549 epochs: 10,
550 optimizer: "Adam".to_string(),
551 loss_function: "CrossEntropy".to_string(),
552 regularization: HashMap::new(),
553 hyperparameters: HashMap::new(),
554 };
555
556 let performance_metrics = PerformanceMetrics {
557 accuracy: 0.95,
558 loss: 0.05,
559 validation_accuracy: 0.93,
560 validation_loss: 0.07,
561 f1_score: None,
562 precision: None,
563 recall: None,
564 custom_metrics: HashMap::new(),
565 };
566
567 let model_data1 = b"fake model data v1";
569 manager
570 .create_version(
571 "test_model".to_string(),
572 model_data1,
573 "Version 1".to_string(),
574 "test_user".to_string(),
575 vec![],
576 training_config.clone(),
577 performance_metrics.clone(),
578 HashMap::new(),
579 )
580 .expect("operation failed in test");
581
582 let model_data2 = b"fake model data v2";
584 let version2 = manager
585 .create_version(
586 "test_model".to_string(),
587 model_data2,
588 "Version 2".to_string(),
589 "test_user".to_string(),
590 vec![],
591 training_config,
592 performance_metrics,
593 HashMap::new(),
594 )
595 .expect("operation failed in test");
596
597 let latest = manager.get_latest_version("test_model").expect("operation failed in test");
598 assert_eq!(latest.version_id, version2.version_id);
599 assert_eq!(latest.version_number, 2);
600 }
601
602 #[test]
603 fn test_version_lineage() {
604 let temp_dir = TempDir::new().expect("temp file creation failed");
605 let mut manager = ModelVersioningManager::new(temp_dir.path().to_path_buf())
606 .expect("temp file creation failed");
607
608 let training_config = TrainingConfig {
609 learning_rate: 0.001,
610 batch_size: 32,
611 epochs: 10,
612 optimizer: "Adam".to_string(),
613 loss_function: "CrossEntropy".to_string(),
614 regularization: HashMap::new(),
615 hyperparameters: HashMap::new(),
616 };
617
618 let performance_metrics = PerformanceMetrics {
619 accuracy: 0.95,
620 loss: 0.05,
621 validation_accuracy: 0.93,
622 validation_loss: 0.07,
623 f1_score: None,
624 precision: None,
625 recall: None,
626 custom_metrics: HashMap::new(),
627 };
628
629 manager
631 .create_version(
632 "test_model".to_string(),
633 b"v1",
634 "Version 1".to_string(),
635 "user".to_string(),
636 vec![],
637 training_config.clone(),
638 performance_metrics.clone(),
639 HashMap::new(),
640 )
641 .expect("operation failed in test");
642
643 manager
644 .create_version(
645 "test_model".to_string(),
646 b"v2",
647 "Version 2".to_string(),
648 "user".to_string(),
649 vec![],
650 training_config.clone(),
651 performance_metrics.clone(),
652 HashMap::new(),
653 )
654 .expect("operation failed in test");
655
656 let version3 = manager
657 .create_version(
658 "test_model".to_string(),
659 b"v3",
660 "Version 3".to_string(),
661 "user".to_string(),
662 vec![],
663 training_config,
664 performance_metrics,
665 HashMap::new(),
666 )
667 .expect("operation failed in test");
668
669 let lineage = manager.get_version_lineage("test_model", &version3.version_id);
670 assert_eq!(lineage.len(), 3);
671 assert_eq!(lineage[0].version_number, 3);
672 assert_eq!(lineage[1].version_number, 2);
673 assert_eq!(lineage[2].version_number, 1);
674 }
675
676 #[test]
677 fn test_find_versions_by_tag() {
678 let temp_dir = TempDir::new().expect("temp file creation failed");
679 let mut manager = ModelVersioningManager::new(temp_dir.path().to_path_buf())
680 .expect("temp file creation failed");
681
682 let training_config = TrainingConfig {
683 learning_rate: 0.001,
684 batch_size: 32,
685 epochs: 10,
686 optimizer: "Adam".to_string(),
687 loss_function: "CrossEntropy".to_string(),
688 regularization: HashMap::new(),
689 hyperparameters: HashMap::new(),
690 };
691
692 let performance_metrics = PerformanceMetrics {
693 accuracy: 0.95,
694 loss: 0.05,
695 validation_accuracy: 0.93,
696 validation_loss: 0.07,
697 f1_score: None,
698 precision: None,
699 recall: None,
700 custom_metrics: HashMap::new(),
701 };
702
703 manager
705 .create_version(
706 "model1".to_string(),
707 b"data",
708 "Production model".to_string(),
709 "user".to_string(),
710 vec!["production".to_string()],
711 training_config.clone(),
712 performance_metrics.clone(),
713 HashMap::new(),
714 )
715 .expect("operation failed in test");
716
717 manager
719 .create_version(
720 "model2".to_string(),
721 b"data",
722 "Dev model".to_string(),
723 "user".to_string(),
724 vec!["development".to_string()],
725 training_config,
726 performance_metrics,
727 HashMap::new(),
728 )
729 .expect("operation failed in test");
730
731 let production_versions = manager.find_versions_by_tag("production");
732 assert_eq!(production_versions.len(), 1);
733 assert_eq!(production_versions[0].model_name, "model1");
734
735 let dev_versions = manager.find_versions_by_tag("development");
736 assert_eq!(dev_versions.len(), 1);
737 assert_eq!(dev_versions[0].model_name, "model2");
738 }
739}