1use super::module::PyModule;
4use crate::{error::PyResult, py_result, tensor::PyTensor};
5use pyo3::prelude::*;
6use pyo3::types::PyAny;
7use std::collections::HashMap;
8use torsh_tensor::Tensor;
9
10#[pyclass(name = "Conv2d", extends = PyModule)]
12pub struct PyConv2d {
13 weight: Tensor<f32>,
14 bias: Option<Tensor<f32>>,
15 in_channels: usize,
16 out_channels: usize,
17 kernel_size: (usize, usize),
18 stride: (usize, usize),
19 padding: (usize, usize),
20 dilation: (usize, usize),
21 groups: usize,
22 has_bias: bool,
23 training: bool,
24}
25
26#[pymethods]
27impl PyConv2d {
28 #[new]
29 fn new(
30 in_channels: usize,
31 out_channels: usize,
32 kernel_size: Py<PyAny>,
33 stride: Option<Py<PyAny>>,
34 padding: Option<Py<PyAny>>,
35 dilation: Option<Py<PyAny>>,
36 groups: Option<usize>,
37 bias: Option<bool>,
38 ) -> PyResult<(Self, PyModule)> {
39 let has_bias = bias.unwrap_or(true);
40 let groups = groups.unwrap_or(1);
41
42 let kernel_size = Python::attach(|py| -> PyResult<(usize, usize)> {
44 if let Ok(size) = kernel_size.extract::<usize>(py) {
45 Ok((size, size))
46 } else if let Ok(tuple) = kernel_size.extract::<(usize, usize)>(py) {
47 Ok(tuple)
48 } else {
49 Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
50 "kernel_size must be an integer or tuple of integers",
51 ))
52 }
53 })?;
54
55 let stride = if let Some(stride_obj) = stride {
57 Python::attach(|py| -> PyResult<(usize, usize)> {
58 if let Ok(stride) = stride_obj.extract::<usize>(py) {
59 Ok((stride, stride))
60 } else if let Ok(tuple) = stride_obj.extract::<(usize, usize)>(py) {
61 Ok(tuple)
62 } else {
63 Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
64 "stride must be an integer or tuple of integers",
65 ))
66 }
67 })?
68 } else {
69 (1, 1)
70 };
71
72 let padding = if let Some(padding_obj) = padding {
74 Python::attach(|py| -> PyResult<(usize, usize)> {
75 if let Ok(padding) = padding_obj.extract::<usize>(py) {
76 Ok((padding, padding))
77 } else if let Ok(tuple) = padding_obj.extract::<(usize, usize)>(py) {
78 Ok(tuple)
79 } else {
80 Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
81 "padding must be an integer or tuple of integers",
82 ))
83 }
84 })?
85 } else {
86 (0, 0)
87 };
88
89 let dilation = if let Some(dilation_obj) = dilation {
91 Python::attach(|py| -> PyResult<(usize, usize)> {
92 if let Ok(dilation) = dilation_obj.extract::<usize>(py) {
93 Ok((dilation, dilation))
94 } else if let Ok(tuple) = dilation_obj.extract::<(usize, usize)>(py) {
95 Ok(tuple)
96 } else {
97 Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
98 "dilation must be an integer or tuple of integers",
99 ))
100 }
101 })?
102 } else {
103 (1, 1)
104 };
105
106 let weight_shape = vec![
108 out_channels,
109 in_channels / groups,
110 kernel_size.0,
111 kernel_size.1,
112 ];
113 let weight = py_result!(torsh_tensor::creation::randn(&weight_shape))?.requires_grad_(true);
114
115 let bias = if has_bias {
117 let bias_shape = vec![out_channels];
118 Some(py_result!(torsh_tensor::creation::zeros(&bias_shape))?.requires_grad_(true))
119 } else {
120 None
121 };
122
123 Ok((
124 Self {
125 weight,
126 bias,
127 in_channels,
128 out_channels,
129 kernel_size,
130 stride,
131 padding,
132 dilation,
133 groups,
134 has_bias,
135 training: true,
136 },
137 PyModule::new(),
138 ))
139 }
140
141 fn forward(&mut self, input: &PyTensor) -> PyResult<PyTensor> {
143 let result = py_result!(input.tensor.conv2d(
145 &self.weight,
146 self.bias.as_ref(),
147 self.stride,
148 self.padding,
149 self.dilation,
150 self.groups
151 ))?;
152
153 Ok(PyTensor { tensor: result })
154 }
155
156 fn parameters(&self) -> PyResult<Vec<PyTensor>> {
158 let mut params = vec![PyTensor {
159 tensor: self.weight.clone(),
160 }];
161 if let Some(ref bias) = self.bias {
162 params.push(PyTensor {
163 tensor: bias.clone(),
164 });
165 }
166 Ok(params)
167 }
168
169 fn named_parameters(&self) -> PyResult<HashMap<String, PyTensor>> {
171 let mut params = HashMap::new();
172 params.insert(
173 "weight".to_string(),
174 PyTensor {
175 tensor: self.weight.clone(),
176 },
177 );
178 if let Some(ref bias) = self.bias {
179 params.insert(
180 "bias".to_string(),
181 PyTensor {
182 tensor: bias.clone(),
183 },
184 );
185 }
186 Ok(params)
187 }
188
189 fn train(&mut self, mode: Option<bool>) -> PyResult<()> {
191 self.training = mode.unwrap_or(true);
192 Ok(())
193 }
194
195 fn eval(&mut self) -> PyResult<()> {
197 self.training = false;
198 Ok(())
199 }
200
201 fn extra_repr(&self) -> String {
203 let bias_str = if self.has_bias {
204 "bias=True"
205 } else {
206 "bias=False"
207 };
208 format!(
209 "{}, {}, kernel_size={:?}, stride={:?}, padding={:?}, dilation={:?}, groups={}, {}",
210 self.in_channels,
211 self.out_channels,
212 self.kernel_size,
213 self.stride,
214 self.padding,
215 self.dilation,
216 self.groups,
217 bias_str
218 )
219 }
220
221 fn __repr__(&self) -> String {
223 format!("Conv2d({})", self.extra_repr())
224 }
225}
226
227#[pyclass(name = "Conv1d", extends = PyModule)]
229pub struct PyConv1d {
230 weight: Tensor<f32>,
231 bias: Option<Tensor<f32>>,
232 in_channels: usize,
233 out_channels: usize,
234 kernel_size: usize,
235 stride: usize,
236 padding: usize,
237 dilation: usize,
238 groups: usize,
239 has_bias: bool,
240 training: bool,
241}
242
243#[pymethods]
244impl PyConv1d {
245 #[new]
246 fn new(
247 in_channels: usize,
248 out_channels: usize,
249 kernel_size: usize,
250 stride: Option<usize>,
251 padding: Option<usize>,
252 dilation: Option<usize>,
253 groups: Option<usize>,
254 bias: Option<bool>,
255 ) -> PyResult<(Self, PyModule)> {
256 let has_bias = bias.unwrap_or(true);
257 let stride = stride.unwrap_or(1);
258 let padding = padding.unwrap_or(0);
259 let dilation = dilation.unwrap_or(1);
260 let groups = groups.unwrap_or(1);
261
262 let weight_shape = vec![out_channels, in_channels / groups, kernel_size];
264 let weight = py_result!(torsh_tensor::creation::randn(&weight_shape))?.requires_grad_(true);
265
266 let bias = if has_bias {
268 let bias_shape = vec![out_channels];
269 Some(py_result!(torsh_tensor::creation::zeros(&bias_shape))?.requires_grad_(true))
270 } else {
271 None
272 };
273
274 Ok((
275 Self {
276 weight,
277 bias,
278 in_channels,
279 out_channels,
280 kernel_size,
281 stride,
282 padding,
283 dilation,
284 groups,
285 has_bias,
286 training: true,
287 },
288 PyModule::new(),
289 ))
290 }
291
292 fn forward(&mut self, input: &PyTensor) -> PyResult<PyTensor> {
294 let result = py_result!(input.tensor.conv1d(
296 &self.weight,
297 self.bias.as_ref(),
298 self.stride,
299 self.padding,
300 self.dilation,
301 self.groups
302 ))?;
303
304 Ok(PyTensor { tensor: result })
305 }
306
307 fn parameters(&self) -> PyResult<Vec<PyTensor>> {
309 let mut params = vec![PyTensor {
310 tensor: self.weight.clone(),
311 }];
312 if let Some(ref bias) = self.bias {
313 params.push(PyTensor {
314 tensor: bias.clone(),
315 });
316 }
317 Ok(params)
318 }
319
320 fn named_parameters(&self) -> PyResult<HashMap<String, PyTensor>> {
322 let mut params = HashMap::new();
323 params.insert(
324 "weight".to_string(),
325 PyTensor {
326 tensor: self.weight.clone(),
327 },
328 );
329 if let Some(ref bias) = self.bias {
330 params.insert(
331 "bias".to_string(),
332 PyTensor {
333 tensor: bias.clone(),
334 },
335 );
336 }
337 Ok(params)
338 }
339
340 fn train(&mut self, mode: Option<bool>) -> PyResult<()> {
342 self.training = mode.unwrap_or(true);
343 Ok(())
344 }
345
346 fn eval(&mut self) -> PyResult<()> {
348 self.training = false;
349 Ok(())
350 }
351
352 fn __repr__(&self) -> String {
354 let bias_str = if self.has_bias {
355 "bias=True"
356 } else {
357 "bias=False"
358 };
359 format!(
360 "Conv1d({}, {}, kernel_size={}, stride={}, padding={}, dilation={}, groups={}, {})",
361 self.in_channels,
362 self.out_channels,
363 self.kernel_size,
364 self.stride,
365 self.padding,
366 self.dilation,
367 self.groups,
368 bias_str
369 )
370 }
371}