1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
use super::tensor::Tensor;
use crate::TchError;

pub struct COptimizer {
    c_optimizer: *mut torch_sys::C_optimizer,
}

unsafe impl Send for COptimizer {}

impl std::fmt::Debug for COptimizer {
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
        write!(f, "optimizer")
    }
}

impl COptimizer {
    pub fn adam(
        lr: f64,
        beta1: f64,
        beta2: f64,
        wd: f64,
        eps: f64,
        amsgrad: bool,
    ) -> Result<COptimizer, TchError> {
        let c_optimizer =
            unsafe_torch_err!(torch_sys::ato_adam(lr, beta1, beta2, wd, eps, amsgrad));
        Ok(COptimizer { c_optimizer })
    }

    pub fn adamw(
        lr: f64,
        beta1: f64,
        beta2: f64,
        wd: f64,
        eps: f64,
        amsgrad: bool,
    ) -> Result<COptimizer, TchError> {
        let c_optimizer =
            unsafe_torch_err!(torch_sys::ato_adamw(lr, beta1, beta2, wd, eps, amsgrad));
        Ok(COptimizer { c_optimizer })
    }

    // Maybe we should use the builder pattern to provide default values for these ?
    pub fn rms_prop(
        lr: f64,
        alpha: f64,
        eps: f64,
        wd: f64,
        momentum: f64,
        centered: bool,
    ) -> Result<COptimizer, TchError> {
        let centered = i32::from(centered);
        let c_optimizer =
            unsafe_torch_err!(torch_sys::ato_rms_prop(lr, alpha, eps, wd, momentum, centered));
        Ok(COptimizer { c_optimizer })
    }

    pub fn sgd(
        lr: f64,
        momentum: f64,
        dampening: f64,
        wd: f64,
        nesterov: bool,
    ) -> Result<COptimizer, TchError> {
        let nesterov = i32::from(nesterov);
        let c_optimizer =
            unsafe_torch_err!(torch_sys::ato_sgd(lr, momentum, dampening, wd, nesterov));
        Ok(COptimizer { c_optimizer })
    }

    pub fn add_parameters(&mut self, t: &Tensor, group: usize) -> Result<(), TchError> {
        unsafe_torch_err!(torch_sys::ato_add_parameters(self.c_optimizer, t.c_tensor, group));
        Ok(())
    }

    pub fn set_learning_rate(&mut self, lr: f64) -> Result<(), TchError> {
        unsafe_torch_err!(torch_sys::ato_set_learning_rate(self.c_optimizer, lr));
        Ok(())
    }

    pub fn set_learning_rate_group(&mut self, group: usize, lr: f64) -> Result<(), TchError> {
        unsafe_torch_err!(torch_sys::ato_set_learning_rate_group(self.c_optimizer, group, lr));
        Ok(())
    }

    pub fn set_momentum(&mut self, m: f64) -> Result<(), TchError> {
        unsafe_torch_err!(torch_sys::ato_set_momentum(self.c_optimizer, m));
        Ok(())
    }

    pub fn set_momentum_group(&mut self, group: usize, m: f64) -> Result<(), TchError> {
        unsafe_torch_err!(torch_sys::ato_set_momentum_group(self.c_optimizer, group, m));
        Ok(())
    }

    pub fn set_weight_decay(&mut self, weight_decay: f64) -> Result<(), TchError> {
        unsafe_torch_err!(torch_sys::ato_set_weight_decay(self.c_optimizer, weight_decay));
        Ok(())
    }

    pub fn set_weight_decay_group(
        &mut self,
        group: usize,
        weight_decay: f64,
    ) -> Result<(), TchError> {
        unsafe_torch_err!(torch_sys::ato_set_weight_decay_group(
            self.c_optimizer,
            group,
            weight_decay
        ));
        Ok(())
    }

    pub fn zero_grad(&self) -> Result<(), TchError> {
        unsafe_torch_err!(torch_sys::ato_zero_grad(self.c_optimizer));
        Ok(())
    }

    pub fn step(&self) -> Result<(), TchError> {
        unsafe_torch_err!(torch_sys::ato_step(self.c_optimizer));
        Ok(())
    }
}

impl Drop for COptimizer {
    fn drop(&mut self) {
        unsafe { torch_sys::ato_free(self.c_optimizer) }
    }
}