1use super::common::{compute_gradient_norm, GradClipMode, Optimizer, OptimizerConfig};
43use crate::{TrainError, TrainResult};
44use scirs2_core::ndarray::{Array, Ix2};
45use std::collections::HashMap;
46
47#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
49pub enum SophiaVariant {
50 GaussNewtonBartlett,
52 Hutchinson,
54}
55
56#[derive(Debug, Clone)]
58pub struct SophiaConfig {
59 pub base: OptimizerConfig,
61 pub rho: f64,
63 pub hessian_update_freq: usize,
65 pub variant: SophiaVariant,
67}
68
69impl Default for SophiaConfig {
70 fn default() -> Self {
71 Self {
72 base: OptimizerConfig {
73 learning_rate: 2e-4,
74 beta1: 0.965,
75 beta2: 0.99,
76 epsilon: 1e-8,
77 weight_decay: 0.01,
78 ..Default::default()
79 },
80 rho: 0.04,
81 hessian_update_freq: 10,
82 variant: SophiaVariant::GaussNewtonBartlett,
83 }
84 }
85}
86
87pub struct SophiaOptimizer {
94 config: SophiaConfig,
95 m: HashMap<String, Array<f64, Ix2>>,
97 h: HashMap<String, Array<f64, Ix2>>,
99 t: usize,
101 steps_since_hessian_update: usize,
103}
104
105impl SophiaOptimizer {
106 pub fn new(config: OptimizerConfig) -> Self {
108 Self::with_sophia_config(SophiaConfig {
109 base: config,
110 ..Default::default()
111 })
112 }
113
114 pub fn with_sophia_config(config: SophiaConfig) -> Self {
116 Self {
117 config,
118 m: HashMap::new(),
119 h: HashMap::new(),
120 t: 0,
121 steps_since_hessian_update: 0,
122 }
123 }
124
125 fn clip_gradients(&self, gradients: &mut HashMap<String, Array<f64, Ix2>>) {
127 if let Some(clip_value) = self.config.base.grad_clip {
128 match self.config.base.grad_clip_mode {
129 GradClipMode::Value => {
130 for grad in gradients.values_mut() {
131 grad.mapv_inplace(|g| g.max(-clip_value).min(clip_value));
132 }
133 }
134 GradClipMode::Norm => {
135 let total_norm = compute_gradient_norm(gradients);
136 if total_norm > clip_value {
137 let scale = clip_value / total_norm;
138 for grad in gradients.values_mut() {
139 grad.mapv_inplace(|g| g * scale);
140 }
141 }
142 }
143 }
144 }
145 }
146
147 fn update_hessian_gnb(&mut self, gradients: &HashMap<String, Array<f64, Ix2>>) {
152 let beta2 = self.config.base.beta2;
153
154 for (name, grad) in gradients {
155 let grad_squared = grad.mapv(|g| g * g);
156
157 if let Some(h_state) = self.h.get_mut(name) {
158 *h_state = &*h_state * beta2 + &grad_squared * (1.0 - beta2);
160 } else {
161 self.h.insert(name.clone(), grad_squared * (1.0 - beta2));
162 }
163 }
164 }
165
166 fn update_hessian_hutchinson(&mut self, gradients: &HashMap<String, Array<f64, Ix2>>) {
174 self.update_hessian_gnb(gradients);
182 }
183}
184
185impl Optimizer for SophiaOptimizer {
186 fn step(
187 &mut self,
188 parameters: &mut HashMap<String, Array<f64, Ix2>>,
189 gradients: &HashMap<String, Array<f64, Ix2>>,
190 ) -> TrainResult<()> {
191 let mut clipped_gradients = gradients.clone();
192 self.clip_gradients(&mut clipped_gradients);
193
194 self.t += 1;
195 self.steps_since_hessian_update += 1;
196
197 let lr = self.config.base.learning_rate;
198 let beta1 = self.config.base.beta1;
199 let eps = self.config.base.epsilon;
200 let rho = self.config.rho;
201 let weight_decay = self.config.base.weight_decay;
202
203 if self.steps_since_hessian_update >= self.config.hessian_update_freq {
205 match self.config.variant {
206 SophiaVariant::GaussNewtonBartlett => {
207 self.update_hessian_gnb(&clipped_gradients);
208 }
209 SophiaVariant::Hutchinson => {
210 self.update_hessian_hutchinson(&clipped_gradients);
211 }
212 }
213 self.steps_since_hessian_update = 0;
214 }
215
216 let bias_correction1 = 1.0 - beta1.powi(self.t as i32);
218
219 for (name, param) in parameters.iter_mut() {
221 let grad = clipped_gradients.get(name).ok_or_else(|| {
222 TrainError::OptimizerError(format!("Missing gradient for parameter: {}", name))
223 })?;
224
225 if !self.m.contains_key(name) {
227 self.m.insert(name.clone(), Array::zeros(param.raw_dim()));
228 self.h
229 .insert(name.clone(), Array::ones(param.raw_dim()) * eps);
230 }
231
232 let m = self.m.get_mut(name).unwrap();
233 let h = self.h.get(name).unwrap();
234
235 *m = &*m * beta1 + &(grad * (1.0 - beta1));
237
238 let m_hat = &*m / bias_correction1;
240
241 let denominator = h * rho + eps;
243 let update_direction = &m_hat / &denominator;
244
245 let clipped_update = update_direction.mapv(|x| x.clamp(-1.0, 1.0));
247
248 *param = &*param - &(&clipped_update * lr);
250
251 if weight_decay > 0.0 {
253 *param = &*param - &(&*param * (weight_decay * lr));
254 }
255 }
256
257 Ok(())
258 }
259
260 fn zero_grad(&mut self) {
261 }
263
264 fn get_lr(&self) -> f64 {
265 self.config.base.learning_rate
266 }
267
268 fn set_lr(&mut self, lr: f64) {
269 self.config.base.learning_rate = lr;
270 }
271
272 fn state_dict(&self) -> HashMap<String, Vec<f64>> {
273 let mut state = HashMap::new();
274 state.insert("t".to_string(), vec![self.t as f64]);
275 state.insert(
276 "steps_since_hessian_update".to_string(),
277 vec![self.steps_since_hessian_update as f64],
278 );
279
280 for (name, m_val) in &self.m {
281 state.insert(format!("m_{}", name), m_val.iter().copied().collect());
282 }
283 for (name, h_val) in &self.h {
284 state.insert(format!("h_{}", name), h_val.iter().copied().collect());
285 }
286
287 state
288 }
289
290 fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
291 if let Some(t_vals) = state.get("t") {
292 self.t = t_vals[0] as usize;
293 }
294 if let Some(steps_vals) = state.get("steps_since_hessian_update") {
295 self.steps_since_hessian_update = steps_vals[0] as usize;
296 }
297
298 for (key, values) in state {
299 if let Some(name) = key.strip_prefix("m_") {
300 if let Some(m) = self.m.get(name) {
301 let shape = m.raw_dim();
302 if let Ok(arr) = Array::from_shape_vec(shape, values) {
303 self.m.insert(name.to_string(), arr);
304 }
305 }
306 } else if let Some(name) = key.strip_prefix("h_") {
307 if let Some(h) = self.h.get(name) {
308 let shape = h.raw_dim();
309 if let Ok(arr) = Array::from_shape_vec(shape, values) {
310 self.h.insert(name.to_string(), arr);
311 }
312 }
313 }
314 }
315 }
316}
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321 use approx::assert_relative_eq;
322 use scirs2_core::ndarray::array;
323
324 #[test]
325 fn test_sophia_initialization() {
326 let config = OptimizerConfig::default();
327 let optimizer = SophiaOptimizer::new(config);
328
329 assert_eq!(optimizer.t, 0);
330 assert!(optimizer.m.is_empty());
331 assert!(optimizer.h.is_empty());
332 }
333
334 #[test]
335 fn test_sophia_custom_config() {
336 let config = SophiaConfig {
337 base: OptimizerConfig {
338 learning_rate: 1e-4,
339 beta1: 0.965,
340 beta2: 0.99,
341 ..Default::default()
342 },
343 rho: 0.04,
344 ..Default::default()
345 };
346
347 let optimizer = SophiaOptimizer::with_sophia_config(config);
348 assert_relative_eq!(optimizer.get_lr(), 1e-4);
349 }
350
351 #[test]
352 fn test_sophia_single_step() {
353 let config = OptimizerConfig {
354 learning_rate: 0.1,
355 ..Default::default()
356 };
357
358 let mut optimizer = SophiaOptimizer::new(config);
359 let mut params = HashMap::new();
360 params.insert("w".to_string(), array![[1.0, 2.0, 3.0]]);
361
362 let mut grads = HashMap::new();
363 grads.insert("w".to_string(), array![[0.1, 0.2, 0.3]]);
364
365 let initial = params["w"].clone();
366 optimizer.step(&mut params, &grads).unwrap();
367
368 assert!(params["w"][[0, 0]] < initial[[0, 0]]);
370 assert!(params["w"][[0, 1]] < initial[[0, 1]]);
371 assert!(params["w"][[0, 2]] < initial[[0, 2]]);
372 }
373
374 #[test]
375 fn test_sophia_convergence() {
376 let config = OptimizerConfig {
377 learning_rate: 0.1,
378 ..Default::default()
379 };
380
381 let mut optimizer = SophiaOptimizer::new(config);
382 let mut params = HashMap::new();
383 params.insert("w".to_string(), array![[5.0], [-3.0], [2.0]]);
384
385 for _ in 0..50 {
387 let mut grads = HashMap::new();
388 grads.insert("w".to_string(), ¶ms["w"] * 2.0); optimizer.step(&mut params, &grads).unwrap();
390 }
391
392 for &p in params["w"].iter() {
394 assert!(p.abs() < 0.5);
395 }
396 }
397
398 #[test]
399 fn test_sophia_2d_parameters() {
400 let config = OptimizerConfig {
401 learning_rate: 0.01,
402 ..Default::default()
403 };
404
405 let mut optimizer = SophiaOptimizer::new(config);
406 let mut params = HashMap::new();
407 params.insert("w".to_string(), array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
408
409 let mut grads = HashMap::new();
410 grads.insert("w".to_string(), array![[0.1, 0.1, 0.1], [-0.1, -0.1, -0.1]]);
411
412 let initial_shape = params["w"].shape().to_vec();
413 optimizer.step(&mut params, &grads).unwrap();
414
415 assert_eq!(params["w"].shape(), &initial_shape[..]);
416 }
417
418 #[test]
419 fn test_sophia_reset_and_state_dict() {
420 let config = OptimizerConfig::default();
421 let mut optimizer = SophiaOptimizer::new(config);
422
423 let mut params = HashMap::new();
424 params.insert("w".to_string(), array![[1.0, 2.0]]);
425
426 let mut grads = HashMap::new();
427 grads.insert("w".to_string(), array![[0.1, 0.2]]);
428
429 optimizer.step(&mut params, &grads).unwrap();
430 assert!(!optimizer.m.is_empty());
431 assert_eq!(optimizer.t, 1);
432
433 let state = optimizer.state_dict();
435 assert!(state.contains_key("t"));
436 assert!(state.contains_key("m_w"));
437 assert!(state.contains_key("h_w"));
438 }
439
440 #[test]
441 fn test_sophia_hessian_update_frequency() {
442 let config = SophiaConfig {
443 hessian_update_freq: 5,
444 ..Default::default()
445 };
446
447 let mut optimizer = SophiaOptimizer::with_sophia_config(config);
448 let mut params = HashMap::new();
449 params.insert("w".to_string(), array![[1.0, 2.0]]);
450
451 let mut grads = HashMap::new();
452 grads.insert("w".to_string(), array![[0.1, 0.2]]);
453
454 optimizer.step(&mut params, &grads).unwrap();
456 assert_eq!(optimizer.steps_since_hessian_update, 1);
457
458 for _ in 0..4 {
460 optimizer.step(&mut params, &grads).unwrap();
461 }
462 assert_eq!(optimizer.steps_since_hessian_update, 0); assert!(optimizer.h.contains_key("w"));
466 }
467
468 #[test]
469 fn test_sophia_weight_decay() {
470 let config = SophiaConfig {
471 base: OptimizerConfig {
472 learning_rate: 0.1,
473 weight_decay: 0.01,
474 ..Default::default()
475 },
476 ..Default::default()
477 };
478
479 let mut optimizer = SophiaOptimizer::with_sophia_config(config);
480 let mut params = HashMap::new();
481 params.insert("w".to_string(), array![[1.0, 2.0, 3.0]]);
482
483 let mut grads = HashMap::new();
484 grads.insert("w".to_string(), array![[0.0, 0.0, 0.0]]); let initial = params["w"].clone();
487 optimizer.step(&mut params, &grads).unwrap();
488
489 assert!(params["w"][[0, 0]] < initial[[0, 0]]);
491 assert!(params["w"][[0, 1]] < initial[[0, 1]]);
492 assert!(params["w"][[0, 2]] < initial[[0, 2]]);
493 }
494
495 #[test]
496 fn test_sophia_gradient_clipping_value() {
497 let config = SophiaConfig {
498 base: OptimizerConfig {
499 learning_rate: 0.1,
500 grad_clip: Some(0.5),
501 grad_clip_mode: GradClipMode::Value,
502 ..Default::default()
503 },
504 ..Default::default()
505 };
506
507 let mut optimizer = SophiaOptimizer::with_sophia_config(config);
508 let mut params = HashMap::new();
509 params.insert("w".to_string(), array![[1.0, 2.0]]);
510
511 let mut grads = HashMap::new();
512 grads.insert("w".to_string(), array![[1.0, -2.0]]); let initial = params["w"].clone();
515 optimizer.step(&mut params, &grads).unwrap();
516
517 let update_mag = (initial[[0, 0]] - params["w"][[0, 0]]).abs();
519 assert!(update_mag < 0.2); }
521
522 #[test]
523 fn test_sophia_gradient_clipping_norm() {
524 let config = SophiaConfig {
525 base: OptimizerConfig {
526 learning_rate: 0.1,
527 grad_clip: Some(1.0),
528 grad_clip_mode: GradClipMode::Norm,
529 ..Default::default()
530 },
531 ..Default::default()
532 };
533
534 let mut optimizer = SophiaOptimizer::with_sophia_config(config);
535 let mut params = HashMap::new();
536 params.insert("w".to_string(), array![[1.0, 2.0, 3.0]]);
537
538 let mut grads = HashMap::new();
539 grads.insert("w".to_string(), array![[10.0, 10.0, 10.0]]); let initial = params["w"].clone();
542 optimizer.step(&mut params, &grads).unwrap();
543
544 let total_update: f64 = initial
546 .iter()
547 .zip(params["w"].iter())
548 .map(|(&p, &u)| (p - u).powi(2))
549 .sum::<f64>()
550 .sqrt();
551
552 assert!(total_update < 1.0); }
554
555 #[test]
556 fn test_sophia_learning_rate_getter_setter() {
557 let config = OptimizerConfig::default();
558 let mut optimizer = SophiaOptimizer::new(config);
559
560 optimizer.set_lr(0.001);
561 assert_relative_eq!(optimizer.get_lr(), 0.001);
562
563 optimizer.set_lr(0.1);
564 assert_relative_eq!(optimizer.get_lr(), 0.1);
565 }
566
567 #[test]
568 fn test_sophia_variant_gnb() {
569 let config = SophiaConfig {
570 variant: SophiaVariant::GaussNewtonBartlett,
571 ..Default::default()
572 };
573
574 let mut optimizer = SophiaOptimizer::with_sophia_config(config);
575 let mut params = HashMap::new();
576 params.insert("w".to_string(), array![[1.0, 2.0]]);
577
578 let mut grads = HashMap::new();
579 grads.insert("w".to_string(), array![[0.5, 0.5]]);
580
581 let initial = params["w"].clone();
582 optimizer.step(&mut params, &grads).unwrap();
583 assert!(params["w"][[0, 0]] < initial[[0, 0]]); }
585
586 #[test]
587 fn test_sophia_variant_hutchinson() {
588 let config = SophiaConfig {
589 variant: SophiaVariant::Hutchinson,
590 ..Default::default()
591 };
592
593 let mut optimizer = SophiaOptimizer::with_sophia_config(config);
594 let mut params = HashMap::new();
595 params.insert("w".to_string(), array![[1.0, 2.0]]);
596
597 let mut grads = HashMap::new();
598 grads.insert("w".to_string(), array![[0.5, 0.5]]);
599
600 let initial = params["w"].clone();
601 optimizer.step(&mut params, &grads).unwrap();
602 assert!(params["w"][[0, 0]] < initial[[0, 0]]); }
604
605 #[test]
606 fn test_sophia_update_clipping() {
607 let config = SophiaConfig {
609 base: OptimizerConfig {
610 learning_rate: 0.1,
611 ..Default::default()
612 },
613 rho: 0.001, ..Default::default()
615 };
616
617 let mut optimizer = SophiaOptimizer::with_sophia_config(config);
618 let mut params = HashMap::new();
619 params.insert("w".to_string(), array![[10.0]]);
620
621 let mut grads = HashMap::new();
622 grads.insert("w".to_string(), array![[100.0]]); let initial = params["w"][[0, 0]];
625 optimizer.step(&mut params, &grads).unwrap();
626
627 let update_size = (initial - params["w"][[0, 0]]).abs();
629 assert!(update_size <= 0.12); }
631
632 #[test]
633 fn test_sophia_load_state_dict() {
634 let config = OptimizerConfig::default();
635 let mut optimizer1 = SophiaOptimizer::new(config.clone());
636 let mut optimizer2 = SophiaOptimizer::new(config);
637
638 let mut params = HashMap::new();
639 params.insert("w".to_string(), array![[1.0, 2.0]]);
640
641 let mut grads = HashMap::new();
642 grads.insert("w".to_string(), array![[0.1, 0.2]]);
643
644 for _ in 0..5 {
646 optimizer1.step(&mut params, &grads).unwrap();
647 }
648
649 let state = optimizer1.state_dict();
651 optimizer2.load_state_dict(state);
652
653 assert_eq!(optimizer2.t, optimizer1.t);
655 assert_eq!(
656 optimizer2.steps_since_hessian_update,
657 optimizer1.steps_since_hessian_update
658 );
659 }
660}