Skip to main content

torsh_python/nn/
module.rs

1//! Base neural network module - Foundation for all PyTorch-compatible layers
2
3use crate::{device::PyDevice, error::PyResult, tensor::PyTensor};
4use pyo3::prelude::*;
5use pyo3::types::PyAny;
6use std::collections::HashMap;
7
8/// Base class for all neural network modules
9#[pyclass(name = "Module", subclass)]
10pub struct PyModule {
11    // This will be overridden by subclasses
12}
13
14#[pymethods]
15impl PyModule {
16    #[new]
17    pub fn new() -> Self {
18        Self {}
19    }
20
21    /// Get all parameters of the module
22    fn parameters(&self) -> PyResult<Vec<PyTensor>> {
23        // Default implementation - subclasses should override
24        Ok(Vec::new())
25    }
26
27    /// Get all named parameters of the module
28    fn named_parameters(&self) -> PyResult<HashMap<String, PyTensor>> {
29        // Default implementation - subclasses should override
30        Ok(HashMap::new())
31    }
32
33    /// Set the module in training mode
34    fn train(&mut self, mode: Option<bool>) {
35        // Default implementation - subclasses should override
36        let _mode = mode.unwrap_or(true);
37        // Subclasses should implement actual training mode logic
38    }
39
40    /// Set the module in evaluation mode
41    fn eval(&mut self) {
42        // Default implementation - subclasses should override
43        // Subclasses should implement actual evaluation mode logic
44    }
45
46    /// Move module to specified device
47    fn to(&mut self, device: PyDevice) -> PyResult<()> {
48        // Default implementation - subclasses should override
49        let _device = device;
50        Ok(())
51    }
52
53    /// Zero out gradients of all parameters
54    fn zero_grad(&mut self) {
55        // Default implementation - subclasses should override
56        // Subclasses should implement actual gradient zeroing
57    }
58
59    /// Make module callable (forward pass)
60    fn __call__(&self, input: &PyTensor) -> PyResult<PyTensor> {
61        self.forward(input)
62    }
63
64    /// Forward pass - must be implemented by subclasses
65    fn forward(&self, _input: &PyTensor) -> PyResult<PyTensor> {
66        Err(PyErr::new::<pyo3::exceptions::PyNotImplementedError, _>(
67            "Subclasses must implement forward method",
68        ))
69    }
70
71    /// String representation
72    fn __repr__(&self) -> String {
73        "Module()".to_string()
74    }
75
76    /// Apply a function to all submodules
77    fn apply(&mut self, _func: Py<PyAny>) -> PyResult<()> {
78        // Default implementation - subclasses should override
79        Ok(())
80    }
81
82    /// Get the state dict (parameters and buffers)
83    fn state_dict(&self) -> PyResult<HashMap<String, PyTensor>> {
84        // Default implementation returns named parameters
85        self.named_parameters()
86    }
87
88    /// Load state dict (parameters and buffers)
89    fn load_state_dict(&mut self, _state_dict: HashMap<String, PyTensor>) -> PyResult<()> {
90        // Default implementation - subclasses should override
91        Err(PyErr::new::<pyo3::exceptions::PyNotImplementedError, _>(
92            "Subclasses must implement load_state_dict method",
93        ))
94    }
95
96    /// Get number of parameters
97    fn num_parameters(&self) -> PyResult<usize> {
98        let params = self.parameters()?;
99        Ok(params.iter().map(|p| p.numel()).sum())
100    }
101
102    /// Check if module is in training mode
103    fn training(&self) -> bool {
104        // Default implementation - subclasses should track this
105        true
106    }
107}