torsh_ffi/python/
optimizer.rs

1//! Python optimizer wrappers
2
3use crate::error::FfiError;
4use crate::python::tensor::PyTensor;
5use pyo3::prelude::*;
6
7/// Base optimizer class
8#[pyclass(name = "Optimizer", subclass)]
9#[derive(Clone)]
10pub struct PyOptimizer {
11    learning_rate: f32,
12    name: String,
13}
14
15#[pymethods]
16impl PyOptimizer {
17    /// Perform optimization step
18    fn step(&mut self) -> PyResult<()> {
19        Err(FfiError::UnsupportedOperation {
20            operation: "step not implemented for base Optimizer".to_string(),
21        }
22        .into())
23    }
24
25    /// Zero all gradients
26    fn zero_grad(&mut self) -> PyResult<()> {
27        // In a real implementation, this would zero gradients of all parameters
28        Ok(())
29    }
30
31    #[getter]
32    fn lr(&self) -> f32 {
33        self.learning_rate
34    }
35
36    #[setter]
37    fn set_lr(&mut self, lr: f32) {
38        self.learning_rate = lr;
39    }
40
41    fn __repr__(&self) -> String {
42        format!("{}(lr={})", self.name, self.learning_rate)
43    }
44}
45
46/// SGD optimizer
47#[pyclass(name = "SGD")]
48pub struct PySGD {
49    momentum: f32,
50    #[allow(dead_code)]
51    dampening: f32,
52    weight_decay: f32,
53    nesterov: bool,
54    learning_rate: f32,
55    // In real implementation, would store momentum buffers
56}
57
58#[pymethods]
59impl PySGD {
60    #[new]
61    #[pyo3(signature = (_params, lr=0.01, momentum=0.0, dampening=0.0, weight_decay=0.0, nesterov=false))]
62    fn new(
63        _params: Vec<PyTensor>,
64        lr: f32,
65        momentum: f32,
66        dampening: f32,
67        weight_decay: f32,
68        nesterov: bool,
69    ) -> Self {
70        if nesterov && (momentum <= 0.0 || dampening != 0.0) {
71            panic!("Nesterov momentum requires a momentum and zero dampening");
72        }
73
74        PySGD {
75            momentum,
76            dampening,
77            weight_decay,
78            nesterov,
79            learning_rate: lr,
80        }
81    }
82
83    fn step(&mut self) -> PyResult<()> {
84        // Simplified SGD step implementation
85        // In real implementation, would update all parameters
86
87        // For each parameter:
88        // 1. Compute gradient with weight decay: grad = grad + weight_decay * param
89        // 2. Apply momentum if > 0
90        // 3. Update parameter: param = param - lr * grad
91
92        Ok(())
93    }
94
95    #[getter]
96    fn momentum(&self) -> f32 {
97        self.momentum
98    }
99
100    #[getter]
101    fn weight_decay(&self) -> f32 {
102        self.weight_decay
103    }
104
105    #[getter]
106    fn nesterov(&self) -> bool {
107        self.nesterov
108    }
109
110    #[getter]
111    fn lr(&self) -> f32 {
112        self.learning_rate
113    }
114
115    #[setter]
116    fn set_lr(&mut self, lr: f32) {
117        self.learning_rate = lr;
118    }
119
120    fn __repr__(&self) -> String {
121        format!(
122            "SGD(lr={}, momentum={}, weight_decay={}, nesterov={})",
123            self.learning_rate, self.momentum, self.weight_decay, self.nesterov
124        )
125    }
126}
127
128/// Adam optimizer
129#[pyclass(name = "Adam")]
130pub struct PyAdam {
131    betas: (f32, f32),
132    eps: f32,
133    weight_decay: f32,
134    amsgrad: bool,
135    learning_rate: f32,
136    // In real implementation, would store exp_avg and exp_avg_sq for each parameter
137}
138
139#[pymethods]
140impl PyAdam {
141    #[new]
142    #[pyo3(signature = (_params, lr=0.001, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0, amsgrad=false))]
143    fn new(
144        _params: Vec<PyTensor>,
145        lr: f32,
146        betas: (f32, f32),
147        eps: f32,
148        weight_decay: f32,
149        amsgrad: bool,
150    ) -> Self {
151        PyAdam {
152            betas,
153            eps,
154            weight_decay,
155            amsgrad,
156            learning_rate: lr,
157        }
158    }
159
160    fn step(&mut self) -> PyResult<()> {
161        // Simplified Adam step implementation
162        // In real implementation, would update all parameters using Adam algorithm:
163
164        // For each parameter:
165        // 1. Compute gradient with weight decay if > 0
166        // 2. Update biased first moment estimate: m_t = beta1 * m_{t-1} + (1 - beta1) * grad
167        // 3. Update biased second moment estimate: v_t = beta2 * v_{t-1} + (1 - beta2) * grad^2
168        // 4. Compute bias-corrected estimates: m_hat = m_t / (1 - beta1^t), v_hat = v_t / (1 - beta2^t)
169        // 5. Update parameter: param = param - lr * m_hat / (sqrt(v_hat) + eps)
170
171        Ok(())
172    }
173
174    #[getter]
175    fn betas(&self) -> (f32, f32) {
176        self.betas
177    }
178
179    #[getter]
180    fn eps(&self) -> f32 {
181        self.eps
182    }
183
184    #[getter]
185    fn weight_decay(&self) -> f32 {
186        self.weight_decay
187    }
188
189    #[getter]
190    fn amsgrad(&self) -> bool {
191        self.amsgrad
192    }
193
194    #[getter]
195    fn lr(&self) -> f32 {
196        self.learning_rate
197    }
198
199    #[setter]
200    fn set_lr(&mut self, lr: f32) {
201        self.learning_rate = lr;
202    }
203
204    fn __repr__(&self) -> String {
205        format!(
206            "Adam(lr={}, betas={:?}, eps={}, weight_decay={})",
207            self.learning_rate, self.betas, self.eps, self.weight_decay
208        )
209    }
210}
211
212/// AdamW optimizer (Adam with decoupled weight decay)
213#[pyclass(name = "AdamW")]
214pub struct PyAdamW {
215    betas: (f32, f32),
216    eps: f32,
217    weight_decay: f32,
218    #[allow(dead_code)]
219    amsgrad: bool,
220    learning_rate: f32,
221}
222
223#[pymethods]
224impl PyAdamW {
225    #[new]
226    #[pyo3(signature = (_params, lr=0.001, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01, amsgrad=false))]
227    fn new(
228        _params: Vec<PyTensor>,
229        lr: f32,
230        betas: (f32, f32),
231        eps: f32,
232        weight_decay: f32,
233        amsgrad: bool,
234    ) -> Self {
235        PyAdamW {
236            betas,
237            eps,
238            weight_decay,
239            amsgrad,
240            learning_rate: lr,
241        }
242    }
243
244    fn step(&mut self) -> PyResult<()> {
245        // AdamW implementation with decoupled weight decay
246        // The key difference from Adam is that weight decay is applied directly to parameters
247        // rather than being added to gradients
248
249        Ok(())
250    }
251
252    fn __repr__(&self) -> String {
253        format!(
254            "AdamW(lr={}, betas={:?}, eps={}, weight_decay={})",
255            self.learning_rate, self.betas, self.eps, self.weight_decay
256        )
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263    use pyo3::types::PyList;
264    use pyo3::Python;
265
266    #[test]
267    fn test_sgd_creation() {
268        Python::initialize();
269        Python::attach(|py| {
270            let data = PyList::new(py, vec![1.0, 2.0, 3.0]).unwrap();
271            let tensor = PyTensor::new(data.as_ref(), None, None, true).unwrap();
272            let params = vec![tensor];
273
274            let sgd = PySGD::new(params, 0.01, 0.9, 0.0, 0.0, false);
275            assert_eq!(sgd.lr(), 0.01);
276            assert_eq!(sgd.momentum(), 0.9);
277        });
278    }
279
280    #[test]
281    fn test_adam_creation() {
282        Python::initialize();
283        Python::attach(|py| {
284            let data = PyList::new(py, vec![1.0, 2.0, 3.0]).unwrap();
285            let tensor = PyTensor::new(data.as_ref(), None, None, true).unwrap();
286            let params = vec![tensor];
287
288            let adam = PyAdam::new(params, 0.001, (0.9, 0.999), 1e-8, 0.0, false);
289            assert_eq!(adam.lr(), 0.001);
290            assert_eq!(adam.betas(), (0.9, 0.999));
291            assert_eq!(adam.eps(), 1e-8);
292        });
293    }
294
295    #[test]
296    fn test_optimizer_step() {
297        Python::initialize();
298        Python::attach(|py| {
299            let data = PyList::new(py, vec![1.0, 2.0, 3.0]).unwrap();
300            let tensor = PyTensor::new(data.as_ref(), None, None, true).unwrap();
301            let params = vec![tensor];
302
303            let mut sgd = PySGD::new(params, 0.01, 0.0, 0.0, 0.0, false);
304
305            // Should not error
306            assert!(sgd.step().is_ok());
307        });
308    }
309}