sklears_preprocessing/
functional.rs

1//! Functional APIs for preprocessing
2//!
3//! This module provides functional APIs that directly transform data without
4//! needing to create and fit transformer objects.
5
6use scirs2_core::ndarray::{Array1, Array2};
7use sklears_core::{
8    error::{Result, SklearsError},
9    traits::{Fit, Transform},
10    types::Float,
11};
12
13use crate::{Binarizer, LabelBinarizer, NormType, Normalizer};
14
15/// Standardize a dataset along any axis
16///
17/// Center to the mean and component wise scale to unit variance.
18///
19/// # Arguments
20/// * `x` - The data to scale
21/// * `axis` - Axis along which to compute mean and std (0 for features, 1 for samples)
22/// * `with_mean` - If true, center the data before scaling
23/// * `with_std` - If true, scale the data to unit variance
24///
25/// # Returns
26/// Scaled data
27pub fn scale(
28    x: &Array2<Float>,
29    axis: usize,
30    with_mean: bool,
31    with_std: bool,
32) -> Result<Array2<Float>> {
33    // Manual implementation since StandardScaler is a placeholder
34    let mut result = x.clone();
35
36    if axis == 0 {
37        // Scale along features
38        for j in 0..x.ncols() {
39            let column = x.column(j);
40            let mean = if with_mean {
41                column.mean().unwrap_or(0.0)
42            } else {
43                0.0
44            };
45            let std = if with_std {
46                column.std(0.0).max(1e-8)
47            } else {
48                1.0
49            };
50
51            for i in 0..x.nrows() {
52                result[[i, j]] = (x[[i, j]] - mean) / std;
53            }
54        }
55    } else if axis == 1 {
56        // Scale along samples
57        for i in 0..x.nrows() {
58            let row = x.row(i);
59            let mean = if with_mean {
60                row.mean().unwrap_or(0.0)
61            } else {
62                0.0
63            };
64            let std = if with_std {
65                row.std(0.0).max(1e-8)
66            } else {
67                1.0
68            };
69
70            for j in 0..x.ncols() {
71                result[[i, j]] = (x[[i, j]] - mean) / std;
72            }
73        }
74    } else {
75        return Err(SklearsError::InvalidInput(format!(
76            "axis must be 0 or 1, got {axis}"
77        )));
78    }
79
80    Ok(result)
81}
82
83/// Scale samples individually to unit norm
84///
85/// Each sample (i.e. each row of the data matrix) with at least one
86/// non-zero component is rescaled independently of other samples so
87/// that its norm equals one.
88///
89/// # Arguments
90/// * `x` - The data to normalize
91/// * `norm` - The norm to use ('l1', 'l2', or 'max')
92/// * `axis` - Axis along which to normalize (1 for samples, 0 for features)
93///
94/// # Returns
95/// Normalized data
96pub fn normalize(x: &Array2<Float>, norm: NormType, axis: usize) -> Result<Array2<Float>> {
97    if axis == 1 {
98        // Normalize along samples (standard behavior)
99        let normalizer = Normalizer::new().norm(norm);
100        normalizer.transform(x)
101    } else if axis == 0 {
102        // Normalize along features (transpose, normalize, transpose back)
103        let x_t = x.t().to_owned();
104        let normalizer = Normalizer::new().norm(norm);
105        let normalized = normalizer.transform(&x_t)?;
106        Ok(normalized.t().to_owned())
107    } else {
108        Err(SklearsError::InvalidInput(format!(
109            "axis must be 0 or 1, got {axis}"
110        )))
111    }
112}
113
114/// Boolean thresholding of array-like or scipy.sparse matrix
115///
116/// # Arguments
117/// * `x` - The data to binarize
118/// * `threshold` - Feature values below or equal to this are replaced by 0, above it by 1
119///
120/// # Returns
121/// Binarized data
122pub fn binarize(x: &Array2<Float>, threshold: Float) -> Result<Array2<Float>> {
123    let binarizer = Binarizer::new().threshold(threshold);
124    let fitted = binarizer.fit(x, &())?;
125    fitted.transform(x)
126}
127
128/// Scale each feature by its maximum absolute value
129///
130/// # Arguments
131/// * `x` - The data to scale
132/// * `axis` - Axis along which to scale (0 for features)
133///
134/// # Returns
135/// Scaled data
136pub fn maxabs_scale(x: &Array2<Float>, axis: usize) -> Result<Array2<Float>> {
137    if axis != 0 {
138        return Err(SklearsError::InvalidInput(
139            "maxabs_scale only supports axis=0".to_string(),
140        ));
141    }
142
143    let mut result = x.clone();
144
145    for j in 0..x.ncols() {
146        let column = x.column(j);
147        let max_abs = column.iter().map(|&v| v.abs()).fold(0.0, Float::max);
148
149        if max_abs > 1e-8 {
150            for i in 0..x.nrows() {
151                result[[i, j]] = x[[i, j]] / max_abs;
152            }
153        }
154    }
155
156    Ok(result)
157}
158
159/// Transform features to range [0, 1]
160///
161/// # Arguments
162/// * `x` - The data to scale
163/// * `feature_range` - Desired range of transformed data
164/// * `axis` - Axis along which to scale (0 for features)
165///
166/// # Returns
167/// Scaled data
168pub fn minmax_scale(
169    x: &Array2<Float>,
170    feature_range: (Float, Float),
171    axis: usize,
172) -> Result<Array2<Float>> {
173    if axis != 0 {
174        return Err(SklearsError::InvalidInput(
175            "minmax_scale only supports axis=0".to_string(),
176        ));
177    }
178
179    let mut result = x.clone();
180    let (min_range, max_range) = feature_range;
181
182    for j in 0..x.ncols() {
183        let column = x.column(j);
184        let min_val = column.iter().fold(Float::INFINITY, |a, &b| a.min(b));
185        let max_val = column.iter().fold(Float::NEG_INFINITY, |a, &b| a.max(b));
186        let range = max_val - min_val;
187
188        if range > 1e-8 {
189            for i in 0..x.nrows() {
190                let normalized = (x[[i, j]] - min_val) / range;
191                result[[i, j]] = normalized * (max_range - min_range) + min_range;
192            }
193        } else {
194            // If range is zero, set all values to midpoint of desired range
195            let midpoint = (min_range + max_range) / 2.0;
196            for i in 0..x.nrows() {
197                result[[i, j]] = midpoint;
198            }
199        }
200    }
201
202    Ok(result)
203}
204
205/// Scale features using statistics that are robust to outliers
206///
207/// # Arguments
208/// * `x` - The data to scale
209/// * `axis` - Axis along which to scale (0 for features)
210/// * `with_centering` - If true, center the data before scaling
211/// * `with_scaling` - If true, scale the data to interquartile range
212/// * `quantile_range` - Quantile range used to calculate scale
213///
214/// # Returns
215/// Scaled data
216pub fn robust_scale(
217    x: &Array2<Float>,
218    axis: usize,
219    with_centering: bool,
220    with_scaling: bool,
221    quantile_range: (Float, Float),
222) -> Result<Array2<Float>> {
223    if axis != 0 {
224        return Err(SklearsError::InvalidInput(
225            "robust_scale only supports axis=0".to_string(),
226        ));
227    }
228
229    let mut result = x.clone();
230
231    for j in 0..x.ncols() {
232        let mut column: Vec<Float> = x.column(j).to_vec();
233        column.sort_by(|a, b| a.partial_cmp(b).unwrap());
234
235        let n = column.len();
236        let q1_idx = ((n as Float) * quantile_range.0) as usize;
237        let q3_idx = ((n as Float) * quantile_range.1) as usize;
238
239        let q1 = column[q1_idx.min(n - 1)];
240        let q3 = column[q3_idx.min(n - 1)];
241        let median = column[n / 2];
242
243        let center = if with_centering { median } else { 0.0 };
244        let scale = if with_scaling && (q3 - q1) > 1e-8 {
245            q3 - q1
246        } else {
247            1.0
248        };
249
250        for i in 0..x.nrows() {
251            result[[i, j]] = (x[[i, j]] - center) / scale;
252        }
253    }
254
255    Ok(result)
256}
257
258// FIXME: Commenting out complex transformations until proper implementations are available
259
260// /// Transform features to uniform or normal distribution
261// ///
262// /// # Arguments
263// /// * `x` - The data to transform
264// /// * `n_quantiles` - Number of quantiles to estimate
265// /// * `output_distribution` - Marginal distribution for transformed data
266// /// * `subsample` - Maximum number of samples to use for quantile estimation
267// ///
268// /// # Returns
269// /// Transformed data
270// pub fn quantile_transform(
271//     x: &Array2<Float>,
272//     n_quantiles: usize,
273//     output_distribution: QuantileOutput,
274//     subsample: Option<usize>,
275// ) -> Result<Array2<Float>> {
276//     let transformer = QuantileTransformer::new()
277//         .n_quantiles(n_quantiles)
278//         .output_distribution(output_distribution)
279//         .subsample(subsample);
280//     let fitted = transformer.fit(x, &())?;
281//     fitted.transform(x)
282// }
283
284// /// Apply a power transform to make data more Gaussian-like
285// ///
286// /// # Arguments
287// /// * `x` - The data to transform
288// /// * `method` - The power transform method ('yeo-johnson' or 'box-cox')
289// /// * `standardize` - Apply zero-mean, unit-variance normalization
290// ///
291// /// # Returns
292// /// Transformed data
293// pub fn power_transform(
294//     x: &Array2<Float>,
295//     method: PowerMethod,
296//     standardize: bool,
297// ) -> Result<Array2<Float>> {
298//     let transformer = PowerTransformer::new()
299//         .method(method)
300//         .standardize(standardize);
301//     let fitted = transformer.fit(x, &())?;
302//     fitted.transform(x)
303// }
304
305/// Add a dummy feature to the data
306///
307/// This is useful for fitting an intercept term with implementations which
308/// cannot otherwise fit it directly.
309///
310/// # Arguments
311/// * `x` - The data to add dummy feature to
312/// * `value` - Value of the dummy feature
313///
314/// # Returns
315/// Data with dummy feature added as first column
316pub fn add_dummy_feature(x: &Array2<Float>, value: Float) -> Result<Array2<Float>> {
317    let n_samples = x.nrows();
318    let n_features = x.ncols();
319
320    // Create new array with extra column
321    let mut x_with_dummy = Array2::zeros((n_samples, n_features + 1));
322
323    // Set dummy feature values
324    x_with_dummy.column_mut(0).fill(value);
325
326    // Copy original features
327    x_with_dummy
328        .slice_mut(scirs2_core::ndarray::s![.., 1..])
329        .assign(x);
330
331    Ok(x_with_dummy)
332}
333
334/// Binarize labels in a one-vs-all fashion
335///
336/// # Arguments
337/// * `y` - Target values (labels)
338/// * `neg_label` - Value for negative labels
339/// * `pos_label` - Value for positive labels
340///
341/// # Returns
342/// Binarized labels
343pub fn label_binarize<T>(y: &Array1<T>, neg_label: i32, pos_label: i32) -> Result<Array2<Float>>
344where
345    T: std::hash::Hash + Eq + Clone + std::fmt::Debug + Ord + Send + Sync,
346{
347    let binarizer = LabelBinarizer::<T>::new()
348        .neg_label(neg_label)
349        .pos_label(pos_label);
350    let fitted = binarizer.fit(y, &())?;
351    fitted.transform(y)
352}
353
354#[allow(non_snake_case)]
355#[cfg(test)]
356mod tests {
357    use super::*;
358    use approx::assert_abs_diff_eq;
359    use scirs2_core::ndarray::{arr1, arr2};
360
361    #[test]
362    fn test_scale() {
363        let x = arr2(&[[0.0, 0.0], [0.0, 0.0], [1.0, 1.0], [1.0, 1.0]]);
364
365        // Scale along features
366        let scaled = scale(&x, 0, true, true).unwrap();
367
368        // Check mean is 0
369        for j in 0..x.ncols() {
370            let col_mean = scaled.column(j).mean().unwrap();
371            assert_abs_diff_eq!(col_mean, 0.0, epsilon = 1e-10);
372        }
373
374        // Check std is 1
375        for j in 0..x.ncols() {
376            let col = scaled.column(j);
377            let std = col.std(0.0);
378            assert_abs_diff_eq!(std, 1.0, epsilon = 1e-10);
379        }
380    }
381
382    #[test]
383    fn test_normalize() {
384        let x = arr2(&[[4.0, 3.0], [1.0, 2.0]]);
385
386        // L2 normalize along samples
387        let normalized = normalize(&x, NormType::L2, 1).unwrap();
388
389        // Check L2 norm is 1 for each row
390        for i in 0..x.nrows() {
391            let row = normalized.row(i);
392            let norm = row.dot(&row).sqrt();
393            assert_abs_diff_eq!(norm, 1.0, epsilon = 1e-10);
394        }
395    }
396
397    #[test]
398    fn test_binarize() {
399        let x = arr2(&[[0.5, 1.5], [2.5, 3.5]]);
400
401        let binarized = binarize(&x, 2.0).unwrap();
402
403        assert_eq!(binarized[[0, 0]], 0.0);
404        assert_eq!(binarized[[0, 1]], 0.0);
405        assert_eq!(binarized[[1, 0]], 1.0);
406        assert_eq!(binarized[[1, 1]], 1.0);
407    }
408
409    #[test]
410    fn test_minmax_scale() {
411        let x = arr2(&[[0.0, 0.0], [1.0, 2.0], [2.0, 4.0]]);
412
413        let scaled = minmax_scale(&x, (0.0, 1.0), 0).unwrap();
414
415        // Check min is 0 and max is 1 for each feature
416        for j in 0..x.ncols() {
417            let col = scaled.column(j);
418            let min = col.iter().fold(f64::INFINITY, |a, &b| a.min(b));
419            let max = col.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
420            assert_abs_diff_eq!(min, 0.0, epsilon = 1e-10);
421            assert_abs_diff_eq!(max, 1.0, epsilon = 1e-10);
422        }
423    }
424
425    #[test]
426    fn test_add_dummy_feature() {
427        let x = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
428
429        let x_with_dummy = add_dummy_feature(&x, 1.0).unwrap();
430
431        assert_eq!(x_with_dummy.shape(), &[2, 3]);
432        assert_eq!(x_with_dummy[[0, 0]], 1.0);
433        assert_eq!(x_with_dummy[[1, 0]], 1.0);
434        assert_eq!(x_with_dummy[[0, 1]], 1.0);
435        assert_eq!(x_with_dummy[[0, 2]], 2.0);
436    }
437
438    #[test]
439    fn test_label_binarize() {
440        let y = arr1(&[0, 1, 2, 1, 0]);
441
442        let binarized = label_binarize(&y, 0, 1).unwrap();
443
444        // Should have shape (n_samples, n_classes)
445        assert_eq!(binarized.shape(), &[5, 3]);
446
447        // Check one-hot encoding
448        assert_eq!(binarized[[0, 0]], 1.0); // class 0
449        assert_eq!(binarized[[1, 1]], 1.0); // class 1
450        assert_eq!(binarized[[2, 2]], 1.0); // class 2
451    }
452}