Skip to main content

trustformers_optim/
convergence.rs

1//! # Convergence Improvement Methods
2//!
3//! This module implements advanced optimization techniques that improve convergence
4//! speed, stability, and final performance through sophisticated momentum variants
5//! and variance reduction methods.
6//!
7//! ## Available Methods
8//!
9//! - **QHM (Quasi-Hyperbolic Momentum)**: Generalizes momentum and Nesterov acceleration
10//! - **AggMo (Aggregated Momentum)**: Maintains multiple momentum buffers for better convergence
11//! - **SVRG (Stochastic Variance Reduced Gradient)**: Reduces gradient variance for better convergence
12//! - **SAG (Stochastic Average Gradient)**: Maintains running average of gradients
13//! - **Nesterov Accelerated Gradient (NAG)**: Classical acceleration method with lookahead
14//! - **Heavy Ball Method**: Momentum-based acceleration with inertia
15//! - **FISTA**: Fast Iterative Shrinkage-Thresholding Algorithm for proximal methods
16//! - **Adaptive Batch Sizing**: Dynamically adjusts batch size based on training progress
17//! - **Loss Surface Smoothing**: Reduces noise in the loss surface for better convergence
18
19use crate::optimizer::OptimizerState;
20use anyhow::{anyhow, Result};
21use serde::{Deserialize, Serialize};
22use std::collections::HashMap;
23use trustformers_core::tensor::Tensor;
24
25/// Configuration for Quasi-Hyperbolic Momentum (QHM).
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct QHMConfig {
28    /// Learning rate
29    pub learning_rate: f32,
30    /// Momentum parameter (β)
31    pub momentum: f32,
32    /// Averaging parameter (ν) - controls interpolation between current gradient and momentum
33    pub nu: f32,
34    /// Weight decay
35    pub weight_decay: f32,
36}
37
38impl Default for QHMConfig {
39    fn default() -> Self {
40        Self {
41            learning_rate: 1e-3,
42            momentum: 0.9,
43            nu: 0.7,
44            weight_decay: 0.0,
45        }
46    }
47}
48
49/// Quasi-Hyperbolic Momentum optimizer.
50///
51/// QHM interpolates between the current gradient and the momentum buffer,
52/// providing a generalization of both momentum and Nesterov acceleration.
53/// Update rule: p = p - lr * (nu * g + (1 - nu) * momentum)
54#[derive(Debug)]
55pub struct QHM {
56    config: QHMConfig,
57    momentum_buffers: HashMap<usize, Tensor>,
58    current_step: usize,
59}
60
61impl QHM {
62    /// Create a new QHM optimizer.
63    pub fn new(config: QHMConfig) -> Self {
64        Self {
65            config,
66            momentum_buffers: HashMap::new(),
67            current_step: 0,
68        }
69    }
70
71    /// Create QHM with default configuration.
72    pub fn with_defaults(learning_rate: f32, momentum: f32, nu: f32) -> Self {
73        Self::new(QHMConfig {
74            learning_rate,
75            momentum,
76            nu,
77            weight_decay: 0.0,
78        })
79    }
80
81    /// Get the configuration.
82    pub fn get_config(&self) -> &QHMConfig {
83        &self.config
84    }
85
86    /// Update configuration.
87    pub fn set_config(&mut self, config: QHMConfig) {
88        self.config = config;
89    }
90}
91
92impl OptimizerState for QHM {
93    fn zero_grad(&mut self) -> Result<()> {
94        // QHM doesn't need explicit gradient zeroing
95        Ok(())
96    }
97
98    fn step(&mut self, parameters: &mut [Tensor]) -> Result<()> {
99        self.current_step += 1;
100
101        for (param_id, parameter) in parameters.iter_mut().enumerate() {
102            // Access gradient from parameter (should be computed during forward/backward pass)
103            let gradient = match parameter.grad() {
104                Ok(grad) => grad,
105                Err(_) => {
106                    // If gradient is not available, skip this parameter
107                    continue;
108                },
109            };
110
111            // Apply weight decay to gradient
112            let effective_grad = if self.config.weight_decay > 0.0 {
113                gradient.add(&parameter.mul_scalar(self.config.weight_decay)?)?
114            } else {
115                gradient
116            };
117
118            // Get or initialize momentum buffer
119            let momentum_buffer = if let Some(buffer) = self.momentum_buffers.get(&param_id) {
120                // Update momentum: momentum = β * momentum + (1 - β) * grad
121                let updated = buffer
122                    .mul_scalar(self.config.momentum)?
123                    .add(&effective_grad.mul_scalar(1.0 - self.config.momentum)?)?;
124                self.momentum_buffers.insert(param_id, updated.clone());
125                updated
126            } else {
127                // Initialize momentum buffer with current gradient
128                let initial_momentum = effective_grad.clone();
129                self.momentum_buffers.insert(param_id, initial_momentum.clone());
130                initial_momentum
131            };
132
133            // QHM update: interpolate between current gradient and momentum
134            let update_direction = effective_grad
135                .mul_scalar(self.config.nu)?
136                .add(&momentum_buffer.mul_scalar(1.0 - self.config.nu)?)?;
137
138            // Apply update
139            *parameter = parameter.sub(&update_direction.mul_scalar(self.config.learning_rate)?)?;
140        }
141
142        Ok(())
143    }
144
145    fn get_lr(&self) -> f32 {
146        self.config.learning_rate
147    }
148
149    fn set_lr(&mut self, lr: f32) {
150        self.config.learning_rate = lr;
151    }
152
153    fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
154        let mut state = HashMap::new();
155
156        // Save configuration
157        state.insert(
158            "learning_rate".to_string(),
159            Tensor::scalar(self.config.learning_rate)?,
160        );
161        state.insert(
162            "momentum".to_string(),
163            Tensor::scalar(self.config.momentum)?,
164        );
165        state.insert("nu".to_string(), Tensor::scalar(self.config.nu)?);
166        state.insert(
167            "weight_decay".to_string(),
168            Tensor::scalar(self.config.weight_decay)?,
169        );
170        state.insert(
171            "current_step".to_string(),
172            Tensor::scalar(self.current_step as f32)?,
173        );
174
175        // Save momentum buffers
176        for (&param_id, buffer) in &self.momentum_buffers {
177            state.insert(format!("momentum_buffer_{}", param_id), buffer.clone());
178        }
179
180        Ok(state)
181    }
182
183    fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
184        // Load configuration
185        if let Some(lr) = state.get("learning_rate") {
186            self.config.learning_rate = lr.to_scalar()?;
187        }
188        if let Some(momentum) = state.get("momentum") {
189            self.config.momentum = momentum.to_scalar()?;
190        }
191        if let Some(nu) = state.get("nu") {
192            self.config.nu = nu.to_scalar()?;
193        }
194        if let Some(wd) = state.get("weight_decay") {
195            self.config.weight_decay = wd.to_scalar()?;
196        }
197        if let Some(step) = state.get("current_step") {
198            self.current_step = step.to_scalar()? as usize;
199        }
200
201        // Load momentum buffers
202        self.momentum_buffers.clear();
203        for (key, tensor) in state {
204            if let Some(param_id_str) = key.strip_prefix("momentum_buffer_") {
205                if let Ok(param_id) = param_id_str.parse::<usize>() {
206                    self.momentum_buffers.insert(param_id, tensor);
207                }
208            }
209        }
210
211        Ok(())
212    }
213}
214
215/// Configuration for Aggregated Momentum (AggMo).
216#[derive(Debug, Clone, Serialize, Deserialize)]
217pub struct AggMoConfig {
218    /// Learning rate
219    pub learning_rate: f32,
220    /// List of momentum coefficients
221    pub momentum_coefficients: Vec<f32>,
222    /// Weight decay
223    pub weight_decay: f32,
224}
225
226impl Default for AggMoConfig {
227    fn default() -> Self {
228        Self {
229            learning_rate: 1e-3,
230            momentum_coefficients: vec![0.0, 0.9, 0.99],
231            weight_decay: 0.0,
232        }
233    }
234}
235
236/// Aggregated Momentum optimizer.
237///
238/// AggMo maintains multiple momentum buffers with different decay rates
239/// and averages their contributions to improve convergence.
240#[derive(Debug)]
241pub struct AggMo {
242    config: AggMoConfig,
243    momentum_buffers: HashMap<usize, Vec<Tensor>>, // param_id -> list of momentum buffers
244    current_step: usize,
245}
246
247impl AggMo {
248    /// Create a new AggMo optimizer.
249    pub fn new(config: AggMoConfig) -> Self {
250        assert!(
251            !config.momentum_coefficients.is_empty(),
252            "Must provide at least one momentum coefficient"
253        );
254        Self {
255            config,
256            momentum_buffers: HashMap::new(),
257            current_step: 0,
258        }
259    }
260
261    /// Create AggMo with default configuration.
262    pub fn with_defaults(learning_rate: f32, momentum_coefficients: Vec<f32>) -> Self {
263        Self::new(AggMoConfig {
264            learning_rate,
265            momentum_coefficients,
266            weight_decay: 0.0,
267        })
268    }
269
270    /// Get the configuration.
271    pub fn get_config(&self) -> &AggMoConfig {
272        &self.config
273    }
274
275    /// Get the number of momentum buffers per parameter.
276    pub fn num_momentum_buffers(&self) -> usize {
277        self.config.momentum_coefficients.len()
278    }
279}
280
281impl OptimizerState for AggMo {
282    fn zero_grad(&mut self) -> Result<()> {
283        Ok(())
284    }
285
286    fn step(&mut self, parameters: &mut [Tensor]) -> Result<()> {
287        self.current_step += 1;
288
289        for (param_id, parameter) in parameters.iter_mut().enumerate() {
290            // Access gradient from parameter (should be computed during forward/backward pass)
291            let gradient = match parameter.grad() {
292                Ok(grad) => grad,
293                Err(_) => {
294                    // If gradient is not available, skip this parameter
295                    continue;
296                },
297            };
298
299            // Apply weight decay
300            let effective_grad = if self.config.weight_decay > 0.0 {
301                gradient.add(&parameter.mul_scalar(self.config.weight_decay)?)?
302            } else {
303                gradient
304            };
305
306            // Get or initialize momentum buffers for this parameter
307            let buffers = self.momentum_buffers.entry(param_id).or_insert_with(|| {
308                // Initialize all momentum buffers with zeros
309                (0..self.config.momentum_coefficients.len())
310                    .map(|_| Tensor::zeros(&effective_grad.shape()).unwrap())
311                    .collect()
312            });
313
314            // Update each momentum buffer
315            let mut aggregated_momentum = Tensor::zeros(&effective_grad.shape())?;
316            for (i, &beta) in self.config.momentum_coefficients.iter().enumerate() {
317                // Update momentum: m_i = β_i * m_i + (1 - β_i) * grad
318                buffers[i] =
319                    buffers[i].mul_scalar(beta)?.add(&effective_grad.mul_scalar(1.0 - beta)?)?;
320
321                // Add to aggregated momentum
322                aggregated_momentum = aggregated_momentum.add(&buffers[i])?;
323            }
324
325            // Average the momentum buffers
326            let num_buffers = self.config.momentum_coefficients.len() as f32;
327            let averaged_momentum = aggregated_momentum.div_scalar(num_buffers)?;
328
329            // Apply update
330            *parameter =
331                parameter.sub(&averaged_momentum.mul_scalar(self.config.learning_rate)?)?;
332        }
333
334        Ok(())
335    }
336
337    fn get_lr(&self) -> f32 {
338        self.config.learning_rate
339    }
340
341    fn set_lr(&mut self, lr: f32) {
342        self.config.learning_rate = lr;
343    }
344
345    fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
346        let mut state = HashMap::new();
347
348        // Save configuration
349        state.insert(
350            "learning_rate".to_string(),
351            Tensor::scalar(self.config.learning_rate)?,
352        );
353        state.insert(
354            "weight_decay".to_string(),
355            Tensor::scalar(self.config.weight_decay)?,
356        );
357        state.insert(
358            "current_step".to_string(),
359            Tensor::scalar(self.current_step as f32)?,
360        );
361        state.insert(
362            "num_momentum_coeffs".to_string(),
363            Tensor::scalar(self.config.momentum_coefficients.len() as f32)?,
364        );
365
366        // Save momentum coefficients
367        for (i, &coeff) in self.config.momentum_coefficients.iter().enumerate() {
368            state.insert(format!("momentum_coeff_{}", i), Tensor::scalar(coeff)?);
369        }
370
371        // Save momentum buffers
372        for (&param_id, buffers) in &self.momentum_buffers {
373            for (buffer_idx, buffer) in buffers.iter().enumerate() {
374                state.insert(
375                    format!("momentum_buffer_{}_{}", param_id, buffer_idx),
376                    buffer.clone(),
377                );
378            }
379        }
380
381        Ok(state)
382    }
383
384    fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
385        // Load configuration
386        if let Some(lr) = state.get("learning_rate") {
387            self.config.learning_rate = lr.to_scalar()?;
388        }
389        if let Some(wd) = state.get("weight_decay") {
390            self.config.weight_decay = wd.to_scalar()?;
391        }
392        if let Some(step) = state.get("current_step") {
393            self.current_step = step.to_scalar()? as usize;
394        }
395
396        // Load momentum coefficients
397        if let Some(num_coeffs_tensor) = state.get("num_momentum_coeffs") {
398            let num_coeffs = num_coeffs_tensor.to_scalar()? as usize;
399            let mut coefficients = Vec::with_capacity(num_coeffs);
400            for i in 0..num_coeffs {
401                if let Some(coeff_tensor) = state.get(&format!("momentum_coeff_{}", i)) {
402                    coefficients.push(coeff_tensor.to_scalar()?);
403                }
404            }
405            self.config.momentum_coefficients = coefficients;
406        }
407
408        // Load momentum buffers
409        self.momentum_buffers.clear();
410        let mut param_buffers: HashMap<usize, HashMap<usize, Tensor>> = HashMap::new();
411
412        for (key, tensor) in state {
413            if key.starts_with("momentum_buffer_") {
414                let parts: Vec<&str> = key.split('_').collect();
415                if parts.len() >= 4 {
416                    if let (Ok(param_id), Ok(buffer_idx)) =
417                        (parts[2].parse::<usize>(), parts[3].parse::<usize>())
418                    {
419                        param_buffers.entry(param_id).or_default().insert(buffer_idx, tensor);
420                    }
421                }
422            }
423        }
424
425        // Reconstruct momentum buffers in correct order
426        for (param_id, buffer_map) in param_buffers {
427            let mut buffers = Vec::new();
428            for i in 0..self.config.momentum_coefficients.len() {
429                if let Some(buffer) = buffer_map.get(&i) {
430                    buffers.push(buffer.clone());
431                }
432            }
433            if buffers.len() == self.config.momentum_coefficients.len() {
434                self.momentum_buffers.insert(param_id, buffers);
435            }
436        }
437
438        Ok(())
439    }
440}
441
442/// Configuration for Variance Reduction methods.
443#[derive(Debug, Clone, Serialize, Deserialize)]
444pub struct VarianceReductionConfig {
445    /// Learning rate
446    pub learning_rate: f32,
447    /// Method type
448    pub method: VarianceReductionMethod,
449    /// Gradient history size for SVRG
450    pub history_size: usize,
451    /// Update frequency for full gradient computation
452    pub full_grad_frequency: usize,
453    /// Weight decay
454    pub weight_decay: f32,
455}
456
457impl Default for VarianceReductionConfig {
458    fn default() -> Self {
459        Self {
460            learning_rate: 1e-3,
461            method: VarianceReductionMethod::SVRG,
462            history_size: 100,
463            full_grad_frequency: 10,
464            weight_decay: 0.0,
465        }
466    }
467}
468
469/// Types of variance reduction methods.
470#[derive(Debug, Clone, Serialize, Deserialize)]
471pub enum VarianceReductionMethod {
472    /// Stochastic Variance Reduced Gradient
473    SVRG,
474    /// Stochastic Average Gradient
475    SAG,
476}
477
478/// Variance Reduction optimizer implementing SVRG and SAG methods.
479#[derive(Debug)]
480pub struct VarianceReduction {
481    config: VarianceReductionConfig,
482    gradient_history: HashMap<usize, Vec<Tensor>>,
483    average_gradients: HashMap<usize, Tensor>,
484    full_gradients: HashMap<usize, Tensor>,
485    current_step: usize,
486    last_full_grad_step: usize,
487}
488
489impl VarianceReduction {
490    /// Create a new variance reduction optimizer.
491    pub fn new(config: VarianceReductionConfig) -> Self {
492        Self {
493            config,
494            gradient_history: HashMap::new(),
495            average_gradients: HashMap::new(),
496            full_gradients: HashMap::new(),
497            current_step: 0,
498            last_full_grad_step: 0,
499        }
500    }
501
502    /// Create SVRG optimizer with default settings.
503    pub fn svrg(learning_rate: f32, history_size: usize, full_grad_frequency: usize) -> Self {
504        Self::new(VarianceReductionConfig {
505            learning_rate,
506            method: VarianceReductionMethod::SVRG,
507            history_size,
508            full_grad_frequency,
509            weight_decay: 0.0,
510        })
511    }
512
513    /// Create SAG optimizer with default settings.
514    pub fn sag(learning_rate: f32, history_size: usize) -> Self {
515        Self::new(VarianceReductionConfig {
516            learning_rate,
517            method: VarianceReductionMethod::SAG,
518            history_size,
519            full_grad_frequency: 1, // Not used for SAG
520            weight_decay: 0.0,
521        })
522    }
523
524    fn update_gradient_history(&mut self, param_id: usize, gradient: &Tensor) -> Result<()> {
525        let history = self.gradient_history.entry(param_id).or_default();
526
527        history.push(gradient.clone());
528        if history.len() > self.config.history_size {
529            history.remove(0);
530        }
531
532        Ok(())
533    }
534
535    fn compute_average_gradient(&mut self, param_id: usize) -> Result<Tensor> {
536        if let Some(history) = self.gradient_history.get(&param_id) {
537            if history.is_empty() {
538                return Err(anyhow!("No gradient history available"));
539            }
540
541            let mut sum = history[0].clone();
542            for grad in history.iter().skip(1) {
543                sum = sum.add(grad)?;
544            }
545
546            let average = sum.div_scalar(history.len() as f32)?;
547            self.average_gradients.insert(param_id, average.clone());
548            Ok(average)
549        } else {
550            Err(anyhow!("No gradient history for parameter {}", param_id))
551        }
552    }
553
554    fn should_compute_full_gradient(&self) -> bool {
555        self.current_step - self.last_full_grad_step >= self.config.full_grad_frequency
556    }
557}
558
559impl OptimizerState for VarianceReduction {
560    fn zero_grad(&mut self) -> Result<()> {
561        Ok(())
562    }
563
564    fn step(&mut self, parameters: &mut [Tensor]) -> Result<()> {
565        self.current_step += 1;
566
567        // Check if we need to compute full gradient (for SVRG)
568        let compute_full_grad = match self.config.method {
569            VarianceReductionMethod::SVRG => self.should_compute_full_gradient(),
570            VarianceReductionMethod::SAG => false,
571        };
572
573        if compute_full_grad {
574            self.last_full_grad_step = self.current_step;
575            // In practice, full gradient computation would require access to the full dataset
576            // Here we'll use the current gradient as an approximation
577            for (param_id, parameter) in parameters.iter().enumerate() {
578                // Access gradient from parameter (should be computed during forward/backward pass)
579                let gradient = match parameter.grad() {
580                    Ok(grad) => grad,
581                    Err(_) => {
582                        // If gradient is not available, skip this parameter
583                        continue;
584                    },
585                };
586                self.full_gradients.insert(param_id, gradient);
587            }
588        }
589
590        for (param_id, parameter) in parameters.iter_mut().enumerate() {
591            // Access gradient from parameter (should be computed during forward/backward pass)
592            let current_gradient = match parameter.grad() {
593                Ok(grad) => grad,
594                Err(_) => {
595                    // If gradient is not available, skip this parameter
596                    continue;
597                },
598            };
599
600            // Apply weight decay
601            let effective_grad = if self.config.weight_decay > 0.0 {
602                current_gradient.add(&parameter.mul_scalar(self.config.weight_decay)?)?
603            } else {
604                current_gradient
605            };
606
607            // Update gradient history
608            self.update_gradient_history(param_id, &effective_grad)?;
609
610            // Apply variance reduction
611            let variance_reduced_grad = match self.config.method {
612                VarianceReductionMethod::SVRG => {
613                    if self.full_gradients.contains_key(&param_id) {
614                        let avg_grad = self.compute_average_gradient(param_id)?;
615                        let full_grad = self.full_gradients.get(&param_id).unwrap();
616                        // SVRG update: grad - avg_grad + full_grad
617                        effective_grad.sub(&avg_grad)?.add(full_grad)?
618                    } else {
619                        effective_grad
620                    }
621                },
622                VarianceReductionMethod::SAG => {
623                    // SAG uses running average of gradients
624                    self.compute_average_gradient(param_id)?
625                },
626            };
627
628            // Apply update
629            *parameter =
630                parameter.sub(&variance_reduced_grad.mul_scalar(self.config.learning_rate)?)?;
631        }
632
633        Ok(())
634    }
635
636    fn get_lr(&self) -> f32 {
637        self.config.learning_rate
638    }
639
640    fn set_lr(&mut self, lr: f32) {
641        self.config.learning_rate = lr;
642    }
643
644    fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
645        let mut state = HashMap::new();
646
647        state.insert(
648            "learning_rate".to_string(),
649            Tensor::scalar(self.config.learning_rate)?,
650        );
651        state.insert(
652            "current_step".to_string(),
653            Tensor::scalar(self.current_step as f32)?,
654        );
655        state.insert(
656            "last_full_grad_step".to_string(),
657            Tensor::scalar(self.last_full_grad_step as f32)?,
658        );
659
660        // Note: Saving full gradient history would be expensive
661        // In practice, you might want to save only recent gradients or statistics
662
663        Ok(state)
664    }
665
666    fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
667        if let Some(lr) = state.get("learning_rate") {
668            self.config.learning_rate = lr.to_scalar()?;
669        }
670        if let Some(step) = state.get("current_step") {
671            self.current_step = step.to_scalar()? as usize;
672        }
673        if let Some(last_step) = state.get("last_full_grad_step") {
674            self.last_full_grad_step = last_step.to_scalar()? as usize;
675        }
676
677        Ok(())
678    }
679}
680
681/// Configuration for Nesterov Accelerated Gradient (NAG).
682#[derive(Debug, Clone, Serialize, Deserialize)]
683pub struct NesterovAcceleratedGradientConfig {
684    /// Learning rate
685    pub learning_rate: f32,
686    /// Momentum parameter
687    pub momentum: f32,
688    /// Weight decay
689    pub weight_decay: f32,
690    /// Whether to use strong convexity assumption for restart
691    pub restart_on_increase: bool,
692}
693
694impl Default for NesterovAcceleratedGradientConfig {
695    fn default() -> Self {
696        Self {
697            learning_rate: 1e-3,
698            momentum: 0.9,
699            weight_decay: 0.0,
700            restart_on_increase: false,
701        }
702    }
703}
704
705/// Nesterov Accelerated Gradient optimizer.
706///
707/// NAG uses lookahead to evaluate the gradient at the predicted next position,
708/// which can lead to faster convergence than standard momentum methods.
709/// Update rule:
710/// v_t = momentum * v_{t-1} + lr * grad(x_t + momentum * v_{t-1})
711/// x_{t+1} = x_t - v_t
712#[derive(Debug)]
713pub struct NesterovAcceleratedGradient {
714    config: NesterovAcceleratedGradientConfig,
715    velocity_buffers: HashMap<usize, Tensor>,
716    current_step: usize,
717    previous_loss: Option<f32>,
718}
719
720impl NesterovAcceleratedGradient {
721    /// Create a new NAG optimizer.
722    pub fn new(config: NesterovAcceleratedGradientConfig) -> Self {
723        Self {
724            config,
725            velocity_buffers: HashMap::new(),
726            current_step: 0,
727            previous_loss: None,
728        }
729    }
730
731    /// Create NAG with default configuration.
732    pub fn with_defaults(learning_rate: f32, momentum: f32) -> Self {
733        Self::new(NesterovAcceleratedGradientConfig {
734            learning_rate,
735            momentum,
736            weight_decay: 0.0,
737            restart_on_increase: false,
738        })
739    }
740
741    /// Get the configuration.
742    pub fn get_config(&self) -> &NesterovAcceleratedGradientConfig {
743        &self.config
744    }
745
746    /// Set a loss value for restart detection.
747    pub fn set_current_loss(&mut self, loss: f32) {
748        if self.config.restart_on_increase {
749            if let Some(prev_loss) = self.previous_loss {
750                if loss > prev_loss {
751                    // Restart by clearing velocity buffers
752                    self.velocity_buffers.clear();
753                }
754            }
755        }
756        self.previous_loss = Some(loss);
757    }
758}
759
760impl OptimizerState for NesterovAcceleratedGradient {
761    fn zero_grad(&mut self) -> Result<()> {
762        Ok(())
763    }
764
765    fn step(&mut self, parameters: &mut [Tensor]) -> Result<()> {
766        self.current_step += 1;
767
768        for (param_id, parameter) in parameters.iter_mut().enumerate() {
769            // Access gradient from parameter (should be computed during forward/backward pass)
770            let gradient = match parameter.grad() {
771                Ok(grad) => grad,
772                Err(_) => {
773                    // If gradient is not available, skip this parameter
774                    continue;
775                },
776            };
777
778            // Apply weight decay to gradient
779            let effective_grad = if self.config.weight_decay > 0.0 {
780                gradient.add(&parameter.mul_scalar(self.config.weight_decay)?)?
781            } else {
782                gradient
783            };
784
785            // Get or initialize velocity buffer
786            let velocity = if let Some(v) = self.velocity_buffers.get(&param_id) {
787                v.clone()
788            } else {
789                Tensor::zeros_like(parameter)?
790            };
791
792            // Nesterov acceleration: compute gradient at lookahead position
793            let _lookahead_position = parameter.sub(&velocity.mul_scalar(self.config.momentum)?)?;
794
795            // In practice, we'd need to recompute the gradient at the lookahead position
796            // For now, we'll use the current gradient as approximation
797            // Lookahead gradient computation using current gradient state
798
799            // Update velocity: v_t = momentum * v_{t-1} + lr * grad
800            let new_velocity = velocity
801                .mul_scalar(self.config.momentum)?
802                .add(&effective_grad.mul_scalar(self.config.learning_rate)?)?;
803
804            self.velocity_buffers.insert(param_id, new_velocity.clone());
805
806            // Update parameters: x_{t+1} = x_t - v_t
807            *parameter = parameter.sub(&new_velocity)?;
808        }
809
810        Ok(())
811    }
812
813    fn get_lr(&self) -> f32 {
814        self.config.learning_rate
815    }
816
817    fn set_lr(&mut self, lr: f32) {
818        self.config.learning_rate = lr;
819    }
820
821    fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
822        let mut state = HashMap::new();
823
824        state.insert(
825            "learning_rate".to_string(),
826            Tensor::scalar(self.config.learning_rate)?,
827        );
828        state.insert(
829            "momentum".to_string(),
830            Tensor::scalar(self.config.momentum)?,
831        );
832        state.insert(
833            "weight_decay".to_string(),
834            Tensor::scalar(self.config.weight_decay)?,
835        );
836        state.insert(
837            "current_step".to_string(),
838            Tensor::scalar(self.current_step as f32)?,
839        );
840
841        if let Some(loss) = self.previous_loss {
842            state.insert("previous_loss".to_string(), Tensor::scalar(loss)?);
843        }
844
845        for (&param_id, velocity) in &self.velocity_buffers {
846            state.insert(format!("velocity_{}", param_id), velocity.clone());
847        }
848
849        Ok(state)
850    }
851
852    fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
853        if let Some(lr) = state.get("learning_rate") {
854            self.config.learning_rate = lr.to_scalar()?;
855        }
856        if let Some(momentum) = state.get("momentum") {
857            self.config.momentum = momentum.to_scalar()?;
858        }
859        if let Some(wd) = state.get("weight_decay") {
860            self.config.weight_decay = wd.to_scalar()?;
861        }
862        if let Some(step) = state.get("current_step") {
863            self.current_step = step.to_scalar()? as usize;
864        }
865        if let Some(loss) = state.get("previous_loss") {
866            self.previous_loss = Some(loss.to_scalar()?);
867        }
868
869        self.velocity_buffers.clear();
870        for (key, tensor) in state {
871            if let Some(param_id_str) = key.strip_prefix("velocity_") {
872                if let Ok(param_id) = param_id_str.parse::<usize>() {
873                    self.velocity_buffers.insert(param_id, tensor);
874                }
875            }
876        }
877
878        Ok(())
879    }
880}
881
882/// Configuration for Heavy Ball Method.
883#[derive(Debug, Clone, Serialize, Deserialize)]
884pub struct HeavyBallConfig {
885    /// Learning rate
886    pub learning_rate: f32,
887    /// Momentum coefficient (β)
888    pub beta: f32,
889    /// Weight decay
890    pub weight_decay: f32,
891    /// Adaptive momentum based on gradient alignment
892    pub adaptive_momentum: bool,
893}
894
895impl Default for HeavyBallConfig {
896    fn default() -> Self {
897        Self {
898            learning_rate: 1e-3,
899            beta: 0.9,
900            weight_decay: 0.0,
901            adaptive_momentum: false,
902        }
903    }
904}
905
906/// Heavy Ball Method optimizer.
907///
908/// Classical momentum-based acceleration method that adds inertia to gradient descent.
909/// Update rule:
910/// v_t = β * v_{t-1} - lr * grad(x_t)
911/// x_{t+1} = x_t + v_t
912#[derive(Debug)]
913pub struct HeavyBall {
914    config: HeavyBallConfig,
915    velocity_buffers: HashMap<usize, Tensor>,
916    previous_gradients: HashMap<usize, Tensor>,
917    current_step: usize,
918}
919
920impl HeavyBall {
921    /// Create a new Heavy Ball optimizer.
922    pub fn new(config: HeavyBallConfig) -> Self {
923        Self {
924            config,
925            velocity_buffers: HashMap::new(),
926            previous_gradients: HashMap::new(),
927            current_step: 0,
928        }
929    }
930
931    /// Create Heavy Ball with default configuration.
932    pub fn with_defaults(learning_rate: f32, beta: f32) -> Self {
933        Self::new(HeavyBallConfig {
934            learning_rate,
935            beta,
936            weight_decay: 0.0,
937            adaptive_momentum: false,
938        })
939    }
940
941    /// Get the configuration.
942    pub fn get_config(&self) -> &HeavyBallConfig {
943        &self.config
944    }
945
946    /// Compute adaptive momentum based on gradient alignment.
947    fn compute_adaptive_momentum(&self, param_id: usize, current_grad: &Tensor) -> Result<f32> {
948        if let Some(prev_grad) = self.previous_gradients.get(&param_id) {
949            // Compute cosine similarity between current and previous gradients
950            let dot_product = current_grad.mul(prev_grad)?.sum(None, false)?;
951            let norm_current = current_grad.norm_squared()?.sqrt()?;
952            let norm_prev = prev_grad.norm_squared()?.sqrt()?;
953
954            let dot_scalar = dot_product.to_scalar()?;
955            let norm_current_scalar = norm_current.to_scalar()?;
956            let norm_prev_scalar = norm_prev.to_scalar()?;
957
958            let denominator = norm_current_scalar * norm_prev_scalar;
959            if denominator > 1e-8 {
960                let cosine_similarity = dot_scalar / denominator;
961                // Increase momentum when gradients are aligned, decrease when opposed
962                let adaptive_beta = self.config.beta * cosine_similarity.max(0.0);
963                Ok(adaptive_beta)
964            } else {
965                Ok(self.config.beta)
966            }
967        } else {
968            Ok(self.config.beta)
969        }
970    }
971}
972
973impl OptimizerState for HeavyBall {
974    fn zero_grad(&mut self) -> Result<()> {
975        Ok(())
976    }
977
978    fn step(&mut self, parameters: &mut [Tensor]) -> Result<()> {
979        self.current_step += 1;
980
981        for (param_id, parameter) in parameters.iter_mut().enumerate() {
982            // Access gradient from parameter (should be computed during forward/backward pass)
983            let gradient = match parameter.grad() {
984                Ok(grad) => grad,
985                Err(_) => {
986                    // If gradient is not available, skip this parameter
987                    continue;
988                },
989            };
990
991            // Apply weight decay to gradient
992            let effective_grad = if self.config.weight_decay > 0.0 {
993                gradient.add(&parameter.mul_scalar(self.config.weight_decay)?)?
994            } else {
995                gradient
996            };
997
998            // Compute momentum coefficient
999            let beta = if self.config.adaptive_momentum {
1000                self.compute_adaptive_momentum(param_id, &effective_grad)?
1001            } else {
1002                self.config.beta
1003            };
1004
1005            // Get or initialize velocity buffer
1006            let velocity = if let Some(v) = self.velocity_buffers.get(&param_id) {
1007                v.clone()
1008            } else {
1009                Tensor::zeros_like(parameter)?
1010            };
1011
1012            // Heavy Ball update: v_t = β * v_{t-1} - lr * grad
1013            let new_velocity = velocity
1014                .mul_scalar(beta)?
1015                .sub(&effective_grad.mul_scalar(self.config.learning_rate)?)?;
1016
1017            self.velocity_buffers.insert(param_id, new_velocity.clone());
1018
1019            // Update parameters: x_{t+1} = x_t + v_t
1020            *parameter = parameter.add(&new_velocity)?;
1021
1022            // Store gradient for adaptive momentum
1023            if self.config.adaptive_momentum {
1024                self.previous_gradients.insert(param_id, effective_grad);
1025            }
1026        }
1027
1028        Ok(())
1029    }
1030
1031    fn get_lr(&self) -> f32 {
1032        self.config.learning_rate
1033    }
1034
1035    fn set_lr(&mut self, lr: f32) {
1036        self.config.learning_rate = lr;
1037    }
1038
1039    fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
1040        let mut state = HashMap::new();
1041
1042        state.insert(
1043            "learning_rate".to_string(),
1044            Tensor::scalar(self.config.learning_rate)?,
1045        );
1046        state.insert("beta".to_string(), Tensor::scalar(self.config.beta)?);
1047        state.insert(
1048            "weight_decay".to_string(),
1049            Tensor::scalar(self.config.weight_decay)?,
1050        );
1051        state.insert(
1052            "current_step".to_string(),
1053            Tensor::scalar(self.current_step as f32)?,
1054        );
1055
1056        for (&param_id, velocity) in &self.velocity_buffers {
1057            state.insert(format!("velocity_{}", param_id), velocity.clone());
1058        }
1059
1060        for (&param_id, grad) in &self.previous_gradients {
1061            state.insert(format!("prev_grad_{}", param_id), grad.clone());
1062        }
1063
1064        Ok(state)
1065    }
1066
1067    fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
1068        if let Some(lr) = state.get("learning_rate") {
1069            self.config.learning_rate = lr.to_scalar()?;
1070        }
1071        if let Some(beta) = state.get("beta") {
1072            self.config.beta = beta.to_scalar()?;
1073        }
1074        if let Some(wd) = state.get("weight_decay") {
1075            self.config.weight_decay = wd.to_scalar()?;
1076        }
1077        if let Some(step) = state.get("current_step") {
1078            self.current_step = step.to_scalar()? as usize;
1079        }
1080
1081        self.velocity_buffers.clear();
1082        self.previous_gradients.clear();
1083
1084        for (key, tensor) in state {
1085            if let Some(param_id_str) = key.strip_prefix("velocity_") {
1086                if let Ok(param_id) = param_id_str.parse::<usize>() {
1087                    self.velocity_buffers.insert(param_id, tensor);
1088                }
1089            } else if let Some(param_id_str) = key.strip_prefix("prev_grad_") {
1090                if let Ok(param_id) = param_id_str.parse::<usize>() {
1091                    self.previous_gradients.insert(param_id, tensor);
1092                }
1093            }
1094        }
1095
1096        Ok(())
1097    }
1098}
1099
1100/// Configuration for FISTA (Fast Iterative Shrinkage-Thresholding Algorithm).
1101#[derive(Debug, Clone, Serialize, Deserialize)]
1102pub struct FISTAConfig {
1103    /// Learning rate
1104    pub learning_rate: f32,
1105    /// Proximal threshold parameter
1106    pub threshold: f32,
1107    /// Whether to use adaptive restart
1108    pub adaptive_restart: bool,
1109    /// Weight decay
1110    pub weight_decay: f32,
1111}
1112
1113impl Default for FISTAConfig {
1114    fn default() -> Self {
1115        Self {
1116            learning_rate: 1e-3,
1117            threshold: 1e-4,
1118            adaptive_restart: true,
1119            weight_decay: 0.0,
1120        }
1121    }
1122}
1123
1124/// FISTA optimizer for problems with L1 regularization or other proximal operators.
1125///
1126/// FISTA is designed for problems of the form: min f(x) + λ||x||_1
1127/// where f(x) is smooth and convex, and λ||x||_1 is the L1 regularization term.
1128#[derive(Debug)]
1129pub struct FISTA {
1130    config: FISTAConfig,
1131    previous_params: HashMap<usize, Tensor>,
1132    current_step: usize,
1133    momentum_coefficient: f32,
1134    previous_momentum: f32,
1135}
1136
1137impl FISTA {
1138    /// Create a new FISTA optimizer.
1139    pub fn new(config: FISTAConfig) -> Self {
1140        Self {
1141            config,
1142            previous_params: HashMap::new(),
1143            current_step: 0,
1144            momentum_coefficient: 1.0,
1145            previous_momentum: 1.0,
1146        }
1147    }
1148
1149    /// Create FISTA with default configuration.
1150    pub fn with_defaults(learning_rate: f32, threshold: f32) -> Self {
1151        Self::new(FISTAConfig {
1152            learning_rate,
1153            threshold,
1154            adaptive_restart: true,
1155            weight_decay: 0.0,
1156        })
1157    }
1158
1159    /// Get the configuration.
1160    pub fn get_config(&self) -> &FISTAConfig {
1161        &self.config
1162    }
1163
1164    /// Apply soft thresholding (proximal operator for L1 regularization).
1165    fn soft_threshold(&self, tensor: &Tensor, threshold: f32) -> Result<Tensor> {
1166        let threshold_tensor = Tensor::scalar(threshold)?;
1167        let zero_tensor = Tensor::zeros_like(tensor)?;
1168
1169        // Soft thresholding: sign(x) * max(0, |x| - threshold)
1170        let abs_tensor = tensor.abs()?;
1171        let thresholded = abs_tensor.sub(&threshold_tensor)?.max(&zero_tensor)?;
1172        let sign_tensor = tensor.sign()?;
1173
1174        Ok(sign_tensor.mul(&thresholded)?)
1175    }
1176
1177    /// Update momentum coefficient using FISTA formula.
1178    fn update_momentum_coefficient(&mut self) {
1179        let t = self.current_step as f32;
1180        self.previous_momentum = self.momentum_coefficient;
1181        self.momentum_coefficient = (1.0 + (1.0 + 4.0 * t * t).sqrt()) / 2.0;
1182    }
1183}
1184
1185impl OptimizerState for FISTA {
1186    fn zero_grad(&mut self) -> Result<()> {
1187        Ok(())
1188    }
1189
1190    fn step(&mut self, parameters: &mut [Tensor]) -> Result<()> {
1191        self.current_step += 1;
1192        self.update_momentum_coefficient();
1193
1194        for (param_id, parameter) in parameters.iter_mut().enumerate() {
1195            // Access gradient from parameter (should be computed during forward/backward pass)
1196            let gradient = match parameter.grad() {
1197                Ok(grad) => grad,
1198                Err(_) => {
1199                    // If gradient is not available, skip this parameter
1200                    continue;
1201                },
1202            };
1203
1204            // Apply weight decay to gradient
1205            let effective_grad = if self.config.weight_decay > 0.0 {
1206                gradient.add(&parameter.mul_scalar(self.config.weight_decay)?)?
1207            } else {
1208                gradient
1209            };
1210
1211            // Get previous parameter value
1212            let previous_param = if let Some(prev) = self.previous_params.get(&param_id) {
1213                prev.clone()
1214            } else {
1215                parameter.clone()
1216            };
1217
1218            // Momentum coefficient ratio
1219            let beta = (self.previous_momentum - 1.0) / self.momentum_coefficient;
1220
1221            // Compute extrapolated point
1222            let extrapolated = parameter.add(&previous_param.sub(parameter)?.mul_scalar(beta)?)?;
1223
1224            // Gradient step
1225            let grad_step =
1226                extrapolated.sub(&effective_grad.mul_scalar(self.config.learning_rate)?)?;
1227
1228            // Apply proximal operator (soft thresholding)
1229            let new_parameter = self.soft_threshold(&grad_step, self.config.threshold)?;
1230
1231            // Store current parameter for next iteration
1232            self.previous_params.insert(param_id, parameter.clone());
1233
1234            // Update parameter
1235            *parameter = new_parameter;
1236        }
1237
1238        Ok(())
1239    }
1240
1241    fn get_lr(&self) -> f32 {
1242        self.config.learning_rate
1243    }
1244
1245    fn set_lr(&mut self, lr: f32) {
1246        self.config.learning_rate = lr;
1247    }
1248
1249    fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
1250        let mut state = HashMap::new();
1251
1252        state.insert(
1253            "learning_rate".to_string(),
1254            Tensor::scalar(self.config.learning_rate)?,
1255        );
1256        state.insert(
1257            "threshold".to_string(),
1258            Tensor::scalar(self.config.threshold)?,
1259        );
1260        state.insert(
1261            "weight_decay".to_string(),
1262            Tensor::scalar(self.config.weight_decay)?,
1263        );
1264        state.insert(
1265            "current_step".to_string(),
1266            Tensor::scalar(self.current_step as f32)?,
1267        );
1268        state.insert(
1269            "momentum_coefficient".to_string(),
1270            Tensor::scalar(self.momentum_coefficient)?,
1271        );
1272        state.insert(
1273            "previous_momentum".to_string(),
1274            Tensor::scalar(self.previous_momentum)?,
1275        );
1276
1277        for (&param_id, param) in &self.previous_params {
1278            state.insert(format!("prev_param_{}", param_id), param.clone());
1279        }
1280
1281        Ok(state)
1282    }
1283
1284    fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
1285        if let Some(lr) = state.get("learning_rate") {
1286            self.config.learning_rate = lr.to_scalar()?;
1287        }
1288        if let Some(threshold) = state.get("threshold") {
1289            self.config.threshold = threshold.to_scalar()?;
1290        }
1291        if let Some(wd) = state.get("weight_decay") {
1292            self.config.weight_decay = wd.to_scalar()?;
1293        }
1294        if let Some(step) = state.get("current_step") {
1295            self.current_step = step.to_scalar()? as usize;
1296        }
1297        if let Some(momentum) = state.get("momentum_coefficient") {
1298            self.momentum_coefficient = momentum.to_scalar()?;
1299        }
1300        if let Some(prev_momentum) = state.get("previous_momentum") {
1301            self.previous_momentum = prev_momentum.to_scalar()?;
1302        }
1303
1304        self.previous_params.clear();
1305        for (key, tensor) in state {
1306            if let Some(param_id_str) = key.strip_prefix("prev_param_") {
1307                if let Ok(param_id) = param_id_str.parse::<usize>() {
1308                    self.previous_params.insert(param_id, tensor);
1309                }
1310            }
1311        }
1312
1313        Ok(())
1314    }
1315}
1316
1317/// Configuration for Adaptive Batch Sizing.
1318#[derive(Debug, Clone, Serialize, Deserialize)]
1319pub struct AdaptiveBatchSizingConfig {
1320    /// Initial batch size
1321    pub initial_batch_size: usize,
1322    /// Minimum batch size
1323    pub min_batch_size: usize,
1324    /// Maximum batch size
1325    pub max_batch_size: usize,
1326    /// Tolerance for gradient variance
1327    pub gradient_variance_tolerance: f32,
1328    /// Learning rate adaptation factor
1329    pub lr_adaptation_factor: f32,
1330    /// Window size for gradient variance calculation
1331    pub variance_window_size: usize,
1332    /// Threshold for increasing batch size
1333    pub increase_threshold: f32,
1334    /// Threshold for decreasing batch size
1335    pub decrease_threshold: f32,
1336}
1337
1338impl Default for AdaptiveBatchSizingConfig {
1339    fn default() -> Self {
1340        Self {
1341            initial_batch_size: 32,
1342            min_batch_size: 8,
1343            max_batch_size: 512,
1344            gradient_variance_tolerance: 0.1,
1345            lr_adaptation_factor: 0.8,
1346            variance_window_size: 10,
1347            increase_threshold: 0.05,
1348            decrease_threshold: 0.2,
1349        }
1350    }
1351}
1352
1353/// Adaptive Batch Sizing utility for dynamically adjusting batch size based on training progress.
1354///
1355/// This strategy monitors gradient variance and training stability to determine optimal batch sizes.
1356/// When gradient variance is high, it increases batch size to reduce noise.
1357/// When variance is low, it may decrease batch size to improve convergence speed.
1358#[derive(Debug)]
1359pub struct AdaptiveBatchSizing {
1360    config: AdaptiveBatchSizingConfig,
1361    current_batch_size: usize,
1362    gradient_variance_history: Vec<f32>,
1363    loss_history: Vec<f32>,
1364    current_step: usize,
1365    last_adjustment_step: usize,
1366}
1367
1368impl AdaptiveBatchSizing {
1369    /// Create a new adaptive batch sizing utility.
1370    pub fn new(config: AdaptiveBatchSizingConfig) -> Self {
1371        let initial_batch_size = config.initial_batch_size;
1372        Self {
1373            config,
1374            current_batch_size: initial_batch_size,
1375            gradient_variance_history: Vec::new(),
1376            loss_history: Vec::new(),
1377            current_step: 0,
1378            last_adjustment_step: 0,
1379        }
1380    }
1381
1382    /// Create with default configuration.
1383    pub fn with_defaults(
1384        initial_batch_size: usize,
1385        min_batch_size: usize,
1386        max_batch_size: usize,
1387    ) -> Self {
1388        Self::new(AdaptiveBatchSizingConfig {
1389            initial_batch_size,
1390            min_batch_size,
1391            max_batch_size,
1392            ..Default::default()
1393        })
1394    }
1395
1396    /// Get current batch size.
1397    pub fn current_batch_size(&self) -> usize {
1398        self.current_batch_size
1399    }
1400
1401    /// Get the configuration.
1402    pub fn get_config(&self) -> &AdaptiveBatchSizingConfig {
1403        &self.config
1404    }
1405
1406    /// Update with current gradient variance and loss.
1407    pub fn update(&mut self, gradient_variance: f32, current_loss: f32) -> Result<usize> {
1408        self.current_step += 1;
1409
1410        // Add to history
1411        self.gradient_variance_history.push(gradient_variance);
1412        self.loss_history.push(current_loss);
1413
1414        // Keep only recent history
1415        if self.gradient_variance_history.len() > self.config.variance_window_size {
1416            self.gradient_variance_history.remove(0);
1417        }
1418        if self.loss_history.len() > self.config.variance_window_size {
1419            self.loss_history.remove(0);
1420        }
1421
1422        // Check if we should adjust batch size
1423        if self.should_adjust_batch_size() {
1424            self.adjust_batch_size()?;
1425            self.last_adjustment_step = self.current_step;
1426        }
1427
1428        Ok(self.current_batch_size)
1429    }
1430
1431    /// Compute gradient variance from gradients.
1432    pub fn compute_gradient_variance(&self, gradients: &[Tensor]) -> Result<f32> {
1433        if gradients.is_empty() {
1434            return Ok(0.0);
1435        }
1436
1437        // Compute mean gradient
1438        let mut mean_grad = gradients[0].clone();
1439        for grad in gradients.iter().skip(1) {
1440            mean_grad = mean_grad.add(grad)?;
1441        }
1442        mean_grad = mean_grad.div_scalar(gradients.len() as f32)?;
1443
1444        // Compute variance
1445        let mut variance_sum = 0.0;
1446        for grad in gradients {
1447            let diff = grad.sub(&mean_grad)?;
1448            let squared_norm = diff.mul(&diff)?.sum(None, false)?;
1449            variance_sum += squared_norm.to_scalar()?;
1450        }
1451
1452        Ok(variance_sum / gradients.len() as f32)
1453    }
1454
1455    fn should_adjust_batch_size(&self) -> bool {
1456        // Don't adjust too frequently
1457        if self.current_step - self.last_adjustment_step < 5 {
1458            return false;
1459        }
1460
1461        // Need enough history
1462        self.gradient_variance_history.len() >= 3
1463    }
1464
1465    fn adjust_batch_size(&mut self) -> Result<()> {
1466        let recent_variance = self.recent_average_variance();
1467        let variance_trend = self.variance_trend();
1468        let loss_trend = self.loss_trend();
1469
1470        // Decide whether to increase or decrease batch size
1471        if recent_variance > self.config.decrease_threshold && variance_trend > 0.0 {
1472            // High variance and increasing - increase batch size
1473            self.increase_batch_size();
1474        } else if recent_variance < self.config.increase_threshold && loss_trend < -0.01 {
1475            // Low variance and decreasing loss - try smaller batch size
1476            self.decrease_batch_size();
1477        }
1478
1479        Ok(())
1480    }
1481
1482    fn recent_average_variance(&self) -> f32 {
1483        if self.gradient_variance_history.is_empty() {
1484            return 0.0;
1485        }
1486
1487        let recent_window = std::cmp::min(5, self.gradient_variance_history.len());
1488        let start_idx = self.gradient_variance_history.len() - recent_window;
1489
1490        self.gradient_variance_history[start_idx..].iter().sum::<f32>() / recent_window as f32
1491    }
1492
1493    fn variance_trend(&self) -> f32 {
1494        if self.gradient_variance_history.len() < 3 {
1495            return 0.0;
1496        }
1497
1498        let len = self.gradient_variance_history.len();
1499        let recent = self.gradient_variance_history[len - 2..].iter().sum::<f32>() / 2.0;
1500        let older = self.gradient_variance_history[len - 4..len - 2].iter().sum::<f32>() / 2.0;
1501
1502        recent - older
1503    }
1504
1505    fn loss_trend(&self) -> f32 {
1506        if self.loss_history.len() < 3 {
1507            return 0.0;
1508        }
1509
1510        let len = self.loss_history.len();
1511        let recent = self.loss_history[len - 2..].iter().sum::<f32>() / 2.0;
1512        let older = self.loss_history[len - 4..len - 2].iter().sum::<f32>() / 2.0;
1513
1514        (recent - older) / older.max(1e-8)
1515    }
1516
1517    fn increase_batch_size(&mut self) {
1518        let new_size = (self.current_batch_size as f32 * 1.5) as usize;
1519        self.current_batch_size = new_size.min(self.config.max_batch_size);
1520    }
1521
1522    fn decrease_batch_size(&mut self) {
1523        let new_size = (self.current_batch_size as f32 * 0.8) as usize;
1524        self.current_batch_size = new_size.max(self.config.min_batch_size);
1525    }
1526
1527    /// Get suggested learning rate adjustment based on batch size changes.
1528    pub fn get_lr_adjustment(&self, original_batch_size: usize) -> f32 {
1529        let ratio = self.current_batch_size as f32 / original_batch_size as f32;
1530        ratio.sqrt() * self.config.lr_adaptation_factor
1531    }
1532
1533    /// Reset state for new training run.
1534    pub fn reset(&mut self) {
1535        self.current_batch_size = self.config.initial_batch_size;
1536        self.gradient_variance_history.clear();
1537        self.loss_history.clear();
1538        self.current_step = 0;
1539        self.last_adjustment_step = 0;
1540    }
1541}
1542
1543/// Configuration for Loss Surface Smoothing.
1544#[derive(Debug, Clone, Serialize, Deserialize)]
1545pub struct LossSurfaceSmoothingConfig {
1546    /// Smoothing strength parameter
1547    pub smoothing_strength: f32,
1548    /// Noise injection variance
1549    pub noise_variance: f32,
1550    /// Exponential moving average decay
1551    pub ema_decay: f32,
1552    /// Number of gradient steps to average
1553    pub averaging_window: usize,
1554    /// Whether to use gradient averaging
1555    pub use_gradient_averaging: bool,
1556    /// Whether to use noise injection
1557    pub use_noise_injection: bool,
1558}
1559
1560impl Default for LossSurfaceSmoothingConfig {
1561    fn default() -> Self {
1562        Self {
1563            smoothing_strength: 0.1,
1564            noise_variance: 1e-4,
1565            ema_decay: 0.9,
1566            averaging_window: 5,
1567            use_gradient_averaging: true,
1568            use_noise_injection: false,
1569        }
1570    }
1571}
1572
1573/// Loss Surface Smoothing utility for reducing noise in the loss landscape.
1574///
1575/// This implements several techniques to smooth the loss surface:
1576/// - Gradient averaging over multiple steps
1577/// - Exponential moving average of gradients
1578/// - Controlled noise injection for exploration
1579/// - Parameter smoothing to reduce sharp changes
1580#[derive(Debug)]
1581pub struct LossSurfaceSmoothing {
1582    config: LossSurfaceSmoothingConfig,
1583    gradient_history: HashMap<usize, Vec<Tensor>>,
1584    ema_gradients: HashMap<usize, Tensor>,
1585    smoothed_parameters: HashMap<usize, Tensor>,
1586    current_step: usize,
1587}
1588
1589impl LossSurfaceSmoothing {
1590    /// Create a new loss surface smoothing utility.
1591    pub fn new(config: LossSurfaceSmoothingConfig) -> Self {
1592        Self {
1593            config,
1594            gradient_history: HashMap::new(),
1595            ema_gradients: HashMap::new(),
1596            smoothed_parameters: HashMap::new(),
1597            current_step: 0,
1598        }
1599    }
1600
1601    /// Create with default configuration.
1602    pub fn with_defaults(smoothing_strength: f32, use_noise: bool) -> Self {
1603        Self::new(LossSurfaceSmoothingConfig {
1604            smoothing_strength,
1605            use_noise_injection: use_noise,
1606            ..Default::default()
1607        })
1608    }
1609
1610    /// Get the configuration.
1611    pub fn get_config(&self) -> &LossSurfaceSmoothingConfig {
1612        &self.config
1613    }
1614
1615    /// Apply smoothing to gradients.
1616    pub fn smooth_gradients(&mut self, parameters: &mut [Tensor]) -> Result<()> {
1617        self.current_step += 1;
1618
1619        for (param_id, parameter) in parameters.iter_mut().enumerate() {
1620            let original_grad = parameter.grad()?;
1621            let mut smoothed_grad = original_grad.clone();
1622
1623            // Apply gradient averaging
1624            if self.config.use_gradient_averaging {
1625                smoothed_grad = self.apply_gradient_averaging(param_id, &original_grad)?;
1626            }
1627
1628            // Apply exponential moving average
1629            smoothed_grad = self.apply_ema_smoothing(param_id, &smoothed_grad)?;
1630
1631            // Apply noise injection for exploration
1632            if self.config.use_noise_injection {
1633                smoothed_grad = self.apply_noise_injection(&smoothed_grad)?;
1634            }
1635
1636            // Update parameter gradient
1637            parameter.set_grad(smoothed_grad)?;
1638        }
1639
1640        Ok(())
1641    }
1642
1643    /// Apply parameter smoothing.
1644    pub fn smooth_parameters(&mut self, parameters: &mut [Tensor]) -> Result<()> {
1645        for (param_id, parameter) in parameters.iter_mut().enumerate() {
1646            if let Some(smoothed_param) = self.smoothed_parameters.get(&param_id) {
1647                // Apply exponential moving average to parameters
1648                let new_smoothed = smoothed_param
1649                    .mul_scalar(self.config.ema_decay)?
1650                    .add(&parameter.mul_scalar(1.0 - self.config.ema_decay)?)?;
1651
1652                // Interpolate between original and smoothed parameters
1653                *parameter = parameter
1654                    .mul_scalar(1.0 - self.config.smoothing_strength)?
1655                    .add(&new_smoothed.mul_scalar(self.config.smoothing_strength)?)?;
1656
1657                self.smoothed_parameters.insert(param_id, new_smoothed);
1658            } else {
1659                // Initialize smoothed parameter
1660                self.smoothed_parameters.insert(param_id, parameter.clone());
1661            }
1662        }
1663
1664        Ok(())
1665    }
1666
1667    fn apply_gradient_averaging(&mut self, param_id: usize, gradient: &Tensor) -> Result<Tensor> {
1668        let history = self.gradient_history.entry(param_id).or_default();
1669
1670        history.push(gradient.clone());
1671        if history.len() > self.config.averaging_window {
1672            history.remove(0);
1673        }
1674
1675        // Compute average of recent gradients
1676        if history.len() == 1 {
1677            Ok(gradient.clone())
1678        } else {
1679            let mut sum = history[0].clone();
1680            for grad in history.iter().skip(1) {
1681                sum = sum.add(grad)?;
1682            }
1683            Ok(sum.div_scalar(history.len() as f32)?)
1684        }
1685    }
1686
1687    fn apply_ema_smoothing(&mut self, param_id: usize, gradient: &Tensor) -> Result<Tensor> {
1688        if let Some(ema_grad) = self.ema_gradients.get(&param_id) {
1689            let new_ema = ema_grad
1690                .mul_scalar(self.config.ema_decay)?
1691                .add(&gradient.mul_scalar(1.0 - self.config.ema_decay)?)?;
1692            self.ema_gradients.insert(param_id, new_ema.clone());
1693            Ok(new_ema)
1694        } else {
1695            self.ema_gradients.insert(param_id, gradient.clone());
1696            Ok(gradient.clone())
1697        }
1698    }
1699
1700    fn apply_noise_injection(&self, gradient: &Tensor) -> Result<Tensor> {
1701        let noise = Tensor::randn_like(gradient)
1702            .map_err(|e| anyhow!("Failed to create noise tensor: {}", e))?
1703            .mul_scalar(self.config.noise_variance.sqrt())
1704            .map_err(|e| anyhow!("Failed to scale noise tensor: {}", e))?;
1705        gradient
1706            .add(&noise)
1707            .map_err(|e| anyhow!("Failed to add noise to gradient: {}", e))
1708    }
1709
1710    /// Reset state for new training run.
1711    pub fn reset(&mut self) {
1712        self.gradient_history.clear();
1713        self.ema_gradients.clear();
1714        self.smoothed_parameters.clear();
1715        self.current_step = 0;
1716    }
1717
1718    /// Get smoothing statistics.
1719    pub fn get_statistics(&self) -> HashMap<String, f32> {
1720        let mut stats = HashMap::new();
1721        stats.insert("current_step".to_string(), self.current_step as f32);
1722        stats.insert(
1723            "num_tracked_params".to_string(),
1724            self.gradient_history.len() as f32,
1725        );
1726        stats.insert(
1727            "smoothing_strength".to_string(),
1728            self.config.smoothing_strength,
1729        );
1730        stats.insert("ema_decay".to_string(), self.config.ema_decay);
1731        stats
1732    }
1733}
1734
1735#[cfg(test)]
1736mod tests {
1737    use super::*;
1738
1739    #[test]
1740    fn test_qhm_config_default() {
1741        let config = QHMConfig::default();
1742        assert_eq!(config.learning_rate, 1e-3);
1743        assert_eq!(config.momentum, 0.9);
1744        assert_eq!(config.nu, 0.7);
1745        assert_eq!(config.weight_decay, 0.0);
1746    }
1747
1748    #[test]
1749    fn test_aggmo_config_default() {
1750        let config = AggMoConfig::default();
1751        assert_eq!(config.learning_rate, 1e-3);
1752        assert_eq!(config.momentum_coefficients, vec![0.0, 0.9, 0.99]);
1753        assert_eq!(config.weight_decay, 0.0);
1754    }
1755
1756    #[test]
1757    fn test_qhm_creation() {
1758        let optimizer = QHM::with_defaults(1e-3, 0.9, 0.7);
1759        assert_eq!(optimizer.get_lr(), 1e-3);
1760        assert_eq!(optimizer.current_step, 0);
1761    }
1762
1763    #[test]
1764    fn test_aggmo_creation() {
1765        let optimizer = AggMo::with_defaults(1e-3, vec![0.0, 0.9, 0.99]);
1766        assert_eq!(optimizer.get_lr(), 1e-3);
1767        assert_eq!(optimizer.num_momentum_buffers(), 3);
1768    }
1769
1770    #[test]
1771    fn test_variance_reduction_svrg() {
1772        let optimizer = VarianceReduction::svrg(1e-3, 50, 10);
1773        assert_eq!(optimizer.get_lr(), 1e-3);
1774        assert_eq!(optimizer.current_step, 0);
1775    }
1776
1777    #[test]
1778    fn test_variance_reduction_sag() {
1779        let optimizer = VarianceReduction::sag(1e-3, 100);
1780        assert_eq!(optimizer.get_lr(), 1e-3);
1781        assert!(matches!(
1782            optimizer.config.method,
1783            VarianceReductionMethod::SAG
1784        ));
1785    }
1786
1787    #[test]
1788    fn test_nesterov_accelerated_gradient_config() {
1789        let config = NesterovAcceleratedGradientConfig::default();
1790        assert_eq!(config.learning_rate, 1e-3);
1791        assert_eq!(config.momentum, 0.9);
1792        assert_eq!(config.weight_decay, 0.0);
1793        assert!(!config.restart_on_increase);
1794    }
1795
1796    #[test]
1797    fn test_nesterov_accelerated_gradient_creation() {
1798        let optimizer = NesterovAcceleratedGradient::with_defaults(1e-3, 0.9);
1799        assert_eq!(optimizer.get_lr(), 1e-3);
1800        assert_eq!(optimizer.current_step, 0);
1801        assert!(optimizer.previous_loss.is_none());
1802    }
1803
1804    #[test]
1805    fn test_nesterov_restart_on_increase() {
1806        let mut optimizer = NesterovAcceleratedGradient::new(NesterovAcceleratedGradientConfig {
1807            learning_rate: 1e-3,
1808            momentum: 0.9,
1809            weight_decay: 0.0,
1810            restart_on_increase: true,
1811        });
1812
1813        // Set initial loss
1814        optimizer.set_current_loss(1.0);
1815        assert_eq!(optimizer.previous_loss, Some(1.0));
1816
1817        // Increasing loss should trigger restart
1818        optimizer.set_current_loss(1.5);
1819        assert_eq!(optimizer.previous_loss, Some(1.5));
1820    }
1821
1822    #[test]
1823    fn test_heavy_ball_config() {
1824        let config = HeavyBallConfig::default();
1825        assert_eq!(config.learning_rate, 1e-3);
1826        assert_eq!(config.beta, 0.9);
1827        assert_eq!(config.weight_decay, 0.0);
1828        assert!(!config.adaptive_momentum);
1829    }
1830
1831    #[test]
1832    fn test_heavy_ball_creation() {
1833        let optimizer = HeavyBall::with_defaults(1e-3, 0.9);
1834        assert_eq!(optimizer.get_lr(), 1e-3);
1835        assert_eq!(optimizer.current_step, 0);
1836        assert_eq!(optimizer.get_config().beta, 0.9);
1837    }
1838
1839    #[test]
1840    fn test_heavy_ball_adaptive_momentum() {
1841        let optimizer = HeavyBall::new(HeavyBallConfig {
1842            learning_rate: 1e-3,
1843            beta: 0.9,
1844            weight_decay: 0.0,
1845            adaptive_momentum: true,
1846        });
1847
1848        assert!(optimizer.config.adaptive_momentum);
1849    }
1850
1851    #[test]
1852    fn test_fista_config() {
1853        let config = FISTAConfig::default();
1854        assert_eq!(config.learning_rate, 1e-3);
1855        assert_eq!(config.threshold, 1e-4);
1856        assert!(config.adaptive_restart);
1857        assert_eq!(config.weight_decay, 0.0);
1858    }
1859
1860    #[test]
1861    fn test_fista_creation() {
1862        let optimizer = FISTA::with_defaults(1e-3, 1e-4);
1863        assert_eq!(optimizer.get_lr(), 1e-3);
1864        assert_eq!(optimizer.current_step, 0);
1865        assert_eq!(optimizer.momentum_coefficient, 1.0);
1866        assert_eq!(optimizer.previous_momentum, 1.0);
1867    }
1868
1869    #[test]
1870    fn test_fista_momentum_update() {
1871        let mut optimizer = FISTA::with_defaults(1e-3, 1e-4);
1872
1873        // Momentum coefficient should update with step (increment step first)
1874        optimizer.current_step = 1;
1875        optimizer.update_momentum_coefficient();
1876        assert!(optimizer.momentum_coefficient > 1.0);
1877        assert_eq!(optimizer.previous_momentum, 1.0);
1878
1879        let prev_momentum = optimizer.momentum_coefficient;
1880        optimizer.current_step = 2;
1881        optimizer.update_momentum_coefficient();
1882        assert!(optimizer.momentum_coefficient > prev_momentum);
1883    }
1884
1885    #[test]
1886    fn test_adaptive_batch_sizing_config() {
1887        let config = AdaptiveBatchSizingConfig::default();
1888        assert_eq!(config.initial_batch_size, 32);
1889        assert_eq!(config.min_batch_size, 8);
1890        assert_eq!(config.max_batch_size, 512);
1891        assert_eq!(config.gradient_variance_tolerance, 0.1);
1892        assert_eq!(config.lr_adaptation_factor, 0.8);
1893        assert_eq!(config.variance_window_size, 10);
1894        assert_eq!(config.increase_threshold, 0.05);
1895        assert_eq!(config.decrease_threshold, 0.2);
1896    }
1897
1898    #[test]
1899    fn test_adaptive_batch_sizing_creation() {
1900        let abs = AdaptiveBatchSizing::with_defaults(64, 16, 256);
1901        assert_eq!(abs.current_batch_size(), 64);
1902        assert_eq!(abs.get_config().min_batch_size, 16);
1903        assert_eq!(abs.get_config().max_batch_size, 256);
1904    }
1905
1906    #[test]
1907    fn test_adaptive_batch_sizing_lr_adjustment() {
1908        let abs = AdaptiveBatchSizing::with_defaults(64, 16, 256);
1909        let lr_adj = abs.get_lr_adjustment(32);
1910        assert!(lr_adj > 0.0);
1911        assert!(lr_adj < 2.0);
1912    }
1913
1914    #[test]
1915    fn test_adaptive_batch_sizing_reset() {
1916        let mut abs = AdaptiveBatchSizing::with_defaults(64, 16, 256);
1917        abs.current_step = 10;
1918        abs.reset();
1919        assert_eq!(abs.current_step, 0);
1920        assert_eq!(abs.current_batch_size(), 64);
1921    }
1922
1923    #[test]
1924    fn test_loss_surface_smoothing_config() {
1925        let config = LossSurfaceSmoothingConfig::default();
1926        assert_eq!(config.smoothing_strength, 0.1);
1927        assert_eq!(config.noise_variance, 1e-4);
1928        assert_eq!(config.ema_decay, 0.9);
1929        assert_eq!(config.averaging_window, 5);
1930        assert!(config.use_gradient_averaging);
1931        assert!(!config.use_noise_injection);
1932    }
1933
1934    #[test]
1935    fn test_loss_surface_smoothing_creation() {
1936        let lss = LossSurfaceSmoothing::with_defaults(0.2, true);
1937        assert_eq!(lss.get_config().smoothing_strength, 0.2);
1938        assert!(lss.get_config().use_noise_injection);
1939        assert_eq!(lss.current_step, 0);
1940    }
1941
1942    #[test]
1943    fn test_loss_surface_smoothing_statistics() {
1944        let lss = LossSurfaceSmoothing::with_defaults(0.1, false);
1945        let stats = lss.get_statistics();
1946        assert_eq!(stats.get("current_step"), Some(&0.0));
1947        assert_eq!(stats.get("num_tracked_params"), Some(&0.0));
1948        assert_eq!(stats.get("smoothing_strength"), Some(&0.1));
1949        assert_eq!(stats.get("ema_decay"), Some(&0.9));
1950    }
1951
1952    #[test]
1953    fn test_loss_surface_smoothing_reset() {
1954        let mut lss = LossSurfaceSmoothing::with_defaults(0.1, false);
1955        lss.current_step = 5;
1956        lss.reset();
1957        assert_eq!(lss.current_step, 0);
1958    }
1959}