1use std::collections::HashMap;
5use std::fmt;
6
7#[derive(Clone, Copy, Debug, PartialEq, Eq)]
9#[non_exhaustive]
10pub enum Average {
11 Binary,
13 Macro,
15 Weighted,
17}
18
19#[derive(Clone, Debug)]
21#[non_exhaustive]
22pub struct ConfusionMatrix {
23 pub matrix: Vec<Vec<usize>>,
25 pub labels: Vec<String>,
27}
28
29#[derive(Clone, Debug)]
31#[non_exhaustive]
32pub struct ClassMetrics {
33 pub precision: f64,
35 pub recall: f64,
37 pub f1: f64,
39 pub support: usize,
41}
42
43#[derive(Clone, Debug)]
45#[non_exhaustive]
46pub struct ClassificationReport {
47 pub accuracy: f64,
49 pub per_class: Vec<(String, ClassMetrics)>,
51 pub macro_avg: ClassMetrics,
53 pub weighted_avg: ClassMetrics,
55 pub total_support: usize,
57}
58
59impl fmt::Display for ClassificationReport {
60 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61 writeln!(
62 f,
63 "{:>15} {:>10} {:>10} {:>10} {:>10}",
64 "", "precision", "recall", "f1-score", "support"
65 )?;
66 writeln!(f)?;
67 for (label, m) in &self.per_class {
68 writeln!(
69 f,
70 "{:>15} {:>10.4} {:>10.4} {:>10.4} {:>10}",
71 label, m.precision, m.recall, m.f1, m.support
72 )?;
73 }
74 writeln!(f)?;
75 writeln!(
76 f,
77 "{:>15} {:>10.4} {:>10.4} {:>10.4} {:>10}",
78 "accuracy", "", "", self.accuracy, self.total_support
79 )?;
80 writeln!(
81 f,
82 "{:>15} {:>10.4} {:>10.4} {:>10.4} {:>10}",
83 "macro avg",
84 self.macro_avg.precision,
85 self.macro_avg.recall,
86 self.macro_avg.f1,
87 self.total_support
88 )?;
89 writeln!(
90 f,
91 "{:>15} {:>10.4} {:>10.4} {:>10.4} {:>10}",
92 "weighted avg",
93 self.weighted_avg.precision,
94 self.weighted_avg.recall,
95 self.weighted_avg.f1,
96 self.total_support
97 )?;
98 Ok(())
99 }
100}
101
102pub fn accuracy(y_true: &[f64], y_pred: &[f64]) -> f64 {
104 if y_true.is_empty() {
105 return 0.0;
106 }
107 let correct = y_true
108 .iter()
109 .zip(y_pred.iter())
110 .filter(|(t, p)| (*t - *p).abs() < 1e-6)
111 .count();
112 correct as f64 / y_true.len() as f64
113}
114
115fn precision_from_cm(cm: &ConfusionMatrix, avg: Average) -> f64 {
117 let n = cm.matrix.len();
118 match avg {
119 Average::Binary => {
120 let tp = if n >= 2 { cm.matrix[1][1] } else { 0 };
121 let fp = (0..n)
122 .map(|i| if i == 1 { 0 } else { cm.matrix[i][1] })
123 .sum::<usize>();
124 if tp + fp == 0 {
125 0.0
126 } else {
127 tp as f64 / (tp + fp) as f64
128 }
129 }
130 Average::Macro => {
131 let mut total = 0.0;
132 for c in 0..n {
133 let tp = cm.matrix[c][c];
134 let fp: usize = (0..n)
135 .map(|i| if i == c { 0 } else { cm.matrix[i][c] })
136 .sum();
137 total += if tp + fp == 0 {
138 0.0
139 } else {
140 tp as f64 / (tp + fp) as f64
141 };
142 }
143 total / n as f64
144 }
145 Average::Weighted => {
146 let mut total = 0.0;
147 let mut total_support = 0;
148 for c in 0..n {
149 let support: usize = cm.matrix[c].iter().sum();
150 let tp = cm.matrix[c][c];
151 let fp: usize = (0..n)
152 .map(|i| if i == c { 0 } else { cm.matrix[i][c] })
153 .sum();
154 let p = if tp + fp == 0 {
155 0.0
156 } else {
157 tp as f64 / (tp + fp) as f64
158 };
159 total += p * support as f64;
160 total_support += support;
161 }
162 if total_support == 0 {
163 0.0
164 } else {
165 total / total_support as f64
166 }
167 }
168 }
169}
170
171fn recall_from_cm(cm: &ConfusionMatrix, avg: Average) -> f64 {
173 let n = cm.matrix.len();
174 match avg {
175 Average::Binary => {
176 let tp = if n >= 2 { cm.matrix[1][1] } else { 0 };
177 let fn_ = if n >= 2 {
178 (0..n)
179 .map(|j| if j == 1 { 0 } else { cm.matrix[1][j] })
180 .sum::<usize>()
181 } else {
182 0
183 };
184 if tp + fn_ == 0 {
185 0.0
186 } else {
187 tp as f64 / (tp + fn_) as f64
188 }
189 }
190 Average::Macro => {
191 let mut total = 0.0;
192 for c in 0..n {
193 let tp = cm.matrix[c][c];
194 let support: usize = cm.matrix[c].iter().sum();
195 total += if support == 0 {
196 0.0
197 } else {
198 tp as f64 / support as f64
199 };
200 }
201 total / n as f64
202 }
203 Average::Weighted => {
204 let mut total = 0.0;
205 let mut total_support = 0;
206 for c in 0..n {
207 let support: usize = cm.matrix[c].iter().sum();
208 let tp = cm.matrix[c][c];
209 let r = if support == 0 {
210 0.0
211 } else {
212 tp as f64 / support as f64
213 };
214 total += r * support as f64;
215 total_support += support;
216 }
217 if total_support == 0 {
218 0.0
219 } else {
220 total / total_support as f64
221 }
222 }
223 }
224}
225
226pub fn precision(y_true: &[f64], y_pred: &[f64], avg: Average) -> f64 {
228 let cm = confusion_matrix(y_true, y_pred);
229 precision_from_cm(&cm, avg)
230}
231
232pub fn recall(y_true: &[f64], y_pred: &[f64], avg: Average) -> f64 {
234 let cm = confusion_matrix(y_true, y_pred);
235 recall_from_cm(&cm, avg)
236}
237
238pub fn f1_score(y_true: &[f64], y_pred: &[f64], avg: Average) -> f64 {
244 let cm = confusion_matrix(y_true, y_pred);
245 let n = cm.matrix.len();
246
247 match avg {
248 Average::Binary => {
249 let p = precision_from_cm(&cm, Average::Binary);
250 let r = recall_from_cm(&cm, Average::Binary);
251 if p + r == 0.0 {
252 0.0
253 } else {
254 2.0 * p * r / (p + r)
255 }
256 }
257 Average::Macro => {
258 let mut total_f1 = 0.0;
259 for c in 0..n {
260 let tp = cm.matrix[c][c];
261 let fp: usize = (0..n)
262 .map(|i| if i == c { 0 } else { cm.matrix[i][c] })
263 .sum();
264 let support: usize = cm.matrix[c].iter().sum();
265 let p = if tp + fp == 0 {
266 0.0
267 } else {
268 tp as f64 / (tp + fp) as f64
269 };
270 let r = if support == 0 {
271 0.0
272 } else {
273 tp as f64 / support as f64
274 };
275 total_f1 += if p + r == 0.0 {
276 0.0
277 } else {
278 2.0 * p * r / (p + r)
279 };
280 }
281 total_f1 / n as f64
282 }
283 Average::Weighted => {
284 let mut total_f1 = 0.0;
285 let mut total_support = 0;
286 for c in 0..n {
287 let tp = cm.matrix[c][c];
288 let fp: usize = (0..n)
289 .map(|i| if i == c { 0 } else { cm.matrix[i][c] })
290 .sum();
291 let support: usize = cm.matrix[c].iter().sum();
292 let p = if tp + fp == 0 {
293 0.0
294 } else {
295 tp as f64 / (tp + fp) as f64
296 };
297 let r = if support == 0 {
298 0.0
299 } else {
300 tp as f64 / support as f64
301 };
302 let f = if p + r == 0.0 {
303 0.0
304 } else {
305 2.0 * p * r / (p + r)
306 };
307 total_f1 += f * support as f64;
308 total_support += support;
309 }
310 if total_support == 0 {
311 0.0
312 } else {
313 total_f1 / total_support as f64
314 }
315 }
316 }
317}
318
319pub fn confusion_matrix(y_true: &[f64], y_pred: &[f64]) -> ConfusionMatrix {
321 let mut classes: Vec<i64> = y_true
322 .iter()
323 .chain(y_pred.iter())
324 .map(|&v| v as i64)
325 .collect();
326 classes.sort_unstable();
327 classes.dedup();
328
329 let n = classes.len();
330 let mut matrix = vec![vec![0usize; n]; n];
331 let labels: Vec<String> = classes
332 .iter()
333 .map(std::string::ToString::to_string)
334 .collect();
335
336 let class_map: HashMap<i64, usize> = classes.iter().enumerate().map(|(i, &c)| (c, i)).collect();
338
339 for (&t, &p) in y_true.iter().zip(y_pred.iter()) {
340 let ti = class_map.get(&(t as i64)).copied().unwrap_or(0);
341 let pi = class_map.get(&(p as i64)).copied().unwrap_or(0);
342 matrix[ti][pi] += 1;
343 }
344
345 ConfusionMatrix { matrix, labels }
346}
347
348pub fn classification_report(y_true: &[f64], y_pred: &[f64]) -> ClassificationReport {
350 let cm = confusion_matrix(y_true, y_pred);
351 let n = cm.matrix.len();
352 let total: usize = cm.matrix.iter().flat_map(|r| r.iter()).sum();
353
354 let mut per_class = Vec::with_capacity(n);
355 let mut macro_p = 0.0;
356 let mut macro_r = 0.0;
357 let mut macro_f = 0.0;
358 let mut weighted_p = 0.0;
359 let mut weighted_r = 0.0;
360 let mut weighted_f = 0.0;
361
362 for c in 0..n {
363 let tp = cm.matrix[c][c];
364 let support: usize = cm.matrix[c].iter().sum();
365 let fp: usize = (0..n)
366 .map(|i| if i == c { 0 } else { cm.matrix[i][c] })
367 .sum();
368
369 let p = if tp + fp == 0 {
370 0.0
371 } else {
372 tp as f64 / (tp + fp) as f64
373 };
374 let r = if support == 0 {
375 0.0
376 } else {
377 tp as f64 / support as f64
378 };
379 let f = if p + r == 0.0 {
380 0.0
381 } else {
382 2.0 * p * r / (p + r)
383 };
384
385 per_class.push((
386 cm.labels[c].clone(),
387 ClassMetrics {
388 precision: p,
389 recall: r,
390 f1: f,
391 support,
392 },
393 ));
394
395 macro_p += p;
396 macro_r += r;
397 macro_f += f;
398 weighted_p += p * support as f64;
399 weighted_r += r * support as f64;
400 weighted_f += f * support as f64;
401 }
402
403 let n_f = n as f64;
404 let total_f = total as f64;
405
406 ClassificationReport {
407 accuracy: accuracy(y_true, y_pred),
408 per_class,
409 macro_avg: ClassMetrics {
410 precision: macro_p / n_f,
411 recall: macro_r / n_f,
412 f1: macro_f / n_f,
413 support: total,
414 },
415 weighted_avg: ClassMetrics {
416 precision: if total > 0 { weighted_p / total_f } else { 0.0 },
417 recall: if total > 0 { weighted_r / total_f } else { 0.0 },
418 f1: if total > 0 { weighted_f / total_f } else { 0.0 },
419 support: total,
420 },
421 total_support: total,
422 }
423}
424
425pub fn log_loss(y_true: &[f64], y_prob: &[Vec<f64>]) -> f64 {
433 if y_true.is_empty() || y_prob.is_empty() {
434 return 0.0;
435 }
436 let eps = 1e-15;
437 let n = y_true.len();
438 let mut total = 0.0;
439 for (i, &label) in y_true.iter().enumerate() {
440 let class_idx = label as usize;
441 if class_idx < y_prob[i].len() {
442 let p = y_prob[i][class_idx].clamp(eps, 1.0 - eps);
443 total -= p.ln();
444 }
445 }
446 total / n as f64
447}
448
449pub fn balanced_accuracy(y_true: &[f64], y_pred: &[f64]) -> f64 {
454 if y_true.is_empty() {
455 return 0.0;
456 }
457 let cm = confusion_matrix(y_true, y_pred);
458 let n = cm.matrix.len();
459 let mut total_recall = 0.0;
460 for c in 0..n {
461 let support: usize = cm.matrix[c].iter().sum();
462 let tp = cm.matrix[c][c];
463 total_recall += if support == 0 {
464 0.0
465 } else {
466 tp as f64 / support as f64
467 };
468 }
469 total_recall / n as f64
470}
471
472pub fn cohen_kappa_score(y_true: &[f64], y_pred: &[f64]) -> f64 {
478 if y_true.is_empty() {
479 return 0.0;
480 }
481 let cm = confusion_matrix(y_true, y_pred);
482 let n_classes = cm.matrix.len();
483 let total: f64 = cm.matrix.iter().flat_map(|r| r.iter()).sum::<usize>() as f64;
484 if total == 0.0 {
485 return 0.0;
486 }
487
488 let p_o: f64 = (0..n_classes).map(|c| cm.matrix[c][c] as f64).sum::<f64>() / total;
490
491 let mut p_e = 0.0;
493 for c in 0..n_classes {
494 let row_sum: f64 = cm.matrix[c].iter().sum::<usize>() as f64;
495 let col_sum: f64 = (0..n_classes).map(|r| cm.matrix[r][c] as f64).sum::<f64>();
496 p_e += (row_sum * col_sum) / (total * total);
497 }
498
499 if (1.0 - p_e).abs() < 1e-15 {
500 return if (p_o - 1.0).abs() < 1e-15 { 1.0 } else { 0.0 };
501 }
502
503 (p_o - p_e) / (1.0 - p_e)
504}
505
506#[cfg(test)]
507mod tests {
508 use super::*;
509
510 #[test]
511 fn test_accuracy_perfect() {
512 assert!((accuracy(&[0.0, 1.0, 2.0], &[0.0, 1.0, 2.0]) - 1.0).abs() < 1e-10);
513 }
514
515 #[test]
516 fn test_accuracy_half() {
517 assert!((accuracy(&[0.0, 1.0, 0.0, 1.0], &[0.0, 0.0, 0.0, 1.0]) - 0.75).abs() < 1e-10);
518 }
519
520 #[test]
521 fn test_confusion_matrix_binary() {
522 let y_true = vec![0.0, 0.0, 1.0, 1.0];
523 let y_pred = vec![0.0, 1.0, 0.0, 1.0];
524 let cm = confusion_matrix(&y_true, &y_pred);
525 assert_eq!(cm.matrix, vec![vec![1, 1], vec![1, 1]]);
526 }
527
528 #[test]
529 fn test_classification_report_display() {
530 let y_true = vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0];
531 let y_pred = vec![0.0, 0.0, 1.0, 2.0, 1.0, 2.0];
532 let report = classification_report(&y_true, &y_pred);
533 let output = format!("{report}");
534 assert!(output.contains("accuracy"));
535 assert!(output.contains("macro avg"));
536 }
537
538 #[test]
539 fn test_f1_binary() {
540 let y_true = vec![0.0, 1.0, 1.0];
542 let y_pred = vec![1.0, 1.0, 0.0];
543 let f = f1_score(&y_true, &y_pred, Average::Binary);
544 assert!((f - 0.5).abs() < 1e-6, "expected F1=0.5, got {f}");
545 }
546
547 #[test]
552 fn test_log_loss_perfect() {
553 let y_true = vec![0.0, 1.0, 2.0];
555 let y_prob = vec![
556 vec![1.0, 0.0, 0.0],
557 vec![0.0, 1.0, 0.0],
558 vec![0.0, 0.0, 1.0],
559 ];
560 let ll = log_loss(&y_true, &y_prob);
561 assert!(ll < 1e-10, "perfect log_loss should be ~0, got {ll}");
562 }
563
564 #[test]
565 fn test_log_loss_random() {
566 let y_true = vec![0.0, 1.0, 2.0];
568 let y_prob = vec![
569 vec![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0],
570 vec![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0],
571 vec![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0],
572 ];
573 let ll = log_loss(&y_true, &y_prob);
574 assert!(ll > 0.5, "random log_loss should be positive, got {ll}");
575 assert!(
576 (ll - 3.0_f64.ln()).abs() < 1e-6,
577 "expected ~ln(3), got {ll}"
578 );
579 }
580
581 #[test]
586 fn test_balanced_accuracy_perfect() {
587 let ba = balanced_accuracy(&[0.0, 1.0, 2.0], &[0.0, 1.0, 2.0]);
588 assert!((ba - 1.0).abs() < 1e-10);
589 }
590
591 #[test]
592 fn test_balanced_accuracy_imbalanced() {
593 let mut y_true = vec![0.0; 90];
595 y_true.extend(vec![1.0; 10]);
596 let y_pred = vec![0.0; 100];
597
598 let raw = accuracy(&y_true, &y_pred);
599 let bal = balanced_accuracy(&y_true, &y_pred);
600
601 assert!((raw - 0.90).abs() < 1e-10);
603 assert!((bal - 0.50).abs() < 1e-10);
604 assert!(bal < raw, "balanced should be lower on imbalanced data");
605 }
606
607 #[test]
612 fn test_cohen_kappa_perfect() {
613 let kappa = cohen_kappa_score(&[0.0, 1.0, 2.0, 0.0, 1.0], &[0.0, 1.0, 2.0, 0.0, 1.0]);
614 assert!(
615 (kappa - 1.0).abs() < 1e-10,
616 "perfect kappa should be 1.0, got {kappa}"
617 );
618 }
619
620 #[test]
621 fn test_cohen_kappa_chance() {
622 let y_true = vec![0.0, 0.0, 1.0, 1.0];
624 let y_pred = vec![0.0, 0.0, 0.0, 0.0];
625 let kappa = cohen_kappa_score(&y_true, &y_pred);
626 assert!(
627 kappa.abs() < 1e-10,
628 "chance kappa should be ~0, got {kappa}"
629 );
630 }
631
632 #[test]
633 fn test_cohen_kappa_partial() {
634 let y_true = vec![0.0, 0.0, 1.0, 1.0];
636 let y_pred = vec![0.0, 0.0, 0.0, 1.0];
637 let kappa = cohen_kappa_score(&y_true, &y_pred);
638 assert!(
642 (kappa - 0.5).abs() < 1e-10,
643 "expected kappa=0.5, got {kappa}"
644 );
645 }
646}