zenu_layer/layers/rnn/
gru.rs

1use rand_distr::{Distribution, StandardNormal};
2use zenu_autograd::{
3    nn::rnns::{gru::naive::gru_naive, weights::GRUCell},
4    Variable,
5};
6
7#[cfg(feature = "nvidia")]
8use zenu_autograd::nn::rnns::gru::cudnn::gru_cudnn;
9
10use zenu_matrix::{device::Device, num::Num};
11
12use crate::{Module, ModuleParameters, Parameters};
13
14use super::{builder::RNNSLayerBuilder, inner::RNNInner};
15
16pub struct GRUInput<T: Num, D: Device> {
17    pub x: Variable<T, D>,
18    pub hx: Variable<T, D>,
19}
20
21impl<T: Num, D: Device> ModuleParameters<T, D> for GRUInput<T, D> {}
22
23impl<T: Num, D: Device> RNNInner<T, D, GRUCell> {
24    fn forward(&self, input: GRUInput<T, D>) -> Variable<T, D> {
25        #[cfg(feature = "nvidia")]
26        if self.is_cudnn {
27            let desc = self.desc.as_ref().unwrap();
28            let weights = self.cudnn_weights.as_ref().unwrap();
29
30            let out = gru_cudnn(
31                desc.clone(),
32                input.x.to(),
33                Some(input.hx.to()),
34                weights.to(),
35                self.is_training,
36            );
37
38            return out.y.to();
39        }
40
41        gru_naive(
42            input.x,
43            input.hx,
44            self.weights.as_ref().unwrap(),
45            self.is_bidirectional,
46        )
47    }
48}
49
50pub struct GRU<T: Num, D: Device>(RNNInner<T, D, GRUCell>);
51
52impl<T: Num, D: Device> Parameters<T, D> for GRU<T, D> {
53    fn weights(&self) -> std::collections::HashMap<String, Variable<T, D>> {
54        self.0.weights()
55    }
56
57    fn biases(&self) -> std::collections::HashMap<String, Variable<T, D>> {
58        self.0.biases()
59    }
60
61    fn load_parameters(&mut self, parameters: std::collections::HashMap<String, Variable<T, D>>) {
62        self.0.load_parameters(parameters);
63    }
64}
65
66impl<T: Num, D: Device> Module<T, D> for GRU<T, D> {
67    type Input = GRUInput<T, D>;
68    type Output = Variable<T, D>;
69
70    fn call(&self, input: Self::Input) -> Self::Output {
71        self.0.forward(input)
72    }
73}
74
75pub type GRUBuilder<T, D> = RNNSLayerBuilder<T, D, GRUCell>;
76
77impl<T: Num, D: Device> RNNSLayerBuilder<T, D, GRUCell>
78where
79    StandardNormal: Distribution<T>,
80{
81    pub fn build_gru(self) -> GRU<T, D> {
82        GRU(self.build_inner())
83    }
84}