torsh_python/nn/
activation.rs1use super::module::PyModule;
4use crate::{error::PyResult, py_result, tensor::PyTensor};
5use pyo3::prelude::*;
6
7#[pyclass(name = "ReLU", extends = PyModule)]
9pub struct PyReLU {
10 inplace: bool,
11 training: bool,
12}
13
14#[pymethods]
15impl PyReLU {
16 #[new]
17 fn new(inplace: Option<bool>) -> (Self, PyModule) {
18 (
19 Self {
20 inplace: inplace.unwrap_or(false),
21 training: true,
22 },
23 PyModule::new(),
24 )
25 }
26
27 fn forward(&self, input: &PyTensor) -> PyResult<PyTensor> {
29 let result = py_result!(input.tensor.relu())?;
30 Ok(PyTensor { tensor: result })
31 }
32
33 fn __repr__(&self) -> String {
35 if self.inplace {
36 "ReLU(inplace=True)".to_string()
37 } else {
38 "ReLU()".to_string()
39 }
40 }
41
42 fn train(&mut self, mode: Option<bool>) {
44 self.training = mode.unwrap_or(true);
45 }
46
47 fn eval(&mut self) {
49 self.training = false;
50 }
51
52 fn training(&self) -> bool {
54 self.training
55 }
56}
57
58#[pyclass(name = "Sigmoid", extends = PyModule)]
60pub struct PySigmoid {
61 training: bool,
62}
63
64#[pymethods]
65impl PySigmoid {
66 #[new]
67 fn new() -> (Self, PyModule) {
68 (Self { training: true }, PyModule::new())
69 }
70
71 fn forward(&self, input: &PyTensor) -> PyResult<PyTensor> {
73 let result = py_result!(input.tensor.sigmoid())?;
74 Ok(PyTensor { tensor: result })
75 }
76
77 fn __repr__(&self) -> String {
79 "Sigmoid()".to_string()
80 }
81
82 fn train(&mut self, mode: Option<bool>) {
84 self.training = mode.unwrap_or(true);
85 }
86
87 fn eval(&mut self) {
89 self.training = false;
90 }
91
92 fn training(&self) -> bool {
94 self.training
95 }
96}
97
98#[pyclass(name = "Tanh", extends = PyModule)]
100pub struct PyTanh {
101 training: bool,
102}
103
104#[pymethods]
105impl PyTanh {
106 #[new]
107 fn new() -> (Self, PyModule) {
108 (Self { training: true }, PyModule::new())
109 }
110
111 fn forward(&self, input: &PyTensor) -> PyResult<PyTensor> {
113 let result = py_result!(input.tensor.tanh())?;
114 Ok(PyTensor { tensor: result })
115 }
116
117 fn __repr__(&self) -> String {
119 "Tanh()".to_string()
120 }
121
122 fn train(&mut self, mode: Option<bool>) {
124 self.training = mode.unwrap_or(true);
125 }
126
127 fn eval(&mut self) {
129 self.training = false;
130 }
131
132 fn training(&self) -> bool {
134 self.training
135 }
136}