tch_plus/nn/
optimizer.rs

1//! Optimizers to be used for gradient-descent based training.
2use super::var_store::{VarStore, Variables};
3use crate::wrappers::optimizer::COptimizer;
4use crate::{TchError, Tensor};
5use std::sync::{Arc, Mutex};
6
7/// An optimizer to run gradient descent.
8#[derive(Debug)]
9pub struct Optimizer {
10    opt: COptimizer,
11    variables: Arc<Mutex<Variables>>,
12    variables_in_optimizer: usize,
13}
14
15/// Optimizer configurations. These configs can be used to build optimizer.
16pub trait OptimizerConfig
17where
18    Self: std::marker::Sized,
19{
20    fn build_copt(&self, lr: f64) -> Result<COptimizer, TchError>;
21
22    /// Builds an optimizer with the specified learning rate handling variables stored in `vs`.
23    fn build(self, vs: &VarStore, lr: f64) -> Result<Optimizer, TchError> {
24        let mut opt = self.build_copt(lr)?;
25        let v = vs.variables_.lock().unwrap();
26        for var in &v.trainable_variables {
27            opt.add_parameters(&var.tensor, var.group)?;
28        }
29        Ok(Optimizer {
30            opt,
31            variables: vs.variables_.clone(),
32            variables_in_optimizer: v.trainable_variables.len(),
33        })
34    }
35}
36
37/// Parameters for the SGD optimizer.
38#[derive(Debug, Copy, Clone)]
39pub struct Sgd {
40    pub momentum: f64,
41    pub dampening: f64,
42    pub wd: f64,
43    pub nesterov: bool,
44}
45
46impl Default for Sgd {
47    fn default() -> Self {
48        Sgd { momentum: 0., dampening: 0., wd: 0., nesterov: false }
49    }
50}
51
52/// Creates the configuration for a Stochastic Gradient Descent (SGD) optimizer.
53pub fn sgd(momentum: f64, dampening: f64, wd: f64, nesterov: bool) -> Sgd {
54    Sgd { momentum, dampening, wd, nesterov }
55}
56
57impl OptimizerConfig for Sgd {
58    fn build_copt(&self, lr: f64) -> Result<COptimizer, TchError> {
59        COptimizer::sgd(lr, self.momentum, self.dampening, self.wd, self.nesterov)
60    }
61}
62
63/// Parameters for the Adam optimizer.
64#[derive(Debug, Copy, Clone)]
65pub struct Adam {
66    pub beta1: f64,
67    pub beta2: f64,
68    pub wd: f64,
69    pub eps: f64,
70    pub amsgrad: bool,
71}
72
73impl Default for Adam {
74    fn default() -> Self {
75        Adam { beta1: 0.9, beta2: 0.999, wd: 0., eps: 1e-8, amsgrad: false }
76    }
77}
78
79/// Creates the configuration for the Adam optimizer.
80pub fn adam(beta1: f64, beta2: f64, wd: f64) -> Adam {
81    Adam { beta1, beta2, wd, eps: 1e-8, amsgrad: false }
82}
83
84impl Adam {
85    pub fn beta1(mut self, b: f64) -> Self {
86        self.beta1 = b;
87        self
88    }
89
90    pub fn beta2(mut self, b: f64) -> Self {
91        self.beta2 = b;
92        self
93    }
94
95    pub fn wd(mut self, w: f64) -> Self {
96        self.wd = w;
97        self
98    }
99
100    pub fn eps(mut self, e: f64) -> Self {
101        self.eps = e;
102        self
103    }
104
105    pub fn amsgrad(mut self, a: bool) -> Self {
106        self.amsgrad = a;
107        self
108    }
109}
110
111impl OptimizerConfig for Adam {
112    fn build_copt(&self, lr: f64) -> Result<COptimizer, TchError> {
113        COptimizer::adam(lr, self.beta1, self.beta2, self.wd, self.eps, self.amsgrad)
114    }
115}
116
117/// Parameters for the AdamW optimizer.
118#[derive(Debug, Copy, Clone)]
119pub struct AdamW {
120    pub beta1: f64,
121    pub beta2: f64,
122    pub wd: f64,
123    pub eps: f64,
124    pub amsgrad: bool,
125}
126
127impl Default for AdamW {
128    fn default() -> Self {
129        AdamW { beta1: 0.9, beta2: 0.999, wd: 0.01, eps: 1e-8, amsgrad: false }
130    }
131}
132
133/// Creates the configuration for the AdamW optimizer.
134pub fn adamw(beta1: f64, beta2: f64, wd: f64) -> AdamW {
135    AdamW { beta1, beta2, wd, eps: 1e-8, amsgrad: false }
136}
137
138impl AdamW {
139    pub fn beta1(mut self, b: f64) -> Self {
140        self.beta1 = b;
141        self
142    }
143
144    pub fn beta2(mut self, b: f64) -> Self {
145        self.beta2 = b;
146        self
147    }
148
149    pub fn wd(mut self, w: f64) -> Self {
150        self.wd = w;
151        self
152    }
153
154    pub fn eps(mut self, e: f64) -> Self {
155        self.eps = e;
156        self
157    }
158
159    pub fn amsgrad(mut self, a: bool) -> Self {
160        self.amsgrad = a;
161        self
162    }
163}
164
165impl OptimizerConfig for AdamW {
166    fn build_copt(&self, lr: f64) -> Result<COptimizer, TchError> {
167        COptimizer::adamw(lr, self.beta1, self.beta2, self.wd, self.eps, self.amsgrad)
168    }
169}
170
171/// Parameters for the RmsProp optimizer.
172#[derive(Debug, Copy, Clone)]
173pub struct RmsProp {
174    pub alpha: f64,
175    pub eps: f64,
176    pub wd: f64,
177    pub momentum: f64,
178    pub centered: bool,
179}
180
181impl Default for RmsProp {
182    fn default() -> Self {
183        RmsProp { alpha: 0.99, eps: 1e-8, wd: 0., momentum: 0., centered: false }
184    }
185}
186
187/// Creates the configuration for the RmsProp optimizer.
188pub fn rms_prop(alpha: f64, eps: f64, wd: f64, momentum: f64, centered: bool) -> RmsProp {
189    RmsProp { alpha, eps, wd, momentum, centered }
190}
191
192impl OptimizerConfig for RmsProp {
193    fn build_copt(&self, lr: f64) -> Result<COptimizer, TchError> {
194        COptimizer::rms_prop(lr, self.alpha, self.eps, self.wd, self.momentum, self.centered)
195    }
196}
197
198impl Optimizer {
199    fn add_missing_variables(&mut self) {
200        let v = self.variables.lock().unwrap();
201        if v.trainable_variables.len() > self.variables_in_optimizer {
202            for var in &v.trainable_variables[self.variables_in_optimizer..] {
203                self.opt.add_parameters(&var.tensor, var.group).unwrap();
204            }
205            self.variables_in_optimizer = v.trainable_variables.len();
206        }
207    }
208
209    /// Zeroes the gradient for the tensors tracked by this optimizer.
210    pub fn zero_grad(&mut self) {
211        self.add_missing_variables();
212        self.opt.zero_grad().unwrap()
213    }
214
215    /// Clips gradient value at some specified maximum value.
216    pub fn clip_grad_value(&self, max: f64) {
217        let v = self.variables.lock().unwrap();
218        for var in v.trainable_variables.iter() {
219            let mut grad = var.tensor.grad();
220            if grad.defined() {
221                let _t = grad.clamp_(-max, max);
222            }
223        }
224    }
225
226    /// Clips gradient L2 norm over all trainable parameters.
227    ///
228    /// The norm is computed over all gradients together, as if they were
229    /// concatenated into a single vector.
230    pub fn clip_grad_norm(&self, max: f64) {
231        crate::no_grad(|| {
232            let v = self.variables.lock().unwrap();
233            let mut norms = vec![];
234            for var in v.trainable_variables.iter() {
235                let grad = var.tensor.grad();
236                if grad.defined() {
237                    norms.push(grad.norm());
238                }
239            }
240            let total_norm = f64::try_from(Tensor::stack(&norms, 0).norm()).unwrap();
241            let clip_coef = max / (total_norm + 1e-6);
242            if clip_coef < 1.0 {
243                for var in v.trainable_variables.iter() {
244                    let mut grad = var.tensor.grad();
245                    if grad.defined() {
246                        let _t = grad.g_mul_scalar_(clip_coef);
247                    }
248                }
249            }
250        })
251    }
252
253    /// Performs an optimization step, updating the tracked tensors based on their gradients.
254    pub fn step(&mut self) {
255        self.add_missing_variables();
256        self.opt.step().unwrap()
257    }
258
259    /// Applies a backward step pass, update the gradients, and performs an optimization step.
260    pub fn backward_step(&mut self, loss: &Tensor) {
261        self.add_missing_variables();
262        self.opt.zero_grad().unwrap();
263        loss.backward();
264        self.opt.step().unwrap()
265    }
266
267    /// Applies a backward step pass, update the gradients, and performs an optimization step.
268    ///
269    /// The gradients are clipped based on `max` before being applied.
270    pub fn backward_step_clip(&mut self, loss: &Tensor, max: f64) {
271        self.add_missing_variables();
272        self.opt.zero_grad().unwrap();
273        loss.backward();
274        self.clip_grad_value(max);
275        self.opt.step().unwrap()
276    }
277
278    /// Applies a backward step pass, update the gradients, and performs an optimization step.
279    ///
280    /// The gradients L2 norm is clipped based on `max`.
281    pub fn backward_step_clip_norm(&mut self, loss: &Tensor, max: f64) {
282        self.add_missing_variables();
283        self.opt.zero_grad().unwrap();
284        loss.backward();
285        self.clip_grad_norm(max);
286        self.opt.step().unwrap()
287    }
288
289    /// Sets the optimizer learning rate.
290    pub fn set_lr(&mut self, lr: f64) {
291        self.opt.set_learning_rate(lr).unwrap()
292    }
293
294    /// Sets the optimizer momentum.
295    pub fn set_momentum(&mut self, m: f64) {
296        self.opt.set_momentum(m).unwrap()
297    }
298
299    /// Sets the optimizer learning rate for a parameter group.
300    pub fn set_lr_group(&mut self, group: usize, lr: f64) {
301        self.opt.set_learning_rate_group(group, lr).unwrap()
302    }
303
304    /// Sets the optimizer momentum.
305    pub fn set_momentum_group(&mut self, group: usize, m: f64) {
306        self.opt.set_momentum_group(group, m).unwrap()
307    }
308
309    /// Returns all the trainable variables for this optimizer.
310    pub fn trainable_variables(&self) -> Vec<Tensor> {
311        let variables = self.variables.lock().unwrap();
312        variables.trainable_variables.iter().map(|v| v.tensor.shallow_clone()).collect()
313    }
314
315    /// Sets the optimizer weight decay.
316    pub fn set_weight_decay(&mut self, weight_decay: f64) {
317        self.opt.set_weight_decay(weight_decay).unwrap()
318    }
319
320    /// Sets the optimizer weight decay.
321    pub fn set_weight_decay_group(&mut self, group: usize, weight_decay: f64) {
322        self.opt.set_weight_decay_group(group, weight_decay).unwrap()
323    }
324}