_scors/
lib.rs

1#![feature(trait_alias)]
2
3mod combine;
4
5use ndarray::{Array1,ArrayView,ArrayView2,ArrayView3,ArrayViewMut1,Ix1};
6use num;
7use num::traits::float::TotalOrder;
8use numpy::{Element,PyArray,PyArray1,PyArray2,PyArray3,PyArrayDescrMethods,PyArrayDyn,PyArrayMethods,PyReadonlyArray1,PyUntypedArray,PyUntypedArrayMethods,dtype};
9use pyo3::Bound;
10use pyo3::exceptions::PyTypeError;
11use pyo3::marker::Ungil;
12use pyo3::prelude::*;
13use std::cmp::PartialOrd;
14use std::iter::{DoubleEndedIterator,repeat};
15use std::ops::AddAssign;
16
17#[derive(Clone, Copy)]
18pub enum Order {
19    ASCENDING,
20    DESCENDING
21}
22
23#[derive(Clone, Copy)]
24struct ConstWeight<F: num::Float> {
25    value: F
26}
27
28impl <F: num::Float> ConstWeight<F> {
29    fn new(value: F) -> Self {
30        return ConstWeight { value: value };
31    }
32    fn one() -> Self {
33        return Self::new(F::one());
34    }
35}
36
37pub trait Data<T: Clone>: {
38    // TODO This is necessary because it seems that there is no trait like that in rust
39    //      Maybe I am just not aware, but for now use my own trait.
40    fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> + Clone;
41    fn get_at(&self, index: usize) -> T;
42}
43
44pub trait SortableData<T> {
45    fn argsort_unstable(&self) -> Vec<usize>;
46}
47
48impl <F: num::Float> Iterator for ConstWeight<F> {
49    type Item = F;
50    fn next(&mut self) -> Option<F> {
51        return Some(self.value);
52    }
53}
54
55impl <F: num::Float> DoubleEndedIterator for ConstWeight<F> {
56    fn next_back(&mut self) -> Option<F> {
57        return Some(self.value);
58    }
59}
60
61impl <F: num::Float> Data<F> for ConstWeight<F> {
62    fn get_iterator(&self) -> impl DoubleEndedIterator<Item = F> + Clone {
63        return ConstWeight::new(self.value);
64    }
65
66    fn get_at(&self, _index: usize) -> F {
67        return self.value.clone();
68    }
69}
70
71impl <T: Clone> Data<T> for Vec<T> {
72    fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> + Clone {
73        return self.iter().cloned();
74    }
75    fn get_at(&self, index: usize) -> T {
76        return self[index].clone();
77    }
78}
79
80impl SortableData<f64> for Vec<f64> {
81    fn argsort_unstable(&self) -> Vec<usize> {
82        let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
83        indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
84        // indices.sort_unstable_by_key(|i| self[*i]);
85        return indices;
86    }
87}
88
89impl <T: Clone> Data<T> for &[T] {
90    fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> + Clone {
91        return self.iter().cloned();
92    }
93    fn get_at(&self, index: usize) -> T {
94        return self[index].clone();
95    }
96}
97
98impl SortableData<f64> for &[f64] {
99    fn argsort_unstable(&self) -> Vec<usize> {
100        let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
101        indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
102        return indices;
103    }
104}
105
106impl <T: Clone, const N: usize> Data<T> for [T; N] {
107    fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> + Clone {
108        return self.iter().cloned();
109    }
110    fn get_at(&self, index: usize) -> T {
111        return self[index].clone();
112    }
113}
114
115impl <const N: usize> SortableData<f64> for [f64; N] {
116    fn argsort_unstable(&self) -> Vec<usize> {
117        let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
118        indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
119        return indices;
120    }
121}
122
123impl <T: Clone> Data<T> for ArrayView<'_, T, Ix1> {
124    fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> + Clone {
125        return self.iter().cloned();
126    }
127    fn get_at(&self, index: usize) -> T {
128        return self[index].clone();
129    }
130}
131
132impl <F> SortableData<F> for ArrayView<'_, F, Ix1>
133where F: num::Float + TotalOrder
134{
135    fn argsort_unstable(&self) -> Vec<usize> {
136        let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
137        indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
138        return indices;
139    }
140}
141
142pub trait BinaryLabel: Clone + Copy {
143    fn get_value(&self) -> bool;
144}
145
146impl BinaryLabel for bool {
147    fn get_value(&self) -> bool {
148        return self.clone();
149    }
150}
151
152impl BinaryLabel for u8 {
153    fn get_value(&self) -> bool {
154        return (self & 1u8) == 1u8;
155    }
156}
157
158impl BinaryLabel for u16 {
159    fn get_value(&self) -> bool {
160        return (self & 1u16) == 1u16;
161    }
162}
163
164impl BinaryLabel for u32 {
165    fn get_value(&self) -> bool {
166        return (self & 1u32) == 1u32;
167    }
168}
169
170impl BinaryLabel for u64 {
171    fn get_value(&self) -> bool {
172        return (self & 1u64) == 1u64;
173    }
174}
175
176impl BinaryLabel for i8 {
177    fn get_value(&self) -> bool {
178        return (self & 1i8) == 1i8;
179    }
180}
181
182impl BinaryLabel for i16 {
183    fn get_value(&self) -> bool {
184        return (self & 1i16) == 1i16;
185    }
186}
187
188impl BinaryLabel for i32 {
189    fn get_value(&self) -> bool {
190        return (self & 1i32) == 1i32;
191    }
192}
193
194impl BinaryLabel for i64 {
195    fn get_value(&self) -> bool {
196        return (self & 1i64) == 1i64;
197    }
198}
199
200fn select<T, I>(slice: &I, indices: &[usize]) -> Vec<T>
201where T: Copy, I: Data<T>
202{
203    let mut selection: Vec<T> = Vec::new();
204    selection.reserve_exact(indices.len());
205    for index in indices {
206        selection.push(slice.get_at(*index));
207    }
208    return selection;
209}
210
211pub trait ScoreAccumulator = num::Float + AddAssign + From<bool> + From<f32>;
212pub trait IntoScore<S: ScoreAccumulator> =  Into<S> + num::Float;
213
214
215
216pub trait ScoreSortedDescending {
217    fn _score<S: ScoreAccumulator>(&self, labels_with_weights: impl Iterator<Item = (S, (bool, S))> + Clone) -> S;
218    fn score<S, P, B, W>(&self, labels_with_weights: impl Iterator<Item = (P, (B, W))> + Clone) -> S
219    where S: ScoreAccumulator, P: IntoScore<S>, B: BinaryLabel, W: IntoScore<S>
220    {
221        return self._score(
222            labels_with_weights.map(|(p, (b, w))| -> (S, (bool, S)) { (p.into(), (b.get_value(), w.into()))})
223        )
224    }
225}
226
227
228pub fn score_sorted_iterators<S, SA, P, B, W>(
229    score: S,
230    predictions: impl Iterator<Item = P> + Clone,
231    labels: impl Iterator<Item = B> + Clone,
232    weights: impl Iterator<Item = W> + Clone,
233) -> SA
234where S: ScoreSortedDescending, SA: ScoreAccumulator, P: IntoScore<SA>, B: BinaryLabel, W: IntoScore<SA> {
235    let zipped = predictions.zip(labels.zip(weights));
236    return score.score(zipped);
237}
238
239
240pub fn score_sorted_sample<S, SA, P, B, W>(
241    score: S,
242    predictions: &impl Data<P>,
243    labels: &impl Data<B>,
244    weights: &impl Data<W>,
245    order: Order,
246) -> SA
247where S: ScoreSortedDescending, SA: ScoreAccumulator, P: IntoScore<SA>, B: BinaryLabel, W: IntoScore<SA> + Clone {
248    let p = predictions.get_iterator();
249    let l = labels.get_iterator();
250    let w = weights.get_iterator();
251    return match order {
252        Order::ASCENDING => score_sorted_iterators(score, p.rev(), l.rev(), w.rev()),
253        Order::DESCENDING => score_sorted_iterators(score, p, l, w),
254    };
255}
256
257
258pub fn score_maybe_sorted_sample<S, SA, P, B, W>(
259    score: S,
260    predictions: &(impl Data<P> + SortableData<P>),
261    labels: &impl Data<B>,
262    weights: Option<&impl Data<W>>,
263    order: Option<Order>,
264) -> SA
265where S: ScoreSortedDescending, SA: ScoreAccumulator, P: IntoScore<SA>, B: BinaryLabel, W: IntoScore<SA> + Clone
266{
267    return match order {
268        Some(o) => {
269            match weights {
270                Some(w) => score_sorted_sample(score, predictions, labels, w, o),
271                None => score_sorted_sample(score, predictions, labels, &ConstWeight::<W>::one(), o),
272            }
273        }
274        None => {
275            let indices = predictions.argsort_unstable();
276            let sorted_labels = select(labels, &indices);
277            let sorted_predictions = select(predictions, &indices);
278            match weights {
279                Some(w) => {
280                    let sorted_weights = select(w, &indices);
281                    score_sorted_sample(score, &sorted_predictions, &sorted_labels, &sorted_weights, Order::DESCENDING)
282                }
283                None => score_sorted_sample(score, &sorted_predictions, &sorted_labels, &ConstWeight::<W>::one(), Order::DESCENDING)
284            }
285        }
286    };
287}
288
289
290pub fn score_sample<S, SA, P, B, W>(
291    score: S,
292    predictions: &(impl Data<P> + SortableData<P>),
293    labels: &impl Data<B>,
294    weights: Option<&impl Data<W>>,
295) -> SA
296
297where S: ScoreSortedDescending, SA: ScoreAccumulator, P: IntoScore<SA>, B: BinaryLabel, W: IntoScore<SA> + Clone {
298    return score_maybe_sorted_sample(score, predictions, labels, weights, None);
299}
300
301
302pub fn score_two_sorted_samples<S, SA, P, B, W>(
303    score: S,
304    predictions1: impl Iterator<Item = P> + Clone,
305    label1: impl Iterator<Item = B> + Clone,
306    weight1: impl Iterator<Item = W> + Clone,
307    predictions2: impl Iterator<Item = P> + Clone,
308    label2: impl Iterator<Item = B> + Clone,
309    weight2: impl Iterator<Item = W> + Clone,
310) -> SA
311where S: ScoreSortedDescending, SA: ScoreAccumulator, P: IntoScore<SA>, B: BinaryLabel + PartialOrd, W: IntoScore<SA>
312{
313    return score_two_sorted_samples_zipped(
314        score,
315        predictions1.zip(label1.zip(weight1)),
316        predictions2.zip(label2.zip(weight2)),
317    );
318}
319
320
321pub fn score_two_sorted_samples_zipped<S, SA, P, B, W>(
322    score: S,
323    iter1: impl Iterator<Item = (P, (B, W))> + Clone,
324    iter2: impl Iterator<Item = (P, (B, W))> + Clone,
325) -> SA
326where S: ScoreSortedDescending, SA: ScoreAccumulator, P: IntoScore<SA>, B: BinaryLabel + PartialOrd, W: IntoScore<SA>
327{
328    let combined_iter = combine::combine::CombineIterDescending::new(iter1, iter2);
329    return score.score(combined_iter);
330}
331
332
333struct AveragePrecision {
334    
335}
336
337
338impl AveragePrecision {
339    fn new() -> Self {
340        return AveragePrecision{};
341    }
342}
343
344
345#[derive(Clone,Copy,Debug)]
346struct Positives<P>
347where P: num::Float + From<bool> + AddAssign
348{
349    tps: P,
350    fps: P,
351}
352
353impl <P> Positives<P>
354where P: num::Float + From<bool> + AddAssign
355{
356    fn new(tps: P, fps: P) -> Self {
357        return Positives { tps, fps };
358    }
359
360    fn zero() -> Self {
361        return Positives::new(P::zero(), P::zero());
362    }
363
364    fn add(&mut self, label: bool, weight: P) {
365        let label: P = label.into();
366        let tp = weight * label;
367        let fp = weight - tp;  // (weight*(1 -label) = weight - weight * label = weight - tp)
368        self.tps += tp;
369        self.fps += fp;
370    }
371
372    fn positives_sum(&self) -> P {
373        return self.tps + self.fps;
374    }
375
376    fn precision(&self) -> P {
377        return self.tps / self.positives_sum();
378    }
379}
380
381
382impl ScoreSortedDescending for AveragePrecision {
383    fn _score<S: ScoreAccumulator>(&self, mut labels_with_weights: impl Iterator<Item = (S, (bool, S))> + Clone) -> S
384    {
385        let mut positives: Positives<S> = Positives::zero();
386        let mut last_p: S = f32::NAN.into();
387        let mut last_tps: S = S::zero();
388        let mut ap: S = S::zero();
389
390        // TODO can we unify this preparation step with the loop?
391        match labels_with_weights.next() {
392            None => (), // TODO: Sohuld we return an error in this case?
393            Some((p, (label, w))) => {
394                positives.add(label, w);
395                last_p = p;
396            }
397        }
398        
399        for (p, (label, w)) in labels_with_weights {
400            if last_p != p {
401                ap += (positives.tps - last_tps) * positives.precision();
402                last_p = p;
403                last_tps = positives.tps;
404            }
405            positives.add(label.get_value(), w.into());
406        }
407
408        ap += (positives.tps - last_tps) * positives.precision();
409        
410        // Special case for tps == 0 following sklearn
411        // https://github.com/scikit-learn/scikit-learn/blob/5cce87176a530d2abea45b5a7e5a4d837c481749/sklearn/metrics/_ranking.py#L1032-L1039
412        // I.e. if tps is 0.0, there are no positive samples in labels: Either all labels are 0, or all weights (for positive labels) are 0
413        return if positives.tps == S::zero() {
414            S::zero()
415        } else {
416            ap / positives.tps
417        };
418    }
419}
420
421
422struct RocAuc {
423
424}
425
426
427impl RocAuc {
428    fn new() -> Self {
429        return RocAuc { };
430    }
431}
432
433
434impl ScoreSortedDescending for RocAuc {
435    fn _score<S: ScoreAccumulator>(&self, mut labels_with_weights: impl Iterator<Item = (S, (bool, S))> + Clone) -> S
436    {
437        let mut positives: Positives<S> = Positives::zero();
438        let mut last_p: S = f32::NAN.into();
439        let mut last_counted_fp = S::zero();
440        let mut last_counted_tp = S::zero();
441        let mut area_under_curve = S::zero();
442
443        // TODO can we unify this preparation step with the loop?
444        match labels_with_weights.next() {
445            None => (), // TODO: Should we return an error in this case?
446            Some((p, (label, w))) => {
447                positives.add(label, w);
448                last_p = p;
449            }
450        }
451        
452        for (p, (label, w)) in labels_with_weights {
453            if last_p != p {
454                area_under_curve += area_under_line_segment(
455                    last_counted_fp,
456                    positives.fps,
457                    last_counted_tp,
458                    positives.tps,
459                );
460                last_counted_fp = positives.fps;
461                last_counted_tp = positives.tps;
462                last_p = p;
463            }
464            positives.add(label, w);
465        }
466        area_under_curve += area_under_line_segment(
467            last_counted_fp,
468            positives.fps,
469            last_counted_tp,
470            positives.tps,
471        );
472        return area_under_curve / (positives.tps * positives.fps);
473    }
474}
475
476
477struct RocAucWithMaxFPR {
478    max_fpr: f32,
479}
480
481
482impl RocAucWithMaxFPR {
483    fn new(max_fpr: f32) -> Self {
484        return RocAucWithMaxFPR { max_fpr };
485    }
486
487    fn get_positive_sum<B, W>(labels_with_weights: impl Iterator<Item = (B, W)>) -> Positives<W>
488    where B: BinaryLabel, W: num::Float + From::<bool> + AddAssign
489    {
490        let mut positives: Positives<W>  = Positives::zero();
491        for (label, weight) in labels_with_weights {
492            positives.add(label.get_value(), weight);
493        }
494        return positives;
495    }
496}
497
498
499impl ScoreSortedDescending for RocAucWithMaxFPR {
500    fn _score<S: ScoreAccumulator>(&self, mut labels_with_weights: impl Iterator<Item = (S, (bool, S))> + Clone) -> S
501    {
502        let total_positives = Self::get_positive_sum(labels_with_weights.clone().map(|(_a, b)| b));
503        let max_fpr: S = self.max_fpr.into();
504        let false_positive_cutoff = max_fpr * total_positives.fps;
505
506        let mut positives: Positives<S> = Positives::zero();
507        let mut last_p: S = f32::NAN.into();
508        let mut last_counted_fp = S::zero();
509        let mut last_counted_tp = S::zero();
510        let mut area_under_curve = S::zero();
511
512        // TODO can we unify this preparation step with the loop?
513        match labels_with_weights.next() {
514            None => (), // TODO: Should we return an error in this case?
515            Some((p, (label, w))) => {
516                positives.add(label, w);
517                last_p = p;
518            }
519        }
520        
521        for (p, (label, w)) in labels_with_weights {
522            if last_p != p {
523                area_under_curve += area_under_line_segment(
524                    last_counted_fp,
525                    positives.fps,
526                    last_counted_tp,
527                    positives.tps,
528                );
529                last_counted_fp = positives.fps;
530                last_counted_tp = positives.tps;
531                last_p = p;
532            }
533            let mut next_pos = positives.clone();
534            next_pos.add(label, w);
535            if next_pos.fps > false_positive_cutoff {
536                let dx = next_pos.fps - positives.fps;
537                let dy = next_pos.tps - positives.tps;
538                positives = Positives::new(
539                    positives.tps + dy * false_positive_cutoff / dx,
540                    false_positive_cutoff,
541                );
542                break;
543            }
544            else {
545                positives = next_pos;
546            }
547        }
548
549        area_under_curve += area_under_line_segment(
550            last_counted_fp,
551            positives.fps,
552            last_counted_tp,
553            positives.tps,
554        );
555        
556        let normalized_area_under_curve = area_under_curve / (total_positives.tps * total_positives.fps);
557        let one_half: S = 0.5f32.into(); 
558        let min_area = one_half * max_fpr * max_fpr;
559        let max_area = max_fpr;
560        return one_half * (S::one() + (normalized_area_under_curve - min_area) / (max_area - min_area));
561    }
562}
563
564
565struct RocAucWithOptionalMaxFPR {
566    // TODO: Can we have a single implementation for this and RocAuc?
567    //       This would add an unncessary check to RocAuc but performance
568    //       penalty may be negligible.
569    max_fpr: Option<f32>,
570}
571
572impl RocAucWithOptionalMaxFPR {
573    fn new(max_fpr: Option<f32>) -> Self {
574        return Self { max_fpr };
575    }
576}
577
578
579impl ScoreSortedDescending for RocAucWithOptionalMaxFPR {
580    fn _score<S: ScoreAccumulator>(&self, labels_with_weights: impl Iterator<Item = (S, (bool, S))> + Clone) -> S
581    {
582        return match self.max_fpr {
583            Some(mfpr) => RocAucWithMaxFPR::new(mfpr).score(labels_with_weights),
584            None => RocAuc::new().score(labels_with_weights),
585        }
586    }
587}
588
589
590pub fn average_precision<S, P, B, W>(
591    predictions: &(impl Data<P> + SortableData<P>),
592    labels: &impl Data<B>,
593    weights: Option<&impl Data<W>>,
594    order: Option<Order>,
595) -> S
596where S: ScoreAccumulator, P: IntoScore<S>, B: BinaryLabel, W: IntoScore<S> + Clone
597{
598    return score_maybe_sorted_sample(AveragePrecision::new(), predictions, labels, weights, order);
599}
600
601
602pub fn roc_auc<S, P, B, W>(
603    predictions: &(impl Data<P> + SortableData<P>),
604    labels: &impl Data<B>,
605    weights: Option<&impl Data<W>>,
606    order: Option<Order>,
607    max_fpr: Option<f32>,
608) -> S
609where S: ScoreAccumulator, P: IntoScore<S>, B: BinaryLabel, W: IntoScore<S> + Clone
610{
611    return score_maybe_sorted_sample(RocAucWithOptionalMaxFPR::new(max_fpr), predictions, labels, weights, order);
612}
613
614
615fn area_under_line_segment<P>(x0: P, x1: P, y0: P, y1: P) -> P
616where P: num::Float + From<f32>
617{
618    let dx = x1 - x0;
619    let dy = y1 - y0;
620    let one_half: P = 0.5f32.into();
621    return dx * y0 + dy * dx * one_half;
622}
623
624
625pub fn loo_cossim<F: num::Float + AddAssign>(mat: &ArrayView2<'_, F>, replicate_sum: &mut ArrayViewMut1<'_, F>) -> F {
626    let num_replicates = mat.shape()[0];
627    let loo_weight = F::from(num_replicates - 1).unwrap();
628    let loo_weight_factor = F::from(1).unwrap() / loo_weight;
629    for mat_replicate in mat.outer_iter() {
630        for (feature, feature_sum) in mat_replicate.iter().zip(replicate_sum.iter_mut()) {
631            *feature_sum += *feature;
632        }
633    }
634
635    let mut result = F::zero();
636
637    for mat_replicate in mat.outer_iter() {
638        let mut m_sqs = F::zero();
639        let mut l_sqs = F::zero();
640        let mut prod_sum = F::zero();
641        for (feature, feature_sum) in mat_replicate.iter().zip(replicate_sum.iter()) {
642            let m_f = *feature;
643            let l_f = (*feature_sum - *feature) * loo_weight_factor;
644            prod_sum += m_f * l_f;
645            m_sqs += m_f * m_f;
646            l_sqs += l_f * l_f;
647        }
648        result += prod_sum / (m_sqs * l_sqs).sqrt();
649    }
650
651    return result / F::from(num_replicates).unwrap();
652}
653
654
655pub fn loo_cossim_single<F: num::Float + AddAssign>(mat: &ArrayView2<'_, F>) -> F {
656    let mut replicate_sum = Array1::<F>::zeros(mat.shape()[1]);
657    return loo_cossim(mat, &mut replicate_sum.view_mut());
658}
659
660
661pub fn loo_cossim_many<F: num::Float + AddAssign>(mat: &ArrayView3<'_, F>) -> Array1<F> {
662    let mut cossims = Array1::<F>::zeros(mat.shape()[0]);
663    let mut replicate_sum = Array1::<F>::zeros(mat.shape()[2]);
664    for (m, c) in mat.outer_iter().zip(cossims.iter_mut()) {
665        replicate_sum.fill(F::zero());
666        *c = loo_cossim(&m, &mut replicate_sum.view_mut());
667    }
668    return cossims;
669}
670
671
672// Python bindings
673#[pyclass(eq, eq_int, name="Order")]
674#[derive(Clone, Copy, PartialEq)]
675pub enum PyOrder {
676    ASCENDING,
677    DESCENDING
678}
679
680fn py_order_as_order(order: PyOrder) -> Order {
681    return match order {
682        PyOrder::ASCENDING => Order::ASCENDING,
683        PyOrder::DESCENDING => Order::DESCENDING,
684    }
685}
686
687trait PyScoreGeneric<S: ScoreSortedDescending>: Ungil + Sync {
688
689    fn get_score(&self) -> S;
690
691    fn score_py<'py, P, B, W>(
692        &self,
693        py: Python<'py>,
694        labels: PyReadonlyArray1<'py, B>,
695        predictions: PyReadonlyArray1<'py, P>,
696        weights: Option<PyReadonlyArray1<'py, W>>,
697        order: Option<PyOrder>,
698    ) -> P
699    where P: ScoreAccumulator + Element + TotalOrder, B: BinaryLabel + Element, W: IntoScore<P> + Element
700    {
701        let labels = labels.as_array();
702        let predictions = predictions.as_array();
703        let order = order.map(py_order_as_order);
704        let score = match weights {
705            Some(weight) => {
706                let w = weight.as_array();
707                py.allow_threads(move || {
708                    score_maybe_sorted_sample(self.get_score(), &predictions, &labels, Some(&w), order)
709                })
710            },
711            None => py.allow_threads(move || {
712                score_maybe_sorted_sample(self.get_score(), &predictions, &labels, None::<&Vec<W>>, order)
713            })
714        };
715        return score;
716    }
717
718    fn score_two_sorted_samples_py_generic<'py, B, F, W, B1, B2, F1, F2, W1, W2>(
719        &self,
720        py: Python<'py>,
721        labels1: PyReadonlyArray1<'py, B1>,
722        predictions1: PyReadonlyArray1<'py, F1>,
723        weights1: Option<PyReadonlyArray1<'py, W1>>,
724        labels2: PyReadonlyArray1<'py, B1>,
725        predictions2: PyReadonlyArray1<'py, F2>,
726        weights2: Option<PyReadonlyArray1<'py, W2>>,
727    ) -> F
728    where B: BinaryLabel + PartialOrd, F: ScoreAccumulator + TotalOrder + Ungil, W: IntoScore<F>, B1: Element + Into<B> + Clone, B2: Element + Into<B> + Clone, F1: Element + Into<F> + Clone, F2: Element + Into<F> + Clone, W1: Element + Into<W> + Clone, W2: Element + Into<W> + Clone
729    {
730        let l1 = labels1.as_array().into_iter().cloned().map(|l| -> B { l.into() });
731        let l2 = labels2.as_array().into_iter().cloned().map(|l| -> B { l.into() });
732        let p1 = predictions1.as_array().into_iter().cloned().map(|f| -> F { f.into() });
733        let p2 = predictions2.as_array().into_iter().cloned().map(|f| -> F { f.into() });
734
735
736        return match (weights1, weights2) {
737            (None, None) => {
738                py.allow_threads(move || {
739                    score_two_sorted_samples(self.get_score(), p1, l1, repeat(W::one()), p2, l2, repeat(W::one()))
740                })
741            }
742            (Some(w1), None) => {
743                let w1i = w1.as_array().into_iter().cloned().map(|w| -> W { w.into() });
744                py.allow_threads(move || {
745                    score_two_sorted_samples(self.get_score(), p1, l1, w1i, p2, l2, repeat(W::one()))
746                })
747            }
748            (None, Some(w2)) => {
749                let w2i = w2.as_array().into_iter().cloned().map(|w| -> W { w.into() });
750                py.allow_threads(move || {
751                    score_two_sorted_samples(self.get_score(), p1, l1, repeat(W::one()), p2, l2, w2i)
752                })
753            }
754            (Some(w1), Some(w2)) =>  {
755                let w1i = w1.as_array().into_iter().cloned().map(|w| -> W { w.into() });
756                let w2i = w2.as_array().into_iter().cloned().map(|w| -> W { w.into() });
757                py.allow_threads(move || {
758                    score_two_sorted_samples(self.get_score(), p1, l1, w1i, p2, l2, w2i)
759                })
760            }
761        };
762    }
763}
764
765struct AveragePrecisionPyGeneric {
766
767}
768
769impl AveragePrecisionPyGeneric {
770    fn new() -> Self {
771        return AveragePrecisionPyGeneric {};
772    }
773}
774
775impl PyScoreGeneric<AveragePrecision> for AveragePrecisionPyGeneric {
776    fn get_score(&self) -> AveragePrecision {
777        return AveragePrecision::new();
778    }
779}
780
781struct RocAucPyGeneric {
782    max_fpr: Option<f32>,
783}
784
785impl RocAucPyGeneric {
786    fn new(max_fpr: Option<f32>) -> Self {
787        return RocAucPyGeneric { max_fpr: max_fpr };
788    }
789}
790
791impl PyScoreGeneric<RocAucWithOptionalMaxFPR> for RocAucPyGeneric {
792    fn get_score(&self) -> RocAucWithOptionalMaxFPR {
793        return RocAucWithOptionalMaxFPR::new(self.max_fpr);
794    }
795}
796
797// https://stackoverflow.com/questions/70128978/how-to-define-different-function-names-with-a-macro
798// https://stackoverflow.com/questions/70872059/using-a-rust-macro-to-generate-a-function-with-variable-parameters
799// https://doc.rust-lang.org/rust-by-example/macros/designators.html
800// https://users.rust-lang.org/t/is-there-a-way-to-convert-given-identifier-to-a-string-in-a-macro/42907
801macro_rules! average_precision_py {
802    ($fname: ident, $pyname:literal, $label_type:ty, $prediction_type:ty, $weight_type:ty) => {
803        #[pyfunction(name = $pyname)]
804        #[pyo3(signature = (labels, predictions, *, weights=None, order=None))]
805        pub fn $fname<'py>(
806            py: Python<'py>,
807            labels: PyReadonlyArray1<'py, $label_type>,
808            predictions: PyReadonlyArray1<'py, $prediction_type>,
809            weights: Option<PyReadonlyArray1<'py, $weight_type>>,
810            order: Option<PyOrder>
811        ) -> $prediction_type
812        {
813            return AveragePrecisionPyGeneric::new().score_py(py, labels, predictions, weights, order);
814        }
815    };
816    ($fname: ident, $pyname:literal, $label_type:ty, $prediction_type:ty, $weight_type:ty, $py_module:ident) => {
817        average_precision_py!($fname, $pyname, $label_type, $prediction_type, $weight_type);
818        $py_module.add_function(wrap_pyfunction!($fname, $py_module)?).unwrap();
819    };
820}
821
822
823macro_rules! roc_auc_py {
824    ($fname: ident, $pyname:literal, $label_type:ty, $prediction_type:ty, $weight_type:ty) => {
825        #[pyfunction(name = $pyname)]
826        #[pyo3(signature = (labels, predictions, *, weights=None, order=None, max_fpr=None))]
827        pub fn $fname<'py>(
828            py: Python<'py>,
829            labels: PyReadonlyArray1<'py, $label_type>,
830            predictions: PyReadonlyArray1<'py, $prediction_type>,
831            weights: Option<PyReadonlyArray1<'py, $weight_type>>,
832            order: Option<PyOrder>,
833            max_fpr: Option<f32>,
834        ) -> $prediction_type
835        {
836            return RocAucPyGeneric::new(max_fpr).score_py(py, labels, predictions, weights, order);
837        }
838    };
839    ($fname: ident, $pyname:literal, $label_type:ty, $prediction_type:ty, $weight_type: ty, $py_module:ident) => {
840        roc_auc_py!($fname, $pyname, $label_type, $prediction_type, $weight_type);
841        $py_module.add_function(wrap_pyfunction!($fname, $py_module)?).unwrap();
842    };
843}
844
845
846macro_rules! average_precision_on_two_sorted_samples_py {
847    ($fname: ident, $pyname:literal, $lt:ty, $pt:ty, $wt:ty, $lt1:ty, $pt1:ty, $wt1:ty, $lt2:ty, $pt2:ty, $wt2: ty) => {
848        #[pyfunction(name = $pyname)]
849        #[pyo3(signature = (labels1, predictions1, weights1, labels2, predictions2, weights2, *))]
850        pub fn $fname<'py>(
851            py: Python<'py>,
852            labels1: PyReadonlyArray1<'py, $lt1>,
853            predictions1: PyReadonlyArray1<'py, $pt1>,
854            weights1: Option<PyReadonlyArray1<'py, $wt1>>,
855            labels2: PyReadonlyArray1<'py, $lt2>,
856            predictions2: PyReadonlyArray1<'py, $pt2>,
857            weights2: Option<PyReadonlyArray1<'py, $wt2>>,
858        ) -> $pt
859        {
860            return AveragePrecisionPyGeneric::new().score_two_sorted_samples_py_generic::<$lt, $pt, $wt, $lt1, $lt2, $pt1, $pt2, $wt1, $wt2>(py, labels1, predictions1, weights1, labels2, predictions2, weights2);
861        }
862    };
863    ($fname: ident, $pyname:literal, $lt:ty, $pt:ty, $wt:ty, $lt1:ty, $pt1:ty, $wt1:ty, $lt2:ty, $pt2:ty, $wt2: ty, $py_module:ident) => {
864        average_precision_on_two_sorted_samples_py!($fname, $pyname, $lt, $pt, $wt, $lt1, $pt1, $wt1, $lt2, $pt2, $wt2);
865        $py_module.add_function(wrap_pyfunction!($fname, $py_module)?).unwrap();
866    };
867}
868
869
870macro_rules! roc_auc_on_two_sorted_samples_py {
871    ($fname: ident, $pyname:literal, $lt:ty, $pt:ty, $wt:ty, $lt1:ty, $pt1:ty, $wt1:ty, $lt2:ty, $pt2:ty, $wt2: ty) => {
872        #[pyfunction(name = $pyname)]
873        #[pyo3(signature = (labels1, predictions1, weights1, labels2, predictions2, weights2, *, max_fpr=None))]
874        pub fn $fname<'py>(
875            py: Python<'py>,
876            labels1: PyReadonlyArray1<'py, $lt1>,
877            predictions1: PyReadonlyArray1<'py, $pt1>,
878            weights1: Option<PyReadonlyArray1<'py, $wt1>>,
879            labels2: PyReadonlyArray1<'py, $lt2>,
880            predictions2: PyReadonlyArray1<'py, $pt2>,
881            weights2: Option<PyReadonlyArray1<'py, $wt2>>,
882            max_fpr: Option<f32>,
883        ) -> $pt
884        {
885            return RocAucPyGeneric::new(max_fpr).score_two_sorted_samples_py_generic::<$lt, $pt, $wt, $lt1, $lt2, $pt1, $pt2, $wt1, $wt2>(py, labels1, predictions1, weights1, labels2, predictions2, weights2);
886        }
887    };
888    ($fname: ident, $pyname:literal, $lt:ty, $pt:ty, $wt:ty, $lt1:ty, $pt1:ty, $wt1:ty, $lt2:ty, $pt2:ty, $wt2: ty, $py_module:ident) => {
889        roc_auc_on_two_sorted_samples_py!($fname, $pyname, $lt, $pt, $wt, $lt1, $pt1, $wt1, $lt2, $pt2, $wt2);
890        $py_module.add_function(wrap_pyfunction!($fname, $py_module)?).unwrap();
891    };
892}
893
894
895#[pyfunction(name = "loo_cossim")]
896#[pyo3(signature = (data))]
897pub fn loo_cossim_py<'py>(
898    py: Python<'py>,
899    data: &Bound<'py, PyUntypedArray>
900) -> PyResult<f64> {
901    if data.ndim() != 2 {
902        return Err(PyTypeError::new_err(format!("Expected 2-dimensional array for data (samples x features) but found {} dimenisons.", data.ndim())));
903    }
904
905    let dt = data.dtype();
906    if dt.is_equiv_to(&dtype::<f32>(py)) {
907        let typed_data = data.downcast::<PyArray2<f32>>().unwrap().readonly();
908        let array = typed_data.as_array();
909        let score = py.allow_threads(move || {
910            loo_cossim_single(&array)
911        });
912        return Ok(score as f64);
913    }
914    if dt.is_equiv_to(&dtype::<f64>(py)) {
915        let typed_data = data.downcast::<PyArray2<f64>>().unwrap().readonly();
916        let array = typed_data.as_array();
917        let score = py.allow_threads(move || {
918            loo_cossim_single(&array)
919        });
920        return Ok(score);
921    }
922    return Err(PyTypeError::new_err(format!("Only float32 and float64 data supported, but found {}", dt)));
923}
924
925pub fn loo_cossim_many_generic_py<'py, F: num::Float + AddAssign + Element>(
926    py: Python<'py>,
927    data: &Bound<'py, PyArrayDyn<F>>
928) -> PyResult<Bound<'py, PyArray1<F>>> {
929    if data.ndim() != 3 {
930        return Err(PyTypeError::new_err(format!("Expected 3-dimensional array for data (outer(?) x samples x features) but found {} dimenisons.", data.ndim())));
931    }
932    let typed_data = data.downcast::<PyArray3<F>>().unwrap().readonly();
933    let array = typed_data.as_array();
934    let score = py.allow_threads(move || {
935        loo_cossim_many(&array)
936    });
937    // TODO how can we return this generically without making a copy at the end?
938    let score_py = PyArray::from_owned_array(py, score);
939    return Ok(score_py);
940}
941
942#[pyfunction(name = "loo_cossim_many_f64")]
943#[pyo3(signature = (data))]
944pub fn loo_cossim_many_py_f64<'py>(
945    py: Python<'py>,
946    data: &Bound<'py, PyUntypedArray>
947) -> PyResult<Bound<'py, PyArray1<f64>>> {
948    if data.ndim() != 3 {
949        return Err(PyTypeError::new_err(format!("Expected 3-dimensional array for data (outer(?) x samples x features) but found {} dimenisons.", data.ndim())));
950    }
951
952    let dt = data.dtype();
953    if !dt.is_equiv_to(&dtype::<f64>(py)) {
954        return Err(PyTypeError::new_err(format!("Only float64 data supported, but found {}", dt)));
955    }
956    let typed_data = data.downcast::<PyArrayDyn<f64>>().unwrap();
957    return loo_cossim_many_generic_py(py, typed_data);
958}
959
960#[pyfunction(name = "loo_cossim_many_f32")]
961#[pyo3(signature = (data))]
962pub fn loo_cossim_many_py_f32<'py>(
963    py: Python<'py>,
964    data: &Bound<'py, PyUntypedArray>
965) -> PyResult<Bound<'py, PyArray1<f32>>> {
966    if data.ndim() != 3 {
967        return Err(PyTypeError::new_err(format!("Expected 3-dimensional array for data (outer(?) x samples x features) but found {} dimenisons.", data.ndim())));
968    }
969
970    let dt = data.dtype();
971    if !dt.is_equiv_to(&dtype::<f32>(py)) {
972        return Err(PyTypeError::new_err(format!("Only float32 data supported, but found {}", dt)));
973    }
974    let typed_data = data.downcast::<PyArrayDyn<f32>>().unwrap();
975    return loo_cossim_many_generic_py(py, typed_data);
976}
977
978#[pymodule(name = "_scors")]
979fn scors(m: &Bound<'_, PyModule>) -> PyResult<()> {
980    average_precision_py!(average_precision_bool_f32, "average_precision_bool_f32", bool, f32, f32, m);
981    average_precision_py!(average_precision_i8_f32, "average_precision_i8_f32", i8, f32, f32, m);
982    average_precision_py!(average_precision_i16_f32, "average_precision_i16_f32", i16, f32, f32, m);
983    average_precision_py!(average_precision_i32_f32, "average_precision_i32_f32", i32, f32, f32, m);
984    average_precision_py!(average_precision_i64_f32, "average_precision_i64_f32", i64, f32, f32, m);
985    average_precision_py!(average_precision_u8_f32, "average_precision_u8_f32", u8, f32, f32, m);
986    average_precision_py!(average_precision_u16_f32, "average_precision_u16_f32", u16, f32, f32, m);
987    average_precision_py!(average_precision_u32_f32, "average_precision_u32_f32", u32, f32, f32, m);
988    average_precision_py!(average_precision_u64_f32, "average_precision_u64_f32", u64, f32, f32, m);
989    average_precision_py!(average_precision_bool_f64, "average_precision_bool_f64", bool, f64, f64, m);
990    average_precision_py!(average_precision_i8_f64, "average_precision_i8_f64", i8, f64, f64, m);
991    average_precision_py!(average_precision_i16_f64, "average_precision_i16_f64", i16, f64, f64, m);
992    average_precision_py!(average_precision_i32_f64, "average_precision_i32_f64", i32, f64, f64, m);
993    average_precision_py!(average_precision_i64_f64, "average_precision_i64_f64", i64, f64, f64, m);
994    average_precision_py!(average_precision_u8_f64, "average_precision_u8_f64", u8, f64, f64, m);
995    average_precision_py!(average_precision_u16_f64, "average_precision_u16_f64", u16, f64, f64, m);
996    average_precision_py!(average_precision_u32_f64, "average_precision_u32_f64", u32, f64, f64, m);
997    average_precision_py!(average_precision_u64_f64, "average_precision_u64_f64", u64, f64, f64, m);
998
999    roc_auc_py!(roc_auc_bool_f32, "roc_auc_bool_f32", bool, f32, f32, m);
1000    roc_auc_py!(roc_auc_i8_f32, "roc_auc_i8_f32", i8, f32, f32, m);
1001    roc_auc_py!(roc_auc_i16_f32, "roc_auc_i16_f32", i16, f32, f32, m);
1002    roc_auc_py!(roc_auc_i32_f32, "roc_auc_i32_f32", i32, f32, f32, m);
1003    roc_auc_py!(roc_auc_i64_f32, "roc_auc_i64_f32", i64, f32, f32, m);
1004    roc_auc_py!(roc_auc_u8_f32, "roc_auc_u8_f32", u8, f32, f32, m);
1005    roc_auc_py!(roc_auc_u16_f32, "roc_auc_u16_f32", u16, f32, f32, m);
1006    roc_auc_py!(roc_auc_u32_f32, "roc_auc_u32_f32", u32, f32, f32, m);
1007    roc_auc_py!(roc_auc_u64_f32, "roc_auc_u64_f32", u64, f32, f32, m);
1008    roc_auc_py!(roc_auc_bool_f64, "roc_auc_bool_f64", bool, f64, f64, m);
1009    roc_auc_py!(roc_auc_i8_f64, "roc_auc_i8_f64", i8, f64, f64, m);
1010    roc_auc_py!(roc_auc_i16_f64, "roc_auc_i16_f64", i16, f64, f64, m);
1011    roc_auc_py!(roc_auc_i32_f64, "roc_auc_i32_f64", i32, f64, f64, m);
1012    roc_auc_py!(roc_auc_i64_f64, "roc_auc_i64_f64", i64, f64, f64, m);
1013    roc_auc_py!(roc_auc_u8_f64, "roc_auc_u8_f64", u8, f64, f64, m);
1014    roc_auc_py!(roc_auc_u16_f64, "roc_auc_u16_f64", u16, f64, f64, m);
1015    roc_auc_py!(roc_auc_u32_f64, "roc_auc_u32_f64", u32, f64, f64, m);
1016    roc_auc_py!(roc_auc_u64_f64, "roc_auc_u64_f64", u64, f64, f64, m);
1017
1018    average_precision_on_two_sorted_samples_py!(average_precision_on_two_sorted_samples_bool_f32, "average_precision_on_two_sorted_samples_bool_f32", bool, f32, f32, bool, f32, f32, bool, f32, f32, m);
1019    average_precision_on_two_sorted_samples_py!(average_precision_on_two_sorted_samples_i8_f32, "average_precision_on_two_sorted_samples_i8_f32", i8, f32, f32, i8, f32, f32, i8, f32, f32, m);
1020    average_precision_on_two_sorted_samples_py!(average_precision_on_two_sorted_samples_i16_f32, "average_precision_on_two_sorted_samples_i16_f32", i16, f32, f32, i16, f32, f32, i16, f32, f32, m);
1021    average_precision_on_two_sorted_samples_py!(average_precision_on_two_sorted_samples_i32_f32, "average_precision_on_two_sorted_samples_i32_f32", i32, f32, f32, i32, f32, f32, i32, f32, f32, m);
1022    average_precision_on_two_sorted_samples_py!(average_precision_on_two_sorted_samples_i64_f32, "average_precision_on_two_sorted_samples_i64_f32", i64, f32, f32, i64, f32, f32, i64, f32, f32, m);
1023    average_precision_on_two_sorted_samples_py!(average_precision_on_two_sorted_samples_u8_f32, "average_precision_on_two_sorted_samples_u8_f32", u8, f32, f32, u8, f32, f32, u8, f32, f32, m);
1024    average_precision_on_two_sorted_samples_py!(average_precision_on_two_sorted_samples_u16_f32, "average_precision_on_two_sorted_samples_u16_f32", u16, f32, f32, u16, f32, f32, u16, f32, f32, m);
1025    average_precision_on_two_sorted_samples_py!(average_precision_on_two_sorted_samples_u32_f32, "average_precision_on_two_sorted_samples_u32_f32", u32, f32, f32, u32, f32, f32, u32, f32, f32, m);
1026    average_precision_on_two_sorted_samples_py!(average_precision_on_two_sorted_samples_u64_f32, "average_precision_on_two_sorted_samples_u64_f32", u64, f32, f32, u64, f32, f32, u64, f32, f32, m);
1027    average_precision_on_two_sorted_samples_py!(average_precision_on_two_sorted_samples_bool_f64, "average_precision_on_two_sorted_samples_bool_f64", bool, f64, f64, bool, f64, f64, bool, f64, f64, m);
1028    average_precision_on_two_sorted_samples_py!(average_precision_on_two_sorted_samples_i8_f64, "average_precision_on_two_sorted_samples_i8_f64", i8, f64, f64, i8, f64, f64, i8, f64, f64, m);
1029    average_precision_on_two_sorted_samples_py!(average_precision_on_two_sorted_samples_i16_f64, "average_precision_on_two_sorted_samples_i16_f64", i16, f64, f64, i16, f64, f64, i16, f64, f64, m);
1030    average_precision_on_two_sorted_samples_py!(average_precision_on_two_sorted_samples_i32_f64, "average_precision_on_two_sorted_samples_i32_f64", i32, f64, f64, i16, f64, f64, i16, f64, f64, m);
1031    average_precision_on_two_sorted_samples_py!(average_precision_on_two_sorted_samples_i64_f64, "average_precision_on_two_sorted_samples_i64_f64", i64, f64, f64, i64, f64, f64, i64, f64, f64, m);
1032    average_precision_on_two_sorted_samples_py!(average_precision_on_two_sorted_samples_u8_f64, "average_precision_on_two_sorted_samples_u8_f64", u8, f64, f64, u8, f64, f64, u8, f64, f64, m);
1033    average_precision_on_two_sorted_samples_py!(average_precision_on_two_sorted_samples_u16_f64, "average_precision_on_two_sorted_samples_u16_f64", u16, f64, f64, u16, f64, f64, u16, f64, f64, m);
1034    average_precision_on_two_sorted_samples_py!(average_precision_on_two_sorted_samples_u32_f64, "average_precision_on_two_sorted_samples_u32_f64", u32, f64, f64, u32, f64, f64, u32, f64, f64, m);
1035    average_precision_on_two_sorted_samples_py!(average_precision_on_two_sorted_samples_u64_f64, "average_precision_on_two_sorted_samples_u64_f64", u64, f64, f64, u64, f64, f64, u64, f64, f64, m);
1036
1037    roc_auc_on_two_sorted_samples_py!(roc_auc_on_two_sorted_samples_bool_f32, "roc_auc_on_two_sorted_samples_bool_f32", bool, f32, f32, bool, f32, f32, bool, f32, f32, m);
1038    roc_auc_on_two_sorted_samples_py!(roc_auc_on_two_sorted_samples_i8_f32, "roc_auc_on_two_sorted_samples_i8_f32", i8, f32, f32, i8, f32, f32, i8, f32, f32, m);
1039    roc_auc_on_two_sorted_samples_py!(roc_auc_on_two_sorted_samples_i16_f32, "roc_auc_on_two_sorted_samples_i16_f32", i16, f32, f32, i16, f32, f32, i16, f32, f32, m);
1040    roc_auc_on_two_sorted_samples_py!(roc_auc_on_two_sorted_samples_i32_f32, "roc_auc_on_two_sorted_samples_i32_f32", i32, f32, f32, i32, f32, f32, i32, f32, f32, m);
1041    roc_auc_on_two_sorted_samples_py!(roc_auc_on_two_sorted_samples_i64_f32, "roc_auc_on_two_sorted_samples_i64_f32", i64, f32, f32, i64, f32, f32, i64, f32, f32, m);
1042    roc_auc_on_two_sorted_samples_py!(roc_auc_on_two_sorted_samples_u8_f32, "roc_auc_on_two_sorted_samples_u8_f32", u8, f32, f32, u8, f32, f32, u8, f32, f32, m);
1043    roc_auc_on_two_sorted_samples_py!(roc_auc_on_two_sorted_samples_u16_f32, "roc_auc_on_two_sorted_samples_u16_f32", u16, f32, f32, u16, f32, f32, u16, f32, f32, m);
1044    roc_auc_on_two_sorted_samples_py!(roc_auc_on_two_sorted_samples_u32_f32, "roc_auc_on_two_sorted_samples_u32_f32", u32, f32, f32, u32, f32, f32, u32, f32, f32, m);
1045    roc_auc_on_two_sorted_samples_py!(roc_auc_on_two_sorted_samples_u64_f32, "roc_auc_on_two_sorted_samples_u64_f32", u64, f32, f32, u64, f32, f32, u64, f32, f32, m);
1046    roc_auc_on_two_sorted_samples_py!(roc_auc_on_two_sorted_samples_bool_f64, "roc_auc_on_two_sorted_samples_bool_f64", bool, f64, f64, bool, f64, f64, bool, f64, f64, m);
1047    roc_auc_on_two_sorted_samples_py!(roc_auc_on_two_sorted_samples_i8_f64, "roc_auc_on_two_sorted_samples_i8_f64", i8, f64, f64, i8, f64, f64, i8, f64, f64, m);
1048    roc_auc_on_two_sorted_samples_py!(roc_auc_on_two_sorted_samples_i16_f64, "roc_auc_on_two_sorted_samples_i16_f64", i16, f64, f64, i16, f64, f64, i16, f64, f64, m);
1049    roc_auc_on_two_sorted_samples_py!(roc_auc_on_two_sorted_samples_i32_f64, "roc_auc_on_two_sorted_samples_i32_f64", i32, f64, f64, i16, f64, f64, i16, f64, f64, m);
1050    roc_auc_on_two_sorted_samples_py!(roc_auc_on_two_sorted_samples_i64_f64, "roc_auc_on_two_sorted_samples_i64_f64", i64, f64, f64, i64, f64, f64, i64, f64, f64, m);
1051    roc_auc_on_two_sorted_samples_py!(roc_auc_on_two_sorted_samples_u8_f64, "roc_auc_on_two_sorted_samples_u8_f64", u8, f64, f64, u8, f64, f64, u8, f64, f64, m);
1052    roc_auc_on_two_sorted_samples_py!(roc_auc_on_two_sorted_samples_u16_f64, "roc_auc_on_two_sorted_samples_u16_f64", u16, f64, f64, u16, f64, f64, u16, f64, f64, m);
1053    roc_auc_on_two_sorted_samples_py!(roc_auc_on_two_sorted_samples_u32_f64, "roc_auc_on_two_sorted_samples_u32_f64", u32, f64, f64, u32, f64, f64, u32, f64, f64, m);
1054    roc_auc_on_two_sorted_samples_py!(roc_auc_on_two_sorted_samples_u64_f64, "roc_auc_on_two_sorted_samples_u64_f64", u64, f64, f64, u64, f64, f64, u64, f64, f64, m);
1055
1056    m.add_function(wrap_pyfunction!(loo_cossim_py, m)?).unwrap();
1057    m.add_function(wrap_pyfunction!(loo_cossim_many_py_f64, m)?).unwrap();
1058    m.add_function(wrap_pyfunction!(loo_cossim_many_py_f32, m)?).unwrap();
1059    m.add_class::<PyOrder>().unwrap();
1060    return Ok(());
1061}
1062
1063
1064#[cfg(test)]
1065mod tests {
1066    use super::*;
1067
1068    #[test]
1069    fn test_average_precision_on_sorted() {
1070        let labels: [u8; 4] = [1, 0, 1, 0];
1071        let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
1072        let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
1073        let actual: f64 = score_sorted_sample(AveragePrecision::new(), &predictions, &labels, &weights, Order::DESCENDING);
1074        assert_eq!(actual, 0.8333333333333333);
1075    }
1076
1077    #[test]
1078    fn test_average_precision_on_sorted_double() {
1079        let labels: [u8; 8] = [1, 1, 0, 0, 1, 1, 0, 0];
1080        let predictions: [f64; 8] = [0.8, 0.8, 0.4, 0.4, 0.35, 0.35, 0.1, 0.1];
1081        let weights: [f64; 8] = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
1082        let actual: f64 = score_sorted_sample(AveragePrecision::new(), &predictions, &labels, &weights, Order::DESCENDING);
1083        assert_eq!(actual, 0.8333333333333333);
1084    }
1085
1086    #[test]
1087    fn test_average_precision_unsorted() {
1088        let labels: [u8; 4] = [0, 0, 1, 1];
1089        let predictions: [f64; 4] = [0.1, 0.4, 0.35, 0.8];
1090        let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
1091        let actual: f64 = average_precision(&predictions, &labels, Some(&weights), None);
1092        assert_eq!(actual, 0.8333333333333333);
1093    }
1094
1095    #[test]
1096    fn test_average_precision_sorted() {
1097        let labels: [u8; 4] = [1, 0, 1, 0];
1098        let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
1099        let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
1100        let actual: f64 = average_precision(&predictions, &labels, Some(&weights), Some(Order::DESCENDING));
1101        assert_eq!(actual, 0.8333333333333333);
1102    }
1103
1104    #[test]
1105    fn test_average_precision_sorted_pair() {
1106        let labels: [u8; 4] = [1, 0, 1, 0];
1107        let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
1108        let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
1109        let actual: f64 = score_two_sorted_samples(
1110            AveragePrecision::new(),
1111            predictions.iter().cloned(),
1112            labels.iter().cloned(),
1113            weights.iter().cloned(),
1114            predictions.iter().cloned(),
1115            labels.iter().cloned(),
1116            weights.iter().cloned()
1117        );
1118        assert_eq!(actual, 0.8333333333333333);
1119    }
1120
1121    #[test]
1122    fn test_roc_auc() {
1123        let labels: [u8; 4] = [1, 0, 1, 0];
1124        let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
1125        let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
1126        let actual: f64 = roc_auc(&predictions, &labels, Some(&weights), Some(Order::DESCENDING), None);
1127        assert_eq!(actual, 0.75);
1128    }
1129
1130    #[test]
1131    fn test_roc_auc_double() {
1132        let labels: [u8; 8] = [1, 0, 1, 0, 1, 0, 1, 0];
1133        let predictions: [f64; 8] = [0.8, 0.4, 0.35, 0.1, 0.8, 0.4, 0.35, 0.1];
1134        let actual: f64 = roc_auc(&predictions, &labels, None::<&[f64; 8]>, None, None);
1135        assert_eq!(actual, 0.75);
1136    }
1137
1138    #[test]
1139    fn test_roc_sorted_pair() {
1140        let labels: [u8; 4] = [1, 0, 1, 0];
1141        let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
1142        let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
1143        let actual: f64 = score_two_sorted_samples(
1144            RocAuc::new(),
1145            predictions.iter().cloned(),
1146            labels.iter().cloned(),
1147            weights.iter().cloned(),
1148            predictions.iter().cloned(),
1149            labels.iter().cloned(),
1150            weights.iter().cloned()
1151        );
1152        assert_eq!(actual, 0.75);
1153    }
1154
1155    #[test]
1156    fn test_roc_auc_max_fpr() {
1157        let labels: [u8; 4] = [1, 0, 1, 0];
1158        let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
1159        let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
1160        let actual: f64 = roc_auc(&predictions, &labels, Some(&weights), Some(Order::DESCENDING), Some(0.25));
1161        assert_eq!(actual, 0.7142857142857143);
1162    }
1163
1164    #[test]
1165    fn test_roc_auc_max_fpr_double() {
1166        let labels: [u8; 8] = [1, 0, 1, 0, 1, 0, 1, 0];
1167        let predictions: [f64; 8] = [0.8, 0.4, 0.35, 0.1, 0.8, 0.4, 0.35, 0.1];
1168        let actual: f64 = roc_auc(&predictions, &labels, None::<&[f64; 8]>, None, Some(0.25));
1169        assert_eq!(actual, 0.7142857142857143);
1170    }
1171
1172    #[test]
1173    fn test_roc_auc_max_fpr_sorted_pair() {
1174        let labels: [u8; 4] = [1, 0, 1, 0];
1175        let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
1176        let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
1177        let actual: f64 = score_two_sorted_samples(
1178            RocAucWithMaxFPR::new(0.25),
1179            predictions.iter().cloned(),
1180            labels.iter().cloned(),
1181            weights.iter().cloned(),
1182            predictions.iter().cloned(),
1183            labels.iter().cloned(),
1184            weights.iter().cloned()
1185        );
1186        assert_eq!(actual, 0.7142857142857143);
1187    }
1188}