1use crate::error::{StatsError, StatsResult};
10use scirs2_core::ndarray::{Array1, Array2, ArrayBase, ArrayView1, Data, Ix1, Ix2};
11use scirs2_core::numeric::{Float, NumCast};
12use scirs2_core::parallel_ops::{num_threads, par_chunks, IntoParallelIterator, ParallelIterator};
13use scirs2_core::validation::check_not_empty;
14use std::sync::Arc;
15
16#[derive(Debug, Clone)]
18pub struct ParallelConfig {
19 pub minsize: usize,
21 pub chunksize: Option<usize>,
23 pub max_threads: Option<usize>,
25 pub adaptive: bool,
27}
28
29impl Default for ParallelConfig {
30 fn default() -> Self {
31 Self {
32 minsize: 5_000, chunksize: None, max_threads: None, adaptive: true,
36 }
37 }
38}
39
40impl ParallelConfig {
41 pub fn with_threads(mut self, threads: usize) -> Self {
43 self.max_threads = Some(threads);
44 self
45 }
46
47 pub fn with_chunksize(mut self, size: usize) -> Self {
49 self.chunksize = Some(size);
50 self
51 }
52
53 pub fn should_parallelize(&self, n: usize) -> bool {
55 if self.adaptive {
56 let threads = self.max_threads.unwrap_or_else(num_threads);
58
59 let base_overhead = 800;
61 let overhead_factor = base_overhead + (threads.saturating_sub(1) * 200);
62
63 if n > 100_000 {
65 return true;
66 }
67
68 if n < 1_000 {
70 return false;
71 }
72
73 n > threads * overhead_factor
75 } else {
76 n >= self.minsize
77 }
78 }
79
80 pub fn get_chunksize(&self, n: usize) -> usize {
82 if let Some(size) = self.chunksize {
83 size
84 } else {
85 let threads = self.max_threads.unwrap_or(num_threads());
87 (n / threads).max(1000)
88 }
89 }
90}
91
92#[allow(dead_code)]
96pub fn mean_parallel_enhanced<F, D>(
97 x: &ArrayBase<D, Ix1>,
98 config: Option<ParallelConfig>,
99) -> StatsResult<F>
100where
101 F: Float + NumCast + Send + Sync + std::iter::Sum<F> + std::fmt::Display,
102 D: Data<Elem = F> + Sync,
103{
104 check_not_empty(x, "x")
106 .map_err(|_| StatsError::invalid_argument("Cannot compute mean of empty array"))?;
107
108 let config = config.unwrap_or_default();
109 let n = x.len();
110
111 if !config.should_parallelize(n) {
112 let sum = x.iter().fold(F::zero(), |acc, &val| acc + val);
114 return Ok(sum / F::from(n).expect("Failed to convert to float"));
115 }
116
117 let sum = if let Some(slice) = x.as_slice() {
119 parallel_sum_slice(slice, &config)
121 } else {
122 parallel_sum_indexed(x, &config)
124 };
125
126 Ok(sum / F::from(n).expect("Failed to convert to float"))
127}
128
129#[allow(dead_code)]
133pub fn variance_parallel_enhanced<F, D>(
134 x: &ArrayBase<D, Ix1>,
135 ddof: usize,
136 config: Option<ParallelConfig>,
137) -> StatsResult<F>
138where
139 F: Float + NumCast + Send + Sync + std::iter::Sum<F> + std::fmt::Display,
140 D: Data<Elem = F> + Sync,
141{
142 let n = x.len();
143 if n <= ddof {
144 return Err(StatsError::invalid_argument(
145 "Not enough data points for the given degrees of freedom",
146 ));
147 }
148
149 let config = config.unwrap_or_default();
150
151 if !config.should_parallelize(n) {
152 return variance_sequential_welford(x, ddof);
154 }
155
156 let chunksize = config.get_chunksize(n);
158 let n_chunks = n.div_ceil(chunksize);
159
160 let chunk_stats: Vec<(F, F, usize)> = (0..n_chunks)
162 .collect::<Vec<_>>()
163 .into_par_iter()
164 .map(|chunk_idx| {
165 let start = chunk_idx * chunksize;
166 let end = (start + chunksize).min(n);
167
168 let mut local_mean = F::zero();
169 let mut local_m2 = F::zero();
170 let mut count = 0;
171
172 for i in start..end {
173 count += 1;
174 let val = x[i];
175 let delta = val - local_mean;
176 local_mean =
177 local_mean + delta / F::from(count).expect("Failed to convert to float");
178 let delta2 = val - local_mean;
179 local_m2 = local_m2 + delta * delta2;
180 }
181
182 (local_mean, local_m2, count)
183 })
184 .collect();
185
186 let (_total_mean, total_m2__, total_count) = combine_welford_stats(&chunk_stats);
188
189 Ok(total_m2__ / F::from(n - ddof).expect("Failed to convert to float"))
190}
191
192#[allow(dead_code)]
196pub fn corrcoef_parallel_enhanced<F, D>(
197 data: &ArrayBase<D, Ix2>,
198 config: Option<ParallelConfig>,
199) -> StatsResult<Array2<F>>
200where
201 F: Float + NumCast + Send + Sync + std::iter::Sum<F> + std::fmt::Display,
202 D: Data<Elem = F> + Sync,
203{
204 let (n_samples_, n_features) = data.dim();
205
206 if n_samples_ == 0 || n_features == 0 {
207 return Err(StatsError::invalid_argument("Empty data matrix"));
208 }
209
210 let config = config.unwrap_or_default();
211
212 let means: Vec<F> = (0..n_features)
214 .collect::<Vec<_>>()
215 .into_par_iter()
216 .map(|j| {
217 let col = data.column(j);
218 mean_parallel_enhanced(&col, Some(config.clone())).unwrap_or(F::zero())
219 })
220 .collect();
221
222 let mut corr_matrix = Array2::zeros((n_features, n_features));
224
225 let indices: Vec<(usize, usize)> = (0..n_features)
227 .flat_map(|i| (i..n_features).map(move |j| (i, j)))
228 .collect();
229
230 let correlations: Vec<((usize, usize), F)> = indices
231 .into_par_iter()
232 .map(|(i, j)| {
233 let corr = if i == j {
234 F::one() } else {
236 compute_correlation_pair(&data.column(i), &data.column(j), means[i], means[j])
237 };
238 ((i, j), corr)
239 })
240 .collect();
241
242 for ((i, j), corr) in correlations {
244 corr_matrix[(i, j)] = corr;
245 if i != j {
246 corr_matrix[(j, i)] = corr; }
248 }
249
250 Ok(corr_matrix)
251}
252
253#[allow(dead_code)]
257pub fn bootstrap_parallel_enhanced<F, D>(
258 data: &ArrayBase<D, Ix1>,
259 n_samples_: usize,
260 statistic_fn: impl Fn(&ArrayView1<F>) -> F + Send + Sync,
261 config: Option<ParallelConfig>,
262) -> StatsResult<Array1<F>>
263where
264 F: Float + NumCast + Send + Sync,
265 D: Data<Elem = F> + Sync,
266{
267 if data.is_empty() {
268 return Err(StatsError::invalid_argument("Cannot bootstrap empty data"));
269 }
270
271 let _config = config.unwrap_or_default();
272 let data_arc = Arc::new(data.to_owned());
273 let n = data.len();
274
275 let stats: Vec<F> = (0..n_samples_)
277 .collect::<Vec<_>>()
278 .into_par_iter()
279 .map(|sample_idx| {
280 use scirs2_core::random::rngs::StdRng;
281 use scirs2_core::random::{Rng, SeedableRng};
282
283 let mut rng = StdRng::seed_from_u64(sample_idx as u64);
285 let mut sample = Array1::zeros(n);
286
287 for i in 0..n {
289 let idx = rng.random_range(0..n);
290 sample[i] = data_arc[idx];
291 }
292
293 statistic_fn(&sample.view())
294 })
295 .collect();
296
297 Ok(Array1::from(stats))
298}
299
300#[allow(dead_code)]
302fn parallel_sum_slice<F>(slice: &[F], config: &ParallelConfig) -> F
303where
304 F: Float + NumCast + Send + Sync + std::iter::Sum + std::fmt::Display,
305{
306 let chunksize = config.get_chunksize(slice.len());
307
308 par_chunks(slice, chunksize)
309 .map(|chunk| chunk.iter().fold(F::zero(), |acc, &val| acc + val))
310 .reduce(|| F::zero(), |a, b| a + b)
311}
312
313#[allow(dead_code)]
315fn parallel_sum_indexed<F, D>(arr: &ArrayBase<D, Ix1>, config: &ParallelConfig) -> F
316where
317 F: Float + NumCast + Send + Sync + std::iter::Sum<F> + std::fmt::Display,
318 D: Data<Elem = F> + Sync,
319{
320 let n = arr.len();
321 let chunksize = config.get_chunksize(n);
322 let n_chunks = n.div_ceil(chunksize);
323
324 (0..n_chunks)
325 .collect::<Vec<_>>()
326 .into_par_iter()
327 .map(|chunk_idx| {
328 let start = chunk_idx * chunksize;
329 let end = (start + chunksize).min(n);
330
331 (start..end)
332 .map(|i| arr[i])
333 .fold(F::zero(), |acc, val| acc + val)
334 })
335 .reduce(|| F::zero(), |a, b| a + b)
336}
337
338#[allow(dead_code)]
340fn variance_sequential_welford<F, D>(x: &ArrayBase<D, Ix1>, ddof: usize) -> StatsResult<F>
341where
342 F: Float + NumCast,
343 D: Data<Elem = F>,
344{
345 let mut mean = F::zero();
346 let mut m2 = F::zero();
347 let mut count = 0;
348
349 for &val in x.iter() {
350 count += 1;
351 let delta = val - mean;
352 mean = mean + delta / F::from(count).expect("Failed to convert to float");
353 let delta2 = val - mean;
354 m2 = m2 + delta * delta2;
355 }
356
357 Ok(m2 / F::from(count - ddof).expect("Failed to convert to float"))
358}
359
360#[allow(dead_code)]
362fn combine_welford_stats<F>(stats: &[(F, F, usize)]) -> (F, F, usize)
363where
364 F: Float + NumCast + std::fmt::Display,
365{
366 stats.iter().fold(
367 (F::zero(), F::zero(), 0),
368 |(mean_a, m2_a, count_a), &(mean_b, m2_b, count_b)| {
369 let count = count_a + count_b;
370 let delta = mean_b - mean_a;
371 let mean = mean_a
372 + delta * F::from(count_b).expect("Failed to convert to float")
373 / F::from(count).expect("Failed to convert to float");
374 let m2 = m2_a
375 + m2_b
376 + delta
377 * delta
378 * F::from(count_a).expect("Failed to convert to float")
379 * F::from(count_b).expect("Failed to convert to float")
380 / F::from(count).expect("Failed to convert to float");
381 (mean, m2, count)
382 },
383 )
384}
385
386#[allow(dead_code)]
388fn compute_correlation_pair<F>(x: &ArrayView1<F>, y: &ArrayView1<F>, mean_x: F, meany: F) -> F
389where
390 F: Float + NumCast + std::fmt::Display,
391{
392 let n = x.len();
393 let mut cov = F::zero();
394 let mut var_x = F::zero();
395 let mut var_y = F::zero();
396
397 for i in 0..n {
398 let dx = x[i] - mean_x;
399 let dy = y[i] - meany;
400 cov = cov + dx * dy;
401 var_x = var_x + dx * dx;
402 var_y = var_y + dy * dy;
403 }
404
405 if var_x > F::epsilon() && var_y > F::epsilon() {
406 cov / (var_x * var_y).sqrt()
407 } else {
408 F::zero()
409 }
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415 use scirs2_core::ndarray::array;
416
417 #[test]
418 fn test_parallel_config() {
419 let config = ParallelConfig::default();
420 assert!(config.should_parallelize(100_000));
421 assert!(!config.should_parallelize(100));
422
423 let config_fixed = ParallelConfig::default()
424 .with_threads(4)
425 .with_chunksize(1000);
426 assert_eq!(config_fixed.get_chunksize(10_000), 1000);
427 }
428
429 #[test]
430 fn test_mean_parallel_enhanced() {
431 let data = Array1::from_vec((0..10_000).map(|i| i as f64).collect());
432 let mean = mean_parallel_enhanced(&data.view(), None).expect("Operation failed");
433 assert!((mean - 4999.5).abs() < 1e-10);
434 }
435
436 #[test]
437 fn test_variance_parallel_enhanced() {
438 let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
439 let var = variance_parallel_enhanced(&data.view(), 1, None).expect("Operation failed");
440 assert!((var - 2.5).abs() < 1e-10);
441 }
442}