1#![feature(trait_alias)]
2
3mod combine;
4
5use ndarray::{Array1,ArrayView,ArrayView2,ArrayView3,ArrayViewMut1,Ix1};
6use num;
7use num::traits::float::TotalOrder;
8use numpy::{Element,PyArray,PyArray1,PyArray2,PyArray3,PyArrayDescrMethods,PyArrayDyn,PyArrayMethods,PyReadonlyArray1,PyUntypedArray,PyUntypedArrayMethods,dtype};
9use pyo3::Bound;
10use pyo3::exceptions::PyTypeError;
11use pyo3::marker::Ungil;
12use pyo3::prelude::*;
13use std::cmp::PartialOrd;
14use std::iter::{DoubleEndedIterator,repeat};
15use std::ops::AddAssign;
16
17#[derive(Clone, Copy)]
18pub enum Order {
19 ASCENDING,
20 DESCENDING
21}
22
23#[derive(Clone, Copy)]
24struct ConstWeight<F: num::Float> {
25 value: F
26}
27
28impl <F: num::Float> ConstWeight<F> {
29 fn new(value: F) -> Self {
30 return ConstWeight { value: value };
31 }
32 fn one() -> Self {
33 return Self::new(F::one());
34 }
35}
36
37pub trait Data<T: Clone>: {
38 fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> + Clone;
41 fn get_at(&self, index: usize) -> T;
42}
43
44pub trait SortableData<T> {
45 fn argsort_unstable(&self) -> Vec<usize>;
46}
47
48impl <F: num::Float> Iterator for ConstWeight<F> {
49 type Item = F;
50 fn next(&mut self) -> Option<F> {
51 return Some(self.value);
52 }
53}
54
55impl <F: num::Float> DoubleEndedIterator for ConstWeight<F> {
56 fn next_back(&mut self) -> Option<F> {
57 return Some(self.value);
58 }
59}
60
61impl <F: num::Float> Data<F> for ConstWeight<F> {
62 fn get_iterator(&self) -> impl DoubleEndedIterator<Item = F> + Clone {
63 return ConstWeight::new(self.value);
64 }
65
66 fn get_at(&self, _index: usize) -> F {
67 return self.value.clone();
68 }
69}
70
71impl <T: Clone> Data<T> for Vec<T> {
72 fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> + Clone {
73 return self.iter().cloned();
74 }
75 fn get_at(&self, index: usize) -> T {
76 return self[index].clone();
77 }
78}
79
80impl SortableData<f64> for Vec<f64> {
81 fn argsort_unstable(&self) -> Vec<usize> {
82 let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
83 indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
84 return indices;
86 }
87}
88
89impl <T: Clone> Data<T> for &[T] {
90 fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> + Clone {
91 return self.iter().cloned();
92 }
93 fn get_at(&self, index: usize) -> T {
94 return self[index].clone();
95 }
96}
97
98impl SortableData<f64> for &[f64] {
99 fn argsort_unstable(&self) -> Vec<usize> {
100 let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
101 indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
102 return indices;
103 }
104}
105
106impl <T: Clone, const N: usize> Data<T> for [T; N] {
107 fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> + Clone {
108 return self.iter().cloned();
109 }
110 fn get_at(&self, index: usize) -> T {
111 return self[index].clone();
112 }
113}
114
115impl <const N: usize> SortableData<f64> for [f64; N] {
116 fn argsort_unstable(&self) -> Vec<usize> {
117 let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
118 indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
119 return indices;
120 }
121}
122
123impl <T: Clone> Data<T> for ArrayView<'_, T, Ix1> {
124 fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> + Clone {
125 return self.iter().cloned();
126 }
127 fn get_at(&self, index: usize) -> T {
128 return self[index].clone();
129 }
130}
131
132impl <F> SortableData<F> for ArrayView<'_, F, Ix1>
133where F: num::Float + TotalOrder
134{
135 fn argsort_unstable(&self) -> Vec<usize> {
136 let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
137 indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
138 return indices;
139 }
140}
141
142pub trait BinaryLabel: Clone + Copy {
143 fn get_value(&self) -> bool;
144}
145
146impl BinaryLabel for bool {
147 fn get_value(&self) -> bool {
148 return self.clone();
149 }
150}
151
152impl BinaryLabel for u8 {
153 fn get_value(&self) -> bool {
154 return (self & 1u8) == 1u8;
155 }
156}
157
158impl BinaryLabel for u16 {
159 fn get_value(&self) -> bool {
160 return (self & 1u16) == 1u16;
161 }
162}
163
164impl BinaryLabel for u32 {
165 fn get_value(&self) -> bool {
166 return (self & 1u32) == 1u32;
167 }
168}
169
170impl BinaryLabel for u64 {
171 fn get_value(&self) -> bool {
172 return (self & 1u64) == 1u64;
173 }
174}
175
176impl BinaryLabel for i8 {
177 fn get_value(&self) -> bool {
178 return (self & 1i8) == 1i8;
179 }
180}
181
182impl BinaryLabel for i16 {
183 fn get_value(&self) -> bool {
184 return (self & 1i16) == 1i16;
185 }
186}
187
188impl BinaryLabel for i32 {
189 fn get_value(&self) -> bool {
190 return (self & 1i32) == 1i32;
191 }
192}
193
194impl BinaryLabel for i64 {
195 fn get_value(&self) -> bool {
196 return (self & 1i64) == 1i64;
197 }
198}
199
200fn select<T, I>(slice: &I, indices: &[usize]) -> Vec<T>
201where T: Copy, I: Data<T>
202{
203 let mut selection: Vec<T> = Vec::new();
204 selection.reserve_exact(indices.len());
205 for index in indices {
206 selection.push(slice.get_at(*index));
207 }
208 return selection;
209}
210
211pub trait ScoreAccumulator = num::Float + AddAssign + From<bool> + From<f32>;
212pub trait IntoScore<S: ScoreAccumulator> = Into<S> + num::Float;
213
214
215
216pub trait ScoreSortedDescending {
217 fn _score<S: ScoreAccumulator>(&self, labels_with_weights: impl Iterator<Item = (S, (bool, S))> + Clone) -> S;
218 fn score<S, P, B, W>(&self, labels_with_weights: impl Iterator<Item = (P, (B, W))> + Clone) -> S
219 where S: ScoreAccumulator, P: IntoScore<S>, B: BinaryLabel, W: IntoScore<S>
220 {
221 return self._score(
222 labels_with_weights.map(|(p, (b, w))| -> (S, (bool, S)) { (p.into(), (b.get_value(), w.into()))})
223 )
224 }
225}
226
227
228pub fn score_sorted_iterators<S, SA, P, B, W>(
229 score: S,
230 predictions: impl Iterator<Item = P> + Clone,
231 labels: impl Iterator<Item = B> + Clone,
232 weights: impl Iterator<Item = W> + Clone,
233) -> SA
234where S: ScoreSortedDescending, SA: ScoreAccumulator, P: IntoScore<SA>, B: BinaryLabel, W: IntoScore<SA> {
235 let zipped = predictions.zip(labels.zip(weights));
236 return score.score(zipped);
237}
238
239
240pub fn score_sorted_sample<S, SA, P, B, W>(
241 score: S,
242 predictions: &impl Data<P>,
243 labels: &impl Data<B>,
244 weights: &impl Data<W>,
245 order: Order,
246) -> SA
247where S: ScoreSortedDescending, SA: ScoreAccumulator, P: IntoScore<SA>, B: BinaryLabel, W: IntoScore<SA> + Clone {
248 let p = predictions.get_iterator();
249 let l = labels.get_iterator();
250 let w = weights.get_iterator();
251 return match order {
252 Order::ASCENDING => score_sorted_iterators(score, p.rev(), l.rev(), w.rev()),
253 Order::DESCENDING => score_sorted_iterators(score, p, l, w),
254 };
255}
256
257
258pub fn score_maybe_sorted_sample<S, SA, P, B, W>(
259 score: S,
260 predictions: &(impl Data<P> + SortableData<P>),
261 labels: &impl Data<B>,
262 weights: Option<&impl Data<W>>,
263 order: Option<Order>,
264) -> SA
265where S: ScoreSortedDescending, SA: ScoreAccumulator, P: IntoScore<SA>, B: BinaryLabel, W: IntoScore<SA> + Clone
266{
267 return match order {
268 Some(o) => {
269 match weights {
270 Some(w) => score_sorted_sample(score, predictions, labels, w, o),
271 None => score_sorted_sample(score, predictions, labels, &ConstWeight::<W>::one(), o),
272 }
273 }
274 None => {
275 let indices = predictions.argsort_unstable();
276 let sorted_labels = select(labels, &indices);
277 let sorted_predictions = select(predictions, &indices);
278 match weights {
279 Some(w) => {
280 let sorted_weights = select(w, &indices);
281 score_sorted_sample(score, &sorted_predictions, &sorted_labels, &sorted_weights, Order::DESCENDING)
282 }
283 None => score_sorted_sample(score, &sorted_predictions, &sorted_labels, &ConstWeight::<W>::one(), Order::DESCENDING)
284 }
285 }
286 };
287}
288
289
290pub fn score_sample<S, SA, P, B, W>(
291 score: S,
292 predictions: &(impl Data<P> + SortableData<P>),
293 labels: &impl Data<B>,
294 weights: Option<&impl Data<W>>,
295) -> SA
296
297where S: ScoreSortedDescending, SA: ScoreAccumulator, P: IntoScore<SA>, B: BinaryLabel, W: IntoScore<SA> + Clone {
298 return score_maybe_sorted_sample(score, predictions, labels, weights, None);
299}
300
301
302pub fn score_two_sorted_samples<S, SA, P, B, W>(
303 score: S,
304 predictions1: impl Iterator<Item = P> + Clone,
305 label1: impl Iterator<Item = B> + Clone,
306 weight1: impl Iterator<Item = W> + Clone,
307 predictions2: impl Iterator<Item = P> + Clone,
308 label2: impl Iterator<Item = B> + Clone,
309 weight2: impl Iterator<Item = W> + Clone,
310) -> SA
311where S: ScoreSortedDescending, SA: ScoreAccumulator, P: IntoScore<SA>, B: BinaryLabel + PartialOrd, W: IntoScore<SA>
312{
313 return score_two_sorted_samples_zipped(
314 score,
315 predictions1.zip(label1.zip(weight1)),
316 predictions2.zip(label2.zip(weight2)),
317 );
318}
319
320
321pub fn score_two_sorted_samples_zipped<S, SA, P, B, W>(
322 score: S,
323 iter1: impl Iterator<Item = (P, (B, W))> + Clone,
324 iter2: impl Iterator<Item = (P, (B, W))> + Clone,
325) -> SA
326where S: ScoreSortedDescending, SA: ScoreAccumulator, P: IntoScore<SA>, B: BinaryLabel + PartialOrd, W: IntoScore<SA>
327{
328 let combined_iter = combine::combine::CombineIterDescending::new(iter1, iter2);
329 return score.score(combined_iter);
330}
331
332
333struct AveragePrecision {
334
335}
336
337
338impl AveragePrecision {
339 fn new() -> Self {
340 return AveragePrecision{};
341 }
342}
343
344
345#[derive(Clone,Copy,Debug)]
346struct Positives<P>
347where P: num::Float + From<bool> + AddAssign
348{
349 tps: P,
350 fps: P,
351}
352
353impl <P> Positives<P>
354where P: num::Float + From<bool> + AddAssign
355{
356 fn new(tps: P, fps: P) -> Self {
357 return Positives { tps, fps };
358 }
359
360 fn zero() -> Self {
361 return Positives::new(P::zero(), P::zero());
362 }
363
364 fn add(&mut self, label: bool, weight: P) {
365 let label: P = label.into();
366 let tp = weight * label;
367 let fp = weight - tp; self.tps += tp;
369 self.fps += fp;
370 }
371
372 fn positives_sum(&self) -> P {
373 return self.tps + self.fps;
374 }
375
376 fn precision(&self) -> P {
377 return self.tps / self.positives_sum();
378 }
379}
380
381
382impl ScoreSortedDescending for AveragePrecision {
383 fn _score<S: ScoreAccumulator>(&self, mut labels_with_weights: impl Iterator<Item = (S, (bool, S))> + Clone) -> S
384 {
385 let mut positives: Positives<S> = Positives::zero();
386 let mut last_p: S = f32::NAN.into();
387 let mut last_tps: S = S::zero();
388 let mut ap: S = S::zero();
389
390 match labels_with_weights.next() {
392 None => (), Some((p, (label, w))) => {
394 positives.add(label, w);
395 last_p = p;
396 }
397 }
398
399 for (p, (label, w)) in labels_with_weights {
400 if last_p != p {
401 ap += (positives.tps - last_tps) * positives.precision();
402 last_p = p;
403 last_tps = positives.tps;
404 }
405 positives.add(label.get_value(), w.into());
406 }
407
408 ap += (positives.tps - last_tps) * positives.precision();
409
410 return if positives.tps == S::zero() {
414 S::zero()
415 } else {
416 ap / positives.tps
417 };
418 }
419}
420
421
422struct RocAuc {
423
424}
425
426
427impl RocAuc {
428 fn new() -> Self {
429 return RocAuc { };
430 }
431}
432
433
434impl ScoreSortedDescending for RocAuc {
435 fn _score<S: ScoreAccumulator>(&self, mut labels_with_weights: impl Iterator<Item = (S, (bool, S))> + Clone) -> S
436 {
437 let mut positives: Positives<S> = Positives::zero();
438 let mut last_p: S = f32::NAN.into();
439 let mut last_counted_fp = S::zero();
440 let mut last_counted_tp = S::zero();
441 let mut area_under_curve = S::zero();
442
443 match labels_with_weights.next() {
445 None => (), Some((p, (label, w))) => {
447 positives.add(label, w);
448 last_p = p;
449 }
450 }
451
452 for (p, (label, w)) in labels_with_weights {
453 if last_p != p {
454 area_under_curve += area_under_line_segment(
455 last_counted_fp,
456 positives.fps,
457 last_counted_tp,
458 positives.tps,
459 );
460 last_counted_fp = positives.fps;
461 last_counted_tp = positives.tps;
462 last_p = p;
463 }
464 positives.add(label, w);
465 }
466 area_under_curve += area_under_line_segment(
467 last_counted_fp,
468 positives.fps,
469 last_counted_tp,
470 positives.tps,
471 );
472 return area_under_curve / (positives.tps * positives.fps);
473 }
474}
475
476
477struct RocAucWithMaxFPR {
478 max_fpr: f32,
479}
480
481
482impl RocAucWithMaxFPR {
483 fn new(max_fpr: f32) -> Self {
484 return RocAucWithMaxFPR { max_fpr };
485 }
486
487 fn get_positive_sum<B, W>(labels_with_weights: impl Iterator<Item = (B, W)>) -> Positives<W>
488 where B: BinaryLabel, W: num::Float + From::<bool> + AddAssign
489 {
490 let mut positives: Positives<W> = Positives::zero();
491 for (label, weight) in labels_with_weights {
492 positives.add(label.get_value(), weight);
493 }
494 return positives;
495 }
496}
497
498
499impl ScoreSortedDescending for RocAucWithMaxFPR {
500 fn _score<S: ScoreAccumulator>(&self, mut labels_with_weights: impl Iterator<Item = (S, (bool, S))> + Clone) -> S
501 {
502 let total_positives = Self::get_positive_sum(labels_with_weights.clone().map(|(_a, b)| b));
503 let max_fpr: S = self.max_fpr.into();
504 let false_positive_cutoff = max_fpr * total_positives.fps;
505
506 let mut positives: Positives<S> = Positives::zero();
507 let mut last_p: S = f32::NAN.into();
508 let mut last_counted_fp = S::zero();
509 let mut last_counted_tp = S::zero();
510 let mut area_under_curve = S::zero();
511
512 match labels_with_weights.next() {
514 None => (), Some((p, (label, w))) => {
516 positives.add(label, w);
517 last_p = p;
518 }
519 }
520
521 for (p, (label, w)) in labels_with_weights {
522 if last_p != p {
523 area_under_curve += area_under_line_segment(
524 last_counted_fp,
525 positives.fps,
526 last_counted_tp,
527 positives.tps,
528 );
529 last_counted_fp = positives.fps;
530 last_counted_tp = positives.tps;
531 last_p = p;
532 }
533 let mut next_pos = positives.clone();
534 next_pos.add(label, w);
535 if next_pos.fps > false_positive_cutoff {
536 let dx = next_pos.fps - positives.fps;
537 let dy = next_pos.tps - positives.tps;
538 positives = Positives::new(
539 positives.tps + dy * false_positive_cutoff / dx,
540 false_positive_cutoff,
541 );
542 break;
543 }
544 else {
545 positives = next_pos;
546 }
547 }
548
549 area_under_curve += area_under_line_segment(
550 last_counted_fp,
551 positives.fps,
552 last_counted_tp,
553 positives.tps,
554 );
555
556 let normalized_area_under_curve = area_under_curve / (total_positives.tps * total_positives.fps);
557 let one_half: S = 0.5f32.into();
558 let min_area = one_half * max_fpr * max_fpr;
559 let max_area = max_fpr;
560 return one_half * (S::one() + (normalized_area_under_curve - min_area) / (max_area - min_area));
561 }
562}
563
564
565struct RocAucWithOptionalMaxFPR {
566 max_fpr: Option<f32>,
570}
571
572impl RocAucWithOptionalMaxFPR {
573 fn new(max_fpr: Option<f32>) -> Self {
574 return Self { max_fpr };
575 }
576}
577
578
579impl ScoreSortedDescending for RocAucWithOptionalMaxFPR {
580 fn _score<S: ScoreAccumulator>(&self, labels_with_weights: impl Iterator<Item = (S, (bool, S))> + Clone) -> S
581 {
582 return match self.max_fpr {
583 Some(mfpr) => RocAucWithMaxFPR::new(mfpr).score(labels_with_weights),
584 None => RocAuc::new().score(labels_with_weights),
585 }
586 }
587}
588
589
590pub fn average_precision<S, P, B, W>(
591 predictions: &(impl Data<P> + SortableData<P>),
592 labels: &impl Data<B>,
593 weights: Option<&impl Data<W>>,
594 order: Option<Order>,
595) -> S
596where S: ScoreAccumulator, P: IntoScore<S>, B: BinaryLabel, W: IntoScore<S> + Clone
597{
598 return score_maybe_sorted_sample(AveragePrecision::new(), predictions, labels, weights, order);
599}
600
601
602pub fn roc_auc<S, P, B, W>(
603 predictions: &(impl Data<P> + SortableData<P>),
604 labels: &impl Data<B>,
605 weights: Option<&impl Data<W>>,
606 order: Option<Order>,
607 max_fpr: Option<f32>,
608) -> S
609where S: ScoreAccumulator, P: IntoScore<S>, B: BinaryLabel, W: IntoScore<S> + Clone
610{
611 return score_maybe_sorted_sample(RocAucWithOptionalMaxFPR::new(max_fpr), predictions, labels, weights, order);
612}
613
614
615fn area_under_line_segment<P>(x0: P, x1: P, y0: P, y1: P) -> P
616where P: num::Float + From<f32>
617{
618 let dx = x1 - x0;
619 let dy = y1 - y0;
620 let one_half: P = 0.5f32.into();
621 return dx * y0 + dy * dx * one_half;
622}
623
624
625pub fn loo_cossim<F: num::Float + AddAssign>(mat: &ArrayView2<'_, F>, replicate_sum: &mut ArrayViewMut1<'_, F>) -> F {
626 let num_replicates = mat.shape()[0];
627 let loo_weight = F::from(num_replicates - 1).unwrap();
628 let loo_weight_factor = F::from(1).unwrap() / loo_weight;
629 for mat_replicate in mat.outer_iter() {
630 for (feature, feature_sum) in mat_replicate.iter().zip(replicate_sum.iter_mut()) {
631 *feature_sum += *feature;
632 }
633 }
634
635 let mut result = F::zero();
636
637 for mat_replicate in mat.outer_iter() {
638 let mut m_sqs = F::zero();
639 let mut l_sqs = F::zero();
640 let mut prod_sum = F::zero();
641 for (feature, feature_sum) in mat_replicate.iter().zip(replicate_sum.iter()) {
642 let m_f = *feature;
643 let l_f = (*feature_sum - *feature) * loo_weight_factor;
644 prod_sum += m_f * l_f;
645 m_sqs += m_f * m_f;
646 l_sqs += l_f * l_f;
647 }
648 result += prod_sum / (m_sqs * l_sqs).sqrt();
649 }
650
651 return result / F::from(num_replicates).unwrap();
652}
653
654
655pub fn loo_cossim_single<F: num::Float + AddAssign>(mat: &ArrayView2<'_, F>) -> F {
656 let mut replicate_sum = Array1::<F>::zeros(mat.shape()[1]);
657 return loo_cossim(mat, &mut replicate_sum.view_mut());
658}
659
660
661pub fn loo_cossim_many<F: num::Float + AddAssign>(mat: &ArrayView3<'_, F>) -> Array1<F> {
662 let mut cossims = Array1::<F>::zeros(mat.shape()[0]);
663 let mut replicate_sum = Array1::<F>::zeros(mat.shape()[2]);
664 for (m, c) in mat.outer_iter().zip(cossims.iter_mut()) {
665 replicate_sum.fill(F::zero());
666 *c = loo_cossim(&m, &mut replicate_sum.view_mut());
667 }
668 return cossims;
669}
670
671
672#[pyclass(eq, eq_int, name="Order")]
674#[derive(Clone, Copy, PartialEq)]
675pub enum PyOrder {
676 ASCENDING,
677 DESCENDING
678}
679
680fn py_order_as_order(order: PyOrder) -> Order {
681 return match order {
682 PyOrder::ASCENDING => Order::ASCENDING,
683 PyOrder::DESCENDING => Order::DESCENDING,
684 }
685}
686
687trait PyScoreGeneric<S: ScoreSortedDescending>: Ungil + Sync {
688
689 fn get_score(&self) -> S;
690
691 fn score_py<'py, P, B, W>(
692 &self,
693 py: Python<'py>,
694 labels: PyReadonlyArray1<'py, B>,
695 predictions: PyReadonlyArray1<'py, P>,
696 weights: Option<PyReadonlyArray1<'py, W>>,
697 order: Option<PyOrder>,
698 ) -> P
699 where P: ScoreAccumulator + Element + TotalOrder, B: BinaryLabel + Element, W: IntoScore<P> + Element
700 {
701 let labels = labels.as_array();
702 let predictions = predictions.as_array();
703 let order = order.map(py_order_as_order);
704 let score = match weights {
705 Some(weight) => {
706 let w = weight.as_array();
707 py.allow_threads(move || {
708 score_maybe_sorted_sample(self.get_score(), &predictions, &labels, Some(&w), order)
709 })
710 },
711 None => py.allow_threads(move || {
712 score_maybe_sorted_sample(self.get_score(), &predictions, &labels, None::<&Vec<W>>, order)
713 })
714 };
715 return score;
716 }
717
718 fn score_two_sorted_samples_py_generic<'py, B, F, W, B1, B2, F1, F2, W1, W2>(
719 &self,
720 py: Python<'py>,
721 labels1: PyReadonlyArray1<'py, B1>,
722 predictions1: PyReadonlyArray1<'py, F1>,
723 weights1: Option<PyReadonlyArray1<'py, W1>>,
724 labels2: PyReadonlyArray1<'py, B1>,
725 predictions2: PyReadonlyArray1<'py, F2>,
726 weights2: Option<PyReadonlyArray1<'py, W2>>,
727 ) -> F
728 where B: BinaryLabel + PartialOrd, F: ScoreAccumulator + TotalOrder + Ungil, W: IntoScore<F>, B1: Element + Into<B> + Clone, B2: Element + Into<B> + Clone, F1: Element + Into<F> + Clone, F2: Element + Into<F> + Clone, W1: Element + Into<W> + Clone, W2: Element + Into<W> + Clone
729 {
730 let l1 = labels1.as_array().into_iter().cloned().map(|l| -> B { l.into() });
731 let l2 = labels2.as_array().into_iter().cloned().map(|l| -> B { l.into() });
732 let p1 = predictions1.as_array().into_iter().cloned().map(|f| -> F { f.into() });
733 let p2 = predictions2.as_array().into_iter().cloned().map(|f| -> F { f.into() });
734
735
736 return match (weights1, weights2) {
737 (None, None) => {
738 py.allow_threads(move || {
739 score_two_sorted_samples(self.get_score(), p1, l1, repeat(W::one()), p2, l2, repeat(W::one()))
740 })
741 }
742 (Some(w1), None) => {
743 let w1i = w1.as_array().into_iter().cloned().map(|w| -> W { w.into() });
744 py.allow_threads(move || {
745 score_two_sorted_samples(self.get_score(), p1, l1, w1i, p2, l2, repeat(W::one()))
746 })
747 }
748 (None, Some(w2)) => {
749 let w2i = w2.as_array().into_iter().cloned().map(|w| -> W { w.into() });
750 py.allow_threads(move || {
751 score_two_sorted_samples(self.get_score(), p1, l1, repeat(W::one()), p2, l2, w2i)
752 })
753 }
754 (Some(w1), Some(w2)) => {
755 let w1i = w1.as_array().into_iter().cloned().map(|w| -> W { w.into() });
756 let w2i = w2.as_array().into_iter().cloned().map(|w| -> W { w.into() });
757 py.allow_threads(move || {
758 score_two_sorted_samples(self.get_score(), p1, l1, w1i, p2, l2, w2i)
759 })
760 }
761 };
762 }
763}
764
765struct AveragePrecisionPyGeneric {
766
767}
768
769impl AveragePrecisionPyGeneric {
770 fn new() -> Self {
771 return AveragePrecisionPyGeneric {};
772 }
773}
774
775impl PyScoreGeneric<AveragePrecision> for AveragePrecisionPyGeneric {
776 fn get_score(&self) -> AveragePrecision {
777 return AveragePrecision::new();
778 }
779}
780
781struct RocAucPyGeneric {
782 max_fpr: Option<f32>,
783}
784
785impl RocAucPyGeneric {
786 fn new(max_fpr: Option<f32>) -> Self {
787 return RocAucPyGeneric { max_fpr: max_fpr };
788 }
789}
790
791impl PyScoreGeneric<RocAucWithOptionalMaxFPR> for RocAucPyGeneric {
792 fn get_score(&self) -> RocAucWithOptionalMaxFPR {
793 return RocAucWithOptionalMaxFPR::new(self.max_fpr);
794 }
795}
796
797macro_rules! average_precision_py {
802 ($fname: ident, $pyname:literal, $label_type:ty, $prediction_type:ty, $weight_type:ty) => {
803 #[pyfunction(name = $pyname)]
804 #[pyo3(signature = (labels, predictions, *, weights=None, order=None))]
805 pub fn $fname<'py>(
806 py: Python<'py>,
807 labels: PyReadonlyArray1<'py, $label_type>,
808 predictions: PyReadonlyArray1<'py, $prediction_type>,
809 weights: Option<PyReadonlyArray1<'py, $weight_type>>,
810 order: Option<PyOrder>
811 ) -> $prediction_type
812 {
813 return AveragePrecisionPyGeneric::new().score_py(py, labels, predictions, weights, order);
814 }
815 };
816 ($fname: ident, $pyname:literal, $label_type:ty, $prediction_type:ty, $weight_type:ty, $py_module:ident) => {
817 average_precision_py!($fname, $pyname, $label_type, $prediction_type, $weight_type);
818 $py_module.add_function(wrap_pyfunction!($fname, $py_module)?).unwrap();
819 };
820}
821
822
823macro_rules! roc_auc_py {
824 ($fname: ident, $pyname:literal, $label_type:ty, $prediction_type:ty, $weight_type:ty) => {
825 #[pyfunction(name = $pyname)]
826 #[pyo3(signature = (labels, predictions, *, weights=None, order=None, max_fpr=None))]
827 pub fn $fname<'py>(
828 py: Python<'py>,
829 labels: PyReadonlyArray1<'py, $label_type>,
830 predictions: PyReadonlyArray1<'py, $prediction_type>,
831 weights: Option<PyReadonlyArray1<'py, $weight_type>>,
832 order: Option<PyOrder>,
833 max_fpr: Option<f32>,
834 ) -> $prediction_type
835 {
836 return RocAucPyGeneric::new(max_fpr).score_py(py, labels, predictions, weights, order);
837 }
838 };
839 ($fname: ident, $pyname:literal, $label_type:ty, $prediction_type:ty, $weight_type: ty, $py_module:ident) => {
840 roc_auc_py!($fname, $pyname, $label_type, $prediction_type, $weight_type);
841 $py_module.add_function(wrap_pyfunction!($fname, $py_module)?).unwrap();
842 };
843}
844
845
846macro_rules! average_precision_on_two_sorted_samples_py {
847 ($fname: ident, $pyname:literal, $lt:ty, $pt:ty, $wt:ty, $lt1:ty, $pt1:ty, $wt1:ty, $lt2:ty, $pt2:ty, $wt2: ty) => {
848 #[pyfunction(name = $pyname)]
849 #[pyo3(signature = (labels1, predictions1, weights1, labels2, predictions2, weights2, *))]
850 pub fn $fname<'py>(
851 py: Python<'py>,
852 labels1: PyReadonlyArray1<'py, $lt1>,
853 predictions1: PyReadonlyArray1<'py, $pt1>,
854 weights1: Option<PyReadonlyArray1<'py, $wt1>>,
855 labels2: PyReadonlyArray1<'py, $lt2>,
856 predictions2: PyReadonlyArray1<'py, $pt2>,
857 weights2: Option<PyReadonlyArray1<'py, $wt2>>,
858 ) -> $pt
859 {
860 return AveragePrecisionPyGeneric::new().score_two_sorted_samples_py_generic::<$lt, $pt, $wt, $lt1, $lt2, $pt1, $pt2, $wt1, $wt2>(py, labels1, predictions1, weights1, labels2, predictions2, weights2);
861 }
862 };
863 ($fname: ident, $pyname:literal, $lt:ty, $pt:ty, $wt:ty, $lt1:ty, $pt1:ty, $wt1:ty, $lt2:ty, $pt2:ty, $wt2: ty, $py_module:ident) => {
864 average_precision_on_two_sorted_samples_py!($fname, $pyname, $lt, $pt, $wt, $lt1, $pt1, $wt1, $lt2, $pt2, $wt2);
865 $py_module.add_function(wrap_pyfunction!($fname, $py_module)?).unwrap();
866 };
867}
868
869
870macro_rules! roc_auc_on_two_sorted_samples_py {
871 ($fname: ident, $pyname:literal, $lt:ty, $pt:ty, $wt:ty, $lt1:ty, $pt1:ty, $wt1:ty, $lt2:ty, $pt2:ty, $wt2: ty) => {
872 #[pyfunction(name = $pyname)]
873 #[pyo3(signature = (labels1, predictions1, weights1, labels2, predictions2, weights2, *, max_fpr=None))]
874 pub fn $fname<'py>(
875 py: Python<'py>,
876 labels1: PyReadonlyArray1<'py, $lt1>,
877 predictions1: PyReadonlyArray1<'py, $pt1>,
878 weights1: Option<PyReadonlyArray1<'py, $wt1>>,
879 labels2: PyReadonlyArray1<'py, $lt2>,
880 predictions2: PyReadonlyArray1<'py, $pt2>,
881 weights2: Option<PyReadonlyArray1<'py, $wt2>>,
882 max_fpr: Option<f32>,
883 ) -> $pt
884 {
885 return RocAucPyGeneric::new(max_fpr).score_two_sorted_samples_py_generic::<$lt, $pt, $wt, $lt1, $lt2, $pt1, $pt2, $wt1, $wt2>(py, labels1, predictions1, weights1, labels2, predictions2, weights2);
886 }
887 };
888 ($fname: ident, $pyname:literal, $lt:ty, $pt:ty, $wt:ty, $lt1:ty, $pt1:ty, $wt1:ty, $lt2:ty, $pt2:ty, $wt2: ty, $py_module:ident) => {
889 roc_auc_on_two_sorted_samples_py!($fname, $pyname, $lt, $pt, $wt, $lt1, $pt1, $wt1, $lt2, $pt2, $wt2);
890 $py_module.add_function(wrap_pyfunction!($fname, $py_module)?).unwrap();
891 };
892}
893
894
895#[pyfunction(name = "loo_cossim")]
896#[pyo3(signature = (data))]
897pub fn loo_cossim_py<'py>(
898 py: Python<'py>,
899 data: &Bound<'py, PyUntypedArray>
900) -> PyResult<f64> {
901 if data.ndim() != 2 {
902 return Err(PyTypeError::new_err(format!("Expected 2-dimensional array for data (samples x features) but found {} dimenisons.", data.ndim())));
903 }
904
905 let dt = data.dtype();
906 if dt.is_equiv_to(&dtype::<f32>(py)) {
907 let typed_data = data.downcast::<PyArray2<f32>>().unwrap().readonly();
908 let array = typed_data.as_array();
909 let score = py.allow_threads(move || {
910 loo_cossim_single(&array)
911 });
912 return Ok(score as f64);
913 }
914 if dt.is_equiv_to(&dtype::<f64>(py)) {
915 let typed_data = data.downcast::<PyArray2<f64>>().unwrap().readonly();
916 let array = typed_data.as_array();
917 let score = py.allow_threads(move || {
918 loo_cossim_single(&array)
919 });
920 return Ok(score);
921 }
922 return Err(PyTypeError::new_err(format!("Only float32 and float64 data supported, but found {}", dt)));
923}
924
925pub fn loo_cossim_many_generic_py<'py, F: num::Float + AddAssign + Element>(
926 py: Python<'py>,
927 data: &Bound<'py, PyArrayDyn<F>>
928) -> PyResult<Bound<'py, PyArray1<F>>> {
929 if data.ndim() != 3 {
930 return Err(PyTypeError::new_err(format!("Expected 3-dimensional array for data (outer(?) x samples x features) but found {} dimenisons.", data.ndim())));
931 }
932 let typed_data = data.downcast::<PyArray3<F>>().unwrap().readonly();
933 let array = typed_data.as_array();
934 let score = py.allow_threads(move || {
935 loo_cossim_many(&array)
936 });
937 let score_py = PyArray::from_owned_array(py, score);
939 return Ok(score_py);
940}
941
942#[pyfunction(name = "loo_cossim_many_f64")]
943#[pyo3(signature = (data))]
944pub fn loo_cossim_many_py_f64<'py>(
945 py: Python<'py>,
946 data: &Bound<'py, PyUntypedArray>
947) -> PyResult<Bound<'py, PyArray1<f64>>> {
948 if data.ndim() != 3 {
949 return Err(PyTypeError::new_err(format!("Expected 3-dimensional array for data (outer(?) x samples x features) but found {} dimenisons.", data.ndim())));
950 }
951
952 let dt = data.dtype();
953 if !dt.is_equiv_to(&dtype::<f64>(py)) {
954 return Err(PyTypeError::new_err(format!("Only float64 data supported, but found {}", dt)));
955 }
956 let typed_data = data.downcast::<PyArrayDyn<f64>>().unwrap();
957 return loo_cossim_many_generic_py(py, typed_data);
958}
959
960#[pyfunction(name = "loo_cossim_many_f32")]
961#[pyo3(signature = (data))]
962pub fn loo_cossim_many_py_f32<'py>(
963 py: Python<'py>,
964 data: &Bound<'py, PyUntypedArray>
965) -> PyResult<Bound<'py, PyArray1<f32>>> {
966 if data.ndim() != 3 {
967 return Err(PyTypeError::new_err(format!("Expected 3-dimensional array for data (outer(?) x samples x features) but found {} dimenisons.", data.ndim())));
968 }
969
970 let dt = data.dtype();
971 if !dt.is_equiv_to(&dtype::<f32>(py)) {
972 return Err(PyTypeError::new_err(format!("Only float32 data supported, but found {}", dt)));
973 }
974 let typed_data = data.downcast::<PyArrayDyn<f32>>().unwrap();
975 return loo_cossim_many_generic_py(py, typed_data);
976}
977
978#[pymodule(name = "_scors")]
979fn scors(m: &Bound<'_, PyModule>) -> PyResult<()> {
980 average_precision_py!(average_precision_bool_f32, "average_precision_bool_f32", bool, f32, f32, m);
981 average_precision_py!(average_precision_i8_f32, "average_precision_i8_f32", i8, f32, f32, m);
982 average_precision_py!(average_precision_i16_f32, "average_precision_i16_f32", i16, f32, f32, m);
983 average_precision_py!(average_precision_i32_f32, "average_precision_i32_f32", i32, f32, f32, m);
984 average_precision_py!(average_precision_i64_f32, "average_precision_i64_f32", i64, f32, f32, m);
985 average_precision_py!(average_precision_u8_f32, "average_precision_u8_f32", u8, f32, f32, m);
986 average_precision_py!(average_precision_u16_f32, "average_precision_u16_f32", u16, f32, f32, m);
987 average_precision_py!(average_precision_u32_f32, "average_precision_u32_f32", u32, f32, f32, m);
988 average_precision_py!(average_precision_u64_f32, "average_precision_u64_f32", u64, f32, f32, m);
989 average_precision_py!(average_precision_bool_f64, "average_precision_bool_f64", bool, f64, f64, m);
990 average_precision_py!(average_precision_i8_f64, "average_precision_i8_f64", i8, f64, f64, m);
991 average_precision_py!(average_precision_i16_f64, "average_precision_i16_f64", i16, f64, f64, m);
992 average_precision_py!(average_precision_i32_f64, "average_precision_i32_f64", i32, f64, f64, m);
993 average_precision_py!(average_precision_i64_f64, "average_precision_i64_f64", i64, f64, f64, m);
994 average_precision_py!(average_precision_u8_f64, "average_precision_u8_f64", u8, f64, f64, m);
995 average_precision_py!(average_precision_u16_f64, "average_precision_u16_f64", u16, f64, f64, m);
996 average_precision_py!(average_precision_u32_f64, "average_precision_u32_f64", u32, f64, f64, m);
997 average_precision_py!(average_precision_u64_f64, "average_precision_u64_f64", u64, f64, f64, m);
998
999 roc_auc_py!(roc_auc_bool_f32, "roc_auc_bool_f32", bool, f32, f32, m);
1000 roc_auc_py!(roc_auc_i8_f32, "roc_auc_i8_f32", i8, f32, f32, m);
1001 roc_auc_py!(roc_auc_i16_f32, "roc_auc_i16_f32", i16, f32, f32, m);
1002 roc_auc_py!(roc_auc_i32_f32, "roc_auc_i32_f32", i32, f32, f32, m);
1003 roc_auc_py!(roc_auc_i64_f32, "roc_auc_i64_f32", i64, f32, f32, m);
1004 roc_auc_py!(roc_auc_u8_f32, "roc_auc_u8_f32", u8, f32, f32, m);
1005 roc_auc_py!(roc_auc_u16_f32, "roc_auc_u16_f32", u16, f32, f32, m);
1006 roc_auc_py!(roc_auc_u32_f32, "roc_auc_u32_f32", u32, f32, f32, m);
1007 roc_auc_py!(roc_auc_u64_f32, "roc_auc_u64_f32", u64, f32, f32, m);
1008 roc_auc_py!(roc_auc_bool_f64, "roc_auc_bool_f64", bool, f64, f64, m);
1009 roc_auc_py!(roc_auc_i8_f64, "roc_auc_i8_f64", i8, f64, f64, m);
1010 roc_auc_py!(roc_auc_i16_f64, "roc_auc_i16_f64", i16, f64, f64, m);
1011 roc_auc_py!(roc_auc_i32_f64, "roc_auc_i32_f64", i32, f64, f64, m);
1012 roc_auc_py!(roc_auc_i64_f64, "roc_auc_i64_f64", i64, f64, f64, m);
1013 roc_auc_py!(roc_auc_u8_f64, "roc_auc_u8_f64", u8, f64, f64, m);
1014 roc_auc_py!(roc_auc_u16_f64, "roc_auc_u16_f64", u16, f64, f64, m);
1015 roc_auc_py!(roc_auc_u32_f64, "roc_auc_u32_f64", u32, f64, f64, m);
1016 roc_auc_py!(roc_auc_u64_f64, "roc_auc_u64_f64", u64, f64, f64, m);
1017
1018 average_precision_on_two_sorted_samples_py!(average_precision_on_two_sorted_samples_bool_f32, "average_precision_on_two_sorted_samples_bool_f32", bool, f32, f32, bool, f32, f32, bool, f32, f32, m);
1019 average_precision_on_two_sorted_samples_py!(average_precision_on_two_sorted_samples_i8_f32, "average_precision_on_two_sorted_samples_i8_f32", i8, f32, f32, i8, f32, f32, i8, f32, f32, m);
1020 average_precision_on_two_sorted_samples_py!(average_precision_on_two_sorted_samples_i16_f32, "average_precision_on_two_sorted_samples_i16_f32", i16, f32, f32, i16, f32, f32, i16, f32, f32, m);
1021 average_precision_on_two_sorted_samples_py!(average_precision_on_two_sorted_samples_i32_f32, "average_precision_on_two_sorted_samples_i32_f32", i32, f32, f32, i32, f32, f32, i32, f32, f32, m);
1022 average_precision_on_two_sorted_samples_py!(average_precision_on_two_sorted_samples_i64_f32, "average_precision_on_two_sorted_samples_i64_f32", i64, f32, f32, i64, f32, f32, i64, f32, f32, m);
1023 average_precision_on_two_sorted_samples_py!(average_precision_on_two_sorted_samples_u8_f32, "average_precision_on_two_sorted_samples_u8_f32", u8, f32, f32, u8, f32, f32, u8, f32, f32, m);
1024 average_precision_on_two_sorted_samples_py!(average_precision_on_two_sorted_samples_u16_f32, "average_precision_on_two_sorted_samples_u16_f32", u16, f32, f32, u16, f32, f32, u16, f32, f32, m);
1025 average_precision_on_two_sorted_samples_py!(average_precision_on_two_sorted_samples_u32_f32, "average_precision_on_two_sorted_samples_u32_f32", u32, f32, f32, u32, f32, f32, u32, f32, f32, m);
1026 average_precision_on_two_sorted_samples_py!(average_precision_on_two_sorted_samples_u64_f32, "average_precision_on_two_sorted_samples_u64_f32", u64, f32, f32, u64, f32, f32, u64, f32, f32, m);
1027 average_precision_on_two_sorted_samples_py!(average_precision_on_two_sorted_samples_bool_f64, "average_precision_on_two_sorted_samples_bool_f64", bool, f64, f64, bool, f64, f64, bool, f64, f64, m);
1028 average_precision_on_two_sorted_samples_py!(average_precision_on_two_sorted_samples_i8_f64, "average_precision_on_two_sorted_samples_i8_f64", i8, f64, f64, i8, f64, f64, i8, f64, f64, m);
1029 average_precision_on_two_sorted_samples_py!(average_precision_on_two_sorted_samples_i16_f64, "average_precision_on_two_sorted_samples_i16_f64", i16, f64, f64, i16, f64, f64, i16, f64, f64, m);
1030 average_precision_on_two_sorted_samples_py!(average_precision_on_two_sorted_samples_i32_f64, "average_precision_on_two_sorted_samples_i32_f64", i32, f64, f64, i16, f64, f64, i16, f64, f64, m);
1031 average_precision_on_two_sorted_samples_py!(average_precision_on_two_sorted_samples_i64_f64, "average_precision_on_two_sorted_samples_i64_f64", i64, f64, f64, i64, f64, f64, i64, f64, f64, m);
1032 average_precision_on_two_sorted_samples_py!(average_precision_on_two_sorted_samples_u8_f64, "average_precision_on_two_sorted_samples_u8_f64", u8, f64, f64, u8, f64, f64, u8, f64, f64, m);
1033 average_precision_on_two_sorted_samples_py!(average_precision_on_two_sorted_samples_u16_f64, "average_precision_on_two_sorted_samples_u16_f64", u16, f64, f64, u16, f64, f64, u16, f64, f64, m);
1034 average_precision_on_two_sorted_samples_py!(average_precision_on_two_sorted_samples_u32_f64, "average_precision_on_two_sorted_samples_u32_f64", u32, f64, f64, u32, f64, f64, u32, f64, f64, m);
1035 average_precision_on_two_sorted_samples_py!(average_precision_on_two_sorted_samples_u64_f64, "average_precision_on_two_sorted_samples_u64_f64", u64, f64, f64, u64, f64, f64, u64, f64, f64, m);
1036
1037 roc_auc_on_two_sorted_samples_py!(roc_auc_on_two_sorted_samples_bool_f32, "roc_auc_on_two_sorted_samples_bool_f32", bool, f32, f32, bool, f32, f32, bool, f32, f32, m);
1038 roc_auc_on_two_sorted_samples_py!(roc_auc_on_two_sorted_samples_i8_f32, "roc_auc_on_two_sorted_samples_i8_f32", i8, f32, f32, i8, f32, f32, i8, f32, f32, m);
1039 roc_auc_on_two_sorted_samples_py!(roc_auc_on_two_sorted_samples_i16_f32, "roc_auc_on_two_sorted_samples_i16_f32", i16, f32, f32, i16, f32, f32, i16, f32, f32, m);
1040 roc_auc_on_two_sorted_samples_py!(roc_auc_on_two_sorted_samples_i32_f32, "roc_auc_on_two_sorted_samples_i32_f32", i32, f32, f32, i32, f32, f32, i32, f32, f32, m);
1041 roc_auc_on_two_sorted_samples_py!(roc_auc_on_two_sorted_samples_i64_f32, "roc_auc_on_two_sorted_samples_i64_f32", i64, f32, f32, i64, f32, f32, i64, f32, f32, m);
1042 roc_auc_on_two_sorted_samples_py!(roc_auc_on_two_sorted_samples_u8_f32, "roc_auc_on_two_sorted_samples_u8_f32", u8, f32, f32, u8, f32, f32, u8, f32, f32, m);
1043 roc_auc_on_two_sorted_samples_py!(roc_auc_on_two_sorted_samples_u16_f32, "roc_auc_on_two_sorted_samples_u16_f32", u16, f32, f32, u16, f32, f32, u16, f32, f32, m);
1044 roc_auc_on_two_sorted_samples_py!(roc_auc_on_two_sorted_samples_u32_f32, "roc_auc_on_two_sorted_samples_u32_f32", u32, f32, f32, u32, f32, f32, u32, f32, f32, m);
1045 roc_auc_on_two_sorted_samples_py!(roc_auc_on_two_sorted_samples_u64_f32, "roc_auc_on_two_sorted_samples_u64_f32", u64, f32, f32, u64, f32, f32, u64, f32, f32, m);
1046 roc_auc_on_two_sorted_samples_py!(roc_auc_on_two_sorted_samples_bool_f64, "roc_auc_on_two_sorted_samples_bool_f64", bool, f64, f64, bool, f64, f64, bool, f64, f64, m);
1047 roc_auc_on_two_sorted_samples_py!(roc_auc_on_two_sorted_samples_i8_f64, "roc_auc_on_two_sorted_samples_i8_f64", i8, f64, f64, i8, f64, f64, i8, f64, f64, m);
1048 roc_auc_on_two_sorted_samples_py!(roc_auc_on_two_sorted_samples_i16_f64, "roc_auc_on_two_sorted_samples_i16_f64", i16, f64, f64, i16, f64, f64, i16, f64, f64, m);
1049 roc_auc_on_two_sorted_samples_py!(roc_auc_on_two_sorted_samples_i32_f64, "roc_auc_on_two_sorted_samples_i32_f64", i32, f64, f64, i16, f64, f64, i16, f64, f64, m);
1050 roc_auc_on_two_sorted_samples_py!(roc_auc_on_two_sorted_samples_i64_f64, "roc_auc_on_two_sorted_samples_i64_f64", i64, f64, f64, i64, f64, f64, i64, f64, f64, m);
1051 roc_auc_on_two_sorted_samples_py!(roc_auc_on_two_sorted_samples_u8_f64, "roc_auc_on_two_sorted_samples_u8_f64", u8, f64, f64, u8, f64, f64, u8, f64, f64, m);
1052 roc_auc_on_two_sorted_samples_py!(roc_auc_on_two_sorted_samples_u16_f64, "roc_auc_on_two_sorted_samples_u16_f64", u16, f64, f64, u16, f64, f64, u16, f64, f64, m);
1053 roc_auc_on_two_sorted_samples_py!(roc_auc_on_two_sorted_samples_u32_f64, "roc_auc_on_two_sorted_samples_u32_f64", u32, f64, f64, u32, f64, f64, u32, f64, f64, m);
1054 roc_auc_on_two_sorted_samples_py!(roc_auc_on_two_sorted_samples_u64_f64, "roc_auc_on_two_sorted_samples_u64_f64", u64, f64, f64, u64, f64, f64, u64, f64, f64, m);
1055
1056 m.add_function(wrap_pyfunction!(loo_cossim_py, m)?).unwrap();
1057 m.add_function(wrap_pyfunction!(loo_cossim_many_py_f64, m)?).unwrap();
1058 m.add_function(wrap_pyfunction!(loo_cossim_many_py_f32, m)?).unwrap();
1059 m.add_class::<PyOrder>().unwrap();
1060 return Ok(());
1061}
1062
1063
1064#[cfg(test)]
1065mod tests {
1066 use super::*;
1067
1068 #[test]
1069 fn test_average_precision_on_sorted() {
1070 let labels: [u8; 4] = [1, 0, 1, 0];
1071 let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
1072 let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
1073 let actual: f64 = score_sorted_sample(AveragePrecision::new(), &predictions, &labels, &weights, Order::DESCENDING);
1074 assert_eq!(actual, 0.8333333333333333);
1075 }
1076
1077 #[test]
1078 fn test_average_precision_on_sorted_double() {
1079 let labels: [u8; 8] = [1, 1, 0, 0, 1, 1, 0, 0];
1080 let predictions: [f64; 8] = [0.8, 0.8, 0.4, 0.4, 0.35, 0.35, 0.1, 0.1];
1081 let weights: [f64; 8] = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
1082 let actual: f64 = score_sorted_sample(AveragePrecision::new(), &predictions, &labels, &weights, Order::DESCENDING);
1083 assert_eq!(actual, 0.8333333333333333);
1084 }
1085
1086 #[test]
1087 fn test_average_precision_unsorted() {
1088 let labels: [u8; 4] = [0, 0, 1, 1];
1089 let predictions: [f64; 4] = [0.1, 0.4, 0.35, 0.8];
1090 let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
1091 let actual: f64 = average_precision(&predictions, &labels, Some(&weights), None);
1092 assert_eq!(actual, 0.8333333333333333);
1093 }
1094
1095 #[test]
1096 fn test_average_precision_sorted() {
1097 let labels: [u8; 4] = [1, 0, 1, 0];
1098 let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
1099 let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
1100 let actual: f64 = average_precision(&predictions, &labels, Some(&weights), Some(Order::DESCENDING));
1101 assert_eq!(actual, 0.8333333333333333);
1102 }
1103
1104 #[test]
1105 fn test_average_precision_sorted_pair() {
1106 let labels: [u8; 4] = [1, 0, 1, 0];
1107 let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
1108 let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
1109 let actual: f64 = score_two_sorted_samples(
1110 AveragePrecision::new(),
1111 predictions.iter().cloned(),
1112 labels.iter().cloned(),
1113 weights.iter().cloned(),
1114 predictions.iter().cloned(),
1115 labels.iter().cloned(),
1116 weights.iter().cloned()
1117 );
1118 assert_eq!(actual, 0.8333333333333333);
1119 }
1120
1121 #[test]
1122 fn test_roc_auc() {
1123 let labels: [u8; 4] = [1, 0, 1, 0];
1124 let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
1125 let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
1126 let actual: f64 = roc_auc(&predictions, &labels, Some(&weights), Some(Order::DESCENDING), None);
1127 assert_eq!(actual, 0.75);
1128 }
1129
1130 #[test]
1131 fn test_roc_auc_double() {
1132 let labels: [u8; 8] = [1, 0, 1, 0, 1, 0, 1, 0];
1133 let predictions: [f64; 8] = [0.8, 0.4, 0.35, 0.1, 0.8, 0.4, 0.35, 0.1];
1134 let actual: f64 = roc_auc(&predictions, &labels, None::<&[f64; 8]>, None, None);
1135 assert_eq!(actual, 0.75);
1136 }
1137
1138 #[test]
1139 fn test_roc_sorted_pair() {
1140 let labels: [u8; 4] = [1, 0, 1, 0];
1141 let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
1142 let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
1143 let actual: f64 = score_two_sorted_samples(
1144 RocAuc::new(),
1145 predictions.iter().cloned(),
1146 labels.iter().cloned(),
1147 weights.iter().cloned(),
1148 predictions.iter().cloned(),
1149 labels.iter().cloned(),
1150 weights.iter().cloned()
1151 );
1152 assert_eq!(actual, 0.75);
1153 }
1154
1155 #[test]
1156 fn test_roc_auc_max_fpr() {
1157 let labels: [u8; 4] = [1, 0, 1, 0];
1158 let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
1159 let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
1160 let actual: f64 = roc_auc(&predictions, &labels, Some(&weights), Some(Order::DESCENDING), Some(0.25));
1161 assert_eq!(actual, 0.7142857142857143);
1162 }
1163
1164 #[test]
1165 fn test_roc_auc_max_fpr_double() {
1166 let labels: [u8; 8] = [1, 0, 1, 0, 1, 0, 1, 0];
1167 let predictions: [f64; 8] = [0.8, 0.4, 0.35, 0.1, 0.8, 0.4, 0.35, 0.1];
1168 let actual: f64 = roc_auc(&predictions, &labels, None::<&[f64; 8]>, None, Some(0.25));
1169 assert_eq!(actual, 0.7142857142857143);
1170 }
1171
1172 #[test]
1173 fn test_roc_auc_max_fpr_sorted_pair() {
1174 let labels: [u8; 4] = [1, 0, 1, 0];
1175 let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
1176 let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
1177 let actual: f64 = score_two_sorted_samples(
1178 RocAucWithMaxFPR::new(0.25),
1179 predictions.iter().cloned(),
1180 labels.iter().cloned(),
1181 weights.iter().cloned(),
1182 predictions.iter().cloned(),
1183 labels.iter().cloned(),
1184 weights.iter().cloned()
1185 );
1186 assert_eq!(actual, 0.7142857142857143);
1187 }
1188}