torsh_python/nn/
container.rs1use super::module::PyModule;
4use crate::{device::PyDevice, error::PyResult, tensor::PyTensor};
5use pyo3::prelude::*;
6use pyo3::types::PyAny;
7use std::collections::HashMap;
8
9#[pyclass(name = "Sequential", extends = PyModule)]
11pub struct PySequential {
12 modules: Vec<Py<PyAny>>,
13 training: bool,
14}
15
16#[pymethods]
17impl PySequential {
18 #[new]
19 fn new(modules: Option<Vec<Py<PyAny>>>) -> (Self, PyModule) {
20 let modules = modules.unwrap_or_default();
21 (
22 Self {
23 modules,
24 training: true,
25 },
26 PyModule::new(),
27 )
28 }
29
30 fn add_module(&mut self, _name: &str, module: Py<PyAny>) {
32 self.modules.push(module);
34 }
35
36 fn forward(&self, mut input: PyTensor) -> PyResult<PyTensor> {
38 Python::attach(|py| {
39 for module in &self.modules {
40 if let Ok(forward_method) = module.getattr(py, "forward") {
42 let result = forward_method.call1(py, (input.clone(),))?;
43 input = result.extract::<PyTensor>(py)?;
44 } else {
45 let result = module.call1(py, (input.clone(),))?;
47 input = result.extract::<PyTensor>(py)?;
48 }
49 }
50 Ok(input)
51 })
52 }
53
54 fn parameters(&self) -> PyResult<Vec<PyTensor>> {
56 let mut all_params = Vec::new();
57 Python::attach(|py| {
58 for module in &self.modules {
59 if let Ok(params_method) = module.getattr(py, "parameters") {
60 let params_result = params_method.call0(py)?;
61 if let Ok(params) = params_result.extract::<Vec<PyTensor>>(py) {
62 all_params.extend(params);
63 }
64 }
65 }
66 Ok(all_params)
67 })
68 }
69
70 fn named_parameters(&self) -> PyResult<HashMap<String, PyTensor>> {
72 let mut all_named_params = HashMap::new();
73 Python::attach(|py| {
74 for (i, module) in self.modules.iter().enumerate() {
75 if let Ok(named_params_method) = module.getattr(py, "named_parameters") {
76 let named_params_result = named_params_method.call0(py)?;
77 if let Ok(named_params) =
78 named_params_result.extract::<HashMap<String, PyTensor>>(py)
79 {
80 for (name, param) in named_params {
81 all_named_params.insert(format!("{}.{}", i, name), param);
82 }
83 }
84 }
85 }
86 Ok(all_named_params)
87 })
88 }
89
90 fn train(&mut self, mode: Option<bool>) {
92 let mode = mode.unwrap_or(true);
93 self.training = mode;
94 Python::attach(|py| {
95 for module in &self.modules {
96 if let Ok(train_method) = module.getattr(py, "train") {
97 let _ = train_method.call1(py, (mode,));
98 }
99 }
100 });
101 }
102
103 fn eval(&mut self) {
105 self.training = false;
106 Python::attach(|py| {
107 for module in &self.modules {
108 if let Ok(eval_method) = module.getattr(py, "eval") {
109 let _ = eval_method.call0(py);
110 }
111 }
112 });
113 }
114
115 fn to(&mut self, device: PyDevice) -> PyResult<()> {
117 Python::attach(|py| {
118 for module in &self.modules {
119 if let Ok(to_method) = module.getattr(py, "to") {
120 to_method.call1(py, (device.clone(),))?;
121 }
122 }
123 Ok(())
124 })
125 }
126
127 fn zero_grad(&mut self) {
129 Python::attach(|py| {
130 for module in &self.modules {
131 if let Ok(zero_grad_method) = module.getattr(py, "zero_grad") {
132 let _ = zero_grad_method.call0(py);
133 }
134 }
135 });
136 }
137
138 fn __repr__(&self) -> String {
140 format!("Sequential({} modules)", self.modules.len())
141 }
142
143 fn __len__(&self) -> usize {
145 self.modules.len()
146 }
147
148 fn __getitem__(&self, index: usize) -> PyResult<Py<PyAny>> {
150 Python::attach(|py| {
151 self.modules
152 .get(index)
153 .map(|obj| obj.clone_ref(py))
154 .ok_or_else(|| {
155 PyErr::new::<pyo3::exceptions::PyIndexError, _>("Index out of range")
156 })
157 })
158 }
159
160 fn training(&self) -> bool {
162 self.training
163 }
164}
165
166#[pyclass(name = "ModuleList", extends = PyModule)]
168pub struct PyModuleList {
169 modules: Vec<Py<PyAny>>,
170 training: bool,
171}
172
173#[pymethods]
174impl PyModuleList {
175 #[new]
176 fn new(modules: Option<Vec<Py<PyAny>>>) -> (Self, PyModule) {
177 let modules = modules.unwrap_or_default();
178 (
179 Self {
180 modules,
181 training: true,
182 },
183 PyModule::new(),
184 )
185 }
186
187 fn append(&mut self, module: Py<PyAny>) {
189 self.modules.push(module);
190 }
191
192 fn extend(&mut self, modules: Vec<Py<PyAny>>) {
194 self.modules.extend(modules);
195 }
196
197 fn insert(&mut self, index: usize, module: Py<PyAny>) {
199 if index <= self.modules.len() {
200 self.modules.insert(index, module);
201 }
202 }
203
204 fn parameters(&self) -> PyResult<Vec<PyTensor>> {
206 let mut all_params = Vec::new();
207 Python::attach(|py| {
208 for module in &self.modules {
209 if let Ok(params_method) = module.getattr(py, "parameters") {
210 let params_result = params_method.call0(py)?;
211 if let Ok(params) = params_result.extract::<Vec<PyTensor>>(py) {
212 all_params.extend(params);
213 }
214 }
215 }
216 Ok(all_params)
217 })
218 }
219
220 fn named_parameters(&self) -> PyResult<HashMap<String, PyTensor>> {
222 let mut all_named_params = HashMap::new();
223 Python::attach(|py| {
224 for (i, module) in self.modules.iter().enumerate() {
225 if let Ok(named_params_method) = module.getattr(py, "named_parameters") {
226 let named_params_result = named_params_method.call0(py)?;
227 if let Ok(named_params) =
228 named_params_result.extract::<HashMap<String, PyTensor>>(py)
229 {
230 for (name, param) in named_params {
231 all_named_params.insert(format!("{}.{}", i, name), param);
232 }
233 }
234 }
235 }
236 Ok(all_named_params)
237 })
238 }
239
240 fn train(&mut self, mode: Option<bool>) {
242 let mode = mode.unwrap_or(true);
243 self.training = mode;
244 Python::attach(|py| {
245 for module in &self.modules {
246 if let Ok(train_method) = module.getattr(py, "train") {
247 let _ = train_method.call1(py, (mode,));
248 }
249 }
250 });
251 }
252
253 fn eval(&mut self) {
255 self.training = false;
256 Python::attach(|py| {
257 for module in &self.modules {
258 if let Ok(eval_method) = module.getattr(py, "eval") {
259 let _ = eval_method.call0(py);
260 }
261 }
262 });
263 }
264
265 fn to(&mut self, device: PyDevice) -> PyResult<()> {
267 Python::attach(|py| {
268 for module in &self.modules {
269 if let Ok(to_method) = module.getattr(py, "to") {
270 to_method.call1(py, (device.clone(),))?;
271 }
272 }
273 Ok(())
274 })
275 }
276
277 fn zero_grad(&mut self) {
279 Python::attach(|py| {
280 for module in &self.modules {
281 if let Ok(zero_grad_method) = module.getattr(py, "zero_grad") {
282 let _ = zero_grad_method.call0(py);
283 }
284 }
285 });
286 }
287
288 fn __repr__(&self) -> String {
290 format!("ModuleList({} modules)", self.modules.len())
291 }
292
293 fn __len__(&self) -> usize {
295 self.modules.len()
296 }
297
298 fn __getitem__(&self, index: usize) -> PyResult<Py<PyAny>> {
300 Python::attach(|py| {
301 self.modules
302 .get(index)
303 .map(|obj| obj.clone_ref(py))
304 .ok_or_else(|| {
305 PyErr::new::<pyo3::exceptions::PyIndexError, _>("Index out of range")
306 })
307 })
308 }
309
310 fn __setitem__(&mut self, index: usize, module: Py<PyAny>) -> PyResult<()> {
312 if index < self.modules.len() {
313 self.modules[index] = module;
314 Ok(())
315 } else {
316 Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
317 "Index out of range",
318 ))
319 }
320 }
321
322 fn training(&self) -> bool {
324 self.training
325 }
326}