pub struct GradientPredictor { /* private fields */ }Expand description
Predict future gradients from history.
Gradient prediction reduces compute by ~80% (4 predicted steps per 1 computed step) while maintaining convergence quality through periodic correction cycles.
The predictor maintains a history of recent gradients and uses a momentum-based extrapolation to predict future gradients. Corrections are computed as the difference between predicted and actual gradients.
§Example
ⓘ
use vsa_optim_rs::prediction::GradientPredictor;
use vsa_optim_rs::PredictionConfig;
let shapes = vec![
("layer1.weight".to_string(), vec![64, 128]),
];
let mut predictor = GradientPredictor::new(&shapes, PredictionConfig::default(), &Device::Cpu)?;
// Training loop
for step in 0..total_steps {
if predictor.should_compute_full() {
// loss.backward() - compute full gradients
predictor.record_gradient(&gradients)?;
predictor.apply_correction(&mut gradients);
} else {
let predicted = predictor.predict_gradient()?;
// Use predicted gradients for optimizer step
}
}Implementations§
Source§impl GradientPredictor
impl GradientPredictor
Sourcepub fn new(
param_shapes: &[(String, Vec<usize>)],
config: PredictionConfig,
device: &Device,
) -> Result<GradientPredictor, OptimError>
pub fn new( param_shapes: &[(String, Vec<usize>)], config: PredictionConfig, device: &Device, ) -> Result<GradientPredictor, OptimError>
Sourcepub fn should_compute_full(&self) -> bool
pub fn should_compute_full(&self) -> bool
Check if full gradient computation is needed.
Full computation is needed:
- At the start (insufficient history)
- After
prediction_stepspredicted steps (correction cycle) - When prediction quality degrades below threshold
Sourcepub fn record_gradient(
&mut self,
gradients: &HashMap<String, Tensor>,
) -> Result<(), OptimError>
pub fn record_gradient( &mut self, gradients: &HashMap<String, Tensor>, ) -> Result<(), OptimError>
Sourcepub fn predict_gradient(
&mut self,
) -> Result<HashMap<String, Tensor>, OptimError>
pub fn predict_gradient( &mut self, ) -> Result<HashMap<String, Tensor>, OptimError>
Sourcepub fn compute_correction(
&mut self,
actual_gradients: &HashMap<String, Tensor>,
) -> Result<HashMap<String, Tensor>, OptimError>
pub fn compute_correction( &mut self, actual_gradients: &HashMap<String, Tensor>, ) -> Result<HashMap<String, Tensor>, OptimError>
Compute correction between predicted and actual gradients.
The correction term captures the prediction error and is accumulated to apply a “catch-up” adjustment.
§Arguments
actual_gradients- The actual computed gradients
§Returns
Dictionary of correction terms.
§Errors
Returns error if tensor operations fail.
Sourcepub fn apply_correction(
&mut self,
gradients: &mut HashMap<String, Tensor>,
) -> Result<(), OptimError>
pub fn apply_correction( &mut self, gradients: &mut HashMap<String, Tensor>, ) -> Result<(), OptimError>
Sourcepub fn get_stats(&self) -> PredictorStats
pub fn get_stats(&self) -> PredictorStats
Get prediction statistics.
Sourcepub const fn total_steps(&self) -> usize
pub const fn total_steps(&self) -> usize
Get total steps.
Auto Trait Implementations§
impl Freeze for GradientPredictor
impl !RefUnwindSafe for GradientPredictor
impl Send for GradientPredictor
impl Sync for GradientPredictor
impl Unpin for GradientPredictor
impl !UnwindSafe for GradientPredictor
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more
Source§impl<T> Instrument for T
impl<T> Instrument for T
Source§fn instrument(self, span: Span) -> Instrumented<Self>
fn instrument(self, span: Span) -> Instrumented<Self>
Source§fn in_current_span(self) -> Instrumented<Self>
fn in_current_span(self) -> Instrumented<Self>
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more