_scors/
lib.rs

1use ndarray::{ArrayView,Ix1};
2use numpy::{Element,PyArray1,PyArrayDescr,PyArrayDescrMethods,PyArrayMethods,PyReadonlyArray1,PyUntypedArray,PyUntypedArrayMethods,dtype};
3use pyo3::Bound;
4use pyo3::exceptions::PyTypeError;
5use pyo3::marker::Ungil;
6use pyo3::prelude::*;
7use std::iter::DoubleEndedIterator;
8
9#[derive(Clone, Copy)]
10pub enum Order {
11    ASCENDING,
12    DESCENDING
13}
14
15struct ConstWeight {
16    value: f64
17}
18
19impl ConstWeight {
20    fn new(value: f64) -> Self {
21        return ConstWeight { value: value };
22    }
23    fn one() -> Self {
24        return Self::new(1.0);
25    }
26}
27
28pub trait Data<T: Clone>: {
29    // TODO This is necessary because it seems that there is no trait like that in rust
30    //      Maybe I am just not aware, but for now use my own trait.
31    fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T>;
32    fn get_at(&self, index: usize) -> T;
33}
34
35pub trait SortableData<T> {
36    fn argsort_unstable(&self) -> Vec<usize>;
37}
38
39impl Iterator for ConstWeight {
40    type Item = f64;
41    fn next(&mut self) -> Option<f64> {
42        return Some(self.value);
43    }
44}
45
46impl DoubleEndedIterator for ConstWeight {
47    fn next_back(&mut self) -> Option<f64> {
48        return Some(self.value);
49    }
50}
51
52impl Data<f64> for ConstWeight {
53    fn get_iterator(&self) -> impl DoubleEndedIterator<Item = f64> {
54        return ConstWeight::new(self.value);
55    }
56
57    fn get_at(&self, _index: usize) -> f64 {
58        return self.value.clone();
59    }
60}
61
62impl <T: Clone> Data<T> for Vec<T> {
63    fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> {
64        return self.iter().cloned();
65    }
66    fn get_at(&self, index: usize) -> T {
67        return self[index].clone();
68    }
69}
70
71impl SortableData<f64> for Vec<f64> {
72    fn argsort_unstable(&self) -> Vec<usize> {
73        let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
74        indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
75        // indices.sort_unstable_by_key(|i| self[*i]);
76        return indices;
77    }
78}
79
80impl <T: Clone> Data<T> for &[T] {
81    fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> {
82        return self.iter().cloned();
83    }
84    fn get_at(&self, index: usize) -> T {
85        return self[index].clone();
86    }
87}
88
89impl SortableData<f64> for &[f64] {
90    fn argsort_unstable(&self) -> Vec<usize> {
91        // let t0 = Instant::now();
92        let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
93        // println!("Creating indices took {}ms", t0.elapsed().as_millis());
94        // let t1 = Instant::now();
95        indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
96        // println!("Sorting took {}ms", t0.elapsed().as_millis());
97        return indices;
98    }
99}
100
101impl <T: Clone, const N: usize> Data<T> for [T; N] {
102    fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> {
103        return self.iter().cloned();
104    }
105    fn get_at(&self, index: usize) -> T {
106        return self[index].clone();
107    }
108}
109
110impl <const N: usize> SortableData<f64> for [f64; N] {
111    fn argsort_unstable(&self) -> Vec<usize> {
112        let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
113        indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
114        return indices;
115    }
116}
117
118impl <T: Clone> Data<T> for ArrayView<'_, T, Ix1> {
119    fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> {
120        return self.iter().cloned();
121    }
122    fn get_at(&self, index: usize) -> T {
123        return self[index].clone();
124    }
125}
126
127impl SortableData<f64> for ArrayView<'_, f64, Ix1> {
128    fn argsort_unstable(&self) -> Vec<usize> {
129        let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
130        indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
131        return indices;
132    }
133}
134
135// struct IndexView<'a, T, D> where T: Clone, D: Data<T>{
136//     data: &'a D,
137//     indices: &'a Vec<usize>,
138// }
139
140// impl <'a, T: Clone> Data<T> for IndexView<'a, T> {
141// }
142
143pub trait BinaryLabel: Clone + Copy {
144    fn get_value(&self) -> bool;
145}
146
147impl BinaryLabel for bool {
148    fn get_value(&self) -> bool {
149        return self.clone();
150    }
151}
152
153impl BinaryLabel for u8 {
154    fn get_value(&self) -> bool {
155        return (self & 1) == 1;
156    }
157}
158
159impl BinaryLabel for u16 {
160    fn get_value(&self) -> bool {
161        return (self & 1) == 1;
162    }
163}
164
165impl BinaryLabel for u32 {
166    fn get_value(&self) -> bool {
167        return (self & 1) == 1;
168    }
169}
170
171impl BinaryLabel for u64 {
172    fn get_value(&self) -> bool {
173        return (self & 1) == 1;
174    }
175}
176
177impl BinaryLabel for i8 {
178    fn get_value(&self) -> bool {
179        return (self & 1) == 1;
180    }
181}
182
183impl BinaryLabel for i16 {
184    fn get_value(&self) -> bool {
185        return (self & 1) == 1;
186    }
187}
188
189impl BinaryLabel for i32 {
190    fn get_value(&self) -> bool {
191        return (self & 1) == 1;
192    }
193}
194
195impl BinaryLabel for i64 {
196    fn get_value(&self) -> bool {
197        return (self & 1) == 1;
198    }
199}
200
201fn select<T, I>(slice: &I, indices: &[usize]) -> Vec<T>
202where T: Copy, I: Data<T>
203{
204    let mut selection: Vec<T> = Vec::new();
205    selection.reserve_exact(indices.len());
206    for index in indices {
207        selection.push(slice.get_at(*index));
208    }
209    return selection;
210}
211
212pub fn average_precision<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>) -> f64
213where B: BinaryLabel, L: Data<B>, P: SortableData<f64>, W: Data<f64>
214{
215    return average_precision_with_order(labels, predictions, weights, None);
216}
217
218pub fn average_precision_with_order<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>, order: Option<Order>) -> f64
219where B: BinaryLabel, L: Data<B>, P: SortableData<f64>, W: Data<f64>
220{
221    return match order {
222        Some(o) => average_precision_on_sorted_labels(labels, weights, o),
223        None => {
224            let indices = predictions.argsort_unstable();
225            let sorted_labels = select(labels, &indices);
226            let ap = match weights {
227                None => {
228                    // let w: Oepion<&
229                    average_precision_on_sorted_labels(&sorted_labels, weights, Order::DESCENDING)
230                },
231                Some(w) => average_precision_on_sorted_labels(&sorted_labels, Some(&select(w, &indices)), Order::DESCENDING),
232            };
233            ap
234        }
235    };
236}
237
238pub fn average_precision_on_sorted_labels<B, L, W>(labels: &L, weights: Option<&W>, order: Order) -> f64
239where B: BinaryLabel, L: Data<B>, W: Data<f64>
240{
241    return match weights {
242        None => average_precision_on_iterator(labels.get_iterator(), ConstWeight::one(), order),
243        Some(w) => average_precision_on_iterator(labels.get_iterator(), w.get_iterator(), order)
244    };
245}
246
247pub fn average_precision_on_iterator<B, L, W>(labels: L, weights: W, order: Order) -> f64
248where B: BinaryLabel, L: DoubleEndedIterator<Item = B>, W: DoubleEndedIterator<Item = f64>
249{
250    return match order {
251        Order::ASCENDING => average_precision_on_descending_iterator(labels.rev(), weights.rev()),
252        Order::DESCENDING => average_precision_on_descending_iterator(labels, weights)
253    };
254}
255
256pub fn average_precision_on_descending_iterator<B: BinaryLabel>(labels: impl Iterator<Item = B>, weights: impl Iterator<Item = f64>) -> f64 {
257    let mut ap: f64 = 0.0;
258    let mut tps: f64 = 0.0;
259    let mut fps: f64 = 0.0;
260    for (label, weight) in labels.zip(weights) {
261        let w: f64 = weight;
262        let l: bool = label.get_value();
263        let tp = w * f64::from(l);
264        tps += tp;
265        fps += weight - tp;
266        let ps = tps + fps;
267        let precision = tps / ps;
268        ap += tp * precision;
269    }
270    // Special case for tps == 0 following sklearn
271    // https://github.com/scikit-learn/scikit-learn/blob/5cce87176a530d2abea45b5a7e5a4d837c481749/sklearn/metrics/_ranking.py#L1032-L1039
272    // I.e. if tps is 0.0, there are no positive samples in labels: Either all labels are 0, or all weights (for positive labels) are 0
273    return if tps == 0.0 {
274        0.0
275    } else {
276        ap / tps
277    };
278}
279
280
281
282// ROC AUC score
283pub fn roc_auc<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>) -> f64
284where B: BinaryLabel, L: Data<B>, P: SortableData<f64> + Data<f64>, W: Data<f64>
285{
286    return roc_auc_with_order(labels, predictions, weights, None, None);
287}
288
289pub fn roc_auc_max_fpr<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>, max_false_positive_rate: Option<f64>) -> f64
290where B: BinaryLabel, L: Data<B>, P: SortableData<f64> + Data<f64>, W: Data<f64>
291{
292    return roc_auc_with_order(labels, predictions, weights, None, max_false_positive_rate);
293}
294
295pub fn roc_auc_with_order<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>, order: Option<Order>, max_false_positive_rate: Option<f64>) -> f64
296where B: BinaryLabel, L: Data<B>, P: SortableData<f64> + Data<f64>, W: Data<f64>
297{
298    return match order {
299        Some(o) => roc_auc_on_sorted_labels(labels, predictions, weights, o, max_false_positive_rate),
300        None => {
301            let indices = predictions.argsort_unstable();
302            let sorted_labels = select(labels, &indices);
303            let sorted_predictions = select(predictions, &indices);
304            let roc_auc_score = match weights {
305                Some(w) => {
306                    let sorted_weights = select(w, &indices);
307                    roc_auc_on_sorted_labels(&sorted_labels, &sorted_predictions, Some(&sorted_weights), Order::DESCENDING, max_false_positive_rate)
308                },
309                None => {
310                    roc_auc_on_sorted_labels(&sorted_labels, &sorted_predictions, None::<&W>, Order::DESCENDING, max_false_positive_rate)
311                }
312            };
313            roc_auc_score
314        }
315    };
316}
317pub fn roc_auc_on_sorted_labels<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>, order: Order, max_false_positive_rate: Option<f64>) -> f64
318where B: BinaryLabel, L: Data<B>, P: Data<f64>, W: Data<f64> {
319    return match max_false_positive_rate {
320        None => match weights {
321            Some(w) => roc_auc_on_sorted_iterator(&mut labels.get_iterator(), &mut predictions.get_iterator(), &mut w.get_iterator(), order),
322            None => roc_auc_on_sorted_iterator(&mut labels.get_iterator(), &mut predictions.get_iterator(), &mut ConstWeight::one().get_iterator(), order),
323        }
324        Some(max_fpr) => match weights {
325            Some(w) => roc_auc_on_sorted_with_fp_cutoff(labels, predictions, w, order, max_fpr),
326            None => roc_auc_on_sorted_with_fp_cutoff(labels, predictions, &ConstWeight::one(), order, max_fpr),
327        }
328    };
329}
330
331pub fn roc_auc_on_sorted_iterator<B: BinaryLabel>(
332    labels: &mut impl DoubleEndedIterator<Item = B>,
333    predictions: &mut impl DoubleEndedIterator<Item = f64>,
334    weights: &mut impl DoubleEndedIterator<Item = f64>,
335    order: Order
336) -> f64 {
337    return match order {
338        Order::ASCENDING => roc_auc_on_descending_iterator(&mut labels.rev(), &mut predictions.rev(), &mut weights.rev()),
339        Order::DESCENDING => roc_auc_on_descending_iterator(labels, predictions, weights)
340    }
341}
342
343pub fn roc_auc_on_descending_iterator<B: BinaryLabel>(
344    labels: &mut impl Iterator<Item = B>,
345    predictions: &mut impl Iterator<Item = f64>,
346    weights: &mut impl Iterator<Item = f64>
347) -> f64 {
348    let mut false_positives: f64 = 0.0;
349    let mut true_positives: f64 = 0.0;
350    let mut last_counted_fp = 0.0;
351    let mut last_counted_tp = 0.0;
352    let mut area_under_curve = 0.0;
353    let mut zipped = labels.zip(predictions).zip(weights).peekable();
354    loop {
355        match zipped.next() {
356            None => break,
357            Some(actual) => {
358                let l = f64::from(actual.0.0.get_value());
359                let w = actual.1;
360                let wl = l * w;
361                true_positives += wl;
362                false_positives += w - wl;
363                if zipped.peek().map(|x| x.0.1 != actual.0.1).unwrap_or(true) {
364                    area_under_curve += area_under_line_segment(last_counted_fp, false_positives, last_counted_tp, true_positives);
365                    last_counted_fp = false_positives;
366                    last_counted_tp = true_positives;
367                }
368            }
369        };
370    }
371    return area_under_curve / (true_positives * false_positives);
372}
373
374fn area_under_line_segment(x0: f64, x1: f64, y0: f64, y1: f64) -> f64 {
375    let dx = x1 - x0;
376    let dy = y1 - y0;
377    return dx * y0 + dy * dx * 0.5;
378}
379
380fn get_positive_sum<B: BinaryLabel>(
381    labels: impl Iterator<Item = B>,
382    weights: impl Iterator<Item = f64>
383) -> (f64, f64) {
384    let mut false_positives = 0f64;
385    let mut true_positives = 0f64;
386    for (label, weight) in labels.zip(weights) {
387        let lw = weight * f64::from(label.get_value());
388        false_positives += weight - lw;
389        true_positives += lw;
390    }
391    return (false_positives, true_positives);
392}
393
394pub fn roc_auc_on_sorted_with_fp_cutoff<B, L, P, W>(labels: &L, predictions: &P, weights: &W, order: Order, max_false_positive_rate: f64) -> f64
395where B: BinaryLabel, L: Data<B>, P: Data<f64>, W: Data<f64> {
396    // TODO validate max_fpr
397    let (fps, tps) = get_positive_sum(labels.get_iterator(), weights.get_iterator());
398    let mut l_it = labels.get_iterator();
399    let mut p_it = predictions.get_iterator();
400    let mut w_it = weights.get_iterator();
401    return match order {
402        Order::ASCENDING => roc_auc_on_descending_iterator_with_fp_cutoff(&mut l_it.rev(), &mut p_it.rev(), &mut w_it.rev(), fps, tps, max_false_positive_rate),
403        Order::DESCENDING => roc_auc_on_descending_iterator_with_fp_cutoff(&mut l_it, &mut p_it, &mut w_it, fps, tps, max_false_positive_rate)
404    };
405}
406    
407
408fn roc_auc_on_descending_iterator_with_fp_cutoff<B: BinaryLabel>(
409    labels: &mut impl Iterator<Item = B>,
410    predictions: &mut impl Iterator<Item = f64>,
411    weights: &mut impl Iterator<Item = f64>,
412    false_positive_sum: f64,
413    true_positive_sum: f64,
414    max_false_positive_rate: f64
415) -> f64 {
416    let mut false_positives: f64 = 0.0;
417    let mut true_positives: f64 = 0.0;
418    let mut last_counted_fp = 0.0;
419    let mut last_counted_tp = 0.0;
420    let mut area_under_curve = 0.0;
421    let mut zipped = labels.zip(predictions).zip(weights).peekable();
422    let false_positive_cutoff = max_false_positive_rate * false_positive_sum;
423    loop {
424        match zipped.next() {
425            None => break,
426            Some(actual) => {
427                let l = f64::from(actual.0.0.get_value());
428                let w = actual.1;
429                let wl = l * w;
430                let next_tp = true_positives + wl;
431                let next_fp = false_positives + (w - wl);
432                let is_above_max = next_fp > false_positive_cutoff;
433                if is_above_max {
434                    let dx = next_fp  - false_positives;
435                    let dy = next_tp - true_positives;
436                    true_positives += dy * false_positive_cutoff / dx;
437                    false_positives = false_positive_cutoff;
438                } else {
439                    true_positives = next_tp;
440                    false_positives = next_fp;
441                }
442                if zipped.peek().map(|x| x.0.1 != actual.0.1).unwrap_or(true) || is_above_max {
443                    area_under_curve += area_under_line_segment(last_counted_fp, false_positives, last_counted_tp, true_positives);
444                    last_counted_fp = false_positives;
445                    last_counted_tp = true_positives;
446                }
447                if is_above_max {
448                    break;
449                }                
450            }
451        };
452    }
453    let normalized_area_under_curve = area_under_curve / (true_positive_sum * false_positive_sum);
454    let min_area = 0.5 * max_false_positive_rate * max_false_positive_rate;
455    let max_area = max_false_positive_rate;
456    return 0.5 * (1.0 + (normalized_area_under_curve - min_area) / (max_area - min_area));
457}
458
459
460// Python bindings
461#[pyclass(eq, eq_int, name="Order")]
462#[derive(Clone, Copy, PartialEq)]
463pub enum PyOrder {
464    ASCENDING,
465    DESCENDING
466}
467
468fn py_order_as_order(order: PyOrder) -> Order {
469    return match order {
470        PyOrder::ASCENDING => Order::ASCENDING,
471        PyOrder::DESCENDING => Order::DESCENDING,
472    }
473}
474
475
476trait PyScore: Ungil + Sync {
477
478    fn score<B, L, P, W>(&self, labels: &L, predictions: &P, weights: Option<&W>, order: Option<Order>) -> f64
479    where B: BinaryLabel, L: Data<B>, P: SortableData<f64> + Data<f64>, W: Data<f64>;
480
481    fn score_py_generic<'py, B>(
482        &self,
483        py: Python<'py>,
484        labels: &PyReadonlyArray1<'py, B>,
485        predictions: &PyReadonlyArray1<'py, f64>,
486        weights: &Option<PyReadonlyArray1<'py, f64>>,
487        order: &Option<PyOrder>,
488    ) -> f64
489    where B: BinaryLabel + Element
490    {
491        let labels = labels.as_array();
492        let predictions = predictions.as_array();
493        let order = order.map(py_order_as_order);
494        let score = match weights {
495            Some(weight) => {
496                let weights = weight.as_array();
497                py.allow_threads(move || {
498                    self.score(&labels, &predictions, Some(&weights), order)
499                })
500            },
501            None => py.allow_threads(move || {
502                self.score(&labels, &predictions, None::<&Vec<f64>>, order)
503            })
504        };
505        return score;
506    }
507
508    fn score_py_match_run<'py, T>(
509        &self,
510        py: Python<'py>,
511        labels: &Bound<'py, PyUntypedArray>,
512        predictions: &PyReadonlyArray1<'py, f64>,
513        weights: &Option<PyReadonlyArray1<'py, f64>>,
514        order: &Option<PyOrder>,
515        dt: &Bound<'py, PyArrayDescr>
516    ) -> Option<f64>
517    where T: Element + BinaryLabel
518    {
519        return if dt.is_equiv_to(&dtype::<T>(py)) {
520            let labels = labels.downcast::<PyArray1<T>>().unwrap().readonly();
521            Some(self.score_py_generic(py, &labels.readonly(), predictions, weights, order))
522        } else {
523            None
524        };
525    }
526    
527    fn score_py<'py>(
528        &self,
529        py: Python<'py>,
530        labels: &Bound<'py, PyUntypedArray>,
531        predictions: PyReadonlyArray1<'py, f64>,
532        weights: Option<PyReadonlyArray1<'py, f64>>,
533        order: Option<PyOrder>,
534    ) -> PyResult<f64> {
535        if labels.ndim() != 1 {
536            return Err(PyTypeError::new_err(format!("Expected 1-dimensional array for labels but found {} dimenisons.", labels.ndim())));
537        }
538        let label_dtype = labels.dtype();
539        if let Some(score) = self.score_py_match_run::<bool>(py, &labels, &predictions, &weights, &order, &label_dtype) {
540            return Ok(score)
541        }
542        else if let Some(score) = self.score_py_match_run::<u8>(py, &labels, &predictions, &weights, &order, &label_dtype) {
543            return Ok(score)
544        }
545        else if let Some(score) = self.score_py_match_run::<i8>(py, &labels, &predictions, &weights, &order, &label_dtype) {
546            return Ok(score)
547        }
548        else if let Some(score) = self.score_py_match_run::<u16>(py, &labels, &predictions, &weights, &order, &label_dtype) {
549            return Ok(score)
550        }
551        else if let Some(score) = self.score_py_match_run::<i16>(py, &labels, &predictions, &weights, &order, &label_dtype) {
552            return Ok(score)
553        }
554        else if let Some(score) = self.score_py_match_run::<u32>(py, &labels, &predictions, &weights, &order, &label_dtype) {
555            return Ok(score)
556        }
557        else if let Some(score) = self.score_py_match_run::<i32>(py, &labels, &predictions, &weights, &order, &label_dtype) {
558            return Ok(score)
559        }
560        else if let Some(score) = self.score_py_match_run::<u64>(py, &labels, &predictions, &weights, &order, &label_dtype) {
561            return Ok(score)
562        }
563        else if let Some(score) = self.score_py_match_run::<i64>(py, &labels, &predictions, &weights, &order, &label_dtype) {
564            return Ok(score)
565        }
566        return Err(PyTypeError::new_err(format!("Unsupported dtype for labels: {}. Supported dtypes are bool, uint8, uint16, uint32, uint64, in8, int16, int32, int64", label_dtype)));
567    }
568}
569
570struct PyAveragePrecision {
571    
572}
573
574impl PyAveragePrecision{
575    fn new() -> Self {
576        return PyAveragePrecision {};
577    }
578}
579
580impl PyScore for PyAveragePrecision {
581    fn score<B, L, P, W>(&self, labels: &L, predictions: &P, weights: Option<&W>, order: Option<Order>) -> f64
582    where B: BinaryLabel, L: Data<B>, P: SortableData<f64> + Data<f64>, W: Data<f64> {
583        return average_precision_with_order(labels, predictions, weights, order);
584    }
585}
586
587struct PyRocAuc {
588    max_fpr: Option<f64>
589}
590
591impl PyRocAuc {
592    fn new(max_fpr: Option<f64>) -> Self {
593        return PyRocAuc { max_fpr: max_fpr };
594    }
595}
596
597impl PyScore for PyRocAuc {
598    fn score<B, L, P, W>(&self, labels: &L, predictions: &P, weights: Option<&W>, order: Option<Order>) -> f64
599    where B: BinaryLabel, L: Data<B>, P: SortableData<f64> + Data<f64>, W: Data<f64> {
600        return roc_auc_with_order(labels, predictions, weights, order, self.max_fpr);
601    }
602}
603
604
605#[pyfunction(name = "average_precision")]
606#[pyo3(signature = (labels, predictions, *, weights=None, order=None))]
607pub fn average_precision_py<'py>(
608    py: Python<'py>,
609    labels: &Bound<'py, PyUntypedArray>,
610    predictions: PyReadonlyArray1<'py, f64>,
611    weights: Option<PyReadonlyArray1<'py, f64>>,
612    order: Option<PyOrder>
613) -> PyResult<f64> {
614    return PyAveragePrecision::new().score_py(py, labels, predictions, weights, order);
615}
616
617#[pyfunction(name = "roc_auc")]
618#[pyo3(signature = (labels, predictions, *, weights=None, order=None, max_fpr=None))]
619pub fn roc_auc_py<'py>(
620    py: Python<'py>,
621    labels: &Bound<'py, PyUntypedArray>,
622    predictions: PyReadonlyArray1<'py, f64>,
623    weights: Option<PyReadonlyArray1<'py, f64>>,
624    order: Option<PyOrder>,
625    max_fpr: Option<f64>,
626) -> PyResult<f64> {
627    return PyRocAuc::new(max_fpr).score_py(py, labels, predictions, weights, order);
628}
629
630#[pymodule(name = "_scors")]
631fn scors(m: &Bound<'_, PyModule>) -> PyResult<()> {
632    m.add_function(wrap_pyfunction!(average_precision_py, m)?).unwrap();
633    m.add_function(wrap_pyfunction!(roc_auc_py, m)?).unwrap();
634    m.add_class::<PyOrder>().unwrap();
635    return Ok(());
636}
637
638
639#[cfg(test)]
640mod tests {
641    use super::*;
642
643    #[test]
644    fn test_average_precision_on_sorted() {
645        let labels: [u8; 4] = [1, 0, 1, 0];
646        // let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
647        let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
648        let actual = average_precision_on_sorted_labels(&labels, Some(&weights), Order::DESCENDING);
649        assert_eq!(actual, 0.8333333333333333);
650    }
651
652    #[test]
653    fn test_average_precision_unsorted() {
654        let labels: [u8; 4] = [0, 0, 1, 1];
655        let predictions: [f64; 4] = [0.1, 0.4, 0.35, 0.8];
656        let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
657        let actual = average_precision_with_order(&labels, &predictions, Some(&weights), None);
658        assert_eq!(actual, 0.8333333333333333);
659    }
660
661    #[test]
662    fn test_average_precision_sorted() {
663        let labels: [u8; 4] = [1, 0, 1, 0];
664        let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
665        let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
666        let actual = average_precision_with_order(&labels, &predictions, Some(&weights), Some(Order::DESCENDING));
667        assert_eq!(actual, 0.8333333333333333);
668    }
669
670    #[test]
671    fn test_roc_auc() {
672        let labels: [u8; 4] = [1, 0, 1, 0];
673        let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
674        let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
675        let actual = roc_auc_with_order(&labels, &predictions, Some(&weights), Some(Order::DESCENDING), None);
676        assert_eq!(actual, 0.75);
677    }
678}