1use crate::optimizer::OptimizerState;
14use anyhow::Result;
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17use trustformers_core::tensor::Tensor;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct SparseConfig {
22 pub sparsity_threshold: f32,
24 pub max_active_params: Option<usize>,
26 pub lazy_updates: bool,
28 pub cleanup_frequency: usize,
30 pub compress_inactive: bool,
32}
33
34impl Default for SparseConfig {
35 fn default() -> Self {
36 Self {
37 sparsity_threshold: 1e-8,
38 max_active_params: None,
39 lazy_updates: true,
40 cleanup_frequency: 1000,
41 compress_inactive: false,
42 }
43 }
44}
45
46#[derive(Debug, Clone)]
48pub struct SparseMomentumState {
49 pub momentum: HashMap<usize, f32>,
51 pub last_update: HashMap<usize, usize>,
53 pub grad_norm_acc: HashMap<usize, f32>,
55 pub is_compressed: bool,
57}
58
59impl Default for SparseMomentumState {
60 fn default() -> Self {
61 Self::new()
62 }
63}
64
65impl SparseMomentumState {
66 pub fn new() -> Self {
67 Self {
68 momentum: HashMap::new(),
69 last_update: HashMap::new(),
70 grad_norm_acc: HashMap::new(),
71 is_compressed: false,
72 }
73 }
74
75 pub fn num_active(&self) -> usize {
77 self.momentum.len()
78 }
79
80 pub fn apply_lazy_update(&mut self, current_step: usize, decay: f32) {
82 for (idx, momentum) in self.momentum.iter_mut() {
83 if let Some(&last_step) = self.last_update.get(idx) {
84 let steps_skipped = current_step - last_step - 1;
85 if steps_skipped > 0 {
86 *momentum *= decay.powi(steps_skipped as i32);
88 }
89 }
90 }
91 }
92
93 pub fn cleanup(&mut self, max_age_steps: usize, current_step: usize) {
95 let mut to_remove = Vec::new();
96
97 for (idx, &last_step) in &self.last_update {
98 if current_step - last_step > max_age_steps {
99 to_remove.push(*idx);
100 }
101 }
102
103 for idx in to_remove {
104 self.momentum.remove(&idx);
105 self.last_update.remove(&idx);
106 self.grad_norm_acc.remove(&idx);
107 }
108 }
109
110 pub fn compress(&mut self) {
112 if self.is_compressed {
113 return;
114 }
115
116 let threshold = 1e-10;
118 self.momentum.retain(|_, &mut v| v.abs() > threshold);
119 self.grad_norm_acc.retain(|_, &mut v| v > threshold);
120
121 self.is_compressed = true;
122 }
123
124 pub fn decompress(&mut self) {
126 self.is_compressed = false;
127 }
128}
129
130#[derive(Debug)]
132pub struct SparseSGD {
133 learning_rate: f32,
134 momentum: f32,
135 dampening: f32,
136 weight_decay: f32,
137 nesterov: bool,
138 config: SparseConfig,
139 momentum_states: HashMap<usize, SparseMomentumState>,
140 current_step: usize,
141}
142
143impl SparseSGD {
144 pub fn new(
145 learning_rate: f32,
146 momentum: f32,
147 dampening: f32,
148 weight_decay: f32,
149 nesterov: bool,
150 config: SparseConfig,
151 ) -> Self {
152 Self {
153 learning_rate,
154 momentum,
155 dampening,
156 weight_decay,
157 nesterov,
158 config,
159 momentum_states: HashMap::new(),
160 current_step: 0,
161 }
162 }
163
164 pub fn with_default_config(
166 learning_rate: f32,
167 momentum: f32,
168 dampening: f32,
169 weight_decay: f32,
170 nesterov: bool,
171 ) -> Self {
172 Self::new(
173 learning_rate,
174 momentum,
175 dampening,
176 weight_decay,
177 nesterov,
178 SparseConfig::default(),
179 )
180 }
181
182 fn get_sparse_indices(&self, gradient: &Tensor) -> Result<Vec<usize>> {
184 let grad_data = gradient.data()?;
185 let indices: Vec<usize> = grad_data
186 .iter()
187 .enumerate()
188 .filter_map(
189 |(i, &val)| {
190 if val.abs() > self.config.sparsity_threshold {
191 Some(i)
192 } else {
193 None
194 }
195 },
196 )
197 .collect();
198
199 if let Some(max_active) = self.config.max_active_params {
201 if indices.len() > max_active {
202 let mut indexed_grads: Vec<(usize, f32)> =
204 indices.iter().map(|&i| (i, grad_data[i].abs())).collect();
205 indexed_grads
206 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
207 return Ok(indexed_grads.into_iter().take(max_active).map(|(i, _)| i).collect());
208 }
209 }
210
211 Ok(indices)
212 }
213
214 fn update_sparse_momentum(
216 &mut self,
217 param_id: usize,
218 gradient: &Tensor,
219 parameter: &mut Tensor,
220 ) -> Result<()> {
221 let sparse_indices = self.get_sparse_indices(gradient)?;
222 if sparse_indices.is_empty() {
223 return Ok(());
224 }
225
226 let grad_data = gradient.data()?;
227 let mut param_data = parameter.data()?;
228
229 let momentum_state = self.momentum_states.entry(param_id).or_default();
231
232 if self.config.lazy_updates {
234 momentum_state.apply_lazy_update(self.current_step, self.momentum);
235 }
236
237 for &idx in &sparse_indices {
239 let mut grad_val = grad_data[idx];
240
241 if self.weight_decay != 0.0 {
243 grad_val += self.weight_decay * param_data[idx];
244 }
245
246 let momentum_val = momentum_state.momentum.get(&idx).copied().unwrap_or(0.0);
248 let new_momentum = self.momentum * momentum_val + (1.0 - self.dampening) * grad_val;
249 momentum_state.momentum.insert(idx, new_momentum);
250 momentum_state.last_update.insert(idx, self.current_step);
251
252 let update = if self.nesterov {
254 grad_val + self.momentum * new_momentum
255 } else {
256 new_momentum
257 };
258
259 param_data[idx] -= self.learning_rate * update;
260 }
261
262 *parameter = Tensor::from_vec(param_data, ¶meter.shape())?;
264
265 Ok(())
266 }
267
268 pub fn get_momentum_stats(&self) -> HashMap<usize, usize> {
270 self.momentum_states
271 .iter()
272 .map(|(¶m_id, state)| (param_id, state.num_active()))
273 .collect()
274 }
275
276 pub fn total_active_states(&self) -> usize {
278 self.momentum_states.values().map(|s| s.num_active()).sum()
279 }
280
281 pub fn cleanup_momentum_states(&mut self) {
283 if self.current_step.is_multiple_of(self.config.cleanup_frequency) {
284 let max_age = self.config.cleanup_frequency * 2;
285 for state in self.momentum_states.values_mut() {
286 state.cleanup(max_age, self.current_step);
287 if self.config.compress_inactive {
288 state.compress();
289 }
290 }
291 }
292 }
293}
294
295impl OptimizerState for SparseSGD {
296 fn zero_grad(&mut self) -> Result<()> {
297 Ok(())
300 }
301
302 fn step(&mut self, parameters: &mut [Tensor]) -> Result<()> {
303 self.current_step += 1;
304
305 for (param_id, parameter) in parameters.iter_mut().enumerate() {
306 if let Ok(gradient) = parameter.grad() {
309 self.update_sparse_momentum(param_id, &gradient, parameter)?;
310 }
311 }
312
313 self.cleanup_momentum_states();
315
316 Ok(())
317 }
318
319 fn get_lr(&self) -> f32 {
320 self.learning_rate
321 }
322
323 fn set_lr(&mut self, lr: f32) {
324 self.learning_rate = lr;
325 }
326
327 fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
328 let mut state = HashMap::new();
329
330 state.insert(
332 "learning_rate".to_string(),
333 Tensor::scalar(self.learning_rate)?,
334 );
335 state.insert("momentum".to_string(), Tensor::scalar(self.momentum)?);
336 state.insert("dampening".to_string(), Tensor::scalar(self.dampening)?);
337 state.insert(
338 "weight_decay".to_string(),
339 Tensor::scalar(self.weight_decay)?,
340 );
341 state.insert(
342 "nesterov".to_string(),
343 Tensor::scalar(self.nesterov as i32 as f32)?,
344 );
345 state.insert(
346 "current_step".to_string(),
347 Tensor::scalar(self.current_step as f32)?,
348 );
349
350 for (¶m_id, momentum_state) in &self.momentum_states {
352 let num_active = momentum_state.num_active();
353 state.insert(
354 format!("momentum_state_{}_active_count", param_id),
355 Tensor::scalar(num_active as f32)?,
356 );
357 }
358
359 Ok(state)
360 }
361
362 fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
363 if let Some(lr_tensor) = state.get("learning_rate") {
365 self.learning_rate = lr_tensor.to_scalar()?;
366 }
367 if let Some(momentum_tensor) = state.get("momentum") {
368 self.momentum = momentum_tensor.to_scalar()?;
369 }
370 if let Some(dampening_tensor) = state.get("dampening") {
371 self.dampening = dampening_tensor.to_scalar()?;
372 }
373 if let Some(wd_tensor) = state.get("weight_decay") {
374 self.weight_decay = wd_tensor.to_scalar()?;
375 }
376 if let Some(nesterov_tensor) = state.get("nesterov") {
377 self.nesterov = nesterov_tensor.to_scalar()? > 0.5;
378 }
379 if let Some(step_tensor) = state.get("current_step") {
380 self.current_step = step_tensor.to_scalar()? as usize;
381 }
382
383 Ok(())
387 }
388}
389
390#[derive(Debug)]
392pub struct SparseAdam {
393 learning_rate: f32,
394 beta1: f32,
395 beta2: f32,
396 epsilon: f32,
397 weight_decay: f32,
398 config: SparseConfig,
399 momentum_states: HashMap<usize, SparseMomentumState>,
400 variance_states: HashMap<usize, HashMap<usize, f32>>,
401 current_step: usize,
402}
403
404impl SparseAdam {
405 pub fn new(
406 learning_rate: f32,
407 beta1: f32,
408 beta2: f32,
409 epsilon: f32,
410 weight_decay: f32,
411 config: SparseConfig,
412 ) -> Self {
413 Self {
414 learning_rate,
415 beta1,
416 beta2,
417 epsilon,
418 weight_decay,
419 config,
420 momentum_states: HashMap::new(),
421 variance_states: HashMap::new(),
422 current_step: 0,
423 }
424 }
425
426 pub fn with_default_config(
427 learning_rate: f32,
428 beta1: f32,
429 beta2: f32,
430 epsilon: f32,
431 weight_decay: f32,
432 ) -> Self {
433 Self::new(
434 learning_rate,
435 beta1,
436 beta2,
437 epsilon,
438 weight_decay,
439 SparseConfig::default(),
440 )
441 }
442
443 fn get_sparse_indices(&self, gradient: &Tensor) -> Result<Vec<usize>> {
444 let grad_data = gradient.data()?;
445 Ok(grad_data
446 .iter()
447 .enumerate()
448 .filter_map(
449 |(i, &val)| {
450 if val.abs() > self.config.sparsity_threshold {
451 Some(i)
452 } else {
453 None
454 }
455 },
456 )
457 .collect())
458 }
459
460 fn update_sparse_adam(
461 &mut self,
462 param_id: usize,
463 gradient: &Tensor,
464 parameter: &mut Tensor,
465 ) -> Result<()> {
466 let sparse_indices = self.get_sparse_indices(gradient)?;
467 if sparse_indices.is_empty() {
468 return Ok(());
469 }
470
471 let grad_data = gradient.data()?;
472 let mut param_data = parameter.data()?;
473
474 let momentum_state = self.momentum_states.entry(param_id).or_default();
476 let variance_state = self.variance_states.entry(param_id).or_default();
477
478 let bias_correction1 = 1.0 - self.beta1.powi(self.current_step as i32);
480 let bias_correction2 = 1.0 - self.beta2.powi(self.current_step as i32);
481
482 for &idx in &sparse_indices {
484 let mut grad_val = grad_data[idx];
485
486 if self.weight_decay != 0.0 {
488 grad_val += self.weight_decay * param_data[idx];
489 }
490
491 let momentum_val = momentum_state.momentum.get(&idx).copied().unwrap_or(0.0);
493 let new_momentum = self.beta1 * momentum_val + (1.0 - self.beta1) * grad_val;
494 momentum_state.momentum.insert(idx, new_momentum);
495
496 let variance_val = variance_state.get(&idx).copied().unwrap_or(0.0);
498 let new_variance = self.beta2 * variance_val + (1.0 - self.beta2) * grad_val * grad_val;
499 variance_state.insert(idx, new_variance);
500
501 let momentum_corrected = new_momentum / bias_correction1;
503 let variance_corrected = new_variance / bias_correction2;
504
505 let denom = variance_corrected.sqrt() + self.epsilon;
507 param_data[idx] -= self.learning_rate * momentum_corrected / denom;
508
509 momentum_state.last_update.insert(idx, self.current_step);
510 }
511
512 *parameter = Tensor::from_vec(param_data, ¶meter.shape())?;
514
515 Ok(())
516 }
517}
518
519impl OptimizerState for SparseAdam {
520 fn zero_grad(&mut self) -> Result<()> {
521 Ok(())
522 }
523
524 fn step(&mut self, parameters: &mut [Tensor]) -> Result<()> {
525 self.current_step += 1;
526
527 for (param_id, parameter) in parameters.iter_mut().enumerate() {
528 if let Ok(gradient) = parameter.grad() {
529 self.update_sparse_adam(param_id, &gradient, parameter)?;
530 }
531 }
532
533 Ok(())
534 }
535
536 fn get_lr(&self) -> f32 {
537 self.learning_rate
538 }
539
540 fn set_lr(&mut self, lr: f32) {
541 self.learning_rate = lr;
542 }
543
544 fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
545 let mut state = HashMap::new();
546 state.insert(
547 "learning_rate".to_string(),
548 Tensor::scalar(self.learning_rate)?,
549 );
550 state.insert("beta1".to_string(), Tensor::scalar(self.beta1)?);
551 state.insert("beta2".to_string(), Tensor::scalar(self.beta2)?);
552 state.insert("epsilon".to_string(), Tensor::scalar(self.epsilon)?);
553 state.insert(
554 "weight_decay".to_string(),
555 Tensor::scalar(self.weight_decay)?,
556 );
557 state.insert(
558 "current_step".to_string(),
559 Tensor::scalar(self.current_step as f32)?,
560 );
561 Ok(state)
562 }
563
564 fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
565 if let Some(lr) = state.get("learning_rate") {
566 self.learning_rate = lr.to_scalar()?;
567 }
568 if let Some(beta1) = state.get("beta1") {
569 self.beta1 = beta1.to_scalar()?;
570 }
571 if let Some(beta2) = state.get("beta2") {
572 self.beta2 = beta2.to_scalar()?;
573 }
574 if let Some(eps) = state.get("epsilon") {
575 self.epsilon = eps.to_scalar()?;
576 }
577 if let Some(wd) = state.get("weight_decay") {
578 self.weight_decay = wd.to_scalar()?;
579 }
580 if let Some(step) = state.get("current_step") {
581 self.current_step = step.to_scalar()? as usize;
582 }
583 Ok(())
584 }
585}
586
587#[cfg(test)]
588mod tests {
589 use super::*;
590
591 #[test]
592 fn test_sparse_config_default() {
593 let config = SparseConfig::default();
594 assert_eq!(config.sparsity_threshold, 1e-8);
595 assert!(config.max_active_params.is_none());
596 assert!(config.lazy_updates);
597 assert_eq!(config.cleanup_frequency, 1000);
598 assert!(!config.compress_inactive);
599 }
600
601 #[test]
602 fn test_sparse_momentum_state() {
603 let mut state = SparseMomentumState::new();
604 assert_eq!(state.num_active(), 0);
605
606 state.momentum.insert(0, 1.0);
607 state.momentum.insert(5, 2.0);
608 assert_eq!(state.num_active(), 2);
609
610 state.cleanup(0, 100);
611 assert_eq!(state.num_active(), 2); }
613
614 #[test]
615 fn test_sparse_sgd_creation() {
616 let optimizer = SparseSGD::with_default_config(0.01, 0.9, 0.0, 1e-4, false);
617 assert_eq!(optimizer.get_lr(), 0.01);
618 assert_eq!(optimizer.total_active_states(), 0);
619 }
620
621 #[test]
622 fn test_sparse_adam_creation() {
623 let optimizer = SparseAdam::with_default_config(1e-3, 0.9, 0.999, 1e-8, 0.01);
624 assert_eq!(optimizer.get_lr(), 1e-3);
625 assert_eq!(optimizer.current_step, 0);
626 }
627
628 #[test]
629 fn test_sparse_sgd_lr_update() {
630 let mut optimizer = SparseSGD::with_default_config(0.01, 0.9, 0.0, 1e-4, false);
631 assert_eq!(optimizer.get_lr(), 0.01);
632
633 optimizer.set_lr(0.001);
634 assert_eq!(optimizer.get_lr(), 0.001);
635 }
636}