torsh_python/optim/
base.rs1use crate::{error::PyResult, tensor::PyTensor};
4use pyo3::prelude::*;
5use pyo3::types::PyAny;
6use std::collections::HashMap;
7
8#[pyclass(name = "Optimizer", subclass)]
10pub struct PyOptimizer {
11 }
13
14#[pymethods]
15impl PyOptimizer {
16 #[new]
17 fn new() -> Self {
18 Self {}
19 }
20
21 fn step(&mut self) -> PyResult<()> {
23 Err(PyErr::new::<pyo3::exceptions::PyNotImplementedError, _>(
24 "Subclasses must implement step method",
25 ))
26 }
27
28 fn zero_grad(&mut self, set_to_none: Option<bool>) {
30 let _set_to_none = set_to_none.unwrap_or(false);
32 }
34
35 fn state_dict(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
37 Ok(HashMap::new())
39 }
40
41 fn load_state_dict(&mut self, state_dict: HashMap<String, Py<PyAny>>) -> PyResult<()> {
43 let _state_dict = state_dict;
45 Ok(())
46 }
47
48 fn param_groups(&self) -> PyResult<Vec<HashMap<String, Py<PyAny>>>> {
50 Ok(Vec::new())
52 }
53
54 fn state(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
56 Ok(HashMap::new())
58 }
59
60 fn add_param_group(&mut self, param_group: HashMap<String, Py<PyAny>>) -> PyResult<()> {
62 let _param_group = param_group;
64 Err(PyErr::new::<pyo3::exceptions::PyNotImplementedError, _>(
65 "Subclasses must implement add_param_group method",
66 ))
67 }
68
69 fn __repr__(&self) -> String {
71 "Optimizer()".to_string()
72 }
73
74 fn defaults(&self) -> PyResult<HashMap<String, Py<PyAny>>> {
76 Ok(HashMap::new())
78 }
79}
80
81pub fn extract_parameters(params: Vec<PyTensor>) -> PyResult<Vec<torsh_tensor::Tensor<f32>>> {
83 params.into_iter().map(|p| Ok(p.tensor)).collect()
84}
85
86pub fn create_param_group(
88 params: Vec<PyTensor>,
89 lr: f32,
90 extra_params: HashMap<String, Py<PyAny>>,
91) -> PyResult<HashMap<String, Py<PyAny>>> {
92 let mut param_group = HashMap::new();
93
94 Python::attach(|py| {
95 let py_params: Vec<Py<PyAny>> = params
97 .into_iter()
98 .map(|p| p.into_pyobject(py).unwrap().into())
99 .collect();
100 param_group.insert(
101 "params".to_string(),
102 py_params.into_pyobject(py).unwrap().into(),
103 );
104
105 param_group.insert("lr".to_string(), lr.into_pyobject(py).unwrap().into());
107
108 for (key, value) in extra_params {
110 param_group.insert(key, value);
111 }
112
113 Ok(param_group)
114 })
115}