pub struct PhaseTrainer { /* private fields */ }Expand description
Orchestrates phase-based training for acceleration.
This is the main training orchestrator that combines all optimization techniques. It manages the phase transitions and ensures convergence while maximizing training speed.
The trainer automatically:
- Tracks which phase we’re in
- Manages gradient prediction during PREDICT phase
- Applies corrections to prevent drift
- Uses ternary accumulation for memory efficiency
- Optionally uses VSA compression for gradient storage
§Example
ⓘ
use vsa_optim_rs::phase::PhaseTrainer;
use vsa_optim_rs::PhaseConfig;
let shapes = vec![("layer.weight".to_string(), vec![64, 128])];
let config = PhaseConfig::default();
let mut trainer = PhaseTrainer::new(&shapes, config, &Device::Cpu)?;
// Training loop
for step in 0..total_steps {
let step_info = trainer.begin_step()?;
match step_info.phase {
TrainingPhase::Full | TrainingPhase::Correct => {
// Compute full gradients via backprop
trainer.record_full_gradients(&gradients)?;
}
TrainingPhase::Predict => {
// Use predicted gradients
let predicted = trainer.get_predicted_gradients()?;
}
}
trainer.end_step(loss_value)?;
}Implementations§
Source§impl PhaseTrainer
impl PhaseTrainer
Sourcepub fn new(
param_shapes: &[(String, Vec<usize>)],
config: PhaseConfig,
device: &Device,
) -> Result<PhaseTrainer, OptimError>
pub fn new( param_shapes: &[(String, Vec<usize>)], config: PhaseConfig, device: &Device, ) -> Result<PhaseTrainer, OptimError>
Sourcepub fn begin_step(&mut self) -> Result<StepInfo, OptimError>
pub fn begin_step(&mut self) -> Result<StepInfo, OptimError>
Sourcepub fn record_full_gradients(
&mut self,
gradients: &HashMap<String, Tensor>,
) -> Result<(), OptimError>
pub fn record_full_gradients( &mut self, gradients: &HashMap<String, Tensor>, ) -> Result<(), OptimError>
Sourcepub fn get_predicted_gradients(
&mut self,
) -> Result<HashMap<String, Tensor>, OptimError>
pub fn get_predicted_gradients( &mut self, ) -> Result<HashMap<String, Tensor>, OptimError>
Sourcepub fn apply_correction(
&mut self,
gradients: &mut HashMap<String, Tensor>,
) -> Result<(), OptimError>
pub fn apply_correction( &mut self, gradients: &mut HashMap<String, Tensor>, ) -> Result<(), OptimError>
Sourcepub const fn current_phase(&self) -> TrainingPhase
pub const fn current_phase(&self) -> TrainingPhase
Get current training phase.
Sourcepub const fn total_step(&self) -> usize
pub const fn total_step(&self) -> usize
Get total step count.
Sourcepub const fn cycle_count(&self) -> usize
pub const fn cycle_count(&self) -> usize
Get cycle count.
Sourcepub const fn speedup_ratio(&self) -> f32
pub const fn speedup_ratio(&self) -> f32
Get speedup ratio.
Sourcepub fn get_stats(&self) -> TrainerStats
pub fn get_stats(&self) -> TrainerStats
Get training statistics.
Sourcepub fn reset(&mut self) -> Result<(), OptimError>
pub fn reset(&mut self) -> Result<(), OptimError>
Reset trainer state.
Sourcepub fn vsa_compressor_mut(&mut self) -> &mut VSAGradientCompressor
pub fn vsa_compressor_mut(&mut self) -> &mut VSAGradientCompressor
Get mutable access to VSA compressor.
Sourcepub fn ternary_accumulator_mut(&mut self) -> &mut TernaryGradientAccumulator
pub fn ternary_accumulator_mut(&mut self) -> &mut TernaryGradientAccumulator
Get mutable access to ternary accumulator.
Sourcepub fn should_compute_full(&self) -> bool
pub fn should_compute_full(&self) -> bool
Check if should compute full gradients.
Auto Trait Implementations§
impl Freeze for PhaseTrainer
impl !RefUnwindSafe for PhaseTrainer
impl Send for PhaseTrainer
impl Sync for PhaseTrainer
impl Unpin for PhaseTrainer
impl !UnwindSafe for PhaseTrainer
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more
Source§impl<T> Instrument for T
impl<T> Instrument for T
Source§fn instrument(self, span: Span) -> Instrumented<Self>
fn instrument(self, span: Span) -> Instrumented<Self>
Source§fn in_current_span(self) -> Instrumented<Self>
fn in_current_span(self) -> Instrumented<Self>
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more