1use crate::error::{DatasetsError, Result};
4use crate::utils::Dataset;
5use scirs2_core::ndarray::{Array1, Array2};
6use scirs2_core::random::prelude::*;
7use scirs2_core::random::rand_distributions::{Distribution, Uniform};
8use std::f64::consts::PI;
9
10#[allow(dead_code)]
12#[allow(clippy::too_many_arguments)]
13pub fn make_classification(
14 n_samples: usize,
15 n_features: usize,
16 n_classes: usize,
17 n_clusters_per_class: usize,
18 n_informative: usize,
19 randomseed: Option<u64>,
20) -> Result<Dataset> {
21 if n_samples == 0 {
23 return Err(DatasetsError::InvalidFormat(
24 "n_samples must be > 0".to_string(),
25 ));
26 }
27
28 if n_features == 0 {
29 return Err(DatasetsError::InvalidFormat(
30 "n_features must be > 0".to_string(),
31 ));
32 }
33
34 if n_informative == 0 {
35 return Err(DatasetsError::InvalidFormat(
36 "n_informative must be > 0".to_string(),
37 ));
38 }
39
40 if n_features < n_informative {
41 return Err(DatasetsError::InvalidFormat(format!(
42 "n_features ({n_features}) must be >= n_informative ({n_informative})"
43 )));
44 }
45
46 if n_classes < 2 {
47 return Err(DatasetsError::InvalidFormat(
48 "n_classes must be >= 2".to_string(),
49 ));
50 }
51
52 if n_clusters_per_class == 0 {
53 return Err(DatasetsError::InvalidFormat(
54 "n_clusters_per_class must be > 0".to_string(),
55 ));
56 }
57
58 let mut rng = match randomseed {
59 Some(_seed) => StdRng::seed_from_u64(_seed),
60 None => {
61 let mut r = thread_rng();
62 StdRng::seed_from_u64(r.next_u64())
63 }
64 };
65
66 let n_centroids = n_classes * n_clusters_per_class;
68 let mut centroids = Array2::zeros((n_centroids, n_informative));
69 let scale = 2.0;
70
71 for i in 0..n_centroids {
72 for j in 0..n_informative {
73 centroids[[i, j]] = scale * rng.gen_range(-1.0f64..1.0f64);
74 }
75 }
76
77 let mut data = Array2::zeros((n_samples, n_features));
79 let mut target = Array1::zeros(n_samples);
80
81 let normal = scirs2_core::random::Normal::new(0.0, 1.0).unwrap();
82
83 let samples_per_class = n_samples / n_classes;
85 let remainder = n_samples % n_classes;
86
87 let mut sample_idx = 0;
88
89 for _class in 0..n_classes {
90 let n_samples_class = if _class < remainder {
91 samples_per_class + 1
92 } else {
93 samples_per_class
94 };
95
96 let samples_per_cluster = n_samples_class / n_clusters_per_class;
98 let cluster_remainder = n_samples_class % n_clusters_per_class;
99
100 for cluster in 0..n_clusters_per_class {
101 let n_samples_cluster = if cluster < cluster_remainder {
102 samples_per_cluster + 1
103 } else {
104 samples_per_cluster
105 };
106
107 let centroid_idx = _class * n_clusters_per_class + cluster;
108
109 for _ in 0..n_samples_cluster {
110 for j in 0..n_informative {
112 data[[sample_idx, j]] =
113 centroids[[centroid_idx, j]] + 0.3 * normal.sample(&mut rng);
114 }
115
116 for j in n_informative..n_features {
118 data[[sample_idx, j]] = normal.sample(&mut rng);
119 }
120
121 target[sample_idx] = _class as f64;
122 sample_idx += 1;
123 }
124 }
125 }
126
127 let mut dataset = Dataset::new(data, Some(target));
129
130 let featurenames: Vec<String> = (0..n_features).map(|i| format!("feature_{i}")).collect();
132
133 let classnames: Vec<String> = (0..n_classes).map(|i| format!("class_{i}")).collect();
135
136 dataset = dataset
137 .with_featurenames(featurenames)
138 .with_targetnames(classnames)
139 .with_description(format!(
140 "Synthetic classification dataset with {n_classes} _classes and {n_features} _features"
141 ));
142
143 Ok(dataset)
144}
145
146#[allow(dead_code)]
148pub fn make_regression(
149 n_samples: usize,
150 n_features: usize,
151 n_informative: usize,
152 noise: f64,
153 randomseed: Option<u64>,
154) -> Result<Dataset> {
155 if n_samples == 0 {
157 return Err(DatasetsError::InvalidFormat(
158 "n_samples must be > 0".to_string(),
159 ));
160 }
161
162 if n_features == 0 {
163 return Err(DatasetsError::InvalidFormat(
164 "n_features must be > 0".to_string(),
165 ));
166 }
167
168 if n_informative == 0 {
169 return Err(DatasetsError::InvalidFormat(
170 "n_informative must be > 0".to_string(),
171 ));
172 }
173
174 if n_features < n_informative {
175 return Err(DatasetsError::InvalidFormat(format!(
176 "n_features ({n_features}) must be >= n_informative ({n_informative})"
177 )));
178 }
179
180 if noise < 0.0 {
181 return Err(DatasetsError::InvalidFormat(
182 "noise must be >= 0.0".to_string(),
183 ));
184 }
185
186 let mut rng = match randomseed {
187 Some(_seed) => StdRng::seed_from_u64(_seed),
188 None => {
189 let mut r = thread_rng();
190 StdRng::seed_from_u64(r.next_u64())
191 }
192 };
193
194 let mut coef = Array1::zeros(n_features);
196 let normal = scirs2_core::random::Normal::new(0.0, 1.0).unwrap();
197
198 for i in 0..n_informative {
199 coef[i] = 100.0 * normal.sample(&mut rng);
200 }
201
202 let mut data = Array2::zeros((n_samples, n_features));
204
205 for i in 0..n_samples {
206 for j in 0..n_features {
207 data[[i, j]] = normal.sample(&mut rng);
208 }
209 }
210
211 let mut target = Array1::zeros(n_samples);
213
214 for i in 0..n_samples {
215 let mut y = 0.0;
216 for j in 0..n_features {
217 y += data[[i, j]] * coef[j];
218 }
219
220 if noise > 0.0 {
222 y += normal.sample(&mut rng) * noise;
223 }
224
225 target[i] = y;
226 }
227
228 let mut dataset = Dataset::new(data, Some(target));
230
231 let featurenames: Vec<String> = (0..n_features).map(|i| format!("feature_{i}")).collect();
233
234 dataset = dataset
235 .with_featurenames(featurenames)
236 .with_description(format!(
237 "Synthetic regression dataset with {n_features} _features ({n_informative} informative)"
238 ))
239 .with_metadata("noise", &noise.to_string())
240 .with_metadata("coefficients", &format!("{coef:?}"));
241
242 Ok(dataset)
243}
244
245#[allow(dead_code)]
247pub fn make_time_series(
248 n_samples: usize,
249 n_features: usize,
250 trend: bool,
251 seasonality: bool,
252 noise: f64,
253 randomseed: Option<u64>,
254) -> Result<Dataset> {
255 if n_samples == 0 {
257 return Err(DatasetsError::InvalidFormat(
258 "n_samples must be > 0".to_string(),
259 ));
260 }
261
262 if n_features == 0 {
263 return Err(DatasetsError::InvalidFormat(
264 "n_features must be > 0".to_string(),
265 ));
266 }
267
268 if noise < 0.0 {
269 return Err(DatasetsError::InvalidFormat(
270 "noise must be >= 0.0".to_string(),
271 ));
272 }
273
274 let mut rng = match randomseed {
275 Some(_seed) => StdRng::seed_from_u64(_seed),
276 None => {
277 let mut r = thread_rng();
278 StdRng::seed_from_u64(r.next_u64())
279 }
280 };
281
282 let normal = scirs2_core::random::Normal::new(0.0, 1.0).unwrap();
283 let mut data = Array2::zeros((n_samples, n_features));
284
285 for feature in 0..n_features {
286 let trend_coef = if trend {
287 rng.gen_range(0.01f64..0.1f64)
288 } else {
289 0.0
290 };
291 let seasonality_period = rng.sample(Uniform::new(10, 50).unwrap()) as f64;
292 let seasonality_amplitude = if seasonality {
293 rng.gen_range(1.0f64..5.0f64)
294 } else {
295 0.0
296 };
297
298 let base_value = rng.gen_range(-10.0f64..10.0f64);
299
300 for i in 0..n_samples {
301 let t = i as f64;
302
303 let mut value = base_value;
305
306 if trend {
308 value += trend_coef * t;
309 }
310
311 if seasonality {
313 value += seasonality_amplitude * (2.0 * PI * t / seasonality_period).sin();
314 }
315
316 if noise > 0.0 {
318 value += normal.sample(&mut rng) * noise;
319 }
320
321 data[[i, feature]] = value;
322 }
323 }
324
325 let time_index: Vec<f64> = (0..n_samples).map(|i| i as f64).collect();
327 let _time_array = Array1::from(time_index);
328
329 let mut dataset = Dataset::new(data, None);
331
332 let featurenames: Vec<String> = (0..n_features).map(|i| format!("feature_{i}")).collect();
334
335 dataset = dataset
336 .with_featurenames(featurenames)
337 .with_description(format!(
338 "Synthetic time series dataset with {n_features} _features"
339 ))
340 .with_metadata("trend", &trend.to_string())
341 .with_metadata("seasonality", &seasonality.to_string())
342 .with_metadata("noise", &noise.to_string());
343
344 Ok(dataset)
345}
346
347#[allow(dead_code)]
349pub fn make_blobs(
350 n_samples: usize,
351 n_features: usize,
352 centers: usize,
353 cluster_std: f64,
354 randomseed: Option<u64>,
355) -> Result<Dataset> {
356 if n_samples == 0 {
358 return Err(DatasetsError::InvalidFormat(
359 "n_samples must be > 0".to_string(),
360 ));
361 }
362
363 if n_features == 0 {
364 return Err(DatasetsError::InvalidFormat(
365 "n_features must be > 0".to_string(),
366 ));
367 }
368
369 if centers == 0 {
370 return Err(DatasetsError::InvalidFormat(
371 "centers must be > 0".to_string(),
372 ));
373 }
374
375 if cluster_std <= 0.0 {
376 return Err(DatasetsError::InvalidFormat(
377 "cluster_std must be > 0.0".to_string(),
378 ));
379 }
380
381 let mut rng = match randomseed {
382 Some(_seed) => StdRng::seed_from_u64(_seed),
383 None => {
384 let mut r = thread_rng();
385 StdRng::seed_from_u64(r.next_u64())
386 }
387 };
388
389 let mut cluster_centers = Array2::zeros((centers, n_features));
391 let center_box = 10.0;
392
393 for i in 0..centers {
394 for j in 0..n_features {
395 cluster_centers[[i, j]] = rng.gen_range(-center_box..center_box);
396 }
397 }
398
399 let mut data = Array2::zeros((n_samples, n_features));
401 let mut target = Array1::zeros(n_samples);
402
403 let normal = scirs2_core::random::Normal::new(0.0, cluster_std).unwrap();
404
405 let samples_per_center = n_samples / centers;
407 let remainder = n_samples % centers;
408
409 let mut sample_idx = 0;
410
411 for center_idx in 0..centers {
412 let n_samples_center = if center_idx < remainder {
413 samples_per_center + 1
414 } else {
415 samples_per_center
416 };
417
418 for _ in 0..n_samples_center {
419 for j in 0..n_features {
420 data[[sample_idx, j]] = cluster_centers[[center_idx, j]] + normal.sample(&mut rng);
421 }
422
423 target[sample_idx] = center_idx as f64;
424 sample_idx += 1;
425 }
426 }
427
428 let mut dataset = Dataset::new(data, Some(target));
430
431 let featurenames: Vec<String> = (0..n_features).map(|i| format!("feature_{i}")).collect();
433
434 dataset = dataset
435 .with_featurenames(featurenames)
436 .with_description(format!(
437 "Synthetic clustering dataset with {centers} clusters and {n_features} _features"
438 ))
439 .with_metadata("centers", ¢ers.to_string())
440 .with_metadata("cluster_std", &cluster_std.to_string());
441
442 Ok(dataset)
443}
444
445#[allow(dead_code)]
447pub fn make_spirals(
448 n_samples: usize,
449 n_spirals: usize,
450 noise: f64,
451 randomseed: Option<u64>,
452) -> Result<Dataset> {
453 if n_samples == 0 {
455 return Err(DatasetsError::InvalidFormat(
456 "n_samples must be > 0".to_string(),
457 ));
458 }
459
460 if n_spirals == 0 {
461 return Err(DatasetsError::InvalidFormat(
462 "n_spirals must be > 0".to_string(),
463 ));
464 }
465
466 if noise < 0.0 {
467 return Err(DatasetsError::InvalidFormat(
468 "noise must be >= 0.0".to_string(),
469 ));
470 }
471
472 let mut rng = match randomseed {
473 Some(_seed) => StdRng::seed_from_u64(_seed),
474 None => {
475 let mut r = thread_rng();
476 StdRng::seed_from_u64(r.next_u64())
477 }
478 };
479
480 let mut data = Array2::zeros((n_samples, 2));
481 let mut target = Array1::zeros(n_samples);
482
483 let normal = if noise > 0.0 {
484 Some(scirs2_core::random::Normal::new(0.0, noise).unwrap())
485 } else {
486 None
487 };
488
489 let samples_per_spiral = n_samples / n_spirals;
490 let remainder = n_samples % n_spirals;
491
492 let mut sample_idx = 0;
493
494 for spiral in 0..n_spirals {
495 let n_samples_spiral = if spiral < remainder {
496 samples_per_spiral + 1
497 } else {
498 samples_per_spiral
499 };
500
501 let spiral_offset = 2.0 * PI * spiral as f64 / n_spirals as f64;
502
503 for i in 0..n_samples_spiral {
504 let t = 2.0 * PI * i as f64 / n_samples_spiral as f64;
505 let radius = t / (2.0 * PI);
506
507 let mut x = radius * (t + spiral_offset).cos();
508 let mut y = radius * (t + spiral_offset).sin();
509
510 if let Some(ref normal_dist) = normal {
512 x += normal_dist.sample(&mut rng);
513 y += normal_dist.sample(&mut rng);
514 }
515
516 data[[sample_idx, 0]] = x;
517 data[[sample_idx, 1]] = y;
518 target[sample_idx] = spiral as f64;
519 sample_idx += 1;
520 }
521 }
522
523 let mut dataset = Dataset::new(data, Some(target));
524 dataset = dataset
525 .with_featurenames(vec!["x".to_string(), "y".to_string()])
526 .with_targetnames((0..n_spirals).map(|i| format!("spiral_{i}")).collect())
527 .with_description(format!("Spiral dataset with {n_spirals} _spirals"))
528 .with_metadata("noise", &noise.to_string());
529
530 Ok(dataset)
531}
532
533#[allow(dead_code)]
535pub fn make_moons(n_samples: usize, noise: f64, randomseed: Option<u64>) -> Result<Dataset> {
536 if n_samples == 0 {
538 return Err(DatasetsError::InvalidFormat(
539 "n_samples must be > 0".to_string(),
540 ));
541 }
542
543 if noise < 0.0 {
544 return Err(DatasetsError::InvalidFormat(
545 "noise must be >= 0.0".to_string(),
546 ));
547 }
548
549 let mut rng = match randomseed {
550 Some(_seed) => StdRng::seed_from_u64(_seed),
551 None => {
552 let mut r = thread_rng();
553 StdRng::seed_from_u64(r.next_u64())
554 }
555 };
556
557 let mut data = Array2::zeros((n_samples, 2));
558 let mut target = Array1::zeros(n_samples);
559
560 let normal = if noise > 0.0 {
561 Some(scirs2_core::random::Normal::new(0.0, noise).unwrap())
562 } else {
563 None
564 };
565
566 let samples_per_moon = n_samples / 2;
567 let remainder = n_samples % 2;
568
569 let mut sample_idx = 0;
570
571 for i in 0..(samples_per_moon + remainder) {
573 let t = PI * i as f64 / (samples_per_moon + remainder) as f64;
574
575 let mut x = t.cos();
576 let mut y = t.sin();
577
578 if let Some(ref normal_dist) = normal {
580 x += normal_dist.sample(&mut rng);
581 y += normal_dist.sample(&mut rng);
582 }
583
584 data[[sample_idx, 0]] = x;
585 data[[sample_idx, 1]] = y;
586 target[sample_idx] = 0.0;
587 sample_idx += 1;
588 }
589
590 for i in 0..samples_per_moon {
592 let t = PI * i as f64 / samples_per_moon as f64;
593
594 let mut x = 1.0 - t.cos();
595 let mut y = 0.5 - t.sin(); if let Some(ref normal_dist) = normal {
599 x += normal_dist.sample(&mut rng);
600 y += normal_dist.sample(&mut rng);
601 }
602
603 data[[sample_idx, 0]] = x;
604 data[[sample_idx, 1]] = y;
605 target[sample_idx] = 1.0;
606 sample_idx += 1;
607 }
608
609 let mut dataset = Dataset::new(data, Some(target));
610 dataset = dataset
611 .with_featurenames(vec!["x".to_string(), "y".to_string()])
612 .with_targetnames(vec!["moon_0".to_string(), "moon_1".to_string()])
613 .with_description("Two moons dataset for non-linear classification".to_string())
614 .with_metadata("noise", &noise.to_string());
615
616 Ok(dataset)
617}
618
619#[allow(dead_code)]
621pub fn make_circles(
622 n_samples: usize,
623 factor: f64,
624 noise: f64,
625 randomseed: Option<u64>,
626) -> Result<Dataset> {
627 if n_samples == 0 {
629 return Err(DatasetsError::InvalidFormat(
630 "n_samples must be > 0".to_string(),
631 ));
632 }
633
634 if factor <= 0.0 || factor >= 1.0 {
635 return Err(DatasetsError::InvalidFormat(
636 "factor must be between 0.0 and 1.0".to_string(),
637 ));
638 }
639
640 if noise < 0.0 {
641 return Err(DatasetsError::InvalidFormat(
642 "noise must be >= 0.0".to_string(),
643 ));
644 }
645
646 let mut rng = match randomseed {
647 Some(_seed) => StdRng::seed_from_u64(_seed),
648 None => {
649 let mut r = thread_rng();
650 StdRng::seed_from_u64(r.next_u64())
651 }
652 };
653
654 let mut data = Array2::zeros((n_samples, 2));
655 let mut target = Array1::zeros(n_samples);
656
657 let normal = if noise > 0.0 {
658 Some(scirs2_core::random::Normal::new(0.0, noise).unwrap())
659 } else {
660 None
661 };
662
663 let samples_per_circle = n_samples / 2;
664 let remainder = n_samples % 2;
665
666 let mut sample_idx = 0;
667
668 for i in 0..(samples_per_circle + remainder) {
670 let angle = 2.0 * PI * i as f64 / (samples_per_circle + remainder) as f64;
671
672 let mut x = angle.cos();
673 let mut y = angle.sin();
674
675 if let Some(ref normal_dist) = normal {
677 x += normal_dist.sample(&mut rng);
678 y += normal_dist.sample(&mut rng);
679 }
680
681 data[[sample_idx, 0]] = x;
682 data[[sample_idx, 1]] = y;
683 target[sample_idx] = 0.0;
684 sample_idx += 1;
685 }
686
687 for i in 0..samples_per_circle {
689 let angle = 2.0 * PI * i as f64 / samples_per_circle as f64;
690
691 let mut x = factor * angle.cos();
692 let mut y = factor * angle.sin();
693
694 if let Some(ref normal_dist) = normal {
696 x += normal_dist.sample(&mut rng);
697 y += normal_dist.sample(&mut rng);
698 }
699
700 data[[sample_idx, 0]] = x;
701 data[[sample_idx, 1]] = y;
702 target[sample_idx] = 1.0;
703 sample_idx += 1;
704 }
705
706 let mut dataset = Dataset::new(data, Some(target));
707 dataset = dataset
708 .with_featurenames(vec!["x".to_string(), "y".to_string()])
709 .with_targetnames(vec!["outer_circle".to_string(), "inner_circle".to_string()])
710 .with_description("Concentric circles dataset for non-linear classification".to_string())
711 .with_metadata("factor", &factor.to_string())
712 .with_metadata("noise", &noise.to_string());
713
714 Ok(dataset)
715}
716
717#[allow(clippy::too_many_arguments)]
719#[allow(dead_code)]
720pub fn make_anisotropic_blobs(
721 n_samples: usize,
722 n_features: usize,
723 centers: usize,
724 cluster_std: f64,
725 anisotropy_factor: f64,
726 randomseed: Option<u64>,
727) -> Result<Dataset> {
728 if n_samples == 0 {
730 return Err(DatasetsError::InvalidFormat(
731 "n_samples must be > 0".to_string(),
732 ));
733 }
734
735 if n_features < 2 {
736 return Err(DatasetsError::InvalidFormat(
737 "n_features must be >= 2 for anisotropic clusters".to_string(),
738 ));
739 }
740
741 if centers == 0 {
742 return Err(DatasetsError::InvalidFormat(
743 "centers must be > 0".to_string(),
744 ));
745 }
746
747 if cluster_std <= 0.0 {
748 return Err(DatasetsError::InvalidFormat(
749 "cluster_std must be > 0.0".to_string(),
750 ));
751 }
752
753 if anisotropy_factor <= 0.0 {
754 return Err(DatasetsError::InvalidFormat(
755 "anisotropy_factor must be > 0.0".to_string(),
756 ));
757 }
758
759 let mut rng = match randomseed {
760 Some(_seed) => StdRng::seed_from_u64(_seed),
761 None => {
762 let mut r = thread_rng();
763 StdRng::seed_from_u64(r.next_u64())
764 }
765 };
766
767 let mut cluster_centers = Array2::zeros((centers, n_features));
769 let center_box = 10.0;
770
771 for i in 0..centers {
772 for j in 0..n_features {
773 cluster_centers[[i, j]] = rng.gen_range(-center_box..center_box);
774 }
775 }
776
777 let mut data = Array2::zeros((n_samples, n_features));
779 let mut target = Array1::zeros(n_samples);
780
781 let normal = scirs2_core::random::Normal::new(0.0, cluster_std).unwrap();
782
783 let samples_per_center = n_samples / centers;
784 let remainder = n_samples % centers;
785
786 let mut sample_idx = 0;
787
788 for center_idx in 0..centers {
789 let n_samples_center = if center_idx < remainder {
790 samples_per_center + 1
791 } else {
792 samples_per_center
793 };
794
795 let rotation_angle = rng.gen_range(0.0..(2.0 * PI));
797
798 for _ in 0..n_samples_center {
799 let mut point = vec![0.0; n_features];
801
802 point[0] = normal.sample(&mut rng);
804 point[1] = normal.sample(&mut rng) / anisotropy_factor;
805
806 for item in point.iter_mut().take(n_features).skip(2) {
808 *item = normal.sample(&mut rng);
809 }
810
811 if n_features >= 2 {
813 let cos_theta = rotation_angle.cos();
814 let sin_theta = rotation_angle.sin();
815
816 let x_rot = cos_theta * point[0] - sin_theta * point[1];
817 let y_rot = sin_theta * point[0] + cos_theta * point[1];
818
819 point[0] = x_rot;
820 point[1] = y_rot;
821 }
822
823 for j in 0..n_features {
825 data[[sample_idx, j]] = cluster_centers[[center_idx, j]] + point[j];
826 }
827
828 target[sample_idx] = center_idx as f64;
829 sample_idx += 1;
830 }
831 }
832
833 let mut dataset = Dataset::new(data, Some(target));
834 let featurenames: Vec<String> = (0..n_features).map(|i| format!("feature_{i}")).collect();
835
836 dataset = dataset
837 .with_featurenames(featurenames)
838 .with_description(format!(
839 "Anisotropic clustering dataset with {centers} elongated clusters and {n_features} _features"
840 ))
841 .with_metadata("centers", ¢ers.to_string())
842 .with_metadata("cluster_std", &cluster_std.to_string())
843 .with_metadata("anisotropy_factor", &anisotropy_factor.to_string());
844
845 Ok(dataset)
846}
847
848#[allow(clippy::too_many_arguments)]
850#[allow(dead_code)]
851pub fn make_hierarchical_clusters(
852 n_samples: usize,
853 n_features: usize,
854 n_main_clusters: usize,
855 n_sub_clusters: usize,
856 main_cluster_std: f64,
857 sub_cluster_std: f64,
858 randomseed: Option<u64>,
859) -> Result<Dataset> {
860 if n_samples == 0 {
862 return Err(DatasetsError::InvalidFormat(
863 "n_samples must be > 0".to_string(),
864 ));
865 }
866
867 if n_features == 0 {
868 return Err(DatasetsError::InvalidFormat(
869 "n_features must be > 0".to_string(),
870 ));
871 }
872
873 if n_main_clusters == 0 {
874 return Err(DatasetsError::InvalidFormat(
875 "n_main_clusters must be > 0".to_string(),
876 ));
877 }
878
879 if n_sub_clusters == 0 {
880 return Err(DatasetsError::InvalidFormat(
881 "n_sub_clusters must be > 0".to_string(),
882 ));
883 }
884
885 if main_cluster_std <= 0.0 {
886 return Err(DatasetsError::InvalidFormat(
887 "main_cluster_std must be > 0.0".to_string(),
888 ));
889 }
890
891 if sub_cluster_std <= 0.0 {
892 return Err(DatasetsError::InvalidFormat(
893 "sub_cluster_std must be > 0.0".to_string(),
894 ));
895 }
896
897 let mut rng = match randomseed {
898 Some(_seed) => StdRng::seed_from_u64(_seed),
899 None => {
900 let mut r = thread_rng();
901 StdRng::seed_from_u64(r.next_u64())
902 }
903 };
904
905 let mut main_centers = Array2::zeros((n_main_clusters, n_features));
907 let center_box = 20.0;
908
909 for i in 0..n_main_clusters {
910 for j in 0..n_features {
911 main_centers[[i, j]] = rng.gen_range(-center_box..center_box);
912 }
913 }
914
915 let mut data = Array2::zeros((n_samples, n_features));
916 let mut main_target = Array1::zeros(n_samples);
917 let mut sub_target = Array1::zeros(n_samples);
918
919 let main_normal = scirs2_core::random::Normal::new(0.0, main_cluster_std).unwrap();
920 let sub_normal = scirs2_core::random::Normal::new(0.0, sub_cluster_std).unwrap();
921
922 let samples_per_main = n_samples / n_main_clusters;
923 let remainder = n_samples % n_main_clusters;
924
925 let mut sample_idx = 0;
926
927 for main_idx in 0..n_main_clusters {
928 let n_samples_main = if main_idx < remainder {
929 samples_per_main + 1
930 } else {
931 samples_per_main
932 };
933
934 let mut sub_centers = Array2::zeros((n_sub_clusters, n_features));
936 for i in 0..n_sub_clusters {
937 for j in 0..n_features {
938 sub_centers[[i, j]] = main_centers[[main_idx, j]] + main_normal.sample(&mut rng);
939 }
940 }
941
942 let samples_per_sub = n_samples_main / n_sub_clusters;
943 let sub_remainder = n_samples_main % n_sub_clusters;
944
945 for sub_idx in 0..n_sub_clusters {
946 let n_samples_sub = if sub_idx < sub_remainder {
947 samples_per_sub + 1
948 } else {
949 samples_per_sub
950 };
951
952 for _ in 0..n_samples_sub {
953 for j in 0..n_features {
954 data[[sample_idx, j]] = sub_centers[[sub_idx, j]] + sub_normal.sample(&mut rng);
955 }
956
957 main_target[sample_idx] = main_idx as f64;
958 sub_target[sample_idx] = (main_idx * n_sub_clusters + sub_idx) as f64;
959 sample_idx += 1;
960 }
961 }
962 }
963
964 let mut dataset = Dataset::new(data, Some(main_target));
965 let featurenames: Vec<String> = (0..n_features).map(|i| format!("feature_{i}")).collect();
966
967 dataset = dataset
968 .with_featurenames(featurenames)
969 .with_description(format!(
970 "Hierarchical clustering dataset with {n_main_clusters} main clusters, {n_sub_clusters} sub-_clusters each"
971 ))
972 .with_metadata("n_main_clusters", &n_main_clusters.to_string())
973 .with_metadata("n_sub_clusters", &n_sub_clusters.to_string())
974 .with_metadata("main_cluster_std", &main_cluster_std.to_string())
975 .with_metadata("sub_cluster_std", &sub_cluster_std.to_string());
976
977 let sub_target_vec = sub_target.to_vec();
978 dataset = dataset.with_metadata("sub_cluster_labels", &format!("{sub_target_vec:?}"));
979
980 Ok(dataset)
981}