1use super::var_store::{VarStore, Variables};
3use crate::wrappers::optimizer::COptimizer;
4use crate::{TchError, Tensor};
5use std::sync::{Arc, Mutex};
6
7#[derive(Debug)]
9pub struct Optimizer {
10 opt: COptimizer,
11 variables: Arc<Mutex<Variables>>,
12 variables_in_optimizer: usize,
13}
14
15pub trait OptimizerConfig
17where
18 Self: std::marker::Sized,
19{
20 fn build_copt(&self, lr: f64) -> Result<COptimizer, TchError>;
21
22 fn build(self, vs: &VarStore, lr: f64) -> Result<Optimizer, TchError> {
24 let mut opt = self.build_copt(lr)?;
25 let v = vs.variables_.lock().unwrap();
26 for var in &v.trainable_variables {
27 opt.add_parameters(&var.tensor, var.group)?;
28 }
29 Ok(Optimizer {
30 opt,
31 variables: vs.variables_.clone(),
32 variables_in_optimizer: v.trainable_variables.len(),
33 })
34 }
35}
36
37#[derive(Debug, Copy, Clone)]
39pub struct Sgd {
40 pub momentum: f64,
41 pub dampening: f64,
42 pub wd: f64,
43 pub nesterov: bool,
44}
45
46impl Default for Sgd {
47 fn default() -> Self {
48 Sgd { momentum: 0., dampening: 0., wd: 0., nesterov: false }
49 }
50}
51
52pub fn sgd(momentum: f64, dampening: f64, wd: f64, nesterov: bool) -> Sgd {
54 Sgd { momentum, dampening, wd, nesterov }
55}
56
57impl OptimizerConfig for Sgd {
58 fn build_copt(&self, lr: f64) -> Result<COptimizer, TchError> {
59 COptimizer::sgd(lr, self.momentum, self.dampening, self.wd, self.nesterov)
60 }
61}
62
63#[derive(Debug, Copy, Clone)]
65pub struct Adam {
66 pub beta1: f64,
67 pub beta2: f64,
68 pub wd: f64,
69 pub eps: f64,
70 pub amsgrad: bool,
71}
72
73impl Default for Adam {
74 fn default() -> Self {
75 Adam { beta1: 0.9, beta2: 0.999, wd: 0., eps: 1e-8, amsgrad: false }
76 }
77}
78
79pub fn adam(beta1: f64, beta2: f64, wd: f64) -> Adam {
81 Adam { beta1, beta2, wd, eps: 1e-8, amsgrad: false }
82}
83
84impl Adam {
85 pub fn beta1(mut self, b: f64) -> Self {
86 self.beta1 = b;
87 self
88 }
89
90 pub fn beta2(mut self, b: f64) -> Self {
91 self.beta2 = b;
92 self
93 }
94
95 pub fn wd(mut self, w: f64) -> Self {
96 self.wd = w;
97 self
98 }
99
100 pub fn eps(mut self, e: f64) -> Self {
101 self.eps = e;
102 self
103 }
104
105 pub fn amsgrad(mut self, a: bool) -> Self {
106 self.amsgrad = a;
107 self
108 }
109}
110
111impl OptimizerConfig for Adam {
112 fn build_copt(&self, lr: f64) -> Result<COptimizer, TchError> {
113 COptimizer::adam(lr, self.beta1, self.beta2, self.wd, self.eps, self.amsgrad)
114 }
115}
116
117#[derive(Debug, Copy, Clone)]
119pub struct AdamW {
120 pub beta1: f64,
121 pub beta2: f64,
122 pub wd: f64,
123 pub eps: f64,
124 pub amsgrad: bool,
125}
126
127impl Default for AdamW {
128 fn default() -> Self {
129 AdamW { beta1: 0.9, beta2: 0.999, wd: 0.01, eps: 1e-8, amsgrad: false }
130 }
131}
132
133pub fn adamw(beta1: f64, beta2: f64, wd: f64) -> AdamW {
135 AdamW { beta1, beta2, wd, eps: 1e-8, amsgrad: false }
136}
137
138impl AdamW {
139 pub fn beta1(mut self, b: f64) -> Self {
140 self.beta1 = b;
141 self
142 }
143
144 pub fn beta2(mut self, b: f64) -> Self {
145 self.beta2 = b;
146 self
147 }
148
149 pub fn wd(mut self, w: f64) -> Self {
150 self.wd = w;
151 self
152 }
153
154 pub fn eps(mut self, e: f64) -> Self {
155 self.eps = e;
156 self
157 }
158
159 pub fn amsgrad(mut self, a: bool) -> Self {
160 self.amsgrad = a;
161 self
162 }
163}
164
165impl OptimizerConfig for AdamW {
166 fn build_copt(&self, lr: f64) -> Result<COptimizer, TchError> {
167 COptimizer::adamw(lr, self.beta1, self.beta2, self.wd, self.eps, self.amsgrad)
168 }
169}
170
171#[derive(Debug, Copy, Clone)]
173pub struct RmsProp {
174 pub alpha: f64,
175 pub eps: f64,
176 pub wd: f64,
177 pub momentum: f64,
178 pub centered: bool,
179}
180
181impl Default for RmsProp {
182 fn default() -> Self {
183 RmsProp { alpha: 0.99, eps: 1e-8, wd: 0., momentum: 0., centered: false }
184 }
185}
186
187pub fn rms_prop(alpha: f64, eps: f64, wd: f64, momentum: f64, centered: bool) -> RmsProp {
189 RmsProp { alpha, eps, wd, momentum, centered }
190}
191
192impl OptimizerConfig for RmsProp {
193 fn build_copt(&self, lr: f64) -> Result<COptimizer, TchError> {
194 COptimizer::rms_prop(lr, self.alpha, self.eps, self.wd, self.momentum, self.centered)
195 }
196}
197
198impl Optimizer {
199 fn add_missing_variables(&mut self) {
200 let v = self.variables.lock().unwrap();
201 if v.trainable_variables.len() > self.variables_in_optimizer {
202 for var in &v.trainable_variables[self.variables_in_optimizer..] {
203 self.opt.add_parameters(&var.tensor, var.group).unwrap();
204 }
205 self.variables_in_optimizer = v.trainable_variables.len();
206 }
207 }
208
209 pub fn zero_grad(&mut self) {
211 self.add_missing_variables();
212 self.opt.zero_grad().unwrap()
213 }
214
215 pub fn clip_grad_value(&self, max: f64) {
217 let v = self.variables.lock().unwrap();
218 for var in v.trainable_variables.iter() {
219 let mut grad = var.tensor.grad();
220 if grad.defined() {
221 let _t = grad.clamp_(-max, max);
222 }
223 }
224 }
225
226 pub fn clip_grad_norm(&self, max: f64) {
231 crate::no_grad(|| {
232 let v = self.variables.lock().unwrap();
233 let mut norms = vec![];
234 for var in v.trainable_variables.iter() {
235 let grad = var.tensor.grad();
236 if grad.defined() {
237 norms.push(grad.norm());
238 }
239 }
240 let total_norm = f64::try_from(Tensor::stack(&norms, 0).norm()).unwrap();
241 let clip_coef = max / (total_norm + 1e-6);
242 if clip_coef < 1.0 {
243 for var in v.trainable_variables.iter() {
244 let mut grad = var.tensor.grad();
245 if grad.defined() {
246 let _t = grad.g_mul_scalar_(clip_coef);
247 }
248 }
249 }
250 })
251 }
252
253 pub fn step(&mut self) {
255 self.add_missing_variables();
256 self.opt.step().unwrap()
257 }
258
259 pub fn backward_step(&mut self, loss: &Tensor) {
261 self.add_missing_variables();
262 self.opt.zero_grad().unwrap();
263 loss.backward();
264 self.opt.step().unwrap()
265 }
266
267 pub fn backward_step_clip(&mut self, loss: &Tensor, max: f64) {
271 self.add_missing_variables();
272 self.opt.zero_grad().unwrap();
273 loss.backward();
274 self.clip_grad_value(max);
275 self.opt.step().unwrap()
276 }
277
278 pub fn backward_step_clip_norm(&mut self, loss: &Tensor, max: f64) {
282 self.add_missing_variables();
283 self.opt.zero_grad().unwrap();
284 loss.backward();
285 self.clip_grad_norm(max);
286 self.opt.step().unwrap()
287 }
288
289 pub fn set_lr(&mut self, lr: f64) {
291 self.opt.set_learning_rate(lr).unwrap()
292 }
293
294 pub fn set_momentum(&mut self, m: f64) {
296 self.opt.set_momentum(m).unwrap()
297 }
298
299 pub fn set_lr_group(&mut self, group: usize, lr: f64) {
301 self.opt.set_learning_rate_group(group, lr).unwrap()
302 }
303
304 pub fn set_momentum_group(&mut self, group: usize, m: f64) {
306 self.opt.set_momentum_group(group, m).unwrap()
307 }
308
309 pub fn trainable_variables(&self) -> Vec<Tensor> {
311 let variables = self.variables.lock().unwrap();
312 variables.trainable_variables.iter().map(|v| v.tensor.shallow_clone()).collect()
313 }
314
315 pub fn set_weight_decay(&mut self, weight_decay: f64) {
317 self.opt.set_weight_decay(weight_decay).unwrap()
318 }
319
320 pub fn set_weight_decay_group(&mut self, group: usize, weight_decay: f64) {
322 self.opt.set_weight_decay_group(group, weight_decay).unwrap()
323 }
324}