Skip to main content

ruvector_core/advanced_features/
matryoshka.rs

1//! Matryoshka Representation Learning Support
2//!
3//! Implements adaptive-dimension embedding search inspired by Matryoshka
4//! Representation Learning (MRL). Full-dimensional embeddings are stored once,
5//! but searches can be performed at any prefix dimension—smaller prefixes run
6//! faster while larger ones are more accurate.
7//!
8//! # Two-Phase Funnel Search
9//!
10//! The flagship feature is [`MatryoshkaIndex::funnel_search`], which:
11//! 1. Filters candidates at a low dimension (fast, coarse)
12//! 2. Reranks the survivors at full dimension (slower, precise)
13//!
14//! This typically yields the same recall as full-dimension search at a fraction
15//! of the cost.
16//!
17//! # Example
18//!
19//! ```
20//! use ruvector_core::advanced_features::matryoshka::*;
21//! use ruvector_core::types::DistanceMetric;
22//!
23//! let config = MatryoshkaConfig {
24//!     full_dim: 8,
25//!     supported_dims: vec![2, 4, 8],
26//!     metric: DistanceMetric::Cosine,
27//! };
28//! let mut index = MatryoshkaIndex::new(config).unwrap();
29//! index.insert("v1".into(), vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], None).unwrap();
30//! let results = index.search(&[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 4, 10).unwrap();
31//! assert_eq!(results[0].id, "v1");
32//! ```
33
34use crate::error::{Result, RuvectorError};
35use crate::types::{DistanceMetric, SearchResult, VectorId};
36use serde::{Deserialize, Serialize};
37use std::collections::HashMap;
38
39/// Configuration for a Matryoshka embedding index.
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct MatryoshkaConfig {
42    /// The full (maximum) embedding dimension.
43    pub full_dim: usize,
44    /// Supported truncation dimensions, sorted ascending.
45    /// Each must be <= `full_dim`. The last element should equal `full_dim`.
46    pub supported_dims: Vec<usize>,
47    /// Distance metric for similarity computation.
48    pub metric: DistanceMetric,
49}
50
51impl Default for MatryoshkaConfig {
52    fn default() -> Self {
53        Self {
54            full_dim: 768,
55            supported_dims: vec![64, 128, 256, 512, 768],
56            metric: DistanceMetric::Cosine,
57        }
58    }
59}
60
61/// Configuration for the multi-phase funnel search.
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct FunnelConfig {
64    /// Dimension used for the coarse filtering phase.
65    pub filter_dim: usize,
66    /// Multiplier applied to `top_k` to determine how many candidates
67    /// survive the coarse phase. E.g., 4.0 means 4x top_k candidates.
68    pub candidate_multiplier: f32,
69}
70
71impl Default for FunnelConfig {
72    fn default() -> Self {
73        Self {
74            filter_dim: 64,
75            candidate_multiplier: 4.0,
76        }
77    }
78}
79
80/// Entry stored in the Matryoshka index.
81#[derive(Debug, Clone, Serialize, Deserialize)]
82struct MatryoshkaEntry {
83    id: VectorId,
84    /// Full-dimensional embedding.
85    embedding: Vec<f32>,
86    /// Precomputed L2 norm of the full embedding.
87    full_norm: f32,
88    /// Optional metadata.
89    metadata: Option<HashMap<String, serde_json::Value>>,
90}
91
92/// Matryoshka embedding index supporting adaptive-dimension search.
93///
94/// Stores embeddings at full dimensionality but can search at any prefix
95/// dimension for a speed-accuracy trade-off.
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct MatryoshkaIndex {
98    /// Index configuration.
99    pub config: MatryoshkaConfig,
100    /// Stored entries.
101    entries: Vec<MatryoshkaEntry>,
102    /// Map from vector ID to index in `entries`.
103    id_map: HashMap<VectorId, usize>,
104}
105
106impl MatryoshkaIndex {
107    /// Create a new Matryoshka index.
108    ///
109    /// # Errors
110    ///
111    /// Returns an error if `supported_dims` is empty, any dimension is zero,
112    /// or any dimension exceeds `full_dim`.
113    pub fn new(mut config: MatryoshkaConfig) -> Result<Self> {
114        if config.supported_dims.is_empty() {
115            return Err(RuvectorError::InvalidParameter(
116                "supported_dims must not be empty".into(),
117            ));
118        }
119        config.supported_dims.sort_unstable();
120        config.supported_dims.dedup();
121
122        for &d in &config.supported_dims {
123            if d == 0 {
124                return Err(RuvectorError::InvalidParameter(
125                    "Dimensions must be > 0".into(),
126                ));
127            }
128            if d > config.full_dim {
129                return Err(RuvectorError::InvalidParameter(format!(
130                    "Supported dimension {} exceeds full_dim {}",
131                    d, config.full_dim
132                )));
133            }
134        }
135
136        Ok(Self {
137            config,
138            entries: Vec::new(),
139            id_map: HashMap::new(),
140        })
141    }
142
143    /// Insert a full-dimensional embedding into the index.
144    ///
145    /// # Errors
146    ///
147    /// Returns an error if the embedding dimension does not match `full_dim`.
148    pub fn insert(
149        &mut self,
150        id: VectorId,
151        embedding: Vec<f32>,
152        metadata: Option<HashMap<String, serde_json::Value>>,
153    ) -> Result<()> {
154        if embedding.len() != self.config.full_dim {
155            return Err(RuvectorError::DimensionMismatch {
156                expected: self.config.full_dim,
157                actual: embedding.len(),
158            });
159        }
160
161        let full_norm = compute_norm(&embedding);
162
163        if let Some(&existing_idx) = self.id_map.get(&id) {
164            self.entries[existing_idx] = MatryoshkaEntry {
165                id,
166                embedding,
167                full_norm,
168                metadata,
169            };
170        } else {
171            let idx = self.entries.len();
172            self.entries.push(MatryoshkaEntry {
173                id: id.clone(),
174                embedding,
175                full_norm,
176                metadata,
177            });
178            self.id_map.insert(id, idx);
179        }
180
181        Ok(())
182    }
183
184    /// Return the number of stored vectors.
185    pub fn len(&self) -> usize {
186        self.entries.len()
187    }
188
189    /// Check if the index is empty.
190    pub fn is_empty(&self) -> bool {
191        self.entries.is_empty()
192    }
193
194    /// Search at a specific dimension by truncating embeddings to the first
195    /// `dim` components.
196    ///
197    /// # Arguments
198    ///
199    /// * `query` - Full-dimensional (or at least `dim`-dimensional) query vector.
200    /// * `dim` - The truncation dimension to use for search.
201    /// * `top_k` - Number of results to return.
202    ///
203    /// # Errors
204    ///
205    /// Returns an error if `dim` exceeds the query length or `full_dim`.
206    pub fn search(&self, query: &[f32], dim: usize, top_k: usize) -> Result<Vec<SearchResult>> {
207        if dim == 0 {
208            return Err(RuvectorError::InvalidParameter(
209                "Search dimension must be > 0".into(),
210            ));
211        }
212        if dim > self.config.full_dim {
213            return Err(RuvectorError::InvalidParameter(format!(
214                "Search dimension {} exceeds full_dim {}",
215                dim, self.config.full_dim
216            )));
217        }
218        if query.len() < dim {
219            return Err(RuvectorError::DimensionMismatch {
220                expected: dim,
221                actual: query.len(),
222            });
223        }
224
225        let query_prefix = &query[..dim];
226        let query_norm = compute_norm(query_prefix);
227
228        let mut scored: Vec<(usize, f32)> = self
229            .entries
230            .iter()
231            .enumerate()
232            .map(|(idx, entry)| {
233                let doc_prefix = &entry.embedding[..dim];
234                let doc_norm = compute_norm(doc_prefix);
235                let sim = similarity(
236                    query_prefix,
237                    query_norm,
238                    doc_prefix,
239                    doc_norm,
240                    self.config.metric,
241                );
242                (idx, sim)
243            })
244            .collect();
245
246        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
247        scored.truncate(top_k);
248
249        Ok(scored
250            .into_iter()
251            .map(|(idx, score)| {
252                let entry = &self.entries[idx];
253                SearchResult {
254                    id: entry.id.clone(),
255                    score,
256                    vector: None,
257                    metadata: entry.metadata.clone(),
258                }
259            })
260            .collect())
261    }
262
263    /// Two-phase funnel search: coarse filter at low dimension, rerank at full dimension.
264    ///
265    /// 1. Search at `funnel_config.filter_dim` for `candidate_multiplier * top_k` candidates.
266    /// 2. Rerank those candidates at `full_dim`.
267    /// 3. Return the top `top_k`.
268    ///
269    /// # Errors
270    ///
271    /// Returns an error if the query is shorter than `full_dim`.
272    pub fn funnel_search(
273        &self,
274        query: &[f32],
275        top_k: usize,
276        funnel_config: &FunnelConfig,
277    ) -> Result<Vec<SearchResult>> {
278        if query.len() < self.config.full_dim {
279            return Err(RuvectorError::DimensionMismatch {
280                expected: self.config.full_dim,
281                actual: query.len(),
282            });
283        }
284
285        let filter_dim = funnel_config.filter_dim.min(self.config.full_dim);
286        let num_candidates = ((top_k as f32) * funnel_config.candidate_multiplier).ceil() as usize;
287        let num_candidates = num_candidates.max(top_k);
288
289        // Phase 1: coarse search at low dimension.
290        let coarse_results = self.search(query, filter_dim, num_candidates)?;
291
292        // Phase 2: rerank at full dimension.
293        let query_full = &query[..self.config.full_dim];
294        let query_full_norm = compute_norm(query_full);
295
296        let mut reranked: Vec<(VectorId, f32, Option<HashMap<String, serde_json::Value>>)> =
297            coarse_results
298                .into_iter()
299                .filter_map(|r| {
300                    let idx = self.id_map.get(&r.id)?;
301                    let entry = &self.entries[*idx];
302                    let sim = similarity(
303                        query_full,
304                        query_full_norm,
305                        &entry.embedding,
306                        entry.full_norm,
307                        self.config.metric,
308                    );
309                    Some((entry.id.clone(), sim, entry.metadata.clone()))
310                })
311                .collect();
312
313        reranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
314        reranked.truncate(top_k);
315
316        Ok(reranked
317            .into_iter()
318            .map(|(id, score, metadata)| SearchResult {
319                id,
320                score,
321                vector: None,
322                metadata,
323            })
324            .collect())
325    }
326
327    /// Multi-stage cascade search through multiple dimensions.
328    ///
329    /// Searches through dimensions in ascending order, progressively narrowing
330    /// candidates. At each stage, the candidate set is reduced by the
331    /// `reduction_factor`.
332    pub fn cascade_search(
333        &self,
334        query: &[f32],
335        top_k: usize,
336        dims: &[usize],
337        reduction_factor: f32,
338    ) -> Result<Vec<SearchResult>> {
339        if dims.is_empty() {
340            return Err(RuvectorError::InvalidParameter(
341                "Dimension cascade must not be empty".into(),
342            ));
343        }
344        if query.len() < self.config.full_dim {
345            return Err(RuvectorError::DimensionMismatch {
346                expected: self.config.full_dim,
347                actual: query.len(),
348            });
349        }
350
351        // Start with all candidates at the lowest dimension.
352        let mut candidate_indices: Vec<usize> = (0..self.entries.len()).collect();
353
354        for &dim in dims {
355            let dim = dim.min(self.config.full_dim);
356            let query_prefix = &query[..dim];
357            let query_norm = compute_norm(query_prefix);
358
359            let mut scored: Vec<(usize, f32)> = candidate_indices
360                .iter()
361                .map(|&idx| {
362                    let entry = &self.entries[idx];
363                    let doc_prefix = &entry.embedding[..dim];
364                    let doc_norm = compute_norm(doc_prefix);
365                    let sim = similarity(
366                        query_prefix,
367                        query_norm,
368                        doc_prefix,
369                        doc_norm,
370                        self.config.metric,
371                    );
372                    (idx, sim)
373                })
374                .collect();
375
376            scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
377
378            let keep = ((candidate_indices.len() as f32) / reduction_factor)
379                .ceil()
380                .max(top_k as f32) as usize;
381            scored.truncate(keep);
382            candidate_indices = scored.into_iter().map(|(idx, _)| idx).collect();
383        }
384
385        // Final scoring at the last dimension in the cascade.
386        let last_dim = dims.last().copied().unwrap_or(self.config.full_dim);
387        let last_dim = last_dim.min(self.config.full_dim);
388        let query_prefix = &query[..last_dim];
389        let query_norm = compute_norm(query_prefix);
390
391        let mut final_scored: Vec<(usize, f32)> = candidate_indices
392            .iter()
393            .map(|&idx| {
394                let entry = &self.entries[idx];
395                let doc_prefix = &entry.embedding[..last_dim];
396                let doc_norm = compute_norm(doc_prefix);
397                let sim = similarity(
398                    query_prefix,
399                    query_norm,
400                    doc_prefix,
401                    doc_norm,
402                    self.config.metric,
403                );
404                (idx, sim)
405            })
406            .collect();
407
408        final_scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
409        final_scored.truncate(top_k);
410
411        Ok(final_scored
412            .into_iter()
413            .map(|(idx, score)| {
414                let entry = &self.entries[idx];
415                SearchResult {
416                    id: entry.id.clone(),
417                    score,
418                    vector: None,
419                    metadata: entry.metadata.clone(),
420                }
421            })
422            .collect())
423    }
424}
425
426/// Compute the L2 norm of a vector slice.
427#[inline]
428fn compute_norm(v: &[f32]) -> f32 {
429    v.iter().map(|x| x * x).sum::<f32>().sqrt()
430}
431
432/// Compute similarity between two vectors using the given metric and precomputed norms.
433#[inline]
434fn similarity(a: &[f32], norm_a: f32, b: &[f32], norm_b: f32, metric: DistanceMetric) -> f32 {
435    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
436    match metric {
437        DistanceMetric::Cosine => {
438            let denom = norm_a * norm_b;
439            if denom < f32::EPSILON {
440                0.0
441            } else {
442                dot / denom
443            }
444        }
445        DistanceMetric::DotProduct => dot,
446        DistanceMetric::Euclidean => {
447            let dist_sq: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum();
448            1.0 / (1.0 + dist_sq.sqrt())
449        }
450        DistanceMetric::Manhattan => {
451            let dist: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum();
452            1.0 / (1.0 + dist)
453        }
454    }
455}
456
457#[cfg(test)]
458mod tests {
459    use super::*;
460
461    fn make_config(full_dim: usize, dims: Vec<usize>) -> MatryoshkaConfig {
462        MatryoshkaConfig {
463            full_dim,
464            supported_dims: dims,
465            metric: DistanceMetric::Cosine,
466        }
467    }
468
469    fn make_index(full_dim: usize) -> MatryoshkaIndex {
470        let dims: Vec<usize> = (1..=full_dim)
471            .filter(|d| d.is_power_of_two() || *d == full_dim)
472            .collect();
473        MatryoshkaIndex::new(make_config(full_dim, dims)).unwrap()
474    }
475
476    #[test]
477    fn test_insert_and_len() {
478        let mut index = make_index(4);
479        assert!(index.is_empty());
480        index
481            .insert("v1".into(), vec![1.0, 0.0, 0.0, 0.0], None)
482            .unwrap();
483        assert_eq!(index.len(), 1);
484    }
485
486    #[test]
487    fn test_insert_wrong_dimension_error() {
488        let mut index = make_index(4);
489        let res = index.insert("v1".into(), vec![1.0, 0.0], None);
490        assert!(res.is_err());
491    }
492
493    #[test]
494    fn test_search_at_full_dim() {
495        let mut index = make_index(4);
496        index
497            .insert("v1".into(), vec![1.0, 0.0, 0.0, 0.0], None)
498            .unwrap();
499        index
500            .insert("v2".into(), vec![0.0, 1.0, 0.0, 0.0], None)
501            .unwrap();
502
503        let results = index.search(&[1.0, 0.0, 0.0, 0.0], 4, 10).unwrap();
504        assert_eq!(results[0].id, "v1");
505        assert!((results[0].score - 1.0).abs() < 1e-5);
506        // v2 is orthogonal, score should be ~0
507        assert!(results[1].score.abs() < 1e-5);
508    }
509
510    #[test]
511    fn test_search_at_truncated_dim() {
512        let mut index = make_index(4);
513        // Vectors differ only in the last two components
514        index
515            .insert("v1".into(), vec![1.0, 0.0, 1.0, 0.0], None)
516            .unwrap();
517        index
518            .insert("v2".into(), vec![1.0, 0.0, 0.0, 1.0], None)
519            .unwrap();
520
521        // At dim=2, both truncate to [1.0, 0.0] — identical scores
522        let results = index.search(&[1.0, 0.0, 0.5, 0.5], 2, 10).unwrap();
523        assert!((results[0].score - results[1].score).abs() < 1e-5);
524
525        // At dim=4, they should differ
526        let results = index.search(&[1.0, 0.0, 1.0, 0.0], 4, 10).unwrap();
527        assert_eq!(results[0].id, "v1");
528        assert!(results[0].score > results[1].score);
529    }
530
531    #[test]
532    fn test_funnel_search() {
533        let mut index = make_index(8);
534        // Insert vectors that share the same first 2 dims but differ later
535        index
536            .insert(
537                "best".into(),
538                vec![1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
539                None,
540            )
541            .unwrap();
542        index
543            .insert(
544                "good".into(),
545                vec![1.0, 0.0, 0.5, 0.5, 0.0, 0.0, 0.0, 0.0],
546                None,
547            )
548            .unwrap();
549        index
550            .insert(
551                "bad".into(),
552                vec![0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
553                None,
554            )
555            .unwrap();
556
557        let query = vec![1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0];
558        let funnel = FunnelConfig {
559            filter_dim: 2,
560            candidate_multiplier: 2.0,
561        };
562        let results = index.funnel_search(&query, 2, &funnel).unwrap();
563        assert_eq!(results.len(), 2);
564        assert_eq!(results[0].id, "best");
565    }
566
567    #[test]
568    fn test_funnel_search_finds_correct_top_k() {
569        let mut index = make_index(4);
570        for i in 0..20 {
571            let angle = (i as f32) * std::f32::consts::PI / 20.0;
572            index
573                .insert(
574                    format!("v{}", i),
575                    vec![angle.cos(), angle.sin(), 0.0, 0.0],
576                    None,
577                )
578                .unwrap();
579        }
580
581        let query = vec![1.0, 0.0, 0.0, 0.0];
582        let funnel = FunnelConfig {
583            filter_dim: 2,
584            candidate_multiplier: 4.0,
585        };
586        let results = index.funnel_search(&query, 3, &funnel).unwrap();
587        assert_eq!(results.len(), 3);
588        // The closest vector should be v0 (angle=0, cos=1, sin=0)
589        assert_eq!(results[0].id, "v0");
590    }
591
592    #[test]
593    fn test_cascade_search() {
594        let mut index = make_index(8);
595        index
596            .insert(
597                "a".into(),
598                vec![1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0],
599                None,
600            )
601            .unwrap();
602        index
603            .insert(
604                "b".into(),
605                vec![1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
606                None,
607            )
608            .unwrap();
609        index
610            .insert(
611                "c".into(),
612                vec![0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
613                None,
614            )
615            .unwrap();
616
617        let query = vec![1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0];
618        let results = index.cascade_search(&query, 2, &[2, 4, 8], 1.5).unwrap();
619        assert_eq!(results[0].id, "a");
620    }
621
622    #[test]
623    fn test_search_dim_exceeds_full_dim_error() {
624        let index = make_index(4);
625        let res = index.search(&[1.0, 0.0, 0.0, 0.0], 8, 10);
626        assert!(res.is_err());
627    }
628
629    #[test]
630    fn test_search_empty_index() {
631        let index = make_index(4);
632        let results = index.search(&[1.0, 0.0, 0.0, 0.0], 4, 10).unwrap();
633        assert!(results.is_empty());
634    }
635
636    #[test]
637    fn test_upsert_overwrites() {
638        let mut index = make_index(4);
639        index
640            .insert("v1".into(), vec![1.0, 0.0, 0.0, 0.0], None)
641            .unwrap();
642        index
643            .insert("v1".into(), vec![0.0, 1.0, 0.0, 0.0], None)
644            .unwrap();
645        assert_eq!(index.len(), 1);
646        let results = index.search(&[0.0, 1.0, 0.0, 0.0], 4, 10).unwrap();
647        assert_eq!(results[0].id, "v1");
648        assert!((results[0].score - 1.0).abs() < 1e-5);
649    }
650
651    #[test]
652    fn test_config_validation_empty_dims() {
653        let res = MatryoshkaIndex::new(MatryoshkaConfig {
654            full_dim: 4,
655            supported_dims: vec![],
656            metric: DistanceMetric::Cosine,
657        });
658        assert!(res.is_err());
659    }
660
661    #[test]
662    fn test_config_validation_dim_exceeds_full() {
663        let res = MatryoshkaIndex::new(MatryoshkaConfig {
664            full_dim: 4,
665            supported_dims: vec![2, 8],
666            metric: DistanceMetric::Cosine,
667        });
668        assert!(res.is_err());
669    }
670
671    #[test]
672    fn test_dot_product_metric() {
673        let config = MatryoshkaConfig {
674            full_dim: 4,
675            supported_dims: vec![2, 4],
676            metric: DistanceMetric::DotProduct,
677        };
678        let mut index = MatryoshkaIndex::new(config).unwrap();
679        index
680            .insert("v1".into(), vec![2.0, 0.0, 0.0, 0.0], None)
681            .unwrap();
682        let results = index.search(&[3.0, 0.0, 0.0, 0.0], 4, 10).unwrap();
683        assert!((results[0].score - 6.0).abs() < 1e-5);
684    }
685}