1use crate::error::{DatasetsError, Result};
4use crate::utils::Dataset;
5use ndarray::{Array1, Array2};
6use rand::prelude::*;
7use rand::rng;
8use rand::rngs::StdRng;
9use rand_distr::Distribution;
10use std::f64::consts::PI;
11
12pub fn make_classification(
14 n_samples: usize,
15 n_features: usize,
16 n_classes: usize,
17 n_clusters_per_class: usize,
18 n_informative: usize,
19 random_seed: Option<u64>,
20) -> Result<Dataset> {
21 if n_samples == 0 {
23 return Err(DatasetsError::InvalidFormat(
24 "n_samples must be > 0".to_string(),
25 ));
26 }
27
28 if n_features == 0 {
29 return Err(DatasetsError::InvalidFormat(
30 "n_features must be > 0".to_string(),
31 ));
32 }
33
34 if n_informative == 0 {
35 return Err(DatasetsError::InvalidFormat(
36 "n_informative must be > 0".to_string(),
37 ));
38 }
39
40 if n_features < n_informative {
41 return Err(DatasetsError::InvalidFormat(format!(
42 "n_features ({}) must be >= n_informative ({})",
43 n_features, n_informative
44 )));
45 }
46
47 if n_classes < 2 {
48 return Err(DatasetsError::InvalidFormat(
49 "n_classes must be >= 2".to_string(),
50 ));
51 }
52
53 if n_clusters_per_class == 0 {
54 return Err(DatasetsError::InvalidFormat(
55 "n_clusters_per_class must be > 0".to_string(),
56 ));
57 }
58
59 let mut rng = match random_seed {
60 Some(seed) => StdRng::seed_from_u64(seed),
61 None => {
62 let mut r = rng();
63 StdRng::seed_from_u64(r.next_u64())
64 }
65 };
66
67 let n_centroids = n_classes * n_clusters_per_class;
69 let mut centroids = Array2::zeros((n_centroids, n_informative));
70 let scale = 2.0;
71
72 for i in 0..n_centroids {
73 for j in 0..n_informative {
74 centroids[[i, j]] = scale * rng.random_range(-1.0f64..1.0f64);
75 }
76 }
77
78 let mut data = Array2::zeros((n_samples, n_features));
80 let mut target = Array1::zeros(n_samples);
81
82 let normal = rand_distr::Normal::new(0.0, 1.0).unwrap();
83
84 let samples_per_class = n_samples / n_classes;
86 let remainder = n_samples % n_classes;
87
88 let mut sample_idx = 0;
89
90 for class in 0..n_classes {
91 let n_samples_class = if class < remainder {
92 samples_per_class + 1
93 } else {
94 samples_per_class
95 };
96
97 let samples_per_cluster = n_samples_class / n_clusters_per_class;
99 let cluster_remainder = n_samples_class % n_clusters_per_class;
100
101 for cluster in 0..n_clusters_per_class {
102 let n_samples_cluster = if cluster < cluster_remainder {
103 samples_per_cluster + 1
104 } else {
105 samples_per_cluster
106 };
107
108 let centroid_idx = class * n_clusters_per_class + cluster;
109
110 for _ in 0..n_samples_cluster {
111 for j in 0..n_informative {
113 data[[sample_idx, j]] =
114 centroids[[centroid_idx, j]] + 0.3 * normal.sample(&mut rng);
115 }
116
117 for j in n_informative..n_features {
119 data[[sample_idx, j]] = normal.sample(&mut rng);
120 }
121
122 target[sample_idx] = class as f64;
123 sample_idx += 1;
124 }
125 }
126 }
127
128 let mut dataset = Dataset::new(data, Some(target));
130
131 let feature_names: Vec<String> = (0..n_features).map(|i| format!("feature_{}", i)).collect();
133
134 let class_names: Vec<String> = (0..n_classes).map(|i| format!("class_{}", i)).collect();
136
137 dataset = dataset
138 .with_feature_names(feature_names)
139 .with_target_names(class_names)
140 .with_description(format!(
141 "Synthetic classification dataset with {} classes and {} features",
142 n_classes, n_features
143 ));
144
145 Ok(dataset)
146}
147
148pub fn make_regression(
150 n_samples: usize,
151 n_features: usize,
152 n_informative: usize,
153 noise: f64,
154 random_seed: Option<u64>,
155) -> Result<Dataset> {
156 if n_samples == 0 {
158 return Err(DatasetsError::InvalidFormat(
159 "n_samples must be > 0".to_string(),
160 ));
161 }
162
163 if n_features == 0 {
164 return Err(DatasetsError::InvalidFormat(
165 "n_features must be > 0".to_string(),
166 ));
167 }
168
169 if n_informative == 0 {
170 return Err(DatasetsError::InvalidFormat(
171 "n_informative must be > 0".to_string(),
172 ));
173 }
174
175 if n_features < n_informative {
176 return Err(DatasetsError::InvalidFormat(format!(
177 "n_features ({}) must be >= n_informative ({})",
178 n_features, n_informative
179 )));
180 }
181
182 if noise < 0.0 {
183 return Err(DatasetsError::InvalidFormat(
184 "noise must be >= 0.0".to_string(),
185 ));
186 }
187
188 let mut rng = match random_seed {
189 Some(seed) => StdRng::seed_from_u64(seed),
190 None => {
191 let mut r = rng();
192 StdRng::seed_from_u64(r.next_u64())
193 }
194 };
195
196 let mut coef = Array1::zeros(n_features);
198 let normal = rand_distr::Normal::new(0.0, 1.0).unwrap();
199
200 for i in 0..n_informative {
201 coef[i] = 100.0 * normal.sample(&mut rng);
202 }
203
204 let mut data = Array2::zeros((n_samples, n_features));
206
207 for i in 0..n_samples {
208 for j in 0..n_features {
209 data[[i, j]] = normal.sample(&mut rng);
210 }
211 }
212
213 let mut target = Array1::zeros(n_samples);
215
216 for i in 0..n_samples {
217 let mut y = 0.0;
218 for j in 0..n_features {
219 y += data[[i, j]] * coef[j];
220 }
221
222 if noise > 0.0 {
224 y += normal.sample(&mut rng) * noise;
225 }
226
227 target[i] = y;
228 }
229
230 let mut dataset = Dataset::new(data, Some(target));
232
233 let feature_names: Vec<String> = (0..n_features).map(|i| format!("feature_{}", i)).collect();
235
236 dataset = dataset
237 .with_feature_names(feature_names)
238 .with_description(format!(
239 "Synthetic regression dataset with {} features ({} informative)",
240 n_features, n_informative
241 ))
242 .with_metadata("noise", &noise.to_string())
243 .with_metadata("coefficients", &format!("{:?}", coef));
244
245 Ok(dataset)
246}
247
248pub fn make_time_series(
250 n_samples: usize,
251 n_features: usize,
252 trend: bool,
253 seasonality: bool,
254 noise: f64,
255 random_seed: Option<u64>,
256) -> Result<Dataset> {
257 if n_samples == 0 {
259 return Err(DatasetsError::InvalidFormat(
260 "n_samples must be > 0".to_string(),
261 ));
262 }
263
264 if n_features == 0 {
265 return Err(DatasetsError::InvalidFormat(
266 "n_features must be > 0".to_string(),
267 ));
268 }
269
270 if noise < 0.0 {
271 return Err(DatasetsError::InvalidFormat(
272 "noise must be >= 0.0".to_string(),
273 ));
274 }
275
276 let mut rng = match random_seed {
277 Some(seed) => StdRng::seed_from_u64(seed),
278 None => {
279 let mut r = rng();
280 StdRng::seed_from_u64(r.next_u64())
281 }
282 };
283
284 let normal = rand_distr::Normal::new(0.0, 1.0).unwrap();
285 let mut data = Array2::zeros((n_samples, n_features));
286
287 for feature in 0..n_features {
288 let trend_coef = if trend {
289 rng.random_range(0.01f64..0.1f64)
290 } else {
291 0.0
292 };
293 let seasonality_period = rng.random_range(10..=50) as f64;
294 let seasonality_amplitude = if seasonality {
295 rng.random_range(1.0f64..5.0f64)
296 } else {
297 0.0
298 };
299
300 let base_value = rng.random_range(-10.0f64..10.0f64);
301
302 for i in 0..n_samples {
303 let t = i as f64;
304
305 let mut value = base_value;
307
308 if trend {
310 value += trend_coef * t;
311 }
312
313 if seasonality {
315 value += seasonality_amplitude * (2.0 * PI * t / seasonality_period).sin();
316 }
317
318 if noise > 0.0 {
320 value += normal.sample(&mut rng) * noise;
321 }
322
323 data[[i, feature]] = value;
324 }
325 }
326
327 let time_index: Vec<f64> = (0..n_samples).map(|i| i as f64).collect();
329 let _time_array = Array1::from(time_index);
330
331 let mut dataset = Dataset::new(data, None);
333
334 let feature_names: Vec<String> = (0..n_features).map(|i| format!("feature_{}", i)).collect();
336
337 dataset = dataset
338 .with_feature_names(feature_names)
339 .with_description(format!(
340 "Synthetic time series dataset with {} features",
341 n_features
342 ))
343 .with_metadata("trend", &trend.to_string())
344 .with_metadata("seasonality", &seasonality.to_string())
345 .with_metadata("noise", &noise.to_string());
346
347 Ok(dataset)
348}
349
350pub fn make_blobs(
352 n_samples: usize,
353 n_features: usize,
354 centers: usize,
355 cluster_std: f64,
356 random_seed: Option<u64>,
357) -> Result<Dataset> {
358 if n_samples == 0 {
360 return Err(DatasetsError::InvalidFormat(
361 "n_samples must be > 0".to_string(),
362 ));
363 }
364
365 if n_features == 0 {
366 return Err(DatasetsError::InvalidFormat(
367 "n_features must be > 0".to_string(),
368 ));
369 }
370
371 if centers == 0 {
372 return Err(DatasetsError::InvalidFormat(
373 "centers must be > 0".to_string(),
374 ));
375 }
376
377 if cluster_std <= 0.0 {
378 return Err(DatasetsError::InvalidFormat(
379 "cluster_std must be > 0.0".to_string(),
380 ));
381 }
382
383 let mut rng = match random_seed {
384 Some(seed) => StdRng::seed_from_u64(seed),
385 None => {
386 let mut r = rng();
387 StdRng::seed_from_u64(r.next_u64())
388 }
389 };
390
391 let mut cluster_centers = Array2::zeros((centers, n_features));
393 let center_box = 10.0;
394
395 for i in 0..centers {
396 for j in 0..n_features {
397 cluster_centers[[i, j]] = rng.random_range(-center_box..=center_box);
398 }
399 }
400
401 let mut data = Array2::zeros((n_samples, n_features));
403 let mut target = Array1::zeros(n_samples);
404
405 let normal = rand_distr::Normal::new(0.0, cluster_std).unwrap();
406
407 let samples_per_center = n_samples / centers;
409 let remainder = n_samples % centers;
410
411 let mut sample_idx = 0;
412
413 for center_idx in 0..centers {
414 let n_samples_center = if center_idx < remainder {
415 samples_per_center + 1
416 } else {
417 samples_per_center
418 };
419
420 for _ in 0..n_samples_center {
421 for j in 0..n_features {
422 data[[sample_idx, j]] = cluster_centers[[center_idx, j]] + normal.sample(&mut rng);
423 }
424
425 target[sample_idx] = center_idx as f64;
426 sample_idx += 1;
427 }
428 }
429
430 let mut dataset = Dataset::new(data, Some(target));
432
433 let feature_names: Vec<String> = (0..n_features).map(|i| format!("feature_{}", i)).collect();
435
436 dataset = dataset
437 .with_feature_names(feature_names)
438 .with_description(format!(
439 "Synthetic clustering dataset with {} clusters and {} features",
440 centers, n_features
441 ))
442 .with_metadata("centers", ¢ers.to_string())
443 .with_metadata("cluster_std", &cluster_std.to_string());
444
445 Ok(dataset)
446}
447
448pub fn make_spirals(
450 n_samples: usize,
451 n_spirals: usize,
452 noise: f64,
453 random_seed: Option<u64>,
454) -> Result<Dataset> {
455 if n_samples == 0 {
457 return Err(DatasetsError::InvalidFormat(
458 "n_samples must be > 0".to_string(),
459 ));
460 }
461
462 if n_spirals == 0 {
463 return Err(DatasetsError::InvalidFormat(
464 "n_spirals must be > 0".to_string(),
465 ));
466 }
467
468 if noise < 0.0 {
469 return Err(DatasetsError::InvalidFormat(
470 "noise must be >= 0.0".to_string(),
471 ));
472 }
473
474 let mut rng = match random_seed {
475 Some(seed) => StdRng::seed_from_u64(seed),
476 None => {
477 let mut r = rng();
478 StdRng::seed_from_u64(r.next_u64())
479 }
480 };
481
482 let mut data = Array2::zeros((n_samples, 2));
483 let mut target = Array1::zeros(n_samples);
484
485 let normal = if noise > 0.0 {
486 Some(rand_distr::Normal::new(0.0, noise).unwrap())
487 } else {
488 None
489 };
490
491 let samples_per_spiral = n_samples / n_spirals;
492 let remainder = n_samples % n_spirals;
493
494 let mut sample_idx = 0;
495
496 for spiral in 0..n_spirals {
497 let n_samples_spiral = if spiral < remainder {
498 samples_per_spiral + 1
499 } else {
500 samples_per_spiral
501 };
502
503 let spiral_offset = 2.0 * PI * spiral as f64 / n_spirals as f64;
504
505 for i in 0..n_samples_spiral {
506 let t = 2.0 * PI * i as f64 / n_samples_spiral as f64;
507 let radius = t / (2.0 * PI);
508
509 let mut x = radius * (t + spiral_offset).cos();
510 let mut y = radius * (t + spiral_offset).sin();
511
512 if let Some(ref normal_dist) = normal {
514 x += normal_dist.sample(&mut rng);
515 y += normal_dist.sample(&mut rng);
516 }
517
518 data[[sample_idx, 0]] = x;
519 data[[sample_idx, 1]] = y;
520 target[sample_idx] = spiral as f64;
521 sample_idx += 1;
522 }
523 }
524
525 let mut dataset = Dataset::new(data, Some(target));
526 dataset = dataset
527 .with_feature_names(vec!["x".to_string(), "y".to_string()])
528 .with_target_names((0..n_spirals).map(|i| format!("spiral_{}", i)).collect())
529 .with_description(format!("Spiral dataset with {} spirals", n_spirals))
530 .with_metadata("noise", &noise.to_string());
531
532 Ok(dataset)
533}
534
535pub fn make_moons(n_samples: usize, noise: f64, random_seed: Option<u64>) -> Result<Dataset> {
537 if n_samples == 0 {
539 return Err(DatasetsError::InvalidFormat(
540 "n_samples must be > 0".to_string(),
541 ));
542 }
543
544 if noise < 0.0 {
545 return Err(DatasetsError::InvalidFormat(
546 "noise must be >= 0.0".to_string(),
547 ));
548 }
549
550 let mut rng = match random_seed {
551 Some(seed) => StdRng::seed_from_u64(seed),
552 None => {
553 let mut r = rng();
554 StdRng::seed_from_u64(r.next_u64())
555 }
556 };
557
558 let mut data = Array2::zeros((n_samples, 2));
559 let mut target = Array1::zeros(n_samples);
560
561 let normal = if noise > 0.0 {
562 Some(rand_distr::Normal::new(0.0, noise).unwrap())
563 } else {
564 None
565 };
566
567 let samples_per_moon = n_samples / 2;
568 let remainder = n_samples % 2;
569
570 let mut sample_idx = 0;
571
572 for i in 0..(samples_per_moon + remainder) {
574 let t = PI * i as f64 / (samples_per_moon + remainder) as f64;
575
576 let mut x = t.cos();
577 let mut y = t.sin();
578
579 if let Some(ref normal_dist) = normal {
581 x += normal_dist.sample(&mut rng);
582 y += normal_dist.sample(&mut rng);
583 }
584
585 data[[sample_idx, 0]] = x;
586 data[[sample_idx, 1]] = y;
587 target[sample_idx] = 0.0;
588 sample_idx += 1;
589 }
590
591 for i in 0..samples_per_moon {
593 let t = PI * i as f64 / samples_per_moon as f64;
594
595 let mut x = 1.0 - t.cos();
596 let mut y = 0.5 - t.sin(); if let Some(ref normal_dist) = normal {
600 x += normal_dist.sample(&mut rng);
601 y += normal_dist.sample(&mut rng);
602 }
603
604 data[[sample_idx, 0]] = x;
605 data[[sample_idx, 1]] = y;
606 target[sample_idx] = 1.0;
607 sample_idx += 1;
608 }
609
610 let mut dataset = Dataset::new(data, Some(target));
611 dataset = dataset
612 .with_feature_names(vec!["x".to_string(), "y".to_string()])
613 .with_target_names(vec!["moon_0".to_string(), "moon_1".to_string()])
614 .with_description("Two moons dataset for non-linear classification".to_string())
615 .with_metadata("noise", &noise.to_string());
616
617 Ok(dataset)
618}
619
620pub fn make_circles(
622 n_samples: usize,
623 factor: f64,
624 noise: f64,
625 random_seed: Option<u64>,
626) -> Result<Dataset> {
627 if n_samples == 0 {
629 return Err(DatasetsError::InvalidFormat(
630 "n_samples must be > 0".to_string(),
631 ));
632 }
633
634 if factor <= 0.0 || factor >= 1.0 {
635 return Err(DatasetsError::InvalidFormat(
636 "factor must be between 0.0 and 1.0".to_string(),
637 ));
638 }
639
640 if noise < 0.0 {
641 return Err(DatasetsError::InvalidFormat(
642 "noise must be >= 0.0".to_string(),
643 ));
644 }
645
646 let mut rng = match random_seed {
647 Some(seed) => StdRng::seed_from_u64(seed),
648 None => {
649 let mut r = rng();
650 StdRng::seed_from_u64(r.next_u64())
651 }
652 };
653
654 let mut data = Array2::zeros((n_samples, 2));
655 let mut target = Array1::zeros(n_samples);
656
657 let normal = if noise > 0.0 {
658 Some(rand_distr::Normal::new(0.0, noise).unwrap())
659 } else {
660 None
661 };
662
663 let samples_per_circle = n_samples / 2;
664 let remainder = n_samples % 2;
665
666 let mut sample_idx = 0;
667
668 for i in 0..(samples_per_circle + remainder) {
670 let angle = 2.0 * PI * i as f64 / (samples_per_circle + remainder) as f64;
671
672 let mut x = angle.cos();
673 let mut y = angle.sin();
674
675 if let Some(ref normal_dist) = normal {
677 x += normal_dist.sample(&mut rng);
678 y += normal_dist.sample(&mut rng);
679 }
680
681 data[[sample_idx, 0]] = x;
682 data[[sample_idx, 1]] = y;
683 target[sample_idx] = 0.0;
684 sample_idx += 1;
685 }
686
687 for i in 0..samples_per_circle {
689 let angle = 2.0 * PI * i as f64 / samples_per_circle as f64;
690
691 let mut x = factor * angle.cos();
692 let mut y = factor * angle.sin();
693
694 if let Some(ref normal_dist) = normal {
696 x += normal_dist.sample(&mut rng);
697 y += normal_dist.sample(&mut rng);
698 }
699
700 data[[sample_idx, 0]] = x;
701 data[[sample_idx, 1]] = y;
702 target[sample_idx] = 1.0;
703 sample_idx += 1;
704 }
705
706 let mut dataset = Dataset::new(data, Some(target));
707 dataset = dataset
708 .with_feature_names(vec!["x".to_string(), "y".to_string()])
709 .with_target_names(vec!["outer_circle".to_string(), "inner_circle".to_string()])
710 .with_description("Concentric circles dataset for non-linear classification".to_string())
711 .with_metadata("factor", &factor.to_string())
712 .with_metadata("noise", &noise.to_string());
713
714 Ok(dataset)
715}
716
717pub fn make_swiss_roll(n_samples: usize, noise: f64, random_seed: Option<u64>) -> Result<Dataset> {
719 if n_samples == 0 {
721 return Err(DatasetsError::InvalidFormat(
722 "n_samples must be > 0".to_string(),
723 ));
724 }
725
726 if noise < 0.0 {
727 return Err(DatasetsError::InvalidFormat(
728 "noise must be >= 0.0".to_string(),
729 ));
730 }
731
732 let mut rng = match random_seed {
733 Some(seed) => StdRng::seed_from_u64(seed),
734 None => {
735 let mut r = rng();
736 StdRng::seed_from_u64(r.next_u64())
737 }
738 };
739
740 let mut data = Array2::zeros((n_samples, 3));
741 let mut color = Array1::zeros(n_samples); let normal = if noise > 0.0 {
744 Some(rand_distr::Normal::new(0.0, noise).unwrap())
745 } else {
746 None
747 };
748
749 for i in 0..n_samples {
750 let t = 1.5 * PI * (1.0 + 2.0 * i as f64 / n_samples as f64);
752
753 let height = 21.0 * i as f64 / n_samples as f64;
755
756 let mut x = t * t.cos();
757 let mut y = height;
758 let mut z = t * t.sin();
759
760 if let Some(ref normal_dist) = normal {
762 x += normal_dist.sample(&mut rng);
763 y += normal_dist.sample(&mut rng);
764 z += normal_dist.sample(&mut rng);
765 }
766
767 data[[i, 0]] = x;
768 data[[i, 1]] = y;
769 data[[i, 2]] = z;
770 color[i] = t; }
772
773 let mut dataset = Dataset::new(data, Some(color));
774 dataset = dataset
775 .with_feature_names(vec!["x".to_string(), "y".to_string(), "z".to_string()])
776 .with_description("Swiss roll manifold dataset for dimensionality reduction".to_string())
777 .with_metadata("noise", &noise.to_string())
778 .with_metadata("dimensions", "3")
779 .with_metadata("manifold_dim", "2");
780
781 Ok(dataset)
782}
783
784pub fn make_anisotropic_blobs(
786 n_samples: usize,
787 n_features: usize,
788 centers: usize,
789 cluster_std: f64,
790 anisotropy_factor: f64,
791 random_seed: Option<u64>,
792) -> Result<Dataset> {
793 if n_samples == 0 {
795 return Err(DatasetsError::InvalidFormat(
796 "n_samples must be > 0".to_string(),
797 ));
798 }
799
800 if n_features < 2 {
801 return Err(DatasetsError::InvalidFormat(
802 "n_features must be >= 2 for anisotropic clusters".to_string(),
803 ));
804 }
805
806 if centers == 0 {
807 return Err(DatasetsError::InvalidFormat(
808 "centers must be > 0".to_string(),
809 ));
810 }
811
812 if cluster_std <= 0.0 {
813 return Err(DatasetsError::InvalidFormat(
814 "cluster_std must be > 0.0".to_string(),
815 ));
816 }
817
818 if anisotropy_factor <= 0.0 {
819 return Err(DatasetsError::InvalidFormat(
820 "anisotropy_factor must be > 0.0".to_string(),
821 ));
822 }
823
824 let mut rng = match random_seed {
825 Some(seed) => StdRng::seed_from_u64(seed),
826 None => {
827 let mut r = rng();
828 StdRng::seed_from_u64(r.next_u64())
829 }
830 };
831
832 let mut cluster_centers = Array2::zeros((centers, n_features));
834 let center_box = 10.0;
835
836 for i in 0..centers {
837 for j in 0..n_features {
838 cluster_centers[[i, j]] = rng.random_range(-center_box..=center_box);
839 }
840 }
841
842 let mut data = Array2::zeros((n_samples, n_features));
844 let mut target = Array1::zeros(n_samples);
845
846 let normal = rand_distr::Normal::new(0.0, cluster_std).unwrap();
847
848 let samples_per_center = n_samples / centers;
849 let remainder = n_samples % centers;
850
851 let mut sample_idx = 0;
852
853 for center_idx in 0..centers {
854 let n_samples_center = if center_idx < remainder {
855 samples_per_center + 1
856 } else {
857 samples_per_center
858 };
859
860 let rotation_angle = rng.random_range(0.0..(2.0 * PI));
862
863 for _ in 0..n_samples_center {
864 let mut point = vec![0.0; n_features];
866
867 point[0] = normal.sample(&mut rng);
869 point[1] = normal.sample(&mut rng) / anisotropy_factor;
870
871 for item in point.iter_mut().take(n_features).skip(2) {
873 *item = normal.sample(&mut rng);
874 }
875
876 if n_features >= 2 {
878 let cos_theta = rotation_angle.cos();
879 let sin_theta = rotation_angle.sin();
880
881 let x_rot = cos_theta * point[0] - sin_theta * point[1];
882 let y_rot = sin_theta * point[0] + cos_theta * point[1];
883
884 point[0] = x_rot;
885 point[1] = y_rot;
886 }
887
888 for j in 0..n_features {
890 data[[sample_idx, j]] = cluster_centers[[center_idx, j]] + point[j];
891 }
892
893 target[sample_idx] = center_idx as f64;
894 sample_idx += 1;
895 }
896 }
897
898 let mut dataset = Dataset::new(data, Some(target));
899 let feature_names: Vec<String> = (0..n_features).map(|i| format!("feature_{}", i)).collect();
900
901 dataset = dataset
902 .with_feature_names(feature_names)
903 .with_description(format!(
904 "Anisotropic clustering dataset with {} elongated clusters and {} features",
905 centers, n_features
906 ))
907 .with_metadata("centers", ¢ers.to_string())
908 .with_metadata("cluster_std", &cluster_std.to_string())
909 .with_metadata("anisotropy_factor", &anisotropy_factor.to_string());
910
911 Ok(dataset)
912}
913
914pub fn make_hierarchical_clusters(
916 n_samples: usize,
917 n_features: usize,
918 n_main_clusters: usize,
919 n_sub_clusters: usize,
920 main_cluster_std: f64,
921 sub_cluster_std: f64,
922 random_seed: Option<u64>,
923) -> Result<Dataset> {
924 if n_samples == 0 {
926 return Err(DatasetsError::InvalidFormat(
927 "n_samples must be > 0".to_string(),
928 ));
929 }
930
931 if n_features == 0 {
932 return Err(DatasetsError::InvalidFormat(
933 "n_features must be > 0".to_string(),
934 ));
935 }
936
937 if n_main_clusters == 0 {
938 return Err(DatasetsError::InvalidFormat(
939 "n_main_clusters must be > 0".to_string(),
940 ));
941 }
942
943 if n_sub_clusters == 0 {
944 return Err(DatasetsError::InvalidFormat(
945 "n_sub_clusters must be > 0".to_string(),
946 ));
947 }
948
949 if main_cluster_std <= 0.0 {
950 return Err(DatasetsError::InvalidFormat(
951 "main_cluster_std must be > 0.0".to_string(),
952 ));
953 }
954
955 if sub_cluster_std <= 0.0 {
956 return Err(DatasetsError::InvalidFormat(
957 "sub_cluster_std must be > 0.0".to_string(),
958 ));
959 }
960
961 let mut rng = match random_seed {
962 Some(seed) => StdRng::seed_from_u64(seed),
963 None => {
964 let mut r = rng();
965 StdRng::seed_from_u64(r.next_u64())
966 }
967 };
968
969 let mut main_centers = Array2::zeros((n_main_clusters, n_features));
971 let center_box = 20.0;
972
973 for i in 0..n_main_clusters {
974 for j in 0..n_features {
975 main_centers[[i, j]] = rng.random_range(-center_box..=center_box);
976 }
977 }
978
979 let mut data = Array2::zeros((n_samples, n_features));
980 let mut main_target = Array1::zeros(n_samples);
981 let mut sub_target = Array1::zeros(n_samples);
982
983 let main_normal = rand_distr::Normal::new(0.0, main_cluster_std).unwrap();
984 let sub_normal = rand_distr::Normal::new(0.0, sub_cluster_std).unwrap();
985
986 let samples_per_main = n_samples / n_main_clusters;
987 let remainder = n_samples % n_main_clusters;
988
989 let mut sample_idx = 0;
990
991 for main_idx in 0..n_main_clusters {
992 let n_samples_main = if main_idx < remainder {
993 samples_per_main + 1
994 } else {
995 samples_per_main
996 };
997
998 let mut sub_centers = Array2::zeros((n_sub_clusters, n_features));
1000 for i in 0..n_sub_clusters {
1001 for j in 0..n_features {
1002 sub_centers[[i, j]] = main_centers[[main_idx, j]] + main_normal.sample(&mut rng);
1003 }
1004 }
1005
1006 let samples_per_sub = n_samples_main / n_sub_clusters;
1007 let sub_remainder = n_samples_main % n_sub_clusters;
1008
1009 for sub_idx in 0..n_sub_clusters {
1010 let n_samples_sub = if sub_idx < sub_remainder {
1011 samples_per_sub + 1
1012 } else {
1013 samples_per_sub
1014 };
1015
1016 for _ in 0..n_samples_sub {
1017 for j in 0..n_features {
1018 data[[sample_idx, j]] = sub_centers[[sub_idx, j]] + sub_normal.sample(&mut rng);
1019 }
1020
1021 main_target[sample_idx] = main_idx as f64;
1022 sub_target[sample_idx] = (main_idx * n_sub_clusters + sub_idx) as f64;
1023 sample_idx += 1;
1024 }
1025 }
1026 }
1027
1028 let mut dataset = Dataset::new(data, Some(main_target));
1029 let feature_names: Vec<String> = (0..n_features).map(|i| format!("feature_{}", i)).collect();
1030
1031 dataset = dataset
1032 .with_feature_names(feature_names)
1033 .with_description(format!(
1034 "Hierarchical clustering dataset with {} main clusters, {} sub-clusters each",
1035 n_main_clusters, n_sub_clusters
1036 ))
1037 .with_metadata("n_main_clusters", &n_main_clusters.to_string())
1038 .with_metadata("n_sub_clusters", &n_sub_clusters.to_string())
1039 .with_metadata("main_cluster_std", &main_cluster_std.to_string())
1040 .with_metadata("sub_cluster_std", &sub_cluster_std.to_string())
1041 .with_metadata("sub_cluster_labels", &format!("{:?}", sub_target.to_vec()));
1042
1043 Ok(dataset)
1044}
1045
1046#[derive(Debug, Clone, Copy)]
1048pub enum MissingPattern {
1049 MCAR,
1051 MAR,
1053 MNAR,
1055 Block,
1057}
1058
1059#[derive(Debug, Clone, Copy)]
1061pub enum OutlierType {
1062 Point,
1064 Contextual,
1066 Collective,
1068}
1069
1070pub fn inject_missing_data(
1072 data: &mut Array2<f64>,
1073 missing_rate: f64,
1074 pattern: MissingPattern,
1075 random_seed: Option<u64>,
1076) -> Result<Array2<bool>> {
1077 if !(0.0..=1.0).contains(&missing_rate) {
1079 return Err(DatasetsError::InvalidFormat(
1080 "missing_rate must be between 0.0 and 1.0".to_string(),
1081 ));
1082 }
1083
1084 let mut rng = match random_seed {
1085 Some(seed) => StdRng::seed_from_u64(seed),
1086 None => {
1087 let mut r = rng();
1088 StdRng::seed_from_u64(r.next_u64())
1089 }
1090 };
1091
1092 let (n_samples, n_features) = data.dim();
1093 let mut missing_mask = Array2::from_elem((n_samples, n_features), false);
1094
1095 match pattern {
1096 MissingPattern::MCAR => {
1097 for i in 0..n_samples {
1099 for j in 0..n_features {
1100 if rng.random_range(0.0f64..1.0) < missing_rate {
1101 missing_mask[[i, j]] = true;
1102 data[[i, j]] = f64::NAN;
1103 }
1104 }
1105 }
1106 }
1107 MissingPattern::MAR => {
1108 for i in 0..n_samples {
1110 let first_feature_val = data[[i, 0]];
1111 let normalized_val = (first_feature_val + 10.0) / 20.0; let adjusted_rate = missing_rate * normalized_val.clamp(0.1, 2.0);
1113
1114 for j in 1..n_features {
1115 if rng.random_range(0.0f64..1.0) < adjusted_rate {
1117 missing_mask[[i, j]] = true;
1118 data[[i, j]] = f64::NAN;
1119 }
1120 }
1121 }
1122 }
1123 MissingPattern::MNAR => {
1124 for i in 0..n_samples {
1126 for j in 0..n_features {
1127 let value = data[[i, j]];
1128 let normalized_val = (value + 10.0) / 20.0; let adjusted_rate = missing_rate * normalized_val.clamp(0.1, 3.0);
1130
1131 if rng.random_range(0.0f64..1.0) < adjusted_rate {
1132 missing_mask[[i, j]] = true;
1133 data[[i, j]] = f64::NAN;
1134 }
1135 }
1136 }
1137 }
1138 MissingPattern::Block => {
1139 let block_size = (n_features as f64 * missing_rate).ceil() as usize;
1141 let n_blocks = (missing_rate * n_samples as f64).ceil() as usize;
1142
1143 for _ in 0..n_blocks {
1144 let start_row = rng.random_range(0..n_samples);
1145 let start_col = rng.random_range(0..n_features.saturating_sub(block_size));
1146
1147 for i in start_row..n_samples.min(start_row + block_size) {
1148 for j in start_col..n_features.min(start_col + block_size) {
1149 missing_mask[[i, j]] = true;
1150 data[[i, j]] = f64::NAN;
1151 }
1152 }
1153 }
1154 }
1155 }
1156
1157 Ok(missing_mask)
1158}
1159
1160pub fn inject_outliers(
1162 data: &mut Array2<f64>,
1163 outlier_rate: f64,
1164 outlier_type: OutlierType,
1165 outlier_strength: f64,
1166 random_seed: Option<u64>,
1167) -> Result<Array1<bool>> {
1168 if !(0.0..=1.0).contains(&outlier_rate) {
1170 return Err(DatasetsError::InvalidFormat(
1171 "outlier_rate must be between 0.0 and 1.0".to_string(),
1172 ));
1173 }
1174
1175 if outlier_strength <= 0.0 {
1176 return Err(DatasetsError::InvalidFormat(
1177 "outlier_strength must be > 0.0".to_string(),
1178 ));
1179 }
1180
1181 let mut rng = match random_seed {
1182 Some(seed) => StdRng::seed_from_u64(seed),
1183 None => {
1184 let mut r = rng();
1185 StdRng::seed_from_u64(r.next_u64())
1186 }
1187 };
1188
1189 let (n_samples, n_features) = data.dim();
1190 let n_outliers = (n_samples as f64 * outlier_rate).ceil() as usize;
1191 let mut outlier_mask = Array1::from_elem(n_samples, false);
1192
1193 let mut feature_means = vec![0.0; n_features];
1195 let mut feature_stds = vec![0.0; n_features];
1196
1197 for j in 0..n_features {
1198 let column = data.column(j);
1199 let valid_values: Vec<f64> = column.iter().filter(|&&x| !x.is_nan()).cloned().collect();
1200
1201 if !valid_values.is_empty() {
1202 feature_means[j] = valid_values.iter().sum::<f64>() / valid_values.len() as f64;
1203 let variance = valid_values
1204 .iter()
1205 .map(|&x| (x - feature_means[j]).powi(2))
1206 .sum::<f64>()
1207 / valid_values.len() as f64;
1208 feature_stds[j] = variance.sqrt().max(1.0); }
1210 }
1211
1212 match outlier_type {
1213 OutlierType::Point => {
1214 for _ in 0..n_outliers {
1216 let outlier_idx = rng.random_range(0..n_samples);
1217 outlier_mask[outlier_idx] = true;
1218
1219 for j in 0..n_features {
1221 let direction = if rng.random_range(0.0f64..1.0) < 0.5 {
1222 -1.0
1223 } else {
1224 1.0
1225 };
1226 data[[outlier_idx, j]] =
1227 feature_means[j] + direction * outlier_strength * feature_stds[j];
1228 }
1229 }
1230 }
1231 OutlierType::Contextual => {
1232 for _ in 0..n_outliers {
1234 let outlier_idx = rng.random_range(0..n_samples);
1235 outlier_mask[outlier_idx] = true;
1236
1237 let n_features_to_modify = rng.random_range(1..=(n_features / 2).max(1));
1239 let mut features_to_modify: Vec<usize> = (0..n_features).collect();
1240 features_to_modify.shuffle(&mut rng);
1241 features_to_modify.truncate(n_features_to_modify);
1242
1243 for &j in &features_to_modify {
1244 let direction = if rng.random_range(0.0f64..1.0) < 0.5 {
1245 -1.0
1246 } else {
1247 1.0
1248 };
1249 data[[outlier_idx, j]] =
1250 feature_means[j] + direction * outlier_strength * feature_stds[j];
1251 }
1252 }
1253 }
1254 OutlierType::Collective => {
1255 let outliers_per_group = (n_outliers / 3).max(2); let n_groups = (n_outliers / outliers_per_group).max(1);
1258
1259 for _ in 0..n_groups {
1260 let mut outlier_center = vec![0.0; n_features];
1262 for j in 0..n_features {
1263 let direction = if rng.random_range(0.0f64..1.0) < 0.5 {
1264 -1.0
1265 } else {
1266 1.0
1267 };
1268 outlier_center[j] =
1269 feature_means[j] + direction * outlier_strength * feature_stds[j];
1270 }
1271
1272 for _ in 0..outliers_per_group {
1274 let outlier_idx = rng.random_range(0..n_samples);
1275 outlier_mask[outlier_idx] = true;
1276
1277 for j in 0..n_features {
1278 let noise = rng.random_range(-0.5f64..0.5f64) * feature_stds[j];
1279 data[[outlier_idx, j]] = outlier_center[j] + noise;
1280 }
1281 }
1282 }
1283 }
1284 }
1285
1286 Ok(outlier_mask)
1287}
1288
1289pub fn add_time_series_noise(
1291 data: &mut Array2<f64>,
1292 noise_types: &[(&str, f64)], random_seed: Option<u64>,
1294) -> Result<()> {
1295 let mut rng = match random_seed {
1296 Some(seed) => StdRng::seed_from_u64(seed),
1297 None => {
1298 let mut r = rng();
1299 StdRng::seed_from_u64(r.next_u64())
1300 }
1301 };
1302
1303 let (n_samples, n_features) = data.dim();
1304 let normal = rand_distr::Normal::new(0.0, 1.0).unwrap();
1305
1306 for &(noise_type, strength) in noise_types {
1307 match noise_type {
1308 "gaussian" => {
1309 for i in 0..n_samples {
1311 for j in 0..n_features {
1312 data[[i, j]] += strength * normal.sample(&mut rng);
1313 }
1314 }
1315 }
1316 "spikes" => {
1317 let n_spikes = (n_samples as f64 * strength * 0.1).ceil() as usize;
1319 for _ in 0..n_spikes {
1320 let spike_idx = rng.random_range(0..n_samples);
1321 let feature_idx = rng.random_range(0..n_features);
1322 let spike_magnitude = rng.random_range(5.0..=15.0) * strength;
1323 let direction = if rng.random_range(0.0f64..1.0) < 0.5 {
1324 -1.0
1325 } else {
1326 1.0
1327 };
1328
1329 data[[spike_idx, feature_idx]] += direction * spike_magnitude;
1330 }
1331 }
1332 "drift" => {
1333 for i in 0..n_samples {
1335 let drift_amount = strength * (i as f64 / n_samples as f64);
1336 for j in 0..n_features {
1337 data[[i, j]] += drift_amount;
1338 }
1339 }
1340 }
1341 "seasonal" => {
1342 let period = n_samples as f64 / 4.0; for i in 0..n_samples {
1345 let seasonal_component = strength * (2.0 * PI * i as f64 / period).sin();
1346 for j in 0..n_features {
1347 data[[i, j]] += seasonal_component;
1348 }
1349 }
1350 }
1351 "autocorrelated" => {
1352 let ar_coeff = 0.7; for j in 0..n_features {
1355 let mut prev_noise = 0.0;
1356 for i in 0..n_samples {
1357 let new_noise = ar_coeff * prev_noise + strength * normal.sample(&mut rng);
1358 data[[i, j]] += new_noise;
1359 prev_noise = new_noise;
1360 }
1361 }
1362 }
1363 "heteroscedastic" => {
1364 for i in 0..n_samples {
1366 let variance_factor = 1.0 + strength * (i as f64 / n_samples as f64);
1367 for j in 0..n_features {
1368 data[[i, j]] += variance_factor * strength * normal.sample(&mut rng);
1369 }
1370 }
1371 }
1372 _ => {
1373 return Err(DatasetsError::InvalidFormat(format!(
1374 "Unknown noise type: {}. Supported types: gaussian, spikes, drift, seasonal, autocorrelated, heteroscedastic",
1375 noise_type
1376 )));
1377 }
1378 }
1379 }
1380
1381 Ok(())
1382}
1383
1384pub fn make_corrupted_dataset(
1386 base_dataset: &Dataset,
1387 missing_rate: f64,
1388 missing_pattern: MissingPattern,
1389 outlier_rate: f64,
1390 outlier_type: OutlierType,
1391 outlier_strength: f64,
1392 random_seed: Option<u64>,
1393) -> Result<Dataset> {
1394 if !(0.0..=1.0).contains(&missing_rate) {
1396 return Err(DatasetsError::InvalidFormat(
1397 "missing_rate must be between 0.0 and 1.0".to_string(),
1398 ));
1399 }
1400
1401 if !(0.0..=1.0).contains(&outlier_rate) {
1402 return Err(DatasetsError::InvalidFormat(
1403 "outlier_rate must be between 0.0 and 1.0".to_string(),
1404 ));
1405 }
1406
1407 let mut corrupted_data = base_dataset.data.clone();
1409 let corrupted_target = base_dataset.target.clone();
1410
1411 let missing_mask = inject_missing_data(
1413 &mut corrupted_data,
1414 missing_rate,
1415 missing_pattern,
1416 random_seed,
1417 )?;
1418
1419 let outlier_mask = inject_outliers(
1421 &mut corrupted_data,
1422 outlier_rate,
1423 outlier_type,
1424 outlier_strength,
1425 random_seed,
1426 )?;
1427
1428 let mut corrupted_dataset = Dataset::new(corrupted_data, corrupted_target);
1430
1431 if let Some(feature_names) = &base_dataset.feature_names {
1432 corrupted_dataset = corrupted_dataset.with_feature_names(feature_names.clone());
1433 }
1434
1435 if let Some(target_names) = &base_dataset.target_names {
1436 corrupted_dataset = corrupted_dataset.with_target_names(target_names.clone());
1437 }
1438
1439 corrupted_dataset = corrupted_dataset
1440 .with_description(format!(
1441 "Corrupted version of: {}",
1442 base_dataset
1443 .description
1444 .as_deref()
1445 .unwrap_or("Unknown dataset")
1446 ))
1447 .with_metadata("missing_rate", &missing_rate.to_string())
1448 .with_metadata("missing_pattern", &format!("{:?}", missing_pattern))
1449 .with_metadata("outlier_rate", &outlier_rate.to_string())
1450 .with_metadata("outlier_type", &format!("{:?}", outlier_type))
1451 .with_metadata("outlier_strength", &outlier_strength.to_string())
1452 .with_metadata(
1453 "missing_count",
1454 &missing_mask.iter().filter(|&&x| x).count().to_string(),
1455 )
1456 .with_metadata(
1457 "outlier_count",
1458 &outlier_mask.iter().filter(|&&x| x).count().to_string(),
1459 );
1460
1461 Ok(corrupted_dataset)
1462}
1463
1464#[cfg(test)]
1465mod tests {
1466 use super::*;
1467
1468 #[test]
1469 fn test_make_classification_invalid_params() {
1470 assert!(make_classification(0, 5, 2, 1, 3, None).is_err());
1472
1473 assert!(make_classification(10, 0, 2, 1, 3, None).is_err());
1475
1476 assert!(make_classification(10, 5, 2, 1, 0, None).is_err());
1478
1479 assert!(make_classification(10, 3, 2, 1, 5, None).is_err());
1481
1482 assert!(make_classification(10, 5, 1, 1, 3, None).is_err());
1484
1485 assert!(make_classification(10, 5, 2, 0, 3, None).is_err());
1487 }
1488
1489 #[test]
1490 fn test_make_regression_invalid_params() {
1491 assert!(make_regression(0, 5, 3, 1.0, None).is_err());
1493
1494 assert!(make_regression(10, 0, 3, 1.0, None).is_err());
1496
1497 assert!(make_regression(10, 5, 0, 1.0, None).is_err());
1499
1500 assert!(make_regression(10, 3, 5, 1.0, None).is_err());
1502
1503 assert!(make_regression(10, 5, 3, -1.0, None).is_err());
1505 }
1506
1507 #[test]
1508 fn test_make_time_series_invalid_params() {
1509 assert!(make_time_series(0, 3, false, false, 1.0, None).is_err());
1511
1512 assert!(make_time_series(10, 0, false, false, 1.0, None).is_err());
1514
1515 assert!(make_time_series(10, 3, false, false, -1.0, None).is_err());
1517 }
1518
1519 #[test]
1520 fn test_make_blobs_invalid_params() {
1521 assert!(make_blobs(0, 3, 2, 1.0, None).is_err());
1523
1524 assert!(make_blobs(10, 0, 2, 1.0, None).is_err());
1526
1527 assert!(make_blobs(10, 3, 0, 1.0, None).is_err());
1529
1530 assert!(make_blobs(10, 3, 2, 0.0, None).is_err());
1532 assert!(make_blobs(10, 3, 2, -1.0, None).is_err());
1533 }
1534
1535 #[test]
1536 fn test_make_classification_valid_params() {
1537 let dataset = make_classification(20, 5, 3, 2, 4, Some(42)).unwrap();
1538 assert_eq!(dataset.n_samples(), 20);
1539 assert_eq!(dataset.n_features(), 5);
1540 assert!(dataset.target.is_some());
1541 assert!(dataset.feature_names.is_some());
1542 assert!(dataset.target_names.is_some());
1543 }
1544
1545 #[test]
1546 fn test_make_regression_valid_params() {
1547 let dataset = make_regression(15, 4, 3, 0.5, Some(42)).unwrap();
1548 assert_eq!(dataset.n_samples(), 15);
1549 assert_eq!(dataset.n_features(), 4);
1550 assert!(dataset.target.is_some());
1551 assert!(dataset.feature_names.is_some());
1552 }
1553
1554 #[test]
1555 fn test_make_time_series_valid_params() {
1556 let dataset = make_time_series(25, 3, true, true, 0.1, Some(42)).unwrap();
1557 assert_eq!(dataset.n_samples(), 25);
1558 assert_eq!(dataset.n_features(), 3);
1559 assert!(dataset.feature_names.is_some());
1560 assert!(dataset.target.is_none());
1562 }
1563
1564 #[test]
1565 fn test_make_blobs_valid_params() {
1566 let dataset = make_blobs(30, 4, 3, 1.5, Some(42)).unwrap();
1567 assert_eq!(dataset.n_samples(), 30);
1568 assert_eq!(dataset.n_features(), 4);
1569 assert!(dataset.target.is_some());
1570 assert!(dataset.feature_names.is_some());
1571 }
1572
1573 #[test]
1574 fn test_make_spirals_invalid_params() {
1575 assert!(make_spirals(0, 2, 0.1, None).is_err());
1577
1578 assert!(make_spirals(100, 0, 0.1, None).is_err());
1580
1581 assert!(make_spirals(100, 2, -0.1, None).is_err());
1583 }
1584
1585 #[test]
1586 fn test_make_spirals_valid_params() {
1587 let dataset = make_spirals(100, 2, 0.1, Some(42)).unwrap();
1588 assert_eq!(dataset.n_samples(), 100);
1589 assert_eq!(dataset.n_features(), 2);
1590 assert!(dataset.target.is_some());
1591 assert!(dataset.feature_names.is_some());
1592
1593 if let Some(target) = &dataset.target {
1595 let unique_labels: std::collections::HashSet<_> =
1596 target.iter().map(|&x| x as i32).collect();
1597 assert_eq!(unique_labels.len(), 2);
1598 }
1599 }
1600
1601 #[test]
1602 fn test_make_moons_invalid_params() {
1603 assert!(make_moons(0, 0.1, None).is_err());
1605
1606 assert!(make_moons(100, -0.1, None).is_err());
1608 }
1609
1610 #[test]
1611 fn test_make_moons_valid_params() {
1612 let dataset = make_moons(100, 0.1, Some(42)).unwrap();
1613 assert_eq!(dataset.n_samples(), 100);
1614 assert_eq!(dataset.n_features(), 2);
1615 assert!(dataset.target.is_some());
1616 assert!(dataset.feature_names.is_some());
1617
1618 if let Some(target) = &dataset.target {
1620 let unique_labels: std::collections::HashSet<_> =
1621 target.iter().map(|&x| x as i32).collect();
1622 assert_eq!(unique_labels.len(), 2);
1623 }
1624 }
1625
1626 #[test]
1627 fn test_make_circles_invalid_params() {
1628 assert!(make_circles(0, 0.5, 0.1, None).is_err());
1630
1631 assert!(make_circles(100, 0.0, 0.1, None).is_err());
1633 assert!(make_circles(100, 1.0, 0.1, None).is_err());
1634 assert!(make_circles(100, 1.5, 0.1, None).is_err());
1635
1636 assert!(make_circles(100, 0.5, -0.1, None).is_err());
1638 }
1639
1640 #[test]
1641 fn test_make_circles_valid_params() {
1642 let dataset = make_circles(100, 0.5, 0.1, Some(42)).unwrap();
1643 assert_eq!(dataset.n_samples(), 100);
1644 assert_eq!(dataset.n_features(), 2);
1645 assert!(dataset.target.is_some());
1646 assert!(dataset.feature_names.is_some());
1647
1648 if let Some(target) = &dataset.target {
1650 let unique_labels: std::collections::HashSet<_> =
1651 target.iter().map(|&x| x as i32).collect();
1652 assert_eq!(unique_labels.len(), 2);
1653 }
1654 }
1655
1656 #[test]
1657 fn test_make_swiss_roll_invalid_params() {
1658 assert!(make_swiss_roll(0, 0.1, None).is_err());
1660
1661 assert!(make_swiss_roll(100, -0.1, None).is_err());
1663 }
1664
1665 #[test]
1666 fn test_make_swiss_roll_valid_params() {
1667 let dataset = make_swiss_roll(100, 0.1, Some(42)).unwrap();
1668 assert_eq!(dataset.n_samples(), 100);
1669 assert_eq!(dataset.n_features(), 3);
1670 assert!(dataset.target.is_some()); assert!(dataset.feature_names.is_some());
1672 }
1673
1674 #[test]
1675 fn test_make_anisotropic_blobs_invalid_params() {
1676 assert!(make_anisotropic_blobs(0, 3, 2, 1.0, 2.0, None).is_err());
1678
1679 assert!(make_anisotropic_blobs(100, 1, 2, 1.0, 2.0, None).is_err());
1681
1682 assert!(make_anisotropic_blobs(100, 3, 0, 1.0, 2.0, None).is_err());
1684
1685 assert!(make_anisotropic_blobs(100, 3, 2, 0.0, 2.0, None).is_err());
1687
1688 assert!(make_anisotropic_blobs(100, 3, 2, 1.0, 0.0, None).is_err());
1690 }
1691
1692 #[test]
1693 fn test_make_anisotropic_blobs_valid_params() {
1694 let dataset = make_anisotropic_blobs(100, 3, 2, 1.0, 3.0, Some(42)).unwrap();
1695 assert_eq!(dataset.n_samples(), 100);
1696 assert_eq!(dataset.n_features(), 3);
1697 assert!(dataset.target.is_some());
1698 assert!(dataset.feature_names.is_some());
1699
1700 if let Some(target) = &dataset.target {
1702 let unique_labels: std::collections::HashSet<_> =
1703 target.iter().map(|&x| x as i32).collect();
1704 assert_eq!(unique_labels.len(), 2);
1705 }
1706 }
1707
1708 #[test]
1709 fn test_make_hierarchical_clusters_invalid_params() {
1710 assert!(make_hierarchical_clusters(0, 3, 2, 3, 1.0, 0.5, None).is_err());
1712
1713 assert!(make_hierarchical_clusters(100, 0, 2, 3, 1.0, 0.5, None).is_err());
1715
1716 assert!(make_hierarchical_clusters(100, 3, 0, 3, 1.0, 0.5, None).is_err());
1718
1719 assert!(make_hierarchical_clusters(100, 3, 2, 0, 1.0, 0.5, None).is_err());
1721
1722 assert!(make_hierarchical_clusters(100, 3, 2, 3, 0.0, 0.5, None).is_err());
1724
1725 assert!(make_hierarchical_clusters(100, 3, 2, 3, 1.0, 0.0, None).is_err());
1727 }
1728
1729 #[test]
1730 fn test_make_hierarchical_clusters_valid_params() {
1731 let dataset = make_hierarchical_clusters(120, 3, 2, 3, 2.0, 0.5, Some(42)).unwrap();
1732 assert_eq!(dataset.n_samples(), 120);
1733 assert_eq!(dataset.n_features(), 3);
1734 assert!(dataset.target.is_some());
1735 assert!(dataset.feature_names.is_some());
1736
1737 if let Some(target) = &dataset.target {
1739 let unique_labels: std::collections::HashSet<_> =
1740 target.iter().map(|&x| x as i32).collect();
1741 assert_eq!(unique_labels.len(), 2); }
1743
1744 assert!(dataset.metadata.contains_key("sub_cluster_labels"));
1746 }
1747
1748 #[test]
1749 fn test_inject_missing_data_invalid_params() {
1750 let mut data = Array2::from_shape_vec((5, 3), vec![1.0; 15]).unwrap();
1751
1752 assert!(inject_missing_data(&mut data, -0.1, MissingPattern::MCAR, None).is_err());
1754 assert!(inject_missing_data(&mut data, 1.5, MissingPattern::MCAR, None).is_err());
1755 }
1756
1757 #[test]
1758 fn test_inject_missing_data_mcar() {
1759 let mut data =
1760 Array2::from_shape_vec((10, 3), (0..30).map(|x| x as f64).collect()).unwrap();
1761 let original_data = data.clone();
1762
1763 let missing_mask =
1764 inject_missing_data(&mut data, 0.3, MissingPattern::MCAR, Some(42)).unwrap();
1765
1766 let missing_count = missing_mask.iter().filter(|&&x| x).count();
1768 assert!(missing_count > 0);
1769
1770 for ((i, j), &is_missing) in missing_mask.indexed_iter() {
1772 if is_missing {
1773 assert!(data[[i, j]].is_nan());
1774 } else {
1775 assert_eq!(data[[i, j]], original_data[[i, j]]);
1776 }
1777 }
1778 }
1779
1780 #[test]
1781 fn test_inject_outliers_invalid_params() {
1782 let mut data = Array2::from_shape_vec((5, 3), vec![1.0; 15]).unwrap();
1783
1784 assert!(inject_outliers(&mut data, -0.1, OutlierType::Point, 2.0, None).is_err());
1786 assert!(inject_outliers(&mut data, 1.5, OutlierType::Point, 2.0, None).is_err());
1787
1788 assert!(inject_outliers(&mut data, 0.1, OutlierType::Point, 0.0, None).is_err());
1790 assert!(inject_outliers(&mut data, 0.1, OutlierType::Point, -1.0, None).is_err());
1791 }
1792
1793 #[test]
1794 fn test_inject_outliers_point() {
1795 let mut data = Array2::from_shape_vec((20, 2), vec![1.0; 40]).unwrap();
1796
1797 let outlier_mask =
1798 inject_outliers(&mut data, 0.2, OutlierType::Point, 3.0, Some(42)).unwrap();
1799
1800 let outlier_count = outlier_mask.iter().filter(|&&x| x).count();
1802 assert!(outlier_count > 0);
1803
1804 for (i, &is_outlier) in outlier_mask.iter().enumerate() {
1806 if is_outlier {
1807 let row = data.row(i);
1809 assert!(row.iter().any(|&x| (x - 1.0).abs() > 1.0));
1810 }
1811 }
1812 }
1813
1814 #[test]
1815 fn test_add_time_series_noise() {
1816 let mut data = Array2::zeros((100, 2));
1817
1818 let noise_types = [("gaussian", 0.1), ("spikes", 0.05), ("drift", 0.2)];
1819
1820 let original_data = data.clone();
1821 add_time_series_noise(&mut data, &noise_types, Some(42)).unwrap();
1822
1823 assert!(!data
1825 .iter()
1826 .zip(original_data.iter())
1827 .all(|(&a, &b)| (a - b).abs() < 1e-10));
1828
1829 let invalid_noise = [("invalid_type", 0.1)];
1831 let mut test_data = Array2::zeros((10, 2));
1832 assert!(add_time_series_noise(&mut test_data, &invalid_noise, Some(42)).is_err());
1833 }
1834
1835 #[test]
1836 fn test_make_corrupted_dataset() {
1837 let base_dataset = make_blobs(50, 3, 2, 1.0, Some(42)).unwrap();
1838
1839 let corrupted = make_corrupted_dataset(
1840 &base_dataset,
1841 0.1, MissingPattern::MCAR,
1843 0.05, OutlierType::Point,
1845 2.0, Some(42),
1847 )
1848 .unwrap();
1849
1850 assert_eq!(corrupted.n_samples(), base_dataset.n_samples());
1852 assert_eq!(corrupted.n_features(), base_dataset.n_features());
1853
1854 assert!(corrupted.metadata.contains_key("missing_rate"));
1856 assert!(corrupted.metadata.contains_key("outlier_rate"));
1857 assert!(corrupted.metadata.contains_key("missing_count"));
1858 assert!(corrupted.metadata.contains_key("outlier_count"));
1859
1860 let has_nan = corrupted.data.iter().any(|&x| x.is_nan());
1862 assert!(has_nan, "Dataset should have some missing values");
1863 }
1864
1865 #[test]
1866 fn test_make_corrupted_dataset_invalid_params() {
1867 let base_dataset = make_blobs(20, 2, 2, 1.0, Some(42)).unwrap();
1868
1869 assert!(make_corrupted_dataset(
1871 &base_dataset,
1872 -0.1,
1873 MissingPattern::MCAR,
1874 0.0,
1875 OutlierType::Point,
1876 1.0,
1877 None
1878 )
1879 .is_err());
1880 assert!(make_corrupted_dataset(
1881 &base_dataset,
1882 1.5,
1883 MissingPattern::MCAR,
1884 0.0,
1885 OutlierType::Point,
1886 1.0,
1887 None
1888 )
1889 .is_err());
1890
1891 assert!(make_corrupted_dataset(
1893 &base_dataset,
1894 0.0,
1895 MissingPattern::MCAR,
1896 -0.1,
1897 OutlierType::Point,
1898 1.0,
1899 None
1900 )
1901 .is_err());
1902 assert!(make_corrupted_dataset(
1903 &base_dataset,
1904 0.0,
1905 MissingPattern::MCAR,
1906 1.5,
1907 OutlierType::Point,
1908 1.0,
1909 None
1910 )
1911 .is_err());
1912 }
1913
1914 #[test]
1915 fn test_missing_patterns() {
1916 let data = Array2::from_shape_vec((20, 4), (0..80).map(|x| x as f64).collect()).unwrap();
1917
1918 for pattern in [
1920 MissingPattern::MCAR,
1921 MissingPattern::MAR,
1922 MissingPattern::MNAR,
1923 MissingPattern::Block,
1924 ] {
1925 let mut test_data = data.clone();
1926 let missing_mask = inject_missing_data(&mut test_data, 0.2, pattern, Some(42)).unwrap();
1927
1928 let missing_count = missing_mask.iter().filter(|&&x| x).count();
1929 assert!(
1930 missing_count > 0,
1931 "Pattern {:?} should create some missing values",
1932 pattern
1933 );
1934 }
1935 }
1936
1937 #[test]
1938 fn test_outlier_types() {
1939 let data = Array2::ones((30, 3));
1940
1941 for outlier_type in [
1943 OutlierType::Point,
1944 OutlierType::Contextual,
1945 OutlierType::Collective,
1946 ] {
1947 let mut test_data = data.clone();
1948 let outlier_mask =
1949 inject_outliers(&mut test_data, 0.2, outlier_type, 3.0, Some(42)).unwrap();
1950
1951 let outlier_count = outlier_mask.iter().filter(|&&x| x).count();
1952 assert!(
1953 outlier_count > 0,
1954 "Outlier type {:?} should create some outliers",
1955 outlier_type
1956 );
1957 }
1958 }
1959}