zenu_layer/layers/linear.rs
1use std::collections::HashMap;
2
3use crate::{Module, Parameters};
4use rand_distr::{Distribution, StandardNormal};
5use zenu_autograd::{
6 creator::{rand::normal, zeros::zeros},
7 functions::{matmul::matmul, transpose::transpose},
8 Variable,
9};
10use zenu_matrix::{device::Device, num::Num};
11
12pub struct Linear<T: Num, D: Device> {
13 in_features: usize,
14 out_features: usize,
15 pub weight: Variable<T, D>,
16 pub bias: Option<Variable<T, D>>,
17}
18
19impl<T: Num, D: Device> Module<T, D> for Linear<T, D> {
20 type Input = Variable<T, D>;
21 type Output = Variable<T, D>;
22 fn call(&self, input: Variable<T, D>) -> Variable<T, D> {
23 let weight_t = transpose(self.weight.clone());
24 let output = matmul(input, weight_t);
25 if let Some(bias) = &self.bias {
26 output.set_name("linear.intermediate_output");
27 output + bias.clone()
28 } else {
29 output
30 }
31 }
32}
33
34impl<T: Num, D: Device> Parameters<T, D> for Linear<T, D> {
35 fn weights(&self) -> HashMap<String, Variable<T, D>> {
36 let mut weights = HashMap::new();
37 weights.insert("linear.weight".to_string(), self.weight.clone());
38 weights
39 }
40
41 fn biases(&self) -> HashMap<String, Variable<T, D>> {
42 let mut biases = HashMap::new();
43 if let Some(bias) = &self.bias {
44 biases.insert("linear.bias".to_string(), bias.clone());
45 }
46 biases
47 }
48}
49
50impl<T: Num, D: Device> Linear<T, D> {
51 #[must_use]
52 pub fn new(in_features: usize, out_features: usize, use_bias: bool) -> Self
53 where
54 StandardNormal: Distribution<T>,
55 {
56 let weight = normal(T::zero(), T::one(), None, [out_features, in_features]);
57 weight
58 .get_data_mut()
59 .to_ref_mut()
60 .div_scalar_assign(T::from_usize(in_features).sqrt());
61 let bias = if use_bias {
62 let bias = zeros([out_features]);
63 bias.set_name("linear.bias");
64 bias.set_is_train(true);
65 Some(bias)
66 } else {
67 None
68 };
69
70 weight.set_is_train(true);
71 weight.set_name("linear.weight");
72
73 Self {
74 in_features,
75 out_features,
76 weight,
77 bias,
78 }
79 }
80
81 #[must_use]
82 pub fn to<Dout: Device>(self) -> Linear<T, Dout> {
83 Linear {
84 in_features: self.in_features,
85 out_features: self.out_features,
86 weight: self.weight.to(),
87 bias: self.bias.map(|b| b.to()),
88 }
89 }
90}
91
92// #[cfg(test)]
93// mod linear {
94// use zenu_autograd::creator::rand::normal;
95// use zenu_matrix::{device::Device, dim::DimTrait, operation::mul::matmul};
96// use zenu_test::{assert_mat_eq_epsilon, assert_val_eq, run_test};
97//
98// use crate::{Module, StateDict};
99//
100// use super::Linear;
101//
102// fn with_bias<D: Device>() {
103// let layer = Linear::<f32, D>::new(3, 2, true);
104// let input = normal::<_, _, D>(0., 1., None, [5, 3]);
105// let output = layer.call(input.clone());
106// assert_eq!(output.get_data().shape().slice(), [5, 2]);
107//
108// let parameters = layer.to_json();
109//
110// let ans = matmul(
111// &input.get_data().to_ref(),
112// &layer.weight.get_data().to_ref(),
113// ) + &layer.bias.unwrap().get_data().to_ref();
114//
115// assert_val_eq!(output.clone(), ans, 1e-4);
116//
117// let new_layer = Linear::<f32, D>::from_json(¶meters);
118// let new_output = new_layer.call(input.clone());
119//
120// assert_mat_eq_epsilon!(output.get_data(), new_output.get_data(), 1e-4);
121// }
122// run_test!(with_bias, with_bias_cpu, with_bias_gpu);
123//
124// fn without_bias<D: Device>() {
125// let layer = Linear::<f32, D>::new(3, 2, false);
126// let input = normal::<_, _, D>(0., 1., None, [5, 3]);
127// let output = layer.call(input.clone());
128// assert_eq!(output.get_data().shape().slice(), [5, 2]);
129//
130// let parameters = layer.to_json();
131// let weight = layer.weight.clone();
132//
133// let ans = matmul(&input.get_data().to_ref(), &weight.get_data().to_ref());
134// assert_val_eq!(output.clone(), ans, 1e-4);
135//
136// let new_layer = Linear::<f32, D>::from_json(¶meters);
137// let new_output = new_layer.call(input.clone());
138//
139// assert_mat_eq_epsilon!(output.get_data(), new_output.get_data(), 1e-4);
140// }
141// run_test!(without_bias, without_bias_cpu, without_bias_gpu);
142// }