1use crate::common::{OptimizerState, StateMemoryStats};
23use crate::traits::StatefulOptimizer;
24use serde::{Deserialize, Serialize};
25use std::collections::HashMap;
26use trustformers_core::errors::Result;
27use trustformers_core::tensor::Tensor;
28use trustformers_core::traits::Optimizer;
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct MuonConfig {
33 pub learning_rate: f32,
35 pub momentum: f32,
37 pub ns_steps: usize,
39 pub min_dim_2d: usize,
41 pub fallback_lr: f32,
43 pub fallback_momentum: f32,
45 pub weight_decay: f32,
47 pub use_orthogonal: bool,
49}
50
51impl Default for MuonConfig {
52 fn default() -> Self {
53 Self {
54 learning_rate: 0.02,
55 momentum: 0.95,
56 ns_steps: 5,
57 min_dim_2d: 64,
58 fallback_lr: 1e-3,
59 fallback_momentum: 0.9,
60 weight_decay: 0.0,
61 use_orthogonal: true,
62 }
63 }
64}
65
66#[derive(Debug)]
72pub struct Muon {
73 config: MuonConfig,
74 state: OptimizerState,
75 momentum_2d: HashMap<String, Vec<Vec<f32>>>,
77 momentum_1d: HashMap<String, Vec<f32>>,
79 param_shapes: HashMap<String, (usize, usize)>,
81}
82
83impl Muon {
84 pub fn new() -> Self {
86 Self::with_config(MuonConfig::default())
87 }
88
89 pub fn new_with_lr(learning_rate: f32) -> Self {
91 let config = MuonConfig {
92 learning_rate,
93 ..Default::default()
94 };
95 Self::with_config(config)
96 }
97
98 pub fn for_nanogpt() -> Self {
100 let config = MuonConfig {
101 learning_rate: 0.01,
102 momentum: 0.95,
103 ns_steps: 5,
104 min_dim_2d: 32, fallback_lr: 5e-4,
106 fallback_momentum: 0.9,
107 weight_decay: 0.0,
108 use_orthogonal: true,
109 };
110 Self::with_config(config)
111 }
112
113 pub fn for_cifar10() -> Self {
115 let config = MuonConfig {
116 learning_rate: 0.03,
117 momentum: 0.9,
118 ns_steps: 4, min_dim_2d: 64,
120 fallback_lr: 1e-3,
121 fallback_momentum: 0.9,
122 weight_decay: 1e-4,
123 use_orthogonal: true,
124 };
125 Self::with_config(config)
126 }
127
128 pub fn for_large_lm() -> Self {
130 let config = MuonConfig {
131 learning_rate: 0.015,
132 momentum: 0.98, ns_steps: 6, min_dim_2d: 128, fallback_lr: 3e-4,
136 fallback_momentum: 0.95,
137 weight_decay: 0.01,
138 use_orthogonal: true,
139 };
140 Self::with_config(config)
141 }
142
143 pub fn with_config(config: MuonConfig) -> Self {
145 Self {
146 config,
147 state: OptimizerState::new(),
148 momentum_2d: HashMap::new(),
149 momentum_1d: HashMap::new(),
150 param_shapes: HashMap::new(),
151 }
152 }
153
154 fn should_use_2d_optimization(&self, rows: usize, cols: usize) -> bool {
156 rows >= self.config.min_dim_2d && cols >= self.config.min_dim_2d
157 }
158
159 fn newton_schulz_orthogonalize(&self, matrix: &mut [Vec<f32>]) {
162 if !self.config.use_orthogonal {
163 return;
164 }
165
166 let rows = matrix.len();
167 let cols = matrix[0].len();
168
169 for _ in 0..self.config.ns_steps {
171 let mut xtx = vec![vec![0.0; cols]; cols];
173 for i in 0..cols {
174 for j in 0..cols {
175 let mut sum = 0.0;
176 for k in 0..rows {
177 sum += matrix[k][i] * matrix[k][j];
178 }
179 xtx[i][j] = sum;
180 }
181 }
182
183 for i in 0..cols {
185 for j in 0..cols {
186 if i == j {
187 xtx[i][j] = 3.0 - xtx[i][j];
188 } else {
189 xtx[i][j] = -xtx[i][j];
190 }
191 }
192 }
193
194 let mut new_matrix = vec![vec![0.0; cols]; rows];
196 for i in 0..rows {
197 for j in 0..cols {
198 let mut sum = 0.0;
199 for k in 0..cols {
200 sum += matrix[i][k] * xtx[k][j];
201 }
202 new_matrix[i][j] = sum * 0.5;
203 }
204 }
205
206 for i in 0..rows {
208 for j in 0..cols {
209 matrix[i][j] = new_matrix[i][j];
210 }
211 }
212 }
213 }
214
215 fn update_2d_parameter(
217 &mut self,
218 param_data: &mut [f32],
219 grad_data: &[f32],
220 param_id: &str,
221 rows: usize,
222 cols: usize,
223 ) -> Result<()> {
224 if !self.momentum_2d.contains_key(param_id) {
226 let momentum = vec![vec![0.0; cols]; rows];
227 self.momentum_2d.insert(param_id.to_string(), momentum);
228 }
229
230 let momentum = self.momentum_2d.get_mut(param_id).unwrap();
231
232 let mut param_matrix = vec![vec![0.0; cols]; rows];
234 let mut grad_matrix = vec![vec![0.0; cols]; rows];
235
236 for i in 0..rows {
238 for j in 0..cols {
239 let idx = i * cols + j;
240 param_matrix[i][j] = param_data[idx];
241 grad_matrix[i][j] = grad_data[idx];
242 }
243 }
244
245 if self.config.weight_decay > 0.0 {
247 for i in 0..rows {
248 for j in 0..cols {
249 grad_matrix[i][j] += self.config.weight_decay * param_matrix[i][j];
250 }
251 }
252 }
253
254 for i in 0..rows {
256 for j in 0..cols {
257 momentum[i][j] = self.config.momentum * momentum[i][j] + grad_matrix[i][j];
258 }
259 }
260
261 let mut update_matrix = momentum.clone();
263
264 self.newton_schulz_orthogonalize(&mut update_matrix);
266
267 for i in 0..rows {
269 for j in 0..cols {
270 param_matrix[i][j] -= self.config.learning_rate * update_matrix[i][j];
271
272 let idx = i * cols + j;
274 param_data[idx] = param_matrix[i][j];
275 }
276 }
277
278 Ok(())
279 }
280
281 fn update_1d_parameter(
283 &mut self,
284 param_data: &mut [f32],
285 grad_data: &[f32],
286 param_id: &str,
287 ) -> Result<()> {
288 let param_size = param_data.len();
289
290 if !self.momentum_1d.contains_key(param_id) {
292 self.momentum_1d.insert(param_id.to_string(), vec![0.0; param_size]);
293 }
294
295 let momentum = self.momentum_1d.get_mut(param_id).unwrap();
296
297 for i in 0..param_size {
299 let mut grad = grad_data[i];
300
301 if self.config.weight_decay > 0.0 {
303 grad += self.config.weight_decay * param_data[i];
304 }
305
306 momentum[i] = self.config.fallback_momentum * momentum[i] + grad;
308
309 param_data[i] -= self.config.fallback_lr * momentum[i];
311 }
312
313 Ok(())
314 }
315
316 pub fn memory_stats(&self) -> StateMemoryStats {
318 self.memory_usage()
319 }
320
321 pub fn optimization_stats(&self) -> (usize, usize, f32) {
323 let params_2d = self.momentum_2d.len();
324 let params_1d = self.momentum_1d.len();
325 let total_params = params_2d + params_1d;
326 let ratio_2d = if total_params > 0 { params_2d as f32 / total_params as f32 } else { 0.0 };
327
328 (params_2d, params_1d, ratio_2d)
329 }
330}
331
332impl Default for Muon {
333 fn default() -> Self {
334 Self::new()
335 }
336}
337
338impl Optimizer for Muon {
339 fn update(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
340 let param_data = parameter.data_mut()?;
341 let grad_data = grad.data()?;
342
343 let param_id = format!("param_{:p}", param_data.as_ptr());
345 let param_size = param_data.len();
346
347 let (rows, cols) = if let Some(&shape) = self.param_shapes.get(¶m_id) {
349 shape
350 } else {
351 let factors = self.find_good_factorization(param_size);
353 self.param_shapes.insert(param_id.clone(), factors);
354 factors
355 };
356
357 if self.should_use_2d_optimization(rows, cols) && rows * cols == param_size {
359 self.update_2d_parameter(param_data, &grad_data, ¶m_id, rows, cols)?;
360 } else {
361 self.update_1d_parameter(param_data, &grad_data, ¶m_id)?;
362 }
363
364 Ok(())
365 }
366
367 fn step(&mut self) {
368 self.state.step += 1;
369 }
370
371 fn zero_grad(&mut self) {
372 }
375
376 fn get_lr(&self) -> f32 {
377 self.config.learning_rate
378 }
379
380 fn set_lr(&mut self, lr: f32) {
381 self.config.learning_rate = lr;
382 }
383}
384
385impl Muon {
386 fn find_good_factorization(&self, size: usize) -> (usize, usize) {
388 if size < self.config.min_dim_2d {
389 return (1, size);
390 }
391
392 let sqrt_size = (size as f32).sqrt() as usize;
394
395 for offset in 0..=sqrt_size / 4 {
397 let candidate1 = sqrt_size + offset;
398 let candidate2 = sqrt_size - offset;
399
400 if candidate1 > 0 && size % candidate1 == 0 {
401 let other = size / candidate1;
402 if candidate1 >= self.config.min_dim_2d && other >= self.config.min_dim_2d {
403 return (candidate1, other);
404 }
405 }
406
407 if candidate2 > 0 && size % candidate2 == 0 {
408 let other = size / candidate2;
409 if candidate2 >= self.config.min_dim_2d && other >= self.config.min_dim_2d {
410 return (candidate2, other);
411 }
412 }
413 }
414
415 (1, size)
417 }
418}
419
420impl StatefulOptimizer for Muon {
421 type Config = MuonConfig;
422 type State = OptimizerState;
423
424 fn config(&self) -> &Self::Config {
425 &self.config
426 }
427
428 fn state(&self) -> &Self::State {
429 &self.state
430 }
431
432 fn state_mut(&mut self) -> &mut Self::State {
433 &mut self.state
434 }
435
436 fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
437 let mut state_dict = HashMap::new();
438
439 state_dict.insert(
441 "step".to_string(),
442 Tensor::new(vec![self.state.step as f32])?,
443 );
444
445 for (param_id, momentum) in &self.momentum_2d {
447 let mut flattened = Vec::new();
448 for row in momentum {
449 flattened.extend_from_slice(row);
450 }
451 state_dict.insert(format!("momentum_2d_{}", param_id), Tensor::new(flattened)?);
452 }
453
454 for (param_id, momentum) in &self.momentum_1d {
456 state_dict.insert(
457 format!("momentum_1d_{}", param_id),
458 Tensor::new(momentum.clone())?,
459 );
460 }
461
462 for (param_id, &(rows, cols)) in &self.param_shapes {
464 state_dict.insert(
465 format!("shape_{}", param_id),
466 Tensor::new(vec![rows as f32, cols as f32])?,
467 );
468 }
469
470 Ok(state_dict)
471 }
472
473 fn load_state_dict(&mut self, state_dict: HashMap<String, Tensor>) -> Result<()> {
474 if let Some(step_tensor) = state_dict.get("step") {
476 let step_data = step_tensor.data()?;
477 if !step_data.is_empty() {
478 self.state.step = step_data[0] as usize;
479 }
480 }
481
482 for (key, tensor) in &state_dict {
484 if let Some(param_id) = key.strip_prefix("shape_") {
485 let shape_data = tensor.data()?;
486 if shape_data.len() >= 2 {
487 let rows = shape_data[0] as usize;
488 let cols = shape_data[1] as usize;
489 self.param_shapes.insert(param_id.to_string(), (rows, cols));
490 }
491 }
492 }
493
494 for (key, tensor) in &state_dict {
496 let data = tensor.data()?;
497 if let Some(param_id) = key.strip_prefix("momentum_2d_") {
498 if let Some(&(rows, cols)) = self.param_shapes.get(param_id) {
499 let mut momentum = vec![vec![0.0; cols]; rows];
500 for i in 0..rows {
501 for j in 0..cols {
502 let idx = i * cols + j;
503 if idx < data.len() {
504 momentum[i][j] = data[idx];
505 }
506 }
507 }
508 self.momentum_2d.insert(param_id.to_string(), momentum);
509 }
510 } else if let Some(param_id) = key.strip_prefix("momentum_1d_") {
511 self.momentum_1d.insert(param_id.to_string(), data);
512 }
513 }
514
515 Ok(())
516 }
517
518 fn memory_usage(&self) -> StateMemoryStats {
519 let mut momentum_elements = 0;
520 let mut total_elements = 0;
521
522 for momentum in self.momentum_2d.values() {
524 let param_count = momentum.len() * momentum[0].len();
525 momentum_elements += param_count;
526 total_elements += param_count;
527 }
528
529 for momentum in self.momentum_1d.values() {
531 momentum_elements += momentum.len();
532 total_elements += momentum.len();
533 }
534
535 let total_bytes = total_elements * std::mem::size_of::<f32>();
536
537 StateMemoryStats {
538 momentum_elements,
539 variance_elements: 0,
540 third_moment_elements: 0,
541 total_bytes,
542 num_parameters: momentum_elements,
543 }
544 }
545
546 fn reset_state(&mut self) {
547 self.state = OptimizerState::new();
548 self.momentum_2d.clear();
549 self.momentum_1d.clear();
550 self.param_shapes.clear();
551 }
552
553 fn num_parameters(&self) -> usize {
554 let mut total = 0;
555 for momentum in self.momentum_2d.values() {
556 total += momentum.len() * momentum[0].len();
557 }
558 for momentum in self.momentum_1d.values() {
559 total += momentum.len();
560 }
561 total
562 }
563}
564
565#[cfg(test)]
566mod tests {
567 use super::*;
568 use approx::assert_relative_eq;
569
570 #[test]
571 fn test_muon_creation() {
572 let optimizer = Muon::new();
573 assert_eq!(optimizer.config.learning_rate, 0.02);
574 assert_eq!(optimizer.config.momentum, 0.95);
575 assert_eq!(optimizer.config.ns_steps, 5);
576 assert_eq!(optimizer.config.min_dim_2d, 64);
577 assert_eq!(optimizer.state.step, 0);
578 }
579
580 #[test]
581 fn test_muon_with_lr() {
582 let optimizer = Muon::new_with_lr(0.01);
583 assert_eq!(optimizer.config.learning_rate, 0.01);
584 }
585
586 #[test]
587 fn test_muon_nanogpt_preset() {
588 let optimizer = Muon::for_nanogpt();
589 assert_eq!(optimizer.config.learning_rate, 0.01);
590 assert_eq!(optimizer.config.min_dim_2d, 32);
591 assert_eq!(optimizer.config.fallback_lr, 5e-4);
592 }
593
594 #[test]
595 fn test_muon_cifar10_preset() {
596 let optimizer = Muon::for_cifar10();
597 assert_eq!(optimizer.config.learning_rate, 0.03);
598 assert_eq!(optimizer.config.ns_steps, 4);
599 assert_eq!(optimizer.config.weight_decay, 1e-4);
600 }
601
602 #[test]
603 fn test_muon_large_lm_preset() {
604 let optimizer = Muon::for_large_lm();
605 assert_eq!(optimizer.config.learning_rate, 0.015);
606 assert_eq!(optimizer.config.momentum, 0.98);
607 assert_eq!(optimizer.config.min_dim_2d, 128);
608 }
609
610 #[test]
611 fn test_should_use_2d_optimization() {
612 let optimizer = Muon::new();
613
614 assert!(optimizer.should_use_2d_optimization(128, 128));
616 assert!(optimizer.should_use_2d_optimization(64, 256));
617
618 assert!(!optimizer.should_use_2d_optimization(32, 32));
620 assert!(!optimizer.should_use_2d_optimization(64, 32));
621 assert!(!optimizer.should_use_2d_optimization(1, 1000));
622 }
623
624 #[test]
625 fn test_find_good_factorization() {
626 let optimizer = Muon::new();
627
628 let (rows, cols) = optimizer.find_good_factorization(64 * 64);
630 assert_eq!(rows * cols, 64 * 64);
631 assert!(rows >= optimizer.config.min_dim_2d);
632 assert!(cols >= optimizer.config.min_dim_2d);
633
634 let (rows, cols) = optimizer.find_good_factorization(10);
636 assert_eq!((rows, cols), (1, 10));
637
638 let (rows, cols) = optimizer.find_good_factorization(128 * 256);
640 assert_eq!(rows * cols, 128 * 256);
641 }
642
643 #[test]
644 fn test_optimization_stats() {
645 let mut optimizer = Muon::new();
646
647 let (params_2d, params_1d, ratio) = optimizer.optimization_stats();
649 assert_eq!(params_2d, 0);
650 assert_eq!(params_1d, 0);
651 assert_eq!(ratio, 0.0);
652
653 optimizer.momentum_2d.insert("param_0".to_string(), vec![vec![0.0; 128]; 128]);
655 optimizer.momentum_1d.insert("param_1".to_string(), vec![0.0; 10]);
656 optimizer.momentum_1d.insert("param_2".to_string(), vec![0.0; 20]);
657
658 let (params_2d, params_1d, ratio) = optimizer.optimization_stats();
659 assert_eq!(params_2d, 1);
660 assert_eq!(params_1d, 2);
661 assert_relative_eq!(ratio, 1.0 / 3.0, epsilon = 1e-6);
662 }
663
664 #[test]
665 fn test_memory_stats() {
666 let mut optimizer = Muon::new();
667
668 optimizer.momentum_2d.insert("param_0".to_string(), vec![vec![0.0; 100]; 50]); optimizer.momentum_1d.insert("param_1".to_string(), vec![0.0; 1000]); let stats = optimizer.memory_stats();
673 assert_eq!(stats.num_parameters, 6000);
674 assert_eq!(stats.momentum_elements, 6000);
675 assert_eq!(stats.variance_elements, 0);
676 assert_eq!(stats.total_bytes, 6000 * 4); }
678
679 #[test]
680 fn test_state_dict_operations() {
681 let mut optimizer = Muon::new();
682 optimizer.state.step = 5;
683
684 optimizer.param_shapes.insert("param_0".to_string(), (2, 3));
686 optimizer.momentum_2d.insert(
687 "param_0".to_string(),
688 vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]],
689 );
690 optimizer.momentum_1d.insert("param_1".to_string(), vec![0.7, 0.8]);
691
692 let state_dict = optimizer.state_dict().unwrap();
694 assert!(state_dict.contains_key("step"));
695 assert!(state_dict.contains_key("momentum_2d_param_0"));
696 assert!(state_dict.contains_key("momentum_1d_param_1"));
697 assert!(state_dict.contains_key("shape_param_0"));
698
699 let mut new_optimizer = Muon::new();
701 new_optimizer.load_state_dict(state_dict).unwrap();
702
703 assert_eq!(new_optimizer.state.step, 5);
704 assert_eq!(new_optimizer.param_shapes["param_0"], (2, 3));
705 assert_eq!(new_optimizer.momentum_1d["param_1"], vec![0.7, 0.8]);
706 }
707
708 #[test]
709 fn test_lr_setter_getter() {
710 let mut optimizer = Muon::new();
711 assert_eq!(optimizer.get_lr(), 0.02);
712
713 optimizer.set_lr(0.01);
714 assert_eq!(optimizer.get_lr(), 0.01);
715 assert_eq!(optimizer.config.learning_rate, 0.01);
716 }
717
718 #[test]
719 fn test_reset() {
720 let mut optimizer = Muon::new();
721 optimizer.state.step = 10;
722 optimizer.momentum_2d.insert("param_0".to_string(), vec![vec![1.0]]);
723 optimizer.momentum_1d.insert("param_1".to_string(), vec![1.0]);
724 optimizer.param_shapes.insert("param_0".to_string(), (1, 1));
725
726 optimizer.reset_state();
727
728 assert_eq!(optimizer.state.step, 0);
729 assert!(optimizer.momentum_2d.is_empty());
730 assert!(optimizer.momentum_1d.is_empty());
731 assert!(optimizer.param_shapes.is_empty());
732 }
733
734 #[test]
735 fn test_config_serialization() {
736 let config = MuonConfig {
737 learning_rate: 0.01,
738 momentum: 0.9,
739 ns_steps: 3,
740 min_dim_2d: 32,
741 fallback_lr: 1e-4,
742 fallback_momentum: 0.8,
743 weight_decay: 1e-5,
744 use_orthogonal: false,
745 };
746
747 let serialized = serde_json::to_string(&config).unwrap();
748 let deserialized: MuonConfig = serde_json::from_str(&serialized).unwrap();
749
750 assert_relative_eq!(deserialized.learning_rate, config.learning_rate);
751 assert_eq!(deserialized.ns_steps, config.ns_steps);
752 assert_eq!(deserialized.use_orthogonal, config.use_orthogonal);
753 }
754}