1use crate::error::{StatsError, StatsResult};
10use scirs2_core::ndarray::{Array1, Array2, ArrayBase, ArrayView1, Data, Ix1, Ix2};
11use scirs2_core::numeric::{Float, NumCast};
12use scirs2_core::parallel_ops::{num_threads, par_chunks, IntoParallelIterator, ParallelIterator};
13use scirs2_core::validation::check_not_empty;
14use std::sync::Arc;
15
16#[derive(Debug, Clone)]
18pub struct ParallelConfig {
19 pub minsize: usize,
21 pub chunksize: Option<usize>,
23 pub max_threads: Option<usize>,
25 pub adaptive: bool,
27}
28
29impl Default for ParallelConfig {
30 fn default() -> Self {
31 Self {
32 minsize: 5_000, chunksize: None, max_threads: None, adaptive: true,
36 }
37 }
38}
39
40impl ParallelConfig {
41 pub fn with_threads(mut self, threads: usize) -> Self {
43 self.max_threads = Some(threads);
44 self
45 }
46
47 pub fn with_chunksize(mut self, size: usize) -> Self {
49 self.chunksize = Some(size);
50 self
51 }
52
53 pub fn should_parallelize(&self, n: usize) -> bool {
55 if self.adaptive {
56 let threads = self.max_threads.unwrap_or_else(num_threads);
58
59 let base_overhead = 800;
61 let overhead_factor = base_overhead + (threads.saturating_sub(1) * 200);
62
63 if n > 100_000 {
65 return true;
66 }
67
68 if n < 1_000 {
70 return false;
71 }
72
73 n > threads * overhead_factor
75 } else {
76 n >= self.minsize
77 }
78 }
79
80 pub fn get_chunksize(&self, n: usize) -> usize {
82 if let Some(size) = self.chunksize {
83 size
84 } else {
85 let threads = self.max_threads.unwrap_or(num_threads());
87 (n / threads).max(1000)
88 }
89 }
90}
91
92#[allow(dead_code)]
96pub fn mean_parallel_enhanced<F, D>(
97 x: &ArrayBase<D, Ix1>,
98 config: Option<ParallelConfig>,
99) -> StatsResult<F>
100where
101 F: Float + NumCast + Send + Sync + std::iter::Sum<F> + std::fmt::Display,
102 D: Data<Elem = F> + Sync,
103{
104 check_not_empty(x, "x")
106 .map_err(|_| StatsError::invalid_argument("Cannot compute mean of empty array"))?;
107
108 let config = config.unwrap_or_default();
109 let n = x.len();
110
111 if !config.should_parallelize(n) {
112 let sum = x.iter().fold(F::zero(), |acc, &val| acc + val);
114 return Ok(sum / F::from(n).unwrap());
115 }
116
117 let sum = if let Some(slice) = x.as_slice() {
119 parallel_sum_slice(slice, &config)
121 } else {
122 parallel_sum_indexed(x, &config)
124 };
125
126 Ok(sum / F::from(n).unwrap())
127}
128
129#[allow(dead_code)]
133pub fn variance_parallel_enhanced<F, D>(
134 x: &ArrayBase<D, Ix1>,
135 ddof: usize,
136 config: Option<ParallelConfig>,
137) -> StatsResult<F>
138where
139 F: Float + NumCast + Send + Sync + std::iter::Sum<F> + std::fmt::Display,
140 D: Data<Elem = F> + Sync,
141{
142 let n = x.len();
143 if n <= ddof {
144 return Err(StatsError::invalid_argument(
145 "Not enough data points for the given degrees of freedom",
146 ));
147 }
148
149 let config = config.unwrap_or_default();
150
151 if !config.should_parallelize(n) {
152 return variance_sequential_welford(x, ddof);
154 }
155
156 let chunksize = config.get_chunksize(n);
158 let n_chunks = n.div_ceil(chunksize);
159
160 let chunk_stats: Vec<(F, F, usize)> = (0..n_chunks)
162 .collect::<Vec<_>>()
163 .into_par_iter()
164 .map(|chunk_idx| {
165 let start = chunk_idx * chunksize;
166 let end = (start + chunksize).min(n);
167
168 let mut local_mean = F::zero();
169 let mut local_m2 = F::zero();
170 let mut count = 0;
171
172 for i in start..end {
173 count += 1;
174 let val = x[i];
175 let delta = val - local_mean;
176 local_mean = local_mean + delta / F::from(count).unwrap();
177 let delta2 = val - local_mean;
178 local_m2 = local_m2 + delta * delta2;
179 }
180
181 (local_mean, local_m2, count)
182 })
183 .collect();
184
185 let (_total_mean, total_m2__, total_count) = combine_welford_stats(&chunk_stats);
187
188 Ok(total_m2__ / F::from(n - ddof).unwrap())
189}
190
191#[allow(dead_code)]
195pub fn corrcoef_parallel_enhanced<F, D>(
196 data: &ArrayBase<D, Ix2>,
197 config: Option<ParallelConfig>,
198) -> StatsResult<Array2<F>>
199where
200 F: Float + NumCast + Send + Sync + std::iter::Sum<F> + std::fmt::Display,
201 D: Data<Elem = F> + Sync,
202{
203 let (n_samples_, n_features) = data.dim();
204
205 if n_samples_ == 0 || n_features == 0 {
206 return Err(StatsError::invalid_argument("Empty data matrix"));
207 }
208
209 let config = config.unwrap_or_default();
210
211 let means: Vec<F> = (0..n_features)
213 .collect::<Vec<_>>()
214 .into_par_iter()
215 .map(|j| {
216 let col = data.column(j);
217 mean_parallel_enhanced(&col, Some(config.clone())).unwrap_or(F::zero())
218 })
219 .collect();
220
221 let mut corr_matrix = Array2::zeros((n_features, n_features));
223
224 let indices: Vec<(usize, usize)> = (0..n_features)
226 .flat_map(|i| (i..n_features).map(move |j| (i, j)))
227 .collect();
228
229 let correlations: Vec<((usize, usize), F)> = indices
230 .into_par_iter()
231 .map(|(i, j)| {
232 let corr = if i == j {
233 F::one() } else {
235 compute_correlation_pair(&data.column(i), &data.column(j), means[i], means[j])
236 };
237 ((i, j), corr)
238 })
239 .collect();
240
241 for ((i, j), corr) in correlations {
243 corr_matrix[(i, j)] = corr;
244 if i != j {
245 corr_matrix[(j, i)] = corr; }
247 }
248
249 Ok(corr_matrix)
250}
251
252#[allow(dead_code)]
256pub fn bootstrap_parallel_enhanced<F, D>(
257 data: &ArrayBase<D, Ix1>,
258 n_samples_: usize,
259 statistic_fn: impl Fn(&ArrayView1<F>) -> F + Send + Sync,
260 config: Option<ParallelConfig>,
261) -> StatsResult<Array1<F>>
262where
263 F: Float + NumCast + Send + Sync,
264 D: Data<Elem = F> + Sync,
265{
266 if data.is_empty() {
267 return Err(StatsError::invalid_argument("Cannot bootstrap empty data"));
268 }
269
270 let _config = config.unwrap_or_default();
271 let data_arc = Arc::new(data.to_owned());
272 let n = data.len();
273
274 let stats: Vec<F> = (0..n_samples_)
276 .collect::<Vec<_>>()
277 .into_par_iter()
278 .map(|sample_idx| {
279 use scirs2_core::random::rngs::StdRng;
280 use scirs2_core::random::{Rng, SeedableRng};
281
282 let mut rng = StdRng::seed_from_u64(sample_idx as u64);
284 let mut sample = Array1::zeros(n);
285
286 for i in 0..n {
288 let idx = rng.gen_range(0..n);
289 sample[i] = data_arc[idx];
290 }
291
292 statistic_fn(&sample.view())
293 })
294 .collect();
295
296 Ok(Array1::from(stats))
297}
298
299#[allow(dead_code)]
301fn parallel_sum_slice<F>(slice: &[F], config: &ParallelConfig) -> F
302where
303 F: Float + NumCast + Send + Sync + std::iter::Sum + std::fmt::Display,
304{
305 let chunksize = config.get_chunksize(slice.len());
306
307 par_chunks(slice, chunksize)
308 .map(|chunk| chunk.iter().fold(F::zero(), |acc, &val| acc + val))
309 .reduce(|| F::zero(), |a, b| a + b)
310}
311
312#[allow(dead_code)]
314fn parallel_sum_indexed<F, D>(arr: &ArrayBase<D, Ix1>, config: &ParallelConfig) -> F
315where
316 F: Float + NumCast + Send + Sync + std::iter::Sum<F> + std::fmt::Display,
317 D: Data<Elem = F> + Sync,
318{
319 let n = arr.len();
320 let chunksize = config.get_chunksize(n);
321 let n_chunks = n.div_ceil(chunksize);
322
323 (0..n_chunks)
324 .collect::<Vec<_>>()
325 .into_par_iter()
326 .map(|chunk_idx| {
327 let start = chunk_idx * chunksize;
328 let end = (start + chunksize).min(n);
329
330 (start..end)
331 .map(|i| arr[i])
332 .fold(F::zero(), |acc, val| acc + val)
333 })
334 .reduce(|| F::zero(), |a, b| a + b)
335}
336
337#[allow(dead_code)]
339fn variance_sequential_welford<F, D>(x: &ArrayBase<D, Ix1>, ddof: usize) -> StatsResult<F>
340where
341 F: Float + NumCast,
342 D: Data<Elem = F>,
343{
344 let mut mean = F::zero();
345 let mut m2 = F::zero();
346 let mut count = 0;
347
348 for &val in x.iter() {
349 count += 1;
350 let delta = val - mean;
351 mean = mean + delta / F::from(count).unwrap();
352 let delta2 = val - mean;
353 m2 = m2 + delta * delta2;
354 }
355
356 Ok(m2 / F::from(count - ddof).unwrap())
357}
358
359#[allow(dead_code)]
361fn combine_welford_stats<F>(stats: &[(F, F, usize)]) -> (F, F, usize)
362where
363 F: Float + NumCast + std::fmt::Display,
364{
365 stats.iter().fold(
366 (F::zero(), F::zero(), 0),
367 |(mean_a, m2_a, count_a), &(mean_b, m2_b, count_b)| {
368 let count = count_a + count_b;
369 let delta = mean_b - mean_a;
370 let mean = mean_a + delta * F::from(count_b).unwrap() / F::from(count).unwrap();
371 let m2 = m2_a
372 + m2_b
373 + delta * delta * F::from(count_a).unwrap() * F::from(count_b).unwrap()
374 / F::from(count).unwrap();
375 (mean, m2, count)
376 },
377 )
378}
379
380#[allow(dead_code)]
382fn compute_correlation_pair<F>(x: &ArrayView1<F>, y: &ArrayView1<F>, mean_x: F, meany: F) -> F
383where
384 F: Float + NumCast + std::fmt::Display,
385{
386 let n = x.len();
387 let mut cov = F::zero();
388 let mut var_x = F::zero();
389 let mut var_y = F::zero();
390
391 for i in 0..n {
392 let dx = x[i] - mean_x;
393 let dy = y[i] - meany;
394 cov = cov + dx * dy;
395 var_x = var_x + dx * dx;
396 var_y = var_y + dy * dy;
397 }
398
399 if var_x > F::epsilon() && var_y > F::epsilon() {
400 cov / (var_x * var_y).sqrt()
401 } else {
402 F::zero()
403 }
404}
405
406#[cfg(test)]
407mod tests {
408 use super::*;
409 use scirs2_core::ndarray::array;
410
411 #[test]
412 fn test_parallel_config() {
413 let config = ParallelConfig::default();
414 assert!(config.should_parallelize(100_000));
415 assert!(!config.should_parallelize(100));
416
417 let config_fixed = ParallelConfig::default()
418 .with_threads(4)
419 .with_chunksize(1000);
420 assert_eq!(config_fixed.get_chunksize(10_000), 1000);
421 }
422
423 #[test]
424 fn test_mean_parallel_enhanced() {
425 let data = Array1::from_vec((0..10_000).map(|i| i as f64).collect());
426 let mean = mean_parallel_enhanced(&data.view(), None).unwrap();
427 assert!((mean - 4999.5).abs() < 1e-10);
428 }
429
430 #[test]
431 fn test_variance_parallel_enhanced() {
432 let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
433 let var = variance_parallel_enhanced(&data.view(), 1, None).unwrap();
434 assert!((var - 2.5).abs() < 1e-10);
435 }
436}