trustformers_core/versioning/
storage.rs1use anyhow::Result;
4use async_trait::async_trait;
5use chrono::{DateTime, Utc};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::path::{Path, PathBuf};
9use tokio::fs;
10use uuid::Uuid;
11
12#[async_trait]
14pub trait ModelStorage: Send + Sync {
15 async fn store_artifacts(&self, artifacts: &[Artifact]) -> Result<Vec<Uuid>>;
17
18 async fn get_artifact(&self, artifact_id: Uuid) -> Result<Option<Artifact>>;
20
21 async fn delete_artifacts(&self, artifact_ids: &[Uuid]) -> Result<()>;
23
24 async fn archive_version(&self, version_id: Uuid) -> Result<()>;
26
27 async fn delete_version(&self, version_id: Uuid) -> Result<()>;
29
30 async fn list_artifacts(&self, version_id: Uuid) -> Result<Vec<Artifact>>;
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
36pub enum ArtifactType {
37 Model,
39 Config,
41 Tokenizer,
43 Vocabulary,
45 Checkpoint,
47 OptimizerState,
49 Architecture,
51 Preprocessing,
53 Metrics,
55 Documentation,
57 Custom(String),
59}
60
61impl ArtifactType {
62 pub fn default_extension(&self) -> &'static str {
64 match self {
65 ArtifactType::Model => "bin",
66 ArtifactType::Config => "json",
67 ArtifactType::Tokenizer => "json",
68 ArtifactType::Vocabulary => "txt",
69 ArtifactType::Checkpoint => "ckpt",
70 ArtifactType::OptimizerState => "bin",
71 ArtifactType::Architecture => "json",
72 ArtifactType::Preprocessing => "json",
73 ArtifactType::Metrics => "json",
74 ArtifactType::Documentation => "md",
75 ArtifactType::Custom(_) => "bin",
76 }
77 }
78
79 pub fn is_required_for_deployment(&self) -> bool {
81 matches!(self, ArtifactType::Model | ArtifactType::Config)
82 }
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct Artifact {
88 pub id: Uuid,
90 pub artifact_type: ArtifactType,
92 pub file_path: PathBuf,
94 pub size_bytes: u64,
96 pub content_hash: String,
98 pub mime_type: String,
100 pub content: Vec<u8>,
102 pub created_at: DateTime<Utc>,
104 pub metadata: HashMap<String, serde_json::Value>,
106}
107
108impl Artifact {
109 pub fn new(artifact_type: ArtifactType, file_path: PathBuf, content: Vec<u8>) -> Self {
111 let content_hash = Self::compute_hash(&content);
112 let mime_type = Self::detect_mime_type(&file_path, &artifact_type);
113
114 Self {
115 id: Uuid::new_v4(),
116 artifact_type,
117 size_bytes: content.len() as u64,
118 content_hash,
119 mime_type,
120 content,
121 file_path,
122 created_at: Utc::now(),
123 metadata: HashMap::new(),
124 }
125 }
126
127 pub async fn from_file(artifact_type: ArtifactType, file_path: PathBuf) -> Result<Self> {
129 let content = fs::read(&file_path).await?;
130 Ok(Self::new(artifact_type, file_path, content))
131 }
132
133 pub fn with_metadata(mut self, key: String, value: serde_json::Value) -> Self {
135 self.metadata.insert(key, value);
136 self
137 }
138
139 fn compute_hash(content: &[u8]) -> String {
141 use sha2::{Digest, Sha256};
142 let mut hasher = Sha256::new();
143 hasher.update(content);
144 format!("{:x}", hasher.finalize())
145 }
146
147 fn detect_mime_type(file_path: &Path, artifact_type: &ArtifactType) -> String {
149 if let Some(extension) = file_path.extension().and_then(|s| s.to_str()) {
151 match extension.to_lowercase().as_str() {
152 "json" => "application/json".to_string(),
153 "bin" | "pt" | "pth" => "application/octet-stream".to_string(),
154 "txt" => "text/plain".to_string(),
155 "md" => "text/markdown".to_string(),
156 "yaml" | "yml" => "application/x-yaml".to_string(),
157 _ => "application/octet-stream".to_string(),
158 }
159 } else {
160 match artifact_type {
161 ArtifactType::Config
162 | ArtifactType::Tokenizer
163 | ArtifactType::Architecture
164 | ArtifactType::Preprocessing
165 | ArtifactType::Metrics => "application/json".to_string(),
166 ArtifactType::Documentation => "text/markdown".to_string(),
167 ArtifactType::Vocabulary => "text/plain".to_string(),
168 _ => "application/octet-stream".to_string(),
169 }
170 }
171 }
172
173 pub fn verify_integrity(&self) -> bool {
175 Self::compute_hash(&self.content) == self.content_hash
176 }
177
178 pub fn file_extension(&self) -> Option<&str> {
180 self.file_path.extension()?.to_str()
181 }
182}
183
184pub struct FileSystemStorage {
186 base_path: PathBuf,
187 archive_path: PathBuf,
188 metadata_cache: tokio::sync::RwLock<HashMap<Uuid, Artifact>>,
189}
190
191impl FileSystemStorage {
192 pub fn new(base_path: PathBuf) -> Self {
194 let archive_path = base_path.join("archive");
195 Self {
196 base_path,
197 archive_path,
198 metadata_cache: tokio::sync::RwLock::new(HashMap::new()),
199 }
200 }
201
202 pub async fn initialize(&self) -> Result<()> {
204 fs::create_dir_all(&self.base_path).await?;
205 fs::create_dir_all(&self.archive_path).await?;
206 Ok(())
207 }
208
209 fn get_artifact_path(&self, artifact_id: Uuid) -> PathBuf {
211 let id_str = artifact_id.to_string();
212 let prefix = &id_str[0..2];
213 self.base_path.join("artifacts").join(prefix).join(&id_str)
214 }
215
216 fn get_archive_path(&self, artifact_id: Uuid) -> PathBuf {
218 let id_str = artifact_id.to_string();
219 let prefix = &id_str[0..2];
220 self.archive_path.join("artifacts").join(prefix).join(&id_str)
221 }
222
223 async fn store_metadata(&self, artifact: &Artifact) -> Result<()> {
225 let metadata_path = self.get_artifact_path(artifact.id).with_extension("meta");
226 if let Some(parent) = metadata_path.parent() {
227 fs::create_dir_all(parent).await?;
228 }
229
230 let metadata_json = serde_json::to_string_pretty(artifact)?;
231 fs::write(metadata_path, metadata_json).await?;
232
233 self.metadata_cache.write().await.insert(artifact.id, artifact.clone());
235 Ok(())
236 }
237
238 async fn load_metadata(&self, artifact_id: Uuid) -> Result<Option<Artifact>> {
240 if let Some(artifact) = self.metadata_cache.read().await.get(&artifact_id) {
242 return Ok(Some(artifact.clone()));
243 }
244
245 let metadata_path = self.get_artifact_path(artifact_id).with_extension("meta");
247 if !metadata_path.exists() {
248 return Ok(None);
249 }
250
251 let metadata_json = fs::read_to_string(metadata_path).await?;
252 let mut artifact: Artifact = serde_json::from_str(&metadata_json)?;
253
254 let content_path = self.get_artifact_path(artifact_id).with_extension("bin");
256 if content_path.exists() {
257 artifact.content = fs::read(content_path).await?;
258 }
259
260 self.metadata_cache.write().await.insert(artifact_id, artifact.clone());
262 Ok(Some(artifact))
263 }
264}
265
266#[async_trait]
267impl ModelStorage for FileSystemStorage {
268 async fn store_artifacts(&self, artifacts: &[Artifact]) -> Result<Vec<Uuid>> {
269 let mut artifact_ids = Vec::new();
270
271 for artifact in artifacts {
272 let content_path = self.get_artifact_path(artifact.id).with_extension("bin");
274 if let Some(parent) = content_path.parent() {
275 fs::create_dir_all(parent).await?;
276 }
277 fs::write(&content_path, &artifact.content).await?;
278
279 self.store_metadata(artifact).await?;
281
282 artifact_ids.push(artifact.id);
283 tracing::debug!("Stored artifact {} at {:?}", artifact.id, content_path);
284 }
285
286 Ok(artifact_ids)
287 }
288
289 async fn get_artifact(&self, artifact_id: Uuid) -> Result<Option<Artifact>> {
290 self.load_metadata(artifact_id).await
291 }
292
293 async fn delete_artifacts(&self, artifact_ids: &[Uuid]) -> Result<()> {
294 for &artifact_id in artifact_ids {
295 let content_path = self.get_artifact_path(artifact_id).with_extension("bin");
296 let metadata_path = self.get_artifact_path(artifact_id).with_extension("meta");
297
298 if content_path.exists() {
299 fs::remove_file(content_path).await?;
300 }
301 if metadata_path.exists() {
302 fs::remove_file(metadata_path).await?;
303 }
304
305 self.metadata_cache.write().await.remove(&artifact_id);
307 tracing::debug!("Deleted artifact {}", artifact_id);
308 }
309 Ok(())
310 }
311
312 async fn archive_version(&self, version_id: Uuid) -> Result<()> {
313 let artifacts = self.list_artifacts(version_id).await?;
315
316 for artifact in artifacts {
317 let src_content = self.get_artifact_path(artifact.id).with_extension("bin");
318 let src_metadata = self.get_artifact_path(artifact.id).with_extension("meta");
319
320 let dst_content = self.get_archive_path(artifact.id).with_extension("bin");
321 let dst_metadata = self.get_archive_path(artifact.id).with_extension("meta");
322
323 if let Some(parent) = dst_content.parent() {
324 fs::create_dir_all(parent).await?;
325 }
326
327 if src_content.exists() {
328 fs::rename(src_content, dst_content).await?;
329 }
330 if src_metadata.exists() {
331 fs::rename(src_metadata, dst_metadata).await?;
332 }
333
334 self.metadata_cache.write().await.remove(&artifact.id);
336 }
337
338 tracing::info!("Archived version {}", version_id);
339 Ok(())
340 }
341
342 async fn delete_version(&self, version_id: Uuid) -> Result<()> {
343 let artifacts = self.list_artifacts(version_id).await?;
344 let artifact_ids: Vec<Uuid> = artifacts.iter().map(|a| a.id).collect();
345 self.delete_artifacts(&artifact_ids).await?;
346
347 tracing::info!("Deleted version {}", version_id);
348 Ok(())
349 }
350
351 async fn list_artifacts(&self, _version_id: Uuid) -> Result<Vec<Artifact>> {
352 let cache = self.metadata_cache.read().await;
355 Ok(cache.values().cloned().collect())
356 }
357}
358
359pub struct InMemoryStorage {
361 artifacts: tokio::sync::RwLock<HashMap<Uuid, Artifact>>,
362 archived: tokio::sync::RwLock<HashMap<Uuid, Artifact>>,
363}
364
365impl InMemoryStorage {
366 pub fn new() -> Self {
367 Self {
368 artifacts: tokio::sync::RwLock::new(HashMap::new()),
369 archived: tokio::sync::RwLock::new(HashMap::new()),
370 }
371 }
372}
373
374#[async_trait]
375impl ModelStorage for InMemoryStorage {
376 async fn store_artifacts(&self, artifacts: &[Artifact]) -> Result<Vec<Uuid>> {
377 let mut artifact_ids = Vec::new();
378 let mut storage = self.artifacts.write().await;
379
380 for artifact in artifacts {
381 storage.insert(artifact.id, artifact.clone());
382 artifact_ids.push(artifact.id);
383 }
384
385 Ok(artifact_ids)
386 }
387
388 async fn get_artifact(&self, artifact_id: Uuid) -> Result<Option<Artifact>> {
389 let storage = self.artifacts.read().await;
390 Ok(storage.get(&artifact_id).cloned())
391 }
392
393 async fn delete_artifacts(&self, artifact_ids: &[Uuid]) -> Result<()> {
394 let mut storage = self.artifacts.write().await;
395 for &artifact_id in artifact_ids {
396 storage.remove(&artifact_id);
397 }
398 Ok(())
399 }
400
401 async fn archive_version(&self, version_id: Uuid) -> Result<()> {
402 let artifacts = self.list_artifacts(version_id).await?;
403
404 let mut storage = self.artifacts.write().await;
405 let mut archived = self.archived.write().await;
406
407 for artifact in artifacts {
408 if let Some(artifact) = storage.remove(&artifact.id) {
409 archived.insert(artifact.id, artifact);
410 }
411 }
412
413 Ok(())
414 }
415
416 async fn delete_version(&self, version_id: Uuid) -> Result<()> {
417 let artifacts = self.list_artifacts(version_id).await?;
418 let artifact_ids: Vec<Uuid> = artifacts.iter().map(|a| a.id).collect();
419 self.delete_artifacts(&artifact_ids).await
420 }
421
422 async fn list_artifacts(&self, _version_id: Uuid) -> Result<Vec<Artifact>> {
423 let storage = self.artifacts.read().await;
424 Ok(storage.values().cloned().collect())
425 }
426}
427
428impl Default for InMemoryStorage {
429 fn default() -> Self {
430 Self::new()
431 }
432}
433
434#[cfg(test)]
435mod tests {
436 use super::*;
437 use tempfile::TempDir;
438
439 #[tokio::test]
440 async fn test_artifact_creation() {
441 let content = b"test model data".to_vec();
442 let artifact = Artifact::new(
443 ArtifactType::Model,
444 PathBuf::from("model.bin"),
445 content.clone(),
446 );
447
448 assert_eq!(artifact.artifact_type, ArtifactType::Model);
449 assert_eq!(artifact.content, content);
450 assert_eq!(artifact.size_bytes, content.len() as u64);
451 assert!(!artifact.content_hash.is_empty());
452 assert!(artifact.verify_integrity());
453 }
454
455 #[tokio::test]
456 async fn test_filesystem_storage() {
457 let temp_dir = TempDir::new().expect("temp file creation failed");
458 let storage = FileSystemStorage::new(temp_dir.path().to_path_buf());
459 storage.initialize().await.expect("async operation failed");
460
461 let artifact = Artifact::new(
462 ArtifactType::Model,
463 PathBuf::from("test_model.bin"),
464 b"test content".to_vec(),
465 );
466
467 let ids = storage
469 .store_artifacts(std::slice::from_ref(&artifact))
470 .await
471 .expect("async operation failed");
472 assert_eq!(ids.len(), 1);
473 assert_eq!(ids[0], artifact.id);
474
475 let retrieved = storage.get_artifact(artifact.id).await.expect("async operation failed");
477 assert!(retrieved.is_some());
478 let retrieved = retrieved.expect("operation failed in test");
479 assert_eq!(retrieved.content, artifact.content);
480 assert_eq!(retrieved.content_hash, artifact.content_hash);
481
482 storage.delete_artifacts(&[artifact.id]).await.expect("async operation failed");
484 let deleted = storage.get_artifact(artifact.id).await.expect("async operation failed");
485 assert!(deleted.is_none());
486 }
487
488 #[tokio::test]
489 async fn test_inmemory_storage() {
490 let storage = InMemoryStorage::new();
491
492 let artifact = Artifact::new(
493 ArtifactType::Config,
494 PathBuf::from("config.json"),
495 b"{}".to_vec(),
496 );
497
498 let ids = storage
500 .store_artifacts(std::slice::from_ref(&artifact))
501 .await
502 .expect("async operation failed");
503 assert_eq!(ids[0], artifact.id);
504
505 let retrieved = storage.get_artifact(artifact.id).await.expect("async operation failed");
506 assert!(retrieved.is_some());
507 assert_eq!(
508 retrieved.expect("operation failed in test").content,
509 artifact.content
510 );
511 }
512
513 #[test]
514 fn test_artifact_types() {
515 assert_eq!(ArtifactType::Model.default_extension(), "bin");
516 assert_eq!(ArtifactType::Config.default_extension(), "json");
517 assert!(ArtifactType::Model.is_required_for_deployment());
518 assert!(ArtifactType::Config.is_required_for_deployment());
519 assert!(!ArtifactType::Documentation.is_required_for_deployment());
520 }
521
522 #[test]
523 fn test_mime_type_detection() {
524 let json_artifact = Artifact::new(
525 ArtifactType::Config,
526 PathBuf::from("config.json"),
527 b"{}".to_vec(),
528 );
529 assert_eq!(json_artifact.mime_type, "application/json");
530
531 let bin_artifact = Artifact::new(
532 ArtifactType::Model,
533 PathBuf::from("model.bin"),
534 b"binary data".to_vec(),
535 );
536 assert_eq!(bin_artifact.mime_type, "application/octet-stream");
537 }
538}