ruvector_core/advanced/
learned_index.rs

1//! # Learned Index Structures
2//!
3//! Experimental learned indexes using neural networks to approximate data distribution.
4//! Based on Recursive Model Index (RMI) concept with bounded error correction.
5
6use crate::error::{Result, RuvectorError};
7use crate::types::VectorId;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11/// Trait for learned index structures
12pub trait LearnedIndex {
13    /// Predict position for a key
14    fn predict(&self, key: &[f32]) -> Result<usize>;
15
16    /// Insert a key-value pair
17    fn insert(&mut self, key: Vec<f32>, value: VectorId) -> Result<()>;
18
19    /// Search for a key
20    fn search(&self, key: &[f32]) -> Result<Option<VectorId>>;
21
22    /// Get index statistics
23    fn stats(&self) -> IndexStats;
24}
25
26/// Statistics for learned indexes
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct IndexStats {
29    pub total_entries: usize,
30    pub model_size_bytes: usize,
31    pub avg_error: f32,
32    pub max_error: usize,
33}
34
35/// Simple linear model for CDF approximation
36#[derive(Debug, Clone, Serialize, Deserialize)]
37struct LinearModel {
38    weights: Vec<f32>,
39    bias: f32,
40}
41
42impl LinearModel {
43    fn new(dimensions: usize) -> Self {
44        Self {
45            weights: vec![0.0; dimensions],
46            bias: 0.0,
47        }
48    }
49
50    fn predict(&self, input: &[f32]) -> f32 {
51        let mut result = self.bias;
52        for (w, x) in self.weights.iter().zip(input.iter()) {
53            result += w * x;
54        }
55        result.max(0.0)
56    }
57
58    fn train_simple(&mut self, data: &[(Vec<f32>, usize)]) {
59        if data.is_empty() {
60            return;
61        }
62
63        // Simple least squares approximation
64        let n = data.len() as f32;
65        let dim = self.weights.len();
66
67        // Reset weights
68        self.weights.fill(0.0);
69        self.bias = 0.0;
70
71        // Compute means
72        let mut mean_x = vec![0.0; dim];
73        let mut mean_y = 0.0;
74
75        for (x, y) in data {
76            for (i, &val) in x.iter().enumerate() {
77                mean_x[i] += val;
78            }
79            mean_y += *y as f32;
80        }
81
82        for val in mean_x.iter_mut() {
83            *val /= n;
84        }
85        mean_y /= n;
86
87        // Simple linear regression for first dimension
88        if dim > 0 {
89            let mut numerator = 0.0;
90            let mut denominator = 0.0;
91
92            for (x, y) in data {
93                let x_diff = x[0] - mean_x[0];
94                let y_diff = *y as f32 - mean_y;
95                numerator += x_diff * y_diff;
96                denominator += x_diff * x_diff;
97            }
98
99            if denominator.abs() > 1e-10 {
100                self.weights[0] = numerator / denominator;
101            }
102            self.bias = mean_y - self.weights[0] * mean_x[0];
103        }
104    }
105}
106
107/// Recursive Model Index (RMI)
108/// Multi-stage neural models making coarse-then-fine predictions
109pub struct RecursiveModelIndex {
110    /// Root model for coarse prediction
111    root_model: LinearModel,
112    /// Second-level models for fine prediction
113    leaf_models: Vec<LinearModel>,
114    /// Sorted data with error correction
115    data: Vec<(Vec<f32>, VectorId)>,
116    /// Error bounds for binary search fallback
117    max_error: usize,
118    /// Dimensions of vectors
119    dimensions: usize,
120}
121
122impl RecursiveModelIndex {
123    /// Create a new RMI with specified number of leaf models
124    pub fn new(dimensions: usize, num_leaf_models: usize) -> Self {
125        let leaf_models = (0..num_leaf_models)
126            .map(|_| LinearModel::new(dimensions))
127            .collect();
128
129        Self {
130            root_model: LinearModel::new(dimensions),
131            leaf_models,
132            data: Vec::new(),
133            max_error: 100,
134            dimensions,
135        }
136    }
137
138    /// Build the index from data
139    pub fn build(&mut self, mut data: Vec<(Vec<f32>, VectorId)>) -> Result<()> {
140        if data.is_empty() {
141            return Err(RuvectorError::InvalidInput(
142                "Cannot build index from empty data".into(),
143            ));
144        }
145
146        if data[0].0.is_empty() {
147            return Err(RuvectorError::InvalidInput(
148                "Cannot build index from vectors with zero dimensions".into(),
149            ));
150        }
151
152        if self.leaf_models.is_empty() {
153            return Err(RuvectorError::InvalidInput(
154                "Cannot build index with zero leaf models".into(),
155            ));
156        }
157
158        // Sort data by first dimension (simple heuristic)
159        data.sort_by(|a, b| {
160            a.0[0]
161                .partial_cmp(&b.0[0])
162                .unwrap_or(std::cmp::Ordering::Equal)
163        });
164
165        let n = data.len();
166
167        // Train root model to predict leaf model index
168        let root_training_data: Vec<(Vec<f32>, usize)> = data
169            .iter()
170            .enumerate()
171            .map(|(i, (key, _))| {
172                let leaf_idx = (i * self.leaf_models.len()) / n;
173                (key.clone(), leaf_idx)
174            })
175            .collect();
176
177        self.root_model.train_simple(&root_training_data);
178
179        // Train each leaf model
180        let num_leaf_models = self.leaf_models.len();
181        let chunk_size = n / num_leaf_models;
182        for (i, model) in self.leaf_models.iter_mut().enumerate() {
183            let start = i * chunk_size;
184            let end = if i == num_leaf_models - 1 {
185                n
186            } else {
187                (i + 1) * chunk_size
188            };
189
190            if start < n {
191                let leaf_data: Vec<(Vec<f32>, usize)> = data[start..end.min(n)]
192                    .iter()
193                    .enumerate()
194                    .map(|(j, (key, _))| (key.clone(), start + j))
195                    .collect();
196
197                model.train_simple(&leaf_data);
198            }
199        }
200
201        self.data = data;
202        Ok(())
203    }
204}
205
206impl LearnedIndex for RecursiveModelIndex {
207    fn predict(&self, key: &[f32]) -> Result<usize> {
208        if key.len() != self.dimensions {
209            return Err(RuvectorError::InvalidInput(
210                "Key dimensions mismatch".into(),
211            ));
212        }
213
214        if self.leaf_models.is_empty() {
215            return Err(RuvectorError::InvalidInput(
216                "Index not built: no leaf models available".into(),
217            ));
218        }
219
220        if self.data.is_empty() {
221            return Err(RuvectorError::InvalidInput(
222                "Index not built: no data available".into(),
223            ));
224        }
225
226        // Root model predicts leaf model
227        let leaf_idx = self.root_model.predict(key) as usize;
228        let leaf_idx = leaf_idx.min(self.leaf_models.len() - 1);
229
230        // Leaf model predicts position
231        let pos = self.leaf_models[leaf_idx].predict(key) as usize;
232        let pos = pos.min(self.data.len().saturating_sub(1));
233
234        Ok(pos)
235    }
236
237    fn insert(&mut self, key: Vec<f32>, value: VectorId) -> Result<()> {
238        // For simplicity, append and mark for rebuild
239        // Production implementation would use incremental updates
240        self.data.push((key, value));
241        Ok(())
242    }
243
244    fn search(&self, key: &[f32]) -> Result<Option<VectorId>> {
245        if self.data.is_empty() {
246            return Ok(None);
247        }
248
249        let predicted_pos = self.predict(key)?;
250
251        // Binary search around predicted position with error bound
252        let start = predicted_pos.saturating_sub(self.max_error);
253        let end = (predicted_pos + self.max_error).min(self.data.len());
254
255        for i in start..end {
256            if self.data[i].0 == key {
257                return Ok(Some(self.data[i].1.clone()));
258            }
259        }
260
261        Ok(None)
262    }
263
264    fn stats(&self) -> IndexStats {
265        let model_size = std::mem::size_of_val(&self.root_model)
266            + self.leaf_models.len() * std::mem::size_of::<LinearModel>();
267
268        // Compute average prediction error
269        let mut total_error = 0.0;
270        let mut max_error = 0;
271
272        for (i, (key, _)) in self.data.iter().enumerate() {
273            if let Ok(pred_pos) = self.predict(key) {
274                let error = (i as i32 - pred_pos as i32).abs() as usize;
275                total_error += error as f32;
276                max_error = max_error.max(error);
277            }
278        }
279
280        let avg_error = if !self.data.is_empty() {
281            total_error / self.data.len() as f32
282        } else {
283            0.0
284        };
285
286        IndexStats {
287            total_entries: self.data.len(),
288            model_size_bytes: model_size,
289            avg_error,
290            max_error,
291        }
292    }
293}
294
295/// Hybrid index combining learned index for static data with HNSW for dynamic updates
296pub struct HybridIndex {
297    /// Learned index for static segment
298    learned: RecursiveModelIndex,
299    /// Dynamic updates buffer
300    dynamic_buffer: HashMap<Vec<u8>, VectorId>,
301    /// Threshold for rebuilding learned index
302    rebuild_threshold: usize,
303}
304
305impl HybridIndex {
306    /// Create a new hybrid index
307    pub fn new(dimensions: usize, num_leaf_models: usize, rebuild_threshold: usize) -> Self {
308        Self {
309            learned: RecursiveModelIndex::new(dimensions, num_leaf_models),
310            dynamic_buffer: HashMap::new(),
311            rebuild_threshold,
312        }
313    }
314
315    /// Build the learned portion from static data
316    pub fn build_static(&mut self, data: Vec<(Vec<f32>, VectorId)>) -> Result<()> {
317        self.learned.build(data)
318    }
319
320    /// Check if rebuild is needed
321    pub fn needs_rebuild(&self) -> bool {
322        self.dynamic_buffer.len() >= self.rebuild_threshold
323    }
324
325    /// Rebuild learned index incorporating dynamic updates
326    pub fn rebuild(&mut self) -> Result<()> {
327        let mut all_data: Vec<(Vec<f32>, VectorId)> = self.learned.data.clone();
328
329        for (key_bytes, value) in &self.dynamic_buffer {
330            let (key, _): (Vec<f32>, usize) =
331                bincode::decode_from_slice(key_bytes, bincode::config::standard())
332                    .map_err(|e| RuvectorError::SerializationError(e.to_string()))?;
333            all_data.push((key, value.clone()));
334        }
335
336        self.learned.build(all_data)?;
337        self.dynamic_buffer.clear();
338        Ok(())
339    }
340
341    fn serialize_key(key: &[f32]) -> Vec<u8> {
342        bincode::encode_to_vec(key, bincode::config::standard()).unwrap_or_default()
343    }
344}
345
346impl LearnedIndex for HybridIndex {
347    fn predict(&self, key: &[f32]) -> Result<usize> {
348        self.learned.predict(key)
349    }
350
351    fn insert(&mut self, key: Vec<f32>, value: VectorId) -> Result<()> {
352        let key_bytes = Self::serialize_key(&key);
353        self.dynamic_buffer.insert(key_bytes, value);
354        Ok(())
355    }
356
357    fn search(&self, key: &[f32]) -> Result<Option<VectorId>> {
358        // Check dynamic buffer first
359        let key_bytes = Self::serialize_key(key);
360        if let Some(value) = self.dynamic_buffer.get(&key_bytes) {
361            return Ok(Some(value.clone()));
362        }
363
364        // Fall back to learned index
365        self.learned.search(key)
366    }
367
368    fn stats(&self) -> IndexStats {
369        let mut stats = self.learned.stats();
370        stats.total_entries += self.dynamic_buffer.len();
371        stats
372    }
373}
374
375#[cfg(test)]
376mod tests {
377    use super::*;
378
379    #[test]
380    fn test_linear_model() {
381        let mut model = LinearModel::new(2);
382        let data = vec![
383            (vec![0.0, 0.0], 0),
384            (vec![1.0, 1.0], 10),
385            (vec![2.0, 2.0], 20),
386        ];
387
388        model.train_simple(&data);
389
390        let pred = model.predict(&[1.5, 1.5]);
391        assert!(pred >= 0.0 && pred <= 30.0);
392    }
393
394    #[test]
395    fn test_rmi_build() {
396        let mut rmi = RecursiveModelIndex::new(2, 4);
397
398        let data: Vec<(Vec<f32>, VectorId)> = (0..100)
399            .map(|i| {
400                let x = i as f32 / 100.0;
401                (vec![x, x * x], i.to_string())
402            })
403            .collect();
404
405        rmi.build(data).unwrap();
406
407        let stats = rmi.stats();
408        assert_eq!(stats.total_entries, 100);
409        assert!(stats.avg_error < 50.0); // Should have reasonable error
410    }
411
412    #[test]
413    fn test_rmi_search() {
414        let mut rmi = RecursiveModelIndex::new(1, 2);
415
416        let data = vec![
417            (vec![0.0], "0".to_string()),
418            (vec![0.5], "1".to_string()),
419            (vec![1.0], "2".to_string()),
420        ];
421
422        rmi.build(data).unwrap();
423
424        let result = rmi.search(&[0.5]).unwrap();
425        assert_eq!(result, Some("1".to_string()));
426    }
427
428    #[test]
429    fn test_hybrid_index() {
430        let mut hybrid = HybridIndex::new(1, 2, 10);
431
432        let static_data = vec![
433            (vec![0.0], "0".to_string()),
434            (vec![1.0], "1".to_string()),
435        ];
436        hybrid.build_static(static_data).unwrap();
437
438        // Add dynamic updates
439        hybrid.insert(vec![2.0], "2".to_string()).unwrap();
440
441        assert_eq!(hybrid.search(&[2.0]).unwrap(), Some("2".to_string()));
442        assert_eq!(hybrid.search(&[0.0]).unwrap(), Some("0".to_string()));
443    }
444}