1use ndarray::{Array1,ArrayView,ArrayView2,ArrayView3,ArrayViewMut1,Ix1};
2use numpy::{Element,PyArray,PyArray1,PyArray2,PyArray3,PyArrayDescr,PyArrayDescrMethods,PyArrayMethods,PyReadonlyArray1,PyUntypedArray,PyUntypedArrayMethods,dtype};
3use pyo3::Bound;
4use pyo3::exceptions::PyTypeError;
5use pyo3::marker::Ungil;
6use pyo3::prelude::*;
7use std::iter::DoubleEndedIterator;
8use std::marker::PhantomData;
9
10#[derive(Clone, Copy)]
11pub enum Order {
12 ASCENDING,
13 DESCENDING
14}
15
16struct ConstWeight {
17 value: f64
18}
19
20impl ConstWeight {
21 fn new(value: f64) -> Self {
22 return ConstWeight { value: value };
23 }
24 fn one() -> Self {
25 return Self::new(1.0);
26 }
27}
28
29pub trait Data<T: Clone>: {
30 fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T>;
33 fn get_at(&self, index: usize) -> T;
34}
35
36pub trait SortableData<T> {
37 fn argsort_unstable(&self) -> Vec<usize>;
38}
39
40impl Iterator for ConstWeight {
41 type Item = f64;
42 fn next(&mut self) -> Option<f64> {
43 return Some(self.value);
44 }
45}
46
47impl DoubleEndedIterator for ConstWeight {
48 fn next_back(&mut self) -> Option<f64> {
49 return Some(self.value);
50 }
51}
52
53impl Data<f64> for ConstWeight {
54 fn get_iterator(&self) -> impl DoubleEndedIterator<Item = f64> {
55 return ConstWeight::new(self.value);
56 }
57
58 fn get_at(&self, _index: usize) -> f64 {
59 return self.value.clone();
60 }
61}
62
63impl <T: Clone> Data<T> for Vec<T> {
64 fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> {
65 return self.iter().cloned();
66 }
67 fn get_at(&self, index: usize) -> T {
68 return self[index].clone();
69 }
70}
71
72impl SortableData<f64> for Vec<f64> {
73 fn argsort_unstable(&self) -> Vec<usize> {
74 let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
75 indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
76 return indices;
78 }
79}
80
81impl <T: Clone> Data<T> for &[T] {
82 fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> {
83 return self.iter().cloned();
84 }
85 fn get_at(&self, index: usize) -> T {
86 return self[index].clone();
87 }
88}
89
90impl SortableData<f64> for &[f64] {
91 fn argsort_unstable(&self) -> Vec<usize> {
92 let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
94 indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
97 return indices;
99 }
100}
101
102impl <T: Clone, const N: usize> Data<T> for [T; N] {
103 fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> {
104 return self.iter().cloned();
105 }
106 fn get_at(&self, index: usize) -> T {
107 return self[index].clone();
108 }
109}
110
111impl <const N: usize> SortableData<f64> for [f64; N] {
112 fn argsort_unstable(&self) -> Vec<usize> {
113 let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
114 indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
115 return indices;
116 }
117}
118
119impl <T: Clone> Data<T> for ArrayView<'_, T, Ix1> {
120 fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> {
121 return self.iter().cloned();
122 }
123 fn get_at(&self, index: usize) -> T {
124 return self[index].clone();
125 }
126}
127
128impl SortableData<f64> for ArrayView<'_, f64, Ix1> {
129 fn argsort_unstable(&self) -> Vec<usize> {
130 let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
131 indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
132 return indices;
133 }
134}
135
136pub trait BinaryLabel: Clone + Copy {
145 fn get_value(&self) -> bool;
146}
147
148impl BinaryLabel for bool {
149 fn get_value(&self) -> bool {
150 return self.clone();
151 }
152}
153
154impl BinaryLabel for u8 {
155 fn get_value(&self) -> bool {
156 return (self & 1) == 1;
157 }
158}
159
160impl BinaryLabel for u16 {
161 fn get_value(&self) -> bool {
162 return (self & 1) == 1;
163 }
164}
165
166impl BinaryLabel for u32 {
167 fn get_value(&self) -> bool {
168 return (self & 1) == 1;
169 }
170}
171
172impl BinaryLabel for u64 {
173 fn get_value(&self) -> bool {
174 return (self & 1) == 1;
175 }
176}
177
178impl BinaryLabel for i8 {
179 fn get_value(&self) -> bool {
180 return (self & 1) == 1;
181 }
182}
183
184impl BinaryLabel for i16 {
185 fn get_value(&self) -> bool {
186 return (self & 1) == 1;
187 }
188}
189
190impl BinaryLabel for i32 {
191 fn get_value(&self) -> bool {
192 return (self & 1) == 1;
193 }
194}
195
196impl BinaryLabel for i64 {
197 fn get_value(&self) -> bool {
198 return (self & 1) == 1;
199 }
200}
201
202struct SortedSampleDescending<'a, B, L, P, W>
203where B: BinaryLabel + Clone + 'a, &'a L: IntoIterator<Item = &'a B>, &'a P: IntoIterator<Item = &'a f64>, &'a W: IntoIterator<Item = &'a f64>
204{
205 labels: &'a L,
206 predictions: &'a P,
207 weights: &'a W,
208 label_type: PhantomData<B>,
209}
210
211impl <'a, B, L, P, W> SortedSampleDescending<'a, B, L, P, W>
212where B: BinaryLabel + Clone + 'a, &'a L: IntoIterator<Item = &'a B>, &'a P: IntoIterator<Item = &'a f64>, &'a W: IntoIterator<Item = &'a f64>
213{
214 fn new(labels: &'a L, predictions: &'a P, weights: &'a W) -> Self {
215 return SortedSampleDescending {
216 labels: labels,
217 predictions: predictions,
218 weights: weights,
219 label_type: PhantomData
220 }
221 }
222}
223
224fn select<T, I>(slice: &I, indices: &[usize]) -> Vec<T>
314where T: Copy, I: Data<T>
315{
316 let mut selection: Vec<T> = Vec::new();
317 selection.reserve_exact(indices.len());
318 for index in indices {
319 selection.push(slice.get_at(*index));
320 }
321 return selection;
322}
323
324pub fn average_precision<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>) -> f64
325where B: BinaryLabel, L: Data<B>, P: SortableData<f64>, W: Data<f64>
326{
327 return average_precision_with_order(labels, predictions, weights, None);
328}
329
330pub fn average_precision_with_order<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>, order: Option<Order>) -> f64
331where B: BinaryLabel, L: Data<B>, P: SortableData<f64>, W: Data<f64>
332{
333 return match order {
334 Some(o) => average_precision_on_sorted_labels(labels, weights, o),
335 None => {
336 let indices = predictions.argsort_unstable();
337 let sorted_labels = select(labels, &indices);
338 let ap = match weights {
339 None => {
340 average_precision_on_sorted_labels(&sorted_labels, weights, Order::DESCENDING)
342 },
343 Some(w) => average_precision_on_sorted_labels(&sorted_labels, Some(&select(w, &indices)), Order::DESCENDING),
344 };
345 ap
346 }
347 };
348}
349
350pub fn average_precision_on_sorted_labels<B, L, W>(labels: &L, weights: Option<&W>, order: Order) -> f64
351where B: BinaryLabel, L: Data<B>, W: Data<f64>
352{
353 return match weights {
354 None => average_precision_on_iterator(labels.get_iterator(), ConstWeight::one(), order),
355 Some(w) => average_precision_on_iterator(labels.get_iterator(), w.get_iterator(), order)
356 };
357}
358
359pub fn average_precision_on_iterator<B, L, W>(labels: L, weights: W, order: Order) -> f64
360where B: BinaryLabel, L: DoubleEndedIterator<Item = B>, W: DoubleEndedIterator<Item = f64>
361{
362 return match order {
363 Order::ASCENDING => average_precision_on_descending_iterator(labels.rev(), weights.rev()),
364 Order::DESCENDING => average_precision_on_descending_iterator(labels, weights)
365 };
366}
367
368pub fn average_precision_on_descending_iterator<B: BinaryLabel>(labels: impl Iterator<Item = B>, weights: impl Iterator<Item = f64>) -> f64 {
369 return average_precision_on_descending_iterators(labels.zip(weights));
370}
371
372pub fn average_precision_on_sorted_samples<'a, B, L, P, W>(l1: &'a L, p1: &'a P, w1: &'a W, l2: &'a L, p2: &'a P, w2: &'a W) -> f64
376where B: BinaryLabel + Clone + 'a, &'a L: IntoIterator<Item = &'a B>, &'a P: IntoIterator<Item = &'a f64>, &'a W: IntoIterator<Item = &'a f64>
377{
378 let i1 = p1.into_iter().cloned().zip(l1.into_iter().cloned().zip(w1.into_iter().cloned()));
380 let i2 = p2.into_iter().cloned().zip(l2.into_iter().cloned().zip(w2.into_iter().cloned()));
381 let labels_and_weights = i1.zip(i2).map(|(t1, t2)| {
382 if t1.0 > t2.0 {
383 t1.1
384 } else {
385 t2.1
386 }
387 });
388 return average_precision_on_descending_iterators(labels_and_weights);
389}
390
391pub fn average_precision_on_descending_iterators<B: BinaryLabel>(labels_and_weights: impl Iterator<Item = (B, f64)>) -> f64 {
392 let mut ap: f64 = 0.0;
393 let mut tps: f64 = 0.0;
394 let mut fps: f64 = 0.0;
395 for (label, weight) in labels_and_weights {
396 let w: f64 = weight;
397 let l: bool = label.get_value();
398 let tp = w * f64::from(l);
399 tps += tp;
400 fps += weight - tp;
401 let ps = tps + fps;
402 let precision = tps / ps;
403 ap += tp * precision;
404 }
405 return if tps == 0.0 {
409 0.0
410 } else {
411 ap / tps
412 };
413}
414
415
416
417pub fn roc_auc<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>) -> f64
419where B: BinaryLabel, L: Data<B>, P: SortableData<f64> + Data<f64>, W: Data<f64>
420{
421 return roc_auc_with_order(labels, predictions, weights, None, None);
422}
423
424pub fn roc_auc_max_fpr<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>, max_false_positive_rate: Option<f64>) -> f64
425where B: BinaryLabel, L: Data<B>, P: SortableData<f64> + Data<f64>, W: Data<f64>
426{
427 return roc_auc_with_order(labels, predictions, weights, None, max_false_positive_rate);
428}
429
430pub fn roc_auc_with_order<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>, order: Option<Order>, max_false_positive_rate: Option<f64>) -> f64
431where B: BinaryLabel, L: Data<B>, P: SortableData<f64> + Data<f64>, W: Data<f64>
432{
433 return match order {
434 Some(o) => roc_auc_on_sorted_labels(labels, predictions, weights, o, max_false_positive_rate),
435 None => {
436 let indices = predictions.argsort_unstable();
437 let sorted_labels = select(labels, &indices);
438 let sorted_predictions = select(predictions, &indices);
439 let roc_auc_score = match weights {
440 Some(w) => {
441 let sorted_weights = select(w, &indices);
442 roc_auc_on_sorted_labels(&sorted_labels, &sorted_predictions, Some(&sorted_weights), Order::DESCENDING, max_false_positive_rate)
443 },
444 None => {
445 roc_auc_on_sorted_labels(&sorted_labels, &sorted_predictions, None::<&W>, Order::DESCENDING, max_false_positive_rate)
446 }
447 };
448 roc_auc_score
449 }
450 };
451}
452pub fn roc_auc_on_sorted_labels<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>, order: Order, max_false_positive_rate: Option<f64>) -> f64
453where B: BinaryLabel, L: Data<B>, P: Data<f64>, W: Data<f64> {
454 return match max_false_positive_rate {
455 None => match weights {
456 Some(w) => roc_auc_on_sorted_iterator(&mut labels.get_iterator(), &mut predictions.get_iterator(), &mut w.get_iterator(), order),
457 None => roc_auc_on_sorted_iterator(&mut labels.get_iterator(), &mut predictions.get_iterator(), &mut ConstWeight::one().get_iterator(), order),
458 }
459 Some(max_fpr) => match weights {
460 Some(w) => roc_auc_on_sorted_with_fp_cutoff(labels, predictions, w, order, max_fpr),
461 None => roc_auc_on_sorted_with_fp_cutoff(labels, predictions, &ConstWeight::one(), order, max_fpr),
462 }
463 };
464}
465
466pub fn roc_auc_on_sorted_iterator<B: BinaryLabel>(
467 labels: &mut impl DoubleEndedIterator<Item = B>,
468 predictions: &mut impl DoubleEndedIterator<Item = f64>,
469 weights: &mut impl DoubleEndedIterator<Item = f64>,
470 order: Order
471) -> f64 {
472 return match order {
473 Order::ASCENDING => roc_auc_on_descending_iterator(&mut labels.rev(), &mut predictions.rev(), &mut weights.rev()),
474 Order::DESCENDING => roc_auc_on_descending_iterator(labels, predictions, weights)
475 }
476}
477
478pub fn roc_auc_on_descending_iterator<B: BinaryLabel>(
479 labels: &mut impl Iterator<Item = B>,
480 predictions: &mut impl Iterator<Item = f64>,
481 weights: &mut impl Iterator<Item = f64>
482) -> f64 {
483 let mut false_positives: f64 = 0.0;
484 let mut true_positives: f64 = 0.0;
485 let mut last_counted_fp = 0.0;
486 let mut last_counted_tp = 0.0;
487 let mut area_under_curve = 0.0;
488 let mut zipped = labels.zip(predictions).zip(weights).peekable();
489 loop {
490 match zipped.next() {
491 None => break,
492 Some(actual) => {
493 let l = f64::from(actual.0.0.get_value());
494 let w = actual.1;
495 let wl = l * w;
496 true_positives += wl;
497 false_positives += w - wl;
498 if zipped.peek().map(|x| x.0.1 != actual.0.1).unwrap_or(true) {
499 area_under_curve += area_under_line_segment(last_counted_fp, false_positives, last_counted_tp, true_positives);
500 last_counted_fp = false_positives;
501 last_counted_tp = true_positives;
502 }
503 }
504 };
505 }
506 return area_under_curve / (true_positives * false_positives);
507}
508
509fn area_under_line_segment(x0: f64, x1: f64, y0: f64, y1: f64) -> f64 {
510 let dx = x1 - x0;
511 let dy = y1 - y0;
512 return dx * y0 + dy * dx * 0.5;
513}
514
515fn get_positive_sum<B: BinaryLabel>(
516 labels: impl Iterator<Item = B>,
517 weights: impl Iterator<Item = f64>
518) -> (f64, f64) {
519 let mut false_positives = 0f64;
520 let mut true_positives = 0f64;
521 for (label, weight) in labels.zip(weights) {
522 let lw = weight * f64::from(label.get_value());
523 false_positives += weight - lw;
524 true_positives += lw;
525 }
526 return (false_positives, true_positives);
527}
528
529pub fn roc_auc_on_sorted_with_fp_cutoff<B, L, P, W>(labels: &L, predictions: &P, weights: &W, order: Order, max_false_positive_rate: f64) -> f64
530where B: BinaryLabel, L: Data<B>, P: Data<f64>, W: Data<f64> {
531 let (fps, tps) = get_positive_sum(labels.get_iterator(), weights.get_iterator());
533 let mut l_it = labels.get_iterator();
534 let mut p_it = predictions.get_iterator();
535 let mut w_it = weights.get_iterator();
536 return match order {
537 Order::ASCENDING => roc_auc_on_descending_iterator_with_fp_cutoff(&mut l_it.rev(), &mut p_it.rev(), &mut w_it.rev(), fps, tps, max_false_positive_rate),
538 Order::DESCENDING => roc_auc_on_descending_iterator_with_fp_cutoff(&mut l_it, &mut p_it, &mut w_it, fps, tps, max_false_positive_rate)
539 };
540}
541
542
543fn roc_auc_on_descending_iterator_with_fp_cutoff<B: BinaryLabel>(
544 labels: &mut impl Iterator<Item = B>,
545 predictions: &mut impl Iterator<Item = f64>,
546 weights: &mut impl Iterator<Item = f64>,
547 false_positive_sum: f64,
548 true_positive_sum: f64,
549 max_false_positive_rate: f64
550) -> f64 {
551 let mut false_positives: f64 = 0.0;
552 let mut true_positives: f64 = 0.0;
553 let mut last_counted_fp = 0.0;
554 let mut last_counted_tp = 0.0;
555 let mut area_under_curve = 0.0;
556 let mut zipped = labels.zip(predictions).zip(weights).peekable();
557 let false_positive_cutoff = max_false_positive_rate * false_positive_sum;
558 loop {
559 match zipped.next() {
560 None => break,
561 Some(actual) => {
562 let l = f64::from(actual.0.0.get_value());
563 let w = actual.1;
564 let wl = l * w;
565 let next_tp = true_positives + wl;
566 let next_fp = false_positives + (w - wl);
567 let is_above_max = next_fp > false_positive_cutoff;
568 if is_above_max {
569 let dx = next_fp - false_positives;
570 let dy = next_tp - true_positives;
571 true_positives += dy * false_positive_cutoff / dx;
572 false_positives = false_positive_cutoff;
573 } else {
574 true_positives = next_tp;
575 false_positives = next_fp;
576 }
577 if zipped.peek().map(|x| x.0.1 != actual.0.1).unwrap_or(true) || is_above_max {
578 area_under_curve += area_under_line_segment(last_counted_fp, false_positives, last_counted_tp, true_positives);
579 last_counted_fp = false_positives;
580 last_counted_tp = true_positives;
581 }
582 if is_above_max {
583 break;
584 }
585 }
586 };
587 }
588 let normalized_area_under_curve = area_under_curve / (true_positive_sum * false_positive_sum);
589 let min_area = 0.5 * max_false_positive_rate * max_false_positive_rate;
590 let max_area = max_false_positive_rate;
591 return 0.5 * (1.0 + (normalized_area_under_curve - min_area) / (max_area - min_area));
592}
593
594pub fn loo_cossim(mat: &ArrayView2<'_, f64>, replicate_sum: &mut ArrayViewMut1<'_, f64>) -> f64 {
595 let num_replicates = mat.shape()[0];
596 let loo_weight = num_replicates as f64 - 1.0;
597 let loo_weight_factor = 1.0 / loo_weight;
598 for mat_replicate in mat.outer_iter() {
599 for (feature, feature_sum) in mat_replicate.iter().zip(replicate_sum.iter_mut()) {
600 *feature_sum += feature;
601 }
602 }
603
604 let mut result = 0f64;
605
606 for mat_replicate in mat.outer_iter() {
607 let mut m_sqs = 0.0;
608 let mut l_sqs = 0.0;
609 let mut prod_sum = 0.0;
610 for (feature, feature_sum) in mat_replicate.iter().zip(replicate_sum.iter()) {
611 let m_f = feature;
612 let l_f = (feature_sum - feature) * loo_weight_factor;
613 prod_sum += m_f * l_f;
614 m_sqs += m_f * m_f;
615 l_sqs += l_f * l_f;
616 }
617 result += prod_sum / (m_sqs * l_sqs).sqrt();
618 }
619
620 return result / num_replicates as f64;
621}
622
623pub fn loo_cossim_single(mat: &ArrayView2<'_, f64>) -> f64 {
624 let mut replicate_sum = Array1::<f64>::zeros(mat.shape()[1]);
625 return loo_cossim(mat, &mut replicate_sum.view_mut());
626}
627
628pub fn loo_cossim_many(mat: &ArrayView3<'_, f64>) -> Array1<f64> {
629 let mut cossims = Array1::<f64>::zeros(mat.shape()[0]);
630 let mut replicate_sum = Array1::<f64>::zeros(mat.shape()[2]);
631 for (m, c) in mat.outer_iter().zip(cossims.iter_mut()) {
632 replicate_sum.fill(0.0);
633 *c = loo_cossim(&m, &mut replicate_sum.view_mut());
634 }
635 return cossims;
636}
637
638
639#[pyclass(eq, eq_int, name="Order")]
641#[derive(Clone, Copy, PartialEq)]
642pub enum PyOrder {
643 ASCENDING,
644 DESCENDING
645}
646
647fn py_order_as_order(order: PyOrder) -> Order {
648 return match order {
649 PyOrder::ASCENDING => Order::ASCENDING,
650 PyOrder::DESCENDING => Order::DESCENDING,
651 }
652}
653
654
655trait PyScore: Ungil + Sync {
656
657 fn score<B, L, P, W>(&self, labels: &L, predictions: &P, weights: Option<&W>, order: Option<Order>) -> f64
658 where B: BinaryLabel, L: Data<B>, P: SortableData<f64> + Data<f64>, W: Data<f64>;
659
660 fn score_py_generic<'py, B>(
661 &self,
662 py: Python<'py>,
663 labels: &PyReadonlyArray1<'py, B>,
664 predictions: &PyReadonlyArray1<'py, f64>,
665 weights: &Option<PyReadonlyArray1<'py, f64>>,
666 order: &Option<PyOrder>,
667 ) -> f64
668 where B: BinaryLabel + Element
669 {
670 let labels = labels.as_array();
671 let predictions = predictions.as_array();
672 let order = order.map(py_order_as_order);
673 let score = match weights {
674 Some(weight) => {
675 let weights = weight.as_array();
676 py.allow_threads(move || {
677 self.score(&labels, &predictions, Some(&weights), order)
678 })
679 },
680 None => py.allow_threads(move || {
681 self.score(&labels, &predictions, None::<&Vec<f64>>, order)
682 })
683 };
684 return score;
685 }
686
687 fn score_py_match_run<'py, T>(
688 &self,
689 py: Python<'py>,
690 labels: &Bound<'py, PyUntypedArray>,
691 predictions: &PyReadonlyArray1<'py, f64>,
692 weights: &Option<PyReadonlyArray1<'py, f64>>,
693 order: &Option<PyOrder>,
694 dt: &Bound<'py, PyArrayDescr>
695 ) -> Option<f64>
696 where T: Element + BinaryLabel
697 {
698 return if dt.is_equiv_to(&dtype::<T>(py)) {
699 let labels = labels.downcast::<PyArray1<T>>().unwrap().readonly();
700 Some(self.score_py_generic(py, &labels.readonly(), predictions, weights, order))
701 } else {
702 None
703 };
704 }
705
706 fn score_py<'py>(
707 &self,
708 py: Python<'py>,
709 labels: &Bound<'py, PyUntypedArray>,
710 predictions: PyReadonlyArray1<'py, f64>,
711 weights: Option<PyReadonlyArray1<'py, f64>>,
712 order: Option<PyOrder>,
713 ) -> PyResult<f64> {
714 if labels.ndim() != 1 {
715 return Err(PyTypeError::new_err(format!("Expected 1-dimensional array for labels but found {} dimenisons.", labels.ndim())));
716 }
717 let label_dtype = labels.dtype();
718 if let Some(score) = self.score_py_match_run::<bool>(py, &labels, &predictions, &weights, &order, &label_dtype) {
719 return Ok(score)
720 }
721 else if let Some(score) = self.score_py_match_run::<u8>(py, &labels, &predictions, &weights, &order, &label_dtype) {
722 return Ok(score)
723 }
724 else if let Some(score) = self.score_py_match_run::<i8>(py, &labels, &predictions, &weights, &order, &label_dtype) {
725 return Ok(score)
726 }
727 else if let Some(score) = self.score_py_match_run::<u16>(py, &labels, &predictions, &weights, &order, &label_dtype) {
728 return Ok(score)
729 }
730 else if let Some(score) = self.score_py_match_run::<i16>(py, &labels, &predictions, &weights, &order, &label_dtype) {
731 return Ok(score)
732 }
733 else if let Some(score) = self.score_py_match_run::<u32>(py, &labels, &predictions, &weights, &order, &label_dtype) {
734 return Ok(score)
735 }
736 else if let Some(score) = self.score_py_match_run::<i32>(py, &labels, &predictions, &weights, &order, &label_dtype) {
737 return Ok(score)
738 }
739 else if let Some(score) = self.score_py_match_run::<u64>(py, &labels, &predictions, &weights, &order, &label_dtype) {
740 return Ok(score)
741 }
742 else if let Some(score) = self.score_py_match_run::<i64>(py, &labels, &predictions, &weights, &order, &label_dtype) {
743 return Ok(score)
744 }
745 return Err(PyTypeError::new_err(format!("Unsupported dtype for labels: {}. Supported dtypes are bool, uint8, uint16, uint32, uint64, in8, int16, int32, int64", label_dtype)));
746 }
747}
748
749struct PyAveragePrecision {
750
751}
752
753impl PyAveragePrecision{
754 fn new() -> Self {
755 return PyAveragePrecision {};
756 }
757}
758
759impl PyScore for PyAveragePrecision {
760 fn score<B, L, P, W>(&self, labels: &L, predictions: &P, weights: Option<&W>, order: Option<Order>) -> f64
761 where B: BinaryLabel, L: Data<B>, P: SortableData<f64> + Data<f64>, W: Data<f64> {
762 return average_precision_with_order(labels, predictions, weights, order);
763 }
764}
765
766struct PyRocAuc {
767 max_fpr: Option<f64>
768}
769
770impl PyRocAuc {
771 fn new(max_fpr: Option<f64>) -> Self {
772 return PyRocAuc { max_fpr: max_fpr };
773 }
774}
775
776impl PyScore for PyRocAuc {
777 fn score<B, L, P, W>(&self, labels: &L, predictions: &P, weights: Option<&W>, order: Option<Order>) -> f64
778 where B: BinaryLabel, L: Data<B>, P: SortableData<f64> + Data<f64>, W: Data<f64> {
779 return roc_auc_with_order(labels, predictions, weights, order, self.max_fpr);
780 }
781}
782
783
784#[pyfunction(name = "average_precision")]
785#[pyo3(signature = (labels, predictions, *, weights=None, order=None))]
786pub fn average_precision_py<'py>(
787 py: Python<'py>,
788 labels: &Bound<'py, PyUntypedArray>,
789 predictions: PyReadonlyArray1<'py, f64>,
790 weights: Option<PyReadonlyArray1<'py, f64>>,
791 order: Option<PyOrder>
792) -> PyResult<f64> {
793 return PyAveragePrecision::new().score_py(py, labels, predictions, weights, order);
794}
795
796#[pyfunction(name = "roc_auc")]
797#[pyo3(signature = (labels, predictions, *, weights=None, order=None, max_fpr=None))]
798pub fn roc_auc_py<'py>(
799 py: Python<'py>,
800 labels: &Bound<'py, PyUntypedArray>,
801 predictions: PyReadonlyArray1<'py, f64>,
802 weights: Option<PyReadonlyArray1<'py, f64>>,
803 order: Option<PyOrder>,
804 max_fpr: Option<f64>,
805) -> PyResult<f64> {
806 return PyRocAuc::new(max_fpr).score_py(py, labels, predictions, weights, order);
807}
808
809#[pyfunction(name = "loo_cossim")]
810#[pyo3(signature = (data))]
811pub fn loo_cossim_py<'py>(
812 py: Python<'py>,
813 data: &Bound<'py, PyUntypedArray>
814) -> PyResult<f64> {
815 if data.ndim() != 2 {
816 return Err(PyTypeError::new_err(format!("Expected 2-dimensional array for data (samples x features) but found {} dimenisons.", data.ndim())));
817 }
818
819 let dt = data.dtype();
820 if !dt.is_equiv_to(&dtype::<f64>(py)) {
821 return Err(PyTypeError::new_err(format!("Currently only float64 data supported, but found {}", dt)));
822 }
823 let typed_data = data.downcast::<PyArray2<f64>>().unwrap().readonly();
824 let array = typed_data.as_array();
825 let score = py.allow_threads(move || {
826 loo_cossim_single(&array)
827 });
828 return Ok(score);
829}
830
831#[pyfunction(name = "loo_cossim_many")]
832#[pyo3(signature = (data))]
833pub fn loo_cossim_many_py<'py>(
834 py: Python<'py>,
835 data: &Bound<'py, PyUntypedArray>
836) -> PyResult<Bound<'py, PyArray1<f64>>> {
837 if data.ndim() != 3 {
838 return Err(PyTypeError::new_err(format!("Expected 3-dimensional array for data (outer(?) x samples x features) but found {} dimenisons.", data.ndim())));
839 }
840
841 let dt = data.dtype();
842 if !dt.is_equiv_to(&dtype::<f64>(py)) {
843 return Err(PyTypeError::new_err(format!("Currently only float64 data supported, but found {}", dt)));
844 }
845 let typed_data = data.downcast::<PyArray3<f64>>().unwrap().readonly();
846 let array = typed_data.as_array();
847 let score = py.allow_threads(move || {
848 loo_cossim_many(&array)
849 });
850 let score_py = PyArray::from_owned_array(py, score);
851 return Ok(score_py);
852}
853
854#[pymodule(name = "_scors")]
855fn scors(m: &Bound<'_, PyModule>) -> PyResult<()> {
856 m.add_function(wrap_pyfunction!(average_precision_py, m)?).unwrap();
857 m.add_function(wrap_pyfunction!(roc_auc_py, m)?).unwrap();
858 m.add_function(wrap_pyfunction!(loo_cossim_py, m)?).unwrap();
859 m.add_function(wrap_pyfunction!(loo_cossim_many_py, m)?).unwrap();
860 m.add_class::<PyOrder>().unwrap();
861 return Ok(());
862}
863
864
865#[cfg(test)]
866mod tests {
867 use approx::{assert_relative_eq};
868 use super::*;
869
870 #[test]
871 fn test_average_precision_on_sorted() {
872 let labels: [u8; 4] = [1, 0, 1, 0];
873 let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
875 let actual = average_precision_on_sorted_labels(&labels, Some(&weights), Order::DESCENDING);
876 assert_eq!(actual, 0.8333333333333333);
877 }
878
879 #[test]
880 fn test_average_precision_unsorted() {
881 let labels: [u8; 4] = [0, 0, 1, 1];
882 let predictions: [f64; 4] = [0.1, 0.4, 0.35, 0.8];
883 let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
884 let actual = average_precision_with_order(&labels, &predictions, Some(&weights), None);
885 assert_eq!(actual, 0.8333333333333333);
886 }
887
888 #[test]
889 fn test_average_precision_sorted() {
890 let labels: [u8; 4] = [1, 0, 1, 0];
891 let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
892 let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
893 let actual = average_precision_with_order(&labels, &predictions, Some(&weights), Some(Order::DESCENDING));
894 assert_eq!(actual, 0.8333333333333333);
895 }
896
897 #[test]
898 fn test_average_precision_sorted_pair() {
899 let labels: [u8; 4] = [1, 0, 1, 0];
900 let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
901 let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
902 let actual = average_precision_on_sorted_samples(&labels, &predictions, &weights, &labels, &predictions, &weights);
903 assert_eq!(actual, 0.8333333333333333);
904 }
905
906 #[test]
907 fn test_roc_auc() {
908 let labels: [u8; 4] = [1, 0, 1, 0];
909 let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
910 let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
911 let actual = roc_auc_with_order(&labels, &predictions, Some(&weights), Some(Order::DESCENDING), None);
912 assert_eq!(actual, 0.75);
913 }
914
915 #[test]
916 fn test_loo_cossim_single() {
917 let data = arr2(&[[0.77395605, 0.43887844, 0.85859792],
918 [0.69736803, 0.09417735, 0.97562235]]);
919 let cossim = loo_cossim_single(&data.view());
920 let expected = 0.95385941;
921 assert_relative_eq!(cossim, expected);
922 }
923
924 #[test]
925 fn test_loo_cossim_many() {
926 let data = arr3(&[[[0.77395605, 0.43887844, 0.85859792],
927 [0.69736803, 0.09417735, 0.97562235]],
928 [[0.7611397 , 0.78606431, 0.12811363],
929 [0.45038594, 0.37079802, 0.92676499]],
930 [[0.64386512, 0.82276161, 0.4434142 ],
931 [0.22723872, 0.55458479, 0.06381726]],
932 [[0.82763117, 0.6316644 , 0.75808774],
933 [0.35452597, 0.97069802, 0.89312112]]]);
934 let cossim = loo_cossim_many(&data.view());
935 let expected = arr1(&[0.95385941, 0.62417001, 0.92228589, 0.90025417]);
936 assert_eq!(cossim.shape(), expected.shape());
937 for (c, e) in cossim.iter().zip(expected.iter()) {
938 assert_relative_eq!(c, e);
939 }
940 }
941}