data_generators/
data_generators.rs

1use scirs2_datasets::{
2    make_blobs, make_classification, make_regression, make_time_series, utils::normalize,
3    utils::train_test_split,
4};
5
6#[allow(dead_code)]
7fn main() -> Result<(), Box<dyn std::error::Error>> {
8    println!("Creating synthetic datasets...\n");
9
10    // Generate classification dataset
11    let n_samples = 100;
12    let n_features = 5;
13
14    let classificationdata = make_classification(
15        n_samples,
16        n_features,
17        3,        // 3 classes
18        2,        // 2 clusters per class
19        3,        // 3 informative features
20        Some(42), // random seed
21    )?;
22
23    // Train-test split
24    let (train, test) = train_test_split(&classificationdata, 0.2, Some(42))?;
25
26    println!("Classification dataset:");
27    println!("  Total samples: {}", classificationdata.n_samples());
28    println!("  Features: {}", classificationdata.n_features());
29    println!("  Training samples: {}", train.n_samples());
30    println!("  Test samples: {}", test.n_samples());
31
32    // Generate regression dataset
33    let regressiondata = make_regression(
34        n_samples,
35        n_features,
36        3,   // 3 informative features
37        0.5, // noise level
38        Some(42),
39    )?;
40
41    println!("\nRegression dataset:");
42    println!("  Samples: {}", regressiondata.n_samples());
43    println!("  Features: {}", regressiondata.n_features());
44
45    // Normalize the data (in-place)
46    let mut data_copy = regressiondata.data.clone();
47    normalize(&mut data_copy);
48    println!("  Data normalized successfully");
49
50    // Generate clustering data (blobs)
51    let clusteringdata = make_blobs(
52        n_samples,
53        2,   // 2 features for easy visualization
54        4,   // 4 clusters
55        0.8, // cluster standard deviation
56        Some(42),
57    )?;
58
59    println!("\nClustering dataset (blobs):");
60    println!("  Samples: {}", clusteringdata.n_samples());
61    println!("  Features: {}", clusteringdata.n_features());
62
63    // Find the number of clusters by finding the max value of target
64    let num_clusters = clusteringdata.target.as_ref().map_or(0, |t| {
65        let mut max_val = -1.0;
66        for &val in t.iter() {
67            if val > max_val {
68                max_val = val;
69            }
70        }
71        (max_val as usize) + 1
72    });
73
74    println!("  Clusters: {num_clusters}");
75
76    // Generate time series data
77    let time_series = make_time_series(
78        100,  // 100 time steps
79        3,    // 3 features/variables
80        true, // with trend
81        true, // with seasonality
82        0.2,  // noise level
83        Some(42),
84    )?;
85
86    println!("\nTime series dataset:");
87    println!("  Time steps: {}", time_series.n_samples());
88    println!("  Features: {}", time_series.n_features());
89
90    Ok(())
91}