Skip to main content

touchstone_rs/metrics/
range.rs

1/// Range-based precision and recall from Tatbul et al., NeurIPS 2018.
2/// "Precision and Recall for Time Series"
3use super::{Metric, thresholding::apply_threshold};
4
5/// Penalization policy when multiple predicted ranges match one true range.
6#[derive(Clone, Copy, PartialEq, Eq)]
7#[allow(dead_code)]
8pub enum Cardinality {
9    /// Each predicted range contributes fully (1.0) regardless of multiplicity.
10    One,
11    /// Multiple overlaps incur a 1/count penalty.
12    Reciprocal,
13}
14
15/// Positional weighting within each range.
16#[derive(Clone, Copy, PartialEq, Eq)]
17#[allow(dead_code)]
18pub enum Bias {
19    /// All positions within a range have equal weight.
20    Flat,
21    /// Weight decreases from start to end (front-loading).
22    Front,
23    /// Weight is highest at the center, declining toward edges.
24    Middle,
25    /// Weight increases from start to end (back-loading).
26    Back,
27}
28
29/// Extracts contiguous ranges of 1s from a binary vector, returning start/end indices.
30fn extract_ranges(binary: &[u8]) -> Vec<(usize, usize)> {
31    let mut ranges = Vec::new();
32    let mut start = None;
33    for (i, &v) in binary.iter().enumerate() {
34        match (v, start) {
35            (1, None) => start = Some(i),
36            (0, Some(s)) => {
37                ranges.push((s, i - 1));
38                start = None;
39            }
40            _ => {}
41        }
42    }
43    if let Some(s) = start {
44        ranges.push((s, binary.len() - 1));
45    }
46    ranges
47}
48
49/// Computes the positional weight for a position within a range given a bias strategy.
50fn delta(pos: usize, range_start: usize, range_end: usize, bias: Bias) -> f64 {
51    let len = (range_end - range_start + 1) as f64;
52    match bias {
53        Bias::Flat => 1.0,
54        Bias::Front => {
55            let i = (pos - range_start + 1) as f64;
56            (2.0 * (len - i + 1.0)) / (len * (len + 1.0))
57        }
58        Bias::Back => {
59            let i = (pos - range_start + 1) as f64;
60            (2.0 * i) / (len * (len + 1.0))
61        }
62        Bias::Middle => {
63            let i = (pos - range_start + 1) as f64;
64            let mid = (len + 1.0) / 2.0;
65            let dist = (i - mid).abs();
66            let peak = if len % 2.0 == 0.0 { 0.5 } else { 1.0 };
67            // Triangle: 0 at edges, peak at center
68            if len == 1.0 {
69                1.0
70            } else {
71                peak - dist * peak / (len / 2.0).ceil()
72            }
73        }
74    }
75}
76
77/// Computes the weighted overlap between a predicted and real range.
78fn omega(pred: (usize, usize), real: (usize, usize), bias: Bias) -> f64 {
79    let overlap_start = pred.0.max(real.0);
80    let overlap_end = pred.1.min(real.1);
81    if overlap_start > overlap_end {
82        return 0.0;
83    }
84    let my_len = (pred.1 - pred.0 + 1) as f64;
85    let weighted_overlap: f64 = (overlap_start..=overlap_end)
86        .map(|p| delta(p, pred.0, pred.1, bias))
87        .sum();
88    let total_weight: f64 = (pred.0..=pred.1)
89        .map(|p| delta(p, pred.0, pred.1, bias))
90        .sum();
91    if total_weight < 1e-12 {
92        return 0.0;
93    }
94    weighted_overlap / total_weight * (overlap_end - overlap_start + 1) as f64 / my_len
95}
96
97/// Computes the cardinality penalty applied to overlap counts.
98fn gamma(overlap_count: usize, cardinality: Cardinality) -> f64 {
99    match cardinality {
100        Cardinality::One => 1.0,
101        Cardinality::Reciprocal => {
102            if overlap_count == 0 {
103                0.0
104            } else {
105                1.0 / overlap_count as f64
106            }
107        }
108    }
109}
110
111/// Score a single anomaly range (either predicted or real) against the set of
112/// reference ranges on the other side.
113///
114/// When scoring precision: `my_range` = predicted range, `ref_ranges` = real ranges.
115/// When scoring recall:    `my_range` = real range,      `ref_ranges` = predicted ranges.
116fn range_score(
117    my_range: (usize, usize),
118    ref_ranges: &[(usize, usize)],
119    alpha: f64,
120    cardinality: Cardinality,
121    bias: Bias,
122) -> f64 {
123    let mut overlap_reward = 0.0;
124    let mut overlap_count = 0;
125
126    for &r in ref_ranges {
127        let ov = omega(my_range, r, bias);
128        if ov > 0.0 {
129            overlap_reward += ov;
130            overlap_count += 1;
131        }
132    }
133
134    let existence = if overlap_count > 0 { 1.0 } else { 0.0 };
135    overlap_reward *= gamma(overlap_count, cardinality);
136
137    alpha * existence + (1.0 - alpha) * overlap_reward
138}
139
140/// Computes range-based precision: average score of predicted ranges against real ranges.
141pub(crate) fn range_precision_raw(
142    real: &[u8],
143    pred: &[u8],
144    alpha: f64,
145    cardinality: Cardinality,
146    bias: Bias,
147) -> f64 {
148    let pred_ranges = extract_ranges(pred);
149    if pred_ranges.is_empty() {
150        return 0.0;
151    }
152    let real_ranges = extract_ranges(real);
153    let sum: f64 = pred_ranges
154        .iter()
155        .map(|&p| range_score(p, &real_ranges, alpha, cardinality, bias))
156        .sum();
157    sum / pred_ranges.len() as f64
158}
159
160/// Computes range-based recall: average score of real ranges against predicted ranges.
161pub(crate) fn range_recall_raw(
162    real: &[u8],
163    pred: &[u8],
164    alpha: f64,
165    cardinality: Cardinality,
166    bias: Bias,
167) -> f64 {
168    let real_ranges = extract_ranges(real);
169    if real_ranges.is_empty() {
170        return f64::NAN;
171    }
172    let pred_ranges = extract_ranges(pred);
173    let sum: f64 = real_ranges
174        .iter()
175        .map(|&r| range_score(r, &pred_ranges, alpha, cardinality, bias))
176        .sum();
177    sum / real_ranges.len() as f64
178}
179
180/// Computes the F-score from range precision and recall with a given beta weight.
181fn range_fscore(prec: f64, rec: f64, beta: f64) -> f64 {
182    let denom = beta * beta * prec + rec;
183    if denom < 1e-12 {
184        return 0.0;
185    }
186    (1.0 + beta * beta) * prec * rec / denom
187}
188
189// ─── public metric structs ──────────────────────────────────────────────────
190
191/// Range-based precision metric (Tatbul et al., NeurIPS 2018).
192pub struct RangePrecision {
193    /// Weight between overlap-only (0.0) and existence-aware (1.0) scoring.
194    pub alpha: f64,
195    /// Penalization policy when multiple predicted ranges match one true range.
196    pub cardinality: Cardinality,
197    /// Positional weighting within each range.
198    pub bias: Bias,
199    /// Score percentile used to derive a binary prediction threshold.
200    pub percentile: f64,
201}
202
203/// Range-based recall metric (Tatbul et al., NeurIPS 2018).
204pub struct RangeRecall {
205    /// Weight between overlap-only (0.0) and existence-aware (1.0) scoring.
206    pub alpha: f64,
207    /// Penalization policy when multiple predicted ranges match one true range.
208    pub cardinality: Cardinality,
209    /// Positional weighting within each range.
210    pub bias: Bias,
211    /// Score percentile used to derive a binary prediction threshold.
212    pub percentile: f64,
213}
214
215/// Range-based F-score metric (Tatbul et al., NeurIPS 2018).
216pub struct RangeFScore {
217    /// Relative recall weight in F-score (`1.0` gives F1).
218    pub beta: f64,
219    /// Alpha used in the precision component.
220    pub p_alpha: f64,
221    /// Alpha used in the recall component.
222    pub r_alpha: f64,
223    /// Penalization policy for range multiplicity.
224    pub cardinality: Cardinality,
225    /// Bias used in the precision component.
226    pub p_bias: Bias,
227    /// Bias used in the recall component.
228    pub r_bias: Bias,
229    /// Score percentile used to derive a binary prediction threshold.
230    pub percentile: f64,
231}
232
233/// AUC of the range-based PR curve, sweeping thresholds.
234pub struct RangeAuc {
235    /// Penalization policy when multiple predicted ranges match one true range.
236    pub cardinality: Cardinality,
237    /// Positional weighting within each range.
238    pub bias: Bias,
239    /// Maximum number of thresholds sampled when approximating the curve.
240    pub max_samples: usize,
241}
242
243impl Default for RangePrecision {
244    fn default() -> Self {
245        Self {
246            alpha: 0.0,
247            cardinality: Cardinality::One,
248            bias: Bias::Flat,
249            percentile: 90.0,
250        }
251    }
252}
253
254impl Default for RangeRecall {
255    fn default() -> Self {
256        Self {
257            alpha: 0.0,
258            cardinality: Cardinality::One,
259            bias: Bias::Flat,
260            percentile: 90.0,
261        }
262    }
263}
264
265impl Default for RangeFScore {
266    fn default() -> Self {
267        Self {
268            beta: 1.0,
269            p_alpha: 0.0,
270            r_alpha: 0.0,
271            cardinality: Cardinality::One,
272            p_bias: Bias::Flat,
273            r_bias: Bias::Flat,
274            percentile: 90.0,
275        }
276    }
277}
278
279impl Default for RangeAuc {
280    fn default() -> Self {
281        Self {
282            cardinality: Cardinality::One,
283            bias: Bias::Flat,
284            max_samples: 50,
285        }
286    }
287}
288
289impl Metric for RangePrecision {
290    fn name(&self) -> &str {
291        "RangePrec"
292    }
293    fn score(&self, labels: &[u8], scores: &[f32]) -> f64 {
294        let mut sorted = scores.to_vec();
295        sorted.sort_by(|a, b| a.total_cmp(b));
296        let idx = ((self.percentile / 100.0) * (sorted.len() - 1) as f64).round() as usize;
297        let thresh = sorted[idx.min(sorted.len() - 1)];
298        let pred = apply_threshold(scores, thresh);
299        range_precision_raw(labels, &pred, self.alpha, self.cardinality, self.bias)
300    }
301}
302
303impl Metric for RangeRecall {
304    fn name(&self) -> &str {
305        "RangeRec"
306    }
307    fn score(&self, labels: &[u8], scores: &[f32]) -> f64 {
308        let mut sorted = scores.to_vec();
309        sorted.sort_by(|a, b| a.total_cmp(b));
310        let idx = ((self.percentile / 100.0) * (sorted.len() - 1) as f64).round() as usize;
311        let thresh = sorted[idx.min(sorted.len() - 1)];
312        let pred = apply_threshold(scores, thresh);
313        range_recall_raw(labels, &pred, self.alpha, self.cardinality, self.bias)
314    }
315}
316
317impl Metric for RangeFScore {
318    fn name(&self) -> &str {
319        "RangeF1"
320    }
321    fn score(&self, labels: &[u8], scores: &[f32]) -> f64 {
322        let mut sorted = scores.to_vec();
323        sorted.sort_by(|a, b| a.total_cmp(b));
324        let idx = ((self.percentile / 100.0) * (sorted.len() - 1) as f64).round() as usize;
325        let thresh = sorted[idx.min(sorted.len() - 1)];
326        let pred = apply_threshold(scores, thresh);
327        let p = range_precision_raw(labels, &pred, self.p_alpha, self.cardinality, self.p_bias);
328        let r = range_recall_raw(labels, &pred, self.r_alpha, self.cardinality, self.r_bias);
329        range_fscore(p, r, self.beta)
330    }
331}
332
333/// Computes the area under the range-based precision-recall curve.
334pub(crate) fn range_pr_auc_impl(
335    labels: &[u8],
336    scores: &[f32],
337    cardinality: Cardinality,
338    bias: Bias,
339    max_samples: usize,
340) -> f64 {
341    // Collect unique thresholds (capped at max_samples evenly spaced).
342    let mut sorted_scores = scores.to_vec();
343    sorted_scores.sort_by(|a, b| a.total_cmp(b));
344    sorted_scores.dedup_by(|a, b| (*a - *b).abs() < f32::EPSILON);
345
346    let step = if sorted_scores.len() <= max_samples {
347        1
348    } else {
349        sorted_scores.len() / max_samples
350    };
351
352    let thresholds: Vec<f32> = sorted_scores.into_iter().step_by(step.max(1)).collect();
353
354    let mut points: Vec<(f64, f64)> = thresholds
355        .iter()
356        .map(|&t| {
357            let pred = apply_threshold(scores, t);
358            let p = range_precision_raw(labels, &pred, 0.0, cardinality, bias);
359            let r = range_recall_raw(labels, &pred, 0.0, cardinality, bias);
360            (r, p)
361        })
362        .collect();
363
364    // Add sentinel endpoints.
365    points.push((0.0, 1.0));
366    points.push((1.0, 0.0));
367    points.sort_by(|a, b| a.0.total_cmp(&b.0));
368    points.dedup_by(|a, b| (a.0 - b.0).abs() < 1e-12);
369
370    // Trapezoidal integration.
371    let mut auc = 0.0;
372    for w in points.windows(2) {
373        let (r0, p0) = w[0];
374        let (r1, p1) = w[1];
375        auc += (r1 - r0) * (p0 + p1) / 2.0;
376    }
377    auc
378}
379
380impl Metric for RangeAuc {
381    fn name(&self) -> &str {
382        "RangePR-AUC"
383    }
384    fn score(&self, labels: &[u8], scores: &[f32]) -> f64 {
385        range_pr_auc_impl(
386            labels,
387            scores,
388            self.cardinality,
389            self.bias,
390            self.max_samples,
391        )
392    }
393}
394
395#[cfg(test)]
396mod tests {
397    use super::*;
398
399    #[test]
400    fn extract_ranges_basic() {
401        let b = vec![0, 1, 1, 0, 1, 0];
402        assert_eq!(extract_ranges(&b), vec![(1, 2), (4, 4)]);
403    }
404
405    #[test]
406    fn omega_full_overlap() {
407        // Identical ranges → overlap score = 1.0
408        let score = omega((2, 5), (2, 5), Bias::Flat);
409        assert!((score - 1.0).abs() < 1e-9, "got {score}");
410    }
411
412    #[test]
413    fn omega_no_overlap() {
414        let score = omega((0, 2), (5, 8), Bias::Flat);
415        assert!((score).abs() < 1e-9, "got {score}");
416    }
417
418    #[test]
419    fn gamma_reciprocal_penalizes() {
420        assert!((gamma(1, Cardinality::One) - 1.0).abs() < 1e-9);
421        assert!((gamma(2, Cardinality::Reciprocal) - 0.5).abs() < 1e-9);
422    }
423
424    #[test]
425    fn range_precision_perfect() {
426        let real = vec![0, 0, 1, 1, 1, 0, 0];
427        let pred = vec![0, 0, 1, 1, 1, 0, 0];
428        let p = range_precision_raw(&real, &pred, 0.0, Cardinality::One, Bias::Flat);
429        assert!((p - 1.0).abs() < 1e-9, "got {p}");
430    }
431
432    #[test]
433    fn range_recall_perfect() {
434        let real = vec![0, 0, 1, 1, 1, 0, 0];
435        let pred = vec![0, 0, 1, 1, 1, 0, 0];
436        let r = range_recall_raw(&real, &pred, 0.0, Cardinality::One, Bias::Flat);
437        assert!((r - 1.0).abs() < 1e-9, "got {r}");
438    }
439}