1use crate::error::{Result, RuvectorError};
35use crate::types::{DistanceMetric, SearchResult, VectorId};
36use serde::{Deserialize, Serialize};
37use std::collections::HashMap;
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct MatryoshkaConfig {
42 pub full_dim: usize,
44 pub supported_dims: Vec<usize>,
47 pub metric: DistanceMetric,
49}
50
51impl Default for MatryoshkaConfig {
52 fn default() -> Self {
53 Self {
54 full_dim: 768,
55 supported_dims: vec![64, 128, 256, 512, 768],
56 metric: DistanceMetric::Cosine,
57 }
58 }
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct FunnelConfig {
64 pub filter_dim: usize,
66 pub candidate_multiplier: f32,
69}
70
71impl Default for FunnelConfig {
72 fn default() -> Self {
73 Self {
74 filter_dim: 64,
75 candidate_multiplier: 4.0,
76 }
77 }
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
82struct MatryoshkaEntry {
83 id: VectorId,
84 embedding: Vec<f32>,
86 full_norm: f32,
88 metadata: Option<HashMap<String, serde_json::Value>>,
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct MatryoshkaIndex {
98 pub config: MatryoshkaConfig,
100 entries: Vec<MatryoshkaEntry>,
102 id_map: HashMap<VectorId, usize>,
104}
105
106impl MatryoshkaIndex {
107 pub fn new(mut config: MatryoshkaConfig) -> Result<Self> {
114 if config.supported_dims.is_empty() {
115 return Err(RuvectorError::InvalidParameter(
116 "supported_dims must not be empty".into(),
117 ));
118 }
119 config.supported_dims.sort_unstable();
120 config.supported_dims.dedup();
121
122 for &d in &config.supported_dims {
123 if d == 0 {
124 return Err(RuvectorError::InvalidParameter(
125 "Dimensions must be > 0".into(),
126 ));
127 }
128 if d > config.full_dim {
129 return Err(RuvectorError::InvalidParameter(format!(
130 "Supported dimension {} exceeds full_dim {}",
131 d, config.full_dim
132 )));
133 }
134 }
135
136 Ok(Self {
137 config,
138 entries: Vec::new(),
139 id_map: HashMap::new(),
140 })
141 }
142
143 pub fn insert(
149 &mut self,
150 id: VectorId,
151 embedding: Vec<f32>,
152 metadata: Option<HashMap<String, serde_json::Value>>,
153 ) -> Result<()> {
154 if embedding.len() != self.config.full_dim {
155 return Err(RuvectorError::DimensionMismatch {
156 expected: self.config.full_dim,
157 actual: embedding.len(),
158 });
159 }
160
161 let full_norm = compute_norm(&embedding);
162
163 if let Some(&existing_idx) = self.id_map.get(&id) {
164 self.entries[existing_idx] = MatryoshkaEntry {
165 id,
166 embedding,
167 full_norm,
168 metadata,
169 };
170 } else {
171 let idx = self.entries.len();
172 self.entries.push(MatryoshkaEntry {
173 id: id.clone(),
174 embedding,
175 full_norm,
176 metadata,
177 });
178 self.id_map.insert(id, idx);
179 }
180
181 Ok(())
182 }
183
184 pub fn len(&self) -> usize {
186 self.entries.len()
187 }
188
189 pub fn is_empty(&self) -> bool {
191 self.entries.is_empty()
192 }
193
194 pub fn search(&self, query: &[f32], dim: usize, top_k: usize) -> Result<Vec<SearchResult>> {
207 if dim == 0 {
208 return Err(RuvectorError::InvalidParameter(
209 "Search dimension must be > 0".into(),
210 ));
211 }
212 if dim > self.config.full_dim {
213 return Err(RuvectorError::InvalidParameter(format!(
214 "Search dimension {} exceeds full_dim {}",
215 dim, self.config.full_dim
216 )));
217 }
218 if query.len() < dim {
219 return Err(RuvectorError::DimensionMismatch {
220 expected: dim,
221 actual: query.len(),
222 });
223 }
224
225 let query_prefix = &query[..dim];
226 let query_norm = compute_norm(query_prefix);
227
228 let mut scored: Vec<(usize, f32)> = self
229 .entries
230 .iter()
231 .enumerate()
232 .map(|(idx, entry)| {
233 let doc_prefix = &entry.embedding[..dim];
234 let doc_norm = compute_norm(doc_prefix);
235 let sim = similarity(
236 query_prefix,
237 query_norm,
238 doc_prefix,
239 doc_norm,
240 self.config.metric,
241 );
242 (idx, sim)
243 })
244 .collect();
245
246 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
247 scored.truncate(top_k);
248
249 Ok(scored
250 .into_iter()
251 .map(|(idx, score)| {
252 let entry = &self.entries[idx];
253 SearchResult {
254 id: entry.id.clone(),
255 score,
256 vector: None,
257 metadata: entry.metadata.clone(),
258 }
259 })
260 .collect())
261 }
262
263 pub fn funnel_search(
273 &self,
274 query: &[f32],
275 top_k: usize,
276 funnel_config: &FunnelConfig,
277 ) -> Result<Vec<SearchResult>> {
278 if query.len() < self.config.full_dim {
279 return Err(RuvectorError::DimensionMismatch {
280 expected: self.config.full_dim,
281 actual: query.len(),
282 });
283 }
284
285 let filter_dim = funnel_config.filter_dim.min(self.config.full_dim);
286 let num_candidates = ((top_k as f32) * funnel_config.candidate_multiplier).ceil() as usize;
287 let num_candidates = num_candidates.max(top_k);
288
289 let coarse_results = self.search(query, filter_dim, num_candidates)?;
291
292 let query_full = &query[..self.config.full_dim];
294 let query_full_norm = compute_norm(query_full);
295
296 let mut reranked: Vec<(VectorId, f32, Option<HashMap<String, serde_json::Value>>)> =
297 coarse_results
298 .into_iter()
299 .filter_map(|r| {
300 let idx = self.id_map.get(&r.id)?;
301 let entry = &self.entries[*idx];
302 let sim = similarity(
303 query_full,
304 query_full_norm,
305 &entry.embedding,
306 entry.full_norm,
307 self.config.metric,
308 );
309 Some((entry.id.clone(), sim, entry.metadata.clone()))
310 })
311 .collect();
312
313 reranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
314 reranked.truncate(top_k);
315
316 Ok(reranked
317 .into_iter()
318 .map(|(id, score, metadata)| SearchResult {
319 id,
320 score,
321 vector: None,
322 metadata,
323 })
324 .collect())
325 }
326
327 pub fn cascade_search(
333 &self,
334 query: &[f32],
335 top_k: usize,
336 dims: &[usize],
337 reduction_factor: f32,
338 ) -> Result<Vec<SearchResult>> {
339 if dims.is_empty() {
340 return Err(RuvectorError::InvalidParameter(
341 "Dimension cascade must not be empty".into(),
342 ));
343 }
344 if query.len() < self.config.full_dim {
345 return Err(RuvectorError::DimensionMismatch {
346 expected: self.config.full_dim,
347 actual: query.len(),
348 });
349 }
350
351 let mut candidate_indices: Vec<usize> = (0..self.entries.len()).collect();
353
354 for &dim in dims {
355 let dim = dim.min(self.config.full_dim);
356 let query_prefix = &query[..dim];
357 let query_norm = compute_norm(query_prefix);
358
359 let mut scored: Vec<(usize, f32)> = candidate_indices
360 .iter()
361 .map(|&idx| {
362 let entry = &self.entries[idx];
363 let doc_prefix = &entry.embedding[..dim];
364 let doc_norm = compute_norm(doc_prefix);
365 let sim = similarity(
366 query_prefix,
367 query_norm,
368 doc_prefix,
369 doc_norm,
370 self.config.metric,
371 );
372 (idx, sim)
373 })
374 .collect();
375
376 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
377
378 let keep = ((candidate_indices.len() as f32) / reduction_factor)
379 .ceil()
380 .max(top_k as f32) as usize;
381 scored.truncate(keep);
382 candidate_indices = scored.into_iter().map(|(idx, _)| idx).collect();
383 }
384
385 let last_dim = dims.last().copied().unwrap_or(self.config.full_dim);
387 let last_dim = last_dim.min(self.config.full_dim);
388 let query_prefix = &query[..last_dim];
389 let query_norm = compute_norm(query_prefix);
390
391 let mut final_scored: Vec<(usize, f32)> = candidate_indices
392 .iter()
393 .map(|&idx| {
394 let entry = &self.entries[idx];
395 let doc_prefix = &entry.embedding[..last_dim];
396 let doc_norm = compute_norm(doc_prefix);
397 let sim = similarity(
398 query_prefix,
399 query_norm,
400 doc_prefix,
401 doc_norm,
402 self.config.metric,
403 );
404 (idx, sim)
405 })
406 .collect();
407
408 final_scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
409 final_scored.truncate(top_k);
410
411 Ok(final_scored
412 .into_iter()
413 .map(|(idx, score)| {
414 let entry = &self.entries[idx];
415 SearchResult {
416 id: entry.id.clone(),
417 score,
418 vector: None,
419 metadata: entry.metadata.clone(),
420 }
421 })
422 .collect())
423 }
424}
425
426#[inline]
428fn compute_norm(v: &[f32]) -> f32 {
429 v.iter().map(|x| x * x).sum::<f32>().sqrt()
430}
431
432#[inline]
434fn similarity(a: &[f32], norm_a: f32, b: &[f32], norm_b: f32, metric: DistanceMetric) -> f32 {
435 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
436 match metric {
437 DistanceMetric::Cosine => {
438 let denom = norm_a * norm_b;
439 if denom < f32::EPSILON {
440 0.0
441 } else {
442 dot / denom
443 }
444 }
445 DistanceMetric::DotProduct => dot,
446 DistanceMetric::Euclidean => {
447 let dist_sq: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum();
448 1.0 / (1.0 + dist_sq.sqrt())
449 }
450 DistanceMetric::Manhattan => {
451 let dist: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum();
452 1.0 / (1.0 + dist)
453 }
454 }
455}
456
457#[cfg(test)]
458mod tests {
459 use super::*;
460
461 fn make_config(full_dim: usize, dims: Vec<usize>) -> MatryoshkaConfig {
462 MatryoshkaConfig {
463 full_dim,
464 supported_dims: dims,
465 metric: DistanceMetric::Cosine,
466 }
467 }
468
469 fn make_index(full_dim: usize) -> MatryoshkaIndex {
470 let dims: Vec<usize> = (1..=full_dim)
471 .filter(|d| d.is_power_of_two() || *d == full_dim)
472 .collect();
473 MatryoshkaIndex::new(make_config(full_dim, dims)).unwrap()
474 }
475
476 #[test]
477 fn test_insert_and_len() {
478 let mut index = make_index(4);
479 assert!(index.is_empty());
480 index
481 .insert("v1".into(), vec![1.0, 0.0, 0.0, 0.0], None)
482 .unwrap();
483 assert_eq!(index.len(), 1);
484 }
485
486 #[test]
487 fn test_insert_wrong_dimension_error() {
488 let mut index = make_index(4);
489 let res = index.insert("v1".into(), vec![1.0, 0.0], None);
490 assert!(res.is_err());
491 }
492
493 #[test]
494 fn test_search_at_full_dim() {
495 let mut index = make_index(4);
496 index
497 .insert("v1".into(), vec![1.0, 0.0, 0.0, 0.0], None)
498 .unwrap();
499 index
500 .insert("v2".into(), vec![0.0, 1.0, 0.0, 0.0], None)
501 .unwrap();
502
503 let results = index.search(&[1.0, 0.0, 0.0, 0.0], 4, 10).unwrap();
504 assert_eq!(results[0].id, "v1");
505 assert!((results[0].score - 1.0).abs() < 1e-5);
506 assert!(results[1].score.abs() < 1e-5);
508 }
509
510 #[test]
511 fn test_search_at_truncated_dim() {
512 let mut index = make_index(4);
513 index
515 .insert("v1".into(), vec![1.0, 0.0, 1.0, 0.0], None)
516 .unwrap();
517 index
518 .insert("v2".into(), vec![1.0, 0.0, 0.0, 1.0], None)
519 .unwrap();
520
521 let results = index.search(&[1.0, 0.0, 0.5, 0.5], 2, 10).unwrap();
523 assert!((results[0].score - results[1].score).abs() < 1e-5);
524
525 let results = index.search(&[1.0, 0.0, 1.0, 0.0], 4, 10).unwrap();
527 assert_eq!(results[0].id, "v1");
528 assert!(results[0].score > results[1].score);
529 }
530
531 #[test]
532 fn test_funnel_search() {
533 let mut index = make_index(8);
534 index
536 .insert(
537 "best".into(),
538 vec![1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
539 None,
540 )
541 .unwrap();
542 index
543 .insert(
544 "good".into(),
545 vec![1.0, 0.0, 0.5, 0.5, 0.0, 0.0, 0.0, 0.0],
546 None,
547 )
548 .unwrap();
549 index
550 .insert(
551 "bad".into(),
552 vec![0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
553 None,
554 )
555 .unwrap();
556
557 let query = vec![1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0];
558 let funnel = FunnelConfig {
559 filter_dim: 2,
560 candidate_multiplier: 2.0,
561 };
562 let results = index.funnel_search(&query, 2, &funnel).unwrap();
563 assert_eq!(results.len(), 2);
564 assert_eq!(results[0].id, "best");
565 }
566
567 #[test]
568 fn test_funnel_search_finds_correct_top_k() {
569 let mut index = make_index(4);
570 for i in 0..20 {
571 let angle = (i as f32) * std::f32::consts::PI / 20.0;
572 index
573 .insert(
574 format!("v{}", i),
575 vec![angle.cos(), angle.sin(), 0.0, 0.0],
576 None,
577 )
578 .unwrap();
579 }
580
581 let query = vec![1.0, 0.0, 0.0, 0.0];
582 let funnel = FunnelConfig {
583 filter_dim: 2,
584 candidate_multiplier: 4.0,
585 };
586 let results = index.funnel_search(&query, 3, &funnel).unwrap();
587 assert_eq!(results.len(), 3);
588 assert_eq!(results[0].id, "v0");
590 }
591
592 #[test]
593 fn test_cascade_search() {
594 let mut index = make_index(8);
595 index
596 .insert(
597 "a".into(),
598 vec![1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0],
599 None,
600 )
601 .unwrap();
602 index
603 .insert(
604 "b".into(),
605 vec![1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
606 None,
607 )
608 .unwrap();
609 index
610 .insert(
611 "c".into(),
612 vec![0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
613 None,
614 )
615 .unwrap();
616
617 let query = vec![1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0];
618 let results = index.cascade_search(&query, 2, &[2, 4, 8], 1.5).unwrap();
619 assert_eq!(results[0].id, "a");
620 }
621
622 #[test]
623 fn test_search_dim_exceeds_full_dim_error() {
624 let index = make_index(4);
625 let res = index.search(&[1.0, 0.0, 0.0, 0.0], 8, 10);
626 assert!(res.is_err());
627 }
628
629 #[test]
630 fn test_search_empty_index() {
631 let index = make_index(4);
632 let results = index.search(&[1.0, 0.0, 0.0, 0.0], 4, 10).unwrap();
633 assert!(results.is_empty());
634 }
635
636 #[test]
637 fn test_upsert_overwrites() {
638 let mut index = make_index(4);
639 index
640 .insert("v1".into(), vec![1.0, 0.0, 0.0, 0.0], None)
641 .unwrap();
642 index
643 .insert("v1".into(), vec![0.0, 1.0, 0.0, 0.0], None)
644 .unwrap();
645 assert_eq!(index.len(), 1);
646 let results = index.search(&[0.0, 1.0, 0.0, 0.0], 4, 10).unwrap();
647 assert_eq!(results[0].id, "v1");
648 assert!((results[0].score - 1.0).abs() < 1e-5);
649 }
650
651 #[test]
652 fn test_config_validation_empty_dims() {
653 let res = MatryoshkaIndex::new(MatryoshkaConfig {
654 full_dim: 4,
655 supported_dims: vec![],
656 metric: DistanceMetric::Cosine,
657 });
658 assert!(res.is_err());
659 }
660
661 #[test]
662 fn test_config_validation_dim_exceeds_full() {
663 let res = MatryoshkaIndex::new(MatryoshkaConfig {
664 full_dim: 4,
665 supported_dims: vec![2, 8],
666 metric: DistanceMetric::Cosine,
667 });
668 assert!(res.is_err());
669 }
670
671 #[test]
672 fn test_dot_product_metric() {
673 let config = MatryoshkaConfig {
674 full_dim: 4,
675 supported_dims: vec![2, 4],
676 metric: DistanceMetric::DotProduct,
677 };
678 let mut index = MatryoshkaIndex::new(config).unwrap();
679 index
680 .insert("v1".into(), vec![2.0, 0.0, 0.0, 0.0], None)
681 .unwrap();
682 let results = index.search(&[3.0, 0.0, 0.0, 0.0], 4, 10).unwrap();
683 assert!((results[0].score - 6.0).abs() < 1e-5);
684 }
685}