pub fn make_regression(
n_samples: usize,
n_features: usize,
n_informative: usize,
noise: f64,
random_seed: Option<u64>,
) -> Result<Dataset>
Expand description
Generate a random regression dataset
Examples found in repository?
examples/data_generators.rs (lines 32-38)
6fn main() -> Result<(), Box<dyn std::error::Error>> {
7 println!("Creating synthetic datasets...\n");
8
9 // Generate classification dataset
10 let n_samples = 100;
11 let n_features = 5;
12
13 let classification_data = make_classification(
14 n_samples,
15 n_features,
16 3, // 3 classes
17 2, // 2 clusters per class
18 3, // 3 informative features
19 Some(42), // random seed
20 )?;
21
22 // Train-test split
23 let (train, test) = train_test_split(&classification_data, 0.2, Some(42))?;
24
25 println!("Classification dataset:");
26 println!(" Total samples: {}", classification_data.n_samples());
27 println!(" Features: {}", classification_data.n_features());
28 println!(" Training samples: {}", train.n_samples());
29 println!(" Test samples: {}", test.n_samples());
30
31 // Generate regression dataset
32 let regression_data = make_regression(
33 n_samples,
34 n_features,
35 3, // 3 informative features
36 0.5, // noise level
37 Some(42),
38 )?;
39
40 println!("\nRegression dataset:");
41 println!(" Samples: {}", regression_data.n_samples());
42 println!(" Features: {}", regression_data.n_features());
43
44 // Normalize the data (in-place)
45 let mut data_copy = regression_data.data.clone();
46 normalize(&mut data_copy);
47 println!(" Data normalized successfully");
48
49 // Generate clustering data (blobs)
50 let clustering_data = make_blobs(
51 n_samples,
52 2, // 2 features for easy visualization
53 4, // 4 clusters
54 0.8, // cluster standard deviation
55 Some(42),
56 )?;
57
58 println!("\nClustering dataset (blobs):");
59 println!(" Samples: {}", clustering_data.n_samples());
60 println!(" Features: {}", clustering_data.n_features());
61
62 // Find the number of clusters by finding the max value of target
63 let num_clusters = clustering_data.target.as_ref().map_or(0, |t| {
64 let mut max_val = -1.0;
65 for &val in t.iter() {
66 if val > max_val {
67 max_val = val;
68 }
69 }
70 (max_val as usize) + 1
71 });
72
73 println!(" Clusters: {}", num_clusters);
74
75 // Generate time series data
76 let time_series = make_time_series(
77 100, // 100 time steps
78 3, // 3 features/variables
79 true, // with trend
80 true, // with seasonality
81 0.2, // noise level
82 Some(42),
83 )?;
84
85 println!("\nTime series dataset:");
86 println!(" Time steps: {}", time_series.n_samples());
87 println!(" Features: {}", time_series.n_features());
88
89 Ok(())
90}