Skip to main content

trustformers_optim/
muon.rs

1//! # Muon Optimizer
2//!
3//! Implementation of the Muon optimizer, a second-order optimization algorithm designed for
4//! neural network training, particularly with hidden layers having 2D weight matrices.
5//!
6//! Muon is used in the current training speed records for both NanoGPT and CIFAR-10 speedrunning.
7//!
8//! ## Key Features
9//!
10//! - **Second-Order Optimization**: Uses Newton-Schulz iteration for efficient orthogonalization
11//! - **Low FLOP Overhead**: Below 1% FLOP overhead for typical LM training scenarios
12//! - **2D Parameter Focus**: Designed specifically for 2D weight matrices (linear layers)
13//! - **Speed Records**: Achieves state-of-the-art training speed on multiple benchmarks
14//!
15//! ## Design Philosophy
16//!
17//! Muon only applies to 2D parameters (weight matrices), while scalar and vector parameters
18//! must be optimized using a standard method (e.g., AdamW). This hybrid approach provides
19//! the best of both worlds: second-order benefits for main parameters and proven stability
20//! for auxiliary parameters.
21
22use crate::common::{OptimizerState, StateMemoryStats};
23use crate::traits::StatefulOptimizer;
24use serde::{Deserialize, Serialize};
25use std::collections::HashMap;
26use trustformers_core::errors::Result;
27use trustformers_core::tensor::Tensor;
28use trustformers_core::traits::Optimizer;
29
30/// Configuration for Muon optimizer
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct MuonConfig {
33    /// Learning rate (default: 0.02)
34    pub learning_rate: f32,
35    /// Momentum coefficient (default: 0.95)
36    pub momentum: f32,
37    /// Newton-Schulz iteration steps (default: 5)
38    pub ns_steps: usize,
39    /// Minimum dimension for 2D optimization (default: 64)
40    pub min_dim_2d: usize,
41    /// Fallback optimizer learning rate for 1D parameters (default: 1e-3)
42    pub fallback_lr: f32,
43    /// Fallback momentum for 1D parameters (default: 0.9)
44    pub fallback_momentum: f32,
45    /// Weight decay coefficient (default: 0.0)
46    pub weight_decay: f32,
47    /// Whether to use orthogonalization (default: true)
48    pub use_orthogonal: bool,
49}
50
51impl Default for MuonConfig {
52    fn default() -> Self {
53        Self {
54            learning_rate: 0.02,
55            momentum: 0.95,
56            ns_steps: 5,
57            min_dim_2d: 64,
58            fallback_lr: 1e-3,
59            fallback_momentum: 0.9,
60            weight_decay: 0.0,
61            use_orthogonal: true,
62        }
63    }
64}
65
66/// Muon optimizer implementation
67///
68/// Muon uses Newton-Schulz iteration for orthogonalization of 2D weight matrices,
69/// providing efficient second-order optimization. For 1D parameters, it falls back
70/// to a standard momentum-based update.
71#[derive(Debug)]
72pub struct Muon {
73    config: MuonConfig,
74    state: OptimizerState,
75    /// Momentum buffers for 2D parameters
76    momentum_2d: HashMap<String, Vec<Vec<f32>>>,
77    /// Momentum buffers for 1D parameters (AdamW-style fallback)
78    momentum_1d: HashMap<String, Vec<f32>>,
79    /// Parameter shapes for tracking 2D vs 1D
80    param_shapes: HashMap<String, (usize, usize)>,
81}
82
83impl Muon {
84    /// Create a new Muon optimizer with default configuration
85    pub fn new() -> Self {
86        Self::with_config(MuonConfig::default())
87    }
88
89    /// Create Muon with custom learning rate
90    pub fn new_with_lr(learning_rate: f32) -> Self {
91        let config = MuonConfig {
92            learning_rate,
93            ..Default::default()
94        };
95        Self::with_config(config)
96    }
97
98    /// Create Muon optimized for NanoGPT training
99    pub fn for_nanogpt() -> Self {
100        let config = MuonConfig {
101            learning_rate: 0.01,
102            momentum: 0.95,
103            ns_steps: 5,
104            min_dim_2d: 32, // Lower threshold for smaller models
105            fallback_lr: 5e-4,
106            fallback_momentum: 0.9,
107            weight_decay: 0.0,
108            use_orthogonal: true,
109        };
110        Self::with_config(config)
111    }
112
113    /// Create Muon optimized for CIFAR-10 training
114    pub fn for_cifar10() -> Self {
115        let config = MuonConfig {
116            learning_rate: 0.03,
117            momentum: 0.9,
118            ns_steps: 4, // Fewer steps for vision tasks
119            min_dim_2d: 64,
120            fallback_lr: 1e-3,
121            fallback_momentum: 0.9,
122            weight_decay: 1e-4,
123            use_orthogonal: true,
124        };
125        Self::with_config(config)
126    }
127
128    /// Create Muon optimized for large language models
129    pub fn for_large_lm() -> Self {
130        let config = MuonConfig {
131            learning_rate: 0.015,
132            momentum: 0.98,  // Higher momentum for large models
133            ns_steps: 6,     // More steps for better approximation
134            min_dim_2d: 128, // Higher threshold for large models
135            fallback_lr: 3e-4,
136            fallback_momentum: 0.95,
137            weight_decay: 0.01,
138            use_orthogonal: true,
139        };
140        Self::with_config(config)
141    }
142
143    /// Create Muon with custom configuration
144    pub fn with_config(config: MuonConfig) -> Self {
145        Self {
146            config,
147            state: OptimizerState::new(),
148            momentum_2d: HashMap::new(),
149            momentum_1d: HashMap::new(),
150            param_shapes: HashMap::new(),
151        }
152    }
153
154    /// Check if parameter should use 2D optimization
155    fn should_use_2d_optimization(&self, rows: usize, cols: usize) -> bool {
156        rows >= self.config.min_dim_2d && cols >= self.config.min_dim_2d
157    }
158
159    /// Newton-Schulz iteration for matrix orthogonalization
160    /// Approximates the orthogonal polar factor of a matrix
161    fn newton_schulz_orthogonalize(&self, matrix: &mut [Vec<f32>]) {
162        if !self.config.use_orthogonal {
163            return;
164        }
165
166        let rows = matrix.len();
167        let cols = matrix[0].len();
168
169        // Newton-Schulz iteration: X_{k+1} = X_k * (3I - X_k^T * X_k) / 2
170        for _ in 0..self.config.ns_steps {
171            // Compute X^T * X
172            let mut xtx = vec![vec![0.0; cols]; cols];
173            for i in 0..cols {
174                for j in 0..cols {
175                    let mut sum = 0.0;
176                    for k in 0..rows {
177                        sum += matrix[k][i] * matrix[k][j];
178                    }
179                    xtx[i][j] = sum;
180                }
181            }
182
183            // Compute 3I - X^T * X
184            for i in 0..cols {
185                for j in 0..cols {
186                    if i == j {
187                        xtx[i][j] = 3.0 - xtx[i][j];
188                    } else {
189                        xtx[i][j] = -xtx[i][j];
190                    }
191                }
192            }
193
194            // Compute X * (3I - X^T * X) / 2
195            let mut new_matrix = vec![vec![0.0; cols]; rows];
196            for i in 0..rows {
197                for j in 0..cols {
198                    let mut sum = 0.0;
199                    for k in 0..cols {
200                        sum += matrix[i][k] * xtx[k][j];
201                    }
202                    new_matrix[i][j] = sum * 0.5;
203                }
204            }
205
206            // Update matrix
207            for i in 0..rows {
208                for j in 0..cols {
209                    matrix[i][j] = new_matrix[i][j];
210                }
211            }
212        }
213    }
214
215    /// Update 2D parameter using Muon algorithm
216    fn update_2d_parameter(
217        &mut self,
218        param_data: &mut [f32],
219        grad_data: &[f32],
220        param_id: &str,
221        rows: usize,
222        cols: usize,
223    ) -> Result<()> {
224        // Initialize momentum if needed
225        if !self.momentum_2d.contains_key(param_id) {
226            let momentum = vec![vec![0.0; cols]; rows];
227            self.momentum_2d.insert(param_id.to_string(), momentum);
228        }
229
230        let momentum = self.momentum_2d.get_mut(param_id).unwrap();
231
232        // Reshape flat arrays to 2D views
233        let mut param_matrix = vec![vec![0.0; cols]; rows];
234        let mut grad_matrix = vec![vec![0.0; cols]; rows];
235
236        // Convert flat to 2D
237        for i in 0..rows {
238            for j in 0..cols {
239                let idx = i * cols + j;
240                param_matrix[i][j] = param_data[idx];
241                grad_matrix[i][j] = grad_data[idx];
242            }
243        }
244
245        // Apply weight decay
246        if self.config.weight_decay > 0.0 {
247            for i in 0..rows {
248                for j in 0..cols {
249                    grad_matrix[i][j] += self.config.weight_decay * param_matrix[i][j];
250                }
251            }
252        }
253
254        // Update momentum: m = momentum * m + grad
255        for i in 0..rows {
256            for j in 0..cols {
257                momentum[i][j] = self.config.momentum * momentum[i][j] + grad_matrix[i][j];
258            }
259        }
260
261        // Create update matrix (copy of momentum for orthogonalization)
262        let mut update_matrix = momentum.clone();
263
264        // Apply Newton-Schulz orthogonalization
265        self.newton_schulz_orthogonalize(&mut update_matrix);
266
267        // Apply update: param = param - lr * orthogonalized_momentum
268        for i in 0..rows {
269            for j in 0..cols {
270                param_matrix[i][j] -= self.config.learning_rate * update_matrix[i][j];
271
272                // Convert back to flat array
273                let idx = i * cols + j;
274                param_data[idx] = param_matrix[i][j];
275            }
276        }
277
278        Ok(())
279    }
280
281    /// Update 1D parameter using fallback method (momentum SGD)
282    fn update_1d_parameter(
283        &mut self,
284        param_data: &mut [f32],
285        grad_data: &[f32],
286        param_id: &str,
287    ) -> Result<()> {
288        let param_size = param_data.len();
289
290        // Initialize momentum if needed
291        if !self.momentum_1d.contains_key(param_id) {
292            self.momentum_1d.insert(param_id.to_string(), vec![0.0; param_size]);
293        }
294
295        let momentum = self.momentum_1d.get_mut(param_id).unwrap();
296
297        // Apply momentum SGD update
298        for i in 0..param_size {
299            let mut grad = grad_data[i];
300
301            // Apply weight decay
302            if self.config.weight_decay > 0.0 {
303                grad += self.config.weight_decay * param_data[i];
304            }
305
306            // Update momentum
307            momentum[i] = self.config.fallback_momentum * momentum[i] + grad;
308
309            // Update parameter
310            param_data[i] -= self.config.fallback_lr * momentum[i];
311        }
312
313        Ok(())
314    }
315
316    /// Get memory statistics for Muon state (deprecated - use memory_usage instead)
317    pub fn memory_stats(&self) -> StateMemoryStats {
318        self.memory_usage()
319    }
320
321    /// Get optimization statistics
322    pub fn optimization_stats(&self) -> (usize, usize, f32) {
323        let params_2d = self.momentum_2d.len();
324        let params_1d = self.momentum_1d.len();
325        let total_params = params_2d + params_1d;
326        let ratio_2d = if total_params > 0 { params_2d as f32 / total_params as f32 } else { 0.0 };
327
328        (params_2d, params_1d, ratio_2d)
329    }
330}
331
332impl Default for Muon {
333    fn default() -> Self {
334        Self::new()
335    }
336}
337
338impl Optimizer for Muon {
339    fn update(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
340        let param_data = parameter.data_mut()?;
341        let grad_data = grad.data()?;
342
343        // Generate unique parameter ID based on memory address
344        let param_id = format!("param_{:p}", param_data.as_ptr());
345        let param_size = param_data.len();
346
347        // Determine parameter shape
348        let (rows, cols) = if let Some(&shape) = self.param_shapes.get(&param_id) {
349            shape
350        } else {
351            // Try common factorizations for typical NN layers
352            let factors = self.find_good_factorization(param_size);
353            self.param_shapes.insert(param_id.clone(), factors);
354            factors
355        };
356
357        // Choose optimization method based on parameter shape
358        if self.should_use_2d_optimization(rows, cols) && rows * cols == param_size {
359            self.update_2d_parameter(param_data, &grad_data, &param_id, rows, cols)?;
360        } else {
361            self.update_1d_parameter(param_data, &grad_data, &param_id)?;
362        }
363
364        Ok(())
365    }
366
367    fn step(&mut self) {
368        self.state.step += 1;
369    }
370
371    fn zero_grad(&mut self) {
372        // This is typically handled by the training framework
373        // No action needed here as gradients are managed externally
374    }
375
376    fn get_lr(&self) -> f32 {
377        self.config.learning_rate
378    }
379
380    fn set_lr(&mut self, lr: f32) {
381        self.config.learning_rate = lr;
382    }
383}
384
385impl Muon {
386    /// Find a good factorization for a given parameter size
387    fn find_good_factorization(&self, size: usize) -> (usize, usize) {
388        if size < self.config.min_dim_2d {
389            return (1, size);
390        }
391
392        // Common neural network layer sizes
393        let sqrt_size = (size as f32).sqrt() as usize;
394
395        // Try factors close to square root
396        for offset in 0..=sqrt_size / 4 {
397            let candidate1 = sqrt_size + offset;
398            let candidate2 = sqrt_size - offset;
399
400            if candidate1 > 0 && size % candidate1 == 0 {
401                let other = size / candidate1;
402                if candidate1 >= self.config.min_dim_2d && other >= self.config.min_dim_2d {
403                    return (candidate1, other);
404                }
405            }
406
407            if candidate2 > 0 && size % candidate2 == 0 {
408                let other = size / candidate2;
409                if candidate2 >= self.config.min_dim_2d && other >= self.config.min_dim_2d {
410                    return (candidate2, other);
411                }
412            }
413        }
414
415        // If no good factorization found, treat as 1D
416        (1, size)
417    }
418}
419
420impl StatefulOptimizer for Muon {
421    type Config = MuonConfig;
422    type State = OptimizerState;
423
424    fn config(&self) -> &Self::Config {
425        &self.config
426    }
427
428    fn state(&self) -> &Self::State {
429        &self.state
430    }
431
432    fn state_mut(&mut self) -> &mut Self::State {
433        &mut self.state
434    }
435
436    fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
437        let mut state_dict = HashMap::new();
438
439        // Save step count
440        state_dict.insert(
441            "step".to_string(),
442            Tensor::new(vec![self.state.step as f32])?,
443        );
444
445        // Save 2D momentum buffers (flattened)
446        for (param_id, momentum) in &self.momentum_2d {
447            let mut flattened = Vec::new();
448            for row in momentum {
449                flattened.extend_from_slice(row);
450            }
451            state_dict.insert(format!("momentum_2d_{}", param_id), Tensor::new(flattened)?);
452        }
453
454        // Save 1D momentum buffers
455        for (param_id, momentum) in &self.momentum_1d {
456            state_dict.insert(
457                format!("momentum_1d_{}", param_id),
458                Tensor::new(momentum.clone())?,
459            );
460        }
461
462        // Save parameter shapes
463        for (param_id, &(rows, cols)) in &self.param_shapes {
464            state_dict.insert(
465                format!("shape_{}", param_id),
466                Tensor::new(vec![rows as f32, cols as f32])?,
467            );
468        }
469
470        Ok(state_dict)
471    }
472
473    fn load_state_dict(&mut self, state_dict: HashMap<String, Tensor>) -> Result<()> {
474        // Load step count
475        if let Some(step_tensor) = state_dict.get("step") {
476            let step_data = step_tensor.data()?;
477            if !step_data.is_empty() {
478                self.state.step = step_data[0] as usize;
479            }
480        }
481
482        // Load parameter shapes first
483        for (key, tensor) in &state_dict {
484            if let Some(param_id) = key.strip_prefix("shape_") {
485                let shape_data = tensor.data()?;
486                if shape_data.len() >= 2 {
487                    let rows = shape_data[0] as usize;
488                    let cols = shape_data[1] as usize;
489                    self.param_shapes.insert(param_id.to_string(), (rows, cols));
490                }
491            }
492        }
493
494        // Load momentum buffers
495        for (key, tensor) in &state_dict {
496            let data = tensor.data()?;
497            if let Some(param_id) = key.strip_prefix("momentum_2d_") {
498                if let Some(&(rows, cols)) = self.param_shapes.get(param_id) {
499                    let mut momentum = vec![vec![0.0; cols]; rows];
500                    for i in 0..rows {
501                        for j in 0..cols {
502                            let idx = i * cols + j;
503                            if idx < data.len() {
504                                momentum[i][j] = data[idx];
505                            }
506                        }
507                    }
508                    self.momentum_2d.insert(param_id.to_string(), momentum);
509                }
510            } else if let Some(param_id) = key.strip_prefix("momentum_1d_") {
511                self.momentum_1d.insert(param_id.to_string(), data);
512            }
513        }
514
515        Ok(())
516    }
517
518    fn memory_usage(&self) -> StateMemoryStats {
519        let mut momentum_elements = 0;
520        let mut total_elements = 0;
521
522        // Count 2D momentum elements
523        for momentum in self.momentum_2d.values() {
524            let param_count = momentum.len() * momentum[0].len();
525            momentum_elements += param_count;
526            total_elements += param_count;
527        }
528
529        // Count 1D momentum elements
530        for momentum in self.momentum_1d.values() {
531            momentum_elements += momentum.len();
532            total_elements += momentum.len();
533        }
534
535        let total_bytes = total_elements * std::mem::size_of::<f32>();
536
537        StateMemoryStats {
538            momentum_elements,
539            variance_elements: 0,
540            third_moment_elements: 0,
541            total_bytes,
542            num_parameters: momentum_elements,
543        }
544    }
545
546    fn reset_state(&mut self) {
547        self.state = OptimizerState::new();
548        self.momentum_2d.clear();
549        self.momentum_1d.clear();
550        self.param_shapes.clear();
551    }
552
553    fn num_parameters(&self) -> usize {
554        let mut total = 0;
555        for momentum in self.momentum_2d.values() {
556            total += momentum.len() * momentum[0].len();
557        }
558        for momentum in self.momentum_1d.values() {
559            total += momentum.len();
560        }
561        total
562    }
563}
564
565#[cfg(test)]
566mod tests {
567    use super::*;
568    use approx::assert_relative_eq;
569
570    #[test]
571    fn test_muon_creation() {
572        let optimizer = Muon::new();
573        assert_eq!(optimizer.config.learning_rate, 0.02);
574        assert_eq!(optimizer.config.momentum, 0.95);
575        assert_eq!(optimizer.config.ns_steps, 5);
576        assert_eq!(optimizer.config.min_dim_2d, 64);
577        assert_eq!(optimizer.state.step, 0);
578    }
579
580    #[test]
581    fn test_muon_with_lr() {
582        let optimizer = Muon::new_with_lr(0.01);
583        assert_eq!(optimizer.config.learning_rate, 0.01);
584    }
585
586    #[test]
587    fn test_muon_nanogpt_preset() {
588        let optimizer = Muon::for_nanogpt();
589        assert_eq!(optimizer.config.learning_rate, 0.01);
590        assert_eq!(optimizer.config.min_dim_2d, 32);
591        assert_eq!(optimizer.config.fallback_lr, 5e-4);
592    }
593
594    #[test]
595    fn test_muon_cifar10_preset() {
596        let optimizer = Muon::for_cifar10();
597        assert_eq!(optimizer.config.learning_rate, 0.03);
598        assert_eq!(optimizer.config.ns_steps, 4);
599        assert_eq!(optimizer.config.weight_decay, 1e-4);
600    }
601
602    #[test]
603    fn test_muon_large_lm_preset() {
604        let optimizer = Muon::for_large_lm();
605        assert_eq!(optimizer.config.learning_rate, 0.015);
606        assert_eq!(optimizer.config.momentum, 0.98);
607        assert_eq!(optimizer.config.min_dim_2d, 128);
608    }
609
610    #[test]
611    fn test_should_use_2d_optimization() {
612        let optimizer = Muon::new();
613
614        // Should use 2D for large matrices
615        assert!(optimizer.should_use_2d_optimization(128, 128));
616        assert!(optimizer.should_use_2d_optimization(64, 256));
617
618        // Should not use 2D for small matrices
619        assert!(!optimizer.should_use_2d_optimization(32, 32));
620        assert!(!optimizer.should_use_2d_optimization(64, 32));
621        assert!(!optimizer.should_use_2d_optimization(1, 1000));
622    }
623
624    #[test]
625    fn test_find_good_factorization() {
626        let optimizer = Muon::new();
627
628        // Perfect square
629        let (rows, cols) = optimizer.find_good_factorization(64 * 64);
630        assert_eq!(rows * cols, 64 * 64);
631        assert!(rows >= optimizer.config.min_dim_2d);
632        assert!(cols >= optimizer.config.min_dim_2d);
633
634        // Small size should be treated as 1D
635        let (rows, cols) = optimizer.find_good_factorization(10);
636        assert_eq!((rows, cols), (1, 10));
637
638        // Common NN layer size
639        let (rows, cols) = optimizer.find_good_factorization(128 * 256);
640        assert_eq!(rows * cols, 128 * 256);
641    }
642
643    #[test]
644    fn test_optimization_stats() {
645        let mut optimizer = Muon::new();
646
647        // Initially no parameters
648        let (params_2d, params_1d, ratio) = optimizer.optimization_stats();
649        assert_eq!(params_2d, 0);
650        assert_eq!(params_1d, 0);
651        assert_eq!(ratio, 0.0);
652
653        // Add some 2D and 1D parameters
654        optimizer.momentum_2d.insert("param_0".to_string(), vec![vec![0.0; 128]; 128]);
655        optimizer.momentum_1d.insert("param_1".to_string(), vec![0.0; 10]);
656        optimizer.momentum_1d.insert("param_2".to_string(), vec![0.0; 20]);
657
658        let (params_2d, params_1d, ratio) = optimizer.optimization_stats();
659        assert_eq!(params_2d, 1);
660        assert_eq!(params_1d, 2);
661        assert_relative_eq!(ratio, 1.0 / 3.0, epsilon = 1e-6);
662    }
663
664    #[test]
665    fn test_memory_stats() {
666        let mut optimizer = Muon::new();
667
668        // Add momentum buffers
669        optimizer.momentum_2d.insert("param_0".to_string(), vec![vec![0.0; 100]; 50]); // 5000 params
670        optimizer.momentum_1d.insert("param_1".to_string(), vec![0.0; 1000]); // 1000 params
671
672        let stats = optimizer.memory_stats();
673        assert_eq!(stats.num_parameters, 6000);
674        assert_eq!(stats.momentum_elements, 6000);
675        assert_eq!(stats.variance_elements, 0);
676        assert_eq!(stats.total_bytes, 6000 * 4); // 4 bytes per f32
677    }
678
679    #[test]
680    fn test_state_dict_operations() {
681        let mut optimizer = Muon::new();
682        optimizer.state.step = 5;
683
684        // Add parameter shapes and momentum
685        optimizer.param_shapes.insert("param_0".to_string(), (2, 3));
686        optimizer.momentum_2d.insert(
687            "param_0".to_string(),
688            vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]],
689        );
690        optimizer.momentum_1d.insert("param_1".to_string(), vec![0.7, 0.8]);
691
692        // Save state
693        let state_dict = optimizer.state_dict().unwrap();
694        assert!(state_dict.contains_key("step"));
695        assert!(state_dict.contains_key("momentum_2d_param_0"));
696        assert!(state_dict.contains_key("momentum_1d_param_1"));
697        assert!(state_dict.contains_key("shape_param_0"));
698
699        // Create new optimizer and load state
700        let mut new_optimizer = Muon::new();
701        new_optimizer.load_state_dict(state_dict).unwrap();
702
703        assert_eq!(new_optimizer.state.step, 5);
704        assert_eq!(new_optimizer.param_shapes["param_0"], (2, 3));
705        assert_eq!(new_optimizer.momentum_1d["param_1"], vec![0.7, 0.8]);
706    }
707
708    #[test]
709    fn test_lr_setter_getter() {
710        let mut optimizer = Muon::new();
711        assert_eq!(optimizer.get_lr(), 0.02);
712
713        optimizer.set_lr(0.01);
714        assert_eq!(optimizer.get_lr(), 0.01);
715        assert_eq!(optimizer.config.learning_rate, 0.01);
716    }
717
718    #[test]
719    fn test_reset() {
720        let mut optimizer = Muon::new();
721        optimizer.state.step = 10;
722        optimizer.momentum_2d.insert("param_0".to_string(), vec![vec![1.0]]);
723        optimizer.momentum_1d.insert("param_1".to_string(), vec![1.0]);
724        optimizer.param_shapes.insert("param_0".to_string(), (1, 1));
725
726        optimizer.reset_state();
727
728        assert_eq!(optimizer.state.step, 0);
729        assert!(optimizer.momentum_2d.is_empty());
730        assert!(optimizer.momentum_1d.is_empty());
731        assert!(optimizer.param_shapes.is_empty());
732    }
733
734    #[test]
735    fn test_config_serialization() {
736        let config = MuonConfig {
737            learning_rate: 0.01,
738            momentum: 0.9,
739            ns_steps: 3,
740            min_dim_2d: 32,
741            fallback_lr: 1e-4,
742            fallback_momentum: 0.8,
743            weight_decay: 1e-5,
744            use_orthogonal: false,
745        };
746
747        let serialized = serde_json::to_string(&config).unwrap();
748        let deserialized: MuonConfig = serde_json::from_str(&serialized).unwrap();
749
750        assert_relative_eq!(deserialized.learning_rate, config.learning_rate);
751        assert_eq!(deserialized.ns_steps, config.ns_steps);
752        assert_eq!(deserialized.use_orthogonal, config.use_orthogonal);
753    }
754}