scry_learn/preprocess/
normalizer.rs1use crate::dataset::Dataset;
5use crate::error::{Result, ScryLearnError};
6use crate::preprocess::Transformer;
7
8#[derive(Clone, Debug, Copy, PartialEq, Eq)]
10#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
11#[non_exhaustive]
12pub enum Norm {
13 L1,
15 L2,
17 Max,
19}
20
21#[derive(Clone, Debug)]
36#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
37#[non_exhaustive]
38pub struct Normalizer {
39 norm: Norm,
40 #[cfg_attr(feature = "serde", serde(default))]
41 _schema_version: u32,
42}
43
44impl Normalizer {
45 pub fn new(norm: Norm) -> Self {
47 Self {
48 norm,
49 _schema_version: crate::version::SCHEMA_VERSION,
50 }
51 }
52
53 pub fn l2() -> Self {
55 Self {
56 norm: Norm::L2,
57 _schema_version: crate::version::SCHEMA_VERSION,
58 }
59 }
60}
61
62impl Default for Normalizer {
63 fn default() -> Self {
64 Self::l2()
65 }
66}
67
68impl Transformer for Normalizer {
69 fn fit(&mut self, data: &Dataset) -> Result<()> {
70 data.validate_finite()?;
71 if data.n_samples() == 0 {
72 return Err(ScryLearnError::EmptyDataset);
73 }
74 Ok(())
76 }
77
78 fn transform(&self, data: &mut Dataset) -> Result<()> {
79 crate::version::check_schema_version(self._schema_version)?;
80 let n = data.n_samples();
81 let m = data.n_features();
82
83 for i in 0..n {
84 let norm_val = match self.norm {
86 Norm::L1 => {
87 let mut s = 0.0_f64;
88 for col in &data.features {
89 s += col[i].abs();
90 }
91 s
92 }
93 Norm::L2 => {
94 let mut s = 0.0_f64;
95 for col in &data.features {
96 s += col[i] * col[i];
97 }
98 s.sqrt()
99 }
100 Norm::Max => {
101 let mut mx = 0.0_f64;
102 for col in &data.features {
103 mx = mx.max(col[i].abs());
104 }
105 mx
106 }
107 };
108
109 if norm_val > 1e-12 {
110 for j in 0..m {
111 data.features[j][i] /= norm_val;
112 }
113 }
114 }
115
116 data.sync_matrix();
117 Ok(())
118 }
119
120 fn inverse_transform(&self, _data: &mut Dataset) -> Result<()> {
121 Err(ScryLearnError::InvalidParameter(
122 "Normalizer is not invertible (row norms are lost)".into(),
123 ))
124 }
125}
126
127#[cfg(test)]
128mod tests {
129 use super::*;
130
131 fn make_ds(rows: &[Vec<f64>]) -> Dataset {
132 let n = rows.len();
133 let m = rows[0].len();
134 let mut features = vec![vec![0.0; n]; m];
135 for (i, row) in rows.iter().enumerate() {
136 for (j, &val) in row.iter().enumerate() {
137 features[j][i] = val;
138 }
139 }
140 let names: Vec<String> = (0..m).map(|j| format!("f{j}")).collect();
141 Dataset::new(features, vec![0.0; n], names, "y")
142 }
143
144 #[test]
145 fn test_normalizer_l2_unit_norm() {
146 let mut ds = make_ds(&[vec![3.0, 4.0], vec![1.0, 0.0]]);
147 let mut norm = Normalizer::new(Norm::L2);
148 norm.fit_transform(&mut ds).unwrap();
149
150 assert!((ds.features[0][0] - 0.6).abs() < 1e-10);
152 assert!((ds.features[1][0] - 0.8).abs() < 1e-10);
153
154 for i in 0..ds.n_samples() {
156 let mut sq_sum = 0.0;
157 for col in &ds.features {
158 sq_sum += col[i] * col[i];
159 }
160 assert!(
161 (sq_sum - 1.0).abs() < 1e-10,
162 "row {i} L2 norm² = {sq_sum}, expected 1.0"
163 );
164 }
165 }
166
167 #[test]
168 fn test_normalizer_l1() {
169 let mut ds = make_ds(&[vec![1.0, 2.0, 3.0]]);
170 let mut norm = Normalizer::new(Norm::L1);
171 norm.fit_transform(&mut ds).unwrap();
172
173 let abs_sum: f64 = ds.features.iter().map(|c| c[0].abs()).sum();
175 assert!(
176 (abs_sum - 1.0).abs() < 1e-10,
177 "L1 norm should be 1.0, got {abs_sum}"
178 );
179 }
180
181 #[test]
182 fn test_normalizer_max() {
183 let mut ds = make_ds(&[vec![-5.0, 2.0, 3.0]]);
184 let mut norm = Normalizer::new(Norm::Max);
185 norm.fit_transform(&mut ds).unwrap();
186
187 assert!((ds.features[0][0] - (-1.0)).abs() < 1e-10);
189 let max_abs: f64 = ds
190 .features
191 .iter()
192 .map(|c| c[0].abs())
193 .fold(0.0_f64, f64::max);
194 assert!(
195 (max_abs - 1.0).abs() < 1e-10,
196 "Max norm should be 1.0, got {max_abs}"
197 );
198 }
199
200 #[test]
201 fn test_normalizer_zero_row() {
202 let mut ds = make_ds(&[vec![0.0, 0.0]]);
204 let mut norm = Normalizer::new(Norm::L2);
205 norm.fit_transform(&mut ds).unwrap();
206
207 for col in &ds.features {
208 assert!((col[0]).abs() < 1e-10);
209 }
210 }
211}