pub struct Adam { /* private fields */ }Expand description
Adam optimizer for neural network parameter optimization
Implements the Adam optimization algorithm with PyTorch-compatible interface. Provides adaptive learning rates with momentum for efficient training of neural networks. The optimizer maintains per-parameter state for momentum and velocity estimates, enabling adaptive learning rates that improve convergence across diverse architectures.
§Usage Pattern
The optimizer uses ID-based parameter linking for maximum flexibility and thread safety:
- Parameters are linked to the optimizer via
add_parameteroradd_parameters - The
stepmethod takes mutable references to parameters for thread-safe updates - Parameter states are maintained by tensor ID, allowing for dynamic parameter management
- Supports serialization and deserialization with parameter re-linking
§Dynamic Parameter Management
Parameters can be added, removed, or re-linked at runtime:
add_parameter: Link a single parameteradd_parameters: Link multiple parameters at onceunlink_parameter: Remove parameter state by IDclear_states: Remove all parameter statesis_parameter_linked: Check if a parameter is linked
§Serialization Support
The optimizer supports full serialization and deserialization with state preservation:
- Parameter states are saved with their shapes and insertion order for validation
- After deserialization, use
relink_parametersto restore saved states to new tensors - Parameters must be re-linked in the same chronological order they were originally added
- Shape validation ensures consistency between saved and current parameters
§Features
- ID-Based Parameter Linking: Dynamic parameter management via tensor IDs
- Thread-Safe Step Method: Takes mutable references for safe concurrent access
- Per-Parameter State: Each parameter maintains its own momentum and velocity buffers
- Bias Correction: Automatically corrects initialization bias in moment estimates
- Weight Decay: Optional L2 regularization with efficient implementation
- AMSGrad Support: Optional AMSGrad variant for improved convergence stability
- SIMD Optimization: AVX2-optimized updates for maximum performance
- Full Serialization: Complete state persistence and restoration
§Thread Safety
This type is thread-safe and can be shared between threads. The step method takes mutable references to parameters, ensuring exclusive access during updates.
Implementations§
Source§impl Adam
impl Adam
Sourcepub fn saved_parameter_count(&self) -> usize
pub fn saved_parameter_count(&self) -> usize
Get the number of saved parameter states for checkpoint validation
This method returns the count of parameter states currently stored in the optimizer, which is essential for validating checkpoint integrity and ensuring proper parameter re-linking after deserialization. The count includes all parameters that have been linked to the optimizer and have accumulated optimization state.
§Returns
Number of parameter states currently stored in the optimizer
§Usage Patterns
§Checkpoint Validation
After deserializing an optimizer, this method helps verify that the expected number of parameters were saved and can guide the re-linking process.
§Training Resumption
When resuming training, compare this count with the number of parameters in your model to ensure checkpoint compatibility.
§State Management
Use this method to monitor optimizer state growth and memory usage during training with dynamic parameter addition.
§Examples
use train_station::{Tensor, optimizers::Adam};
use train_station::serialization::Serializable;
let weight = Tensor::ones(vec![10, 5]).with_requires_grad();
let bias = Tensor::zeros(vec![5]).with_requires_grad();
let mut optimizer = Adam::new();
optimizer.add_parameter(&weight);
optimizer.add_parameter(&bias);
// Check parameter count before serialization
assert_eq!(optimizer.saved_parameter_count(), 2);
// Serialize and deserialize
let json = optimizer.to_json().unwrap();
let loaded_optimizer = Adam::from_json(&json).unwrap();
// Verify parameter count is preserved
assert_eq!(loaded_optimizer.saved_parameter_count(), 2);§Performance
- Time Complexity: O(1) - Direct access to internal state count
- Memory Usage: No additional memory allocation
- Thread Safety: Safe to call from multiple threads concurrently
Source§impl Adam
impl Adam
Sourcepub fn new() -> Self
pub fn new() -> Self
Create a new Adam optimizer with default configuration
Initializes an Adam optimizer with PyTorch-compatible default hyperparameters.
Parameters must be linked separately using add_parameter or add_parameters.
§Returns
A new Adam optimizer instance with default hyperparameters
Examples found in repository?
47fn demonstrate_basic_optimizer_setup() {
48 println!("--- Basic Optimizer Setup ---");
49
50 // Create parameters that require gradients
51 let weight = Tensor::randn(vec![3, 2], Some(42)).with_requires_grad();
52 let bias = Tensor::zeros(vec![2]).with_requires_grad();
53
54 println!("Created parameters:");
55 println!(
56 " Weight: shape {:?}, requires_grad: {}",
57 weight.shape().dims(),
58 weight.requires_grad()
59 );
60 println!(
61 " Bias: shape {:?}, requires_grad: {}",
62 bias.shape().dims(),
63 bias.requires_grad()
64 );
65
66 // Create Adam optimizer with default configuration
67 let mut optimizer = Adam::new();
68 println!(
69 "Created Adam optimizer with learning rate: {}",
70 optimizer.learning_rate()
71 );
72
73 // Add parameters to optimizer
74 optimizer.add_parameter(&weight);
75 optimizer.add_parameter(&bias);
76 println!(
77 "Added {} parameters to optimizer",
78 optimizer.parameter_count()
79 );
80
81 // Create optimizer with custom configuration
82 let config = AdamConfig {
83 learning_rate: 0.01,
84 beta1: 0.9,
85 beta2: 0.999,
86 eps: 1e-8,
87 weight_decay: 0.0,
88 amsgrad: false,
89 };
90
91 let mut custom_optimizer = Adam::with_config(config);
92 custom_optimizer.add_parameter(&weight);
93 custom_optimizer.add_parameter(&bias);
94
95 println!(
96 "Created custom optimizer with learning rate: {}",
97 custom_optimizer.learning_rate()
98 );
99
100 // Demonstrate parameter linking
101 println!("Parameter linking completed successfully");
102}More examples
84fn demonstrate_default_adam() -> Result<(), Box<dyn std::error::Error>> {
85 println!("--- Default Adam Configuration ---");
86
87 // Create a simple regression problem: y = 2*x + 1
88 let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5, 1]).unwrap();
89 let y_true = Tensor::from_slice(&[3.0, 5.0, 7.0, 9.0, 11.0], vec![5, 1]).unwrap();
90
91 // Create model parameters
92 let mut weight = Tensor::randn(vec![1, 1], Some(42)).with_requires_grad();
93 let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
94
95 // Create Adam optimizer with default configuration
96 let mut optimizer = Adam::new();
97 optimizer.add_parameter(&weight);
98 optimizer.add_parameter(&bias);
99
100 println!("Default Adam configuration:");
101 println!(" Learning rate: {}", optimizer.learning_rate());
102 println!(" Initial weight: {:.6}", weight.value());
103 println!(" Initial bias: {:.6}", bias.value());
104
105 // Training loop
106 let num_epochs = 50;
107 let mut losses = Vec::new();
108
109 for epoch in 0..num_epochs {
110 // Forward pass
111 let y_pred = x_data.matmul(&weight) + &bias;
112 let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
113
114 // Backward pass
115 loss.backward(None);
116
117 // Optimizer step
118 optimizer.step(&mut [&mut weight, &mut bias]);
119 optimizer.zero_grad(&mut [&mut weight, &mut bias]);
120
121 losses.push(loss.value());
122
123 if epoch % 10 == 0 || epoch == num_epochs - 1 {
124 println!("Epoch {:3}: Loss = {:.6}", epoch, loss.value());
125 }
126 }
127
128 // Evaluate final model
129 let _final_predictions = x_data.matmul(&weight) + &bias;
130 println!("\nFinal model:");
131 println!(" Learned weight: {:.6} (target: 2.0)", weight.value());
132 println!(" Learned bias: {:.6} (target: 1.0)", bias.value());
133 println!(" Final loss: {:.6}", losses[losses.len() - 1]);
134
135 Ok(())
136}Sourcepub fn with_config(config: AdamConfig) -> Self
pub fn with_config(config: AdamConfig) -> Self
Create a new Adam optimizer with custom configuration
Allows full control over all Adam hyperparameters for specialized training
scenarios such as fine-tuning, transfer learning, or research applications.
Parameters must be linked separately using add_parameter or add_parameters.
§Arguments
config- Adam configuration with custom hyperparameters
§Returns
A new Adam optimizer instance with the specified configuration
Examples found in repository?
47fn demonstrate_basic_optimizer_setup() {
48 println!("--- Basic Optimizer Setup ---");
49
50 // Create parameters that require gradients
51 let weight = Tensor::randn(vec![3, 2], Some(42)).with_requires_grad();
52 let bias = Tensor::zeros(vec![2]).with_requires_grad();
53
54 println!("Created parameters:");
55 println!(
56 " Weight: shape {:?}, requires_grad: {}",
57 weight.shape().dims(),
58 weight.requires_grad()
59 );
60 println!(
61 " Bias: shape {:?}, requires_grad: {}",
62 bias.shape().dims(),
63 bias.requires_grad()
64 );
65
66 // Create Adam optimizer with default configuration
67 let mut optimizer = Adam::new();
68 println!(
69 "Created Adam optimizer with learning rate: {}",
70 optimizer.learning_rate()
71 );
72
73 // Add parameters to optimizer
74 optimizer.add_parameter(&weight);
75 optimizer.add_parameter(&bias);
76 println!(
77 "Added {} parameters to optimizer",
78 optimizer.parameter_count()
79 );
80
81 // Create optimizer with custom configuration
82 let config = AdamConfig {
83 learning_rate: 0.01,
84 beta1: 0.9,
85 beta2: 0.999,
86 eps: 1e-8,
87 weight_decay: 0.0,
88 amsgrad: false,
89 };
90
91 let mut custom_optimizer = Adam::with_config(config);
92 custom_optimizer.add_parameter(&weight);
93 custom_optimizer.add_parameter(&bias);
94
95 println!(
96 "Created custom optimizer with learning rate: {}",
97 custom_optimizer.learning_rate()
98 );
99
100 // Demonstrate parameter linking
101 println!("Parameter linking completed successfully");
102}More examples
109fn demonstrate_optimizer_serialization() -> Result<(), Box<dyn std::error::Error>> {
110 println!("\n--- Optimizer Serialization ---");
111
112 // Create an optimizer with some parameters
113 let mut weight = Tensor::randn(vec![2, 2], Some(42)).with_requires_grad();
114 let mut bias = Tensor::randn(vec![2], Some(43)).with_requires_grad();
115
116 let config = AdamConfig {
117 learning_rate: 0.001,
118 beta1: 0.9,
119 beta2: 0.999,
120 eps: 1e-8,
121 weight_decay: 0.0,
122 amsgrad: false,
123 };
124
125 let mut optimizer = Adam::with_config(config);
126 optimizer.add_parameter(&weight);
127 optimizer.add_parameter(&bias);
128
129 println!(
130 "Created optimizer with {} parameters",
131 optimizer.parameter_count()
132 );
133 println!("Learning rate: {}", optimizer.learning_rate());
134
135 // Simulate some training steps
136 for _ in 0..3 {
137 let mut loss = weight.sum() + bias.sum();
138 loss.backward(None);
139 optimizer.step(&mut [&mut weight, &mut bias]);
140 optimizer.zero_grad(&mut [&mut weight, &mut bias]);
141 }
142
143 // Save optimizer state
144 let optimizer_path = "temp_optimizer.json";
145 optimizer.save_json(optimizer_path)?;
146 println!("Saved optimizer to: {}", optimizer_path);
147
148 // Load optimizer state
149 let loaded_optimizer = Adam::load_json(optimizer_path)?;
150 println!(
151 "Loaded optimizer with {} parameters",
152 loaded_optimizer.parameter_count()
153 );
154 println!("Learning rate: {}", loaded_optimizer.learning_rate());
155
156 // Verify optimizer state
157 assert_eq!(
158 optimizer.parameter_count(),
159 loaded_optimizer.parameter_count()
160 );
161 assert_eq!(optimizer.learning_rate(), loaded_optimizer.learning_rate());
162 println!("Optimizer serialization verification: PASSED");
163
164 Ok(())
165}317fn train_with_config(config: TrainingConfig) -> Result<TrainingStats, Box<dyn std::error::Error>> {
318 // Create training data
319 let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5, 1]).unwrap();
320 let y_true = Tensor::from_slice(&[3.0, 5.0, 7.0, 9.0, 11.0], vec![5, 1]).unwrap();
321
322 // Create model parameters
323 let mut weight = Tensor::randn(vec![1, 1], Some(123)).with_requires_grad();
324 let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
325
326 // Create optimizer with custom configuration
327 let adam_config = AdamConfig {
328 learning_rate: config.learning_rate,
329 beta1: config.beta1,
330 beta2: config.beta2,
331 eps: 1e-8,
332 weight_decay: config.weight_decay,
333 amsgrad: false,
334 };
335
336 let mut optimizer = Adam::with_config(adam_config);
337 optimizer.add_parameter(&weight);
338 optimizer.add_parameter(&bias);
339
340 // Training loop
341 let mut losses = Vec::new();
342 let mut convergence_epoch = config.epochs;
343
344 for epoch in 0..config.epochs {
345 // Forward pass
346 let y_pred = x_data.matmul(&weight) + &bias;
347 let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
348
349 // Backward pass
350 loss.backward(None);
351
352 // Optimizer step
353 optimizer.step(&mut [&mut weight, &mut bias]);
354 optimizer.zero_grad(&mut [&mut weight, &mut bias]);
355
356 let loss_value = loss.value();
357 losses.push(loss_value);
358
359 // Check for convergence (loss < 0.01)
360 if loss_value < 0.01 && convergence_epoch == config.epochs {
361 convergence_epoch = epoch;
362 }
363 }
364
365 Ok(TrainingStats {
366 config,
367 final_loss: losses[losses.len() - 1],
368 loss_history: losses,
369 convergence_epoch,
370 weight_norm: weight.norm().value(),
371 })
372}223fn demonstrate_training_loop() -> Result<(), Box<dyn std::error::Error>> {
224 println!("\n--- Training Loop ---");
225
226 // Create layer and training data
227 let mut layer = LinearLayer::new(2, 1, Some(45));
228
229 // Simple regression task: y = 2*x1 + 3*x2 + 1
230 let x_data = Tensor::from_slice(
231 &[
232 1.0, 1.0, // x1=1, x2=1 -> y=6
233 2.0, 1.0, // x1=2, x2=1 -> y=8
234 1.0, 2.0, // x1=1, x2=2 -> y=9
235 2.0, 2.0, // x1=2, x2=2 -> y=11
236 ],
237 vec![4, 2],
238 )
239 .unwrap();
240
241 let y_true = Tensor::from_slice(&[6.0, 8.0, 9.0, 11.0], vec![4, 1]).unwrap();
242
243 println!("Training data:");
244 println!(" X shape: {:?}", x_data.shape().dims());
245 println!(" Y shape: {:?}", y_true.shape().dims());
246 println!(" Target function: y = 2*x1 + 3*x2 + 1");
247
248 // Create optimizer
249 let config = AdamConfig {
250 learning_rate: 0.01,
251 beta1: 0.9,
252 beta2: 0.999,
253 eps: 1e-8,
254 weight_decay: 0.0,
255 amsgrad: false,
256 };
257
258 let mut optimizer = Adam::with_config(config);
259 let params = layer.parameters();
260 for param in ¶ms {
261 optimizer.add_parameter(param);
262 }
263
264 println!("Optimizer setup complete. Starting training...");
265
266 // Training loop
267 let num_epochs = 100;
268 let mut losses = Vec::new();
269
270 for epoch in 0..num_epochs {
271 // Forward pass
272 let y_pred = layer.forward(&x_data);
273
274 // Compute loss: MSE
275 let diff = y_pred.sub_tensor(&y_true);
276 let mut loss = diff.pow_scalar(2.0).mean();
277
278 // Backward pass
279 loss.backward(None);
280
281 // Optimizer step
282 let mut params = layer.parameters();
283 optimizer.step(&mut params);
284 optimizer.zero_grad(&mut params);
285
286 losses.push(loss.value());
287
288 // Print progress
289 if epoch % 20 == 0 || epoch == num_epochs - 1 {
290 println!("Epoch {:3}: Loss = {:.6}", epoch, loss.value());
291 }
292 }
293
294 // Evaluate final model
295 let final_predictions = layer.forward_no_grad(&x_data);
296
297 println!("\nFinal model evaluation:");
298 println!(" Learned weights: {:?}", layer.weight.data());
299 println!(" Learned bias: {:?}", layer.bias.data());
300 println!(" Target weights: [2.0, 3.0]");
301 println!(" Target bias: [1.0]");
302
303 println!(" Predictions vs True:");
304 for i in 0..4 {
305 let pred = final_predictions.data()[i];
306 let true_val = y_true.data()[i];
307 println!(
308 " Sample {}: pred={:.3}, true={:.1}, error={:.3}",
309 i + 1,
310 pred,
311 true_val,
312 (pred - true_val).abs()
313 );
314 }
315
316 // Training analysis
317 let initial_loss = losses[0];
318 let final_loss = losses[losses.len() - 1];
319 let loss_reduction = (initial_loss - final_loss) / initial_loss * 100.0;
320
321 println!("\nTraining Analysis:");
322 println!(" Initial loss: {:.6}", initial_loss);
323 println!(" Final loss: {:.6}", final_loss);
324 println!(" Loss reduction: {:.1}%", loss_reduction);
325
326 Ok(())
327}Sourcepub fn with_learning_rate(learning_rate: f32) -> Self
pub fn with_learning_rate(learning_rate: f32) -> Self
Create a new Adam optimizer with custom learning rate
A convenience constructor that allows setting only the learning rate while
using default values for all other hyperparameters. Parameters must be
linked separately using add_parameter or add_parameters.
§Arguments
learning_rate- Learning rate for optimization
§Returns
A new Adam optimizer instance with the specified learning rate and default values for all other hyperparameters
Examples found in repository?
73fn main() -> Result<(), Box<dyn std::error::Error>> {
74 println!("=== Basic Encoder Example ===");
75
76 let batch = 2usize;
77 let seq = 6usize;
78 let embed = 32usize;
79 let heads = 4usize;
80
81 let input = Tensor::randn(vec![batch, seq, embed], Some(11));
82 let mut enc = EncoderBlock::new(embed, heads, Some(123));
83
84 // Example: no mask (set Some(mask) to use masking)
85 let out = enc.forward(&input, None);
86 println!("Output shape: {:?}", out.shape().dims());
87
88 // Verify gradients/optimization
89 let mut opt = Adam::with_learning_rate(0.01);
90 let mut params = enc.parameters();
91 for p in ¶ms {
92 opt.add_parameter(p);
93 }
94 let mut loss = out.mean();
95 loss.backward(None);
96 opt.step(&mut params);
97 opt.zero_grad(&mut params);
98 println!("Loss: {:.6}", loss.value());
99 println!("=== Done ===");
100 Ok(())
101}More examples
84fn main() -> Result<(), Box<dyn std::error::Error>> {
85 println!("=== Basic Decoder Example ===");
86
87 let batch = 2usize;
88 let src = 7usize;
89 let tgt = 5usize;
90 let embed = 32usize;
91 let heads = 4usize;
92
93 let memory = Tensor::randn(vec![batch, src, embed], Some(21));
94 let tgt_in = Tensor::randn(vec![batch, tgt, embed], Some(22));
95
96 let mut dec = DecoderBlock::new(embed, heads, Some(456));
97 let out = dec.forward(&tgt_in, &memory, None, None);
98 println!("Output shape: {:?}", out.shape().dims());
99
100 let mut opt = Adam::with_learning_rate(0.01);
101 let mut params = dec.parameters();
102 for p in ¶ms {
103 opt.add_parameter(p);
104 }
105 let mut loss = out.mean();
106 loss.backward(None);
107 opt.step(&mut params);
108 opt.zero_grad(&mut params);
109 println!("Loss: {:.6}", loss.value());
110 println!("=== Done ===");
111 Ok(())
112}148 pub fn train_non_autoregressive_steps(
149 &mut self,
150 src: &Tensor,
151 tgt: &Tensor,
152 steps: usize,
153 lr: f32,
154 ) {
155 let mut opt = Adam::with_learning_rate(lr);
156 {
157 let params_once = self.parameters();
158 for p in ¶ms_once {
159 opt.add_parameter(p);
160 }
161 }
162 for step in 0..steps {
163 // forward + backward scope (immutable borrow)
164 {
165 let pred = self.forward(src, tgt);
166 let diff = pred.sub_tensor(tgt);
167 let mut loss = diff.pow_scalar(2.0).mean();
168 if step == 0 || step + 1 == steps {
169 println!("NAR train step {}: loss={:.6}", step, loss.value());
170 }
171 loss.backward(None);
172 }
173 // step + zero_grad scope (mutable borrow)
174 let mut params_step = self.parameters();
175 opt.step(&mut params_step);
176 opt.zero_grad(&mut params_step);
177 }
178 }
179
180 /// Auto-regressive training (teacher forcing): predict next token with causal mask
181 pub fn train_autoregressive_steps(
182 &mut self,
183 src: &Tensor,
184 tgt: &Tensor,
185 steps: usize,
186 lr: f32,
187 ) {
188 let mut opt = Adam::with_learning_rate(lr);
189 {
190 let params_once = self.parameters();
191 for p in ¶ms_once {
192 opt.add_parameter(p);
193 }
194 }
195
196 // Build encoder memory once (static dataset demo)
197 let mut memory = src.clone();
198 for enc in &self.encoders {
199 memory = enc.forward(&memory, None);
200 }
201
202 let (b, t, _e) = Self::triple(tgt);
203 // Predict y[t] from y[:t] using causal mask; here we simply predict full seq with mask
204 let causal = Self::build_causal_mask_static(b, self.num_heads, t);
205 for step in 0..steps {
206 // forward + backward scope
207 {
208 let mut out = tgt.clone();
209 for dec in &self.decoders {
210 out = dec.forward(&out, &memory, Some(&causal), None);
211 }
212 let diff = out.sub_tensor(tgt);
213 let mut loss = diff.pow_scalar(2.0).mean();
214 if step == 0 || step + 1 == steps {
215 println!("AR train step {}: loss={:.6}", step, loss.value());
216 }
217 loss.backward(None);
218 }
219 let mut params_step = self.parameters();
220 opt.step(&mut params_step);
221 opt.zero_grad(&mut params_step);
222 }
223 }
224
225 fn triple(t: &Tensor) -> (usize, usize, usize) {
226 let d = t.shape().dims();
227 (d[0], d[1], d[2])
228 }
229}
230
231fn main() -> Result<(), Box<dyn std::error::Error>> {
232 println!("=== Basic Transformer Example ===");
233
234 let batch = 2usize;
235 let src_len = 8usize;
236 let tgt_len = 6usize;
237 let embed = 32usize;
238 let heads = 4usize;
239 let layers = 2usize;
240
241 let src = Tensor::randn(vec![batch, src_len, embed], Some(1001));
242 let tgt = Tensor::randn(vec![batch, tgt_len, embed], Some(1002));
243
244 let mut trf = BasicTransformer::new(embed, heads, layers, Some(999));
245 let out = trf.forward(&src, &tgt);
246 println!("Output shape: {:?}", out.shape().dims());
247
248 // Quick optimization step
249 let mut opt = Adam::with_learning_rate(0.005);
250 let mut params = trf.parameters();
251 for p in ¶ms {
252 opt.add_parameter(p);
253 }
254 let mut loss = out.mean();
255 loss.backward(None);
256 opt.step(&mut params);
257 opt.zero_grad(&mut params);
258 println!("Loss: {:.6}", loss.value());
259
260 // Demo: non auto-regressive inference (single pass)
261 let nar = trf.infer_non_autoregressive(&src, tgt_len);
262 println!("NAR output shape: {:?}", nar.shape().dims());
263
264 // Demo: auto-regressive inference (toy)
265 let ar = trf.infer_autoregressive(&src, 3);
266 println!("AR output shape: {:?}", ar.shape().dims());
267
268 // NAR training demo
269 let nar_tgt = tgt.clone();
270 trf.train_non_autoregressive_steps(&src, &nar_tgt, 3, 0.01);
271
272 // AR training demo (teacher-forced)
273 let ar_tgt = tgt.clone();
274 trf.train_autoregressive_steps(&src, &ar_tgt, 3, 0.01);
275 println!("=== Done ===");
276 Ok(())
277}165fn main() -> Result<(), Box<dyn std::error::Error>> {
166 println!("=== Multi-Head Attention Example ===");
167
168 let batch = 2usize;
169 let src_len = 5usize;
170 let tgt_len = 4usize;
171 let embed = 16usize;
172 let heads = 4usize;
173
174 let query = Tensor::randn(vec![batch, tgt_len, embed], Some(7));
175 let key = Tensor::randn(vec![batch, src_len, embed], Some(8));
176 let value = Tensor::randn(vec![batch, src_len, embed], Some(9));
177
178 let mut mha = MultiHeadAttention::new(embed, heads, Some(42));
179
180 // Simple causal mask for target self-attention shape [b, h, tq, tk]
181 let mut mask = Tensor::zeros(vec![batch, heads, tgt_len, src_len]);
182 // Disallow attending to future positions when tgt_len <= src_len by adding -1e9
183 // Here, just demonstrate mask broadcast/add mechanics with a light mask on last head
184 if src_len >= tgt_len {
185 // set upper triangle to a large negative value for head 0
186 for i in 0..tgt_len {
187 for j in (i + 1)..src_len {
188 let idx = [0usize, 0usize, i, j];
189 // Quick set via data_mut using a slice view
190 let offset = mask.memory_offset(&idx);
191 let data = mask.data_mut();
192 data[offset] = -1e9;
193 }
194 }
195 }
196
197 let out = mha.forward(&query, &key, &value, Some(&mask));
198 println!("Output shape: {:?}", out.shape().dims());
199
200 // Tiny training step to confirm gradients are wired
201 let mut optimizer = Adam::with_learning_rate(0.01);
202 let mut params = mha.parameters();
203 for p in ¶ms {
204 optimizer.add_parameter(p);
205 }
206
207 // Dummy loss = mean of output
208 let mut loss = out.mean();
209 loss.backward(None);
210 optimizer.step(&mut params);
211 optimizer.zero_grad(&mut params);
212
213 println!("Loss: {:.6}", loss.value());
214 println!("=== Done ===");
215 Ok(())
216}319fn train_with_scheduler(
320 scheduler: &mut dyn LearningRateScheduler,
321 num_epochs: usize,
322) -> Result<TrainingStats, Box<dyn std::error::Error>> {
323 // Create training data: y = 2*x + 1
324 let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5, 1]).unwrap();
325 let y_true = Tensor::from_slice(&[3.0, 5.0, 7.0, 9.0, 11.0], vec![5, 1]).unwrap();
326
327 // Create model parameters
328 let mut weight = Tensor::randn(vec![1, 1], Some(456)).with_requires_grad();
329 let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
330
331 // Create optimizer with initial learning rate
332 let mut optimizer = Adam::with_learning_rate(0.05);
333 optimizer.add_parameter(&weight);
334 optimizer.add_parameter(&bias);
335
336 // Training loop
337 let mut losses = Vec::new();
338 let mut lr_history = Vec::new();
339 let mut convergence_epoch = num_epochs;
340
341 for epoch in 0..num_epochs {
342 // Forward pass
343 let y_pred = x_data.matmul(&weight) + &bias;
344 let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
345
346 // Backward pass
347 loss.backward(None);
348
349 // Update learning rate using scheduler
350 let current_lr = optimizer.learning_rate();
351 let new_lr = scheduler.step(current_lr, epoch, loss.value());
352
353 if (new_lr - current_lr).abs() > 1e-8 {
354 optimizer.set_learning_rate(new_lr);
355 }
356
357 // Optimizer step
358 optimizer.step(&mut [&mut weight, &mut bias]);
359 optimizer.zero_grad(&mut [&mut weight, &mut bias]);
360
361 let loss_value = loss.value();
362 losses.push(loss_value);
363 lr_history.push(new_lr);
364
365 // Check for convergence
366 if loss_value < 0.01 && convergence_epoch == num_epochs {
367 convergence_epoch = epoch;
368 }
369 }
370
371 Ok(TrainingStats {
372 scheduler_name: scheduler.name().to_string(),
373 final_loss: losses[losses.len() - 1],
374 lr_history,
375 loss_history: losses,
376 convergence_epoch,
377 })
378}376fn demonstrate_serialization() -> Result<(), Box<dyn std::error::Error>> {
377 println!("\n--- Serialization ---");
378
379 // Create and train a simple layer
380 let mut original_layer = LinearLayer::new(2, 1, Some(47));
381
382 // Simple training data
383 let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
384 let y_true = Tensor::from_slice(&[5.0, 11.0], vec![2, 1]).unwrap();
385
386 let mut optimizer = Adam::with_learning_rate(0.01);
387 let params = original_layer.parameters();
388 for param in ¶ms {
389 optimizer.add_parameter(param);
390 }
391
392 // Train for a few epochs
393 for _ in 0..10 {
394 let y_pred = original_layer.forward(&x_data);
395 let mut loss = (y_pred.sub_tensor(&y_true)).pow_scalar(2.0).mean();
396 loss.backward(None);
397
398 let mut params = original_layer.parameters();
399 optimizer.step(&mut params);
400 optimizer.zero_grad(&mut params);
401 }
402
403 println!("Original layer trained");
404 println!(" Weight: {:?}", original_layer.weight.data());
405 println!(" Bias: {:?}", original_layer.bias.data());
406
407 // Save layer
408 original_layer.save_json("temp_linear_layer")?;
409
410 // Load layer
411 let loaded_layer = LinearLayer::load_json("temp_linear_layer", 2, 1)?;
412
413 println!("Loaded layer");
414 println!(" Weight: {:?}", loaded_layer.weight.data());
415 println!(" Bias: {:?}", loaded_layer.bias.data());
416
417 // Verify consistency
418 let test_input = Tensor::from_slice(&[1.0, 1.0], vec![1, 2]).unwrap();
419 let original_output = original_layer.forward_no_grad(&test_input);
420 let loaded_output = loaded_layer.forward_no_grad(&test_input);
421
422 println!("Consistency check:");
423 println!(" Original output: {:?}", original_output.data());
424 println!(" Loaded output: {:?}", loaded_output.data());
425 println!(
426 " Match: {}",
427 original_output
428 .data()
429 .iter()
430 .zip(loaded_output.data().iter())
431 .all(|(a, b)| (a - b).abs() < 1e-6)
432 );
433
434 println!("Serialization verification: PASSED");
435
436 Ok(())
437}- examples/getting_started/serialization_basics.rs
- examples/getting_started/optimizer_basics.rs
- examples/supervised_training/../neural_networks/feedforward_network.rs
- examples/supervised_training/supervised_bce.rs
- examples/supervised_training/supervised_classification.rs
- examples/supervised_training/supervised_regression.rs
- examples/RL_training/dqn.rs
- examples/RL_training/ppo_discrete.rs
- examples/RL_training/ppo_continuous.rs
- examples/RL_training/td3.rs
Sourcepub fn add_parameter(&mut self, parameter: &Tensor)
pub fn add_parameter(&mut self, parameter: &Tensor)
Add a single parameter to the optimizer
Links a parameter to the optimizer by creating a new parameter state
indexed by the tensor’s ID. The parameter must have requires_grad set to true.
§Arguments
parameter- Reference to the tensor to link
§Panics
Panics if the parameter does not have requires_grad set to true
Examples found in repository?
73fn main() -> Result<(), Box<dyn std::error::Error>> {
74 println!("=== Basic Encoder Example ===");
75
76 let batch = 2usize;
77 let seq = 6usize;
78 let embed = 32usize;
79 let heads = 4usize;
80
81 let input = Tensor::randn(vec![batch, seq, embed], Some(11));
82 let mut enc = EncoderBlock::new(embed, heads, Some(123));
83
84 // Example: no mask (set Some(mask) to use masking)
85 let out = enc.forward(&input, None);
86 println!("Output shape: {:?}", out.shape().dims());
87
88 // Verify gradients/optimization
89 let mut opt = Adam::with_learning_rate(0.01);
90 let mut params = enc.parameters();
91 for p in ¶ms {
92 opt.add_parameter(p);
93 }
94 let mut loss = out.mean();
95 loss.backward(None);
96 opt.step(&mut params);
97 opt.zero_grad(&mut params);
98 println!("Loss: {:.6}", loss.value());
99 println!("=== Done ===");
100 Ok(())
101}More examples
84fn main() -> Result<(), Box<dyn std::error::Error>> {
85 println!("=== Basic Decoder Example ===");
86
87 let batch = 2usize;
88 let src = 7usize;
89 let tgt = 5usize;
90 let embed = 32usize;
91 let heads = 4usize;
92
93 let memory = Tensor::randn(vec![batch, src, embed], Some(21));
94 let tgt_in = Tensor::randn(vec![batch, tgt, embed], Some(22));
95
96 let mut dec = DecoderBlock::new(embed, heads, Some(456));
97 let out = dec.forward(&tgt_in, &memory, None, None);
98 println!("Output shape: {:?}", out.shape().dims());
99
100 let mut opt = Adam::with_learning_rate(0.01);
101 let mut params = dec.parameters();
102 for p in ¶ms {
103 opt.add_parameter(p);
104 }
105 let mut loss = out.mean();
106 loss.backward(None);
107 opt.step(&mut params);
108 opt.zero_grad(&mut params);
109 println!("Loss: {:.6}", loss.value());
110 println!("=== Done ===");
111 Ok(())
112}148 pub fn train_non_autoregressive_steps(
149 &mut self,
150 src: &Tensor,
151 tgt: &Tensor,
152 steps: usize,
153 lr: f32,
154 ) {
155 let mut opt = Adam::with_learning_rate(lr);
156 {
157 let params_once = self.parameters();
158 for p in ¶ms_once {
159 opt.add_parameter(p);
160 }
161 }
162 for step in 0..steps {
163 // forward + backward scope (immutable borrow)
164 {
165 let pred = self.forward(src, tgt);
166 let diff = pred.sub_tensor(tgt);
167 let mut loss = diff.pow_scalar(2.0).mean();
168 if step == 0 || step + 1 == steps {
169 println!("NAR train step {}: loss={:.6}", step, loss.value());
170 }
171 loss.backward(None);
172 }
173 // step + zero_grad scope (mutable borrow)
174 let mut params_step = self.parameters();
175 opt.step(&mut params_step);
176 opt.zero_grad(&mut params_step);
177 }
178 }
179
180 /// Auto-regressive training (teacher forcing): predict next token with causal mask
181 pub fn train_autoregressive_steps(
182 &mut self,
183 src: &Tensor,
184 tgt: &Tensor,
185 steps: usize,
186 lr: f32,
187 ) {
188 let mut opt = Adam::with_learning_rate(lr);
189 {
190 let params_once = self.parameters();
191 for p in ¶ms_once {
192 opt.add_parameter(p);
193 }
194 }
195
196 // Build encoder memory once (static dataset demo)
197 let mut memory = src.clone();
198 for enc in &self.encoders {
199 memory = enc.forward(&memory, None);
200 }
201
202 let (b, t, _e) = Self::triple(tgt);
203 // Predict y[t] from y[:t] using causal mask; here we simply predict full seq with mask
204 let causal = Self::build_causal_mask_static(b, self.num_heads, t);
205 for step in 0..steps {
206 // forward + backward scope
207 {
208 let mut out = tgt.clone();
209 for dec in &self.decoders {
210 out = dec.forward(&out, &memory, Some(&causal), None);
211 }
212 let diff = out.sub_tensor(tgt);
213 let mut loss = diff.pow_scalar(2.0).mean();
214 if step == 0 || step + 1 == steps {
215 println!("AR train step {}: loss={:.6}", step, loss.value());
216 }
217 loss.backward(None);
218 }
219 let mut params_step = self.parameters();
220 opt.step(&mut params_step);
221 opt.zero_grad(&mut params_step);
222 }
223 }
224
225 fn triple(t: &Tensor) -> (usize, usize, usize) {
226 let d = t.shape().dims();
227 (d[0], d[1], d[2])
228 }
229}
230
231fn main() -> Result<(), Box<dyn std::error::Error>> {
232 println!("=== Basic Transformer Example ===");
233
234 let batch = 2usize;
235 let src_len = 8usize;
236 let tgt_len = 6usize;
237 let embed = 32usize;
238 let heads = 4usize;
239 let layers = 2usize;
240
241 let src = Tensor::randn(vec![batch, src_len, embed], Some(1001));
242 let tgt = Tensor::randn(vec![batch, tgt_len, embed], Some(1002));
243
244 let mut trf = BasicTransformer::new(embed, heads, layers, Some(999));
245 let out = trf.forward(&src, &tgt);
246 println!("Output shape: {:?}", out.shape().dims());
247
248 // Quick optimization step
249 let mut opt = Adam::with_learning_rate(0.005);
250 let mut params = trf.parameters();
251 for p in ¶ms {
252 opt.add_parameter(p);
253 }
254 let mut loss = out.mean();
255 loss.backward(None);
256 opt.step(&mut params);
257 opt.zero_grad(&mut params);
258 println!("Loss: {:.6}", loss.value());
259
260 // Demo: non auto-regressive inference (single pass)
261 let nar = trf.infer_non_autoregressive(&src, tgt_len);
262 println!("NAR output shape: {:?}", nar.shape().dims());
263
264 // Demo: auto-regressive inference (toy)
265 let ar = trf.infer_autoregressive(&src, 3);
266 println!("AR output shape: {:?}", ar.shape().dims());
267
268 // NAR training demo
269 let nar_tgt = tgt.clone();
270 trf.train_non_autoregressive_steps(&src, &nar_tgt, 3, 0.01);
271
272 // AR training demo (teacher-forced)
273 let ar_tgt = tgt.clone();
274 trf.train_autoregressive_steps(&src, &ar_tgt, 3, 0.01);
275 println!("=== Done ===");
276 Ok(())
277}47fn demonstrate_basic_optimizer_setup() {
48 println!("--- Basic Optimizer Setup ---");
49
50 // Create parameters that require gradients
51 let weight = Tensor::randn(vec![3, 2], Some(42)).with_requires_grad();
52 let bias = Tensor::zeros(vec![2]).with_requires_grad();
53
54 println!("Created parameters:");
55 println!(
56 " Weight: shape {:?}, requires_grad: {}",
57 weight.shape().dims(),
58 weight.requires_grad()
59 );
60 println!(
61 " Bias: shape {:?}, requires_grad: {}",
62 bias.shape().dims(),
63 bias.requires_grad()
64 );
65
66 // Create Adam optimizer with default configuration
67 let mut optimizer = Adam::new();
68 println!(
69 "Created Adam optimizer with learning rate: {}",
70 optimizer.learning_rate()
71 );
72
73 // Add parameters to optimizer
74 optimizer.add_parameter(&weight);
75 optimizer.add_parameter(&bias);
76 println!(
77 "Added {} parameters to optimizer",
78 optimizer.parameter_count()
79 );
80
81 // Create optimizer with custom configuration
82 let config = AdamConfig {
83 learning_rate: 0.01,
84 beta1: 0.9,
85 beta2: 0.999,
86 eps: 1e-8,
87 weight_decay: 0.0,
88 amsgrad: false,
89 };
90
91 let mut custom_optimizer = Adam::with_config(config);
92 custom_optimizer.add_parameter(&weight);
93 custom_optimizer.add_parameter(&bias);
94
95 println!(
96 "Created custom optimizer with learning rate: {}",
97 custom_optimizer.learning_rate()
98 );
99
100 // Demonstrate parameter linking
101 println!("Parameter linking completed successfully");
102}
103
104/// Demonstrate simple linear regression training
105fn demonstrate_linear_regression() -> Result<(), Box<dyn std::error::Error>> {
106 println!("\n--- Linear Regression Training ---");
107
108 // Create model parameters
109 let mut weight = Tensor::randn(vec![1, 1], Some(43)).with_requires_grad();
110 let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
111
112 // Create optimizer
113 let mut optimizer = Adam::with_learning_rate(0.01);
114 optimizer.add_parameter(&weight);
115 optimizer.add_parameter(&bias);
116
117 // Create simple training data: y = 2*x + 1
118 let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5, 1]).unwrap();
119 let y_true = Tensor::from_slice(&[3.0, 5.0, 7.0, 9.0, 11.0], vec![5, 1]).unwrap();
120
121 println!("Training data:");
122 println!(" X: {:?}", x_data.data());
123 println!(" Y: {:?}", y_true.data());
124 println!(" Target: y = 2*x + 1");
125
126 // Training loop
127 let num_epochs = 100;
128 let mut losses = Vec::new();
129
130 for epoch in 0..num_epochs {
131 // Forward pass: y_pred = x * weight + bias
132 let y_pred = x_data.matmul(&weight) + &bias;
133
134 // Compute loss: MSE
135 let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
136
137 // Backward pass
138 loss.backward(None);
139
140 // Optimizer step
141 optimizer.step(&mut [&mut weight, &mut bias]);
142 optimizer.zero_grad(&mut [&mut weight, &mut bias]);
143
144 losses.push(loss.value());
145
146 // Print progress every 20 epochs
147 if epoch % 20 == 0 || epoch == num_epochs - 1 {
148 println!("Epoch {:3}: Loss = {:.6}", epoch, loss.value());
149 }
150 }
151
152 // Evaluate final model
153 let final_predictions = x_data.matmul(&weight) + &bias;
154 println!("\nFinal model evaluation:");
155 println!(" Learned weight: {:.6}", weight.value());
156 println!(" Learned bias: {:.6}", bias.value());
157 println!(" Predictions vs True:");
158
159 for i in 0..5 {
160 let x1 = x_data.data()[i];
161 let pred = final_predictions.data()[i];
162 let true_val = y_true.data()[i];
163 println!(
164 " x={:.1}: pred={:.3}, true={:.1}, error={:.3}",
165 x1,
166 pred,
167 true_val,
168 (pred - true_val).abs()
169 );
170 }
171
172 Ok(())
173}
174
175/// Demonstrate advanced training patterns
176fn demonstrate_advanced_training() -> Result<(), Box<dyn std::error::Error>> {
177 println!("\n--- Advanced Training Patterns ---");
178
179 // Create a more complex model
180 let mut weight = Tensor::randn(vec![1, 2], Some(44)).with_requires_grad();
181 let mut bias = Tensor::zeros(vec![2]).with_requires_grad();
182
183 // Create optimizer with different learning rate
184 let mut optimizer = Adam::with_learning_rate(0.005);
185 optimizer.add_parameter(&weight);
186 optimizer.add_parameter(&bias);
187
188 // Create training data: y = 2*x + [1, 3]
189 let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5, 1]).unwrap();
190 let y_true = Tensor::from_slice(
191 &[3.0, 5.0, 7.0, 9.0, 11.0, 6.0, 8.0, 10.0, 12.0, 14.0],
192 vec![5, 2],
193 )
194 .unwrap();
195
196 println!("Advanced training with monitoring:");
197 println!(" Initial learning rate: {}", optimizer.learning_rate());
198
199 // Training loop with monitoring
200 let num_epochs = 50;
201 let mut losses = Vec::new();
202 let mut weight_norms = Vec::new();
203 let mut gradient_norms = Vec::new();
204
205 for epoch in 0..num_epochs {
206 // Forward pass
207 let y_pred = x_data.matmul(&weight) + &bias;
208 let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
209
210 // Backward pass
211 loss.backward(None);
212
213 // Compute gradient norm before optimizer step
214 let gradient_norm = weight.grad_owned().unwrap().norm();
215
216 // Optimizer step
217 optimizer.step(&mut [&mut weight, &mut bias]);
218 optimizer.zero_grad(&mut [&mut weight, &mut bias]);
219
220 // Learning rate scheduling: reduce every 10 epochs
221 if epoch > 0 && epoch % 10 == 0 {
222 let current_lr = optimizer.learning_rate();
223 let new_lr = current_lr * 0.5;
224 optimizer.set_learning_rate(new_lr);
225 println!(
226 "Epoch {:2}: Reduced learning rate from {:.3} to {:.3}",
227 epoch, current_lr, new_lr
228 );
229 }
230
231 // Record metrics
232 losses.push(loss.value());
233 weight_norms.push(weight.norm().value());
234 gradient_norms.push(gradient_norm.value());
235
236 // Print detailed progress
237 if epoch % 10 == 0 || epoch == num_epochs - 1 {
238 println!(
239 "Epoch {:2}: Loss = {:.6}, Weight Norm = {:.6}, Gradient Norm = {:.6}",
240 epoch,
241 loss.value(),
242 weight.norm().value(),
243 gradient_norm.value()
244 );
245 }
246 }
247
248 println!("Final learning rate: {}", optimizer.learning_rate());
249
250 // Analyze training progression
251 let initial_loss = losses[0];
252 let final_loss = losses[losses.len() - 1];
253 let loss_reduction = (initial_loss - final_loss) / initial_loss * 100.0;
254
255 println!("\nTraining Analysis:");
256 println!(" Initial loss: {:.6}", initial_loss);
257 println!(" Final loss: {:.6}", final_loss);
258 println!(" Loss reduction: {:.1}%", loss_reduction);
259 println!(" Final weight norm: {:.6}", weight.norm().value());
260 println!(" Final bias: {:?}", bias.data());
261
262 Ok(())
263}
264
265/// Demonstrate learning rate scheduling
266fn demonstrate_learning_rate_scheduling() -> Result<(), Box<dyn std::error::Error>> {
267 println!("\n--- Learning Rate Scheduling ---");
268
269 // Create simple model
270 let mut weight = Tensor::randn(vec![1, 1], Some(45)).with_requires_grad();
271 let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
272
273 // Create optimizer with high initial learning rate
274 let mut optimizer = Adam::with_learning_rate(0.1);
275 optimizer.add_parameter(&weight);
276 optimizer.add_parameter(&bias);
277
278 // Simple data
279 let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3, 1]).unwrap();
280 let y_true = Tensor::from_slice(&[2.0, 4.0, 6.0], vec![3, 1]).unwrap();
281
282 println!("Initial learning rate: {}", optimizer.learning_rate());
283
284 // Training loop with learning rate scheduling
285 let num_epochs = 50;
286 let mut losses = Vec::new();
287
288 for epoch in 0..num_epochs {
289 // Forward pass
290 let y_pred = x_data.matmul(&weight) + &bias;
291 let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
292
293 // Backward pass
294 loss.backward(None);
295
296 // Optimizer step
297 optimizer.step(&mut [&mut weight, &mut bias]);
298 optimizer.zero_grad(&mut [&mut weight, &mut bias]);
299
300 // Learning rate scheduling: reduce every 10 epochs
301 if epoch > 0 && epoch % 10 == 0 {
302 let current_lr = optimizer.learning_rate();
303 let new_lr = current_lr * 0.5;
304 optimizer.set_learning_rate(new_lr);
305 println!(
306 "Epoch {:2}: Reduced learning rate from {:.3} to {:.3}",
307 epoch, current_lr, new_lr
308 );
309 }
310
311 losses.push(loss.value());
312
313 // Print progress
314 if epoch % 10 == 0 || epoch == num_epochs - 1 {
315 println!(
316 "Epoch {:2}: Loss = {:.6}, LR = {:.3}",
317 epoch,
318 loss.value(),
319 optimizer.learning_rate()
320 );
321 }
322 }
323
324 println!("Final learning rate: {}", optimizer.learning_rate());
325
326 Ok(())
327}
328
329/// Demonstrate training monitoring and analysis
330fn demonstrate_training_monitoring() -> Result<(), Box<dyn std::error::Error>> {
331 println!("\n--- Training Monitoring ---");
332
333 // Create model
334 let mut weight = Tensor::randn(vec![1, 1], Some(46)).with_requires_grad();
335 let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
336
337 // Create optimizer
338 let mut optimizer = Adam::with_learning_rate(0.01);
339 optimizer.add_parameter(&weight);
340 optimizer.add_parameter(&bias);
341
342 // Training data
343 let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4, 1]).unwrap();
344 let y_true = Tensor::from_slice(&[3.0, 5.0, 7.0, 9.0], vec![4, 1]).unwrap();
345
346 // Training loop with comprehensive monitoring
347 let num_epochs = 30;
348 let mut losses = Vec::new();
349 let mut weight_history = Vec::new();
350 let mut bias_history = Vec::new();
351
352 for epoch in 0..num_epochs {
353 // Forward pass
354 let y_pred = x_data.matmul(&weight) + &bias;
355 let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
356
357 // Backward pass
358 loss.backward(None);
359
360 // Optimizer step
361 optimizer.step(&mut [&mut weight, &mut bias]);
362 optimizer.zero_grad(&mut [&mut weight, &mut bias]);
363
364 // Record history
365 losses.push(loss.value());
366 weight_history.push(weight.value());
367 bias_history.push(bias.value());
368
369 // Print detailed monitoring
370 if epoch % 5 == 0 || epoch == num_epochs - 1 {
371 println!(
372 "Epoch {:2}: Loss = {:.6}, Weight = {:.6}, Bias = {:.6}",
373 epoch,
374 loss.value(),
375 weight.value(),
376 bias.value()
377 );
378 }
379 }
380
381 // Analyze training progression
382 println!("\nTraining Analysis:");
383 println!(" Initial loss: {:.6}", losses[0]);
384 println!(" Final loss: {:.6}", losses[losses.len() - 1]);
385 println!(
386 " Loss reduction: {:.1}%",
387 (losses[0] - losses[losses.len() - 1]) / losses[0] * 100.0
388 );
389
390 // Compute statistics
391 let loss_mean = compute_mean(&losses);
392 let loss_std = compute_std(&losses);
393 let weight_change = (weight_history[weight_history.len() - 1] - weight_history[0]).abs();
394 let bias_change = (bias_history[bias_history.len() - 1] - bias_history[0]).abs();
395
396 println!(" Average loss: {:.6} ± {:.6}", loss_mean, loss_std);
397 println!(" Weight change: {:.6}", weight_change);
398 println!(" Bias change: {:.6}", bias_change);
399 println!(" Final weight norm: {:.6}", weight.norm().value());
400 println!(" Final bias: {:.6}", bias.value());
401
402 Ok(())
403}109fn demonstrate_optimizer_serialization() -> Result<(), Box<dyn std::error::Error>> {
110 println!("\n--- Optimizer Serialization ---");
111
112 // Create an optimizer with some parameters
113 let mut weight = Tensor::randn(vec![2, 2], Some(42)).with_requires_grad();
114 let mut bias = Tensor::randn(vec![2], Some(43)).with_requires_grad();
115
116 let config = AdamConfig {
117 learning_rate: 0.001,
118 beta1: 0.9,
119 beta2: 0.999,
120 eps: 1e-8,
121 weight_decay: 0.0,
122 amsgrad: false,
123 };
124
125 let mut optimizer = Adam::with_config(config);
126 optimizer.add_parameter(&weight);
127 optimizer.add_parameter(&bias);
128
129 println!(
130 "Created optimizer with {} parameters",
131 optimizer.parameter_count()
132 );
133 println!("Learning rate: {}", optimizer.learning_rate());
134
135 // Simulate some training steps
136 for _ in 0..3 {
137 let mut loss = weight.sum() + bias.sum();
138 loss.backward(None);
139 optimizer.step(&mut [&mut weight, &mut bias]);
140 optimizer.zero_grad(&mut [&mut weight, &mut bias]);
141 }
142
143 // Save optimizer state
144 let optimizer_path = "temp_optimizer.json";
145 optimizer.save_json(optimizer_path)?;
146 println!("Saved optimizer to: {}", optimizer_path);
147
148 // Load optimizer state
149 let loaded_optimizer = Adam::load_json(optimizer_path)?;
150 println!(
151 "Loaded optimizer with {} parameters",
152 loaded_optimizer.parameter_count()
153 );
154 println!("Learning rate: {}", loaded_optimizer.learning_rate());
155
156 // Verify optimizer state
157 assert_eq!(
158 optimizer.parameter_count(),
159 loaded_optimizer.parameter_count()
160 );
161 assert_eq!(optimizer.learning_rate(), loaded_optimizer.learning_rate());
162 println!("Optimizer serialization verification: PASSED");
163
164 Ok(())
165}
166
167/// Demonstrate format comparison and performance characteristics
168fn demonstrate_format_comparison() -> Result<(), Box<dyn std::error::Error>> {
169 println!("\n--- Format Comparison ---");
170
171 // Create a larger tensor for comparison
172 let tensor = Tensor::randn(vec![10, 10], Some(44));
173
174 // Save in both formats
175 tensor.save_json("temp_comparison.json")?;
176 tensor.save_binary("temp_comparison.bin")?;
177
178 // Compare file sizes
179 let json_size = fs::metadata("temp_comparison.json")?.len();
180 let binary_size = fs::metadata("temp_comparison.bin")?.len();
181
182 println!("JSON file size: {} bytes", json_size);
183 println!("Binary file size: {} bytes", binary_size);
184 println!(
185 "Compression ratio: {:.2}x",
186 json_size as f64 / binary_size as f64
187 );
188
189 // Load and verify both formats
190 let json_tensor = Tensor::load_json("temp_comparison.json")?;
191 let binary_tensor = Tensor::load_binary("temp_comparison.bin")?;
192
193 assert_eq!(tensor.shape().dims(), json_tensor.shape().dims());
194 assert_eq!(tensor.shape().dims(), binary_tensor.shape().dims());
195 assert_eq!(tensor.data(), json_tensor.data());
196 assert_eq!(tensor.data(), binary_tensor.data());
197
198 println!("Format comparison verification: PASSED");
199
200 Ok(())
201}
202
203/// Demonstrate a basic model checkpointing workflow
204fn demonstrate_model_checkpointing() -> Result<(), Box<dyn std::error::Error>> {
205 println!("\n--- Model Checkpointing ---");
206
207 // Create a simple model (weights and bias)
208 let mut weights = Tensor::randn(vec![2, 1], Some(45)).with_requires_grad();
209 let mut bias = Tensor::randn(vec![1], Some(46)).with_requires_grad();
210
211 // Create optimizer
212 let mut optimizer = Adam::with_learning_rate(0.01);
213 optimizer.add_parameter(&weights);
214 optimizer.add_parameter(&bias);
215
216 println!("Initial weights: {:?}", weights.data());
217 println!("Initial bias: {:?}", bias.data());
218
219 // Simulate training
220 for epoch in 0..5 {
221 let mut loss = weights.sum() + bias.sum();
222 loss.backward(None);
223 optimizer.step(&mut [&mut weights, &mut bias]);
224 optimizer.zero_grad(&mut [&mut weights, &mut bias]);
225
226 if epoch % 2 == 0 {
227 // Save checkpoint
228 let checkpoint_dir = format!("checkpoint_epoch_{}", epoch);
229 fs::create_dir_all(&checkpoint_dir)?;
230
231 weights.save_json(format!("{}/weights.json", checkpoint_dir))?;
232 bias.save_json(format!("{}/bias.json", checkpoint_dir))?;
233 optimizer.save_json(format!("{}/optimizer.json", checkpoint_dir))?;
234
235 println!("Saved checkpoint for epoch {}", epoch);
236 }
237 }
238
239 // Load from checkpoint
240 let loaded_weights = Tensor::load_json("checkpoint_epoch_4/weights.json")?;
241 let loaded_bias = Tensor::load_json("checkpoint_epoch_4/bias.json")?;
242 let loaded_optimizer = Adam::load_json("checkpoint_epoch_4/optimizer.json")?;
243
244 println!("Loaded weights: {:?}", loaded_weights.data());
245 println!("Loaded bias: {:?}", loaded_bias.data());
246 println!(
247 "Loaded optimizer learning rate: {}",
248 loaded_optimizer.learning_rate()
249 );
250
251 // Verify checkpoint integrity
252 assert_eq!(weights.shape().dims(), loaded_weights.shape().dims());
253 assert_eq!(bias.shape().dims(), loaded_bias.shape().dims());
254 assert_eq!(optimizer.learning_rate(), loaded_optimizer.learning_rate());
255
256 println!("Checkpointing verification: PASSED");
257
258 Ok(())
259}165fn main() -> Result<(), Box<dyn std::error::Error>> {
166 println!("=== Multi-Head Attention Example ===");
167
168 let batch = 2usize;
169 let src_len = 5usize;
170 let tgt_len = 4usize;
171 let embed = 16usize;
172 let heads = 4usize;
173
174 let query = Tensor::randn(vec![batch, tgt_len, embed], Some(7));
175 let key = Tensor::randn(vec![batch, src_len, embed], Some(8));
176 let value = Tensor::randn(vec![batch, src_len, embed], Some(9));
177
178 let mut mha = MultiHeadAttention::new(embed, heads, Some(42));
179
180 // Simple causal mask for target self-attention shape [b, h, tq, tk]
181 let mut mask = Tensor::zeros(vec![batch, heads, tgt_len, src_len]);
182 // Disallow attending to future positions when tgt_len <= src_len by adding -1e9
183 // Here, just demonstrate mask broadcast/add mechanics with a light mask on last head
184 if src_len >= tgt_len {
185 // set upper triangle to a large negative value for head 0
186 for i in 0..tgt_len {
187 for j in (i + 1)..src_len {
188 let idx = [0usize, 0usize, i, j];
189 // Quick set via data_mut using a slice view
190 let offset = mask.memory_offset(&idx);
191 let data = mask.data_mut();
192 data[offset] = -1e9;
193 }
194 }
195 }
196
197 let out = mha.forward(&query, &key, &value, Some(&mask));
198 println!("Output shape: {:?}", out.shape().dims());
199
200 // Tiny training step to confirm gradients are wired
201 let mut optimizer = Adam::with_learning_rate(0.01);
202 let mut params = mha.parameters();
203 for p in ¶ms {
204 optimizer.add_parameter(p);
205 }
206
207 // Dummy loss = mean of output
208 let mut loss = out.mean();
209 loss.backward(None);
210 optimizer.step(&mut params);
211 optimizer.zero_grad(&mut params);
212
213 println!("Loss: {:.6}", loss.value());
214 println!("=== Done ===");
215 Ok(())
216}- examples/optimizers/adam_configurations.rs
- examples/optimizers/learning_rate_scheduling.rs
- examples/supervised_training/../neural_networks/feedforward_network.rs
- examples/RL_training/../neural_networks/basic_linear_layer.rs
- examples/supervised_training/supervised_bce.rs
- examples/supervised_training/supervised_classification.rs
- examples/supervised_training/supervised_regression.rs
- examples/RL_training/dqn.rs
- examples/RL_training/ppo_discrete.rs
- examples/RL_training/ppo_continuous.rs
- examples/RL_training/td3.rs
Sourcepub fn add_parameters(&mut self, parameters: &[&Tensor])
pub fn add_parameters(&mut self, parameters: &[&Tensor])
Add multiple parameters to the optimizer
Links multiple parameters to the optimizer by creating parameter states
indexed by each tensor’s ID. All parameters must have requires_grad set to true.
§Arguments
parameters- Slice of references to tensors to link
§Panics
Panics if any parameter does not have requires_grad set to true
Sourcepub fn unlink_parameter(&mut self, parameter: &Tensor) -> bool
pub fn unlink_parameter(&mut self, parameter: &Tensor) -> bool
Sourcepub fn clear_states(&mut self)
pub fn clear_states(&mut self)
Remove all parameter states from the optimizer
Clears all parameter states, effectively unlinking all parameters. This is useful for resetting the optimizer or preparing for parameter re-linking.
Sourcepub fn is_parameter_linked(&self, parameter: &Tensor) -> bool
pub fn is_parameter_linked(&self, parameter: &Tensor) -> bool
Sourcepub fn parameter_count(&self) -> usize
pub fn parameter_count(&self) -> usize
Get the number of linked parameters
Returns the count of parameters currently linked to the optimizer.
§Returns
Number of linked parameters
Examples found in repository?
47fn demonstrate_basic_optimizer_setup() {
48 println!("--- Basic Optimizer Setup ---");
49
50 // Create parameters that require gradients
51 let weight = Tensor::randn(vec![3, 2], Some(42)).with_requires_grad();
52 let bias = Tensor::zeros(vec![2]).with_requires_grad();
53
54 println!("Created parameters:");
55 println!(
56 " Weight: shape {:?}, requires_grad: {}",
57 weight.shape().dims(),
58 weight.requires_grad()
59 );
60 println!(
61 " Bias: shape {:?}, requires_grad: {}",
62 bias.shape().dims(),
63 bias.requires_grad()
64 );
65
66 // Create Adam optimizer with default configuration
67 let mut optimizer = Adam::new();
68 println!(
69 "Created Adam optimizer with learning rate: {}",
70 optimizer.learning_rate()
71 );
72
73 // Add parameters to optimizer
74 optimizer.add_parameter(&weight);
75 optimizer.add_parameter(&bias);
76 println!(
77 "Added {} parameters to optimizer",
78 optimizer.parameter_count()
79 );
80
81 // Create optimizer with custom configuration
82 let config = AdamConfig {
83 learning_rate: 0.01,
84 beta1: 0.9,
85 beta2: 0.999,
86 eps: 1e-8,
87 weight_decay: 0.0,
88 amsgrad: false,
89 };
90
91 let mut custom_optimizer = Adam::with_config(config);
92 custom_optimizer.add_parameter(&weight);
93 custom_optimizer.add_parameter(&bias);
94
95 println!(
96 "Created custom optimizer with learning rate: {}",
97 custom_optimizer.learning_rate()
98 );
99
100 // Demonstrate parameter linking
101 println!("Parameter linking completed successfully");
102}More examples
109fn demonstrate_optimizer_serialization() -> Result<(), Box<dyn std::error::Error>> {
110 println!("\n--- Optimizer Serialization ---");
111
112 // Create an optimizer with some parameters
113 let mut weight = Tensor::randn(vec![2, 2], Some(42)).with_requires_grad();
114 let mut bias = Tensor::randn(vec![2], Some(43)).with_requires_grad();
115
116 let config = AdamConfig {
117 learning_rate: 0.001,
118 beta1: 0.9,
119 beta2: 0.999,
120 eps: 1e-8,
121 weight_decay: 0.0,
122 amsgrad: false,
123 };
124
125 let mut optimizer = Adam::with_config(config);
126 optimizer.add_parameter(&weight);
127 optimizer.add_parameter(&bias);
128
129 println!(
130 "Created optimizer with {} parameters",
131 optimizer.parameter_count()
132 );
133 println!("Learning rate: {}", optimizer.learning_rate());
134
135 // Simulate some training steps
136 for _ in 0..3 {
137 let mut loss = weight.sum() + bias.sum();
138 loss.backward(None);
139 optimizer.step(&mut [&mut weight, &mut bias]);
140 optimizer.zero_grad(&mut [&mut weight, &mut bias]);
141 }
142
143 // Save optimizer state
144 let optimizer_path = "temp_optimizer.json";
145 optimizer.save_json(optimizer_path)?;
146 println!("Saved optimizer to: {}", optimizer_path);
147
148 // Load optimizer state
149 let loaded_optimizer = Adam::load_json(optimizer_path)?;
150 println!(
151 "Loaded optimizer with {} parameters",
152 loaded_optimizer.parameter_count()
153 );
154 println!("Learning rate: {}", loaded_optimizer.learning_rate());
155
156 // Verify optimizer state
157 assert_eq!(
158 optimizer.parameter_count(),
159 loaded_optimizer.parameter_count()
160 );
161 assert_eq!(optimizer.learning_rate(), loaded_optimizer.learning_rate());
162 println!("Optimizer serialization verification: PASSED");
163
164 Ok(())
165}Sourcepub fn relink_parameters(
&mut self,
parameters: &[&Tensor],
) -> Result<(), String>
pub fn relink_parameters( &mut self, parameters: &[&Tensor], ) -> Result<(), String>
Re-link parameters to saved optimizer states in chronological order
After deserializing an optimizer, use this method to restore saved parameter states to new tensors. Parameters must be provided in the same chronological order they were originally added to the optimizer. Shape validation ensures parameter compatibility.
§Arguments
parameters- Slice of parameter references in chronological order
§Returns
Result indicating success or failure with detailed error message
§Panics
Panics if any parameter does not have requires_grad set to true
Sourcepub fn config(&self) -> &AdamConfig
pub fn config(&self) -> &AdamConfig
Get the current optimizer configuration
Returns a reference to the current configuration, allowing inspection of all hyperparameters without modification.
§Returns
Reference to the current Adam configuration
Trait Implementations§
Source§impl FromFieldValue for Adam
impl FromFieldValue for Adam
Source§fn from_field_value(
value: FieldValue,
field_name: &str,
) -> SerializationResult<Self>
fn from_field_value( value: FieldValue, field_name: &str, ) -> SerializationResult<Self>
Source§impl Optimizer for Adam
impl Optimizer for Adam
Source§fn step(&mut self, parameters: &mut [&mut Tensor])
fn step(&mut self, parameters: &mut [&mut Tensor])
Perform a single optimization step
Updates all provided parameters based on their accumulated gradients using the Adam algorithm. Each parameter is updated according to the Adam update rule with bias correction and optional AMSGrad variant if enabled. All parameters must be linked to the optimizer before calling this method.
§Arguments
parameters- Mutable slice of parameter references to update
§Thread Safety
This method is thread-safe as it takes mutable references to parameters, ensuring exclusive access during updates.
§Performance
- Uses SIMD optimization (AVX2) when available for 8x vectorization
- Processes parameters in sequence for optimal cache usage
- Maintains per-parameter state for momentum and velocity estimates
§Panics
Panics if any parameter is not linked to the optimizer
Source§fn zero_grad(&mut self, parameters: &mut [&mut Tensor])
fn zero_grad(&mut self, parameters: &mut [&mut Tensor])
Zero out all parameter gradients
Clears accumulated gradients for all provided parameters. This should be called before each backward pass to prevent gradient accumulation across multiple forward/backward passes. Also clears the global autograd gradient map.
§Arguments
parameters- Mutable slice of parameter references to clear gradients for
§Performance
- Efficiently clears gradients using optimized tensor operations
- Clears both per-tensor gradients and global autograd state
- Thread-safe as it takes mutable references to parameters
Source§fn learning_rate(&self) -> f32
fn learning_rate(&self) -> f32
Get the current learning rate
Returns the current learning rate used for parameter updates.
§Returns
Current learning rate as f32
Source§fn set_learning_rate(&mut self, lr: f32)
fn set_learning_rate(&mut self, lr: f32)
Set the learning rate for all parameters
Updates the learning rate for all parameters in the optimizer. This allows dynamic learning rate scheduling during training.
§Arguments
lr- New learning rate value
Source§impl Serializable for Adam
impl Serializable for Adam
Source§fn to_json(&self) -> SerializationResult<String>
fn to_json(&self) -> SerializationResult<String>
Serialize the Adam optimizer to JSON format
This method converts the Adam optimizer into a human-readable JSON string representation that includes all optimizer state, configuration, parameter states, and step counts. The JSON format is suitable for debugging, configuration files, and cross-language interoperability.
§Returns
JSON string representation of the optimizer on success, or SerializationError on failure
§Examples
use train_station::{Tensor, optimizers::Adam};
use train_station::serialization::Serializable;
let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
let mut optimizer = Adam::new();
optimizer.add_parameter(&weight);
let json = optimizer.to_json().unwrap();
assert!(!json.is_empty());Source§fn from_json(json: &str) -> SerializationResult<Self>
fn from_json(json: &str) -> SerializationResult<Self>
Deserialize an Adam optimizer from JSON format
This method parses a JSON string and reconstructs an Adam optimizer with all
saved state. Parameters must be re-linked after deserialization using
add_parameter or relink_parameters.
§Arguments
json- JSON string containing serialized optimizer
§Returns
The deserialized optimizer on success, or SerializationError on failure
§Examples
use train_station::{Tensor, optimizers::Adam};
use train_station::serialization::Serializable;
let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
let mut optimizer = Adam::new();
optimizer.add_parameter(&weight);
let json = optimizer.to_json().unwrap();
let loaded_optimizer = Adam::from_json(&json).unwrap();
assert_eq!(loaded_optimizer.saved_parameter_count(), 1);Source§fn to_binary(&self) -> SerializationResult<Vec<u8>>
fn to_binary(&self) -> SerializationResult<Vec<u8>>
Serialize the Adam optimizer to binary format
This method converts the optimizer into a compact binary representation optimized for storage and transmission. The binary format provides maximum performance and minimal file sizes compared to JSON.
§Returns
Binary representation of the optimizer on success, or SerializationError on failure
§Examples
use train_station::{Tensor, optimizers::Adam};
use train_station::serialization::Serializable;
let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
let mut optimizer = Adam::new();
optimizer.add_parameter(&weight);
let binary = optimizer.to_binary().unwrap();
assert!(!binary.is_empty());Source§fn from_binary(data: &[u8]) -> SerializationResult<Self>
fn from_binary(data: &[u8]) -> SerializationResult<Self>
Deserialize an Adam optimizer from binary format
This method parses binary data and reconstructs an Adam optimizer with all
saved state. Parameters must be re-linked after deserialization using
add_parameter or relink_parameters.
§Arguments
data- Binary data containing serialized optimizer
§Returns
The deserialized optimizer on success, or SerializationError on failure
§Examples
use train_station::{Tensor, optimizers::Adam};
use train_station::serialization::Serializable;
let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
let mut optimizer = Adam::new();
optimizer.add_parameter(&weight);
let binary = optimizer.to_binary().unwrap();
let loaded_optimizer = Adam::from_binary(&binary).unwrap();
assert_eq!(loaded_optimizer.saved_parameter_count(), 1);Source§fn save<P: AsRef<Path>>(
&self,
path: P,
format: Format,
) -> SerializationResult<()>
fn save<P: AsRef<Path>>( &self, path: P, format: Format, ) -> SerializationResult<()>
Source§fn save_to_writer<W: Write>(
&self,
writer: &mut W,
format: Format,
) -> SerializationResult<()>
fn save_to_writer<W: Write>( &self, writer: &mut W, format: Format, ) -> SerializationResult<()>
Source§fn load<P: AsRef<Path>>(path: P, format: Format) -> SerializationResult<Self>
fn load<P: AsRef<Path>>(path: P, format: Format) -> SerializationResult<Self>
Source§fn load_from_reader<R: Read>(
reader: &mut R,
format: Format,
) -> SerializationResult<Self>
fn load_from_reader<R: Read>( reader: &mut R, format: Format, ) -> SerializationResult<Self>
Source§impl StructSerializable for Adam
impl StructSerializable for Adam
Source§fn to_serializer(&self) -> StructSerializer
fn to_serializer(&self) -> StructSerializer
Convert Adam to StructSerializer for serialization
Serializes all optimizer state including configuration, parameter states, and global step count. Parameter linking is not serialized and must be done after deserialization.
§Returns
StructSerializer containing all serializable optimizer state
Source§fn from_deserializer(
deserializer: &mut StructDeserializer,
) -> SerializationResult<Self>
fn from_deserializer( deserializer: &mut StructDeserializer, ) -> SerializationResult<Self>
Create Adam from StructDeserializer
Reconstructs Adam optimizer from serialized state. Parameters must be
linked separately using add_parameter or add_parameters.
§Arguments
deserializer- StructDeserializer containing optimizer data
§Returns
Reconstructed Adam instance without parameter links, or error if deserialization fails