_scors/
lib.rs

1use ndarray::{ArrayView,Ix1};
2use numpy::{Element,PyArray1,PyArrayDescr,PyArrayDescrMethods,PyArrayMethods,PyReadonlyArray1,PyUntypedArray,PyUntypedArrayMethods,dtype};
3use pyo3::Bound;
4use pyo3::exceptions::PyTypeError;
5use pyo3::prelude::*;
6use std::iter::DoubleEndedIterator;
7
8#[derive(Clone, Copy)]
9pub enum Order {
10    ASCENDING,
11    DESCENDING
12}
13
14struct ConstWeight {
15    value: f64
16}
17
18impl ConstWeight {
19    fn new(value: f64) -> Self {
20        return ConstWeight { value: value };
21    }
22    fn one() -> Self {
23        return Self::new(1.0);
24    }
25}
26
27pub trait Data<T: Clone>: {
28    // TODO This is necessary because it seems that there is no trait like that in rust
29    //      Maybe I am just not aware, but for now use my own trait.
30    fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T>;
31    fn get_at(&self, index: usize) -> T;
32}
33
34pub trait SortableData<T> {
35    fn argsort_unstable(&self) -> Vec<usize>;
36}
37
38impl Iterator for ConstWeight {
39    type Item = f64;
40    fn next(&mut self) -> Option<f64> {
41        return Some(self.value);
42    }
43}
44
45impl DoubleEndedIterator for ConstWeight {
46    fn next_back(&mut self) -> Option<f64> {
47        return Some(self.value);
48    }
49}
50
51impl Data<f64> for ConstWeight {
52    fn get_iterator(&self) -> impl DoubleEndedIterator<Item = f64> {
53        return ConstWeight::new(self.value);
54    }
55
56    fn get_at(&self, _index: usize) -> f64 {
57        return self.value.clone();
58    }
59}
60
61impl <T: Clone> Data<T> for Vec<T> {
62    fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> {
63        return self.iter().cloned();
64    }
65    fn get_at(&self, index: usize) -> T {
66        return self[index].clone();
67    }
68}
69
70impl SortableData<f64> for Vec<f64> {
71    fn argsort_unstable(&self) -> Vec<usize> {
72        let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
73        indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
74        // indices.sort_unstable_by_key(|i| self[*i]);
75        return indices;
76    }
77}
78
79impl <T: Clone> Data<T> for &[T] {
80    fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> {
81        return self.iter().cloned();
82    }
83    fn get_at(&self, index: usize) -> T {
84        return self[index].clone();
85    }
86}
87
88impl SortableData<f64> for &[f64] {
89    fn argsort_unstable(&self) -> Vec<usize> {
90        // let t0 = Instant::now();
91        let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
92        // println!("Creating indices took {}ms", t0.elapsed().as_millis());
93        // let t1 = Instant::now();
94        indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
95        // println!("Sorting took {}ms", t0.elapsed().as_millis());
96        return indices;
97    }
98}
99
100impl <T: Clone, const N: usize> Data<T> for [T; N] {
101    fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> {
102        return self.iter().cloned();
103    }
104    fn get_at(&self, index: usize) -> T {
105        return self[index].clone();
106    }
107}
108
109impl <const N: usize> SortableData<f64> for [f64; N] {
110    fn argsort_unstable(&self) -> Vec<usize> {
111        let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
112        indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
113        return indices;
114    }
115}
116
117impl <T: Clone> Data<T> for ArrayView<'_, T, Ix1> {
118    fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> {
119        return self.iter().cloned();
120    }
121    fn get_at(&self, index: usize) -> T {
122        return self[index].clone();
123    }
124}
125
126impl SortableData<f64> for ArrayView<'_, f64, Ix1> {
127    fn argsort_unstable(&self) -> Vec<usize> {
128        let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
129        indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
130        return indices;
131    }
132}
133
134// struct IndexView<'a, T, D> where T: Clone, D: Data<T>{
135//     data: &'a D,
136//     indices: &'a Vec<usize>,
137// }
138
139// impl <'a, T: Clone> Data<T> for IndexView<'a, T> {
140// }
141
142pub trait BinaryLabel: Clone + Copy {
143    fn get_value(&self) -> bool;
144}
145
146impl BinaryLabel for bool {
147    fn get_value(&self) -> bool {
148        return self.clone();
149    }
150}
151
152impl BinaryLabel for u8 {
153    fn get_value(&self) -> bool {
154        return (self & 1) == 1;
155    }
156}
157
158impl BinaryLabel for u16 {
159    fn get_value(&self) -> bool {
160        return (self & 1) == 1;
161    }
162}
163
164impl BinaryLabel for u32 {
165    fn get_value(&self) -> bool {
166        return (self & 1) == 1;
167    }
168}
169
170impl BinaryLabel for u64 {
171    fn get_value(&self) -> bool {
172        return (self & 1) == 1;
173    }
174}
175
176impl BinaryLabel for i8 {
177    fn get_value(&self) -> bool {
178        return (self & 1) == 1;
179    }
180}
181
182impl BinaryLabel for i16 {
183    fn get_value(&self) -> bool {
184        return (self & 1) == 1;
185    }
186}
187
188impl BinaryLabel for i32 {
189    fn get_value(&self) -> bool {
190        return (self & 1) == 1;
191    }
192}
193
194impl BinaryLabel for i64 {
195    fn get_value(&self) -> bool {
196        return (self & 1) == 1;
197    }
198}
199
200fn select<T, I>(slice: &I, indices: &[usize]) -> Vec<T>
201where T: Copy, I: Data<T>
202{
203    let mut selection: Vec<T> = Vec::new();
204    selection.reserve_exact(indices.len());
205    for index in indices {
206        selection.push(slice.get_at(*index));
207    }
208    return selection;
209}
210
211pub fn average_precision<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>) -> f64
212where B: BinaryLabel, L: Data<B>, P: SortableData<f64>, W: Data<f64>
213{
214    return average_precision_with_order(labels, predictions, weights, None);
215}
216
217pub fn average_precision_with_order<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>, order: Option<Order>) -> f64
218where B: BinaryLabel, L: Data<B>, P: SortableData<f64>, W: Data<f64>
219{
220    return match order {
221        Some(o) => average_precision_on_sorted_labels(labels, weights, o),
222        None => {
223            let indices = predictions.argsort_unstable();
224            let sorted_labels = select(labels, &indices);
225            let ap = match weights {
226                None => {
227                    // let w: Oepion<&
228                    average_precision_on_sorted_labels(&sorted_labels, weights, Order::DESCENDING)
229                },
230                Some(w) => average_precision_on_sorted_labels(&sorted_labels, Some(&select(w, &indices)), Order::DESCENDING),
231            };
232            ap
233        }
234    };
235}
236
237pub fn average_precision_on_sorted_labels<B, L, W>(labels: &L, weights: Option<&W>, order: Order) -> f64
238where B: BinaryLabel, L: Data<B>, W: Data<f64>
239{
240    return match weights {
241        None => average_precision_on_iterator(labels.get_iterator(), ConstWeight::one(), order),
242        Some(w) => average_precision_on_iterator(labels.get_iterator(), w.get_iterator(), order)
243    };
244}
245
246pub fn average_precision_on_iterator<B, L, W>(labels: L, weights: W, order: Order) -> f64
247where B: BinaryLabel, L: DoubleEndedIterator<Item = B>, W: DoubleEndedIterator<Item = f64>
248{
249    return match order {
250        Order::ASCENDING => average_precision_on_descending_iterator(labels.rev(), weights.rev()),
251        Order::DESCENDING => average_precision_on_descending_iterator(labels, weights)
252    };
253}
254
255pub fn average_precision_on_descending_iterator<B: BinaryLabel>(labels: impl Iterator<Item = B>, weights: impl Iterator<Item = f64>) -> f64 {
256    let mut ap: f64 = 0.0;
257    let mut tps: f64 = 0.0;
258    let mut fps: f64 = 0.0;
259    for (label, weight) in labels.zip(weights) {
260        let w: f64 = weight;
261        let l: bool = label.get_value();
262        let tp = w * f64::from(l);
263        tps += tp;
264        fps += weight - tp;
265        let ps = tps + fps;
266        let precision = tps / ps;
267        ap += tp * precision;
268    }
269    return ap / tps;
270}
271
272
273
274// ROC AUC score
275pub fn roc_auc<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>) -> f64
276where B: BinaryLabel, L: Data<B>, P: SortableData<f64> + Data<f64>, W: Data<f64>
277{
278    return roc_auc_with_order(labels, predictions, weights, None, None);
279}
280
281pub fn roc_auc_max_fpr<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>, max_false_positive_rate: Option<f64>) -> f64
282where B: BinaryLabel, L: Data<B>, P: SortableData<f64> + Data<f64>, W: Data<f64>
283{
284    return roc_auc_with_order(labels, predictions, weights, None, max_false_positive_rate);
285}
286
287pub fn roc_auc_with_order<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>, order: Option<Order>, max_false_positive_rate: Option<f64>) -> f64
288where B: BinaryLabel, L: Data<B>, P: SortableData<f64> + Data<f64>, W: Data<f64>
289{
290    return match order {
291        Some(o) => roc_auc_on_sorted_labels(labels, predictions, weights, o, max_false_positive_rate),
292        None => {
293            let indices = predictions.argsort_unstable();
294            let sorted_labels = select(labels, &indices);
295            let sorted_predictions = select(predictions, &indices);
296            let roc_auc_score = match weights {
297                Some(w) => {
298                    let sorted_weights = select(w, &indices);
299                    roc_auc_on_sorted_labels(&sorted_labels, &sorted_predictions, Some(&sorted_weights), Order::DESCENDING, max_false_positive_rate)
300                },
301                None => {
302                    roc_auc_on_sorted_labels(&sorted_labels, &sorted_predictions, None::<&W>, Order::DESCENDING, max_false_positive_rate)
303                }
304            };
305            roc_auc_score
306        }
307    };
308}
309pub fn roc_auc_on_sorted_labels<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>, order: Order, max_false_positive_rate: Option<f64>) -> f64
310where B: BinaryLabel, L: Data<B>, P: Data<f64>, W: Data<f64> {
311    return match max_false_positive_rate {
312        None => match weights {
313            Some(w) => roc_auc_on_sorted_iterator(&mut labels.get_iterator(), &mut predictions.get_iterator(), &mut w.get_iterator(), order),
314            None => roc_auc_on_sorted_iterator(&mut labels.get_iterator(), &mut predictions.get_iterator(), &mut ConstWeight::one().get_iterator(), order),
315        }
316        Some(max_fpr) => match weights {
317            Some(w) => roc_auc_on_sorted_with_fp_cutoff(labels, predictions, w, order, max_fpr),
318            None => roc_auc_on_sorted_with_fp_cutoff(labels, predictions, &ConstWeight::one(), order, max_fpr),
319        }
320    };
321}
322
323pub fn roc_auc_on_sorted_iterator<B: BinaryLabel>(
324    labels: &mut impl DoubleEndedIterator<Item = B>,
325    predictions: &mut impl DoubleEndedIterator<Item = f64>,
326    weights: &mut impl DoubleEndedIterator<Item = f64>,
327    order: Order
328) -> f64 {
329    return match order {
330        Order::ASCENDING => roc_auc_on_descending_iterator(&mut labels.rev(), &mut predictions.rev(), &mut weights.rev()),
331        Order::DESCENDING => roc_auc_on_descending_iterator(labels, predictions, weights)
332    }
333}
334
335pub fn roc_auc_on_descending_iterator<B: BinaryLabel>(
336    labels: &mut impl Iterator<Item = B>,
337    predictions: &mut impl Iterator<Item = f64>,
338    weights: &mut impl Iterator<Item = f64>
339) -> f64 {
340    let mut false_positives: f64 = 0.0;
341    let mut true_positives: f64 = 0.0;
342    let mut last_counted_fp = 0.0;
343    let mut last_counted_tp = 0.0;
344    let mut area_under_curve = 0.0;
345    let mut zipped = labels.zip(predictions).zip(weights).peekable();
346    loop {
347        match zipped.next() {
348            None => break,
349            Some(actual) => {
350                let l = f64::from(actual.0.0.get_value());
351                let w = actual.1;
352                let wl = l * w;
353                true_positives += wl;
354                false_positives += w - wl;
355                if zipped.peek().map(|x| x.0.1 != actual.0.1).unwrap_or(true) {
356                    area_under_curve += area_under_line_segment(last_counted_fp, false_positives, last_counted_tp, true_positives);
357                    last_counted_fp = false_positives;
358                    last_counted_tp = true_positives;
359                }
360            }
361        };
362    }
363    return area_under_curve / (true_positives * false_positives);
364}
365
366fn area_under_line_segment(x0: f64, x1: f64, y0: f64, y1: f64) -> f64 {
367    let dx = x1 - x0;
368    let dy = y1 - y0;
369    return dx * y0 + dy * dx * 0.5;
370}
371
372fn get_positive_sum<B: BinaryLabel>(
373    labels: impl Iterator<Item = B>,
374    weights: impl Iterator<Item = f64>
375) -> (f64, f64) {
376    let mut false_positives = 0f64;
377    let mut true_positives = 0f64;
378    for (label, weight) in labels.zip(weights) {
379        let lw = weight * f64::from(label.get_value());
380        false_positives += weight - lw;
381        true_positives += lw;
382    }
383    return (false_positives, true_positives);
384}
385
386pub fn roc_auc_on_sorted_with_fp_cutoff<B, L, P, W>(labels: &L, predictions: &P, weights: &W, order: Order, max_false_positive_rate: f64) -> f64
387where B: BinaryLabel, L: Data<B>, P: Data<f64>, W: Data<f64> {
388    // TODO validate max_fpr
389    let (fps, tps) = get_positive_sum(labels.get_iterator(), weights.get_iterator());
390    let mut l_it = labels.get_iterator();
391    let mut p_it = predictions.get_iterator();
392    let mut w_it = weights.get_iterator();
393    return match order {
394        Order::ASCENDING => roc_auc_on_descending_iterator_with_fp_cutoff(&mut l_it.rev(), &mut p_it.rev(), &mut w_it.rev(), fps, tps, max_false_positive_rate),
395        Order::DESCENDING => roc_auc_on_descending_iterator_with_fp_cutoff(&mut l_it, &mut p_it, &mut w_it, fps, tps, max_false_positive_rate)
396    };
397}
398    
399
400fn roc_auc_on_descending_iterator_with_fp_cutoff<B: BinaryLabel>(
401    labels: &mut impl Iterator<Item = B>,
402    predictions: &mut impl Iterator<Item = f64>,
403    weights: &mut impl Iterator<Item = f64>,
404    false_positive_sum: f64,
405    true_positive_sum: f64,
406    max_false_positive_rate: f64
407) -> f64 {
408    let mut false_positives: f64 = 0.0;
409    let mut true_positives: f64 = 0.0;
410    let mut last_counted_fp = 0.0;
411    let mut last_counted_tp = 0.0;
412    let mut area_under_curve = 0.0;
413    let mut zipped = labels.zip(predictions).zip(weights).peekable();
414    let false_positive_cutoff = max_false_positive_rate * false_positive_sum;
415    loop {
416        match zipped.next() {
417            None => break,
418            Some(actual) => {
419                let l = f64::from(actual.0.0.get_value());
420                let w = actual.1;
421                let wl = l * w;
422                let next_tp = true_positives + wl;
423                let next_fp = false_positives + (w - wl);
424                let is_above_max = next_fp > false_positive_cutoff;
425                if is_above_max {
426                    let dx = next_fp  - false_positives;
427                    let dy = next_tp - true_positives;
428                    true_positives += dy * false_positive_cutoff / dx;
429                    false_positives = false_positive_cutoff;
430                } else {
431                    true_positives = next_tp;
432                    false_positives = next_fp;
433                }
434                if zipped.peek().map(|x| x.0.1 != actual.0.1).unwrap_or(true) || is_above_max {
435                    area_under_curve += area_under_line_segment(last_counted_fp, false_positives, last_counted_tp, true_positives);
436                    last_counted_fp = false_positives;
437                    last_counted_tp = true_positives;
438                }
439                if is_above_max {
440                    break;
441                }                
442            }
443        };
444    }
445    let normalized_area_under_curve = area_under_curve / (true_positive_sum * false_positive_sum);
446    let min_area = 0.5 * max_false_positive_rate * max_false_positive_rate;
447    let max_area = max_false_positive_rate;
448    return 0.5 * (1.0 + (normalized_area_under_curve - min_area) / (max_area - min_area));
449}
450
451
452// Python bindings
453#[pyclass(eq, eq_int, name="Order")]
454#[derive(Clone, Copy, PartialEq)]
455pub enum PyOrder {
456    ASCENDING,
457    DESCENDING
458}
459
460fn py_order_as_order(order: PyOrder) -> Order {
461    return match order {
462        PyOrder::ASCENDING => Order::ASCENDING,
463        PyOrder::DESCENDING => Order::DESCENDING,
464    }
465}
466
467fn average_precision_py_generic<'py, B>(
468    py: Python<'py>,
469    labels: &PyReadonlyArray1<'py, B>,
470    predictions: &PyReadonlyArray1<'py, f64>,
471    weights: &Option<PyReadonlyArray1<'py, f64>>,
472    order: &Option<PyOrder>
473) -> f64
474where B: BinaryLabel + Element
475{
476    let labels = labels.as_array();
477    let predictions = predictions.as_array();
478    let order = order.map(py_order_as_order);
479    let ap = match weights {
480        Some(w) => {
481            let weights = w.as_array();
482            py.allow_threads(move || {
483                average_precision_with_order(&labels, &predictions, Some(&weights), order)
484            })
485        },
486        None => {
487            py.allow_threads(move || {
488                average_precision_with_order(&labels, &predictions, None::<&ArrayView<'_, f64, Ix1>>, order)
489            })
490        }
491    };
492    return ap;
493}
494
495fn average_precision_py_match_run<'py, T>(
496    py: Python<'py>,
497    labels: &Bound<'py, PyUntypedArray>,
498    predictions: &PyReadonlyArray1<'py, f64>,
499    weights: &Option<PyReadonlyArray1<'py, f64>>,
500    order: &Option<PyOrder>,
501    dt: &Bound<'py, PyArrayDescr>
502) -> Option<f64>
503where T: Element + BinaryLabel
504{
505    return if dt.is_equiv_to(&dtype::<T>(py)) {
506        let labels = labels.downcast::<PyArray1<T>>().unwrap().readonly();
507        Some(average_precision_py_generic(py, &labels.readonly(), predictions, weights, order))
508    } else {
509        None
510    }
511}
512
513#[pyfunction(name = "average_precision")]
514#[pyo3(signature = (labels, predictions, *, weights=None, order=None))]
515pub fn average_precision_py<'py>(
516    py: Python<'py>,
517    labels: &Bound<'py, PyUntypedArray>,
518    predictions: PyReadonlyArray1<'py, f64>,
519    weights: Option<PyReadonlyArray1<'py, f64>>,
520    order: Option<PyOrder>
521) -> PyResult<f64> {
522    if labels.ndim() != 1 {
523        return Err(PyTypeError::new_err(format!("Expected 1-dimensional array for labels but found {} dimenisons.", labels.ndim())));
524    }
525    let label_dtype = labels.dtype();
526    if let Some(ap) = average_precision_py_match_run::<bool>(py, &labels, &predictions, &weights, &order, &label_dtype) {
527        return Ok(ap)
528    }
529    else if let Some(ap) = average_precision_py_match_run::<u8>(py, &labels, &predictions, &weights, &order, &label_dtype) {
530        return Ok(ap)
531    }
532    else if let Some(ap) = average_precision_py_match_run::<i8>(py, &labels, &predictions, &weights, &order, &label_dtype) {
533        return Ok(ap)
534    }
535    else if let Some(ap) = average_precision_py_match_run::<u16>(py, &labels, &predictions, &weights, &order, &label_dtype) {
536        return Ok(ap)
537    }
538    else if let Some(ap) = average_precision_py_match_run::<i16>(py, &labels, &predictions, &weights, &order, &label_dtype) {
539        return Ok(ap)
540    }
541    else if let Some(ap) = average_precision_py_match_run::<u32>(py, &labels, &predictions, &weights, &order, &label_dtype) {
542        return Ok(ap)
543    }
544    else if let Some(ap) = average_precision_py_match_run::<i32>(py, &labels, &predictions, &weights, &order, &label_dtype) {
545        return Ok(ap)
546    }
547    else if let Some(ap) = average_precision_py_match_run::<u64>(py, &labels, &predictions, &weights, &order, &label_dtype) {
548        return Ok(ap)
549    }
550    else if let Some(ap) = average_precision_py_match_run::<i64>(py, &labels, &predictions, &weights, &order, &label_dtype) {
551        return Ok(ap)
552    }
553    return Err(PyTypeError::new_err(format!("Unsupported dtype for labels: {}. Supported dtypes are bool, uint8, uint16, uint32, uint64, in8, int16, int32, int64", label_dtype)));
554}
555
556fn roc_auc_py_generic<'py, B>(
557    py: Python<'py>,
558    labels: &PyReadonlyArray1<'py, B>,
559    predictions: &PyReadonlyArray1<'py, f64>,
560    weights: &Option<PyReadonlyArray1<'py, f64>>,
561    order: &Option<PyOrder>,
562    max_false_positive_rate: Option<f64>,
563) -> f64
564where B: BinaryLabel + Element
565{
566    let labels = labels.as_array();
567    let predictions = predictions.as_array();
568    let order = order.map(py_order_as_order);
569    let auc = match weights {
570        Some(weight) => {
571            let weights = weight.as_array();
572            py.allow_threads(move || {
573                roc_auc_with_order(&labels, &predictions, Some(&weights), order, max_false_positive_rate)
574            })
575        },
576        None => py.allow_threads(move || {
577            roc_auc_with_order(&labels, &predictions, None::<&Vec<f64>>, order, max_false_positive_rate)
578        })
579    };
580    return auc;
581}
582
583fn roc_auc_py_match_run<'py, T>(
584    py: Python<'py>,
585    labels: &Bound<'py, PyUntypedArray>,
586    predictions: &PyReadonlyArray1<'py, f64>,
587    weights: &Option<PyReadonlyArray1<'py, f64>>,
588    order: &Option<PyOrder>,
589    max_false_positive_rate: Option<f64>,
590    dt: &Bound<'py, PyArrayDescr>
591) -> Option<f64>
592where T: Element + BinaryLabel
593{
594    return if dt.is_equiv_to(&dtype::<T>(py)) {
595        let labels = labels.downcast::<PyArray1<T>>().unwrap().readonly();
596        Some(roc_auc_py_generic(py, &labels.readonly(), predictions, weights, order, max_false_positive_rate))
597    } else {
598        None
599    }
600}
601
602#[pyfunction(name = "roc_auc")]
603#[pyo3(signature = (labels, predictions, *, weights=None, order=None, max_false_positive_rate=None))]
604pub fn roc_auc_py<'py>(
605    py: Python<'py>,
606    labels: &Bound<'py, PyUntypedArray>,
607    predictions: PyReadonlyArray1<'py, f64>,
608    weights: Option<PyReadonlyArray1<'py, f64>>,
609    order: Option<PyOrder>,
610    max_false_positive_rate: Option<f64>,
611) -> PyResult<f64> {
612    if labels.ndim() != 1 {
613        return Err(PyTypeError::new_err(format!("Expected 1-dimensional array for labels but found {} dimenisons.", labels.ndim())));
614    }
615    let label_dtype = labels.dtype();
616    if let Some(auc) = roc_auc_py_match_run::<bool>(py, &labels, &predictions, &weights, &order, max_false_positive_rate, &label_dtype) {
617        return Ok(auc)
618    }
619    else if let Some(auc) = roc_auc_py_match_run::<u8>(py, &labels, &predictions, &weights, &order, max_false_positive_rate, &label_dtype) {
620        return Ok(auc)
621    }
622    else if let Some(auc) = roc_auc_py_match_run::<i8>(py, &labels, &predictions, &weights, &order, max_false_positive_rate, &label_dtype) {
623        return Ok(auc)
624    }
625    else if let Some(auc) = roc_auc_py_match_run::<u16>(py, &labels, &predictions, &weights, &order, max_false_positive_rate, &label_dtype) {
626        return Ok(auc)
627    }
628    else if let Some(auc) = roc_auc_py_match_run::<i16>(py, &labels, &predictions, &weights, &order, max_false_positive_rate, &label_dtype) {
629        return Ok(auc)
630    }
631    else if let Some(auc) = roc_auc_py_match_run::<u32>(py, &labels, &predictions, &weights, &order, max_false_positive_rate, &label_dtype) {
632        return Ok(auc)
633    }
634    else if let Some(auc) = roc_auc_py_match_run::<i32>(py, &labels, &predictions, &weights, &order, max_false_positive_rate, &label_dtype) {
635        return Ok(auc)
636    }
637    else if let Some(auc) = roc_auc_py_match_run::<u64>(py, &labels, &predictions, &weights, &order, max_false_positive_rate, &label_dtype) {
638        return Ok(auc)
639    }
640    else if let Some(auc) = roc_auc_py_match_run::<i64>(py, &labels, &predictions, &weights, &order, max_false_positive_rate, &label_dtype) {
641        return Ok(auc)
642    }
643    return Err(PyTypeError::new_err(format!("Unsupported dtype for labels: {}. Supported dtypes are bool, uint8, uint16, uint32, uint64, in8, int16, int32, int64", label_dtype)));
644}
645
646#[pymodule(name = "_scors")]
647fn scors(m: &Bound<'_, PyModule>) -> PyResult<()> {
648    m.add_function(wrap_pyfunction!(average_precision_py, m)?).unwrap();
649    m.add_function(wrap_pyfunction!(roc_auc_py, m)?).unwrap();
650    m.add_class::<PyOrder>().unwrap();
651    return Ok(());
652}
653
654
655#[cfg(test)]
656mod tests {
657    use super::*;
658
659    #[test]
660    fn test_average_precision_on_sorted() {
661        let labels: [u8; 4] = [1, 0, 1, 0];
662        // let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
663        let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
664        let actual = average_precision_on_sorted_labels(&labels, Some(&weights), Order::DESCENDING);
665        assert_eq!(actual, 0.8333333333333333);
666    }
667
668    #[test]
669    fn test_average_precision_unsorted() {
670        let labels: [u8; 4] = [0, 0, 1, 1];
671        let predictions: [f64; 4] = [0.1, 0.4, 0.35, 0.8];
672        let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
673        let actual = average_precision_with_order(&labels, &predictions, Some(&weights), None);
674        assert_eq!(actual, 0.8333333333333333);
675    }
676
677    #[test]
678    fn test_average_precision_sorted() {
679        let labels: [u8; 4] = [1, 0, 1, 0];
680        let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
681        let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
682        let actual = average_precision_with_order(&labels, &predictions, Some(&weights), Some(Order::DESCENDING));
683        assert_eq!(actual, 0.8333333333333333);
684    }
685
686    #[test]
687    fn test_roc_auc() {
688        let labels: [u8; 4] = [1, 0, 1, 0];
689        let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
690        let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
691        let actual = roc_auc_with_order(&labels, &predictions, Some(&weights), Some(Order::DESCENDING), None);
692        assert_eq!(actual, 0.75);
693    }
694}