sketchlib/distances/
jaccard.rs

1//! Implementation of Jaccard, core and accessory distance calculations
2use crate::sketch::multisketch::MultiSketch;
3use crate::sketch::BBITS;
4
5/// Returns the Jaccard index between two samples
6pub fn jaccard_index(
7    sketch1: &[u64],
8    sketch2: &[u64],
9    sketchsize64: u64,
10    c1: Option<f64>,
11    c2: Option<f64>,
12    completeness_cutoff: f64,
13) -> f64 {
14    let unionsize = (u64::BITS as u64 * sketchsize64) as f64;
15    let samebits: u32 = sketch1
16        .chunks_exact(BBITS as usize)
17        .zip(sketch2.chunks_exact(BBITS as usize))
18        .map(|(chunk1, chunk2)| {
19            let mut bits: u64 = !0;
20            chunk1.iter().zip(chunk2.iter()).for_each(|(&s1, &s2)| {
21                bits &= !(s1 ^ s2);
22            });
23            bits.count_ones()
24        })
25        .sum();
26    let maxnbits = sketchsize64 as u32 * u64::BITS;
27    let expected_samebits = maxnbits >> BBITS;
28
29    log::trace!("samebits:{samebits} expected_samebits:{expected_samebits} maxnbits:{maxnbits}");
30    let diff = samebits.saturating_sub(expected_samebits);
31    let intersize = (diff as f64 * maxnbits as f64) / (maxnbits - expected_samebits) as f64;
32    log::trace!("intersize:{intersize} unionsize:{unionsize}");
33    let mut jaccard_index = intersize / unionsize;
34
35    // Apply completeness correction if both completeness values are provided
36    if let (Some(c1_val), Some(c2_val)) = (c1, c2) {
37        if c1_val * c2_val >= completeness_cutoff {
38            jaccard_index = completeness_correction(jaccard_index, c1_val, c2_val);
39            // Cap the corrected Jaccard index at 1.0 to prevent negative distances
40            jaccard_index = jaccard_index.min(1.0);
41        }
42    }
43
44    jaccard_index
45}
46
47/// Converts between Jaccard distance and ANI, using a Poisson model of mutations
48#[inline(always)]
49pub fn ani_pois(jaccard: f64, k: f64) -> f64 {
50    0.0_f64.max(1.0 + 1.0 / k * (((2.0 * jaccard) / (1.0 + jaccard)).ln()))
51}
52
53/// Completeness correction for MAGs
54#[inline(always)]
55pub fn completeness_correction(jaccard: f64, c1: f64, c2: f64) -> f64 {
56    jaccard / (c1 * c2 / (c1 + c2 - c1 * c2))
57}
58
59/// Core and accessory distances between two sketches, using the PopPUNK regression
60/// model
61pub fn core_acc_dist(
62    ref_sketches: &MultiSketch,
63    query_sketches: &MultiSketch,
64    ref_sketch_idx: usize,
65    query_sketch_idx: usize,
66    completeness_vec: Option<&Vec<f64>>,
67    completeness_cutoff: f64,
68) -> (f32, f32) {
69    if ref_sketches.kmer_lengths().len() < 2 {
70        panic!("Need at least two k-mer lengths to calculate core/accessory distances");
71    }
72    let (mut xsum, mut ysum, mut xysum, mut xsquaresum, mut ysquaresum, mut n) =
73        (0.0_f64, 0.0_f64, 0.0_f64, 0.0_f64, 0.0_f64, 0.0_f64);
74    let tolerance = (2.0_f64 / ((ref_sketches.sketch_size * u64::BITS as u64) as f64)).ln();
75    //let tolerance = -100.0_f32;
76    for (k_idx, k) in ref_sketches.kmer_lengths().iter().enumerate() {
77        let c1 = completeness_vec.map(|cv| cv[ref_sketch_idx]);
78        let c2 = completeness_vec.map(|cv| cv[query_sketch_idx]);
79        let y = jaccard_index(
80            ref_sketches.get_sketch_slice(ref_sketch_idx, k_idx),
81            query_sketches.get_sketch_slice(query_sketch_idx, k_idx),
82            ref_sketches.sketchsize64,
83            c1,
84            c2,
85            completeness_cutoff,
86        )
87        .ln();
88        if y < tolerance {
89            break;
90        }
91        let k_fl = *k as f64;
92        xsum += k_fl;
93        ysum += y;
94        xysum += k_fl * y;
95        xsquaresum += k_fl * k_fl;
96        ysquaresum += y * y;
97        n += 1.0;
98    }
99    simple_linear_regression(xsum, ysum, xysum, xsquaresum, ysquaresum, n)
100}
101
102// Linear regression for calculating core/accessory distances from matches, with some
103// sensible bounds for bad fits
104fn simple_linear_regression(
105    xsum: f64,
106    ysum: f64,
107    xysum: f64,
108    xsquaresum: f64,
109    ysquaresum: f64,
110    n: f64,
111) -> (f32, f32) {
112    log::trace!(
113        "xsum:{xsum} ysum:{ysum} xysum:{xysum} xsquaresum:{xsquaresum} ysquaresum:{ysquaresum}"
114    );
115    // No matches
116    if ysum.is_nan() || ysum == f64::NEG_INFINITY || n < 3.0 {
117        return (1.0, 1.0);
118    }
119
120    let xbar = xsum / n;
121    let ybar = ysum / n;
122    let x_diff = xsquaresum - xsum * xsum / n;
123    let y_diff = ysquaresum - ysum * ysum / n;
124    let xstddev = ((xsquaresum - xsum * xsum / n) / n).sqrt();
125    let ystddev = ((ysquaresum - ysum * ysum / n) / n).sqrt();
126    let r = (xysum - xsum * ysum / n) / (x_diff * y_diff).sqrt();
127    let beta = r * ystddev / xstddev;
128    let alpha = -beta * xbar + ybar;
129    log::trace!("r:{r} alpha:{alpha} beta:{beta}");
130
131    let (mut core, mut acc) = (0.0_f64, 0.0_f64);
132    if beta < 0.0 {
133        core = 1.0 - beta.exp();
134    } else if r > 0.0 {
135        core = 1.0;
136    }
137    if alpha < 0.0 {
138        acc = 1.0 - alpha.exp();
139    }
140    (core as f32, acc as f32)
141}