zenu_layer/layers/rnn/
gru.rs1use rand_distr::{Distribution, StandardNormal};
2use zenu_autograd::{
3 nn::rnns::{gru::naive::gru_naive, weights::GRUCell},
4 Variable,
5};
6
7#[cfg(feature = "nvidia")]
8use zenu_autograd::nn::rnns::gru::cudnn::gru_cudnn;
9
10use zenu_matrix::{device::Device, num::Num};
11
12use crate::{Module, ModuleParameters, Parameters};
13
14use super::{builder::RNNSLayerBuilder, inner::RNNInner};
15
16pub struct GRUInput<T: Num, D: Device> {
17 pub x: Variable<T, D>,
18 pub hx: Variable<T, D>,
19}
20
21impl<T: Num, D: Device> ModuleParameters<T, D> for GRUInput<T, D> {}
22
23impl<T: Num, D: Device> RNNInner<T, D, GRUCell> {
24 fn forward(&self, input: GRUInput<T, D>) -> Variable<T, D> {
25 #[cfg(feature = "nvidia")]
26 if self.is_cudnn {
27 let desc = self.desc.as_ref().unwrap();
28 let weights = self.cudnn_weights.as_ref().unwrap();
29
30 let out = gru_cudnn(
31 desc.clone(),
32 input.x.to(),
33 Some(input.hx.to()),
34 weights.to(),
35 self.is_training,
36 );
37
38 return out.y.to();
39 }
40
41 gru_naive(
42 input.x,
43 input.hx,
44 self.weights.as_ref().unwrap(),
45 self.is_bidirectional,
46 )
47 }
48}
49
50pub struct GRU<T: Num, D: Device>(RNNInner<T, D, GRUCell>);
51
52impl<T: Num, D: Device> Parameters<T, D> for GRU<T, D> {
53 fn weights(&self) -> std::collections::HashMap<String, Variable<T, D>> {
54 self.0.weights()
55 }
56
57 fn biases(&self) -> std::collections::HashMap<String, Variable<T, D>> {
58 self.0.biases()
59 }
60
61 fn load_parameters(&mut self, parameters: std::collections::HashMap<String, Variable<T, D>>) {
62 self.0.load_parameters(parameters);
63 }
64}
65
66impl<T: Num, D: Device> Module<T, D> for GRU<T, D> {
67 type Input = GRUInput<T, D>;
68 type Output = Variable<T, D>;
69
70 fn call(&self, input: Self::Input) -> Self::Output {
71 self.0.forward(input)
72 }
73}
74
75pub type GRUBuilder<T, D> = RNNSLayerBuilder<T, D, GRUCell>;
76
77impl<T: Num, D: Device> RNNSLayerBuilder<T, D, GRUCell>
78where
79 StandardNormal: Distribution<T>,
80{
81 pub fn build_gru(self) -> GRU<T, D> {
82 GRU(self.build_inner())
83 }
84}