scirs2_cluster/preprocess/
mod.rs

1//! Data preprocessing utilities for clustering algorithms
2//!
3//! This module provides functions for preprocessing data before applying
4//! clustering algorithms, such as:
5//! - Whitening: Scaling features to have unit variance
6//! - Normalization: Scaling data to a specific range or norm
7//! - Standardization: Transforming data to have zero mean and unit variance
8
9use scirs2_core::ndarray::{Array1, Array2, ArrayView2, Axis};
10use scirs2_core::numeric::{Float, FromPrimitive};
11use std::fmt::Debug;
12
13use crate::error::{ClusteringError, Result};
14
15/// Whiten a dataset by rescaling each feature to have unit variance
16///
17/// This is useful for preprocessing before applying certain clustering algorithms
18/// like K-means. Each feature is divided by its standard deviation to give it
19/// unit variance.
20///
21/// # Arguments
22///
23/// * `data` - Input data as a 2D array (n_samples × n_features)
24/// * `check_finite` - Whether to check for NaN or infinite values
25///
26/// # Returns
27///
28/// * `Result<Array2<F>>` - The whitened data
29///
30/// # Examples
31///
32/// ```
33/// use scirs2_core::ndarray::Array2;
34/// use scirs2_cluster::preprocess::whiten;
35///
36/// // Example data with 3 features
37/// let data = Array2::from_shape_vec((3, 3), vec![
38///     1.9, 2.3, 1.7,
39///     1.5, 2.5, 2.2,
40///     0.8, 0.6, 1.7,
41/// ]).unwrap();
42///
43/// // Whiten the data
44/// let whitened = whiten(data.view(), true).unwrap();
45/// ```
46#[allow(dead_code)]
47pub fn whiten<F: Float + FromPrimitive + Debug>(
48    data: ArrayView2<F>,
49    check_finite: bool,
50) -> Result<Array2<F>> {
51    let n_samples = data.shape()[0];
52    let n_features = data.shape()[1];
53
54    if n_samples == 0 || n_features == 0 {
55        return Err(ClusteringError::InvalidInput("Input data is empty".into()));
56    }
57
58    if check_finite {
59        // Check for NaN or infinite values
60        for element in data.iter() {
61            if !element.is_finite() {
62                return Err(ClusteringError::InvalidInput(
63                    "Input data contains NaN or infinite values".into(),
64                ));
65            }
66        }
67    }
68
69    // Calculate the standard deviation for each feature
70    let std_dev = standard_deviation(data, Axis(0))?;
71
72    // Create the whitened data
73    let mut result = Array2::zeros(data.dim());
74
75    // Scale each feature by its standard deviation
76    for j in 0..n_features {
77        let std_j = std_dev[j];
78        if std_j <= F::epsilon() {
79            // If the standard deviation is close to zero, don't scale
80            for i in 0..n_samples {
81                result[[i, j]] = data[[i, j]];
82            }
83        } else {
84            for i in 0..n_samples {
85                result[[i, j]] = data[[i, j]] / std_j;
86            }
87        }
88    }
89
90    Ok(result)
91}
92
93/// Standardize a dataset by rescaling each feature to have zero mean and unit variance
94///
95/// # Arguments
96///
97/// * `data` - Input data as a 2D array (n_samples × n_features)
98/// * `check_finite` - Whether to check for NaN or infinite values
99///
100/// # Returns
101///
102/// * `Result<Array2<F>>` - The standardized data
103///
104/// # Examples
105///
106/// ```
107/// use scirs2_core::ndarray::Array2;
108/// use scirs2_cluster::preprocess::standardize;
109///
110/// // Example data with 3 features
111/// let data = Array2::from_shape_vec((3, 3), vec![
112///     1.9, 2.3, 1.7,
113///     1.5, 2.5, 2.2,
114///     0.8, 0.6, 1.7,
115/// ]).unwrap();
116///
117/// // Standardize the data
118/// let standardized = standardize(data.view(), true).unwrap();
119/// ```
120#[allow(dead_code)]
121pub fn standardize<F: Float + FromPrimitive + Debug>(
122    data: ArrayView2<F>,
123    check_finite: bool,
124) -> Result<Array2<F>> {
125    let n_samples = data.shape()[0];
126    let n_features = data.shape()[1];
127
128    if n_samples == 0 || n_features == 0 {
129        return Err(ClusteringError::InvalidInput("Input data is empty".into()));
130    }
131
132    if check_finite {
133        // Check for NaN or infinite values
134        for element in data.iter() {
135            if !element.is_finite() {
136                return Err(ClusteringError::InvalidInput(
137                    "Input data contains NaN or infinite values".into(),
138                ));
139            }
140        }
141    }
142
143    // Calculate the mean for each feature
144    let mean = data.mean_axis(Axis(0)).unwrap();
145
146    // Calculate the standard deviation for each feature
147    let std_dev = standard_deviation(data, Axis(0))?;
148
149    // Create the standardized data
150    let mut result = Array2::zeros(data.dim());
151
152    // Scale each feature to zero mean and unit variance
153    for j in 0..n_features {
154        let mean_j = mean[j];
155        let std_j = std_dev[j];
156
157        if std_j <= F::epsilon() {
158            // If the standard deviation is close to zero, just subtract the mean
159            for i in 0..n_samples {
160                result[[i, j]] = data[[i, j]] - mean_j;
161            }
162        } else {
163            for i in 0..n_samples {
164                result[[i, j]] = (data[[i, j]] - mean_j) / std_j;
165            }
166        }
167    }
168
169    Ok(result)
170}
171
172/// Normalize each sample to a given norm
173///
174/// # Arguments
175///
176/// * `data` - Input data as a 2D array (n_samples × n_features)
177/// * `norm` - Type of normalization: L1, L2, or Max
178/// * `check_finite` - Whether to check for NaN or infinite values
179///
180/// # Returns
181///
182/// * `Result<Array2<F>>` - The normalized data
183///
184/// # Examples
185///
186/// ```
187/// use scirs2_core::ndarray::Array2;
188/// use scirs2_cluster::preprocess::{normalize, NormType};
189///
190/// // Example data with 3 features
191/// let data = Array2::from_shape_vec((3, 3), vec![
192///     1.9, 2.3, 1.7,
193///     1.5, 2.5, 2.2,
194///     0.8, 0.6, 1.7,
195/// ]).unwrap();
196///
197/// // Normalize the data using L2 norm
198/// let normalized = normalize(data.view(), NormType::L2, true).unwrap();
199/// ```
200#[allow(dead_code)]
201pub fn normalize<F: Float + FromPrimitive + Debug>(
202    data: ArrayView2<F>,
203    norm: NormType,
204    check_finite: bool,
205) -> Result<Array2<F>> {
206    let n_samples = data.shape()[0];
207    let n_features = data.shape()[1];
208
209    if n_samples == 0 || n_features == 0 {
210        return Err(ClusteringError::InvalidInput("Input data is empty".into()));
211    }
212
213    if check_finite {
214        // Check for NaN or infinite values
215        for element in data.iter() {
216            if !element.is_finite() {
217                return Err(ClusteringError::InvalidInput(
218                    "Input data contains NaN or infinite values".into(),
219                ));
220            }
221        }
222    }
223
224    // Calculate the norm for each sample
225    let norms = match norm {
226        NormType::L1 => {
227            // L1 norm (sum of absolute values)
228            let mut norms = Array1::zeros(n_samples);
229            for i in 0..n_samples {
230                let row = data.row(i);
231                let row_norm = row.iter().fold(F::zero(), |acc, &x| acc + x.abs());
232                norms[i] = row_norm;
233            }
234            norms
235        }
236        NormType::L2 => {
237            // L2 norm (square root of sum of squares)
238            let mut norms = Array1::zeros(n_samples);
239            for i in 0..n_samples {
240                let row = data.row(i);
241                let row_norm = row.iter().fold(F::zero(), |acc, &x| acc + x * x).sqrt();
242                norms[i] = row_norm;
243            }
244            norms
245        }
246        NormType::Max => {
247            // Max norm (maximum absolute value)
248            let mut norms = Array1::zeros(n_samples);
249            for i in 0..n_samples {
250                let row = data.row(i);
251                let row_norm = row.iter().fold(F::zero(), |acc, &x| acc.max(x.abs()));
252                norms[i] = row_norm;
253            }
254            norms
255        }
256    };
257
258    // Create the normalized data
259    let mut result = Array2::zeros(data.dim());
260
261    // Scale each sample by its norm
262    for i in 0..n_samples {
263        let norm_i = norms[i];
264        if norm_i <= F::epsilon() {
265            // If the norm is close to zero, don't normalize
266            for j in 0..n_features {
267                result[[i, j]] = data[[i, j]];
268            }
269        } else {
270            for j in 0..n_features {
271                result[[i, j]] = data[[i, j]] / norm_i;
272            }
273        }
274    }
275
276    Ok(result)
277}
278
279/// Normalize data to a specified range (min-max scaling)
280///
281/// # Arguments
282///
283/// * `data` - Input data as a 2D array (n_samples × n_features)
284/// * `feature_range` - Tuple of (min, max) values to scale to
285/// * `check_finite` - Whether to check for NaN or infinite values
286///
287/// # Returns
288///
289/// * `Result<Array2<F>>` - The scaled data
290///
291/// # Examples
292///
293/// ```
294/// use scirs2_core::ndarray::Array2;
295/// use scirs2_cluster::preprocess::min_max_scale;
296///
297/// // Example data with 3 features
298/// let data = Array2::from_shape_vec((3, 3), vec![
299///     1.9, 2.3, 1.7,
300///     1.5, 2.5, 2.2,
301///     0.8, 0.6, 1.7,
302/// ]).unwrap();
303///
304/// // Scale the data to the range [0, 1]
305/// let scaled = min_max_scale(data.view(), (0.0, 1.0), true).unwrap();
306/// ```
307#[allow(dead_code)]
308pub fn min_max_scale<F: Float + FromPrimitive + Debug>(
309    data: ArrayView2<F>,
310    feature_range: (f64, f64),
311    check_finite: bool,
312) -> Result<Array2<F>> {
313    let n_samples = data.shape()[0];
314    let n_features = data.shape()[1];
315
316    if n_samples == 0 || n_features == 0 {
317        return Err(ClusteringError::InvalidInput("Input data is empty".into()));
318    }
319
320    if check_finite {
321        // Check for NaN or infinite values
322        for element in data.iter() {
323            if !element.is_finite() {
324                return Err(ClusteringError::InvalidInput(
325                    "Input data contains NaN or infinite values".into(),
326                ));
327            }
328        }
329    }
330
331    let (min_val, max_val) = feature_range;
332    if min_val >= max_val {
333        return Err(ClusteringError::InvalidInput(
334            "Feature range minimum must be less than maximum".into(),
335        ));
336    }
337
338    let feature_min = F::from_f64(min_val).unwrap();
339    let feature_max = F::from_f64(max_val).unwrap();
340
341    // Calculate the min and max for each feature
342    let mut min_values = Array1::zeros(n_features);
343    let mut max_values = Array1::zeros(n_features);
344
345    for j in 0..n_features {
346        let column = data.column(j);
347        let (min_j, max_j) = column.iter().fold(
348            (F::infinity(), F::neg_infinity()),
349            |(min_val, max_val), &x| (min_val.min(x), max_val.max(x)),
350        );
351        min_values[j] = min_j;
352        max_values[j] = max_j;
353    }
354
355    // Create the scaled data
356    let mut result = Array2::zeros(data.dim());
357
358    // Scale each feature to the specified range
359    for j in 0..n_features {
360        let min_j = min_values[j];
361        let max_j = max_values[j];
362        let range_j = max_j - min_j;
363
364        if range_j <= F::epsilon() {
365            // If the feature has no variation, set to the middle of the feature range
366            let middle = (feature_min + feature_max) / F::from_f64(2.0).unwrap();
367            for i in 0..n_samples {
368                result[[i, j]] = middle;
369            }
370        } else {
371            for i in 0..n_samples {
372                // Scale to [0, 1]
373                let scaled = (data[[i, j]] - min_j) / range_j;
374                // Scale to [feature_min, feature_max]
375                result[[i, j]] = scaled * (feature_max - feature_min) + feature_min;
376            }
377        }
378    }
379
380    Ok(result)
381}
382
383/// Normalization types for the normalize function
384#[derive(Debug, Clone, Copy, PartialEq, Eq)]
385pub enum NormType {
386    /// L1 norm (sum of absolute values)
387    L1,
388    /// L2 norm (square root of sum of squares)
389    L2,
390    /// Max norm (maximum absolute value)
391    Max,
392}
393
394/// Calculate the standard deviation along the specified axis
395#[allow(dead_code)]
396fn standard_deviation<F: Float + FromPrimitive + Debug>(
397    data: ArrayView2<F>,
398    axis: Axis,
399) -> Result<Array1<F>> {
400    let mean = data.mean_axis(axis).unwrap();
401    let n = F::from_usize(match axis {
402        Axis(0) => data.shape()[0],
403        Axis(1) => data.shape()[1],
404        _ => return Err(ClusteringError::InvalidInput("Invalid axis".into())),
405    })
406    .unwrap();
407
408    let mut variance = match axis {
409        Axis(0) => Array1::zeros(data.shape()[1]),
410        Axis(1) => Array1::zeros(data.shape()[0]),
411        _ => return Err(ClusteringError::InvalidInput("Invalid axis".into())),
412    };
413
414    if axis == Axis(0) {
415        // Calculate variance along rows (for each feature)
416        let n_features = data.shape()[1];
417        for j in 0..n_features {
418            let mut sum_squared_diff = F::zero();
419            for i in 0..data.shape()[0] {
420                let diff = data[[i, j]] - mean[j];
421                sum_squared_diff = sum_squared_diff + diff * diff;
422            }
423            // Avoid division by zero for single sample
424            if n > F::one() {
425                variance[j] = sum_squared_diff / (n - F::one());
426            } else {
427                variance[j] = F::zero();
428            }
429        }
430    } else {
431        // Calculate variance along columns (for each sample)
432        let n_samples = data.shape()[0];
433        for i in 0..n_samples {
434            let mut sum_squared_diff = F::zero();
435            for j in 0..data.shape()[1] {
436                let diff = data[[i, j]] - mean[i];
437                sum_squared_diff = sum_squared_diff + diff * diff;
438            }
439            // Avoid division by zero for single feature
440            if n > F::one() {
441                variance[i] = sum_squared_diff / (n - F::one());
442            } else {
443                variance[i] = F::zero();
444            }
445        }
446    }
447
448    // Calculate standard deviation
449    let std_dev = variance.mapv(|x| x.sqrt());
450
451    // Replace zeros with ones to avoid division by zero
452    let std_dev = std_dev.mapv(|x| if x <= F::epsilon() { F::one() } else { x });
453
454    Ok(std_dev)
455}
456
457#[cfg(test)]
458mod tests {
459    use super::*;
460    use approx::assert_abs_diff_eq;
461    use scirs2_core::ndarray::Array2;
462
463    #[test]
464    fn test_whiten() {
465        let data =
466            Array2::from_shape_vec((3, 3), vec![1.9, 2.3, 1.7, 1.5, 2.5, 2.2, 0.8, 0.6, 1.7])
467                .unwrap();
468
469        let whitened = whiten(data.view(), true).unwrap();
470
471        // Check that each feature has approximately unit variance
472        let std_dev = standard_deviation(whitened.view(), Axis(0)).unwrap();
473        for &std in std_dev.iter() {
474            assert_abs_diff_eq!(std, 1.0, epsilon = 1e-10);
475        }
476    }
477
478    #[test]
479    fn test_standardize() {
480        let data =
481            Array2::from_shape_vec((3, 3), vec![1.9, 2.3, 1.7, 1.5, 2.5, 2.2, 0.8, 0.6, 1.7])
482                .unwrap();
483
484        let standardized = standardize(data.view(), true).unwrap();
485
486        // Check that each feature has approximately zero mean
487        let mean = standardized.mean_axis(Axis(0)).unwrap();
488        for mean_val in mean.iter() {
489            assert_abs_diff_eq!(*mean_val, 0.0, epsilon = 1e-10);
490        }
491
492        // Check that each feature has approximately unit variance
493        let std_dev = standard_deviation(standardized.view(), Axis(0)).unwrap();
494        for std_val in std_dev.iter() {
495            assert_abs_diff_eq!(*std_val, 1.0, epsilon = 1e-10);
496        }
497    }
498
499    #[test]
500    fn test_normalize_l2() {
501        let data =
502            Array2::from_shape_vec((3, 3), vec![1.9, 2.3, 1.7, 1.5, 2.5, 2.2, 0.8, 0.6, 1.7])
503                .unwrap();
504
505        let normalized = normalize(data.view(), NormType::L2, true).unwrap();
506
507        // Check that each sample has L2 norm approximately equal to 1
508        for i in 0..data.shape()[0] {
509            let row = normalized.row(i);
510            let norm_sq: f64 = row.iter().fold(0.0, |acc, &x| acc + x * x);
511            let norm = norm_sq.sqrt();
512            assert_abs_diff_eq!(norm, 1.0, epsilon = 1e-10);
513        }
514    }
515
516    #[test]
517    fn test_min_max_scale() {
518        let data =
519            Array2::from_shape_vec((3, 3), vec![1.9, 2.3, 1.7, 1.5, 2.5, 2.2, 0.8, 0.6, 1.7])
520                .unwrap();
521
522        let scaled = min_max_scale(data.view(), (0.0, 1.0), true).unwrap();
523
524        // Check that all values are in the range [0, 1]
525        for val in scaled.iter() {
526            assert!(*val >= 0.0 && *val <= 1.0);
527        }
528
529        // Check that for each feature, the minimum value is 0.0 and the maximum is 1.0
530        for j in 0..data.shape()[1] {
531            let column = scaled.column(j);
532
533            // Convert the values to f64 for stable comparison
534            let column_values: Vec<f64> = column.iter().copied().collect();
535
536            if !column_values.is_empty() {
537                let min_val = column_values
538                    .iter()
539                    .fold(f64::INFINITY, |min, &x| min.min(x));
540                let max_val = column_values
541                    .iter()
542                    .fold(f64::NEG_INFINITY, |max, &x| max.max(x));
543
544                // Only check if the feature had different values in the original data
545                if data.column(j).iter().any(|&x| x != data[[0, j]]) {
546                    assert_abs_diff_eq!(min_val, 0.0, epsilon = 1e-10);
547                    assert_abs_diff_eq!(max_val, 1.0, epsilon = 1e-10);
548                }
549            }
550        }
551    }
552}