1use crate::error::StatsResult;
10use scirs2_core::ndarray::{ArrayBase, Data, Ix1};
11use scirs2_core::numeric::{Float, NumCast};
12use scirs2_core::simd_ops::{PlatformCapabilities, SimdUnifiedOps};
13use scirs2_core::validation::check_not_empty;
14
15#[derive(Debug, Clone, Copy)]
17pub struct SimdConfig {
18 pub minsize: usize,
20 pub use_aligned: bool,
22 pub unroll_factor: usize,
24}
25
26impl Default for SimdConfig {
27 fn default() -> Self {
28 let minsize = 128; Self {
33 minsize,
34 use_aligned: true, unroll_factor: 4, }
37 }
38}
39
40#[allow(dead_code)]
45pub fn mean_simd_optimized<F, D>(
46 x: &ArrayBase<D, Ix1>,
47 config: Option<SimdConfig>,
48) -> StatsResult<F>
49where
50 F: Float + NumCast + SimdUnifiedOps,
51 D: Data<Elem = F>,
52{
53 check_not_empty(x, "x").map_err(|_| {
55 crate::error::StatsError::invalid_argument("Cannot compute mean of empty array")
56 })?;
57
58 let config = config.unwrap_or_default();
59 let n = x.len();
60
61 if n < config.minsize {
62 let sum = x.iter().fold(F::zero(), |acc, &val| acc + val);
64 return Ok(sum / F::from(n).unwrap());
65 }
66
67 let sum = chunked_simd_sum(x, &config)?;
69 Ok(sum / F::from(n).unwrap())
70}
71
72#[allow(dead_code)]
76pub fn variance_simd_optimized<F, D>(
77 x: &ArrayBase<D, Ix1>,
78 ddof: usize,
79 config: Option<SimdConfig>,
80) -> StatsResult<F>
81where
82 F: Float + NumCast + SimdUnifiedOps,
83 D: Data<Elem = F>,
84{
85 let n = x.len();
86 if n <= ddof {
87 return Err(crate::error::StatsError::invalid_argument(
88 "Not enough data points for the given degrees of freedom",
89 ));
90 }
91
92 let config = config.unwrap_or_default();
93
94 if n < config.minsize {
95 return variance_scalar_welford(x, ddof);
97 }
98
99 let mean = mean_simd_optimized(x, Some(config))?;
101 let sum_sq_dev = chunked_simd_sum_squared_deviations(x, mean, &config)?;
102
103 Ok(sum_sq_dev / F::from(n - ddof).unwrap())
104}
105
106#[allow(dead_code)]
110pub fn stats_simd_single_pass<F, D>(
111 x: &ArrayBase<D, Ix1>,
112 config: Option<SimdConfig>,
113) -> StatsResult<(F, F, F, F, F, F)>
114where
115 F: Float + NumCast + SimdUnifiedOps,
116 D: Data<Elem = F>,
117{
118 if x.is_empty() {
119 return Err(crate::error::StatsError::invalid_argument(
120 "Cannot compute statistics of empty array",
121 ));
122 }
123
124 let config = config.unwrap_or_default();
125 let n = x.len();
126 let n_f = F::from(n).unwrap();
127
128 if n < config.minsize {
129 return stats_scalar_single_pass(x);
131 }
132
133 let capabilities = PlatformCapabilities::detect();
135 let simd_width = if capabilities.simd_available { 8 } else { 1 };
136
137 let mut m1 = F::zero(); let mut m2 = F::zero(); let mut m3 = F::zero(); let mut m4 = F::zero(); let mut min = x[0];
143 let mut max = x[0];
144
145 let chunks = x.len() / simd_width;
147 let _remainder = x.len() % simd_width;
148
149 for chunk_idx in 0..chunks {
150 let start = chunk_idx * simd_width;
151 let chunk = x.slice(scirs2_core::ndarray::s![start..start + simd_width]);
152
153 let chunk_sum = F::simd_sum(&chunk);
155 m1 = m1 + chunk_sum;
156
157 let chunk_min = F::simd_min_element(&chunk);
159 let chunk_max = F::simd_max_element(&chunk);
160 if chunk_min < min {
161 min = chunk_min;
162 }
163 if chunk_max > max {
164 max = chunk_max;
165 }
166 }
167
168 let remainder_start = chunks * simd_width;
170 for i in remainder_start..x.len() {
171 let val = x[i];
172 m1 = m1 + val;
173 if val < min {
174 min = val;
175 }
176 if val > max {
177 max = val;
178 }
179 }
180
181 let mean = m1 / n_f;
183
184 for chunk_idx in 0..chunks {
186 let start = chunk_idx * simd_width;
187 let chunk = x.slice(scirs2_core::ndarray::s![start..start + simd_width]);
188
189 for &val in chunk.iter() {
191 let dev = val - mean;
192 let dev2 = dev * dev;
193 let dev3 = dev2 * dev;
194 let dev4 = dev3 * dev;
195
196 m2 = m2 + dev2;
197 m3 = m3 + dev3;
198 m4 = m4 + dev4;
199 }
200 }
201
202 for i in remainder_start..x.len() {
204 let dev = x[i] - mean;
205 let dev2 = dev * dev;
206 let dev3 = dev2 * dev;
207 let dev4 = dev3 * dev;
208
209 m2 = m2 + dev2;
210 m3 = m3 + dev3;
211 m4 = m4 + dev4;
212 }
213
214 let variance = m2 / F::from(n - 1).unwrap();
216 let std_dev = variance.sqrt();
217
218 let skewness = if std_dev > F::epsilon() {
219 (m3 / n_f) / (std_dev * std_dev * std_dev)
220 } else {
221 F::zero()
222 };
223
224 let kurtosis = if variance > F::epsilon() {
225 (m4 / n_f) / (variance * variance) - F::from(3).unwrap()
226 } else {
227 F::zero()
228 };
229
230 Ok((mean, variance, min, max, skewness, kurtosis))
231}
232
233#[allow(dead_code)]
235fn chunked_simd_sum<F, D>(x: &ArrayBase<D, Ix1>, config: &SimdConfig) -> StatsResult<F>
236where
237 F: Float + NumCast + SimdUnifiedOps,
238 D: Data<Elem = F>,
239{
240 let capabilities = PlatformCapabilities::detect();
241 let _simd_width = if capabilities.simd_available { 8 } else { 1 };
242
243 const CHUNK_SIZE: usize = 1024;
245 let mut total_sum = F::zero();
246
247 for chunk in x.windows(CHUNK_SIZE) {
248 let chunk_sum = F::simd_sum(&chunk.view());
249 total_sum = total_sum + chunk_sum;
250 }
251
252 let processed = (x.len() / CHUNK_SIZE) * CHUNK_SIZE;
254 if processed < x.len() {
255 let remainder = x.slice(scirs2_core::ndarray::s![processed..]);
256 let remainder_sum = F::simd_sum(&remainder);
257 total_sum = total_sum + remainder_sum;
258 }
259
260 Ok(total_sum)
261}
262
263#[allow(dead_code)]
265fn chunked_simd_sum_squared_deviations<F, D>(
266 x: &ArrayBase<D, Ix1>,
267 mean: F,
268 config: &SimdConfig,
269) -> StatsResult<F>
270where
271 F: Float + NumCast + SimdUnifiedOps,
272 D: Data<Elem = F>,
273{
274 const CHUNK_SIZE: usize = 1024;
275 let mut total_sum = F::zero();
276
277 for chunk in x.windows(CHUNK_SIZE) {
279 let chunk_sum = chunk
280 .iter()
281 .map(|&val| {
282 let dev = val - mean;
283 dev * dev
284 })
285 .fold(F::zero(), |acc, val| acc + val);
286 total_sum = total_sum + chunk_sum;
287 }
288
289 let processed = (x.len() / CHUNK_SIZE) * CHUNK_SIZE;
291 if processed < x.len() {
292 for i in processed..x.len() {
293 let dev = x[i] - mean;
294 total_sum = total_sum + dev * dev;
295 }
296 }
297
298 Ok(total_sum)
299}
300
301#[allow(dead_code)]
303fn variance_scalar_welford<F, D>(x: &ArrayBase<D, Ix1>, ddof: usize) -> StatsResult<F>
304where
305 F: Float + NumCast,
306 D: Data<Elem = F>,
307{
308 let mut mean = F::zero();
309 let mut m2 = F::zero();
310 let mut count = 0;
311
312 for &val in x.iter() {
313 count += 1;
314 let delta = val - mean;
315 mean = mean + delta / F::from(count).unwrap();
316 let delta2 = val - mean;
317 m2 = m2 + delta * delta2;
318 }
319
320 Ok(m2 / F::from(count - ddof).unwrap())
321}
322
323#[allow(dead_code)]
325fn stats_scalar_single_pass<F, D>(x: &ArrayBase<D, Ix1>) -> StatsResult<(F, F, F, F, F, F)>
326where
327 F: Float + NumCast,
328 D: Data<Elem = F>,
329{
330 let n = x.len();
331 let n_f = F::from(n).unwrap();
332
333 let mean = x.iter().fold(F::zero(), |acc, &val| acc + val) / n_f;
335
336 let mut m2 = F::zero();
338 let mut m3 = F::zero();
339 let mut m4 = F::zero();
340 let mut min = x[0];
341 let mut max = x[0];
342
343 for &val in x.iter() {
344 let dev = val - mean;
345 let dev2 = dev * dev;
346 let dev3 = dev2 * dev;
347 let dev4 = dev3 * dev;
348
349 m2 = m2 + dev2;
350 m3 = m3 + dev3;
351 m4 = m4 + dev4;
352
353 if val < min {
354 min = val;
355 }
356 if val > max {
357 max = val;
358 }
359 }
360
361 let variance = m2 / F::from(n - 1).unwrap();
362 let std_dev = variance.sqrt();
363
364 let skewness = if std_dev > F::epsilon() {
365 (m3 / n_f) / (std_dev * std_dev * std_dev)
366 } else {
367 F::zero()
368 };
369
370 let kurtosis = if variance > F::epsilon() {
371 (m4 / n_f) / (variance * variance) - F::from(3).unwrap()
372 } else {
373 F::zero()
374 };
375
376 Ok((mean, variance, min, max, skewness, kurtosis))
377}
378
379#[cfg(test)]
380mod tests {
381 use super::*;
382 use scirs2_core::ndarray::array;
383
384 #[test]
385 fn test_mean_simd_optimized() {
386 let data = array![1.0f64, 2.0, 3.0, 4.0, 5.0];
387 let mean = mean_simd_optimized(&data.view(), None).unwrap();
388 assert!((mean - 3.0).abs() < 1e-10);
389 }
390
391 #[test]
392 fn test_variance_simd_optimized() {
393 let data = array![1.0f64, 2.0, 3.0, 4.0, 5.0];
394 let var = variance_simd_optimized(&data.view(), 1, None).unwrap();
395 assert!((var - 2.5).abs() < 1e-10);
396 }
397
398 #[test]
399 fn test_stats_single_pass() {
400 let data = array![1.0f64, 2.0, 3.0, 4.0, 5.0];
401 let (mean, var, min, max__, skew, kurt) =
402 stats_simd_single_pass(&data.view(), None).unwrap();
403
404 assert!((mean - 3.0).abs() < 1e-10);
405 assert!((var - 2.5).abs() < 1e-10);
406 assert!((min - 1.0).abs() < 1e-10);
407 assert!((max__ - 5.0).abs() < 1e-10);
408 }
409}