zenu_layer/layers/rnn/
rnn.rs

1use rand_distr::{Distribution, StandardNormal};
2use zenu_autograd::{
3    nn::rnns::{
4        rnn::naive::{rnn_relu, rnn_tanh},
5        weights::RNNCell,
6    },
7    Variable,
8};
9
10#[cfg(feature = "nvidia")]
11use zenu_autograd::nn::rnns::rnn::{cudnn::cudnn_rnn_fwd, RNNOutput};
12
13use zenu_matrix::{device::Device, num::Num};
14
15use crate::{Module, ModuleParameters, Parameters};
16#[cfg(feature = "nvidia")]
17use zenu_matrix::device::nvidia::Nvidia;
18
19use super::{
20    builder::RNNSLayerBuilder,
21    inner::{Activation, RNNInner},
22};
23
24pub struct RNNLayerInput<T: Num, D: Device> {
25    pub x: Variable<T, D>,
26    pub hx: Variable<T, D>,
27}
28
29impl<T: Num, D: Device> ModuleParameters<T, D> for RNNLayerInput<T, D> {}
30
31impl<T: Num, D: Device> RNNInner<T, D, RNNCell> {
32    fn forward(&self, input: RNNLayerInput<T, D>) -> Variable<T, D> {
33        #[cfg(feature = "nvidia")]
34        if self.is_cudnn {
35            let desc = self.desc.as_ref().unwrap();
36            let weights = self.cudnn_weights.as_ref().unwrap();
37
38            let out: RNNOutput<T, Nvidia> = cudnn_rnn_fwd(
39                desc.clone(),
40                input.x.to(),
41                Some(input.hx.to()),
42                weights.to(),
43                self.is_training,
44            );
45
46            return out.y.to();
47        }
48
49        let activation = self.activation.unwrap();
50        if activation == Activation::ReLU {
51            rnn_relu(
52                input.x,
53                input.hx,
54                self.weights.as_ref().unwrap(),
55                self.is_bidirectional,
56            )
57        } else {
58            rnn_tanh(
59                input.x,
60                input.hx,
61                self.weights.as_ref().unwrap(),
62                self.is_bidirectional,
63            )
64        }
65    }
66}
67
68pub struct RNN<T: Num, D: Device>(RNNInner<T, D, RNNCell>);
69
70impl<T: Num, D: Device> Parameters<T, D> for RNN<T, D> {
71    fn weights(&self) -> std::collections::HashMap<String, Variable<T, D>> {
72        self.0.weights()
73    }
74
75    fn biases(&self) -> std::collections::HashMap<String, Variable<T, D>> {
76        self.0.biases()
77    }
78
79    fn load_parameters(&mut self, parameters: std::collections::HashMap<String, Variable<T, D>>) {
80        self.0.load_parameters(parameters);
81    }
82}
83
84impl<T: Num, D: Device> Module<T, D> for RNN<T, D> {
85    type Input = RNNLayerInput<T, D>;
86    type Output = Variable<T, D>;
87
88    fn call(&self, input: Self::Input) -> Self::Output {
89        self.0.forward(input)
90    }
91}
92
93impl<T: Num, D: Device> RNNSLayerBuilder<T, D, RNNCell> {
94    pub fn build_rnn(self) -> RNN<T, D>
95    where
96        StandardNormal: Distribution<T>,
97    {
98        RNN(self.build_inner())
99    }
100}
101
102pub type RNNBuilder<T, D> = RNNSLayerBuilder<T, D, RNNCell>;
103
104#[cfg(test)]
105mod rnn_layer_test {
106    use zenu_autograd::creator::{rand::uniform, zeros::zeros};
107    use zenu_matrix::{device::Device, dim::DimDyn};
108    use zenu_test::{assert_val_eq, run_test};
109
110    use crate::{Module, Parameters};
111
112    use super::RNNBuilder;
113
114    fn layer_save_load_test_not_cudnn<D: Device>() {
115        let layer = RNNBuilder::<f32, D>::default()
116            .hidden_size(10)
117            .num_layers(2)
118            .input_size(5)
119            .batch_size(1)
120            .build_rnn();
121
122        let input = uniform(-1., 1., None, DimDyn::from([5, 1, 5]));
123        let hidden = zeros([2, 1, 10]);
124
125        let output = layer.call(super::RNNLayerInput {
126            x: input.clone(),
127            hx: hidden.clone(),
128        });
129
130        let parameters = layer.parameters();
131
132        let new_layer = RNNBuilder::<f32, D>::default()
133            .hidden_size(10)
134            .num_layers(2)
135            .input_size(5)
136            .batch_size(1)
137            .build_rnn();
138
139        let new_layer_parameters = new_layer.parameters();
140
141        for (key, value) in &parameters {
142            new_layer_parameters
143                .get(key)
144                .unwrap()
145                .get_as_mut()
146                .copy_from(&value.get_as_ref());
147        }
148
149        let new_output = new_layer.call(super::RNNLayerInput {
150            x: input,
151            hx: hidden,
152        });
153
154        assert_val_eq!(output, new_output.get_as_ref(), 1e-4);
155    }
156    run_test!(
157        layer_save_load_test_not_cudnn,
158        layer_save_load_test_not_cudnn_cpu,
159        layer_save_load_test_not_cudnn_gpu
160    );
161
162    #[cfg(feature = "nvidia")]
163    #[test]
164    fn layer_save_load_test_cudnn() {
165        use zenu_matrix::device::nvidia::Nvidia;
166
167        let layer = RNNBuilder::<f32, Nvidia>::default()
168            .hidden_size(10)
169            .num_layers(3)
170            .input_size(5)
171            .batch_size(5)
172            .set_is_cudnn(true)
173            .build_rnn();
174
175        let mut new_layer = RNNBuilder::<f32, Nvidia>::default()
176            .hidden_size(10)
177            .num_layers(3)
178            .input_size(5)
179            .batch_size(5)
180            .set_is_cudnn(true)
181            .build_rnn();
182
183        let layer_parameters = layer.parameters();
184
185        new_layer.load_parameters(layer_parameters.clone());
186
187        let new_layer_parameters = new_layer.parameters();
188
189        for (key, value) in &layer_parameters {
190            assert_val_eq!(
191                value,
192                new_layer_parameters.get(key).unwrap().get_as_ref(),
193                1e-4
194            );
195        }
196    }
197}