torsh_python/nn/module.rs
1//! Base neural network module - Foundation for all PyTorch-compatible layers
2
3use crate::{device::PyDevice, error::PyResult, tensor::PyTensor};
4use pyo3::prelude::*;
5use pyo3::types::PyAny;
6use std::collections::HashMap;
7
8/// Base class for all neural network modules
9#[pyclass(name = "Module", subclass)]
10pub struct PyModule {
11 // This will be overridden by subclasses
12}
13
14#[pymethods]
15impl PyModule {
16 #[new]
17 pub fn new() -> Self {
18 Self {}
19 }
20
21 /// Get all parameters of the module
22 fn parameters(&self) -> PyResult<Vec<PyTensor>> {
23 // Default implementation - subclasses should override
24 Ok(Vec::new())
25 }
26
27 /// Get all named parameters of the module
28 fn named_parameters(&self) -> PyResult<HashMap<String, PyTensor>> {
29 // Default implementation - subclasses should override
30 Ok(HashMap::new())
31 }
32
33 /// Set the module in training mode
34 fn train(&mut self, mode: Option<bool>) {
35 // Default implementation - subclasses should override
36 let _mode = mode.unwrap_or(true);
37 // Subclasses should implement actual training mode logic
38 }
39
40 /// Set the module in evaluation mode
41 fn eval(&mut self) {
42 // Default implementation - subclasses should override
43 // Subclasses should implement actual evaluation mode logic
44 }
45
46 /// Move module to specified device
47 fn to(&mut self, device: PyDevice) -> PyResult<()> {
48 // Default implementation - subclasses should override
49 let _device = device;
50 Ok(())
51 }
52
53 /// Zero out gradients of all parameters
54 fn zero_grad(&mut self) {
55 // Default implementation - subclasses should override
56 // Subclasses should implement actual gradient zeroing
57 }
58
59 /// Make module callable (forward pass)
60 fn __call__(&self, input: &PyTensor) -> PyResult<PyTensor> {
61 self.forward(input)
62 }
63
64 /// Forward pass - must be implemented by subclasses
65 fn forward(&self, _input: &PyTensor) -> PyResult<PyTensor> {
66 Err(PyErr::new::<pyo3::exceptions::PyNotImplementedError, _>(
67 "Subclasses must implement forward method",
68 ))
69 }
70
71 /// String representation
72 fn __repr__(&self) -> String {
73 "Module()".to_string()
74 }
75
76 /// Apply a function to all submodules
77 fn apply(&mut self, _func: Py<PyAny>) -> PyResult<()> {
78 // Default implementation - subclasses should override
79 Ok(())
80 }
81
82 /// Get the state dict (parameters and buffers)
83 fn state_dict(&self) -> PyResult<HashMap<String, PyTensor>> {
84 // Default implementation returns named parameters
85 self.named_parameters()
86 }
87
88 /// Load state dict (parameters and buffers)
89 fn load_state_dict(&mut self, _state_dict: HashMap<String, PyTensor>) -> PyResult<()> {
90 // Default implementation - subclasses should override
91 Err(PyErr::new::<pyo3::exceptions::PyNotImplementedError, _>(
92 "Subclasses must implement load_state_dict method",
93 ))
94 }
95
96 /// Get number of parameters
97 fn num_parameters(&self) -> PyResult<usize> {
98 let params = self.parameters()?;
99 Ok(params.iter().map(|p| p.numel()).sum())
100 }
101
102 /// Check if module is in training mode
103 fn training(&self) -> bool {
104 // Default implementation - subclasses should track this
105 true
106 }
107}