rust_lstm/
lib.rs

1//! # Rust LSTM Library
2//! 
3//! A complete LSTM implementation with training capabilities, multiple optimizers,
4//! dropout regularization, and support for various architectures including peephole 
5//! connections and bidirectional processing.
6//! 
7//! ## Core Components
8//! 
9//! - **LSTM Cells**: Standard and peephole LSTM implementations with full backpropagation
10//! - **Bidirectional LSTM**: Process sequences in both directions with flexible output combination
11//! - **Networks**: Multi-layer LSTM networks for sequence modeling
12//! - **Training**: Complete training system with BPTT, gradient clipping, and validation
13//! - **Optimizers**: SGD, Adam, and RMSprop optimizers with adaptive learning rates
14//! - **Loss Functions**: MSE, MAE, and Cross-Entropy with numerically stable implementations
15//! - **Dropout**: Input, recurrent, output dropout and zoneout regularization
16//! 
17//! ## Quick Start
18//! 
19//! ```rust
20//! use rust_lstm::models::lstm_network::LSTMNetwork;
21//! use rust_lstm::training::create_basic_trainer;
22//! 
23//! // Create a 2-layer LSTM with 10 input features and 20 hidden units
24//! let mut network = LSTMNetwork::new(10, 20, 2)
25//!     .with_input_dropout(0.2, true)     // Variational input dropout
26//!     .with_recurrent_dropout(0.3, true) // Variational recurrent dropout
27//!     .with_output_dropout(0.1);         // Standard output dropout
28//! 
29//! let mut trainer = create_basic_trainer(network, 0.001);
30//! 
31//! // Train on your data
32//! // trainer.train(&train_data, Some(&validation_data));
33//! ```
34
35/// Main library module.
36pub mod utils;
37pub mod layers;
38pub mod models;
39pub mod loss;
40pub mod optimizers;
41pub mod schedulers;
42pub mod training;
43pub mod persistence;
44
45// Re-export commonly used items
46pub use models::lstm_network::{LSTMNetwork, LayerDropoutConfig};
47pub use models::gru_network::{GRUNetwork, LayerDropoutConfig as GRULayerDropoutConfig, GRUNetworkCache};
48pub use layers::lstm_cell::LSTMCell;
49pub use layers::peephole_lstm_cell::PeepholeLSTMCell;
50pub use layers::gru_cell::{GRUCell, GRUCellGradients, GRUCellCache};
51pub use layers::bilstm_network::{BiLSTMNetwork, CombineMode, BiLSTMNetworkCache};
52pub use layers::dropout::{Dropout, Zoneout};
53pub use training::{
54    LSTMTrainer, ScheduledLSTMTrainer, TrainingConfig, 
55    create_basic_trainer, create_step_lr_trainer, create_one_cycle_trainer, create_cosine_annealing_trainer
56};
57pub use optimizers::{SGD, Adam, RMSprop, ScheduledOptimizer};
58pub use schedulers::{
59    LearningRateScheduler, ConstantLR, StepLR, MultiStepLR, ExponentialLR, 
60    CosineAnnealingLR, CosineAnnealingWarmRestarts, OneCycleLR, 
61    ReduceLROnPlateau, LinearLR, AnnealStrategy
62};
63pub use loss::{LossFunction, MSELoss, MAELoss, CrossEntropyLoss};
64pub use persistence::{ModelPersistence, PersistentModel, ModelMetadata, PersistenceError};
65
66#[cfg(test)]
67mod tests {
68    use super::*;
69    use ndarray::arr2;
70    
71    #[test]
72    fn test_library_integration() {
73        let mut network = models::lstm_network::LSTMNetwork::new(2, 3, 1);
74        let input = arr2(&[[1.0], [0.5]]);
75        let hx = arr2(&[[0.0], [0.0], [0.0]]);
76        let cx = arr2(&[[0.0], [0.0], [0.0]]);
77        
78        let (hy, cy) = network.forward(&input, &hx, &cx);
79        
80        assert_eq!(hy.shape(), &[3, 1]);
81        assert_eq!(cy.shape(), &[3, 1]);
82    }
83
84    #[test]
85    fn test_library_with_dropout() {
86        let mut network = models::lstm_network::LSTMNetwork::new(2, 3, 1)
87            .with_input_dropout(0.2, false)
88            .with_recurrent_dropout(0.3, true)
89            .with_output_dropout(0.1);
90        
91        let input = arr2(&[[1.0], [0.5]]);
92        let hx = arr2(&[[0.0], [0.0], [0.0]]);
93        let cx = arr2(&[[0.0], [0.0], [0.0]]);
94        
95        // Test training mode
96        network.train();
97        let (hy_train, cy_train) = network.forward(&input, &hx, &cx);
98        
99        // Test evaluation mode
100        network.eval();
101        let (hy_eval, cy_eval) = network.forward(&input, &hx, &cx);
102        
103        assert_eq!(hy_train.shape(), &[3, 1]);
104        assert_eq!(cy_train.shape(), &[3, 1]);
105        assert_eq!(hy_eval.shape(), &[3, 1]);
106        assert_eq!(cy_eval.shape(), &[3, 1]);
107    }
108}