pub struct Optimizer { /* private fields */ }
Expand description
An optimizer to run gradient descent.
Implementations§
source§impl Optimizer
impl Optimizer
sourcepub fn clip_grad_value(&self, max: f64)
pub fn clip_grad_value(&self, max: f64)
Clips gradient value at some specified maximum value.
sourcepub fn clip_grad_norm(&self, max: f64)
pub fn clip_grad_norm(&self, max: f64)
Clips gradient L2 norm over all trainable parameters.
The norm is computed over all gradients together, as if they were concatenated into a single vector.
sourcepub fn step(&mut self)
pub fn step(&mut self)
Performs an optimization step, updating the tracked tensors based on their gradients.
sourcepub fn backward_step(&mut self, loss: &Tensor)
pub fn backward_step(&mut self, loss: &Tensor)
Applies a backward step pass, update the gradients, and performs an optimization step.
sourcepub fn backward_step_clip(&mut self, loss: &Tensor, max: f64)
pub fn backward_step_clip(&mut self, loss: &Tensor, max: f64)
Applies a backward step pass, update the gradients, and performs an optimization step.
The gradients are clipped based on max
before being applied.
sourcepub fn backward_step_clip_norm(&mut self, loss: &Tensor, max: f64)
pub fn backward_step_clip_norm(&mut self, loss: &Tensor, max: f64)
Applies a backward step pass, update the gradients, and performs an optimization step.
The gradients L2 norm is clipped based on max
.
sourcepub fn set_momentum(&mut self, m: f64)
pub fn set_momentum(&mut self, m: f64)
Sets the optimizer momentum.
sourcepub fn set_lr_group(&mut self, group: usize, lr: f64)
pub fn set_lr_group(&mut self, group: usize, lr: f64)
Sets the optimizer learning rate for a parameter group.
sourcepub fn set_momentum_group(&mut self, group: usize, m: f64)
pub fn set_momentum_group(&mut self, group: usize, m: f64)
Sets the optimizer momentum.
sourcepub fn trainable_variables(&self) -> Vec<Tensor>
pub fn trainable_variables(&self) -> Vec<Tensor>
Returns all the trainable variables for this optimizer.
sourcepub fn set_weight_decay(&mut self, weight_decay: f64)
pub fn set_weight_decay(&mut self, weight_decay: f64)
Sets the optimizer weight decay.
sourcepub fn set_weight_decay_group(&mut self, group: usize, weight_decay: f64)
pub fn set_weight_decay_group(&mut self, group: usize, weight_decay: f64)
Sets the optimizer weight decay.