pub struct ValidationHandler<F: Float + Debug + ScalarOperand + Display + FromPrimitive + Send + Sync> {
pub config: ValidationConfig,
/* private fields */
}
Expand description
Validation handler for model validation during training
Fields§
§config: ValidationConfig
Configuration for validation
Implementations§
Source§impl<F: Float + Debug + ScalarOperand + Display + FromPrimitive + Send + Sync> ValidationHandler<F>
impl<F: Float + Debug + ScalarOperand + Display + FromPrimitive + Send + Sync> ValidationHandler<F>
Sourcepub fn new(config: ValidationConfig) -> Result<Self>
pub fn new(config: ValidationConfig) -> Result<Self>
Create a new validation handler
Examples found in repository?
examples/model_evaluation_example.rs (line 257)
165fn main() -> Result<()> {
166 println!("Model Evaluation Framework Example");
167 println!("---------------------------------");
168
169 // 1. Basic evaluation
170 println!("\n1. Basic Evaluation:");
171
172 // Generate synthetic regression dataset
173 let dataset = generate_regression_dataset::<f32>(1000, 5)?;
174
175 // Split into train, validation, and test sets
176 let n_samples = dataset.len();
177 let train_size = n_samples * 6 / 10;
178 let val_size = n_samples * 2 / 10;
179 let _test_size = n_samples - train_size - val_size;
180
181 let mut indices: Vec<usize> = (0..n_samples).collect();
182 use rand::seq::SliceRandom;
183 let mut shuffle_rng = SmallRng::seed_from_u64(44);
184 indices.shuffle(&mut shuffle_rng);
185
186 let train_indices = indices[0..train_size].to_vec();
187 let val_indices = indices[train_size..train_size + val_size].to_vec();
188 let test_indices = indices[train_size + val_size..].to_vec();
189
190 println!(
191 "Dataset splits: Train={}, Validation={}, Test={}",
192 train_indices.len(),
193 val_indices.len(),
194 test_indices.len()
195 );
196
197 // Create subset datasets
198 let _train_dataset = SubsetDataset::new(dataset.clone(), train_indices.clone())?;
199 let val_dataset = SubsetDataset::new(dataset.clone(), val_indices.clone())?;
200 let test_dataset = SubsetDataset::new(dataset.clone(), test_indices.clone())?;
201
202 // Build a simple model
203 let model_builder = SimpleModelBuilder::<f32>::new(5, 32, 1);
204 let mut model = model_builder.build()?;
205
206 // Create loss function
207 let loss_fn = MeanSquaredError::new();
208
209 // Create evaluator
210 let eval_config = EvaluationConfig {
211 batch_size: 32,
212 shuffle: false,
213 num_workers: 0,
214 metrics: vec![
215 MetricType::Loss,
216 MetricType::MeanSquaredError,
217 MetricType::MeanAbsoluteError,
218 MetricType::RSquared,
219 ],
220 steps: None,
221 verbose: 1,
222 };
223
224 let mut evaluator = Evaluator::new(eval_config)?;
225
226 // Evaluate model on validation set
227 println!("\nEvaluating model on validation set:");
228 let val_metrics = evaluator.evaluate(&mut model, &val_dataset, Some(&loss_fn))?;
229
230 println!("Validation metrics:");
231 for (name, value) in &val_metrics {
232 println!(" {}: {:.4}", name, value);
233 }
234
235 // 2. Validation with early stopping
236 println!("\n2. Validation with Early Stopping:");
237
238 // Configure early stopping
239 let early_stopping_config = EarlyStoppingConfig {
240 monitor: "val_loss".to_string(),
241 min_delta: 0.001,
242 patience: 5,
243 restore_best_weights: true,
244 mode: EarlyStoppingMode::Min,
245 };
246
247 let validation_config = ValidationConfig {
248 batch_size: 32,
249 shuffle: false,
250 num_workers: 0,
251 steps: None,
252 metrics: vec![MetricType::Loss, MetricType::MeanSquaredError],
253 verbose: 1,
254 early_stopping: Some(early_stopping_config),
255 };
256
257 let mut validation_handler = ValidationHandler::new(validation_config)?;
258
259 // Simulate training loop with validation
260 println!("\nSimulating training loop with validation:");
261 let num_epochs = 10;
262
263 for epoch in 0..num_epochs {
264 println!("Epoch {}/{}", epoch + 1, num_epochs);
265
266 // Simulate training step (not actually training the model)
267 println!("Training...");
268
269 // Validate model
270 let (val_metrics, should_stop) =
271 validation_handler.validate(&mut model, &val_dataset, Some(&loss_fn), epoch)?;
272
273 println!("Validation metrics:");
274 for (name, value) in &val_metrics {
275 println!(" {}: {:.4}", name, value);
276 }
277
278 if should_stop {
279 println!("Early stopping triggered!");
280 break;
281 }
282 }
283
284 // 3. Cross-validation
285 println!("\n3. Cross-Validation:");
286
287 // Configure cross-validation
288 let cv_config = CrossValidationConfig {
289 strategy: CrossValidationStrategy::KFold(5),
290 shuffle: true,
291 random_seed: Some(42),
292 batch_size: 32,
293 num_workers: 0,
294 metrics: vec![
295 MetricType::Loss,
296 MetricType::MeanSquaredError,
297 MetricType::RSquared,
298 ],
299 verbose: 1,
300 };
301
302 let mut cross_validator = CrossValidator::new(cv_config)?;
303
304 // Perform cross-validation
305 println!("\nPerforming 5-fold cross-validation:");
306 let cv_results = cross_validator.cross_validate(&model_builder, &dataset, Some(&loss_fn))?;
307
308 println!("Cross-validation results:");
309 for (name, values) in &cv_results {
310 // Calculate mean and std
311 let sum: f32 = values.iter().sum();
312 let mean = sum / values.len() as f32;
313
314 let variance_sum: f32 = values.iter().map(|&x| (x - mean).powi(2)).sum();
315 let std = (variance_sum / values.len() as f32).sqrt();
316
317 println!(" {}: {:.4} ± {:.4}", name, mean, std);
318 }
319
320 // 4. Test set evaluation
321 println!("\n4. Test Set Evaluation:");
322
323 // Configure test evaluator
324 let test_config = TestConfig {
325 batch_size: 32,
326 num_workers: 0,
327 metrics: vec![
328 MetricType::Loss,
329 MetricType::MeanSquaredError,
330 MetricType::MeanAbsoluteError,
331 MetricType::RSquared,
332 ],
333 steps: None,
334 verbose: 1,
335 generate_predictions: true,
336 save_outputs: false,
337 };
338
339 let mut test_evaluator = TestEvaluator::new(test_config)?;
340
341 // Evaluate model on test set
342 println!("\nEvaluating model on test set:");
343 let test_metrics = test_evaluator.evaluate(&mut model, &test_dataset, Some(&loss_fn))?;
344
345 println!("Test metrics:");
346 for (name, value) in &test_metrics {
347 println!(" {}: {:.4}", name, value);
348 }
349
350 // 5. Classification example
351 println!("\n5. Classification Example:");
352
353 // Generate synthetic classification dataset
354 let n_classes = 3;
355 let class_dataset = generate_classification_dataset::<f32>(1000, 5, n_classes)?;
356
357 // Split dataset
358 let _class_train_dataset = SubsetDataset::new(class_dataset.clone(), train_indices.clone())?;
359 let class_test_dataset = SubsetDataset::new(class_dataset.clone(), test_indices.clone())?;
360
361 // Build classification model
362 let class_model_builder = SimpleModelBuilder::<f32>::new(5, 32, n_classes);
363 let mut class_model = class_model_builder.build()?;
364
365 // Configure test evaluator for classification
366 let class_test_config = TestConfig {
367 batch_size: 32,
368 num_workers: 0,
369 metrics: vec![
370 MetricType::Accuracy,
371 MetricType::Precision,
372 MetricType::Recall,
373 MetricType::F1Score,
374 ],
375 steps: None,
376 verbose: 1,
377 generate_predictions: true,
378 save_outputs: false,
379 };
380
381 let mut class_test_evaluator = TestEvaluator::new(class_test_config)?;
382
383 // Evaluate classification model
384 println!("\nEvaluating classification model:");
385 let class_metrics =
386 class_test_evaluator.evaluate(&mut class_model, &class_test_dataset, None)?;
387
388 println!("Classification metrics:");
389 for (name, value) in &class_metrics {
390 println!(" {}: {:.4}", name, value);
391 }
392
393 // Generate classification report
394 println!("\nClassification Report:");
395 match class_test_evaluator.classification_report() {
396 Ok(report) => println!("{}", report),
397 Err(e) => println!("Could not generate classification report: {}", e),
398 }
399
400 // Generate confusion matrix
401 println!("\nConfusion Matrix:");
402 match class_test_evaluator.confusion_matrix() {
403 Ok(cm) => println!("{}", cm),
404 Err(e) => println!("Could not generate confusion matrix: {}", e),
405 }
406
407 println!("\nModel Evaluation Example Completed Successfully!");
408
409 Ok(())
410}
Sourcepub fn validate<L: Layer<F>>(
&mut self,
model: &mut L,
dataset: &dyn Dataset<F>,
loss_fn: Option<&dyn Loss<F>>,
epoch: usize,
) -> Result<(HashMap<String, F>, bool)>
pub fn validate<L: Layer<F>>( &mut self, model: &mut L, dataset: &dyn Dataset<F>, loss_fn: Option<&dyn Loss<F>>, epoch: usize, ) -> Result<(HashMap<String, F>, bool)>
Validate a model on a dataset
Examples found in repository?
examples/model_evaluation_example.rs (line 271)
165fn main() -> Result<()> {
166 println!("Model Evaluation Framework Example");
167 println!("---------------------------------");
168
169 // 1. Basic evaluation
170 println!("\n1. Basic Evaluation:");
171
172 // Generate synthetic regression dataset
173 let dataset = generate_regression_dataset::<f32>(1000, 5)?;
174
175 // Split into train, validation, and test sets
176 let n_samples = dataset.len();
177 let train_size = n_samples * 6 / 10;
178 let val_size = n_samples * 2 / 10;
179 let _test_size = n_samples - train_size - val_size;
180
181 let mut indices: Vec<usize> = (0..n_samples).collect();
182 use rand::seq::SliceRandom;
183 let mut shuffle_rng = SmallRng::seed_from_u64(44);
184 indices.shuffle(&mut shuffle_rng);
185
186 let train_indices = indices[0..train_size].to_vec();
187 let val_indices = indices[train_size..train_size + val_size].to_vec();
188 let test_indices = indices[train_size + val_size..].to_vec();
189
190 println!(
191 "Dataset splits: Train={}, Validation={}, Test={}",
192 train_indices.len(),
193 val_indices.len(),
194 test_indices.len()
195 );
196
197 // Create subset datasets
198 let _train_dataset = SubsetDataset::new(dataset.clone(), train_indices.clone())?;
199 let val_dataset = SubsetDataset::new(dataset.clone(), val_indices.clone())?;
200 let test_dataset = SubsetDataset::new(dataset.clone(), test_indices.clone())?;
201
202 // Build a simple model
203 let model_builder = SimpleModelBuilder::<f32>::new(5, 32, 1);
204 let mut model = model_builder.build()?;
205
206 // Create loss function
207 let loss_fn = MeanSquaredError::new();
208
209 // Create evaluator
210 let eval_config = EvaluationConfig {
211 batch_size: 32,
212 shuffle: false,
213 num_workers: 0,
214 metrics: vec![
215 MetricType::Loss,
216 MetricType::MeanSquaredError,
217 MetricType::MeanAbsoluteError,
218 MetricType::RSquared,
219 ],
220 steps: None,
221 verbose: 1,
222 };
223
224 let mut evaluator = Evaluator::new(eval_config)?;
225
226 // Evaluate model on validation set
227 println!("\nEvaluating model on validation set:");
228 let val_metrics = evaluator.evaluate(&mut model, &val_dataset, Some(&loss_fn))?;
229
230 println!("Validation metrics:");
231 for (name, value) in &val_metrics {
232 println!(" {}: {:.4}", name, value);
233 }
234
235 // 2. Validation with early stopping
236 println!("\n2. Validation with Early Stopping:");
237
238 // Configure early stopping
239 let early_stopping_config = EarlyStoppingConfig {
240 monitor: "val_loss".to_string(),
241 min_delta: 0.001,
242 patience: 5,
243 restore_best_weights: true,
244 mode: EarlyStoppingMode::Min,
245 };
246
247 let validation_config = ValidationConfig {
248 batch_size: 32,
249 shuffle: false,
250 num_workers: 0,
251 steps: None,
252 metrics: vec![MetricType::Loss, MetricType::MeanSquaredError],
253 verbose: 1,
254 early_stopping: Some(early_stopping_config),
255 };
256
257 let mut validation_handler = ValidationHandler::new(validation_config)?;
258
259 // Simulate training loop with validation
260 println!("\nSimulating training loop with validation:");
261 let num_epochs = 10;
262
263 for epoch in 0..num_epochs {
264 println!("Epoch {}/{}", epoch + 1, num_epochs);
265
266 // Simulate training step (not actually training the model)
267 println!("Training...");
268
269 // Validate model
270 let (val_metrics, should_stop) =
271 validation_handler.validate(&mut model, &val_dataset, Some(&loss_fn), epoch)?;
272
273 println!("Validation metrics:");
274 for (name, value) in &val_metrics {
275 println!(" {}: {:.4}", name, value);
276 }
277
278 if should_stop {
279 println!("Early stopping triggered!");
280 break;
281 }
282 }
283
284 // 3. Cross-validation
285 println!("\n3. Cross-Validation:");
286
287 // Configure cross-validation
288 let cv_config = CrossValidationConfig {
289 strategy: CrossValidationStrategy::KFold(5),
290 shuffle: true,
291 random_seed: Some(42),
292 batch_size: 32,
293 num_workers: 0,
294 metrics: vec![
295 MetricType::Loss,
296 MetricType::MeanSquaredError,
297 MetricType::RSquared,
298 ],
299 verbose: 1,
300 };
301
302 let mut cross_validator = CrossValidator::new(cv_config)?;
303
304 // Perform cross-validation
305 println!("\nPerforming 5-fold cross-validation:");
306 let cv_results = cross_validator.cross_validate(&model_builder, &dataset, Some(&loss_fn))?;
307
308 println!("Cross-validation results:");
309 for (name, values) in &cv_results {
310 // Calculate mean and std
311 let sum: f32 = values.iter().sum();
312 let mean = sum / values.len() as f32;
313
314 let variance_sum: f32 = values.iter().map(|&x| (x - mean).powi(2)).sum();
315 let std = (variance_sum / values.len() as f32).sqrt();
316
317 println!(" {}: {:.4} ± {:.4}", name, mean, std);
318 }
319
320 // 4. Test set evaluation
321 println!("\n4. Test Set Evaluation:");
322
323 // Configure test evaluator
324 let test_config = TestConfig {
325 batch_size: 32,
326 num_workers: 0,
327 metrics: vec![
328 MetricType::Loss,
329 MetricType::MeanSquaredError,
330 MetricType::MeanAbsoluteError,
331 MetricType::RSquared,
332 ],
333 steps: None,
334 verbose: 1,
335 generate_predictions: true,
336 save_outputs: false,
337 };
338
339 let mut test_evaluator = TestEvaluator::new(test_config)?;
340
341 // Evaluate model on test set
342 println!("\nEvaluating model on test set:");
343 let test_metrics = test_evaluator.evaluate(&mut model, &test_dataset, Some(&loss_fn))?;
344
345 println!("Test metrics:");
346 for (name, value) in &test_metrics {
347 println!(" {}: {:.4}", name, value);
348 }
349
350 // 5. Classification example
351 println!("\n5. Classification Example:");
352
353 // Generate synthetic classification dataset
354 let n_classes = 3;
355 let class_dataset = generate_classification_dataset::<f32>(1000, 5, n_classes)?;
356
357 // Split dataset
358 let _class_train_dataset = SubsetDataset::new(class_dataset.clone(), train_indices.clone())?;
359 let class_test_dataset = SubsetDataset::new(class_dataset.clone(), test_indices.clone())?;
360
361 // Build classification model
362 let class_model_builder = SimpleModelBuilder::<f32>::new(5, 32, n_classes);
363 let mut class_model = class_model_builder.build()?;
364
365 // Configure test evaluator for classification
366 let class_test_config = TestConfig {
367 batch_size: 32,
368 num_workers: 0,
369 metrics: vec![
370 MetricType::Accuracy,
371 MetricType::Precision,
372 MetricType::Recall,
373 MetricType::F1Score,
374 ],
375 steps: None,
376 verbose: 1,
377 generate_predictions: true,
378 save_outputs: false,
379 };
380
381 let mut class_test_evaluator = TestEvaluator::new(class_test_config)?;
382
383 // Evaluate classification model
384 println!("\nEvaluating classification model:");
385 let class_metrics =
386 class_test_evaluator.evaluate(&mut class_model, &class_test_dataset, None)?;
387
388 println!("Classification metrics:");
389 for (name, value) in &class_metrics {
390 println!(" {}: {:.4}", name, value);
391 }
392
393 // Generate classification report
394 println!("\nClassification Report:");
395 match class_test_evaluator.classification_report() {
396 Ok(report) => println!("{}", report),
397 Err(e) => println!("Could not generate classification report: {}", e),
398 }
399
400 // Generate confusion matrix
401 println!("\nConfusion Matrix:");
402 match class_test_evaluator.confusion_matrix() {
403 Ok(cm) => println!("{}", cm),
404 Err(e) => println!("Could not generate confusion matrix: {}", e),
405 }
406
407 println!("\nModel Evaluation Example Completed Successfully!");
408
409 Ok(())
410}
Sourcepub fn has_early_stopping(&self) -> bool
pub fn has_early_stopping(&self) -> bool
Check if early stopping is enabled
Sourcepub fn get_early_stopping_state(&self) -> Option<&EarlyStoppingState<F>>
pub fn get_early_stopping_state(&self) -> Option<&EarlyStoppingState<F>>
Get the current early stopping state
Sourcepub fn reset_early_stopping(&mut self)
pub fn reset_early_stopping(&mut self)
Reset early stopping state
Trait Implementations§
Auto Trait Implementations§
impl<F> Freeze for ValidationHandler<F>where
F: Freeze,
impl<F> !RefUnwindSafe for ValidationHandler<F>
impl<F> !Send for ValidationHandler<F>
impl<F> !Sync for ValidationHandler<F>
impl<F> Unpin for ValidationHandler<F>where
F: Unpin,
impl<F> !UnwindSafe for ValidationHandler<F>
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
Converts
self
into a Left
variant of Either<Self, Self>
if into_left
is true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
Converts
self
into a Left
variant of Either<Self, Self>
if into_left(&self)
returns true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read more