rust_lstm/
lib.rs

1//! # Rust LSTM Library
2//! 
3//! A complete LSTM implementation with training capabilities, multiple optimizers,
4//! dropout regularization, and support for various architectures including peephole 
5//! connections and bidirectional processing.
6//! 
7//! ## Core Components
8//! 
9//! - **LSTM Cells**: Standard and peephole LSTM implementations with full backpropagation
10//! - **Bidirectional LSTM**: Process sequences in both directions with flexible output combination
11//! - **Networks**: Multi-layer LSTM networks for sequence modeling
12//! - **Training**: Complete training system with BPTT, gradient clipping, and validation
13//! - **Optimizers**: SGD, Adam, and RMSprop optimizers with adaptive learning rates
14//! - **Loss Functions**: MSE, MAE, and Cross-Entropy with numerically stable implementations
15//! - **Dropout**: Input, recurrent, output dropout and zoneout regularization
16//! 
17//! ## Quick Start
18//! 
19//! ```rust
20//! use rust_lstm::models::lstm_network::LSTMNetwork;
21//! use rust_lstm::training::create_basic_trainer;
22//! 
23//! // Create a 2-layer LSTM with 10 input features and 20 hidden units
24//! let mut network = LSTMNetwork::new(10, 20, 2)
25//!     .with_input_dropout(0.2, true)     // Variational input dropout
26//!     .with_recurrent_dropout(0.3, true) // Variational recurrent dropout
27//!     .with_output_dropout(0.1);         // Standard output dropout
28//! 
29//! let mut trainer = create_basic_trainer(network, 0.001);
30//! 
31//! // Train on your data
32//! // trainer.train(&train_data, Some(&validation_data));
33//! ```
34
35/// Main library module.
36pub mod utils;
37pub mod layers;
38pub mod models;
39pub mod loss;
40pub mod optimizers;
41pub mod schedulers;
42pub mod training;
43pub mod persistence;
44
45// Re-export commonly used items
46pub use models::lstm_network::{LSTMNetwork, LSTMNetworkCache, LSTMNetworkBatchCache, LayerDropoutConfig};
47pub use models::gru_network::{GRUNetwork, LayerDropoutConfig as GRULayerDropoutConfig, GRUNetworkCache};
48pub use layers::lstm_cell::{LSTMCell, LSTMCellCache, LSTMCellBatchCache, LSTMCellGradients};
49pub use layers::peephole_lstm_cell::PeepholeLSTMCell;
50pub use layers::gru_cell::{GRUCell, GRUCellGradients, GRUCellCache};
51pub use layers::bilstm_network::{BiLSTMNetwork, CombineMode, BiLSTMNetworkCache};
52pub use layers::dropout::{Dropout, Zoneout};
53pub use training::{
54    LSTMTrainer, ScheduledLSTMTrainer, LSTMBatchTrainer, TrainingConfig, TrainingMetrics,
55    create_basic_trainer, create_step_lr_trainer, create_one_cycle_trainer, create_cosine_annealing_trainer,
56    create_basic_batch_trainer, create_adam_batch_trainer
57};
58pub use optimizers::{SGD, Adam, RMSprop, ScheduledOptimizer};
59pub use schedulers::{
60    LearningRateScheduler, ConstantLR, StepLR, MultiStepLR, ExponentialLR, 
61    CosineAnnealingLR, CosineAnnealingWarmRestarts, OneCycleLR, 
62    ReduceLROnPlateau, LinearLR, AnnealStrategy,
63    PolynomialLR, CyclicalLR, CyclicalMode, ScaleMode, WarmupScheduler,
64    LRScheduleVisualizer
65};
66pub use loss::{LossFunction, MSELoss, MAELoss, CrossEntropyLoss};
67pub use persistence::{ModelPersistence, PersistentModel, ModelMetadata, PersistenceError};
68
69#[cfg(test)]
70mod tests {
71    use super::*;
72    use ndarray::arr2;
73    
74    #[test]
75    fn test_library_integration() {
76        let mut network = models::lstm_network::LSTMNetwork::new(2, 3, 1);
77        let input = arr2(&[[1.0], [0.5]]);
78        let hx = arr2(&[[0.0], [0.0], [0.0]]);
79        let cx = arr2(&[[0.0], [0.0], [0.0]]);
80        
81        let (hy, cy) = network.forward(&input, &hx, &cx);
82        
83        assert_eq!(hy.shape(), &[3, 1]);
84        assert_eq!(cy.shape(), &[3, 1]);
85    }
86
87    #[test]
88    fn test_library_with_dropout() {
89        let mut network = models::lstm_network::LSTMNetwork::new(2, 3, 1)
90            .with_input_dropout(0.2, false)
91            .with_recurrent_dropout(0.3, true)
92            .with_output_dropout(0.1);
93        
94        let input = arr2(&[[1.0], [0.5]]);
95        let hx = arr2(&[[0.0], [0.0], [0.0]]);
96        let cx = arr2(&[[0.0], [0.0], [0.0]]);
97        
98        // Test training mode
99        network.train();
100        let (hy_train, cy_train) = network.forward(&input, &hx, &cx);
101        
102        // Test evaluation mode
103        network.eval();
104        let (hy_eval, cy_eval) = network.forward(&input, &hx, &cx);
105        
106        assert_eq!(hy_train.shape(), &[3, 1]);
107        assert_eq!(cy_train.shape(), &[3, 1]);
108        assert_eq!(hy_eval.shape(), &[3, 1]);
109        assert_eq!(cy_eval.shape(), &[3, 1]);
110    }
111}