Skip to main content

trustformers_optim/
sparse.rs

1//! # Sparse Momentum Methods
2//!
3//! This module provides optimizers specifically designed for sparse gradients,
4//! commonly encountered in embedding layers, large language models, and recommendation systems.
5//! These optimizers are memory-efficient and only update parameters that receive non-zero gradients.
6//!
7//! ## Benefits
8//! - Memory efficient for sparse models (embeddings, transformers)
9//! - Faster convergence on sparse data
10//! - Reduced computation for inactive parameters
11//! - Better handling of rare features
12
13use crate::optimizer::OptimizerState;
14use anyhow::Result;
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17use trustformers_core::tensor::Tensor;
18
19/// Configuration for sparse optimization methods.
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct SparseConfig {
22    /// Sparsity threshold below which gradients are considered zero
23    pub sparsity_threshold: f32,
24    /// Maximum number of active parameters to track
25    pub max_active_params: Option<usize>,
26    /// Whether to use lazy updates for momentum
27    pub lazy_updates: bool,
28    /// Frequency of momentum state cleanup (steps)
29    pub cleanup_frequency: usize,
30    /// Whether to compress inactive momentum states
31    pub compress_inactive: bool,
32}
33
34impl Default for SparseConfig {
35    fn default() -> Self {
36        Self {
37            sparsity_threshold: 1e-8,
38            max_active_params: None,
39            lazy_updates: true,
40            cleanup_frequency: 1000,
41            compress_inactive: false,
42        }
43    }
44}
45
46/// Sparse momentum state for a parameter.
47#[derive(Debug, Clone)]
48pub struct SparseMomentumState {
49    /// Momentum buffer (only for active indices)
50    pub momentum: HashMap<usize, f32>,
51    /// Last update step for each active index
52    pub last_update: HashMap<usize, usize>,
53    /// Accumulated gradient norm (for adaptive methods)
54    pub grad_norm_acc: HashMap<usize, f32>,
55    /// Whether state is compressed
56    pub is_compressed: bool,
57}
58
59impl Default for SparseMomentumState {
60    fn default() -> Self {
61        Self::new()
62    }
63}
64
65impl SparseMomentumState {
66    pub fn new() -> Self {
67        Self {
68            momentum: HashMap::new(),
69            last_update: HashMap::new(),
70            grad_norm_acc: HashMap::new(),
71            is_compressed: false,
72        }
73    }
74
75    /// Get the number of active parameters.
76    pub fn num_active(&self) -> usize {
77        self.momentum.len()
78    }
79
80    /// Apply lazy momentum updates.
81    pub fn apply_lazy_update(&mut self, current_step: usize, decay: f32) {
82        for (idx, momentum) in self.momentum.iter_mut() {
83            if let Some(&last_step) = self.last_update.get(idx) {
84                let steps_skipped = current_step - last_step - 1;
85                if steps_skipped > 0 {
86                    // Apply exponential decay for skipped steps
87                    *momentum *= decay.powi(steps_skipped as i32);
88                }
89            }
90        }
91    }
92
93    /// Clean up old momentum states.
94    pub fn cleanup(&mut self, max_age_steps: usize, current_step: usize) {
95        let mut to_remove = Vec::new();
96
97        for (idx, &last_step) in &self.last_update {
98            if current_step - last_step > max_age_steps {
99                to_remove.push(*idx);
100            }
101        }
102
103        for idx in to_remove {
104            self.momentum.remove(&idx);
105            self.last_update.remove(&idx);
106            self.grad_norm_acc.remove(&idx);
107        }
108    }
109
110    /// Compress inactive momentum states.
111    pub fn compress(&mut self) {
112        if self.is_compressed {
113            return;
114        }
115
116        // Remove very small momentum values
117        let threshold = 1e-10;
118        self.momentum.retain(|_, &mut v| v.abs() > threshold);
119        self.grad_norm_acc.retain(|_, &mut v| v > threshold);
120
121        self.is_compressed = true;
122    }
123
124    /// Decompress momentum states.
125    pub fn decompress(&mut self) {
126        self.is_compressed = false;
127    }
128}
129
130/// Sparse SGD with momentum optimizer.
131#[derive(Debug)]
132pub struct SparseSGD {
133    learning_rate: f32,
134    momentum: f32,
135    dampening: f32,
136    weight_decay: f32,
137    nesterov: bool,
138    config: SparseConfig,
139    momentum_states: HashMap<usize, SparseMomentumState>,
140    current_step: usize,
141}
142
143impl SparseSGD {
144    pub fn new(
145        learning_rate: f32,
146        momentum: f32,
147        dampening: f32,
148        weight_decay: f32,
149        nesterov: bool,
150        config: SparseConfig,
151    ) -> Self {
152        Self {
153            learning_rate,
154            momentum,
155            dampening,
156            weight_decay,
157            nesterov,
158            config,
159            momentum_states: HashMap::new(),
160            current_step: 0,
161        }
162    }
163
164    /// Create with default sparse configuration.
165    pub fn with_default_config(
166        learning_rate: f32,
167        momentum: f32,
168        dampening: f32,
169        weight_decay: f32,
170        nesterov: bool,
171    ) -> Self {
172        Self::new(
173            learning_rate,
174            momentum,
175            dampening,
176            weight_decay,
177            nesterov,
178            SparseConfig::default(),
179        )
180    }
181
182    /// Get sparse indices from gradient tensor.
183    fn get_sparse_indices(&self, gradient: &Tensor) -> Result<Vec<usize>> {
184        let grad_data = gradient.data()?;
185        let indices: Vec<usize> = grad_data
186            .iter()
187            .enumerate()
188            .filter_map(
189                |(i, &val)| {
190                    if val.abs() > self.config.sparsity_threshold {
191                        Some(i)
192                    } else {
193                        None
194                    }
195                },
196            )
197            .collect();
198
199        // Limit active parameters if configured
200        if let Some(max_active) = self.config.max_active_params {
201            if indices.len() > max_active {
202                // Keep indices with largest gradients
203                let mut indexed_grads: Vec<(usize, f32)> =
204                    indices.iter().map(|&i| (i, grad_data[i].abs())).collect();
205                indexed_grads.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
206                return Ok(indexed_grads.into_iter().take(max_active).map(|(i, _)| i).collect());
207            }
208        }
209
210        Ok(indices)
211    }
212
213    /// Update sparse momentum for a parameter.
214    fn update_sparse_momentum(
215        &mut self,
216        param_id: usize,
217        gradient: &Tensor,
218        parameter: &mut Tensor,
219    ) -> Result<()> {
220        let sparse_indices = self.get_sparse_indices(gradient)?;
221        if sparse_indices.is_empty() {
222            return Ok(());
223        }
224
225        let grad_data = gradient.data()?;
226        let mut param_data = parameter.data()?;
227
228        // Get or create momentum state
229        let momentum_state = self.momentum_states.entry(param_id).or_default();
230
231        // Apply lazy updates if enabled
232        if self.config.lazy_updates {
233            momentum_state.apply_lazy_update(self.current_step, self.momentum);
234        }
235
236        // Update momentum for each sparse index
237        for &idx in &sparse_indices {
238            let mut grad_val = grad_data[idx];
239
240            // Apply weight decay
241            if self.weight_decay != 0.0 {
242                grad_val += self.weight_decay * param_data[idx];
243            }
244
245            // Update momentum
246            let momentum_val = momentum_state.momentum.get(&idx).copied().unwrap_or(0.0);
247            let new_momentum = self.momentum * momentum_val + (1.0 - self.dampening) * grad_val;
248            momentum_state.momentum.insert(idx, new_momentum);
249            momentum_state.last_update.insert(idx, self.current_step);
250
251            // Apply update
252            let update = if self.nesterov {
253                grad_val + self.momentum * new_momentum
254            } else {
255                new_momentum
256            };
257
258            param_data[idx] -= self.learning_rate * update;
259        }
260
261        // Update parameter tensor
262        *parameter = Tensor::from_vec(param_data, &parameter.shape())?;
263
264        Ok(())
265    }
266
267    /// Get momentum statistics.
268    pub fn get_momentum_stats(&self) -> HashMap<usize, usize> {
269        self.momentum_states
270            .iter()
271            .map(|(&param_id, state)| (param_id, state.num_active()))
272            .collect()
273    }
274
275    /// Total number of active momentum states across all parameters.
276    pub fn total_active_states(&self) -> usize {
277        self.momentum_states.values().map(|s| s.num_active()).sum()
278    }
279
280    /// Cleanup old momentum states for all parameters.
281    pub fn cleanup_momentum_states(&mut self) {
282        if self.current_step % self.config.cleanup_frequency == 0 {
283            let max_age = self.config.cleanup_frequency * 2;
284            for state in self.momentum_states.values_mut() {
285                state.cleanup(max_age, self.current_step);
286                if self.config.compress_inactive {
287                    state.compress();
288                }
289            }
290        }
291    }
292}
293
294impl OptimizerState for SparseSGD {
295    fn zero_grad(&mut self) -> Result<()> {
296        // For sparse optimizers, we don't need to explicitly zero gradients
297        // since we only process non-zero gradients
298        Ok(())
299    }
300
301    fn step(&mut self, parameters: &mut [Tensor]) -> Result<()> {
302        self.current_step += 1;
303
304        for (param_id, parameter) in parameters.iter_mut().enumerate() {
305            // For sparse optimizers, we assume gradients are already computed
306            // and available in the parameter's grad field
307            if let Ok(gradient) = parameter.grad() {
308                self.update_sparse_momentum(param_id, &gradient, parameter)?;
309            }
310        }
311
312        // Periodic cleanup
313        self.cleanup_momentum_states();
314
315        Ok(())
316    }
317
318    fn get_lr(&self) -> f32 {
319        self.learning_rate
320    }
321
322    fn set_lr(&mut self, lr: f32) {
323        self.learning_rate = lr;
324    }
325
326    fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
327        let mut state = HashMap::new();
328
329        // Save hyperparameters
330        state.insert(
331            "learning_rate".to_string(),
332            Tensor::scalar(self.learning_rate)?,
333        );
334        state.insert("momentum".to_string(), Tensor::scalar(self.momentum)?);
335        state.insert("dampening".to_string(), Tensor::scalar(self.dampening)?);
336        state.insert(
337            "weight_decay".to_string(),
338            Tensor::scalar(self.weight_decay)?,
339        );
340        state.insert(
341            "nesterov".to_string(),
342            Tensor::scalar(self.nesterov as i32 as f32)?,
343        );
344        state.insert(
345            "current_step".to_string(),
346            Tensor::scalar(self.current_step as f32)?,
347        );
348
349        // Save momentum states (this is simplified - real implementation would be more complex)
350        for (&param_id, momentum_state) in &self.momentum_states {
351            let num_active = momentum_state.num_active();
352            state.insert(
353                format!("momentum_state_{}_active_count", param_id),
354                Tensor::scalar(num_active as f32)?,
355            );
356        }
357
358        Ok(state)
359    }
360
361    fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
362        // Load hyperparameters
363        if let Some(lr_tensor) = state.get("learning_rate") {
364            self.learning_rate = lr_tensor.to_scalar()?;
365        }
366        if let Some(momentum_tensor) = state.get("momentum") {
367            self.momentum = momentum_tensor.to_scalar()?;
368        }
369        if let Some(dampening_tensor) = state.get("dampening") {
370            self.dampening = dampening_tensor.to_scalar()?;
371        }
372        if let Some(wd_tensor) = state.get("weight_decay") {
373            self.weight_decay = wd_tensor.to_scalar()?;
374        }
375        if let Some(nesterov_tensor) = state.get("nesterov") {
376            self.nesterov = nesterov_tensor.to_scalar()? > 0.5;
377        }
378        if let Some(step_tensor) = state.get("current_step") {
379            self.current_step = step_tensor.to_scalar()? as usize;
380        }
381
382        // Note: Loading full momentum states would require more complex serialization
383        // This is a simplified implementation
384
385        Ok(())
386    }
387}
388
389/// Sparse Adam optimizer for handling sparse gradients efficiently.
390#[derive(Debug)]
391pub struct SparseAdam {
392    learning_rate: f32,
393    beta1: f32,
394    beta2: f32,
395    epsilon: f32,
396    weight_decay: f32,
397    config: SparseConfig,
398    momentum_states: HashMap<usize, SparseMomentumState>,
399    variance_states: HashMap<usize, HashMap<usize, f32>>,
400    current_step: usize,
401}
402
403impl SparseAdam {
404    pub fn new(
405        learning_rate: f32,
406        beta1: f32,
407        beta2: f32,
408        epsilon: f32,
409        weight_decay: f32,
410        config: SparseConfig,
411    ) -> Self {
412        Self {
413            learning_rate,
414            beta1,
415            beta2,
416            epsilon,
417            weight_decay,
418            config,
419            momentum_states: HashMap::new(),
420            variance_states: HashMap::new(),
421            current_step: 0,
422        }
423    }
424
425    pub fn with_default_config(
426        learning_rate: f32,
427        beta1: f32,
428        beta2: f32,
429        epsilon: f32,
430        weight_decay: f32,
431    ) -> Self {
432        Self::new(
433            learning_rate,
434            beta1,
435            beta2,
436            epsilon,
437            weight_decay,
438            SparseConfig::default(),
439        )
440    }
441
442    fn get_sparse_indices(&self, gradient: &Tensor) -> Result<Vec<usize>> {
443        let grad_data = gradient.data()?;
444        Ok(grad_data
445            .iter()
446            .enumerate()
447            .filter_map(
448                |(i, &val)| {
449                    if val.abs() > self.config.sparsity_threshold {
450                        Some(i)
451                    } else {
452                        None
453                    }
454                },
455            )
456            .collect())
457    }
458
459    fn update_sparse_adam(
460        &mut self,
461        param_id: usize,
462        gradient: &Tensor,
463        parameter: &mut Tensor,
464    ) -> Result<()> {
465        let sparse_indices = self.get_sparse_indices(gradient)?;
466        if sparse_indices.is_empty() {
467            return Ok(());
468        }
469
470        let grad_data = gradient.data()?;
471        let mut param_data = parameter.data()?;
472
473        // Get or create states
474        let momentum_state = self.momentum_states.entry(param_id).or_default();
475        let variance_state = self.variance_states.entry(param_id).or_default();
476
477        // Bias correction terms
478        let bias_correction1 = 1.0 - self.beta1.powi(self.current_step as i32);
479        let bias_correction2 = 1.0 - self.beta2.powi(self.current_step as i32);
480
481        // Update sparse parameters
482        for &idx in &sparse_indices {
483            let mut grad_val = grad_data[idx];
484
485            // Apply weight decay
486            if self.weight_decay != 0.0 {
487                grad_val += self.weight_decay * param_data[idx];
488            }
489
490            // Update biased first moment estimate
491            let momentum_val = momentum_state.momentum.get(&idx).copied().unwrap_or(0.0);
492            let new_momentum = self.beta1 * momentum_val + (1.0 - self.beta1) * grad_val;
493            momentum_state.momentum.insert(idx, new_momentum);
494
495            // Update biased second raw moment estimate
496            let variance_val = variance_state.get(&idx).copied().unwrap_or(0.0);
497            let new_variance = self.beta2 * variance_val + (1.0 - self.beta2) * grad_val * grad_val;
498            variance_state.insert(idx, new_variance);
499
500            // Compute bias-corrected first and second moment estimates
501            let momentum_corrected = new_momentum / bias_correction1;
502            let variance_corrected = new_variance / bias_correction2;
503
504            // Update parameter
505            let denom = variance_corrected.sqrt() + self.epsilon;
506            param_data[idx] -= self.learning_rate * momentum_corrected / denom;
507
508            momentum_state.last_update.insert(idx, self.current_step);
509        }
510
511        // Update parameter tensor
512        *parameter = Tensor::from_vec(param_data, &parameter.shape())?;
513
514        Ok(())
515    }
516}
517
518impl OptimizerState for SparseAdam {
519    fn zero_grad(&mut self) -> Result<()> {
520        Ok(())
521    }
522
523    fn step(&mut self, parameters: &mut [Tensor]) -> Result<()> {
524        self.current_step += 1;
525
526        for (param_id, parameter) in parameters.iter_mut().enumerate() {
527            if let Ok(gradient) = parameter.grad() {
528                self.update_sparse_adam(param_id, &gradient, parameter)?;
529            }
530        }
531
532        Ok(())
533    }
534
535    fn get_lr(&self) -> f32 {
536        self.learning_rate
537    }
538
539    fn set_lr(&mut self, lr: f32) {
540        self.learning_rate = lr;
541    }
542
543    fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
544        let mut state = HashMap::new();
545        state.insert(
546            "learning_rate".to_string(),
547            Tensor::scalar(self.learning_rate)?,
548        );
549        state.insert("beta1".to_string(), Tensor::scalar(self.beta1)?);
550        state.insert("beta2".to_string(), Tensor::scalar(self.beta2)?);
551        state.insert("epsilon".to_string(), Tensor::scalar(self.epsilon)?);
552        state.insert(
553            "weight_decay".to_string(),
554            Tensor::scalar(self.weight_decay)?,
555        );
556        state.insert(
557            "current_step".to_string(),
558            Tensor::scalar(self.current_step as f32)?,
559        );
560        Ok(state)
561    }
562
563    fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
564        if let Some(lr) = state.get("learning_rate") {
565            self.learning_rate = lr.to_scalar()?;
566        }
567        if let Some(beta1) = state.get("beta1") {
568            self.beta1 = beta1.to_scalar()?;
569        }
570        if let Some(beta2) = state.get("beta2") {
571            self.beta2 = beta2.to_scalar()?;
572        }
573        if let Some(eps) = state.get("epsilon") {
574            self.epsilon = eps.to_scalar()?;
575        }
576        if let Some(wd) = state.get("weight_decay") {
577            self.weight_decay = wd.to_scalar()?;
578        }
579        if let Some(step) = state.get("current_step") {
580            self.current_step = step.to_scalar()? as usize;
581        }
582        Ok(())
583    }
584}
585
586#[cfg(test)]
587mod tests {
588    use super::*;
589
590    #[test]
591    fn test_sparse_config_default() {
592        let config = SparseConfig::default();
593        assert_eq!(config.sparsity_threshold, 1e-8);
594        assert!(config.max_active_params.is_none());
595        assert!(config.lazy_updates);
596        assert_eq!(config.cleanup_frequency, 1000);
597        assert!(!config.compress_inactive);
598    }
599
600    #[test]
601    fn test_sparse_momentum_state() {
602        let mut state = SparseMomentumState::new();
603        assert_eq!(state.num_active(), 0);
604
605        state.momentum.insert(0, 1.0);
606        state.momentum.insert(5, 2.0);
607        assert_eq!(state.num_active(), 2);
608
609        state.cleanup(0, 100);
610        assert_eq!(state.num_active(), 2); // No cleanup without last_update entries
611    }
612
613    #[test]
614    fn test_sparse_sgd_creation() {
615        let optimizer = SparseSGD::with_default_config(0.01, 0.9, 0.0, 1e-4, false);
616        assert_eq!(optimizer.get_lr(), 0.01);
617        assert_eq!(optimizer.total_active_states(), 0);
618    }
619
620    #[test]
621    fn test_sparse_adam_creation() {
622        let optimizer = SparseAdam::with_default_config(1e-3, 0.9, 0.999, 1e-8, 0.01);
623        assert_eq!(optimizer.get_lr(), 1e-3);
624        assert_eq!(optimizer.current_step, 0);
625    }
626
627    #[test]
628    fn test_sparse_sgd_lr_update() {
629        let mut optimizer = SparseSGD::with_default_config(0.01, 0.9, 0.0, 1e-4, false);
630        assert_eq!(optimizer.get_lr(), 0.01);
631
632        optimizer.set_lr(0.001);
633        assert_eq!(optimizer.get_lr(), 0.001);
634    }
635}