1use crate::error::DataError;
7use scirs2_core::Distribution; use torsh_tensor::Tensor;
9
10use scirs2_datasets::toy::{
12 load_boston as scirs2_load_boston, load_breast_cancer as scirs2_load_breast_cancer,
13 load_diabetes as scirs2_load_diabetes, load_digits as scirs2_load_digits,
14 load_iris as scirs2_load_iris,
15};
16
17#[derive(Debug, Clone)]
19pub enum BuiltinDataset {
20 Iris,
21 Boston,
22 Diabetes,
23 Wine,
24 BreastCancer,
25 Digits,
26}
27
28#[derive(Debug, Clone)]
30pub struct SyntheticDataConfig {
31 pub n_samples: usize,
33 pub n_features: usize,
35 pub n_classes: Option<usize>,
37 pub seed: Option<u64>,
39 pub noise: Option<f64>,
41 pub scale: Option<ScalingMethod>,
43}
44
45#[derive(Debug, Clone)]
47pub enum ScalingMethod {
48 StandardScaler,
49 MinMaxScaler,
50 RobustScaler,
51 Normalizer,
52}
53
54#[derive(Debug, Clone)]
56pub struct RegressionConfig {
57 pub n_samples: usize,
58 pub n_features: usize,
59 pub n_informative: Option<usize>,
60 pub noise: Option<f64>,
61 pub bias: Option<f64>,
62 pub random_state: Option<u64>,
63}
64
65#[derive(Debug, Clone)]
67pub struct ClassificationConfig {
68 pub n_samples: usize,
69 pub n_features: usize,
70 pub n_classes: usize,
71 pub n_informative: Option<usize>,
72 pub n_redundant: Option<usize>,
73 pub n_clusters_per_class: Option<usize>,
74 pub class_sep: Option<f64>,
75 pub random_state: Option<u64>,
76}
77
78#[derive(Debug, Clone)]
80pub struct ClusteringConfig {
81 pub n_samples: usize,
82 pub centers: usize,
83 pub n_features: Option<usize>,
84 pub cluster_std: Option<f64>,
85 pub center_box: Option<(f64, f64)>,
86 pub random_state: Option<u64>,
87}
88
89#[derive(Debug, Clone)]
91pub struct DatasetResult {
92 pub features: Tensor,
93 pub targets: Tensor,
94 pub feature_names: Option<Vec<String>>,
95 pub target_names: Option<Vec<String>>,
96 pub description: String,
97}
98
99impl Default for SyntheticDataConfig {
100 fn default() -> Self {
101 Self {
102 n_samples: 100,
103 n_features: 2,
104 n_classes: Some(2),
105 seed: None,
106 noise: Some(0.1),
107 scale: Some(ScalingMethod::StandardScaler),
108 }
109 }
110}
111
112pub fn load_builtin_dataset(dataset: BuiltinDataset) -> Result<DatasetResult, DataError> {
114 match dataset {
115 BuiltinDataset::Iris => load_iris_dataset(),
116 BuiltinDataset::Boston => load_boston_dataset(),
117 BuiltinDataset::Diabetes => load_diabetes_dataset(),
118 BuiltinDataset::Wine => load_wine_dataset(),
119 BuiltinDataset::BreastCancer => load_breast_cancer_dataset(),
120 BuiltinDataset::Digits => load_digits_dataset(),
121 }
122}
123
124pub fn make_regression(config: RegressionConfig) -> Result<DatasetResult, DataError> {
135 use scirs2_core::random::{Normal, SeedableRng, StdRng};
136
137 let n_informative = config.n_informative.unwrap_or(config.n_features);
138 let noise_std = config.noise.unwrap_or(0.0);
139 let bias = config.bias.unwrap_or(0.0);
140
141 if n_informative > config.n_features {
142 return Err(DataError::dataset(
143 crate::error::DatasetErrorKind::CorruptedData,
144 format!(
145 "n_informative ({}) cannot exceed n_features ({})",
146 n_informative, config.n_features
147 ),
148 ));
149 }
150
151 let mut rng = if let Some(seed) = config.random_state {
153 StdRng::seed_from_u64(seed)
154 } else {
155 let mut thread_rng = scirs2_core::random::thread_rng();
156 StdRng::from_rng(&mut thread_rng)
157 };
158
159 let normal = Normal::new(0.0, 1.0).expect("valid Normal parameters");
161 let features_data: Vec<f32> = (0..config.n_samples * config.n_features)
162 .map(|_| normal.sample(&mut rng) as f32)
163 .collect();
164
165 let features = Tensor::from_vec(
166 features_data.clone(),
167 &[config.n_samples, config.n_features],
168 )?;
169
170 let coefficients: Vec<f32> = (0..n_informative)
172 .map(|_| rng.gen_range(-100.0..100.0))
173 .collect();
174
175 let noise_dist = Normal::new(0.0, noise_std).expect("valid Normal parameters");
177 let targets_data: Vec<f32> = (0..config.n_samples)
178 .map(|i| {
179 let mut target = bias as f32;
181 for j in 0..n_informative {
182 let idx = i * config.n_features + j;
183 target += coefficients[j] * features_data[idx];
184 }
185
186 if noise_std > 0.0 {
188 target += noise_dist.sample(&mut rng) as f32;
189 }
190
191 target
192 })
193 .collect();
194
195 let targets = Tensor::from_vec(targets_data, &[config.n_samples])?;
196
197 Ok(DatasetResult {
198 features,
199 targets,
200 feature_names: Some(
201 (0..config.n_features)
202 .map(|i| {
203 if i < n_informative {
204 format!("informative_{}", i)
205 } else {
206 format!("noise_{}", i - n_informative)
207 }
208 })
209 .collect(),
210 ),
211 target_names: Some(vec!["target".to_string()]),
212 description: format!(
213 "Synthetic regression dataset: {} samples, {} features ({} informative), noise_std={:.2}, bias={:.2}",
214 config.n_samples, config.n_features, n_informative, noise_std, bias
215 ),
216 })
217}
218
219pub fn make_classification(config: ClassificationConfig) -> Result<DatasetResult, DataError> {
229 use scirs2_core::random::{Normal, SeedableRng, StdRng};
230
231 let n_informative = config.n_informative.unwrap_or(config.n_features.min(2));
232 let n_redundant = config.n_redundant.unwrap_or(0);
233 let n_clusters_per_class = config.n_clusters_per_class.unwrap_or(1);
234 let class_sep = config.class_sep.unwrap_or(1.0);
235
236 if n_informative + n_redundant > config.n_features {
237 return Err(DataError::dataset(
238 crate::error::DatasetErrorKind::CorruptedData,
239 format!(
240 "n_informative ({}) + n_redundant ({}) cannot exceed n_features ({})",
241 n_informative, n_redundant, config.n_features
242 ),
243 ));
244 }
245
246 let mut rng = if let Some(seed) = config.random_state {
248 StdRng::seed_from_u64(seed)
249 } else {
250 let mut thread_rng = scirs2_core::random::thread_rng();
251 StdRng::from_rng(&mut thread_rng)
252 };
253
254 let total_clusters = config.n_classes * n_clusters_per_class;
256 let mut cluster_centers: Vec<Vec<f32>> = Vec::new();
257 let mut cluster_labels: Vec<usize> = Vec::new();
258
259 for class_id in 0..config.n_classes {
260 for _ in 0..n_clusters_per_class {
261 let center: Vec<f32> = (0..n_informative)
262 .map(|_| rng.gen_range(-class_sep as f32..class_sep as f32) * 10.0)
263 .collect();
264 cluster_centers.push(center);
265 cluster_labels.push(class_id);
266 }
267 }
268
269 let samples_per_cluster = config.n_samples / total_clusters;
271 let remainder = config.n_samples % total_clusters;
272
273 let mut features_data = Vec::new();
274 let mut targets_data = Vec::new();
275
276 let normal = Normal::new(0.0, 1.0).expect("valid Normal parameters");
277
278 for (cluster_idx, (center, &class_label)) in cluster_centers
279 .iter()
280 .zip(cluster_labels.iter())
281 .enumerate()
282 {
283 let n_samples_this_cluster =
284 samples_per_cluster + if cluster_idx < remainder { 1 } else { 0 };
285
286 for _ in 0..n_samples_this_cluster {
287 for ¢er_val in center.iter() {
289 let noise = normal.sample(&mut rng) as f32;
290 features_data.push(center_val + noise);
291 }
292
293 let start_idx = features_data.len() - n_informative;
295 for _ in 0..n_redundant {
296 let mut redundant = 0.0f32;
297 for j in 0..n_informative {
298 let weight = rng.gen_range(-1.0..1.0);
299 redundant += weight * features_data[start_idx + j];
300 }
301 features_data.push(redundant);
302 }
303
304 let n_noise = config.n_features - n_informative - n_redundant;
306 for _ in 0..n_noise {
307 features_data.push(rng.gen_range(-10.0..10.0));
308 }
309
310 targets_data.push(class_label as f32);
311 }
312 }
313
314 let features = Tensor::from_vec(features_data, &[config.n_samples, config.n_features])?;
315 let targets = Tensor::from_vec(targets_data, &[config.n_samples])?;
316
317 Ok(DatasetResult {
318 features,
319 targets,
320 feature_names: Some(
321 (0..config.n_features)
322 .map(|i| {
323 if i < n_informative {
324 format!("informative_{}", i)
325 } else if i < n_informative + n_redundant {
326 format!("redundant_{}", i - n_informative)
327 } else {
328 format!("noise_{}", i - n_informative - n_redundant)
329 }
330 })
331 .collect(),
332 ),
333 target_names: Some(
334 (0..config.n_classes)
335 .map(|i| format!("class_{}", i))
336 .collect(),
337 ),
338 description: format!(
339 "Synthetic classification dataset: {} samples, {} features ({} informative, {} redundant), {} classes, class_sep={:.2}",
340 config.n_samples, config.n_features, n_informative, n_redundant, config.n_classes, class_sep
341 ),
342 })
343}
344
345pub fn make_blobs(config: ClusteringConfig) -> Result<DatasetResult, DataError> {
355 use scirs2_core::random::{Normal, SeedableRng, StdRng};
356
357 let mut rng = if let Some(seed) = config.random_state {
359 StdRng::seed_from_u64(seed)
360 } else {
361 let mut thread_rng = scirs2_core::random::thread_rng();
362 StdRng::from_rng(&mut thread_rng)
363 };
364
365 let n_features = config.n_features.unwrap_or(2);
366 let cluster_std = config.cluster_std.unwrap_or(1.0);
367 let (box_min, box_max) = config.center_box.unwrap_or((-10.0, 10.0));
368
369 if box_min >= box_max {
370 return Err(DataError::dataset(
371 crate::error::DatasetErrorKind::CorruptedData,
372 format!(
373 "center_box min ({}) must be less than max ({})",
374 box_min, box_max
375 ),
376 ));
377 }
378
379 let centers: Vec<Vec<f32>> = (0..config.centers)
381 .map(|_| {
382 (0..n_features)
383 .map(|_| rng.gen_range(box_min as f32..box_max as f32))
384 .collect()
385 })
386 .collect();
387
388 let samples_per_cluster = config.n_samples / config.centers;
390 let remainder = config.n_samples % config.centers;
391
392 let mut features_data = Vec::new();
393 let mut targets_data = Vec::new();
394
395 let normal = Normal::new(0.0, cluster_std).expect("valid Normal parameters");
397
398 for (cluster_id, center) in centers.iter().enumerate() {
399 let n_samples_this_cluster =
400 samples_per_cluster + if cluster_id < remainder { 1 } else { 0 };
401
402 for _ in 0..n_samples_this_cluster {
403 for ¢er_coord in center {
405 let noise = normal.sample(&mut rng) as f32;
406 features_data.push(center_coord + noise);
407 }
408 targets_data.push(cluster_id as f32);
409 }
410 }
411
412 let features = Tensor::from_vec(features_data, &[config.n_samples, n_features])?;
413 let targets = Tensor::from_vec(targets_data, &[config.n_samples])?;
414
415 Ok(DatasetResult {
416 features,
417 targets,
418 feature_names: Some((0..n_features).map(|i| format!("feature_{}", i)).collect()),
419 target_names: Some(
420 (0..config.centers)
421 .map(|i| format!("cluster_{}", i))
422 .collect(),
423 ),
424 description: format!(
425 "Synthetic clustering dataset (blobs): {} samples, {} features, {} clusters, cluster_std={:.2}",
426 config.n_samples, n_features, config.centers, cluster_std
427 ),
428 })
429}
430
431fn convert_scirs2_dataset(
433 scirs2_dataset: scirs2_datasets::utils::Dataset,
434) -> Result<DatasetResult, DataError> {
435 let shape = scirs2_dataset.data.shape();
437 let features_data: Vec<f32> = scirs2_dataset.data.iter().map(|&x| x as f32).collect();
438 let features = Tensor::from_vec(features_data, &[shape[0], shape[1]])?;
439
440 let targets = if let Some(target_array) = scirs2_dataset.target {
442 let target_data: Vec<f32> = target_array.iter().map(|&x| x as f32).collect();
443 Tensor::from_vec(target_data, &[target_array.len()])?
444 } else {
445 Tensor::from_vec(vec![], &[0])?
447 };
448
449 Ok(DatasetResult {
450 features,
451 targets,
452 feature_names: scirs2_dataset.featurenames,
453 target_names: scirs2_dataset.targetnames,
454 description: scirs2_dataset
455 .description
456 .unwrap_or_else(|| "Dataset loaded from scirs2".to_string()),
457 })
458}
459
460fn load_iris_dataset() -> Result<DatasetResult, DataError> {
462 let scirs2_dataset = scirs2_load_iris().map_err(|e| {
464 DataError::dataset(
465 crate::error::DatasetErrorKind::CorruptedData,
466 format!("Failed to load Iris dataset from scirs2_datasets: {}", e),
467 )
468 })?;
469
470 convert_scirs2_dataset(scirs2_dataset)
471}
472
473fn load_boston_dataset() -> Result<DatasetResult, DataError> {
474 let scirs2_dataset = scirs2_load_boston().map_err(|e| {
476 DataError::dataset(
477 crate::error::DatasetErrorKind::CorruptedData,
478 format!("Failed to load Boston dataset from scirs2_datasets: {}", e),
479 )
480 })?;
481
482 convert_scirs2_dataset(scirs2_dataset)
483}
484
485fn load_diabetes_dataset() -> Result<DatasetResult, DataError> {
486 let scirs2_dataset = scirs2_load_diabetes().map_err(|e| {
488 DataError::dataset(
489 crate::error::DatasetErrorKind::CorruptedData,
490 format!(
491 "Failed to load Diabetes dataset from scirs2_datasets: {}",
492 e
493 ),
494 )
495 })?;
496
497 convert_scirs2_dataset(scirs2_dataset)
498}
499
500fn load_wine_dataset() -> Result<DatasetResult, DataError> {
501 make_classification(ClassificationConfig {
502 n_samples: 178,
503 n_features: 13,
504 n_classes: 3,
505 n_informative: Some(13),
506 random_state: Some(42),
507 ..Default::default()
508 })
509}
510
511fn load_breast_cancer_dataset() -> Result<DatasetResult, DataError> {
512 let scirs2_dataset = scirs2_load_breast_cancer().map_err(|e| {
514 DataError::dataset(
515 crate::error::DatasetErrorKind::CorruptedData,
516 format!(
517 "Failed to load Breast Cancer dataset from scirs2_datasets: {}",
518 e
519 ),
520 )
521 })?;
522
523 convert_scirs2_dataset(scirs2_dataset)
524}
525
526fn load_digits_dataset() -> Result<DatasetResult, DataError> {
527 let scirs2_dataset = scirs2_load_digits().map_err(|e| {
529 DataError::dataset(
530 crate::error::DatasetErrorKind::CorruptedData,
531 format!("Failed to load Digits dataset from scirs2_datasets: {}", e),
532 )
533 })?;
534
535 convert_scirs2_dataset(scirs2_dataset)
536}
537
538impl Default for RegressionConfig {
539 fn default() -> Self {
540 Self {
541 n_samples: 100,
542 n_features: 1,
543 n_informative: None,
544 noise: Some(0.1),
545 bias: Some(0.0),
546 random_state: None,
547 }
548 }
549}
550
551impl Default for ClassificationConfig {
552 fn default() -> Self {
553 Self {
554 n_samples: 100,
555 n_features: 2,
556 n_classes: 2,
557 n_informative: None,
558 n_redundant: None,
559 n_clusters_per_class: None,
560 class_sep: Some(1.0),
561 random_state: None,
562 }
563 }
564}
565
566impl Default for ClusteringConfig {
567 fn default() -> Self {
568 Self {
569 n_samples: 100,
570 centers: 3,
571 n_features: Some(2),
572 cluster_std: Some(1.0),
573 center_box: Some((-10.0, 10.0)),
574 random_state: None,
575 }
576 }
577}
578
579#[derive(Debug, Default)]
581pub struct DatasetRegistry {
582 builtin_datasets: Vec<BuiltinDataset>,
583}
584
585impl DatasetRegistry {
586 pub fn new() -> Self {
588 Self {
589 builtin_datasets: vec![
590 BuiltinDataset::Iris,
591 BuiltinDataset::Boston,
592 BuiltinDataset::Diabetes,
593 BuiltinDataset::Wine,
594 BuiltinDataset::BreastCancer,
595 BuiltinDataset::Digits,
596 ],
597 }
598 }
599
600 pub fn list_builtin(&self) -> &[BuiltinDataset] {
602 &self.builtin_datasets
603 }
604
605 pub fn load_by_name(&self, name: &str) -> Result<DatasetResult, DataError> {
607 let dataset = match name.to_lowercase().as_str() {
608 "iris" => BuiltinDataset::Iris,
609 "boston" => BuiltinDataset::Boston,
610 "diabetes" => BuiltinDataset::Diabetes,
611 "wine" => BuiltinDataset::Wine,
612 "breast_cancer" | "breastcancer" => BuiltinDataset::BreastCancer,
613 "digits" => BuiltinDataset::Digits,
614 _ => {
615 return Err(DataError::dataset(
616 crate::error::DatasetErrorKind::UnsupportedFormat,
617 format!("Unknown dataset: {}", name),
618 ))
619 }
620 };
621
622 load_builtin_dataset(dataset)
623 }
624}
625
626#[cfg(test)]
627mod tests {
628 use super::*;
629
630 #[test]
631 fn test_load_iris_dataset() {
632 let result = load_builtin_dataset(BuiltinDataset::Iris);
633 assert!(result.is_ok());
634 let dataset = result.unwrap();
635
636 assert_eq!(dataset.features.size(0).unwrap(), 150);
638 assert_eq!(dataset.features.size(1).unwrap(), 4);
639 assert_eq!(dataset.targets.size(0).unwrap(), 150);
640
641 assert!(dataset.feature_names.is_some());
643 assert!(dataset.target_names.is_some());
644 assert!(!dataset.description.is_empty());
645
646 let feature_names = dataset.feature_names.unwrap();
647 assert_eq!(feature_names.len(), 4);
648 assert!(feature_names.contains(&"sepal_length".to_string()));
649
650 let target_names = dataset.target_names.unwrap();
651 assert_eq!(target_names.len(), 3);
652 }
653
654 #[test]
655 fn test_load_boston_dataset() {
656 let result = load_builtin_dataset(BuiltinDataset::Boston);
657 assert!(result.is_ok());
658 let dataset = result.unwrap();
659
660 assert_eq!(dataset.features.size(0).unwrap(), 30);
662 assert_eq!(dataset.features.size(1).unwrap(), 5);
663 assert_eq!(dataset.targets.size(0).unwrap(), 30);
664
665 assert!(dataset.feature_names.is_some());
667 assert!(!dataset.description.is_empty());
668 }
669
670 #[test]
671 fn test_load_diabetes_dataset() {
672 let result = load_builtin_dataset(BuiltinDataset::Diabetes);
673 assert!(result.is_ok());
674 let dataset = result.unwrap();
675
676 assert_eq!(dataset.features.size(0).unwrap(), 442);
678 assert_eq!(dataset.features.size(1).unwrap(), 10);
679 assert_eq!(dataset.targets.size(0).unwrap(), 442);
680
681 assert!(dataset.feature_names.is_some());
683 assert!(!dataset.description.is_empty());
684
685 let feature_names = dataset.feature_names.unwrap();
686 assert_eq!(feature_names.len(), 10);
687 assert!(feature_names.contains(&"age".to_string()));
689 assert!(feature_names.contains(&"bmi".to_string()));
690 }
691
692 #[test]
693 fn test_load_breast_cancer_dataset() {
694 let result = load_builtin_dataset(BuiltinDataset::BreastCancer);
695 assert!(result.is_ok());
696 let dataset = result.unwrap();
697
698 assert_eq!(dataset.features.size(0).unwrap(), 30);
700 assert_eq!(dataset.features.size(1).unwrap(), 5);
701 assert_eq!(dataset.targets.size(0).unwrap(), 30);
702
703 assert!(dataset.feature_names.is_some());
705 assert!(dataset.target_names.is_some());
706 assert!(!dataset.description.is_empty());
707
708 let target_names = dataset.target_names.unwrap();
709 assert_eq!(target_names.len(), 2); assert!(target_names.contains(&"malignant".to_string()));
711 assert!(target_names.contains(&"benign".to_string()));
712 }
713
714 #[test]
715 fn test_load_digits_dataset() {
716 let result = load_builtin_dataset(BuiltinDataset::Digits);
717 assert!(result.is_ok());
718 let dataset = result.unwrap();
719
720 assert_eq!(dataset.features.size(0).unwrap(), 50);
722 assert_eq!(dataset.features.size(1).unwrap(), 16);
723 assert_eq!(dataset.targets.size(0).unwrap(), 50);
724
725 assert!(dataset.target_names.is_some());
727 assert!(!dataset.description.is_empty());
728
729 let target_names = dataset.target_names.unwrap();
730 assert_eq!(target_names.len(), 10); }
732
733 #[test]
734 fn test_load_wine_dataset() {
735 let result = load_builtin_dataset(BuiltinDataset::Wine);
736 assert!(result.is_ok());
737 let dataset = result.unwrap();
738
739 assert_eq!(dataset.features.size(0).unwrap(), 178);
741 assert_eq!(dataset.features.size(1).unwrap(), 13);
742 assert_eq!(dataset.targets.size(0).unwrap(), 178);
743
744 assert!(!dataset.description.is_empty());
746 }
747
748 #[test]
749 fn test_dataset_registry() {
750 let registry = DatasetRegistry::new();
751 let builtin_datasets = registry.list_builtin();
752
753 assert_eq!(builtin_datasets.len(), 6);
755 }
756
757 #[test]
758 fn test_load_by_name() {
759 let registry = DatasetRegistry::new();
760
761 assert!(registry.load_by_name("iris").is_ok());
763 assert!(registry.load_by_name("boston").is_ok());
764 assert!(registry.load_by_name("diabetes").is_ok());
765 assert!(registry.load_by_name("wine").is_ok());
766 assert!(registry.load_by_name("breast_cancer").is_ok());
767 assert!(registry.load_by_name("breastcancer").is_ok()); assert!(registry.load_by_name("digits").is_ok());
769
770 assert!(registry.load_by_name("IRIS").is_ok());
772 assert!(registry.load_by_name("Diabetes").is_ok());
773
774 assert!(registry.load_by_name("unknown").is_err());
776 }
777
778 #[test]
779 fn test_make_regression() {
780 let config = RegressionConfig {
781 n_samples: 100,
782 n_features: 5,
783 n_informative: Some(3),
784 noise: Some(0.1),
785 bias: Some(1.0),
786 random_state: Some(42),
787 };
788
789 let result = make_regression(config);
790 assert!(result.is_ok());
791 let dataset = result.unwrap();
792
793 assert_eq!(dataset.features.size(0).unwrap(), 100);
794 assert_eq!(dataset.features.size(1).unwrap(), 5);
795 assert_eq!(dataset.targets.size(0).unwrap(), 100);
796 }
797
798 #[test]
799 fn test_make_classification() {
800 let config = ClassificationConfig {
801 n_samples: 200,
802 n_features: 10,
803 n_classes: 3,
804 n_informative: Some(5),
805 random_state: Some(42),
806 ..Default::default()
807 };
808
809 let result = make_classification(config);
810 assert!(result.is_ok());
811 let dataset = result.unwrap();
812
813 assert_eq!(dataset.features.size(0).unwrap(), 200);
814 assert_eq!(dataset.features.size(1).unwrap(), 10);
815 assert_eq!(dataset.targets.size(0).unwrap(), 200);
816 }
817
818 #[test]
819 fn test_make_blobs() {
820 let config = ClusteringConfig {
821 n_samples: 150,
822 centers: 3,
823 n_features: Some(2),
824 cluster_std: Some(0.5),
825 random_state: Some(42),
826 ..Default::default()
827 };
828
829 let result = make_blobs(config);
830 assert!(result.is_ok());
831 let dataset = result.unwrap();
832
833 assert_eq!(dataset.features.size(0).unwrap(), 150);
834 assert_eq!(dataset.features.size(1).unwrap(), 2);
835 assert_eq!(dataset.targets.size(0).unwrap(), 150);
836 }
837
838 #[test]
839 fn test_regression_config_validation() {
840 let config = RegressionConfig {
842 n_samples: 100,
843 n_features: 5,
844 n_informative: Some(10), noise: Some(0.1),
846 bias: Some(0.0),
847 random_state: Some(42),
848 };
849
850 let result = make_regression(config);
851 assert!(result.is_err());
852 }
853
854 #[test]
855 fn test_scirs2_integration_diabetes() {
856 let result = load_builtin_dataset(BuiltinDataset::Diabetes);
858 assert!(result.is_ok());
859 let dataset = result.unwrap();
860
861 assert_eq!(dataset.features.size(0).unwrap(), 442);
863 assert_eq!(dataset.features.size(1).unwrap(), 10);
864
865 assert!(
867 dataset.description.contains("diabetes") || dataset.description.contains("Diabetes")
868 );
869 }
870
871 #[test]
872 fn test_scirs2_integration_breast_cancer() {
873 let result = load_builtin_dataset(BuiltinDataset::BreastCancer);
875 assert!(result.is_ok());
876 let dataset = result.unwrap();
877
878 assert_eq!(dataset.features.size(0).unwrap(), 30);
880 assert_eq!(dataset.features.size(1).unwrap(), 5);
881
882 assert!(dataset.feature_names.is_some());
884 assert!(dataset.target_names.is_some());
885 }
886
887 #[test]
888 fn test_scirs2_integration_digits() {
889 let result = load_builtin_dataset(BuiltinDataset::Digits);
891 assert!(result.is_ok());
892 let dataset = result.unwrap();
893
894 assert_eq!(dataset.features.size(0).unwrap(), 50);
896 assert_eq!(dataset.features.size(1).unwrap(), 16); assert!(dataset.target_names.is_some());
900 let target_names = dataset.target_names.unwrap();
901 assert_eq!(target_names.len(), 10);
902 }
903}