1use scirs2_core::distributed::par_iter::{par_fold, par_map};
29use scirs2_core::distributed::primitives::{distributed_map, distributed_map_reduce};
30use scirs2_core::ndarray::{Array1, Array2};
31
32use crate::error::{DatasetsError, Result};
33use crate::utils::Dataset;
34
35pub fn par_map_rows<U, F>(dataset: &Dataset, f: F, num_workers: Option<usize>) -> Result<Vec<U>>
56where
57 U: Send + 'static,
58 F: Fn(Vec<f64>) -> U + Send + Sync + 'static,
59{
60 let rows: Vec<Vec<f64>> = dataset
63 .data
64 .rows()
65 .into_iter()
66 .map(|row| row.to_vec())
67 .collect();
68
69 let mapped = par_map(&rows, |row| f(row.clone()), num_workers);
70 Ok(mapped)
71}
72
73pub fn par_fold_rows<A, FoldOp, CombineOp>(
97 dataset: &Dataset,
98 identity: A,
99 fold_fn: FoldOp,
100 combine_fn: CombineOp,
101 num_workers: Option<usize>,
102) -> Result<A>
103where
104 A: Clone + Send + Sync + 'static,
105 FoldOp: Fn(A, &Vec<f64>) -> A + Send + Sync + 'static,
106 CombineOp: Fn(A, A) -> A + Send + Sync + 'static,
107{
108 let rows: Vec<Vec<f64>> = dataset
109 .data
110 .rows()
111 .into_iter()
112 .map(|row| row.to_vec())
113 .collect();
114
115 let result = par_fold(&rows, identity, fold_fn, combine_fn, num_workers);
116 Ok(result)
117}
118
119pub fn core_par_map_chunks<R, F>(
136 dataset: &Dataset,
137 chunk_size: usize,
138 n_workers: usize,
139 f: F,
140) -> Result<Vec<R>>
141where
142 R: Send + 'static,
143 F: Fn(Dataset) -> R + Send + Clone + 'static,
144{
145 let chunks = build_chunks(dataset, chunk_size)?;
146 let results = distributed_map(chunks, f, n_workers);
147 Ok(results)
148}
149
150pub fn core_map_reduce_chunks<R, S, F, G>(
174 dataset: &Dataset,
175 chunk_size: usize,
176 n_workers: usize,
177 map_fn: F,
178 reduce_fn: G,
179 initial: S,
180) -> Result<S>
181where
182 R: Send + 'static,
183 S: Send + Clone + 'static,
184 F: Fn(Dataset) -> R + Send + Clone + 'static,
185 G: Fn(S, R) -> S + Send + Clone + 'static,
186{
187 let chunks = build_chunks(dataset, chunk_size)?;
188 let result = distributed_map_reduce(chunks, map_fn, reduce_fn, initial, n_workers);
189 Ok(result)
190}
191
192pub fn par_feature_stats(
203 dataset: &Dataset,
204 chunk_size: usize,
205 n_workers: usize,
206) -> Result<FeatureStats> {
207 let n_features = dataset.n_features();
208 if n_features == 0 {
209 return Err(DatasetsError::InvalidFormat(
210 "Dataset has no features".to_string(),
211 ));
212 }
213
214 let chunks = build_chunks(dataset, chunk_size)?;
215 if chunks.is_empty() {
216 return Ok(FeatureStats::zeros(n_features));
217 }
218
219 let partial_stats: Vec<PartialStats> = distributed_map(
221 chunks,
222 move |chunk| PartialStats::from_dataset(&chunk),
223 n_workers,
224 );
225
226 let merged = partial_stats
228 .into_iter()
229 .reduce(|a, b| a.merge(&b))
230 .ok_or_else(|| DatasetsError::InvalidFormat("No chunks to reduce".to_string()))?;
231
232 Ok(merged.finalise())
233}
234
235#[derive(Debug, Clone)]
244struct PartialStats {
245 n: usize,
246 sums: Vec<f64>,
247 sum_sq: Vec<f64>,
248 mins: Vec<f64>,
249 maxs: Vec<f64>,
250}
251
252impl PartialStats {
253 fn from_dataset(ds: &Dataset) -> Self {
254 let n_features = ds.n_features();
255 let mut sums = vec![0.0f64; n_features];
256 let mut sum_sq = vec![0.0f64; n_features];
257 let mut mins = vec![f64::INFINITY; n_features];
258 let mut maxs = vec![f64::NEG_INFINITY; n_features];
259
260 for row in ds.data.rows() {
261 for (j, &v) in row.iter().enumerate() {
262 sums[j] += v;
263 sum_sq[j] += v * v;
264 if v < mins[j] {
265 mins[j] = v;
266 }
267 if v > maxs[j] {
268 maxs[j] = v;
269 }
270 }
271 }
272
273 Self {
274 n: ds.n_samples(),
275 sums,
276 sum_sq,
277 mins,
278 maxs,
279 }
280 }
281
282 fn merge(&self, other: &Self) -> Self {
284 let n_features = self.sums.len();
285 let mut sums = vec![0.0f64; n_features];
286 let mut sum_sq = vec![0.0f64; n_features];
287 let mut mins = vec![0.0f64; n_features];
288 let mut maxs = vec![0.0f64; n_features];
289
290 for j in 0..n_features {
291 sums[j] = self.sums[j] + other.sums[j];
292 sum_sq[j] = self.sum_sq[j] + other.sum_sq[j];
293 mins[j] = self.mins[j].min(other.mins[j]);
294 maxs[j] = self.maxs[j].max(other.maxs[j]);
295 }
296
297 Self {
298 n: self.n + other.n,
299 sums,
300 sum_sq,
301 mins,
302 maxs,
303 }
304 }
305
306 fn finalise(&self) -> FeatureStats {
308 let n = self.n as f64;
309 let n_features = self.sums.len();
310 let mut means = vec![0.0f64; n_features];
311 let mut variances = vec![0.0f64; n_features];
312
313 for j in 0..n_features {
314 let mean = if n > 0.0 { self.sums[j] / n } else { 0.0 };
315 means[j] = mean;
316 let variance = if n > 1.0 {
317 (self.sum_sq[j] / n) - mean * mean
319 } else {
320 0.0
321 };
322 variances[j] = variance.max(0.0); }
324
325 FeatureStats {
326 means,
327 variances,
328 mins: self.mins.clone(),
329 maxs: self.maxs.clone(),
330 n_samples: self.n,
331 }
332 }
333}
334
335#[derive(Debug, Clone)]
337pub struct FeatureStats {
338 pub means: Vec<f64>,
340 pub variances: Vec<f64>,
342 pub mins: Vec<f64>,
344 pub maxs: Vec<f64>,
346 pub n_samples: usize,
348}
349
350impl FeatureStats {
351 fn zeros(n_features: usize) -> Self {
353 Self {
354 means: vec![0.0; n_features],
355 variances: vec![0.0; n_features],
356 mins: vec![0.0; n_features],
357 maxs: vec![0.0; n_features],
358 n_samples: 0,
359 }
360 }
361
362 pub fn stds(&self) -> Vec<f64> {
364 self.variances.iter().map(|v| v.sqrt()).collect()
365 }
366}
367
368fn build_chunks(dataset: &Dataset, chunk_size: usize) -> Result<Vec<Dataset>> {
374 let chunk_size = chunk_size.max(1);
375 let n = dataset.n_samples();
376 let n_features = dataset.n_features();
377 let mut chunks = Vec::new();
378
379 let mut start = 0usize;
380 while start < n {
381 let end = (start + chunk_size).min(n);
382 let n_rows = end - start;
383
384 let flat: Vec<f64> = dataset
386 .data
387 .rows()
388 .into_iter()
389 .skip(start)
390 .take(n_rows)
391 .flat_map(|row| row.to_vec())
392 .collect();
393
394 let data = Array2::from_shape_vec((n_rows, n_features), flat)
395 .map_err(|e| DatasetsError::InvalidFormat(format!("chunk build failed: {}", e)))?;
396
397 let target = dataset.target.as_ref().map(|t| {
398 let vals: Vec<f64> = t.iter().skip(start).take(n_rows).copied().collect();
399 Array1::from_vec(vals)
400 });
401
402 chunks.push(Dataset {
403 data,
404 target,
405 featurenames: dataset.featurenames.clone(),
406 targetnames: dataset.targetnames.clone(),
407 feature_descriptions: dataset.feature_descriptions.clone(),
408 description: Some(format!("chunk {start}..{end}")),
409 metadata: dataset.metadata.clone(),
410 });
411
412 start = end;
413 }
414
415 Ok(chunks)
416}
417
418#[cfg(test)]
423mod tests {
424 use super::*;
425 use crate::generators::make_classification;
426
427 #[test]
430 fn test_build_chunks_total_rows_preserved() {
431 let ds = make_classification(47, 4, 2, 2, 1, Some(1)).expect("make_classification");
432 let chunks = build_chunks(&ds, 10).expect("build_chunks");
433
434 let total: usize = chunks.iter().map(|c| c.n_samples()).sum();
435 assert_eq!(total, 47, "total rows across chunks must equal source rows");
436 }
437
438 #[test]
439 fn test_build_chunks_exact_split() {
440 let ds = make_classification(30, 3, 2, 2, 1, Some(2)).expect("make_classification");
441 let chunks = build_chunks(&ds, 10).expect("build_chunks");
442 assert_eq!(chunks.len(), 3, "30 rows / 10 per chunk = 3 chunks");
443 for c in &chunks {
444 assert_eq!(c.n_samples(), 10);
445 }
446 }
447
448 #[test]
449 fn test_build_chunks_remainder() {
450 let ds = make_classification(25, 3, 2, 2, 1, Some(3)).expect("make_classification");
451 let chunks = build_chunks(&ds, 10).expect("build_chunks");
452 assert_eq!(chunks.len(), 3);
454 assert_eq!(chunks[2].n_samples(), 5);
455 }
456
457 #[test]
460 fn test_par_map_rows_count_matches() {
461 let ds = make_classification(60, 4, 2, 2, 1, Some(7)).expect("make_classification");
462 let results =
463 par_map_rows(&ds, |row| row.iter().copied().sum::<f64>(), None).expect("par_map_rows");
464 assert_eq!(results.len(), 60, "one result per row");
465 }
466
467 #[test]
468 fn test_par_map_rows_identity_feature_lengths() {
469 let ds = make_classification(20, 5, 2, 2, 1, Some(11)).expect("make_classification");
470 let lengths = par_map_rows(&ds, |row| row.len(), None).expect("par_map_rows");
471 assert!(
472 lengths.iter().all(|&l| l == 5),
473 "each mapped row should have 5 features"
474 );
475 }
476
477 #[test]
480 fn test_par_fold_rows_row_count() {
481 let ds = make_classification(80, 3, 2, 2, 1, Some(13)).expect("make_classification");
482 let count = par_fold_rows(&ds, 0usize, |acc, _row| acc + 1, |a, b| a + b, None)
483 .expect("par_fold_rows");
484 assert_eq!(count, 80, "fold should accumulate one per row");
485 }
486
487 #[test]
490 fn test_core_par_map_chunks_total_samples() {
491 let ds = make_classification(100, 4, 2, 3, 1, Some(17)).expect("make_classification");
492 let chunk_sample_counts =
493 core_par_map_chunks(&ds, 25, 2, |c| c.n_samples()).expect("core_par_map_chunks");
494 let total: usize = chunk_sample_counts.iter().sum();
495 assert_eq!(total, 100);
496 }
497
498 #[test]
499 fn test_core_par_map_chunks_feature_dim() {
500 let ds = make_classification(50, 6, 2, 2, 1, Some(19)).expect("make_classification");
501 let feature_counts =
502 core_par_map_chunks(&ds, 15, 2, |c| c.n_features()).expect("core_par_map_chunks");
503 assert!(
504 feature_counts.iter().all(|&f| f == 6),
505 "all chunks should have 6 features"
506 );
507 }
508
509 #[test]
512 fn test_core_map_reduce_total_sample_count() {
513 let ds = make_classification(120, 4, 2, 3, 1, Some(23)).expect("make_classification");
514 let total = core_map_reduce_chunks(
515 &ds,
516 30,
517 2,
518 |chunk| chunk.n_samples(),
519 |acc, r| acc + r,
520 0usize,
521 )
522 .expect("core_map_reduce_chunks");
523 assert_eq!(total, 120);
524 }
525
526 #[test]
529 fn test_par_feature_stats_n_samples() {
530 let ds = make_classification(200, 4, 2, 3, 1, Some(29)).expect("make_classification");
531 let stats = par_feature_stats(&ds, 50, 2).expect("par_feature_stats");
532 assert_eq!(stats.n_samples, 200);
533 }
534
535 #[test]
536 fn test_par_feature_stats_means_len() {
537 let ds = make_classification(100, 5, 2, 3, 1, Some(31)).expect("make_classification");
538 let stats = par_feature_stats(&ds, 25, 2).expect("par_feature_stats");
539 assert_eq!(stats.means.len(), 5, "one mean per feature");
540 assert_eq!(stats.variances.len(), 5);
541 assert_eq!(stats.mins.len(), 5);
542 assert_eq!(stats.maxs.len(), 5);
543 }
544
545 #[test]
546 fn test_par_feature_stats_mins_le_maxs() {
547 let ds = make_classification(80, 4, 2, 3, 1, Some(37)).expect("make_classification");
548 let stats = par_feature_stats(&ds, 20, 2).expect("par_feature_stats");
549 for j in 0..4 {
550 assert!(
551 stats.mins[j] <= stats.maxs[j],
552 "min[{j}] must be <= max[{j}]"
553 );
554 }
555 }
556
557 #[test]
558 fn test_par_feature_stats_variances_nonnegative() {
559 let ds = make_classification(60, 3, 2, 2, 1, Some(41)).expect("make_classification");
560 let stats = par_feature_stats(&ds, 20, 2).expect("par_feature_stats");
561 for (j, &v) in stats.variances.iter().enumerate() {
562 assert!(v >= 0.0, "variance[{j}] must be non-negative, got {v}");
563 }
564 }
565
566 #[test]
567 fn test_feature_stats_stds() {
568 let ds = make_classification(40, 3, 2, 2, 1, Some(43)).expect("make_classification");
569 let stats = par_feature_stats(&ds, 10, 2).expect("par_feature_stats");
570 let stds = stats.stds();
571 assert_eq!(stds.len(), 3);
572 for (j, &s) in stds.iter().enumerate() {
573 assert!(s >= 0.0, "std[{j}] must be non-negative, got {s}");
574 }
575 }
576}