Skip to main content

scirs2_transform/
scaling.rs

1//! Advanced scaling and transformation methods
2//!
3//! This module provides sophisticated scaling methods that go beyond basic normalization,
4//! including quantile transformations and robust scaling methods.
5
6use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Data, Ix2};
7use scirs2_core::numeric::{Float, NumCast};
8
9use crate::error::{Result, TransformError};
10
11/// Small epsilon value for numerical stability and comparison with zero
12pub const EPSILON: f64 = 1e-10;
13
14/// QuantileTransformer for non-linear transformations
15///
16/// This transformer transforms features to follow a uniform or normal distribution
17/// using quantiles information. This method reduces the impact of outliers.
18pub struct QuantileTransformer {
19    /// Number of quantiles to estimate
20    n_quantiles: usize,
21    /// Output distribution ('uniform' or 'normal')
22    output_distribution: String,
23    /// Whether to clip transformed values to bounds [0, 1] for uniform distribution
24    clip: bool,
25    /// The quantiles for each feature
26    quantiles: Option<Array2<f64>>,
27    /// References values for each quantile
28    references: Option<Array1<f64>>,
29}
30
31impl QuantileTransformer {
32    /// Creates a new QuantileTransformer
33    ///
34    /// # Arguments
35    /// * `n_quantiles` - Number of quantiles to estimate (default: 1000)
36    /// * `output_distribution` - Target distribution ('uniform' or 'normal')
37    /// * `clip` - Whether to clip transformed values
38    ///
39    /// # Returns
40    /// * A new QuantileTransformer instance
41    pub fn new(n_quantiles: usize, outputdistribution: &str, clip: bool) -> Result<Self> {
42        if n_quantiles < 2 {
43            return Err(TransformError::InvalidInput(
44                "n_quantiles must be at least 2".to_string(),
45            ));
46        }
47
48        if outputdistribution != "uniform" && outputdistribution != "normal" {
49            return Err(TransformError::InvalidInput(
50                "output_distribution must be 'uniform' or 'normal'".to_string(),
51            ));
52        }
53
54        Ok(QuantileTransformer {
55            n_quantiles,
56            output_distribution: outputdistribution.to_string(),
57            clip,
58            quantiles: None,
59            references: None,
60        })
61    }
62
63    /// Fits the QuantileTransformer to the input data
64    ///
65    /// # Arguments
66    /// * `x` - The input data, shape (n_samples, n_features)
67    ///
68    /// # Returns
69    /// * `Result<()>` - Ok if successful, Err otherwise
70    pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
71    where
72        S: Data,
73        S::Elem: Float + NumCast,
74    {
75        let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
76
77        let n_samples = x_f64.shape()[0];
78        let n_features = x_f64.shape()[1];
79
80        if n_samples == 0 || n_features == 0 {
81            return Err(TransformError::InvalidInput("Empty input data".to_string()));
82        }
83
84        if self.n_quantiles > n_samples {
85            return Err(TransformError::InvalidInput(format!(
86                "n_quantiles ({}) cannot be greater than n_samples ({})",
87                self.n_quantiles, n_samples
88            )));
89        }
90
91        // Compute quantiles for each feature
92        let mut quantiles = Array2::zeros((n_features, self.n_quantiles));
93
94        for j in 0..n_features {
95            // Extract feature data and sort it
96            let mut feature_data: Vec<f64> = x_f64.column(j).to_vec();
97            feature_data.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
98
99            // Compute quantiles
100            for i in 0..self.n_quantiles {
101                let q = i as f64 / (self.n_quantiles - 1) as f64;
102                let idx = (q * (feature_data.len() - 1) as f64).round() as usize;
103                quantiles[[j, i]] = feature_data[idx];
104            }
105        }
106
107        // Generate reference distribution
108        let references = if self.output_distribution == "uniform" {
109            // Uniform distribution references
110            Array1::from_shape_fn(self.n_quantiles, |i| {
111                i as f64 / (self.n_quantiles - 1) as f64
112            })
113        } else {
114            // Normal distribution references (using inverse normal CDF approximation)
115            Array1::from_shape_fn(self.n_quantiles, |i| {
116                let u = i as f64 / (self.n_quantiles - 1) as f64;
117                // Clamp u to avoid extreme values
118                let u_clamped = u.clamp(1e-7, 1.0 - 1e-7);
119                inverse_normal_cdf(u_clamped)
120            })
121        };
122
123        self.quantiles = Some(quantiles);
124        self.references = Some(references);
125
126        Ok(())
127    }
128
129    /// Transforms the input data using the fitted QuantileTransformer
130    ///
131    /// # Arguments
132    /// * `x` - The input data, shape (n_samples, n_features)
133    ///
134    /// # Returns
135    /// * `Result<Array2<f64>>` - The transformed data
136    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
137    where
138        S: Data,
139        S::Elem: Float + NumCast,
140    {
141        let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
142
143        let n_samples = x_f64.shape()[0];
144        let n_features = x_f64.shape()[1];
145
146        if self.quantiles.is_none() || self.references.is_none() {
147            return Err(TransformError::TransformationError(
148                "QuantileTransformer has not been fitted".to_string(),
149            ));
150        }
151
152        let quantiles = self.quantiles.as_ref().expect("Operation failed");
153        let references = self.references.as_ref().expect("Operation failed");
154
155        if n_features != quantiles.shape()[0] {
156            return Err(TransformError::InvalidInput(format!(
157                "x has {} features, but QuantileTransformer was fitted with {} features",
158                n_features,
159                quantiles.shape()[0]
160            )));
161        }
162
163        let mut transformed = Array2::zeros((n_samples, n_features));
164
165        for i in 0..n_samples {
166            for j in 0..n_features {
167                let value = x_f64[[i, j]];
168
169                // Find the position of the value in the quantiles
170                let feature_quantiles = quantiles.row(j);
171
172                // Find the index where value would be inserted
173                let mut lower_idx = 0;
174                let mut upper_idx = self.n_quantiles - 1;
175
176                // Handle edge cases
177                if value <= feature_quantiles[0] {
178                    transformed[[i, j]] = references[0];
179                    continue;
180                }
181                if value >= feature_quantiles[self.n_quantiles - 1] {
182                    transformed[[i, j]] = references[self.n_quantiles - 1];
183                    continue;
184                }
185
186                // Binary search to find the interval
187                while upper_idx - lower_idx > 1 {
188                    let mid = (lower_idx + upper_idx) / 2;
189                    if value <= feature_quantiles[mid] {
190                        upper_idx = mid;
191                    } else {
192                        lower_idx = mid;
193                    }
194                }
195
196                // Linear interpolation between reference values
197                let lower_quantile = feature_quantiles[lower_idx];
198                let upper_quantile = feature_quantiles[upper_idx];
199                let lower_ref = references[lower_idx];
200                let upper_ref = references[upper_idx];
201
202                if (upper_quantile - lower_quantile).abs() < EPSILON {
203                    transformed[[i, j]] = lower_ref;
204                } else {
205                    let ratio = (value - lower_quantile) / (upper_quantile - lower_quantile);
206                    transformed[[i, j]] = lower_ref + ratio * (upper_ref - lower_ref);
207                }
208            }
209        }
210
211        // Apply clipping if requested and output distribution is uniform
212        if self.clip && self.output_distribution == "uniform" {
213            for i in 0..n_samples {
214                for j in 0..n_features {
215                    transformed[[i, j]] = transformed[[i, j]].clamp(0.0, 1.0);
216                }
217            }
218        }
219
220        Ok(transformed)
221    }
222
223    /// Fits the QuantileTransformer to the input data and transforms it
224    ///
225    /// # Arguments
226    /// * `x` - The input data, shape (n_samples, n_features)
227    ///
228    /// # Returns
229    /// * `Result<Array2<f64>>` - The transformed data
230    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
231    where
232        S: Data,
233        S::Elem: Float + NumCast,
234    {
235        self.fit(x)?;
236        self.transform(x)
237    }
238
239    /// Returns the quantiles for each feature
240    ///
241    /// # Returns
242    /// * `Option<&Array2<f64>>` - The quantiles, shape (n_features, n_quantiles)
243    pub fn quantiles(&self) -> Option<&Array2<f64>> {
244        self.quantiles.as_ref()
245    }
246}
247
248/// Approximation of the inverse normal cumulative distribution function
249///
250/// This uses the Beasley-Springer-Moro algorithm for approximating the inverse normal CDF
251#[allow(dead_code)]
252fn inverse_normal_cdf(u: f64) -> f64 {
253    // Constants for the Beasley-Springer-Moro algorithm
254    const A0: f64 = 2.50662823884;
255    const A1: f64 = -18.61500062529;
256    const A2: f64 = 41.39119773534;
257    const A3: f64 = -25.44106049637;
258    const B1: f64 = -8.47351093090;
259    const B2: f64 = 23.08336743743;
260    const B3: f64 = -21.06224101826;
261    const B4: f64 = 3.13082909833;
262    const C0: f64 = 0.3374754822726147;
263    const C1: f64 = 0.9761690190917186;
264    const C2: f64 = 0.1607979714918209;
265    const C3: f64 = 0.0276438810333863;
266    const C4: f64 = 0.0038405729373609;
267    const C5: f64 = 0.0003951896511919;
268    const C6: f64 = 0.0000321767881768;
269    const C7: f64 = 0.0000002888167364;
270    const C8: f64 = 0.0000003960315187;
271
272    let y = u - 0.5;
273
274    if y.abs() < 0.42 {
275        // Central region
276        let r = y * y;
277        y * (((A3 * r + A2) * r + A1) * r + A0) / ((((B4 * r + B3) * r + B2) * r + B1) * r + 1.0)
278    } else {
279        // Tail region
280        let r = if y > 0.0 { 1.0 - u } else { u };
281        let r = (-r.ln()).ln();
282
283        let result = C0
284            + r * (C1 + r * (C2 + r * (C3 + r * (C4 + r * (C5 + r * (C6 + r * (C7 + r * C8)))))));
285
286        if y < 0.0 {
287            -result
288        } else {
289            result
290        }
291    }
292}
293
294/// MaxAbsScaler for scaling features by their maximum absolute value
295///
296/// This scaler scales each feature individually such that the maximal absolute value
297/// of each feature in the training set will be 1.0. It does not shift/center the data,
298/// and thus does not destroy any sparsity.
299pub struct MaxAbsScaler {
300    /// Maximum absolute values for each feature (learned during fit)
301    max_abs_: Option<Array1<f64>>,
302    /// Scale factors for each feature (1 / max_abs_)
303    scale_: Option<Array1<f64>>,
304}
305
306impl MaxAbsScaler {
307    /// Creates a new MaxAbsScaler
308    ///
309    /// # Returns
310    /// * A new MaxAbsScaler instance
311    ///
312    /// # Examples
313    /// ```
314    /// use scirs2_transform::scaling::MaxAbsScaler;
315    ///
316    /// let scaler = MaxAbsScaler::new();
317    /// ```
318    pub fn new() -> Self {
319        MaxAbsScaler {
320            max_abs_: None,
321            scale_: None,
322        }
323    }
324
325    /// Creates a MaxAbsScaler with default settings (same as new())
326    #[allow(dead_code)]
327    pub fn with_defaults() -> Self {
328        Self::new()
329    }
330
331    /// Fits the MaxAbsScaler to the input data
332    ///
333    /// # Arguments
334    /// * `x` - The input data, shape (n_samples, n_features)
335    ///
336    /// # Returns
337    /// * `Result<()>` - Ok if successful, Err otherwise
338    pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
339    where
340        S: Data,
341        S::Elem: Float + NumCast,
342    {
343        let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
344
345        let n_samples = x_f64.shape()[0];
346        let n_features = x_f64.shape()[1];
347
348        if n_samples == 0 || n_features == 0 {
349            return Err(TransformError::InvalidInput("Empty input data".to_string()));
350        }
351
352        // Compute maximum absolute value for each feature
353        let mut max_abs = Array1::zeros(n_features);
354
355        for j in 0..n_features {
356            let feature_data = x_f64.column(j);
357            let max_abs_value = feature_data
358                .iter()
359                .map(|&x| x.abs())
360                .fold(0.0, |acc, x| acc.max(x));
361
362            max_abs[j] = max_abs_value;
363        }
364
365        // Compute scale factors (avoid division by zero)
366        let scale = max_abs.mapv(|max_abs_val| {
367            if max_abs_val > EPSILON {
368                1.0 / max_abs_val
369            } else {
370                1.0 // If max_abs is 0, don't scale (feature is constant zero)
371            }
372        });
373
374        self.max_abs_ = Some(max_abs);
375        self.scale_ = Some(scale);
376
377        Ok(())
378    }
379
380    /// Transforms the input data using the fitted MaxAbsScaler
381    ///
382    /// # Arguments
383    /// * `x` - The input data, shape (n_samples, n_features)
384    ///
385    /// # Returns
386    /// * `Result<Array2<f64>>` - The scaled data
387    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
388    where
389        S: Data,
390        S::Elem: Float + NumCast,
391    {
392        let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
393
394        let n_samples = x_f64.shape()[0];
395        let n_features = x_f64.shape()[1];
396
397        if self.scale_.is_none() {
398            return Err(TransformError::TransformationError(
399                "MaxAbsScaler has not been fitted".to_string(),
400            ));
401        }
402
403        let scale = self.scale_.as_ref().expect("Operation failed");
404
405        if n_features != scale.len() {
406            return Err(TransformError::InvalidInput(format!(
407                "x has {} features, but MaxAbsScaler was fitted with {} features",
408                n_features,
409                scale.len()
410            )));
411        }
412
413        let mut transformed = Array2::zeros((n_samples, n_features));
414
415        // Scale each feature by its scale factor
416        for i in 0..n_samples {
417            for j in 0..n_features {
418                transformed[[i, j]] = x_f64[[i, j]] * scale[j];
419            }
420        }
421
422        Ok(transformed)
423    }
424
425    /// Fits the MaxAbsScaler to the input data and transforms it
426    ///
427    /// # Arguments
428    /// * `x` - The input data, shape (n_samples, n_features)
429    ///
430    /// # Returns
431    /// * `Result<Array2<f64>>` - The scaled data
432    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
433    where
434        S: Data,
435        S::Elem: Float + NumCast,
436    {
437        self.fit(x)?;
438        self.transform(x)
439    }
440
441    /// Inverse transforms the scaled data back to original scale
442    ///
443    /// # Arguments
444    /// * `x` - The scaled data, shape (n_samples, n_features)
445    ///
446    /// # Returns
447    /// * `Result<Array2<f64>>` - The data in original scale
448    pub fn inverse_transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
449    where
450        S: Data,
451        S::Elem: Float + NumCast,
452    {
453        let x_f64 = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
454
455        let n_samples = x_f64.shape()[0];
456        let n_features = x_f64.shape()[1];
457
458        if self.max_abs_.is_none() {
459            return Err(TransformError::TransformationError(
460                "MaxAbsScaler has not been fitted".to_string(),
461            ));
462        }
463
464        let max_abs = self.max_abs_.as_ref().expect("Operation failed");
465
466        if n_features != max_abs.len() {
467            return Err(TransformError::InvalidInput(format!(
468                "x has {} features, but MaxAbsScaler was fitted with {} features",
469                n_features,
470                max_abs.len()
471            )));
472        }
473
474        let mut transformed = Array2::zeros((n_samples, n_features));
475
476        // Scale back by multiplying with max_abs values
477        for i in 0..n_samples {
478            for j in 0..n_features {
479                transformed[[i, j]] = x_f64[[i, j]] * max_abs[j];
480            }
481        }
482
483        Ok(transformed)
484    }
485
486    /// Returns the maximum absolute values for each feature
487    ///
488    /// # Returns
489    /// * `Option<&Array1<f64>>` - The maximum absolute values
490    pub fn max_abs(&self) -> Option<&Array1<f64>> {
491        self.max_abs_.as_ref()
492    }
493
494    /// Returns the scale factors for each feature
495    ///
496    /// # Returns
497    /// * `Option<&Array1<f64>>` - The scale factors (1 / max_abs)
498    pub fn scale(&self) -> Option<&Array1<f64>> {
499        self.scale_.as_ref()
500    }
501}
502
503impl Default for MaxAbsScaler {
504    fn default() -> Self {
505        Self::new()
506    }
507}
508
509#[cfg(test)]
510mod tests {
511    use super::*;
512    use approx::assert_abs_diff_eq;
513    use scirs2_core::ndarray::Array;
514
515    #[test]
516    fn test_quantile_transformer_uniform() {
517        // Create test data with different distributions
518        let data = Array::from_shape_vec(
519            (6, 2),
520            vec![
521                1.0, 10.0, 2.0, 20.0, 3.0, 30.0, 4.0, 40.0, 5.0, 50.0, 100.0, 1000.0,
522            ], // Last row has outliers
523        )
524        .expect("Operation failed");
525
526        let mut transformer =
527            QuantileTransformer::new(5, "uniform", true).expect("Operation failed");
528        let transformed = transformer.fit_transform(&data).expect("Operation failed");
529
530        // Check that the shape is preserved
531        assert_eq!(transformed.shape(), &[6, 2]);
532
533        // For uniform distribution, values should be between 0 and 1
534        for i in 0..6 {
535            for j in 0..2 {
536                assert!(
537                    transformed[[i, j]] >= 0.0 && transformed[[i, j]] <= 1.0,
538                    "Value at [{}, {}] = {} is not in [0, 1]",
539                    i,
540                    j,
541                    transformed[[i, j]]
542                );
543            }
544        }
545
546        // The smallest value should map to 0 and largest to 1
547        assert_abs_diff_eq!(transformed[[0, 0]], 0.0, epsilon = 1e-10); // min of column 0
548        assert_abs_diff_eq!(transformed[[5, 0]], 1.0, epsilon = 1e-10); // max of column 0
549        assert_abs_diff_eq!(transformed[[0, 1]], 0.0, epsilon = 1e-10); // min of column 1
550        assert_abs_diff_eq!(transformed[[5, 1]], 1.0, epsilon = 1e-10); // max of column 1
551    }
552
553    #[test]
554    fn test_quantile_transformer_normal() {
555        // Create test data
556        let data =
557            Array::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).expect("Operation failed");
558
559        let mut transformer =
560            QuantileTransformer::new(5, "normal", false).expect("Operation failed");
561        let transformed = transformer.fit_transform(&data).expect("Operation failed");
562
563        // Check that the shape is preserved
564        assert_eq!(transformed.shape(), &[5, 1]);
565
566        // The middle value should be close to 0 (median of normal distribution)
567        assert_abs_diff_eq!(transformed[[2, 0]], 0.0, epsilon = 1e-10);
568    }
569
570    #[test]
571    fn test_quantile_transformer_errors() {
572        // Test invalid n_quantiles
573        assert!(QuantileTransformer::new(1, "uniform", true).is_err());
574
575        // Test invalid output_distribution
576        assert!(QuantileTransformer::new(100, "invalid", true).is_err());
577
578        // Test fitting with insufficient data
579        let small_data = Array::from_shape_vec((2, 1), vec![1.0, 2.0]).expect("Operation failed");
580        let mut transformer =
581            QuantileTransformer::new(10, "uniform", true).expect("Operation failed");
582        assert!(transformer.fit(&small_data).is_err());
583    }
584
585    #[test]
586    fn test_inverse_normal_cdf() {
587        // Test some known values
588        assert_abs_diff_eq!(inverse_normal_cdf(0.5), 0.0, epsilon = 1e-6);
589        assert!(inverse_normal_cdf(0.1) < 0.0); // Should be negative
590        assert!(inverse_normal_cdf(0.9) > 0.0); // Should be positive
591    }
592
593    #[test]
594    fn test_max_abs_scaler_basic() {
595        // Create test data with different ranges
596        // Feature 0: [-4, -2, 0, 2, 4] -> max_abs = 4
597        // Feature 1: [-10, -5, 0, 5, 10] -> max_abs = 10
598        let data = Array::from_shape_vec(
599            (5, 2),
600            vec![-4.0, -10.0, -2.0, -5.0, 0.0, 0.0, 2.0, 5.0, 4.0, 10.0],
601        )
602        .expect("Operation failed");
603
604        let mut scaler = MaxAbsScaler::new();
605        let scaled = scaler.fit_transform(&data).expect("Operation failed");
606
607        // Check that the shape is preserved
608        assert_eq!(scaled.shape(), &[5, 2]);
609
610        // Check the maximum absolute values
611        let max_abs = scaler.max_abs().expect("Operation failed");
612        assert_abs_diff_eq!(max_abs[0], 4.0, epsilon = 1e-10);
613        assert_abs_diff_eq!(max_abs[1], 10.0, epsilon = 1e-10);
614
615        // Check the scale factors
616        let scale = scaler.scale().expect("Operation failed");
617        assert_abs_diff_eq!(scale[0], 0.25, epsilon = 1e-10); // 1/4
618        assert_abs_diff_eq!(scale[1], 0.1, epsilon = 1e-10); // 1/10
619
620        // Check that the maximum absolute value in each feature is 1.0
621        for j in 0..2 {
622            let feature_max = scaled
623                .column(j)
624                .iter()
625                .map(|&x| x.abs())
626                .fold(0.0, f64::max);
627            assert_abs_diff_eq!(feature_max, 1.0, epsilon = 1e-10);
628        }
629
630        // Check specific scaled values
631        assert_abs_diff_eq!(scaled[[0, 0]], -1.0, epsilon = 1e-10); // -4 / 4 = -1
632        assert_abs_diff_eq!(scaled[[0, 1]], -1.0, epsilon = 1e-10); // -10 / 10 = -1
633        assert_abs_diff_eq!(scaled[[2, 0]], 0.0, epsilon = 1e-10); // 0 / 4 = 0
634        assert_abs_diff_eq!(scaled[[2, 1]], 0.0, epsilon = 1e-10); // 0 / 10 = 0
635        assert_abs_diff_eq!(scaled[[4, 0]], 1.0, epsilon = 1e-10); // 4 / 4 = 1
636        assert_abs_diff_eq!(scaled[[4, 1]], 1.0, epsilon = 1e-10); // 10 / 10 = 1
637    }
638
639    #[test]
640    fn test_max_abs_scaler_positive_only() {
641        // Test with positive-only data
642        let data = Array::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 6.0, 5.0, 10.0])
643            .expect("Operation failed");
644
645        let mut scaler = MaxAbsScaler::new();
646        let scaled = scaler.fit_transform(&data).expect("Operation failed");
647
648        // Check maximum absolute values
649        let max_abs = scaler.max_abs().expect("Operation failed");
650        assert_abs_diff_eq!(max_abs[0], 5.0, epsilon = 1e-10);
651        assert_abs_diff_eq!(max_abs[1], 10.0, epsilon = 1e-10);
652
653        // Check scaled values
654        assert_abs_diff_eq!(scaled[[0, 0]], 0.2, epsilon = 1e-10); // 1 / 5
655        assert_abs_diff_eq!(scaled[[0, 1]], 0.2, epsilon = 1e-10); // 2 / 10
656        assert_abs_diff_eq!(scaled[[2, 0]], 1.0, epsilon = 1e-10); // 5 / 5
657        assert_abs_diff_eq!(scaled[[2, 1]], 1.0, epsilon = 1e-10); // 10 / 10
658    }
659
660    #[test]
661    fn test_max_abs_scaler_inverse_transform() {
662        let data = Array::from_shape_vec((3, 2), vec![-6.0, 8.0, 0.0, -4.0, 3.0, 12.0])
663            .expect("Operation failed");
664
665        let mut scaler = MaxAbsScaler::new();
666        let scaled = scaler.fit_transform(&data).expect("Operation failed");
667        let inverse = scaler.inverse_transform(&scaled).expect("Operation failed");
668
669        // Check that inverse transform recovers original data
670        assert_eq!(inverse.shape(), data.shape());
671        for i in 0..3 {
672            for j in 0..2 {
673                assert_abs_diff_eq!(inverse[[i, j]], data[[i, j]], epsilon = 1e-10);
674            }
675        }
676    }
677
678    #[test]
679    fn test_max_abs_scaler_constant_feature() {
680        // Test with a constant feature (all zeros)
681        let data = Array::from_shape_vec((3, 2), vec![0.0, 5.0, 0.0, 10.0, 0.0, 15.0])
682            .expect("Operation failed");
683
684        let mut scaler = MaxAbsScaler::new();
685        let scaled = scaler.fit_transform(&data).expect("Operation failed");
686
687        // Constant zero feature should remain zero
688        for i in 0..3 {
689            assert_abs_diff_eq!(scaled[[i, 0]], 0.0, epsilon = 1e-10);
690        }
691
692        // Second feature should be scaled normally
693        assert_abs_diff_eq!(scaled[[0, 1]], 1.0 / 3.0, epsilon = 1e-10); // 5 / 15
694        assert_abs_diff_eq!(scaled[[2, 1]], 1.0, epsilon = 1e-10); // 15 / 15
695    }
696
697    #[test]
698    fn test_max_abs_scaler_errors() {
699        // Test with empty data
700        let empty_data = Array2::<f64>::zeros((0, 2));
701        let mut scaler = MaxAbsScaler::new();
702        assert!(scaler.fit(&empty_data).is_err());
703
704        // Test transform before fit
705        let data =
706            Array::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).expect("Operation failed");
707        let unfitted_scaler = MaxAbsScaler::new();
708        assert!(unfitted_scaler.transform(&data).is_err());
709        assert!(unfitted_scaler.inverse_transform(&data).is_err());
710
711        // Test feature dimension mismatch
712        let train_data = Array::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
713            .expect("Operation failed");
714        let test_data =
715            Array::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).expect("Operation failed");
716
717        let mut scaler = MaxAbsScaler::new();
718        scaler.fit(&train_data).expect("Operation failed");
719        assert!(scaler.transform(&test_data).is_err());
720        assert!(scaler.inverse_transform(&test_data).is_err());
721    }
722
723    #[test]
724    fn test_max_abs_scaler_single_feature() {
725        // Test with single feature
726        let data =
727            Array::from_shape_vec((4, 1), vec![-8.0, -2.0, 4.0, 6.0]).expect("Operation failed");
728
729        let mut scaler = MaxAbsScaler::new();
730        let scaled = scaler.fit_transform(&data).expect("Operation failed");
731
732        // Maximum absolute value should be 8.0
733        let max_abs = scaler.max_abs().expect("Operation failed");
734        assert_abs_diff_eq!(max_abs[0], 8.0, epsilon = 1e-10);
735
736        // Check scaled values
737        assert_abs_diff_eq!(scaled[[0, 0]], -1.0, epsilon = 1e-10); // -8 / 8
738        assert_abs_diff_eq!(scaled[[1, 0]], -0.25, epsilon = 1e-10); // -2 / 8
739        assert_abs_diff_eq!(scaled[[2, 0]], 0.5, epsilon = 1e-10); // 4 / 8
740        assert_abs_diff_eq!(scaled[[3, 0]], 0.75, epsilon = 1e-10); // 6 / 8
741    }
742
743    #[test]
744    fn test_max_abs_scaler_sparse_preservation() {
745        // Test that zero values remain zero (sparsity preservation)
746        let data = Array::from_shape_vec(
747            (4, 3),
748            vec![
749                0.0, 5.0, 0.0, // Row with zeros
750                10.0, 0.0, -8.0, // Another row with zeros
751                0.0, 0.0, 4.0, // Row with multiple zeros
752                -5.0, 10.0, 0.0, // Row with zero at end
753            ],
754        )
755        .expect("Operation failed");
756
757        let mut scaler = MaxAbsScaler::new();
758        let scaled = scaler.fit_transform(&data).expect("Operation failed");
759
760        // Check that zeros remain zeros
761        assert_abs_diff_eq!(scaled[[0, 0]], 0.0, epsilon = 1e-10);
762        assert_abs_diff_eq!(scaled[[0, 2]], 0.0, epsilon = 1e-10);
763        assert_abs_diff_eq!(scaled[[1, 1]], 0.0, epsilon = 1e-10);
764        assert_abs_diff_eq!(scaled[[2, 0]], 0.0, epsilon = 1e-10);
765        assert_abs_diff_eq!(scaled[[2, 1]], 0.0, epsilon = 1e-10);
766        assert_abs_diff_eq!(scaled[[3, 2]], 0.0, epsilon = 1e-10);
767
768        // Check that non-zero values are scaled correctly
769        // Feature 0: max_abs = 10, Feature 1: max_abs = 10, Feature 2: max_abs = 8
770        assert_abs_diff_eq!(scaled[[0, 1]], 0.5, epsilon = 1e-10); // 5 / 10
771        assert_abs_diff_eq!(scaled[[1, 0]], 1.0, epsilon = 1e-10); // 10 / 10
772        assert_abs_diff_eq!(scaled[[1, 2]], -1.0, epsilon = 1e-10); // -8 / 8
773        assert_abs_diff_eq!(scaled[[2, 2]], 0.5, epsilon = 1e-10); // 4 / 8
774        assert_abs_diff_eq!(scaled[[3, 0]], -0.5, epsilon = 1e-10); // -5 / 10
775        assert_abs_diff_eq!(scaled[[3, 1]], 1.0, epsilon = 1e-10); // 10 / 10
776    }
777}