1use crate::error::{Result, RuvectorError};
7use crate::types::VectorId;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11pub trait LearnedIndex {
13 fn predict(&self, key: &[f32]) -> Result<usize>;
15
16 fn insert(&mut self, key: Vec<f32>, value: VectorId) -> Result<()>;
18
19 fn search(&self, key: &[f32]) -> Result<Option<VectorId>>;
21
22 fn stats(&self) -> IndexStats;
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct IndexStats {
29 pub total_entries: usize,
30 pub model_size_bytes: usize,
31 pub avg_error: f32,
32 pub max_error: usize,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37struct LinearModel {
38 weights: Vec<f32>,
39 bias: f32,
40}
41
42impl LinearModel {
43 fn new(dimensions: usize) -> Self {
44 Self {
45 weights: vec![0.0; dimensions],
46 bias: 0.0,
47 }
48 }
49
50 fn predict(&self, input: &[f32]) -> f32 {
51 let mut result = self.bias;
52 for (w, x) in self.weights.iter().zip(input.iter()) {
53 result += w * x;
54 }
55 result.max(0.0)
56 }
57
58 fn train_simple(&mut self, data: &[(Vec<f32>, usize)]) {
59 if data.is_empty() {
60 return;
61 }
62
63 let n = data.len() as f32;
65 let dim = self.weights.len();
66
67 self.weights.fill(0.0);
69 self.bias = 0.0;
70
71 let mut mean_x = vec![0.0; dim];
73 let mut mean_y = 0.0;
74
75 for (x, y) in data {
76 for (i, &val) in x.iter().enumerate() {
77 mean_x[i] += val;
78 }
79 mean_y += *y as f32;
80 }
81
82 for val in mean_x.iter_mut() {
83 *val /= n;
84 }
85 mean_y /= n;
86
87 if dim > 0 {
89 let mut numerator = 0.0;
90 let mut denominator = 0.0;
91
92 for (x, y) in data {
93 let x_diff = x[0] - mean_x[0];
94 let y_diff = *y as f32 - mean_y;
95 numerator += x_diff * y_diff;
96 denominator += x_diff * x_diff;
97 }
98
99 if denominator.abs() > 1e-10 {
100 self.weights[0] = numerator / denominator;
101 }
102 self.bias = mean_y - self.weights[0] * mean_x[0];
103 }
104 }
105}
106
107pub struct RecursiveModelIndex {
110 root_model: LinearModel,
112 leaf_models: Vec<LinearModel>,
114 data: Vec<(Vec<f32>, VectorId)>,
116 max_error: usize,
118 dimensions: usize,
120}
121
122impl RecursiveModelIndex {
123 pub fn new(dimensions: usize, num_leaf_models: usize) -> Self {
125 let leaf_models = (0..num_leaf_models)
126 .map(|_| LinearModel::new(dimensions))
127 .collect();
128
129 Self {
130 root_model: LinearModel::new(dimensions),
131 leaf_models,
132 data: Vec::new(),
133 max_error: 100,
134 dimensions,
135 }
136 }
137
138 pub fn build(&mut self, mut data: Vec<(Vec<f32>, VectorId)>) -> Result<()> {
140 if data.is_empty() {
141 return Err(RuvectorError::InvalidInput(
142 "Cannot build index from empty data".into(),
143 ));
144 }
145
146 data.sort_by(|a, b| {
148 a.0[0]
149 .partial_cmp(&b.0[0])
150 .unwrap_or(std::cmp::Ordering::Equal)
151 });
152
153 let n = data.len();
154
155 let root_training_data: Vec<(Vec<f32>, usize)> = data
157 .iter()
158 .enumerate()
159 .map(|(i, (key, _))| {
160 let leaf_idx = (i * self.leaf_models.len()) / n;
161 (key.clone(), leaf_idx)
162 })
163 .collect();
164
165 self.root_model.train_simple(&root_training_data);
166
167 let num_leaf_models = self.leaf_models.len();
169 let chunk_size = n / num_leaf_models;
170 for (i, model) in self.leaf_models.iter_mut().enumerate() {
171 let start = i * chunk_size;
172 let end = if i == num_leaf_models - 1 {
173 n
174 } else {
175 (i + 1) * chunk_size
176 };
177
178 if start < n {
179 let leaf_data: Vec<(Vec<f32>, usize)> = data[start..end.min(n)]
180 .iter()
181 .enumerate()
182 .map(|(j, (key, _))| (key.clone(), start + j))
183 .collect();
184
185 model.train_simple(&leaf_data);
186 }
187 }
188
189 self.data = data;
190 Ok(())
191 }
192}
193
194impl LearnedIndex for RecursiveModelIndex {
195 fn predict(&self, key: &[f32]) -> Result<usize> {
196 if key.len() != self.dimensions {
197 return Err(RuvectorError::InvalidInput(
198 "Key dimensions mismatch".into(),
199 ));
200 }
201
202 let leaf_idx = self.root_model.predict(key) as usize;
204 let leaf_idx = leaf_idx.min(self.leaf_models.len() - 1);
205
206 let pos = self.leaf_models[leaf_idx].predict(key) as usize;
208 let pos = pos.min(self.data.len().saturating_sub(1));
209
210 Ok(pos)
211 }
212
213 fn insert(&mut self, key: Vec<f32>, value: VectorId) -> Result<()> {
214 self.data.push((key, value));
217 Ok(())
218 }
219
220 fn search(&self, key: &[f32]) -> Result<Option<VectorId>> {
221 if self.data.is_empty() {
222 return Ok(None);
223 }
224
225 let predicted_pos = self.predict(key)?;
226
227 let start = predicted_pos.saturating_sub(self.max_error);
229 let end = (predicted_pos + self.max_error).min(self.data.len());
230
231 for i in start..end {
232 if self.data[i].0 == key {
233 return Ok(Some(self.data[i].1.clone()));
234 }
235 }
236
237 Ok(None)
238 }
239
240 fn stats(&self) -> IndexStats {
241 let model_size = std::mem::size_of_val(&self.root_model)
242 + self.leaf_models.len() * std::mem::size_of::<LinearModel>();
243
244 let mut total_error = 0.0;
246 let mut max_error = 0;
247
248 for (i, (key, _)) in self.data.iter().enumerate() {
249 if let Ok(pred_pos) = self.predict(key) {
250 let error = (i as i32 - pred_pos as i32).abs() as usize;
251 total_error += error as f32;
252 max_error = max_error.max(error);
253 }
254 }
255
256 let avg_error = if !self.data.is_empty() {
257 total_error / self.data.len() as f32
258 } else {
259 0.0
260 };
261
262 IndexStats {
263 total_entries: self.data.len(),
264 model_size_bytes: model_size,
265 avg_error,
266 max_error,
267 }
268 }
269}
270
271pub struct HybridIndex {
273 learned: RecursiveModelIndex,
275 dynamic_buffer: HashMap<Vec<u8>, VectorId>,
277 rebuild_threshold: usize,
279}
280
281impl HybridIndex {
282 pub fn new(dimensions: usize, num_leaf_models: usize, rebuild_threshold: usize) -> Self {
284 Self {
285 learned: RecursiveModelIndex::new(dimensions, num_leaf_models),
286 dynamic_buffer: HashMap::new(),
287 rebuild_threshold,
288 }
289 }
290
291 pub fn build_static(&mut self, data: Vec<(Vec<f32>, VectorId)>) -> Result<()> {
293 self.learned.build(data)
294 }
295
296 pub fn needs_rebuild(&self) -> bool {
298 self.dynamic_buffer.len() >= self.rebuild_threshold
299 }
300
301 pub fn rebuild(&mut self) -> Result<()> {
303 let mut all_data: Vec<(Vec<f32>, VectorId)> = self.learned.data.clone();
304
305 for (key_bytes, value) in &self.dynamic_buffer {
306 let (key, _): (Vec<f32>, usize) =
307 bincode::decode_from_slice(key_bytes, bincode::config::standard())
308 .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
309 all_data.push((key, value.clone()));
310 }
311
312 self.learned.build(all_data)?;
313 self.dynamic_buffer.clear();
314 Ok(())
315 }
316
317 fn serialize_key(key: &[f32]) -> Vec<u8> {
318 bincode::encode_to_vec(key, bincode::config::standard()).unwrap_or_default()
319 }
320}
321
322impl LearnedIndex for HybridIndex {
323 fn predict(&self, key: &[f32]) -> Result<usize> {
324 self.learned.predict(key)
325 }
326
327 fn insert(&mut self, key: Vec<f32>, value: VectorId) -> Result<()> {
328 let key_bytes = Self::serialize_key(&key);
329 self.dynamic_buffer.insert(key_bytes, value);
330 Ok(())
331 }
332
333 fn search(&self, key: &[f32]) -> Result<Option<VectorId>> {
334 let key_bytes = Self::serialize_key(key);
336 if let Some(value) = self.dynamic_buffer.get(&key_bytes) {
337 return Ok(Some(value.clone()));
338 }
339
340 self.learned.search(key)
342 }
343
344 fn stats(&self) -> IndexStats {
345 let mut stats = self.learned.stats();
346 stats.total_entries += self.dynamic_buffer.len();
347 stats
348 }
349}
350
351#[cfg(test)]
352mod tests {
353 use super::*;
354
355 #[test]
356 fn test_linear_model() {
357 let mut model = LinearModel::new(2);
358 let data = vec![
359 (vec![0.0, 0.0], 0),
360 (vec![1.0, 1.0], 10),
361 (vec![2.0, 2.0], 20),
362 ];
363
364 model.train_simple(&data);
365
366 let pred = model.predict(&[1.5, 1.5]);
367 assert!(pred >= 0.0 && pred <= 30.0);
368 }
369
370 #[test]
371 fn test_rmi_build() {
372 let mut rmi = RecursiveModelIndex::new(2, 4);
373
374 let data: Vec<(Vec<f32>, VectorId)> = (0..100)
375 .map(|i| {
376 let x = i as f32 / 100.0;
377 (vec![x, x * x], i as VectorId)
378 })
379 .collect();
380
381 rmi.build(data).unwrap();
382
383 let stats = rmi.stats();
384 assert_eq!(stats.total_entries, 100);
385 assert!(stats.avg_error < 50.0); }
387
388 #[test]
389 fn test_rmi_search() {
390 let mut rmi = RecursiveModelIndex::new(1, 2);
391
392 let data = vec![(vec![0.0], 0), (vec![0.5], 1), (vec![1.0], 2)];
393
394 rmi.build(data).unwrap();
395
396 let result = rmi.search(&[0.5]).unwrap();
397 assert_eq!(result, Some(1));
398 }
399
400 #[test]
401 fn test_hybrid_index() {
402 let mut hybrid = HybridIndex::new(1, 2, 10);
403
404 let static_data = vec![(vec![0.0], 0), (vec![1.0], 1)];
405 hybrid.build_static(static_data).unwrap();
406
407 hybrid.insert(vec![2.0], 2).unwrap();
409
410 assert_eq!(hybrid.search(&[2.0]).unwrap(), Some(2));
411 assert_eq!(hybrid.search(&[0.0]).unwrap(), Some(0));
412 }
413}