Skip to main content

trustformers_optim/
fusion.rs

1//! Optimizer Fusion Techniques
2//!
3//! This module provides advanced optimizer fusion techniques for performance optimization.
4//! It combines multiple optimizer operations into fused kernels to reduce memory bandwidth
5//! and improve overall training performance.
6
7use crate::OptimizerState;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::{Arc, Mutex};
11use trustformers_core::errors::Result;
12use trustformers_core::Tensor;
13
14/// Fused optimizer operations for performance optimization
15#[derive(Debug, Clone)]
16pub enum FusedOperation {
17    /// Fused Adam update (parameter, gradient, momentum, velocity)
18    FusedAdam {
19        lr: f64,
20        beta1: f64,
21        beta2: f64,
22        eps: f64,
23        weight_decay: f64,
24    },
25    /// Fused AdamW update with decoupled weight decay
26    FusedAdamW {
27        lr: f64,
28        beta1: f64,
29        beta2: f64,
30        eps: f64,
31        weight_decay: f64,
32    },
33    /// Fused SGD with momentum
34    FusedSGDMomentum {
35        lr: f64,
36        momentum: f64,
37        dampening: f64,
38        weight_decay: f64,
39        nesterov: bool,
40    },
41    /// Fused gradient clipping and scaling
42    FusedGradientClipping { max_norm: f64, scale_factor: f64 },
43    /// Fused batch normalization update
44    FusedBatchNorm { eps: f64, momentum: f64 },
45}
46
47/// Configuration for fused optimizer operations
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct FusionConfig {
50    /// Enable memory bandwidth optimization
51    pub enable_memory_coalescing: bool,
52    /// Use vectorized operations when possible
53    pub enable_vectorization: bool,
54    /// Batch size for parameter updates
55    pub batch_size: usize,
56    /// Enable kernel fusion for compatible operations
57    pub enable_kernel_fusion: bool,
58    /// Buffer size for batched operations
59    pub buffer_size: usize,
60    /// Enable asynchronous updates
61    pub enable_async_updates: bool,
62}
63
64impl Default for FusionConfig {
65    fn default() -> Self {
66        Self {
67            enable_memory_coalescing: true,
68            enable_vectorization: true,
69            batch_size: 64,
70            enable_kernel_fusion: true,
71            buffer_size: 1024,
72            enable_async_updates: false,
73        }
74    }
75}
76
77/// Fused optimizer state for multiple parameters
78#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct FusedOptimizerState {
80    /// Parameter states indexed by parameter name
81    pub parameter_states: HashMap<String, OptimizerState>,
82    /// Fused operation buffers
83    pub operation_buffers: HashMap<String, Vec<f64>>,
84    /// Fusion statistics
85    pub fusion_stats: FusionStats,
86}
87
88/// Statistics for fusion operations
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct FusionStats {
91    /// Number of fused operations executed
92    pub fused_operations: u64,
93    /// Memory bandwidth saved (bytes)
94    pub memory_bandwidth_saved: u64,
95    /// FLOPS saved through fusion
96    pub flops_saved: u64,
97    /// Average batch size
98    pub avg_batch_size: f64,
99    /// Fusion efficiency ratio
100    pub fusion_efficiency: f64,
101}
102
103impl Default for FusionStats {
104    fn default() -> Self {
105        Self {
106            fused_operations: 0,
107            memory_bandwidth_saved: 0,
108            flops_saved: 0,
109            avg_batch_size: 0.0,
110            fusion_efficiency: 0.0,
111        }
112    }
113}
114
115/// Fused optimizer that combines multiple optimization operations
116#[derive(Debug)]
117pub struct FusedOptimizer {
118    config: FusionConfig,
119    state: Arc<Mutex<FusedOptimizerState>>,
120    pending_operations: Arc<Mutex<Vec<(String, FusedOperation, Tensor, Tensor)>>>,
121    #[allow(dead_code)]
122    operation_queue: Arc<Mutex<HashMap<String, Vec<FusedOperation>>>>,
123}
124
125impl FusedOptimizer {
126    /// Create new fused optimizer
127    pub fn new(config: FusionConfig) -> Result<Self> {
128        let state = FusedOptimizerState {
129            parameter_states: HashMap::new(),
130            operation_buffers: HashMap::new(),
131            fusion_stats: FusionStats::default(),
132        };
133
134        Ok(Self {
135            config,
136            state: Arc::new(Mutex::new(state)),
137            pending_operations: Arc::new(Mutex::new(Vec::new())),
138            operation_queue: Arc::new(Mutex::new(HashMap::new())),
139        })
140    }
141
142    /// Add operation to fusion queue
143    pub fn queue_operation(
144        &mut self,
145        param_name: String,
146        operation: FusedOperation,
147        parameter: Tensor,
148        gradient: Tensor,
149    ) -> Result<()> {
150        let should_execute = {
151            let mut pending = self.pending_operations.lock().expect("Mutex lock poisoned");
152            pending.push((param_name, operation, parameter, gradient));
153            pending.len() >= self.config.batch_size
154        };
155
156        // Execute batch if buffer is full
157        if should_execute {
158            self.execute_fused_batch()?;
159        }
160
161        Ok(())
162    }
163
164    /// Execute all pending operations in a fused manner
165    pub fn execute_fused_batch(&mut self) -> Result<()> {
166        let mut pending = self.pending_operations.lock().expect("Mutex lock poisoned");
167        if pending.is_empty() {
168            return Ok(());
169        }
170
171        let operations = std::mem::take(&mut *pending);
172        drop(pending);
173
174        // Group operations by type for maximum fusion efficiency
175        let mut adam_ops = Vec::new();
176        let mut adamw_ops = Vec::new();
177        let mut sgd_ops = Vec::new();
178        let mut clip_ops = Vec::new();
179
180        for (param_name, op, param, grad) in operations {
181            match op {
182                FusedOperation::FusedAdam { .. } => adam_ops.push((param_name, op, param, grad)),
183                FusedOperation::FusedAdamW { .. } => adamw_ops.push((param_name, op, param, grad)),
184                FusedOperation::FusedSGDMomentum { .. } => {
185                    sgd_ops.push((param_name, op, param, grad))
186                },
187                FusedOperation::FusedGradientClipping { .. } => {
188                    clip_ops.push((param_name, op, param, grad))
189                },
190                _ => {
191                    // Handle other operations individually
192                    self.execute_single_operation(param_name, op, param, grad)?;
193                },
194            }
195        }
196
197        // Execute fused batches
198        if !adam_ops.is_empty() {
199            self.execute_fused_adam_batch(adam_ops)?;
200        }
201        if !adamw_ops.is_empty() {
202            self.execute_fused_adamw_batch(adamw_ops)?;
203        }
204        if !sgd_ops.is_empty() {
205            self.execute_fused_sgd_batch(sgd_ops)?;
206        }
207        if !clip_ops.is_empty() {
208            self.execute_fused_clipping_batch(clip_ops)?;
209        }
210
211        Ok(())
212    }
213
214    /// Execute fused Adam operations
215    fn execute_fused_adam_batch(
216        &mut self,
217        operations: Vec<(String, FusedOperation, Tensor, Tensor)>,
218    ) -> Result<()> {
219        let mut state = self.state.lock().expect("Mutex lock poisoned");
220        let batch_size = operations.len();
221
222        for (param_name, op, param, grad) in operations {
223            if let FusedOperation::FusedAdam {
224                lr,
225                beta1,
226                beta2,
227                eps,
228                weight_decay,
229            } = op
230            {
231                // Get or create optimizer state
232                let opt_state =
233                    state.parameter_states.entry(param_name.clone()).or_insert_with(|| {
234                        OptimizerState {
235                            step: 0,
236                            momentum: HashMap::new(),
237                            variance: HashMap::new(),
238                            ..Default::default()
239                        }
240                    });
241
242                // Fused Adam update with optimized memory access
243                self.fused_adam_update(
244                    &param,
245                    &grad,
246                    opt_state,
247                    lr,
248                    beta1,
249                    beta2,
250                    eps,
251                    weight_decay,
252                )?;
253            }
254        }
255
256        // Update fusion statistics
257        state.fusion_stats.fused_operations += 1;
258        state.fusion_stats.avg_batch_size = (state.fusion_stats.avg_batch_size
259            * (state.fusion_stats.fused_operations - 1) as f64
260            + batch_size as f64)
261            / state.fusion_stats.fused_operations as f64;
262
263        // Estimate memory bandwidth savings (simplified)
264        let bandwidth_saved = batch_size * 4 * 8; // 4 tensors * 8 bytes per element (approximate)
265        state.fusion_stats.memory_bandwidth_saved += bandwidth_saved as u64;
266
267        Ok(())
268    }
269
270    /// Execute fused AdamW operations
271    fn execute_fused_adamw_batch(
272        &mut self,
273        operations: Vec<(String, FusedOperation, Tensor, Tensor)>,
274    ) -> Result<()> {
275        let mut state = self.state.lock().expect("Mutex lock poisoned");
276        let batch_size = operations.len();
277
278        for (param_name, op, param, grad) in operations {
279            if let FusedOperation::FusedAdamW {
280                lr,
281                beta1,
282                beta2,
283                eps,
284                weight_decay,
285            } = op
286            {
287                let opt_state =
288                    state.parameter_states.entry(param_name.clone()).or_insert_with(|| {
289                        OptimizerState {
290                            step: 0,
291                            momentum: HashMap::new(),
292                            variance: HashMap::new(),
293                            ..Default::default()
294                        }
295                    });
296
297                // Fused AdamW update with decoupled weight decay
298                self.fused_adamw_update(
299                    &param,
300                    &grad,
301                    opt_state,
302                    lr,
303                    beta1,
304                    beta2,
305                    eps,
306                    weight_decay,
307                )?;
308            }
309        }
310
311        // Update statistics
312        state.fusion_stats.fused_operations += 1;
313        let bandwidth_saved = batch_size * 4 * 8;
314        state.fusion_stats.memory_bandwidth_saved += bandwidth_saved as u64;
315
316        Ok(())
317    }
318
319    /// Execute fused SGD operations
320    fn execute_fused_sgd_batch(
321        &mut self,
322        operations: Vec<(String, FusedOperation, Tensor, Tensor)>,
323    ) -> Result<()> {
324        let mut state = self.state.lock().expect("Mutex lock poisoned");
325        let batch_size = operations.len();
326
327        for (param_name, op, param, grad) in operations {
328            if let FusedOperation::FusedSGDMomentum {
329                lr,
330                momentum,
331                dampening,
332                weight_decay,
333                nesterov,
334            } = op
335            {
336                let opt_state =
337                    state.parameter_states.entry(param_name.clone()).or_insert_with(|| {
338                        OptimizerState {
339                            step: 0,
340                            momentum: HashMap::new(),
341                            ..Default::default()
342                        }
343                    });
344
345                // Fused SGD with momentum update
346                self.fused_sgd_update(
347                    &param,
348                    &grad,
349                    opt_state,
350                    lr,
351                    momentum,
352                    dampening,
353                    weight_decay,
354                    nesterov,
355                )?;
356            }
357        }
358
359        // Update statistics
360        state.fusion_stats.fused_operations += 1;
361        let bandwidth_saved = batch_size * 2 * 8; // SGD uses fewer tensors
362        state.fusion_stats.memory_bandwidth_saved += bandwidth_saved as u64;
363
364        Ok(())
365    }
366
367    /// Execute fused gradient clipping operations
368    fn execute_fused_clipping_batch(
369        &mut self,
370        operations: Vec<(String, FusedOperation, Tensor, Tensor)>,
371    ) -> Result<()> {
372        let mut state = self.state.lock().expect("Mutex lock poisoned");
373        let batch_size = operations.len();
374
375        // Collect all gradients for global norm computation
376        let mut gradients = Vec::new();
377        for (_, _, _, grad) in &operations {
378            gradients.push(grad.clone());
379        }
380
381        // Compute global gradient norm for batch
382        let global_norm = self.compute_global_norm(&gradients)?;
383
384        for (_, op, _, grad) in operations {
385            if let FusedOperation::FusedGradientClipping {
386                max_norm,
387                scale_factor,
388            } = op
389            {
390                // Apply clipping with pre-computed global norm
391                if global_norm > max_norm {
392                    let clip_coef = max_norm / global_norm;
393                    let grad_mut = grad;
394                    grad_mut.mul_scalar((clip_coef * scale_factor) as f32)?;
395                } else {
396                    let grad_mut = grad;
397                    grad_mut.mul_scalar(scale_factor as f32)?;
398                }
399            }
400        }
401
402        // Update statistics
403        state.fusion_stats.fused_operations += 1;
404        let bandwidth_saved = batch_size * 8; // Single pass through gradients
405        state.fusion_stats.memory_bandwidth_saved += bandwidth_saved as u64;
406
407        Ok(())
408    }
409
410    /// Execute single operation (fallback for non-batchable operations)
411    fn execute_single_operation(
412        &mut self,
413        _param_name: String,
414        _operation: FusedOperation,
415        _parameter: Tensor,
416        _gradient: Tensor,
417    ) -> Result<()> {
418        // Implementation for individual operations
419        Ok(())
420    }
421
422    /// Optimized Adam update with fused operations
423    fn fused_adam_update(
424        &self,
425        param: &Tensor,
426        grad: &Tensor,
427        state: &mut OptimizerState,
428        lr: f64,
429        beta1: f64,
430        beta2: f64,
431        eps: f64,
432        weight_decay: f64,
433    ) -> Result<()> {
434        use crate::common::ParameterIds;
435
436        state.step += 1;
437        let param_id = ParameterIds::from_tensor(param)?;
438        let param_len = param.data()?.len();
439
440        // Get or initialize momentum and variance buffers
441        let momentum =
442            state.momentum.entry(param_id.clone()).or_insert_with(|| vec![0.0; param_len]);
443        let variance = state.variance.entry(param_id).or_insert_with(|| vec![0.0; param_len]);
444
445        let grad_data = grad.data()?;
446        let mut param_data = param.data()?;
447
448        // Bias correction factors
449        let bias_correction1 = 1.0 - beta1.powi(state.step as i32);
450        let bias_correction2 = 1.0 - beta2.powi(state.step as i32);
451
452        // Fused update loop - combines all operations in single pass
453        for i in 0..param_data.len() {
454            let mut grad_val = grad_data[i];
455
456            // Apply weight decay if specified (L2 regularization)
457            if weight_decay > 0.0 {
458                grad_val += weight_decay as f32 * param_data[i];
459            }
460
461            // Update biased first moment estimate (momentum)
462            momentum[i] = beta1 as f32 * momentum[i] + (1.0 - beta1 as f32) * grad_val;
463
464            // Update biased second raw moment estimate (variance)
465            variance[i] = beta2 as f32 * variance[i] + (1.0 - beta2 as f32) * grad_val * grad_val;
466
467            // Compute bias-corrected first and second moment estimates
468            let m_hat = momentum[i] / bias_correction1 as f32;
469            let v_hat = variance[i] / bias_correction2 as f32;
470
471            // Update parameter with fused Adam step
472            param_data[i] -= lr as f32 * m_hat / (v_hat.sqrt() + eps as f32);
473        }
474
475        Ok(())
476    }
477
478    /// Optimized AdamW update with fused operations and decoupled weight decay
479    fn fused_adamw_update(
480        &self,
481        param: &Tensor,
482        grad: &Tensor,
483        state: &mut OptimizerState,
484        lr: f64,
485        beta1: f64,
486        beta2: f64,
487        eps: f64,
488        weight_decay: f64,
489    ) -> Result<()> {
490        use crate::common::ParameterIds;
491
492        state.step += 1;
493        let param_id = ParameterIds::from_tensor(param)?;
494        let param_len = param.data()?.len();
495
496        // Get or initialize momentum and variance buffers
497        let momentum =
498            state.momentum.entry(param_id.clone()).or_insert_with(|| vec![0.0; param_len]);
499        let variance = state.variance.entry(param_id).or_insert_with(|| vec![0.0; param_len]);
500
501        let grad_data = grad.data()?;
502        let mut param_data = param.data()?;
503
504        // Bias correction factors
505        let bias_correction1 = 1.0 - beta1.powi(state.step as i32);
506        let bias_correction2 = 1.0 - beta2.powi(state.step as i32);
507
508        // Fused AdamW update loop - decoupled weight decay
509        for i in 0..param_data.len() {
510            let grad_val = grad_data[i];
511
512            // Update biased first moment estimate (momentum)
513            momentum[i] = beta1 as f32 * momentum[i] + (1.0 - beta1 as f32) * grad_val;
514
515            // Update biased second raw moment estimate (variance)
516            variance[i] = beta2 as f32 * variance[i] + (1.0 - beta2 as f32) * grad_val * grad_val;
517
518            // Compute bias-corrected first and second moment estimates
519            let m_hat = momentum[i] / bias_correction1 as f32;
520            let v_hat = variance[i] / bias_correction2 as f32;
521
522            // AdamW update: apply weight decay directly to parameters (decoupled)
523            let adaptive_step = lr as f32 * m_hat / (v_hat.sqrt() + eps as f32);
524            let weight_decay_step = lr as f32 * weight_decay as f32 * param_data[i];
525
526            // Combined update with decoupled weight decay
527            param_data[i] -= adaptive_step + weight_decay_step;
528        }
529
530        Ok(())
531    }
532
533    /// Optimized SGD update with fused momentum
534    fn fused_sgd_update(
535        &self,
536        param: &Tensor,
537        grad: &Tensor,
538        state: &mut OptimizerState,
539        lr: f64,
540        momentum_coef: f64,
541        dampening: f64,
542        weight_decay: f64,
543        nesterov: bool,
544    ) -> Result<()> {
545        use crate::common::ParameterIds;
546
547        state.step += 1;
548        let param_id = ParameterIds::from_tensor(param)?;
549        let param_len = param.data()?.len();
550
551        // Get or initialize momentum buffer
552        let momentum = state.momentum.entry(param_id).or_insert_with(|| vec![0.0; param_len]);
553
554        let grad_data = grad.data()?;
555        let mut param_data = param.data()?;
556
557        // Fused SGD update loop with momentum
558        for i in 0..param_data.len() {
559            let mut grad_val = grad_data[i];
560
561            // Apply weight decay if specified
562            if weight_decay > 0.0 {
563                grad_val += weight_decay as f32 * param_data[i];
564            }
565
566            // Update momentum buffer
567            if momentum_coef > 0.0 {
568                if state.step == 1 {
569                    // First step: initialize momentum with gradient
570                    momentum[i] = grad_val;
571                } else {
572                    // Update momentum with dampening
573                    momentum[i] =
574                        momentum_coef as f32 * momentum[i] + (1.0 - dampening as f32) * grad_val;
575                }
576
577                // Apply Nesterov momentum if enabled
578                let update_direction = if nesterov {
579                    grad_val + momentum_coef as f32 * momentum[i]
580                } else {
581                    momentum[i]
582                };
583
584                // Update parameter
585                param_data[i] -= lr as f32 * update_direction;
586            } else {
587                // Simple SGD without momentum
588                param_data[i] -= lr as f32 * grad_val;
589            }
590        }
591
592        Ok(())
593    }
594
595    /// Compute global gradient norm for clipping
596    fn compute_global_norm(&self, gradients: &[Tensor]) -> Result<f64> {
597        let mut total_norm_sq = 0.0;
598
599        for grad in gradients {
600            let norm = grad.norm()?;
601            total_norm_sq += norm * norm;
602        }
603
604        Ok(total_norm_sq.sqrt() as f64)
605    }
606
607    /// Flush all pending operations
608    pub fn flush(&mut self) -> Result<()> {
609        self.execute_fused_batch()
610    }
611
612    /// Get fusion statistics
613    pub fn get_fusion_stats(&self) -> FusionStats {
614        let state = self.state.lock().expect("Mutex lock poisoned");
615        state.fusion_stats.clone()
616    }
617
618    /// Reset fusion statistics
619    pub fn reset_stats(&mut self) {
620        let mut state = self.state.lock().expect("Mutex lock poisoned");
621        state.fusion_stats = FusionStats::default();
622    }
623
624    /// Update fusion configuration
625    pub fn update_config(&mut self, config: FusionConfig) {
626        self.config = config;
627    }
628}
629
630/// SIMD-optimized vectorized operations
631#[cfg(target_arch = "x86_64")]
632pub mod simd {
633
634    /// SIMD-optimized Adam update
635    pub fn simd_adam_update(
636        param: &mut [f32],
637        grad: &[f32],
638        momentum: &mut [f32],
639        velocity: &mut [f32],
640        lr: f32,
641        beta1: f32,
642        beta2: f32,
643        eps: f32,
644        step: i32,
645    ) {
646        use std::arch::x86_64::*;
647
648        let bias_correction1 = 1.0 - beta1.powi(step);
649        let bias_correction2 = 1.0 - beta2.powi(step);
650        let corrected_lr = lr * (bias_correction2.sqrt() / bias_correction1);
651
652        unsafe {
653            let beta1_vec = _mm256_set1_ps(beta1);
654            let beta2_vec = _mm256_set1_ps(beta2);
655            let one_minus_beta1 = _mm256_set1_ps(1.0 - beta1);
656            let one_minus_beta2 = _mm256_set1_ps(1.0 - beta2);
657            let eps_vec = _mm256_set1_ps(eps);
658            let lr_vec = _mm256_set1_ps(corrected_lr);
659
660            let chunks = param.len() / 8;
661            for i in 0..chunks {
662                let idx = i * 8;
663
664                // Load values
665                let p = _mm256_loadu_ps(param.as_ptr().add(idx));
666                let g = _mm256_loadu_ps(grad.as_ptr().add(idx));
667                let m = _mm256_loadu_ps(momentum.as_ptr().add(idx));
668                let v = _mm256_loadu_ps(velocity.as_ptr().add(idx));
669
670                // Update momentum: momentum = beta1 * momentum + (1 - beta1) * grad
671                let m_new = _mm256_fmadd_ps(beta1_vec, m, _mm256_mul_ps(one_minus_beta1, g));
672
673                // Update velocity: velocity = beta2 * velocity + (1 - beta2) * grad^2
674                let g_sq = _mm256_mul_ps(g, g);
675                let v_new = _mm256_fmadd_ps(beta2_vec, v, _mm256_mul_ps(one_minus_beta2, g_sq));
676
677                // Update parameter: param = param - lr * momentum / (sqrt(velocity) + eps)
678                let v_sqrt = _mm256_sqrt_ps(v_new);
679                let v_sqrt_eps = _mm256_add_ps(v_sqrt, eps_vec);
680                let update = _mm256_div_ps(m_new, v_sqrt_eps);
681                let p_new = _mm256_fnmadd_ps(lr_vec, update, p);
682
683                // Store results
684                _mm256_storeu_ps(param.as_mut_ptr().add(idx), p_new);
685                _mm256_storeu_ps(momentum.as_mut_ptr().add(idx), m_new);
686                _mm256_storeu_ps(velocity.as_mut_ptr().add(idx), v_new);
687            }
688
689            // Handle remaining elements
690            for i in (chunks * 8)..param.len() {
691                let g = grad[i];
692                momentum[i] = beta1 * momentum[i] + (1.0 - beta1) * g;
693                velocity[i] = beta2 * velocity[i] + (1.0 - beta2) * g * g;
694                param[i] -= corrected_lr * momentum[i] / (velocity[i].sqrt() + eps);
695            }
696        }
697    }
698}
699
700#[cfg(test)]
701mod tests {
702    use super::*;
703    use trustformers_core::Tensor;
704
705    #[test]
706    fn test_fused_optimizer_creation() {
707        let config = FusionConfig::default();
708        let optimizer = FusedOptimizer::new(config).unwrap();
709
710        let stats = optimizer.get_fusion_stats();
711        assert_eq!(stats.fused_operations, 0);
712    }
713
714    #[test]
715    fn test_fused_adam_operation() {
716        let config = FusionConfig::default();
717        let mut optimizer = FusedOptimizer::new(config).unwrap();
718
719        let param = Tensor::ones(&[10, 10]).unwrap();
720        let grad = Tensor::ones(&[10, 10]).unwrap();
721
722        let operation = FusedOperation::FusedAdam {
723            lr: 0.001,
724            beta1: 0.9,
725            beta2: 0.999,
726            eps: 1e-8,
727            weight_decay: 0.0,
728        };
729
730        optimizer.queue_operation("param1".to_string(), operation, param, grad).unwrap();
731
732        optimizer.flush().unwrap();
733
734        let stats = optimizer.get_fusion_stats();
735        assert_eq!(stats.fused_operations, 1);
736    }
737
738    #[test]
739    fn test_fused_adamw_operation() {
740        let config = FusionConfig::default();
741        let mut optimizer = FusedOptimizer::new(config).unwrap();
742
743        let param = Tensor::ones(&[5, 5]).unwrap();
744        let grad = Tensor::ones(&[5, 5]).unwrap();
745
746        let operation = FusedOperation::FusedAdamW {
747            lr: 0.001,
748            beta1: 0.9,
749            beta2: 0.999,
750            eps: 1e-8,
751            weight_decay: 0.01,
752        };
753
754        optimizer.queue_operation("param2".to_string(), operation, param, grad).unwrap();
755
756        optimizer.flush().unwrap();
757
758        let stats = optimizer.get_fusion_stats();
759        assert_eq!(stats.fused_operations, 1);
760    }
761
762    #[test]
763    fn test_fused_sgd_operation() {
764        let config = FusionConfig::default();
765        let mut optimizer = FusedOptimizer::new(config).unwrap();
766
767        let param = Tensor::ones(&[3, 3]).unwrap();
768        let grad = Tensor::ones(&[3, 3]).unwrap();
769
770        let operation = FusedOperation::FusedSGDMomentum {
771            lr: 0.01,
772            momentum: 0.9,
773            dampening: 0.0,
774            weight_decay: 0.0,
775            nesterov: false,
776        };
777
778        optimizer.queue_operation("param3".to_string(), operation, param, grad).unwrap();
779
780        optimizer.flush().unwrap();
781
782        let stats = optimizer.get_fusion_stats();
783        assert_eq!(stats.fused_operations, 1);
784    }
785
786    #[test]
787    fn test_batch_fusion() {
788        let mut config = FusionConfig::default();
789        config.batch_size = 2;
790        let mut optimizer = FusedOptimizer::new(config).unwrap();
791
792        // Queue multiple operations
793        for i in 0..3 {
794            let param = Tensor::ones(&[2, 2]).unwrap();
795            let grad = Tensor::ones(&[2, 2]).unwrap();
796
797            let operation = FusedOperation::FusedAdam {
798                lr: 0.001,
799                beta1: 0.9,
800                beta2: 0.999,
801                eps: 1e-8,
802                weight_decay: 0.0,
803            };
804
805            optimizer
806                .queue_operation(format!("param_{}", i), operation, param, grad)
807                .unwrap();
808        }
809
810        // Should have executed batch automatically
811        let stats = optimizer.get_fusion_stats();
812        assert!(stats.fused_operations > 0);
813    }
814
815    #[test]
816    fn test_fusion_stats() {
817        let config = FusionConfig::default();
818        let mut optimizer = FusedOptimizer::new(config).unwrap();
819
820        let param = Tensor::ones(&[10, 10]).unwrap();
821        let grad = Tensor::ones(&[10, 10]).unwrap();
822
823        let operation = FusedOperation::FusedAdam {
824            lr: 0.001,
825            beta1: 0.9,
826            beta2: 0.999,
827            eps: 1e-8,
828            weight_decay: 0.0,
829        };
830
831        optimizer.queue_operation("param1".to_string(), operation, param, grad).unwrap();
832
833        optimizer.flush().unwrap();
834
835        let stats = optimizer.get_fusion_stats();
836        assert_eq!(stats.fused_operations, 1);
837        assert!(stats.memory_bandwidth_saved > 0);
838
839        optimizer.reset_stats();
840        let reset_stats = optimizer.get_fusion_stats();
841        assert_eq!(reset_stats.fused_operations, 0);
842        assert_eq!(reset_stats.memory_bandwidth_saved, 0);
843    }
844
845    #[test]
846    fn test_global_norm_computation() {
847        let config = FusionConfig::default();
848        let optimizer = FusedOptimizer::new(config).unwrap();
849
850        let grad1 = Tensor::ones(&[3, 3]).unwrap();
851        let grad2 = Tensor::ones(&[2, 2]).unwrap();
852
853        let gradients = vec![grad1, grad2];
854        let global_norm = optimizer.compute_global_norm(&gradients).unwrap();
855
856        // Expected: sqrt(9 + 4) = sqrt(13) ≈ 3.606
857        assert!((global_norm - 3.606).abs() < 0.01);
858    }
859}