1pub mod utils;
37pub mod layers;
38pub mod models;
39pub mod loss;
40pub mod optimizers;
41pub mod schedulers;
42pub mod training;
43pub mod persistence;
44
45pub use models::lstm_network::{LSTMNetwork, LSTMNetworkCache, LSTMNetworkBatchCache, LayerDropoutConfig};
47pub use models::gru_network::{GRUNetwork, LayerDropoutConfig as GRULayerDropoutConfig, GRUNetworkCache};
48pub use layers::lstm_cell::{LSTMCell, LSTMCellCache, LSTMCellBatchCache, LSTMCellGradients};
49pub use layers::peephole_lstm_cell::PeepholeLSTMCell;
50pub use layers::gru_cell::{GRUCell, GRUCellGradients, GRUCellCache};
51pub use layers::bilstm_network::{BiLSTMNetwork, CombineMode, BiLSTMNetworkCache};
52pub use layers::dropout::{Dropout, Zoneout};
53pub use training::{
54 LSTMTrainer, ScheduledLSTMTrainer, LSTMBatchTrainer, TrainingConfig, TrainingMetrics,
55 EarlyStoppingConfig, EarlyStoppingMetric, EarlyStopper,
56 create_basic_trainer, create_step_lr_trainer, create_one_cycle_trainer, create_cosine_annealing_trainer,
57 create_basic_batch_trainer, create_adam_batch_trainer
58};
59pub use optimizers::{SGD, Adam, RMSprop, ScheduledOptimizer};
60pub use schedulers::{
61 LearningRateScheduler, ConstantLR, StepLR, MultiStepLR, ExponentialLR,
62 CosineAnnealingLR, CosineAnnealingWarmRestarts, OneCycleLR,
63 ReduceLROnPlateau, LinearLR, AnnealStrategy,
64 PolynomialLR, CyclicalLR, CyclicalMode, ScaleMode, WarmupScheduler,
65 LRScheduleVisualizer
66};
67pub use loss::{LossFunction, MSELoss, MAELoss, CrossEntropyLoss};
68pub use persistence::{ModelPersistence, PersistentModel, ModelMetadata, PersistenceError};
69
70#[cfg(test)]
71mod tests {
72 use super::*;
73 use ndarray::arr2;
74
75 #[test]
76 fn test_library_integration() {
77 let mut network = models::lstm_network::LSTMNetwork::new(2, 3, 1);
78 let input = arr2(&[[1.0], [0.5]]);
79 let hx = arr2(&[[0.0], [0.0], [0.0]]);
80 let cx = arr2(&[[0.0], [0.0], [0.0]]);
81
82 let (hy, cy) = network.forward(&input, &hx, &cx);
83
84 assert_eq!(hy.shape(), &[3, 1]);
85 assert_eq!(cy.shape(), &[3, 1]);
86 }
87
88 #[test]
89 fn test_library_with_dropout() {
90 let mut network = models::lstm_network::LSTMNetwork::new(2, 3, 1)
91 .with_input_dropout(0.2, false)
92 .with_recurrent_dropout(0.3, true)
93 .with_output_dropout(0.1);
94
95 let input = arr2(&[[1.0], [0.5]]);
96 let hx = arr2(&[[0.0], [0.0], [0.0]]);
97 let cx = arr2(&[[0.0], [0.0], [0.0]]);
98
99 network.train();
101 let (hy_train, cy_train) = network.forward(&input, &hx, &cx);
102
103 network.eval();
105 let (hy_eval, cy_eval) = network.forward(&input, &hx, &cx);
106
107 assert_eq!(hy_train.shape(), &[3, 1]);
108 assert_eq!(cy_train.shape(), &[3, 1]);
109 assert_eq!(hy_eval.shape(), &[3, 1]);
110 assert_eq!(cy_eval.shape(), &[3, 1]);
111 }
112}