Skip to main content

torsh_python/nn/
activation.rs

1//! Activation functions for neural networks
2
3use super::module::PyModule;
4use crate::{error::PyResult, py_result, tensor::PyTensor};
5use pyo3::prelude::*;
6
7/// ReLU activation function
8#[pyclass(name = "ReLU", extends = PyModule)]
9pub struct PyReLU {
10    inplace: bool,
11    training: bool,
12}
13
14#[pymethods]
15impl PyReLU {
16    #[new]
17    fn new(inplace: Option<bool>) -> (Self, PyModule) {
18        (
19            Self {
20                inplace: inplace.unwrap_or(false),
21                training: true,
22            },
23            PyModule::new(),
24        )
25    }
26
27    /// Forward pass through ReLU
28    fn forward(&self, input: &PyTensor) -> PyResult<PyTensor> {
29        let result = py_result!(input.tensor.relu())?;
30        Ok(PyTensor { tensor: result })
31    }
32
33    /// String representation
34    fn __repr__(&self) -> String {
35        if self.inplace {
36            "ReLU(inplace=True)".to_string()
37        } else {
38            "ReLU()".to_string()
39        }
40    }
41
42    /// Set training mode
43    fn train(&mut self, mode: Option<bool>) {
44        self.training = mode.unwrap_or(true);
45    }
46
47    /// Set evaluation mode
48    fn eval(&mut self) {
49        self.training = false;
50    }
51
52    /// Check if module is in training mode
53    fn training(&self) -> bool {
54        self.training
55    }
56}
57
58/// Sigmoid activation function
59#[pyclass(name = "Sigmoid", extends = PyModule)]
60pub struct PySigmoid {
61    training: bool,
62}
63
64#[pymethods]
65impl PySigmoid {
66    #[new]
67    fn new() -> (Self, PyModule) {
68        (Self { training: true }, PyModule::new())
69    }
70
71    /// Forward pass through Sigmoid
72    fn forward(&self, input: &PyTensor) -> PyResult<PyTensor> {
73        let result = py_result!(input.tensor.sigmoid())?;
74        Ok(PyTensor { tensor: result })
75    }
76
77    /// String representation
78    fn __repr__(&self) -> String {
79        "Sigmoid()".to_string()
80    }
81
82    /// Set training mode
83    fn train(&mut self, mode: Option<bool>) {
84        self.training = mode.unwrap_or(true);
85    }
86
87    /// Set evaluation mode
88    fn eval(&mut self) {
89        self.training = false;
90    }
91
92    /// Check if module is in training mode
93    fn training(&self) -> bool {
94        self.training
95    }
96}
97
98/// Tanh activation function
99#[pyclass(name = "Tanh", extends = PyModule)]
100pub struct PyTanh {
101    training: bool,
102}
103
104#[pymethods]
105impl PyTanh {
106    #[new]
107    fn new() -> (Self, PyModule) {
108        (Self { training: true }, PyModule::new())
109    }
110
111    /// Forward pass through Tanh
112    fn forward(&self, input: &PyTensor) -> PyResult<PyTensor> {
113        let result = py_result!(input.tensor.tanh())?;
114        Ok(PyTensor { tensor: result })
115    }
116
117    /// String representation
118    fn __repr__(&self) -> String {
119        "Tanh()".to_string()
120    }
121
122    /// Set training mode
123    fn train(&mut self, mode: Option<bool>) {
124        self.training = mode.unwrap_or(true);
125    }
126
127    /// Set evaluation mode
128    fn eval(&mut self) {
129        self.training = false;
130    }
131
132    /// Check if module is in training mode
133    fn training(&self) -> bool {
134        self.training
135    }
136}