1use std::collections::{HashMap, HashSet};
8
9use anyhow::{anyhow, Result};
10
11use crate::types::{HitResult, IndexMetadata, QueryRecord};
12use crate::ShardedInvertedIndex;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum FastPath {
18 None,
20 NumHigh,
22}
23
24impl FastPath {
25 pub fn as_str(&self) -> &'static str {
27 match self {
28 FastPath::None => "none",
29 FastPath::NumHigh => "num_high",
30 }
31 }
32}
33
34#[derive(Debug, Clone, PartialEq)]
36pub struct LogRatioResult {
37 pub query_id: i64,
38 pub log_ratio: f64,
39 pub fast_path: FastPath,
40}
41
42pub fn compute_log_ratio(numerator: f64, denominator: f64) -> f64 {
49 if numerator == 0.0 && denominator == 0.0 {
50 f64::NAN
51 } else if numerator == 0.0 {
52 f64::NEG_INFINITY
53 } else if denominator == 0.0 {
54 f64::INFINITY
55 } else {
56 (numerator / denominator).log10()
57 }
58}
59
60pub fn validate_single_bucket_index(bucket_names: &HashMap<u32, String>) -> Result<(u32, String)> {
64 if bucket_names.len() != 1 {
65 return Err(anyhow!(
66 "log-ratio mode requires each index to have exactly 1 bucket, but found {}.\n\
67 Use 'rype index stats -i <index>' to see bucket information.",
68 bucket_names.len()
69 ));
70 }
71
72 let (&bucket_id, bucket_name) = bucket_names.iter().next().unwrap();
73 Ok((bucket_id, bucket_name.clone()))
74}
75
76pub fn validate_compatible_indices(a: &IndexMetadata, b: &IndexMetadata) -> Result<()> {
80 if a.k != b.k {
81 return Err(anyhow!(
82 "Numerator and denominator indices have different k values: {} vs {}.\n\
83 Both indices must be built with the same k, w, and salt.",
84 a.k,
85 b.k
86 ));
87 }
88 if a.w != b.w {
89 return Err(anyhow!(
90 "Numerator and denominator indices have different w values: {} vs {}.\n\
91 Both indices must be built with the same k, w, and salt.",
92 a.w,
93 b.w
94 ));
95 }
96 if a.salt != b.salt {
97 return Err(anyhow!(
98 "Numerator and denominator indices have different salt values: {:#x} vs {:#x}.\n\
99 Both indices must be built with the same k, w, and salt.",
100 a.salt,
101 b.salt
102 ));
103 }
104 Ok(())
105}
106
107pub struct PartitionResult {
109 pub fast_path_results: Vec<LogRatioResult>,
111 pub needs_denom_query_ids: Vec<i64>,
113 pub num_scores: Vec<f64>,
116}
117
118pub fn partition_by_numerator_score(
129 num_results: &[HitResult],
130 total_reads: usize,
131 skip_threshold: Option<f64>,
132) -> PartitionResult {
133 let mut num_scores = vec![0.0_f64; total_reads];
135 for hit in num_results {
136 num_scores[hit.query_id as usize] = hit.score;
137 }
138
139 let mut fast_path_results = Vec::new();
140 let mut needs_denom_query_ids = Vec::new();
141
142 for query_id in 0..total_reads as i64 {
143 let score = num_scores[query_id as usize];
144
145 if let Some(thresh) = skip_threshold {
146 if score >= thresh {
147 fast_path_results.push(LogRatioResult {
149 query_id,
150 log_ratio: f64::INFINITY,
151 fast_path: FastPath::NumHigh,
152 });
153 } else {
154 needs_denom_query_ids.push(query_id);
155 }
156 } else {
157 needs_denom_query_ids.push(query_id);
158 }
159 }
160
161 PartitionResult {
162 fast_path_results,
163 needs_denom_query_ids,
164 num_scores,
165 }
166}
167
168pub fn validate_log_ratio_indices(
173 numerator: &ShardedInvertedIndex,
174 denominator: &ShardedInvertedIndex,
175) -> Result<(usize, usize, u64)> {
176 let num_manifest = numerator.manifest();
177 let denom_manifest = denominator.manifest();
178
179 validate_single_bucket_index(&num_manifest.bucket_names)
180 .map_err(|e| anyhow!("numerator index: {}", e))?;
181 validate_single_bucket_index(&denom_manifest.bucket_names)
182 .map_err(|e| anyhow!("denominator index: {}", e))?;
183
184 if num_manifest.k != denom_manifest.k {
185 return Err(anyhow!(
186 "Numerator and denominator indices have different k values: {} vs {}.\n\
187 Both indices must be built with the same k, w, and salt.",
188 num_manifest.k,
189 denom_manifest.k
190 ));
191 }
192 if num_manifest.w != denom_manifest.w {
193 return Err(anyhow!(
194 "Numerator and denominator indices have different w values: {} vs {}.\n\
195 Both indices must be built with the same k, w, and salt.",
196 num_manifest.w,
197 denom_manifest.w
198 ));
199 }
200 if num_manifest.salt != denom_manifest.salt {
201 return Err(anyhow!(
202 "Numerator and denominator indices have different salt values: {:#x} vs {:#x}.\n\
203 Both indices must be built with the same k, w, and salt.",
204 num_manifest.salt,
205 denom_manifest.salt
206 ));
207 }
208
209 Ok((num_manifest.k, num_manifest.w, num_manifest.salt))
210}
211
212pub fn classify_log_ratio_batch(
229 numerator: &ShardedInvertedIndex,
230 denominator: &ShardedInvertedIndex,
231 records: &[QueryRecord],
232 skip_threshold: Option<f64>,
233) -> Result<Vec<LogRatioResult>> {
234 let num_queries = records.len();
235 if num_queries == 0 {
236 return Ok(Vec::new());
237 }
238
239 let (k, w, salt) = validate_log_ratio_indices(numerator, denominator)?;
240
241 let original_ids: Vec<i64> = records.iter().map(|r| r.0).collect();
245 let sequential_ids: Vec<i64> = (0..num_queries as i64).collect();
246
247 let extracted = crate::extract_batch_minimizers(k, w, salt, None, records);
248
249 let num_results = crate::classify_from_extracted_minimizers(
251 numerator,
252 &extracted,
253 &sequential_ids,
254 0.0,
255 None,
256 )?;
257
258 let partition = partition_by_numerator_score(&num_results, num_queries, skip_threshold);
260
261 let needs_denom_set: HashSet<i64> = partition.needs_denom_query_ids.iter().copied().collect();
263
264 let mut denom_extracted = Vec::new();
265 let mut denom_ids = Vec::new();
266 for (i, ext) in extracted.iter().enumerate() {
267 let seq_id = i as i64;
268 if needs_denom_set.contains(&seq_id) {
269 denom_extracted.push(ext.clone());
270 denom_ids.push(seq_id);
271 }
272 }
273
274 let denom_results = if !denom_ids.is_empty() {
276 crate::classify_from_extracted_minimizers(
277 denominator,
278 &denom_extracted,
279 &denom_ids,
280 0.0,
281 None,
282 )?
283 } else {
284 Vec::new()
285 };
286
287 let mut denom_scores = vec![0.0_f64; num_queries];
289 for hit in &denom_results {
290 denom_scores[hit.query_id as usize] = hit.score;
291 }
292
293 let mut results: Vec<LogRatioResult> = Vec::with_capacity(num_queries);
295
296 for lr in &partition.fast_path_results {
297 results.push(LogRatioResult {
298 query_id: original_ids[lr.query_id as usize],
299 log_ratio: lr.log_ratio,
300 fast_path: lr.fast_path,
301 });
302 }
303
304 for &seq_id in &partition.needs_denom_query_ids {
305 let idx = seq_id as usize;
306 let num_score = partition.num_scores[idx];
307 let denom_score = denom_scores[idx];
308 let log_ratio = compute_log_ratio(num_score, denom_score);
309 results.push(LogRatioResult {
310 query_id: original_ids[idx],
311 log_ratio,
312 fast_path: FastPath::None,
313 });
314 }
315
316 results.sort_by_key(|r| r.query_id);
318 Ok(results)
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324
325 #[test]
328 fn test_fast_path_as_str() {
329 assert_eq!(FastPath::None.as_str(), "none");
330 assert_eq!(FastPath::NumHigh.as_str(), "num_high");
331 }
332
333 #[test]
334 fn test_log_ratio_result_with_fast_path() {
335 let result = LogRatioResult {
336 query_id: 7,
337 log_ratio: f64::INFINITY,
338 fast_path: FastPath::NumHigh,
339 };
340 assert_eq!(result.fast_path, FastPath::NumHigh);
341
342 let result = LogRatioResult {
343 query_id: 0,
344 log_ratio: 1.5,
345 fast_path: FastPath::None,
346 };
347 assert_eq!(result.fast_path, FastPath::None);
348 }
349
350 #[test]
353 fn test_compute_log_ratio_both_positive() {
354 let result = compute_log_ratio(100.0, 10.0);
355 assert!((result - 1.0).abs() < 1e-10);
356
357 let result = compute_log_ratio(10.0, 100.0);
358 assert!((result - (-1.0)).abs() < 1e-10);
359 }
360
361 #[test]
362 fn test_compute_log_ratio_equal_scores() {
363 let result = compute_log_ratio(50.0, 50.0);
364 assert!((result - 0.0).abs() < 1e-10);
365 }
366
367 #[test]
368 fn test_compute_log_ratio_numerator_zero() {
369 let result = compute_log_ratio(0.0, 50.0);
370 assert!(result.is_infinite() && result.is_sign_negative());
371 }
372
373 #[test]
374 fn test_compute_log_ratio_denominator_zero() {
375 let result = compute_log_ratio(50.0, 0.0);
376 assert!(result.is_infinite() && result.is_sign_positive());
377 }
378
379 #[test]
380 fn test_compute_log_ratio_both_zero() {
381 let result = compute_log_ratio(0.0, 0.0);
382 assert!(result.is_nan());
383 }
384
385 #[test]
388 fn test_validate_single_bucket_index_passes() {
389 let mut bucket_names = HashMap::new();
390 bucket_names.insert(0, "MyBucket".to_string());
391
392 let result = validate_single_bucket_index(&bucket_names);
393 assert!(result.is_ok());
394
395 let (bucket_id, bucket_name) = result.unwrap();
396 assert_eq!(bucket_id, 0);
397 assert_eq!(bucket_name, "MyBucket");
398 }
399
400 #[test]
401 fn test_validate_single_bucket_index_fails_empty() {
402 let bucket_names: HashMap<u32, String> = HashMap::new();
403
404 let result = validate_single_bucket_index(&bucket_names);
405 assert!(result.is_err());
406 let err = result.unwrap_err().to_string();
407 assert!(err.contains("exactly 1 bucket"));
408 assert!(err.contains("found 0"));
409 }
410
411 #[test]
412 fn test_validate_single_bucket_index_fails_two_buckets() {
413 let mut bucket_names = HashMap::new();
414 bucket_names.insert(0, "A".to_string());
415 bucket_names.insert(1, "B".to_string());
416
417 let result = validate_single_bucket_index(&bucket_names);
418 assert!(result.is_err());
419 let err = result.unwrap_err().to_string();
420 assert!(err.contains("exactly 1 bucket"));
421 assert!(err.contains("found 2"));
422 }
423
424 #[test]
425 fn test_validate_single_bucket_index_preserves_id() {
426 let mut bucket_names = HashMap::new();
427 bucket_names.insert(42, "HighId".to_string());
428
429 let result = validate_single_bucket_index(&bucket_names);
430 assert!(result.is_ok());
431
432 let (bucket_id, bucket_name) = result.unwrap();
433 assert_eq!(bucket_id, 42);
434 assert_eq!(bucket_name, "HighId");
435 }
436
437 fn make_metadata(k: usize, w: usize, salt: u64) -> IndexMetadata {
440 IndexMetadata {
441 k,
442 w,
443 salt,
444 bucket_names: HashMap::new(),
445 bucket_sources: HashMap::new(),
446 bucket_minimizer_counts: HashMap::new(),
447 largest_shard_entries: 0,
448 bucket_file_stats: None,
449 }
450 }
451
452 #[test]
453 fn test_validate_compatible_indices_passes_when_matching() {
454 let a = make_metadata(32, 10, 0x5555555555555555);
455 let b = make_metadata(32, 10, 0x5555555555555555);
456
457 assert!(validate_compatible_indices(&a, &b).is_ok());
458 }
459
460 #[test]
461 fn test_validate_compatible_indices_fails_on_k_mismatch() {
462 let a = make_metadata(32, 10, 0x5555555555555555);
463 let b = make_metadata(64, 10, 0x5555555555555555);
464
465 let result = validate_compatible_indices(&a, &b);
466 assert!(result.is_err());
467 let err = result.unwrap_err().to_string();
468 assert!(err.contains("different k values"));
469 }
470
471 #[test]
472 fn test_validate_compatible_indices_fails_on_w_mismatch() {
473 let a = make_metadata(32, 10, 0x5555555555555555);
474 let b = make_metadata(32, 20, 0x5555555555555555);
475
476 let result = validate_compatible_indices(&a, &b);
477 assert!(result.is_err());
478 assert!(result
479 .unwrap_err()
480 .to_string()
481 .contains("different w values"));
482 }
483
484 #[test]
485 fn test_validate_compatible_indices_fails_on_salt_mismatch() {
486 let a = make_metadata(32, 10, 0x5555555555555555);
487 let b = make_metadata(32, 10, 0xAAAAAAAAAAAAAAAA);
488
489 let result = validate_compatible_indices(&a, &b);
490 assert!(result.is_err());
491 assert!(result
492 .unwrap_err()
493 .to_string()
494 .contains("different salt values"));
495 }
496
497 #[test]
500 fn test_partition_all_zeros_goes_to_needs_denom() {
501 let num_results: Vec<HitResult> = vec![];
502 let result = partition_by_numerator_score(&num_results, 3, None);
503
504 assert!(result.fast_path_results.is_empty());
505 assert_eq!(result.needs_denom_query_ids.len(), 3);
506 assert_eq!(result.needs_denom_query_ids, vec![0, 1, 2]);
507 assert!(result.num_scores.iter().all(|&s| s == 0.0));
508 }
509
510 #[test]
511 fn test_partition_with_skip_threshold_creates_two_groups() {
512 let num_results = vec![
513 HitResult {
514 query_id: 1,
515 bucket_id: 0,
516 score: 0.05,
517 },
518 HitResult {
519 query_id: 2,
520 bucket_id: 0,
521 score: 0.5,
522 },
523 HitResult {
524 query_id: 3,
525 bucket_id: 0,
526 score: 0.01,
527 },
528 ];
529
530 let result = partition_by_numerator_score(&num_results, 4, Some(0.1));
531
532 assert_eq!(result.fast_path_results.len(), 1);
533 assert_eq!(result.fast_path_results[0].query_id, 2);
534 assert_eq!(result.fast_path_results[0].fast_path, FastPath::NumHigh);
535 assert!(result.fast_path_results[0].log_ratio == f64::INFINITY);
536
537 assert_eq!(result.needs_denom_query_ids.len(), 3);
538 assert!(result.needs_denom_query_ids.contains(&0));
539 assert!(result.needs_denom_query_ids.contains(&1));
540 assert!(result.needs_denom_query_ids.contains(&3));
541 }
542
543 #[test]
544 fn test_partition_without_skip_threshold_no_fast_path() {
545 let num_results = vec![
546 HitResult {
547 query_id: 1,
548 bucket_id: 0,
549 score: 0.5,
550 },
551 HitResult {
552 query_id: 2,
553 bucket_id: 0,
554 score: 0.9,
555 },
556 ];
557
558 let result = partition_by_numerator_score(&num_results, 3, None);
559
560 assert!(result.fast_path_results.is_empty());
561 assert_eq!(result.needs_denom_query_ids.len(), 3);
562 }
563
564 #[test]
565 fn test_partition_skip_threshold_at_boundary() {
566 let num_results = vec![HitResult {
567 query_id: 0,
568 bucket_id: 0,
569 score: 0.1,
570 }];
571
572 let result = partition_by_numerator_score(&num_results, 1, Some(0.1));
573
574 assert_eq!(result.fast_path_results.len(), 1);
575 assert_eq!(result.fast_path_results[0].fast_path, FastPath::NumHigh);
576 assert!(result.needs_denom_query_ids.is_empty());
577 }
578
579 #[test]
580 fn test_partition_empty_batch() {
581 let result = partition_by_numerator_score(&[], 0, None);
582
583 assert!(result.fast_path_results.is_empty());
584 assert!(result.needs_denom_query_ids.is_empty());
585 assert!(result.num_scores.is_empty());
586 }
587}