zenu_layer/layers/rnn/
lstm.rs

1use rand_distr::{Distribution, StandardNormal};
2use zenu_autograd::{
3    nn::rnns::{lstm::naive::lstm_naive, weights::LSTMCell},
4    Variable,
5};
6
7#[cfg(feature = "nvidia")]
8use zenu_autograd::nn::rnns::lstm::cudnn::lstm_cudnn;
9
10use zenu_matrix::{device::Device, num::Num};
11
12use crate::{Module, ModuleParameters, Parameters};
13
14use super::{builder::RNNSLayerBuilder, inner::RNNInner};
15
16pub struct LSTMInput<T: Num, D: Device> {
17    pub x: Variable<T, D>,
18    pub hx: Variable<T, D>,
19    pub cx: Variable<T, D>,
20}
21
22impl<T: Num, D: Device> ModuleParameters<T, D> for LSTMInput<T, D> {}
23
24impl<T: Num, D: Device> RNNInner<T, D, LSTMCell> {
25    fn forward(&self, input: LSTMInput<T, D>) -> Variable<T, D> {
26        #[cfg(feature = "nvidia")]
27        if self.is_cudnn {
28            let desc = self.desc.as_ref().unwrap();
29            let weights = self.cudnn_weights.as_ref().unwrap();
30
31            let out = lstm_cudnn(
32                desc.clone(),
33                input.x.to(),
34                Some(input.hx.to()),
35                Some(input.cx.to()),
36                weights.to(),
37                self.is_training,
38            );
39
40            return out.to();
41        }
42
43        lstm_naive(
44            input.x,
45            input.hx,
46            input.cx,
47            self.weights.as_ref().unwrap(),
48            self.is_bidirectional,
49        )
50    }
51}
52
53pub struct LSTM<T: Num, D: Device>(RNNInner<T, D, LSTMCell>);
54
55impl<T: Num, D: Device> Parameters<T, D> for LSTM<T, D> {
56    fn weights(&self) -> std::collections::HashMap<String, Variable<T, D>> {
57        self.0.weights()
58    }
59
60    fn biases(&self) -> std::collections::HashMap<String, Variable<T, D>> {
61        self.0.biases()
62    }
63
64    fn load_parameters(&mut self, parameters: std::collections::HashMap<String, Variable<T, D>>) {
65        self.0.load_parameters(parameters);
66    }
67}
68
69impl<T: Num, D: Device> Module<T, D> for LSTM<T, D> {
70    type Input = LSTMInput<T, D>;
71    type Output = Variable<T, D>;
72
73    fn call(&self, input: Self::Input) -> Self::Output {
74        self.0.forward(input)
75    }
76}
77
78pub type LSTMBuilder<T, D> = RNNSLayerBuilder<T, D, LSTMCell>;
79
80impl<T: Num, D: Device> RNNSLayerBuilder<T, D, LSTMCell>
81where
82    StandardNormal: Distribution<T>,
83{
84    pub fn build_lstm(self) -> LSTM<T, D> {
85        LSTM(self.build_inner())
86    }
87}