train_test_split

Function train_test_split 

Source
pub fn train_test_split(
    dataset: &Dataset,
    test_size: f64,
    random_seed: Option<u64>,
) -> Result<(Dataset, Dataset)>
Expand description

Split a dataset into training and test sets

This function creates a random split of the dataset while preserving the metadata and feature information in both resulting datasets.

§Arguments

  • dataset - The dataset to split
  • test_size - Fraction of samples to include in test set (0.0 to 1.0)
  • random_seed - Optional random seed for reproducible splits

§Returns

A tuple of (train_dataset, test_dataset)

§Examples

use ndarray::Array2;
use scirs2__datasets::utils::{Dataset, train_test_split};

let data = Array2::from_shape_vec((10, 3), (0..30).map(|x| x as f64).collect()).unwrap();
let dataset = Dataset::new(data, None);

let (train, test) = train_test_split(&dataset, 0.3, Some(42)).unwrap();
assert_eq!(train.n_samples() + test.n_samples(), 10);
Examples found in repository?
examples/dataset_loaders.rs (line 37)
7fn main() {
8    // Check if a CSV file is provided as a command-line argument
9    let args: Vec<String> = env::args().collect();
10    if args.len() < 2 {
11        println!("Usage: {} <path_to_csv_file>", args[0]);
12        println!("Example: {} examples/sampledata.csv", args[0]);
13        return;
14    }
15
16    let filepath = &args[1];
17
18    // Verify the file exists
19    if !Path::new(filepath).exists() {
20        println!("Error: File '{filepath}' does not exist");
21        return;
22    }
23
24    // Load CSV file
25    println!("Loading CSV file: {filepath}");
26    let csv_config = loaders::CsvConfig {
27        has_header: true,
28        target_column: None,
29        ..Default::default()
30    };
31    match loaders::load_csv(filepath, csv_config) {
32        Ok(dataset) => {
33            print_dataset_info(&dataset, "Loaded CSV");
34
35            // Split the dataset for demonstration
36            println!("\nDemonstrating train-test split...");
37            match train_test_split(&dataset, 0.2, Some(42)) {
38                Ok((train, test)) => {
39                    println!("Training set: {} samples", train.n_samples());
40                    println!("Test set: {} samples", test.n_samples());
41
42                    // Save as JSON for demonstration
43                    let jsonpath = format!("{filepath}.json");
44                    println!("\nSaving training dataset to JSON: {jsonpath}");
45                    if let Err(e) = loaders::save_json(&train, &jsonpath) {
46                        println!("Error saving JSON: {e}");
47                    } else {
48                        println!("Successfully saved JSON file");
49
50                        // Load back the JSON file
51                        println!("\nLoading back from JSON file...");
52                        match loaders::load_json(&jsonpath) {
53                            Ok(loaded) => {
54                                print_dataset_info(&loaded, "Loaded JSON");
55                            }
56                            Err(e) => println!("Error loading JSON: {e}"),
57                        }
58                    }
59                }
60                Err(e) => println!("Error splitting dataset: {e}"),
61            }
62        }
63        Err(e) => println!("Error loading CSV: {e}"),
64    }
65}
More examples
Hide additional examples
examples/real_world_datasets.rs (line 147)
102fn demonstrate_classification_datasets() -> Result<(), Box<dyn std::error::Error>> {
103    println!("🎯 CLASSIFICATION DATASETS");
104    println!("{}", "-".repeat(40));
105
106    // Titanic dataset
107    println!("Loading Titanic dataset...");
108    let titanic = load_titanic()?;
109
110    println!("Titanic Dataset:");
111    println!(
112        "  Description: {}",
113        titanic
114            .metadata
115            .get("description")
116            .unwrap_or(&"Unknown".to_string())
117    );
118    println!("  Samples: {}", titanic.n_samples());
119    println!("  Features: {}", titanic.n_features());
120
121    if let Some(featurenames) = titanic.featurenames() {
122        println!("  Features: {featurenames:?}");
123    }
124
125    if let Some(targetnames) = titanic.targetnames() {
126        println!("  Classes: {targetnames:?}");
127    }
128
129    // Analyze class distribution
130    if let Some(target) = &titanic.target {
131        let mut class_counts = HashMap::new();
132        for &class in target.iter() {
133            *class_counts.entry(class as i32).or_insert(0) += 1;
134        }
135        println!("  Class distribution: {class_counts:?}");
136
137        // Calculate survival rate
138        let survived = class_counts.get(&1).unwrap_or(&0);
139        let total = titanic.n_samples();
140        println!(
141            "  Survival rate: {:.1}%",
142            (*survived as f64 / total as f64) * 100.0
143        );
144    }
145
146    // Demonstrate train/test split
147    let (train, test) = train_test_split(&titanic, 0.2, Some(42))?;
148    println!(
149        "  Train/test split: {} train, {} test",
150        train.n_samples(),
151        test.n_samples()
152    );
153
154    // Adult (Census Income) dataset
155    println!("\nLoading Adult (Census Income) dataset...");
156    match load_adult() {
157        Ok(adult) => {
158            println!("Adult Dataset:");
159            println!(
160                "  Description: {}",
161                adult
162                    .metadata
163                    .get("description")
164                    .unwrap_or(&"Unknown".to_string())
165            );
166            println!("  Samples: {}", adult.n_samples());
167            println!("  Features: {}", adult.n_features());
168            println!("  Task: Predict income >$50K based on census data");
169        }
170        Err(e) => {
171            println!("  Note: Adult dataset requires download: {e}");
172            println!("  This is expected for the demonstration");
173        }
174    }
175
176    println!();
177    Ok(())
178}
179
180#[allow(dead_code)]
181fn demonstrate_regression_datasets() -> Result<(), Box<dyn std::error::Error>> {
182    println!("📈 REGRESSION DATASETS");
183    println!("{}", "-".repeat(40));
184
185    // California Housing dataset
186    println!("Loading California Housing dataset...");
187    let housing = load_california_housing()?;
188
189    println!("California Housing Dataset:");
190    println!(
191        "  Description: {}",
192        housing
193            .metadata
194            .get("description")
195            .unwrap_or(&"Unknown".to_string())
196    );
197    println!("  Samples: {}", housing.n_samples());
198    println!("  Features: {}", housing.n_features());
199
200    if let Some(featurenames) = housing.featurenames() {
201        println!("  Features: {featurenames:?}");
202    }
203
204    // Analyze target distribution
205    if let Some(target) = &housing.target {
206        let mean = target.mean().unwrap_or(0.0);
207        let std = target.std(0.0);
208        let min = target.iter().fold(f64::INFINITY, |a, &b| a.min(b));
209        let max = target.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
210
211        println!("  Target (house value) statistics:");
212        println!("    Mean: {mean:.2} (hundreds of thousands)");
213        println!("    Std:  {std:.2}");
214        println!("    Range: [{min:.2}, {max:.2}]");
215    }
216
217    // Red Wine Quality dataset
218    println!("\nLoading Red Wine Quality dataset...");
219    let wine = load_red_wine_quality()?;
220
221    println!("Red Wine Quality Dataset:");
222    println!(
223        "  Description: {}",
224        wine.metadata
225            .get("description")
226            .unwrap_or(&"Unknown".to_string())
227    );
228    println!("  Samples: {}", wine.n_samples());
229    println!("  Features: {}", wine.n_features());
230
231    if let Some(target) = &wine.target {
232        let mean_quality = target.mean().unwrap_or(0.0);
233        println!("  Average wine quality: {mean_quality:.1}/10");
234
235        // Quality distribution
236        let mut quality_counts = HashMap::new();
237        for &quality in target.iter() {
238            let q = quality.round() as i32;
239            *quality_counts.entry(q).or_insert(0) += 1;
240        }
241        println!("  Quality distribution: {quality_counts:?}");
242    }
243
244    println!();
245    Ok(())
246}
247
248#[allow(dead_code)]
249fn demonstrate_healthcare_datasets() -> Result<(), Box<dyn std::error::Error>> {
250    println!("🏥 HEALTHCARE DATASETS");
251    println!("{}", "-".repeat(40));
252
253    // Heart Disease dataset
254    println!("Loading Heart Disease dataset...");
255    let heart = load_heart_disease()?;
256
257    println!("Heart Disease Dataset:");
258    println!(
259        "  Description: {}",
260        heart
261            .metadata
262            .get("description")
263            .unwrap_or(&"Unknown".to_string())
264    );
265    println!("  Samples: {}", heart.n_samples());
266    println!("  Features: {}", heart.n_features());
267
268    if let Some(featurenames) = heart.featurenames() {
269        println!("  Clinical features: {:?}", &featurenames[..5]); // Show first 5
270        println!("  ... and {} more features", featurenames.len() - 5);
271    }
272
273    // Analyze risk factors
274    if let Some(target) = &heart.target {
275        let mut disease_counts = HashMap::new();
276        for &disease in target.iter() {
277            *disease_counts.entry(disease as i32).or_insert(0) += 1;
278        }
279
280        let with_disease = disease_counts.get(&1).unwrap_or(&0);
281        let total = heart.n_samples();
282        println!(
283            "  Disease prevalence: {:.1}% ({}/{})",
284            (*with_disease as f64 / total as f64) * 100.0,
285            with_disease,
286            total
287        );
288    }
289
290    // Demonstrate feature analysis
291    println!("  Sample clinical parameter ranges:");
292    let age_col = heart.data.column(0);
293    let age_mean = age_col.mean();
294    let age_std = age_col.std(0.0);
295    println!("    Age: {age_mean:.1} ± {age_std:.1} years");
296
297    println!();
298    Ok(())
299}
300
301#[allow(dead_code)]
302fn demonstrate_advanced_operations() -> Result<(), Box<dyn std::error::Error>> {
303    println!("🔧 ADVANCED DATASET OPERATIONS");
304    println!("{}", "-".repeat(40));
305
306    let housing = load_california_housing()?;
307
308    // Data preprocessing pipeline
309    println!("Preprocessing pipeline for California Housing:");
310
311    // 1. Train/test split
312    let (mut train, test) = train_test_split(&housing, 0.2, Some(42))?;
313    println!(
314        "  1. Split: {} train, {} test",
315        train.n_samples(),
316        test.n_samples()
317    );
318
319    // 2. Feature scaling
320    let mut pipeline = MLPipeline::default();
321    train = pipeline.prepare_dataset(&train)?;
322    println!("  2. Standardized features");
323
324    // 3. Cross-validation setup
325    let cv_folds = k_fold_split(train.n_samples(), 5, true, Some(42))?;
326    println!("  3. Created {} CV folds", cv_folds.len());
327
328    // Feature correlation analysis (simplified)
329    println!("  4. Feature analysis:");
330    println!("     • {} numerical features", train.n_features());
331    println!("     • Ready for machine learning models");
332
333    // Custom dataset configuration
334    println!("\nCustom dataset loading configuration:");
335    let config = RealWorldConfig {
336        use_cache: true,
337        download_if_missing: false, // Don't download in demo
338        return_preprocessed: true,
339        subset: Some("small".to_string()),
340        random_state: Some(42),
341        ..Default::default()
342    };
343
344    println!("  • Caching: {}", config.use_cache);
345    println!("  • Download missing: {}", config.download_if_missing);
346    println!("  • Preprocessed: {}", config.return_preprocessed);
347    println!("  • Subset: {:?}", config.subset);
348
349    println!();
350    Ok(())
351}
examples/data_generators.rs (line 24)
7fn main() -> Result<(), Box<dyn std::error::Error>> {
8    println!("Creating synthetic datasets...\n");
9
10    // Generate classification dataset
11    let n_samples = 100;
12    let n_features = 5;
13
14    let classificationdata = make_classification(
15        n_samples,
16        n_features,
17        3,        // 3 classes
18        2,        // 2 clusters per class
19        3,        // 3 informative features
20        Some(42), // random seed
21    )?;
22
23    // Train-test split
24    let (train, test) = train_test_split(&classificationdata, 0.2, Some(42))?;
25
26    println!("Classification dataset:");
27    println!("  Total samples: {}", classificationdata.n_samples());
28    println!("  Features: {}", classificationdata.n_features());
29    println!("  Training samples: {}", train.n_samples());
30    println!("  Test samples: {}", test.n_samples());
31
32    // Generate regression dataset
33    let regressiondata = make_regression(
34        n_samples,
35        n_features,
36        3,   // 3 informative features
37        0.5, // noise level
38        Some(42),
39    )?;
40
41    println!("\nRegression dataset:");
42    println!("  Samples: {}", regressiondata.n_samples());
43    println!("  Features: {}", regressiondata.n_features());
44
45    // Normalize the data (in-place)
46    let mut data_copy = regressiondata.data.clone();
47    normalize(&mut data_copy);
48    println!("  Data normalized successfully");
49
50    // Generate clustering data (blobs)
51    let clusteringdata = make_blobs(
52        n_samples,
53        2,   // 2 features for easy visualization
54        4,   // 4 clusters
55        0.8, // cluster standard deviation
56        Some(42),
57    )?;
58
59    println!("\nClustering dataset (blobs):");
60    println!("  Samples: {}", clusteringdata.n_samples());
61    println!("  Features: {}", clusteringdata.n_features());
62
63    // Find the number of clusters by finding the max value of target
64    let num_clusters = clusteringdata.target.as_ref().map_or(0, |t| {
65        let mut max_val = -1.0;
66        for &val in t.iter() {
67            if val > max_val {
68                max_val = val;
69            }
70        }
71        (max_val as usize) + 1
72    });
73
74    println!("  Clusters: {num_clusters}");
75
76    // Generate time series data
77    let time_series = make_time_series(
78        100,  // 100 time steps
79        3,    // 3 features/variables
80        true, // with trend
81        true, // with seasonality
82        0.2,  // noise level
83        Some(42),
84    )?;
85
86    println!("\nTime series dataset:");
87    println!("  Time steps: {}", time_series.n_samples());
88    println!("  Features: {}", time_series.n_features());
89
90    Ok(())
91}
examples/datasets_streaming_demo.rs (line 148)
132fn demonstrate_memory_efficient_processing() -> Result<(), Box<dyn std::error::Error>> {
133    println!("💾 MEMORY-EFFICIENT PROCESSING");
134    println!("{}", "-".repeat(40));
135
136    // Compare memory usage: streaming vs. in-memory
137    let datasetsize = 50_000;
138    let n_features = 50;
139
140    println!("Comparing memory usage for {datasetsize} samples with {n_features} features");
141
142    // In-memory approach (for comparison)
143    println!("\n1. In-memory approach:");
144    let start_mem = get_memory_usage();
145    let start_time = Instant::now();
146
147    let in_memorydataset = make_classification(datasetsize, n_features, 5, 2, 25, Some(42))?;
148    let (train, test) = train_test_split(&in_memorydataset, 0.2, Some(42))?;
149
150    let in_memory_time = start_time.elapsed();
151    let in_memory_mem = get_memory_usage() - start_mem;
152
153    println!("  Time: {:.2}s", in_memory_time.as_secs_f64());
154    println!("  Memory usage: ~{in_memory_mem:.1} MB");
155    println!("  Train samples: {}", train.n_samples());
156    println!("  Test samples: {}", test.n_samples());
157
158    // Streaming approach
159    println!("\n2. Streaming approach:");
160    let stream_start_time = Instant::now();
161    let stream_start_mem = get_memory_usage();
162
163    let config = StreamConfig {
164        chunk_size: 5_000, // Smaller chunks for memory efficiency
165        buffer_size: 2,    // Smaller buffer
166        num_workers: 2,
167        memory_limit_mb: Some(50),
168        ..Default::default()
169    };
170
171    let mut stream = stream_classification(datasetsize, n_features, 5, config)?;
172
173    let mut total_processed = 0;
174    let mut train_samples = 0;
175    let mut test_samples = 0;
176
177    while let Some(chunk) = stream.next_chunk()? {
178        total_processed += chunk.n_samples();
179
180        // Simulate train/test split on chunk level
181        let chunk_trainsize = (chunk.n_samples() as f64 * 0.8) as usize;
182        train_samples += chunk_trainsize;
183        test_samples += chunk.n_samples() - chunk_trainsize;
184
185        // Process chunk (simulate some computation)
186        let _mean = chunk.data.mean_axis(ndarray::Axis(0));
187        let _std = chunk.data.std_axis(ndarray::Axis(0), 0.0);
188
189        if chunk.is_last {
190            break;
191        }
192    }
193
194    let stream_time = stream_start_time.elapsed();
195    let stream_mem = get_memory_usage() - stream_start_mem;
196
197    println!("  Time: {:.2}s", stream_time.as_secs_f64());
198    println!("  Memory usage: ~{stream_mem:.1} MB");
199    println!("  Train samples: {train_samples}");
200    println!("  Test samples: {test_samples}");
201    println!("  Total processed: {total_processed}");
202
203    // Comparison
204    println!("\n3. Comparison:");
205    println!(
206        "  Memory savings: {:.1}x less memory",
207        in_memory_mem / stream_mem.max(1.0)
208    );
209    println!(
210        "  Time overhead: {:.1}x",
211        stream_time.as_secs_f64() / in_memory_time.as_secs_f64()
212    );
213    println!("  Streaming is beneficial for large datasets that don't fit in memory");
214
215    println!();
216    Ok(())
217}