1use crate::error::DataError;
7use scirs2_core::random::{Random, Rng, SeedableRng};
9use torsh_core::error::TorshError;
10use torsh_tensor::Tensor;
11#[derive(Debug, Clone)]
16pub enum BuiltinDataset {
17 Iris,
18 Boston,
19 Diabetes,
20 Wine,
21 BreastCancer,
22 Digits,
23}
24
25#[derive(Debug, Clone)]
27pub struct SyntheticDataConfig {
28 pub n_samples: usize,
30 pub n_features: usize,
32 pub n_classes: Option<usize>,
34 pub seed: Option<u64>,
36 pub noise: Option<f64>,
38 pub scale: Option<ScalingMethod>,
40}
41
42#[derive(Debug, Clone)]
44pub enum ScalingMethod {
45 StandardScaler,
46 MinMaxScaler,
47 RobustScaler,
48 Normalizer,
49}
50
51#[derive(Debug, Clone)]
53pub struct RegressionConfig {
54 pub n_samples: usize,
55 pub n_features: usize,
56 pub n_informative: Option<usize>,
57 pub noise: Option<f64>,
58 pub bias: Option<f64>,
59 pub random_state: Option<u64>,
60}
61
62#[derive(Debug, Clone)]
64pub struct ClassificationConfig {
65 pub n_samples: usize,
66 pub n_features: usize,
67 pub n_classes: usize,
68 pub n_informative: Option<usize>,
69 pub n_redundant: Option<usize>,
70 pub n_clusters_per_class: Option<usize>,
71 pub class_sep: Option<f64>,
72 pub random_state: Option<u64>,
73}
74
75#[derive(Debug, Clone)]
77pub struct ClusteringConfig {
78 pub n_samples: usize,
79 pub centers: usize,
80 pub n_features: Option<usize>,
81 pub cluster_std: Option<f64>,
82 pub center_box: Option<(f64, f64)>,
83 pub random_state: Option<u64>,
84}
85
86#[derive(Debug, Clone)]
88pub struct DatasetResult {
89 pub features: Tensor,
90 pub targets: Tensor,
91 pub feature_names: Option<Vec<String>>,
92 pub target_names: Option<Vec<String>>,
93 pub description: String,
94}
95
96impl Default for SyntheticDataConfig {
97 fn default() -> Self {
98 Self {
99 n_samples: 100,
100 n_features: 2,
101 n_classes: Some(2),
102 seed: None,
103 noise: Some(0.1),
104 scale: Some(ScalingMethod::StandardScaler),
105 }
106 }
107}
108
109pub fn load_builtin_dataset(dataset: BuiltinDataset) -> Result<DatasetResult, DataError> {
111 match dataset {
112 BuiltinDataset::Iris => load_iris_dataset(),
113 BuiltinDataset::Boston => load_boston_dataset(),
114 BuiltinDataset::Diabetes => load_diabetes_dataset(),
115 BuiltinDataset::Wine => load_wine_dataset(),
116 BuiltinDataset::BreastCancer => load_breast_cancer_dataset(),
117 BuiltinDataset::Digits => load_digits_dataset(),
118 }
119}
120
121pub fn make_regression(config: RegressionConfig) -> Result<DatasetResult, DataError> {
123 let n_informative = config.n_informative.unwrap_or(config.n_features);
127 let noise_std = config.noise.unwrap_or(0.0);
128
129 let features_data: Vec<f32> = (0..config.n_samples * config.n_features)
131 .map(|_| {
132 let mut rng = scirs2_core::random::thread_rng();
134 rng.gen_range(-1.0..1.0)
135 })
136 .collect();
137
138 let features = Tensor::from_vec(features_data, &[config.n_samples, config.n_features])?;
139
140 let targets_data: Vec<f32> = (0..config.n_samples)
142 .map(|i| {
143 let mut target = 0.0;
144 for j in 0..n_informative {
145 if let Ok(feature_vec) = features.to_vec() {
146 let idx = i * config.n_features + j;
147 if let Some(&feature_val) = feature_vec.get(idx) {
148 target += feature_val;
149 }
150 }
151 }
152
153 if noise_std > 0.0 {
155 let mut rng = scirs2_core::random::thread_rng();
157 target += rng.gen_range(-noise_std as f32..noise_std as f32);
158 }
159
160 target
161 })
162 .collect();
163
164 let targets = Tensor::from_vec(targets_data, &[config.n_samples])?;
165
166 Ok(DatasetResult {
167 features,
168 targets,
169 feature_names: Some(
170 (0..config.n_features)
171 .map(|i| format!("feature_{}", i))
172 .collect(),
173 ),
174 target_names: Some(vec!["target".to_string()]),
175 description: "Synthetic regression dataset".to_string(),
176 })
177}
178
179pub fn make_classification(config: ClassificationConfig) -> Result<DatasetResult, DataError> {
181 let mut rng = if let Some(seed) = config.random_state {
186 scirs2_core::random::StdRng::seed_from_u64(seed)
187 } else {
188 {
189 let mut thread_rng = scirs2_core::random::thread_rng();
190 scirs2_core::random::StdRng::from_rng(&mut thread_rng)
191 }
192 };
193
194 let features_data: Vec<f32> = (0..config.n_samples * config.n_features)
196 .map(|_| rng.gen_range(-1.0..1.0))
197 .collect();
198
199 let features = Tensor::from_vec(features_data, &[config.n_samples, config.n_features])?;
200
201 let targets_data: Vec<f32> = (0..config.n_samples)
203 .map(|_| rng.gen_range(0..config.n_classes) as f32)
204 .collect();
205
206 let targets = Tensor::from_vec(targets_data, &[config.n_samples])?;
207
208 Ok(DatasetResult {
209 features,
210 targets,
211 feature_names: Some(
212 (0..config.n_features)
213 .map(|i| format!("feature_{}", i))
214 .collect(),
215 ),
216 target_names: Some(
217 (0..config.n_classes)
218 .map(|i| format!("class_{}", i))
219 .collect(),
220 ),
221 description: "Synthetic classification dataset".to_string(),
222 })
223}
224
225pub fn make_blobs(config: ClusteringConfig) -> Result<DatasetResult, DataError> {
227 let mut rng = if let Some(seed) = config.random_state {
231 scirs2_core::random::StdRng::seed_from_u64(seed)
232 } else {
233 {
234 let mut thread_rng = scirs2_core::random::thread_rng();
235 scirs2_core::random::StdRng::from_rng(&mut thread_rng)
236 }
237 };
238
239 let n_features = config.n_features.unwrap_or(2);
240 let cluster_std = config.cluster_std.unwrap_or(1.0);
241
242 let centers: Vec<Vec<f32>> = (0..config.centers)
244 .map(|_| (0..n_features).map(|_| rng.gen_range(-5.0..5.0)).collect())
245 .collect();
246
247 let samples_per_cluster = config.n_samples / config.centers;
248 let mut features_data = Vec::new();
249 let mut targets_data = Vec::new();
250
251 for (cluster_id, center) in centers.iter().enumerate() {
252 for _ in 0..samples_per_cluster {
253 for ¢er_coord in center {
255 let noise: f32 = rng.gen_range(-cluster_std as f32..cluster_std as f32);
256 features_data.push(center_coord + noise);
257 }
258 targets_data.push(cluster_id as f32);
259 }
260 }
261
262 let features = Tensor::from_vec(
263 features_data,
264 &[samples_per_cluster * config.centers, n_features],
265 )?;
266
267 let targets = Tensor::from_vec(targets_data, &[samples_per_cluster * config.centers])?;
268
269 Ok(DatasetResult {
270 features,
271 targets,
272 feature_names: Some((0..n_features).map(|i| format!("feature_{}", i)).collect()),
273 target_names: Some(
274 (0..config.centers)
275 .map(|i| format!("cluster_{}", i))
276 .collect(),
277 ),
278 description: "Synthetic clustering dataset (blobs)".to_string(),
279 })
280}
281
282fn load_iris_dataset() -> Result<DatasetResult, DataError> {
284 make_classification(ClassificationConfig {
287 n_samples: 150,
288 n_features: 4,
289 n_classes: 3,
290 n_informative: Some(4),
291 random_state: Some(42),
292 ..Default::default()
293 })
294}
295
296fn load_boston_dataset() -> Result<DatasetResult, DataError> {
297 make_regression(RegressionConfig {
299 n_samples: 506,
300 n_features: 13,
301 n_informative: Some(13),
302 noise: Some(0.1),
303 random_state: Some(42),
304 bias: Some(0.0),
305 })
306}
307
308fn load_diabetes_dataset() -> Result<DatasetResult, DataError> {
309 make_regression(RegressionConfig {
310 n_samples: 442,
311 n_features: 10,
312 n_informative: Some(10),
313 noise: Some(0.1),
314 random_state: Some(42),
315 bias: Some(0.0),
316 })
317}
318
319fn load_wine_dataset() -> Result<DatasetResult, DataError> {
320 make_classification(ClassificationConfig {
321 n_samples: 178,
322 n_features: 13,
323 n_classes: 3,
324 n_informative: Some(13),
325 random_state: Some(42),
326 ..Default::default()
327 })
328}
329
330fn load_breast_cancer_dataset() -> Result<DatasetResult, DataError> {
331 make_classification(ClassificationConfig {
332 n_samples: 569,
333 n_features: 30,
334 n_classes: 2,
335 n_informative: Some(30),
336 random_state: Some(42),
337 ..Default::default()
338 })
339}
340
341fn load_digits_dataset() -> Result<DatasetResult, DataError> {
342 make_classification(ClassificationConfig {
343 n_samples: 1797,
344 n_features: 64,
345 n_classes: 10,
346 n_informative: Some(64),
347 random_state: Some(42),
348 ..Default::default()
349 })
350}
351
352impl Default for RegressionConfig {
353 fn default() -> Self {
354 Self {
355 n_samples: 100,
356 n_features: 1,
357 n_informative: None,
358 noise: Some(0.1),
359 bias: Some(0.0),
360 random_state: None,
361 }
362 }
363}
364
365impl Default for ClassificationConfig {
366 fn default() -> Self {
367 Self {
368 n_samples: 100,
369 n_features: 2,
370 n_classes: 2,
371 n_informative: None,
372 n_redundant: None,
373 n_clusters_per_class: None,
374 class_sep: Some(1.0),
375 random_state: None,
376 }
377 }
378}
379
380impl Default for ClusteringConfig {
381 fn default() -> Self {
382 Self {
383 n_samples: 100,
384 centers: 3,
385 n_features: Some(2),
386 cluster_std: Some(1.0),
387 center_box: Some((-10.0, 10.0)),
388 random_state: None,
389 }
390 }
391}
392
393#[derive(Debug, Default)]
395pub struct DatasetRegistry {
396 builtin_datasets: Vec<BuiltinDataset>,
397}
398
399impl DatasetRegistry {
400 pub fn new() -> Self {
402 Self {
403 builtin_datasets: vec![
404 BuiltinDataset::Iris,
405 BuiltinDataset::Boston,
406 BuiltinDataset::Diabetes,
407 BuiltinDataset::Wine,
408 BuiltinDataset::BreastCancer,
409 BuiltinDataset::Digits,
410 ],
411 }
412 }
413
414 pub fn list_builtin(&self) -> &[BuiltinDataset] {
416 &self.builtin_datasets
417 }
418
419 pub fn load_by_name(&self, name: &str) -> Result<DatasetResult, DataError> {
421 let dataset = match name.to_lowercase().as_str() {
422 "iris" => BuiltinDataset::Iris,
423 "boston" => BuiltinDataset::Boston,
424 "diabetes" => BuiltinDataset::Diabetes,
425 "wine" => BuiltinDataset::Wine,
426 "breast_cancer" | "breastcancer" => BuiltinDataset::BreastCancer,
427 "digits" => BuiltinDataset::Digits,
428 _ => {
429 return Err(DataError::dataset(
430 crate::error::DatasetErrorKind::UnsupportedFormat,
431 format!("Unknown dataset: {}", name),
432 ))
433 }
434 };
435
436 load_builtin_dataset(dataset)
437 }
438}