_scors/
lib.rs

1use ndarray::{ArrayView,Ix1};
2use numpy::{NotContiguousError,PyReadonlyArray1};
3use pyo3::prelude::*; // {PyModule,PyResult,Python,pymodule};
4use std::iter::DoubleEndedIterator;
5use std::time::Instant;
6
7pub enum Order {
8    ASCENDING,
9    DESCENDING
10}
11
12pub trait Data<T: Clone>: {
13    // TODO This is necessary because it seems that there is no trait like that in rust
14    //      Maybe I am just not aware, but for now use my own trait.
15    fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T>;
16    fn get_at(&self, index: usize) -> T;
17}
18
19pub trait SortableData<T> {
20    fn argsort_unstable(&self) -> Vec<usize>;
21}
22
23impl <T: Clone> Data<T> for Vec<T> {
24    fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> {
25        return self.iter().cloned();
26    }
27    fn get_at(&self, index: usize) -> T {
28        return self[index].clone();
29    }
30}
31
32impl SortableData<f64> for Vec<f64> {
33    fn argsort_unstable(&self) -> Vec<usize> {
34        let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
35        indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
36        // indices.sort_unstable_by_key(|i| self[*i]);
37        return indices;
38    }
39}
40
41impl <T: Clone> Data<T> for &[T] {
42    fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> {
43        return self.iter().cloned();
44    }
45    fn get_at(&self, index: usize) -> T {
46        return self[index].clone();
47    }
48}
49
50impl SortableData<f64> for &[f64] {
51    fn argsort_unstable(&self) -> Vec<usize> {
52        // let t0 = Instant::now();
53        let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
54        // println!("Creating indices took {}ms", t0.elapsed().as_millis());
55        // let t1 = Instant::now();
56        indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
57        // println!("Sorting took {}ms", t0.elapsed().as_millis());
58        return indices;
59    }
60}
61
62impl <T: Clone, const N: usize> Data<T> for [T; N] {
63    fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> {
64        return self.iter().cloned();
65    }
66    fn get_at(&self, index: usize) -> T {
67        return self[index].clone();
68    }
69}
70
71impl <const N: usize> SortableData<f64> for [f64; N] {
72    fn argsort_unstable(&self) -> Vec<usize> {
73        let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
74        indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
75        return indices;
76    }
77}
78
79impl <T: Clone> Data<T> for ArrayView<'_, T, Ix1> {
80    fn get_iterator(&self) -> impl DoubleEndedIterator<Item = T> {
81        return self.iter().cloned();
82    }
83    fn get_at(&self, index: usize) -> T {
84        return self[index].clone();
85    }
86}
87
88impl SortableData<f64> for ArrayView<'_, f64, Ix1> {
89    fn argsort_unstable(&self) -> Vec<usize> {
90        let mut indices: Vec<usize> = (0..self.len()).collect::<Vec<_>>();
91        indices.sort_unstable_by(|i, k| self[*k].total_cmp(&self[*i]));
92        return indices;
93    }
94}
95
96fn select<T, I>(slice: &I, indices: &[usize]) -> Vec<T>
97where T: Copy, I: Data<T>
98{
99    let mut selection: Vec<T> = Vec::new();
100    selection.reserve_exact(indices.len());
101    for index in indices {
102        selection.push(slice.get_at(*index));
103    }
104    return selection;
105}
106
107pub fn average_precision<L, P, W>(labels: &L, predictions: &P, weights: &W) -> f64
108where L: Data<u8>, P: SortableData<f64>, W: Data<f64>
109{
110    return average_precision_with_order(labels, predictions, weights, None);
111}
112
113pub fn average_precision_with_order<L, P, W>(labels: &L, predictions: &P, weights: &W, order: Option<Order>) -> f64
114where L: Data<u8>, P: SortableData<f64>, W: Data<f64>
115{
116    return match order {
117        Some(o) => average_precision_on_sorted_labels(labels, weights, o),
118        None => {
119            let indices = predictions.argsort_unstable();
120            let sorted_labels = select(labels, &indices);
121            let sorted_weights = select(weights, &indices);
122            let ap = average_precision_on_sorted_labels(&sorted_labels, &sorted_weights, Order::DESCENDING);
123            ap
124        }
125    };
126}
127
128pub fn average_precision_on_sorted_labels<L, W>(labels: &L, weights: &W, order: Order) -> f64
129where L: Data<u8>, W: Data<f64>
130{
131    return average_precision_on_iterator(labels.get_iterator(), weights.get_iterator(), order);
132}
133
134pub fn average_precision_on_iterator<L, W>(labels: L, weights: W, order: Order) -> f64
135where L: DoubleEndedIterator<Item = u8>, W: DoubleEndedIterator<Item = f64>
136{
137    return match order {
138        Order::ASCENDING => average_precision_on_descending_iterator(labels.rev(), weights.rev()),
139        Order::DESCENDING => average_precision_on_descending_iterator(labels, weights)
140    };
141}
142
143pub fn average_precision_on_descending_iterator(labels: impl Iterator<Item = u8>, weights: impl Iterator<Item = f64>) -> f64 {
144    let mut ap: f64 = 0.0;
145    let mut tps: f64 = 0.0;
146    let mut fps: f64 = 0.0;
147    for (label, weight) in labels.zip(weights) {
148        let w: f64 = weight;
149        let l: u8 = label;
150        let tp = w * (l as f64);
151        tps += tp;
152        fps += weight - tp;
153        let ps = tps + fps;
154        let precision = tps / ps;
155        ap += tp * precision;
156    }
157    return ap / tps;
158}
159
160
161
162// ROC AUC score
163pub fn roc_auc<L, P, W>(labels: &L, predictions: &P, weights: &W) -> f64
164where L: Data<u8>, P: SortableData<f64> + Data<f64>, W: Data<f64>
165{
166    return roc_auc_with_order(labels, predictions, weights, None);
167}
168
169pub fn roc_auc_with_order<L, P, W>(labels: &L, predictions: &P, weights: &W, order: Option<Order>) -> f64
170where L: Data<u8>, P: SortableData<f64> + Data<f64>, W: Data<f64>
171{
172    return match order {
173        Some(o) => roc_auc_on_sorted_labels(labels, predictions, weights, o),
174        None => {
175            let indices = predictions.argsort_unstable();
176            let sorted_labels = select(labels, &indices);
177            let sorted_predictions = select(predictions, &indices);
178            let sorted_weights = select(weights, &indices);
179            let ap = roc_auc_on_sorted_labels(&sorted_labels, &sorted_predictions, &sorted_weights, Order::DESCENDING);
180            ap
181        }
182    };
183}
184pub fn roc_auc_on_sorted_labels<L, P, W>(labels: &L, predictions: &P, weights: &W, order: Order) -> f64 
185where L: Data<u8>, P: Data<f64>, W: Data<f64> {
186    return roc_auc_on_sorted_iterator(&mut labels.get_iterator(), &mut predictions.get_iterator(), &mut weights.get_iterator(), order);
187}
188
189pub fn roc_auc_on_sorted_iterator(
190    labels: &mut impl DoubleEndedIterator<Item = u8>,
191    predictions: &mut impl DoubleEndedIterator<Item = f64>,
192    weights: &mut impl DoubleEndedIterator<Item = f64>,
193    order: Order
194) -> f64 {
195    return match order {
196        Order::ASCENDING => roc_auc_on_descending_iterator(&mut labels.rev(), &mut predictions.rev(), &mut weights.rev()),
197        Order::DESCENDING => roc_auc_on_descending_iterator(labels, predictions, weights)
198    }
199}
200
201pub fn roc_auc_on_descending_iterator(
202    labels: &mut impl Iterator<Item = u8>,
203    predictions: &mut impl Iterator<Item = f64>,
204    weights: &mut impl Iterator<Item = f64>
205) -> f64 {
206    let mut false_positives: f64 = 0.0;
207    let mut true_positives: f64 = 0.0;
208    let mut last_counted_fp = 0.0;
209    let mut last_counted_tp = 0.0;
210    let mut area_under_curve = 0.0;
211    let mut zipped = labels.zip(predictions).zip(weights).peekable();
212    loop {
213        match zipped.next() {
214            None => break,
215            Some(actual) => {
216                let l = actual.0.0 as f64;
217                let w = actual.1;
218                let wl = l * w;
219                true_positives += wl;
220                false_positives += w - wl;
221                if zipped.peek().map(|x| x.0.1 != actual.0.1).unwrap_or(true) {
222                    area_under_curve += area_under_line_segment(last_counted_fp, false_positives, last_counted_tp, true_positives);
223                    last_counted_fp = false_positives;
224                    last_counted_tp = true_positives;
225                }
226            }
227        };
228    }
229    return area_under_curve / (true_positives * false_positives);
230}
231
232fn area_under_line_segment(x0: f64, x1: f64, y0: f64, y1: f64) -> f64 {
233    let dx = x1 - x0;
234    let dy = y1 - y0;
235    return dx * y0 + dy * dx * 0.5;
236}
237
238
239// Python bindings
240#[pyclass(eq, eq_int, name="Order")]
241#[derive(PartialEq)]
242pub enum PyOrder {
243    ASCENDING,
244    DESCENDING
245}
246
247impl Clone for PyOrder {
248    fn clone(&self) -> Self {
249        match self {
250            PyOrder::ASCENDING => PyOrder::ASCENDING,
251            PyOrder::DESCENDING => PyOrder::DESCENDING
252        }
253    }
254}
255
256fn py_order_as_order(order: PyOrder) -> Order {
257    return match order {
258        PyOrder::ASCENDING => Order::ASCENDING,
259        PyOrder::DESCENDING => Order::DESCENDING,
260    }
261}
262
263#[pyfunction(name = "average_precision")]
264#[pyo3(signature = (labels, predictions, *, weights, order=None))]
265pub fn average_precision_py<'py>(
266    py: Python<'py>,
267    labels: PyReadonlyArray1<'py, u8>,
268    predictions: PyReadonlyArray1<'py, f64>,
269    weights: PyReadonlyArray1<'py, f64>,
270    order: Option<PyOrder>
271) -> Result<f64, NotContiguousError> {
272    let o = order.map(py_order_as_order);
273    let ap = if let (Ok(l), Ok(p), Ok(w)) = (labels.as_slice(), predictions.as_slice(), weights.as_slice()) {
274        let ap = average_precision_with_order(&l, &p, &w, o);
275        ap
276    } else {
277        average_precision_with_order(&labels.as_array(), &predictions.as_array(), &weights.as_array(), o)
278    };
279
280    return Ok(ap);
281}
282
283#[pyfunction(name = "roc_auc")]
284#[pyo3(signature = (labels, predictions, *, weights, order=None))]
285pub fn roc_auc_py<'py>(
286    py: Python<'py>,
287    labels: PyReadonlyArray1<'py, u8>,
288    predictions: PyReadonlyArray1<'py, f64>,
289    weights: PyReadonlyArray1<'py, f64>,
290    order: Option<PyOrder>
291) -> Result<f64, NotContiguousError> {
292    let o = order.map(py_order_as_order);
293    let ap = if let (Ok(l), Ok(p), Ok(w)) = (labels.as_slice(), predictions.as_slice(), weights.as_slice()) {
294        let roc_auc = roc_auc_with_order(&l, &p, &w, o);
295        roc_auc
296    } else {
297        roc_auc_with_order(&labels.as_array(), &predictions.as_array(), &weights.as_array(), o)
298    };
299
300    return Ok(ap);
301}
302
303#[pymodule(name = "_scors")]
304fn scors(m: &Bound<'_, PyModule>) -> PyResult<()> {
305    m.add_function(wrap_pyfunction!(average_precision_py, m)?).unwrap();
306    m.add_function(wrap_pyfunction!(roc_auc_py, m)?).unwrap();
307    m.add_class::<PyOrder>().unwrap();
308    return Ok(());
309}
310
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315
316    #[test]
317    fn test_average_precision_on_sorted() {
318        let labels: [u8; 4] = [1, 0, 1, 0];
319        // let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
320        let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
321        let actual = average_precision_on_sorted_labels(&labels, &weights, Order::DESCENDING);
322        assert_eq!(actual, 0.8333333333333333);
323    }
324
325    #[test]
326    fn test_average_precision_unsorted() {
327        let labels: [u8; 4] = [0, 0, 1, 1];
328        let predictions: [f64; 4] = [0.1, 0.4, 0.35, 0.8];
329        let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
330        let actual = average_precision_with_order(&labels, &predictions, &weights, None);
331        assert_eq!(actual, 0.8333333333333333);
332    }
333
334    #[test]
335    fn test_average_precision_sorted() {
336        let labels: [u8; 4] = [1, 0, 1, 0];
337        let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
338        let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
339        let actual = average_precision_with_order(&labels, &predictions, &weights, Some(Order::DESCENDING));
340        assert_eq!(actual, 0.8333333333333333);
341    }
342
343    #[test]
344    fn test_roc_auc() {
345        let labels: [u8; 4] = [1, 0, 1, 0];
346        let predictions: [f64; 4] = [0.8, 0.4, 0.35, 0.1];
347        let weights: [f64; 4] = [1.0, 1.0, 1.0, 1.0];
348        let actual = roc_auc_with_order(&labels, &predictions, &weights, Some(Order::DESCENDING));
349        assert_eq!(actual, 0.75);
350    }
351}