torsh_data/
builtin.rs

1//! Built-in datasets powered by SciRS2
2//!
3//! This module provides access to toy datasets, synthetic data generators,
4//! and other built-in data sources from the SciRS2 ecosystem.
5
6use crate::error::DataError;
7// ✅ SciRS2 Policy Compliant - Using scirs2_core::random instead of direct rand
8use scirs2_core::random::{Random, Rng, SeedableRng};
9use torsh_core::error::TorshError;
10use torsh_tensor::Tensor;
11// Direct SciRS2 datasets integration
12// use scirs2_datasets::{load_iris, load_boston, make_classification}; // Will be uncommented when API stabilizes
13
14/// Built-in dataset types
15#[derive(Debug, Clone)]
16pub enum BuiltinDataset {
17    Iris,
18    Boston,
19    Diabetes,
20    Wine,
21    BreastCancer,
22    Digits,
23}
24
25/// Synthetic data generation configuration
26#[derive(Debug, Clone)]
27pub struct SyntheticDataConfig {
28    /// Number of samples to generate
29    pub n_samples: usize,
30    /// Number of features
31    pub n_features: usize,
32    /// Number of classes (for classification)
33    pub n_classes: Option<usize>,
34    /// Random seed for reproducibility
35    pub seed: Option<u64>,
36    /// Whether to add noise
37    pub noise: Option<f64>,
38    /// Feature scaling method
39    pub scale: Option<ScalingMethod>,
40}
41
42/// Feature scaling methods
43#[derive(Debug, Clone)]
44pub enum ScalingMethod {
45    StandardScaler,
46    MinMaxScaler,
47    RobustScaler,
48    Normalizer,
49}
50
51/// Regression data generation parameters
52#[derive(Debug, Clone)]
53pub struct RegressionConfig {
54    pub n_samples: usize,
55    pub n_features: usize,
56    pub n_informative: Option<usize>,
57    pub noise: Option<f64>,
58    pub bias: Option<f64>,
59    pub random_state: Option<u64>,
60}
61
62/// Classification data generation parameters
63#[derive(Debug, Clone)]
64pub struct ClassificationConfig {
65    pub n_samples: usize,
66    pub n_features: usize,
67    pub n_classes: usize,
68    pub n_informative: Option<usize>,
69    pub n_redundant: Option<usize>,
70    pub n_clusters_per_class: Option<usize>,
71    pub class_sep: Option<f64>,
72    pub random_state: Option<u64>,
73}
74
75/// Clustering data generation parameters
76#[derive(Debug, Clone)]
77pub struct ClusteringConfig {
78    pub n_samples: usize,
79    pub centers: usize,
80    pub n_features: Option<usize>,
81    pub cluster_std: Option<f64>,
82    pub center_box: Option<(f64, f64)>,
83    pub random_state: Option<u64>,
84}
85
86/// Dataset result containing features and targets
87#[derive(Debug, Clone)]
88pub struct DatasetResult {
89    pub features: Tensor,
90    pub targets: Tensor,
91    pub feature_names: Option<Vec<String>>,
92    pub target_names: Option<Vec<String>>,
93    pub description: String,
94}
95
96impl Default for SyntheticDataConfig {
97    fn default() -> Self {
98        Self {
99            n_samples: 100,
100            n_features: 2,
101            n_classes: Some(2),
102            seed: None,
103            noise: Some(0.1),
104            scale: Some(ScalingMethod::StandardScaler),
105        }
106    }
107}
108
109/// Load a built-in dataset
110pub fn load_builtin_dataset(dataset: BuiltinDataset) -> Result<DatasetResult, DataError> {
111    match dataset {
112        BuiltinDataset::Iris => load_iris_dataset(),
113        BuiltinDataset::Boston => load_boston_dataset(),
114        BuiltinDataset::Diabetes => load_diabetes_dataset(),
115        BuiltinDataset::Wine => load_wine_dataset(),
116        BuiltinDataset::BreastCancer => load_breast_cancer_dataset(),
117        BuiltinDataset::Digits => load_digits_dataset(),
118    }
119}
120
121/// Generate synthetic regression data
122pub fn make_regression(config: RegressionConfig) -> Result<DatasetResult, DataError> {
123    // TODO: Use scirs2_datasets::make_regression when API is available
124    // For now, implement basic synthetic data generation
125
126    let n_informative = config.n_informative.unwrap_or(config.n_features);
127    let noise_std = config.noise.unwrap_or(0.0);
128
129    // Generate random features
130    let features_data: Vec<f32> = (0..config.n_samples * config.n_features)
131        .map(|_| {
132            // ✅ SciRS2 Policy Compliant
133            let mut rng = scirs2_core::random::thread_rng();
134            rng.gen_range(-1.0..1.0)
135        })
136        .collect();
137
138    let features = Tensor::from_vec(features_data, &[config.n_samples, config.n_features])?;
139
140    // Generate targets as linear combination of informative features
141    let targets_data: Vec<f32> = (0..config.n_samples)
142        .map(|i| {
143            let mut target = 0.0;
144            for j in 0..n_informative {
145                if let Ok(feature_vec) = features.to_vec() {
146                    let idx = i * config.n_features + j;
147                    if let Some(&feature_val) = feature_vec.get(idx) {
148                        target += feature_val;
149                    }
150                }
151            }
152
153            // Add noise
154            if noise_std > 0.0 {
155                // ✅ SciRS2 Policy Compliant
156                let mut rng = scirs2_core::random::thread_rng();
157                target += rng.gen_range(-noise_std as f32..noise_std as f32);
158            }
159
160            target
161        })
162        .collect();
163
164    let targets = Tensor::from_vec(targets_data, &[config.n_samples])?;
165
166    Ok(DatasetResult {
167        features,
168        targets,
169        feature_names: Some(
170            (0..config.n_features)
171                .map(|i| format!("feature_{}", i))
172                .collect(),
173        ),
174        target_names: Some(vec!["target".to_string()]),
175        description: "Synthetic regression dataset".to_string(),
176    })
177}
178
179/// Generate synthetic classification data
180pub fn make_classification(config: ClassificationConfig) -> Result<DatasetResult, DataError> {
181    // TODO: Use scirs2_datasets::make_classification when API is available
182    // For now, implement basic synthetic data generation
183
184    // ✅ SciRS2 Policy Compliant - Using scirs2_core::random
185    let mut rng = if let Some(seed) = config.random_state {
186        scirs2_core::random::StdRng::seed_from_u64(seed)
187    } else {
188        {
189            let mut thread_rng = scirs2_core::random::thread_rng();
190            scirs2_core::random::StdRng::from_rng(&mut thread_rng)
191        }
192    };
193
194    // Generate random features
195    let features_data: Vec<f32> = (0..config.n_samples * config.n_features)
196        .map(|_| rng.gen_range(-1.0..1.0))
197        .collect();
198
199    let features = Tensor::from_vec(features_data, &[config.n_samples, config.n_features])?;
200
201    // Generate targets
202    let targets_data: Vec<f32> = (0..config.n_samples)
203        .map(|_| rng.gen_range(0..config.n_classes) as f32)
204        .collect();
205
206    let targets = Tensor::from_vec(targets_data, &[config.n_samples])?;
207
208    Ok(DatasetResult {
209        features,
210        targets,
211        feature_names: Some(
212            (0..config.n_features)
213                .map(|i| format!("feature_{}", i))
214                .collect(),
215        ),
216        target_names: Some(
217            (0..config.n_classes)
218                .map(|i| format!("class_{}", i))
219                .collect(),
220        ),
221        description: "Synthetic classification dataset".to_string(),
222    })
223}
224
225/// Generate synthetic clustering data (blobs)
226pub fn make_blobs(config: ClusteringConfig) -> Result<DatasetResult, DataError> {
227    // TODO: Use scirs2_datasets::make_blobs when API is available
228
229    // ✅ SciRS2 Policy Compliant - Using scirs2_core::random
230    let mut rng = if let Some(seed) = config.random_state {
231        scirs2_core::random::StdRng::seed_from_u64(seed)
232    } else {
233        {
234            let mut thread_rng = scirs2_core::random::thread_rng();
235            scirs2_core::random::StdRng::from_rng(&mut thread_rng)
236        }
237    };
238
239    let n_features = config.n_features.unwrap_or(2);
240    let cluster_std = config.cluster_std.unwrap_or(1.0);
241
242    // Generate cluster centers
243    let centers: Vec<Vec<f32>> = (0..config.centers)
244        .map(|_| (0..n_features).map(|_| rng.gen_range(-5.0..5.0)).collect())
245        .collect();
246
247    let samples_per_cluster = config.n_samples / config.centers;
248    let mut features_data = Vec::new();
249    let mut targets_data = Vec::new();
250
251    for (cluster_id, center) in centers.iter().enumerate() {
252        for _ in 0..samples_per_cluster {
253            // Generate point around cluster center
254            for &center_coord in center {
255                let noise: f32 = rng.gen_range(-cluster_std as f32..cluster_std as f32);
256                features_data.push(center_coord + noise);
257            }
258            targets_data.push(cluster_id as f32);
259        }
260    }
261
262    let features = Tensor::from_vec(
263        features_data,
264        &[samples_per_cluster * config.centers, n_features],
265    )?;
266
267    let targets = Tensor::from_vec(targets_data, &[samples_per_cluster * config.centers])?;
268
269    Ok(DatasetResult {
270        features,
271        targets,
272        feature_names: Some((0..n_features).map(|i| format!("feature_{}", i)).collect()),
273        target_names: Some(
274            (0..config.centers)
275                .map(|i| format!("cluster_{}", i))
276                .collect(),
277        ),
278        description: "Synthetic clustering dataset (blobs)".to_string(),
279    })
280}
281
282// Built-in dataset implementations (placeholders for now)
283fn load_iris_dataset() -> Result<DatasetResult, DataError> {
284    // TODO: Use scirs2_datasets::load_iris() when available
285    // For now, create a minimal iris-like dataset
286    make_classification(ClassificationConfig {
287        n_samples: 150,
288        n_features: 4,
289        n_classes: 3,
290        n_informative: Some(4),
291        random_state: Some(42),
292        ..Default::default()
293    })
294}
295
296fn load_boston_dataset() -> Result<DatasetResult, DataError> {
297    // TODO: Use scirs2_datasets::load_boston() when available
298    make_regression(RegressionConfig {
299        n_samples: 506,
300        n_features: 13,
301        n_informative: Some(13),
302        noise: Some(0.1),
303        random_state: Some(42),
304        bias: Some(0.0),
305    })
306}
307
308fn load_diabetes_dataset() -> Result<DatasetResult, DataError> {
309    make_regression(RegressionConfig {
310        n_samples: 442,
311        n_features: 10,
312        n_informative: Some(10),
313        noise: Some(0.1),
314        random_state: Some(42),
315        bias: Some(0.0),
316    })
317}
318
319fn load_wine_dataset() -> Result<DatasetResult, DataError> {
320    make_classification(ClassificationConfig {
321        n_samples: 178,
322        n_features: 13,
323        n_classes: 3,
324        n_informative: Some(13),
325        random_state: Some(42),
326        ..Default::default()
327    })
328}
329
330fn load_breast_cancer_dataset() -> Result<DatasetResult, DataError> {
331    make_classification(ClassificationConfig {
332        n_samples: 569,
333        n_features: 30,
334        n_classes: 2,
335        n_informative: Some(30),
336        random_state: Some(42),
337        ..Default::default()
338    })
339}
340
341fn load_digits_dataset() -> Result<DatasetResult, DataError> {
342    make_classification(ClassificationConfig {
343        n_samples: 1797,
344        n_features: 64,
345        n_classes: 10,
346        n_informative: Some(64),
347        random_state: Some(42),
348        ..Default::default()
349    })
350}
351
352impl Default for RegressionConfig {
353    fn default() -> Self {
354        Self {
355            n_samples: 100,
356            n_features: 1,
357            n_informative: None,
358            noise: Some(0.1),
359            bias: Some(0.0),
360            random_state: None,
361        }
362    }
363}
364
365impl Default for ClassificationConfig {
366    fn default() -> Self {
367        Self {
368            n_samples: 100,
369            n_features: 2,
370            n_classes: 2,
371            n_informative: None,
372            n_redundant: None,
373            n_clusters_per_class: None,
374            class_sep: Some(1.0),
375            random_state: None,
376        }
377    }
378}
379
380impl Default for ClusteringConfig {
381    fn default() -> Self {
382        Self {
383            n_samples: 100,
384            centers: 3,
385            n_features: Some(2),
386            cluster_std: Some(1.0),
387            center_box: Some((-10.0, 10.0)),
388            random_state: None,
389        }
390    }
391}
392
393/// Dataset registry for managing available datasets
394#[derive(Debug, Default)]
395pub struct DatasetRegistry {
396    builtin_datasets: Vec<BuiltinDataset>,
397}
398
399impl DatasetRegistry {
400    /// Create a new dataset registry
401    pub fn new() -> Self {
402        Self {
403            builtin_datasets: vec![
404                BuiltinDataset::Iris,
405                BuiltinDataset::Boston,
406                BuiltinDataset::Diabetes,
407                BuiltinDataset::Wine,
408                BuiltinDataset::BreastCancer,
409                BuiltinDataset::Digits,
410            ],
411        }
412    }
413
414    /// List all available built-in datasets
415    pub fn list_builtin(&self) -> &[BuiltinDataset] {
416        &self.builtin_datasets
417    }
418
419    /// Load a dataset by name
420    pub fn load_by_name(&self, name: &str) -> Result<DatasetResult, DataError> {
421        let dataset = match name.to_lowercase().as_str() {
422            "iris" => BuiltinDataset::Iris,
423            "boston" => BuiltinDataset::Boston,
424            "diabetes" => BuiltinDataset::Diabetes,
425            "wine" => BuiltinDataset::Wine,
426            "breast_cancer" | "breastcancer" => BuiltinDataset::BreastCancer,
427            "digits" => BuiltinDataset::Digits,
428            _ => {
429                return Err(DataError::dataset(
430                    crate::error::DatasetErrorKind::UnsupportedFormat,
431                    format!("Unknown dataset: {}", name),
432                ))
433            }
434        };
435
436        load_builtin_dataset(dataset)
437    }
438}