scirs2_datasets/
toy.rs

1//! Toy datasets for testing and examples
2//!
3//! This module provides small, synthetic datasets that are useful for
4//! testing algorithms and illustrating concepts.
5
6use crate::error::Result;
7use crate::utils::Dataset;
8use scirs2_core::ndarray::{Array1, Array2};
9use scirs2_core::random::prelude::*;
10
11/// Generate the classic Iris dataset
12#[allow(dead_code)]
13pub fn load_iris() -> Result<Dataset> {
14    // Define the data
15    #[rustfmt::skip]
16    let data = Array2::from_shape_vec((150, 4), vec![
17        5.1, 3.5, 1.4, 0.2, 4.9, 3.0, 1.4, 0.2, 4.7, 3.2, 1.3, 0.2, 4.6, 3.1, 1.5, 0.2, 5.0, 3.6, 1.4, 0.2,
18        5.4, 3.9, 1.7, 0.4, 4.6, 3.4, 1.4, 0.3, 5.0, 3.4, 1.5, 0.2, 4.4, 2.9, 1.4, 0.2, 4.9, 3.1, 1.5, 0.1,
19        5.4, 3.7, 1.5, 0.2, 4.8, 3.4, 1.6, 0.2, 4.8, 3.0, 1.4, 0.1, 4.3, 3.0, 1.1, 0.1, 5.8, 4.0, 1.2, 0.2,
20        5.7, 4.4, 1.5, 0.4, 5.4, 3.9, 1.3, 0.4, 5.1, 3.5, 1.4, 0.3, 5.7, 3.8, 1.7, 0.3, 5.1, 3.8, 1.5, 0.3,
21        5.4, 3.4, 1.7, 0.2, 5.1, 3.7, 1.5, 0.4, 4.6, 3.6, 1.0, 0.2, 5.1, 3.3, 1.7, 0.5, 4.8, 3.4, 1.9, 0.2,
22        5.0, 3.0, 1.6, 0.2, 5.0, 3.4, 1.6, 0.4, 5.2, 3.5, 1.5, 0.2, 5.2, 3.4, 1.4, 0.2, 4.7, 3.2, 1.6, 0.2,
23        4.8, 3.1, 1.6, 0.2, 5.4, 3.4, 1.5, 0.4, 5.2, 4.1, 1.5, 0.1, 5.5, 4.2, 1.4, 0.2, 4.9, 3.1, 1.5, 0.1,
24        5.0, 3.2, 1.2, 0.2, 5.5, 3.5, 1.3, 0.2, 4.9, 3.1, 1.5, 0.1, 4.4, 3.0, 1.3, 0.2, 5.1, 3.4, 1.5, 0.2,
25        5.0, 3.5, 1.3, 0.3, 4.5, 2.3, 1.3, 0.3, 4.4, 3.2, 1.3, 0.2, 5.0, 3.5, 1.6, 0.6, 5.1, 3.8, 1.9, 0.4,
26        4.8, 3.0, 1.4, 0.3, 5.1, 3.8, 1.6, 0.2, 4.6, 3.2, 1.4, 0.2, 5.3, 3.7, 1.5, 0.2, 5.0, 3.3, 1.4, 0.2,
27        7.0, 3.2, 4.7, 1.4, 6.4, 3.2, 4.5, 1.5, 6.9, 3.1, 4.9, 1.5, 5.5, 2.3, 4.0, 1.3, 6.5, 2.8, 4.6, 1.5,
28        5.7, 2.8, 4.5, 1.3, 6.3, 3.3, 4.7, 1.6, 4.9, 2.4, 3.3, 1.0, 6.6, 2.9, 4.6, 1.3, 5.2, 2.7, 3.9, 1.4,
29        5.0, 2.0, 3.5, 1.0, 5.9, 3.0, 4.2, 1.5, 6.0, 2.2, 4.0, 1.0, 6.1, 2.9, 4.7, 1.4, 5.6, 2.9, 3.6, 1.3,
30        6.7, 3.1, 4.4, 1.4, 5.6, 3.0, 4.5, 1.5, 5.8, 2.7, 4.1, 1.0, 6.2, 2.2, 4.5, 1.5, 5.6, 2.5, 3.9, 1.1,
31        5.9, 3.2, 4.8, 1.8, 6.1, 2.8, 4.0, 1.3, 6.3, 2.5, 4.9, 1.5, 6.1, 2.8, 4.7, 1.2, 6.4, 2.9, 4.3, 1.3,
32        6.6, 3.0, 4.4, 1.4, 6.8, 2.8, 4.8, 1.4, 6.7, 3.0, 5.0, 1.7, 6.0, 2.9, 4.5, 1.5, 5.7, 2.6, 3.5, 1.0,
33        5.5, 2.4, 3.8, 1.1, 5.5, 2.4, 3.7, 1.0, 5.8, 2.7, 3.9, 1.2, 6.0, 2.7, 5.1, 1.6, 5.4, 3.0, 4.5, 1.5,
34        6.0, 3.4, 4.5, 1.6, 6.7, 3.1, 4.7, 1.5, 6.3, 2.3, 4.4, 1.3, 5.6, 3.0, 4.1, 1.3, 5.5, 2.5, 4.0, 1.3,
35        5.5, 2.6, 4.4, 1.2, 6.1, 3.0, 4.6, 1.4, 5.8, 2.6, 4.0, 1.2, 5.0, 2.3, 3.3, 1.0, 5.6, 2.7, 4.2, 1.3,
36        5.7, 3.0, 4.2, 1.2, 5.7, 2.9, 4.2, 1.3, 6.2, 2.9, 4.3, 1.3, 5.1, 2.5, 3.0, 1.1, 5.7, 2.8, 4.1, 1.3,
37        6.3, 3.3, 6.0, 2.5, 5.8, 2.7, 5.1, 1.9, 7.1, 3.0, 5.9, 2.1, 6.3, 2.9, 5.6, 1.8, 6.5, 3.0, 5.8, 2.2,
38        7.6, 3.0, 6.6, 2.1, 4.9, 2.5, 4.5, 1.7, 7.3, 2.9, 6.3, 1.8, 6.7, 2.5, 5.8, 1.8, 7.2, 3.6, 6.1, 2.5,
39        6.5, 3.2, 5.1, 2.0, 6.4, 2.7, 5.3, 1.9, 6.8, 3.0, 5.5, 2.1, 5.7, 2.5, 5.0, 2.0, 5.8, 2.8, 5.1, 2.4,
40        6.4, 3.2, 5.3, 2.3, 6.5, 3.0, 5.5, 1.8, 7.7, 3.8, 6.7, 2.2, 7.7, 2.6, 6.9, 2.3, 6.0, 2.2, 5.0, 1.5,
41        6.9, 3.2, 5.7, 2.3, 5.6, 2.8, 4.9, 2.0, 7.7, 2.8, 6.7, 2.0, 6.3, 2.7, 4.9, 1.8, 6.7, 3.3, 5.7, 2.1,
42        7.2, 3.2, 6.0, 1.8, 6.2, 2.8, 4.8, 1.8, 6.1, 3.0, 4.9, 1.8, 6.4, 2.8, 5.6, 2.1, 7.2, 3.0, 5.8, 1.6,
43        7.4, 2.8, 6.1, 1.9, 7.9, 3.8, 6.4, 2.0, 6.4, 2.8, 5.6, 2.2, 6.3, 2.8, 5.1, 1.5, 6.1, 2.6, 5.6, 1.4,
44        7.7, 3.0, 6.1, 2.3, 6.3, 3.4, 5.6, 2.4, 6.4, 3.1, 5.5, 1.8, 6.0, 3.0, 4.8, 1.8, 6.9, 3.1, 5.4, 2.1,
45        6.7, 3.1, 5.6, 2.4, 6.9, 3.1, 5.1, 2.3, 5.8, 2.7, 5.1, 1.9, 6.8, 3.2, 5.9, 2.3, 6.7, 3.3, 5.7, 2.5,
46        6.7, 3.0, 5.2, 2.3, 6.3, 2.5, 5.0, 1.9, 6.5, 3.0, 5.2, 2.0, 6.2, 3.4, 5.4, 2.3, 5.9, 3.0, 5.1, 1.8
47    ]).unwrap();
48
49    // Define the target (class labels)
50    let targets = vec![
51        0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
52        0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
53        0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0,
54        1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
55        1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
56        1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
57        2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
58        2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
59        2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
60    ];
61    let target = Array1::from(targets);
62
63    // Create dataset
64    let mut dataset = Dataset::new(data, Some(target));
65
66    // Add metadata
67    let featurenames = vec![
68        "sepal_length".to_string(),
69        "sepal_width".to_string(),
70        "petal_length".to_string(),
71        "petal_width".to_string(),
72    ];
73
74    let targetnames = vec![
75        "setosa".to_string(),
76        "versicolor".to_string(),
77        "virginica".to_string(),
78    ];
79
80    let description = "Iris dataset: classic dataset for classification, clustering, and machine learning
81    
82The dataset contains 3 classes of 50 instances each, where each class refers to a type of iris plant.
83One class is linearly separable from the other two; the latter are not linearly separable from each other.
84
85Attributes:
86- sepal length in cm
87- sepal width in cm
88- petal length in cm
89- petal width in cm
90
91Target: 
92- Iris Setosa
93- Iris Versicolour
94- Iris Virginica".to_string();
95
96    dataset = dataset
97        .with_featurenames(featurenames)
98        .with_targetnames(targetnames)
99        .with_description(description);
100
101    Ok(dataset)
102}
103
104/// Generate the breast cancer dataset
105#[allow(dead_code)]
106pub fn load_breast_cancer() -> Result<Dataset> {
107    // This is a simplified version with only 30 samples
108    // In a real implementation, include the full dataset
109    #[rustfmt::skip]
110    let data = Array2::from_shape_vec((30, 5), vec![
111        17.99, 10.38, 122.8, 1001.0, 0.1184,
112        20.57, 17.77, 132.9, 1326.0, 0.08474,
113        19.69, 21.25, 130.0, 1203.0, 0.1096,
114        11.42, 20.38, 77.58, 386.1, 0.1425,
115        20.29, 14.34, 135.1, 1297.0, 0.1003,
116        12.45, 15.7, 82.57, 477.1, 0.1278,
117        18.25, 19.98, 119.6, 1040.0, 0.09463,
118        13.71, 20.83, 90.2, 577.9, 0.1189,
119        13.0, 21.82, 87.5, 519.8, 0.1273,
120        12.46, 24.04, 83.97, 475.9, 0.1186,
121        16.02, 23.24, 102.7, 797.8, 0.08206,
122        15.78, 17.89, 103.6, 781.0, 0.0971,
123        19.17, 24.8, 132.4, 1123.0, 0.0974,
124        15.85, 23.95, 103.7, 782.7, 0.08401,
125        13.73, 22.61, 93.6, 578.3, 0.1131,
126        14.54, 27.54, 96.73, 658.8, 0.1139,
127        14.68, 20.13, 94.74, 684.5, 0.09867,
128        16.13, 20.68, 108.1, 798.8, 0.117,
129        19.81, 22.15, 130.0, 1260.0, 0.09831,
130        13.54, 14.36, 87.46, 566.3, 0.09779,
131        13.08, 15.71, 85.63, 520.0, 0.1075,
132        9.504, 12.44, 60.34, 273.9, 0.1024,
133        15.34, 14.26, 102.5, 704.4, 0.1073,
134        21.16, 23.04, 137.2, 1404.0, 0.09428,
135        16.65, 21.38, 110.0, 904.6, 0.1121,
136        17.14, 16.4, 116.0, 912.7, 0.1186,
137        14.58, 21.53, 97.41, 644.8, 0.1054,
138        18.61, 20.25, 122.1, 1094.0, 0.0944,
139        15.3, 25.27, 102.4, 732.4, 0.1082,
140        17.57, 15.05, 115.0, 955.1, 0.09847
141    ]).unwrap();
142
143    // Define the target (0 = malignant, 1 = benign)
144    let targets = vec![
145        0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0,
146        1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
147    ];
148    let target = Array1::from(targets);
149
150    // Create dataset
151    let mut dataset = Dataset::new(data, Some(target));
152
153    // Add metadata
154    let featurenames = vec![
155        "mean_radius".to_string(),
156        "meantexture".to_string(),
157        "mean_perimeter".to_string(),
158        "mean_area".to_string(),
159        "mean_smoothness".to_string(),
160    ];
161
162    let targetnames = vec!["malignant".to_string(), "benign".to_string()];
163
164    let description = "Breast Cancer Wisconsin (Diagnostic) Database
165    
166Features computed from a digitized image of a fine needle aspirate (FNA) of a breast mass.
167They describe characteristics of the cell nuclei present in the image.
168
169(This is a simplified version of the dataset with only 5 features and 30 samples)
170
171Target:
172- Malignant
173- Benign"
174        .to_string();
175
176    dataset = dataset
177        .with_featurenames(featurenames)
178        .with_targetnames(targetnames)
179        .with_description(description);
180
181    Ok(dataset)
182}
183
184/// Generate the digits dataset
185#[allow(dead_code)]
186pub fn load_digits() -> Result<Dataset> {
187    // Use a simplified version with fewer samples and features
188    // Each digit is represented as a 4x4 image flattened to 16 features
189    let n_samples = 50; // 5 samples per digit (0-9)
190    let n_features = 16; // 4x4 pixels
191
192    let mut data = Array2::zeros((n_samples, n_features));
193    let mut target = Array1::zeros(n_samples);
194
195    // Sample digit patterns (4x4 pixel representations of digits 0-9)
196    #[rustfmt::skip]
197    let digit_patterns = [
198        // Digit 0
199        [0., 1., 1., 0.,
200         1., 0., 0., 1.,
201         1., 0., 0., 1.,
202         0., 1., 1., 0.],
203        // Digit 1
204        [0., 1., 0., 0.,
205         0., 1., 0., 0.,
206         0., 1., 0., 0.,
207         0., 1., 0., 0.],
208        // Digit 2
209        [1., 1., 1., 0.,
210         0., 0., 1., 0.,
211         0., 1., 0., 0.,
212         1., 1., 1., 1.],
213        // Digit 3
214        [1., 1., 1., 0.,
215         0., 0., 1., 0.,
216         1., 1., 1., 0.,
217         0., 0., 1., 0.],
218        // Digit 4
219        [1., 0., 1., 0.,
220         1., 0., 1., 0.,
221         1., 1., 1., 1.,
222         0., 0., 1., 0.],
223        // Digit 5
224        [1., 1., 1., 1.,
225         1., 0., 0., 0.,
226         1., 1., 1., 0.,
227         0., 0., 1., 1.],
228        // Digit 6
229        [0., 1., 1., 0.,
230         1., 0., 0., 0.,
231         1., 1., 1., 0.,
232         0., 1., 1., 0.],
233        // Digit 7
234        [1., 1., 1., 1.,
235         0., 0., 0., 1.,
236         0., 0., 1., 0.,
237         0., 1., 0., 0.],
238        // Digit 8
239        [0., 1., 1., 0.,
240         1., 0., 0., 1.,
241         0., 1., 1., 0.,
242         1., 0., 0., 1.],
243        // Digit 9
244        [0., 1., 1., 0.,
245         1., 0., 0., 1.,
246         0., 1., 1., 1.,
247         0., 0., 1., 0.],
248    ];
249
250    // Create 5 samples per digit with small random variations
251    let mut rng = thread_rng();
252    let noise_level = 0.1;
253
254    for (digit, &pattern) in digit_patterns.iter().enumerate() {
255        for sample in 0..5 {
256            let idx = digit * 5 + sample;
257            target[idx] = digit as f64;
258
259            // Copy the pattern with noise
260            for (j, &pixel) in pattern.iter().enumerate() {
261                let noise = if pixel > 0.5 {
262                    -noise_level * rng.random::<f64>()
263                } else {
264                    noise_level * rng.random::<f64>()
265                };
266
267                let val: f64 = pixel + noise;
268                data[[idx, j]] = val.clamp(0.0, 1.0);
269            }
270        }
271    }
272
273    // Create dataset
274    let mut dataset = Dataset::new(data, Some(target));
275
276    // Create feature names
277    let featurenames: Vec<String> = (0..n_features).map(|i| format!("pixel_{i}")).collect();
278
279    let targetnames: Vec<String> = (0..10).map(|i| format!("{i}")).collect();
280
281    let description = "Optical recognition of handwritten digits dataset
282    
283A simplified version with 50 samples (5 for each digit 0-9) and 16 features (4x4 pixel images).
284Each feature is the grayscale value of a pixel in the image.
285
286Target: Digit identity (0-9)"
287        .to_string();
288
289    dataset = dataset
290        .with_featurenames(featurenames)
291        .with_targetnames(targetnames)
292        .with_description(description);
293
294    Ok(dataset)
295}
296
297/// Generate the Boston housing dataset
298#[allow(dead_code)]
299pub fn load_boston() -> Result<Dataset> {
300    // Simplified version with fewer samples and features
301    let n_samples = 30;
302    let n_features = 5;
303
304    #[rustfmt::skip]
305    let data = Array2::from_shape_vec((n_samples, n_features), vec![
306        0.00632, 18.0, 2.31, 0.538, 6.575,
307        0.02731, 0.0, 7.07, 0.469, 6.421,
308        0.02729, 0.0, 7.07, 0.469, 7.185,
309        0.03237, 0.0, 2.18, 0.458, 6.998,
310        0.06905, 0.0, 2.18, 0.458, 7.147,
311        0.02985, 0.0, 2.18, 0.458, 6.430,
312        0.08829, 12.5, 7.87, 0.524, 6.012,
313        0.14455, 12.5, 7.87, 0.524, 6.172,
314        0.21124, 12.5, 7.87, 0.524, 5.631,
315        0.17004, 12.5, 7.87, 0.524, 6.004,
316        0.22489, 12.5, 7.87, 0.524, 6.377,
317        0.11747, 12.5, 7.87, 0.524, 6.009,
318        0.09378, 12.5, 7.87, 0.524, 5.889,
319        0.62976, 0.0, 8.14, 0.538, 5.949,
320        0.63796, 0.0, 8.14, 0.538, 6.096,
321        0.62739, 0.0, 8.14, 0.538, 5.834,
322        1.05393, 0.0, 8.14, 0.538, 5.935,
323        0.7842, 0.0, 8.14, 0.538, 5.990,
324        0.80271, 0.0, 8.14, 0.538, 5.456,
325        0.7258, 0.0, 8.14, 0.538, 5.727,
326        1.25179, 0.0, 8.14, 0.538, 5.570,
327        0.85204, 0.0, 8.14, 0.538, 5.965,
328        1.23247, 0.0, 8.14, 0.538, 6.142,
329        0.98843, 0.0, 8.14, 0.538, 5.813,
330        0.75026, 0.0, 8.14, 0.538, 5.924,
331        0.84054, 0.0, 8.14, 0.538, 5.599,
332        0.67191, 0.0, 8.14, 0.538, 5.813,
333        0.95577, 0.0, 8.14, 0.538, 6.047,
334        0.77299, 0.0, 8.14, 0.538, 6.495,
335        1.00245, 0.0, 8.14, 0.538, 6.674
336    ]).unwrap();
337
338    let targets = vec![
339        24.0, 21.6, 34.7, 33.4, 36.2, 28.7, 22.9, 27.1, 16.5, 18.9, 15.0, 18.9, 21.7, 20.4, 18.2,
340        19.9, 23.1, 17.5, 20.2, 18.2, 13.6, 19.6, 15.2, 14.5, 15.6, 13.9, 16.6, 14.8, 18.4, 21.0,
341    ];
342    let target = Array1::from(targets);
343
344    // Create dataset
345    let mut dataset = Dataset::new(data, Some(target));
346
347    // Add metadata
348    let featurenames = vec![
349        "CRIM".to_string(),
350        "ZN".to_string(),
351        "INDUS".to_string(),
352        "CHAS".to_string(),
353        "NOX".to_string(),
354    ];
355
356    let feature_descriptions = vec![
357        "per capita crime rate by town".to_string(),
358        "proportion of residential land zoned for lots over 25,000 sq.ft.".to_string(),
359        "proportion of non-retail business acres per town".to_string(),
360        "Charles River dummy variable (= 1 if tract bounds river; 0 otherwise)".to_string(),
361        "nitric oxides concentration (parts per 10 million)".to_string(),
362    ];
363
364    let description = "Boston Housing Dataset (Simplified)
365    
366A simplified version of the Boston housing dataset with 30 samples and 5 features.
367The target variable is the median value of owner-occupied homes in $1000s.
368
369This is a regression dataset."
370        .to_string();
371
372    dataset = dataset
373        .with_featurenames(featurenames)
374        .with_feature_descriptions(feature_descriptions)
375        .with_description(description);
376
377    Ok(dataset)
378}
379
380/// Generate a synthetic diabetes dataset for regression
381///
382/// This is a simplified version of the classic diabetes dataset with 442 samples
383/// and 10 features, suitable for regression tasks.
384#[allow(dead_code)]
385pub fn load_diabetes() -> Result<Dataset> {
386    // Use a fixed seed for reproducibility
387    let mut rng = StdRng::seed_from_u64(42);
388
389    let n_samples = 442;
390    let n_features = 10;
391
392    // Generate synthetic data that resembles the diabetes dataset structure
393    let mut data = Vec::with_capacity(n_samples * n_features);
394    let mut targets = Vec::with_capacity(n_samples);
395
396    for _ in 0..n_samples {
397        // Generate correlated features (representing biomarkers)
398        let age = rng.random::<f64>() * 0.1 - 0.05;
399        let sex = if rng.random::<f64>() < 0.5 {
400            -0.05
401        } else {
402            0.05
403        };
404        let bmi = (rng.random::<f64>() * 0.12 - 0.06) + age * 0.3;
405        let bp = (rng.random::<f64>() * 0.1 - 0.05) + bmi * 0.4;
406        let s1 = (rng.random::<f64>() * 0.14 - 0.07) + bmi * 0.2;
407        let s2 = (rng.random::<f64>() * 0.16 - 0.08) + s1 * 0.5;
408        let s3 = (rng.random::<f64>() * 0.12 - 0.06) + age * 0.2;
409        let s4 = (rng.random::<f64>() * 0.12 - 0.06) + s1 * 0.3;
410        let s5 = (rng.random::<f64>() * 0.14 - 0.07) + bmi * 0.25;
411        let s6 = (rng.random::<f64>() * 0.1 - 0.05) + s5 * 0.4;
412
413        data.extend_from_slice(&[age, sex, bmi, bp, s1, s2, s3, s4, s5, s6]);
414
415        // Generate target as a linear combination with noise
416        let target = 152.0
417            + 938.0 * bmi
418            + 519.0 * bp
419            + 324.0 * s1
420            + 217.0 * s5
421            + (rng.random::<f64>() * 40.0 - 20.0);
422        targets.push(target);
423    }
424
425    let data_array = Array2::from_shape_vec((n_samples, n_features), data).unwrap();
426    let target_array = Array1::from_vec(targets);
427
428    let featurenames = vec![
429        "age".to_string(),
430        "sex".to_string(),
431        "bmi".to_string(),
432        "bp".to_string(),
433        "s1".to_string(),
434        "s2".to_string(),
435        "s3".to_string(),
436        "s4".to_string(),
437        "s5".to_string(),
438        "s6".to_string(),
439    ];
440
441    let feature_descriptions = vec![
442        "Age".to_string(),
443        "Sex".to_string(),
444        "Body mass index".to_string(),
445        "Average blood pressure".to_string(),
446        "Total serum cholesterol".to_string(),
447        "Low-density lipoproteins".to_string(),
448        "High-density lipoproteins".to_string(),
449        "Total cholesterol / HDL".to_string(),
450        "Log of serum triglycerides level".to_string(),
451        "Blood sugar level".to_string(),
452    ];
453
454    let description = "Diabetes dataset for regression. A synthetic version of the classic diabetes dataset with 442 samples and 10 physiological features.".to_string();
455
456    let dataset = Dataset::new(data_array, Some(target_array))
457        .with_featurenames(featurenames)
458        .with_feature_descriptions(feature_descriptions)
459        .with_description(description);
460
461    Ok(dataset)
462}
463
464#[cfg(test)]
465mod tests {
466    use super::*;
467
468    #[test]
469    fn test_load_iris() {
470        let dataset = load_iris().unwrap();
471
472        assert_eq!(dataset.n_samples(), 150);
473        assert_eq!(dataset.n_features(), 4);
474        assert!(dataset.target.is_some());
475        assert!(dataset.description.is_some());
476        assert!(dataset.featurenames.is_some());
477        assert!(dataset.targetnames.is_some());
478
479        let featurenames = dataset.featurenames.as_ref().unwrap();
480        assert_eq!(featurenames.len(), 4);
481        assert_eq!(featurenames[0], "sepal_length");
482        assert_eq!(featurenames[3], "petal_width");
483
484        let targetnames = dataset.targetnames.as_ref().unwrap();
485        assert_eq!(targetnames.len(), 3);
486        assert!(targetnames.contains(&"setosa".to_string()));
487        assert!(targetnames.contains(&"versicolor".to_string()));
488        assert!(targetnames.contains(&"virginica".to_string()));
489
490        // Check target values are in valid range (0, 1, 2)
491        let target = dataset.target.as_ref().unwrap();
492        for &val in target.iter() {
493            assert!((0.0..=2.0).contains(&val));
494        }
495    }
496
497    #[test]
498    fn test_load_breast_cancer() {
499        let dataset = load_breast_cancer().unwrap();
500
501        assert_eq!(dataset.n_samples(), 30);
502        assert_eq!(dataset.n_features(), 5);
503        assert!(dataset.target.is_some());
504        assert!(dataset.description.is_some());
505        assert!(dataset.featurenames.is_some());
506        assert!(dataset.targetnames.is_some());
507
508        let featurenames = dataset.featurenames.as_ref().unwrap();
509        assert_eq!(featurenames.len(), 5);
510        assert_eq!(featurenames[0], "mean_radius");
511        assert_eq!(featurenames[4], "mean_smoothness");
512
513        let targetnames = dataset.targetnames.as_ref().unwrap();
514        assert_eq!(targetnames.len(), 2);
515        assert!(targetnames.contains(&"malignant".to_string()));
516        assert!(targetnames.contains(&"benign".to_string()));
517
518        // Check target values are binary (0 or 1)
519        let target = dataset.target.as_ref().unwrap();
520        for &val in target.iter() {
521            assert!(val == 0.0 || val == 1.0);
522        }
523    }
524
525    #[test]
526    fn test_load_digits() {
527        let dataset = load_digits().unwrap();
528
529        assert_eq!(dataset.n_samples(), 50);
530        assert_eq!(dataset.n_features(), 16);
531        assert!(dataset.target.is_some());
532        assert!(dataset.description.is_some());
533        assert!(dataset.featurenames.is_some());
534        assert!(dataset.targetnames.is_some());
535
536        let featurenames = dataset.featurenames.as_ref().unwrap();
537        assert_eq!(featurenames.len(), 16);
538        assert_eq!(featurenames[0], "pixel_0");
539        assert_eq!(featurenames[15], "pixel_15");
540
541        let targetnames = dataset.targetnames.as_ref().unwrap();
542        assert_eq!(targetnames.len(), 10);
543        for i in 0..10 {
544            assert!(targetnames.contains(&i.to_string()));
545        }
546
547        // Check target values are digits (0-9)
548        let target = dataset.target.as_ref().unwrap();
549        for &val in target.iter() {
550            assert!((0.0..=9.0).contains(&val));
551        }
552
553        // Check pixel values are in valid range [0, 1]
554        for row in dataset.data.rows() {
555            for &pixel in row.iter() {
556                assert!((0.0..=1.0).contains(&pixel));
557            }
558        }
559    }
560
561    #[test]
562    fn test_load_boston() {
563        let dataset = load_boston().unwrap();
564
565        assert_eq!(dataset.n_samples(), 30);
566        assert_eq!(dataset.n_features(), 5);
567        assert!(dataset.target.is_some());
568        assert!(dataset.description.is_some());
569        assert!(dataset.featurenames.is_some());
570        assert!(dataset.feature_descriptions.is_some());
571
572        let featurenames = dataset.featurenames.as_ref().unwrap();
573        assert_eq!(featurenames.len(), 5);
574        assert_eq!(featurenames[0], "CRIM");
575        assert_eq!(featurenames[4], "NOX");
576
577        let feature_descriptions = dataset.feature_descriptions.as_ref().unwrap();
578        assert_eq!(feature_descriptions.len(), 5);
579        assert!(feature_descriptions[0].contains("crime rate"));
580
581        // Check target values are reasonable housing prices
582        let target = dataset.target.as_ref().unwrap();
583        for &val in target.iter() {
584            assert!(val > 0.0 && val < 100.0); // Reasonable housing prices in $1000s
585        }
586    }
587
588    #[test]
589    fn test_all_datasets_have_consistentshapes() {
590        let datasets = vec![
591            ("iris", load_iris().unwrap()),
592            ("breast_cancer", load_breast_cancer().unwrap()),
593            ("digits", load_digits().unwrap()),
594            ("boston", load_boston().unwrap()),
595            ("diabetes", load_diabetes().unwrap()),
596        ];
597
598        for (name, dataset) in datasets {
599            // Check that data and target have consistent sample counts
600            if let Some(ref target) = dataset.target {
601                assert_eq!(
602                    dataset.data.nrows(),
603                    target.len(),
604                    "Dataset '{name}' has inconsistent sample counts"
605                );
606            }
607
608            // Check that feature names match feature count (if present)
609            if let Some(ref featurenames) = dataset.featurenames {
610                assert_eq!(
611                    dataset.data.ncols(),
612                    featurenames.len(),
613                    "Dataset '{name}' has inconsistent feature count"
614                );
615            }
616
617            // Check that feature descriptions match feature count (if present)
618            if let Some(ref feature_descriptions) = dataset.feature_descriptions {
619                assert_eq!(
620                    dataset.data.ncols(),
621                    feature_descriptions.len(),
622                    "Dataset '{name}' has inconsistent feature description count"
623                );
624            }
625
626            // Check that dataset has a description
627            assert!(
628                dataset.description.is_some(),
629                "Dataset '{name}' missing description"
630            );
631        }
632    }
633}