1use crate::error::FfiError;
4use crate::python::tensor::PyTensor;
5use pyo3::prelude::*;
6
7#[pyclass(name = "Optimizer", subclass)]
9#[derive(Clone)]
10pub struct PyOptimizer {
11 learning_rate: f32,
12 name: String,
13}
14
15#[pymethods]
16impl PyOptimizer {
17 fn step(&mut self) -> PyResult<()> {
19 Err(FfiError::UnsupportedOperation {
20 operation: "step not implemented for base Optimizer".to_string(),
21 }
22 .into())
23 }
24
25 fn zero_grad(&mut self) -> PyResult<()> {
27 Ok(())
29 }
30
31 #[getter]
32 fn lr(&self) -> f32 {
33 self.learning_rate
34 }
35
36 #[setter]
37 fn set_lr(&mut self, lr: f32) {
38 self.learning_rate = lr;
39 }
40
41 fn __repr__(&self) -> String {
42 format!("{}(lr={})", self.name, self.learning_rate)
43 }
44}
45
46#[pyclass(name = "SGD")]
48pub struct PySGD {
49 momentum: f32,
50 #[allow(dead_code)]
51 dampening: f32,
52 weight_decay: f32,
53 nesterov: bool,
54 learning_rate: f32,
55 }
57
58#[pymethods]
59impl PySGD {
60 #[new]
61 #[pyo3(signature = (_params, lr=0.01, momentum=0.0, dampening=0.0, weight_decay=0.0, nesterov=false))]
62 fn new(
63 _params: Vec<PyTensor>,
64 lr: f32,
65 momentum: f32,
66 dampening: f32,
67 weight_decay: f32,
68 nesterov: bool,
69 ) -> Self {
70 if nesterov && (momentum <= 0.0 || dampening != 0.0) {
71 panic!("Nesterov momentum requires a momentum and zero dampening");
72 }
73
74 PySGD {
75 momentum,
76 dampening,
77 weight_decay,
78 nesterov,
79 learning_rate: lr,
80 }
81 }
82
83 fn step(&mut self) -> PyResult<()> {
84 Ok(())
93 }
94
95 #[getter]
96 fn momentum(&self) -> f32 {
97 self.momentum
98 }
99
100 #[getter]
101 fn weight_decay(&self) -> f32 {
102 self.weight_decay
103 }
104
105 #[getter]
106 fn nesterov(&self) -> bool {
107 self.nesterov
108 }
109
110 #[getter]
111 fn lr(&self) -> f32 {
112 self.learning_rate
113 }
114
115 #[setter]
116 fn set_lr(&mut self, lr: f32) {
117 self.learning_rate = lr;
118 }
119
120 fn __repr__(&self) -> String {
121 format!(
122 "SGD(lr={}, momentum={}, weight_decay={}, nesterov={})",
123 self.learning_rate, self.momentum, self.weight_decay, self.nesterov
124 )
125 }
126}
127
128#[pyclass(name = "Adam")]
130pub struct PyAdam {
131 betas: (f32, f32),
132 eps: f32,
133 weight_decay: f32,
134 amsgrad: bool,
135 learning_rate: f32,
136 }
138
139#[pymethods]
140impl PyAdam {
141 #[new]
142 #[pyo3(signature = (_params, lr=0.001, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0, amsgrad=false))]
143 fn new(
144 _params: Vec<PyTensor>,
145 lr: f32,
146 betas: (f32, f32),
147 eps: f32,
148 weight_decay: f32,
149 amsgrad: bool,
150 ) -> Self {
151 PyAdam {
152 betas,
153 eps,
154 weight_decay,
155 amsgrad,
156 learning_rate: lr,
157 }
158 }
159
160 fn step(&mut self) -> PyResult<()> {
161 Ok(())
172 }
173
174 #[getter]
175 fn betas(&self) -> (f32, f32) {
176 self.betas
177 }
178
179 #[getter]
180 fn eps(&self) -> f32 {
181 self.eps
182 }
183
184 #[getter]
185 fn weight_decay(&self) -> f32 {
186 self.weight_decay
187 }
188
189 #[getter]
190 fn amsgrad(&self) -> bool {
191 self.amsgrad
192 }
193
194 #[getter]
195 fn lr(&self) -> f32 {
196 self.learning_rate
197 }
198
199 #[setter]
200 fn set_lr(&mut self, lr: f32) {
201 self.learning_rate = lr;
202 }
203
204 fn __repr__(&self) -> String {
205 format!(
206 "Adam(lr={}, betas={:?}, eps={}, weight_decay={})",
207 self.learning_rate, self.betas, self.eps, self.weight_decay
208 )
209 }
210}
211
212#[pyclass(name = "AdamW")]
214pub struct PyAdamW {
215 betas: (f32, f32),
216 eps: f32,
217 weight_decay: f32,
218 #[allow(dead_code)]
219 amsgrad: bool,
220 learning_rate: f32,
221}
222
223#[pymethods]
224impl PyAdamW {
225 #[new]
226 #[pyo3(signature = (_params, lr=0.001, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01, amsgrad=false))]
227 fn new(
228 _params: Vec<PyTensor>,
229 lr: f32,
230 betas: (f32, f32),
231 eps: f32,
232 weight_decay: f32,
233 amsgrad: bool,
234 ) -> Self {
235 PyAdamW {
236 betas,
237 eps,
238 weight_decay,
239 amsgrad,
240 learning_rate: lr,
241 }
242 }
243
244 fn step(&mut self) -> PyResult<()> {
245 Ok(())
250 }
251
252 fn __repr__(&self) -> String {
253 format!(
254 "AdamW(lr={}, betas={:?}, eps={}, weight_decay={})",
255 self.learning_rate, self.betas, self.eps, self.weight_decay
256 )
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263 use pyo3::types::PyList;
264 use pyo3::Python;
265
266 #[test]
267 fn test_sgd_creation() {
268 Python::initialize();
269 Python::attach(|py| {
270 let data = PyList::new(py, vec![1.0, 2.0, 3.0]).unwrap();
271 let tensor = PyTensor::new(data.as_ref(), None, None, true).unwrap();
272 let params = vec![tensor];
273
274 let sgd = PySGD::new(params, 0.01, 0.9, 0.0, 0.0, false);
275 assert_eq!(sgd.lr(), 0.01);
276 assert_eq!(sgd.momentum(), 0.9);
277 });
278 }
279
280 #[test]
281 fn test_adam_creation() {
282 Python::initialize();
283 Python::attach(|py| {
284 let data = PyList::new(py, vec![1.0, 2.0, 3.0]).unwrap();
285 let tensor = PyTensor::new(data.as_ref(), None, None, true).unwrap();
286 let params = vec![tensor];
287
288 let adam = PyAdam::new(params, 0.001, (0.9, 0.999), 1e-8, 0.0, false);
289 assert_eq!(adam.lr(), 0.001);
290 assert_eq!(adam.betas(), (0.9, 0.999));
291 assert_eq!(adam.eps(), 1e-8);
292 });
293 }
294
295 #[test]
296 fn test_optimizer_step() {
297 Python::initialize();
298 Python::attach(|py| {
299 let data = PyList::new(py, vec![1.0, 2.0, 3.0]).unwrap();
300 let tensor = PyTensor::new(data.as_ref(), None, None, true).unwrap();
301 let params = vec![tensor];
302
303 let mut sgd = PySGD::new(params, 0.01, 0.0, 0.0, 0.0, false);
304
305 assert!(sgd.step().is_ok());
307 });
308 }
309}