smelte_rs/nn/layers/
linear.rs1use crate::traits::{Tensor, TensorOps};
2use crate::SmeltError;
3
4#[derive(Clone)]
6pub struct Linear<T: Tensor> {
7 weight: T,
8 bias: T,
9}
10
11impl<T: Tensor + TensorOps<T>> Linear<T> {
12 pub fn new(weight: T, bias: T) -> Self {
14 Self { weight, bias }
15 }
16
17 pub fn forward(&self, tensor: &T, out: &mut T) -> Result<(), SmeltError> {
19 T::matmul_t(tensor, &self.weight, out)?;
20 T::add(&self.bias, out)?;
21 Ok(())
22 }
23
24 pub fn weight(&self) -> &T {
26 &self.weight
27 }
28
29 pub fn bias(&self) -> &T {
31 &self.bias
32 }
33}
34
35#[derive(Clone)]
37pub struct LinearT<T: Tensor> {
38 weight: T,
39 bias: T,
40}
41
42impl<T: Tensor + TensorOps<T>> LinearT<T> {
43 pub fn new(weight: T, bias: T) -> Self {
45 Self { weight, bias }
46 }
47
48 pub fn forward(&self, tensor: &T, out: &mut T) -> Result<(), SmeltError> {
50 T::matmul_t(tensor, &self.weight, out)?;
51 T::add(&self.bias, out)?;
52 Ok(())
53 }
54}
55
56#[derive(Clone)]
58pub struct UnbiasedLinear<T: Tensor> {
59 weight: T,
60}
61
62impl<T: Tensor + TensorOps<T>> UnbiasedLinear<T> {
63 pub fn new(weight: T) -> Self {
65 Self { weight }
66 }
67
68 pub fn forward(&self, tensor: &T, out: &mut T) -> Result<(), SmeltError> {
70 T::matmul_t(tensor, &self.weight, out)?;
71 Ok(())
72 }
73}
74
75#[cfg(test)]
76mod tests {
77 use super::*;
78 use crate::cpu::f32::Tensor;
79
80 #[test]
81 fn test_linear() {
82 let zeros = Tensor::zeros(vec![2, 2]);
83 let weights = Tensor::zeros(vec![3, 2]);
84 let bias = Tensor::zeros(vec![3]);
85 let mut out = Tensor::zeros(vec![2, 3]);
86
87 let linear = Linear::new(weights, bias);
88
89 linear.forward(&zeros, &mut out).unwrap();
90 }
91}