_scors/
lib.rs

1use ndarray::{Array1,ArrayView,ArrayView2,ArrayView3,ArrayViewMut1,Ix1};
2use num;
3use numpy::{Element,PyArray,PyArray1,PyArray2,PyArray3,PyArrayDescr,PyArrayDescrMethods,PyArrayDyn,PyArrayMethods,PyReadonlyArray1,PyUntypedArray,PyUntypedArrayMethods,dtype};
4use pyo3::Bound;
5use pyo3::exceptions::PyTypeError;
6use pyo3::marker::Ungil;
7use pyo3::prelude::*;
8use std::iter::DoubleEndedIterator;
9use std::marker::PhantomData;
10use std::ops::AddAssign;
11
12#[derive(Clone, Copy)]
13pub enum Order {
14    ASCENDING,
15    DESCENDING
16}
17
18struct ConstWeight {
19    value: f64
20}
21
22impl ConstWeight {
23    fn new(value: f64) -> Self {
24        return ConstWeight { value: value };
25    }
26    fn one() -> Self {
27        return Self::new(1.0);
28    }
29}
30
31pub trait Data<T: Clone>: {
32    // TODO This is necessary because it seems that there is no trait like that in rust
33    //      Maybe I am just not aware, but for now use my own trait.
34    fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T>;
35    fn get_at(&self, index: usize) -> T;
36}
37
38pub trait SortableData<T> {
39    fn argsort_unstable(&self) -> Vec<usize>;
40}
41
42impl Iterator for ConstWeight {
43    type Item = f64;
44    fn next(&mut self) -> Option<f64> {
45        return Some(self.value);
46    }
47}
48
49impl DoubleEndedIterator for ConstWeight {
50    fn next_back(&mut self) -> Option<f64> {
51        return Some(self.value);
52    }
53}
54
55impl Data<f64> for ConstWeight {
56    fn get_iterator(&self) -> impl DoubleEndedIterator<Item = f64> {
57        return ConstWeight::new(self.value);
58    }
59
60    fn get_at(&self, _index: usize) -> f64 {
61        return self.value.clone();
62    }
63}
64
65impl <T: Clone> Data<T> for Vec<T> {
66    fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> {
67        return self.iter().cloned();
68    }
69    fn get_at(&self, index: usize) -> T {
70        return self[index].clone();
71    }
72}
73
74impl SortableData<f64> for Vec<f64> {
75    fn argsort_unstable(&self) -> Vec<usize> {
76        let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
77        indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
78        // indices.sort_unstable_by_key(|i| self[*i]);
79        return indices;
80    }
81}
82
83impl <T: Clone> Data<T> for &[T] {
84    fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> {
85        return self.iter().cloned();
86    }
87    fn get_at(&self, index: usize) -> T {
88        return self[index].clone();
89    }
90}
91
92impl SortableData<f64> for &[f64] {
93    fn argsort_unstable(&self) -> Vec<usize> {
94        // let t0 = Instant::now();
95        let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
96        // println!("Creating indices took {}ms", t0.elapsed().as_millis());
97        // let t1 = Instant::now();
98        indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
99        // println!("Sorting took {}ms", t0.elapsed().as_millis());
100        return indices;
101    }
102}
103
104impl <T: Clone, const N: usize> Data<T> for [T; N] {
105    fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> {
106        return self.iter().cloned();
107    }
108    fn get_at(&self, index: usize) -> T {
109        return self[index].clone();
110    }
111}
112
113impl <const N: usize> SortableData<f64> for [f64; N] {
114    fn argsort_unstable(&self) -> Vec<usize> {
115        let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
116        indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
117        return indices;
118    }
119}
120
121impl <T: Clone> Data<T> for ArrayView<'_, T, Ix1> {
122    fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> {
123        return self.iter().cloned();
124    }
125    fn get_at(&self, index: usize) -> T {
126        return self[index].clone();
127    }
128}
129
130impl SortableData<f64> for ArrayView<'_, f64, Ix1> {
131    fn argsort_unstable(&self) -> Vec<usize> {
132        let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
133        indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
134        return indices;
135    }
136}
137
138// struct IndexView<'a, T, D> where T: Clone, D: Data<T>{
139//     data: &'a D,
140//     indices: &'a Vec<usize>,
141// }
142
143// impl <'a, T: Clone> Data<T> for IndexView<'a, T> {
144// }
145
146pub trait BinaryLabel: Clone + Copy {
147    fn get_value(&self) -> bool;
148}
149
150impl BinaryLabel for bool {
151    fn get_value(&self) -> bool {
152        return self.clone();
153    }
154}
155
156impl BinaryLabel for u8 {
157    fn get_value(&self) -> bool {
158        return (self & 1) == 1;
159    }
160}
161
162impl BinaryLabel for u16 {
163    fn get_value(&self) -> bool {
164        return (self & 1) == 1;
165    }
166}
167
168impl BinaryLabel for u32 {
169    fn get_value(&self) -> bool {
170        return (self & 1) == 1;
171    }
172}
173
174impl BinaryLabel for u64 {
175    fn get_value(&self) -> bool {
176        return (self & 1) == 1;
177    }
178}
179
180impl BinaryLabel for i8 {
181    fn get_value(&self) -> bool {
182        return (self & 1) == 1;
183    }
184}
185
186impl BinaryLabel for i16 {
187    fn get_value(&self) -> bool {
188        return (self & 1) == 1;
189    }
190}
191
192impl BinaryLabel for i32 {
193    fn get_value(&self) -> bool {
194        return (self & 1) == 1;
195    }
196}
197
198impl BinaryLabel for i64 {
199    fn get_value(&self) -> bool {
200        return (self & 1) == 1;
201    }
202}
203
204struct SortedSampleDescending<'a, B, L, P, W>
205where B: BinaryLabel + Clone + 'a, &'a L: IntoIterator<Item = &'a B>, &'a P: IntoIterator<Item = &'a f64>, &'a W: IntoIterator<Item = &'a f64>
206{
207    labels: &'a L,
208    predictions: &'a P,
209    weights: &'a W,
210    label_type: PhantomData<B>,
211}
212
213impl <'a, B, L, P, W> SortedSampleDescending<'a, B, L, P, W>
214where B: BinaryLabel + Clone + 'a, &'a L: IntoIterator<Item = &'a B>, &'a P: IntoIterator<Item = &'a f64>, &'a W: IntoIterator<Item = &'a f64>
215{
216    fn new(labels: &'a L, predictions: &'a P, weights: &'a W) -> Self {
217        return SortedSampleDescending {
218            labels: labels,
219            predictions: predictions,
220            weights: weights,
221            label_type: PhantomData
222        }
223    }
224}
225
226// struct SortedSampleIterator<'a, B, L, P, W>
227// where B: BinaryLabel + Clone + 'a, L: Iterator<Item = &'a B>, P: Iterator<Item = &'a f64>, W: Iterator<Item = &' f64>
228// {
229//     labels: &'a mut L,
230//     predictions: &'a mut P,
231//     weights: &'a mut W,
232// }
233
234
235// impl <'a, B, L, P, W> Iterator for SortedSampleIterator<'a, B, L, P, W>
236// where B: BinaryLabel, L: Iterator<Item = B>, P: Iterator<Item = f64>, W: Iterator<Item = f64>
237// {
238//     type Item = (B, f64, f64);
239//     fn next(&mut self) -> Option<(B, f64, f64)> {
240//         return match (self.labels.next(), self.predictions.next(), self.weights.next()) {
241//             (Some(l), Some(p), Some(w)) => Some((l, p, w)),
242//             _ => None
243//         };
244//     }
245// }
246
247// impl <'a, B, L, P, W> IntoIterator for &SortedSampleDescending<'a, B, L, P, W>
248// where B: BinaryLabel, L: IntoIterator<Item = B>, P: IntoIterator<Item = f64>, W: IntoIterator<Item = f64> {
249//     type Item = (B, f64, f64);
250//     type IntoIter = SortedSampleIterator<'a, B, L::IntoIter, P::IntoIter, W::IntoIter>;
251
252//     fn into_iter(self) -> Self::IntoIter {
253//         let l = self.labels.into_iter();
254//         // return SortedSampleIterator {
255//         //     labels: self.labels.into_iter(),
256//         //     predictions: self.labels.into_iter(),
257//         //     weights: self.weights.into_iter(),
258//         // }
259//     }
260// }
261
262// struct CombineView<'a, B, L, P, W> 
263// where B: BinaryLabel, L: Data<B>, P: Data<f64>, W: Data<f64>
264// {
265//     sample1: &'a SortedSampleDescending<'a, B, L, P, W>,
266//     sample2: &'a SortedSampleDescending<'a, B, L, P, W>,
267// }
268
269// impl <'a, B, L, P, W> CombineView<'a, B, L, P, W>
270// where B: BinaryLabel, L: Data<B>, P: Data<f64>, W: Data<f64>
271// {
272//     fn new(sample1: &'a SortedSampleDescending<'a, B, L, P, W>, sample2: &'a SortedSampleDescending<'a, B, L, P, W>,) -> Self {
273//         return PairedSortedSampleDescending { sample1: sample1, sample2: sample2 };
274//     }
275// }
276
277// struct CombineViewIterator<'a, B, L, P, W>
278// where B: BinaryLabel, L: Data<B>, P: Data<f64>, W: Data<f64>
279// {
280//     view: &'a CombineView<'a, B, L, P, W>,
281//     iterator1: impl Iterator<Item: ((B, f64), f64),
282//     iterator2: impl Iterator<Item: ((B, f64), f64),
283//     current_index: usize,
284//     num_elements: usize
285// }
286
287// impl <'a, B, L, P, W> CombineViewIterator<'a, B, L, P, W>
288// where B: BinaryLabel, L: Data<B>, P: Data<f64>, W: Data<f64>
289// {
290//     fn new(view: &'a CombineView<'a, B, L, P, W>) -> Self {
291//         return CombineViewIerator {
292//             view: view,
293//             iterator1: view.sample1.iterator(),
294//             iterator2: view.sample2.iterator(),
295//             current_index: 0,
296//             num_elements: usize
297//         };
298//     }
299// }
300
301
302// impl <'a, B, L, P, W> Iterator for CombineViewIterator<'a, B, L, P, W>
303// where B: BinaryLabel, L: Data<B>, P: Data<f64>, W: Data<f64>
304// {
305//     type Item = (B, f64, f64);
306//     fn next(&mut self) -> Option<(B, f64, f64)> {
307//         if self.current_index == self.num_elements {
308//             return None;
309//         }
310        
311//         return Some(self.value);
312//     }
313// }
314
315fn select<T, I>(slice: &I, indices: &[usize]) -> Vec<T>
316where T: Copy, I: Data<T>
317{
318    let mut selection: Vec<T> = Vec::new();
319    selection.reserve_exact(indices.len());
320    for index in indices {
321        selection.push(slice.get_at(*index));
322    }
323    return selection;
324}
325
326pub fn average_precision<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>) -> f64
327where B: BinaryLabel, L: Data<B>, P: SortableData<f64>, W: Data<f64>
328{
329    return average_precision_with_order(labels, predictions, weights, None);
330}
331
332pub fn average_precision_with_order<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>, order: Option<Order>) -> f64
333where B: BinaryLabel, L: Data<B>, P: SortableData<f64>, W: Data<f64>
334{
335    return match order {
336        Some(o) => average_precision_on_sorted_labels(labels, weights, o),
337        None => {
338            let indices = predictions.argsort_unstable();
339            let sorted_labels = select(labels, &indices);
340            let ap = match weights {
341                None => {
342                    // let w: Oepion<&
343                    average_precision_on_sorted_labels(&sorted_labels, weights, Order::DESCENDING)
344                },
345                Some(w) => average_precision_on_sorted_labels(&sorted_labels, Some(&select(w, &indices)), Order::DESCENDING),
346            };
347            ap
348        }
349    };
350}
351
352pub fn average_precision_on_sorted_labels<B, L, W>(labels: &L, weights: Option<&W>, order: Order) -> f64
353where B: BinaryLabel, L: Data<B>, W: Data<f64>
354{
355    return match weights {
356        None => average_precision_on_iterator(labels.get_iterator(), ConstWeight::one(), order),
357        Some(w) => average_precision_on_iterator(labels.get_iterator(), w.get_iterator(), order)
358    };
359}
360
361pub fn average_precision_on_iterator<B, L, W>(labels: L, weights: W, order: Order) -> f64
362where B: BinaryLabel, L: DoubleEndedIterator<Item = B>, W: DoubleEndedIterator<Item = f64>
363{
364    return match order {
365        Order::ASCENDING => average_precision_on_descending_iterator(labels.rev(), weights.rev()),
366        Order::DESCENDING => average_precision_on_descending_iterator(labels, weights)
367    };
368}
369
370pub fn average_precision_on_descending_iterator<B: BinaryLabel>(labels: impl Iterator<Item = B>, weights: impl Iterator<Item = f64>) -> f64 {
371    return average_precision_on_descending_iterators(labels.zip(weights));
372}
373
374// impl <'a, B, L, P, W> SortedSampleDescending<'a, B, L, P, W>
375// where B: BinaryLabel, L: IntoIterator<Item = B>, P: IntoIterator<Item = f64>, W: IntoIterator<Item = f64>
376
377pub fn average_precision_on_sorted_samples<'a, B, L, P, W>(l1: &'a L, p1: &'a P, w1: &'a W, l2: &'a L, p2: &'a P, w2: &'a W) -> f64
378where B: BinaryLabel + Clone + 'a, &'a L: IntoIterator<Item = &'a B>, &'a P: IntoIterator<Item = &'a f64>, &'a W: IntoIterator<Item = &'a f64>
379{
380    // let mut it1 = p1.into_iter();
381    let i1 = p1.into_iter().cloned().zip(l1.into_iter().cloned().zip(w1.into_iter().cloned()));
382    let i2 = p2.into_iter().cloned().zip(l2.into_iter().cloned().zip(w2.into_iter().cloned()));
383    let labels_and_weights = i1.zip(i2).map(|(t1, t2)| {
384        if t1.0 > t2.0 {
385            t1.1
386        } else {
387            t2.1
388        }
389    });
390    return average_precision_on_descending_iterators(labels_and_weights);
391}
392
393pub fn average_precision_on_descending_iterators<B: BinaryLabel>(labels_and_weights: impl Iterator<Item = (B, f64)>) -> f64 {
394    let mut ap: f64 = 0.0;
395    let mut tps: f64 = 0.0;
396    let mut fps: f64 = 0.0;
397    for (label, weight) in labels_and_weights {
398        let w: f64 = weight;
399        let l: bool = label.get_value();
400        let tp = w * f64::from(l);
401        tps += tp;
402        fps += weight - tp;
403        let ps = tps + fps;
404        let precision = tps / ps;
405        ap += tp * precision;
406    }
407    // Special case for tps == 0 following sklearn
408    // https://github.com/scikit-learn/scikit-learn/blob/5cce87176a530d2abea45b5a7e5a4d837c481749/sklearn/metrics/_ranking.py#L1032-L1039
409    // I.e. if tps is 0.0, there are no positive samples in labels: Either all labels are 0, or all weights (for positive labels) are 0
410    return if tps == 0.0 {
411        0.0
412    } else {
413        ap / tps
414    };
415}
416
417
418
419// ROC AUC score
420pub fn roc_auc<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>) -> f64
421where B: BinaryLabel, L: Data<B>, P: SortableData<f64> + Data<f64>, W: Data<f64>
422{
423    return roc_auc_with_order(labels, predictions, weights, None, None);
424}
425
426pub fn roc_auc_max_fpr<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>, max_false_positive_rate: Option<f64>) -> f64
427where B: BinaryLabel, L: Data<B>, P: SortableData<f64> + Data<f64>, W: Data<f64>
428{
429    return roc_auc_with_order(labels, predictions, weights, None, max_false_positive_rate);
430}
431
432pub fn roc_auc_with_order<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>, order: Option<Order>, max_false_positive_rate: Option<f64>) -> f64
433where B: BinaryLabel, L: Data<B>, P: SortableData<f64> + Data<f64>, W: Data<f64>
434{
435    return match order {
436        Some(o) => roc_auc_on_sorted_labels(labels, predictions, weights, o, max_false_positive_rate),
437        None => {
438            let indices = predictions.argsort_unstable();
439            let sorted_labels = select(labels, &indices);
440            let sorted_predictions = select(predictions, &indices);
441            let roc_auc_score = match weights {
442                Some(w) => {
443                    let sorted_weights = select(w, &indices);
444                    roc_auc_on_sorted_labels(&sorted_labels, &sorted_predictions, Some(&sorted_weights), Order::DESCENDING, max_false_positive_rate)
445                },
446                None => {
447                    roc_auc_on_sorted_labels(&sorted_labels, &sorted_predictions, None::<&W>, Order::DESCENDING, max_false_positive_rate)
448                }
449            };
450            roc_auc_score
451        }
452    };
453}
454pub fn roc_auc_on_sorted_labels<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>, order: Order, max_false_positive_rate: Option<f64>) -> f64
455where B: BinaryLabel, L: Data<B>, P: Data<f64>, W: Data<f64> {
456    return match max_false_positive_rate {
457        None => match weights {
458            Some(w) => roc_auc_on_sorted_iterator(&mut labels.get_iterator(), &mut predictions.get_iterator(), &mut w.get_iterator(), order),
459            None => roc_auc_on_sorted_iterator(&mut labels.get_iterator(), &mut predictions.get_iterator(), &mut ConstWeight::one().get_iterator(), order),
460        }
461        Some(max_fpr) => match weights {
462            Some(w) => roc_auc_on_sorted_with_fp_cutoff(labels, predictions, w, order, max_fpr),
463            None => roc_auc_on_sorted_with_fp_cutoff(labels, predictions, &ConstWeight::one(), order, max_fpr),
464        }
465    };
466}
467
468pub fn roc_auc_on_sorted_iterator<B: BinaryLabel>(
469    labels: &mut impl DoubleEndedIterator<Item = B>,
470    predictions: &mut impl DoubleEndedIterator<Item = f64>,
471    weights: &mut impl DoubleEndedIterator<Item = f64>,
472    order: Order
473) -> f64 {
474    return match order {
475        Order::ASCENDING => roc_auc_on_descending_iterator(&mut labels.rev(), &mut predictions.rev(), &mut weights.rev()),
476        Order::DESCENDING => roc_auc_on_descending_iterator(labels, predictions, weights)
477    }
478}
479
480pub fn roc_auc_on_descending_iterator<B: BinaryLabel>(
481    labels: &mut impl Iterator<Item = B>,
482    predictions: &mut impl Iterator<Item = f64>,
483    weights: &mut impl Iterator<Item = f64>
484) -> f64 {
485    let mut false_positives: f64 = 0.0;
486    let mut true_positives: f64 = 0.0;
487    let mut last_counted_fp = 0.0;
488    let mut last_counted_tp = 0.0;
489    let mut area_under_curve = 0.0;
490    let mut zipped = labels.zip(predictions).zip(weights).peekable();
491    loop {
492        match zipped.next() {
493            None => break,
494            Some(actual) => {
495                let l = f64::from(actual.0.0.get_value());
496                let w = actual.1;
497                let wl = l * w;
498                true_positives += wl;
499                false_positives += w - wl;
500                if zipped.peek().map(|x| x.0.1 != actual.0.1).unwrap_or(true) {
501                    area_under_curve += area_under_line_segment(last_counted_fp, false_positives, last_counted_tp, true_positives);
502                    last_counted_fp = false_positives;
503                    last_counted_tp = true_positives;
504                }
505            }
506        };
507    }
508    return area_under_curve / (true_positives * false_positives);
509}
510
511fn area_under_line_segment(x0: f64, x1: f64, y0: f64, y1: f64) -> f64 {
512    let dx = x1 - x0;
513    let dy = y1 - y0;
514    return dx * y0 + dy * dx * 0.5;
515}
516
517fn get_positive_sum<B: BinaryLabel>(
518    labels: impl Iterator<Item = B>,
519    weights: impl Iterator<Item = f64>
520) -> (f64, f64) {
521    let mut false_positives = 0f64;
522    let mut true_positives = 0f64;
523    for (label, weight) in labels.zip(weights) {
524        let lw = weight * f64::from(label.get_value());
525        false_positives += weight - lw;
526        true_positives += lw;
527    }
528    return (false_positives, true_positives);
529}
530
531pub fn roc_auc_on_sorted_with_fp_cutoff<B, L, P, W>(labels: &L, predictions: &P, weights: &W, order: Order, max_false_positive_rate: f64) -> f64
532where B: BinaryLabel, L: Data<B>, P: Data<f64>, W: Data<f64> {
533    // TODO validate max_fpr
534    let (fps, tps) = get_positive_sum(labels.get_iterator(), weights.get_iterator());
535    let mut l_it = labels.get_iterator();
536    let mut p_it = predictions.get_iterator();
537    let mut w_it = weights.get_iterator();
538    return match order {
539        Order::ASCENDING => roc_auc_on_descending_iterator_with_fp_cutoff(&mut l_it.rev(), &mut p_it.rev(), &mut w_it.rev(), fps, tps, max_false_positive_rate),
540        Order::DESCENDING => roc_auc_on_descending_iterator_with_fp_cutoff(&mut l_it, &mut p_it, &mut w_it, fps, tps, max_false_positive_rate)
541    };
542}
543    
544
545fn roc_auc_on_descending_iterator_with_fp_cutoff<B: BinaryLabel>(
546    labels: &mut impl Iterator<Item = B>,
547    predictions: &mut impl Iterator<Item = f64>,
548    weights: &mut impl Iterator<Item = f64>,
549    false_positive_sum: f64,
550    true_positive_sum: f64,
551    max_false_positive_rate: f64
552) -> f64 {
553    let mut false_positives: f64 = 0.0;
554    let mut true_positives: f64 = 0.0;
555    let mut last_counted_fp = 0.0;
556    let mut last_counted_tp = 0.0;
557    let mut area_under_curve = 0.0;
558    let mut zipped = labels.zip(predictions).zip(weights).peekable();
559    let false_positive_cutoff = max_false_positive_rate * false_positive_sum;
560    loop {
561        match zipped.next() {
562            None => break,
563            Some(actual) => {
564                let l = f64::from(actual.0.0.get_value());
565                let w = actual.1;
566                let wl = l * w;
567                let next_tp = true_positives + wl;
568                let next_fp = false_positives + (w - wl);
569                let is_above_max = next_fp > false_positive_cutoff;
570                if is_above_max {
571                    let dx = next_fp  - false_positives;
572                    let dy = next_tp - true_positives;
573                    true_positives += dy * false_positive_cutoff / dx;
574                    false_positives = false_positive_cutoff;
575                } else {
576                    true_positives = next_tp;
577                    false_positives = next_fp;
578                }
579                if zipped.peek().map(|x| x.0.1 != actual.0.1).unwrap_or(true) || is_above_max {
580                    area_under_curve += area_under_line_segment(last_counted_fp, false_positives, last_counted_tp, true_positives);
581                    last_counted_fp = false_positives;
582                    last_counted_tp = true_positives;
583                }
584                if is_above_max {
585                    break;
586                }                
587            }
588        };
589    }
590    let normalized_area_under_curve = area_under_curve / (true_positive_sum * false_positive_sum);
591    let min_area = 0.5 * max_false_positive_rate * max_false_positive_rate;
592    let max_area = max_false_positive_rate;
593    return 0.5 * (1.0 + (normalized_area_under_curve - min_area) / (max_area - min_area));
594}
595
596pub fn loo_cossim<F: num::Float + AddAssign>(mat: &ArrayView2<'_, F>, replicate_sum: &mut ArrayViewMut1<'_, F>) -> F {
597    let num_replicates = mat.shape()[0];
598    let loo_weight = F::from(num_replicates - 1).unwrap();
599    let loo_weight_factor = F::from(1).unwrap() / loo_weight;
600    for mat_replicate in mat.outer_iter() {
601        for (feature, feature_sum) in mat_replicate.iter().zip(replicate_sum.iter_mut()) {
602            *feature_sum += *feature;
603        }
604    }
605
606    let mut result = F::zero();
607
608    for mat_replicate in mat.outer_iter() {
609        let mut m_sqs = F::zero();
610        let mut l_sqs = F::zero();
611        let mut prod_sum = F::zero();
612        for (feature, feature_sum) in mat_replicate.iter().zip(replicate_sum.iter()) {
613            let m_f = *feature;
614            let l_f = (*feature_sum - *feature) * loo_weight_factor;
615            prod_sum += m_f * l_f;
616            m_sqs += m_f * m_f;
617            l_sqs += l_f * l_f;
618        }
619        result += prod_sum / (m_sqs * l_sqs).sqrt();
620    }
621
622    return result / F::from(num_replicates).unwrap();
623}
624
625pub fn loo_cossim_single<F: num::Float + AddAssign>(mat: &ArrayView2<'_, F>) -> F {
626    let mut replicate_sum = Array1::<F>::zeros(mat.shape()[1]);
627    return loo_cossim(mat, &mut replicate_sum.view_mut());
628}
629
630pub fn loo_cossim_many<F: num::Float + AddAssign>(mat: &ArrayView3<'_, F>) -> Array1<F> {
631    let mut cossims = Array1::<F>::zeros(mat.shape()[0]);
632    let mut replicate_sum = Array1::<F>::zeros(mat.shape()[2]);
633    for (m, c) in mat.outer_iter().zip(cossims.iter_mut()) {
634        replicate_sum.fill(F::zero());
635        *c = loo_cossim(&m, &mut replicate_sum.view_mut());
636    }
637    return cossims;
638}
639
640
641// Python bindings
642#[pyclass(eq, eq_int, name="Order")]
643#[derive(Clone, Copy, PartialEq)]
644pub enum PyOrder {
645    ASCENDING,
646    DESCENDING
647}
648
649fn py_order_as_order(order: PyOrder) -> Order {
650    return match order {
651        PyOrder::ASCENDING => Order::ASCENDING,
652        PyOrder::DESCENDING => Order::DESCENDING,
653    }
654}
655
656
657trait PyScore: Ungil + Sync {
658
659    fn score<B, L, P, W>(&self, labels: &L, predictions: &P, weights: Option<&W>, order: Option<Order>) -> f64
660    where B: BinaryLabel, L: Data<B>, P: SortableData<f64> + Data<f64>, W: Data<f64>;
661
662    fn score_py_generic<'py, B>(
663        &self,
664        py: Python<'py>,
665        labels: &PyReadonlyArray1<'py, B>,
666        predictions: &PyReadonlyArray1<'py, f64>,
667        weights: &Option<PyReadonlyArray1<'py, f64>>,
668        order: &Option<PyOrder>,
669    ) -> f64
670    where B: BinaryLabel + Element
671    {
672        let labels = labels.as_array();
673        let predictions = predictions.as_array();
674        let order = order.map(py_order_as_order);
675        let score = match weights {
676            Some(weight) => {
677                let weights = weight.as_array();
678                py.allow_threads(move || {
679                    self.score(&labels, &predictions, Some(&weights), order)
680                })
681            },
682            None => py.allow_threads(move || {
683                self.score(&labels, &predictions, None::<&Vec<f64>>, order)
684            })
685        };
686        return score;
687    }
688
689    fn score_py_match_run<'py, T>(
690        &self,
691        py: Python<'py>,
692        labels: &Bound<'py, PyUntypedArray>,
693        predictions: &PyReadonlyArray1<'py, f64>,
694        weights: &Option<PyReadonlyArray1<'py, f64>>,
695        order: &Option<PyOrder>,
696        dt: &Bound<'py, PyArrayDescr>
697    ) -> Option<f64>
698    where T: Element + BinaryLabel
699    {
700        return if dt.is_equiv_to(&dtype::<T>(py)) {
701            let labels = labels.downcast::<PyArray1<T>>().unwrap().readonly();
702            Some(self.score_py_generic(py, &labels.readonly(), predictions, weights, order))
703        } else {
704            None
705        };
706    }
707    
708    fn score_py<'py>(
709        &self,
710        py: Python<'py>,
711        labels: &Bound<'py, PyUntypedArray>,
712        predictions: PyReadonlyArray1<'py, f64>,
713        weights: Option<PyReadonlyArray1<'py, f64>>,
714        order: Option<PyOrder>,
715    ) -> PyResult<f64> {
716        if labels.ndim() != 1 {
717            return Err(PyTypeError::new_err(format!("Expected 1-dimensional array for labels but found {} dimenisons.", labels.ndim())));
718        }
719        let label_dtype = labels.dtype();
720        if let Some(score) = self.score_py_match_run::<bool>(py, &labels, &predictions, &weights, &order, &label_dtype) {
721            return Ok(score)
722        }
723        else if let Some(score) = self.score_py_match_run::<u8>(py, &labels, &predictions, &weights, &order, &label_dtype) {
724            return Ok(score)
725        }
726        else if let Some(score) = self.score_py_match_run::<i8>(py, &labels, &predictions, &weights, &order, &label_dtype) {
727            return Ok(score)
728        }
729        else if let Some(score) = self.score_py_match_run::<u16>(py, &labels, &predictions, &weights, &order, &label_dtype) {
730            return Ok(score)
731        }
732        else if let Some(score) = self.score_py_match_run::<i16>(py, &labels, &predictions, &weights, &order, &label_dtype) {
733            return Ok(score)
734        }
735        else if let Some(score) = self.score_py_match_run::<u32>(py, &labels, &predictions, &weights, &order, &label_dtype) {
736            return Ok(score)
737        }
738        else if let Some(score) = self.score_py_match_run::<i32>(py, &labels, &predictions, &weights, &order, &label_dtype) {
739            return Ok(score)
740        }
741        else if let Some(score) = self.score_py_match_run::<u64>(py, &labels, &predictions, &weights, &order, &label_dtype) {
742            return Ok(score)
743        }
744        else if let Some(score) = self.score_py_match_run::<i64>(py, &labels, &predictions, &weights, &order, &label_dtype) {
745            return Ok(score)
746        }
747        return Err(PyTypeError::new_err(format!("Unsupported dtype for labels: {}. Supported dtypes are bool, uint8, uint16, uint32, uint64, in8, int16, int32, int64", label_dtype)));
748    }
749}
750
751struct PyAveragePrecision {
752    
753}
754
755impl PyAveragePrecision{
756    fn new() -> Self {
757        return PyAveragePrecision {};
758    }
759}
760
761impl PyScore for PyAveragePrecision {
762    fn score<B, L, P, W>(&self, labels: &L, predictions: &P, weights: Option<&W>, order: Option<Order>) -> f64
763    where B: BinaryLabel, L: Data<B>, P: SortableData<f64> + Data<f64>, W: Data<f64> {
764        return average_precision_with_order(labels, predictions, weights, order);
765    }
766}
767
768struct PyRocAuc {
769    max_fpr: Option<f64>
770}
771
772impl PyRocAuc {
773    fn new(max_fpr: Option<f64>) -> Self {
774        return PyRocAuc { max_fpr: max_fpr };
775    }
776}
777
778impl PyScore for PyRocAuc {
779    fn score<B, L, P, W>(&self, labels: &L, predictions: &P, weights: Option<&W>, order: Option<Order>) -> f64
780    where B: BinaryLabel, L: Data<B>, P: SortableData<f64> + Data<f64>, W: Data<f64> {
781        return roc_auc_with_order(labels, predictions, weights, order, self.max_fpr);
782    }
783}
784
785
786#[pyfunction(name = "average_precision")]
787#[pyo3(signature = (labels, predictions, *, weights=None, order=None))]
788pub fn average_precision_py<'py>(
789    py: Python<'py>,
790    labels: &Bound<'py, PyUntypedArray>,
791    predictions: PyReadonlyArray1<'py, f64>,
792    weights: Option<PyReadonlyArray1<'py, f64>>,
793    order: Option<PyOrder>
794) -> PyResult<f64> {
795    return PyAveragePrecision::new().score_py(py, labels, predictions, weights, order);
796}
797
798#[pyfunction(name = "roc_auc")]
799#[pyo3(signature = (labels, predictions, *, weights=None, order=None, max_fpr=None))]
800pub fn roc_auc_py<'py>(
801    py: Python<'py>,
802    labels: &Bound<'py, PyUntypedArray>,
803    predictions: PyReadonlyArray1<'py, f64>,
804    weights: Option<PyReadonlyArray1<'py, f64>>,
805    order: Option<PyOrder>,
806    max_fpr: Option<f64>,
807) -> PyResult<f64> {
808    return PyRocAuc::new(max_fpr).score_py(py, labels, predictions, weights, order);
809}
810
811#[pyfunction(name = "loo_cossim")]
812#[pyo3(signature = (data))]
813pub fn loo_cossim_py<'py>(
814    py: Python<'py>,
815    data: &Bound<'py, PyUntypedArray>
816) -> PyResult<f64> {
817    if data.ndim() != 2 {
818        return Err(PyTypeError::new_err(format!("Expected 2-dimensional array for data (samples x features) but found {} dimenisons.", data.ndim())));
819    }
820
821    let dt = data.dtype();
822    if dt.is_equiv_to(&dtype::<f32>(py)) {
823        let typed_data = data.downcast::<PyArray2<f32>>().unwrap().readonly();
824        let array = typed_data.as_array();
825        let score = py.allow_threads(move || {
826            loo_cossim_single(&array)
827        });
828        return Ok(score as f64);
829    }
830    if dt.is_equiv_to(&dtype::<f64>(py)) {
831        let typed_data = data.downcast::<PyArray2<f64>>().unwrap().readonly();
832        let array = typed_data.as_array();
833        let score = py.allow_threads(move || {
834            loo_cossim_single(&array)
835        });
836        return Ok(score);
837    }
838    return Err(PyTypeError::new_err(format!("Only float32 and float64 data supported, but found {}", dt)));
839}
840
841pub fn loo_cossim_many_generic_py<'py, F: num::Float + AddAssign + Element>(
842    py: Python<'py>,
843    data: &Bound<'py, PyArrayDyn<F>>
844) -> PyResult<Bound<'py, PyArray1<F>>> {
845    if data.ndim() != 3 {
846        return Err(PyTypeError::new_err(format!("Expected 3-dimensional array for data (outer(?) x samples x features) but found {} dimenisons.", data.ndim())));
847    }
848    let typed_data = data.downcast::<PyArray3<F>>().unwrap().readonly();
849    let array = typed_data.as_array();
850    let score = py.allow_threads(move || {
851        loo_cossim_many(&array)
852    });
853    // TODO how can we return this generically without making a copy at the end?
854    let score_py = PyArray::from_owned_array(py, score);
855    return Ok(score_py);
856}
857
858#[pyfunction(name = "loo_cossim_many_f64")]
859#[pyo3(signature = (data))]
860pub fn loo_cossim_many_py_f64<'py>(
861    py: Python<'py>,
862    data: &Bound<'py, PyUntypedArray>
863) -> PyResult<Bound<'py, PyArray1<f64>>> {
864    if data.ndim() != 3 {
865        return Err(PyTypeError::new_err(format!("Expected 3-dimensional array for data (outer(?) x samples x features) but found {} dimenisons.", data.ndim())));
866    }
867
868    let dt = data.dtype();
869    if !dt.is_equiv_to(&dtype::<f64>(py)) {
870        return Err(PyTypeError::new_err(format!("Only float64 data supported, but found {}", dt)));
871    }
872    let typed_data = data.downcast::<PyArrayDyn<f64>>().unwrap();
873    return loo_cossim_many_generic_py(py, typed_data);
874}
875
876#[pyfunction(name = "loo_cossim_many_f32")]
877#[pyo3(signature = (data))]
878pub fn loo_cossim_many_py_f32<'py>(
879    py: Python<'py>,
880    data: &Bound<'py, PyUntypedArray>
881) -> PyResult<Bound<'py, PyArray1<f32>>> {
882    if data.ndim() != 3 {
883        return Err(PyTypeError::new_err(format!("Expected 3-dimensional array for data (outer(?) x samples x features) but found {} dimenisons.", data.ndim())));
884    }
885
886    let dt = data.dtype();
887    if !dt.is_equiv_to(&dtype::<f32>(py)) {
888        return Err(PyTypeError::new_err(format!("Only float32 data supported, but found {}", dt)));
889    }
890    let typed_data = data.downcast::<PyArrayDyn<f32>>().unwrap();
891    return loo_cossim_many_generic_py(py, typed_data);
892}
893
894#[pymodule(name = "_scors")]
895fn scors(m: &Bound<'_, PyModule>) -> PyResult<()> {
896    m.add_function(wrap_pyfunction!(average_precision_py, m)?).unwrap();
897    m.add_function(wrap_pyfunction!(roc_auc_py, m)?).unwrap();
898    m.add_function(wrap_pyfunction!(loo_cossim_py, m)?).unwrap();
899    m.add_function(wrap_pyfunction!(loo_cossim_many_py_f64, m)?).unwrap();
900    m.add_function(wrap_pyfunction!(loo_cossim_many_py_f32, m)?).unwrap();
901    m.add_class::<PyOrder>().unwrap();
902    return Ok(());
903}
904
905
906#[cfg(test)]
907mod tests {
908    use approx::{assert_relative_eq};
909    use super::*;
910
911    #[test]
912    fn test_average_precision_on_sorted() {
913        let labels: [u8; 4] = [1, 0, 1, 0];
914        // let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
915        let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
916        let actual = average_precision_on_sorted_labels(&labels, Some(&weights), Order::DESCENDING);
917        assert_eq!(actual, 0.8333333333333333);
918    }
919
920    #[test]
921    fn test_average_precision_unsorted() {
922        let labels: [u8; 4] = [0, 0, 1, 1];
923        let predictions: [f64; 4] = [0.1, 0.4, 0.35, 0.8];
924        let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
925        let actual = average_precision_with_order(&labels, &predictions, Some(&weights), None);
926        assert_eq!(actual, 0.8333333333333333);
927    }
928
929    #[test]
930    fn test_average_precision_sorted() {
931        let labels: [u8; 4] = [1, 0, 1, 0];
932        let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
933        let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
934        let actual = average_precision_with_order(&labels, &predictions, Some(&weights), Some(Order::DESCENDING));
935        assert_eq!(actual, 0.8333333333333333);
936    }
937
938    #[test]
939    fn test_average_precision_sorted_pair() {
940        let labels: [u8; 4] = [1, 0, 1, 0];
941        let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
942        let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
943        let actual = average_precision_on_sorted_samples(&labels, &predictions, &weights, &labels, &predictions, &weights);
944        assert_eq!(actual, 0.8333333333333333);
945    }
946
947    #[test]
948    fn test_roc_auc() {
949        let labels: [u8; 4] = [1, 0, 1, 0];
950        let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
951        let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
952        let actual = roc_auc_with_order(&labels, &predictions, Some(&weights), Some(Order::DESCENDING), None);
953        assert_eq!(actual, 0.75);
954    }
955
956    #[test]
957    fn test_loo_cossim_single() {
958        let data = arr2(&[[0.77395605, 0.43887844, 0.85859792],
959                          [0.69736803, 0.09417735, 0.97562235]]);
960        let cossim = loo_cossim_single(&data.view());
961        let expected = 0.95385941;
962        assert_relative_eq!(cossim, expected);
963    }
964
965    #[test]
966    fn test_loo_cossim_many() {
967        let data = arr3(&[[[0.77395605, 0.43887844, 0.85859792],
968                           [0.69736803, 0.09417735, 0.97562235]],
969                          [[0.7611397 , 0.78606431, 0.12811363],
970                           [0.45038594, 0.37079802, 0.92676499]],
971                          [[0.64386512, 0.82276161, 0.4434142 ],
972                           [0.22723872, 0.55458479, 0.06381726]],
973                          [[0.82763117, 0.6316644 , 0.75808774],
974                           [0.35452597, 0.97069802, 0.89312112]]]);
975        let cossim = loo_cossim_many(&data.view());
976        let expected = arr1(&[0.95385941, 0.62417001, 0.92228589, 0.90025417]);
977        assert_eq!(cossim.shape(), expected.shape());
978        for (c, e) in cossim.iter().zip(expected.iter()) {
979            assert_relative_eq!(c, e);
980        }
981    }
982}