Skip to main content

tensorlogic_train/optimizers/
sophia.rs

1//! Sophia optimizer - Scalable Stochastic Second-order Optimizer
2//!
3//! Implementation of the Sophia optimizer from:
4//! "Sophia: A Scalable Stochastic Second-order Optimizer for Language Model Pre-training"
5//! Hong Liu, Zhiyuan Li, David Hall, Percy Liang, Tengyu Ma (2023)
6//! <https://arxiv.org/abs/2305.14342>
7//!
8//! Sophia uses a lightweight Hessian diagonal estimate to adapt learning rates,
9//! achieving faster convergence than Adam with similar memory requirements.
10//!
11//! # Key Features
12//! - **Second-order information**: Uses Hessian diagonal estimates for better curvature awareness
13//! - **Scalable**: Only requires tracking Hessian diagonal (same memory as Adam)
14//! - **Fast convergence**: Typically 2-3x faster than Adam for language model pretraining
15//! - **Two variants**: Sophia-G (Gauss-Newton-Bartlett) and Sophia-H (Hutchinson)
16//!
17//! # Usage
18//! ```rust
19//! use tensorlogic_train::{SophiaOptimizer, OptimizerConfig, Optimizer};
20//! use scirs2_core::ndarray::Array2;
21//! use std::collections::HashMap;
22//!
23//! let config = OptimizerConfig {
24//!     learning_rate: 1e-4,
25//!     ..Default::default()
26//! };
27//!
28//! let mut optimizer = SophiaOptimizer::new(config);
29//!
30//! // During training with parameter HashMap:
31//! // optimizer.step(&mut parameters, &gradients)?;
32//! ```
33//!
34//! # Hyperparameter Recommendations
35//! - Learning rate: 1e-4 to 2e-4 (higher than Adam's typical 1e-5)
36//! - Beta1: 0.965 (momentum for gradients)
37//! - Beta2: 0.99 (momentum for Hessian diagonal)
38//! - Epsilon: 1e-8
39//! - Rho: 0.04 (clipping parameter for update direction)
40//! - Hessian update frequency: Every 10 steps (k=10)
41
42use super::common::{compute_gradient_norm, GradClipMode, Optimizer, OptimizerConfig};
43use crate::{TrainError, TrainResult};
44use scirs2_core::ndarray::{Array, Ix2};
45use std::collections::HashMap;
46
47/// Variant of Sophia optimizer to use
48#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
49pub enum SophiaVariant {
50    /// Gauss-Newton-Bartlett estimator (more accurate, slightly more expensive)
51    GaussNewtonBartlett,
52    /// Hutchinson estimator (cheaper, uses random projections)
53    Hutchinson,
54}
55
56/// Configuration for Sophia optimizer with additional Sophia-specific parameters
57#[derive(Debug, Clone)]
58pub struct SophiaConfig {
59    /// Base optimizer configuration
60    pub base: OptimizerConfig,
61    /// Clipping parameter for update direction (typically 0.04)
62    pub rho: f64,
63    /// Frequency of Hessian updates (every k steps)
64    pub hessian_update_freq: usize,
65    /// Variant to use (G or H)
66    pub variant: SophiaVariant,
67}
68
69impl Default for SophiaConfig {
70    fn default() -> Self {
71        Self {
72            base: OptimizerConfig {
73                learning_rate: 2e-4,
74                beta1: 0.965,
75                beta2: 0.99,
76                epsilon: 1e-8,
77                weight_decay: 0.01,
78                ..Default::default()
79            },
80            rho: 0.04,
81            hessian_update_freq: 10,
82            variant: SophiaVariant::GaussNewtonBartlett,
83        }
84    }
85}
86
87/// Sophia optimizer - Second-order optimizer with Hessian diagonal estimation
88///
89/// Maintains three state tensors per parameter:
90/// - m: First moment estimate (exponential moving average of gradients)
91/// - h: Hessian diagonal estimate (EMA of element-wise gradient^2 or Hutchinson estimate)
92/// - t: Step counter for bias correction
93pub struct SophiaOptimizer {
94    config: SophiaConfig,
95    /// First moment estimates (m_t)
96    m: HashMap<String, Array<f64, Ix2>>,
97    /// Hessian diagonal estimates (h_t)
98    h: HashMap<String, Array<f64, Ix2>>,
99    /// Timestep counter
100    t: usize,
101    /// Steps since last Hessian update
102    steps_since_hessian_update: usize,
103}
104
105impl SophiaOptimizer {
106    /// Create a new Sophia optimizer with default Sophia configuration
107    pub fn new(config: OptimizerConfig) -> Self {
108        Self::with_sophia_config(SophiaConfig {
109            base: config,
110            ..Default::default()
111        })
112    }
113
114    /// Create a new Sophia optimizer with custom Sophia configuration
115    pub fn with_sophia_config(config: SophiaConfig) -> Self {
116        Self {
117            config,
118            m: HashMap::new(),
119            h: HashMap::new(),
120            t: 0,
121            steps_since_hessian_update: 0,
122        }
123    }
124
125    /// Apply gradient clipping if configured
126    fn clip_gradients(&self, gradients: &mut HashMap<String, Array<f64, Ix2>>) {
127        if let Some(clip_value) = self.config.base.grad_clip {
128            match self.config.base.grad_clip_mode {
129                GradClipMode::Value => {
130                    for grad in gradients.values_mut() {
131                        grad.mapv_inplace(|g| g.max(-clip_value).min(clip_value));
132                    }
133                }
134                GradClipMode::Norm => {
135                    let total_norm = compute_gradient_norm(gradients);
136                    if total_norm > clip_value {
137                        let scale = clip_value / total_norm;
138                        for grad in gradients.values_mut() {
139                            grad.mapv_inplace(|g| g * scale);
140                        }
141                    }
142                }
143            }
144        }
145    }
146
147    /// Update Hessian diagonal estimate using Gauss-Newton-Bartlett method
148    ///
149    /// This uses the gradient itself as an approximation:
150    /// h_t = β₂ * h_{t-1} + (1 - β₂) * g_t²
151    fn update_hessian_gnb(&mut self, gradients: &HashMap<String, Array<f64, Ix2>>) {
152        let beta2 = self.config.base.beta2;
153
154        for (name, grad) in gradients {
155            let grad_squared = grad.mapv(|g| g * g);
156
157            if let Some(h_state) = self.h.get_mut(name) {
158                // h_t = β₂ * h_{t-1} + (1 - β₂) * g²
159                *h_state = &*h_state * beta2 + &grad_squared * (1.0 - beta2);
160            } else {
161                self.h.insert(name.clone(), grad_squared * (1.0 - beta2));
162            }
163        }
164    }
165
166    /// Update Hessian diagonal estimate using Hutchinson method
167    ///
168    /// Uses random Rademacher vectors for unbiased estimation:
169    /// h_t ≈ g_t ⊙ (∇²L * u) where u ~ Rademacher({-1, +1})
170    ///
171    /// Note: Full Hutchinson requires Hessian-vector products which aren't available
172    /// in this interface, so we use GNB as a reasonable approximation.
173    fn update_hessian_hutchinson(&mut self, gradients: &HashMap<String, Array<f64, Ix2>>) {
174        // For a full Hutchinson implementation, we'd need:
175        // 1. Sample u ~ Rademacher({-1, +1})
176        // 2. Compute Hessian-vector product: Hv = ∇(g^T u)
177        // 3. Estimate diagonal: h ≈ u ⊙ Hv
178        //
179        // Since we don't have access to Hessian-vector products in this interface,
180        // we use GNB as a practical approximation
181        self.update_hessian_gnb(gradients);
182    }
183}
184
185impl Optimizer for SophiaOptimizer {
186    fn step(
187        &mut self,
188        parameters: &mut HashMap<String, Array<f64, Ix2>>,
189        gradients: &HashMap<String, Array<f64, Ix2>>,
190    ) -> TrainResult<()> {
191        let mut clipped_gradients = gradients.clone();
192        self.clip_gradients(&mut clipped_gradients);
193
194        self.t += 1;
195        self.steps_since_hessian_update += 1;
196
197        let lr = self.config.base.learning_rate;
198        let beta1 = self.config.base.beta1;
199        let eps = self.config.base.epsilon;
200        let rho = self.config.rho;
201        let weight_decay = self.config.base.weight_decay;
202
203        // Update Hessian diagonal estimate (every k steps)
204        if self.steps_since_hessian_update >= self.config.hessian_update_freq {
205            match self.config.variant {
206                SophiaVariant::GaussNewtonBartlett => {
207                    self.update_hessian_gnb(&clipped_gradients);
208                }
209                SophiaVariant::Hutchinson => {
210                    self.update_hessian_hutchinson(&clipped_gradients);
211                }
212            }
213            self.steps_since_hessian_update = 0;
214        }
215
216        // Bias correction for first moment
217        let bias_correction1 = 1.0 - beta1.powi(self.t as i32);
218
219        // Update parameters
220        for (name, param) in parameters.iter_mut() {
221            let grad = clipped_gradients.get(name).ok_or_else(|| {
222                TrainError::OptimizerError(format!("Missing gradient for parameter: {}", name))
223            })?;
224
225            // Initialize state if needed
226            if !self.m.contains_key(name) {
227                self.m.insert(name.clone(), Array::zeros(param.raw_dim()));
228                self.h
229                    .insert(name.clone(), Array::ones(param.raw_dim()) * eps);
230            }
231
232            let m = self
233                .m
234                .get_mut(name)
235                .expect("m initialized for all parameters");
236            let h = self.h.get(name).expect("h initialized for all parameters");
237
238            // Update first moment (gradient EMA): m_t = β₁ * m_{t-1} + (1 - β₁) * g_t
239            *m = &*m * beta1 + &(grad * (1.0 - beta1));
240
241            // Bias-corrected first moment: m̂_t = m_t / (1 - β₁^t)
242            let m_hat = &*m / bias_correction1;
243
244            // Compute update direction: m̂ / (ρ * h + ε)
245            let denominator = h * rho + eps;
246            let update_direction = &m_hat / &denominator;
247
248            // Clip update direction to [-1, 1]
249            let clipped_update = update_direction.mapv(|x| x.clamp(-1.0, 1.0));
250
251            // Apply update: θ_{t+1} = θ_t - lr * clip(m̂ / (ρ * h), -1, 1)
252            *param = &*param - &(&clipped_update * lr);
253
254            // Weight decay (decoupled, like AdamW): θ_{t+1} -= lr * λ * θ_t
255            if weight_decay > 0.0 {
256                *param = &*param - &(&*param * (weight_decay * lr));
257            }
258        }
259
260        Ok(())
261    }
262
263    fn zero_grad(&mut self) {
264        // Gradients are passed in, not stored, so nothing to zero
265    }
266
267    fn get_lr(&self) -> f64 {
268        self.config.base.learning_rate
269    }
270
271    fn set_lr(&mut self, lr: f64) {
272        self.config.base.learning_rate = lr;
273    }
274
275    fn state_dict(&self) -> HashMap<String, Vec<f64>> {
276        let mut state = HashMap::new();
277        state.insert("t".to_string(), vec![self.t as f64]);
278        state.insert(
279            "steps_since_hessian_update".to_string(),
280            vec![self.steps_since_hessian_update as f64],
281        );
282
283        for (name, m_val) in &self.m {
284            state.insert(format!("m_{}", name), m_val.iter().copied().collect());
285        }
286        for (name, h_val) in &self.h {
287            state.insert(format!("h_{}", name), h_val.iter().copied().collect());
288        }
289
290        state
291    }
292
293    fn load_state_dict(&mut self, state: HashMap<String, Vec<f64>>) {
294        if let Some(t_vals) = state.get("t") {
295            self.t = t_vals[0] as usize;
296        }
297        if let Some(steps_vals) = state.get("steps_since_hessian_update") {
298            self.steps_since_hessian_update = steps_vals[0] as usize;
299        }
300
301        for (key, values) in state {
302            if let Some(name) = key.strip_prefix("m_") {
303                if let Some(m) = self.m.get(name) {
304                    let shape = m.raw_dim();
305                    if let Ok(arr) = Array::from_shape_vec(shape, values) {
306                        self.m.insert(name.to_string(), arr);
307                    }
308                }
309            } else if let Some(name) = key.strip_prefix("h_") {
310                if let Some(h) = self.h.get(name) {
311                    let shape = h.raw_dim();
312                    if let Ok(arr) = Array::from_shape_vec(shape, values) {
313                        self.h.insert(name.to_string(), arr);
314                    }
315                }
316            }
317        }
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324    use approx::assert_relative_eq;
325    use scirs2_core::ndarray::array;
326
327    #[test]
328    fn test_sophia_initialization() {
329        let config = OptimizerConfig::default();
330        let optimizer = SophiaOptimizer::new(config);
331
332        assert_eq!(optimizer.t, 0);
333        assert!(optimizer.m.is_empty());
334        assert!(optimizer.h.is_empty());
335    }
336
337    #[test]
338    fn test_sophia_custom_config() {
339        let config = SophiaConfig {
340            base: OptimizerConfig {
341                learning_rate: 1e-4,
342                beta1: 0.965,
343                beta2: 0.99,
344                ..Default::default()
345            },
346            rho: 0.04,
347            ..Default::default()
348        };
349
350        let optimizer = SophiaOptimizer::with_sophia_config(config);
351        assert_relative_eq!(optimizer.get_lr(), 1e-4);
352    }
353
354    #[test]
355    fn test_sophia_single_step() {
356        let config = OptimizerConfig {
357            learning_rate: 0.1,
358            ..Default::default()
359        };
360
361        let mut optimizer = SophiaOptimizer::new(config);
362        let mut params = HashMap::new();
363        params.insert("w".to_string(), array![[1.0, 2.0, 3.0]]);
364
365        let mut grads = HashMap::new();
366        grads.insert("w".to_string(), array![[0.1, 0.2, 0.3]]);
367
368        let initial = params["w"].clone();
369        optimizer.step(&mut params, &grads).expect("unwrap");
370
371        // Parameters should be updated (decreased for positive gradients)
372        assert!(params["w"][[0, 0]] < initial[[0, 0]]);
373        assert!(params["w"][[0, 1]] < initial[[0, 1]]);
374        assert!(params["w"][[0, 2]] < initial[[0, 2]]);
375    }
376
377    #[test]
378    fn test_sophia_convergence() {
379        let config = OptimizerConfig {
380            learning_rate: 0.1,
381            ..Default::default()
382        };
383
384        let mut optimizer = SophiaOptimizer::new(config);
385        let mut params = HashMap::new();
386        params.insert("w".to_string(), array![[5.0], [-3.0], [2.0]]);
387
388        // Simulate optimization to zero
389        for _ in 0..50 {
390            let mut grads = HashMap::new();
391            grads.insert("w".to_string(), &params["w"] * 2.0); // Gradient of x²
392            optimizer.step(&mut params, &grads).expect("unwrap");
393        }
394
395        // Should converge close to zero
396        for &p in params["w"].iter() {
397            assert!(p.abs() < 0.5);
398        }
399    }
400
401    #[test]
402    fn test_sophia_2d_parameters() {
403        let config = OptimizerConfig {
404            learning_rate: 0.01,
405            ..Default::default()
406        };
407
408        let mut optimizer = SophiaOptimizer::new(config);
409        let mut params = HashMap::new();
410        params.insert("w".to_string(), array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
411
412        let mut grads = HashMap::new();
413        grads.insert("w".to_string(), array![[0.1, 0.1, 0.1], [-0.1, -0.1, -0.1]]);
414
415        let initial_shape = params["w"].shape().to_vec();
416        optimizer.step(&mut params, &grads).expect("unwrap");
417
418        assert_eq!(params["w"].shape(), &initial_shape[..]);
419    }
420
421    #[test]
422    fn test_sophia_reset_and_state_dict() {
423        let config = OptimizerConfig::default();
424        let mut optimizer = SophiaOptimizer::new(config);
425
426        let mut params = HashMap::new();
427        params.insert("w".to_string(), array![[1.0, 2.0]]);
428
429        let mut grads = HashMap::new();
430        grads.insert("w".to_string(), array![[0.1, 0.2]]);
431
432        optimizer.step(&mut params, &grads).expect("unwrap");
433        assert!(!optimizer.m.is_empty());
434        assert_eq!(optimizer.t, 1);
435
436        // Test state dict
437        let state = optimizer.state_dict();
438        assert!(state.contains_key("t"));
439        assert!(state.contains_key("m_w"));
440        assert!(state.contains_key("h_w"));
441    }
442
443    #[test]
444    fn test_sophia_hessian_update_frequency() {
445        let config = SophiaConfig {
446            hessian_update_freq: 5,
447            ..Default::default()
448        };
449
450        let mut optimizer = SophiaOptimizer::with_sophia_config(config);
451        let mut params = HashMap::new();
452        params.insert("w".to_string(), array![[1.0, 2.0]]);
453
454        let mut grads = HashMap::new();
455        grads.insert("w".to_string(), array![[0.1, 0.2]]);
456
457        // First step should update Hessian
458        optimizer.step(&mut params, &grads).expect("unwrap");
459        assert_eq!(optimizer.steps_since_hessian_update, 1);
460
461        // Steps 2-4 should not update
462        for _ in 0..4 {
463            optimizer.step(&mut params, &grads).expect("unwrap");
464        }
465        assert_eq!(optimizer.steps_since_hessian_update, 0); // Reset after 5 steps
466
467        // Hessian state should exist
468        assert!(optimizer.h.contains_key("w"));
469    }
470
471    #[test]
472    fn test_sophia_weight_decay() {
473        let config = SophiaConfig {
474            base: OptimizerConfig {
475                learning_rate: 0.1,
476                weight_decay: 0.01,
477                ..Default::default()
478            },
479            ..Default::default()
480        };
481
482        let mut optimizer = SophiaOptimizer::with_sophia_config(config);
483        let mut params = HashMap::new();
484        params.insert("w".to_string(), array![[1.0, 2.0, 3.0]]);
485
486        let mut grads = HashMap::new();
487        grads.insert("w".to_string(), array![[0.0, 0.0, 0.0]]); // Zero gradients
488
489        let initial = params["w"].clone();
490        optimizer.step(&mut params, &grads).expect("unwrap");
491
492        // With weight decay and zero gradients, parameters should decay
493        assert!(params["w"][[0, 0]] < initial[[0, 0]]);
494        assert!(params["w"][[0, 1]] < initial[[0, 1]]);
495        assert!(params["w"][[0, 2]] < initial[[0, 2]]);
496    }
497
498    #[test]
499    fn test_sophia_gradient_clipping_value() {
500        let config = SophiaConfig {
501            base: OptimizerConfig {
502                learning_rate: 0.1,
503                grad_clip: Some(0.5),
504                grad_clip_mode: GradClipMode::Value,
505                ..Default::default()
506            },
507            ..Default::default()
508        };
509
510        let mut optimizer = SophiaOptimizer::with_sophia_config(config);
511        let mut params = HashMap::new();
512        params.insert("w".to_string(), array![[1.0, 2.0]]);
513
514        let mut grads = HashMap::new();
515        grads.insert("w".to_string(), array![[1.0, -2.0]]); // Should be clipped to [0.5, -0.5]
516
517        let initial = params["w"].clone();
518        optimizer.step(&mut params, &grads).expect("unwrap");
519
520        // Effect should be limited by clipping
521        let update_mag = (initial[[0, 0]] - params["w"][[0, 0]]).abs();
522        assert!(update_mag < 0.2); // Much less than if unclipped
523    }
524
525    #[test]
526    fn test_sophia_gradient_clipping_norm() {
527        let config = SophiaConfig {
528            base: OptimizerConfig {
529                learning_rate: 0.1,
530                grad_clip: Some(1.0),
531                grad_clip_mode: GradClipMode::Norm,
532                ..Default::default()
533            },
534            ..Default::default()
535        };
536
537        let mut optimizer = SophiaOptimizer::with_sophia_config(config);
538        let mut params = HashMap::new();
539        params.insert("w".to_string(), array![[1.0, 2.0, 3.0]]);
540
541        let mut grads = HashMap::new();
542        grads.insert("w".to_string(), array![[10.0, 10.0, 10.0]]); // Large gradients
543
544        let initial = params["w"].clone();
545        optimizer.step(&mut params, &grads).expect("unwrap");
546
547        // Norm clipping should limit the total update
548        let total_update: f64 = initial
549            .iter()
550            .zip(params["w"].iter())
551            .map(|(&p, &u)| (p - u).powi(2))
552            .sum::<f64>()
553            .sqrt();
554
555        assert!(total_update < 1.0); // Should be limited
556    }
557
558    #[test]
559    fn test_sophia_learning_rate_getter_setter() {
560        let config = OptimizerConfig::default();
561        let mut optimizer = SophiaOptimizer::new(config);
562
563        optimizer.set_lr(0.001);
564        assert_relative_eq!(optimizer.get_lr(), 0.001);
565
566        optimizer.set_lr(0.1);
567        assert_relative_eq!(optimizer.get_lr(), 0.1);
568    }
569
570    #[test]
571    fn test_sophia_variant_gnb() {
572        let config = SophiaConfig {
573            variant: SophiaVariant::GaussNewtonBartlett,
574            ..Default::default()
575        };
576
577        let mut optimizer = SophiaOptimizer::with_sophia_config(config);
578        let mut params = HashMap::new();
579        params.insert("w".to_string(), array![[1.0, 2.0]]);
580
581        let mut grads = HashMap::new();
582        grads.insert("w".to_string(), array![[0.5, 0.5]]);
583
584        let initial = params["w"].clone();
585        optimizer.step(&mut params, &grads).expect("unwrap");
586        assert!(params["w"][[0, 0]] < initial[[0, 0]]); // Should make progress
587    }
588
589    #[test]
590    fn test_sophia_variant_hutchinson() {
591        let config = SophiaConfig {
592            variant: SophiaVariant::Hutchinson,
593            ..Default::default()
594        };
595
596        let mut optimizer = SophiaOptimizer::with_sophia_config(config);
597        let mut params = HashMap::new();
598        params.insert("w".to_string(), array![[1.0, 2.0]]);
599
600        let mut grads = HashMap::new();
601        grads.insert("w".to_string(), array![[0.5, 0.5]]);
602
603        let initial = params["w"].clone();
604        optimizer.step(&mut params, &grads).expect("unwrap");
605        assert!(params["w"][[0, 0]] < initial[[0, 0]]); // Should make progress
606    }
607
608    #[test]
609    fn test_sophia_update_clipping() {
610        // Test that updates are clipped to [-1, 1] before applying learning rate
611        let config = SophiaConfig {
612            base: OptimizerConfig {
613                learning_rate: 0.1,
614                ..Default::default()
615            },
616            rho: 0.001, // Very small rho to create large update direction
617            ..Default::default()
618        };
619
620        let mut optimizer = SophiaOptimizer::with_sophia_config(config);
621        let mut params = HashMap::new();
622        params.insert("w".to_string(), array![[10.0]]);
623
624        let mut grads = HashMap::new();
625        grads.insert("w".to_string(), array![[100.0]]); // Large gradient
626
627        let initial = params["w"][[0, 0]];
628        optimizer.step(&mut params, &grads).expect("unwrap");
629
630        // Even with large gradient, update should be bounded
631        let update_size = (initial - params["w"][[0, 0]]).abs();
632        assert!(update_size <= 0.12); // lr * 1.0 (clipped) + small margin
633    }
634
635    #[test]
636    fn test_sophia_load_state_dict() {
637        let config = OptimizerConfig::default();
638        let mut optimizer1 = SophiaOptimizer::new(config.clone());
639        let mut optimizer2 = SophiaOptimizer::new(config);
640
641        let mut params = HashMap::new();
642        params.insert("w".to_string(), array![[1.0, 2.0]]);
643
644        let mut grads = HashMap::new();
645        grads.insert("w".to_string(), array![[0.1, 0.2]]);
646
647        // Take several steps with optimizer1
648        for _ in 0..5 {
649            optimizer1.step(&mut params, &grads).expect("unwrap");
650        }
651
652        // Save and load state
653        let state = optimizer1.state_dict();
654        optimizer2.load_state_dict(state);
655
656        // Verify state was loaded
657        assert_eq!(optimizer2.t, optimizer1.t);
658        assert_eq!(
659            optimizer2.steps_since_hessian_update,
660            optimizer1.steps_since_hessian_update
661        );
662    }
663}