torsh_python/nn/
linear.rs1use super::module::PyModule;
4use crate::{device::PyDevice, error::PyResult, py_result, tensor::PyTensor};
5use pyo3::prelude::*;
6use std::collections::HashMap;
7use torsh_tensor::Tensor;
8
9#[pyclass(name = "Linear", extends = PyModule)]
11pub struct PyLinear {
12 weight: Tensor<f32>,
13 bias: Option<Tensor<f32>>,
14 in_features: usize,
15 out_features: usize,
16 has_bias: bool,
17 training: bool,
18}
19
20#[pymethods]
21impl PyLinear {
22 #[new]
23 fn new(
24 in_features: usize,
25 out_features: usize,
26 bias: Option<bool>,
27 ) -> PyResult<(Self, PyModule)> {
28 let has_bias = bias.unwrap_or(true);
29
30 let weight_shape = vec![out_features, in_features];
32 let weight = py_result!(torsh_tensor::creation::randn(&weight_shape))?.requires_grad_(true);
33
34 let bias = if has_bias {
36 let bias_shape = vec![out_features];
37 Some(py_result!(torsh_tensor::creation::zeros(&bias_shape))?.requires_grad_(true))
38 } else {
39 None
40 };
41
42 Ok((
43 Self {
44 weight,
45 bias,
46 in_features,
47 out_features,
48 has_bias,
49 training: true,
50 },
51 PyModule::new(),
52 ))
53 }
54
55 fn forward(&self, input: &PyTensor) -> PyResult<PyTensor> {
57 let result = py_result!(input.tensor.matmul(&self.weight))?;
59
60 let result = if let Some(ref bias) = self.bias {
62 py_result!(result.add(bias))?
63 } else {
64 result
65 };
66
67 Ok(PyTensor { tensor: result })
68 }
69
70 fn parameters(&self) -> PyResult<Vec<PyTensor>> {
72 let mut params = Vec::new();
73
74 params.push(PyTensor {
76 tensor: self.weight.clone(),
77 });
78
79 if let Some(ref bias) = self.bias {
81 params.push(PyTensor {
82 tensor: bias.clone(),
83 });
84 }
85
86 Ok(params)
87 }
88
89 fn named_parameters(&self) -> PyResult<HashMap<String, PyTensor>> {
91 let mut named_params = HashMap::new();
92
93 named_params.insert(
95 "weight".to_string(),
96 PyTensor {
97 tensor: self.weight.clone(),
98 },
99 );
100
101 if let Some(ref bias) = self.bias {
103 named_params.insert(
104 "bias".to_string(),
105 PyTensor {
106 tensor: bias.clone(),
107 },
108 );
109 }
110
111 Ok(named_params)
112 }
113
114 fn train(&mut self, mode: Option<bool>) {
116 self.training = mode.unwrap_or(true);
117 }
120
121 fn eval(&mut self) {
123 self.training = false;
124 }
125
126 fn to(&mut self, device: PyDevice) -> PyResult<()> {
128 self.weight = py_result!(self.weight.clone().to(device.device))?;
130
131 if let Some(ref bias) = self.bias {
133 self.bias = Some(py_result!(bias.clone().to(device.device))?);
134 }
135
136 Ok(())
137 }
138
139 fn zero_grad(&mut self) {
141 let _ = self.weight.zero_grad();
143 if let Some(ref mut bias) = self.bias {
144 let _ = bias.zero_grad();
145 }
146 }
147
148 fn __repr__(&self) -> String {
150 format!(
151 "Linear(in_features={}, out_features={}, bias={})",
152 self.in_features, self.out_features, self.has_bias
153 )
154 }
155
156 #[getter]
158 fn in_features(&self) -> usize {
159 self.in_features
160 }
161
162 #[getter]
164 fn out_features(&self) -> usize {
165 self.out_features
166 }
167
168 #[getter]
170 fn bias(&self) -> bool {
171 self.has_bias
172 }
173
174 fn training(&self) -> bool {
176 self.training
177 }
178
179 #[getter]
181 fn weight(&self) -> PyResult<PyTensor> {
182 Ok(PyTensor {
183 tensor: self.weight.clone(),
184 })
185 }
186
187 fn load_state_dict(&mut self, state_dict: HashMap<String, PyTensor>) -> PyResult<()> {
189 if let Some(weight_tensor) = state_dict.get("weight") {
191 self.weight = weight_tensor.tensor.clone();
192 }
193
194 if self.has_bias {
196 if let Some(bias_tensor) = state_dict.get("bias") {
197 self.bias = Some(bias_tensor.tensor.clone());
198 }
199 }
200
201 Ok(())
202 }
203}