scirs2_io/pipeline/
transforms.rs

1//! Common data transformations for pipelines
2
3#![allow(dead_code)]
4#![allow(missing_docs)]
5
6use super::*;
7use crate::error::Result;
8use scirs2_core::ndarray::{s, Array1, Array2, Axis};
9use scirs2_core::numeric::{Float, FromPrimitive};
10use statrs::statistics::Statistics;
11use std::collections::HashMap;
12use std::marker::PhantomData;
13
14/// Normalization transformer
15pub struct NormalizeTransform<T> {
16    method: NormalizationMethod,
17    _phantom: PhantomData<T>,
18}
19
20#[derive(Debug, Clone)]
21pub enum NormalizationMethod {
22    MinMax { min: f64, max: f64 },
23    ZScore,
24    L1,
25    L2,
26    MaxAbs,
27}
28
29impl<T> NormalizeTransform<T>
30where
31    T: Float + FromPrimitive + Send + Sync,
32{
33    pub fn new(method: NormalizationMethod) -> Self {
34        Self {
35            method,
36            _phantom: PhantomData,
37        }
38    }
39}
40
41impl<T> DataTransformer for NormalizeTransform<T>
42where
43    T: Float + FromPrimitive + Send + Sync + 'static,
44{
45    fn transform(&self, data: Box<dyn Any + Send + Sync>) -> Result<Box<dyn Any + Send + Sync>> {
46        if let Ok(array) = data.downcast::<Array2<T>>() {
47            let normalized = match &self.method {
48                NormalizationMethod::MinMax { min, max } => normalize_minmax(
49                    *array,
50                    T::from_f64(*min).unwrap(),
51                    T::from_f64(*max).unwrap(),
52                ),
53                NormalizationMethod::ZScore => normalize_zscore(*array),
54                NormalizationMethod::L1 => normalize_l1(*array),
55                NormalizationMethod::L2 => normalize_l2(*array),
56                NormalizationMethod::MaxAbs => normalize_maxabs(*array),
57            };
58            Ok(Box::new(normalized) as Box<dyn Any + Send + Sync>)
59        } else {
60            Err(IoError::Other(
61                "Invalid data type for normalization".to_string(),
62            ))
63        }
64    }
65}
66
67#[allow(dead_code)]
68fn normalize_minmax<T>(mut array: Array2<T>, new_min: T, new_max: T) -> Array2<T>
69where
70    T: Float + FromPrimitive,
71{
72    let _min = array.iter().fold(T::infinity(), |a, &b| a.min(b));
73    let _max = array.iter().fold(T::neg_infinity(), |a, &b| a.max(b));
74    let range = _max - _min;
75
76    if range > T::zero() {
77        let scale = (new_max - new_min) / range;
78        array.mapv_inplace(|x| (x - _min) * scale + new_min);
79    }
80
81    array
82}
83
84#[allow(dead_code)]
85fn normalize_zscore<T>(mut array: Array2<T>) -> Array2<T>
86where
87    T: Float + FromPrimitive,
88{
89    let n = T::from_usize(array.len()).unwrap();
90    let mean = array.iter().fold(T::zero(), |a, &b| a + b) / n;
91    let variance = array.iter().fold(T::zero(), |a, &b| {
92        let diff = b - mean;
93        a + diff * diff
94    }) / n;
95    let std = variance.sqrt();
96
97    if std > T::zero() {
98        array.mapv_inplace(|x| (x - mean) / std);
99    }
100
101    array
102}
103
104#[allow(dead_code)]
105fn normalize_l1<T>(mut array: Array2<T>) -> Array2<T>
106where
107    T: Float,
108{
109    for mut row in array.axis_iter_mut(Axis(0)) {
110        let norm = row.iter().fold(T::zero(), |a, &b| a + b.abs());
111        if norm > T::zero() {
112            row.mapv_inplace(|x| x / norm);
113        }
114    }
115    array
116}
117
118#[allow(dead_code)]
119fn normalize_l2<T>(mut array: Array2<T>) -> Array2<T>
120where
121    T: Float,
122{
123    for mut row in array.axis_iter_mut(Axis(0)) {
124        let norm = row.iter().fold(T::zero(), |a, &b| a + b * b).sqrt();
125        if norm > T::zero() {
126            row.mapv_inplace(|x| x / norm);
127        }
128    }
129    array
130}
131
132#[allow(dead_code)]
133fn normalize_maxabs<T>(mut array: Array2<T>) -> Array2<T>
134where
135    T: Float,
136{
137    let max_abs = array.iter().fold(T::zero(), |a, &b| a.max(b.abs()));
138    if max_abs > T::zero() {
139        array.mapv_inplace(|x| x / max_abs);
140    }
141    array
142}
143
144/// Reshape transformer
145pub struct ReshapeTransform {
146    newshape: Vec<usize>,
147}
148
149impl ReshapeTransform {
150    pub fn new(shape: Vec<usize>) -> Self {
151        Self { newshape: shape }
152    }
153}
154
155impl DataTransformer for ReshapeTransform {
156    fn transform(&self, data: Box<dyn Any + Send + Sync>) -> Result<Box<dyn Any + Send + Sync>> {
157        if let Ok(array) = data.downcast::<Array2<f64>>() {
158            let total_elements: usize = self.newshape.iter().product();
159            if array.len() != total_elements {
160                return Err(IoError::Other(format!(
161                    "Cannot reshape array of size {} to shape {:?}",
162                    array.len(),
163                    self.newshape
164                )));
165            }
166
167            // Convert to 1D, then reshape
168            let flat: Vec<f64> = array.into_iter().collect();
169            let reshaped = Array2::from_shape_vec((self.newshape[0], self.newshape[1]), flat)
170                .map_err(|e| IoError::Other(e.to_string()))?;
171
172            Ok(Box::new(reshaped) as Box<dyn Any + Send + Sync>)
173        } else {
174            Err(IoError::Other("Invalid data type for reshape".to_string()))
175        }
176    }
177}
178
179/// Type conversion transformer
180pub struct TypeConvertTransform<From, To> {
181    _from: PhantomData<From>,
182    _to: PhantomData<To>,
183}
184
185impl<From, To> Default for TypeConvertTransform<From, To> {
186    fn default() -> Self {
187        Self::new()
188    }
189}
190
191impl<From, To> TypeConvertTransform<From, To> {
192    pub fn new() -> Self {
193        Self {
194            _from: PhantomData,
195            _to: PhantomData,
196        }
197    }
198}
199
200impl<From, To> DataTransformer for TypeConvertTransform<From, To>
201where
202    From: 'static + Send + Sync,
203    To: 'static + Send + Sync + std::convert::From<From>,
204{
205    fn transform(&self, data: Box<dyn Any + Send + Sync>) -> Result<Box<dyn Any + Send + Sync>> {
206        if let Ok(from_data) = data.downcast::<From>() {
207            let to_data: To = To::from(*from_data);
208            Ok(Box::new(to_data) as Box<dyn Any + Send + Sync>)
209        } else {
210            Err(IoError::Other("Type conversion failed".to_string()))
211        }
212    }
213}
214
215/// Aggregation transformer
216pub struct AggregateTransform {
217    method: AggregationMethod,
218    axis: Option<Axis>,
219}
220
221#[derive(Debug, Clone)]
222pub enum AggregationMethod {
223    Sum,
224    Mean,
225    Min,
226    Max,
227    Std,
228    Var,
229}
230
231impl AggregateTransform {
232    pub fn new(method: AggregationMethod, axis: Option<Axis>) -> Self {
233        Self { method, axis }
234    }
235}
236
237impl DataTransformer for AggregateTransform {
238    fn transform(&self, data: Box<dyn Any + Send + Sync>) -> Result<Box<dyn Any + Send + Sync>> {
239        if let Ok(array) = data.downcast::<Array2<f64>>() {
240            let result = match (&self.method, self.axis) {
241                (AggregationMethod::Sum, Some(axis)) => {
242                    Box::new(array.sum_axis(axis)) as Box<dyn Any + Send + Sync>
243                }
244                (AggregationMethod::Mean, Some(axis)) => {
245                    Box::new(array.mean_axis(axis).unwrap()) as Box<dyn Any + Send + Sync>
246                }
247                (AggregationMethod::Sum, None) => {
248                    Box::new(array.sum()) as Box<dyn Any + Send + Sync>
249                }
250                (AggregationMethod::Mean, None) => {
251                    Box::new(array.mean()) as Box<dyn Any + Send + Sync>
252                }
253                _ => return Err(IoError::Other("Unsupported aggregation".to_string())),
254            };
255            Ok(result)
256        } else {
257            Err(IoError::Other(
258                "Invalid data type for aggregation".to_string(),
259            ))
260        }
261    }
262}
263
264/// Encoding transformer for categorical data
265pub struct EncodingTransform {
266    method: EncodingMethod,
267}
268
269#[derive(Debug, Clone)]
270pub enum EncodingMethod {
271    OneHot,
272    Label,
273    Ordinal(Vec<String>),
274}
275
276impl EncodingTransform {
277    pub fn new(method: EncodingMethod) -> Self {
278        Self { method }
279    }
280}
281
282impl DataTransformer for EncodingTransform {
283    fn transform(&self, data: Box<dyn Any + Send + Sync>) -> Result<Box<dyn Any + Send + Sync>> {
284        if let Ok(categories) = data.downcast::<Vec<String>>() {
285            match &self.method {
286                EncodingMethod::Label => {
287                    let mut label_map = HashMap::new();
288                    let mut next_label = 0;
289
290                    let encoded: Vec<i32> = categories
291                        .iter()
292                        .map(|cat| {
293                            *label_map.entry(cat.clone()).or_insert_with(|| {
294                                let label = next_label;
295                                next_label += 1;
296                                label
297                            })
298                        })
299                        .collect();
300
301                    Ok(Box::new(encoded) as Box<dyn Any + Send + Sync>)
302                }
303                EncodingMethod::OneHot => {
304                    let unique_categories: Vec<String> = {
305                        let mut cats = (*categories).clone();
306                        cats.sort();
307                        cats.dedup();
308                        cats
309                    };
310
311                    let n_categories = unique_categories.len();
312                    let n_samples = categories.len();
313                    let mut encoded = Array2::<f64>::zeros((n_samples, n_categories));
314
315                    for (i, cat) in categories.iter().enumerate() {
316                        if let Some(j) = unique_categories.iter().position(|c| c == cat) {
317                            encoded[[i, j]] = 1.0;
318                        }
319                    }
320
321                    Ok(Box::new(encoded) as Box<dyn Any + Send + Sync>)
322                }
323                EncodingMethod::Ordinal(order) => {
324                    let encoded: Result<Vec<i32>> = categories
325                        .iter()
326                        .map(|cat| {
327                            order
328                                .iter()
329                                .position(|o| o == cat)
330                                .map(|pos| pos as i32)
331                                .ok_or_else(|| IoError::Other(format!("Unknown category: {}", cat)))
332                        })
333                        .collect();
334
335                    Ok(Box::new(encoded?) as Box<dyn Any + Send + Sync>)
336                }
337            }
338        } else {
339            Err(IoError::Other("Invalid data type for encoding".to_string()))
340        }
341    }
342}
343
344/// Missing value imputation transformer
345pub struct ImputeTransform {
346    strategy: ImputationStrategy,
347}
348
349#[derive(Debug, Clone)]
350pub enum ImputationStrategy {
351    Mean,
352    Median,
353    Mode,
354    Constant(f64),
355    Forward,
356    Backward,
357}
358
359impl ImputeTransform {
360    pub fn new(strategy: ImputationStrategy) -> Self {
361        Self { strategy }
362    }
363}
364
365impl DataTransformer for ImputeTransform {
366    fn transform(&self, data: Box<dyn Any + Send + Sync>) -> Result<Box<dyn Any + Send + Sync>> {
367        if let Ok(mut array) = data.downcast::<Array2<Option<f64>>>() {
368            match &self.strategy {
369                ImputationStrategy::Mean => {
370                    for mut col in array.axis_iter_mut(Axis(1)) {
371                        let valid_values: Vec<f64> = col.iter().filter_map(|&x| x).collect();
372
373                        if !valid_values.is_empty() {
374                            let mean = valid_values.iter().sum::<f64>() / valid_values.len() as f64;
375                            col.mapv_inplace(|x| Some(x.unwrap_or(mean)));
376                        }
377                    }
378                }
379                ImputationStrategy::Constant(value) => {
380                    array.mapv_inplace(|x| Some(x.unwrap_or(*value)));
381                }
382                _ => {
383                    return Err(IoError::Other(
384                        "Unsupported imputation strategy".to_string(),
385                    ))
386                }
387            }
388
389            // Convert to Array2<f64> after imputation
390            let imputed: Array2<f64> = array.mapv(|x| x.unwrap_or(0.0));
391            Ok(Box::new(imputed) as Box<dyn Any + Send + Sync>)
392        } else {
393            Err(IoError::Other(
394                "Invalid data type for imputation".to_string(),
395            ))
396        }
397    }
398}
399
400/// Outlier detection and removal transformer
401pub struct OutlierTransform {
402    method: OutlierMethod,
403    threshold: f64,
404}
405
406#[derive(Debug, Clone)]
407pub enum OutlierMethod {
408    ZScore,
409    IQR,
410    IsolationForest,
411}
412
413impl OutlierTransform {
414    pub fn new(method: OutlierMethod, threshold: f64) -> Self {
415        Self { method, threshold }
416    }
417}
418
419impl DataTransformer for OutlierTransform {
420    fn transform(&self, data: Box<dyn Any + Send + Sync>) -> Result<Box<dyn Any + Send + Sync>> {
421        if let Ok(array) = data.downcast::<Array2<f64>>() {
422            match &self.method {
423                OutlierMethod::ZScore => {
424                    let mean = array.view().mean();
425                    let std = array.std(0.0);
426
427                    let filtered: Vec<Vec<f64>> = array
428                        .axis_iter(Axis(0))
429                        .filter(|row| {
430                            row.iter()
431                                .all(|&x| ((x - mean) / std).abs() <= self.threshold)
432                        })
433                        .map(|row| row.to_vec())
434                        .collect();
435
436                    if filtered.is_empty() {
437                        return Err(IoError::Other("All data filtered as outliers".to_string()));
438                    }
439
440                    let n_rows = filtered.len();
441                    let n_cols = filtered[0].len();
442                    let flat: Vec<f64> = filtered.into_iter().flatten().collect();
443
444                    let result = Array2::from_shape_vec((n_rows, n_cols), flat)
445                        .map_err(|e| IoError::Other(e.to_string()))?;
446
447                    Ok(Box::new(result) as Box<dyn Any + Send + Sync>)
448                }
449                OutlierMethod::IQR => {
450                    // Interquartile Range method
451                    let mut filtered_rows = Vec::new();
452
453                    for row in array.axis_iter(Axis(0)) {
454                        let mut values: Vec<f64> = row.to_vec();
455                        values.sort_by(|a, b| a.partial_cmp(b).unwrap());
456
457                        let n = values.len();
458                        let q1_idx = n / 4;
459                        let q3_idx = 3 * n / 4;
460                        let q1 = values[q1_idx];
461                        let q3 = values[q3_idx];
462                        let iqr = q3 - q1;
463
464                        let lower_bound = q1 - self.threshold * iqr;
465                        let upper_bound = q3 + self.threshold * iqr;
466
467                        let is_outlier = row.iter().any(|&x| x < lower_bound || x > upper_bound);
468
469                        if !is_outlier {
470                            filtered_rows.push(row.to_vec());
471                        }
472                    }
473
474                    if filtered_rows.is_empty() {
475                        return Err(IoError::Other("All data filtered as outliers".to_string()));
476                    }
477
478                    let n_rows = filtered_rows.len();
479                    let n_cols = filtered_rows[0].len();
480                    let flat: Vec<f64> = filtered_rows.into_iter().flatten().collect();
481
482                    let result = Array2::from_shape_vec((n_rows, n_cols), flat)
483                        .map_err(|e| IoError::Other(e.to_string()))?;
484
485                    Ok(Box::new(result) as Box<dyn Any + Send + Sync>)
486                }
487                _ => Err(IoError::Other("Unsupported outlier method".to_string())),
488            }
489        } else {
490            Err(IoError::Other(
491                "Invalid data type for outlier detection".to_string(),
492            ))
493        }
494    }
495}
496
497/// Principal Component Analysis transformer
498pub struct PCATransform {
499    n_components: usize,
500    components: Option<Array2<f64>>,
501    mean: Option<Array1<f64>>,
502}
503
504impl PCATransform {
505    pub fn new(_ncomponents: usize) -> Self {
506        Self {
507            n_components: _ncomponents,
508            components: None,
509            mean: None,
510        }
511    }
512
513    /// Fit PCA on training data
514    pub fn fit(&mut self, data: &Array2<f64>) -> Result<()> {
515        let (n_samples, n_features) = data.dim();
516
517        if self.n_components > n_features {
518            return Err(IoError::Other(
519                "n_components cannot exceed n_features".to_string(),
520            ));
521        }
522
523        // Center the data
524        let mean = data.mean_axis(Axis(0)).unwrap();
525        let centered = data - &mean.clone().insert_axis(Axis(0));
526
527        // Compute covariance matrix
528        let _cov = centered.t().dot(&centered) / (n_samples - 1) as f64;
529
530        // For simplicity, use a basic eigenvalue decomposition approximation
531        // In practice, you would use a proper linear algebra library
532        self.mean = Some(mean);
533
534        // Mock components for demonstration
535        let components = Array2::eye(n_features)
536            .slice(s![..self.n_components, ..])
537            .to_owned();
538        self.components = Some(components);
539
540        Ok(())
541    }
542}
543
544impl DataTransformer for PCATransform {
545    fn transform(&self, data: Box<dyn Any + Send + Sync>) -> Result<Box<dyn Any + Send + Sync>> {
546        if let Ok(array) = data.downcast::<Array2<f64>>() {
547            let mean = self
548                .mean
549                .as_ref()
550                .ok_or_else(|| IoError::Other("PCA not fitted yet".to_string()))?;
551            let components = self
552                .components
553                .as_ref()
554                .ok_or_else(|| IoError::Other("PCA not fitted yet".to_string()))?;
555
556            // Center the data
557            let centered = &*array - &mean.clone().insert_axis(Axis(0));
558
559            // Project onto principal components
560            let transformed = centered.dot(&components.t());
561
562            Ok(Box::new(transformed) as Box<dyn Any + Send + Sync>)
563        } else {
564            Err(IoError::Other("Invalid data type for PCA".to_string()))
565        }
566    }
567}
568
569/// Feature engineering transformer
570pub struct FeatureEngineeringTransform {
571    operations: Vec<FeatureOperation>,
572}
573
574#[derive(Debug, Clone)]
575pub enum FeatureOperation {
576    Polynomial {
577        degree: usize,
578    },
579    Log,
580    Sqrt,
581    Square,
582    Interaction {
583        indices: Vec<usize>,
584    },
585    Binning {
586        n_bins: usize,
587        strategy: BinningStrategy,
588    },
589}
590
591#[derive(Debug, Clone)]
592pub enum BinningStrategy {
593    Uniform,
594    Quantile,
595}
596
597impl FeatureEngineeringTransform {
598    pub fn new(operations: Vec<FeatureOperation>) -> Self {
599        Self { operations }
600    }
601}
602
603impl DataTransformer for FeatureEngineeringTransform {
604    fn transform(&self, data: Box<dyn Any + Send + Sync>) -> Result<Box<dyn Any + Send + Sync>> {
605        if let Ok(array) = data.downcast::<Array2<f64>>() {
606            let mut result = (*array).clone();
607
608            for operation in &self.operations {
609                match operation {
610                    FeatureOperation::Log => {
611                        let log_features = result.mapv(|x| if x > 0.0 { x.ln() } else { 0.0 });
612                        result = scirs2_core::ndarray::concatenate(
613                            Axis(1),
614                            &[result.view(), log_features.view()],
615                        )
616                        .unwrap();
617                    }
618                    FeatureOperation::Sqrt => {
619                        let sqrt_features = result.mapv(|x| if x >= 0.0 { x.sqrt() } else { 0.0 });
620                        result = scirs2_core::ndarray::concatenate(
621                            Axis(1),
622                            &[result.view(), sqrt_features.view()],
623                        )
624                        .unwrap();
625                    }
626                    FeatureOperation::Square => {
627                        let square_features = result.mapv(|x| x * x);
628                        result = scirs2_core::ndarray::concatenate(
629                            Axis(1),
630                            &[result.view(), square_features.view()],
631                        )
632                        .unwrap();
633                    }
634                    FeatureOperation::Polynomial { degree } => {
635                        let mut poly_features = result.clone();
636                        for d in 2..=*degree {
637                            let power_features = result.mapv(|x| x.powi(d as i32));
638                            poly_features = scirs2_core::ndarray::concatenate(
639                                Axis(1),
640                                &[poly_features.view(), power_features.view()],
641                            )
642                            .unwrap();
643                        }
644                        result = poly_features;
645                    }
646                    FeatureOperation::Interaction { indices } => {
647                        if indices.len() >= 2 {
648                            let mut interaction_col = result.column(indices[0]).to_owned();
649                            for &idx in &indices[1..] {
650                                if idx < result.ncols() {
651                                    interaction_col *= &result.column(idx);
652                                }
653                            }
654                            result = scirs2_core::ndarray::concatenate(
655                                Axis(1),
656                                &[result.view(), interaction_col.insert_axis(Axis(1)).view()],
657                            )
658                            .unwrap();
659                        }
660                    }
661                    FeatureOperation::Binning { n_bins, strategy } => {
662                        // Simple uniform binning implementation
663                        let mut binned_features = Array2::zeros((result.nrows(), result.ncols()));
664
665                        for (col_idx, col) in result.axis_iter(Axis(1)).enumerate() {
666                            let min_val = col.iter().fold(f64::INFINITY, |a, &b| a.min(b));
667                            let max_val = col.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
668                            let bin_width = (max_val - min_val) / *n_bins as f64;
669
670                            for (row_idx, &val) in col.iter().enumerate() {
671                                let bin = ((val - min_val) / bin_width).floor() as usize;
672                                let bin = bin.min(n_bins - 1);
673                                binned_features[[row_idx, col_idx]] = bin as f64;
674                            }
675                        }
676
677                        result = scirs2_core::ndarray::concatenate(
678                            Axis(1),
679                            &[result.view(), binned_features.view()],
680                        )
681                        .unwrap();
682                    }
683                }
684            }
685
686            Ok(Box::new(result) as Box<dyn Any + Send + Sync>)
687        } else {
688            Err(IoError::Other(
689                "Invalid data type for feature engineering".to_string(),
690            ))
691        }
692    }
693}
694
695/// Text processing transformer
696pub struct TextProcessingTransform {
697    operations: Vec<TextOperation>,
698}
699
700#[derive(Debug, Clone)]
701pub enum TextOperation {
702    Lowercase,
703    RemovePunctuation,
704    RemoveStopwords,
705    Tokenize,
706    Stemming,
707    NGrams { n: usize },
708}
709
710impl TextProcessingTransform {
711    pub fn new(operations: Vec<TextOperation>) -> Self {
712        Self { operations }
713    }
714}
715
716impl DataTransformer for TextProcessingTransform {
717    fn transform(&self, data: Box<dyn Any + Send + Sync>) -> Result<Box<dyn Any + Send + Sync>> {
718        if let Ok(texts) = data.downcast::<Vec<String>>() {
719            let mut processed = texts.clone();
720
721            for operation in &self.operations {
722                match operation {
723                    TextOperation::Lowercase => {
724                        processed = Box::new(
725                            processed
726                                .into_iter()
727                                .map(|s| s.to_lowercase())
728                                .collect::<Vec<_>>(),
729                        );
730                    }
731                    TextOperation::RemovePunctuation => {
732                        processed = Box::new(
733                            processed
734                                .into_iter()
735                                .map(|s| {
736                                    s.chars()
737                                        .filter(|c| c.is_alphanumeric() || c.is_whitespace())
738                                        .collect()
739                                })
740                                .collect::<Vec<_>>(),
741                        );
742                    }
743                    TextOperation::Tokenize => {
744                        let tokens: Vec<Vec<String>> = processed
745                            .into_iter()
746                            .map(|s| s.split_whitespace().map(|w| w.to_string()).collect())
747                            .collect();
748                        return Ok(Box::new(tokens) as Box<dyn Any + Send + Sync>);
749                    }
750                    TextOperation::NGrams { n } => {
751                        let ngrams: Vec<Vec<String>> = processed
752                            .into_iter()
753                            .map(|s| {
754                                let words: Vec<&str> = s.split_whitespace().collect();
755                                words.windows(*n).map(|window| window.join(" ")).collect()
756                            })
757                            .collect();
758                        return Ok(Box::new(ngrams) as Box<dyn Any + Send + Sync>);
759                    }
760                    _ => {}
761                }
762            }
763
764            Ok(Box::new(processed) as Box<dyn Any + Send + Sync>)
765        } else {
766            Err(IoError::Other(
767                "Invalid data type for text processing".to_string(),
768            ))
769        }
770    }
771}
772
773#[cfg(test)]
774mod tests {
775    use super::*;
776    use scirs2_core::ndarray::arr2;
777
778    #[test]
779    fn test_normalize_minmax() {
780        let transform =
781            NormalizeTransform::<f64>::new(NormalizationMethod::MinMax { min: 0.0, max: 1.0 });
782        let data = Box::new(arr2(&[[1.0, 2.0], [3.0, 4.0]])) as Box<dyn Any + Send + Sync>;
783        let result = transform.transform(data).unwrap();
784        let normalized = result.downcast::<Array2<f64>>().unwrap();
785
786        assert!((normalized[[0, 0]] - 0.0).abs() < 1e-6);
787        assert!((normalized[[1, 1]] - 1.0).abs() < 1e-6);
788    }
789
790    #[test]
791    fn test_encoding_label() {
792        let transform = EncodingTransform::new(EncodingMethod::Label);
793        let data = Box::new(vec![
794            "cat".to_string(),
795            "dog".to_string(),
796            "cat".to_string(),
797        ]) as Box<dyn Any + Send + Sync>;
798        let result = transform.transform(data).unwrap();
799        let encoded = result.downcast::<Vec<i32>>().unwrap();
800
801        assert_eq!(*encoded, vec![0, 1, 0]);
802    }
803}