make_regression

Function make_regression 

Source
pub fn make_regression(
    n_samples: usize,
    n_features: usize,
    n_informative: usize,
    noise: f64,
    randomseed: Option<u64>,
) -> Result<Dataset>
Expand description

Generate a random regression dataset

Examples found in repository?
examples/data_generators.rs (lines 33-39)
7fn main() -> Result<(), Box<dyn std::error::Error>> {
8    println!("Creating synthetic datasets...\n");
9
10    // Generate classification dataset
11    let n_samples = 100;
12    let n_features = 5;
13
14    let classificationdata = make_classification(
15        n_samples,
16        n_features,
17        3,        // 3 classes
18        2,        // 2 clusters per class
19        3,        // 3 informative features
20        Some(42), // random seed
21    )?;
22
23    // Train-test split
24    let (train, test) = train_test_split(&classificationdata, 0.2, Some(42))?;
25
26    println!("Classification dataset:");
27    println!("  Total samples: {}", classificationdata.n_samples());
28    println!("  Features: {}", classificationdata.n_features());
29    println!("  Training samples: {}", train.n_samples());
30    println!("  Test samples: {}", test.n_samples());
31
32    // Generate regression dataset
33    let regressiondata = make_regression(
34        n_samples,
35        n_features,
36        3,   // 3 informative features
37        0.5, // noise level
38        Some(42),
39    )?;
40
41    println!("\nRegression dataset:");
42    println!("  Samples: {}", regressiondata.n_samples());
43    println!("  Features: {}", regressiondata.n_features());
44
45    // Normalize the data (in-place)
46    let mut data_copy = regressiondata.data.clone();
47    normalize(&mut data_copy);
48    println!("  Data normalized successfully");
49
50    // Generate clustering data (blobs)
51    let clusteringdata = make_blobs(
52        n_samples,
53        2,   // 2 features for easy visualization
54        4,   // 4 clusters
55        0.8, // cluster standard deviation
56        Some(42),
57    )?;
58
59    println!("\nClustering dataset (blobs):");
60    println!("  Samples: {}", clusteringdata.n_samples());
61    println!("  Features: {}", clusteringdata.n_features());
62
63    // Find the number of clusters by finding the max value of target
64    let num_clusters = clusteringdata.target.as_ref().map_or(0, |t| {
65        let mut max_val = -1.0;
66        for &val in t.iter() {
67            if val > max_val {
68                max_val = val;
69            }
70        }
71        (max_val as usize) + 1
72    });
73
74    println!("  Clusters: {num_clusters}");
75
76    // Generate time series data
77    let time_series = make_time_series(
78        100,  // 100 time steps
79        3,    // 3 features/variables
80        true, // with trend
81        true, // with seasonality
82        0.2,  // noise level
83        Some(42),
84    )?;
85
86    println!("\nTime series dataset:");
87    println!("  Time steps: {}", time_series.n_samples());
88    println!("  Features: {}", time_series.n_features());
89
90    Ok(())
91}
More examples
Hide additional examples
examples/scikit_learn_benchmark.rs (line 365)
346fn run_sklearn_generation_comparison() {
347    println!("\n  🔬 Data Generation Comparison:");
348
349    let configs = vec![
350        (1000, 10, "classification"),
351        (5000, 20, "classification"),
352        (1000, 10, "regression"),
353        (5000, 20, "regression"),
354    ];
355
356    for (n_samples, n_features, gen_type) in configs {
357        #[allow(clippy::type_complexity)]
358        let (python_code, scirs2_fn): (&str, Box<dyn Fn() -> Result<Dataset, Box<dyn std::error::Error>>>) = match gen_type {
359            "classification" => (
360                &format!("from sklearn.datasets import make_classification; make_classification(n_samples={n_samples}, n_features={n_features}, random_state=42)"),
361                Box::new(move || make_classification(n_samples, n_features, 3, 2, 4, Some(42)).map_err(|e| Box::new(e) as Box<dyn std::error::Error>))
362            ),
363            "regression" => (
364                &format!("from sklearn.datasets import make_regression; make_regression(n_samples={n_samples}, n_features={n_features}, random_state=42)"),
365                Box::new(move || make_regression(n_samples, n_features, 3, 0.1, Some(42)).map_err(|e| Box::new(e) as Box<dyn std::error::Error>))
366            ),
367            _ => continue,
368        };
369
370        // Time Python execution
371        let python_result = Command::new("python3")
372            .arg("-c")
373            .arg(format!(
374                "import time; start=time.time(); {python_code}; print(f'{{:.4f}}', time.time()-start)"
375            ))
376            .output();
377
378        match python_result {
379            Ok(output) if output.status.success() => {
380                let python_time = String::from_utf8_lossy(&output.stdout)
381                    .trim()
382                    .parse::<f64>()
383                    .unwrap_or(0.0);
384
385                // Time SciRS2 execution
386                let scirs2_start = Instant::now();
387                let _scirs2_result = scirs2_fn();
388                let scirs2_time = scirs2_start.elapsed().as_secs_f64();
389
390                let speedup = python_time / scirs2_time;
391                let status = if speedup > 1.2 {
392                    "🚀 FASTER"
393                } else if speedup > 0.8 {
394                    "≈ SIMILAR"
395                } else {
396                    "🐌 SLOWER"
397                };
398
399                println!(
400                    "    {} {}x{}: SciRS2 {:.2}ms vs sklearn {:.2}ms ({:.1}x {})",
401                    gen_type,
402                    n_samples,
403                    n_features,
404                    scirs2_time * 1000.0,
405                    python_time * 1000.0,
406                    speedup,
407                    status
408                );
409            }
410            _ => {
411                println!(
412                    "    {gen_type} {n_samples}x{n_features}: Failed to benchmark Python version"
413                );
414            }
415        }
416    }
417}