tensorlogic_train/optimizers/
sam.rs1use super::common::{compute_gradient_norm, Optimizer};
17use crate::{TrainError, TrainResult};
18use scirs2_core::ndarray::{Array, Ix2};
19use std::collections::HashMap;
20
21#[derive(Debug)]
36pub struct SamOptimizer<O: Optimizer> {
37 base_optimizer: O,
39 rho: f64,
41 perturbations: HashMap<String, Array<f64, Ix2>>,
43}
44
45impl<O: Optimizer> SamOptimizer<O> {
46 pub fn new(base_optimizer: O, rho: f64) -> TrainResult<Self> {
52 if rho <= 0.0 {
53 return Err(TrainError::OptimizerError(
54 "SAM rho must be positive".to_string(),
55 ));
56 }
57 Ok(Self {
58 base_optimizer,
59 rho,
60 perturbations: HashMap::new(),
61 })
62 }
63
64 pub fn first_step(
69 &mut self,
70 parameters: &mut HashMap<String, Array<f64, Ix2>>,
71 gradients: &HashMap<String, Array<f64, Ix2>>,
72 ) -> TrainResult<()> {
73 let grad_norm = compute_gradient_norm(gradients);
74 if grad_norm == 0.0 {
75 return Ok(());
76 }
77 for (name, param) in parameters.iter_mut() {
78 let grad = gradients.get(name).ok_or_else(|| {
79 TrainError::OptimizerError(format!("Missing gradient for parameter: {}", name))
80 })?;
81 let perturbation = grad.mapv(|g| self.rho * g / grad_norm);
82 *param = &*param + &perturbation;
83 self.perturbations.insert(name.clone(), perturbation);
84 }
85 Ok(())
86 }
87
88 pub fn second_step(
93 &mut self,
94 parameters: &mut HashMap<String, Array<f64, Ix2>>,
95 gradients: &HashMap<String, Array<f64, Ix2>>,
96 ) -> TrainResult<()> {
97 for (name, param) in parameters.iter_mut() {
98 if let Some(perturbation) = self.perturbations.get(name) {
99 *param = &*param - perturbation;
100 }
101 }
102 self.perturbations.clear();
103 self.base_optimizer.step(parameters, gradients)
104 }
105}
106
107impl<O: Optimizer> Optimizer for SamOptimizer<O> {
108 fn step(
109 &mut self,
110 parameters: &mut HashMap<String, Array<f64, Ix2>>,
111 gradients: &HashMap<String, Array<f64, Ix2>>,
112 ) -> TrainResult<()> {
113 self.second_step(parameters, gradients)
114 }
115
116 fn zero_grad(&mut self) {
117 self.base_optimizer.zero_grad();
118 }
119
120 fn get_lr(&self) -> f64 {
121 self.base_optimizer.get_lr()
122 }
123
124 fn set_lr(&mut self, lr: f64) {
125 self.base_optimizer.set_lr(lr);
126 }
127
128 fn state_dict(&self) -> HashMap<String, Vec<f64>> {
129 let mut state = self.base_optimizer.state_dict();
130 state.insert("rho".to_string(), vec![self.rho]);
131 for (name, perturbation) in &self.perturbations {
132 state.insert(
133 format!("perturbation_{}", name),
134 perturbation.iter().copied().collect(),
135 );
136 }
137 state
138 }
139
140 fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
141 if let Some(rho_val) = state.get("rho") {
142 self.rho = rho_val[0];
143 }
144 self.base_optimizer.load_state_dict(state.clone());
145 for (key, values) in state {
146 if let Some(name) = key.strip_prefix("perturbation_") {
147 if let Some(pert) = self.perturbations.get(name) {
148 let shape = pert.raw_dim();
149 if let Ok(arr) = Array::from_shape_vec(shape, values) {
150 self.perturbations.insert(name.to_string(), arr);
151 }
152 }
153 }
154 }
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use super::super::common::OptimizerConfig;
161 use super::super::sgd::SgdOptimizer;
162 use super::*;
163 use scirs2_core::ndarray::array;
164
165 #[test]
166 fn test_sam_optimizer() {
167 let inner_config = OptimizerConfig {
168 learning_rate: 0.01,
169 ..Default::default()
170 };
171 let inner_optimizer = SgdOptimizer::new(inner_config);
172 let mut optimizer = SamOptimizer::new(inner_optimizer, 0.05).unwrap();
173 let mut params = HashMap::new();
174 params.insert("w".to_string(), array![[1.0, 2.0]]);
175 let mut grads = HashMap::new();
176 grads.insert("w".to_string(), array![[0.1, 0.1]]);
177 let original_w = params.get("w").unwrap().clone();
178 optimizer.first_step(&mut params, &grads).unwrap();
179 let perturbed_w = params.get("w").unwrap();
180 assert_ne!(perturbed_w[[0, 0]], original_w[[0, 0]]);
181 optimizer.second_step(&mut params, &grads).unwrap();
182 let final_w = params.get("w").unwrap();
183 assert!(final_w[[0, 0]] < original_w[[0, 0]]);
184 let state = optimizer.state_dict();
185 assert!(state.contains_key("rho"));
186 }
187
188 #[test]
189 fn test_sam_invalid_rho() {
190 let inner_optimizer = SgdOptimizer::new(OptimizerConfig::default());
191 let result = SamOptimizer::new(inner_optimizer, 0.0);
192 assert!(result.is_err());
193 let inner_optimizer = SgdOptimizer::new(OptimizerConfig::default());
194 let result = SamOptimizer::new(inner_optimizer, -0.1);
195 assert!(result.is_err());
196 }
197}