Skip to main content

tensorlogic_train/metrics/
ranking.rs

1//! Ranking metrics.
2
3use crate::{TrainError, TrainResult};
4use scirs2_core::ndarray::{ArrayView, Ix2};
5
6use super::Metric;
7
8/// Top-K accuracy metric.
9/// Measures whether the correct class is in the top K predictions.
10#[derive(Debug, Clone)]
11pub struct TopKAccuracy {
12    /// Number of top predictions to consider.
13    pub k: usize,
14}
15
16impl Default for TopKAccuracy {
17    fn default() -> Self {
18        Self { k: 5 }
19    }
20}
21
22impl TopKAccuracy {
23    /// Create a new Top-K accuracy metric.
24    pub fn new(k: usize) -> Self {
25        Self { k }
26    }
27}
28
29impl Metric for TopKAccuracy {
30    fn compute(
31        &self,
32        predictions: &ArrayView<f64, Ix2>,
33        targets: &ArrayView<f64, Ix2>,
34    ) -> TrainResult<f64> {
35        if predictions.shape() != targets.shape() {
36            return Err(TrainError::MetricsError(format!(
37                "Shape mismatch: predictions {:?} vs targets {:?}",
38                predictions.shape(),
39                targets.shape()
40            )));
41        }
42
43        let num_classes = predictions.ncols();
44        if self.k > num_classes {
45            return Err(TrainError::MetricsError(format!(
46                "K ({}) cannot be greater than number of classes ({})",
47                self.k, num_classes
48            )));
49        }
50
51        let mut correct = 0;
52        let total = predictions.nrows();
53
54        for i in 0..total {
55            // Find true class
56            let mut true_class = 0;
57            let mut max_true = targets[[i, 0]];
58            for j in 1..num_classes {
59                if targets[[i, j]] > max_true {
60                    max_true = targets[[i, j]];
61                    true_class = j;
62                }
63            }
64
65            // Get top K predictions
66            let mut indices: Vec<usize> = (0..num_classes).collect();
67            indices.sort_by(|&a, &b| {
68                predictions[[i, b]]
69                    .partial_cmp(&predictions[[i, a]])
70                    .unwrap_or(std::cmp::Ordering::Equal)
71            });
72
73            // Check if true class is in top K
74            if indices[..self.k].contains(&true_class) {
75                correct += 1;
76            }
77        }
78
79        Ok(correct as f64 / total as f64)
80    }
81
82    fn name(&self) -> &str {
83        "top_k_accuracy"
84    }
85}
86
87/// Normalized Discounted Cumulative Gain (NDCG) metric for ranking.
88///
89/// NDCG measures the quality of ranking by comparing the predicted order
90/// with the ideal order. It accounts for position: items ranked higher
91/// contribute more to the score.
92///
93/// # Formula
94/// DCG@k = Σᵢ₌₁ᵏ (2^relᵢ - 1) / log₂(i + 1)
95/// NDCG@k = DCG@k / IDCG@k
96///
97/// where IDCG is the DCG of the ideal ranking.
98///
99/// # Use Cases
100/// - Recommendation systems
101/// - Search engine ranking
102/// - Information retrieval
103/// - Learning to rank
104///
105/// Reference: Järvelin & Kekäläinen "Cumulated gain-based evaluation of IR techniques" (ACM TOIS 2002)
106#[derive(Debug, Clone)]
107pub struct NormalizedDiscountedCumulativeGain {
108    /// Number of top results to consider (k).
109    pub k: usize,
110}
111
112impl Default for NormalizedDiscountedCumulativeGain {
113    fn default() -> Self {
114        Self { k: 10 }
115    }
116}
117
118impl NormalizedDiscountedCumulativeGain {
119    /// Create NDCG metric with custom k value.
120    ///
121    /// # Arguments
122    /// * `k` - Number of top results to consider
123    pub fn new(k: usize) -> Self {
124        Self { k }
125    }
126
127    /// Compute DCG (Discounted Cumulative Gain) for a single ranking.
128    ///
129    /// # Arguments
130    /// * `relevances` - Relevance scores in the predicted order
131    /// * `k` - Number of positions to consider
132    fn compute_dcg(relevances: &[f64], k: usize) -> f64 {
133        let k = k.min(relevances.len());
134        let mut dcg = 0.0;
135
136        for (i, &rel) in relevances.iter().take(k).enumerate() {
137            let position = (i + 2) as f64; // i+2 because positions start at 1 and log₂(1) = 0
138            dcg += (2.0_f64.powf(rel) - 1.0) / position.log2();
139        }
140
141        dcg
142    }
143
144    /// Compute IDCG (Ideal DCG) by sorting relevances in descending order.
145    fn compute_idcg(relevances: &[f64], k: usize) -> f64 {
146        let mut sorted_rel = relevances.to_vec();
147        sorted_rel.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
148        Self::compute_dcg(&sorted_rel, k)
149    }
150}
151
152impl Metric for NormalizedDiscountedCumulativeGain {
153    fn compute(
154        &self,
155        predictions: &ArrayView<f64, Ix2>,
156        targets: &ArrayView<f64, Ix2>,
157    ) -> TrainResult<f64> {
158        if predictions.shape() != targets.shape() {
159            return Err(TrainError::MetricsError(format!(
160                "Shape mismatch: predictions {:?} vs targets {:?}",
161                predictions.shape(),
162                targets.shape()
163            )));
164        }
165
166        let n_samples = predictions.nrows();
167        if n_samples == 0 {
168            return Ok(0.0);
169        }
170
171        let mut ndcg_sum = 0.0;
172
173        for i in 0..n_samples {
174            // Get predicted scores and true relevances for this sample
175            let pred_scores: Vec<f64> = predictions.row(i).iter().copied().collect();
176            let true_relevances: Vec<f64> = targets.row(i).iter().copied().collect();
177
178            // Create indices and sort by predicted scores (descending)
179            let mut indices: Vec<usize> = (0..pred_scores.len()).collect();
180            indices.sort_by(|&a, &b| {
181                pred_scores[b]
182                    .partial_cmp(&pred_scores[a])
183                    .unwrap_or(std::cmp::Ordering::Equal)
184            });
185
186            // Reorder relevances according to predicted ranking
187            let ranked_relevances: Vec<f64> =
188                indices.iter().map(|&idx| true_relevances[idx]).collect();
189
190            // Compute DCG for this ranking
191            let dcg = Self::compute_dcg(&ranked_relevances, self.k);
192
193            // Compute IDCG (ideal ranking)
194            let idcg = Self::compute_idcg(&true_relevances, self.k);
195
196            // Compute NDCG (handle division by zero)
197            let ndcg = if idcg > 1e-12 { dcg / idcg } else { 0.0 };
198
199            ndcg_sum += ndcg;
200        }
201
202        Ok(ndcg_sum / n_samples as f64)
203    }
204
205    fn name(&self) -> &str {
206        "ndcg"
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213    use scirs2_core::ndarray::array;
214
215    #[test]
216    fn test_top_k_accuracy() {
217        let metric = TopKAccuracy::new(2);
218
219        // Test with 3 classes
220        let predictions = array![
221            [0.7, 0.2, 0.1], // Correct class is 0, top-2 includes it
222            [0.1, 0.6, 0.3], // Correct class is 1, top-2 includes it
223            [0.3, 0.4, 0.3], // Correct class is 2, top-2 includes it (1, 0)
224        ];
225        let targets = array![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]];
226
227        let top_k = metric
228            .compute(&predictions.view(), &targets.view())
229            .unwrap();
230        assert!((0.0..=1.0).contains(&top_k));
231        assert!(top_k >= 0.66); // At least 2/3 should be in top-2
232    }
233
234    #[test]
235    fn test_ndcg_perfect_ranking() {
236        let metric = NormalizedDiscountedCumulativeGain::new(5);
237
238        // Perfect ranking: predicted order matches true relevance order
239        let predictions = array![
240            [5.0, 4.0, 3.0, 2.0, 1.0], // Pred scores: highest to lowest
241        ];
242        let targets = array![
243            [5.0, 4.0, 3.0, 2.0, 1.0], // True relevances: match pred order
244        ];
245
246        let ndcg = metric
247            .compute(&predictions.view(), &targets.view())
248            .unwrap();
249
250        // Perfect ranking should give NDCG = 1.0
251        assert!(
252            (ndcg - 1.0).abs() < 1e-6,
253            "Perfect ranking should have NDCG ≈ 1.0, got {}",
254            ndcg
255        );
256    }
257
258    #[test]
259    fn test_ndcg_worst_ranking() {
260        let metric = NormalizedDiscountedCumulativeGain::new(5);
261
262        // Worst ranking: predicted order is reverse of true relevance
263        let predictions = array![
264            [1.0, 2.0, 3.0, 4.0, 5.0], // Pred scores: lowest to highest
265        ];
266        let targets = array![
267            [5.0, 4.0, 3.0, 2.0, 1.0], // True relevances: highest to lowest
268        ];
269
270        let ndcg = metric
271            .compute(&predictions.view(), &targets.view())
272            .unwrap();
273
274        // Worst ranking should give low NDCG
275        assert!(
276            ndcg < 0.8,
277            "Worst ranking should have low NDCG, got {}",
278            ndcg
279        );
280    }
281
282    #[test]
283    fn test_ndcg_partial_match() {
284        let metric = NormalizedDiscountedCumulativeGain::new(3);
285
286        // Partial match: some items ranked correctly
287        let predictions = array![
288            [4.0, 5.0, 2.0, 3.0, 1.0], // Pred order: [1, 0, 3, 2, 4]
289        ];
290        let targets = array![
291            [3.0, 5.0, 1.0, 2.0, 0.0], // True relevances
292        ];
293
294        let ndcg = metric
295            .compute(&predictions.view(), &targets.view())
296            .unwrap();
297
298        // Should be between 0 and 1
299        assert!(
300            (0.0..=1.0).contains(&ndcg),
301            "NDCG should be in [0, 1], got {}",
302            ndcg
303        );
304
305        // Should be reasonably high since highest relevance (5.0) is predicted correctly
306        assert!(
307            ndcg > 0.7,
308            "NDCG should be > 0.7 for this ranking, got {}",
309            ndcg
310        );
311    }
312
313    #[test]
314    fn test_ndcg_multiple_samples() {
315        let metric = NormalizedDiscountedCumulativeGain::new(3);
316
317        // Two samples: one perfect, one reversed
318        let predictions = array![[5.0, 4.0, 3.0, 2.0], [2.0, 3.0, 4.0, 5.0],];
319        let targets = array![[5.0, 4.0, 3.0, 2.0], [5.0, 4.0, 3.0, 2.0],];
320
321        let ndcg = metric
322            .compute(&predictions.view(), &targets.view())
323            .unwrap();
324
325        // Average of perfect (1.0) and poor ranking
326        assert!((0.0..=1.0).contains(&ndcg));
327        assert!(ndcg > 0.4 && ndcg < 0.9); // Should be somewhere in between
328    }
329
330    #[test]
331    fn test_ndcg_different_k_values() {
332        let metric_k3 = NormalizedDiscountedCumulativeGain::new(3);
333        let metric_k5 = NormalizedDiscountedCumulativeGain::new(5);
334
335        let predictions = array![[5.0, 4.0, 3.0, 1.0, 2.0]];
336        let targets = array![[5.0, 4.0, 3.0, 2.0, 1.0]];
337
338        let ndcg_k3 = metric_k3
339            .compute(&predictions.view(), &targets.view())
340            .unwrap();
341        let ndcg_k5 = metric_k5
342            .compute(&predictions.view(), &targets.view())
343            .unwrap();
344
345        // k=3 should be perfect (top 3 are correct)
346        assert!((ndcg_k3 - 1.0).abs() < 1e-6);
347
348        // k=5 should be lower (last 2 are swapped)
349        assert!(ndcg_k5 < ndcg_k3);
350        assert!(ndcg_k5 > 0.9); // Still very good
351    }
352
353    #[test]
354    fn test_ndcg_zero_relevances() {
355        let metric = NormalizedDiscountedCumulativeGain::new(5);
356
357        // All zero relevances
358        let predictions = array![[1.0, 2.0, 3.0]];
359        let targets = array![[0.0, 0.0, 0.0]];
360
361        let ndcg = metric
362            .compute(&predictions.view(), &targets.view())
363            .unwrap();
364
365        // Should handle gracefully (IDCG = 0)
366        assert!(ndcg.is_finite());
367        assert_eq!(ndcg, 0.0);
368    }
369
370    #[test]
371    fn test_ndcg_empty_input() {
372        let metric = NormalizedDiscountedCumulativeGain::default();
373
374        use scirs2_core::ndarray::Array;
375        let empty_predictions: Array<f64, _> = Array::zeros((0, 5));
376        let empty_targets: Array<f64, _> = Array::zeros((0, 5));
377
378        let ndcg = metric
379            .compute(&empty_predictions.view(), &empty_targets.view())
380            .unwrap();
381
382        assert_eq!(ndcg, 0.0);
383    }
384
385    #[test]
386    fn test_ndcg_shape_mismatch() {
387        let metric = NormalizedDiscountedCumulativeGain::default();
388
389        let predictions = array![[1.0, 2.0, 3.0]];
390        let targets = array![[1.0, 2.0]]; // Different shape
391
392        let result = metric.compute(&predictions.view(), &targets.view());
393        assert!(result.is_err());
394    }
395
396    #[test]
397    fn test_ndcg_binary_relevance() {
398        let metric = NormalizedDiscountedCumulativeGain::new(5);
399
400        // Binary relevance (0 or 1)
401        let predictions = array![[0.9, 0.7, 0.5, 0.3, 0.1]];
402        let targets = array![[1.0, 1.0, 0.0, 1.0, 0.0]];
403
404        let ndcg = metric
405            .compute(&predictions.view(), &targets.view())
406            .unwrap();
407
408        // Should be in valid range
409        assert!((0.0..=1.0).contains(&ndcg));
410
411        // Top 2 are relevant, so should have decent NDCG
412        assert!(
413            ndcg > 0.6,
414            "Should have decent NDCG with top-2 relevant, got {}",
415            ndcg
416        );
417    }
418}