Skip to main content

trustformers_core/versioning/
metadata.rs

1//! Model metadata and version definitions
2
3use anyhow::Result;
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use uuid::Uuid;
8
9/// Model metadata containing version information
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct ModelMetadata {
12    /// Human-readable description
13    pub description: String,
14    /// Creator/author of this version
15    pub created_by: String,
16    /// Creation timestamp
17    pub created_at: DateTime<Utc>,
18    /// Model type (e.g., "transformer", "cnn", "rnn")
19    pub model_type: String,
20    /// Model architecture (e.g., "gpt2", "bert", "llama")
21    pub architecture: Option<String>,
22    /// Model size/parameters (e.g., "125M", "1.3B", "7B")
23    pub size: Option<String>,
24    /// Training configuration used
25    pub training_config: Option<serde_json::Value>,
26    /// Performance metrics
27    pub metrics: HashMap<String, f64>,
28    /// Tags for categorization
29    pub tags: Vec<ModelTag>,
30    /// Custom attributes
31    pub attributes: HashMap<String, serde_json::Value>,
32    /// Source information (dataset, training run, etc.)
33    pub source: Option<ModelSource>,
34    /// Checksum for integrity verification
35    pub checksum: Option<String>,
36    /// Model size in bytes
37    pub size_bytes: Option<u64>,
38    /// Compatible framework versions
39    pub framework_versions: Vec<String>,
40}
41
42impl ModelMetadata {
43    /// Create a new metadata builder
44    pub fn builder() -> ModelMetadataBuilder {
45        ModelMetadataBuilder::new()
46    }
47
48    /// Add a metric
49    pub fn add_metric(&mut self, name: String, value: f64) {
50        self.metrics.insert(name, value);
51    }
52
53    /// Add a tag
54    pub fn add_tag(&mut self, tag: ModelTag) {
55        self.tags.push(tag);
56    }
57
58    /// Get metric value
59    pub fn get_metric(&self, name: &str) -> Option<f64> {
60        self.metrics.get(name).copied()
61    }
62
63    /// Check if model has tag
64    pub fn has_tag(&self, tag_name: &str) -> bool {
65        self.tags.iter().any(|t| t.name == tag_name)
66    }
67}
68
69/// Builder for model metadata
70pub struct ModelMetadataBuilder {
71    description: Option<String>,
72    created_by: Option<String>,
73    model_type: Option<String>,
74    architecture: Option<String>,
75    size: Option<String>,
76    training_config: Option<serde_json::Value>,
77    metrics: HashMap<String, f64>,
78    tags: Vec<ModelTag>,
79    attributes: HashMap<String, serde_json::Value>,
80    source: Option<ModelSource>,
81    checksum: Option<String>,
82    size_bytes: Option<u64>,
83    framework_versions: Vec<String>,
84}
85
86impl ModelMetadataBuilder {
87    fn new() -> Self {
88        Self {
89            description: None,
90            created_by: None,
91            model_type: None,
92            architecture: None,
93            size: None,
94            training_config: None,
95            metrics: HashMap::new(),
96            tags: Vec::new(),
97            attributes: HashMap::new(),
98            source: None,
99            checksum: None,
100            size_bytes: None,
101            framework_versions: Vec::new(),
102        }
103    }
104
105    pub fn description(mut self, description: String) -> Self {
106        self.description = Some(description);
107        self
108    }
109
110    pub fn created_by(mut self, created_by: String) -> Self {
111        self.created_by = Some(created_by);
112        self
113    }
114
115    pub fn model_type(mut self, model_type: String) -> Self {
116        self.model_type = Some(model_type);
117        self
118    }
119
120    pub fn architecture(mut self, architecture: String) -> Self {
121        self.architecture = Some(architecture);
122        self
123    }
124
125    pub fn size(mut self, size: String) -> Self {
126        self.size = Some(size);
127        self
128    }
129
130    pub fn training_config(mut self, config: serde_json::Value) -> Self {
131        self.training_config = Some(config);
132        self
133    }
134
135    pub fn metric(mut self, name: String, value: f64) -> Self {
136        self.metrics.insert(name, value);
137        self
138    }
139
140    pub fn tag(mut self, tag: ModelTag) -> Self {
141        self.tags.push(tag);
142        self
143    }
144
145    pub fn attribute(mut self, key: String, value: serde_json::Value) -> Self {
146        self.attributes.insert(key, value);
147        self
148    }
149
150    pub fn source(mut self, source: ModelSource) -> Self {
151        self.source = Some(source);
152        self
153    }
154
155    pub fn checksum(mut self, checksum: String) -> Self {
156        self.checksum = Some(checksum);
157        self
158    }
159
160    pub fn size_bytes(mut self, size_bytes: u64) -> Self {
161        self.size_bytes = Some(size_bytes);
162        self
163    }
164
165    pub fn framework_version(mut self, version: String) -> Self {
166        self.framework_versions.push(version);
167        self
168    }
169
170    pub fn build(self) -> ModelMetadata {
171        ModelMetadata {
172            description: self.description.unwrap_or_default(),
173            created_by: self.created_by.unwrap_or_default(),
174            created_at: Utc::now(),
175            model_type: self.model_type.unwrap_or_default(),
176            architecture: self.architecture,
177            size: self.size,
178            training_config: self.training_config,
179            metrics: self.metrics,
180            tags: self.tags,
181            attributes: self.attributes,
182            source: self.source,
183            checksum: self.checksum,
184            size_bytes: self.size_bytes,
185            framework_versions: self.framework_versions,
186        }
187    }
188}
189
190/// Model source information
191#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct ModelSource {
193    /// Source type (e.g., "training", "fine_tuning", "conversion")
194    pub source_type: String,
195    /// Training dataset name/identifier
196    pub dataset: Option<String>,
197    /// Training run identifier
198    pub training_run_id: Option<String>,
199    /// Base model (for fine-tuned models)
200    pub base_model: Option<String>,
201    /// Training configuration reference
202    pub config_ref: Option<String>,
203    /// Additional source metadata
204    pub metadata: HashMap<String, serde_json::Value>,
205}
206
207/// Model tag for categorization
208#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
209pub struct ModelTag {
210    pub name: String,
211    pub value: Option<String>,
212    pub category: Option<String>,
213}
214
215impl ModelTag {
216    pub fn new(name: &str) -> Self {
217        Self {
218            name: name.to_string(),
219            value: None,
220            category: None,
221        }
222    }
223
224    pub fn with_value(name: &str, value: &str) -> Self {
225        Self {
226            name: name.to_string(),
227            value: Some(value.to_string()),
228            category: None,
229        }
230    }
231
232    pub fn with_category(name: &str, value: &str, category: &str) -> Self {
233        Self {
234            name: name.to_string(),
235            value: Some(value.to_string()),
236            category: Some(category.to_string()),
237        }
238    }
239}
240
241/// A versioned model containing metadata and artifact references
242#[derive(Debug, Clone, Serialize, Deserialize)]
243pub struct VersionedModel {
244    /// Unique identifier
245    id: Uuid,
246    /// Model name
247    model_name: String,
248    /// Version string (e.g., "1.0.0", "v2.1-beta")
249    version: String,
250    /// Model metadata
251    metadata: ModelMetadata,
252    /// Artifact IDs
253    artifact_ids: Vec<Uuid>,
254    /// Parent version (for incremental versions)
255    parent_version: Option<Uuid>,
256    /// Child versions (derived from this version)
257    child_versions: Vec<Uuid>,
258}
259
260impl VersionedModel {
261    /// Create a new versioned model
262    pub fn new(
263        model_name: String,
264        version: String,
265        metadata: ModelMetadata,
266        artifact_ids: Vec<Uuid>,
267    ) -> Self {
268        Self {
269            id: Uuid::new_v4(),
270            model_name,
271            version,
272            metadata,
273            artifact_ids,
274            parent_version: None,
275            child_versions: Vec::new(),
276        }
277    }
278
279    /// Create a versioned model with parent
280    pub fn with_parent(
281        model_name: String,
282        version: String,
283        metadata: ModelMetadata,
284        artifact_ids: Vec<Uuid>,
285        parent_id: Uuid,
286    ) -> Self {
287        Self {
288            id: Uuid::new_v4(),
289            model_name,
290            version,
291            metadata,
292            artifact_ids,
293            parent_version: Some(parent_id),
294            child_versions: Vec::new(),
295        }
296    }
297
298    /// Get unique identifier
299    pub fn id(&self) -> Uuid {
300        self.id
301    }
302
303    /// Get model name
304    pub fn model_name(&self) -> &str {
305        &self.model_name
306    }
307
308    /// Get version string
309    pub fn version(&self) -> &str {
310        &self.version
311    }
312
313    /// Get metadata
314    pub fn metadata(&self) -> &ModelMetadata {
315        &self.metadata
316    }
317
318    /// Get artifact IDs
319    pub fn artifact_ids(&self) -> &[Uuid] {
320        &self.artifact_ids
321    }
322
323    /// Get parent version ID
324    pub fn parent_version(&self) -> Option<Uuid> {
325        self.parent_version
326    }
327
328    /// Get child version IDs
329    pub fn child_versions(&self) -> &[Uuid] {
330        &self.child_versions
331    }
332
333    /// Add child version
334    pub fn add_child(&mut self, child_id: Uuid) {
335        if !self.child_versions.contains(&child_id) {
336            self.child_versions.push(child_id);
337        }
338    }
339
340    /// Remove child version
341    pub fn remove_child(&mut self, child_id: Uuid) {
342        self.child_versions.retain(|&id| id != child_id);
343    }
344
345    /// Check if this is a root version (no parent)
346    pub fn is_root(&self) -> bool {
347        self.parent_version.is_none()
348    }
349
350    /// Check if this is a leaf version (no children)
351    pub fn is_leaf(&self) -> bool {
352        self.child_versions.is_empty()
353    }
354
355    /// Get full qualified name
356    pub fn qualified_name(&self) -> String {
357        format!("{}:{}", self.model_name, self.version)
358    }
359
360    /// Validate version format
361    pub fn validate_version_format(&self) -> Result<()> {
362        // Basic semantic version validation
363        if self.version.is_empty() {
364            anyhow::bail!("Version cannot be empty");
365        }
366
367        // Allow semver, git tags, or custom formats
368        if !self.is_valid_version_format() {
369            anyhow::bail!("Invalid version format: {}", self.version);
370        }
371
372        Ok(())
373    }
374
375    fn is_valid_version_format(&self) -> bool {
376        // Accept semver (1.0.0), git-style (v1.0.0), or custom formats
377        let version = &self.version;
378
379        // Semver pattern
380        if regex::Regex::new(r"^\d+\.\d+\.\d+(-[a-zA-Z0-9.-]+)?(\+[a-zA-Z0-9.-]+)?$")
381            .expect("semver regex pattern is valid")
382            .is_match(version)
383        {
384            return true;
385        }
386
387        // Git tag pattern
388        if regex::Regex::new(r"^v?\d+\.\d+(\.\d+)?(-[a-zA-Z0-9.-]+)?$")
389            .expect("git tag regex pattern is valid")
390            .is_match(version)
391        {
392            return true;
393        }
394
395        // Custom format (alphanumeric, dots, dashes, underscores)
396        if regex::Regex::new(r"^[a-zA-Z0-9._-]+$")
397            .expect("Regex compilation failed")
398            .is_match(version)
399        {
400            return true;
401        }
402
403        false
404    }
405}
406
407/// Version comparison for sorting
408impl PartialOrd for VersionedModel {
409    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
410        Some(self.cmp(other))
411    }
412}
413
414impl Ord for VersionedModel {
415    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
416        // First compare by model name
417        match self.model_name.cmp(&other.model_name) {
418            std::cmp::Ordering::Equal => {
419                // Then by creation time
420                self.metadata.created_at.cmp(&other.metadata.created_at)
421            },
422            other => other,
423        }
424    }
425}
426
427impl PartialEq for VersionedModel {
428    fn eq(&self, other: &Self) -> bool {
429        self.id == other.id
430    }
431}
432
433impl Eq for VersionedModel {}
434
435#[cfg(test)]
436mod tests {
437    use super::*;
438
439    #[test]
440    fn test_metadata_builder() {
441        let metadata = ModelMetadata::builder()
442            .description("Test model".to_string())
443            .created_by("test_user".to_string())
444            .model_type("transformer".to_string())
445            .architecture("gpt2".to_string())
446            .metric("accuracy".to_string(), 0.95)
447            .tag(ModelTag::new("experimental"))
448            .build();
449
450        assert_eq!(metadata.description, "Test model");
451        assert_eq!(metadata.created_by, "test_user");
452        assert_eq!(metadata.model_type, "transformer");
453        assert_eq!(metadata.architecture, Some("gpt2".to_string()));
454        assert_eq!(metadata.get_metric("accuracy"), Some(0.95));
455        assert!(metadata.has_tag("experimental"));
456    }
457
458    #[test]
459    fn test_versioned_model() {
460        let metadata = ModelMetadata::builder()
461            .description("Test model".to_string())
462            .created_by("test_user".to_string())
463            .model_type("transformer".to_string())
464            .build();
465
466        let model = VersionedModel::new(
467            "test_model".to_string(),
468            "1.0.0".to_string(),
469            metadata,
470            vec![Uuid::new_v4()],
471        );
472
473        assert_eq!(model.model_name(), "test_model");
474        assert_eq!(model.version(), "1.0.0");
475        assert_eq!(model.qualified_name(), "test_model:1.0.0");
476        assert!(model.is_root());
477        assert!(model.is_leaf());
478        assert!(model.validate_version_format().is_ok());
479    }
480
481    #[test]
482    fn test_version_format_validation() {
483        let test_cases = vec![
484            ("1.0.0", true),
485            ("v1.0.0", true),
486            ("2.1.3-beta", true),
487            ("1.0.0+build.1", true),
488            ("main", true),
489            ("experimental-v2", true),
490            ("", false),
491            ("1.0", true), // Should pass custom format
492            ("invalid version!", false),
493        ];
494
495        for (version, should_be_valid) in test_cases {
496            let metadata = ModelMetadata::builder()
497                .description("Test".to_string())
498                .created_by("test".to_string())
499                .model_type("test".to_string())
500                .build();
501
502            let model =
503                VersionedModel::new("test".to_string(), version.to_string(), metadata, vec![]);
504
505            let is_valid = model.validate_version_format().is_ok();
506            assert_eq!(
507                is_valid, should_be_valid,
508                "Version '{}' validation failed",
509                version
510            );
511        }
512    }
513
514    #[test]
515    fn test_model_tags() {
516        let tag1 = ModelTag::new("production");
517        let tag2 = ModelTag::with_value("environment", "staging");
518        let tag3 = ModelTag::with_category("model_type", "llm", "architecture");
519
520        assert_eq!(tag1.name, "production");
521        assert_eq!(tag1.value, None);
522        assert_eq!(tag1.category, None);
523
524        assert_eq!(tag2.name, "environment");
525        assert_eq!(tag2.value, Some("staging".to_string()));
526        assert_eq!(tag2.category, None);
527
528        assert_eq!(tag3.name, "model_type");
529        assert_eq!(tag3.value, Some("llm".to_string()));
530        assert_eq!(tag3.category, Some("architecture".to_string()));
531    }
532}