Skip to main content

scry_learn/preprocess/
normalizer.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Row-wise sample normalization.
3
4use crate::dataset::Dataset;
5use crate::error::{Result, ScryLearnError};
6use crate::preprocess::Transformer;
7
8/// Norm type for row-wise normalization.
9#[derive(Clone, Debug, Copy, PartialEq, Eq)]
10#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
11#[non_exhaustive]
12pub enum Norm {
13    /// Divide each row by the sum of absolute values.
14    L1,
15    /// Divide each row by its Euclidean (L2) norm.
16    L2,
17    /// Divide each row by its maximum absolute value.
18    Max,
19}
20
21/// Normalize samples individually to unit norm.
22///
23/// Each sample (row) is scaled independently so that its chosen norm
24/// equals 1.0. This is useful for text classification or clustering
25/// where the direction of the feature vector matters more than magnitude.
26///
27/// `fit()` is a no-op — normalizer is stateless.
28///
29/// # Example
30///
31/// ```ignore
32/// let mut norm = Normalizer::new(Norm::L2);
33/// norm.transform(&mut ds)?;
34/// ```
35#[derive(Clone, Debug)]
36#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
37#[non_exhaustive]
38pub struct Normalizer {
39    norm: Norm,
40    #[cfg_attr(feature = "serde", serde(default))]
41    _schema_version: u32,
42}
43
44impl Normalizer {
45    /// Create a normalizer with the given norm type.
46    pub fn new(norm: Norm) -> Self {
47        Self {
48            norm,
49            _schema_version: crate::version::SCHEMA_VERSION,
50        }
51    }
52
53    /// Create a normalizer with L2 norm (default).
54    pub fn l2() -> Self {
55        Self {
56            norm: Norm::L2,
57            _schema_version: crate::version::SCHEMA_VERSION,
58        }
59    }
60}
61
62impl Default for Normalizer {
63    fn default() -> Self {
64        Self::l2()
65    }
66}
67
68impl Transformer for Normalizer {
69    fn fit(&mut self, data: &Dataset) -> Result<()> {
70        data.validate_finite()?;
71        if data.n_samples() == 0 {
72            return Err(ScryLearnError::EmptyDataset);
73        }
74        // No-op: normalizer is stateless.
75        Ok(())
76    }
77
78    fn transform(&self, data: &mut Dataset) -> Result<()> {
79        crate::version::check_schema_version(self._schema_version)?;
80        let n = data.n_samples();
81        let m = data.n_features();
82
83        for i in 0..n {
84            // Compute the norm for this row.
85            let norm_val = match self.norm {
86                Norm::L1 => {
87                    let mut s = 0.0_f64;
88                    for col in &data.features {
89                        s += col[i].abs();
90                    }
91                    s
92                }
93                Norm::L2 => {
94                    let mut s = 0.0_f64;
95                    for col in &data.features {
96                        s += col[i] * col[i];
97                    }
98                    s.sqrt()
99                }
100                Norm::Max => {
101                    let mut mx = 0.0_f64;
102                    for col in &data.features {
103                        mx = mx.max(col[i].abs());
104                    }
105                    mx
106                }
107            };
108
109            if norm_val > 1e-12 {
110                for j in 0..m {
111                    data.features[j][i] /= norm_val;
112                }
113            }
114        }
115
116        data.sync_matrix();
117        Ok(())
118    }
119
120    fn inverse_transform(&self, _data: &mut Dataset) -> Result<()> {
121        Err(ScryLearnError::InvalidParameter(
122            "Normalizer is not invertible (row norms are lost)".into(),
123        ))
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130
131    fn make_ds(rows: &[Vec<f64>]) -> Dataset {
132        let n = rows.len();
133        let m = rows[0].len();
134        let mut features = vec![vec![0.0; n]; m];
135        for (i, row) in rows.iter().enumerate() {
136            for (j, &val) in row.iter().enumerate() {
137                features[j][i] = val;
138            }
139        }
140        let names: Vec<String> = (0..m).map(|j| format!("f{j}")).collect();
141        Dataset::new(features, vec![0.0; n], names, "y")
142    }
143
144    #[test]
145    fn test_normalizer_l2_unit_norm() {
146        let mut ds = make_ds(&[vec![3.0, 4.0], vec![1.0, 0.0]]);
147        let mut norm = Normalizer::new(Norm::L2);
148        norm.fit_transform(&mut ds).unwrap();
149
150        // Row 0: [3,4] → norm=5 → [0.6, 0.8]
151        assert!((ds.features[0][0] - 0.6).abs() < 1e-10);
152        assert!((ds.features[1][0] - 0.8).abs() < 1e-10);
153
154        // Verify unit L2 norm for each row.
155        for i in 0..ds.n_samples() {
156            let mut sq_sum = 0.0;
157            for col in &ds.features {
158                sq_sum += col[i] * col[i];
159            }
160            assert!(
161                (sq_sum - 1.0).abs() < 1e-10,
162                "row {i} L2 norm² = {sq_sum}, expected 1.0"
163            );
164        }
165    }
166
167    #[test]
168    fn test_normalizer_l1() {
169        let mut ds = make_ds(&[vec![1.0, 2.0, 3.0]]);
170        let mut norm = Normalizer::new(Norm::L1);
171        norm.fit_transform(&mut ds).unwrap();
172
173        // Row 0: sum_abs = 6, so [1/6, 2/6, 3/6]
174        let abs_sum: f64 = ds.features.iter().map(|c| c[0].abs()).sum();
175        assert!(
176            (abs_sum - 1.0).abs() < 1e-10,
177            "L1 norm should be 1.0, got {abs_sum}"
178        );
179    }
180
181    #[test]
182    fn test_normalizer_max() {
183        let mut ds = make_ds(&[vec![-5.0, 2.0, 3.0]]);
184        let mut norm = Normalizer::new(Norm::Max);
185        norm.fit_transform(&mut ds).unwrap();
186
187        // max_abs = 5, so [-1, 0.4, 0.6]
188        assert!((ds.features[0][0] - (-1.0)).abs() < 1e-10);
189        let max_abs: f64 = ds
190            .features
191            .iter()
192            .map(|c| c[0].abs())
193            .fold(0.0_f64, f64::max);
194        assert!(
195            (max_abs - 1.0).abs() < 1e-10,
196            "Max norm should be 1.0, got {max_abs}"
197        );
198    }
199
200    #[test]
201    fn test_normalizer_zero_row() {
202        // Zero row should be left as-is (no division by zero).
203        let mut ds = make_ds(&[vec![0.0, 0.0]]);
204        let mut norm = Normalizer::new(Norm::L2);
205        norm.fit_transform(&mut ds).unwrap();
206
207        for col in &ds.features {
208            assert!((col[0]).abs() < 1e-10);
209        }
210    }
211}