scirs2_datasets/
domain_specific.rs

1//! Domain-specific datasets for scientific research
2//!
3//! This module provides specialized datasets for various scientific domains:
4//! - Astronomy and astrophysics
5//! - Genomics and bioinformatics
6//! - Climate science and meteorology
7//! - Materials science
8//! - Finance and economics
9//! - Computer vision and image processing
10//! - Natural language processing
11
12use std::collections::HashMap;
13
14use scirs2_core::ndarray::{Array1, Array2};
15use scirs2_core::random::prelude::*;
16use scirs2_core::random::{Distribution, Uniform};
17use serde::{Deserialize, Serialize};
18
19use crate::cache::DatasetCache;
20use crate::error::{DatasetsError, Result};
21use crate::external::ExternalClient;
22use crate::utils::Dataset;
23
24/// Configuration for domain-specific dataset loading
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct DomainConfig {
27    /// Base URL for dataset repository
28    pub base_url: Option<String>,
29    /// API key for authenticated access
30    pub api_key: Option<String>,
31    /// Data format preferences
32    pub preferred_formats: Vec<String>,
33    /// Quality filters
34    pub quality_filters: QualityFilters,
35}
36
37impl Default for DomainConfig {
38    fn default() -> Self {
39        Self {
40            base_url: None,
41            api_key: None,
42            preferred_formats: vec!["csv".to_string(), "fits".to_string(), "hdf5".to_string()],
43            quality_filters: QualityFilters::default(),
44        }
45    }
46}
47
48/// Quality filters for dataset selection
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct QualityFilters {
51    /// Minimum number of samples
52    pub min_samples: Option<usize>,
53    /// Maximum missing data percentage
54    pub max_missing_percent: Option<f64>,
55    /// Required data completeness
56    pub min_completeness: Option<f64>,
57    /// Minimum publication year
58    pub min_year: Option<u32>,
59}
60
61impl Default for QualityFilters {
62    fn default() -> Self {
63        Self {
64            min_samples: Some(100),
65            max_missing_percent: Some(0.1),
66            min_completeness: Some(0.9),
67            min_year: Some(2000),
68        }
69    }
70}
71
72/// Astronomy and astrophysics datasets
73pub mod astronomy {
74    use super::*;
75
76    /// Stellar classification and properties
77    pub struct StellarDatasets {
78        #[allow(dead_code)]
79        client: ExternalClient,
80        #[allow(dead_code)]
81        cache: DatasetCache,
82    }
83
84    impl StellarDatasets {
85        /// Create a new stellar datasets client
86        pub fn new() -> Result<Self> {
87            let cachedir = dirs::cache_dir()
88                .ok_or_else(|| {
89                    DatasetsError::Other("Could not determine cache directory".to_string())
90                })?
91                .join("scirs2-datasets");
92            Ok(Self {
93                client: ExternalClient::new()?,
94                cache: DatasetCache::new(cachedir),
95            })
96        }
97
98        /// Load Hipparcos stellar catalog data
99        pub fn load_hipparcos_catalog(&self) -> Result<Dataset> {
100            self.load_synthetic_stellar_data("hipparcos", 118218)
101        }
102
103        /// Load Gaia DR3 stellar data (synthetic for demonstration)
104        pub fn load_gaia_dr3_sample(&self) -> Result<Dataset> {
105            self.load_synthetic_stellar_data("gaia_dr3", 50000)
106        }
107
108        /// Load exoplanet catalog
109        pub fn load_exoplanet_catalog(&self) -> Result<Dataset> {
110            self.load_synthetic_exoplanet_data(5000)
111        }
112
113        /// Load supernova photometry data
114        pub fn load_supernova_photometry(&self) -> Result<Dataset> {
115            self.load_synthetic_supernova_data(1000)
116        }
117
118        fn load_synthetic_stellar_data(&self, catalog: &str, nstars: usize) -> Result<Dataset> {
119            use scirs2_core::random::{Distribution, Normal};
120
121            let mut rng = thread_rng();
122
123            // Generate synthetic stellar parameters
124            let mut data = Vec::with_capacity(nstars * 8);
125            let mut spectral_classes = Vec::with_capacity(nstars);
126
127            // Distributions for stellar parameters
128            let ra_dist = scirs2_core::random::Uniform::new(0.0, 360.0).unwrap();
129            let dec_dist = scirs2_core::random::Uniform::new(-90.0, 90.0).unwrap();
130            let magnitude_dist = Normal::new(8.0, 3.0).unwrap();
131            let color_dist = Normal::new(0.5, 0.3).unwrap();
132            let parallax_dist = Normal::new(10.0, 5.0).unwrap();
133            let proper_motion_dist = Normal::new(0.0, 50.0).unwrap();
134            let radial_velocity_dist = Normal::new(0.0, 30.0).unwrap();
135
136            for _ in 0..nstars {
137                // Right ascension (degrees)
138                data.push(ra_dist.sample(&mut rng));
139                // Declination (degrees)
140                data.push(dec_dist.sample(&mut rng));
141                // Apparent magnitude
142                data.push(magnitude_dist.sample(&mut rng));
143                // Color index (B-V)
144                data.push(color_dist.sample(&mut rng));
145                // Parallax (mas)
146                data.push((parallax_dist.sample(&mut rng) as f64).max(0.1f64));
147                // Proper motion RA (mas/yr)
148                data.push(proper_motion_dist.sample(&mut rng));
149                // Proper motion Dec (mas/yr)
150                data.push(proper_motion_dist.sample(&mut rng));
151                // Radial velocity (km/s)
152                data.push(radial_velocity_dist.sample(&mut rng));
153
154                // Assign spectral class based on color
155                let color = data[data.len() - 5];
156                let spectral_class = match color {
157                    c if c < -0.3 => 0, // O
158                    c if c < -0.1 => 1, // B
159                    c if c < 0.2 => 2,  // A
160                    c if c < 0.5 => 3,  // F
161                    c if c < 0.8 => 4,  // G
162                    c if c < 1.2 => 5,  // K
163                    _ => 6,             // M
164                };
165                spectral_classes.push(spectral_class as f64);
166            }
167
168            let data_array = Array2::from_shape_vec((nstars, 8), data)
169                .map_err(|e| DatasetsError::FormatError(e.to_string()))?;
170
171            let target = Array1::from_vec(spectral_classes);
172
173            Ok(Dataset {
174                data: data_array,
175                target: Some(target),
176                featurenames: Some(vec![
177                    "ra".to_string(),
178                    "dec".to_string(),
179                    "magnitude".to_string(),
180                    "color_bv".to_string(),
181                    "parallax".to_string(),
182                    "pm_ra".to_string(),
183                    "pm_dec".to_string(),
184                    "radial_velocity".to_string(),
185                ]),
186                targetnames: Some(vec![
187                    "O".to_string(),
188                    "B".to_string(),
189                    "A".to_string(),
190                    "F".to_string(),
191                    "G".to_string(),
192                    "K".to_string(),
193                    "M".to_string(),
194                ]),
195                feature_descriptions: Some(vec![
196                    "Right Ascension (degrees)".to_string(),
197                    "Declination (degrees)".to_string(),
198                    "Apparent magnitude (visual)".to_string(),
199                    "B-V color index".to_string(),
200                    "Parallax (arcseconds)".to_string(),
201                    "Proper motion RA (mas/year)".to_string(),
202                    "Proper motion Dec (mas/year)".to_string(),
203                    "Radial velocity (km/s)".to_string(),
204                ]),
205                description: Some(format!(
206                    "Synthetic {catalog} stellar catalog with {nstars} _stars"
207                )),
208                metadata: std::collections::HashMap::new(),
209            })
210        }
211
212        fn load_synthetic_exoplanet_data(&self, nplanets: usize) -> Result<Dataset> {
213            use scirs2_core::random::{Distribution, LogNormal, Normal};
214
215            let mut rng = thread_rng();
216
217            // Generate synthetic exoplanet parameters
218            let mut data = Vec::with_capacity(nplanets * 6);
219            let mut planet_types = Vec::with_capacity(nplanets);
220
221            // Distributions for planetary parameters
222            let period_dist = LogNormal::new(1.0, 1.5).unwrap();
223            let radius_dist = LogNormal::new(0.0, 0.8).unwrap();
224            let mass_dist = LogNormal::new(1.0, 1.2).unwrap();
225            let stellar_mass_dist = Normal::new(1.0, 0.3).unwrap();
226            let stellar_temp_dist = Normal::new(5800.0, 1000.0).unwrap();
227            let metallicity_dist = Normal::new(0.0, 0.3).unwrap();
228
229            for _ in 0..nplanets {
230                // Orbital period (days)
231                data.push(period_dist.sample(&mut rng));
232                // Planet radius (Earth radii)
233                data.push(radius_dist.sample(&mut rng));
234                // Planet mass (Earth masses)
235                data.push(mass_dist.sample(&mut rng));
236                // Stellar mass (Solar masses)
237                data.push((stellar_mass_dist.sample(&mut rng) as f64).max(0.1f64));
238                // Stellar temperature (K)
239                data.push(stellar_temp_dist.sample(&mut rng));
240                // Stellar metallicity [Fe/H]
241                data.push(metallicity_dist.sample(&mut rng));
242
243                // Classify planet type based on radius
244                let radius = data[data.len() - 5];
245                let planet_type = match radius {
246                    r if r < 1.25 => 0, // Rocky
247                    r if r < 2.0 => 1,  // Super-Earth
248                    r if r < 4.0 => 2,  // Sub-Neptune
249                    r if r < 11.0 => 3, // Neptune
250                    _ => 4,             // Jupiter
251                };
252                planet_types.push(planet_type as f64);
253            }
254
255            let data_array = Array2::from_shape_vec((nplanets, 6), data)
256                .map_err(|e| DatasetsError::FormatError(e.to_string()))?;
257
258            let target = Array1::from_vec(planet_types);
259
260            Ok(Dataset {
261                data: data_array,
262                target: Some(target),
263                featurenames: Some(vec![
264                    "period".to_string(),
265                    "radius".to_string(),
266                    "mass".to_string(),
267                    "stellar_mass".to_string(),
268                    "stellar_temp".to_string(),
269                    "metallicity".to_string(),
270                ]),
271                targetnames: Some(vec![
272                    "Rocky".to_string(),
273                    "Super-Earth".to_string(),
274                    "Sub-Neptune".to_string(),
275                    "Neptune".to_string(),
276                    "Jupiter".to_string(),
277                ]),
278                feature_descriptions: Some(vec![
279                    "Orbital period (days)".to_string(),
280                    "Planet radius (Earth radii)".to_string(),
281                    "Planet mass (Earth masses)".to_string(),
282                    "Stellar mass (Solar masses)".to_string(),
283                    "Stellar temperature (K)".to_string(),
284                    "Stellar metallicity [Fe/H]".to_string(),
285                ]),
286                description: Some(format!(
287                    "Synthetic exoplanet catalog with {nplanets} _planets"
288                )),
289                metadata: std::collections::HashMap::new(),
290            })
291        }
292
293        fn load_synthetic_supernova_data(&self, nsupernovae: usize) -> Result<Dataset> {
294            use scirs2_core::random::{Distribution, Normal};
295
296            let mut rng = thread_rng();
297
298            // Generate synthetic supernova light curve features
299            let mut data = Vec::with_capacity(nsupernovae * 10);
300            let mut sn_types = Vec::with_capacity(nsupernovae);
301
302            // Different supernova types have different characteristics
303            let _type_probs = [0.7, 0.15, 0.10, 0.05]; // Ia, Ib/c, II-P, II-L
304
305            for _ in 0..nsupernovae {
306                let sn_type = rng.sample(Uniform::new(0, 4).unwrap());
307
308                let (peak_mag, decline_rate, color_evolution, host_mass) = match sn_type {
309                    0 => (-19.3, 1.1, 0.2, 10.5), // Type Ia
310                    1 => (-18.5, 1.8, 0.5, 9.8),  // Type Ib/c
311                    2 => (-16.8, 0.8, 0.3, 9.2),  // Type II-P
312                    _ => (-17.5, 1.2, 0.4, 9.0),  // Type II-L
313                };
314
315                // Add noise to base parameters
316                let peak_noise = Normal::new(0.0, 0.3).unwrap();
317                let decline_noise = Normal::new(0.0, 0.2).unwrap();
318                let color_noise = Normal::new(0.0, 0.1).unwrap();
319                let host_noise = Normal::new(0.0, 0.5).unwrap();
320
321                // Peak absolute magnitude
322                data.push(peak_mag + peak_noise.sample(&mut rng));
323                // Decline rate (mag/15 days)
324                data.push(decline_rate + decline_noise.sample(&mut rng));
325                // Color at maximum
326                data.push(color_evolution + color_noise.sample(&mut rng));
327                // Host galaxy mass (log M_sun)
328                data.push(host_mass + host_noise.sample(&mut rng));
329                // Redshift
330                data.push(rng.gen_range(0.01..0.3));
331                // Duration (days)
332                data.push(rng.gen_range(20.0..200.0));
333                // Stretch factor
334                data.push(rng.gen_range(0.7..1.3));
335                // Color excess E(B-V)
336                data.push(rng.gen_range(0.0..0.5));
337                // Discovery magnitude
338                data.push(rng.gen_range(15.0..22.0));
339                // Galactic latitude
340                data.push(rng.gen_range(-90.0..90.0));
341
342                sn_types.push(sn_type as f64);
343            }
344
345            let data_array = Array2::from_shape_vec((nsupernovae, 10), data)
346                .map_err(|e| DatasetsError::FormatError(e.to_string()))?;
347
348            let target = Array1::from_vec(sn_types);
349
350            Ok(Dataset {
351                data: data_array,
352                target: Some(target),
353                featurenames: Some(vec![
354                    "peak_magnitude".to_string(),
355                    "decline_rate".to_string(),
356                    "color_max".to_string(),
357                    "host_mass".to_string(),
358                    "redshift".to_string(),
359                    "duration".to_string(),
360                    "stretch".to_string(),
361                    "color_excess".to_string(),
362                    "discovery_mag".to_string(),
363                    "galactic_lat".to_string(),
364                ]),
365                targetnames: Some(vec![
366                    "Type Ia".to_string(),
367                    "Type Ib/c".to_string(),
368                    "Type II-P".to_string(),
369                    "Type II-L".to_string(),
370                ]),
371                feature_descriptions: Some(vec![
372                    "Peak apparent magnitude".to_string(),
373                    "Magnitude decline rate (mag/day)".to_string(),
374                    "Maximum color index".to_string(),
375                    "Host galaxy stellar mass (log10 M_sun)".to_string(),
376                    "Cosmological redshift".to_string(),
377                    "Light curve duration (days)".to_string(),
378                    "Light curve stretch factor".to_string(),
379                    "Host galaxy color excess E(B-V)".to_string(),
380                    "Discovery magnitude".to_string(),
381                    "Galactic latitude (degrees)".to_string(),
382                ]),
383                description: Some(format!(
384                    "Synthetic supernova catalog with {nsupernovae} events"
385                )),
386                metadata: std::collections::HashMap::new(),
387            })
388        }
389    }
390}
391
392/// Genomics and bioinformatics datasets
393pub mod genomics {
394    use super::*;
395
396    /// Genomic sequence and expression datasets
397    pub struct GenomicsDatasets {
398        #[allow(dead_code)]
399        client: ExternalClient,
400        #[allow(dead_code)]
401        cache: DatasetCache,
402    }
403
404    impl GenomicsDatasets {
405        /// Create a new genomics datasets client
406        pub fn new() -> Result<Self> {
407            let cachedir = dirs::cache_dir()
408                .ok_or_else(|| {
409                    DatasetsError::Other("Could not determine cache directory".to_string())
410                })?
411                .join("scirs2-datasets");
412            Ok(Self {
413                client: ExternalClient::new()?,
414                cache: DatasetCache::new(cachedir),
415            })
416        }
417
418        /// Load synthetic gene expression data
419        pub fn load_gene_expression(&self, n_samples: usize, ngenes: usize) -> Result<Dataset> {
420            use scirs2_core::random::{Distribution, LogNormal, Normal};
421
422            let mut rng = thread_rng();
423
424            // Generate synthetic gene expression matrix
425            let mut data = Vec::with_capacity(n_samples * ngenes);
426            let mut phenotypes = Vec::with_capacity(n_samples);
427
428            // Different expression patterns for different conditions
429            let condition_effects = [1.0, 2.5, 0.4, 1.8, 0.7]; // Log-fold changes
430
431            for sample_idx in 0..n_samples {
432                let condition = sample_idx % condition_effects.len();
433                let base_effect = condition_effects[condition];
434
435                for gene_idx in 0..ngenes {
436                    // Base expression level
437                    let base_expr = LogNormal::new(5.0, 2.0).unwrap().sample(&mut rng);
438
439                    // Condition-specific modulation
440                    let gene_effect = if gene_idx < ngenes / 10 {
441                        // 10% of _genes are differentially expressed
442                        base_effect
443                    } else {
444                        1.0
445                    };
446
447                    // Add noise
448                    let noise = Normal::new(1.0, 0.2).unwrap().sample(&mut rng);
449
450                    let expression: f64 = base_expr * gene_effect * noise;
451                    data.push(expression.ln()); // Log-transform
452                }
453
454                phenotypes.push(condition as f64);
455            }
456
457            let data_array = Array2::from_shape_vec((n_samples, ngenes), data)
458                .map_err(|e| DatasetsError::FormatError(e.to_string()))?;
459
460            let target = Array1::from_vec(phenotypes);
461
462            // Generate gene names
463            let featurenames: Vec<String> = (0..ngenes).map(|i| format!("GENE_{i:06}")).collect();
464
465            Ok(Dataset {
466                data: data_array,
467                target: Some(target),
468                featurenames: Some(featurenames.clone()),
469                targetnames: Some(vec![
470                    "Control".to_string(),
471                    "Treatment_A".to_string(),
472                    "Treatment_B".to_string(),
473                    "Disease_X".to_string(),
474                    "Disease_Y".to_string(),
475                ]),
476                feature_descriptions: Some(
477                    featurenames
478                        .iter()
479                        .map(|name| format!("Expression level of {name}"))
480                        .collect(),
481                ),
482                description: Some(format!(
483                    "Synthetic gene expression data: {n_samples} _samples × {ngenes} _genes"
484                )),
485                metadata: std::collections::HashMap::new(),
486            })
487        }
488
489        /// Load synthetic DNA sequence features
490        pub fn load_dnasequences(
491            &self,
492            nsequences: usize,
493            sequence_length: usize,
494        ) -> Result<Dataset> {
495            let mut rng = thread_rng();
496            let nucleotides = ['A', 'T', 'G', 'C'];
497
498            let mut sequences = Vec::new();
499            let mut sequence_types = Vec::with_capacity(nsequences);
500
501            for seq_idx in 0..nsequences {
502                let mut sequence = String::with_capacity(sequence_length);
503
504                // Generate sequence with some patterns
505                let seq_type = seq_idx % 3; // 3 different types
506
507                for _pos in 0..sequence_length {
508                    let nucleotide = match seq_type {
509                        0 => {
510                            // GC-rich sequences
511                            if rng.random::<f64>() < 0.6 {
512                                if rng.random::<f64>() < 0.5 {
513                                    'G'
514                                } else {
515                                    'C'
516                                }
517                            } else if rng.random::<f64>() < 0.5 {
518                                'A'
519                            } else {
520                                'T'
521                            }
522                        }
523                        1 => {
524                            // AT-rich sequences
525                            if rng.random::<f64>() < 0.6 {
526                                if rng.random::<f64>() < 0.5 {
527                                    'A'
528                                } else {
529                                    'T'
530                                }
531                            } else if rng.random::<f64>() < 0.5 {
532                                'G'
533                            } else {
534                                'C'
535                            }
536                        }
537                        _ => {
538                            // Random sequences
539                            nucleotides[rng.sample(Uniform::new(0, 4).unwrap())]
540                        }
541                    };
542
543                    sequence.push(nucleotide);
544                }
545
546                sequences.push(sequence);
547                sequence_types.push(seq_type as f64);
548            }
549
550            // Convert sequences to k-mer features (k=3)
551            let mut data = Vec::new();
552            let k = 3;
553            let kmers = Self::generate_kmers(k);
554
555            for sequence in &sequences {
556                let kmer_counts = Self::count_kmers(sequence, k, &kmers);
557                data.extend(kmer_counts);
558            }
559
560            let n_features = 4_usize.pow(k as u32); // 4^k possible k-mers
561            let data_array = Array2::from_shape_vec((nsequences, n_features), data)
562                .map_err(|e| DatasetsError::FormatError(e.to_string()))?;
563
564            let target = Array1::from_vec(sequence_types);
565
566            Ok(Dataset {
567                data: data_array,
568                target: Some(target),
569                featurenames: Some(kmers.clone()),
570                targetnames: Some(vec![
571                    "GC-rich".to_string(),
572                    "AT-rich".to_string(),
573                    "Random".to_string(),
574                ]),
575                feature_descriptions: Some(
576                    kmers
577                        .iter()
578                        .map(|kmer| format!("Frequency of {k}-mer: {kmer}"))
579                        .collect(),
580                ),
581                description: Some(format!(
582                    "DNA sequences: {nsequences} seqs × {k}-mer features"
583                )),
584                metadata: std::collections::HashMap::new(),
585            })
586        }
587
588        fn generate_kmers(k: usize) -> Vec<String> {
589            let nucleotides = vec!['A', 'T', 'G', 'C'];
590            let mut kmers = Vec::new();
591
592            fn generate_recursive(
593                current: String,
594                remaining: usize,
595                nucleotides: &[char],
596                kmers: &mut Vec<String>,
597            ) {
598                if remaining == 0 {
599                    kmers.push(current);
600                    return;
601                }
602
603                for &nucleotide in nucleotides {
604                    let mut new_current = current.clone();
605                    new_current.push(nucleotide);
606                    generate_recursive(new_current, remaining - 1, nucleotides, kmers);
607                }
608            }
609
610            generate_recursive(String::new(), k, &nucleotides, &mut kmers);
611            kmers
612        }
613
614        fn count_kmers(sequence: &str, k: usize, kmers: &[String]) -> Vec<f64> {
615            let mut counts = vec![0.0; kmers.len()];
616            let kmer_to_idx: HashMap<&str, usize> = kmers
617                .iter()
618                .enumerate()
619                .map(|(i, k)| (k.as_str(), i))
620                .collect();
621
622            for i in 0..=sequence.len().saturating_sub(k) {
623                let kmer = &sequence[i..i + k];
624                if let Some(&idx) = kmer_to_idx.get(kmer) {
625                    counts[idx] += 1.0;
626                }
627            }
628
629            // Normalize by sequence length
630            let total: f64 = counts.iter().sum();
631            if total > 0.0 {
632                for count in &mut counts {
633                    *count /= total;
634                }
635            }
636
637            counts
638        }
639    }
640}
641
642/// Climate science and meteorology datasets
643pub mod climate {
644    use super::*;
645
646    /// Climate and weather datasets
647    pub struct ClimateDatasets {
648        #[allow(dead_code)]
649        client: ExternalClient,
650        #[allow(dead_code)]
651        cache: DatasetCache,
652    }
653
654    impl ClimateDatasets {
655        /// Create a new climate datasets client
656        pub fn new() -> Result<Self> {
657            let cachedir = dirs::cache_dir()
658                .ok_or_else(|| {
659                    DatasetsError::Other("Could not determine cache directory".to_string())
660                })?
661                .join("scirs2-datasets");
662            Ok(Self {
663                client: ExternalClient::new()?,
664                cache: DatasetCache::new(cachedir),
665            })
666        }
667
668        /// Load synthetic temperature time series data
669        pub fn load_temperature_timeseries(
670            &self,
671            n_stations: usize,
672            n_years: usize,
673        ) -> Result<Dataset> {
674            use scirs2_core::random::{Distribution, Normal};
675
676            let mut rng = thread_rng();
677            let days_per_year = 365;
678            let total_days = n_years * days_per_year;
679
680            let mut data = Vec::with_capacity(n_stations * 8); // 8 climate features per station
681            let mut climate_zones = Vec::with_capacity(n_stations);
682
683            for station_idx in 0..n_stations {
684                // Assign climate zone
685                let zone = station_idx % 5; // 5 climate zones
686                climate_zones.push(zone as f64);
687
688                // Base parameters for different climate zones
689                let (base_temp, temp_amplitude, annual_precip, humidity) = match zone {
690                    0 => (25.0, 5.0, 2000.0, 80.0),  // Tropical
691                    1 => (15.0, 15.0, 800.0, 60.0),  // Temperate
692                    2 => (-5.0, 20.0, 400.0, 70.0),  // Continental
693                    3 => (5.0, 8.0, 200.0, 40.0),    // Desert
694                    _ => (-10.0, 25.0, 300.0, 75.0), // Arctic
695                };
696
697                // Simulate temperature time series and derive statistics
698                let mut temperatures = Vec::with_capacity(total_days);
699                let mut precipitation = Vec::with_capacity(total_days);
700
701                for day in 0..total_days {
702                    let year_progress = (day % days_per_year) as f64 / days_per_year as f64;
703                    let seasonal_temp = base_temp
704                        + temp_amplitude * (year_progress * 2.0 * std::f64::consts::PI).cos();
705
706                    // Add daily noise
707                    let temp_noise = Normal::new(0.0, 2.0).unwrap();
708                    let temp = seasonal_temp + temp_noise.sample(&mut rng);
709                    temperatures.push(temp);
710
711                    // Precipitation (more in summer for some zones)
712                    let seasonal_precip_factor = match zone {
713                        0 => {
714                            1.0 + 0.3
715                                * (year_progress * 2.0 * std::f64::consts::PI
716                                    + std::f64::consts::PI)
717                                    .cos()
718                        }
719                        1 => 1.0 + 0.2 * (year_progress * 2.0 * std::f64::consts::PI).sin(),
720                        _ => 1.0,
721                    };
722
723                    let precip = if rng.random::<f64>() < 0.3 {
724                        // 30% chance of precipitation
725                        rng.gen_range(0.0..20.0) * seasonal_precip_factor
726                    } else {
727                        0.0
728                    };
729                    precipitation.push(precip);
730                }
731
732                // Calculate summary statistics
733                let mean_temp = temperatures.iter().sum::<f64>() / temperatures.len() as f64;
734                let max_temp = temperatures
735                    .iter()
736                    .fold(f64::NEG_INFINITY, |a, &b| a.max(b));
737                let min_temp = temperatures.iter().fold(f64::INFINITY, |a, &b| a.min(b));
738                let temp_range = max_temp - min_temp;
739
740                let total_precip = precipitation.iter().sum::<f64>();
741                let precip_days = precipitation.iter().filter(|&&p| p > 0.0).count() as f64;
742
743                // Generate additional climate variables
744                let avg_humidity = humidity + Normal::new(0.0, 5.0).unwrap().sample(&mut rng);
745                let wind_speed = rng.gen_range(2.0..15.0);
746
747                data.extend(vec![
748                    mean_temp,
749                    temp_range,
750                    total_precip,
751                    precip_days,
752                    avg_humidity,
753                    wind_speed,
754                    base_temp,             // Latitude proxy
755                    annual_precip / 365.0, // Average daily precipitation
756                ]);
757            }
758
759            let data_array = Array2::from_shape_vec((n_stations, 8), data)
760                .map_err(|e| DatasetsError::FormatError(e.to_string()))?;
761
762            let target = Array1::from_vec(climate_zones);
763
764            Ok(Dataset {
765                data: data_array,
766                target: Some(target),
767                featurenames: Some(vec![
768                    "mean_temperature".to_string(),
769                    "temperature_range".to_string(),
770                    "annual_precipitation".to_string(),
771                    "precipitation_days".to_string(),
772                    "avg_humidity".to_string(),
773                    "avg_wind_speed".to_string(),
774                    "latitude_proxy".to_string(),
775                    "daily_precip_avg".to_string(),
776                ]),
777                targetnames: Some(vec![
778                    "Tropical".to_string(),
779                    "Temperate".to_string(),
780                    "Continental".to_string(),
781                    "Desert".to_string(),
782                    "Arctic".to_string(),
783                ]),
784                feature_descriptions: Some(vec![
785                    "Mean annual temperature (°C)".to_string(),
786                    "Temperature range (max-min, °C)".to_string(),
787                    "Total annual precipitation (mm)".to_string(),
788                    "Number of precipitation days per year".to_string(),
789                    "Average humidity (%)".to_string(),
790                    "Average wind speed (m/s)".to_string(),
791                    "Latitude proxy (normalized)".to_string(),
792                    "Average daily precipitation (mm/day)".to_string(),
793                ]),
794                description: Some(format!(
795                    "Climate data: {n_stations} _stations × {n_years} _years"
796                )),
797                metadata: std::collections::HashMap::new(),
798            })
799        }
800
801        /// Load atmospheric chemistry data
802        pub fn load_atmospheric_chemistry(&self, nmeasurements: usize) -> Result<Dataset> {
803            use scirs2_core::random::{Distribution, LogNormal, Normal};
804
805            let mut rng = thread_rng();
806
807            let mut data = Vec::with_capacity(nmeasurements * 12);
808            let mut air_quality_index = Vec::with_capacity(nmeasurements);
809
810            for _ in 0..nmeasurements {
811                // Generate correlated atmospheric _measurements
812                let base_pollution = rng.gen_range(0.0..1.0);
813
814                // Major pollutants (concentrations in µg/m³)
815                let pm25: f64 = LogNormal::new(2.0 + base_pollution, 0.5)
816                    .unwrap()
817                    .sample(&mut rng);
818                let pm10 = pm25 * rng.gen_range(1.5..2.5);
819                let no2 = LogNormal::new(3.0 + base_pollution * 0.5, 0.3)
820                    .unwrap()
821                    .sample(&mut rng);
822                let so2 = LogNormal::new(1.0 + base_pollution * 0.3, 0.4)
823                    .unwrap()
824                    .sample(&mut rng);
825                let o3 = LogNormal::new(4.0 - base_pollution * 0.2, 0.2)
826                    .unwrap()
827                    .sample(&mut rng);
828                let co = LogNormal::new(0.5 + base_pollution * 0.4, 0.3)
829                    .unwrap()
830                    .sample(&mut rng);
831
832                // Meteorological factors
833                let temperature = Normal::new(20.0, 10.0).unwrap().sample(&mut rng);
834                let humidity = rng.gen_range(30.0..90.0);
835                let wind_speed = rng.gen_range(0.5..12.0);
836                let pressure = Normal::new(1013.0, 15.0).unwrap().sample(&mut rng);
837
838                // Derived _measurements
839                let visibility = (50.0 - pm25.ln() * 5.0).max(1.0);
840                let uv_index = rng.gen_range(0.0..12.0);
841
842                data.extend(vec![
843                    pm25,
844                    pm10,
845                    no2,
846                    so2,
847                    o3,
848                    co,
849                    temperature,
850                    humidity,
851                    wind_speed,
852                    pressure,
853                    visibility,
854                    uv_index,
855                ]);
856
857                // Calculate air quality index
858                let aqi = Self::calculate_aqi(pm25, pm10, no2, so2, o3, co);
859                air_quality_index.push(aqi);
860            }
861
862            let data_array = Array2::from_shape_vec((nmeasurements, 12), data)
863                .map_err(|e| DatasetsError::FormatError(e.to_string()))?;
864
865            let target = Array1::from_vec(air_quality_index);
866
867            Ok(Dataset {
868                data: data_array,
869                target: Some(target),
870                featurenames: Some(vec![
871                    "pm2_5".to_string(),
872                    "pm10".to_string(),
873                    "no2".to_string(),
874                    "so2".to_string(),
875                    "o3".to_string(),
876                    "co".to_string(),
877                    "temperature".to_string(),
878                    "humidity".to_string(),
879                    "wind_speed".to_string(),
880                    "pressure".to_string(),
881                    "visibility".to_string(),
882                    "uv_index".to_string(),
883                ]),
884                targetnames: None,
885                feature_descriptions: Some(vec![
886                    "PM2.5 concentration (µg/m³)".to_string(),
887                    "PM10 concentration (µg/m³)".to_string(),
888                    "NO2 concentration (µg/m³)".to_string(),
889                    "SO2 concentration (µg/m³)".to_string(),
890                    "O3 concentration (µg/m³)".to_string(),
891                    "CO concentration (µg/m³)".to_string(),
892                    "Temperature (°C)".to_string(),
893                    "Relative humidity (%)".to_string(),
894                    "Wind speed (m/s)".to_string(),
895                    "Atmospheric pressure (hPa)".to_string(),
896                    "Visibility (km)".to_string(),
897                    "UV index".to_string(),
898                ]),
899                description: Some(format!(
900                    "Atmospheric chemistry _measurements: {nmeasurements} samples"
901                )),
902                metadata: std::collections::HashMap::new(),
903            })
904        }
905
906        #[allow(clippy::too_many_arguments)]
907        fn calculate_aqi(pm25: f64, pm10: f64, no2: f64, so2: f64, o3: f64, co: f64) -> f64 {
908            // Simplified AQI calculation
909            let pm25_aqi = (pm25 / 35.0 * 100.0).min(300.0);
910            let pm10_aqi = (pm10 / 150.0 * 100.0).min(300.0);
911            let no2_aqi = (no2 / 100.0 * 100.0).min(300.0);
912            let so2_aqi = (so2 / 75.0 * 100.0).min(300.0);
913            let o3_aqi = (o3 / 120.0 * 100.0).min(300.0);
914            let co_aqi = (co / 9.0 * 100.0).min(300.0);
915
916            // Return the maximum AQI (worst pollutant determines overall AQI)
917            [pm25_aqi, pm10_aqi, no2_aqi, so2_aqi, o3_aqi, co_aqi]
918                .iter()
919                .fold(0.0f64, |a, &b| a.max(b))
920        }
921    }
922}
923
924/// Convenience functions for loading domain-specific datasets
925pub mod convenience {
926    use super::astronomy::StellarDatasets;
927    use super::climate::ClimateDatasets;
928    use super::genomics::GenomicsDatasets;
929    use super::*;
930
931    /// Load a stellar classification dataset
932    pub fn load_stellar_classification() -> Result<Dataset> {
933        let datasets = StellarDatasets::new()?;
934        datasets.load_hipparcos_catalog()
935    }
936
937    /// Load an exoplanet dataset
938    pub fn load_exoplanets() -> Result<Dataset> {
939        let datasets = StellarDatasets::new()?;
940        datasets.load_exoplanet_catalog()
941    }
942
943    /// Load a gene expression dataset
944    pub fn load_gene_expression(
945        n_samples: Option<usize>,
946        ngenes: Option<usize>,
947    ) -> Result<Dataset> {
948        let datasets = GenomicsDatasets::new()?;
949        datasets.load_gene_expression(n_samples.unwrap_or(200), ngenes.unwrap_or(1000))
950    }
951
952    /// Load a climate dataset
953    pub fn load_climate_data(
954        _n_stations: Option<usize>,
955        n_years: Option<usize>,
956    ) -> Result<Dataset> {
957        let datasets = ClimateDatasets::new()?;
958        datasets.load_temperature_timeseries(_n_stations.unwrap_or(100), n_years.unwrap_or(10))
959    }
960
961    /// Load atmospheric chemistry data
962    pub fn load_atmospheric_chemistry(_nmeasurements: Option<usize>) -> Result<Dataset> {
963        let datasets = ClimateDatasets::new()?;
964        datasets.load_atmospheric_chemistry(_nmeasurements.unwrap_or(1000))
965    }
966
967    /// List all available domain-specific datasets
968    pub fn list_domain_datasets() -> Vec<(&'static str, &'static str)> {
969        vec![
970            ("astronomy", "stellar_classification"),
971            ("astronomy", "exoplanets"),
972            ("astronomy", "supernovae"),
973            ("astronomy", "gaia_dr3"),
974            ("genomics", "gene_expression"),
975            ("genomics", "dnasequences"),
976            ("climate", "temperature_timeseries"),
977            ("climate", "atmospheric_chemistry"),
978        ]
979    }
980}
981
982#[cfg(test)]
983mod tests {
984    use super::convenience::*;
985    use scirs2_core::random::Uniform;
986
987    #[test]
988    fn test_load_stellar_classification() {
989        let dataset = load_stellar_classification().unwrap();
990        assert!(dataset.n_samples() > 1000);
991        assert_eq!(dataset.n_features(), 8);
992        assert!(dataset.target.is_some());
993    }
994
995    #[test]
996    fn test_load_exoplanets() {
997        let dataset = load_exoplanets().unwrap();
998        assert!(dataset.n_samples() > 100);
999        assert_eq!(dataset.n_features(), 6);
1000        assert!(dataset.target.is_some());
1001    }
1002
1003    #[test]
1004    fn test_load_gene_expression() {
1005        let dataset = load_gene_expression(Some(50), Some(100)).unwrap();
1006        assert_eq!(dataset.n_samples(), 50);
1007        assert_eq!(dataset.n_features(), 100);
1008        assert!(dataset.target.is_some());
1009    }
1010
1011    #[test]
1012    fn test_load_climate_data() {
1013        let dataset = load_climate_data(Some(20), Some(5)).unwrap();
1014        assert_eq!(dataset.n_samples(), 20);
1015        assert_eq!(dataset.n_features(), 8);
1016        assert!(dataset.target.is_some());
1017    }
1018
1019    #[test]
1020    fn test_load_atmospheric_chemistry() {
1021        let dataset = load_atmospheric_chemistry(Some(100)).unwrap();
1022        assert_eq!(dataset.n_samples(), 100);
1023        assert_eq!(dataset.n_features(), 12);
1024        assert!(dataset.target.is_some());
1025    }
1026
1027    #[test]
1028    fn test_list_domain_datasets() {
1029        let datasets = list_domain_datasets();
1030        assert!(!datasets.is_empty());
1031        assert!(datasets.iter().any(|(domain_, _)| *domain_ == "astronomy"));
1032        assert!(datasets.iter().any(|(domain_, _)| *domain_ == "genomics"));
1033        assert!(datasets.iter().any(|(domain_, _)| *domain_ == "climate"));
1034    }
1035}