Skip to main content

tensorlogic_train/optimizers/
sam.rs

1//! SAM optimizer (Sharpness Aware Minimization).
2//!
3//! SAM seeks parameters that lie in neighborhoods having uniformly low loss,
4//! improving model generalization. It requires two forward-backward passes per step:
5//! one to compute the adversarial perturbation, and one to compute the actual gradient.
6//!
7//! Reference: Foret et al. "Sharpness-Aware Minimization for Efficiently Improving Generalization" (ICLR 2021)
8//!
9//! Note: This is a wrapper optimizer. SAM requires special handling in the training loop
10//! to perform two gradient computations per step. The typical usage is:
11//! 1. Compute gradients at current parameters
12//! 2. Compute adversarial perturbation
13//! 3. Compute gradients at perturbed parameters
14//! 4. Update with the perturbed gradients
15
16use super::common::{compute_gradient_norm, Optimizer};
17use crate::{TrainError, TrainResult};
18use scirs2_core::ndarray::{Array, Ix2};
19use std::collections::HashMap;
20
21/// SAM optimizer (Sharpness Aware Minimization).
22///
23/// SAM seeks parameters that lie in neighborhoods having uniformly low loss,
24/// improving model generalization. It requires two forward-backward passes per step:
25/// one to compute the adversarial perturbation, and one to compute the actual gradient.
26///
27/// Reference: Foret et al. "Sharpness-Aware Minimization for Efficiently Improving Generalization" (ICLR 2021)
28///
29/// Note: This is a wrapper optimizer. SAM requires special handling in the training loop
30/// to perform two gradient computations per step. The typical usage is:
31/// 1. Compute gradients at current parameters
32/// 2. Compute adversarial perturbation
33/// 3. Compute gradients at perturbed parameters
34/// 4. Update with the perturbed gradients
35#[derive(Debug)]
36pub struct SamOptimizer<O: Optimizer> {
37    /// Base optimizer (e.g., SGD, Adam).
38    base_optimizer: O,
39    /// Perturbation radius (rho).
40    rho: f64,
41    /// Stored perturbations for each parameter.
42    perturbations: HashMap<String, Array<f64, Ix2>>,
43}
44
45impl<O: Optimizer> SamOptimizer<O> {
46    /// Create a new SAM optimizer.
47    ///
48    /// # Arguments
49    /// * `base_optimizer` - The base optimizer to use (SGD, Adam, etc.)
50    /// * `rho` - Perturbation radius (typically 0.05)
51    pub fn new(base_optimizer: O, rho: f64) -> TrainResult<Self> {
52        if rho <= 0.0 {
53            return Err(TrainError::OptimizerError(
54                "SAM rho must be positive".to_string(),
55            ));
56        }
57        Ok(Self {
58            base_optimizer,
59            rho,
60            perturbations: HashMap::new(),
61        })
62    }
63
64    /// Compute adversarial perturbations.
65    ///
66    /// This should be called with the first set of gradients to compute
67    /// the perturbation direction.
68    pub fn first_step(
69        &mut self,
70        parameters: &mut HashMap<String, Array<f64, Ix2>>,
71        gradients: &HashMap<String, Array<f64, Ix2>>,
72    ) -> TrainResult<()> {
73        let grad_norm = compute_gradient_norm(gradients);
74        if grad_norm == 0.0 {
75            return Ok(());
76        }
77        for (name, param) in parameters.iter_mut() {
78            let grad = gradients.get(name).ok_or_else(|| {
79                TrainError::OptimizerError(format!("Missing gradient for parameter: {}", name))
80            })?;
81            let perturbation = grad.mapv(|g| self.rho * g / grad_norm);
82            *param = &*param + &perturbation;
83            self.perturbations.insert(name.clone(), perturbation);
84        }
85        Ok(())
86    }
87
88    /// Perform the actual optimization step.
89    ///
90    /// This should be called with the second set of gradients (computed at the perturbed parameters).
91    /// It will remove the perturbations and update the parameters using the base optimizer.
92    pub fn second_step(
93        &mut self,
94        parameters: &mut HashMap<String, Array<f64, Ix2>>,
95        gradients: &HashMap<String, Array<f64, Ix2>>,
96    ) -> TrainResult<()> {
97        for (name, param) in parameters.iter_mut() {
98            if let Some(perturbation) = self.perturbations.get(name) {
99                *param = &*param - perturbation;
100            }
101        }
102        self.perturbations.clear();
103        self.base_optimizer.step(parameters, gradients)
104    }
105}
106
107impl<O: Optimizer> Optimizer for SamOptimizer<O> {
108    fn step(
109        &mut self,
110        parameters: &mut HashMap<String, Array<f64, Ix2>>,
111        gradients: &HashMap<String, Array<f64, Ix2>>,
112    ) -> TrainResult<()> {
113        self.second_step(parameters, gradients)
114    }
115
116    fn zero_grad(&mut self) {
117        self.base_optimizer.zero_grad();
118    }
119
120    fn get_lr(&self) -> f64 {
121        self.base_optimizer.get_lr()
122    }
123
124    fn set_lr(&mut self, lr: f64) {
125        self.base_optimizer.set_lr(lr);
126    }
127
128    fn state_dict(&self) -> HashMap<String, Vec<f64>> {
129        let mut state = self.base_optimizer.state_dict();
130        state.insert("rho".to_string(), vec![self.rho]);
131        for (name, perturbation) in &self.perturbations {
132            state.insert(
133                format!("perturbation_{}", name),
134                perturbation.iter().copied().collect(),
135            );
136        }
137        state
138    }
139
140    fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
141        if let Some(rho_val) = state.get("rho") {
142            self.rho = rho_val[0];
143        }
144        self.base_optimizer.load_state_dict(state.clone());
145        for (key, values) in state {
146            if let Some(name) = key.strip_prefix("perturbation_") {
147                if let Some(pert) = self.perturbations.get(name) {
148                    let shape = pert.raw_dim();
149                    if let Ok(arr) = Array::from_shape_vec(shape, values) {
150                        self.perturbations.insert(name.to_string(), arr);
151                    }
152                }
153            }
154        }
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use super::super::common::OptimizerConfig;
161    use super::super::sgd::SgdOptimizer;
162    use super::*;
163    use scirs2_core::ndarray::array;
164
165    #[test]
166    fn test_sam_optimizer() {
167        let inner_config = OptimizerConfig {
168            learning_rate: 0.01,
169            ..Default::default()
170        };
171        let inner_optimizer = SgdOptimizer::new(inner_config);
172        let mut optimizer = SamOptimizer::new(inner_optimizer, 0.05).unwrap();
173        let mut params = HashMap::new();
174        params.insert("w".to_string(), array![[1.0, 2.0]]);
175        let mut grads = HashMap::new();
176        grads.insert("w".to_string(), array![[0.1, 0.1]]);
177        let original_w = params.get("w").unwrap().clone();
178        optimizer.first_step(&mut params, &grads).unwrap();
179        let perturbed_w = params.get("w").unwrap();
180        assert_ne!(perturbed_w[[0, 0]], original_w[[0, 0]]);
181        optimizer.second_step(&mut params, &grads).unwrap();
182        let final_w = params.get("w").unwrap();
183        assert!(final_w[[0, 0]] < original_w[[0, 0]]);
184        let state = optimizer.state_dict();
185        assert!(state.contains_key("rho"));
186    }
187
188    #[test]
189    fn test_sam_invalid_rho() {
190        let inner_optimizer = SgdOptimizer::new(OptimizerConfig::default());
191        let result = SamOptimizer::new(inner_optimizer, 0.0);
192        assert!(result.is_err());
193        let inner_optimizer = SgdOptimizer::new(OptimizerConfig::default());
194        let result = SamOptimizer::new(inner_optimizer, -0.1);
195        assert!(result.is_err());
196    }
197}