1use super::{Metric, thresholding::apply_threshold};
4
5#[derive(Clone, Copy, PartialEq, Eq)]
7#[allow(dead_code)]
8pub enum Cardinality {
9 One,
11 Reciprocal,
13}
14
15#[derive(Clone, Copy, PartialEq, Eq)]
17#[allow(dead_code)]
18pub enum Bias {
19 Flat,
21 Front,
23 Middle,
25 Back,
27}
28
29fn extract_ranges(binary: &[u8]) -> Vec<(usize, usize)> {
31 let mut ranges = Vec::new();
32 let mut start = None;
33 for (i, &v) in binary.iter().enumerate() {
34 match (v, start) {
35 (1, None) => start = Some(i),
36 (0, Some(s)) => {
37 ranges.push((s, i - 1));
38 start = None;
39 }
40 _ => {}
41 }
42 }
43 if let Some(s) = start {
44 ranges.push((s, binary.len() - 1));
45 }
46 ranges
47}
48
49fn delta(pos: usize, range_start: usize, range_end: usize, bias: Bias) -> f64 {
51 let len = (range_end - range_start + 1) as f64;
52 match bias {
53 Bias::Flat => 1.0,
54 Bias::Front => {
55 let i = (pos - range_start + 1) as f64;
56 (2.0 * (len - i + 1.0)) / (len * (len + 1.0))
57 }
58 Bias::Back => {
59 let i = (pos - range_start + 1) as f64;
60 (2.0 * i) / (len * (len + 1.0))
61 }
62 Bias::Middle => {
63 let i = (pos - range_start + 1) as f64;
64 let mid = (len + 1.0) / 2.0;
65 let dist = (i - mid).abs();
66 let peak = if len % 2.0 == 0.0 { 0.5 } else { 1.0 };
67 if len == 1.0 {
69 1.0
70 } else {
71 peak - dist * peak / (len / 2.0).ceil()
72 }
73 }
74 }
75}
76
77fn omega(pred: (usize, usize), real: (usize, usize), bias: Bias) -> f64 {
79 let overlap_start = pred.0.max(real.0);
80 let overlap_end = pred.1.min(real.1);
81 if overlap_start > overlap_end {
82 return 0.0;
83 }
84 let my_len = (pred.1 - pred.0 + 1) as f64;
85 let weighted_overlap: f64 = (overlap_start..=overlap_end)
86 .map(|p| delta(p, pred.0, pred.1, bias))
87 .sum();
88 let total_weight: f64 = (pred.0..=pred.1)
89 .map(|p| delta(p, pred.0, pred.1, bias))
90 .sum();
91 if total_weight < 1e-12 {
92 return 0.0;
93 }
94 weighted_overlap / total_weight * (overlap_end - overlap_start + 1) as f64 / my_len
95}
96
97fn gamma(overlap_count: usize, cardinality: Cardinality) -> f64 {
99 match cardinality {
100 Cardinality::One => 1.0,
101 Cardinality::Reciprocal => {
102 if overlap_count == 0 {
103 0.0
104 } else {
105 1.0 / overlap_count as f64
106 }
107 }
108 }
109}
110
111fn range_score(
117 my_range: (usize, usize),
118 ref_ranges: &[(usize, usize)],
119 alpha: f64,
120 cardinality: Cardinality,
121 bias: Bias,
122) -> f64 {
123 let mut overlap_reward = 0.0;
124 let mut overlap_count = 0;
125
126 for &r in ref_ranges {
127 let ov = omega(my_range, r, bias);
128 if ov > 0.0 {
129 overlap_reward += ov;
130 overlap_count += 1;
131 }
132 }
133
134 let existence = if overlap_count > 0 { 1.0 } else { 0.0 };
135 overlap_reward *= gamma(overlap_count, cardinality);
136
137 alpha * existence + (1.0 - alpha) * overlap_reward
138}
139
140pub(crate) fn range_precision_raw(
142 real: &[u8],
143 pred: &[u8],
144 alpha: f64,
145 cardinality: Cardinality,
146 bias: Bias,
147) -> f64 {
148 let pred_ranges = extract_ranges(pred);
149 if pred_ranges.is_empty() {
150 return 0.0;
151 }
152 let real_ranges = extract_ranges(real);
153 let sum: f64 = pred_ranges
154 .iter()
155 .map(|&p| range_score(p, &real_ranges, alpha, cardinality, bias))
156 .sum();
157 sum / pred_ranges.len() as f64
158}
159
160pub(crate) fn range_recall_raw(
162 real: &[u8],
163 pred: &[u8],
164 alpha: f64,
165 cardinality: Cardinality,
166 bias: Bias,
167) -> f64 {
168 let real_ranges = extract_ranges(real);
169 if real_ranges.is_empty() {
170 return f64::NAN;
171 }
172 let pred_ranges = extract_ranges(pred);
173 let sum: f64 = real_ranges
174 .iter()
175 .map(|&r| range_score(r, &pred_ranges, alpha, cardinality, bias))
176 .sum();
177 sum / real_ranges.len() as f64
178}
179
180fn range_fscore(prec: f64, rec: f64, beta: f64) -> f64 {
182 let denom = beta * beta * prec + rec;
183 if denom < 1e-12 {
184 return 0.0;
185 }
186 (1.0 + beta * beta) * prec * rec / denom
187}
188
189pub struct RangePrecision {
193 pub alpha: f64,
195 pub cardinality: Cardinality,
197 pub bias: Bias,
199 pub percentile: f64,
201}
202
203pub struct RangeRecall {
205 pub alpha: f64,
207 pub cardinality: Cardinality,
209 pub bias: Bias,
211 pub percentile: f64,
213}
214
215pub struct RangeFScore {
217 pub beta: f64,
219 pub p_alpha: f64,
221 pub r_alpha: f64,
223 pub cardinality: Cardinality,
225 pub p_bias: Bias,
227 pub r_bias: Bias,
229 pub percentile: f64,
231}
232
233pub struct RangeAuc {
235 pub cardinality: Cardinality,
237 pub bias: Bias,
239 pub max_samples: usize,
241}
242
243impl Default for RangePrecision {
244 fn default() -> Self {
245 Self {
246 alpha: 0.0,
247 cardinality: Cardinality::One,
248 bias: Bias::Flat,
249 percentile: 90.0,
250 }
251 }
252}
253
254impl Default for RangeRecall {
255 fn default() -> Self {
256 Self {
257 alpha: 0.0,
258 cardinality: Cardinality::One,
259 bias: Bias::Flat,
260 percentile: 90.0,
261 }
262 }
263}
264
265impl Default for RangeFScore {
266 fn default() -> Self {
267 Self {
268 beta: 1.0,
269 p_alpha: 0.0,
270 r_alpha: 0.0,
271 cardinality: Cardinality::One,
272 p_bias: Bias::Flat,
273 r_bias: Bias::Flat,
274 percentile: 90.0,
275 }
276 }
277}
278
279impl Default for RangeAuc {
280 fn default() -> Self {
281 Self {
282 cardinality: Cardinality::One,
283 bias: Bias::Flat,
284 max_samples: 50,
285 }
286 }
287}
288
289impl Metric for RangePrecision {
290 fn name(&self) -> &str {
291 "RangePrec"
292 }
293 fn score(&self, labels: &[u8], scores: &[f32]) -> f64 {
294 let mut sorted = scores.to_vec();
295 sorted.sort_by(|a, b| a.total_cmp(b));
296 let idx = ((self.percentile / 100.0) * (sorted.len() - 1) as f64).round() as usize;
297 let thresh = sorted[idx.min(sorted.len() - 1)];
298 let pred = apply_threshold(scores, thresh);
299 range_precision_raw(labels, &pred, self.alpha, self.cardinality, self.bias)
300 }
301}
302
303impl Metric for RangeRecall {
304 fn name(&self) -> &str {
305 "RangeRec"
306 }
307 fn score(&self, labels: &[u8], scores: &[f32]) -> f64 {
308 let mut sorted = scores.to_vec();
309 sorted.sort_by(|a, b| a.total_cmp(b));
310 let idx = ((self.percentile / 100.0) * (sorted.len() - 1) as f64).round() as usize;
311 let thresh = sorted[idx.min(sorted.len() - 1)];
312 let pred = apply_threshold(scores, thresh);
313 range_recall_raw(labels, &pred, self.alpha, self.cardinality, self.bias)
314 }
315}
316
317impl Metric for RangeFScore {
318 fn name(&self) -> &str {
319 "RangeF1"
320 }
321 fn score(&self, labels: &[u8], scores: &[f32]) -> f64 {
322 let mut sorted = scores.to_vec();
323 sorted.sort_by(|a, b| a.total_cmp(b));
324 let idx = ((self.percentile / 100.0) * (sorted.len() - 1) as f64).round() as usize;
325 let thresh = sorted[idx.min(sorted.len() - 1)];
326 let pred = apply_threshold(scores, thresh);
327 let p = range_precision_raw(labels, &pred, self.p_alpha, self.cardinality, self.p_bias);
328 let r = range_recall_raw(labels, &pred, self.r_alpha, self.cardinality, self.r_bias);
329 range_fscore(p, r, self.beta)
330 }
331}
332
333pub(crate) fn range_pr_auc_impl(
335 labels: &[u8],
336 scores: &[f32],
337 cardinality: Cardinality,
338 bias: Bias,
339 max_samples: usize,
340) -> f64 {
341 let mut sorted_scores = scores.to_vec();
343 sorted_scores.sort_by(|a, b| a.total_cmp(b));
344 sorted_scores.dedup_by(|a, b| (*a - *b).abs() < f32::EPSILON);
345
346 let step = if sorted_scores.len() <= max_samples {
347 1
348 } else {
349 sorted_scores.len() / max_samples
350 };
351
352 let thresholds: Vec<f32> = sorted_scores.into_iter().step_by(step.max(1)).collect();
353
354 let mut points: Vec<(f64, f64)> = thresholds
355 .iter()
356 .map(|&t| {
357 let pred = apply_threshold(scores, t);
358 let p = range_precision_raw(labels, &pred, 0.0, cardinality, bias);
359 let r = range_recall_raw(labels, &pred, 0.0, cardinality, bias);
360 (r, p)
361 })
362 .collect();
363
364 points.push((0.0, 1.0));
366 points.push((1.0, 0.0));
367 points.sort_by(|a, b| a.0.total_cmp(&b.0));
368 points.dedup_by(|a, b| (a.0 - b.0).abs() < 1e-12);
369
370 let mut auc = 0.0;
372 for w in points.windows(2) {
373 let (r0, p0) = w[0];
374 let (r1, p1) = w[1];
375 auc += (r1 - r0) * (p0 + p1) / 2.0;
376 }
377 auc
378}
379
380impl Metric for RangeAuc {
381 fn name(&self) -> &str {
382 "RangePR-AUC"
383 }
384 fn score(&self, labels: &[u8], scores: &[f32]) -> f64 {
385 range_pr_auc_impl(
386 labels,
387 scores,
388 self.cardinality,
389 self.bias,
390 self.max_samples,
391 )
392 }
393}
394
395#[cfg(test)]
396mod tests {
397 use super::*;
398
399 #[test]
400 fn extract_ranges_basic() {
401 let b = vec![0, 1, 1, 0, 1, 0];
402 assert_eq!(extract_ranges(&b), vec![(1, 2), (4, 4)]);
403 }
404
405 #[test]
406 fn omega_full_overlap() {
407 let score = omega((2, 5), (2, 5), Bias::Flat);
409 assert!((score - 1.0).abs() < 1e-9, "got {score}");
410 }
411
412 #[test]
413 fn omega_no_overlap() {
414 let score = omega((0, 2), (5, 8), Bias::Flat);
415 assert!((score).abs() < 1e-9, "got {score}");
416 }
417
418 #[test]
419 fn gamma_reciprocal_penalizes() {
420 assert!((gamma(1, Cardinality::One) - 1.0).abs() < 1e-9);
421 assert!((gamma(2, Cardinality::Reciprocal) - 0.5).abs() < 1e-9);
422 }
423
424 #[test]
425 fn range_precision_perfect() {
426 let real = vec![0, 0, 1, 1, 1, 0, 0];
427 let pred = vec![0, 0, 1, 1, 1, 0, 0];
428 let p = range_precision_raw(&real, &pred, 0.0, Cardinality::One, Bias::Flat);
429 assert!((p - 1.0).abs() < 1e-9, "got {p}");
430 }
431
432 #[test]
433 fn range_recall_perfect() {
434 let real = vec![0, 0, 1, 1, 1, 0, 0];
435 let pred = vec![0, 0, 1, 1, 1, 0, 0];
436 let r = range_recall_raw(&real, &pred, 0.0, Cardinality::One, Bias::Flat);
437 assert!((r - 1.0).abs() < 1e-9, "got {r}");
438 }
439}