Skip to main content

tenflowers_dataset/transforms/
normalization.rs

1//! Normalization transformations for datasets
2
3use crate::{transforms::Transform, Dataset};
4use tenflowers_core::{Result, Tensor, TensorError};
5
6/// Normalize features by subtracting mean and dividing by standard deviation
7pub struct Normalize<T> {
8    mean: Vec<T>,
9    std: Vec<T>,
10}
11
12impl<T> Normalize<T>
13where
14    T: Clone
15        + Default
16        + scirs2_core::numeric::Float
17        + Send
18        + Sync
19        + 'static
20        + bytemuck::Pod
21        + bytemuck::Zeroable,
22{
23    pub fn new(mean: Vec<T>, std: Vec<T>) -> Result<Self> {
24        if mean.len() != std.len() {
25            return Err(TensorError::invalid_argument(
26                "Mean and std vectors must have the same length".to_string(),
27            ));
28        }
29        Ok(Self { mean, std })
30    }
31
32    /// Compute normalization parameters from a dataset
33    pub fn from_dataset<D: Dataset<T>>(dataset: &D) -> Result<Self> {
34        if dataset.is_empty() {
35            return Err(TensorError::invalid_argument(
36                "Cannot compute normalization from empty dataset".to_string(),
37            ));
38        }
39
40        // Get first sample to determine feature dimension
41        let (first_features, _) = dataset.get(0)?;
42        let feature_dim = first_features.shape().size();
43
44        let mut feature_sums = vec![T::zero(); feature_dim];
45        let mut feature_sq_sums = vec![T::zero(); feature_dim];
46        let n = T::from(dataset.len()).expect("dataset length should convert to float");
47
48        // Compute means and variances
49        for i in 0..dataset.len() {
50            let (features, _) = dataset.get(i)?;
51
52            // Flatten features to 1D for computation
53            let flat_features = tenflowers_core::ops::reshape(&features, &[feature_dim])?;
54
55            for j in 0..feature_dim {
56                if let Some(val) = flat_features.get(&[j]) {
57                    feature_sums[j] = feature_sums[j] + val;
58                    feature_sq_sums[j] = feature_sq_sums[j] + val * val;
59                }
60            }
61        }
62
63        let mut means = Vec::new();
64        let mut stds = Vec::new();
65
66        for i in 0..feature_dim {
67            let mean = feature_sums[i] / n;
68            let variance = (feature_sq_sums[i] / n) - (mean * mean);
69            let std = variance.sqrt();
70
71            means.push(mean);
72            stds.push(std);
73        }
74
75        Self::new(means, stds)
76    }
77
78    /// Create mean tensor for given feature dimension
79    fn create_mean_tensor(&self, feature_dim: usize) -> Result<Tensor<T>> {
80        // Extend or truncate mean vector to match feature dimension
81        let mut mean_vec = self.mean.clone();
82        match mean_vec.len().cmp(&feature_dim) {
83            std::cmp::Ordering::Less => {
84                // Repeat last value if we need more elements
85                if let Some(last_val) = mean_vec.last() {
86                    mean_vec.resize(feature_dim, *last_val);
87                } else {
88                    mean_vec.resize(feature_dim, T::zero());
89                }
90            }
91            std::cmp::Ordering::Greater => {
92                // Truncate if we have too many elements
93                mean_vec.truncate(feature_dim);
94            }
95            std::cmp::Ordering::Equal => {
96                // Perfect match, no changes needed
97            }
98        }
99        Tensor::from_vec(mean_vec, &[feature_dim])
100    }
101
102    /// Create std tensor for given feature dimension
103    fn create_std_tensor(&self, feature_dim: usize) -> Result<Tensor<T>> {
104        // Extend or truncate std vector to match feature dimension
105        let mut std_vec = self.std.clone();
106        match std_vec.len().cmp(&feature_dim) {
107            std::cmp::Ordering::Less => {
108                // Repeat last value if we need more elements
109                if let Some(last_val) = std_vec.last() {
110                    std_vec.resize(feature_dim, *last_val);
111                } else {
112                    std_vec.resize(feature_dim, T::one());
113                }
114            }
115            std::cmp::Ordering::Greater => {
116                // Truncate if we have too many elements
117                std_vec.truncate(feature_dim);
118            }
119            std::cmp::Ordering::Equal => {
120                // Perfect match, no changes needed
121            }
122        }
123        Tensor::from_vec(std_vec, &[feature_dim])
124    }
125}
126
127impl<T> Transform<T> for Normalize<T>
128where
129    T: Clone
130        + Default
131        + scirs2_core::numeric::Float
132        + Send
133        + Sync
134        + 'static
135        + bytemuck::Pod
136        + bytemuck::Zeroable,
137{
138    fn apply(&self, sample: (Tensor<T>, Tensor<T>)) -> Result<(Tensor<T>, Tensor<T>)> {
139        let (features, labels) = sample;
140        let original_shape = features.shape().dims().to_vec();
141        let feature_dim = features.shape().size();
142
143        // Flatten features for normalization
144        let flat_features = tenflowers_core::ops::reshape(&features, &[feature_dim])?;
145
146        // Create mean and std tensors using optimized helper methods
147        let mean_tensor = self.create_mean_tensor(feature_dim)?;
148        let std_tensor = self.create_std_tensor(feature_dim)?;
149
150        // Normalize: (x - mean) / std
151        let centered = flat_features.sub(&mean_tensor)?;
152        let normalized = centered.div(&std_tensor)?;
153
154        // Reshape back to original shape
155        let normalized_features = tenflowers_core::ops::reshape(&normalized, &original_shape)?;
156
157        Ok((normalized_features, labels))
158    }
159}
160
161/// Scale features to a specific range [min_val, max_val]
162pub struct MinMaxScale<T> {
163    data_min: Vec<T>,
164    data_max: Vec<T>,
165    feature_range: (T, T),
166}
167
168impl<T> MinMaxScale<T>
169where
170    T: Clone
171        + Default
172        + scirs2_core::numeric::Float
173        + Send
174        + Sync
175        + 'static
176        + bytemuck::Pod
177        + bytemuck::Zeroable,
178{
179    pub fn new(data_min: Vec<T>, data_max: Vec<T>, feature_range: (T, T)) -> Result<Self> {
180        if data_min.len() != data_max.len() {
181            return Err(TensorError::invalid_argument(
182                "Data min and max vectors must have the same length".to_string(),
183            ));
184        }
185        Ok(Self {
186            data_min,
187            data_max,
188            feature_range,
189        })
190    }
191
192    /// Compute scaling parameters from a dataset
193    pub fn from_dataset<D: Dataset<T>>(dataset: &D, feature_range: (T, T)) -> Result<Self> {
194        if dataset.is_empty() {
195            return Err(TensorError::invalid_argument(
196                "Cannot compute scaling from empty dataset".to_string(),
197            ));
198        }
199
200        // Get first sample to determine feature dimension
201        let (first_features, _) = dataset.get(0)?;
202        let feature_dim = first_features.shape().size();
203
204        let mut data_min = vec![T::infinity(); feature_dim];
205        let mut data_max = vec![T::neg_infinity(); feature_dim];
206
207        // Find min and max values for each feature
208        for i in 0..dataset.len() {
209            let (features, _) = dataset.get(i)?;
210            let flat_features = tenflowers_core::ops::reshape(&features, &[feature_dim])?;
211
212            for j in 0..feature_dim {
213                if let Some(val) = flat_features.get(&[j]) {
214                    if val < data_min[j] {
215                        data_min[j] = val;
216                    }
217                    if val > data_max[j] {
218                        data_max[j] = val;
219                    }
220                }
221            }
222        }
223
224        Self::new(data_min, data_max, feature_range)
225    }
226}
227
228impl<T> Transform<T> for MinMaxScale<T>
229where
230    T: Clone
231        + Default
232        + scirs2_core::numeric::Float
233        + Send
234        + Sync
235        + 'static
236        + bytemuck::Pod
237        + bytemuck::Zeroable,
238{
239    fn apply(&self, sample: (Tensor<T>, Tensor<T>)) -> Result<(Tensor<T>, Tensor<T>)> {
240        let (features, labels) = sample;
241        let original_shape = features.shape().dims().to_vec();
242        let feature_dim = features.shape().size();
243
244        // Flatten features for scaling
245        let flat_features = tenflowers_core::ops::reshape(&features, &[feature_dim])?;
246
247        let mut scaled_data = Vec::with_capacity(feature_dim);
248        let (min_range, max_range) = self.feature_range;
249        let range_scale = max_range - min_range;
250
251        for i in 0..feature_dim {
252            if let Some(val) = flat_features.get(&[i]) {
253                let data_range = self.data_max[i] - self.data_min[i];
254                let scaled = if data_range == T::zero() {
255                    min_range // If no variance, set to min of range
256                } else {
257                    min_range + (val - self.data_min[i]) / data_range * range_scale
258                };
259                scaled_data.push(scaled);
260            } else {
261                return Err(TensorError::invalid_argument(
262                    "Failed to get feature value".to_string(),
263                ));
264            }
265        }
266
267        let scaled_tensor = Tensor::from_vec(scaled_data, &[feature_dim])?;
268        let scaled_features = tenflowers_core::ops::reshape(&scaled_tensor, &original_shape)?;
269
270        Ok((scaled_features, labels))
271    }
272}
273
274/// Robust scaler using median and IQR instead of mean and std
275pub struct RobustScaler<T> {
276    medians: Vec<T>,
277    iqrs: Vec<T>,
278}
279
280impl<T> RobustScaler<T>
281where
282    T: Clone
283        + Default
284        + scirs2_core::numeric::Float
285        + Send
286        + Sync
287        + 'static
288        + bytemuck::Pod
289        + bytemuck::Zeroable,
290{
291    pub fn new(medians: Vec<T>, iqrs: Vec<T>) -> Result<Self> {
292        if medians.len() != iqrs.len() {
293            return Err(TensorError::invalid_argument(
294                "Medians and IQRs vectors must have the same length".to_string(),
295            ));
296        }
297        Ok(Self { medians, iqrs })
298    }
299
300    /// Compute robust scaling parameters from a dataset
301    pub fn from_dataset<D: Dataset<T>>(dataset: &D) -> Result<Self> {
302        if dataset.is_empty() {
303            return Err(TensorError::invalid_argument(
304                "Cannot compute robust scaling from empty dataset".to_string(),
305            ));
306        }
307
308        // Get first sample to determine feature dimension
309        let (first_features, _) = dataset.get(0)?;
310        let feature_dim = first_features.shape().size();
311
312        // Collect all feature values by dimension
313        let mut feature_values: Vec<Vec<T>> = vec![Vec::new(); feature_dim];
314
315        for i in 0..dataset.len() {
316            let (features, _) = dataset.get(i)?;
317            let flat_features = tenflowers_core::ops::reshape(&features, &[feature_dim])?;
318
319            for (j, feature_value) in feature_values.iter_mut().enumerate().take(feature_dim) {
320                if let Some(val) = flat_features.get(&[j]) {
321                    feature_value.push(val);
322                }
323            }
324        }
325
326        // Compute medians and IQRs for each dimension
327        let mut medians = Vec::new();
328        let mut iqrs = Vec::new();
329
330        for values in feature_values {
331            if values.is_empty() {
332                medians.push(T::zero());
333                iqrs.push(T::one());
334                continue;
335            }
336
337            let mut sorted_values = values;
338            sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
339
340            let n = sorted_values.len();
341            let median = if n % 2 == 0 {
342                (sorted_values[n / 2 - 1] + sorted_values[n / 2])
343                    / T::from(2.0).expect("constant 2.0 should convert to float")
344            } else {
345                sorted_values[n / 2]
346            };
347
348            let q1_idx = n / 4;
349            let q3_idx = (3 * n) / 4;
350            let q1 = sorted_values[q1_idx];
351            let q3 = sorted_values[q3_idx];
352            let iqr = q3 - q1;
353
354            medians.push(median);
355            iqrs.push(if iqr > T::zero() { iqr } else { T::one() });
356        }
357
358        Ok(Self { medians, iqrs })
359    }
360}
361
362impl<T> Transform<T> for RobustScaler<T>
363where
364    T: Clone
365        + Default
366        + scirs2_core::numeric::Float
367        + Send
368        + Sync
369        + 'static
370        + bytemuck::Pod
371        + bytemuck::Zeroable,
372{
373    fn apply(&self, sample: (Tensor<T>, Tensor<T>)) -> Result<(Tensor<T>, Tensor<T>)> {
374        let (features, labels) = sample;
375        let original_shape = features.shape().dims().to_vec();
376        let feature_dim = features.shape().size();
377
378        // Flatten features for normalization
379        let flat_features = tenflowers_core::ops::reshape(&features, &[feature_dim])?;
380
381        // Apply robust scaling: (x - median) / IQR
382        let mut scaled_data = Vec::new();
383        for i in 0..feature_dim {
384            let idx = i % self.medians.len();
385            if let Some(val) = flat_features.get(&[i]) {
386                let scaled = (val - self.medians[idx]) / self.iqrs[idx];
387                scaled_data.push(scaled);
388            }
389        }
390
391        let scaled_tensor = Tensor::from_vec(scaled_data, &[feature_dim])?;
392        let reshaped_features = tenflowers_core::ops::reshape(&scaled_tensor, &original_shape)?;
393
394        Ok((reshaped_features, labels))
395    }
396}
397
398/// Per-channel normalization for multi-channel data (e.g., RGB images)
399pub struct PerChannelNormalize<T> {
400    channel_means: Vec<T>,
401    channel_stds: Vec<T>,
402}
403
404impl<T> PerChannelNormalize<T>
405where
406    T: Clone
407        + Default
408        + scirs2_core::numeric::Float
409        + Send
410        + Sync
411        + 'static
412        + bytemuck::Pod
413        + bytemuck::Zeroable,
414{
415    pub fn new(channel_means: Vec<T>, channel_stds: Vec<T>) -> Result<Self> {
416        if channel_means.len() != channel_stds.len() {
417            return Err(TensorError::invalid_argument(
418                "Channel means and stds must have the same length".to_string(),
419            ));
420        }
421        Ok(Self {
422            channel_means,
423            channel_stds,
424        })
425    }
426
427    /// Common ImageNet normalization values
428    pub fn imagenet() -> Self {
429        Self {
430            channel_means: vec![
431                T::from(0.485).expect("constant should convert to float"),
432                T::from(0.456).expect("constant should convert to float"),
433                T::from(0.406).expect("constant should convert to float"),
434            ],
435            channel_stds: vec![
436                T::from(0.229).expect("constant should convert to float"),
437                T::from(0.224).expect("constant should convert to float"),
438                T::from(0.225).expect("constant should convert to float"),
439            ],
440        }
441    }
442}
443
444impl<T> Transform<T> for PerChannelNormalize<T>
445where
446    T: Clone
447        + Default
448        + scirs2_core::numeric::Float
449        + Send
450        + Sync
451        + 'static
452        + bytemuck::Pod
453        + bytemuck::Zeroable,
454{
455    fn apply(&self, sample: (Tensor<T>, Tensor<T>)) -> Result<(Tensor<T>, Tensor<T>)> {
456        let (features, labels) = sample;
457        let shape = features.shape().dims();
458
459        // Assume features are in format [channels, height, width] or [channels, ...]
460        if shape.is_empty() {
461            return Ok((features, labels));
462        }
463
464        let channels = shape[0];
465        if channels != self.channel_means.len() {
466            return Err(TensorError::invalid_argument(format!(
467                "Expected {} channels, got {}",
468                self.channel_means.len(),
469                channels
470            )));
471        }
472
473        let data = features.as_slice().ok_or_else(|| {
474            TensorError::invalid_argument(
475                "Cannot access tensor data (GPU tensor not supported)".to_string(),
476            )
477        })?;
478        let mut normalized_data = Vec::new();
479
480        let channel_size = data.len() / channels;
481
482        for c in 0..channels {
483            let start = c * channel_size;
484            let end = start + channel_size;
485
486            for value in data.iter().skip(start).take(end - start) {
487                let normalized = (*value - self.channel_means[c]) / self.channel_stds[c];
488                normalized_data.push(normalized);
489            }
490        }
491
492        let normalized_tensor = Tensor::from_vec(normalized_data, shape)?;
493        Ok((normalized_tensor, labels))
494    }
495}
496
497/// Global normalization across all samples in the dataset
498pub struct GlobalNormalize<T> {
499    global_mean: T,
500    global_std: T,
501}
502
503impl<T> GlobalNormalize<T>
504where
505    T: Clone
506        + Default
507        + scirs2_core::numeric::Float
508        + Send
509        + Sync
510        + 'static
511        + bytemuck::Pod
512        + bytemuck::Zeroable,
513{
514    pub fn new(global_mean: T, global_std: T) -> Self {
515        Self {
516            global_mean,
517            global_std,
518        }
519    }
520
521    /// Compute global normalization parameters from a dataset
522    pub fn from_dataset<D: Dataset<T>>(dataset: &D) -> Result<Self> {
523        if dataset.is_empty() {
524            return Err(TensorError::invalid_argument(
525                "Cannot compute global normalization from empty dataset".to_string(),
526            ));
527        }
528
529        let mut total_sum = T::zero();
530        let mut total_sq_sum = T::zero();
531        let mut total_count = 0;
532
533        for i in 0..dataset.len() {
534            let (features, _) = dataset.get(i)?;
535            let data = features.as_slice().ok_or_else(|| {
536                TensorError::invalid_argument(
537                    "Cannot access tensor data (GPU tensor not supported)".to_string(),
538                )
539            })?;
540
541            for &val in data {
542                total_sum = total_sum + val;
543                total_sq_sum = total_sq_sum + val * val;
544                total_count += 1;
545            }
546        }
547
548        let n = T::from(total_count).expect("count should convert to float");
549        let mean = total_sum / n;
550        let variance = (total_sq_sum / n) - (mean * mean);
551        let std = variance.sqrt();
552
553        Ok(Self {
554            global_mean: mean,
555            global_std: if std > T::zero() { std } else { T::one() },
556        })
557    }
558}
559
560impl<T> Transform<T> for GlobalNormalize<T>
561where
562    T: Clone
563        + Default
564        + scirs2_core::numeric::Float
565        + Send
566        + Sync
567        + 'static
568        + bytemuck::Pod
569        + bytemuck::Zeroable,
570{
571    fn apply(&self, sample: (Tensor<T>, Tensor<T>)) -> Result<(Tensor<T>, Tensor<T>)> {
572        let (features, labels) = sample;
573        let shape = features.shape().dims();
574        let data = features.as_slice().ok_or_else(|| {
575            TensorError::invalid_argument(
576                "Cannot access tensor data (GPU tensor not supported)".to_string(),
577            )
578        })?;
579
580        let normalized_data: Vec<T> = data
581            .iter()
582            .map(|&val| (val - self.global_mean) / self.global_std)
583            .collect();
584
585        let normalized_tensor = Tensor::from_vec(normalized_data, shape)?;
586        Ok((normalized_tensor, labels))
587    }
588}