1use std::collections::{HashMap, HashSet};
2use crate::bm25;
3use crate::hnsw::index::HNSW;
4use crate::index::{load_index, save_index, Index};
5use crate::similarities::cosine_similarity;
6use crate::storage::{append_record, retrieve_record};
7use crate::types::Record;
8
9const DEFAULT_CHECKPOINT_EVERY: usize = 100000;
10
11pub struct Database {
12 pub data_path: String,
13 pub index_path: String,
14 pub bm25_index_path: String,
15 pub index: Index,
16 pub bm25_index: bm25::index::Bm25Index,
17 pub bm25_tokenizer: bm25::tokenizer::Tokenizer,
18 pub hnsw: HNSW,
19 pub graph_path: String,
20 pub checkpoint_every: usize,
22 dirty_count: usize,
23}
24
25impl Database {
26 pub fn new(data_path: &str, index_path: &str, bm25_index_path: &str, hnsw_path: &str, graph_path: &str) -> Database {
27 if let Some(dir) = std::path::Path::new(data_path).parent() {
29 std::fs::create_dir_all(dir).unwrap();
30 }
31 let index = load_index(index_path);
32 let bm25_index = bm25::index::Bm25Index::load(bm25_index_path);
33 let bm25_tokenizer = bm25::tokenizer::Tokenizer::new();
34 let hnsw = HNSW::load(hnsw_path, graph_path);
35
36 Database {
37 data_path: data_path.to_string(),
38 index_path: index_path.to_string(),
39 bm25_index_path: bm25_index_path.to_string(),
40 index,
41 bm25_index,
42 bm25_tokenizer,
43 hnsw,
44 graph_path: graph_path.to_string(),
45 checkpoint_every: DEFAULT_CHECKPOINT_EVERY,
46 dirty_count: 0,
47 }
48 }
49
50 pub fn save_all(&mut self) {
52 save_index(&self.index_path, &self.index);
53 self.bm25_index.save(&self.bm25_index_path);
54 self.hnsw.save();
55 self.dirty_count = 0;
56 }
57
58 pub fn insert_raw(&mut self, vector: Vec<f32>, text: &str, id: Option<&str>) {
62 let record = Record::new(vector, Some(text.to_string()), id.map(str::to_string));
63 let tokens = self.bm25_tokenizer.tokenize(text);
64 let offset = append_record(&record, &self.data_path);
65 self.bm25_index.index_record(&record.id, &tokens);
66 self.index.insert(record.id.clone(), offset);
67 let (node_index, assigned_layer, prev_entry, prev_highest) = self.hnsw.insert(&record);
68 self.make_connections(&record.vector, node_index, assigned_layer, prev_entry, prev_highest);
69 self.dirty_count += 1;
70 if self.dirty_count >= self.checkpoint_every {
71 self.save_all();
72 }
73 }
74
75 fn make_connections(&mut self, vector: &[f32], node_index: u32, assigned_layer: usize, prev_entry: Option<u32>, prev_highest: usize) {
77 let entry = match prev_entry {
79 Some(ep) => ep,
80 None => return,
81 };
82
83 let candidates_by_layer = self.search_layers(vector, self.hnsw.max_neighbors_per_document * 2, Some((entry, prev_highest)));
86
87 for (layer, candidates) in &candidates_by_layer {
88 if *layer > assigned_layer {
90 continue;
91 }
92
93 let neighbors: Vec<u32> = candidates.iter()
94 .filter(|&&n| n != node_index) .take(self.hnsw.max_neighbors_per_document) .cloned()
97 .collect();
98
99 self.hnsw.set_neighbors(node_index, *layer, &neighbors);
101
102 for &neighbor in &neighbors {
104 let mut existing = self.hnsw.get_neighbors(neighbor, *layer);
105
106 if !existing.contains(&node_index) {
107 existing.push(node_index);
108
109 if existing.len() > self.hnsw.max_neighbors_per_document {
110 let neighbor_uuid = &self.hnsw.index_to_id[neighbor as usize];
112 let neighbor_emb = retrieve_record(*self.index.get(neighbor_uuid).unwrap(), &self.data_path).vector;
114 let mut scored: Vec<(f32, u32)> = existing.iter()
116 .map(|&n| {
117 let uuid = &self.hnsw.index_to_id[n as usize];
118 let vec = retrieve_record(*self.index.get(uuid).unwrap(), &self.data_path).vector;
119 (cosine_similarity(&neighbor_emb, &vec), n)
120 })
121 .collect();
122 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
123 existing = scored.iter()
125 .take(self.hnsw.max_neighbors_per_document)
126 .map(|&(_, n)| n)
127 .collect();
128 }
129
130 self.hnsw.set_neighbors(neighbor, *layer, &existing);
131 }
132 }
133 }
134 }
135
136 pub fn text_search(&self, query: &str, k: usize) -> Vec<Record> {
140 let tokens = self.bm25_tokenizer.tokenize(query);
141 let scores = self.bm25_index.score(&tokens);
142
143 scores.into_iter().take(k)
144 .filter_map(|(doc_id, _)| self.index.get(&doc_id))
145 .map(|offset| retrieve_record(*offset, &self.data_path))
146 .collect()
147 }
148
149 pub fn delete(&mut self, id: &str) -> bool {
151 let Some(offset) = self.index.get(id).copied() else { return false };
152
153 let record = retrieve_record(offset, &self.data_path);
154 let tokens = self.bm25_tokenizer.tokenize(record.metadata.as_deref().unwrap_or(""));
155
156 self.bm25_index.remove_record(id, &tokens);
157 self.index.remove(id);
158 save_index(&self.index_path, &self.index);
159 self.bm25_index.save(&self.bm25_index_path);
160 true
161 }
162
163 pub fn wipe(&mut self) {
165 self.index.clear();
166 self.bm25_index = bm25::index::Bm25Index::new();
167 self.hnsw.wipe();
168 self.dirty_count = 0;
169
170 let _ = std::fs::remove_file(&self.data_path);
171 let _ = std::fs::remove_file(&self.index_path);
172 let _ = std::fs::remove_file(&self.bm25_index_path);
173 }
174
175 pub fn search_scored(&self, query_vector: &[f32], k: usize) -> Vec<(f32, Record)> {
179 let mut results: Vec<(f32, Record)> = self.index.iter()
180 .map(|(_id, offset)| retrieve_record(*offset, &self.data_path))
181 .map(|r| {
182 let score = cosine_similarity(&r.vector, query_vector);
183 (score, r)
184 })
185 .collect();
186
187 results.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
188 results.truncate(k);
189 results
190 }
191
192 fn score_node(&self, node_index: u32, query_vector: &[f32]) -> Option<f32> {
193 let uuid = self.hnsw.index_to_id.get(node_index as usize)?;
194 let offset = self.index.get(uuid)?;
195 let record = retrieve_record(*offset, &self.data_path);
196 Some(cosine_similarity(&record.vector, query_vector))
197 }
198
199 pub fn search_hnsw(&self, query_vector: &[f32], ef: usize) -> Vec<Record> {
204 if self.hnsw.node_offsets.is_empty() {
205 return Vec::new();
206 }
207
208 let (mut current, top_layer) = match self.hnsw.entry_point {
209 Some(ep) => (ep, self.hnsw.highest_layer),
210 None => return Vec::new(),
211 };
212
213 for layer in (1..=top_layer).rev() { loop {
216 let current_score = self.score_node(current, query_vector).unwrap_or(f32::NEG_INFINITY); let best = self.hnsw.get_neighbors(current, layer).into_iter()
218 .filter_map(|n| Some((self.score_node(n, query_vector)?, n))) .filter(|(s, _)| *s > current_score) .max_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)); match best {
222 Some((_, n)) => current = n, None => break, }
225 }
226 }
227
228 let entrypoint_score = match self.score_node(current, query_vector) {
231 Some(s) => s,
232 None => return Vec::new(),
233 };
234 let mut visited: HashSet<u32> = HashSet::from([current]);
236 let mut candidates: Vec<(f32, u32)> = vec![(entrypoint_score, current)];
239 let mut results: Vec<(f32, u32)> = vec![(entrypoint_score, current)];
242
243 while let Some(&(c_score, c_node)) = candidates.first() {
244 let worst_result = results.last().map(|(s, _)| *s).unwrap_or(f32::NEG_INFINITY);
246 if c_score < worst_result {
248 break; }
250 candidates.remove(0);
251
252 for neighbor in self.hnsw.get_neighbors(c_node, 0) {
254 if visited.insert(neighbor) {
256 if let Some(n_score) = self.score_node(neighbor, query_vector) {
258 let worst = results.last().map(|(s, _)| *s).unwrap_or(f32::NEG_INFINITY);
259 if n_score > worst || results.len() < ef {
261 let pos = candidates.partition_point(|(s, _)| *s > n_score);
263 candidates.insert(pos, (n_score, neighbor));
264 let pos = results.partition_point(|(s, _)| *s > n_score);
266 results.insert(pos, (n_score, neighbor));
267 if results.len() > ef {
269 results.pop();
270 }
271 }
272 }
273 }
274 }
275 }
276
277 results.iter()
278 .filter_map(|(_, node)| {
279 let uuid = self.hnsw.index_to_id.get(*node as usize)?;
280 let offset = self.index.get(uuid)?;
281 Some(retrieve_record(*offset, &self.data_path))
282 })
283 .collect()
284 }
285
286 fn search_layers(&self, query_vector: &[f32], candidates_per_layer: usize, entry_override: Option<(u32, usize)>) -> HashMap<usize, Vec<u32>> {
288 if self.hnsw.node_offsets.is_empty() {
289 return HashMap::new();
290 }
291
292 let (entry_node, top_layer) = match entry_override {
293 Some(pair) => pair,
294 None => match self.hnsw.entry_point {
295 Some(ep) => (ep, self.hnsw.highest_layer),
296 None => return HashMap::new(),
297 },
298 };
299
300 let mut current_candidates: Vec<u32> = vec![entry_node];
301 let mut layer_candidates: HashMap<usize, Vec<u32>> = HashMap::new();
302
303 for layer in (0..=top_layer).rev() {
304 let mut seen: HashSet<u32> = current_candidates.iter().cloned().collect();
306 for &candidate in ¤t_candidates {
307 for neighbor in self.hnsw.get_neighbors(candidate, layer) {
308 seen.insert(neighbor);
309 }
310 }
311
312 let mut scored: Vec<(f32, u32)> = seen.iter()
314 .filter_map(|&node_index| {
315 let uuid = &self.hnsw.index_to_id[node_index as usize];
317
318 let data_offset = self.index.get(uuid)?;
320 let record = retrieve_record(*data_offset, &self.data_path);
321
322 let score = cosine_similarity(&record.vector, query_vector);
324 Some((score, node_index))
325 })
326 .collect();
327
328 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
330 scored.truncate(candidates_per_layer);
331
332 layer_candidates.insert(layer, scored.iter().map(|&(_, idx)| idx).collect());
333
334 let next_size = (candidates_per_layer / 2).max(1);
336 current_candidates = scored.iter().take(next_size).map(|&(_, idx)| idx).collect();
337 }
338
339 layer_candidates
340 }
341
342 pub fn search(&self, query_vector: &[f32], k: usize) -> Vec<Record> {
344 let mut results = Vec::new();
345
346 for (_id, offset) in &self.index {
347 let record = retrieve_record(*offset, &self.data_path);
348 let similarity = cosine_similarity(&record.vector, query_vector);
349 results.push((similarity, record));
350 }
351
352 results.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
353 results.into_iter().take(k).map(|(_, record)| record).collect()
354 }
355}
356
357impl Drop for Database {
358 fn drop(&mut self) {
359 if self.dirty_count > 0 {
360 self.save_all();
361 }
362 }
363}