Skip to main content

trustformers_core/versioning/
registry.rs

1//! Model version registry for tracking and querying versions
2
3use anyhow::Result;
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use std::collections::{HashMap, HashSet};
7use tokio::sync::RwLock;
8use uuid::Uuid;
9
10use super::metadata::VersionedModel;
11
12/// Model version registry
13pub struct ModelRegistry {
14    /// Map of version ID to versioned model
15    versions: RwLock<HashMap<Uuid, VersionedModel>>,
16    /// Index by model name -> version string -> version ID
17    name_index: RwLock<HashMap<String, HashMap<String, Uuid>>>,
18    /// Index by tag -> version IDs
19    tag_index: RwLock<HashMap<String, HashSet<Uuid>>>,
20    /// Index by creation time (sorted)
21    time_index: RwLock<Vec<(DateTime<Utc>, Uuid)>>,
22}
23
24impl ModelRegistry {
25    /// Create a new model registry
26    pub fn new() -> Self {
27        Self {
28            versions: RwLock::new(HashMap::new()),
29            name_index: RwLock::new(HashMap::new()),
30            tag_index: RwLock::new(HashMap::new()),
31            time_index: RwLock::new(Vec::new()),
32        }
33    }
34
35    /// Register a new model version
36    pub async fn register(&self, model: VersionedModel) -> Result<Uuid> {
37        let version_id = model.id();
38        let model_name = model.model_name().to_string();
39        let version_string = model.version().to_string();
40
41        // Check for duplicate versions
42        {
43            let name_index = self.name_index.read().await;
44            if let Some(versions) = name_index.get(&model_name) {
45                if versions.contains_key(&version_string) {
46                    anyhow::bail!("Version {}:{} already exists", model_name, version_string);
47                }
48            }
49        }
50
51        // Update version storage
52        {
53            let mut versions = self.versions.write().await;
54            versions.insert(version_id, model.clone());
55        }
56
57        // Update name index
58        {
59            let mut name_index = self.name_index.write().await;
60            name_index.entry(model_name).or_default().insert(version_string, version_id);
61        }
62
63        // Update tag index
64        {
65            let mut tag_index = self.tag_index.write().await;
66            for tag in &model.metadata().tags {
67                tag_index.entry(tag.name.clone()).or_default().insert(version_id);
68            }
69        }
70
71        // Update time index
72        {
73            let mut time_index = self.time_index.write().await;
74            time_index.push((model.metadata().created_at, version_id));
75            time_index.sort_by_key(|(time, _)| *time);
76        }
77
78        tracing::debug!("Registered model version: {}", version_id);
79        Ok(version_id)
80    }
81
82    /// Get a model version by ID
83    pub async fn get_version(&self, version_id: Uuid) -> Result<Option<VersionedModel>> {
84        let versions = self.versions.read().await;
85        Ok(versions.get(&version_id).cloned())
86    }
87
88    /// Get a model version by name and version string
89    pub async fn get_version_by_name(
90        &self,
91        model_name: &str,
92        version: &str,
93    ) -> Result<Option<VersionedModel>> {
94        let name_index = self.name_index.read().await;
95        if let Some(versions) = name_index.get(model_name) {
96            if let Some(&version_id) = versions.get(version) {
97                let versions_map = self.versions.read().await;
98                return Ok(versions_map.get(&version_id).cloned());
99            }
100        }
101        Ok(None)
102    }
103
104    /// List all versions for a model
105    pub async fn list_versions(&self, model_name: &str) -> Result<Vec<VersionedModel>> {
106        let name_index = self.name_index.read().await;
107        let versions_map = self.versions.read().await;
108
109        if let Some(versions) = name_index.get(model_name) {
110            let mut models: Vec<VersionedModel> =
111                versions.values().filter_map(|&id| versions_map.get(&id).cloned()).collect();
112
113            // Sort by creation time (newest first)
114            models.sort_by_key(|model| std::cmp::Reverse(model.metadata().created_at));
115            Ok(models)
116        } else {
117            Ok(vec![])
118        }
119    }
120
121    /// List all model names
122    pub async fn list_models(&self) -> Result<Vec<String>> {
123        let name_index = self.name_index.read().await;
124        let mut names: Vec<String> = name_index.keys().cloned().collect();
125        names.sort();
126        Ok(names)
127    }
128
129    /// Query versions with filters
130    pub async fn query_versions(&self, query: VersionQuery) -> Result<Vec<VersionedModel>> {
131        let versions_map = self.versions.read().await;
132        let mut results = Vec::new();
133
134        for model in versions_map.values() {
135            if self.matches_query(model, &query).await {
136                results.push(model.clone());
137            }
138        }
139
140        // Apply sorting
141        self.sort_results(&mut results, &query.sort_by);
142
143        // Apply pagination
144        if let Some(limit) = query.limit {
145            let offset = query.offset.unwrap_or(0);
146            let end = std::cmp::min(offset + limit, results.len());
147            results = results[offset..end].to_vec();
148        }
149
150        Ok(results)
151    }
152
153    /// Remove a version from the registry
154    pub async fn remove_version(&self, version_id: Uuid) -> Result<Option<VersionedModel>> {
155        // Get the model to remove from indices
156        let model = {
157            let mut versions = self.versions.write().await;
158            versions.remove(&version_id)
159        };
160
161        if let Some(ref model) = model {
162            // Remove from name index
163            {
164                let mut name_index = self.name_index.write().await;
165                if let Some(versions) = name_index.get_mut(model.model_name()) {
166                    versions.remove(model.version());
167                    if versions.is_empty() {
168                        name_index.remove(model.model_name());
169                    }
170                }
171            }
172
173            // Remove from tag index
174            {
175                let mut tag_index = self.tag_index.write().await;
176                for tag in &model.metadata().tags {
177                    if let Some(tag_set) = tag_index.get_mut(&tag.name) {
178                        tag_set.remove(&version_id);
179                        if tag_set.is_empty() {
180                            tag_index.remove(&tag.name);
181                        }
182                    }
183                }
184            }
185
186            // Remove from time index
187            {
188                let mut time_index = self.time_index.write().await;
189                time_index.retain(|(_, id)| *id != version_id);
190            }
191
192            tracing::debug!("Removed model version: {}", version_id);
193        }
194
195        Ok(model)
196    }
197
198    /// Get versions by tag
199    pub async fn get_versions_by_tag(&self, tag_name: &str) -> Result<Vec<VersionedModel>> {
200        let tag_index = self.tag_index.read().await;
201        let versions_map = self.versions.read().await;
202
203        if let Some(version_ids) = tag_index.get(tag_name) {
204            let models: Vec<VersionedModel> =
205                version_ids.iter().filter_map(|&id| versions_map.get(&id).cloned()).collect();
206            Ok(models)
207        } else {
208            Ok(vec![])
209        }
210    }
211
212    /// Get latest version for a model
213    pub async fn get_latest_version(&self, model_name: &str) -> Result<Option<VersionedModel>> {
214        let versions = self.list_versions(model_name).await?;
215        Ok(versions.into_iter().next()) // Already sorted by creation time (newest first)
216    }
217
218    /// Get registry statistics
219    pub async fn get_statistics(&self) -> Result<RegistryStatistics> {
220        let versions_map = self.versions.read().await;
221        let name_index = self.name_index.read().await;
222        let tag_index = self.tag_index.read().await;
223
224        let total_versions = versions_map.len();
225        let total_models = name_index.len();
226        let total_tags = tag_index.len();
227
228        // Calculate storage statistics
229        let mut total_artifacts = 0;
230        let mut total_size_bytes = 0;
231
232        for model in versions_map.values() {
233            total_artifacts += model.artifact_ids().len();
234            if let Some(size) = model.metadata().size_bytes {
235                total_size_bytes += size;
236            }
237        }
238
239        Ok(RegistryStatistics {
240            total_versions,
241            total_models,
242            total_tags,
243            total_artifacts,
244            total_size_bytes,
245        })
246    }
247
248    // Helper methods
249
250    async fn matches_query(&self, model: &VersionedModel, query: &VersionQuery) -> bool {
251        // Model name filter
252        if let Some(ref pattern) = query.model_name_pattern {
253            if !self.matches_pattern(model.model_name(), pattern) {
254                return false;
255            }
256        }
257
258        // Version filter
259        if let Some(ref version_filter) = query.version_filter {
260            if !self.matches_version_filter(model, version_filter) {
261                return false;
262            }
263        }
264
265        // Tag filter
266        if !query.tags.is_empty() {
267            let model_tags: HashSet<String> =
268                model.metadata().tags.iter().map(|tag| tag.name.clone()).collect();
269
270            match query.tag_mode {
271                TagMatchMode::Any => {
272                    if !query.tags.iter().any(|tag| model_tags.contains(tag)) {
273                        return false;
274                    }
275                },
276                TagMatchMode::All => {
277                    if !query.tags.iter().all(|tag| model_tags.contains(tag)) {
278                        return false;
279                    }
280                },
281            }
282        }
283
284        // Created date range
285        if let Some(ref date_range) = query.created_date_range {
286            let created_at = model.metadata().created_at;
287            if let Some(start) = date_range.start {
288                if created_at < start {
289                    return false;
290                }
291            }
292            if let Some(end) = date_range.end {
293                if created_at > end {
294                    return false;
295                }
296            }
297        }
298
299        // Model type filter
300        if let Some(ref model_type) = query.model_type {
301            if model.metadata().model_type != *model_type {
302                return false;
303            }
304        }
305
306        true
307    }
308
309    fn matches_pattern(&self, text: &str, pattern: &str) -> bool {
310        // Simple pattern matching - could be enhanced with regex
311        if pattern.contains('*') {
312            // Wildcard matching
313            let parts: Vec<&str> = pattern.split('*').collect();
314            if parts.len() == 2 {
315                let prefix = parts[0];
316                let suffix = parts[1];
317                return text.starts_with(prefix) && text.ends_with(suffix);
318            }
319        }
320        text.contains(pattern)
321    }
322
323    fn matches_version_filter(&self, model: &VersionedModel, filter: &VersionFilter) -> bool {
324        match filter {
325            VersionFilter::Exact(version) => model.version() == version,
326            VersionFilter::Prefix(prefix) => model.version().starts_with(prefix),
327            VersionFilter::Regex(regex_str) => {
328                if let Ok(regex) = regex::Regex::new(regex_str) {
329                    regex.is_match(model.version())
330                } else {
331                    false
332                }
333            },
334        }
335    }
336
337    fn sort_results(&self, results: &mut [VersionedModel], sort_by: &SortBy) {
338        match sort_by {
339            SortBy::CreatedAt(order) => {
340                results.sort_by(|a, b| {
341                    let cmp = a.metadata().created_at.cmp(&b.metadata().created_at);
342                    match order {
343                        SortOrder::Ascending => cmp,
344                        SortOrder::Descending => cmp.reverse(),
345                    }
346                });
347            },
348            SortBy::ModelName(order) => {
349                results.sort_by(|a, b| {
350                    let cmp = a.model_name().cmp(b.model_name());
351                    match order {
352                        SortOrder::Ascending => cmp,
353                        SortOrder::Descending => cmp.reverse(),
354                    }
355                });
356            },
357            SortBy::Version(order) => {
358                results.sort_by(|a, b| {
359                    let cmp = a.version().cmp(b.version());
360                    match order {
361                        SortOrder::Ascending => cmp,
362                        SortOrder::Descending => cmp.reverse(),
363                    }
364                });
365            },
366        }
367    }
368}
369
370impl Default for ModelRegistry {
371    fn default() -> Self {
372        Self::new()
373    }
374}
375
376/// Query for searching model versions
377#[derive(Debug, Clone)]
378pub struct VersionQuery {
379    /// Model name pattern (supports wildcards)
380    pub model_name_pattern: Option<String>,
381    /// Version filter
382    pub version_filter: Option<VersionFilter>,
383    /// Tags to match
384    pub tags: Vec<String>,
385    /// Tag matching mode
386    pub tag_mode: TagMatchMode,
387    /// Created date range
388    pub created_date_range: Option<DateRange>,
389    /// Model type filter
390    pub model_type: Option<String>,
391    /// Sort order
392    pub sort_by: SortBy,
393    /// Pagination offset
394    pub offset: Option<usize>,
395    /// Pagination limit
396    pub limit: Option<usize>,
397}
398
399impl Default for VersionQuery {
400    fn default() -> Self {
401        Self {
402            model_name_pattern: None,
403            version_filter: None,
404            tags: Vec::new(),
405            tag_mode: TagMatchMode::Any,
406            created_date_range: None,
407            model_type: None,
408            sort_by: SortBy::CreatedAt(SortOrder::Descending),
409            offset: None,
410            limit: None,
411        }
412    }
413}
414
415impl VersionQuery {
416    pub fn new() -> Self {
417        Self::default()
418    }
419
420    pub fn model_name_pattern(mut self, pattern: String) -> Self {
421        self.model_name_pattern = Some(pattern);
422        self
423    }
424
425    pub fn version_filter(mut self, filter: VersionFilter) -> Self {
426        self.version_filter = Some(filter);
427        self
428    }
429
430    pub fn with_tag(mut self, tag: String) -> Self {
431        self.tags.push(tag);
432        self
433    }
434
435    pub fn tag_mode(mut self, mode: TagMatchMode) -> Self {
436        self.tag_mode = mode;
437        self
438    }
439
440    pub fn created_after(mut self, date: DateTime<Utc>) -> Self {
441        let range = self.created_date_range.get_or_insert(DateRange::default());
442        range.start = Some(date);
443        self
444    }
445
446    pub fn created_before(mut self, date: DateTime<Utc>) -> Self {
447        let range = self.created_date_range.get_or_insert(DateRange::default());
448        range.end = Some(date);
449        self
450    }
451
452    pub fn model_type(mut self, model_type: String) -> Self {
453        self.model_type = Some(model_type);
454        self
455    }
456
457    pub fn sort_by(mut self, sort_by: SortBy) -> Self {
458        self.sort_by = sort_by;
459        self
460    }
461
462    pub fn limit(mut self, limit: usize) -> Self {
463        self.limit = Some(limit);
464        self
465    }
466
467    pub fn offset(mut self, offset: usize) -> Self {
468        self.offset = Some(offset);
469        self
470    }
471}
472
473/// Version filter options
474#[derive(Debug, Clone)]
475pub enum VersionFilter {
476    /// Exact version match
477    Exact(String),
478    /// Version prefix match
479    Prefix(String),
480    /// Regex pattern match
481    Regex(String),
482}
483
484/// Tag matching mode
485#[derive(Debug, Clone)]
486pub enum TagMatchMode {
487    /// Match any of the specified tags
488    Any,
489    /// Match all of the specified tags
490    All,
491}
492
493/// Date range filter
494#[derive(Debug, Clone, Default)]
495pub struct DateRange {
496    pub start: Option<DateTime<Utc>>,
497    pub end: Option<DateTime<Utc>>,
498}
499
500/// Sort options
501#[derive(Debug, Clone)]
502pub enum SortBy {
503    CreatedAt(SortOrder),
504    ModelName(SortOrder),
505    Version(SortOrder),
506}
507
508/// Sort order
509#[derive(Debug, Clone)]
510pub enum SortOrder {
511    Ascending,
512    Descending,
513}
514
515/// Registry statistics
516#[derive(Debug, Clone, Serialize, Deserialize)]
517pub struct RegistryStatistics {
518    pub total_versions: usize,
519    pub total_models: usize,
520    pub total_tags: usize,
521    pub total_artifacts: usize,
522    pub total_size_bytes: u64,
523}
524
525#[cfg(test)]
526mod tests {
527    use super::*;
528    use crate::versioning::metadata::{ModelMetadata, ModelTag};
529
530    async fn create_test_model(name: &str, version: &str, tags: Vec<ModelTag>) -> VersionedModel {
531        let mut metadata_builder = ModelMetadata::builder()
532            .description(format!("Test model {}", name))
533            .created_by("test_user".to_string())
534            .model_type("transformer".to_string());
535
536        for tag in tags {
537            metadata_builder = metadata_builder.tag(tag);
538        }
539
540        let metadata = metadata_builder.build();
541
542        VersionedModel::new(name.to_string(), version.to_string(), metadata, vec![])
543    }
544
545    #[tokio::test]
546    async fn test_registry_operations() {
547        let registry = ModelRegistry::new();
548
549        // Register a model
550        let model = create_test_model("gpt2", "1.0.0", vec![ModelTag::new("production")]).await;
551        let version_id = registry.register(model.clone()).await.expect("async operation failed");
552        assert_eq!(version_id, model.id());
553
554        // Get by ID
555        let retrieved = registry.get_version(version_id).await.expect("async operation failed");
556        assert!(retrieved.is_some());
557        assert_eq!(
558            retrieved.expect("operation failed in test").version(),
559            "1.0.0"
560        );
561
562        // Get by name and version
563        let retrieved = registry
564            .get_version_by_name("gpt2", "1.0.0")
565            .await
566            .expect("async operation failed");
567        assert!(retrieved.is_some());
568        assert_eq!(
569            retrieved.expect("operation failed in test").model_name(),
570            "gpt2"
571        );
572
573        // List versions
574        let versions = registry.list_versions("gpt2").await.expect("async operation failed");
575        assert_eq!(versions.len(), 1);
576
577        // List models
578        let models = registry.list_models().await.expect("async operation failed");
579        assert_eq!(models, vec!["gpt2"]);
580    }
581
582    #[tokio::test]
583    async fn test_query_functionality() {
584        let registry = ModelRegistry::new();
585
586        // Register multiple models
587        let models = vec![
588            create_test_model("gpt2", "1.0.0", vec![ModelTag::new("production")]).await,
589            create_test_model("gpt2", "1.1.0", vec![ModelTag::new("staging")]).await,
590            create_test_model("bert", "1.0.0", vec![ModelTag::new("production")]).await,
591        ];
592
593        for model in models {
594            registry.register(model).await.expect("async operation failed");
595        }
596
597        // Query by model name pattern
598        let query = VersionQuery::new().model_name_pattern("gpt*".to_string());
599        let results = registry.query_versions(query).await.expect("async operation failed");
600        assert_eq!(results.len(), 2);
601
602        // Query by tag
603        let query = VersionQuery::new().with_tag("production".to_string());
604        let results = registry.query_versions(query).await.expect("async operation failed");
605        assert_eq!(results.len(), 2);
606
607        // Query with limit
608        let query = VersionQuery::new().limit(1);
609        let results = registry.query_versions(query).await.expect("async operation failed");
610        assert_eq!(results.len(), 1);
611    }
612
613    #[tokio::test]
614    async fn test_tag_operations() {
615        let registry = ModelRegistry::new();
616
617        let model = create_test_model(
618            "test",
619            "1.0.0",
620            vec![ModelTag::new("production"), ModelTag::new("gpu")],
621        )
622        .await;
623
624        registry.register(model).await.expect("async operation failed");
625
626        // Get by tag
627        let results = registry
628            .get_versions_by_tag("production")
629            .await
630            .expect("async operation failed");
631        assert_eq!(results.len(), 1);
632
633        let results = registry
634            .get_versions_by_tag("nonexistent")
635            .await
636            .expect("async operation failed");
637        assert_eq!(results.len(), 0);
638    }
639
640    #[tokio::test]
641    async fn test_duplicate_prevention() {
642        let registry = ModelRegistry::new();
643
644        let model1 = create_test_model("test", "1.0.0", vec![]).await;
645        let model2 = create_test_model("test", "1.0.0", vec![]).await;
646
647        registry.register(model1).await.expect("async operation failed");
648        let result = registry.register(model2).await;
649        assert!(result.is_err());
650    }
651
652    #[tokio::test]
653    async fn test_registry_statistics() {
654        let registry = ModelRegistry::new();
655
656        let model = create_test_model("test", "1.0.0", vec![ModelTag::new("test")]).await;
657        registry.register(model).await.expect("async operation failed");
658
659        let stats = registry.get_statistics().await.expect("async operation failed");
660        assert_eq!(stats.total_versions, 1);
661        assert_eq!(stats.total_models, 1);
662        assert_eq!(stats.total_tags, 1);
663    }
664}