Skip to main content

trustformers_training/
model_versioning.rs

1use anyhow::{Context, Result};
2use serde::{Deserialize, Serialize};
3use sha2::{Digest, Sha256};
4use std::collections::HashMap;
5use std::path::PathBuf;
6use std::time::{SystemTime, UNIX_EPOCH};
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct ModelVersion {
10    pub version_id: String,
11    pub model_name: String,
12    pub version_number: u32,
13    pub created_at: u64,
14    pub created_by: String,
15    pub description: String,
16    pub tags: Vec<String>,
17    pub metadata: HashMap<String, String>,
18    pub model_hash: String,
19    pub file_path: PathBuf,
20    pub parent_version: Option<String>,
21    pub training_config: TrainingConfig,
22    pub performance_metrics: PerformanceMetrics,
23    pub model_size: u64,
24    pub status: ModelStatus,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct TrainingConfig {
29    pub learning_rate: f32,
30    pub batch_size: usize,
31    pub epochs: u32,
32    pub optimizer: String,
33    pub loss_function: String,
34    pub regularization: HashMap<String, f32>,
35    pub hyperparameters: HashMap<String, String>,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct PerformanceMetrics {
40    pub accuracy: f32,
41    pub loss: f32,
42    pub validation_accuracy: f32,
43    pub validation_loss: f32,
44    pub f1_score: Option<f32>,
45    pub precision: Option<f32>,
46    pub recall: Option<f32>,
47    pub custom_metrics: HashMap<String, f32>,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub enum ModelStatus {
52    Training,
53    Trained,
54    Validated,
55    Deployed,
56    Archived,
57    Failed,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct ModelRegistry {
62    pub models: HashMap<String, Vec<ModelVersion>>,
63    pub latest_versions: HashMap<String, String>,
64    pub registry_path: PathBuf,
65}
66
67pub struct ModelVersioningManager {
68    registry: ModelRegistry,
69    storage_root: PathBuf,
70}
71
72impl ModelVersioningManager {
73    pub fn new(storage_root: PathBuf) -> Result<Self> {
74        let registry_path = storage_root.join("model_registry.json");
75
76        let registry = if registry_path.exists() {
77            let registry_data =
78                std::fs::read_to_string(&registry_path).context("Failed to read model registry")?;
79            serde_json::from_str(&registry_data).context("Failed to parse model registry")?
80        } else {
81            ModelRegistry {
82                models: HashMap::new(),
83                latest_versions: HashMap::new(),
84                registry_path: registry_path.clone(),
85            }
86        };
87
88        // Create storage root if it doesn't exist
89        std::fs::create_dir_all(&storage_root)
90            .context("Failed to create storage root directory")?;
91
92        Ok(Self {
93            registry,
94            storage_root,
95        })
96    }
97
98    pub fn create_version(
99        &mut self,
100        model_name: String,
101        model_data: &[u8],
102        description: String,
103        created_by: String,
104        tags: Vec<String>,
105        training_config: TrainingConfig,
106        performance_metrics: PerformanceMetrics,
107        metadata: HashMap<String, String>,
108    ) -> Result<ModelVersion> {
109        // Generate version number
110        let version_number = self.get_next_version_number(&model_name);
111
112        // Generate version ID
113        let version_id = format!("{}_{:04}", model_name, version_number);
114
115        // Calculate model hash
116        let mut hasher = Sha256::new();
117        hasher.update(model_data);
118        let model_hash = format!("{:x}", hasher.finalize());
119
120        // Create file path
121        let file_name = format!("{}.model", version_id);
122        let file_path = self.storage_root.join(&model_name).join(&file_name);
123
124        // Create model directory if it doesn't exist
125        if let Some(parent) = file_path.parent() {
126            std::fs::create_dir_all(parent).context("Failed to create model directory")?;
127        }
128
129        // Save model data
130        std::fs::write(&file_path, model_data).context("Failed to save model data")?;
131
132        // Get current timestamp
133        let created_at = SystemTime::now()
134            .duration_since(UNIX_EPOCH)
135            .expect("SystemTime should be after UNIX_EPOCH")
136            .as_secs();
137
138        // Get parent version (latest version of the same model)
139        let parent_version = self.registry.latest_versions.get(&model_name).cloned();
140
141        // Create version record
142        let version = ModelVersion {
143            version_id: version_id.clone(),
144            model_name: model_name.clone(),
145            version_number,
146            created_at,
147            created_by,
148            description,
149            tags,
150            metadata,
151            model_hash,
152            file_path,
153            parent_version,
154            training_config,
155            performance_metrics,
156            model_size: model_data.len() as u64,
157            status: ModelStatus::Trained,
158        };
159
160        // Add to registry
161        self.registry
162            .models
163            .entry(model_name.clone())
164            .or_default()
165            .push(version.clone());
166
167        // Update latest version
168        self.registry.latest_versions.insert(model_name, version_id);
169
170        // Save registry
171        self.save_registry()?;
172
173        Ok(version)
174    }
175
176    pub fn get_version(&self, model_name: &str, version_id: &str) -> Option<&ModelVersion> {
177        self.registry
178            .models
179            .get(model_name)?
180            .iter()
181            .find(|v| v.version_id == version_id)
182    }
183
184    pub fn get_latest_version(&self, model_name: &str) -> Option<&ModelVersion> {
185        let latest_version_id = self.registry.latest_versions.get(model_name)?;
186        self.get_version(model_name, latest_version_id)
187    }
188
189    pub fn list_versions(&self, model_name: &str) -> Vec<&ModelVersion> {
190        self.registry
191            .models
192            .get(model_name)
193            .map(|versions| {
194                let mut sorted_versions: Vec<_> = versions.iter().collect();
195                sorted_versions.sort_by_key(|v| std::cmp::Reverse(v.created_at));
196                sorted_versions
197            })
198            .unwrap_or_default()
199    }
200
201    pub fn list_models(&self) -> Vec<String> {
202        self.registry.models.keys().cloned().collect()
203    }
204
205    pub fn update_status(
206        &mut self,
207        model_name: &str,
208        version_id: &str,
209        status: ModelStatus,
210    ) -> Result<()> {
211        let versions = self.registry.models.get_mut(model_name).context("Model not found")?;
212
213        let version = versions
214            .iter_mut()
215            .find(|v| v.version_id == version_id)
216            .context("Version not found")?;
217
218        version.status = status;
219        self.save_registry()?;
220
221        Ok(())
222    }
223
224    pub fn add_tag(&mut self, model_name: &str, version_id: &str, tag: String) -> Result<()> {
225        let versions = self.registry.models.get_mut(model_name).context("Model not found")?;
226
227        let version = versions
228            .iter_mut()
229            .find(|v| v.version_id == version_id)
230            .context("Version not found")?;
231
232        if !version.tags.contains(&tag) {
233            version.tags.push(tag);
234            self.save_registry()?;
235        }
236
237        Ok(())
238    }
239
240    pub fn remove_tag(&mut self, model_name: &str, version_id: &str, tag: &str) -> Result<()> {
241        let versions = self.registry.models.get_mut(model_name).context("Model not found")?;
242
243        let version = versions
244            .iter_mut()
245            .find(|v| v.version_id == version_id)
246            .context("Version not found")?;
247
248        version.tags.retain(|t| t != tag);
249        self.save_registry()?;
250
251        Ok(())
252    }
253
254    pub fn find_versions_by_tag(&self, tag: &str) -> Vec<&ModelVersion> {
255        self.registry
256            .models
257            .values()
258            .flatten()
259            .filter(|version| version.tags.contains(&tag.to_string()))
260            .collect()
261    }
262
263    pub fn find_versions_by_performance(
264        &self,
265        metric_name: &str,
266        min_value: f32,
267        max_value: Option<f32>,
268    ) -> Vec<&ModelVersion> {
269        self.registry
270            .models
271            .values()
272            .flatten()
273            .filter(|version| {
274                if let Some(value) = version.performance_metrics.custom_metrics.get(metric_name) {
275                    *value >= min_value && max_value.map_or(true, |max| *value <= max)
276                } else {
277                    // Check standard metrics
278                    match metric_name {
279                        "accuracy" => {
280                            let value = version.performance_metrics.accuracy;
281                            value >= min_value && max_value.map_or(true, |max| value <= max)
282                        },
283                        "loss" => {
284                            let value = version.performance_metrics.loss;
285                            value >= min_value && max_value.map_or(true, |max| value <= max)
286                        },
287                        "validation_accuracy" => {
288                            let value = version.performance_metrics.validation_accuracy;
289                            value >= min_value && max_value.map_or(true, |max| value <= max)
290                        },
291                        "validation_loss" => {
292                            let value = version.performance_metrics.validation_loss;
293                            value >= min_value && max_value.map_or(true, |max| value <= max)
294                        },
295                        _ => false,
296                    }
297                }
298            })
299            .collect()
300    }
301
302    pub fn delete_version(&mut self, model_name: &str, version_id: &str) -> Result<()> {
303        // Get the version to delete
304        let version =
305            self.get_version(model_name, version_id).context("Version not found")?.clone();
306
307        // Remove from registry
308        if let Some(versions) = self.registry.models.get_mut(model_name) {
309            versions.retain(|v| v.version_id != version_id);
310
311            // If this was the latest version, update the latest version
312            if self.registry.latest_versions.get(model_name) == Some(&version_id.to_string()) {
313                if let Some(latest) = versions.iter().max_by_key(|v| v.created_at) {
314                    self.registry
315                        .latest_versions
316                        .insert(model_name.to_string(), latest.version_id.clone());
317                } else {
318                    self.registry.latest_versions.remove(model_name);
319                }
320            }
321        }
322
323        // Delete the model file
324        if version.file_path.exists() {
325            std::fs::remove_file(&version.file_path).context("Failed to delete model file")?;
326        }
327
328        self.save_registry()?;
329
330        Ok(())
331    }
332
333    pub fn load_model_data(&self, model_name: &str, version_id: &str) -> Result<Vec<u8>> {
334        let version = self.get_version(model_name, version_id).context("Version not found")?;
335
336        std::fs::read(&version.file_path).context("Failed to read model data")
337    }
338
339    pub fn get_version_lineage(&self, model_name: &str, version_id: &str) -> Vec<&ModelVersion> {
340        let mut lineage = Vec::new();
341        let mut current_version_id = Some(version_id.to_string());
342
343        while let Some(vid) = current_version_id {
344            if let Some(version) = self.get_version(model_name, &vid) {
345                lineage.push(version);
346                current_version_id = version.parent_version.clone();
347            } else {
348                break;
349            }
350        }
351
352        lineage
353    }
354
355    pub fn compare_versions(
356        &self,
357        model_name: &str,
358        version_id1: &str,
359        version_id2: &str,
360    ) -> Result<VersionComparison> {
361        let version1 =
362            self.get_version(model_name, version_id1).context("First version not found")?;
363        let version2 =
364            self.get_version(model_name, version_id2).context("Second version not found")?;
365
366        Ok(VersionComparison {
367            version1: version1.clone(),
368            version2: version2.clone(),
369            accuracy_diff: version2.performance_metrics.accuracy
370                - version1.performance_metrics.accuracy,
371            loss_diff: version2.performance_metrics.loss - version1.performance_metrics.loss,
372            size_diff: version2.model_size as i64 - version1.model_size as i64,
373            config_changes: self
374                .compare_training_configs(&version1.training_config, &version2.training_config),
375        })
376    }
377
378    fn compare_training_configs(
379        &self,
380        config1: &TrainingConfig,
381        config2: &TrainingConfig,
382    ) -> Vec<String> {
383        let mut changes = Vec::new();
384
385        if config1.learning_rate != config2.learning_rate {
386            changes.push(format!(
387                "Learning rate: {} -> {}",
388                config1.learning_rate, config2.learning_rate
389            ));
390        }
391
392        if config1.batch_size != config2.batch_size {
393            changes.push(format!(
394                "Batch size: {} -> {}",
395                config1.batch_size, config2.batch_size
396            ));
397        }
398
399        if config1.epochs != config2.epochs {
400            changes.push(format!("Epochs: {} -> {}", config1.epochs, config2.epochs));
401        }
402
403        if config1.optimizer != config2.optimizer {
404            changes.push(format!(
405                "Optimizer: {} -> {}",
406                config1.optimizer, config2.optimizer
407            ));
408        }
409
410        if config1.loss_function != config2.loss_function {
411            changes.push(format!(
412                "Loss function: {} -> {}",
413                config1.loss_function, config2.loss_function
414            ));
415        }
416
417        changes
418    }
419
420    fn get_next_version_number(&self, model_name: &str) -> u32 {
421        self.registry
422            .models
423            .get(model_name)
424            .map(|versions| versions.iter().map(|v| v.version_number).max().unwrap_or(0) + 1)
425            .unwrap_or(1)
426    }
427
428    fn save_registry(&self) -> Result<()> {
429        let registry_data =
430            serde_json::to_string_pretty(&self.registry).context("Failed to serialize registry")?;
431
432        std::fs::write(&self.registry.registry_path, registry_data)
433            .context("Failed to save registry")?;
434
435        Ok(())
436    }
437
438    pub fn get_statistics(&self) -> ModelRegistryStatistics {
439        let total_models = self.registry.models.len();
440        let total_versions = self.registry.models.values().map(|v| v.len()).sum();
441        let total_size: u64 = self.registry.models.values().flatten().map(|v| v.model_size).sum();
442
443        let status_counts =
444            self.registry
445                .models
446                .values()
447                .flatten()
448                .fold(HashMap::new(), |mut acc, version| {
449                    *acc.entry(format!("{:?}", version.status)).or_insert(0) += 1;
450                    acc
451                });
452
453        ModelRegistryStatistics {
454            total_models,
455            total_versions,
456            total_size,
457            status_counts,
458        }
459    }
460}
461
462#[derive(Debug, Clone)]
463pub struct VersionComparison {
464    pub version1: ModelVersion,
465    pub version2: ModelVersion,
466    pub accuracy_diff: f32,
467    pub loss_diff: f32,
468    pub size_diff: i64,
469    pub config_changes: Vec<String>,
470}
471
472#[derive(Debug)]
473pub struct ModelRegistryStatistics {
474    pub total_models: usize,
475    pub total_versions: usize,
476    pub total_size: u64,
477    pub status_counts: HashMap<String, usize>,
478}
479
480#[cfg(test)]
481mod tests {
482    use super::*;
483    use tempfile::TempDir;
484
485    #[test]
486    fn test_model_versioning_manager_creation() {
487        let temp_dir = TempDir::new().expect("temp file creation failed");
488        let manager = ModelVersioningManager::new(temp_dir.path().to_path_buf())
489            .expect("temp file creation failed");
490        assert_eq!(manager.list_models().len(), 0);
491    }
492
493    #[test]
494    fn test_create_version() {
495        let temp_dir = TempDir::new().expect("temp file creation failed");
496        let mut manager = ModelVersioningManager::new(temp_dir.path().to_path_buf())
497            .expect("temp file creation failed");
498
499        let training_config = TrainingConfig {
500            learning_rate: 0.001,
501            batch_size: 32,
502            epochs: 10,
503            optimizer: "Adam".to_string(),
504            loss_function: "CrossEntropy".to_string(),
505            regularization: HashMap::new(),
506            hyperparameters: HashMap::new(),
507        };
508
509        let performance_metrics = PerformanceMetrics {
510            accuracy: 0.95,
511            loss: 0.05,
512            validation_accuracy: 0.93,
513            validation_loss: 0.07,
514            f1_score: Some(0.94),
515            precision: Some(0.96),
516            recall: Some(0.92),
517            custom_metrics: HashMap::new(),
518        };
519
520        let model_data = b"fake model data";
521        let version = manager
522            .create_version(
523                "test_model".to_string(),
524                model_data,
525                "Test model version".to_string(),
526                "test_user".to_string(),
527                vec!["test".to_string()],
528                training_config,
529                performance_metrics,
530                HashMap::new(),
531            )
532            .expect("operation failed in test");
533
534        assert_eq!(version.model_name, "test_model");
535        assert_eq!(version.version_number, 1);
536        assert_eq!(version.description, "Test model version");
537        assert_eq!(version.tags, vec!["test"]);
538    }
539
540    #[test]
541    fn test_get_latest_version() {
542        let temp_dir = TempDir::new().expect("temp file creation failed");
543        let mut manager = ModelVersioningManager::new(temp_dir.path().to_path_buf())
544            .expect("temp file creation failed");
545
546        let training_config = TrainingConfig {
547            learning_rate: 0.001,
548            batch_size: 32,
549            epochs: 10,
550            optimizer: "Adam".to_string(),
551            loss_function: "CrossEntropy".to_string(),
552            regularization: HashMap::new(),
553            hyperparameters: HashMap::new(),
554        };
555
556        let performance_metrics = PerformanceMetrics {
557            accuracy: 0.95,
558            loss: 0.05,
559            validation_accuracy: 0.93,
560            validation_loss: 0.07,
561            f1_score: None,
562            precision: None,
563            recall: None,
564            custom_metrics: HashMap::new(),
565        };
566
567        // Create first version
568        let model_data1 = b"fake model data v1";
569        manager
570            .create_version(
571                "test_model".to_string(),
572                model_data1,
573                "Version 1".to_string(),
574                "test_user".to_string(),
575                vec![],
576                training_config.clone(),
577                performance_metrics.clone(),
578                HashMap::new(),
579            )
580            .expect("operation failed in test");
581
582        // Create second version
583        let model_data2 = b"fake model data v2";
584        let version2 = manager
585            .create_version(
586                "test_model".to_string(),
587                model_data2,
588                "Version 2".to_string(),
589                "test_user".to_string(),
590                vec![],
591                training_config,
592                performance_metrics,
593                HashMap::new(),
594            )
595            .expect("operation failed in test");
596
597        let latest = manager.get_latest_version("test_model").expect("operation failed in test");
598        assert_eq!(latest.version_id, version2.version_id);
599        assert_eq!(latest.version_number, 2);
600    }
601
602    #[test]
603    fn test_version_lineage() {
604        let temp_dir = TempDir::new().expect("temp file creation failed");
605        let mut manager = ModelVersioningManager::new(temp_dir.path().to_path_buf())
606            .expect("temp file creation failed");
607
608        let training_config = TrainingConfig {
609            learning_rate: 0.001,
610            batch_size: 32,
611            epochs: 10,
612            optimizer: "Adam".to_string(),
613            loss_function: "CrossEntropy".to_string(),
614            regularization: HashMap::new(),
615            hyperparameters: HashMap::new(),
616        };
617
618        let performance_metrics = PerformanceMetrics {
619            accuracy: 0.95,
620            loss: 0.05,
621            validation_accuracy: 0.93,
622            validation_loss: 0.07,
623            f1_score: None,
624            precision: None,
625            recall: None,
626            custom_metrics: HashMap::new(),
627        };
628
629        // Create multiple versions
630        manager
631            .create_version(
632                "test_model".to_string(),
633                b"v1",
634                "Version 1".to_string(),
635                "user".to_string(),
636                vec![],
637                training_config.clone(),
638                performance_metrics.clone(),
639                HashMap::new(),
640            )
641            .expect("operation failed in test");
642
643        manager
644            .create_version(
645                "test_model".to_string(),
646                b"v2",
647                "Version 2".to_string(),
648                "user".to_string(),
649                vec![],
650                training_config.clone(),
651                performance_metrics.clone(),
652                HashMap::new(),
653            )
654            .expect("operation failed in test");
655
656        let version3 = manager
657            .create_version(
658                "test_model".to_string(),
659                b"v3",
660                "Version 3".to_string(),
661                "user".to_string(),
662                vec![],
663                training_config,
664                performance_metrics,
665                HashMap::new(),
666            )
667            .expect("operation failed in test");
668
669        let lineage = manager.get_version_lineage("test_model", &version3.version_id);
670        assert_eq!(lineage.len(), 3);
671        assert_eq!(lineage[0].version_number, 3);
672        assert_eq!(lineage[1].version_number, 2);
673        assert_eq!(lineage[2].version_number, 1);
674    }
675
676    #[test]
677    fn test_find_versions_by_tag() {
678        let temp_dir = TempDir::new().expect("temp file creation failed");
679        let mut manager = ModelVersioningManager::new(temp_dir.path().to_path_buf())
680            .expect("temp file creation failed");
681
682        let training_config = TrainingConfig {
683            learning_rate: 0.001,
684            batch_size: 32,
685            epochs: 10,
686            optimizer: "Adam".to_string(),
687            loss_function: "CrossEntropy".to_string(),
688            regularization: HashMap::new(),
689            hyperparameters: HashMap::new(),
690        };
691
692        let performance_metrics = PerformanceMetrics {
693            accuracy: 0.95,
694            loss: 0.05,
695            validation_accuracy: 0.93,
696            validation_loss: 0.07,
697            f1_score: None,
698            precision: None,
699            recall: None,
700            custom_metrics: HashMap::new(),
701        };
702
703        // Create version with production tag
704        manager
705            .create_version(
706                "model1".to_string(),
707                b"data",
708                "Production model".to_string(),
709                "user".to_string(),
710                vec!["production".to_string()],
711                training_config.clone(),
712                performance_metrics.clone(),
713                HashMap::new(),
714            )
715            .expect("operation failed in test");
716
717        // Create version with development tag
718        manager
719            .create_version(
720                "model2".to_string(),
721                b"data",
722                "Dev model".to_string(),
723                "user".to_string(),
724                vec!["development".to_string()],
725                training_config,
726                performance_metrics,
727                HashMap::new(),
728            )
729            .expect("operation failed in test");
730
731        let production_versions = manager.find_versions_by_tag("production");
732        assert_eq!(production_versions.len(), 1);
733        assert_eq!(production_versions[0].model_name, "model1");
734
735        let dev_versions = manager.find_versions_by_tag("development");
736        assert_eq!(dev_versions.len(), 1);
737        assert_eq!(dev_versions[0].model_name, "model2");
738    }
739}