1use ndarray::{ArrayView,Ix1};
2use numpy::{NotContiguousError,PyReadonlyArray1};
3use pyo3::prelude::*; use std::iter::DoubleEndedIterator;
5use std::time::Instant;
6
7pub enum Order {
8 ASCENDING,
9 DESCENDING
10}
11
12pub trait Data<T: Clone>: {
13 fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T>;
16 fn get_at(&self, index: usize) -> T;
17}
18
19pub trait SortableData<T> {
20 fn argsort_unstable(&self) -> Vec<usize>;
21}
22
23impl <T: Clone> Data<T> for Vec<T> {
24 fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> {
25 return self.iter().cloned();
26 }
27 fn get_at(&self, index: usize) -> T {
28 return self[index].clone();
29 }
30}
31
32impl SortableData<f64> for Vec<f64> {
33 fn argsort_unstable(&self) -> Vec<usize> {
34 let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
35 indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
36 return indices;
38 }
39}
40
41impl <T: Clone> Data<T> for &[T] {
42 fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> {
43 return self.iter().cloned();
44 }
45 fn get_at(&self, index: usize) -> T {
46 return self[index].clone();
47 }
48}
49
50impl SortableData<f64> for &[f64] {
51 fn argsort_unstable(&self) -> Vec<usize> {
52 let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
54 indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
57 return indices;
59 }
60}
61
62impl <T: Clone, const N: usize> Data<T> for [T; N] {
63 fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> {
64 return self.iter().cloned();
65 }
66 fn get_at(&self, index: usize) -> T {
67 return self[index].clone();
68 }
69}
70
71impl <const N: usize> SortableData<f64> for [f64; N] {
72 fn argsort_unstable(&self) -> Vec<usize> {
73 let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
74 indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
75 return indices;
76 }
77}
78
79impl <T: Clone> Data<T> for ArrayView<'_, T, Ix1> {
80 fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> {
81 return self.iter().cloned();
82 }
83 fn get_at(&self, index: usize) -> T {
84 return self[index].clone();
85 }
86}
87
88impl SortableData<f64> for ArrayView<'_, f64, Ix1> {
89 fn argsort_unstable(&self) -> Vec<usize> {
90 let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
91 indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
92 return indices;
93 }
94}
95
96fn select<T, I>(slice: &I, indices: &[usize]) -> Vec<T>
97where T: Copy, I: Data<T>
98{
99 let mut selection: Vec<T> = Vec::new();
100 selection.reserve_exact(indices.len());
101 for index in indices {
102 selection.push(slice.get_at(*index));
103 }
104 return selection;
105}
106
107pub fn average_precision<L, P, W>(labels: &L, predictions: &P, weights: &W) -> f64
108where L: Data<u8>, P: SortableData<f64>, W: Data<f64>
109{
110 return average_precision_with_order(labels, predictions, weights, None);
111}
112
113pub fn average_precision_with_order<L, P, W>(labels: &L, predictions: &P, weights: &W, order: Option<Order>) -> f64
114where L: Data<u8>, P: SortableData<f64>, W: Data<f64>
115{
116 return match order {
117 Some(o) => average_precision_on_sorted_labels(labels, weights, o),
118 None => {
119 let indices = predictions.argsort_unstable();
120 let sorted_labels = select(labels, &indices);
121 let sorted_weights = select(weights, &indices);
122 let ap = average_precision_on_sorted_labels(&sorted_labels, &sorted_weights, Order::DESCENDING);
123 ap
124 }
125 };
126}
127
128pub fn average_precision_on_sorted_labels<L, W>(labels: &L, weights: &W, order: Order) -> f64
129where L: Data<u8>, W: Data<f64>
130{
131 return average_precision_on_iterator(labels.get_iterator(), weights.get_iterator(), order);
132}
133
134pub fn average_precision_on_iterator<L, W>(labels: L, weights: W, order: Order) -> f64
135where L: DoubleEndedIterator<Item = u8>, W: DoubleEndedIterator<Item = f64>
136{
137 return match order {
138 Order::ASCENDING => average_precision_on_descending_iterator(labels.rev(), weights.rev()),
139 Order::DESCENDING => average_precision_on_descending_iterator(labels, weights)
140 };
141}
142
143pub fn average_precision_on_descending_iterator(labels: impl Iterator<Item = u8>, weights: impl Iterator<Item = f64>) -> f64 {
144 let mut ap: f64 = 0.0;
145 let mut tps: f64 = 0.0;
146 let mut fps: f64 = 0.0;
147 for (label, weight) in labels.zip(weights) {
148 let w: f64 = weight;
149 let l: u8 = label;
150 let tp = w * (l as f64);
151 tps += tp;
152 fps += weight - tp;
153 let ps = tps + fps;
154 let precision = tps / ps;
155 ap += tp * precision;
156 }
157 return ap / tps;
158}
159
160
161
162pub fn roc_auc<L, P, W>(labels: &L, predictions: &P, weights: &W) -> f64
164where L: Data<u8>, P: SortableData<f64> + Data<f64>, W: Data<f64>
165{
166 return roc_auc_with_order(labels, predictions, weights, None);
167}
168
169pub fn roc_auc_with_order<L, P, W>(labels: &L, predictions: &P, weights: &W, order: Option<Order>) -> f64
170where L: Data<u8>, P: SortableData<f64> + Data<f64>, W: Data<f64>
171{
172 return match order {
173 Some(o) => roc_auc_on_sorted_labels(labels, predictions, weights, o),
174 None => {
175 let indices = predictions.argsort_unstable();
176 let sorted_labels = select(labels, &indices);
177 let sorted_predictions = select(predictions, &indices);
178 let sorted_weights = select(weights, &indices);
179 let ap = roc_auc_on_sorted_labels(&sorted_labels, &sorted_predictions, &sorted_weights, Order::DESCENDING);
180 ap
181 }
182 };
183}
184pub fn roc_auc_on_sorted_labels<L, P, W>(labels: &L, predictions: &P, weights: &W, order: Order) -> f64
185where L: Data<u8>, P: Data<f64>, W: Data<f64> {
186 return roc_auc_on_sorted_iterator(&mut labels.get_iterator(), &mut predictions.get_iterator(), &mut weights.get_iterator(), order);
187}
188
189pub fn roc_auc_on_sorted_iterator(
190 labels: &mut impl DoubleEndedIterator<Item = u8>,
191 predictions: &mut impl DoubleEndedIterator<Item = f64>,
192 weights: &mut impl DoubleEndedIterator<Item = f64>,
193 order: Order
194) -> f64 {
195 return match order {
196 Order::ASCENDING => roc_auc_on_descending_iterator(&mut labels.rev(), &mut predictions.rev(), &mut weights.rev()),
197 Order::DESCENDING => roc_auc_on_descending_iterator(labels, predictions, weights)
198 }
199}
200
201pub fn roc_auc_on_descending_iterator(
202 labels: &mut impl Iterator<Item = u8>,
203 predictions: &mut impl Iterator<Item = f64>,
204 weights: &mut impl Iterator<Item = f64>
205) -> f64 {
206 let mut false_positives: f64 = 0.0;
207 let mut true_positives: f64 = 0.0;
208 let mut last_counted_fp = 0.0;
209 let mut last_counted_tp = 0.0;
210 let mut area_under_curve = 0.0;
211 let mut zipped = labels.zip(predictions).zip(weights).peekable();
212 loop {
213 match zipped.next() {
214 None => break,
215 Some(actual) => {
216 let l = actual.0.0 as f64;
217 let w = actual.1;
218 let wl = l * w;
219 true_positives += wl;
220 false_positives += w - wl;
221 if zipped.peek().map(|x| x.0.1 != actual.0.1).unwrap_or(true) {
222 area_under_curve += area_under_line_segment(last_counted_fp, false_positives, last_counted_tp, true_positives);
223 last_counted_fp = false_positives;
224 last_counted_tp = true_positives;
225 }
226 }
227 };
228 }
229 return area_under_curve / (true_positives * false_positives);
230}
231
232fn area_under_line_segment(x0: f64, x1: f64, y0: f64, y1: f64) -> f64 {
233 let dx = x1 - x0;
234 let dy = y1 - y0;
235 return dx * y0 + dy * dx * 0.5;
236}
237
238
239#[pyclass(eq, eq_int, name="Order")]
241#[derive(PartialEq)]
242pub enum PyOrder {
243 ASCENDING,
244 DESCENDING
245}
246
247impl Clone for PyOrder {
248 fn clone(&self) -> Self {
249 match self {
250 PyOrder::ASCENDING => PyOrder::ASCENDING,
251 PyOrder::DESCENDING => PyOrder::DESCENDING
252 }
253 }
254}
255
256fn py_order_as_order(order: PyOrder) -> Order {
257 return match order {
258 PyOrder::ASCENDING => Order::ASCENDING,
259 PyOrder::DESCENDING => Order::DESCENDING,
260 }
261}
262
263#[pyfunction(name = "average_precision")]
264#[pyo3(signature = (labels, predictions, *, weights, order=None))]
265pub fn average_precision_py<'py>(
266 py: Python<'py>,
267 labels: PyReadonlyArray1<'py, u8>,
268 predictions: PyReadonlyArray1<'py, f64>,
269 weights: PyReadonlyArray1<'py, f64>,
270 order: Option<PyOrder>
271) -> Result<f64, NotContiguousError> {
272 let o = order.map(py_order_as_order);
273 let ap = if let (Ok(l), Ok(p), Ok(w)) = (labels.as_slice(), predictions.as_slice(), weights.as_slice()) {
274 let ap = average_precision_with_order(&l, &p, &w, o);
275 ap
276 } else {
277 average_precision_with_order(&labels.as_array(), &predictions.as_array(), &weights.as_array(), o)
278 };
279
280 return Ok(ap);
281}
282
283#[pyfunction(name = "roc_auc")]
284#[pyo3(signature = (labels, predictions, *, weights, order=None))]
285pub fn roc_auc_py<'py>(
286 py: Python<'py>,
287 labels: PyReadonlyArray1<'py, u8>,
288 predictions: PyReadonlyArray1<'py, f64>,
289 weights: PyReadonlyArray1<'py, f64>,
290 order: Option<PyOrder>
291) -> Result<f64, NotContiguousError> {
292 let o = order.map(py_order_as_order);
293 let ap = if let (Ok(l), Ok(p), Ok(w)) = (labels.as_slice(), predictions.as_slice(), weights.as_slice()) {
294 let roc_auc = roc_auc_with_order(&l, &p, &w, o);
295 roc_auc
296 } else {
297 roc_auc_with_order(&labels.as_array(), &predictions.as_array(), &weights.as_array(), o)
298 };
299
300 return Ok(ap);
301}
302
303#[pymodule(name = "_scors")]
304fn scors(m: &Bound<'_, PyModule>) -> PyResult<()> {
305 m.add_function(wrap_pyfunction!(average_precision_py, m)?).unwrap();
306 m.add_function(wrap_pyfunction!(roc_auc_py, m)?).unwrap();
307 m.add_class::<PyOrder>().unwrap();
308 return Ok(());
309}
310
311
312#[cfg(test)]
313mod tests {
314 use super::*;
315
316 #[test]
317 fn test_average_precision_on_sorted() {
318 let labels: [u8; 4] = [1, 0, 1, 0];
319 let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
321 let actual = average_precision_on_sorted_labels(&labels, &weights, Order::DESCENDING);
322 assert_eq!(actual, 0.8333333333333333);
323 }
324
325 #[test]
326 fn test_average_precision_unsorted() {
327 let labels: [u8; 4] = [0, 0, 1, 1];
328 let predictions: [f64; 4] = [0.1, 0.4, 0.35, 0.8];
329 let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
330 let actual = average_precision_with_order(&labels, &predictions, &weights, None);
331 assert_eq!(actual, 0.8333333333333333);
332 }
333
334 #[test]
335 fn test_average_precision_sorted() {
336 let labels: [u8; 4] = [1, 0, 1, 0];
337 let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
338 let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
339 let actual = average_precision_with_order(&labels, &predictions, &weights, Some(Order::DESCENDING));
340 assert_eq!(actual, 0.8333333333333333);
341 }
342
343 #[test]
344 fn test_roc_auc() {
345 let labels: [u8; 4] = [1, 0, 1, 0];
346 let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
347 let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
348 let actual = roc_auc_with_order(&labels, &predictions, &weights, Some(Order::DESCENDING));
349 assert_eq!(actual, 0.75);
350 }
351}