sklears_datasets/
traits.rs

1//! Trait-based dataset framework for extensible dataset operations
2//!
3//! This module provides a comprehensive trait system for dataset generation,
4//! loading, transformation, and validation. It enables pluggable generators
5//! and composable generation strategies.
6
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
8use scirs2_core::random::{Distribution, RandNormal, Random};
9use std::collections::HashMap;
10use thiserror::Error;
11
12// Note: We inline the distribution sampling rather than using a helper function
13// because Random and Random<StdRng> are different types and hard to make generic
14
15/// Errors in the trait-based dataset framework
16#[derive(Error, Debug)]
17pub enum DatasetTraitError {
18    #[error("Generation error: {0}")]
19    Generation(String),
20    #[error("Validation error: {0}")]
21    Validation(String),
22    #[error("Configuration error: {0}")]
23    Configuration(String),
24    #[error("IO error: {0}")]
25    Io(String),
26    #[error("Dimension mismatch: expected {expected}, got {actual}")]
27    DimensionMismatch { expected: String, actual: String },
28    #[error("Unsupported operation: {0}")]
29    UnsupportedOperation(String),
30}
31
32pub type DatasetTraitResult<T> = Result<T, DatasetTraitError>;
33
34/// Core trait for any dataset representation
35pub trait Dataset {
36    /// Get the number of samples in the dataset
37    fn n_samples(&self) -> usize;
38
39    /// Get the number of features in the dataset
40    fn n_features(&self) -> usize;
41
42    /// Get the shape of the dataset as (n_samples, n_features)
43    fn shape(&self) -> (usize, usize) {
44        (self.n_samples(), self.n_features())
45    }
46
47    /// Get features as an array view
48    fn features(&self) -> DatasetTraitResult<ArrayView2<'_, f64>>;
49
50    /// Get a specific sample (row) by index
51    fn sample(&self, index: usize) -> DatasetTraitResult<ArrayView1<'_, f64>>;
52
53    /// Check if the dataset has target values
54    fn has_targets(&self) -> bool;
55
56    /// Get targets as an array view (if available)
57    fn targets(&self) -> DatasetTraitResult<Option<ArrayView1<'_, f64>>>;
58
59    /// Get dataset metadata
60    fn metadata(&self) -> HashMap<String, String> {
61        HashMap::new()
62    }
63}
64
65/// Trait for datasets that can be generated
66pub trait DatasetGenerator {
67    type Config: Default + Clone;
68    type Output: Dataset;
69
70    /// Generate a dataset with the given configuration
71    fn generate(&self, config: Self::Config) -> DatasetTraitResult<Self::Output>;
72
73    /// Get the name of this generator
74    fn name(&self) -> &'static str;
75
76    /// Get a description of what this generator produces
77    fn description(&self) -> &'static str;
78
79    /// Validate the configuration before generation
80    fn validate_config(&self, config: &Self::Config) -> DatasetTraitResult<()> {
81        let _ = config;
82        Ok(())
83    }
84}
85
86/// Trait for loading datasets from external sources
87pub trait DatasetLoader {
88    type Config: Default + Clone;
89    type Output: Dataset;
90
91    /// Load a dataset with the given configuration
92    fn load(&self, config: Self::Config) -> DatasetTraitResult<Self::Output>;
93
94    /// Get the name of this loader
95    fn name(&self) -> &'static str;
96
97    /// Get available datasets that this loader can handle
98    fn available_datasets(&self) -> Vec<String>;
99
100    /// Check if a dataset is available
101    fn has_dataset(&self, name: &str) -> bool {
102        self.available_datasets().contains(&name.to_string())
103    }
104}
105
106/// Trait for transforming datasets
107pub trait DatasetTransformer {
108    type Config: Default + Clone;
109    type Input: Dataset;
110    type Output: Dataset;
111
112    /// Transform a dataset
113    fn transform(
114        &self,
115        input: Self::Input,
116        config: Self::Config,
117    ) -> DatasetTraitResult<Self::Output>;
118
119    /// Get the name of this transformer
120    fn name(&self) -> &'static str;
121
122    /// Check if this transformer can handle the given input
123    fn can_transform(&self, input: &Self::Input) -> bool;
124}
125
126/// Trait for validating datasets
127pub trait DatasetValidator {
128    type Config: Default + Clone;
129    type Report: Default;
130
131    /// Validate a dataset and return a validation report
132    fn validate(
133        &self,
134        dataset: &dyn Dataset,
135        config: Self::Config,
136    ) -> DatasetTraitResult<Self::Report>;
137
138    /// Get the name of this validator
139    fn name(&self) -> &'static str;
140
141    /// Get validation criteria
142    fn criteria(&self) -> Vec<String>;
143}
144
145/// Trait for streaming dataset access
146pub trait StreamingDataset: Dataset {
147    type Batch;
148
149    /// Get a batch starting at the given index with the specified size
150    fn batch(&self, start: usize, size: usize) -> DatasetTraitResult<Self::Batch>;
151
152    /// Create an iterator over batches
153    fn batches(
154        &self,
155        batch_size: usize,
156    ) -> Box<dyn Iterator<Item = DatasetTraitResult<Self::Batch>>>;
157
158    /// Get the preferred batch size for this dataset
159    fn preferred_batch_size(&self) -> usize {
160        1000
161    }
162}
163
164/// Trait for mutable datasets
165pub trait MutableDataset: Dataset {
166    /// Set a specific sample
167    fn set_sample(&mut self, index: usize, sample: ArrayView1<f64>) -> DatasetTraitResult<()>;
168
169    /// Set targets (if supported)
170    fn set_targets(&mut self, targets: ArrayView1<f64>) -> DatasetTraitResult<()>;
171
172    /// Add a new sample to the dataset
173    fn add_sample(
174        &mut self,
175        sample: ArrayView1<f64>,
176        target: Option<f64>,
177    ) -> DatasetTraitResult<()>;
178
179    /// Remove a sample from the dataset
180    fn remove_sample(&mut self, index: usize) -> DatasetTraitResult<()>;
181}
182
183/// Configuration for composable generation strategies
184pub trait GenerationStrategy {
185    type Config: Default + Clone;
186
187    /// Apply this strategy to modify generation parameters
188    fn apply(&self, config: &mut Self::Config, rng: &mut Random) -> DatasetTraitResult<()>;
189
190    /// Get the name of this strategy
191    fn name(&self) -> &'static str;
192
193    /// Check if this strategy is applicable to the given configuration
194    fn is_applicable(&self, config: &Self::Config) -> bool;
195}
196
197/// A concrete implementation of Dataset trait for in-memory datasets
198#[derive(Debug, Clone)]
199pub struct InMemoryDataset {
200    features: Array2<f64>,
201    targets: Option<Array1<f64>>,
202    metadata: HashMap<String, String>,
203}
204
205impl InMemoryDataset {
206    /// Create a new in-memory dataset
207    pub fn new(features: Array2<f64>, targets: Option<Array1<f64>>) -> Self {
208        Self {
209            features,
210            targets,
211            metadata: HashMap::new(),
212        }
213    }
214
215    /// Create with metadata
216    pub fn with_metadata(
217        features: Array2<f64>,
218        targets: Option<Array1<f64>>,
219        metadata: HashMap<String, String>,
220    ) -> Self {
221        Self {
222            features,
223            targets,
224            metadata,
225        }
226    }
227
228    /// Add metadata entry
229    pub fn add_metadata(&mut self, key: String, value: String) {
230        self.metadata.insert(key, value);
231    }
232}
233
234impl Dataset for InMemoryDataset {
235    fn n_samples(&self) -> usize {
236        self.features.nrows()
237    }
238
239    fn n_features(&self) -> usize {
240        self.features.ncols()
241    }
242
243    fn features(&self) -> DatasetTraitResult<ArrayView2<'_, f64>> {
244        Ok(self.features.view())
245    }
246
247    fn sample(&self, index: usize) -> DatasetTraitResult<ArrayView1<'_, f64>> {
248        if index >= self.n_samples() {
249            return Err(DatasetTraitError::DimensionMismatch {
250                expected: format!("index < {}", self.n_samples()),
251                actual: format!("index = {}", index),
252            });
253        }
254        Ok(self.features.row(index))
255    }
256
257    fn has_targets(&self) -> bool {
258        self.targets.is_some()
259    }
260
261    fn targets(&self) -> DatasetTraitResult<Option<ArrayView1<'_, f64>>> {
262        Ok(self.targets.as_ref().map(|t| t.view()))
263    }
264
265    fn metadata(&self) -> HashMap<String, String> {
266        self.metadata.clone()
267    }
268}
269
270impl MutableDataset for InMemoryDataset {
271    fn set_sample(&mut self, index: usize, sample: ArrayView1<f64>) -> DatasetTraitResult<()> {
272        if index >= self.n_samples() {
273            return Err(DatasetTraitError::DimensionMismatch {
274                expected: format!("index < {}", self.n_samples()),
275                actual: format!("index = {}", index),
276            });
277        }
278        if sample.len() != self.n_features() {
279            return Err(DatasetTraitError::DimensionMismatch {
280                expected: format!("{} features", self.n_features()),
281                actual: format!("{} features", sample.len()),
282            });
283        }
284        self.features.row_mut(index).assign(&sample);
285        Ok(())
286    }
287
288    fn set_targets(&mut self, targets: ArrayView1<f64>) -> DatasetTraitResult<()> {
289        if targets.len() != self.n_samples() {
290            return Err(DatasetTraitError::DimensionMismatch {
291                expected: format!("{} targets", self.n_samples()),
292                actual: format!("{} targets", targets.len()),
293            });
294        }
295        self.targets = Some(targets.to_owned());
296        Ok(())
297    }
298
299    fn add_sample(
300        &mut self,
301        sample: ArrayView1<f64>,
302        _target: Option<f64>,
303    ) -> DatasetTraitResult<()> {
304        if sample.len() != self.n_features() {
305            return Err(DatasetTraitError::DimensionMismatch {
306                expected: format!("{} features", self.n_features()),
307                actual: format!("{} features", sample.len()),
308            });
309        }
310
311        // This is a simplified implementation - in practice, you'd need to resize the arrays
312        Err(DatasetTraitError::UnsupportedOperation(
313            "Adding samples to fixed-size arrays not yet implemented".to_string(),
314        ))
315    }
316
317    fn remove_sample(&mut self, index: usize) -> DatasetTraitResult<()> {
318        if index >= self.n_samples() {
319            return Err(DatasetTraitError::DimensionMismatch {
320                expected: format!("index < {}", self.n_samples()),
321                actual: format!("index = {}", index),
322            });
323        }
324
325        // This is a simplified implementation - in practice, you'd need to resize the arrays
326        Err(DatasetTraitError::UnsupportedOperation(
327            "Removing samples from fixed-size arrays not yet implemented".to_string(),
328        ))
329    }
330}
331
332/// Registry for dataset generators
333pub struct GeneratorRegistry {
334    generators: HashMap<
335        String,
336        Box<dyn DatasetGenerator<Config = GeneratorConfig, Output = InMemoryDataset>>,
337    >,
338}
339
340impl GeneratorRegistry {
341    /// Create a new registry
342    pub fn new() -> Self {
343        Self {
344            generators: HashMap::new(),
345        }
346    }
347
348    /// Register a generator
349    pub fn register<G>(&mut self, generator: G)
350    where
351        G: DatasetGenerator<Config = GeneratorConfig, Output = InMemoryDataset> + 'static,
352    {
353        self.generators
354            .insert(generator.name().to_string(), Box::new(generator));
355    }
356
357    /// Get a generator by name
358    pub fn get(
359        &self,
360        name: &str,
361    ) -> Option<&dyn DatasetGenerator<Config = GeneratorConfig, Output = InMemoryDataset>> {
362        self.generators.get(name).map(|g| g.as_ref())
363    }
364
365    /// List all available generators
366    pub fn list(&self) -> Vec<String> {
367        self.generators.keys().cloned().collect()
368    }
369
370    /// Generate a dataset using a named generator
371    pub fn generate(
372        &self,
373        name: &str,
374        config: GeneratorConfig,
375    ) -> DatasetTraitResult<InMemoryDataset> {
376        let generator = self.get(name).ok_or_else(|| {
377            DatasetTraitError::Configuration(format!("Unknown generator: {}", name))
378        })?;
379        generator.generate(config)
380    }
381}
382
383impl Default for GeneratorRegistry {
384    fn default() -> Self {
385        Self::new()
386    }
387}
388
389/// Universal configuration for generators
390#[derive(Debug, Clone)]
391pub struct GeneratorConfig {
392    pub n_samples: usize,
393    pub n_features: usize,
394    pub random_state: Option<u64>,
395    pub parameters: HashMap<String, ConfigValue>,
396}
397
398impl Default for GeneratorConfig {
399    fn default() -> Self {
400        Self {
401            n_samples: 100,
402            n_features: 2,
403            random_state: None,
404            parameters: HashMap::new(),
405        }
406    }
407}
408
409impl GeneratorConfig {
410    /// Create a new configuration
411    pub fn new(n_samples: usize, n_features: usize) -> Self {
412        Self {
413            n_samples,
414            n_features,
415            random_state: None,
416            parameters: HashMap::new(),
417        }
418    }
419
420    /// Set a parameter
421    pub fn set_parameter<T: Into<ConfigValue>>(&mut self, key: String, value: T) {
422        self.parameters.insert(key, value.into());
423    }
424
425    /// Get a parameter
426    pub fn get_parameter(&self, key: &str) -> Option<&ConfigValue> {
427        self.parameters.get(key)
428    }
429
430    /// Set random state
431    pub fn with_random_state(mut self, seed: u64) -> Self {
432        self.random_state = Some(seed);
433        self
434    }
435}
436
437/// Configuration value types
438#[derive(Debug, Clone)]
439pub enum ConfigValue {
440    /// Int
441    Int(i64),
442    /// Float
443    Float(f64),
444    /// String
445    String(String),
446    /// Bool
447    Bool(bool),
448    /// IntArray
449    IntArray(Vec<i64>),
450    /// FloatArray
451    FloatArray(Vec<f64>),
452}
453
454impl From<i64> for ConfigValue {
455    fn from(value: i64) -> Self {
456        ConfigValue::Int(value)
457    }
458}
459
460impl From<f64> for ConfigValue {
461    fn from(value: f64) -> Self {
462        ConfigValue::Float(value)
463    }
464}
465
466impl From<String> for ConfigValue {
467    fn from(value: String) -> Self {
468        ConfigValue::String(value)
469    }
470}
471
472impl From<bool> for ConfigValue {
473    fn from(value: bool) -> Self {
474        ConfigValue::Bool(value)
475    }
476}
477
478impl From<Vec<i64>> for ConfigValue {
479    fn from(value: Vec<i64>) -> Self {
480        ConfigValue::IntArray(value)
481    }
482}
483
484impl From<Vec<f64>> for ConfigValue {
485    fn from(value: Vec<f64>) -> Self {
486        ConfigValue::FloatArray(value)
487    }
488}
489
490/// Example implementation: Classification generator
491pub struct ClassificationGenerator;
492
493impl DatasetGenerator for ClassificationGenerator {
494    type Config = GeneratorConfig;
495    type Output = InMemoryDataset;
496
497    fn generate(&self, config: Self::Config) -> DatasetTraitResult<Self::Output> {
498        let mut rng = match config.random_state {
499            Some(seed) => Random::seed(seed),
500            None => Random::seed(42),
501        };
502
503        // Get number of classes from parameters
504        let n_classes = config
505            .get_parameter("n_classes")
506            .and_then(|v| match v {
507                ConfigValue::Int(n) => Some(*n as usize),
508                _ => None,
509            })
510            .unwrap_or(2);
511
512        // Generate features
513        let mut features = Array2::<f64>::zeros((config.n_samples, config.n_features));
514        let normal_dist = RandNormal::new(0.0, 1.0).unwrap();
515        for mut row in features.rows_mut() {
516            for val in row.iter_mut() {
517                *val = normal_dist.sample(&mut rng);
518            }
519        }
520
521        // Generate targets
522        let targets: Array1<f64> =
523            Array1::from_shape_fn(config.n_samples, |_| rng.gen_range(0..n_classes) as f64);
524
525        let mut metadata = HashMap::new();
526        metadata.insert("generator".to_string(), "classification".to_string());
527        metadata.insert("n_classes".to_string(), n_classes.to_string());
528
529        Ok(InMemoryDataset::with_metadata(
530            features,
531            Some(targets),
532            metadata,
533        ))
534    }
535
536    fn name(&self) -> &'static str {
537        "classification"
538    }
539
540    fn description(&self) -> &'static str {
541        "Generates a classification dataset with Gaussian features"
542    }
543
544    fn validate_config(&self, config: &Self::Config) -> DatasetTraitResult<()> {
545        if config.n_samples == 0 {
546            return Err(DatasetTraitError::Configuration(
547                "n_samples must be > 0".to_string(),
548            ));
549        }
550        if config.n_features == 0 {
551            return Err(DatasetTraitError::Configuration(
552                "n_features must be > 0".to_string(),
553            ));
554        }
555
556        // Validate n_classes parameter
557        if let Some(ConfigValue::Int(n_classes)) = config.get_parameter("n_classes") {
558            if *n_classes <= 0 {
559                return Err(DatasetTraitError::Configuration(
560                    "n_classes must be > 0".to_string(),
561                ));
562            }
563        }
564
565        Ok(())
566    }
567}
568
569/// Example implementation: Regression generator
570pub struct RegressionGenerator;
571
572impl DatasetGenerator for RegressionGenerator {
573    type Config = GeneratorConfig;
574    type Output = InMemoryDataset;
575
576    fn generate(&self, config: Self::Config) -> DatasetTraitResult<Self::Output> {
577        let mut rng = match config.random_state {
578            Some(seed) => Random::seed(seed),
579            None => Random::seed(42),
580        };
581
582        // Get noise level from parameters
583        let noise = config
584            .get_parameter("noise")
585            .and_then(|v| match v {
586                ConfigValue::Float(n) => Some(*n),
587                _ => None,
588            })
589            .unwrap_or(0.1);
590
591        // Generate features
592        let mut features = Array2::<f64>::zeros((config.n_samples, config.n_features));
593        let normal_dist = RandNormal::new(0.0, 1.0).unwrap();
594        for mut row in features.rows_mut() {
595            for val in row.iter_mut() {
596                *val = normal_dist.sample(&mut rng);
597            }
598        }
599
600        // Generate random coefficients
601        let coefficients: Array1<f64> =
602            Array1::from_shape_fn(config.n_features, |_| rng.random_range(-1.0..1.0));
603
604        // Generate targets
605        let mut targets = Array1::<f64>::zeros(config.n_samples);
606        for (i, target) in targets.iter_mut().enumerate() {
607            let feature_row = features.row(i);
608            let noise_dist = RandNormal::new(0.0, noise).unwrap();
609            *target = feature_row.dot(&coefficients) + noise_dist.sample(&mut rng);
610        }
611
612        let mut metadata = HashMap::new();
613        metadata.insert("generator".to_string(), "regression".to_string());
614        metadata.insert("noise".to_string(), noise.to_string());
615
616        Ok(InMemoryDataset::with_metadata(
617            features,
618            Some(targets),
619            metadata,
620        ))
621    }
622
623    fn name(&self) -> &'static str {
624        "regression"
625    }
626
627    fn description(&self) -> &'static str {
628        "Generates a regression dataset with linear relationship and noise"
629    }
630}
631
632/// Factory function to create a default registry with standard generators
633pub fn create_default_registry() -> GeneratorRegistry {
634    let mut registry = GeneratorRegistry::new();
635    registry.register(ClassificationGenerator);
636    registry.register(RegressionGenerator);
637    registry
638}
639
640#[allow(non_snake_case)]
641#[cfg(test)]
642mod tests {
643    use super::*;
644    use scirs2_core::ndarray::Array;
645
646    #[test]
647    fn test_in_memory_dataset() {
648        let features = Array::from_shape_vec((10, 3), (0..30).map(|x| x as f64).collect()).unwrap();
649        let targets = Array1::from_shape_vec(10, (0..10).map(|x| x as f64).collect()).unwrap();
650
651        let dataset = InMemoryDataset::new(features, Some(targets));
652
653        assert_eq!(dataset.n_samples(), 10);
654        assert_eq!(dataset.n_features(), 3);
655        assert_eq!(dataset.shape(), (10, 3));
656        assert!(dataset.has_targets());
657
658        let features_view = dataset.features().unwrap();
659        assert_eq!(features_view.dim(), (10, 3));
660
661        let sample = dataset.sample(5).unwrap();
662        assert_eq!(sample.len(), 3);
663        assert_eq!(sample[0], 15.0); // 5 * 3 + 0
664
665        let targets_view = dataset.targets().unwrap().unwrap();
666        assert_eq!(targets_view.len(), 10);
667        assert_eq!(targets_view[5], 5.0);
668    }
669
670    #[test]
671    fn test_generator_registry() {
672        let mut registry = GeneratorRegistry::new();
673        registry.register(ClassificationGenerator);
674        registry.register(RegressionGenerator);
675
676        let generators = registry.list();
677        assert!(generators.contains(&"classification".to_string()));
678        assert!(generators.contains(&"regression".to_string()));
679
680        let config = GeneratorConfig::new(50, 4);
681        let dataset = registry.generate("classification", config).unwrap();
682
683        assert_eq!(dataset.n_samples(), 50);
684        assert_eq!(dataset.n_features(), 4);
685        assert!(dataset.has_targets());
686    }
687
688    #[test]
689    fn test_classification_generator() {
690        let generator = ClassificationGenerator;
691        let mut config = GeneratorConfig::new(100, 5);
692        config.set_parameter("n_classes".to_string(), 3i64);
693        config.random_state = Some(42);
694
695        let dataset = generator.generate(config).unwrap();
696
697        assert_eq!(dataset.n_samples(), 100);
698        assert_eq!(dataset.n_features(), 5);
699        assert!(dataset.has_targets());
700
701        let targets = dataset.targets().unwrap().unwrap();
702        assert!(targets.iter().all(|&t| t >= 0.0 && t < 3.0));
703
704        let metadata = dataset.metadata();
705        assert_eq!(
706            metadata.get("generator"),
707            Some(&"classification".to_string())
708        );
709        assert_eq!(metadata.get("n_classes"), Some(&"3".to_string()));
710    }
711
712    #[test]
713    fn test_regression_generator() {
714        let generator = RegressionGenerator;
715        let mut config = GeneratorConfig::new(100, 3);
716        config.set_parameter("noise".to_string(), 0.05);
717        config.random_state = Some(42);
718
719        let dataset = generator.generate(config).unwrap();
720
721        assert_eq!(dataset.n_samples(), 100);
722        assert_eq!(dataset.n_features(), 3);
723        assert!(dataset.has_targets());
724
725        let metadata = dataset.metadata();
726        assert_eq!(metadata.get("generator"), Some(&"regression".to_string()));
727        assert_eq!(metadata.get("noise"), Some(&"0.05".to_string()));
728    }
729
730    #[test]
731    fn test_config_validation() {
732        let generator = ClassificationGenerator;
733
734        // Valid config
735        let valid_config = GeneratorConfig::new(100, 5);
736        assert!(generator.validate_config(&valid_config).is_ok());
737
738        // Invalid configs
739        let invalid_config = GeneratorConfig::new(0, 5);
740        assert!(generator.validate_config(&invalid_config).is_err());
741
742        let invalid_config = GeneratorConfig::new(100, 0);
743        assert!(generator.validate_config(&invalid_config).is_err());
744    }
745
746    #[test]
747    fn test_config_parameters() {
748        let mut config = GeneratorConfig::new(100, 5);
749
750        config.set_parameter("n_classes".to_string(), 3i64);
751        config.set_parameter("noise".to_string(), 0.1);
752        config.set_parameter("seed".to_string(), "test".to_string());
753        config.set_parameter("enabled".to_string(), true);
754
755        assert!(matches!(
756            config.get_parameter("n_classes"),
757            Some(ConfigValue::Int(3))
758        ));
759        assert!(matches!(
760            config.get_parameter("noise"),
761            Some(ConfigValue::Float(0.1))
762        ));
763        assert!(matches!(
764            config.get_parameter("seed"),
765            Some(ConfigValue::String(_))
766        ));
767        assert!(matches!(
768            config.get_parameter("enabled"),
769            Some(ConfigValue::Bool(true))
770        ));
771    }
772
773    #[test]
774    fn test_default_registry() {
775        let registry = create_default_registry();
776        let generators = registry.list();
777
778        assert!(generators.contains(&"classification".to_string()));
779        assert!(generators.contains(&"regression".to_string()));
780        assert_eq!(generators.len(), 2);
781    }
782
783    #[test]
784    fn test_mutable_dataset() {
785        let features = Array::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
786        let targets = Array1::from_shape_vec(3, vec![10.0, 20.0, 30.0]).unwrap();
787
788        let mut dataset = InMemoryDataset::new(features, Some(targets));
789
790        // Test setting a sample
791        let new_sample = Array1::from_vec(vec![99.0, 88.0]);
792        assert!(dataset.set_sample(1, new_sample.view()).is_ok());
793
794        let updated_sample = dataset.sample(1).unwrap();
795        assert_eq!(updated_sample[0], 99.0);
796        assert_eq!(updated_sample[1], 88.0);
797
798        // Test dimension mismatch
799        let wrong_sample = Array1::from_vec(vec![1.0, 2.0, 3.0]); // Wrong size
800        assert!(dataset.set_sample(0, wrong_sample.view()).is_err());
801
802        // Test index out of bounds
803        let sample = Array1::from_vec(vec![1.0, 2.0]);
804        assert!(dataset.set_sample(10, sample.view()).is_err());
805    }
806}