scirs2_datasets/generators/
basic.rs

1//! Basic dataset generators (classification, regression, blobs, etc.)
2
3use crate::error::{DatasetsError, Result};
4use crate::utils::Dataset;
5use scirs2_core::ndarray::{Array1, Array2};
6use scirs2_core::random::prelude::*;
7use scirs2_core::random::rand_distributions::{Distribution, Uniform};
8use std::f64::consts::PI;
9
10/// Generate a random classification dataset with clusters
11#[allow(dead_code)]
12#[allow(clippy::too_many_arguments)]
13pub fn make_classification(
14    n_samples: usize,
15    n_features: usize,
16    n_classes: usize,
17    n_clusters_per_class: usize,
18    n_informative: usize,
19    randomseed: Option<u64>,
20) -> Result<Dataset> {
21    // Validate input parameters
22    if n_samples == 0 {
23        return Err(DatasetsError::InvalidFormat(
24            "n_samples must be > 0".to_string(),
25        ));
26    }
27
28    if n_features == 0 {
29        return Err(DatasetsError::InvalidFormat(
30            "n_features must be > 0".to_string(),
31        ));
32    }
33
34    if n_informative == 0 {
35        return Err(DatasetsError::InvalidFormat(
36            "n_informative must be > 0".to_string(),
37        ));
38    }
39
40    if n_features < n_informative {
41        return Err(DatasetsError::InvalidFormat(format!(
42            "n_features ({n_features}) must be >= n_informative ({n_informative})"
43        )));
44    }
45
46    if n_classes < 2 {
47        return Err(DatasetsError::InvalidFormat(
48            "n_classes must be >= 2".to_string(),
49        ));
50    }
51
52    if n_clusters_per_class == 0 {
53        return Err(DatasetsError::InvalidFormat(
54            "n_clusters_per_class must be > 0".to_string(),
55        ));
56    }
57
58    let mut rng = match randomseed {
59        Some(_seed) => StdRng::seed_from_u64(_seed),
60        None => {
61            let mut r = thread_rng();
62            StdRng::seed_from_u64(r.next_u64())
63        }
64    };
65
66    // Generate centroids for each _class and cluster
67    let n_centroids = n_classes * n_clusters_per_class;
68    let mut centroids = Array2::zeros((n_centroids, n_informative));
69    let scale = 2.0;
70
71    for i in 0..n_centroids {
72        for j in 0..n_informative {
73            centroids[[i, j]] = scale * rng.gen_range(-1.0f64..1.0f64);
74        }
75    }
76
77    // Generate _samples
78    let mut data = Array2::zeros((n_samples, n_features));
79    let mut target = Array1::zeros(n_samples);
80
81    let normal = scirs2_core::random::Normal::new(0.0, 1.0).unwrap();
82
83    // Samples per _class
84    let samples_per_class = n_samples / n_classes;
85    let remainder = n_samples % n_classes;
86
87    let mut sample_idx = 0;
88
89    for _class in 0..n_classes {
90        let n_samples_class = if _class < remainder {
91            samples_per_class + 1
92        } else {
93            samples_per_class
94        };
95
96        // Assign clusters within this _class
97        let samples_per_cluster = n_samples_class / n_clusters_per_class;
98        let cluster_remainder = n_samples_class % n_clusters_per_class;
99
100        for cluster in 0..n_clusters_per_class {
101            let n_samples_cluster = if cluster < cluster_remainder {
102                samples_per_cluster + 1
103            } else {
104                samples_per_cluster
105            };
106
107            let centroid_idx = _class * n_clusters_per_class + cluster;
108
109            for _ in 0..n_samples_cluster {
110                // Randomly select a point near the cluster centroid
111                for j in 0..n_informative {
112                    data[[sample_idx, j]] =
113                        centroids[[centroid_idx, j]] + 0.3 * normal.sample(&mut rng);
114                }
115
116                // Add noise _features
117                for j in n_informative..n_features {
118                    data[[sample_idx, j]] = normal.sample(&mut rng);
119                }
120
121                target[sample_idx] = _class as f64;
122                sample_idx += 1;
123            }
124        }
125    }
126
127    // Create dataset
128    let mut dataset = Dataset::new(data, Some(target));
129
130    // Create feature names
131    let featurenames: Vec<String> = (0..n_features).map(|i| format!("feature_{i}")).collect();
132
133    // Create _class names
134    let classnames: Vec<String> = (0..n_classes).map(|i| format!("class_{i}")).collect();
135
136    dataset = dataset
137        .with_featurenames(featurenames)
138        .with_targetnames(classnames)
139        .with_description(format!(
140            "Synthetic classification dataset with {n_classes} _classes and {n_features} _features"
141        ));
142
143    Ok(dataset)
144}
145
146/// Generate a random regression dataset
147#[allow(dead_code)]
148pub fn make_regression(
149    n_samples: usize,
150    n_features: usize,
151    n_informative: usize,
152    noise: f64,
153    randomseed: Option<u64>,
154) -> Result<Dataset> {
155    // Validate input parameters
156    if n_samples == 0 {
157        return Err(DatasetsError::InvalidFormat(
158            "n_samples must be > 0".to_string(),
159        ));
160    }
161
162    if n_features == 0 {
163        return Err(DatasetsError::InvalidFormat(
164            "n_features must be > 0".to_string(),
165        ));
166    }
167
168    if n_informative == 0 {
169        return Err(DatasetsError::InvalidFormat(
170            "n_informative must be > 0".to_string(),
171        ));
172    }
173
174    if n_features < n_informative {
175        return Err(DatasetsError::InvalidFormat(format!(
176            "n_features ({n_features}) must be >= n_informative ({n_informative})"
177        )));
178    }
179
180    if noise < 0.0 {
181        return Err(DatasetsError::InvalidFormat(
182            "noise must be >= 0.0".to_string(),
183        ));
184    }
185
186    let mut rng = match randomseed {
187        Some(_seed) => StdRng::seed_from_u64(_seed),
188        None => {
189            let mut r = thread_rng();
190            StdRng::seed_from_u64(r.next_u64())
191        }
192    };
193
194    // Generate the coefficients for the _informative _features
195    let mut coef = Array1::zeros(n_features);
196    let normal = scirs2_core::random::Normal::new(0.0, 1.0).unwrap();
197
198    for i in 0..n_informative {
199        coef[i] = 100.0 * normal.sample(&mut rng);
200    }
201
202    // Generate the _features
203    let mut data = Array2::zeros((n_samples, n_features));
204
205    for i in 0..n_samples {
206        for j in 0..n_features {
207            data[[i, j]] = normal.sample(&mut rng);
208        }
209    }
210
211    // Generate the target
212    let mut target = Array1::zeros(n_samples);
213
214    for i in 0..n_samples {
215        let mut y = 0.0;
216        for j in 0..n_features {
217            y += data[[i, j]] * coef[j];
218        }
219
220        // Add noise
221        if noise > 0.0 {
222            y += normal.sample(&mut rng) * noise;
223        }
224
225        target[i] = y;
226    }
227
228    // Create dataset
229    let mut dataset = Dataset::new(data, Some(target));
230
231    // Create feature names
232    let featurenames: Vec<String> = (0..n_features).map(|i| format!("feature_{i}")).collect();
233
234    dataset = dataset
235        .with_featurenames(featurenames)
236        .with_description(format!(
237            "Synthetic regression dataset with {n_features} _features ({n_informative} informative)"
238        ))
239        .with_metadata("noise", &noise.to_string())
240        .with_metadata("coefficients", &format!("{coef:?}"));
241
242    Ok(dataset)
243}
244
245/// Generate a random time series dataset
246#[allow(dead_code)]
247pub fn make_time_series(
248    n_samples: usize,
249    n_features: usize,
250    trend: bool,
251    seasonality: bool,
252    noise: f64,
253    randomseed: Option<u64>,
254) -> Result<Dataset> {
255    // Validate input parameters
256    if n_samples == 0 {
257        return Err(DatasetsError::InvalidFormat(
258            "n_samples must be > 0".to_string(),
259        ));
260    }
261
262    if n_features == 0 {
263        return Err(DatasetsError::InvalidFormat(
264            "n_features must be > 0".to_string(),
265        ));
266    }
267
268    if noise < 0.0 {
269        return Err(DatasetsError::InvalidFormat(
270            "noise must be >= 0.0".to_string(),
271        ));
272    }
273
274    let mut rng = match randomseed {
275        Some(_seed) => StdRng::seed_from_u64(_seed),
276        None => {
277            let mut r = thread_rng();
278            StdRng::seed_from_u64(r.next_u64())
279        }
280    };
281
282    let normal = scirs2_core::random::Normal::new(0.0, 1.0).unwrap();
283    let mut data = Array2::zeros((n_samples, n_features));
284
285    for feature in 0..n_features {
286        let trend_coef = if trend {
287            rng.gen_range(0.01f64..0.1f64)
288        } else {
289            0.0
290        };
291        let seasonality_period = rng.sample(Uniform::new(10, 50).unwrap()) as f64;
292        let seasonality_amplitude = if seasonality {
293            rng.gen_range(1.0f64..5.0f64)
294        } else {
295            0.0
296        };
297
298        let base_value = rng.gen_range(-10.0f64..10.0f64);
299
300        for i in 0..n_samples {
301            let t = i as f64;
302
303            // Add base value
304            let mut value = base_value;
305
306            // Add trend
307            if trend {
308                value += trend_coef * t;
309            }
310
311            // Add seasonality
312            if seasonality {
313                value += seasonality_amplitude * (2.0 * PI * t / seasonality_period).sin();
314            }
315
316            // Add noise
317            if noise > 0.0 {
318                value += normal.sample(&mut rng) * noise;
319            }
320
321            data[[i, feature]] = value;
322        }
323    }
324
325    // Create time index (unused for now but can be useful for plotting)
326    let time_index: Vec<f64> = (0..n_samples).map(|i| i as f64).collect();
327    let _time_array = Array1::from(time_index);
328
329    // Create dataset
330    let mut dataset = Dataset::new(data, None);
331
332    // Create feature names
333    let featurenames: Vec<String> = (0..n_features).map(|i| format!("feature_{i}")).collect();
334
335    dataset = dataset
336        .with_featurenames(featurenames)
337        .with_description(format!(
338            "Synthetic time series dataset with {n_features} _features"
339        ))
340        .with_metadata("trend", &trend.to_string())
341        .with_metadata("seasonality", &seasonality.to_string())
342        .with_metadata("noise", &noise.to_string());
343
344    Ok(dataset)
345}
346
347/// Generate a random blobs dataset for clustering
348#[allow(dead_code)]
349pub fn make_blobs(
350    n_samples: usize,
351    n_features: usize,
352    centers: usize,
353    cluster_std: f64,
354    randomseed: Option<u64>,
355) -> Result<Dataset> {
356    // Validate input parameters
357    if n_samples == 0 {
358        return Err(DatasetsError::InvalidFormat(
359            "n_samples must be > 0".to_string(),
360        ));
361    }
362
363    if n_features == 0 {
364        return Err(DatasetsError::InvalidFormat(
365            "n_features must be > 0".to_string(),
366        ));
367    }
368
369    if centers == 0 {
370        return Err(DatasetsError::InvalidFormat(
371            "centers must be > 0".to_string(),
372        ));
373    }
374
375    if cluster_std <= 0.0 {
376        return Err(DatasetsError::InvalidFormat(
377            "cluster_std must be > 0.0".to_string(),
378        ));
379    }
380
381    let mut rng = match randomseed {
382        Some(_seed) => StdRng::seed_from_u64(_seed),
383        None => {
384            let mut r = thread_rng();
385            StdRng::seed_from_u64(r.next_u64())
386        }
387    };
388
389    // Generate random centers
390    let mut cluster_centers = Array2::zeros((centers, n_features));
391    let center_box = 10.0;
392
393    for i in 0..centers {
394        for j in 0..n_features {
395            cluster_centers[[i, j]] = rng.gen_range(-center_box..center_box);
396        }
397    }
398
399    // Generate _samples around centers
400    let mut data = Array2::zeros((n_samples, n_features));
401    let mut target = Array1::zeros(n_samples);
402
403    let normal = scirs2_core::random::Normal::new(0.0, cluster_std).unwrap();
404
405    // Samples per center
406    let samples_per_center = n_samples / centers;
407    let remainder = n_samples % centers;
408
409    let mut sample_idx = 0;
410
411    for center_idx in 0..centers {
412        let n_samples_center = if center_idx < remainder {
413            samples_per_center + 1
414        } else {
415            samples_per_center
416        };
417
418        for _ in 0..n_samples_center {
419            for j in 0..n_features {
420                data[[sample_idx, j]] = cluster_centers[[center_idx, j]] + normal.sample(&mut rng);
421            }
422
423            target[sample_idx] = center_idx as f64;
424            sample_idx += 1;
425        }
426    }
427
428    // Create dataset
429    let mut dataset = Dataset::new(data, Some(target));
430
431    // Create feature names
432    let featurenames: Vec<String> = (0..n_features).map(|i| format!("feature_{i}")).collect();
433
434    dataset = dataset
435        .with_featurenames(featurenames)
436        .with_description(format!(
437            "Synthetic clustering dataset with {centers} clusters and {n_features} _features"
438        ))
439        .with_metadata("centers", &centers.to_string())
440        .with_metadata("cluster_std", &cluster_std.to_string());
441
442    Ok(dataset)
443}
444
445/// Generate a spiral dataset for non-linear classification
446#[allow(dead_code)]
447pub fn make_spirals(
448    n_samples: usize,
449    n_spirals: usize,
450    noise: f64,
451    randomseed: Option<u64>,
452) -> Result<Dataset> {
453    // Validate input parameters
454    if n_samples == 0 {
455        return Err(DatasetsError::InvalidFormat(
456            "n_samples must be > 0".to_string(),
457        ));
458    }
459
460    if n_spirals == 0 {
461        return Err(DatasetsError::InvalidFormat(
462            "n_spirals must be > 0".to_string(),
463        ));
464    }
465
466    if noise < 0.0 {
467        return Err(DatasetsError::InvalidFormat(
468            "noise must be >= 0.0".to_string(),
469        ));
470    }
471
472    let mut rng = match randomseed {
473        Some(_seed) => StdRng::seed_from_u64(_seed),
474        None => {
475            let mut r = thread_rng();
476            StdRng::seed_from_u64(r.next_u64())
477        }
478    };
479
480    let mut data = Array2::zeros((n_samples, 2));
481    let mut target = Array1::zeros(n_samples);
482
483    let normal = if noise > 0.0 {
484        Some(scirs2_core::random::Normal::new(0.0, noise).unwrap())
485    } else {
486        None
487    };
488
489    let samples_per_spiral = n_samples / n_spirals;
490    let remainder = n_samples % n_spirals;
491
492    let mut sample_idx = 0;
493
494    for spiral in 0..n_spirals {
495        let n_samples_spiral = if spiral < remainder {
496            samples_per_spiral + 1
497        } else {
498            samples_per_spiral
499        };
500
501        let spiral_offset = 2.0 * PI * spiral as f64 / n_spirals as f64;
502
503        for i in 0..n_samples_spiral {
504            let t = 2.0 * PI * i as f64 / n_samples_spiral as f64;
505            let radius = t / (2.0 * PI);
506
507            let mut x = radius * (t + spiral_offset).cos();
508            let mut y = radius * (t + spiral_offset).sin();
509
510            // Add noise if specified
511            if let Some(ref normal_dist) = normal {
512                x += normal_dist.sample(&mut rng);
513                y += normal_dist.sample(&mut rng);
514            }
515
516            data[[sample_idx, 0]] = x;
517            data[[sample_idx, 1]] = y;
518            target[sample_idx] = spiral as f64;
519            sample_idx += 1;
520        }
521    }
522
523    let mut dataset = Dataset::new(data, Some(target));
524    dataset = dataset
525        .with_featurenames(vec!["x".to_string(), "y".to_string()])
526        .with_targetnames((0..n_spirals).map(|i| format!("spiral_{i}")).collect())
527        .with_description(format!("Spiral dataset with {n_spirals} _spirals"))
528        .with_metadata("noise", &noise.to_string());
529
530    Ok(dataset)
531}
532
533/// Generate a moons dataset for non-linear classification
534#[allow(dead_code)]
535pub fn make_moons(n_samples: usize, noise: f64, randomseed: Option<u64>) -> Result<Dataset> {
536    // Validate input parameters
537    if n_samples == 0 {
538        return Err(DatasetsError::InvalidFormat(
539            "n_samples must be > 0".to_string(),
540        ));
541    }
542
543    if noise < 0.0 {
544        return Err(DatasetsError::InvalidFormat(
545            "noise must be >= 0.0".to_string(),
546        ));
547    }
548
549    let mut rng = match randomseed {
550        Some(_seed) => StdRng::seed_from_u64(_seed),
551        None => {
552            let mut r = thread_rng();
553            StdRng::seed_from_u64(r.next_u64())
554        }
555    };
556
557    let mut data = Array2::zeros((n_samples, 2));
558    let mut target = Array1::zeros(n_samples);
559
560    let normal = if noise > 0.0 {
561        Some(scirs2_core::random::Normal::new(0.0, noise).unwrap())
562    } else {
563        None
564    };
565
566    let samples_per_moon = n_samples / 2;
567    let remainder = n_samples % 2;
568
569    let mut sample_idx = 0;
570
571    // Generate first moon (upper crescent)
572    for i in 0..(samples_per_moon + remainder) {
573        let t = PI * i as f64 / (samples_per_moon + remainder) as f64;
574
575        let mut x = t.cos();
576        let mut y = t.sin();
577
578        // Add noise if specified
579        if let Some(ref normal_dist) = normal {
580            x += normal_dist.sample(&mut rng);
581            y += normal_dist.sample(&mut rng);
582        }
583
584        data[[sample_idx, 0]] = x;
585        data[[sample_idx, 1]] = y;
586        target[sample_idx] = 0.0;
587        sample_idx += 1;
588    }
589
590    // Generate second moon (lower crescent, flipped)
591    for i in 0..samples_per_moon {
592        let t = PI * i as f64 / samples_per_moon as f64;
593
594        let mut x = 1.0 - t.cos();
595        let mut y = 0.5 - t.sin(); // Offset vertically and flip
596
597        // Add noise if specified
598        if let Some(ref normal_dist) = normal {
599            x += normal_dist.sample(&mut rng);
600            y += normal_dist.sample(&mut rng);
601        }
602
603        data[[sample_idx, 0]] = x;
604        data[[sample_idx, 1]] = y;
605        target[sample_idx] = 1.0;
606        sample_idx += 1;
607    }
608
609    let mut dataset = Dataset::new(data, Some(target));
610    dataset = dataset
611        .with_featurenames(vec!["x".to_string(), "y".to_string()])
612        .with_targetnames(vec!["moon_0".to_string(), "moon_1".to_string()])
613        .with_description("Two moons dataset for non-linear classification".to_string())
614        .with_metadata("noise", &noise.to_string());
615
616    Ok(dataset)
617}
618
619/// Generate a circles dataset for non-linear classification
620#[allow(dead_code)]
621pub fn make_circles(
622    n_samples: usize,
623    factor: f64,
624    noise: f64,
625    randomseed: Option<u64>,
626) -> Result<Dataset> {
627    // Validate input parameters
628    if n_samples == 0 {
629        return Err(DatasetsError::InvalidFormat(
630            "n_samples must be > 0".to_string(),
631        ));
632    }
633
634    if factor <= 0.0 || factor >= 1.0 {
635        return Err(DatasetsError::InvalidFormat(
636            "factor must be between 0.0 and 1.0".to_string(),
637        ));
638    }
639
640    if noise < 0.0 {
641        return Err(DatasetsError::InvalidFormat(
642            "noise must be >= 0.0".to_string(),
643        ));
644    }
645
646    let mut rng = match randomseed {
647        Some(_seed) => StdRng::seed_from_u64(_seed),
648        None => {
649            let mut r = thread_rng();
650            StdRng::seed_from_u64(r.next_u64())
651        }
652    };
653
654    let mut data = Array2::zeros((n_samples, 2));
655    let mut target = Array1::zeros(n_samples);
656
657    let normal = if noise > 0.0 {
658        Some(scirs2_core::random::Normal::new(0.0, noise).unwrap())
659    } else {
660        None
661    };
662
663    let samples_per_circle = n_samples / 2;
664    let remainder = n_samples % 2;
665
666    let mut sample_idx = 0;
667
668    // Generate outer circle
669    for i in 0..(samples_per_circle + remainder) {
670        let angle = 2.0 * PI * i as f64 / (samples_per_circle + remainder) as f64;
671
672        let mut x = angle.cos();
673        let mut y = angle.sin();
674
675        // Add noise if specified
676        if let Some(ref normal_dist) = normal {
677            x += normal_dist.sample(&mut rng);
678            y += normal_dist.sample(&mut rng);
679        }
680
681        data[[sample_idx, 0]] = x;
682        data[[sample_idx, 1]] = y;
683        target[sample_idx] = 0.0;
684        sample_idx += 1;
685    }
686
687    // Generate inner circle (scaled by factor)
688    for i in 0..samples_per_circle {
689        let angle = 2.0 * PI * i as f64 / samples_per_circle as f64;
690
691        let mut x = factor * angle.cos();
692        let mut y = factor * angle.sin();
693
694        // Add noise if specified
695        if let Some(ref normal_dist) = normal {
696            x += normal_dist.sample(&mut rng);
697            y += normal_dist.sample(&mut rng);
698        }
699
700        data[[sample_idx, 0]] = x;
701        data[[sample_idx, 1]] = y;
702        target[sample_idx] = 1.0;
703        sample_idx += 1;
704    }
705
706    let mut dataset = Dataset::new(data, Some(target));
707    dataset = dataset
708        .with_featurenames(vec!["x".to_string(), "y".to_string()])
709        .with_targetnames(vec!["outer_circle".to_string(), "inner_circle".to_string()])
710        .with_description("Concentric circles dataset for non-linear classification".to_string())
711        .with_metadata("factor", &factor.to_string())
712        .with_metadata("noise", &noise.to_string());
713
714    Ok(dataset)
715}
716
717/// Generate anisotropic (elongated) clusters
718#[allow(clippy::too_many_arguments)]
719#[allow(dead_code)]
720pub fn make_anisotropic_blobs(
721    n_samples: usize,
722    n_features: usize,
723    centers: usize,
724    cluster_std: f64,
725    anisotropy_factor: f64,
726    randomseed: Option<u64>,
727) -> Result<Dataset> {
728    // Validate input parameters
729    if n_samples == 0 {
730        return Err(DatasetsError::InvalidFormat(
731            "n_samples must be > 0".to_string(),
732        ));
733    }
734
735    if n_features < 2 {
736        return Err(DatasetsError::InvalidFormat(
737            "n_features must be >= 2 for anisotropic clusters".to_string(),
738        ));
739    }
740
741    if centers == 0 {
742        return Err(DatasetsError::InvalidFormat(
743            "centers must be > 0".to_string(),
744        ));
745    }
746
747    if cluster_std <= 0.0 {
748        return Err(DatasetsError::InvalidFormat(
749            "cluster_std must be > 0.0".to_string(),
750        ));
751    }
752
753    if anisotropy_factor <= 0.0 {
754        return Err(DatasetsError::InvalidFormat(
755            "anisotropy_factor must be > 0.0".to_string(),
756        ));
757    }
758
759    let mut rng = match randomseed {
760        Some(_seed) => StdRng::seed_from_u64(_seed),
761        None => {
762            let mut r = thread_rng();
763            StdRng::seed_from_u64(r.next_u64())
764        }
765    };
766
767    // Generate random centers
768    let mut cluster_centers = Array2::zeros((centers, n_features));
769    let center_box = 10.0;
770
771    for i in 0..centers {
772        for j in 0..n_features {
773            cluster_centers[[i, j]] = rng.gen_range(-center_box..center_box);
774        }
775    }
776
777    // Generate _samples around centers with anisotropic distribution
778    let mut data = Array2::zeros((n_samples, n_features));
779    let mut target = Array1::zeros(n_samples);
780
781    let normal = scirs2_core::random::Normal::new(0.0, cluster_std).unwrap();
782
783    let samples_per_center = n_samples / centers;
784    let remainder = n_samples % centers;
785
786    let mut sample_idx = 0;
787
788    for center_idx in 0..centers {
789        let n_samples_center = if center_idx < remainder {
790            samples_per_center + 1
791        } else {
792            samples_per_center
793        };
794
795        // Generate a random rotation angle for this cluster
796        let rotation_angle = rng.gen_range(0.0..(2.0 * PI));
797
798        for _ in 0..n_samples_center {
799            // Generate point with anisotropic distribution (elongated along first axis)
800            let mut point = vec![0.0; n_features];
801
802            // First axis has normal std..second axis has reduced _std (anisotropy)
803            point[0] = normal.sample(&mut rng);
804            point[1] = normal.sample(&mut rng) / anisotropy_factor;
805
806            // Remaining axes have normal _std
807            for item in point.iter_mut().take(n_features).skip(2) {
808                *item = normal.sample(&mut rng);
809            }
810
811            // Apply rotation for 2D case
812            if n_features >= 2 {
813                let cos_theta = rotation_angle.cos();
814                let sin_theta = rotation_angle.sin();
815
816                let x_rot = cos_theta * point[0] - sin_theta * point[1];
817                let y_rot = sin_theta * point[0] + cos_theta * point[1];
818
819                point[0] = x_rot;
820                point[1] = y_rot;
821            }
822
823            // Translate to cluster center
824            for j in 0..n_features {
825                data[[sample_idx, j]] = cluster_centers[[center_idx, j]] + point[j];
826            }
827
828            target[sample_idx] = center_idx as f64;
829            sample_idx += 1;
830        }
831    }
832
833    let mut dataset = Dataset::new(data, Some(target));
834    let featurenames: Vec<String> = (0..n_features).map(|i| format!("feature_{i}")).collect();
835
836    dataset = dataset
837        .with_featurenames(featurenames)
838        .with_description(format!(
839            "Anisotropic clustering dataset with {centers} elongated clusters and {n_features} _features"
840        ))
841        .with_metadata("centers", &centers.to_string())
842        .with_metadata("cluster_std", &cluster_std.to_string())
843        .with_metadata("anisotropy_factor", &anisotropy_factor.to_string());
844
845    Ok(dataset)
846}
847
848/// Generate hierarchical clusters (clusters within clusters)
849#[allow(clippy::too_many_arguments)]
850#[allow(dead_code)]
851pub fn make_hierarchical_clusters(
852    n_samples: usize,
853    n_features: usize,
854    n_main_clusters: usize,
855    n_sub_clusters: usize,
856    main_cluster_std: f64,
857    sub_cluster_std: f64,
858    randomseed: Option<u64>,
859) -> Result<Dataset> {
860    // Validate input parameters
861    if n_samples == 0 {
862        return Err(DatasetsError::InvalidFormat(
863            "n_samples must be > 0".to_string(),
864        ));
865    }
866
867    if n_features == 0 {
868        return Err(DatasetsError::InvalidFormat(
869            "n_features must be > 0".to_string(),
870        ));
871    }
872
873    if n_main_clusters == 0 {
874        return Err(DatasetsError::InvalidFormat(
875            "n_main_clusters must be > 0".to_string(),
876        ));
877    }
878
879    if n_sub_clusters == 0 {
880        return Err(DatasetsError::InvalidFormat(
881            "n_sub_clusters must be > 0".to_string(),
882        ));
883    }
884
885    if main_cluster_std <= 0.0 {
886        return Err(DatasetsError::InvalidFormat(
887            "main_cluster_std must be > 0.0".to_string(),
888        ));
889    }
890
891    if sub_cluster_std <= 0.0 {
892        return Err(DatasetsError::InvalidFormat(
893            "sub_cluster_std must be > 0.0".to_string(),
894        ));
895    }
896
897    let mut rng = match randomseed {
898        Some(_seed) => StdRng::seed_from_u64(_seed),
899        None => {
900            let mut r = thread_rng();
901            StdRng::seed_from_u64(r.next_u64())
902        }
903    };
904
905    // Generate main cluster centers
906    let mut main_centers = Array2::zeros((n_main_clusters, n_features));
907    let center_box = 20.0;
908
909    for i in 0..n_main_clusters {
910        for j in 0..n_features {
911            main_centers[[i, j]] = rng.gen_range(-center_box..center_box);
912        }
913    }
914
915    let mut data = Array2::zeros((n_samples, n_features));
916    let mut main_target = Array1::zeros(n_samples);
917    let mut sub_target = Array1::zeros(n_samples);
918
919    let main_normal = scirs2_core::random::Normal::new(0.0, main_cluster_std).unwrap();
920    let sub_normal = scirs2_core::random::Normal::new(0.0, sub_cluster_std).unwrap();
921
922    let samples_per_main = n_samples / n_main_clusters;
923    let remainder = n_samples % n_main_clusters;
924
925    let mut sample_idx = 0;
926
927    for main_idx in 0..n_main_clusters {
928        let n_samples_main = if main_idx < remainder {
929            samples_per_main + 1
930        } else {
931            samples_per_main
932        };
933
934        // Generate sub-cluster centers within this main cluster
935        let mut sub_centers = Array2::zeros((n_sub_clusters, n_features));
936        for i in 0..n_sub_clusters {
937            for j in 0..n_features {
938                sub_centers[[i, j]] = main_centers[[main_idx, j]] + main_normal.sample(&mut rng);
939            }
940        }
941
942        let samples_per_sub = n_samples_main / n_sub_clusters;
943        let sub_remainder = n_samples_main % n_sub_clusters;
944
945        for sub_idx in 0..n_sub_clusters {
946            let n_samples_sub = if sub_idx < sub_remainder {
947                samples_per_sub + 1
948            } else {
949                samples_per_sub
950            };
951
952            for _ in 0..n_samples_sub {
953                for j in 0..n_features {
954                    data[[sample_idx, j]] = sub_centers[[sub_idx, j]] + sub_normal.sample(&mut rng);
955                }
956
957                main_target[sample_idx] = main_idx as f64;
958                sub_target[sample_idx] = (main_idx * n_sub_clusters + sub_idx) as f64;
959                sample_idx += 1;
960            }
961        }
962    }
963
964    let mut dataset = Dataset::new(data, Some(main_target));
965    let featurenames: Vec<String> = (0..n_features).map(|i| format!("feature_{i}")).collect();
966
967    dataset = dataset
968        .with_featurenames(featurenames)
969        .with_description(format!(
970            "Hierarchical clustering dataset with {n_main_clusters} main clusters, {n_sub_clusters} sub-_clusters each"
971        ))
972        .with_metadata("n_main_clusters", &n_main_clusters.to_string())
973        .with_metadata("n_sub_clusters", &n_sub_clusters.to_string())
974        .with_metadata("main_cluster_std", &main_cluster_std.to_string())
975        .with_metadata("sub_cluster_std", &sub_cluster_std.to_string());
976
977    let sub_target_vec = sub_target.to_vec();
978    dataset = dataset.with_metadata("sub_cluster_labels", &format!("{sub_target_vec:?}"));
979
980    Ok(dataset)
981}