sklears_datasets/generators/
basic.rs1use scirs2_core::ndarray::{s, Array1, Array2};
7use scirs2_core::random::prelude::*;
8use scirs2_core::random::rngs::StdRng;
9use scirs2_core::random::{Normal, StandardNormal};
10use sklears_core::error::{Result, SklearsError};
11use std::f64::consts::PI;
12
13pub fn make_blobs(
14 n_samples: usize,
15 n_features: usize,
16 centers: usize,
17 cluster_std: f64,
18 random_state: Option<u64>,
19) -> Result<(Array2<f64>, Array1<i32>)> {
20 if n_samples == 0 || n_features == 0 || centers == 0 {
21 return Err(SklearsError::InvalidInput(
22 "n_samples, n_features, and centers must be positive".to_string(),
23 ));
24 }
25
26 let mut rng = if let Some(seed) = random_state {
27 StdRng::seed_from_u64(seed)
28 } else {
29 StdRng::from_rng(&mut scirs2_core::random::thread_rng())
30 };
31
32 let mut center_points = Array2::zeros((centers, n_features));
34 for i in 0..centers {
35 for j in 0..n_features {
36 center_points[[i, j]] = rng.gen_range(-10.0..10.0);
37 }
38 }
39
40 let samples_per_center = n_samples / centers;
42 let extra_samples = n_samples % centers;
43
44 let mut x = Array2::zeros((n_samples, n_features));
45 let mut y = Array1::zeros(n_samples);
46
47 let mut sample_idx = 0;
48
49 for center_idx in 0..centers {
50 let n_samples_for_center = if center_idx < extra_samples {
51 samples_per_center + 1
52 } else {
53 samples_per_center
54 };
55
56 let normal = Normal::new(0.0, cluster_std).unwrap();
57
58 for _ in 0..n_samples_for_center {
59 y[sample_idx] = center_idx as i32;
60
61 for feature_idx in 0..n_features {
62 let center_value = center_points[[center_idx, feature_idx]];
63 let noise: f64 = rng.sample(normal);
64 x[[sample_idx, feature_idx]] = center_value + noise;
65 }
66
67 sample_idx += 1;
68 }
69 }
70
71 Ok((x, y))
72}
73
74pub fn make_classification(
75 n_samples: usize,
76 n_features: usize,
77 n_informative: usize,
78 n_redundant: usize,
79 n_classes: usize,
80 random_state: Option<u64>,
81) -> Result<(Array2<f64>, Array1<i32>)> {
82 if n_samples == 0 || n_features == 0 || n_classes < 2 {
83 return Err(SklearsError::InvalidInput(
84 "n_samples and n_features must be positive, n_classes must be >= 2".to_string(),
85 ));
86 }
87
88 if n_informative + n_redundant > n_features {
89 return Err(SklearsError::InvalidInput(
90 "n_informative + n_redundant cannot exceed n_features".to_string(),
91 ));
92 }
93
94 let mut rng = if let Some(seed) = random_state {
95 StdRng::seed_from_u64(seed)
96 } else {
97 StdRng::from_rng(&mut scirs2_core::random::thread_rng())
98 };
99
100 let normal = StandardNormal;
101
102 let mut x = Array2::zeros((n_samples, n_features));
104 let mut y = Array1::zeros(n_samples);
105
106 for i in 0..n_samples {
108 y[i] = rng.gen_range(0..n_classes) as i32;
109 }
110
111 for i in 0..n_samples {
113 let class = y[i] as usize;
114 for j in 0..n_informative {
115 let class_offset = (class as f64 - (n_classes as f64 - 1.0) / 2.0) * 2.0;
116 x[[i, j]] = rng.sample::<f64, _>(normal) + class_offset;
117 }
118 }
119
120 for j in n_informative..(n_informative + n_redundant) {
122 let informative_idx = rng.gen_range(0..n_informative);
123 let weight = rng.gen_range(-1.0..1.0);
124 for i in 0..n_samples {
125 x[[i, j]] = x[[i, informative_idx]] * weight + rng.sample::<f64, _>(normal) * 0.1;
126 }
127 }
128
129 for j in (n_informative + n_redundant)..n_features {
131 for i in 0..n_samples {
132 x[[i, j]] = rng.sample::<f64, _>(normal);
133 }
134 }
135
136 Ok((x, y))
137}
138
139pub fn make_regression(
140 n_samples: usize,
141 n_features: usize,
142 n_informative: usize,
143 noise: f64,
144 random_state: Option<u64>,
145) -> Result<(Array2<f64>, Array1<f64>)> {
146 if n_samples == 0 || n_features == 0 {
147 return Err(SklearsError::InvalidInput(
148 "n_samples and n_features must be positive".to_string(),
149 ));
150 }
151
152 if n_informative > n_features {
153 return Err(SklearsError::InvalidInput(
154 "n_informative cannot exceed n_features".to_string(),
155 ));
156 }
157
158 let mut rng = if let Some(seed) = random_state {
159 StdRng::seed_from_u64(seed)
160 } else {
161 StdRng::from_rng(&mut scirs2_core::random::thread_rng())
162 };
163
164 let normal = StandardNormal;
165
166 let mut x = Array2::zeros((n_samples, n_features));
168 for i in 0..n_samples {
169 for j in 0..n_features {
170 x[[i, j]] = rng.sample::<f64, _>(normal);
171 }
172 }
173
174 let mut coef = Array1::zeros(n_features);
176 for i in 0..n_informative {
177 coef[i] = rng.gen_range(-1.0..1.0) * 100.0;
178 }
179
180 let mut y = Array1::zeros(n_samples);
182 for i in 0..n_samples {
183 let mut target = 0.0;
184 for j in 0..n_informative {
185 target += x[[i, j]] * coef[j];
186 }
187
188 if noise > 0.0 {
190 let noise_dist = Normal::new(0.0, noise).unwrap();
191 target += rng.sample(noise_dist);
192 }
193
194 y[i] = target;
195 }
196
197 Ok((x, y))
198}
199
200pub fn make_circles(
201 n_samples: usize,
202 noise: Option<f64>,
203 factor: f64,
204 random_state: Option<u64>,
205) -> Result<(Array2<f64>, Array1<i32>)> {
206 if n_samples == 0 {
207 return Err(SklearsError::InvalidInput(
208 "n_samples must be positive".to_string(),
209 ));
210 }
211
212 if factor <= 0.0 || factor >= 1.0 {
213 return Err(SklearsError::InvalidInput(
214 "factor must be between 0 and 1".to_string(),
215 ));
216 }
217
218 let mut rng = if let Some(seed) = random_state {
219 StdRng::seed_from_u64(seed)
220 } else {
221 StdRng::from_rng(&mut scirs2_core::random::thread_rng())
222 };
223
224 let n_samples_out = n_samples / 2;
225 let n_samples_in = n_samples - n_samples_out;
226
227 let mut x = Array2::zeros((n_samples, 2));
228 let mut y = Array1::zeros(n_samples);
229
230 for i in 0..n_samples_out {
232 let angle = rng.gen::<f64>() * 2.0 * PI;
233 x[[i, 0]] = angle.cos();
234 x[[i, 1]] = angle.sin();
235 y[i] = 0;
236 }
237
238 for i in 0..n_samples_in {
240 let angle = rng.gen::<f64>() * 2.0 * PI;
241 x[[n_samples_out + i, 0]] = factor * angle.cos();
242 x[[n_samples_out + i, 1]] = factor * angle.sin();
243 y[n_samples_out + i] = 1;
244 }
245
246 if let Some(noise_level) = noise {
248 if noise_level > 0.0 {
249 let noise_dist = Normal::new(0.0, noise_level).unwrap();
250 for i in 0..n_samples {
251 x[[i, 0]] += rng.sample(noise_dist);
252 x[[i, 1]] += rng.sample(noise_dist);
253 }
254 }
255 }
256
257 Ok((x, y))
258}
259
260pub fn make_moons(
261 n_samples: usize,
262 noise: Option<f64>,
263 random_state: Option<u64>,
264) -> Result<(Array2<f64>, Array1<i32>)> {
265 if n_samples == 0 {
266 return Err(SklearsError::InvalidInput(
267 "n_samples must be positive".to_string(),
268 ));
269 }
270
271 let mut rng = if let Some(seed) = random_state {
272 StdRng::seed_from_u64(seed)
273 } else {
274 StdRng::from_rng(&mut scirs2_core::random::thread_rng())
275 };
276
277 let n_samples_out = n_samples / 2;
278 let n_samples_in = n_samples - n_samples_out;
279
280 let mut x = Array2::zeros((n_samples, 2));
281 let mut y = Array1::zeros(n_samples);
282
283 for i in 0..n_samples_out {
285 let t = rng.gen::<f64>() * PI;
286 x[[i, 0]] = t.cos();
287 x[[i, 1]] = t.sin();
288 y[i] = 0;
289 }
290
291 for i in 0..n_samples_in {
293 let t = rng.gen::<f64>() * PI;
294 x[[n_samples_out + i, 0]] = 1.0 - t.cos();
295 x[[n_samples_out + i, 1]] = 1.0 - t.sin() - 0.5;
296 y[n_samples_out + i] = 1;
297 }
298
299 if let Some(noise_level) = noise {
301 if noise_level > 0.0 {
302 let noise_dist = Normal::new(0.0, noise_level).unwrap();
303 for i in 0..n_samples {
304 x[[i, 0]] += rng.sample(noise_dist);
305 x[[i, 1]] += rng.sample(noise_dist);
306 }
307 }
308 }
309
310 Ok((x, y))
311}
312
313pub fn make_gaussian_quantiles(
314 n_samples: usize,
315 n_features: usize,
316 n_classes: usize,
317 random_state: Option<u64>,
318) -> Result<(Array2<f64>, Array1<i32>)> {
319 if n_samples == 0 || n_features == 0 || n_classes < 2 {
320 return Err(SklearsError::InvalidInput(
321 "n_samples, n_features must be positive, n_classes must be >= 2".to_string(),
322 ));
323 }
324
325 let mut rng = if let Some(seed) = random_state {
326 StdRng::seed_from_u64(seed)
327 } else {
328 StdRng::from_rng(&mut scirs2_core::random::thread_rng())
329 };
330
331 let normal = StandardNormal;
332
333 let mut x = Array2::zeros((n_samples, n_features));
335 for i in 0..n_samples {
336 for j in 0..n_features {
337 x[[i, j]] = rng.sample::<f64, _>(normal);
338 }
339 }
340
341 let mut norms = Array1::zeros(n_samples);
343 for i in 0..n_samples {
344 let norm = x.slice(s![i, ..]).mapv(|x| x * x).sum().sqrt();
345 norms[i] = norm;
346 }
347
348 let mut indices: Vec<usize> = (0..n_samples).collect();
350 indices.sort_by(|&a, &b| norms[a].partial_cmp(&norms[b]).unwrap());
351
352 let mut y = Array1::zeros(n_samples);
354 let samples_per_class = n_samples / n_classes;
355
356 for (class_idx, chunk) in indices.chunks(samples_per_class).enumerate() {
357 let class = (class_idx.min(n_classes - 1)) as i32;
358 for &sample_idx in chunk {
359 y[sample_idx] = class;
360 }
361 }
362
363 Ok((x, y))
364}
365
366#[allow(non_snake_case)]
367#[cfg(test)]
368mod tests {
369 use super::*;
370
371 #[test]
372 fn test_make_blobs() {
373 let (x, y) = make_blobs(100, 2, 3, 1.0, Some(42)).unwrap();
374 assert_eq!(x.shape(), &[100, 2]);
375 assert_eq!(y.len(), 100);
376
377 let mut classes = y.iter().cloned().collect::<Vec<_>>();
379 classes.sort();
380 classes.dedup();
381 assert_eq!(classes.len(), 3);
382 }
383
384 #[test]
385 fn test_make_classification() {
386 let (x, y) = make_classification(100, 20, 10, 5, 3, Some(42)).unwrap();
387 assert_eq!(x.shape(), &[100, 20]);
388 assert_eq!(y.len(), 100);
389
390 let mut classes = y.iter().cloned().collect::<Vec<_>>();
392 classes.sort();
393 classes.dedup();
394 assert!(classes.len() <= 3);
395 }
396
397 #[test]
398 fn test_make_regression() {
399 let (x, y) = make_regression(50, 10, 5, 0.1, Some(42)).unwrap();
400 assert_eq!(x.shape(), &[50, 10]);
401 assert_eq!(y.len(), 50);
402
403 let mean = y.mean().unwrap();
405 let variance = y.mapv(|v| (v - mean).powi(2)).mean().unwrap();
406 assert!(variance > 0.0);
407 }
408
409 #[test]
410 fn test_make_circles() {
411 let (x, y) = make_circles(100, Some(0.1), 0.4, Some(42)).unwrap();
412 assert_eq!(x.shape(), &[100, 2]);
413 assert_eq!(y.len(), 100);
414
415 let mut classes = y.iter().cloned().collect::<Vec<_>>();
417 classes.sort();
418 classes.dedup();
419 assert_eq!(classes.len(), 2);
420 }
421
422 #[test]
423 fn test_make_moons() {
424 let (x, y) = make_moons(80, Some(0.15), Some(42)).unwrap();
425 assert_eq!(x.shape(), &[80, 2]);
426 assert_eq!(y.len(), 80);
427
428 let mut classes = y.iter().cloned().collect::<Vec<_>>();
430 classes.sort();
431 classes.dedup();
432 assert_eq!(classes.len(), 2);
433 }
434
435 #[test]
436 fn test_make_gaussian_quantiles() {
437 let (x, y) = make_gaussian_quantiles(120, 5, 3, Some(42)).unwrap();
438 assert_eq!(x.shape(), &[120, 5]);
439 assert_eq!(y.len(), 120);
440
441 let mut classes = y.iter().cloned().collect::<Vec<_>>();
443 classes.sort();
444 classes.dedup();
445 assert!(classes.len() <= 3);
446 }
447
448 #[test]
449 fn test_invalid_inputs() {
450 assert!(make_blobs(0, 2, 3, 1.0, Some(42)).is_err());
452
453 assert!(make_circles(100, Some(0.1), 1.5, Some(42)).is_err());
455
456 assert!(make_classification(100, 5, 10, 0, 3, Some(42)).is_err());
458 }
459}