Skip to main content

rype/classify/
log_ratio.rs

1//! Log-ratio types and functions for two-index classification.
2//!
3//! This module provides the core types and logic for computing
4//! log10(numerator_score / denominator_score) between two single-bucket indices.
5//! It lives in the library crate so both the CLI and C API can use it.
6
7use std::collections::{HashMap, HashSet};
8
9use anyhow::{anyhow, Result};
10
11use crate::types::{HitResult, IndexMetadata, QueryRecord};
12use crate::ShardedInvertedIndex;
13
14/// Indicates whether a log-ratio result was determined via a fast path
15/// (skipping the denominator classification) or computed exactly.
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum FastPath {
18    /// Result was computed exactly (both numerator and denominator classified).
19    None,
20    /// Numerator score exceeded the skip threshold, so log-ratio is +inf without needing denominator.
21    NumHigh,
22}
23
24impl FastPath {
25    /// Return a short string label for TSV output.
26    pub fn as_str(&self) -> &'static str {
27        match self {
28            FastPath::None => "none",
29            FastPath::NumHigh => "num_high",
30        }
31    }
32}
33
34/// Result of log-ratio computation for a single query.
35#[derive(Debug, Clone, PartialEq)]
36pub struct LogRatioResult {
37    pub query_id: i64,
38    pub log_ratio: f64,
39    pub fast_path: FastPath,
40}
41
42/// Compute log10(numerator / denominator) with special handling for edge cases.
43///
44/// Edge cases:
45/// - numerator = 0, denominator > 0 → -infinity (read matches denom but not num)
46/// - numerator > 0, denominator = 0 → +infinity (read matches num but not denom)
47/// - both = 0 → NaN (no evidence for or against)
48pub fn compute_log_ratio(numerator: f64, denominator: f64) -> f64 {
49    if numerator == 0.0 && denominator == 0.0 {
50        f64::NAN
51    } else if numerator == 0.0 {
52        f64::NEG_INFINITY
53    } else if denominator == 0.0 {
54        f64::INFINITY
55    } else {
56        (numerator / denominator).log10()
57    }
58}
59
60/// Validate that the index has exactly one bucket and return its ID and name.
61///
62/// Used for the two-index log-ratio workflow where each index holds a single bucket.
63pub fn validate_single_bucket_index(bucket_names: &HashMap<u32, String>) -> Result<(u32, String)> {
64    if bucket_names.len() != 1 {
65        return Err(anyhow!(
66            "log-ratio mode requires each index to have exactly 1 bucket, but found {}.\n\
67             Use 'rype index stats -i <index>' to see bucket information.",
68            bucket_names.len()
69        ));
70    }
71
72    let (&bucket_id, bucket_name) = bucket_names.iter().next().unwrap();
73    Ok((bucket_id, bucket_name.clone()))
74}
75
76/// Validate that two indices are compatible for log-ratio computation.
77///
78/// Checks that k, w, and salt match between the numerator and denominator indices.
79pub fn validate_compatible_indices(a: &IndexMetadata, b: &IndexMetadata) -> Result<()> {
80    if a.k != b.k {
81        return Err(anyhow!(
82            "Numerator and denominator indices have different k values: {} vs {}.\n\
83             Both indices must be built with the same k, w, and salt.",
84            a.k,
85            b.k
86        ));
87    }
88    if a.w != b.w {
89        return Err(anyhow!(
90            "Numerator and denominator indices have different w values: {} vs {}.\n\
91             Both indices must be built with the same k, w, and salt.",
92            a.w,
93            b.w
94        ));
95    }
96    if a.salt != b.salt {
97        return Err(anyhow!(
98            "Numerator and denominator indices have different salt values: {:#x} vs {:#x}.\n\
99             Both indices must be built with the same k, w, and salt.",
100            a.salt,
101            b.salt
102        ));
103    }
104    Ok(())
105}
106
107/// Result of partitioning reads by numerator score into fast-path and needs-denominator groups.
108pub struct PartitionResult {
109    /// Reads resolved via fast path (NumHigh only) — no denominator needed.
110    pub fast_path_results: Vec<LogRatioResult>,
111    /// Query IDs of reads that need denominator classification.
112    pub needs_denom_query_ids: Vec<i64>,
113    /// Numerator scores indexed by query_id (0.0 for zero-score reads).
114    /// Only entries for needs-denom reads are meaningful.
115    pub num_scores: Vec<f64>,
116}
117
118/// Partition reads by numerator classification results into fast-path and needs-denominator groups.
119///
120/// For each read in 0..total_reads:
121/// - If `skip_threshold` is set and the read's numerator score >= threshold, it gets
122///   fast-path `+inf` (NumHigh).
123/// - Otherwise (including score=0), the read needs denominator classification.
124///   Score=0 reads need the denominator to distinguish -inf (denom>0) from NaN (denom=0).
125///
126/// `num_results` are the HitResults from classifying against the numerator index
127/// (single bucket, threshold=0.0).
128pub fn partition_by_numerator_score(
129    num_results: &[HitResult],
130    total_reads: usize,
131    skip_threshold: Option<f64>,
132) -> PartitionResult {
133    // Dense score lookup: query_ids are sequential 0..total_reads
134    let mut num_scores = vec![0.0_f64; total_reads];
135    for hit in num_results {
136        num_scores[hit.query_id as usize] = hit.score;
137    }
138
139    let mut fast_path_results = Vec::new();
140    let mut needs_denom_query_ids = Vec::new();
141
142    for query_id in 0..total_reads as i64 {
143        let score = num_scores[query_id as usize];
144
145        if let Some(thresh) = skip_threshold {
146            if score >= thresh {
147                // Strong numerator signal → +inf
148                fast_path_results.push(LogRatioResult {
149                    query_id,
150                    log_ratio: f64::INFINITY,
151                    fast_path: FastPath::NumHigh,
152                });
153            } else {
154                needs_denom_query_ids.push(query_id);
155            }
156        } else {
157            needs_denom_query_ids.push(query_id);
158        }
159    }
160
161    PartitionResult {
162        fast_path_results,
163        needs_denom_query_ids,
164        num_scores,
165    }
166}
167
168/// Validate that two sharded indices are compatible for log-ratio classification.
169///
170/// Checks that both are single-bucket indices with matching k, w, and salt.
171/// Returns `(k, w, salt)` on success.
172pub fn validate_log_ratio_indices(
173    numerator: &ShardedInvertedIndex,
174    denominator: &ShardedInvertedIndex,
175) -> Result<(usize, usize, u64)> {
176    let num_manifest = numerator.manifest();
177    let denom_manifest = denominator.manifest();
178
179    validate_single_bucket_index(&num_manifest.bucket_names)
180        .map_err(|e| anyhow!("numerator index: {}", e))?;
181    validate_single_bucket_index(&denom_manifest.bucket_names)
182        .map_err(|e| anyhow!("denominator index: {}", e))?;
183
184    if num_manifest.k != denom_manifest.k {
185        return Err(anyhow!(
186            "Numerator and denominator indices have different k values: {} vs {}.\n\
187             Both indices must be built with the same k, w, and salt.",
188            num_manifest.k,
189            denom_manifest.k
190        ));
191    }
192    if num_manifest.w != denom_manifest.w {
193        return Err(anyhow!(
194            "Numerator and denominator indices have different w values: {} vs {}.\n\
195             Both indices must be built with the same k, w, and salt.",
196            num_manifest.w,
197            denom_manifest.w
198        ));
199    }
200    if num_manifest.salt != denom_manifest.salt {
201        return Err(anyhow!(
202            "Numerator and denominator indices have different salt values: {:#x} vs {:#x}.\n\
203             Both indices must be built with the same k, w, and salt.",
204            num_manifest.salt,
205            denom_manifest.salt
206        ));
207    }
208
209    Ok((num_manifest.k, num_manifest.w, num_manifest.salt))
210}
211
212/// Classify a batch of reads using log-ratio (numerator vs denominator).
213///
214/// This is the core log-ratio pipeline:
215/// 1. Validate indices (single-bucket, compatible k/w/salt)
216/// 2. Extract minimizers from all reads
217/// 3. Classify all against numerator (threshold=0.0)
218/// 4. Partition into fast-path (NumHigh) and needs-denom
219/// 5. Classify needs-denom subset against denominator
220/// 6. Compute log10(num_score / denom_score) for each read
221///
222/// Returns one `LogRatioResult` per input read, sorted by the original query IDs
223/// from the input records.
224///
225/// Note: Internally uses sequential query IDs (0..N) for the partition step,
226/// then maps back to the original IDs in the results. This avoids panics when
227/// caller-provided query IDs are non-sequential (e.g., [100, 200, 300]).
228pub fn classify_log_ratio_batch(
229    numerator: &ShardedInvertedIndex,
230    denominator: &ShardedInvertedIndex,
231    records: &[QueryRecord],
232    skip_threshold: Option<f64>,
233) -> Result<Vec<LogRatioResult>> {
234    let num_queries = records.len();
235    if num_queries == 0 {
236        return Ok(Vec::new());
237    }
238
239    let (k, w, salt) = validate_log_ratio_indices(numerator, denominator)?;
240
241    // Save original query IDs, use sequential 0..N internally.
242    // partition_by_numerator_score uses dense arrays indexed by query_id,
243    // so query IDs must be sequential 0..N.
244    let original_ids: Vec<i64> = records.iter().map(|r| r.0).collect();
245    let sequential_ids: Vec<i64> = (0..num_queries as i64).collect();
246
247    let extracted = crate::extract_batch_minimizers(k, w, salt, None, records);
248
249    // Classify against numerator (threshold=0.0 to get all scores)
250    let num_results = crate::classify_from_extracted_minimizers(
251        numerator,
252        &extracted,
253        &sequential_ids,
254        0.0,
255        None,
256    )?;
257
258    // Partition: fast-path vs needs-denom (uses sequential IDs internally)
259    let partition = partition_by_numerator_score(&num_results, num_queries, skip_threshold);
260
261    // Build needs-denom subset
262    let needs_denom_set: HashSet<i64> = partition.needs_denom_query_ids.iter().copied().collect();
263
264    let mut denom_extracted = Vec::new();
265    let mut denom_ids = Vec::new();
266    for (i, ext) in extracted.iter().enumerate() {
267        let seq_id = i as i64;
268        if needs_denom_set.contains(&seq_id) {
269            denom_extracted.push(ext.clone());
270            denom_ids.push(seq_id);
271        }
272    }
273
274    // Classify needs-denom subset against denominator
275    let denom_results = if !denom_ids.is_empty() {
276        crate::classify_from_extracted_minimizers(
277            denominator,
278            &denom_extracted,
279            &denom_ids,
280            0.0,
281            None,
282        )?
283    } else {
284        Vec::new()
285    };
286
287    // Build dense denominator score lookup (indexed by sequential ID)
288    let mut denom_scores = vec![0.0_f64; num_queries];
289    for hit in &denom_results {
290        denom_scores[hit.query_id as usize] = hit.score;
291    }
292
293    // Merge fast-path + computed results, mapping back to original query IDs
294    let mut results: Vec<LogRatioResult> = Vec::with_capacity(num_queries);
295
296    for lr in &partition.fast_path_results {
297        results.push(LogRatioResult {
298            query_id: original_ids[lr.query_id as usize],
299            log_ratio: lr.log_ratio,
300            fast_path: lr.fast_path,
301        });
302    }
303
304    for &seq_id in &partition.needs_denom_query_ids {
305        let idx = seq_id as usize;
306        let num_score = partition.num_scores[idx];
307        let denom_score = denom_scores[idx];
308        let log_ratio = compute_log_ratio(num_score, denom_score);
309        results.push(LogRatioResult {
310            query_id: original_ids[idx],
311            log_ratio,
312            fast_path: FastPath::None,
313        });
314    }
315
316    // Sort by original query_id for deterministic output
317    results.sort_by_key(|r| r.query_id);
318    Ok(results)
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324
325    // FastPath tests
326
327    #[test]
328    fn test_fast_path_as_str() {
329        assert_eq!(FastPath::None.as_str(), "none");
330        assert_eq!(FastPath::NumHigh.as_str(), "num_high");
331    }
332
333    #[test]
334    fn test_log_ratio_result_with_fast_path() {
335        let result = LogRatioResult {
336            query_id: 7,
337            log_ratio: f64::INFINITY,
338            fast_path: FastPath::NumHigh,
339        };
340        assert_eq!(result.fast_path, FastPath::NumHigh);
341
342        let result = LogRatioResult {
343            query_id: 0,
344            log_ratio: 1.5,
345            fast_path: FastPath::None,
346        };
347        assert_eq!(result.fast_path, FastPath::None);
348    }
349
350    // compute_log_ratio tests
351
352    #[test]
353    fn test_compute_log_ratio_both_positive() {
354        let result = compute_log_ratio(100.0, 10.0);
355        assert!((result - 1.0).abs() < 1e-10);
356
357        let result = compute_log_ratio(10.0, 100.0);
358        assert!((result - (-1.0)).abs() < 1e-10);
359    }
360
361    #[test]
362    fn test_compute_log_ratio_equal_scores() {
363        let result = compute_log_ratio(50.0, 50.0);
364        assert!((result - 0.0).abs() < 1e-10);
365    }
366
367    #[test]
368    fn test_compute_log_ratio_numerator_zero() {
369        let result = compute_log_ratio(0.0, 50.0);
370        assert!(result.is_infinite() && result.is_sign_negative());
371    }
372
373    #[test]
374    fn test_compute_log_ratio_denominator_zero() {
375        let result = compute_log_ratio(50.0, 0.0);
376        assert!(result.is_infinite() && result.is_sign_positive());
377    }
378
379    #[test]
380    fn test_compute_log_ratio_both_zero() {
381        let result = compute_log_ratio(0.0, 0.0);
382        assert!(result.is_nan());
383    }
384
385    // validate_single_bucket_index tests
386
387    #[test]
388    fn test_validate_single_bucket_index_passes() {
389        let mut bucket_names = HashMap::new();
390        bucket_names.insert(0, "MyBucket".to_string());
391
392        let result = validate_single_bucket_index(&bucket_names);
393        assert!(result.is_ok());
394
395        let (bucket_id, bucket_name) = result.unwrap();
396        assert_eq!(bucket_id, 0);
397        assert_eq!(bucket_name, "MyBucket");
398    }
399
400    #[test]
401    fn test_validate_single_bucket_index_fails_empty() {
402        let bucket_names: HashMap<u32, String> = HashMap::new();
403
404        let result = validate_single_bucket_index(&bucket_names);
405        assert!(result.is_err());
406        let err = result.unwrap_err().to_string();
407        assert!(err.contains("exactly 1 bucket"));
408        assert!(err.contains("found 0"));
409    }
410
411    #[test]
412    fn test_validate_single_bucket_index_fails_two_buckets() {
413        let mut bucket_names = HashMap::new();
414        bucket_names.insert(0, "A".to_string());
415        bucket_names.insert(1, "B".to_string());
416
417        let result = validate_single_bucket_index(&bucket_names);
418        assert!(result.is_err());
419        let err = result.unwrap_err().to_string();
420        assert!(err.contains("exactly 1 bucket"));
421        assert!(err.contains("found 2"));
422    }
423
424    #[test]
425    fn test_validate_single_bucket_index_preserves_id() {
426        let mut bucket_names = HashMap::new();
427        bucket_names.insert(42, "HighId".to_string());
428
429        let result = validate_single_bucket_index(&bucket_names);
430        assert!(result.is_ok());
431
432        let (bucket_id, bucket_name) = result.unwrap();
433        assert_eq!(bucket_id, 42);
434        assert_eq!(bucket_name, "HighId");
435    }
436
437    // validate_compatible_indices tests
438
439    fn make_metadata(k: usize, w: usize, salt: u64) -> IndexMetadata {
440        IndexMetadata {
441            k,
442            w,
443            salt,
444            bucket_names: HashMap::new(),
445            bucket_sources: HashMap::new(),
446            bucket_minimizer_counts: HashMap::new(),
447            largest_shard_entries: 0,
448            bucket_file_stats: None,
449        }
450    }
451
452    #[test]
453    fn test_validate_compatible_indices_passes_when_matching() {
454        let a = make_metadata(32, 10, 0x5555555555555555);
455        let b = make_metadata(32, 10, 0x5555555555555555);
456
457        assert!(validate_compatible_indices(&a, &b).is_ok());
458    }
459
460    #[test]
461    fn test_validate_compatible_indices_fails_on_k_mismatch() {
462        let a = make_metadata(32, 10, 0x5555555555555555);
463        let b = make_metadata(64, 10, 0x5555555555555555);
464
465        let result = validate_compatible_indices(&a, &b);
466        assert!(result.is_err());
467        let err = result.unwrap_err().to_string();
468        assert!(err.contains("different k values"));
469    }
470
471    #[test]
472    fn test_validate_compatible_indices_fails_on_w_mismatch() {
473        let a = make_metadata(32, 10, 0x5555555555555555);
474        let b = make_metadata(32, 20, 0x5555555555555555);
475
476        let result = validate_compatible_indices(&a, &b);
477        assert!(result.is_err());
478        assert!(result
479            .unwrap_err()
480            .to_string()
481            .contains("different w values"));
482    }
483
484    #[test]
485    fn test_validate_compatible_indices_fails_on_salt_mismatch() {
486        let a = make_metadata(32, 10, 0x5555555555555555);
487        let b = make_metadata(32, 10, 0xAAAAAAAAAAAAAAAA);
488
489        let result = validate_compatible_indices(&a, &b);
490        assert!(result.is_err());
491        assert!(result
492            .unwrap_err()
493            .to_string()
494            .contains("different salt values"));
495    }
496
497    // partition_by_numerator_score tests
498
499    #[test]
500    fn test_partition_all_zeros_goes_to_needs_denom() {
501        let num_results: Vec<HitResult> = vec![];
502        let result = partition_by_numerator_score(&num_results, 3, None);
503
504        assert!(result.fast_path_results.is_empty());
505        assert_eq!(result.needs_denom_query_ids.len(), 3);
506        assert_eq!(result.needs_denom_query_ids, vec![0, 1, 2]);
507        assert!(result.num_scores.iter().all(|&s| s == 0.0));
508    }
509
510    #[test]
511    fn test_partition_with_skip_threshold_creates_two_groups() {
512        let num_results = vec![
513            HitResult {
514                query_id: 1,
515                bucket_id: 0,
516                score: 0.05,
517            },
518            HitResult {
519                query_id: 2,
520                bucket_id: 0,
521                score: 0.5,
522            },
523            HitResult {
524                query_id: 3,
525                bucket_id: 0,
526                score: 0.01,
527            },
528        ];
529
530        let result = partition_by_numerator_score(&num_results, 4, Some(0.1));
531
532        assert_eq!(result.fast_path_results.len(), 1);
533        assert_eq!(result.fast_path_results[0].query_id, 2);
534        assert_eq!(result.fast_path_results[0].fast_path, FastPath::NumHigh);
535        assert!(result.fast_path_results[0].log_ratio == f64::INFINITY);
536
537        assert_eq!(result.needs_denom_query_ids.len(), 3);
538        assert!(result.needs_denom_query_ids.contains(&0));
539        assert!(result.needs_denom_query_ids.contains(&1));
540        assert!(result.needs_denom_query_ids.contains(&3));
541    }
542
543    #[test]
544    fn test_partition_without_skip_threshold_no_fast_path() {
545        let num_results = vec![
546            HitResult {
547                query_id: 1,
548                bucket_id: 0,
549                score: 0.5,
550            },
551            HitResult {
552                query_id: 2,
553                bucket_id: 0,
554                score: 0.9,
555            },
556        ];
557
558        let result = partition_by_numerator_score(&num_results, 3, None);
559
560        assert!(result.fast_path_results.is_empty());
561        assert_eq!(result.needs_denom_query_ids.len(), 3);
562    }
563
564    #[test]
565    fn test_partition_skip_threshold_at_boundary() {
566        let num_results = vec![HitResult {
567            query_id: 0,
568            bucket_id: 0,
569            score: 0.1,
570        }];
571
572        let result = partition_by_numerator_score(&num_results, 1, Some(0.1));
573
574        assert_eq!(result.fast_path_results.len(), 1);
575        assert_eq!(result.fast_path_results[0].fast_path, FastPath::NumHigh);
576        assert!(result.needs_denom_query_ids.is_empty());
577    }
578
579    #[test]
580    fn test_partition_empty_batch() {
581        let result = partition_by_numerator_score(&[], 0, None);
582
583        assert!(result.fast_path_results.is_empty());
584        assert!(result.needs_denom_query_ids.is_empty());
585        assert!(result.num_scores.is_empty());
586    }
587}