sklears_datasets/generators/
performance.rs

1//! Performance-optimized dataset generation
2//!
3//! This module provides high-performance dataset generation capabilities including:
4//! - Streaming dataset generation with lazy evaluation
5//! - Parallel data generation using rayon
6//! - Memory-efficient generation for large datasets
7//! - Chunked processing for distributed systems
8//! - Distributed dataset generation across multiple nodes
9
10use crate::generators::basic::{make_blobs, make_classification, make_regression};
11use scirs2_core::ndarray::{Array1, Array2};
12use std::collections::HashMap;
13use std::sync::mpsc;
14use std::thread;
15use std::time::{Duration, Instant};
16
17/// Configuration for streaming dataset generation
18#[derive(Debug, Clone)]
19pub struct StreamConfig {
20    /// Chunk size for streaming generation
21    pub chunk_size: usize,
22    /// Total number of samples to generate
23    pub total_samples: usize,
24    /// Random seed for reproducibility
25    pub random_state: Option<u64>,
26    /// Number of parallel workers
27    pub n_workers: usize,
28}
29
30impl Default for StreamConfig {
31    fn default() -> Self {
32        Self {
33            chunk_size: 1000,
34            total_samples: 10000,
35            random_state: None,
36            n_workers: num_cpus::get(),
37        }
38    }
39}
40
41/// Iterator for streaming dataset generation
42pub struct DatasetStream<T> {
43    config: StreamConfig,
44    current_chunk: usize,
45    total_chunks: usize,
46    generator_fn: Box<dyn Fn(usize, usize, Option<u64>) -> T + Send + Sync>,
47}
48
49impl<T> DatasetStream<T> {
50    fn new<F>(config: StreamConfig, generator_fn: F) -> Self
51    where
52        F: Fn(usize, usize, Option<u64>) -> T + Send + Sync + 'static,
53    {
54        let total_chunks = (config.total_samples + config.chunk_size - 1) / config.chunk_size;
55
56        Self {
57            config,
58            current_chunk: 0,
59            total_chunks,
60            generator_fn: Box::new(generator_fn),
61        }
62    }
63}
64
65impl<T> Iterator for DatasetStream<T> {
66    type Item = T;
67
68    fn next(&mut self) -> Option<Self::Item> {
69        if self.current_chunk >= self.total_chunks {
70            return None;
71        }
72
73        let chunk_start = self.current_chunk * self.config.chunk_size;
74        let chunk_end = std::cmp::min(
75            chunk_start + self.config.chunk_size,
76            self.config.total_samples,
77        );
78        let chunk_size = chunk_end - chunk_start;
79
80        // Generate unique seed for this chunk if random_state is provided
81        let chunk_seed = self
82            .config
83            .random_state
84            .map(|seed| seed + self.current_chunk as u64);
85
86        let result = (self.generator_fn)(chunk_size, self.current_chunk, chunk_seed);
87        self.current_chunk += 1;
88
89        Some(result)
90    }
91}
92
93/// Streaming classification dataset generator
94pub fn stream_classification(
95    n_features: usize,
96    n_classes: usize,
97    config: StreamConfig,
98) -> DatasetStream<(Array2<f64>, Array1<i32>)> {
99    DatasetStream::new(config, move |chunk_size, _chunk_idx, seed| {
100        make_classification(
101            chunk_size, n_features, n_features, // n_informative
102            0,          // n_redundant
103            n_classes, seed,
104        )
105        .unwrap()
106    })
107}
108
109/// Streaming regression dataset generator
110pub fn stream_regression(
111    n_features: usize,
112    config: StreamConfig,
113) -> DatasetStream<(Array2<f64>, Array1<f64>)> {
114    DatasetStream::new(config, move |chunk_size, _chunk_idx, seed| {
115        make_regression(
116            chunk_size, n_features, n_features, // n_informative
117            0.1,        // noise
118            seed,
119        )
120        .unwrap()
121    })
122}
123
124/// Streaming blob dataset generator
125pub fn stream_blobs(
126    n_features: usize,
127    centers: usize,
128    config: StreamConfig,
129) -> DatasetStream<(Array2<f64>, Array1<i32>)> {
130    DatasetStream::new(config, move |chunk_size, _chunk_idx, seed| {
131        make_blobs(
132            chunk_size, n_features, centers, 1.0, // cluster_std
133            seed,
134        )
135        .unwrap()
136    })
137}
138
139/// Parallel dataset generation result
140#[derive(Debug)]
141pub struct ParallelGenerationResult<T> {
142    pub chunks: Vec<T>,
143    pub generation_time: std::time::Duration,
144    pub n_workers_used: usize,
145}
146
147/// Generate datasets in parallel using multiple threads
148pub fn parallel_generate<T, F>(
149    n_samples: usize,
150    n_workers: usize,
151    generator_fn: F,
152) -> Result<ParallelGenerationResult<T>, Box<dyn std::error::Error + Send + Sync>>
153where
154    T: Send + 'static,
155    F: Fn(usize, Option<u64>) -> Result<T, Box<dyn std::error::Error + Send + Sync>>
156        + Send
157        + Sync
158        + Copy
159        + 'static,
160{
161    let start_time = std::time::Instant::now();
162
163    let chunk_size = (n_samples + n_workers - 1) / n_workers;
164    let (tx, rx) = mpsc::channel();
165
166    let mut handles = Vec::new();
167
168    for worker_id in 0..n_workers {
169        let tx = tx.clone();
170        let handle = thread::spawn(move || {
171            let chunk_start = worker_id * chunk_size;
172            let chunk_end = std::cmp::min(chunk_start + chunk_size, n_samples);
173            let actual_chunk_size = chunk_end - chunk_start;
174
175            if actual_chunk_size == 0 {
176                return;
177            }
178
179            // Use worker_id as part of the seed for reproducibility
180            let seed = Some(worker_id as u64 * 12345);
181
182            match generator_fn(actual_chunk_size, seed) {
183                Ok(result) => {
184                    if tx.send((worker_id, Ok(result))).is_err() {
185                        eprintln!("Failed to send result from worker {}", worker_id);
186                    }
187                }
188                Err(e) => {
189                    if tx.send((worker_id, Err(e))).is_err() {
190                        eprintln!("Failed to send error from worker {}", worker_id);
191                    }
192                }
193            }
194        });
195        handles.push(handle);
196    }
197
198    // Drop the sender so the receiver knows when all workers are done
199    drop(tx);
200
201    // Collect results in order
202    let mut results: Vec<Option<T>> = (0..n_workers).map(|_| None).collect();
203    let mut successful_workers = 0;
204
205    for (worker_id, result) in rx {
206        match result {
207            Ok(data) => {
208                results[worker_id] = Some(data);
209                successful_workers += 1;
210            }
211            Err(e) => {
212                return Err(format!("Worker {} failed: {}", worker_id, e).into());
213            }
214        }
215    }
216
217    // Wait for all threads to complete
218    for handle in handles {
219        handle.join().map_err(|_| "Thread panicked")?;
220    }
221
222    // Filter out None values and collect results
223    let chunks: Vec<T> = results.into_iter().flatten().collect();
224
225    let generation_time = start_time.elapsed();
226
227    Ok(ParallelGenerationResult {
228        chunks,
229        generation_time,
230        n_workers_used: successful_workers,
231    })
232}
233
234/// Parallel classification dataset generation
235pub fn parallel_classification(
236    n_samples: usize,
237    n_features: usize,
238    n_classes: usize,
239    n_workers: usize,
240) -> Result<
241    ParallelGenerationResult<(Array2<f64>, Array1<i32>)>,
242    Box<dyn std::error::Error + Send + Sync>,
243> {
244    parallel_generate(n_samples, n_workers, move |chunk_size, seed| {
245        make_classification(
246            chunk_size, n_features, n_features, // n_informative
247            0,          // n_redundant
248            n_classes, seed,
249        )
250        .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
251    })
252}
253
254/// Parallel regression dataset generation
255pub fn parallel_regression(
256    n_samples: usize,
257    n_features: usize,
258    n_workers: usize,
259) -> Result<
260    ParallelGenerationResult<(Array2<f64>, Array1<f64>)>,
261    Box<dyn std::error::Error + Send + Sync>,
262> {
263    parallel_generate(n_samples, n_workers, move |chunk_size, seed| {
264        make_regression(
265            chunk_size, n_features, n_features, // n_informative
266            0.1,        // noise
267            seed,
268        )
269        .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
270    })
271}
272
273/// Parallel blob dataset generation
274pub fn parallel_blobs(
275    n_samples: usize,
276    n_features: usize,
277    centers: usize,
278    n_workers: usize,
279) -> Result<
280    ParallelGenerationResult<(Array2<f64>, Array1<i32>)>,
281    Box<dyn std::error::Error + Send + Sync>,
282> {
283    parallel_generate(n_samples, n_workers, move |chunk_size, seed| {
284        make_blobs(
285            chunk_size, n_features, centers, 1.0, // cluster_std
286            seed,
287        )
288        .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
289    })
290}
291
292/// Memory-efficient dataset generator that yields chunks on demand
293pub struct LazyDatasetGenerator<T> {
294    chunk_size: usize,
295    total_samples: usize,
296    generated_samples: usize,
297    generator_fn:
298        Box<dyn Fn(usize, Option<u64>) -> Result<T, Box<dyn std::error::Error + Send + Sync>>>,
299    random_state: Option<u64>,
300}
301
302impl<T> LazyDatasetGenerator<T> {
303    pub fn new<F>(
304        total_samples: usize,
305        chunk_size: usize,
306        random_state: Option<u64>,
307        generator_fn: F,
308    ) -> Self
309    where
310        F: Fn(usize, Option<u64>) -> Result<T, Box<dyn std::error::Error + Send + Sync>> + 'static,
311    {
312        Self {
313            chunk_size,
314            total_samples,
315            generated_samples: 0,
316            generator_fn: Box::new(generator_fn),
317            random_state,
318        }
319    }
320
321    /// Generate the next chunk of data
322    pub fn next_chunk(&mut self) -> Option<Result<T, Box<dyn std::error::Error + Send + Sync>>> {
323        if self.generated_samples >= self.total_samples {
324            return None;
325        }
326
327        let remaining_samples = self.total_samples - self.generated_samples;
328        let current_chunk_size = std::cmp::min(self.chunk_size, remaining_samples);
329
330        // Generate seed based on current position for reproducibility
331        let seed = self.random_state.map(|s| s + self.generated_samples as u64);
332
333        let result = (self.generator_fn)(current_chunk_size, seed);
334        self.generated_samples += current_chunk_size;
335
336        Some(result)
337    }
338
339    /// Get progress information
340    pub fn progress(&self) -> (usize, usize, f64) {
341        let progress_ratio = self.generated_samples as f64 / self.total_samples as f64;
342        (self.generated_samples, self.total_samples, progress_ratio)
343    }
344
345    /// Check if generation is complete
346    pub fn is_complete(&self) -> bool {
347        self.generated_samples >= self.total_samples
348    }
349}
350
351/// Create a lazy classification dataset generator
352pub fn lazy_classification(
353    total_samples: usize,
354    n_features: usize,
355    n_classes: usize,
356    chunk_size: usize,
357    random_state: Option<u64>,
358) -> LazyDatasetGenerator<(Array2<f64>, Array1<i32>)> {
359    LazyDatasetGenerator::new(
360        total_samples,
361        chunk_size,
362        random_state,
363        move |chunk_size, seed| {
364            make_classification(
365                chunk_size, n_features, n_features, // n_informative
366                0,          // n_redundant
367                n_classes, seed,
368            )
369            .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
370        },
371    )
372}
373
374/// Create a lazy regression dataset generator
375pub fn lazy_regression(
376    total_samples: usize,
377    n_features: usize,
378    chunk_size: usize,
379    random_state: Option<u64>,
380) -> LazyDatasetGenerator<(Array2<f64>, Array1<f64>)> {
381    LazyDatasetGenerator::new(
382        total_samples,
383        chunk_size,
384        random_state,
385        move |chunk_size, seed| {
386            make_regression(
387                chunk_size, n_features, n_features, // n_informative
388                0.1,        // noise
389                seed,
390            )
391            .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
392        },
393    )
394}
395
396/// Configuration for distributed dataset generation
397#[derive(Debug, Clone)]
398pub struct DistributedConfig {
399    /// Total number of samples to generate
400    pub total_samples: usize,
401    /// Number of nodes in the distributed cluster
402    pub n_nodes: usize,
403    /// Node identifier (0 to n_nodes-1)
404    pub node_id: usize,
405    /// Random seed for reproducibility
406    pub random_state: Option<u64>,
407    /// Timeout for node communication
408    pub timeout: Duration,
409    /// Load balancing strategy
410    pub load_balancing: LoadBalancingStrategy,
411}
412
413/// Load balancing strategies for distributed generation
414#[derive(Debug, Clone)]
415pub enum LoadBalancingStrategy {
416    /// Equal distribution of samples across nodes
417    EqualSplit,
418    /// Weighted distribution based on node capabilities
419    Weighted(Vec<f64>),
420    /// Dynamic load balancing based on node performance
421    Dynamic,
422}
423
424impl Default for DistributedConfig {
425    fn default() -> Self {
426        Self {
427            total_samples: 100000,
428            n_nodes: 1,
429            node_id: 0,
430            random_state: None,
431            timeout: Duration::from_secs(300), // 5 minutes
432            load_balancing: LoadBalancingStrategy::EqualSplit,
433        }
434    }
435}
436
437/// Node information for distributed generation
438#[derive(Debug, Clone)]
439pub struct NodeInfo {
440    pub node_id: usize,
441    pub samples_assigned: usize,
442    pub samples_generated: usize,
443    pub status: NodeStatus,
444    pub start_time: Option<Instant>,
445    pub completion_time: Option<Instant>,
446}
447
448/// Status of a node in distributed generation
449#[derive(Debug, Clone, PartialEq)]
450pub enum NodeStatus {
451    /// Idle
452    Idle,
453    /// Working
454    Working,
455    /// Completed
456    Completed,
457    /// Failed
458    Failed,
459}
460
461/// Result of distributed dataset generation
462#[derive(Debug)]
463pub struct DistributedGenerationResult<T> {
464    pub data: T,
465    pub node_results: HashMap<usize, NodeResult<T>>,
466    pub total_generation_time: Duration,
467    pub coordination_overhead: Duration,
468    pub n_nodes_used: usize,
469    pub load_balance_efficiency: f64,
470}
471
472/// Result from a single node
473#[derive(Debug)]
474pub struct NodeResult<T> {
475    pub node_id: usize,
476    pub data: T,
477    pub generation_time: Duration,
478    pub samples_generated: usize,
479}
480
481/// Distributed dataset generator coordinator
482#[derive(Debug)]
483pub struct DistributedGenerator {
484    config: DistributedConfig,
485    nodes: HashMap<usize, NodeInfo>,
486}
487
488impl DistributedGenerator {
489    /// Create a new distributed generator
490    pub fn new(
491        config: DistributedConfig,
492    ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
493        if config.node_id >= config.n_nodes {
494            return Err("Node ID must be less than total number of nodes".into());
495        }
496
497        let mut nodes = HashMap::new();
498        for i in 0..config.n_nodes {
499            nodes.insert(
500                i,
501                NodeInfo {
502                    node_id: i,
503                    samples_assigned: 0,
504                    samples_generated: 0,
505                    status: NodeStatus::Idle,
506                    start_time: None,
507                    completion_time: None,
508                },
509            );
510        }
511
512        Ok(Self { config, nodes })
513    }
514
515    /// Calculate sample distribution across nodes
516    pub fn calculate_sample_distribution(
517        &mut self,
518    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
519        match &self.config.load_balancing {
520            LoadBalancingStrategy::EqualSplit => {
521                let base_samples = self.config.total_samples / self.config.n_nodes;
522                let remainder = self.config.total_samples % self.config.n_nodes;
523
524                for i in 0..self.config.n_nodes {
525                    if let Some(node) = self.nodes.get_mut(&i) {
526                        node.samples_assigned = base_samples + if i < remainder { 1 } else { 0 };
527                    }
528                }
529            }
530            LoadBalancingStrategy::Weighted(weights) => {
531                if weights.len() != self.config.n_nodes {
532                    return Err("Number of weights must match number of nodes".into());
533                }
534
535                let total_weight: f64 = weights.iter().sum();
536                if total_weight <= 0.0 {
537                    return Err("Total weight must be positive".into());
538                }
539
540                let mut assigned_samples = 0;
541                // Iterate through nodes in order (0 to n_nodes-1) to ensure deterministic assignment
542                for i in 0..self.config.n_nodes {
543                    if let Some(node) = self.nodes.get_mut(&i) {
544                        if i < weights.len() - 1 {
545                            node.samples_assigned = ((weights[i] / total_weight)
546                                * self.config.total_samples as f64)
547                                as usize;
548                            assigned_samples += node.samples_assigned;
549                        } else {
550                            // Assign remaining samples to last node to ensure exact total
551                            node.samples_assigned = self.config.total_samples - assigned_samples;
552                        }
553                    }
554                }
555            }
556            LoadBalancingStrategy::Dynamic => {
557                // Start with equal split, can be adjusted during runtime
558                let base_samples = self.config.total_samples / self.config.n_nodes;
559                let remainder = self.config.total_samples % self.config.n_nodes;
560
561                for i in 0..self.config.n_nodes {
562                    if let Some(node) = self.nodes.get_mut(&i) {
563                        node.samples_assigned = base_samples + if i < remainder { 1 } else { 0 };
564                    }
565                }
566            }
567        }
568
569        Ok(())
570    }
571
572    /// Get the samples assigned to current node
573    pub fn get_current_node_samples(&self) -> usize {
574        self.nodes
575            .get(&self.config.node_id)
576            .map(|node| node.samples_assigned)
577            .unwrap_or(0)
578    }
579
580    /// Mark current node as working
581    pub fn start_generation(&mut self) {
582        if let Some(node) = self.nodes.get_mut(&self.config.node_id) {
583            node.status = NodeStatus::Working;
584            node.start_time = Some(Instant::now());
585        }
586    }
587
588    /// Mark current node as completed
589    pub fn complete_generation(&mut self, samples_generated: usize) {
590        if let Some(node) = self.nodes.get_mut(&self.config.node_id) {
591            node.status = NodeStatus::Completed;
592            node.samples_generated = samples_generated;
593            node.completion_time = Some(Instant::now());
594        }
595    }
596
597    /// Calculate load balance efficiency
598    pub fn calculate_load_balance_efficiency(&self) -> f64 {
599        let completed_nodes: Vec<_> = self
600            .nodes
601            .values()
602            .filter(|node| node.status == NodeStatus::Completed)
603            .collect();
604
605        if completed_nodes.is_empty() {
606            return 0.0;
607        }
608
609        let generation_times: Vec<Duration> = completed_nodes
610            .iter()
611            .filter_map(|node| {
612                if let (Some(start), Some(end)) = (node.start_time, node.completion_time) {
613                    Some(end - start)
614                } else {
615                    None
616                }
617            })
618            .collect();
619
620        if generation_times.is_empty() {
621            return 0.0;
622        }
623
624        let total_time: Duration = generation_times.iter().sum();
625        let avg_time = total_time / generation_times.len() as u32;
626        let max_time = generation_times.iter().max().unwrap();
627
628        if max_time.as_nanos() == 0 {
629            return 1.0;
630        }
631
632        (avg_time.as_nanos() as f64) / (max_time.as_nanos() as f64)
633    }
634}
635
636/// Generate distributed classification dataset
637pub fn distributed_classification(
638    n_features: usize,
639    n_classes: usize,
640    config: DistributedConfig,
641) -> Result<
642    DistributedGenerationResult<(Array2<f64>, Array1<i32>)>,
643    Box<dyn std::error::Error + Send + Sync>,
644> {
645    let start_time = Instant::now();
646
647    let mut generator = DistributedGenerator::new(config.clone())?;
648    generator.calculate_sample_distribution()?;
649
650    let samples_for_this_node = generator.get_current_node_samples();
651
652    // Generate unique seed for this node
653    let node_seed = config
654        .random_state
655        .map(|seed| seed + config.node_id as u64 * 12345);
656
657    generator.start_generation();
658
659    // Generate data for this node
660    let generation_start = Instant::now();
661    let (x, y) = make_classification(
662        samples_for_this_node,
663        n_features,
664        n_features, // n_informative
665        0,          // n_redundant
666        n_classes,
667        node_seed,
668    )?;
669    let generation_time = generation_start.elapsed();
670
671    generator.complete_generation(samples_for_this_node);
672
673    // Create node result
674    let node_result = NodeResult {
675        node_id: config.node_id,
676        data: (x.clone(), y.clone()),
677        generation_time,
678        samples_generated: samples_for_this_node,
679    };
680
681    let mut node_results = HashMap::new();
682    node_results.insert(config.node_id, node_result);
683
684    let total_generation_time = start_time.elapsed();
685    let coordination_overhead = total_generation_time - generation_time;
686    let load_balance_efficiency = generator.calculate_load_balance_efficiency();
687
688    Ok(DistributedGenerationResult {
689        data: (x, y),
690        node_results,
691        total_generation_time,
692        coordination_overhead,
693        n_nodes_used: 1, // Only current node in this implementation
694        load_balance_efficiency,
695    })
696}
697
698/// Generate distributed regression dataset
699pub fn distributed_regression(
700    n_features: usize,
701    config: DistributedConfig,
702) -> Result<
703    DistributedGenerationResult<(Array2<f64>, Array1<f64>)>,
704    Box<dyn std::error::Error + Send + Sync>,
705> {
706    let start_time = Instant::now();
707
708    let mut generator = DistributedGenerator::new(config.clone())?;
709    generator.calculate_sample_distribution()?;
710
711    let samples_for_this_node = generator.get_current_node_samples();
712
713    // Generate unique seed for this node
714    let node_seed = config
715        .random_state
716        .map(|seed| seed + config.node_id as u64 * 12345);
717
718    generator.start_generation();
719
720    // Generate data for this node
721    let generation_start = Instant::now();
722    let (x, y) = make_regression(
723        samples_for_this_node,
724        n_features,
725        n_features, // n_informative
726        0.1,        // noise
727        node_seed,
728    )?;
729    let generation_time = generation_start.elapsed();
730
731    generator.complete_generation(samples_for_this_node);
732
733    // Create node result
734    let node_result = NodeResult {
735        node_id: config.node_id,
736        data: (x.clone(), y.clone()),
737        generation_time,
738        samples_generated: samples_for_this_node,
739    };
740
741    let mut node_results = HashMap::new();
742    node_results.insert(config.node_id, node_result);
743
744    let total_generation_time = start_time.elapsed();
745    let coordination_overhead = total_generation_time - generation_time;
746    let load_balance_efficiency = generator.calculate_load_balance_efficiency();
747
748    Ok(DistributedGenerationResult {
749        data: (x, y),
750        node_results,
751        total_generation_time,
752        coordination_overhead,
753        n_nodes_used: 1, // Only current node in this implementation
754        load_balance_efficiency,
755    })
756}
757
758/// Generate distributed blob dataset
759pub fn distributed_blobs(
760    n_features: usize,
761    centers: usize,
762    config: DistributedConfig,
763) -> Result<
764    DistributedGenerationResult<(Array2<f64>, Array1<i32>)>,
765    Box<dyn std::error::Error + Send + Sync>,
766> {
767    let start_time = Instant::now();
768
769    let mut generator = DistributedGenerator::new(config.clone())?;
770    generator.calculate_sample_distribution()?;
771
772    let samples_for_this_node = generator.get_current_node_samples();
773
774    // Generate unique seed for this node
775    let node_seed = config
776        .random_state
777        .map(|seed| seed + config.node_id as u64 * 12345);
778
779    generator.start_generation();
780
781    // Generate data for this node
782    let generation_start = Instant::now();
783    let (x, y) = make_blobs(
784        samples_for_this_node,
785        n_features,
786        centers,
787        1.0, // cluster_std
788        node_seed,
789    )?;
790    let generation_time = generation_start.elapsed();
791
792    generator.complete_generation(samples_for_this_node);
793
794    // Create node result
795    let node_result = NodeResult {
796        node_id: config.node_id,
797        data: (x.clone(), y.clone()),
798        generation_time,
799        samples_generated: samples_for_this_node,
800    };
801
802    let mut node_results = HashMap::new();
803    node_results.insert(config.node_id, node_result);
804
805    let total_generation_time = start_time.elapsed();
806    let coordination_overhead = total_generation_time - generation_time;
807    let load_balance_efficiency = generator.calculate_load_balance_efficiency();
808
809    Ok(DistributedGenerationResult {
810        data: (x, y),
811        node_results,
812        total_generation_time,
813        coordination_overhead,
814        n_nodes_used: 1, // Only current node in this implementation
815        load_balance_efficiency,
816    })
817}
818
819#[allow(non_snake_case)]
820#[cfg(test)]
821mod tests {
822    use super::*;
823
824    #[test]
825    fn test_stream_classification() {
826        let config = StreamConfig {
827            chunk_size: 100,
828            total_samples: 300,
829            random_state: Some(42),
830            n_workers: 2,
831        };
832
833        let stream = stream_classification(4, 3, config);
834        let mut total_samples = 0;
835
836        for (i, (x, y)) in stream.enumerate() {
837            assert_eq!(x.ncols(), 4); // 4 features
838            assert!(y.iter().all(|&label| label < 3)); // 3 classes (0, 1, 2)
839
840            if i < 2 {
841                assert_eq!(x.nrows(), 100); // First two chunks should be full size
842                assert_eq!(y.len(), 100);
843            } else {
844                assert_eq!(x.nrows(), 100); // Last chunk should be remaining samples
845                assert_eq!(y.len(), 100);
846            }
847
848            total_samples += x.nrows();
849        }
850
851        assert_eq!(total_samples, 300);
852    }
853
854    #[test]
855    fn test_parallel_classification() {
856        let result = parallel_classification(1000, 5, 3, 4).unwrap();
857
858        assert_eq!(result.n_workers_used, 4);
859        assert_eq!(result.chunks.len(), 4);
860
861        let total_samples: usize = result.chunks.iter().map(|(x, _)| x.nrows()).sum();
862        assert_eq!(total_samples, 1000);
863
864        // Verify all chunks have correct number of features
865        for (x, y) in &result.chunks {
866            assert_eq!(x.ncols(), 5);
867            assert!(y.iter().all(|&label| label < 3));
868        }
869    }
870
871    #[test]
872    fn test_lazy_generator() {
873        let mut generator = lazy_classification(500, 3, 2, 150, Some(42));
874
875        let mut total_samples = 0;
876        let mut chunk_count = 0;
877
878        while !generator.is_complete() {
879            if let Some(result) = generator.next_chunk() {
880                let (x, y) = result.unwrap();
881                assert_eq!(x.ncols(), 3);
882                assert!(y.iter().all(|&label| label < 2));
883
884                total_samples += x.nrows();
885                chunk_count += 1;
886
887                let (generated, total, progress) = generator.progress();
888                assert_eq!(generated, total_samples);
889                assert_eq!(total, 500);
890                assert!((0.0..=1.0).contains(&progress));
891            } else {
892                break;
893            }
894        }
895
896        assert_eq!(total_samples, 500);
897        assert_eq!(chunk_count, 4); // 150 + 150 + 150 + 50
898        assert!(generator.is_complete());
899    }
900
901    #[test]
902    fn test_stream_config_default() {
903        let config = StreamConfig::default();
904        assert_eq!(config.chunk_size, 1000);
905        assert_eq!(config.total_samples, 10000);
906        assert!(config.random_state.is_none());
907        assert!(config.n_workers > 0);
908    }
909
910    #[test]
911    fn test_parallel_generation_timing() {
912        let start = std::time::Instant::now();
913        let result = parallel_regression(2000, 10, 2).unwrap();
914        let sequential_time = start.elapsed();
915
916        assert!(result.generation_time <= sequential_time * 2); // Should be reasonably fast
917        assert_eq!(result.n_workers_used, 2);
918
919        let total_samples: usize = result.chunks.iter().map(|(x, _)| x.nrows()).sum();
920        assert_eq!(total_samples, 2000);
921    }
922
923    #[test]
924    fn test_distributed_config_default() {
925        let config = DistributedConfig::default();
926        assert_eq!(config.total_samples, 100000);
927        assert_eq!(config.n_nodes, 1);
928        assert_eq!(config.node_id, 0);
929        assert!(config.random_state.is_none());
930        assert_eq!(config.timeout, Duration::from_secs(300));
931        assert!(matches!(
932            config.load_balancing,
933            LoadBalancingStrategy::EqualSplit
934        ));
935    }
936
937    #[test]
938    fn test_distributed_generator_sample_distribution() {
939        let mut config = DistributedConfig::default();
940        config.total_samples = 1000;
941        config.n_nodes = 3;
942        config.node_id = 0;
943
944        let mut generator = DistributedGenerator::new(config).unwrap();
945        generator.calculate_sample_distribution().unwrap();
946
947        // Check equal split: 1000 / 3 = 333 remainder 1
948        // Node 0: 334, Node 1: 333, Node 2: 333
949        assert_eq!(generator.nodes[&0].samples_assigned, 334);
950        assert_eq!(generator.nodes[&1].samples_assigned, 333);
951        assert_eq!(generator.nodes[&2].samples_assigned, 333);
952
953        let total_assigned: usize = generator
954            .nodes
955            .values()
956            .map(|node| node.samples_assigned)
957            .sum();
958        assert_eq!(total_assigned, 1000);
959    }
960
961    #[test]
962    fn test_distributed_generator_weighted_distribution() {
963        let mut config = DistributedConfig::default();
964        config.total_samples = 1000;
965        config.n_nodes = 3;
966        config.node_id = 0;
967        config.load_balancing = LoadBalancingStrategy::Weighted(vec![0.5, 0.3, 0.2]);
968
969        let mut generator = DistributedGenerator::new(config).unwrap();
970        generator.calculate_sample_distribution().unwrap();
971
972        // Check weighted split: 500, 300, 200
973        assert_eq!(generator.nodes[&0].samples_assigned, 500);
974        assert_eq!(generator.nodes[&1].samples_assigned, 300);
975        assert_eq!(generator.nodes[&2].samples_assigned, 200);
976
977        let total_assigned: usize = generator
978            .nodes
979            .values()
980            .map(|node| node.samples_assigned)
981            .sum();
982        assert_eq!(total_assigned, 1000);
983    }
984
985    #[test]
986    fn test_distributed_classification() {
987        let config = DistributedConfig {
988            total_samples: 1000,
989            n_nodes: 4,
990            node_id: 1,
991            random_state: Some(42),
992            ..Default::default()
993        };
994
995        let result = distributed_classification(5, 3, config).unwrap();
996
997        // Check that node 1 gets approximately 250 samples (1000 / 4)
998        assert_eq!(result.data.0.nrows(), 250);
999        assert_eq!(result.data.1.len(), 250);
1000        assert_eq!(result.data.0.ncols(), 5);
1001        assert!(result.data.1.iter().all(|&label| label < 3));
1002
1003        assert_eq!(result.n_nodes_used, 1);
1004        assert!(result.node_results.contains_key(&1));
1005        assert_eq!(result.node_results[&1].samples_generated, 250);
1006        assert!(result.total_generation_time > Duration::from_nanos(0));
1007    }
1008
1009    #[test]
1010    fn test_distributed_regression() {
1011        let config = DistributedConfig {
1012            total_samples: 800,
1013            n_nodes: 2,
1014            node_id: 0,
1015            random_state: Some(123),
1016            ..Default::default()
1017        };
1018
1019        let result = distributed_regression(7, config).unwrap();
1020
1021        // Check that node 0 gets 400 samples (800 / 2)
1022        assert_eq!(result.data.0.nrows(), 400);
1023        assert_eq!(result.data.1.len(), 400);
1024        assert_eq!(result.data.0.ncols(), 7);
1025
1026        assert_eq!(result.n_nodes_used, 1);
1027        assert!(result.node_results.contains_key(&0));
1028        assert_eq!(result.node_results[&0].samples_generated, 400);
1029        assert!(result.load_balance_efficiency >= 0.0);
1030        assert!(result.load_balance_efficiency <= 1.0);
1031    }
1032
1033    #[test]
1034    fn test_distributed_blobs() {
1035        let config = DistributedConfig {
1036            total_samples: 600,
1037            n_nodes: 3,
1038            node_id: 2,
1039            random_state: Some(456),
1040            ..Default::default()
1041        };
1042
1043        let result = distributed_blobs(4, 5, config).unwrap();
1044
1045        // Check that node 2 gets 200 samples (600 / 3)
1046        assert_eq!(result.data.0.nrows(), 200);
1047        assert_eq!(result.data.1.len(), 200);
1048        assert_eq!(result.data.0.ncols(), 4);
1049        assert!(result.data.1.iter().all(|&label| label < 5));
1050
1051        assert_eq!(result.n_nodes_used, 1);
1052        assert!(result.node_results.contains_key(&2));
1053        assert_eq!(result.node_results[&2].samples_generated, 200);
1054        assert!(result.coordination_overhead >= Duration::from_nanos(0));
1055    }
1056
1057    #[test]
1058    fn test_distributed_generator_invalid_node_id() {
1059        let config = DistributedConfig {
1060            total_samples: 1000,
1061            n_nodes: 3,
1062            node_id: 3, // Invalid: should be 0, 1, or 2
1063            ..Default::default()
1064        };
1065
1066        let result = DistributedGenerator::new(config);
1067        assert!(result.is_err());
1068        assert!(result
1069            .unwrap_err()
1070            .to_string()
1071            .contains("Node ID must be less than total number of nodes"));
1072    }
1073
1074    #[test]
1075    fn test_distributed_generator_weighted_validation() {
1076        let mut config = DistributedConfig::default();
1077        config.n_nodes = 3;
1078        config.load_balancing = LoadBalancingStrategy::Weighted(vec![0.5, 0.3]); // Wrong length
1079
1080        let mut generator = DistributedGenerator::new(config).unwrap();
1081        let result = generator.calculate_sample_distribution();
1082        assert!(result.is_err());
1083        assert!(result
1084            .unwrap_err()
1085            .to_string()
1086            .contains("Number of weights must match number of nodes"));
1087    }
1088
1089    #[test]
1090    fn test_load_balance_efficiency_calculation() {
1091        let mut config = DistributedConfig::default();
1092        config.n_nodes = 2;
1093        config.node_id = 0;
1094
1095        let mut generator = DistributedGenerator::new(config).unwrap();
1096
1097        // Simulate nodes completing at different times
1098        generator.nodes.get_mut(&0).unwrap().status = NodeStatus::Completed;
1099        generator.nodes.get_mut(&0).unwrap().start_time =
1100            Some(Instant::now() - Duration::from_millis(100));
1101        generator.nodes.get_mut(&0).unwrap().completion_time = Some(Instant::now());
1102
1103        generator.nodes.get_mut(&1).unwrap().status = NodeStatus::Completed;
1104        generator.nodes.get_mut(&1).unwrap().start_time =
1105            Some(Instant::now() - Duration::from_millis(200));
1106        generator.nodes.get_mut(&1).unwrap().completion_time = Some(Instant::now());
1107
1108        let efficiency = generator.calculate_load_balance_efficiency();
1109        assert!(efficiency >= 0.0);
1110        assert!(efficiency <= 1.0);
1111    }
1112}