scirs2_transform/
normalize.rs

1//! Data normalization and standardization utilities
2//!
3//! This module provides functions for normalizing and standardizing data,
4//! which is often a preprocessing step for machine learning algorithms.
5
6use ndarray::{Array1, Array2, ArrayBase, Axis, Data, Ix1, Ix2};
7use num_traits::{Float, NumCast};
8
9use crate::error::{Result, TransformError};
10
11// Define a small value to use for comparison with zero
12const EPSILON: f64 = 1e-10;
13
14/// Method of normalization to apply
15#[derive(Debug, Clone, Copy, PartialEq)]
16pub enum NormalizationMethod {
17    /// Min-max normalization (scales values to [0, 1] range)
18    MinMax,
19    /// Min-max normalization to custom range
20    MinMaxCustom(f64, f64),
21    /// Z-score standardization (zero mean, unit variance)
22    ZScore,
23    /// Max absolute scaling (scales by maximum absolute value)
24    MaxAbs,
25    /// L1 normalization (divide by sum of absolute values)
26    L1,
27    /// L2 normalization (divide by Euclidean norm)
28    L2,
29    /// Robust scaling using median and IQR (robust to outliers)
30    Robust,
31}
32
33/// Normalizes a 2D array along a specified axis
34///
35/// # Arguments
36/// * `array` - The input 2D array to normalize
37/// * `method` - The normalization method to apply
38/// * `axis` - The axis along which to normalize (0 for columns, 1 for rows)
39///
40/// # Returns
41/// * `Result<Array2<f64>>` - The normalized array
42///
43/// # Examples
44/// ```
45/// use ndarray::array;
46/// use scirs2_transform::normalize::{normalize_array, NormalizationMethod};
47///
48/// let data = array![[1.0, 2.0, 3.0],
49///                   [4.0, 5.0, 6.0],
50///                   [7.0, 8.0, 9.0]];
51///                   
52/// // Normalize columns (axis 0) using min-max normalization
53/// let normalized = normalize_array(&data, NormalizationMethod::MinMax, 0).unwrap();
54/// ```
55pub fn normalize_array<S>(
56    array: &ArrayBase<S, Ix2>,
57    method: NormalizationMethod,
58    axis: usize,
59) -> Result<Array2<f64>>
60where
61    S: Data,
62    S::Elem: Float + NumCast,
63{
64    let array_f64 = array.mapv(|x| num_traits::cast::<S::Elem, f64>(x).unwrap_or(0.0));
65
66    if !array_f64.is_standard_layout() {
67        return Err(TransformError::InvalidInput(
68            "Input array must be in standard memory layout".to_string(),
69        ));
70    }
71
72    if array_f64.ndim() != 2 {
73        return Err(TransformError::InvalidInput(
74            "Only 2D arrays are supported".to_string(),
75        ));
76    }
77
78    if axis >= array_f64.ndim() {
79        return Err(TransformError::InvalidInput(format!(
80            "Invalid axis {} for array with {} dimensions",
81            axis,
82            array_f64.ndim()
83        )));
84    }
85
86    let shape = array_f64.shape();
87    let mut normalized = Array2::zeros((shape[0], shape[1]));
88
89    match method {
90        NormalizationMethod::MinMax => {
91            let min = array_f64.map_axis(Axis(axis), |view| {
92                view.fold(f64::INFINITY, |acc, &x| acc.min(x))
93            });
94
95            let max = array_f64.map_axis(Axis(axis), |view| {
96                view.fold(f64::NEG_INFINITY, |acc, &x| acc.max(x))
97            });
98
99            let range = &max - &min;
100
101            for i in 0..shape[0] {
102                for j in 0..shape[1] {
103                    let value = array_f64[[i, j]];
104                    let idx = if axis == 0 { j } else { i };
105
106                    if range[idx].abs() > EPSILON {
107                        normalized[[i, j]] = (value - min[idx]) / range[idx];
108                    } else {
109                        normalized[[i, j]] = 0.5; // Default for constant features
110                    }
111                }
112            }
113        }
114        NormalizationMethod::MinMaxCustom(new_min, new_max) => {
115            let min = array_f64.map_axis(Axis(axis), |view| {
116                view.fold(f64::INFINITY, |acc, &x| acc.min(x))
117            });
118
119            let max = array_f64.map_axis(Axis(axis), |view| {
120                view.fold(f64::NEG_INFINITY, |acc, &x| acc.max(x))
121            });
122
123            let range = &max - &min;
124            let new_range = new_max - new_min;
125
126            for i in 0..shape[0] {
127                for j in 0..shape[1] {
128                    let value = array_f64[[i, j]];
129                    let idx = if axis == 0 { j } else { i };
130
131                    if range[idx].abs() > EPSILON {
132                        normalized[[i, j]] = (value - min[idx]) / range[idx] * new_range + new_min;
133                    } else {
134                        normalized[[i, j]] = (new_min + new_max) / 2.0; // Default for constant features
135                    }
136                }
137            }
138        }
139        NormalizationMethod::ZScore => {
140            let mean = array_f64.map_axis(Axis(axis), |view| {
141                view.iter().sum::<f64>() / view.len() as f64
142            });
143
144            let std_dev = array_f64.map_axis(Axis(axis), |view| {
145                let m = view.iter().sum::<f64>() / view.len() as f64;
146                let variance =
147                    view.iter().map(|&x| (x - m).powi(2)).sum::<f64>() / view.len() as f64;
148                variance.sqrt()
149            });
150
151            for i in 0..shape[0] {
152                for j in 0..shape[1] {
153                    let value = array_f64[[i, j]];
154                    let idx = if axis == 0 { j } else { i };
155
156                    if std_dev[idx] > EPSILON {
157                        normalized[[i, j]] = (value - mean[idx]) / std_dev[idx];
158                    } else {
159                        normalized[[i, j]] = 0.0; // Default for constant features
160                    }
161                }
162            }
163        }
164        NormalizationMethod::MaxAbs => {
165            let max_abs = array_f64.map_axis(Axis(axis), |view| {
166                view.fold(0.0, |acc, &x| acc.max(x.abs()))
167            });
168
169            for i in 0..shape[0] {
170                for j in 0..shape[1] {
171                    let value = array_f64[[i, j]];
172                    let idx = if axis == 0 { j } else { i };
173
174                    if max_abs[idx] > EPSILON {
175                        normalized[[i, j]] = value / max_abs[idx];
176                    } else {
177                        normalized[[i, j]] = 0.0; // Default for constant features
178                    }
179                }
180            }
181        }
182        NormalizationMethod::L1 => {
183            let l1_norm =
184                array_f64.map_axis(Axis(axis), |view| view.fold(0.0, |acc, &x| acc + x.abs()));
185
186            for i in 0..shape[0] {
187                for j in 0..shape[1] {
188                    let value = array_f64[[i, j]];
189                    let idx = if axis == 0 { j } else { i };
190
191                    if l1_norm[idx] > EPSILON {
192                        normalized[[i, j]] = value / l1_norm[idx];
193                    } else {
194                        normalized[[i, j]] = 0.0; // Default for constant features
195                    }
196                }
197            }
198        }
199        NormalizationMethod::L2 => {
200            let l2_norm = array_f64.map_axis(Axis(axis), |view| {
201                let sum_squares = view.iter().fold(0.0, |acc, &x| acc + x * x);
202                sum_squares.sqrt()
203            });
204
205            for i in 0..shape[0] {
206                for j in 0..shape[1] {
207                    let value = array_f64[[i, j]];
208                    let idx = if axis == 0 { j } else { i };
209
210                    if l2_norm[idx] > EPSILON {
211                        normalized[[i, j]] = value / l2_norm[idx];
212                    } else {
213                        normalized[[i, j]] = 0.0; // Default for constant features
214                    }
215                }
216            }
217        }
218        NormalizationMethod::Robust => {
219            let median = array_f64.map_axis(Axis(axis), |view| {
220                let mut data = view.to_vec();
221                data.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
222                let n = data.len();
223                if n % 2 == 0 {
224                    (data[n / 2 - 1] + data[n / 2]) / 2.0
225                } else {
226                    data[n / 2]
227                }
228            });
229
230            let iqr = array_f64.map_axis(Axis(axis), |view| {
231                let mut data = view.to_vec();
232                data.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
233                let n = data.len();
234
235                // Calculate Q1 (25th percentile)
236                let q1_pos = 0.25 * (n - 1) as f64;
237                let q1_idx_low = q1_pos.floor() as usize;
238                let q1_idx_high = q1_pos.ceil() as usize;
239                let q1 = if q1_idx_low == q1_idx_high {
240                    data[q1_idx_low]
241                } else {
242                    let weight = q1_pos - q1_idx_low as f64;
243                    data[q1_idx_low] * (1.0 - weight) + data[q1_idx_high] * weight
244                };
245
246                // Calculate Q3 (75th percentile)
247                let q3_pos = 0.75 * (n - 1) as f64;
248                let q3_idx_low = q3_pos.floor() as usize;
249                let q3_idx_high = q3_pos.ceil() as usize;
250                let q3 = if q3_idx_low == q3_idx_high {
251                    data[q3_idx_low]
252                } else {
253                    let weight = q3_pos - q3_idx_low as f64;
254                    data[q3_idx_low] * (1.0 - weight) + data[q3_idx_high] * weight
255                };
256
257                q3 - q1
258            });
259
260            for i in 0..shape[0] {
261                for j in 0..shape[1] {
262                    let value = array_f64[[i, j]];
263                    let idx = if axis == 0 { j } else { i };
264
265                    if iqr[idx] > EPSILON {
266                        normalized[[i, j]] = (value - median[idx]) / iqr[idx];
267                    } else {
268                        normalized[[i, j]] = 0.0; // Default for constant features
269                    }
270                }
271            }
272        }
273    }
274
275    Ok(normalized)
276}
277
278/// Normalizes a 1D array
279///
280/// # Arguments
281/// * `array` - The input 1D array to normalize
282/// * `method` - The normalization method to apply
283///
284/// # Returns
285/// * `Result<Array1<f64>>` - The normalized array
286///
287/// # Examples
288/// ```
289/// use ndarray::array;
290/// use scirs2_transform::normalize::{normalize_vector, NormalizationMethod};
291///
292/// let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
293///                   
294/// // Normalize vector using min-max normalization
295/// let normalized = normalize_vector(&data, NormalizationMethod::MinMax).unwrap();
296/// ```
297pub fn normalize_vector<S>(
298    array: &ArrayBase<S, Ix1>,
299    method: NormalizationMethod,
300) -> Result<Array1<f64>>
301where
302    S: Data,
303    S::Elem: Float + NumCast,
304{
305    let array_f64 = array.mapv(|x| num_traits::cast::<S::Elem, f64>(x).unwrap_or(0.0));
306
307    if array_f64.is_empty() {
308        return Err(TransformError::InvalidInput(
309            "Input array is empty".to_string(),
310        ));
311    }
312
313    let mut normalized = Array1::zeros(array_f64.len());
314
315    match method {
316        NormalizationMethod::MinMax => {
317            let min = array_f64.fold(f64::INFINITY, |acc, &x| acc.min(x));
318            let max = array_f64.fold(f64::NEG_INFINITY, |acc, &x| acc.max(x));
319            let range = max - min;
320
321            if range.abs() > EPSILON {
322                for (i, &value) in array_f64.iter().enumerate() {
323                    normalized[i] = (value - min) / range;
324                }
325            } else {
326                normalized.fill(0.5); // Default for constant features
327            }
328        }
329        NormalizationMethod::MinMaxCustom(new_min, new_max) => {
330            let min = array_f64.fold(f64::INFINITY, |acc, &x| acc.min(x));
331            let max = array_f64.fold(f64::NEG_INFINITY, |acc, &x| acc.max(x));
332            let range = max - min;
333            let new_range = new_max - new_min;
334
335            if range.abs() > EPSILON {
336                for (i, &value) in array_f64.iter().enumerate() {
337                    normalized[i] = (value - min) / range * new_range + new_min;
338                }
339            } else {
340                normalized.fill((new_min + new_max) / 2.0); // Default for constant features
341            }
342        }
343        NormalizationMethod::ZScore => {
344            let mean = array_f64.iter().sum::<f64>() / array_f64.len() as f64;
345            let variance =
346                array_f64.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / array_f64.len() as f64;
347            let std_dev = variance.sqrt();
348
349            if std_dev > EPSILON {
350                for (i, &value) in array_f64.iter().enumerate() {
351                    normalized[i] = (value - mean) / std_dev;
352                }
353            } else {
354                normalized.fill(0.0); // Default for constant features
355            }
356        }
357        NormalizationMethod::MaxAbs => {
358            let max_abs = array_f64.fold(0.0, |acc, &x| acc.max(x.abs()));
359
360            if max_abs > EPSILON {
361                for (i, &value) in array_f64.iter().enumerate() {
362                    normalized[i] = value / max_abs;
363                }
364            } else {
365                normalized.fill(0.0); // Default for constant features
366            }
367        }
368        NormalizationMethod::L1 => {
369            let l1_norm = array_f64.fold(0.0, |acc, &x| acc + x.abs());
370
371            if l1_norm > EPSILON {
372                for (i, &value) in array_f64.iter().enumerate() {
373                    normalized[i] = value / l1_norm;
374                }
375            } else {
376                normalized.fill(0.0); // Default for constant features
377            }
378        }
379        NormalizationMethod::L2 => {
380            let sum_squares = array_f64.iter().fold(0.0, |acc, &x| acc + x * x);
381            let l2_norm = sum_squares.sqrt();
382
383            if l2_norm > EPSILON {
384                for (i, &value) in array_f64.iter().enumerate() {
385                    normalized[i] = value / l2_norm;
386                }
387            } else {
388                normalized.fill(0.0); // Default for constant features
389            }
390        }
391        NormalizationMethod::Robust => {
392            let mut data = array_f64.to_vec();
393            data.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
394            let n = data.len();
395
396            // Calculate median
397            let median = if n % 2 == 0 {
398                (data[n / 2 - 1] + data[n / 2]) / 2.0
399            } else {
400                data[n / 2]
401            };
402
403            // Calculate IQR (Interquartile Range)
404            // Calculate Q1 (25th percentile)
405            let q1_pos = 0.25 * (n - 1) as f64;
406            let q1_idx_low = q1_pos.floor() as usize;
407            let q1_idx_high = q1_pos.ceil() as usize;
408            let q1 = if q1_idx_low == q1_idx_high {
409                data[q1_idx_low]
410            } else {
411                let weight = q1_pos - q1_idx_low as f64;
412                data[q1_idx_low] * (1.0 - weight) + data[q1_idx_high] * weight
413            };
414
415            // Calculate Q3 (75th percentile)
416            let q3_pos = 0.75 * (n - 1) as f64;
417            let q3_idx_low = q3_pos.floor() as usize;
418            let q3_idx_high = q3_pos.ceil() as usize;
419            let q3 = if q3_idx_low == q3_idx_high {
420                data[q3_idx_low]
421            } else {
422                let weight = q3_pos - q3_idx_low as f64;
423                data[q3_idx_low] * (1.0 - weight) + data[q3_idx_high] * weight
424            };
425
426            let iqr = q3 - q1;
427
428            if iqr > EPSILON {
429                for (i, &value) in array_f64.iter().enumerate() {
430                    normalized[i] = (value - median) / iqr;
431                }
432            } else {
433                normalized.fill(0.0); // Default for constant features
434            }
435        }
436    }
437
438    Ok(normalized)
439}
440
441/// Represents a fitted normalization model that can transform new data
442pub struct Normalizer {
443    /// The normalization method to apply
444    #[allow(dead_code)]
445    method: NormalizationMethod,
446    /// The axis along which to normalize (0 for columns, 1 for rows)
447    axis: usize,
448    /// Parameters from the fit (depends on method)
449    params: NormalizerParams,
450}
451
452/// Parameters for different normalization methods
453#[derive(Clone)]
454enum NormalizerParams {
455    /// Min and max values for MinMax normalization
456    MinMax {
457        min: Array1<f64>,
458        max: Array1<f64>,
459        new_min: f64,
460        new_max: f64,
461    },
462    /// Mean and standard deviation for ZScore normalization
463    ZScore {
464        mean: Array1<f64>,
465        std_dev: Array1<f64>,
466    },
467    /// Maximum absolute values for MaxAbs normalization
468    MaxAbs { max_abs: Array1<f64> },
469    /// L1 norms for L1 normalization
470    L1 { l1_norm: Array1<f64> },
471    /// L2 norms for L2 normalization
472    L2 { l2_norm: Array1<f64> },
473    /// Median and IQR for Robust normalization
474    Robust {
475        median: Array1<f64>,
476        iqr: Array1<f64>,
477    },
478}
479
480impl Normalizer {
481    /// Creates a new Normalizer with the specified method and axis
482    ///
483    /// # Arguments
484    /// * `method` - The normalization method to apply
485    /// * `axis` - The axis along which to normalize (0 for columns, 1 for rows)
486    ///
487    /// # Returns
488    /// * A new Normalizer instance
489    pub fn new(method: NormalizationMethod, axis: usize) -> Self {
490        let params = match method {
491            NormalizationMethod::MinMax => NormalizerParams::MinMax {
492                min: Array1::zeros(0),
493                max: Array1::zeros(0),
494                new_min: 0.0,
495                new_max: 1.0,
496            },
497            NormalizationMethod::MinMaxCustom(min, max) => NormalizerParams::MinMax {
498                min: Array1::zeros(0),
499                max: Array1::zeros(0),
500                new_min: min,
501                new_max: max,
502            },
503            NormalizationMethod::ZScore => NormalizerParams::ZScore {
504                mean: Array1::zeros(0),
505                std_dev: Array1::zeros(0),
506            },
507            NormalizationMethod::MaxAbs => NormalizerParams::MaxAbs {
508                max_abs: Array1::zeros(0),
509            },
510            NormalizationMethod::L1 => NormalizerParams::L1 {
511                l1_norm: Array1::zeros(0),
512            },
513            NormalizationMethod::L2 => NormalizerParams::L2 {
514                l2_norm: Array1::zeros(0),
515            },
516            NormalizationMethod::Robust => NormalizerParams::Robust {
517                median: Array1::zeros(0),
518                iqr: Array1::zeros(0),
519            },
520        };
521
522        Normalizer {
523            method,
524            axis,
525            params,
526        }
527    }
528
529    /// Fits the normalizer to the input data
530    ///
531    /// # Arguments
532    /// * `array` - The input 2D array to fit the normalizer to
533    ///
534    /// # Returns
535    /// * `Result<()>` - Ok if successful, Err otherwise
536    pub fn fit<S>(&mut self, array: &ArrayBase<S, Ix2>) -> Result<()>
537    where
538        S: Data,
539        S::Elem: Float + NumCast,
540    {
541        let array_f64 = array.mapv(|x| num_traits::cast::<S::Elem, f64>(x).unwrap_or(0.0));
542
543        if !array_f64.is_standard_layout() {
544            return Err(TransformError::InvalidInput(
545                "Input array must be in standard memory layout".to_string(),
546            ));
547        }
548
549        if array_f64.ndim() != 2 {
550            return Err(TransformError::InvalidInput(
551                "Only 2D arrays are supported".to_string(),
552            ));
553        }
554
555        if self.axis >= array_f64.ndim() {
556            return Err(TransformError::InvalidInput(format!(
557                "Invalid axis {} for array with {} dimensions",
558                self.axis,
559                array_f64.ndim()
560            )));
561        }
562
563        let _size = if self.axis == 0 {
564            array_f64.shape()[1]
565        } else {
566            array_f64.shape()[0]
567        };
568
569        match &mut self.params {
570            NormalizerParams::MinMax {
571                min,
572                max,
573                new_min: _,
574                new_max: _,
575            } => {
576                *min = array_f64.map_axis(Axis(self.axis), |view| {
577                    view.fold(f64::INFINITY, |acc, &x| acc.min(x))
578                });
579
580                *max = array_f64.map_axis(Axis(self.axis), |view| {
581                    view.fold(f64::NEG_INFINITY, |acc, &x| acc.max(x))
582                });
583            }
584            NormalizerParams::ZScore { mean, std_dev } => {
585                *mean = array_f64.map_axis(Axis(self.axis), |view| {
586                    view.iter().sum::<f64>() / view.len() as f64
587                });
588
589                *std_dev = array_f64.map_axis(Axis(self.axis), |view| {
590                    let m = view.iter().sum::<f64>() / view.len() as f64;
591                    let variance =
592                        view.iter().map(|&x| (x - m).powi(2)).sum::<f64>() / view.len() as f64;
593                    variance.sqrt()
594                });
595            }
596            NormalizerParams::MaxAbs { max_abs } => {
597                *max_abs = array_f64.map_axis(Axis(self.axis), |view| {
598                    view.fold(0.0, |acc, &x| acc.max(x.abs()))
599                });
600            }
601            NormalizerParams::L1 { l1_norm } => {
602                *l1_norm = array_f64.map_axis(Axis(self.axis), |view| {
603                    view.fold(0.0, |acc, &x| acc + x.abs())
604                });
605            }
606            NormalizerParams::L2 { l2_norm } => {
607                *l2_norm = array_f64.map_axis(Axis(self.axis), |view| {
608                    let sum_squares = view.iter().fold(0.0, |acc, &x| acc + x * x);
609                    sum_squares.sqrt()
610                });
611            }
612            NormalizerParams::Robust { median, iqr } => {
613                *median = array_f64.map_axis(Axis(self.axis), |view| {
614                    let mut data = view.to_vec();
615                    data.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
616                    let n = data.len();
617                    if n % 2 == 0 {
618                        (data[n / 2 - 1] + data[n / 2]) / 2.0
619                    } else {
620                        data[n / 2]
621                    }
622                });
623
624                *iqr = array_f64.map_axis(Axis(self.axis), |view| {
625                    let mut data = view.to_vec();
626                    data.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
627                    let n = data.len();
628
629                    // Calculate Q1 (25th percentile)
630                    let q1_pos = 0.25 * (n - 1) as f64;
631                    let q1_idx_low = q1_pos.floor() as usize;
632                    let q1_idx_high = q1_pos.ceil() as usize;
633                    let q1 = if q1_idx_low == q1_idx_high {
634                        data[q1_idx_low]
635                    } else {
636                        let weight = q1_pos - q1_idx_low as f64;
637                        data[q1_idx_low] * (1.0 - weight) + data[q1_idx_high] * weight
638                    };
639
640                    // Calculate Q3 (75th percentile)
641                    let q3_pos = 0.75 * (n - 1) as f64;
642                    let q3_idx_low = q3_pos.floor() as usize;
643                    let q3_idx_high = q3_pos.ceil() as usize;
644                    let q3 = if q3_idx_low == q3_idx_high {
645                        data[q3_idx_low]
646                    } else {
647                        let weight = q3_pos - q3_idx_low as f64;
648                        data[q3_idx_low] * (1.0 - weight) + data[q3_idx_high] * weight
649                    };
650
651                    q3 - q1
652                });
653            }
654        }
655
656        Ok(())
657    }
658
659    /// Transforms the input data using the fitted normalizer
660    ///
661    /// # Arguments
662    /// * `array` - The input 2D array to transform
663    ///
664    /// # Returns
665    /// * `Result<Array2<f64>>` - The transformed array
666    pub fn transform<S>(&self, array: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
667    where
668        S: Data,
669        S::Elem: Float + NumCast,
670    {
671        let array_f64 = array.mapv(|x| num_traits::cast::<S::Elem, f64>(x).unwrap_or(0.0));
672
673        if !array_f64.is_standard_layout() {
674            return Err(TransformError::InvalidInput(
675                "Input array must be in standard memory layout".to_string(),
676            ));
677        }
678
679        if array_f64.ndim() != 2 {
680            return Err(TransformError::InvalidInput(
681                "Only 2D arrays are supported".to_string(),
682            ));
683        }
684
685        // Check the dimension along the normalization axis
686        let expected_size = match &self.params {
687            NormalizerParams::MinMax { min, .. } => min.len(),
688            NormalizerParams::ZScore { mean, .. } => mean.len(),
689            NormalizerParams::MaxAbs { max_abs } => max_abs.len(),
690            NormalizerParams::L1 { l1_norm } => l1_norm.len(),
691            NormalizerParams::L2 { l2_norm } => l2_norm.len(),
692            NormalizerParams::Robust { median, .. } => median.len(),
693        };
694
695        let actual_size = if self.axis == 0 {
696            array_f64.shape()[1]
697        } else {
698            array_f64.shape()[0]
699        };
700
701        if expected_size != actual_size {
702            return Err(TransformError::InvalidInput(format!(
703                "Expected {} features, got {}",
704                expected_size, actual_size
705            )));
706        }
707
708        let shape = array_f64.shape();
709        let mut transformed = Array2::zeros((shape[0], shape[1]));
710
711        match &self.params {
712            NormalizerParams::MinMax {
713                min,
714                max,
715                new_min,
716                new_max,
717            } => {
718                let range = max - min;
719                let new_range = new_max - new_min;
720
721                for i in 0..shape[0] {
722                    for j in 0..shape[1] {
723                        let value = array_f64[[i, j]];
724                        let idx = if self.axis == 0 { j } else { i };
725
726                        if range[idx].abs() > EPSILON {
727                            transformed[[i, j]] =
728                                (value - min[idx]) / range[idx] * new_range + new_min;
729                        } else {
730                            transformed[[i, j]] = (new_min + new_max) / 2.0; // Default for constant features
731                        }
732                    }
733                }
734            }
735            NormalizerParams::ZScore { mean, std_dev } => {
736                for i in 0..shape[0] {
737                    for j in 0..shape[1] {
738                        let value = array_f64[[i, j]];
739                        let idx = if self.axis == 0 { j } else { i };
740
741                        if std_dev[idx] > EPSILON {
742                            transformed[[i, j]] = (value - mean[idx]) / std_dev[idx];
743                        } else {
744                            transformed[[i, j]] = 0.0; // Default for constant features
745                        }
746                    }
747                }
748            }
749            NormalizerParams::MaxAbs { max_abs } => {
750                for i in 0..shape[0] {
751                    for j in 0..shape[1] {
752                        let value = array_f64[[i, j]];
753                        let idx = if self.axis == 0 { j } else { i };
754
755                        if max_abs[idx] > EPSILON {
756                            transformed[[i, j]] = value / max_abs[idx];
757                        } else {
758                            transformed[[i, j]] = 0.0; // Default for constant features
759                        }
760                    }
761                }
762            }
763            NormalizerParams::L1 { l1_norm } => {
764                for i in 0..shape[0] {
765                    for j in 0..shape[1] {
766                        let value = array_f64[[i, j]];
767                        let idx = if self.axis == 0 { j } else { i };
768
769                        if l1_norm[idx] > EPSILON {
770                            transformed[[i, j]] = value / l1_norm[idx];
771                        } else {
772                            transformed[[i, j]] = 0.0; // Default for constant features
773                        }
774                    }
775                }
776            }
777            NormalizerParams::L2 { l2_norm } => {
778                for i in 0..shape[0] {
779                    for j in 0..shape[1] {
780                        let value = array_f64[[i, j]];
781                        let idx = if self.axis == 0 { j } else { i };
782
783                        if l2_norm[idx] > EPSILON {
784                            transformed[[i, j]] = value / l2_norm[idx];
785                        } else {
786                            transformed[[i, j]] = 0.0; // Default for constant features
787                        }
788                    }
789                }
790            }
791            NormalizerParams::Robust { median, iqr } => {
792                for i in 0..shape[0] {
793                    for j in 0..shape[1] {
794                        let value = array_f64[[i, j]];
795                        let idx = if self.axis == 0 { j } else { i };
796
797                        if iqr[idx] > EPSILON {
798                            transformed[[i, j]] = (value - median[idx]) / iqr[idx];
799                        } else {
800                            transformed[[i, j]] = 0.0; // Default for constant features
801                        }
802                    }
803                }
804            }
805        }
806
807        Ok(transformed)
808    }
809
810    /// Fits the normalizer to the input data and transforms it
811    ///
812    /// # Arguments
813    /// * `array` - The input 2D array to fit and transform
814    ///
815    /// # Returns
816    /// * `Result<Array2<f64>>` - The transformed array
817    pub fn fit_transform<S>(&mut self, array: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
818    where
819        S: Data,
820        S::Elem: Float + NumCast,
821    {
822        self.fit(array)?;
823        self.transform(array)
824    }
825}
826
827#[cfg(test)]
828mod tests {
829    use super::*;
830    use approx::assert_abs_diff_eq;
831    use ndarray::Array;
832
833    #[test]
834    fn test_normalize_vector_minmax() {
835        let data = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
836        let normalized = normalize_vector(&data, NormalizationMethod::MinMax).unwrap();
837
838        let expected = Array::from_vec(vec![0.0, 0.25, 0.5, 0.75, 1.0]);
839
840        for (a, b) in normalized.iter().zip(expected.iter()) {
841            assert_abs_diff_eq!(a, b, epsilon = 1e-10);
842        }
843    }
844
845    #[test]
846    fn test_normalize_vector_zscore() {
847        let data = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
848        let normalized = normalize_vector(&data, NormalizationMethod::ZScore).unwrap();
849
850        let mean = 3.0;
851        let std_dev = (10.0 / 5.0_f64).sqrt();
852        let expected = data.mapv(|x| (x - mean) / std_dev);
853
854        for (a, b) in normalized.iter().zip(expected.iter()) {
855            assert_abs_diff_eq!(a, b, epsilon = 1e-10);
856        }
857    }
858
859    #[test]
860    fn test_normalize_array_minmax() {
861        let data = Array::from_shape_vec((3, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
862            .unwrap();
863
864        // Normalize columns (axis 0)
865        let normalized = normalize_array(&data, NormalizationMethod::MinMax, 0).unwrap();
866
867        let expected =
868            Array::from_shape_vec((3, 3), vec![0.0, 0.0, 0.0, 0.5, 0.5, 0.5, 1.0, 1.0, 1.0])
869                .unwrap();
870
871        for i in 0..3 {
872            for j in 0..3 {
873                assert_abs_diff_eq!(normalized[[i, j]], expected[[i, j]], epsilon = 1e-10);
874            }
875        }
876
877        // Normalize rows (axis 1)
878        let normalized = normalize_array(&data, NormalizationMethod::MinMax, 1).unwrap();
879
880        let expected =
881            Array::from_shape_vec((3, 3), vec![0.0, 0.5, 1.0, 0.0, 0.5, 1.0, 0.0, 0.5, 1.0])
882                .unwrap();
883
884        for i in 0..3 {
885            for j in 0..3 {
886                assert_abs_diff_eq!(normalized[[i, j]], expected[[i, j]], epsilon = 1e-10);
887            }
888        }
889    }
890
891    #[test]
892    fn test_normalizer_fit_transform() {
893        let data = Array::from_shape_vec((3, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
894            .unwrap();
895
896        // Test MinMax normalization
897        let mut normalizer = Normalizer::new(NormalizationMethod::MinMax, 0);
898        let transformed = normalizer.fit_transform(&data).unwrap();
899
900        let expected =
901            Array::from_shape_vec((3, 3), vec![0.0, 0.0, 0.0, 0.5, 0.5, 0.5, 1.0, 1.0, 1.0])
902                .unwrap();
903
904        for i in 0..3 {
905            for j in 0..3 {
906                assert_abs_diff_eq!(transformed[[i, j]], expected[[i, j]], epsilon = 1e-10);
907            }
908        }
909
910        // Test with separate fit and transform
911        let data2 = Array::from_shape_vec((2, 3), vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0]).unwrap();
912
913        let transformed2 = normalizer.transform(&data2).unwrap();
914
915        let expected2 = Array::from_shape_vec(
916            (2, 3),
917            vec![
918                1.0 / 6.0,
919                1.0 / 6.0,
920                1.0 / 6.0,
921                2.0 / 3.0,
922                2.0 / 3.0,
923                2.0 / 3.0,
924            ],
925        )
926        .unwrap();
927
928        for i in 0..2 {
929            for j in 0..3 {
930                assert_abs_diff_eq!(transformed2[[i, j]], expected2[[i, j]], epsilon = 1e-10);
931            }
932        }
933    }
934
935    #[test]
936    fn test_normalize_vector_robust() {
937        // Test with data containing outliers
938        let data = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 100.0]); // 100 is an outlier
939        let normalized = normalize_vector(&data, NormalizationMethod::Robust).unwrap();
940
941        // For this data: sorted = [1, 2, 3, 4, 100]
942        // median = 3.0 (middle value)
943        // Q1 = 2.0 (at 25th percentile), Q3 = 4.0 (at 75th percentile), IQR = 2.0
944        // Expected transformation: (x - 3) / 2
945        let expected = Array::from_vec(vec![
946            (1.0 - 3.0) / 2.0,   // -1.0
947            (2.0 - 3.0) / 2.0,   // -0.5
948            (3.0 - 3.0) / 2.0,   // 0
949            (4.0 - 3.0) / 2.0,   // 0.5
950            (100.0 - 3.0) / 2.0, // 48.5
951        ]);
952
953        for (a, b) in normalized.iter().zip(expected.iter()) {
954            assert_abs_diff_eq!(a, b, epsilon = 1e-10);
955        }
956    }
957
958    #[test]
959    fn test_normalize_array_robust() {
960        let data = Array::from_shape_vec((3, 2), vec![1.0, 10.0, 2.0, 20.0, 3.0, 30.0]).unwrap();
961
962        // Normalize columns (axis 0)
963        let normalized = normalize_array(&data, NormalizationMethod::Robust, 0).unwrap();
964
965        // For column 0: [1, 2, 3] -> median=2, Q1=1.5, Q3=2.5, IQR=1.0
966        // For column 1: [10, 20, 30] -> median=20, Q1=15, Q3=25, IQR=10
967        let expected = Array::from_shape_vec(
968            (3, 2),
969            vec![
970                (1.0 - 2.0) / 1.0,    // -1.0
971                (10.0 - 20.0) / 10.0, // -1.0
972                (2.0 - 2.0) / 1.0,    // 0.0
973                (20.0 - 20.0) / 10.0, // 0.0
974                (3.0 - 2.0) / 1.0,    // 1.0
975                (30.0 - 20.0) / 10.0, // 1.0
976            ],
977        )
978        .unwrap();
979
980        for i in 0..3 {
981            for j in 0..2 {
982                assert_abs_diff_eq!(normalized[[i, j]], expected[[i, j]], epsilon = 1e-10);
983            }
984        }
985    }
986
987    #[test]
988    fn test_robust_normalizer() {
989        let data =
990            Array::from_shape_vec((4, 2), vec![1.0, 100.0, 2.0, 200.0, 3.0, 300.0, 4.0, 400.0])
991                .unwrap();
992
993        let mut normalizer = Normalizer::new(NormalizationMethod::Robust, 0);
994        let transformed = normalizer.fit_transform(&data).unwrap();
995
996        // For column 0: [1, 2, 3, 4] -> median=2.5, Q1=1.75, Q3=3.25, IQR=1.5
997        // For column 1: [100, 200, 300, 400] -> median=250, Q1=175, Q3=325, IQR=150
998        let expected = Array::from_shape_vec(
999            (4, 2),
1000            vec![
1001                (1.0 - 2.5) / 1.5,       // -1.0
1002                (100.0 - 250.0) / 150.0, // -1.0
1003                (2.0 - 2.5) / 1.5,       // -0.333...
1004                (200.0 - 250.0) / 150.0, // -0.333...
1005                (3.0 - 2.5) / 1.5,       // 0.333...
1006                (300.0 - 250.0) / 150.0, // 0.333...
1007                (4.0 - 2.5) / 1.5,       // 1.0
1008                (400.0 - 250.0) / 150.0, // 1.0
1009            ],
1010        )
1011        .unwrap();
1012
1013        for i in 0..4 {
1014            for j in 0..2 {
1015                assert_abs_diff_eq!(transformed[[i, j]], expected[[i, j]], epsilon = 1e-10);
1016            }
1017        }
1018    }
1019}