data_generators/
data_generators.rs

1use scirs2_datasets::{
2    make_blobs, make_classification, make_regression, make_time_series, utils::normalize,
3};
4
5fn main() -> Result<(), Box<dyn std::error::Error>> {
6    println!("Creating synthetic datasets...\n");
7
8    // Generate classification dataset
9    let n_samples = 100;
10    let n_features = 5;
11
12    let classification_data = make_classification(
13        n_samples,
14        n_features,
15        3,        // 3 classes
16        2,        // 2 clusters per class
17        3,        // 3 informative features
18        Some(42), // random seed
19    )?;
20
21    // Train-test split
22    let (train, test) = classification_data.train_test_split(0.2, Some(42))?;
23
24    println!("Classification dataset:");
25    println!("  Total samples: {}", classification_data.n_samples());
26    println!("  Features: {}", classification_data.n_features());
27    println!("  Training samples: {}", train.n_samples());
28    println!("  Test samples: {}", test.n_samples());
29
30    // Generate regression dataset
31    let regression_data = make_regression(
32        n_samples,
33        n_features,
34        3,   // 3 informative features
35        0.5, // noise level
36        Some(42),
37    )?;
38
39    println!("\nRegression dataset:");
40    println!("  Samples: {}", regression_data.n_samples());
41    println!("  Features: {}", regression_data.n_features());
42
43    // Normalize the data (in-place)
44    let mut data_copy = regression_data.data.clone();
45    normalize(&mut data_copy);
46    println!("  Data normalized successfully");
47
48    // Generate clustering data (blobs)
49    let clustering_data = make_blobs(
50        n_samples,
51        2,   // 2 features for easy visualization
52        4,   // 4 clusters
53        0.8, // cluster standard deviation
54        Some(42),
55    )?;
56
57    println!("\nClustering dataset (blobs):");
58    println!("  Samples: {}", clustering_data.n_samples());
59    println!("  Features: {}", clustering_data.n_features());
60
61    // Find the number of clusters by finding the max value of target
62    let num_clusters = clustering_data.target.as_ref().map_or(0, |t| {
63        let mut max_val = -1.0;
64        for &val in t.iter() {
65            if val > max_val {
66                max_val = val;
67            }
68        }
69        (max_val as usize) + 1
70    });
71
72    println!("  Clusters: {}", num_clusters);
73
74    // Generate time series data
75    let time_series = make_time_series(
76        100,  // 100 time steps
77        3,    // 3 features/variables
78        true, // with trend
79        true, // with seasonality
80        0.2,  // noise level
81        Some(42),
82    )?;
83
84    println!("\nTime series dataset:");
85    println!("  Time steps: {}", time_series.n_samples());
86    println!("  Features: {}", time_series.n_features());
87
88    Ok(())
89}