1use crate::optimizer::OptimizerState;
14use anyhow::Result;
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17use trustformers_core::tensor::Tensor;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct SparseConfig {
22 pub sparsity_threshold: f32,
24 pub max_active_params: Option<usize>,
26 pub lazy_updates: bool,
28 pub cleanup_frequency: usize,
30 pub compress_inactive: bool,
32}
33
34impl Default for SparseConfig {
35 fn default() -> Self {
36 Self {
37 sparsity_threshold: 1e-8,
38 max_active_params: None,
39 lazy_updates: true,
40 cleanup_frequency: 1000,
41 compress_inactive: false,
42 }
43 }
44}
45
46#[derive(Debug, Clone)]
48pub struct SparseMomentumState {
49 pub momentum: HashMap<usize, f32>,
51 pub last_update: HashMap<usize, usize>,
53 pub grad_norm_acc: HashMap<usize, f32>,
55 pub is_compressed: bool,
57}
58
59impl Default for SparseMomentumState {
60 fn default() -> Self {
61 Self::new()
62 }
63}
64
65impl SparseMomentumState {
66 pub fn new() -> Self {
67 Self {
68 momentum: HashMap::new(),
69 last_update: HashMap::new(),
70 grad_norm_acc: HashMap::new(),
71 is_compressed: false,
72 }
73 }
74
75 pub fn num_active(&self) -> usize {
77 self.momentum.len()
78 }
79
80 pub fn apply_lazy_update(&mut self, current_step: usize, decay: f32) {
82 for (idx, momentum) in self.momentum.iter_mut() {
83 if let Some(&last_step) = self.last_update.get(idx) {
84 let steps_skipped = current_step - last_step - 1;
85 if steps_skipped > 0 {
86 *momentum *= decay.powi(steps_skipped as i32);
88 }
89 }
90 }
91 }
92
93 pub fn cleanup(&mut self, max_age_steps: usize, current_step: usize) {
95 let mut to_remove = Vec::new();
96
97 for (idx, &last_step) in &self.last_update {
98 if current_step - last_step > max_age_steps {
99 to_remove.push(*idx);
100 }
101 }
102
103 for idx in to_remove {
104 self.momentum.remove(&idx);
105 self.last_update.remove(&idx);
106 self.grad_norm_acc.remove(&idx);
107 }
108 }
109
110 pub fn compress(&mut self) {
112 if self.is_compressed {
113 return;
114 }
115
116 let threshold = 1e-10;
118 self.momentum.retain(|_, &mut v| v.abs() > threshold);
119 self.grad_norm_acc.retain(|_, &mut v| v > threshold);
120
121 self.is_compressed = true;
122 }
123
124 pub fn decompress(&mut self) {
126 self.is_compressed = false;
127 }
128}
129
130#[derive(Debug)]
132pub struct SparseSGD {
133 learning_rate: f32,
134 momentum: f32,
135 dampening: f32,
136 weight_decay: f32,
137 nesterov: bool,
138 config: SparseConfig,
139 momentum_states: HashMap<usize, SparseMomentumState>,
140 current_step: usize,
141}
142
143impl SparseSGD {
144 pub fn new(
145 learning_rate: f32,
146 momentum: f32,
147 dampening: f32,
148 weight_decay: f32,
149 nesterov: bool,
150 config: SparseConfig,
151 ) -> Self {
152 Self {
153 learning_rate,
154 momentum,
155 dampening,
156 weight_decay,
157 nesterov,
158 config,
159 momentum_states: HashMap::new(),
160 current_step: 0,
161 }
162 }
163
164 pub fn with_default_config(
166 learning_rate: f32,
167 momentum: f32,
168 dampening: f32,
169 weight_decay: f32,
170 nesterov: bool,
171 ) -> Self {
172 Self::new(
173 learning_rate,
174 momentum,
175 dampening,
176 weight_decay,
177 nesterov,
178 SparseConfig::default(),
179 )
180 }
181
182 fn get_sparse_indices(&self, gradient: &Tensor) -> Result<Vec<usize>> {
184 let grad_data = gradient.data()?;
185 let indices: Vec<usize> = grad_data
186 .iter()
187 .enumerate()
188 .filter_map(
189 |(i, &val)| {
190 if val.abs() > self.config.sparsity_threshold {
191 Some(i)
192 } else {
193 None
194 }
195 },
196 )
197 .collect();
198
199 if let Some(max_active) = self.config.max_active_params {
201 if indices.len() > max_active {
202 let mut indexed_grads: Vec<(usize, f32)> =
204 indices.iter().map(|&i| (i, grad_data[i].abs())).collect();
205 indexed_grads.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
206 return Ok(indexed_grads.into_iter().take(max_active).map(|(i, _)| i).collect());
207 }
208 }
209
210 Ok(indices)
211 }
212
213 fn update_sparse_momentum(
215 &mut self,
216 param_id: usize,
217 gradient: &Tensor,
218 parameter: &mut Tensor,
219 ) -> Result<()> {
220 let sparse_indices = self.get_sparse_indices(gradient)?;
221 if sparse_indices.is_empty() {
222 return Ok(());
223 }
224
225 let grad_data = gradient.data()?;
226 let mut param_data = parameter.data()?;
227
228 let momentum_state = self.momentum_states.entry(param_id).or_default();
230
231 if self.config.lazy_updates {
233 momentum_state.apply_lazy_update(self.current_step, self.momentum);
234 }
235
236 for &idx in &sparse_indices {
238 let mut grad_val = grad_data[idx];
239
240 if self.weight_decay != 0.0 {
242 grad_val += self.weight_decay * param_data[idx];
243 }
244
245 let momentum_val = momentum_state.momentum.get(&idx).copied().unwrap_or(0.0);
247 let new_momentum = self.momentum * momentum_val + (1.0 - self.dampening) * grad_val;
248 momentum_state.momentum.insert(idx, new_momentum);
249 momentum_state.last_update.insert(idx, self.current_step);
250
251 let update = if self.nesterov {
253 grad_val + self.momentum * new_momentum
254 } else {
255 new_momentum
256 };
257
258 param_data[idx] -= self.learning_rate * update;
259 }
260
261 *parameter = Tensor::from_vec(param_data, ¶meter.shape())?;
263
264 Ok(())
265 }
266
267 pub fn get_momentum_stats(&self) -> HashMap<usize, usize> {
269 self.momentum_states
270 .iter()
271 .map(|(¶m_id, state)| (param_id, state.num_active()))
272 .collect()
273 }
274
275 pub fn total_active_states(&self) -> usize {
277 self.momentum_states.values().map(|s| s.num_active()).sum()
278 }
279
280 pub fn cleanup_momentum_states(&mut self) {
282 if self.current_step % self.config.cleanup_frequency == 0 {
283 let max_age = self.config.cleanup_frequency * 2;
284 for state in self.momentum_states.values_mut() {
285 state.cleanup(max_age, self.current_step);
286 if self.config.compress_inactive {
287 state.compress();
288 }
289 }
290 }
291 }
292}
293
294impl OptimizerState for SparseSGD {
295 fn zero_grad(&mut self) -> Result<()> {
296 Ok(())
299 }
300
301 fn step(&mut self, parameters: &mut [Tensor]) -> Result<()> {
302 self.current_step += 1;
303
304 for (param_id, parameter) in parameters.iter_mut().enumerate() {
305 if let Ok(gradient) = parameter.grad() {
308 self.update_sparse_momentum(param_id, &gradient, parameter)?;
309 }
310 }
311
312 self.cleanup_momentum_states();
314
315 Ok(())
316 }
317
318 fn get_lr(&self) -> f32 {
319 self.learning_rate
320 }
321
322 fn set_lr(&mut self, lr: f32) {
323 self.learning_rate = lr;
324 }
325
326 fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
327 let mut state = HashMap::new();
328
329 state.insert(
331 "learning_rate".to_string(),
332 Tensor::scalar(self.learning_rate)?,
333 );
334 state.insert("momentum".to_string(), Tensor::scalar(self.momentum)?);
335 state.insert("dampening".to_string(), Tensor::scalar(self.dampening)?);
336 state.insert(
337 "weight_decay".to_string(),
338 Tensor::scalar(self.weight_decay)?,
339 );
340 state.insert(
341 "nesterov".to_string(),
342 Tensor::scalar(self.nesterov as i32 as f32)?,
343 );
344 state.insert(
345 "current_step".to_string(),
346 Tensor::scalar(self.current_step as f32)?,
347 );
348
349 for (¶m_id, momentum_state) in &self.momentum_states {
351 let num_active = momentum_state.num_active();
352 state.insert(
353 format!("momentum_state_{}_active_count", param_id),
354 Tensor::scalar(num_active as f32)?,
355 );
356 }
357
358 Ok(state)
359 }
360
361 fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
362 if let Some(lr_tensor) = state.get("learning_rate") {
364 self.learning_rate = lr_tensor.to_scalar()?;
365 }
366 if let Some(momentum_tensor) = state.get("momentum") {
367 self.momentum = momentum_tensor.to_scalar()?;
368 }
369 if let Some(dampening_tensor) = state.get("dampening") {
370 self.dampening = dampening_tensor.to_scalar()?;
371 }
372 if let Some(wd_tensor) = state.get("weight_decay") {
373 self.weight_decay = wd_tensor.to_scalar()?;
374 }
375 if let Some(nesterov_tensor) = state.get("nesterov") {
376 self.nesterov = nesterov_tensor.to_scalar()? > 0.5;
377 }
378 if let Some(step_tensor) = state.get("current_step") {
379 self.current_step = step_tensor.to_scalar()? as usize;
380 }
381
382 Ok(())
386 }
387}
388
389#[derive(Debug)]
391pub struct SparseAdam {
392 learning_rate: f32,
393 beta1: f32,
394 beta2: f32,
395 epsilon: f32,
396 weight_decay: f32,
397 config: SparseConfig,
398 momentum_states: HashMap<usize, SparseMomentumState>,
399 variance_states: HashMap<usize, HashMap<usize, f32>>,
400 current_step: usize,
401}
402
403impl SparseAdam {
404 pub fn new(
405 learning_rate: f32,
406 beta1: f32,
407 beta2: f32,
408 epsilon: f32,
409 weight_decay: f32,
410 config: SparseConfig,
411 ) -> Self {
412 Self {
413 learning_rate,
414 beta1,
415 beta2,
416 epsilon,
417 weight_decay,
418 config,
419 momentum_states: HashMap::new(),
420 variance_states: HashMap::new(),
421 current_step: 0,
422 }
423 }
424
425 pub fn with_default_config(
426 learning_rate: f32,
427 beta1: f32,
428 beta2: f32,
429 epsilon: f32,
430 weight_decay: f32,
431 ) -> Self {
432 Self::new(
433 learning_rate,
434 beta1,
435 beta2,
436 epsilon,
437 weight_decay,
438 SparseConfig::default(),
439 )
440 }
441
442 fn get_sparse_indices(&self, gradient: &Tensor) -> Result<Vec<usize>> {
443 let grad_data = gradient.data()?;
444 Ok(grad_data
445 .iter()
446 .enumerate()
447 .filter_map(
448 |(i, &val)| {
449 if val.abs() > self.config.sparsity_threshold {
450 Some(i)
451 } else {
452 None
453 }
454 },
455 )
456 .collect())
457 }
458
459 fn update_sparse_adam(
460 &mut self,
461 param_id: usize,
462 gradient: &Tensor,
463 parameter: &mut Tensor,
464 ) -> Result<()> {
465 let sparse_indices = self.get_sparse_indices(gradient)?;
466 if sparse_indices.is_empty() {
467 return Ok(());
468 }
469
470 let grad_data = gradient.data()?;
471 let mut param_data = parameter.data()?;
472
473 let momentum_state = self.momentum_states.entry(param_id).or_default();
475 let variance_state = self.variance_states.entry(param_id).or_default();
476
477 let bias_correction1 = 1.0 - self.beta1.powi(self.current_step as i32);
479 let bias_correction2 = 1.0 - self.beta2.powi(self.current_step as i32);
480
481 for &idx in &sparse_indices {
483 let mut grad_val = grad_data[idx];
484
485 if self.weight_decay != 0.0 {
487 grad_val += self.weight_decay * param_data[idx];
488 }
489
490 let momentum_val = momentum_state.momentum.get(&idx).copied().unwrap_or(0.0);
492 let new_momentum = self.beta1 * momentum_val + (1.0 - self.beta1) * grad_val;
493 momentum_state.momentum.insert(idx, new_momentum);
494
495 let variance_val = variance_state.get(&idx).copied().unwrap_or(0.0);
497 let new_variance = self.beta2 * variance_val + (1.0 - self.beta2) * grad_val * grad_val;
498 variance_state.insert(idx, new_variance);
499
500 let momentum_corrected = new_momentum / bias_correction1;
502 let variance_corrected = new_variance / bias_correction2;
503
504 let denom = variance_corrected.sqrt() + self.epsilon;
506 param_data[idx] -= self.learning_rate * momentum_corrected / denom;
507
508 momentum_state.last_update.insert(idx, self.current_step);
509 }
510
511 *parameter = Tensor::from_vec(param_data, ¶meter.shape())?;
513
514 Ok(())
515 }
516}
517
518impl OptimizerState for SparseAdam {
519 fn zero_grad(&mut self) -> Result<()> {
520 Ok(())
521 }
522
523 fn step(&mut self, parameters: &mut [Tensor]) -> Result<()> {
524 self.current_step += 1;
525
526 for (param_id, parameter) in parameters.iter_mut().enumerate() {
527 if let Ok(gradient) = parameter.grad() {
528 self.update_sparse_adam(param_id, &gradient, parameter)?;
529 }
530 }
531
532 Ok(())
533 }
534
535 fn get_lr(&self) -> f32 {
536 self.learning_rate
537 }
538
539 fn set_lr(&mut self, lr: f32) {
540 self.learning_rate = lr;
541 }
542
543 fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
544 let mut state = HashMap::new();
545 state.insert(
546 "learning_rate".to_string(),
547 Tensor::scalar(self.learning_rate)?,
548 );
549 state.insert("beta1".to_string(), Tensor::scalar(self.beta1)?);
550 state.insert("beta2".to_string(), Tensor::scalar(self.beta2)?);
551 state.insert("epsilon".to_string(), Tensor::scalar(self.epsilon)?);
552 state.insert(
553 "weight_decay".to_string(),
554 Tensor::scalar(self.weight_decay)?,
555 );
556 state.insert(
557 "current_step".to_string(),
558 Tensor::scalar(self.current_step as f32)?,
559 );
560 Ok(state)
561 }
562
563 fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
564 if let Some(lr) = state.get("learning_rate") {
565 self.learning_rate = lr.to_scalar()?;
566 }
567 if let Some(beta1) = state.get("beta1") {
568 self.beta1 = beta1.to_scalar()?;
569 }
570 if let Some(beta2) = state.get("beta2") {
571 self.beta2 = beta2.to_scalar()?;
572 }
573 if let Some(eps) = state.get("epsilon") {
574 self.epsilon = eps.to_scalar()?;
575 }
576 if let Some(wd) = state.get("weight_decay") {
577 self.weight_decay = wd.to_scalar()?;
578 }
579 if let Some(step) = state.get("current_step") {
580 self.current_step = step.to_scalar()? as usize;
581 }
582 Ok(())
583 }
584}
585
586#[cfg(test)]
587mod tests {
588 use super::*;
589
590 #[test]
591 fn test_sparse_config_default() {
592 let config = SparseConfig::default();
593 assert_eq!(config.sparsity_threshold, 1e-8);
594 assert!(config.max_active_params.is_none());
595 assert!(config.lazy_updates);
596 assert_eq!(config.cleanup_frequency, 1000);
597 assert!(!config.compress_inactive);
598 }
599
600 #[test]
601 fn test_sparse_momentum_state() {
602 let mut state = SparseMomentumState::new();
603 assert_eq!(state.num_active(), 0);
604
605 state.momentum.insert(0, 1.0);
606 state.momentum.insert(5, 2.0);
607 assert_eq!(state.num_active(), 2);
608
609 state.cleanup(0, 100);
610 assert_eq!(state.num_active(), 2); }
612
613 #[test]
614 fn test_sparse_sgd_creation() {
615 let optimizer = SparseSGD::with_default_config(0.01, 0.9, 0.0, 1e-4, false);
616 assert_eq!(optimizer.get_lr(), 0.01);
617 assert_eq!(optimizer.total_active_states(), 0);
618 }
619
620 #[test]
621 fn test_sparse_adam_creation() {
622 let optimizer = SparseAdam::with_default_config(1e-3, 0.9, 0.999, 1e-8, 0.01);
623 assert_eq!(optimizer.get_lr(), 1e-3);
624 assert_eq!(optimizer.current_step, 0);
625 }
626
627 #[test]
628 fn test_sparse_sgd_lr_update() {
629 let mut optimizer = SparseSGD::with_default_config(0.01, 0.9, 0.0, 1e-4, false);
630 assert_eq!(optimizer.get_lr(), 0.01);
631
632 optimizer.set_lr(0.001);
633 assert_eq!(optimizer.get_lr(), 0.001);
634 }
635}