1use super::common::{compute_gradient_norm, GradClipMode, Optimizer, OptimizerConfig};
43use crate::{TrainError, TrainResult};
44use scirs2_core::ndarray::{Array, Ix2};
45use std::collections::HashMap;
46
47#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
49pub enum SophiaVariant {
50 GaussNewtonBartlett,
52 Hutchinson,
54}
55
56#[derive(Debug, Clone)]
58pub struct SophiaConfig {
59 pub base: OptimizerConfig,
61 pub rho: f64,
63 pub hessian_update_freq: usize,
65 pub variant: SophiaVariant,
67}
68
69impl Default for SophiaConfig {
70 fn default() -> Self {
71 Self {
72 base: OptimizerConfig {
73 learning_rate: 2e-4,
74 beta1: 0.965,
75 beta2: 0.99,
76 epsilon: 1e-8,
77 weight_decay: 0.01,
78 ..Default::default()
79 },
80 rho: 0.04,
81 hessian_update_freq: 10,
82 variant: SophiaVariant::GaussNewtonBartlett,
83 }
84 }
85}
86
87pub struct SophiaOptimizer {
94 config: SophiaConfig,
95 m: HashMap<String, Array<f64, Ix2>>,
97 h: HashMap<String, Array<f64, Ix2>>,
99 t: usize,
101 steps_since_hessian_update: usize,
103}
104
105impl SophiaOptimizer {
106 pub fn new(config: OptimizerConfig) -> Self {
108 Self::with_sophia_config(SophiaConfig {
109 base: config,
110 ..Default::default()
111 })
112 }
113
114 pub fn with_sophia_config(config: SophiaConfig) -> Self {
116 Self {
117 config,
118 m: HashMap::new(),
119 h: HashMap::new(),
120 t: 0,
121 steps_since_hessian_update: 0,
122 }
123 }
124
125 fn clip_gradients(&self, gradients: &mut HashMap<String, Array<f64, Ix2>>) {
127 if let Some(clip_value) = self.config.base.grad_clip {
128 match self.config.base.grad_clip_mode {
129 GradClipMode::Value => {
130 for grad in gradients.values_mut() {
131 grad.mapv_inplace(|g| g.max(-clip_value).min(clip_value));
132 }
133 }
134 GradClipMode::Norm => {
135 let total_norm = compute_gradient_norm(gradients);
136 if total_norm > clip_value {
137 let scale = clip_value / total_norm;
138 for grad in gradients.values_mut() {
139 grad.mapv_inplace(|g| g * scale);
140 }
141 }
142 }
143 }
144 }
145 }
146
147 fn update_hessian_gnb(&mut self, gradients: &HashMap<String, Array<f64, Ix2>>) {
152 let beta2 = self.config.base.beta2;
153
154 for (name, grad) in gradients {
155 let grad_squared = grad.mapv(|g| g * g);
156
157 if let Some(h_state) = self.h.get_mut(name) {
158 *h_state = &*h_state * beta2 + &grad_squared * (1.0 - beta2);
160 } else {
161 self.h.insert(name.clone(), grad_squared * (1.0 - beta2));
162 }
163 }
164 }
165
166 fn update_hessian_hutchinson(&mut self, gradients: &HashMap<String, Array<f64, Ix2>>) {
174 self.update_hessian_gnb(gradients);
182 }
183}
184
185impl Optimizer for SophiaOptimizer {
186 fn step(
187 &mut self,
188 parameters: &mut HashMap<String, Array<f64, Ix2>>,
189 gradients: &HashMap<String, Array<f64, Ix2>>,
190 ) -> TrainResult<()> {
191 let mut clipped_gradients = gradients.clone();
192 self.clip_gradients(&mut clipped_gradients);
193
194 self.t += 1;
195 self.steps_since_hessian_update += 1;
196
197 let lr = self.config.base.learning_rate;
198 let beta1 = self.config.base.beta1;
199 let eps = self.config.base.epsilon;
200 let rho = self.config.rho;
201 let weight_decay = self.config.base.weight_decay;
202
203 if self.steps_since_hessian_update >= self.config.hessian_update_freq {
205 match self.config.variant {
206 SophiaVariant::GaussNewtonBartlett => {
207 self.update_hessian_gnb(&clipped_gradients);
208 }
209 SophiaVariant::Hutchinson => {
210 self.update_hessian_hutchinson(&clipped_gradients);
211 }
212 }
213 self.steps_since_hessian_update = 0;
214 }
215
216 let bias_correction1 = 1.0 - beta1.powi(self.t as i32);
218
219 for (name, param) in parameters.iter_mut() {
221 let grad = clipped_gradients.get(name).ok_or_else(|| {
222 TrainError::OptimizerError(format!("Missing gradient for parameter: {}", name))
223 })?;
224
225 if !self.m.contains_key(name) {
227 self.m.insert(name.clone(), Array::zeros(param.raw_dim()));
228 self.h
229 .insert(name.clone(), Array::ones(param.raw_dim()) * eps);
230 }
231
232 let m = self
233 .m
234 .get_mut(name)
235 .expect("m initialized for all parameters");
236 let h = self.h.get(name).expect("h initialized for all parameters");
237
238 *m = &*m * beta1 + &(grad * (1.0 - beta1));
240
241 let m_hat = &*m / bias_correction1;
243
244 let denominator = h * rho + eps;
246 let update_direction = &m_hat / &denominator;
247
248 let clipped_update = update_direction.mapv(|x| x.clamp(-1.0, 1.0));
250
251 *param = &*param - &(&clipped_update * lr);
253
254 if weight_decay > 0.0 {
256 *param = &*param - &(&*param * (weight_decay * lr));
257 }
258 }
259
260 Ok(())
261 }
262
263 fn zero_grad(&mut self) {
264 }
266
267 fn get_lr(&self) -> f64 {
268 self.config.base.learning_rate
269 }
270
271 fn set_lr(&mut self, lr: f64) {
272 self.config.base.learning_rate = lr;
273 }
274
275 fn state_dict(&self) -> HashMap<String, Vec<f64>> {
276 let mut state = HashMap::new();
277 state.insert("t".to_string(), vec![self.t as f64]);
278 state.insert(
279 "steps_since_hessian_update".to_string(),
280 vec![self.steps_since_hessian_update as f64],
281 );
282
283 for (name, m_val) in &self.m {
284 state.insert(format!("m_{}", name), m_val.iter().copied().collect());
285 }
286 for (name, h_val) in &self.h {
287 state.insert(format!("h_{}", name), h_val.iter().copied().collect());
288 }
289
290 state
291 }
292
293 fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
294 if let Some(t_vals) = state.get("t") {
295 self.t = t_vals[0] as usize;
296 }
297 if let Some(steps_vals) = state.get("steps_since_hessian_update") {
298 self.steps_since_hessian_update = steps_vals[0] as usize;
299 }
300
301 for (key, values) in state {
302 if let Some(name) = key.strip_prefix("m_") {
303 if let Some(m) = self.m.get(name) {
304 let shape = m.raw_dim();
305 if let Ok(arr) = Array::from_shape_vec(shape, values) {
306 self.m.insert(name.to_string(), arr);
307 }
308 }
309 } else if let Some(name) = key.strip_prefix("h_") {
310 if let Some(h) = self.h.get(name) {
311 let shape = h.raw_dim();
312 if let Ok(arr) = Array::from_shape_vec(shape, values) {
313 self.h.insert(name.to_string(), arr);
314 }
315 }
316 }
317 }
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324 use approx::assert_relative_eq;
325 use scirs2_core::ndarray::array;
326
327 #[test]
328 fn test_sophia_initialization() {
329 let config = OptimizerConfig::default();
330 let optimizer = SophiaOptimizer::new(config);
331
332 assert_eq!(optimizer.t, 0);
333 assert!(optimizer.m.is_empty());
334 assert!(optimizer.h.is_empty());
335 }
336
337 #[test]
338 fn test_sophia_custom_config() {
339 let config = SophiaConfig {
340 base: OptimizerConfig {
341 learning_rate: 1e-4,
342 beta1: 0.965,
343 beta2: 0.99,
344 ..Default::default()
345 },
346 rho: 0.04,
347 ..Default::default()
348 };
349
350 let optimizer = SophiaOptimizer::with_sophia_config(config);
351 assert_relative_eq!(optimizer.get_lr(), 1e-4);
352 }
353
354 #[test]
355 fn test_sophia_single_step() {
356 let config = OptimizerConfig {
357 learning_rate: 0.1,
358 ..Default::default()
359 };
360
361 let mut optimizer = SophiaOptimizer::new(config);
362 let mut params = HashMap::new();
363 params.insert("w".to_string(), array![[1.0, 2.0, 3.0]]);
364
365 let mut grads = HashMap::new();
366 grads.insert("w".to_string(), array![[0.1, 0.2, 0.3]]);
367
368 let initial = params["w"].clone();
369 optimizer.step(&mut params, &grads).expect("unwrap");
370
371 assert!(params["w"][[0, 0]] < initial[[0, 0]]);
373 assert!(params["w"][[0, 1]] < initial[[0, 1]]);
374 assert!(params["w"][[0, 2]] < initial[[0, 2]]);
375 }
376
377 #[test]
378 fn test_sophia_convergence() {
379 let config = OptimizerConfig {
380 learning_rate: 0.1,
381 ..Default::default()
382 };
383
384 let mut optimizer = SophiaOptimizer::new(config);
385 let mut params = HashMap::new();
386 params.insert("w".to_string(), array![[5.0], [-3.0], [2.0]]);
387
388 for _ in 0..50 {
390 let mut grads = HashMap::new();
391 grads.insert("w".to_string(), ¶ms["w"] * 2.0); optimizer.step(&mut params, &grads).expect("unwrap");
393 }
394
395 for &p in params["w"].iter() {
397 assert!(p.abs() < 0.5);
398 }
399 }
400
401 #[test]
402 fn test_sophia_2d_parameters() {
403 let config = OptimizerConfig {
404 learning_rate: 0.01,
405 ..Default::default()
406 };
407
408 let mut optimizer = SophiaOptimizer::new(config);
409 let mut params = HashMap::new();
410 params.insert("w".to_string(), array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
411
412 let mut grads = HashMap::new();
413 grads.insert("w".to_string(), array![[0.1, 0.1, 0.1], [-0.1, -0.1, -0.1]]);
414
415 let initial_shape = params["w"].shape().to_vec();
416 optimizer.step(&mut params, &grads).expect("unwrap");
417
418 assert_eq!(params["w"].shape(), &initial_shape[..]);
419 }
420
421 #[test]
422 fn test_sophia_reset_and_state_dict() {
423 let config = OptimizerConfig::default();
424 let mut optimizer = SophiaOptimizer::new(config);
425
426 let mut params = HashMap::new();
427 params.insert("w".to_string(), array![[1.0, 2.0]]);
428
429 let mut grads = HashMap::new();
430 grads.insert("w".to_string(), array![[0.1, 0.2]]);
431
432 optimizer.step(&mut params, &grads).expect("unwrap");
433 assert!(!optimizer.m.is_empty());
434 assert_eq!(optimizer.t, 1);
435
436 let state = optimizer.state_dict();
438 assert!(state.contains_key("t"));
439 assert!(state.contains_key("m_w"));
440 assert!(state.contains_key("h_w"));
441 }
442
443 #[test]
444 fn test_sophia_hessian_update_frequency() {
445 let config = SophiaConfig {
446 hessian_update_freq: 5,
447 ..Default::default()
448 };
449
450 let mut optimizer = SophiaOptimizer::with_sophia_config(config);
451 let mut params = HashMap::new();
452 params.insert("w".to_string(), array![[1.0, 2.0]]);
453
454 let mut grads = HashMap::new();
455 grads.insert("w".to_string(), array![[0.1, 0.2]]);
456
457 optimizer.step(&mut params, &grads).expect("unwrap");
459 assert_eq!(optimizer.steps_since_hessian_update, 1);
460
461 for _ in 0..4 {
463 optimizer.step(&mut params, &grads).expect("unwrap");
464 }
465 assert_eq!(optimizer.steps_since_hessian_update, 0); assert!(optimizer.h.contains_key("w"));
469 }
470
471 #[test]
472 fn test_sophia_weight_decay() {
473 let config = SophiaConfig {
474 base: OptimizerConfig {
475 learning_rate: 0.1,
476 weight_decay: 0.01,
477 ..Default::default()
478 },
479 ..Default::default()
480 };
481
482 let mut optimizer = SophiaOptimizer::with_sophia_config(config);
483 let mut params = HashMap::new();
484 params.insert("w".to_string(), array![[1.0, 2.0, 3.0]]);
485
486 let mut grads = HashMap::new();
487 grads.insert("w".to_string(), array![[0.0, 0.0, 0.0]]); let initial = params["w"].clone();
490 optimizer.step(&mut params, &grads).expect("unwrap");
491
492 assert!(params["w"][[0, 0]] < initial[[0, 0]]);
494 assert!(params["w"][[0, 1]] < initial[[0, 1]]);
495 assert!(params["w"][[0, 2]] < initial[[0, 2]]);
496 }
497
498 #[test]
499 fn test_sophia_gradient_clipping_value() {
500 let config = SophiaConfig {
501 base: OptimizerConfig {
502 learning_rate: 0.1,
503 grad_clip: Some(0.5),
504 grad_clip_mode: GradClipMode::Value,
505 ..Default::default()
506 },
507 ..Default::default()
508 };
509
510 let mut optimizer = SophiaOptimizer::with_sophia_config(config);
511 let mut params = HashMap::new();
512 params.insert("w".to_string(), array![[1.0, 2.0]]);
513
514 let mut grads = HashMap::new();
515 grads.insert("w".to_string(), array![[1.0, -2.0]]); let initial = params["w"].clone();
518 optimizer.step(&mut params, &grads).expect("unwrap");
519
520 let update_mag = (initial[[0, 0]] - params["w"][[0, 0]]).abs();
522 assert!(update_mag < 0.2); }
524
525 #[test]
526 fn test_sophia_gradient_clipping_norm() {
527 let config = SophiaConfig {
528 base: OptimizerConfig {
529 learning_rate: 0.1,
530 grad_clip: Some(1.0),
531 grad_clip_mode: GradClipMode::Norm,
532 ..Default::default()
533 },
534 ..Default::default()
535 };
536
537 let mut optimizer = SophiaOptimizer::with_sophia_config(config);
538 let mut params = HashMap::new();
539 params.insert("w".to_string(), array![[1.0, 2.0, 3.0]]);
540
541 let mut grads = HashMap::new();
542 grads.insert("w".to_string(), array![[10.0, 10.0, 10.0]]); let initial = params["w"].clone();
545 optimizer.step(&mut params, &grads).expect("unwrap");
546
547 let total_update: f64 = initial
549 .iter()
550 .zip(params["w"].iter())
551 .map(|(&p, &u)| (p - u).powi(2))
552 .sum::<f64>()
553 .sqrt();
554
555 assert!(total_update < 1.0); }
557
558 #[test]
559 fn test_sophia_learning_rate_getter_setter() {
560 let config = OptimizerConfig::default();
561 let mut optimizer = SophiaOptimizer::new(config);
562
563 optimizer.set_lr(0.001);
564 assert_relative_eq!(optimizer.get_lr(), 0.001);
565
566 optimizer.set_lr(0.1);
567 assert_relative_eq!(optimizer.get_lr(), 0.1);
568 }
569
570 #[test]
571 fn test_sophia_variant_gnb() {
572 let config = SophiaConfig {
573 variant: SophiaVariant::GaussNewtonBartlett,
574 ..Default::default()
575 };
576
577 let mut optimizer = SophiaOptimizer::with_sophia_config(config);
578 let mut params = HashMap::new();
579 params.insert("w".to_string(), array![[1.0, 2.0]]);
580
581 let mut grads = HashMap::new();
582 grads.insert("w".to_string(), array![[0.5, 0.5]]);
583
584 let initial = params["w"].clone();
585 optimizer.step(&mut params, &grads).expect("unwrap");
586 assert!(params["w"][[0, 0]] < initial[[0, 0]]); }
588
589 #[test]
590 fn test_sophia_variant_hutchinson() {
591 let config = SophiaConfig {
592 variant: SophiaVariant::Hutchinson,
593 ..Default::default()
594 };
595
596 let mut optimizer = SophiaOptimizer::with_sophia_config(config);
597 let mut params = HashMap::new();
598 params.insert("w".to_string(), array![[1.0, 2.0]]);
599
600 let mut grads = HashMap::new();
601 grads.insert("w".to_string(), array![[0.5, 0.5]]);
602
603 let initial = params["w"].clone();
604 optimizer.step(&mut params, &grads).expect("unwrap");
605 assert!(params["w"][[0, 0]] < initial[[0, 0]]); }
607
608 #[test]
609 fn test_sophia_update_clipping() {
610 let config = SophiaConfig {
612 base: OptimizerConfig {
613 learning_rate: 0.1,
614 ..Default::default()
615 },
616 rho: 0.001, ..Default::default()
618 };
619
620 let mut optimizer = SophiaOptimizer::with_sophia_config(config);
621 let mut params = HashMap::new();
622 params.insert("w".to_string(), array![[10.0]]);
623
624 let mut grads = HashMap::new();
625 grads.insert("w".to_string(), array![[100.0]]); let initial = params["w"][[0, 0]];
628 optimizer.step(&mut params, &grads).expect("unwrap");
629
630 let update_size = (initial - params["w"][[0, 0]]).abs();
632 assert!(update_size <= 0.12); }
634
635 #[test]
636 fn test_sophia_load_state_dict() {
637 let config = OptimizerConfig::default();
638 let mut optimizer1 = SophiaOptimizer::new(config.clone());
639 let mut optimizer2 = SophiaOptimizer::new(config);
640
641 let mut params = HashMap::new();
642 params.insert("w".to_string(), array![[1.0, 2.0]]);
643
644 let mut grads = HashMap::new();
645 grads.insert("w".to_string(), array![[0.1, 0.2]]);
646
647 for _ in 0..5 {
649 optimizer1.step(&mut params, &grads).expect("unwrap");
650 }
651
652 let state = optimizer1.state_dict();
654 optimizer2.load_state_dict(state);
655
656 assert_eq!(optimizer2.t, optimizer1.t);
658 assert_eq!(
659 optimizer2.steps_since_hessian_update,
660 optimizer1.steps_since_hessian_update
661 );
662 }
663}