Skip to main content

scirs2_neural/data/
transforms.rs

1//! Data transforms for preprocessing inputs
2
3use crate::error::Result;
4use scirs2_core::ndarray::ArrayStatCompat;
5use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
6use scirs2_core::numeric::{Float, FromPrimitive};
7use statrs::statistics::Statistics;
8use std::fmt::Debug;
9/// Trait for data transforms
10pub trait Transform<F: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync>:
11    Send + Sync + Debug
12{
13    /// Apply the transform to the input
14    fn apply(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>>;
15    /// Get a description of the transform
16    fn description(&self) -> String;
17    /// Clone the transform (we need to implement it as a method since we can't derive Clone for trait objects)
18    fn box_clone(&self) -> Box<dyn Transform<F> + Send + Sync>;
19}
20/// Standard scaler transform
21#[derive(Debug, Clone)]
22pub struct StandardScaler<F: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> {
23    /// Mean values for each feature
24    mean: Option<Array<F, IxDyn>>,
25    /// Standard deviation values for each feature
26    std: Option<Array<F, IxDyn>>,
27    /// Whether to fit on the first dimension (samples)
28    fit_per_sample: bool,
29}
30
31impl<F: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> StandardScaler<F> {
32    /// Create a new standard scaler
33    pub fn new(fit_per_sample: bool) -> Self {
34        Self {
35            mean: None,
36            std: None,
37            fit_per_sample,
38        }
39    }
40    /// Fit the scaler to the data
41    pub fn fit(&mut self, data: &Array<F, IxDyn>) -> Result<&mut Self> {
42        let zero = F::from(0.0).unwrap_or(F::zero());
43        if data.ndim() < 2 {
44            // Just compute global mean and std
45            let mean = data.mean_or(F::zero());
46            let std = data.std(zero);
47            self.mean = Some(Array::from_elem(IxDyn(&[1]), mean));
48            self.std = Some(Array::from_elem(IxDyn(&[1]), std));
49        } else if self.fit_per_sample {
50            // Compute mean and std for each sample
51            let axis = 1; // Fit on feature dimension
52            let mean = data
53                .mean_axis(scirs2_core::ndarray::Axis(axis))
54                .unwrap_or(Array::zeros(IxDyn(&[data.shape()[0]])));
55            let std = data.std_axis(scirs2_core::ndarray::Axis(axis), zero);
56            self.mean = Some(mean);
57            self.std = Some(std);
58        } else {
59            // Compute mean and std for each feature
60            let axis = 0; // Fit on sample dimension
61            let mean = data
62                .mean_axis(scirs2_core::ndarray::Axis(axis))
63                .unwrap_or(Array::zeros(IxDyn(&[data.shape()[1]])));
64            let std = data.std_axis(scirs2_core::ndarray::Axis(axis), zero);
65            self.mean = Some(mean);
66            self.std = Some(std);
67        }
68        Ok(self)
69    }
70
71    /// Transform data using the fitted parameters
72    pub fn transform(&self, data: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
73        if self.mean.is_none() || self.std.is_none() {
74            return Err(crate::error::NeuralError::InferenceError(
75                "StandardScaler has not been fitted".to_string(),
76            ));
77        }
78
79        let mean = self.mean.as_ref().expect("Operation failed");
80        let std = self.std.as_ref().expect("Operation failed");
81        let mut result = data.clone();
82
83        if data.ndim() < 2 {
84            // Apply global mean and std
85            let mean_val = mean[[0]];
86            let std_val = std[[0]].max(F::epsilon());
87            for item in result.iter_mut() {
88                *item = (*item - mean_val) / std_val;
89            }
90        } else if self.fit_per_sample {
91            // Apply per-sample normalization
92            for i in 0..data.shape()[0] {
93                let mean_val = mean[[i]];
94                let std_val = std[[i]].max(F::epsilon());
95                for j in 0..data.shape()[1] {
96                    result[[i, j]] = (data[[i, j]] - mean_val) / std_val;
97                }
98            }
99        } else {
100            // Apply per-feature normalization
101            for j in 0..data.shape()[1] {
102                let mean_val = mean[[j]];
103                let std_val = std[[j]].max(F::epsilon());
104                for i in 0..data.shape()[0] {
105                    result[[i, j]] = (data[[i, j]] - mean_val) / std_val;
106                }
107            }
108        }
109
110        Ok(result)
111    }
112
113    /// Fit to data and then transform it
114    pub fn fit_transform(&mut self, data: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
115        self.fit(data)?;
116        self.transform(data)
117    }
118}
119
120impl<F: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> Transform<F>
121    for StandardScaler<F>
122{
123    fn apply(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
124        self.transform(input)
125    }
126
127    fn description(&self) -> String {
128        if self.fit_per_sample {
129            "StandardScaler (per-sample)".to_string()
130        } else {
131            "StandardScaler (per-feature)".to_string()
132        }
133    }
134
135    fn box_clone(&self) -> Box<dyn Transform<F> + Send + Sync> {
136        Box::new(self.clone())
137    }
138}
139
140/// MinMax scaler transform
141#[derive(Debug, Clone)]
142pub struct MinMaxScaler<F: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> {
143    /// Minimum values for each feature
144    min: Option<Array<F, IxDyn>>,
145    /// Maximum values for each feature
146    max: Option<Array<F, IxDyn>>,
147    /// Target range for scaling (default: [0, 1])
148    range: (F, F),
149    /// Whether to fit on the first dimension (samples)
150    fit_per_sample: bool,
151}
152
153impl<F: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> MinMaxScaler<F> {
154    /// Create a new MinMax scaler with default range [0, 1]
155    pub fn new(fit_per_sample: bool) -> Self {
156        Self::with_range(F::zero(), F::one(), fit_per_sample)
157    }
158
159    /// Create a new MinMax scaler with custom range
160    pub fn with_range(min_val: F, max_val: F, fit_per_sample: bool) -> Self {
161        Self {
162            min: None,
163            max: None,
164            range: (min_val, max_val),
165            fit_per_sample,
166        }
167    }
168
169    /// Fit the scaler to the data
170    pub fn fit(&mut self, data: &Array<F, IxDyn>) -> Result<&mut Self> {
171        if data.ndim() < 2 {
172            // Just compute global min and max
173            let min = match data
174                .iter()
175                .min_by(|a, b| a.partial_cmp(b).expect("Operation failed"))
176            {
177                Some(&val) => val,
178                None => F::zero(),
179            };
180            let max = match data
181                .iter()
182                .max_by(|a, b| a.partial_cmp(b).expect("Operation failed"))
183            {
184                Some(&val) => val,
185                None => F::one(),
186            };
187            self.min = Some(Array::from_elem(IxDyn(&[1]), min));
188            self.max = Some(Array::from_elem(IxDyn(&[1]), max));
189        } else if self.fit_per_sample {
190            // Compute min and max for each sample
191            let mut min_vals = Array::zeros(IxDyn(&[data.shape()[0]]));
192            let mut max_vals = Array::zeros(IxDyn(&[data.shape()[0]]));
193            for i in 0..data.shape()[0] {
194                let row = data.slice(scirs2_core::ndarray::s![i, ..]);
195                let min = match row
196                    .iter()
197                    .min_by(|a, b| a.partial_cmp(b).expect("Operation failed"))
198                {
199                    Some(&val) => val,
200                    None => F::zero(),
201                };
202                let max = match row
203                    .iter()
204                    .max_by(|a, b| a.partial_cmp(b).expect("Operation failed"))
205                {
206                    Some(&val) => val,
207                    None => F::one(),
208                };
209                min_vals[[i]] = min;
210                max_vals[[i]] = max;
211            }
212            self.min = Some(min_vals);
213            self.max = Some(max_vals);
214        } else {
215            // Compute min and max for each feature
216            let mut min_vals = Array::zeros(IxDyn(&[data.shape()[1]]));
217            let mut max_vals = Array::zeros(IxDyn(&[data.shape()[1]]));
218            for j in 0..data.shape()[1] {
219                let col = data.slice(scirs2_core::ndarray::s![.., j]);
220                let min = match col
221                    .iter()
222                    .min_by(|a, b| a.partial_cmp(b).expect("Operation failed"))
223                {
224                    Some(&val) => val,
225                    None => F::zero(),
226                };
227                let max = match col
228                    .iter()
229                    .max_by(|a, b| a.partial_cmp(b).expect("Operation failed"))
230                {
231                    Some(&val) => val,
232                    None => F::one(),
233                };
234                min_vals[[j]] = min;
235                max_vals[[j]] = max;
236            }
237            self.min = Some(min_vals);
238            self.max = Some(max_vals);
239        }
240        Ok(self)
241    }
242
243    /// Transform data using the fitted parameters
244    pub fn transform(&self, data: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
245        if self.min.is_none() || self.max.is_none() {
246            return Err(crate::error::NeuralError::InferenceError(
247                "MinMaxScaler has not been fitted".to_string(),
248            ));
249        }
250
251        let min = self.min.as_ref().expect("Operation failed");
252        let max = self.max.as_ref().expect("Operation failed");
253        let (range_min, range_max) = self.range;
254        let range_diff = range_max - range_min;
255        let mut result = data.clone();
256
257        if data.ndim() < 2 {
258            // Apply global min and max
259            let min_val = min[[0]];
260            let max_val = max[[0]];
261            let scale = if max_val > min_val {
262                F::one() / (max_val - min_val)
263            } else {
264                F::one()
265            };
266            for item in result.iter_mut() {
267                *item = range_min + range_diff * ((*item - min_val) * scale);
268            }
269        } else if self.fit_per_sample {
270            // Apply per-sample normalization
271            for i in 0..data.shape()[0] {
272                let min_val = min[[i]];
273                let max_val = max[[i]];
274                let scale = if max_val > min_val {
275                    F::one() / (max_val - min_val)
276                } else {
277                    F::one()
278                };
279                for j in 0..data.shape()[1] {
280                    result[[i, j]] = range_min + range_diff * ((data[[i, j]] - min_val) * scale);
281                }
282            }
283        } else {
284            // Apply per-feature normalization
285            for j in 0..data.shape()[1] {
286                let min_val = min[[j]];
287                let max_val = max[[j]];
288                let scale = if max_val > min_val {
289                    F::one() / (max_val - min_val)
290                } else {
291                    F::one()
292                };
293                for i in 0..data.shape()[0] {
294                    result[[i, j]] = range_min + range_diff * ((data[[i, j]] - min_val) * scale);
295                }
296            }
297        }
298
299        Ok(result)
300    }
301
302    /// Fit to data and then transform it
303    pub fn fit_transform(&mut self, data: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
304        self.fit(data)?;
305        self.transform(data)
306    }
307}
308
309impl<F: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> Transform<F>
310    for MinMaxScaler<F>
311{
312    fn apply(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
313        self.transform(input)
314    }
315
316    fn description(&self) -> String {
317        format!(
318            "MinMaxScaler (range: [{:.1}, {:.1}], {})",
319            self.range.0.to_f64().unwrap_or(0.0),
320            self.range.1.to_f64().unwrap_or(1.0),
321            if self.fit_per_sample {
322                "per-sample"
323            } else {
324                "per-feature"
325            }
326        )
327    }
328
329    fn box_clone(&self) -> Box<dyn Transform<F> + Send + Sync> {
330        Box::new(self.clone())
331    }
332}
333
334/// One-hot encoder transform
335#[derive(Debug, Clone)]
336pub struct OneHotEncoder<F: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> {
337    /// Number of classes
338    n_classes: usize,
339    /// Phantom data for generic type
340    _phantom: std::marker::PhantomData<F>,
341}
342
343impl<F: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> OneHotEncoder<F> {
344    /// Create a new one-hot encoder
345    pub fn new(n_classes: usize) -> Self {
346        Self {
347            n_classes,
348            _phantom: std::marker::PhantomData,
349        }
350    }
351
352    /// Transform class indices to one-hot encoded vectors
353    pub fn transform(&self, data: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
354        let shape = data.shape();
355        let n_samples = shape[0];
356
357        // Create output array with shape [n_samples, n_classes]
358        let mut result = Array::zeros(IxDyn(&[n_samples, self.n_classes]));
359
360        // Fill one-hot encoded values
361        for i in 0..n_samples {
362            let class_idx = data[[i]].to_usize().unwrap_or(0);
363            if class_idx >= self.n_classes {
364                return Err(crate::error::NeuralError::InferenceError(format!(
365                    "Class index {} is out of bounds for {} classes",
366                    class_idx, self.n_classes
367                )));
368            }
369            result[[i, class_idx]] = F::one();
370        }
371
372        Ok(result)
373    }
374}
375
376impl<F: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> Transform<F>
377    for OneHotEncoder<F>
378{
379    fn apply(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
380        self.transform(input)
381    }
382
383    fn description(&self) -> String {
384        format!("OneHotEncoder (n_classes: {})", self.n_classes)
385    }
386
387    fn box_clone(&self) -> Box<dyn Transform<F> + Send + Sync> {
388        Box::new(self.clone())
389    }
390}
391
392/// Compose multiple transforms into a single transform
393pub struct ComposeTransform<F: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> {
394    /// List of transforms to apply in sequence
395    transforms: Vec<Box<dyn Transform<F> + Send + Sync>>,
396}
397
398/// Debug wrapper for a trait object transform
399struct DebugTransformWrapper<'a, F: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> {
400    /// Reference to the transform
401    inner: &'a (dyn Transform<F> + Send + Sync),
402}
403
404impl<F: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> Debug
405    for DebugTransformWrapper<'_, F>
406{
407    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
408        write!(f, "Transform({})", self.inner.description())
409    }
410}
411
412impl<F: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> Debug for ComposeTransform<F> {
413    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
414        let mut debug_list = f.debug_list();
415        for transform in &self.transforms {
416            debug_list.entry(&DebugTransformWrapper {
417                inner: transform.as_ref(),
418            });
419        }
420        debug_list.finish()
421    }
422}
423
424impl<F: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> Clone for ComposeTransform<F> {
425    fn clone(&self) -> Self {
426        Self {
427            transforms: self
428                .transforms
429                .iter()
430                .map(|transform| transform.box_clone())
431                .collect(),
432        }
433    }
434}
435
436impl<F: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> ComposeTransform<F> {
437    /// Create a new composition of transforms
438    pub fn new(transforms: Vec<Box<dyn Transform<F> + Send + Sync>>) -> Self {
439        Self { transforms }
440    }
441}
442
443impl<F: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> Transform<F>
444    for ComposeTransform<F>
445{
446    fn apply(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
447        let mut data = input.clone();
448        for transform in &self.transforms {
449            data = transform.apply(&data)?;
450        }
451        Ok(data)
452    }
453
454    fn description(&self) -> String {
455        let descriptions: Vec<String> = self.transforms.iter().map(|t| t.description()).collect();
456        format!("Compose({})", descriptions.join(", "))
457    }
458
459    fn box_clone(&self) -> Box<dyn Transform<F> + Send + Sync> {
460        Box::new(self.clone())
461    }
462}