pub struct GradientAccumulator { /* private fields */ }Expand description
Gradient accumulator managing multiple parameter gradients.
Provides micro-batching support by accumulating gradients across multiple forward/backward passes before applying an optimizer step.
Implementations§
Source§impl GradientAccumulator
impl GradientAccumulator
Sourcepub fn new(config: AccumulationConfig) -> Self
pub fn new(config: AccumulationConfig) -> Self
Create a new gradient accumulator with the given configuration.
Sourcepub fn register(&mut self, name: impl Into<String>, shape: Vec<usize>)
pub fn register(&mut self, name: impl Into<String>, shape: Vec<usize>)
Register a parameter with its gradient shape.
If the parameter is already registered, this is a no-op.
Sourcepub fn accumulate(
&mut self,
name: &str,
grad: &[f64],
) -> Result<(), AccumulationError>
pub fn accumulate( &mut self, name: &str, grad: &[f64], ) -> Result<(), AccumulationError>
Accumulate a gradient for a named parameter.
Returns an error if the parameter has not been registered or if the gradient size does not match the registered shape.
Sourcepub fn should_update(&self) -> bool
pub fn should_update(&self) -> bool
Check if enough micro-batches have been accumulated to trigger an update.
Sourcepub fn get_gradients(
&self,
) -> Result<HashMap<String, Vec<f64>>, AccumulationError>
pub fn get_gradients( &self, ) -> Result<HashMap<String, Vec<f64>>, AccumulationError>
Get all accumulated gradients, optionally normalized and clipped.
Sourcepub fn step(
&mut self,
gradients: &HashMap<String, Vec<f64>>,
) -> Result<bool, AccumulationError>
pub fn step( &mut self, gradients: &HashMap<String, Vec<f64>>, ) -> Result<bool, AccumulationError>
Accumulate a full micro-batch of gradients, returning true if an
update should now be applied.
Sourcepub fn stats(&self) -> AccumulationStats
pub fn stats(&self) -> AccumulationStats
Get statistics about the accumulation state.
Auto Trait Implementations§
impl Freeze for GradientAccumulator
impl RefUnwindSafe for GradientAccumulator
impl Send for GradientAccumulator
impl Sync for GradientAccumulator
impl Unpin for GradientAccumulator
impl UnsafeUnpin for GradientAccumulator
impl UnwindSafe for GradientAccumulator
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more