Skip to main content

torsh_python/optim/
base.rs

1//! Base optimizer implementation - Foundation for all PyTorch-compatible optimizers
2
3use crate::{error::PyResult, tensor::PyTensor};
4use pyo3::prelude::*;
5use pyo3::types::PyAny;
6use std::collections::HashMap;
7
8/// Base optimizer class - foundation for all optimizers
9#[pyclass(name = "Optimizer", subclass)]
10pub struct PyOptimizer {
11    // This will be overridden by subclasses
12}
13
14#[pymethods]
15impl PyOptimizer {
16    #[new]
17    fn new() -> Self {
18        Self {}
19    }
20
21    /// Perform a single optimization step - must be implemented by subclasses
22    fn step(&mut self) -> PyResult<()> {
23        Err(PyErr::new::<pyo3::exceptions::PyNotImplementedError, _>(
24            "Subclasses must implement step method",
25        ))
26    }
27
28    /// Zero out gradients of all parameters
29    fn zero_grad(&mut self, set_to_none: Option<bool>) {
30        // Default implementation - subclasses should override
31        let _set_to_none = set_to_none.unwrap_or(false);
32        // Subclasses should implement actual gradient zeroing
33    }
34
35    /// Get the state dictionary (optimizer state and hyperparameters)
36    fn state_dict(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
37        // Default implementation - subclasses should override
38        Ok(HashMap::new())
39    }
40
41    /// Load state dictionary
42    fn load_state_dict(&mut self, state_dict: HashMap<String, Py<PyAny>>) -> PyResult<()> {
43        // Default implementation - subclasses should override
44        let _state_dict = state_dict;
45        Ok(())
46    }
47
48    /// Get parameter groups
49    fn param_groups(&self) -> PyResult<Vec<HashMap<String, Py<PyAny>>>> {
50        // Default implementation - subclasses should override
51        Ok(Vec::new())
52    }
53
54    /// Get current state
55    fn state(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
56        // Default implementation - subclasses should override
57        Ok(HashMap::new())
58    }
59
60    /// Add a new parameter group
61    fn add_param_group(&mut self, param_group: HashMap<String, Py<PyAny>>) -> PyResult<()> {
62        // Default implementation - subclasses should override
63        let _param_group = param_group;
64        Err(PyErr::new::<pyo3::exceptions::PyNotImplementedError, _>(
65            "Subclasses must implement add_param_group method",
66        ))
67    }
68
69    /// String representation
70    fn __repr__(&self) -> String {
71        "Optimizer()".to_string()
72    }
73
74    /// Get defaults (default hyperparameters)
75    fn defaults(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
76        // Default implementation - subclasses should override
77        Ok(HashMap::new())
78    }
79}
80
81/// Helper function to extract parameters from Python objects
82pub fn extract_parameters(params: Vec<PyTensor>) -> PyResult<Vec<torsh_tensor::Tensor<f32>>> {
83    params.into_iter().map(|p| Ok(p.tensor)).collect()
84}
85
86/// Helper function to create parameter group
87pub fn create_param_group(
88    params: Vec<PyTensor>,
89    lr: f32,
90    extra_params: HashMap<String, Py<PyAny>>,
91) -> PyResult<HashMap<String, Py<PyAny>>> {
92    let mut param_group = HashMap::new();
93
94    Python::attach(|py| {
95        // Add parameters
96        let py_params: Vec<Py<PyAny>> = params
97            .into_iter()
98            .map(|p| {
99                p.into_pyobject(py)
100                    .expect("Python object conversion should succeed")
101                    .into()
102            })
103            .collect();
104        param_group.insert(
105            "params".to_string(),
106            py_params
107                .into_pyobject(py)
108                .expect("Python object conversion should succeed")
109                .into(),
110        );
111
112        // Add learning rate
113        param_group.insert(
114            "lr".to_string(),
115            lr.into_pyobject(py)
116                .expect("Python object conversion should succeed")
117                .into(),
118        );
119
120        // Add extra parameters
121        for (key, value) in extra_params {
122            param_group.insert(key, value);
123        }
124
125        Ok(param_group)
126    })
127}