scirs2_neural/layers/recurrent/
mod.rs

1//! Recurrent layer implementations
2
3pub mod bidirectional;
4pub mod gru;
5pub mod lstm;
6pub mod rnn;
7
8// Re-export main types
9pub use bidirectional::Bidirectional;
10pub use gru::GRU;
11pub use lstm::LSTM;
12pub use rnn::RNN;
13
14// Type aliases for compatibility
15use scirs2_core::ndarray::{Array, IxDyn};
16use std::sync::{Arc, RwLock};
17
18/// Type alias for LSTM state cache (hidden, cell)
19pub type LstmStateCache<F> = Arc<RwLock<Option<(Array<F, IxDyn>, Array<F, IxDyn>)>>>;
20
21/// Type alias for LSTM gate cache (input, forget, output, cell gates)
22pub type LstmGateCache<F> = Arc<
23    RwLock<
24        Option<(
25            Array<F, IxDyn>,
26            Array<F, IxDyn>,
27            Array<F, IxDyn>,
28            Array<F, IxDyn>,
29        )>,
30    >,
31>;
32
33/// Type alias for GRU state cache
34pub type GruStateCache<F> = Arc<RwLock<Option<Array<F, IxDyn>>>>;
35
36/// Type alias for GRU gate cache (reset, update, new gates)
37pub type GruGateCache<F> = Arc<RwLock<Option<(Array<F, IxDyn>, Array<F, IxDyn>, Array<F, IxDyn>)>>>;
38
39/// Type alias for RNN state cache
40pub type RnnStateCache<F> = Arc<RwLock<Option<Array<F, IxDyn>>>>;
41
42/// Type alias for LSTM step output (new_h, new_c, (input_gate, forget_gate, cell_gate, output_gate))
43pub type LstmStepOutput<F> = (
44    Array<F, IxDyn>,
45    Array<F, IxDyn>,
46    (
47        Array<F, IxDyn>,
48        Array<F, IxDyn>,
49        Array<F, IxDyn>,
50        Array<F, IxDyn>,
51    ),
52);
53
54/// Type alias for GRU forward output (new_h, (reset_gate, update_gate, new_gate))
55pub type GruForwardOutput<F> = (
56    Array<F, IxDyn>,
57    (Array<F, IxDyn>, Array<F, IxDyn>, Array<F, IxDyn>),
58);