1use std::collections::HashMap;
13
14use scirs2_core::ndarray::{Array1, Array2};
15use scirs2_core::random::prelude::*;
16use scirs2_core::random::{Distribution, Uniform};
17use serde::{Deserialize, Serialize};
18
19use crate::cache::DatasetCache;
20use crate::error::{DatasetsError, Result};
21use crate::external::ExternalClient;
22use crate::utils::Dataset;
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct DomainConfig {
27 pub base_url: Option<String>,
29 pub api_key: Option<String>,
31 pub preferred_formats: Vec<String>,
33 pub quality_filters: QualityFilters,
35}
36
37impl Default for DomainConfig {
38 fn default() -> Self {
39 Self {
40 base_url: None,
41 api_key: None,
42 preferred_formats: vec!["csv".to_string(), "fits".to_string(), "hdf5".to_string()],
43 quality_filters: QualityFilters::default(),
44 }
45 }
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct QualityFilters {
51 pub min_samples: Option<usize>,
53 pub max_missing_percent: Option<f64>,
55 pub min_completeness: Option<f64>,
57 pub min_year: Option<u32>,
59}
60
61impl Default for QualityFilters {
62 fn default() -> Self {
63 Self {
64 min_samples: Some(100),
65 max_missing_percent: Some(0.1),
66 min_completeness: Some(0.9),
67 min_year: Some(2000),
68 }
69 }
70}
71
72pub mod astronomy {
74 use super::*;
75
76 pub struct StellarDatasets {
78 #[allow(dead_code)]
79 client: ExternalClient,
80 #[allow(dead_code)]
81 cache: DatasetCache,
82 }
83
84 impl StellarDatasets {
85 pub fn new() -> Result<Self> {
87 let cachedir = dirs::cache_dir()
88 .ok_or_else(|| {
89 DatasetsError::Other("Could not determine cache directory".to_string())
90 })?
91 .join("scirs2-datasets");
92 Ok(Self {
93 client: ExternalClient::new()?,
94 cache: DatasetCache::new(cachedir),
95 })
96 }
97
98 pub fn load_hipparcos_catalog(&self) -> Result<Dataset> {
100 self.load_synthetic_stellar_data("hipparcos", 118218)
101 }
102
103 pub fn load_gaia_dr3_sample(&self) -> Result<Dataset> {
105 self.load_synthetic_stellar_data("gaia_dr3", 50000)
106 }
107
108 pub fn load_exoplanet_catalog(&self) -> Result<Dataset> {
110 self.load_synthetic_exoplanet_data(5000)
111 }
112
113 pub fn load_supernova_photometry(&self) -> Result<Dataset> {
115 self.load_synthetic_supernova_data(1000)
116 }
117
118 fn load_synthetic_stellar_data(&self, catalog: &str, nstars: usize) -> Result<Dataset> {
119 use scirs2_core::random::{Distribution, Normal};
120
121 let mut rng = thread_rng();
122
123 let mut data = Vec::with_capacity(nstars * 8);
125 let mut spectral_classes = Vec::with_capacity(nstars);
126
127 let ra_dist = scirs2_core::random::Uniform::new(0.0, 360.0).unwrap();
129 let dec_dist = scirs2_core::random::Uniform::new(-90.0, 90.0).unwrap();
130 let magnitude_dist = Normal::new(8.0, 3.0).unwrap();
131 let color_dist = Normal::new(0.5, 0.3).unwrap();
132 let parallax_dist = Normal::new(10.0, 5.0).unwrap();
133 let proper_motion_dist = Normal::new(0.0, 50.0).unwrap();
134 let radial_velocity_dist = Normal::new(0.0, 30.0).unwrap();
135
136 for _ in 0..nstars {
137 data.push(ra_dist.sample(&mut rng));
139 data.push(dec_dist.sample(&mut rng));
141 data.push(magnitude_dist.sample(&mut rng));
143 data.push(color_dist.sample(&mut rng));
145 data.push((parallax_dist.sample(&mut rng) as f64).max(0.1f64));
147 data.push(proper_motion_dist.sample(&mut rng));
149 data.push(proper_motion_dist.sample(&mut rng));
151 data.push(radial_velocity_dist.sample(&mut rng));
153
154 let color = data[data.len() - 5];
156 let spectral_class = match color {
157 c if c < -0.3 => 0, c if c < -0.1 => 1, c if c < 0.2 => 2, c if c < 0.5 => 3, c if c < 0.8 => 4, c if c < 1.2 => 5, _ => 6, };
165 spectral_classes.push(spectral_class as f64);
166 }
167
168 let data_array = Array2::from_shape_vec((nstars, 8), data)
169 .map_err(|e| DatasetsError::FormatError(e.to_string()))?;
170
171 let target = Array1::from_vec(spectral_classes);
172
173 Ok(Dataset {
174 data: data_array,
175 target: Some(target),
176 featurenames: Some(vec![
177 "ra".to_string(),
178 "dec".to_string(),
179 "magnitude".to_string(),
180 "color_bv".to_string(),
181 "parallax".to_string(),
182 "pm_ra".to_string(),
183 "pm_dec".to_string(),
184 "radial_velocity".to_string(),
185 ]),
186 targetnames: Some(vec![
187 "O".to_string(),
188 "B".to_string(),
189 "A".to_string(),
190 "F".to_string(),
191 "G".to_string(),
192 "K".to_string(),
193 "M".to_string(),
194 ]),
195 feature_descriptions: Some(vec![
196 "Right Ascension (degrees)".to_string(),
197 "Declination (degrees)".to_string(),
198 "Apparent magnitude (visual)".to_string(),
199 "B-V color index".to_string(),
200 "Parallax (arcseconds)".to_string(),
201 "Proper motion RA (mas/year)".to_string(),
202 "Proper motion Dec (mas/year)".to_string(),
203 "Radial velocity (km/s)".to_string(),
204 ]),
205 description: Some(format!(
206 "Synthetic {catalog} stellar catalog with {nstars} _stars"
207 )),
208 metadata: std::collections::HashMap::new(),
209 })
210 }
211
212 fn load_synthetic_exoplanet_data(&self, nplanets: usize) -> Result<Dataset> {
213 use scirs2_core::random::{Distribution, LogNormal, Normal};
214
215 let mut rng = thread_rng();
216
217 let mut data = Vec::with_capacity(nplanets * 6);
219 let mut planet_types = Vec::with_capacity(nplanets);
220
221 let period_dist = LogNormal::new(1.0, 1.5).unwrap();
223 let radius_dist = LogNormal::new(0.0, 0.8).unwrap();
224 let mass_dist = LogNormal::new(1.0, 1.2).unwrap();
225 let stellar_mass_dist = Normal::new(1.0, 0.3).unwrap();
226 let stellar_temp_dist = Normal::new(5800.0, 1000.0).unwrap();
227 let metallicity_dist = Normal::new(0.0, 0.3).unwrap();
228
229 for _ in 0..nplanets {
230 data.push(period_dist.sample(&mut rng));
232 data.push(radius_dist.sample(&mut rng));
234 data.push(mass_dist.sample(&mut rng));
236 data.push((stellar_mass_dist.sample(&mut rng) as f64).max(0.1f64));
238 data.push(stellar_temp_dist.sample(&mut rng));
240 data.push(metallicity_dist.sample(&mut rng));
242
243 let radius = data[data.len() - 5];
245 let planet_type = match radius {
246 r if r < 1.25 => 0, r if r < 2.0 => 1, r if r < 4.0 => 2, r if r < 11.0 => 3, _ => 4, };
252 planet_types.push(planet_type as f64);
253 }
254
255 let data_array = Array2::from_shape_vec((nplanets, 6), data)
256 .map_err(|e| DatasetsError::FormatError(e.to_string()))?;
257
258 let target = Array1::from_vec(planet_types);
259
260 Ok(Dataset {
261 data: data_array,
262 target: Some(target),
263 featurenames: Some(vec![
264 "period".to_string(),
265 "radius".to_string(),
266 "mass".to_string(),
267 "stellar_mass".to_string(),
268 "stellar_temp".to_string(),
269 "metallicity".to_string(),
270 ]),
271 targetnames: Some(vec![
272 "Rocky".to_string(),
273 "Super-Earth".to_string(),
274 "Sub-Neptune".to_string(),
275 "Neptune".to_string(),
276 "Jupiter".to_string(),
277 ]),
278 feature_descriptions: Some(vec![
279 "Orbital period (days)".to_string(),
280 "Planet radius (Earth radii)".to_string(),
281 "Planet mass (Earth masses)".to_string(),
282 "Stellar mass (Solar masses)".to_string(),
283 "Stellar temperature (K)".to_string(),
284 "Stellar metallicity [Fe/H]".to_string(),
285 ]),
286 description: Some(format!(
287 "Synthetic exoplanet catalog with {nplanets} _planets"
288 )),
289 metadata: std::collections::HashMap::new(),
290 })
291 }
292
293 fn load_synthetic_supernova_data(&self, nsupernovae: usize) -> Result<Dataset> {
294 use scirs2_core::random::{Distribution, Normal};
295
296 let mut rng = thread_rng();
297
298 let mut data = Vec::with_capacity(nsupernovae * 10);
300 let mut sn_types = Vec::with_capacity(nsupernovae);
301
302 let _type_probs = [0.7, 0.15, 0.10, 0.05]; for _ in 0..nsupernovae {
306 let sn_type = rng.sample(Uniform::new(0, 4).unwrap());
307
308 let (peak_mag, decline_rate, color_evolution, host_mass) = match sn_type {
309 0 => (-19.3, 1.1, 0.2, 10.5), 1 => (-18.5, 1.8, 0.5, 9.8), 2 => (-16.8, 0.8, 0.3, 9.2), _ => (-17.5, 1.2, 0.4, 9.0), };
314
315 let peak_noise = Normal::new(0.0, 0.3).unwrap();
317 let decline_noise = Normal::new(0.0, 0.2).unwrap();
318 let color_noise = Normal::new(0.0, 0.1).unwrap();
319 let host_noise = Normal::new(0.0, 0.5).unwrap();
320
321 data.push(peak_mag + peak_noise.sample(&mut rng));
323 data.push(decline_rate + decline_noise.sample(&mut rng));
325 data.push(color_evolution + color_noise.sample(&mut rng));
327 data.push(host_mass + host_noise.sample(&mut rng));
329 data.push(rng.gen_range(0.01..0.3));
331 data.push(rng.gen_range(20.0..200.0));
333 data.push(rng.gen_range(0.7..1.3));
335 data.push(rng.gen_range(0.0..0.5));
337 data.push(rng.gen_range(15.0..22.0));
339 data.push(rng.gen_range(-90.0..90.0));
341
342 sn_types.push(sn_type as f64);
343 }
344
345 let data_array = Array2::from_shape_vec((nsupernovae, 10), data)
346 .map_err(|e| DatasetsError::FormatError(e.to_string()))?;
347
348 let target = Array1::from_vec(sn_types);
349
350 Ok(Dataset {
351 data: data_array,
352 target: Some(target),
353 featurenames: Some(vec![
354 "peak_magnitude".to_string(),
355 "decline_rate".to_string(),
356 "color_max".to_string(),
357 "host_mass".to_string(),
358 "redshift".to_string(),
359 "duration".to_string(),
360 "stretch".to_string(),
361 "color_excess".to_string(),
362 "discovery_mag".to_string(),
363 "galactic_lat".to_string(),
364 ]),
365 targetnames: Some(vec![
366 "Type Ia".to_string(),
367 "Type Ib/c".to_string(),
368 "Type II-P".to_string(),
369 "Type II-L".to_string(),
370 ]),
371 feature_descriptions: Some(vec![
372 "Peak apparent magnitude".to_string(),
373 "Magnitude decline rate (mag/day)".to_string(),
374 "Maximum color index".to_string(),
375 "Host galaxy stellar mass (log10 M_sun)".to_string(),
376 "Cosmological redshift".to_string(),
377 "Light curve duration (days)".to_string(),
378 "Light curve stretch factor".to_string(),
379 "Host galaxy color excess E(B-V)".to_string(),
380 "Discovery magnitude".to_string(),
381 "Galactic latitude (degrees)".to_string(),
382 ]),
383 description: Some(format!(
384 "Synthetic supernova catalog with {nsupernovae} events"
385 )),
386 metadata: std::collections::HashMap::new(),
387 })
388 }
389 }
390}
391
392pub mod genomics {
394 use super::*;
395
396 pub struct GenomicsDatasets {
398 #[allow(dead_code)]
399 client: ExternalClient,
400 #[allow(dead_code)]
401 cache: DatasetCache,
402 }
403
404 impl GenomicsDatasets {
405 pub fn new() -> Result<Self> {
407 let cachedir = dirs::cache_dir()
408 .ok_or_else(|| {
409 DatasetsError::Other("Could not determine cache directory".to_string())
410 })?
411 .join("scirs2-datasets");
412 Ok(Self {
413 client: ExternalClient::new()?,
414 cache: DatasetCache::new(cachedir),
415 })
416 }
417
418 pub fn load_gene_expression(&self, n_samples: usize, ngenes: usize) -> Result<Dataset> {
420 use scirs2_core::random::{Distribution, LogNormal, Normal};
421
422 let mut rng = thread_rng();
423
424 let mut data = Vec::with_capacity(n_samples * ngenes);
426 let mut phenotypes = Vec::with_capacity(n_samples);
427
428 let condition_effects = [1.0, 2.5, 0.4, 1.8, 0.7]; for sample_idx in 0..n_samples {
432 let condition = sample_idx % condition_effects.len();
433 let base_effect = condition_effects[condition];
434
435 for gene_idx in 0..ngenes {
436 let base_expr = LogNormal::new(5.0, 2.0).unwrap().sample(&mut rng);
438
439 let gene_effect = if gene_idx < ngenes / 10 {
441 base_effect
443 } else {
444 1.0
445 };
446
447 let noise = Normal::new(1.0, 0.2).unwrap().sample(&mut rng);
449
450 let expression: f64 = base_expr * gene_effect * noise;
451 data.push(expression.ln()); }
453
454 phenotypes.push(condition as f64);
455 }
456
457 let data_array = Array2::from_shape_vec((n_samples, ngenes), data)
458 .map_err(|e| DatasetsError::FormatError(e.to_string()))?;
459
460 let target = Array1::from_vec(phenotypes);
461
462 let featurenames: Vec<String> = (0..ngenes).map(|i| format!("GENE_{i:06}")).collect();
464
465 Ok(Dataset {
466 data: data_array,
467 target: Some(target),
468 featurenames: Some(featurenames.clone()),
469 targetnames: Some(vec![
470 "Control".to_string(),
471 "Treatment_A".to_string(),
472 "Treatment_B".to_string(),
473 "Disease_X".to_string(),
474 "Disease_Y".to_string(),
475 ]),
476 feature_descriptions: Some(
477 featurenames
478 .iter()
479 .map(|name| format!("Expression level of {name}"))
480 .collect(),
481 ),
482 description: Some(format!(
483 "Synthetic gene expression data: {n_samples} _samples × {ngenes} _genes"
484 )),
485 metadata: std::collections::HashMap::new(),
486 })
487 }
488
489 pub fn load_dnasequences(
491 &self,
492 nsequences: usize,
493 sequence_length: usize,
494 ) -> Result<Dataset> {
495 let mut rng = thread_rng();
496 let nucleotides = ['A', 'T', 'G', 'C'];
497
498 let mut sequences = Vec::new();
499 let mut sequence_types = Vec::with_capacity(nsequences);
500
501 for seq_idx in 0..nsequences {
502 let mut sequence = String::with_capacity(sequence_length);
503
504 let seq_type = seq_idx % 3; for _pos in 0..sequence_length {
508 let nucleotide = match seq_type {
509 0 => {
510 if rng.random::<f64>() < 0.6 {
512 if rng.random::<f64>() < 0.5 {
513 'G'
514 } else {
515 'C'
516 }
517 } else if rng.random::<f64>() < 0.5 {
518 'A'
519 } else {
520 'T'
521 }
522 }
523 1 => {
524 if rng.random::<f64>() < 0.6 {
526 if rng.random::<f64>() < 0.5 {
527 'A'
528 } else {
529 'T'
530 }
531 } else if rng.random::<f64>() < 0.5 {
532 'G'
533 } else {
534 'C'
535 }
536 }
537 _ => {
538 nucleotides[rng.sample(Uniform::new(0, 4).unwrap())]
540 }
541 };
542
543 sequence.push(nucleotide);
544 }
545
546 sequences.push(sequence);
547 sequence_types.push(seq_type as f64);
548 }
549
550 let mut data = Vec::new();
552 let k = 3;
553 let kmers = Self::generate_kmers(k);
554
555 for sequence in &sequences {
556 let kmer_counts = Self::count_kmers(sequence, k, &kmers);
557 data.extend(kmer_counts);
558 }
559
560 let n_features = 4_usize.pow(k as u32); let data_array = Array2::from_shape_vec((nsequences, n_features), data)
562 .map_err(|e| DatasetsError::FormatError(e.to_string()))?;
563
564 let target = Array1::from_vec(sequence_types);
565
566 Ok(Dataset {
567 data: data_array,
568 target: Some(target),
569 featurenames: Some(kmers.clone()),
570 targetnames: Some(vec![
571 "GC-rich".to_string(),
572 "AT-rich".to_string(),
573 "Random".to_string(),
574 ]),
575 feature_descriptions: Some(
576 kmers
577 .iter()
578 .map(|kmer| format!("Frequency of {k}-mer: {kmer}"))
579 .collect(),
580 ),
581 description: Some(format!(
582 "DNA sequences: {nsequences} seqs × {k}-mer features"
583 )),
584 metadata: std::collections::HashMap::new(),
585 })
586 }
587
588 fn generate_kmers(k: usize) -> Vec<String> {
589 let nucleotides = vec!['A', 'T', 'G', 'C'];
590 let mut kmers = Vec::new();
591
592 fn generate_recursive(
593 current: String,
594 remaining: usize,
595 nucleotides: &[char],
596 kmers: &mut Vec<String>,
597 ) {
598 if remaining == 0 {
599 kmers.push(current);
600 return;
601 }
602
603 for &nucleotide in nucleotides {
604 let mut new_current = current.clone();
605 new_current.push(nucleotide);
606 generate_recursive(new_current, remaining - 1, nucleotides, kmers);
607 }
608 }
609
610 generate_recursive(String::new(), k, &nucleotides, &mut kmers);
611 kmers
612 }
613
614 fn count_kmers(sequence: &str, k: usize, kmers: &[String]) -> Vec<f64> {
615 let mut counts = vec![0.0; kmers.len()];
616 let kmer_to_idx: HashMap<&str, usize> = kmers
617 .iter()
618 .enumerate()
619 .map(|(i, k)| (k.as_str(), i))
620 .collect();
621
622 for i in 0..=sequence.len().saturating_sub(k) {
623 let kmer = &sequence[i..i + k];
624 if let Some(&idx) = kmer_to_idx.get(kmer) {
625 counts[idx] += 1.0;
626 }
627 }
628
629 let total: f64 = counts.iter().sum();
631 if total > 0.0 {
632 for count in &mut counts {
633 *count /= total;
634 }
635 }
636
637 counts
638 }
639 }
640}
641
642pub mod climate {
644 use super::*;
645
646 pub struct ClimateDatasets {
648 #[allow(dead_code)]
649 client: ExternalClient,
650 #[allow(dead_code)]
651 cache: DatasetCache,
652 }
653
654 impl ClimateDatasets {
655 pub fn new() -> Result<Self> {
657 let cachedir = dirs::cache_dir()
658 .ok_or_else(|| {
659 DatasetsError::Other("Could not determine cache directory".to_string())
660 })?
661 .join("scirs2-datasets");
662 Ok(Self {
663 client: ExternalClient::new()?,
664 cache: DatasetCache::new(cachedir),
665 })
666 }
667
668 pub fn load_temperature_timeseries(
670 &self,
671 n_stations: usize,
672 n_years: usize,
673 ) -> Result<Dataset> {
674 use scirs2_core::random::{Distribution, Normal};
675
676 let mut rng = thread_rng();
677 let days_per_year = 365;
678 let total_days = n_years * days_per_year;
679
680 let mut data = Vec::with_capacity(n_stations * 8); let mut climate_zones = Vec::with_capacity(n_stations);
682
683 for station_idx in 0..n_stations {
684 let zone = station_idx % 5; climate_zones.push(zone as f64);
687
688 let (base_temp, temp_amplitude, annual_precip, humidity) = match zone {
690 0 => (25.0, 5.0, 2000.0, 80.0), 1 => (15.0, 15.0, 800.0, 60.0), 2 => (-5.0, 20.0, 400.0, 70.0), 3 => (5.0, 8.0, 200.0, 40.0), _ => (-10.0, 25.0, 300.0, 75.0), };
696
697 let mut temperatures = Vec::with_capacity(total_days);
699 let mut precipitation = Vec::with_capacity(total_days);
700
701 for day in 0..total_days {
702 let year_progress = (day % days_per_year) as f64 / days_per_year as f64;
703 let seasonal_temp = base_temp
704 + temp_amplitude * (year_progress * 2.0 * std::f64::consts::PI).cos();
705
706 let temp_noise = Normal::new(0.0, 2.0).unwrap();
708 let temp = seasonal_temp + temp_noise.sample(&mut rng);
709 temperatures.push(temp);
710
711 let seasonal_precip_factor = match zone {
713 0 => {
714 1.0 + 0.3
715 * (year_progress * 2.0 * std::f64::consts::PI
716 + std::f64::consts::PI)
717 .cos()
718 }
719 1 => 1.0 + 0.2 * (year_progress * 2.0 * std::f64::consts::PI).sin(),
720 _ => 1.0,
721 };
722
723 let precip = if rng.random::<f64>() < 0.3 {
724 rng.gen_range(0.0..20.0) * seasonal_precip_factor
726 } else {
727 0.0
728 };
729 precipitation.push(precip);
730 }
731
732 let mean_temp = temperatures.iter().sum::<f64>() / temperatures.len() as f64;
734 let max_temp = temperatures
735 .iter()
736 .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
737 let min_temp = temperatures.iter().fold(f64::INFINITY, |a, &b| a.min(b));
738 let temp_range = max_temp - min_temp;
739
740 let total_precip = precipitation.iter().sum::<f64>();
741 let precip_days = precipitation.iter().filter(|&&p| p > 0.0).count() as f64;
742
743 let avg_humidity = humidity + Normal::new(0.0, 5.0).unwrap().sample(&mut rng);
745 let wind_speed = rng.gen_range(2.0..15.0);
746
747 data.extend(vec![
748 mean_temp,
749 temp_range,
750 total_precip,
751 precip_days,
752 avg_humidity,
753 wind_speed,
754 base_temp, annual_precip / 365.0, ]);
757 }
758
759 let data_array = Array2::from_shape_vec((n_stations, 8), data)
760 .map_err(|e| DatasetsError::FormatError(e.to_string()))?;
761
762 let target = Array1::from_vec(climate_zones);
763
764 Ok(Dataset {
765 data: data_array,
766 target: Some(target),
767 featurenames: Some(vec![
768 "mean_temperature".to_string(),
769 "temperature_range".to_string(),
770 "annual_precipitation".to_string(),
771 "precipitation_days".to_string(),
772 "avg_humidity".to_string(),
773 "avg_wind_speed".to_string(),
774 "latitude_proxy".to_string(),
775 "daily_precip_avg".to_string(),
776 ]),
777 targetnames: Some(vec![
778 "Tropical".to_string(),
779 "Temperate".to_string(),
780 "Continental".to_string(),
781 "Desert".to_string(),
782 "Arctic".to_string(),
783 ]),
784 feature_descriptions: Some(vec![
785 "Mean annual temperature (°C)".to_string(),
786 "Temperature range (max-min, °C)".to_string(),
787 "Total annual precipitation (mm)".to_string(),
788 "Number of precipitation days per year".to_string(),
789 "Average humidity (%)".to_string(),
790 "Average wind speed (m/s)".to_string(),
791 "Latitude proxy (normalized)".to_string(),
792 "Average daily precipitation (mm/day)".to_string(),
793 ]),
794 description: Some(format!(
795 "Climate data: {n_stations} _stations × {n_years} _years"
796 )),
797 metadata: std::collections::HashMap::new(),
798 })
799 }
800
801 pub fn load_atmospheric_chemistry(&self, nmeasurements: usize) -> Result<Dataset> {
803 use scirs2_core::random::{Distribution, LogNormal, Normal};
804
805 let mut rng = thread_rng();
806
807 let mut data = Vec::with_capacity(nmeasurements * 12);
808 let mut air_quality_index = Vec::with_capacity(nmeasurements);
809
810 for _ in 0..nmeasurements {
811 let base_pollution = rng.gen_range(0.0..1.0);
813
814 let pm25: f64 = LogNormal::new(2.0 + base_pollution, 0.5)
816 .unwrap()
817 .sample(&mut rng);
818 let pm10 = pm25 * rng.gen_range(1.5..2.5);
819 let no2 = LogNormal::new(3.0 + base_pollution * 0.5, 0.3)
820 .unwrap()
821 .sample(&mut rng);
822 let so2 = LogNormal::new(1.0 + base_pollution * 0.3, 0.4)
823 .unwrap()
824 .sample(&mut rng);
825 let o3 = LogNormal::new(4.0 - base_pollution * 0.2, 0.2)
826 .unwrap()
827 .sample(&mut rng);
828 let co = LogNormal::new(0.5 + base_pollution * 0.4, 0.3)
829 .unwrap()
830 .sample(&mut rng);
831
832 let temperature = Normal::new(20.0, 10.0).unwrap().sample(&mut rng);
834 let humidity = rng.gen_range(30.0..90.0);
835 let wind_speed = rng.gen_range(0.5..12.0);
836 let pressure = Normal::new(1013.0, 15.0).unwrap().sample(&mut rng);
837
838 let visibility = (50.0 - pm25.ln() * 5.0).max(1.0);
840 let uv_index = rng.gen_range(0.0..12.0);
841
842 data.extend(vec![
843 pm25,
844 pm10,
845 no2,
846 so2,
847 o3,
848 co,
849 temperature,
850 humidity,
851 wind_speed,
852 pressure,
853 visibility,
854 uv_index,
855 ]);
856
857 let aqi = Self::calculate_aqi(pm25, pm10, no2, so2, o3, co);
859 air_quality_index.push(aqi);
860 }
861
862 let data_array = Array2::from_shape_vec((nmeasurements, 12), data)
863 .map_err(|e| DatasetsError::FormatError(e.to_string()))?;
864
865 let target = Array1::from_vec(air_quality_index);
866
867 Ok(Dataset {
868 data: data_array,
869 target: Some(target),
870 featurenames: Some(vec![
871 "pm2_5".to_string(),
872 "pm10".to_string(),
873 "no2".to_string(),
874 "so2".to_string(),
875 "o3".to_string(),
876 "co".to_string(),
877 "temperature".to_string(),
878 "humidity".to_string(),
879 "wind_speed".to_string(),
880 "pressure".to_string(),
881 "visibility".to_string(),
882 "uv_index".to_string(),
883 ]),
884 targetnames: None,
885 feature_descriptions: Some(vec![
886 "PM2.5 concentration (µg/m³)".to_string(),
887 "PM10 concentration (µg/m³)".to_string(),
888 "NO2 concentration (µg/m³)".to_string(),
889 "SO2 concentration (µg/m³)".to_string(),
890 "O3 concentration (µg/m³)".to_string(),
891 "CO concentration (µg/m³)".to_string(),
892 "Temperature (°C)".to_string(),
893 "Relative humidity (%)".to_string(),
894 "Wind speed (m/s)".to_string(),
895 "Atmospheric pressure (hPa)".to_string(),
896 "Visibility (km)".to_string(),
897 "UV index".to_string(),
898 ]),
899 description: Some(format!(
900 "Atmospheric chemistry _measurements: {nmeasurements} samples"
901 )),
902 metadata: std::collections::HashMap::new(),
903 })
904 }
905
906 #[allow(clippy::too_many_arguments)]
907 fn calculate_aqi(pm25: f64, pm10: f64, no2: f64, so2: f64, o3: f64, co: f64) -> f64 {
908 let pm25_aqi = (pm25 / 35.0 * 100.0).min(300.0);
910 let pm10_aqi = (pm10 / 150.0 * 100.0).min(300.0);
911 let no2_aqi = (no2 / 100.0 * 100.0).min(300.0);
912 let so2_aqi = (so2 / 75.0 * 100.0).min(300.0);
913 let o3_aqi = (o3 / 120.0 * 100.0).min(300.0);
914 let co_aqi = (co / 9.0 * 100.0).min(300.0);
915
916 [pm25_aqi, pm10_aqi, no2_aqi, so2_aqi, o3_aqi, co_aqi]
918 .iter()
919 .fold(0.0f64, |a, &b| a.max(b))
920 }
921 }
922}
923
924pub mod convenience {
926 use super::astronomy::StellarDatasets;
927 use super::climate::ClimateDatasets;
928 use super::genomics::GenomicsDatasets;
929 use super::*;
930
931 pub fn load_stellar_classification() -> Result<Dataset> {
933 let datasets = StellarDatasets::new()?;
934 datasets.load_hipparcos_catalog()
935 }
936
937 pub fn load_exoplanets() -> Result<Dataset> {
939 let datasets = StellarDatasets::new()?;
940 datasets.load_exoplanet_catalog()
941 }
942
943 pub fn load_gene_expression(
945 n_samples: Option<usize>,
946 ngenes: Option<usize>,
947 ) -> Result<Dataset> {
948 let datasets = GenomicsDatasets::new()?;
949 datasets.load_gene_expression(n_samples.unwrap_or(200), ngenes.unwrap_or(1000))
950 }
951
952 pub fn load_climate_data(
954 _n_stations: Option<usize>,
955 n_years: Option<usize>,
956 ) -> Result<Dataset> {
957 let datasets = ClimateDatasets::new()?;
958 datasets.load_temperature_timeseries(_n_stations.unwrap_or(100), n_years.unwrap_or(10))
959 }
960
961 pub fn load_atmospheric_chemistry(_nmeasurements: Option<usize>) -> Result<Dataset> {
963 let datasets = ClimateDatasets::new()?;
964 datasets.load_atmospheric_chemistry(_nmeasurements.unwrap_or(1000))
965 }
966
967 pub fn list_domain_datasets() -> Vec<(&'static str, &'static str)> {
969 vec![
970 ("astronomy", "stellar_classification"),
971 ("astronomy", "exoplanets"),
972 ("astronomy", "supernovae"),
973 ("astronomy", "gaia_dr3"),
974 ("genomics", "gene_expression"),
975 ("genomics", "dnasequences"),
976 ("climate", "temperature_timeseries"),
977 ("climate", "atmospheric_chemistry"),
978 ]
979 }
980}
981
982#[cfg(test)]
983mod tests {
984 use super::convenience::*;
985 use scirs2_core::random::Uniform;
986
987 #[test]
988 fn test_load_stellar_classification() {
989 let dataset = load_stellar_classification().unwrap();
990 assert!(dataset.n_samples() > 1000);
991 assert_eq!(dataset.n_features(), 8);
992 assert!(dataset.target.is_some());
993 }
994
995 #[test]
996 fn test_load_exoplanets() {
997 let dataset = load_exoplanets().unwrap();
998 assert!(dataset.n_samples() > 100);
999 assert_eq!(dataset.n_features(), 6);
1000 assert!(dataset.target.is_some());
1001 }
1002
1003 #[test]
1004 fn test_load_gene_expression() {
1005 let dataset = load_gene_expression(Some(50), Some(100)).unwrap();
1006 assert_eq!(dataset.n_samples(), 50);
1007 assert_eq!(dataset.n_features(), 100);
1008 assert!(dataset.target.is_some());
1009 }
1010
1011 #[test]
1012 fn test_load_climate_data() {
1013 let dataset = load_climate_data(Some(20), Some(5)).unwrap();
1014 assert_eq!(dataset.n_samples(), 20);
1015 assert_eq!(dataset.n_features(), 8);
1016 assert!(dataset.target.is_some());
1017 }
1018
1019 #[test]
1020 fn test_load_atmospheric_chemistry() {
1021 let dataset = load_atmospheric_chemistry(Some(100)).unwrap();
1022 assert_eq!(dataset.n_samples(), 100);
1023 assert_eq!(dataset.n_features(), 12);
1024 assert!(dataset.target.is_some());
1025 }
1026
1027 #[test]
1028 fn test_list_domain_datasets() {
1029 let datasets = list_domain_datasets();
1030 assert!(!datasets.is_empty());
1031 assert!(datasets.iter().any(|(domain_, _)| *domain_ == "astronomy"));
1032 assert!(datasets.iter().any(|(domain_, _)| *domain_ == "genomics"));
1033 assert!(datasets.iter().any(|(domain_, _)| *domain_ == "climate"));
1034 }
1035}