1use std::collections::HashMap;
10use std::sync::mpsc;
11use std::sync::Arc;
12use std::thread;
13
14use scirs2_core::ndarray::{Array1, Array2, Axis};
15use serde::{Deserialize, Serialize};
16
17use crate::cache::DatasetCache;
18use crate::error::{DatasetsError, Result};
19use crate::utils::Dataset;
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct DistributedConfig {
24 pub num_workers: usize,
26 pub chunk_size: usize,
28 pub timeout: u64,
30 pub use_shared_memory: bool,
32 pub memory_limit_mb: usize,
34}
35
36impl Default for DistributedConfig {
37 fn default() -> Self {
38 let num_cpus = thread::available_parallelism()
39 .map(|n| n.get())
40 .unwrap_or(4);
41
42 Self {
43 num_workers: num_cpus,
44 chunk_size: 10000,
45 timeout: 300,
46 use_shared_memory: false,
47 memory_limit_mb: 1024,
48 }
49 }
50}
51
52pub struct DistributedProcessor {
54 config: DistributedConfig,
55 #[allow(dead_code)]
56 cache: DatasetCache,
57}
58
59impl DistributedProcessor {
60 pub fn new(config: DistributedConfig) -> Result<Self> {
62 let cachedir = dirs::cache_dir()
63 .ok_or_else(|| DatasetsError::Other("Could not determine cache directory".to_string()))?
64 .join("scirs2-datasets");
65 let cache = DatasetCache::new(cachedir);
66
67 Ok(Self { config, cache })
68 }
69
70 pub fn default_config() -> Result<Self> {
72 Self::new(DistributedConfig::default())
73 }
74
75 pub fn process_dataset_parallel<F, R>(&self, dataset: &Dataset, processor: F) -> Result<Vec<R>>
77 where
78 F: Fn(&Dataset) -> Result<R> + Send + Sync + Clone + 'static,
79 R: Send + 'static,
80 {
81 let chunks = self.split_dataset_into_chunks(dataset)?;
82 let processor = Arc::new(processor);
83
84 let (tx, rx) = mpsc::channel();
85 let mut handles = Vec::new();
86
87 for chunk in chunks {
89 let tx = tx.clone();
90 let processor = Arc::clone(&processor);
91
92 let handle = thread::spawn(move || {
93 let result = processor(&chunk);
94 let _ = tx.send(result);
95 });
96
97 handles.push(handle);
98 }
99
100 drop(tx);
102
103 let mut results = Vec::new();
105 for result in rx {
106 results.push(result?);
107 }
108
109 for handle in handles {
111 let _ = handle.join();
112 }
113
114 Ok(results)
115 }
116
117 pub fn map_reduce_dataset<M, R, C>(&self, dataset: &Dataset, mapper: M, reducer: R) -> Result<C>
119 where
120 M: Fn(&Dataset) -> Result<Vec<C>> + Send + Sync + Clone + 'static,
121 R: Fn(Vec<C>) -> Result<C> + Send + Sync + 'static,
122 C: Send + 'static,
123 {
124 let map_results = self.process_dataset_parallel(dataset, mapper)?;
126
127 let flattened: Vec<C> = map_results.into_iter().flatten().collect();
129 reducer(flattened)
130 }
131
132 pub fn split_dataset_into_chunks(&self, dataset: &Dataset) -> Result<Vec<Dataset>> {
134 let n_samples = dataset.n_samples();
135 let chunk_size = self
136 .config
137 .chunk_size
138 .min(n_samples / self.config.num_workers + 1);
139
140 let mut chunks = Vec::new();
141
142 for start in (0..n_samples).step_by(chunk_size) {
143 let end = (start + chunk_size).min(n_samples);
144 let chunk_data = dataset.data.slice(s![start..end, ..]).to_owned();
145
146 let chunk_target = dataset
147 .target
148 .as_ref()
149 .map(|target| target.slice(s![start..end]).to_owned());
150
151 let chunk = Dataset {
152 data: chunk_data,
153 target: chunk_target,
154 featurenames: dataset.featurenames.clone(),
155 targetnames: dataset.targetnames.clone(),
156 feature_descriptions: dataset.feature_descriptions.clone(),
157 description: Some(format!("Chunk {start}-{end} of distributed dataset")),
158 metadata: dataset.metadata.clone(),
159 };
160
161 chunks.push(chunk);
162 }
163
164 Ok(chunks)
165 }
166
167 pub fn distributed_sample(
169 &self,
170 dataset: &Dataset,
171 n_samples: usize,
172 random_state: Option<u64>,
173 ) -> Result<Dataset> {
174 if n_samples >= dataset.n_samples() {
175 return Ok(dataset.clone());
176 }
177
178 let samples_per_chunk = n_samples / self.config.num_workers;
179 let remainder = n_samples % self.config.num_workers;
180
181 let chunks = self.split_dataset_into_chunks(dataset)?;
182 let (tx, rx) = mpsc::channel();
183 let mut handles = Vec::new();
184
185 for (i, chunk) in chunks.into_iter().enumerate() {
186 let tx = tx.clone();
187 let chunk_samples = if i < remainder {
188 samples_per_chunk + 1
189 } else {
190 samples_per_chunk
191 };
192
193 let seed = random_state.map(|s| s + i as u64);
194
195 let handle = thread::spawn(move || {
196 let sampled = Self::sample_chunk(&chunk, chunk_samples, seed);
197 let _ = tx.send(sampled);
198 });
199
200 handles.push(handle);
201 }
202
203 drop(tx);
204
205 let mut sampled_chunks = Vec::new();
207 for result in rx {
208 sampled_chunks.push(result?);
209 }
210
211 for handle in handles {
213 let _ = handle.join();
214 }
215
216 self.combine_datasets(&sampled_chunks)
218 }
219
220 pub fn distributed_k_fold(
222 &self,
223 dataset: &Dataset,
224 k: usize,
225 shuffle: bool,
226 random_state: Option<u64>,
227 ) -> Result<Vec<(Dataset, Dataset)>> {
228 let n_samples = dataset.n_samples();
229 let fold_size = n_samples / k;
230
231 let mut indices: Vec<usize> = (0..n_samples).collect();
232
233 if shuffle {
234 use scirs2_core::random::seq::SliceRandom;
235 use scirs2_core::random::SeedableRng;
236
237 let mut rng = if let Some(seed) = random_state {
238 scirs2_core::random::rngs::StdRng::seed_from_u64(seed)
239 } else {
240 scirs2_core::random::rngs::StdRng::seed_from_u64(42)
242 };
243
244 indices.shuffle(&mut rng);
245 }
246
247 let mut folds = Vec::new();
248
249 for fold_idx in 0..k {
250 let test_start = fold_idx * fold_size;
251 let test_end = if fold_idx == k - 1 {
252 n_samples
253 } else {
254 (fold_idx + 1) * fold_size
255 };
256
257 let test_indices = &indices[test_start..test_end];
258 let train_indices: Vec<usize> = indices[..test_start]
259 .iter()
260 .chain(indices[test_end..].iter())
261 .copied()
262 .collect();
263
264 let train_data = self.select_samples(dataset, &train_indices)?;
265 let test_data = self.select_samples(dataset, test_indices)?;
266
267 folds.push((train_data, test_data));
268 }
269
270 Ok(folds)
271 }
272
273 pub fn distributed_stratified_sample(
275 &self,
276 dataset: &Dataset,
277 n_samples: usize,
278 random_state: Option<u64>,
279 ) -> Result<Dataset> {
280 let target = dataset.target.as_ref().ok_or_else(|| {
281 DatasetsError::InvalidFormat("Stratified sampling requires target values".to_string())
282 })?;
283
284 let mut class_groups: HashMap<i32, Vec<usize>> = HashMap::new();
286 for (idx, &value) in target.iter().enumerate() {
287 let class = value as i32;
288 class_groups.entry(class).or_default().push(idx);
289 }
290
291 let n_classes = class_groups.len();
293 let base_samples_per_class = n_samples / n_classes;
294 let remainder = n_samples % n_classes;
295
296 let (tx, rx) = mpsc::channel();
297 let mut handles = Vec::new();
298
299 for (class_idx, (class, indices)) in class_groups.into_iter().enumerate() {
300 let tx = tx.clone();
301 let class_samples = if class_idx < remainder {
302 base_samples_per_class + 1
303 } else {
304 base_samples_per_class
305 };
306
307 let seed = random_state.map(|s| s + class_idx as u64);
308
309 let handle = thread::spawn(move || {
310 let sampled_indices = Self::sample_indices(&indices, class_samples, seed);
311 let _ = tx.send((class, sampled_indices));
312 });
313
314 handles.push(handle);
315 }
316
317 drop(tx);
318
319 let mut all_sampled_indices = Vec::new();
321 for (_, indices) in rx {
322 all_sampled_indices.extend(indices?);
323 }
324
325 for handle in handles {
327 let _ = handle.join();
328 }
329
330 self.select_samples(dataset, &all_sampled_indices)
332 }
333
334 pub fn distributed_scale(
336 &self,
337 dataset: &Dataset,
338 method: ScalingMethod,
339 ) -> Result<(Dataset, ScalingParameters)> {
340 let n_features = dataset.n_features();
341 let chunks = self.split_dataset_into_chunks(dataset)?;
342
343 let (tx, rx) = mpsc::channel();
345 let mut handles = Vec::new();
346
347 for chunk in chunks.iter() {
348 let tx = tx.clone();
349 let chunk = chunk.clone();
350
351 let handle = thread::spawn(move || {
352 let stats = Self::compute_chunk_statistics(&chunk);
353 let _ = tx.send(stats);
354 });
355
356 handles.push(handle);
357 }
358
359 drop(tx);
360
361 let mut all_stats = Vec::new();
363 for stats in rx {
364 all_stats.push(stats?);
365 }
366
367 for handle in handles {
369 let _ = handle.join();
370 }
371
372 let global_stats = Self::combine_statistics(&all_stats, n_features)?;
374 let scaling_params = ScalingParameters::from_statistics(&global_stats, method);
375
376 let (tx, rx) = mpsc::channel();
378 let mut handles = Vec::new();
379
380 for chunk in chunks {
381 let tx = tx.clone();
382 let params = scaling_params.clone();
383
384 let handle = thread::spawn(move || {
385 let scaled_chunk = Self::apply_scaling(&chunk, ¶ms);
386 let _ = tx.send(scaled_chunk);
387 });
388
389 handles.push(handle);
390 }
391
392 drop(tx);
393
394 let mut scaled_chunks = Vec::new();
396 for result in rx {
397 scaled_chunks.push(result?);
398 }
399
400 for handle in handles {
402 let _ = handle.join();
403 }
404
405 let scaled_dataset = self.combine_datasets(&scaled_chunks)?;
407 Ok((scaled_dataset, scaling_params))
408 }
409
410 fn sample_chunk(
413 chunk: &Dataset,
414 n_samples: usize,
415 random_state: Option<u64>,
416 ) -> Result<Dataset> {
417 if n_samples >= chunk.n_samples() {
418 return Ok(chunk.clone());
419 }
420
421 use scirs2_core::random::seq::SliceRandom;
422 use scirs2_core::random::SeedableRng;
423
424 let mut rng = if let Some(seed) = random_state {
425 scirs2_core::random::rngs::StdRng::seed_from_u64(seed)
426 } else {
427 scirs2_core::random::rngs::StdRng::seed_from_u64(42)
429 };
430
431 let mut indices: Vec<usize> = (0..chunk.n_samples()).collect();
432 indices.shuffle(&mut rng);
433 indices.truncate(n_samples);
434
435 Self::select_samples_static(chunk, &indices)
436 }
437
438 fn sample_indices(
439 indices: &[usize],
440 n_samples: usize,
441 random_state: Option<u64>,
442 ) -> Result<Vec<usize>> {
443 if n_samples >= indices.len() {
444 return Ok(indices.to_vec());
445 }
446
447 use scirs2_core::random::seq::SliceRandom;
448 use scirs2_core::random::SeedableRng;
449
450 let mut rng = if let Some(seed) = random_state {
451 scirs2_core::random::rngs::StdRng::seed_from_u64(seed)
452 } else {
453 scirs2_core::random::rngs::StdRng::seed_from_u64(42)
455 };
456
457 let mut sampled = indices.to_vec();
458 sampled.shuffle(&mut rng);
459 sampled.truncate(n_samples);
460
461 Ok(sampled)
462 }
463
464 fn select_samples(&self, dataset: &Dataset, indices: &[usize]) -> Result<Dataset> {
465 Self::select_samples_static(dataset, indices)
466 }
467
468 fn select_samples_static(dataset: &Dataset, indices: &[usize]) -> Result<Dataset> {
469 let selected_data = dataset.data.select(Axis(0), indices);
470 let selected_target = dataset
471 .target
472 .as_ref()
473 .map(|target| target.select(Axis(0), indices));
474
475 Ok(Dataset {
476 data: selected_data,
477 target: selected_target,
478 featurenames: dataset.featurenames.clone(),
479 targetnames: dataset.targetnames.clone(),
480 feature_descriptions: dataset.feature_descriptions.clone(),
481 description: Some("Distributed sample".to_string()),
482 metadata: dataset.metadata.clone(),
483 })
484 }
485
486 fn combine_datasets(&self, datasets: &[Dataset]) -> Result<Dataset> {
487 if datasets.is_empty() {
488 return Err(DatasetsError::InvalidFormat(
489 "Cannot combine empty dataset list".to_string(),
490 ));
491 }
492
493 let n_features = datasets[0].n_features();
494 let total_samples: usize = datasets.iter().map(|d| d.n_samples()).sum();
495
496 let mut combined_data = Vec::with_capacity(total_samples * n_features);
498 let mut combined_target = if datasets[0].target.is_some() {
499 Some(Vec::with_capacity(total_samples))
500 } else {
501 None
502 };
503
504 for dataset in datasets {
505 for row in dataset.data.rows() {
506 combined_data.extend(row.iter());
507 }
508
509 if let (Some(ref mut combined), Some(ref target)) =
510 (&mut combined_target, &dataset.target)
511 {
512 combined.extend(target.iter());
513 }
514 }
515
516 let data = Array2::from_shape_vec((total_samples, n_features), combined_data)
517 .map_err(|e| DatasetsError::FormatError(e.to_string()))?;
518
519 let target = combined_target.map(Array1::from_vec);
520
521 Ok(Dataset {
522 data,
523 target,
524 featurenames: datasets[0].featurenames.clone(),
525 targetnames: datasets[0].targetnames.clone(),
526 feature_descriptions: datasets[0].feature_descriptions.clone(),
527 description: Some("Combined distributed dataset".to_string()),
528 metadata: datasets[0].metadata.clone(),
529 })
530 }
531
532 fn compute_chunk_statistics(chunk: &Dataset) -> Result<ChunkStatistics> {
533 let data = &chunk.data;
534 let n_features = data.ncols();
535 let n_samples = data.nrows() as f64;
536
537 let mut means = vec![0.0; n_features];
538 let mut mins = vec![f64::INFINITY; n_features];
539 let mut maxs = vec![f64::NEG_INFINITY; n_features];
540 let mut sum_squares = vec![0.0; n_features];
541
542 for col in 0..n_features {
543 let column = data.column(col);
544
545 let sum: f64 = column.sum();
546 means[col] = sum / n_samples;
547
548 for &value in column.iter() {
549 mins[col] = mins[col].min(value);
550 maxs[col] = maxs[col].max(value);
551 sum_squares[col] += value * value;
552 }
553 }
554
555 Ok(ChunkStatistics {
556 n_samples: n_samples as usize,
557 means,
558 mins,
559 maxs,
560 sum_squares,
561 })
562 }
563
564 fn combine_statistics(
565 stats: &[ChunkStatistics],
566 n_features: usize,
567 ) -> Result<GlobalStatistics> {
568 let total_samples: usize = stats.iter().map(|s| s.n_samples).sum();
569 let mut global_means = vec![0.0; n_features];
570 let mut global_mins = vec![f64::INFINITY; n_features];
571 let mut global_maxs = vec![f64::NEG_INFINITY; n_features];
572 let mut global_stds = vec![0.0; n_features];
573
574 for (feature, global_mean) in global_means.iter_mut().enumerate().take(n_features) {
576 let weighted_sum: f64 = stats
577 .iter()
578 .map(|s| s.means[feature] * s.n_samples as f64)
579 .sum();
580 *global_mean = weighted_sum / total_samples as f64;
581 }
582
583 for feature in 0..n_features {
585 for chunk_stats in stats {
586 global_mins[feature] = global_mins[feature].min(chunk_stats.mins[feature]);
587 global_maxs[feature] = global_maxs[feature].max(chunk_stats.maxs[feature]);
588 }
589 }
590
591 for feature in 0..n_features {
593 let sum_squared_deviations: f64 = stats
594 .iter()
595 .map(|s| {
596 let chunk_mean = s.means[feature];
597 let global_mean = global_means[feature];
598 let n = s.n_samples as f64;
599
600 s.sum_squares[feature] - 2.0 * chunk_mean * n * global_mean
602 + n * global_mean * global_mean
603 })
604 .sum();
605
606 global_stds[feature] = (sum_squared_deviations / total_samples as f64).sqrt();
607 }
608
609 Ok(GlobalStatistics {
610 means: global_means,
611 stds: global_stds,
612 mins: global_mins,
613 maxs: global_maxs,
614 })
615 }
616
617 fn apply_scaling(dataset: &Dataset, params: &ScalingParameters) -> Result<Dataset> {
618 let mut scaled_data = dataset.data.clone();
619
620 match params.method {
621 ScalingMethod::StandardScaler => {
622 for (col_idx, mut column) in scaled_data.columns_mut().into_iter().enumerate() {
623 let mean = params.means[col_idx];
624 let std = params.stds[col_idx];
625
626 if std > 1e-8 {
627 for value in column.iter_mut() {
629 *value = (*value - mean) / std;
630 }
631 }
632 }
633 }
634 ScalingMethod::MinMaxScaler => {
635 for (col_idx, mut column) in scaled_data.columns_mut().into_iter().enumerate() {
636 let min = params.mins[col_idx];
637 let max = params.maxs[col_idx];
638 let range = max - min;
639
640 if range > 1e-8 {
641 for value in column.iter_mut() {
643 *value = (*value - min) / range;
644 }
645 }
646 }
647 }
648 ScalingMethod::RobustScaler => {
649 for (col_idx, mut column) in scaled_data.columns_mut().into_iter().enumerate() {
652 let mean = params.means[col_idx];
653 let std = params.stds[col_idx];
654
655 if std > 1e-8 {
656 for value in column.iter_mut() {
657 *value = (*value - mean) / std;
658 }
659 }
660 }
661 }
662 }
663
664 Ok(Dataset {
665 data: scaled_data,
666 target: dataset.target.clone(),
667 featurenames: dataset.featurenames.clone(),
668 targetnames: dataset.targetnames.clone(),
669 feature_descriptions: dataset.feature_descriptions.clone(),
670 description: Some("Distributed scaled dataset".to_string()),
671 metadata: dataset.metadata.clone(),
672 })
673 }
674}
675
676#[derive(Debug, Clone)]
678struct ChunkStatistics {
679 n_samples: usize,
680 means: Vec<f64>,
681 mins: Vec<f64>,
682 maxs: Vec<f64>,
683 sum_squares: Vec<f64>,
684}
685
686#[derive(Debug, Clone)]
688struct GlobalStatistics {
689 means: Vec<f64>,
690 stds: Vec<f64>,
691 mins: Vec<f64>,
692 maxs: Vec<f64>,
693}
694
695#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
697pub enum ScalingMethod {
698 StandardScaler,
700 MinMaxScaler,
702 RobustScaler,
704}
705
706#[derive(Debug, Clone)]
708pub struct ScalingParameters {
709 method: ScalingMethod,
710 means: Vec<f64>,
711 stds: Vec<f64>,
712 mins: Vec<f64>,
713 maxs: Vec<f64>,
714}
715
716impl ScalingParameters {
717 fn from_statistics(stats: &GlobalStatistics, method: ScalingMethod) -> Self {
718 Self {
719 method,
720 means: stats.means.clone(),
721 stds: stats.stds.clone(),
722 mins: stats.mins.clone(),
723 maxs: stats.maxs.clone(),
724 }
725 }
726}
727
728use scirs2_core::ndarray::s;
730
731#[cfg(test)]
732mod tests {
733 use super::*;
734 use crate::generators::make_classification;
735
736 #[test]
737 fn test_distributed_config_default() {
738 let config = DistributedConfig::default();
739 assert!(config.num_workers > 0);
740 assert!(config.chunk_size > 0);
741 }
742
743 #[test]
744 fn test_split_dataset_into_chunks() {
745 let dataset = make_classification(100, 5, 2, 3, 1, Some(42)).unwrap();
746 let processor = DistributedProcessor::default_config().unwrap();
747
748 let chunks = processor.split_dataset_into_chunks(&dataset).unwrap();
749
750 assert!(!chunks.is_empty());
751
752 let total_samples: usize = chunks.iter().map(|c| c.n_samples()).sum();
753 assert_eq!(total_samples, dataset.n_samples());
754 }
755
756 #[test]
757 fn test_distributed_sample() {
758 let dataset = make_classification(1000, 5, 2, 3, 1, Some(42)).unwrap();
759 let processor = DistributedProcessor::default_config().unwrap();
760
761 let sampled = processor
762 .distributed_sample(&dataset, 100, Some(42))
763 .unwrap();
764
765 assert_eq!(sampled.n_samples(), 100);
766 assert_eq!(sampled.n_features(), dataset.n_features());
767 }
768
769 #[test]
770 fn test_distributed_k_fold() {
771 let dataset = make_classification(100, 5, 2, 3, 1, Some(42)).unwrap();
772 let processor = DistributedProcessor::default_config().unwrap();
773
774 let folds = processor
775 .distributed_k_fold(&dataset, 5, true, Some(42))
776 .unwrap();
777
778 assert_eq!(folds.len(), 5);
779
780 for (train, test) in folds {
781 assert!(train.n_samples() > 0);
782 assert!(test.n_samples() > 0);
783 assert_eq!(train.n_features(), dataset.n_features());
784 assert_eq!(test.n_features(), dataset.n_features());
785 }
786 }
787
788 #[test]
789 fn test_combine_datasets() {
790 let dataset1 = make_classification(50, 3, 2, 2, 1, Some(42)).unwrap();
791 let dataset2 = make_classification(30, 3, 2, 2, 1, Some(43)).unwrap();
792
793 let processor = DistributedProcessor::default_config().unwrap();
794 let combined = processor.combine_datasets(&[dataset1, dataset2]).unwrap();
795
796 assert_eq!(combined.n_samples(), 80);
797 assert_eq!(combined.n_features(), 3);
798 }
799
800 #[test]
801 fn test_parallel_processing() {
802 let dataset = make_classification(200, 4, 2, 3, 1, Some(42)).unwrap();
803 let processor = DistributedProcessor::default_config().unwrap();
804
805 let counter = |chunk: &Dataset| -> Result<usize> { Ok(chunk.n_samples()) };
807
808 let results = processor
809 .process_dataset_parallel(&dataset, counter)
810 .unwrap();
811
812 let total_processed: usize = results.iter().sum();
813 assert_eq!(total_processed, dataset.n_samples());
814 }
815}