Module training

Source
Expand description

Training utilities and infrastructure

This module provides comprehensive utilities for training neural networks, including advanced features like gradient accumulation, mixed precision training, distributed training, and sophisticated training loop management.

§Overview

The training module consists of several key components:

  • Trainer: High-level training orchestrator that manages the entire training process
  • TrainingConfig: Configuration structure for customizing training behavior
  • GradientAccumulator: For accumulating gradients across multiple batches
  • MixedPrecisionManager: For memory-efficient mixed precision training
  • ValidationSettings: For configuring validation during training

§Examples

§Basic Training Loop

use scirs2_neural::training::{Trainer, TrainingConfig, ValidationSettings};
use scirs2_neural::data::{DataLoader, Dataset};
use scirs2_neural::models::Sequential;
use scirs2_neural::layers::Dense;
use scirs2_neural::losses::CrossEntropyLoss;
use scirs2_neural::optimizers::Adam;
use scirs2_neural::callbacks::CallbackManager;
use rand::rngs::SmallRng;
use rand::SeedableRng;

let mut rng = SmallRng::seed_from_u64(42);

// Create a simple model
let mut model: Sequential<f32> = Sequential::new();
model.add_layer(Dense::new(784, 128, Some("relu"), &mut rng)?);
model.add_layer(Dense::new(128, 10, Some("softmax"), &mut rng)?);

// Configure training
let config = TrainingConfig {
    batch_size: 32,
    epochs: 10,
    learning_rate: 0.001,
    shuffle: true,
    verbose: 1,
    validation: Some(ValidationSettings {
        enabled: true,
        validation_split: 0.2,
        batch_size: 32,
        num_workers: 1,
    }),
    ..Default::default()
};

// Create trainer
// let trainer = Trainer::new(model, optimizer, loss_fn, config);

// Set up data, loss, optimizer, and callbacks
// let train_loader = DataLoader::new(...);
// let loss_fn = CrossEntropyLoss::new();
// let optimizer = Adam::new(0.001);
// let callbacks = CallbackManager::new();

// Train the model
// let history = trainer.fit(&mut model, &train_loader, &loss_fn, &mut optimizer, &mut callbacks)?;

§Advanced Training with Gradient Accumulation

use scirs2_neural::training::{TrainingConfig, GradientAccumulationConfig};

let config = TrainingConfig {
    batch_size: 8,  // Smaller effective batch size
    gradient_accumulation: Some(GradientAccumulationConfig {
        accumulation_steps: 4,  // Accumulate over 4 steps (effective batch size: 32)
        clip_gradients: true,
        average_gradients: true,
        zero_gradients_after_update: true,
        max_gradient_norm: Some(1.0),
        log_gradient_stats: false,
    }),
    ..Default::default()
};

§Mixed Precision Training

use scirs2_neural::training::{TrainingConfig, MixedPrecisionConfig};

let config = TrainingConfig {
    mixed_precision: Some(MixedPrecisionConfig {
        dynamic_loss_scaling: true,
        initial_loss_scale: 1024.0,
        scale_factor: 2.0,
        scale_window: 2000,
        ..Default::default()
    }),
    ..Default::default()
};

Re-exports§

pub use gradient_accumulation::*;
pub use mixed_precision::*;

Modules§

gradient_accumulation
Gradient Accumulation utilities
mixed_precision
Mixed Precision Training utilities

Structs§

Trainer
Trainer for a neural network model
TrainingConfig
Configuration for neural network training
TrainingSession
Training session for tracking training history
ValidationSettings
Configuration for validation during training