1#[cfg(feature = "storage")]
7use crate::error::{Result, RuvectorError};
8#[cfg(feature = "storage")]
9use crate::types::{VectorEntry, VectorId};
10#[cfg(feature = "storage")]
11use bincode::config;
12#[cfg(feature = "storage")]
13use once_cell::sync::Lazy;
14#[cfg(feature = "storage")]
15use parking_lot::Mutex;
16#[cfg(feature = "storage")]
17use redb::{Database, ReadableTable, ReadableTableMetadata, TableDefinition};
18#[cfg(feature = "storage")]
19use serde_json;
20#[cfg(feature = "storage")]
21use std::collections::HashMap;
22#[cfg(feature = "storage")]
23use std::path::{Path, PathBuf};
24#[cfg(feature = "storage")]
25use std::sync::Arc;
26
27#[cfg(feature = "storage")]
28
29const VECTORS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("vectors");
30const METADATA_TABLE: TableDefinition<&str, &str> = TableDefinition::new("metadata");
31
32static DB_POOL: Lazy<Mutex<HashMap<PathBuf, Arc<Database>>>> =
35 Lazy::new(|| Mutex::new(HashMap::new()));
36
37pub struct VectorStorage {
39 db: Arc<Database>,
40 dimensions: usize,
41}
42
43impl VectorStorage {
44 pub fn new<P: AsRef<Path>>(path: P, dimensions: usize) -> Result<Self> {
50 let path_ref = path.as_ref();
52
53 if let Some(parent) = path_ref.parent() {
55 if !parent.as_os_str().is_empty() && !parent.exists() {
56 std::fs::create_dir_all(parent).map_err(|e| {
57 RuvectorError::InvalidPath(format!("Failed to create directory: {}", e))
58 })?;
59 }
60 }
61
62 let path_buf = if path_ref.is_absolute() {
64 path_ref.to_path_buf()
65 } else {
66 std::env::current_dir()
67 .map_err(|e| RuvectorError::InvalidPath(format!("Failed to get cwd: {}", e)))?
68 .join(path_ref)
69 };
70
71 let path_str = path_ref.to_string_lossy();
74 if path_str.contains("..") {
75 if !path_ref.is_absolute() {
79 if let Ok(cwd) = std::env::current_dir() {
80 let mut normalized = cwd.clone();
82 for component in path_ref.components() {
83 match component {
84 std::path::Component::ParentDir => {
85 if !normalized.pop() || !normalized.starts_with(&cwd) {
86 return Err(RuvectorError::InvalidPath(
87 "Path traversal attempt detected".to_string()
88 ));
89 }
90 }
91 std::path::Component::Normal(c) => normalized.push(c),
92 _ => {}
93 }
94 }
95 }
96 }
97 }
98
99 let db = {
101 let mut pool = DB_POOL.lock();
102
103 if let Some(existing_db) = pool.get(&path_buf) {
104 Arc::clone(existing_db)
106 } else {
107 let new_db = Arc::new(Database::create(&path_buf)?);
109
110 let write_txn = new_db.begin_write()?;
112 {
113 let _ = write_txn.open_table(VECTORS_TABLE)?;
114 let _ = write_txn.open_table(METADATA_TABLE)?;
115 }
116 write_txn.commit()?;
117
118 pool.insert(path_buf, Arc::clone(&new_db));
119 new_db
120 }
121 };
122
123 Ok(Self { db, dimensions })
124 }
125
126 pub fn insert(&self, entry: &VectorEntry) -> Result<VectorId> {
128 if entry.vector.len() != self.dimensions {
129 return Err(RuvectorError::DimensionMismatch {
130 expected: self.dimensions,
131 actual: entry.vector.len(),
132 });
133 }
134
135 let id = entry
136 .id
137 .clone()
138 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
139
140 let write_txn = self.db.begin_write()?;
141 {
142 let mut table = write_txn.open_table(VECTORS_TABLE)?;
143
144 let vector_data = bincode::encode_to_vec(&entry.vector, config::standard())
146 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
147
148 table.insert(id.as_str(), vector_data.as_slice())?;
149
150 if let Some(metadata) = &entry.metadata {
152 let mut meta_table = write_txn.open_table(METADATA_TABLE)?;
153 let metadata_json = serde_json::to_string(metadata)
154 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
155 meta_table.insert(id.as_str(), metadata_json.as_str())?;
156 }
157 }
158 write_txn.commit()?;
159
160 Ok(id)
161 }
162
163 pub fn insert_batch(&self, entries: &[VectorEntry]) -> Result<Vec<VectorId>> {
165 let write_txn = self.db.begin_write()?;
166 let mut ids = Vec::with_capacity(entries.len());
167
168 {
169 let mut table = write_txn.open_table(VECTORS_TABLE)?;
170 let mut meta_table = write_txn.open_table(METADATA_TABLE)?;
171
172 for entry in entries {
173 if entry.vector.len() != self.dimensions {
174 return Err(RuvectorError::DimensionMismatch {
175 expected: self.dimensions,
176 actual: entry.vector.len(),
177 });
178 }
179
180 let id = entry
181 .id
182 .clone()
183 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
184
185 let vector_data = bincode::encode_to_vec(&entry.vector, config::standard())
187 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
188 table.insert(id.as_str(), vector_data.as_slice())?;
189
190 if let Some(metadata) = &entry.metadata {
192 let metadata_json = serde_json::to_string(metadata)
193 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
194 meta_table.insert(id.as_str(), metadata_json.as_str())?;
195 }
196
197 ids.push(id);
198 }
199 }
200
201 write_txn.commit()?;
202 Ok(ids)
203 }
204
205 pub fn get(&self, id: &str) -> Result<Option<VectorEntry>> {
207 let read_txn = self.db.begin_read()?;
208 let table = read_txn.open_table(VECTORS_TABLE)?;
209
210 let Some(vector_data) = table.get(id)? else {
211 return Ok(None);
212 };
213
214 let (vector, _): (Vec<f32>, usize) =
215 bincode::decode_from_slice(vector_data.value(), config::standard())
216 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
217
218 let meta_table = read_txn.open_table(METADATA_TABLE)?;
220 let metadata = if let Some(meta_data) = meta_table.get(id)? {
221 let meta_str = meta_data.value();
222 Some(
223 serde_json::from_str(meta_str)
224 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?,
225 )
226 } else {
227 None
228 };
229
230 Ok(Some(VectorEntry {
231 id: Some(id.to_string()),
232 vector,
233 metadata,
234 }))
235 }
236
237 pub fn delete(&self, id: &str) -> Result<bool> {
239 let write_txn = self.db.begin_write()?;
240 let mut deleted = false;
241
242 {
243 let mut table = write_txn.open_table(VECTORS_TABLE)?;
244 deleted = table.remove(id)?.is_some();
245
246 let mut meta_table = write_txn.open_table(METADATA_TABLE)?;
247 let _ = meta_table.remove(id)?;
248 }
249
250 write_txn.commit()?;
251 Ok(deleted)
252 }
253
254 pub fn len(&self) -> Result<usize> {
256 let read_txn = self.db.begin_read()?;
257 let table = read_txn.open_table(VECTORS_TABLE)?;
258 Ok(table.len()? as usize)
259 }
260
261 pub fn is_empty(&self) -> Result<bool> {
263 Ok(self.len()? == 0)
264 }
265
266 pub fn all_ids(&self) -> Result<Vec<VectorId>> {
268 let read_txn = self.db.begin_read()?;
269 let table = read_txn.open_table(VECTORS_TABLE)?;
270
271 let mut ids = Vec::new();
272 let iter = table.iter()?;
273 for item in iter {
274 let (key, _) = item?;
275 ids.push(key.value().to_string());
276 }
277
278 Ok(ids)
279 }
280}
281
282use uuid;
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288 use tempfile::tempdir;
289
290 #[test]
291 fn test_insert_and_get() -> Result<()> {
292 let dir = tempdir().unwrap();
293 let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
294
295 let entry = VectorEntry {
296 id: Some("test1".to_string()),
297 vector: vec![1.0, 2.0, 3.0],
298 metadata: None,
299 };
300
301 let id = storage.insert(&entry)?;
302 assert_eq!(id, "test1");
303
304 let retrieved = storage.get("test1")?;
305 assert!(retrieved.is_some());
306 let retrieved = retrieved.unwrap();
307 assert_eq!(retrieved.vector, vec![1.0, 2.0, 3.0]);
308
309 Ok(())
310 }
311
312 #[test]
313 fn test_batch_insert() -> Result<()> {
314 let dir = tempdir().unwrap();
315 let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
316
317 let entries = vec![
318 VectorEntry {
319 id: None,
320 vector: vec![1.0, 2.0, 3.0],
321 metadata: None,
322 },
323 VectorEntry {
324 id: None,
325 vector: vec![4.0, 5.0, 6.0],
326 metadata: None,
327 },
328 ];
329
330 let ids = storage.insert_batch(&entries)?;
331 assert_eq!(ids.len(), 2);
332 assert_eq!(storage.len()?, 2);
333
334 Ok(())
335 }
336
337 #[test]
338 fn test_delete() -> Result<()> {
339 let dir = tempdir().unwrap();
340 let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
341
342 let entry = VectorEntry {
343 id: Some("test1".to_string()),
344 vector: vec![1.0, 2.0, 3.0],
345 metadata: None,
346 };
347
348 storage.insert(&entry)?;
349 assert_eq!(storage.len()?, 1);
350
351 let deleted = storage.delete("test1")?;
352 assert!(deleted);
353 assert_eq!(storage.len()?, 0);
354
355 Ok(())
356 }
357
358 #[test]
359 fn test_multiple_instances_same_path() -> Result<()> {
360 let dir = tempdir().unwrap();
363 let db_path = dir.path().join("shared.db");
364
365 let storage1 = VectorStorage::new(&db_path, 3)?;
367
368 storage1.insert(&VectorEntry {
370 id: Some("test1".to_string()),
371 vector: vec![1.0, 2.0, 3.0],
372 metadata: None,
373 })?;
374
375 let storage2 = VectorStorage::new(&db_path, 3)?;
377
378 assert_eq!(storage1.len()?, 1);
380 assert_eq!(storage2.len()?, 1);
381
382 storage2.insert(&VectorEntry {
384 id: Some("test2".to_string()),
385 vector: vec![4.0, 5.0, 6.0],
386 metadata: None,
387 })?;
388
389 assert_eq!(storage1.len()?, 2);
391 assert_eq!(storage2.len()?, 2);
392
393 let retrieved1 = storage1.get("test1")?;
395 assert!(retrieved1.is_some());
396
397 let retrieved2 = storage2.get("test2")?;
398 assert!(retrieved2.is_some());
399
400 Ok(())
401 }
402}