1use crate::optimizer::OptimizerState;
20use anyhow::{anyhow, Result};
21use serde::{Deserialize, Serialize};
22use std::collections::HashMap;
23use trustformers_core::tensor::Tensor;
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct QHMConfig {
28 pub learning_rate: f32,
30 pub momentum: f32,
32 pub nu: f32,
34 pub weight_decay: f32,
36}
37
38impl Default for QHMConfig {
39 fn default() -> Self {
40 Self {
41 learning_rate: 1e-3,
42 momentum: 0.9,
43 nu: 0.7,
44 weight_decay: 0.0,
45 }
46 }
47}
48
49#[derive(Debug)]
55pub struct QHM {
56 config: QHMConfig,
57 momentum_buffers: HashMap<usize, Tensor>,
58 current_step: usize,
59}
60
61impl QHM {
62 pub fn new(config: QHMConfig) -> Self {
64 Self {
65 config,
66 momentum_buffers: HashMap::new(),
67 current_step: 0,
68 }
69 }
70
71 pub fn with_defaults(learning_rate: f32, momentum: f32, nu: f32) -> Self {
73 Self::new(QHMConfig {
74 learning_rate,
75 momentum,
76 nu,
77 weight_decay: 0.0,
78 })
79 }
80
81 pub fn get_config(&self) -> &QHMConfig {
83 &self.config
84 }
85
86 pub fn set_config(&mut self, config: QHMConfig) {
88 self.config = config;
89 }
90}
91
92impl OptimizerState for QHM {
93 fn zero_grad(&mut self) -> Result<()> {
94 Ok(())
96 }
97
98 fn step(&mut self, parameters: &mut [Tensor]) -> Result<()> {
99 self.current_step += 1;
100
101 for (param_id, parameter) in parameters.iter_mut().enumerate() {
102 let gradient = match parameter.grad() {
104 Ok(grad) => grad,
105 Err(_) => {
106 continue;
108 },
109 };
110
111 let effective_grad = if self.config.weight_decay > 0.0 {
113 gradient.add(¶meter.mul_scalar(self.config.weight_decay)?)?
114 } else {
115 gradient
116 };
117
118 let momentum_buffer = if let Some(buffer) = self.momentum_buffers.get(¶m_id) {
120 let updated = buffer
122 .mul_scalar(self.config.momentum)?
123 .add(&effective_grad.mul_scalar(1.0 - self.config.momentum)?)?;
124 self.momentum_buffers.insert(param_id, updated.clone());
125 updated
126 } else {
127 let initial_momentum = effective_grad.clone();
129 self.momentum_buffers.insert(param_id, initial_momentum.clone());
130 initial_momentum
131 };
132
133 let update_direction = effective_grad
135 .mul_scalar(self.config.nu)?
136 .add(&momentum_buffer.mul_scalar(1.0 - self.config.nu)?)?;
137
138 *parameter = parameter.sub(&update_direction.mul_scalar(self.config.learning_rate)?)?;
140 }
141
142 Ok(())
143 }
144
145 fn get_lr(&self) -> f32 {
146 self.config.learning_rate
147 }
148
149 fn set_lr(&mut self, lr: f32) {
150 self.config.learning_rate = lr;
151 }
152
153 fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
154 let mut state = HashMap::new();
155
156 state.insert(
158 "learning_rate".to_string(),
159 Tensor::scalar(self.config.learning_rate)?,
160 );
161 state.insert(
162 "momentum".to_string(),
163 Tensor::scalar(self.config.momentum)?,
164 );
165 state.insert("nu".to_string(), Tensor::scalar(self.config.nu)?);
166 state.insert(
167 "weight_decay".to_string(),
168 Tensor::scalar(self.config.weight_decay)?,
169 );
170 state.insert(
171 "current_step".to_string(),
172 Tensor::scalar(self.current_step as f32)?,
173 );
174
175 for (¶m_id, buffer) in &self.momentum_buffers {
177 state.insert(format!("momentum_buffer_{}", param_id), buffer.clone());
178 }
179
180 Ok(state)
181 }
182
183 fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
184 if let Some(lr) = state.get("learning_rate") {
186 self.config.learning_rate = lr.to_scalar()?;
187 }
188 if let Some(momentum) = state.get("momentum") {
189 self.config.momentum = momentum.to_scalar()?;
190 }
191 if let Some(nu) = state.get("nu") {
192 self.config.nu = nu.to_scalar()?;
193 }
194 if let Some(wd) = state.get("weight_decay") {
195 self.config.weight_decay = wd.to_scalar()?;
196 }
197 if let Some(step) = state.get("current_step") {
198 self.current_step = step.to_scalar()? as usize;
199 }
200
201 self.momentum_buffers.clear();
203 for (key, tensor) in state {
204 if let Some(param_id_str) = key.strip_prefix("momentum_buffer_") {
205 if let Ok(param_id) = param_id_str.parse::<usize>() {
206 self.momentum_buffers.insert(param_id, tensor);
207 }
208 }
209 }
210
211 Ok(())
212 }
213}
214
215#[derive(Debug, Clone, Serialize, Deserialize)]
217pub struct AggMoConfig {
218 pub learning_rate: f32,
220 pub momentum_coefficients: Vec<f32>,
222 pub weight_decay: f32,
224}
225
226impl Default for AggMoConfig {
227 fn default() -> Self {
228 Self {
229 learning_rate: 1e-3,
230 momentum_coefficients: vec![0.0, 0.9, 0.99],
231 weight_decay: 0.0,
232 }
233 }
234}
235
236#[derive(Debug)]
241pub struct AggMo {
242 config: AggMoConfig,
243 momentum_buffers: HashMap<usize, Vec<Tensor>>, current_step: usize,
245}
246
247impl AggMo {
248 pub fn new(config: AggMoConfig) -> Self {
250 assert!(
251 !config.momentum_coefficients.is_empty(),
252 "Must provide at least one momentum coefficient"
253 );
254 Self {
255 config,
256 momentum_buffers: HashMap::new(),
257 current_step: 0,
258 }
259 }
260
261 pub fn with_defaults(learning_rate: f32, momentum_coefficients: Vec<f32>) -> Self {
263 Self::new(AggMoConfig {
264 learning_rate,
265 momentum_coefficients,
266 weight_decay: 0.0,
267 })
268 }
269
270 pub fn get_config(&self) -> &AggMoConfig {
272 &self.config
273 }
274
275 pub fn num_momentum_buffers(&self) -> usize {
277 self.config.momentum_coefficients.len()
278 }
279}
280
281impl OptimizerState for AggMo {
282 fn zero_grad(&mut self) -> Result<()> {
283 Ok(())
284 }
285
286 fn step(&mut self, parameters: &mut [Tensor]) -> Result<()> {
287 self.current_step += 1;
288
289 for (param_id, parameter) in parameters.iter_mut().enumerate() {
290 let gradient = match parameter.grad() {
292 Ok(grad) => grad,
293 Err(_) => {
294 continue;
296 },
297 };
298
299 let effective_grad = if self.config.weight_decay > 0.0 {
301 gradient.add(¶meter.mul_scalar(self.config.weight_decay)?)?
302 } else {
303 gradient
304 };
305
306 let buffers = self.momentum_buffers.entry(param_id).or_insert_with(|| {
308 (0..self.config.momentum_coefficients.len())
310 .map(|_| Tensor::zeros(&effective_grad.shape()).unwrap())
311 .collect()
312 });
313
314 let mut aggregated_momentum = Tensor::zeros(&effective_grad.shape())?;
316 for (i, &beta) in self.config.momentum_coefficients.iter().enumerate() {
317 buffers[i] =
319 buffers[i].mul_scalar(beta)?.add(&effective_grad.mul_scalar(1.0 - beta)?)?;
320
321 aggregated_momentum = aggregated_momentum.add(&buffers[i])?;
323 }
324
325 let num_buffers = self.config.momentum_coefficients.len() as f32;
327 let averaged_momentum = aggregated_momentum.div_scalar(num_buffers)?;
328
329 *parameter =
331 parameter.sub(&averaged_momentum.mul_scalar(self.config.learning_rate)?)?;
332 }
333
334 Ok(())
335 }
336
337 fn get_lr(&self) -> f32 {
338 self.config.learning_rate
339 }
340
341 fn set_lr(&mut self, lr: f32) {
342 self.config.learning_rate = lr;
343 }
344
345 fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
346 let mut state = HashMap::new();
347
348 state.insert(
350 "learning_rate".to_string(),
351 Tensor::scalar(self.config.learning_rate)?,
352 );
353 state.insert(
354 "weight_decay".to_string(),
355 Tensor::scalar(self.config.weight_decay)?,
356 );
357 state.insert(
358 "current_step".to_string(),
359 Tensor::scalar(self.current_step as f32)?,
360 );
361 state.insert(
362 "num_momentum_coeffs".to_string(),
363 Tensor::scalar(self.config.momentum_coefficients.len() as f32)?,
364 );
365
366 for (i, &coeff) in self.config.momentum_coefficients.iter().enumerate() {
368 state.insert(format!("momentum_coeff_{}", i), Tensor::scalar(coeff)?);
369 }
370
371 for (¶m_id, buffers) in &self.momentum_buffers {
373 for (buffer_idx, buffer) in buffers.iter().enumerate() {
374 state.insert(
375 format!("momentum_buffer_{}_{}", param_id, buffer_idx),
376 buffer.clone(),
377 );
378 }
379 }
380
381 Ok(state)
382 }
383
384 fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
385 if let Some(lr) = state.get("learning_rate") {
387 self.config.learning_rate = lr.to_scalar()?;
388 }
389 if let Some(wd) = state.get("weight_decay") {
390 self.config.weight_decay = wd.to_scalar()?;
391 }
392 if let Some(step) = state.get("current_step") {
393 self.current_step = step.to_scalar()? as usize;
394 }
395
396 if let Some(num_coeffs_tensor) = state.get("num_momentum_coeffs") {
398 let num_coeffs = num_coeffs_tensor.to_scalar()? as usize;
399 let mut coefficients = Vec::with_capacity(num_coeffs);
400 for i in 0..num_coeffs {
401 if let Some(coeff_tensor) = state.get(&format!("momentum_coeff_{}", i)) {
402 coefficients.push(coeff_tensor.to_scalar()?);
403 }
404 }
405 self.config.momentum_coefficients = coefficients;
406 }
407
408 self.momentum_buffers.clear();
410 let mut param_buffers: HashMap<usize, HashMap<usize, Tensor>> = HashMap::new();
411
412 for (key, tensor) in state {
413 if key.starts_with("momentum_buffer_") {
414 let parts: Vec<&str> = key.split('_').collect();
415 if parts.len() >= 4 {
416 if let (Ok(param_id), Ok(buffer_idx)) =
417 (parts[2].parse::<usize>(), parts[3].parse::<usize>())
418 {
419 param_buffers.entry(param_id).or_default().insert(buffer_idx, tensor);
420 }
421 }
422 }
423 }
424
425 for (param_id, buffer_map) in param_buffers {
427 let mut buffers = Vec::new();
428 for i in 0..self.config.momentum_coefficients.len() {
429 if let Some(buffer) = buffer_map.get(&i) {
430 buffers.push(buffer.clone());
431 }
432 }
433 if buffers.len() == self.config.momentum_coefficients.len() {
434 self.momentum_buffers.insert(param_id, buffers);
435 }
436 }
437
438 Ok(())
439 }
440}
441
442#[derive(Debug, Clone, Serialize, Deserialize)]
444pub struct VarianceReductionConfig {
445 pub learning_rate: f32,
447 pub method: VarianceReductionMethod,
449 pub history_size: usize,
451 pub full_grad_frequency: usize,
453 pub weight_decay: f32,
455}
456
457impl Default for VarianceReductionConfig {
458 fn default() -> Self {
459 Self {
460 learning_rate: 1e-3,
461 method: VarianceReductionMethod::SVRG,
462 history_size: 100,
463 full_grad_frequency: 10,
464 weight_decay: 0.0,
465 }
466 }
467}
468
469#[derive(Debug, Clone, Serialize, Deserialize)]
471pub enum VarianceReductionMethod {
472 SVRG,
474 SAG,
476}
477
478#[derive(Debug)]
480pub struct VarianceReduction {
481 config: VarianceReductionConfig,
482 gradient_history: HashMap<usize, Vec<Tensor>>,
483 average_gradients: HashMap<usize, Tensor>,
484 full_gradients: HashMap<usize, Tensor>,
485 current_step: usize,
486 last_full_grad_step: usize,
487}
488
489impl VarianceReduction {
490 pub fn new(config: VarianceReductionConfig) -> Self {
492 Self {
493 config,
494 gradient_history: HashMap::new(),
495 average_gradients: HashMap::new(),
496 full_gradients: HashMap::new(),
497 current_step: 0,
498 last_full_grad_step: 0,
499 }
500 }
501
502 pub fn svrg(learning_rate: f32, history_size: usize, full_grad_frequency: usize) -> Self {
504 Self::new(VarianceReductionConfig {
505 learning_rate,
506 method: VarianceReductionMethod::SVRG,
507 history_size,
508 full_grad_frequency,
509 weight_decay: 0.0,
510 })
511 }
512
513 pub fn sag(learning_rate: f32, history_size: usize) -> Self {
515 Self::new(VarianceReductionConfig {
516 learning_rate,
517 method: VarianceReductionMethod::SAG,
518 history_size,
519 full_grad_frequency: 1, weight_decay: 0.0,
521 })
522 }
523
524 fn update_gradient_history(&mut self, param_id: usize, gradient: &Tensor) -> Result<()> {
525 let history = self.gradient_history.entry(param_id).or_default();
526
527 history.push(gradient.clone());
528 if history.len() > self.config.history_size {
529 history.remove(0);
530 }
531
532 Ok(())
533 }
534
535 fn compute_average_gradient(&mut self, param_id: usize) -> Result<Tensor> {
536 if let Some(history) = self.gradient_history.get(¶m_id) {
537 if history.is_empty() {
538 return Err(anyhow!("No gradient history available"));
539 }
540
541 let mut sum = history[0].clone();
542 for grad in history.iter().skip(1) {
543 sum = sum.add(grad)?;
544 }
545
546 let average = sum.div_scalar(history.len() as f32)?;
547 self.average_gradients.insert(param_id, average.clone());
548 Ok(average)
549 } else {
550 Err(anyhow!("No gradient history for parameter {}", param_id))
551 }
552 }
553
554 fn should_compute_full_gradient(&self) -> bool {
555 self.current_step - self.last_full_grad_step >= self.config.full_grad_frequency
556 }
557}
558
559impl OptimizerState for VarianceReduction {
560 fn zero_grad(&mut self) -> Result<()> {
561 Ok(())
562 }
563
564 fn step(&mut self, parameters: &mut [Tensor]) -> Result<()> {
565 self.current_step += 1;
566
567 let compute_full_grad = match self.config.method {
569 VarianceReductionMethod::SVRG => self.should_compute_full_gradient(),
570 VarianceReductionMethod::SAG => false,
571 };
572
573 if compute_full_grad {
574 self.last_full_grad_step = self.current_step;
575 for (param_id, parameter) in parameters.iter().enumerate() {
578 let gradient = match parameter.grad() {
580 Ok(grad) => grad,
581 Err(_) => {
582 continue;
584 },
585 };
586 self.full_gradients.insert(param_id, gradient);
587 }
588 }
589
590 for (param_id, parameter) in parameters.iter_mut().enumerate() {
591 let current_gradient = match parameter.grad() {
593 Ok(grad) => grad,
594 Err(_) => {
595 continue;
597 },
598 };
599
600 let effective_grad = if self.config.weight_decay > 0.0 {
602 current_gradient.add(¶meter.mul_scalar(self.config.weight_decay)?)?
603 } else {
604 current_gradient
605 };
606
607 self.update_gradient_history(param_id, &effective_grad)?;
609
610 let variance_reduced_grad = match self.config.method {
612 VarianceReductionMethod::SVRG => {
613 if self.full_gradients.contains_key(¶m_id) {
614 let avg_grad = self.compute_average_gradient(param_id)?;
615 let full_grad = self.full_gradients.get(¶m_id).unwrap();
616 effective_grad.sub(&avg_grad)?.add(full_grad)?
618 } else {
619 effective_grad
620 }
621 },
622 VarianceReductionMethod::SAG => {
623 self.compute_average_gradient(param_id)?
625 },
626 };
627
628 *parameter =
630 parameter.sub(&variance_reduced_grad.mul_scalar(self.config.learning_rate)?)?;
631 }
632
633 Ok(())
634 }
635
636 fn get_lr(&self) -> f32 {
637 self.config.learning_rate
638 }
639
640 fn set_lr(&mut self, lr: f32) {
641 self.config.learning_rate = lr;
642 }
643
644 fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
645 let mut state = HashMap::new();
646
647 state.insert(
648 "learning_rate".to_string(),
649 Tensor::scalar(self.config.learning_rate)?,
650 );
651 state.insert(
652 "current_step".to_string(),
653 Tensor::scalar(self.current_step as f32)?,
654 );
655 state.insert(
656 "last_full_grad_step".to_string(),
657 Tensor::scalar(self.last_full_grad_step as f32)?,
658 );
659
660 Ok(state)
664 }
665
666 fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
667 if let Some(lr) = state.get("learning_rate") {
668 self.config.learning_rate = lr.to_scalar()?;
669 }
670 if let Some(step) = state.get("current_step") {
671 self.current_step = step.to_scalar()? as usize;
672 }
673 if let Some(last_step) = state.get("last_full_grad_step") {
674 self.last_full_grad_step = last_step.to_scalar()? as usize;
675 }
676
677 Ok(())
678 }
679}
680
681#[derive(Debug, Clone, Serialize, Deserialize)]
683pub struct NesterovAcceleratedGradientConfig {
684 pub learning_rate: f32,
686 pub momentum: f32,
688 pub weight_decay: f32,
690 pub restart_on_increase: bool,
692}
693
694impl Default for NesterovAcceleratedGradientConfig {
695 fn default() -> Self {
696 Self {
697 learning_rate: 1e-3,
698 momentum: 0.9,
699 weight_decay: 0.0,
700 restart_on_increase: false,
701 }
702 }
703}
704
705#[derive(Debug)]
713pub struct NesterovAcceleratedGradient {
714 config: NesterovAcceleratedGradientConfig,
715 velocity_buffers: HashMap<usize, Tensor>,
716 current_step: usize,
717 previous_loss: Option<f32>,
718}
719
720impl NesterovAcceleratedGradient {
721 pub fn new(config: NesterovAcceleratedGradientConfig) -> Self {
723 Self {
724 config,
725 velocity_buffers: HashMap::new(),
726 current_step: 0,
727 previous_loss: None,
728 }
729 }
730
731 pub fn with_defaults(learning_rate: f32, momentum: f32) -> Self {
733 Self::new(NesterovAcceleratedGradientConfig {
734 learning_rate,
735 momentum,
736 weight_decay: 0.0,
737 restart_on_increase: false,
738 })
739 }
740
741 pub fn get_config(&self) -> &NesterovAcceleratedGradientConfig {
743 &self.config
744 }
745
746 pub fn set_current_loss(&mut self, loss: f32) {
748 if self.config.restart_on_increase {
749 if let Some(prev_loss) = self.previous_loss {
750 if loss > prev_loss {
751 self.velocity_buffers.clear();
753 }
754 }
755 }
756 self.previous_loss = Some(loss);
757 }
758}
759
760impl OptimizerState for NesterovAcceleratedGradient {
761 fn zero_grad(&mut self) -> Result<()> {
762 Ok(())
763 }
764
765 fn step(&mut self, parameters: &mut [Tensor]) -> Result<()> {
766 self.current_step += 1;
767
768 for (param_id, parameter) in parameters.iter_mut().enumerate() {
769 let gradient = match parameter.grad() {
771 Ok(grad) => grad,
772 Err(_) => {
773 continue;
775 },
776 };
777
778 let effective_grad = if self.config.weight_decay > 0.0 {
780 gradient.add(¶meter.mul_scalar(self.config.weight_decay)?)?
781 } else {
782 gradient
783 };
784
785 let velocity = if let Some(v) = self.velocity_buffers.get(¶m_id) {
787 v.clone()
788 } else {
789 Tensor::zeros_like(parameter)?
790 };
791
792 let _lookahead_position = parameter.sub(&velocity.mul_scalar(self.config.momentum)?)?;
794
795 let new_velocity = velocity
801 .mul_scalar(self.config.momentum)?
802 .add(&effective_grad.mul_scalar(self.config.learning_rate)?)?;
803
804 self.velocity_buffers.insert(param_id, new_velocity.clone());
805
806 *parameter = parameter.sub(&new_velocity)?;
808 }
809
810 Ok(())
811 }
812
813 fn get_lr(&self) -> f32 {
814 self.config.learning_rate
815 }
816
817 fn set_lr(&mut self, lr: f32) {
818 self.config.learning_rate = lr;
819 }
820
821 fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
822 let mut state = HashMap::new();
823
824 state.insert(
825 "learning_rate".to_string(),
826 Tensor::scalar(self.config.learning_rate)?,
827 );
828 state.insert(
829 "momentum".to_string(),
830 Tensor::scalar(self.config.momentum)?,
831 );
832 state.insert(
833 "weight_decay".to_string(),
834 Tensor::scalar(self.config.weight_decay)?,
835 );
836 state.insert(
837 "current_step".to_string(),
838 Tensor::scalar(self.current_step as f32)?,
839 );
840
841 if let Some(loss) = self.previous_loss {
842 state.insert("previous_loss".to_string(), Tensor::scalar(loss)?);
843 }
844
845 for (¶m_id, velocity) in &self.velocity_buffers {
846 state.insert(format!("velocity_{}", param_id), velocity.clone());
847 }
848
849 Ok(state)
850 }
851
852 fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
853 if let Some(lr) = state.get("learning_rate") {
854 self.config.learning_rate = lr.to_scalar()?;
855 }
856 if let Some(momentum) = state.get("momentum") {
857 self.config.momentum = momentum.to_scalar()?;
858 }
859 if let Some(wd) = state.get("weight_decay") {
860 self.config.weight_decay = wd.to_scalar()?;
861 }
862 if let Some(step) = state.get("current_step") {
863 self.current_step = step.to_scalar()? as usize;
864 }
865 if let Some(loss) = state.get("previous_loss") {
866 self.previous_loss = Some(loss.to_scalar()?);
867 }
868
869 self.velocity_buffers.clear();
870 for (key, tensor) in state {
871 if let Some(param_id_str) = key.strip_prefix("velocity_") {
872 if let Ok(param_id) = param_id_str.parse::<usize>() {
873 self.velocity_buffers.insert(param_id, tensor);
874 }
875 }
876 }
877
878 Ok(())
879 }
880}
881
882#[derive(Debug, Clone, Serialize, Deserialize)]
884pub struct HeavyBallConfig {
885 pub learning_rate: f32,
887 pub beta: f32,
889 pub weight_decay: f32,
891 pub adaptive_momentum: bool,
893}
894
895impl Default for HeavyBallConfig {
896 fn default() -> Self {
897 Self {
898 learning_rate: 1e-3,
899 beta: 0.9,
900 weight_decay: 0.0,
901 adaptive_momentum: false,
902 }
903 }
904}
905
906#[derive(Debug)]
913pub struct HeavyBall {
914 config: HeavyBallConfig,
915 velocity_buffers: HashMap<usize, Tensor>,
916 previous_gradients: HashMap<usize, Tensor>,
917 current_step: usize,
918}
919
920impl HeavyBall {
921 pub fn new(config: HeavyBallConfig) -> Self {
923 Self {
924 config,
925 velocity_buffers: HashMap::new(),
926 previous_gradients: HashMap::new(),
927 current_step: 0,
928 }
929 }
930
931 pub fn with_defaults(learning_rate: f32, beta: f32) -> Self {
933 Self::new(HeavyBallConfig {
934 learning_rate,
935 beta,
936 weight_decay: 0.0,
937 adaptive_momentum: false,
938 })
939 }
940
941 pub fn get_config(&self) -> &HeavyBallConfig {
943 &self.config
944 }
945
946 fn compute_adaptive_momentum(&self, param_id: usize, current_grad: &Tensor) -> Result<f32> {
948 if let Some(prev_grad) = self.previous_gradients.get(¶m_id) {
949 let dot_product = current_grad.mul(prev_grad)?.sum(None, false)?;
951 let norm_current = current_grad.norm_squared()?.sqrt()?;
952 let norm_prev = prev_grad.norm_squared()?.sqrt()?;
953
954 let dot_scalar = dot_product.to_scalar()?;
955 let norm_current_scalar = norm_current.to_scalar()?;
956 let norm_prev_scalar = norm_prev.to_scalar()?;
957
958 let denominator = norm_current_scalar * norm_prev_scalar;
959 if denominator > 1e-8 {
960 let cosine_similarity = dot_scalar / denominator;
961 let adaptive_beta = self.config.beta * cosine_similarity.max(0.0);
963 Ok(adaptive_beta)
964 } else {
965 Ok(self.config.beta)
966 }
967 } else {
968 Ok(self.config.beta)
969 }
970 }
971}
972
973impl OptimizerState for HeavyBall {
974 fn zero_grad(&mut self) -> Result<()> {
975 Ok(())
976 }
977
978 fn step(&mut self, parameters: &mut [Tensor]) -> Result<()> {
979 self.current_step += 1;
980
981 for (param_id, parameter) in parameters.iter_mut().enumerate() {
982 let gradient = match parameter.grad() {
984 Ok(grad) => grad,
985 Err(_) => {
986 continue;
988 },
989 };
990
991 let effective_grad = if self.config.weight_decay > 0.0 {
993 gradient.add(¶meter.mul_scalar(self.config.weight_decay)?)?
994 } else {
995 gradient
996 };
997
998 let beta = if self.config.adaptive_momentum {
1000 self.compute_adaptive_momentum(param_id, &effective_grad)?
1001 } else {
1002 self.config.beta
1003 };
1004
1005 let velocity = if let Some(v) = self.velocity_buffers.get(¶m_id) {
1007 v.clone()
1008 } else {
1009 Tensor::zeros_like(parameter)?
1010 };
1011
1012 let new_velocity = velocity
1014 .mul_scalar(beta)?
1015 .sub(&effective_grad.mul_scalar(self.config.learning_rate)?)?;
1016
1017 self.velocity_buffers.insert(param_id, new_velocity.clone());
1018
1019 *parameter = parameter.add(&new_velocity)?;
1021
1022 if self.config.adaptive_momentum {
1024 self.previous_gradients.insert(param_id, effective_grad);
1025 }
1026 }
1027
1028 Ok(())
1029 }
1030
1031 fn get_lr(&self) -> f32 {
1032 self.config.learning_rate
1033 }
1034
1035 fn set_lr(&mut self, lr: f32) {
1036 self.config.learning_rate = lr;
1037 }
1038
1039 fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
1040 let mut state = HashMap::new();
1041
1042 state.insert(
1043 "learning_rate".to_string(),
1044 Tensor::scalar(self.config.learning_rate)?,
1045 );
1046 state.insert("beta".to_string(), Tensor::scalar(self.config.beta)?);
1047 state.insert(
1048 "weight_decay".to_string(),
1049 Tensor::scalar(self.config.weight_decay)?,
1050 );
1051 state.insert(
1052 "current_step".to_string(),
1053 Tensor::scalar(self.current_step as f32)?,
1054 );
1055
1056 for (¶m_id, velocity) in &self.velocity_buffers {
1057 state.insert(format!("velocity_{}", param_id), velocity.clone());
1058 }
1059
1060 for (¶m_id, grad) in &self.previous_gradients {
1061 state.insert(format!("prev_grad_{}", param_id), grad.clone());
1062 }
1063
1064 Ok(state)
1065 }
1066
1067 fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
1068 if let Some(lr) = state.get("learning_rate") {
1069 self.config.learning_rate = lr.to_scalar()?;
1070 }
1071 if let Some(beta) = state.get("beta") {
1072 self.config.beta = beta.to_scalar()?;
1073 }
1074 if let Some(wd) = state.get("weight_decay") {
1075 self.config.weight_decay = wd.to_scalar()?;
1076 }
1077 if let Some(step) = state.get("current_step") {
1078 self.current_step = step.to_scalar()? as usize;
1079 }
1080
1081 self.velocity_buffers.clear();
1082 self.previous_gradients.clear();
1083
1084 for (key, tensor) in state {
1085 if let Some(param_id_str) = key.strip_prefix("velocity_") {
1086 if let Ok(param_id) = param_id_str.parse::<usize>() {
1087 self.velocity_buffers.insert(param_id, tensor);
1088 }
1089 } else if let Some(param_id_str) = key.strip_prefix("prev_grad_") {
1090 if let Ok(param_id) = param_id_str.parse::<usize>() {
1091 self.previous_gradients.insert(param_id, tensor);
1092 }
1093 }
1094 }
1095
1096 Ok(())
1097 }
1098}
1099
1100#[derive(Debug, Clone, Serialize, Deserialize)]
1102pub struct FISTAConfig {
1103 pub learning_rate: f32,
1105 pub threshold: f32,
1107 pub adaptive_restart: bool,
1109 pub weight_decay: f32,
1111}
1112
1113impl Default for FISTAConfig {
1114 fn default() -> Self {
1115 Self {
1116 learning_rate: 1e-3,
1117 threshold: 1e-4,
1118 adaptive_restart: true,
1119 weight_decay: 0.0,
1120 }
1121 }
1122}
1123
1124#[derive(Debug)]
1129pub struct FISTA {
1130 config: FISTAConfig,
1131 previous_params: HashMap<usize, Tensor>,
1132 current_step: usize,
1133 momentum_coefficient: f32,
1134 previous_momentum: f32,
1135}
1136
1137impl FISTA {
1138 pub fn new(config: FISTAConfig) -> Self {
1140 Self {
1141 config,
1142 previous_params: HashMap::new(),
1143 current_step: 0,
1144 momentum_coefficient: 1.0,
1145 previous_momentum: 1.0,
1146 }
1147 }
1148
1149 pub fn with_defaults(learning_rate: f32, threshold: f32) -> Self {
1151 Self::new(FISTAConfig {
1152 learning_rate,
1153 threshold,
1154 adaptive_restart: true,
1155 weight_decay: 0.0,
1156 })
1157 }
1158
1159 pub fn get_config(&self) -> &FISTAConfig {
1161 &self.config
1162 }
1163
1164 fn soft_threshold(&self, tensor: &Tensor, threshold: f32) -> Result<Tensor> {
1166 let threshold_tensor = Tensor::scalar(threshold)?;
1167 let zero_tensor = Tensor::zeros_like(tensor)?;
1168
1169 let abs_tensor = tensor.abs()?;
1171 let thresholded = abs_tensor.sub(&threshold_tensor)?.max(&zero_tensor)?;
1172 let sign_tensor = tensor.sign()?;
1173
1174 Ok(sign_tensor.mul(&thresholded)?)
1175 }
1176
1177 fn update_momentum_coefficient(&mut self) {
1179 let t = self.current_step as f32;
1180 self.previous_momentum = self.momentum_coefficient;
1181 self.momentum_coefficient = (1.0 + (1.0 + 4.0 * t * t).sqrt()) / 2.0;
1182 }
1183}
1184
1185impl OptimizerState for FISTA {
1186 fn zero_grad(&mut self) -> Result<()> {
1187 Ok(())
1188 }
1189
1190 fn step(&mut self, parameters: &mut [Tensor]) -> Result<()> {
1191 self.current_step += 1;
1192 self.update_momentum_coefficient();
1193
1194 for (param_id, parameter) in parameters.iter_mut().enumerate() {
1195 let gradient = match parameter.grad() {
1197 Ok(grad) => grad,
1198 Err(_) => {
1199 continue;
1201 },
1202 };
1203
1204 let effective_grad = if self.config.weight_decay > 0.0 {
1206 gradient.add(¶meter.mul_scalar(self.config.weight_decay)?)?
1207 } else {
1208 gradient
1209 };
1210
1211 let previous_param = if let Some(prev) = self.previous_params.get(¶m_id) {
1213 prev.clone()
1214 } else {
1215 parameter.clone()
1216 };
1217
1218 let beta = (self.previous_momentum - 1.0) / self.momentum_coefficient;
1220
1221 let extrapolated = parameter.add(&previous_param.sub(parameter)?.mul_scalar(beta)?)?;
1223
1224 let grad_step =
1226 extrapolated.sub(&effective_grad.mul_scalar(self.config.learning_rate)?)?;
1227
1228 let new_parameter = self.soft_threshold(&grad_step, self.config.threshold)?;
1230
1231 self.previous_params.insert(param_id, parameter.clone());
1233
1234 *parameter = new_parameter;
1236 }
1237
1238 Ok(())
1239 }
1240
1241 fn get_lr(&self) -> f32 {
1242 self.config.learning_rate
1243 }
1244
1245 fn set_lr(&mut self, lr: f32) {
1246 self.config.learning_rate = lr;
1247 }
1248
1249 fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
1250 let mut state = HashMap::new();
1251
1252 state.insert(
1253 "learning_rate".to_string(),
1254 Tensor::scalar(self.config.learning_rate)?,
1255 );
1256 state.insert(
1257 "threshold".to_string(),
1258 Tensor::scalar(self.config.threshold)?,
1259 );
1260 state.insert(
1261 "weight_decay".to_string(),
1262 Tensor::scalar(self.config.weight_decay)?,
1263 );
1264 state.insert(
1265 "current_step".to_string(),
1266 Tensor::scalar(self.current_step as f32)?,
1267 );
1268 state.insert(
1269 "momentum_coefficient".to_string(),
1270 Tensor::scalar(self.momentum_coefficient)?,
1271 );
1272 state.insert(
1273 "previous_momentum".to_string(),
1274 Tensor::scalar(self.previous_momentum)?,
1275 );
1276
1277 for (¶m_id, param) in &self.previous_params {
1278 state.insert(format!("prev_param_{}", param_id), param.clone());
1279 }
1280
1281 Ok(state)
1282 }
1283
1284 fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
1285 if let Some(lr) = state.get("learning_rate") {
1286 self.config.learning_rate = lr.to_scalar()?;
1287 }
1288 if let Some(threshold) = state.get("threshold") {
1289 self.config.threshold = threshold.to_scalar()?;
1290 }
1291 if let Some(wd) = state.get("weight_decay") {
1292 self.config.weight_decay = wd.to_scalar()?;
1293 }
1294 if let Some(step) = state.get("current_step") {
1295 self.current_step = step.to_scalar()? as usize;
1296 }
1297 if let Some(momentum) = state.get("momentum_coefficient") {
1298 self.momentum_coefficient = momentum.to_scalar()?;
1299 }
1300 if let Some(prev_momentum) = state.get("previous_momentum") {
1301 self.previous_momentum = prev_momentum.to_scalar()?;
1302 }
1303
1304 self.previous_params.clear();
1305 for (key, tensor) in state {
1306 if let Some(param_id_str) = key.strip_prefix("prev_param_") {
1307 if let Ok(param_id) = param_id_str.parse::<usize>() {
1308 self.previous_params.insert(param_id, tensor);
1309 }
1310 }
1311 }
1312
1313 Ok(())
1314 }
1315}
1316
1317#[derive(Debug, Clone, Serialize, Deserialize)]
1319pub struct AdaptiveBatchSizingConfig {
1320 pub initial_batch_size: usize,
1322 pub min_batch_size: usize,
1324 pub max_batch_size: usize,
1326 pub gradient_variance_tolerance: f32,
1328 pub lr_adaptation_factor: f32,
1330 pub variance_window_size: usize,
1332 pub increase_threshold: f32,
1334 pub decrease_threshold: f32,
1336}
1337
1338impl Default for AdaptiveBatchSizingConfig {
1339 fn default() -> Self {
1340 Self {
1341 initial_batch_size: 32,
1342 min_batch_size: 8,
1343 max_batch_size: 512,
1344 gradient_variance_tolerance: 0.1,
1345 lr_adaptation_factor: 0.8,
1346 variance_window_size: 10,
1347 increase_threshold: 0.05,
1348 decrease_threshold: 0.2,
1349 }
1350 }
1351}
1352
1353#[derive(Debug)]
1359pub struct AdaptiveBatchSizing {
1360 config: AdaptiveBatchSizingConfig,
1361 current_batch_size: usize,
1362 gradient_variance_history: Vec<f32>,
1363 loss_history: Vec<f32>,
1364 current_step: usize,
1365 last_adjustment_step: usize,
1366}
1367
1368impl AdaptiveBatchSizing {
1369 pub fn new(config: AdaptiveBatchSizingConfig) -> Self {
1371 let initial_batch_size = config.initial_batch_size;
1372 Self {
1373 config,
1374 current_batch_size: initial_batch_size,
1375 gradient_variance_history: Vec::new(),
1376 loss_history: Vec::new(),
1377 current_step: 0,
1378 last_adjustment_step: 0,
1379 }
1380 }
1381
1382 pub fn with_defaults(
1384 initial_batch_size: usize,
1385 min_batch_size: usize,
1386 max_batch_size: usize,
1387 ) -> Self {
1388 Self::new(AdaptiveBatchSizingConfig {
1389 initial_batch_size,
1390 min_batch_size,
1391 max_batch_size,
1392 ..Default::default()
1393 })
1394 }
1395
1396 pub fn current_batch_size(&self) -> usize {
1398 self.current_batch_size
1399 }
1400
1401 pub fn get_config(&self) -> &AdaptiveBatchSizingConfig {
1403 &self.config
1404 }
1405
1406 pub fn update(&mut self, gradient_variance: f32, current_loss: f32) -> Result<usize> {
1408 self.current_step += 1;
1409
1410 self.gradient_variance_history.push(gradient_variance);
1412 self.loss_history.push(current_loss);
1413
1414 if self.gradient_variance_history.len() > self.config.variance_window_size {
1416 self.gradient_variance_history.remove(0);
1417 }
1418 if self.loss_history.len() > self.config.variance_window_size {
1419 self.loss_history.remove(0);
1420 }
1421
1422 if self.should_adjust_batch_size() {
1424 self.adjust_batch_size()?;
1425 self.last_adjustment_step = self.current_step;
1426 }
1427
1428 Ok(self.current_batch_size)
1429 }
1430
1431 pub fn compute_gradient_variance(&self, gradients: &[Tensor]) -> Result<f32> {
1433 if gradients.is_empty() {
1434 return Ok(0.0);
1435 }
1436
1437 let mut mean_grad = gradients[0].clone();
1439 for grad in gradients.iter().skip(1) {
1440 mean_grad = mean_grad.add(grad)?;
1441 }
1442 mean_grad = mean_grad.div_scalar(gradients.len() as f32)?;
1443
1444 let mut variance_sum = 0.0;
1446 for grad in gradients {
1447 let diff = grad.sub(&mean_grad)?;
1448 let squared_norm = diff.mul(&diff)?.sum(None, false)?;
1449 variance_sum += squared_norm.to_scalar()?;
1450 }
1451
1452 Ok(variance_sum / gradients.len() as f32)
1453 }
1454
1455 fn should_adjust_batch_size(&self) -> bool {
1456 if self.current_step - self.last_adjustment_step < 5 {
1458 return false;
1459 }
1460
1461 self.gradient_variance_history.len() >= 3
1463 }
1464
1465 fn adjust_batch_size(&mut self) -> Result<()> {
1466 let recent_variance = self.recent_average_variance();
1467 let variance_trend = self.variance_trend();
1468 let loss_trend = self.loss_trend();
1469
1470 if recent_variance > self.config.decrease_threshold && variance_trend > 0.0 {
1472 self.increase_batch_size();
1474 } else if recent_variance < self.config.increase_threshold && loss_trend < -0.01 {
1475 self.decrease_batch_size();
1477 }
1478
1479 Ok(())
1480 }
1481
1482 fn recent_average_variance(&self) -> f32 {
1483 if self.gradient_variance_history.is_empty() {
1484 return 0.0;
1485 }
1486
1487 let recent_window = std::cmp::min(5, self.gradient_variance_history.len());
1488 let start_idx = self.gradient_variance_history.len() - recent_window;
1489
1490 self.gradient_variance_history[start_idx..].iter().sum::<f32>() / recent_window as f32
1491 }
1492
1493 fn variance_trend(&self) -> f32 {
1494 if self.gradient_variance_history.len() < 3 {
1495 return 0.0;
1496 }
1497
1498 let len = self.gradient_variance_history.len();
1499 let recent = self.gradient_variance_history[len - 2..].iter().sum::<f32>() / 2.0;
1500 let older = self.gradient_variance_history[len - 4..len - 2].iter().sum::<f32>() / 2.0;
1501
1502 recent - older
1503 }
1504
1505 fn loss_trend(&self) -> f32 {
1506 if self.loss_history.len() < 3 {
1507 return 0.0;
1508 }
1509
1510 let len = self.loss_history.len();
1511 let recent = self.loss_history[len - 2..].iter().sum::<f32>() / 2.0;
1512 let older = self.loss_history[len - 4..len - 2].iter().sum::<f32>() / 2.0;
1513
1514 (recent - older) / older.max(1e-8)
1515 }
1516
1517 fn increase_batch_size(&mut self) {
1518 let new_size = (self.current_batch_size as f32 * 1.5) as usize;
1519 self.current_batch_size = new_size.min(self.config.max_batch_size);
1520 }
1521
1522 fn decrease_batch_size(&mut self) {
1523 let new_size = (self.current_batch_size as f32 * 0.8) as usize;
1524 self.current_batch_size = new_size.max(self.config.min_batch_size);
1525 }
1526
1527 pub fn get_lr_adjustment(&self, original_batch_size: usize) -> f32 {
1529 let ratio = self.current_batch_size as f32 / original_batch_size as f32;
1530 ratio.sqrt() * self.config.lr_adaptation_factor
1531 }
1532
1533 pub fn reset(&mut self) {
1535 self.current_batch_size = self.config.initial_batch_size;
1536 self.gradient_variance_history.clear();
1537 self.loss_history.clear();
1538 self.current_step = 0;
1539 self.last_adjustment_step = 0;
1540 }
1541}
1542
1543#[derive(Debug, Clone, Serialize, Deserialize)]
1545pub struct LossSurfaceSmoothingConfig {
1546 pub smoothing_strength: f32,
1548 pub noise_variance: f32,
1550 pub ema_decay: f32,
1552 pub averaging_window: usize,
1554 pub use_gradient_averaging: bool,
1556 pub use_noise_injection: bool,
1558}
1559
1560impl Default for LossSurfaceSmoothingConfig {
1561 fn default() -> Self {
1562 Self {
1563 smoothing_strength: 0.1,
1564 noise_variance: 1e-4,
1565 ema_decay: 0.9,
1566 averaging_window: 5,
1567 use_gradient_averaging: true,
1568 use_noise_injection: false,
1569 }
1570 }
1571}
1572
1573#[derive(Debug)]
1581pub struct LossSurfaceSmoothing {
1582 config: LossSurfaceSmoothingConfig,
1583 gradient_history: HashMap<usize, Vec<Tensor>>,
1584 ema_gradients: HashMap<usize, Tensor>,
1585 smoothed_parameters: HashMap<usize, Tensor>,
1586 current_step: usize,
1587}
1588
1589impl LossSurfaceSmoothing {
1590 pub fn new(config: LossSurfaceSmoothingConfig) -> Self {
1592 Self {
1593 config,
1594 gradient_history: HashMap::new(),
1595 ema_gradients: HashMap::new(),
1596 smoothed_parameters: HashMap::new(),
1597 current_step: 0,
1598 }
1599 }
1600
1601 pub fn with_defaults(smoothing_strength: f32, use_noise: bool) -> Self {
1603 Self::new(LossSurfaceSmoothingConfig {
1604 smoothing_strength,
1605 use_noise_injection: use_noise,
1606 ..Default::default()
1607 })
1608 }
1609
1610 pub fn get_config(&self) -> &LossSurfaceSmoothingConfig {
1612 &self.config
1613 }
1614
1615 pub fn smooth_gradients(&mut self, parameters: &mut [Tensor]) -> Result<()> {
1617 self.current_step += 1;
1618
1619 for (param_id, parameter) in parameters.iter_mut().enumerate() {
1620 let original_grad = parameter.grad()?;
1621 let mut smoothed_grad = original_grad.clone();
1622
1623 if self.config.use_gradient_averaging {
1625 smoothed_grad = self.apply_gradient_averaging(param_id, &original_grad)?;
1626 }
1627
1628 smoothed_grad = self.apply_ema_smoothing(param_id, &smoothed_grad)?;
1630
1631 if self.config.use_noise_injection {
1633 smoothed_grad = self.apply_noise_injection(&smoothed_grad)?;
1634 }
1635
1636 parameter.set_grad(smoothed_grad)?;
1638 }
1639
1640 Ok(())
1641 }
1642
1643 pub fn smooth_parameters(&mut self, parameters: &mut [Tensor]) -> Result<()> {
1645 for (param_id, parameter) in parameters.iter_mut().enumerate() {
1646 if let Some(smoothed_param) = self.smoothed_parameters.get(¶m_id) {
1647 let new_smoothed = smoothed_param
1649 .mul_scalar(self.config.ema_decay)?
1650 .add(¶meter.mul_scalar(1.0 - self.config.ema_decay)?)?;
1651
1652 *parameter = parameter
1654 .mul_scalar(1.0 - self.config.smoothing_strength)?
1655 .add(&new_smoothed.mul_scalar(self.config.smoothing_strength)?)?;
1656
1657 self.smoothed_parameters.insert(param_id, new_smoothed);
1658 } else {
1659 self.smoothed_parameters.insert(param_id, parameter.clone());
1661 }
1662 }
1663
1664 Ok(())
1665 }
1666
1667 fn apply_gradient_averaging(&mut self, param_id: usize, gradient: &Tensor) -> Result<Tensor> {
1668 let history = self.gradient_history.entry(param_id).or_default();
1669
1670 history.push(gradient.clone());
1671 if history.len() > self.config.averaging_window {
1672 history.remove(0);
1673 }
1674
1675 if history.len() == 1 {
1677 Ok(gradient.clone())
1678 } else {
1679 let mut sum = history[0].clone();
1680 for grad in history.iter().skip(1) {
1681 sum = sum.add(grad)?;
1682 }
1683 Ok(sum.div_scalar(history.len() as f32)?)
1684 }
1685 }
1686
1687 fn apply_ema_smoothing(&mut self, param_id: usize, gradient: &Tensor) -> Result<Tensor> {
1688 if let Some(ema_grad) = self.ema_gradients.get(¶m_id) {
1689 let new_ema = ema_grad
1690 .mul_scalar(self.config.ema_decay)?
1691 .add(&gradient.mul_scalar(1.0 - self.config.ema_decay)?)?;
1692 self.ema_gradients.insert(param_id, new_ema.clone());
1693 Ok(new_ema)
1694 } else {
1695 self.ema_gradients.insert(param_id, gradient.clone());
1696 Ok(gradient.clone())
1697 }
1698 }
1699
1700 fn apply_noise_injection(&self, gradient: &Tensor) -> Result<Tensor> {
1701 let noise = Tensor::randn_like(gradient)
1702 .map_err(|e| anyhow!("Failed to create noise tensor: {}", e))?
1703 .mul_scalar(self.config.noise_variance.sqrt())
1704 .map_err(|e| anyhow!("Failed to scale noise tensor: {}", e))?;
1705 gradient
1706 .add(&noise)
1707 .map_err(|e| anyhow!("Failed to add noise to gradient: {}", e))
1708 }
1709
1710 pub fn reset(&mut self) {
1712 self.gradient_history.clear();
1713 self.ema_gradients.clear();
1714 self.smoothed_parameters.clear();
1715 self.current_step = 0;
1716 }
1717
1718 pub fn get_statistics(&self) -> HashMap<String, f32> {
1720 let mut stats = HashMap::new();
1721 stats.insert("current_step".to_string(), self.current_step as f32);
1722 stats.insert(
1723 "num_tracked_params".to_string(),
1724 self.gradient_history.len() as f32,
1725 );
1726 stats.insert(
1727 "smoothing_strength".to_string(),
1728 self.config.smoothing_strength,
1729 );
1730 stats.insert("ema_decay".to_string(), self.config.ema_decay);
1731 stats
1732 }
1733}
1734
1735#[cfg(test)]
1736mod tests {
1737 use super::*;
1738
1739 #[test]
1740 fn test_qhm_config_default() {
1741 let config = QHMConfig::default();
1742 assert_eq!(config.learning_rate, 1e-3);
1743 assert_eq!(config.momentum, 0.9);
1744 assert_eq!(config.nu, 0.7);
1745 assert_eq!(config.weight_decay, 0.0);
1746 }
1747
1748 #[test]
1749 fn test_aggmo_config_default() {
1750 let config = AggMoConfig::default();
1751 assert_eq!(config.learning_rate, 1e-3);
1752 assert_eq!(config.momentum_coefficients, vec![0.0, 0.9, 0.99]);
1753 assert_eq!(config.weight_decay, 0.0);
1754 }
1755
1756 #[test]
1757 fn test_qhm_creation() {
1758 let optimizer = QHM::with_defaults(1e-3, 0.9, 0.7);
1759 assert_eq!(optimizer.get_lr(), 1e-3);
1760 assert_eq!(optimizer.current_step, 0);
1761 }
1762
1763 #[test]
1764 fn test_aggmo_creation() {
1765 let optimizer = AggMo::with_defaults(1e-3, vec![0.0, 0.9, 0.99]);
1766 assert_eq!(optimizer.get_lr(), 1e-3);
1767 assert_eq!(optimizer.num_momentum_buffers(), 3);
1768 }
1769
1770 #[test]
1771 fn test_variance_reduction_svrg() {
1772 let optimizer = VarianceReduction::svrg(1e-3, 50, 10);
1773 assert_eq!(optimizer.get_lr(), 1e-3);
1774 assert_eq!(optimizer.current_step, 0);
1775 }
1776
1777 #[test]
1778 fn test_variance_reduction_sag() {
1779 let optimizer = VarianceReduction::sag(1e-3, 100);
1780 assert_eq!(optimizer.get_lr(), 1e-3);
1781 assert!(matches!(
1782 optimizer.config.method,
1783 VarianceReductionMethod::SAG
1784 ));
1785 }
1786
1787 #[test]
1788 fn test_nesterov_accelerated_gradient_config() {
1789 let config = NesterovAcceleratedGradientConfig::default();
1790 assert_eq!(config.learning_rate, 1e-3);
1791 assert_eq!(config.momentum, 0.9);
1792 assert_eq!(config.weight_decay, 0.0);
1793 assert!(!config.restart_on_increase);
1794 }
1795
1796 #[test]
1797 fn test_nesterov_accelerated_gradient_creation() {
1798 let optimizer = NesterovAcceleratedGradient::with_defaults(1e-3, 0.9);
1799 assert_eq!(optimizer.get_lr(), 1e-3);
1800 assert_eq!(optimizer.current_step, 0);
1801 assert!(optimizer.previous_loss.is_none());
1802 }
1803
1804 #[test]
1805 fn test_nesterov_restart_on_increase() {
1806 let mut optimizer = NesterovAcceleratedGradient::new(NesterovAcceleratedGradientConfig {
1807 learning_rate: 1e-3,
1808 momentum: 0.9,
1809 weight_decay: 0.0,
1810 restart_on_increase: true,
1811 });
1812
1813 optimizer.set_current_loss(1.0);
1815 assert_eq!(optimizer.previous_loss, Some(1.0));
1816
1817 optimizer.set_current_loss(1.5);
1819 assert_eq!(optimizer.previous_loss, Some(1.5));
1820 }
1821
1822 #[test]
1823 fn test_heavy_ball_config() {
1824 let config = HeavyBallConfig::default();
1825 assert_eq!(config.learning_rate, 1e-3);
1826 assert_eq!(config.beta, 0.9);
1827 assert_eq!(config.weight_decay, 0.0);
1828 assert!(!config.adaptive_momentum);
1829 }
1830
1831 #[test]
1832 fn test_heavy_ball_creation() {
1833 let optimizer = HeavyBall::with_defaults(1e-3, 0.9);
1834 assert_eq!(optimizer.get_lr(), 1e-3);
1835 assert_eq!(optimizer.current_step, 0);
1836 assert_eq!(optimizer.get_config().beta, 0.9);
1837 }
1838
1839 #[test]
1840 fn test_heavy_ball_adaptive_momentum() {
1841 let optimizer = HeavyBall::new(HeavyBallConfig {
1842 learning_rate: 1e-3,
1843 beta: 0.9,
1844 weight_decay: 0.0,
1845 adaptive_momentum: true,
1846 });
1847
1848 assert!(optimizer.config.adaptive_momentum);
1849 }
1850
1851 #[test]
1852 fn test_fista_config() {
1853 let config = FISTAConfig::default();
1854 assert_eq!(config.learning_rate, 1e-3);
1855 assert_eq!(config.threshold, 1e-4);
1856 assert!(config.adaptive_restart);
1857 assert_eq!(config.weight_decay, 0.0);
1858 }
1859
1860 #[test]
1861 fn test_fista_creation() {
1862 let optimizer = FISTA::with_defaults(1e-3, 1e-4);
1863 assert_eq!(optimizer.get_lr(), 1e-3);
1864 assert_eq!(optimizer.current_step, 0);
1865 assert_eq!(optimizer.momentum_coefficient, 1.0);
1866 assert_eq!(optimizer.previous_momentum, 1.0);
1867 }
1868
1869 #[test]
1870 fn test_fista_momentum_update() {
1871 let mut optimizer = FISTA::with_defaults(1e-3, 1e-4);
1872
1873 optimizer.current_step = 1;
1875 optimizer.update_momentum_coefficient();
1876 assert!(optimizer.momentum_coefficient > 1.0);
1877 assert_eq!(optimizer.previous_momentum, 1.0);
1878
1879 let prev_momentum = optimizer.momentum_coefficient;
1880 optimizer.current_step = 2;
1881 optimizer.update_momentum_coefficient();
1882 assert!(optimizer.momentum_coefficient > prev_momentum);
1883 }
1884
1885 #[test]
1886 fn test_adaptive_batch_sizing_config() {
1887 let config = AdaptiveBatchSizingConfig::default();
1888 assert_eq!(config.initial_batch_size, 32);
1889 assert_eq!(config.min_batch_size, 8);
1890 assert_eq!(config.max_batch_size, 512);
1891 assert_eq!(config.gradient_variance_tolerance, 0.1);
1892 assert_eq!(config.lr_adaptation_factor, 0.8);
1893 assert_eq!(config.variance_window_size, 10);
1894 assert_eq!(config.increase_threshold, 0.05);
1895 assert_eq!(config.decrease_threshold, 0.2);
1896 }
1897
1898 #[test]
1899 fn test_adaptive_batch_sizing_creation() {
1900 let abs = AdaptiveBatchSizing::with_defaults(64, 16, 256);
1901 assert_eq!(abs.current_batch_size(), 64);
1902 assert_eq!(abs.get_config().min_batch_size, 16);
1903 assert_eq!(abs.get_config().max_batch_size, 256);
1904 }
1905
1906 #[test]
1907 fn test_adaptive_batch_sizing_lr_adjustment() {
1908 let abs = AdaptiveBatchSizing::with_defaults(64, 16, 256);
1909 let lr_adj = abs.get_lr_adjustment(32);
1910 assert!(lr_adj > 0.0);
1911 assert!(lr_adj < 2.0);
1912 }
1913
1914 #[test]
1915 fn test_adaptive_batch_sizing_reset() {
1916 let mut abs = AdaptiveBatchSizing::with_defaults(64, 16, 256);
1917 abs.current_step = 10;
1918 abs.reset();
1919 assert_eq!(abs.current_step, 0);
1920 assert_eq!(abs.current_batch_size(), 64);
1921 }
1922
1923 #[test]
1924 fn test_loss_surface_smoothing_config() {
1925 let config = LossSurfaceSmoothingConfig::default();
1926 assert_eq!(config.smoothing_strength, 0.1);
1927 assert_eq!(config.noise_variance, 1e-4);
1928 assert_eq!(config.ema_decay, 0.9);
1929 assert_eq!(config.averaging_window, 5);
1930 assert!(config.use_gradient_averaging);
1931 assert!(!config.use_noise_injection);
1932 }
1933
1934 #[test]
1935 fn test_loss_surface_smoothing_creation() {
1936 let lss = LossSurfaceSmoothing::with_defaults(0.2, true);
1937 assert_eq!(lss.get_config().smoothing_strength, 0.2);
1938 assert!(lss.get_config().use_noise_injection);
1939 assert_eq!(lss.current_step, 0);
1940 }
1941
1942 #[test]
1943 fn test_loss_surface_smoothing_statistics() {
1944 let lss = LossSurfaceSmoothing::with_defaults(0.1, false);
1945 let stats = lss.get_statistics();
1946 assert_eq!(stats.get("current_step"), Some(&0.0));
1947 assert_eq!(stats.get("num_tracked_params"), Some(&0.0));
1948 assert_eq!(stats.get("smoothing_strength"), Some(&0.1));
1949 assert_eq!(stats.get("ema_decay"), Some(&0.9));
1950 }
1951
1952 #[test]
1953 fn test_loss_surface_smoothing_reset() {
1954 let mut lss = LossSurfaceSmoothing::with_defaults(0.1, false);
1955 lss.current_step = 5;
1956 lss.reset();
1957 assert_eq!(lss.current_step, 0);
1958 }
1959}