1use rand_distr::{Distribution, StandardNormal};
2use zenu_autograd::{
3 nn::rnns::{
4 rnn::naive::{rnn_relu, rnn_tanh},
5 weights::RNNCell,
6 },
7 Variable,
8};
9
10#[cfg(feature = "nvidia")]
11use zenu_autograd::nn::rnns::rnn::{cudnn::cudnn_rnn_fwd, RNNOutput};
12
13use zenu_matrix::{device::Device, num::Num};
14
15use crate::{Module, ModuleParameters, Parameters};
16#[cfg(feature = "nvidia")]
17use zenu_matrix::device::nvidia::Nvidia;
18
19use super::{
20 builder::RNNSLayerBuilder,
21 inner::{Activation, RNNInner},
22};
23
24pub struct RNNLayerInput<T: Num, D: Device> {
25 pub x: Variable<T, D>,
26 pub hx: Variable<T, D>,
27}
28
29impl<T: Num, D: Device> ModuleParameters<T, D> for RNNLayerInput<T, D> {}
30
31impl<T: Num, D: Device> RNNInner<T, D, RNNCell> {
32 fn forward(&self, input: RNNLayerInput<T, D>) -> Variable<T, D> {
33 #[cfg(feature = "nvidia")]
34 if self.is_cudnn {
35 let desc = self.desc.as_ref().unwrap();
36 let weights = self.cudnn_weights.as_ref().unwrap();
37
38 let out: RNNOutput<T, Nvidia> = cudnn_rnn_fwd(
39 desc.clone(),
40 input.x.to(),
41 Some(input.hx.to()),
42 weights.to(),
43 self.is_training,
44 );
45
46 return out.y.to();
47 }
48
49 let activation = self.activation.unwrap();
50 if activation == Activation::ReLU {
51 rnn_relu(
52 input.x,
53 input.hx,
54 self.weights.as_ref().unwrap(),
55 self.is_bidirectional,
56 )
57 } else {
58 rnn_tanh(
59 input.x,
60 input.hx,
61 self.weights.as_ref().unwrap(),
62 self.is_bidirectional,
63 )
64 }
65 }
66}
67
68pub struct RNN<T: Num, D: Device>(RNNInner<T, D, RNNCell>);
69
70impl<T: Num, D: Device> Parameters<T, D> for RNN<T, D> {
71 fn weights(&self) -> std::collections::HashMap<String, Variable<T, D>> {
72 self.0.weights()
73 }
74
75 fn biases(&self) -> std::collections::HashMap<String, Variable<T, D>> {
76 self.0.biases()
77 }
78
79 fn load_parameters(&mut self, parameters: std::collections::HashMap<String, Variable<T, D>>) {
80 self.0.load_parameters(parameters);
81 }
82}
83
84impl<T: Num, D: Device> Module<T, D> for RNN<T, D> {
85 type Input = RNNLayerInput<T, D>;
86 type Output = Variable<T, D>;
87
88 fn call(&self, input: Self::Input) -> Self::Output {
89 self.0.forward(input)
90 }
91}
92
93impl<T: Num, D: Device> RNNSLayerBuilder<T, D, RNNCell> {
94 pub fn build_rnn(self) -> RNN<T, D>
95 where
96 StandardNormal: Distribution<T>,
97 {
98 RNN(self.build_inner())
99 }
100}
101
102pub type RNNBuilder<T, D> = RNNSLayerBuilder<T, D, RNNCell>;
103
104#[cfg(test)]
105mod rnn_layer_test {
106 use zenu_autograd::creator::{rand::uniform, zeros::zeros};
107 use zenu_matrix::{device::Device, dim::DimDyn};
108 use zenu_test::{assert_val_eq, run_test};
109
110 use crate::{Module, Parameters};
111
112 use super::RNNBuilder;
113
114 fn layer_save_load_test_not_cudnn<D: Device>() {
115 let layer = RNNBuilder::<f32, D>::default()
116 .hidden_size(10)
117 .num_layers(2)
118 .input_size(5)
119 .batch_size(1)
120 .build_rnn();
121
122 let input = uniform(-1., 1., None, DimDyn::from([5, 1, 5]));
123 let hidden = zeros([2, 1, 10]);
124
125 let output = layer.call(super::RNNLayerInput {
126 x: input.clone(),
127 hx: hidden.clone(),
128 });
129
130 let parameters = layer.parameters();
131
132 let new_layer = RNNBuilder::<f32, D>::default()
133 .hidden_size(10)
134 .num_layers(2)
135 .input_size(5)
136 .batch_size(1)
137 .build_rnn();
138
139 let new_layer_parameters = new_layer.parameters();
140
141 for (key, value) in ¶meters {
142 new_layer_parameters
143 .get(key)
144 .unwrap()
145 .get_as_mut()
146 .copy_from(&value.get_as_ref());
147 }
148
149 let new_output = new_layer.call(super::RNNLayerInput {
150 x: input,
151 hx: hidden,
152 });
153
154 assert_val_eq!(output, new_output.get_as_ref(), 1e-4);
155 }
156 run_test!(
157 layer_save_load_test_not_cudnn,
158 layer_save_load_test_not_cudnn_cpu,
159 layer_save_load_test_not_cudnn_gpu
160 );
161
162 #[cfg(feature = "nvidia")]
163 #[test]
164 fn layer_save_load_test_cudnn() {
165 use zenu_matrix::device::nvidia::Nvidia;
166
167 let layer = RNNBuilder::<f32, Nvidia>::default()
168 .hidden_size(10)
169 .num_layers(3)
170 .input_size(5)
171 .batch_size(5)
172 .set_is_cudnn(true)
173 .build_rnn();
174
175 let mut new_layer = RNNBuilder::<f32, Nvidia>::default()
176 .hidden_size(10)
177 .num_layers(3)
178 .input_size(5)
179 .batch_size(5)
180 .set_is_cudnn(true)
181 .build_rnn();
182
183 let layer_parameters = layer.parameters();
184
185 new_layer.load_parameters(layer_parameters.clone());
186
187 let new_layer_parameters = new_layer.parameters();
188
189 for (key, value) in &layer_parameters {
190 assert_val_eq!(
191 value,
192 new_layer_parameters.get(key).unwrap().get_as_ref(),
193 1e-4
194 );
195 }
196 }
197}