1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
8use scirs2_core::random::{Distribution, RandNormal, Random};
9use std::collections::HashMap;
10use thiserror::Error;
11
12#[derive(Error, Debug)]
17pub enum DatasetTraitError {
18 #[error("Generation error: {0}")]
19 Generation(String),
20 #[error("Validation error: {0}")]
21 Validation(String),
22 #[error("Configuration error: {0}")]
23 Configuration(String),
24 #[error("IO error: {0}")]
25 Io(String),
26 #[error("Dimension mismatch: expected {expected}, got {actual}")]
27 DimensionMismatch { expected: String, actual: String },
28 #[error("Unsupported operation: {0}")]
29 UnsupportedOperation(String),
30}
31
32pub type DatasetTraitResult<T> = Result<T, DatasetTraitError>;
33
34pub trait Dataset {
36 fn n_samples(&self) -> usize;
38
39 fn n_features(&self) -> usize;
41
42 fn shape(&self) -> (usize, usize) {
44 (self.n_samples(), self.n_features())
45 }
46
47 fn features(&self) -> DatasetTraitResult<ArrayView2<'_, f64>>;
49
50 fn sample(&self, index: usize) -> DatasetTraitResult<ArrayView1<'_, f64>>;
52
53 fn has_targets(&self) -> bool;
55
56 fn targets(&self) -> DatasetTraitResult<Option<ArrayView1<'_, f64>>>;
58
59 fn metadata(&self) -> HashMap<String, String> {
61 HashMap::new()
62 }
63}
64
65pub trait DatasetGenerator {
67 type Config: Default + Clone;
68 type Output: Dataset;
69
70 fn generate(&self, config: Self::Config) -> DatasetTraitResult<Self::Output>;
72
73 fn name(&self) -> &'static str;
75
76 fn description(&self) -> &'static str;
78
79 fn validate_config(&self, config: &Self::Config) -> DatasetTraitResult<()> {
81 let _ = config;
82 Ok(())
83 }
84}
85
86pub trait DatasetLoader {
88 type Config: Default + Clone;
89 type Output: Dataset;
90
91 fn load(&self, config: Self::Config) -> DatasetTraitResult<Self::Output>;
93
94 fn name(&self) -> &'static str;
96
97 fn available_datasets(&self) -> Vec<String>;
99
100 fn has_dataset(&self, name: &str) -> bool {
102 self.available_datasets().contains(&name.to_string())
103 }
104}
105
106pub trait DatasetTransformer {
108 type Config: Default + Clone;
109 type Input: Dataset;
110 type Output: Dataset;
111
112 fn transform(
114 &self,
115 input: Self::Input,
116 config: Self::Config,
117 ) -> DatasetTraitResult<Self::Output>;
118
119 fn name(&self) -> &'static str;
121
122 fn can_transform(&self, input: &Self::Input) -> bool;
124}
125
126pub trait DatasetValidator {
128 type Config: Default + Clone;
129 type Report: Default;
130
131 fn validate(
133 &self,
134 dataset: &dyn Dataset,
135 config: Self::Config,
136 ) -> DatasetTraitResult<Self::Report>;
137
138 fn name(&self) -> &'static str;
140
141 fn criteria(&self) -> Vec<String>;
143}
144
145pub trait StreamingDataset: Dataset {
147 type Batch;
148
149 fn batch(&self, start: usize, size: usize) -> DatasetTraitResult<Self::Batch>;
151
152 fn batches(
154 &self,
155 batch_size: usize,
156 ) -> Box<dyn Iterator<Item = DatasetTraitResult<Self::Batch>>>;
157
158 fn preferred_batch_size(&self) -> usize {
160 1000
161 }
162}
163
164pub trait MutableDataset: Dataset {
166 fn set_sample(&mut self, index: usize, sample: ArrayView1<f64>) -> DatasetTraitResult<()>;
168
169 fn set_targets(&mut self, targets: ArrayView1<f64>) -> DatasetTraitResult<()>;
171
172 fn add_sample(
174 &mut self,
175 sample: ArrayView1<f64>,
176 target: Option<f64>,
177 ) -> DatasetTraitResult<()>;
178
179 fn remove_sample(&mut self, index: usize) -> DatasetTraitResult<()>;
181}
182
183pub trait GenerationStrategy {
185 type Config: Default + Clone;
186
187 fn apply(&self, config: &mut Self::Config, rng: &mut Random) -> DatasetTraitResult<()>;
189
190 fn name(&self) -> &'static str;
192
193 fn is_applicable(&self, config: &Self::Config) -> bool;
195}
196
197#[derive(Debug, Clone)]
199pub struct InMemoryDataset {
200 features: Array2<f64>,
201 targets: Option<Array1<f64>>,
202 metadata: HashMap<String, String>,
203}
204
205impl InMemoryDataset {
206 pub fn new(features: Array2<f64>, targets: Option<Array1<f64>>) -> Self {
208 Self {
209 features,
210 targets,
211 metadata: HashMap::new(),
212 }
213 }
214
215 pub fn with_metadata(
217 features: Array2<f64>,
218 targets: Option<Array1<f64>>,
219 metadata: HashMap<String, String>,
220 ) -> Self {
221 Self {
222 features,
223 targets,
224 metadata,
225 }
226 }
227
228 pub fn add_metadata(&mut self, key: String, value: String) {
230 self.metadata.insert(key, value);
231 }
232}
233
234impl Dataset for InMemoryDataset {
235 fn n_samples(&self) -> usize {
236 self.features.nrows()
237 }
238
239 fn n_features(&self) -> usize {
240 self.features.ncols()
241 }
242
243 fn features(&self) -> DatasetTraitResult<ArrayView2<'_, f64>> {
244 Ok(self.features.view())
245 }
246
247 fn sample(&self, index: usize) -> DatasetTraitResult<ArrayView1<'_, f64>> {
248 if index >= self.n_samples() {
249 return Err(DatasetTraitError::DimensionMismatch {
250 expected: format!("index < {}", self.n_samples()),
251 actual: format!("index = {}", index),
252 });
253 }
254 Ok(self.features.row(index))
255 }
256
257 fn has_targets(&self) -> bool {
258 self.targets.is_some()
259 }
260
261 fn targets(&self) -> DatasetTraitResult<Option<ArrayView1<'_, f64>>> {
262 Ok(self.targets.as_ref().map(|t| t.view()))
263 }
264
265 fn metadata(&self) -> HashMap<String, String> {
266 self.metadata.clone()
267 }
268}
269
270impl MutableDataset for InMemoryDataset {
271 fn set_sample(&mut self, index: usize, sample: ArrayView1<f64>) -> DatasetTraitResult<()> {
272 if index >= self.n_samples() {
273 return Err(DatasetTraitError::DimensionMismatch {
274 expected: format!("index < {}", self.n_samples()),
275 actual: format!("index = {}", index),
276 });
277 }
278 if sample.len() != self.n_features() {
279 return Err(DatasetTraitError::DimensionMismatch {
280 expected: format!("{} features", self.n_features()),
281 actual: format!("{} features", sample.len()),
282 });
283 }
284 self.features.row_mut(index).assign(&sample);
285 Ok(())
286 }
287
288 fn set_targets(&mut self, targets: ArrayView1<f64>) -> DatasetTraitResult<()> {
289 if targets.len() != self.n_samples() {
290 return Err(DatasetTraitError::DimensionMismatch {
291 expected: format!("{} targets", self.n_samples()),
292 actual: format!("{} targets", targets.len()),
293 });
294 }
295 self.targets = Some(targets.to_owned());
296 Ok(())
297 }
298
299 fn add_sample(
300 &mut self,
301 sample: ArrayView1<f64>,
302 _target: Option<f64>,
303 ) -> DatasetTraitResult<()> {
304 if sample.len() != self.n_features() {
305 return Err(DatasetTraitError::DimensionMismatch {
306 expected: format!("{} features", self.n_features()),
307 actual: format!("{} features", sample.len()),
308 });
309 }
310
311 Err(DatasetTraitError::UnsupportedOperation(
313 "Adding samples to fixed-size arrays not yet implemented".to_string(),
314 ))
315 }
316
317 fn remove_sample(&mut self, index: usize) -> DatasetTraitResult<()> {
318 if index >= self.n_samples() {
319 return Err(DatasetTraitError::DimensionMismatch {
320 expected: format!("index < {}", self.n_samples()),
321 actual: format!("index = {}", index),
322 });
323 }
324
325 Err(DatasetTraitError::UnsupportedOperation(
327 "Removing samples from fixed-size arrays not yet implemented".to_string(),
328 ))
329 }
330}
331
332pub struct GeneratorRegistry {
334 generators: HashMap<
335 String,
336 Box<dyn DatasetGenerator<Config = GeneratorConfig, Output = InMemoryDataset>>,
337 >,
338}
339
340impl GeneratorRegistry {
341 pub fn new() -> Self {
343 Self {
344 generators: HashMap::new(),
345 }
346 }
347
348 pub fn register<G>(&mut self, generator: G)
350 where
351 G: DatasetGenerator<Config = GeneratorConfig, Output = InMemoryDataset> + 'static,
352 {
353 self.generators
354 .insert(generator.name().to_string(), Box::new(generator));
355 }
356
357 pub fn get(
359 &self,
360 name: &str,
361 ) -> Option<&dyn DatasetGenerator<Config = GeneratorConfig, Output = InMemoryDataset>> {
362 self.generators.get(name).map(|g| g.as_ref())
363 }
364
365 pub fn list(&self) -> Vec<String> {
367 self.generators.keys().cloned().collect()
368 }
369
370 pub fn generate(
372 &self,
373 name: &str,
374 config: GeneratorConfig,
375 ) -> DatasetTraitResult<InMemoryDataset> {
376 let generator = self.get(name).ok_or_else(|| {
377 DatasetTraitError::Configuration(format!("Unknown generator: {}", name))
378 })?;
379 generator.generate(config)
380 }
381}
382
383impl Default for GeneratorRegistry {
384 fn default() -> Self {
385 Self::new()
386 }
387}
388
389#[derive(Debug, Clone)]
391pub struct GeneratorConfig {
392 pub n_samples: usize,
393 pub n_features: usize,
394 pub random_state: Option<u64>,
395 pub parameters: HashMap<String, ConfigValue>,
396}
397
398impl Default for GeneratorConfig {
399 fn default() -> Self {
400 Self {
401 n_samples: 100,
402 n_features: 2,
403 random_state: None,
404 parameters: HashMap::new(),
405 }
406 }
407}
408
409impl GeneratorConfig {
410 pub fn new(n_samples: usize, n_features: usize) -> Self {
412 Self {
413 n_samples,
414 n_features,
415 random_state: None,
416 parameters: HashMap::new(),
417 }
418 }
419
420 pub fn set_parameter<T: Into<ConfigValue>>(&mut self, key: String, value: T) {
422 self.parameters.insert(key, value.into());
423 }
424
425 pub fn get_parameter(&self, key: &str) -> Option<&ConfigValue> {
427 self.parameters.get(key)
428 }
429
430 pub fn with_random_state(mut self, seed: u64) -> Self {
432 self.random_state = Some(seed);
433 self
434 }
435}
436
437#[derive(Debug, Clone)]
439pub enum ConfigValue {
440 Int(i64),
442 Float(f64),
444 String(String),
446 Bool(bool),
448 IntArray(Vec<i64>),
450 FloatArray(Vec<f64>),
452}
453
454impl From<i64> for ConfigValue {
455 fn from(value: i64) -> Self {
456 ConfigValue::Int(value)
457 }
458}
459
460impl From<f64> for ConfigValue {
461 fn from(value: f64) -> Self {
462 ConfigValue::Float(value)
463 }
464}
465
466impl From<String> for ConfigValue {
467 fn from(value: String) -> Self {
468 ConfigValue::String(value)
469 }
470}
471
472impl From<bool> for ConfigValue {
473 fn from(value: bool) -> Self {
474 ConfigValue::Bool(value)
475 }
476}
477
478impl From<Vec<i64>> for ConfigValue {
479 fn from(value: Vec<i64>) -> Self {
480 ConfigValue::IntArray(value)
481 }
482}
483
484impl From<Vec<f64>> for ConfigValue {
485 fn from(value: Vec<f64>) -> Self {
486 ConfigValue::FloatArray(value)
487 }
488}
489
490pub struct ClassificationGenerator;
492
493impl DatasetGenerator for ClassificationGenerator {
494 type Config = GeneratorConfig;
495 type Output = InMemoryDataset;
496
497 fn generate(&self, config: Self::Config) -> DatasetTraitResult<Self::Output> {
498 let mut rng = match config.random_state {
499 Some(seed) => Random::seed(seed),
500 None => Random::seed(42),
501 };
502
503 let n_classes = config
505 .get_parameter("n_classes")
506 .and_then(|v| match v {
507 ConfigValue::Int(n) => Some(*n as usize),
508 _ => None,
509 })
510 .unwrap_or(2);
511
512 let mut features = Array2::<f64>::zeros((config.n_samples, config.n_features));
514 let normal_dist = RandNormal::new(0.0, 1.0).unwrap();
515 for mut row in features.rows_mut() {
516 for val in row.iter_mut() {
517 *val = normal_dist.sample(&mut rng);
518 }
519 }
520
521 let targets: Array1<f64> =
523 Array1::from_shape_fn(config.n_samples, |_| rng.gen_range(0..n_classes) as f64);
524
525 let mut metadata = HashMap::new();
526 metadata.insert("generator".to_string(), "classification".to_string());
527 metadata.insert("n_classes".to_string(), n_classes.to_string());
528
529 Ok(InMemoryDataset::with_metadata(
530 features,
531 Some(targets),
532 metadata,
533 ))
534 }
535
536 fn name(&self) -> &'static str {
537 "classification"
538 }
539
540 fn description(&self) -> &'static str {
541 "Generates a classification dataset with Gaussian features"
542 }
543
544 fn validate_config(&self, config: &Self::Config) -> DatasetTraitResult<()> {
545 if config.n_samples == 0 {
546 return Err(DatasetTraitError::Configuration(
547 "n_samples must be > 0".to_string(),
548 ));
549 }
550 if config.n_features == 0 {
551 return Err(DatasetTraitError::Configuration(
552 "n_features must be > 0".to_string(),
553 ));
554 }
555
556 if let Some(ConfigValue::Int(n_classes)) = config.get_parameter("n_classes") {
558 if *n_classes <= 0 {
559 return Err(DatasetTraitError::Configuration(
560 "n_classes must be > 0".to_string(),
561 ));
562 }
563 }
564
565 Ok(())
566 }
567}
568
569pub struct RegressionGenerator;
571
572impl DatasetGenerator for RegressionGenerator {
573 type Config = GeneratorConfig;
574 type Output = InMemoryDataset;
575
576 fn generate(&self, config: Self::Config) -> DatasetTraitResult<Self::Output> {
577 let mut rng = match config.random_state {
578 Some(seed) => Random::seed(seed),
579 None => Random::seed(42),
580 };
581
582 let noise = config
584 .get_parameter("noise")
585 .and_then(|v| match v {
586 ConfigValue::Float(n) => Some(*n),
587 _ => None,
588 })
589 .unwrap_or(0.1);
590
591 let mut features = Array2::<f64>::zeros((config.n_samples, config.n_features));
593 let normal_dist = RandNormal::new(0.0, 1.0).unwrap();
594 for mut row in features.rows_mut() {
595 for val in row.iter_mut() {
596 *val = normal_dist.sample(&mut rng);
597 }
598 }
599
600 let coefficients: Array1<f64> =
602 Array1::from_shape_fn(config.n_features, |_| rng.random_range(-1.0..1.0));
603
604 let mut targets = Array1::<f64>::zeros(config.n_samples);
606 for (i, target) in targets.iter_mut().enumerate() {
607 let feature_row = features.row(i);
608 let noise_dist = RandNormal::new(0.0, noise).unwrap();
609 *target = feature_row.dot(&coefficients) + noise_dist.sample(&mut rng);
610 }
611
612 let mut metadata = HashMap::new();
613 metadata.insert("generator".to_string(), "regression".to_string());
614 metadata.insert("noise".to_string(), noise.to_string());
615
616 Ok(InMemoryDataset::with_metadata(
617 features,
618 Some(targets),
619 metadata,
620 ))
621 }
622
623 fn name(&self) -> &'static str {
624 "regression"
625 }
626
627 fn description(&self) -> &'static str {
628 "Generates a regression dataset with linear relationship and noise"
629 }
630}
631
632pub fn create_default_registry() -> GeneratorRegistry {
634 let mut registry = GeneratorRegistry::new();
635 registry.register(ClassificationGenerator);
636 registry.register(RegressionGenerator);
637 registry
638}
639
640#[allow(non_snake_case)]
641#[cfg(test)]
642mod tests {
643 use super::*;
644 use scirs2_core::ndarray::Array;
645
646 #[test]
647 fn test_in_memory_dataset() {
648 let features = Array::from_shape_vec((10, 3), (0..30).map(|x| x as f64).collect()).unwrap();
649 let targets = Array1::from_shape_vec(10, (0..10).map(|x| x as f64).collect()).unwrap();
650
651 let dataset = InMemoryDataset::new(features, Some(targets));
652
653 assert_eq!(dataset.n_samples(), 10);
654 assert_eq!(dataset.n_features(), 3);
655 assert_eq!(dataset.shape(), (10, 3));
656 assert!(dataset.has_targets());
657
658 let features_view = dataset.features().unwrap();
659 assert_eq!(features_view.dim(), (10, 3));
660
661 let sample = dataset.sample(5).unwrap();
662 assert_eq!(sample.len(), 3);
663 assert_eq!(sample[0], 15.0); let targets_view = dataset.targets().unwrap().unwrap();
666 assert_eq!(targets_view.len(), 10);
667 assert_eq!(targets_view[5], 5.0);
668 }
669
670 #[test]
671 fn test_generator_registry() {
672 let mut registry = GeneratorRegistry::new();
673 registry.register(ClassificationGenerator);
674 registry.register(RegressionGenerator);
675
676 let generators = registry.list();
677 assert!(generators.contains(&"classification".to_string()));
678 assert!(generators.contains(&"regression".to_string()));
679
680 let config = GeneratorConfig::new(50, 4);
681 let dataset = registry.generate("classification", config).unwrap();
682
683 assert_eq!(dataset.n_samples(), 50);
684 assert_eq!(dataset.n_features(), 4);
685 assert!(dataset.has_targets());
686 }
687
688 #[test]
689 fn test_classification_generator() {
690 let generator = ClassificationGenerator;
691 let mut config = GeneratorConfig::new(100, 5);
692 config.set_parameter("n_classes".to_string(), 3i64);
693 config.random_state = Some(42);
694
695 let dataset = generator.generate(config).unwrap();
696
697 assert_eq!(dataset.n_samples(), 100);
698 assert_eq!(dataset.n_features(), 5);
699 assert!(dataset.has_targets());
700
701 let targets = dataset.targets().unwrap().unwrap();
702 assert!(targets.iter().all(|&t| t >= 0.0 && t < 3.0));
703
704 let metadata = dataset.metadata();
705 assert_eq!(
706 metadata.get("generator"),
707 Some(&"classification".to_string())
708 );
709 assert_eq!(metadata.get("n_classes"), Some(&"3".to_string()));
710 }
711
712 #[test]
713 fn test_regression_generator() {
714 let generator = RegressionGenerator;
715 let mut config = GeneratorConfig::new(100, 3);
716 config.set_parameter("noise".to_string(), 0.05);
717 config.random_state = Some(42);
718
719 let dataset = generator.generate(config).unwrap();
720
721 assert_eq!(dataset.n_samples(), 100);
722 assert_eq!(dataset.n_features(), 3);
723 assert!(dataset.has_targets());
724
725 let metadata = dataset.metadata();
726 assert_eq!(metadata.get("generator"), Some(&"regression".to_string()));
727 assert_eq!(metadata.get("noise"), Some(&"0.05".to_string()));
728 }
729
730 #[test]
731 fn test_config_validation() {
732 let generator = ClassificationGenerator;
733
734 let valid_config = GeneratorConfig::new(100, 5);
736 assert!(generator.validate_config(&valid_config).is_ok());
737
738 let invalid_config = GeneratorConfig::new(0, 5);
740 assert!(generator.validate_config(&invalid_config).is_err());
741
742 let invalid_config = GeneratorConfig::new(100, 0);
743 assert!(generator.validate_config(&invalid_config).is_err());
744 }
745
746 #[test]
747 fn test_config_parameters() {
748 let mut config = GeneratorConfig::new(100, 5);
749
750 config.set_parameter("n_classes".to_string(), 3i64);
751 config.set_parameter("noise".to_string(), 0.1);
752 config.set_parameter("seed".to_string(), "test".to_string());
753 config.set_parameter("enabled".to_string(), true);
754
755 assert!(matches!(
756 config.get_parameter("n_classes"),
757 Some(ConfigValue::Int(3))
758 ));
759 assert!(matches!(
760 config.get_parameter("noise"),
761 Some(ConfigValue::Float(0.1))
762 ));
763 assert!(matches!(
764 config.get_parameter("seed"),
765 Some(ConfigValue::String(_))
766 ));
767 assert!(matches!(
768 config.get_parameter("enabled"),
769 Some(ConfigValue::Bool(true))
770 ));
771 }
772
773 #[test]
774 fn test_default_registry() {
775 let registry = create_default_registry();
776 let generators = registry.list();
777
778 assert!(generators.contains(&"classification".to_string()));
779 assert!(generators.contains(&"regression".to_string()));
780 assert_eq!(generators.len(), 2);
781 }
782
783 #[test]
784 fn test_mutable_dataset() {
785 let features = Array::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
786 let targets = Array1::from_shape_vec(3, vec![10.0, 20.0, 30.0]).unwrap();
787
788 let mut dataset = InMemoryDataset::new(features, Some(targets));
789
790 let new_sample = Array1::from_vec(vec![99.0, 88.0]);
792 assert!(dataset.set_sample(1, new_sample.view()).is_ok());
793
794 let updated_sample = dataset.sample(1).unwrap();
795 assert_eq!(updated_sample[0], 99.0);
796 assert_eq!(updated_sample[1], 88.0);
797
798 let wrong_sample = Array1::from_vec(vec![1.0, 2.0, 3.0]); assert!(dataset.set_sample(0, wrong_sample.view()).is_err());
801
802 let sample = Array1::from_vec(vec![1.0, 2.0]);
804 assert!(dataset.set_sample(10, sample.view()).is_err());
805 }
806}