scirs2_datasets/
distributed.rs

1//! Distributed dataset processing capabilities
2//!
3//! This module provides functionality for processing datasets across multiple machines or processes:
4//! - Parallel data loading and processing
5//! - Dataset sharding and distribution
6//! - Distributed sampling and cross-validation
7//! - MapReduce-style operations on datasets
8
9use std::collections::HashMap;
10use std::sync::mpsc;
11use std::sync::Arc;
12use std::thread;
13
14use scirs2_core::ndarray::{Array1, Array2, Axis};
15use serde::{Deserialize, Serialize};
16
17use crate::cache::DatasetCache;
18use crate::error::{DatasetsError, Result};
19use crate::utils::Dataset;
20
21/// Configuration for distributed processing
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct DistributedConfig {
24    /// Number of worker processes/threads
25    pub num_workers: usize,
26    /// Chunk size for processing
27    pub chunk_size: usize,
28    /// Communication timeout (seconds)
29    pub timeout: u64,
30    /// Whether to use shared memory for large datasets
31    pub use_shared_memory: bool,
32    /// Maximum memory per worker (MB)
33    pub memory_limit_mb: usize,
34}
35
36impl Default for DistributedConfig {
37    fn default() -> Self {
38        let num_cpus = thread::available_parallelism()
39            .map(|n| n.get())
40            .unwrap_or(4);
41
42        Self {
43            num_workers: num_cpus,
44            chunk_size: 10000,
45            timeout: 300,
46            use_shared_memory: false,
47            memory_limit_mb: 1024,
48        }
49    }
50}
51
52/// Distributed dataset processor
53pub struct DistributedProcessor {
54    config: DistributedConfig,
55    #[allow(dead_code)]
56    cache: DatasetCache,
57}
58
59impl DistributedProcessor {
60    /// Create a new distributed processor
61    pub fn new(config: DistributedConfig) -> Result<Self> {
62        let cachedir = dirs::cache_dir()
63            .ok_or_else(|| DatasetsError::Other("Could not determine cache directory".to_string()))?
64            .join("scirs2-datasets");
65        let cache = DatasetCache::new(cachedir);
66
67        Ok(Self { config, cache })
68    }
69
70    /// Create with default configuration
71    pub fn default_config() -> Result<Self> {
72        Self::new(DistributedConfig::default())
73    }
74
75    /// Process a large dataset in parallel chunks
76    pub fn process_dataset_parallel<F, R>(&self, dataset: &Dataset, processor: F) -> Result<Vec<R>>
77    where
78        F: Fn(&Dataset) -> Result<R> + Send + Sync + Clone + 'static,
79        R: Send + 'static,
80    {
81        let chunks = self.split_dataset_into_chunks(dataset)?;
82        let processor = Arc::new(processor);
83
84        let (tx, rx) = mpsc::channel();
85        let mut handles = Vec::new();
86
87        // Spawn worker threads
88        for chunk in chunks {
89            let tx = tx.clone();
90            let processor = Arc::clone(&processor);
91
92            let handle = thread::spawn(move || {
93                let result = processor(&chunk);
94                let _ = tx.send(result);
95            });
96
97            handles.push(handle);
98        }
99
100        // Drop the original sender
101        drop(tx);
102
103        // Collect results
104        let mut results = Vec::new();
105        for result in rx {
106            results.push(result?);
107        }
108
109        // Wait for all workers to finish
110        for handle in handles {
111            let _ = handle.join();
112        }
113
114        Ok(results)
115    }
116
117    /// Distribute dataset across multiple workers with MapReduce pattern
118    pub fn map_reduce_dataset<M, R, C>(&self, dataset: &Dataset, mapper: M, reducer: R) -> Result<C>
119    where
120        M: Fn(&Dataset) -> Result<Vec<C>> + Send + Sync + Clone + 'static,
121        R: Fn(Vec<C>) -> Result<C> + Send + Sync + 'static,
122        C: Send + 'static,
123    {
124        // Map phase: process chunks in parallel
125        let map_results = self.process_dataset_parallel(dataset, mapper)?;
126
127        // Reduce phase: combine results
128        let flattened: Vec<C> = map_results.into_iter().flatten().collect();
129        reducer(flattened)
130    }
131
132    /// Split a dataset into balanced chunks for distribution
133    pub fn split_dataset_into_chunks(&self, dataset: &Dataset) -> Result<Vec<Dataset>> {
134        let n_samples = dataset.n_samples();
135        let chunk_size = self
136            .config
137            .chunk_size
138            .min(n_samples / self.config.num_workers + 1);
139
140        let mut chunks = Vec::new();
141
142        for start in (0..n_samples).step_by(chunk_size) {
143            let end = (start + chunk_size).min(n_samples);
144            let chunk_data = dataset.data.slice(s![start..end, ..]).to_owned();
145
146            let chunk_target = dataset
147                .target
148                .as_ref()
149                .map(|target| target.slice(s![start..end]).to_owned());
150
151            let chunk = Dataset {
152                data: chunk_data,
153                target: chunk_target,
154                featurenames: dataset.featurenames.clone(),
155                targetnames: dataset.targetnames.clone(),
156                feature_descriptions: dataset.feature_descriptions.clone(),
157                description: Some(format!("Chunk {start}-{end} of distributed dataset")),
158                metadata: dataset.metadata.clone(),
159            };
160
161            chunks.push(chunk);
162        }
163
164        Ok(chunks)
165    }
166
167    /// Distributed random sampling across workers
168    pub fn distributed_sample(
169        &self,
170        dataset: &Dataset,
171        n_samples: usize,
172        random_state: Option<u64>,
173    ) -> Result<Dataset> {
174        if n_samples >= dataset.n_samples() {
175            return Ok(dataset.clone());
176        }
177
178        let samples_per_chunk = n_samples / self.config.num_workers;
179        let remainder = n_samples % self.config.num_workers;
180
181        let chunks = self.split_dataset_into_chunks(dataset)?;
182        let (tx, rx) = mpsc::channel();
183        let mut handles = Vec::new();
184
185        for (i, chunk) in chunks.into_iter().enumerate() {
186            let tx = tx.clone();
187            let chunk_samples = if i < remainder {
188                samples_per_chunk + 1
189            } else {
190                samples_per_chunk
191            };
192
193            let seed = random_state.map(|s| s + i as u64);
194
195            let handle = thread::spawn(move || {
196                let sampled = Self::sample_chunk(&chunk, chunk_samples, seed);
197                let _ = tx.send(sampled);
198            });
199
200            handles.push(handle);
201        }
202
203        drop(tx);
204
205        // Collect sampled chunks
206        let mut sampled_chunks = Vec::new();
207        for result in rx {
208            sampled_chunks.push(result?);
209        }
210
211        // Wait for workers
212        for handle in handles {
213            let _ = handle.join();
214        }
215
216        // Combine sampled chunks
217        self.combine_datasets(&sampled_chunks)
218    }
219
220    /// Distributed cross-validation split
221    pub fn distributed_k_fold(
222        &self,
223        dataset: &Dataset,
224        k: usize,
225        shuffle: bool,
226        random_state: Option<u64>,
227    ) -> Result<Vec<(Dataset, Dataset)>> {
228        let n_samples = dataset.n_samples();
229        let fold_size = n_samples / k;
230
231        let mut indices: Vec<usize> = (0..n_samples).collect();
232
233        if shuffle {
234            use scirs2_core::random::seq::SliceRandom;
235            use scirs2_core::random::SeedableRng;
236
237            let mut rng = if let Some(seed) = random_state {
238                scirs2_core::random::rngs::StdRng::seed_from_u64(seed)
239            } else {
240                // For deterministic testing, use a fixed seed when no seed provided
241                scirs2_core::random::rngs::StdRng::seed_from_u64(42)
242            };
243
244            indices.shuffle(&mut rng);
245        }
246
247        let mut folds = Vec::new();
248
249        for fold_idx in 0..k {
250            let test_start = fold_idx * fold_size;
251            let test_end = if fold_idx == k - 1 {
252                n_samples
253            } else {
254                (fold_idx + 1) * fold_size
255            };
256
257            let test_indices = &indices[test_start..test_end];
258            let train_indices: Vec<usize> = indices[..test_start]
259                .iter()
260                .chain(indices[test_end..].iter())
261                .copied()
262                .collect();
263
264            let train_data = self.select_samples(dataset, &train_indices)?;
265            let test_data = self.select_samples(dataset, test_indices)?;
266
267            folds.push((train_data, test_data));
268        }
269
270        Ok(folds)
271    }
272
273    /// Distributed stratified sampling
274    pub fn distributed_stratified_sample(
275        &self,
276        dataset: &Dataset,
277        n_samples: usize,
278        random_state: Option<u64>,
279    ) -> Result<Dataset> {
280        let target = dataset.target.as_ref().ok_or_else(|| {
281            DatasetsError::InvalidFormat("Stratified sampling requires target values".to_string())
282        })?;
283
284        // Group _samples by class
285        let mut class_groups: HashMap<i32, Vec<usize>> = HashMap::new();
286        for (idx, &value) in target.iter().enumerate() {
287            let class = value as i32;
288            class_groups.entry(class).or_default().push(idx);
289        }
290
291        // Calculate _samples per class
292        let n_classes = class_groups.len();
293        let base_samples_per_class = n_samples / n_classes;
294        let remainder = n_samples % n_classes;
295
296        let (tx, rx) = mpsc::channel();
297        let mut handles = Vec::new();
298
299        for (class_idx, (class, indices)) in class_groups.into_iter().enumerate() {
300            let tx = tx.clone();
301            let class_samples = if class_idx < remainder {
302                base_samples_per_class + 1
303            } else {
304                base_samples_per_class
305            };
306
307            let seed = random_state.map(|s| s + class_idx as u64);
308
309            let handle = thread::spawn(move || {
310                let sampled_indices = Self::sample_indices(&indices, class_samples, seed);
311                let _ = tx.send((class, sampled_indices));
312            });
313
314            handles.push(handle);
315        }
316
317        drop(tx);
318
319        // Collect sampled indices
320        let mut all_sampled_indices = Vec::new();
321        for (_, indices) in rx {
322            all_sampled_indices.extend(indices?);
323        }
324
325        // Wait for workers
326        for handle in handles {
327            let _ = handle.join();
328        }
329
330        // Create stratified sample
331        self.select_samples(dataset, &all_sampled_indices)
332    }
333
334    /// Parallel feature scaling across workers
335    pub fn distributed_scale(
336        &self,
337        dataset: &Dataset,
338        method: ScalingMethod,
339    ) -> Result<(Dataset, ScalingParameters)> {
340        let n_features = dataset.n_features();
341        let chunks = self.split_dataset_into_chunks(dataset)?;
342
343        // Phase 1: Compute statistics in parallel
344        let (tx, rx) = mpsc::channel();
345        let mut handles = Vec::new();
346
347        for chunk in chunks.iter() {
348            let tx = tx.clone();
349            let chunk = chunk.clone();
350
351            let handle = thread::spawn(move || {
352                let stats = Self::compute_chunk_statistics(&chunk);
353                let _ = tx.send(stats);
354            });
355
356            handles.push(handle);
357        }
358
359        drop(tx);
360
361        // Collect statistics
362        let mut all_stats = Vec::new();
363        for stats in rx {
364            all_stats.push(stats?);
365        }
366
367        // Wait for workers
368        for handle in handles {
369            let _ = handle.join();
370        }
371
372        // Phase 2: Combine statistics
373        let global_stats = Self::combine_statistics(&all_stats, n_features)?;
374        let scaling_params = ScalingParameters::from_statistics(&global_stats, method);
375
376        // Phase 3: Apply scaling in parallel
377        let (tx, rx) = mpsc::channel();
378        let mut handles = Vec::new();
379
380        for chunk in chunks {
381            let tx = tx.clone();
382            let params = scaling_params.clone();
383
384            let handle = thread::spawn(move || {
385                let scaled_chunk = Self::apply_scaling(&chunk, &params);
386                let _ = tx.send(scaled_chunk);
387            });
388
389            handles.push(handle);
390        }
391
392        drop(tx);
393
394        // Collect scaled chunks
395        let mut scaled_chunks = Vec::new();
396        for result in rx {
397            scaled_chunks.push(result?);
398        }
399
400        // Wait for workers
401        for handle in handles {
402            let _ = handle.join();
403        }
404
405        // Combine scaled chunks
406        let scaled_dataset = self.combine_datasets(&scaled_chunks)?;
407        Ok((scaled_dataset, scaling_params))
408    }
409
410    // Helper methods
411
412    fn sample_chunk(
413        chunk: &Dataset,
414        n_samples: usize,
415        random_state: Option<u64>,
416    ) -> Result<Dataset> {
417        if n_samples >= chunk.n_samples() {
418            return Ok(chunk.clone());
419        }
420
421        use scirs2_core::random::seq::SliceRandom;
422        use scirs2_core::random::SeedableRng;
423
424        let mut rng = if let Some(seed) = random_state {
425            scirs2_core::random::rngs::StdRng::seed_from_u64(seed)
426        } else {
427            // For deterministic testing, use a fixed seed when no seed provided
428            scirs2_core::random::rngs::StdRng::seed_from_u64(42)
429        };
430
431        let mut indices: Vec<usize> = (0..chunk.n_samples()).collect();
432        indices.shuffle(&mut rng);
433        indices.truncate(n_samples);
434
435        Self::select_samples_static(chunk, &indices)
436    }
437
438    fn sample_indices(
439        indices: &[usize],
440        n_samples: usize,
441        random_state: Option<u64>,
442    ) -> Result<Vec<usize>> {
443        if n_samples >= indices.len() {
444            return Ok(indices.to_vec());
445        }
446
447        use scirs2_core::random::seq::SliceRandom;
448        use scirs2_core::random::SeedableRng;
449
450        let mut rng = if let Some(seed) = random_state {
451            scirs2_core::random::rngs::StdRng::seed_from_u64(seed)
452        } else {
453            // For deterministic testing, use a fixed seed when no seed provided
454            scirs2_core::random::rngs::StdRng::seed_from_u64(42)
455        };
456
457        let mut sampled = indices.to_vec();
458        sampled.shuffle(&mut rng);
459        sampled.truncate(n_samples);
460
461        Ok(sampled)
462    }
463
464    fn select_samples(&self, dataset: &Dataset, indices: &[usize]) -> Result<Dataset> {
465        Self::select_samples_static(dataset, indices)
466    }
467
468    fn select_samples_static(dataset: &Dataset, indices: &[usize]) -> Result<Dataset> {
469        let selected_data = dataset.data.select(Axis(0), indices);
470        let selected_target = dataset
471            .target
472            .as_ref()
473            .map(|target| target.select(Axis(0), indices));
474
475        Ok(Dataset {
476            data: selected_data,
477            target: selected_target,
478            featurenames: dataset.featurenames.clone(),
479            targetnames: dataset.targetnames.clone(),
480            feature_descriptions: dataset.feature_descriptions.clone(),
481            description: Some("Distributed sample".to_string()),
482            metadata: dataset.metadata.clone(),
483        })
484    }
485
486    fn combine_datasets(&self, datasets: &[Dataset]) -> Result<Dataset> {
487        if datasets.is_empty() {
488            return Err(DatasetsError::InvalidFormat(
489                "Cannot combine empty dataset list".to_string(),
490            ));
491        }
492
493        let n_features = datasets[0].n_features();
494        let total_samples: usize = datasets.iter().map(|d| d.n_samples()).sum();
495
496        // Combine data arrays
497        let mut combined_data = Vec::with_capacity(total_samples * n_features);
498        let mut combined_target = if datasets[0].target.is_some() {
499            Some(Vec::with_capacity(total_samples))
500        } else {
501            None
502        };
503
504        for dataset in datasets {
505            for row in dataset.data.rows() {
506                combined_data.extend(row.iter());
507            }
508
509            if let (Some(ref mut combined), Some(ref target)) =
510                (&mut combined_target, &dataset.target)
511            {
512                combined.extend(target.iter());
513            }
514        }
515
516        let data = Array2::from_shape_vec((total_samples, n_features), combined_data)
517            .map_err(|e| DatasetsError::FormatError(e.to_string()))?;
518
519        let target = combined_target.map(Array1::from_vec);
520
521        Ok(Dataset {
522            data,
523            target,
524            featurenames: datasets[0].featurenames.clone(),
525            targetnames: datasets[0].targetnames.clone(),
526            feature_descriptions: datasets[0].feature_descriptions.clone(),
527            description: Some("Combined distributed dataset".to_string()),
528            metadata: datasets[0].metadata.clone(),
529        })
530    }
531
532    fn compute_chunk_statistics(chunk: &Dataset) -> Result<ChunkStatistics> {
533        let data = &chunk.data;
534        let n_features = data.ncols();
535        let n_samples = data.nrows() as f64;
536
537        let mut means = vec![0.0; n_features];
538        let mut mins = vec![f64::INFINITY; n_features];
539        let mut maxs = vec![f64::NEG_INFINITY; n_features];
540        let mut sum_squares = vec![0.0; n_features];
541
542        for col in 0..n_features {
543            let column = data.column(col);
544
545            let sum: f64 = column.sum();
546            means[col] = sum / n_samples;
547
548            for &value in column.iter() {
549                mins[col] = mins[col].min(value);
550                maxs[col] = maxs[col].max(value);
551                sum_squares[col] += value * value;
552            }
553        }
554
555        Ok(ChunkStatistics {
556            n_samples: n_samples as usize,
557            means,
558            mins,
559            maxs,
560            sum_squares,
561        })
562    }
563
564    fn combine_statistics(
565        stats: &[ChunkStatistics],
566        n_features: usize,
567    ) -> Result<GlobalStatistics> {
568        let total_samples: usize = stats.iter().map(|s| s.n_samples).sum();
569        let mut global_means = vec![0.0; n_features];
570        let mut global_mins = vec![f64::INFINITY; n_features];
571        let mut global_maxs = vec![f64::NEG_INFINITY; n_features];
572        let mut global_stds = vec![0.0; n_features];
573
574        // Combine means
575        for (feature, global_mean) in global_means.iter_mut().enumerate().take(n_features) {
576            let weighted_sum: f64 = stats
577                .iter()
578                .map(|s| s.means[feature] * s.n_samples as f64)
579                .sum();
580            *global_mean = weighted_sum / total_samples as f64;
581        }
582
583        // Combine mins and maxs
584        for feature in 0..n_features {
585            for chunk_stats in stats {
586                global_mins[feature] = global_mins[feature].min(chunk_stats.mins[feature]);
587                global_maxs[feature] = global_maxs[feature].max(chunk_stats.maxs[feature]);
588            }
589        }
590
591        // Compute global standard deviations
592        for feature in 0..n_features {
593            let sum_squared_deviations: f64 = stats
594                .iter()
595                .map(|s| {
596                    let chunk_mean = s.means[feature];
597                    let global_mean = global_means[feature];
598                    let n = s.n_samples as f64;
599
600                    // Sum of squares within chunk + correction for mean difference
601                    s.sum_squares[feature] - 2.0 * chunk_mean * n * global_mean
602                        + n * global_mean * global_mean
603                })
604                .sum();
605
606            global_stds[feature] = (sum_squared_deviations / total_samples as f64).sqrt();
607        }
608
609        Ok(GlobalStatistics {
610            means: global_means,
611            stds: global_stds,
612            mins: global_mins,
613            maxs: global_maxs,
614        })
615    }
616
617    fn apply_scaling(dataset: &Dataset, params: &ScalingParameters) -> Result<Dataset> {
618        let mut scaled_data = dataset.data.clone();
619
620        match params.method {
621            ScalingMethod::StandardScaler => {
622                for (col_idx, mut column) in scaled_data.columns_mut().into_iter().enumerate() {
623                    let mean = params.means[col_idx];
624                    let std = params.stds[col_idx];
625
626                    if std > 1e-8 {
627                        // Avoid division by zero
628                        for value in column.iter_mut() {
629                            *value = (*value - mean) / std;
630                        }
631                    }
632                }
633            }
634            ScalingMethod::MinMaxScaler => {
635                for (col_idx, mut column) in scaled_data.columns_mut().into_iter().enumerate() {
636                    let min = params.mins[col_idx];
637                    let max = params.maxs[col_idx];
638                    let range = max - min;
639
640                    if range > 1e-8 {
641                        // Avoid division by zero
642                        for value in column.iter_mut() {
643                            *value = (*value - min) / range;
644                        }
645                    }
646                }
647            }
648            ScalingMethod::RobustScaler => {
649                // For simplicity, fall back to standard scaling
650                // In a full implementation, you'd compute medians and MAD
651                for (col_idx, mut column) in scaled_data.columns_mut().into_iter().enumerate() {
652                    let mean = params.means[col_idx];
653                    let std = params.stds[col_idx];
654
655                    if std > 1e-8 {
656                        for value in column.iter_mut() {
657                            *value = (*value - mean) / std;
658                        }
659                    }
660                }
661            }
662        }
663
664        Ok(Dataset {
665            data: scaled_data,
666            target: dataset.target.clone(),
667            featurenames: dataset.featurenames.clone(),
668            targetnames: dataset.targetnames.clone(),
669            feature_descriptions: dataset.feature_descriptions.clone(),
670            description: Some("Distributed scaled dataset".to_string()),
671            metadata: dataset.metadata.clone(),
672        })
673    }
674}
675
676/// Statistics computed on a chunk of data
677#[derive(Debug, Clone)]
678struct ChunkStatistics {
679    n_samples: usize,
680    means: Vec<f64>,
681    mins: Vec<f64>,
682    maxs: Vec<f64>,
683    sum_squares: Vec<f64>,
684}
685
686/// Global statistics combined from all chunks
687#[derive(Debug, Clone)]
688struct GlobalStatistics {
689    means: Vec<f64>,
690    stds: Vec<f64>,
691    mins: Vec<f64>,
692    maxs: Vec<f64>,
693}
694
695/// Scaling methods for distributed processing
696#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
697pub enum ScalingMethod {
698    /// Z-score normalization
699    StandardScaler,
700    /// Min-max scaling to [0, 1]
701    MinMaxScaler,
702    /// Robust scaling using median and MAD
703    RobustScaler,
704}
705
706/// Parameters for scaling transformations
707#[derive(Debug, Clone)]
708pub struct ScalingParameters {
709    method: ScalingMethod,
710    means: Vec<f64>,
711    stds: Vec<f64>,
712    mins: Vec<f64>,
713    maxs: Vec<f64>,
714}
715
716impl ScalingParameters {
717    fn from_statistics(stats: &GlobalStatistics, method: ScalingMethod) -> Self {
718        Self {
719            method,
720            means: stats.means.clone(),
721            stds: stats.stds.clone(),
722            mins: stats.mins.clone(),
723            maxs: stats.maxs.clone(),
724        }
725    }
726}
727
728// Add missing import for array slicing syntax
729use scirs2_core::ndarray::s;
730
731#[cfg(test)]
732mod tests {
733    use super::*;
734    use crate::generators::make_classification;
735
736    #[test]
737    fn test_distributed_config_default() {
738        let config = DistributedConfig::default();
739        assert!(config.num_workers > 0);
740        assert!(config.chunk_size > 0);
741    }
742
743    #[test]
744    fn test_split_dataset_into_chunks() {
745        let dataset = make_classification(100, 5, 2, 3, 1, Some(42)).unwrap();
746        let processor = DistributedProcessor::default_config().unwrap();
747
748        let chunks = processor.split_dataset_into_chunks(&dataset).unwrap();
749
750        assert!(!chunks.is_empty());
751
752        let total_samples: usize = chunks.iter().map(|c| c.n_samples()).sum();
753        assert_eq!(total_samples, dataset.n_samples());
754    }
755
756    #[test]
757    fn test_distributed_sample() {
758        let dataset = make_classification(1000, 5, 2, 3, 1, Some(42)).unwrap();
759        let processor = DistributedProcessor::default_config().unwrap();
760
761        let sampled = processor
762            .distributed_sample(&dataset, 100, Some(42))
763            .unwrap();
764
765        assert_eq!(sampled.n_samples(), 100);
766        assert_eq!(sampled.n_features(), dataset.n_features());
767    }
768
769    #[test]
770    fn test_distributed_k_fold() {
771        let dataset = make_classification(100, 5, 2, 3, 1, Some(42)).unwrap();
772        let processor = DistributedProcessor::default_config().unwrap();
773
774        let folds = processor
775            .distributed_k_fold(&dataset, 5, true, Some(42))
776            .unwrap();
777
778        assert_eq!(folds.len(), 5);
779
780        for (train, test) in folds {
781            assert!(train.n_samples() > 0);
782            assert!(test.n_samples() > 0);
783            assert_eq!(train.n_features(), dataset.n_features());
784            assert_eq!(test.n_features(), dataset.n_features());
785        }
786    }
787
788    #[test]
789    fn test_combine_datasets() {
790        let dataset1 = make_classification(50, 3, 2, 2, 1, Some(42)).unwrap();
791        let dataset2 = make_classification(30, 3, 2, 2, 1, Some(43)).unwrap();
792
793        let processor = DistributedProcessor::default_config().unwrap();
794        let combined = processor.combine_datasets(&[dataset1, dataset2]).unwrap();
795
796        assert_eq!(combined.n_samples(), 80);
797        assert_eq!(combined.n_features(), 3);
798    }
799
800    #[test]
801    fn test_parallel_processing() {
802        let dataset = make_classification(200, 4, 2, 3, 1, Some(42)).unwrap();
803        let processor = DistributedProcessor::default_config().unwrap();
804
805        // Simple processor that counts samples
806        let counter = |chunk: &Dataset| -> Result<usize> { Ok(chunk.n_samples()) };
807
808        let results = processor
809            .process_dataset_parallel(&dataset, counter)
810            .unwrap();
811
812        let total_processed: usize = results.iter().sum();
813        assert_eq!(total_processed, dataset.n_samples());
814    }
815}