Skip to main content

trustformers_core/versioning/
storage.rs

1//! Model storage backend for artifacts and metadata
2
3use anyhow::Result;
4use async_trait::async_trait;
5use chrono::{DateTime, Utc};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::path::{Path, PathBuf};
9use tokio::fs;
10use uuid::Uuid;
11
12/// Storage backend trait for model artifacts
13#[async_trait]
14pub trait ModelStorage: Send + Sync {
15    /// Store artifacts and return their IDs
16    async fn store_artifacts(&self, artifacts: &[Artifact]) -> Result<Vec<Uuid>>;
17
18    /// Retrieve an artifact by ID
19    async fn get_artifact(&self, artifact_id: Uuid) -> Result<Option<Artifact>>;
20
21    /// Delete artifacts by IDs
22    async fn delete_artifacts(&self, artifact_ids: &[Uuid]) -> Result<()>;
23
24    /// Archive artifacts for a version
25    async fn archive_version(&self, version_id: Uuid) -> Result<()>;
26
27    /// Delete all artifacts for a version
28    async fn delete_version(&self, version_id: Uuid) -> Result<()>;
29
30    /// List all artifacts for a version
31    async fn list_artifacts(&self, version_id: Uuid) -> Result<Vec<Artifact>>;
32}
33
34/// Model artifact types
35#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
36pub enum ArtifactType {
37    /// Model weights/parameters
38    Model,
39    /// Model configuration
40    Config,
41    /// Tokenizer files
42    Tokenizer,
43    /// Vocabulary files
44    Vocabulary,
45    /// Training checkpoints
46    Checkpoint,
47    /// Optimization state
48    OptimizerState,
49    /// Model architecture definition
50    Architecture,
51    /// Preprocessing pipeline
52    Preprocessing,
53    /// Evaluation metrics
54    Metrics,
55    /// Documentation
56    Documentation,
57    /// Custom artifact type
58    Custom(String),
59}
60
61impl ArtifactType {
62    /// Get file extension for artifact type
63    pub fn default_extension(&self) -> &'static str {
64        match self {
65            ArtifactType::Model => "bin",
66            ArtifactType::Config => "json",
67            ArtifactType::Tokenizer => "json",
68            ArtifactType::Vocabulary => "txt",
69            ArtifactType::Checkpoint => "ckpt",
70            ArtifactType::OptimizerState => "bin",
71            ArtifactType::Architecture => "json",
72            ArtifactType::Preprocessing => "json",
73            ArtifactType::Metrics => "json",
74            ArtifactType::Documentation => "md",
75            ArtifactType::Custom(_) => "bin",
76        }
77    }
78
79    /// Check if artifact type is required for deployment
80    pub fn is_required_for_deployment(&self) -> bool {
81        matches!(self, ArtifactType::Model | ArtifactType::Config)
82    }
83}
84
85/// Model artifact
86#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct Artifact {
88    /// Unique identifier
89    pub id: Uuid,
90    /// Artifact type
91    pub artifact_type: ArtifactType,
92    /// Original file path
93    pub file_path: PathBuf,
94    /// File size in bytes
95    pub size_bytes: u64,
96    /// Content hash (SHA256)
97    pub content_hash: String,
98    /// MIME type
99    pub mime_type: String,
100    /// Binary content
101    pub content: Vec<u8>,
102    /// Creation timestamp
103    pub created_at: DateTime<Utc>,
104    /// Optional metadata
105    pub metadata: HashMap<String, serde_json::Value>,
106}
107
108impl Artifact {
109    /// Create a new artifact
110    pub fn new(artifact_type: ArtifactType, file_path: PathBuf, content: Vec<u8>) -> Self {
111        let content_hash = Self::compute_hash(&content);
112        let mime_type = Self::detect_mime_type(&file_path, &artifact_type);
113
114        Self {
115            id: Uuid::new_v4(),
116            artifact_type,
117            size_bytes: content.len() as u64,
118            content_hash,
119            mime_type,
120            content,
121            file_path,
122            created_at: Utc::now(),
123            metadata: HashMap::new(),
124        }
125    }
126
127    /// Create artifact from file
128    pub async fn from_file(artifact_type: ArtifactType, file_path: PathBuf) -> Result<Self> {
129        let content = fs::read(&file_path).await?;
130        Ok(Self::new(artifact_type, file_path, content))
131    }
132
133    /// Add metadata
134    pub fn with_metadata(mut self, key: String, value: serde_json::Value) -> Self {
135        self.metadata.insert(key, value);
136        self
137    }
138
139    /// Compute SHA256 hash of content
140    fn compute_hash(content: &[u8]) -> String {
141        use sha2::{Digest, Sha256};
142        let mut hasher = Sha256::new();
143        hasher.update(content);
144        format!("{:x}", hasher.finalize())
145    }
146
147    /// Detect MIME type
148    fn detect_mime_type(file_path: &Path, artifact_type: &ArtifactType) -> String {
149        // Simple MIME type detection based on extension and artifact type
150        if let Some(extension) = file_path.extension().and_then(|s| s.to_str()) {
151            match extension.to_lowercase().as_str() {
152                "json" => "application/json".to_string(),
153                "bin" | "pt" | "pth" => "application/octet-stream".to_string(),
154                "txt" => "text/plain".to_string(),
155                "md" => "text/markdown".to_string(),
156                "yaml" | "yml" => "application/x-yaml".to_string(),
157                _ => "application/octet-stream".to_string(),
158            }
159        } else {
160            match artifact_type {
161                ArtifactType::Config
162                | ArtifactType::Tokenizer
163                | ArtifactType::Architecture
164                | ArtifactType::Preprocessing
165                | ArtifactType::Metrics => "application/json".to_string(),
166                ArtifactType::Documentation => "text/markdown".to_string(),
167                ArtifactType::Vocabulary => "text/plain".to_string(),
168                _ => "application/octet-stream".to_string(),
169            }
170        }
171    }
172
173    /// Verify content integrity
174    pub fn verify_integrity(&self) -> bool {
175        Self::compute_hash(&self.content) == self.content_hash
176    }
177
178    /// Get file extension
179    pub fn file_extension(&self) -> Option<&str> {
180        self.file_path.extension()?.to_str()
181    }
182}
183
184/// File system storage backend
185pub struct FileSystemStorage {
186    base_path: PathBuf,
187    archive_path: PathBuf,
188    metadata_cache: tokio::sync::RwLock<HashMap<Uuid, Artifact>>,
189}
190
191impl FileSystemStorage {
192    /// Create a new filesystem storage backend
193    pub fn new(base_path: PathBuf) -> Self {
194        let archive_path = base_path.join("archive");
195        Self {
196            base_path,
197            archive_path,
198            metadata_cache: tokio::sync::RwLock::new(HashMap::new()),
199        }
200    }
201
202    /// Initialize storage directories
203    pub async fn initialize(&self) -> Result<()> {
204        fs::create_dir_all(&self.base_path).await?;
205        fs::create_dir_all(&self.archive_path).await?;
206        Ok(())
207    }
208
209    /// Get storage path for an artifact
210    fn get_artifact_path(&self, artifact_id: Uuid) -> PathBuf {
211        let id_str = artifact_id.to_string();
212        let prefix = &id_str[0..2];
213        self.base_path.join("artifacts").join(prefix).join(&id_str)
214    }
215
216    /// Get archive path for an artifact
217    fn get_archive_path(&self, artifact_id: Uuid) -> PathBuf {
218        let id_str = artifact_id.to_string();
219        let prefix = &id_str[0..2];
220        self.archive_path.join("artifacts").join(prefix).join(&id_str)
221    }
222
223    /// Store artifact metadata
224    async fn store_metadata(&self, artifact: &Artifact) -> Result<()> {
225        let metadata_path = self.get_artifact_path(artifact.id).with_extension("meta");
226        if let Some(parent) = metadata_path.parent() {
227            fs::create_dir_all(parent).await?;
228        }
229
230        let metadata_json = serde_json::to_string_pretty(artifact)?;
231        fs::write(metadata_path, metadata_json).await?;
232
233        // Cache metadata
234        self.metadata_cache.write().await.insert(artifact.id, artifact.clone());
235        Ok(())
236    }
237
238    /// Load artifact metadata
239    async fn load_metadata(&self, artifact_id: Uuid) -> Result<Option<Artifact>> {
240        // Check cache first
241        if let Some(artifact) = self.metadata_cache.read().await.get(&artifact_id) {
242            return Ok(Some(artifact.clone()));
243        }
244
245        // Load from disk
246        let metadata_path = self.get_artifact_path(artifact_id).with_extension("meta");
247        if !metadata_path.exists() {
248            return Ok(None);
249        }
250
251        let metadata_json = fs::read_to_string(metadata_path).await?;
252        let mut artifact: Artifact = serde_json::from_str(&metadata_json)?;
253
254        // Load content if needed
255        let content_path = self.get_artifact_path(artifact_id).with_extension("bin");
256        if content_path.exists() {
257            artifact.content = fs::read(content_path).await?;
258        }
259
260        // Cache metadata
261        self.metadata_cache.write().await.insert(artifact_id, artifact.clone());
262        Ok(Some(artifact))
263    }
264}
265
266#[async_trait]
267impl ModelStorage for FileSystemStorage {
268    async fn store_artifacts(&self, artifacts: &[Artifact]) -> Result<Vec<Uuid>> {
269        let mut artifact_ids = Vec::new();
270
271        for artifact in artifacts {
272            // Store content
273            let content_path = self.get_artifact_path(artifact.id).with_extension("bin");
274            if let Some(parent) = content_path.parent() {
275                fs::create_dir_all(parent).await?;
276            }
277            fs::write(&content_path, &artifact.content).await?;
278
279            // Store metadata
280            self.store_metadata(artifact).await?;
281
282            artifact_ids.push(artifact.id);
283            tracing::debug!("Stored artifact {} at {:?}", artifact.id, content_path);
284        }
285
286        Ok(artifact_ids)
287    }
288
289    async fn get_artifact(&self, artifact_id: Uuid) -> Result<Option<Artifact>> {
290        self.load_metadata(artifact_id).await
291    }
292
293    async fn delete_artifacts(&self, artifact_ids: &[Uuid]) -> Result<()> {
294        for &artifact_id in artifact_ids {
295            let content_path = self.get_artifact_path(artifact_id).with_extension("bin");
296            let metadata_path = self.get_artifact_path(artifact_id).with_extension("meta");
297
298            if content_path.exists() {
299                fs::remove_file(content_path).await?;
300            }
301            if metadata_path.exists() {
302                fs::remove_file(metadata_path).await?;
303            }
304
305            // Remove from cache
306            self.metadata_cache.write().await.remove(&artifact_id);
307            tracing::debug!("Deleted artifact {}", artifact_id);
308        }
309        Ok(())
310    }
311
312    async fn archive_version(&self, version_id: Uuid) -> Result<()> {
313        // Move artifacts to archive directory
314        let artifacts = self.list_artifacts(version_id).await?;
315
316        for artifact in artifacts {
317            let src_content = self.get_artifact_path(artifact.id).with_extension("bin");
318            let src_metadata = self.get_artifact_path(artifact.id).with_extension("meta");
319
320            let dst_content = self.get_archive_path(artifact.id).with_extension("bin");
321            let dst_metadata = self.get_archive_path(artifact.id).with_extension("meta");
322
323            if let Some(parent) = dst_content.parent() {
324                fs::create_dir_all(parent).await?;
325            }
326
327            if src_content.exists() {
328                fs::rename(src_content, dst_content).await?;
329            }
330            if src_metadata.exists() {
331                fs::rename(src_metadata, dst_metadata).await?;
332            }
333
334            // Remove from cache
335            self.metadata_cache.write().await.remove(&artifact.id);
336        }
337
338        tracing::info!("Archived version {}", version_id);
339        Ok(())
340    }
341
342    async fn delete_version(&self, version_id: Uuid) -> Result<()> {
343        let artifacts = self.list_artifacts(version_id).await?;
344        let artifact_ids: Vec<Uuid> = artifacts.iter().map(|a| a.id).collect();
345        self.delete_artifacts(&artifact_ids).await?;
346
347        tracing::info!("Deleted version {}", version_id);
348        Ok(())
349    }
350
351    async fn list_artifacts(&self, _version_id: Uuid) -> Result<Vec<Artifact>> {
352        // This would normally query a database or index
353        // For now, return artifacts from cache
354        let cache = self.metadata_cache.read().await;
355        Ok(cache.values().cloned().collect())
356    }
357}
358
359/// In-memory storage backend for testing
360pub struct InMemoryStorage {
361    artifacts: tokio::sync::RwLock<HashMap<Uuid, Artifact>>,
362    archived: tokio::sync::RwLock<HashMap<Uuid, Artifact>>,
363}
364
365impl InMemoryStorage {
366    pub fn new() -> Self {
367        Self {
368            artifacts: tokio::sync::RwLock::new(HashMap::new()),
369            archived: tokio::sync::RwLock::new(HashMap::new()),
370        }
371    }
372}
373
374#[async_trait]
375impl ModelStorage for InMemoryStorage {
376    async fn store_artifacts(&self, artifacts: &[Artifact]) -> Result<Vec<Uuid>> {
377        let mut artifact_ids = Vec::new();
378        let mut storage = self.artifacts.write().await;
379
380        for artifact in artifacts {
381            storage.insert(artifact.id, artifact.clone());
382            artifact_ids.push(artifact.id);
383        }
384
385        Ok(artifact_ids)
386    }
387
388    async fn get_artifact(&self, artifact_id: Uuid) -> Result<Option<Artifact>> {
389        let storage = self.artifacts.read().await;
390        Ok(storage.get(&artifact_id).cloned())
391    }
392
393    async fn delete_artifacts(&self, artifact_ids: &[Uuid]) -> Result<()> {
394        let mut storage = self.artifacts.write().await;
395        for &artifact_id in artifact_ids {
396            storage.remove(&artifact_id);
397        }
398        Ok(())
399    }
400
401    async fn archive_version(&self, version_id: Uuid) -> Result<()> {
402        let artifacts = self.list_artifacts(version_id).await?;
403
404        let mut storage = self.artifacts.write().await;
405        let mut archived = self.archived.write().await;
406
407        for artifact in artifacts {
408            if let Some(artifact) = storage.remove(&artifact.id) {
409                archived.insert(artifact.id, artifact);
410            }
411        }
412
413        Ok(())
414    }
415
416    async fn delete_version(&self, version_id: Uuid) -> Result<()> {
417        let artifacts = self.list_artifacts(version_id).await?;
418        let artifact_ids: Vec<Uuid> = artifacts.iter().map(|a| a.id).collect();
419        self.delete_artifacts(&artifact_ids).await
420    }
421
422    async fn list_artifacts(&self, _version_id: Uuid) -> Result<Vec<Artifact>> {
423        let storage = self.artifacts.read().await;
424        Ok(storage.values().cloned().collect())
425    }
426}
427
428impl Default for InMemoryStorage {
429    fn default() -> Self {
430        Self::new()
431    }
432}
433
434#[cfg(test)]
435mod tests {
436    use super::*;
437    use tempfile::TempDir;
438
439    #[tokio::test]
440    async fn test_artifact_creation() {
441        let content = b"test model data".to_vec();
442        let artifact = Artifact::new(
443            ArtifactType::Model,
444            PathBuf::from("model.bin"),
445            content.clone(),
446        );
447
448        assert_eq!(artifact.artifact_type, ArtifactType::Model);
449        assert_eq!(artifact.content, content);
450        assert_eq!(artifact.size_bytes, content.len() as u64);
451        assert!(!artifact.content_hash.is_empty());
452        assert!(artifact.verify_integrity());
453    }
454
455    #[tokio::test]
456    async fn test_filesystem_storage() {
457        let temp_dir = TempDir::new().expect("temp file creation failed");
458        let storage = FileSystemStorage::new(temp_dir.path().to_path_buf());
459        storage.initialize().await.expect("async operation failed");
460
461        let artifact = Artifact::new(
462            ArtifactType::Model,
463            PathBuf::from("test_model.bin"),
464            b"test content".to_vec(),
465        );
466
467        // Store artifact
468        let ids = storage
469            .store_artifacts(std::slice::from_ref(&artifact))
470            .await
471            .expect("async operation failed");
472        assert_eq!(ids.len(), 1);
473        assert_eq!(ids[0], artifact.id);
474
475        // Retrieve artifact
476        let retrieved = storage.get_artifact(artifact.id).await.expect("async operation failed");
477        assert!(retrieved.is_some());
478        let retrieved = retrieved.expect("operation failed in test");
479        assert_eq!(retrieved.content, artifact.content);
480        assert_eq!(retrieved.content_hash, artifact.content_hash);
481
482        // Delete artifact
483        storage.delete_artifacts(&[artifact.id]).await.expect("async operation failed");
484        let deleted = storage.get_artifact(artifact.id).await.expect("async operation failed");
485        assert!(deleted.is_none());
486    }
487
488    #[tokio::test]
489    async fn test_inmemory_storage() {
490        let storage = InMemoryStorage::new();
491
492        let artifact = Artifact::new(
493            ArtifactType::Config,
494            PathBuf::from("config.json"),
495            b"{}".to_vec(),
496        );
497
498        // Store and retrieve
499        let ids = storage
500            .store_artifacts(std::slice::from_ref(&artifact))
501            .await
502            .expect("async operation failed");
503        assert_eq!(ids[0], artifact.id);
504
505        let retrieved = storage.get_artifact(artifact.id).await.expect("async operation failed");
506        assert!(retrieved.is_some());
507        assert_eq!(
508            retrieved.expect("operation failed in test").content,
509            artifact.content
510        );
511    }
512
513    #[test]
514    fn test_artifact_types() {
515        assert_eq!(ArtifactType::Model.default_extension(), "bin");
516        assert_eq!(ArtifactType::Config.default_extension(), "json");
517        assert!(ArtifactType::Model.is_required_for_deployment());
518        assert!(ArtifactType::Config.is_required_for_deployment());
519        assert!(!ArtifactType::Documentation.is_required_for_deployment());
520    }
521
522    #[test]
523    fn test_mime_type_detection() {
524        let json_artifact = Artifact::new(
525            ArtifactType::Config,
526            PathBuf::from("config.json"),
527            b"{}".to_vec(),
528        );
529        assert_eq!(json_artifact.mime_type, "application/json");
530
531        let bin_artifact = Artifact::new(
532            ArtifactType::Model,
533            PathBuf::from("model.bin"),
534            b"binary data".to_vec(),
535        );
536        assert_eq!(bin_artifact.mime_type, "application/octet-stream");
537    }
538}