pub struct AdamW<F: Float + ScalarOperand + Debug> { /* private fields */ }
Expand description
AdamW optimizer for neural networks
Implements the AdamW optimizer from the paper: “Decoupled Weight Decay Regularization” by Loshchilov and Hutter (2017).
AdamW is a variant of Adam that correctly implements weight decay regularization, which helps improve generalization and training stability.
The key difference from Adam is that weight decay is applied directly to the weights rather than to the gradients.
§Examples
use ndarray::Array1;
use scirs2_neural::optimizers::{AdamW, Optimizer};
// Create a simple AdamW optimizer with default parameters
let mut adamw = AdamW::<f64>::default_with_lr(0.001).unwrap();
// or with custom configuration
let mut adamw_custom = AdamW::new(0.001, 0.9, 0.999, 1e-8, 0.01);
Implementations§
Source§impl<F: Float + ScalarOperand + Debug> AdamW<F>
impl<F: Float + ScalarOperand + Debug> AdamW<F>
Sourcepub fn new(
learning_rate: F,
beta1: F,
beta2: F,
epsilon: F,
weight_decay: F,
) -> Self
pub fn new( learning_rate: F, beta1: F, beta2: F, epsilon: F, weight_decay: F, ) -> Self
Creates a new AdamW optimizer with the given hyperparameters
§Arguments
learning_rate
- The learning rate for parameter updatesbeta1
- Exponential decay rate for the first moment estimates (default: 0.9)beta2
- Exponential decay rate for the second moment estimates (default: 0.999)epsilon
- Small constant for numerical stability (default: 1e-8)weight_decay
- Weight decay factor (default: 0.01)
Examples found in repository?
examples/advanced_optimizers_example.rs (line 124)
11fn main() -> Result<(), Box<dyn std::error::Error>> {
12 println!("Advanced Optimizers Example");
13
14 // Initialize random number generator
15 let mut rng = SmallRng::seed_from_u64(42);
16
17 // Create a synthetic binary classification dataset
18 let num_samples = 1000;
19 let num_features = 20;
20 let num_classes = 2;
21
22 println!(
23 "Generating synthetic dataset with {} samples, {} features...",
24 num_samples, num_features
25 );
26
27 // Generate random input features
28 let mut x_data = Array2::<f32>::zeros((num_samples, num_features));
29 for i in 0..num_samples {
30 for j in 0..num_features {
31 x_data[[i, j]] = rng.random_range(-1.0..1.0);
32 }
33 }
34
35 // Create true weights and bias for data generation
36 let mut true_weights = Array2::<f32>::zeros((num_features, 1));
37 for i in 0..num_features {
38 true_weights[[i, 0]] = rng.random_range(-1.0..1.0);
39 }
40 let true_bias = rng.random_range(-1.0..1.0);
41
42 // Generate binary labels (0 or 1) based on linear model with logistic function
43 let mut y_data = Array2::<f32>::zeros((num_samples, num_classes));
44 for i in 0..num_samples {
45 let mut logit = true_bias;
46 for j in 0..num_features {
47 logit += x_data[[i, j]] * true_weights[[j, 0]];
48 }
49
50 // Apply sigmoid to get probability
51 let prob = 1.0 / (1.0 + (-logit).exp());
52
53 // Convert to one-hot encoding
54 if prob > 0.5 {
55 y_data[[i, 1]] = 1.0; // Class 1
56 } else {
57 y_data[[i, 0]] = 1.0; // Class 0
58 }
59 }
60
61 // Split into train and test sets (80% train, 20% test)
62 let train_size = (num_samples as f32 * 0.8) as usize;
63 let test_size = num_samples - train_size;
64
65 let x_train = x_data.slice(ndarray::s![0..train_size, ..]).to_owned();
66 let y_train = y_data.slice(ndarray::s![0..train_size, ..]).to_owned();
67 let x_test = x_data.slice(ndarray::s![train_size.., ..]).to_owned();
68 let y_test = y_data.slice(ndarray::s![train_size.., ..]).to_owned();
69
70 println!("Training set: {} samples", train_size);
71 println!("Test set: {} samples", test_size);
72
73 // Create a simple neural network model
74 let hidden_size = 64;
75 let dropout_rate = 0.2;
76 let seed_rng = SmallRng::seed_from_u64(42);
77
78 // Shared function to create identical model architectures for fair comparison
79 let create_model = || -> Result<Sequential<f32>, Box<dyn std::error::Error>> {
80 let mut model = Sequential::new();
81
82 // Input to hidden layer
83 let dense1 = Dense::new(
84 num_features,
85 hidden_size,
86 Some("relu"),
87 &mut seed_rng.clone(),
88 )?;
89 model.add_layer(dense1);
90
91 // Dropout for regularization
92 let dropout = Dropout::new(dropout_rate, &mut seed_rng.clone())?;
93 model.add_layer(dropout);
94
95 // Hidden to output layer
96 let dense2 = Dense::new(
97 hidden_size,
98 num_classes,
99 Some("softmax"),
100 &mut seed_rng.clone(),
101 )?;
102 model.add_layer(dense2);
103
104 Ok(model)
105 };
106
107 // Create models for each optimizer
108 let mut sgd_model = create_model()?;
109 let mut adam_model = create_model()?;
110 let mut adamw_model = create_model()?;
111 let mut radam_model = create_model()?;
112 let mut rmsprop_model = create_model()?;
113
114 // Create the loss function
115 let loss_fn = CrossEntropyLoss::new(1e-10);
116
117 // Create optimizers
118 let learning_rate = 0.001;
119 let batch_size = 32;
120 let epochs = 20;
121
122 let mut sgd_optimizer = SGD::new_with_config(learning_rate, 0.9, 0.0);
123 let mut adam_optimizer = Adam::new(learning_rate, 0.9, 0.999, 1e-8);
124 let mut adamw_optimizer = AdamW::new(learning_rate, 0.9, 0.999, 1e-8, 0.01);
125 let mut radam_optimizer = RAdam::new(learning_rate, 0.9, 0.999, 1e-8, 0.0);
126 let mut rmsprop_optimizer = RMSprop::new_with_config(learning_rate, 0.9, 1e-8, 0.0);
127
128 // Helper function to compute accuracy
129 let compute_accuracy = |model: &Sequential<f32>, x: &Array2<f32>, y: &Array2<f32>| -> f32 {
130 let predictions = model.forward(&x.clone().into_dyn()).unwrap();
131 let mut correct = 0;
132
133 for i in 0..x.shape()[0] {
134 let mut max_idx = 0;
135 let mut max_val = predictions[[i, 0]];
136
137 for j in 1..num_classes {
138 if predictions[[i, j]] > max_val {
139 max_val = predictions[[i, j]];
140 max_idx = j;
141 }
142 }
143
144 let true_idx =
145 y[[i, 0]] < y[[i, 1]] as usize as u8 as i8 as usize as isize as i32 as f32;
146 if max_idx as i32 == true_idx as i32 {
147 correct += 1;
148 }
149 }
150
151 correct as f32 / x.shape()[0] as f32
152 };
153
154 // Helper function to train model
155 let mut train_model =
156 |model: &mut Sequential<f32>, optimizer: &mut dyn Optimizer<f32>, name: &str| -> Vec<f32> {
157 println!("\nTraining with {} optimizer...", name);
158 let start_time = Instant::now();
159
160 let mut train_losses = Vec::new();
161 let num_batches = train_size.div_ceil(batch_size);
162
163 for epoch in 0..epochs {
164 let mut epoch_loss = 0.0;
165
166 // Create a permutation for shuffling the data
167 let mut indices: Vec<usize> = (0..train_size).collect();
168 indices.shuffle(&mut rng);
169
170 for batch_idx in 0..num_batches {
171 let start = batch_idx * batch_size;
172 let end = (start + batch_size).min(train_size);
173 let batch_indices = &indices[start..end];
174
175 // Create batch data
176 let mut x_batch = Array2::<f32>::zeros((batch_indices.len(), num_features));
177 let mut y_batch = Array2::<f32>::zeros((batch_indices.len(), num_classes));
178
179 for (i, &idx) in batch_indices.iter().enumerate() {
180 for j in 0..num_features {
181 x_batch[[i, j]] = x_train[[idx, j]];
182 }
183 for j in 0..num_classes {
184 y_batch[[i, j]] = y_train[[idx, j]];
185 }
186 }
187
188 // Convert to dynamic dimension arrays
189 let x_batch_dyn = x_batch.into_dyn();
190 let y_batch_dyn = y_batch.into_dyn();
191
192 // Perform a training step
193 let batch_loss = model
194 .train_batch(&x_batch_dyn, &y_batch_dyn, &loss_fn, optimizer)
195 .unwrap();
196 epoch_loss += batch_loss;
197 }
198
199 epoch_loss /= num_batches as f32;
200 train_losses.push(epoch_loss);
201
202 // Calculate and print metrics every few epochs
203 if epoch % 5 == 0 || epoch == epochs - 1 {
204 let train_accuracy = compute_accuracy(model, &x_train, &y_train);
205 let test_accuracy = compute_accuracy(model, &x_test, &y_test);
206
207 println!(
208 "Epoch {}/{}: loss = {:.6}, train_acc = {:.2}%, test_acc = {:.2}%",
209 epoch + 1,
210 epochs,
211 epoch_loss,
212 train_accuracy * 100.0,
213 test_accuracy * 100.0
214 );
215 }
216 }
217
218 let elapsed = start_time.elapsed();
219 println!("{} training completed in {:.2?}", name, elapsed);
220
221 // Final evaluation
222 let train_accuracy = compute_accuracy(model, &x_train, &y_train);
223 let test_accuracy = compute_accuracy(model, &x_test, &y_test);
224 println!("Final metrics for {}:", name);
225 println!(" Train accuracy: {:.2}%", train_accuracy * 100.0);
226 println!(" Test accuracy: {:.2}%", test_accuracy * 100.0);
227
228 train_losses
229 };
230
231 // Train models with different optimizers
232 let sgd_losses = train_model(&mut sgd_model, &mut sgd_optimizer, "SGD");
233 let adam_losses = train_model(&mut adam_model, &mut adam_optimizer, "Adam");
234 let adamw_losses = train_model(&mut adamw_model, &mut adamw_optimizer, "AdamW");
235 let radam_losses = train_model(&mut radam_model, &mut radam_optimizer, "RAdam");
236 let rmsprop_losses = train_model(&mut rmsprop_model, &mut rmsprop_optimizer, "RMSprop");
237
238 // Print comparison summary
239 println!("\nOptimizer Comparison Summary:");
240 println!("----------------------------");
241 println!("Initial learning rate: {}", learning_rate);
242 println!("Batch size: {}", batch_size);
243 println!("Epochs: {}", epochs);
244 println!();
245
246 println!("Final Loss Values:");
247 println!(" SGD: {:.6}", sgd_losses.last().unwrap());
248 println!(" Adam: {:.6}", adam_losses.last().unwrap());
249 println!(" AdamW: {:.6}", adamw_losses.last().unwrap());
250 println!(" RAdam: {:.6}", radam_losses.last().unwrap());
251 println!(" RMSprop: {:.6}", rmsprop_losses.last().unwrap());
252
253 println!("\nLoss progression (first value, middle value, last value):");
254 println!(
255 " SGD: {:.6}, {:.6}, {:.6}",
256 sgd_losses.first().unwrap(),
257 sgd_losses[epochs / 2],
258 sgd_losses.last().unwrap()
259 );
260 println!(
261 " Adam: {:.6}, {:.6}, {:.6}",
262 adam_losses.first().unwrap(),
263 adam_losses[epochs / 2],
264 adam_losses.last().unwrap()
265 );
266 println!(
267 " AdamW: {:.6}, {:.6}, {:.6}",
268 adamw_losses.first().unwrap(),
269 adamw_losses[epochs / 2],
270 adamw_losses.last().unwrap()
271 );
272 println!(
273 " RAdam: {:.6}, {:.6}, {:.6}",
274 radam_losses.first().unwrap(),
275 radam_losses[epochs / 2],
276 radam_losses.last().unwrap()
277 );
278 println!(
279 " RMSprop: {:.6}, {:.6}, {:.6}",
280 rmsprop_losses.first().unwrap(),
281 rmsprop_losses[epochs / 2],
282 rmsprop_losses.last().unwrap()
283 );
284
285 println!("\nLoss improvement ratio (first loss / last loss):");
286 println!(
287 " SGD: {:.2}x",
288 sgd_losses.first().unwrap() / sgd_losses.last().unwrap()
289 );
290 println!(
291 " Adam: {:.2}x",
292 adam_losses.first().unwrap() / adam_losses.last().unwrap()
293 );
294 println!(
295 " AdamW: {:.2}x",
296 adamw_losses.first().unwrap() / adamw_losses.last().unwrap()
297 );
298 println!(
299 " RAdam: {:.2}x",
300 radam_losses.first().unwrap() / radam_losses.last().unwrap()
301 );
302 println!(
303 " RMSprop: {:.2}x",
304 rmsprop_losses.first().unwrap() / rmsprop_losses.last().unwrap()
305 );
306
307 println!("\nAdvanced optimizers demo completed successfully!");
308
309 Ok(())
310}
Sourcepub fn default_with_lr(learning_rate: F) -> Result<Self>
pub fn default_with_lr(learning_rate: F) -> Result<Self>
Creates a new AdamW optimizer with default hyperparameters
§Arguments
learning_rate
- The learning rate for parameter updates
Sourcepub fn get_epsilon(&self) -> F
pub fn get_epsilon(&self) -> F
Gets the epsilon parameter
Sourcepub fn set_epsilon(&mut self, epsilon: F) -> &mut Self
pub fn set_epsilon(&mut self, epsilon: F) -> &mut Self
Sets the epsilon parameter
Sourcepub fn get_weight_decay(&self) -> F
pub fn get_weight_decay(&self) -> F
Gets the weight decay parameter
Sourcepub fn set_weight_decay(&mut self, weight_decay: F) -> &mut Self
pub fn set_weight_decay(&mut self, weight_decay: F) -> &mut Self
Sets the weight decay parameter
Trait Implementations§
Source§impl<F: Float + ScalarOperand + Debug> Optimizer<F> for AdamW<F>
impl<F: Float + ScalarOperand + Debug> Optimizer<F> for AdamW<F>
Source§fn update(
&mut self,
params: &mut [Array<F, IxDyn>],
grads: &[Array<F, IxDyn>],
) -> Result<()>
fn update( &mut self, params: &mut [Array<F, IxDyn>], grads: &[Array<F, IxDyn>], ) -> Result<()>
Update parameters based on gradients
Source§fn get_learning_rate(&self) -> F
fn get_learning_rate(&self) -> F
Get the current learning rate
Source§fn set_learning_rate(&mut self, lr: F)
fn set_learning_rate(&mut self, lr: F)
Set the learning rate
Source§fn step_model(&mut self, model: &mut dyn ParamLayer<F>) -> Result<()>
fn step_model(&mut self, model: &mut dyn ParamLayer<F>) -> Result<()>
Update model parameters using the optimizer (non-generic version for trait objects)
Auto Trait Implementations§
impl<F> Freeze for AdamW<F>where
F: Freeze,
impl<F> RefUnwindSafe for AdamW<F>where
F: RefUnwindSafe,
impl<F> Send for AdamW<F>where
F: Send,
impl<F> Sync for AdamW<F>where
F: Sync,
impl<F> Unpin for AdamW<F>where
F: Unpin,
impl<F> UnwindSafe for AdamW<F>where
F: UnwindSafe + RefUnwindSafe,
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more
Source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
Converts
self
into a Left
variant of Either<Self, Self>
if into_left
is true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
Converts
self
into a Left
variant of Either<Self, Self>
if into_left(&self)
returns true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read moreSource§impl<F, O> OptimizerStep<F> for O
impl<F, O> OptimizerStep<F> for O
Source§fn step<L>(&mut self, model: &mut L) -> Result<(), NeuralError>where
L: ParamLayer<F> + ?Sized,
fn step<L>(&mut self, model: &mut L) -> Result<(), NeuralError>where
L: ParamLayer<F> + ?Sized,
Update model parameters using the optimizer