tenflowers_dataset/simd_transforms/
normalization.rs1#![allow(unsafe_code)]
7
8use crate::Transform;
9use std::marker::PhantomData;
10use tenflowers_core::{Result, Tensor, TensorError};
11
12#[cfg(target_arch = "x86_64")]
13use std::arch::x86_64::*;
14
15pub struct SimdNormalize<T> {
20 mean: Vec<T>,
21 std: Vec<T>,
22 use_simd: bool,
23}
24
25impl<T> SimdNormalize<T>
26where
27 T: Clone + Default + scirs2_core::numeric::Float + Send + Sync + 'static,
28{
29 pub fn new(mean: Vec<T>, std: Vec<T>) -> Self {
31 #[cfg(target_arch = "x86_64")]
32 let use_simd = is_x86_feature_detected!("avx2") && std::mem::size_of::<T>() == 4;
33
34 #[cfg(not(target_arch = "x86_64"))]
35 let use_simd = false;
36
37 Self {
38 mean,
39 std,
40 use_simd,
41 }
42 }
43
44 pub fn is_simd_enabled(&self) -> bool {
46 self.use_simd
47 }
48
49 #[cfg(target_arch = "x86_64")]
51 unsafe fn normalize_f32_simd(&self, data: &mut [f32], mean: f32, std: f32) {
52 if !self.use_simd || data.len() < 8 {
53 self.normalize_scalar_f32(data, mean, std);
55 return;
56 }
57
58 let mean_vec = _mm256_set1_ps(mean);
59 let inv_std_vec = _mm256_set1_ps(1.0 / std);
60
61 let chunks = data.len() / 8;
62 let remainder = data.len() % 8;
63
64 for i in 0..chunks {
66 let offset = i * 8;
67 let values = _mm256_loadu_ps(data.as_ptr().add(offset));
68
69 let centered = _mm256_sub_ps(values, mean_vec);
71
72 let normalized = _mm256_mul_ps(centered, inv_std_vec);
74
75 _mm256_storeu_ps(data.as_mut_ptr().add(offset), normalized);
76 }
77
78 if remainder > 0 {
80 let start = chunks * 8;
81 self.normalize_scalar_f32(&mut data[start..], mean, std);
82 }
83 }
84
85 fn normalize_scalar(&self, data: &mut [T], mean: T, std: T)
87 where
88 T: scirs2_core::numeric::Float,
89 {
90 for value in data.iter_mut() {
91 *value = (*value - mean) / std;
92 }
93 }
94
95 #[allow(dead_code)]
97 fn normalize_scalar_f32(&self, data: &mut [f32], mean: f32, std: f32) {
98 for value in data.iter_mut() {
99 *value = (*value - mean) / std;
100 }
101 }
102}
103
104impl<T> Transform<T> for SimdNormalize<T>
105where
106 T: Clone + Default + scirs2_core::numeric::Float + Send + Sync + 'static,
107{
108 fn apply(&self, sample: (Tensor<T>, Tensor<T>)) -> Result<(Tensor<T>, Tensor<T>)> {
109 let (features, labels) = sample;
110
111 if let Some(data) = features.as_slice() {
114 let mut mutable_data = data.to_vec();
115 let feature_count = self.mean.len();
116
117 if mutable_data.len() % feature_count != 0 {
118 return Err(TensorError::invalid_argument(
119 "Feature tensor size must be divisible by number of features".to_string(),
120 ));
121 }
122
123 let samples = mutable_data.len() / feature_count;
124
125 for feature_idx in 0..feature_count {
127 let mean = self.mean[feature_idx];
128 let std = self.std[feature_idx];
129
130 if std == T::zero() {
132 continue;
133 }
134
135 let mut feature_values: Vec<T> = (0..samples)
137 .map(|sample_idx| mutable_data[sample_idx * feature_count + feature_idx])
138 .collect();
139
140 #[cfg(target_arch = "x86_64")]
142 {
143 if self.use_simd && std::mem::size_of::<T>() == 4 {
144 let mean_f32 = unsafe { std::mem::transmute_copy::<T, f32>(&mean) };
145 let std_f32 = unsafe { std::mem::transmute_copy::<T, f32>(&std) };
146 let feature_f32 = unsafe {
147 std::slice::from_raw_parts_mut(
148 feature_values.as_mut_ptr() as *mut f32,
149 feature_values.len(),
150 )
151 };
152
153 unsafe {
154 self.normalize_f32_simd(feature_f32, mean_f32, std_f32);
155 }
156 } else {
157 self.normalize_scalar(&mut feature_values, mean, std);
158 }
159 }
160 #[cfg(not(target_arch = "x86_64"))]
161 {
162 self.normalize_scalar(&mut feature_values, mean, std);
163 }
164
165 for (sample_idx, &normalized_value) in feature_values.iter().enumerate() {
167 mutable_data[sample_idx * feature_count + feature_idx] = normalized_value;
168 }
169 }
170
171 let new_features = Tensor::from_vec(mutable_data, features.shape().dims())?;
173 Ok((new_features, labels))
174 } else {
175 Err(TensorError::invalid_argument(
176 "Cannot access tensor data for normalization".to_string(),
177 ))
178 }
179 }
180}
181
182pub struct SimdNormalizeScalarOnly<T> {
186 _marker: PhantomData<T>,
187}
188
189impl<T> SimdNormalizeScalarOnly<T>
190where
191 T: Clone + Default + Send + Sync + 'static,
192{
193 pub fn new() -> Self {
195 Self {
196 _marker: PhantomData,
197 }
198 }
199}
200
201impl<T> Default for SimdNormalizeScalarOnly<T>
202where
203 T: Clone + Default + Send + Sync + 'static,
204{
205 fn default() -> Self {
206 Self::new()
207 }
208}
209
210impl<T> Transform<T> for SimdNormalizeScalarOnly<T>
211where
212 T: Clone + Default + scirs2_core::numeric::Float + Send + Sync + 'static,
213{
214 fn apply(&self, sample: (Tensor<T>, Tensor<T>)) -> Result<(Tensor<T>, Tensor<T>)> {
215 let (features, labels) = sample;
216
217 if let Some(data) = features.as_slice() {
218 let mut values = data.to_vec();
220 let n = T::from(values.len()).unwrap_or(T::one());
221
222 let sum = values.iter().fold(T::zero(), |acc, &x| acc + x);
224 let mean = sum / n;
225
226 let variance = values
228 .iter()
229 .map(|&x| {
230 let diff = x - mean;
231 diff * diff
232 })
233 .fold(T::zero(), |acc, x| acc + x)
234 / n;
235
236 let std = variance.sqrt();
237
238 if std > T::zero() {
240 for value in &mut values {
241 *value = (*value - mean) / std;
242 }
243 }
244
245 let normalized_features = Tensor::from_vec(values, features.shape().dims())?;
246 Ok((normalized_features, labels))
247 } else {
248 Err(TensorError::invalid_argument(
249 "Cannot access tensor data for scalar normalization".to_string(),
250 ))
251 }
252 }
253}