Skip to main content

trustformers_optim/
convergence.rs

1//! # Convergence Improvement Methods
2//!
3//! This module implements advanced optimization techniques that improve convergence
4//! speed, stability, and final performance through sophisticated momentum variants
5//! and variance reduction methods.
6//!
7//! ## Available Methods
8//!
9//! - **QHM (Quasi-Hyperbolic Momentum)**: Generalizes momentum and Nesterov acceleration
10//! - **AggMo (Aggregated Momentum)**: Maintains multiple momentum buffers for better convergence
11//! - **SVRG (Stochastic Variance Reduced Gradient)**: Reduces gradient variance for better convergence
12//! - **SAG (Stochastic Average Gradient)**: Maintains running average of gradients
13//! - **Nesterov Accelerated Gradient (NAG)**: Classical acceleration method with lookahead
14//! - **Heavy Ball Method**: Momentum-based acceleration with inertia
15//! - **FISTA**: Fast Iterative Shrinkage-Thresholding Algorithm for proximal methods
16//! - **Adaptive Batch Sizing**: Dynamically adjusts batch size based on training progress
17//! - **Loss Surface Smoothing**: Reduces noise in the loss surface for better convergence
18
19use crate::optimizer::OptimizerState;
20use anyhow::{anyhow, Result};
21use serde::{Deserialize, Serialize};
22use std::collections::HashMap;
23use trustformers_core::tensor::Tensor;
24
25/// Configuration for Quasi-Hyperbolic Momentum (QHM).
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct QHMConfig {
28    /// Learning rate
29    pub learning_rate: f32,
30    /// Momentum parameter (β)
31    pub momentum: f32,
32    /// Averaging parameter (ν) - controls interpolation between current gradient and momentum
33    pub nu: f32,
34    /// Weight decay
35    pub weight_decay: f32,
36}
37
38impl Default for QHMConfig {
39    fn default() -> Self {
40        Self {
41            learning_rate: 1e-3,
42            momentum: 0.9,
43            nu: 0.7,
44            weight_decay: 0.0,
45        }
46    }
47}
48
49/// Quasi-Hyperbolic Momentum optimizer.
50///
51/// QHM interpolates between the current gradient and the momentum buffer,
52/// providing a generalization of both momentum and Nesterov acceleration.
53/// Update rule: p = p - lr * (nu * g + (1 - nu) * momentum)
54#[derive(Debug)]
55pub struct QHM {
56    config: QHMConfig,
57    momentum_buffers: HashMap<usize, Tensor>,
58    current_step: usize,
59}
60
61impl QHM {
62    /// Create a new QHM optimizer.
63    pub fn new(config: QHMConfig) -> Self {
64        Self {
65            config,
66            momentum_buffers: HashMap::new(),
67            current_step: 0,
68        }
69    }
70
71    /// Create QHM with default configuration.
72    pub fn with_defaults(learning_rate: f32, momentum: f32, nu: f32) -> Self {
73        Self::new(QHMConfig {
74            learning_rate,
75            momentum,
76            nu,
77            weight_decay: 0.0,
78        })
79    }
80
81    /// Get the configuration.
82    pub fn get_config(&self) -> &QHMConfig {
83        &self.config
84    }
85
86    /// Update configuration.
87    pub fn set_config(&mut self, config: QHMConfig) {
88        self.config = config;
89    }
90}
91
92impl OptimizerState for QHM {
93    fn zero_grad(&mut self) -> Result<()> {
94        // QHM doesn't need explicit gradient zeroing
95        Ok(())
96    }
97
98    fn step(&mut self, parameters: &mut [Tensor]) -> Result<()> {
99        self.current_step += 1;
100
101        for (param_id, parameter) in parameters.iter_mut().enumerate() {
102            // Access gradient from parameter (should be computed during forward/backward pass)
103            let gradient = match parameter.grad() {
104                Ok(grad) => grad,
105                Err(_) => {
106                    // If gradient is not available, skip this parameter
107                    continue;
108                },
109            };
110
111            // Apply weight decay to gradient
112            let effective_grad = if self.config.weight_decay > 0.0 {
113                gradient.add(&parameter.mul_scalar(self.config.weight_decay)?)?
114            } else {
115                gradient
116            };
117
118            // Get or initialize momentum buffer
119            let momentum_buffer = if let Some(buffer) = self.momentum_buffers.get(&param_id) {
120                // Update momentum: momentum = β * momentum + (1 - β) * grad
121                let updated = buffer
122                    .mul_scalar(self.config.momentum)?
123                    .add(&effective_grad.mul_scalar(1.0 - self.config.momentum)?)?;
124                self.momentum_buffers.insert(param_id, updated.clone());
125                updated
126            } else {
127                // Initialize momentum buffer with current gradient
128                let initial_momentum = effective_grad.clone();
129                self.momentum_buffers.insert(param_id, initial_momentum.clone());
130                initial_momentum
131            };
132
133            // QHM update: interpolate between current gradient and momentum
134            let update_direction = effective_grad
135                .mul_scalar(self.config.nu)?
136                .add(&momentum_buffer.mul_scalar(1.0 - self.config.nu)?)?;
137
138            // Apply update
139            *parameter = parameter.sub(&update_direction.mul_scalar(self.config.learning_rate)?)?;
140        }
141
142        Ok(())
143    }
144
145    fn get_lr(&self) -> f32 {
146        self.config.learning_rate
147    }
148
149    fn set_lr(&mut self, lr: f32) {
150        self.config.learning_rate = lr;
151    }
152
153    fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
154        let mut state = HashMap::new();
155
156        // Save configuration
157        state.insert(
158            "learning_rate".to_string(),
159            Tensor::scalar(self.config.learning_rate)?,
160        );
161        state.insert(
162            "momentum".to_string(),
163            Tensor::scalar(self.config.momentum)?,
164        );
165        state.insert("nu".to_string(), Tensor::scalar(self.config.nu)?);
166        state.insert(
167            "weight_decay".to_string(),
168            Tensor::scalar(self.config.weight_decay)?,
169        );
170        state.insert(
171            "current_step".to_string(),
172            Tensor::scalar(self.current_step as f32)?,
173        );
174
175        // Save momentum buffers
176        for (&param_id, buffer) in &self.momentum_buffers {
177            state.insert(format!("momentum_buffer_{}", param_id), buffer.clone());
178        }
179
180        Ok(state)
181    }
182
183    fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
184        // Load configuration
185        if let Some(lr) = state.get("learning_rate") {
186            self.config.learning_rate = lr.to_scalar()?;
187        }
188        if let Some(momentum) = state.get("momentum") {
189            self.config.momentum = momentum.to_scalar()?;
190        }
191        if let Some(nu) = state.get("nu") {
192            self.config.nu = nu.to_scalar()?;
193        }
194        if let Some(wd) = state.get("weight_decay") {
195            self.config.weight_decay = wd.to_scalar()?;
196        }
197        if let Some(step) = state.get("current_step") {
198            self.current_step = step.to_scalar()? as usize;
199        }
200
201        // Load momentum buffers
202        self.momentum_buffers.clear();
203        for (key, tensor) in state {
204            if let Some(param_id_str) = key.strip_prefix("momentum_buffer_") {
205                if let Ok(param_id) = param_id_str.parse::<usize>() {
206                    self.momentum_buffers.insert(param_id, tensor);
207                }
208            }
209        }
210
211        Ok(())
212    }
213}
214
215/// Configuration for Aggregated Momentum (AggMo).
216#[derive(Debug, Clone, Serialize, Deserialize)]
217pub struct AggMoConfig {
218    /// Learning rate
219    pub learning_rate: f32,
220    /// List of momentum coefficients
221    pub momentum_coefficients: Vec<f32>,
222    /// Weight decay
223    pub weight_decay: f32,
224}
225
226impl Default for AggMoConfig {
227    fn default() -> Self {
228        Self {
229            learning_rate: 1e-3,
230            momentum_coefficients: vec![0.0, 0.9, 0.99],
231            weight_decay: 0.0,
232        }
233    }
234}
235
236/// Aggregated Momentum optimizer.
237///
238/// AggMo maintains multiple momentum buffers with different decay rates
239/// and averages their contributions to improve convergence.
240#[derive(Debug)]
241pub struct AggMo {
242    config: AggMoConfig,
243    momentum_buffers: HashMap<usize, Vec<Tensor>>, // param_id -> list of momentum buffers
244    current_step: usize,
245}
246
247impl AggMo {
248    /// Create a new AggMo optimizer.
249    pub fn new(config: AggMoConfig) -> Self {
250        assert!(
251            !config.momentum_coefficients.is_empty(),
252            "Must provide at least one momentum coefficient"
253        );
254        Self {
255            config,
256            momentum_buffers: HashMap::new(),
257            current_step: 0,
258        }
259    }
260
261    /// Create AggMo with default configuration.
262    pub fn with_defaults(learning_rate: f32, momentum_coefficients: Vec<f32>) -> Self {
263        Self::new(AggMoConfig {
264            learning_rate,
265            momentum_coefficients,
266            weight_decay: 0.0,
267        })
268    }
269
270    /// Get the configuration.
271    pub fn get_config(&self) -> &AggMoConfig {
272        &self.config
273    }
274
275    /// Get the number of momentum buffers per parameter.
276    pub fn num_momentum_buffers(&self) -> usize {
277        self.config.momentum_coefficients.len()
278    }
279}
280
281impl OptimizerState for AggMo {
282    fn zero_grad(&mut self) -> Result<()> {
283        Ok(())
284    }
285
286    fn step(&mut self, parameters: &mut [Tensor]) -> Result<()> {
287        self.current_step += 1;
288
289        for (param_id, parameter) in parameters.iter_mut().enumerate() {
290            // Access gradient from parameter (should be computed during forward/backward pass)
291            let gradient = match parameter.grad() {
292                Ok(grad) => grad,
293                Err(_) => {
294                    // If gradient is not available, skip this parameter
295                    continue;
296                },
297            };
298
299            // Apply weight decay
300            let effective_grad = if self.config.weight_decay > 0.0 {
301                gradient.add(&parameter.mul_scalar(self.config.weight_decay)?)?
302            } else {
303                gradient
304            };
305
306            // Get or initialize momentum buffers for this parameter
307            let buffers = self.momentum_buffers.entry(param_id).or_insert_with(|| {
308                // Initialize all momentum buffers with zeros
309                (0..self.config.momentum_coefficients.len())
310                    .map(|_| {
311                        Tensor::zeros(&effective_grad.shape())
312                            .expect("zeros should always succeed for valid gradient shape")
313                    })
314                    .collect()
315            });
316
317            // Update each momentum buffer
318            let mut aggregated_momentum = Tensor::zeros(&effective_grad.shape())?;
319            for (i, &beta) in self.config.momentum_coefficients.iter().enumerate() {
320                // Update momentum: m_i = β_i * m_i + (1 - β_i) * grad
321                buffers[i] =
322                    buffers[i].mul_scalar(beta)?.add(&effective_grad.mul_scalar(1.0 - beta)?)?;
323
324                // Add to aggregated momentum
325                aggregated_momentum = aggregated_momentum.add(&buffers[i])?;
326            }
327
328            // Average the momentum buffers
329            let num_buffers = self.config.momentum_coefficients.len() as f32;
330            let averaged_momentum = aggregated_momentum.div_scalar(num_buffers)?;
331
332            // Apply update
333            *parameter =
334                parameter.sub(&averaged_momentum.mul_scalar(self.config.learning_rate)?)?;
335        }
336
337        Ok(())
338    }
339
340    fn get_lr(&self) -> f32 {
341        self.config.learning_rate
342    }
343
344    fn set_lr(&mut self, lr: f32) {
345        self.config.learning_rate = lr;
346    }
347
348    fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
349        let mut state = HashMap::new();
350
351        // Save configuration
352        state.insert(
353            "learning_rate".to_string(),
354            Tensor::scalar(self.config.learning_rate)?,
355        );
356        state.insert(
357            "weight_decay".to_string(),
358            Tensor::scalar(self.config.weight_decay)?,
359        );
360        state.insert(
361            "current_step".to_string(),
362            Tensor::scalar(self.current_step as f32)?,
363        );
364        state.insert(
365            "num_momentum_coeffs".to_string(),
366            Tensor::scalar(self.config.momentum_coefficients.len() as f32)?,
367        );
368
369        // Save momentum coefficients
370        for (i, &coeff) in self.config.momentum_coefficients.iter().enumerate() {
371            state.insert(format!("momentum_coeff_{}", i), Tensor::scalar(coeff)?);
372        }
373
374        // Save momentum buffers
375        for (&param_id, buffers) in &self.momentum_buffers {
376            for (buffer_idx, buffer) in buffers.iter().enumerate() {
377                state.insert(
378                    format!("momentum_buffer_{}_{}", param_id, buffer_idx),
379                    buffer.clone(),
380                );
381            }
382        }
383
384        Ok(state)
385    }
386
387    fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
388        // Load configuration
389        if let Some(lr) = state.get("learning_rate") {
390            self.config.learning_rate = lr.to_scalar()?;
391        }
392        if let Some(wd) = state.get("weight_decay") {
393            self.config.weight_decay = wd.to_scalar()?;
394        }
395        if let Some(step) = state.get("current_step") {
396            self.current_step = step.to_scalar()? as usize;
397        }
398
399        // Load momentum coefficients
400        if let Some(num_coeffs_tensor) = state.get("num_momentum_coeffs") {
401            let num_coeffs = num_coeffs_tensor.to_scalar()? as usize;
402            let mut coefficients = Vec::with_capacity(num_coeffs);
403            for i in 0..num_coeffs {
404                if let Some(coeff_tensor) = state.get(&format!("momentum_coeff_{}", i)) {
405                    coefficients.push(coeff_tensor.to_scalar()?);
406                }
407            }
408            self.config.momentum_coefficients = coefficients;
409        }
410
411        // Load momentum buffers
412        self.momentum_buffers.clear();
413        let mut param_buffers: HashMap<usize, HashMap<usize, Tensor>> = HashMap::new();
414
415        for (key, tensor) in state {
416            if key.starts_with("momentum_buffer_") {
417                let parts: Vec<&str> = key.split('_').collect();
418                if parts.len() >= 4 {
419                    if let (Ok(param_id), Ok(buffer_idx)) =
420                        (parts[2].parse::<usize>(), parts[3].parse::<usize>())
421                    {
422                        param_buffers.entry(param_id).or_default().insert(buffer_idx, tensor);
423                    }
424                }
425            }
426        }
427
428        // Reconstruct momentum buffers in correct order
429        for (param_id, buffer_map) in param_buffers {
430            let mut buffers = Vec::new();
431            for i in 0..self.config.momentum_coefficients.len() {
432                if let Some(buffer) = buffer_map.get(&i) {
433                    buffers.push(buffer.clone());
434                }
435            }
436            if buffers.len() == self.config.momentum_coefficients.len() {
437                self.momentum_buffers.insert(param_id, buffers);
438            }
439        }
440
441        Ok(())
442    }
443}
444
445/// Configuration for Variance Reduction methods.
446#[derive(Debug, Clone, Serialize, Deserialize)]
447pub struct VarianceReductionConfig {
448    /// Learning rate
449    pub learning_rate: f32,
450    /// Method type
451    pub method: VarianceReductionMethod,
452    /// Gradient history size for SVRG
453    pub history_size: usize,
454    /// Update frequency for full gradient computation
455    pub full_grad_frequency: usize,
456    /// Weight decay
457    pub weight_decay: f32,
458}
459
460impl Default for VarianceReductionConfig {
461    fn default() -> Self {
462        Self {
463            learning_rate: 1e-3,
464            method: VarianceReductionMethod::SVRG,
465            history_size: 100,
466            full_grad_frequency: 10,
467            weight_decay: 0.0,
468        }
469    }
470}
471
472/// Types of variance reduction methods.
473#[derive(Debug, Clone, Serialize, Deserialize)]
474pub enum VarianceReductionMethod {
475    /// Stochastic Variance Reduced Gradient
476    SVRG,
477    /// Stochastic Average Gradient
478    SAG,
479}
480
481/// Variance Reduction optimizer implementing SVRG and SAG methods.
482#[derive(Debug)]
483pub struct VarianceReduction {
484    config: VarianceReductionConfig,
485    gradient_history: HashMap<usize, Vec<Tensor>>,
486    average_gradients: HashMap<usize, Tensor>,
487    full_gradients: HashMap<usize, Tensor>,
488    current_step: usize,
489    last_full_grad_step: usize,
490}
491
492impl VarianceReduction {
493    /// Create a new variance reduction optimizer.
494    pub fn new(config: VarianceReductionConfig) -> Self {
495        Self {
496            config,
497            gradient_history: HashMap::new(),
498            average_gradients: HashMap::new(),
499            full_gradients: HashMap::new(),
500            current_step: 0,
501            last_full_grad_step: 0,
502        }
503    }
504
505    /// Create SVRG optimizer with default settings.
506    pub fn svrg(learning_rate: f32, history_size: usize, full_grad_frequency: usize) -> Self {
507        Self::new(VarianceReductionConfig {
508            learning_rate,
509            method: VarianceReductionMethod::SVRG,
510            history_size,
511            full_grad_frequency,
512            weight_decay: 0.0,
513        })
514    }
515
516    /// Create SAG optimizer with default settings.
517    pub fn sag(learning_rate: f32, history_size: usize) -> Self {
518        Self::new(VarianceReductionConfig {
519            learning_rate,
520            method: VarianceReductionMethod::SAG,
521            history_size,
522            full_grad_frequency: 1, // Not used for SAG
523            weight_decay: 0.0,
524        })
525    }
526
527    fn update_gradient_history(&mut self, param_id: usize, gradient: &Tensor) -> Result<()> {
528        let history = self.gradient_history.entry(param_id).or_default();
529
530        history.push(gradient.clone());
531        if history.len() > self.config.history_size {
532            history.remove(0);
533        }
534
535        Ok(())
536    }
537
538    fn compute_average_gradient(&mut self, param_id: usize) -> Result<Tensor> {
539        if let Some(history) = self.gradient_history.get(&param_id) {
540            if history.is_empty() {
541                return Err(anyhow!("No gradient history available"));
542            }
543
544            let mut sum = history[0].clone();
545            for grad in history.iter().skip(1) {
546                sum = sum.add(grad)?;
547            }
548
549            let average = sum.div_scalar(history.len() as f32)?;
550            self.average_gradients.insert(param_id, average.clone());
551            Ok(average)
552        } else {
553            Err(anyhow!("No gradient history for parameter {}", param_id))
554        }
555    }
556
557    fn should_compute_full_gradient(&self) -> bool {
558        self.current_step - self.last_full_grad_step >= self.config.full_grad_frequency
559    }
560}
561
562impl OptimizerState for VarianceReduction {
563    fn zero_grad(&mut self) -> Result<()> {
564        Ok(())
565    }
566
567    fn step(&mut self, parameters: &mut [Tensor]) -> Result<()> {
568        self.current_step += 1;
569
570        // Check if we need to compute full gradient (for SVRG)
571        let compute_full_grad = match self.config.method {
572            VarianceReductionMethod::SVRG => self.should_compute_full_gradient(),
573            VarianceReductionMethod::SAG => false,
574        };
575
576        if compute_full_grad {
577            self.last_full_grad_step = self.current_step;
578            // In practice, full gradient computation would require access to the full dataset
579            // Here we'll use the current gradient as an approximation
580            for (param_id, parameter) in parameters.iter().enumerate() {
581                // Access gradient from parameter (should be computed during forward/backward pass)
582                let gradient = match parameter.grad() {
583                    Ok(grad) => grad,
584                    Err(_) => {
585                        // If gradient is not available, skip this parameter
586                        continue;
587                    },
588                };
589                self.full_gradients.insert(param_id, gradient);
590            }
591        }
592
593        for (param_id, parameter) in parameters.iter_mut().enumerate() {
594            // Access gradient from parameter (should be computed during forward/backward pass)
595            let current_gradient = match parameter.grad() {
596                Ok(grad) => grad,
597                Err(_) => {
598                    // If gradient is not available, skip this parameter
599                    continue;
600                },
601            };
602
603            // Apply weight decay
604            let effective_grad = if self.config.weight_decay > 0.0 {
605                current_gradient.add(&parameter.mul_scalar(self.config.weight_decay)?)?
606            } else {
607                current_gradient
608            };
609
610            // Update gradient history
611            self.update_gradient_history(param_id, &effective_grad)?;
612
613            // Apply variance reduction
614            let variance_reduced_grad = match self.config.method {
615                VarianceReductionMethod::SVRG => {
616                    // Clone full_grad before mutable borrow for compute_average_gradient
617                    let full_grad_opt = self.full_gradients.get(&param_id).cloned();
618                    if let Some(full_grad) = full_grad_opt {
619                        let avg_grad = self.compute_average_gradient(param_id)?;
620                        // SVRG update: grad - avg_grad + full_grad
621                        effective_grad.sub(&avg_grad)?.add(&full_grad)?
622                    } else {
623                        effective_grad
624                    }
625                },
626                VarianceReductionMethod::SAG => {
627                    // SAG uses running average of gradients
628                    self.compute_average_gradient(param_id)?
629                },
630            };
631
632            // Apply update
633            *parameter =
634                parameter.sub(&variance_reduced_grad.mul_scalar(self.config.learning_rate)?)?;
635        }
636
637        Ok(())
638    }
639
640    fn get_lr(&self) -> f32 {
641        self.config.learning_rate
642    }
643
644    fn set_lr(&mut self, lr: f32) {
645        self.config.learning_rate = lr;
646    }
647
648    fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
649        let mut state = HashMap::new();
650
651        state.insert(
652            "learning_rate".to_string(),
653            Tensor::scalar(self.config.learning_rate)?,
654        );
655        state.insert(
656            "current_step".to_string(),
657            Tensor::scalar(self.current_step as f32)?,
658        );
659        state.insert(
660            "last_full_grad_step".to_string(),
661            Tensor::scalar(self.last_full_grad_step as f32)?,
662        );
663
664        // Note: Saving full gradient history would be expensive
665        // In practice, you might want to save only recent gradients or statistics
666
667        Ok(state)
668    }
669
670    fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
671        if let Some(lr) = state.get("learning_rate") {
672            self.config.learning_rate = lr.to_scalar()?;
673        }
674        if let Some(step) = state.get("current_step") {
675            self.current_step = step.to_scalar()? as usize;
676        }
677        if let Some(last_step) = state.get("last_full_grad_step") {
678            self.last_full_grad_step = last_step.to_scalar()? as usize;
679        }
680
681        Ok(())
682    }
683}
684
685/// Configuration for Nesterov Accelerated Gradient (NAG).
686#[derive(Debug, Clone, Serialize, Deserialize)]
687pub struct NesterovAcceleratedGradientConfig {
688    /// Learning rate
689    pub learning_rate: f32,
690    /// Momentum parameter
691    pub momentum: f32,
692    /// Weight decay
693    pub weight_decay: f32,
694    /// Whether to use strong convexity assumption for restart
695    pub restart_on_increase: bool,
696}
697
698impl Default for NesterovAcceleratedGradientConfig {
699    fn default() -> Self {
700        Self {
701            learning_rate: 1e-3,
702            momentum: 0.9,
703            weight_decay: 0.0,
704            restart_on_increase: false,
705        }
706    }
707}
708
709/// Nesterov Accelerated Gradient optimizer.
710///
711/// NAG uses lookahead to evaluate the gradient at the predicted next position,
712/// which can lead to faster convergence than standard momentum methods.
713/// Update rule:
714/// v_t = momentum * v_{t-1} + lr * grad(x_t + momentum * v_{t-1})
715/// x_{t+1} = x_t - v_t
716#[derive(Debug)]
717pub struct NesterovAcceleratedGradient {
718    config: NesterovAcceleratedGradientConfig,
719    velocity_buffers: HashMap<usize, Tensor>,
720    current_step: usize,
721    previous_loss: Option<f32>,
722}
723
724impl NesterovAcceleratedGradient {
725    /// Create a new NAG optimizer.
726    pub fn new(config: NesterovAcceleratedGradientConfig) -> Self {
727        Self {
728            config,
729            velocity_buffers: HashMap::new(),
730            current_step: 0,
731            previous_loss: None,
732        }
733    }
734
735    /// Create NAG with default configuration.
736    pub fn with_defaults(learning_rate: f32, momentum: f32) -> Self {
737        Self::new(NesterovAcceleratedGradientConfig {
738            learning_rate,
739            momentum,
740            weight_decay: 0.0,
741            restart_on_increase: false,
742        })
743    }
744
745    /// Get the configuration.
746    pub fn get_config(&self) -> &NesterovAcceleratedGradientConfig {
747        &self.config
748    }
749
750    /// Set a loss value for restart detection.
751    pub fn set_current_loss(&mut self, loss: f32) {
752        if self.config.restart_on_increase {
753            if let Some(prev_loss) = self.previous_loss {
754                if loss > prev_loss {
755                    // Restart by clearing velocity buffers
756                    self.velocity_buffers.clear();
757                }
758            }
759        }
760        self.previous_loss = Some(loss);
761    }
762}
763
764impl OptimizerState for NesterovAcceleratedGradient {
765    fn zero_grad(&mut self) -> Result<()> {
766        Ok(())
767    }
768
769    fn step(&mut self, parameters: &mut [Tensor]) -> Result<()> {
770        self.current_step += 1;
771
772        for (param_id, parameter) in parameters.iter_mut().enumerate() {
773            // Access gradient from parameter (should be computed during forward/backward pass)
774            let gradient = match parameter.grad() {
775                Ok(grad) => grad,
776                Err(_) => {
777                    // If gradient is not available, skip this parameter
778                    continue;
779                },
780            };
781
782            // Apply weight decay to gradient
783            let effective_grad = if self.config.weight_decay > 0.0 {
784                gradient.add(&parameter.mul_scalar(self.config.weight_decay)?)?
785            } else {
786                gradient
787            };
788
789            // Get or initialize velocity buffer
790            let velocity = if let Some(v) = self.velocity_buffers.get(&param_id) {
791                v.clone()
792            } else {
793                Tensor::zeros_like(parameter)?
794            };
795
796            // Nesterov acceleration: compute gradient at lookahead position
797            let _lookahead_position = parameter.sub(&velocity.mul_scalar(self.config.momentum)?)?;
798
799            // In practice, we'd need to recompute the gradient at the lookahead position
800            // For now, we'll use the current gradient as approximation
801            // Lookahead gradient computation using current gradient state
802
803            // Update velocity: v_t = momentum * v_{t-1} + lr * grad
804            let new_velocity = velocity
805                .mul_scalar(self.config.momentum)?
806                .add(&effective_grad.mul_scalar(self.config.learning_rate)?)?;
807
808            self.velocity_buffers.insert(param_id, new_velocity.clone());
809
810            // Update parameters: x_{t+1} = x_t - v_t
811            *parameter = parameter.sub(&new_velocity)?;
812        }
813
814        Ok(())
815    }
816
817    fn get_lr(&self) -> f32 {
818        self.config.learning_rate
819    }
820
821    fn set_lr(&mut self, lr: f32) {
822        self.config.learning_rate = lr;
823    }
824
825    fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
826        let mut state = HashMap::new();
827
828        state.insert(
829            "learning_rate".to_string(),
830            Tensor::scalar(self.config.learning_rate)?,
831        );
832        state.insert(
833            "momentum".to_string(),
834            Tensor::scalar(self.config.momentum)?,
835        );
836        state.insert(
837            "weight_decay".to_string(),
838            Tensor::scalar(self.config.weight_decay)?,
839        );
840        state.insert(
841            "current_step".to_string(),
842            Tensor::scalar(self.current_step as f32)?,
843        );
844
845        if let Some(loss) = self.previous_loss {
846            state.insert("previous_loss".to_string(), Tensor::scalar(loss)?);
847        }
848
849        for (&param_id, velocity) in &self.velocity_buffers {
850            state.insert(format!("velocity_{}", param_id), velocity.clone());
851        }
852
853        Ok(state)
854    }
855
856    fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
857        if let Some(lr) = state.get("learning_rate") {
858            self.config.learning_rate = lr.to_scalar()?;
859        }
860        if let Some(momentum) = state.get("momentum") {
861            self.config.momentum = momentum.to_scalar()?;
862        }
863        if let Some(wd) = state.get("weight_decay") {
864            self.config.weight_decay = wd.to_scalar()?;
865        }
866        if let Some(step) = state.get("current_step") {
867            self.current_step = step.to_scalar()? as usize;
868        }
869        if let Some(loss) = state.get("previous_loss") {
870            self.previous_loss = Some(loss.to_scalar()?);
871        }
872
873        self.velocity_buffers.clear();
874        for (key, tensor) in state {
875            if let Some(param_id_str) = key.strip_prefix("velocity_") {
876                if let Ok(param_id) = param_id_str.parse::<usize>() {
877                    self.velocity_buffers.insert(param_id, tensor);
878                }
879            }
880        }
881
882        Ok(())
883    }
884}
885
886/// Configuration for Heavy Ball Method.
887#[derive(Debug, Clone, Serialize, Deserialize)]
888pub struct HeavyBallConfig {
889    /// Learning rate
890    pub learning_rate: f32,
891    /// Momentum coefficient (β)
892    pub beta: f32,
893    /// Weight decay
894    pub weight_decay: f32,
895    /// Adaptive momentum based on gradient alignment
896    pub adaptive_momentum: bool,
897}
898
899impl Default for HeavyBallConfig {
900    fn default() -> Self {
901        Self {
902            learning_rate: 1e-3,
903            beta: 0.9,
904            weight_decay: 0.0,
905            adaptive_momentum: false,
906        }
907    }
908}
909
910/// Heavy Ball Method optimizer.
911///
912/// Classical momentum-based acceleration method that adds inertia to gradient descent.
913/// Update rule:
914/// v_t = β * v_{t-1} - lr * grad(x_t)
915/// x_{t+1} = x_t + v_t
916#[derive(Debug)]
917pub struct HeavyBall {
918    config: HeavyBallConfig,
919    velocity_buffers: HashMap<usize, Tensor>,
920    previous_gradients: HashMap<usize, Tensor>,
921    current_step: usize,
922}
923
924impl HeavyBall {
925    /// Create a new Heavy Ball optimizer.
926    pub fn new(config: HeavyBallConfig) -> Self {
927        Self {
928            config,
929            velocity_buffers: HashMap::new(),
930            previous_gradients: HashMap::new(),
931            current_step: 0,
932        }
933    }
934
935    /// Create Heavy Ball with default configuration.
936    pub fn with_defaults(learning_rate: f32, beta: f32) -> Self {
937        Self::new(HeavyBallConfig {
938            learning_rate,
939            beta,
940            weight_decay: 0.0,
941            adaptive_momentum: false,
942        })
943    }
944
945    /// Get the configuration.
946    pub fn get_config(&self) -> &HeavyBallConfig {
947        &self.config
948    }
949
950    /// Compute adaptive momentum based on gradient alignment.
951    fn compute_adaptive_momentum(&self, param_id: usize, current_grad: &Tensor) -> Result<f32> {
952        if let Some(prev_grad) = self.previous_gradients.get(&param_id) {
953            // Compute cosine similarity between current and previous gradients
954            let dot_product = current_grad.mul(prev_grad)?.sum(None, false)?;
955            let norm_current = current_grad.norm_squared()?.sqrt()?;
956            let norm_prev = prev_grad.norm_squared()?.sqrt()?;
957
958            let dot_scalar = dot_product.to_scalar()?;
959            let norm_current_scalar = norm_current.to_scalar()?;
960            let norm_prev_scalar = norm_prev.to_scalar()?;
961
962            let denominator = norm_current_scalar * norm_prev_scalar;
963            if denominator > 1e-8 {
964                let cosine_similarity = dot_scalar / denominator;
965                // Increase momentum when gradients are aligned, decrease when opposed
966                let adaptive_beta = self.config.beta * cosine_similarity.max(0.0);
967                Ok(adaptive_beta)
968            } else {
969                Ok(self.config.beta)
970            }
971        } else {
972            Ok(self.config.beta)
973        }
974    }
975}
976
977impl OptimizerState for HeavyBall {
978    fn zero_grad(&mut self) -> Result<()> {
979        Ok(())
980    }
981
982    fn step(&mut self, parameters: &mut [Tensor]) -> Result<()> {
983        self.current_step += 1;
984
985        for (param_id, parameter) in parameters.iter_mut().enumerate() {
986            // Access gradient from parameter (should be computed during forward/backward pass)
987            let gradient = match parameter.grad() {
988                Ok(grad) => grad,
989                Err(_) => {
990                    // If gradient is not available, skip this parameter
991                    continue;
992                },
993            };
994
995            // Apply weight decay to gradient
996            let effective_grad = if self.config.weight_decay > 0.0 {
997                gradient.add(&parameter.mul_scalar(self.config.weight_decay)?)?
998            } else {
999                gradient
1000            };
1001
1002            // Compute momentum coefficient
1003            let beta = if self.config.adaptive_momentum {
1004                self.compute_adaptive_momentum(param_id, &effective_grad)?
1005            } else {
1006                self.config.beta
1007            };
1008
1009            // Get or initialize velocity buffer
1010            let velocity = if let Some(v) = self.velocity_buffers.get(&param_id) {
1011                v.clone()
1012            } else {
1013                Tensor::zeros_like(parameter)?
1014            };
1015
1016            // Heavy Ball update: v_t = β * v_{t-1} - lr * grad
1017            let new_velocity = velocity
1018                .mul_scalar(beta)?
1019                .sub(&effective_grad.mul_scalar(self.config.learning_rate)?)?;
1020
1021            self.velocity_buffers.insert(param_id, new_velocity.clone());
1022
1023            // Update parameters: x_{t+1} = x_t + v_t
1024            *parameter = parameter.add(&new_velocity)?;
1025
1026            // Store gradient for adaptive momentum
1027            if self.config.adaptive_momentum {
1028                self.previous_gradients.insert(param_id, effective_grad);
1029            }
1030        }
1031
1032        Ok(())
1033    }
1034
1035    fn get_lr(&self) -> f32 {
1036        self.config.learning_rate
1037    }
1038
1039    fn set_lr(&mut self, lr: f32) {
1040        self.config.learning_rate = lr;
1041    }
1042
1043    fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
1044        let mut state = HashMap::new();
1045
1046        state.insert(
1047            "learning_rate".to_string(),
1048            Tensor::scalar(self.config.learning_rate)?,
1049        );
1050        state.insert("beta".to_string(), Tensor::scalar(self.config.beta)?);
1051        state.insert(
1052            "weight_decay".to_string(),
1053            Tensor::scalar(self.config.weight_decay)?,
1054        );
1055        state.insert(
1056            "current_step".to_string(),
1057            Tensor::scalar(self.current_step as f32)?,
1058        );
1059
1060        for (&param_id, velocity) in &self.velocity_buffers {
1061            state.insert(format!("velocity_{}", param_id), velocity.clone());
1062        }
1063
1064        for (&param_id, grad) in &self.previous_gradients {
1065            state.insert(format!("prev_grad_{}", param_id), grad.clone());
1066        }
1067
1068        Ok(state)
1069    }
1070
1071    fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
1072        if let Some(lr) = state.get("learning_rate") {
1073            self.config.learning_rate = lr.to_scalar()?;
1074        }
1075        if let Some(beta) = state.get("beta") {
1076            self.config.beta = beta.to_scalar()?;
1077        }
1078        if let Some(wd) = state.get("weight_decay") {
1079            self.config.weight_decay = wd.to_scalar()?;
1080        }
1081        if let Some(step) = state.get("current_step") {
1082            self.current_step = step.to_scalar()? as usize;
1083        }
1084
1085        self.velocity_buffers.clear();
1086        self.previous_gradients.clear();
1087
1088        for (key, tensor) in state {
1089            if let Some(param_id_str) = key.strip_prefix("velocity_") {
1090                if let Ok(param_id) = param_id_str.parse::<usize>() {
1091                    self.velocity_buffers.insert(param_id, tensor);
1092                }
1093            } else if let Some(param_id_str) = key.strip_prefix("prev_grad_") {
1094                if let Ok(param_id) = param_id_str.parse::<usize>() {
1095                    self.previous_gradients.insert(param_id, tensor);
1096                }
1097            }
1098        }
1099
1100        Ok(())
1101    }
1102}
1103
1104/// Configuration for FISTA (Fast Iterative Shrinkage-Thresholding Algorithm).
1105#[derive(Debug, Clone, Serialize, Deserialize)]
1106pub struct FISTAConfig {
1107    /// Learning rate
1108    pub learning_rate: f32,
1109    /// Proximal threshold parameter
1110    pub threshold: f32,
1111    /// Whether to use adaptive restart
1112    pub adaptive_restart: bool,
1113    /// Weight decay
1114    pub weight_decay: f32,
1115}
1116
1117impl Default for FISTAConfig {
1118    fn default() -> Self {
1119        Self {
1120            learning_rate: 1e-3,
1121            threshold: 1e-4,
1122            adaptive_restart: true,
1123            weight_decay: 0.0,
1124        }
1125    }
1126}
1127
1128/// FISTA optimizer for problems with L1 regularization or other proximal operators.
1129///
1130/// FISTA is designed for problems of the form: min f(x) + λ||x||_1
1131/// where f(x) is smooth and convex, and λ||x||_1 is the L1 regularization term.
1132#[derive(Debug)]
1133pub struct FISTA {
1134    config: FISTAConfig,
1135    previous_params: HashMap<usize, Tensor>,
1136    current_step: usize,
1137    momentum_coefficient: f32,
1138    previous_momentum: f32,
1139}
1140
1141impl FISTA {
1142    /// Create a new FISTA optimizer.
1143    pub fn new(config: FISTAConfig) -> Self {
1144        Self {
1145            config,
1146            previous_params: HashMap::new(),
1147            current_step: 0,
1148            momentum_coefficient: 1.0,
1149            previous_momentum: 1.0,
1150        }
1151    }
1152
1153    /// Create FISTA with default configuration.
1154    pub fn with_defaults(learning_rate: f32, threshold: f32) -> Self {
1155        Self::new(FISTAConfig {
1156            learning_rate,
1157            threshold,
1158            adaptive_restart: true,
1159            weight_decay: 0.0,
1160        })
1161    }
1162
1163    /// Get the configuration.
1164    pub fn get_config(&self) -> &FISTAConfig {
1165        &self.config
1166    }
1167
1168    /// Apply soft thresholding (proximal operator for L1 regularization).
1169    fn soft_threshold(&self, tensor: &Tensor, threshold: f32) -> Result<Tensor> {
1170        let threshold_tensor = Tensor::scalar(threshold)?;
1171        let zero_tensor = Tensor::zeros_like(tensor)?;
1172
1173        // Soft thresholding: sign(x) * max(0, |x| - threshold)
1174        let abs_tensor = tensor.abs()?;
1175        let thresholded = abs_tensor.sub(&threshold_tensor)?.max(&zero_tensor)?;
1176        let sign_tensor = tensor.sign()?;
1177
1178        Ok(sign_tensor.mul(&thresholded)?)
1179    }
1180
1181    /// Update momentum coefficient using FISTA formula.
1182    fn update_momentum_coefficient(&mut self) {
1183        let t = self.current_step as f32;
1184        self.previous_momentum = self.momentum_coefficient;
1185        self.momentum_coefficient = (1.0 + (1.0 + 4.0 * t * t).sqrt()) / 2.0;
1186    }
1187}
1188
1189impl OptimizerState for FISTA {
1190    fn zero_grad(&mut self) -> Result<()> {
1191        Ok(())
1192    }
1193
1194    fn step(&mut self, parameters: &mut [Tensor]) -> Result<()> {
1195        self.current_step += 1;
1196        self.update_momentum_coefficient();
1197
1198        for (param_id, parameter) in parameters.iter_mut().enumerate() {
1199            // Access gradient from parameter (should be computed during forward/backward pass)
1200            let gradient = match parameter.grad() {
1201                Ok(grad) => grad,
1202                Err(_) => {
1203                    // If gradient is not available, skip this parameter
1204                    continue;
1205                },
1206            };
1207
1208            // Apply weight decay to gradient
1209            let effective_grad = if self.config.weight_decay > 0.0 {
1210                gradient.add(&parameter.mul_scalar(self.config.weight_decay)?)?
1211            } else {
1212                gradient
1213            };
1214
1215            // Get previous parameter value
1216            let previous_param = if let Some(prev) = self.previous_params.get(&param_id) {
1217                prev.clone()
1218            } else {
1219                parameter.clone()
1220            };
1221
1222            // Momentum coefficient ratio
1223            let beta = (self.previous_momentum - 1.0) / self.momentum_coefficient;
1224
1225            // Compute extrapolated point
1226            let extrapolated = parameter.add(&previous_param.sub(parameter)?.mul_scalar(beta)?)?;
1227
1228            // Gradient step
1229            let grad_step =
1230                extrapolated.sub(&effective_grad.mul_scalar(self.config.learning_rate)?)?;
1231
1232            // Apply proximal operator (soft thresholding)
1233            let new_parameter = self.soft_threshold(&grad_step, self.config.threshold)?;
1234
1235            // Store current parameter for next iteration
1236            self.previous_params.insert(param_id, parameter.clone());
1237
1238            // Update parameter
1239            *parameter = new_parameter;
1240        }
1241
1242        Ok(())
1243    }
1244
1245    fn get_lr(&self) -> f32 {
1246        self.config.learning_rate
1247    }
1248
1249    fn set_lr(&mut self, lr: f32) {
1250        self.config.learning_rate = lr;
1251    }
1252
1253    fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
1254        let mut state = HashMap::new();
1255
1256        state.insert(
1257            "learning_rate".to_string(),
1258            Tensor::scalar(self.config.learning_rate)?,
1259        );
1260        state.insert(
1261            "threshold".to_string(),
1262            Tensor::scalar(self.config.threshold)?,
1263        );
1264        state.insert(
1265            "weight_decay".to_string(),
1266            Tensor::scalar(self.config.weight_decay)?,
1267        );
1268        state.insert(
1269            "current_step".to_string(),
1270            Tensor::scalar(self.current_step as f32)?,
1271        );
1272        state.insert(
1273            "momentum_coefficient".to_string(),
1274            Tensor::scalar(self.momentum_coefficient)?,
1275        );
1276        state.insert(
1277            "previous_momentum".to_string(),
1278            Tensor::scalar(self.previous_momentum)?,
1279        );
1280
1281        for (&param_id, param) in &self.previous_params {
1282            state.insert(format!("prev_param_{}", param_id), param.clone());
1283        }
1284
1285        Ok(state)
1286    }
1287
1288    fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
1289        if let Some(lr) = state.get("learning_rate") {
1290            self.config.learning_rate = lr.to_scalar()?;
1291        }
1292        if let Some(threshold) = state.get("threshold") {
1293            self.config.threshold = threshold.to_scalar()?;
1294        }
1295        if let Some(wd) = state.get("weight_decay") {
1296            self.config.weight_decay = wd.to_scalar()?;
1297        }
1298        if let Some(step) = state.get("current_step") {
1299            self.current_step = step.to_scalar()? as usize;
1300        }
1301        if let Some(momentum) = state.get("momentum_coefficient") {
1302            self.momentum_coefficient = momentum.to_scalar()?;
1303        }
1304        if let Some(prev_momentum) = state.get("previous_momentum") {
1305            self.previous_momentum = prev_momentum.to_scalar()?;
1306        }
1307
1308        self.previous_params.clear();
1309        for (key, tensor) in state {
1310            if let Some(param_id_str) = key.strip_prefix("prev_param_") {
1311                if let Ok(param_id) = param_id_str.parse::<usize>() {
1312                    self.previous_params.insert(param_id, tensor);
1313                }
1314            }
1315        }
1316
1317        Ok(())
1318    }
1319}
1320
1321/// Configuration for Adaptive Batch Sizing.
1322#[derive(Debug, Clone, Serialize, Deserialize)]
1323pub struct AdaptiveBatchSizingConfig {
1324    /// Initial batch size
1325    pub initial_batch_size: usize,
1326    /// Minimum batch size
1327    pub min_batch_size: usize,
1328    /// Maximum batch size
1329    pub max_batch_size: usize,
1330    /// Tolerance for gradient variance
1331    pub gradient_variance_tolerance: f32,
1332    /// Learning rate adaptation factor
1333    pub lr_adaptation_factor: f32,
1334    /// Window size for gradient variance calculation
1335    pub variance_window_size: usize,
1336    /// Threshold for increasing batch size
1337    pub increase_threshold: f32,
1338    /// Threshold for decreasing batch size
1339    pub decrease_threshold: f32,
1340}
1341
1342impl Default for AdaptiveBatchSizingConfig {
1343    fn default() -> Self {
1344        Self {
1345            initial_batch_size: 32,
1346            min_batch_size: 8,
1347            max_batch_size: 512,
1348            gradient_variance_tolerance: 0.1,
1349            lr_adaptation_factor: 0.8,
1350            variance_window_size: 10,
1351            increase_threshold: 0.05,
1352            decrease_threshold: 0.2,
1353        }
1354    }
1355}
1356
1357/// Adaptive Batch Sizing utility for dynamically adjusting batch size based on training progress.
1358///
1359/// This strategy monitors gradient variance and training stability to determine optimal batch sizes.
1360/// When gradient variance is high, it increases batch size to reduce noise.
1361/// When variance is low, it may decrease batch size to improve convergence speed.
1362#[derive(Debug)]
1363pub struct AdaptiveBatchSizing {
1364    config: AdaptiveBatchSizingConfig,
1365    current_batch_size: usize,
1366    gradient_variance_history: Vec<f32>,
1367    loss_history: Vec<f32>,
1368    current_step: usize,
1369    last_adjustment_step: usize,
1370}
1371
1372impl AdaptiveBatchSizing {
1373    /// Create a new adaptive batch sizing utility.
1374    pub fn new(config: AdaptiveBatchSizingConfig) -> Self {
1375        let initial_batch_size = config.initial_batch_size;
1376        Self {
1377            config,
1378            current_batch_size: initial_batch_size,
1379            gradient_variance_history: Vec::new(),
1380            loss_history: Vec::new(),
1381            current_step: 0,
1382            last_adjustment_step: 0,
1383        }
1384    }
1385
1386    /// Create with default configuration.
1387    pub fn with_defaults(
1388        initial_batch_size: usize,
1389        min_batch_size: usize,
1390        max_batch_size: usize,
1391    ) -> Self {
1392        Self::new(AdaptiveBatchSizingConfig {
1393            initial_batch_size,
1394            min_batch_size,
1395            max_batch_size,
1396            ..Default::default()
1397        })
1398    }
1399
1400    /// Get current batch size.
1401    pub fn current_batch_size(&self) -> usize {
1402        self.current_batch_size
1403    }
1404
1405    /// Get the configuration.
1406    pub fn get_config(&self) -> &AdaptiveBatchSizingConfig {
1407        &self.config
1408    }
1409
1410    /// Update with current gradient variance and loss.
1411    pub fn update(&mut self, gradient_variance: f32, current_loss: f32) -> Result<usize> {
1412        self.current_step += 1;
1413
1414        // Add to history
1415        self.gradient_variance_history.push(gradient_variance);
1416        self.loss_history.push(current_loss);
1417
1418        // Keep only recent history
1419        if self.gradient_variance_history.len() > self.config.variance_window_size {
1420            self.gradient_variance_history.remove(0);
1421        }
1422        if self.loss_history.len() > self.config.variance_window_size {
1423            self.loss_history.remove(0);
1424        }
1425
1426        // Check if we should adjust batch size
1427        if self.should_adjust_batch_size() {
1428            self.adjust_batch_size()?;
1429            self.last_adjustment_step = self.current_step;
1430        }
1431
1432        Ok(self.current_batch_size)
1433    }
1434
1435    /// Compute gradient variance from gradients.
1436    pub fn compute_gradient_variance(&self, gradients: &[Tensor]) -> Result<f32> {
1437        if gradients.is_empty() {
1438            return Ok(0.0);
1439        }
1440
1441        // Compute mean gradient
1442        let mut mean_grad = gradients[0].clone();
1443        for grad in gradients.iter().skip(1) {
1444            mean_grad = mean_grad.add(grad)?;
1445        }
1446        mean_grad = mean_grad.div_scalar(gradients.len() as f32)?;
1447
1448        // Compute variance
1449        let mut variance_sum = 0.0;
1450        for grad in gradients {
1451            let diff = grad.sub(&mean_grad)?;
1452            let squared_norm = diff.mul(&diff)?.sum(None, false)?;
1453            variance_sum += squared_norm.to_scalar()?;
1454        }
1455
1456        Ok(variance_sum / gradients.len() as f32)
1457    }
1458
1459    fn should_adjust_batch_size(&self) -> bool {
1460        // Don't adjust too frequently
1461        if self.current_step - self.last_adjustment_step < 5 {
1462            return false;
1463        }
1464
1465        // Need enough history
1466        self.gradient_variance_history.len() >= 3
1467    }
1468
1469    fn adjust_batch_size(&mut self) -> Result<()> {
1470        let recent_variance = self.recent_average_variance();
1471        let variance_trend = self.variance_trend();
1472        let loss_trend = self.loss_trend();
1473
1474        // Decide whether to increase or decrease batch size
1475        if recent_variance > self.config.decrease_threshold && variance_trend > 0.0 {
1476            // High variance and increasing - increase batch size
1477            self.increase_batch_size();
1478        } else if recent_variance < self.config.increase_threshold && loss_trend < -0.01 {
1479            // Low variance and decreasing loss - try smaller batch size
1480            self.decrease_batch_size();
1481        }
1482
1483        Ok(())
1484    }
1485
1486    fn recent_average_variance(&self) -> f32 {
1487        if self.gradient_variance_history.is_empty() {
1488            return 0.0;
1489        }
1490
1491        let recent_window = std::cmp::min(5, self.gradient_variance_history.len());
1492        let start_idx = self.gradient_variance_history.len() - recent_window;
1493
1494        self.gradient_variance_history[start_idx..].iter().sum::<f32>() / recent_window as f32
1495    }
1496
1497    fn variance_trend(&self) -> f32 {
1498        if self.gradient_variance_history.len() < 3 {
1499            return 0.0;
1500        }
1501
1502        let len = self.gradient_variance_history.len();
1503        let recent = self.gradient_variance_history[len - 2..].iter().sum::<f32>() / 2.0;
1504        let older = self.gradient_variance_history[len - 4..len - 2].iter().sum::<f32>() / 2.0;
1505
1506        recent - older
1507    }
1508
1509    fn loss_trend(&self) -> f32 {
1510        if self.loss_history.len() < 3 {
1511            return 0.0;
1512        }
1513
1514        let len = self.loss_history.len();
1515        let recent = self.loss_history[len - 2..].iter().sum::<f32>() / 2.0;
1516        let older = self.loss_history[len - 4..len - 2].iter().sum::<f32>() / 2.0;
1517
1518        (recent - older) / older.max(1e-8)
1519    }
1520
1521    fn increase_batch_size(&mut self) {
1522        let new_size = (self.current_batch_size as f32 * 1.5) as usize;
1523        self.current_batch_size = new_size.min(self.config.max_batch_size);
1524    }
1525
1526    fn decrease_batch_size(&mut self) {
1527        let new_size = (self.current_batch_size as f32 * 0.8) as usize;
1528        self.current_batch_size = new_size.max(self.config.min_batch_size);
1529    }
1530
1531    /// Get suggested learning rate adjustment based on batch size changes.
1532    pub fn get_lr_adjustment(&self, original_batch_size: usize) -> f32 {
1533        let ratio = self.current_batch_size as f32 / original_batch_size as f32;
1534        ratio.sqrt() * self.config.lr_adaptation_factor
1535    }
1536
1537    /// Reset state for new training run.
1538    pub fn reset(&mut self) {
1539        self.current_batch_size = self.config.initial_batch_size;
1540        self.gradient_variance_history.clear();
1541        self.loss_history.clear();
1542        self.current_step = 0;
1543        self.last_adjustment_step = 0;
1544    }
1545}
1546
1547/// Configuration for Loss Surface Smoothing.
1548#[derive(Debug, Clone, Serialize, Deserialize)]
1549pub struct LossSurfaceSmoothingConfig {
1550    /// Smoothing strength parameter
1551    pub smoothing_strength: f32,
1552    /// Noise injection variance
1553    pub noise_variance: f32,
1554    /// Exponential moving average decay
1555    pub ema_decay: f32,
1556    /// Number of gradient steps to average
1557    pub averaging_window: usize,
1558    /// Whether to use gradient averaging
1559    pub use_gradient_averaging: bool,
1560    /// Whether to use noise injection
1561    pub use_noise_injection: bool,
1562}
1563
1564impl Default for LossSurfaceSmoothingConfig {
1565    fn default() -> Self {
1566        Self {
1567            smoothing_strength: 0.1,
1568            noise_variance: 1e-4,
1569            ema_decay: 0.9,
1570            averaging_window: 5,
1571            use_gradient_averaging: true,
1572            use_noise_injection: false,
1573        }
1574    }
1575}
1576
1577/// Loss Surface Smoothing utility for reducing noise in the loss landscape.
1578///
1579/// This implements several techniques to smooth the loss surface:
1580/// - Gradient averaging over multiple steps
1581/// - Exponential moving average of gradients
1582/// - Controlled noise injection for exploration
1583/// - Parameter smoothing to reduce sharp changes
1584#[derive(Debug)]
1585pub struct LossSurfaceSmoothing {
1586    config: LossSurfaceSmoothingConfig,
1587    gradient_history: HashMap<usize, Vec<Tensor>>,
1588    ema_gradients: HashMap<usize, Tensor>,
1589    smoothed_parameters: HashMap<usize, Tensor>,
1590    current_step: usize,
1591}
1592
1593impl LossSurfaceSmoothing {
1594    /// Create a new loss surface smoothing utility.
1595    pub fn new(config: LossSurfaceSmoothingConfig) -> Self {
1596        Self {
1597            config,
1598            gradient_history: HashMap::new(),
1599            ema_gradients: HashMap::new(),
1600            smoothed_parameters: HashMap::new(),
1601            current_step: 0,
1602        }
1603    }
1604
1605    /// Create with default configuration.
1606    pub fn with_defaults(smoothing_strength: f32, use_noise: bool) -> Self {
1607        Self::new(LossSurfaceSmoothingConfig {
1608            smoothing_strength,
1609            use_noise_injection: use_noise,
1610            ..Default::default()
1611        })
1612    }
1613
1614    /// Get the configuration.
1615    pub fn get_config(&self) -> &LossSurfaceSmoothingConfig {
1616        &self.config
1617    }
1618
1619    /// Apply smoothing to gradients.
1620    pub fn smooth_gradients(&mut self, parameters: &mut [Tensor]) -> Result<()> {
1621        self.current_step += 1;
1622
1623        for (param_id, parameter) in parameters.iter_mut().enumerate() {
1624            let original_grad = parameter.grad()?;
1625            let mut smoothed_grad = original_grad.clone();
1626
1627            // Apply gradient averaging
1628            if self.config.use_gradient_averaging {
1629                smoothed_grad = self.apply_gradient_averaging(param_id, &original_grad)?;
1630            }
1631
1632            // Apply exponential moving average
1633            smoothed_grad = self.apply_ema_smoothing(param_id, &smoothed_grad)?;
1634
1635            // Apply noise injection for exploration
1636            if self.config.use_noise_injection {
1637                smoothed_grad = self.apply_noise_injection(&smoothed_grad)?;
1638            }
1639
1640            // Update parameter gradient
1641            parameter.set_grad(smoothed_grad)?;
1642        }
1643
1644        Ok(())
1645    }
1646
1647    /// Apply parameter smoothing.
1648    pub fn smooth_parameters(&mut self, parameters: &mut [Tensor]) -> Result<()> {
1649        for (param_id, parameter) in parameters.iter_mut().enumerate() {
1650            if let Some(smoothed_param) = self.smoothed_parameters.get(&param_id) {
1651                // Apply exponential moving average to parameters
1652                let new_smoothed = smoothed_param
1653                    .mul_scalar(self.config.ema_decay)?
1654                    .add(&parameter.mul_scalar(1.0 - self.config.ema_decay)?)?;
1655
1656                // Interpolate between original and smoothed parameters
1657                *parameter = parameter
1658                    .mul_scalar(1.0 - self.config.smoothing_strength)?
1659                    .add(&new_smoothed.mul_scalar(self.config.smoothing_strength)?)?;
1660
1661                self.smoothed_parameters.insert(param_id, new_smoothed);
1662            } else {
1663                // Initialize smoothed parameter
1664                self.smoothed_parameters.insert(param_id, parameter.clone());
1665            }
1666        }
1667
1668        Ok(())
1669    }
1670
1671    fn apply_gradient_averaging(&mut self, param_id: usize, gradient: &Tensor) -> Result<Tensor> {
1672        let history = self.gradient_history.entry(param_id).or_default();
1673
1674        history.push(gradient.clone());
1675        if history.len() > self.config.averaging_window {
1676            history.remove(0);
1677        }
1678
1679        // Compute average of recent gradients
1680        if history.len() == 1 {
1681            Ok(gradient.clone())
1682        } else {
1683            let mut sum = history[0].clone();
1684            for grad in history.iter().skip(1) {
1685                sum = sum.add(grad)?;
1686            }
1687            Ok(sum.div_scalar(history.len() as f32)?)
1688        }
1689    }
1690
1691    fn apply_ema_smoothing(&mut self, param_id: usize, gradient: &Tensor) -> Result<Tensor> {
1692        if let Some(ema_grad) = self.ema_gradients.get(&param_id) {
1693            let new_ema = ema_grad
1694                .mul_scalar(self.config.ema_decay)?
1695                .add(&gradient.mul_scalar(1.0 - self.config.ema_decay)?)?;
1696            self.ema_gradients.insert(param_id, new_ema.clone());
1697            Ok(new_ema)
1698        } else {
1699            self.ema_gradients.insert(param_id, gradient.clone());
1700            Ok(gradient.clone())
1701        }
1702    }
1703
1704    fn apply_noise_injection(&self, gradient: &Tensor) -> Result<Tensor> {
1705        let noise = Tensor::randn_like(gradient)
1706            .map_err(|e| anyhow!("Failed to create noise tensor: {}", e))?
1707            .mul_scalar(self.config.noise_variance.sqrt())
1708            .map_err(|e| anyhow!("Failed to scale noise tensor: {}", e))?;
1709        gradient
1710            .add(&noise)
1711            .map_err(|e| anyhow!("Failed to add noise to gradient: {}", e))
1712    }
1713
1714    /// Reset state for new training run.
1715    pub fn reset(&mut self) {
1716        self.gradient_history.clear();
1717        self.ema_gradients.clear();
1718        self.smoothed_parameters.clear();
1719        self.current_step = 0;
1720    }
1721
1722    /// Get smoothing statistics.
1723    pub fn get_statistics(&self) -> HashMap<String, f32> {
1724        let mut stats = HashMap::new();
1725        stats.insert("current_step".to_string(), self.current_step as f32);
1726        stats.insert(
1727            "num_tracked_params".to_string(),
1728            self.gradient_history.len() as f32,
1729        );
1730        stats.insert(
1731            "smoothing_strength".to_string(),
1732            self.config.smoothing_strength,
1733        );
1734        stats.insert("ema_decay".to_string(), self.config.ema_decay);
1735        stats
1736    }
1737}
1738
1739#[cfg(test)]
1740mod tests {
1741    use super::*;
1742
1743    #[test]
1744    fn test_qhm_config_default() {
1745        let config = QHMConfig::default();
1746        assert_eq!(config.learning_rate, 1e-3);
1747        assert_eq!(config.momentum, 0.9);
1748        assert_eq!(config.nu, 0.7);
1749        assert_eq!(config.weight_decay, 0.0);
1750    }
1751
1752    #[test]
1753    fn test_aggmo_config_default() {
1754        let config = AggMoConfig::default();
1755        assert_eq!(config.learning_rate, 1e-3);
1756        assert_eq!(config.momentum_coefficients, vec![0.0, 0.9, 0.99]);
1757        assert_eq!(config.weight_decay, 0.0);
1758    }
1759
1760    #[test]
1761    fn test_qhm_creation() {
1762        let optimizer = QHM::with_defaults(1e-3, 0.9, 0.7);
1763        assert_eq!(optimizer.get_lr(), 1e-3);
1764        assert_eq!(optimizer.current_step, 0);
1765    }
1766
1767    #[test]
1768    fn test_aggmo_creation() {
1769        let optimizer = AggMo::with_defaults(1e-3, vec![0.0, 0.9, 0.99]);
1770        assert_eq!(optimizer.get_lr(), 1e-3);
1771        assert_eq!(optimizer.num_momentum_buffers(), 3);
1772    }
1773
1774    #[test]
1775    fn test_variance_reduction_svrg() {
1776        let optimizer = VarianceReduction::svrg(1e-3, 50, 10);
1777        assert_eq!(optimizer.get_lr(), 1e-3);
1778        assert_eq!(optimizer.current_step, 0);
1779    }
1780
1781    #[test]
1782    fn test_variance_reduction_sag() {
1783        let optimizer = VarianceReduction::sag(1e-3, 100);
1784        assert_eq!(optimizer.get_lr(), 1e-3);
1785        assert!(matches!(
1786            optimizer.config.method,
1787            VarianceReductionMethod::SAG
1788        ));
1789    }
1790
1791    #[test]
1792    fn test_nesterov_accelerated_gradient_config() {
1793        let config = NesterovAcceleratedGradientConfig::default();
1794        assert_eq!(config.learning_rate, 1e-3);
1795        assert_eq!(config.momentum, 0.9);
1796        assert_eq!(config.weight_decay, 0.0);
1797        assert!(!config.restart_on_increase);
1798    }
1799
1800    #[test]
1801    fn test_nesterov_accelerated_gradient_creation() {
1802        let optimizer = NesterovAcceleratedGradient::with_defaults(1e-3, 0.9);
1803        assert_eq!(optimizer.get_lr(), 1e-3);
1804        assert_eq!(optimizer.current_step, 0);
1805        assert!(optimizer.previous_loss.is_none());
1806    }
1807
1808    #[test]
1809    fn test_nesterov_restart_on_increase() {
1810        let mut optimizer = NesterovAcceleratedGradient::new(NesterovAcceleratedGradientConfig {
1811            learning_rate: 1e-3,
1812            momentum: 0.9,
1813            weight_decay: 0.0,
1814            restart_on_increase: true,
1815        });
1816
1817        // Set initial loss
1818        optimizer.set_current_loss(1.0);
1819        assert_eq!(optimizer.previous_loss, Some(1.0));
1820
1821        // Increasing loss should trigger restart
1822        optimizer.set_current_loss(1.5);
1823        assert_eq!(optimizer.previous_loss, Some(1.5));
1824    }
1825
1826    #[test]
1827    fn test_heavy_ball_config() {
1828        let config = HeavyBallConfig::default();
1829        assert_eq!(config.learning_rate, 1e-3);
1830        assert_eq!(config.beta, 0.9);
1831        assert_eq!(config.weight_decay, 0.0);
1832        assert!(!config.adaptive_momentum);
1833    }
1834
1835    #[test]
1836    fn test_heavy_ball_creation() {
1837        let optimizer = HeavyBall::with_defaults(1e-3, 0.9);
1838        assert_eq!(optimizer.get_lr(), 1e-3);
1839        assert_eq!(optimizer.current_step, 0);
1840        assert_eq!(optimizer.get_config().beta, 0.9);
1841    }
1842
1843    #[test]
1844    fn test_heavy_ball_adaptive_momentum() {
1845        let optimizer = HeavyBall::new(HeavyBallConfig {
1846            learning_rate: 1e-3,
1847            beta: 0.9,
1848            weight_decay: 0.0,
1849            adaptive_momentum: true,
1850        });
1851
1852        assert!(optimizer.config.adaptive_momentum);
1853    }
1854
1855    #[test]
1856    fn test_fista_config() {
1857        let config = FISTAConfig::default();
1858        assert_eq!(config.learning_rate, 1e-3);
1859        assert_eq!(config.threshold, 1e-4);
1860        assert!(config.adaptive_restart);
1861        assert_eq!(config.weight_decay, 0.0);
1862    }
1863
1864    #[test]
1865    fn test_fista_creation() {
1866        let optimizer = FISTA::with_defaults(1e-3, 1e-4);
1867        assert_eq!(optimizer.get_lr(), 1e-3);
1868        assert_eq!(optimizer.current_step, 0);
1869        assert_eq!(optimizer.momentum_coefficient, 1.0);
1870        assert_eq!(optimizer.previous_momentum, 1.0);
1871    }
1872
1873    #[test]
1874    fn test_fista_momentum_update() {
1875        let mut optimizer = FISTA::with_defaults(1e-3, 1e-4);
1876
1877        // Momentum coefficient should update with step (increment step first)
1878        optimizer.current_step = 1;
1879        optimizer.update_momentum_coefficient();
1880        assert!(optimizer.momentum_coefficient > 1.0);
1881        assert_eq!(optimizer.previous_momentum, 1.0);
1882
1883        let prev_momentum = optimizer.momentum_coefficient;
1884        optimizer.current_step = 2;
1885        optimizer.update_momentum_coefficient();
1886        assert!(optimizer.momentum_coefficient > prev_momentum);
1887    }
1888
1889    #[test]
1890    fn test_adaptive_batch_sizing_config() {
1891        let config = AdaptiveBatchSizingConfig::default();
1892        assert_eq!(config.initial_batch_size, 32);
1893        assert_eq!(config.min_batch_size, 8);
1894        assert_eq!(config.max_batch_size, 512);
1895        assert_eq!(config.gradient_variance_tolerance, 0.1);
1896        assert_eq!(config.lr_adaptation_factor, 0.8);
1897        assert_eq!(config.variance_window_size, 10);
1898        assert_eq!(config.increase_threshold, 0.05);
1899        assert_eq!(config.decrease_threshold, 0.2);
1900    }
1901
1902    #[test]
1903    fn test_adaptive_batch_sizing_creation() {
1904        let abs = AdaptiveBatchSizing::with_defaults(64, 16, 256);
1905        assert_eq!(abs.current_batch_size(), 64);
1906        assert_eq!(abs.get_config().min_batch_size, 16);
1907        assert_eq!(abs.get_config().max_batch_size, 256);
1908    }
1909
1910    #[test]
1911    fn test_adaptive_batch_sizing_lr_adjustment() {
1912        let abs = AdaptiveBatchSizing::with_defaults(64, 16, 256);
1913        let lr_adj = abs.get_lr_adjustment(32);
1914        assert!(lr_adj > 0.0);
1915        assert!(lr_adj < 2.0);
1916    }
1917
1918    #[test]
1919    fn test_adaptive_batch_sizing_reset() {
1920        let mut abs = AdaptiveBatchSizing::with_defaults(64, 16, 256);
1921        abs.current_step = 10;
1922        abs.reset();
1923        assert_eq!(abs.current_step, 0);
1924        assert_eq!(abs.current_batch_size(), 64);
1925    }
1926
1927    #[test]
1928    fn test_loss_surface_smoothing_config() {
1929        let config = LossSurfaceSmoothingConfig::default();
1930        assert_eq!(config.smoothing_strength, 0.1);
1931        assert_eq!(config.noise_variance, 1e-4);
1932        assert_eq!(config.ema_decay, 0.9);
1933        assert_eq!(config.averaging_window, 5);
1934        assert!(config.use_gradient_averaging);
1935        assert!(!config.use_noise_injection);
1936    }
1937
1938    #[test]
1939    fn test_loss_surface_smoothing_creation() {
1940        let lss = LossSurfaceSmoothing::with_defaults(0.2, true);
1941        assert_eq!(lss.get_config().smoothing_strength, 0.2);
1942        assert!(lss.get_config().use_noise_injection);
1943        assert_eq!(lss.current_step, 0);
1944    }
1945
1946    #[test]
1947    fn test_loss_surface_smoothing_statistics() {
1948        let lss = LossSurfaceSmoothing::with_defaults(0.1, false);
1949        let stats = lss.get_statistics();
1950        assert_eq!(stats.get("current_step"), Some(&0.0));
1951        assert_eq!(stats.get("num_tracked_params"), Some(&0.0));
1952        assert_eq!(stats.get("smoothing_strength"), Some(&0.1));
1953        assert_eq!(stats.get("ema_decay"), Some(&0.9));
1954    }
1955
1956    #[test]
1957    fn test_loss_surface_smoothing_reset() {
1958        let mut lss = LossSurfaceSmoothing::with_defaults(0.1, false);
1959        lss.current_step = 5;
1960        lss.reset();
1961        assert_eq!(lss.current_step, 0);
1962    }
1963}