tensorlogic_train/metrics/
ranking.rs1use crate::{TrainError, TrainResult};
4use scirs2_core::ndarray::{ArrayView, Ix2};
5
6use super::Metric;
7
8#[derive(Debug, Clone)]
11pub struct TopKAccuracy {
12 pub k: usize,
14}
15
16impl Default for TopKAccuracy {
17 fn default() -> Self {
18 Self { k: 5 }
19 }
20}
21
22impl TopKAccuracy {
23 pub fn new(k: usize) -> Self {
25 Self { k }
26 }
27}
28
29impl Metric for TopKAccuracy {
30 fn compute(
31 &self,
32 predictions: &ArrayView<f64, Ix2>,
33 targets: &ArrayView<f64, Ix2>,
34 ) -> TrainResult<f64> {
35 if predictions.shape() != targets.shape() {
36 return Err(TrainError::MetricsError(format!(
37 "Shape mismatch: predictions {:?} vs targets {:?}",
38 predictions.shape(),
39 targets.shape()
40 )));
41 }
42
43 let num_classes = predictions.ncols();
44 if self.k > num_classes {
45 return Err(TrainError::MetricsError(format!(
46 "K ({}) cannot be greater than number of classes ({})",
47 self.k, num_classes
48 )));
49 }
50
51 let mut correct = 0;
52 let total = predictions.nrows();
53
54 for i in 0..total {
55 let mut true_class = 0;
57 let mut max_true = targets[[i, 0]];
58 for j in 1..num_classes {
59 if targets[[i, j]] > max_true {
60 max_true = targets[[i, j]];
61 true_class = j;
62 }
63 }
64
65 let mut indices: Vec<usize> = (0..num_classes).collect();
67 indices.sort_by(|&a, &b| {
68 predictions[[i, b]]
69 .partial_cmp(&predictions[[i, a]])
70 .unwrap_or(std::cmp::Ordering::Equal)
71 });
72
73 if indices[..self.k].contains(&true_class) {
75 correct += 1;
76 }
77 }
78
79 Ok(correct as f64 / total as f64)
80 }
81
82 fn name(&self) -> &str {
83 "top_k_accuracy"
84 }
85}
86
87#[derive(Debug, Clone)]
107pub struct NormalizedDiscountedCumulativeGain {
108 pub k: usize,
110}
111
112impl Default for NormalizedDiscountedCumulativeGain {
113 fn default() -> Self {
114 Self { k: 10 }
115 }
116}
117
118impl NormalizedDiscountedCumulativeGain {
119 pub fn new(k: usize) -> Self {
124 Self { k }
125 }
126
127 fn compute_dcg(relevances: &[f64], k: usize) -> f64 {
133 let k = k.min(relevances.len());
134 let mut dcg = 0.0;
135
136 for (i, &rel) in relevances.iter().take(k).enumerate() {
137 let position = (i + 2) as f64; dcg += (2.0_f64.powf(rel) - 1.0) / position.log2();
139 }
140
141 dcg
142 }
143
144 fn compute_idcg(relevances: &[f64], k: usize) -> f64 {
146 let mut sorted_rel = relevances.to_vec();
147 sorted_rel.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
148 Self::compute_dcg(&sorted_rel, k)
149 }
150}
151
152impl Metric for NormalizedDiscountedCumulativeGain {
153 fn compute(
154 &self,
155 predictions: &ArrayView<f64, Ix2>,
156 targets: &ArrayView<f64, Ix2>,
157 ) -> TrainResult<f64> {
158 if predictions.shape() != targets.shape() {
159 return Err(TrainError::MetricsError(format!(
160 "Shape mismatch: predictions {:?} vs targets {:?}",
161 predictions.shape(),
162 targets.shape()
163 )));
164 }
165
166 let n_samples = predictions.nrows();
167 if n_samples == 0 {
168 return Ok(0.0);
169 }
170
171 let mut ndcg_sum = 0.0;
172
173 for i in 0..n_samples {
174 let pred_scores: Vec<f64> = predictions.row(i).iter().copied().collect();
176 let true_relevances: Vec<f64> = targets.row(i).iter().copied().collect();
177
178 let mut indices: Vec<usize> = (0..pred_scores.len()).collect();
180 indices.sort_by(|&a, &b| {
181 pred_scores[b]
182 .partial_cmp(&pred_scores[a])
183 .unwrap_or(std::cmp::Ordering::Equal)
184 });
185
186 let ranked_relevances: Vec<f64> =
188 indices.iter().map(|&idx| true_relevances[idx]).collect();
189
190 let dcg = Self::compute_dcg(&ranked_relevances, self.k);
192
193 let idcg = Self::compute_idcg(&true_relevances, self.k);
195
196 let ndcg = if idcg > 1e-12 { dcg / idcg } else { 0.0 };
198
199 ndcg_sum += ndcg;
200 }
201
202 Ok(ndcg_sum / n_samples as f64)
203 }
204
205 fn name(&self) -> &str {
206 "ndcg"
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213 use scirs2_core::ndarray::array;
214
215 #[test]
216 fn test_top_k_accuracy() {
217 let metric = TopKAccuracy::new(2);
218
219 let predictions = array![
221 [0.7, 0.2, 0.1], [0.1, 0.6, 0.3], [0.3, 0.4, 0.3], ];
225 let targets = array![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]];
226
227 let top_k = metric
228 .compute(&predictions.view(), &targets.view())
229 .unwrap();
230 assert!((0.0..=1.0).contains(&top_k));
231 assert!(top_k >= 0.66); }
233
234 #[test]
235 fn test_ndcg_perfect_ranking() {
236 let metric = NormalizedDiscountedCumulativeGain::new(5);
237
238 let predictions = array![
240 [5.0, 4.0, 3.0, 2.0, 1.0], ];
242 let targets = array![
243 [5.0, 4.0, 3.0, 2.0, 1.0], ];
245
246 let ndcg = metric
247 .compute(&predictions.view(), &targets.view())
248 .unwrap();
249
250 assert!(
252 (ndcg - 1.0).abs() < 1e-6,
253 "Perfect ranking should have NDCG ≈ 1.0, got {}",
254 ndcg
255 );
256 }
257
258 #[test]
259 fn test_ndcg_worst_ranking() {
260 let metric = NormalizedDiscountedCumulativeGain::new(5);
261
262 let predictions = array![
264 [1.0, 2.0, 3.0, 4.0, 5.0], ];
266 let targets = array![
267 [5.0, 4.0, 3.0, 2.0, 1.0], ];
269
270 let ndcg = metric
271 .compute(&predictions.view(), &targets.view())
272 .unwrap();
273
274 assert!(
276 ndcg < 0.8,
277 "Worst ranking should have low NDCG, got {}",
278 ndcg
279 );
280 }
281
282 #[test]
283 fn test_ndcg_partial_match() {
284 let metric = NormalizedDiscountedCumulativeGain::new(3);
285
286 let predictions = array![
288 [4.0, 5.0, 2.0, 3.0, 1.0], ];
290 let targets = array![
291 [3.0, 5.0, 1.0, 2.0, 0.0], ];
293
294 let ndcg = metric
295 .compute(&predictions.view(), &targets.view())
296 .unwrap();
297
298 assert!(
300 (0.0..=1.0).contains(&ndcg),
301 "NDCG should be in [0, 1], got {}",
302 ndcg
303 );
304
305 assert!(
307 ndcg > 0.7,
308 "NDCG should be > 0.7 for this ranking, got {}",
309 ndcg
310 );
311 }
312
313 #[test]
314 fn test_ndcg_multiple_samples() {
315 let metric = NormalizedDiscountedCumulativeGain::new(3);
316
317 let predictions = array![[5.0, 4.0, 3.0, 2.0], [2.0, 3.0, 4.0, 5.0],];
319 let targets = array![[5.0, 4.0, 3.0, 2.0], [5.0, 4.0, 3.0, 2.0],];
320
321 let ndcg = metric
322 .compute(&predictions.view(), &targets.view())
323 .unwrap();
324
325 assert!((0.0..=1.0).contains(&ndcg));
327 assert!(ndcg > 0.4 && ndcg < 0.9); }
329
330 #[test]
331 fn test_ndcg_different_k_values() {
332 let metric_k3 = NormalizedDiscountedCumulativeGain::new(3);
333 let metric_k5 = NormalizedDiscountedCumulativeGain::new(5);
334
335 let predictions = array![[5.0, 4.0, 3.0, 1.0, 2.0]];
336 let targets = array![[5.0, 4.0, 3.0, 2.0, 1.0]];
337
338 let ndcg_k3 = metric_k3
339 .compute(&predictions.view(), &targets.view())
340 .unwrap();
341 let ndcg_k5 = metric_k5
342 .compute(&predictions.view(), &targets.view())
343 .unwrap();
344
345 assert!((ndcg_k3 - 1.0).abs() < 1e-6);
347
348 assert!(ndcg_k5 < ndcg_k3);
350 assert!(ndcg_k5 > 0.9); }
352
353 #[test]
354 fn test_ndcg_zero_relevances() {
355 let metric = NormalizedDiscountedCumulativeGain::new(5);
356
357 let predictions = array![[1.0, 2.0, 3.0]];
359 let targets = array![[0.0, 0.0, 0.0]];
360
361 let ndcg = metric
362 .compute(&predictions.view(), &targets.view())
363 .unwrap();
364
365 assert!(ndcg.is_finite());
367 assert_eq!(ndcg, 0.0);
368 }
369
370 #[test]
371 fn test_ndcg_empty_input() {
372 let metric = NormalizedDiscountedCumulativeGain::default();
373
374 use scirs2_core::ndarray::Array;
375 let empty_predictions: Array<f64, _> = Array::zeros((0, 5));
376 let empty_targets: Array<f64, _> = Array::zeros((0, 5));
377
378 let ndcg = metric
379 .compute(&empty_predictions.view(), &empty_targets.view())
380 .unwrap();
381
382 assert_eq!(ndcg, 0.0);
383 }
384
385 #[test]
386 fn test_ndcg_shape_mismatch() {
387 let metric = NormalizedDiscountedCumulativeGain::default();
388
389 let predictions = array![[1.0, 2.0, 3.0]];
390 let targets = array![[1.0, 2.0]]; let result = metric.compute(&predictions.view(), &targets.view());
393 assert!(result.is_err());
394 }
395
396 #[test]
397 fn test_ndcg_binary_relevance() {
398 let metric = NormalizedDiscountedCumulativeGain::new(5);
399
400 let predictions = array![[0.9, 0.7, 0.5, 0.3, 0.1]];
402 let targets = array![[1.0, 1.0, 0.0, 1.0, 0.0]];
403
404 let ndcg = metric
405 .compute(&predictions.view(), &targets.view())
406 .unwrap();
407
408 assert!((0.0..=1.0).contains(&ndcg));
410
411 assert!(
413 ndcg > 0.6,
414 "Should have decent NDCG with top-2 relevant, got {}",
415 ndcg
416 );
417 }
418}