Skip to main content

PhaseTrainer

Struct PhaseTrainer 

Source
pub struct PhaseTrainer { /* private fields */ }
Expand description

Orchestrates phase-based training for acceleration.

This is the main training orchestrator that combines all optimization techniques. It manages the phase transitions and ensures convergence while maximizing training speed.

The trainer automatically:

  1. Tracks which phase we’re in
  2. Manages gradient prediction during PREDICT phase
  3. Applies corrections to prevent drift
  4. Uses ternary accumulation for memory efficiency
  5. Optionally uses VSA compression for gradient storage

§Example

use vsa_optim_rs::phase::PhaseTrainer;
use vsa_optim_rs::PhaseConfig;

let shapes = vec![("layer.weight".to_string(), vec![64, 128])];
let config = PhaseConfig::default();
let mut trainer = PhaseTrainer::new(&shapes, config, &Device::Cpu)?;

// Training loop
for step in 0..total_steps {
    let step_info = trainer.begin_step()?;

    match step_info.phase {
        TrainingPhase::Full | TrainingPhase::Correct => {
            // Compute full gradients via backprop
            trainer.record_full_gradients(&gradients)?;
        }
        TrainingPhase::Predict => {
            // Use predicted gradients
            let predicted = trainer.get_predicted_gradients()?;
        }
    }

    trainer.end_step(loss_value)?;
}

Implementations§

Source§

impl PhaseTrainer

Source

pub fn new( param_shapes: &[(String, Vec<usize>)], config: PhaseConfig, device: &Device, ) -> Result<PhaseTrainer, OptimError>

Create a new phase trainer.

§Arguments
  • param_shapes - List of (name, shape) tuples for parameters
  • config - Phase training configuration
  • device - Device for tensor storage
§Errors

Returns error if component initialization fails.

Source

pub fn begin_step(&mut self) -> Result<StepInfo, OptimError>

Begin a training step. Returns info about current phase.

§Returns

Step information including phase and whether phase changed.

§Errors

Returns error if phase transition fails.

Source

pub fn record_full_gradients( &mut self, gradients: &HashMap<String, Tensor>, ) -> Result<(), OptimError>

Record full gradients after backprop (for FULL or CORRECT phase).

§Arguments
  • gradients - Map of parameter names to gradient tensors
§Errors

Returns error if recording fails.

Source

pub fn get_predicted_gradients( &mut self, ) -> Result<HashMap<String, Tensor>, OptimError>

Get predicted gradients (for PREDICT phase).

§Returns

Map of parameter names to predicted gradient tensors.

§Errors

Returns error if prediction fails.

Source

pub fn apply_correction( &mut self, gradients: &mut HashMap<String, Tensor>, ) -> Result<(), OptimError>

Apply correction to gradients.

§Arguments
  • gradients - Mutable map of gradients to modify in-place
§Errors

Returns error if correction fails.

Source

pub fn end_step(&mut self, loss: f32) -> Result<(), OptimError>

End the training step.

§Arguments
  • loss - Loss value for this step
§Errors

Returns error if tracking fails.

Source

pub const fn current_phase(&self) -> TrainingPhase

Get current training phase.

Source

pub const fn total_step(&self) -> usize

Get total step count.

Source

pub const fn cycle_count(&self) -> usize

Get cycle count.

Source

pub const fn speedup_ratio(&self) -> f32

Get speedup ratio.

Source

pub fn get_stats(&self) -> TrainerStats

Get training statistics.

Source

pub fn reset(&mut self) -> Result<(), OptimError>

Reset trainer state.

Source

pub fn vsa_compressor_mut(&mut self) -> &mut VSAGradientCompressor

Get mutable access to VSA compressor.

Source

pub fn ternary_accumulator_mut(&mut self) -> &mut TernaryGradientAccumulator

Get mutable access to ternary accumulator.

Source

pub fn should_compute_full(&self) -> bool

Check if should compute full gradients.

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T> Instrument for T

Source§

fn instrument(self, span: Span) -> Instrumented<Self>

Instruments this type with the provided Span, returning an Instrumented wrapper. Read more
Source§

fn in_current_span(self) -> Instrumented<Self>

Instruments this type with the current Span, returning an Instrumented wrapper. Read more
Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V

Source§

impl<T> WithSubscriber for T

Source§

fn with_subscriber<S>(self, subscriber: S) -> WithDispatch<Self>
where S: Into<Dispatch>,

Attaches the provided Subscriber to this type, returning a WithDispatch wrapper. Read more
Source§

fn with_current_subscriber(self) -> WithDispatch<Self>

Attaches the current default Subscriber to this type, returning a WithDispatch wrapper. Read more
Source§

impl<T> ErasedDestructor for T
where T: 'static,

Source§

impl<T> Ungil for T
where T: Send,