1use crate::cache::{CacheKey, CacheManager};
8use crate::error::{DatasetsError, Result};
9use crate::registry::{DatasetMetadata, DatasetRegistry};
10use crate::utils::Dataset;
11use ndarray::{Array1, Array2};
12use rand_distr::Uniform;
13use scirs2_core::rng;
14
15#[derive(Debug, Clone)]
17pub struct RealWorldConfig {
18 pub use_cache: bool,
20 pub download_if_missing: bool,
22 pub data_home: Option<String>,
24 pub return_preprocessed: bool,
26 pub subset: Option<String>,
28 pub random_state: Option<u64>,
30}
31
32impl Default for RealWorldConfig {
33 fn default() -> Self {
34 Self {
35 use_cache: true,
36 download_if_missing: true,
37 data_home: None,
38 return_preprocessed: false,
39 subset: None,
40 random_state: None,
41 }
42 }
43}
44
45pub struct RealWorldDatasets {
47 cache: CacheManager,
48 registry: DatasetRegistry,
49 config: RealWorldConfig,
50}
51
52impl RealWorldDatasets {
53 pub fn new(config: RealWorldConfig) -> Result<Self> {
55 let cache = CacheManager::new()?;
56 let registry = DatasetRegistry::new();
57
58 Ok(Self {
59 cache,
60 registry,
61 config,
62 })
63 }
64
65 pub fn load_dataset(&mut self, name: &str) -> Result<Dataset> {
67 match name {
68 "adult" => self.load_adult(),
70 "bank_marketing" => self.load_bank_marketing(),
71 "credit_approval" => self.load_credit_approval(),
72 "german_credit" => self.load_german_credit(),
73 "mushroom" => self.load_mushroom(),
74 "spam" => self.load_spam(),
75 "titanic" => self.load_titanic(),
76
77 "auto_mpg" => self.load_auto_mpg(),
79 "california_housing" => self.load_california_housing(),
80 "concrete_strength" => self.load_concrete_strength(),
81 "energy_efficiency" => self.load_energy_efficiency(),
82 "red_wine_quality" => self.load_red_wine_quality(),
83 "white_wine_quality" => self.load_white_wine_quality(),
84
85 "air_passengers" => self.load_air_passengers(),
87 "bitcoin_prices" => self.load_bitcoin_prices(),
88 "electricity_load" => self.load_electricity_load(),
89 "stock_prices" => self.load_stock_prices(),
90
91 "cifar10_subset" => self.load_cifar10_subset(),
93 "fashion_mnist_subset" => self.load_fashion_mnist_subset(),
94
95 "imdb_reviews" => self.load_imdb_reviews(),
97 "news_articles" => self.load_news_articles(),
98
99 "diabetes_readmission" => self.load_diabetes_readmission(),
101 "heart_disease" => self.load_heart_disease(),
102
103 "credit_card_fraud" => self.load_credit_card_fraud(),
105 "loan_default" => self.load_loan_default(),
106 _ => Err(DatasetsError::NotFound(format!("Unknown dataset: {name}"))),
107 }
108 }
109
110 pub fn list_datasets(&self) -> Vec<String> {
112 vec![
113 "adult".to_string(),
115 "bank_marketing".to_string(),
116 "credit_approval".to_string(),
117 "german_credit".to_string(),
118 "mushroom".to_string(),
119 "spam".to_string(),
120 "titanic".to_string(),
121 "auto_mpg".to_string(),
123 "california_housing".to_string(),
124 "concrete_strength".to_string(),
125 "energy_efficiency".to_string(),
126 "red_wine_quality".to_string(),
127 "white_wine_quality".to_string(),
128 "air_passengers".to_string(),
130 "bitcoin_prices".to_string(),
131 "electricity_load".to_string(),
132 "stock_prices".to_string(),
133 "cifar10_subset".to_string(),
135 "fashion_mnist_subset".to_string(),
136 "imdb_reviews".to_string(),
138 "news_articles".to_string(),
139 "diabetes_readmission".to_string(),
141 "heart_disease".to_string(),
142 "credit_card_fraud".to_string(),
144 "loan_default".to_string(),
145 ]
146 }
147
148 pub fn get_dataset_info(&self, name: &str) -> Result<DatasetMetadata> {
150 self.registry.get_metadata(name)
151 }
152}
153
154impl RealWorldDatasets {
156 pub fn load_adult(&mut self) -> Result<Dataset> {
159 let cache_key = CacheKey::new("adult", &self.config);
160
161 if self.config.use_cache {
162 if let Some(dataset) = self.cache.get(&cache_key)? {
163 return Ok(dataset);
164 }
165 }
166
167 let url = "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data";
168 let dataset = self.download_and_parse_csv(
169 url,
170 "adult",
171 &[
172 "age",
173 "workclass",
174 "fnlwgt",
175 "education",
176 "education_num",
177 "marital_status",
178 "occupation",
179 "relationship",
180 "race",
181 "sex",
182 "capital_gain",
183 "capital_loss",
184 "hours_per_week",
185 "native_country",
186 "income",
187 ],
188 Some("income"),
189 true, )?;
191
192 if self.config.use_cache {
193 self.cache.put(&cache_key, &dataset)?;
194 }
195
196 Ok(dataset)
197 }
198
199 pub fn load_bank_marketing(&mut self) -> Result<Dataset> {
202 let cache_key = CacheKey::new("bank_marketing", &self.config);
203
204 if self.config.use_cache {
205 if let Some(dataset) = self.cache.get(&cache_key)? {
206 return Ok(dataset);
207 }
208 }
209
210 let (data, target) = self.create_synthetic_bank_data(4521, 16)?;
213
214 let metadata = DatasetMetadata {
215 name: "Bank Marketing".to_string(),
216 description: "Direct marketing campaigns of a Portuguese banking institution"
217 .to_string(),
218 n_samples: 4521,
219 n_features: 16,
220 task_type: "classification".to_string(),
221 targetnames: Some(vec!["no".to_string(), "yes".to_string()]),
222 featurenames: None,
223 url: None,
224 checksum: None,
225 };
226
227 let dataset = Dataset::from_metadata(data, Some(target), metadata);
228
229 if self.config.use_cache {
230 self.cache.put(&cache_key, &dataset)?;
231 }
232
233 Ok(dataset)
234 }
235
236 pub fn load_titanic(&mut self) -> Result<Dataset> {
239 let cache_key = CacheKey::new("titanic", &self.config);
240
241 if self.config.use_cache {
242 if let Some(dataset) = self.cache.get(&cache_key)? {
243 return Ok(dataset);
244 }
245 }
246
247 let (data, target) = self.create_synthetic_titanic_data(891, 7)?;
248
249 let metadata = DatasetMetadata {
250 name: "Titanic".to_string(),
251 description: "Passenger survival data from the Titanic disaster".to_string(),
252 n_samples: 891,
253 n_features: 7,
254 task_type: "classification".to_string(),
255 targetnames: Some(vec!["died".to_string(), "survived".to_string()]),
256 featurenames: None,
257 url: None,
258 checksum: None,
259 };
260
261 let dataset = Dataset::from_metadata(data, Some(target), metadata);
262
263 if self.config.use_cache {
264 self.cache.put(&cache_key, &dataset)?;
265 }
266
267 Ok(dataset)
268 }
269
270 pub fn load_german_credit(&mut self) -> Result<Dataset> {
273 let (data, target) = self.create_synthetic_credit_data(1000, 20)?;
274
275 let metadata = DatasetMetadata {
276 name: "German Credit".to_string(),
277 description: "Credit risk classification dataset".to_string(),
278 n_samples: 1000,
279 n_features: 20,
280 task_type: "classification".to_string(),
281 targetnames: Some(vec!["bad_credit".to_string(), "good_credit".to_string()]),
282 featurenames: None,
283 url: None,
284 checksum: None,
285 };
286
287 Ok(Dataset::from_metadata(data, Some(target), metadata))
288 }
289}
290
291impl RealWorldDatasets {
293 pub fn load_california_housing(&mut self) -> Result<Dataset> {
296 let (data, target) = self.create_synthetic_housing_data(20640, 8)?;
297
298 let metadata = DatasetMetadata {
299 name: "California Housing".to_string(),
300 description: "Median house values for California districts from 1990 census"
301 .to_string(),
302 n_samples: 20640,
303 n_features: 8,
304 task_type: "regression".to_string(),
305 targetnames: None, featurenames: None,
307 url: None,
308 checksum: None,
309 };
310
311 Ok(Dataset::from_metadata(data, Some(target), metadata))
312 }
313
314 pub fn load_red_wine_quality(&mut self) -> Result<Dataset> {
317 let (data, target) = self.create_synthetic_wine_data(1599, 11)?;
318
319 let metadata = DatasetMetadata {
320 name: "Red Wine Quality".to_string(),
321 description: "Red wine quality based on physicochemical tests".to_string(),
322 n_samples: 1599,
323 n_features: 11,
324 task_type: "regression".to_string(),
325 targetnames: None, featurenames: None,
327 url: None,
328 checksum: None,
329 };
330
331 Ok(Dataset::from_metadata(data, Some(target), metadata))
332 }
333
334 pub fn load_energy_efficiency(&mut self) -> Result<Dataset> {
337 let (data, target) = self.create_synthetic_energy_data(768, 8)?;
338
339 let metadata = DatasetMetadata {
340 name: "Energy Efficiency".to_string(),
341 description: "Energy efficiency of buildings based on building parameters".to_string(),
342 n_samples: 768,
343 n_features: 8,
344 task_type: "regression".to_string(),
345 targetnames: None, featurenames: None,
347 url: None,
348 checksum: None,
349 };
350
351 Ok(Dataset::from_metadata(data, Some(target), metadata))
352 }
353}
354
355impl RealWorldDatasets {
357 pub fn load_air_passengers(&mut self) -> Result<Dataset> {
360 let (data, target) = self.create_air_passengers_data(144)?;
361
362 let metadata = DatasetMetadata {
363 name: "Air Passengers".to_string(),
364 description: "Monthly airline passenger numbers 1949-1960".to_string(),
365 n_samples: 144,
366 n_features: 1,
367 task_type: "time_series".to_string(),
368 targetnames: None, featurenames: None,
370 url: None,
371 checksum: None,
372 };
373
374 Ok(Dataset::from_metadata(data, target, metadata))
375 }
376
377 pub fn load_bitcoin_prices(&mut self) -> Result<Dataset> {
380 let (data, target) = self.create_bitcoin_price_data(1000)?;
381
382 let metadata = DatasetMetadata {
383 name: "Bitcoin Prices".to_string(),
384 description: "Historical Bitcoin price data with technical indicators".to_string(),
385 n_samples: 1000,
386 n_features: 6,
387 task_type: "time_series".to_string(),
388 targetnames: None, featurenames: None,
390 url: None,
391 checksum: None,
392 };
393
394 Ok(Dataset::from_metadata(data, target, metadata))
395 }
396}
397
398impl RealWorldDatasets {
400 pub fn load_heart_disease(&mut self) -> Result<Dataset> {
403 let (data, target) = self.create_heart_disease_data(303, 13)?;
404
405 let metadata = DatasetMetadata {
406 name: "Heart Disease".to_string(),
407 description: "Heart disease prediction based on clinical parameters".to_string(),
408 n_samples: 303,
409 n_features: 13,
410 task_type: "classification".to_string(),
411 targetnames: Some(vec!["no_disease".to_string(), "disease".to_string()]),
412 featurenames: None,
413 url: None,
414 checksum: None,
415 };
416
417 Ok(Dataset::from_metadata(data, Some(target), metadata))
418 }
419
420 pub fn load_diabetes_readmission(&mut self) -> Result<Dataset> {
423 let (data, target) = self.create_diabetes_readmission_data(101766, 49)?;
424
425 let metadata = DatasetMetadata {
426 name: "Diabetes Readmission".to_string(),
427 description: "Hospital readmission prediction for diabetic patients".to_string(),
428 n_samples: 101766,
429 n_features: 49,
430 task_type: "classification".to_string(),
431 targetnames: Some(vec![
432 "no_readmission".to_string(),
433 "readmission".to_string(),
434 ]),
435 featurenames: None,
436 url: None,
437 checksum: None,
438 };
439
440 Ok(Dataset::from_metadata(data, Some(target), metadata))
441 }
442
443 pub fn load_credit_approval(&mut self) -> Result<Dataset> {
445 let cache_key = CacheKey::new("credit_approval", &self.config);
446
447 if self.config.use_cache {
448 if let Some(dataset) = self.cache.get(&cache_key)? {
449 return Ok(dataset);
450 }
451 }
452
453 let url =
455 "https://archive.ics.uci.edu/ml/machine-learning-databases/credit-screening/crx.data";
456 let columns = &[
457 "A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "A9", "A10", "A11", "A12", "A13",
458 "A14", "A15", "class",
459 ];
460
461 let dataset =
462 self.download_and_parse_csv(url, "credit_approval", columns, Some("class"), true)?;
463
464 if self.config.use_cache {
465 self.cache.put(&cache_key, &dataset)?;
466 }
467
468 Ok(dataset)
469 }
470
471 pub fn load_mushroom(&mut self) -> Result<Dataset> {
473 let cache_key = CacheKey::new("mushroom", &self.config);
474
475 if self.config.use_cache {
476 if let Some(dataset) = self.cache.get(&cache_key)? {
477 return Ok(dataset);
478 }
479 }
480
481 let url = "https://archive.ics.uci.edu/ml/machine-learning-databases/mushroom/agaricus-lepiota.data";
483 let columns = &[
484 "class",
485 "cap-shape",
486 "cap-surface",
487 "cap-color",
488 "bruises",
489 "odor",
490 "gill-attachment",
491 "gill-spacing",
492 "gill-size",
493 "gill-color",
494 "stalk-shape",
495 "stalk-root",
496 "stalk-surface-above-ring",
497 "stalk-surface-below-ring",
498 "stalk-color-above-ring",
499 "stalk-color-below-ring",
500 "veil-type",
501 "veil-color",
502 "ring-number",
503 "ring-type",
504 "spore-print-color",
505 "population",
506 "habitat",
507 ];
508
509 let dataset = self.download_and_parse_csv(url, "mushroom", columns, Some("class"), true)?;
510
511 if self.config.use_cache {
512 self.cache.put(&cache_key, &dataset)?;
513 }
514
515 Ok(dataset)
516 }
517
518 pub fn load_spam(&mut self) -> Result<Dataset> {
520 let cache_key = CacheKey::new("spam", &self.config);
521
522 if self.config.use_cache {
523 if let Some(dataset) = self.cache.get(&cache_key)? {
524 return Ok(dataset);
525 }
526 }
527
528 let url =
530 "https://archive.ics.uci.edu/ml/machine-learning-databases/spambase/spambase.data";
531 let mut columns: Vec<String> = Vec::new();
532
533 for i in 0..48 {
535 columns.push(format!("word_freq_{i}"));
536 }
537 for i in 0..6 {
538 columns.push(format!("char_freq_{i}"));
539 }
540 columns.push("capital_run_length_average".to_string());
541 columns.push("capital_run_length_longest".to_string());
542 columns.push("capital_run_length_total".to_string());
543 columns.push("spam".to_string());
544
545 let column_refs: Vec<&str> = columns.iter().map(|s| s.as_str()).collect();
546
547 let dataset =
548 self.download_and_parse_csv(url, "spam", &column_refs, Some("spam"), false)?;
549
550 if self.config.use_cache {
551 self.cache.put(&cache_key, &dataset)?;
552 }
553
554 Ok(dataset)
555 }
556
557 pub fn load_auto_mpg(&mut self) -> Result<Dataset> {
559 let cache_key = CacheKey::new("auto_mpg", &self.config);
560
561 if self.config.use_cache {
562 if let Some(dataset) = self.cache.get(&cache_key)? {
563 return Ok(dataset);
564 }
565 }
566
567 let (data, target) = self.create_synthetic_auto_mpg_data(392, 7)?;
568
569 let metadata = DatasetMetadata {
570 name: "Auto MPG".to_string(),
571 description:
572 "Predict car fuel efficiency (miles per gallon) from technical specifications"
573 .to_string(),
574 n_samples: 392,
575 n_features: 7,
576 task_type: "regression".to_string(),
577 targetnames: None, featurenames: None,
579 url: None,
580 checksum: None,
581 };
582
583 let dataset = Dataset::from_metadata(data, Some(target), metadata);
584
585 if self.config.use_cache {
586 self.cache.put(&cache_key, &dataset)?;
587 }
588
589 Ok(dataset)
590 }
591
592 pub fn load_concrete_strength(&mut self) -> Result<Dataset> {
594 let cache_key = CacheKey::new("concrete_strength", &self.config);
595
596 if self.config.use_cache {
597 if let Some(dataset) = self.cache.get(&cache_key)? {
598 return Ok(dataset);
599 }
600 }
601
602 let (data, target) = self.create_synthetic_concrete_data(1030, 8)?;
603
604 let metadata = DatasetMetadata {
605 name: "Concrete Compressive Strength".to_string(),
606 description: "Predict concrete compressive strength from mixture components"
607 .to_string(),
608 n_samples: 1030,
609 n_features: 8,
610 task_type: "regression".to_string(),
611 targetnames: None, featurenames: None,
613 url: None,
614 checksum: None,
615 };
616
617 let dataset = Dataset::from_metadata(data, Some(target), metadata);
618
619 if self.config.use_cache {
620 self.cache.put(&cache_key, &dataset)?;
621 }
622
623 Ok(dataset)
624 }
625
626 pub fn load_white_wine_quality(&mut self) -> Result<Dataset> {
628 let cache_key = CacheKey::new("white_wine_quality", &self.config);
629
630 if self.config.use_cache {
631 if let Some(dataset) = self.cache.get(&cache_key)? {
632 return Ok(dataset);
633 }
634 }
635
636 let (data, target) = self.create_synthetic_wine_data(4898, 11)?;
637
638 let metadata = DatasetMetadata {
639 name: "White Wine Quality".to_string(),
640 description: "White wine quality based on physicochemical tests".to_string(),
641 n_samples: 4898,
642 n_features: 11,
643 task_type: "regression".to_string(),
644 targetnames: None, featurenames: None,
646 url: None,
647 checksum: None,
648 };
649
650 let dataset = Dataset::from_metadata(data, Some(target), metadata);
651
652 if self.config.use_cache {
653 self.cache.put(&cache_key, &dataset)?;
654 }
655
656 Ok(dataset)
657 }
658
659 pub fn load_electricity_load(&mut self) -> Result<Dataset> {
661 let cache_key = CacheKey::new("electricity_load", &self.config);
662
663 if self.config.use_cache {
664 if let Some(dataset) = self.cache.get(&cache_key)? {
665 return Ok(dataset);
666 }
667 }
668
669 let (data, target) = self.create_synthetic_electricity_data(26304, 3)?; let metadata = DatasetMetadata {
672 name: "Electricity Load".to_string(),
673 description: "Hourly electricity consumption forecasting with weather factors"
674 .to_string(),
675 n_samples: 26304,
676 n_features: 3,
677 task_type: "time_series".to_string(),
678 targetnames: None, featurenames: None,
680 url: None,
681 checksum: None,
682 };
683
684 let dataset = Dataset::from_metadata(data, Some(target), metadata);
685
686 if self.config.use_cache {
687 self.cache.put(&cache_key, &dataset)?;
688 }
689
690 Ok(dataset)
691 }
692
693 pub fn load_stock_prices(&mut self) -> Result<Dataset> {
695 let cache_key = CacheKey::new("stock_prices", &self.config);
696
697 if self.config.use_cache {
698 if let Some(dataset) = self.cache.get(&cache_key)? {
699 return Ok(dataset);
700 }
701 }
702
703 let (data, target) = self.create_synthetic_stock_data(1260, 5)?; let metadata = DatasetMetadata {
706 name: "Stock Prices".to_string(),
707 description: "Daily stock price prediction with technical indicators".to_string(),
708 n_samples: 1260,
709 n_features: 5,
710 task_type: "time_series".to_string(),
711 targetnames: None, featurenames: None,
713 url: None,
714 checksum: None,
715 };
716
717 let dataset = Dataset::from_metadata(data, Some(target), metadata);
718
719 if self.config.use_cache {
720 self.cache.put(&cache_key, &dataset)?;
721 }
722
723 Ok(dataset)
724 }
725
726 pub fn load_cifar10_subset(&mut self) -> Result<Dataset> {
728 let cache_key = CacheKey::new("cifar10_subset", &self.config);
729
730 if self.config.use_cache {
731 if let Some(dataset) = self.cache.get(&cache_key)? {
732 return Ok(dataset);
733 }
734 }
735
736 let (data, target) = self.create_synthetic_cifar10_data(1000, 3072)?; let metadata = DatasetMetadata {
739 name: "CIFAR-10 Subset".to_string(),
740 description: "Subset of CIFAR-10 32x32 color images in 10 classes".to_string(),
741 n_samples: 1000,
742 n_features: 3072,
743 task_type: "classification".to_string(),
744 targetnames: Some(vec![
745 "airplane".to_string(),
746 "automobile".to_string(),
747 "bird".to_string(),
748 "cat".to_string(),
749 "deer".to_string(),
750 "dog".to_string(),
751 "frog".to_string(),
752 "horse".to_string(),
753 "ship".to_string(),
754 "truck".to_string(),
755 ]),
756 featurenames: None,
757 url: None,
758 checksum: None,
759 };
760
761 let dataset = Dataset::from_metadata(data, Some(target), metadata);
762
763 if self.config.use_cache {
764 self.cache.put(&cache_key, &dataset)?;
765 }
766
767 Ok(dataset)
768 }
769
770 pub fn load_fashion_mnist_subset(&mut self) -> Result<Dataset> {
772 let cache_key = CacheKey::new("fashion_mnist_subset", &self.config);
773
774 if self.config.use_cache {
775 if let Some(dataset) = self.cache.get(&cache_key)? {
776 return Ok(dataset);
777 }
778 }
779
780 let (data, target) = self.create_synthetic_fashion_mnist_data(1000, 784)?; let metadata = DatasetMetadata {
783 name: "Fashion-MNIST Subset".to_string(),
784 description: "Subset of Fashion-MNIST 28x28 grayscale images of fashion items"
785 .to_string(),
786 n_samples: 1000,
787 n_features: 784,
788 task_type: "classification".to_string(),
789 targetnames: Some(vec![
790 "T-shirt/top".to_string(),
791 "Trouser".to_string(),
792 "Pullover".to_string(),
793 "Dress".to_string(),
794 "Coat".to_string(),
795 "Sandal".to_string(),
796 "Shirt".to_string(),
797 "Sneaker".to_string(),
798 "Bag".to_string(),
799 "Ankle boot".to_string(),
800 ]),
801 featurenames: None,
802 url: None,
803 checksum: None,
804 };
805
806 let dataset = Dataset::from_metadata(data, Some(target), metadata);
807
808 if self.config.use_cache {
809 self.cache.put(&cache_key, &dataset)?;
810 }
811
812 Ok(dataset)
813 }
814
815 pub fn load_imdb_reviews(&mut self) -> Result<Dataset> {
817 let cache_key = CacheKey::new("imdb_reviews", &self.config);
818
819 if self.config.use_cache {
820 if let Some(dataset) = self.cache.get(&cache_key)? {
821 return Ok(dataset);
822 }
823 }
824
825 let (data, target) = self.create_synthetic_imdb_data(5000, 1000)?; let metadata = DatasetMetadata {
828 name: "IMDB Movie Reviews".to_string(),
829 description: "Subset of IMDB movie reviews for sentiment classification".to_string(),
830 n_samples: 5000,
831 n_features: 1000,
832 task_type: "classification".to_string(),
833 targetnames: Some(vec!["negative".to_string(), "positive".to_string()]),
834 featurenames: None,
835 url: None,
836 checksum: None,
837 };
838
839 let dataset = Dataset::from_metadata(data, Some(target), metadata);
840
841 if self.config.use_cache {
842 self.cache.put(&cache_key, &dataset)?;
843 }
844
845 Ok(dataset)
846 }
847
848 pub fn load_news_articles(&mut self) -> Result<Dataset> {
850 let cache_key = CacheKey::new("news_articles", &self.config);
851
852 if self.config.use_cache {
853 if let Some(dataset) = self.cache.get(&cache_key)? {
854 return Ok(dataset);
855 }
856 }
857
858 let (data, target) = self.create_synthetic_news_data(2000, 500)?; let metadata = DatasetMetadata {
861 name: "News Articles".to_string(),
862 description: "News articles categorized by topic for text classification".to_string(),
863 n_samples: 2000,
864 n_features: 500,
865 task_type: "classification".to_string(),
866 targetnames: Some(vec![
867 "business".to_string(),
868 "entertainment".to_string(),
869 "politics".to_string(),
870 "sport".to_string(),
871 "tech".to_string(),
872 ]),
873 featurenames: None,
874 url: None,
875 checksum: None,
876 };
877
878 let dataset = Dataset::from_metadata(data, Some(target), metadata);
879
880 if self.config.use_cache {
881 self.cache.put(&cache_key, &dataset)?;
882 }
883
884 Ok(dataset)
885 }
886
887 pub fn load_credit_card_fraud(&mut self) -> Result<Dataset> {
889 let cache_key = CacheKey::new("credit_card_fraud", &self.config);
890
891 if self.config.use_cache {
892 if let Some(dataset) = self.cache.get(&cache_key)? {
893 return Ok(dataset);
894 }
895 }
896
897 let (data, target) = self.create_synthetic_fraud_data(284807, 28)?;
898
899 let metadata = DatasetMetadata {
900 name: "Credit Card Fraud Detection".to_string(),
901 description: "Detect fraudulent credit card transactions from anonymized features"
902 .to_string(),
903 n_samples: 284807,
904 n_features: 28,
905 task_type: "classification".to_string(),
906 targetnames: Some(vec!["legitimate".to_string(), "fraud".to_string()]),
907 featurenames: None,
908 url: None,
909 checksum: None,
910 };
911
912 let dataset = Dataset::from_metadata(data, Some(target), metadata);
913
914 if self.config.use_cache {
915 self.cache.put(&cache_key, &dataset)?;
916 }
917
918 Ok(dataset)
919 }
920
921 pub fn load_loan_default(&mut self) -> Result<Dataset> {
923 let cache_key = CacheKey::new("loan_default", &self.config);
924
925 if self.config.use_cache {
926 if let Some(dataset) = self.cache.get(&cache_key)? {
927 return Ok(dataset);
928 }
929 }
930
931 let (data, target) = self.create_synthetic_loan_data(10000, 15)?;
932
933 let metadata = DatasetMetadata {
934 name: "Loan Default Prediction".to_string(),
935 description: "Predict loan default risk from borrower characteristics and loan details"
936 .to_string(),
937 n_samples: 10000,
938 n_features: 15,
939 task_type: "classification".to_string(),
940 targetnames: Some(vec!["no_default".to_string(), "default".to_string()]),
941 featurenames: None,
942 url: None,
943 checksum: None,
944 };
945
946 let dataset = Dataset::from_metadata(data, Some(target), metadata);
947
948 if self.config.use_cache {
949 self.cache.put(&cache_key, &dataset)?;
950 }
951
952 Ok(dataset)
953 }
954}
955
956impl RealWorldDatasets {
958 fn download_and_parse_csv(
959 &self,
960 url: &str,
961 name: &str,
962 columns: &[&str],
963 target_col: Option<&str>,
964 has_categorical: bool,
965 ) -> Result<Dataset> {
966 if !self.config.download_if_missing {
968 return Err(DatasetsError::DownloadError(
969 "Download disabled in configuration".to_string(),
970 ));
971 }
972
973 #[cfg(feature = "download")]
975 {
976 match self.download_real_dataset(url, name, columns, target_col, has_categorical) {
977 Ok(dataset) => return Ok(dataset),
978 Err(e) => {
979 eprintln!("Warning: Failed to download real dataset from {}: {}. Falling back to synthetic data.", url, e);
980 }
981 }
982 }
983
984 match name {
986 "adult" => {
987 let (data, target) = self.create_synthetic_adult_dataset(32561, 14)?;
988
989 let featurenames = vec![
990 "age".to_string(),
991 "workclass".to_string(),
992 "fnlwgt".to_string(),
993 "education".to_string(),
994 "education_num".to_string(),
995 "marital_status".to_string(),
996 "occupation".to_string(),
997 "relationship".to_string(),
998 "race".to_string(),
999 "sex".to_string(),
1000 "capital_gain".to_string(),
1001 "capital_loss".to_string(),
1002 "hours_per_week".to_string(),
1003 "native_country".to_string(),
1004 ];
1005
1006 let metadata = crate::registry::DatasetMetadata {
1007 name: "Adult Census Income".to_string(),
1008 description: "Predict whether income exceeds $50K/yr based on census data"
1009 .to_string(),
1010 n_samples: 32561,
1011 n_features: 14,
1012 task_type: "classification".to_string(),
1013 targetnames: Some(vec!["<=50K".to_string(), ">50K".to_string()]),
1014 featurenames: Some(featurenames),
1015 url: Some(url.to_string()),
1016 checksum: None,
1017 };
1018
1019 Ok(Dataset::from_metadata(data, Some(target), metadata))
1020 }
1021 _ => {
1022 let n_features = columns.len() - if target_col.is_some() { 1 } else { 0 };
1024 let (data, target) =
1025 self.create_generic_synthetic_dataset(1000, n_features, has_categorical)?;
1026
1027 let featurenames: Vec<String> = columns
1028 .iter()
1029 .filter(|&&_col| Some(_col) != target_col)
1030 .map(|&_col| _col.to_string())
1031 .collect();
1032
1033 let metadata = crate::registry::DatasetMetadata {
1034 name: format!("Synthetic {name}"),
1035 description: format!("Synthetic version of {name} dataset"),
1036 n_samples: 1000,
1037 n_features,
1038 task_type: if target_col.is_some() {
1039 "classification"
1040 } else {
1041 "regression"
1042 }
1043 .to_string(),
1044 targetnames: None,
1045 featurenames: Some(featurenames),
1046 url: Some(url.to_string()),
1047 checksum: None,
1048 };
1049
1050 Ok(Dataset::from_metadata(data, target, metadata))
1051 }
1052 }
1053 }
1054
1055 #[cfg(feature = "download")]
1057 fn download_real_dataset(
1058 &self,
1059 url: &str,
1060 name: &str,
1061 columns: &[&str],
1062 target_col: Option<&str>,
1063 _has_categorical: bool,
1064 ) -> Result<Dataset> {
1065 use crate::cache::download_data;
1066 use std::collections::HashMap;
1067 use std::io::{BufRead, BufReader, Cursor};
1068
1069 let data_bytes = download_data(url, false)?;
1071
1072 let cursor = Cursor::new(data_bytes);
1074 let reader = BufReader::new(cursor);
1075
1076 let mut rows: Vec<Vec<String>> = Vec::new();
1077 let mut header_found = false;
1078
1079 for line_result in reader.lines() {
1080 let line = line_result
1081 .map_err(|e| DatasetsError::FormatError(format!("Failed to read line: {}", e)))?;
1082 let line = line.trim();
1083
1084 if line.is_empty() {
1085 continue;
1086 }
1087
1088 let fields: Vec<String> = line
1090 .split(',')
1091 .map(|s| s.trim().trim_matches('"').to_string())
1092 .collect();
1093
1094 if !header_found && fields.len() == columns.len() {
1095 let is_header = fields.iter().enumerate().all(|(i, field)| {
1097 field.to_lowercase().contains(&columns[i].to_lowercase())
1098 || columns[i].to_lowercase().contains(&field.to_lowercase())
1099 });
1100 if is_header {
1101 header_found = true;
1102 continue;
1103 }
1104 }
1105
1106 if fields.len() == columns.len() {
1107 rows.push(fields);
1108 }
1109 }
1110
1111 if rows.is_empty() {
1112 return Err(DatasetsError::FormatError(
1113 "No valid data rows found in CSV".to_string(),
1114 ));
1115 }
1116
1117 let n_samples = rows.len();
1119 let n_features = if let Some(_) = target_col {
1120 columns.len() - 1
1121 } else {
1122 columns.len()
1123 };
1124
1125 let mut data = Array2::<f64>::zeros((n_samples, n_features));
1126 let mut target = if target_col.is_some() {
1127 Some(Array1::<f64>::zeros(n_samples))
1128 } else {
1129 None
1130 };
1131
1132 let mut category_maps: HashMap<usize, HashMap<String, f64>> = HashMap::new();
1134
1135 for (row_idx, row) in rows.iter().enumerate() {
1136 let mut feature_idx = 0;
1137
1138 for (col_idx, value) in row.iter().enumerate() {
1139 if Some(columns[col_idx]) == target_col {
1140 if let Some(ref mut target_array) = target {
1142 let numeric_value = match value.parse::<f64>() {
1143 Ok(v) => v,
1144 Err(_) => {
1145 let category_map =
1147 category_maps.entry(col_idx).or_insert_with(HashMap::new);
1148 let next_id = category_map.len() as f64;
1149 *category_map.entry(value.clone()).or_insert(next_id)
1150 }
1151 };
1152 target_array[row_idx] = numeric_value;
1153 }
1154 } else {
1155 let numeric_value = match value.parse::<f64>() {
1157 Ok(v) => v,
1158 Err(_) => {
1159 let category_map =
1161 category_maps.entry(col_idx).or_insert_with(HashMap::new);
1162 let next_id = category_map.len() as f64;
1163 *category_map.entry(value.clone()).or_insert(next_id)
1164 }
1165 };
1166 data[[row_idx, feature_idx]] = numeric_value;
1167 feature_idx += 1;
1168 }
1169 }
1170 }
1171
1172 let featurenames: Vec<String> = columns
1174 .iter()
1175 .filter(|&&_col| Some(_col) != target_col)
1176 .map(|&_col| col.to_string())
1177 .collect();
1178
1179 let metadata = crate::registry::DatasetMetadata {
1181 name: name.to_string(),
1182 description: format!("Real-world dataset: {}", name),
1183 n_samples,
1184 n_features,
1185 task_type: if target.is_some() {
1186 "classification".to_string()
1187 } else {
1188 "unsupervised".to_string()
1189 },
1190 targetnames: None,
1191 featurenames: Some(featurenames),
1192 url: Some(url.to_string()),
1193 checksum: None,
1194 };
1195
1196 Ok(Dataset::from_metadata(data, target, metadata))
1197 }
1198
1199 fn create_synthetic_bank_data(
1200 &self,
1201 n_samples: usize,
1202 n_features: usize,
1203 ) -> Result<(Array2<f64>, Array1<f64>)> {
1204 use rand::Rng;
1205 let mut rng = rng();
1206
1207 let mut data = Array2::zeros((n_samples, n_features));
1208 let mut target = Array1::zeros(n_samples);
1209
1210 for i in 0..n_samples {
1211 for j in 0..n_features {
1212 data[[i, j]] = rng.gen_range(0.0..1.0);
1213 }
1214 target[i] = if data.row(i).iter().take(3).sum::<f64>() > 1.5 {
1216 1.0
1217 } else {
1218 0.0
1219 };
1220 }
1221
1222 Ok((data, target))
1223 }
1224
1225 #[allow(dead_code)]
1226 fn create_synthetic_credit_approval_data(&self) -> Result<Dataset> {
1227 use rand::Rng;
1228 let mut rng = rng();
1229
1230 let n_samples = 690; let n_features = 15;
1232
1233 let mut data = Array2::zeros((n_samples, n_features));
1234 let mut target = Array1::zeros(n_samples);
1235
1236 let featurenames = vec![
1237 "credit_score".to_string(),
1238 "annual_income".to_string(),
1239 "debt_to_income_ratio".to_string(),
1240 "employment_length".to_string(),
1241 "age".to_string(),
1242 "home_ownership".to_string(),
1243 "loan_amount".to_string(),
1244 "loan_purpose".to_string(),
1245 "credit_history_length".to_string(),
1246 "number_of_credit_lines".to_string(),
1247 "utilization_rate".to_string(),
1248 "delinquency_count".to_string(),
1249 "education_level".to_string(),
1250 "marital_status".to_string(),
1251 "verification_status".to_string(),
1252 ];
1253
1254 for i in 0..n_samples {
1255 data[[i, 0]] = rng.gen_range(300.0..850.0);
1257 data[[i, 1]] = rng.gen_range(20000.0..200000.0);
1259 data[[i, 2]] = rng.gen_range(0.0..0.6);
1261 data[[i, 3]] = rng.gen_range(0.0..30.0);
1263 data[[i, 4]] = rng.gen_range(18.0..80.0);
1265 data[[i, 5]] = rng.gen_range(0.0f64..3.0).floor();
1267 data[[i, 6]] = rng.gen_range(1000.0..50000.0);
1269 data[[i, 7]] = rng.gen_range(0.0f64..7.0).floor();
1271 data[[i, 8]] = rng.gen_range(0.0..40.0);
1273 data[[i, 9]] = rng.gen_range(0.0..20.0);
1275 data[[i, 10]] = rng.gen_range(0.0..1.0);
1277 data[[i, 11]] = rng.gen_range(0.0f64..11.0).floor();
1279 data[[i, 12]] = rng.gen_range(0.0f64..4.0).floor();
1281 data[[i, 13]] = rng.gen_range(0.0f64..3.0).floor();
1283 data[[i, 14]] = if rng.gen_bool(0.7) { 1.0 } else { 0.0 };
1285
1286 let credit_score_factor = (data[[i, 0]] - 300.0) / 550.0; let income_factor = (data[[i, 1]] / 100000.0).min(1.0); let debt_factor = 1.0 - data[[i, 2]]; let employment_factor = (data[[i, 3]] / 10.0).min(1.0); let delinquency_penalty = data[[i, 11]] * 0.1; let approval_score = credit_score_factor * 0.4
1294 + income_factor * 0.3
1295 + debt_factor * 0.2
1296 + employment_factor * 0.1
1297 - delinquency_penalty;
1298
1299 let noise = rng.gen_range(-0.2..0.2);
1301 target[i] = if (approval_score + noise) > 0.5 {
1302 1.0
1303 } else {
1304 0.0
1305 };
1306 }
1307
1308 let metadata = crate::registry::DatasetMetadata {
1309 name: "Credit Approval Dataset".to_string(),
1310 description: "Synthetic credit approval dataset with realistic financial features for binary classification".to_string(),
1311 n_samples,
1312 n_features,
1313 task_type: "classification".to_string(),
1314 targetnames: Some(vec!["denied".to_string(), "approved".to_string()]),
1315 featurenames: Some(featurenames),
1316 url: None,
1317 checksum: None,
1318 };
1319
1320 Ok(Dataset::from_metadata(data, Some(target), metadata))
1321 }
1322
1323 #[allow(dead_code)]
1324 fn create_synthetic_mushroom_data(&self) -> Result<Dataset> {
1325 use rand::Rng;
1326 let mut rng = rng();
1327
1328 let n_samples = 8124; let n_features = 22;
1330
1331 let mut data = Array2::zeros((n_samples, n_features));
1332 let mut target = Array1::zeros(n_samples);
1333
1334 let featurenames = vec![
1335 "capshape".to_string(),
1336 "cap_surface".to_string(),
1337 "cap_color".to_string(),
1338 "bruises".to_string(),
1339 "odor".to_string(),
1340 "gill_attachment".to_string(),
1341 "gill_spacing".to_string(),
1342 "gill_size".to_string(),
1343 "gill_color".to_string(),
1344 "stalkshape".to_string(),
1345 "stalk_root".to_string(),
1346 "stalk_surface_above_ring".to_string(),
1347 "stalk_surface_below_ring".to_string(),
1348 "stalk_color_above_ring".to_string(),
1349 "stalk_color_below_ring".to_string(),
1350 "veil_type".to_string(),
1351 "veil_color".to_string(),
1352 "ring_number".to_string(),
1353 "ring_type".to_string(),
1354 "spore_print_color".to_string(),
1355 "population".to_string(),
1356 "habitat".to_string(),
1357 ];
1358
1359 for i in 0..n_samples {
1360 data[[i, 0]] = rng.gen_range(0.0f64..6.0).floor();
1362 data[[i, 1]] = rng.gen_range(0.0f64..4.0).floor();
1364 data[[i, 2]] = rng.gen_range(0.0f64..10.0).floor();
1366 data[[i, 3]] = if rng.gen_bool(0.6) { 1.0 } else { 0.0 };
1368 data[[i, 4]] = rng.gen_range(0.0f64..9.0).floor();
1370 data[[i, 5]] = if rng.gen_bool(0.5) { 1.0 } else { 0.0 };
1372 data[[i, 6]] = if rng.gen_bool(0.5) { 1.0 } else { 0.0 };
1374 data[[i, 7]] = if rng.gen_bool(0.5) { 1.0 } else { 0.0 };
1376 data[[i, 8]] = rng.gen_range(0.0f64..12.0).floor();
1378 data[[i, 9]] = if rng.gen_bool(0.5) { 1.0 } else { 0.0 };
1380 data[[i, 10]] = rng.gen_range(0.0f64..5.0).floor();
1382 data[[i, 11]] = rng.gen_range(0.0f64..4.0).floor();
1384 data[[i, 12]] = rng.gen_range(0.0f64..4.0).floor();
1386 data[[i, 13]] = rng.gen_range(0.0f64..9.0).floor();
1388 data[[i, 14]] = rng.gen_range(0.0f64..9.0).floor();
1390 data[[i, 15]] = 0.0;
1392 data[[i, 16]] = rng.gen_range(0.0f64..4.0).floor();
1394 data[[i, 17]] = rng.gen_range(0.0f64..3.0).floor();
1396 data[[i, 18]] = rng.gen_range(0.0f64..8.0).floor();
1398 data[[i, 19]] = rng.gen_range(0.0f64..9.0).floor();
1400 data[[i, 20]] = rng.gen_range(0.0f64..6.0).floor();
1402 data[[i, 21]] = rng.gen_range(0.0f64..7.0).floor();
1404
1405 let mut poison_score = 0.0;
1408
1409 if data[[i, 4]] == 2.0 || data[[i, 4]] == 3.0 || data[[i, 4]] == 4.0 {
1411 poison_score += 0.8;
1413 }
1414 if data[[i, 4]] == 5.0 || data[[i, 4]] == 7.0 {
1415 poison_score += 0.4;
1417 }
1418
1419 if data[[i, 19]] == 2.0 || data[[i, 19]] == 4.0 {
1421 poison_score += 0.3;
1423 }
1424
1425 if data[[i, 10]] == 0.0 {
1427 poison_score += 0.2;
1429 }
1430
1431 let noise = rng.gen_range(-0.3..0.3);
1433 target[i] = if (poison_score + noise) > 0.5 {
1434 1.0
1435 } else {
1436 0.0
1437 }; }
1439
1440 let metadata = crate::registry::DatasetMetadata {
1441 name: "Mushroom Dataset".to_string(),
1442 description: "Synthetic mushroom classification dataset with morphological features for edibility prediction".to_string(),
1443 n_samples,
1444 n_features,
1445 task_type: "classification".to_string(),
1446 targetnames: Some(vec!["edible".to_string(), "poisonous".to_string()]),
1447 featurenames: Some(featurenames),
1448 url: None,
1449 checksum: None,
1450 };
1451
1452 Ok(Dataset::from_metadata(data, Some(target), metadata))
1453 }
1454
1455 #[allow(dead_code)]
1456 fn create_synthetic_spam_data(&self) -> Result<Dataset> {
1457 use rand::Rng;
1458 let mut rng = rng();
1459
1460 let n_samples = 4601; let n_features = 57; let mut data = Array2::zeros((n_samples, n_features));
1464 let mut target = Array1::zeros(n_samples);
1465
1466 let mut featurenames = Vec::with_capacity(n_features);
1468
1469 let spam_words = vec![
1471 "make",
1472 "address",
1473 "all",
1474 "3d",
1475 "our",
1476 "over",
1477 "remove",
1478 "internet",
1479 "order",
1480 "mail",
1481 "receive",
1482 "will",
1483 "people",
1484 "report",
1485 "addresses",
1486 "free",
1487 "business",
1488 "email",
1489 "you",
1490 "credit",
1491 "your",
1492 "font",
1493 "000",
1494 "money",
1495 "hp",
1496 "hpl",
1497 "george",
1498 "650",
1499 "lab",
1500 "labs",
1501 "telnet",
1502 "857",
1503 "data",
1504 "415",
1505 "85",
1506 "technology",
1507 "1999",
1508 "parts",
1509 "pm",
1510 "direct",
1511 "cs",
1512 "meeting",
1513 "original",
1514 "project",
1515 "re",
1516 "edu",
1517 "table",
1518 "conference",
1519 "char_freq_semicolon",
1520 "char_freq_parenthesis",
1521 "char_freq_bracket",
1522 "char_freq_exclamation",
1523 "char_freq_dollar",
1524 "char_freq_hash",
1525 "capital_run_length_average",
1526 "capital_run_length_longest",
1527 "capital_run_length_total",
1528 ];
1529
1530 for (i, word) in spam_words.iter().enumerate() {
1531 if i < n_features {
1532 featurenames.push(format!("word_freq_{word}"));
1533 }
1534 }
1535
1536 while featurenames.len() < n_features {
1538 featurenames.push(format!("feature_{}", featurenames.len()));
1539 }
1540
1541 for i in 0..n_samples {
1542 let is_spam = rng.gen_bool(0.4); for j in 0..54 {
1546 if is_spam {
1547 match j {
1549 0..=7 => data[[i, j]] = rng.gen_range(0.0..5.0), 8..=15 => data[[i, j]] = rng.gen_range(0.0..3.0), 16..=25 => data[[i, j]] = rng.gen_range(0.0..4.0), _ => data[[i, j]] = rng.gen_range(0.0..1.0), }
1554 } else {
1555 match j {
1557 26..=35 => data[[i, j]] = rng.gen_range(0.0..2.0), 36..=45 => data[[i, j]] = rng.gen_range(0.0..1.5), _ => data[[i, j]] = rng.gen_range(0.0..0.5), }
1561 }
1562 }
1563
1564 if is_spam {
1566 data[[i, 54]] = rng.gen_range(0.0..0.2); data[[i, 55]] = rng.gen_range(0.0..0.5); data[[i, 56]] = rng.gen_range(0.0..0.3); } else {
1570 data[[i, 54]] = rng.gen_range(0.0..0.1);
1571 data[[i, 55]] = rng.gen_range(0.0..0.2);
1572 data[[i, 56]] = rng.gen_range(0.0..0.1);
1573 }
1574
1575 target[i] = if is_spam { 1.0 } else { 0.0 };
1576 }
1577
1578 let metadata = crate::registry::DatasetMetadata {
1579 name: "Spam Email Dataset".to_string(),
1580 description: "Synthetic spam email classification dataset with word and character frequency features".to_string(),
1581 n_samples,
1582 n_features,
1583 task_type: "classification".to_string(),
1584 targetnames: Some(vec!["ham".to_string(), "spam".to_string()]),
1585 featurenames: Some(featurenames),
1586 url: None,
1587 checksum: None,
1588 };
1589
1590 Ok(Dataset::from_metadata(data, Some(target), metadata))
1591 }
1592
1593 fn create_synthetic_titanic_data(
1594 &self,
1595 n_samples: usize,
1596 n_features: usize,
1597 ) -> Result<(Array2<f64>, Array1<f64>)> {
1598 use rand::Rng;
1599 let mut rng = rng();
1600
1601 let mut data = Array2::zeros((n_samples, n_features));
1602 let mut target = Array1::zeros(n_samples);
1603
1604 for i in 0..n_samples {
1605 data[[i, 0]] = rng.gen_range(1.0f64..4.0).floor();
1607 data[[i, 1]] = if rng.gen_bool(0.5) { 0.0 } else { 1.0 };
1609 data[[i, 2]] = rng.gen_range(1.0..80.0);
1611 data[[i, 3]] = rng.gen_range(0.0f64..6.0).floor();
1613 data[[i, 4]] = rng.gen_range(0.0f64..4.0).floor();
1615 data[[i, 5]] = rng.gen_range(0.0..512.0);
1617 data[[i, 6]] = rng.gen_range(0.0f64..3.0).floor();
1619
1620 let survival_score = (4.0 - data[[i, 0]]) * 0.3 + (1.0 - data[[i, 1]]) * 0.4 + (80.0 - data[[i, 2]]) / 80.0 * 0.3; target[i] = if survival_score > 0.5 { 1.0 } else { 0.0 };
1626 }
1627
1628 Ok((data, target))
1629 }
1630
1631 fn create_synthetic_credit_data(
1632 &self,
1633 n_samples: usize,
1634 n_features: usize,
1635 ) -> Result<(Array2<f64>, Array1<f64>)> {
1636 use rand::Rng;
1637 let mut rng = rng();
1638
1639 let mut data = Array2::zeros((n_samples, n_features));
1640 let mut target = Array1::zeros(n_samples);
1641
1642 for i in 0..n_samples {
1643 for j in 0..n_features {
1644 data[[i, j]] = rng.gen_range(0.0..1.0);
1645 }
1646 let score = data.row(i).iter().sum::<f64>() / n_features as f64;
1648 target[i] = if score > 0.6 { 1.0 } else { 0.0 };
1649 }
1650
1651 Ok((data, target))
1652 }
1653
1654 fn create_synthetic_housing_data(
1655 &self,
1656 n_samples: usize,
1657 n_features: usize,
1658 ) -> Result<(Array2<f64>, Array1<f64>)> {
1659 use rand::Rng;
1660 let mut rng = rng();
1661
1662 let mut data = Array2::zeros((n_samples, n_features));
1663 let mut target = Array1::zeros(n_samples);
1664
1665 for i in 0..n_samples {
1666 data[[i, 0]] = rng.gen_range(0.5..15.0);
1668 data[[i, 1]] = rng.gen_range(1.0..52.0);
1670 data[[i, 2]] = rng.gen_range(3.0..20.0);
1672 data[[i, 3]] = rng.gen_range(0.8..6.0);
1674 data[[i, 4]] = rng.gen_range(3.0..35682.0);
1676 data[[i, 5]] = rng.gen_range(0.7..1243.0);
1678 data[[i, 6]] = rng.gen_range(32.0..42.0);
1680 data[[i, 7]] = rng.gen_range(-124.0..-114.0);
1682
1683 let house_value = data[[i, 0]] * 50000.0 + data[[i, 2]] * 10000.0 + (40.0 - data[[i, 6]]) * 5000.0; target[i] = house_value / 100000.0; }
1690
1691 Ok((data, target))
1692 }
1693
1694 fn create_synthetic_wine_data(
1695 &self,
1696 n_samples: usize,
1697 n_features: usize,
1698 ) -> Result<(Array2<f64>, Array1<f64>)> {
1699 use rand::Rng;
1700 let mut rng = rng();
1701
1702 let mut data = Array2::zeros((n_samples, n_features));
1703 let mut target = Array1::zeros(n_samples);
1704
1705 for i in 0..n_samples {
1706 data[[i, 0]] = rng.gen_range(4.6..15.9); data[[i, 1]] = rng.gen_range(0.12..1.58); data[[i, 2]] = rng.gen_range(0.0..1.0); data[[i, 3]] = rng.gen_range(0.9..15.5); data[[i, 4]] = rng.gen_range(0.012..0.611); data[[i, 5]] = rng.gen_range(1.0..72.0); data[[i, 6]] = rng.gen_range(6.0..289.0); data[[i, 7]] = rng.gen_range(0.99007..1.00369); data[[i, 8]] = rng.gen_range(2.74..4.01); data[[i, 9]] = rng.gen_range(0.33..2.0); data[[i, 10]] = rng.gen_range(8.4..14.9); let quality: f64 = 3.0 +
1721 (data[[i, 10]] - 8.0) * 0.5 + (1.0 - data[[i, 1]]) * 2.0 + data[[i, 2]] * 2.0 + rng.gen_range(-0.5..0.5); target[i] = quality.clamp(3.0, 8.0);
1727 }
1728
1729 Ok((data, target))
1730 }
1731
1732 fn create_synthetic_energy_data(
1733 &self,
1734 n_samples: usize,
1735 n_features: usize,
1736 ) -> Result<(Array2<f64>, Array1<f64>)> {
1737 use rand::Rng;
1738 let mut rng = rng();
1739
1740 let mut data = Array2::zeros((n_samples, n_features));
1741 let mut target = Array1::zeros(n_samples);
1742
1743 for i in 0..n_samples {
1744 for j in 0..n_features {
1745 data[[i, j]] = rng.gen_range(0.0..1.0);
1746 }
1747
1748 let efficiency = data.row(i).iter().sum::<f64>() / n_features as f64;
1750 target[i] = efficiency * 40.0 + 10.0; }
1752
1753 Ok((data, target))
1754 }
1755
1756 fn create_air_passengers_data(
1757 &self,
1758 n_timesteps: usize,
1759 ) -> Result<(Array2<f64>, Option<Array1<f64>>)> {
1760 use rand::Rng;
1761 let mut rng = rng();
1762 let mut data = Array2::zeros((n_timesteps, 1));
1763
1764 for i in 0..n_timesteps {
1765 let t = i as f64;
1766 let trend = 100.0 + t * 2.0;
1767 let seasonal = 20.0 * (2.0 * std::f64::consts::PI * t / 12.0).sin();
1768 let noise = rng.random::<f64>() * 10.0 - 5.0;
1769
1770 data[[i, 0]] = trend + seasonal + noise;
1771 }
1772
1773 Ok((data, None))
1774 }
1775
1776 fn create_bitcoin_price_data(
1777 &self,
1778 n_timesteps: usize,
1779 ) -> Result<(Array2<f64>, Option<Array1<f64>>)> {
1780 use rand::Rng;
1781 let mut rng = rng();
1782
1783 let mut data = Array2::zeros((n_timesteps, 6));
1784 let mut price = 30000.0; for i in 0..n_timesteps {
1787 let change = rng.gen_range(-0.05..0.05);
1789 price *= 1.0 + change;
1790
1791 let high = price * (1.0 + rng.gen_range(0.0..0.02));
1792 let low = price * (1.0 - rng.gen_range(0.0..0.02));
1793 let volume = rng.gen_range(1000000.0..10000000.0);
1794
1795 data[[i, 0]] = price; data[[i, 1]] = high;
1797 data[[i, 2]] = low;
1798 data[[i, 3]] = price; data[[i, 4]] = volume;
1800 data[[i, 5]] = price * volume; }
1802
1803 Ok((data, None))
1804 }
1805
1806 fn create_heart_disease_data(
1807 &self,
1808 n_samples: usize,
1809 n_features: usize,
1810 ) -> Result<(Array2<f64>, Array1<f64>)> {
1811 use rand::Rng;
1812 let mut rng = rng();
1813
1814 let mut data = Array2::zeros((n_samples, n_features));
1815 let mut target = Array1::zeros(n_samples);
1816
1817 for i in 0..n_samples {
1818 data[[i, 0]] = rng.gen_range(29.0..77.0);
1820 data[[i, 1]] = if rng.gen_bool(0.68) { 1.0 } else { 0.0 };
1822 data[[i, 2]] = rng.gen_range(0.0f64..4.0).floor();
1824 data[[i, 3]] = rng.gen_range(94.0..200.0);
1826 data[[i, 4]] = rng.gen_range(126.0..564.0);
1828
1829 for j in 5..n_features {
1831 data[[i, j]] = rng.gen_range(0.0..1.0);
1832 }
1833
1834 let risk_score = (data[[i, 0]] - 29.0) / 48.0 * 0.3 + data[[i, 1]] * 0.2 + (data[[i, 3]] - 94.0) / 106.0 * 0.2 + (data[[i, 4]] - 126.0) / 438.0 * 0.3; target[i] = if risk_score > 0.5 { 1.0 } else { 0.0 };
1841 }
1842
1843 Ok((data, target))
1844 }
1845
1846 fn create_diabetes_readmission_data(
1847 &self,
1848 n_samples: usize,
1849 n_features: usize,
1850 ) -> Result<(Array2<f64>, Array1<f64>)> {
1851 use rand::Rng;
1852 let mut rng = rng();
1853
1854 let mut data = Array2::zeros((n_samples, n_features));
1855 let mut target = Array1::zeros(n_samples);
1856
1857 for i in 0..n_samples {
1858 for j in 0..n_features {
1859 data[[i, j]] = rng.gen_range(0.0..1.0);
1860 }
1861
1862 let readmission_score = data.row(i).iter().take(10).sum::<f64>() / 10.0;
1864 target[i] = if readmission_score > 0.6 { 1.0 } else { 0.0 };
1865 }
1866
1867 Ok((data, target))
1868 }
1869
1870 fn create_synthetic_auto_mpg_data(
1871 &self,
1872 n_samples: usize,
1873 n_features: usize,
1874 ) -> Result<(Array2<f64>, Array1<f64>)> {
1875 use rand::Rng;
1876 let mut rng = rng();
1877
1878 let mut data = Array2::zeros((n_samples, n_features));
1879 let mut target = Array1::zeros(n_samples);
1880
1881 for i in 0..n_samples {
1882 data[[i, 0]] = [4.0, 6.0, 8.0][rng.sample(Uniform::new(0, 3).unwrap())];
1884 data[[i, 1]] = rng.gen_range(68.0..455.0);
1886 data[[i, 2]] = rng.gen_range(46.0..230.0);
1888 data[[i, 3]] = rng.gen_range(1613.0..5140.0);
1890 data[[i, 4]] = rng.gen_range(8.0..24.8);
1892 data[[i, 5]] = rng.gen_range(70.0..82.0);
1894 data[[i, 6]] = (rng.gen_range(1.0f64..4.0f64)).floor();
1896
1897 let mpg: f64 = 45.0 - (data[[i, 3]] / 5140.0) * 20.0 - (data[[i, 1]] / 455.0) * 15.0
1899 + (data[[i, 4]] / 24.8) * 10.0
1900 + rng.gen_range(-3.0..3.0);
1901 target[i] = mpg.clamp(9.0, 46.6);
1902 }
1903
1904 Ok((data, target))
1905 }
1906
1907 fn create_synthetic_concrete_data(
1908 &self,
1909 n_samples: usize,
1910 n_features: usize,
1911 ) -> Result<(Array2<f64>, Array1<f64>)> {
1912 use rand::Rng;
1913 let mut rng = rng();
1914
1915 let mut data = Array2::zeros((n_samples, n_features));
1916 let mut target = Array1::zeros(n_samples);
1917
1918 for i in 0..n_samples {
1919 data[[i, 0]] = rng.gen_range(102.0..540.0);
1921 data[[i, 1]] = rng.gen_range(0.0..359.4);
1923 data[[i, 2]] = rng.gen_range(0.0..200.1);
1925 data[[i, 3]] = rng.gen_range(121.8..247.0);
1927 data[[i, 4]] = rng.gen_range(0.0..32.2);
1929 data[[i, 5]] = rng.gen_range(801.0..1145.0);
1931 data[[i, 6]] = rng.gen_range(594.0..992.6);
1933 data[[i, 7]] = rng.gen_range(1.0..365.0);
1935
1936 let strength: f64 = (data[[i, 0]] / 540.0) * 30.0 + (data[[i, 1]] / 359.4) * 15.0 + (data[[i, 3]] / 247.0) * (-20.0) + (data[[i, 7]] / 365.0_f64).ln() * 10.0 + rng.gen_range(-5.0..5.0); target[i] = strength.clamp(2.33, 82.6);
1944 }
1945
1946 Ok((data, target))
1947 }
1948
1949 fn create_synthetic_electricity_data(
1950 &self,
1951 n_samples: usize,
1952 n_features: usize,
1953 ) -> Result<(Array2<f64>, Array1<f64>)> {
1954 use rand::Rng;
1955 let mut rng = rng();
1956
1957 let mut data = Array2::zeros((n_samples, n_features));
1958 let mut target = Array1::zeros(n_samples);
1959
1960 for i in 0..n_samples {
1961 let hour = (i % 24) as f64;
1962 let day_of_year = (i / 24) % 365;
1963
1964 data[[i, 0]] = 20.0
1966 + 15.0 * (day_of_year as f64 * 2.0 * std::f64::consts::PI / 365.0).sin()
1967 + rng.gen_range(-5.0..5.0);
1968 data[[i, 1]] = 50.0 + 30.0 * rng.gen_range(0.0..1.0);
1970 data[[i, 2]] = hour;
1972
1973 let seasonal = 50.0
1975 + 30.0
1976 * (day_of_year as f64 * 2.0 * std::f64::consts::PI / 365.0
1977 + std::f64::consts::PI)
1978 .cos();
1979 let daily = 40.0 + 60.0 * ((hour - 12.0) * std::f64::consts::PI / 12.0).cos();
1980 let temp_effect = (data[[i, 0]] - 20.0).abs() * 2.0; target[i] = seasonal + daily + temp_effect + rng.gen_range(-10.0..10.0);
1983 }
1984
1985 Ok((data, target))
1986 }
1987
1988 fn create_synthetic_stock_data(
1989 &self,
1990 n_samples: usize,
1991 n_features: usize,
1992 ) -> Result<(Array2<f64>, Array1<f64>)> {
1993 use rand::Rng;
1994 let mut rng = rng();
1995
1996 let mut data = Array2::zeros((n_samples, n_features));
1997 let mut target = Array1::zeros(n_samples);
1998
1999 let mut price = 100.0; for i in 0..n_samples {
2002 let change = rng.gen_range(-0.05..0.05);
2004 price *= 1.0 + change;
2005
2006 let high = price * (1.0 + rng.gen_range(0.0..0.02));
2008 let low = price * (1.0 - rng.gen_range(0.0..0.02));
2009 let volume = rng.gen_range(1000000.0..10000000.0);
2010
2011 data[[i, 0]] = price; data[[i, 1]] = high;
2013 data[[i, 2]] = low;
2014 data[[i, 3]] = volume;
2015 data[[i, 4]] = (high - low) / price; let next_change = rng.gen_range(-0.05..0.05);
2019 target[i] = next_change;
2020 }
2021
2022 Ok((data, target))
2023 }
2024
2025 fn create_synthetic_fraud_data(
2026 &self,
2027 n_samples: usize,
2028 n_features: usize,
2029 ) -> Result<(Array2<f64>, Array1<f64>)> {
2030 use rand::Rng;
2031 let mut rng = rng();
2032
2033 let mut data = Array2::zeros((n_samples, n_features));
2034 let mut target = Array1::zeros(n_samples);
2035
2036 for i in 0..n_samples {
2037 let is_fraud = rng.gen_range(0.0..1.0) < 0.001728; for j in 0..n_features {
2040 if j < 28 {
2041 if is_fraud {
2043 data[[i, j]] = rng.gen_range(-5.0..5.0) * 2.0; } else {
2046 data[[i, j]] = rng.gen_range(-3.0..3.0);
2048 }
2049 }
2050 }
2051
2052 target[i] = if is_fraud { 1.0 } else { 0.0 };
2053 }
2054
2055 Ok((data, target))
2056 }
2057
2058 fn create_synthetic_loan_data(
2059 &self,
2060 n_samples: usize,
2061 n_features: usize,
2062 ) -> Result<(Array2<f64>, Array1<f64>)> {
2063 use rand::Rng;
2064 let mut rng = rng();
2065
2066 let mut data = Array2::zeros((n_samples, n_features));
2067 let mut target = Array1::zeros(n_samples);
2068
2069 for i in 0..n_samples {
2070 data[[i, 0]] = rng.gen_range(1000.0..50000.0);
2072 data[[i, 1]] = rng.gen_range(5.0..25.0);
2074 data[[i, 2]] = [12.0, 24.0, 36.0, 48.0, 60.0][rng.sample(Uniform::new(0, 5).unwrap())];
2076 data[[i, 3]] = rng.gen_range(20000.0..200000.0);
2078 data[[i, 4]] = rng.gen_range(300.0..850.0);
2080 data[[i, 5]] = rng.gen_range(0.0..40.0);
2082 data[[i, 6]] = rng.gen_range(0.0..0.4);
2084
2085 for j in 7..n_features {
2087 data[[i, j]] = rng.gen_range(0.0..1.0);
2088 }
2089
2090 let risk_score = (850.0 - data[[i, 4]]) / 550.0 * 0.4 + data[[i, 6]] * 0.3 + (data[[i, 1]] - 5.0) / 20.0 * 0.2 + (50000.0 - data[[i, 3]]) / 180000.0 * 0.1; target[i] = if risk_score > 0.3 { 1.0 } else { 0.0 };
2097 }
2098
2099 Ok((data, target))
2100 }
2101
2102 fn create_synthetic_adult_dataset(
2103 &self,
2104 n_samples: usize,
2105 n_features: usize,
2106 ) -> Result<(Array2<f64>, Array1<f64>)> {
2107 use rand::Rng;
2108 let mut rng = rng();
2109
2110 let mut data = Array2::zeros((n_samples, n_features));
2111 let mut target = Array1::zeros(n_samples);
2112
2113 for i in 0..n_samples {
2114 data[[i, 0]] = rng.gen_range(17.0..90.0);
2116 data[[i, 1]] = rng.gen_range(0.0f64..9.0).floor();
2118 data[[i, 2]] = rng.gen_range(12285.0..1484705.0);
2120 data[[i, 3]] = rng.gen_range(0.0f64..16.0).floor();
2122 data[[i, 4]] = rng.gen_range(1.0..17.0);
2124 data[[i, 5]] = rng.gen_range(0.0f64..7.0).floor();
2126 data[[i, 6]] = rng.gen_range(0.0f64..14.0).floor();
2128 data[[i, 7]] = rng.gen_range(0.0f64..6.0).floor();
2130 data[[i, 8]] = rng.gen_range(0.0f64..5.0).floor();
2132 data[[i, 9]] = if rng.gen_bool(0.67) { 1.0 } else { 0.0 };
2134 data[[i, 10]] = if rng.gen_bool(0.9) {
2136 0.0
2137 } else {
2138 rng.gen_range(1.0..99999.0)
2139 };
2140 data[[i, 11]] = if rng.gen_bool(0.95) {
2142 0.0
2143 } else {
2144 rng.gen_range(1.0..4356.0)
2145 };
2146 data[[i, 12]] = rng.gen_range(1.0..99.0);
2148 data[[i, 13]] = rng.gen_range(0.0f64..41.0).floor();
2150
2151 let income_score = (data[[i, 0]] - 17.0) / 73.0 * 0.2 + data[[i, 4]] / 16.0 * 0.3 + data[[i, 9]] * 0.2 + (data[[i, 12]] - 1.0) / 98.0 * 0.2 + (data[[i, 10]] + data[[i, 11]]) / 100000.0 * 0.1; let noise = rng.gen_range(-0.15..0.15);
2160 target[i] = if (income_score + noise) > 0.5 {
2161 1.0
2162 } else {
2163 0.0
2164 };
2165 }
2166
2167 Ok((data, target))
2168 }
2169
2170 fn create_generic_synthetic_dataset(
2171 &self,
2172 n_samples: usize,
2173 n_features: usize,
2174 has_categorical: bool,
2175 ) -> Result<(Array2<f64>, Option<Array1<f64>>)> {
2176 use rand::Rng;
2177 let mut rng = rng();
2178
2179 let mut data = Array2::zeros((n_samples, n_features));
2180
2181 for i in 0..n_samples {
2182 for j in 0..n_features {
2183 if has_categorical && j < n_features / 3 {
2184 data[[i, j]] = rng.gen_range(0.0f64..10.0).floor();
2186 } else {
2187 data[[i, j]] = rng.gen_range(-2.0..2.0);
2189 }
2190 }
2191 }
2192
2193 let mut target = Array1::zeros(n_samples);
2195 for i in 0..n_samples {
2196 let feature_sum = data.row(i).iter().sum::<f64>();
2197 let score = feature_sum / n_features as f64;
2198 target[i] = if score > 0.0 { 1.0 } else { 0.0 };
2199 }
2200
2201 Ok((data, Some(target)))
2202 }
2203
2204 fn create_synthetic_cifar10_data(
2205 &self,
2206 n_samples: usize,
2207 n_features: usize,
2208 ) -> Result<(Array2<f64>, Array1<f64>)> {
2209 use rand::Rng;
2210 let mut rng = rng();
2211
2212 let mut data = Array2::zeros((n_samples, n_features));
2213 let mut target = Array1::zeros(n_samples);
2214
2215 for i in 0..n_samples {
2216 let class = rng.sample(Uniform::new(0, 10).unwrap()) as f64;
2217 target[i] = class;
2218
2219 for j in 0..n_features {
2221 let base_intensity = match class as i32 {
2222 0 => 0.6, 1 => 0.3, 2 => 0.8, 3 => 0.5, 4 => 0.7, 5 => 0.4, 6 => 0.9, 7 => 0.6, 8 => 0.2, 9 => 0.3, _ => 0.5,
2233 };
2234
2235 data[[i, j]] = base_intensity + rng.gen_range(-0.3f64..0.3f64);
2237 data[[i, j]] = data[[i, j]].clamp(0.0f64, 1.0f64); }
2239 }
2240
2241 Ok((data, target))
2242 }
2243
2244 fn create_synthetic_fashion_mnist_data(
2245 &self,
2246 n_samples: usize,
2247 n_features: usize,
2248 ) -> Result<(Array2<f64>, Array1<f64>)> {
2249 use rand::Rng;
2250 let mut rng = rng();
2251
2252 let mut data = Array2::zeros((n_samples, n_features));
2253 let mut target = Array1::zeros(n_samples);
2254
2255 for i in 0..n_samples {
2256 let class = rng.sample(Uniform::new(0, 10).unwrap()) as f64;
2257 target[i] = class;
2258
2259 for j in 0..n_features {
2261 let base_intensity = match class as i32 {
2262 0 => 0.3, 1 => 0.4, 2 => 0.5, 3 => 0.6, 4 => 0.7, 5 => 0.2, 6 => 0.4, 7 => 0.3, 8 => 0.5, 9 => 0.4, _ => 0.4,
2273 };
2274
2275 let texture_noise = rng.gen_range(-0.2f64..0.2f64);
2277 data[[i, j]] = base_intensity + texture_noise;
2278 data[[i, j]] = data[[i, j]].clamp(0.0f64, 1.0f64); }
2280 }
2281
2282 Ok((data, target))
2283 }
2284
2285 fn create_synthetic_imdb_data(
2286 &self,
2287 n_samples: usize,
2288 n_features: usize,
2289 ) -> Result<(Array2<f64>, Array1<f64>)> {
2290 use rand::Rng;
2291 let mut rng = rng();
2292
2293 let mut data = Array2::zeros((n_samples, n_features));
2294 let mut target = Array1::zeros(n_samples);
2295
2296 let positive_words = 0..n_features / 3; let negative_words = n_features / 3..2 * n_features / 3; let _neutral_words = 2 * n_features / 3..n_features; for i in 0..n_samples {
2302 let is_positive = rng.gen_bool(0.5);
2303 target[i] = if is_positive { 1.0 } else { 0.0 };
2304
2305 for j in 0..n_features {
2306 let base_freq = if positive_words.contains(&j) {
2307 if is_positive {
2308 rng.gen_range(0.5..2.0) } else {
2310 rng.gen_range(0.0..0.5) }
2312 } else if negative_words.contains(&j) {
2313 if is_positive {
2314 rng.gen_range(0.0..0.5) } else {
2316 rng.gen_range(0.5..2.0) }
2318 } else {
2319 rng.gen_range(0.2..1.0)
2321 };
2322
2323 data[[i, j]] = base_freq;
2324 }
2325 }
2326
2327 Ok((data, target))
2328 }
2329
2330 fn create_synthetic_news_data(
2331 &self,
2332 n_samples: usize,
2333 n_features: usize,
2334 ) -> Result<(Array2<f64>, Array1<f64>)> {
2335 use rand::Rng;
2336 let mut rng = rng();
2337
2338 let mut data = Array2::zeros((n_samples, n_features));
2339 let mut target = Array1::zeros(n_samples);
2340
2341 let words_per_topic = n_features / 5; for i in 0..n_samples {
2345 let topic = rng.sample(Uniform::new(0, 5).unwrap()) as f64;
2346 target[i] = topic;
2347
2348 for j in 0..n_features {
2349 let word_topic = j / words_per_topic;
2350
2351 let base_freq = if word_topic == topic as usize {
2352 rng.gen_range(1.0..3.0)
2354 } else {
2355 rng.gen_range(0.0..0.8)
2357 };
2358
2359 let noise = rng.gen_range(-0.2f64..0.2f64);
2361 data[[i, j]] = (base_freq + noise).max(0.0f64);
2362 }
2363 }
2364
2365 Ok((data, target))
2366 }
2367}
2368
2369#[allow(dead_code)]
2371pub fn load_adult() -> Result<Dataset> {
2372 let config = RealWorldConfig::default();
2373 let mut loader = RealWorldDatasets::new(config)?;
2374 loader.load_adult()
2375}
2376
2377#[allow(dead_code)]
2379pub fn load_titanic() -> Result<Dataset> {
2380 let config = RealWorldConfig::default();
2381 let mut loader = RealWorldDatasets::new(config)?;
2382 loader.load_titanic()
2383}
2384
2385#[allow(dead_code)]
2387pub fn load_california_housing() -> Result<Dataset> {
2388 let config = RealWorldConfig::default();
2389 let mut loader = RealWorldDatasets::new(config)?;
2390 loader.load_california_housing()
2391}
2392
2393#[allow(dead_code)]
2395pub fn load_heart_disease() -> Result<Dataset> {
2396 let config = RealWorldConfig::default();
2397 let mut loader = RealWorldDatasets::new(config)?;
2398 loader.load_heart_disease()
2399}
2400
2401#[allow(dead_code)]
2403pub fn load_red_wine_quality() -> Result<Dataset> {
2404 let config = RealWorldConfig::default();
2405 let mut loader = RealWorldDatasets::new(config)?;
2406 loader.load_red_wine_quality()
2407}
2408
2409#[allow(dead_code)]
2411pub fn list_real_world_datasets() -> Vec<String> {
2412 let config = RealWorldConfig::default();
2413 let loader = RealWorldDatasets::new(config).unwrap();
2414 loader.list_datasets()
2415}
2416
2417#[cfg(test)]
2418mod tests {
2419 use super::*;
2420 use rand_distr::Uniform;
2421
2422 #[test]
2423 fn test_load_titanic() {
2424 let dataset = load_titanic().unwrap();
2425 assert_eq!(dataset.n_samples(), 891);
2426 assert_eq!(dataset.n_features(), 7);
2427 assert!(dataset.target.is_some());
2428 }
2429
2430 #[test]
2431 fn test_load_california_housing() {
2432 let dataset = load_california_housing().unwrap();
2433 assert_eq!(dataset.n_samples(), 20640);
2434 assert_eq!(dataset.n_features(), 8);
2435 assert!(dataset.target.is_some());
2436 }
2437
2438 #[test]
2439 fn test_load_heart_disease() {
2440 let dataset = load_heart_disease().unwrap();
2441 assert_eq!(dataset.n_samples(), 303);
2442 assert_eq!(dataset.n_features(), 13);
2443 assert!(dataset.target.is_some());
2444 }
2445
2446 #[test]
2447 fn test_list_datasets() {
2448 let datasets = list_real_world_datasets();
2449 assert!(!datasets.is_empty());
2450 assert!(datasets.contains(&"titanic".to_string()));
2451 assert!(datasets.contains(&"california_housing".to_string()));
2452 }
2453
2454 #[test]
2455 fn test_real_world_config() {
2456 let config = RealWorldConfig {
2457 use_cache: false,
2458 download_if_missing: false,
2459 ..Default::default()
2460 };
2461
2462 assert!(!config.use_cache);
2463 assert!(!config.download_if_missing);
2464 }
2465}