1pub mod legacy;
7
8pub use legacy::{CAMEConfig, CAME};
9
10use trustformers_core::errors::TrustformersError;
13
14#[derive(Debug, thiserror::Error)]
16pub enum OptimError {
17 #[error("length mismatch: param length {param} != grad length {grad}")]
19 LengthMismatch { param: usize, grad: usize },
20 #[error("dimension mismatch: rows * cols ({rows} * {cols} = {product}) != size {size}")]
22 DimensionMismatch {
23 rows: usize,
24 cols: usize,
25 product: usize,
26 size: usize,
27 },
28 #[error("no state initialised for parameter group index {0}")]
30 StateNotInitialised(usize),
31 #[error("numerical error: {0}")]
33 NumericalError(String),
34}
35
36impl From<OptimError> for TrustformersError {
37 fn from(e: OptimError) -> Self {
38 TrustformersError::invalid_operation(e.to_string())
39 }
40}
41
42#[derive(Debug, Clone)]
46pub struct CameConfig {
47 pub lr: f64,
49 pub betas: (f64, f64, f64),
52 pub eps: (f64, f64),
55 pub weight_decay: f64,
57 pub clip_threshold: f64,
59 pub decay_rate: f64,
62}
63
64impl Default for CameConfig {
65 fn default() -> Self {
66 Self {
67 lr: 2e-4,
68 betas: (0.9, 0.999, 0.9999),
69 eps: (1e-30, 1e-16),
70 weight_decay: 0.0,
71 clip_threshold: 1.0,
72 decay_rate: -0.8,
73 }
74 }
75}
76
77#[derive(Debug, Clone)]
79pub struct CameParamState {
80 pub step: u64,
82 pub exp_avg: Vec<f32>,
84 pub exp_avg_sq_row: Vec<f32>,
86 pub exp_avg_sq_col: Vec<f32>,
88 pub exp_avg_sq: Option<Vec<f32>>,
90 pub exp_avg_insta_sq_row: Vec<f32>,
92 pub exp_avg_insta_sq_col: Vec<f32>,
94}
95
96impl CameParamState {
97 pub fn new_2d(size: usize, rows: usize, cols: usize) -> Self {
99 Self {
100 step: 0,
101 exp_avg: vec![0.0_f32; size],
102 exp_avg_sq_row: vec![0.0_f32; rows],
103 exp_avg_sq_col: vec![0.0_f32; cols],
104 exp_avg_sq: None,
105 exp_avg_insta_sq_row: vec![0.0_f32; rows],
106 exp_avg_insta_sq_col: vec![0.0_f32; cols],
107 }
108 }
109
110 pub fn new_1d(size: usize) -> Self {
112 Self {
113 step: 0,
114 exp_avg: vec![0.0_f32; size],
115 exp_avg_sq_row: Vec::new(),
116 exp_avg_sq_col: Vec::new(),
117 exp_avg_sq: Some(vec![0.0_f32; size]),
118 exp_avg_insta_sq_row: Vec::new(),
119 exp_avg_insta_sq_col: Vec::new(),
120 }
121 }
122}
123
124#[inline]
126fn rms(v: &[f32]) -> f32 {
127 if v.is_empty() {
128 return 0.0;
129 }
130 let sq_sum: f32 = v.iter().map(|x| x * x).sum();
131 (sq_sum / v.len() as f32).sqrt()
132}
133
134pub fn came_update(
149 param: &mut [f32],
150 grad: &[f32],
151 state: &mut CameParamState,
152 config: &CameConfig,
153 rows: usize,
154 cols: usize,
155) -> Result<(), OptimError> {
156 let size = param.len();
158 if grad.len() != size {
159 return Err(OptimError::LengthMismatch {
160 param: size,
161 grad: grad.len(),
162 });
163 }
164 let expected = rows * cols;
165 if expected != size {
166 return Err(OptimError::DimensionMismatch {
167 rows,
168 cols,
169 product: expected,
170 size,
171 });
172 }
173
174 state.step += 1;
176 let step = state.step as f64;
177
178 let beta2_t = (1.0 - step.powf(config.decay_rate)).min(config.betas.1) as f32;
181
182 let beta1 = config.betas.0 as f32;
183 let beta3 = config.betas.2 as f32;
184 let eps1 = config.eps.0 as f32;
185 let eps2 = config.eps.1 as f32;
186
187 let grad_rms = rms(grad);
189 let clip_scale = if grad_rms > config.clip_threshold as f32 {
190 config.clip_threshold as f32 / (grad_rms + eps1)
191 } else {
192 1.0
193 };
194
195 for (m, &g) in state.exp_avg.iter_mut().zip(grad.iter()) {
200 let g_clipped = g * clip_scale;
201 *m = beta1 * *m + (1.0 - beta1) * g_clipped;
202 }
203
204 if rows == 1 {
206 let sq = state
208 .exp_avg_sq
209 .as_mut()
210 .ok_or_else(|| OptimError::NumericalError("1-D state missing exp_avg_sq".into()))?;
211 for (s, &g) in sq.iter_mut().zip(grad.iter()) {
212 let g_clipped = g * clip_scale;
213 *s = beta2_t * *s + (1.0 - beta2_t) * (g_clipped * g_clipped + eps1);
214 }
215
216 for ((p, &m), &s) in param.iter_mut().zip(state.exp_avg.iter()).zip(sq.iter()) {
218 let denom = s.sqrt() + eps2;
219 let update = m / denom;
220 if config.weight_decay != 0.0 {
221 *p -= config.lr as f32 * config.weight_decay as f32 * *p;
222 }
223 *p -= config.lr as f32 * update;
224 }
225 } else {
226 let mut row_mean = vec![0.0_f32; rows];
229 let mut col_mean = vec![0.0_f32; cols];
230
231 for i in 0..rows {
232 let mut s = 0.0_f32;
233 for j in 0..cols {
234 let g = grad[i * cols + j] * clip_scale;
235 s += g * g;
236 }
237 row_mean[i] = s / cols as f32 + eps1;
238 }
239 for j in 0..cols {
240 let mut s = 0.0_f32;
241 for i in 0..rows {
242 let g = grad[i * cols + j] * clip_scale;
243 s += g * g;
244 }
245 col_mean[j] = s / rows as f32 + eps1;
246 }
247
248 for (r, &rm) in state.exp_avg_sq_row.iter_mut().zip(row_mean.iter()) {
250 *r = beta2_t * *r + (1.0 - beta2_t) * rm;
251 }
252 for (c, &cm) in state.exp_avg_sq_col.iter_mut().zip(col_mean.iter()) {
253 *c = beta2_t * *c + (1.0 - beta2_t) * cm;
254 }
255
256 for (r, &rm) in state.exp_avg_insta_sq_row.iter_mut().zip(row_mean.iter()) {
258 *r = beta3 * *r + (1.0 - beta3) * rm;
259 }
260 for (c, &cm) in state.exp_avg_insta_sq_col.iter_mut().zip(col_mean.iter()) {
261 *c = beta3 * *c + (1.0 - beta3) * cm;
262 }
263
264 let row_sum: f32 = state.exp_avg_sq_row.iter().sum();
266 let row_normaliser = (row_sum / rows as f32).max(eps1);
267
268 for i in 0..rows {
270 let smoothed_row = state.exp_avg_sq_row[i];
271 let insta_row = state.exp_avg_insta_sq_row[i];
272
273 for j in 0..cols {
274 let smoothed_col = state.exp_avg_sq_col[j];
275 let insta_col = state.exp_avg_insta_sq_col[j];
276
277 let v_approx = (smoothed_row * smoothed_col / row_normaliser).sqrt();
279
280 let smoothed_insta_row = (insta_row * insta_col / row_normaliser).sqrt();
282 let confidence = if smoothed_insta_row > eps1 {
283 (v_approx / (smoothed_insta_row + eps2)).min(1.0_f32)
284 } else {
285 1.0_f32
286 };
287
288 let denom = v_approx + eps2;
289 let idx = i * cols + j;
290 let m = state.exp_avg[idx];
291 let update = confidence * m / denom;
292
293 let p = &mut param[idx];
294 if config.weight_decay != 0.0 {
295 *p -= config.lr as f32 * config.weight_decay as f32 * *p;
296 }
297 *p -= config.lr as f32 * update;
298 }
299 }
300 }
301
302 Ok(())
303}
304
305#[derive(Debug, Clone)]
307struct ParamGroupMeta {
308 #[allow(dead_code)]
309 size: usize,
310 rows: usize,
311 cols: usize,
312}
313
314#[derive(Debug)]
319pub struct CameOptimizer {
320 pub config: CameConfig,
322 pub states: Vec<CameParamState>,
324 meta: Vec<ParamGroupMeta>,
326}
327
328impl CameOptimizer {
329 pub fn new(config: CameConfig) -> Self {
331 Self {
332 config,
333 states: Vec::new(),
334 meta: Vec::new(),
335 }
336 }
337
338 pub fn add_param_group(&mut self, param_size: usize, rows: usize, cols: usize) {
343 let state = if rows == 1 {
344 CameParamState::new_1d(param_size)
345 } else {
346 CameParamState::new_2d(param_size, rows, cols)
347 };
348 self.states.push(state);
349 self.meta.push(ParamGroupMeta {
350 size: param_size,
351 rows,
352 cols,
353 });
354 }
355
356 pub fn step(&mut self, params: &mut [Vec<f32>], grads: &[Vec<f32>]) -> Result<(), OptimError> {
367 for (idx, ((param, grad), state)) in
368 params.iter_mut().zip(grads.iter()).zip(self.states.iter_mut()).enumerate()
369 {
370 let meta = self.meta.get(idx).ok_or(OptimError::StateNotInitialised(idx))?;
371 came_update(param, grad, state, &self.config, meta.rows, meta.cols)?;
372 }
373 Ok(())
374 }
375}
376
377#[cfg(test)]
382mod tests {
383 use super::*;
384 use approx::assert_relative_eq;
385
386 #[test]
390 fn test_came_config_defaults() {
391 let cfg = CameConfig::default();
392 assert_relative_eq!(cfg.lr, 2e-4);
393 assert_relative_eq!(cfg.betas.0, 0.9);
394 assert_relative_eq!(cfg.betas.1, 0.999);
395 assert_relative_eq!(cfg.betas.2, 0.9999);
396 assert_relative_eq!(cfg.eps.0, 1e-30);
397 assert_relative_eq!(cfg.eps.1, 1e-16);
398 assert_relative_eq!(cfg.weight_decay, 0.0);
399 assert_relative_eq!(cfg.clip_threshold, 1.0);
400 assert_relative_eq!(cfg.decay_rate, -0.8);
401 }
402
403 #[test]
407 fn test_state_init_2d() {
408 let state = CameParamState::new_2d(6, 2, 3);
409 assert_eq!(state.step, 0);
410 assert_eq!(state.exp_avg.len(), 6);
411 assert_eq!(state.exp_avg_sq_row.len(), 2);
412 assert_eq!(state.exp_avg_sq_col.len(), 3);
413 assert!(state.exp_avg_sq.is_none());
414 assert_eq!(state.exp_avg_insta_sq_row.len(), 2);
415 assert_eq!(state.exp_avg_insta_sq_col.len(), 3);
416 assert!(state.exp_avg.iter().all(|&x| x == 0.0));
417 }
418
419 #[test]
423 fn test_state_init_1d() {
424 let state = CameParamState::new_1d(5);
425 assert_eq!(state.step, 0);
426 assert_eq!(state.exp_avg.len(), 5);
427 assert!(state.exp_avg_sq_row.is_empty());
428 assert!(state.exp_avg_sq_col.is_empty());
429 assert!(state.exp_avg_sq.is_some());
430 assert_eq!(state.exp_avg_sq.as_ref().map(|v| v.len()), Some(5));
431 }
432
433 #[test]
437 fn test_step_counter() {
438 let cfg = CameConfig::default();
439 let mut state = CameParamState::new_1d(2);
440 let mut param = vec![1.0_f32; 2];
441 let grad = vec![0.1_f32; 2];
442
443 came_update(&mut param, &grad, &mut state, &cfg, 1, 2).expect("update failed");
444 assert_eq!(state.step, 1);
445 came_update(&mut param, &grad, &mut state, &cfg, 1, 2).expect("update failed");
446 assert_eq!(state.step, 2);
447 }
448
449 #[test]
453 fn test_factored_second_moment_update() {
454 let cfg = CameConfig {
455 lr: 0.0,
456 ..CameConfig::default()
457 };
458 let rows = 2_usize;
459 let cols = 3_usize;
460 let size = rows * cols;
461 let mut state = CameParamState::new_2d(size, rows, cols);
462 let mut param = vec![0.0_f32; size];
463 let grad = vec![1.0_f32; size];
464
465 came_update(&mut param, &grad, &mut state, &cfg, rows, cols).expect("update failed");
467 assert!(state.exp_avg_sq_row.iter().all(|&x| x > 0.0));
468 assert!(state.exp_avg_sq_col.iter().all(|&x| x > 0.0));
469 }
470
471 #[test]
475 fn test_dynamic_beta2_schedule() {
476 let cfg = CameConfig::default();
477 let step = 1_f64;
479 let beta2_t = (1.0 - step.powf(cfg.decay_rate)).min(cfg.betas.1);
480 assert_relative_eq!(beta2_t, 0.0, epsilon = 1e-9);
481
482 let step100 = 100_f64;
484 let beta2_100 = (1.0 - step100.powf(cfg.decay_rate)).min(cfg.betas.1);
485 assert!(beta2_100 > 0.9 && beta2_100 < 1.0);
486 }
487
488 #[test]
492 fn test_confidence_adaptation() {
493 let cfg = CameConfig::default();
494 let rows = 2_usize;
495 let cols = 2_usize;
496 let size = rows * cols;
497 let mut state = CameParamState::new_2d(size, rows, cols);
498 let mut param = vec![0.0_f32; size];
499 let grad = vec![1.0_f32; size];
500
501 came_update(&mut param, &grad, &mut state, &cfg, rows, cols).expect("update failed");
502
503 assert!(state.exp_avg_insta_sq_row.iter().all(|&x| x > 0.0));
505 assert!(state.exp_avg_insta_sq_col.iter().all(|&x| x > 0.0));
506 }
507
508 #[test]
512 fn test_weight_decay() {
513 let cfg = CameConfig {
514 lr: 1e-1,
515 weight_decay: 0.1,
516 ..CameConfig::default()
517 };
518 let mut state = CameParamState::new_1d(2);
519 let initial_param = vec![1.0_f32; 2];
520 let mut param = initial_param.clone();
521 let grad = vec![0.0_f32; 2]; came_update(&mut param, &grad, &mut state, &cfg, 1, 2).expect("update failed");
524
525 for (p_new, p_old) in param.iter().zip(initial_param.iter()) {
527 assert!(
528 p_new.abs() < p_old.abs(),
529 "weight decay did not reduce param"
530 );
531 }
532 }
533
534 #[test]
538 fn test_single_step_direction() {
539 let cfg = CameConfig::default();
540 let mut state = CameParamState::new_1d(3);
541 let mut param = vec![0.5_f32; 3];
542 let grad = vec![0.1_f32; 3]; let param_before = param.clone();
545 came_update(&mut param, &grad, &mut state, &cfg, 1, 3).expect("update failed");
546
547 for (p_new, p_old) in param.iter().zip(param_before.iter()) {
549 assert!(
550 p_new < p_old,
551 "param did not decrease with positive gradient"
552 );
553 }
554 }
555
556 #[test]
560 fn test_gradient_clipping() {
561 let cfg_tight = CameConfig {
565 clip_threshold: 0.1,
566 ..CameConfig::default()
567 };
568 let cfg_loose = CameConfig {
569 clip_threshold: 1000.0,
570 ..CameConfig::default()
571 };
572
573 let large_grad = vec![5.0_f32; 4];
574
575 let mut state_tight = CameParamState::new_1d(4);
576 let mut param_tight = vec![0.0_f32; 4];
577 came_update(
578 &mut param_tight,
579 &large_grad,
580 &mut state_tight,
581 &cfg_tight,
582 1,
583 4,
584 )
585 .expect("tight update failed");
586
587 let mut state_loose = CameParamState::new_1d(4);
588 let mut param_loose = vec![0.0_f32; 4];
589 came_update(
590 &mut param_loose,
591 &large_grad,
592 &mut state_loose,
593 &cfg_loose,
594 1,
595 4,
596 )
597 .expect("loose update failed");
598
599 let m_tight: f32 = state_tight.exp_avg.iter().map(|x| x.abs()).sum();
602 let m_loose: f32 = state_loose.exp_avg.iter().map(|x| x.abs()).sum();
603 assert!(
604 m_tight < m_loose,
605 "tight clipping did not reduce first moment: m_tight={m_tight} m_loose={m_loose}"
606 );
607 }
608
609 #[test]
613 fn test_convergence_quadratic() {
614 let cfg = CameConfig {
616 lr: 1e-2,
617 ..CameConfig::default()
618 };
619 let mut state = CameParamState::new_1d(1);
620 let mut param = vec![5.0_f32];
621
622 for _ in 0..2000 {
623 let grad = param.clone(); came_update(&mut param, &grad, &mut state, &cfg, 1, 1).expect("update failed");
625 }
626
627 assert!(
628 param[0].abs() < 0.1,
629 "CAME did not converge on quadratic: final param = {}",
630 param[0]
631 );
632 }
633
634 #[test]
638 fn test_dimension_mismatch_error() {
639 let cfg = CameConfig::default();
640 let mut state = CameParamState::new_1d(4);
641 let mut param = vec![0.0_f32; 4];
642 let grad = vec![0.0_f32; 5]; let result = came_update(&mut param, &grad, &mut state, &cfg, 1, 4);
645 assert!(result.is_err());
646 matches!(result.unwrap_err(), OptimError::LengthMismatch { .. });
647 }
648
649 #[test]
653 fn test_came_optimizer_multi_param() {
654 let cfg = CameConfig::default();
655 let mut optimizer = CameOptimizer::new(cfg);
656 optimizer.add_param_group(4, 2, 2);
657 optimizer.add_param_group(3, 1, 3);
658
659 let mut params = vec![vec![1.0_f32; 4], vec![1.0_f32; 3]];
660 let grads = vec![vec![0.1_f32; 4], vec![0.1_f32; 3]];
661
662 optimizer.step(&mut params, &grads).expect("step failed");
663 assert_eq!(optimizer.states[0].step, 1);
664 assert_eq!(optimizer.states[1].step, 1);
665 }
666}
667
668#[cfg(test)]
669mod extended_tests {
670 use super::*;
671 use approx::assert_relative_eq;
672
673 #[test]
674 fn test_came_state_step_zero_at_init() {
675 let state = CameParamState::new_2d(6, 2, 3);
676 assert_eq!(state.step, 0);
677 let state1d = CameParamState::new_1d(4);
678 assert_eq!(state1d.step, 0);
679 }
680
681 #[test]
682 fn test_came_confidence_factors_nonzero_after_step() {
683 let cfg = CameConfig::default();
684 let mut state = CameParamState::new_2d(6, 2, 3);
685 let mut param = vec![0.5_f32; 6];
686 let grad = vec![0.1_f32; 6];
687 came_update(&mut param, &grad, &mut state, &cfg, 2, 3).expect("update failed");
688 assert!(
689 state.exp_avg_insta_sq_row.iter().all(|&x| x > 0.0),
690 "insta_sq_row should be nonzero after update"
691 );
692 assert!(
693 state.exp_avg_insta_sq_col.iter().all(|&x| x > 0.0),
694 "insta_sq_col should be nonzero after update"
695 );
696 }
697
698 #[test]
699 fn test_came_positive_grad_decreases_params() {
700 let cfg = CameConfig::default();
701 let mut state = CameParamState::new_1d(4);
702 let mut param = vec![1.0_f32; 4];
703 let grad = vec![0.5_f32; 4];
704 let before = param.clone();
705 came_update(&mut param, &grad, &mut state, &cfg, 1, 4).expect("update failed");
706 for (p_new, p_old) in param.iter().zip(before.iter()) {
707 assert!(
708 p_new < p_old,
709 "param should decrease with positive gradient"
710 );
711 }
712 }
713
714 #[test]
715 fn test_came_1d_vs_2d_single_element_both_decrease() {
716 let cfg = CameConfig::default();
717 let grad = vec![0.2_f32];
718
719 let mut state_1d = CameParamState::new_1d(1);
721 let mut param_1d = vec![1.0_f32];
722 came_update(&mut param_1d, &grad, &mut state_1d, &cfg, 1, 1).expect("1d update failed");
723 assert!(param_1d[0] < 1.0, "1D param should decrease");
724
725 let grad_2d = vec![0.2_f32; 4];
727 let mut state_2d = CameParamState::new_2d(4, 2, 2);
728 let mut param_2d = vec![1.0_f32; 4];
729 came_update(&mut param_2d, &grad_2d, &mut state_2d, &cfg, 2, 2).expect("2d update failed");
730 for &p in ¶m_2d {
731 assert!(p < 1.0, "2D param should decrease");
732 }
733 }
734
735 #[test]
736 fn test_came_weight_decay_larger_shrinks_more() {
737 let grad = vec![0.0_f32; 3];
738
739 let cfg_small = CameConfig {
740 lr: 0.1,
741 weight_decay: 0.01,
742 ..CameConfig::default()
743 };
744 let mut state_small = CameParamState::new_1d(3);
745 let mut param_small = vec![1.0_f32; 3];
746 came_update(&mut param_small, &grad, &mut state_small, &cfg_small, 1, 3)
747 .expect("small wd update failed");
748
749 let cfg_large = CameConfig {
750 lr: 0.1,
751 weight_decay: 0.1,
752 ..CameConfig::default()
753 };
754 let mut state_large = CameParamState::new_1d(3);
755 let mut param_large = vec![1.0_f32; 3];
756 came_update(&mut param_large, &grad, &mut state_large, &cfg_large, 1, 3)
757 .expect("large wd update failed");
758
759 for (ps, pl) in param_small.iter().zip(param_large.iter()) {
760 assert!(
761 ps.abs() > pl.abs(),
762 "larger weight_decay should shrink more: small={ps}, large={pl}"
763 );
764 }
765 }
766
767 #[test]
768 fn test_came_zero_grad_zero_wd_params_unchanged() {
769 let cfg = CameConfig {
770 lr: 0.1,
771 weight_decay: 0.0,
772 ..CameConfig::default()
773 };
774 let mut state = CameParamState::new_1d(3);
775 let mut param = vec![2.0_f32; 3];
776 let original = param.clone();
777 let grad = vec![0.0_f32; 3];
778 came_update(&mut param, &grad, &mut state, &cfg, 1, 3).expect("update failed");
779 for (p_new, p_old) in param.iter().zip(original.iter()) {
780 assert_relative_eq!(*p_new, *p_old, epsilon = 1e-6);
781 }
782 }
783
784 #[test]
785 fn test_came_multiple_steps_move_toward_zero() {
786 let cfg = CameConfig {
787 lr: 1e-2,
788 weight_decay: 0.0,
789 ..CameConfig::default()
790 };
791 let mut state = CameParamState::new_1d(1);
792 let mut param = vec![3.0_f32];
793 for _ in 0..500 {
794 let grad = param.clone();
795 came_update(&mut param, &grad, &mut state, &cfg, 1, 1).expect("update failed");
796 }
797 assert!(
798 param[0].abs() < 3.0,
799 "param should move toward 0 over many steps"
800 );
801 }
802
803 #[test]
804 fn test_came_state_not_initialised_no_panic() {
805 let cfg = CameConfig::default();
806 let mut optimizer = CameOptimizer::new(cfg);
807 let mut params = vec![vec![1.0_f32; 3]];
809 let grads = vec![vec![0.1_f32; 3]];
810 let result = optimizer.step(&mut params, &grads);
811 let _ = result;
813 }
814
815 #[test]
816 fn test_came_batch_2d_params_step_count() {
817 let cfg = CameConfig::default();
818 let mut optimizer = CameOptimizer::new(cfg);
819 optimizer.add_param_group(6, 2, 3);
820 optimizer.add_param_group(9, 3, 3);
821 let mut params = vec![vec![0.5_f32; 6], vec![0.5_f32; 9]];
822 let grads = vec![vec![0.1_f32; 6], vec![0.1_f32; 9]];
823 optimizer.step(&mut params, &grads).expect("step failed");
824 assert_eq!(optimizer.states[0].step, 1);
825 assert_eq!(optimizer.states[1].step, 1);
826 }
827
828 #[test]
829 fn test_came_clipping_bounds_param_change() {
830 let large_grad = vec![100.0_f32; 4];
835
836 let cfg_tight = CameConfig {
837 lr: 1.0,
838 clip_threshold: 0.001,
839 weight_decay: 0.0,
840 ..CameConfig::default()
841 };
842 let mut s_tight = CameParamState::new_1d(4);
843 let mut p_tight = vec![0.0_f32; 4];
844 came_update(&mut p_tight, &large_grad, &mut s_tight, &cfg_tight, 1, 4)
845 .expect("tight failed");
846
847 let cfg_loose = CameConfig {
848 lr: 1.0,
849 clip_threshold: 1000.0,
850 weight_decay: 0.0,
851 ..CameConfig::default()
852 };
853 let mut s_loose = CameParamState::new_1d(4);
854 let mut p_loose = vec![0.0_f32; 4];
855 came_update(&mut p_loose, &large_grad, &mut s_loose, &cfg_loose, 1, 4)
856 .expect("loose failed");
857
858 let m_tight: f32 = s_tight.exp_avg.iter().map(|x| x.abs()).sum();
860 let m_loose: f32 = s_loose.exp_avg.iter().map(|x| x.abs()).sum();
861 assert!(
862 m_tight < m_loose,
863 "tight clipping should reduce first moment: tight={m_tight}, loose={m_loose}"
864 );
865 }
866
867 #[test]
868 fn test_came_2d_factored_memory_efficiency() {
869 let rows = 100_usize;
870 let cols = 200_usize;
871 let size = rows * cols;
872 let state = CameParamState::new_2d(size, rows, cols);
873 let factored_size = state.exp_avg_sq_row.len() + state.exp_avg_sq_col.len();
874 assert!(
875 factored_size < size,
876 "factored memory ({factored_size}) should be less than full size ({size})"
877 );
878 }
879
880 #[test]
881 fn test_came_beta3_effect_on_insta_sq() {
882 let rows = 2_usize;
883 let cols = 2_usize;
884 let grad = vec![1.0_f32; 4];
885
886 let cfg_high = CameConfig {
887 betas: (0.9, 0.999, 0.9999),
888 ..CameConfig::default()
889 };
890 let mut state_high = CameParamState::new_2d(4, rows, cols);
891 let mut param_high = vec![0.5_f32; 4];
892 came_update(
893 &mut param_high,
894 &grad,
895 &mut state_high,
896 &cfg_high,
897 rows,
898 cols,
899 )
900 .expect("high beta3 update failed");
901
902 let cfg_low = CameConfig {
903 betas: (0.9, 0.999, 0.5),
904 ..CameConfig::default()
905 };
906 let mut state_low = CameParamState::new_2d(4, rows, cols);
907 let mut param_low = vec![0.5_f32; 4];
908 came_update(&mut param_low, &grad, &mut state_low, &cfg_low, rows, cols)
909 .expect("low beta3 update failed");
910
911 let sum_high: f32 = state_high.exp_avg_insta_sq_row.iter().sum();
912 let sum_low: f32 = state_low.exp_avg_insta_sq_row.iter().sum();
913 assert!(
914 sum_high < sum_low,
915 "higher β3 should give smaller insta_sq update: high={sum_high}, low={sum_low}"
916 );
917 }
918
919 #[test]
920 fn test_came_three_groups_distinct_states() {
921 let cfg = CameConfig::default();
922 let mut optimizer = CameOptimizer::new(cfg);
923 optimizer.add_param_group(2, 1, 2);
924 optimizer.add_param_group(4, 2, 2);
925 optimizer.add_param_group(6, 2, 3);
926
927 let mut params = vec![vec![1.0_f32; 2], vec![1.0_f32; 4], vec![1.0_f32; 6]];
928 let grads = vec![vec![0.1_f32; 2], vec![0.1_f32; 4], vec![0.1_f32; 6]];
929 optimizer.step(&mut params, &grads).expect("step failed");
930 assert_eq!(optimizer.states[0].step, 1);
931 assert_eq!(optimizer.states[1].step, 1);
932 assert_eq!(optimizer.states[2].step, 1);
933 assert_eq!(optimizer.states[0].exp_avg.len(), 2);
934 assert_eq!(optimizer.states[1].exp_avg.len(), 4);
935 assert_eq!(optimizer.states[2].exp_avg.len(), 6);
936 }
937
938 #[test]
939 fn test_came_lr_scaling_effect() {
940 let grad = vec![0.1_f32; 3];
941
942 let cfg_small_lr = CameConfig {
943 lr: 1e-4,
944 weight_decay: 0.0,
945 ..CameConfig::default()
946 };
947 let mut s_small = CameParamState::new_1d(3);
948 let mut p_small = vec![2.0_f32; 3];
949 came_update(&mut p_small, &grad, &mut s_small, &cfg_small_lr, 1, 3)
950 .expect("small lr failed");
951
952 let cfg_large_lr = CameConfig {
953 lr: 1e-1,
954 weight_decay: 0.0,
955 ..CameConfig::default()
956 };
957 let mut s_large = CameParamState::new_1d(3);
958 let mut p_large = vec![2.0_f32; 3];
959 came_update(&mut p_large, &grad, &mut s_large, &cfg_large_lr, 1, 3)
960 .expect("large lr failed");
961
962 let change_small: f32 = (2.0 - p_small[0]).abs();
963 let change_large: f32 = (2.0 - p_large[0]).abs();
964 assert!(
965 change_large > change_small,
966 "larger lr should produce larger change: small={change_small}, large={change_large}"
967 );
968 }
969
970 #[test]
971 fn test_came_dimension_mismatch_rows_cols_wrong() {
972 let cfg = CameConfig::default();
973 let mut state = CameParamState::new_2d(9, 3, 3);
974 let mut param = vec![0.0_f32; 8];
976 let grad = vec![0.0_f32; 8];
977 let result = came_update(&mut param, &grad, &mut state, &cfg, 3, 3);
978 assert!(result.is_err(), "should return error on dimension mismatch");
979 }
980
981 #[test]
982 fn test_came_exp_avg_direction_matches_grad() {
983 let cfg = CameConfig::default();
984 let mut state = CameParamState::new_1d(3);
985 let mut param = vec![0.0_f32; 3];
986 let grad = vec![0.5_f32, -0.5_f32, 0.3_f32];
987 came_update(&mut param, &grad, &mut state, &cfg, 1, 3).expect("update failed");
988 assert!(
989 state.exp_avg[0] > 0.0,
990 "positive grad → positive exp_avg[0]"
991 );
992 assert!(
993 state.exp_avg[1] < 0.0,
994 "negative grad → negative exp_avg[1]"
995 );
996 assert!(
997 state.exp_avg[2] > 0.0,
998 "positive grad → positive exp_avg[2]"
999 );
1000 }
1001}