scirs2_neural/layers/recurrent/
mod.rs1pub mod bidirectional;
4pub mod gru;
5pub mod lstm;
6pub mod rnn;
7
8pub use bidirectional::Bidirectional;
10pub use gru::GRU;
11pub use lstm::LSTM;
12pub use rnn::RNN;
13
14use scirs2_core::ndarray::{Array, IxDyn};
16use std::sync::{Arc, RwLock};
17
18pub type LstmStateCache<F> = Arc<RwLock<Option<(Array<F, IxDyn>, Array<F, IxDyn>)>>>;
20
21pub type LstmGateCache<F> = Arc<
23 RwLock<
24 Option<(
25 Array<F, IxDyn>,
26 Array<F, IxDyn>,
27 Array<F, IxDyn>,
28 Array<F, IxDyn>,
29 )>,
30 >,
31>;
32
33pub type GruStateCache<F> = Arc<RwLock<Option<Array<F, IxDyn>>>>;
35
36pub type GruGateCache<F> = Arc<RwLock<Option<(Array<F, IxDyn>, Array<F, IxDyn>, Array<F, IxDyn>)>>>;
38
39pub type RnnStateCache<F> = Arc<RwLock<Option<Array<F, IxDyn>>>>;
41
42pub type LstmStepOutput<F> = (
44 Array<F, IxDyn>,
45 Array<F, IxDyn>,
46 (
47 Array<F, IxDyn>,
48 Array<F, IxDyn>,
49 Array<F, IxDyn>,
50 Array<F, IxDyn>,
51 ),
52);
53
54pub type GruForwardOutput<F> = (
56 Array<F, IxDyn>,
57 (Array<F, IxDyn>, Array<F, IxDyn>, Array<F, IxDyn>),
58);