1use crate::OptimizerState;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::{Arc, Mutex};
11use trustformers_core::errors::Result;
12use trustformers_core::Tensor;
13
14#[derive(Debug, Clone)]
16pub enum FusedOperation {
17 FusedAdam {
19 lr: f64,
20 beta1: f64,
21 beta2: f64,
22 eps: f64,
23 weight_decay: f64,
24 },
25 FusedAdamW {
27 lr: f64,
28 beta1: f64,
29 beta2: f64,
30 eps: f64,
31 weight_decay: f64,
32 },
33 FusedSGDMomentum {
35 lr: f64,
36 momentum: f64,
37 dampening: f64,
38 weight_decay: f64,
39 nesterov: bool,
40 },
41 FusedGradientClipping { max_norm: f64, scale_factor: f64 },
43 FusedBatchNorm { eps: f64, momentum: f64 },
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct FusionConfig {
50 pub enable_memory_coalescing: bool,
52 pub enable_vectorization: bool,
54 pub batch_size: usize,
56 pub enable_kernel_fusion: bool,
58 pub buffer_size: usize,
60 pub enable_async_updates: bool,
62}
63
64impl Default for FusionConfig {
65 fn default() -> Self {
66 Self {
67 enable_memory_coalescing: true,
68 enable_vectorization: true,
69 batch_size: 64,
70 enable_kernel_fusion: true,
71 buffer_size: 1024,
72 enable_async_updates: false,
73 }
74 }
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct FusedOptimizerState {
80 pub parameter_states: HashMap<String, OptimizerState>,
82 pub operation_buffers: HashMap<String, Vec<f64>>,
84 pub fusion_stats: FusionStats,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct FusionStats {
91 pub fused_operations: u64,
93 pub memory_bandwidth_saved: u64,
95 pub flops_saved: u64,
97 pub avg_batch_size: f64,
99 pub fusion_efficiency: f64,
101}
102
103impl Default for FusionStats {
104 fn default() -> Self {
105 Self {
106 fused_operations: 0,
107 memory_bandwidth_saved: 0,
108 flops_saved: 0,
109 avg_batch_size: 0.0,
110 fusion_efficiency: 0.0,
111 }
112 }
113}
114
115#[derive(Debug)]
117pub struct FusedOptimizer {
118 config: FusionConfig,
119 state: Arc<Mutex<FusedOptimizerState>>,
120 pending_operations: Arc<Mutex<Vec<(String, FusedOperation, Tensor, Tensor)>>>,
121 #[allow(dead_code)]
122 operation_queue: Arc<Mutex<HashMap<String, Vec<FusedOperation>>>>,
123}
124
125impl FusedOptimizer {
126 pub fn new(config: FusionConfig) -> Result<Self> {
128 let state = FusedOptimizerState {
129 parameter_states: HashMap::new(),
130 operation_buffers: HashMap::new(),
131 fusion_stats: FusionStats::default(),
132 };
133
134 Ok(Self {
135 config,
136 state: Arc::new(Mutex::new(state)),
137 pending_operations: Arc::new(Mutex::new(Vec::new())),
138 operation_queue: Arc::new(Mutex::new(HashMap::new())),
139 })
140 }
141
142 pub fn queue_operation(
144 &mut self,
145 param_name: String,
146 operation: FusedOperation,
147 parameter: Tensor,
148 gradient: Tensor,
149 ) -> Result<()> {
150 let should_execute = {
151 let mut pending = self.pending_operations.lock().expect("Mutex lock poisoned");
152 pending.push((param_name, operation, parameter, gradient));
153 pending.len() >= self.config.batch_size
154 };
155
156 if should_execute {
158 self.execute_fused_batch()?;
159 }
160
161 Ok(())
162 }
163
164 pub fn execute_fused_batch(&mut self) -> Result<()> {
166 let mut pending = self.pending_operations.lock().expect("Mutex lock poisoned");
167 if pending.is_empty() {
168 return Ok(());
169 }
170
171 let operations = std::mem::take(&mut *pending);
172 drop(pending);
173
174 let mut adam_ops = Vec::new();
176 let mut adamw_ops = Vec::new();
177 let mut sgd_ops = Vec::new();
178 let mut clip_ops = Vec::new();
179
180 for (param_name, op, param, grad) in operations {
181 match op {
182 FusedOperation::FusedAdam { .. } => adam_ops.push((param_name, op, param, grad)),
183 FusedOperation::FusedAdamW { .. } => adamw_ops.push((param_name, op, param, grad)),
184 FusedOperation::FusedSGDMomentum { .. } => {
185 sgd_ops.push((param_name, op, param, grad))
186 },
187 FusedOperation::FusedGradientClipping { .. } => {
188 clip_ops.push((param_name, op, param, grad))
189 },
190 _ => {
191 self.execute_single_operation(param_name, op, param, grad)?;
193 },
194 }
195 }
196
197 if !adam_ops.is_empty() {
199 self.execute_fused_adam_batch(adam_ops)?;
200 }
201 if !adamw_ops.is_empty() {
202 self.execute_fused_adamw_batch(adamw_ops)?;
203 }
204 if !sgd_ops.is_empty() {
205 self.execute_fused_sgd_batch(sgd_ops)?;
206 }
207 if !clip_ops.is_empty() {
208 self.execute_fused_clipping_batch(clip_ops)?;
209 }
210
211 Ok(())
212 }
213
214 fn execute_fused_adam_batch(
216 &mut self,
217 operations: Vec<(String, FusedOperation, Tensor, Tensor)>,
218 ) -> Result<()> {
219 let mut state = self.state.lock().expect("Mutex lock poisoned");
220 let batch_size = operations.len();
221
222 for (param_name, op, param, grad) in operations {
223 if let FusedOperation::FusedAdam {
224 lr,
225 beta1,
226 beta2,
227 eps,
228 weight_decay,
229 } = op
230 {
231 let opt_state =
233 state.parameter_states.entry(param_name.clone()).or_insert_with(|| {
234 OptimizerState {
235 step: 0,
236 momentum: HashMap::new(),
237 variance: HashMap::new(),
238 ..Default::default()
239 }
240 });
241
242 self.fused_adam_update(
244 ¶m,
245 &grad,
246 opt_state,
247 lr,
248 beta1,
249 beta2,
250 eps,
251 weight_decay,
252 )?;
253 }
254 }
255
256 state.fusion_stats.fused_operations += 1;
258 state.fusion_stats.avg_batch_size = (state.fusion_stats.avg_batch_size
259 * (state.fusion_stats.fused_operations - 1) as f64
260 + batch_size as f64)
261 / state.fusion_stats.fused_operations as f64;
262
263 let bandwidth_saved = batch_size * 4 * 8; state.fusion_stats.memory_bandwidth_saved += bandwidth_saved as u64;
266
267 Ok(())
268 }
269
270 fn execute_fused_adamw_batch(
272 &mut self,
273 operations: Vec<(String, FusedOperation, Tensor, Tensor)>,
274 ) -> Result<()> {
275 let mut state = self.state.lock().expect("Mutex lock poisoned");
276 let batch_size = operations.len();
277
278 for (param_name, op, param, grad) in operations {
279 if let FusedOperation::FusedAdamW {
280 lr,
281 beta1,
282 beta2,
283 eps,
284 weight_decay,
285 } = op
286 {
287 let opt_state =
288 state.parameter_states.entry(param_name.clone()).or_insert_with(|| {
289 OptimizerState {
290 step: 0,
291 momentum: HashMap::new(),
292 variance: HashMap::new(),
293 ..Default::default()
294 }
295 });
296
297 self.fused_adamw_update(
299 ¶m,
300 &grad,
301 opt_state,
302 lr,
303 beta1,
304 beta2,
305 eps,
306 weight_decay,
307 )?;
308 }
309 }
310
311 state.fusion_stats.fused_operations += 1;
313 let bandwidth_saved = batch_size * 4 * 8;
314 state.fusion_stats.memory_bandwidth_saved += bandwidth_saved as u64;
315
316 Ok(())
317 }
318
319 fn execute_fused_sgd_batch(
321 &mut self,
322 operations: Vec<(String, FusedOperation, Tensor, Tensor)>,
323 ) -> Result<()> {
324 let mut state = self.state.lock().expect("Mutex lock poisoned");
325 let batch_size = operations.len();
326
327 for (param_name, op, param, grad) in operations {
328 if let FusedOperation::FusedSGDMomentum {
329 lr,
330 momentum,
331 dampening,
332 weight_decay,
333 nesterov,
334 } = op
335 {
336 let opt_state =
337 state.parameter_states.entry(param_name.clone()).or_insert_with(|| {
338 OptimizerState {
339 step: 0,
340 momentum: HashMap::new(),
341 ..Default::default()
342 }
343 });
344
345 self.fused_sgd_update(
347 ¶m,
348 &grad,
349 opt_state,
350 lr,
351 momentum,
352 dampening,
353 weight_decay,
354 nesterov,
355 )?;
356 }
357 }
358
359 state.fusion_stats.fused_operations += 1;
361 let bandwidth_saved = batch_size * 2 * 8; state.fusion_stats.memory_bandwidth_saved += bandwidth_saved as u64;
363
364 Ok(())
365 }
366
367 fn execute_fused_clipping_batch(
369 &mut self,
370 operations: Vec<(String, FusedOperation, Tensor, Tensor)>,
371 ) -> Result<()> {
372 let mut state = self.state.lock().expect("Mutex lock poisoned");
373 let batch_size = operations.len();
374
375 let mut gradients = Vec::new();
377 for (_, _, _, grad) in &operations {
378 gradients.push(grad.clone());
379 }
380
381 let global_norm = self.compute_global_norm(&gradients)?;
383
384 for (_, op, _, grad) in operations {
385 if let FusedOperation::FusedGradientClipping {
386 max_norm,
387 scale_factor,
388 } = op
389 {
390 if global_norm > max_norm {
392 let clip_coef = max_norm / global_norm;
393 let grad_mut = grad;
394 grad_mut.mul_scalar((clip_coef * scale_factor) as f32)?;
395 } else {
396 let grad_mut = grad;
397 grad_mut.mul_scalar(scale_factor as f32)?;
398 }
399 }
400 }
401
402 state.fusion_stats.fused_operations += 1;
404 let bandwidth_saved = batch_size * 8; state.fusion_stats.memory_bandwidth_saved += bandwidth_saved as u64;
406
407 Ok(())
408 }
409
410 fn execute_single_operation(
412 &mut self,
413 _param_name: String,
414 _operation: FusedOperation,
415 _parameter: Tensor,
416 _gradient: Tensor,
417 ) -> Result<()> {
418 Ok(())
420 }
421
422 fn fused_adam_update(
424 &self,
425 param: &Tensor,
426 grad: &Tensor,
427 state: &mut OptimizerState,
428 lr: f64,
429 beta1: f64,
430 beta2: f64,
431 eps: f64,
432 weight_decay: f64,
433 ) -> Result<()> {
434 use crate::common::ParameterIds;
435
436 state.step += 1;
437 let param_id = ParameterIds::from_tensor(param)?;
438 let param_len = param.data()?.len();
439
440 let momentum =
442 state.momentum.entry(param_id.clone()).or_insert_with(|| vec![0.0; param_len]);
443 let variance = state.variance.entry(param_id).or_insert_with(|| vec![0.0; param_len]);
444
445 let grad_data = grad.data()?;
446 let mut param_data = param.data()?;
447
448 let bias_correction1 = 1.0 - beta1.powi(state.step as i32);
450 let bias_correction2 = 1.0 - beta2.powi(state.step as i32);
451
452 for i in 0..param_data.len() {
454 let mut grad_val = grad_data[i];
455
456 if weight_decay > 0.0 {
458 grad_val += weight_decay as f32 * param_data[i];
459 }
460
461 momentum[i] = beta1 as f32 * momentum[i] + (1.0 - beta1 as f32) * grad_val;
463
464 variance[i] = beta2 as f32 * variance[i] + (1.0 - beta2 as f32) * grad_val * grad_val;
466
467 let m_hat = momentum[i] / bias_correction1 as f32;
469 let v_hat = variance[i] / bias_correction2 as f32;
470
471 param_data[i] -= lr as f32 * m_hat / (v_hat.sqrt() + eps as f32);
473 }
474
475 Ok(())
476 }
477
478 fn fused_adamw_update(
480 &self,
481 param: &Tensor,
482 grad: &Tensor,
483 state: &mut OptimizerState,
484 lr: f64,
485 beta1: f64,
486 beta2: f64,
487 eps: f64,
488 weight_decay: f64,
489 ) -> Result<()> {
490 use crate::common::ParameterIds;
491
492 state.step += 1;
493 let param_id = ParameterIds::from_tensor(param)?;
494 let param_len = param.data()?.len();
495
496 let momentum =
498 state.momentum.entry(param_id.clone()).or_insert_with(|| vec![0.0; param_len]);
499 let variance = state.variance.entry(param_id).or_insert_with(|| vec![0.0; param_len]);
500
501 let grad_data = grad.data()?;
502 let mut param_data = param.data()?;
503
504 let bias_correction1 = 1.0 - beta1.powi(state.step as i32);
506 let bias_correction2 = 1.0 - beta2.powi(state.step as i32);
507
508 for i in 0..param_data.len() {
510 let grad_val = grad_data[i];
511
512 momentum[i] = beta1 as f32 * momentum[i] + (1.0 - beta1 as f32) * grad_val;
514
515 variance[i] = beta2 as f32 * variance[i] + (1.0 - beta2 as f32) * grad_val * grad_val;
517
518 let m_hat = momentum[i] / bias_correction1 as f32;
520 let v_hat = variance[i] / bias_correction2 as f32;
521
522 let adaptive_step = lr as f32 * m_hat / (v_hat.sqrt() + eps as f32);
524 let weight_decay_step = lr as f32 * weight_decay as f32 * param_data[i];
525
526 param_data[i] -= adaptive_step + weight_decay_step;
528 }
529
530 Ok(())
531 }
532
533 fn fused_sgd_update(
535 &self,
536 param: &Tensor,
537 grad: &Tensor,
538 state: &mut OptimizerState,
539 lr: f64,
540 momentum_coef: f64,
541 dampening: f64,
542 weight_decay: f64,
543 nesterov: bool,
544 ) -> Result<()> {
545 use crate::common::ParameterIds;
546
547 state.step += 1;
548 let param_id = ParameterIds::from_tensor(param)?;
549 let param_len = param.data()?.len();
550
551 let momentum = state.momentum.entry(param_id).or_insert_with(|| vec![0.0; param_len]);
553
554 let grad_data = grad.data()?;
555 let mut param_data = param.data()?;
556
557 for i in 0..param_data.len() {
559 let mut grad_val = grad_data[i];
560
561 if weight_decay > 0.0 {
563 grad_val += weight_decay as f32 * param_data[i];
564 }
565
566 if momentum_coef > 0.0 {
568 if state.step == 1 {
569 momentum[i] = grad_val;
571 } else {
572 momentum[i] =
574 momentum_coef as f32 * momentum[i] + (1.0 - dampening as f32) * grad_val;
575 }
576
577 let update_direction = if nesterov {
579 grad_val + momentum_coef as f32 * momentum[i]
580 } else {
581 momentum[i]
582 };
583
584 param_data[i] -= lr as f32 * update_direction;
586 } else {
587 param_data[i] -= lr as f32 * grad_val;
589 }
590 }
591
592 Ok(())
593 }
594
595 fn compute_global_norm(&self, gradients: &[Tensor]) -> Result<f64> {
597 let mut total_norm_sq = 0.0;
598
599 for grad in gradients {
600 let norm = grad.norm()?;
601 total_norm_sq += norm * norm;
602 }
603
604 Ok(total_norm_sq.sqrt() as f64)
605 }
606
607 pub fn flush(&mut self) -> Result<()> {
609 self.execute_fused_batch()
610 }
611
612 pub fn get_fusion_stats(&self) -> FusionStats {
614 let state = self.state.lock().expect("Mutex lock poisoned");
615 state.fusion_stats.clone()
616 }
617
618 pub fn reset_stats(&mut self) {
620 let mut state = self.state.lock().expect("Mutex lock poisoned");
621 state.fusion_stats = FusionStats::default();
622 }
623
624 pub fn update_config(&mut self, config: FusionConfig) {
626 self.config = config;
627 }
628}
629
630#[cfg(target_arch = "x86_64")]
632pub mod simd {
633
634 pub fn simd_adam_update(
636 param: &mut [f32],
637 grad: &[f32],
638 momentum: &mut [f32],
639 velocity: &mut [f32],
640 lr: f32,
641 beta1: f32,
642 beta2: f32,
643 eps: f32,
644 step: i32,
645 ) {
646 use std::arch::x86_64::*;
647
648 let bias_correction1 = 1.0 - beta1.powi(step);
649 let bias_correction2 = 1.0 - beta2.powi(step);
650 let corrected_lr = lr * (bias_correction2.sqrt() / bias_correction1);
651
652 unsafe {
653 let beta1_vec = _mm256_set1_ps(beta1);
654 let beta2_vec = _mm256_set1_ps(beta2);
655 let one_minus_beta1 = _mm256_set1_ps(1.0 - beta1);
656 let one_minus_beta2 = _mm256_set1_ps(1.0 - beta2);
657 let eps_vec = _mm256_set1_ps(eps);
658 let lr_vec = _mm256_set1_ps(corrected_lr);
659
660 let chunks = param.len() / 8;
661 for i in 0..chunks {
662 let idx = i * 8;
663
664 let p = _mm256_loadu_ps(param.as_ptr().add(idx));
666 let g = _mm256_loadu_ps(grad.as_ptr().add(idx));
667 let m = _mm256_loadu_ps(momentum.as_ptr().add(idx));
668 let v = _mm256_loadu_ps(velocity.as_ptr().add(idx));
669
670 let m_new = _mm256_fmadd_ps(beta1_vec, m, _mm256_mul_ps(one_minus_beta1, g));
672
673 let g_sq = _mm256_mul_ps(g, g);
675 let v_new = _mm256_fmadd_ps(beta2_vec, v, _mm256_mul_ps(one_minus_beta2, g_sq));
676
677 let v_sqrt = _mm256_sqrt_ps(v_new);
679 let v_sqrt_eps = _mm256_add_ps(v_sqrt, eps_vec);
680 let update = _mm256_div_ps(m_new, v_sqrt_eps);
681 let p_new = _mm256_fnmadd_ps(lr_vec, update, p);
682
683 _mm256_storeu_ps(param.as_mut_ptr().add(idx), p_new);
685 _mm256_storeu_ps(momentum.as_mut_ptr().add(idx), m_new);
686 _mm256_storeu_ps(velocity.as_mut_ptr().add(idx), v_new);
687 }
688
689 for i in (chunks * 8)..param.len() {
691 let g = grad[i];
692 momentum[i] = beta1 * momentum[i] + (1.0 - beta1) * g;
693 velocity[i] = beta2 * velocity[i] + (1.0 - beta2) * g * g;
694 param[i] -= corrected_lr * momentum[i] / (velocity[i].sqrt() + eps);
695 }
696 }
697 }
698}
699
700#[cfg(test)]
701mod tests {
702 use super::*;
703 use trustformers_core::Tensor;
704
705 #[test]
706 fn test_fused_optimizer_creation() {
707 let config = FusionConfig::default();
708 let optimizer = FusedOptimizer::new(config).unwrap();
709
710 let stats = optimizer.get_fusion_stats();
711 assert_eq!(stats.fused_operations, 0);
712 }
713
714 #[test]
715 fn test_fused_adam_operation() {
716 let config = FusionConfig::default();
717 let mut optimizer = FusedOptimizer::new(config).unwrap();
718
719 let param = Tensor::ones(&[10, 10]).unwrap();
720 let grad = Tensor::ones(&[10, 10]).unwrap();
721
722 let operation = FusedOperation::FusedAdam {
723 lr: 0.001,
724 beta1: 0.9,
725 beta2: 0.999,
726 eps: 1e-8,
727 weight_decay: 0.0,
728 };
729
730 optimizer.queue_operation("param1".to_string(), operation, param, grad).unwrap();
731
732 optimizer.flush().unwrap();
733
734 let stats = optimizer.get_fusion_stats();
735 assert_eq!(stats.fused_operations, 1);
736 }
737
738 #[test]
739 fn test_fused_adamw_operation() {
740 let config = FusionConfig::default();
741 let mut optimizer = FusedOptimizer::new(config).unwrap();
742
743 let param = Tensor::ones(&[5, 5]).unwrap();
744 let grad = Tensor::ones(&[5, 5]).unwrap();
745
746 let operation = FusedOperation::FusedAdamW {
747 lr: 0.001,
748 beta1: 0.9,
749 beta2: 0.999,
750 eps: 1e-8,
751 weight_decay: 0.01,
752 };
753
754 optimizer.queue_operation("param2".to_string(), operation, param, grad).unwrap();
755
756 optimizer.flush().unwrap();
757
758 let stats = optimizer.get_fusion_stats();
759 assert_eq!(stats.fused_operations, 1);
760 }
761
762 #[test]
763 fn test_fused_sgd_operation() {
764 let config = FusionConfig::default();
765 let mut optimizer = FusedOptimizer::new(config).unwrap();
766
767 let param = Tensor::ones(&[3, 3]).unwrap();
768 let grad = Tensor::ones(&[3, 3]).unwrap();
769
770 let operation = FusedOperation::FusedSGDMomentum {
771 lr: 0.01,
772 momentum: 0.9,
773 dampening: 0.0,
774 weight_decay: 0.0,
775 nesterov: false,
776 };
777
778 optimizer.queue_operation("param3".to_string(), operation, param, grad).unwrap();
779
780 optimizer.flush().unwrap();
781
782 let stats = optimizer.get_fusion_stats();
783 assert_eq!(stats.fused_operations, 1);
784 }
785
786 #[test]
787 fn test_batch_fusion() {
788 let mut config = FusionConfig::default();
789 config.batch_size = 2;
790 let mut optimizer = FusedOptimizer::new(config).unwrap();
791
792 for i in 0..3 {
794 let param = Tensor::ones(&[2, 2]).unwrap();
795 let grad = Tensor::ones(&[2, 2]).unwrap();
796
797 let operation = FusedOperation::FusedAdam {
798 lr: 0.001,
799 beta1: 0.9,
800 beta2: 0.999,
801 eps: 1e-8,
802 weight_decay: 0.0,
803 };
804
805 optimizer
806 .queue_operation(format!("param_{}", i), operation, param, grad)
807 .unwrap();
808 }
809
810 let stats = optimizer.get_fusion_stats();
812 assert!(stats.fused_operations > 0);
813 }
814
815 #[test]
816 fn test_fusion_stats() {
817 let config = FusionConfig::default();
818 let mut optimizer = FusedOptimizer::new(config).unwrap();
819
820 let param = Tensor::ones(&[10, 10]).unwrap();
821 let grad = Tensor::ones(&[10, 10]).unwrap();
822
823 let operation = FusedOperation::FusedAdam {
824 lr: 0.001,
825 beta1: 0.9,
826 beta2: 0.999,
827 eps: 1e-8,
828 weight_decay: 0.0,
829 };
830
831 optimizer.queue_operation("param1".to_string(), operation, param, grad).unwrap();
832
833 optimizer.flush().unwrap();
834
835 let stats = optimizer.get_fusion_stats();
836 assert_eq!(stats.fused_operations, 1);
837 assert!(stats.memory_bandwidth_saved > 0);
838
839 optimizer.reset_stats();
840 let reset_stats = optimizer.get_fusion_stats();
841 assert_eq!(reset_stats.fused_operations, 0);
842 assert_eq!(reset_stats.memory_bandwidth_saved, 0);
843 }
844
845 #[test]
846 fn test_global_norm_computation() {
847 let config = FusionConfig::default();
848 let optimizer = FusedOptimizer::new(config).unwrap();
849
850 let grad1 = Tensor::ones(&[3, 3]).unwrap();
851 let grad2 = Tensor::ones(&[2, 2]).unwrap();
852
853 let gradients = vec![grad1, grad2];
854 let global_norm = optimizer.compute_global_norm(&gradients).unwrap();
855
856 assert!((global_norm - 3.606).abs() < 0.01);
858 }
859}