Skip to main content

tensorlogic_train/optimizers/
sophia.rs

1//! Sophia optimizer - Scalable Stochastic Second-order Optimizer
2//!
3//! Implementation of the Sophia optimizer from:
4//! "Sophia: A Scalable Stochastic Second-order Optimizer for Language Model Pre-training"
5//! Hong Liu, Zhiyuan Li, David Hall, Percy Liang, Tengyu Ma (2023)
6//! <https://arxiv.org/abs/2305.14342>
7//!
8//! Sophia uses a lightweight Hessian diagonal estimate to adapt learning rates,
9//! achieving faster convergence than Adam with similar memory requirements.
10//!
11//! # Key Features
12//! - **Second-order information**: Uses Hessian diagonal estimates for better curvature awareness
13//! - **Scalable**: Only requires tracking Hessian diagonal (same memory as Adam)
14//! - **Fast convergence**: Typically 2-3x faster than Adam for language model pretraining
15//! - **Two variants**: Sophia-G (Gauss-Newton-Bartlett) and Sophia-H (Hutchinson)
16//!
17//! # Usage
18//! ```rust
19//! use tensorlogic_train::{SophiaOptimizer, OptimizerConfig, Optimizer};
20//! use scirs2_core::ndarray::Array2;
21//! use std::collections::HashMap;
22//!
23//! let config = OptimizerConfig {
24//!     learning_rate: 1e-4,
25//!     ..Default::default()
26//! };
27//!
28//! let mut optimizer = SophiaOptimizer::new(config);
29//!
30//! // During training with parameter HashMap:
31//! // optimizer.step(&mut parameters, &gradients)?;
32//! ```
33//!
34//! # Hyperparameter Recommendations
35//! - Learning rate: 1e-4 to 2e-4 (higher than Adam's typical 1e-5)
36//! - Beta1: 0.965 (momentum for gradients)
37//! - Beta2: 0.99 (momentum for Hessian diagonal)
38//! - Epsilon: 1e-8
39//! - Rho: 0.04 (clipping parameter for update direction)
40//! - Hessian update frequency: Every 10 steps (k=10)
41
42use super::common::{compute_gradient_norm, GradClipMode, Optimizer, OptimizerConfig};
43use crate::{TrainError, TrainResult};
44use scirs2_core::ndarray::{Array, Ix2};
45use std::collections::HashMap;
46
47/// Variant of Sophia optimizer to use
48#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
49pub enum SophiaVariant {
50    /// Gauss-Newton-Bartlett estimator (more accurate, slightly more expensive)
51    GaussNewtonBartlett,
52    /// Hutchinson estimator (cheaper, uses random projections)
53    Hutchinson,
54}
55
56/// Configuration for Sophia optimizer with additional Sophia-specific parameters
57#[derive(Debug, Clone)]
58pub struct SophiaConfig {
59    /// Base optimizer configuration
60    pub base: OptimizerConfig,
61    /// Clipping parameter for update direction (typically 0.04)
62    pub rho: f64,
63    /// Frequency of Hessian updates (every k steps)
64    pub hessian_update_freq: usize,
65    /// Variant to use (G or H)
66    pub variant: SophiaVariant,
67}
68
69impl Default for SophiaConfig {
70    fn default() -> Self {
71        Self {
72            base: OptimizerConfig {
73                learning_rate: 2e-4,
74                beta1: 0.965,
75                beta2: 0.99,
76                epsilon: 1e-8,
77                weight_decay: 0.01,
78                ..Default::default()
79            },
80            rho: 0.04,
81            hessian_update_freq: 10,
82            variant: SophiaVariant::GaussNewtonBartlett,
83        }
84    }
85}
86
87/// Sophia optimizer - Second-order optimizer with Hessian diagonal estimation
88///
89/// Maintains three state tensors per parameter:
90/// - m: First moment estimate (exponential moving average of gradients)
91/// - h: Hessian diagonal estimate (EMA of element-wise gradient^2 or Hutchinson estimate)
92/// - t: Step counter for bias correction
93pub struct SophiaOptimizer {
94    config: SophiaConfig,
95    /// First moment estimates (m_t)
96    m: HashMap<String, Array<f64, Ix2>>,
97    /// Hessian diagonal estimates (h_t)
98    h: HashMap<String, Array<f64, Ix2>>,
99    /// Timestep counter
100    t: usize,
101    /// Steps since last Hessian update
102    steps_since_hessian_update: usize,
103}
104
105impl SophiaOptimizer {
106    /// Create a new Sophia optimizer with default Sophia configuration
107    pub fn new(config: OptimizerConfig) -> Self {
108        Self::with_sophia_config(SophiaConfig {
109            base: config,
110            ..Default::default()
111        })
112    }
113
114    /// Create a new Sophia optimizer with custom Sophia configuration
115    pub fn with_sophia_config(config: SophiaConfig) -> Self {
116        Self {
117            config,
118            m: HashMap::new(),
119            h: HashMap::new(),
120            t: 0,
121            steps_since_hessian_update: 0,
122        }
123    }
124
125    /// Apply gradient clipping if configured
126    fn clip_gradients(&self, gradients: &mut HashMap<String, Array<f64, Ix2>>) {
127        if let Some(clip_value) = self.config.base.grad_clip {
128            match self.config.base.grad_clip_mode {
129                GradClipMode::Value => {
130                    for grad in gradients.values_mut() {
131                        grad.mapv_inplace(|g| g.max(-clip_value).min(clip_value));
132                    }
133                }
134                GradClipMode::Norm => {
135                    let total_norm = compute_gradient_norm(gradients);
136                    if total_norm > clip_value {
137                        let scale = clip_value / total_norm;
138                        for grad in gradients.values_mut() {
139                            grad.mapv_inplace(|g| g * scale);
140                        }
141                    }
142                }
143            }
144        }
145    }
146
147    /// Update Hessian diagonal estimate using Gauss-Newton-Bartlett method
148    ///
149    /// This uses the gradient itself as an approximation:
150    /// h_t = β₂ * h_{t-1} + (1 - β₂) * g_t²
151    fn update_hessian_gnb(&mut self, gradients: &HashMap<String, Array<f64, Ix2>>) {
152        let beta2 = self.config.base.beta2;
153
154        for (name, grad) in gradients {
155            let grad_squared = grad.mapv(|g| g * g);
156
157            if let Some(h_state) = self.h.get_mut(name) {
158                // h_t = β₂ * h_{t-1} + (1 - β₂) * g²
159                *h_state = &*h_state * beta2 + &grad_squared * (1.0 - beta2);
160            } else {
161                self.h.insert(name.clone(), grad_squared * (1.0 - beta2));
162            }
163        }
164    }
165
166    /// Update Hessian diagonal estimate using Hutchinson method
167    ///
168    /// Uses random Rademacher vectors for unbiased estimation:
169    /// h_t ≈ g_t ⊙ (∇²L * u) where u ~ Rademacher({-1, +1})
170    ///
171    /// Note: Full Hutchinson requires Hessian-vector products which aren't available
172    /// in this interface, so we use GNB as a reasonable approximation.
173    fn update_hessian_hutchinson(&mut self, gradients: &HashMap<String, Array<f64, Ix2>>) {
174        // For a full Hutchinson implementation, we'd need:
175        // 1. Sample u ~ Rademacher({-1, +1})
176        // 2. Compute Hessian-vector product: Hv = ∇(g^T u)
177        // 3. Estimate diagonal: h ≈ u ⊙ Hv
178        //
179        // Since we don't have access to Hessian-vector products in this interface,
180        // we use GNB as a practical approximation
181        self.update_hessian_gnb(gradients);
182    }
183}
184
185impl Optimizer for SophiaOptimizer {
186    fn step(
187        &mut self,
188        parameters: &mut HashMap<String, Array<f64, Ix2>>,
189        gradients: &HashMap<String, Array<f64, Ix2>>,
190    ) -> TrainResult<()> {
191        let mut clipped_gradients = gradients.clone();
192        self.clip_gradients(&mut clipped_gradients);
193
194        self.t += 1;
195        self.steps_since_hessian_update += 1;
196
197        let lr = self.config.base.learning_rate;
198        let beta1 = self.config.base.beta1;
199        let eps = self.config.base.epsilon;
200        let rho = self.config.rho;
201        let weight_decay = self.config.base.weight_decay;
202
203        // Update Hessian diagonal estimate (every k steps)
204        if self.steps_since_hessian_update >= self.config.hessian_update_freq {
205            match self.config.variant {
206                SophiaVariant::GaussNewtonBartlett => {
207                    self.update_hessian_gnb(&clipped_gradients);
208                }
209                SophiaVariant::Hutchinson => {
210                    self.update_hessian_hutchinson(&clipped_gradients);
211                }
212            }
213            self.steps_since_hessian_update = 0;
214        }
215
216        // Bias correction for first moment
217        let bias_correction1 = 1.0 - beta1.powi(self.t as i32);
218
219        // Update parameters
220        for (name, param) in parameters.iter_mut() {
221            let grad = clipped_gradients.get(name).ok_or_else(|| {
222                TrainError::OptimizerError(format!("Missing gradient for parameter: {}", name))
223            })?;
224
225            // Initialize state if needed
226            if !self.m.contains_key(name) {
227                self.m.insert(name.clone(), Array::zeros(param.raw_dim()));
228                self.h
229                    .insert(name.clone(), Array::ones(param.raw_dim()) * eps);
230            }
231
232            let m = self.m.get_mut(name).unwrap();
233            let h = self.h.get(name).unwrap();
234
235            // Update first moment (gradient EMA): m_t = β₁ * m_{t-1} + (1 - β₁) * g_t
236            *m = &*m * beta1 + &(grad * (1.0 - beta1));
237
238            // Bias-corrected first moment: m̂_t = m_t / (1 - β₁^t)
239            let m_hat = &*m / bias_correction1;
240
241            // Compute update direction: m̂ / (ρ * h + ε)
242            let denominator = h * rho + eps;
243            let update_direction = &m_hat / &denominator;
244
245            // Clip update direction to [-1, 1]
246            let clipped_update = update_direction.mapv(|x| x.clamp(-1.0, 1.0));
247
248            // Apply update: θ_{t+1} = θ_t - lr * clip(m̂ / (ρ * h), -1, 1)
249            *param = &*param - &(&clipped_update * lr);
250
251            // Weight decay (decoupled, like AdamW): θ_{t+1} -= lr * λ * θ_t
252            if weight_decay > 0.0 {
253                *param = &*param - &(&*param * (weight_decay * lr));
254            }
255        }
256
257        Ok(())
258    }
259
260    fn zero_grad(&mut self) {
261        // Gradients are passed in, not stored, so nothing to zero
262    }
263
264    fn get_lr(&self) -> f64 {
265        self.config.base.learning_rate
266    }
267
268    fn set_lr(&mut self, lr: f64) {
269        self.config.base.learning_rate = lr;
270    }
271
272    fn state_dict(&self) -> HashMap<String, Vec<f64>> {
273        let mut state = HashMap::new();
274        state.insert("t".to_string(), vec![self.t as f64]);
275        state.insert(
276            "steps_since_hessian_update".to_string(),
277            vec![self.steps_since_hessian_update as f64],
278        );
279
280        for (name, m_val) in &self.m {
281            state.insert(format!("m_{}", name), m_val.iter().copied().collect());
282        }
283        for (name, h_val) in &self.h {
284            state.insert(format!("h_{}", name), h_val.iter().copied().collect());
285        }
286
287        state
288    }
289
290    fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
291        if let Some(t_vals) = state.get("t") {
292            self.t = t_vals[0] as usize;
293        }
294        if let Some(steps_vals) = state.get("steps_since_hessian_update") {
295            self.steps_since_hessian_update = steps_vals[0] as usize;
296        }
297
298        for (key, values) in state {
299            if let Some(name) = key.strip_prefix("m_") {
300                if let Some(m) = self.m.get(name) {
301                    let shape = m.raw_dim();
302                    if let Ok(arr) = Array::from_shape_vec(shape, values) {
303                        self.m.insert(name.to_string(), arr);
304                    }
305                }
306            } else if let Some(name) = key.strip_prefix("h_") {
307                if let Some(h) = self.h.get(name) {
308                    let shape = h.raw_dim();
309                    if let Ok(arr) = Array::from_shape_vec(shape, values) {
310                        self.h.insert(name.to_string(), arr);
311                    }
312                }
313            }
314        }
315    }
316}
317
318#[cfg(test)]
319mod tests {
320    use super::*;
321    use approx::assert_relative_eq;
322    use scirs2_core::ndarray::array;
323
324    #[test]
325    fn test_sophia_initialization() {
326        let config = OptimizerConfig::default();
327        let optimizer = SophiaOptimizer::new(config);
328
329        assert_eq!(optimizer.t, 0);
330        assert!(optimizer.m.is_empty());
331        assert!(optimizer.h.is_empty());
332    }
333
334    #[test]
335    fn test_sophia_custom_config() {
336        let config = SophiaConfig {
337            base: OptimizerConfig {
338                learning_rate: 1e-4,
339                beta1: 0.965,
340                beta2: 0.99,
341                ..Default::default()
342            },
343            rho: 0.04,
344            ..Default::default()
345        };
346
347        let optimizer = SophiaOptimizer::with_sophia_config(config);
348        assert_relative_eq!(optimizer.get_lr(), 1e-4);
349    }
350
351    #[test]
352    fn test_sophia_single_step() {
353        let config = OptimizerConfig {
354            learning_rate: 0.1,
355            ..Default::default()
356        };
357
358        let mut optimizer = SophiaOptimizer::new(config);
359        let mut params = HashMap::new();
360        params.insert("w".to_string(), array![[1.0, 2.0, 3.0]]);
361
362        let mut grads = HashMap::new();
363        grads.insert("w".to_string(), array![[0.1, 0.2, 0.3]]);
364
365        let initial = params["w"].clone();
366        optimizer.step(&mut params, &grads).unwrap();
367
368        // Parameters should be updated (decreased for positive gradients)
369        assert!(params["w"][[0, 0]] < initial[[0, 0]]);
370        assert!(params["w"][[0, 1]] < initial[[0, 1]]);
371        assert!(params["w"][[0, 2]] < initial[[0, 2]]);
372    }
373
374    #[test]
375    fn test_sophia_convergence() {
376        let config = OptimizerConfig {
377            learning_rate: 0.1,
378            ..Default::default()
379        };
380
381        let mut optimizer = SophiaOptimizer::new(config);
382        let mut params = HashMap::new();
383        params.insert("w".to_string(), array![[5.0], [-3.0], [2.0]]);
384
385        // Simulate optimization to zero
386        for _ in 0..50 {
387            let mut grads = HashMap::new();
388            grads.insert("w".to_string(), &params["w"] * 2.0); // Gradient of x²
389            optimizer.step(&mut params, &grads).unwrap();
390        }
391
392        // Should converge close to zero
393        for &p in params["w"].iter() {
394            assert!(p.abs() < 0.5);
395        }
396    }
397
398    #[test]
399    fn test_sophia_2d_parameters() {
400        let config = OptimizerConfig {
401            learning_rate: 0.01,
402            ..Default::default()
403        };
404
405        let mut optimizer = SophiaOptimizer::new(config);
406        let mut params = HashMap::new();
407        params.insert("w".to_string(), array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
408
409        let mut grads = HashMap::new();
410        grads.insert("w".to_string(), array![[0.1, 0.1, 0.1], [-0.1, -0.1, -0.1]]);
411
412        let initial_shape = params["w"].shape().to_vec();
413        optimizer.step(&mut params, &grads).unwrap();
414
415        assert_eq!(params["w"].shape(), &initial_shape[..]);
416    }
417
418    #[test]
419    fn test_sophia_reset_and_state_dict() {
420        let config = OptimizerConfig::default();
421        let mut optimizer = SophiaOptimizer::new(config);
422
423        let mut params = HashMap::new();
424        params.insert("w".to_string(), array![[1.0, 2.0]]);
425
426        let mut grads = HashMap::new();
427        grads.insert("w".to_string(), array![[0.1, 0.2]]);
428
429        optimizer.step(&mut params, &grads).unwrap();
430        assert!(!optimizer.m.is_empty());
431        assert_eq!(optimizer.t, 1);
432
433        // Test state dict
434        let state = optimizer.state_dict();
435        assert!(state.contains_key("t"));
436        assert!(state.contains_key("m_w"));
437        assert!(state.contains_key("h_w"));
438    }
439
440    #[test]
441    fn test_sophia_hessian_update_frequency() {
442        let config = SophiaConfig {
443            hessian_update_freq: 5,
444            ..Default::default()
445        };
446
447        let mut optimizer = SophiaOptimizer::with_sophia_config(config);
448        let mut params = HashMap::new();
449        params.insert("w".to_string(), array![[1.0, 2.0]]);
450
451        let mut grads = HashMap::new();
452        grads.insert("w".to_string(), array![[0.1, 0.2]]);
453
454        // First step should update Hessian
455        optimizer.step(&mut params, &grads).unwrap();
456        assert_eq!(optimizer.steps_since_hessian_update, 1);
457
458        // Steps 2-4 should not update
459        for _ in 0..4 {
460            optimizer.step(&mut params, &grads).unwrap();
461        }
462        assert_eq!(optimizer.steps_since_hessian_update, 0); // Reset after 5 steps
463
464        // Hessian state should exist
465        assert!(optimizer.h.contains_key("w"));
466    }
467
468    #[test]
469    fn test_sophia_weight_decay() {
470        let config = SophiaConfig {
471            base: OptimizerConfig {
472                learning_rate: 0.1,
473                weight_decay: 0.01,
474                ..Default::default()
475            },
476            ..Default::default()
477        };
478
479        let mut optimizer = SophiaOptimizer::with_sophia_config(config);
480        let mut params = HashMap::new();
481        params.insert("w".to_string(), array![[1.0, 2.0, 3.0]]);
482
483        let mut grads = HashMap::new();
484        grads.insert("w".to_string(), array![[0.0, 0.0, 0.0]]); // Zero gradients
485
486        let initial = params["w"].clone();
487        optimizer.step(&mut params, &grads).unwrap();
488
489        // With weight decay and zero gradients, parameters should decay
490        assert!(params["w"][[0, 0]] < initial[[0, 0]]);
491        assert!(params["w"][[0, 1]] < initial[[0, 1]]);
492        assert!(params["w"][[0, 2]] < initial[[0, 2]]);
493    }
494
495    #[test]
496    fn test_sophia_gradient_clipping_value() {
497        let config = SophiaConfig {
498            base: OptimizerConfig {
499                learning_rate: 0.1,
500                grad_clip: Some(0.5),
501                grad_clip_mode: GradClipMode::Value,
502                ..Default::default()
503            },
504            ..Default::default()
505        };
506
507        let mut optimizer = SophiaOptimizer::with_sophia_config(config);
508        let mut params = HashMap::new();
509        params.insert("w".to_string(), array![[1.0, 2.0]]);
510
511        let mut grads = HashMap::new();
512        grads.insert("w".to_string(), array![[1.0, -2.0]]); // Should be clipped to [0.5, -0.5]
513
514        let initial = params["w"].clone();
515        optimizer.step(&mut params, &grads).unwrap();
516
517        // Effect should be limited by clipping
518        let update_mag = (initial[[0, 0]] - params["w"][[0, 0]]).abs();
519        assert!(update_mag < 0.2); // Much less than if unclipped
520    }
521
522    #[test]
523    fn test_sophia_gradient_clipping_norm() {
524        let config = SophiaConfig {
525            base: OptimizerConfig {
526                learning_rate: 0.1,
527                grad_clip: Some(1.0),
528                grad_clip_mode: GradClipMode::Norm,
529                ..Default::default()
530            },
531            ..Default::default()
532        };
533
534        let mut optimizer = SophiaOptimizer::with_sophia_config(config);
535        let mut params = HashMap::new();
536        params.insert("w".to_string(), array![[1.0, 2.0, 3.0]]);
537
538        let mut grads = HashMap::new();
539        grads.insert("w".to_string(), array![[10.0, 10.0, 10.0]]); // Large gradients
540
541        let initial = params["w"].clone();
542        optimizer.step(&mut params, &grads).unwrap();
543
544        // Norm clipping should limit the total update
545        let total_update: f64 = initial
546            .iter()
547            .zip(params["w"].iter())
548            .map(|(&p, &u)| (p - u).powi(2))
549            .sum::<f64>()
550            .sqrt();
551
552        assert!(total_update < 1.0); // Should be limited
553    }
554
555    #[test]
556    fn test_sophia_learning_rate_getter_setter() {
557        let config = OptimizerConfig::default();
558        let mut optimizer = SophiaOptimizer::new(config);
559
560        optimizer.set_lr(0.001);
561        assert_relative_eq!(optimizer.get_lr(), 0.001);
562
563        optimizer.set_lr(0.1);
564        assert_relative_eq!(optimizer.get_lr(), 0.1);
565    }
566
567    #[test]
568    fn test_sophia_variant_gnb() {
569        let config = SophiaConfig {
570            variant: SophiaVariant::GaussNewtonBartlett,
571            ..Default::default()
572        };
573
574        let mut optimizer = SophiaOptimizer::with_sophia_config(config);
575        let mut params = HashMap::new();
576        params.insert("w".to_string(), array![[1.0, 2.0]]);
577
578        let mut grads = HashMap::new();
579        grads.insert("w".to_string(), array![[0.5, 0.5]]);
580
581        let initial = params["w"].clone();
582        optimizer.step(&mut params, &grads).unwrap();
583        assert!(params["w"][[0, 0]] < initial[[0, 0]]); // Should make progress
584    }
585
586    #[test]
587    fn test_sophia_variant_hutchinson() {
588        let config = SophiaConfig {
589            variant: SophiaVariant::Hutchinson,
590            ..Default::default()
591        };
592
593        let mut optimizer = SophiaOptimizer::with_sophia_config(config);
594        let mut params = HashMap::new();
595        params.insert("w".to_string(), array![[1.0, 2.0]]);
596
597        let mut grads = HashMap::new();
598        grads.insert("w".to_string(), array![[0.5, 0.5]]);
599
600        let initial = params["w"].clone();
601        optimizer.step(&mut params, &grads).unwrap();
602        assert!(params["w"][[0, 0]] < initial[[0, 0]]); // Should make progress
603    }
604
605    #[test]
606    fn test_sophia_update_clipping() {
607        // Test that updates are clipped to [-1, 1] before applying learning rate
608        let config = SophiaConfig {
609            base: OptimizerConfig {
610                learning_rate: 0.1,
611                ..Default::default()
612            },
613            rho: 0.001, // Very small rho to create large update direction
614            ..Default::default()
615        };
616
617        let mut optimizer = SophiaOptimizer::with_sophia_config(config);
618        let mut params = HashMap::new();
619        params.insert("w".to_string(), array![[10.0]]);
620
621        let mut grads = HashMap::new();
622        grads.insert("w".to_string(), array![[100.0]]); // Large gradient
623
624        let initial = params["w"][[0, 0]];
625        optimizer.step(&mut params, &grads).unwrap();
626
627        // Even with large gradient, update should be bounded
628        let update_size = (initial - params["w"][[0, 0]]).abs();
629        assert!(update_size <= 0.12); // lr * 1.0 (clipped) + small margin
630    }
631
632    #[test]
633    fn test_sophia_load_state_dict() {
634        let config = OptimizerConfig::default();
635        let mut optimizer1 = SophiaOptimizer::new(config.clone());
636        let mut optimizer2 = SophiaOptimizer::new(config);
637
638        let mut params = HashMap::new();
639        params.insert("w".to_string(), array![[1.0, 2.0]]);
640
641        let mut grads = HashMap::new();
642        grads.insert("w".to_string(), array![[0.1, 0.2]]);
643
644        // Take several steps with optimizer1
645        for _ in 0..5 {
646            optimizer1.step(&mut params, &grads).unwrap();
647        }
648
649        // Save and load state
650        let state = optimizer1.state_dict();
651        optimizer2.load_state_dict(state);
652
653        // Verify state was loaded
654        assert_eq!(optimizer2.t, optimizer1.t);
655        assert_eq!(
656            optimizer2.steps_since_hessian_update,
657            optimizer1.steps_since_hessian_update
658        );
659    }
660}