scirs2_datasets/
toy.rs

1//! Toy datasets for testing and examples
2//!
3//! This module provides small, synthetic datasets that are useful for
4//! testing algorithms and illustrating concepts.
5
6use crate::error::Result;
7use crate::utils::Dataset;
8use ndarray::{Array1, Array2};
9use rand::prelude::*;
10use rand::rng;
11
12/// Generate the classic Iris dataset
13pub fn load_iris() -> Result<Dataset> {
14    // Define the data
15    #[rustfmt::skip]
16    let data = Array2::from_shape_vec((150, 4), vec![
17        5.1, 3.5, 1.4, 0.2, 4.9, 3.0, 1.4, 0.2, 4.7, 3.2, 1.3, 0.2, 4.6, 3.1, 1.5, 0.2, 5.0, 3.6, 1.4, 0.2,
18        5.4, 3.9, 1.7, 0.4, 4.6, 3.4, 1.4, 0.3, 5.0, 3.4, 1.5, 0.2, 4.4, 2.9, 1.4, 0.2, 4.9, 3.1, 1.5, 0.1,
19        5.4, 3.7, 1.5, 0.2, 4.8, 3.4, 1.6, 0.2, 4.8, 3.0, 1.4, 0.1, 4.3, 3.0, 1.1, 0.1, 5.8, 4.0, 1.2, 0.2,
20        5.7, 4.4, 1.5, 0.4, 5.4, 3.9, 1.3, 0.4, 5.1, 3.5, 1.4, 0.3, 5.7, 3.8, 1.7, 0.3, 5.1, 3.8, 1.5, 0.3,
21        5.4, 3.4, 1.7, 0.2, 5.1, 3.7, 1.5, 0.4, 4.6, 3.6, 1.0, 0.2, 5.1, 3.3, 1.7, 0.5, 4.8, 3.4, 1.9, 0.2,
22        5.0, 3.0, 1.6, 0.2, 5.0, 3.4, 1.6, 0.4, 5.2, 3.5, 1.5, 0.2, 5.2, 3.4, 1.4, 0.2, 4.7, 3.2, 1.6, 0.2,
23        4.8, 3.1, 1.6, 0.2, 5.4, 3.4, 1.5, 0.4, 5.2, 4.1, 1.5, 0.1, 5.5, 4.2, 1.4, 0.2, 4.9, 3.1, 1.5, 0.1,
24        5.0, 3.2, 1.2, 0.2, 5.5, 3.5, 1.3, 0.2, 4.9, 3.1, 1.5, 0.1, 4.4, 3.0, 1.3, 0.2, 5.1, 3.4, 1.5, 0.2,
25        5.0, 3.5, 1.3, 0.3, 4.5, 2.3, 1.3, 0.3, 4.4, 3.2, 1.3, 0.2, 5.0, 3.5, 1.6, 0.6, 5.1, 3.8, 1.9, 0.4,
26        4.8, 3.0, 1.4, 0.3, 5.1, 3.8, 1.6, 0.2, 4.6, 3.2, 1.4, 0.2, 5.3, 3.7, 1.5, 0.2, 5.0, 3.3, 1.4, 0.2,
27        7.0, 3.2, 4.7, 1.4, 6.4, 3.2, 4.5, 1.5, 6.9, 3.1, 4.9, 1.5, 5.5, 2.3, 4.0, 1.3, 6.5, 2.8, 4.6, 1.5,
28        5.7, 2.8, 4.5, 1.3, 6.3, 3.3, 4.7, 1.6, 4.9, 2.4, 3.3, 1.0, 6.6, 2.9, 4.6, 1.3, 5.2, 2.7, 3.9, 1.4,
29        5.0, 2.0, 3.5, 1.0, 5.9, 3.0, 4.2, 1.5, 6.0, 2.2, 4.0, 1.0, 6.1, 2.9, 4.7, 1.4, 5.6, 2.9, 3.6, 1.3,
30        6.7, 3.1, 4.4, 1.4, 5.6, 3.0, 4.5, 1.5, 5.8, 2.7, 4.1, 1.0, 6.2, 2.2, 4.5, 1.5, 5.6, 2.5, 3.9, 1.1,
31        5.9, 3.2, 4.8, 1.8, 6.1, 2.8, 4.0, 1.3, 6.3, 2.5, 4.9, 1.5, 6.1, 2.8, 4.7, 1.2, 6.4, 2.9, 4.3, 1.3,
32        6.6, 3.0, 4.4, 1.4, 6.8, 2.8, 4.8, 1.4, 6.7, 3.0, 5.0, 1.7, 6.0, 2.9, 4.5, 1.5, 5.7, 2.6, 3.5, 1.0,
33        5.5, 2.4, 3.8, 1.1, 5.5, 2.4, 3.7, 1.0, 5.8, 2.7, 3.9, 1.2, 6.0, 2.7, 5.1, 1.6, 5.4, 3.0, 4.5, 1.5,
34        6.0, 3.4, 4.5, 1.6, 6.7, 3.1, 4.7, 1.5, 6.3, 2.3, 4.4, 1.3, 5.6, 3.0, 4.1, 1.3, 5.5, 2.5, 4.0, 1.3,
35        5.5, 2.6, 4.4, 1.2, 6.1, 3.0, 4.6, 1.4, 5.8, 2.6, 4.0, 1.2, 5.0, 2.3, 3.3, 1.0, 5.6, 2.7, 4.2, 1.3,
36        5.7, 3.0, 4.2, 1.2, 5.7, 2.9, 4.2, 1.3, 6.2, 2.9, 4.3, 1.3, 5.1, 2.5, 3.0, 1.1, 5.7, 2.8, 4.1, 1.3,
37        6.3, 3.3, 6.0, 2.5, 5.8, 2.7, 5.1, 1.9, 7.1, 3.0, 5.9, 2.1, 6.3, 2.9, 5.6, 1.8, 6.5, 3.0, 5.8, 2.2,
38        7.6, 3.0, 6.6, 2.1, 4.9, 2.5, 4.5, 1.7, 7.3, 2.9, 6.3, 1.8, 6.7, 2.5, 5.8, 1.8, 7.2, 3.6, 6.1, 2.5,
39        6.5, 3.2, 5.1, 2.0, 6.4, 2.7, 5.3, 1.9, 6.8, 3.0, 5.5, 2.1, 5.7, 2.5, 5.0, 2.0, 5.8, 2.8, 5.1, 2.4,
40        6.4, 3.2, 5.3, 2.3, 6.5, 3.0, 5.5, 1.8, 7.7, 3.8, 6.7, 2.2, 7.7, 2.6, 6.9, 2.3, 6.0, 2.2, 5.0, 1.5,
41        6.9, 3.2, 5.7, 2.3, 5.6, 2.8, 4.9, 2.0, 7.7, 2.8, 6.7, 2.0, 6.3, 2.7, 4.9, 1.8, 6.7, 3.3, 5.7, 2.1,
42        7.2, 3.2, 6.0, 1.8, 6.2, 2.8, 4.8, 1.8, 6.1, 3.0, 4.9, 1.8, 6.4, 2.8, 5.6, 2.1, 7.2, 3.0, 5.8, 1.6,
43        7.4, 2.8, 6.1, 1.9, 7.9, 3.8, 6.4, 2.0, 6.4, 2.8, 5.6, 2.2, 6.3, 2.8, 5.1, 1.5, 6.1, 2.6, 5.6, 1.4,
44        7.7, 3.0, 6.1, 2.3, 6.3, 3.4, 5.6, 2.4, 6.4, 3.1, 5.5, 1.8, 6.0, 3.0, 4.8, 1.8, 6.9, 3.1, 5.4, 2.1,
45        6.7, 3.1, 5.6, 2.4, 6.9, 3.1, 5.1, 2.3, 5.8, 2.7, 5.1, 1.9, 6.8, 3.2, 5.9, 2.3, 6.7, 3.3, 5.7, 2.5,
46        6.7, 3.0, 5.2, 2.3, 6.3, 2.5, 5.0, 1.9, 6.5, 3.0, 5.2, 2.0, 6.2, 3.4, 5.4, 2.3, 5.9, 3.0, 5.1, 1.8
47    ]).unwrap();
48
49    // Define the target (class labels)
50    let targets = vec![
51        0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
52        0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
53        0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0,
54        1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
55        1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
56        1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
57        2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
58        2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
59        2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
60    ];
61    let target = Array1::from(targets);
62
63    // Create dataset
64    let mut dataset = Dataset::new(data, Some(target));
65
66    // Add metadata
67    let feature_names = vec![
68        "sepal_length".to_string(),
69        "sepal_width".to_string(),
70        "petal_length".to_string(),
71        "petal_width".to_string(),
72    ];
73
74    let target_names = vec![
75        "setosa".to_string(),
76        "versicolor".to_string(),
77        "virginica".to_string(),
78    ];
79
80    let description = "Iris dataset: classic dataset for classification, clustering, and machine learning
81    
82The dataset contains 3 classes of 50 instances each, where each class refers to a type of iris plant.
83One class is linearly separable from the other two; the latter are not linearly separable from each other.
84
85Attributes:
86- sepal length in cm
87- sepal width in cm
88- petal length in cm
89- petal width in cm
90
91Target: 
92- Iris Setosa
93- Iris Versicolour
94- Iris Virginica".to_string();
95
96    dataset = dataset
97        .with_feature_names(feature_names)
98        .with_target_names(target_names)
99        .with_description(description);
100
101    Ok(dataset)
102}
103
104/// Generate the breast cancer dataset
105pub fn load_breast_cancer() -> Result<Dataset> {
106    // This is a simplified version with only 30 samples
107    // In a real implementation, include the full dataset
108    #[rustfmt::skip]
109    let data = Array2::from_shape_vec((30, 5), vec![
110        17.99, 10.38, 122.8, 1001.0, 0.1184,
111        20.57, 17.77, 132.9, 1326.0, 0.08474,
112        19.69, 21.25, 130.0, 1203.0, 0.1096,
113        11.42, 20.38, 77.58, 386.1, 0.1425,
114        20.29, 14.34, 135.1, 1297.0, 0.1003,
115        12.45, 15.7, 82.57, 477.1, 0.1278,
116        18.25, 19.98, 119.6, 1040.0, 0.09463,
117        13.71, 20.83, 90.2, 577.9, 0.1189,
118        13.0, 21.82, 87.5, 519.8, 0.1273,
119        12.46, 24.04, 83.97, 475.9, 0.1186,
120        16.02, 23.24, 102.7, 797.8, 0.08206,
121        15.78, 17.89, 103.6, 781.0, 0.0971,
122        19.17, 24.8, 132.4, 1123.0, 0.0974,
123        15.85, 23.95, 103.7, 782.7, 0.08401,
124        13.73, 22.61, 93.6, 578.3, 0.1131,
125        14.54, 27.54, 96.73, 658.8, 0.1139,
126        14.68, 20.13, 94.74, 684.5, 0.09867,
127        16.13, 20.68, 108.1, 798.8, 0.117,
128        19.81, 22.15, 130.0, 1260.0, 0.09831,
129        13.54, 14.36, 87.46, 566.3, 0.09779,
130        13.08, 15.71, 85.63, 520.0, 0.1075,
131        9.504, 12.44, 60.34, 273.9, 0.1024,
132        15.34, 14.26, 102.5, 704.4, 0.1073,
133        21.16, 23.04, 137.2, 1404.0, 0.09428,
134        16.65, 21.38, 110.0, 904.6, 0.1121,
135        17.14, 16.4, 116.0, 912.7, 0.1186,
136        14.58, 21.53, 97.41, 644.8, 0.1054,
137        18.61, 20.25, 122.1, 1094.0, 0.0944,
138        15.3, 25.27, 102.4, 732.4, 0.1082,
139        17.57, 15.05, 115.0, 955.1, 0.09847
140    ]).unwrap();
141
142    // Define the target (0 = malignant, 1 = benign)
143    let targets = vec![
144        0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0,
145        1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
146    ];
147    let target = Array1::from(targets);
148
149    // Create dataset
150    let mut dataset = Dataset::new(data, Some(target));
151
152    // Add metadata
153    let feature_names = vec![
154        "mean_radius".to_string(),
155        "mean_texture".to_string(),
156        "mean_perimeter".to_string(),
157        "mean_area".to_string(),
158        "mean_smoothness".to_string(),
159    ];
160
161    let target_names = vec!["malignant".to_string(), "benign".to_string()];
162
163    let description = "Breast Cancer Wisconsin (Diagnostic) Database
164    
165Features computed from a digitized image of a fine needle aspirate (FNA) of a breast mass.
166They describe characteristics of the cell nuclei present in the image.
167
168(This is a simplified version of the dataset with only 5 features and 30 samples)
169
170Target:
171- Malignant
172- Benign"
173        .to_string();
174
175    dataset = dataset
176        .with_feature_names(feature_names)
177        .with_target_names(target_names)
178        .with_description(description);
179
180    Ok(dataset)
181}
182
183/// Generate the digits dataset
184pub fn load_digits() -> Result<Dataset> {
185    // Use a simplified version with fewer samples and features
186    // Each digit is represented as a 4x4 image flattened to 16 features
187    let n_samples = 50; // 5 samples per digit (0-9)
188    let n_features = 16; // 4x4 pixels
189
190    let mut data = Array2::zeros((n_samples, n_features));
191    let mut target = Array1::zeros(n_samples);
192
193    // Sample digit patterns (4x4 pixel representations of digits 0-9)
194    #[rustfmt::skip]
195    let digit_patterns = [
196        // Digit 0
197        [0., 1., 1., 0.,
198         1., 0., 0., 1.,
199         1., 0., 0., 1.,
200         0., 1., 1., 0.],
201        // Digit 1
202        [0., 1., 0., 0.,
203         0., 1., 0., 0.,
204         0., 1., 0., 0.,
205         0., 1., 0., 0.],
206        // Digit 2
207        [1., 1., 1., 0.,
208         0., 0., 1., 0.,
209         0., 1., 0., 0.,
210         1., 1., 1., 1.],
211        // Digit 3
212        [1., 1., 1., 0.,
213         0., 0., 1., 0.,
214         1., 1., 1., 0.,
215         0., 0., 1., 0.],
216        // Digit 4
217        [1., 0., 1., 0.,
218         1., 0., 1., 0.,
219         1., 1., 1., 1.,
220         0., 0., 1., 0.],
221        // Digit 5
222        [1., 1., 1., 1.,
223         1., 0., 0., 0.,
224         1., 1., 1., 0.,
225         0., 0., 1., 1.],
226        // Digit 6
227        [0., 1., 1., 0.,
228         1., 0., 0., 0.,
229         1., 1., 1., 0.,
230         0., 1., 1., 0.],
231        // Digit 7
232        [1., 1., 1., 1.,
233         0., 0., 0., 1.,
234         0., 0., 1., 0.,
235         0., 1., 0., 0.],
236        // Digit 8
237        [0., 1., 1., 0.,
238         1., 0., 0., 1.,
239         0., 1., 1., 0.,
240         1., 0., 0., 1.],
241        // Digit 9
242        [0., 1., 1., 0.,
243         1., 0., 0., 1.,
244         0., 1., 1., 1.,
245         0., 0., 1., 0.],
246    ];
247
248    // Create 5 samples per digit with small random variations
249    let mut rng = rng();
250    let noise_level = 0.1;
251
252    for (digit, &pattern) in digit_patterns.iter().enumerate() {
253        for sample in 0..5 {
254            let idx = digit * 5 + sample;
255            target[idx] = digit as f64;
256
257            // Copy the pattern with noise
258            for (j, &pixel) in pattern.iter().enumerate() {
259                let noise = if pixel > 0.5 {
260                    -noise_level * rng.random::<f64>()
261                } else {
262                    noise_level * rng.random::<f64>()
263                };
264
265                let val = pixel + noise;
266                data[[idx, j]] = val.clamp(0.0, 1.0);
267            }
268        }
269    }
270
271    // Create dataset
272    let mut dataset = Dataset::new(data, Some(target));
273
274    // Create feature names
275    let feature_names: Vec<String> = (0..n_features).map(|i| format!("pixel_{}", i)).collect();
276
277    let target_names: Vec<String> = (0..10).map(|i| format!("{}", i)).collect();
278
279    let description = "Optical recognition of handwritten digits dataset
280    
281A simplified version with 50 samples (5 for each digit 0-9) and 16 features (4x4 pixel images).
282Each feature is the grayscale value of a pixel in the image.
283
284Target: Digit identity (0-9)"
285        .to_string();
286
287    dataset = dataset
288        .with_feature_names(feature_names)
289        .with_target_names(target_names)
290        .with_description(description);
291
292    Ok(dataset)
293}
294
295/// Generate the Boston housing dataset
296pub fn load_boston() -> Result<Dataset> {
297    // Simplified version with fewer samples and features
298    let n_samples = 30;
299    let n_features = 5;
300
301    #[rustfmt::skip]
302    let data = Array2::from_shape_vec((n_samples, n_features), vec![
303        0.00632, 18.0, 2.31, 0.538, 6.575,
304        0.02731, 0.0, 7.07, 0.469, 6.421,
305        0.02729, 0.0, 7.07, 0.469, 7.185,
306        0.03237, 0.0, 2.18, 0.458, 6.998,
307        0.06905, 0.0, 2.18, 0.458, 7.147,
308        0.02985, 0.0, 2.18, 0.458, 6.430,
309        0.08829, 12.5, 7.87, 0.524, 6.012,
310        0.14455, 12.5, 7.87, 0.524, 6.172,
311        0.21124, 12.5, 7.87, 0.524, 5.631,
312        0.17004, 12.5, 7.87, 0.524, 6.004,
313        0.22489, 12.5, 7.87, 0.524, 6.377,
314        0.11747, 12.5, 7.87, 0.524, 6.009,
315        0.09378, 12.5, 7.87, 0.524, 5.889,
316        0.62976, 0.0, 8.14, 0.538, 5.949,
317        0.63796, 0.0, 8.14, 0.538, 6.096,
318        0.62739, 0.0, 8.14, 0.538, 5.834,
319        1.05393, 0.0, 8.14, 0.538, 5.935,
320        0.7842, 0.0, 8.14, 0.538, 5.990,
321        0.80271, 0.0, 8.14, 0.538, 5.456,
322        0.7258, 0.0, 8.14, 0.538, 5.727,
323        1.25179, 0.0, 8.14, 0.538, 5.570,
324        0.85204, 0.0, 8.14, 0.538, 5.965,
325        1.23247, 0.0, 8.14, 0.538, 6.142,
326        0.98843, 0.0, 8.14, 0.538, 5.813,
327        0.75026, 0.0, 8.14, 0.538, 5.924,
328        0.84054, 0.0, 8.14, 0.538, 5.599,
329        0.67191, 0.0, 8.14, 0.538, 5.813,
330        0.95577, 0.0, 8.14, 0.538, 6.047,
331        0.77299, 0.0, 8.14, 0.538, 6.495,
332        1.00245, 0.0, 8.14, 0.538, 6.674
333    ]).unwrap();
334
335    let targets = vec![
336        24.0, 21.6, 34.7, 33.4, 36.2, 28.7, 22.9, 27.1, 16.5, 18.9, 15.0, 18.9, 21.7, 20.4, 18.2,
337        19.9, 23.1, 17.5, 20.2, 18.2, 13.6, 19.6, 15.2, 14.5, 15.6, 13.9, 16.6, 14.8, 18.4, 21.0,
338    ];
339    let target = Array1::from(targets);
340
341    // Create dataset
342    let mut dataset = Dataset::new(data, Some(target));
343
344    // Add metadata
345    let feature_names = vec![
346        "CRIM".to_string(),
347        "ZN".to_string(),
348        "INDUS".to_string(),
349        "CHAS".to_string(),
350        "NOX".to_string(),
351    ];
352
353    let feature_descriptions = vec![
354        "per capita crime rate by town".to_string(),
355        "proportion of residential land zoned for lots over 25,000 sq.ft.".to_string(),
356        "proportion of non-retail business acres per town".to_string(),
357        "Charles River dummy variable (= 1 if tract bounds river; 0 otherwise)".to_string(),
358        "nitric oxides concentration (parts per 10 million)".to_string(),
359    ];
360
361    let description = "Boston Housing Dataset (Simplified)
362    
363A simplified version of the Boston housing dataset with 30 samples and 5 features.
364The target variable is the median value of owner-occupied homes in $1000s.
365
366This is a regression dataset."
367        .to_string();
368
369    dataset = dataset
370        .with_feature_names(feature_names)
371        .with_feature_descriptions(feature_descriptions)
372        .with_description(description);
373
374    Ok(dataset)
375}
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380
381    #[test]
382    fn test_load_iris() {
383        let dataset = load_iris().unwrap();
384
385        assert_eq!(dataset.n_samples(), 150);
386        assert_eq!(dataset.n_features(), 4);
387        assert!(dataset.target.is_some());
388        assert!(dataset.description.is_some());
389        assert!(dataset.feature_names.is_some());
390        assert!(dataset.target_names.is_some());
391
392        let feature_names = dataset.feature_names.as_ref().unwrap();
393        assert_eq!(feature_names.len(), 4);
394        assert_eq!(feature_names[0], "sepal_length");
395        assert_eq!(feature_names[3], "petal_width");
396
397        let target_names = dataset.target_names.as_ref().unwrap();
398        assert_eq!(target_names.len(), 3);
399        assert!(target_names.contains(&"setosa".to_string()));
400        assert!(target_names.contains(&"versicolor".to_string()));
401        assert!(target_names.contains(&"virginica".to_string()));
402
403        // Check target values are in valid range (0, 1, 2)
404        let target = dataset.target.as_ref().unwrap();
405        for &val in target.iter() {
406            assert!((0.0..=2.0).contains(&val));
407        }
408    }
409
410    #[test]
411    fn test_load_breast_cancer() {
412        let dataset = load_breast_cancer().unwrap();
413
414        assert_eq!(dataset.n_samples(), 30);
415        assert_eq!(dataset.n_features(), 5);
416        assert!(dataset.target.is_some());
417        assert!(dataset.description.is_some());
418        assert!(dataset.feature_names.is_some());
419        assert!(dataset.target_names.is_some());
420
421        let feature_names = dataset.feature_names.as_ref().unwrap();
422        assert_eq!(feature_names.len(), 5);
423        assert_eq!(feature_names[0], "mean_radius");
424        assert_eq!(feature_names[4], "mean_smoothness");
425
426        let target_names = dataset.target_names.as_ref().unwrap();
427        assert_eq!(target_names.len(), 2);
428        assert!(target_names.contains(&"malignant".to_string()));
429        assert!(target_names.contains(&"benign".to_string()));
430
431        // Check target values are binary (0 or 1)
432        let target = dataset.target.as_ref().unwrap();
433        for &val in target.iter() {
434            assert!(val == 0.0 || val == 1.0);
435        }
436    }
437
438    #[test]
439    fn test_load_digits() {
440        let dataset = load_digits().unwrap();
441
442        assert_eq!(dataset.n_samples(), 50);
443        assert_eq!(dataset.n_features(), 16);
444        assert!(dataset.target.is_some());
445        assert!(dataset.description.is_some());
446        assert!(dataset.feature_names.is_some());
447        assert!(dataset.target_names.is_some());
448
449        let feature_names = dataset.feature_names.as_ref().unwrap();
450        assert_eq!(feature_names.len(), 16);
451        assert_eq!(feature_names[0], "pixel_0");
452        assert_eq!(feature_names[15], "pixel_15");
453
454        let target_names = dataset.target_names.as_ref().unwrap();
455        assert_eq!(target_names.len(), 10);
456        for i in 0..10 {
457            assert!(target_names.contains(&i.to_string()));
458        }
459
460        // Check target values are digits (0-9)
461        let target = dataset.target.as_ref().unwrap();
462        for &val in target.iter() {
463            assert!((0.0..=9.0).contains(&val));
464        }
465
466        // Check pixel values are in valid range [0, 1]
467        for row in dataset.data.rows() {
468            for &pixel in row.iter() {
469                assert!((0.0..=1.0).contains(&pixel));
470            }
471        }
472    }
473
474    #[test]
475    fn test_load_boston() {
476        let dataset = load_boston().unwrap();
477
478        assert_eq!(dataset.n_samples(), 30);
479        assert_eq!(dataset.n_features(), 5);
480        assert!(dataset.target.is_some());
481        assert!(dataset.description.is_some());
482        assert!(dataset.feature_names.is_some());
483        assert!(dataset.feature_descriptions.is_some());
484
485        let feature_names = dataset.feature_names.as_ref().unwrap();
486        assert_eq!(feature_names.len(), 5);
487        assert_eq!(feature_names[0], "CRIM");
488        assert_eq!(feature_names[4], "NOX");
489
490        let feature_descriptions = dataset.feature_descriptions.as_ref().unwrap();
491        assert_eq!(feature_descriptions.len(), 5);
492        assert!(feature_descriptions[0].contains("crime rate"));
493
494        // Check target values are reasonable housing prices
495        let target = dataset.target.as_ref().unwrap();
496        for &val in target.iter() {
497            assert!(val > 0.0 && val < 100.0); // Reasonable housing prices in $1000s
498        }
499    }
500
501    #[test]
502    fn test_all_datasets_have_consistent_shapes() {
503        let datasets = vec![
504            ("iris", load_iris().unwrap()),
505            ("breast_cancer", load_breast_cancer().unwrap()),
506            ("digits", load_digits().unwrap()),
507            ("boston", load_boston().unwrap()),
508        ];
509
510        for (name, dataset) in datasets {
511            // Check that data and target have consistent sample counts
512            if let Some(ref target) = dataset.target {
513                assert_eq!(
514                    dataset.data.nrows(),
515                    target.len(),
516                    "Dataset '{}' has inconsistent sample counts",
517                    name
518                );
519            }
520
521            // Check that feature names match feature count (if present)
522            if let Some(ref feature_names) = dataset.feature_names {
523                assert_eq!(
524                    dataset.data.ncols(),
525                    feature_names.len(),
526                    "Dataset '{}' has inconsistent feature count",
527                    name
528                );
529            }
530
531            // Check that feature descriptions match feature count (if present)
532            if let Some(ref feature_descriptions) = dataset.feature_descriptions {
533                assert_eq!(
534                    dataset.data.ncols(),
535                    feature_descriptions.len(),
536                    "Dataset '{}' has inconsistent feature description count",
537                    name
538                );
539            }
540
541            // Check that dataset has a description
542            assert!(
543                dataset.description.is_some(),
544                "Dataset '{}' missing description",
545                name
546            );
547        }
548    }
549}