1use ndarray::{ArrayView,Ix1};
2use numpy::{Element,PyArray1,PyArrayDescr,PyArrayDescrMethods,PyArrayMethods,PyReadonlyArray1,PyUntypedArray,PyUntypedArrayMethods,dtype};
3use pyo3::Bound;
4use pyo3::exceptions::PyTypeError;
5use pyo3::marker::Ungil;
6use pyo3::prelude::*;
7use std::iter::DoubleEndedIterator;
8
9#[derive(Clone, Copy)]
10pub enum Order {
11 ASCENDING,
12 DESCENDING
13}
14
15struct ConstWeight {
16 value: f64
17}
18
19impl ConstWeight {
20 fn new(value: f64) -> Self {
21 return ConstWeight { value: value };
22 }
23 fn one() -> Self {
24 return Self::new(1.0);
25 }
26}
27
28pub trait Data<T: Clone>: {
29 fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T>;
32 fn get_at(&self, index: usize) -> T;
33}
34
35pub trait SortableData<T> {
36 fn argsort_unstable(&self) -> Vec<usize>;
37}
38
39impl Iterator for ConstWeight {
40 type Item = f64;
41 fn next(&mut self) -> Option<f64> {
42 return Some(self.value);
43 }
44}
45
46impl DoubleEndedIterator for ConstWeight {
47 fn next_back(&mut self) -> Option<f64> {
48 return Some(self.value);
49 }
50}
51
52impl Data<f64> for ConstWeight {
53 fn get_iterator(&self) -> impl DoubleEndedIterator<Item = f64> {
54 return ConstWeight::new(self.value);
55 }
56
57 fn get_at(&self, _index: usize) -> f64 {
58 return self.value.clone();
59 }
60}
61
62impl <T: Clone> Data<T> for Vec<T> {
63 fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> {
64 return self.iter().cloned();
65 }
66 fn get_at(&self, index: usize) -> T {
67 return self[index].clone();
68 }
69}
70
71impl SortableData<f64> for Vec<f64> {
72 fn argsort_unstable(&self) -> Vec<usize> {
73 let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
74 indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
75 return indices;
77 }
78}
79
80impl <T: Clone> Data<T> for &[T] {
81 fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> {
82 return self.iter().cloned();
83 }
84 fn get_at(&self, index: usize) -> T {
85 return self[index].clone();
86 }
87}
88
89impl SortableData<f64> for &[f64] {
90 fn argsort_unstable(&self) -> Vec<usize> {
91 let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
93 indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
96 return indices;
98 }
99}
100
101impl <T: Clone, const N: usize> Data<T> for [T; N] {
102 fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> {
103 return self.iter().cloned();
104 }
105 fn get_at(&self, index: usize) -> T {
106 return self[index].clone();
107 }
108}
109
110impl <const N: usize> SortableData<f64> for [f64; N] {
111 fn argsort_unstable(&self) -> Vec<usize> {
112 let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
113 indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
114 return indices;
115 }
116}
117
118impl <T: Clone> Data<T> for ArrayView<'_, T, Ix1> {
119 fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> {
120 return self.iter().cloned();
121 }
122 fn get_at(&self, index: usize) -> T {
123 return self[index].clone();
124 }
125}
126
127impl SortableData<f64> for ArrayView<'_, f64, Ix1> {
128 fn argsort_unstable(&self) -> Vec<usize> {
129 let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
130 indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
131 return indices;
132 }
133}
134
135pub trait BinaryLabel: Clone + Copy {
144 fn get_value(&self) -> bool;
145}
146
147impl BinaryLabel for bool {
148 fn get_value(&self) -> bool {
149 return self.clone();
150 }
151}
152
153impl BinaryLabel for u8 {
154 fn get_value(&self) -> bool {
155 return (self & 1) == 1;
156 }
157}
158
159impl BinaryLabel for u16 {
160 fn get_value(&self) -> bool {
161 return (self & 1) == 1;
162 }
163}
164
165impl BinaryLabel for u32 {
166 fn get_value(&self) -> bool {
167 return (self & 1) == 1;
168 }
169}
170
171impl BinaryLabel for u64 {
172 fn get_value(&self) -> bool {
173 return (self & 1) == 1;
174 }
175}
176
177impl BinaryLabel for i8 {
178 fn get_value(&self) -> bool {
179 return (self & 1) == 1;
180 }
181}
182
183impl BinaryLabel for i16 {
184 fn get_value(&self) -> bool {
185 return (self & 1) == 1;
186 }
187}
188
189impl BinaryLabel for i32 {
190 fn get_value(&self) -> bool {
191 return (self & 1) == 1;
192 }
193}
194
195impl BinaryLabel for i64 {
196 fn get_value(&self) -> bool {
197 return (self & 1) == 1;
198 }
199}
200
201fn select<T, I>(slice: &I, indices: &[usize]) -> Vec<T>
202where T: Copy, I: Data<T>
203{
204 let mut selection: Vec<T> = Vec::new();
205 selection.reserve_exact(indices.len());
206 for index in indices {
207 selection.push(slice.get_at(*index));
208 }
209 return selection;
210}
211
212pub fn average_precision<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>) -> f64
213where B: BinaryLabel, L: Data<B>, P: SortableData<f64>, W: Data<f64>
214{
215 return average_precision_with_order(labels, predictions, weights, None);
216}
217
218pub fn average_precision_with_order<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>, order: Option<Order>) -> f64
219where B: BinaryLabel, L: Data<B>, P: SortableData<f64>, W: Data<f64>
220{
221 return match order {
222 Some(o) => average_precision_on_sorted_labels(labels, weights, o),
223 None => {
224 let indices = predictions.argsort_unstable();
225 let sorted_labels = select(labels, &indices);
226 let ap = match weights {
227 None => {
228 average_precision_on_sorted_labels(&sorted_labels, weights, Order::DESCENDING)
230 },
231 Some(w) => average_precision_on_sorted_labels(&sorted_labels, Some(&select(w, &indices)), Order::DESCENDING),
232 };
233 ap
234 }
235 };
236}
237
238pub fn average_precision_on_sorted_labels<B, L, W>(labels: &L, weights: Option<&W>, order: Order) -> f64
239where B: BinaryLabel, L: Data<B>, W: Data<f64>
240{
241 return match weights {
242 None => average_precision_on_iterator(labels.get_iterator(), ConstWeight::one(), order),
243 Some(w) => average_precision_on_iterator(labels.get_iterator(), w.get_iterator(), order)
244 };
245}
246
247pub fn average_precision_on_iterator<B, L, W>(labels: L, weights: W, order: Order) -> f64
248where B: BinaryLabel, L: DoubleEndedIterator<Item = B>, W: DoubleEndedIterator<Item = f64>
249{
250 return match order {
251 Order::ASCENDING => average_precision_on_descending_iterator(labels.rev(), weights.rev()),
252 Order::DESCENDING => average_precision_on_descending_iterator(labels, weights)
253 };
254}
255
256pub fn average_precision_on_descending_iterator<B: BinaryLabel>(labels: impl Iterator<Item = B>, weights: impl Iterator<Item = f64>) -> f64 {
257 let mut ap: f64 = 0.0;
258 let mut tps: f64 = 0.0;
259 let mut fps: f64 = 0.0;
260 for (label, weight) in labels.zip(weights) {
261 let w: f64 = weight;
262 let l: bool = label.get_value();
263 let tp = w * f64::from(l);
264 tps += tp;
265 fps += weight - tp;
266 let ps = tps + fps;
267 let precision = tps / ps;
268 ap += tp * precision;
269 }
270 return ap / tps;
271}
272
273
274
275pub fn roc_auc<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>) -> f64
277where B: BinaryLabel, L: Data<B>, P: SortableData<f64> + Data<f64>, W: Data<f64>
278{
279 return roc_auc_with_order(labels, predictions, weights, None, None);
280}
281
282pub fn roc_auc_max_fpr<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>, max_false_positive_rate: Option<f64>) -> f64
283where B: BinaryLabel, L: Data<B>, P: SortableData<f64> + Data<f64>, W: Data<f64>
284{
285 return roc_auc_with_order(labels, predictions, weights, None, max_false_positive_rate);
286}
287
288pub fn roc_auc_with_order<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>, order: Option<Order>, max_false_positive_rate: Option<f64>) -> f64
289where B: BinaryLabel, L: Data<B>, P: SortableData<f64> + Data<f64>, W: Data<f64>
290{
291 return match order {
292 Some(o) => roc_auc_on_sorted_labels(labels, predictions, weights, o, max_false_positive_rate),
293 None => {
294 let indices = predictions.argsort_unstable();
295 let sorted_labels = select(labels, &indices);
296 let sorted_predictions = select(predictions, &indices);
297 let roc_auc_score = match weights {
298 Some(w) => {
299 let sorted_weights = select(w, &indices);
300 roc_auc_on_sorted_labels(&sorted_labels, &sorted_predictions, Some(&sorted_weights), Order::DESCENDING, max_false_positive_rate)
301 },
302 None => {
303 roc_auc_on_sorted_labels(&sorted_labels, &sorted_predictions, None::<&W>, Order::DESCENDING, max_false_positive_rate)
304 }
305 };
306 roc_auc_score
307 }
308 };
309}
310pub fn roc_auc_on_sorted_labels<B, L, P, W>(labels: &L, predictions: &P, weights: Option<&W>, order: Order, max_false_positive_rate: Option<f64>) -> f64
311where B: BinaryLabel, L: Data<B>, P: Data<f64>, W: Data<f64> {
312 return match max_false_positive_rate {
313 None => match weights {
314 Some(w) => roc_auc_on_sorted_iterator(&mut labels.get_iterator(), &mut predictions.get_iterator(), &mut w.get_iterator(), order),
315 None => roc_auc_on_sorted_iterator(&mut labels.get_iterator(), &mut predictions.get_iterator(), &mut ConstWeight::one().get_iterator(), order),
316 }
317 Some(max_fpr) => match weights {
318 Some(w) => roc_auc_on_sorted_with_fp_cutoff(labels, predictions, w, order, max_fpr),
319 None => roc_auc_on_sorted_with_fp_cutoff(labels, predictions, &ConstWeight::one(), order, max_fpr),
320 }
321 };
322}
323
324pub fn roc_auc_on_sorted_iterator<B: BinaryLabel>(
325 labels: &mut impl DoubleEndedIterator<Item = B>,
326 predictions: &mut impl DoubleEndedIterator<Item = f64>,
327 weights: &mut impl DoubleEndedIterator<Item = f64>,
328 order: Order
329) -> f64 {
330 return match order {
331 Order::ASCENDING => roc_auc_on_descending_iterator(&mut labels.rev(), &mut predictions.rev(), &mut weights.rev()),
332 Order::DESCENDING => roc_auc_on_descending_iterator(labels, predictions, weights)
333 }
334}
335
336pub fn roc_auc_on_descending_iterator<B: BinaryLabel>(
337 labels: &mut impl Iterator<Item = B>,
338 predictions: &mut impl Iterator<Item = f64>,
339 weights: &mut impl Iterator<Item = f64>
340) -> f64 {
341 let mut false_positives: f64 = 0.0;
342 let mut true_positives: f64 = 0.0;
343 let mut last_counted_fp = 0.0;
344 let mut last_counted_tp = 0.0;
345 let mut area_under_curve = 0.0;
346 let mut zipped = labels.zip(predictions).zip(weights).peekable();
347 loop {
348 match zipped.next() {
349 None => break,
350 Some(actual) => {
351 let l = f64::from(actual.0.0.get_value());
352 let w = actual.1;
353 let wl = l * w;
354 true_positives += wl;
355 false_positives += w - wl;
356 if zipped.peek().map(|x| x.0.1 != actual.0.1).unwrap_or(true) {
357 area_under_curve += area_under_line_segment(last_counted_fp, false_positives, last_counted_tp, true_positives);
358 last_counted_fp = false_positives;
359 last_counted_tp = true_positives;
360 }
361 }
362 };
363 }
364 return area_under_curve / (true_positives * false_positives);
365}
366
367fn area_under_line_segment(x0: f64, x1: f64, y0: f64, y1: f64) -> f64 {
368 let dx = x1 - x0;
369 let dy = y1 - y0;
370 return dx * y0 + dy * dx * 0.5;
371}
372
373fn get_positive_sum<B: BinaryLabel>(
374 labels: impl Iterator<Item = B>,
375 weights: impl Iterator<Item = f64>
376) -> (f64, f64) {
377 let mut false_positives = 0f64;
378 let mut true_positives = 0f64;
379 for (label, weight) in labels.zip(weights) {
380 let lw = weight * f64::from(label.get_value());
381 false_positives += weight - lw;
382 true_positives += lw;
383 }
384 return (false_positives, true_positives);
385}
386
387pub fn roc_auc_on_sorted_with_fp_cutoff<B, L, P, W>(labels: &L, predictions: &P, weights: &W, order: Order, max_false_positive_rate: f64) -> f64
388where B: BinaryLabel, L: Data<B>, P: Data<f64>, W: Data<f64> {
389 let (fps, tps) = get_positive_sum(labels.get_iterator(), weights.get_iterator());
391 let mut l_it = labels.get_iterator();
392 let mut p_it = predictions.get_iterator();
393 let mut w_it = weights.get_iterator();
394 return match order {
395 Order::ASCENDING => roc_auc_on_descending_iterator_with_fp_cutoff(&mut l_it.rev(), &mut p_it.rev(), &mut w_it.rev(), fps, tps, max_false_positive_rate),
396 Order::DESCENDING => roc_auc_on_descending_iterator_with_fp_cutoff(&mut l_it, &mut p_it, &mut w_it, fps, tps, max_false_positive_rate)
397 };
398}
399
400
401fn roc_auc_on_descending_iterator_with_fp_cutoff<B: BinaryLabel>(
402 labels: &mut impl Iterator<Item = B>,
403 predictions: &mut impl Iterator<Item = f64>,
404 weights: &mut impl Iterator<Item = f64>,
405 false_positive_sum: f64,
406 true_positive_sum: f64,
407 max_false_positive_rate: f64
408) -> f64 {
409 let mut false_positives: f64 = 0.0;
410 let mut true_positives: f64 = 0.0;
411 let mut last_counted_fp = 0.0;
412 let mut last_counted_tp = 0.0;
413 let mut area_under_curve = 0.0;
414 let mut zipped = labels.zip(predictions).zip(weights).peekable();
415 let false_positive_cutoff = max_false_positive_rate * false_positive_sum;
416 loop {
417 match zipped.next() {
418 None => break,
419 Some(actual) => {
420 let l = f64::from(actual.0.0.get_value());
421 let w = actual.1;
422 let wl = l * w;
423 let next_tp = true_positives + wl;
424 let next_fp = false_positives + (w - wl);
425 let is_above_max = next_fp > false_positive_cutoff;
426 if is_above_max {
427 let dx = next_fp - false_positives;
428 let dy = next_tp - true_positives;
429 true_positives += dy * false_positive_cutoff / dx;
430 false_positives = false_positive_cutoff;
431 } else {
432 true_positives = next_tp;
433 false_positives = next_fp;
434 }
435 if zipped.peek().map(|x| x.0.1 != actual.0.1).unwrap_or(true) || is_above_max {
436 area_under_curve += area_under_line_segment(last_counted_fp, false_positives, last_counted_tp, true_positives);
437 last_counted_fp = false_positives;
438 last_counted_tp = true_positives;
439 }
440 if is_above_max {
441 break;
442 }
443 }
444 };
445 }
446 let normalized_area_under_curve = area_under_curve / (true_positive_sum * false_positive_sum);
447 let min_area = 0.5 * max_false_positive_rate * max_false_positive_rate;
448 let max_area = max_false_positive_rate;
449 return 0.5 * (1.0 + (normalized_area_under_curve - min_area) / (max_area - min_area));
450}
451
452
453#[pyclass(eq, eq_int, name="Order")]
455#[derive(Clone, Copy, PartialEq)]
456pub enum PyOrder {
457 ASCENDING,
458 DESCENDING
459}
460
461fn py_order_as_order(order: PyOrder) -> Order {
462 return match order {
463 PyOrder::ASCENDING => Order::ASCENDING,
464 PyOrder::DESCENDING => Order::DESCENDING,
465 }
466}
467
468
469trait PyScore: Ungil + Sync {
470
471 fn score<B, L, P, W>(&self, labels: &L, predictions: &P, weights: Option<&W>, order: Option<Order>) -> f64
472 where B: BinaryLabel, L: Data<B>, P: SortableData<f64> + Data<f64>, W: Data<f64>;
473
474 fn score_py_generic<'py, B>(
475 &self,
476 py: Python<'py>,
477 labels: &PyReadonlyArray1<'py, B>,
478 predictions: &PyReadonlyArray1<'py, f64>,
479 weights: &Option<PyReadonlyArray1<'py, f64>>,
480 order: &Option<PyOrder>,
481 ) -> f64
482 where B: BinaryLabel + Element
483 {
484 let labels = labels.as_array();
485 let predictions = predictions.as_array();
486 let order = order.map(py_order_as_order);
487 let score = match weights {
488 Some(weight) => {
489 let weights = weight.as_array();
490 py.allow_threads(move || {
491 self.score(&labels, &predictions, Some(&weights), order)
492 })
493 },
494 None => py.allow_threads(move || {
495 self.score(&labels, &predictions, None::<&Vec<f64>>, order)
496 })
497 };
498 return score;
499 }
500
501 fn score_py_match_run<'py, T>(
502 &self,
503 py: Python<'py>,
504 labels: &Bound<'py, PyUntypedArray>,
505 predictions: &PyReadonlyArray1<'py, f64>,
506 weights: &Option<PyReadonlyArray1<'py, f64>>,
507 order: &Option<PyOrder>,
508 dt: &Bound<'py, PyArrayDescr>
509 ) -> Option<f64>
510 where T: Element + BinaryLabel
511 {
512 return if dt.is_equiv_to(&dtype::<T>(py)) {
513 let labels = labels.downcast::<PyArray1<T>>().unwrap().readonly();
514 Some(self.score_py_generic(py, &labels.readonly(), predictions, weights, order))
515 } else {
516 None
517 };
518 }
519
520 fn score_py<'py>(
521 &self,
522 py: Python<'py>,
523 labels: &Bound<'py, PyUntypedArray>,
524 predictions: PyReadonlyArray1<'py, f64>,
525 weights: Option<PyReadonlyArray1<'py, f64>>,
526 order: Option<PyOrder>,
527 ) -> PyResult<f64> {
528 if labels.ndim() != 1 {
529 return Err(PyTypeError::new_err(format!("Expected 1-dimensional array for labels but found {} dimenisons.", labels.ndim())));
530 }
531 let label_dtype = labels.dtype();
532 if let Some(score) = self.score_py_match_run::<bool>(py, &labels, &predictions, &weights, &order, &label_dtype) {
533 return Ok(score)
534 }
535 else if let Some(score) = self.score_py_match_run::<u8>(py, &labels, &predictions, &weights, &order, &label_dtype) {
536 return Ok(score)
537 }
538 else if let Some(score) = self.score_py_match_run::<i8>(py, &labels, &predictions, &weights, &order, &label_dtype) {
539 return Ok(score)
540 }
541 else if let Some(score) = self.score_py_match_run::<u16>(py, &labels, &predictions, &weights, &order, &label_dtype) {
542 return Ok(score)
543 }
544 else if let Some(score) = self.score_py_match_run::<i16>(py, &labels, &predictions, &weights, &order, &label_dtype) {
545 return Ok(score)
546 }
547 else if let Some(score) = self.score_py_match_run::<u32>(py, &labels, &predictions, &weights, &order, &label_dtype) {
548 return Ok(score)
549 }
550 else if let Some(score) = self.score_py_match_run::<i32>(py, &labels, &predictions, &weights, &order, &label_dtype) {
551 return Ok(score)
552 }
553 else if let Some(score) = self.score_py_match_run::<u64>(py, &labels, &predictions, &weights, &order, &label_dtype) {
554 return Ok(score)
555 }
556 else if let Some(score) = self.score_py_match_run::<i64>(py, &labels, &predictions, &weights, &order, &label_dtype) {
557 return Ok(score)
558 }
559 return Err(PyTypeError::new_err(format!("Unsupported dtype for labels: {}. Supported dtypes are bool, uint8, uint16, uint32, uint64, in8, int16, int32, int64", label_dtype)));
560 }
561}
562
563struct PyAveragePrecision {
564
565}
566
567impl PyAveragePrecision{
568 fn new() -> Self {
569 return PyAveragePrecision {};
570 }
571}
572
573impl PyScore for PyAveragePrecision {
574 fn score<B, L, P, W>(&self, labels: &L, predictions: &P, weights: Option<&W>, order: Option<Order>) -> f64
575 where B: BinaryLabel, L: Data<B>, P: SortableData<f64> + Data<f64>, W: Data<f64> {
576 return average_precision_with_order(labels, predictions, weights, order);
577 }
578}
579
580struct PyRocAuc {
581 max_fpr: Option<f64>
582}
583
584impl PyRocAuc {
585 fn new(max_fpr: Option<f64>) -> Self {
586 return PyRocAuc { max_fpr: max_fpr };
587 }
588}
589
590impl PyScore for PyRocAuc {
591 fn score<B, L, P, W>(&self, labels: &L, predictions: &P, weights: Option<&W>, order: Option<Order>) -> f64
592 where B: BinaryLabel, L: Data<B>, P: SortableData<f64> + Data<f64>, W: Data<f64> {
593 return roc_auc_with_order(labels, predictions, weights, order, self.max_fpr);
594 }
595}
596
597
598#[pyfunction(name = "average_precision")]
599#[pyo3(signature = (labels, predictions, *, weights=None, order=None))]
600pub fn average_precision_py<'py>(
601 py: Python<'py>,
602 labels: &Bound<'py, PyUntypedArray>,
603 predictions: PyReadonlyArray1<'py, f64>,
604 weights: Option<PyReadonlyArray1<'py, f64>>,
605 order: Option<PyOrder>
606) -> PyResult<f64> {
607 return PyAveragePrecision::new().score_py(py, labels, predictions, weights, order);
608}
609
610#[pyfunction(name = "roc_auc")]
611#[pyo3(signature = (labels, predictions, *, weights=None, order=None, max_fpr=None))]
612pub fn roc_auc_py<'py>(
613 py: Python<'py>,
614 labels: &Bound<'py, PyUntypedArray>,
615 predictions: PyReadonlyArray1<'py, f64>,
616 weights: Option<PyReadonlyArray1<'py, f64>>,
617 order: Option<PyOrder>,
618 max_fpr: Option<f64>,
619) -> PyResult<f64> {
620 return PyRocAuc::new(max_fpr).score_py(py, labels, predictions, weights, order);
621}
622
623#[pymodule(name = "_scors")]
624fn scors(m: &Bound<'_, PyModule>) -> PyResult<()> {
625 m.add_function(wrap_pyfunction!(average_precision_py, m)?).unwrap();
626 m.add_function(wrap_pyfunction!(roc_auc_py, m)?).unwrap();
627 m.add_class::<PyOrder>().unwrap();
628 return Ok(());
629}
630
631
632#[cfg(test)]
633mod tests {
634 use super::*;
635
636 #[test]
637 fn test_average_precision_on_sorted() {
638 let labels: [u8; 4] = [1, 0, 1, 0];
639 let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
641 let actual = average_precision_on_sorted_labels(&labels, Some(&weights), Order::DESCENDING);
642 assert_eq!(actual, 0.8333333333333333);
643 }
644
645 #[test]
646 fn test_average_precision_unsorted() {
647 let labels: [u8; 4] = [0, 0, 1, 1];
648 let predictions: [f64; 4] = [0.1, 0.4, 0.35, 0.8];
649 let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
650 let actual = average_precision_with_order(&labels, &predictions, Some(&weights), None);
651 assert_eq!(actual, 0.8333333333333333);
652 }
653
654 #[test]
655 fn test_average_precision_sorted() {
656 let labels: [u8; 4] = [1, 0, 1, 0];
657 let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
658 let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
659 let actual = average_precision_with_order(&labels, &predictions, Some(&weights), Some(Order::DESCENDING));
660 assert_eq!(actual, 0.8333333333333333);
661 }
662
663 #[test]
664 fn test_roc_auc() {
665 let labels: [u8; 4] = [1, 0, 1, 0];
666 let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
667 let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
668 let actual = roc_auc_with_order(&labels, &predictions, Some(&weights), Some(Order::DESCENDING), None);
669 assert_eq!(actual, 0.75);
670 }
671}