Skip to main content

trustformers_optim/came/
mod.rs

1//! # CAME Optimizer Module
2//!
3//! Contains both the original CAME implementation and the new advanced
4//! CAME optimizer with factored second-moment estimation and confidence guidance.
5
6pub mod legacy;
7
8pub use legacy::{CAMEConfig, CAME};
9
10// New advanced CAME implementation as specified in Wave 15 Workstream BB
11
12use trustformers_core::errors::TrustformersError;
13
14/// Error type for the advanced optimizer implementations.
15#[derive(Debug, thiserror::Error)]
16pub enum OptimError {
17    /// Parameter and gradient length mismatch.
18    #[error("length mismatch: param length {param} != grad length {grad}")]
19    LengthMismatch { param: usize, grad: usize },
20    /// Row/col dimensions inconsistent with total size.
21    #[error("dimension mismatch: rows * cols ({rows} * {cols} = {product}) != size {size}")]
22    DimensionMismatch {
23        rows: usize,
24        cols: usize,
25        product: usize,
26        size: usize,
27    },
28    /// State not initialised for a parameter group index.
29    #[error("no state initialised for parameter group index {0}")]
30    StateNotInitialised(usize),
31    /// Unexpected numerical issue (NaN/Inf).
32    #[error("numerical error: {0}")]
33    NumericalError(String),
34}
35
36impl From<OptimError> for TrustformersError {
37    fn from(e: OptimError) -> Self {
38        TrustformersError::invalid_operation(e.to_string())
39    }
40}
41
42/// Configuration for the advanced CAME optimizer (Luo et al., 2023).
43///
44/// Reference: "CAME: Confidence-guided Adaptive Memory Efficient Optimization"
45#[derive(Debug, Clone)]
46pub struct CameConfig {
47    /// Learning rate (default 2e-4).
48    pub lr: f64,
49    /// (β1, β2, β3) — momentum, RMS, confidence decay rates.
50    /// Default: (0.9, 0.999, 0.9999).
51    pub betas: (f64, f64, f64),
52    /// (ε1, ε2) — numerical stability constants.
53    /// Default: (1e-30, 1e-16).
54    pub eps: (f64, f64),
55    /// Decoupled weight decay (default 0.0).
56    pub weight_decay: f64,
57    /// RMS gradient clipping threshold (default 1.0).
58    pub clip_threshold: f64,
59    /// Exponent for second-moment decay schedule: β2_t = min(1 − t^decay_rate, β2).
60    /// Default: -0.8.
61    pub decay_rate: f64,
62}
63
64impl Default for CameConfig {
65    fn default() -> Self {
66        Self {
67            lr: 2e-4,
68            betas: (0.9, 0.999, 0.9999),
69            eps: (1e-30, 1e-16),
70            weight_decay: 0.0,
71            clip_threshold: 1.0,
72            decay_rate: -0.8,
73        }
74    }
75}
76
77/// Per-parameter optimizer state for the advanced CAME optimizer.
78#[derive(Debug, Clone)]
79pub struct CameParamState {
80    /// Number of update steps taken.
81    pub step: u64,
82    /// Exponential moving average of gradients (first moment).
83    pub exp_avg: Vec<f32>,
84    /// Factored second moment — row factor `[rows]`.
85    pub exp_avg_sq_row: Vec<f32>,
86    /// Factored second moment — column factor `[cols]`.
87    pub exp_avg_sq_col: Vec<f32>,
88    /// Full second moment for 1-D parameters (`None` for 2-D params).
89    pub exp_avg_sq: Option<Vec<f32>>,
90    /// Instantaneous second-moment row factor (for confidence estimation).
91    pub exp_avg_insta_sq_row: Vec<f32>,
92    /// Instantaneous second-moment column factor (for confidence estimation).
93    pub exp_avg_insta_sq_col: Vec<f32>,
94}
95
96impl CameParamState {
97    /// Create a zeroed state for a 2-D parameter with the given dimensions.
98    pub fn new_2d(size: usize, rows: usize, cols: usize) -> Self {
99        Self {
100            step: 0,
101            exp_avg: vec![0.0_f32; size],
102            exp_avg_sq_row: vec![0.0_f32; rows],
103            exp_avg_sq_col: vec![0.0_f32; cols],
104            exp_avg_sq: None,
105            exp_avg_insta_sq_row: vec![0.0_f32; rows],
106            exp_avg_insta_sq_col: vec![0.0_f32; cols],
107        }
108    }
109
110    /// Create a zeroed state for a 1-D parameter.
111    pub fn new_1d(size: usize) -> Self {
112        Self {
113            step: 0,
114            exp_avg: vec![0.0_f32; size],
115            exp_avg_sq_row: Vec::new(),
116            exp_avg_sq_col: Vec::new(),
117            exp_avg_sq: Some(vec![0.0_f32; size]),
118            exp_avg_insta_sq_row: Vec::new(),
119            exp_avg_insta_sq_col: Vec::new(),
120        }
121    }
122}
123
124/// Compute the Root-Mean-Square of `v`.
125#[inline]
126fn rms(v: &[f32]) -> f32 {
127    if v.is_empty() {
128        return 0.0;
129    }
130    let sq_sum: f32 = v.iter().map(|x| x * x).sum();
131    (sq_sum / v.len() as f32).sqrt()
132}
133
134/// Perform one CAME update step for a single parameter group.
135///
136/// # Arguments
137///
138/// * `param`  – mutable slice of parameter values (length = `rows * cols`).
139/// * `grad`   – gradient slice (same length).
140/// * `state`  – mutable per-parameter state.
141/// * `config` – optimizer configuration.
142/// * `rows`   – matrix row count (set to 1 for 1-D parameters).
143/// * `cols`   – matrix column count (= `param.len()` for 1-D parameters).
144///
145/// # Errors
146///
147/// Returns [`OptimError`] on dimension mismatches or numerical issues.
148pub fn came_update(
149    param: &mut [f32],
150    grad: &[f32],
151    state: &mut CameParamState,
152    config: &CameConfig,
153    rows: usize,
154    cols: usize,
155) -> Result<(), OptimError> {
156    // --- Validate dimensions ------------------------------------------------
157    let size = param.len();
158    if grad.len() != size {
159        return Err(OptimError::LengthMismatch {
160            param: size,
161            grad: grad.len(),
162        });
163    }
164    let expected = rows * cols;
165    if expected != size {
166        return Err(OptimError::DimensionMismatch {
167            rows,
168            cols,
169            product: expected,
170            size,
171        });
172    }
173
174    // --- Step counter -------------------------------------------------------
175    state.step += 1;
176    let step = state.step as f64;
177
178    // --- Dynamic β2_t -------------------------------------------------------
179    // β2_t = min(1 - step^decay_rate, β2)
180    let beta2_t = (1.0 - step.powf(config.decay_rate)).min(config.betas.1) as f32;
181
182    let beta1 = config.betas.0 as f32;
183    let beta3 = config.betas.2 as f32;
184    let eps1 = config.eps.0 as f32;
185    let eps2 = config.eps.1 as f32;
186
187    // --- RMS gradient clip --------------------------------------------------
188    let grad_rms = rms(grad);
189    let clip_scale = if grad_rms > config.clip_threshold as f32 {
190        config.clip_threshold as f32 / (grad_rms + eps1)
191    } else {
192        1.0
193    };
194
195    // Lazily clipped gradient (we avoid a heap allocation by applying the
196    // scale inline in the loops below).
197
198    // --- First moment update ------------------------------------------------
199    for (m, &g) in state.exp_avg.iter_mut().zip(grad.iter()) {
200        let g_clipped = g * clip_scale;
201        *m = beta1 * *m + (1.0 - beta1) * g_clipped;
202    }
203
204    // --- Second-moment and confidence update --------------------------------
205    if rows == 1 {
206        // ---- 1-D path: full second moment -----------------------------------
207        let sq = state
208            .exp_avg_sq
209            .as_mut()
210            .ok_or_else(|| OptimError::NumericalError("1-D state missing exp_avg_sq".into()))?;
211        for (s, &g) in sq.iter_mut().zip(grad.iter()) {
212            let g_clipped = g * clip_scale;
213            *s = beta2_t * *s + (1.0 - beta2_t) * (g_clipped * g_clipped + eps1);
214        }
215
216        // Parameter update
217        for ((p, &m), &s) in param.iter_mut().zip(state.exp_avg.iter()).zip(sq.iter()) {
218            let denom = s.sqrt() + eps2;
219            let update = m / denom;
220            if config.weight_decay != 0.0 {
221                *p -= config.lr as f32 * config.weight_decay as f32 * *p;
222            }
223            *p -= config.lr as f32 * update;
224        }
225    } else {
226        // ---- 2-D path: factored second moment + confidence ------------------
227        // grad² row-means and col-means
228        let mut row_mean = vec![0.0_f32; rows];
229        let mut col_mean = vec![0.0_f32; cols];
230
231        for i in 0..rows {
232            let mut s = 0.0_f32;
233            for j in 0..cols {
234                let g = grad[i * cols + j] * clip_scale;
235                s += g * g;
236            }
237            row_mean[i] = s / cols as f32 + eps1;
238        }
239        for j in 0..cols {
240            let mut s = 0.0_f32;
241            for i in 0..rows {
242                let g = grad[i * cols + j] * clip_scale;
243                s += g * g;
244            }
245            col_mean[j] = s / rows as f32 + eps1;
246        }
247
248        // Smoothed second-moment factors
249        for (r, &rm) in state.exp_avg_sq_row.iter_mut().zip(row_mean.iter()) {
250            *r = beta2_t * *r + (1.0 - beta2_t) * rm;
251        }
252        for (c, &cm) in state.exp_avg_sq_col.iter_mut().zip(col_mean.iter()) {
253            *c = beta2_t * *c + (1.0 - beta2_t) * cm;
254        }
255
256        // Instantaneous second-moment factors (for confidence), use β3
257        for (r, &rm) in state.exp_avg_insta_sq_row.iter_mut().zip(row_mean.iter()) {
258            *r = beta3 * *r + (1.0 - beta3) * rm;
259        }
260        for (c, &cm) in state.exp_avg_insta_sq_col.iter_mut().zip(col_mean.iter()) {
261            *c = beta3 * *c + (1.0 - beta3) * cm;
262        }
263
264        // Compute R = mean of smoothed row factors (used to normalize outer-product)
265        let row_sum: f32 = state.exp_avg_sq_row.iter().sum();
266        let row_normaliser = (row_sum / rows as f32).max(eps1);
267
268        // Parameter update with confidence weighting
269        for i in 0..rows {
270            let smoothed_row = state.exp_avg_sq_row[i];
271            let insta_row = state.exp_avg_insta_sq_row[i];
272
273            for j in 0..cols {
274                let smoothed_col = state.exp_avg_sq_col[j];
275                let insta_col = state.exp_avg_insta_sq_col[j];
276
277                // RMS estimate from factored moments
278                let v_approx = (smoothed_row * smoothed_col / row_normaliser).sqrt();
279
280                // Confidence weight: ratio of smoothed vs instantaneous
281                let smoothed_insta_row = (insta_row * insta_col / row_normaliser).sqrt();
282                let confidence = if smoothed_insta_row > eps1 {
283                    (v_approx / (smoothed_insta_row + eps2)).min(1.0_f32)
284                } else {
285                    1.0_f32
286                };
287
288                let denom = v_approx + eps2;
289                let idx = i * cols + j;
290                let m = state.exp_avg[idx];
291                let update = confidence * m / denom;
292
293                let p = &mut param[idx];
294                if config.weight_decay != 0.0 {
295                    *p -= config.lr as f32 * config.weight_decay as f32 * *p;
296                }
297                *p -= config.lr as f32 * update;
298            }
299        }
300    }
301
302    Ok(())
303}
304
305/// Per-parameter group descriptor stored alongside the state.
306#[derive(Debug, Clone)]
307struct ParamGroupMeta {
308    #[allow(dead_code)]
309    size: usize,
310    rows: usize,
311    cols: usize,
312}
313
314/// Advanced CAME optimizer (factored second-moment + confidence guidance).
315///
316/// Reference: "CAME: Confidence-guided Adaptive Memory Efficient Optimization"
317/// (Luo et al., 2023)
318#[derive(Debug)]
319pub struct CameOptimizer {
320    /// Hyperparameter configuration.
321    pub config: CameConfig,
322    /// Per-parameter states.
323    pub states: Vec<CameParamState>,
324    /// Metadata (size/rows/cols) for each parameter group.
325    meta: Vec<ParamGroupMeta>,
326}
327
328impl CameOptimizer {
329    /// Create a new optimizer with the given configuration.
330    pub fn new(config: CameConfig) -> Self {
331        Self {
332            config,
333            states: Vec::new(),
334            meta: Vec::new(),
335        }
336    }
337
338    /// Register a parameter group and initialise its state.
339    ///
340    /// For 2-D matrices set `rows` and `cols` appropriately.
341    /// For 1-D tensors use `rows = 1` and `cols = param_size`.
342    pub fn add_param_group(&mut self, param_size: usize, rows: usize, cols: usize) {
343        let state = if rows == 1 {
344            CameParamState::new_1d(param_size)
345        } else {
346            CameParamState::new_2d(param_size, rows, cols)
347        };
348        self.states.push(state);
349        self.meta.push(ParamGroupMeta {
350            size: param_size,
351            rows,
352            cols,
353        });
354    }
355
356    /// Perform one update step across all parameter groups.
357    ///
358    /// # Arguments
359    ///
360    /// * `params` – mutable reference to all parameter vectors (one per group).
361    /// * `grads`  – gradient vectors (same order as `params`).
362    ///
363    /// # Errors
364    ///
365    /// Returns [`OptimError`] on any dimension mismatch.
366    pub fn step(&mut self, params: &mut [Vec<f32>], grads: &[Vec<f32>]) -> Result<(), OptimError> {
367        for (idx, ((param, grad), state)) in
368            params.iter_mut().zip(grads.iter()).zip(self.states.iter_mut()).enumerate()
369        {
370            let meta = self.meta.get(idx).ok_or(OptimError::StateNotInitialised(idx))?;
371            came_update(param, grad, state, &self.config, meta.rows, meta.cols)?;
372        }
373        Ok(())
374    }
375}
376
377// ---------------------------------------------------------------------------
378// Tests
379// ---------------------------------------------------------------------------
380
381#[cfg(test)]
382mod tests {
383    use super::*;
384    use approx::assert_relative_eq;
385
386    // -----------------------------------------------------------------------
387    // 1. Config defaults
388    // -----------------------------------------------------------------------
389    #[test]
390    fn test_came_config_defaults() {
391        let cfg = CameConfig::default();
392        assert_relative_eq!(cfg.lr, 2e-4);
393        assert_relative_eq!(cfg.betas.0, 0.9);
394        assert_relative_eq!(cfg.betas.1, 0.999);
395        assert_relative_eq!(cfg.betas.2, 0.9999);
396        assert_relative_eq!(cfg.eps.0, 1e-30);
397        assert_relative_eq!(cfg.eps.1, 1e-16);
398        assert_relative_eq!(cfg.weight_decay, 0.0);
399        assert_relative_eq!(cfg.clip_threshold, 1.0);
400        assert_relative_eq!(cfg.decay_rate, -0.8);
401    }
402
403    // -----------------------------------------------------------------------
404    // 2. State initialisation — 2-D
405    // -----------------------------------------------------------------------
406    #[test]
407    fn test_state_init_2d() {
408        let state = CameParamState::new_2d(6, 2, 3);
409        assert_eq!(state.step, 0);
410        assert_eq!(state.exp_avg.len(), 6);
411        assert_eq!(state.exp_avg_sq_row.len(), 2);
412        assert_eq!(state.exp_avg_sq_col.len(), 3);
413        assert!(state.exp_avg_sq.is_none());
414        assert_eq!(state.exp_avg_insta_sq_row.len(), 2);
415        assert_eq!(state.exp_avg_insta_sq_col.len(), 3);
416        assert!(state.exp_avg.iter().all(|&x| x == 0.0));
417    }
418
419    // -----------------------------------------------------------------------
420    // 3. State initialisation — 1-D
421    // -----------------------------------------------------------------------
422    #[test]
423    fn test_state_init_1d() {
424        let state = CameParamState::new_1d(5);
425        assert_eq!(state.step, 0);
426        assert_eq!(state.exp_avg.len(), 5);
427        assert!(state.exp_avg_sq_row.is_empty());
428        assert!(state.exp_avg_sq_col.is_empty());
429        assert!(state.exp_avg_sq.is_some());
430        assert_eq!(state.exp_avg_sq.as_ref().map(|v| v.len()), Some(5));
431    }
432
433    // -----------------------------------------------------------------------
434    // 4. Step counter increments
435    // -----------------------------------------------------------------------
436    #[test]
437    fn test_step_counter() {
438        let cfg = CameConfig::default();
439        let mut state = CameParamState::new_1d(2);
440        let mut param = vec![1.0_f32; 2];
441        let grad = vec![0.1_f32; 2];
442
443        came_update(&mut param, &grad, &mut state, &cfg, 1, 2).expect("update failed");
444        assert_eq!(state.step, 1);
445        came_update(&mut param, &grad, &mut state, &cfg, 1, 2).expect("update failed");
446        assert_eq!(state.step, 2);
447    }
448
449    // -----------------------------------------------------------------------
450    // 5. Factored second moment update (2-D)
451    // -----------------------------------------------------------------------
452    #[test]
453    fn test_factored_second_moment_update() {
454        let cfg = CameConfig {
455            lr: 0.0,
456            ..CameConfig::default()
457        };
458        let rows = 2_usize;
459        let cols = 3_usize;
460        let size = rows * cols;
461        let mut state = CameParamState::new_2d(size, rows, cols);
462        let mut param = vec![0.0_f32; size];
463        let grad = vec![1.0_f32; size];
464
465        // After step 1 all row/col factors must be positive
466        came_update(&mut param, &grad, &mut state, &cfg, rows, cols).expect("update failed");
467        assert!(state.exp_avg_sq_row.iter().all(|&x| x > 0.0));
468        assert!(state.exp_avg_sq_col.iter().all(|&x| x > 0.0));
469    }
470
471    // -----------------------------------------------------------------------
472    // 6. Dynamic β2 schedule
473    // -----------------------------------------------------------------------
474    #[test]
475    fn test_dynamic_beta2_schedule() {
476        let cfg = CameConfig::default();
477        // At step 1: beta2_t = min(1 - 1^(-0.8), 0.999) = min(0.0, 0.999) = 0.0
478        let step = 1_f64;
479        let beta2_t = (1.0 - step.powf(cfg.decay_rate)).min(cfg.betas.1);
480        assert_relative_eq!(beta2_t, 0.0, epsilon = 1e-9);
481
482        // At step 100: 1 - 100^(-0.8) ≈ 1 - 0.025 = 0.975 < 0.999, so not capped
483        let step100 = 100_f64;
484        let beta2_100 = (1.0 - step100.powf(cfg.decay_rate)).min(cfg.betas.1);
485        assert!(beta2_100 > 0.9 && beta2_100 < 1.0);
486    }
487
488    // -----------------------------------------------------------------------
489    // 7. Confidence adaptation (insta rows updated with β3)
490    // -----------------------------------------------------------------------
491    #[test]
492    fn test_confidence_adaptation() {
493        let cfg = CameConfig::default();
494        let rows = 2_usize;
495        let cols = 2_usize;
496        let size = rows * cols;
497        let mut state = CameParamState::new_2d(size, rows, cols);
498        let mut param = vec![0.0_f32; size];
499        let grad = vec![1.0_f32; size];
500
501        came_update(&mut param, &grad, &mut state, &cfg, rows, cols).expect("update failed");
502
503        // Instantaneous factors are updated with β3 = 0.9999 — they should be non-zero
504        assert!(state.exp_avg_insta_sq_row.iter().all(|&x| x > 0.0));
505        assert!(state.exp_avg_insta_sq_col.iter().all(|&x| x > 0.0));
506    }
507
508    // -----------------------------------------------------------------------
509    // 8. Weight decay applied
510    // -----------------------------------------------------------------------
511    #[test]
512    fn test_weight_decay() {
513        let cfg = CameConfig {
514            lr: 1e-1,
515            weight_decay: 0.1,
516            ..CameConfig::default()
517        };
518        let mut state = CameParamState::new_1d(2);
519        let initial_param = vec![1.0_f32; 2];
520        let mut param = initial_param.clone();
521        let grad = vec![0.0_f32; 2]; // zero grad — only weight decay effect
522
523        came_update(&mut param, &grad, &mut state, &cfg, 1, 2).expect("update failed");
524
525        // Parameters must be strictly smaller in absolute value
526        for (p_new, p_old) in param.iter().zip(initial_param.iter()) {
527            assert!(
528                p_new.abs() < p_old.abs(),
529                "weight decay did not reduce param"
530            );
531        }
532    }
533
534    // -----------------------------------------------------------------------
535    // 9. Single-step update moves in the right direction
536    // -----------------------------------------------------------------------
537    #[test]
538    fn test_single_step_direction() {
539        let cfg = CameConfig::default();
540        let mut state = CameParamState::new_1d(3);
541        let mut param = vec![0.5_f32; 3];
542        let grad = vec![0.1_f32; 3]; // positive gradient
543
544        let param_before = param.clone();
545        came_update(&mut param, &grad, &mut state, &cfg, 1, 3).expect("update failed");
546
547        // With positive gradient, parameters should decrease
548        for (p_new, p_old) in param.iter().zip(param_before.iter()) {
549            assert!(
550                p_new < p_old,
551                "param did not decrease with positive gradient"
552            );
553        }
554    }
555
556    // -----------------------------------------------------------------------
557    // 10. Gradient clipping — first moment is smaller under aggressive clip
558    // -----------------------------------------------------------------------
559    #[test]
560    fn test_gradient_clipping() {
561        // The clip_scale = clip_threshold / (rms(grad) + eps1) when rms > threshold.
562        // With a large gradient the clipped first moment should be smaller than the
563        // unclipped first moment.
564        let cfg_tight = CameConfig {
565            clip_threshold: 0.1,
566            ..CameConfig::default()
567        };
568        let cfg_loose = CameConfig {
569            clip_threshold: 1000.0,
570            ..CameConfig::default()
571        };
572
573        let large_grad = vec![5.0_f32; 4];
574
575        let mut state_tight = CameParamState::new_1d(4);
576        let mut param_tight = vec![0.0_f32; 4];
577        came_update(
578            &mut param_tight,
579            &large_grad,
580            &mut state_tight,
581            &cfg_tight,
582            1,
583            4,
584        )
585        .expect("tight update failed");
586
587        let mut state_loose = CameParamState::new_1d(4);
588        let mut param_loose = vec![0.0_f32; 4];
589        came_update(
590            &mut param_loose,
591            &large_grad,
592            &mut state_loose,
593            &cfg_loose,
594            1,
595            4,
596        )
597        .expect("loose update failed");
598
599        // Under tight clipping the first moment exp_avg values must be smaller in
600        // absolute value because the effective gradient fed into the EMA was scaled down.
601        let m_tight: f32 = state_tight.exp_avg.iter().map(|x| x.abs()).sum();
602        let m_loose: f32 = state_loose.exp_avg.iter().map(|x| x.abs()).sum();
603        assert!(
604            m_tight < m_loose,
605            "tight clipping did not reduce first moment: m_tight={m_tight} m_loose={m_loose}"
606        );
607    }
608
609    // -----------------------------------------------------------------------
610    // 11. Multi-step convergence on a quadratic (1-D)
611    // -----------------------------------------------------------------------
612    #[test]
613    fn test_convergence_quadratic() {
614        // Minimise f(x) = x^2 / 2, gradient = x
615        let cfg = CameConfig {
616            lr: 1e-2,
617            ..CameConfig::default()
618        };
619        let mut state = CameParamState::new_1d(1);
620        let mut param = vec![5.0_f32];
621
622        for _ in 0..2000 {
623            let grad = param.clone(); // gradient of x^2/2 is x
624            came_update(&mut param, &grad, &mut state, &cfg, 1, 1).expect("update failed");
625        }
626
627        assert!(
628            param[0].abs() < 0.1,
629            "CAME did not converge on quadratic: final param = {}",
630            param[0]
631        );
632    }
633
634    // -----------------------------------------------------------------------
635    // 12. Dimension mismatch error returned (not panicked)
636    // -----------------------------------------------------------------------
637    #[test]
638    fn test_dimension_mismatch_error() {
639        let cfg = CameConfig::default();
640        let mut state = CameParamState::new_1d(4);
641        let mut param = vec![0.0_f32; 4];
642        let grad = vec![0.0_f32; 5]; // wrong size
643
644        let result = came_update(&mut param, &grad, &mut state, &cfg, 1, 4);
645        assert!(result.is_err());
646        matches!(result.unwrap_err(), OptimError::LengthMismatch { .. });
647    }
648
649    // -----------------------------------------------------------------------
650    // 13. CameOptimizer multi-param step
651    // -----------------------------------------------------------------------
652    #[test]
653    fn test_came_optimizer_multi_param() {
654        let cfg = CameConfig::default();
655        let mut optimizer = CameOptimizer::new(cfg);
656        optimizer.add_param_group(4, 2, 2);
657        optimizer.add_param_group(3, 1, 3);
658
659        let mut params = vec![vec![1.0_f32; 4], vec![1.0_f32; 3]];
660        let grads = vec![vec![0.1_f32; 4], vec![0.1_f32; 3]];
661
662        optimizer.step(&mut params, &grads).expect("step failed");
663        assert_eq!(optimizer.states[0].step, 1);
664        assert_eq!(optimizer.states[1].step, 1);
665    }
666}
667
668#[cfg(test)]
669mod extended_tests {
670    use super::*;
671    use approx::assert_relative_eq;
672
673    #[test]
674    fn test_came_state_step_zero_at_init() {
675        let state = CameParamState::new_2d(6, 2, 3);
676        assert_eq!(state.step, 0);
677        let state1d = CameParamState::new_1d(4);
678        assert_eq!(state1d.step, 0);
679    }
680
681    #[test]
682    fn test_came_confidence_factors_nonzero_after_step() {
683        let cfg = CameConfig::default();
684        let mut state = CameParamState::new_2d(6, 2, 3);
685        let mut param = vec![0.5_f32; 6];
686        let grad = vec![0.1_f32; 6];
687        came_update(&mut param, &grad, &mut state, &cfg, 2, 3).expect("update failed");
688        assert!(
689            state.exp_avg_insta_sq_row.iter().all(|&x| x > 0.0),
690            "insta_sq_row should be nonzero after update"
691        );
692        assert!(
693            state.exp_avg_insta_sq_col.iter().all(|&x| x > 0.0),
694            "insta_sq_col should be nonzero after update"
695        );
696    }
697
698    #[test]
699    fn test_came_positive_grad_decreases_params() {
700        let cfg = CameConfig::default();
701        let mut state = CameParamState::new_1d(4);
702        let mut param = vec![1.0_f32; 4];
703        let grad = vec![0.5_f32; 4];
704        let before = param.clone();
705        came_update(&mut param, &grad, &mut state, &cfg, 1, 4).expect("update failed");
706        for (p_new, p_old) in param.iter().zip(before.iter()) {
707            assert!(
708                p_new < p_old,
709                "param should decrease with positive gradient"
710            );
711        }
712    }
713
714    #[test]
715    fn test_came_1d_vs_2d_single_element_both_decrease() {
716        let cfg = CameConfig::default();
717        let grad = vec![0.2_f32];
718
719        // 1D path: new_1d, rows=1, cols=1
720        let mut state_1d = CameParamState::new_1d(1);
721        let mut param_1d = vec![1.0_f32];
722        came_update(&mut param_1d, &grad, &mut state_1d, &cfg, 1, 1).expect("1d update failed");
723        assert!(param_1d[0] < 1.0, "1D param should decrease");
724
725        // True 2D path: 2 rows x 2 cols (rows != 1 to take the factored path)
726        let grad_2d = vec![0.2_f32; 4];
727        let mut state_2d = CameParamState::new_2d(4, 2, 2);
728        let mut param_2d = vec![1.0_f32; 4];
729        came_update(&mut param_2d, &grad_2d, &mut state_2d, &cfg, 2, 2).expect("2d update failed");
730        for &p in &param_2d {
731            assert!(p < 1.0, "2D param should decrease");
732        }
733    }
734
735    #[test]
736    fn test_came_weight_decay_larger_shrinks_more() {
737        let grad = vec![0.0_f32; 3];
738
739        let cfg_small = CameConfig {
740            lr: 0.1,
741            weight_decay: 0.01,
742            ..CameConfig::default()
743        };
744        let mut state_small = CameParamState::new_1d(3);
745        let mut param_small = vec![1.0_f32; 3];
746        came_update(&mut param_small, &grad, &mut state_small, &cfg_small, 1, 3)
747            .expect("small wd update failed");
748
749        let cfg_large = CameConfig {
750            lr: 0.1,
751            weight_decay: 0.1,
752            ..CameConfig::default()
753        };
754        let mut state_large = CameParamState::new_1d(3);
755        let mut param_large = vec![1.0_f32; 3];
756        came_update(&mut param_large, &grad, &mut state_large, &cfg_large, 1, 3)
757            .expect("large wd update failed");
758
759        for (ps, pl) in param_small.iter().zip(param_large.iter()) {
760            assert!(
761                ps.abs() > pl.abs(),
762                "larger weight_decay should shrink more: small={ps}, large={pl}"
763            );
764        }
765    }
766
767    #[test]
768    fn test_came_zero_grad_zero_wd_params_unchanged() {
769        let cfg = CameConfig {
770            lr: 0.1,
771            weight_decay: 0.0,
772            ..CameConfig::default()
773        };
774        let mut state = CameParamState::new_1d(3);
775        let mut param = vec![2.0_f32; 3];
776        let original = param.clone();
777        let grad = vec![0.0_f32; 3];
778        came_update(&mut param, &grad, &mut state, &cfg, 1, 3).expect("update failed");
779        for (p_new, p_old) in param.iter().zip(original.iter()) {
780            assert_relative_eq!(*p_new, *p_old, epsilon = 1e-6);
781        }
782    }
783
784    #[test]
785    fn test_came_multiple_steps_move_toward_zero() {
786        let cfg = CameConfig {
787            lr: 1e-2,
788            weight_decay: 0.0,
789            ..CameConfig::default()
790        };
791        let mut state = CameParamState::new_1d(1);
792        let mut param = vec![3.0_f32];
793        for _ in 0..500 {
794            let grad = param.clone();
795            came_update(&mut param, &grad, &mut state, &cfg, 1, 1).expect("update failed");
796        }
797        assert!(
798            param[0].abs() < 3.0,
799            "param should move toward 0 over many steps"
800        );
801    }
802
803    #[test]
804    fn test_came_state_not_initialised_no_panic() {
805        let cfg = CameConfig::default();
806        let mut optimizer = CameOptimizer::new(cfg);
807        // No add_param_group calls — zip with 0 states = 0 iterations, no panic
808        let mut params = vec![vec![1.0_f32; 3]];
809        let grads = vec![vec![0.1_f32; 3]];
810        let result = optimizer.step(&mut params, &grads);
811        // Should not panic; either Ok or Err is acceptable
812        let _ = result;
813    }
814
815    #[test]
816    fn test_came_batch_2d_params_step_count() {
817        let cfg = CameConfig::default();
818        let mut optimizer = CameOptimizer::new(cfg);
819        optimizer.add_param_group(6, 2, 3);
820        optimizer.add_param_group(9, 3, 3);
821        let mut params = vec![vec![0.5_f32; 6], vec![0.5_f32; 9]];
822        let grads = vec![vec![0.1_f32; 6], vec![0.1_f32; 9]];
823        optimizer.step(&mut params, &grads).expect("step failed");
824        assert_eq!(optimizer.states[0].step, 1);
825        assert_eq!(optimizer.states[1].step, 1);
826    }
827
828    #[test]
829    fn test_came_clipping_bounds_param_change() {
830        // Clipping affects the first moment (exp_avg). After step 1:
831        // exp_avg_tight[i] = (1-beta1) * clip_scale * grad[i]  (small clip_scale for tight)
832        // exp_avg_loose[i] = (1-beta1) * 1.0 * grad[i]         (no clipping needed)
833        // We verify by checking that the first moments differ.
834        let large_grad = vec![100.0_f32; 4];
835
836        let cfg_tight = CameConfig {
837            lr: 1.0,
838            clip_threshold: 0.001,
839            weight_decay: 0.0,
840            ..CameConfig::default()
841        };
842        let mut s_tight = CameParamState::new_1d(4);
843        let mut p_tight = vec![0.0_f32; 4];
844        came_update(&mut p_tight, &large_grad, &mut s_tight, &cfg_tight, 1, 4)
845            .expect("tight failed");
846
847        let cfg_loose = CameConfig {
848            lr: 1.0,
849            clip_threshold: 1000.0,
850            weight_decay: 0.0,
851            ..CameConfig::default()
852        };
853        let mut s_loose = CameParamState::new_1d(4);
854        let mut p_loose = vec![0.0_f32; 4];
855        came_update(&mut p_loose, &large_grad, &mut s_loose, &cfg_loose, 1, 4)
856            .expect("loose failed");
857
858        // The tight-clipped first moment should be much smaller in magnitude
859        let m_tight: f32 = s_tight.exp_avg.iter().map(|x| x.abs()).sum();
860        let m_loose: f32 = s_loose.exp_avg.iter().map(|x| x.abs()).sum();
861        assert!(
862            m_tight < m_loose,
863            "tight clipping should reduce first moment: tight={m_tight}, loose={m_loose}"
864        );
865    }
866
867    #[test]
868    fn test_came_2d_factored_memory_efficiency() {
869        let rows = 100_usize;
870        let cols = 200_usize;
871        let size = rows * cols;
872        let state = CameParamState::new_2d(size, rows, cols);
873        let factored_size = state.exp_avg_sq_row.len() + state.exp_avg_sq_col.len();
874        assert!(
875            factored_size < size,
876            "factored memory ({factored_size}) should be less than full size ({size})"
877        );
878    }
879
880    #[test]
881    fn test_came_beta3_effect_on_insta_sq() {
882        let rows = 2_usize;
883        let cols = 2_usize;
884        let grad = vec![1.0_f32; 4];
885
886        let cfg_high = CameConfig {
887            betas: (0.9, 0.999, 0.9999),
888            ..CameConfig::default()
889        };
890        let mut state_high = CameParamState::new_2d(4, rows, cols);
891        let mut param_high = vec![0.5_f32; 4];
892        came_update(
893            &mut param_high,
894            &grad,
895            &mut state_high,
896            &cfg_high,
897            rows,
898            cols,
899        )
900        .expect("high beta3 update failed");
901
902        let cfg_low = CameConfig {
903            betas: (0.9, 0.999, 0.5),
904            ..CameConfig::default()
905        };
906        let mut state_low = CameParamState::new_2d(4, rows, cols);
907        let mut param_low = vec![0.5_f32; 4];
908        came_update(&mut param_low, &grad, &mut state_low, &cfg_low, rows, cols)
909            .expect("low beta3 update failed");
910
911        let sum_high: f32 = state_high.exp_avg_insta_sq_row.iter().sum();
912        let sum_low: f32 = state_low.exp_avg_insta_sq_row.iter().sum();
913        assert!(
914            sum_high < sum_low,
915            "higher β3 should give smaller insta_sq update: high={sum_high}, low={sum_low}"
916        );
917    }
918
919    #[test]
920    fn test_came_three_groups_distinct_states() {
921        let cfg = CameConfig::default();
922        let mut optimizer = CameOptimizer::new(cfg);
923        optimizer.add_param_group(2, 1, 2);
924        optimizer.add_param_group(4, 2, 2);
925        optimizer.add_param_group(6, 2, 3);
926
927        let mut params = vec![vec![1.0_f32; 2], vec![1.0_f32; 4], vec![1.0_f32; 6]];
928        let grads = vec![vec![0.1_f32; 2], vec![0.1_f32; 4], vec![0.1_f32; 6]];
929        optimizer.step(&mut params, &grads).expect("step failed");
930        assert_eq!(optimizer.states[0].step, 1);
931        assert_eq!(optimizer.states[1].step, 1);
932        assert_eq!(optimizer.states[2].step, 1);
933        assert_eq!(optimizer.states[0].exp_avg.len(), 2);
934        assert_eq!(optimizer.states[1].exp_avg.len(), 4);
935        assert_eq!(optimizer.states[2].exp_avg.len(), 6);
936    }
937
938    #[test]
939    fn test_came_lr_scaling_effect() {
940        let grad = vec![0.1_f32; 3];
941
942        let cfg_small_lr = CameConfig {
943            lr: 1e-4,
944            weight_decay: 0.0,
945            ..CameConfig::default()
946        };
947        let mut s_small = CameParamState::new_1d(3);
948        let mut p_small = vec![2.0_f32; 3];
949        came_update(&mut p_small, &grad, &mut s_small, &cfg_small_lr, 1, 3)
950            .expect("small lr failed");
951
952        let cfg_large_lr = CameConfig {
953            lr: 1e-1,
954            weight_decay: 0.0,
955            ..CameConfig::default()
956        };
957        let mut s_large = CameParamState::new_1d(3);
958        let mut p_large = vec![2.0_f32; 3];
959        came_update(&mut p_large, &grad, &mut s_large, &cfg_large_lr, 1, 3)
960            .expect("large lr failed");
961
962        let change_small: f32 = (2.0 - p_small[0]).abs();
963        let change_large: f32 = (2.0 - p_large[0]).abs();
964        assert!(
965            change_large > change_small,
966            "larger lr should produce larger change: small={change_small}, large={change_large}"
967        );
968    }
969
970    #[test]
971    fn test_came_dimension_mismatch_rows_cols_wrong() {
972        let cfg = CameConfig::default();
973        let mut state = CameParamState::new_2d(9, 3, 3);
974        // param has 8 elements but rows*cols=9
975        let mut param = vec![0.0_f32; 8];
976        let grad = vec![0.0_f32; 8];
977        let result = came_update(&mut param, &grad, &mut state, &cfg, 3, 3);
978        assert!(result.is_err(), "should return error on dimension mismatch");
979    }
980
981    #[test]
982    fn test_came_exp_avg_direction_matches_grad() {
983        let cfg = CameConfig::default();
984        let mut state = CameParamState::new_1d(3);
985        let mut param = vec![0.0_f32; 3];
986        let grad = vec![0.5_f32, -0.5_f32, 0.3_f32];
987        came_update(&mut param, &grad, &mut state, &cfg, 1, 3).expect("update failed");
988        assert!(
989            state.exp_avg[0] > 0.0,
990            "positive grad → positive exp_avg[0]"
991        );
992        assert!(
993            state.exp_avg[1] < 0.0,
994            "negative grad → negative exp_avg[1]"
995        );
996        assert!(
997            state.exp_avg[2] > 0.0,
998            "positive grad → positive exp_avg[2]"
999        );
1000    }
1001}