Skip to main content

tenflowers_dataset/simd_transforms/
normalization.rs

1//! SIMD-accelerated normalization transforms
2//!
3//! This module provides vectorized implementations of data normalization
4//! using SIMD instructions for significant performance improvements.
5
6#![allow(unsafe_code)]
7
8use crate::Transform;
9use std::marker::PhantomData;
10use tenflowers_core::{Result, Tensor, TensorError};
11
12#[cfg(target_arch = "x86_64")]
13use std::arch::x86_64::*;
14
15/// SIMD-accelerated normalization transform
16///
17/// Uses AVX2 instructions when available for up to 8x performance improvement
18/// over scalar normalization on compatible hardware.
19pub struct SimdNormalize<T> {
20    mean: Vec<T>,
21    std: Vec<T>,
22    use_simd: bool,
23}
24
25impl<T> SimdNormalize<T>
26where
27    T: Clone + Default + scirs2_core::numeric::Float + Send + Sync + 'static,
28{
29    /// Create a new SIMD-accelerated normalization transform
30    pub fn new(mean: Vec<T>, std: Vec<T>) -> Self {
31        #[cfg(target_arch = "x86_64")]
32        let use_simd = is_x86_feature_detected!("avx2") && std::mem::size_of::<T>() == 4;
33
34        #[cfg(not(target_arch = "x86_64"))]
35        let use_simd = false;
36
37        Self {
38            mean,
39            std,
40            use_simd,
41        }
42    }
43
44    /// Get SIMD capability status
45    pub fn is_simd_enabled(&self) -> bool {
46        self.use_simd
47    }
48
49    /// SIMD-accelerated normalization for f32 data
50    #[cfg(target_arch = "x86_64")]
51    unsafe fn normalize_f32_simd(&self, data: &mut [f32], mean: f32, std: f32) {
52        if !self.use_simd || data.len() < 8 {
53            // Fall back to scalar for small arrays
54            self.normalize_scalar_f32(data, mean, std);
55            return;
56        }
57
58        let mean_vec = _mm256_set1_ps(mean);
59        let inv_std_vec = _mm256_set1_ps(1.0 / std);
60
61        let chunks = data.len() / 8;
62        let remainder = data.len() % 8;
63
64        // Process 8 elements at a time using AVX2
65        for i in 0..chunks {
66            let offset = i * 8;
67            let values = _mm256_loadu_ps(data.as_ptr().add(offset));
68
69            // Subtract mean
70            let centered = _mm256_sub_ps(values, mean_vec);
71
72            // Multiply by inverse std
73            let normalized = _mm256_mul_ps(centered, inv_std_vec);
74
75            _mm256_storeu_ps(data.as_mut_ptr().add(offset), normalized);
76        }
77
78        // Handle remaining elements with scalar operations
79        if remainder > 0 {
80            let start = chunks * 8;
81            self.normalize_scalar_f32(&mut data[start..], mean, std);
82        }
83    }
84
85    /// Scalar fallback for normalization
86    fn normalize_scalar(&self, data: &mut [T], mean: T, std: T)
87    where
88        T: scirs2_core::numeric::Float,
89    {
90        for value in data.iter_mut() {
91            *value = (*value - mean) / std;
92        }
93    }
94
95    /// Scalar fallback for f32 normalization
96    #[allow(dead_code)]
97    fn normalize_scalar_f32(&self, data: &mut [f32], mean: f32, std: f32) {
98        for value in data.iter_mut() {
99            *value = (*value - mean) / std;
100        }
101    }
102}
103
104impl<T> Transform<T> for SimdNormalize<T>
105where
106    T: Clone + Default + scirs2_core::numeric::Float + Send + Sync + 'static,
107{
108    fn apply(&self, sample: (Tensor<T>, Tensor<T>)) -> Result<(Tensor<T>, Tensor<T>)> {
109        let (features, labels) = sample;
110
111        // Note: For now we'll work with immutable data and create a new tensor
112        // In a real implementation, we'd need mutable tensor access
113        if let Some(data) = features.as_slice() {
114            let mut mutable_data = data.to_vec();
115            let feature_count = self.mean.len();
116
117            if mutable_data.len() % feature_count != 0 {
118                return Err(TensorError::invalid_argument(
119                    "Feature tensor size must be divisible by number of features".to_string(),
120                ));
121            }
122
123            let samples = mutable_data.len() / feature_count;
124
125            // Normalize each feature dimension
126            for feature_idx in 0..feature_count {
127                let mean = self.mean[feature_idx];
128                let std = self.std[feature_idx];
129
130                // Skip normalization if std is zero
131                if std == T::zero() {
132                    continue;
133                }
134
135                // Extract feature values across all samples
136                let mut feature_values: Vec<T> = (0..samples)
137                    .map(|sample_idx| mutable_data[sample_idx * feature_count + feature_idx])
138                    .collect();
139
140                // Apply SIMD normalization if available and appropriate
141                #[cfg(target_arch = "x86_64")]
142                {
143                    if self.use_simd && std::mem::size_of::<T>() == 4 {
144                        let mean_f32 = unsafe { std::mem::transmute_copy::<T, f32>(&mean) };
145                        let std_f32 = unsafe { std::mem::transmute_copy::<T, f32>(&std) };
146                        let feature_f32 = unsafe {
147                            std::slice::from_raw_parts_mut(
148                                feature_values.as_mut_ptr() as *mut f32,
149                                feature_values.len(),
150                            )
151                        };
152
153                        unsafe {
154                            self.normalize_f32_simd(feature_f32, mean_f32, std_f32);
155                        }
156                    } else {
157                        self.normalize_scalar(&mut feature_values, mean, std);
158                    }
159                }
160                #[cfg(not(target_arch = "x86_64"))]
161                {
162                    self.normalize_scalar(&mut feature_values, mean, std);
163                }
164
165                // Write normalized values back
166                for (sample_idx, &normalized_value) in feature_values.iter().enumerate() {
167                    mutable_data[sample_idx * feature_count + feature_idx] = normalized_value;
168                }
169            }
170
171            // Create new tensor with normalized data
172            let new_features = Tensor::from_vec(mutable_data, features.shape().dims())?;
173            Ok((new_features, labels))
174        } else {
175            Err(TensorError::invalid_argument(
176                "Cannot access tensor data for normalization".to_string(),
177            ))
178        }
179    }
180}
181
182/// SIMD-accelerated normalization for scalar-only operations
183///
184/// Simplified version that only supports scalar operations for compatibility.
185pub struct SimdNormalizeScalarOnly<T> {
186    _marker: PhantomData<T>,
187}
188
189impl<T> SimdNormalizeScalarOnly<T>
190where
191    T: Clone + Default + Send + Sync + 'static,
192{
193    /// Create a new scalar-only normalization transform
194    pub fn new() -> Self {
195        Self {
196            _marker: PhantomData,
197        }
198    }
199}
200
201impl<T> Default for SimdNormalizeScalarOnly<T>
202where
203    T: Clone + Default + Send + Sync + 'static,
204{
205    fn default() -> Self {
206        Self::new()
207    }
208}
209
210impl<T> Transform<T> for SimdNormalizeScalarOnly<T>
211where
212    T: Clone + Default + scirs2_core::numeric::Float + Send + Sync + 'static,
213{
214    fn apply(&self, sample: (Tensor<T>, Tensor<T>)) -> Result<(Tensor<T>, Tensor<T>)> {
215        let (features, labels) = sample;
216
217        if let Some(data) = features.as_slice() {
218            // Simple z-score normalization
219            let mut values = data.to_vec();
220            let n = T::from(values.len()).unwrap_or(T::one());
221
222            // Calculate mean
223            let sum = values.iter().fold(T::zero(), |acc, &x| acc + x);
224            let mean = sum / n;
225
226            // Calculate standard deviation
227            let variance = values
228                .iter()
229                .map(|&x| {
230                    let diff = x - mean;
231                    diff * diff
232                })
233                .fold(T::zero(), |acc, x| acc + x)
234                / n;
235
236            let std = variance.sqrt();
237
238            // Apply normalization if std is not zero
239            if std > T::zero() {
240                for value in &mut values {
241                    *value = (*value - mean) / std;
242                }
243            }
244
245            let normalized_features = Tensor::from_vec(values, features.shape().dims())?;
246            Ok((normalized_features, labels))
247        } else {
248            Err(TensorError::invalid_argument(
249                "Cannot access tensor data for scalar normalization".to_string(),
250            ))
251        }
252    }
253}