Skip to main content

ruve/
database.rs

1use std::collections::{HashMap, HashSet};
2use crate::bm25;
3use crate::hnsw::index::HNSW;
4use crate::index::{load_index, save_index, Index};
5use crate::similarities::cosine_similarity;
6use crate::storage::{append_record, retrieve_record};
7use crate::types::Record;
8
9const DEFAULT_CHECKPOINT_EVERY: usize = 100000;
10
11pub struct Database {
12    pub data_path: String,
13    pub index_path: String,
14    pub bm25_index_path: String,
15    pub index: Index,
16    pub bm25_index: bm25::index::Bm25Index,
17    pub bm25_tokenizer: bm25::tokenizer::Tokenizer,
18    pub hnsw: HNSW,
19    pub graph_path: String,
20    /// Flush indexes to disk every this many inserts. Default: 500.
21    pub checkpoint_every: usize,
22    dirty_count: usize,
23}
24
25impl Database {
26    pub fn new(data_path: &str, index_path: &str, bm25_index_path: &str, hnsw_path: &str, graph_path: &str) -> Database {
27        // ensure the data directory exists
28        if let Some(dir) = std::path::Path::new(data_path).parent() {
29            std::fs::create_dir_all(dir).unwrap();
30        }
31        let index = load_index(index_path);
32        let bm25_index = bm25::index::Bm25Index::load(bm25_index_path);
33        let bm25_tokenizer = bm25::tokenizer::Tokenizer::new();
34        let hnsw = HNSW::load(hnsw_path, graph_path);
35
36        Database {
37            data_path: data_path.to_string(),
38            index_path: index_path.to_string(),
39            bm25_index_path: bm25_index_path.to_string(),
40            index,
41            bm25_index,
42            bm25_tokenizer,
43            hnsw,
44            graph_path: graph_path.to_string(),
45            checkpoint_every: DEFAULT_CHECKPOINT_EVERY,
46            dirty_count: 0,
47        }
48    }
49
50    /// Flush all indexes to disk immediately.
51    pub fn save_all(&mut self) {
52        save_index(&self.index_path, &self.index);
53        self.bm25_index.save(&self.bm25_index_path);
54        self.hnsw.save();
55        self.dirty_count = 0;
56    }
57
58    /// Insert a record with a pre-computed embedding vector and text metadata.
59    /// `id` pins the record to a specific key; pass `None` to auto-generate a UUID v7.
60    /// The text is indexed into BM25 for full-text search; the vector is inserted into the HNSW graph.
61    pub fn insert_raw(&mut self, vector: Vec<f32>, text: &str, id: Option<&str>) {
62        let record = Record::new(vector, Some(text.to_string()), id.map(str::to_string));
63        let tokens = self.bm25_tokenizer.tokenize(text);
64        let offset = append_record(&record, &self.data_path);
65        self.bm25_index.index_record(&record.id, &tokens);
66        self.index.insert(record.id.clone(), offset);
67        let (node_index, assigned_layer, prev_entry, prev_highest) = self.hnsw.insert(&record);
68        self.make_connections(&record.vector, node_index, assigned_layer, prev_entry, prev_highest);
69        self.dirty_count += 1;
70        if self.dirty_count >= self.checkpoint_every {
71            self.save_all();
72        }
73    }
74
75    // find nearest neighbors for the new node at each of its layers and wire them up bidirectionally
76    fn make_connections(&mut self, vector: &[f32], node_index: u32, assigned_layer: usize, prev_entry: Option<u32>, prev_highest: usize) {
77        // first node ever — nothing to connect to
78        let entry = match prev_entry {
79            Some(ep) => ep,
80            None => return,
81        };
82
83        // search using the pre-insert entry point and its highest layer so the
84        // new (connectionless) node is never used as a starting point
85        let candidates_by_layer = self.search_layers(vector, self.hnsw.max_neighbors_per_document * 2, Some((entry, prev_highest)));
86
87        for (layer, candidates) in &candidates_by_layer {
88            // only connect layers below the assigned one
89            if *layer > assigned_layer {
90                continue;
91            }
92
93            let neighbors: Vec<u32> = candidates.iter()
94                .filter(|&&n| n != node_index) // filter out the node itself
95                .take(self.hnsw.max_neighbors_per_document) // take first max neighbors per layer
96                .cloned()
97                .collect();
98
99            // wire the node to its new best neighbors
100            self.hnsw.set_neighbors(node_index, *layer, &neighbors);
101
102            // wire bidirectionally
103            for &neighbor in &neighbors {
104                let mut existing = self.hnsw.get_neighbors(neighbor, *layer);
105
106                if !existing.contains(&node_index) {
107                    existing.push(node_index);
108
109                    if existing.len() > self.hnsw.max_neighbors_per_document {
110                        // keep only the best neighbors by similarity
111                        let neighbor_uuid = &self.hnsw.index_to_id[neighbor as usize];
112                        // reobtain the neighbor vector
113                        let neighbor_emb = retrieve_record(*self.index.get(neighbor_uuid).unwrap(), &self.data_path).vector;
114                        // score it
115                        let mut scored: Vec<(f32, u32)> = existing.iter()
116                            .map(|&n| {
117                                let uuid = &self.hnsw.index_to_id[n as usize];
118                                let vec = retrieve_record(*self.index.get(uuid).unwrap(), &self.data_path).vector;
119                                (cosine_similarity(&neighbor_emb, &vec), n)
120                            })
121                            .collect();
122                        scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
123                        // keep best scoring one
124                        existing = scored.iter()
125                            .take(self.hnsw.max_neighbors_per_document)
126                            .map(|&(_, n)| n)
127                            .collect();
128                    }
129
130                    self.hnsw.set_neighbors(neighbor, *layer, &existing);
131                }
132            }
133        }
134    }
135
136    // given a query get the most relevant records using the BM25 index
137    // then convert the doc ids to offsets using the index
138    // and retrieve the records from storage
139    pub fn text_search(&self, query: &str, k: usize) -> Vec<Record> {
140        let tokens = self.bm25_tokenizer.tokenize(query);
141        let scores = self.bm25_index.score(&tokens);
142
143        scores.into_iter().take(k)
144            .filter_map(|(doc_id, _)| self.index.get(&doc_id))
145            .map(|offset| retrieve_record(*offset, &self.data_path))
146            .collect()
147    }
148
149    /// Delete a record by UUID. Returns false if the ID was not found.
150    pub fn delete(&mut self, id: &str) -> bool {
151        let Some(offset) = self.index.get(id).copied() else { return false };
152
153        let record = retrieve_record(offset, &self.data_path);
154        let tokens = self.bm25_tokenizer.tokenize(record.metadata.as_deref().unwrap_or(""));
155
156        self.bm25_index.remove_record(id, &tokens);
157        self.index.remove(id);
158        save_index(&self.index_path, &self.index);
159        self.bm25_index.save(&self.bm25_index_path);
160        true
161    }
162
163    /// Delete all records and wipe every index and data file.
164    pub fn wipe(&mut self) {
165        self.index.clear();
166        self.bm25_index = bm25::index::Bm25Index::new();
167        self.hnsw.wipe();
168        self.dirty_count = 0;
169
170        let _ = std::fs::remove_file(&self.data_path);
171        let _ = std::fs::remove_file(&self.index_path);
172        let _ = std::fs::remove_file(&self.bm25_index_path);
173    }
174
175    /// Brute-force cosine-similarity search over all stored vectors.
176    /// Returns up to k records sorted by descending score.
177    /// Useful for recall evaluation against HNSW results.
178    pub fn search_scored(&self, query_vector: &[f32], k: usize) -> Vec<(f32, Record)> {
179        let mut results: Vec<(f32, Record)> = self.index.iter()
180            .map(|(_id, offset)| retrieve_record(*offset, &self.data_path))
181            .map(|r| {
182                let score = cosine_similarity(&r.vector, query_vector);
183                (score, r)
184            })
185            .collect();
186
187        results.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
188        results.truncate(k);
189        results
190    }
191
192    fn score_node(&self, node_index: u32, query_vector: &[f32]) -> Option<f32> {
193        let uuid = self.hnsw.index_to_id.get(node_index as usize)?;
194        let offset = self.index.get(uuid)?;
195        let record = retrieve_record(*offset, &self.data_path);
196        Some(cosine_similarity(&record.vector, query_vector))
197    }
198
199    /// HNSW approximate nearest-neighbour search.
200    ///
201    /// Returns records sorted by descending similarity.
202    /// `ef` is the exploration factor — higher values improve recall at the cost of speed.
203    pub fn search_hnsw(&self, query_vector: &[f32], ef: usize) -> Vec<Record> {
204        if self.hnsw.node_offsets.is_empty() {
205            return Vec::new();
206        }
207
208        let (mut current, top_layer) = match self.hnsw.entry_point {
209            Some(ep) => (ep, self.hnsw.highest_layer),
210            None => return Vec::new(),
211        };
212
213        // phase 1: greedy descent — upper layers are sparse, used to fast-forward close to the query
214        for layer in (1..=top_layer).rev() {       // walk down from top to layer 1
215            loop {
216                let current_score = self.score_node(current, query_vector).unwrap_or(f32::NEG_INFINITY); // score of where we are now
217                let best = self.hnsw.get_neighbors(current, layer).into_iter()
218                    .filter_map(|n| Some((self.score_node(n, query_vector)?, n))) // score each neighbor
219                    .filter(|(s, _)| *s > current_score)                          // keep only those better than current
220                    .max_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)); // pick the best
221                match best {
222                    Some((_, n)) => current = n, // found a better neighbor, move there and re-check
223                    None => break,               // no neighbor beats current — local optimum, drop to next layer
224                }
225            }
226        }
227
228        // phase 2: beam search at layer 0 — explore the ef best candidates
229        // both vecs are kept sorted descending by score
230        let entrypoint_score = match self.score_node(current, query_vector) {
231            Some(s) => s,
232            None => return Vec::new(),
233        };
234        // nodes already scored — never enqueue or score twice
235        let mut visited: HashSet<u32> = HashSet::from([current]);
236        // frontier: nodes discovered but not yet expanded (neighbors not yet read)
237        // sorted descending — front is always the most promising next step
238        let mut candidates: Vec<(f32, u32)> = vec![(entrypoint_score, current)];
239        // best ef nodes seen so far — what we return at the end
240        // sorted descending — last element is the worst, evicted when results exceeds ef
241        let mut results: Vec<(f32, u32)> = vec![(entrypoint_score, current)];
242
243        while let Some(&(c_score, c_node)) = candidates.first() {
244            // worst score currently in our result set
245            let worst_result = results.last().map(|(s, _)| *s).unwrap_or(f32::NEG_INFINITY);
246            // if the worst result is better than the candidate we are exploring, break
247            if c_score < worst_result {
248                break; // best remaining candidate can't improve the result set
249            }
250            candidates.remove(0);
251
252            // get current candidate neighbors
253            for neighbor in self.hnsw.get_neighbors(c_node, 0) {
254                // if we haven't seen this neighbor before
255                if visited.insert(neighbor) {
256                    // evaluate the similarity between the neighbor and the given query vector
257                    if let Some(n_score) = self.score_node(neighbor, query_vector) {
258                        let worst = results.last().map(|(s, _)| *s).unwrap_or(f32::NEG_INFINITY);
259                        // if the score is better than the worst in the results or the results is not yet full
260                        if n_score > worst || results.len() < ef {
261                            // insert into candidates keeping order
262                            let pos = candidates.partition_point(|(s, _)| *s > n_score);
263                            candidates.insert(pos, (n_score, neighbor));
264                            // insert into results keeping order
265                            let pos = results.partition_point(|(s, _)| *s > n_score);
266                            results.insert(pos, (n_score, neighbor));
267                            // if we are over the limits, pop the last one (worst)
268                            if results.len() > ef {
269                                results.pop();
270                            }
271                        }
272                    }
273                }
274            }
275        }
276
277        results.iter()
278            .filter_map(|(_, node)| {
279                let uuid = self.hnsw.index_to_id.get(*node as usize)?;
280                let offset = self.index.get(uuid)?;
281                Some(retrieve_record(*offset, &self.data_path))
282            })
283            .collect()
284    }
285
286    // used during insertion: returns candidates per layer so make_connections can wire each one
287    fn search_layers(&self, query_vector: &[f32], candidates_per_layer: usize, entry_override: Option<(u32, usize)>) -> HashMap<usize, Vec<u32>> {
288        if self.hnsw.node_offsets.is_empty() {
289            return HashMap::new();
290        }
291
292        let (entry_node, top_layer) = match entry_override {
293            Some(pair) => pair,
294            None => match self.hnsw.entry_point {
295                Some(ep) => (ep, self.hnsw.highest_layer),
296                None => return HashMap::new(),
297            },
298        };
299
300        let mut current_candidates: Vec<u32> = vec![entry_node];
301        let mut layer_candidates: HashMap<usize, Vec<u32>> = HashMap::new();
302
303        for layer in (0..=top_layer).rev() {
304            // collect current candidates + their neighbors, deduplicated
305            let mut seen: HashSet<u32> = current_candidates.iter().cloned().collect();
306            for &candidate in &current_candidates {
307                for neighbor in self.hnsw.get_neighbors(candidate, layer) {
308                    seen.insert(neighbor);
309                }
310            }
311
312            // score all candidates by cosine similarity
313            let mut scored: Vec<(f32, u32)> = seen.iter()
314                .filter_map(|&node_index| {
315                    // foreach neighbor
316                    let uuid = &self.hnsw.index_to_id[node_index as usize];
317
318                    // retrieve the actual record
319                    let data_offset = self.index.get(uuid)?;
320                    let record = retrieve_record(*data_offset, &self.data_path);
321
322                    // evaluate the score against given vector
323                    let score = cosine_similarity(&record.vector, query_vector);
324                    Some((score, node_index))
325                })
326                .collect();
327
328            // sort highest score first
329            scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
330            scored.truncate(candidates_per_layer);
331
332            layer_candidates.insert(layer, scored.iter().map(|&(_, idx)| idx).collect());
333
334            // pass top half down to next layer
335            let next_size = (candidates_per_layer / 2).max(1);
336            current_candidates = scored.iter().take(next_size).map(|&(_, idx)| idx).collect();
337        }
338
339        layer_candidates
340    }
341
342    /// Brute-force search returning the top-k records (without scores).
343    pub fn search(&self, query_vector: &[f32], k: usize) -> Vec<Record> {
344        let mut results = Vec::new();
345
346        for (_id, offset) in &self.index {
347            let record = retrieve_record(*offset, &self.data_path);
348            let similarity = cosine_similarity(&record.vector, query_vector);
349            results.push((similarity, record));
350        }
351
352        results.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
353        results.into_iter().take(k).map(|(_, record)| record).collect()
354    }
355}
356
357impl Drop for Database {
358    fn drop(&mut self) {
359        if self.dirty_count > 0 {
360            self.save_all();
361        }
362    }
363}