Skip to main content

ruve/
database.rs

1use std::collections::{HashMap, HashSet};
2use crate::bm25;
3use crate::hnsw::index::HNSW;
4use crate::index::{load_index, save_index, Index};
5use crate::similarities::cosine_similarity;
6use crate::storage::{append_record, retrieve_record};
7use crate::types::Record;
8
9pub struct Database {
10    pub data_path: String,
11    pub index_path: String,
12    pub bm25_index_path: String,
13    pub index: Index,
14    pub bm25_index: bm25::index::Bm25Index,
15    pub bm25_tokenizer: bm25::tokenizer::Tokenizer,
16    pub hnsw: HNSW,
17    pub graph_path: String,
18}
19
20impl Database {
21    pub fn new(data_path: &str, index_path: &str, bm25_index_path: &str, hnsw_path: &str, graph_path: &str) -> Database {
22        // ensure the data directory exists
23        if let Some(dir) = std::path::Path::new(data_path).parent() {
24            std::fs::create_dir_all(dir).unwrap();
25        }
26        let index = load_index(index_path);
27        let bm25_index = bm25::index::Bm25Index::load(bm25_index_path);
28        let bm25_tokenizer = bm25::tokenizer::Tokenizer::new();
29        let hnsw = HNSW::load(hnsw_path, graph_path);
30
31        Database {
32            data_path: data_path.to_string(),
33            index_path: index_path.to_string(),
34            bm25_index_path: bm25_index_path.to_string(),
35            index,
36            bm25_index,
37            bm25_tokenizer,
38            hnsw,
39            graph_path: graph_path.to_string(),
40        }
41    }
42
43    /// Insert a record with a pre-computed embedding vector and text metadata.
44    /// `id` pins the record to a specific key; pass `None` to auto-generate a UUID v7.
45    /// The text is indexed into BM25 for full-text search; the vector is inserted into the HNSW graph.
46    pub fn insert_raw(&mut self, vector: Vec<f32>, text: &str, id: Option<&str>) {
47        let record = Record::new(vector, Some(text.to_string()), id.map(str::to_string));
48        let tokens = self.bm25_tokenizer.tokenize(text);
49        let offset = append_record(&record, &self.data_path);
50        self.bm25_index.index_record(&record.id, &tokens);
51        self.index.insert(record.id.clone(), offset);
52        let (node_index, assigned_layer, prev_entry, prev_highest) = self.hnsw.insert(&record);
53        self.make_connections(&record.vector, node_index, assigned_layer, prev_entry, prev_highest);
54        save_index(&self.index_path, &self.index);
55        self.bm25_index.save(&self.bm25_index_path);
56    }
57
58    // find nearest neighbors for the new node at each of its layers and wire them up bidirectionally
59    fn make_connections(&mut self, vector: &[f32], node_index: u32, assigned_layer: usize, prev_entry: Option<u32>, prev_highest: usize) {
60        // first node ever — nothing to connect to
61        let entry = match prev_entry {
62            Some(ep) => ep,
63            None => return,
64        };
65
66        // search using the pre-insert entry point and its highest layer so the
67        // new (connectionless) node is never used as a starting point
68        let candidates_by_layer = self.search_layers(vector, self.hnsw.max_neighbors_per_document * 2, Some((entry, prev_highest)));
69
70        for (layer, candidates) in &candidates_by_layer {
71            // only connect layers below the assigned one
72            if *layer > assigned_layer {
73                continue;
74            }
75
76            let neighbors: Vec<u32> = candidates.iter()
77                .filter(|&&n| n != node_index) // filter out the node itself
78                .take(self.hnsw.max_neighbors_per_document) // take first max neighbors per layer
79                .cloned()
80                .collect();
81
82            // wire the node to its new best neighbors
83            self.hnsw.set_neighbors(node_index, *layer, &neighbors);
84
85            // wire bidirectionally
86            for &neighbor in &neighbors {
87                let mut existing = self.hnsw.get_neighbors(neighbor, *layer);
88
89                if !existing.contains(&node_index) {
90                    existing.push(node_index);
91
92                    if existing.len() > self.hnsw.max_neighbors_per_document {
93                        // keep only the best neighbors by similarity
94                        let neighbor_uuid = &self.hnsw.index_to_id[neighbor as usize];
95                        // reobtain the neighbor vector
96                        let neighbor_emb = retrieve_record(*self.index.get(neighbor_uuid).unwrap(), &self.data_path).vector;
97                        // score it
98                        let mut scored: Vec<(f32, u32)> = existing.iter()
99                            .map(|&n| {
100                                let uuid = &self.hnsw.index_to_id[n as usize];
101                                let vec = retrieve_record(*self.index.get(uuid).unwrap(), &self.data_path).vector;
102                                (cosine_similarity(&neighbor_emb, &vec), n)
103                            })
104                            .collect();
105                        scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
106                        // keep best scoring one
107                        existing = scored.iter()
108                            .take(self.hnsw.max_neighbors_per_document)
109                            .map(|&(_, n)| n)
110                            .collect();
111                    }
112
113                    self.hnsw.set_neighbors(neighbor, *layer, &existing);
114                }
115            }
116        }
117    }
118
119    // given a query get the most relevant records using the BM25 index
120    // then convert the doc ids to offsets using the index
121    // and retrieve the records from storage
122    pub fn text_search(&self, query: &str, k: usize) -> Vec<Record> {
123        let tokens = self.bm25_tokenizer.tokenize(query);
124        let scores = self.bm25_index.score(&tokens);
125
126        scores.into_iter().take(k)
127            .filter_map(|(doc_id, _)| self.index.get(&doc_id))
128            .map(|offset| retrieve_record(*offset, &self.data_path))
129            .collect()
130    }
131
132    /// Delete a record by UUID. Returns false if the ID was not found.
133    pub fn delete(&mut self, id: &str) -> bool {
134        let Some(offset) = self.index.get(id).copied() else { return false };
135
136        let record = retrieve_record(offset, &self.data_path);
137        let tokens = self.bm25_tokenizer.tokenize(record.metadata.as_deref().unwrap_or(""));
138
139        self.bm25_index.remove_record(id, &tokens);
140        self.index.remove(id);
141        save_index(&self.index_path, &self.index);
142        self.bm25_index.save(&self.bm25_index_path);
143        true
144    }
145
146    /// Delete all records and wipe every index and data file.
147    pub fn wipe(&mut self) {
148        self.index.clear();
149        self.bm25_index = bm25::index::Bm25Index::new();
150        self.hnsw.wipe();
151
152        let _ = std::fs::remove_file(&self.data_path);
153        let _ = std::fs::remove_file(&self.index_path);
154        let _ = std::fs::remove_file(&self.bm25_index_path);
155    }
156
157    /// Brute-force cosine-similarity search over all stored vectors.
158    /// Returns up to k records sorted by descending score.
159    /// Useful for recall evaluation against HNSW results.
160    pub fn search_scored(&self, query_vector: &[f32], k: usize) -> Vec<(f32, Record)> {
161        let mut results: Vec<(f32, Record)> = self.index.iter()
162            .map(|(_id, offset)| retrieve_record(*offset, &self.data_path))
163            .map(|r| {
164                let score = cosine_similarity(&r.vector, query_vector);
165                (score, r)
166            })
167            .collect();
168
169        results.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
170        results.truncate(k);
171        results
172    }
173
174    fn score_node(&self, node_index: u32, query_vector: &[f32]) -> Option<f32> {
175        let uuid = self.hnsw.index_to_id.get(node_index as usize)?;
176        let offset = self.index.get(uuid)?;
177        let record = retrieve_record(*offset, &self.data_path);
178        Some(cosine_similarity(&record.vector, query_vector))
179    }
180
181    /// HNSW approximate nearest-neighbour search.
182    ///
183    /// Returns records sorted by descending similarity.
184    /// `ef` is the exploration factor — higher values improve recall at the cost of speed.
185    pub fn search_hnsw(&self, query_vector: &[f32], ef: usize) -> Vec<Record> {
186        if self.hnsw.node_offsets.is_empty() {
187            return Vec::new();
188        }
189
190        let (mut current, top_layer) = match self.hnsw.entry_point {
191            Some(ep) => (ep, self.hnsw.highest_layer),
192            None => return Vec::new(),
193        };
194
195        // phase 1: greedy descent — upper layers are sparse, used to fast-forward close to the query
196        for layer in (1..=top_layer).rev() {       // walk down from top to layer 1
197            loop {
198                let current_score = self.score_node(current, query_vector).unwrap_or(f32::NEG_INFINITY); // score of where we are now
199                let best = self.hnsw.get_neighbors(current, layer).into_iter()
200                    .filter_map(|n| Some((self.score_node(n, query_vector)?, n))) // score each neighbor
201                    .filter(|(s, _)| *s > current_score)                          // keep only those better than current
202                    .max_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)); // pick the best
203                match best {
204                    Some((_, n)) => current = n, // found a better neighbor, move there and re-check
205                    None => break,               // no neighbor beats current — local optimum, drop to next layer
206                }
207            }
208        }
209
210        // phase 2: beam search at layer 0 — explore the ef best candidates
211        // both vecs are kept sorted descending by score
212        let entrypoint_score = match self.score_node(current, query_vector) {
213            Some(s) => s,
214            None => return Vec::new(),
215        };
216        // nodes already scored — never enqueue or score twice
217        let mut visited: HashSet<u32> = HashSet::from([current]);
218        // frontier: nodes discovered but not yet expanded (neighbors not yet read)
219        // sorted descending — front is always the most promising next step
220        let mut candidates: Vec<(f32, u32)> = vec![(entrypoint_score, current)];
221        // best ef nodes seen so far — what we return at the end
222        // sorted descending — last element is the worst, evicted when results exceeds ef
223        let mut results: Vec<(f32, u32)> = vec![(entrypoint_score, current)];
224
225        while let Some(&(c_score, c_node)) = candidates.first() {
226            // worst score currently in our result set
227            let worst_result = results.last().map(|(s, _)| *s).unwrap_or(f32::NEG_INFINITY);
228            // if the worst result is better than the candidate we are exploring, break
229            if c_score < worst_result {
230                break; // best remaining candidate can't improve the result set
231            }
232            candidates.remove(0);
233
234            // get current candidate neighbors
235            for neighbor in self.hnsw.get_neighbors(c_node, 0) {
236                // if we haven't seen this neighbor before
237                if visited.insert(neighbor) {
238                    // evaluate the similarity between the neighbor and the given query vector
239                    if let Some(n_score) = self.score_node(neighbor, query_vector) {
240                        let worst = results.last().map(|(s, _)| *s).unwrap_or(f32::NEG_INFINITY);
241                        // if the score is better than the worst in the results or the results is not yet full
242                        if n_score > worst || results.len() < ef {
243                            // insert into candidates keeping order
244                            let pos = candidates.partition_point(|(s, _)| *s > n_score);
245                            candidates.insert(pos, (n_score, neighbor));
246                            // insert into results keeping order
247                            let pos = results.partition_point(|(s, _)| *s > n_score);
248                            results.insert(pos, (n_score, neighbor));
249                            // if we are over the limits, pop the last one (worst)
250                            if results.len() > ef {
251                                results.pop();
252                            }
253                        }
254                    }
255                }
256            }
257        }
258
259        results.iter()
260            .filter_map(|(_, node)| {
261                let uuid = self.hnsw.index_to_id.get(*node as usize)?;
262                let offset = self.index.get(uuid)?;
263                Some(retrieve_record(*offset, &self.data_path))
264            })
265            .collect()
266    }
267
268    // used during insertion: returns candidates per layer so make_connections can wire each one
269    fn search_layers(&self, query_vector: &[f32], candidates_per_layer: usize, entry_override: Option<(u32, usize)>) -> HashMap<usize, Vec<u32>> {
270        if self.hnsw.node_offsets.is_empty() {
271            return HashMap::new();
272        }
273
274        let (entry_node, top_layer) = match entry_override {
275            Some(pair) => pair,
276            None => match self.hnsw.entry_point {
277                Some(ep) => (ep, self.hnsw.highest_layer),
278                None => return HashMap::new(),
279            },
280        };
281
282        let mut current_candidates: Vec<u32> = vec![entry_node];
283        let mut layer_candidates: HashMap<usize, Vec<u32>> = HashMap::new();
284
285        for layer in (0..=top_layer).rev() {
286            // collect current candidates + their neighbors, deduplicated
287            let mut seen: HashSet<u32> = current_candidates.iter().cloned().collect();
288            for &candidate in &current_candidates {
289                for neighbor in self.hnsw.get_neighbors(candidate, layer) {
290                    seen.insert(neighbor);
291                }
292            }
293
294            // score all candidates by cosine similarity
295            let mut scored: Vec<(f32, u32)> = seen.iter()
296                .filter_map(|&node_index| {
297                    // foreach neighbor
298                    let uuid = &self.hnsw.index_to_id[node_index as usize];
299
300                    // retrieve the actual record
301                    let data_offset = self.index.get(uuid)?;
302                    let record = retrieve_record(*data_offset, &self.data_path);
303
304                    // evaluate the score against given vector
305                    let score = cosine_similarity(&record.vector, query_vector);
306                    Some((score, node_index))
307                })
308                .collect();
309
310            // sort highest score first
311            scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
312            scored.truncate(candidates_per_layer);
313
314            layer_candidates.insert(layer, scored.iter().map(|&(_, idx)| idx).collect());
315
316            // pass top half down to next layer
317            let next_size = (candidates_per_layer / 2).max(1);
318            current_candidates = scored.iter().take(next_size).map(|&(_, idx)| idx).collect();
319        }
320
321        layer_candidates
322    }
323
324    /// Brute-force search returning the top-k records (without scores).
325    pub fn search(&self, query_vector: &[f32], k: usize) -> Vec<Record> {
326        let mut results = Vec::new();
327
328        for (_id, offset) in &self.index {
329            let record = retrieve_record(*offset, &self.data_path);
330            let similarity = cosine_similarity(&record.vector, query_vector);
331            results.push((similarity, record));
332        }
333
334        results.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
335        results.into_iter().take(k).map(|(_, record)| record).collect()
336    }
337}