1use svod_dtype::DType;
2
3use crate::Tensor;
4use crate::nn::Layer;
5
6type Result<T> = crate::Result<T>;
7
8pub struct Conv1d {
13 pub weight: Tensor,
14 pub bias: Option<Tensor>,
15 pub stride: usize,
16 pub padding: (isize, isize),
17}
18
19impl Conv1d {
20 pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
22 Self { weight, bias, stride: 1, padding: (0, 0) }
23 }
24
25 pub fn with_dims(in_channels: usize, out_channels: usize, kernel: usize, dtype: DType) -> Self {
27 let weight_data: Vec<f32> =
28 (0..in_channels * out_channels * kernel).map(|i| ((i as f32) * 0.1).sin() * 0.1).collect();
29 let weight = Tensor::from_slice(&weight_data)
30 .try_reshape([out_channels as isize, in_channels as isize, kernel as isize])
31 .expect("conv1d weight reshape failed");
32 let bias = Tensor::full(&[out_channels], 0.0, dtype).expect("conv1d bias creation failed");
33 Self { weight, bias: Some(bias), stride: 1, padding: (0, 0) }
34 }
35
36 pub fn with_stride(mut self, stride: usize) -> Self {
37 self.stride = stride;
38 self
39 }
40
41 pub fn with_padding(mut self, padding: (isize, isize)) -> Self {
42 self.padding = padding;
43 self
44 }
45}
46
47impl Layer for Conv1d {
48 fn forward(&self, x: &Tensor) -> Result<Tensor> {
49 x.conv2d()
50 .weight(&self.weight)
51 .maybe_bias(self.bias.as_ref())
52 .stride(&[self.stride])
53 .padding(&[self.padding])
54 .call()
55 }
56}