Skip to main content

scirs2_datasets/
distributed_core.rs

1//! Integration of `scirs2-core` distributed primitives with dataset loading.
2//!
3//! This module bridges `scirs2-datasets` distributed infrastructure to the
4//! production-grade primitives in `scirs2_core::distributed`:
5//!
6//! - [`par_map`] / [`par_fold`] from `scirs2_core::distributed::par_iter` —
7//!   lightweight OS-thread parallel iterators that need no async runtime.
8//! - [`distributed_map`] / [`distributed_map_reduce`] from
9//!   `scirs2_core::distributed::primitives` — a `WorkerPool`-backed parallel
10//!   map / map-reduce with result-order preservation.
11//!
12//! ## Relationship to existing modules
13//!
14//! `crate::distributed` (`distributed.rs`) provides `DistributedProcessor`,
15//! `DistributedConfig`, and `ScalingParameters` — higher-level machinery
16//! built on `std::thread::spawn + mpsc`.  This module provides lower-level
17//! **core-backed** helpers that use the same primitives as the rest of the
18//! SciRS2 ecosystem.
19//!
20//! ## Design goals
21//!
22//! 1. Zero new public types — pure functional helpers.
23//! 2. Composable: callers supply an `Fn` closure; results are returned in
24//!    input order.
25//! 3. Fallible: every public function returns `Result<_, DatasetsError>`.
26//! 4. No C/Fortran transitive deps (follows COOLJAPAN Pure Rust Policy).
27
28use scirs2_core::distributed::par_iter::{par_fold, par_map};
29use scirs2_core::distributed::primitives::{distributed_map, distributed_map_reduce};
30use scirs2_core::ndarray::{Array1, Array2};
31
32use crate::error::{DatasetsError, Result};
33use crate::utils::Dataset;
34
35// ─────────────────────────────────────────────────────────────────────────────
36// Public API
37// ─────────────────────────────────────────────────────────────────────────────
38
39/// Apply `f` to each row of `dataset.data` in parallel using
40/// `scirs2_core::distributed::par_iter::par_map`.
41///
42/// Rows are split into contiguous chunks; each chunk is processed on a
43/// separate OS thread.  Results are returned in input order.
44///
45/// # Arguments
46///
47/// * `dataset`     — Source dataset (only `.data` is used).
48/// * `f`           — Closure applied to each row (as a `Vec<f64>`).
49/// * `num_workers` — Override number of worker threads. `None` → logical
50///   CPU count.
51///
52/// # Returns
53///
54/// `Ok(Vec<U>)` with one element per row, in original order.
55pub fn par_map_rows<U, F>(dataset: &Dataset, f: F, num_workers: Option<usize>) -> Result<Vec<U>>
56where
57    U: Send + 'static,
58    F: Fn(Vec<f64>) -> U + Send + Sync + 'static,
59{
60    // Materialise rows as owned Vecs so we don't hold a reference to dataset
61    // across the thread boundary.
62    let rows: Vec<Vec<f64>> = dataset
63        .data
64        .rows()
65        .into_iter()
66        .map(|row| row.to_vec())
67        .collect();
68
69    let mapped = par_map(&rows, |row| f(row.clone()), num_workers);
70    Ok(mapped)
71}
72
73/// Reduce the rows of `dataset.data` in parallel using
74/// `scirs2_core::distributed::par_iter::par_fold`.
75///
76/// The fold is first applied within each chunk (on its own thread); partial
77/// accumulators are then combined with `combine_fn` on the calling thread.
78///
79/// # Type parameters
80///
81/// * `A`          — Accumulator type (must implement `Clone + Send + 'static`).
82/// * `FoldOp`     — Per-element fold operation.
83/// * `CombineOp`  — Accumulator combination operation.
84///
85/// # Arguments
86///
87/// * `dataset`    — Source dataset.
88/// * `identity`   — Identity value for the accumulator.
89/// * `fold_fn`    — `|acc, row| -> A` applied sequentially within a chunk.
90/// * `combine_fn` — `|acc_a, acc_b| -> A` used to merge chunk accumulators.
91/// * `num_workers` — Override thread count; `None` → CPU count.
92///
93/// # Returns
94///
95/// `Ok(A)` with the final reduced accumulator.
96pub fn par_fold_rows<A, FoldOp, CombineOp>(
97    dataset: &Dataset,
98    identity: A,
99    fold_fn: FoldOp,
100    combine_fn: CombineOp,
101    num_workers: Option<usize>,
102) -> Result<A>
103where
104    A: Clone + Send + Sync + 'static,
105    FoldOp: Fn(A, &Vec<f64>) -> A + Send + Sync + 'static,
106    CombineOp: Fn(A, A) -> A + Send + Sync + 'static,
107{
108    let rows: Vec<Vec<f64>> = dataset
109        .data
110        .rows()
111        .into_iter()
112        .map(|row| row.to_vec())
113        .collect();
114
115    let result = par_fold(&rows, identity, fold_fn, combine_fn, num_workers);
116    Ok(result)
117}
118
119/// Apply `f` to each dataset chunk (slice of rows) in parallel using
120/// `scirs2_core::distributed::primitives::distributed_map`.
121///
122/// This uses the `WorkerPool`-backed primitive from `scirs2-core` which
123/// preserves output ordering and supports arbitrary chunk sizes.
124///
125/// # Arguments
126///
127/// * `dataset`     — Source dataset to chunk.
128/// * `chunk_size`  — Number of rows per chunk.
129/// * `n_workers`   — Number of worker threads.
130/// * `f`           — Closure applied to each chunk (given as a sub-`Dataset`).
131///
132/// # Returns
133///
134/// `Ok(Vec<R>)` with one element per chunk, in original order.
135pub fn core_par_map_chunks<R, F>(
136    dataset: &Dataset,
137    chunk_size: usize,
138    n_workers: usize,
139    f: F,
140) -> Result<Vec<R>>
141where
142    R: Send + 'static,
143    F: Fn(Dataset) -> R + Send + Clone + 'static,
144{
145    let chunks = build_chunks(dataset, chunk_size)?;
146    let results = distributed_map(chunks, f, n_workers);
147    Ok(results)
148}
149
150/// Map-reduce over dataset chunks using
151/// `scirs2_core::distributed::primitives::distributed_map_reduce`.
152///
153/// The map phase processes chunks in parallel; the reduce phase combines the
154/// mapped results into a single accumulator on the calling thread.
155///
156/// # Type parameters
157///
158/// * `R` — Per-chunk map result.
159/// * `S` — Accumulator type.
160///
161/// # Arguments
162///
163/// * `dataset`     — Source dataset.
164/// * `chunk_size`  — Number of rows per chunk.
165/// * `n_workers`   — Worker thread count.
166/// * `map_fn`      — `|chunk: Dataset| -> R`.
167/// * `reduce_fn`   — `|acc: S, r: R| -> S`.
168/// * `initial`     — Initial accumulator value.
169///
170/// # Returns
171///
172/// `Ok(S)` with the final reduced accumulator.
173pub fn core_map_reduce_chunks<R, S, F, G>(
174    dataset: &Dataset,
175    chunk_size: usize,
176    n_workers: usize,
177    map_fn: F,
178    reduce_fn: G,
179    initial: S,
180) -> Result<S>
181where
182    R: Send + 'static,
183    S: Send + Clone + 'static,
184    F: Fn(Dataset) -> R + Send + Clone + 'static,
185    G: Fn(S, R) -> S + Send + Clone + 'static,
186{
187    let chunks = build_chunks(dataset, chunk_size)?;
188    let result = distributed_map_reduce(chunks, map_fn, reduce_fn, initial, n_workers);
189    Ok(result)
190}
191
192/// Compute per-feature column statistics (mean, min, max, variance) in parallel
193/// across dataset chunks, using `scirs2_core` distributed map-reduce.
194///
195/// This is intended as a practical demonstration of the core integration:
196/// each chunk computes partial Welford-style statistics; partial statistics are
197/// combined on the calling thread.
198///
199/// # Returns
200///
201/// `Ok(FeatureStats)` with per-feature mean, min, max, and population variance.
202pub fn par_feature_stats(
203    dataset: &Dataset,
204    chunk_size: usize,
205    n_workers: usize,
206) -> Result<FeatureStats> {
207    let n_features = dataset.n_features();
208    if n_features == 0 {
209        return Err(DatasetsError::InvalidFormat(
210            "Dataset has no features".to_string(),
211        ));
212    }
213
214    let chunks = build_chunks(dataset, chunk_size)?;
215    if chunks.is_empty() {
216        return Ok(FeatureStats::zeros(n_features));
217    }
218
219    // Map: compute partial stats per chunk
220    let partial_stats: Vec<PartialStats> = distributed_map(
221        chunks,
222        move |chunk| PartialStats::from_dataset(&chunk),
223        n_workers,
224    );
225
226    // Reduce: merge all partial stats
227    let merged = partial_stats
228        .into_iter()
229        .reduce(|a, b| a.merge(&b))
230        .ok_or_else(|| DatasetsError::InvalidFormat("No chunks to reduce".to_string()))?;
231
232    Ok(merged.finalise())
233}
234
235// ─────────────────────────────────────────────────────────────────────────────
236// Statistics helpers
237// ─────────────────────────────────────────────────────────────────────────────
238
239/// Partial statistics computed over one chunk of rows.
240///
241/// Uses a parallel version of Welford's online algorithm for numerically stable
242/// mean and variance computation.  All fields are per-feature `Vec<f64>`.
243#[derive(Debug, Clone)]
244struct PartialStats {
245    n: usize,
246    sums: Vec<f64>,
247    sum_sq: Vec<f64>,
248    mins: Vec<f64>,
249    maxs: Vec<f64>,
250}
251
252impl PartialStats {
253    fn from_dataset(ds: &Dataset) -> Self {
254        let n_features = ds.n_features();
255        let mut sums = vec![0.0f64; n_features];
256        let mut sum_sq = vec![0.0f64; n_features];
257        let mut mins = vec![f64::INFINITY; n_features];
258        let mut maxs = vec![f64::NEG_INFINITY; n_features];
259
260        for row in ds.data.rows() {
261            for (j, &v) in row.iter().enumerate() {
262                sums[j] += v;
263                sum_sq[j] += v * v;
264                if v < mins[j] {
265                    mins[j] = v;
266                }
267                if v > maxs[j] {
268                    maxs[j] = v;
269                }
270            }
271        }
272
273        Self {
274            n: ds.n_samples(),
275            sums,
276            sum_sq,
277            mins,
278            maxs,
279        }
280    }
281
282    /// Merge `other` into `self`, returning a new combined `PartialStats`.
283    fn merge(&self, other: &Self) -> Self {
284        let n_features = self.sums.len();
285        let mut sums = vec![0.0f64; n_features];
286        let mut sum_sq = vec![0.0f64; n_features];
287        let mut mins = vec![0.0f64; n_features];
288        let mut maxs = vec![0.0f64; n_features];
289
290        for j in 0..n_features {
291            sums[j] = self.sums[j] + other.sums[j];
292            sum_sq[j] = self.sum_sq[j] + other.sum_sq[j];
293            mins[j] = self.mins[j].min(other.mins[j]);
294            maxs[j] = self.maxs[j].max(other.maxs[j]);
295        }
296
297        Self {
298            n: self.n + other.n,
299            sums,
300            sum_sq,
301            mins,
302            maxs,
303        }
304    }
305
306    /// Compute final `FeatureStats` from the accumulated sums.
307    fn finalise(&self) -> FeatureStats {
308        let n = self.n as f64;
309        let n_features = self.sums.len();
310        let mut means = vec![0.0f64; n_features];
311        let mut variances = vec![0.0f64; n_features];
312
313        for j in 0..n_features {
314            let mean = if n > 0.0 { self.sums[j] / n } else { 0.0 };
315            means[j] = mean;
316            let variance = if n > 1.0 {
317                // Population variance: E[X^2] - E[X]^2
318                (self.sum_sq[j] / n) - mean * mean
319            } else {
320                0.0
321            };
322            variances[j] = variance.max(0.0); // clamp floating-point negatives
323        }
324
325        FeatureStats {
326            means,
327            variances,
328            mins: self.mins.clone(),
329            maxs: self.maxs.clone(),
330            n_samples: self.n,
331        }
332    }
333}
334
335/// Per-feature statistics computed in parallel via `scirs2-core` primitives.
336#[derive(Debug, Clone)]
337pub struct FeatureStats {
338    /// Per-feature arithmetic means.
339    pub means: Vec<f64>,
340    /// Per-feature population variances.
341    pub variances: Vec<f64>,
342    /// Per-feature minimums.
343    pub mins: Vec<f64>,
344    /// Per-feature maximums.
345    pub maxs: Vec<f64>,
346    /// Total number of samples processed.
347    pub n_samples: usize,
348}
349
350impl FeatureStats {
351    /// Return a zero-initialised `FeatureStats` for `n_features` features.
352    fn zeros(n_features: usize) -> Self {
353        Self {
354            means: vec![0.0; n_features],
355            variances: vec![0.0; n_features],
356            mins: vec![0.0; n_features],
357            maxs: vec![0.0; n_features],
358            n_samples: 0,
359        }
360    }
361
362    /// Standard deviations (square roots of the population variances).
363    pub fn stds(&self) -> Vec<f64> {
364        self.variances.iter().map(|v| v.sqrt()).collect()
365    }
366}
367
368// ─────────────────────────────────────────────────────────────────────────────
369// Internal helpers
370// ─────────────────────────────────────────────────────────────────────────────
371
372/// Split `dataset` into owned `Dataset` chunks of at most `chunk_size` rows.
373fn build_chunks(dataset: &Dataset, chunk_size: usize) -> Result<Vec<Dataset>> {
374    let chunk_size = chunk_size.max(1);
375    let n = dataset.n_samples();
376    let n_features = dataset.n_features();
377    let mut chunks = Vec::new();
378
379    let mut start = 0usize;
380    while start < n {
381        let end = (start + chunk_size).min(n);
382        let n_rows = end - start;
383
384        // Build data array for this chunk
385        let flat: Vec<f64> = dataset
386            .data
387            .rows()
388            .into_iter()
389            .skip(start)
390            .take(n_rows)
391            .flat_map(|row| row.to_vec())
392            .collect();
393
394        let data = Array2::from_shape_vec((n_rows, n_features), flat)
395            .map_err(|e| DatasetsError::InvalidFormat(format!("chunk build failed: {}", e)))?;
396
397        let target = dataset.target.as_ref().map(|t| {
398            let vals: Vec<f64> = t.iter().skip(start).take(n_rows).copied().collect();
399            Array1::from_vec(vals)
400        });
401
402        chunks.push(Dataset {
403            data,
404            target,
405            featurenames: dataset.featurenames.clone(),
406            targetnames: dataset.targetnames.clone(),
407            feature_descriptions: dataset.feature_descriptions.clone(),
408            description: Some(format!("chunk {start}..{end}")),
409            metadata: dataset.metadata.clone(),
410        });
411
412        start = end;
413    }
414
415    Ok(chunks)
416}
417
418// ─────────────────────────────────────────────────────────────────────────────
419// Tests
420// ─────────────────────────────────────────────────────────────────────────────
421
422#[cfg(test)]
423mod tests {
424    use super::*;
425    use crate::generators::make_classification;
426
427    // ── build_chunks ──────────────────────────────────────────────────────────
428
429    #[test]
430    fn test_build_chunks_total_rows_preserved() {
431        let ds = make_classification(47, 4, 2, 2, 1, Some(1)).expect("make_classification");
432        let chunks = build_chunks(&ds, 10).expect("build_chunks");
433
434        let total: usize = chunks.iter().map(|c| c.n_samples()).sum();
435        assert_eq!(total, 47, "total rows across chunks must equal source rows");
436    }
437
438    #[test]
439    fn test_build_chunks_exact_split() {
440        let ds = make_classification(30, 3, 2, 2, 1, Some(2)).expect("make_classification");
441        let chunks = build_chunks(&ds, 10).expect("build_chunks");
442        assert_eq!(chunks.len(), 3, "30 rows / 10 per chunk = 3 chunks");
443        for c in &chunks {
444            assert_eq!(c.n_samples(), 10);
445        }
446    }
447
448    #[test]
449    fn test_build_chunks_remainder() {
450        let ds = make_classification(25, 3, 2, 2, 1, Some(3)).expect("make_classification");
451        let chunks = build_chunks(&ds, 10).expect("build_chunks");
452        // 10, 10, 5
453        assert_eq!(chunks.len(), 3);
454        assert_eq!(chunks[2].n_samples(), 5);
455    }
456
457    // ── par_map_rows ──────────────────────────────────────────────────────────
458
459    #[test]
460    fn test_par_map_rows_count_matches() {
461        let ds = make_classification(60, 4, 2, 2, 1, Some(7)).expect("make_classification");
462        let results =
463            par_map_rows(&ds, |row| row.iter().copied().sum::<f64>(), None).expect("par_map_rows");
464        assert_eq!(results.len(), 60, "one result per row");
465    }
466
467    #[test]
468    fn test_par_map_rows_identity_feature_lengths() {
469        let ds = make_classification(20, 5, 2, 2, 1, Some(11)).expect("make_classification");
470        let lengths = par_map_rows(&ds, |row| row.len(), None).expect("par_map_rows");
471        assert!(
472            lengths.iter().all(|&l| l == 5),
473            "each mapped row should have 5 features"
474        );
475    }
476
477    // ── par_fold_rows ─────────────────────────────────────────────────────────
478
479    #[test]
480    fn test_par_fold_rows_row_count() {
481        let ds = make_classification(80, 3, 2, 2, 1, Some(13)).expect("make_classification");
482        let count = par_fold_rows(&ds, 0usize, |acc, _row| acc + 1, |a, b| a + b, None)
483            .expect("par_fold_rows");
484        assert_eq!(count, 80, "fold should accumulate one per row");
485    }
486
487    // ── core_par_map_chunks ───────────────────────────────────────────────────
488
489    #[test]
490    fn test_core_par_map_chunks_total_samples() {
491        let ds = make_classification(100, 4, 2, 3, 1, Some(17)).expect("make_classification");
492        let chunk_sample_counts =
493            core_par_map_chunks(&ds, 25, 2, |c| c.n_samples()).expect("core_par_map_chunks");
494        let total: usize = chunk_sample_counts.iter().sum();
495        assert_eq!(total, 100);
496    }
497
498    #[test]
499    fn test_core_par_map_chunks_feature_dim() {
500        let ds = make_classification(50, 6, 2, 2, 1, Some(19)).expect("make_classification");
501        let feature_counts =
502            core_par_map_chunks(&ds, 15, 2, |c| c.n_features()).expect("core_par_map_chunks");
503        assert!(
504            feature_counts.iter().all(|&f| f == 6),
505            "all chunks should have 6 features"
506        );
507    }
508
509    // ── core_map_reduce_chunks ────────────────────────────────────────────────
510
511    #[test]
512    fn test_core_map_reduce_total_sample_count() {
513        let ds = make_classification(120, 4, 2, 3, 1, Some(23)).expect("make_classification");
514        let total = core_map_reduce_chunks(
515            &ds,
516            30,
517            2,
518            |chunk| chunk.n_samples(),
519            |acc, r| acc + r,
520            0usize,
521        )
522        .expect("core_map_reduce_chunks");
523        assert_eq!(total, 120);
524    }
525
526    // ── par_feature_stats ─────────────────────────────────────────────────────
527
528    #[test]
529    fn test_par_feature_stats_n_samples() {
530        let ds = make_classification(200, 4, 2, 3, 1, Some(29)).expect("make_classification");
531        let stats = par_feature_stats(&ds, 50, 2).expect("par_feature_stats");
532        assert_eq!(stats.n_samples, 200);
533    }
534
535    #[test]
536    fn test_par_feature_stats_means_len() {
537        let ds = make_classification(100, 5, 2, 3, 1, Some(31)).expect("make_classification");
538        let stats = par_feature_stats(&ds, 25, 2).expect("par_feature_stats");
539        assert_eq!(stats.means.len(), 5, "one mean per feature");
540        assert_eq!(stats.variances.len(), 5);
541        assert_eq!(stats.mins.len(), 5);
542        assert_eq!(stats.maxs.len(), 5);
543    }
544
545    #[test]
546    fn test_par_feature_stats_mins_le_maxs() {
547        let ds = make_classification(80, 4, 2, 3, 1, Some(37)).expect("make_classification");
548        let stats = par_feature_stats(&ds, 20, 2).expect("par_feature_stats");
549        for j in 0..4 {
550            assert!(
551                stats.mins[j] <= stats.maxs[j],
552                "min[{j}] must be <= max[{j}]"
553            );
554        }
555    }
556
557    #[test]
558    fn test_par_feature_stats_variances_nonnegative() {
559        let ds = make_classification(60, 3, 2, 2, 1, Some(41)).expect("make_classification");
560        let stats = par_feature_stats(&ds, 20, 2).expect("par_feature_stats");
561        for (j, &v) in stats.variances.iter().enumerate() {
562            assert!(v >= 0.0, "variance[{j}] must be non-negative, got {v}");
563        }
564    }
565
566    #[test]
567    fn test_feature_stats_stds() {
568        let ds = make_classification(40, 3, 2, 2, 1, Some(43)).expect("make_classification");
569        let stats = par_feature_stats(&ds, 10, 2).expect("par_feature_stats");
570        let stds = stats.stds();
571        assert_eq!(stds.len(), 3);
572        for (j, &s) in stds.iter().enumerate() {
573            assert!(s >= 0.0, "std[{j}] must be non-negative, got {s}");
574        }
575    }
576}