1use crate::error::Result;
7use crate::utils::Dataset;
8use ndarray::{Array1, Array2};
9use rand::prelude::*;
10use rand::rng;
11
12pub fn load_iris() -> Result<Dataset> {
14 #[rustfmt::skip]
16 let data = Array2::from_shape_vec((150, 4), vec![
17 5.1, 3.5, 1.4, 0.2, 4.9, 3.0, 1.4, 0.2, 4.7, 3.2, 1.3, 0.2, 4.6, 3.1, 1.5, 0.2, 5.0, 3.6, 1.4, 0.2,
18 5.4, 3.9, 1.7, 0.4, 4.6, 3.4, 1.4, 0.3, 5.0, 3.4, 1.5, 0.2, 4.4, 2.9, 1.4, 0.2, 4.9, 3.1, 1.5, 0.1,
19 5.4, 3.7, 1.5, 0.2, 4.8, 3.4, 1.6, 0.2, 4.8, 3.0, 1.4, 0.1, 4.3, 3.0, 1.1, 0.1, 5.8, 4.0, 1.2, 0.2,
20 5.7, 4.4, 1.5, 0.4, 5.4, 3.9, 1.3, 0.4, 5.1, 3.5, 1.4, 0.3, 5.7, 3.8, 1.7, 0.3, 5.1, 3.8, 1.5, 0.3,
21 5.4, 3.4, 1.7, 0.2, 5.1, 3.7, 1.5, 0.4, 4.6, 3.6, 1.0, 0.2, 5.1, 3.3, 1.7, 0.5, 4.8, 3.4, 1.9, 0.2,
22 5.0, 3.0, 1.6, 0.2, 5.0, 3.4, 1.6, 0.4, 5.2, 3.5, 1.5, 0.2, 5.2, 3.4, 1.4, 0.2, 4.7, 3.2, 1.6, 0.2,
23 4.8, 3.1, 1.6, 0.2, 5.4, 3.4, 1.5, 0.4, 5.2, 4.1, 1.5, 0.1, 5.5, 4.2, 1.4, 0.2, 4.9, 3.1, 1.5, 0.1,
24 5.0, 3.2, 1.2, 0.2, 5.5, 3.5, 1.3, 0.2, 4.9, 3.1, 1.5, 0.1, 4.4, 3.0, 1.3, 0.2, 5.1, 3.4, 1.5, 0.2,
25 5.0, 3.5, 1.3, 0.3, 4.5, 2.3, 1.3, 0.3, 4.4, 3.2, 1.3, 0.2, 5.0, 3.5, 1.6, 0.6, 5.1, 3.8, 1.9, 0.4,
26 4.8, 3.0, 1.4, 0.3, 5.1, 3.8, 1.6, 0.2, 4.6, 3.2, 1.4, 0.2, 5.3, 3.7, 1.5, 0.2, 5.0, 3.3, 1.4, 0.2,
27 7.0, 3.2, 4.7, 1.4, 6.4, 3.2, 4.5, 1.5, 6.9, 3.1, 4.9, 1.5, 5.5, 2.3, 4.0, 1.3, 6.5, 2.8, 4.6, 1.5,
28 5.7, 2.8, 4.5, 1.3, 6.3, 3.3, 4.7, 1.6, 4.9, 2.4, 3.3, 1.0, 6.6, 2.9, 4.6, 1.3, 5.2, 2.7, 3.9, 1.4,
29 5.0, 2.0, 3.5, 1.0, 5.9, 3.0, 4.2, 1.5, 6.0, 2.2, 4.0, 1.0, 6.1, 2.9, 4.7, 1.4, 5.6, 2.9, 3.6, 1.3,
30 6.7, 3.1, 4.4, 1.4, 5.6, 3.0, 4.5, 1.5, 5.8, 2.7, 4.1, 1.0, 6.2, 2.2, 4.5, 1.5, 5.6, 2.5, 3.9, 1.1,
31 5.9, 3.2, 4.8, 1.8, 6.1, 2.8, 4.0, 1.3, 6.3, 2.5, 4.9, 1.5, 6.1, 2.8, 4.7, 1.2, 6.4, 2.9, 4.3, 1.3,
32 6.6, 3.0, 4.4, 1.4, 6.8, 2.8, 4.8, 1.4, 6.7, 3.0, 5.0, 1.7, 6.0, 2.9, 4.5, 1.5, 5.7, 2.6, 3.5, 1.0,
33 5.5, 2.4, 3.8, 1.1, 5.5, 2.4, 3.7, 1.0, 5.8, 2.7, 3.9, 1.2, 6.0, 2.7, 5.1, 1.6, 5.4, 3.0, 4.5, 1.5,
34 6.0, 3.4, 4.5, 1.6, 6.7, 3.1, 4.7, 1.5, 6.3, 2.3, 4.4, 1.3, 5.6, 3.0, 4.1, 1.3, 5.5, 2.5, 4.0, 1.3,
35 5.5, 2.6, 4.4, 1.2, 6.1, 3.0, 4.6, 1.4, 5.8, 2.6, 4.0, 1.2, 5.0, 2.3, 3.3, 1.0, 5.6, 2.7, 4.2, 1.3,
36 5.7, 3.0, 4.2, 1.2, 5.7, 2.9, 4.2, 1.3, 6.2, 2.9, 4.3, 1.3, 5.1, 2.5, 3.0, 1.1, 5.7, 2.8, 4.1, 1.3,
37 6.3, 3.3, 6.0, 2.5, 5.8, 2.7, 5.1, 1.9, 7.1, 3.0, 5.9, 2.1, 6.3, 2.9, 5.6, 1.8, 6.5, 3.0, 5.8, 2.2,
38 7.6, 3.0, 6.6, 2.1, 4.9, 2.5, 4.5, 1.7, 7.3, 2.9, 6.3, 1.8, 6.7, 2.5, 5.8, 1.8, 7.2, 3.6, 6.1, 2.5,
39 6.5, 3.2, 5.1, 2.0, 6.4, 2.7, 5.3, 1.9, 6.8, 3.0, 5.5, 2.1, 5.7, 2.5, 5.0, 2.0, 5.8, 2.8, 5.1, 2.4,
40 6.4, 3.2, 5.3, 2.3, 6.5, 3.0, 5.5, 1.8, 7.7, 3.8, 6.7, 2.2, 7.7, 2.6, 6.9, 2.3, 6.0, 2.2, 5.0, 1.5,
41 6.9, 3.2, 5.7, 2.3, 5.6, 2.8, 4.9, 2.0, 7.7, 2.8, 6.7, 2.0, 6.3, 2.7, 4.9, 1.8, 6.7, 3.3, 5.7, 2.1,
42 7.2, 3.2, 6.0, 1.8, 6.2, 2.8, 4.8, 1.8, 6.1, 3.0, 4.9, 1.8, 6.4, 2.8, 5.6, 2.1, 7.2, 3.0, 5.8, 1.6,
43 7.4, 2.8, 6.1, 1.9, 7.9, 3.8, 6.4, 2.0, 6.4, 2.8, 5.6, 2.2, 6.3, 2.8, 5.1, 1.5, 6.1, 2.6, 5.6, 1.4,
44 7.7, 3.0, 6.1, 2.3, 6.3, 3.4, 5.6, 2.4, 6.4, 3.1, 5.5, 1.8, 6.0, 3.0, 4.8, 1.8, 6.9, 3.1, 5.4, 2.1,
45 6.7, 3.1, 5.6, 2.4, 6.9, 3.1, 5.1, 2.3, 5.8, 2.7, 5.1, 1.9, 6.8, 3.2, 5.9, 2.3, 6.7, 3.3, 5.7, 2.5,
46 6.7, 3.0, 5.2, 2.3, 6.3, 2.5, 5.0, 1.9, 6.5, 3.0, 5.2, 2.0, 6.2, 3.4, 5.4, 2.3, 5.9, 3.0, 5.1, 1.8
47 ]).unwrap();
48
49 let targets = vec![
51 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
52 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
53 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0,
54 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
55 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
56 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
57 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
58 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
59 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
60 ];
61 let target = Array1::from(targets);
62
63 let mut dataset = Dataset::new(data, Some(target));
65
66 let feature_names = vec![
68 "sepal_length".to_string(),
69 "sepal_width".to_string(),
70 "petal_length".to_string(),
71 "petal_width".to_string(),
72 ];
73
74 let target_names = vec![
75 "setosa".to_string(),
76 "versicolor".to_string(),
77 "virginica".to_string(),
78 ];
79
80 let description = "Iris dataset: classic dataset for classification, clustering, and machine learning
81
82The dataset contains 3 classes of 50 instances each, where each class refers to a type of iris plant.
83One class is linearly separable from the other two; the latter are not linearly separable from each other.
84
85Attributes:
86- sepal length in cm
87- sepal width in cm
88- petal length in cm
89- petal width in cm
90
91Target:
92- Iris Setosa
93- Iris Versicolour
94- Iris Virginica".to_string();
95
96 dataset = dataset
97 .with_feature_names(feature_names)
98 .with_target_names(target_names)
99 .with_description(description);
100
101 Ok(dataset)
102}
103
104pub fn load_breast_cancer() -> Result<Dataset> {
106 #[rustfmt::skip]
109 let data = Array2::from_shape_vec((30, 5), vec![
110 17.99, 10.38, 122.8, 1001.0, 0.1184,
111 20.57, 17.77, 132.9, 1326.0, 0.08474,
112 19.69, 21.25, 130.0, 1203.0, 0.1096,
113 11.42, 20.38, 77.58, 386.1, 0.1425,
114 20.29, 14.34, 135.1, 1297.0, 0.1003,
115 12.45, 15.7, 82.57, 477.1, 0.1278,
116 18.25, 19.98, 119.6, 1040.0, 0.09463,
117 13.71, 20.83, 90.2, 577.9, 0.1189,
118 13.0, 21.82, 87.5, 519.8, 0.1273,
119 12.46, 24.04, 83.97, 475.9, 0.1186,
120 16.02, 23.24, 102.7, 797.8, 0.08206,
121 15.78, 17.89, 103.6, 781.0, 0.0971,
122 19.17, 24.8, 132.4, 1123.0, 0.0974,
123 15.85, 23.95, 103.7, 782.7, 0.08401,
124 13.73, 22.61, 93.6, 578.3, 0.1131,
125 14.54, 27.54, 96.73, 658.8, 0.1139,
126 14.68, 20.13, 94.74, 684.5, 0.09867,
127 16.13, 20.68, 108.1, 798.8, 0.117,
128 19.81, 22.15, 130.0, 1260.0, 0.09831,
129 13.54, 14.36, 87.46, 566.3, 0.09779,
130 13.08, 15.71, 85.63, 520.0, 0.1075,
131 9.504, 12.44, 60.34, 273.9, 0.1024,
132 15.34, 14.26, 102.5, 704.4, 0.1073,
133 21.16, 23.04, 137.2, 1404.0, 0.09428,
134 16.65, 21.38, 110.0, 904.6, 0.1121,
135 17.14, 16.4, 116.0, 912.7, 0.1186,
136 14.58, 21.53, 97.41, 644.8, 0.1054,
137 18.61, 20.25, 122.1, 1094.0, 0.0944,
138 15.3, 25.27, 102.4, 732.4, 0.1082,
139 17.57, 15.05, 115.0, 955.1, 0.09847
140 ]).unwrap();
141
142 let targets = vec![
144 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0,
145 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
146 ];
147 let target = Array1::from(targets);
148
149 let mut dataset = Dataset::new(data, Some(target));
151
152 let feature_names = vec![
154 "mean_radius".to_string(),
155 "mean_texture".to_string(),
156 "mean_perimeter".to_string(),
157 "mean_area".to_string(),
158 "mean_smoothness".to_string(),
159 ];
160
161 let target_names = vec!["malignant".to_string(), "benign".to_string()];
162
163 let description = "Breast Cancer Wisconsin (Diagnostic) Database
164
165Features computed from a digitized image of a fine needle aspirate (FNA) of a breast mass.
166They describe characteristics of the cell nuclei present in the image.
167
168(This is a simplified version of the dataset with only 5 features and 30 samples)
169
170Target:
171- Malignant
172- Benign"
173 .to_string();
174
175 dataset = dataset
176 .with_feature_names(feature_names)
177 .with_target_names(target_names)
178 .with_description(description);
179
180 Ok(dataset)
181}
182
183pub fn load_digits() -> Result<Dataset> {
185 let n_samples = 50; let n_features = 16; let mut data = Array2::zeros((n_samples, n_features));
191 let mut target = Array1::zeros(n_samples);
192
193 #[rustfmt::skip]
195 let digit_patterns = [
196 [0., 1., 1., 0.,
198 1., 0., 0., 1.,
199 1., 0., 0., 1.,
200 0., 1., 1., 0.],
201 [0., 1., 0., 0.,
203 0., 1., 0., 0.,
204 0., 1., 0., 0.,
205 0., 1., 0., 0.],
206 [1., 1., 1., 0.,
208 0., 0., 1., 0.,
209 0., 1., 0., 0.,
210 1., 1., 1., 1.],
211 [1., 1., 1., 0.,
213 0., 0., 1., 0.,
214 1., 1., 1., 0.,
215 0., 0., 1., 0.],
216 [1., 0., 1., 0.,
218 1., 0., 1., 0.,
219 1., 1., 1., 1.,
220 0., 0., 1., 0.],
221 [1., 1., 1., 1.,
223 1., 0., 0., 0.,
224 1., 1., 1., 0.,
225 0., 0., 1., 1.],
226 [0., 1., 1., 0.,
228 1., 0., 0., 0.,
229 1., 1., 1., 0.,
230 0., 1., 1., 0.],
231 [1., 1., 1., 1.,
233 0., 0., 0., 1.,
234 0., 0., 1., 0.,
235 0., 1., 0., 0.],
236 [0., 1., 1., 0.,
238 1., 0., 0., 1.,
239 0., 1., 1., 0.,
240 1., 0., 0., 1.],
241 [0., 1., 1., 0.,
243 1., 0., 0., 1.,
244 0., 1., 1., 1.,
245 0., 0., 1., 0.],
246 ];
247
248 let mut rng = rng();
250 let noise_level = 0.1;
251
252 for (digit, &pattern) in digit_patterns.iter().enumerate() {
253 for sample in 0..5 {
254 let idx = digit * 5 + sample;
255 target[idx] = digit as f64;
256
257 for (j, &pixel) in pattern.iter().enumerate() {
259 let noise = if pixel > 0.5 {
260 -noise_level * rng.random::<f64>()
261 } else {
262 noise_level * rng.random::<f64>()
263 };
264
265 let val = pixel + noise;
266 data[[idx, j]] = val.clamp(0.0, 1.0);
267 }
268 }
269 }
270
271 let mut dataset = Dataset::new(data, Some(target));
273
274 let feature_names: Vec<String> = (0..n_features).map(|i| format!("pixel_{}", i)).collect();
276
277 let target_names: Vec<String> = (0..10).map(|i| format!("{}", i)).collect();
278
279 let description = "Optical recognition of handwritten digits dataset
280
281A simplified version with 50 samples (5 for each digit 0-9) and 16 features (4x4 pixel images).
282Each feature is the grayscale value of a pixel in the image.
283
284Target: Digit identity (0-9)"
285 .to_string();
286
287 dataset = dataset
288 .with_feature_names(feature_names)
289 .with_target_names(target_names)
290 .with_description(description);
291
292 Ok(dataset)
293}
294
295pub fn load_boston() -> Result<Dataset> {
297 let n_samples = 30;
299 let n_features = 5;
300
301 #[rustfmt::skip]
302 let data = Array2::from_shape_vec((n_samples, n_features), vec![
303 0.00632, 18.0, 2.31, 0.538, 6.575,
304 0.02731, 0.0, 7.07, 0.469, 6.421,
305 0.02729, 0.0, 7.07, 0.469, 7.185,
306 0.03237, 0.0, 2.18, 0.458, 6.998,
307 0.06905, 0.0, 2.18, 0.458, 7.147,
308 0.02985, 0.0, 2.18, 0.458, 6.430,
309 0.08829, 12.5, 7.87, 0.524, 6.012,
310 0.14455, 12.5, 7.87, 0.524, 6.172,
311 0.21124, 12.5, 7.87, 0.524, 5.631,
312 0.17004, 12.5, 7.87, 0.524, 6.004,
313 0.22489, 12.5, 7.87, 0.524, 6.377,
314 0.11747, 12.5, 7.87, 0.524, 6.009,
315 0.09378, 12.5, 7.87, 0.524, 5.889,
316 0.62976, 0.0, 8.14, 0.538, 5.949,
317 0.63796, 0.0, 8.14, 0.538, 6.096,
318 0.62739, 0.0, 8.14, 0.538, 5.834,
319 1.05393, 0.0, 8.14, 0.538, 5.935,
320 0.7842, 0.0, 8.14, 0.538, 5.990,
321 0.80271, 0.0, 8.14, 0.538, 5.456,
322 0.7258, 0.0, 8.14, 0.538, 5.727,
323 1.25179, 0.0, 8.14, 0.538, 5.570,
324 0.85204, 0.0, 8.14, 0.538, 5.965,
325 1.23247, 0.0, 8.14, 0.538, 6.142,
326 0.98843, 0.0, 8.14, 0.538, 5.813,
327 0.75026, 0.0, 8.14, 0.538, 5.924,
328 0.84054, 0.0, 8.14, 0.538, 5.599,
329 0.67191, 0.0, 8.14, 0.538, 5.813,
330 0.95577, 0.0, 8.14, 0.538, 6.047,
331 0.77299, 0.0, 8.14, 0.538, 6.495,
332 1.00245, 0.0, 8.14, 0.538, 6.674
333 ]).unwrap();
334
335 let targets = vec![
336 24.0, 21.6, 34.7, 33.4, 36.2, 28.7, 22.9, 27.1, 16.5, 18.9, 15.0, 18.9, 21.7, 20.4, 18.2,
337 19.9, 23.1, 17.5, 20.2, 18.2, 13.6, 19.6, 15.2, 14.5, 15.6, 13.9, 16.6, 14.8, 18.4, 21.0,
338 ];
339 let target = Array1::from(targets);
340
341 let mut dataset = Dataset::new(data, Some(target));
343
344 let feature_names = vec![
346 "CRIM".to_string(),
347 "ZN".to_string(),
348 "INDUS".to_string(),
349 "CHAS".to_string(),
350 "NOX".to_string(),
351 ];
352
353 let feature_descriptions = vec![
354 "per capita crime rate by town".to_string(),
355 "proportion of residential land zoned for lots over 25,000 sq.ft.".to_string(),
356 "proportion of non-retail business acres per town".to_string(),
357 "Charles River dummy variable (= 1 if tract bounds river; 0 otherwise)".to_string(),
358 "nitric oxides concentration (parts per 10 million)".to_string(),
359 ];
360
361 let description = "Boston Housing Dataset (Simplified)
362
363A simplified version of the Boston housing dataset with 30 samples and 5 features.
364The target variable is the median value of owner-occupied homes in $1000s.
365
366This is a regression dataset."
367 .to_string();
368
369 dataset = dataset
370 .with_feature_names(feature_names)
371 .with_feature_descriptions(feature_descriptions)
372 .with_description(description);
373
374 Ok(dataset)
375}
376
377#[cfg(test)]
378mod tests {
379 use super::*;
380
381 #[test]
382 fn test_load_iris() {
383 let dataset = load_iris().unwrap();
384
385 assert_eq!(dataset.n_samples(), 150);
386 assert_eq!(dataset.n_features(), 4);
387 assert!(dataset.target.is_some());
388 assert!(dataset.description.is_some());
389 assert!(dataset.feature_names.is_some());
390 assert!(dataset.target_names.is_some());
391
392 let feature_names = dataset.feature_names.as_ref().unwrap();
393 assert_eq!(feature_names.len(), 4);
394 assert_eq!(feature_names[0], "sepal_length");
395 assert_eq!(feature_names[3], "petal_width");
396
397 let target_names = dataset.target_names.as_ref().unwrap();
398 assert_eq!(target_names.len(), 3);
399 assert!(target_names.contains(&"setosa".to_string()));
400 assert!(target_names.contains(&"versicolor".to_string()));
401 assert!(target_names.contains(&"virginica".to_string()));
402
403 let target = dataset.target.as_ref().unwrap();
405 for &val in target.iter() {
406 assert!((0.0..=2.0).contains(&val));
407 }
408 }
409
410 #[test]
411 fn test_load_breast_cancer() {
412 let dataset = load_breast_cancer().unwrap();
413
414 assert_eq!(dataset.n_samples(), 30);
415 assert_eq!(dataset.n_features(), 5);
416 assert!(dataset.target.is_some());
417 assert!(dataset.description.is_some());
418 assert!(dataset.feature_names.is_some());
419 assert!(dataset.target_names.is_some());
420
421 let feature_names = dataset.feature_names.as_ref().unwrap();
422 assert_eq!(feature_names.len(), 5);
423 assert_eq!(feature_names[0], "mean_radius");
424 assert_eq!(feature_names[4], "mean_smoothness");
425
426 let target_names = dataset.target_names.as_ref().unwrap();
427 assert_eq!(target_names.len(), 2);
428 assert!(target_names.contains(&"malignant".to_string()));
429 assert!(target_names.contains(&"benign".to_string()));
430
431 let target = dataset.target.as_ref().unwrap();
433 for &val in target.iter() {
434 assert!(val == 0.0 || val == 1.0);
435 }
436 }
437
438 #[test]
439 fn test_load_digits() {
440 let dataset = load_digits().unwrap();
441
442 assert_eq!(dataset.n_samples(), 50);
443 assert_eq!(dataset.n_features(), 16);
444 assert!(dataset.target.is_some());
445 assert!(dataset.description.is_some());
446 assert!(dataset.feature_names.is_some());
447 assert!(dataset.target_names.is_some());
448
449 let feature_names = dataset.feature_names.as_ref().unwrap();
450 assert_eq!(feature_names.len(), 16);
451 assert_eq!(feature_names[0], "pixel_0");
452 assert_eq!(feature_names[15], "pixel_15");
453
454 let target_names = dataset.target_names.as_ref().unwrap();
455 assert_eq!(target_names.len(), 10);
456 for i in 0..10 {
457 assert!(target_names.contains(&i.to_string()));
458 }
459
460 let target = dataset.target.as_ref().unwrap();
462 for &val in target.iter() {
463 assert!((0.0..=9.0).contains(&val));
464 }
465
466 for row in dataset.data.rows() {
468 for &pixel in row.iter() {
469 assert!((0.0..=1.0).contains(&pixel));
470 }
471 }
472 }
473
474 #[test]
475 fn test_load_boston() {
476 let dataset = load_boston().unwrap();
477
478 assert_eq!(dataset.n_samples(), 30);
479 assert_eq!(dataset.n_features(), 5);
480 assert!(dataset.target.is_some());
481 assert!(dataset.description.is_some());
482 assert!(dataset.feature_names.is_some());
483 assert!(dataset.feature_descriptions.is_some());
484
485 let feature_names = dataset.feature_names.as_ref().unwrap();
486 assert_eq!(feature_names.len(), 5);
487 assert_eq!(feature_names[0], "CRIM");
488 assert_eq!(feature_names[4], "NOX");
489
490 let feature_descriptions = dataset.feature_descriptions.as_ref().unwrap();
491 assert_eq!(feature_descriptions.len(), 5);
492 assert!(feature_descriptions[0].contains("crime rate"));
493
494 let target = dataset.target.as_ref().unwrap();
496 for &val in target.iter() {
497 assert!(val > 0.0 && val < 100.0); }
499 }
500
501 #[test]
502 fn test_all_datasets_have_consistent_shapes() {
503 let datasets = vec![
504 ("iris", load_iris().unwrap()),
505 ("breast_cancer", load_breast_cancer().unwrap()),
506 ("digits", load_digits().unwrap()),
507 ("boston", load_boston().unwrap()),
508 ];
509
510 for (name, dataset) in datasets {
511 if let Some(ref target) = dataset.target {
513 assert_eq!(
514 dataset.data.nrows(),
515 target.len(),
516 "Dataset '{}' has inconsistent sample counts",
517 name
518 );
519 }
520
521 if let Some(ref feature_names) = dataset.feature_names {
523 assert_eq!(
524 dataset.data.ncols(),
525 feature_names.len(),
526 "Dataset '{}' has inconsistent feature count",
527 name
528 );
529 }
530
531 if let Some(ref feature_descriptions) = dataset.feature_descriptions {
533 assert_eq!(
534 dataset.data.ncols(),
535 feature_descriptions.len(),
536 "Dataset '{}' has inconsistent feature description count",
537 name
538 );
539 }
540
541 assert!(
543 dataset.description.is_some(),
544 "Dataset '{}' missing description",
545 name
546 );
547 }
548 }
549}