scirs2_datasets/utils/
dataset.rs

1//! Core Dataset structure and basic methods
2//!
3//! This module provides the main Dataset struct used throughout the datasets
4//! crate, along with its core methods for creation, metadata management, and
5//! basic properties.
6
7use crate::utils::serialization;
8use ndarray::{Array1, Array2};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11
12/// Represents a dataset with features, optional targets, and metadata
13///
14/// The Dataset struct is the core data structure for managing machine learning
15/// datasets. It stores the feature matrix, optional target values, and rich
16/// metadata including feature names, descriptions, and arbitrary key-value pairs.
17///
18/// # Examples
19///
20/// ```rust
21/// use ndarray::Array2;
22/// use scirs2_datasets::utils::Dataset;
23///
24/// let data = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
25/// let dataset = Dataset::new(data, None)
26///     .with_feature_names(vec!["feature1".to_string(), "feature2".to_string()])
27///     .with_description("Sample dataset".to_string());
28///
29/// assert_eq!(dataset.n_samples(), 3);
30/// assert_eq!(dataset.n_features(), 2);
31/// ```
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct Dataset {
34    /// Features/data matrix (n_samples, n_features)
35    #[serde(
36        serialize_with = "serialization::serialize_array2",
37        deserialize_with = "serialization::deserialize_array2"
38    )]
39    pub data: Array2<f64>,
40
41    /// Optional target values
42    #[serde(skip_serializing_if = "Option::is_none")]
43    pub target: Option<Array1<f64>>,
44
45    /// Optional target names for classification problems
46    #[serde(skip_serializing_if = "Option::is_none")]
47    pub target_names: Option<Vec<String>>,
48
49    /// Optional feature names
50    #[serde(skip_serializing_if = "Option::is_none")]
51    pub feature_names: Option<Vec<String>>,
52
53    /// Optional descriptions for each feature
54    #[serde(skip_serializing_if = "Option::is_none")]
55    pub feature_descriptions: Option<Vec<String>>,
56
57    /// Optional dataset description
58    #[serde(skip_serializing_if = "Option::is_none")]
59    pub description: Option<String>,
60
61    /// Optional dataset metadata
62    pub metadata: HashMap<String, String>,
63}
64
65impl Dataset {
66    /// Create a new dataset with the given data and target
67    ///
68    /// # Arguments
69    ///
70    /// * `data` - The feature matrix (n_samples, n_features)
71    /// * `target` - Optional target values (n_samples,)
72    ///
73    /// # Returns
74    ///
75    /// A new Dataset instance with empty metadata
76    ///
77    /// # Examples
78    ///
79    /// ```rust
80    /// use ndarray::{Array1, Array2};
81    /// use scirs2_datasets::utils::Dataset;
82    ///
83    /// let data = Array2::zeros((100, 5));
84    /// let target = Some(Array1::zeros(100));
85    /// let dataset = Dataset::new(data, target);
86    /// ```
87    pub fn new(data: Array2<f64>, target: Option<Array1<f64>>) -> Self {
88        Dataset {
89            data,
90            target,
91            target_names: None,
92            feature_names: None,
93            feature_descriptions: None,
94            description: None,
95            metadata: HashMap::new(),
96        }
97    }
98
99    /// Add target names to the dataset (builder pattern)
100    ///
101    /// # Arguments
102    ///
103    /// * `target_names` - Vector of target class names
104    ///
105    /// # Returns
106    ///
107    /// Self for method chaining
108    pub fn with_target_names(mut self, target_names: Vec<String>) -> Self {
109        self.target_names = Some(target_names);
110        self
111    }
112
113    /// Add feature names to the dataset (builder pattern)
114    ///
115    /// # Arguments
116    ///
117    /// * `feature_names` - Vector of feature names
118    ///
119    /// # Returns
120    ///
121    /// Self for method chaining
122    pub fn with_feature_names(mut self, feature_names: Vec<String>) -> Self {
123        self.feature_names = Some(feature_names);
124        self
125    }
126
127    /// Add feature descriptions to the dataset (builder pattern)
128    ///
129    /// # Arguments
130    ///
131    /// * `feature_descriptions` - Vector of feature descriptions
132    ///
133    /// # Returns
134    ///
135    /// Self for method chaining
136    pub fn with_feature_descriptions(mut self, feature_descriptions: Vec<String>) -> Self {
137        self.feature_descriptions = Some(feature_descriptions);
138        self
139    }
140
141    /// Add a description to the dataset (builder pattern)
142    ///
143    /// # Arguments
144    ///
145    /// * `description` - Dataset description
146    ///
147    /// # Returns
148    ///
149    /// Self for method chaining
150    pub fn with_description(mut self, description: String) -> Self {
151        self.description = Some(description);
152        self
153    }
154
155    /// Add metadata to the dataset (builder pattern)
156    ///
157    /// # Arguments
158    ///
159    /// * `key` - Metadata key
160    /// * `value` - Metadata value
161    ///
162    /// # Returns
163    ///
164    /// Self for method chaining
165    pub fn with_metadata(mut self, key: &str, value: &str) -> Self {
166        self.metadata.insert(key.to_string(), value.to_string());
167        self
168    }
169
170    /// Get the number of samples in the dataset
171    ///
172    /// # Returns
173    ///
174    /// Number of samples (rows) in the dataset
175    pub fn n_samples(&self) -> usize {
176        self.data.nrows()
177    }
178
179    /// Get the number of features in the dataset
180    ///
181    /// # Returns
182    ///
183    /// Number of features (columns) in the dataset
184    pub fn n_features(&self) -> usize {
185        self.data.ncols()
186    }
187
188    /// Get dataset shape as (n_samples, n_features)
189    ///
190    /// # Returns
191    ///
192    /// Tuple of (n_samples, n_features)
193    pub fn shape(&self) -> (usize, usize) {
194        (self.n_samples(), self.n_features())
195    }
196
197    /// Check if the dataset has target values
198    ///
199    /// # Returns
200    ///
201    /// True if target values are present, false otherwise
202    pub fn has_target(&self) -> bool {
203        self.target.is_some()
204    }
205
206    /// Get a reference to the feature names if available
207    ///
208    /// # Returns
209    ///
210    /// Optional reference to feature names vector
211    pub fn feature_names(&self) -> Option<&Vec<String>> {
212        self.feature_names.as_ref()
213    }
214
215    /// Get a reference to the target names if available
216    ///
217    /// # Returns
218    ///
219    /// Optional reference to target names vector  
220    pub fn target_names(&self) -> Option<&Vec<String>> {
221        self.target_names.as_ref()
222    }
223
224    /// Get a reference to the dataset description if available
225    ///
226    /// # Returns
227    ///
228    /// Optional reference to dataset description
229    pub fn description(&self) -> Option<&String> {
230        self.description.as_ref()
231    }
232
233    /// Get a reference to the metadata
234    ///
235    /// # Returns
236    ///
237    /// Reference to metadata HashMap
238    pub fn metadata(&self) -> &HashMap<String, String> {
239        &self.metadata
240    }
241
242    /// Add or update a metadata entry
243    ///
244    /// # Arguments
245    ///
246    /// * `key` - Metadata key
247    /// * `value` - Metadata value
248    pub fn set_metadata(&mut self, key: &str, value: &str) {
249        self.metadata.insert(key.to_string(), value.to_string());
250    }
251
252    /// Get a metadata value by key
253    ///
254    /// # Arguments
255    ///
256    /// * `key` - Metadata key to lookup
257    ///
258    /// # Returns
259    ///
260    /// Optional reference to the metadata value
261    pub fn get_metadata(&self, key: &str) -> Option<&String> {
262        self.metadata.get(key)
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269    use ndarray::array;
270
271    #[test]
272    fn test_dataset_creation() {
273        let data = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
274        let target = Some(array![0.0, 1.0, 0.0]);
275
276        let dataset = Dataset::new(data.clone(), target.clone());
277
278        assert_eq!(dataset.n_samples(), 3);
279        assert_eq!(dataset.n_features(), 2);
280        assert_eq!(dataset.shape(), (3, 2));
281        assert!(dataset.has_target());
282        assert_eq!(dataset.data, data);
283        assert_eq!(dataset.target, target);
284    }
285
286    #[test]
287    fn test_dataset_builder_pattern() {
288        let data = array![[1.0, 2.0], [3.0, 4.0]];
289
290        let dataset = Dataset::new(data, None)
291            .with_feature_names(vec!["feat1".to_string(), "feat2".to_string()])
292            .with_description("Test dataset".to_string())
293            .with_metadata("version", "1.0")
294            .with_metadata("author", "test");
295
296        assert_eq!(dataset.feature_names().unwrap().len(), 2);
297        assert_eq!(dataset.description().unwrap(), "Test dataset");
298        assert_eq!(dataset.get_metadata("version").unwrap(), "1.0");
299        assert_eq!(dataset.get_metadata("author").unwrap(), "test");
300    }
301
302    #[test]
303    fn test_dataset_without_target() {
304        let data = array![[1.0, 2.0], [3.0, 4.0]];
305        let dataset = Dataset::new(data, None);
306
307        assert!(!dataset.has_target());
308        assert!(dataset.target.is_none());
309    }
310
311    #[test]
312    fn test_metadata_operations() {
313        let data = array![[1.0, 2.0]];
314        let mut dataset = Dataset::new(data, None);
315
316        dataset.set_metadata("key1", "value1");
317        dataset.set_metadata("key2", "value2");
318
319        assert_eq!(dataset.get_metadata("key1").unwrap(), "value1");
320        assert_eq!(dataset.get_metadata("key2").unwrap(), "value2");
321        assert!(dataset.get_metadata("nonexistent").is_none());
322
323        // Update existing key
324        dataset.set_metadata("key1", "updated_value");
325        assert_eq!(dataset.get_metadata("key1").unwrap(), "updated_value");
326    }
327}