Skip to main content

ruvector_core/advanced_features/
matryoshka.rs

1//! Matryoshka Representation Learning Support
2//!
3//! Implements adaptive-dimension embedding search inspired by Matryoshka
4//! Representation Learning (MRL). Full-dimensional embeddings are stored once,
5//! but searches can be performed at any prefix dimension—smaller prefixes run
6//! faster while larger ones are more accurate.
7//!
8//! # Two-Phase Funnel Search
9//!
10//! The flagship feature is [`MatryoshkaIndex::funnel_search`], which:
11//! 1. Filters candidates at a low dimension (fast, coarse)
12//! 2. Reranks the survivors at full dimension (slower, precise)
13//!
14//! This typically yields the same recall as full-dimension search at a fraction
15//! of the cost.
16//!
17//! # Example
18//!
19//! ```
20//! use ruvector_core::advanced_features::matryoshka::*;
21//! use ruvector_core::types::DistanceMetric;
22//!
23//! let config = MatryoshkaConfig {
24//!     full_dim: 8,
25//!     supported_dims: vec![2, 4, 8],
26//!     metric: DistanceMetric::Cosine,
27//! };
28//! let mut index = MatryoshkaIndex::new(config).unwrap();
29//! index.insert("v1".into(), vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], None).unwrap();
30//! let results = index.search(&[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 4, 10).unwrap();
31//! assert_eq!(results[0].id, "v1");
32//! ```
33
34use crate::error::{Result, RuvectorError};
35use crate::types::{DistanceMetric, SearchResult, VectorId};
36use serde::{Deserialize, Serialize};
37use std::collections::HashMap;
38
39/// Configuration for a Matryoshka embedding index.
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct MatryoshkaConfig {
42    /// The full (maximum) embedding dimension.
43    pub full_dim: usize,
44    /// Supported truncation dimensions, sorted ascending.
45    /// Each must be <= `full_dim`. The last element should equal `full_dim`.
46    pub supported_dims: Vec<usize>,
47    /// Distance metric for similarity computation.
48    pub metric: DistanceMetric,
49}
50
51impl Default for MatryoshkaConfig {
52    fn default() -> Self {
53        Self {
54            full_dim: 768,
55            supported_dims: vec![64, 128, 256, 512, 768],
56            metric: DistanceMetric::Cosine,
57        }
58    }
59}
60
61/// Configuration for the multi-phase funnel search.
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct FunnelConfig {
64    /// Dimension used for the coarse filtering phase.
65    pub filter_dim: usize,
66    /// Multiplier applied to `top_k` to determine how many candidates
67    /// survive the coarse phase. E.g., 4.0 means 4x top_k candidates.
68    pub candidate_multiplier: f32,
69}
70
71impl Default for FunnelConfig {
72    fn default() -> Self {
73        Self {
74            filter_dim: 64,
75            candidate_multiplier: 4.0,
76        }
77    }
78}
79
80/// Entry stored in the Matryoshka index.
81#[derive(Debug, Clone, Serialize, Deserialize)]
82struct MatryoshkaEntry {
83    id: VectorId,
84    /// Full-dimensional embedding.
85    embedding: Vec<f32>,
86    /// Precomputed L2 norm of the full embedding.
87    full_norm: f32,
88    /// Optional metadata.
89    metadata: Option<HashMap<String, serde_json::Value>>,
90}
91
92/// Matryoshka embedding index supporting adaptive-dimension search.
93///
94/// Stores embeddings at full dimensionality but can search at any prefix
95/// dimension for a speed-accuracy trade-off.
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct MatryoshkaIndex {
98    /// Index configuration.
99    pub config: MatryoshkaConfig,
100    /// Stored entries.
101    entries: Vec<MatryoshkaEntry>,
102    /// Map from vector ID to index in `entries`.
103    id_map: HashMap<VectorId, usize>,
104}
105
106impl MatryoshkaIndex {
107    /// Create a new Matryoshka index.
108    ///
109    /// # Errors
110    ///
111    /// Returns an error if `supported_dims` is empty, any dimension is zero,
112    /// or any dimension exceeds `full_dim`.
113    pub fn new(mut config: MatryoshkaConfig) -> Result<Self> {
114        if config.supported_dims.is_empty() {
115            return Err(RuvectorError::InvalidParameter(
116                "supported_dims must not be empty".into(),
117            ));
118        }
119        config.supported_dims.sort_unstable();
120        config.supported_dims.dedup();
121
122        for &d in &config.supported_dims {
123            if d == 0 {
124                return Err(RuvectorError::InvalidParameter(
125                    "Dimensions must be > 0".into(),
126                ));
127            }
128            if d > config.full_dim {
129                return Err(RuvectorError::InvalidParameter(format!(
130                    "Supported dimension {} exceeds full_dim {}",
131                    d, config.full_dim
132                )));
133            }
134        }
135
136        Ok(Self {
137            config,
138            entries: Vec::new(),
139            id_map: HashMap::new(),
140        })
141    }
142
143    /// Insert a full-dimensional embedding into the index.
144    ///
145    /// # Errors
146    ///
147    /// Returns an error if the embedding dimension does not match `full_dim`.
148    pub fn insert(
149        &mut self,
150        id: VectorId,
151        embedding: Vec<f32>,
152        metadata: Option<HashMap<String, serde_json::Value>>,
153    ) -> Result<()> {
154        if embedding.len() != self.config.full_dim {
155            return Err(RuvectorError::DimensionMismatch {
156                expected: self.config.full_dim,
157                actual: embedding.len(),
158            });
159        }
160
161        let full_norm = compute_norm(&embedding);
162
163        if let Some(&existing_idx) = self.id_map.get(&id) {
164            self.entries[existing_idx] = MatryoshkaEntry {
165                id,
166                embedding,
167                full_norm,
168                metadata,
169            };
170        } else {
171            let idx = self.entries.len();
172            self.entries.push(MatryoshkaEntry {
173                id: id.clone(),
174                embedding,
175                full_norm,
176                metadata,
177            });
178            self.id_map.insert(id, idx);
179        }
180
181        Ok(())
182    }
183
184    /// Return the number of stored vectors.
185    pub fn len(&self) -> usize {
186        self.entries.len()
187    }
188
189    /// Check if the index is empty.
190    pub fn is_empty(&self) -> bool {
191        self.entries.is_empty()
192    }
193
194    /// Search at a specific dimension by truncating embeddings to the first
195    /// `dim` components.
196    ///
197    /// # Arguments
198    ///
199    /// * `query` - Full-dimensional (or at least `dim`-dimensional) query vector.
200    /// * `dim` - The truncation dimension to use for search.
201    /// * `top_k` - Number of results to return.
202    ///
203    /// # Errors
204    ///
205    /// Returns an error if `dim` exceeds the query length or `full_dim`.
206    pub fn search(
207        &self,
208        query: &[f32],
209        dim: usize,
210        top_k: usize,
211    ) -> Result<Vec<SearchResult>> {
212        if dim == 0 {
213            return Err(RuvectorError::InvalidParameter(
214                "Search dimension must be > 0".into(),
215            ));
216        }
217        if dim > self.config.full_dim {
218            return Err(RuvectorError::InvalidParameter(format!(
219                "Search dimension {} exceeds full_dim {}",
220                dim, self.config.full_dim
221            )));
222        }
223        if query.len() < dim {
224            return Err(RuvectorError::DimensionMismatch {
225                expected: dim,
226                actual: query.len(),
227            });
228        }
229
230        let query_prefix = &query[..dim];
231        let query_norm = compute_norm(query_prefix);
232
233        let mut scored: Vec<(usize, f32)> = self
234            .entries
235            .iter()
236            .enumerate()
237            .map(|(idx, entry)| {
238                let doc_prefix = &entry.embedding[..dim];
239                let doc_norm = compute_norm(doc_prefix);
240                let sim = similarity(query_prefix, query_norm, doc_prefix, doc_norm, self.config.metric);
241                (idx, sim)
242            })
243            .collect();
244
245        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
246        scored.truncate(top_k);
247
248        Ok(scored
249            .into_iter()
250            .map(|(idx, score)| {
251                let entry = &self.entries[idx];
252                SearchResult {
253                    id: entry.id.clone(),
254                    score,
255                    vector: None,
256                    metadata: entry.metadata.clone(),
257                }
258            })
259            .collect())
260    }
261
262    /// Two-phase funnel search: coarse filter at low dimension, rerank at full dimension.
263    ///
264    /// 1. Search at `funnel_config.filter_dim` for `candidate_multiplier * top_k` candidates.
265    /// 2. Rerank those candidates at `full_dim`.
266    /// 3. Return the top `top_k`.
267    ///
268    /// # Errors
269    ///
270    /// Returns an error if the query is shorter than `full_dim`.
271    pub fn funnel_search(
272        &self,
273        query: &[f32],
274        top_k: usize,
275        funnel_config: &FunnelConfig,
276    ) -> Result<Vec<SearchResult>> {
277        if query.len() < self.config.full_dim {
278            return Err(RuvectorError::DimensionMismatch {
279                expected: self.config.full_dim,
280                actual: query.len(),
281            });
282        }
283
284        let filter_dim = funnel_config.filter_dim.min(self.config.full_dim);
285        let num_candidates = ((top_k as f32) * funnel_config.candidate_multiplier).ceil() as usize;
286        let num_candidates = num_candidates.max(top_k);
287
288        // Phase 1: coarse search at low dimension.
289        let coarse_results = self.search(query, filter_dim, num_candidates)?;
290
291        // Phase 2: rerank at full dimension.
292        let query_full = &query[..self.config.full_dim];
293        let query_full_norm = compute_norm(query_full);
294
295        let mut reranked: Vec<(VectorId, f32, Option<HashMap<String, serde_json::Value>>)> =
296            coarse_results
297                .into_iter()
298                .filter_map(|r| {
299                    let idx = self.id_map.get(&r.id)?;
300                    let entry = &self.entries[*idx];
301                    let sim = similarity(
302                        query_full,
303                        query_full_norm,
304                        &entry.embedding,
305                        entry.full_norm,
306                        self.config.metric,
307                    );
308                    Some((entry.id.clone(), sim, entry.metadata.clone()))
309                })
310                .collect();
311
312        reranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
313        reranked.truncate(top_k);
314
315        Ok(reranked
316            .into_iter()
317            .map(|(id, score, metadata)| SearchResult {
318                id,
319                score,
320                vector: None,
321                metadata,
322            })
323            .collect())
324    }
325
326    /// Multi-stage cascade search through multiple dimensions.
327    ///
328    /// Searches through dimensions in ascending order, progressively narrowing
329    /// candidates. At each stage, the candidate set is reduced by the
330    /// `reduction_factor`.
331    pub fn cascade_search(
332        &self,
333        query: &[f32],
334        top_k: usize,
335        dims: &[usize],
336        reduction_factor: f32,
337    ) -> Result<Vec<SearchResult>> {
338        if dims.is_empty() {
339            return Err(RuvectorError::InvalidParameter(
340                "Dimension cascade must not be empty".into(),
341            ));
342        }
343        if query.len() < self.config.full_dim {
344            return Err(RuvectorError::DimensionMismatch {
345                expected: self.config.full_dim,
346                actual: query.len(),
347            });
348        }
349
350        // Start with all candidates at the lowest dimension.
351        let mut candidate_indices: Vec<usize> = (0..self.entries.len()).collect();
352
353        for &dim in dims {
354            let dim = dim.min(self.config.full_dim);
355            let query_prefix = &query[..dim];
356            let query_norm = compute_norm(query_prefix);
357
358            let mut scored: Vec<(usize, f32)> = candidate_indices
359                .iter()
360                .map(|&idx| {
361                    let entry = &self.entries[idx];
362                    let doc_prefix = &entry.embedding[..dim];
363                    let doc_norm = compute_norm(doc_prefix);
364                    let sim = similarity(
365                        query_prefix,
366                        query_norm,
367                        doc_prefix,
368                        doc_norm,
369                        self.config.metric,
370                    );
371                    (idx, sim)
372                })
373                .collect();
374
375            scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
376
377            let keep = ((candidate_indices.len() as f32) / reduction_factor)
378                .ceil()
379                .max(top_k as f32) as usize;
380            scored.truncate(keep);
381            candidate_indices = scored.into_iter().map(|(idx, _)| idx).collect();
382        }
383
384        // Final scoring at the last dimension in the cascade.
385        let last_dim = dims.last().copied().unwrap_or(self.config.full_dim);
386        let last_dim = last_dim.min(self.config.full_dim);
387        let query_prefix = &query[..last_dim];
388        let query_norm = compute_norm(query_prefix);
389
390        let mut final_scored: Vec<(usize, f32)> = candidate_indices
391            .iter()
392            .map(|&idx| {
393                let entry = &self.entries[idx];
394                let doc_prefix = &entry.embedding[..last_dim];
395                let doc_norm = compute_norm(doc_prefix);
396                let sim = similarity(
397                    query_prefix,
398                    query_norm,
399                    doc_prefix,
400                    doc_norm,
401                    self.config.metric,
402                );
403                (idx, sim)
404            })
405            .collect();
406
407        final_scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
408        final_scored.truncate(top_k);
409
410        Ok(final_scored
411            .into_iter()
412            .map(|(idx, score)| {
413                let entry = &self.entries[idx];
414                SearchResult {
415                    id: entry.id.clone(),
416                    score,
417                    vector: None,
418                    metadata: entry.metadata.clone(),
419                }
420            })
421            .collect())
422    }
423}
424
425/// Compute the L2 norm of a vector slice.
426#[inline]
427fn compute_norm(v: &[f32]) -> f32 {
428    v.iter().map(|x| x * x).sum::<f32>().sqrt()
429}
430
431/// Compute similarity between two vectors using the given metric and precomputed norms.
432#[inline]
433fn similarity(a: &[f32], norm_a: f32, b: &[f32], norm_b: f32, metric: DistanceMetric) -> f32 {
434    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
435    match metric {
436        DistanceMetric::Cosine => {
437            let denom = norm_a * norm_b;
438            if denom < f32::EPSILON {
439                0.0
440            } else {
441                dot / denom
442            }
443        }
444        DistanceMetric::DotProduct => dot,
445        DistanceMetric::Euclidean => {
446            let dist_sq: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum();
447            1.0 / (1.0 + dist_sq.sqrt())
448        }
449        DistanceMetric::Manhattan => {
450            let dist: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum();
451            1.0 / (1.0 + dist)
452        }
453    }
454}
455
456#[cfg(test)]
457mod tests {
458    use super::*;
459
460    fn make_config(full_dim: usize, dims: Vec<usize>) -> MatryoshkaConfig {
461        MatryoshkaConfig {
462            full_dim,
463            supported_dims: dims,
464            metric: DistanceMetric::Cosine,
465        }
466    }
467
468    fn make_index(full_dim: usize) -> MatryoshkaIndex {
469        let dims: Vec<usize> = (1..=full_dim).filter(|d| d.is_power_of_two() || *d == full_dim).collect();
470        MatryoshkaIndex::new(make_config(full_dim, dims)).unwrap()
471    }
472
473    #[test]
474    fn test_insert_and_len() {
475        let mut index = make_index(4);
476        assert!(index.is_empty());
477        index.insert("v1".into(), vec![1.0, 0.0, 0.0, 0.0], None).unwrap();
478        assert_eq!(index.len(), 1);
479    }
480
481    #[test]
482    fn test_insert_wrong_dimension_error() {
483        let mut index = make_index(4);
484        let res = index.insert("v1".into(), vec![1.0, 0.0], None);
485        assert!(res.is_err());
486    }
487
488    #[test]
489    fn test_search_at_full_dim() {
490        let mut index = make_index(4);
491        index.insert("v1".into(), vec![1.0, 0.0, 0.0, 0.0], None).unwrap();
492        index.insert("v2".into(), vec![0.0, 1.0, 0.0, 0.0], None).unwrap();
493
494        let results = index.search(&[1.0, 0.0, 0.0, 0.0], 4, 10).unwrap();
495        assert_eq!(results[0].id, "v1");
496        assert!((results[0].score - 1.0).abs() < 1e-5);
497        // v2 is orthogonal, score should be ~0
498        assert!(results[1].score.abs() < 1e-5);
499    }
500
501    #[test]
502    fn test_search_at_truncated_dim() {
503        let mut index = make_index(4);
504        // Vectors differ only in the last two components
505        index.insert("v1".into(), vec![1.0, 0.0, 1.0, 0.0], None).unwrap();
506        index.insert("v2".into(), vec![1.0, 0.0, 0.0, 1.0], None).unwrap();
507
508        // At dim=2, both truncate to [1.0, 0.0] — identical scores
509        let results = index.search(&[1.0, 0.0, 0.5, 0.5], 2, 10).unwrap();
510        assert!((results[0].score - results[1].score).abs() < 1e-5);
511
512        // At dim=4, they should differ
513        let results = index.search(&[1.0, 0.0, 1.0, 0.0], 4, 10).unwrap();
514        assert_eq!(results[0].id, "v1");
515        assert!(results[0].score > results[1].score);
516    }
517
518    #[test]
519    fn test_funnel_search() {
520        let mut index = make_index(8);
521        // Insert vectors that share the same first 2 dims but differ later
522        index
523            .insert("best".into(), vec![1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], None)
524            .unwrap();
525        index
526            .insert("good".into(), vec![1.0, 0.0, 0.5, 0.5, 0.0, 0.0, 0.0, 0.0], None)
527            .unwrap();
528        index
529            .insert("bad".into(), vec![0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], None)
530            .unwrap();
531
532        let query = vec![1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0];
533        let funnel = FunnelConfig {
534            filter_dim: 2,
535            candidate_multiplier: 2.0,
536        };
537        let results = index.funnel_search(&query, 2, &funnel).unwrap();
538        assert_eq!(results.len(), 2);
539        assert_eq!(results[0].id, "best");
540    }
541
542    #[test]
543    fn test_funnel_search_finds_correct_top_k() {
544        let mut index = make_index(4);
545        for i in 0..20 {
546            let angle = (i as f32) * std::f32::consts::PI / 20.0;
547            index
548                .insert(
549                    format!("v{}", i),
550                    vec![angle.cos(), angle.sin(), 0.0, 0.0],
551                    None,
552                )
553                .unwrap();
554        }
555
556        let query = vec![1.0, 0.0, 0.0, 0.0];
557        let funnel = FunnelConfig {
558            filter_dim: 2,
559            candidate_multiplier: 4.0,
560        };
561        let results = index.funnel_search(&query, 3, &funnel).unwrap();
562        assert_eq!(results.len(), 3);
563        // The closest vector should be v0 (angle=0, cos=1, sin=0)
564        assert_eq!(results[0].id, "v0");
565    }
566
567    #[test]
568    fn test_cascade_search() {
569        let mut index = make_index(8);
570        index
571            .insert("a".into(), vec![1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0], None)
572            .unwrap();
573        index
574            .insert("b".into(), vec![1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], None)
575            .unwrap();
576        index
577            .insert("c".into(), vec![0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], None)
578            .unwrap();
579
580        let query = vec![1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0];
581        let results = index.cascade_search(&query, 2, &[2, 4, 8], 1.5).unwrap();
582        assert_eq!(results[0].id, "a");
583    }
584
585    #[test]
586    fn test_search_dim_exceeds_full_dim_error() {
587        let index = make_index(4);
588        let res = index.search(&[1.0, 0.0, 0.0, 0.0], 8, 10);
589        assert!(res.is_err());
590    }
591
592    #[test]
593    fn test_search_empty_index() {
594        let index = make_index(4);
595        let results = index.search(&[1.0, 0.0, 0.0, 0.0], 4, 10).unwrap();
596        assert!(results.is_empty());
597    }
598
599    #[test]
600    fn test_upsert_overwrites() {
601        let mut index = make_index(4);
602        index.insert("v1".into(), vec![1.0, 0.0, 0.0, 0.0], None).unwrap();
603        index.insert("v1".into(), vec![0.0, 1.0, 0.0, 0.0], None).unwrap();
604        assert_eq!(index.len(), 1);
605        let results = index.search(&[0.0, 1.0, 0.0, 0.0], 4, 10).unwrap();
606        assert_eq!(results[0].id, "v1");
607        assert!((results[0].score - 1.0).abs() < 1e-5);
608    }
609
610    #[test]
611    fn test_config_validation_empty_dims() {
612        let res = MatryoshkaIndex::new(MatryoshkaConfig {
613            full_dim: 4,
614            supported_dims: vec![],
615            metric: DistanceMetric::Cosine,
616        });
617        assert!(res.is_err());
618    }
619
620    #[test]
621    fn test_config_validation_dim_exceeds_full() {
622        let res = MatryoshkaIndex::new(MatryoshkaConfig {
623            full_dim: 4,
624            supported_dims: vec![2, 8],
625            metric: DistanceMetric::Cosine,
626        });
627        assert!(res.is_err());
628    }
629
630    #[test]
631    fn test_dot_product_metric() {
632        let config = MatryoshkaConfig {
633            full_dim: 4,
634            supported_dims: vec![2, 4],
635            metric: DistanceMetric::DotProduct,
636        };
637        let mut index = MatryoshkaIndex::new(config).unwrap();
638        index.insert("v1".into(), vec![2.0, 0.0, 0.0, 0.0], None).unwrap();
639        let results = index.search(&[3.0, 0.0, 0.0, 0.0], 4, 10).unwrap();
640        assert!((results[0].score - 6.0).abs() < 1e-5);
641    }
642}