1use crate::error::Result;
7use crate::utils::Dataset;
8use scirs2_core::ndarray::{Array1, Array2};
9use scirs2_core::random::prelude::*;
10
11#[allow(dead_code)]
13pub fn load_iris() -> Result<Dataset> {
14 #[rustfmt::skip]
16 let data = Array2::from_shape_vec((150, 4), vec![
17 5.1, 3.5, 1.4, 0.2, 4.9, 3.0, 1.4, 0.2, 4.7, 3.2, 1.3, 0.2, 4.6, 3.1, 1.5, 0.2, 5.0, 3.6, 1.4, 0.2,
18 5.4, 3.9, 1.7, 0.4, 4.6, 3.4, 1.4, 0.3, 5.0, 3.4, 1.5, 0.2, 4.4, 2.9, 1.4, 0.2, 4.9, 3.1, 1.5, 0.1,
19 5.4, 3.7, 1.5, 0.2, 4.8, 3.4, 1.6, 0.2, 4.8, 3.0, 1.4, 0.1, 4.3, 3.0, 1.1, 0.1, 5.8, 4.0, 1.2, 0.2,
20 5.7, 4.4, 1.5, 0.4, 5.4, 3.9, 1.3, 0.4, 5.1, 3.5, 1.4, 0.3, 5.7, 3.8, 1.7, 0.3, 5.1, 3.8, 1.5, 0.3,
21 5.4, 3.4, 1.7, 0.2, 5.1, 3.7, 1.5, 0.4, 4.6, 3.6, 1.0, 0.2, 5.1, 3.3, 1.7, 0.5, 4.8, 3.4, 1.9, 0.2,
22 5.0, 3.0, 1.6, 0.2, 5.0, 3.4, 1.6, 0.4, 5.2, 3.5, 1.5, 0.2, 5.2, 3.4, 1.4, 0.2, 4.7, 3.2, 1.6, 0.2,
23 4.8, 3.1, 1.6, 0.2, 5.4, 3.4, 1.5, 0.4, 5.2, 4.1, 1.5, 0.1, 5.5, 4.2, 1.4, 0.2, 4.9, 3.1, 1.5, 0.1,
24 5.0, 3.2, 1.2, 0.2, 5.5, 3.5, 1.3, 0.2, 4.9, 3.1, 1.5, 0.1, 4.4, 3.0, 1.3, 0.2, 5.1, 3.4, 1.5, 0.2,
25 5.0, 3.5, 1.3, 0.3, 4.5, 2.3, 1.3, 0.3, 4.4, 3.2, 1.3, 0.2, 5.0, 3.5, 1.6, 0.6, 5.1, 3.8, 1.9, 0.4,
26 4.8, 3.0, 1.4, 0.3, 5.1, 3.8, 1.6, 0.2, 4.6, 3.2, 1.4, 0.2, 5.3, 3.7, 1.5, 0.2, 5.0, 3.3, 1.4, 0.2,
27 7.0, 3.2, 4.7, 1.4, 6.4, 3.2, 4.5, 1.5, 6.9, 3.1, 4.9, 1.5, 5.5, 2.3, 4.0, 1.3, 6.5, 2.8, 4.6, 1.5,
28 5.7, 2.8, 4.5, 1.3, 6.3, 3.3, 4.7, 1.6, 4.9, 2.4, 3.3, 1.0, 6.6, 2.9, 4.6, 1.3, 5.2, 2.7, 3.9, 1.4,
29 5.0, 2.0, 3.5, 1.0, 5.9, 3.0, 4.2, 1.5, 6.0, 2.2, 4.0, 1.0, 6.1, 2.9, 4.7, 1.4, 5.6, 2.9, 3.6, 1.3,
30 6.7, 3.1, 4.4, 1.4, 5.6, 3.0, 4.5, 1.5, 5.8, 2.7, 4.1, 1.0, 6.2, 2.2, 4.5, 1.5, 5.6, 2.5, 3.9, 1.1,
31 5.9, 3.2, 4.8, 1.8, 6.1, 2.8, 4.0, 1.3, 6.3, 2.5, 4.9, 1.5, 6.1, 2.8, 4.7, 1.2, 6.4, 2.9, 4.3, 1.3,
32 6.6, 3.0, 4.4, 1.4, 6.8, 2.8, 4.8, 1.4, 6.7, 3.0, 5.0, 1.7, 6.0, 2.9, 4.5, 1.5, 5.7, 2.6, 3.5, 1.0,
33 5.5, 2.4, 3.8, 1.1, 5.5, 2.4, 3.7, 1.0, 5.8, 2.7, 3.9, 1.2, 6.0, 2.7, 5.1, 1.6, 5.4, 3.0, 4.5, 1.5,
34 6.0, 3.4, 4.5, 1.6, 6.7, 3.1, 4.7, 1.5, 6.3, 2.3, 4.4, 1.3, 5.6, 3.0, 4.1, 1.3, 5.5, 2.5, 4.0, 1.3,
35 5.5, 2.6, 4.4, 1.2, 6.1, 3.0, 4.6, 1.4, 5.8, 2.6, 4.0, 1.2, 5.0, 2.3, 3.3, 1.0, 5.6, 2.7, 4.2, 1.3,
36 5.7, 3.0, 4.2, 1.2, 5.7, 2.9, 4.2, 1.3, 6.2, 2.9, 4.3, 1.3, 5.1, 2.5, 3.0, 1.1, 5.7, 2.8, 4.1, 1.3,
37 6.3, 3.3, 6.0, 2.5, 5.8, 2.7, 5.1, 1.9, 7.1, 3.0, 5.9, 2.1, 6.3, 2.9, 5.6, 1.8, 6.5, 3.0, 5.8, 2.2,
38 7.6, 3.0, 6.6, 2.1, 4.9, 2.5, 4.5, 1.7, 7.3, 2.9, 6.3, 1.8, 6.7, 2.5, 5.8, 1.8, 7.2, 3.6, 6.1, 2.5,
39 6.5, 3.2, 5.1, 2.0, 6.4, 2.7, 5.3, 1.9, 6.8, 3.0, 5.5, 2.1, 5.7, 2.5, 5.0, 2.0, 5.8, 2.8, 5.1, 2.4,
40 6.4, 3.2, 5.3, 2.3, 6.5, 3.0, 5.5, 1.8, 7.7, 3.8, 6.7, 2.2, 7.7, 2.6, 6.9, 2.3, 6.0, 2.2, 5.0, 1.5,
41 6.9, 3.2, 5.7, 2.3, 5.6, 2.8, 4.9, 2.0, 7.7, 2.8, 6.7, 2.0, 6.3, 2.7, 4.9, 1.8, 6.7, 3.3, 5.7, 2.1,
42 7.2, 3.2, 6.0, 1.8, 6.2, 2.8, 4.8, 1.8, 6.1, 3.0, 4.9, 1.8, 6.4, 2.8, 5.6, 2.1, 7.2, 3.0, 5.8, 1.6,
43 7.4, 2.8, 6.1, 1.9, 7.9, 3.8, 6.4, 2.0, 6.4, 2.8, 5.6, 2.2, 6.3, 2.8, 5.1, 1.5, 6.1, 2.6, 5.6, 1.4,
44 7.7, 3.0, 6.1, 2.3, 6.3, 3.4, 5.6, 2.4, 6.4, 3.1, 5.5, 1.8, 6.0, 3.0, 4.8, 1.8, 6.9, 3.1, 5.4, 2.1,
45 6.7, 3.1, 5.6, 2.4, 6.9, 3.1, 5.1, 2.3, 5.8, 2.7, 5.1, 1.9, 6.8, 3.2, 5.9, 2.3, 6.7, 3.3, 5.7, 2.5,
46 6.7, 3.0, 5.2, 2.3, 6.3, 2.5, 5.0, 1.9, 6.5, 3.0, 5.2, 2.0, 6.2, 3.4, 5.4, 2.3, 5.9, 3.0, 5.1, 1.8
47 ]).unwrap();
48
49 let targets = vec![
51 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
52 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
53 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0,
54 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
55 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
56 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
57 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
58 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
59 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
60 ];
61 let target = Array1::from(targets);
62
63 let mut dataset = Dataset::new(data, Some(target));
65
66 let featurenames = vec![
68 "sepal_length".to_string(),
69 "sepal_width".to_string(),
70 "petal_length".to_string(),
71 "petal_width".to_string(),
72 ];
73
74 let targetnames = vec![
75 "setosa".to_string(),
76 "versicolor".to_string(),
77 "virginica".to_string(),
78 ];
79
80 let description = "Iris dataset: classic dataset for classification, clustering, and machine learning
81
82The dataset contains 3 classes of 50 instances each, where each class refers to a type of iris plant.
83One class is linearly separable from the other two; the latter are not linearly separable from each other.
84
85Attributes:
86- sepal length in cm
87- sepal width in cm
88- petal length in cm
89- petal width in cm
90
91Target:
92- Iris Setosa
93- Iris Versicolour
94- Iris Virginica".to_string();
95
96 dataset = dataset
97 .with_featurenames(featurenames)
98 .with_targetnames(targetnames)
99 .with_description(description);
100
101 Ok(dataset)
102}
103
104#[allow(dead_code)]
106pub fn load_breast_cancer() -> Result<Dataset> {
107 #[rustfmt::skip]
110 let data = Array2::from_shape_vec((30, 5), vec![
111 17.99, 10.38, 122.8, 1001.0, 0.1184,
112 20.57, 17.77, 132.9, 1326.0, 0.08474,
113 19.69, 21.25, 130.0, 1203.0, 0.1096,
114 11.42, 20.38, 77.58, 386.1, 0.1425,
115 20.29, 14.34, 135.1, 1297.0, 0.1003,
116 12.45, 15.7, 82.57, 477.1, 0.1278,
117 18.25, 19.98, 119.6, 1040.0, 0.09463,
118 13.71, 20.83, 90.2, 577.9, 0.1189,
119 13.0, 21.82, 87.5, 519.8, 0.1273,
120 12.46, 24.04, 83.97, 475.9, 0.1186,
121 16.02, 23.24, 102.7, 797.8, 0.08206,
122 15.78, 17.89, 103.6, 781.0, 0.0971,
123 19.17, 24.8, 132.4, 1123.0, 0.0974,
124 15.85, 23.95, 103.7, 782.7, 0.08401,
125 13.73, 22.61, 93.6, 578.3, 0.1131,
126 14.54, 27.54, 96.73, 658.8, 0.1139,
127 14.68, 20.13, 94.74, 684.5, 0.09867,
128 16.13, 20.68, 108.1, 798.8, 0.117,
129 19.81, 22.15, 130.0, 1260.0, 0.09831,
130 13.54, 14.36, 87.46, 566.3, 0.09779,
131 13.08, 15.71, 85.63, 520.0, 0.1075,
132 9.504, 12.44, 60.34, 273.9, 0.1024,
133 15.34, 14.26, 102.5, 704.4, 0.1073,
134 21.16, 23.04, 137.2, 1404.0, 0.09428,
135 16.65, 21.38, 110.0, 904.6, 0.1121,
136 17.14, 16.4, 116.0, 912.7, 0.1186,
137 14.58, 21.53, 97.41, 644.8, 0.1054,
138 18.61, 20.25, 122.1, 1094.0, 0.0944,
139 15.3, 25.27, 102.4, 732.4, 0.1082,
140 17.57, 15.05, 115.0, 955.1, 0.09847
141 ]).unwrap();
142
143 let targets = vec![
145 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0,
146 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
147 ];
148 let target = Array1::from(targets);
149
150 let mut dataset = Dataset::new(data, Some(target));
152
153 let featurenames = vec![
155 "mean_radius".to_string(),
156 "meantexture".to_string(),
157 "mean_perimeter".to_string(),
158 "mean_area".to_string(),
159 "mean_smoothness".to_string(),
160 ];
161
162 let targetnames = vec!["malignant".to_string(), "benign".to_string()];
163
164 let description = "Breast Cancer Wisconsin (Diagnostic) Database
165
166Features computed from a digitized image of a fine needle aspirate (FNA) of a breast mass.
167They describe characteristics of the cell nuclei present in the image.
168
169(This is a simplified version of the dataset with only 5 features and 30 samples)
170
171Target:
172- Malignant
173- Benign"
174 .to_string();
175
176 dataset = dataset
177 .with_featurenames(featurenames)
178 .with_targetnames(targetnames)
179 .with_description(description);
180
181 Ok(dataset)
182}
183
184#[allow(dead_code)]
186pub fn load_digits() -> Result<Dataset> {
187 let n_samples = 50; let n_features = 16; let mut data = Array2::zeros((n_samples, n_features));
193 let mut target = Array1::zeros(n_samples);
194
195 #[rustfmt::skip]
197 let digit_patterns = [
198 [0., 1., 1., 0.,
200 1., 0., 0., 1.,
201 1., 0., 0., 1.,
202 0., 1., 1., 0.],
203 [0., 1., 0., 0.,
205 0., 1., 0., 0.,
206 0., 1., 0., 0.,
207 0., 1., 0., 0.],
208 [1., 1., 1., 0.,
210 0., 0., 1., 0.,
211 0., 1., 0., 0.,
212 1., 1., 1., 1.],
213 [1., 1., 1., 0.,
215 0., 0., 1., 0.,
216 1., 1., 1., 0.,
217 0., 0., 1., 0.],
218 [1., 0., 1., 0.,
220 1., 0., 1., 0.,
221 1., 1., 1., 1.,
222 0., 0., 1., 0.],
223 [1., 1., 1., 1.,
225 1., 0., 0., 0.,
226 1., 1., 1., 0.,
227 0., 0., 1., 1.],
228 [0., 1., 1., 0.,
230 1., 0., 0., 0.,
231 1., 1., 1., 0.,
232 0., 1., 1., 0.],
233 [1., 1., 1., 1.,
235 0., 0., 0., 1.,
236 0., 0., 1., 0.,
237 0., 1., 0., 0.],
238 [0., 1., 1., 0.,
240 1., 0., 0., 1.,
241 0., 1., 1., 0.,
242 1., 0., 0., 1.],
243 [0., 1., 1., 0.,
245 1., 0., 0., 1.,
246 0., 1., 1., 1.,
247 0., 0., 1., 0.],
248 ];
249
250 let mut rng = thread_rng();
252 let noise_level = 0.1;
253
254 for (digit, &pattern) in digit_patterns.iter().enumerate() {
255 for sample in 0..5 {
256 let idx = digit * 5 + sample;
257 target[idx] = digit as f64;
258
259 for (j, &pixel) in pattern.iter().enumerate() {
261 let noise = if pixel > 0.5 {
262 -noise_level * rng.random::<f64>()
263 } else {
264 noise_level * rng.random::<f64>()
265 };
266
267 let val: f64 = pixel + noise;
268 data[[idx, j]] = val.clamp(0.0, 1.0);
269 }
270 }
271 }
272
273 let mut dataset = Dataset::new(data, Some(target));
275
276 let featurenames: Vec<String> = (0..n_features).map(|i| format!("pixel_{i}")).collect();
278
279 let targetnames: Vec<String> = (0..10).map(|i| format!("{i}")).collect();
280
281 let description = "Optical recognition of handwritten digits dataset
282
283A simplified version with 50 samples (5 for each digit 0-9) and 16 features (4x4 pixel images).
284Each feature is the grayscale value of a pixel in the image.
285
286Target: Digit identity (0-9)"
287 .to_string();
288
289 dataset = dataset
290 .with_featurenames(featurenames)
291 .with_targetnames(targetnames)
292 .with_description(description);
293
294 Ok(dataset)
295}
296
297#[allow(dead_code)]
299pub fn load_boston() -> Result<Dataset> {
300 let n_samples = 30;
302 let n_features = 5;
303
304 #[rustfmt::skip]
305 let data = Array2::from_shape_vec((n_samples, n_features), vec![
306 0.00632, 18.0, 2.31, 0.538, 6.575,
307 0.02731, 0.0, 7.07, 0.469, 6.421,
308 0.02729, 0.0, 7.07, 0.469, 7.185,
309 0.03237, 0.0, 2.18, 0.458, 6.998,
310 0.06905, 0.0, 2.18, 0.458, 7.147,
311 0.02985, 0.0, 2.18, 0.458, 6.430,
312 0.08829, 12.5, 7.87, 0.524, 6.012,
313 0.14455, 12.5, 7.87, 0.524, 6.172,
314 0.21124, 12.5, 7.87, 0.524, 5.631,
315 0.17004, 12.5, 7.87, 0.524, 6.004,
316 0.22489, 12.5, 7.87, 0.524, 6.377,
317 0.11747, 12.5, 7.87, 0.524, 6.009,
318 0.09378, 12.5, 7.87, 0.524, 5.889,
319 0.62976, 0.0, 8.14, 0.538, 5.949,
320 0.63796, 0.0, 8.14, 0.538, 6.096,
321 0.62739, 0.0, 8.14, 0.538, 5.834,
322 1.05393, 0.0, 8.14, 0.538, 5.935,
323 0.7842, 0.0, 8.14, 0.538, 5.990,
324 0.80271, 0.0, 8.14, 0.538, 5.456,
325 0.7258, 0.0, 8.14, 0.538, 5.727,
326 1.25179, 0.0, 8.14, 0.538, 5.570,
327 0.85204, 0.0, 8.14, 0.538, 5.965,
328 1.23247, 0.0, 8.14, 0.538, 6.142,
329 0.98843, 0.0, 8.14, 0.538, 5.813,
330 0.75026, 0.0, 8.14, 0.538, 5.924,
331 0.84054, 0.0, 8.14, 0.538, 5.599,
332 0.67191, 0.0, 8.14, 0.538, 5.813,
333 0.95577, 0.0, 8.14, 0.538, 6.047,
334 0.77299, 0.0, 8.14, 0.538, 6.495,
335 1.00245, 0.0, 8.14, 0.538, 6.674
336 ]).unwrap();
337
338 let targets = vec![
339 24.0, 21.6, 34.7, 33.4, 36.2, 28.7, 22.9, 27.1, 16.5, 18.9, 15.0, 18.9, 21.7, 20.4, 18.2,
340 19.9, 23.1, 17.5, 20.2, 18.2, 13.6, 19.6, 15.2, 14.5, 15.6, 13.9, 16.6, 14.8, 18.4, 21.0,
341 ];
342 let target = Array1::from(targets);
343
344 let mut dataset = Dataset::new(data, Some(target));
346
347 let featurenames = vec![
349 "CRIM".to_string(),
350 "ZN".to_string(),
351 "INDUS".to_string(),
352 "CHAS".to_string(),
353 "NOX".to_string(),
354 ];
355
356 let feature_descriptions = vec![
357 "per capita crime rate by town".to_string(),
358 "proportion of residential land zoned for lots over 25,000 sq.ft.".to_string(),
359 "proportion of non-retail business acres per town".to_string(),
360 "Charles River dummy variable (= 1 if tract bounds river; 0 otherwise)".to_string(),
361 "nitric oxides concentration (parts per 10 million)".to_string(),
362 ];
363
364 let description = "Boston Housing Dataset (Simplified)
365
366A simplified version of the Boston housing dataset with 30 samples and 5 features.
367The target variable is the median value of owner-occupied homes in $1000s.
368
369This is a regression dataset."
370 .to_string();
371
372 dataset = dataset
373 .with_featurenames(featurenames)
374 .with_feature_descriptions(feature_descriptions)
375 .with_description(description);
376
377 Ok(dataset)
378}
379
380#[allow(dead_code)]
385pub fn load_diabetes() -> Result<Dataset> {
386 let mut rng = StdRng::seed_from_u64(42);
388
389 let n_samples = 442;
390 let n_features = 10;
391
392 let mut data = Vec::with_capacity(n_samples * n_features);
394 let mut targets = Vec::with_capacity(n_samples);
395
396 for _ in 0..n_samples {
397 let age = rng.random::<f64>() * 0.1 - 0.05;
399 let sex = if rng.random::<f64>() < 0.5 {
400 -0.05
401 } else {
402 0.05
403 };
404 let bmi = (rng.random::<f64>() * 0.12 - 0.06) + age * 0.3;
405 let bp = (rng.random::<f64>() * 0.1 - 0.05) + bmi * 0.4;
406 let s1 = (rng.random::<f64>() * 0.14 - 0.07) + bmi * 0.2;
407 let s2 = (rng.random::<f64>() * 0.16 - 0.08) + s1 * 0.5;
408 let s3 = (rng.random::<f64>() * 0.12 - 0.06) + age * 0.2;
409 let s4 = (rng.random::<f64>() * 0.12 - 0.06) + s1 * 0.3;
410 let s5 = (rng.random::<f64>() * 0.14 - 0.07) + bmi * 0.25;
411 let s6 = (rng.random::<f64>() * 0.1 - 0.05) + s5 * 0.4;
412
413 data.extend_from_slice(&[age, sex, bmi, bp, s1, s2, s3, s4, s5, s6]);
414
415 let target = 152.0
417 + 938.0 * bmi
418 + 519.0 * bp
419 + 324.0 * s1
420 + 217.0 * s5
421 + (rng.random::<f64>() * 40.0 - 20.0);
422 targets.push(target);
423 }
424
425 let data_array = Array2::from_shape_vec((n_samples, n_features), data).unwrap();
426 let target_array = Array1::from_vec(targets);
427
428 let featurenames = vec![
429 "age".to_string(),
430 "sex".to_string(),
431 "bmi".to_string(),
432 "bp".to_string(),
433 "s1".to_string(),
434 "s2".to_string(),
435 "s3".to_string(),
436 "s4".to_string(),
437 "s5".to_string(),
438 "s6".to_string(),
439 ];
440
441 let feature_descriptions = vec![
442 "Age".to_string(),
443 "Sex".to_string(),
444 "Body mass index".to_string(),
445 "Average blood pressure".to_string(),
446 "Total serum cholesterol".to_string(),
447 "Low-density lipoproteins".to_string(),
448 "High-density lipoproteins".to_string(),
449 "Total cholesterol / HDL".to_string(),
450 "Log of serum triglycerides level".to_string(),
451 "Blood sugar level".to_string(),
452 ];
453
454 let description = "Diabetes dataset for regression. A synthetic version of the classic diabetes dataset with 442 samples and 10 physiological features.".to_string();
455
456 let dataset = Dataset::new(data_array, Some(target_array))
457 .with_featurenames(featurenames)
458 .with_feature_descriptions(feature_descriptions)
459 .with_description(description);
460
461 Ok(dataset)
462}
463
464#[cfg(test)]
465mod tests {
466 use super::*;
467
468 #[test]
469 fn test_load_iris() {
470 let dataset = load_iris().unwrap();
471
472 assert_eq!(dataset.n_samples(), 150);
473 assert_eq!(dataset.n_features(), 4);
474 assert!(dataset.target.is_some());
475 assert!(dataset.description.is_some());
476 assert!(dataset.featurenames.is_some());
477 assert!(dataset.targetnames.is_some());
478
479 let featurenames = dataset.featurenames.as_ref().unwrap();
480 assert_eq!(featurenames.len(), 4);
481 assert_eq!(featurenames[0], "sepal_length");
482 assert_eq!(featurenames[3], "petal_width");
483
484 let targetnames = dataset.targetnames.as_ref().unwrap();
485 assert_eq!(targetnames.len(), 3);
486 assert!(targetnames.contains(&"setosa".to_string()));
487 assert!(targetnames.contains(&"versicolor".to_string()));
488 assert!(targetnames.contains(&"virginica".to_string()));
489
490 let target = dataset.target.as_ref().unwrap();
492 for &val in target.iter() {
493 assert!((0.0..=2.0).contains(&val));
494 }
495 }
496
497 #[test]
498 fn test_load_breast_cancer() {
499 let dataset = load_breast_cancer().unwrap();
500
501 assert_eq!(dataset.n_samples(), 30);
502 assert_eq!(dataset.n_features(), 5);
503 assert!(dataset.target.is_some());
504 assert!(dataset.description.is_some());
505 assert!(dataset.featurenames.is_some());
506 assert!(dataset.targetnames.is_some());
507
508 let featurenames = dataset.featurenames.as_ref().unwrap();
509 assert_eq!(featurenames.len(), 5);
510 assert_eq!(featurenames[0], "mean_radius");
511 assert_eq!(featurenames[4], "mean_smoothness");
512
513 let targetnames = dataset.targetnames.as_ref().unwrap();
514 assert_eq!(targetnames.len(), 2);
515 assert!(targetnames.contains(&"malignant".to_string()));
516 assert!(targetnames.contains(&"benign".to_string()));
517
518 let target = dataset.target.as_ref().unwrap();
520 for &val in target.iter() {
521 assert!(val == 0.0 || val == 1.0);
522 }
523 }
524
525 #[test]
526 fn test_load_digits() {
527 let dataset = load_digits().unwrap();
528
529 assert_eq!(dataset.n_samples(), 50);
530 assert_eq!(dataset.n_features(), 16);
531 assert!(dataset.target.is_some());
532 assert!(dataset.description.is_some());
533 assert!(dataset.featurenames.is_some());
534 assert!(dataset.targetnames.is_some());
535
536 let featurenames = dataset.featurenames.as_ref().unwrap();
537 assert_eq!(featurenames.len(), 16);
538 assert_eq!(featurenames[0], "pixel_0");
539 assert_eq!(featurenames[15], "pixel_15");
540
541 let targetnames = dataset.targetnames.as_ref().unwrap();
542 assert_eq!(targetnames.len(), 10);
543 for i in 0..10 {
544 assert!(targetnames.contains(&i.to_string()));
545 }
546
547 let target = dataset.target.as_ref().unwrap();
549 for &val in target.iter() {
550 assert!((0.0..=9.0).contains(&val));
551 }
552
553 for row in dataset.data.rows() {
555 for &pixel in row.iter() {
556 assert!((0.0..=1.0).contains(&pixel));
557 }
558 }
559 }
560
561 #[test]
562 fn test_load_boston() {
563 let dataset = load_boston().unwrap();
564
565 assert_eq!(dataset.n_samples(), 30);
566 assert_eq!(dataset.n_features(), 5);
567 assert!(dataset.target.is_some());
568 assert!(dataset.description.is_some());
569 assert!(dataset.featurenames.is_some());
570 assert!(dataset.feature_descriptions.is_some());
571
572 let featurenames = dataset.featurenames.as_ref().unwrap();
573 assert_eq!(featurenames.len(), 5);
574 assert_eq!(featurenames[0], "CRIM");
575 assert_eq!(featurenames[4], "NOX");
576
577 let feature_descriptions = dataset.feature_descriptions.as_ref().unwrap();
578 assert_eq!(feature_descriptions.len(), 5);
579 assert!(feature_descriptions[0].contains("crime rate"));
580
581 let target = dataset.target.as_ref().unwrap();
583 for &val in target.iter() {
584 assert!(val > 0.0 && val < 100.0); }
586 }
587
588 #[test]
589 fn test_all_datasets_have_consistentshapes() {
590 let datasets = vec![
591 ("iris", load_iris().unwrap()),
592 ("breast_cancer", load_breast_cancer().unwrap()),
593 ("digits", load_digits().unwrap()),
594 ("boston", load_boston().unwrap()),
595 ("diabetes", load_diabetes().unwrap()),
596 ];
597
598 for (name, dataset) in datasets {
599 if let Some(ref target) = dataset.target {
601 assert_eq!(
602 dataset.data.nrows(),
603 target.len(),
604 "Dataset '{name}' has inconsistent sample counts"
605 );
606 }
607
608 if let Some(ref featurenames) = dataset.featurenames {
610 assert_eq!(
611 dataset.data.ncols(),
612 featurenames.len(),
613 "Dataset '{name}' has inconsistent feature count"
614 );
615 }
616
617 if let Some(ref feature_descriptions) = dataset.feature_descriptions {
619 assert_eq!(
620 dataset.data.ncols(),
621 feature_descriptions.len(),
622 "Dataset '{name}' has inconsistent feature description count"
623 );
624 }
625
626 assert!(
628 dataset.description.is_some(),
629 "Dataset '{name}' missing description"
630 );
631 }
632 }
633}