zenu_layer/layers/
linear.rs

1use std::collections::HashMap;
2
3use crate::{Module, Parameters};
4use rand_distr::{Distribution, StandardNormal};
5use zenu_autograd::{
6    creator::{rand::normal, zeros::zeros},
7    functions::{matmul::matmul, transpose::transpose},
8    Variable,
9};
10use zenu_matrix::{device::Device, num::Num};
11
12pub struct Linear<T: Num, D: Device> {
13    in_features: usize,
14    out_features: usize,
15    pub weight: Variable<T, D>,
16    pub bias: Option<Variable<T, D>>,
17}
18
19impl<T: Num, D: Device> Module<T, D> for Linear<T, D> {
20    type Input = Variable<T, D>;
21    type Output = Variable<T, D>;
22    fn call(&self, input: Variable<T, D>) -> Variable<T, D> {
23        let weight_t = transpose(self.weight.clone());
24        let output = matmul(input, weight_t);
25        if let Some(bias) = &self.bias {
26            output.set_name("linear.intermediate_output");
27            output + bias.clone()
28        } else {
29            output
30        }
31    }
32}
33
34impl<T: Num, D: Device> Parameters<T, D> for Linear<T, D> {
35    fn weights(&self) -> HashMap<String, Variable<T, D>> {
36        let mut weights = HashMap::new();
37        weights.insert("linear.weight".to_string(), self.weight.clone());
38        weights
39    }
40
41    fn biases(&self) -> HashMap<String, Variable<T, D>> {
42        let mut biases = HashMap::new();
43        if let Some(bias) = &self.bias {
44            biases.insert("linear.bias".to_string(), bias.clone());
45        }
46        biases
47    }
48}
49
50impl<T: Num, D: Device> Linear<T, D> {
51    #[must_use]
52    pub fn new(in_features: usize, out_features: usize, use_bias: bool) -> Self
53    where
54        StandardNormal: Distribution<T>,
55    {
56        let weight = normal(T::zero(), T::one(), None, [out_features, in_features]);
57        weight
58            .get_data_mut()
59            .to_ref_mut()
60            .div_scalar_assign(T::from_usize(in_features).sqrt());
61        let bias = if use_bias {
62            let bias = zeros([out_features]);
63            bias.set_name("linear.bias");
64            bias.set_is_train(true);
65            Some(bias)
66        } else {
67            None
68        };
69
70        weight.set_is_train(true);
71        weight.set_name("linear.weight");
72
73        Self {
74            in_features,
75            out_features,
76            weight,
77            bias,
78        }
79    }
80
81    #[must_use]
82    pub fn to<Dout: Device>(self) -> Linear<T, Dout> {
83        Linear {
84            in_features: self.in_features,
85            out_features: self.out_features,
86            weight: self.weight.to(),
87            bias: self.bias.map(|b| b.to()),
88        }
89    }
90}
91
92// #[cfg(test)]
93// mod linear {
94//     use zenu_autograd::creator::rand::normal;
95//     use zenu_matrix::{device::Device, dim::DimTrait, operation::mul::matmul};
96//     use zenu_test::{assert_mat_eq_epsilon, assert_val_eq, run_test};
97//
98//     use crate::{Module, StateDict};
99//
100//     use super::Linear;
101//
102//     fn with_bias<D: Device>() {
103//         let layer = Linear::<f32, D>::new(3, 2, true);
104//         let input = normal::<_, _, D>(0., 1., None, [5, 3]);
105//         let output = layer.call(input.clone());
106//         assert_eq!(output.get_data().shape().slice(), [5, 2]);
107//
108//         let parameters = layer.to_json();
109//
110//         let ans = matmul(
111//             &input.get_data().to_ref(),
112//             &layer.weight.get_data().to_ref(),
113//         ) + &layer.bias.unwrap().get_data().to_ref();
114//
115//         assert_val_eq!(output.clone(), ans, 1e-4);
116//
117//         let new_layer = Linear::<f32, D>::from_json(&parameters);
118//         let new_output = new_layer.call(input.clone());
119//
120//         assert_mat_eq_epsilon!(output.get_data(), new_output.get_data(), 1e-4);
121//     }
122//     run_test!(with_bias, with_bias_cpu, with_bias_gpu);
123//
124//     fn without_bias<D: Device>() {
125//         let layer = Linear::<f32, D>::new(3, 2, false);
126//         let input = normal::<_, _, D>(0., 1., None, [5, 3]);
127//         let output = layer.call(input.clone());
128//         assert_eq!(output.get_data().shape().slice(), [5, 2]);
129//
130//         let parameters = layer.to_json();
131//         let weight = layer.weight.clone();
132//
133//         let ans = matmul(&input.get_data().to_ref(), &weight.get_data().to_ref());
134//         assert_val_eq!(output.clone(), ans, 1e-4);
135//
136//         let new_layer = Linear::<f32, D>::from_json(&parameters);
137//         let new_output = new_layer.call(input.clone());
138//
139//         assert_mat_eq_epsilon!(output.get_data(), new_output.get_data(), 1e-4);
140//     }
141//     run_test!(without_bias, without_bias_cpu, without_bias_gpu);
142// }