smelte_rs/nn/layers/
linear.rs

1use crate::traits::{Tensor, TensorOps};
2use crate::SmeltError;
3
4/// Linear layer, applies matmul(x, W.T) + b
5#[derive(Clone)]
6pub struct Linear<T: Tensor> {
7    weight: T,
8    bias: T,
9}
10
11impl<T: Tensor + TensorOps<T>> Linear<T> {
12    /// Linear layer creation
13    pub fn new(weight: T, bias: T) -> Self {
14        Self { weight, bias }
15    }
16
17    /// Forward pass
18    pub fn forward(&self, tensor: &T, out: &mut T) -> Result<(), SmeltError> {
19        T::matmul_t(tensor, &self.weight, out)?;
20        T::add(&self.bias, out)?;
21        Ok(())
22    }
23
24    /// TODO
25    pub fn weight(&self) -> &T {
26        &self.weight
27    }
28
29    /// TODO
30    pub fn bias(&self) -> &T {
31        &self.bias
32    }
33}
34
35/// Linear layer, applies matmul(x, W) + b (also named conv1d sometimes)
36#[derive(Clone)]
37pub struct LinearT<T: Tensor> {
38    weight: T,
39    bias: T,
40}
41
42impl<T: Tensor + TensorOps<T>> LinearT<T> {
43    /// LinearT layer creation
44    pub fn new(weight: T, bias: T) -> Self {
45        Self { weight, bias }
46    }
47
48    /// Forward pass
49    pub fn forward(&self, tensor: &T, out: &mut T) -> Result<(), SmeltError> {
50        T::matmul_t(tensor, &self.weight, out)?;
51        T::add(&self.bias, out)?;
52        Ok(())
53    }
54}
55
56/// UnbiasedLinear layer, applies matmul(x, W.T)
57#[derive(Clone)]
58pub struct UnbiasedLinear<T: Tensor> {
59    weight: T,
60}
61
62impl<T: Tensor + TensorOps<T>> UnbiasedLinear<T> {
63    /// UnbiasedLinear layer creation
64    pub fn new(weight: T) -> Self {
65        Self { weight }
66    }
67
68    /// Forward pass
69    pub fn forward(&self, tensor: &T, out: &mut T) -> Result<(), SmeltError> {
70        T::matmul_t(tensor, &self.weight, out)?;
71        Ok(())
72    }
73}
74
75#[cfg(test)]
76mod tests {
77    use super::*;
78    use crate::cpu::f32::Tensor;
79
80    #[test]
81    fn test_linear() {
82        let zeros = Tensor::zeros(vec![2, 2]);
83        let weights = Tensor::zeros(vec![3, 2]);
84        let bias = Tensor::zeros(vec![3]);
85        let mut out = Tensor::zeros(vec![2, 3]);
86
87        let linear = Linear::new(weights, bias);
88
89        linear.forward(&zeros, &mut out).unwrap();
90    }
91}