Skip to main content

svod_tensor/nn/
conv1d.rs

1use svod_dtype::DType;
2
3use crate::Tensor;
4use crate::nn::Layer;
5
6type Result<T> = crate::Result<T>;
7
8/// 1D convolution: `y = conv1d(x, weight) + bias`.
9///
10/// Weight shape: `[out_channels, in_channels, kernel]`, optional bias shape: `[out_channels]`.
11/// `stride` and `padding` are stored on the module so [`Layer::forward`] stays parameter-free.
12pub struct Conv1d {
13    pub weight: Tensor,
14    pub bias: Option<Tensor>,
15    pub stride: usize,
16    pub padding: (isize, isize),
17}
18
19impl Conv1d {
20    /// Create a Conv1d from existing weight (and optional bias) tensors.
21    pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
22        Self { weight, bias, stride: 1, padding: (0, 0) }
23    }
24
25    /// Create a Conv1d with deterministic `sin()` initialization, zero bias.
26    pub fn with_dims(in_channels: usize, out_channels: usize, kernel: usize, dtype: DType) -> Self {
27        let weight_data: Vec<f32> =
28            (0..in_channels * out_channels * kernel).map(|i| ((i as f32) * 0.1).sin() * 0.1).collect();
29        let weight = Tensor::from_slice(&weight_data)
30            .try_reshape([out_channels as isize, in_channels as isize, kernel as isize])
31            .expect("conv1d weight reshape failed");
32        let bias = Tensor::full(&[out_channels], 0.0, dtype).expect("conv1d bias creation failed");
33        Self { weight, bias: Some(bias), stride: 1, padding: (0, 0) }
34    }
35
36    pub fn with_stride(mut self, stride: usize) -> Self {
37        self.stride = stride;
38        self
39    }
40
41    pub fn with_padding(mut self, padding: (isize, isize)) -> Self {
42        self.padding = padding;
43        self
44    }
45}
46
47impl Layer for Conv1d {
48    fn forward(&self, x: &Tensor) -> Result<Tensor> {
49        x.conv2d()
50            .weight(&self.weight)
51            .maybe_bias(self.bias.as_ref())
52            .stride(&[self.stride])
53            .padding(&[self.padding])
54            .call()
55    }
56}