1use crate::cache::{CacheKey, CacheManager};
8use crate::error::{DatasetsError, Result};
9use crate::registry::{DatasetMetadata, DatasetRegistry};
10use crate::utils::Dataset;
11use scirs2_core::ndarray::{Array1, Array2};
12use scirs2_core::random::prelude::*;
13use scirs2_core::random::Uniform;
14
15#[derive(Debug, Clone)]
17pub struct RealWorldConfig {
18 pub use_cache: bool,
20 pub download_if_missing: bool,
22 pub data_home: Option<String>,
24 pub return_preprocessed: bool,
26 pub subset: Option<String>,
28 pub random_state: Option<u64>,
30}
31
32impl Default for RealWorldConfig {
33 fn default() -> Self {
34 Self {
35 use_cache: true,
36 download_if_missing: true,
37 data_home: None,
38 return_preprocessed: false,
39 subset: None,
40 random_state: None,
41 }
42 }
43}
44
45pub struct RealWorldDatasets {
47 cache: CacheManager,
48 registry: DatasetRegistry,
49 config: RealWorldConfig,
50}
51
52impl RealWorldDatasets {
53 pub fn new(config: RealWorldConfig) -> Result<Self> {
55 let cache = CacheManager::new()?;
56 let registry = DatasetRegistry::new();
57
58 Ok(Self {
59 cache,
60 registry,
61 config,
62 })
63 }
64
65 pub fn load_dataset(&mut self, name: &str) -> Result<Dataset> {
67 match name {
68 "adult" => self.load_adult(),
70 "bank_marketing" => self.load_bank_marketing(),
71 "credit_approval" => self.load_credit_approval(),
72 "german_credit" => self.load_german_credit(),
73 "mushroom" => self.load_mushroom(),
74 "spam" => self.load_spam(),
75 "titanic" => self.load_titanic(),
76
77 "auto_mpg" => self.load_auto_mpg(),
79 "california_housing" => self.load_california_housing(),
80 "concrete_strength" => self.load_concrete_strength(),
81 "energy_efficiency" => self.load_energy_efficiency(),
82 "red_wine_quality" => self.load_red_wine_quality(),
83 "white_wine_quality" => self.load_white_wine_quality(),
84
85 "air_passengers" => self.load_air_passengers(),
87 "bitcoin_prices" => self.load_bitcoin_prices(),
88 "electricity_load" => self.load_electricity_load(),
89 "stock_prices" => self.load_stock_prices(),
90
91 "cifar10_subset" => self.load_cifar10_subset(),
93 "fashion_mnist_subset" => self.load_fashion_mnist_subset(),
94
95 "imdb_reviews" => self.load_imdb_reviews(),
97 "news_articles" => self.load_news_articles(),
98
99 "diabetes_readmission" => self.load_diabetes_readmission(),
101 "heart_disease" => self.load_heart_disease(),
102
103 "credit_card_fraud" => self.load_credit_card_fraud(),
105 "loan_default" => self.load_loan_default(),
106 _ => Err(DatasetsError::NotFound(format!("Unknown dataset: {name}"))),
107 }
108 }
109
110 pub fn list_datasets(&self) -> Vec<String> {
112 vec![
113 "adult".to_string(),
115 "bank_marketing".to_string(),
116 "credit_approval".to_string(),
117 "german_credit".to_string(),
118 "mushroom".to_string(),
119 "spam".to_string(),
120 "titanic".to_string(),
121 "auto_mpg".to_string(),
123 "california_housing".to_string(),
124 "concrete_strength".to_string(),
125 "energy_efficiency".to_string(),
126 "red_wine_quality".to_string(),
127 "white_wine_quality".to_string(),
128 "air_passengers".to_string(),
130 "bitcoin_prices".to_string(),
131 "electricity_load".to_string(),
132 "stock_prices".to_string(),
133 "cifar10_subset".to_string(),
135 "fashion_mnist_subset".to_string(),
136 "imdb_reviews".to_string(),
138 "news_articles".to_string(),
139 "diabetes_readmission".to_string(),
141 "heart_disease".to_string(),
142 "credit_card_fraud".to_string(),
144 "loan_default".to_string(),
145 ]
146 }
147
148 pub fn get_dataset_info(&self, name: &str) -> Result<DatasetMetadata> {
150 self.registry.get_metadata(name)
151 }
152}
153
154impl RealWorldDatasets {
156 pub fn load_adult(&mut self) -> Result<Dataset> {
159 let cache_key = CacheKey::new("adult", &self.config);
160
161 if self.config.use_cache {
162 if let Some(dataset) = self.cache.get(&cache_key)? {
163 return Ok(dataset);
164 }
165 }
166
167 let url = "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data";
168 let dataset = self.download_and_parse_csv(
169 url,
170 "adult",
171 &[
172 "age",
173 "workclass",
174 "fnlwgt",
175 "education",
176 "education_num",
177 "marital_status",
178 "occupation",
179 "relationship",
180 "race",
181 "sex",
182 "capital_gain",
183 "capital_loss",
184 "hours_per_week",
185 "native_country",
186 "income",
187 ],
188 Some("income"),
189 true, )?;
191
192 if self.config.use_cache {
193 self.cache.put(&cache_key, &dataset)?;
194 }
195
196 Ok(dataset)
197 }
198
199 pub fn load_bank_marketing(&mut self) -> Result<Dataset> {
202 let cache_key = CacheKey::new("bank_marketing", &self.config);
203
204 if self.config.use_cache {
205 if let Some(dataset) = self.cache.get(&cache_key)? {
206 return Ok(dataset);
207 }
208 }
209
210 let (data, target) = self.create_synthetic_bank_data(4521, 16)?;
213
214 let metadata = DatasetMetadata {
215 name: "Bank Marketing".to_string(),
216 description: "Direct marketing campaigns of a Portuguese banking institution"
217 .to_string(),
218 n_samples: 4521,
219 n_features: 16,
220 task_type: "classification".to_string(),
221 targetnames: Some(vec!["no".to_string(), "yes".to_string()]),
222 featurenames: None,
223 url: None,
224 checksum: None,
225 };
226
227 let dataset = Dataset::from_metadata(data, Some(target), metadata);
228
229 if self.config.use_cache {
230 self.cache.put(&cache_key, &dataset)?;
231 }
232
233 Ok(dataset)
234 }
235
236 pub fn load_titanic(&mut self) -> Result<Dataset> {
239 let cache_key = CacheKey::new("titanic", &self.config);
240
241 if self.config.use_cache {
242 if let Some(dataset) = self.cache.get(&cache_key)? {
243 return Ok(dataset);
244 }
245 }
246
247 let (data, target) = self.create_synthetic_titanic_data(891, 7)?;
248
249 let metadata = DatasetMetadata {
250 name: "Titanic".to_string(),
251 description: "Passenger survival data from the Titanic disaster".to_string(),
252 n_samples: 891,
253 n_features: 7,
254 task_type: "classification".to_string(),
255 targetnames: Some(vec!["died".to_string(), "survived".to_string()]),
256 featurenames: None,
257 url: None,
258 checksum: None,
259 };
260
261 let dataset = Dataset::from_metadata(data, Some(target), metadata);
262
263 if self.config.use_cache {
264 self.cache.put(&cache_key, &dataset)?;
265 }
266
267 Ok(dataset)
268 }
269
270 pub fn load_german_credit(&mut self) -> Result<Dataset> {
273 let (data, target) = self.create_synthetic_credit_data(1000, 20)?;
274
275 let metadata = DatasetMetadata {
276 name: "German Credit".to_string(),
277 description: "Credit risk classification dataset".to_string(),
278 n_samples: 1000,
279 n_features: 20,
280 task_type: "classification".to_string(),
281 targetnames: Some(vec!["bad_credit".to_string(), "good_credit".to_string()]),
282 featurenames: None,
283 url: None,
284 checksum: None,
285 };
286
287 Ok(Dataset::from_metadata(data, Some(target), metadata))
288 }
289}
290
291impl RealWorldDatasets {
293 pub fn load_california_housing(&mut self) -> Result<Dataset> {
296 let (data, target) = self.create_synthetic_housing_data(20640, 8)?;
297
298 let metadata = DatasetMetadata {
299 name: "California Housing".to_string(),
300 description: "Median house values for California districts from 1990 census"
301 .to_string(),
302 n_samples: 20640,
303 n_features: 8,
304 task_type: "regression".to_string(),
305 targetnames: None, featurenames: None,
307 url: None,
308 checksum: None,
309 };
310
311 Ok(Dataset::from_metadata(data, Some(target), metadata))
312 }
313
314 pub fn load_red_wine_quality(&mut self) -> Result<Dataset> {
317 let (data, target) = self.create_synthetic_wine_data(1599, 11)?;
318
319 let metadata = DatasetMetadata {
320 name: "Red Wine Quality".to_string(),
321 description: "Red wine quality based on physicochemical tests".to_string(),
322 n_samples: 1599,
323 n_features: 11,
324 task_type: "regression".to_string(),
325 targetnames: None, featurenames: None,
327 url: None,
328 checksum: None,
329 };
330
331 Ok(Dataset::from_metadata(data, Some(target), metadata))
332 }
333
334 pub fn load_energy_efficiency(&mut self) -> Result<Dataset> {
337 let (data, target) = self.create_synthetic_energy_data(768, 8)?;
338
339 let metadata = DatasetMetadata {
340 name: "Energy Efficiency".to_string(),
341 description: "Energy efficiency of buildings based on building parameters".to_string(),
342 n_samples: 768,
343 n_features: 8,
344 task_type: "regression".to_string(),
345 targetnames: None, featurenames: None,
347 url: None,
348 checksum: None,
349 };
350
351 Ok(Dataset::from_metadata(data, Some(target), metadata))
352 }
353}
354
355impl RealWorldDatasets {
357 pub fn load_air_passengers(&mut self) -> Result<Dataset> {
360 let (data, target) = self.create_air_passengers_data(144)?;
361
362 let metadata = DatasetMetadata {
363 name: "Air Passengers".to_string(),
364 description: "Monthly airline passenger numbers 1949-1960".to_string(),
365 n_samples: 144,
366 n_features: 1,
367 task_type: "time_series".to_string(),
368 targetnames: None, featurenames: None,
370 url: None,
371 checksum: None,
372 };
373
374 Ok(Dataset::from_metadata(data, target, metadata))
375 }
376
377 pub fn load_bitcoin_prices(&mut self) -> Result<Dataset> {
380 let (data, target) = self.create_bitcoin_price_data(1000)?;
381
382 let metadata = DatasetMetadata {
383 name: "Bitcoin Prices".to_string(),
384 description: "Historical Bitcoin price data with technical indicators".to_string(),
385 n_samples: 1000,
386 n_features: 6,
387 task_type: "time_series".to_string(),
388 targetnames: None, featurenames: None,
390 url: None,
391 checksum: None,
392 };
393
394 Ok(Dataset::from_metadata(data, target, metadata))
395 }
396}
397
398impl RealWorldDatasets {
400 pub fn load_heart_disease(&mut self) -> Result<Dataset> {
403 let (data, target) = self.create_heart_disease_data(303, 13)?;
404
405 let metadata = DatasetMetadata {
406 name: "Heart Disease".to_string(),
407 description: "Heart disease prediction based on clinical parameters".to_string(),
408 n_samples: 303,
409 n_features: 13,
410 task_type: "classification".to_string(),
411 targetnames: Some(vec!["no_disease".to_string(), "disease".to_string()]),
412 featurenames: None,
413 url: None,
414 checksum: None,
415 };
416
417 Ok(Dataset::from_metadata(data, Some(target), metadata))
418 }
419
420 pub fn load_diabetes_readmission(&mut self) -> Result<Dataset> {
423 let (data, target) = self.create_diabetes_readmission_data(101766, 49)?;
424
425 let metadata = DatasetMetadata {
426 name: "Diabetes Readmission".to_string(),
427 description: "Hospital readmission prediction for diabetic patients".to_string(),
428 n_samples: 101766,
429 n_features: 49,
430 task_type: "classification".to_string(),
431 targetnames: Some(vec![
432 "no_readmission".to_string(),
433 "readmission".to_string(),
434 ]),
435 featurenames: None,
436 url: None,
437 checksum: None,
438 };
439
440 Ok(Dataset::from_metadata(data, Some(target), metadata))
441 }
442
443 pub fn load_credit_approval(&mut self) -> Result<Dataset> {
445 let cache_key = CacheKey::new("credit_approval", &self.config);
446
447 if self.config.use_cache {
448 if let Some(dataset) = self.cache.get(&cache_key)? {
449 return Ok(dataset);
450 }
451 }
452
453 let url =
455 "https://archive.ics.uci.edu/ml/machine-learning-databases/credit-screening/crx.data";
456 let columns = &[
457 "A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "A9", "A10", "A11", "A12", "A13",
458 "A14", "A15", "class",
459 ];
460
461 let dataset =
462 self.download_and_parse_csv(url, "credit_approval", columns, Some("class"), true)?;
463
464 if self.config.use_cache {
465 self.cache.put(&cache_key, &dataset)?;
466 }
467
468 Ok(dataset)
469 }
470
471 pub fn load_mushroom(&mut self) -> Result<Dataset> {
473 let cache_key = CacheKey::new("mushroom", &self.config);
474
475 if self.config.use_cache {
476 if let Some(dataset) = self.cache.get(&cache_key)? {
477 return Ok(dataset);
478 }
479 }
480
481 let url = "https://archive.ics.uci.edu/ml/machine-learning-databases/mushroom/agaricus-lepiota.data";
483 let columns = &[
484 "class",
485 "cap-shape",
486 "cap-surface",
487 "cap-color",
488 "bruises",
489 "odor",
490 "gill-attachment",
491 "gill-spacing",
492 "gill-size",
493 "gill-color",
494 "stalk-shape",
495 "stalk-root",
496 "stalk-surface-above-ring",
497 "stalk-surface-below-ring",
498 "stalk-color-above-ring",
499 "stalk-color-below-ring",
500 "veil-type",
501 "veil-color",
502 "ring-number",
503 "ring-type",
504 "spore-print-color",
505 "population",
506 "habitat",
507 ];
508
509 let dataset = self.download_and_parse_csv(url, "mushroom", columns, Some("class"), true)?;
510
511 if self.config.use_cache {
512 self.cache.put(&cache_key, &dataset)?;
513 }
514
515 Ok(dataset)
516 }
517
518 pub fn load_spam(&mut self) -> Result<Dataset> {
520 let cache_key = CacheKey::new("spam", &self.config);
521
522 if self.config.use_cache {
523 if let Some(dataset) = self.cache.get(&cache_key)? {
524 return Ok(dataset);
525 }
526 }
527
528 let url =
530 "https://archive.ics.uci.edu/ml/machine-learning-databases/spambase/spambase.data";
531 let mut columns: Vec<String> = Vec::new();
532
533 for i in 0..48 {
535 columns.push(format!("word_freq_{i}"));
536 }
537 for i in 0..6 {
538 columns.push(format!("char_freq_{i}"));
539 }
540 columns.push("capital_run_length_average".to_string());
541 columns.push("capital_run_length_longest".to_string());
542 columns.push("capital_run_length_total".to_string());
543 columns.push("spam".to_string());
544
545 let column_refs: Vec<&str> = columns.iter().map(|s| s.as_str()).collect();
546
547 let dataset =
548 self.download_and_parse_csv(url, "spam", &column_refs, Some("spam"), false)?;
549
550 if self.config.use_cache {
551 self.cache.put(&cache_key, &dataset)?;
552 }
553
554 Ok(dataset)
555 }
556
557 pub fn load_auto_mpg(&mut self) -> Result<Dataset> {
559 let cache_key = CacheKey::new("auto_mpg", &self.config);
560
561 if self.config.use_cache {
562 if let Some(dataset) = self.cache.get(&cache_key)? {
563 return Ok(dataset);
564 }
565 }
566
567 let (data, target) = self.create_synthetic_auto_mpg_data(392, 7)?;
568
569 let metadata = DatasetMetadata {
570 name: "Auto MPG".to_string(),
571 description:
572 "Predict car fuel efficiency (miles per gallon) from technical specifications"
573 .to_string(),
574 n_samples: 392,
575 n_features: 7,
576 task_type: "regression".to_string(),
577 targetnames: None, featurenames: None,
579 url: None,
580 checksum: None,
581 };
582
583 let dataset = Dataset::from_metadata(data, Some(target), metadata);
584
585 if self.config.use_cache {
586 self.cache.put(&cache_key, &dataset)?;
587 }
588
589 Ok(dataset)
590 }
591
592 pub fn load_concrete_strength(&mut self) -> Result<Dataset> {
594 let cache_key = CacheKey::new("concrete_strength", &self.config);
595
596 if self.config.use_cache {
597 if let Some(dataset) = self.cache.get(&cache_key)? {
598 return Ok(dataset);
599 }
600 }
601
602 let (data, target) = self.create_synthetic_concrete_data(1030, 8)?;
603
604 let metadata = DatasetMetadata {
605 name: "Concrete Compressive Strength".to_string(),
606 description: "Predict concrete compressive strength from mixture components"
607 .to_string(),
608 n_samples: 1030,
609 n_features: 8,
610 task_type: "regression".to_string(),
611 targetnames: None, featurenames: None,
613 url: None,
614 checksum: None,
615 };
616
617 let dataset = Dataset::from_metadata(data, Some(target), metadata);
618
619 if self.config.use_cache {
620 self.cache.put(&cache_key, &dataset)?;
621 }
622
623 Ok(dataset)
624 }
625
626 pub fn load_white_wine_quality(&mut self) -> Result<Dataset> {
628 let cache_key = CacheKey::new("white_wine_quality", &self.config);
629
630 if self.config.use_cache {
631 if let Some(dataset) = self.cache.get(&cache_key)? {
632 return Ok(dataset);
633 }
634 }
635
636 let (data, target) = self.create_synthetic_wine_data(4898, 11)?;
637
638 let metadata = DatasetMetadata {
639 name: "White Wine Quality".to_string(),
640 description: "White wine quality based on physicochemical tests".to_string(),
641 n_samples: 4898,
642 n_features: 11,
643 task_type: "regression".to_string(),
644 targetnames: None, featurenames: None,
646 url: None,
647 checksum: None,
648 };
649
650 let dataset = Dataset::from_metadata(data, Some(target), metadata);
651
652 if self.config.use_cache {
653 self.cache.put(&cache_key, &dataset)?;
654 }
655
656 Ok(dataset)
657 }
658
659 pub fn load_electricity_load(&mut self) -> Result<Dataset> {
661 let cache_key = CacheKey::new("electricity_load", &self.config);
662
663 if self.config.use_cache {
664 if let Some(dataset) = self.cache.get(&cache_key)? {
665 return Ok(dataset);
666 }
667 }
668
669 let (data, target) = self.create_synthetic_electricity_data(26304, 3)?; let metadata = DatasetMetadata {
672 name: "Electricity Load".to_string(),
673 description: "Hourly electricity consumption forecasting with weather factors"
674 .to_string(),
675 n_samples: 26304,
676 n_features: 3,
677 task_type: "time_series".to_string(),
678 targetnames: None, featurenames: None,
680 url: None,
681 checksum: None,
682 };
683
684 let dataset = Dataset::from_metadata(data, Some(target), metadata);
685
686 if self.config.use_cache {
687 self.cache.put(&cache_key, &dataset)?;
688 }
689
690 Ok(dataset)
691 }
692
693 pub fn load_stock_prices(&mut self) -> Result<Dataset> {
695 let cache_key = CacheKey::new("stock_prices", &self.config);
696
697 if self.config.use_cache {
698 if let Some(dataset) = self.cache.get(&cache_key)? {
699 return Ok(dataset);
700 }
701 }
702
703 let (data, target) = self.create_synthetic_stock_data(1260, 5)?; let metadata = DatasetMetadata {
706 name: "Stock Prices".to_string(),
707 description: "Daily stock price prediction with technical indicators".to_string(),
708 n_samples: 1260,
709 n_features: 5,
710 task_type: "time_series".to_string(),
711 targetnames: None, featurenames: None,
713 url: None,
714 checksum: None,
715 };
716
717 let dataset = Dataset::from_metadata(data, Some(target), metadata);
718
719 if self.config.use_cache {
720 self.cache.put(&cache_key, &dataset)?;
721 }
722
723 Ok(dataset)
724 }
725
726 pub fn load_cifar10_subset(&mut self) -> Result<Dataset> {
728 let cache_key = CacheKey::new("cifar10_subset", &self.config);
729
730 if self.config.use_cache {
731 if let Some(dataset) = self.cache.get(&cache_key)? {
732 return Ok(dataset);
733 }
734 }
735
736 let (data, target) = self.create_synthetic_cifar10_data(1000, 3072)?; let metadata = DatasetMetadata {
739 name: "CIFAR-10 Subset".to_string(),
740 description: "Subset of CIFAR-10 32x32 color images in 10 classes".to_string(),
741 n_samples: 1000,
742 n_features: 3072,
743 task_type: "classification".to_string(),
744 targetnames: Some(vec![
745 "airplane".to_string(),
746 "automobile".to_string(),
747 "bird".to_string(),
748 "cat".to_string(),
749 "deer".to_string(),
750 "dog".to_string(),
751 "frog".to_string(),
752 "horse".to_string(),
753 "ship".to_string(),
754 "truck".to_string(),
755 ]),
756 featurenames: None,
757 url: None,
758 checksum: None,
759 };
760
761 let dataset = Dataset::from_metadata(data, Some(target), metadata);
762
763 if self.config.use_cache {
764 self.cache.put(&cache_key, &dataset)?;
765 }
766
767 Ok(dataset)
768 }
769
770 pub fn load_fashion_mnist_subset(&mut self) -> Result<Dataset> {
772 let cache_key = CacheKey::new("fashion_mnist_subset", &self.config);
773
774 if self.config.use_cache {
775 if let Some(dataset) = self.cache.get(&cache_key)? {
776 return Ok(dataset);
777 }
778 }
779
780 let (data, target) = self.create_synthetic_fashion_mnist_data(1000, 784)?; let metadata = DatasetMetadata {
783 name: "Fashion-MNIST Subset".to_string(),
784 description: "Subset of Fashion-MNIST 28x28 grayscale images of fashion items"
785 .to_string(),
786 n_samples: 1000,
787 n_features: 784,
788 task_type: "classification".to_string(),
789 targetnames: Some(vec![
790 "T-shirt/top".to_string(),
791 "Trouser".to_string(),
792 "Pullover".to_string(),
793 "Dress".to_string(),
794 "Coat".to_string(),
795 "Sandal".to_string(),
796 "Shirt".to_string(),
797 "Sneaker".to_string(),
798 "Bag".to_string(),
799 "Ankle boot".to_string(),
800 ]),
801 featurenames: None,
802 url: None,
803 checksum: None,
804 };
805
806 let dataset = Dataset::from_metadata(data, Some(target), metadata);
807
808 if self.config.use_cache {
809 self.cache.put(&cache_key, &dataset)?;
810 }
811
812 Ok(dataset)
813 }
814
815 pub fn load_imdb_reviews(&mut self) -> Result<Dataset> {
817 let cache_key = CacheKey::new("imdb_reviews", &self.config);
818
819 if self.config.use_cache {
820 if let Some(dataset) = self.cache.get(&cache_key)? {
821 return Ok(dataset);
822 }
823 }
824
825 let (data, target) = self.create_synthetic_imdb_data(5000, 1000)?; let metadata = DatasetMetadata {
828 name: "IMDB Movie Reviews".to_string(),
829 description: "Subset of IMDB movie reviews for sentiment classification".to_string(),
830 n_samples: 5000,
831 n_features: 1000,
832 task_type: "classification".to_string(),
833 targetnames: Some(vec!["negative".to_string(), "positive".to_string()]),
834 featurenames: None,
835 url: None,
836 checksum: None,
837 };
838
839 let dataset = Dataset::from_metadata(data, Some(target), metadata);
840
841 if self.config.use_cache {
842 self.cache.put(&cache_key, &dataset)?;
843 }
844
845 Ok(dataset)
846 }
847
848 pub fn load_news_articles(&mut self) -> Result<Dataset> {
850 let cache_key = CacheKey::new("news_articles", &self.config);
851
852 if self.config.use_cache {
853 if let Some(dataset) = self.cache.get(&cache_key)? {
854 return Ok(dataset);
855 }
856 }
857
858 let (data, target) = self.create_synthetic_news_data(2000, 500)?; let metadata = DatasetMetadata {
861 name: "News Articles".to_string(),
862 description: "News articles categorized by topic for text classification".to_string(),
863 n_samples: 2000,
864 n_features: 500,
865 task_type: "classification".to_string(),
866 targetnames: Some(vec![
867 "business".to_string(),
868 "entertainment".to_string(),
869 "politics".to_string(),
870 "sport".to_string(),
871 "tech".to_string(),
872 ]),
873 featurenames: None,
874 url: None,
875 checksum: None,
876 };
877
878 let dataset = Dataset::from_metadata(data, Some(target), metadata);
879
880 if self.config.use_cache {
881 self.cache.put(&cache_key, &dataset)?;
882 }
883
884 Ok(dataset)
885 }
886
887 pub fn load_credit_card_fraud(&mut self) -> Result<Dataset> {
889 let cache_key = CacheKey::new("credit_card_fraud", &self.config);
890
891 if self.config.use_cache {
892 if let Some(dataset) = self.cache.get(&cache_key)? {
893 return Ok(dataset);
894 }
895 }
896
897 let (data, target) = self.create_synthetic_fraud_data(284807, 28)?;
898
899 let metadata = DatasetMetadata {
900 name: "Credit Card Fraud Detection".to_string(),
901 description: "Detect fraudulent credit card transactions from anonymized features"
902 .to_string(),
903 n_samples: 284807,
904 n_features: 28,
905 task_type: "classification".to_string(),
906 targetnames: Some(vec!["legitimate".to_string(), "fraud".to_string()]),
907 featurenames: None,
908 url: None,
909 checksum: None,
910 };
911
912 let dataset = Dataset::from_metadata(data, Some(target), metadata);
913
914 if self.config.use_cache {
915 self.cache.put(&cache_key, &dataset)?;
916 }
917
918 Ok(dataset)
919 }
920
921 pub fn load_loan_default(&mut self) -> Result<Dataset> {
923 let cache_key = CacheKey::new("loan_default", &self.config);
924
925 if self.config.use_cache {
926 if let Some(dataset) = self.cache.get(&cache_key)? {
927 return Ok(dataset);
928 }
929 }
930
931 let (data, target) = self.create_synthetic_loan_data(10000, 15)?;
932
933 let metadata = DatasetMetadata {
934 name: "Loan Default Prediction".to_string(),
935 description: "Predict loan default risk from borrower characteristics and loan details"
936 .to_string(),
937 n_samples: 10000,
938 n_features: 15,
939 task_type: "classification".to_string(),
940 targetnames: Some(vec!["no_default".to_string(), "default".to_string()]),
941 featurenames: None,
942 url: None,
943 checksum: None,
944 };
945
946 let dataset = Dataset::from_metadata(data, Some(target), metadata);
947
948 if self.config.use_cache {
949 self.cache.put(&cache_key, &dataset)?;
950 }
951
952 Ok(dataset)
953 }
954}
955
956impl RealWorldDatasets {
958 fn download_and_parse_csv(
959 &self,
960 url: &str,
961 name: &str,
962 columns: &[&str],
963 target_col: Option<&str>,
964 has_categorical: bool,
965 ) -> Result<Dataset> {
966 if !self.config.download_if_missing {
968 return Err(DatasetsError::DownloadError(
969 "Download disabled in configuration".to_string(),
970 ));
971 }
972
973 #[cfg(feature = "download")]
975 {
976 match self.download_real_dataset(url, name, columns, target_col, has_categorical) {
977 Ok(dataset) => return Ok(dataset),
978 Err(e) => {
979 eprintln!("Warning: Failed to download real dataset from {}: {}. Falling back to synthetic data.", url, e);
980 }
981 }
982 }
983
984 match name {
986 "adult" => {
987 let (data, target) = self.create_synthetic_adult_dataset(32561, 14)?;
988
989 let featurenames = vec![
990 "age".to_string(),
991 "workclass".to_string(),
992 "fnlwgt".to_string(),
993 "education".to_string(),
994 "education_num".to_string(),
995 "marital_status".to_string(),
996 "occupation".to_string(),
997 "relationship".to_string(),
998 "race".to_string(),
999 "sex".to_string(),
1000 "capital_gain".to_string(),
1001 "capital_loss".to_string(),
1002 "hours_per_week".to_string(),
1003 "native_country".to_string(),
1004 ];
1005
1006 let metadata = crate::registry::DatasetMetadata {
1007 name: "Adult Census Income".to_string(),
1008 description: "Predict whether income exceeds $50K/yr based on census data"
1009 .to_string(),
1010 n_samples: 32561,
1011 n_features: 14,
1012 task_type: "classification".to_string(),
1013 targetnames: Some(vec!["<=50K".to_string(), ">50K".to_string()]),
1014 featurenames: Some(featurenames),
1015 url: Some(url.to_string()),
1016 checksum: None,
1017 };
1018
1019 Ok(Dataset::from_metadata(data, Some(target), metadata))
1020 }
1021 _ => {
1022 let n_features = columns.len() - if target_col.is_some() { 1 } else { 0 };
1024 let (data, target) =
1025 self.create_generic_synthetic_dataset(1000, n_features, has_categorical)?;
1026
1027 let featurenames: Vec<String> = columns
1028 .iter()
1029 .filter(|&&_col| Some(_col) != target_col)
1030 .map(|&_col| _col.to_string())
1031 .collect();
1032
1033 let metadata = crate::registry::DatasetMetadata {
1034 name: format!("Synthetic {name}"),
1035 description: format!("Synthetic version of {name} dataset"),
1036 n_samples: 1000,
1037 n_features,
1038 task_type: if target_col.is_some() {
1039 "classification"
1040 } else {
1041 "regression"
1042 }
1043 .to_string(),
1044 targetnames: None,
1045 featurenames: Some(featurenames),
1046 url: Some(url.to_string()),
1047 checksum: None,
1048 };
1049
1050 Ok(Dataset::from_metadata(data, target, metadata))
1051 }
1052 }
1053 }
1054
1055 #[cfg(feature = "download")]
1057 fn download_real_dataset(
1058 &self,
1059 url: &str,
1060 name: &str,
1061 columns: &[&str],
1062 target_col: Option<&str>,
1063 _has_categorical: bool,
1064 ) -> Result<Dataset> {
1065 use crate::cache::download_data;
1066 use std::collections::HashMap;
1067 use std::io::{BufRead, BufReader, Cursor};
1068
1069 let data_bytes = download_data(url, false)?;
1071
1072 let cursor = Cursor::new(data_bytes);
1074 let reader = BufReader::new(cursor);
1075
1076 let mut rows: Vec<Vec<String>> = Vec::new();
1077 let mut header_found = false;
1078
1079 for line_result in reader.lines() {
1080 let line = line_result
1081 .map_err(|e| DatasetsError::FormatError(format!("Failed to read line: {}", e)))?;
1082 let line = line.trim();
1083
1084 if line.is_empty() {
1085 continue;
1086 }
1087
1088 let fields: Vec<String> = line
1090 .split(',')
1091 .map(|s| s.trim().trim_matches('"').to_string())
1092 .collect();
1093
1094 if !header_found && fields.len() == columns.len() {
1095 let is_header = fields.iter().enumerate().all(|(i, field)| {
1097 field.to_lowercase().contains(&columns[i].to_lowercase())
1098 || columns[i].to_lowercase().contains(&field.to_lowercase())
1099 });
1100 if is_header {
1101 header_found = true;
1102 continue;
1103 }
1104 }
1105
1106 if fields.len() == columns.len() {
1107 rows.push(fields);
1108 }
1109 }
1110
1111 if rows.is_empty() {
1112 return Err(DatasetsError::FormatError(
1113 "No valid data rows found in CSV".to_string(),
1114 ));
1115 }
1116
1117 let n_samples = rows.len();
1119 let n_features = if target_col.is_some() {
1120 columns.len() - 1
1121 } else {
1122 columns.len()
1123 };
1124
1125 let mut data = Array2::<f64>::zeros((n_samples, n_features));
1126 let mut target = if target_col.is_some() {
1127 Some(Array1::<f64>::zeros(n_samples))
1128 } else {
1129 None
1130 };
1131
1132 let mut category_maps: HashMap<usize, HashMap<String, f64>> = HashMap::new();
1134
1135 for (row_idx, row) in rows.iter().enumerate() {
1136 let mut feature_idx = 0;
1137
1138 for (col_idx, value) in row.iter().enumerate() {
1139 if Some(columns[col_idx]) == target_col {
1140 if let Some(ref mut target_array) = target {
1142 let numeric_value = match value.parse::<f64>() {
1143 Ok(v) => v,
1144 Err(_) => {
1145 let category_map = category_maps.entry(col_idx).or_default();
1147 let next_id = category_map.len() as f64;
1148 *category_map.entry(value.clone()).or_insert(next_id)
1149 }
1150 };
1151 target_array[row_idx] = numeric_value;
1152 }
1153 } else {
1154 let numeric_value = match value.parse::<f64>() {
1156 Ok(v) => v,
1157 Err(_) => {
1158 let category_map = category_maps.entry(col_idx).or_default();
1160 let next_id = category_map.len() as f64;
1161 *category_map.entry(value.clone()).or_insert(next_id)
1162 }
1163 };
1164 data[[row_idx, feature_idx]] = numeric_value;
1165 feature_idx += 1;
1166 }
1167 }
1168 }
1169
1170 let featurenames: Vec<String> = columns
1172 .iter()
1173 .filter(|&&_col| Some(_col) != target_col)
1174 .map(|&col| col.to_string())
1175 .collect();
1176
1177 let metadata = crate::registry::DatasetMetadata {
1179 name: name.to_string(),
1180 description: format!("Real-world dataset: {}", name),
1181 n_samples,
1182 n_features,
1183 task_type: if target.is_some() {
1184 "classification".to_string()
1185 } else {
1186 "unsupervised".to_string()
1187 },
1188 targetnames: None,
1189 featurenames: Some(featurenames),
1190 url: Some(url.to_string()),
1191 checksum: None,
1192 };
1193
1194 Ok(Dataset::from_metadata(data, target, metadata))
1195 }
1196
1197 fn create_synthetic_bank_data(
1198 &self,
1199 n_samples: usize,
1200 n_features: usize,
1201 ) -> Result<(Array2<f64>, Array1<f64>)> {
1202 use scirs2_core::random::Rng;
1203 let mut rng = thread_rng();
1204
1205 let mut data = Array2::zeros((n_samples, n_features));
1206 let mut target = Array1::zeros(n_samples);
1207
1208 for i in 0..n_samples {
1209 for j in 0..n_features {
1210 data[[i, j]] = rng.gen_range(0.0..1.0);
1211 }
1212 target[i] = if data.row(i).iter().take(3).sum::<f64>() > 1.5 {
1214 1.0
1215 } else {
1216 0.0
1217 };
1218 }
1219
1220 Ok((data, target))
1221 }
1222
1223 #[allow(dead_code)]
1224 fn create_synthetic_credit_approval_data(&self) -> Result<Dataset> {
1225 use scirs2_core::random::Rng;
1226 let mut rng = thread_rng();
1227
1228 let n_samples = 690; let n_features = 15;
1230
1231 let mut data = Array2::zeros((n_samples, n_features));
1232 let mut target = Array1::zeros(n_samples);
1233
1234 let featurenames = vec![
1235 "credit_score".to_string(),
1236 "annual_income".to_string(),
1237 "debt_to_income_ratio".to_string(),
1238 "employment_length".to_string(),
1239 "age".to_string(),
1240 "home_ownership".to_string(),
1241 "loan_amount".to_string(),
1242 "loan_purpose".to_string(),
1243 "credit_history_length".to_string(),
1244 "number_of_credit_lines".to_string(),
1245 "utilization_rate".to_string(),
1246 "delinquency_count".to_string(),
1247 "education_level".to_string(),
1248 "marital_status".to_string(),
1249 "verification_status".to_string(),
1250 ];
1251
1252 for i in 0..n_samples {
1253 data[[i, 0]] = rng.gen_range(300.0..850.0);
1255 data[[i, 1]] = rng.gen_range(20000.0..200000.0);
1257 data[[i, 2]] = rng.gen_range(0.0..0.6);
1259 data[[i, 3]] = rng.gen_range(0.0..30.0);
1261 data[[i, 4]] = rng.gen_range(18.0..80.0);
1263 data[[i, 5]] = rng.gen_range(0.0f64..3.0).floor();
1265 data[[i, 6]] = rng.gen_range(1000.0..50000.0);
1267 data[[i, 7]] = rng.gen_range(0.0f64..7.0).floor();
1269 data[[i, 8]] = rng.gen_range(0.0..40.0);
1271 data[[i, 9]] = rng.gen_range(0.0..20.0);
1273 data[[i, 10]] = rng.gen_range(0.0..1.0);
1275 data[[i, 11]] = rng.gen_range(0.0f64..11.0).floor();
1277 data[[i, 12]] = rng.gen_range(0.0f64..4.0).floor();
1279 data[[i, 13]] = rng.gen_range(0.0f64..3.0).floor();
1281 data[[i, 14]] = if rng.gen_bool(0.7) { 1.0 } else { 0.0 };
1283
1284 let credit_score_factor = (data[[i, 0]] - 300.0) / 550.0; let income_factor = (data[[i, 1]] / 100000.0).min(1.0); let debt_factor = 1.0 - data[[i, 2]]; let employment_factor = (data[[i, 3]] / 10.0).min(1.0); let delinquency_penalty = data[[i, 11]] * 0.1; let approval_score = credit_score_factor * 0.4
1292 + income_factor * 0.3
1293 + debt_factor * 0.2
1294 + employment_factor * 0.1
1295 - delinquency_penalty;
1296
1297 let noise = rng.gen_range(-0.2..0.2);
1299 target[i] = if (approval_score + noise) > 0.5 {
1300 1.0
1301 } else {
1302 0.0
1303 };
1304 }
1305
1306 let metadata = crate::registry::DatasetMetadata {
1307 name: "Credit Approval Dataset".to_string(),
1308 description: "Synthetic credit approval dataset with realistic financial features for binary classification".to_string(),
1309 n_samples,
1310 n_features,
1311 task_type: "classification".to_string(),
1312 targetnames: Some(vec!["denied".to_string(), "approved".to_string()]),
1313 featurenames: Some(featurenames),
1314 url: None,
1315 checksum: None,
1316 };
1317
1318 Ok(Dataset::from_metadata(data, Some(target), metadata))
1319 }
1320
1321 #[allow(dead_code)]
1322 fn create_synthetic_mushroom_data(&self) -> Result<Dataset> {
1323 use scirs2_core::random::Rng;
1324 let mut rng = thread_rng();
1325
1326 let n_samples = 8124; let n_features = 22;
1328
1329 let mut data = Array2::zeros((n_samples, n_features));
1330 let mut target = Array1::zeros(n_samples);
1331
1332 let featurenames = vec![
1333 "capshape".to_string(),
1334 "cap_surface".to_string(),
1335 "cap_color".to_string(),
1336 "bruises".to_string(),
1337 "odor".to_string(),
1338 "gill_attachment".to_string(),
1339 "gill_spacing".to_string(),
1340 "gill_size".to_string(),
1341 "gill_color".to_string(),
1342 "stalkshape".to_string(),
1343 "stalk_root".to_string(),
1344 "stalk_surface_above_ring".to_string(),
1345 "stalk_surface_below_ring".to_string(),
1346 "stalk_color_above_ring".to_string(),
1347 "stalk_color_below_ring".to_string(),
1348 "veil_type".to_string(),
1349 "veil_color".to_string(),
1350 "ring_number".to_string(),
1351 "ring_type".to_string(),
1352 "spore_print_color".to_string(),
1353 "population".to_string(),
1354 "habitat".to_string(),
1355 ];
1356
1357 for i in 0..n_samples {
1358 data[[i, 0]] = rng.gen_range(0.0f64..6.0).floor();
1360 data[[i, 1]] = rng.gen_range(0.0f64..4.0).floor();
1362 data[[i, 2]] = rng.gen_range(0.0f64..10.0).floor();
1364 data[[i, 3]] = if rng.gen_bool(0.6) { 1.0 } else { 0.0 };
1366 data[[i, 4]] = rng.gen_range(0.0f64..9.0).floor();
1368 data[[i, 5]] = if rng.gen_bool(0.5) { 1.0 } else { 0.0 };
1370 data[[i, 6]] = if rng.gen_bool(0.5) { 1.0 } else { 0.0 };
1372 data[[i, 7]] = if rng.gen_bool(0.5) { 1.0 } else { 0.0 };
1374 data[[i, 8]] = rng.gen_range(0.0f64..12.0).floor();
1376 data[[i, 9]] = if rng.gen_bool(0.5) { 1.0 } else { 0.0 };
1378 data[[i, 10]] = rng.gen_range(0.0f64..5.0).floor();
1380 data[[i, 11]] = rng.gen_range(0.0f64..4.0).floor();
1382 data[[i, 12]] = rng.gen_range(0.0f64..4.0).floor();
1384 data[[i, 13]] = rng.gen_range(0.0f64..9.0).floor();
1386 data[[i, 14]] = rng.gen_range(0.0f64..9.0).floor();
1388 data[[i, 15]] = 0.0;
1390 data[[i, 16]] = rng.gen_range(0.0f64..4.0).floor();
1392 data[[i, 17]] = rng.gen_range(0.0f64..3.0).floor();
1394 data[[i, 18]] = rng.gen_range(0.0f64..8.0).floor();
1396 data[[i, 19]] = rng.gen_range(0.0f64..9.0).floor();
1398 data[[i, 20]] = rng.gen_range(0.0f64..6.0).floor();
1400 data[[i, 21]] = rng.gen_range(0.0f64..7.0).floor();
1402
1403 let mut poison_score = 0.0;
1406
1407 if data[[i, 4]] == 2.0 || data[[i, 4]] == 3.0 || data[[i, 4]] == 4.0 {
1409 poison_score += 0.8;
1411 }
1412 if data[[i, 4]] == 5.0 || data[[i, 4]] == 7.0 {
1413 poison_score += 0.4;
1415 }
1416
1417 if data[[i, 19]] == 2.0 || data[[i, 19]] == 4.0 {
1419 poison_score += 0.3;
1421 }
1422
1423 if data[[i, 10]] == 0.0 {
1425 poison_score += 0.2;
1427 }
1428
1429 let noise = rng.gen_range(-0.3..0.3);
1431 target[i] = if (poison_score + noise) > 0.5 {
1432 1.0
1433 } else {
1434 0.0
1435 }; }
1437
1438 let metadata = crate::registry::DatasetMetadata {
1439 name: "Mushroom Dataset".to_string(),
1440 description: "Synthetic mushroom classification dataset with morphological features for edibility prediction".to_string(),
1441 n_samples,
1442 n_features,
1443 task_type: "classification".to_string(),
1444 targetnames: Some(vec!["edible".to_string(), "poisonous".to_string()]),
1445 featurenames: Some(featurenames),
1446 url: None,
1447 checksum: None,
1448 };
1449
1450 Ok(Dataset::from_metadata(data, Some(target), metadata))
1451 }
1452
1453 #[allow(dead_code)]
1454 fn create_synthetic_spam_data(&self) -> Result<Dataset> {
1455 use scirs2_core::random::Rng;
1456 let mut rng = thread_rng();
1457
1458 let n_samples = 4601; let n_features = 57; let mut data = Array2::zeros((n_samples, n_features));
1462 let mut target = Array1::zeros(n_samples);
1463
1464 let mut featurenames = Vec::with_capacity(n_features);
1466
1467 let spam_words = vec![
1469 "make",
1470 "address",
1471 "all",
1472 "3d",
1473 "our",
1474 "over",
1475 "remove",
1476 "internet",
1477 "order",
1478 "mail",
1479 "receive",
1480 "will",
1481 "people",
1482 "report",
1483 "addresses",
1484 "free",
1485 "business",
1486 "email",
1487 "you",
1488 "credit",
1489 "your",
1490 "font",
1491 "000",
1492 "money",
1493 "hp",
1494 "hpl",
1495 "george",
1496 "650",
1497 "lab",
1498 "labs",
1499 "telnet",
1500 "857",
1501 "data",
1502 "415",
1503 "85",
1504 "technology",
1505 "1999",
1506 "parts",
1507 "pm",
1508 "direct",
1509 "cs",
1510 "meeting",
1511 "original",
1512 "project",
1513 "re",
1514 "edu",
1515 "table",
1516 "conference",
1517 "char_freq_semicolon",
1518 "char_freq_parenthesis",
1519 "char_freq_bracket",
1520 "char_freq_exclamation",
1521 "char_freq_dollar",
1522 "char_freq_hash",
1523 "capital_run_length_average",
1524 "capital_run_length_longest",
1525 "capital_run_length_total",
1526 ];
1527
1528 for (i, word) in spam_words.iter().enumerate() {
1529 if i < n_features {
1530 featurenames.push(format!("word_freq_{word}"));
1531 }
1532 }
1533
1534 while featurenames.len() < n_features {
1536 featurenames.push(format!("feature_{}", featurenames.len()));
1537 }
1538
1539 for i in 0..n_samples {
1540 let is_spam = rng.gen_bool(0.4); for j in 0..54 {
1544 if is_spam {
1545 match j {
1547 0..=7 => data[[i, j]] = rng.gen_range(0.0..5.0), 8..=15 => data[[i, j]] = rng.gen_range(0.0..3.0), 16..=25 => data[[i, j]] = rng.gen_range(0.0..4.0), _ => data[[i, j]] = rng.gen_range(0.0..1.0), }
1552 } else {
1553 match j {
1555 26..=35 => data[[i, j]] = rng.gen_range(0.0..2.0), 36..=45 => data[[i, j]] = rng.gen_range(0.0..1.5), _ => data[[i, j]] = rng.gen_range(0.0..0.5), }
1559 }
1560 }
1561
1562 if is_spam {
1564 data[[i, 54]] = rng.gen_range(0.0..0.2); data[[i, 55]] = rng.gen_range(0.0..0.5); data[[i, 56]] = rng.gen_range(0.0..0.3); } else {
1568 data[[i, 54]] = rng.gen_range(0.0..0.1);
1569 data[[i, 55]] = rng.gen_range(0.0..0.2);
1570 data[[i, 56]] = rng.gen_range(0.0..0.1);
1571 }
1572
1573 target[i] = if is_spam { 1.0 } else { 0.0 };
1574 }
1575
1576 let metadata = crate::registry::DatasetMetadata {
1577 name: "Spam Email Dataset".to_string(),
1578 description: "Synthetic spam email classification dataset with word and character frequency features".to_string(),
1579 n_samples,
1580 n_features,
1581 task_type: "classification".to_string(),
1582 targetnames: Some(vec!["ham".to_string(), "spam".to_string()]),
1583 featurenames: Some(featurenames),
1584 url: None,
1585 checksum: None,
1586 };
1587
1588 Ok(Dataset::from_metadata(data, Some(target), metadata))
1589 }
1590
1591 fn create_synthetic_titanic_data(
1592 &self,
1593 n_samples: usize,
1594 n_features: usize,
1595 ) -> Result<(Array2<f64>, Array1<f64>)> {
1596 use scirs2_core::random::Rng;
1597 let mut rng = thread_rng();
1598
1599 let mut data = Array2::zeros((n_samples, n_features));
1600 let mut target = Array1::zeros(n_samples);
1601
1602 for i in 0..n_samples {
1603 data[[i, 0]] = rng.gen_range(1.0f64..4.0).floor();
1605 data[[i, 1]] = if rng.gen_bool(0.5) { 0.0 } else { 1.0 };
1607 data[[i, 2]] = rng.gen_range(1.0..80.0);
1609 data[[i, 3]] = rng.gen_range(0.0f64..6.0).floor();
1611 data[[i, 4]] = rng.gen_range(0.0f64..4.0).floor();
1613 data[[i, 5]] = rng.gen_range(0.0..512.0);
1615 data[[i, 6]] = rng.gen_range(0.0f64..3.0).floor();
1617
1618 let survival_score = (4.0 - data[[i, 0]]) * 0.3 + (1.0 - data[[i, 1]]) * 0.4 + (80.0 - data[[i, 2]]) / 80.0 * 0.3; target[i] = if survival_score > 0.5 { 1.0 } else { 0.0 };
1624 }
1625
1626 Ok((data, target))
1627 }
1628
1629 fn create_synthetic_credit_data(
1630 &self,
1631 n_samples: usize,
1632 n_features: usize,
1633 ) -> Result<(Array2<f64>, Array1<f64>)> {
1634 use scirs2_core::random::Rng;
1635 let mut rng = thread_rng();
1636
1637 let mut data = Array2::zeros((n_samples, n_features));
1638 let mut target = Array1::zeros(n_samples);
1639
1640 for i in 0..n_samples {
1641 for j in 0..n_features {
1642 data[[i, j]] = rng.gen_range(0.0..1.0);
1643 }
1644 let score = data.row(i).iter().sum::<f64>() / n_features as f64;
1646 target[i] = if score > 0.6 { 1.0 } else { 0.0 };
1647 }
1648
1649 Ok((data, target))
1650 }
1651
1652 fn create_synthetic_housing_data(
1653 &self,
1654 n_samples: usize,
1655 n_features: usize,
1656 ) -> Result<(Array2<f64>, Array1<f64>)> {
1657 use scirs2_core::random::Rng;
1658 let mut rng = thread_rng();
1659
1660 let mut data = Array2::zeros((n_samples, n_features));
1661 let mut target = Array1::zeros(n_samples);
1662
1663 for i in 0..n_samples {
1664 data[[i, 0]] = rng.gen_range(0.5..15.0);
1666 data[[i, 1]] = rng.gen_range(1.0..52.0);
1668 data[[i, 2]] = rng.gen_range(3.0..20.0);
1670 data[[i, 3]] = rng.gen_range(0.8..6.0);
1672 data[[i, 4]] = rng.gen_range(3.0..35682.0);
1674 data[[i, 5]] = rng.gen_range(0.7..1243.0);
1676 data[[i, 6]] = rng.gen_range(32.0..42.0);
1678 data[[i, 7]] = rng.gen_range(-124.0..-114.0);
1680
1681 let house_value = data[[i, 0]] * 50000.0 + data[[i, 2]] * 10000.0 + (40.0 - data[[i, 6]]) * 5000.0; target[i] = house_value / 100000.0; }
1688
1689 Ok((data, target))
1690 }
1691
1692 fn create_synthetic_wine_data(
1693 &self,
1694 n_samples: usize,
1695 n_features: usize,
1696 ) -> Result<(Array2<f64>, Array1<f64>)> {
1697 use scirs2_core::random::Rng;
1698 let mut rng = thread_rng();
1699
1700 let mut data = Array2::zeros((n_samples, n_features));
1701 let mut target = Array1::zeros(n_samples);
1702
1703 for i in 0..n_samples {
1704 data[[i, 0]] = rng.gen_range(4.6..15.9); data[[i, 1]] = rng.gen_range(0.12..1.58); data[[i, 2]] = rng.gen_range(0.0..1.0); data[[i, 3]] = rng.gen_range(0.9..15.5); data[[i, 4]] = rng.gen_range(0.012..0.611); data[[i, 5]] = rng.gen_range(1.0..72.0); data[[i, 6]] = rng.gen_range(6.0..289.0); data[[i, 7]] = rng.gen_range(0.99007..1.00369); data[[i, 8]] = rng.gen_range(2.74..4.01); data[[i, 9]] = rng.gen_range(0.33..2.0); data[[i, 10]] = rng.gen_range(8.4..14.9); let quality: f64 = 3.0 +
1719 (data[[i, 10]] - 8.0) * 0.5 + (1.0 - data[[i, 1]]) * 2.0 + data[[i, 2]] * 2.0 + rng.gen_range(-0.5..0.5); target[i] = quality.clamp(3.0, 8.0);
1725 }
1726
1727 Ok((data, target))
1728 }
1729
1730 fn create_synthetic_energy_data(
1731 &self,
1732 n_samples: usize,
1733 n_features: usize,
1734 ) -> Result<(Array2<f64>, Array1<f64>)> {
1735 use scirs2_core::random::Rng;
1736 let mut rng = thread_rng();
1737
1738 let mut data = Array2::zeros((n_samples, n_features));
1739 let mut target = Array1::zeros(n_samples);
1740
1741 for i in 0..n_samples {
1742 for j in 0..n_features {
1743 data[[i, j]] = rng.gen_range(0.0..1.0);
1744 }
1745
1746 let efficiency = data.row(i).iter().sum::<f64>() / n_features as f64;
1748 target[i] = efficiency * 40.0 + 10.0; }
1750
1751 Ok((data, target))
1752 }
1753
1754 fn create_air_passengers_data(
1755 &self,
1756 n_timesteps: usize,
1757 ) -> Result<(Array2<f64>, Option<Array1<f64>>)> {
1758 use scirs2_core::random::Rng;
1759 let mut rng = thread_rng();
1760 let mut data = Array2::zeros((n_timesteps, 1));
1761
1762 for i in 0..n_timesteps {
1763 let t = i as f64;
1764 let trend = 100.0 + t * 2.0;
1765 let seasonal = 20.0 * (2.0 * std::f64::consts::PI * t / 12.0).sin();
1766 let noise = rng.random::<f64>() * 10.0 - 5.0;
1767
1768 data[[i, 0]] = trend + seasonal + noise;
1769 }
1770
1771 Ok((data, None))
1772 }
1773
1774 fn create_bitcoin_price_data(
1775 &self,
1776 n_timesteps: usize,
1777 ) -> Result<(Array2<f64>, Option<Array1<f64>>)> {
1778 use scirs2_core::random::Rng;
1779 let mut rng = thread_rng();
1780
1781 let mut data = Array2::zeros((n_timesteps, 6));
1782 let mut price = 30000.0; for i in 0..n_timesteps {
1785 let change = rng.gen_range(-0.05..0.05);
1787 price *= 1.0 + change;
1788
1789 let high = price * (1.0 + rng.gen_range(0.0..0.02));
1790 let low = price * (1.0 - rng.gen_range(0.0..0.02));
1791 let volume = rng.gen_range(1000000.0..10000000.0);
1792
1793 data[[i, 0]] = price; data[[i, 1]] = high;
1795 data[[i, 2]] = low;
1796 data[[i, 3]] = price; data[[i, 4]] = volume;
1798 data[[i, 5]] = price * volume; }
1800
1801 Ok((data, None))
1802 }
1803
1804 fn create_heart_disease_data(
1805 &self,
1806 n_samples: usize,
1807 n_features: usize,
1808 ) -> Result<(Array2<f64>, Array1<f64>)> {
1809 use scirs2_core::random::Rng;
1810 let mut rng = thread_rng();
1811
1812 let mut data = Array2::zeros((n_samples, n_features));
1813 let mut target = Array1::zeros(n_samples);
1814
1815 for i in 0..n_samples {
1816 data[[i, 0]] = rng.gen_range(29.0..77.0);
1818 data[[i, 1]] = if rng.gen_bool(0.68) { 1.0 } else { 0.0 };
1820 data[[i, 2]] = rng.gen_range(0.0f64..4.0).floor();
1822 data[[i, 3]] = rng.gen_range(94.0..200.0);
1824 data[[i, 4]] = rng.gen_range(126.0..564.0);
1826
1827 for j in 5..n_features {
1829 data[[i, j]] = rng.gen_range(0.0..1.0);
1830 }
1831
1832 let risk_score = (data[[i, 0]] - 29.0) / 48.0 * 0.3 + data[[i, 1]] * 0.2 + (data[[i, 3]] - 94.0) / 106.0 * 0.2 + (data[[i, 4]] - 126.0) / 438.0 * 0.3; target[i] = if risk_score > 0.5 { 1.0 } else { 0.0 };
1839 }
1840
1841 Ok((data, target))
1842 }
1843
1844 fn create_diabetes_readmission_data(
1845 &self,
1846 n_samples: usize,
1847 n_features: usize,
1848 ) -> Result<(Array2<f64>, Array1<f64>)> {
1849 use scirs2_core::random::Rng;
1850 let mut rng = thread_rng();
1851
1852 let mut data = Array2::zeros((n_samples, n_features));
1853 let mut target = Array1::zeros(n_samples);
1854
1855 for i in 0..n_samples {
1856 for j in 0..n_features {
1857 data[[i, j]] = rng.gen_range(0.0..1.0);
1858 }
1859
1860 let readmission_score = data.row(i).iter().take(10).sum::<f64>() / 10.0;
1862 target[i] = if readmission_score > 0.6 { 1.0 } else { 0.0 };
1863 }
1864
1865 Ok((data, target))
1866 }
1867
1868 fn create_synthetic_auto_mpg_data(
1869 &self,
1870 n_samples: usize,
1871 n_features: usize,
1872 ) -> Result<(Array2<f64>, Array1<f64>)> {
1873 use scirs2_core::random::Rng;
1874 let mut rng = thread_rng();
1875
1876 let mut data = Array2::zeros((n_samples, n_features));
1877 let mut target = Array1::zeros(n_samples);
1878
1879 for i in 0..n_samples {
1880 data[[i, 0]] = [4.0, 6.0, 8.0][rng.sample(Uniform::new(0, 3).unwrap())];
1882 data[[i, 1]] = rng.gen_range(68.0..455.0);
1884 data[[i, 2]] = rng.gen_range(46.0..230.0);
1886 data[[i, 3]] = rng.gen_range(1613.0..5140.0);
1888 data[[i, 4]] = rng.gen_range(8.0..24.8);
1890 data[[i, 5]] = rng.gen_range(70.0..82.0);
1892 data[[i, 6]] = (rng.gen_range(1.0f64..4.0f64)).floor();
1894
1895 let mpg: f64 = 45.0 - (data[[i, 3]] / 5140.0) * 20.0 - (data[[i, 1]] / 455.0) * 15.0
1897 + (data[[i, 4]] / 24.8) * 10.0
1898 + rng.gen_range(-3.0..3.0);
1899 target[i] = mpg.clamp(9.0, 46.6);
1900 }
1901
1902 Ok((data, target))
1903 }
1904
1905 fn create_synthetic_concrete_data(
1906 &self,
1907 n_samples: usize,
1908 n_features: usize,
1909 ) -> Result<(Array2<f64>, Array1<f64>)> {
1910 use scirs2_core::random::Rng;
1911 let mut rng = thread_rng();
1912
1913 let mut data = Array2::zeros((n_samples, n_features));
1914 let mut target = Array1::zeros(n_samples);
1915
1916 for i in 0..n_samples {
1917 data[[i, 0]] = rng.gen_range(102.0..540.0);
1919 data[[i, 1]] = rng.gen_range(0.0..359.4);
1921 data[[i, 2]] = rng.gen_range(0.0..200.1);
1923 data[[i, 3]] = rng.gen_range(121.8..247.0);
1925 data[[i, 4]] = rng.gen_range(0.0..32.2);
1927 data[[i, 5]] = rng.gen_range(801.0..1145.0);
1929 data[[i, 6]] = rng.gen_range(594.0..992.6);
1931 data[[i, 7]] = rng.gen_range(1.0..365.0);
1933
1934 let strength: f64 = (data[[i, 0]] / 540.0) * 30.0 + (data[[i, 1]] / 359.4) * 15.0 + (data[[i, 3]] / 247.0) * (-20.0) + (data[[i, 7]] / 365.0_f64).ln() * 10.0 + rng.gen_range(-5.0..5.0); target[i] = strength.clamp(2.33, 82.6);
1942 }
1943
1944 Ok((data, target))
1945 }
1946
1947 fn create_synthetic_electricity_data(
1948 &self,
1949 n_samples: usize,
1950 n_features: usize,
1951 ) -> Result<(Array2<f64>, Array1<f64>)> {
1952 use scirs2_core::random::Rng;
1953 let mut rng = thread_rng();
1954
1955 let mut data = Array2::zeros((n_samples, n_features));
1956 let mut target = Array1::zeros(n_samples);
1957
1958 for i in 0..n_samples {
1959 let hour = (i % 24) as f64;
1960 let day_of_year = (i / 24) % 365;
1961
1962 data[[i, 0]] = 20.0
1964 + 15.0 * (day_of_year as f64 * 2.0 * std::f64::consts::PI / 365.0).sin()
1965 + rng.gen_range(-5.0..5.0);
1966 data[[i, 1]] = 50.0 + 30.0 * rng.gen_range(0.0..1.0);
1968 data[[i, 2]] = hour;
1970
1971 let seasonal = 50.0
1973 + 30.0
1974 * (day_of_year as f64 * 2.0 * std::f64::consts::PI / 365.0
1975 + std::f64::consts::PI)
1976 .cos();
1977 let daily = 40.0 + 60.0 * ((hour - 12.0) * std::f64::consts::PI / 12.0).cos();
1978 let temp_effect = (data[[i, 0]] - 20.0).abs() * 2.0; target[i] = seasonal + daily + temp_effect + rng.gen_range(-10.0..10.0);
1981 }
1982
1983 Ok((data, target))
1984 }
1985
1986 fn create_synthetic_stock_data(
1987 &self,
1988 n_samples: usize,
1989 n_features: usize,
1990 ) -> Result<(Array2<f64>, Array1<f64>)> {
1991 use scirs2_core::random::Rng;
1992 let mut rng = thread_rng();
1993
1994 let mut data = Array2::zeros((n_samples, n_features));
1995 let mut target = Array1::zeros(n_samples);
1996
1997 let mut price = 100.0; for i in 0..n_samples {
2000 let change = rng.gen_range(-0.05..0.05);
2002 price *= 1.0 + change;
2003
2004 let high = price * (1.0 + rng.gen_range(0.0..0.02));
2006 let low = price * (1.0 - rng.gen_range(0.0..0.02));
2007 let volume = rng.gen_range(1000000.0..10000000.0);
2008
2009 data[[i, 0]] = price; data[[i, 1]] = high;
2011 data[[i, 2]] = low;
2012 data[[i, 3]] = volume;
2013 data[[i, 4]] = (high - low) / price; let next_change = rng.gen_range(-0.05..0.05);
2017 target[i] = next_change;
2018 }
2019
2020 Ok((data, target))
2021 }
2022
2023 fn create_synthetic_fraud_data(
2024 &self,
2025 n_samples: usize,
2026 n_features: usize,
2027 ) -> Result<(Array2<f64>, Array1<f64>)> {
2028 use scirs2_core::random::Rng;
2029 let mut rng = thread_rng();
2030
2031 let mut data = Array2::zeros((n_samples, n_features));
2032 let mut target = Array1::zeros(n_samples);
2033
2034 for i in 0..n_samples {
2035 let is_fraud = rng.gen_range(0.0..1.0) < 0.001728; for j in 0..n_features {
2038 if j < 28 {
2039 if is_fraud {
2041 data[[i, j]] = rng.gen_range(-5.0..5.0) * 2.0; } else {
2044 data[[i, j]] = rng.gen_range(-3.0..3.0);
2046 }
2047 }
2048 }
2049
2050 target[i] = if is_fraud { 1.0 } else { 0.0 };
2051 }
2052
2053 Ok((data, target))
2054 }
2055
2056 fn create_synthetic_loan_data(
2057 &self,
2058 n_samples: usize,
2059 n_features: usize,
2060 ) -> Result<(Array2<f64>, Array1<f64>)> {
2061 use scirs2_core::random::Rng;
2062 let mut rng = thread_rng();
2063
2064 let mut data = Array2::zeros((n_samples, n_features));
2065 let mut target = Array1::zeros(n_samples);
2066
2067 for i in 0..n_samples {
2068 data[[i, 0]] = rng.gen_range(1000.0..50000.0);
2070 data[[i, 1]] = rng.gen_range(5.0..25.0);
2072 data[[i, 2]] = [12.0, 24.0, 36.0, 48.0, 60.0][rng.sample(Uniform::new(0, 5).unwrap())];
2074 data[[i, 3]] = rng.gen_range(20000.0..200000.0);
2076 data[[i, 4]] = rng.gen_range(300.0..850.0);
2078 data[[i, 5]] = rng.gen_range(0.0..40.0);
2080 data[[i, 6]] = rng.gen_range(0.0..0.4);
2082
2083 for j in 7..n_features {
2085 data[[i, j]] = rng.gen_range(0.0..1.0);
2086 }
2087
2088 let risk_score = (850.0 - data[[i, 4]]) / 550.0 * 0.4 + data[[i, 6]] * 0.3 + (data[[i, 1]] - 5.0) / 20.0 * 0.2 + (50000.0 - data[[i, 3]]) / 180000.0 * 0.1; target[i] = if risk_score > 0.3 { 1.0 } else { 0.0 };
2095 }
2096
2097 Ok((data, target))
2098 }
2099
2100 fn create_synthetic_adult_dataset(
2101 &self,
2102 n_samples: usize,
2103 n_features: usize,
2104 ) -> Result<(Array2<f64>, Array1<f64>)> {
2105 use scirs2_core::random::Rng;
2106 let mut rng = thread_rng();
2107
2108 let mut data = Array2::zeros((n_samples, n_features));
2109 let mut target = Array1::zeros(n_samples);
2110
2111 for i in 0..n_samples {
2112 data[[i, 0]] = rng.gen_range(17.0..90.0);
2114 data[[i, 1]] = rng.gen_range(0.0f64..9.0).floor();
2116 data[[i, 2]] = rng.gen_range(12285.0..1484705.0);
2118 data[[i, 3]] = rng.gen_range(0.0f64..16.0).floor();
2120 data[[i, 4]] = rng.gen_range(1.0..17.0);
2122 data[[i, 5]] = rng.gen_range(0.0f64..7.0).floor();
2124 data[[i, 6]] = rng.gen_range(0.0f64..14.0).floor();
2126 data[[i, 7]] = rng.gen_range(0.0f64..6.0).floor();
2128 data[[i, 8]] = rng.gen_range(0.0f64..5.0).floor();
2130 data[[i, 9]] = if rng.gen_bool(0.67) { 1.0 } else { 0.0 };
2132 data[[i, 10]] = if rng.gen_bool(0.9) {
2134 0.0
2135 } else {
2136 rng.gen_range(1.0..99999.0)
2137 };
2138 data[[i, 11]] = if rng.gen_bool(0.95) {
2140 0.0
2141 } else {
2142 rng.gen_range(1.0..4356.0)
2143 };
2144 data[[i, 12]] = rng.gen_range(1.0..99.0);
2146 data[[i, 13]] = rng.gen_range(0.0f64..41.0).floor();
2148
2149 let income_score = (data[[i, 0]] - 17.0) / 73.0 * 0.2 + data[[i, 4]] / 16.0 * 0.3 + data[[i, 9]] * 0.2 + (data[[i, 12]] - 1.0) / 98.0 * 0.2 + (data[[i, 10]] + data[[i, 11]]) / 100000.0 * 0.1; let noise = rng.gen_range(-0.15..0.15);
2158 target[i] = if (income_score + noise) > 0.5 {
2159 1.0
2160 } else {
2161 0.0
2162 };
2163 }
2164
2165 Ok((data, target))
2166 }
2167
2168 fn create_generic_synthetic_dataset(
2169 &self,
2170 n_samples: usize,
2171 n_features: usize,
2172 has_categorical: bool,
2173 ) -> Result<(Array2<f64>, Option<Array1<f64>>)> {
2174 use scirs2_core::random::Rng;
2175 let mut rng = thread_rng();
2176
2177 let mut data = Array2::zeros((n_samples, n_features));
2178
2179 for i in 0..n_samples {
2180 for j in 0..n_features {
2181 if has_categorical && j < n_features / 3 {
2182 data[[i, j]] = rng.gen_range(0.0f64..10.0).floor();
2184 } else {
2185 data[[i, j]] = rng.gen_range(-2.0..2.0);
2187 }
2188 }
2189 }
2190
2191 let mut target = Array1::zeros(n_samples);
2193 for i in 0..n_samples {
2194 let feature_sum = data.row(i).iter().sum::<f64>();
2195 let score = feature_sum / n_features as f64;
2196 target[i] = if score > 0.0 { 1.0 } else { 0.0 };
2197 }
2198
2199 Ok((data, Some(target)))
2200 }
2201
2202 fn create_synthetic_cifar10_data(
2203 &self,
2204 n_samples: usize,
2205 n_features: usize,
2206 ) -> Result<(Array2<f64>, Array1<f64>)> {
2207 use scirs2_core::random::Rng;
2208 let mut rng = thread_rng();
2209
2210 let mut data = Array2::zeros((n_samples, n_features));
2211 let mut target = Array1::zeros(n_samples);
2212
2213 for i in 0..n_samples {
2214 let class = rng.sample(Uniform::new(0, 10).unwrap()) as f64;
2215 target[i] = class;
2216
2217 for j in 0..n_features {
2219 let base_intensity = match class as i32 {
2220 0 => 0.6, 1 => 0.3, 2 => 0.8, 3 => 0.5, 4 => 0.7, 5 => 0.4, 6 => 0.9, 7 => 0.6, 8 => 0.2, 9 => 0.3, _ => 0.5,
2231 };
2232
2233 data[[i, j]] = base_intensity + rng.gen_range(-0.3f64..0.3f64);
2235 data[[i, j]] = data[[i, j]].clamp(0.0f64, 1.0f64); }
2237 }
2238
2239 Ok((data, target))
2240 }
2241
2242 fn create_synthetic_fashion_mnist_data(
2243 &self,
2244 n_samples: usize,
2245 n_features: usize,
2246 ) -> Result<(Array2<f64>, Array1<f64>)> {
2247 use scirs2_core::random::Rng;
2248 let mut rng = thread_rng();
2249
2250 let mut data = Array2::zeros((n_samples, n_features));
2251 let mut target = Array1::zeros(n_samples);
2252
2253 for i in 0..n_samples {
2254 let class = rng.sample(Uniform::new(0, 10).unwrap()) as f64;
2255 target[i] = class;
2256
2257 for j in 0..n_features {
2259 let base_intensity = match class as i32 {
2260 0 => 0.3, 1 => 0.4, 2 => 0.5, 3 => 0.6, 4 => 0.7, 5 => 0.2, 6 => 0.4, 7 => 0.3, 8 => 0.5, 9 => 0.4, _ => 0.4,
2271 };
2272
2273 let texture_noise = rng.gen_range(-0.2f64..0.2f64);
2275 data[[i, j]] = base_intensity + texture_noise;
2276 data[[i, j]] = data[[i, j]].clamp(0.0f64, 1.0f64); }
2278 }
2279
2280 Ok((data, target))
2281 }
2282
2283 fn create_synthetic_imdb_data(
2284 &self,
2285 n_samples: usize,
2286 n_features: usize,
2287 ) -> Result<(Array2<f64>, Array1<f64>)> {
2288 use scirs2_core::random::Rng;
2289 let mut rng = thread_rng();
2290
2291 let mut data = Array2::zeros((n_samples, n_features));
2292 let mut target = Array1::zeros(n_samples);
2293
2294 let positive_words = 0..n_features / 3; let negative_words = n_features / 3..2 * n_features / 3; let _neutral_words = 2 * n_features / 3..n_features; for i in 0..n_samples {
2300 let is_positive = rng.gen_bool(0.5);
2301 target[i] = if is_positive { 1.0 } else { 0.0 };
2302
2303 for j in 0..n_features {
2304 let base_freq = if positive_words.contains(&j) {
2305 if is_positive {
2306 rng.gen_range(0.5..2.0) } else {
2308 rng.gen_range(0.0..0.5) }
2310 } else if negative_words.contains(&j) {
2311 if is_positive {
2312 rng.gen_range(0.0..0.5) } else {
2314 rng.gen_range(0.5..2.0) }
2316 } else {
2317 rng.gen_range(0.2..1.0)
2319 };
2320
2321 data[[i, j]] = base_freq;
2322 }
2323 }
2324
2325 Ok((data, target))
2326 }
2327
2328 fn create_synthetic_news_data(
2329 &self,
2330 n_samples: usize,
2331 n_features: usize,
2332 ) -> Result<(Array2<f64>, Array1<f64>)> {
2333 use scirs2_core::random::Rng;
2334 let mut rng = thread_rng();
2335
2336 let mut data = Array2::zeros((n_samples, n_features));
2337 let mut target = Array1::zeros(n_samples);
2338
2339 let words_per_topic = n_features / 5; for i in 0..n_samples {
2343 let topic = rng.sample(Uniform::new(0, 5).unwrap()) as f64;
2344 target[i] = topic;
2345
2346 for j in 0..n_features {
2347 let word_topic = j / words_per_topic;
2348
2349 let base_freq = if word_topic == topic as usize {
2350 rng.gen_range(1.0..3.0)
2352 } else {
2353 rng.gen_range(0.0..0.8)
2355 };
2356
2357 let noise = rng.gen_range(-0.2f64..0.2f64);
2359 data[[i, j]] = (base_freq + noise).max(0.0f64);
2360 }
2361 }
2362
2363 Ok((data, target))
2364 }
2365}
2366
2367#[allow(dead_code)]
2369pub fn load_adult() -> Result<Dataset> {
2370 let config = RealWorldConfig::default();
2371 let mut loader = RealWorldDatasets::new(config)?;
2372 loader.load_adult()
2373}
2374
2375#[allow(dead_code)]
2377pub fn load_titanic() -> Result<Dataset> {
2378 let config = RealWorldConfig::default();
2379 let mut loader = RealWorldDatasets::new(config)?;
2380 loader.load_titanic()
2381}
2382
2383#[allow(dead_code)]
2385pub fn load_california_housing() -> Result<Dataset> {
2386 let config = RealWorldConfig::default();
2387 let mut loader = RealWorldDatasets::new(config)?;
2388 loader.load_california_housing()
2389}
2390
2391#[allow(dead_code)]
2393pub fn load_heart_disease() -> Result<Dataset> {
2394 let config = RealWorldConfig::default();
2395 let mut loader = RealWorldDatasets::new(config)?;
2396 loader.load_heart_disease()
2397}
2398
2399#[allow(dead_code)]
2401pub fn load_red_wine_quality() -> Result<Dataset> {
2402 let config = RealWorldConfig::default();
2403 let mut loader = RealWorldDatasets::new(config)?;
2404 loader.load_red_wine_quality()
2405}
2406
2407#[allow(dead_code)]
2409pub fn list_real_world_datasets() -> Vec<String> {
2410 let config = RealWorldConfig::default();
2411 let loader = RealWorldDatasets::new(config).unwrap();
2412 loader.list_datasets()
2413}
2414
2415#[cfg(test)]
2416mod tests {
2417 use super::*;
2418 use scirs2_core::random::Uniform;
2419
2420 #[test]
2421 fn test_load_titanic() {
2422 let dataset = load_titanic().unwrap();
2423 assert_eq!(dataset.n_samples(), 891);
2424 assert_eq!(dataset.n_features(), 7);
2425 assert!(dataset.target.is_some());
2426 }
2427
2428 #[test]
2429 fn test_load_california_housing() {
2430 let dataset = load_california_housing().unwrap();
2431 assert_eq!(dataset.n_samples(), 20640);
2432 assert_eq!(dataset.n_features(), 8);
2433 assert!(dataset.target.is_some());
2434 }
2435
2436 #[test]
2437 fn test_load_heart_disease() {
2438 let dataset = load_heart_disease().unwrap();
2439 assert_eq!(dataset.n_samples(), 303);
2440 assert_eq!(dataset.n_features(), 13);
2441 assert!(dataset.target.is_some());
2442 }
2443
2444 #[test]
2445 fn test_list_datasets() {
2446 let datasets = list_real_world_datasets();
2447 assert!(!datasets.is_empty());
2448 assert!(datasets.contains(&"titanic".to_string()));
2449 assert!(datasets.contains(&"california_housing".to_string()));
2450 }
2451
2452 #[test]
2453 fn test_real_world_config() {
2454 let config = RealWorldConfig {
2455 use_cache: false,
2456 download_if_missing: false,
2457 ..Default::default()
2458 };
2459
2460 assert!(!config.use_cache);
2461 assert!(!config.download_if_missing);
2462 }
2463}