ruvector_snapshot/
storage.rs1use async_trait::async_trait;
2use flate2::read::GzDecoder;
3use flate2::write::GzEncoder;
4use flate2::Compression;
5use sha2::{Digest, Sha256};
6use std::io::{Read, Write};
7use std::path::PathBuf;
8use tokio::fs;
9
10use crate::error::{Result, SnapshotError};
11use crate::snapshot::{Snapshot, SnapshotData};
12
13#[async_trait]
15pub trait SnapshotStorage: Send + Sync {
16 async fn save(&self, snapshot: &SnapshotData) -> Result<Snapshot>;
18
19 async fn load(&self, id: &str) -> Result<SnapshotData>;
21
22 async fn list(&self) -> Result<Vec<Snapshot>>;
24
25 async fn delete(&self, id: &str) -> Result<()>;
27}
28
29pub struct LocalStorage {
31 base_path: PathBuf,
32}
33
34impl LocalStorage {
35 pub fn new(base_path: PathBuf) -> Self {
37 Self { base_path }
38 }
39
40 fn snapshot_path(&self, id: &str) -> PathBuf {
42 self.base_path.join(format!("{}.snapshot.gz", id))
43 }
44
45 fn metadata_path(&self, id: &str) -> PathBuf {
47 self.base_path.join(format!("{}.metadata.json", id))
48 }
49
50 fn compress(data: &[u8]) -> Result<Vec<u8>> {
52 let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
53 encoder
54 .write_all(data)
55 .map_err(|e| SnapshotError::compression(format!("Compression failed: {}", e)))?;
56 encoder
57 .finish()
58 .map_err(|e| SnapshotError::compression(format!("Finish compression failed: {}", e)))
59 }
60
61 fn decompress(data: &[u8]) -> Result<Vec<u8>> {
63 let mut decoder = GzDecoder::new(data);
64 let mut decompressed = Vec::new();
65 decoder
66 .read_to_end(&mut decompressed)
67 .map_err(|e| SnapshotError::compression(format!("Decompression failed: {}", e)))?;
68 Ok(decompressed)
69 }
70
71 fn calculate_checksum(data: &[u8]) -> String {
73 let mut hasher = Sha256::new();
74 hasher.update(data);
75 format!("{:x}", hasher.finalize())
76 }
77
78 async fn ensure_dir(&self) -> Result<()> {
80 if !self.base_path.exists() {
81 fs::create_dir_all(&self.base_path).await?;
82 }
83 Ok(())
84 }
85}
86
87#[async_trait]
88impl SnapshotStorage for LocalStorage {
89 async fn save(&self, snapshot_data: &SnapshotData) -> Result<Snapshot> {
90 self.ensure_dir().await?;
91
92 let id = snapshot_data.id().to_string();
93 let snapshot_path = self.snapshot_path(&id);
94 let metadata_path = self.metadata_path(&id);
95
96 let config = bincode::config::standard();
98 let serialized = bincode::encode_to_vec(snapshot_data, config)
99 .map_err(|e| SnapshotError::SerializationError(e.to_string()))?;
100
101 let checksum = Self::calculate_checksum(&serialized);
103
104 let compressed = Self::compress(&serialized)?;
106 let size_bytes = compressed.len() as u64;
107
108 fs::write(&snapshot_path, &compressed).await?;
110
111 let created_at = chrono::DateTime::parse_from_rfc3339(&snapshot_data.metadata.created_at)
113 .map_err(|e| SnapshotError::storage(format!("Invalid timestamp: {}", e)))?
114 .with_timezone(&chrono::Utc);
115
116 let snapshot = Snapshot {
117 id: id.clone(),
118 collection_name: snapshot_data.collection_name().to_string(),
119 created_at,
120 vectors_count: snapshot_data.vectors_count(),
121 checksum,
122 size_bytes,
123 };
124
125 let metadata_json = serde_json::to_string_pretty(&snapshot)?;
127 fs::write(&metadata_path, metadata_json).await?;
128
129 Ok(snapshot)
130 }
131
132 async fn load(&self, id: &str) -> Result<SnapshotData> {
133 let snapshot_path = self.snapshot_path(id);
134 let metadata_path = self.metadata_path(id);
135
136 if !snapshot_path.exists() {
138 return Err(SnapshotError::SnapshotNotFound(id.to_string()));
139 }
140
141 let metadata_json = fs::read_to_string(&metadata_path).await?;
143 let snapshot: Snapshot = serde_json::from_str(&metadata_json)?;
144
145 let compressed = fs::read(&snapshot_path).await?;
147
148 let decompressed = Self::decompress(&compressed)?;
150
151 let actual_checksum = Self::calculate_checksum(&decompressed);
153 if actual_checksum != snapshot.checksum {
154 return Err(SnapshotError::InvalidChecksum {
155 expected: snapshot.checksum,
156 actual: actual_checksum,
157 });
158 }
159
160 let config = bincode::config::standard();
162 let (snapshot_data, _): (SnapshotData, usize) =
163 bincode::decode_from_slice(&decompressed, config)
164 .map_err(|e| SnapshotError::SerializationError(e.to_string()))?;
165
166 Ok(snapshot_data)
167 }
168
169 async fn list(&self) -> Result<Vec<Snapshot>> {
170 self.ensure_dir().await?;
171
172 let mut snapshots = Vec::new();
173 let mut entries = fs::read_dir(&self.base_path).await?;
174
175 while let Some(entry) = entries.next_entry().await? {
176 let path = entry.path();
177 if let Some(extension) = path.extension() {
178 if extension == "json" {
179 if let Some(file_name) = path.file_stem() {
180 let file_name_str = file_name.to_string_lossy();
181 if file_name_str.ends_with(".metadata") {
182 let contents = fs::read_to_string(&path).await?;
183 if let Ok(snapshot) = serde_json::from_str::<Snapshot>(&contents) {
184 snapshots.push(snapshot);
185 }
186 }
187 }
188 }
189 }
190 }
191
192 snapshots.sort_by(|a, b| b.created_at.cmp(&a.created_at));
194
195 Ok(snapshots)
196 }
197
198 async fn delete(&self, id: &str) -> Result<()> {
199 let snapshot_path = self.snapshot_path(id);
200 let metadata_path = self.metadata_path(id);
201
202 if !snapshot_path.exists() {
203 return Err(SnapshotError::SnapshotNotFound(id.to_string()));
204 }
205
206 fs::remove_file(&snapshot_path).await?;
208
209 if metadata_path.exists() {
210 fs::remove_file(&metadata_path).await?;
211 }
212
213 Ok(())
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220 use crate::snapshot::{CollectionConfig, DistanceMetric, VectorRecord};
221
222 #[test]
223 fn test_compression_roundtrip() {
224 let data = b"Hello, World! This is test data for compression.";
225 let compressed = LocalStorage::compress(data).unwrap();
226 let decompressed = LocalStorage::decompress(&compressed).unwrap();
227 assert_eq!(data.to_vec(), decompressed);
228 }
229
230 #[test]
231 fn test_checksum_calculation() {
232 let data = b"test data";
233 let checksum = LocalStorage::calculate_checksum(data);
234 assert_eq!(checksum.len(), 64); }
236
237 #[tokio::test]
238 async fn test_local_storage_roundtrip() {
239 let temp_dir = std::env::temp_dir().join("ruvector-snapshot-test");
240 let storage = LocalStorage::new(temp_dir.clone());
241
242 let config = CollectionConfig {
243 dimension: 3,
244 metric: DistanceMetric::Cosine,
245 hnsw_config: None,
246 };
247
248 let vectors = vec![
249 VectorRecord::new("v1".to_string(), vec![1.0, 0.0, 0.0], None),
250 VectorRecord::new("v2".to_string(), vec![0.0, 1.0, 0.0], None),
251 ];
252
253 let snapshot_data = SnapshotData::new("test-collection".to_string(), config, vectors);
254 let id = snapshot_data.id().to_string();
255
256 let snapshot = storage.save(&snapshot_data).await.unwrap();
258 assert_eq!(snapshot.id, id);
259 assert_eq!(snapshot.vectors_count, 2);
260
261 let snapshots = storage.list().await.unwrap();
263 assert!(!snapshots.is_empty());
264
265 let loaded = storage.load(&id).await.unwrap();
267 assert_eq!(loaded.id(), id);
268 assert_eq!(loaded.vectors_count(), 2);
269
270 storage.delete(&id).await.unwrap();
272
273 let _ = std::fs::remove_dir_all(temp_dir);
275 }
276}