1use ndarray::{Array1,ArrayView,ArrayView2,ArrayView3,ArrayViewMut1,Ix1};
2use num;
3use numpy::{Element,PyArray,PyArray1,PyArray2,PyArray3,PyArrayDescr,PyArrayDescrMethods,PyArrayDyn,PyArrayMethods,PyReadonlyArray1,PyUntypedArray,PyUntypedArrayMethods,dtype};
4use pyo3::Bound;
5use pyo3::exceptions::PyTypeError;
6use pyo3::marker::Ungil;
7use pyo3::prelude::*;
8use std::iter::DoubleEndedIterator;
9use std::marker::PhantomData;
10use std::ops::AddAssign;
11
12#[derive(Clone, Copy)]
13pub enum Order {
14 ASCENDING,
15 DESCENDING
16}
17
18struct ConstWeight {
19 value: f64
20}
21
22impl ConstWeight {
23 fn new(value: f64) -> Self {
24 return ConstWeight { value: value };
25 }
26 fn one() -> Self {
27 return Self::new(1.0);
28 }
29}
30
31pub trait Data<T: Clone>: {
32 fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T>;
35 fn get_at(&self, index: usize) -> T;
36}
37
38pub trait SortableData<T> {
39 fn argsort_unstable(&self) -> Vec<usize>;
40}
41
42impl Iterator for ConstWeight {
43 type Item = f64;
44 fn next(&mut self) -> Option<f64> {
45 return Some(self.value);
46 }
47}
48
49impl DoubleEndedIterator for ConstWeight {
50 fn next_back(&mut self) -> Option<f64> {
51 return Some(self.value);
52 }
53}
54
55impl Data<f64> for ConstWeight {
56 fn get_iterator(&self) -> impl DoubleEndedIterator<Item = f64> {
57 return ConstWeight::new(self.value);
58 }
59
60 fn get_at(&self, _index: usize) -> f64 {
61 return self.value.clone();
62 }
63}
64
65impl <T: Clone> Data<T> for Vec<T> {
66 fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> {
67 return self.iter().cloned();
68 }
69 fn get_at(&self, index: usize) -> T {
70 return self[index].clone();
71 }
72}
73
74impl SortableData<f64> for Vec<f64> {
75 fn argsort_unstable(&self) -> Vec<usize> {
76 let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
77 indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
78 return indices;
80 }
81}
82
83impl <T: Clone> Data<T> for &[T] {
84 fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> {
85 return self.iter().cloned();
86 }
87 fn get_at(&self, index: usize) -> T {
88 return self[index].clone();
89 }
90}
91
92impl SortableData<f64> for &[f64] {
93 fn argsort_unstable(&self) -> Vec<usize> {
94 let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
96 indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
99 return indices;
101 }
102}
103
104impl <T: Clone, const N: usize> Data<T> for [T; N] {
105 fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> {
106 return self.iter().cloned();
107 }
108 fn get_at(&self, index: usize) -> T {
109 return self[index].clone();
110 }
111}
112
113impl <const N: usize> SortableData<f64> for [f64; N] {
114 fn argsort_unstable(&self) -> Vec<usize> {
115 let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
116 indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
117 return indices;
118 }
119}
120
121impl <T: Clone> Data<T> for ArrayView<'_, T, Ix1> {
122 fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> {
123 return self.iter().cloned();
124 }
125 fn get_at(&self, index: usize) -> T {
126 return self[index].clone();
127 }
128}
129
130impl SortableData<f64> for ArrayView<'_, f64, Ix1> {
131 fn argsort_unstable(&self) -> Vec<usize> {
132 let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
133 indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
134 return indices;
135 }
136}
137
138pub trait BinaryLabel: Clone + Copy {
147 fn get_value(&self) -> bool;
148}
149
150impl BinaryLabel for bool {
151 fn get_value(&self) -> bool {
152 return self.clone();
153 }
154}
155
156impl BinaryLabel for u8 {
157 fn get_value(&self) -> bool {
158 return (self & 1) == 1;
159 }
160}
161
162impl BinaryLabel for u16 {
163 fn get_value(&self) -> bool {
164 return (self & 1) == 1;
165 }
166}
167
168impl BinaryLabel for u32 {
169 fn get_value(&self) -> bool {
170 return (self & 1) == 1;
171 }
172}
173
174impl BinaryLabel for u64 {
175 fn get_value(&self) -> bool {
176 return (self & 1) == 1;
177 }
178}
179
180impl BinaryLabel for i8 {
181 fn get_value(&self) -> bool {
182 return (self & 1) == 1;
183 }
184}
185
186impl BinaryLabel for i16 {
187 fn get_value(&self) -> bool {
188 return (self & 1) == 1;
189 }
190}
191
192impl BinaryLabel for i32 {
193 fn get_value(&self) -> bool {
194 return (self & 1) == 1;
195 }
196}
197
198impl BinaryLabel for i64 {
199 fn get_value(&self) -> bool {
200 return (self & 1) == 1;
201 }
202}
203
204struct SortedSampleDescending<'a, B, L, P, W>
205where B: BinaryLabel + Clone + 'a, &'a L: IntoIterator<Item = &'a B>, &'a P: IntoIterator<Item = &'a f64>, &'a W: IntoIterator<Item = &'a f64>
206{
207 labels: &'a L,
208 predictions: &'a P,
209 weights: &'a W,
210 label_type: PhantomData<B>,
211}
212
213impl <'a, B, L, P, W> SortedSampleDescending<'a, B, L, P, W>
214where B: BinaryLabel + Clone + 'a, &'a L: IntoIterator<Item = &'a B>, &'a P: IntoIterator<Item = &'a f64>, &'a W: IntoIterator<Item = &'a f64>
215{
216 fn new(labels: &'a L, predictions: &'a P, weights: &'a W) -> Self {
217 return SortedSampleDescending {
218 labels: labels,
219 predictions: predictions,
220 weights: weights,
221 label_type: PhantomData
222 }
223 }
224}
225
226fn select<T, I>(slice: &I, indices: &[usize]) -> Vec<T>
316where T: Copy, I: Data<T>
317{
318 let mut selection: Vec<T> = Vec::new();
319 selection.reserve_exact(indices.len());
320 for index in indices {
321 selection.push(slice.get_at(*index));
322 }
323 return selection;
324}
325
326pub fn average_precision<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>) -> f64
327where B: BinaryLabel, L: Data<B>, P: SortableData<f64>, W: Data<f64>
328{
329 return average_precision_with_order(labels, predictions, weights, None);
330}
331
332pub fn average_precision_with_order<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>, order: Option<Order>) -> f64
333where B: BinaryLabel, L: Data<B>, P: SortableData<f64>, W: Data<f64>
334{
335 return match order {
336 Some(o) => average_precision_on_sorted_labels(labels, weights, o),
337 None => {
338 let indices = predictions.argsort_unstable();
339 let sorted_labels = select(labels, &indices);
340 let ap = match weights {
341 None => {
342 average_precision_on_sorted_labels(&sorted_labels, weights, Order::DESCENDING)
344 },
345 Some(w) => average_precision_on_sorted_labels(&sorted_labels, Some(&select(w, &indices)), Order::DESCENDING),
346 };
347 ap
348 }
349 };
350}
351
352pub fn average_precision_on_sorted_labels<B, L, W>(labels: &L, weights: Option<&W>, order: Order) -> f64
353where B: BinaryLabel, L: Data<B>, W: Data<f64>
354{
355 return match weights {
356 None => average_precision_on_iterator(labels.get_iterator(), ConstWeight::one(), order),
357 Some(w) => average_precision_on_iterator(labels.get_iterator(), w.get_iterator(), order)
358 };
359}
360
361pub fn average_precision_on_iterator<B, L, W>(labels: L, weights: W, order: Order) -> f64
362where B: BinaryLabel, L: DoubleEndedIterator<Item = B>, W: DoubleEndedIterator<Item = f64>
363{
364 return match order {
365 Order::ASCENDING => average_precision_on_descending_iterator(labels.rev(), weights.rev()),
366 Order::DESCENDING => average_precision_on_descending_iterator(labels, weights)
367 };
368}
369
370pub fn average_precision_on_descending_iterator<B: BinaryLabel>(labels: impl Iterator<Item = B>, weights: impl Iterator<Item = f64>) -> f64 {
371 return average_precision_on_descending_iterators(labels.zip(weights));
372}
373
374pub fn average_precision_on_sorted_samples<'a, B, L, P, W>(l1: &'a L, p1: &'a P, w1: &'a W, l2: &'a L, p2: &'a P, w2: &'a W) -> f64
378where B: BinaryLabel + Clone + 'a, &'a L: IntoIterator<Item = &'a B>, &'a P: IntoIterator<Item = &'a f64>, &'a W: IntoIterator<Item = &'a f64>
379{
380 let i1 = p1.into_iter().cloned().zip(l1.into_iter().cloned().zip(w1.into_iter().cloned()));
382 let i2 = p2.into_iter().cloned().zip(l2.into_iter().cloned().zip(w2.into_iter().cloned()));
383 let labels_and_weights = i1.zip(i2).map(|(t1, t2)| {
384 if t1.0 > t2.0 {
385 t1.1
386 } else {
387 t2.1
388 }
389 });
390 return average_precision_on_descending_iterators(labels_and_weights);
391}
392
393pub fn average_precision_on_descending_iterators<B: BinaryLabel>(labels_and_weights: impl Iterator<Item = (B, f64)>) -> f64 {
394 let mut ap: f64 = 0.0;
395 let mut tps: f64 = 0.0;
396 let mut fps: f64 = 0.0;
397 for (label, weight) in labels_and_weights {
398 let w: f64 = weight;
399 let l: bool = label.get_value();
400 let tp = w * f64::from(l);
401 tps += tp;
402 fps += weight - tp;
403 let ps = tps + fps;
404 let precision = tps / ps;
405 ap += tp * precision;
406 }
407 return if tps == 0.0 {
411 0.0
412 } else {
413 ap / tps
414 };
415}
416
417
418
419pub fn roc_auc<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>) -> f64
421where B: BinaryLabel, L: Data<B>, P: SortableData<f64> + Data<f64>, W: Data<f64>
422{
423 return roc_auc_with_order(labels, predictions, weights, None, None);
424}
425
426pub fn roc_auc_max_fpr<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>, max_false_positive_rate: Option<f64>) -> f64
427where B: BinaryLabel, L: Data<B>, P: SortableData<f64> + Data<f64>, W: Data<f64>
428{
429 return roc_auc_with_order(labels, predictions, weights, None, max_false_positive_rate);
430}
431
432pub fn roc_auc_with_order<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>, order: Option<Order>, max_false_positive_rate: Option<f64>) -> f64
433where B: BinaryLabel, L: Data<B>, P: SortableData<f64> + Data<f64>, W: Data<f64>
434{
435 return match order {
436 Some(o) => roc_auc_on_sorted_labels(labels, predictions, weights, o, max_false_positive_rate),
437 None => {
438 let indices = predictions.argsort_unstable();
439 let sorted_labels = select(labels, &indices);
440 let sorted_predictions = select(predictions, &indices);
441 let roc_auc_score = match weights {
442 Some(w) => {
443 let sorted_weights = select(w, &indices);
444 roc_auc_on_sorted_labels(&sorted_labels, &sorted_predictions, Some(&sorted_weights), Order::DESCENDING, max_false_positive_rate)
445 },
446 None => {
447 roc_auc_on_sorted_labels(&sorted_labels, &sorted_predictions, None::<&W>, Order::DESCENDING, max_false_positive_rate)
448 }
449 };
450 roc_auc_score
451 }
452 };
453}
454pub fn roc_auc_on_sorted_labels<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>, order: Order, max_false_positive_rate: Option<f64>) -> f64
455where B: BinaryLabel, L: Data<B>, P: Data<f64>, W: Data<f64> {
456 return match max_false_positive_rate {
457 None => match weights {
458 Some(w) => roc_auc_on_sorted_iterator(&mut labels.get_iterator(), &mut predictions.get_iterator(), &mut w.get_iterator(), order),
459 None => roc_auc_on_sorted_iterator(&mut labels.get_iterator(), &mut predictions.get_iterator(), &mut ConstWeight::one().get_iterator(), order),
460 }
461 Some(max_fpr) => match weights {
462 Some(w) => roc_auc_on_sorted_with_fp_cutoff(labels, predictions, w, order, max_fpr),
463 None => roc_auc_on_sorted_with_fp_cutoff(labels, predictions, &ConstWeight::one(), order, max_fpr),
464 }
465 };
466}
467
468pub fn roc_auc_on_sorted_iterator<B: BinaryLabel>(
469 labels: &mut impl DoubleEndedIterator<Item = B>,
470 predictions: &mut impl DoubleEndedIterator<Item = f64>,
471 weights: &mut impl DoubleEndedIterator<Item = f64>,
472 order: Order
473) -> f64 {
474 return match order {
475 Order::ASCENDING => roc_auc_on_descending_iterator(&mut labels.rev(), &mut predictions.rev(), &mut weights.rev()),
476 Order::DESCENDING => roc_auc_on_descending_iterator(labels, predictions, weights)
477 }
478}
479
480pub fn roc_auc_on_descending_iterator<B: BinaryLabel>(
481 labels: &mut impl Iterator<Item = B>,
482 predictions: &mut impl Iterator<Item = f64>,
483 weights: &mut impl Iterator<Item = f64>
484) -> f64 {
485 let mut false_positives: f64 = 0.0;
486 let mut true_positives: f64 = 0.0;
487 let mut last_counted_fp = 0.0;
488 let mut last_counted_tp = 0.0;
489 let mut area_under_curve = 0.0;
490 let mut zipped = labels.zip(predictions).zip(weights).peekable();
491 loop {
492 match zipped.next() {
493 None => break,
494 Some(actual) => {
495 let l = f64::from(actual.0.0.get_value());
496 let w = actual.1;
497 let wl = l * w;
498 true_positives += wl;
499 false_positives += w - wl;
500 if zipped.peek().map(|x| x.0.1 != actual.0.1).unwrap_or(true) {
501 area_under_curve += area_under_line_segment(last_counted_fp, false_positives, last_counted_tp, true_positives);
502 last_counted_fp = false_positives;
503 last_counted_tp = true_positives;
504 }
505 }
506 };
507 }
508 return area_under_curve / (true_positives * false_positives);
509}
510
511fn area_under_line_segment(x0: f64, x1: f64, y0: f64, y1: f64) -> f64 {
512 let dx = x1 - x0;
513 let dy = y1 - y0;
514 return dx * y0 + dy * dx * 0.5;
515}
516
517fn get_positive_sum<B: BinaryLabel>(
518 labels: impl Iterator<Item = B>,
519 weights: impl Iterator<Item = f64>
520) -> (f64, f64) {
521 let mut false_positives = 0f64;
522 let mut true_positives = 0f64;
523 for (label, weight) in labels.zip(weights) {
524 let lw = weight * f64::from(label.get_value());
525 false_positives += weight - lw;
526 true_positives += lw;
527 }
528 return (false_positives, true_positives);
529}
530
531pub fn roc_auc_on_sorted_with_fp_cutoff<B, L, P, W>(labels: &L, predictions: &P, weights: &W, order: Order, max_false_positive_rate: f64) -> f64
532where B: BinaryLabel, L: Data<B>, P: Data<f64>, W: Data<f64> {
533 let (fps, tps) = get_positive_sum(labels.get_iterator(), weights.get_iterator());
535 let mut l_it = labels.get_iterator();
536 let mut p_it = predictions.get_iterator();
537 let mut w_it = weights.get_iterator();
538 return match order {
539 Order::ASCENDING => roc_auc_on_descending_iterator_with_fp_cutoff(&mut l_it.rev(), &mut p_it.rev(), &mut w_it.rev(), fps, tps, max_false_positive_rate),
540 Order::DESCENDING => roc_auc_on_descending_iterator_with_fp_cutoff(&mut l_it, &mut p_it, &mut w_it, fps, tps, max_false_positive_rate)
541 };
542}
543
544
545fn roc_auc_on_descending_iterator_with_fp_cutoff<B: BinaryLabel>(
546 labels: &mut impl Iterator<Item = B>,
547 predictions: &mut impl Iterator<Item = f64>,
548 weights: &mut impl Iterator<Item = f64>,
549 false_positive_sum: f64,
550 true_positive_sum: f64,
551 max_false_positive_rate: f64
552) -> f64 {
553 let mut false_positives: f64 = 0.0;
554 let mut true_positives: f64 = 0.0;
555 let mut last_counted_fp = 0.0;
556 let mut last_counted_tp = 0.0;
557 let mut area_under_curve = 0.0;
558 let mut zipped = labels.zip(predictions).zip(weights).peekable();
559 let false_positive_cutoff = max_false_positive_rate * false_positive_sum;
560 loop {
561 match zipped.next() {
562 None => break,
563 Some(actual) => {
564 let l = f64::from(actual.0.0.get_value());
565 let w = actual.1;
566 let wl = l * w;
567 let next_tp = true_positives + wl;
568 let next_fp = false_positives + (w - wl);
569 let is_above_max = next_fp > false_positive_cutoff;
570 if is_above_max {
571 let dx = next_fp - false_positives;
572 let dy = next_tp - true_positives;
573 true_positives += dy * false_positive_cutoff / dx;
574 false_positives = false_positive_cutoff;
575 } else {
576 true_positives = next_tp;
577 false_positives = next_fp;
578 }
579 if zipped.peek().map(|x| x.0.1 != actual.0.1).unwrap_or(true) || is_above_max {
580 area_under_curve += area_under_line_segment(last_counted_fp, false_positives, last_counted_tp, true_positives);
581 last_counted_fp = false_positives;
582 last_counted_tp = true_positives;
583 }
584 if is_above_max {
585 break;
586 }
587 }
588 };
589 }
590 let normalized_area_under_curve = area_under_curve / (true_positive_sum * false_positive_sum);
591 let min_area = 0.5 * max_false_positive_rate * max_false_positive_rate;
592 let max_area = max_false_positive_rate;
593 return 0.5 * (1.0 + (normalized_area_under_curve - min_area) / (max_area - min_area));
594}
595
596pub fn loo_cossim<F: num::Float + AddAssign>(mat: &ArrayView2<'_, F>, replicate_sum: &mut ArrayViewMut1<'_, F>) -> F {
597 let num_replicates = mat.shape()[0];
598 let loo_weight = F::from(num_replicates - 1).unwrap();
599 let loo_weight_factor = F::from(1).unwrap() / loo_weight;
600 for mat_replicate in mat.outer_iter() {
601 for (feature, feature_sum) in mat_replicate.iter().zip(replicate_sum.iter_mut()) {
602 *feature_sum += *feature;
603 }
604 }
605
606 let mut result = F::zero();
607
608 for mat_replicate in mat.outer_iter() {
609 let mut m_sqs = F::zero();
610 let mut l_sqs = F::zero();
611 let mut prod_sum = F::zero();
612 for (feature, feature_sum) in mat_replicate.iter().zip(replicate_sum.iter()) {
613 let m_f = *feature;
614 let l_f = (*feature_sum - *feature) * loo_weight_factor;
615 prod_sum += m_f * l_f;
616 m_sqs += m_f * m_f;
617 l_sqs += l_f * l_f;
618 }
619 result += prod_sum / (m_sqs * l_sqs).sqrt();
620 }
621
622 return result / F::from(num_replicates).unwrap();
623}
624
625pub fn loo_cossim_single<F: num::Float + AddAssign>(mat: &ArrayView2<'_, F>) -> F {
626 let mut replicate_sum = Array1::<F>::zeros(mat.shape()[1]);
627 return loo_cossim(mat, &mut replicate_sum.view_mut());
628}
629
630pub fn loo_cossim_many<F: num::Float + AddAssign>(mat: &ArrayView3<'_, F>) -> Array1<F> {
631 let mut cossims = Array1::<F>::zeros(mat.shape()[0]);
632 let mut replicate_sum = Array1::<F>::zeros(mat.shape()[2]);
633 for (m, c) in mat.outer_iter().zip(cossims.iter_mut()) {
634 replicate_sum.fill(F::zero());
635 *c = loo_cossim(&m, &mut replicate_sum.view_mut());
636 }
637 return cossims;
638}
639
640
641#[pyclass(eq, eq_int, name="Order")]
643#[derive(Clone, Copy, PartialEq)]
644pub enum PyOrder {
645 ASCENDING,
646 DESCENDING
647}
648
649fn py_order_as_order(order: PyOrder) -> Order {
650 return match order {
651 PyOrder::ASCENDING => Order::ASCENDING,
652 PyOrder::DESCENDING => Order::DESCENDING,
653 }
654}
655
656
657trait PyScore: Ungil + Sync {
658
659 fn score<B, L, P, W>(&self, labels: &L, predictions: &P, weights: Option<&W>, order: Option<Order>) -> f64
660 where B: BinaryLabel, L: Data<B>, P: SortableData<f64> + Data<f64>, W: Data<f64>;
661
662 fn score_py_generic<'py, B>(
663 &self,
664 py: Python<'py>,
665 labels: &PyReadonlyArray1<'py, B>,
666 predictions: &PyReadonlyArray1<'py, f64>,
667 weights: &Option<PyReadonlyArray1<'py, f64>>,
668 order: &Option<PyOrder>,
669 ) -> f64
670 where B: BinaryLabel + Element
671 {
672 let labels = labels.as_array();
673 let predictions = predictions.as_array();
674 let order = order.map(py_order_as_order);
675 let score = match weights {
676 Some(weight) => {
677 let weights = weight.as_array();
678 py.allow_threads(move || {
679 self.score(&labels, &predictions, Some(&weights), order)
680 })
681 },
682 None => py.allow_threads(move || {
683 self.score(&labels, &predictions, None::<&Vec<f64>>, order)
684 })
685 };
686 return score;
687 }
688
689 fn score_py_match_run<'py, T>(
690 &self,
691 py: Python<'py>,
692 labels: &Bound<'py, PyUntypedArray>,
693 predictions: &PyReadonlyArray1<'py, f64>,
694 weights: &Option<PyReadonlyArray1<'py, f64>>,
695 order: &Option<PyOrder>,
696 dt: &Bound<'py, PyArrayDescr>
697 ) -> Option<f64>
698 where T: Element + BinaryLabel
699 {
700 return if dt.is_equiv_to(&dtype::<T>(py)) {
701 let labels = labels.downcast::<PyArray1<T>>().unwrap().readonly();
702 Some(self.score_py_generic(py, &labels.readonly(), predictions, weights, order))
703 } else {
704 None
705 };
706 }
707
708 fn score_py<'py>(
709 &self,
710 py: Python<'py>,
711 labels: &Bound<'py, PyUntypedArray>,
712 predictions: PyReadonlyArray1<'py, f64>,
713 weights: Option<PyReadonlyArray1<'py, f64>>,
714 order: Option<PyOrder>,
715 ) -> PyResult<f64> {
716 if labels.ndim() != 1 {
717 return Err(PyTypeError::new_err(format!("Expected 1-dimensional array for labels but found {} dimenisons.", labels.ndim())));
718 }
719 let label_dtype = labels.dtype();
720 if let Some(score) = self.score_py_match_run::<bool>(py, &labels, &predictions, &weights, &order, &label_dtype) {
721 return Ok(score)
722 }
723 else if let Some(score) = self.score_py_match_run::<u8>(py, &labels, &predictions, &weights, &order, &label_dtype) {
724 return Ok(score)
725 }
726 else if let Some(score) = self.score_py_match_run::<i8>(py, &labels, &predictions, &weights, &order, &label_dtype) {
727 return Ok(score)
728 }
729 else if let Some(score) = self.score_py_match_run::<u16>(py, &labels, &predictions, &weights, &order, &label_dtype) {
730 return Ok(score)
731 }
732 else if let Some(score) = self.score_py_match_run::<i16>(py, &labels, &predictions, &weights, &order, &label_dtype) {
733 return Ok(score)
734 }
735 else if let Some(score) = self.score_py_match_run::<u32>(py, &labels, &predictions, &weights, &order, &label_dtype) {
736 return Ok(score)
737 }
738 else if let Some(score) = self.score_py_match_run::<i32>(py, &labels, &predictions, &weights, &order, &label_dtype) {
739 return Ok(score)
740 }
741 else if let Some(score) = self.score_py_match_run::<u64>(py, &labels, &predictions, &weights, &order, &label_dtype) {
742 return Ok(score)
743 }
744 else if let Some(score) = self.score_py_match_run::<i64>(py, &labels, &predictions, &weights, &order, &label_dtype) {
745 return Ok(score)
746 }
747 return Err(PyTypeError::new_err(format!("Unsupported dtype for labels: {}. Supported dtypes are bool, uint8, uint16, uint32, uint64, in8, int16, int32, int64", label_dtype)));
748 }
749}
750
751struct PyAveragePrecision {
752
753}
754
755impl PyAveragePrecision{
756 fn new() -> Self {
757 return PyAveragePrecision {};
758 }
759}
760
761impl PyScore for PyAveragePrecision {
762 fn score<B, L, P, W>(&self, labels: &L, predictions: &P, weights: Option<&W>, order: Option<Order>) -> f64
763 where B: BinaryLabel, L: Data<B>, P: SortableData<f64> + Data<f64>, W: Data<f64> {
764 return average_precision_with_order(labels, predictions, weights, order);
765 }
766}
767
768struct PyRocAuc {
769 max_fpr: Option<f64>
770}
771
772impl PyRocAuc {
773 fn new(max_fpr: Option<f64>) -> Self {
774 return PyRocAuc { max_fpr: max_fpr };
775 }
776}
777
778impl PyScore for PyRocAuc {
779 fn score<B, L, P, W>(&self, labels: &L, predictions: &P, weights: Option<&W>, order: Option<Order>) -> f64
780 where B: BinaryLabel, L: Data<B>, P: SortableData<f64> + Data<f64>, W: Data<f64> {
781 return roc_auc_with_order(labels, predictions, weights, order, self.max_fpr);
782 }
783}
784
785
786#[pyfunction(name = "average_precision")]
787#[pyo3(signature = (labels, predictions, *, weights=None, order=None))]
788pub fn average_precision_py<'py>(
789 py: Python<'py>,
790 labels: &Bound<'py, PyUntypedArray>,
791 predictions: PyReadonlyArray1<'py, f64>,
792 weights: Option<PyReadonlyArray1<'py, f64>>,
793 order: Option<PyOrder>
794) -> PyResult<f64> {
795 return PyAveragePrecision::new().score_py(py, labels, predictions, weights, order);
796}
797
798#[pyfunction(name = "roc_auc")]
799#[pyo3(signature = (labels, predictions, *, weights=None, order=None, max_fpr=None))]
800pub fn roc_auc_py<'py>(
801 py: Python<'py>,
802 labels: &Bound<'py, PyUntypedArray>,
803 predictions: PyReadonlyArray1<'py, f64>,
804 weights: Option<PyReadonlyArray1<'py, f64>>,
805 order: Option<PyOrder>,
806 max_fpr: Option<f64>,
807) -> PyResult<f64> {
808 return PyRocAuc::new(max_fpr).score_py(py, labels, predictions, weights, order);
809}
810
811#[pyfunction(name = "loo_cossim")]
812#[pyo3(signature = (data))]
813pub fn loo_cossim_py<'py>(
814 py: Python<'py>,
815 data: &Bound<'py, PyUntypedArray>
816) -> PyResult<f64> {
817 if data.ndim() != 2 {
818 return Err(PyTypeError::new_err(format!("Expected 2-dimensional array for data (samples x features) but found {} dimenisons.", data.ndim())));
819 }
820
821 let dt = data.dtype();
822 if dt.is_equiv_to(&dtype::<f32>(py)) {
823 let typed_data = data.downcast::<PyArray2<f32>>().unwrap().readonly();
824 let array = typed_data.as_array();
825 let score = py.allow_threads(move || {
826 loo_cossim_single(&array)
827 });
828 return Ok(score as f64);
829 }
830 if dt.is_equiv_to(&dtype::<f64>(py)) {
831 let typed_data = data.downcast::<PyArray2<f64>>().unwrap().readonly();
832 let array = typed_data.as_array();
833 let score = py.allow_threads(move || {
834 loo_cossim_single(&array)
835 });
836 return Ok(score);
837 }
838 return Err(PyTypeError::new_err(format!("Only float32 and float64 data supported, but found {}", dt)));
839}
840
841pub fn loo_cossim_many_generic_py<'py, F: num::Float + AddAssign + Element>(
842 py: Python<'py>,
843 data: &Bound<'py, PyArrayDyn<F>>
844) -> PyResult<Bound<'py, PyArray1<F>>> {
845 if data.ndim() != 3 {
846 return Err(PyTypeError::new_err(format!("Expected 3-dimensional array for data (outer(?) x samples x features) but found {} dimenisons.", data.ndim())));
847 }
848 let typed_data = data.downcast::<PyArray3<F>>().unwrap().readonly();
849 let array = typed_data.as_array();
850 let score = py.allow_threads(move || {
851 loo_cossim_many(&array)
852 });
853 let score_py = PyArray::from_owned_array(py, score);
855 return Ok(score_py);
856}
857
858#[pyfunction(name = "loo_cossim_many_f64")]
859#[pyo3(signature = (data))]
860pub fn loo_cossim_many_py_f64<'py>(
861 py: Python<'py>,
862 data: &Bound<'py, PyUntypedArray>
863) -> PyResult<Bound<'py, PyArray1<f64>>> {
864 if data.ndim() != 3 {
865 return Err(PyTypeError::new_err(format!("Expected 3-dimensional array for data (outer(?) x samples x features) but found {} dimenisons.", data.ndim())));
866 }
867
868 let dt = data.dtype();
869 if !dt.is_equiv_to(&dtype::<f64>(py)) {
870 return Err(PyTypeError::new_err(format!("Only float64 data supported, but found {}", dt)));
871 }
872 let typed_data = data.downcast::<PyArrayDyn<f64>>().unwrap();
873 return loo_cossim_many_generic_py(py, typed_data);
874}
875
876#[pyfunction(name = "loo_cossim_many_f32")]
877#[pyo3(signature = (data))]
878pub fn loo_cossim_many_py_f32<'py>(
879 py: Python<'py>,
880 data: &Bound<'py, PyUntypedArray>
881) -> PyResult<Bound<'py, PyArray1<f32>>> {
882 if data.ndim() != 3 {
883 return Err(PyTypeError::new_err(format!("Expected 3-dimensional array for data (outer(?) x samples x features) but found {} dimenisons.", data.ndim())));
884 }
885
886 let dt = data.dtype();
887 if !dt.is_equiv_to(&dtype::<f32>(py)) {
888 return Err(PyTypeError::new_err(format!("Only float32 data supported, but found {}", dt)));
889 }
890 let typed_data = data.downcast::<PyArrayDyn<f32>>().unwrap();
891 return loo_cossim_many_generic_py(py, typed_data);
892}
893
894#[pymodule(name = "_scors")]
895fn scors(m: &Bound<'_, PyModule>) -> PyResult<()> {
896 m.add_function(wrap_pyfunction!(average_precision_py, m)?).unwrap();
897 m.add_function(wrap_pyfunction!(roc_auc_py, m)?).unwrap();
898 m.add_function(wrap_pyfunction!(loo_cossim_py, m)?).unwrap();
899 m.add_function(wrap_pyfunction!(loo_cossim_many_py_f64, m)?).unwrap();
900 m.add_function(wrap_pyfunction!(loo_cossim_many_py_f32, m)?).unwrap();
901 m.add_class::<PyOrder>().unwrap();
902 return Ok(());
903}
904
905
906#[cfg(test)]
907mod tests {
908 use approx::{assert_relative_eq};
909 use super::*;
910
911 #[test]
912 fn test_average_precision_on_sorted() {
913 let labels: [u8; 4] = [1, 0, 1, 0];
914 let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
916 let actual = average_precision_on_sorted_labels(&labels, Some(&weights), Order::DESCENDING);
917 assert_eq!(actual, 0.8333333333333333);
918 }
919
920 #[test]
921 fn test_average_precision_unsorted() {
922 let labels: [u8; 4] = [0, 0, 1, 1];
923 let predictions: [f64; 4] = [0.1, 0.4, 0.35, 0.8];
924 let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
925 let actual = average_precision_with_order(&labels, &predictions, Some(&weights), None);
926 assert_eq!(actual, 0.8333333333333333);
927 }
928
929 #[test]
930 fn test_average_precision_sorted() {
931 let labels: [u8; 4] = [1, 0, 1, 0];
932 let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
933 let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
934 let actual = average_precision_with_order(&labels, &predictions, Some(&weights), Some(Order::DESCENDING));
935 assert_eq!(actual, 0.8333333333333333);
936 }
937
938 #[test]
939 fn test_average_precision_sorted_pair() {
940 let labels: [u8; 4] = [1, 0, 1, 0];
941 let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
942 let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
943 let actual = average_precision_on_sorted_samples(&labels, &predictions, &weights, &labels, &predictions, &weights);
944 assert_eq!(actual, 0.8333333333333333);
945 }
946
947 #[test]
948 fn test_roc_auc() {
949 let labels: [u8; 4] = [1, 0, 1, 0];
950 let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
951 let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
952 let actual = roc_auc_with_order(&labels, &predictions, Some(&weights), Some(Order::DESCENDING), None);
953 assert_eq!(actual, 0.75);
954 }
955
956 #[test]
957 fn test_loo_cossim_single() {
958 let data = arr2(&[[0.77395605, 0.43887844, 0.85859792],
959 [0.69736803, 0.09417735, 0.97562235]]);
960 let cossim = loo_cossim_single(&data.view());
961 let expected = 0.95385941;
962 assert_relative_eq!(cossim, expected);
963 }
964
965 #[test]
966 fn test_loo_cossim_many() {
967 let data = arr3(&[[[0.77395605, 0.43887844, 0.85859792],
968 [0.69736803, 0.09417735, 0.97562235]],
969 [[0.7611397 , 0.78606431, 0.12811363],
970 [0.45038594, 0.37079802, 0.92676499]],
971 [[0.64386512, 0.82276161, 0.4434142 ],
972 [0.22723872, 0.55458479, 0.06381726]],
973 [[0.82763117, 0.6316644 , 0.75808774],
974 [0.35452597, 0.97069802, 0.89312112]]]);
975 let cossim = loo_cossim_many(&data.view());
976 let expected = arr1(&[0.95385941, 0.62417001, 0.92228589, 0.90025417]);
977 assert_eq!(cossim.shape(), expected.shape());
978 for (c, e) in cossim.iter().zip(expected.iter()) {
979 assert_relative_eq!(c, e);
980 }
981 }
982}