1pub mod utils;
37pub mod layers;
38pub mod models;
39pub mod loss;
40pub mod optimizers;
41pub mod schedulers;
42pub mod training;
43pub mod persistence;
44
45pub use models::lstm_network::{LSTMNetwork, LSTMNetworkCache, LSTMNetworkBatchCache, LayerDropoutConfig};
47pub use models::gru_network::{GRUNetwork, LayerDropoutConfig as GRULayerDropoutConfig, GRUNetworkCache};
48pub use layers::lstm_cell::{LSTMCell, LSTMCellCache, LSTMCellBatchCache, LSTMCellGradients};
49pub use layers::peephole_lstm_cell::PeepholeLSTMCell;
50pub use layers::gru_cell::{GRUCell, GRUCellGradients, GRUCellCache};
51pub use layers::bilstm_network::{BiLSTMNetwork, CombineMode, BiLSTMNetworkCache};
52pub use layers::dropout::{Dropout, Zoneout};
53pub use layers::linear::{LinearLayer, LinearGradients};
54pub use training::{
55 LSTMTrainer, ScheduledLSTMTrainer, LSTMBatchTrainer, TrainingConfig, TrainingMetrics,
56 EarlyStoppingConfig, EarlyStoppingMetric, EarlyStopper,
57 create_basic_trainer, create_step_lr_trainer, create_one_cycle_trainer, create_cosine_annealing_trainer,
58 create_basic_batch_trainer, create_adam_batch_trainer
59};
60pub use optimizers::{SGD, Adam, RMSprop, ScheduledOptimizer};
61pub use schedulers::{
62 LearningRateScheduler, ConstantLR, StepLR, MultiStepLR, ExponentialLR,
63 CosineAnnealingLR, CosineAnnealingWarmRestarts, OneCycleLR,
64 ReduceLROnPlateau, LinearLR, AnnealStrategy,
65 PolynomialLR, CyclicalLR, CyclicalMode, ScaleMode, WarmupScheduler,
66 LRScheduleVisualizer
67};
68pub use loss::{LossFunction, MSELoss, MAELoss, CrossEntropyLoss};
69pub use persistence::{ModelPersistence, PersistentModel, ModelMetadata, PersistenceError};
70
71#[cfg(test)]
72mod tests {
73 use super::*;
74 use ndarray::arr2;
75
76 #[test]
77 fn test_library_integration() {
78 let mut network = models::lstm_network::LSTMNetwork::new(2, 3, 1);
79 let input = arr2(&[[1.0], [0.5]]);
80 let hx = arr2(&[[0.0], [0.0], [0.0]]);
81 let cx = arr2(&[[0.0], [0.0], [0.0]]);
82
83 let (hy, cy) = network.forward(&input, &hx, &cx);
84
85 assert_eq!(hy.shape(), &[3, 1]);
86 assert_eq!(cy.shape(), &[3, 1]);
87 }
88
89 #[test]
90 fn test_library_with_dropout() {
91 let mut network = models::lstm_network::LSTMNetwork::new(2, 3, 1)
92 .with_input_dropout(0.2, false)
93 .with_recurrent_dropout(0.3, true)
94 .with_output_dropout(0.1);
95
96 let input = arr2(&[[1.0], [0.5]]);
97 let hx = arr2(&[[0.0], [0.0], [0.0]]);
98 let cx = arr2(&[[0.0], [0.0], [0.0]]);
99
100 network.train();
102 let (hy_train, cy_train) = network.forward(&input, &hx, &cx);
103
104 network.eval();
106 let (hy_eval, cy_eval) = network.forward(&input, &hx, &cx);
107
108 assert_eq!(hy_train.shape(), &[3, 1]);
109 assert_eq!(cy_train.shape(), &[3, 1]);
110 assert_eq!(hy_eval.shape(), &[3, 1]);
111 assert_eq!(cy_eval.shape(), &[3, 1]);
112 }
113}