pub struct SamOptimizer<O: Optimizer> { /* private fields */ }Expand description
SAM optimizer (Sharpness Aware Minimization).
SAM seeks parameters that lie in neighborhoods having uniformly low loss, improving model generalization. It requires two forward-backward passes per step: one to compute the adversarial perturbation, and one to compute the actual gradient.
Reference: Foret et al. “Sharpness-Aware Minimization for Efficiently Improving Generalization” (ICLR 2021)
Note: This is a wrapper optimizer. SAM requires special handling in the training loop to perform two gradient computations per step. The typical usage is:
- Compute gradients at current parameters
- Compute adversarial perturbation
- Compute gradients at perturbed parameters
- Update with the perturbed gradients
Implementations§
Source§impl<O: Optimizer> SamOptimizer<O>
impl<O: Optimizer> SamOptimizer<O>
Sourcepub fn new(base_optimizer: O, rho: f64) -> TrainResult<Self>
pub fn new(base_optimizer: O, rho: f64) -> TrainResult<Self>
Create a new SAM optimizer.
§Arguments
base_optimizer- The base optimizer to use (SGD, Adam, etc.)rho- Perturbation radius (typically 0.05)
Sourcepub fn first_step(
&mut self,
parameters: &mut HashMap<String, Array<f64, Ix2>>,
gradients: &HashMap<String, Array<f64, Ix2>>,
) -> TrainResult<()>
pub fn first_step( &mut self, parameters: &mut HashMap<String, Array<f64, Ix2>>, gradients: &HashMap<String, Array<f64, Ix2>>, ) -> TrainResult<()>
Compute adversarial perturbations.
This should be called with the first set of gradients to compute the perturbation direction.
Sourcepub fn second_step(
&mut self,
parameters: &mut HashMap<String, Array<f64, Ix2>>,
gradients: &HashMap<String, Array<f64, Ix2>>,
) -> TrainResult<()>
pub fn second_step( &mut self, parameters: &mut HashMap<String, Array<f64, Ix2>>, gradients: &HashMap<String, Array<f64, Ix2>>, ) -> TrainResult<()>
Perform the actual optimization step.
This should be called with the second set of gradients (computed at the perturbed parameters). It will remove the perturbations and update the parameters using the base optimizer.
Trait Implementations§
Auto Trait Implementations§
impl<O> Freeze for SamOptimizer<O>where
O: Freeze,
impl<O> RefUnwindSafe for SamOptimizer<O>where
O: RefUnwindSafe,
impl<O> Send for SamOptimizer<O>where
O: Send,
impl<O> Sync for SamOptimizer<O>where
O: Sync,
impl<O> Unpin for SamOptimizer<O>where
O: Unpin,
impl<O> UnwindSafe for SamOptimizer<O>where
O: UnwindSafe,
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more