Skip to main content

rust_lstm/
lib.rs

1//! # Rust LSTM Library
2//! 
3//! A complete LSTM implementation with training capabilities, multiple optimizers,
4//! dropout regularization, and support for various architectures including peephole 
5//! connections and bidirectional processing.
6//! 
7//! ## Core Components
8//! 
9//! - **LSTM Cells**: Standard and peephole LSTM implementations with full backpropagation
10//! - **Bidirectional LSTM**: Process sequences in both directions with flexible output combination
11//! - **Networks**: Multi-layer LSTM networks for sequence modeling
12//! - **Training**: Complete training system with BPTT, gradient clipping, and validation
13//! - **Optimizers**: SGD, Adam, and RMSprop optimizers with adaptive learning rates
14//! - **Loss Functions**: MSE, MAE, and Cross-Entropy with numerically stable implementations
15//! - **Dropout**: Input, recurrent, output dropout and zoneout regularization
16//! 
17//! ## Quick Start
18//! 
19//! ```rust
20//! use rust_lstm::models::lstm_network::LSTMNetwork;
21//! use rust_lstm::training::create_basic_trainer;
22//! 
23//! // Create a 2-layer LSTM with 10 input features and 20 hidden units
24//! let mut network = LSTMNetwork::new(10, 20, 2)
25//!     .with_input_dropout(0.2, true)     // Variational input dropout
26//!     .with_recurrent_dropout(0.3, true) // Variational recurrent dropout
27//!     .with_output_dropout(0.1);         // Standard output dropout
28//! 
29//! let mut trainer = create_basic_trainer(network, 0.001);
30//! 
31//! // Train on your data
32//! // trainer.train(&train_data, Some(&validation_data));
33//! ```
34
35/// Main library module.
36pub mod utils;
37pub mod layers;
38pub mod models;
39pub mod loss;
40pub mod optimizers;
41pub mod schedulers;
42pub mod training;
43pub mod persistence;
44pub mod text;
45
46// Re-export commonly used items
47pub use models::lstm_network::{LSTMNetwork, LSTMNetworkCache, LSTMNetworkBatchCache, LayerDropoutConfig};
48pub use models::gru_network::{GRUNetwork, LayerDropoutConfig as GRULayerDropoutConfig, GRUNetworkCache};
49pub use layers::lstm_cell::{LSTMCell, LSTMCellCache, LSTMCellBatchCache, LSTMCellGradients};
50pub use layers::peephole_lstm_cell::PeepholeLSTMCell;
51pub use layers::gru_cell::{GRUCell, GRUCellGradients, GRUCellCache};
52pub use layers::bilstm_network::{BiLSTMNetwork, CombineMode, BiLSTMNetworkCache};
53pub use layers::dropout::{Dropout, Zoneout};
54pub use layers::linear::{LinearLayer, LinearGradients};
55pub use training::{
56    LSTMTrainer, ScheduledLSTMTrainer, LSTMBatchTrainer, TrainingConfig, TrainingMetrics,
57    EarlyStoppingConfig, EarlyStoppingMetric, EarlyStopper,
58    create_basic_trainer, create_step_lr_trainer, create_one_cycle_trainer, create_cosine_annealing_trainer,
59    create_basic_batch_trainer, create_adam_batch_trainer
60};
61pub use optimizers::{SGD, Adam, RMSprop, ScheduledOptimizer};
62pub use schedulers::{
63    LearningRateScheduler, ConstantLR, StepLR, MultiStepLR, ExponentialLR, 
64    CosineAnnealingLR, CosineAnnealingWarmRestarts, OneCycleLR, 
65    ReduceLROnPlateau, LinearLR, AnnealStrategy,
66    PolynomialLR, CyclicalLR, CyclicalMode, ScaleMode, WarmupScheduler,
67    LRScheduleVisualizer
68};
69pub use loss::{LossFunction, MSELoss, MAELoss, CrossEntropyLoss};
70pub use persistence::{ModelPersistence, PersistentModel, ModelMetadata, PersistenceError};
71pub use text::{
72    TextVocabulary, CharacterEmbedding, EmbeddingGradients,
73    sample_with_temperature, sample_top_k, sample_nucleus, argmax, softmax
74};
75
76#[cfg(test)]
77mod tests {
78    use super::*;
79    use ndarray::arr2;
80    
81    #[test]
82    fn test_library_integration() {
83        let mut network = models::lstm_network::LSTMNetwork::new(2, 3, 1);
84        let input = arr2(&[[1.0], [0.5]]);
85        let hx = arr2(&[[0.0], [0.0], [0.0]]);
86        let cx = arr2(&[[0.0], [0.0], [0.0]]);
87        
88        let (hy, cy) = network.forward(&input, &hx, &cx);
89        
90        assert_eq!(hy.shape(), &[3, 1]);
91        assert_eq!(cy.shape(), &[3, 1]);
92    }
93
94    #[test]
95    fn test_library_with_dropout() {
96        let mut network = models::lstm_network::LSTMNetwork::new(2, 3, 1)
97            .with_input_dropout(0.2, false)
98            .with_recurrent_dropout(0.3, true)
99            .with_output_dropout(0.1);
100        
101        let input = arr2(&[[1.0], [0.5]]);
102        let hx = arr2(&[[0.0], [0.0], [0.0]]);
103        let cx = arr2(&[[0.0], [0.0], [0.0]]);
104        
105        // Test training mode
106        network.train();
107        let (hy_train, cy_train) = network.forward(&input, &hx, &cx);
108        
109        // Test evaluation mode
110        network.eval();
111        let (hy_eval, cy_eval) = network.forward(&input, &hx, &cx);
112        
113        assert_eq!(hy_train.shape(), &[3, 1]);
114        assert_eq!(cy_train.shape(), &[3, 1]);
115        assert_eq!(hy_eval.shape(), &[3, 1]);
116        assert_eq!(cy_eval.shape(), &[3, 1]);
117    }
118}