tch_plus/wrappers/
optimizer.rs1use super::tensor::Tensor;
2use crate::TchError;
3
4pub struct COptimizer {
5 c_optimizer: *mut torch_sys_plus::C_optimizer,
6}
7
8unsafe impl Send for COptimizer {}
9
10impl std::fmt::Debug for COptimizer {
11 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
12 write!(f, "optimizer")
13 }
14}
15
16impl COptimizer {
17 pub fn adam(
18 lr: f64,
19 beta1: f64,
20 beta2: f64,
21 wd: f64,
22 eps: f64,
23 amsgrad: bool,
24 ) -> Result<COptimizer, TchError> {
25 let c_optimizer =
26 unsafe_torch_err!(torch_sys_plus::ato_adam(lr, beta1, beta2, wd, eps, amsgrad));
27 Ok(COptimizer { c_optimizer })
28 }
29
30 pub fn adamw(
31 lr: f64,
32 beta1: f64,
33 beta2: f64,
34 wd: f64,
35 eps: f64,
36 amsgrad: bool,
37 ) -> Result<COptimizer, TchError> {
38 let c_optimizer =
39 unsafe_torch_err!(torch_sys_plus::ato_adamw(lr, beta1, beta2, wd, eps, amsgrad));
40 Ok(COptimizer { c_optimizer })
41 }
42
43 pub fn rms_prop(
45 lr: f64,
46 alpha: f64,
47 eps: f64,
48 wd: f64,
49 momentum: f64,
50 centered: bool,
51 ) -> Result<COptimizer, TchError> {
52 let centered = i32::from(centered);
53 let c_optimizer =
54 unsafe_torch_err!(torch_sys_plus::ato_rms_prop(lr, alpha, eps, wd, momentum, centered));
55 Ok(COptimizer { c_optimizer })
56 }
57
58 pub fn sgd(
59 lr: f64,
60 momentum: f64,
61 dampening: f64,
62 wd: f64,
63 nesterov: bool,
64 ) -> Result<COptimizer, TchError> {
65 let nesterov = i32::from(nesterov);
66 let c_optimizer =
67 unsafe_torch_err!(torch_sys_plus::ato_sgd(lr, momentum, dampening, wd, nesterov));
68 Ok(COptimizer { c_optimizer })
69 }
70
71 pub fn add_parameters(&mut self, t: &Tensor, group: usize) -> Result<(), TchError> {
72 unsafe_torch_err!(torch_sys_plus::ato_add_parameters(self.c_optimizer, t.c_tensor, group));
73 Ok(())
74 }
75
76 pub fn set_learning_rate(&mut self, lr: f64) -> Result<(), TchError> {
77 unsafe_torch_err!(torch_sys_plus::ato_set_learning_rate(self.c_optimizer, lr));
78 Ok(())
79 }
80
81 pub fn set_learning_rate_group(&mut self, group: usize, lr: f64) -> Result<(), TchError> {
82 unsafe_torch_err!(torch_sys_plus::ato_set_learning_rate_group(self.c_optimizer, group, lr));
83 Ok(())
84 }
85
86 pub fn set_momentum(&mut self, m: f64) -> Result<(), TchError> {
87 unsafe_torch_err!(torch_sys_plus::ato_set_momentum(self.c_optimizer, m));
88 Ok(())
89 }
90
91 pub fn set_momentum_group(&mut self, group: usize, m: f64) -> Result<(), TchError> {
92 unsafe_torch_err!(torch_sys_plus::ato_set_momentum_group(self.c_optimizer, group, m));
93 Ok(())
94 }
95
96 pub fn set_weight_decay(&mut self, weight_decay: f64) -> Result<(), TchError> {
97 unsafe_torch_err!(torch_sys_plus::ato_set_weight_decay(self.c_optimizer, weight_decay));
98 Ok(())
99 }
100
101 pub fn set_weight_decay_group(
102 &mut self,
103 group: usize,
104 weight_decay: f64,
105 ) -> Result<(), TchError> {
106 unsafe_torch_err!(torch_sys_plus::ato_set_weight_decay_group(
107 self.c_optimizer,
108 group,
109 weight_decay
110 ));
111 Ok(())
112 }
113
114 pub fn zero_grad(&self) -> Result<(), TchError> {
115 unsafe_torch_err!(torch_sys_plus::ato_zero_grad(self.c_optimizer));
116 Ok(())
117 }
118
119 pub fn step(&self) -> Result<(), TchError> {
120 unsafe_torch_err!(torch_sys_plus::ato_step(self.c_optimizer));
121 Ok(())
122 }
123}
124
125impl Drop for COptimizer {
126 fn drop(&mut self) {
127 unsafe { torch_sys_plus::ato_free(self.c_optimizer) }
128 }
129}