1use vectx_core::{Collection, CollectionConfig, Distance, Error, Result, Point, PointId, Vector, MultiVector};
2use parking_lot::RwLock;
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5use std::sync::Arc;
6use std::time::Duration;
7use crate::lmdb_storage::LmdbStorage;
8use crate::wal::WriteAheadLog;
9use crate::snapshot::{SnapshotManager, SnapshotDescription, CollectionSnapshotData, CollectionConfigData, PointData};
10use crate::persistence::ForkBasedPersistence;
11
12pub struct StorageManager {
14 collections: Arc<RwLock<HashMap<String, Arc<Collection>>>>,
15 aliases: Arc<RwLock<HashMap<String, String>>>,
17 data_dir: PathBuf,
18 #[allow(dead_code)]
19 lmdb: Option<Arc<LmdbStorage>>,
20 #[allow(dead_code)]
21 wal: Option<Arc<WriteAheadLog>>,
22 snapshots: Arc<SnapshotManager>,
23 persistence: Arc<ForkBasedPersistence>,
24 #[allow(dead_code)]
25 save_interval: Option<Duration>,
26}
27
28impl StorageManager {
29 pub fn new<P: AsRef<Path>>(data_dir: P) -> Result<Self> {
30 let data_dir = data_dir.as_ref().to_path_buf();
31 std::fs::create_dir_all(&data_dir)?;
32
33 let lmdb_path = data_dir.join("lmdb");
34 let lmdb = Arc::new(LmdbStorage::new(&lmdb_path)
35 .map_err(|e| Error::Storage(e.to_string()))?);
36
37 let wal_path = data_dir.join("wal.log");
38 let wal = Arc::new(WriteAheadLog::new(&wal_path)
39 .map_err(|e| Error::Storage(e.to_string()))?);
40
41 let snapshot_dir = data_dir.join("snapshots");
42 let snapshots = Arc::new(SnapshotManager::new(&snapshot_dir)
43 .map_err(|e| Error::Storage(e.to_string()))?);
44
45 let persistence = Arc::new(ForkBasedPersistence::new(&data_dir));
46
47 let collections = Arc::new(RwLock::new(HashMap::new()));
48 let aliases = Arc::new(RwLock::new(HashMap::new()));
49
50 if let Some(snapshot) = persistence.load_snapshot()
51 .map_err(|e| Error::Persistence(e.to_string()))? {
52 eprintln!("Loading snapshot from disk...");
53 let mut collections_map = HashMap::new();
54
55 for col_snapshot in snapshot.collections {
56 let config = CollectionConfig {
57 name: col_snapshot.name.clone(),
58 vector_dim: col_snapshot.config.vector_dim,
59 distance: match col_snapshot.config.distance.as_str() {
60 "Cosine" => Distance::Cosine,
61 "Euclidean" => Distance::Euclidean,
62 "Dot" => Distance::Dot,
63 _ => Distance::Cosine,
64 },
65 use_hnsw: col_snapshot.config.use_hnsw,
66 enable_bm25: col_snapshot.config.enable_bm25,
67 };
68
69 let collection = Arc::new(Collection::new(config));
70
71 for point_snapshot in col_snapshot.points {
72 let point = Point::new(
73 PointId::String(point_snapshot.id.clone()),
74 Vector::new(point_snapshot.vector),
75 point_snapshot.payload,
76 );
77 if let Err(e) = collection.upsert(point) {
78 eprintln!("Warning: Failed to restore point {}: {}", point_snapshot.id, e);
79 }
80 }
81
82 collections_map.insert(col_snapshot.name, collection);
83 }
84
85 *collections.write() = collections_map;
86 eprintln!("Snapshot loaded: {} collections", collections.read().len());
87 }
88
89 let manager = Self {
90 collections,
91 aliases,
92 data_dir,
93 lmdb: Some(lmdb),
94 wal: Some(wal),
95 snapshots,
96 persistence,
97 save_interval: Some(Duration::from_secs(300)),
98 };
99
100 manager.start_background_save();
101
102 Ok(manager)
103 }
104
105 fn start_background_save(&self) {
107 let collections = self.collections.clone();
108 let persistence = self.persistence.clone();
109 let interval = self.save_interval.unwrap_or(Duration::from_secs(300));
110
111 std::thread::spawn(move || {
112 loop {
113 std::thread::sleep(interval);
114
115 if !ForkBasedPersistence::is_bgsave_in_progress() {
116 let collections_map = collections.read();
117 if let Err(e) = persistence.bgsave(&collections_map) {
118 eprintln!("Background save error: {}", e);
119 }
120 }
121 }
122 });
123 }
124
125 pub fn create_collection(&self, config: CollectionConfig) -> Result<Arc<Collection>> {
126 let name = config.name.clone();
127 let mut collections = self.collections.write();
128
129 if collections.contains_key(&name) {
130 return Err(Error::CollectionExists(name));
131 }
132
133 let collection = Arc::new(Collection::new(config));
134 collections.insert(name.clone(), collection.clone());
135 Ok(collection)
136 }
137
138 #[inline]
139 pub fn get_collection(&self, name: &str) -> Option<Arc<Collection>> {
140 let collections = self.collections.read();
141 if let Some(col) = collections.get(name) {
143 return Some(col.clone());
144 }
145 let aliases = self.aliases.read();
147 if let Some(collection_name) = aliases.get(name) {
148 return collections.get(collection_name).cloned();
149 }
150 None
151 }
152
153 pub fn delete_collection(&self, name: &str) -> Result<bool> {
154 let mut collections = self.collections.write();
155 let removed = collections.remove(name).is_some();
156
157 if removed {
159 }
160
161 Ok(removed)
162 }
163
164 #[inline]
165 #[must_use]
166 pub fn list_collections(&self) -> Vec<String> {
167 self.collections.read().keys().cloned().collect()
168 }
169
170 #[inline]
171 #[must_use]
172 pub fn collection_exists(&self, name: &str) -> bool {
173 self.collections.read().contains_key(name)
174 }
175
176 pub fn create_alias(&self, alias_name: &str, collection_name: &str) -> Result<bool> {
178 if !self.collection_exists(collection_name) {
180 return Err(Error::CollectionNotFound(collection_name.to_string()));
181 }
182 let mut aliases = self.aliases.write();
183 aliases.insert(alias_name.to_string(), collection_name.to_string());
184 Ok(true)
185 }
186
187 pub fn delete_alias(&self, alias_name: &str) -> Result<bool> {
189 let mut aliases = self.aliases.write();
190 Ok(aliases.remove(alias_name).is_some())
191 }
192
193 pub fn rename_alias(&self, old_alias: &str, new_alias: &str) -> Result<bool> {
195 let mut aliases = self.aliases.write();
196 if let Some(collection_name) = aliases.remove(old_alias) {
197 aliases.insert(new_alias.to_string(), collection_name);
198 Ok(true)
199 } else {
200 Ok(false)
201 }
202 }
203
204 pub fn list_aliases(&self) -> Vec<(String, String)> {
206 self.aliases.read()
207 .iter()
208 .map(|(alias, collection)| (alias.clone(), collection.clone()))
209 .collect()
210 }
211
212 pub fn list_collection_aliases(&self, collection_name: &str) -> Vec<String> {
214 self.aliases.read()
215 .iter()
216 .filter(|(_, col)| *col == collection_name)
217 .map(|(alias, _)| alias.clone())
218 .collect()
219 }
220
221 #[inline]
222 #[must_use]
223 pub fn data_dir(&self) -> &Path {
224 &self.data_dir
225 }
226
227
228 pub fn bgsave(&self) -> Result<bool> {
230 let collections = self.collections.read();
231 self.persistence.bgsave(&collections)
232 .map_err(|e| Error::Storage(e.to_string()))
233 }
234
235 pub fn save(&self) -> Result<()> {
237 let collections = self.collections.read();
238 self.persistence.save(&collections)
239 .map_err(|e| Error::Storage(e.to_string()))
240 }
241
242 pub fn last_save_time(&self) -> u64 {
244 ForkBasedPersistence::last_save_time()
245 }
246
247 pub fn is_bgsave_in_progress(&self) -> bool {
249 ForkBasedPersistence::is_bgsave_in_progress()
250 }
251
252 pub fn create_collection_snapshot(&self, collection_name: &str) -> Result<SnapshotDescription> {
256 let collections = self.collections.read();
257 let collection = collections.get(collection_name)
258 .ok_or_else(|| Error::CollectionNotFound(collection_name.to_string()))?;
259
260 let points = collection.get_all_points();
261
262 let snapshot_data = CollectionSnapshotData {
263 name: collection_name.to_string(),
264 config: CollectionConfigData {
265 vector_dim: collection.vector_dim(),
266 distance: match collection.distance() {
267 Distance::Cosine => "Cosine".to_string(),
268 Distance::Euclidean => "Euclidean".to_string(),
269 Distance::Dot => "Dot".to_string(),
270 },
271 use_hnsw: collection.use_hnsw(),
272 enable_bm25: collection.enable_bm25(),
273 },
274 points: points.iter().map(|p| PointData {
275 id: match &p.id {
276 PointId::Integer(i) => i.to_string(),
277 PointId::String(s) => s.clone(),
278 PointId::Uuid(u) => u.to_string(),
279 },
280 vector: p.vector.as_slice().to_vec(),
281 multivector: p.multivector.as_ref().map(|mv: &MultiVector| mv.vectors().to_vec()),
282 payload: p.payload.clone(),
283 }).collect(),
284 created_at: std::time::SystemTime::now()
285 .duration_since(std::time::UNIX_EPOCH)
286 .map(|d| d.as_secs())
287 .unwrap_or(0),
288 };
289
290 self.snapshots.create_collection_snapshot(snapshot_data)
291 .map_err(|e| Error::Storage(e.to_string()))
292 }
293
294 pub fn list_collection_snapshots(&self, collection_name: &str) -> Result<Vec<SnapshotDescription>> {
296 self.snapshots.list_collection_snapshots(collection_name)
297 .map_err(|e| Error::Storage(e.to_string()))
298 }
299
300 pub fn delete_collection_snapshot(&self, collection_name: &str, snapshot_name: &str) -> Result<bool> {
302 self.snapshots.delete_collection_snapshot(collection_name, snapshot_name)
303 .map_err(|e| Error::Storage(e.to_string()))
304 }
305
306 pub fn get_snapshot_path(&self, collection_name: &str, snapshot_name: &str) -> Option<PathBuf> {
308 self.snapshots.get_snapshot_path(collection_name, snapshot_name)
309 }
310
311 pub fn recover_from_snapshot(&self, collection_name: &str, snapshot_name: &str) -> Result<Arc<Collection>> {
313 let snapshot_data = self.snapshots.load_collection_snapshot(collection_name, snapshot_name)
314 .map_err(|e| Error::Storage(e.to_string()))?;
315
316 self.restore_collection_from_data_with_name(snapshot_data, Some(collection_name))
317 }
318
319 pub async fn recover_from_url(&self, collection_name: &str, url: &str, checksum: Option<&str>) -> Result<Arc<Collection>> {
321 let snapshot_path = self.snapshots.download_snapshot_from_url(collection_name, url, checksum)
322 .await
323 .map_err(|e| Error::Storage(e.to_string()))?;
324
325 let snapshot_data = self.snapshots.load_snapshot_from_path(&snapshot_path)
326 .map_err(|e| Error::Storage(e.to_string()))?;
327
328 self.restore_collection_from_data_with_name(snapshot_data, Some(collection_name))
329 }
330
331 fn restore_collection_from_data_with_name(&self, data: CollectionSnapshotData, target_name: Option<&str>) -> Result<Arc<Collection>> {
332 let collection_name = target_name.unwrap_or(&data.name).to_string();
333
334 let config = CollectionConfig {
335 name: collection_name.clone(),
336 vector_dim: data.config.vector_dim,
337 distance: match data.config.distance.as_str() {
338 "Cosine" => Distance::Cosine,
339 "Euclidean" => Distance::Euclidean,
340 "Dot" => Distance::Dot,
341 _ => Distance::Cosine,
342 },
343 use_hnsw: data.config.use_hnsw,
344 enable_bm25: data.config.enable_bm25,
345 };
346
347 {
348 let mut collections = self.collections.write();
349 collections.remove(&collection_name);
350 }
351
352 let collection = Arc::new(Collection::new(config));
353
354 for point_data in data.points {
355 let point_id = point_data.id.parse::<u64>()
356 .map(PointId::Integer)
357 .unwrap_or_else(|_| PointId::String(point_data.id.clone()));
358
359 let point = if let Some(mv_data) = point_data.multivector {
360 match MultiVector::new(mv_data) {
361 Ok(mv) => Point::new_multi(point_id, mv, point_data.payload),
362 Err(e) => {
363 eprintln!("Warning: Failed to create multivector: {}", e);
364 Point::new(point_id, Vector::new(point_data.vector), point_data.payload)
365 }
366 }
367 } else {
368 Point::new(
369 point_id,
370 Vector::new(point_data.vector),
371 point_data.payload,
372 )
373 };
374
375 if let Err(e) = collection.upsert(point) {
376 eprintln!("Warning: Failed to restore point: {}", e);
377 }
378 }
379
380 {
381 let mut collections = self.collections.write();
382 collections.insert(collection_name, collection.clone());
383 }
384
385 Ok(collection)
386 }
387
388 pub fn list_all_snapshots(&self) -> Result<Vec<SnapshotDescription>> {
390 self.snapshots.list_all_snapshots()
391 .map_err(|e| Error::Storage(e.to_string()))
392 }
393
394 pub fn upload_and_restore_snapshot(
396 &self,
397 collection_name: &str,
398 data: &[u8],
399 filename: Option<&str>,
400 ) -> Result<Arc<Collection>> {
401 let snapshot_path = self.snapshots.save_uploaded_snapshot(collection_name, data, filename)
402 .map_err(|e| Error::Storage(e.to_string()))?;
403
404 let snapshot_data = self.snapshots.load_snapshot_from_path(&snapshot_path)
405 .map_err(|e| Error::Storage(e.to_string()))?;
406
407 self.restore_collection_from_data_with_name(snapshot_data, Some(collection_name))
408 }
409}