pub fn train_test_split(
dataset: &Dataset,
test_size: f64,
random_seed: Option<u64>,
) -> Result<(Dataset, Dataset)>
Expand description
Split a dataset into training and test sets
This function creates a random split of the dataset while preserving the metadata and feature information in both resulting datasets.
§Arguments
dataset
- The dataset to splittest_size
- Fraction of samples to include in test set (0.0 to 1.0)random_seed
- Optional random seed for reproducible splits
§Returns
A tuple of (train_dataset, test_dataset)
§Examples
use ndarray::Array2;
use scirs2_datasets::utils::{Dataset, train_test_split};
let data = Array2::from_shape_vec((10, 3), (0..30).map(|x| x as f64).collect()).unwrap();
let dataset = Dataset::new(data, None);
let (train, test) = train_test_split(&dataset, 0.3, Some(42)).unwrap();
assert_eq!(train.n_samples() + test.n_samples(), 10);
Examples found in repository?
examples/dataset_loaders.rs (line 31)
6fn main() {
7 // Check if a CSV file is provided as a command-line argument
8 let args: Vec<String> = env::args().collect();
9 if args.len() < 2 {
10 println!("Usage: {} <path_to_csv_file>", args[0]);
11 println!("Example: {} examples/sample_data.csv", args[0]);
12 return;
13 }
14
15 let file_path = &args[1];
16
17 // Verify the file exists
18 if !Path::new(file_path).exists() {
19 println!("Error: File '{}' does not exist", file_path);
20 return;
21 }
22
23 // Load CSV file
24 println!("Loading CSV file: {}", file_path);
25 match loaders::load_csv(file_path, true, None) {
26 Ok(dataset) => {
27 print_dataset_info(&dataset, "Loaded CSV");
28
29 // Split the dataset for demonstration
30 println!("\nDemonstrating train-test split...");
31 match train_test_split(&dataset, 0.2, Some(42)) {
32 Ok((train, test)) => {
33 println!("Training set: {} samples", train.n_samples());
34 println!("Test set: {} samples", test.n_samples());
35
36 // Save as JSON for demonstration
37 let json_path = format!("{}.json", file_path);
38 println!("\nSaving training dataset to JSON: {}", json_path);
39 if let Err(e) = loaders::save_json(&train, &json_path) {
40 println!("Error saving JSON: {}", e);
41 } else {
42 println!("Successfully saved JSON file");
43
44 // Load back the JSON file
45 println!("\nLoading back from JSON file...");
46 match loaders::load_json(&json_path) {
47 Ok(loaded) => {
48 print_dataset_info(&loaded, "Loaded JSON");
49 }
50 Err(e) => println!("Error loading JSON: {}", e),
51 }
52 }
53 }
54 Err(e) => println!("Error splitting dataset: {}", e),
55 }
56 }
57 Err(e) => println!("Error loading CSV: {}", e),
58 }
59}
More examples
examples/data_generators.rs (line 23)
6fn main() -> Result<(), Box<dyn std::error::Error>> {
7 println!("Creating synthetic datasets...\n");
8
9 // Generate classification dataset
10 let n_samples = 100;
11 let n_features = 5;
12
13 let classification_data = make_classification(
14 n_samples,
15 n_features,
16 3, // 3 classes
17 2, // 2 clusters per class
18 3, // 3 informative features
19 Some(42), // random seed
20 )?;
21
22 // Train-test split
23 let (train, test) = train_test_split(&classification_data, 0.2, Some(42))?;
24
25 println!("Classification dataset:");
26 println!(" Total samples: {}", classification_data.n_samples());
27 println!(" Features: {}", classification_data.n_features());
28 println!(" Training samples: {}", train.n_samples());
29 println!(" Test samples: {}", test.n_samples());
30
31 // Generate regression dataset
32 let regression_data = make_regression(
33 n_samples,
34 n_features,
35 3, // 3 informative features
36 0.5, // noise level
37 Some(42),
38 )?;
39
40 println!("\nRegression dataset:");
41 println!(" Samples: {}", regression_data.n_samples());
42 println!(" Features: {}", regression_data.n_features());
43
44 // Normalize the data (in-place)
45 let mut data_copy = regression_data.data.clone();
46 normalize(&mut data_copy);
47 println!(" Data normalized successfully");
48
49 // Generate clustering data (blobs)
50 let clustering_data = make_blobs(
51 n_samples,
52 2, // 2 features for easy visualization
53 4, // 4 clusters
54 0.8, // cluster standard deviation
55 Some(42),
56 )?;
57
58 println!("\nClustering dataset (blobs):");
59 println!(" Samples: {}", clustering_data.n_samples());
60 println!(" Features: {}", clustering_data.n_features());
61
62 // Find the number of clusters by finding the max value of target
63 let num_clusters = clustering_data.target.as_ref().map_or(0, |t| {
64 let mut max_val = -1.0;
65 for &val in t.iter() {
66 if val > max_val {
67 max_val = val;
68 }
69 }
70 (max_val as usize) + 1
71 });
72
73 println!(" Clusters: {}", num_clusters);
74
75 // Generate time series data
76 let time_series = make_time_series(
77 100, // 100 time steps
78 3, // 3 features/variables
79 true, // with trend
80 true, // with seasonality
81 0.2, // noise level
82 Some(42),
83 )?;
84
85 println!("\nTime series dataset:");
86 println!(" Time steps: {}", time_series.n_samples());
87 println!(" Features: {}", time_series.n_features());
88
89 Ok(())
90}