Struct AdamW

Source
pub struct AdamW<F: Float + ScalarOperand + Debug> { /* private fields */ }
Expand description

AdamW optimizer for neural networks

Implements the AdamW optimizer from the paper: “Decoupled Weight Decay Regularization” by Loshchilov and Hutter (2017).

AdamW is a variant of Adam that correctly implements weight decay regularization, which helps improve generalization and training stability.

The key difference from Adam is that weight decay is applied directly to the weights rather than to the gradients.

§Examples

use ndarray::Array1;
use scirs2_neural::optimizers::{AdamW, Optimizer};

// Create a simple AdamW optimizer with default parameters
let mut adamw = AdamW::<f64>::default_with_lr(0.001).unwrap();

// or with custom configuration
let mut adamw_custom = AdamW::new(0.001, 0.9, 0.999, 1e-8, 0.01);

Implementations§

Source§

impl<F: Float + ScalarOperand + Debug> AdamW<F>

Source

pub fn new( learning_rate: F, beta1: F, beta2: F, epsilon: F, weight_decay: F, ) -> Self

Creates a new AdamW optimizer with the given hyperparameters

§Arguments
  • learning_rate - The learning rate for parameter updates
  • beta1 - Exponential decay rate for the first moment estimates (default: 0.9)
  • beta2 - Exponential decay rate for the second moment estimates (default: 0.999)
  • epsilon - Small constant for numerical stability (default: 1e-8)
  • weight_decay - Weight decay factor (default: 0.01)
Examples found in repository?
examples/advanced_optimizers_example.rs (line 124)
11fn main() -> Result<(), Box<dyn std::error::Error>> {
12    println!("Advanced Optimizers Example");
13
14    // Initialize random number generator
15    let mut rng = SmallRng::seed_from_u64(42);
16
17    // Create a synthetic binary classification dataset
18    let num_samples = 1000;
19    let num_features = 20;
20    let num_classes = 2;
21
22    println!(
23        "Generating synthetic dataset with {} samples, {} features...",
24        num_samples, num_features
25    );
26
27    // Generate random input features
28    let mut x_data = Array2::<f32>::zeros((num_samples, num_features));
29    for i in 0..num_samples {
30        for j in 0..num_features {
31            x_data[[i, j]] = rng.random_range(-1.0..1.0);
32        }
33    }
34
35    // Create true weights and bias for data generation
36    let mut true_weights = Array2::<f32>::zeros((num_features, 1));
37    for i in 0..num_features {
38        true_weights[[i, 0]] = rng.random_range(-1.0..1.0);
39    }
40    let true_bias = rng.random_range(-1.0..1.0);
41
42    // Generate binary labels (0 or 1) based on linear model with logistic function
43    let mut y_data = Array2::<f32>::zeros((num_samples, num_classes));
44    for i in 0..num_samples {
45        let mut logit = true_bias;
46        for j in 0..num_features {
47            logit += x_data[[i, j]] * true_weights[[j, 0]];
48        }
49
50        // Apply sigmoid to get probability
51        let prob = 1.0 / (1.0 + (-logit).exp());
52
53        // Convert to one-hot encoding
54        if prob > 0.5 {
55            y_data[[i, 1]] = 1.0; // Class 1
56        } else {
57            y_data[[i, 0]] = 1.0; // Class 0
58        }
59    }
60
61    // Split into train and test sets (80% train, 20% test)
62    let train_size = (num_samples as f32 * 0.8) as usize;
63    let test_size = num_samples - train_size;
64
65    let x_train = x_data.slice(ndarray::s![0..train_size, ..]).to_owned();
66    let y_train = y_data.slice(ndarray::s![0..train_size, ..]).to_owned();
67    let x_test = x_data.slice(ndarray::s![train_size.., ..]).to_owned();
68    let y_test = y_data.slice(ndarray::s![train_size.., ..]).to_owned();
69
70    println!("Training set: {} samples", train_size);
71    println!("Test set: {} samples", test_size);
72
73    // Create a simple neural network model
74    let hidden_size = 64;
75    let dropout_rate = 0.2;
76    let seed_rng = SmallRng::seed_from_u64(42);
77
78    // Shared function to create identical model architectures for fair comparison
79    let create_model = || -> Result<Sequential<f32>, Box<dyn std::error::Error>> {
80        let mut model = Sequential::new();
81
82        // Input to hidden layer
83        let dense1 = Dense::new(
84            num_features,
85            hidden_size,
86            Some("relu"),
87            &mut seed_rng.clone(),
88        )?;
89        model.add_layer(dense1);
90
91        // Dropout for regularization
92        let dropout = Dropout::new(dropout_rate, &mut seed_rng.clone())?;
93        model.add_layer(dropout);
94
95        // Hidden to output layer
96        let dense2 = Dense::new(
97            hidden_size,
98            num_classes,
99            Some("softmax"),
100            &mut seed_rng.clone(),
101        )?;
102        model.add_layer(dense2);
103
104        Ok(model)
105    };
106
107    // Create models for each optimizer
108    let mut sgd_model = create_model()?;
109    let mut adam_model = create_model()?;
110    let mut adamw_model = create_model()?;
111    let mut radam_model = create_model()?;
112    let mut rmsprop_model = create_model()?;
113
114    // Create the loss function
115    let loss_fn = CrossEntropyLoss::new(1e-10);
116
117    // Create optimizers
118    let learning_rate = 0.001;
119    let batch_size = 32;
120    let epochs = 20;
121
122    let mut sgd_optimizer = SGD::new_with_config(learning_rate, 0.9, 0.0);
123    let mut adam_optimizer = Adam::new(learning_rate, 0.9, 0.999, 1e-8);
124    let mut adamw_optimizer = AdamW::new(learning_rate, 0.9, 0.999, 1e-8, 0.01);
125    let mut radam_optimizer = RAdam::new(learning_rate, 0.9, 0.999, 1e-8, 0.0);
126    let mut rmsprop_optimizer = RMSprop::new_with_config(learning_rate, 0.9, 1e-8, 0.0);
127
128    // Helper function to compute accuracy
129    let compute_accuracy = |model: &Sequential<f32>, x: &Array2<f32>, y: &Array2<f32>| -> f32 {
130        let predictions = model.forward(&x.clone().into_dyn()).unwrap();
131        let mut correct = 0;
132
133        for i in 0..x.shape()[0] {
134            let mut max_idx = 0;
135            let mut max_val = predictions[[i, 0]];
136
137            for j in 1..num_classes {
138                if predictions[[i, j]] > max_val {
139                    max_val = predictions[[i, j]];
140                    max_idx = j;
141                }
142            }
143
144            let true_idx =
145                y[[i, 0]] < y[[i, 1]] as usize as u8 as i8 as usize as isize as i32 as f32;
146            if max_idx as i32 == true_idx as i32 {
147                correct += 1;
148            }
149        }
150
151        correct as f32 / x.shape()[0] as f32
152    };
153
154    // Helper function to train model
155    let mut train_model =
156        |model: &mut Sequential<f32>, optimizer: &mut dyn Optimizer<f32>, name: &str| -> Vec<f32> {
157            println!("\nTraining with {} optimizer...", name);
158            let start_time = Instant::now();
159
160            let mut train_losses = Vec::new();
161            let num_batches = train_size.div_ceil(batch_size);
162
163            for epoch in 0..epochs {
164                let mut epoch_loss = 0.0;
165
166                // Create a permutation for shuffling the data
167                let mut indices: Vec<usize> = (0..train_size).collect();
168                indices.shuffle(&mut rng);
169
170                for batch_idx in 0..num_batches {
171                    let start = batch_idx * batch_size;
172                    let end = (start + batch_size).min(train_size);
173                    let batch_indices = &indices[start..end];
174
175                    // Create batch data
176                    let mut x_batch = Array2::<f32>::zeros((batch_indices.len(), num_features));
177                    let mut y_batch = Array2::<f32>::zeros((batch_indices.len(), num_classes));
178
179                    for (i, &idx) in batch_indices.iter().enumerate() {
180                        for j in 0..num_features {
181                            x_batch[[i, j]] = x_train[[idx, j]];
182                        }
183                        for j in 0..num_classes {
184                            y_batch[[i, j]] = y_train[[idx, j]];
185                        }
186                    }
187
188                    // Convert to dynamic dimension arrays
189                    let x_batch_dyn = x_batch.into_dyn();
190                    let y_batch_dyn = y_batch.into_dyn();
191
192                    // Perform a training step
193                    let batch_loss = model
194                        .train_batch(&x_batch_dyn, &y_batch_dyn, &loss_fn, optimizer)
195                        .unwrap();
196                    epoch_loss += batch_loss;
197                }
198
199                epoch_loss /= num_batches as f32;
200                train_losses.push(epoch_loss);
201
202                // Calculate and print metrics every few epochs
203                if epoch % 5 == 0 || epoch == epochs - 1 {
204                    let train_accuracy = compute_accuracy(model, &x_train, &y_train);
205                    let test_accuracy = compute_accuracy(model, &x_test, &y_test);
206
207                    println!(
208                        "Epoch {}/{}: loss = {:.6}, train_acc = {:.2}%, test_acc = {:.2}%",
209                        epoch + 1,
210                        epochs,
211                        epoch_loss,
212                        train_accuracy * 100.0,
213                        test_accuracy * 100.0
214                    );
215                }
216            }
217
218            let elapsed = start_time.elapsed();
219            println!("{} training completed in {:.2?}", name, elapsed);
220
221            // Final evaluation
222            let train_accuracy = compute_accuracy(model, &x_train, &y_train);
223            let test_accuracy = compute_accuracy(model, &x_test, &y_test);
224            println!("Final metrics for {}:", name);
225            println!("  Train accuracy: {:.2}%", train_accuracy * 100.0);
226            println!("  Test accuracy:  {:.2}%", test_accuracy * 100.0);
227
228            train_losses
229        };
230
231    // Train models with different optimizers
232    let sgd_losses = train_model(&mut sgd_model, &mut sgd_optimizer, "SGD");
233    let adam_losses = train_model(&mut adam_model, &mut adam_optimizer, "Adam");
234    let adamw_losses = train_model(&mut adamw_model, &mut adamw_optimizer, "AdamW");
235    let radam_losses = train_model(&mut radam_model, &mut radam_optimizer, "RAdam");
236    let rmsprop_losses = train_model(&mut rmsprop_model, &mut rmsprop_optimizer, "RMSprop");
237
238    // Print comparison summary
239    println!("\nOptimizer Comparison Summary:");
240    println!("----------------------------");
241    println!("Initial learning rate: {}", learning_rate);
242    println!("Batch size: {}", batch_size);
243    println!("Epochs: {}", epochs);
244    println!();
245
246    println!("Final Loss Values:");
247    println!("  SGD:     {:.6}", sgd_losses.last().unwrap());
248    println!("  Adam:    {:.6}", adam_losses.last().unwrap());
249    println!("  AdamW:   {:.6}", adamw_losses.last().unwrap());
250    println!("  RAdam:   {:.6}", radam_losses.last().unwrap());
251    println!("  RMSprop: {:.6}", rmsprop_losses.last().unwrap());
252
253    println!("\nLoss progression (first value, middle value, last value):");
254    println!(
255        "  SGD:     {:.6}, {:.6}, {:.6}",
256        sgd_losses.first().unwrap(),
257        sgd_losses[epochs / 2],
258        sgd_losses.last().unwrap()
259    );
260    println!(
261        "  Adam:    {:.6}, {:.6}, {:.6}",
262        adam_losses.first().unwrap(),
263        adam_losses[epochs / 2],
264        adam_losses.last().unwrap()
265    );
266    println!(
267        "  AdamW:   {:.6}, {:.6}, {:.6}",
268        adamw_losses.first().unwrap(),
269        adamw_losses[epochs / 2],
270        adamw_losses.last().unwrap()
271    );
272    println!(
273        "  RAdam:   {:.6}, {:.6}, {:.6}",
274        radam_losses.first().unwrap(),
275        radam_losses[epochs / 2],
276        radam_losses.last().unwrap()
277    );
278    println!(
279        "  RMSprop: {:.6}, {:.6}, {:.6}",
280        rmsprop_losses.first().unwrap(),
281        rmsprop_losses[epochs / 2],
282        rmsprop_losses.last().unwrap()
283    );
284
285    println!("\nLoss improvement ratio (first loss / last loss):");
286    println!(
287        "  SGD:     {:.2}x",
288        sgd_losses.first().unwrap() / sgd_losses.last().unwrap()
289    );
290    println!(
291        "  Adam:    {:.2}x",
292        adam_losses.first().unwrap() / adam_losses.last().unwrap()
293    );
294    println!(
295        "  AdamW:   {:.2}x",
296        adamw_losses.first().unwrap() / adamw_losses.last().unwrap()
297    );
298    println!(
299        "  RAdam:   {:.2}x",
300        radam_losses.first().unwrap() / radam_losses.last().unwrap()
301    );
302    println!(
303        "  RMSprop: {:.2}x",
304        rmsprop_losses.first().unwrap() / rmsprop_losses.last().unwrap()
305    );
306
307    println!("\nAdvanced optimizers demo completed successfully!");
308
309    Ok(())
310}
Source

pub fn default_with_lr(learning_rate: F) -> Result<Self>

Creates a new AdamW optimizer with default hyperparameters

§Arguments
  • learning_rate - The learning rate for parameter updates
Source

pub fn get_beta1(&self) -> F

Gets the beta1 parameter

Source

pub fn set_beta1(&mut self, beta1: F) -> &mut Self

Sets the beta1 parameter

Source

pub fn get_beta2(&self) -> F

Gets the beta2 parameter

Source

pub fn set_beta2(&mut self, beta2: F) -> &mut Self

Sets the beta2 parameter

Source

pub fn get_epsilon(&self) -> F

Gets the epsilon parameter

Source

pub fn set_epsilon(&mut self, epsilon: F) -> &mut Self

Sets the epsilon parameter

Source

pub fn get_weight_decay(&self) -> F

Gets the weight decay parameter

Source

pub fn set_weight_decay(&mut self, weight_decay: F) -> &mut Self

Sets the weight decay parameter

Source

pub fn reset(&mut self)

Resets the internal state of the optimizer

Trait Implementations§

Source§

impl<F: Clone + Float + ScalarOperand + Debug> Clone for AdamW<F>

Source§

fn clone(&self) -> AdamW<F>

Returns a duplicate of the value. Read more
1.0.0 · Source§

const fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
Source§

impl<F: Debug + Float + ScalarOperand + Debug> Debug for AdamW<F>

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more
Source§

impl<F: Float + ScalarOperand + Debug> Optimizer<F> for AdamW<F>

Source§

fn update( &mut self, params: &mut [Array<F, IxDyn>], grads: &[Array<F, IxDyn>], ) -> Result<()>

Update parameters based on gradients
Source§

fn get_learning_rate(&self) -> F

Get the current learning rate
Source§

fn set_learning_rate(&mut self, lr: F)

Set the learning rate
Source§

fn step_model(&mut self, model: &mut dyn ParamLayer<F>) -> Result<()>

Update model parameters using the optimizer (non-generic version for trait objects)
Source§

fn reset(&mut self)

Reset the optimizer’s internal state
Source§

fn name(&self) -> &'static str

Get the optimizer’s name

Auto Trait Implementations§

§

impl<F> Freeze for AdamW<F>
where F: Freeze,

§

impl<F> RefUnwindSafe for AdamW<F>
where F: RefUnwindSafe,

§

impl<F> Send for AdamW<F>
where F: Send,

§

impl<F> Sync for AdamW<F>
where F: Sync,

§

impl<F> Unpin for AdamW<F>
where F: Unpin,

§

impl<F> UnwindSafe for AdamW<F>

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> CloneToUninit for T
where T: Clone,

Source§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<F, O> OptimizerStep<F> for O
where F: Float + Debug + ScalarOperand, O: Optimizer<F>,

Source§

fn step<L>(&mut self, model: &mut L) -> Result<(), NeuralError>
where L: ParamLayer<F> + ?Sized,

Update model parameters using the optimizer
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T> ToOwned for T
where T: Clone,

Source§

type Owned = T

The resulting type after obtaining ownership.
Source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
Source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V