Expand description
Training utilities and infrastructure
This module provides comprehensive utilities for training neural networks, including advanced features like gradient accumulation, mixed precision training, distributed training, and sophisticated training loop management.
§Overview
The training module consists of several key components:
- Trainer: High-level training orchestrator that manages the entire training process
- TrainingConfig: Configuration structure for customizing training behavior
- GradientAccumulator: For accumulating gradients across multiple batches
- MixedPrecisionManager: For memory-efficient mixed precision training
- ValidationSettings: For configuring validation during training
§Examples
§Basic Training Loop
use scirs2_neural::training::{Trainer, TrainingConfig, ValidationSettings};
use scirs2_neural::data::{DataLoader, Dataset};
use scirs2_neural::models::Sequential;
use scirs2_neural::layers::Dense;
use scirs2_neural::losses::CrossEntropyLoss;
use scirs2_neural::optimizers::Adam;
use scirs2_neural::callbacks::CallbackManager;
use rand::rngs::SmallRng;
use rand::SeedableRng;
let mut rng = SmallRng::seed_from_u64(42);
// Create a simple model
let mut model: Sequential<f32> = Sequential::new();
model.add_layer(Dense::new(784, 128, Some("relu"), &mut rng)?);
model.add_layer(Dense::new(128, 10, Some("softmax"), &mut rng)?);
// Configure training
let config = TrainingConfig {
batch_size: 32,
epochs: 10,
learning_rate: 0.001,
shuffle: true,
verbose: 1,
validation: Some(ValidationSettings {
enabled: true,
validation_split: 0.2,
batch_size: 32,
num_workers: 1,
}),
..Default::default()
};
// Create trainer
// let trainer = Trainer::new(model, optimizer, loss_fn, config);
// Set up data, loss, optimizer, and callbacks
// let train_loader = DataLoader::new(...);
// let loss_fn = CrossEntropyLoss::new();
// let optimizer = Adam::new(0.001);
// let callbacks = CallbackManager::new();
// Train the model
// let history = trainer.fit(&mut model, &train_loader, &loss_fn, &mut optimizer, &mut callbacks)?;
§Advanced Training with Gradient Accumulation
use scirs2_neural::training::{TrainingConfig, GradientAccumulationConfig};
let config = TrainingConfig {
batch_size: 8, // Smaller effective batch size
gradient_accumulation: Some(GradientAccumulationConfig {
accumulation_steps: 4, // Accumulate over 4 steps (effective batch size: 32)
clip_gradients: true,
average_gradients: true,
zero_gradients_after_update: true,
max_gradient_norm: Some(1.0),
log_gradient_stats: false,
}),
..Default::default()
};
§Mixed Precision Training
use scirs2_neural::training::{TrainingConfig, MixedPrecisionConfig};
let config = TrainingConfig {
mixed_precision: Some(MixedPrecisionConfig {
dynamic_loss_scaling: true,
initial_loss_scale: 1024.0,
scale_factor: 2.0,
scale_window: 2000,
..Default::default()
}),
..Default::default()
};
Re-exports§
pub use gradient_accumulation::*;
pub use mixed_precision::*;
Modules§
- gradient_
accumulation - Gradient Accumulation utilities
- mixed_
precision - Mixed Precision Training utilities
Structs§
- Trainer
- Trainer for a neural network model
- Training
Config - Configuration for neural network training
- Training
Session - Training session for tracking training history
- Validation
Settings - Configuration for validation during training