1use crate::multivector::{codec::ResidualCodec, types::WarpSearchConfig, MultiVectorEmbedding};
10use crate::ChunkId;
11use std::collections::HashMap;
12
13pub struct CentroidSelector;
18
19impl CentroidSelector {
20 #[must_use]
34 pub fn select(
35 query: &MultiVectorEmbedding,
36 centroids: &[f32],
37 dim: usize,
38 config: &WarpSearchConfig,
39 ) -> Vec<Vec<(usize, f32)>> {
40 let num_centroids = centroids.len() / dim;
41
42 query
43 .tokens()
44 .map(|query_token| {
45 let mut scores: Vec<(usize, f32)> = (0..num_centroids)
47 .map(|c| {
48 let centroid = ¢roids[c * dim..(c + 1) * dim];
49 let score = Self::dot_product(query_token, centroid);
50 (c, score)
51 })
52 .collect();
53
54 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
56
57 scores
59 .into_iter()
60 .take(config.nprobe as usize)
61 .filter(|(_, score)| *score >= config.centroid_score_threshold)
62 .collect()
63 })
64 .collect()
65 }
66
67 #[must_use]
71 pub fn batch_scores(query_token: &[f32], centroids: &[f32], dim: usize) -> Vec<(usize, f32)> {
72 let num_centroids = centroids.len() / dim;
73
74 let mut scores: Vec<(usize, f32)> = (0..num_centroids)
75 .map(|c| {
76 let centroid = ¢roids[c * dim..(c + 1) * dim];
77 let score = Self::dot_product(query_token, centroid);
78 (c, score)
79 })
80 .collect();
81
82 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
83 scores
84 }
85
86 fn dot_product(a: &[f32], b: &[f32]) -> f32 {
87 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
88 }
89}
90
91pub struct CandidateScorer;
96
97impl CandidateScorer {
98 #[must_use]
117 #[allow(clippy::too_many_arguments)]
118 pub fn score(
119 query_token: &[f32],
120 centroid_id: usize,
121 centroid_score: f32,
122 codec: &ResidualCodec,
123 sizes: &[usize],
124 offsets: &[usize],
125 chunk_ids: &[ChunkId],
126 token_indices: &[u16],
127 residuals: &[u8],
128 bytes_per_residual: usize,
129 ) -> Vec<(ChunkId, u16, f32)> {
130 let size = sizes.get(centroid_id).copied().unwrap_or(0);
131 if size == 0 {
132 return Vec::new();
133 }
134
135 let offset = offsets.get(centroid_id).copied().unwrap_or(0);
136
137 (0..size)
138 .map(|i| {
139 let idx = offset + i;
140 let chunk_id = chunk_ids[idx];
141 let token_idx = token_indices[idx];
142
143 let residual_start = idx * bytes_per_residual;
144 let residual_end = residual_start + bytes_per_residual;
145 let residual = &residuals[residual_start..residual_end];
146
147 let score =
148 codec.decompress_score(query_token, centroid_id, centroid_score, residual);
149
150 (chunk_id, token_idx, score)
151 })
152 .collect()
153 }
154
155 #[must_use]
157 pub fn score_single(
158 query_token: &[f32],
159 centroid_id: usize,
160 centroid_score: f32,
161 codec: &ResidualCodec,
162 residual: &[u8],
163 ) -> f32 {
164 codec.decompress_score(query_token, centroid_id, centroid_score, residual)
165 }
166}
167
168pub struct ScoreMerger;
175
176impl ScoreMerger {
177 #[must_use]
188 pub fn merge(token_scores: Vec<Vec<(ChunkId, u16, f32)>>, k: usize) -> Vec<(ChunkId, f32)> {
189 if token_scores.is_empty() {
190 return Vec::new();
191 }
192
193 let num_query_tokens = token_scores.len();
194
195 let mut doc_token_maxes: HashMap<ChunkId, Vec<f32>> = HashMap::new();
197
198 for (query_token_idx, scores) in token_scores.into_iter().enumerate() {
199 for (chunk_id, _doc_token_idx, score) in scores {
200 let maxes = doc_token_maxes
201 .entry(chunk_id)
202 .or_insert_with(|| vec![f32::NEG_INFINITY; num_query_tokens]);
203
204 if score > maxes[query_token_idx] {
205 maxes[query_token_idx] = score;
206 }
207 }
208 }
209
210 let mut doc_scores: Vec<(ChunkId, f32)> = doc_token_maxes
212 .into_iter()
213 .map(|(chunk_id, maxes)| {
214 let score: f32 = maxes.into_iter().filter(|&s| s > f32::NEG_INFINITY).sum();
215 (chunk_id, score)
216 })
217 .collect();
218
219 doc_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
221
222 doc_scores.truncate(k);
224 doc_scores
225 }
226
227 #[must_use]
231 pub fn merge_single_doc(token_max_scores: &[f32]) -> f32 {
232 token_max_scores.iter().filter(|&&s| s > f32::NEG_INFINITY).sum()
233 }
234}
235
236#[must_use]
241pub fn exact_maxsim(query: &MultiVectorEmbedding, doc: &MultiVectorEmbedding) -> f32 {
242 query
243 .tokens()
244 .map(|q| doc.tokens().map(|d| dot_product(q, d)).fold(f32::NEG_INFINITY, f32::max))
245 .filter(|&s| s > f32::NEG_INFINITY)
246 .sum()
247}
248
249#[inline]
251fn dot_product(a: &[f32], b: &[f32]) -> f32 {
252 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258
259 fn generate_embedding(num_tokens: usize, dim: usize, seed: u64) -> MultiVectorEmbedding {
260 let mut embeddings = Vec::with_capacity(num_tokens * dim);
261 let mut rng = seed;
262
263 for _ in 0..(num_tokens * dim) {
264 rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1);
265 let val = ((rng >> 33) as f32 / u32::MAX as f32) * 2.0 - 1.0;
266 embeddings.push(val);
267 }
268
269 MultiVectorEmbedding::new(embeddings, num_tokens, dim)
270 }
271
272 fn chunk_id(n: u128) -> ChunkId {
273 ChunkId(uuid::Uuid::from_u128(n))
274 }
275
276 #[test]
279 fn test_centroid_selector_basic() {
280 let query = generate_embedding(2, 4, 42);
281
282 let centroids = vec![
284 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, ];
289
290 let config = WarpSearchConfig::with_k(10).nprobe(2).centroid_score_threshold(-1.0); let selected = CentroidSelector::select(&query, ¢roids, 4, &config);
293
294 assert_eq!(selected.len(), 2); assert!(selected[0].len() <= 2); }
297
298 #[test]
299 fn test_centroid_selector_threshold() {
300 let query = MultiVectorEmbedding::new(vec![1.0, 0.0, 0.0, 0.0], 1, 4);
301
302 let centroids = vec![
303 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.5, 0.5, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, ];
308
309 let config = WarpSearchConfig::with_k(10).nprobe(4).centroid_score_threshold(0.4);
310
311 let selected = CentroidSelector::select(&query, ¢roids, 4, &config);
312
313 assert_eq!(selected.len(), 1);
315 assert!(selected[0].len() <= 2); }
317
318 #[test]
319 fn test_centroid_selector_sorted() {
320 let query = MultiVectorEmbedding::new(vec![0.5, 0.5, 0.0, 0.0], 1, 4);
321
322 let centroids = vec![
323 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.5, 0.5, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, ];
328
329 let config = WarpSearchConfig::with_k(10).nprobe(4).centroid_score_threshold(-1.0);
330
331 let selected = CentroidSelector::select(&query, ¢roids, 4, &config);
332
333 assert!(!selected[0].is_empty());
335 for i in 1..selected[0].len() {
336 assert!(selected[0][i - 1].1 >= selected[0][i].1);
337 }
338 }
339
340 #[test]
341 fn test_batch_scores() {
342 let query_token = vec![1.0, 0.0, 0.0, 0.0];
343 let centroids = vec![
344 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, ];
347
348 let scores = CentroidSelector::batch_scores(&query_token, ¢roids, 4);
349
350 assert_eq!(scores.len(), 2);
351 assert_eq!(scores[0].0, 0); assert!((scores[0].1 - 1.0).abs() < 1e-6);
353 }
354
355 #[test]
358 fn test_candidate_scorer_empty_centroid() {
359 let query_token = vec![1.0, 0.0, 0.0, 0.0];
360 let codec = create_test_codec();
361
362 let sizes = vec![0, 5, 3]; let offsets = vec![0, 0, 5];
364 let chunk_ids: Vec<ChunkId> = vec![];
365 let token_indices: Vec<u16> = vec![];
366 let residuals: Vec<u8> = vec![];
367
368 let results = CandidateScorer::score(
369 &query_token,
370 0, 0.5,
372 &codec,
373 &sizes,
374 &offsets,
375 &chunk_ids,
376 &token_indices,
377 &residuals,
378 2, );
380
381 assert!(results.is_empty());
382 }
383
384 fn create_test_codec() -> ResidualCodec {
385 let embeddings = vec![0.0f32; 200 * 4]; ResidualCodec::train(&embeddings, 4, 4, 2, 3).unwrap()
388 }
389
390 #[test]
393 fn test_score_merger_basic() {
394 let token_scores = vec![
395 vec![(chunk_id(1), 0, 0.9), (chunk_id(2), 0, 0.8), (chunk_id(1), 1, 0.7)],
396 vec![(chunk_id(1), 0, 0.6), (chunk_id(2), 0, 0.5), (chunk_id(3), 0, 0.4)],
397 ];
398
399 let results = ScoreMerger::merge(token_scores, 10);
400
401 assert_eq!(results.len(), 3);
406 assert_eq!(results[0].0, chunk_id(1));
407 assert!((results[0].1 - 1.5).abs() < 0.001);
408 }
409
410 #[test]
411 fn test_score_merger_empty() {
412 let token_scores: Vec<Vec<(ChunkId, u16, f32)>> = vec![];
413 let results = ScoreMerger::merge(token_scores, 10);
414 assert!(results.is_empty());
415 }
416
417 #[test]
418 fn test_score_merger_respects_k() {
419 let token_scores = vec![vec![
420 (chunk_id(1), 0, 0.9),
421 (chunk_id(2), 0, 0.8),
422 (chunk_id(3), 0, 0.7),
423 (chunk_id(4), 0, 0.6),
424 (chunk_id(5), 0, 0.5),
425 ]];
426
427 let results = ScoreMerger::merge(token_scores, 3);
428 assert_eq!(results.len(), 3);
429 }
430
431 #[test]
432 fn test_score_merger_sorted_descending() {
433 let token_scores =
434 vec![vec![(chunk_id(1), 0, 0.3), (chunk_id(2), 0, 0.9), (chunk_id(3), 0, 0.6)]];
435
436 let results = ScoreMerger::merge(token_scores, 10);
437
438 assert_eq!(results[0].0, chunk_id(2)); assert_eq!(results[1].0, chunk_id(3));
440 assert_eq!(results[2].0, chunk_id(1)); }
442
443 #[test]
444 fn test_merge_single_doc() {
445 let scores = vec![0.9, 0.6, f32::NEG_INFINITY, 0.3];
446 let total = ScoreMerger::merge_single_doc(&scores);
447
448 assert!((total - 1.8).abs() < 0.001); }
450
451 #[test]
454 fn test_exact_maxsim_identical() {
455 let emb = generate_embedding(3, 4, 42);
456 let score = exact_maxsim(&emb, &emb);
457
458 assert!(score > 0.0);
461 }
462
463 #[test]
464 fn test_exact_maxsim_orthogonal() {
465 let query = MultiVectorEmbedding::new(vec![1.0, 0.0, 0.0, 0.0], 1, 4);
466 let doc = MultiVectorEmbedding::new(vec![0.0, 1.0, 0.0, 0.0], 1, 4);
467
468 let score = exact_maxsim(&query, &doc);
469 assert!((score - 0.0).abs() < 1e-6);
470 }
471
472 #[test]
473 fn test_exact_maxsim_aligned() {
474 let query = MultiVectorEmbedding::new(vec![1.0, 0.0, 0.0, 0.0], 1, 4);
475 let doc = MultiVectorEmbedding::new(vec![1.0, 0.0, 0.0, 0.0], 1, 4);
476
477 let score = exact_maxsim(&query, &doc);
478 assert!((score - 1.0).abs() < 1e-6);
479 }
480
481 use proptest::prelude::*;
484
485 proptest! {
486 #[test]
487 fn prop_maxsim_non_negative_for_unit_vectors(
488 num_q in 1usize..5,
489 num_d in 1usize..5
490 ) {
491 let query = generate_embedding(num_q, 4, 123);
493 let doc = generate_embedding(num_d, 4, 456);
494
495 let score = exact_maxsim(&query, &doc);
496
497 prop_assert!(score.is_finite());
500 }
501
502 #[test]
503 fn prop_merger_results_count_bounded_by_k(
504 k in 1usize..20,
505 num_docs in 1usize..50
506 ) {
507 let token_scores = vec![
508 (0..num_docs)
509 .map(|i| (chunk_id(i as u128), 0u16, i as f32 / 100.0))
510 .collect()
511 ];
512
513 let results = ScoreMerger::merge(token_scores, k);
514 prop_assert!(results.len() <= k);
515 prop_assert!(results.len() <= num_docs);
516 }
517
518 #[test]
519 fn prop_centroid_selector_respects_nprobe(
520 nprobe in 1u32..10
521 ) {
522 let query = generate_embedding(2, 4, 42);
523 let centroids = vec![0.5f32; 20 * 4]; let config = WarpSearchConfig::with_k(10)
526 .nprobe(nprobe)
527 .centroid_score_threshold(-10.0); let selected = CentroidSelector::select(&query, ¢roids, 4, &config);
530
531 for token_selection in selected {
532 prop_assert!(token_selection.len() <= nprobe as usize);
533 }
534 }
535 }
536}