1use ndarray::{ArrayView,Ix1};
2use numpy::{Element,PyArray1,PyArrayDescr,PyArrayDescrMethods,PyArrayMethods,PyReadonlyArray1,PyUntypedArray,PyUntypedArrayMethods,dtype};
3use pyo3::Bound;
4use pyo3::exceptions::PyTypeError;
5use pyo3::prelude::*;
6use std::iter::DoubleEndedIterator;
7
8#[derive(Clone, Copy)]
9pub enum Order {
10 ASCENDING,
11 DESCENDING
12}
13
14struct ConstWeight {
15 value: f64
16}
17
18impl ConstWeight {
19 fn new(value: f64) -> Self {
20 return ConstWeight { value: value };
21 }
22 fn one() -> Self {
23 return Self::new(1.0);
24 }
25}
26
27pub trait Data<T: Clone>: {
28 fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T>;
31 fn get_at(&self, index: usize) -> T;
32}
33
34pub trait SortableData<T> {
35 fn argsort_unstable(&self) -> Vec<usize>;
36}
37
38impl Iterator for ConstWeight {
39 type Item = f64;
40 fn next(&mut self) -> Option<f64> {
41 return Some(self.value);
42 }
43}
44
45impl DoubleEndedIterator for ConstWeight {
46 fn next_back(&mut self) -> Option<f64> {
47 return Some(self.value);
48 }
49}
50
51impl Data<f64> for ConstWeight {
52 fn get_iterator(&self) -> impl DoubleEndedIterator<Item = f64> {
53 return ConstWeight::new(self.value);
54 }
55
56 fn get_at(&self, _index: usize) -> f64 {
57 return self.value.clone();
58 }
59}
60
61impl <T: Clone> Data<T> for Vec<T> {
62 fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> {
63 return self.iter().cloned();
64 }
65 fn get_at(&self, index: usize) -> T {
66 return self[index].clone();
67 }
68}
69
70impl SortableData<f64> for Vec<f64> {
71 fn argsort_unstable(&self) -> Vec<usize> {
72 let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
73 indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
74 return indices;
76 }
77}
78
79impl <T: Clone> Data<T> for &[T] {
80 fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> {
81 return self.iter().cloned();
82 }
83 fn get_at(&self, index: usize) -> T {
84 return self[index].clone();
85 }
86}
87
88impl SortableData<f64> for &[f64] {
89 fn argsort_unstable(&self) -> Vec<usize> {
90 let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
92 indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
95 return indices;
97 }
98}
99
100impl <T: Clone, const N: usize> Data<T> for [T; N] {
101 fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> {
102 return self.iter().cloned();
103 }
104 fn get_at(&self, index: usize) -> T {
105 return self[index].clone();
106 }
107}
108
109impl <const N: usize> SortableData<f64> for [f64; N] {
110 fn argsort_unstable(&self) -> Vec<usize> {
111 let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
112 indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
113 return indices;
114 }
115}
116
117impl <T: Clone> Data<T> for ArrayView<'_, T, Ix1> {
118 fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> {
119 return self.iter().cloned();
120 }
121 fn get_at(&self, index: usize) -> T {
122 return self[index].clone();
123 }
124}
125
126impl SortableData<f64> for ArrayView<'_, f64, Ix1> {
127 fn argsort_unstable(&self) -> Vec<usize> {
128 let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
129 indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
130 return indices;
131 }
132}
133
134pub trait BinaryLabel: Clone + Copy {
143 fn get_value(&self) -> bool;
144}
145
146impl BinaryLabel for bool {
147 fn get_value(&self) -> bool {
148 return self.clone();
149 }
150}
151
152impl BinaryLabel for u8 {
153 fn get_value(&self) -> bool {
154 return (self & 1) == 1;
155 }
156}
157
158impl BinaryLabel for u16 {
159 fn get_value(&self) -> bool {
160 return (self & 1) == 1;
161 }
162}
163
164impl BinaryLabel for u32 {
165 fn get_value(&self) -> bool {
166 return (self & 1) == 1;
167 }
168}
169
170impl BinaryLabel for u64 {
171 fn get_value(&self) -> bool {
172 return (self & 1) == 1;
173 }
174}
175
176impl BinaryLabel for i8 {
177 fn get_value(&self) -> bool {
178 return (self & 1) == 1;
179 }
180}
181
182impl BinaryLabel for i16 {
183 fn get_value(&self) -> bool {
184 return (self & 1) == 1;
185 }
186}
187
188impl BinaryLabel for i32 {
189 fn get_value(&self) -> bool {
190 return (self & 1) == 1;
191 }
192}
193
194impl BinaryLabel for i64 {
195 fn get_value(&self) -> bool {
196 return (self & 1) == 1;
197 }
198}
199
200fn select<T, I>(slice: &I, indices: &[usize]) -> Vec<T>
201where T: Copy, I: Data<T>
202{
203 let mut selection: Vec<T> = Vec::new();
204 selection.reserve_exact(indices.len());
205 for index in indices {
206 selection.push(slice.get_at(*index));
207 }
208 return selection;
209}
210
211pub fn average_precision<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>) -> f64
212where B: BinaryLabel, L: Data<B>, P: SortableData<f64>, W: Data<f64>
213{
214 return average_precision_with_order(labels, predictions, weights, None);
215}
216
217pub fn average_precision_with_order<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>, order: Option<Order>) -> f64
218where B: BinaryLabel, L: Data<B>, P: SortableData<f64>, W: Data<f64>
219{
220 return match order {
221 Some(o) => average_precision_on_sorted_labels(labels, weights, o),
222 None => {
223 let indices = predictions.argsort_unstable();
224 let sorted_labels = select(labels, &indices);
225 let ap = match weights {
226 None => {
227 average_precision_on_sorted_labels(&sorted_labels, weights, Order::DESCENDING)
229 },
230 Some(w) => average_precision_on_sorted_labels(&sorted_labels, Some(&select(w, &indices)), Order::DESCENDING),
231 };
232 ap
233 }
234 };
235}
236
237pub fn average_precision_on_sorted_labels<B, L, W>(labels: &L, weights: Option<&W>, order: Order) -> f64
238where B: BinaryLabel, L: Data<B>, W: Data<f64>
239{
240 return match weights {
241 None => average_precision_on_iterator(labels.get_iterator(), ConstWeight::one(), order),
242 Some(w) => average_precision_on_iterator(labels.get_iterator(), w.get_iterator(), order)
243 };
244}
245
246pub fn average_precision_on_iterator<B, L, W>(labels: L, weights: W, order: Order) -> f64
247where B: BinaryLabel, L: DoubleEndedIterator<Item = B>, W: DoubleEndedIterator<Item = f64>
248{
249 return match order {
250 Order::ASCENDING => average_precision_on_descending_iterator(labels.rev(), weights.rev()),
251 Order::DESCENDING => average_precision_on_descending_iterator(labels, weights)
252 };
253}
254
255pub fn average_precision_on_descending_iterator<B: BinaryLabel>(labels: impl Iterator<Item = B>, weights: impl Iterator<Item = f64>) -> f64 {
256 let mut ap: f64 = 0.0;
257 let mut tps: f64 = 0.0;
258 let mut fps: f64 = 0.0;
259 for (label, weight) in labels.zip(weights) {
260 let w: f64 = weight;
261 let l: bool = label.get_value();
262 let tp = w * f64::from(l);
263 tps += tp;
264 fps += weight - tp;
265 let ps = tps + fps;
266 let precision = tps / ps;
267 ap += tp * precision;
268 }
269 return ap / tps;
270}
271
272
273
274pub fn roc_auc<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>) -> f64
276where B: BinaryLabel, L: Data<B>, P: SortableData<f64> + Data<f64>, W: Data<f64>
277{
278 return roc_auc_with_order(labels, predictions, weights, None, None);
279}
280
281pub fn roc_auc_max_fpr<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>, max_false_positive_rate: Option<f64>) -> f64
282where B: BinaryLabel, L: Data<B>, P: SortableData<f64> + Data<f64>, W: Data<f64>
283{
284 return roc_auc_with_order(labels, predictions, weights, None, max_false_positive_rate);
285}
286
287pub fn roc_auc_with_order<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>, order: Option<Order>, max_false_positive_rate: Option<f64>) -> f64
288where B: BinaryLabel, L: Data<B>, P: SortableData<f64> + Data<f64>, W: Data<f64>
289{
290 return match order {
291 Some(o) => roc_auc_on_sorted_labels(labels, predictions, weights, o, max_false_positive_rate),
292 None => {
293 let indices = predictions.argsort_unstable();
294 let sorted_labels = select(labels, &indices);
295 let sorted_predictions = select(predictions, &indices);
296 let roc_auc_score = match weights {
297 Some(w) => {
298 let sorted_weights = select(w, &indices);
299 roc_auc_on_sorted_labels(&sorted_labels, &sorted_predictions, Some(&sorted_weights), Order::DESCENDING, max_false_positive_rate)
300 },
301 None => {
302 roc_auc_on_sorted_labels(&sorted_labels, &sorted_predictions, None::<&W>, Order::DESCENDING, max_false_positive_rate)
303 }
304 };
305 roc_auc_score
306 }
307 };
308}
309pub fn roc_auc_on_sorted_labels<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>, order: Order, max_false_positive_rate: Option<f64>) -> f64
310where B: BinaryLabel, L: Data<B>, P: Data<f64>, W: Data<f64> {
311 return match max_false_positive_rate {
312 None => match weights {
313 Some(w) => roc_auc_on_sorted_iterator(&mut labels.get_iterator(), &mut predictions.get_iterator(), &mut w.get_iterator(), order),
314 None => roc_auc_on_sorted_iterator(&mut labels.get_iterator(), &mut predictions.get_iterator(), &mut ConstWeight::one().get_iterator(), order),
315 }
316 Some(max_fpr) => match weights {
317 Some(w) => roc_auc_on_sorted_with_fp_cutoff(labels, predictions, w, order, max_fpr),
318 None => roc_auc_on_sorted_with_fp_cutoff(labels, predictions, &ConstWeight::one(), order, max_fpr),
319 }
320 };
321}
322
323pub fn roc_auc_on_sorted_iterator<B: BinaryLabel>(
324 labels: &mut impl DoubleEndedIterator<Item = B>,
325 predictions: &mut impl DoubleEndedIterator<Item = f64>,
326 weights: &mut impl DoubleEndedIterator<Item = f64>,
327 order: Order
328) -> f64 {
329 return match order {
330 Order::ASCENDING => roc_auc_on_descending_iterator(&mut labels.rev(), &mut predictions.rev(), &mut weights.rev()),
331 Order::DESCENDING => roc_auc_on_descending_iterator(labels, predictions, weights)
332 }
333}
334
335pub fn roc_auc_on_descending_iterator<B: BinaryLabel>(
336 labels: &mut impl Iterator<Item = B>,
337 predictions: &mut impl Iterator<Item = f64>,
338 weights: &mut impl Iterator<Item = f64>
339) -> f64 {
340 let mut false_positives: f64 = 0.0;
341 let mut true_positives: f64 = 0.0;
342 let mut last_counted_fp = 0.0;
343 let mut last_counted_tp = 0.0;
344 let mut area_under_curve = 0.0;
345 let mut zipped = labels.zip(predictions).zip(weights).peekable();
346 loop {
347 match zipped.next() {
348 None => break,
349 Some(actual) => {
350 let l = f64::from(actual.0.0.get_value());
351 let w = actual.1;
352 let wl = l * w;
353 true_positives += wl;
354 false_positives += w - wl;
355 if zipped.peek().map(|x| x.0.1 != actual.0.1).unwrap_or(true) {
356 area_under_curve += area_under_line_segment(last_counted_fp, false_positives, last_counted_tp, true_positives);
357 last_counted_fp = false_positives;
358 last_counted_tp = true_positives;
359 }
360 }
361 };
362 }
363 return area_under_curve / (true_positives * false_positives);
364}
365
366fn area_under_line_segment(x0: f64, x1: f64, y0: f64, y1: f64) -> f64 {
367 let dx = x1 - x0;
368 let dy = y1 - y0;
369 return dx * y0 + dy * dx * 0.5;
370}
371
372fn get_positive_sum<B: BinaryLabel>(
373 labels: impl Iterator<Item = B>,
374 weights: impl Iterator<Item = f64>
375) -> (f64, f64) {
376 let mut false_positives = 0f64;
377 let mut true_positives = 0f64;
378 for (label, weight) in labels.zip(weights) {
379 let lw = weight * f64::from(label.get_value());
380 false_positives += weight - lw;
381 true_positives += lw;
382 }
383 return (false_positives, true_positives);
384}
385
386pub fn roc_auc_on_sorted_with_fp_cutoff<B, L, P, W>(labels: &L, predictions: &P, weights: &W, order: Order, max_false_positive_rate: f64) -> f64
387where B: BinaryLabel, L: Data<B>, P: Data<f64>, W: Data<f64> {
388 let (fps, tps) = get_positive_sum(labels.get_iterator(), weights.get_iterator());
390 let mut l_it = labels.get_iterator();
391 let mut p_it = predictions.get_iterator();
392 let mut w_it = weights.get_iterator();
393 return match order {
394 Order::ASCENDING => roc_auc_on_descending_iterator_with_fp_cutoff(&mut l_it.rev(), &mut p_it.rev(), &mut w_it.rev(), fps, tps, max_false_positive_rate),
395 Order::DESCENDING => roc_auc_on_descending_iterator_with_fp_cutoff(&mut l_it, &mut p_it, &mut w_it, fps, tps, max_false_positive_rate)
396 };
397}
398
399
400fn roc_auc_on_descending_iterator_with_fp_cutoff<B: BinaryLabel>(
401 labels: &mut impl Iterator<Item = B>,
402 predictions: &mut impl Iterator<Item = f64>,
403 weights: &mut impl Iterator<Item = f64>,
404 false_positive_sum: f64,
405 true_positive_sum: f64,
406 max_false_positive_rate: f64
407) -> f64 {
408 let mut false_positives: f64 = 0.0;
409 let mut true_positives: f64 = 0.0;
410 let mut last_counted_fp = 0.0;
411 let mut last_counted_tp = 0.0;
412 let mut area_under_curve = 0.0;
413 let mut zipped = labels.zip(predictions).zip(weights).peekable();
414 let false_positive_cutoff = max_false_positive_rate * false_positive_sum;
415 loop {
416 match zipped.next() {
417 None => break,
418 Some(actual) => {
419 let l = f64::from(actual.0.0.get_value());
420 let w = actual.1;
421 let wl = l * w;
422 let next_tp = true_positives + wl;
423 let next_fp = false_positives + (w - wl);
424 let is_above_max = next_fp > false_positive_cutoff;
425 if is_above_max {
426 let dx = next_fp - false_positives;
427 let dy = next_tp - true_positives;
428 true_positives += dy * false_positive_cutoff / dx;
429 false_positives = false_positive_cutoff;
430 } else {
431 true_positives = next_tp;
432 false_positives = next_fp;
433 }
434 if zipped.peek().map(|x| x.0.1 != actual.0.1).unwrap_or(true) || is_above_max {
435 area_under_curve += area_under_line_segment(last_counted_fp, false_positives, last_counted_tp, true_positives);
436 last_counted_fp = false_positives;
437 last_counted_tp = true_positives;
438 }
439 if is_above_max {
440 break;
441 }
442 }
443 };
444 }
445 let normalized_area_under_curve = area_under_curve / (true_positive_sum * false_positive_sum);
446 let min_area = 0.5 * max_false_positive_rate * max_false_positive_rate;
447 let max_area = max_false_positive_rate;
448 return 0.5 * (1.0 + (normalized_area_under_curve - min_area) / (max_area - min_area));
449}
450
451
452#[pyclass(eq, eq_int, name="Order")]
454#[derive(Clone, Copy, PartialEq)]
455pub enum PyOrder {
456 ASCENDING,
457 DESCENDING
458}
459
460fn py_order_as_order(order: PyOrder) -> Order {
461 return match order {
462 PyOrder::ASCENDING => Order::ASCENDING,
463 PyOrder::DESCENDING => Order::DESCENDING,
464 }
465}
466
467fn average_precision_py_generic<'py, B>(
468 py: Python<'py>,
469 labels: &PyReadonlyArray1<'py, B>,
470 predictions: &PyReadonlyArray1<'py, f64>,
471 weights: &Option<PyReadonlyArray1<'py, f64>>,
472 order: &Option<PyOrder>
473) -> f64
474where B: BinaryLabel + Element
475{
476 let labels = labels.as_array();
477 let predictions = predictions.as_array();
478 let order = order.map(py_order_as_order);
479 let ap = match weights {
480 Some(w) => {
481 let weights = w.as_array();
482 py.allow_threads(move || {
483 average_precision_with_order(&labels, &predictions, Some(&weights), order)
484 })
485 },
486 None => {
487 py.allow_threads(move || {
488 average_precision_with_order(&labels, &predictions, None::<&ArrayView<'_, f64, Ix1>>, order)
489 })
490 }
491 };
492 return ap;
493}
494
495fn average_precision_py_match_run<'py, T>(
496 py: Python<'py>,
497 labels: &Bound<'py, PyUntypedArray>,
498 predictions: &PyReadonlyArray1<'py, f64>,
499 weights: &Option<PyReadonlyArray1<'py, f64>>,
500 order: &Option<PyOrder>,
501 dt: &Bound<'py, PyArrayDescr>
502) -> Option<f64>
503where T: Element + BinaryLabel
504{
505 return if dt.is_equiv_to(&dtype::<T>(py)) {
506 let labels = labels.downcast::<PyArray1<T>>().unwrap().readonly();
507 Some(average_precision_py_generic(py, &labels.readonly(), predictions, weights, order))
508 } else {
509 None
510 }
511}
512
513#[pyfunction(name = "average_precision")]
514#[pyo3(signature = (labels, predictions, *, weights=None, order=None))]
515pub fn average_precision_py<'py>(
516 py: Python<'py>,
517 labels: &Bound<'py, PyUntypedArray>,
518 predictions: PyReadonlyArray1<'py, f64>,
519 weights: Option<PyReadonlyArray1<'py, f64>>,
520 order: Option<PyOrder>
521) -> PyResult<f64> {
522 if labels.ndim() != 1 {
523 return Err(PyTypeError::new_err(format!("Expected 1-dimensional array for labels but found {} dimenisons.", labels.ndim())));
524 }
525 let label_dtype = labels.dtype();
526 if let Some(ap) = average_precision_py_match_run::<bool>(py, &labels, &predictions, &weights, &order, &label_dtype) {
527 return Ok(ap)
528 }
529 else if let Some(ap) = average_precision_py_match_run::<u8>(py, &labels, &predictions, &weights, &order, &label_dtype) {
530 return Ok(ap)
531 }
532 else if let Some(ap) = average_precision_py_match_run::<i8>(py, &labels, &predictions, &weights, &order, &label_dtype) {
533 return Ok(ap)
534 }
535 else if let Some(ap) = average_precision_py_match_run::<u16>(py, &labels, &predictions, &weights, &order, &label_dtype) {
536 return Ok(ap)
537 }
538 else if let Some(ap) = average_precision_py_match_run::<i16>(py, &labels, &predictions, &weights, &order, &label_dtype) {
539 return Ok(ap)
540 }
541 else if let Some(ap) = average_precision_py_match_run::<u32>(py, &labels, &predictions, &weights, &order, &label_dtype) {
542 return Ok(ap)
543 }
544 else if let Some(ap) = average_precision_py_match_run::<i32>(py, &labels, &predictions, &weights, &order, &label_dtype) {
545 return Ok(ap)
546 }
547 else if let Some(ap) = average_precision_py_match_run::<u64>(py, &labels, &predictions, &weights, &order, &label_dtype) {
548 return Ok(ap)
549 }
550 else if let Some(ap) = average_precision_py_match_run::<i64>(py, &labels, &predictions, &weights, &order, &label_dtype) {
551 return Ok(ap)
552 }
553 return Err(PyTypeError::new_err(format!("Unsupported dtype for labels: {}. Supported dtypes are bool, uint8, uint16, uint32, uint64, in8, int16, int32, int64", label_dtype)));
554}
555
556fn roc_auc_py_generic<'py, B>(
557 py: Python<'py>,
558 labels: &PyReadonlyArray1<'py, B>,
559 predictions: &PyReadonlyArray1<'py, f64>,
560 weights: &Option<PyReadonlyArray1<'py, f64>>,
561 order: &Option<PyOrder>,
562 max_false_positive_rate: Option<f64>,
563) -> f64
564where B: BinaryLabel + Element
565{
566 let labels = labels.as_array();
567 let predictions = predictions.as_array();
568 let order = order.map(py_order_as_order);
569 let auc = match weights {
570 Some(weight) => {
571 let weights = weight.as_array();
572 py.allow_threads(move || {
573 roc_auc_with_order(&labels, &predictions, Some(&weights), order, max_false_positive_rate)
574 })
575 },
576 None => py.allow_threads(move || {
577 roc_auc_with_order(&labels, &predictions, None::<&Vec<f64>>, order, max_false_positive_rate)
578 })
579 };
580 return auc;
581}
582
583fn roc_auc_py_match_run<'py, T>(
584 py: Python<'py>,
585 labels: &Bound<'py, PyUntypedArray>,
586 predictions: &PyReadonlyArray1<'py, f64>,
587 weights: &Option<PyReadonlyArray1<'py, f64>>,
588 order: &Option<PyOrder>,
589 max_false_positive_rate: Option<f64>,
590 dt: &Bound<'py, PyArrayDescr>
591) -> Option<f64>
592where T: Element + BinaryLabel
593{
594 return if dt.is_equiv_to(&dtype::<T>(py)) {
595 let labels = labels.downcast::<PyArray1<T>>().unwrap().readonly();
596 Some(roc_auc_py_generic(py, &labels.readonly(), predictions, weights, order, max_false_positive_rate))
597 } else {
598 None
599 }
600}
601
602#[pyfunction(name = "roc_auc")]
603#[pyo3(signature = (labels, predictions, *, weights=None, order=None, max_false_positive_rate=None))]
604pub fn roc_auc_py<'py>(
605 py: Python<'py>,
606 labels: &Bound<'py, PyUntypedArray>,
607 predictions: PyReadonlyArray1<'py, f64>,
608 weights: Option<PyReadonlyArray1<'py, f64>>,
609 order: Option<PyOrder>,
610 max_false_positive_rate: Option<f64>,
611) -> PyResult<f64> {
612 if labels.ndim() != 1 {
613 return Err(PyTypeError::new_err(format!("Expected 1-dimensional array for labels but found {} dimenisons.", labels.ndim())));
614 }
615 let label_dtype = labels.dtype();
616 if let Some(auc) = roc_auc_py_match_run::<bool>(py, &labels, &predictions, &weights, &order, max_false_positive_rate, &label_dtype) {
617 return Ok(auc)
618 }
619 else if let Some(auc) = roc_auc_py_match_run::<u8>(py, &labels, &predictions, &weights, &order, max_false_positive_rate, &label_dtype) {
620 return Ok(auc)
621 }
622 else if let Some(auc) = roc_auc_py_match_run::<i8>(py, &labels, &predictions, &weights, &order, max_false_positive_rate, &label_dtype) {
623 return Ok(auc)
624 }
625 else if let Some(auc) = roc_auc_py_match_run::<u16>(py, &labels, &predictions, &weights, &order, max_false_positive_rate, &label_dtype) {
626 return Ok(auc)
627 }
628 else if let Some(auc) = roc_auc_py_match_run::<i16>(py, &labels, &predictions, &weights, &order, max_false_positive_rate, &label_dtype) {
629 return Ok(auc)
630 }
631 else if let Some(auc) = roc_auc_py_match_run::<u32>(py, &labels, &predictions, &weights, &order, max_false_positive_rate, &label_dtype) {
632 return Ok(auc)
633 }
634 else if let Some(auc) = roc_auc_py_match_run::<i32>(py, &labels, &predictions, &weights, &order, max_false_positive_rate, &label_dtype) {
635 return Ok(auc)
636 }
637 else if let Some(auc) = roc_auc_py_match_run::<u64>(py, &labels, &predictions, &weights, &order, max_false_positive_rate, &label_dtype) {
638 return Ok(auc)
639 }
640 else if let Some(auc) = roc_auc_py_match_run::<i64>(py, &labels, &predictions, &weights, &order, max_false_positive_rate, &label_dtype) {
641 return Ok(auc)
642 }
643 return Err(PyTypeError::new_err(format!("Unsupported dtype for labels: {}. Supported dtypes are bool, uint8, uint16, uint32, uint64, in8, int16, int32, int64", label_dtype)));
644}
645
646#[pymodule(name = "_scors")]
647fn scors(m: &Bound<'_, PyModule>) -> PyResult<()> {
648 m.add_function(wrap_pyfunction!(average_precision_py, m)?).unwrap();
649 m.add_function(wrap_pyfunction!(roc_auc_py, m)?).unwrap();
650 m.add_class::<PyOrder>().unwrap();
651 return Ok(());
652}
653
654
655#[cfg(test)]
656mod tests {
657 use super::*;
658
659 #[test]
660 fn test_average_precision_on_sorted() {
661 let labels: [u8; 4] = [1, 0, 1, 0];
662 let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
664 let actual = average_precision_on_sorted_labels(&labels, Some(&weights), Order::DESCENDING);
665 assert_eq!(actual, 0.8333333333333333);
666 }
667
668 #[test]
669 fn test_average_precision_unsorted() {
670 let labels: [u8; 4] = [0, 0, 1, 1];
671 let predictions: [f64; 4] = [0.1, 0.4, 0.35, 0.8];
672 let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
673 let actual = average_precision_with_order(&labels, &predictions, Some(&weights), None);
674 assert_eq!(actual, 0.8333333333333333);
675 }
676
677 #[test]
678 fn test_average_precision_sorted() {
679 let labels: [u8; 4] = [1, 0, 1, 0];
680 let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
681 let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
682 let actual = average_precision_with_order(&labels, &predictions, Some(&weights), Some(Order::DESCENDING));
683 assert_eq!(actual, 0.8333333333333333);
684 }
685
686 #[test]
687 fn test_roc_auc() {
688 let labels: [u8; 4] = [1, 0, 1, 0];
689 let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
690 let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
691 let actual = roc_auc_with_order(&labels, &predictions, Some(&weights), Some(Order::DESCENDING), None);
692 assert_eq!(actual, 0.75);
693 }
694}