Skip to main content

trueno_rag/multivector/
search.rs

1//! WARP search algorithm components
2//!
3//! This module implements the three phases of WARP search:
4//!
5//! 1. **Centroid Selection** - For each query token, find top-nprobe centroids
6//! 2. **Candidate Scoring** - Decompress and score tokens from selected centroids
7//! 3. **Score Merging** - Aggregate per-token scores into document scores via MaxSim
8
9use crate::multivector::{codec::ResidualCodec, types::WarpSearchConfig, MultiVectorEmbedding};
10use crate::ChunkId;
11use std::collections::HashMap;
12
13/// Phase 1: Select top centroids per query token.
14///
15/// For each query token, compute its similarity with all centroids and
16/// select the top-nprobe centroids above the score threshold.
17pub struct CentroidSelector;
18
19impl CentroidSelector {
20    /// Select top centroids for each query token.
21    ///
22    /// # Arguments
23    ///
24    /// * `query` - Query multi-vector embedding
25    /// * `centroids` - Flattened centroid vectors [num_centroids × dim]
26    /// * `dim` - Token embedding dimension
27    /// * `config` - Search configuration
28    ///
29    /// # Returns
30    ///
31    /// For each query token, a vector of (centroid_id, centroid_score) pairs
32    /// sorted by score descending.
33    #[must_use]
34    pub fn select(
35        query: &MultiVectorEmbedding,
36        centroids: &[f32],
37        dim: usize,
38        config: &WarpSearchConfig,
39    ) -> Vec<Vec<(usize, f32)>> {
40        let num_centroids = centroids.len() / dim;
41
42        query
43            .tokens()
44            .map(|query_token| {
45                // Compute scores with all centroids
46                let mut scores: Vec<(usize, f32)> = (0..num_centroids)
47                    .map(|c| {
48                        let centroid = &centroids[c * dim..(c + 1) * dim];
49                        let score = Self::dot_product(query_token, centroid);
50                        (c, score)
51                    })
52                    .collect();
53
54                // Sort by score descending
55                scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
56
57                // Take top nprobe, filtered by threshold
58                scores
59                    .into_iter()
60                    .take(config.nprobe as usize)
61                    .filter(|(_, score)| *score >= config.centroid_score_threshold)
62                    .collect()
63            })
64            .collect()
65    }
66
67    /// Batch compute centroid scores for a single query token.
68    ///
69    /// Returns scores for all centroids sorted by score descending.
70    #[must_use]
71    pub fn batch_scores(query_token: &[f32], centroids: &[f32], dim: usize) -> Vec<(usize, f32)> {
72        let num_centroids = centroids.len() / dim;
73
74        let mut scores: Vec<(usize, f32)> = (0..num_centroids)
75            .map(|c| {
76                let centroid = &centroids[c * dim..(c + 1) * dim];
77                let score = Self::dot_product(query_token, centroid);
78                (c, score)
79            })
80            .collect();
81
82        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
83        scores
84    }
85
86    fn dot_product(a: &[f32], b: &[f32]) -> f32 {
87        a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
88    }
89}
90
91/// Phase 2: Score candidates from a centroid.
92///
93/// For a single query token and centroid, decompress and score all
94/// document tokens assigned to that centroid.
95pub struct CandidateScorer;
96
97impl CandidateScorer {
98    /// Score candidates from a centroid for one query token.
99    ///
100    /// # Arguments
101    ///
102    /// * `query_token` - Query embedding for this token
103    /// * `centroid_id` - Selected centroid ID
104    /// * `centroid_score` - Precomputed q · c
105    /// * `codec` - Residual codec for decompression
106    /// * `sizes` - Number of tokens per centroid
107    /// * `offsets` - Cumulative offsets per centroid
108    /// * `chunk_ids` - Chunk IDs for all tokens
109    /// * `token_indices` - Token indices within chunks
110    /// * `residuals` - Packed residuals for all tokens
111    /// * `bytes_per_residual` - Bytes per packed residual
112    ///
113    /// # Returns
114    ///
115    /// Vector of (ChunkId, token_index, score) for all candidates.
116    #[must_use]
117    #[allow(clippy::too_many_arguments)]
118    pub fn score(
119        query_token: &[f32],
120        centroid_id: usize,
121        centroid_score: f32,
122        codec: &ResidualCodec,
123        sizes: &[usize],
124        offsets: &[usize],
125        chunk_ids: &[ChunkId],
126        token_indices: &[u16],
127        residuals: &[u8],
128        bytes_per_residual: usize,
129    ) -> Vec<(ChunkId, u16, f32)> {
130        let size = sizes.get(centroid_id).copied().unwrap_or(0);
131        if size == 0 {
132            return Vec::new();
133        }
134
135        let offset = offsets.get(centroid_id).copied().unwrap_or(0);
136
137        (0..size)
138            .map(|i| {
139                let idx = offset + i;
140                let chunk_id = chunk_ids[idx];
141                let token_idx = token_indices[idx];
142
143                let residual_start = idx * bytes_per_residual;
144                let residual_end = residual_start + bytes_per_residual;
145                let residual = &residuals[residual_start..residual_end];
146
147                let score =
148                    codec.decompress_score(query_token, centroid_id, centroid_score, residual);
149
150                (chunk_id, token_idx, score)
151            })
152            .collect()
153    }
154
155    /// Score a single candidate.
156    #[must_use]
157    pub fn score_single(
158        query_token: &[f32],
159        centroid_id: usize,
160        centroid_score: f32,
161        codec: &ResidualCodec,
162        residual: &[u8],
163    ) -> f32 {
164        codec.decompress_score(query_token, centroid_id, centroid_score, residual)
165    }
166}
167
168/// Phase 3: Merge per-token scores into document scores via MaxSim.
169///
170/// MaxSim computes: score(Q, D) = Σ_i max_j(q_i · d_j)
171///
172/// For each query token, find the maximum score with any document token,
173/// then sum across query tokens.
174pub struct ScoreMerger;
175
176impl ScoreMerger {
177    /// Merge per-token scores into document scores via MaxSim.
178    ///
179    /// # Arguments
180    ///
181    /// * `token_scores` - For each query token: (ChunkId, doc_token_idx, score)
182    /// * `k` - Number of top results to return
183    ///
184    /// # Returns
185    ///
186    /// Vector of (ChunkId, total_score) sorted by score descending.
187    #[must_use]
188    pub fn merge(token_scores: Vec<Vec<(ChunkId, u16, f32)>>, k: usize) -> Vec<(ChunkId, f32)> {
189        if token_scores.is_empty() {
190            return Vec::new();
191        }
192
193        let num_query_tokens = token_scores.len();
194
195        // For each document, track max score per query token
196        let mut doc_token_maxes: HashMap<ChunkId, Vec<f32>> = HashMap::new();
197
198        for (query_token_idx, scores) in token_scores.into_iter().enumerate() {
199            for (chunk_id, _doc_token_idx, score) in scores {
200                let maxes = doc_token_maxes
201                    .entry(chunk_id)
202                    .or_insert_with(|| vec![f32::NEG_INFINITY; num_query_tokens]);
203
204                if score > maxes[query_token_idx] {
205                    maxes[query_token_idx] = score;
206                }
207            }
208        }
209
210        // Sum max scores across query tokens
211        let mut doc_scores: Vec<(ChunkId, f32)> = doc_token_maxes
212            .into_iter()
213            .map(|(chunk_id, maxes)| {
214                let score: f32 = maxes.into_iter().filter(|&s| s > f32::NEG_INFINITY).sum();
215                (chunk_id, score)
216            })
217            .collect();
218
219        // Sort by score descending
220        doc_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
221
222        // Take top-k
223        doc_scores.truncate(k);
224        doc_scores
225    }
226
227    /// Merge scores for a single document across query tokens.
228    ///
229    /// This is useful when you have per-token scores already grouped by document.
230    #[must_use]
231    pub fn merge_single_doc(token_max_scores: &[f32]) -> f32 {
232        token_max_scores.iter().filter(|&&s| s > f32::NEG_INFINITY).sum()
233    }
234}
235
236/// Compute exact MaxSim score (for testing/comparison).
237///
238/// This computes the full MaxSim score without compression:
239/// score(Q, D) = Σ_i max_j(q_i · d_j)
240#[must_use]
241pub fn exact_maxsim(query: &MultiVectorEmbedding, doc: &MultiVectorEmbedding) -> f32 {
242    query
243        .tokens()
244        .map(|q| doc.tokens().map(|d| dot_product(q, d)).fold(f32::NEG_INFINITY, f32::max))
245        .filter(|&s| s > f32::NEG_INFINITY)
246        .sum()
247}
248
249/// Compute dot product between two vectors.
250#[inline]
251fn dot_product(a: &[f32], b: &[f32]) -> f32 {
252    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258
259    fn generate_embedding(num_tokens: usize, dim: usize, seed: u64) -> MultiVectorEmbedding {
260        let mut embeddings = Vec::with_capacity(num_tokens * dim);
261        let mut rng = seed;
262
263        for _ in 0..(num_tokens * dim) {
264            rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1);
265            let val = ((rng >> 33) as f32 / u32::MAX as f32) * 2.0 - 1.0;
266            embeddings.push(val);
267        }
268
269        MultiVectorEmbedding::new(embeddings, num_tokens, dim)
270    }
271
272    fn chunk_id(n: u128) -> ChunkId {
273        ChunkId(uuid::Uuid::from_u128(n))
274    }
275
276    // ============ CentroidSelector Tests ============
277
278    #[test]
279    fn test_centroid_selector_basic() {
280        let query = generate_embedding(2, 4, 42);
281
282        // Create 4 centroids
283        let centroids = vec![
284            1.0, 0.0, 0.0, 0.0, // centroid 0
285            0.0, 1.0, 0.0, 0.0, // centroid 1
286            0.0, 0.0, 1.0, 0.0, // centroid 2
287            0.0, 0.0, 0.0, 1.0, // centroid 3
288        ];
289
290        let config = WarpSearchConfig::with_k(10).nprobe(2).centroid_score_threshold(-1.0); // Accept all
291
292        let selected = CentroidSelector::select(&query, &centroids, 4, &config);
293
294        assert_eq!(selected.len(), 2); // 2 query tokens
295        assert!(selected[0].len() <= 2); // nprobe = 2
296    }
297
298    #[test]
299    fn test_centroid_selector_threshold() {
300        let query = MultiVectorEmbedding::new(vec![1.0, 0.0, 0.0, 0.0], 1, 4);
301
302        let centroids = vec![
303            1.0, 0.0, 0.0, 0.0, // centroid 0: score = 1.0
304            0.0, 1.0, 0.0, 0.0, // centroid 1: score = 0.0
305            0.5, 0.5, 0.0, 0.0, // centroid 2: score = 0.5
306            0.0, 0.0, 1.0, 0.0, // centroid 3: score = 0.0
307        ];
308
309        let config = WarpSearchConfig::with_k(10).nprobe(4).centroid_score_threshold(0.4);
310
311        let selected = CentroidSelector::select(&query, &centroids, 4, &config);
312
313        // Only centroids with score >= 0.4 should be selected
314        assert_eq!(selected.len(), 1);
315        assert!(selected[0].len() <= 2); // centroid 0 (1.0) and centroid 2 (0.5)
316    }
317
318    #[test]
319    fn test_centroid_selector_sorted() {
320        let query = MultiVectorEmbedding::new(vec![0.5, 0.5, 0.0, 0.0], 1, 4);
321
322        let centroids = vec![
323            1.0, 0.0, 0.0, 0.0, // centroid 0
324            0.0, 1.0, 0.0, 0.0, // centroid 1
325            0.5, 0.5, 0.0, 0.0, // centroid 2 (best match)
326            0.0, 0.0, 1.0, 0.0, // centroid 3
327        ];
328
329        let config = WarpSearchConfig::with_k(10).nprobe(4).centroid_score_threshold(-1.0);
330
331        let selected = CentroidSelector::select(&query, &centroids, 4, &config);
332
333        // Results should be sorted by score descending
334        assert!(!selected[0].is_empty());
335        for i in 1..selected[0].len() {
336            assert!(selected[0][i - 1].1 >= selected[0][i].1);
337        }
338    }
339
340    #[test]
341    fn test_batch_scores() {
342        let query_token = vec![1.0, 0.0, 0.0, 0.0];
343        let centroids = vec![
344            1.0, 0.0, 0.0, 0.0, // centroid 0
345            0.0, 1.0, 0.0, 0.0, // centroid 1
346        ];
347
348        let scores = CentroidSelector::batch_scores(&query_token, &centroids, 4);
349
350        assert_eq!(scores.len(), 2);
351        assert_eq!(scores[0].0, 0); // Best match is centroid 0
352        assert!((scores[0].1 - 1.0).abs() < 1e-6);
353    }
354
355    // ============ CandidateScorer Tests ============
356
357    #[test]
358    fn test_candidate_scorer_empty_centroid() {
359        let query_token = vec![1.0, 0.0, 0.0, 0.0];
360        let codec = create_test_codec();
361
362        let sizes = vec![0, 5, 3]; // centroid 0 is empty
363        let offsets = vec![0, 0, 5];
364        let chunk_ids: Vec<ChunkId> = vec![];
365        let token_indices: Vec<u16> = vec![];
366        let residuals: Vec<u8> = vec![];
367
368        let results = CandidateScorer::score(
369            &query_token,
370            0, // empty centroid
371            0.5,
372            &codec,
373            &sizes,
374            &offsets,
375            &chunk_ids,
376            &token_indices,
377            &residuals,
378            2, // bytes per residual
379        );
380
381        assert!(results.is_empty());
382    }
383
384    fn create_test_codec() -> ResidualCodec {
385        // Create a minimal test codec
386        let embeddings = vec![0.0f32; 200 * 4]; // 200 samples, dim=4
387        ResidualCodec::train(&embeddings, 4, 4, 2, 3).unwrap()
388    }
389
390    // ============ ScoreMerger Tests ============
391
392    #[test]
393    fn test_score_merger_basic() {
394        let token_scores = vec![
395            vec![(chunk_id(1), 0, 0.9), (chunk_id(2), 0, 0.8), (chunk_id(1), 1, 0.7)],
396            vec![(chunk_id(1), 0, 0.6), (chunk_id(2), 0, 0.5), (chunk_id(3), 0, 0.4)],
397        ];
398
399        let results = ScoreMerger::merge(token_scores, 10);
400
401        // chunk_id(1): max(0.9, 0.7) + max(0.6) = 0.9 + 0.6 = 1.5
402        // chunk_id(2): max(0.8) + max(0.5) = 0.8 + 0.5 = 1.3
403        // chunk_id(3): 0 + max(0.4) = 0.4
404
405        assert_eq!(results.len(), 3);
406        assert_eq!(results[0].0, chunk_id(1));
407        assert!((results[0].1 - 1.5).abs() < 0.001);
408    }
409
410    #[test]
411    fn test_score_merger_empty() {
412        let token_scores: Vec<Vec<(ChunkId, u16, f32)>> = vec![];
413        let results = ScoreMerger::merge(token_scores, 10);
414        assert!(results.is_empty());
415    }
416
417    #[test]
418    fn test_score_merger_respects_k() {
419        let token_scores = vec![vec![
420            (chunk_id(1), 0, 0.9),
421            (chunk_id(2), 0, 0.8),
422            (chunk_id(3), 0, 0.7),
423            (chunk_id(4), 0, 0.6),
424            (chunk_id(5), 0, 0.5),
425        ]];
426
427        let results = ScoreMerger::merge(token_scores, 3);
428        assert_eq!(results.len(), 3);
429    }
430
431    #[test]
432    fn test_score_merger_sorted_descending() {
433        let token_scores =
434            vec![vec![(chunk_id(1), 0, 0.3), (chunk_id(2), 0, 0.9), (chunk_id(3), 0, 0.6)]];
435
436        let results = ScoreMerger::merge(token_scores, 10);
437
438        assert_eq!(results[0].0, chunk_id(2)); // highest
439        assert_eq!(results[1].0, chunk_id(3));
440        assert_eq!(results[2].0, chunk_id(1)); // lowest
441    }
442
443    #[test]
444    fn test_merge_single_doc() {
445        let scores = vec![0.9, 0.6, f32::NEG_INFINITY, 0.3];
446        let total = ScoreMerger::merge_single_doc(&scores);
447
448        assert!((total - 1.8).abs() < 0.001); // 0.9 + 0.6 + 0.3
449    }
450
451    // ============ Exact MaxSim Tests ============
452
453    #[test]
454    fn test_exact_maxsim_identical() {
455        let emb = generate_embedding(3, 4, 42);
456        let score = exact_maxsim(&emb, &emb);
457
458        // Self-similarity: for normalized vectors, this should be num_tokens
459        // For non-normalized, just check it's positive
460        assert!(score > 0.0);
461    }
462
463    #[test]
464    fn test_exact_maxsim_orthogonal() {
465        let query = MultiVectorEmbedding::new(vec![1.0, 0.0, 0.0, 0.0], 1, 4);
466        let doc = MultiVectorEmbedding::new(vec![0.0, 1.0, 0.0, 0.0], 1, 4);
467
468        let score = exact_maxsim(&query, &doc);
469        assert!((score - 0.0).abs() < 1e-6);
470    }
471
472    #[test]
473    fn test_exact_maxsim_aligned() {
474        let query = MultiVectorEmbedding::new(vec![1.0, 0.0, 0.0, 0.0], 1, 4);
475        let doc = MultiVectorEmbedding::new(vec![1.0, 0.0, 0.0, 0.0], 1, 4);
476
477        let score = exact_maxsim(&query, &doc);
478        assert!((score - 1.0).abs() < 1e-6);
479    }
480
481    // ============ Property-Based Tests ============
482
483    use proptest::prelude::*;
484
485    proptest! {
486        #[test]
487        fn prop_maxsim_non_negative_for_unit_vectors(
488            num_q in 1usize..5,
489            num_d in 1usize..5
490        ) {
491            // Generate unit vectors
492            let query = generate_embedding(num_q, 4, 123);
493            let doc = generate_embedding(num_d, 4, 456);
494
495            let score = exact_maxsim(&query, &doc);
496
497            // MaxSim can be negative for non-unit vectors, but the test
498            // just checks it doesn't panic
499            prop_assert!(score.is_finite());
500        }
501
502        #[test]
503        fn prop_merger_results_count_bounded_by_k(
504            k in 1usize..20,
505            num_docs in 1usize..50
506        ) {
507            let token_scores = vec![
508                (0..num_docs)
509                    .map(|i| (chunk_id(i as u128), 0u16, i as f32 / 100.0))
510                    .collect()
511            ];
512
513            let results = ScoreMerger::merge(token_scores, k);
514            prop_assert!(results.len() <= k);
515            prop_assert!(results.len() <= num_docs);
516        }
517
518        #[test]
519        fn prop_centroid_selector_respects_nprobe(
520            nprobe in 1u32..10
521        ) {
522            let query = generate_embedding(2, 4, 42);
523            let centroids = vec![0.5f32; 20 * 4]; // 20 centroids
524
525            let config = WarpSearchConfig::with_k(10)
526                .nprobe(nprobe)
527                .centroid_score_threshold(-10.0); // Accept all
528
529            let selected = CentroidSelector::select(&query, &centroids, 4, &config);
530
531            for token_selection in selected {
532                prop_assert!(token_selection.len() <= nprobe as usize);
533            }
534        }
535    }
536}