1use std::collections::{HashMap, HashSet};
2use crate::bm25;
3use crate::hnsw::index::HNSW;
4use crate::index::{load_index, save_index, Index};
5use crate::similarities::cosine_similarity;
6use crate::storage::{append_record, retrieve_record};
7use crate::types::Record;
8
9pub struct Database {
10 pub data_path: String,
11 pub index_path: String,
12 pub bm25_index_path: String,
13 pub index: Index,
14 pub bm25_index: bm25::index::Bm25Index,
15 pub bm25_tokenizer: bm25::tokenizer::Tokenizer,
16 pub hnsw: HNSW,
17 pub graph_path: String,
18}
19
20impl Database {
21 pub fn new(data_path: &str, index_path: &str, bm25_index_path: &str, hnsw_path: &str, graph_path: &str) -> Database {
22 if let Some(dir) = std::path::Path::new(data_path).parent() {
24 std::fs::create_dir_all(dir).unwrap();
25 }
26 let index = load_index(index_path);
27 let bm25_index = bm25::index::Bm25Index::load(bm25_index_path);
28 let bm25_tokenizer = bm25::tokenizer::Tokenizer::new();
29 let hnsw = HNSW::load(hnsw_path, graph_path);
30
31 Database {
32 data_path: data_path.to_string(),
33 index_path: index_path.to_string(),
34 bm25_index_path: bm25_index_path.to_string(),
35 index,
36 bm25_index,
37 bm25_tokenizer,
38 hnsw,
39 graph_path: graph_path.to_string(),
40 }
41 }
42
43 pub fn insert_raw(&mut self, vector: Vec<f32>, text: &str, id: Option<&str>) {
47 let record = Record::new(vector, Some(text.to_string()), id.map(str::to_string));
48 let tokens = self.bm25_tokenizer.tokenize(text);
49 let offset = append_record(&record, &self.data_path);
50 self.bm25_index.index_record(&record.id, &tokens);
51 self.index.insert(record.id.clone(), offset);
52 let (node_index, assigned_layer, prev_entry, prev_highest) = self.hnsw.insert(&record);
53 self.make_connections(&record.vector, node_index, assigned_layer, prev_entry, prev_highest);
54 save_index(&self.index_path, &self.index);
55 self.bm25_index.save(&self.bm25_index_path);
56 }
57
58 fn make_connections(&mut self, vector: &[f32], node_index: u32, assigned_layer: usize, prev_entry: Option<u32>, prev_highest: usize) {
60 let entry = match prev_entry {
62 Some(ep) => ep,
63 None => return,
64 };
65
66 let candidates_by_layer = self.search_layers(vector, self.hnsw.max_neighbors_per_document * 2, Some((entry, prev_highest)));
69
70 for (layer, candidates) in &candidates_by_layer {
71 if *layer > assigned_layer {
73 continue;
74 }
75
76 let neighbors: Vec<u32> = candidates.iter()
77 .filter(|&&n| n != node_index) .take(self.hnsw.max_neighbors_per_document) .cloned()
80 .collect();
81
82 self.hnsw.set_neighbors(node_index, *layer, &neighbors);
84
85 for &neighbor in &neighbors {
87 let mut existing = self.hnsw.get_neighbors(neighbor, *layer);
88
89 if !existing.contains(&node_index) {
90 existing.push(node_index);
91
92 if existing.len() > self.hnsw.max_neighbors_per_document {
93 let neighbor_uuid = &self.hnsw.index_to_id[neighbor as usize];
95 let neighbor_emb = retrieve_record(*self.index.get(neighbor_uuid).unwrap(), &self.data_path).vector;
97 let mut scored: Vec<(f32, u32)> = existing.iter()
99 .map(|&n| {
100 let uuid = &self.hnsw.index_to_id[n as usize];
101 let vec = retrieve_record(*self.index.get(uuid).unwrap(), &self.data_path).vector;
102 (cosine_similarity(&neighbor_emb, &vec), n)
103 })
104 .collect();
105 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
106 existing = scored.iter()
108 .take(self.hnsw.max_neighbors_per_document)
109 .map(|&(_, n)| n)
110 .collect();
111 }
112
113 self.hnsw.set_neighbors(neighbor, *layer, &existing);
114 }
115 }
116 }
117 }
118
119 pub fn text_search(&self, query: &str, k: usize) -> Vec<Record> {
123 let tokens = self.bm25_tokenizer.tokenize(query);
124 let scores = self.bm25_index.score(&tokens);
125
126 scores.into_iter().take(k)
127 .filter_map(|(doc_id, _)| self.index.get(&doc_id))
128 .map(|offset| retrieve_record(*offset, &self.data_path))
129 .collect()
130 }
131
132 pub fn delete(&mut self, id: &str) -> bool {
134 let Some(offset) = self.index.get(id).copied() else { return false };
135
136 let record = retrieve_record(offset, &self.data_path);
137 let tokens = self.bm25_tokenizer.tokenize(record.metadata.as_deref().unwrap_or(""));
138
139 self.bm25_index.remove_record(id, &tokens);
140 self.index.remove(id);
141 save_index(&self.index_path, &self.index);
142 self.bm25_index.save(&self.bm25_index_path);
143 true
144 }
145
146 pub fn wipe(&mut self) {
148 self.index.clear();
149 self.bm25_index = bm25::index::Bm25Index::new();
150 self.hnsw.wipe();
151
152 let _ = std::fs::remove_file(&self.data_path);
153 let _ = std::fs::remove_file(&self.index_path);
154 let _ = std::fs::remove_file(&self.bm25_index_path);
155 }
156
157 pub fn search_scored(&self, query_vector: &[f32], k: usize) -> Vec<(f32, Record)> {
161 let mut results: Vec<(f32, Record)> = self.index.iter()
162 .map(|(_id, offset)| retrieve_record(*offset, &self.data_path))
163 .map(|r| {
164 let score = cosine_similarity(&r.vector, query_vector);
165 (score, r)
166 })
167 .collect();
168
169 results.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
170 results.truncate(k);
171 results
172 }
173
174 fn score_node(&self, node_index: u32, query_vector: &[f32]) -> Option<f32> {
175 let uuid = self.hnsw.index_to_id.get(node_index as usize)?;
176 let offset = self.index.get(uuid)?;
177 let record = retrieve_record(*offset, &self.data_path);
178 Some(cosine_similarity(&record.vector, query_vector))
179 }
180
181 pub fn search_hnsw(&self, query_vector: &[f32], ef: usize) -> Vec<Record> {
186 if self.hnsw.node_offsets.is_empty() {
187 return Vec::new();
188 }
189
190 let (mut current, top_layer) = match self.hnsw.entry_point {
191 Some(ep) => (ep, self.hnsw.highest_layer),
192 None => return Vec::new(),
193 };
194
195 for layer in (1..=top_layer).rev() { loop {
198 let current_score = self.score_node(current, query_vector).unwrap_or(f32::NEG_INFINITY); let best = self.hnsw.get_neighbors(current, layer).into_iter()
200 .filter_map(|n| Some((self.score_node(n, query_vector)?, n))) .filter(|(s, _)| *s > current_score) .max_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)); match best {
204 Some((_, n)) => current = n, None => break, }
207 }
208 }
209
210 let entrypoint_score = match self.score_node(current, query_vector) {
213 Some(s) => s,
214 None => return Vec::new(),
215 };
216 let mut visited: HashSet<u32> = HashSet::from([current]);
218 let mut candidates: Vec<(f32, u32)> = vec![(entrypoint_score, current)];
221 let mut results: Vec<(f32, u32)> = vec![(entrypoint_score, current)];
224
225 while let Some(&(c_score, c_node)) = candidates.first() {
226 let worst_result = results.last().map(|(s, _)| *s).unwrap_or(f32::NEG_INFINITY);
228 if c_score < worst_result {
230 break; }
232 candidates.remove(0);
233
234 for neighbor in self.hnsw.get_neighbors(c_node, 0) {
236 if visited.insert(neighbor) {
238 if let Some(n_score) = self.score_node(neighbor, query_vector) {
240 let worst = results.last().map(|(s, _)| *s).unwrap_or(f32::NEG_INFINITY);
241 if n_score > worst || results.len() < ef {
243 let pos = candidates.partition_point(|(s, _)| *s > n_score);
245 candidates.insert(pos, (n_score, neighbor));
246 let pos = results.partition_point(|(s, _)| *s > n_score);
248 results.insert(pos, (n_score, neighbor));
249 if results.len() > ef {
251 results.pop();
252 }
253 }
254 }
255 }
256 }
257 }
258
259 results.iter()
260 .filter_map(|(_, node)| {
261 let uuid = self.hnsw.index_to_id.get(*node as usize)?;
262 let offset = self.index.get(uuid)?;
263 Some(retrieve_record(*offset, &self.data_path))
264 })
265 .collect()
266 }
267
268 fn search_layers(&self, query_vector: &[f32], candidates_per_layer: usize, entry_override: Option<(u32, usize)>) -> HashMap<usize, Vec<u32>> {
270 if self.hnsw.node_offsets.is_empty() {
271 return HashMap::new();
272 }
273
274 let (entry_node, top_layer) = match entry_override {
275 Some(pair) => pair,
276 None => match self.hnsw.entry_point {
277 Some(ep) => (ep, self.hnsw.highest_layer),
278 None => return HashMap::new(),
279 },
280 };
281
282 let mut current_candidates: Vec<u32> = vec![entry_node];
283 let mut layer_candidates: HashMap<usize, Vec<u32>> = HashMap::new();
284
285 for layer in (0..=top_layer).rev() {
286 let mut seen: HashSet<u32> = current_candidates.iter().cloned().collect();
288 for &candidate in ¤t_candidates {
289 for neighbor in self.hnsw.get_neighbors(candidate, layer) {
290 seen.insert(neighbor);
291 }
292 }
293
294 let mut scored: Vec<(f32, u32)> = seen.iter()
296 .filter_map(|&node_index| {
297 let uuid = &self.hnsw.index_to_id[node_index as usize];
299
300 let data_offset = self.index.get(uuid)?;
302 let record = retrieve_record(*data_offset, &self.data_path);
303
304 let score = cosine_similarity(&record.vector, query_vector);
306 Some((score, node_index))
307 })
308 .collect();
309
310 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
312 scored.truncate(candidates_per_layer);
313
314 layer_candidates.insert(layer, scored.iter().map(|&(_, idx)| idx).collect());
315
316 let next_size = (candidates_per_layer / 2).max(1);
318 current_candidates = scored.iter().take(next_size).map(|&(_, idx)| idx).collect();
319 }
320
321 layer_candidates
322 }
323
324 pub fn search(&self, query_vector: &[f32], k: usize) -> Vec<Record> {
326 let mut results = Vec::new();
327
328 for (_id, offset) in &self.index {
329 let record = retrieve_record(*offset, &self.data_path);
330 let similarity = cosine_similarity(&record.vector, query_vector);
331 results.push((similarity, record));
332 }
333
334 results.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
335 results.into_iter().take(k).map(|(_, record)| record).collect()
336 }
337}