ruvector_core/advanced/
learned_index.rs

1//! # Learned Index Structures
2//!
3//! Experimental learned indexes using neural networks to approximate data distribution.
4//! Based on Recursive Model Index (RMI) concept with bounded error correction.
5
6use crate::error::{Result, RuvectorError};
7use crate::types::VectorId;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11/// Trait for learned index structures
12pub trait LearnedIndex {
13    /// Predict position for a key
14    fn predict(&self, key: &[f32]) -> Result<usize>;
15
16    /// Insert a key-value pair
17    fn insert(&mut self, key: Vec<f32>, value: VectorId) -> Result<()>;
18
19    /// Search for a key
20    fn search(&self, key: &[f32]) -> Result<Option<VectorId>>;
21
22    /// Get index statistics
23    fn stats(&self) -> IndexStats;
24}
25
26/// Statistics for learned indexes
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct IndexStats {
29    pub total_entries: usize,
30    pub model_size_bytes: usize,
31    pub avg_error: f32,
32    pub max_error: usize,
33}
34
35/// Simple linear model for CDF approximation
36#[derive(Debug, Clone, Serialize, Deserialize)]
37struct LinearModel {
38    weights: Vec<f32>,
39    bias: f32,
40}
41
42impl LinearModel {
43    fn new(dimensions: usize) -> Self {
44        Self {
45            weights: vec![0.0; dimensions],
46            bias: 0.0,
47        }
48    }
49
50    fn predict(&self, input: &[f32]) -> f32 {
51        let mut result = self.bias;
52        for (w, x) in self.weights.iter().zip(input.iter()) {
53            result += w * x;
54        }
55        result.max(0.0)
56    }
57
58    fn train_simple(&mut self, data: &[(Vec<f32>, usize)]) {
59        if data.is_empty() {
60            return;
61        }
62
63        // Simple least squares approximation
64        let n = data.len() as f32;
65        let dim = self.weights.len();
66
67        // Reset weights
68        self.weights.fill(0.0);
69        self.bias = 0.0;
70
71        // Compute means
72        let mut mean_x = vec![0.0; dim];
73        let mut mean_y = 0.0;
74
75        for (x, y) in data {
76            for (i, &val) in x.iter().enumerate() {
77                mean_x[i] += val;
78            }
79            mean_y += *y as f32;
80        }
81
82        for val in mean_x.iter_mut() {
83            *val /= n;
84        }
85        mean_y /= n;
86
87        // Simple linear regression for first dimension
88        if dim > 0 {
89            let mut numerator = 0.0;
90            let mut denominator = 0.0;
91
92            for (x, y) in data {
93                let x_diff = x[0] - mean_x[0];
94                let y_diff = *y as f32 - mean_y;
95                numerator += x_diff * y_diff;
96                denominator += x_diff * x_diff;
97            }
98
99            if denominator.abs() > 1e-10 {
100                self.weights[0] = numerator / denominator;
101            }
102            self.bias = mean_y - self.weights[0] * mean_x[0];
103        }
104    }
105}
106
107/// Recursive Model Index (RMI)
108/// Multi-stage neural models making coarse-then-fine predictions
109pub struct RecursiveModelIndex {
110    /// Root model for coarse prediction
111    root_model: LinearModel,
112    /// Second-level models for fine prediction
113    leaf_models: Vec<LinearModel>,
114    /// Sorted data with error correction
115    data: Vec<(Vec<f32>, VectorId)>,
116    /// Error bounds for binary search fallback
117    max_error: usize,
118    /// Dimensions of vectors
119    dimensions: usize,
120}
121
122impl RecursiveModelIndex {
123    /// Create a new RMI with specified number of leaf models
124    pub fn new(dimensions: usize, num_leaf_models: usize) -> Self {
125        let leaf_models = (0..num_leaf_models)
126            .map(|_| LinearModel::new(dimensions))
127            .collect();
128
129        Self {
130            root_model: LinearModel::new(dimensions),
131            leaf_models,
132            data: Vec::new(),
133            max_error: 100,
134            dimensions,
135        }
136    }
137
138    /// Build the index from data
139    pub fn build(&mut self, mut data: Vec<(Vec<f32>, VectorId)>) -> Result<()> {
140        if data.is_empty() {
141            return Err(RuvectorError::InvalidInput(
142                "Cannot build index from empty data".into(),
143            ));
144        }
145
146        // Sort data by first dimension (simple heuristic)
147        data.sort_by(|a, b| {
148            a.0[0]
149                .partial_cmp(&b.0[0])
150                .unwrap_or(std::cmp::Ordering::Equal)
151        });
152
153        let n = data.len();
154
155        // Train root model to predict leaf model index
156        let root_training_data: Vec<(Vec<f32>, usize)> = data
157            .iter()
158            .enumerate()
159            .map(|(i, (key, _))| {
160                let leaf_idx = (i * self.leaf_models.len()) / n;
161                (key.clone(), leaf_idx)
162            })
163            .collect();
164
165        self.root_model.train_simple(&root_training_data);
166
167        // Train each leaf model
168        let num_leaf_models = self.leaf_models.len();
169        let chunk_size = n / num_leaf_models;
170        for (i, model) in self.leaf_models.iter_mut().enumerate() {
171            let start = i * chunk_size;
172            let end = if i == num_leaf_models - 1 {
173                n
174            } else {
175                (i + 1) * chunk_size
176            };
177
178            if start < n {
179                let leaf_data: Vec<(Vec<f32>, usize)> = data[start..end.min(n)]
180                    .iter()
181                    .enumerate()
182                    .map(|(j, (key, _))| (key.clone(), start + j))
183                    .collect();
184
185                model.train_simple(&leaf_data);
186            }
187        }
188
189        self.data = data;
190        Ok(())
191    }
192}
193
194impl LearnedIndex for RecursiveModelIndex {
195    fn predict(&self, key: &[f32]) -> Result<usize> {
196        if key.len() != self.dimensions {
197            return Err(RuvectorError::InvalidInput(
198                "Key dimensions mismatch".into(),
199            ));
200        }
201
202        // Root model predicts leaf model
203        let leaf_idx = self.root_model.predict(key) as usize;
204        let leaf_idx = leaf_idx.min(self.leaf_models.len() - 1);
205
206        // Leaf model predicts position
207        let pos = self.leaf_models[leaf_idx].predict(key) as usize;
208        let pos = pos.min(self.data.len().saturating_sub(1));
209
210        Ok(pos)
211    }
212
213    fn insert(&mut self, key: Vec<f32>, value: VectorId) -> Result<()> {
214        // For simplicity, append and mark for rebuild
215        // Production implementation would use incremental updates
216        self.data.push((key, value));
217        Ok(())
218    }
219
220    fn search(&self, key: &[f32]) -> Result<Option<VectorId>> {
221        if self.data.is_empty() {
222            return Ok(None);
223        }
224
225        let predicted_pos = self.predict(key)?;
226
227        // Binary search around predicted position with error bound
228        let start = predicted_pos.saturating_sub(self.max_error);
229        let end = (predicted_pos + self.max_error).min(self.data.len());
230
231        for i in start..end {
232            if self.data[i].0 == key {
233                return Ok(Some(self.data[i].1.clone()));
234            }
235        }
236
237        Ok(None)
238    }
239
240    fn stats(&self) -> IndexStats {
241        let model_size = std::mem::size_of_val(&self.root_model)
242            + self.leaf_models.len() * std::mem::size_of::<LinearModel>();
243
244        // Compute average prediction error
245        let mut total_error = 0.0;
246        let mut max_error = 0;
247
248        for (i, (key, _)) in self.data.iter().enumerate() {
249            if let Ok(pred_pos) = self.predict(key) {
250                let error = (i as i32 - pred_pos as i32).abs() as usize;
251                total_error += error as f32;
252                max_error = max_error.max(error);
253            }
254        }
255
256        let avg_error = if !self.data.is_empty() {
257            total_error / self.data.len() as f32
258        } else {
259            0.0
260        };
261
262        IndexStats {
263            total_entries: self.data.len(),
264            model_size_bytes: model_size,
265            avg_error,
266            max_error,
267        }
268    }
269}
270
271/// Hybrid index combining learned index for static data with HNSW for dynamic updates
272pub struct HybridIndex {
273    /// Learned index for static segment
274    learned: RecursiveModelIndex,
275    /// Dynamic updates buffer
276    dynamic_buffer: HashMap<Vec<u8>, VectorId>,
277    /// Threshold for rebuilding learned index
278    rebuild_threshold: usize,
279}
280
281impl HybridIndex {
282    /// Create a new hybrid index
283    pub fn new(dimensions: usize, num_leaf_models: usize, rebuild_threshold: usize) -> Self {
284        Self {
285            learned: RecursiveModelIndex::new(dimensions, num_leaf_models),
286            dynamic_buffer: HashMap::new(),
287            rebuild_threshold,
288        }
289    }
290
291    /// Build the learned portion from static data
292    pub fn build_static(&mut self, data: Vec<(Vec<f32>, VectorId)>) -> Result<()> {
293        self.learned.build(data)
294    }
295
296    /// Check if rebuild is needed
297    pub fn needs_rebuild(&self) -> bool {
298        self.dynamic_buffer.len() >= self.rebuild_threshold
299    }
300
301    /// Rebuild learned index incorporating dynamic updates
302    pub fn rebuild(&mut self) -> Result<()> {
303        let mut all_data: Vec<(Vec<f32>, VectorId)> = self.learned.data.clone();
304
305        for (key_bytes, value) in &self.dynamic_buffer {
306            let (key, _): (Vec<f32>, usize) =
307                bincode::decode_from_slice(key_bytes, bincode::config::standard())
308                    .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
309            all_data.push((key, value.clone()));
310        }
311
312        self.learned.build(all_data)?;
313        self.dynamic_buffer.clear();
314        Ok(())
315    }
316
317    fn serialize_key(key: &[f32]) -> Vec<u8> {
318        bincode::encode_to_vec(key, bincode::config::standard()).unwrap_or_default()
319    }
320}
321
322impl LearnedIndex for HybridIndex {
323    fn predict(&self, key: &[f32]) -> Result<usize> {
324        self.learned.predict(key)
325    }
326
327    fn insert(&mut self, key: Vec<f32>, value: VectorId) -> Result<()> {
328        let key_bytes = Self::serialize_key(&key);
329        self.dynamic_buffer.insert(key_bytes, value);
330        Ok(())
331    }
332
333    fn search(&self, key: &[f32]) -> Result<Option<VectorId>> {
334        // Check dynamic buffer first
335        let key_bytes = Self::serialize_key(key);
336        if let Some(value) = self.dynamic_buffer.get(&key_bytes) {
337            return Ok(Some(value.clone()));
338        }
339
340        // Fall back to learned index
341        self.learned.search(key)
342    }
343
344    fn stats(&self) -> IndexStats {
345        let mut stats = self.learned.stats();
346        stats.total_entries += self.dynamic_buffer.len();
347        stats
348    }
349}
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354
355    #[test]
356    fn test_linear_model() {
357        let mut model = LinearModel::new(2);
358        let data = vec![
359            (vec![0.0, 0.0], 0),
360            (vec![1.0, 1.0], 10),
361            (vec![2.0, 2.0], 20),
362        ];
363
364        model.train_simple(&data);
365
366        let pred = model.predict(&[1.5, 1.5]);
367        assert!(pred >= 0.0 && pred <= 30.0);
368    }
369
370    #[test]
371    fn test_rmi_build() {
372        let mut rmi = RecursiveModelIndex::new(2, 4);
373
374        let data: Vec<(Vec<f32>, VectorId)> = (0..100)
375            .map(|i| {
376                let x = i as f32 / 100.0;
377                (vec![x, x * x], i as VectorId)
378            })
379            .collect();
380
381        rmi.build(data).unwrap();
382
383        let stats = rmi.stats();
384        assert_eq!(stats.total_entries, 100);
385        assert!(stats.avg_error < 50.0); // Should have reasonable error
386    }
387
388    #[test]
389    fn test_rmi_search() {
390        let mut rmi = RecursiveModelIndex::new(1, 2);
391
392        let data = vec![(vec![0.0], 0), (vec![0.5], 1), (vec![1.0], 2)];
393
394        rmi.build(data).unwrap();
395
396        let result = rmi.search(&[0.5]).unwrap();
397        assert_eq!(result, Some(1));
398    }
399
400    #[test]
401    fn test_hybrid_index() {
402        let mut hybrid = HybridIndex::new(1, 2, 10);
403
404        let static_data = vec![(vec![0.0], 0), (vec![1.0], 1)];
405        hybrid.build_static(static_data).unwrap();
406
407        // Add dynamic updates
408        hybrid.insert(vec![2.0], 2).unwrap();
409
410        assert_eq!(hybrid.search(&[2.0]).unwrap(), Some(2));
411        assert_eq!(hybrid.search(&[0.0]).unwrap(), Some(0));
412    }
413}