_scors/
lib.rs

1use ndarray::{ArrayView,Ix1};
2use numpy::{NotContiguousError,PyReadonlyArray1};
3use pyo3::prelude::*; // {PyModule,PyResult,Python,pymodule};
4use std::iter::DoubleEndedIterator;
5
6pub enum Order {
7    ASCENDING,
8    DESCENDING
9}
10
11struct ConstWeight {
12    value: f64
13}
14
15impl ConstWeight {
16    fn new(value: f64) -> Self {
17        return ConstWeight { value: value };
18    }
19    fn one() -> Self {
20        return Self::new(1.0);
21    }
22}
23
24pub trait Data<T: Clone>: {
25    // TODO This is necessary because it seems that there is no trait like that in rust
26    //      Maybe I am just not aware, but for now use my own trait.
27    fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T>;
28    fn get_at(&self, index: usize) -> T;
29}
30
31pub trait SortableData<T> {
32    fn argsort_unstable(&self) -> Vec<usize>;
33}
34
35impl Iterator for ConstWeight {
36    type Item = f64;
37    fn next(&mut self) -> Option<f64> {
38        return Some(self.value);
39    }
40}
41
42impl DoubleEndedIterator for ConstWeight {
43    fn next_back(&mut self) -> Option<f64> {
44        return Some(self.value);
45    }
46}
47
48impl Data<f64> for ConstWeight {
49    fn get_iterator(&self) -> impl DoubleEndedIterator<Item = f64> {
50        return ConstWeight::new(self.value);
51    }
52
53    fn get_at(&self, _index: usize) -> f64 {
54        return self.value.clone();
55    }
56}
57
58impl <T: Clone> Data<T> for Vec<T> {
59    fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> {
60        return self.iter().cloned();
61    }
62    fn get_at(&self, index: usize) -> T {
63        return self[index].clone();
64    }
65}
66
67impl SortableData<f64> for Vec<f64> {
68    fn argsort_unstable(&self) -> Vec<usize> {
69        let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
70        indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
71        // indices.sort_unstable_by_key(|i| self[*i]);
72        return indices;
73    }
74}
75
76impl <T: Clone> Data<T> for &[T] {
77    fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> {
78        return self.iter().cloned();
79    }
80    fn get_at(&self, index: usize) -> T {
81        return self[index].clone();
82    }
83}
84
85impl SortableData<f64> for &[f64] {
86    fn argsort_unstable(&self) -> Vec<usize> {
87        // let t0 = Instant::now();
88        let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
89        // println!("Creating indices took {}ms", t0.elapsed().as_millis());
90        // let t1 = Instant::now();
91        indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
92        // println!("Sorting took {}ms", t0.elapsed().as_millis());
93        return indices;
94    }
95}
96
97impl <T: Clone, const N: usize> Data<T> for [T; N] {
98    fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> {
99        return self.iter().cloned();
100    }
101    fn get_at(&self, index: usize) -> T {
102        return self[index].clone();
103    }
104}
105
106impl <const N: usize> SortableData<f64> for [f64; N] {
107    fn argsort_unstable(&self) -> Vec<usize> {
108        let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
109        indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
110        return indices;
111    }
112}
113
114impl <T: Clone> Data<T> for ArrayView<'_, T, Ix1> {
115    fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> {
116        return self.iter().cloned();
117    }
118    fn get_at(&self, index: usize) -> T {
119        return self[index].clone();
120    }
121}
122
123impl SortableData<f64> for ArrayView<'_, f64, Ix1> {
124    fn argsort_unstable(&self) -> Vec<usize> {
125        let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
126        indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
127        return indices;
128    }
129}
130
131fn select<T, I>(slice: &I, indices: &[usize]) -> Vec<T>
132where T: Copy, I: Data<T>
133{
134    let mut selection: Vec<T> = Vec::new();
135    selection.reserve_exact(indices.len());
136    for index in indices {
137        selection.push(slice.get_at(*index));
138    }
139    return selection;
140}
141
142pub fn average_precision<L, P, W>(labels: &L, predictions: &P, weights: Option<&W>) -> f64
143where L: Data<u8>, P: SortableData<f64>, W: Data<f64>
144{
145    return average_precision_with_order(labels, predictions, weights, None);
146}
147
148pub fn average_precision_with_order<L, P, W>(labels: &L, predictions: &P, weights: Option<&W>, order: Option<Order>) -> f64
149where L: Data<u8>, P: SortableData<f64>, W: Data<f64>
150{
151    return match order {
152        Some(o) => average_precision_on_sorted_labels(labels, weights, o),
153        None => {
154            let indices = predictions.argsort_unstable();
155            let sorted_labels = select(labels, &indices);
156            let ap = match weights {
157                None => {
158                    // let w: Oepion<&
159                    average_precision_on_sorted_labels(&sorted_labels, weights, Order::DESCENDING)
160                },
161                Some(w) => average_precision_on_sorted_labels(&sorted_labels, Some(&select(w, &indices)), Order::DESCENDING),
162            };
163            ap
164        }
165    };
166}
167
168pub fn average_precision_on_sorted_labels<L, W>(labels: &L, weights: Option<&W>, order: Order) -> f64
169where L: Data<u8>, W: Data<f64>
170{
171    return match weights {
172        None => average_precision_on_iterator(labels.get_iterator(), ConstWeight::one(), order),
173        Some(w) => average_precision_on_iterator(labels.get_iterator(), w.get_iterator(), order)
174    };
175}
176
177pub fn average_precision_on_iterator<L, W>(labels: L, weights: W, order: Order) -> f64
178where L: DoubleEndedIterator<Item = u8>, W: DoubleEndedIterator<Item = f64>
179{
180    return match order {
181        Order::ASCENDING => average_precision_on_descending_iterator(labels.rev(), weights.rev()),
182        Order::DESCENDING => average_precision_on_descending_iterator(labels, weights)
183    };
184}
185
186pub fn average_precision_on_descending_iterator(labels: impl Iterator<Item = u8>, weights: impl Iterator<Item = f64>) -> f64 {
187    let mut ap: f64 = 0.0;
188    let mut tps: f64 = 0.0;
189    let mut fps: f64 = 0.0;
190    for (label, weight) in labels.zip(weights) {
191        let w: f64 = weight;
192        let l: u8 = label;
193        let tp = w * (l as f64);
194        tps += tp;
195        fps += weight - tp;
196        let ps = tps + fps;
197        let precision = tps / ps;
198        ap += tp * precision;
199    }
200    return ap / tps;
201}
202
203
204
205// ROC AUC score
206pub fn roc_auc<L, P, W>(labels: &L, predictions: &P, weights: Option<&W>) -> f64
207where L: Data<u8>, P: SortableData<f64> + Data<f64>, W: Data<f64>
208{
209    return roc_auc_with_order(labels, predictions, weights, None, None);
210}
211
212pub fn roc_auc_max_fpr<L, P, W>(labels: &L, predictions: &P, weights: Option<&W>, max_false_positive_rate: Option<f64>) -> f64
213where L: Data<u8>, P: SortableData<f64> + Data<f64>, W: Data<f64>
214{
215    return roc_auc_with_order(labels, predictions, weights, None, max_false_positive_rate);
216}
217
218pub fn roc_auc_with_order<L, P, W>(labels: &L, predictions: &P, weights: Option<&W>, order: Option<Order>, max_false_positive_rate: Option<f64>) -> f64
219where L: Data<u8>, P: SortableData<f64> + Data<f64>, W: Data<f64>
220{
221    return match order {
222        Some(o) => roc_auc_on_sorted_labels(labels, predictions, weights, o, max_false_positive_rate),
223        None => {
224            let indices = predictions.argsort_unstable();
225            let sorted_labels = select(labels, &indices);
226            let sorted_predictions = select(predictions, &indices);
227            let roc_auc_score = match weights {
228                Some(w) => {
229                    let sorted_weights = select(w, &indices);
230                    roc_auc_on_sorted_labels(&sorted_labels, &sorted_predictions, Some(&sorted_weights), Order::DESCENDING, max_false_positive_rate)
231                },
232                None => {
233                    roc_auc_on_sorted_labels(&sorted_labels, &sorted_predictions, None::<&W>, Order::DESCENDING, max_false_positive_rate)
234                }
235            };
236            roc_auc_score
237        }
238    };
239}
240pub fn roc_auc_on_sorted_labels<L, P, W>(labels: &L, predictions: &P, weights: Option<&W>, order: Order, max_false_positive_rate: Option<f64>) -> f64
241where L: Data<u8>, P: Data<f64>, W: Data<f64> {
242    return match max_false_positive_rate {
243        None => match weights {
244            Some(w) => roc_auc_on_sorted_iterator(&mut labels.get_iterator(), &mut predictions.get_iterator(), &mut w.get_iterator(), order),
245            None => roc_auc_on_sorted_iterator(&mut labels.get_iterator(), &mut predictions.get_iterator(), &mut ConstWeight::one().get_iterator(), order),
246        }
247        Some(max_fpr) => match weights {
248            Some(w) => roc_auc_on_sorted_with_fp_cutoff(labels, predictions, w, order, max_fpr),
249            None => roc_auc_on_sorted_with_fp_cutoff(labels, predictions, &ConstWeight::one(), order, max_fpr),
250        }
251    };
252}
253
254pub fn roc_auc_on_sorted_iterator(
255    labels: &mut impl DoubleEndedIterator<Item = u8>,
256    predictions: &mut impl DoubleEndedIterator<Item = f64>,
257    weights: &mut impl DoubleEndedIterator<Item = f64>,
258    order: Order
259) -> f64 {
260    return match order {
261        Order::ASCENDING => roc_auc_on_descending_iterator(&mut labels.rev(), &mut predictions.rev(), &mut weights.rev()),
262        Order::DESCENDING => roc_auc_on_descending_iterator(labels, predictions, weights)
263    }
264}
265
266pub fn roc_auc_on_descending_iterator(
267    labels: &mut impl Iterator<Item = u8>,
268    predictions: &mut impl Iterator<Item = f64>,
269    weights: &mut impl Iterator<Item = f64>
270) -> f64 {
271    let mut false_positives: f64 = 0.0;
272    let mut true_positives: f64 = 0.0;
273    let mut last_counted_fp = 0.0;
274    let mut last_counted_tp = 0.0;
275    let mut area_under_curve = 0.0;
276    let mut zipped = labels.zip(predictions).zip(weights).peekable();
277    loop {
278        match zipped.next() {
279            None => break,
280            Some(actual) => {
281                let l = actual.0.0 as f64;
282                let w = actual.1;
283                let wl = l * w;
284                true_positives += wl;
285                false_positives += w - wl;
286                if zipped.peek().map(|x| x.0.1 != actual.0.1).unwrap_or(true) {
287                    area_under_curve += area_under_line_segment(last_counted_fp, false_positives, last_counted_tp, true_positives);
288                    last_counted_fp = false_positives;
289                    last_counted_tp = true_positives;
290                }
291            }
292        };
293    }
294    return area_under_curve / (true_positives * false_positives);
295}
296
297fn area_under_line_segment(x0: f64, x1: f64, y0: f64, y1: f64) -> f64 {
298    let dx = x1 - x0;
299    let dy = y1 - y0;
300    return dx * y0 + dy * dx * 0.5;
301}
302
303fn get_positive_sum(
304    labels: impl Iterator<Item = u8>,
305    weights: impl Iterator<Item = f64>
306) -> (f64, f64) {
307    let mut false_positives = 0f64;
308    let mut true_positives = 0f64;
309    for (label, weight) in labels.zip(weights) {
310        let lw = weight * (label as f64);
311        false_positives += weight - lw;
312        true_positives += lw;
313    }
314    return (false_positives, true_positives);
315}
316
317pub fn roc_auc_on_sorted_with_fp_cutoff<L, P, W>(labels: &L, predictions: &P, weights: &W, order: Order, max_false_positive_rate: f64) -> f64
318where L: Data<u8>, P: Data<f64>, W: Data<f64> {
319    // TODO validate max_fpr
320    let (fps, tps) = get_positive_sum(labels.get_iterator(), weights.get_iterator());
321    let mut l_it = labels.get_iterator();
322    let mut p_it = predictions.get_iterator();
323    let mut w_it = weights.get_iterator();
324    return match order {
325        Order::ASCENDING => roc_auc_on_descending_iterator_with_fp_cutoff(&mut l_it.rev(), &mut p_it.rev(), &mut w_it.rev(), fps, tps, max_false_positive_rate),
326        Order::DESCENDING => roc_auc_on_descending_iterator_with_fp_cutoff(&mut l_it, &mut p_it, &mut w_it, fps, tps, max_false_positive_rate)
327    };
328}
329    
330
331fn roc_auc_on_descending_iterator_with_fp_cutoff(
332    labels: &mut impl Iterator<Item = u8>,
333    predictions: &mut impl Iterator<Item = f64>,
334    weights: &mut impl Iterator<Item = f64>,
335    false_positive_sum: f64,
336    true_positive_sum: f64,
337    max_false_positive_rate: f64
338) -> f64 {
339    let mut false_positives: f64 = 0.0;
340    let mut true_positives: f64 = 0.0;
341    let mut last_counted_fp = 0.0;
342    let mut last_counted_tp = 0.0;
343    let mut area_under_curve = 0.0;
344    let mut zipped = labels.zip(predictions).zip(weights).peekable();
345    let false_positive_cutoff = max_false_positive_rate * false_positive_sum;
346    loop {
347        match zipped.next() {
348            None => break,
349            Some(actual) => {
350                let l = actual.0.0 as f64;
351                let w = actual.1;
352                let wl = l * w;
353                let next_tp = true_positives + wl;
354                let next_fp = false_positives + (w - wl);
355                let is_above_max = next_fp > false_positive_cutoff;
356                if is_above_max {
357                    let dx = next_fp  - false_positives;
358                    let dy = next_tp - true_positives;
359                    true_positives += dy * false_positive_cutoff / dx;
360                    false_positives = false_positive_cutoff;
361                } else {
362                    true_positives = next_tp;
363                    false_positives = next_fp;
364                }
365                if zipped.peek().map(|x| x.0.1 != actual.0.1).unwrap_or(true) || is_above_max {
366                    area_under_curve += area_under_line_segment(last_counted_fp, false_positives, last_counted_tp, true_positives);
367                    last_counted_fp = false_positives;
368                    last_counted_tp = true_positives;
369                }
370                if is_above_max {
371                    break;
372                }                
373            }
374        };
375    }
376    let normalized_area_under_curve = area_under_curve / (true_positive_sum * false_positive_sum);
377    let min_area = 0.5 * max_false_positive_rate * max_false_positive_rate;
378    let max_area = max_false_positive_rate;
379    return 0.5 * (1.0 + (normalized_area_under_curve - min_area) / (max_area - min_area));
380}
381
382
383// Python bindings
384#[pyclass(eq, eq_int, name="Order")]
385#[derive(PartialEq)]
386pub enum PyOrder {
387    ASCENDING,
388    DESCENDING
389}
390
391impl Clone for PyOrder {
392    fn clone(&self) -> Self {
393        match self {
394            PyOrder::ASCENDING => PyOrder::ASCENDING,
395            PyOrder::DESCENDING => PyOrder::DESCENDING
396        }
397    }
398}
399
400fn py_order_as_order(order: PyOrder) -> Order {
401    return match order {
402        PyOrder::ASCENDING => Order::ASCENDING,
403        PyOrder::DESCENDING => Order::DESCENDING,
404    }
405}
406
407#[pyfunction(name = "average_precision")]
408#[pyo3(signature = (labels, predictions, *, weights=None, order=None))]
409pub fn average_precision_py<'py>(
410    _py: Python<'py>,
411    labels: PyReadonlyArray1<'py, u8>,
412    predictions: PyReadonlyArray1<'py, f64>,
413    weights: Option<PyReadonlyArray1<'py, f64>>,
414    order: Option<PyOrder>
415) -> Result<f64, NotContiguousError> {
416    // TODO benchmark if slice has any benefits over just as_array
417    let o = order.map(py_order_as_order);
418    let ap = match weights {
419        None => {
420            if let (Ok(l), Ok(p)) = (labels.as_slice(), predictions.as_slice()) {
421                average_precision_with_order(&l, &p, None::<&Vec<f64>>, o)  
422            } else {
423                average_precision_with_order(&labels.as_array(), &predictions.as_array(), None::<&Vec<f64>>, o)
424            }
425        },
426        Some(weight) => {
427            if let (Ok(l), Ok(p), Ok(w)) = (labels.as_slice(), predictions.as_slice(), weight.as_slice()) {
428                average_precision_with_order(&l, &p, Some(&w), o)  
429            } else {
430                average_precision_with_order(&labels.as_array(), &predictions.as_array(), Some(&weight.as_array()), o)
431            }
432        }
433    };
434    return Ok(ap);
435}
436
437#[pyfunction(name = "roc_auc")]
438#[pyo3(signature = (labels, predictions, *, weights=None, order=None, max_false_positive_rate=None))]
439pub fn roc_auc_py<'py>(
440    _py: Python<'py>,
441    labels: PyReadonlyArray1<'py, u8>,
442    predictions: PyReadonlyArray1<'py, f64>,
443    weights: Option<PyReadonlyArray1<'py, f64>>,
444    order: Option<PyOrder>,
445    max_false_positive_rate: Option<f64>,
446) -> Result<f64, NotContiguousError> {
447    let o = order.map(py_order_as_order);
448    let ap = match weights {
449        Some(weight) => if let (Ok(l), Ok(p), Ok(w)) = (labels.as_slice(), predictions.as_slice(), weight.as_slice()) {
450            roc_auc_with_order(&l, &p, Some(&w), o, max_false_positive_rate)
451        } else {
452            roc_auc_with_order(&labels.as_array(), &predictions.as_array(), Some(&weight.as_array()), o, max_false_positive_rate)
453        }
454        None => if let (Ok(l), Ok(p)) = (labels.as_slice(), predictions.as_slice()) {
455            roc_auc_with_order(&l, &p, None::<&Vec<f64>>, o, max_false_positive_rate)
456        } else {
457            roc_auc_with_order(&labels.as_array(), &predictions.as_array(), None::<&Vec<f64>>, o, max_false_positive_rate)
458        }
459    };
460    return Ok(ap);
461}
462
463#[pymodule(name = "_scors")]
464fn scors(m: &Bound<'_, PyModule>) -> PyResult<()> {
465    m.add_function(wrap_pyfunction!(average_precision_py, m)?).unwrap();
466    m.add_function(wrap_pyfunction!(roc_auc_py, m)?).unwrap();
467    m.add_class::<PyOrder>().unwrap();
468    return Ok(());
469}
470
471
472#[cfg(test)]
473mod tests {
474    use super::*;
475
476    #[test]
477    fn test_average_precision_on_sorted() {
478        let labels: [u8; 4] = [1, 0, 1, 0];
479        // let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
480        let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
481        let actual = average_precision_on_sorted_labels(&labels, &weights, Order::DESCENDING);
482        assert_eq!(actual, 0.8333333333333333);
483    }
484
485    #[test]
486    fn test_average_precision_unsorted() {
487        let labels: [u8; 4] = [0, 0, 1, 1];
488        let predictions: [f64; 4] = [0.1, 0.4, 0.35, 0.8];
489        let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
490        let actual = average_precision_with_order(&labels, &predictions, &weights, None);
491        assert_eq!(actual, 0.8333333333333333);
492    }
493
494    #[test]
495    fn test_average_precision_sorted() {
496        let labels: [u8; 4] = [1, 0, 1, 0];
497        let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
498        let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
499        let actual = average_precision_with_order(&labels, &predictions, &weights, Some(Order::DESCENDING));
500        assert_eq!(actual, 0.8333333333333333);
501    }
502
503    #[test]
504    fn test_roc_auc() {
505        let labels: [u8; 4] = [1, 0, 1, 0];
506        let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
507        let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
508        let actual = roc_auc_with_order(&labels, &predictions, &weights, Some(Order::DESCENDING));
509        assert_eq!(actual, 0.75);
510    }
511}