sklears_datasets/generators/
basic.rs

1//! Basic synthetic data generators
2//!
3//! This module contains fundamental dataset generation functions including
4//! blobs, classification, regression, circles, moons, and other basic patterns.
5
6use scirs2_core::ndarray::{s, Array1, Array2};
7use scirs2_core::random::prelude::*;
8use scirs2_core::random::rngs::StdRng;
9use scirs2_core::random::{Normal, StandardNormal};
10use sklears_core::error::{Result, SklearsError};
11use std::f64::consts::PI;
12
13pub fn make_blobs(
14    n_samples: usize,
15    n_features: usize,
16    centers: usize,
17    cluster_std: f64,
18    random_state: Option<u64>,
19) -> Result<(Array2<f64>, Array1<i32>)> {
20    if n_samples == 0 || n_features == 0 || centers == 0 {
21        return Err(SklearsError::InvalidInput(
22            "n_samples, n_features, and centers must be positive".to_string(),
23        ));
24    }
25
26    let mut rng = if let Some(seed) = random_state {
27        StdRng::seed_from_u64(seed)
28    } else {
29        StdRng::from_rng(&mut scirs2_core::random::thread_rng())
30    };
31
32    // Generate random centers
33    let mut center_points = Array2::zeros((centers, n_features));
34    for i in 0..centers {
35        for j in 0..n_features {
36            center_points[[i, j]] = rng.gen_range(-10.0..10.0);
37        }
38    }
39
40    // Assign samples to centers
41    let samples_per_center = n_samples / centers;
42    let extra_samples = n_samples % centers;
43
44    let mut x = Array2::zeros((n_samples, n_features));
45    let mut y = Array1::zeros(n_samples);
46
47    let mut sample_idx = 0;
48
49    for center_idx in 0..centers {
50        let n_samples_for_center = if center_idx < extra_samples {
51            samples_per_center + 1
52        } else {
53            samples_per_center
54        };
55
56        let normal = Normal::new(0.0, cluster_std).unwrap();
57
58        for _ in 0..n_samples_for_center {
59            y[sample_idx] = center_idx as i32;
60
61            for feature_idx in 0..n_features {
62                let center_value = center_points[[center_idx, feature_idx]];
63                let noise: f64 = rng.sample(normal);
64                x[[sample_idx, feature_idx]] = center_value + noise;
65            }
66
67            sample_idx += 1;
68        }
69    }
70
71    Ok((x, y))
72}
73
74pub fn make_classification(
75    n_samples: usize,
76    n_features: usize,
77    n_informative: usize,
78    n_redundant: usize,
79    n_classes: usize,
80    random_state: Option<u64>,
81) -> Result<(Array2<f64>, Array1<i32>)> {
82    if n_samples == 0 || n_features == 0 || n_classes < 2 {
83        return Err(SklearsError::InvalidInput(
84            "n_samples and n_features must be positive, n_classes must be >= 2".to_string(),
85        ));
86    }
87
88    if n_informative + n_redundant > n_features {
89        return Err(SklearsError::InvalidInput(
90            "n_informative + n_redundant cannot exceed n_features".to_string(),
91        ));
92    }
93
94    let mut rng = if let Some(seed) = random_state {
95        StdRng::seed_from_u64(seed)
96    } else {
97        StdRng::from_rng(&mut scirs2_core::random::thread_rng())
98    };
99
100    let normal = StandardNormal;
101
102    // Generate informative features
103    let mut x = Array2::zeros((n_samples, n_features));
104    let mut y = Array1::zeros(n_samples);
105
106    // Assign classes randomly
107    for i in 0..n_samples {
108        y[i] = rng.gen_range(0..n_classes) as i32;
109    }
110
111    // Generate informative features based on class
112    for i in 0..n_samples {
113        let class = y[i] as usize;
114        for j in 0..n_informative {
115            let class_offset = (class as f64 - (n_classes as f64 - 1.0) / 2.0) * 2.0;
116            x[[i, j]] = rng.sample::<f64, _>(normal) + class_offset;
117        }
118    }
119
120    // Generate redundant features as linear combinations of informative features
121    for j in n_informative..(n_informative + n_redundant) {
122        let informative_idx = rng.gen_range(0..n_informative);
123        let weight = rng.gen_range(-1.0..1.0);
124        for i in 0..n_samples {
125            x[[i, j]] = x[[i, informative_idx]] * weight + rng.sample::<f64, _>(normal) * 0.1;
126        }
127    }
128
129    // Fill remaining features with random noise
130    for j in (n_informative + n_redundant)..n_features {
131        for i in 0..n_samples {
132            x[[i, j]] = rng.sample::<f64, _>(normal);
133        }
134    }
135
136    Ok((x, y))
137}
138
139pub fn make_regression(
140    n_samples: usize,
141    n_features: usize,
142    n_informative: usize,
143    noise: f64,
144    random_state: Option<u64>,
145) -> Result<(Array2<f64>, Array1<f64>)> {
146    if n_samples == 0 || n_features == 0 {
147        return Err(SklearsError::InvalidInput(
148            "n_samples and n_features must be positive".to_string(),
149        ));
150    }
151
152    if n_informative > n_features {
153        return Err(SklearsError::InvalidInput(
154            "n_informative cannot exceed n_features".to_string(),
155        ));
156    }
157
158    let mut rng = if let Some(seed) = random_state {
159        StdRng::seed_from_u64(seed)
160    } else {
161        StdRng::from_rng(&mut scirs2_core::random::thread_rng())
162    };
163
164    let normal = StandardNormal;
165
166    // Generate feature matrix
167    let mut x = Array2::zeros((n_samples, n_features));
168    for i in 0..n_samples {
169        for j in 0..n_features {
170            x[[i, j]] = rng.sample::<f64, _>(normal);
171        }
172    }
173
174    // Generate true coefficients for informative features
175    let mut coef = Array1::zeros(n_features);
176    for i in 0..n_informative {
177        coef[i] = rng.gen_range(-1.0..1.0) * 100.0;
178    }
179
180    // Compute target values
181    let mut y = Array1::zeros(n_samples);
182    for i in 0..n_samples {
183        let mut target = 0.0;
184        for j in 0..n_informative {
185            target += x[[i, j]] * coef[j];
186        }
187
188        // Add noise
189        if noise > 0.0 {
190            let noise_dist = Normal::new(0.0, noise).unwrap();
191            target += rng.sample(noise_dist);
192        }
193
194        y[i] = target;
195    }
196
197    Ok((x, y))
198}
199
200pub fn make_circles(
201    n_samples: usize,
202    noise: Option<f64>,
203    factor: f64,
204    random_state: Option<u64>,
205) -> Result<(Array2<f64>, Array1<i32>)> {
206    if n_samples == 0 {
207        return Err(SklearsError::InvalidInput(
208            "n_samples must be positive".to_string(),
209        ));
210    }
211
212    if factor <= 0.0 || factor >= 1.0 {
213        return Err(SklearsError::InvalidInput(
214            "factor must be between 0 and 1".to_string(),
215        ));
216    }
217
218    let mut rng = if let Some(seed) = random_state {
219        StdRng::seed_from_u64(seed)
220    } else {
221        StdRng::from_rng(&mut scirs2_core::random::thread_rng())
222    };
223
224    let n_samples_out = n_samples / 2;
225    let n_samples_in = n_samples - n_samples_out;
226
227    let mut x = Array2::zeros((n_samples, 2));
228    let mut y = Array1::zeros(n_samples);
229
230    // Generate outer circle
231    for i in 0..n_samples_out {
232        let angle = rng.gen::<f64>() * 2.0 * PI;
233        x[[i, 0]] = angle.cos();
234        x[[i, 1]] = angle.sin();
235        y[i] = 0;
236    }
237
238    // Generate inner circle
239    for i in 0..n_samples_in {
240        let angle = rng.gen::<f64>() * 2.0 * PI;
241        x[[n_samples_out + i, 0]] = factor * angle.cos();
242        x[[n_samples_out + i, 1]] = factor * angle.sin();
243        y[n_samples_out + i] = 1;
244    }
245
246    // Add noise if specified
247    if let Some(noise_level) = noise {
248        if noise_level > 0.0 {
249            let noise_dist = Normal::new(0.0, noise_level).unwrap();
250            for i in 0..n_samples {
251                x[[i, 0]] += rng.sample(noise_dist);
252                x[[i, 1]] += rng.sample(noise_dist);
253            }
254        }
255    }
256
257    Ok((x, y))
258}
259
260pub fn make_moons(
261    n_samples: usize,
262    noise: Option<f64>,
263    random_state: Option<u64>,
264) -> Result<(Array2<f64>, Array1<i32>)> {
265    if n_samples == 0 {
266        return Err(SklearsError::InvalidInput(
267            "n_samples must be positive".to_string(),
268        ));
269    }
270
271    let mut rng = if let Some(seed) = random_state {
272        StdRng::seed_from_u64(seed)
273    } else {
274        StdRng::from_rng(&mut scirs2_core::random::thread_rng())
275    };
276
277    let n_samples_out = n_samples / 2;
278    let n_samples_in = n_samples - n_samples_out;
279
280    let mut x = Array2::zeros((n_samples, 2));
281    let mut y = Array1::zeros(n_samples);
282
283    // Generate outer moon
284    for i in 0..n_samples_out {
285        let t = rng.gen::<f64>() * PI;
286        x[[i, 0]] = t.cos();
287        x[[i, 1]] = t.sin();
288        y[i] = 0;
289    }
290
291    // Generate inner moon
292    for i in 0..n_samples_in {
293        let t = rng.gen::<f64>() * PI;
294        x[[n_samples_out + i, 0]] = 1.0 - t.cos();
295        x[[n_samples_out + i, 1]] = 1.0 - t.sin() - 0.5;
296        y[n_samples_out + i] = 1;
297    }
298
299    // Add noise if specified
300    if let Some(noise_level) = noise {
301        if noise_level > 0.0 {
302            let noise_dist = Normal::new(0.0, noise_level).unwrap();
303            for i in 0..n_samples {
304                x[[i, 0]] += rng.sample(noise_dist);
305                x[[i, 1]] += rng.sample(noise_dist);
306            }
307        }
308    }
309
310    Ok((x, y))
311}
312
313pub fn make_gaussian_quantiles(
314    n_samples: usize,
315    n_features: usize,
316    n_classes: usize,
317    random_state: Option<u64>,
318) -> Result<(Array2<f64>, Array1<i32>)> {
319    if n_samples == 0 || n_features == 0 || n_classes < 2 {
320        return Err(SklearsError::InvalidInput(
321            "n_samples, n_features must be positive, n_classes must be >= 2".to_string(),
322        ));
323    }
324
325    let mut rng = if let Some(seed) = random_state {
326        StdRng::seed_from_u64(seed)
327    } else {
328        StdRng::from_rng(&mut scirs2_core::random::thread_rng())
329    };
330
331    let normal = StandardNormal;
332
333    // Generate multivariate normal data
334    let mut x = Array2::zeros((n_samples, n_features));
335    for i in 0..n_samples {
336        for j in 0..n_features {
337            x[[i, j]] = rng.sample::<f64, _>(normal);
338        }
339    }
340
341    // Calculate norm for each sample
342    let mut norms = Array1::zeros(n_samples);
343    for i in 0..n_samples {
344        let norm = x.slice(s![i, ..]).mapv(|x| x * x).sum().sqrt();
345        norms[i] = norm;
346    }
347
348    // Sort by norm to create quantiles
349    let mut indices: Vec<usize> = (0..n_samples).collect();
350    indices.sort_by(|&a, &b| norms[a].partial_cmp(&norms[b]).unwrap());
351
352    // Assign classes based on quantiles
353    let mut y = Array1::zeros(n_samples);
354    let samples_per_class = n_samples / n_classes;
355
356    for (class_idx, chunk) in indices.chunks(samples_per_class).enumerate() {
357        let class = (class_idx.min(n_classes - 1)) as i32;
358        for &sample_idx in chunk {
359            y[sample_idx] = class;
360        }
361    }
362
363    Ok((x, y))
364}
365
366#[allow(non_snake_case)]
367#[cfg(test)]
368mod tests {
369    use super::*;
370
371    #[test]
372    fn test_make_blobs() {
373        let (x, y) = make_blobs(100, 2, 3, 1.0, Some(42)).unwrap();
374        assert_eq!(x.shape(), &[100, 2]);
375        assert_eq!(y.len(), 100);
376
377        // Check that we have the right number of classes
378        let mut classes = y.iter().cloned().collect::<Vec<_>>();
379        classes.sort();
380        classes.dedup();
381        assert_eq!(classes.len(), 3);
382    }
383
384    #[test]
385    fn test_make_classification() {
386        let (x, y) = make_classification(100, 20, 10, 5, 3, Some(42)).unwrap();
387        assert_eq!(x.shape(), &[100, 20]);
388        assert_eq!(y.len(), 100);
389
390        // Check that we have the right number of classes
391        let mut classes = y.iter().cloned().collect::<Vec<_>>();
392        classes.sort();
393        classes.dedup();
394        assert!(classes.len() <= 3);
395    }
396
397    #[test]
398    fn test_make_regression() {
399        let (x, y) = make_regression(50, 10, 5, 0.1, Some(42)).unwrap();
400        assert_eq!(x.shape(), &[50, 10]);
401        assert_eq!(y.len(), 50);
402
403        // Check that target values have some variation
404        let mean = y.mean().unwrap();
405        let variance = y.mapv(|v| (v - mean).powi(2)).mean().unwrap();
406        assert!(variance > 0.0);
407    }
408
409    #[test]
410    fn test_make_circles() {
411        let (x, y) = make_circles(100, Some(0.1), 0.4, Some(42)).unwrap();
412        assert_eq!(x.shape(), &[100, 2]);
413        assert_eq!(y.len(), 100);
414
415        // Check that we have two classes
416        let mut classes = y.iter().cloned().collect::<Vec<_>>();
417        classes.sort();
418        classes.dedup();
419        assert_eq!(classes.len(), 2);
420    }
421
422    #[test]
423    fn test_make_moons() {
424        let (x, y) = make_moons(80, Some(0.15), Some(42)).unwrap();
425        assert_eq!(x.shape(), &[80, 2]);
426        assert_eq!(y.len(), 80);
427
428        // Check that we have two classes
429        let mut classes = y.iter().cloned().collect::<Vec<_>>();
430        classes.sort();
431        classes.dedup();
432        assert_eq!(classes.len(), 2);
433    }
434
435    #[test]
436    fn test_make_gaussian_quantiles() {
437        let (x, y) = make_gaussian_quantiles(120, 5, 3, Some(42)).unwrap();
438        assert_eq!(x.shape(), &[120, 5]);
439        assert_eq!(y.len(), 120);
440
441        // Check that we have the right number of classes
442        let mut classes = y.iter().cloned().collect::<Vec<_>>();
443        classes.sort();
444        classes.dedup();
445        assert!(classes.len() <= 3);
446    }
447
448    #[test]
449    fn test_invalid_inputs() {
450        // Test invalid n_samples
451        assert!(make_blobs(0, 2, 3, 1.0, Some(42)).is_err());
452
453        // Test invalid factor for circles
454        assert!(make_circles(100, Some(0.1), 1.5, Some(42)).is_err());
455
456        // Test invalid n_informative for classification
457        assert!(make_classification(100, 5, 10, 0, 3, Some(42)).is_err());
458    }
459}