tch_plus/wrappers/
optimizer.rs

1use super::tensor::Tensor;
2use crate::TchError;
3
4pub struct COptimizer {
5    c_optimizer: *mut torch_sys_plus::C_optimizer,
6}
7
8unsafe impl Send for COptimizer {}
9
10impl std::fmt::Debug for COptimizer {
11    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
12        write!(f, "optimizer")
13    }
14}
15
16impl COptimizer {
17    pub fn adam(
18        lr: f64,
19        beta1: f64,
20        beta2: f64,
21        wd: f64,
22        eps: f64,
23        amsgrad: bool,
24    ) -> Result<COptimizer, TchError> {
25        let c_optimizer =
26            unsafe_torch_err!(torch_sys_plus::ato_adam(lr, beta1, beta2, wd, eps, amsgrad));
27        Ok(COptimizer { c_optimizer })
28    }
29
30    pub fn adamw(
31        lr: f64,
32        beta1: f64,
33        beta2: f64,
34        wd: f64,
35        eps: f64,
36        amsgrad: bool,
37    ) -> Result<COptimizer, TchError> {
38        let c_optimizer =
39            unsafe_torch_err!(torch_sys_plus::ato_adamw(lr, beta1, beta2, wd, eps, amsgrad));
40        Ok(COptimizer { c_optimizer })
41    }
42
43    // Maybe we should use the builder pattern to provide default values for these ?
44    pub fn rms_prop(
45        lr: f64,
46        alpha: f64,
47        eps: f64,
48        wd: f64,
49        momentum: f64,
50        centered: bool,
51    ) -> Result<COptimizer, TchError> {
52        let centered = i32::from(centered);
53        let c_optimizer =
54            unsafe_torch_err!(torch_sys_plus::ato_rms_prop(lr, alpha, eps, wd, momentum, centered));
55        Ok(COptimizer { c_optimizer })
56    }
57
58    pub fn sgd(
59        lr: f64,
60        momentum: f64,
61        dampening: f64,
62        wd: f64,
63        nesterov: bool,
64    ) -> Result<COptimizer, TchError> {
65        let nesterov = i32::from(nesterov);
66        let c_optimizer =
67            unsafe_torch_err!(torch_sys_plus::ato_sgd(lr, momentum, dampening, wd, nesterov));
68        Ok(COptimizer { c_optimizer })
69    }
70
71    pub fn add_parameters(&mut self, t: &Tensor, group: usize) -> Result<(), TchError> {
72        unsafe_torch_err!(torch_sys_plus::ato_add_parameters(self.c_optimizer, t.c_tensor, group));
73        Ok(())
74    }
75
76    pub fn set_learning_rate(&mut self, lr: f64) -> Result<(), TchError> {
77        unsafe_torch_err!(torch_sys_plus::ato_set_learning_rate(self.c_optimizer, lr));
78        Ok(())
79    }
80
81    pub fn set_learning_rate_group(&mut self, group: usize, lr: f64) -> Result<(), TchError> {
82        unsafe_torch_err!(torch_sys_plus::ato_set_learning_rate_group(self.c_optimizer, group, lr));
83        Ok(())
84    }
85
86    pub fn set_momentum(&mut self, m: f64) -> Result<(), TchError> {
87        unsafe_torch_err!(torch_sys_plus::ato_set_momentum(self.c_optimizer, m));
88        Ok(())
89    }
90
91    pub fn set_momentum_group(&mut self, group: usize, m: f64) -> Result<(), TchError> {
92        unsafe_torch_err!(torch_sys_plus::ato_set_momentum_group(self.c_optimizer, group, m));
93        Ok(())
94    }
95
96    pub fn set_weight_decay(&mut self, weight_decay: f64) -> Result<(), TchError> {
97        unsafe_torch_err!(torch_sys_plus::ato_set_weight_decay(self.c_optimizer, weight_decay));
98        Ok(())
99    }
100
101    pub fn set_weight_decay_group(
102        &mut self,
103        group: usize,
104        weight_decay: f64,
105    ) -> Result<(), TchError> {
106        unsafe_torch_err!(torch_sys_plus::ato_set_weight_decay_group(
107            self.c_optimizer,
108            group,
109            weight_decay
110        ));
111        Ok(())
112    }
113
114    pub fn zero_grad(&self) -> Result<(), TchError> {
115        unsafe_torch_err!(torch_sys_plus::ato_zero_grad(self.c_optimizer));
116        Ok(())
117    }
118
119    pub fn step(&self) -> Result<(), TchError> {
120        unsafe_torch_err!(torch_sys_plus::ato_step(self.c_optimizer));
121        Ok(())
122    }
123}
124
125impl Drop for COptimizer {
126    fn drop(&mut self) {
127        unsafe { torch_sys_plus::ato_free(self.c_optimizer) }
128    }
129}