scry_learn/weights.rs
1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Class weighting for imbalanced datasets.
3//!
4//! Provides the [`ClassWeight`] enum and [`compute_sample_weights`] function
5//! to generate per-sample weights that compensate for class imbalance.
6//!
7//! # Example
8//!
9//! ```
10//! use scry_learn::weights::{ClassWeight, compute_sample_weights};
11//!
12//! let targets = vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0];
13//! let weights = compute_sample_weights(&targets, &ClassWeight::Balanced);
14//!
15//! // Minority class (1) gets higher weight to compensate for imbalance.
16//! assert!(weights[9] > weights[0]);
17//! ```
18
19use std::collections::HashMap;
20
21/// Strategy for weighting classes during training.
22///
23/// Used by classifiers to handle imbalanced datasets. When set to `Balanced`,
24/// minority classes receive higher weight, making the model pay more attention
25/// to underrepresented classes.
26///
27/// # Example
28///
29/// ```
30/// use scry_learn::weights::ClassWeight;
31/// use scry_learn::tree::DecisionTreeClassifier;
32///
33/// let dt = DecisionTreeClassifier::new()
34/// .class_weight(ClassWeight::Balanced);
35/// ```
36#[derive(Clone, Debug, Default)]
37#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
38#[non_exhaustive]
39pub enum ClassWeight {
40 /// All classes weighted equally (weight = 1.0). This is the default.
41 #[default]
42 Uniform,
43 /// Automatically adjust weights inversely proportional to class frequencies.
44 ///
45 /// Uses the sklearn formula: `weight_c = n_samples / (n_classes × n_c)`.
46 Balanced,
47 /// User-specified per-class weights (class label → weight).
48 Custom(HashMap<usize, f64>),
49}
50
51/// Compute per-sample weights from target labels and a class weighting strategy.
52///
53/// Returns a vector with one weight per sample. For `Uniform`, all weights are 1.0.
54/// For `Balanced`, uses the sklearn formula:
55/// `weight_c = n_samples / (n_classes × count_c)`.
56///
57/// # Example
58///
59/// ```
60/// use scry_learn::weights::{ClassWeight, compute_sample_weights};
61///
62/// // 8 samples of class 0, 2 samples of class 1
63/// let targets = vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0];
64/// let weights = compute_sample_weights(&targets, &ClassWeight::Balanced);
65///
66/// // n_samples=10, n_classes=2
67/// // weight_0 = 10 / (2 × 8) = 0.625
68/// // weight_1 = 10 / (2 × 2) = 2.5
69/// assert!((weights[0] - 0.625).abs() < 1e-6);
70/// assert!((weights[8] - 2.5).abs() < 1e-6);
71/// ```
72pub fn compute_sample_weights(targets: &[f64], class_weight: &ClassWeight) -> Vec<f64> {
73 let n = targets.len();
74 match class_weight {
75 ClassWeight::Uniform => vec![1.0; n],
76 ClassWeight::Balanced => {
77 // Count samples per class.
78 let mut counts: HashMap<usize, usize> = HashMap::new();
79 for &t in targets {
80 *counts.entry(t as usize).or_insert(0) += 1;
81 }
82 let n_classes = counts.len();
83 let n_f = n as f64;
84
85 // weight_c = n_samples / (n_classes × count_c)
86 let class_weights: HashMap<usize, f64> = counts
87 .iter()
88 .map(|(&cls, &count)| {
89 let w = n_f / (n_classes as f64 * count as f64);
90 (cls, w)
91 })
92 .collect();
93
94 targets
95 .iter()
96 .map(|&t| class_weights.get(&(t as usize)).copied().unwrap_or(1.0))
97 .collect()
98 }
99 ClassWeight::Custom(map) => targets
100 .iter()
101 .map(|&t| map.get(&(t as usize)).copied().unwrap_or(1.0))
102 .collect(),
103 }
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109
110 #[test]
111 fn test_uniform_weights() {
112 let targets = vec![0.0, 0.0, 1.0, 1.0, 2.0];
113 let weights = compute_sample_weights(&targets, &ClassWeight::Uniform);
114 assert_eq!(weights.len(), 5);
115 assert!(weights.iter().all(|&w| (w - 1.0).abs() < 1e-12));
116 }
117
118 #[test]
119 fn test_balanced_weights_equal_classes() {
120 // 5 samples of each class → balanced already → all weights ≈ 1.0
121 let targets = vec![0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0];
122 let weights = compute_sample_weights(&targets, &ClassWeight::Balanced);
123 // n=10, n_classes=2, count_0=5, count_1=5
124 // weight = 10/(2*5) = 1.0
125 for &w in &weights {
126 assert!((w - 1.0).abs() < 1e-6, "expected 1.0, got {w}");
127 }
128 }
129
130 #[test]
131 fn test_balanced_weights_imbalanced() {
132 // 90% class 0, 10% class 1
133 let mut targets = vec![0.0; 90];
134 targets.extend(vec![1.0; 10]);
135 let weights = compute_sample_weights(&targets, &ClassWeight::Balanced);
136
137 // weight_0 = 100/(2*90) = 0.5556
138 // weight_1 = 100/(2*10) = 5.0
139 let w0 = weights[0];
140 let w1 = weights[90];
141 assert!(
142 (w0 - 100.0 / 180.0).abs() < 1e-6,
143 "majority weight: expected {}, got {w0}",
144 100.0 / 180.0
145 );
146 assert!(
147 (w1 - 5.0).abs() < 1e-6,
148 "minority weight: expected 5.0, got {w1}"
149 );
150 // Minority weight should be much higher.
151 assert!(w1 > w0 * 5.0);
152 }
153
154 #[test]
155 fn test_custom_weights() {
156 let mut map = HashMap::new();
157 map.insert(0, 1.0);
158 map.insert(1, 10.0);
159 let targets = vec![0.0, 0.0, 1.0, 1.0];
160 let weights = compute_sample_weights(&targets, &ClassWeight::Custom(map));
161 assert!((weights[0] - 1.0).abs() < 1e-12);
162 assert!((weights[2] - 10.0).abs() < 1e-12);
163 }
164
165 #[test]
166 fn test_custom_weights_missing_class_defaults_to_one() {
167 let mut map = HashMap::new();
168 map.insert(1, 5.0);
169 // Class 0 not in map → defaults to 1.0
170 let targets = vec![0.0, 1.0];
171 let weights = compute_sample_weights(&targets, &ClassWeight::Custom(map));
172 assert!((weights[0] - 1.0).abs() < 1e-12);
173 assert!((weights[1] - 5.0).abs() < 1e-12);
174 }
175}