Skip to main content

trustformers_optim/
sparse.rs

1//! # Sparse Momentum Methods
2//!
3//! This module provides optimizers specifically designed for sparse gradients,
4//! commonly encountered in embedding layers, large language models, and recommendation systems.
5//! These optimizers are memory-efficient and only update parameters that receive non-zero gradients.
6//!
7//! ## Benefits
8//! - Memory efficient for sparse models (embeddings, transformers)
9//! - Faster convergence on sparse data
10//! - Reduced computation for inactive parameters
11//! - Better handling of rare features
12
13use crate::optimizer::OptimizerState;
14use anyhow::Result;
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17use trustformers_core::tensor::Tensor;
18
19/// Configuration for sparse optimization methods.
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct SparseConfig {
22    /// Sparsity threshold below which gradients are considered zero
23    pub sparsity_threshold: f32,
24    /// Maximum number of active parameters to track
25    pub max_active_params: Option<usize>,
26    /// Whether to use lazy updates for momentum
27    pub lazy_updates: bool,
28    /// Frequency of momentum state cleanup (steps)
29    pub cleanup_frequency: usize,
30    /// Whether to compress inactive momentum states
31    pub compress_inactive: bool,
32}
33
34impl Default for SparseConfig {
35    fn default() -> Self {
36        Self {
37            sparsity_threshold: 1e-8,
38            max_active_params: None,
39            lazy_updates: true,
40            cleanup_frequency: 1000,
41            compress_inactive: false,
42        }
43    }
44}
45
46/// Sparse momentum state for a parameter.
47#[derive(Debug, Clone)]
48pub struct SparseMomentumState {
49    /// Momentum buffer (only for active indices)
50    pub momentum: HashMap<usize, f32>,
51    /// Last update step for each active index
52    pub last_update: HashMap<usize, usize>,
53    /// Accumulated gradient norm (for adaptive methods)
54    pub grad_norm_acc: HashMap<usize, f32>,
55    /// Whether state is compressed
56    pub is_compressed: bool,
57}
58
59impl Default for SparseMomentumState {
60    fn default() -> Self {
61        Self::new()
62    }
63}
64
65impl SparseMomentumState {
66    pub fn new() -> Self {
67        Self {
68            momentum: HashMap::new(),
69            last_update: HashMap::new(),
70            grad_norm_acc: HashMap::new(),
71            is_compressed: false,
72        }
73    }
74
75    /// Get the number of active parameters.
76    pub fn num_active(&self) -> usize {
77        self.momentum.len()
78    }
79
80    /// Apply lazy momentum updates.
81    pub fn apply_lazy_update(&mut self, current_step: usize, decay: f32) {
82        for (idx, momentum) in self.momentum.iter_mut() {
83            if let Some(&last_step) = self.last_update.get(idx) {
84                let steps_skipped = current_step - last_step - 1;
85                if steps_skipped > 0 {
86                    // Apply exponential decay for skipped steps
87                    *momentum *= decay.powi(steps_skipped as i32);
88                }
89            }
90        }
91    }
92
93    /// Clean up old momentum states.
94    pub fn cleanup(&mut self, max_age_steps: usize, current_step: usize) {
95        let mut to_remove = Vec::new();
96
97        for (idx, &last_step) in &self.last_update {
98            if current_step - last_step > max_age_steps {
99                to_remove.push(*idx);
100            }
101        }
102
103        for idx in to_remove {
104            self.momentum.remove(&idx);
105            self.last_update.remove(&idx);
106            self.grad_norm_acc.remove(&idx);
107        }
108    }
109
110    /// Compress inactive momentum states.
111    pub fn compress(&mut self) {
112        if self.is_compressed {
113            return;
114        }
115
116        // Remove very small momentum values
117        let threshold = 1e-10;
118        self.momentum.retain(|_, &mut v| v.abs() > threshold);
119        self.grad_norm_acc.retain(|_, &mut v| v > threshold);
120
121        self.is_compressed = true;
122    }
123
124    /// Decompress momentum states.
125    pub fn decompress(&mut self) {
126        self.is_compressed = false;
127    }
128}
129
130/// Sparse SGD with momentum optimizer.
131#[derive(Debug)]
132pub struct SparseSGD {
133    learning_rate: f32,
134    momentum: f32,
135    dampening: f32,
136    weight_decay: f32,
137    nesterov: bool,
138    config: SparseConfig,
139    momentum_states: HashMap<usize, SparseMomentumState>,
140    current_step: usize,
141}
142
143impl SparseSGD {
144    pub fn new(
145        learning_rate: f32,
146        momentum: f32,
147        dampening: f32,
148        weight_decay: f32,
149        nesterov: bool,
150        config: SparseConfig,
151    ) -> Self {
152        Self {
153            learning_rate,
154            momentum,
155            dampening,
156            weight_decay,
157            nesterov,
158            config,
159            momentum_states: HashMap::new(),
160            current_step: 0,
161        }
162    }
163
164    /// Create with default sparse configuration.
165    pub fn with_default_config(
166        learning_rate: f32,
167        momentum: f32,
168        dampening: f32,
169        weight_decay: f32,
170        nesterov: bool,
171    ) -> Self {
172        Self::new(
173            learning_rate,
174            momentum,
175            dampening,
176            weight_decay,
177            nesterov,
178            SparseConfig::default(),
179        )
180    }
181
182    /// Get sparse indices from gradient tensor.
183    fn get_sparse_indices(&self, gradient: &Tensor) -> Result<Vec<usize>> {
184        let grad_data = gradient.data()?;
185        let indices: Vec<usize> = grad_data
186            .iter()
187            .enumerate()
188            .filter_map(
189                |(i, &val)| {
190                    if val.abs() > self.config.sparsity_threshold {
191                        Some(i)
192                    } else {
193                        None
194                    }
195                },
196            )
197            .collect();
198
199        // Limit active parameters if configured
200        if let Some(max_active) = self.config.max_active_params {
201            if indices.len() > max_active {
202                // Keep indices with largest gradients
203                let mut indexed_grads: Vec<(usize, f32)> =
204                    indices.iter().map(|&i| (i, grad_data[i].abs())).collect();
205                indexed_grads
206                    .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
207                return Ok(indexed_grads.into_iter().take(max_active).map(|(i, _)| i).collect());
208            }
209        }
210
211        Ok(indices)
212    }
213
214    /// Update sparse momentum for a parameter.
215    fn update_sparse_momentum(
216        &mut self,
217        param_id: usize,
218        gradient: &Tensor,
219        parameter: &mut Tensor,
220    ) -> Result<()> {
221        let sparse_indices = self.get_sparse_indices(gradient)?;
222        if sparse_indices.is_empty() {
223            return Ok(());
224        }
225
226        let grad_data = gradient.data()?;
227        let mut param_data = parameter.data()?;
228
229        // Get or create momentum state
230        let momentum_state = self.momentum_states.entry(param_id).or_default();
231
232        // Apply lazy updates if enabled
233        if self.config.lazy_updates {
234            momentum_state.apply_lazy_update(self.current_step, self.momentum);
235        }
236
237        // Update momentum for each sparse index
238        for &idx in &sparse_indices {
239            let mut grad_val = grad_data[idx];
240
241            // Apply weight decay
242            if self.weight_decay != 0.0 {
243                grad_val += self.weight_decay * param_data[idx];
244            }
245
246            // Update momentum
247            let momentum_val = momentum_state.momentum.get(&idx).copied().unwrap_or(0.0);
248            let new_momentum = self.momentum * momentum_val + (1.0 - self.dampening) * grad_val;
249            momentum_state.momentum.insert(idx, new_momentum);
250            momentum_state.last_update.insert(idx, self.current_step);
251
252            // Apply update
253            let update = if self.nesterov {
254                grad_val + self.momentum * new_momentum
255            } else {
256                new_momentum
257            };
258
259            param_data[idx] -= self.learning_rate * update;
260        }
261
262        // Update parameter tensor
263        *parameter = Tensor::from_vec(param_data, &parameter.shape())?;
264
265        Ok(())
266    }
267
268    /// Get momentum statistics.
269    pub fn get_momentum_stats(&self) -> HashMap<usize, usize> {
270        self.momentum_states
271            .iter()
272            .map(|(&param_id, state)| (param_id, state.num_active()))
273            .collect()
274    }
275
276    /// Total number of active momentum states across all parameters.
277    pub fn total_active_states(&self) -> usize {
278        self.momentum_states.values().map(|s| s.num_active()).sum()
279    }
280
281    /// Cleanup old momentum states for all parameters.
282    pub fn cleanup_momentum_states(&mut self) {
283        if self.current_step.is_multiple_of(self.config.cleanup_frequency) {
284            let max_age = self.config.cleanup_frequency * 2;
285            for state in self.momentum_states.values_mut() {
286                state.cleanup(max_age, self.current_step);
287                if self.config.compress_inactive {
288                    state.compress();
289                }
290            }
291        }
292    }
293}
294
295impl OptimizerState for SparseSGD {
296    fn zero_grad(&mut self) -> Result<()> {
297        // For sparse optimizers, we don't need to explicitly zero gradients
298        // since we only process non-zero gradients
299        Ok(())
300    }
301
302    fn step(&mut self, parameters: &mut [Tensor]) -> Result<()> {
303        self.current_step += 1;
304
305        for (param_id, parameter) in parameters.iter_mut().enumerate() {
306            // For sparse optimizers, we assume gradients are already computed
307            // and available in the parameter's grad field
308            if let Ok(gradient) = parameter.grad() {
309                self.update_sparse_momentum(param_id, &gradient, parameter)?;
310            }
311        }
312
313        // Periodic cleanup
314        self.cleanup_momentum_states();
315
316        Ok(())
317    }
318
319    fn get_lr(&self) -> f32 {
320        self.learning_rate
321    }
322
323    fn set_lr(&mut self, lr: f32) {
324        self.learning_rate = lr;
325    }
326
327    fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
328        let mut state = HashMap::new();
329
330        // Save hyperparameters
331        state.insert(
332            "learning_rate".to_string(),
333            Tensor::scalar(self.learning_rate)?,
334        );
335        state.insert("momentum".to_string(), Tensor::scalar(self.momentum)?);
336        state.insert("dampening".to_string(), Tensor::scalar(self.dampening)?);
337        state.insert(
338            "weight_decay".to_string(),
339            Tensor::scalar(self.weight_decay)?,
340        );
341        state.insert(
342            "nesterov".to_string(),
343            Tensor::scalar(self.nesterov as i32 as f32)?,
344        );
345        state.insert(
346            "current_step".to_string(),
347            Tensor::scalar(self.current_step as f32)?,
348        );
349
350        // Save momentum states (this is simplified - real implementation would be more complex)
351        for (&param_id, momentum_state) in &self.momentum_states {
352            let num_active = momentum_state.num_active();
353            state.insert(
354                format!("momentum_state_{}_active_count", param_id),
355                Tensor::scalar(num_active as f32)?,
356            );
357        }
358
359        Ok(state)
360    }
361
362    fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
363        // Load hyperparameters
364        if let Some(lr_tensor) = state.get("learning_rate") {
365            self.learning_rate = lr_tensor.to_scalar()?;
366        }
367        if let Some(momentum_tensor) = state.get("momentum") {
368            self.momentum = momentum_tensor.to_scalar()?;
369        }
370        if let Some(dampening_tensor) = state.get("dampening") {
371            self.dampening = dampening_tensor.to_scalar()?;
372        }
373        if let Some(wd_tensor) = state.get("weight_decay") {
374            self.weight_decay = wd_tensor.to_scalar()?;
375        }
376        if let Some(nesterov_tensor) = state.get("nesterov") {
377            self.nesterov = nesterov_tensor.to_scalar()? > 0.5;
378        }
379        if let Some(step_tensor) = state.get("current_step") {
380            self.current_step = step_tensor.to_scalar()? as usize;
381        }
382
383        // Note: Loading full momentum states would require more complex serialization
384        // This is a simplified implementation
385
386        Ok(())
387    }
388}
389
390/// Sparse Adam optimizer for handling sparse gradients efficiently.
391#[derive(Debug)]
392pub struct SparseAdam {
393    learning_rate: f32,
394    beta1: f32,
395    beta2: f32,
396    epsilon: f32,
397    weight_decay: f32,
398    config: SparseConfig,
399    momentum_states: HashMap<usize, SparseMomentumState>,
400    variance_states: HashMap<usize, HashMap<usize, f32>>,
401    current_step: usize,
402}
403
404impl SparseAdam {
405    pub fn new(
406        learning_rate: f32,
407        beta1: f32,
408        beta2: f32,
409        epsilon: f32,
410        weight_decay: f32,
411        config: SparseConfig,
412    ) -> Self {
413        Self {
414            learning_rate,
415            beta1,
416            beta2,
417            epsilon,
418            weight_decay,
419            config,
420            momentum_states: HashMap::new(),
421            variance_states: HashMap::new(),
422            current_step: 0,
423        }
424    }
425
426    pub fn with_default_config(
427        learning_rate: f32,
428        beta1: f32,
429        beta2: f32,
430        epsilon: f32,
431        weight_decay: f32,
432    ) -> Self {
433        Self::new(
434            learning_rate,
435            beta1,
436            beta2,
437            epsilon,
438            weight_decay,
439            SparseConfig::default(),
440        )
441    }
442
443    fn get_sparse_indices(&self, gradient: &Tensor) -> Result<Vec<usize>> {
444        let grad_data = gradient.data()?;
445        Ok(grad_data
446            .iter()
447            .enumerate()
448            .filter_map(
449                |(i, &val)| {
450                    if val.abs() > self.config.sparsity_threshold {
451                        Some(i)
452                    } else {
453                        None
454                    }
455                },
456            )
457            .collect())
458    }
459
460    fn update_sparse_adam(
461        &mut self,
462        param_id: usize,
463        gradient: &Tensor,
464        parameter: &mut Tensor,
465    ) -> Result<()> {
466        let sparse_indices = self.get_sparse_indices(gradient)?;
467        if sparse_indices.is_empty() {
468            return Ok(());
469        }
470
471        let grad_data = gradient.data()?;
472        let mut param_data = parameter.data()?;
473
474        // Get or create states
475        let momentum_state = self.momentum_states.entry(param_id).or_default();
476        let variance_state = self.variance_states.entry(param_id).or_default();
477
478        // Bias correction terms
479        let bias_correction1 = 1.0 - self.beta1.powi(self.current_step as i32);
480        let bias_correction2 = 1.0 - self.beta2.powi(self.current_step as i32);
481
482        // Update sparse parameters
483        for &idx in &sparse_indices {
484            let mut grad_val = grad_data[idx];
485
486            // Apply weight decay
487            if self.weight_decay != 0.0 {
488                grad_val += self.weight_decay * param_data[idx];
489            }
490
491            // Update biased first moment estimate
492            let momentum_val = momentum_state.momentum.get(&idx).copied().unwrap_or(0.0);
493            let new_momentum = self.beta1 * momentum_val + (1.0 - self.beta1) * grad_val;
494            momentum_state.momentum.insert(idx, new_momentum);
495
496            // Update biased second raw moment estimate
497            let variance_val = variance_state.get(&idx).copied().unwrap_or(0.0);
498            let new_variance = self.beta2 * variance_val + (1.0 - self.beta2) * grad_val * grad_val;
499            variance_state.insert(idx, new_variance);
500
501            // Compute bias-corrected first and second moment estimates
502            let momentum_corrected = new_momentum / bias_correction1;
503            let variance_corrected = new_variance / bias_correction2;
504
505            // Update parameter
506            let denom = variance_corrected.sqrt() + self.epsilon;
507            param_data[idx] -= self.learning_rate * momentum_corrected / denom;
508
509            momentum_state.last_update.insert(idx, self.current_step);
510        }
511
512        // Update parameter tensor
513        *parameter = Tensor::from_vec(param_data, &parameter.shape())?;
514
515        Ok(())
516    }
517}
518
519impl OptimizerState for SparseAdam {
520    fn zero_grad(&mut self) -> Result<()> {
521        Ok(())
522    }
523
524    fn step(&mut self, parameters: &mut [Tensor]) -> Result<()> {
525        self.current_step += 1;
526
527        for (param_id, parameter) in parameters.iter_mut().enumerate() {
528            if let Ok(gradient) = parameter.grad() {
529                self.update_sparse_adam(param_id, &gradient, parameter)?;
530            }
531        }
532
533        Ok(())
534    }
535
536    fn get_lr(&self) -> f32 {
537        self.learning_rate
538    }
539
540    fn set_lr(&mut self, lr: f32) {
541        self.learning_rate = lr;
542    }
543
544    fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
545        let mut state = HashMap::new();
546        state.insert(
547            "learning_rate".to_string(),
548            Tensor::scalar(self.learning_rate)?,
549        );
550        state.insert("beta1".to_string(), Tensor::scalar(self.beta1)?);
551        state.insert("beta2".to_string(), Tensor::scalar(self.beta2)?);
552        state.insert("epsilon".to_string(), Tensor::scalar(self.epsilon)?);
553        state.insert(
554            "weight_decay".to_string(),
555            Tensor::scalar(self.weight_decay)?,
556        );
557        state.insert(
558            "current_step".to_string(),
559            Tensor::scalar(self.current_step as f32)?,
560        );
561        Ok(state)
562    }
563
564    fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
565        if let Some(lr) = state.get("learning_rate") {
566            self.learning_rate = lr.to_scalar()?;
567        }
568        if let Some(beta1) = state.get("beta1") {
569            self.beta1 = beta1.to_scalar()?;
570        }
571        if let Some(beta2) = state.get("beta2") {
572            self.beta2 = beta2.to_scalar()?;
573        }
574        if let Some(eps) = state.get("epsilon") {
575            self.epsilon = eps.to_scalar()?;
576        }
577        if let Some(wd) = state.get("weight_decay") {
578            self.weight_decay = wd.to_scalar()?;
579        }
580        if let Some(step) = state.get("current_step") {
581            self.current_step = step.to_scalar()? as usize;
582        }
583        Ok(())
584    }
585}
586
587#[cfg(test)]
588mod tests {
589    use super::*;
590
591    #[test]
592    fn test_sparse_config_default() {
593        let config = SparseConfig::default();
594        assert_eq!(config.sparsity_threshold, 1e-8);
595        assert!(config.max_active_params.is_none());
596        assert!(config.lazy_updates);
597        assert_eq!(config.cleanup_frequency, 1000);
598        assert!(!config.compress_inactive);
599    }
600
601    #[test]
602    fn test_sparse_momentum_state() {
603        let mut state = SparseMomentumState::new();
604        assert_eq!(state.num_active(), 0);
605
606        state.momentum.insert(0, 1.0);
607        state.momentum.insert(5, 2.0);
608        assert_eq!(state.num_active(), 2);
609
610        state.cleanup(0, 100);
611        assert_eq!(state.num_active(), 2); // No cleanup without last_update entries
612    }
613
614    #[test]
615    fn test_sparse_sgd_creation() {
616        let optimizer = SparseSGD::with_default_config(0.01, 0.9, 0.0, 1e-4, false);
617        assert_eq!(optimizer.get_lr(), 0.01);
618        assert_eq!(optimizer.total_active_states(), 0);
619    }
620
621    #[test]
622    fn test_sparse_adam_creation() {
623        let optimizer = SparseAdam::with_default_config(1e-3, 0.9, 0.999, 1e-8, 0.01);
624        assert_eq!(optimizer.get_lr(), 1e-3);
625        assert_eq!(optimizer.current_step, 0);
626    }
627
628    #[test]
629    fn test_sparse_sgd_lr_update() {
630        let mut optimizer = SparseSGD::with_default_config(0.01, 0.9, 0.0, 1e-4, false);
631        assert_eq!(optimizer.get_lr(), 0.01);
632
633        optimizer.set_lr(0.001);
634        assert_eq!(optimizer.get_lr(), 0.001);
635    }
636}