1pub mod utils;
37pub mod layers;
38pub mod models;
39pub mod loss;
40pub mod optimizers;
41pub mod schedulers;
42pub mod training;
43pub mod persistence;
44
45pub use models::lstm_network::{LSTMNetwork, LayerDropoutConfig};
47pub use models::gru_network::{GRUNetwork, LayerDropoutConfig as GRULayerDropoutConfig, GRUNetworkCache};
48pub use layers::lstm_cell::LSTMCell;
49pub use layers::peephole_lstm_cell::PeepholeLSTMCell;
50pub use layers::gru_cell::{GRUCell, GRUCellGradients, GRUCellCache};
51pub use layers::bilstm_network::{BiLSTMNetwork, CombineMode, BiLSTMNetworkCache};
52pub use layers::dropout::{Dropout, Zoneout};
53pub use training::{
54 LSTMTrainer, ScheduledLSTMTrainer, TrainingConfig,
55 create_basic_trainer, create_step_lr_trainer, create_one_cycle_trainer, create_cosine_annealing_trainer
56};
57pub use optimizers::{SGD, Adam, RMSprop, ScheduledOptimizer};
58pub use schedulers::{
59 LearningRateScheduler, ConstantLR, StepLR, MultiStepLR, ExponentialLR,
60 CosineAnnealingLR, CosineAnnealingWarmRestarts, OneCycleLR,
61 ReduceLROnPlateau, LinearLR, AnnealStrategy
62};
63pub use loss::{LossFunction, MSELoss, MAELoss, CrossEntropyLoss};
64pub use persistence::{ModelPersistence, PersistentModel, ModelMetadata, PersistenceError};
65
66#[cfg(test)]
67mod tests {
68 use super::*;
69 use ndarray::arr2;
70
71 #[test]
72 fn test_library_integration() {
73 let mut network = models::lstm_network::LSTMNetwork::new(2, 3, 1);
74 let input = arr2(&[[1.0], [0.5]]);
75 let hx = arr2(&[[0.0], [0.0], [0.0]]);
76 let cx = arr2(&[[0.0], [0.0], [0.0]]);
77
78 let (hy, cy) = network.forward(&input, &hx, &cx);
79
80 assert_eq!(hy.shape(), &[3, 1]);
81 assert_eq!(cy.shape(), &[3, 1]);
82 }
83
84 #[test]
85 fn test_library_with_dropout() {
86 let mut network = models::lstm_network::LSTMNetwork::new(2, 3, 1)
87 .with_input_dropout(0.2, false)
88 .with_recurrent_dropout(0.3, true)
89 .with_output_dropout(0.1);
90
91 let input = arr2(&[[1.0], [0.5]]);
92 let hx = arr2(&[[0.0], [0.0], [0.0]]);
93 let cx = arr2(&[[0.0], [0.0], [0.0]]);
94
95 network.train();
97 let (hy_train, cy_train) = network.forward(&input, &hx, &cx);
98
99 network.eval();
101 let (hy_eval, cy_eval) = network.forward(&input, &hx, &cx);
102
103 assert_eq!(hy_train.shape(), &[3, 1]);
104 assert_eq!(cy_train.shape(), &[3, 1]);
105 assert_eq!(hy_eval.shape(), &[3, 1]);
106 assert_eq!(cy_eval.shape(), &[3, 1]);
107 }
108}