1pub mod utils;
37pub mod layers;
38pub mod models;
39pub mod loss;
40pub mod optimizers;
41pub mod schedulers;
42pub mod training;
43pub mod persistence;
44pub mod text;
45
46pub use models::lstm_network::{LSTMNetwork, LSTMNetworkCache, LSTMNetworkBatchCache, LayerDropoutConfig};
48pub use models::gru_network::{GRUNetwork, LayerDropoutConfig as GRULayerDropoutConfig, GRUNetworkCache};
49pub use layers::lstm_cell::{LSTMCell, LSTMCellCache, LSTMCellBatchCache, LSTMCellGradients};
50pub use layers::peephole_lstm_cell::PeepholeLSTMCell;
51pub use layers::gru_cell::{GRUCell, GRUCellGradients, GRUCellCache};
52pub use layers::bilstm_network::{BiLSTMNetwork, CombineMode, BiLSTMNetworkCache};
53pub use layers::dropout::{Dropout, Zoneout};
54pub use layers::linear::{LinearLayer, LinearGradients};
55pub use training::{
56 LSTMTrainer, ScheduledLSTMTrainer, LSTMBatchTrainer, TrainingConfig, TrainingMetrics,
57 EarlyStoppingConfig, EarlyStoppingMetric, EarlyStopper,
58 create_basic_trainer, create_step_lr_trainer, create_one_cycle_trainer, create_cosine_annealing_trainer,
59 create_basic_batch_trainer, create_adam_batch_trainer
60};
61pub use optimizers::{SGD, Adam, RMSprop, ScheduledOptimizer};
62pub use schedulers::{
63 LearningRateScheduler, ConstantLR, StepLR, MultiStepLR, ExponentialLR,
64 CosineAnnealingLR, CosineAnnealingWarmRestarts, OneCycleLR,
65 ReduceLROnPlateau, LinearLR, AnnealStrategy,
66 PolynomialLR, CyclicalLR, CyclicalMode, ScaleMode, WarmupScheduler,
67 LRScheduleVisualizer
68};
69pub use loss::{LossFunction, MSELoss, MAELoss, CrossEntropyLoss};
70pub use persistence::{ModelPersistence, PersistentModel, ModelMetadata, PersistenceError};
71pub use text::{
72 TextVocabulary, CharacterEmbedding, EmbeddingGradients,
73 sample_with_temperature, sample_top_k, sample_nucleus, argmax, softmax
74};
75
76#[cfg(test)]
77mod tests {
78 use super::*;
79 use ndarray::arr2;
80
81 #[test]
82 fn test_library_integration() {
83 let mut network = models::lstm_network::LSTMNetwork::new(2, 3, 1);
84 let input = arr2(&[[1.0], [0.5]]);
85 let hx = arr2(&[[0.0], [0.0], [0.0]]);
86 let cx = arr2(&[[0.0], [0.0], [0.0]]);
87
88 let (hy, cy) = network.forward(&input, &hx, &cx);
89
90 assert_eq!(hy.shape(), &[3, 1]);
91 assert_eq!(cy.shape(), &[3, 1]);
92 }
93
94 #[test]
95 fn test_library_with_dropout() {
96 let mut network = models::lstm_network::LSTMNetwork::new(2, 3, 1)
97 .with_input_dropout(0.2, false)
98 .with_recurrent_dropout(0.3, true)
99 .with_output_dropout(0.1);
100
101 let input = arr2(&[[1.0], [0.5]]);
102 let hx = arr2(&[[0.0], [0.0], [0.0]]);
103 let cx = arr2(&[[0.0], [0.0], [0.0]]);
104
105 network.train();
107 let (hy_train, cy_train) = network.forward(&input, &hx, &cx);
108
109 network.eval();
111 let (hy_eval, cy_eval) = network.forward(&input, &hx, &cx);
112
113 assert_eq!(hy_train.shape(), &[3, 1]);
114 assert_eq!(cy_train.shape(), &[3, 1]);
115 assert_eq!(hy_eval.shape(), &[3, 1]);
116 assert_eq!(cy_eval.shape(), &[3, 1]);
117 }
118}