1use crate::error::StatsResult;
10use scirs2_core::ndarray::{ArrayBase, Data, Ix1};
11use scirs2_core::numeric::{Float, NumCast};
12use scirs2_core::simd_ops::{PlatformCapabilities, SimdUnifiedOps};
13use scirs2_core::validation::check_not_empty;
14
15#[derive(Debug, Clone, Copy)]
17pub struct SimdConfig {
18 pub minsize: usize,
20 pub use_aligned: bool,
22 pub unroll_factor: usize,
24}
25
26impl Default for SimdConfig {
27 fn default() -> Self {
28 let minsize = 128; Self {
33 minsize,
34 use_aligned: true, unroll_factor: 4, }
37 }
38}
39
40#[allow(dead_code)]
45pub fn mean_simd_optimized<F, D>(
46 x: &ArrayBase<D, Ix1>,
47 config: Option<SimdConfig>,
48) -> StatsResult<F>
49where
50 F: Float + NumCast + SimdUnifiedOps,
51 D: Data<Elem = F>,
52{
53 check_not_empty(x, "x").map_err(|_| {
55 crate::error::StatsError::invalid_argument("Cannot compute mean of empty array")
56 })?;
57
58 let config = config.unwrap_or_default();
59 let n = x.len();
60
61 if n < config.minsize {
62 let sum = x.iter().fold(F::zero(), |acc, &val| acc + val);
64 return Ok(sum / F::from(n).expect("Failed to convert to float"));
65 }
66
67 let sum = chunked_simd_sum(x, &config)?;
69 Ok(sum / F::from(n).expect("Failed to convert to float"))
70}
71
72#[allow(dead_code)]
76pub fn variance_simd_optimized<F, D>(
77 x: &ArrayBase<D, Ix1>,
78 ddof: usize,
79 config: Option<SimdConfig>,
80) -> StatsResult<F>
81where
82 F: Float + NumCast + SimdUnifiedOps,
83 D: Data<Elem = F>,
84{
85 let n = x.len();
86 if n <= ddof {
87 return Err(crate::error::StatsError::invalid_argument(
88 "Not enough data points for the given degrees of freedom",
89 ));
90 }
91
92 let config = config.unwrap_or_default();
93
94 if n < config.minsize {
95 return variance_scalar_welford(x, ddof);
97 }
98
99 let mean = mean_simd_optimized(x, Some(config))?;
101 let sum_sq_dev = chunked_simd_sum_squared_deviations(x, mean, &config)?;
102
103 Ok(sum_sq_dev / F::from(n - ddof).expect("Failed to convert to float"))
104}
105
106#[allow(dead_code)]
110pub fn stats_simd_single_pass<F, D>(
111 x: &ArrayBase<D, Ix1>,
112 config: Option<SimdConfig>,
113) -> StatsResult<(F, F, F, F, F, F)>
114where
115 F: Float + NumCast + SimdUnifiedOps,
116 D: Data<Elem = F>,
117{
118 if x.is_empty() {
119 return Err(crate::error::StatsError::invalid_argument(
120 "Cannot compute statistics of empty array",
121 ));
122 }
123
124 let config = config.unwrap_or_default();
125 let n = x.len();
126 let n_f = F::from(n).expect("Failed to convert to float");
127
128 if n < config.minsize {
129 return stats_scalar_single_pass(x);
131 }
132
133 let capabilities = PlatformCapabilities::detect();
135 let simd_width = if capabilities.simd_available { 8 } else { 1 };
136
137 let mut m1 = F::zero(); let mut m2 = F::zero(); let mut m3 = F::zero(); let mut m4 = F::zero(); let mut min = x[0];
143 let mut max = x[0];
144
145 let chunks = x.len() / simd_width;
147 let _remainder = x.len() % simd_width;
148
149 for chunk_idx in 0..chunks {
150 let start = chunk_idx * simd_width;
151 let chunk = x.slice(scirs2_core::ndarray::s![start..start + simd_width]);
152
153 let chunk_sum = F::simd_sum(&chunk);
155 m1 = m1 + chunk_sum;
156
157 let chunk_min = F::simd_min_element(&chunk);
159 let chunk_max = F::simd_max_element(&chunk);
160 if chunk_min < min {
161 min = chunk_min;
162 }
163 if chunk_max > max {
164 max = chunk_max;
165 }
166 }
167
168 let remainder_start = chunks * simd_width;
170 for i in remainder_start..x.len() {
171 let val = x[i];
172 m1 = m1 + val;
173 if val < min {
174 min = val;
175 }
176 if val > max {
177 max = val;
178 }
179 }
180
181 let mean = m1 / n_f;
183
184 for chunk_idx in 0..chunks {
186 let start = chunk_idx * simd_width;
187 let chunk = x.slice(scirs2_core::ndarray::s![start..start + simd_width]);
188
189 for &val in chunk.iter() {
191 let dev = val - mean;
192 let dev2 = dev * dev;
193 let dev3 = dev2 * dev;
194 let dev4 = dev3 * dev;
195
196 m2 = m2 + dev2;
197 m3 = m3 + dev3;
198 m4 = m4 + dev4;
199 }
200 }
201
202 for i in remainder_start..x.len() {
204 let dev = x[i] - mean;
205 let dev2 = dev * dev;
206 let dev3 = dev2 * dev;
207 let dev4 = dev3 * dev;
208
209 m2 = m2 + dev2;
210 m3 = m3 + dev3;
211 m4 = m4 + dev4;
212 }
213
214 let variance = m2 / F::from(n - 1).expect("Failed to convert to float");
216 let std_dev = variance.sqrt();
217
218 let skewness = if std_dev > F::epsilon() {
219 (m3 / n_f) / (std_dev * std_dev * std_dev)
220 } else {
221 F::zero()
222 };
223
224 let kurtosis = if variance > F::epsilon() {
225 (m4 / n_f) / (variance * variance)
226 - F::from(3).expect("Failed to convert constant to float")
227 } else {
228 F::zero()
229 };
230
231 Ok((mean, variance, min, max, skewness, kurtosis))
232}
233
234#[allow(dead_code)]
236fn chunked_simd_sum<F, D>(x: &ArrayBase<D, Ix1>, config: &SimdConfig) -> StatsResult<F>
237where
238 F: Float + NumCast + SimdUnifiedOps,
239 D: Data<Elem = F>,
240{
241 let capabilities = PlatformCapabilities::detect();
242 let _simd_width = if capabilities.simd_available { 8 } else { 1 };
243
244 const CHUNK_SIZE: usize = 1024;
246 let mut total_sum = F::zero();
247
248 for chunk in x.windows(CHUNK_SIZE) {
249 let chunk_sum = F::simd_sum(&chunk.view());
250 total_sum = total_sum + chunk_sum;
251 }
252
253 let processed = (x.len() / CHUNK_SIZE) * CHUNK_SIZE;
255 if processed < x.len() {
256 let remainder = x.slice(scirs2_core::ndarray::s![processed..]);
257 let remainder_sum = F::simd_sum(&remainder);
258 total_sum = total_sum + remainder_sum;
259 }
260
261 Ok(total_sum)
262}
263
264#[allow(dead_code)]
266fn chunked_simd_sum_squared_deviations<F, D>(
267 x: &ArrayBase<D, Ix1>,
268 mean: F,
269 config: &SimdConfig,
270) -> StatsResult<F>
271where
272 F: Float + NumCast + SimdUnifiedOps,
273 D: Data<Elem = F>,
274{
275 const CHUNK_SIZE: usize = 1024;
276 let mut total_sum = F::zero();
277
278 for chunk in x.windows(CHUNK_SIZE) {
280 let chunk_sum = chunk
281 .iter()
282 .map(|&val| {
283 let dev = val - mean;
284 dev * dev
285 })
286 .fold(F::zero(), |acc, val| acc + val);
287 total_sum = total_sum + chunk_sum;
288 }
289
290 let processed = (x.len() / CHUNK_SIZE) * CHUNK_SIZE;
292 if processed < x.len() {
293 for i in processed..x.len() {
294 let dev = x[i] - mean;
295 total_sum = total_sum + dev * dev;
296 }
297 }
298
299 Ok(total_sum)
300}
301
302#[allow(dead_code)]
304fn variance_scalar_welford<F, D>(x: &ArrayBase<D, Ix1>, ddof: usize) -> StatsResult<F>
305where
306 F: Float + NumCast,
307 D: Data<Elem = F>,
308{
309 let mut mean = F::zero();
310 let mut m2 = F::zero();
311 let mut count = 0;
312
313 for &val in x.iter() {
314 count += 1;
315 let delta = val - mean;
316 mean = mean + delta / F::from(count).expect("Failed to convert to float");
317 let delta2 = val - mean;
318 m2 = m2 + delta * delta2;
319 }
320
321 Ok(m2 / F::from(count - ddof).expect("Failed to convert to float"))
322}
323
324#[allow(dead_code)]
326fn stats_scalar_single_pass<F, D>(x: &ArrayBase<D, Ix1>) -> StatsResult<(F, F, F, F, F, F)>
327where
328 F: Float + NumCast,
329 D: Data<Elem = F>,
330{
331 let n = x.len();
332 let n_f = F::from(n).expect("Failed to convert to float");
333
334 let mean = x.iter().fold(F::zero(), |acc, &val| acc + val) / n_f;
336
337 let mut m2 = F::zero();
339 let mut m3 = F::zero();
340 let mut m4 = F::zero();
341 let mut min = x[0];
342 let mut max = x[0];
343
344 for &val in x.iter() {
345 let dev = val - mean;
346 let dev2 = dev * dev;
347 let dev3 = dev2 * dev;
348 let dev4 = dev3 * dev;
349
350 m2 = m2 + dev2;
351 m3 = m3 + dev3;
352 m4 = m4 + dev4;
353
354 if val < min {
355 min = val;
356 }
357 if val > max {
358 max = val;
359 }
360 }
361
362 let variance = m2 / F::from(n - 1).expect("Failed to convert to float");
363 let std_dev = variance.sqrt();
364
365 let skewness = if std_dev > F::epsilon() {
366 (m3 / n_f) / (std_dev * std_dev * std_dev)
367 } else {
368 F::zero()
369 };
370
371 let kurtosis = if variance > F::epsilon() {
372 (m4 / n_f) / (variance * variance)
373 - F::from(3).expect("Failed to convert constant to float")
374 } else {
375 F::zero()
376 };
377
378 Ok((mean, variance, min, max, skewness, kurtosis))
379}
380
381#[cfg(test)]
382mod tests {
383 use super::*;
384 use scirs2_core::ndarray::array;
385
386 #[test]
387 fn test_mean_simd_optimized() {
388 let data = array![1.0f64, 2.0, 3.0, 4.0, 5.0];
389 let mean = mean_simd_optimized(&data.view(), None).expect("Operation failed");
390 assert!((mean - 3.0).abs() < 1e-10);
391 }
392
393 #[test]
394 fn test_variance_simd_optimized() {
395 let data = array![1.0f64, 2.0, 3.0, 4.0, 5.0];
396 let var = variance_simd_optimized(&data.view(), 1, None).expect("Operation failed");
397 assert!((var - 2.5).abs() < 1e-10);
398 }
399
400 #[test]
401 fn test_stats_single_pass() {
402 let data = array![1.0f64, 2.0, 3.0, 4.0, 5.0];
403 let (mean, var, min, max__, skew, kurt) =
404 stats_simd_single_pass(&data.view(), None).expect("Operation failed");
405
406 assert!((mean - 3.0).abs() < 1e-10);
407 assert!((var - 2.5).abs() < 1e-10);
408 assert!((min - 1.0).abs() < 1e-10);
409 assert!((max__ - 5.0).abs() < 1e-10);
410 }
411}