1use crate::optimizer::OptimizerState;
20use anyhow::{anyhow, Result};
21use serde::{Deserialize, Serialize};
22use std::collections::HashMap;
23use trustformers_core::tensor::Tensor;
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct QHMConfig {
28 pub learning_rate: f32,
30 pub momentum: f32,
32 pub nu: f32,
34 pub weight_decay: f32,
36}
37
38impl Default for QHMConfig {
39 fn default() -> Self {
40 Self {
41 learning_rate: 1e-3,
42 momentum: 0.9,
43 nu: 0.7,
44 weight_decay: 0.0,
45 }
46 }
47}
48
49#[derive(Debug)]
55pub struct QHM {
56 config: QHMConfig,
57 momentum_buffers: HashMap<usize, Tensor>,
58 current_step: usize,
59}
60
61impl QHM {
62 pub fn new(config: QHMConfig) -> Self {
64 Self {
65 config,
66 momentum_buffers: HashMap::new(),
67 current_step: 0,
68 }
69 }
70
71 pub fn with_defaults(learning_rate: f32, momentum: f32, nu: f32) -> Self {
73 Self::new(QHMConfig {
74 learning_rate,
75 momentum,
76 nu,
77 weight_decay: 0.0,
78 })
79 }
80
81 pub fn get_config(&self) -> &QHMConfig {
83 &self.config
84 }
85
86 pub fn set_config(&mut self, config: QHMConfig) {
88 self.config = config;
89 }
90}
91
92impl OptimizerState for QHM {
93 fn zero_grad(&mut self) -> Result<()> {
94 Ok(())
96 }
97
98 fn step(&mut self, parameters: &mut [Tensor]) -> Result<()> {
99 self.current_step += 1;
100
101 for (param_id, parameter) in parameters.iter_mut().enumerate() {
102 let gradient = match parameter.grad() {
104 Ok(grad) => grad,
105 Err(_) => {
106 continue;
108 },
109 };
110
111 let effective_grad = if self.config.weight_decay > 0.0 {
113 gradient.add(¶meter.mul_scalar(self.config.weight_decay)?)?
114 } else {
115 gradient
116 };
117
118 let momentum_buffer = if let Some(buffer) = self.momentum_buffers.get(¶m_id) {
120 let updated = buffer
122 .mul_scalar(self.config.momentum)?
123 .add(&effective_grad.mul_scalar(1.0 - self.config.momentum)?)?;
124 self.momentum_buffers.insert(param_id, updated.clone());
125 updated
126 } else {
127 let initial_momentum = effective_grad.clone();
129 self.momentum_buffers.insert(param_id, initial_momentum.clone());
130 initial_momentum
131 };
132
133 let update_direction = effective_grad
135 .mul_scalar(self.config.nu)?
136 .add(&momentum_buffer.mul_scalar(1.0 - self.config.nu)?)?;
137
138 *parameter = parameter.sub(&update_direction.mul_scalar(self.config.learning_rate)?)?;
140 }
141
142 Ok(())
143 }
144
145 fn get_lr(&self) -> f32 {
146 self.config.learning_rate
147 }
148
149 fn set_lr(&mut self, lr: f32) {
150 self.config.learning_rate = lr;
151 }
152
153 fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
154 let mut state = HashMap::new();
155
156 state.insert(
158 "learning_rate".to_string(),
159 Tensor::scalar(self.config.learning_rate)?,
160 );
161 state.insert(
162 "momentum".to_string(),
163 Tensor::scalar(self.config.momentum)?,
164 );
165 state.insert("nu".to_string(), Tensor::scalar(self.config.nu)?);
166 state.insert(
167 "weight_decay".to_string(),
168 Tensor::scalar(self.config.weight_decay)?,
169 );
170 state.insert(
171 "current_step".to_string(),
172 Tensor::scalar(self.current_step as f32)?,
173 );
174
175 for (¶m_id, buffer) in &self.momentum_buffers {
177 state.insert(format!("momentum_buffer_{}", param_id), buffer.clone());
178 }
179
180 Ok(state)
181 }
182
183 fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
184 if let Some(lr) = state.get("learning_rate") {
186 self.config.learning_rate = lr.to_scalar()?;
187 }
188 if let Some(momentum) = state.get("momentum") {
189 self.config.momentum = momentum.to_scalar()?;
190 }
191 if let Some(nu) = state.get("nu") {
192 self.config.nu = nu.to_scalar()?;
193 }
194 if let Some(wd) = state.get("weight_decay") {
195 self.config.weight_decay = wd.to_scalar()?;
196 }
197 if let Some(step) = state.get("current_step") {
198 self.current_step = step.to_scalar()? as usize;
199 }
200
201 self.momentum_buffers.clear();
203 for (key, tensor) in state {
204 if let Some(param_id_str) = key.strip_prefix("momentum_buffer_") {
205 if let Ok(param_id) = param_id_str.parse::<usize>() {
206 self.momentum_buffers.insert(param_id, tensor);
207 }
208 }
209 }
210
211 Ok(())
212 }
213}
214
215#[derive(Debug, Clone, Serialize, Deserialize)]
217pub struct AggMoConfig {
218 pub learning_rate: f32,
220 pub momentum_coefficients: Vec<f32>,
222 pub weight_decay: f32,
224}
225
226impl Default for AggMoConfig {
227 fn default() -> Self {
228 Self {
229 learning_rate: 1e-3,
230 momentum_coefficients: vec![0.0, 0.9, 0.99],
231 weight_decay: 0.0,
232 }
233 }
234}
235
236#[derive(Debug)]
241pub struct AggMo {
242 config: AggMoConfig,
243 momentum_buffers: HashMap<usize, Vec<Tensor>>, current_step: usize,
245}
246
247impl AggMo {
248 pub fn new(config: AggMoConfig) -> Self {
250 assert!(
251 !config.momentum_coefficients.is_empty(),
252 "Must provide at least one momentum coefficient"
253 );
254 Self {
255 config,
256 momentum_buffers: HashMap::new(),
257 current_step: 0,
258 }
259 }
260
261 pub fn with_defaults(learning_rate: f32, momentum_coefficients: Vec<f32>) -> Self {
263 Self::new(AggMoConfig {
264 learning_rate,
265 momentum_coefficients,
266 weight_decay: 0.0,
267 })
268 }
269
270 pub fn get_config(&self) -> &AggMoConfig {
272 &self.config
273 }
274
275 pub fn num_momentum_buffers(&self) -> usize {
277 self.config.momentum_coefficients.len()
278 }
279}
280
281impl OptimizerState for AggMo {
282 fn zero_grad(&mut self) -> Result<()> {
283 Ok(())
284 }
285
286 fn step(&mut self, parameters: &mut [Tensor]) -> Result<()> {
287 self.current_step += 1;
288
289 for (param_id, parameter) in parameters.iter_mut().enumerate() {
290 let gradient = match parameter.grad() {
292 Ok(grad) => grad,
293 Err(_) => {
294 continue;
296 },
297 };
298
299 let effective_grad = if self.config.weight_decay > 0.0 {
301 gradient.add(¶meter.mul_scalar(self.config.weight_decay)?)?
302 } else {
303 gradient
304 };
305
306 let buffers = self.momentum_buffers.entry(param_id).or_insert_with(|| {
308 (0..self.config.momentum_coefficients.len())
310 .map(|_| {
311 Tensor::zeros(&effective_grad.shape())
312 .expect("zeros should always succeed for valid gradient shape")
313 })
314 .collect()
315 });
316
317 let mut aggregated_momentum = Tensor::zeros(&effective_grad.shape())?;
319 for (i, &beta) in self.config.momentum_coefficients.iter().enumerate() {
320 buffers[i] =
322 buffers[i].mul_scalar(beta)?.add(&effective_grad.mul_scalar(1.0 - beta)?)?;
323
324 aggregated_momentum = aggregated_momentum.add(&buffers[i])?;
326 }
327
328 let num_buffers = self.config.momentum_coefficients.len() as f32;
330 let averaged_momentum = aggregated_momentum.div_scalar(num_buffers)?;
331
332 *parameter =
334 parameter.sub(&averaged_momentum.mul_scalar(self.config.learning_rate)?)?;
335 }
336
337 Ok(())
338 }
339
340 fn get_lr(&self) -> f32 {
341 self.config.learning_rate
342 }
343
344 fn set_lr(&mut self, lr: f32) {
345 self.config.learning_rate = lr;
346 }
347
348 fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
349 let mut state = HashMap::new();
350
351 state.insert(
353 "learning_rate".to_string(),
354 Tensor::scalar(self.config.learning_rate)?,
355 );
356 state.insert(
357 "weight_decay".to_string(),
358 Tensor::scalar(self.config.weight_decay)?,
359 );
360 state.insert(
361 "current_step".to_string(),
362 Tensor::scalar(self.current_step as f32)?,
363 );
364 state.insert(
365 "num_momentum_coeffs".to_string(),
366 Tensor::scalar(self.config.momentum_coefficients.len() as f32)?,
367 );
368
369 for (i, &coeff) in self.config.momentum_coefficients.iter().enumerate() {
371 state.insert(format!("momentum_coeff_{}", i), Tensor::scalar(coeff)?);
372 }
373
374 for (¶m_id, buffers) in &self.momentum_buffers {
376 for (buffer_idx, buffer) in buffers.iter().enumerate() {
377 state.insert(
378 format!("momentum_buffer_{}_{}", param_id, buffer_idx),
379 buffer.clone(),
380 );
381 }
382 }
383
384 Ok(state)
385 }
386
387 fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
388 if let Some(lr) = state.get("learning_rate") {
390 self.config.learning_rate = lr.to_scalar()?;
391 }
392 if let Some(wd) = state.get("weight_decay") {
393 self.config.weight_decay = wd.to_scalar()?;
394 }
395 if let Some(step) = state.get("current_step") {
396 self.current_step = step.to_scalar()? as usize;
397 }
398
399 if let Some(num_coeffs_tensor) = state.get("num_momentum_coeffs") {
401 let num_coeffs = num_coeffs_tensor.to_scalar()? as usize;
402 let mut coefficients = Vec::with_capacity(num_coeffs);
403 for i in 0..num_coeffs {
404 if let Some(coeff_tensor) = state.get(&format!("momentum_coeff_{}", i)) {
405 coefficients.push(coeff_tensor.to_scalar()?);
406 }
407 }
408 self.config.momentum_coefficients = coefficients;
409 }
410
411 self.momentum_buffers.clear();
413 let mut param_buffers: HashMap<usize, HashMap<usize, Tensor>> = HashMap::new();
414
415 for (key, tensor) in state {
416 if key.starts_with("momentum_buffer_") {
417 let parts: Vec<&str> = key.split('_').collect();
418 if parts.len() >= 4 {
419 if let (Ok(param_id), Ok(buffer_idx)) =
420 (parts[2].parse::<usize>(), parts[3].parse::<usize>())
421 {
422 param_buffers.entry(param_id).or_default().insert(buffer_idx, tensor);
423 }
424 }
425 }
426 }
427
428 for (param_id, buffer_map) in param_buffers {
430 let mut buffers = Vec::new();
431 for i in 0..self.config.momentum_coefficients.len() {
432 if let Some(buffer) = buffer_map.get(&i) {
433 buffers.push(buffer.clone());
434 }
435 }
436 if buffers.len() == self.config.momentum_coefficients.len() {
437 self.momentum_buffers.insert(param_id, buffers);
438 }
439 }
440
441 Ok(())
442 }
443}
444
445#[derive(Debug, Clone, Serialize, Deserialize)]
447pub struct VarianceReductionConfig {
448 pub learning_rate: f32,
450 pub method: VarianceReductionMethod,
452 pub history_size: usize,
454 pub full_grad_frequency: usize,
456 pub weight_decay: f32,
458}
459
460impl Default for VarianceReductionConfig {
461 fn default() -> Self {
462 Self {
463 learning_rate: 1e-3,
464 method: VarianceReductionMethod::SVRG,
465 history_size: 100,
466 full_grad_frequency: 10,
467 weight_decay: 0.0,
468 }
469 }
470}
471
472#[derive(Debug, Clone, Serialize, Deserialize)]
474pub enum VarianceReductionMethod {
475 SVRG,
477 SAG,
479}
480
481#[derive(Debug)]
483pub struct VarianceReduction {
484 config: VarianceReductionConfig,
485 gradient_history: HashMap<usize, Vec<Tensor>>,
486 average_gradients: HashMap<usize, Tensor>,
487 full_gradients: HashMap<usize, Tensor>,
488 current_step: usize,
489 last_full_grad_step: usize,
490}
491
492impl VarianceReduction {
493 pub fn new(config: VarianceReductionConfig) -> Self {
495 Self {
496 config,
497 gradient_history: HashMap::new(),
498 average_gradients: HashMap::new(),
499 full_gradients: HashMap::new(),
500 current_step: 0,
501 last_full_grad_step: 0,
502 }
503 }
504
505 pub fn svrg(learning_rate: f32, history_size: usize, full_grad_frequency: usize) -> Self {
507 Self::new(VarianceReductionConfig {
508 learning_rate,
509 method: VarianceReductionMethod::SVRG,
510 history_size,
511 full_grad_frequency,
512 weight_decay: 0.0,
513 })
514 }
515
516 pub fn sag(learning_rate: f32, history_size: usize) -> Self {
518 Self::new(VarianceReductionConfig {
519 learning_rate,
520 method: VarianceReductionMethod::SAG,
521 history_size,
522 full_grad_frequency: 1, weight_decay: 0.0,
524 })
525 }
526
527 fn update_gradient_history(&mut self, param_id: usize, gradient: &Tensor) -> Result<()> {
528 let history = self.gradient_history.entry(param_id).or_default();
529
530 history.push(gradient.clone());
531 if history.len() > self.config.history_size {
532 history.remove(0);
533 }
534
535 Ok(())
536 }
537
538 fn compute_average_gradient(&mut self, param_id: usize) -> Result<Tensor> {
539 if let Some(history) = self.gradient_history.get(¶m_id) {
540 if history.is_empty() {
541 return Err(anyhow!("No gradient history available"));
542 }
543
544 let mut sum = history[0].clone();
545 for grad in history.iter().skip(1) {
546 sum = sum.add(grad)?;
547 }
548
549 let average = sum.div_scalar(history.len() as f32)?;
550 self.average_gradients.insert(param_id, average.clone());
551 Ok(average)
552 } else {
553 Err(anyhow!("No gradient history for parameter {}", param_id))
554 }
555 }
556
557 fn should_compute_full_gradient(&self) -> bool {
558 self.current_step - self.last_full_grad_step >= self.config.full_grad_frequency
559 }
560}
561
562impl OptimizerState for VarianceReduction {
563 fn zero_grad(&mut self) -> Result<()> {
564 Ok(())
565 }
566
567 fn step(&mut self, parameters: &mut [Tensor]) -> Result<()> {
568 self.current_step += 1;
569
570 let compute_full_grad = match self.config.method {
572 VarianceReductionMethod::SVRG => self.should_compute_full_gradient(),
573 VarianceReductionMethod::SAG => false,
574 };
575
576 if compute_full_grad {
577 self.last_full_grad_step = self.current_step;
578 for (param_id, parameter) in parameters.iter().enumerate() {
581 let gradient = match parameter.grad() {
583 Ok(grad) => grad,
584 Err(_) => {
585 continue;
587 },
588 };
589 self.full_gradients.insert(param_id, gradient);
590 }
591 }
592
593 for (param_id, parameter) in parameters.iter_mut().enumerate() {
594 let current_gradient = match parameter.grad() {
596 Ok(grad) => grad,
597 Err(_) => {
598 continue;
600 },
601 };
602
603 let effective_grad = if self.config.weight_decay > 0.0 {
605 current_gradient.add(¶meter.mul_scalar(self.config.weight_decay)?)?
606 } else {
607 current_gradient
608 };
609
610 self.update_gradient_history(param_id, &effective_grad)?;
612
613 let variance_reduced_grad = match self.config.method {
615 VarianceReductionMethod::SVRG => {
616 let full_grad_opt = self.full_gradients.get(¶m_id).cloned();
618 if let Some(full_grad) = full_grad_opt {
619 let avg_grad = self.compute_average_gradient(param_id)?;
620 effective_grad.sub(&avg_grad)?.add(&full_grad)?
622 } else {
623 effective_grad
624 }
625 },
626 VarianceReductionMethod::SAG => {
627 self.compute_average_gradient(param_id)?
629 },
630 };
631
632 *parameter =
634 parameter.sub(&variance_reduced_grad.mul_scalar(self.config.learning_rate)?)?;
635 }
636
637 Ok(())
638 }
639
640 fn get_lr(&self) -> f32 {
641 self.config.learning_rate
642 }
643
644 fn set_lr(&mut self, lr: f32) {
645 self.config.learning_rate = lr;
646 }
647
648 fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
649 let mut state = HashMap::new();
650
651 state.insert(
652 "learning_rate".to_string(),
653 Tensor::scalar(self.config.learning_rate)?,
654 );
655 state.insert(
656 "current_step".to_string(),
657 Tensor::scalar(self.current_step as f32)?,
658 );
659 state.insert(
660 "last_full_grad_step".to_string(),
661 Tensor::scalar(self.last_full_grad_step as f32)?,
662 );
663
664 Ok(state)
668 }
669
670 fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
671 if let Some(lr) = state.get("learning_rate") {
672 self.config.learning_rate = lr.to_scalar()?;
673 }
674 if let Some(step) = state.get("current_step") {
675 self.current_step = step.to_scalar()? as usize;
676 }
677 if let Some(last_step) = state.get("last_full_grad_step") {
678 self.last_full_grad_step = last_step.to_scalar()? as usize;
679 }
680
681 Ok(())
682 }
683}
684
685#[derive(Debug, Clone, Serialize, Deserialize)]
687pub struct NesterovAcceleratedGradientConfig {
688 pub learning_rate: f32,
690 pub momentum: f32,
692 pub weight_decay: f32,
694 pub restart_on_increase: bool,
696}
697
698impl Default for NesterovAcceleratedGradientConfig {
699 fn default() -> Self {
700 Self {
701 learning_rate: 1e-3,
702 momentum: 0.9,
703 weight_decay: 0.0,
704 restart_on_increase: false,
705 }
706 }
707}
708
709#[derive(Debug)]
717pub struct NesterovAcceleratedGradient {
718 config: NesterovAcceleratedGradientConfig,
719 velocity_buffers: HashMap<usize, Tensor>,
720 current_step: usize,
721 previous_loss: Option<f32>,
722}
723
724impl NesterovAcceleratedGradient {
725 pub fn new(config: NesterovAcceleratedGradientConfig) -> Self {
727 Self {
728 config,
729 velocity_buffers: HashMap::new(),
730 current_step: 0,
731 previous_loss: None,
732 }
733 }
734
735 pub fn with_defaults(learning_rate: f32, momentum: f32) -> Self {
737 Self::new(NesterovAcceleratedGradientConfig {
738 learning_rate,
739 momentum,
740 weight_decay: 0.0,
741 restart_on_increase: false,
742 })
743 }
744
745 pub fn get_config(&self) -> &NesterovAcceleratedGradientConfig {
747 &self.config
748 }
749
750 pub fn set_current_loss(&mut self, loss: f32) {
752 if self.config.restart_on_increase {
753 if let Some(prev_loss) = self.previous_loss {
754 if loss > prev_loss {
755 self.velocity_buffers.clear();
757 }
758 }
759 }
760 self.previous_loss = Some(loss);
761 }
762}
763
764impl OptimizerState for NesterovAcceleratedGradient {
765 fn zero_grad(&mut self) -> Result<()> {
766 Ok(())
767 }
768
769 fn step(&mut self, parameters: &mut [Tensor]) -> Result<()> {
770 self.current_step += 1;
771
772 for (param_id, parameter) in parameters.iter_mut().enumerate() {
773 let gradient = match parameter.grad() {
775 Ok(grad) => grad,
776 Err(_) => {
777 continue;
779 },
780 };
781
782 let effective_grad = if self.config.weight_decay > 0.0 {
784 gradient.add(¶meter.mul_scalar(self.config.weight_decay)?)?
785 } else {
786 gradient
787 };
788
789 let velocity = if let Some(v) = self.velocity_buffers.get(¶m_id) {
791 v.clone()
792 } else {
793 Tensor::zeros_like(parameter)?
794 };
795
796 let _lookahead_position = parameter.sub(&velocity.mul_scalar(self.config.momentum)?)?;
798
799 let new_velocity = velocity
805 .mul_scalar(self.config.momentum)?
806 .add(&effective_grad.mul_scalar(self.config.learning_rate)?)?;
807
808 self.velocity_buffers.insert(param_id, new_velocity.clone());
809
810 *parameter = parameter.sub(&new_velocity)?;
812 }
813
814 Ok(())
815 }
816
817 fn get_lr(&self) -> f32 {
818 self.config.learning_rate
819 }
820
821 fn set_lr(&mut self, lr: f32) {
822 self.config.learning_rate = lr;
823 }
824
825 fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
826 let mut state = HashMap::new();
827
828 state.insert(
829 "learning_rate".to_string(),
830 Tensor::scalar(self.config.learning_rate)?,
831 );
832 state.insert(
833 "momentum".to_string(),
834 Tensor::scalar(self.config.momentum)?,
835 );
836 state.insert(
837 "weight_decay".to_string(),
838 Tensor::scalar(self.config.weight_decay)?,
839 );
840 state.insert(
841 "current_step".to_string(),
842 Tensor::scalar(self.current_step as f32)?,
843 );
844
845 if let Some(loss) = self.previous_loss {
846 state.insert("previous_loss".to_string(), Tensor::scalar(loss)?);
847 }
848
849 for (¶m_id, velocity) in &self.velocity_buffers {
850 state.insert(format!("velocity_{}", param_id), velocity.clone());
851 }
852
853 Ok(state)
854 }
855
856 fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
857 if let Some(lr) = state.get("learning_rate") {
858 self.config.learning_rate = lr.to_scalar()?;
859 }
860 if let Some(momentum) = state.get("momentum") {
861 self.config.momentum = momentum.to_scalar()?;
862 }
863 if let Some(wd) = state.get("weight_decay") {
864 self.config.weight_decay = wd.to_scalar()?;
865 }
866 if let Some(step) = state.get("current_step") {
867 self.current_step = step.to_scalar()? as usize;
868 }
869 if let Some(loss) = state.get("previous_loss") {
870 self.previous_loss = Some(loss.to_scalar()?);
871 }
872
873 self.velocity_buffers.clear();
874 for (key, tensor) in state {
875 if let Some(param_id_str) = key.strip_prefix("velocity_") {
876 if let Ok(param_id) = param_id_str.parse::<usize>() {
877 self.velocity_buffers.insert(param_id, tensor);
878 }
879 }
880 }
881
882 Ok(())
883 }
884}
885
886#[derive(Debug, Clone, Serialize, Deserialize)]
888pub struct HeavyBallConfig {
889 pub learning_rate: f32,
891 pub beta: f32,
893 pub weight_decay: f32,
895 pub adaptive_momentum: bool,
897}
898
899impl Default for HeavyBallConfig {
900 fn default() -> Self {
901 Self {
902 learning_rate: 1e-3,
903 beta: 0.9,
904 weight_decay: 0.0,
905 adaptive_momentum: false,
906 }
907 }
908}
909
910#[derive(Debug)]
917pub struct HeavyBall {
918 config: HeavyBallConfig,
919 velocity_buffers: HashMap<usize, Tensor>,
920 previous_gradients: HashMap<usize, Tensor>,
921 current_step: usize,
922}
923
924impl HeavyBall {
925 pub fn new(config: HeavyBallConfig) -> Self {
927 Self {
928 config,
929 velocity_buffers: HashMap::new(),
930 previous_gradients: HashMap::new(),
931 current_step: 0,
932 }
933 }
934
935 pub fn with_defaults(learning_rate: f32, beta: f32) -> Self {
937 Self::new(HeavyBallConfig {
938 learning_rate,
939 beta,
940 weight_decay: 0.0,
941 adaptive_momentum: false,
942 })
943 }
944
945 pub fn get_config(&self) -> &HeavyBallConfig {
947 &self.config
948 }
949
950 fn compute_adaptive_momentum(&self, param_id: usize, current_grad: &Tensor) -> Result<f32> {
952 if let Some(prev_grad) = self.previous_gradients.get(¶m_id) {
953 let dot_product = current_grad.mul(prev_grad)?.sum(None, false)?;
955 let norm_current = current_grad.norm_squared()?.sqrt()?;
956 let norm_prev = prev_grad.norm_squared()?.sqrt()?;
957
958 let dot_scalar = dot_product.to_scalar()?;
959 let norm_current_scalar = norm_current.to_scalar()?;
960 let norm_prev_scalar = norm_prev.to_scalar()?;
961
962 let denominator = norm_current_scalar * norm_prev_scalar;
963 if denominator > 1e-8 {
964 let cosine_similarity = dot_scalar / denominator;
965 let adaptive_beta = self.config.beta * cosine_similarity.max(0.0);
967 Ok(adaptive_beta)
968 } else {
969 Ok(self.config.beta)
970 }
971 } else {
972 Ok(self.config.beta)
973 }
974 }
975}
976
977impl OptimizerState for HeavyBall {
978 fn zero_grad(&mut self) -> Result<()> {
979 Ok(())
980 }
981
982 fn step(&mut self, parameters: &mut [Tensor]) -> Result<()> {
983 self.current_step += 1;
984
985 for (param_id, parameter) in parameters.iter_mut().enumerate() {
986 let gradient = match parameter.grad() {
988 Ok(grad) => grad,
989 Err(_) => {
990 continue;
992 },
993 };
994
995 let effective_grad = if self.config.weight_decay > 0.0 {
997 gradient.add(¶meter.mul_scalar(self.config.weight_decay)?)?
998 } else {
999 gradient
1000 };
1001
1002 let beta = if self.config.adaptive_momentum {
1004 self.compute_adaptive_momentum(param_id, &effective_grad)?
1005 } else {
1006 self.config.beta
1007 };
1008
1009 let velocity = if let Some(v) = self.velocity_buffers.get(¶m_id) {
1011 v.clone()
1012 } else {
1013 Tensor::zeros_like(parameter)?
1014 };
1015
1016 let new_velocity = velocity
1018 .mul_scalar(beta)?
1019 .sub(&effective_grad.mul_scalar(self.config.learning_rate)?)?;
1020
1021 self.velocity_buffers.insert(param_id, new_velocity.clone());
1022
1023 *parameter = parameter.add(&new_velocity)?;
1025
1026 if self.config.adaptive_momentum {
1028 self.previous_gradients.insert(param_id, effective_grad);
1029 }
1030 }
1031
1032 Ok(())
1033 }
1034
1035 fn get_lr(&self) -> f32 {
1036 self.config.learning_rate
1037 }
1038
1039 fn set_lr(&mut self, lr: f32) {
1040 self.config.learning_rate = lr;
1041 }
1042
1043 fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
1044 let mut state = HashMap::new();
1045
1046 state.insert(
1047 "learning_rate".to_string(),
1048 Tensor::scalar(self.config.learning_rate)?,
1049 );
1050 state.insert("beta".to_string(), Tensor::scalar(self.config.beta)?);
1051 state.insert(
1052 "weight_decay".to_string(),
1053 Tensor::scalar(self.config.weight_decay)?,
1054 );
1055 state.insert(
1056 "current_step".to_string(),
1057 Tensor::scalar(self.current_step as f32)?,
1058 );
1059
1060 for (¶m_id, velocity) in &self.velocity_buffers {
1061 state.insert(format!("velocity_{}", param_id), velocity.clone());
1062 }
1063
1064 for (¶m_id, grad) in &self.previous_gradients {
1065 state.insert(format!("prev_grad_{}", param_id), grad.clone());
1066 }
1067
1068 Ok(state)
1069 }
1070
1071 fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
1072 if let Some(lr) = state.get("learning_rate") {
1073 self.config.learning_rate = lr.to_scalar()?;
1074 }
1075 if let Some(beta) = state.get("beta") {
1076 self.config.beta = beta.to_scalar()?;
1077 }
1078 if let Some(wd) = state.get("weight_decay") {
1079 self.config.weight_decay = wd.to_scalar()?;
1080 }
1081 if let Some(step) = state.get("current_step") {
1082 self.current_step = step.to_scalar()? as usize;
1083 }
1084
1085 self.velocity_buffers.clear();
1086 self.previous_gradients.clear();
1087
1088 for (key, tensor) in state {
1089 if let Some(param_id_str) = key.strip_prefix("velocity_") {
1090 if let Ok(param_id) = param_id_str.parse::<usize>() {
1091 self.velocity_buffers.insert(param_id, tensor);
1092 }
1093 } else if let Some(param_id_str) = key.strip_prefix("prev_grad_") {
1094 if let Ok(param_id) = param_id_str.parse::<usize>() {
1095 self.previous_gradients.insert(param_id, tensor);
1096 }
1097 }
1098 }
1099
1100 Ok(())
1101 }
1102}
1103
1104#[derive(Debug, Clone, Serialize, Deserialize)]
1106pub struct FISTAConfig {
1107 pub learning_rate: f32,
1109 pub threshold: f32,
1111 pub adaptive_restart: bool,
1113 pub weight_decay: f32,
1115}
1116
1117impl Default for FISTAConfig {
1118 fn default() -> Self {
1119 Self {
1120 learning_rate: 1e-3,
1121 threshold: 1e-4,
1122 adaptive_restart: true,
1123 weight_decay: 0.0,
1124 }
1125 }
1126}
1127
1128#[derive(Debug)]
1133pub struct FISTA {
1134 config: FISTAConfig,
1135 previous_params: HashMap<usize, Tensor>,
1136 current_step: usize,
1137 momentum_coefficient: f32,
1138 previous_momentum: f32,
1139}
1140
1141impl FISTA {
1142 pub fn new(config: FISTAConfig) -> Self {
1144 Self {
1145 config,
1146 previous_params: HashMap::new(),
1147 current_step: 0,
1148 momentum_coefficient: 1.0,
1149 previous_momentum: 1.0,
1150 }
1151 }
1152
1153 pub fn with_defaults(learning_rate: f32, threshold: f32) -> Self {
1155 Self::new(FISTAConfig {
1156 learning_rate,
1157 threshold,
1158 adaptive_restart: true,
1159 weight_decay: 0.0,
1160 })
1161 }
1162
1163 pub fn get_config(&self) -> &FISTAConfig {
1165 &self.config
1166 }
1167
1168 fn soft_threshold(&self, tensor: &Tensor, threshold: f32) -> Result<Tensor> {
1170 let threshold_tensor = Tensor::scalar(threshold)?;
1171 let zero_tensor = Tensor::zeros_like(tensor)?;
1172
1173 let abs_tensor = tensor.abs()?;
1175 let thresholded = abs_tensor.sub(&threshold_tensor)?.max(&zero_tensor)?;
1176 let sign_tensor = tensor.sign()?;
1177
1178 Ok(sign_tensor.mul(&thresholded)?)
1179 }
1180
1181 fn update_momentum_coefficient(&mut self) {
1183 let t = self.current_step as f32;
1184 self.previous_momentum = self.momentum_coefficient;
1185 self.momentum_coefficient = (1.0 + (1.0 + 4.0 * t * t).sqrt()) / 2.0;
1186 }
1187}
1188
1189impl OptimizerState for FISTA {
1190 fn zero_grad(&mut self) -> Result<()> {
1191 Ok(())
1192 }
1193
1194 fn step(&mut self, parameters: &mut [Tensor]) -> Result<()> {
1195 self.current_step += 1;
1196 self.update_momentum_coefficient();
1197
1198 for (param_id, parameter) in parameters.iter_mut().enumerate() {
1199 let gradient = match parameter.grad() {
1201 Ok(grad) => grad,
1202 Err(_) => {
1203 continue;
1205 },
1206 };
1207
1208 let effective_grad = if self.config.weight_decay > 0.0 {
1210 gradient.add(¶meter.mul_scalar(self.config.weight_decay)?)?
1211 } else {
1212 gradient
1213 };
1214
1215 let previous_param = if let Some(prev) = self.previous_params.get(¶m_id) {
1217 prev.clone()
1218 } else {
1219 parameter.clone()
1220 };
1221
1222 let beta = (self.previous_momentum - 1.0) / self.momentum_coefficient;
1224
1225 let extrapolated = parameter.add(&previous_param.sub(parameter)?.mul_scalar(beta)?)?;
1227
1228 let grad_step =
1230 extrapolated.sub(&effective_grad.mul_scalar(self.config.learning_rate)?)?;
1231
1232 let new_parameter = self.soft_threshold(&grad_step, self.config.threshold)?;
1234
1235 self.previous_params.insert(param_id, parameter.clone());
1237
1238 *parameter = new_parameter;
1240 }
1241
1242 Ok(())
1243 }
1244
1245 fn get_lr(&self) -> f32 {
1246 self.config.learning_rate
1247 }
1248
1249 fn set_lr(&mut self, lr: f32) {
1250 self.config.learning_rate = lr;
1251 }
1252
1253 fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
1254 let mut state = HashMap::new();
1255
1256 state.insert(
1257 "learning_rate".to_string(),
1258 Tensor::scalar(self.config.learning_rate)?,
1259 );
1260 state.insert(
1261 "threshold".to_string(),
1262 Tensor::scalar(self.config.threshold)?,
1263 );
1264 state.insert(
1265 "weight_decay".to_string(),
1266 Tensor::scalar(self.config.weight_decay)?,
1267 );
1268 state.insert(
1269 "current_step".to_string(),
1270 Tensor::scalar(self.current_step as f32)?,
1271 );
1272 state.insert(
1273 "momentum_coefficient".to_string(),
1274 Tensor::scalar(self.momentum_coefficient)?,
1275 );
1276 state.insert(
1277 "previous_momentum".to_string(),
1278 Tensor::scalar(self.previous_momentum)?,
1279 );
1280
1281 for (¶m_id, param) in &self.previous_params {
1282 state.insert(format!("prev_param_{}", param_id), param.clone());
1283 }
1284
1285 Ok(state)
1286 }
1287
1288 fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
1289 if let Some(lr) = state.get("learning_rate") {
1290 self.config.learning_rate = lr.to_scalar()?;
1291 }
1292 if let Some(threshold) = state.get("threshold") {
1293 self.config.threshold = threshold.to_scalar()?;
1294 }
1295 if let Some(wd) = state.get("weight_decay") {
1296 self.config.weight_decay = wd.to_scalar()?;
1297 }
1298 if let Some(step) = state.get("current_step") {
1299 self.current_step = step.to_scalar()? as usize;
1300 }
1301 if let Some(momentum) = state.get("momentum_coefficient") {
1302 self.momentum_coefficient = momentum.to_scalar()?;
1303 }
1304 if let Some(prev_momentum) = state.get("previous_momentum") {
1305 self.previous_momentum = prev_momentum.to_scalar()?;
1306 }
1307
1308 self.previous_params.clear();
1309 for (key, tensor) in state {
1310 if let Some(param_id_str) = key.strip_prefix("prev_param_") {
1311 if let Ok(param_id) = param_id_str.parse::<usize>() {
1312 self.previous_params.insert(param_id, tensor);
1313 }
1314 }
1315 }
1316
1317 Ok(())
1318 }
1319}
1320
1321#[derive(Debug, Clone, Serialize, Deserialize)]
1323pub struct AdaptiveBatchSizingConfig {
1324 pub initial_batch_size: usize,
1326 pub min_batch_size: usize,
1328 pub max_batch_size: usize,
1330 pub gradient_variance_tolerance: f32,
1332 pub lr_adaptation_factor: f32,
1334 pub variance_window_size: usize,
1336 pub increase_threshold: f32,
1338 pub decrease_threshold: f32,
1340}
1341
1342impl Default for AdaptiveBatchSizingConfig {
1343 fn default() -> Self {
1344 Self {
1345 initial_batch_size: 32,
1346 min_batch_size: 8,
1347 max_batch_size: 512,
1348 gradient_variance_tolerance: 0.1,
1349 lr_adaptation_factor: 0.8,
1350 variance_window_size: 10,
1351 increase_threshold: 0.05,
1352 decrease_threshold: 0.2,
1353 }
1354 }
1355}
1356
1357#[derive(Debug)]
1363pub struct AdaptiveBatchSizing {
1364 config: AdaptiveBatchSizingConfig,
1365 current_batch_size: usize,
1366 gradient_variance_history: Vec<f32>,
1367 loss_history: Vec<f32>,
1368 current_step: usize,
1369 last_adjustment_step: usize,
1370}
1371
1372impl AdaptiveBatchSizing {
1373 pub fn new(config: AdaptiveBatchSizingConfig) -> Self {
1375 let initial_batch_size = config.initial_batch_size;
1376 Self {
1377 config,
1378 current_batch_size: initial_batch_size,
1379 gradient_variance_history: Vec::new(),
1380 loss_history: Vec::new(),
1381 current_step: 0,
1382 last_adjustment_step: 0,
1383 }
1384 }
1385
1386 pub fn with_defaults(
1388 initial_batch_size: usize,
1389 min_batch_size: usize,
1390 max_batch_size: usize,
1391 ) -> Self {
1392 Self::new(AdaptiveBatchSizingConfig {
1393 initial_batch_size,
1394 min_batch_size,
1395 max_batch_size,
1396 ..Default::default()
1397 })
1398 }
1399
1400 pub fn current_batch_size(&self) -> usize {
1402 self.current_batch_size
1403 }
1404
1405 pub fn get_config(&self) -> &AdaptiveBatchSizingConfig {
1407 &self.config
1408 }
1409
1410 pub fn update(&mut self, gradient_variance: f32, current_loss: f32) -> Result<usize> {
1412 self.current_step += 1;
1413
1414 self.gradient_variance_history.push(gradient_variance);
1416 self.loss_history.push(current_loss);
1417
1418 if self.gradient_variance_history.len() > self.config.variance_window_size {
1420 self.gradient_variance_history.remove(0);
1421 }
1422 if self.loss_history.len() > self.config.variance_window_size {
1423 self.loss_history.remove(0);
1424 }
1425
1426 if self.should_adjust_batch_size() {
1428 self.adjust_batch_size()?;
1429 self.last_adjustment_step = self.current_step;
1430 }
1431
1432 Ok(self.current_batch_size)
1433 }
1434
1435 pub fn compute_gradient_variance(&self, gradients: &[Tensor]) -> Result<f32> {
1437 if gradients.is_empty() {
1438 return Ok(0.0);
1439 }
1440
1441 let mut mean_grad = gradients[0].clone();
1443 for grad in gradients.iter().skip(1) {
1444 mean_grad = mean_grad.add(grad)?;
1445 }
1446 mean_grad = mean_grad.div_scalar(gradients.len() as f32)?;
1447
1448 let mut variance_sum = 0.0;
1450 for grad in gradients {
1451 let diff = grad.sub(&mean_grad)?;
1452 let squared_norm = diff.mul(&diff)?.sum(None, false)?;
1453 variance_sum += squared_norm.to_scalar()?;
1454 }
1455
1456 Ok(variance_sum / gradients.len() as f32)
1457 }
1458
1459 fn should_adjust_batch_size(&self) -> bool {
1460 if self.current_step - self.last_adjustment_step < 5 {
1462 return false;
1463 }
1464
1465 self.gradient_variance_history.len() >= 3
1467 }
1468
1469 fn adjust_batch_size(&mut self) -> Result<()> {
1470 let recent_variance = self.recent_average_variance();
1471 let variance_trend = self.variance_trend();
1472 let loss_trend = self.loss_trend();
1473
1474 if recent_variance > self.config.decrease_threshold && variance_trend > 0.0 {
1476 self.increase_batch_size();
1478 } else if recent_variance < self.config.increase_threshold && loss_trend < -0.01 {
1479 self.decrease_batch_size();
1481 }
1482
1483 Ok(())
1484 }
1485
1486 fn recent_average_variance(&self) -> f32 {
1487 if self.gradient_variance_history.is_empty() {
1488 return 0.0;
1489 }
1490
1491 let recent_window = std::cmp::min(5, self.gradient_variance_history.len());
1492 let start_idx = self.gradient_variance_history.len() - recent_window;
1493
1494 self.gradient_variance_history[start_idx..].iter().sum::<f32>() / recent_window as f32
1495 }
1496
1497 fn variance_trend(&self) -> f32 {
1498 if self.gradient_variance_history.len() < 3 {
1499 return 0.0;
1500 }
1501
1502 let len = self.gradient_variance_history.len();
1503 let recent = self.gradient_variance_history[len - 2..].iter().sum::<f32>() / 2.0;
1504 let older = self.gradient_variance_history[len - 4..len - 2].iter().sum::<f32>() / 2.0;
1505
1506 recent - older
1507 }
1508
1509 fn loss_trend(&self) -> f32 {
1510 if self.loss_history.len() < 3 {
1511 return 0.0;
1512 }
1513
1514 let len = self.loss_history.len();
1515 let recent = self.loss_history[len - 2..].iter().sum::<f32>() / 2.0;
1516 let older = self.loss_history[len - 4..len - 2].iter().sum::<f32>() / 2.0;
1517
1518 (recent - older) / older.max(1e-8)
1519 }
1520
1521 fn increase_batch_size(&mut self) {
1522 let new_size = (self.current_batch_size as f32 * 1.5) as usize;
1523 self.current_batch_size = new_size.min(self.config.max_batch_size);
1524 }
1525
1526 fn decrease_batch_size(&mut self) {
1527 let new_size = (self.current_batch_size as f32 * 0.8) as usize;
1528 self.current_batch_size = new_size.max(self.config.min_batch_size);
1529 }
1530
1531 pub fn get_lr_adjustment(&self, original_batch_size: usize) -> f32 {
1533 let ratio = self.current_batch_size as f32 / original_batch_size as f32;
1534 ratio.sqrt() * self.config.lr_adaptation_factor
1535 }
1536
1537 pub fn reset(&mut self) {
1539 self.current_batch_size = self.config.initial_batch_size;
1540 self.gradient_variance_history.clear();
1541 self.loss_history.clear();
1542 self.current_step = 0;
1543 self.last_adjustment_step = 0;
1544 }
1545}
1546
1547#[derive(Debug, Clone, Serialize, Deserialize)]
1549pub struct LossSurfaceSmoothingConfig {
1550 pub smoothing_strength: f32,
1552 pub noise_variance: f32,
1554 pub ema_decay: f32,
1556 pub averaging_window: usize,
1558 pub use_gradient_averaging: bool,
1560 pub use_noise_injection: bool,
1562}
1563
1564impl Default for LossSurfaceSmoothingConfig {
1565 fn default() -> Self {
1566 Self {
1567 smoothing_strength: 0.1,
1568 noise_variance: 1e-4,
1569 ema_decay: 0.9,
1570 averaging_window: 5,
1571 use_gradient_averaging: true,
1572 use_noise_injection: false,
1573 }
1574 }
1575}
1576
1577#[derive(Debug)]
1585pub struct LossSurfaceSmoothing {
1586 config: LossSurfaceSmoothingConfig,
1587 gradient_history: HashMap<usize, Vec<Tensor>>,
1588 ema_gradients: HashMap<usize, Tensor>,
1589 smoothed_parameters: HashMap<usize, Tensor>,
1590 current_step: usize,
1591}
1592
1593impl LossSurfaceSmoothing {
1594 pub fn new(config: LossSurfaceSmoothingConfig) -> Self {
1596 Self {
1597 config,
1598 gradient_history: HashMap::new(),
1599 ema_gradients: HashMap::new(),
1600 smoothed_parameters: HashMap::new(),
1601 current_step: 0,
1602 }
1603 }
1604
1605 pub fn with_defaults(smoothing_strength: f32, use_noise: bool) -> Self {
1607 Self::new(LossSurfaceSmoothingConfig {
1608 smoothing_strength,
1609 use_noise_injection: use_noise,
1610 ..Default::default()
1611 })
1612 }
1613
1614 pub fn get_config(&self) -> &LossSurfaceSmoothingConfig {
1616 &self.config
1617 }
1618
1619 pub fn smooth_gradients(&mut self, parameters: &mut [Tensor]) -> Result<()> {
1621 self.current_step += 1;
1622
1623 for (param_id, parameter) in parameters.iter_mut().enumerate() {
1624 let original_grad = parameter.grad()?;
1625 let mut smoothed_grad = original_grad.clone();
1626
1627 if self.config.use_gradient_averaging {
1629 smoothed_grad = self.apply_gradient_averaging(param_id, &original_grad)?;
1630 }
1631
1632 smoothed_grad = self.apply_ema_smoothing(param_id, &smoothed_grad)?;
1634
1635 if self.config.use_noise_injection {
1637 smoothed_grad = self.apply_noise_injection(&smoothed_grad)?;
1638 }
1639
1640 parameter.set_grad(smoothed_grad)?;
1642 }
1643
1644 Ok(())
1645 }
1646
1647 pub fn smooth_parameters(&mut self, parameters: &mut [Tensor]) -> Result<()> {
1649 for (param_id, parameter) in parameters.iter_mut().enumerate() {
1650 if let Some(smoothed_param) = self.smoothed_parameters.get(¶m_id) {
1651 let new_smoothed = smoothed_param
1653 .mul_scalar(self.config.ema_decay)?
1654 .add(¶meter.mul_scalar(1.0 - self.config.ema_decay)?)?;
1655
1656 *parameter = parameter
1658 .mul_scalar(1.0 - self.config.smoothing_strength)?
1659 .add(&new_smoothed.mul_scalar(self.config.smoothing_strength)?)?;
1660
1661 self.smoothed_parameters.insert(param_id, new_smoothed);
1662 } else {
1663 self.smoothed_parameters.insert(param_id, parameter.clone());
1665 }
1666 }
1667
1668 Ok(())
1669 }
1670
1671 fn apply_gradient_averaging(&mut self, param_id: usize, gradient: &Tensor) -> Result<Tensor> {
1672 let history = self.gradient_history.entry(param_id).or_default();
1673
1674 history.push(gradient.clone());
1675 if history.len() > self.config.averaging_window {
1676 history.remove(0);
1677 }
1678
1679 if history.len() == 1 {
1681 Ok(gradient.clone())
1682 } else {
1683 let mut sum = history[0].clone();
1684 for grad in history.iter().skip(1) {
1685 sum = sum.add(grad)?;
1686 }
1687 Ok(sum.div_scalar(history.len() as f32)?)
1688 }
1689 }
1690
1691 fn apply_ema_smoothing(&mut self, param_id: usize, gradient: &Tensor) -> Result<Tensor> {
1692 if let Some(ema_grad) = self.ema_gradients.get(¶m_id) {
1693 let new_ema = ema_grad
1694 .mul_scalar(self.config.ema_decay)?
1695 .add(&gradient.mul_scalar(1.0 - self.config.ema_decay)?)?;
1696 self.ema_gradients.insert(param_id, new_ema.clone());
1697 Ok(new_ema)
1698 } else {
1699 self.ema_gradients.insert(param_id, gradient.clone());
1700 Ok(gradient.clone())
1701 }
1702 }
1703
1704 fn apply_noise_injection(&self, gradient: &Tensor) -> Result<Tensor> {
1705 let noise = Tensor::randn_like(gradient)
1706 .map_err(|e| anyhow!("Failed to create noise tensor: {}", e))?
1707 .mul_scalar(self.config.noise_variance.sqrt())
1708 .map_err(|e| anyhow!("Failed to scale noise tensor: {}", e))?;
1709 gradient
1710 .add(&noise)
1711 .map_err(|e| anyhow!("Failed to add noise to gradient: {}", e))
1712 }
1713
1714 pub fn reset(&mut self) {
1716 self.gradient_history.clear();
1717 self.ema_gradients.clear();
1718 self.smoothed_parameters.clear();
1719 self.current_step = 0;
1720 }
1721
1722 pub fn get_statistics(&self) -> HashMap<String, f32> {
1724 let mut stats = HashMap::new();
1725 stats.insert("current_step".to_string(), self.current_step as f32);
1726 stats.insert(
1727 "num_tracked_params".to_string(),
1728 self.gradient_history.len() as f32,
1729 );
1730 stats.insert(
1731 "smoothing_strength".to_string(),
1732 self.config.smoothing_strength,
1733 );
1734 stats.insert("ema_decay".to_string(), self.config.ema_decay);
1735 stats
1736 }
1737}
1738
1739#[cfg(test)]
1740mod tests {
1741 use super::*;
1742
1743 #[test]
1744 fn test_qhm_config_default() {
1745 let config = QHMConfig::default();
1746 assert_eq!(config.learning_rate, 1e-3);
1747 assert_eq!(config.momentum, 0.9);
1748 assert_eq!(config.nu, 0.7);
1749 assert_eq!(config.weight_decay, 0.0);
1750 }
1751
1752 #[test]
1753 fn test_aggmo_config_default() {
1754 let config = AggMoConfig::default();
1755 assert_eq!(config.learning_rate, 1e-3);
1756 assert_eq!(config.momentum_coefficients, vec![0.0, 0.9, 0.99]);
1757 assert_eq!(config.weight_decay, 0.0);
1758 }
1759
1760 #[test]
1761 fn test_qhm_creation() {
1762 let optimizer = QHM::with_defaults(1e-3, 0.9, 0.7);
1763 assert_eq!(optimizer.get_lr(), 1e-3);
1764 assert_eq!(optimizer.current_step, 0);
1765 }
1766
1767 #[test]
1768 fn test_aggmo_creation() {
1769 let optimizer = AggMo::with_defaults(1e-3, vec![0.0, 0.9, 0.99]);
1770 assert_eq!(optimizer.get_lr(), 1e-3);
1771 assert_eq!(optimizer.num_momentum_buffers(), 3);
1772 }
1773
1774 #[test]
1775 fn test_variance_reduction_svrg() {
1776 let optimizer = VarianceReduction::svrg(1e-3, 50, 10);
1777 assert_eq!(optimizer.get_lr(), 1e-3);
1778 assert_eq!(optimizer.current_step, 0);
1779 }
1780
1781 #[test]
1782 fn test_variance_reduction_sag() {
1783 let optimizer = VarianceReduction::sag(1e-3, 100);
1784 assert_eq!(optimizer.get_lr(), 1e-3);
1785 assert!(matches!(
1786 optimizer.config.method,
1787 VarianceReductionMethod::SAG
1788 ));
1789 }
1790
1791 #[test]
1792 fn test_nesterov_accelerated_gradient_config() {
1793 let config = NesterovAcceleratedGradientConfig::default();
1794 assert_eq!(config.learning_rate, 1e-3);
1795 assert_eq!(config.momentum, 0.9);
1796 assert_eq!(config.weight_decay, 0.0);
1797 assert!(!config.restart_on_increase);
1798 }
1799
1800 #[test]
1801 fn test_nesterov_accelerated_gradient_creation() {
1802 let optimizer = NesterovAcceleratedGradient::with_defaults(1e-3, 0.9);
1803 assert_eq!(optimizer.get_lr(), 1e-3);
1804 assert_eq!(optimizer.current_step, 0);
1805 assert!(optimizer.previous_loss.is_none());
1806 }
1807
1808 #[test]
1809 fn test_nesterov_restart_on_increase() {
1810 let mut optimizer = NesterovAcceleratedGradient::new(NesterovAcceleratedGradientConfig {
1811 learning_rate: 1e-3,
1812 momentum: 0.9,
1813 weight_decay: 0.0,
1814 restart_on_increase: true,
1815 });
1816
1817 optimizer.set_current_loss(1.0);
1819 assert_eq!(optimizer.previous_loss, Some(1.0));
1820
1821 optimizer.set_current_loss(1.5);
1823 assert_eq!(optimizer.previous_loss, Some(1.5));
1824 }
1825
1826 #[test]
1827 fn test_heavy_ball_config() {
1828 let config = HeavyBallConfig::default();
1829 assert_eq!(config.learning_rate, 1e-3);
1830 assert_eq!(config.beta, 0.9);
1831 assert_eq!(config.weight_decay, 0.0);
1832 assert!(!config.adaptive_momentum);
1833 }
1834
1835 #[test]
1836 fn test_heavy_ball_creation() {
1837 let optimizer = HeavyBall::with_defaults(1e-3, 0.9);
1838 assert_eq!(optimizer.get_lr(), 1e-3);
1839 assert_eq!(optimizer.current_step, 0);
1840 assert_eq!(optimizer.get_config().beta, 0.9);
1841 }
1842
1843 #[test]
1844 fn test_heavy_ball_adaptive_momentum() {
1845 let optimizer = HeavyBall::new(HeavyBallConfig {
1846 learning_rate: 1e-3,
1847 beta: 0.9,
1848 weight_decay: 0.0,
1849 adaptive_momentum: true,
1850 });
1851
1852 assert!(optimizer.config.adaptive_momentum);
1853 }
1854
1855 #[test]
1856 fn test_fista_config() {
1857 let config = FISTAConfig::default();
1858 assert_eq!(config.learning_rate, 1e-3);
1859 assert_eq!(config.threshold, 1e-4);
1860 assert!(config.adaptive_restart);
1861 assert_eq!(config.weight_decay, 0.0);
1862 }
1863
1864 #[test]
1865 fn test_fista_creation() {
1866 let optimizer = FISTA::with_defaults(1e-3, 1e-4);
1867 assert_eq!(optimizer.get_lr(), 1e-3);
1868 assert_eq!(optimizer.current_step, 0);
1869 assert_eq!(optimizer.momentum_coefficient, 1.0);
1870 assert_eq!(optimizer.previous_momentum, 1.0);
1871 }
1872
1873 #[test]
1874 fn test_fista_momentum_update() {
1875 let mut optimizer = FISTA::with_defaults(1e-3, 1e-4);
1876
1877 optimizer.current_step = 1;
1879 optimizer.update_momentum_coefficient();
1880 assert!(optimizer.momentum_coefficient > 1.0);
1881 assert_eq!(optimizer.previous_momentum, 1.0);
1882
1883 let prev_momentum = optimizer.momentum_coefficient;
1884 optimizer.current_step = 2;
1885 optimizer.update_momentum_coefficient();
1886 assert!(optimizer.momentum_coefficient > prev_momentum);
1887 }
1888
1889 #[test]
1890 fn test_adaptive_batch_sizing_config() {
1891 let config = AdaptiveBatchSizingConfig::default();
1892 assert_eq!(config.initial_batch_size, 32);
1893 assert_eq!(config.min_batch_size, 8);
1894 assert_eq!(config.max_batch_size, 512);
1895 assert_eq!(config.gradient_variance_tolerance, 0.1);
1896 assert_eq!(config.lr_adaptation_factor, 0.8);
1897 assert_eq!(config.variance_window_size, 10);
1898 assert_eq!(config.increase_threshold, 0.05);
1899 assert_eq!(config.decrease_threshold, 0.2);
1900 }
1901
1902 #[test]
1903 fn test_adaptive_batch_sizing_creation() {
1904 let abs = AdaptiveBatchSizing::with_defaults(64, 16, 256);
1905 assert_eq!(abs.current_batch_size(), 64);
1906 assert_eq!(abs.get_config().min_batch_size, 16);
1907 assert_eq!(abs.get_config().max_batch_size, 256);
1908 }
1909
1910 #[test]
1911 fn test_adaptive_batch_sizing_lr_adjustment() {
1912 let abs = AdaptiveBatchSizing::with_defaults(64, 16, 256);
1913 let lr_adj = abs.get_lr_adjustment(32);
1914 assert!(lr_adj > 0.0);
1915 assert!(lr_adj < 2.0);
1916 }
1917
1918 #[test]
1919 fn test_adaptive_batch_sizing_reset() {
1920 let mut abs = AdaptiveBatchSizing::with_defaults(64, 16, 256);
1921 abs.current_step = 10;
1922 abs.reset();
1923 assert_eq!(abs.current_step, 0);
1924 assert_eq!(abs.current_batch_size(), 64);
1925 }
1926
1927 #[test]
1928 fn test_loss_surface_smoothing_config() {
1929 let config = LossSurfaceSmoothingConfig::default();
1930 assert_eq!(config.smoothing_strength, 0.1);
1931 assert_eq!(config.noise_variance, 1e-4);
1932 assert_eq!(config.ema_decay, 0.9);
1933 assert_eq!(config.averaging_window, 5);
1934 assert!(config.use_gradient_averaging);
1935 assert!(!config.use_noise_injection);
1936 }
1937
1938 #[test]
1939 fn test_loss_surface_smoothing_creation() {
1940 let lss = LossSurfaceSmoothing::with_defaults(0.2, true);
1941 assert_eq!(lss.get_config().smoothing_strength, 0.2);
1942 assert!(lss.get_config().use_noise_injection);
1943 assert_eq!(lss.current_step, 0);
1944 }
1945
1946 #[test]
1947 fn test_loss_surface_smoothing_statistics() {
1948 let lss = LossSurfaceSmoothing::with_defaults(0.1, false);
1949 let stats = lss.get_statistics();
1950 assert_eq!(stats.get("current_step"), Some(&0.0));
1951 assert_eq!(stats.get("num_tracked_params"), Some(&0.0));
1952 assert_eq!(stats.get("smoothing_strength"), Some(&0.1));
1953 assert_eq!(stats.get("ema_decay"), Some(&0.9));
1954 }
1955
1956 #[test]
1957 fn test_loss_surface_smoothing_reset() {
1958 let mut lss = LossSurfaceSmoothing::with_defaults(0.1, false);
1959 lss.current_step = 5;
1960 lss.reset();
1961 assert_eq!(lss.current_step, 0);
1962 }
1963}