ruvector_snapshot/
storage.rs

1use async_trait::async_trait;
2use flate2::read::GzDecoder;
3use flate2::write::GzEncoder;
4use flate2::Compression;
5use sha2::{Digest, Sha256};
6use std::io::{Read, Write};
7use std::path::PathBuf;
8use tokio::fs;
9
10use crate::error::{Result, SnapshotError};
11use crate::snapshot::{Snapshot, SnapshotData};
12
13/// Trait for snapshot storage backends
14#[async_trait]
15pub trait SnapshotStorage: Send + Sync {
16    /// Save a snapshot to storage
17    async fn save(&self, snapshot: &SnapshotData) -> Result<Snapshot>;
18
19    /// Load a snapshot from storage
20    async fn load(&self, id: &str) -> Result<SnapshotData>;
21
22    /// List all available snapshots
23    async fn list(&self) -> Result<Vec<Snapshot>>;
24
25    /// Delete a snapshot from storage
26    async fn delete(&self, id: &str) -> Result<()>;
27}
28
29/// Local filesystem storage backend
30pub struct LocalStorage {
31    base_path: PathBuf,
32}
33
34impl LocalStorage {
35    /// Create a new local storage instance
36    pub fn new(base_path: PathBuf) -> Self {
37        Self { base_path }
38    }
39
40    /// Get the path for a snapshot file
41    fn snapshot_path(&self, id: &str) -> PathBuf {
42        self.base_path.join(format!("{}.snapshot.gz", id))
43    }
44
45    /// Get the path for a snapshot metadata file
46    fn metadata_path(&self, id: &str) -> PathBuf {
47        self.base_path.join(format!("{}.metadata.json", id))
48    }
49
50    /// Compress data using gzip
51    fn compress(data: &[u8]) -> Result<Vec<u8>> {
52        let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
53        encoder
54            .write_all(data)
55            .map_err(|e| SnapshotError::compression(format!("Compression failed: {}", e)))?;
56        encoder
57            .finish()
58            .map_err(|e| SnapshotError::compression(format!("Finish compression failed: {}", e)))
59    }
60
61    /// Decompress gzip data
62    fn decompress(data: &[u8]) -> Result<Vec<u8>> {
63        let mut decoder = GzDecoder::new(data);
64        let mut decompressed = Vec::new();
65        decoder
66            .read_to_end(&mut decompressed)
67            .map_err(|e| SnapshotError::compression(format!("Decompression failed: {}", e)))?;
68        Ok(decompressed)
69    }
70
71    /// Calculate SHA-256 checksum
72    fn calculate_checksum(data: &[u8]) -> String {
73        let mut hasher = Sha256::new();
74        hasher.update(data);
75        format!("{:x}", hasher.finalize())
76    }
77
78    /// Ensure the base directory exists
79    async fn ensure_dir(&self) -> Result<()> {
80        if !self.base_path.exists() {
81            fs::create_dir_all(&self.base_path).await?;
82        }
83        Ok(())
84    }
85}
86
87#[async_trait]
88impl SnapshotStorage for LocalStorage {
89    async fn save(&self, snapshot_data: &SnapshotData) -> Result<Snapshot> {
90        self.ensure_dir().await?;
91
92        let id = snapshot_data.id().to_string();
93        let snapshot_path = self.snapshot_path(&id);
94        let metadata_path = self.metadata_path(&id);
95
96        // Serialize snapshot data
97        let config = bincode::config::standard();
98        let serialized = bincode::encode_to_vec(snapshot_data, config)
99            .map_err(|e| SnapshotError::SerializationError(e.to_string()))?;
100
101        // Calculate checksum before compression
102        let checksum = Self::calculate_checksum(&serialized);
103
104        // Compress data
105        let compressed = Self::compress(&serialized)?;
106        let size_bytes = compressed.len() as u64;
107
108        // Write compressed data
109        fs::write(&snapshot_path, &compressed).await?;
110
111        // Create snapshot metadata
112        let created_at = chrono::DateTime::parse_from_rfc3339(&snapshot_data.metadata.created_at)
113            .map_err(|e| SnapshotError::storage(format!("Invalid timestamp: {}", e)))?
114            .with_timezone(&chrono::Utc);
115
116        let snapshot = Snapshot {
117            id: id.clone(),
118            collection_name: snapshot_data.collection_name().to_string(),
119            created_at,
120            vectors_count: snapshot_data.vectors_count(),
121            checksum,
122            size_bytes,
123        };
124
125        // Write metadata
126        let metadata_json = serde_json::to_string_pretty(&snapshot)?;
127        fs::write(&metadata_path, metadata_json).await?;
128
129        Ok(snapshot)
130    }
131
132    async fn load(&self, id: &str) -> Result<SnapshotData> {
133        let snapshot_path = self.snapshot_path(id);
134        let metadata_path = self.metadata_path(id);
135
136        // Check if files exist
137        if !snapshot_path.exists() {
138            return Err(SnapshotError::SnapshotNotFound(id.to_string()));
139        }
140
141        // Load and verify metadata
142        let metadata_json = fs::read_to_string(&metadata_path).await?;
143        let snapshot: Snapshot = serde_json::from_str(&metadata_json)?;
144
145        // Load compressed data
146        let compressed = fs::read(&snapshot_path).await?;
147
148        // Decompress
149        let decompressed = Self::decompress(&compressed)?;
150
151        // Verify checksum
152        let actual_checksum = Self::calculate_checksum(&decompressed);
153        if actual_checksum != snapshot.checksum {
154            return Err(SnapshotError::InvalidChecksum {
155                expected: snapshot.checksum,
156                actual: actual_checksum,
157            });
158        }
159
160        // Deserialize
161        let config = bincode::config::standard();
162        let (snapshot_data, _): (SnapshotData, usize) =
163            bincode::decode_from_slice(&decompressed, config)
164                .map_err(|e| SnapshotError::SerializationError(e.to_string()))?;
165
166        Ok(snapshot_data)
167    }
168
169    async fn list(&self) -> Result<Vec<Snapshot>> {
170        self.ensure_dir().await?;
171
172        let mut snapshots = Vec::new();
173        let mut entries = fs::read_dir(&self.base_path).await?;
174
175        while let Some(entry) = entries.next_entry().await? {
176            let path = entry.path();
177            if let Some(extension) = path.extension() {
178                if extension == "json" {
179                    if let Some(file_name) = path.file_stem() {
180                        let file_name_str = file_name.to_string_lossy();
181                        if file_name_str.ends_with(".metadata") {
182                            let contents = fs::read_to_string(&path).await?;
183                            if let Ok(snapshot) = serde_json::from_str::<Snapshot>(&contents) {
184                                snapshots.push(snapshot);
185                            }
186                        }
187                    }
188                }
189            }
190        }
191
192        // Sort by creation date (newest first)
193        snapshots.sort_by(|a, b| b.created_at.cmp(&a.created_at));
194
195        Ok(snapshots)
196    }
197
198    async fn delete(&self, id: &str) -> Result<()> {
199        let snapshot_path = self.snapshot_path(id);
200        let metadata_path = self.metadata_path(id);
201
202        if !snapshot_path.exists() {
203            return Err(SnapshotError::SnapshotNotFound(id.to_string()));
204        }
205
206        // Delete both files
207        fs::remove_file(&snapshot_path).await?;
208
209        if metadata_path.exists() {
210            fs::remove_file(&metadata_path).await?;
211        }
212
213        Ok(())
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220    use crate::snapshot::{CollectionConfig, DistanceMetric, VectorRecord};
221
222    #[test]
223    fn test_compression_roundtrip() {
224        let data = b"Hello, World! This is test data for compression.";
225        let compressed = LocalStorage::compress(data).unwrap();
226        let decompressed = LocalStorage::decompress(&compressed).unwrap();
227        assert_eq!(data.to_vec(), decompressed);
228    }
229
230    #[test]
231    fn test_checksum_calculation() {
232        let data = b"test data";
233        let checksum = LocalStorage::calculate_checksum(data);
234        assert_eq!(checksum.len(), 64); // SHA-256 produces 64 hex characters
235    }
236
237    #[tokio::test]
238    async fn test_local_storage_roundtrip() {
239        let temp_dir = std::env::temp_dir().join("ruvector-snapshot-test");
240        let storage = LocalStorage::new(temp_dir.clone());
241
242        let config = CollectionConfig {
243            dimension: 3,
244            metric: DistanceMetric::Cosine,
245            hnsw_config: None,
246        };
247
248        let vectors = vec![
249            VectorRecord::new("v1".to_string(), vec![1.0, 0.0, 0.0], None),
250            VectorRecord::new("v2".to_string(), vec![0.0, 1.0, 0.0], None),
251        ];
252
253        let snapshot_data = SnapshotData::new("test-collection".to_string(), config, vectors);
254        let id = snapshot_data.id().to_string();
255
256        // Save
257        let snapshot = storage.save(&snapshot_data).await.unwrap();
258        assert_eq!(snapshot.id, id);
259        assert_eq!(snapshot.vectors_count, 2);
260
261        // List
262        let snapshots = storage.list().await.unwrap();
263        assert!(!snapshots.is_empty());
264
265        // Load
266        let loaded = storage.load(&id).await.unwrap();
267        assert_eq!(loaded.id(), id);
268        assert_eq!(loaded.vectors_count(), 2);
269
270        // Delete
271        storage.delete(&id).await.unwrap();
272
273        // Cleanup
274        let _ = std::fs::remove_dir_all(temp_dir);
275    }
276}