Skip to main content

torsh_python/nn/
conv.rs

1//! Convolutional neural network layers
2
3use super::module::PyModule;
4use crate::{error::PyResult, py_result, tensor::PyTensor};
5use pyo3::prelude::*;
6use pyo3::types::PyAny;
7use std::collections::HashMap;
8use torsh_tensor::Tensor;
9
10/// 2D Convolutional layer
11#[pyclass(name = "Conv2d", extends = PyModule)]
12pub struct PyConv2d {
13    weight: Tensor<f32>,
14    bias: Option<Tensor<f32>>,
15    in_channels: usize,
16    out_channels: usize,
17    kernel_size: (usize, usize),
18    stride: (usize, usize),
19    padding: (usize, usize),
20    dilation: (usize, usize),
21    groups: usize,
22    has_bias: bool,
23    training: bool,
24}
25
26#[pymethods]
27impl PyConv2d {
28    #[new]
29    fn new(
30        in_channels: usize,
31        out_channels: usize,
32        kernel_size: Py<PyAny>,
33        stride: Option<Py<PyAny>>,
34        padding: Option<Py<PyAny>>,
35        dilation: Option<Py<PyAny>>,
36        groups: Option<usize>,
37        bias: Option<bool>,
38    ) -> PyResult<(Self, PyModule)> {
39        let has_bias = bias.unwrap_or(true);
40        let groups = groups.unwrap_or(1);
41
42        // Parse kernel size
43        let kernel_size = Python::attach(|py| -> PyResult<(usize, usize)> {
44            if let Ok(size) = kernel_size.extract::<usize>(py) {
45                Ok((size, size))
46            } else if let Ok(tuple) = kernel_size.extract::<(usize, usize)>(py) {
47                Ok(tuple)
48            } else {
49                Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
50                    "kernel_size must be an integer or tuple of integers",
51                ))
52            }
53        })?;
54
55        // Parse stride (default to kernel_size)
56        let stride = if let Some(stride_obj) = stride {
57            Python::attach(|py| -> PyResult<(usize, usize)> {
58                if let Ok(stride) = stride_obj.extract::<usize>(py) {
59                    Ok((stride, stride))
60                } else if let Ok(tuple) = stride_obj.extract::<(usize, usize)>(py) {
61                    Ok(tuple)
62                } else {
63                    Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
64                        "stride must be an integer or tuple of integers",
65                    ))
66                }
67            })?
68        } else {
69            (1, 1)
70        };
71
72        // Parse padding (default to 0)
73        let padding = if let Some(padding_obj) = padding {
74            Python::attach(|py| -> PyResult<(usize, usize)> {
75                if let Ok(padding) = padding_obj.extract::<usize>(py) {
76                    Ok((padding, padding))
77                } else if let Ok(tuple) = padding_obj.extract::<(usize, usize)>(py) {
78                    Ok(tuple)
79                } else {
80                    Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
81                        "padding must be an integer or tuple of integers",
82                    ))
83                }
84            })?
85        } else {
86            (0, 0)
87        };
88
89        // Parse dilation (default to 1)
90        let dilation = if let Some(dilation_obj) = dilation {
91            Python::attach(|py| -> PyResult<(usize, usize)> {
92                if let Ok(dilation) = dilation_obj.extract::<usize>(py) {
93                    Ok((dilation, dilation))
94                } else if let Ok(tuple) = dilation_obj.extract::<(usize, usize)>(py) {
95                    Ok(tuple)
96                } else {
97                    Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
98                        "dilation must be an integer or tuple of integers",
99                    ))
100                }
101            })?
102        } else {
103            (1, 1)
104        };
105
106        // Initialize weight with Kaiming uniform initialization
107        let weight_shape = vec![
108            out_channels,
109            in_channels / groups,
110            kernel_size.0,
111            kernel_size.1,
112        ];
113        let weight = py_result!(torsh_tensor::creation::randn(&weight_shape))?.requires_grad_(true);
114
115        // Initialize bias if needed
116        let bias = if has_bias {
117            let bias_shape = vec![out_channels];
118            Some(py_result!(torsh_tensor::creation::zeros(&bias_shape))?.requires_grad_(true))
119        } else {
120            None
121        };
122
123        Ok((
124            Self {
125                weight,
126                bias,
127                in_channels,
128                out_channels,
129                kernel_size,
130                stride,
131                padding,
132                dilation,
133                groups,
134                has_bias,
135                training: true,
136            },
137            PyModule::new(),
138        ))
139    }
140
141    /// Forward pass through the convolutional layer
142    fn forward(&mut self, input: &PyTensor) -> PyResult<PyTensor> {
143        // Perform 2D convolution
144        let result = py_result!(input.tensor.conv2d(
145            &self.weight,
146            self.bias.as_ref(),
147            self.stride,
148            self.padding,
149            self.dilation,
150            self.groups
151        ))?;
152
153        Ok(PyTensor { tensor: result })
154    }
155
156    /// Get layer parameters
157    fn parameters(&self) -> PyResult<Vec<PyTensor>> {
158        let mut params = vec![PyTensor {
159            tensor: self.weight.clone(),
160        }];
161        if let Some(ref bias) = self.bias {
162            params.push(PyTensor {
163                tensor: bias.clone(),
164            });
165        }
166        Ok(params)
167    }
168
169    /// Get named parameters
170    fn named_parameters(&self) -> PyResult<HashMap<String, PyTensor>> {
171        let mut params = HashMap::new();
172        params.insert(
173            "weight".to_string(),
174            PyTensor {
175                tensor: self.weight.clone(),
176            },
177        );
178        if let Some(ref bias) = self.bias {
179            params.insert(
180                "bias".to_string(),
181                PyTensor {
182                    tensor: bias.clone(),
183                },
184            );
185        }
186        Ok(params)
187    }
188
189    /// Set training mode
190    fn train(&mut self, mode: Option<bool>) -> PyResult<()> {
191        self.training = mode.unwrap_or(true);
192        Ok(())
193    }
194
195    /// Set evaluation mode
196    fn eval(&mut self) -> PyResult<()> {
197        self.training = false;
198        Ok(())
199    }
200
201    /// Get layer info string
202    fn extra_repr(&self) -> String {
203        let bias_str = if self.has_bias {
204            "bias=True"
205        } else {
206            "bias=False"
207        };
208        format!(
209            "{}, {}, kernel_size={:?}, stride={:?}, padding={:?}, dilation={:?}, groups={}, {}",
210            self.in_channels,
211            self.out_channels,
212            self.kernel_size,
213            self.stride,
214            self.padding,
215            self.dilation,
216            self.groups,
217            bias_str
218        )
219    }
220
221    /// String representation
222    fn __repr__(&self) -> String {
223        format!("Conv2d({})", self.extra_repr())
224    }
225}
226
227/// 1D Convolutional layer
228#[pyclass(name = "Conv1d", extends = PyModule)]
229pub struct PyConv1d {
230    weight: Tensor<f32>,
231    bias: Option<Tensor<f32>>,
232    in_channels: usize,
233    out_channels: usize,
234    kernel_size: usize,
235    stride: usize,
236    padding: usize,
237    dilation: usize,
238    groups: usize,
239    has_bias: bool,
240    training: bool,
241}
242
243#[pymethods]
244impl PyConv1d {
245    #[new]
246    fn new(
247        in_channels: usize,
248        out_channels: usize,
249        kernel_size: usize,
250        stride: Option<usize>,
251        padding: Option<usize>,
252        dilation: Option<usize>,
253        groups: Option<usize>,
254        bias: Option<bool>,
255    ) -> PyResult<(Self, PyModule)> {
256        let has_bias = bias.unwrap_or(true);
257        let stride = stride.unwrap_or(1);
258        let padding = padding.unwrap_or(0);
259        let dilation = dilation.unwrap_or(1);
260        let groups = groups.unwrap_or(1);
261
262        // Initialize weight with Kaiming uniform initialization
263        let weight_shape = vec![out_channels, in_channels / groups, kernel_size];
264        let weight = py_result!(torsh_tensor::creation::randn(&weight_shape))?.requires_grad_(true);
265
266        // Initialize bias if needed
267        let bias = if has_bias {
268            let bias_shape = vec![out_channels];
269            Some(py_result!(torsh_tensor::creation::zeros(&bias_shape))?.requires_grad_(true))
270        } else {
271            None
272        };
273
274        Ok((
275            Self {
276                weight,
277                bias,
278                in_channels,
279                out_channels,
280                kernel_size,
281                stride,
282                padding,
283                dilation,
284                groups,
285                has_bias,
286                training: true,
287            },
288            PyModule::new(),
289        ))
290    }
291
292    /// Forward pass through the 1D convolutional layer
293    fn forward(&mut self, input: &PyTensor) -> PyResult<PyTensor> {
294        // Perform 1D convolution
295        let result = py_result!(input.tensor.conv1d(
296            &self.weight,
297            self.bias.as_ref(),
298            self.stride,
299            self.padding,
300            self.dilation,
301            self.groups
302        ))?;
303
304        Ok(PyTensor { tensor: result })
305    }
306
307    /// Get layer parameters
308    fn parameters(&self) -> PyResult<Vec<PyTensor>> {
309        let mut params = vec![PyTensor {
310            tensor: self.weight.clone(),
311        }];
312        if let Some(ref bias) = self.bias {
313            params.push(PyTensor {
314                tensor: bias.clone(),
315            });
316        }
317        Ok(params)
318    }
319
320    /// Get named parameters
321    fn named_parameters(&self) -> PyResult<HashMap<String, PyTensor>> {
322        let mut params = HashMap::new();
323        params.insert(
324            "weight".to_string(),
325            PyTensor {
326                tensor: self.weight.clone(),
327            },
328        );
329        if let Some(ref bias) = self.bias {
330            params.insert(
331                "bias".to_string(),
332                PyTensor {
333                    tensor: bias.clone(),
334                },
335            );
336        }
337        Ok(params)
338    }
339
340    /// Set training mode
341    fn train(&mut self, mode: Option<bool>) -> PyResult<()> {
342        self.training = mode.unwrap_or(true);
343        Ok(())
344    }
345
346    /// Set evaluation mode
347    fn eval(&mut self) -> PyResult<()> {
348        self.training = false;
349        Ok(())
350    }
351
352    /// String representation
353    fn __repr__(&self) -> String {
354        let bias_str = if self.has_bias {
355            "bias=True"
356        } else {
357            "bias=False"
358        };
359        format!(
360            "Conv1d({}, {}, kernel_size={}, stride={}, padding={}, dilation={}, groups={}, {})",
361            self.in_channels,
362            self.out_channels,
363            self.kernel_size,
364            self.stride,
365            self.padding,
366            self.dilation,
367            self.groups,
368            bias_str
369        )
370    }
371}