1use anyhow::Result;
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use std::collections::{HashMap, HashSet};
7use tokio::sync::RwLock;
8use uuid::Uuid;
9
10use super::metadata::VersionedModel;
11
12pub struct ModelRegistry {
14 versions: RwLock<HashMap<Uuid, VersionedModel>>,
16 name_index: RwLock<HashMap<String, HashMap<String, Uuid>>>,
18 tag_index: RwLock<HashMap<String, HashSet<Uuid>>>,
20 time_index: RwLock<Vec<(DateTime<Utc>, Uuid)>>,
22}
23
24impl ModelRegistry {
25 pub fn new() -> Self {
27 Self {
28 versions: RwLock::new(HashMap::new()),
29 name_index: RwLock::new(HashMap::new()),
30 tag_index: RwLock::new(HashMap::new()),
31 time_index: RwLock::new(Vec::new()),
32 }
33 }
34
35 pub async fn register(&self, model: VersionedModel) -> Result<Uuid> {
37 let version_id = model.id();
38 let model_name = model.model_name().to_string();
39 let version_string = model.version().to_string();
40
41 {
43 let name_index = self.name_index.read().await;
44 if let Some(versions) = name_index.get(&model_name) {
45 if versions.contains_key(&version_string) {
46 anyhow::bail!("Version {}:{} already exists", model_name, version_string);
47 }
48 }
49 }
50
51 {
53 let mut versions = self.versions.write().await;
54 versions.insert(version_id, model.clone());
55 }
56
57 {
59 let mut name_index = self.name_index.write().await;
60 name_index.entry(model_name).or_default().insert(version_string, version_id);
61 }
62
63 {
65 let mut tag_index = self.tag_index.write().await;
66 for tag in &model.metadata().tags {
67 tag_index.entry(tag.name.clone()).or_default().insert(version_id);
68 }
69 }
70
71 {
73 let mut time_index = self.time_index.write().await;
74 time_index.push((model.metadata().created_at, version_id));
75 time_index.sort_by_key(|(time, _)| *time);
76 }
77
78 tracing::debug!("Registered model version: {}", version_id);
79 Ok(version_id)
80 }
81
82 pub async fn get_version(&self, version_id: Uuid) -> Result<Option<VersionedModel>> {
84 let versions = self.versions.read().await;
85 Ok(versions.get(&version_id).cloned())
86 }
87
88 pub async fn get_version_by_name(
90 &self,
91 model_name: &str,
92 version: &str,
93 ) -> Result<Option<VersionedModel>> {
94 let name_index = self.name_index.read().await;
95 if let Some(versions) = name_index.get(model_name) {
96 if let Some(&version_id) = versions.get(version) {
97 let versions_map = self.versions.read().await;
98 return Ok(versions_map.get(&version_id).cloned());
99 }
100 }
101 Ok(None)
102 }
103
104 pub async fn list_versions(&self, model_name: &str) -> Result<Vec<VersionedModel>> {
106 let name_index = self.name_index.read().await;
107 let versions_map = self.versions.read().await;
108
109 if let Some(versions) = name_index.get(model_name) {
110 let mut models: Vec<VersionedModel> =
111 versions.values().filter_map(|&id| versions_map.get(&id).cloned()).collect();
112
113 models.sort_by_key(|model| std::cmp::Reverse(model.metadata().created_at));
115 Ok(models)
116 } else {
117 Ok(vec![])
118 }
119 }
120
121 pub async fn list_models(&self) -> Result<Vec<String>> {
123 let name_index = self.name_index.read().await;
124 let mut names: Vec<String> = name_index.keys().cloned().collect();
125 names.sort();
126 Ok(names)
127 }
128
129 pub async fn query_versions(&self, query: VersionQuery) -> Result<Vec<VersionedModel>> {
131 let versions_map = self.versions.read().await;
132 let mut results = Vec::new();
133
134 for model in versions_map.values() {
135 if self.matches_query(model, &query).await {
136 results.push(model.clone());
137 }
138 }
139
140 self.sort_results(&mut results, &query.sort_by);
142
143 if let Some(limit) = query.limit {
145 let offset = query.offset.unwrap_or(0);
146 let end = std::cmp::min(offset + limit, results.len());
147 results = results[offset..end].to_vec();
148 }
149
150 Ok(results)
151 }
152
153 pub async fn remove_version(&self, version_id: Uuid) -> Result<Option<VersionedModel>> {
155 let model = {
157 let mut versions = self.versions.write().await;
158 versions.remove(&version_id)
159 };
160
161 if let Some(ref model) = model {
162 {
164 let mut name_index = self.name_index.write().await;
165 if let Some(versions) = name_index.get_mut(model.model_name()) {
166 versions.remove(model.version());
167 if versions.is_empty() {
168 name_index.remove(model.model_name());
169 }
170 }
171 }
172
173 {
175 let mut tag_index = self.tag_index.write().await;
176 for tag in &model.metadata().tags {
177 if let Some(tag_set) = tag_index.get_mut(&tag.name) {
178 tag_set.remove(&version_id);
179 if tag_set.is_empty() {
180 tag_index.remove(&tag.name);
181 }
182 }
183 }
184 }
185
186 {
188 let mut time_index = self.time_index.write().await;
189 time_index.retain(|(_, id)| *id != version_id);
190 }
191
192 tracing::debug!("Removed model version: {}", version_id);
193 }
194
195 Ok(model)
196 }
197
198 pub async fn get_versions_by_tag(&self, tag_name: &str) -> Result<Vec<VersionedModel>> {
200 let tag_index = self.tag_index.read().await;
201 let versions_map = self.versions.read().await;
202
203 if let Some(version_ids) = tag_index.get(tag_name) {
204 let models: Vec<VersionedModel> =
205 version_ids.iter().filter_map(|&id| versions_map.get(&id).cloned()).collect();
206 Ok(models)
207 } else {
208 Ok(vec![])
209 }
210 }
211
212 pub async fn get_latest_version(&self, model_name: &str) -> Result<Option<VersionedModel>> {
214 let versions = self.list_versions(model_name).await?;
215 Ok(versions.into_iter().next()) }
217
218 pub async fn get_statistics(&self) -> Result<RegistryStatistics> {
220 let versions_map = self.versions.read().await;
221 let name_index = self.name_index.read().await;
222 let tag_index = self.tag_index.read().await;
223
224 let total_versions = versions_map.len();
225 let total_models = name_index.len();
226 let total_tags = tag_index.len();
227
228 let mut total_artifacts = 0;
230 let mut total_size_bytes = 0;
231
232 for model in versions_map.values() {
233 total_artifacts += model.artifact_ids().len();
234 if let Some(size) = model.metadata().size_bytes {
235 total_size_bytes += size;
236 }
237 }
238
239 Ok(RegistryStatistics {
240 total_versions,
241 total_models,
242 total_tags,
243 total_artifacts,
244 total_size_bytes,
245 })
246 }
247
248 async fn matches_query(&self, model: &VersionedModel, query: &VersionQuery) -> bool {
251 if let Some(ref pattern) = query.model_name_pattern {
253 if !self.matches_pattern(model.model_name(), pattern) {
254 return false;
255 }
256 }
257
258 if let Some(ref version_filter) = query.version_filter {
260 if !self.matches_version_filter(model, version_filter) {
261 return false;
262 }
263 }
264
265 if !query.tags.is_empty() {
267 let model_tags: HashSet<String> =
268 model.metadata().tags.iter().map(|tag| tag.name.clone()).collect();
269
270 match query.tag_mode {
271 TagMatchMode::Any => {
272 if !query.tags.iter().any(|tag| model_tags.contains(tag)) {
273 return false;
274 }
275 },
276 TagMatchMode::All => {
277 if !query.tags.iter().all(|tag| model_tags.contains(tag)) {
278 return false;
279 }
280 },
281 }
282 }
283
284 if let Some(ref date_range) = query.created_date_range {
286 let created_at = model.metadata().created_at;
287 if let Some(start) = date_range.start {
288 if created_at < start {
289 return false;
290 }
291 }
292 if let Some(end) = date_range.end {
293 if created_at > end {
294 return false;
295 }
296 }
297 }
298
299 if let Some(ref model_type) = query.model_type {
301 if model.metadata().model_type != *model_type {
302 return false;
303 }
304 }
305
306 true
307 }
308
309 fn matches_pattern(&self, text: &str, pattern: &str) -> bool {
310 if pattern.contains('*') {
312 let parts: Vec<&str> = pattern.split('*').collect();
314 if parts.len() == 2 {
315 let prefix = parts[0];
316 let suffix = parts[1];
317 return text.starts_with(prefix) && text.ends_with(suffix);
318 }
319 }
320 text.contains(pattern)
321 }
322
323 fn matches_version_filter(&self, model: &VersionedModel, filter: &VersionFilter) -> bool {
324 match filter {
325 VersionFilter::Exact(version) => model.version() == version,
326 VersionFilter::Prefix(prefix) => model.version().starts_with(prefix),
327 VersionFilter::Regex(regex_str) => {
328 if let Ok(regex) = regex::Regex::new(regex_str) {
329 regex.is_match(model.version())
330 } else {
331 false
332 }
333 },
334 }
335 }
336
337 fn sort_results(&self, results: &mut [VersionedModel], sort_by: &SortBy) {
338 match sort_by {
339 SortBy::CreatedAt(order) => {
340 results.sort_by(|a, b| {
341 let cmp = a.metadata().created_at.cmp(&b.metadata().created_at);
342 match order {
343 SortOrder::Ascending => cmp,
344 SortOrder::Descending => cmp.reverse(),
345 }
346 });
347 },
348 SortBy::ModelName(order) => {
349 results.sort_by(|a, b| {
350 let cmp = a.model_name().cmp(b.model_name());
351 match order {
352 SortOrder::Ascending => cmp,
353 SortOrder::Descending => cmp.reverse(),
354 }
355 });
356 },
357 SortBy::Version(order) => {
358 results.sort_by(|a, b| {
359 let cmp = a.version().cmp(b.version());
360 match order {
361 SortOrder::Ascending => cmp,
362 SortOrder::Descending => cmp.reverse(),
363 }
364 });
365 },
366 }
367 }
368}
369
370impl Default for ModelRegistry {
371 fn default() -> Self {
372 Self::new()
373 }
374}
375
376#[derive(Debug, Clone)]
378pub struct VersionQuery {
379 pub model_name_pattern: Option<String>,
381 pub version_filter: Option<VersionFilter>,
383 pub tags: Vec<String>,
385 pub tag_mode: TagMatchMode,
387 pub created_date_range: Option<DateRange>,
389 pub model_type: Option<String>,
391 pub sort_by: SortBy,
393 pub offset: Option<usize>,
395 pub limit: Option<usize>,
397}
398
399impl Default for VersionQuery {
400 fn default() -> Self {
401 Self {
402 model_name_pattern: None,
403 version_filter: None,
404 tags: Vec::new(),
405 tag_mode: TagMatchMode::Any,
406 created_date_range: None,
407 model_type: None,
408 sort_by: SortBy::CreatedAt(SortOrder::Descending),
409 offset: None,
410 limit: None,
411 }
412 }
413}
414
415impl VersionQuery {
416 pub fn new() -> Self {
417 Self::default()
418 }
419
420 pub fn model_name_pattern(mut self, pattern: String) -> Self {
421 self.model_name_pattern = Some(pattern);
422 self
423 }
424
425 pub fn version_filter(mut self, filter: VersionFilter) -> Self {
426 self.version_filter = Some(filter);
427 self
428 }
429
430 pub fn with_tag(mut self, tag: String) -> Self {
431 self.tags.push(tag);
432 self
433 }
434
435 pub fn tag_mode(mut self, mode: TagMatchMode) -> Self {
436 self.tag_mode = mode;
437 self
438 }
439
440 pub fn created_after(mut self, date: DateTime<Utc>) -> Self {
441 let range = self.created_date_range.get_or_insert(DateRange::default());
442 range.start = Some(date);
443 self
444 }
445
446 pub fn created_before(mut self, date: DateTime<Utc>) -> Self {
447 let range = self.created_date_range.get_or_insert(DateRange::default());
448 range.end = Some(date);
449 self
450 }
451
452 pub fn model_type(mut self, model_type: String) -> Self {
453 self.model_type = Some(model_type);
454 self
455 }
456
457 pub fn sort_by(mut self, sort_by: SortBy) -> Self {
458 self.sort_by = sort_by;
459 self
460 }
461
462 pub fn limit(mut self, limit: usize) -> Self {
463 self.limit = Some(limit);
464 self
465 }
466
467 pub fn offset(mut self, offset: usize) -> Self {
468 self.offset = Some(offset);
469 self
470 }
471}
472
473#[derive(Debug, Clone)]
475pub enum VersionFilter {
476 Exact(String),
478 Prefix(String),
480 Regex(String),
482}
483
484#[derive(Debug, Clone)]
486pub enum TagMatchMode {
487 Any,
489 All,
491}
492
493#[derive(Debug, Clone, Default)]
495pub struct DateRange {
496 pub start: Option<DateTime<Utc>>,
497 pub end: Option<DateTime<Utc>>,
498}
499
500#[derive(Debug, Clone)]
502pub enum SortBy {
503 CreatedAt(SortOrder),
504 ModelName(SortOrder),
505 Version(SortOrder),
506}
507
508#[derive(Debug, Clone)]
510pub enum SortOrder {
511 Ascending,
512 Descending,
513}
514
515#[derive(Debug, Clone, Serialize, Deserialize)]
517pub struct RegistryStatistics {
518 pub total_versions: usize,
519 pub total_models: usize,
520 pub total_tags: usize,
521 pub total_artifacts: usize,
522 pub total_size_bytes: u64,
523}
524
525#[cfg(test)]
526mod tests {
527 use super::*;
528 use crate::versioning::metadata::{ModelMetadata, ModelTag};
529
530 async fn create_test_model(name: &str, version: &str, tags: Vec<ModelTag>) -> VersionedModel {
531 let mut metadata_builder = ModelMetadata::builder()
532 .description(format!("Test model {}", name))
533 .created_by("test_user".to_string())
534 .model_type("transformer".to_string());
535
536 for tag in tags {
537 metadata_builder = metadata_builder.tag(tag);
538 }
539
540 let metadata = metadata_builder.build();
541
542 VersionedModel::new(name.to_string(), version.to_string(), metadata, vec![])
543 }
544
545 #[tokio::test]
546 async fn test_registry_operations() {
547 let registry = ModelRegistry::new();
548
549 let model = create_test_model("gpt2", "1.0.0", vec![ModelTag::new("production")]).await;
551 let version_id = registry.register(model.clone()).await.expect("async operation failed");
552 assert_eq!(version_id, model.id());
553
554 let retrieved = registry.get_version(version_id).await.expect("async operation failed");
556 assert!(retrieved.is_some());
557 assert_eq!(
558 retrieved.expect("operation failed in test").version(),
559 "1.0.0"
560 );
561
562 let retrieved = registry
564 .get_version_by_name("gpt2", "1.0.0")
565 .await
566 .expect("async operation failed");
567 assert!(retrieved.is_some());
568 assert_eq!(
569 retrieved.expect("operation failed in test").model_name(),
570 "gpt2"
571 );
572
573 let versions = registry.list_versions("gpt2").await.expect("async operation failed");
575 assert_eq!(versions.len(), 1);
576
577 let models = registry.list_models().await.expect("async operation failed");
579 assert_eq!(models, vec!["gpt2"]);
580 }
581
582 #[tokio::test]
583 async fn test_query_functionality() {
584 let registry = ModelRegistry::new();
585
586 let models = vec![
588 create_test_model("gpt2", "1.0.0", vec![ModelTag::new("production")]).await,
589 create_test_model("gpt2", "1.1.0", vec![ModelTag::new("staging")]).await,
590 create_test_model("bert", "1.0.0", vec![ModelTag::new("production")]).await,
591 ];
592
593 for model in models {
594 registry.register(model).await.expect("async operation failed");
595 }
596
597 let query = VersionQuery::new().model_name_pattern("gpt*".to_string());
599 let results = registry.query_versions(query).await.expect("async operation failed");
600 assert_eq!(results.len(), 2);
601
602 let query = VersionQuery::new().with_tag("production".to_string());
604 let results = registry.query_versions(query).await.expect("async operation failed");
605 assert_eq!(results.len(), 2);
606
607 let query = VersionQuery::new().limit(1);
609 let results = registry.query_versions(query).await.expect("async operation failed");
610 assert_eq!(results.len(), 1);
611 }
612
613 #[tokio::test]
614 async fn test_tag_operations() {
615 let registry = ModelRegistry::new();
616
617 let model = create_test_model(
618 "test",
619 "1.0.0",
620 vec![ModelTag::new("production"), ModelTag::new("gpu")],
621 )
622 .await;
623
624 registry.register(model).await.expect("async operation failed");
625
626 let results = registry
628 .get_versions_by_tag("production")
629 .await
630 .expect("async operation failed");
631 assert_eq!(results.len(), 1);
632
633 let results = registry
634 .get_versions_by_tag("nonexistent")
635 .await
636 .expect("async operation failed");
637 assert_eq!(results.len(), 0);
638 }
639
640 #[tokio::test]
641 async fn test_duplicate_prevention() {
642 let registry = ModelRegistry::new();
643
644 let model1 = create_test_model("test", "1.0.0", vec![]).await;
645 let model2 = create_test_model("test", "1.0.0", vec![]).await;
646
647 registry.register(model1).await.expect("async operation failed");
648 let result = registry.register(model2).await;
649 assert!(result.is_err());
650 }
651
652 #[tokio::test]
653 async fn test_registry_statistics() {
654 let registry = ModelRegistry::new();
655
656 let model = create_test_model("test", "1.0.0", vec![ModelTag::new("test")]).await;
657 registry.register(model).await.expect("async operation failed");
658
659 let stats = registry.get_statistics().await.expect("async operation failed");
660 assert_eq!(stats.total_versions, 1);
661 assert_eq!(stats.total_models, 1);
662 assert_eq!(stats.total_tags, 1);
663 }
664}