1use crate::generators::basic::{make_blobs, make_classification, make_regression};
11use scirs2_core::ndarray::{Array1, Array2};
12use std::collections::HashMap;
13use std::sync::mpsc;
14use std::thread;
15use std::time::{Duration, Instant};
16
17#[derive(Debug, Clone)]
19pub struct StreamConfig {
20 pub chunk_size: usize,
22 pub total_samples: usize,
24 pub random_state: Option<u64>,
26 pub n_workers: usize,
28}
29
30impl Default for StreamConfig {
31 fn default() -> Self {
32 Self {
33 chunk_size: 1000,
34 total_samples: 10000,
35 random_state: None,
36 n_workers: num_cpus::get(),
37 }
38 }
39}
40
41pub struct DatasetStream<T> {
43 config: StreamConfig,
44 current_chunk: usize,
45 total_chunks: usize,
46 generator_fn: Box<dyn Fn(usize, usize, Option<u64>) -> T + Send + Sync>,
47}
48
49impl<T> DatasetStream<T> {
50 fn new<F>(config: StreamConfig, generator_fn: F) -> Self
51 where
52 F: Fn(usize, usize, Option<u64>) -> T + Send + Sync + 'static,
53 {
54 let total_chunks = (config.total_samples + config.chunk_size - 1) / config.chunk_size;
55
56 Self {
57 config,
58 current_chunk: 0,
59 total_chunks,
60 generator_fn: Box::new(generator_fn),
61 }
62 }
63}
64
65impl<T> Iterator for DatasetStream<T> {
66 type Item = T;
67
68 fn next(&mut self) -> Option<Self::Item> {
69 if self.current_chunk >= self.total_chunks {
70 return None;
71 }
72
73 let chunk_start = self.current_chunk * self.config.chunk_size;
74 let chunk_end = std::cmp::min(
75 chunk_start + self.config.chunk_size,
76 self.config.total_samples,
77 );
78 let chunk_size = chunk_end - chunk_start;
79
80 let chunk_seed = self
82 .config
83 .random_state
84 .map(|seed| seed + self.current_chunk as u64);
85
86 let result = (self.generator_fn)(chunk_size, self.current_chunk, chunk_seed);
87 self.current_chunk += 1;
88
89 Some(result)
90 }
91}
92
93pub fn stream_classification(
95 n_features: usize,
96 n_classes: usize,
97 config: StreamConfig,
98) -> DatasetStream<(Array2<f64>, Array1<i32>)> {
99 DatasetStream::new(config, move |chunk_size, _chunk_idx, seed| {
100 make_classification(
101 chunk_size, n_features, n_features, 0, n_classes, seed,
104 )
105 .unwrap()
106 })
107}
108
109pub fn stream_regression(
111 n_features: usize,
112 config: StreamConfig,
113) -> DatasetStream<(Array2<f64>, Array1<f64>)> {
114 DatasetStream::new(config, move |chunk_size, _chunk_idx, seed| {
115 make_regression(
116 chunk_size, n_features, n_features, 0.1, seed,
119 )
120 .unwrap()
121 })
122}
123
124pub fn stream_blobs(
126 n_features: usize,
127 centers: usize,
128 config: StreamConfig,
129) -> DatasetStream<(Array2<f64>, Array1<i32>)> {
130 DatasetStream::new(config, move |chunk_size, _chunk_idx, seed| {
131 make_blobs(
132 chunk_size, n_features, centers, 1.0, seed,
134 )
135 .unwrap()
136 })
137}
138
139#[derive(Debug)]
141pub struct ParallelGenerationResult<T> {
142 pub chunks: Vec<T>,
143 pub generation_time: std::time::Duration,
144 pub n_workers_used: usize,
145}
146
147pub fn parallel_generate<T, F>(
149 n_samples: usize,
150 n_workers: usize,
151 generator_fn: F,
152) -> Result<ParallelGenerationResult<T>, Box<dyn std::error::Error + Send + Sync>>
153where
154 T: Send + 'static,
155 F: Fn(usize, Option<u64>) -> Result<T, Box<dyn std::error::Error + Send + Sync>>
156 + Send
157 + Sync
158 + Copy
159 + 'static,
160{
161 let start_time = std::time::Instant::now();
162
163 let chunk_size = (n_samples + n_workers - 1) / n_workers;
164 let (tx, rx) = mpsc::channel();
165
166 let mut handles = Vec::new();
167
168 for worker_id in 0..n_workers {
169 let tx = tx.clone();
170 let handle = thread::spawn(move || {
171 let chunk_start = worker_id * chunk_size;
172 let chunk_end = std::cmp::min(chunk_start + chunk_size, n_samples);
173 let actual_chunk_size = chunk_end - chunk_start;
174
175 if actual_chunk_size == 0 {
176 return;
177 }
178
179 let seed = Some(worker_id as u64 * 12345);
181
182 match generator_fn(actual_chunk_size, seed) {
183 Ok(result) => {
184 if tx.send((worker_id, Ok(result))).is_err() {
185 eprintln!("Failed to send result from worker {}", worker_id);
186 }
187 }
188 Err(e) => {
189 if tx.send((worker_id, Err(e))).is_err() {
190 eprintln!("Failed to send error from worker {}", worker_id);
191 }
192 }
193 }
194 });
195 handles.push(handle);
196 }
197
198 drop(tx);
200
201 let mut results: Vec<Option<T>> = (0..n_workers).map(|_| None).collect();
203 let mut successful_workers = 0;
204
205 for (worker_id, result) in rx {
206 match result {
207 Ok(data) => {
208 results[worker_id] = Some(data);
209 successful_workers += 1;
210 }
211 Err(e) => {
212 return Err(format!("Worker {} failed: {}", worker_id, e).into());
213 }
214 }
215 }
216
217 for handle in handles {
219 handle.join().map_err(|_| "Thread panicked")?;
220 }
221
222 let chunks: Vec<T> = results.into_iter().flatten().collect();
224
225 let generation_time = start_time.elapsed();
226
227 Ok(ParallelGenerationResult {
228 chunks,
229 generation_time,
230 n_workers_used: successful_workers,
231 })
232}
233
234pub fn parallel_classification(
236 n_samples: usize,
237 n_features: usize,
238 n_classes: usize,
239 n_workers: usize,
240) -> Result<
241 ParallelGenerationResult<(Array2<f64>, Array1<i32>)>,
242 Box<dyn std::error::Error + Send + Sync>,
243> {
244 parallel_generate(n_samples, n_workers, move |chunk_size, seed| {
245 make_classification(
246 chunk_size, n_features, n_features, 0, n_classes, seed,
249 )
250 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
251 })
252}
253
254pub fn parallel_regression(
256 n_samples: usize,
257 n_features: usize,
258 n_workers: usize,
259) -> Result<
260 ParallelGenerationResult<(Array2<f64>, Array1<f64>)>,
261 Box<dyn std::error::Error + Send + Sync>,
262> {
263 parallel_generate(n_samples, n_workers, move |chunk_size, seed| {
264 make_regression(
265 chunk_size, n_features, n_features, 0.1, seed,
268 )
269 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
270 })
271}
272
273pub fn parallel_blobs(
275 n_samples: usize,
276 n_features: usize,
277 centers: usize,
278 n_workers: usize,
279) -> Result<
280 ParallelGenerationResult<(Array2<f64>, Array1<i32>)>,
281 Box<dyn std::error::Error + Send + Sync>,
282> {
283 parallel_generate(n_samples, n_workers, move |chunk_size, seed| {
284 make_blobs(
285 chunk_size, n_features, centers, 1.0, seed,
287 )
288 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
289 })
290}
291
292pub struct LazyDatasetGenerator<T> {
294 chunk_size: usize,
295 total_samples: usize,
296 generated_samples: usize,
297 generator_fn:
298 Box<dyn Fn(usize, Option<u64>) -> Result<T, Box<dyn std::error::Error + Send + Sync>>>,
299 random_state: Option<u64>,
300}
301
302impl<T> LazyDatasetGenerator<T> {
303 pub fn new<F>(
304 total_samples: usize,
305 chunk_size: usize,
306 random_state: Option<u64>,
307 generator_fn: F,
308 ) -> Self
309 where
310 F: Fn(usize, Option<u64>) -> Result<T, Box<dyn std::error::Error + Send + Sync>> + 'static,
311 {
312 Self {
313 chunk_size,
314 total_samples,
315 generated_samples: 0,
316 generator_fn: Box::new(generator_fn),
317 random_state,
318 }
319 }
320
321 pub fn next_chunk(&mut self) -> Option<Result<T, Box<dyn std::error::Error + Send + Sync>>> {
323 if self.generated_samples >= self.total_samples {
324 return None;
325 }
326
327 let remaining_samples = self.total_samples - self.generated_samples;
328 let current_chunk_size = std::cmp::min(self.chunk_size, remaining_samples);
329
330 let seed = self.random_state.map(|s| s + self.generated_samples as u64);
332
333 let result = (self.generator_fn)(current_chunk_size, seed);
334 self.generated_samples += current_chunk_size;
335
336 Some(result)
337 }
338
339 pub fn progress(&self) -> (usize, usize, f64) {
341 let progress_ratio = self.generated_samples as f64 / self.total_samples as f64;
342 (self.generated_samples, self.total_samples, progress_ratio)
343 }
344
345 pub fn is_complete(&self) -> bool {
347 self.generated_samples >= self.total_samples
348 }
349}
350
351pub fn lazy_classification(
353 total_samples: usize,
354 n_features: usize,
355 n_classes: usize,
356 chunk_size: usize,
357 random_state: Option<u64>,
358) -> LazyDatasetGenerator<(Array2<f64>, Array1<i32>)> {
359 LazyDatasetGenerator::new(
360 total_samples,
361 chunk_size,
362 random_state,
363 move |chunk_size, seed| {
364 make_classification(
365 chunk_size, n_features, n_features, 0, n_classes, seed,
368 )
369 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
370 },
371 )
372}
373
374pub fn lazy_regression(
376 total_samples: usize,
377 n_features: usize,
378 chunk_size: usize,
379 random_state: Option<u64>,
380) -> LazyDatasetGenerator<(Array2<f64>, Array1<f64>)> {
381 LazyDatasetGenerator::new(
382 total_samples,
383 chunk_size,
384 random_state,
385 move |chunk_size, seed| {
386 make_regression(
387 chunk_size, n_features, n_features, 0.1, seed,
390 )
391 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
392 },
393 )
394}
395
396#[derive(Debug, Clone)]
398pub struct DistributedConfig {
399 pub total_samples: usize,
401 pub n_nodes: usize,
403 pub node_id: usize,
405 pub random_state: Option<u64>,
407 pub timeout: Duration,
409 pub load_balancing: LoadBalancingStrategy,
411}
412
413#[derive(Debug, Clone)]
415pub enum LoadBalancingStrategy {
416 EqualSplit,
418 Weighted(Vec<f64>),
420 Dynamic,
422}
423
424impl Default for DistributedConfig {
425 fn default() -> Self {
426 Self {
427 total_samples: 100000,
428 n_nodes: 1,
429 node_id: 0,
430 random_state: None,
431 timeout: Duration::from_secs(300), load_balancing: LoadBalancingStrategy::EqualSplit,
433 }
434 }
435}
436
437#[derive(Debug, Clone)]
439pub struct NodeInfo {
440 pub node_id: usize,
441 pub samples_assigned: usize,
442 pub samples_generated: usize,
443 pub status: NodeStatus,
444 pub start_time: Option<Instant>,
445 pub completion_time: Option<Instant>,
446}
447
448#[derive(Debug, Clone, PartialEq)]
450pub enum NodeStatus {
451 Idle,
453 Working,
455 Completed,
457 Failed,
459}
460
461#[derive(Debug)]
463pub struct DistributedGenerationResult<T> {
464 pub data: T,
465 pub node_results: HashMap<usize, NodeResult<T>>,
466 pub total_generation_time: Duration,
467 pub coordination_overhead: Duration,
468 pub n_nodes_used: usize,
469 pub load_balance_efficiency: f64,
470}
471
472#[derive(Debug)]
474pub struct NodeResult<T> {
475 pub node_id: usize,
476 pub data: T,
477 pub generation_time: Duration,
478 pub samples_generated: usize,
479}
480
481#[derive(Debug)]
483pub struct DistributedGenerator {
484 config: DistributedConfig,
485 nodes: HashMap<usize, NodeInfo>,
486}
487
488impl DistributedGenerator {
489 pub fn new(
491 config: DistributedConfig,
492 ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
493 if config.node_id >= config.n_nodes {
494 return Err("Node ID must be less than total number of nodes".into());
495 }
496
497 let mut nodes = HashMap::new();
498 for i in 0..config.n_nodes {
499 nodes.insert(
500 i,
501 NodeInfo {
502 node_id: i,
503 samples_assigned: 0,
504 samples_generated: 0,
505 status: NodeStatus::Idle,
506 start_time: None,
507 completion_time: None,
508 },
509 );
510 }
511
512 Ok(Self { config, nodes })
513 }
514
515 pub fn calculate_sample_distribution(
517 &mut self,
518 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
519 match &self.config.load_balancing {
520 LoadBalancingStrategy::EqualSplit => {
521 let base_samples = self.config.total_samples / self.config.n_nodes;
522 let remainder = self.config.total_samples % self.config.n_nodes;
523
524 for i in 0..self.config.n_nodes {
525 if let Some(node) = self.nodes.get_mut(&i) {
526 node.samples_assigned = base_samples + if i < remainder { 1 } else { 0 };
527 }
528 }
529 }
530 LoadBalancingStrategy::Weighted(weights) => {
531 if weights.len() != self.config.n_nodes {
532 return Err("Number of weights must match number of nodes".into());
533 }
534
535 let total_weight: f64 = weights.iter().sum();
536 if total_weight <= 0.0 {
537 return Err("Total weight must be positive".into());
538 }
539
540 let mut assigned_samples = 0;
541 for i in 0..self.config.n_nodes {
543 if let Some(node) = self.nodes.get_mut(&i) {
544 if i < weights.len() - 1 {
545 node.samples_assigned = ((weights[i] / total_weight)
546 * self.config.total_samples as f64)
547 as usize;
548 assigned_samples += node.samples_assigned;
549 } else {
550 node.samples_assigned = self.config.total_samples - assigned_samples;
552 }
553 }
554 }
555 }
556 LoadBalancingStrategy::Dynamic => {
557 let base_samples = self.config.total_samples / self.config.n_nodes;
559 let remainder = self.config.total_samples % self.config.n_nodes;
560
561 for i in 0..self.config.n_nodes {
562 if let Some(node) = self.nodes.get_mut(&i) {
563 node.samples_assigned = base_samples + if i < remainder { 1 } else { 0 };
564 }
565 }
566 }
567 }
568
569 Ok(())
570 }
571
572 pub fn get_current_node_samples(&self) -> usize {
574 self.nodes
575 .get(&self.config.node_id)
576 .map(|node| node.samples_assigned)
577 .unwrap_or(0)
578 }
579
580 pub fn start_generation(&mut self) {
582 if let Some(node) = self.nodes.get_mut(&self.config.node_id) {
583 node.status = NodeStatus::Working;
584 node.start_time = Some(Instant::now());
585 }
586 }
587
588 pub fn complete_generation(&mut self, samples_generated: usize) {
590 if let Some(node) = self.nodes.get_mut(&self.config.node_id) {
591 node.status = NodeStatus::Completed;
592 node.samples_generated = samples_generated;
593 node.completion_time = Some(Instant::now());
594 }
595 }
596
597 pub fn calculate_load_balance_efficiency(&self) -> f64 {
599 let completed_nodes: Vec<_> = self
600 .nodes
601 .values()
602 .filter(|node| node.status == NodeStatus::Completed)
603 .collect();
604
605 if completed_nodes.is_empty() {
606 return 0.0;
607 }
608
609 let generation_times: Vec<Duration> = completed_nodes
610 .iter()
611 .filter_map(|node| {
612 if let (Some(start), Some(end)) = (node.start_time, node.completion_time) {
613 Some(end - start)
614 } else {
615 None
616 }
617 })
618 .collect();
619
620 if generation_times.is_empty() {
621 return 0.0;
622 }
623
624 let total_time: Duration = generation_times.iter().sum();
625 let avg_time = total_time / generation_times.len() as u32;
626 let max_time = generation_times.iter().max().unwrap();
627
628 if max_time.as_nanos() == 0 {
629 return 1.0;
630 }
631
632 (avg_time.as_nanos() as f64) / (max_time.as_nanos() as f64)
633 }
634}
635
636pub fn distributed_classification(
638 n_features: usize,
639 n_classes: usize,
640 config: DistributedConfig,
641) -> Result<
642 DistributedGenerationResult<(Array2<f64>, Array1<i32>)>,
643 Box<dyn std::error::Error + Send + Sync>,
644> {
645 let start_time = Instant::now();
646
647 let mut generator = DistributedGenerator::new(config.clone())?;
648 generator.calculate_sample_distribution()?;
649
650 let samples_for_this_node = generator.get_current_node_samples();
651
652 let node_seed = config
654 .random_state
655 .map(|seed| seed + config.node_id as u64 * 12345);
656
657 generator.start_generation();
658
659 let generation_start = Instant::now();
661 let (x, y) = make_classification(
662 samples_for_this_node,
663 n_features,
664 n_features, 0, n_classes,
667 node_seed,
668 )?;
669 let generation_time = generation_start.elapsed();
670
671 generator.complete_generation(samples_for_this_node);
672
673 let node_result = NodeResult {
675 node_id: config.node_id,
676 data: (x.clone(), y.clone()),
677 generation_time,
678 samples_generated: samples_for_this_node,
679 };
680
681 let mut node_results = HashMap::new();
682 node_results.insert(config.node_id, node_result);
683
684 let total_generation_time = start_time.elapsed();
685 let coordination_overhead = total_generation_time - generation_time;
686 let load_balance_efficiency = generator.calculate_load_balance_efficiency();
687
688 Ok(DistributedGenerationResult {
689 data: (x, y),
690 node_results,
691 total_generation_time,
692 coordination_overhead,
693 n_nodes_used: 1, load_balance_efficiency,
695 })
696}
697
698pub fn distributed_regression(
700 n_features: usize,
701 config: DistributedConfig,
702) -> Result<
703 DistributedGenerationResult<(Array2<f64>, Array1<f64>)>,
704 Box<dyn std::error::Error + Send + Sync>,
705> {
706 let start_time = Instant::now();
707
708 let mut generator = DistributedGenerator::new(config.clone())?;
709 generator.calculate_sample_distribution()?;
710
711 let samples_for_this_node = generator.get_current_node_samples();
712
713 let node_seed = config
715 .random_state
716 .map(|seed| seed + config.node_id as u64 * 12345);
717
718 generator.start_generation();
719
720 let generation_start = Instant::now();
722 let (x, y) = make_regression(
723 samples_for_this_node,
724 n_features,
725 n_features, 0.1, node_seed,
728 )?;
729 let generation_time = generation_start.elapsed();
730
731 generator.complete_generation(samples_for_this_node);
732
733 let node_result = NodeResult {
735 node_id: config.node_id,
736 data: (x.clone(), y.clone()),
737 generation_time,
738 samples_generated: samples_for_this_node,
739 };
740
741 let mut node_results = HashMap::new();
742 node_results.insert(config.node_id, node_result);
743
744 let total_generation_time = start_time.elapsed();
745 let coordination_overhead = total_generation_time - generation_time;
746 let load_balance_efficiency = generator.calculate_load_balance_efficiency();
747
748 Ok(DistributedGenerationResult {
749 data: (x, y),
750 node_results,
751 total_generation_time,
752 coordination_overhead,
753 n_nodes_used: 1, load_balance_efficiency,
755 })
756}
757
758pub fn distributed_blobs(
760 n_features: usize,
761 centers: usize,
762 config: DistributedConfig,
763) -> Result<
764 DistributedGenerationResult<(Array2<f64>, Array1<i32>)>,
765 Box<dyn std::error::Error + Send + Sync>,
766> {
767 let start_time = Instant::now();
768
769 let mut generator = DistributedGenerator::new(config.clone())?;
770 generator.calculate_sample_distribution()?;
771
772 let samples_for_this_node = generator.get_current_node_samples();
773
774 let node_seed = config
776 .random_state
777 .map(|seed| seed + config.node_id as u64 * 12345);
778
779 generator.start_generation();
780
781 let generation_start = Instant::now();
783 let (x, y) = make_blobs(
784 samples_for_this_node,
785 n_features,
786 centers,
787 1.0, node_seed,
789 )?;
790 let generation_time = generation_start.elapsed();
791
792 generator.complete_generation(samples_for_this_node);
793
794 let node_result = NodeResult {
796 node_id: config.node_id,
797 data: (x.clone(), y.clone()),
798 generation_time,
799 samples_generated: samples_for_this_node,
800 };
801
802 let mut node_results = HashMap::new();
803 node_results.insert(config.node_id, node_result);
804
805 let total_generation_time = start_time.elapsed();
806 let coordination_overhead = total_generation_time - generation_time;
807 let load_balance_efficiency = generator.calculate_load_balance_efficiency();
808
809 Ok(DistributedGenerationResult {
810 data: (x, y),
811 node_results,
812 total_generation_time,
813 coordination_overhead,
814 n_nodes_used: 1, load_balance_efficiency,
816 })
817}
818
819#[allow(non_snake_case)]
820#[cfg(test)]
821mod tests {
822 use super::*;
823
824 #[test]
825 fn test_stream_classification() {
826 let config = StreamConfig {
827 chunk_size: 100,
828 total_samples: 300,
829 random_state: Some(42),
830 n_workers: 2,
831 };
832
833 let stream = stream_classification(4, 3, config);
834 let mut total_samples = 0;
835
836 for (i, (x, y)) in stream.enumerate() {
837 assert_eq!(x.ncols(), 4); assert!(y.iter().all(|&label| label < 3)); if i < 2 {
841 assert_eq!(x.nrows(), 100); assert_eq!(y.len(), 100);
843 } else {
844 assert_eq!(x.nrows(), 100); assert_eq!(y.len(), 100);
846 }
847
848 total_samples += x.nrows();
849 }
850
851 assert_eq!(total_samples, 300);
852 }
853
854 #[test]
855 fn test_parallel_classification() {
856 let result = parallel_classification(1000, 5, 3, 4).unwrap();
857
858 assert_eq!(result.n_workers_used, 4);
859 assert_eq!(result.chunks.len(), 4);
860
861 let total_samples: usize = result.chunks.iter().map(|(x, _)| x.nrows()).sum();
862 assert_eq!(total_samples, 1000);
863
864 for (x, y) in &result.chunks {
866 assert_eq!(x.ncols(), 5);
867 assert!(y.iter().all(|&label| label < 3));
868 }
869 }
870
871 #[test]
872 fn test_lazy_generator() {
873 let mut generator = lazy_classification(500, 3, 2, 150, Some(42));
874
875 let mut total_samples = 0;
876 let mut chunk_count = 0;
877
878 while !generator.is_complete() {
879 if let Some(result) = generator.next_chunk() {
880 let (x, y) = result.unwrap();
881 assert_eq!(x.ncols(), 3);
882 assert!(y.iter().all(|&label| label < 2));
883
884 total_samples += x.nrows();
885 chunk_count += 1;
886
887 let (generated, total, progress) = generator.progress();
888 assert_eq!(generated, total_samples);
889 assert_eq!(total, 500);
890 assert!((0.0..=1.0).contains(&progress));
891 } else {
892 break;
893 }
894 }
895
896 assert_eq!(total_samples, 500);
897 assert_eq!(chunk_count, 4); assert!(generator.is_complete());
899 }
900
901 #[test]
902 fn test_stream_config_default() {
903 let config = StreamConfig::default();
904 assert_eq!(config.chunk_size, 1000);
905 assert_eq!(config.total_samples, 10000);
906 assert!(config.random_state.is_none());
907 assert!(config.n_workers > 0);
908 }
909
910 #[test]
911 fn test_parallel_generation_timing() {
912 let start = std::time::Instant::now();
913 let result = parallel_regression(2000, 10, 2).unwrap();
914 let sequential_time = start.elapsed();
915
916 assert!(result.generation_time <= sequential_time * 2); assert_eq!(result.n_workers_used, 2);
918
919 let total_samples: usize = result.chunks.iter().map(|(x, _)| x.nrows()).sum();
920 assert_eq!(total_samples, 2000);
921 }
922
923 #[test]
924 fn test_distributed_config_default() {
925 let config = DistributedConfig::default();
926 assert_eq!(config.total_samples, 100000);
927 assert_eq!(config.n_nodes, 1);
928 assert_eq!(config.node_id, 0);
929 assert!(config.random_state.is_none());
930 assert_eq!(config.timeout, Duration::from_secs(300));
931 assert!(matches!(
932 config.load_balancing,
933 LoadBalancingStrategy::EqualSplit
934 ));
935 }
936
937 #[test]
938 fn test_distributed_generator_sample_distribution() {
939 let mut config = DistributedConfig::default();
940 config.total_samples = 1000;
941 config.n_nodes = 3;
942 config.node_id = 0;
943
944 let mut generator = DistributedGenerator::new(config).unwrap();
945 generator.calculate_sample_distribution().unwrap();
946
947 assert_eq!(generator.nodes[&0].samples_assigned, 334);
950 assert_eq!(generator.nodes[&1].samples_assigned, 333);
951 assert_eq!(generator.nodes[&2].samples_assigned, 333);
952
953 let total_assigned: usize = generator
954 .nodes
955 .values()
956 .map(|node| node.samples_assigned)
957 .sum();
958 assert_eq!(total_assigned, 1000);
959 }
960
961 #[test]
962 fn test_distributed_generator_weighted_distribution() {
963 let mut config = DistributedConfig::default();
964 config.total_samples = 1000;
965 config.n_nodes = 3;
966 config.node_id = 0;
967 config.load_balancing = LoadBalancingStrategy::Weighted(vec![0.5, 0.3, 0.2]);
968
969 let mut generator = DistributedGenerator::new(config).unwrap();
970 generator.calculate_sample_distribution().unwrap();
971
972 assert_eq!(generator.nodes[&0].samples_assigned, 500);
974 assert_eq!(generator.nodes[&1].samples_assigned, 300);
975 assert_eq!(generator.nodes[&2].samples_assigned, 200);
976
977 let total_assigned: usize = generator
978 .nodes
979 .values()
980 .map(|node| node.samples_assigned)
981 .sum();
982 assert_eq!(total_assigned, 1000);
983 }
984
985 #[test]
986 fn test_distributed_classification() {
987 let config = DistributedConfig {
988 total_samples: 1000,
989 n_nodes: 4,
990 node_id: 1,
991 random_state: Some(42),
992 ..Default::default()
993 };
994
995 let result = distributed_classification(5, 3, config).unwrap();
996
997 assert_eq!(result.data.0.nrows(), 250);
999 assert_eq!(result.data.1.len(), 250);
1000 assert_eq!(result.data.0.ncols(), 5);
1001 assert!(result.data.1.iter().all(|&label| label < 3));
1002
1003 assert_eq!(result.n_nodes_used, 1);
1004 assert!(result.node_results.contains_key(&1));
1005 assert_eq!(result.node_results[&1].samples_generated, 250);
1006 assert!(result.total_generation_time > Duration::from_nanos(0));
1007 }
1008
1009 #[test]
1010 fn test_distributed_regression() {
1011 let config = DistributedConfig {
1012 total_samples: 800,
1013 n_nodes: 2,
1014 node_id: 0,
1015 random_state: Some(123),
1016 ..Default::default()
1017 };
1018
1019 let result = distributed_regression(7, config).unwrap();
1020
1021 assert_eq!(result.data.0.nrows(), 400);
1023 assert_eq!(result.data.1.len(), 400);
1024 assert_eq!(result.data.0.ncols(), 7);
1025
1026 assert_eq!(result.n_nodes_used, 1);
1027 assert!(result.node_results.contains_key(&0));
1028 assert_eq!(result.node_results[&0].samples_generated, 400);
1029 assert!(result.load_balance_efficiency >= 0.0);
1030 assert!(result.load_balance_efficiency <= 1.0);
1031 }
1032
1033 #[test]
1034 fn test_distributed_blobs() {
1035 let config = DistributedConfig {
1036 total_samples: 600,
1037 n_nodes: 3,
1038 node_id: 2,
1039 random_state: Some(456),
1040 ..Default::default()
1041 };
1042
1043 let result = distributed_blobs(4, 5, config).unwrap();
1044
1045 assert_eq!(result.data.0.nrows(), 200);
1047 assert_eq!(result.data.1.len(), 200);
1048 assert_eq!(result.data.0.ncols(), 4);
1049 assert!(result.data.1.iter().all(|&label| label < 5));
1050
1051 assert_eq!(result.n_nodes_used, 1);
1052 assert!(result.node_results.contains_key(&2));
1053 assert_eq!(result.node_results[&2].samples_generated, 200);
1054 assert!(result.coordination_overhead >= Duration::from_nanos(0));
1055 }
1056
1057 #[test]
1058 fn test_distributed_generator_invalid_node_id() {
1059 let config = DistributedConfig {
1060 total_samples: 1000,
1061 n_nodes: 3,
1062 node_id: 3, ..Default::default()
1064 };
1065
1066 let result = DistributedGenerator::new(config);
1067 assert!(result.is_err());
1068 assert!(result
1069 .unwrap_err()
1070 .to_string()
1071 .contains("Node ID must be less than total number of nodes"));
1072 }
1073
1074 #[test]
1075 fn test_distributed_generator_weighted_validation() {
1076 let mut config = DistributedConfig::default();
1077 config.n_nodes = 3;
1078 config.load_balancing = LoadBalancingStrategy::Weighted(vec![0.5, 0.3]); let mut generator = DistributedGenerator::new(config).unwrap();
1081 let result = generator.calculate_sample_distribution();
1082 assert!(result.is_err());
1083 assert!(result
1084 .unwrap_err()
1085 .to_string()
1086 .contains("Number of weights must match number of nodes"));
1087 }
1088
1089 #[test]
1090 fn test_load_balance_efficiency_calculation() {
1091 let mut config = DistributedConfig::default();
1092 config.n_nodes = 2;
1093 config.node_id = 0;
1094
1095 let mut generator = DistributedGenerator::new(config).unwrap();
1096
1097 generator.nodes.get_mut(&0).unwrap().status = NodeStatus::Completed;
1099 generator.nodes.get_mut(&0).unwrap().start_time =
1100 Some(Instant::now() - Duration::from_millis(100));
1101 generator.nodes.get_mut(&0).unwrap().completion_time = Some(Instant::now());
1102
1103 generator.nodes.get_mut(&1).unwrap().status = NodeStatus::Completed;
1104 generator.nodes.get_mut(&1).unwrap().start_time =
1105 Some(Instant::now() - Duration::from_millis(200));
1106 generator.nodes.get_mut(&1).unwrap().completion_time = Some(Instant::now());
1107
1108 let efficiency = generator.calculate_load_balance_efficiency();
1109 assert!(efficiency >= 0.0);
1110 assert!(efficiency <= 1.0);
1111 }
1112}