pub fn train_test_split(
dataset: &Dataset,
test_size: f64,
random_seed: Option<u64>,
) -> Result<(Dataset, Dataset)>Expand description
Split a dataset into training and test sets
This function creates a random split of the dataset while preserving the metadata and feature information in both resulting datasets.
§Arguments
dataset- The dataset to splittest_size- Fraction of samples to include in test set (0.0 to 1.0)random_seed- Optional random seed for reproducible splits
§Returns
A tuple of (train_dataset, test_dataset)
§Examples
use ndarray::Array2;
use scirs2__datasets::utils::{Dataset, train_test_split};
let data = Array2::from_shape_vec((10, 3), (0..30).map(|x| x as f64).collect()).unwrap();
let dataset = Dataset::new(data, None);
let (train, test) = train_test_split(&dataset, 0.3, Some(42)).unwrap();
assert_eq!(train.n_samples() + test.n_samples(), 10);Examples found in repository?
examples/dataset_loaders.rs (line 37)
7fn main() {
8 // Check if a CSV file is provided as a command-line argument
9 let args: Vec<String> = env::args().collect();
10 if args.len() < 2 {
11 println!("Usage: {} <path_to_csv_file>", args[0]);
12 println!("Example: {} examples/sampledata.csv", args[0]);
13 return;
14 }
15
16 let filepath = &args[1];
17
18 // Verify the file exists
19 if !Path::new(filepath).exists() {
20 println!("Error: File '{filepath}' does not exist");
21 return;
22 }
23
24 // Load CSV file
25 println!("Loading CSV file: {filepath}");
26 let csv_config = loaders::CsvConfig {
27 has_header: true,
28 target_column: None,
29 ..Default::default()
30 };
31 match loaders::load_csv(filepath, csv_config) {
32 Ok(dataset) => {
33 print_dataset_info(&dataset, "Loaded CSV");
34
35 // Split the dataset for demonstration
36 println!("\nDemonstrating train-test split...");
37 match train_test_split(&dataset, 0.2, Some(42)) {
38 Ok((train, test)) => {
39 println!("Training set: {} samples", train.n_samples());
40 println!("Test set: {} samples", test.n_samples());
41
42 // Save as JSON for demonstration
43 let jsonpath = format!("{filepath}.json");
44 println!("\nSaving training dataset to JSON: {jsonpath}");
45 if let Err(e) = loaders::save_json(&train, &jsonpath) {
46 println!("Error saving JSON: {e}");
47 } else {
48 println!("Successfully saved JSON file");
49
50 // Load back the JSON file
51 println!("\nLoading back from JSON file...");
52 match loaders::load_json(&jsonpath) {
53 Ok(loaded) => {
54 print_dataset_info(&loaded, "Loaded JSON");
55 }
56 Err(e) => println!("Error loading JSON: {e}"),
57 }
58 }
59 }
60 Err(e) => println!("Error splitting dataset: {e}"),
61 }
62 }
63 Err(e) => println!("Error loading CSV: {e}"),
64 }
65}More examples
examples/real_world_datasets.rs (line 147)
102fn demonstrate_classification_datasets() -> Result<(), Box<dyn std::error::Error>> {
103 println!("🎯 CLASSIFICATION DATASETS");
104 println!("{}", "-".repeat(40));
105
106 // Titanic dataset
107 println!("Loading Titanic dataset...");
108 let titanic = load_titanic()?;
109
110 println!("Titanic Dataset:");
111 println!(
112 " Description: {}",
113 titanic
114 .metadata
115 .get("description")
116 .unwrap_or(&"Unknown".to_string())
117 );
118 println!(" Samples: {}", titanic.n_samples());
119 println!(" Features: {}", titanic.n_features());
120
121 if let Some(featurenames) = titanic.featurenames() {
122 println!(" Features: {featurenames:?}");
123 }
124
125 if let Some(targetnames) = titanic.targetnames() {
126 println!(" Classes: {targetnames:?}");
127 }
128
129 // Analyze class distribution
130 if let Some(target) = &titanic.target {
131 let mut class_counts = HashMap::new();
132 for &class in target.iter() {
133 *class_counts.entry(class as i32).or_insert(0) += 1;
134 }
135 println!(" Class distribution: {class_counts:?}");
136
137 // Calculate survival rate
138 let survived = class_counts.get(&1).unwrap_or(&0);
139 let total = titanic.n_samples();
140 println!(
141 " Survival rate: {:.1}%",
142 (*survived as f64 / total as f64) * 100.0
143 );
144 }
145
146 // Demonstrate train/test split
147 let (train, test) = train_test_split(&titanic, 0.2, Some(42))?;
148 println!(
149 " Train/test split: {} train, {} test",
150 train.n_samples(),
151 test.n_samples()
152 );
153
154 // Adult (Census Income) dataset
155 println!("\nLoading Adult (Census Income) dataset...");
156 match load_adult() {
157 Ok(adult) => {
158 println!("Adult Dataset:");
159 println!(
160 " Description: {}",
161 adult
162 .metadata
163 .get("description")
164 .unwrap_or(&"Unknown".to_string())
165 );
166 println!(" Samples: {}", adult.n_samples());
167 println!(" Features: {}", adult.n_features());
168 println!(" Task: Predict income >$50K based on census data");
169 }
170 Err(e) => {
171 println!(" Note: Adult dataset requires download: {e}");
172 println!(" This is expected for the demonstration");
173 }
174 }
175
176 println!();
177 Ok(())
178}
179
180#[allow(dead_code)]
181fn demonstrate_regression_datasets() -> Result<(), Box<dyn std::error::Error>> {
182 println!("📈 REGRESSION DATASETS");
183 println!("{}", "-".repeat(40));
184
185 // California Housing dataset
186 println!("Loading California Housing dataset...");
187 let housing = load_california_housing()?;
188
189 println!("California Housing Dataset:");
190 println!(
191 " Description: {}",
192 housing
193 .metadata
194 .get("description")
195 .unwrap_or(&"Unknown".to_string())
196 );
197 println!(" Samples: {}", housing.n_samples());
198 println!(" Features: {}", housing.n_features());
199
200 if let Some(featurenames) = housing.featurenames() {
201 println!(" Features: {featurenames:?}");
202 }
203
204 // Analyze target distribution
205 if let Some(target) = &housing.target {
206 let mean = target.mean().unwrap_or(0.0);
207 let std = target.std(0.0);
208 let min = target.iter().fold(f64::INFINITY, |a, &b| a.min(b));
209 let max = target.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
210
211 println!(" Target (house value) statistics:");
212 println!(" Mean: {mean:.2} (hundreds of thousands)");
213 println!(" Std: {std:.2}");
214 println!(" Range: [{min:.2}, {max:.2}]");
215 }
216
217 // Red Wine Quality dataset
218 println!("\nLoading Red Wine Quality dataset...");
219 let wine = load_red_wine_quality()?;
220
221 println!("Red Wine Quality Dataset:");
222 println!(
223 " Description: {}",
224 wine.metadata
225 .get("description")
226 .unwrap_or(&"Unknown".to_string())
227 );
228 println!(" Samples: {}", wine.n_samples());
229 println!(" Features: {}", wine.n_features());
230
231 if let Some(target) = &wine.target {
232 let mean_quality = target.mean().unwrap_or(0.0);
233 println!(" Average wine quality: {mean_quality:.1}/10");
234
235 // Quality distribution
236 let mut quality_counts = HashMap::new();
237 for &quality in target.iter() {
238 let q = quality.round() as i32;
239 *quality_counts.entry(q).or_insert(0) += 1;
240 }
241 println!(" Quality distribution: {quality_counts:?}");
242 }
243
244 println!();
245 Ok(())
246}
247
248#[allow(dead_code)]
249fn demonstrate_healthcare_datasets() -> Result<(), Box<dyn std::error::Error>> {
250 println!("🏥 HEALTHCARE DATASETS");
251 println!("{}", "-".repeat(40));
252
253 // Heart Disease dataset
254 println!("Loading Heart Disease dataset...");
255 let heart = load_heart_disease()?;
256
257 println!("Heart Disease Dataset:");
258 println!(
259 " Description: {}",
260 heart
261 .metadata
262 .get("description")
263 .unwrap_or(&"Unknown".to_string())
264 );
265 println!(" Samples: {}", heart.n_samples());
266 println!(" Features: {}", heart.n_features());
267
268 if let Some(featurenames) = heart.featurenames() {
269 println!(" Clinical features: {:?}", &featurenames[..5]); // Show first 5
270 println!(" ... and {} more features", featurenames.len() - 5);
271 }
272
273 // Analyze risk factors
274 if let Some(target) = &heart.target {
275 let mut disease_counts = HashMap::new();
276 for &disease in target.iter() {
277 *disease_counts.entry(disease as i32).or_insert(0) += 1;
278 }
279
280 let with_disease = disease_counts.get(&1).unwrap_or(&0);
281 let total = heart.n_samples();
282 println!(
283 " Disease prevalence: {:.1}% ({}/{})",
284 (*with_disease as f64 / total as f64) * 100.0,
285 with_disease,
286 total
287 );
288 }
289
290 // Demonstrate feature analysis
291 println!(" Sample clinical parameter ranges:");
292 let age_col = heart.data.column(0);
293 let age_mean = age_col.mean();
294 let age_std = age_col.std(0.0);
295 println!(" Age: {age_mean:.1} ± {age_std:.1} years");
296
297 println!();
298 Ok(())
299}
300
301#[allow(dead_code)]
302fn demonstrate_advanced_operations() -> Result<(), Box<dyn std::error::Error>> {
303 println!("🔧 ADVANCED DATASET OPERATIONS");
304 println!("{}", "-".repeat(40));
305
306 let housing = load_california_housing()?;
307
308 // Data preprocessing pipeline
309 println!("Preprocessing pipeline for California Housing:");
310
311 // 1. Train/test split
312 let (mut train, test) = train_test_split(&housing, 0.2, Some(42))?;
313 println!(
314 " 1. Split: {} train, {} test",
315 train.n_samples(),
316 test.n_samples()
317 );
318
319 // 2. Feature scaling
320 let mut pipeline = MLPipeline::default();
321 train = pipeline.prepare_dataset(&train)?;
322 println!(" 2. Standardized features");
323
324 // 3. Cross-validation setup
325 let cv_folds = k_fold_split(train.n_samples(), 5, true, Some(42))?;
326 println!(" 3. Created {} CV folds", cv_folds.len());
327
328 // Feature correlation analysis (simplified)
329 println!(" 4. Feature analysis:");
330 println!(" • {} numerical features", train.n_features());
331 println!(" • Ready for machine learning models");
332
333 // Custom dataset configuration
334 println!("\nCustom dataset loading configuration:");
335 let config = RealWorldConfig {
336 use_cache: true,
337 download_if_missing: false, // Don't download in demo
338 return_preprocessed: true,
339 subset: Some("small".to_string()),
340 random_state: Some(42),
341 ..Default::default()
342 };
343
344 println!(" • Caching: {}", config.use_cache);
345 println!(" • Download missing: {}", config.download_if_missing);
346 println!(" • Preprocessed: {}", config.return_preprocessed);
347 println!(" • Subset: {:?}", config.subset);
348
349 println!();
350 Ok(())
351}examples/data_generators.rs (line 24)
7fn main() -> Result<(), Box<dyn std::error::Error>> {
8 println!("Creating synthetic datasets...\n");
9
10 // Generate classification dataset
11 let n_samples = 100;
12 let n_features = 5;
13
14 let classificationdata = make_classification(
15 n_samples,
16 n_features,
17 3, // 3 classes
18 2, // 2 clusters per class
19 3, // 3 informative features
20 Some(42), // random seed
21 )?;
22
23 // Train-test split
24 let (train, test) = train_test_split(&classificationdata, 0.2, Some(42))?;
25
26 println!("Classification dataset:");
27 println!(" Total samples: {}", classificationdata.n_samples());
28 println!(" Features: {}", classificationdata.n_features());
29 println!(" Training samples: {}", train.n_samples());
30 println!(" Test samples: {}", test.n_samples());
31
32 // Generate regression dataset
33 let regressiondata = make_regression(
34 n_samples,
35 n_features,
36 3, // 3 informative features
37 0.5, // noise level
38 Some(42),
39 )?;
40
41 println!("\nRegression dataset:");
42 println!(" Samples: {}", regressiondata.n_samples());
43 println!(" Features: {}", regressiondata.n_features());
44
45 // Normalize the data (in-place)
46 let mut data_copy = regressiondata.data.clone();
47 normalize(&mut data_copy);
48 println!(" Data normalized successfully");
49
50 // Generate clustering data (blobs)
51 let clusteringdata = make_blobs(
52 n_samples,
53 2, // 2 features for easy visualization
54 4, // 4 clusters
55 0.8, // cluster standard deviation
56 Some(42),
57 )?;
58
59 println!("\nClustering dataset (blobs):");
60 println!(" Samples: {}", clusteringdata.n_samples());
61 println!(" Features: {}", clusteringdata.n_features());
62
63 // Find the number of clusters by finding the max value of target
64 let num_clusters = clusteringdata.target.as_ref().map_or(0, |t| {
65 let mut max_val = -1.0;
66 for &val in t.iter() {
67 if val > max_val {
68 max_val = val;
69 }
70 }
71 (max_val as usize) + 1
72 });
73
74 println!(" Clusters: {num_clusters}");
75
76 // Generate time series data
77 let time_series = make_time_series(
78 100, // 100 time steps
79 3, // 3 features/variables
80 true, // with trend
81 true, // with seasonality
82 0.2, // noise level
83 Some(42),
84 )?;
85
86 println!("\nTime series dataset:");
87 println!(" Time steps: {}", time_series.n_samples());
88 println!(" Features: {}", time_series.n_features());
89
90 Ok(())
91}examples/datasets_streaming_demo.rs (line 148)
132fn demonstrate_memory_efficient_processing() -> Result<(), Box<dyn std::error::Error>> {
133 println!("💾 MEMORY-EFFICIENT PROCESSING");
134 println!("{}", "-".repeat(40));
135
136 // Compare memory usage: streaming vs. in-memory
137 let datasetsize = 50_000;
138 let n_features = 50;
139
140 println!("Comparing memory usage for {datasetsize} samples with {n_features} features");
141
142 // In-memory approach (for comparison)
143 println!("\n1. In-memory approach:");
144 let start_mem = get_memory_usage();
145 let start_time = Instant::now();
146
147 let in_memorydataset = make_classification(datasetsize, n_features, 5, 2, 25, Some(42))?;
148 let (train, test) = train_test_split(&in_memorydataset, 0.2, Some(42))?;
149
150 let in_memory_time = start_time.elapsed();
151 let in_memory_mem = get_memory_usage() - start_mem;
152
153 println!(" Time: {:.2}s", in_memory_time.as_secs_f64());
154 println!(" Memory usage: ~{in_memory_mem:.1} MB");
155 println!(" Train samples: {}", train.n_samples());
156 println!(" Test samples: {}", test.n_samples());
157
158 // Streaming approach
159 println!("\n2. Streaming approach:");
160 let stream_start_time = Instant::now();
161 let stream_start_mem = get_memory_usage();
162
163 let config = StreamConfig {
164 chunk_size: 5_000, // Smaller chunks for memory efficiency
165 buffer_size: 2, // Smaller buffer
166 num_workers: 2,
167 memory_limit_mb: Some(50),
168 ..Default::default()
169 };
170
171 let mut stream = stream_classification(datasetsize, n_features, 5, config)?;
172
173 let mut total_processed = 0;
174 let mut train_samples = 0;
175 let mut test_samples = 0;
176
177 while let Some(chunk) = stream.next_chunk()? {
178 total_processed += chunk.n_samples();
179
180 // Simulate train/test split on chunk level
181 let chunk_trainsize = (chunk.n_samples() as f64 * 0.8) as usize;
182 train_samples += chunk_trainsize;
183 test_samples += chunk.n_samples() - chunk_trainsize;
184
185 // Process chunk (simulate some computation)
186 let _mean = chunk.data.mean_axis(ndarray::Axis(0));
187 let _std = chunk.data.std_axis(ndarray::Axis(0), 0.0);
188
189 if chunk.is_last {
190 break;
191 }
192 }
193
194 let stream_time = stream_start_time.elapsed();
195 let stream_mem = get_memory_usage() - stream_start_mem;
196
197 println!(" Time: {:.2}s", stream_time.as_secs_f64());
198 println!(" Memory usage: ~{stream_mem:.1} MB");
199 println!(" Train samples: {train_samples}");
200 println!(" Test samples: {test_samples}");
201 println!(" Total processed: {total_processed}");
202
203 // Comparison
204 println!("\n3. Comparison:");
205 println!(
206 " Memory savings: {:.1}x less memory",
207 in_memory_mem / stream_mem.max(1.0)
208 );
209 println!(
210 " Time overhead: {:.1}x",
211 stream_time.as_secs_f64() / in_memory_time.as_secs_f64()
212 );
213 println!(" Streaming is beneficial for large datasets that don't fit in memory");
214
215 println!();
216 Ok(())
217}