vectx_storage/
manager.rs

1use vectx_core::{Collection, CollectionConfig, Distance, Error, Result, Point, PointId, Vector, MultiVector};
2use parking_lot::RwLock;
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5use std::sync::Arc;
6use std::time::Duration;
7use crate::lmdb_storage::LmdbStorage;
8use crate::wal::WriteAheadLog;
9use crate::snapshot::{SnapshotManager, SnapshotDescription, CollectionSnapshotData, CollectionConfigData, PointData};
10use crate::persistence::ForkBasedPersistence;
11
12/// Manages collections and persistence
13pub struct StorageManager {
14    collections: Arc<RwLock<HashMap<String, Arc<Collection>>>>,
15    /// Aliases: alias_name -> collection_name
16    aliases: Arc<RwLock<HashMap<String, String>>>,
17    data_dir: PathBuf,
18    #[allow(dead_code)]
19    lmdb: Option<Arc<LmdbStorage>>,
20    #[allow(dead_code)]
21    wal: Option<Arc<WriteAheadLog>>,
22    snapshots: Arc<SnapshotManager>,
23    persistence: Arc<ForkBasedPersistence>,
24    #[allow(dead_code)]
25    save_interval: Option<Duration>,
26}
27
28impl StorageManager {
29    pub fn new<P: AsRef<Path>>(data_dir: P) -> Result<Self> {
30        let data_dir = data_dir.as_ref().to_path_buf();
31        std::fs::create_dir_all(&data_dir)?;
32
33        let lmdb_path = data_dir.join("lmdb");
34        let lmdb = Arc::new(LmdbStorage::new(&lmdb_path)
35            .map_err(|e| Error::Storage(e.to_string()))?);
36
37        let wal_path = data_dir.join("wal.log");
38        let wal = Arc::new(WriteAheadLog::new(&wal_path)
39            .map_err(|e| Error::Storage(e.to_string()))?);
40
41        let snapshot_dir = data_dir.join("snapshots");
42        let snapshots = Arc::new(SnapshotManager::new(&snapshot_dir)
43            .map_err(|e| Error::Storage(e.to_string()))?);
44
45        let persistence = Arc::new(ForkBasedPersistence::new(&data_dir));
46
47        let collections = Arc::new(RwLock::new(HashMap::new()));
48        let aliases = Arc::new(RwLock::new(HashMap::new()));
49        
50        if let Some(snapshot) = persistence.load_snapshot()
51            .map_err(|e| Error::Persistence(e.to_string()))? {
52            eprintln!("Loading snapshot from disk...");
53            let mut collections_map = HashMap::new();
54            
55            for col_snapshot in snapshot.collections {
56                let config = CollectionConfig {
57                    name: col_snapshot.name.clone(),
58                    vector_dim: col_snapshot.config.vector_dim,
59                    distance: match col_snapshot.config.distance.as_str() {
60                        "Cosine" => Distance::Cosine,
61                        "Euclidean" => Distance::Euclidean,
62                        "Dot" => Distance::Dot,
63                        _ => Distance::Cosine,
64                    },
65                    use_hnsw: col_snapshot.config.use_hnsw,
66                    enable_bm25: col_snapshot.config.enable_bm25,
67                };
68                
69                let collection = Arc::new(Collection::new(config));
70                
71                for point_snapshot in col_snapshot.points {
72                    let point = Point::new(
73                        PointId::String(point_snapshot.id.clone()),
74                        Vector::new(point_snapshot.vector),
75                        point_snapshot.payload,
76                    );
77                    if let Err(e) = collection.upsert(point) {
78                        eprintln!("Warning: Failed to restore point {}: {}", point_snapshot.id, e);
79                    }
80                }
81                
82                collections_map.insert(col_snapshot.name, collection);
83            }
84            
85            *collections.write() = collections_map;
86            eprintln!("Snapshot loaded: {} collections", collections.read().len());
87        }
88
89        let manager = Self {
90            collections,
91            aliases,
92            data_dir,
93            lmdb: Some(lmdb),
94            wal: Some(wal),
95            snapshots,
96            persistence,
97            save_interval: Some(Duration::from_secs(300)),
98        };
99
100        manager.start_background_save();
101
102        Ok(manager)
103    }
104
105    /// Start background save thread
106    fn start_background_save(&self) {
107        let collections = self.collections.clone();
108        let persistence = self.persistence.clone();
109        let interval = self.save_interval.unwrap_or(Duration::from_secs(300));
110
111        std::thread::spawn(move || {
112            loop {
113                std::thread::sleep(interval);
114                
115                if !ForkBasedPersistence::is_bgsave_in_progress() {
116                    let collections_map = collections.read();
117                    if let Err(e) = persistence.bgsave(&collections_map) {
118                        eprintln!("Background save error: {}", e);
119                    }
120                }
121            }
122        });
123    }
124
125    pub fn create_collection(&self, config: CollectionConfig) -> Result<Arc<Collection>> {
126        let name = config.name.clone();
127        let mut collections = self.collections.write();
128
129        if collections.contains_key(&name) {
130            return Err(Error::CollectionExists(name));
131        }
132
133        let collection = Arc::new(Collection::new(config));
134        collections.insert(name.clone(), collection.clone());
135        Ok(collection)
136    }
137
138    #[inline]
139    pub fn get_collection(&self, name: &str) -> Option<Arc<Collection>> {
140        let collections = self.collections.read();
141        // First try direct collection lookup
142        if let Some(col) = collections.get(name) {
143            return Some(col.clone());
144        }
145        // Then try alias lookup
146        let aliases = self.aliases.read();
147        if let Some(collection_name) = aliases.get(name) {
148            return collections.get(collection_name).cloned();
149        }
150        None
151    }
152
153    pub fn delete_collection(&self, name: &str) -> Result<bool> {
154        let mut collections = self.collections.write();
155        let removed = collections.remove(name).is_some();
156        
157        // Also clean up similarity schema if it exists
158        if removed {
159        }
160        
161        Ok(removed)
162    }
163
164    #[inline]
165    #[must_use]
166    pub fn list_collections(&self) -> Vec<String> {
167        self.collections.read().keys().cloned().collect()
168    }
169
170    #[inline]
171    #[must_use]
172    pub fn collection_exists(&self, name: &str) -> bool {
173        self.collections.read().contains_key(name)
174    }
175
176    /// Create an alias for a collection
177    pub fn create_alias(&self, alias_name: &str, collection_name: &str) -> Result<bool> {
178        // Check that collection exists
179        if !self.collection_exists(collection_name) {
180            return Err(Error::CollectionNotFound(collection_name.to_string()));
181        }
182        let mut aliases = self.aliases.write();
183        aliases.insert(alias_name.to_string(), collection_name.to_string());
184        Ok(true)
185    }
186
187    /// Delete an alias
188    pub fn delete_alias(&self, alias_name: &str) -> Result<bool> {
189        let mut aliases = self.aliases.write();
190        Ok(aliases.remove(alias_name).is_some())
191    }
192
193    /// Rename an alias
194    pub fn rename_alias(&self, old_alias: &str, new_alias: &str) -> Result<bool> {
195        let mut aliases = self.aliases.write();
196        if let Some(collection_name) = aliases.remove(old_alias) {
197            aliases.insert(new_alias.to_string(), collection_name);
198            Ok(true)
199        } else {
200            Ok(false)
201        }
202    }
203
204    /// List all aliases
205    pub fn list_aliases(&self) -> Vec<(String, String)> {
206        self.aliases.read()
207            .iter()
208            .map(|(alias, collection)| (alias.clone(), collection.clone()))
209            .collect()
210    }
211
212    /// List aliases for a specific collection
213    pub fn list_collection_aliases(&self, collection_name: &str) -> Vec<String> {
214        self.aliases.read()
215            .iter()
216            .filter(|(_, col)| *col == collection_name)
217            .map(|(alias, _)| alias.clone())
218            .collect()
219    }
220
221    #[inline]
222    #[must_use]
223    pub fn data_dir(&self) -> &Path {
224        &self.data_dir
225    }
226
227
228    /// Trigger background save
229    pub fn bgsave(&self) -> Result<bool> {
230        let collections = self.collections.read();
231        self.persistence.bgsave(&collections)
232            .map_err(|e| Error::Storage(e.to_string()))
233    }
234
235    /// Force save
236    pub fn save(&self) -> Result<()> {
237        let collections = self.collections.read();
238        self.persistence.save(&collections)
239            .map_err(|e| Error::Storage(e.to_string()))
240    }
241
242    /// Get last save time
243    pub fn last_save_time(&self) -> u64 {
244        ForkBasedPersistence::last_save_time()
245    }
246
247    /// Check if background save is in progress
248    pub fn is_bgsave_in_progress(&self) -> bool {
249        ForkBasedPersistence::is_bgsave_in_progress()
250    }
251
252    // ==================== Snapshot Methods ====================
253
254    /// Create a snapshot for a collection
255    pub fn create_collection_snapshot(&self, collection_name: &str) -> Result<SnapshotDescription> {
256        let collections = self.collections.read();
257        let collection = collections.get(collection_name)
258            .ok_or_else(|| Error::CollectionNotFound(collection_name.to_string()))?;
259
260        let points = collection.get_all_points();
261
262        let snapshot_data = CollectionSnapshotData {
263            name: collection_name.to_string(),
264            config: CollectionConfigData {
265                vector_dim: collection.vector_dim(),
266                distance: match collection.distance() {
267                    Distance::Cosine => "Cosine".to_string(),
268                    Distance::Euclidean => "Euclidean".to_string(),
269                    Distance::Dot => "Dot".to_string(),
270                },
271                use_hnsw: collection.use_hnsw(),
272                enable_bm25: collection.enable_bm25(),
273            },
274            points: points.iter().map(|p| PointData {
275                id: match &p.id {
276                    PointId::Integer(i) => i.to_string(),
277                    PointId::String(s) => s.clone(),
278                    PointId::Uuid(u) => u.to_string(),
279                },
280                vector: p.vector.as_slice().to_vec(),
281                multivector: p.multivector.as_ref().map(|mv: &MultiVector| mv.vectors().to_vec()),
282                payload: p.payload.clone(),
283            }).collect(),
284            created_at: std::time::SystemTime::now()
285                .duration_since(std::time::UNIX_EPOCH)
286                .map(|d| d.as_secs())
287                .unwrap_or(0),
288        };
289
290        self.snapshots.create_collection_snapshot(snapshot_data)
291            .map_err(|e| Error::Storage(e.to_string()))
292    }
293
294    /// List snapshots for a collection
295    pub fn list_collection_snapshots(&self, collection_name: &str) -> Result<Vec<SnapshotDescription>> {
296        self.snapshots.list_collection_snapshots(collection_name)
297            .map_err(|e| Error::Storage(e.to_string()))
298    }
299
300    /// Delete a snapshot
301    pub fn delete_collection_snapshot(&self, collection_name: &str, snapshot_name: &str) -> Result<bool> {
302        self.snapshots.delete_collection_snapshot(collection_name, snapshot_name)
303            .map_err(|e| Error::Storage(e.to_string()))
304    }
305
306    /// Get snapshot file path for download
307    pub fn get_snapshot_path(&self, collection_name: &str, snapshot_name: &str) -> Option<PathBuf> {
308        self.snapshots.get_snapshot_path(collection_name, snapshot_name)
309    }
310
311    /// Recover collection from a snapshot file
312    pub fn recover_from_snapshot(&self, collection_name: &str, snapshot_name: &str) -> Result<Arc<Collection>> {
313        let snapshot_data = self.snapshots.load_collection_snapshot(collection_name, snapshot_name)
314            .map_err(|e| Error::Storage(e.to_string()))?;
315
316        self.restore_collection_from_data_with_name(snapshot_data, Some(collection_name))
317    }
318
319    /// Recover collection from a URL
320    pub async fn recover_from_url(&self, collection_name: &str, url: &str, checksum: Option<&str>) -> Result<Arc<Collection>> {
321        let snapshot_path = self.snapshots.download_snapshot_from_url(collection_name, url, checksum)
322            .await
323            .map_err(|e| Error::Storage(e.to_string()))?;
324
325        let snapshot_data = self.snapshots.load_snapshot_from_path(&snapshot_path)
326            .map_err(|e| Error::Storage(e.to_string()))?;
327
328        self.restore_collection_from_data_with_name(snapshot_data, Some(collection_name))
329    }
330
331    fn restore_collection_from_data_with_name(&self, data: CollectionSnapshotData, target_name: Option<&str>) -> Result<Arc<Collection>> {
332        let collection_name = target_name.unwrap_or(&data.name).to_string();
333        
334        let config = CollectionConfig {
335            name: collection_name.clone(),
336            vector_dim: data.config.vector_dim,
337            distance: match data.config.distance.as_str() {
338                "Cosine" => Distance::Cosine,
339                "Euclidean" => Distance::Euclidean,
340                "Dot" => Distance::Dot,
341                _ => Distance::Cosine,
342            },
343            use_hnsw: data.config.use_hnsw,
344            enable_bm25: data.config.enable_bm25,
345        };
346
347        {
348            let mut collections = self.collections.write();
349            collections.remove(&collection_name);
350        }
351
352        let collection = Arc::new(Collection::new(config));
353
354        for point_data in data.points {
355            let point_id = point_data.id.parse::<u64>()
356                .map(PointId::Integer)
357                .unwrap_or_else(|_| PointId::String(point_data.id.clone()));
358
359            let point = if let Some(mv_data) = point_data.multivector {
360                match MultiVector::new(mv_data) {
361                    Ok(mv) => Point::new_multi(point_id, mv, point_data.payload),
362                    Err(e) => {
363                        eprintln!("Warning: Failed to create multivector: {}", e);
364                        Point::new(point_id, Vector::new(point_data.vector), point_data.payload)
365                    }
366                }
367            } else {
368                Point::new(
369                    point_id,
370                    Vector::new(point_data.vector),
371                    point_data.payload,
372                )
373            };
374
375            if let Err(e) = collection.upsert(point) {
376                eprintln!("Warning: Failed to restore point: {}", e);
377            }
378        }
379
380        {
381            let mut collections = self.collections.write();
382            collections.insert(collection_name, collection.clone());
383        }
384
385        Ok(collection)
386    }
387
388    /// List all snapshots
389    pub fn list_all_snapshots(&self) -> Result<Vec<SnapshotDescription>> {
390        self.snapshots.list_all_snapshots()
391            .map_err(|e| Error::Storage(e.to_string()))
392    }
393
394    /// Upload and restore a snapshot from raw bytes
395    pub fn upload_and_restore_snapshot(
396        &self, 
397        collection_name: &str, 
398        data: &[u8],
399        filename: Option<&str>,
400    ) -> Result<Arc<Collection>> {
401        let snapshot_path = self.snapshots.save_uploaded_snapshot(collection_name, data, filename)
402            .map_err(|e| Error::Storage(e.to_string()))?;
403
404        let snapshot_data = self.snapshots.load_snapshot_from_path(&snapshot_path)
405            .map_err(|e| Error::Storage(e.to_string()))?;
406
407        self.restore_collection_from_data_with_name(snapshot_data, Some(collection_name))
408    }
409}