scirs2_datasets/
generators.rs

1//! Dataset generators
2
3use crate::error::{DatasetsError, Result};
4use crate::utils::Dataset;
5use ndarray::{Array1, Array2};
6use rand::prelude::*;
7use rand::rng;
8use rand::rngs::StdRng;
9use rand_distr::Distribution;
10use std::f64::consts::PI;
11
12/// Generate a random classification dataset with clusters
13pub fn make_classification(
14    n_samples: usize,
15    n_features: usize,
16    n_classes: usize,
17    n_clusters_per_class: usize,
18    n_informative: usize,
19    random_seed: Option<u64>,
20) -> Result<Dataset> {
21    // Validate input parameters
22    if n_samples == 0 {
23        return Err(DatasetsError::InvalidFormat(
24            "n_samples must be > 0".to_string(),
25        ));
26    }
27
28    if n_features == 0 {
29        return Err(DatasetsError::InvalidFormat(
30            "n_features must be > 0".to_string(),
31        ));
32    }
33
34    if n_informative == 0 {
35        return Err(DatasetsError::InvalidFormat(
36            "n_informative must be > 0".to_string(),
37        ));
38    }
39
40    if n_features < n_informative {
41        return Err(DatasetsError::InvalidFormat(format!(
42            "n_features ({}) must be >= n_informative ({})",
43            n_features, n_informative
44        )));
45    }
46
47    if n_classes < 2 {
48        return Err(DatasetsError::InvalidFormat(
49            "n_classes must be >= 2".to_string(),
50        ));
51    }
52
53    if n_clusters_per_class == 0 {
54        return Err(DatasetsError::InvalidFormat(
55            "n_clusters_per_class must be > 0".to_string(),
56        ));
57    }
58
59    let mut rng = match random_seed {
60        Some(seed) => StdRng::seed_from_u64(seed),
61        None => {
62            let mut r = rng();
63            StdRng::seed_from_u64(r.next_u64())
64        }
65    };
66
67    // Generate centroids for each class and cluster
68    let n_centroids = n_classes * n_clusters_per_class;
69    let mut centroids = Array2::zeros((n_centroids, n_informative));
70    let scale = 2.0;
71
72    for i in 0..n_centroids {
73        for j in 0..n_informative {
74            centroids[[i, j]] = scale * rng.random_range(-1.0f64..1.0f64);
75        }
76    }
77
78    // Generate samples
79    let mut data = Array2::zeros((n_samples, n_features));
80    let mut target = Array1::zeros(n_samples);
81
82    let normal = rand_distr::Normal::new(0.0, 1.0).unwrap();
83
84    // Samples per class
85    let samples_per_class = n_samples / n_classes;
86    let remainder = n_samples % n_classes;
87
88    let mut sample_idx = 0;
89
90    for class in 0..n_classes {
91        let n_samples_class = if class < remainder {
92            samples_per_class + 1
93        } else {
94            samples_per_class
95        };
96
97        // Assign clusters within this class
98        let samples_per_cluster = n_samples_class / n_clusters_per_class;
99        let cluster_remainder = n_samples_class % n_clusters_per_class;
100
101        for cluster in 0..n_clusters_per_class {
102            let n_samples_cluster = if cluster < cluster_remainder {
103                samples_per_cluster + 1
104            } else {
105                samples_per_cluster
106            };
107
108            let centroid_idx = class * n_clusters_per_class + cluster;
109
110            for _ in 0..n_samples_cluster {
111                // Randomly select a point near the cluster centroid
112                for j in 0..n_informative {
113                    data[[sample_idx, j]] =
114                        centroids[[centroid_idx, j]] + 0.3 * normal.sample(&mut rng);
115                }
116
117                // Add noise features
118                for j in n_informative..n_features {
119                    data[[sample_idx, j]] = normal.sample(&mut rng);
120                }
121
122                target[sample_idx] = class as f64;
123                sample_idx += 1;
124            }
125        }
126    }
127
128    // Create dataset
129    let mut dataset = Dataset::new(data, Some(target));
130
131    // Create feature names
132    let feature_names: Vec<String> = (0..n_features).map(|i| format!("feature_{}", i)).collect();
133
134    // Create class names
135    let class_names: Vec<String> = (0..n_classes).map(|i| format!("class_{}", i)).collect();
136
137    dataset = dataset
138        .with_feature_names(feature_names)
139        .with_target_names(class_names)
140        .with_description(format!(
141            "Synthetic classification dataset with {} classes and {} features",
142            n_classes, n_features
143        ));
144
145    Ok(dataset)
146}
147
148/// Generate a random regression dataset
149pub fn make_regression(
150    n_samples: usize,
151    n_features: usize,
152    n_informative: usize,
153    noise: f64,
154    random_seed: Option<u64>,
155) -> Result<Dataset> {
156    // Validate input parameters
157    if n_samples == 0 {
158        return Err(DatasetsError::InvalidFormat(
159            "n_samples must be > 0".to_string(),
160        ));
161    }
162
163    if n_features == 0 {
164        return Err(DatasetsError::InvalidFormat(
165            "n_features must be > 0".to_string(),
166        ));
167    }
168
169    if n_informative == 0 {
170        return Err(DatasetsError::InvalidFormat(
171            "n_informative must be > 0".to_string(),
172        ));
173    }
174
175    if n_features < n_informative {
176        return Err(DatasetsError::InvalidFormat(format!(
177            "n_features ({}) must be >= n_informative ({})",
178            n_features, n_informative
179        )));
180    }
181
182    if noise < 0.0 {
183        return Err(DatasetsError::InvalidFormat(
184            "noise must be >= 0.0".to_string(),
185        ));
186    }
187
188    let mut rng = match random_seed {
189        Some(seed) => StdRng::seed_from_u64(seed),
190        None => {
191            let mut r = rng();
192            StdRng::seed_from_u64(r.next_u64())
193        }
194    };
195
196    // Generate the coefficients for the informative features
197    let mut coef = Array1::zeros(n_features);
198    let normal = rand_distr::Normal::new(0.0, 1.0).unwrap();
199
200    for i in 0..n_informative {
201        coef[i] = 100.0 * normal.sample(&mut rng);
202    }
203
204    // Generate the features
205    let mut data = Array2::zeros((n_samples, n_features));
206
207    for i in 0..n_samples {
208        for j in 0..n_features {
209            data[[i, j]] = normal.sample(&mut rng);
210        }
211    }
212
213    // Generate the target
214    let mut target = Array1::zeros(n_samples);
215
216    for i in 0..n_samples {
217        let mut y = 0.0;
218        for j in 0..n_features {
219            y += data[[i, j]] * coef[j];
220        }
221
222        // Add noise
223        if noise > 0.0 {
224            y += normal.sample(&mut rng) * noise;
225        }
226
227        target[i] = y;
228    }
229
230    // Create dataset
231    let mut dataset = Dataset::new(data, Some(target));
232
233    // Create feature names
234    let feature_names: Vec<String> = (0..n_features).map(|i| format!("feature_{}", i)).collect();
235
236    dataset = dataset
237        .with_feature_names(feature_names)
238        .with_description(format!(
239            "Synthetic regression dataset with {} features ({} informative)",
240            n_features, n_informative
241        ))
242        .with_metadata("noise", &noise.to_string())
243        .with_metadata("coefficients", &format!("{:?}", coef));
244
245    Ok(dataset)
246}
247
248/// Generate a random time series dataset
249pub fn make_time_series(
250    n_samples: usize,
251    n_features: usize,
252    trend: bool,
253    seasonality: bool,
254    noise: f64,
255    random_seed: Option<u64>,
256) -> Result<Dataset> {
257    // Validate input parameters
258    if n_samples == 0 {
259        return Err(DatasetsError::InvalidFormat(
260            "n_samples must be > 0".to_string(),
261        ));
262    }
263
264    if n_features == 0 {
265        return Err(DatasetsError::InvalidFormat(
266            "n_features must be > 0".to_string(),
267        ));
268    }
269
270    if noise < 0.0 {
271        return Err(DatasetsError::InvalidFormat(
272            "noise must be >= 0.0".to_string(),
273        ));
274    }
275
276    let mut rng = match random_seed {
277        Some(seed) => StdRng::seed_from_u64(seed),
278        None => {
279            let mut r = rng();
280            StdRng::seed_from_u64(r.next_u64())
281        }
282    };
283
284    let normal = rand_distr::Normal::new(0.0, 1.0).unwrap();
285    let mut data = Array2::zeros((n_samples, n_features));
286
287    for feature in 0..n_features {
288        let trend_coef = if trend {
289            rng.random_range(0.01f64..0.1f64)
290        } else {
291            0.0
292        };
293        let seasonality_period = rng.random_range(10..=50) as f64;
294        let seasonality_amplitude = if seasonality {
295            rng.random_range(1.0f64..5.0f64)
296        } else {
297            0.0
298        };
299
300        let base_value = rng.random_range(-10.0f64..10.0f64);
301
302        for i in 0..n_samples {
303            let t = i as f64;
304
305            // Add base value
306            let mut value = base_value;
307
308            // Add trend
309            if trend {
310                value += trend_coef * t;
311            }
312
313            // Add seasonality
314            if seasonality {
315                value += seasonality_amplitude * (2.0 * PI * t / seasonality_period).sin();
316            }
317
318            // Add noise
319            if noise > 0.0 {
320                value += normal.sample(&mut rng) * noise;
321            }
322
323            data[[i, feature]] = value;
324        }
325    }
326
327    // Create time index (unused for now but can be useful for plotting)
328    let time_index: Vec<f64> = (0..n_samples).map(|i| i as f64).collect();
329    let _time_array = Array1::from(time_index);
330
331    // Create dataset
332    let mut dataset = Dataset::new(data, None);
333
334    // Create feature names
335    let feature_names: Vec<String> = (0..n_features).map(|i| format!("feature_{}", i)).collect();
336
337    dataset = dataset
338        .with_feature_names(feature_names)
339        .with_description(format!(
340            "Synthetic time series dataset with {} features",
341            n_features
342        ))
343        .with_metadata("trend", &trend.to_string())
344        .with_metadata("seasonality", &seasonality.to_string())
345        .with_metadata("noise", &noise.to_string());
346
347    Ok(dataset)
348}
349
350/// Generate a random blobs dataset for clustering
351pub fn make_blobs(
352    n_samples: usize,
353    n_features: usize,
354    centers: usize,
355    cluster_std: f64,
356    random_seed: Option<u64>,
357) -> Result<Dataset> {
358    // Validate input parameters
359    if n_samples == 0 {
360        return Err(DatasetsError::InvalidFormat(
361            "n_samples must be > 0".to_string(),
362        ));
363    }
364
365    if n_features == 0 {
366        return Err(DatasetsError::InvalidFormat(
367            "n_features must be > 0".to_string(),
368        ));
369    }
370
371    if centers == 0 {
372        return Err(DatasetsError::InvalidFormat(
373            "centers must be > 0".to_string(),
374        ));
375    }
376
377    if cluster_std <= 0.0 {
378        return Err(DatasetsError::InvalidFormat(
379            "cluster_std must be > 0.0".to_string(),
380        ));
381    }
382
383    let mut rng = match random_seed {
384        Some(seed) => StdRng::seed_from_u64(seed),
385        None => {
386            let mut r = rng();
387            StdRng::seed_from_u64(r.next_u64())
388        }
389    };
390
391    // Generate random centers
392    let mut cluster_centers = Array2::zeros((centers, n_features));
393    let center_box = 10.0;
394
395    for i in 0..centers {
396        for j in 0..n_features {
397            cluster_centers[[i, j]] = rng.random_range(-center_box..=center_box);
398        }
399    }
400
401    // Generate samples around centers
402    let mut data = Array2::zeros((n_samples, n_features));
403    let mut target = Array1::zeros(n_samples);
404
405    let normal = rand_distr::Normal::new(0.0, cluster_std).unwrap();
406
407    // Samples per center
408    let samples_per_center = n_samples / centers;
409    let remainder = n_samples % centers;
410
411    let mut sample_idx = 0;
412
413    for center_idx in 0..centers {
414        let n_samples_center = if center_idx < remainder {
415            samples_per_center + 1
416        } else {
417            samples_per_center
418        };
419
420        for _ in 0..n_samples_center {
421            for j in 0..n_features {
422                data[[sample_idx, j]] = cluster_centers[[center_idx, j]] + normal.sample(&mut rng);
423            }
424
425            target[sample_idx] = center_idx as f64;
426            sample_idx += 1;
427        }
428    }
429
430    // Create dataset
431    let mut dataset = Dataset::new(data, Some(target));
432
433    // Create feature names
434    let feature_names: Vec<String> = (0..n_features).map(|i| format!("feature_{}", i)).collect();
435
436    dataset = dataset
437        .with_feature_names(feature_names)
438        .with_description(format!(
439            "Synthetic clustering dataset with {} clusters and {} features",
440            centers, n_features
441        ))
442        .with_metadata("centers", &centers.to_string())
443        .with_metadata("cluster_std", &cluster_std.to_string());
444
445    Ok(dataset)
446}
447
448/// Generate a spiral dataset for non-linear classification
449pub fn make_spirals(
450    n_samples: usize,
451    n_spirals: usize,
452    noise: f64,
453    random_seed: Option<u64>,
454) -> Result<Dataset> {
455    // Validate input parameters
456    if n_samples == 0 {
457        return Err(DatasetsError::InvalidFormat(
458            "n_samples must be > 0".to_string(),
459        ));
460    }
461
462    if n_spirals == 0 {
463        return Err(DatasetsError::InvalidFormat(
464            "n_spirals must be > 0".to_string(),
465        ));
466    }
467
468    if noise < 0.0 {
469        return Err(DatasetsError::InvalidFormat(
470            "noise must be >= 0.0".to_string(),
471        ));
472    }
473
474    let mut rng = match random_seed {
475        Some(seed) => StdRng::seed_from_u64(seed),
476        None => {
477            let mut r = rng();
478            StdRng::seed_from_u64(r.next_u64())
479        }
480    };
481
482    let mut data = Array2::zeros((n_samples, 2));
483    let mut target = Array1::zeros(n_samples);
484
485    let normal = if noise > 0.0 {
486        Some(rand_distr::Normal::new(0.0, noise).unwrap())
487    } else {
488        None
489    };
490
491    let samples_per_spiral = n_samples / n_spirals;
492    let remainder = n_samples % n_spirals;
493
494    let mut sample_idx = 0;
495
496    for spiral in 0..n_spirals {
497        let n_samples_spiral = if spiral < remainder {
498            samples_per_spiral + 1
499        } else {
500            samples_per_spiral
501        };
502
503        let spiral_offset = 2.0 * PI * spiral as f64 / n_spirals as f64;
504
505        for i in 0..n_samples_spiral {
506            let t = 2.0 * PI * i as f64 / n_samples_spiral as f64;
507            let radius = t / (2.0 * PI);
508
509            let mut x = radius * (t + spiral_offset).cos();
510            let mut y = radius * (t + spiral_offset).sin();
511
512            // Add noise if specified
513            if let Some(ref normal_dist) = normal {
514                x += normal_dist.sample(&mut rng);
515                y += normal_dist.sample(&mut rng);
516            }
517
518            data[[sample_idx, 0]] = x;
519            data[[sample_idx, 1]] = y;
520            target[sample_idx] = spiral as f64;
521            sample_idx += 1;
522        }
523    }
524
525    let mut dataset = Dataset::new(data, Some(target));
526    dataset = dataset
527        .with_feature_names(vec!["x".to_string(), "y".to_string()])
528        .with_target_names((0..n_spirals).map(|i| format!("spiral_{}", i)).collect())
529        .with_description(format!("Spiral dataset with {} spirals", n_spirals))
530        .with_metadata("noise", &noise.to_string());
531
532    Ok(dataset)
533}
534
535/// Generate a moons dataset for non-linear classification
536pub fn make_moons(n_samples: usize, noise: f64, random_seed: Option<u64>) -> Result<Dataset> {
537    // Validate input parameters
538    if n_samples == 0 {
539        return Err(DatasetsError::InvalidFormat(
540            "n_samples must be > 0".to_string(),
541        ));
542    }
543
544    if noise < 0.0 {
545        return Err(DatasetsError::InvalidFormat(
546            "noise must be >= 0.0".to_string(),
547        ));
548    }
549
550    let mut rng = match random_seed {
551        Some(seed) => StdRng::seed_from_u64(seed),
552        None => {
553            let mut r = rng();
554            StdRng::seed_from_u64(r.next_u64())
555        }
556    };
557
558    let mut data = Array2::zeros((n_samples, 2));
559    let mut target = Array1::zeros(n_samples);
560
561    let normal = if noise > 0.0 {
562        Some(rand_distr::Normal::new(0.0, noise).unwrap())
563    } else {
564        None
565    };
566
567    let samples_per_moon = n_samples / 2;
568    let remainder = n_samples % 2;
569
570    let mut sample_idx = 0;
571
572    // Generate first moon (upper crescent)
573    for i in 0..(samples_per_moon + remainder) {
574        let t = PI * i as f64 / (samples_per_moon + remainder) as f64;
575
576        let mut x = t.cos();
577        let mut y = t.sin();
578
579        // Add noise if specified
580        if let Some(ref normal_dist) = normal {
581            x += normal_dist.sample(&mut rng);
582            y += normal_dist.sample(&mut rng);
583        }
584
585        data[[sample_idx, 0]] = x;
586        data[[sample_idx, 1]] = y;
587        target[sample_idx] = 0.0;
588        sample_idx += 1;
589    }
590
591    // Generate second moon (lower crescent, flipped)
592    for i in 0..samples_per_moon {
593        let t = PI * i as f64 / samples_per_moon as f64;
594
595        let mut x = 1.0 - t.cos();
596        let mut y = 0.5 - t.sin(); // Offset vertically and flip
597
598        // Add noise if specified
599        if let Some(ref normal_dist) = normal {
600            x += normal_dist.sample(&mut rng);
601            y += normal_dist.sample(&mut rng);
602        }
603
604        data[[sample_idx, 0]] = x;
605        data[[sample_idx, 1]] = y;
606        target[sample_idx] = 1.0;
607        sample_idx += 1;
608    }
609
610    let mut dataset = Dataset::new(data, Some(target));
611    dataset = dataset
612        .with_feature_names(vec!["x".to_string(), "y".to_string()])
613        .with_target_names(vec!["moon_0".to_string(), "moon_1".to_string()])
614        .with_description("Two moons dataset for non-linear classification".to_string())
615        .with_metadata("noise", &noise.to_string());
616
617    Ok(dataset)
618}
619
620/// Generate a circles dataset for non-linear classification
621pub fn make_circles(
622    n_samples: usize,
623    factor: f64,
624    noise: f64,
625    random_seed: Option<u64>,
626) -> Result<Dataset> {
627    // Validate input parameters
628    if n_samples == 0 {
629        return Err(DatasetsError::InvalidFormat(
630            "n_samples must be > 0".to_string(),
631        ));
632    }
633
634    if factor <= 0.0 || factor >= 1.0 {
635        return Err(DatasetsError::InvalidFormat(
636            "factor must be between 0.0 and 1.0".to_string(),
637        ));
638    }
639
640    if noise < 0.0 {
641        return Err(DatasetsError::InvalidFormat(
642            "noise must be >= 0.0".to_string(),
643        ));
644    }
645
646    let mut rng = match random_seed {
647        Some(seed) => StdRng::seed_from_u64(seed),
648        None => {
649            let mut r = rng();
650            StdRng::seed_from_u64(r.next_u64())
651        }
652    };
653
654    let mut data = Array2::zeros((n_samples, 2));
655    let mut target = Array1::zeros(n_samples);
656
657    let normal = if noise > 0.0 {
658        Some(rand_distr::Normal::new(0.0, noise).unwrap())
659    } else {
660        None
661    };
662
663    let samples_per_circle = n_samples / 2;
664    let remainder = n_samples % 2;
665
666    let mut sample_idx = 0;
667
668    // Generate outer circle
669    for i in 0..(samples_per_circle + remainder) {
670        let angle = 2.0 * PI * i as f64 / (samples_per_circle + remainder) as f64;
671
672        let mut x = angle.cos();
673        let mut y = angle.sin();
674
675        // Add noise if specified
676        if let Some(ref normal_dist) = normal {
677            x += normal_dist.sample(&mut rng);
678            y += normal_dist.sample(&mut rng);
679        }
680
681        data[[sample_idx, 0]] = x;
682        data[[sample_idx, 1]] = y;
683        target[sample_idx] = 0.0;
684        sample_idx += 1;
685    }
686
687    // Generate inner circle (scaled by factor)
688    for i in 0..samples_per_circle {
689        let angle = 2.0 * PI * i as f64 / samples_per_circle as f64;
690
691        let mut x = factor * angle.cos();
692        let mut y = factor * angle.sin();
693
694        // Add noise if specified
695        if let Some(ref normal_dist) = normal {
696            x += normal_dist.sample(&mut rng);
697            y += normal_dist.sample(&mut rng);
698        }
699
700        data[[sample_idx, 0]] = x;
701        data[[sample_idx, 1]] = y;
702        target[sample_idx] = 1.0;
703        sample_idx += 1;
704    }
705
706    let mut dataset = Dataset::new(data, Some(target));
707    dataset = dataset
708        .with_feature_names(vec!["x".to_string(), "y".to_string()])
709        .with_target_names(vec!["outer_circle".to_string(), "inner_circle".to_string()])
710        .with_description("Concentric circles dataset for non-linear classification".to_string())
711        .with_metadata("factor", &factor.to_string())
712        .with_metadata("noise", &noise.to_string());
713
714    Ok(dataset)
715}
716
717/// Generate a Swiss roll dataset for dimensionality reduction
718pub fn make_swiss_roll(n_samples: usize, noise: f64, random_seed: Option<u64>) -> Result<Dataset> {
719    // Validate input parameters
720    if n_samples == 0 {
721        return Err(DatasetsError::InvalidFormat(
722            "n_samples must be > 0".to_string(),
723        ));
724    }
725
726    if noise < 0.0 {
727        return Err(DatasetsError::InvalidFormat(
728            "noise must be >= 0.0".to_string(),
729        ));
730    }
731
732    let mut rng = match random_seed {
733        Some(seed) => StdRng::seed_from_u64(seed),
734        None => {
735            let mut r = rng();
736            StdRng::seed_from_u64(r.next_u64())
737        }
738    };
739
740    let mut data = Array2::zeros((n_samples, 3));
741    let mut color = Array1::zeros(n_samples); // Color parameter for visualization
742
743    let normal = if noise > 0.0 {
744        Some(rand_distr::Normal::new(0.0, noise).unwrap())
745    } else {
746        None
747    };
748
749    for i in 0..n_samples {
750        // Parameter along the roll
751        let t = 1.5 * PI * (1.0 + 2.0 * i as f64 / n_samples as f64);
752
753        // Height parameter
754        let height = 21.0 * i as f64 / n_samples as f64;
755
756        let mut x = t * t.cos();
757        let mut y = height;
758        let mut z = t * t.sin();
759
760        // Add noise if specified
761        if let Some(ref normal_dist) = normal {
762            x += normal_dist.sample(&mut rng);
763            y += normal_dist.sample(&mut rng);
764            z += normal_dist.sample(&mut rng);
765        }
766
767        data[[i, 0]] = x;
768        data[[i, 1]] = y;
769        data[[i, 2]] = z;
770        color[i] = t; // Color based on parameter for visualization
771    }
772
773    let mut dataset = Dataset::new(data, Some(color));
774    dataset = dataset
775        .with_feature_names(vec!["x".to_string(), "y".to_string(), "z".to_string()])
776        .with_description("Swiss roll manifold dataset for dimensionality reduction".to_string())
777        .with_metadata("noise", &noise.to_string())
778        .with_metadata("dimensions", "3")
779        .with_metadata("manifold_dim", "2");
780
781    Ok(dataset)
782}
783
784/// Generate anisotropic (elongated) clusters
785pub fn make_anisotropic_blobs(
786    n_samples: usize,
787    n_features: usize,
788    centers: usize,
789    cluster_std: f64,
790    anisotropy_factor: f64,
791    random_seed: Option<u64>,
792) -> Result<Dataset> {
793    // Validate input parameters
794    if n_samples == 0 {
795        return Err(DatasetsError::InvalidFormat(
796            "n_samples must be > 0".to_string(),
797        ));
798    }
799
800    if n_features < 2 {
801        return Err(DatasetsError::InvalidFormat(
802            "n_features must be >= 2 for anisotropic clusters".to_string(),
803        ));
804    }
805
806    if centers == 0 {
807        return Err(DatasetsError::InvalidFormat(
808            "centers must be > 0".to_string(),
809        ));
810    }
811
812    if cluster_std <= 0.0 {
813        return Err(DatasetsError::InvalidFormat(
814            "cluster_std must be > 0.0".to_string(),
815        ));
816    }
817
818    if anisotropy_factor <= 0.0 {
819        return Err(DatasetsError::InvalidFormat(
820            "anisotropy_factor must be > 0.0".to_string(),
821        ));
822    }
823
824    let mut rng = match random_seed {
825        Some(seed) => StdRng::seed_from_u64(seed),
826        None => {
827            let mut r = rng();
828            StdRng::seed_from_u64(r.next_u64())
829        }
830    };
831
832    // Generate random centers
833    let mut cluster_centers = Array2::zeros((centers, n_features));
834    let center_box = 10.0;
835
836    for i in 0..centers {
837        for j in 0..n_features {
838            cluster_centers[[i, j]] = rng.random_range(-center_box..=center_box);
839        }
840    }
841
842    // Generate samples around centers with anisotropic distribution
843    let mut data = Array2::zeros((n_samples, n_features));
844    let mut target = Array1::zeros(n_samples);
845
846    let normal = rand_distr::Normal::new(0.0, cluster_std).unwrap();
847
848    let samples_per_center = n_samples / centers;
849    let remainder = n_samples % centers;
850
851    let mut sample_idx = 0;
852
853    for center_idx in 0..centers {
854        let n_samples_center = if center_idx < remainder {
855            samples_per_center + 1
856        } else {
857            samples_per_center
858        };
859
860        // Generate a random rotation angle for this cluster
861        let rotation_angle = rng.random_range(0.0..(2.0 * PI));
862
863        for _ in 0..n_samples_center {
864            // Generate point with anisotropic distribution (elongated along first axis)
865            let mut point = vec![0.0; n_features];
866
867            // First axis has normal std, second axis has reduced std (anisotropy)
868            point[0] = normal.sample(&mut rng);
869            point[1] = normal.sample(&mut rng) / anisotropy_factor;
870
871            // Remaining axes have normal std
872            for item in point.iter_mut().take(n_features).skip(2) {
873                *item = normal.sample(&mut rng);
874            }
875
876            // Apply rotation for 2D case
877            if n_features >= 2 {
878                let cos_theta = rotation_angle.cos();
879                let sin_theta = rotation_angle.sin();
880
881                let x_rot = cos_theta * point[0] - sin_theta * point[1];
882                let y_rot = sin_theta * point[0] + cos_theta * point[1];
883
884                point[0] = x_rot;
885                point[1] = y_rot;
886            }
887
888            // Translate to cluster center
889            for j in 0..n_features {
890                data[[sample_idx, j]] = cluster_centers[[center_idx, j]] + point[j];
891            }
892
893            target[sample_idx] = center_idx as f64;
894            sample_idx += 1;
895        }
896    }
897
898    let mut dataset = Dataset::new(data, Some(target));
899    let feature_names: Vec<String> = (0..n_features).map(|i| format!("feature_{}", i)).collect();
900
901    dataset = dataset
902        .with_feature_names(feature_names)
903        .with_description(format!(
904            "Anisotropic clustering dataset with {} elongated clusters and {} features",
905            centers, n_features
906        ))
907        .with_metadata("centers", &centers.to_string())
908        .with_metadata("cluster_std", &cluster_std.to_string())
909        .with_metadata("anisotropy_factor", &anisotropy_factor.to_string());
910
911    Ok(dataset)
912}
913
914/// Generate hierarchical clusters (clusters within clusters)
915pub fn make_hierarchical_clusters(
916    n_samples: usize,
917    n_features: usize,
918    n_main_clusters: usize,
919    n_sub_clusters: usize,
920    main_cluster_std: f64,
921    sub_cluster_std: f64,
922    random_seed: Option<u64>,
923) -> Result<Dataset> {
924    // Validate input parameters
925    if n_samples == 0 {
926        return Err(DatasetsError::InvalidFormat(
927            "n_samples must be > 0".to_string(),
928        ));
929    }
930
931    if n_features == 0 {
932        return Err(DatasetsError::InvalidFormat(
933            "n_features must be > 0".to_string(),
934        ));
935    }
936
937    if n_main_clusters == 0 {
938        return Err(DatasetsError::InvalidFormat(
939            "n_main_clusters must be > 0".to_string(),
940        ));
941    }
942
943    if n_sub_clusters == 0 {
944        return Err(DatasetsError::InvalidFormat(
945            "n_sub_clusters must be > 0".to_string(),
946        ));
947    }
948
949    if main_cluster_std <= 0.0 {
950        return Err(DatasetsError::InvalidFormat(
951            "main_cluster_std must be > 0.0".to_string(),
952        ));
953    }
954
955    if sub_cluster_std <= 0.0 {
956        return Err(DatasetsError::InvalidFormat(
957            "sub_cluster_std must be > 0.0".to_string(),
958        ));
959    }
960
961    let mut rng = match random_seed {
962        Some(seed) => StdRng::seed_from_u64(seed),
963        None => {
964            let mut r = rng();
965            StdRng::seed_from_u64(r.next_u64())
966        }
967    };
968
969    // Generate main cluster centers
970    let mut main_centers = Array2::zeros((n_main_clusters, n_features));
971    let center_box = 20.0;
972
973    for i in 0..n_main_clusters {
974        for j in 0..n_features {
975            main_centers[[i, j]] = rng.random_range(-center_box..=center_box);
976        }
977    }
978
979    let mut data = Array2::zeros((n_samples, n_features));
980    let mut main_target = Array1::zeros(n_samples);
981    let mut sub_target = Array1::zeros(n_samples);
982
983    let main_normal = rand_distr::Normal::new(0.0, main_cluster_std).unwrap();
984    let sub_normal = rand_distr::Normal::new(0.0, sub_cluster_std).unwrap();
985
986    let samples_per_main = n_samples / n_main_clusters;
987    let remainder = n_samples % n_main_clusters;
988
989    let mut sample_idx = 0;
990
991    for main_idx in 0..n_main_clusters {
992        let n_samples_main = if main_idx < remainder {
993            samples_per_main + 1
994        } else {
995            samples_per_main
996        };
997
998        // Generate sub-cluster centers within this main cluster
999        let mut sub_centers = Array2::zeros((n_sub_clusters, n_features));
1000        for i in 0..n_sub_clusters {
1001            for j in 0..n_features {
1002                sub_centers[[i, j]] = main_centers[[main_idx, j]] + main_normal.sample(&mut rng);
1003            }
1004        }
1005
1006        let samples_per_sub = n_samples_main / n_sub_clusters;
1007        let sub_remainder = n_samples_main % n_sub_clusters;
1008
1009        for sub_idx in 0..n_sub_clusters {
1010            let n_samples_sub = if sub_idx < sub_remainder {
1011                samples_per_sub + 1
1012            } else {
1013                samples_per_sub
1014            };
1015
1016            for _ in 0..n_samples_sub {
1017                for j in 0..n_features {
1018                    data[[sample_idx, j]] = sub_centers[[sub_idx, j]] + sub_normal.sample(&mut rng);
1019                }
1020
1021                main_target[sample_idx] = main_idx as f64;
1022                sub_target[sample_idx] = (main_idx * n_sub_clusters + sub_idx) as f64;
1023                sample_idx += 1;
1024            }
1025        }
1026    }
1027
1028    let mut dataset = Dataset::new(data, Some(main_target));
1029    let feature_names: Vec<String> = (0..n_features).map(|i| format!("feature_{}", i)).collect();
1030
1031    dataset = dataset
1032        .with_feature_names(feature_names)
1033        .with_description(format!(
1034            "Hierarchical clustering dataset with {} main clusters, {} sub-clusters each",
1035            n_main_clusters, n_sub_clusters
1036        ))
1037        .with_metadata("n_main_clusters", &n_main_clusters.to_string())
1038        .with_metadata("n_sub_clusters", &n_sub_clusters.to_string())
1039        .with_metadata("main_cluster_std", &main_cluster_std.to_string())
1040        .with_metadata("sub_cluster_std", &sub_cluster_std.to_string())
1041        .with_metadata("sub_cluster_labels", &format!("{:?}", sub_target.to_vec()));
1042
1043    Ok(dataset)
1044}
1045
1046/// Missing data patterns for noise injection
1047#[derive(Debug, Clone, Copy)]
1048pub enum MissingPattern {
1049    /// Missing Completely at Random - uniform probability across all features
1050    MCAR,
1051    /// Missing at Random - probability depends on observed values
1052    MAR,
1053    /// Missing Not at Random - probability depends on missing values themselves
1054    MNAR,
1055    /// Block-wise missing - entire blocks of consecutive features/samples missing
1056    Block,
1057}
1058
1059/// Outlier types for injection
1060#[derive(Debug, Clone, Copy)]
1061pub enum OutlierType {
1062    /// Point outliers - individual data points that are anomalous
1063    Point,
1064    /// Contextual outliers - points anomalous in specific contexts
1065    Contextual,
1066    /// Collective outliers - groups of points that together form an anomaly
1067    Collective,
1068}
1069
1070/// Inject missing data into a dataset with realistic patterns
1071pub fn inject_missing_data(
1072    data: &mut Array2<f64>,
1073    missing_rate: f64,
1074    pattern: MissingPattern,
1075    random_seed: Option<u64>,
1076) -> Result<Array2<bool>> {
1077    // Validate input parameters
1078    if !(0.0..=1.0).contains(&missing_rate) {
1079        return Err(DatasetsError::InvalidFormat(
1080            "missing_rate must be between 0.0 and 1.0".to_string(),
1081        ));
1082    }
1083
1084    let mut rng = match random_seed {
1085        Some(seed) => StdRng::seed_from_u64(seed),
1086        None => {
1087            let mut r = rng();
1088            StdRng::seed_from_u64(r.next_u64())
1089        }
1090    };
1091
1092    let (n_samples, n_features) = data.dim();
1093    let mut missing_mask = Array2::from_elem((n_samples, n_features), false);
1094
1095    match pattern {
1096        MissingPattern::MCAR => {
1097            // Missing Completely at Random - uniform probability
1098            for i in 0..n_samples {
1099                for j in 0..n_features {
1100                    if rng.random_range(0.0f64..1.0) < missing_rate {
1101                        missing_mask[[i, j]] = true;
1102                        data[[i, j]] = f64::NAN;
1103                    }
1104                }
1105            }
1106        }
1107        MissingPattern::MAR => {
1108            // Missing at Random - probability depends on first feature
1109            for i in 0..n_samples {
1110                let first_feature_val = data[[i, 0]];
1111                let normalized_val = (first_feature_val + 10.0) / 20.0; // Normalize roughly to [0,1]
1112                let adjusted_rate = missing_rate * normalized_val.clamp(0.1, 2.0);
1113
1114                for j in 1..n_features {
1115                    // Skip first feature
1116                    if rng.random_range(0.0f64..1.0) < adjusted_rate {
1117                        missing_mask[[i, j]] = true;
1118                        data[[i, j]] = f64::NAN;
1119                    }
1120                }
1121            }
1122        }
1123        MissingPattern::MNAR => {
1124            // Missing Not at Random - higher values more likely to be missing
1125            for i in 0..n_samples {
1126                for j in 0..n_features {
1127                    let value = data[[i, j]];
1128                    let normalized_val = (value + 10.0) / 20.0; // Normalize roughly to [0,1]
1129                    let adjusted_rate = missing_rate * normalized_val.clamp(0.1, 3.0);
1130
1131                    if rng.random_range(0.0f64..1.0) < adjusted_rate {
1132                        missing_mask[[i, j]] = true;
1133                        data[[i, j]] = f64::NAN;
1134                    }
1135                }
1136            }
1137        }
1138        MissingPattern::Block => {
1139            // Block-wise missing - entire blocks are missing
1140            let block_size = (n_features as f64 * missing_rate).ceil() as usize;
1141            let n_blocks = (missing_rate * n_samples as f64).ceil() as usize;
1142
1143            for _ in 0..n_blocks {
1144                let start_row = rng.random_range(0..n_samples);
1145                let start_col = rng.random_range(0..n_features.saturating_sub(block_size));
1146
1147                for i in start_row..n_samples.min(start_row + block_size) {
1148                    for j in start_col..n_features.min(start_col + block_size) {
1149                        missing_mask[[i, j]] = true;
1150                        data[[i, j]] = f64::NAN;
1151                    }
1152                }
1153            }
1154        }
1155    }
1156
1157    Ok(missing_mask)
1158}
1159
1160/// Inject outliers into a dataset
1161pub fn inject_outliers(
1162    data: &mut Array2<f64>,
1163    outlier_rate: f64,
1164    outlier_type: OutlierType,
1165    outlier_strength: f64,
1166    random_seed: Option<u64>,
1167) -> Result<Array1<bool>> {
1168    // Validate input parameters
1169    if !(0.0..=1.0).contains(&outlier_rate) {
1170        return Err(DatasetsError::InvalidFormat(
1171            "outlier_rate must be between 0.0 and 1.0".to_string(),
1172        ));
1173    }
1174
1175    if outlier_strength <= 0.0 {
1176        return Err(DatasetsError::InvalidFormat(
1177            "outlier_strength must be > 0.0".to_string(),
1178        ));
1179    }
1180
1181    let mut rng = match random_seed {
1182        Some(seed) => StdRng::seed_from_u64(seed),
1183        None => {
1184            let mut r = rng();
1185            StdRng::seed_from_u64(r.next_u64())
1186        }
1187    };
1188
1189    let (n_samples, n_features) = data.dim();
1190    let n_outliers = (n_samples as f64 * outlier_rate).ceil() as usize;
1191    let mut outlier_mask = Array1::from_elem(n_samples, false);
1192
1193    // Calculate data statistics for outlier generation
1194    let mut feature_means = vec![0.0; n_features];
1195    let mut feature_stds = vec![0.0; n_features];
1196
1197    for j in 0..n_features {
1198        let column = data.column(j);
1199        let valid_values: Vec<f64> = column.iter().filter(|&&x| !x.is_nan()).cloned().collect();
1200
1201        if !valid_values.is_empty() {
1202            feature_means[j] = valid_values.iter().sum::<f64>() / valid_values.len() as f64;
1203            let variance = valid_values
1204                .iter()
1205                .map(|&x| (x - feature_means[j]).powi(2))
1206                .sum::<f64>()
1207                / valid_values.len() as f64;
1208            feature_stds[j] = variance.sqrt().max(1.0); // Use minimum std of 1.0 to ensure outliers can be created
1209        }
1210    }
1211
1212    match outlier_type {
1213        OutlierType::Point => {
1214            // Point outliers - individual anomalous points
1215            for _ in 0..n_outliers {
1216                let outlier_idx = rng.random_range(0..n_samples);
1217                outlier_mask[outlier_idx] = true;
1218
1219                // Modify each feature to be an outlier
1220                for j in 0..n_features {
1221                    let direction = if rng.random_range(0.0f64..1.0) < 0.5 {
1222                        -1.0
1223                    } else {
1224                        1.0
1225                    };
1226                    data[[outlier_idx, j]] =
1227                        feature_means[j] + direction * outlier_strength * feature_stds[j];
1228                }
1229            }
1230        }
1231        OutlierType::Contextual => {
1232            // Contextual outliers - anomalous in specific feature combinations
1233            for _ in 0..n_outliers {
1234                let outlier_idx = rng.random_range(0..n_samples);
1235                outlier_mask[outlier_idx] = true;
1236
1237                // Only modify a subset of features to create contextual anomaly
1238                let n_features_to_modify = rng.random_range(1..=(n_features / 2).max(1));
1239                let mut features_to_modify: Vec<usize> = (0..n_features).collect();
1240                features_to_modify.shuffle(&mut rng);
1241                features_to_modify.truncate(n_features_to_modify);
1242
1243                for &j in &features_to_modify {
1244                    let direction = if rng.random_range(0.0f64..1.0) < 0.5 {
1245                        -1.0
1246                    } else {
1247                        1.0
1248                    };
1249                    data[[outlier_idx, j]] =
1250                        feature_means[j] + direction * outlier_strength * feature_stds[j];
1251                }
1252            }
1253        }
1254        OutlierType::Collective => {
1255            // Collective outliers - groups of points that together form anomalies
1256            let outliers_per_group = (n_outliers / 3).max(2); // At least 2 per group
1257            let n_groups = (n_outliers / outliers_per_group).max(1);
1258
1259            for _ in 0..n_groups {
1260                // Generate cluster center for this collective outlier
1261                let mut outlier_center = vec![0.0; n_features];
1262                for j in 0..n_features {
1263                    let direction = if rng.random_range(0.0f64..1.0) < 0.5 {
1264                        -1.0
1265                    } else {
1266                        1.0
1267                    };
1268                    outlier_center[j] =
1269                        feature_means[j] + direction * outlier_strength * feature_stds[j];
1270                }
1271
1272                // Generate points around this center
1273                for _ in 0..outliers_per_group {
1274                    let outlier_idx = rng.random_range(0..n_samples);
1275                    outlier_mask[outlier_idx] = true;
1276
1277                    for j in 0..n_features {
1278                        let noise = rng.random_range(-0.5f64..0.5f64) * feature_stds[j];
1279                        data[[outlier_idx, j]] = outlier_center[j] + noise;
1280                    }
1281                }
1282            }
1283        }
1284    }
1285
1286    Ok(outlier_mask)
1287}
1288
1289/// Add realistic noise patterns to time series data
1290pub fn add_time_series_noise(
1291    data: &mut Array2<f64>,
1292    noise_types: &[(&str, f64)], // (noise_type, strength)
1293    random_seed: Option<u64>,
1294) -> Result<()> {
1295    let mut rng = match random_seed {
1296        Some(seed) => StdRng::seed_from_u64(seed),
1297        None => {
1298            let mut r = rng();
1299            StdRng::seed_from_u64(r.next_u64())
1300        }
1301    };
1302
1303    let (n_samples, n_features) = data.dim();
1304    let normal = rand_distr::Normal::new(0.0, 1.0).unwrap();
1305
1306    for &(noise_type, strength) in noise_types {
1307        match noise_type {
1308            "gaussian" => {
1309                // Add Gaussian white noise
1310                for i in 0..n_samples {
1311                    for j in 0..n_features {
1312                        data[[i, j]] += strength * normal.sample(&mut rng);
1313                    }
1314                }
1315            }
1316            "spikes" => {
1317                // Add random spikes (impulse noise)
1318                let n_spikes = (n_samples as f64 * strength * 0.1).ceil() as usize;
1319                for _ in 0..n_spikes {
1320                    let spike_idx = rng.random_range(0..n_samples);
1321                    let feature_idx = rng.random_range(0..n_features);
1322                    let spike_magnitude = rng.random_range(5.0..=15.0) * strength;
1323                    let direction = if rng.random_range(0.0f64..1.0) < 0.5 {
1324                        -1.0
1325                    } else {
1326                        1.0
1327                    };
1328
1329                    data[[spike_idx, feature_idx]] += direction * spike_magnitude;
1330                }
1331            }
1332            "drift" => {
1333                // Add gradual drift over time
1334                for i in 0..n_samples {
1335                    let drift_amount = strength * (i as f64 / n_samples as f64);
1336                    for j in 0..n_features {
1337                        data[[i, j]] += drift_amount;
1338                    }
1339                }
1340            }
1341            "seasonal" => {
1342                // Add seasonal pattern noise
1343                let period = n_samples as f64 / 4.0; // 4 seasons
1344                for i in 0..n_samples {
1345                    let seasonal_component = strength * (2.0 * PI * i as f64 / period).sin();
1346                    for j in 0..n_features {
1347                        data[[i, j]] += seasonal_component;
1348                    }
1349                }
1350            }
1351            "autocorrelated" => {
1352                // Add autocorrelated noise (AR(1) process)
1353                let ar_coeff = 0.7; // Autocorrelation coefficient
1354                for j in 0..n_features {
1355                    let mut prev_noise = 0.0;
1356                    for i in 0..n_samples {
1357                        let new_noise = ar_coeff * prev_noise + strength * normal.sample(&mut rng);
1358                        data[[i, j]] += new_noise;
1359                        prev_noise = new_noise;
1360                    }
1361                }
1362            }
1363            "heteroscedastic" => {
1364                // Add heteroscedastic noise (variance changes over time)
1365                for i in 0..n_samples {
1366                    let variance_factor = 1.0 + strength * (i as f64 / n_samples as f64);
1367                    for j in 0..n_features {
1368                        data[[i, j]] += variance_factor * strength * normal.sample(&mut rng);
1369                    }
1370                }
1371            }
1372            _ => {
1373                return Err(DatasetsError::InvalidFormat(format!(
1374                    "Unknown noise type: {}. Supported types: gaussian, spikes, drift, seasonal, autocorrelated, heteroscedastic",
1375                    noise_type
1376                )));
1377            }
1378        }
1379    }
1380
1381    Ok(())
1382}
1383
1384/// Generate a dataset with controlled corruption patterns
1385pub fn make_corrupted_dataset(
1386    base_dataset: &Dataset,
1387    missing_rate: f64,
1388    missing_pattern: MissingPattern,
1389    outlier_rate: f64,
1390    outlier_type: OutlierType,
1391    outlier_strength: f64,
1392    random_seed: Option<u64>,
1393) -> Result<Dataset> {
1394    // Validate inputs
1395    if !(0.0..=1.0).contains(&missing_rate) {
1396        return Err(DatasetsError::InvalidFormat(
1397            "missing_rate must be between 0.0 and 1.0".to_string(),
1398        ));
1399    }
1400
1401    if !(0.0..=1.0).contains(&outlier_rate) {
1402        return Err(DatasetsError::InvalidFormat(
1403            "outlier_rate must be between 0.0 and 1.0".to_string(),
1404        ));
1405    }
1406
1407    // Clone the base dataset
1408    let mut corrupted_data = base_dataset.data.clone();
1409    let corrupted_target = base_dataset.target.clone();
1410
1411    // Apply missing data
1412    let missing_mask = inject_missing_data(
1413        &mut corrupted_data,
1414        missing_rate,
1415        missing_pattern,
1416        random_seed,
1417    )?;
1418
1419    // Apply outliers
1420    let outlier_mask = inject_outliers(
1421        &mut corrupted_data,
1422        outlier_rate,
1423        outlier_type,
1424        outlier_strength,
1425        random_seed,
1426    )?;
1427
1428    // Create new dataset with corruption metadata
1429    let mut corrupted_dataset = Dataset::new(corrupted_data, corrupted_target);
1430
1431    if let Some(feature_names) = &base_dataset.feature_names {
1432        corrupted_dataset = corrupted_dataset.with_feature_names(feature_names.clone());
1433    }
1434
1435    if let Some(target_names) = &base_dataset.target_names {
1436        corrupted_dataset = corrupted_dataset.with_target_names(target_names.clone());
1437    }
1438
1439    corrupted_dataset = corrupted_dataset
1440        .with_description(format!(
1441            "Corrupted version of: {}",
1442            base_dataset
1443                .description
1444                .as_deref()
1445                .unwrap_or("Unknown dataset")
1446        ))
1447        .with_metadata("missing_rate", &missing_rate.to_string())
1448        .with_metadata("missing_pattern", &format!("{:?}", missing_pattern))
1449        .with_metadata("outlier_rate", &outlier_rate.to_string())
1450        .with_metadata("outlier_type", &format!("{:?}", outlier_type))
1451        .with_metadata("outlier_strength", &outlier_strength.to_string())
1452        .with_metadata(
1453            "missing_count",
1454            &missing_mask.iter().filter(|&&x| x).count().to_string(),
1455        )
1456        .with_metadata(
1457            "outlier_count",
1458            &outlier_mask.iter().filter(|&&x| x).count().to_string(),
1459        );
1460
1461    Ok(corrupted_dataset)
1462}
1463
1464#[cfg(test)]
1465mod tests {
1466    use super::*;
1467
1468    #[test]
1469    fn test_make_classification_invalid_params() {
1470        // Test zero n_samples
1471        assert!(make_classification(0, 5, 2, 1, 3, None).is_err());
1472
1473        // Test zero n_features
1474        assert!(make_classification(10, 0, 2, 1, 3, None).is_err());
1475
1476        // Test zero n_informative
1477        assert!(make_classification(10, 5, 2, 1, 0, None).is_err());
1478
1479        // Test n_features < n_informative
1480        assert!(make_classification(10, 3, 2, 1, 5, None).is_err());
1481
1482        // Test n_classes < 2
1483        assert!(make_classification(10, 5, 1, 1, 3, None).is_err());
1484
1485        // Test zero n_clusters_per_class
1486        assert!(make_classification(10, 5, 2, 0, 3, None).is_err());
1487    }
1488
1489    #[test]
1490    fn test_make_regression_invalid_params() {
1491        // Test zero n_samples
1492        assert!(make_regression(0, 5, 3, 1.0, None).is_err());
1493
1494        // Test zero n_features
1495        assert!(make_regression(10, 0, 3, 1.0, None).is_err());
1496
1497        // Test zero n_informative
1498        assert!(make_regression(10, 5, 0, 1.0, None).is_err());
1499
1500        // Test n_features < n_informative
1501        assert!(make_regression(10, 3, 5, 1.0, None).is_err());
1502
1503        // Test negative noise
1504        assert!(make_regression(10, 5, 3, -1.0, None).is_err());
1505    }
1506
1507    #[test]
1508    fn test_make_time_series_invalid_params() {
1509        // Test zero n_samples
1510        assert!(make_time_series(0, 3, false, false, 1.0, None).is_err());
1511
1512        // Test zero n_features
1513        assert!(make_time_series(10, 0, false, false, 1.0, None).is_err());
1514
1515        // Test negative noise
1516        assert!(make_time_series(10, 3, false, false, -1.0, None).is_err());
1517    }
1518
1519    #[test]
1520    fn test_make_blobs_invalid_params() {
1521        // Test zero n_samples
1522        assert!(make_blobs(0, 3, 2, 1.0, None).is_err());
1523
1524        // Test zero n_features
1525        assert!(make_blobs(10, 0, 2, 1.0, None).is_err());
1526
1527        // Test zero centers
1528        assert!(make_blobs(10, 3, 0, 1.0, None).is_err());
1529
1530        // Test zero or negative cluster_std
1531        assert!(make_blobs(10, 3, 2, 0.0, None).is_err());
1532        assert!(make_blobs(10, 3, 2, -1.0, None).is_err());
1533    }
1534
1535    #[test]
1536    fn test_make_classification_valid_params() {
1537        let dataset = make_classification(20, 5, 3, 2, 4, Some(42)).unwrap();
1538        assert_eq!(dataset.n_samples(), 20);
1539        assert_eq!(dataset.n_features(), 5);
1540        assert!(dataset.target.is_some());
1541        assert!(dataset.feature_names.is_some());
1542        assert!(dataset.target_names.is_some());
1543    }
1544
1545    #[test]
1546    fn test_make_regression_valid_params() {
1547        let dataset = make_regression(15, 4, 3, 0.5, Some(42)).unwrap();
1548        assert_eq!(dataset.n_samples(), 15);
1549        assert_eq!(dataset.n_features(), 4);
1550        assert!(dataset.target.is_some());
1551        assert!(dataset.feature_names.is_some());
1552    }
1553
1554    #[test]
1555    fn test_make_time_series_valid_params() {
1556        let dataset = make_time_series(25, 3, true, true, 0.1, Some(42)).unwrap();
1557        assert_eq!(dataset.n_samples(), 25);
1558        assert_eq!(dataset.n_features(), 3);
1559        assert!(dataset.feature_names.is_some());
1560        // Time series doesn't have targets by default
1561        assert!(dataset.target.is_none());
1562    }
1563
1564    #[test]
1565    fn test_make_blobs_valid_params() {
1566        let dataset = make_blobs(30, 4, 3, 1.5, Some(42)).unwrap();
1567        assert_eq!(dataset.n_samples(), 30);
1568        assert_eq!(dataset.n_features(), 4);
1569        assert!(dataset.target.is_some());
1570        assert!(dataset.feature_names.is_some());
1571    }
1572
1573    #[test]
1574    fn test_make_spirals_invalid_params() {
1575        // Test zero n_samples
1576        assert!(make_spirals(0, 2, 0.1, None).is_err());
1577
1578        // Test zero n_spirals
1579        assert!(make_spirals(100, 0, 0.1, None).is_err());
1580
1581        // Test negative noise
1582        assert!(make_spirals(100, 2, -0.1, None).is_err());
1583    }
1584
1585    #[test]
1586    fn test_make_spirals_valid_params() {
1587        let dataset = make_spirals(100, 2, 0.1, Some(42)).unwrap();
1588        assert_eq!(dataset.n_samples(), 100);
1589        assert_eq!(dataset.n_features(), 2);
1590        assert!(dataset.target.is_some());
1591        assert!(dataset.feature_names.is_some());
1592
1593        // Check that we have the right number of spirals
1594        if let Some(target) = &dataset.target {
1595            let unique_labels: std::collections::HashSet<_> =
1596                target.iter().map(|&x| x as i32).collect();
1597            assert_eq!(unique_labels.len(), 2);
1598        }
1599    }
1600
1601    #[test]
1602    fn test_make_moons_invalid_params() {
1603        // Test zero n_samples
1604        assert!(make_moons(0, 0.1, None).is_err());
1605
1606        // Test negative noise
1607        assert!(make_moons(100, -0.1, None).is_err());
1608    }
1609
1610    #[test]
1611    fn test_make_moons_valid_params() {
1612        let dataset = make_moons(100, 0.1, Some(42)).unwrap();
1613        assert_eq!(dataset.n_samples(), 100);
1614        assert_eq!(dataset.n_features(), 2);
1615        assert!(dataset.target.is_some());
1616        assert!(dataset.feature_names.is_some());
1617
1618        // Check that we have exactly 2 classes (2 moons)
1619        if let Some(target) = &dataset.target {
1620            let unique_labels: std::collections::HashSet<_> =
1621                target.iter().map(|&x| x as i32).collect();
1622            assert_eq!(unique_labels.len(), 2);
1623        }
1624    }
1625
1626    #[test]
1627    fn test_make_circles_invalid_params() {
1628        // Test zero n_samples
1629        assert!(make_circles(0, 0.5, 0.1, None).is_err());
1630
1631        // Test invalid factor (must be between 0 and 1)
1632        assert!(make_circles(100, 0.0, 0.1, None).is_err());
1633        assert!(make_circles(100, 1.0, 0.1, None).is_err());
1634        assert!(make_circles(100, 1.5, 0.1, None).is_err());
1635
1636        // Test negative noise
1637        assert!(make_circles(100, 0.5, -0.1, None).is_err());
1638    }
1639
1640    #[test]
1641    fn test_make_circles_valid_params() {
1642        let dataset = make_circles(100, 0.5, 0.1, Some(42)).unwrap();
1643        assert_eq!(dataset.n_samples(), 100);
1644        assert_eq!(dataset.n_features(), 2);
1645        assert!(dataset.target.is_some());
1646        assert!(dataset.feature_names.is_some());
1647
1648        // Check that we have exactly 2 classes (inner and outer circle)
1649        if let Some(target) = &dataset.target {
1650            let unique_labels: std::collections::HashSet<_> =
1651                target.iter().map(|&x| x as i32).collect();
1652            assert_eq!(unique_labels.len(), 2);
1653        }
1654    }
1655
1656    #[test]
1657    fn test_make_swiss_roll_invalid_params() {
1658        // Test zero n_samples
1659        assert!(make_swiss_roll(0, 0.1, None).is_err());
1660
1661        // Test negative noise
1662        assert!(make_swiss_roll(100, -0.1, None).is_err());
1663    }
1664
1665    #[test]
1666    fn test_make_swiss_roll_valid_params() {
1667        let dataset = make_swiss_roll(100, 0.1, Some(42)).unwrap();
1668        assert_eq!(dataset.n_samples(), 100);
1669        assert_eq!(dataset.n_features(), 3);
1670        assert!(dataset.target.is_some()); // Color parameter
1671        assert!(dataset.feature_names.is_some());
1672    }
1673
1674    #[test]
1675    fn test_make_anisotropic_blobs_invalid_params() {
1676        // Test zero n_samples
1677        assert!(make_anisotropic_blobs(0, 3, 2, 1.0, 2.0, None).is_err());
1678
1679        // Test insufficient features
1680        assert!(make_anisotropic_blobs(100, 1, 2, 1.0, 2.0, None).is_err());
1681
1682        // Test zero centers
1683        assert!(make_anisotropic_blobs(100, 3, 0, 1.0, 2.0, None).is_err());
1684
1685        // Test invalid std
1686        assert!(make_anisotropic_blobs(100, 3, 2, 0.0, 2.0, None).is_err());
1687
1688        // Test invalid anisotropy factor
1689        assert!(make_anisotropic_blobs(100, 3, 2, 1.0, 0.0, None).is_err());
1690    }
1691
1692    #[test]
1693    fn test_make_anisotropic_blobs_valid_params() {
1694        let dataset = make_anisotropic_blobs(100, 3, 2, 1.0, 3.0, Some(42)).unwrap();
1695        assert_eq!(dataset.n_samples(), 100);
1696        assert_eq!(dataset.n_features(), 3);
1697        assert!(dataset.target.is_some());
1698        assert!(dataset.feature_names.is_some());
1699
1700        // Check that we have the right number of clusters
1701        if let Some(target) = &dataset.target {
1702            let unique_labels: std::collections::HashSet<_> =
1703                target.iter().map(|&x| x as i32).collect();
1704            assert_eq!(unique_labels.len(), 2);
1705        }
1706    }
1707
1708    #[test]
1709    fn test_make_hierarchical_clusters_invalid_params() {
1710        // Test zero n_samples
1711        assert!(make_hierarchical_clusters(0, 3, 2, 3, 1.0, 0.5, None).is_err());
1712
1713        // Test zero features
1714        assert!(make_hierarchical_clusters(100, 0, 2, 3, 1.0, 0.5, None).is_err());
1715
1716        // Test zero main clusters
1717        assert!(make_hierarchical_clusters(100, 3, 0, 3, 1.0, 0.5, None).is_err());
1718
1719        // Test zero sub clusters
1720        assert!(make_hierarchical_clusters(100, 3, 2, 0, 1.0, 0.5, None).is_err());
1721
1722        // Test invalid main cluster std
1723        assert!(make_hierarchical_clusters(100, 3, 2, 3, 0.0, 0.5, None).is_err());
1724
1725        // Test invalid sub cluster std
1726        assert!(make_hierarchical_clusters(100, 3, 2, 3, 1.0, 0.0, None).is_err());
1727    }
1728
1729    #[test]
1730    fn test_make_hierarchical_clusters_valid_params() {
1731        let dataset = make_hierarchical_clusters(120, 3, 2, 3, 2.0, 0.5, Some(42)).unwrap();
1732        assert_eq!(dataset.n_samples(), 120);
1733        assert_eq!(dataset.n_features(), 3);
1734        assert!(dataset.target.is_some());
1735        assert!(dataset.feature_names.is_some());
1736
1737        // Check that we have the right number of main clusters
1738        if let Some(target) = &dataset.target {
1739            let unique_labels: std::collections::HashSet<_> =
1740                target.iter().map(|&x| x as i32).collect();
1741            assert_eq!(unique_labels.len(), 2); // 2 main clusters
1742        }
1743
1744        // Check metadata contains sub-cluster information
1745        assert!(dataset.metadata.contains_key("sub_cluster_labels"));
1746    }
1747
1748    #[test]
1749    fn test_inject_missing_data_invalid_params() {
1750        let mut data = Array2::from_shape_vec((5, 3), vec![1.0; 15]).unwrap();
1751
1752        // Test invalid missing rate
1753        assert!(inject_missing_data(&mut data, -0.1, MissingPattern::MCAR, None).is_err());
1754        assert!(inject_missing_data(&mut data, 1.5, MissingPattern::MCAR, None).is_err());
1755    }
1756
1757    #[test]
1758    fn test_inject_missing_data_mcar() {
1759        let mut data =
1760            Array2::from_shape_vec((10, 3), (0..30).map(|x| x as f64).collect()).unwrap();
1761        let original_data = data.clone();
1762
1763        let missing_mask =
1764            inject_missing_data(&mut data, 0.3, MissingPattern::MCAR, Some(42)).unwrap();
1765
1766        // Check that some data is missing
1767        let missing_count = missing_mask.iter().filter(|&&x| x).count();
1768        assert!(missing_count > 0);
1769
1770        // Check that missing values are NaN
1771        for ((i, j), &is_missing) in missing_mask.indexed_iter() {
1772            if is_missing {
1773                assert!(data[[i, j]].is_nan());
1774            } else {
1775                assert_eq!(data[[i, j]], original_data[[i, j]]);
1776            }
1777        }
1778    }
1779
1780    #[test]
1781    fn test_inject_outliers_invalid_params() {
1782        let mut data = Array2::from_shape_vec((5, 3), vec![1.0; 15]).unwrap();
1783
1784        // Test invalid outlier rate
1785        assert!(inject_outliers(&mut data, -0.1, OutlierType::Point, 2.0, None).is_err());
1786        assert!(inject_outliers(&mut data, 1.5, OutlierType::Point, 2.0, None).is_err());
1787
1788        // Test invalid outlier strength
1789        assert!(inject_outliers(&mut data, 0.1, OutlierType::Point, 0.0, None).is_err());
1790        assert!(inject_outliers(&mut data, 0.1, OutlierType::Point, -1.0, None).is_err());
1791    }
1792
1793    #[test]
1794    fn test_inject_outliers_point() {
1795        let mut data = Array2::from_shape_vec((20, 2), vec![1.0; 40]).unwrap();
1796
1797        let outlier_mask =
1798            inject_outliers(&mut data, 0.2, OutlierType::Point, 3.0, Some(42)).unwrap();
1799
1800        // Check that some outliers were created
1801        let outlier_count = outlier_mask.iter().filter(|&&x| x).count();
1802        assert!(outlier_count > 0);
1803
1804        // Check that outliers are different from original values
1805        for (i, &is_outlier) in outlier_mask.iter().enumerate() {
1806            if is_outlier {
1807                // At least one feature should be different from 1.0
1808                let row = data.row(i);
1809                assert!(row.iter().any(|&x| (x - 1.0).abs() > 1.0));
1810            }
1811        }
1812    }
1813
1814    #[test]
1815    fn test_add_time_series_noise() {
1816        let mut data = Array2::zeros((100, 2));
1817
1818        let noise_types = [("gaussian", 0.1), ("spikes", 0.05), ("drift", 0.2)];
1819
1820        let original_data = data.clone();
1821        add_time_series_noise(&mut data, &noise_types, Some(42)).unwrap();
1822
1823        // Check that data has been modified
1824        assert!(!data
1825            .iter()
1826            .zip(original_data.iter())
1827            .all(|(&a, &b)| (a - b).abs() < 1e-10));
1828
1829        // Test invalid noise type
1830        let invalid_noise = [("invalid_type", 0.1)];
1831        let mut test_data = Array2::zeros((10, 2));
1832        assert!(add_time_series_noise(&mut test_data, &invalid_noise, Some(42)).is_err());
1833    }
1834
1835    #[test]
1836    fn test_make_corrupted_dataset() {
1837        let base_dataset = make_blobs(50, 3, 2, 1.0, Some(42)).unwrap();
1838
1839        let corrupted = make_corrupted_dataset(
1840            &base_dataset,
1841            0.1, // 10% missing
1842            MissingPattern::MCAR,
1843            0.05, // 5% outliers
1844            OutlierType::Point,
1845            2.0, // outlier strength
1846            Some(42),
1847        )
1848        .unwrap();
1849
1850        // Check basic properties
1851        assert_eq!(corrupted.n_samples(), base_dataset.n_samples());
1852        assert_eq!(corrupted.n_features(), base_dataset.n_features());
1853
1854        // Check metadata
1855        assert!(corrupted.metadata.contains_key("missing_rate"));
1856        assert!(corrupted.metadata.contains_key("outlier_rate"));
1857        assert!(corrupted.metadata.contains_key("missing_count"));
1858        assert!(corrupted.metadata.contains_key("outlier_count"));
1859
1860        // Check some data is corrupted
1861        let has_nan = corrupted.data.iter().any(|&x| x.is_nan());
1862        assert!(has_nan, "Dataset should have some missing values");
1863    }
1864
1865    #[test]
1866    fn test_make_corrupted_dataset_invalid_params() {
1867        let base_dataset = make_blobs(20, 2, 2, 1.0, Some(42)).unwrap();
1868
1869        // Test invalid missing rate
1870        assert!(make_corrupted_dataset(
1871            &base_dataset,
1872            -0.1,
1873            MissingPattern::MCAR,
1874            0.0,
1875            OutlierType::Point,
1876            1.0,
1877            None
1878        )
1879        .is_err());
1880        assert!(make_corrupted_dataset(
1881            &base_dataset,
1882            1.5,
1883            MissingPattern::MCAR,
1884            0.0,
1885            OutlierType::Point,
1886            1.0,
1887            None
1888        )
1889        .is_err());
1890
1891        // Test invalid outlier rate
1892        assert!(make_corrupted_dataset(
1893            &base_dataset,
1894            0.0,
1895            MissingPattern::MCAR,
1896            -0.1,
1897            OutlierType::Point,
1898            1.0,
1899            None
1900        )
1901        .is_err());
1902        assert!(make_corrupted_dataset(
1903            &base_dataset,
1904            0.0,
1905            MissingPattern::MCAR,
1906            1.5,
1907            OutlierType::Point,
1908            1.0,
1909            None
1910        )
1911        .is_err());
1912    }
1913
1914    #[test]
1915    fn test_missing_patterns() {
1916        let data = Array2::from_shape_vec((20, 4), (0..80).map(|x| x as f64).collect()).unwrap();
1917
1918        // Test different missing patterns
1919        for pattern in [
1920            MissingPattern::MCAR,
1921            MissingPattern::MAR,
1922            MissingPattern::MNAR,
1923            MissingPattern::Block,
1924        ] {
1925            let mut test_data = data.clone();
1926            let missing_mask = inject_missing_data(&mut test_data, 0.2, pattern, Some(42)).unwrap();
1927
1928            let missing_count = missing_mask.iter().filter(|&&x| x).count();
1929            assert!(
1930                missing_count > 0,
1931                "Pattern {:?} should create some missing values",
1932                pattern
1933            );
1934        }
1935    }
1936
1937    #[test]
1938    fn test_outlier_types() {
1939        let data = Array2::ones((30, 3));
1940
1941        // Test different outlier types
1942        for outlier_type in [
1943            OutlierType::Point,
1944            OutlierType::Contextual,
1945            OutlierType::Collective,
1946        ] {
1947            let mut test_data = data.clone();
1948            let outlier_mask =
1949                inject_outliers(&mut test_data, 0.2, outlier_type, 3.0, Some(42)).unwrap();
1950
1951            let outlier_count = outlier_mask.iter().filter(|&&x| x).count();
1952            assert!(
1953                outlier_count > 0,
1954                "Outlier type {:?} should create some outliers",
1955                outlier_type
1956            );
1957        }
1958    }
1959}