_scors/
lib.rs

1use ndarray::{Array1,ArrayView,ArrayView2,ArrayView3,ArrayViewMut1,Ix1};
2use numpy::{Element,PyArray,PyArray1,PyArray2,PyArray3,PyArrayDescr,PyArrayDescrMethods,PyArrayMethods,PyReadonlyArray1,PyUntypedArray,PyUntypedArrayMethods,dtype};
3use pyo3::Bound;
4use pyo3::exceptions::PyTypeError;
5use pyo3::marker::Ungil;
6use pyo3::prelude::*;
7use std::iter::DoubleEndedIterator;
8use std::marker::PhantomData;
9
10#[derive(Clone, Copy)]
11pub enum Order {
12    ASCENDING,
13    DESCENDING
14}
15
16struct ConstWeight {
17    value: f64
18}
19
20impl ConstWeight {
21    fn new(value: f64) -> Self {
22        return ConstWeight { value: value };
23    }
24    fn one() -> Self {
25        return Self::new(1.0);
26    }
27}
28
29pub trait Data<T: Clone>: {
30    // TODO This is necessary because it seems that there is no trait like that in rust
31    //      Maybe I am just not aware, but for now use my own trait.
32    fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T>;
33    fn get_at(&self, index: usize) -> T;
34}
35
36pub trait SortableData<T> {
37    fn argsort_unstable(&self) -> Vec<usize>;
38}
39
40impl Iterator for ConstWeight {
41    type Item = f64;
42    fn next(&mut self) -> Option<f64> {
43        return Some(self.value);
44    }
45}
46
47impl DoubleEndedIterator for ConstWeight {
48    fn next_back(&mut self) -> Option<f64> {
49        return Some(self.value);
50    }
51}
52
53impl Data<f64> for ConstWeight {
54    fn get_iterator(&self) -> impl DoubleEndedIterator<Item = f64> {
55        return ConstWeight::new(self.value);
56    }
57
58    fn get_at(&self, _index: usize) -> f64 {
59        return self.value.clone();
60    }
61}
62
63impl <T: Clone> Data<T> for Vec<T> {
64    fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> {
65        return self.iter().cloned();
66    }
67    fn get_at(&self, index: usize) -> T {
68        return self[index].clone();
69    }
70}
71
72impl SortableData<f64> for Vec<f64> {
73    fn argsort_unstable(&self) -> Vec<usize> {
74        let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
75        indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
76        // indices.sort_unstable_by_key(|i| self[*i]);
77        return indices;
78    }
79}
80
81impl <T: Clone> Data<T> for &[T] {
82    fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> {
83        return self.iter().cloned();
84    }
85    fn get_at(&self, index: usize) -> T {
86        return self[index].clone();
87    }
88}
89
90impl SortableData<f64> for &[f64] {
91    fn argsort_unstable(&self) -> Vec<usize> {
92        // let t0 = Instant::now();
93        let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
94        // println!("Creating indices took {}ms", t0.elapsed().as_millis());
95        // let t1 = Instant::now();
96        indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
97        // println!("Sorting took {}ms", t0.elapsed().as_millis());
98        return indices;
99    }
100}
101
102impl <T: Clone, const N: usize> Data<T> for [T; N] {
103    fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> {
104        return self.iter().cloned();
105    }
106    fn get_at(&self, index: usize) -> T {
107        return self[index].clone();
108    }
109}
110
111impl <const N: usize> SortableData<f64> for [f64; N] {
112    fn argsort_unstable(&self) -> Vec<usize> {
113        let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
114        indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
115        return indices;
116    }
117}
118
119impl <T: Clone> Data<T> for ArrayView<'_, T, Ix1> {
120    fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> {
121        return self.iter().cloned();
122    }
123    fn get_at(&self, index: usize) -> T {
124        return self[index].clone();
125    }
126}
127
128impl SortableData<f64> for ArrayView<'_, f64, Ix1> {
129    fn argsort_unstable(&self) -> Vec<usize> {
130        let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
131        indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
132        return indices;
133    }
134}
135
136// struct IndexView<'a, T, D> where T: Clone, D: Data<T>{
137//     data: &'a D,
138//     indices: &'a Vec<usize>,
139// }
140
141// impl <'a, T: Clone> Data<T> for IndexView<'a, T> {
142// }
143
144pub trait BinaryLabel: Clone + Copy {
145    fn get_value(&self) -> bool;
146}
147
148impl BinaryLabel for bool {
149    fn get_value(&self) -> bool {
150        return self.clone();
151    }
152}
153
154impl BinaryLabel for u8 {
155    fn get_value(&self) -> bool {
156        return (self & 1) == 1;
157    }
158}
159
160impl BinaryLabel for u16 {
161    fn get_value(&self) -> bool {
162        return (self & 1) == 1;
163    }
164}
165
166impl BinaryLabel for u32 {
167    fn get_value(&self) -> bool {
168        return (self & 1) == 1;
169    }
170}
171
172impl BinaryLabel for u64 {
173    fn get_value(&self) -> bool {
174        return (self & 1) == 1;
175    }
176}
177
178impl BinaryLabel for i8 {
179    fn get_value(&self) -> bool {
180        return (self & 1) == 1;
181    }
182}
183
184impl BinaryLabel for i16 {
185    fn get_value(&self) -> bool {
186        return (self & 1) == 1;
187    }
188}
189
190impl BinaryLabel for i32 {
191    fn get_value(&self) -> bool {
192        return (self & 1) == 1;
193    }
194}
195
196impl BinaryLabel for i64 {
197    fn get_value(&self) -> bool {
198        return (self & 1) == 1;
199    }
200}
201
202struct SortedSampleDescending<'a, B, L, P, W>
203where B: BinaryLabel + Clone + 'a, &'a L: IntoIterator<Item = &'a B>, &'a P: IntoIterator<Item = &'a f64>, &'a W: IntoIterator<Item = &'a f64>
204{
205    labels: &'a L,
206    predictions: &'a P,
207    weights: &'a W,
208    label_type: PhantomData<B>,
209}
210
211impl <'a, B, L, P, W> SortedSampleDescending<'a, B, L, P, W>
212where B: BinaryLabel + Clone + 'a, &'a L: IntoIterator<Item = &'a B>, &'a P: IntoIterator<Item = &'a f64>, &'a W: IntoIterator<Item = &'a f64>
213{
214    fn new(labels: &'a L, predictions: &'a P, weights: &'a W) -> Self {
215        return SortedSampleDescending {
216            labels: labels,
217            predictions: predictions,
218            weights: weights,
219            label_type: PhantomData
220        }
221    }
222}
223
224// struct SortedSampleIterator<'a, B, L, P, W>
225// where B: BinaryLabel + Clone + 'a, L: Iterator<Item = &'a B>, P: Iterator<Item = &'a f64>, W: Iterator<Item = &' f64>
226// {
227//     labels: &'a mut L,
228//     predictions: &'a mut P,
229//     weights: &'a mut W,
230// }
231
232
233// impl <'a, B, L, P, W> Iterator for SortedSampleIterator<'a, B, L, P, W>
234// where B: BinaryLabel, L: Iterator<Item = B>, P: Iterator<Item = f64>, W: Iterator<Item = f64>
235// {
236//     type Item = (B, f64, f64);
237//     fn next(&mut self) -> Option<(B, f64, f64)> {
238//         return match (self.labels.next(), self.predictions.next(), self.weights.next()) {
239//             (Some(l), Some(p), Some(w)) => Some((l, p, w)),
240//             _ => None
241//         };
242//     }
243// }
244
245// impl <'a, B, L, P, W> IntoIterator for &SortedSampleDescending<'a, B, L, P, W>
246// where B: BinaryLabel, L: IntoIterator<Item = B>, P: IntoIterator<Item = f64>, W: IntoIterator<Item = f64> {
247//     type Item = (B, f64, f64);
248//     type IntoIter = SortedSampleIterator<'a, B, L::IntoIter, P::IntoIter, W::IntoIter>;
249
250//     fn into_iter(self) -> Self::IntoIter {
251//         let l = self.labels.into_iter();
252//         // return SortedSampleIterator {
253//         //     labels: self.labels.into_iter(),
254//         //     predictions: self.labels.into_iter(),
255//         //     weights: self.weights.into_iter(),
256//         // }
257//     }
258// }
259
260// struct CombineView<'a, B, L, P, W> 
261// where B: BinaryLabel, L: Data<B>, P: Data<f64>, W: Data<f64>
262// {
263//     sample1: &'a SortedSampleDescending<'a, B, L, P, W>,
264//     sample2: &'a SortedSampleDescending<'a, B, L, P, W>,
265// }
266
267// impl <'a, B, L, P, W> CombineView<'a, B, L, P, W>
268// where B: BinaryLabel, L: Data<B>, P: Data<f64>, W: Data<f64>
269// {
270//     fn new(sample1: &'a SortedSampleDescending<'a, B, L, P, W>, sample2: &'a SortedSampleDescending<'a, B, L, P, W>,) -> Self {
271//         return PairedSortedSampleDescending { sample1: sample1, sample2: sample2 };
272//     }
273// }
274
275// struct CombineViewIterator<'a, B, L, P, W>
276// where B: BinaryLabel, L: Data<B>, P: Data<f64>, W: Data<f64>
277// {
278//     view: &'a CombineView<'a, B, L, P, W>,
279//     iterator1: impl Iterator<Item: ((B, f64), f64),
280//     iterator2: impl Iterator<Item: ((B, f64), f64),
281//     current_index: usize,
282//     num_elements: usize
283// }
284
285// impl <'a, B, L, P, W> CombineViewIterator<'a, B, L, P, W>
286// where B: BinaryLabel, L: Data<B>, P: Data<f64>, W: Data<f64>
287// {
288//     fn new(view: &'a CombineView<'a, B, L, P, W>) -> Self {
289//         return CombineViewIerator {
290//             view: view,
291//             iterator1: view.sample1.iterator(),
292//             iterator2: view.sample2.iterator(),
293//             current_index: 0,
294//             num_elements: usize
295//         };
296//     }
297// }
298
299
300// impl <'a, B, L, P, W> Iterator for CombineViewIterator<'a, B, L, P, W>
301// where B: BinaryLabel, L: Data<B>, P: Data<f64>, W: Data<f64>
302// {
303//     type Item = (B, f64, f64);
304//     fn next(&mut self) -> Option<(B, f64, f64)> {
305//         if self.current_index == self.num_elements {
306//             return None;
307//         }
308        
309//         return Some(self.value);
310//     }
311// }
312
313fn select<T, I>(slice: &I, indices: &[usize]) -> Vec<T>
314where T: Copy, I: Data<T>
315{
316    let mut selection: Vec<T> = Vec::new();
317    selection.reserve_exact(indices.len());
318    for index in indices {
319        selection.push(slice.get_at(*index));
320    }
321    return selection;
322}
323
324pub fn average_precision<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>) -> f64
325where B: BinaryLabel, L: Data<B>, P: SortableData<f64>, W: Data<f64>
326{
327    return average_precision_with_order(labels, predictions, weights, None);
328}
329
330pub fn average_precision_with_order<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>, order: Option<Order>) -> f64
331where B: BinaryLabel, L: Data<B>, P: SortableData<f64>, W: Data<f64>
332{
333    return match order {
334        Some(o) => average_precision_on_sorted_labels(labels, weights, o),
335        None => {
336            let indices = predictions.argsort_unstable();
337            let sorted_labels = select(labels, &indices);
338            let ap = match weights {
339                None => {
340                    // let w: Oepion<&
341                    average_precision_on_sorted_labels(&sorted_labels, weights, Order::DESCENDING)
342                },
343                Some(w) => average_precision_on_sorted_labels(&sorted_labels, Some(&select(w, &indices)), Order::DESCENDING),
344            };
345            ap
346        }
347    };
348}
349
350pub fn average_precision_on_sorted_labels<B, L, W>(labels: &L, weights: Option<&W>, order: Order) -> f64
351where B: BinaryLabel, L: Data<B>, W: Data<f64>
352{
353    return match weights {
354        None => average_precision_on_iterator(labels.get_iterator(), ConstWeight::one(), order),
355        Some(w) => average_precision_on_iterator(labels.get_iterator(), w.get_iterator(), order)
356    };
357}
358
359pub fn average_precision_on_iterator<B, L, W>(labels: L, weights: W, order: Order) -> f64
360where B: BinaryLabel, L: DoubleEndedIterator<Item = B>, W: DoubleEndedIterator<Item = f64>
361{
362    return match order {
363        Order::ASCENDING => average_precision_on_descending_iterator(labels.rev(), weights.rev()),
364        Order::DESCENDING => average_precision_on_descending_iterator(labels, weights)
365    };
366}
367
368pub fn average_precision_on_descending_iterator<B: BinaryLabel>(labels: impl Iterator<Item = B>, weights: impl Iterator<Item = f64>) -> f64 {
369    return average_precision_on_descending_iterators(labels.zip(weights));
370}
371
372// impl <'a, B, L, P, W> SortedSampleDescending<'a, B, L, P, W>
373// where B: BinaryLabel, L: IntoIterator<Item = B>, P: IntoIterator<Item = f64>, W: IntoIterator<Item = f64>
374
375pub fn average_precision_on_sorted_samples<'a, B, L, P, W>(l1: &'a L, p1: &'a P, w1: &'a W, l2: &'a L, p2: &'a P, w2: &'a W) -> f64
376where B: BinaryLabel + Clone + 'a, &'a L: IntoIterator<Item = &'a B>, &'a P: IntoIterator<Item = &'a f64>, &'a W: IntoIterator<Item = &'a f64>
377{
378    // let mut it1 = p1.into_iter();
379    let i1 = p1.into_iter().cloned().zip(l1.into_iter().cloned().zip(w1.into_iter().cloned()));
380    let i2 = p2.into_iter().cloned().zip(l2.into_iter().cloned().zip(w2.into_iter().cloned()));
381    let labels_and_weights = i1.zip(i2).map(|(t1, t2)| {
382        if t1.0 > t2.0 {
383            t1.1
384        } else {
385            t2.1
386        }
387    });
388    return average_precision_on_descending_iterators(labels_and_weights);
389}
390
391pub fn average_precision_on_descending_iterators<B: BinaryLabel>(labels_and_weights: impl Iterator<Item = (B, f64)>) -> f64 {
392    let mut ap: f64 = 0.0;
393    let mut tps: f64 = 0.0;
394    let mut fps: f64 = 0.0;
395    for (label, weight) in labels_and_weights {
396        let w: f64 = weight;
397        let l: bool = label.get_value();
398        let tp = w * f64::from(l);
399        tps += tp;
400        fps += weight - tp;
401        let ps = tps + fps;
402        let precision = tps / ps;
403        ap += tp * precision;
404    }
405    // Special case for tps == 0 following sklearn
406    // https://github.com/scikit-learn/scikit-learn/blob/5cce87176a530d2abea45b5a7e5a4d837c481749/sklearn/metrics/_ranking.py#L1032-L1039
407    // I.e. if tps is 0.0, there are no positive samples in labels: Either all labels are 0, or all weights (for positive labels) are 0
408    return if tps == 0.0 {
409        0.0
410    } else {
411        ap / tps
412    };
413}
414
415
416
417// ROC AUC score
418pub fn roc_auc<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>) -> f64
419where B: BinaryLabel, L: Data<B>, P: SortableData<f64> + Data<f64>, W: Data<f64>
420{
421    return roc_auc_with_order(labels, predictions, weights, None, None);
422}
423
424pub fn roc_auc_max_fpr<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>, max_false_positive_rate: Option<f64>) -> f64
425where B: BinaryLabel, L: Data<B>, P: SortableData<f64> + Data<f64>, W: Data<f64>
426{
427    return roc_auc_with_order(labels, predictions, weights, None, max_false_positive_rate);
428}
429
430pub fn roc_auc_with_order<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>, order: Option<Order>, max_false_positive_rate: Option<f64>) -> f64
431where B: BinaryLabel, L: Data<B>, P: SortableData<f64> + Data<f64>, W: Data<f64>
432{
433    return match order {
434        Some(o) => roc_auc_on_sorted_labels(labels, predictions, weights, o, max_false_positive_rate),
435        None => {
436            let indices = predictions.argsort_unstable();
437            let sorted_labels = select(labels, &indices);
438            let sorted_predictions = select(predictions, &indices);
439            let roc_auc_score = match weights {
440                Some(w) => {
441                    let sorted_weights = select(w, &indices);
442                    roc_auc_on_sorted_labels(&sorted_labels, &sorted_predictions, Some(&sorted_weights), Order::DESCENDING, max_false_positive_rate)
443                },
444                None => {
445                    roc_auc_on_sorted_labels(&sorted_labels, &sorted_predictions, None::<&W>, Order::DESCENDING, max_false_positive_rate)
446                }
447            };
448            roc_auc_score
449        }
450    };
451}
452pub fn roc_auc_on_sorted_labels<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>, order: Order, max_false_positive_rate: Option<f64>) -> f64
453where B: BinaryLabel, L: Data<B>, P: Data<f64>, W: Data<f64> {
454    return match max_false_positive_rate {
455        None => match weights {
456            Some(w) => roc_auc_on_sorted_iterator(&mut labels.get_iterator(), &mut predictions.get_iterator(), &mut w.get_iterator(), order),
457            None => roc_auc_on_sorted_iterator(&mut labels.get_iterator(), &mut predictions.get_iterator(), &mut ConstWeight::one().get_iterator(), order),
458        }
459        Some(max_fpr) => match weights {
460            Some(w) => roc_auc_on_sorted_with_fp_cutoff(labels, predictions, w, order, max_fpr),
461            None => roc_auc_on_sorted_with_fp_cutoff(labels, predictions, &ConstWeight::one(), order, max_fpr),
462        }
463    };
464}
465
466pub fn roc_auc_on_sorted_iterator<B: BinaryLabel>(
467    labels: &mut impl DoubleEndedIterator<Item = B>,
468    predictions: &mut impl DoubleEndedIterator<Item = f64>,
469    weights: &mut impl DoubleEndedIterator<Item = f64>,
470    order: Order
471) -> f64 {
472    return match order {
473        Order::ASCENDING => roc_auc_on_descending_iterator(&mut labels.rev(), &mut predictions.rev(), &mut weights.rev()),
474        Order::DESCENDING => roc_auc_on_descending_iterator(labels, predictions, weights)
475    }
476}
477
478pub fn roc_auc_on_descending_iterator<B: BinaryLabel>(
479    labels: &mut impl Iterator<Item = B>,
480    predictions: &mut impl Iterator<Item = f64>,
481    weights: &mut impl Iterator<Item = f64>
482) -> f64 {
483    let mut false_positives: f64 = 0.0;
484    let mut true_positives: f64 = 0.0;
485    let mut last_counted_fp = 0.0;
486    let mut last_counted_tp = 0.0;
487    let mut area_under_curve = 0.0;
488    let mut zipped = labels.zip(predictions).zip(weights).peekable();
489    loop {
490        match zipped.next() {
491            None => break,
492            Some(actual) => {
493                let l = f64::from(actual.0.0.get_value());
494                let w = actual.1;
495                let wl = l * w;
496                true_positives += wl;
497                false_positives += w - wl;
498                if zipped.peek().map(|x| x.0.1 != actual.0.1).unwrap_or(true) {
499                    area_under_curve += area_under_line_segment(last_counted_fp, false_positives, last_counted_tp, true_positives);
500                    last_counted_fp = false_positives;
501                    last_counted_tp = true_positives;
502                }
503            }
504        };
505    }
506    return area_under_curve / (true_positives * false_positives);
507}
508
509fn area_under_line_segment(x0: f64, x1: f64, y0: f64, y1: f64) -> f64 {
510    let dx = x1 - x0;
511    let dy = y1 - y0;
512    return dx * y0 + dy * dx * 0.5;
513}
514
515fn get_positive_sum<B: BinaryLabel>(
516    labels: impl Iterator<Item = B>,
517    weights: impl Iterator<Item = f64>
518) -> (f64, f64) {
519    let mut false_positives = 0f64;
520    let mut true_positives = 0f64;
521    for (label, weight) in labels.zip(weights) {
522        let lw = weight * f64::from(label.get_value());
523        false_positives += weight - lw;
524        true_positives += lw;
525    }
526    return (false_positives, true_positives);
527}
528
529pub fn roc_auc_on_sorted_with_fp_cutoff<B, L, P, W>(labels: &L, predictions: &P, weights: &W, order: Order, max_false_positive_rate: f64) -> f64
530where B: BinaryLabel, L: Data<B>, P: Data<f64>, W: Data<f64> {
531    // TODO validate max_fpr
532    let (fps, tps) = get_positive_sum(labels.get_iterator(), weights.get_iterator());
533    let mut l_it = labels.get_iterator();
534    let mut p_it = predictions.get_iterator();
535    let mut w_it = weights.get_iterator();
536    return match order {
537        Order::ASCENDING => roc_auc_on_descending_iterator_with_fp_cutoff(&mut l_it.rev(), &mut p_it.rev(), &mut w_it.rev(), fps, tps, max_false_positive_rate),
538        Order::DESCENDING => roc_auc_on_descending_iterator_with_fp_cutoff(&mut l_it, &mut p_it, &mut w_it, fps, tps, max_false_positive_rate)
539    };
540}
541    
542
543fn roc_auc_on_descending_iterator_with_fp_cutoff<B: BinaryLabel>(
544    labels: &mut impl Iterator<Item = B>,
545    predictions: &mut impl Iterator<Item = f64>,
546    weights: &mut impl Iterator<Item = f64>,
547    false_positive_sum: f64,
548    true_positive_sum: f64,
549    max_false_positive_rate: f64
550) -> f64 {
551    let mut false_positives: f64 = 0.0;
552    let mut true_positives: f64 = 0.0;
553    let mut last_counted_fp = 0.0;
554    let mut last_counted_tp = 0.0;
555    let mut area_under_curve = 0.0;
556    let mut zipped = labels.zip(predictions).zip(weights).peekable();
557    let false_positive_cutoff = max_false_positive_rate * false_positive_sum;
558    loop {
559        match zipped.next() {
560            None => break,
561            Some(actual) => {
562                let l = f64::from(actual.0.0.get_value());
563                let w = actual.1;
564                let wl = l * w;
565                let next_tp = true_positives + wl;
566                let next_fp = false_positives + (w - wl);
567                let is_above_max = next_fp > false_positive_cutoff;
568                if is_above_max {
569                    let dx = next_fp  - false_positives;
570                    let dy = next_tp - true_positives;
571                    true_positives += dy * false_positive_cutoff / dx;
572                    false_positives = false_positive_cutoff;
573                } else {
574                    true_positives = next_tp;
575                    false_positives = next_fp;
576                }
577                if zipped.peek().map(|x| x.0.1 != actual.0.1).unwrap_or(true) || is_above_max {
578                    area_under_curve += area_under_line_segment(last_counted_fp, false_positives, last_counted_tp, true_positives);
579                    last_counted_fp = false_positives;
580                    last_counted_tp = true_positives;
581                }
582                if is_above_max {
583                    break;
584                }                
585            }
586        };
587    }
588    let normalized_area_under_curve = area_under_curve / (true_positive_sum * false_positive_sum);
589    let min_area = 0.5 * max_false_positive_rate * max_false_positive_rate;
590    let max_area = max_false_positive_rate;
591    return 0.5 * (1.0 + (normalized_area_under_curve - min_area) / (max_area - min_area));
592}
593
594pub fn loo_cossim(mat: &ArrayView2<'_, f64>, replicate_sum: &mut ArrayViewMut1<'_, f64>) -> f64 {
595    let num_replicates = mat.shape()[0];
596    let loo_weight = num_replicates as f64 - 1.0;
597    let loo_weight_factor = 1.0 / loo_weight;
598    for mat_replicate in mat.outer_iter() {
599        for (feature, feature_sum) in mat_replicate.iter().zip(replicate_sum.iter_mut()) {
600            *feature_sum += feature;
601        }
602    }
603
604    let mut result = 0f64;
605
606    for mat_replicate in mat.outer_iter() {
607        let mut m_sqs = 0.0;
608        let mut l_sqs = 0.0;
609        let mut prod_sum = 0.0;
610        for (feature, feature_sum) in mat_replicate.iter().zip(replicate_sum.iter()) {
611            let m_f = feature;
612            let l_f = (feature_sum - feature) * loo_weight_factor;
613            prod_sum += m_f * l_f;
614            m_sqs += m_f * m_f;
615            l_sqs += l_f * l_f;
616        }
617        result += prod_sum / (m_sqs * l_sqs).sqrt();
618    }
619
620    return result / num_replicates as f64;
621}
622
623pub fn loo_cossim_single(mat: &ArrayView2<'_, f64>) -> f64 {
624    let mut replicate_sum = Array1::<f64>::zeros(mat.shape()[1]);
625    return loo_cossim(mat, &mut replicate_sum.view_mut());
626}
627
628pub fn loo_cossim_many(mat: &ArrayView3<'_, f64>) -> Array1<f64> {
629    let mut cossims = Array1::<f64>::zeros(mat.shape()[0]);
630    let mut replicate_sum = Array1::<f64>::zeros(mat.shape()[2]);
631    for (m, c) in mat.outer_iter().zip(cossims.iter_mut()) {
632        replicate_sum.fill(0.0);
633        *c = loo_cossim(&m, &mut replicate_sum.view_mut());
634    }
635    return cossims;
636}
637
638
639// Python bindings
640#[pyclass(eq, eq_int, name="Order")]
641#[derive(Clone, Copy, PartialEq)]
642pub enum PyOrder {
643    ASCENDING,
644    DESCENDING
645}
646
647fn py_order_as_order(order: PyOrder) -> Order {
648    return match order {
649        PyOrder::ASCENDING => Order::ASCENDING,
650        PyOrder::DESCENDING => Order::DESCENDING,
651    }
652}
653
654
655trait PyScore: Ungil + Sync {
656
657    fn score<B, L, P, W>(&self, labels: &L, predictions: &P, weights: Option<&W>, order: Option<Order>) -> f64
658    where B: BinaryLabel, L: Data<B>, P: SortableData<f64> + Data<f64>, W: Data<f64>;
659
660    fn score_py_generic<'py, B>(
661        &self,
662        py: Python<'py>,
663        labels: &PyReadonlyArray1<'py, B>,
664        predictions: &PyReadonlyArray1<'py, f64>,
665        weights: &Option<PyReadonlyArray1<'py, f64>>,
666        order: &Option<PyOrder>,
667    ) -> f64
668    where B: BinaryLabel + Element
669    {
670        let labels = labels.as_array();
671        let predictions = predictions.as_array();
672        let order = order.map(py_order_as_order);
673        let score = match weights {
674            Some(weight) => {
675                let weights = weight.as_array();
676                py.allow_threads(move || {
677                    self.score(&labels, &predictions, Some(&weights), order)
678                })
679            },
680            None => py.allow_threads(move || {
681                self.score(&labels, &predictions, None::<&Vec<f64>>, order)
682            })
683        };
684        return score;
685    }
686
687    fn score_py_match_run<'py, T>(
688        &self,
689        py: Python<'py>,
690        labels: &Bound<'py, PyUntypedArray>,
691        predictions: &PyReadonlyArray1<'py, f64>,
692        weights: &Option<PyReadonlyArray1<'py, f64>>,
693        order: &Option<PyOrder>,
694        dt: &Bound<'py, PyArrayDescr>
695    ) -> Option<f64>
696    where T: Element + BinaryLabel
697    {
698        return if dt.is_equiv_to(&dtype::<T>(py)) {
699            let labels = labels.downcast::<PyArray1<T>>().unwrap().readonly();
700            Some(self.score_py_generic(py, &labels.readonly(), predictions, weights, order))
701        } else {
702            None
703        };
704    }
705    
706    fn score_py<'py>(
707        &self,
708        py: Python<'py>,
709        labels: &Bound<'py, PyUntypedArray>,
710        predictions: PyReadonlyArray1<'py, f64>,
711        weights: Option<PyReadonlyArray1<'py, f64>>,
712        order: Option<PyOrder>,
713    ) -> PyResult<f64> {
714        if labels.ndim() != 1 {
715            return Err(PyTypeError::new_err(format!("Expected 1-dimensional array for labels but found {} dimenisons.", labels.ndim())));
716        }
717        let label_dtype = labels.dtype();
718        if let Some(score) = self.score_py_match_run::<bool>(py, &labels, &predictions, &weights, &order, &label_dtype) {
719            return Ok(score)
720        }
721        else if let Some(score) = self.score_py_match_run::<u8>(py, &labels, &predictions, &weights, &order, &label_dtype) {
722            return Ok(score)
723        }
724        else if let Some(score) = self.score_py_match_run::<i8>(py, &labels, &predictions, &weights, &order, &label_dtype) {
725            return Ok(score)
726        }
727        else if let Some(score) = self.score_py_match_run::<u16>(py, &labels, &predictions, &weights, &order, &label_dtype) {
728            return Ok(score)
729        }
730        else if let Some(score) = self.score_py_match_run::<i16>(py, &labels, &predictions, &weights, &order, &label_dtype) {
731            return Ok(score)
732        }
733        else if let Some(score) = self.score_py_match_run::<u32>(py, &labels, &predictions, &weights, &order, &label_dtype) {
734            return Ok(score)
735        }
736        else if let Some(score) = self.score_py_match_run::<i32>(py, &labels, &predictions, &weights, &order, &label_dtype) {
737            return Ok(score)
738        }
739        else if let Some(score) = self.score_py_match_run::<u64>(py, &labels, &predictions, &weights, &order, &label_dtype) {
740            return Ok(score)
741        }
742        else if let Some(score) = self.score_py_match_run::<i64>(py, &labels, &predictions, &weights, &order, &label_dtype) {
743            return Ok(score)
744        }
745        return Err(PyTypeError::new_err(format!("Unsupported dtype for labels: {}. Supported dtypes are bool, uint8, uint16, uint32, uint64, in8, int16, int32, int64", label_dtype)));
746    }
747}
748
749struct PyAveragePrecision {
750    
751}
752
753impl PyAveragePrecision{
754    fn new() -> Self {
755        return PyAveragePrecision {};
756    }
757}
758
759impl PyScore for PyAveragePrecision {
760    fn score<B, L, P, W>(&self, labels: &L, predictions: &P, weights: Option<&W>, order: Option<Order>) -> f64
761    where B: BinaryLabel, L: Data<B>, P: SortableData<f64> + Data<f64>, W: Data<f64> {
762        return average_precision_with_order(labels, predictions, weights, order);
763    }
764}
765
766struct PyRocAuc {
767    max_fpr: Option<f64>
768}
769
770impl PyRocAuc {
771    fn new(max_fpr: Option<f64>) -> Self {
772        return PyRocAuc { max_fpr: max_fpr };
773    }
774}
775
776impl PyScore for PyRocAuc {
777    fn score<B, L, P, W>(&self, labels: &L, predictions: &P, weights: Option<&W>, order: Option<Order>) -> f64
778    where B: BinaryLabel, L: Data<B>, P: SortableData<f64> + Data<f64>, W: Data<f64> {
779        return roc_auc_with_order(labels, predictions, weights, order, self.max_fpr);
780    }
781}
782
783
784#[pyfunction(name = "average_precision")]
785#[pyo3(signature = (labels, predictions, *, weights=None, order=None))]
786pub fn average_precision_py<'py>(
787    py: Python<'py>,
788    labels: &Bound<'py, PyUntypedArray>,
789    predictions: PyReadonlyArray1<'py, f64>,
790    weights: Option<PyReadonlyArray1<'py, f64>>,
791    order: Option<PyOrder>
792) -> PyResult<f64> {
793    return PyAveragePrecision::new().score_py(py, labels, predictions, weights, order);
794}
795
796#[pyfunction(name = "roc_auc")]
797#[pyo3(signature = (labels, predictions, *, weights=None, order=None, max_fpr=None))]
798pub fn roc_auc_py<'py>(
799    py: Python<'py>,
800    labels: &Bound<'py, PyUntypedArray>,
801    predictions: PyReadonlyArray1<'py, f64>,
802    weights: Option<PyReadonlyArray1<'py, f64>>,
803    order: Option<PyOrder>,
804    max_fpr: Option<f64>,
805) -> PyResult<f64> {
806    return PyRocAuc::new(max_fpr).score_py(py, labels, predictions, weights, order);
807}
808
809#[pyfunction(name = "loo_cossim")]
810#[pyo3(signature = (data))]
811pub fn loo_cossim_py<'py>(
812    py: Python<'py>,
813    data: &Bound<'py, PyUntypedArray>
814) -> PyResult<f64> {
815    if data.ndim() != 2 {
816        return Err(PyTypeError::new_err(format!("Expected 2-dimensional array for data (samples x features) but found {} dimenisons.", data.ndim())));
817    }
818
819    let dt = data.dtype();
820    if !dt.is_equiv_to(&dtype::<f64>(py)) {
821        return Err(PyTypeError::new_err(format!("Currently only float64 data supported, but found {}", dt)));
822    }
823    let typed_data = data.downcast::<PyArray2<f64>>().unwrap().readonly();
824    let array = typed_data.as_array();
825    let score = py.allow_threads(move || {
826        loo_cossim_single(&array)
827    });
828    return Ok(score);
829}
830
831#[pyfunction(name = "loo_cossim_many")]
832#[pyo3(signature = (data))]
833pub fn loo_cossim_many_py<'py>(
834    py: Python<'py>,
835    data: &Bound<'py, PyUntypedArray>
836) -> PyResult<Bound<'py, PyArray1<f64>>> {
837    if data.ndim() != 3 {
838        return Err(PyTypeError::new_err(format!("Expected 3-dimensional array for data (outer(?) x samples x features) but found {} dimenisons.", data.ndim())));
839    }
840
841    let dt = data.dtype();
842    if !dt.is_equiv_to(&dtype::<f64>(py)) {
843        return Err(PyTypeError::new_err(format!("Currently only float64 data supported, but found {}", dt)));
844    }
845    let typed_data = data.downcast::<PyArray3<f64>>().unwrap().readonly();
846    let array = typed_data.as_array();
847    let score = py.allow_threads(move || {
848        loo_cossim_many(&array)
849    });
850    let score_py = PyArray::from_owned_array(py, score);
851    return Ok(score_py);
852}
853
854#[pymodule(name = "_scors")]
855fn scors(m: &Bound<'_, PyModule>) -> PyResult<()> {
856    m.add_function(wrap_pyfunction!(average_precision_py, m)?).unwrap();
857    m.add_function(wrap_pyfunction!(roc_auc_py, m)?).unwrap();
858    m.add_function(wrap_pyfunction!(loo_cossim_py, m)?).unwrap();
859    m.add_function(wrap_pyfunction!(loo_cossim_many_py, m)?).unwrap();
860    m.add_class::<PyOrder>().unwrap();
861    return Ok(());
862}
863
864
865#[cfg(test)]
866mod tests {
867    use approx::{assert_relative_eq};
868    use super::*;
869
870    #[test]
871    fn test_average_precision_on_sorted() {
872        let labels: [u8; 4] = [1, 0, 1, 0];
873        // let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
874        let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
875        let actual = average_precision_on_sorted_labels(&labels, Some(&weights), Order::DESCENDING);
876        assert_eq!(actual, 0.8333333333333333);
877    }
878
879    #[test]
880    fn test_average_precision_unsorted() {
881        let labels: [u8; 4] = [0, 0, 1, 1];
882        let predictions: [f64; 4] = [0.1, 0.4, 0.35, 0.8];
883        let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
884        let actual = average_precision_with_order(&labels, &predictions, Some(&weights), None);
885        assert_eq!(actual, 0.8333333333333333);
886    }
887
888    #[test]
889    fn test_average_precision_sorted() {
890        let labels: [u8; 4] = [1, 0, 1, 0];
891        let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
892        let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
893        let actual = average_precision_with_order(&labels, &predictions, Some(&weights), Some(Order::DESCENDING));
894        assert_eq!(actual, 0.8333333333333333);
895    }
896
897    #[test]
898    fn test_average_precision_sorted_pair() {
899        let labels: [u8; 4] = [1, 0, 1, 0];
900        let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
901        let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
902        let actual = average_precision_on_sorted_samples(&labels, &predictions, &weights, &labels, &predictions, &weights);
903        assert_eq!(actual, 0.8333333333333333);
904    }
905
906    #[test]
907    fn test_roc_auc() {
908        let labels: [u8; 4] = [1, 0, 1, 0];
909        let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
910        let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
911        let actual = roc_auc_with_order(&labels, &predictions, Some(&weights), Some(Order::DESCENDING), None);
912        assert_eq!(actual, 0.75);
913    }
914
915    #[test]
916    fn test_loo_cossim_single() {
917        let data = arr2(&[[0.77395605, 0.43887844, 0.85859792],
918                          [0.69736803, 0.09417735, 0.97562235]]);
919        let cossim = loo_cossim_single(&data.view());
920        let expected = 0.95385941;
921        assert_relative_eq!(cossim, expected);
922    }
923
924    #[test]
925    fn test_loo_cossim_many() {
926        let data = arr3(&[[[0.77395605, 0.43887844, 0.85859792],
927                           [0.69736803, 0.09417735, 0.97562235]],
928                          [[0.7611397 , 0.78606431, 0.12811363],
929                           [0.45038594, 0.37079802, 0.92676499]],
930                          [[0.64386512, 0.82276161, 0.4434142 ],
931                           [0.22723872, 0.55458479, 0.06381726]],
932                          [[0.82763117, 0.6316644 , 0.75808774],
933                           [0.35452597, 0.97069802, 0.89312112]]]);
934        let cossim = loo_cossim_many(&data.view());
935        let expected = arr1(&[0.95385941, 0.62417001, 0.92228589, 0.90025417]);
936        assert_eq!(cossim.shape(), expected.shape());
937        for (c, e) in cossim.iter().zip(expected.iter()) {
938            assert_relative_eq!(c, e);
939        }
940    }
941}