1pub mod utils;
37pub mod layers;
38pub mod models;
39pub mod loss;
40pub mod optimizers;
41pub mod schedulers;
42pub mod training;
43pub mod persistence;
44
45pub use models::lstm_network::{LSTMNetwork, LSTMNetworkCache, LSTMNetworkBatchCache, LayerDropoutConfig};
47pub use models::gru_network::{GRUNetwork, LayerDropoutConfig as GRULayerDropoutConfig, GRUNetworkCache};
48pub use layers::lstm_cell::{LSTMCell, LSTMCellCache, LSTMCellBatchCache, LSTMCellGradients};
49pub use layers::peephole_lstm_cell::PeepholeLSTMCell;
50pub use layers::gru_cell::{GRUCell, GRUCellGradients, GRUCellCache};
51pub use layers::bilstm_network::{BiLSTMNetwork, CombineMode, BiLSTMNetworkCache};
52pub use layers::dropout::{Dropout, Zoneout};
53pub use training::{
54 LSTMTrainer, ScheduledLSTMTrainer, LSTMBatchTrainer, TrainingConfig, TrainingMetrics,
55 create_basic_trainer, create_step_lr_trainer, create_one_cycle_trainer, create_cosine_annealing_trainer,
56 create_basic_batch_trainer, create_adam_batch_trainer
57};
58pub use optimizers::{SGD, Adam, RMSprop, ScheduledOptimizer};
59pub use schedulers::{
60 LearningRateScheduler, ConstantLR, StepLR, MultiStepLR, ExponentialLR,
61 CosineAnnealingLR, CosineAnnealingWarmRestarts, OneCycleLR,
62 ReduceLROnPlateau, LinearLR, AnnealStrategy,
63 PolynomialLR, CyclicalLR, CyclicalMode, ScaleMode, WarmupScheduler,
64 LRScheduleVisualizer
65};
66pub use loss::{LossFunction, MSELoss, MAELoss, CrossEntropyLoss};
67pub use persistence::{ModelPersistence, PersistentModel, ModelMetadata, PersistenceError};
68
69#[cfg(test)]
70mod tests {
71 use super::*;
72 use ndarray::arr2;
73
74 #[test]
75 fn test_library_integration() {
76 let mut network = models::lstm_network::LSTMNetwork::new(2, 3, 1);
77 let input = arr2(&[[1.0], [0.5]]);
78 let hx = arr2(&[[0.0], [0.0], [0.0]]);
79 let cx = arr2(&[[0.0], [0.0], [0.0]]);
80
81 let (hy, cy) = network.forward(&input, &hx, &cx);
82
83 assert_eq!(hy.shape(), &[3, 1]);
84 assert_eq!(cy.shape(), &[3, 1]);
85 }
86
87 #[test]
88 fn test_library_with_dropout() {
89 let mut network = models::lstm_network::LSTMNetwork::new(2, 3, 1)
90 .with_input_dropout(0.2, false)
91 .with_recurrent_dropout(0.3, true)
92 .with_output_dropout(0.1);
93
94 let input = arr2(&[[1.0], [0.5]]);
95 let hx = arr2(&[[0.0], [0.0], [0.0]]);
96 let cx = arr2(&[[0.0], [0.0], [0.0]]);
97
98 network.train();
100 let (hy_train, cy_train) = network.forward(&input, &hx, &cx);
101
102 network.eval();
104 let (hy_eval, cy_eval) = network.forward(&input, &hx, &cx);
105
106 assert_eq!(hy_train.shape(), &[3, 1]);
107 assert_eq!(cy_train.shape(), &[3, 1]);
108 assert_eq!(hy_eval.shape(), &[3, 1]);
109 assert_eq!(cy_eval.shape(), &[3, 1]);
110 }
111}