Skip to main content

torsh_python/nn/
container.rs

1//! Neural network containers - Sequential, ModuleList, etc.
2
3use super::module::PyModule;
4use crate::{device::PyDevice, error::PyResult, tensor::PyTensor};
5use pyo3::prelude::*;
6use pyo3::types::PyAny;
7use std::collections::HashMap;
8
9/// Sequential container - applies modules in sequence
10#[pyclass(name = "Sequential", extends = PyModule)]
11pub struct PySequential {
12    modules: Vec<Py<PyAny>>,
13    training: bool,
14}
15
16#[pymethods]
17impl PySequential {
18    #[new]
19    fn new(modules: Option<Vec<Py<PyAny>>>) -> (Self, PyModule) {
20        let modules = modules.unwrap_or_default();
21        (
22            Self {
23                modules,
24                training: true,
25            },
26            PyModule::new(),
27        )
28    }
29
30    /// Add a module to the sequential container
31    fn add_module(&mut self, _name: &str, module: Py<PyAny>) {
32        // For now, just add to the list (ignoring name)
33        self.modules.push(module);
34    }
35
36    /// Forward pass through all modules in sequence
37    fn forward(&self, mut input: PyTensor) -> PyResult<PyTensor> {
38        Python::attach(|py| {
39            for module in &self.modules {
40                // Call the forward method on each module
41                if let Ok(forward_method) = module.getattr(py, "forward") {
42                    let result = forward_method.call1(py, (input.clone(),))?;
43                    input = result.extract::<PyTensor>(py)?;
44                } else {
45                    // Try calling the module directly (__call__)
46                    let result = module.call1(py, (input.clone(),))?;
47                    input = result.extract::<PyTensor>(py)?;
48                }
49            }
50            Ok(input)
51        })
52    }
53
54    /// Get all parameters from all modules
55    fn parameters(&self) -> PyResult<Vec<PyTensor>> {
56        let mut all_params = Vec::new();
57        Python::attach(|py| {
58            for module in &self.modules {
59                if let Ok(params_method) = module.getattr(py, "parameters") {
60                    let params_result = params_method.call0(py)?;
61                    if let Ok(params) = params_result.extract::<Vec<PyTensor>>(py) {
62                        all_params.extend(params);
63                    }
64                }
65            }
66            Ok(all_params)
67        })
68    }
69
70    /// Get all named parameters from all modules
71    fn named_parameters(&self) -> PyResult<HashMap<String, PyTensor>> {
72        let mut all_named_params = HashMap::new();
73        Python::attach(|py| {
74            for (i, module) in self.modules.iter().enumerate() {
75                if let Ok(named_params_method) = module.getattr(py, "named_parameters") {
76                    let named_params_result = named_params_method.call0(py)?;
77                    if let Ok(named_params) =
78                        named_params_result.extract::<HashMap<String, PyTensor>>(py)
79                    {
80                        for (name, param) in named_params {
81                            all_named_params.insert(format!("{}.{}", i, name), param);
82                        }
83                    }
84                }
85            }
86            Ok(all_named_params)
87        })
88    }
89
90    /// Set training mode for all modules
91    fn train(&mut self, mode: Option<bool>) {
92        let mode = mode.unwrap_or(true);
93        self.training = mode;
94        Python::attach(|py| {
95            for module in &self.modules {
96                if let Ok(train_method) = module.getattr(py, "train") {
97                    let _ = train_method.call1(py, (mode,));
98                }
99            }
100        });
101    }
102
103    /// Set evaluation mode for all modules
104    fn eval(&mut self) {
105        self.training = false;
106        Python::attach(|py| {
107            for module in &self.modules {
108                if let Ok(eval_method) = module.getattr(py, "eval") {
109                    let _ = eval_method.call0(py);
110                }
111            }
112        });
113    }
114
115    /// Move all modules to specified device
116    fn to(&mut self, device: PyDevice) -> PyResult<()> {
117        Python::attach(|py| {
118            for module in &self.modules {
119                if let Ok(to_method) = module.getattr(py, "to") {
120                    to_method.call1(py, (device.clone(),))?;
121                }
122            }
123            Ok(())
124        })
125    }
126
127    /// Zero gradients for all modules
128    fn zero_grad(&mut self) {
129        Python::attach(|py| {
130            for module in &self.modules {
131                if let Ok(zero_grad_method) = module.getattr(py, "zero_grad") {
132                    let _ = zero_grad_method.call0(py);
133                }
134            }
135        });
136    }
137
138    /// String representation
139    fn __repr__(&self) -> String {
140        format!("Sequential({} modules)", self.modules.len())
141    }
142
143    /// Get length (number of modules)
144    fn __len__(&self) -> usize {
145        self.modules.len()
146    }
147
148    /// Get module by index
149    fn __getitem__(&self, index: usize) -> PyResult<Py<PyAny>> {
150        Python::attach(|py| {
151            self.modules
152                .get(index)
153                .map(|obj| obj.clone_ref(py))
154                .ok_or_else(|| {
155                    PyErr::new::<pyo3::exceptions::PyIndexError, _>("Index out of range")
156                })
157        })
158    }
159
160    /// Check if module is in training mode
161    fn training(&self) -> bool {
162        self.training
163    }
164}
165
166/// ModuleList container - holds modules in a list
167#[pyclass(name = "ModuleList", extends = PyModule)]
168pub struct PyModuleList {
169    modules: Vec<Py<PyAny>>,
170    training: bool,
171}
172
173#[pymethods]
174impl PyModuleList {
175    #[new]
176    fn new(modules: Option<Vec<Py<PyAny>>>) -> (Self, PyModule) {
177        let modules = modules.unwrap_or_default();
178        (
179            Self {
180                modules,
181                training: true,
182            },
183            PyModule::new(),
184        )
185    }
186
187    /// Append a module to the list
188    fn append(&mut self, module: Py<PyAny>) {
189        self.modules.push(module);
190    }
191
192    /// Extend the list with modules from another iterable
193    fn extend(&mut self, modules: Vec<Py<PyAny>>) {
194        self.modules.extend(modules);
195    }
196
197    /// Insert a module at the specified index
198    fn insert(&mut self, index: usize, module: Py<PyAny>) {
199        if index <= self.modules.len() {
200            self.modules.insert(index, module);
201        }
202    }
203
204    /// Get all parameters from all modules
205    fn parameters(&self) -> PyResult<Vec<PyTensor>> {
206        let mut all_params = Vec::new();
207        Python::attach(|py| {
208            for module in &self.modules {
209                if let Ok(params_method) = module.getattr(py, "parameters") {
210                    let params_result = params_method.call0(py)?;
211                    if let Ok(params) = params_result.extract::<Vec<PyTensor>>(py) {
212                        all_params.extend(params);
213                    }
214                }
215            }
216            Ok(all_params)
217        })
218    }
219
220    /// Get all named parameters from all modules
221    fn named_parameters(&self) -> PyResult<HashMap<String, PyTensor>> {
222        let mut all_named_params = HashMap::new();
223        Python::attach(|py| {
224            for (i, module) in self.modules.iter().enumerate() {
225                if let Ok(named_params_method) = module.getattr(py, "named_parameters") {
226                    let named_params_result = named_params_method.call0(py)?;
227                    if let Ok(named_params) =
228                        named_params_result.extract::<HashMap<String, PyTensor>>(py)
229                    {
230                        for (name, param) in named_params {
231                            all_named_params.insert(format!("{}.{}", i, name), param);
232                        }
233                    }
234                }
235            }
236            Ok(all_named_params)
237        })
238    }
239
240    /// Set training mode for all modules
241    fn train(&mut self, mode: Option<bool>) {
242        let mode = mode.unwrap_or(true);
243        self.training = mode;
244        Python::attach(|py| {
245            for module in &self.modules {
246                if let Ok(train_method) = module.getattr(py, "train") {
247                    let _ = train_method.call1(py, (mode,));
248                }
249            }
250        });
251    }
252
253    /// Set evaluation mode for all modules
254    fn eval(&mut self) {
255        self.training = false;
256        Python::attach(|py| {
257            for module in &self.modules {
258                if let Ok(eval_method) = module.getattr(py, "eval") {
259                    let _ = eval_method.call0(py);
260                }
261            }
262        });
263    }
264
265    /// Move all modules to specified device
266    fn to(&mut self, device: PyDevice) -> PyResult<()> {
267        Python::attach(|py| {
268            for module in &self.modules {
269                if let Ok(to_method) = module.getattr(py, "to") {
270                    to_method.call1(py, (device.clone(),))?;
271                }
272            }
273            Ok(())
274        })
275    }
276
277    /// Zero gradients for all modules
278    fn zero_grad(&mut self) {
279        Python::attach(|py| {
280            for module in &self.modules {
281                if let Ok(zero_grad_method) = module.getattr(py, "zero_grad") {
282                    let _ = zero_grad_method.call0(py);
283                }
284            }
285        });
286    }
287
288    /// String representation
289    fn __repr__(&self) -> String {
290        format!("ModuleList({} modules)", self.modules.len())
291    }
292
293    /// Get length (number of modules)
294    fn __len__(&self) -> usize {
295        self.modules.len()
296    }
297
298    /// Get module by index
299    fn __getitem__(&self, index: usize) -> PyResult<Py<PyAny>> {
300        Python::attach(|py| {
301            self.modules
302                .get(index)
303                .map(|obj| obj.clone_ref(py))
304                .ok_or_else(|| {
305                    PyErr::new::<pyo3::exceptions::PyIndexError, _>("Index out of range")
306                })
307        })
308    }
309
310    /// Set module at index
311    fn __setitem__(&mut self, index: usize, module: Py<PyAny>) -> PyResult<()> {
312        if index < self.modules.len() {
313            self.modules[index] = module;
314            Ok(())
315        } else {
316            Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
317                "Index out of range",
318            ))
319        }
320    }
321
322    /// Check if module is in training mode
323    fn training(&self) -> bool {
324        self.training
325    }
326}