1use crate::error::{Result, RuvectorError};
35use crate::types::{DistanceMetric, SearchResult, VectorId};
36use serde::{Deserialize, Serialize};
37use std::collections::HashMap;
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct MatryoshkaConfig {
42 pub full_dim: usize,
44 pub supported_dims: Vec<usize>,
47 pub metric: DistanceMetric,
49}
50
51impl Default for MatryoshkaConfig {
52 fn default() -> Self {
53 Self {
54 full_dim: 768,
55 supported_dims: vec![64, 128, 256, 512, 768],
56 metric: DistanceMetric::Cosine,
57 }
58 }
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct FunnelConfig {
64 pub filter_dim: usize,
66 pub candidate_multiplier: f32,
69}
70
71impl Default for FunnelConfig {
72 fn default() -> Self {
73 Self {
74 filter_dim: 64,
75 candidate_multiplier: 4.0,
76 }
77 }
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
82struct MatryoshkaEntry {
83 id: VectorId,
84 embedding: Vec<f32>,
86 full_norm: f32,
88 metadata: Option<HashMap<String, serde_json::Value>>,
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct MatryoshkaIndex {
98 pub config: MatryoshkaConfig,
100 entries: Vec<MatryoshkaEntry>,
102 id_map: HashMap<VectorId, usize>,
104}
105
106impl MatryoshkaIndex {
107 pub fn new(mut config: MatryoshkaConfig) -> Result<Self> {
114 if config.supported_dims.is_empty() {
115 return Err(RuvectorError::InvalidParameter(
116 "supported_dims must not be empty".into(),
117 ));
118 }
119 config.supported_dims.sort_unstable();
120 config.supported_dims.dedup();
121
122 for &d in &config.supported_dims {
123 if d == 0 {
124 return Err(RuvectorError::InvalidParameter(
125 "Dimensions must be > 0".into(),
126 ));
127 }
128 if d > config.full_dim {
129 return Err(RuvectorError::InvalidParameter(format!(
130 "Supported dimension {} exceeds full_dim {}",
131 d, config.full_dim
132 )));
133 }
134 }
135
136 Ok(Self {
137 config,
138 entries: Vec::new(),
139 id_map: HashMap::new(),
140 })
141 }
142
143 pub fn insert(
149 &mut self,
150 id: VectorId,
151 embedding: Vec<f32>,
152 metadata: Option<HashMap<String, serde_json::Value>>,
153 ) -> Result<()> {
154 if embedding.len() != self.config.full_dim {
155 return Err(RuvectorError::DimensionMismatch {
156 expected: self.config.full_dim,
157 actual: embedding.len(),
158 });
159 }
160
161 let full_norm = compute_norm(&embedding);
162
163 if let Some(&existing_idx) = self.id_map.get(&id) {
164 self.entries[existing_idx] = MatryoshkaEntry {
165 id,
166 embedding,
167 full_norm,
168 metadata,
169 };
170 } else {
171 let idx = self.entries.len();
172 self.entries.push(MatryoshkaEntry {
173 id: id.clone(),
174 embedding,
175 full_norm,
176 metadata,
177 });
178 self.id_map.insert(id, idx);
179 }
180
181 Ok(())
182 }
183
184 pub fn len(&self) -> usize {
186 self.entries.len()
187 }
188
189 pub fn is_empty(&self) -> bool {
191 self.entries.is_empty()
192 }
193
194 pub fn search(
207 &self,
208 query: &[f32],
209 dim: usize,
210 top_k: usize,
211 ) -> Result<Vec<SearchResult>> {
212 if dim == 0 {
213 return Err(RuvectorError::InvalidParameter(
214 "Search dimension must be > 0".into(),
215 ));
216 }
217 if dim > self.config.full_dim {
218 return Err(RuvectorError::InvalidParameter(format!(
219 "Search dimension {} exceeds full_dim {}",
220 dim, self.config.full_dim
221 )));
222 }
223 if query.len() < dim {
224 return Err(RuvectorError::DimensionMismatch {
225 expected: dim,
226 actual: query.len(),
227 });
228 }
229
230 let query_prefix = &query[..dim];
231 let query_norm = compute_norm(query_prefix);
232
233 let mut scored: Vec<(usize, f32)> = self
234 .entries
235 .iter()
236 .enumerate()
237 .map(|(idx, entry)| {
238 let doc_prefix = &entry.embedding[..dim];
239 let doc_norm = compute_norm(doc_prefix);
240 let sim = similarity(query_prefix, query_norm, doc_prefix, doc_norm, self.config.metric);
241 (idx, sim)
242 })
243 .collect();
244
245 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
246 scored.truncate(top_k);
247
248 Ok(scored
249 .into_iter()
250 .map(|(idx, score)| {
251 let entry = &self.entries[idx];
252 SearchResult {
253 id: entry.id.clone(),
254 score,
255 vector: None,
256 metadata: entry.metadata.clone(),
257 }
258 })
259 .collect())
260 }
261
262 pub fn funnel_search(
272 &self,
273 query: &[f32],
274 top_k: usize,
275 funnel_config: &FunnelConfig,
276 ) -> Result<Vec<SearchResult>> {
277 if query.len() < self.config.full_dim {
278 return Err(RuvectorError::DimensionMismatch {
279 expected: self.config.full_dim,
280 actual: query.len(),
281 });
282 }
283
284 let filter_dim = funnel_config.filter_dim.min(self.config.full_dim);
285 let num_candidates = ((top_k as f32) * funnel_config.candidate_multiplier).ceil() as usize;
286 let num_candidates = num_candidates.max(top_k);
287
288 let coarse_results = self.search(query, filter_dim, num_candidates)?;
290
291 let query_full = &query[..self.config.full_dim];
293 let query_full_norm = compute_norm(query_full);
294
295 let mut reranked: Vec<(VectorId, f32, Option<HashMap<String, serde_json::Value>>)> =
296 coarse_results
297 .into_iter()
298 .filter_map(|r| {
299 let idx = self.id_map.get(&r.id)?;
300 let entry = &self.entries[*idx];
301 let sim = similarity(
302 query_full,
303 query_full_norm,
304 &entry.embedding,
305 entry.full_norm,
306 self.config.metric,
307 );
308 Some((entry.id.clone(), sim, entry.metadata.clone()))
309 })
310 .collect();
311
312 reranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
313 reranked.truncate(top_k);
314
315 Ok(reranked
316 .into_iter()
317 .map(|(id, score, metadata)| SearchResult {
318 id,
319 score,
320 vector: None,
321 metadata,
322 })
323 .collect())
324 }
325
326 pub fn cascade_search(
332 &self,
333 query: &[f32],
334 top_k: usize,
335 dims: &[usize],
336 reduction_factor: f32,
337 ) -> Result<Vec<SearchResult>> {
338 if dims.is_empty() {
339 return Err(RuvectorError::InvalidParameter(
340 "Dimension cascade must not be empty".into(),
341 ));
342 }
343 if query.len() < self.config.full_dim {
344 return Err(RuvectorError::DimensionMismatch {
345 expected: self.config.full_dim,
346 actual: query.len(),
347 });
348 }
349
350 let mut candidate_indices: Vec<usize> = (0..self.entries.len()).collect();
352
353 for &dim in dims {
354 let dim = dim.min(self.config.full_dim);
355 let query_prefix = &query[..dim];
356 let query_norm = compute_norm(query_prefix);
357
358 let mut scored: Vec<(usize, f32)> = candidate_indices
359 .iter()
360 .map(|&idx| {
361 let entry = &self.entries[idx];
362 let doc_prefix = &entry.embedding[..dim];
363 let doc_norm = compute_norm(doc_prefix);
364 let sim = similarity(
365 query_prefix,
366 query_norm,
367 doc_prefix,
368 doc_norm,
369 self.config.metric,
370 );
371 (idx, sim)
372 })
373 .collect();
374
375 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
376
377 let keep = ((candidate_indices.len() as f32) / reduction_factor)
378 .ceil()
379 .max(top_k as f32) as usize;
380 scored.truncate(keep);
381 candidate_indices = scored.into_iter().map(|(idx, _)| idx).collect();
382 }
383
384 let last_dim = dims.last().copied().unwrap_or(self.config.full_dim);
386 let last_dim = last_dim.min(self.config.full_dim);
387 let query_prefix = &query[..last_dim];
388 let query_norm = compute_norm(query_prefix);
389
390 let mut final_scored: Vec<(usize, f32)> = candidate_indices
391 .iter()
392 .map(|&idx| {
393 let entry = &self.entries[idx];
394 let doc_prefix = &entry.embedding[..last_dim];
395 let doc_norm = compute_norm(doc_prefix);
396 let sim = similarity(
397 query_prefix,
398 query_norm,
399 doc_prefix,
400 doc_norm,
401 self.config.metric,
402 );
403 (idx, sim)
404 })
405 .collect();
406
407 final_scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
408 final_scored.truncate(top_k);
409
410 Ok(final_scored
411 .into_iter()
412 .map(|(idx, score)| {
413 let entry = &self.entries[idx];
414 SearchResult {
415 id: entry.id.clone(),
416 score,
417 vector: None,
418 metadata: entry.metadata.clone(),
419 }
420 })
421 .collect())
422 }
423}
424
425#[inline]
427fn compute_norm(v: &[f32]) -> f32 {
428 v.iter().map(|x| x * x).sum::<f32>().sqrt()
429}
430
431#[inline]
433fn similarity(a: &[f32], norm_a: f32, b: &[f32], norm_b: f32, metric: DistanceMetric) -> f32 {
434 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
435 match metric {
436 DistanceMetric::Cosine => {
437 let denom = norm_a * norm_b;
438 if denom < f32::EPSILON {
439 0.0
440 } else {
441 dot / denom
442 }
443 }
444 DistanceMetric::DotProduct => dot,
445 DistanceMetric::Euclidean => {
446 let dist_sq: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum();
447 1.0 / (1.0 + dist_sq.sqrt())
448 }
449 DistanceMetric::Manhattan => {
450 let dist: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum();
451 1.0 / (1.0 + dist)
452 }
453 }
454}
455
456#[cfg(test)]
457mod tests {
458 use super::*;
459
460 fn make_config(full_dim: usize, dims: Vec<usize>) -> MatryoshkaConfig {
461 MatryoshkaConfig {
462 full_dim,
463 supported_dims: dims,
464 metric: DistanceMetric::Cosine,
465 }
466 }
467
468 fn make_index(full_dim: usize) -> MatryoshkaIndex {
469 let dims: Vec<usize> = (1..=full_dim).filter(|d| d.is_power_of_two() || *d == full_dim).collect();
470 MatryoshkaIndex::new(make_config(full_dim, dims)).unwrap()
471 }
472
473 #[test]
474 fn test_insert_and_len() {
475 let mut index = make_index(4);
476 assert!(index.is_empty());
477 index.insert("v1".into(), vec![1.0, 0.0, 0.0, 0.0], None).unwrap();
478 assert_eq!(index.len(), 1);
479 }
480
481 #[test]
482 fn test_insert_wrong_dimension_error() {
483 let mut index = make_index(4);
484 let res = index.insert("v1".into(), vec![1.0, 0.0], None);
485 assert!(res.is_err());
486 }
487
488 #[test]
489 fn test_search_at_full_dim() {
490 let mut index = make_index(4);
491 index.insert("v1".into(), vec![1.0, 0.0, 0.0, 0.0], None).unwrap();
492 index.insert("v2".into(), vec![0.0, 1.0, 0.0, 0.0], None).unwrap();
493
494 let results = index.search(&[1.0, 0.0, 0.0, 0.0], 4, 10).unwrap();
495 assert_eq!(results[0].id, "v1");
496 assert!((results[0].score - 1.0).abs() < 1e-5);
497 assert!(results[1].score.abs() < 1e-5);
499 }
500
501 #[test]
502 fn test_search_at_truncated_dim() {
503 let mut index = make_index(4);
504 index.insert("v1".into(), vec![1.0, 0.0, 1.0, 0.0], None).unwrap();
506 index.insert("v2".into(), vec![1.0, 0.0, 0.0, 1.0], None).unwrap();
507
508 let results = index.search(&[1.0, 0.0, 0.5, 0.5], 2, 10).unwrap();
510 assert!((results[0].score - results[1].score).abs() < 1e-5);
511
512 let results = index.search(&[1.0, 0.0, 1.0, 0.0], 4, 10).unwrap();
514 assert_eq!(results[0].id, "v1");
515 assert!(results[0].score > results[1].score);
516 }
517
518 #[test]
519 fn test_funnel_search() {
520 let mut index = make_index(8);
521 index
523 .insert("best".into(), vec![1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], None)
524 .unwrap();
525 index
526 .insert("good".into(), vec![1.0, 0.0, 0.5, 0.5, 0.0, 0.0, 0.0, 0.0], None)
527 .unwrap();
528 index
529 .insert("bad".into(), vec![0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], None)
530 .unwrap();
531
532 let query = vec![1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0];
533 let funnel = FunnelConfig {
534 filter_dim: 2,
535 candidate_multiplier: 2.0,
536 };
537 let results = index.funnel_search(&query, 2, &funnel).unwrap();
538 assert_eq!(results.len(), 2);
539 assert_eq!(results[0].id, "best");
540 }
541
542 #[test]
543 fn test_funnel_search_finds_correct_top_k() {
544 let mut index = make_index(4);
545 for i in 0..20 {
546 let angle = (i as f32) * std::f32::consts::PI / 20.0;
547 index
548 .insert(
549 format!("v{}", i),
550 vec![angle.cos(), angle.sin(), 0.0, 0.0],
551 None,
552 )
553 .unwrap();
554 }
555
556 let query = vec![1.0, 0.0, 0.0, 0.0];
557 let funnel = FunnelConfig {
558 filter_dim: 2,
559 candidate_multiplier: 4.0,
560 };
561 let results = index.funnel_search(&query, 3, &funnel).unwrap();
562 assert_eq!(results.len(), 3);
563 assert_eq!(results[0].id, "v0");
565 }
566
567 #[test]
568 fn test_cascade_search() {
569 let mut index = make_index(8);
570 index
571 .insert("a".into(), vec![1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0], None)
572 .unwrap();
573 index
574 .insert("b".into(), vec![1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0], None)
575 .unwrap();
576 index
577 .insert("c".into(), vec![0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], None)
578 .unwrap();
579
580 let query = vec![1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0];
581 let results = index.cascade_search(&query, 2, &[2, 4, 8], 1.5).unwrap();
582 assert_eq!(results[0].id, "a");
583 }
584
585 #[test]
586 fn test_search_dim_exceeds_full_dim_error() {
587 let index = make_index(4);
588 let res = index.search(&[1.0, 0.0, 0.0, 0.0], 8, 10);
589 assert!(res.is_err());
590 }
591
592 #[test]
593 fn test_search_empty_index() {
594 let index = make_index(4);
595 let results = index.search(&[1.0, 0.0, 0.0, 0.0], 4, 10).unwrap();
596 assert!(results.is_empty());
597 }
598
599 #[test]
600 fn test_upsert_overwrites() {
601 let mut index = make_index(4);
602 index.insert("v1".into(), vec![1.0, 0.0, 0.0, 0.0], None).unwrap();
603 index.insert("v1".into(), vec![0.0, 1.0, 0.0, 0.0], None).unwrap();
604 assert_eq!(index.len(), 1);
605 let results = index.search(&[0.0, 1.0, 0.0, 0.0], 4, 10).unwrap();
606 assert_eq!(results[0].id, "v1");
607 assert!((results[0].score - 1.0).abs() < 1e-5);
608 }
609
610 #[test]
611 fn test_config_validation_empty_dims() {
612 let res = MatryoshkaIndex::new(MatryoshkaConfig {
613 full_dim: 4,
614 supported_dims: vec![],
615 metric: DistanceMetric::Cosine,
616 });
617 assert!(res.is_err());
618 }
619
620 #[test]
621 fn test_config_validation_dim_exceeds_full() {
622 let res = MatryoshkaIndex::new(MatryoshkaConfig {
623 full_dim: 4,
624 supported_dims: vec![2, 8],
625 metric: DistanceMetric::Cosine,
626 });
627 assert!(res.is_err());
628 }
629
630 #[test]
631 fn test_dot_product_metric() {
632 let config = MatryoshkaConfig {
633 full_dim: 4,
634 supported_dims: vec![2, 4],
635 metric: DistanceMetric::DotProduct,
636 };
637 let mut index = MatryoshkaIndex::new(config).unwrap();
638 index.insert("v1".into(), vec![2.0, 0.0, 0.0, 0.0], None).unwrap();
639 let results = index.search(&[3.0, 0.0, 0.0, 0.0], 4, 10).unwrap();
640 assert!((results[0].score - 6.0).abs() < 1e-5);
641 }
642}