Struct ValidationHandler

Source
pub struct ValidationHandler<F: Float + Debug + ScalarOperand + Display + FromPrimitive + Send + Sync> {
    pub config: ValidationConfig,
    /* private fields */
}
Expand description

Validation handler for model validation during training

Fields§

§config: ValidationConfig

Configuration for validation

Implementations§

Source§

impl<F: Float + Debug + ScalarOperand + Display + FromPrimitive + Send + Sync> ValidationHandler<F>

Source

pub fn new(config: ValidationConfig) -> Result<Self>

Create a new validation handler

Examples found in repository?
examples/model_evaluation_example.rs (line 257)
165fn main() -> Result<()> {
166    println!("Model Evaluation Framework Example");
167    println!("---------------------------------");
168
169    // 1. Basic evaluation
170    println!("\n1. Basic Evaluation:");
171
172    // Generate synthetic regression dataset
173    let dataset = generate_regression_dataset::<f32>(1000, 5)?;
174
175    // Split into train, validation, and test sets
176    let n_samples = dataset.len();
177    let train_size = n_samples * 6 / 10;
178    let val_size = n_samples * 2 / 10;
179    let _test_size = n_samples - train_size - val_size;
180
181    let mut indices: Vec<usize> = (0..n_samples).collect();
182    use rand::seq::SliceRandom;
183    let mut shuffle_rng = SmallRng::seed_from_u64(44);
184    indices.shuffle(&mut shuffle_rng);
185
186    let train_indices = indices[0..train_size].to_vec();
187    let val_indices = indices[train_size..train_size + val_size].to_vec();
188    let test_indices = indices[train_size + val_size..].to_vec();
189
190    println!(
191        "Dataset splits: Train={}, Validation={}, Test={}",
192        train_indices.len(),
193        val_indices.len(),
194        test_indices.len()
195    );
196
197    // Create subset datasets
198    let _train_dataset = SubsetDataset::new(dataset.clone(), train_indices.clone())?;
199    let val_dataset = SubsetDataset::new(dataset.clone(), val_indices.clone())?;
200    let test_dataset = SubsetDataset::new(dataset.clone(), test_indices.clone())?;
201
202    // Build a simple model
203    let model_builder = SimpleModelBuilder::<f32>::new(5, 32, 1);
204    let mut model = model_builder.build()?;
205
206    // Create loss function
207    let loss_fn = MeanSquaredError::new();
208
209    // Create evaluator
210    let eval_config = EvaluationConfig {
211        batch_size: 32,
212        shuffle: false,
213        num_workers: 0,
214        metrics: vec![
215            MetricType::Loss,
216            MetricType::MeanSquaredError,
217            MetricType::MeanAbsoluteError,
218            MetricType::RSquared,
219        ],
220        steps: None,
221        verbose: 1,
222    };
223
224    let mut evaluator = Evaluator::new(eval_config)?;
225
226    // Evaluate model on validation set
227    println!("\nEvaluating model on validation set:");
228    let val_metrics = evaluator.evaluate(&mut model, &val_dataset, Some(&loss_fn))?;
229
230    println!("Validation metrics:");
231    for (name, value) in &val_metrics {
232        println!("  {}: {:.4}", name, value);
233    }
234
235    // 2. Validation with early stopping
236    println!("\n2. Validation with Early Stopping:");
237
238    // Configure early stopping
239    let early_stopping_config = EarlyStoppingConfig {
240        monitor: "val_loss".to_string(),
241        min_delta: 0.001,
242        patience: 5,
243        restore_best_weights: true,
244        mode: EarlyStoppingMode::Min,
245    };
246
247    let validation_config = ValidationConfig {
248        batch_size: 32,
249        shuffle: false,
250        num_workers: 0,
251        steps: None,
252        metrics: vec![MetricType::Loss, MetricType::MeanSquaredError],
253        verbose: 1,
254        early_stopping: Some(early_stopping_config),
255    };
256
257    let mut validation_handler = ValidationHandler::new(validation_config)?;
258
259    // Simulate training loop with validation
260    println!("\nSimulating training loop with validation:");
261    let num_epochs = 10;
262
263    for epoch in 0..num_epochs {
264        println!("Epoch {}/{}", epoch + 1, num_epochs);
265
266        // Simulate training step (not actually training the model)
267        println!("Training...");
268
269        // Validate model
270        let (val_metrics, should_stop) =
271            validation_handler.validate(&mut model, &val_dataset, Some(&loss_fn), epoch)?;
272
273        println!("Validation metrics:");
274        for (name, value) in &val_metrics {
275            println!("  {}: {:.4}", name, value);
276        }
277
278        if should_stop {
279            println!("Early stopping triggered!");
280            break;
281        }
282    }
283
284    // 3. Cross-validation
285    println!("\n3. Cross-Validation:");
286
287    // Configure cross-validation
288    let cv_config = CrossValidationConfig {
289        strategy: CrossValidationStrategy::KFold(5),
290        shuffle: true,
291        random_seed: Some(42),
292        batch_size: 32,
293        num_workers: 0,
294        metrics: vec![
295            MetricType::Loss,
296            MetricType::MeanSquaredError,
297            MetricType::RSquared,
298        ],
299        verbose: 1,
300    };
301
302    let mut cross_validator = CrossValidator::new(cv_config)?;
303
304    // Perform cross-validation
305    println!("\nPerforming 5-fold cross-validation:");
306    let cv_results = cross_validator.cross_validate(&model_builder, &dataset, Some(&loss_fn))?;
307
308    println!("Cross-validation results:");
309    for (name, values) in &cv_results {
310        // Calculate mean and std
311        let sum: f32 = values.iter().sum();
312        let mean = sum / values.len() as f32;
313
314        let variance_sum: f32 = values.iter().map(|&x| (x - mean).powi(2)).sum();
315        let std = (variance_sum / values.len() as f32).sqrt();
316
317        println!("  {}: {:.4} ± {:.4}", name, mean, std);
318    }
319
320    // 4. Test set evaluation
321    println!("\n4. Test Set Evaluation:");
322
323    // Configure test evaluator
324    let test_config = TestConfig {
325        batch_size: 32,
326        num_workers: 0,
327        metrics: vec![
328            MetricType::Loss,
329            MetricType::MeanSquaredError,
330            MetricType::MeanAbsoluteError,
331            MetricType::RSquared,
332        ],
333        steps: None,
334        verbose: 1,
335        generate_predictions: true,
336        save_outputs: false,
337    };
338
339    let mut test_evaluator = TestEvaluator::new(test_config)?;
340
341    // Evaluate model on test set
342    println!("\nEvaluating model on test set:");
343    let test_metrics = test_evaluator.evaluate(&mut model, &test_dataset, Some(&loss_fn))?;
344
345    println!("Test metrics:");
346    for (name, value) in &test_metrics {
347        println!("  {}: {:.4}", name, value);
348    }
349
350    // 5. Classification example
351    println!("\n5. Classification Example:");
352
353    // Generate synthetic classification dataset
354    let n_classes = 3;
355    let class_dataset = generate_classification_dataset::<f32>(1000, 5, n_classes)?;
356
357    // Split dataset
358    let _class_train_dataset = SubsetDataset::new(class_dataset.clone(), train_indices.clone())?;
359    let class_test_dataset = SubsetDataset::new(class_dataset.clone(), test_indices.clone())?;
360
361    // Build classification model
362    let class_model_builder = SimpleModelBuilder::<f32>::new(5, 32, n_classes);
363    let mut class_model = class_model_builder.build()?;
364
365    // Configure test evaluator for classification
366    let class_test_config = TestConfig {
367        batch_size: 32,
368        num_workers: 0,
369        metrics: vec![
370            MetricType::Accuracy,
371            MetricType::Precision,
372            MetricType::Recall,
373            MetricType::F1Score,
374        ],
375        steps: None,
376        verbose: 1,
377        generate_predictions: true,
378        save_outputs: false,
379    };
380
381    let mut class_test_evaluator = TestEvaluator::new(class_test_config)?;
382
383    // Evaluate classification model
384    println!("\nEvaluating classification model:");
385    let class_metrics =
386        class_test_evaluator.evaluate(&mut class_model, &class_test_dataset, None)?;
387
388    println!("Classification metrics:");
389    for (name, value) in &class_metrics {
390        println!("  {}: {:.4}", name, value);
391    }
392
393    // Generate classification report
394    println!("\nClassification Report:");
395    match class_test_evaluator.classification_report() {
396        Ok(report) => println!("{}", report),
397        Err(e) => println!("Could not generate classification report: {}", e),
398    }
399
400    // Generate confusion matrix
401    println!("\nConfusion Matrix:");
402    match class_test_evaluator.confusion_matrix() {
403        Ok(cm) => println!("{}", cm),
404        Err(e) => println!("Could not generate confusion matrix: {}", e),
405    }
406
407    println!("\nModel Evaluation Example Completed Successfully!");
408
409    Ok(())
410}
Source

pub fn validate<L: Layer<F>>( &mut self, model: &mut L, dataset: &dyn Dataset<F>, loss_fn: Option<&dyn Loss<F>>, epoch: usize, ) -> Result<(HashMap<String, F>, bool)>

Validate a model on a dataset

Examples found in repository?
examples/model_evaluation_example.rs (line 271)
165fn main() -> Result<()> {
166    println!("Model Evaluation Framework Example");
167    println!("---------------------------------");
168
169    // 1. Basic evaluation
170    println!("\n1. Basic Evaluation:");
171
172    // Generate synthetic regression dataset
173    let dataset = generate_regression_dataset::<f32>(1000, 5)?;
174
175    // Split into train, validation, and test sets
176    let n_samples = dataset.len();
177    let train_size = n_samples * 6 / 10;
178    let val_size = n_samples * 2 / 10;
179    let _test_size = n_samples - train_size - val_size;
180
181    let mut indices: Vec<usize> = (0..n_samples).collect();
182    use rand::seq::SliceRandom;
183    let mut shuffle_rng = SmallRng::seed_from_u64(44);
184    indices.shuffle(&mut shuffle_rng);
185
186    let train_indices = indices[0..train_size].to_vec();
187    let val_indices = indices[train_size..train_size + val_size].to_vec();
188    let test_indices = indices[train_size + val_size..].to_vec();
189
190    println!(
191        "Dataset splits: Train={}, Validation={}, Test={}",
192        train_indices.len(),
193        val_indices.len(),
194        test_indices.len()
195    );
196
197    // Create subset datasets
198    let _train_dataset = SubsetDataset::new(dataset.clone(), train_indices.clone())?;
199    let val_dataset = SubsetDataset::new(dataset.clone(), val_indices.clone())?;
200    let test_dataset = SubsetDataset::new(dataset.clone(), test_indices.clone())?;
201
202    // Build a simple model
203    let model_builder = SimpleModelBuilder::<f32>::new(5, 32, 1);
204    let mut model = model_builder.build()?;
205
206    // Create loss function
207    let loss_fn = MeanSquaredError::new();
208
209    // Create evaluator
210    let eval_config = EvaluationConfig {
211        batch_size: 32,
212        shuffle: false,
213        num_workers: 0,
214        metrics: vec![
215            MetricType::Loss,
216            MetricType::MeanSquaredError,
217            MetricType::MeanAbsoluteError,
218            MetricType::RSquared,
219        ],
220        steps: None,
221        verbose: 1,
222    };
223
224    let mut evaluator = Evaluator::new(eval_config)?;
225
226    // Evaluate model on validation set
227    println!("\nEvaluating model on validation set:");
228    let val_metrics = evaluator.evaluate(&mut model, &val_dataset, Some(&loss_fn))?;
229
230    println!("Validation metrics:");
231    for (name, value) in &val_metrics {
232        println!("  {}: {:.4}", name, value);
233    }
234
235    // 2. Validation with early stopping
236    println!("\n2. Validation with Early Stopping:");
237
238    // Configure early stopping
239    let early_stopping_config = EarlyStoppingConfig {
240        monitor: "val_loss".to_string(),
241        min_delta: 0.001,
242        patience: 5,
243        restore_best_weights: true,
244        mode: EarlyStoppingMode::Min,
245    };
246
247    let validation_config = ValidationConfig {
248        batch_size: 32,
249        shuffle: false,
250        num_workers: 0,
251        steps: None,
252        metrics: vec![MetricType::Loss, MetricType::MeanSquaredError],
253        verbose: 1,
254        early_stopping: Some(early_stopping_config),
255    };
256
257    let mut validation_handler = ValidationHandler::new(validation_config)?;
258
259    // Simulate training loop with validation
260    println!("\nSimulating training loop with validation:");
261    let num_epochs = 10;
262
263    for epoch in 0..num_epochs {
264        println!("Epoch {}/{}", epoch + 1, num_epochs);
265
266        // Simulate training step (not actually training the model)
267        println!("Training...");
268
269        // Validate model
270        let (val_metrics, should_stop) =
271            validation_handler.validate(&mut model, &val_dataset, Some(&loss_fn), epoch)?;
272
273        println!("Validation metrics:");
274        for (name, value) in &val_metrics {
275            println!("  {}: {:.4}", name, value);
276        }
277
278        if should_stop {
279            println!("Early stopping triggered!");
280            break;
281        }
282    }
283
284    // 3. Cross-validation
285    println!("\n3. Cross-Validation:");
286
287    // Configure cross-validation
288    let cv_config = CrossValidationConfig {
289        strategy: CrossValidationStrategy::KFold(5),
290        shuffle: true,
291        random_seed: Some(42),
292        batch_size: 32,
293        num_workers: 0,
294        metrics: vec![
295            MetricType::Loss,
296            MetricType::MeanSquaredError,
297            MetricType::RSquared,
298        ],
299        verbose: 1,
300    };
301
302    let mut cross_validator = CrossValidator::new(cv_config)?;
303
304    // Perform cross-validation
305    println!("\nPerforming 5-fold cross-validation:");
306    let cv_results = cross_validator.cross_validate(&model_builder, &dataset, Some(&loss_fn))?;
307
308    println!("Cross-validation results:");
309    for (name, values) in &cv_results {
310        // Calculate mean and std
311        let sum: f32 = values.iter().sum();
312        let mean = sum / values.len() as f32;
313
314        let variance_sum: f32 = values.iter().map(|&x| (x - mean).powi(2)).sum();
315        let std = (variance_sum / values.len() as f32).sqrt();
316
317        println!("  {}: {:.4} ± {:.4}", name, mean, std);
318    }
319
320    // 4. Test set evaluation
321    println!("\n4. Test Set Evaluation:");
322
323    // Configure test evaluator
324    let test_config = TestConfig {
325        batch_size: 32,
326        num_workers: 0,
327        metrics: vec![
328            MetricType::Loss,
329            MetricType::MeanSquaredError,
330            MetricType::MeanAbsoluteError,
331            MetricType::RSquared,
332        ],
333        steps: None,
334        verbose: 1,
335        generate_predictions: true,
336        save_outputs: false,
337    };
338
339    let mut test_evaluator = TestEvaluator::new(test_config)?;
340
341    // Evaluate model on test set
342    println!("\nEvaluating model on test set:");
343    let test_metrics = test_evaluator.evaluate(&mut model, &test_dataset, Some(&loss_fn))?;
344
345    println!("Test metrics:");
346    for (name, value) in &test_metrics {
347        println!("  {}: {:.4}", name, value);
348    }
349
350    // 5. Classification example
351    println!("\n5. Classification Example:");
352
353    // Generate synthetic classification dataset
354    let n_classes = 3;
355    let class_dataset = generate_classification_dataset::<f32>(1000, 5, n_classes)?;
356
357    // Split dataset
358    let _class_train_dataset = SubsetDataset::new(class_dataset.clone(), train_indices.clone())?;
359    let class_test_dataset = SubsetDataset::new(class_dataset.clone(), test_indices.clone())?;
360
361    // Build classification model
362    let class_model_builder = SimpleModelBuilder::<f32>::new(5, 32, n_classes);
363    let mut class_model = class_model_builder.build()?;
364
365    // Configure test evaluator for classification
366    let class_test_config = TestConfig {
367        batch_size: 32,
368        num_workers: 0,
369        metrics: vec![
370            MetricType::Accuracy,
371            MetricType::Precision,
372            MetricType::Recall,
373            MetricType::F1Score,
374        ],
375        steps: None,
376        verbose: 1,
377        generate_predictions: true,
378        save_outputs: false,
379    };
380
381    let mut class_test_evaluator = TestEvaluator::new(class_test_config)?;
382
383    // Evaluate classification model
384    println!("\nEvaluating classification model:");
385    let class_metrics =
386        class_test_evaluator.evaluate(&mut class_model, &class_test_dataset, None)?;
387
388    println!("Classification metrics:");
389    for (name, value) in &class_metrics {
390        println!("  {}: {:.4}", name, value);
391    }
392
393    // Generate classification report
394    println!("\nClassification Report:");
395    match class_test_evaluator.classification_report() {
396        Ok(report) => println!("{}", report),
397        Err(e) => println!("Could not generate classification report: {}", e),
398    }
399
400    // Generate confusion matrix
401    println!("\nConfusion Matrix:");
402    match class_test_evaluator.confusion_matrix() {
403        Ok(cm) => println!("{}", cm),
404        Err(e) => println!("Could not generate confusion matrix: {}", e),
405    }
406
407    println!("\nModel Evaluation Example Completed Successfully!");
408
409    Ok(())
410}
Source

pub fn has_early_stopping(&self) -> bool

Check if early stopping is enabled

Source

pub fn get_early_stopping_state(&self) -> Option<&EarlyStoppingState<F>>

Get the current early stopping state

Source

pub fn reset_early_stopping(&mut self)

Reset early stopping state

Trait Implementations§

Source§

impl<F: Debug + Float + Debug + ScalarOperand + Display + FromPrimitive + Send + Sync> Debug for ValidationHandler<F>

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more

Auto Trait Implementations§

§

impl<F> Freeze for ValidationHandler<F>
where F: Freeze,

§

impl<F> !RefUnwindSafe for ValidationHandler<F>

§

impl<F> !Send for ValidationHandler<F>

§

impl<F> !Sync for ValidationHandler<F>

§

impl<F> Unpin for ValidationHandler<F>
where F: Unpin,

§

impl<F> !UnwindSafe for ValidationHandler<F>

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V