1#[cfg(feature = "storage")]
7use crate::error::{Result, RuvectorError};
8#[cfg(feature = "storage")]
9use crate::types::{VectorEntry, VectorId};
10#[cfg(feature = "storage")]
11use bincode::config;
12#[cfg(feature = "storage")]
13use once_cell::sync::Lazy;
14#[cfg(feature = "storage")]
15use parking_lot::Mutex;
16#[cfg(feature = "storage")]
17use redb::{Database, ReadableTable, ReadableTableMetadata, TableDefinition};
18#[cfg(feature = "storage")]
19use serde_json;
20#[cfg(feature = "storage")]
21use std::collections::HashMap;
22#[cfg(feature = "storage")]
23use std::path::{Path, PathBuf};
24#[cfg(feature = "storage")]
25use std::sync::Arc;
26
27#[cfg(feature = "storage")]
28
29const VECTORS_TABLE: TableDefinition<&str, &[u8]> = TableDefinition::new("vectors");
30const METADATA_TABLE: TableDefinition<&str, &str> = TableDefinition::new("metadata");
31
32static DB_POOL: Lazy<Mutex<HashMap<PathBuf, Arc<Database>>>> =
35 Lazy::new(|| Mutex::new(HashMap::new()));
36
37pub struct VectorStorage {
39 db: Arc<Database>,
40 dimensions: usize,
41}
42
43impl VectorStorage {
44 pub fn new<P: AsRef<Path>>(path: P, dimensions: usize) -> Result<Self> {
50 let path_ref = path.as_ref();
52 let path_buf = path_ref
53 .canonicalize()
54 .unwrap_or_else(|_| path_ref.to_path_buf());
55
56 if let Ok(cwd) = std::env::current_dir() {
58 if !path_buf.starts_with(&cwd) && !path_buf.is_absolute() {
59 return Err(RuvectorError::InvalidPath(
60 "Path traversal attempt detected".to_string()
61 ));
62 }
63 }
64
65 let db = {
67 let mut pool = DB_POOL.lock();
68
69 if let Some(existing_db) = pool.get(&path_buf) {
70 Arc::clone(existing_db)
72 } else {
73 let new_db = Arc::new(Database::create(&path_buf)?);
75
76 let write_txn = new_db.begin_write()?;
78 {
79 let _ = write_txn.open_table(VECTORS_TABLE)?;
80 let _ = write_txn.open_table(METADATA_TABLE)?;
81 }
82 write_txn.commit()?;
83
84 pool.insert(path_buf, Arc::clone(&new_db));
85 new_db
86 }
87 };
88
89 Ok(Self { db, dimensions })
90 }
91
92 pub fn insert(&self, entry: &VectorEntry) -> Result<VectorId> {
94 if entry.vector.len() != self.dimensions {
95 return Err(RuvectorError::DimensionMismatch {
96 expected: self.dimensions,
97 actual: entry.vector.len(),
98 });
99 }
100
101 let id = entry
102 .id
103 .clone()
104 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
105
106 let write_txn = self.db.begin_write()?;
107 {
108 let mut table = write_txn.open_table(VECTORS_TABLE)?;
109
110 let vector_data = bincode::encode_to_vec(&entry.vector, config::standard())
112 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
113
114 table.insert(id.as_str(), vector_data.as_slice())?;
115
116 if let Some(metadata) = &entry.metadata {
118 let mut meta_table = write_txn.open_table(METADATA_TABLE)?;
119 let metadata_json = serde_json::to_string(metadata)
120 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
121 meta_table.insert(id.as_str(), metadata_json.as_str())?;
122 }
123 }
124 write_txn.commit()?;
125
126 Ok(id)
127 }
128
129 pub fn insert_batch(&self, entries: &[VectorEntry]) -> Result<Vec<VectorId>> {
131 let write_txn = self.db.begin_write()?;
132 let mut ids = Vec::with_capacity(entries.len());
133
134 {
135 let mut table = write_txn.open_table(VECTORS_TABLE)?;
136 let mut meta_table = write_txn.open_table(METADATA_TABLE)?;
137
138 for entry in entries {
139 if entry.vector.len() != self.dimensions {
140 return Err(RuvectorError::DimensionMismatch {
141 expected: self.dimensions,
142 actual: entry.vector.len(),
143 });
144 }
145
146 let id = entry
147 .id
148 .clone()
149 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
150
151 let vector_data = bincode::encode_to_vec(&entry.vector, config::standard())
153 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
154 table.insert(id.as_str(), vector_data.as_slice())?;
155
156 if let Some(metadata) = &entry.metadata {
158 let metadata_json = serde_json::to_string(metadata)
159 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
160 meta_table.insert(id.as_str(), metadata_json.as_str())?;
161 }
162
163 ids.push(id);
164 }
165 }
166
167 write_txn.commit()?;
168 Ok(ids)
169 }
170
171 pub fn get(&self, id: &str) -> Result<Option<VectorEntry>> {
173 let read_txn = self.db.begin_read()?;
174 let table = read_txn.open_table(VECTORS_TABLE)?;
175
176 let Some(vector_data) = table.get(id)? else {
177 return Ok(None);
178 };
179
180 let (vector, _): (Vec<f32>, usize) =
181 bincode::decode_from_slice(vector_data.value(), config::standard())
182 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
183
184 let meta_table = read_txn.open_table(METADATA_TABLE)?;
186 let metadata = if let Some(meta_data) = meta_table.get(id)? {
187 let meta_str = meta_data.value();
188 Some(
189 serde_json::from_str(meta_str)
190 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?,
191 )
192 } else {
193 None
194 };
195
196 Ok(Some(VectorEntry {
197 id: Some(id.to_string()),
198 vector,
199 metadata,
200 }))
201 }
202
203 pub fn delete(&self, id: &str) -> Result<bool> {
205 let write_txn = self.db.begin_write()?;
206 let mut deleted = false;
207
208 {
209 let mut table = write_txn.open_table(VECTORS_TABLE)?;
210 deleted = table.remove(id)?.is_some();
211
212 let mut meta_table = write_txn.open_table(METADATA_TABLE)?;
213 let _ = meta_table.remove(id)?;
214 }
215
216 write_txn.commit()?;
217 Ok(deleted)
218 }
219
220 pub fn len(&self) -> Result<usize> {
222 let read_txn = self.db.begin_read()?;
223 let table = read_txn.open_table(VECTORS_TABLE)?;
224 Ok(table.len()? as usize)
225 }
226
227 pub fn is_empty(&self) -> Result<bool> {
229 Ok(self.len()? == 0)
230 }
231
232 pub fn all_ids(&self) -> Result<Vec<VectorId>> {
234 let read_txn = self.db.begin_read()?;
235 let table = read_txn.open_table(VECTORS_TABLE)?;
236
237 let mut ids = Vec::new();
238 let iter = table.iter()?;
239 for item in iter {
240 let (key, _) = item?;
241 ids.push(key.value().to_string());
242 }
243
244 Ok(ids)
245 }
246}
247
248use uuid;
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254 use tempfile::tempdir;
255
256 #[test]
257 fn test_insert_and_get() -> Result<()> {
258 let dir = tempdir().unwrap();
259 let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
260
261 let entry = VectorEntry {
262 id: Some("test1".to_string()),
263 vector: vec![1.0, 2.0, 3.0],
264 metadata: None,
265 };
266
267 let id = storage.insert(&entry)?;
268 assert_eq!(id, "test1");
269
270 let retrieved = storage.get("test1")?;
271 assert!(retrieved.is_some());
272 let retrieved = retrieved.unwrap();
273 assert_eq!(retrieved.vector, vec![1.0, 2.0, 3.0]);
274
275 Ok(())
276 }
277
278 #[test]
279 fn test_batch_insert() -> Result<()> {
280 let dir = tempdir().unwrap();
281 let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
282
283 let entries = vec![
284 VectorEntry {
285 id: None,
286 vector: vec![1.0, 2.0, 3.0],
287 metadata: None,
288 },
289 VectorEntry {
290 id: None,
291 vector: vec![4.0, 5.0, 6.0],
292 metadata: None,
293 },
294 ];
295
296 let ids = storage.insert_batch(&entries)?;
297 assert_eq!(ids.len(), 2);
298 assert_eq!(storage.len()?, 2);
299
300 Ok(())
301 }
302
303 #[test]
304 fn test_delete() -> Result<()> {
305 let dir = tempdir().unwrap();
306 let storage = VectorStorage::new(dir.path().join("test.db"), 3)?;
307
308 let entry = VectorEntry {
309 id: Some("test1".to_string()),
310 vector: vec![1.0, 2.0, 3.0],
311 metadata: None,
312 };
313
314 storage.insert(&entry)?;
315 assert_eq!(storage.len()?, 1);
316
317 let deleted = storage.delete("test1")?;
318 assert!(deleted);
319 assert_eq!(storage.len()?, 0);
320
321 Ok(())
322 }
323
324 #[test]
325 fn test_multiple_instances_same_path() -> Result<()> {
326 let dir = tempdir().unwrap();
329 let db_path = dir.path().join("shared.db");
330
331 let storage1 = VectorStorage::new(&db_path, 3)?;
333
334 storage1.insert(&VectorEntry {
336 id: Some("test1".to_string()),
337 vector: vec![1.0, 2.0, 3.0],
338 metadata: None,
339 })?;
340
341 let storage2 = VectorStorage::new(&db_path, 3)?;
343
344 assert_eq!(storage1.len()?, 1);
346 assert_eq!(storage2.len()?, 1);
347
348 storage2.insert(&VectorEntry {
350 id: Some("test2".to_string()),
351 vector: vec![4.0, 5.0, 6.0],
352 metadata: None,
353 })?;
354
355 assert_eq!(storage1.len()?, 2);
357 assert_eq!(storage2.len()?, 2);
358
359 let retrieved1 = storage1.get("test1")?;
361 assert!(retrieved1.is_some());
362
363 let retrieved2 = storage2.get("test2")?;
364 assert!(retrieved2.is_some());
365
366 Ok(())
367 }
368}