1use crate::error::{Result, RuvectorError};
7use crate::types::VectorId;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11pub trait LearnedIndex {
13 fn predict(&self, key: &[f32]) -> Result<usize>;
15
16 fn insert(&mut self, key: Vec<f32>, value: VectorId) -> Result<()>;
18
19 fn search(&self, key: &[f32]) -> Result<Option<VectorId>>;
21
22 fn stats(&self) -> IndexStats;
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct IndexStats {
29 pub total_entries: usize,
30 pub model_size_bytes: usize,
31 pub avg_error: f32,
32 pub max_error: usize,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37struct LinearModel {
38 weights: Vec<f32>,
39 bias: f32,
40}
41
42impl LinearModel {
43 fn new(dimensions: usize) -> Self {
44 Self {
45 weights: vec![0.0; dimensions],
46 bias: 0.0,
47 }
48 }
49
50 fn predict(&self, input: &[f32]) -> f32 {
51 let mut result = self.bias;
52 for (w, x) in self.weights.iter().zip(input.iter()) {
53 result += w * x;
54 }
55 result.max(0.0)
56 }
57
58 fn train_simple(&mut self, data: &[(Vec<f32>, usize)]) {
59 if data.is_empty() {
60 return;
61 }
62
63 let n = data.len() as f32;
65 let dim = self.weights.len();
66
67 self.weights.fill(0.0);
69 self.bias = 0.0;
70
71 let mut mean_x = vec![0.0; dim];
73 let mut mean_y = 0.0;
74
75 for (x, y) in data {
76 for (i, &val) in x.iter().enumerate() {
77 mean_x[i] += val;
78 }
79 mean_y += *y as f32;
80 }
81
82 for val in mean_x.iter_mut() {
83 *val /= n;
84 }
85 mean_y /= n;
86
87 if dim > 0 {
89 let mut numerator = 0.0;
90 let mut denominator = 0.0;
91
92 for (x, y) in data {
93 let x_diff = x[0] - mean_x[0];
94 let y_diff = *y as f32 - mean_y;
95 numerator += x_diff * y_diff;
96 denominator += x_diff * x_diff;
97 }
98
99 if denominator.abs() > 1e-10 {
100 self.weights[0] = numerator / denominator;
101 }
102 self.bias = mean_y - self.weights[0] * mean_x[0];
103 }
104 }
105}
106
107pub struct RecursiveModelIndex {
110 root_model: LinearModel,
112 leaf_models: Vec<LinearModel>,
114 data: Vec<(Vec<f32>, VectorId)>,
116 max_error: usize,
118 dimensions: usize,
120}
121
122impl RecursiveModelIndex {
123 pub fn new(dimensions: usize, num_leaf_models: usize) -> Self {
125 let leaf_models = (0..num_leaf_models)
126 .map(|_| LinearModel::new(dimensions))
127 .collect();
128
129 Self {
130 root_model: LinearModel::new(dimensions),
131 leaf_models,
132 data: Vec::new(),
133 max_error: 100,
134 dimensions,
135 }
136 }
137
138 pub fn build(&mut self, mut data: Vec<(Vec<f32>, VectorId)>) -> Result<()> {
140 if data.is_empty() {
141 return Err(RuvectorError::InvalidInput(
142 "Cannot build index from empty data".into(),
143 ));
144 }
145
146 if data[0].0.is_empty() {
147 return Err(RuvectorError::InvalidInput(
148 "Cannot build index from vectors with zero dimensions".into(),
149 ));
150 }
151
152 if self.leaf_models.is_empty() {
153 return Err(RuvectorError::InvalidInput(
154 "Cannot build index with zero leaf models".into(),
155 ));
156 }
157
158 data.sort_by(|a, b| {
160 a.0[0]
161 .partial_cmp(&b.0[0])
162 .unwrap_or(std::cmp::Ordering::Equal)
163 });
164
165 let n = data.len();
166
167 let root_training_data: Vec<(Vec<f32>, usize)> = data
169 .iter()
170 .enumerate()
171 .map(|(i, (key, _))| {
172 let leaf_idx = (i * self.leaf_models.len()) / n;
173 (key.clone(), leaf_idx)
174 })
175 .collect();
176
177 self.root_model.train_simple(&root_training_data);
178
179 let num_leaf_models = self.leaf_models.len();
181 let chunk_size = n / num_leaf_models;
182 for (i, model) in self.leaf_models.iter_mut().enumerate() {
183 let start = i * chunk_size;
184 let end = if i == num_leaf_models - 1 {
185 n
186 } else {
187 (i + 1) * chunk_size
188 };
189
190 if start < n {
191 let leaf_data: Vec<(Vec<f32>, usize)> = data[start..end.min(n)]
192 .iter()
193 .enumerate()
194 .map(|(j, (key, _))| (key.clone(), start + j))
195 .collect();
196
197 model.train_simple(&leaf_data);
198 }
199 }
200
201 self.data = data;
202 Ok(())
203 }
204}
205
206impl LearnedIndex for RecursiveModelIndex {
207 fn predict(&self, key: &[f32]) -> Result<usize> {
208 if key.len() != self.dimensions {
209 return Err(RuvectorError::InvalidInput(
210 "Key dimensions mismatch".into(),
211 ));
212 }
213
214 if self.leaf_models.is_empty() {
215 return Err(RuvectorError::InvalidInput(
216 "Index not built: no leaf models available".into(),
217 ));
218 }
219
220 if self.data.is_empty() {
221 return Err(RuvectorError::InvalidInput(
222 "Index not built: no data available".into(),
223 ));
224 }
225
226 let leaf_idx = self.root_model.predict(key) as usize;
228 let leaf_idx = leaf_idx.min(self.leaf_models.len() - 1);
229
230 let pos = self.leaf_models[leaf_idx].predict(key) as usize;
232 let pos = pos.min(self.data.len().saturating_sub(1));
233
234 Ok(pos)
235 }
236
237 fn insert(&mut self, key: Vec<f32>, value: VectorId) -> Result<()> {
238 self.data.push((key, value));
241 Ok(())
242 }
243
244 fn search(&self, key: &[f32]) -> Result<Option<VectorId>> {
245 if self.data.is_empty() {
246 return Ok(None);
247 }
248
249 let predicted_pos = self.predict(key)?;
250
251 let start = predicted_pos.saturating_sub(self.max_error);
253 let end = (predicted_pos + self.max_error).min(self.data.len());
254
255 for i in start..end {
256 if self.data[i].0 == key {
257 return Ok(Some(self.data[i].1.clone()));
258 }
259 }
260
261 Ok(None)
262 }
263
264 fn stats(&self) -> IndexStats {
265 let model_size = std::mem::size_of_val(&self.root_model)
266 + self.leaf_models.len() * std::mem::size_of::<LinearModel>();
267
268 let mut total_error = 0.0;
270 let mut max_error = 0;
271
272 for (i, (key, _)) in self.data.iter().enumerate() {
273 if let Ok(pred_pos) = self.predict(key) {
274 let error = (i as i32 - pred_pos as i32).abs() as usize;
275 total_error += error as f32;
276 max_error = max_error.max(error);
277 }
278 }
279
280 let avg_error = if !self.data.is_empty() {
281 total_error / self.data.len() as f32
282 } else {
283 0.0
284 };
285
286 IndexStats {
287 total_entries: self.data.len(),
288 model_size_bytes: model_size,
289 avg_error,
290 max_error,
291 }
292 }
293}
294
295pub struct HybridIndex {
297 learned: RecursiveModelIndex,
299 dynamic_buffer: HashMap<Vec<u8>, VectorId>,
301 rebuild_threshold: usize,
303}
304
305impl HybridIndex {
306 pub fn new(dimensions: usize, num_leaf_models: usize, rebuild_threshold: usize) -> Self {
308 Self {
309 learned: RecursiveModelIndex::new(dimensions, num_leaf_models),
310 dynamic_buffer: HashMap::new(),
311 rebuild_threshold,
312 }
313 }
314
315 pub fn build_static(&mut self, data: Vec<(Vec<f32>, VectorId)>) -> Result<()> {
317 self.learned.build(data)
318 }
319
320 pub fn needs_rebuild(&self) -> bool {
322 self.dynamic_buffer.len() >= self.rebuild_threshold
323 }
324
325 pub fn rebuild(&mut self) -> Result<()> {
327 let mut all_data: Vec<(Vec<f32>, VectorId)> = self.learned.data.clone();
328
329 for (key_bytes, value) in &self.dynamic_buffer {
330 let (key, _): (Vec<f32>, usize) =
331 bincode::decode_from_slice(key_bytes, bincode::config::standard())
332 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
333 all_data.push((key, value.clone()));
334 }
335
336 self.learned.build(all_data)?;
337 self.dynamic_buffer.clear();
338 Ok(())
339 }
340
341 fn serialize_key(key: &[f32]) -> Vec<u8> {
342 bincode::encode_to_vec(key, bincode::config::standard()).unwrap_or_default()
343 }
344}
345
346impl LearnedIndex for HybridIndex {
347 fn predict(&self, key: &[f32]) -> Result<usize> {
348 self.learned.predict(key)
349 }
350
351 fn insert(&mut self, key: Vec<f32>, value: VectorId) -> Result<()> {
352 let key_bytes = Self::serialize_key(&key);
353 self.dynamic_buffer.insert(key_bytes, value);
354 Ok(())
355 }
356
357 fn search(&self, key: &[f32]) -> Result<Option<VectorId>> {
358 let key_bytes = Self::serialize_key(key);
360 if let Some(value) = self.dynamic_buffer.get(&key_bytes) {
361 return Ok(Some(value.clone()));
362 }
363
364 self.learned.search(key)
366 }
367
368 fn stats(&self) -> IndexStats {
369 let mut stats = self.learned.stats();
370 stats.total_entries += self.dynamic_buffer.len();
371 stats
372 }
373}
374
375#[cfg(test)]
376mod tests {
377 use super::*;
378
379 #[test]
380 fn test_linear_model() {
381 let mut model = LinearModel::new(2);
382 let data = vec![
383 (vec![0.0, 0.0], 0),
384 (vec![1.0, 1.0], 10),
385 (vec![2.0, 2.0], 20),
386 ];
387
388 model.train_simple(&data);
389
390 let pred = model.predict(&[1.5, 1.5]);
391 assert!(pred >= 0.0 && pred <= 30.0);
392 }
393
394 #[test]
395 fn test_rmi_build() {
396 let mut rmi = RecursiveModelIndex::new(2, 4);
397
398 let data: Vec<(Vec<f32>, VectorId)> = (0..100)
399 .map(|i| {
400 let x = i as f32 / 100.0;
401 (vec![x, x * x], i.to_string())
402 })
403 .collect();
404
405 rmi.build(data).unwrap();
406
407 let stats = rmi.stats();
408 assert_eq!(stats.total_entries, 100);
409 assert!(stats.avg_error < 50.0); }
411
412 #[test]
413 fn test_rmi_search() {
414 let mut rmi = RecursiveModelIndex::new(1, 2);
415
416 let data = vec![
417 (vec![0.0], "0".to_string()),
418 (vec![0.5], "1".to_string()),
419 (vec![1.0], "2".to_string()),
420 ];
421
422 rmi.build(data).unwrap();
423
424 let result = rmi.search(&[0.5]).unwrap();
425 assert_eq!(result, Some("1".to_string()));
426 }
427
428 #[test]
429 fn test_hybrid_index() {
430 let mut hybrid = HybridIndex::new(1, 2, 10);
431
432 let static_data = vec![
433 (vec![0.0], "0".to_string()),
434 (vec![1.0], "1".to_string()),
435 ];
436 hybrid.build_static(static_data).unwrap();
437
438 hybrid.insert(vec![2.0], "2".to_string()).unwrap();
440
441 assert_eq!(hybrid.search(&[2.0]).unwrap(), Some("2".to_string()));
442 assert_eq!(hybrid.search(&[0.0]).unwrap(), Some("0".to_string()));
443 }
444}