Function train_test_split

Source
pub fn train_test_split(
    dataset: &Dataset,
    test_size: f64,
    random_seed: Option<u64>,
) -> Result<(Dataset, Dataset)>
Expand description

Split a dataset into training and test sets

This function creates a random split of the dataset while preserving the metadata and feature information in both resulting datasets.

§Arguments

  • dataset - The dataset to split
  • test_size - Fraction of samples to include in test set (0.0 to 1.0)
  • random_seed - Optional random seed for reproducible splits

§Returns

A tuple of (train_dataset, test_dataset)

§Examples

use ndarray::Array2;
use scirs2_datasets::utils::{Dataset, train_test_split};

let data = Array2::from_shape_vec((10, 3), (0..30).map(|x| x as f64).collect()).unwrap();
let dataset = Dataset::new(data, None);

let (train, test) = train_test_split(&dataset, 0.3, Some(42)).unwrap();
assert_eq!(train.n_samples() + test.n_samples(), 10);
Examples found in repository?
examples/dataset_loaders.rs (line 31)
6fn main() {
7    // Check if a CSV file is provided as a command-line argument
8    let args: Vec<String> = env::args().collect();
9    if args.len() < 2 {
10        println!("Usage: {} <path_to_csv_file>", args[0]);
11        println!("Example: {} examples/sample_data.csv", args[0]);
12        return;
13    }
14
15    let file_path = &args[1];
16
17    // Verify the file exists
18    if !Path::new(file_path).exists() {
19        println!("Error: File '{}' does not exist", file_path);
20        return;
21    }
22
23    // Load CSV file
24    println!("Loading CSV file: {}", file_path);
25    match loaders::load_csv(file_path, true, None) {
26        Ok(dataset) => {
27            print_dataset_info(&dataset, "Loaded CSV");
28
29            // Split the dataset for demonstration
30            println!("\nDemonstrating train-test split...");
31            match train_test_split(&dataset, 0.2, Some(42)) {
32                Ok((train, test)) => {
33                    println!("Training set: {} samples", train.n_samples());
34                    println!("Test set: {} samples", test.n_samples());
35
36                    // Save as JSON for demonstration
37                    let json_path = format!("{}.json", file_path);
38                    println!("\nSaving training dataset to JSON: {}", json_path);
39                    if let Err(e) = loaders::save_json(&train, &json_path) {
40                        println!("Error saving JSON: {}", e);
41                    } else {
42                        println!("Successfully saved JSON file");
43
44                        // Load back the JSON file
45                        println!("\nLoading back from JSON file...");
46                        match loaders::load_json(&json_path) {
47                            Ok(loaded) => {
48                                print_dataset_info(&loaded, "Loaded JSON");
49                            }
50                            Err(e) => println!("Error loading JSON: {}", e),
51                        }
52                    }
53                }
54                Err(e) => println!("Error splitting dataset: {}", e),
55            }
56        }
57        Err(e) => println!("Error loading CSV: {}", e),
58    }
59}
More examples
Hide additional examples
examples/data_generators.rs (line 23)
6fn main() -> Result<(), Box<dyn std::error::Error>> {
7    println!("Creating synthetic datasets...\n");
8
9    // Generate classification dataset
10    let n_samples = 100;
11    let n_features = 5;
12
13    let classification_data = make_classification(
14        n_samples,
15        n_features,
16        3,        // 3 classes
17        2,        // 2 clusters per class
18        3,        // 3 informative features
19        Some(42), // random seed
20    )?;
21
22    // Train-test split
23    let (train, test) = train_test_split(&classification_data, 0.2, Some(42))?;
24
25    println!("Classification dataset:");
26    println!("  Total samples: {}", classification_data.n_samples());
27    println!("  Features: {}", classification_data.n_features());
28    println!("  Training samples: {}", train.n_samples());
29    println!("  Test samples: {}", test.n_samples());
30
31    // Generate regression dataset
32    let regression_data = make_regression(
33        n_samples,
34        n_features,
35        3,   // 3 informative features
36        0.5, // noise level
37        Some(42),
38    )?;
39
40    println!("\nRegression dataset:");
41    println!("  Samples: {}", regression_data.n_samples());
42    println!("  Features: {}", regression_data.n_features());
43
44    // Normalize the data (in-place)
45    let mut data_copy = regression_data.data.clone();
46    normalize(&mut data_copy);
47    println!("  Data normalized successfully");
48
49    // Generate clustering data (blobs)
50    let clustering_data = make_blobs(
51        n_samples,
52        2,   // 2 features for easy visualization
53        4,   // 4 clusters
54        0.8, // cluster standard deviation
55        Some(42),
56    )?;
57
58    println!("\nClustering dataset (blobs):");
59    println!("  Samples: {}", clustering_data.n_samples());
60    println!("  Features: {}", clustering_data.n_features());
61
62    // Find the number of clusters by finding the max value of target
63    let num_clusters = clustering_data.target.as_ref().map_or(0, |t| {
64        let mut max_val = -1.0;
65        for &val in t.iter() {
66            if val > max_val {
67                max_val = val;
68            }
69        }
70        (max_val as usize) + 1
71    });
72
73    println!("  Clusters: {}", num_clusters);
74
75    // Generate time series data
76    let time_series = make_time_series(
77        100,  // 100 time steps
78        3,    // 3 features/variables
79        true, // with trend
80        true, // with seasonality
81        0.2,  // noise level
82        Some(42),
83    )?;
84
85    println!("\nTime series dataset:");
86    println!("  Time steps: {}", time_series.n_samples());
87    println!("  Features: {}", time_series.n_features());
88
89    Ok(())
90}