Skip to main content

scirs2_optimize/differentiable_optimization/
perturbed_optimizer.rs

1//! Differentiable combinatorial optimization via perturbed optimizers.
2//!
3//! Implements the **Perturbed Optimizer** framework (Berthet et al., 2020) for
4//! making any black-box combinatorial solver differentiable through additive
5//! Gaussian noise perturbations.
6//!
7//! Given a combinatorial optimizer `y*(θ) = argmax_y θᵀ y` (or argmin),
8//! the perturbed optimizer computes:
9//!
10//!   ŷ(θ) = E[y*(θ + σZ)]   where  Z ~ N(0, I)
11//!
12//! The gradient is estimated as:
13//!
14//!   ∇_θ L ≈ (1/σ) E[L(y*(θ + σZ)) · Z]   (REINFORCE)
15//!
16//! or via the reparameterized covariance estimator:
17//!
18//!   ∇_θ L ≈ (1/σ) Cov[y*(θ + σZ), Z] · ∇_y L
19//!
20//! Also includes `SparseMap` for structured prediction on the marginal
21//! polytope via QP.
22//!
23//! # References
24//! - Berthet et al. (2020). "Learning with Differentiable Perturbed Optimizers." NeurIPS.
25//! - Niculae & Blondel (2017). "A regularized framework for sparse and structured
26//!   neural attention." NeurIPS.
27
28use crate::error::{OptimizeError, OptimizeResult};
29
30use super::kkt_sensitivity::kkt_sensitivity;
31
32// ─────────────────────────────────────────────────────────────────────────────
33// Random number generator (xorshift64 — pure Rust, no external deps)
34// ─────────────────────────────────────────────────────────────────────────────
35
36/// Lightweight xorshift64 PRNG for Monte Carlo sampling.
37struct Xorshift64 {
38    state: u64,
39}
40
41impl Xorshift64 {
42    fn new(seed: u64) -> Self {
43        Self {
44            state: if seed == 0 { 1 } else { seed },
45        }
46    }
47
48    fn next_u64(&mut self) -> u64 {
49        let mut x = self.state;
50        x ^= x << 13;
51        x ^= x >> 7;
52        x ^= x << 17;
53        self.state = x;
54        x
55    }
56
57    /// Generate a uniform [0, 1) sample.
58    fn uniform(&mut self) -> f64 {
59        (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
60    }
61
62    /// Box-Muller transform: generate N(0, 1) sample.
63    fn normal(&mut self) -> f64 {
64        let u1 = self.uniform().max(1e-15); // avoid log(0)
65        let u2 = self.uniform();
66        let r = (-2.0 * u1.ln()).sqrt();
67        let theta = 2.0 * std::f64::consts::PI * u2;
68        r * theta.cos()
69    }
70
71    /// Generate a N(0, I) vector of length n.
72    fn normal_vector(&mut self, n: usize) -> Vec<f64> {
73        (0..n).map(|_| self.normal()).collect()
74    }
75}
76
77// ─────────────────────────────────────────────────────────────────────────────
78// Configuration
79// ─────────────────────────────────────────────────────────────────────────────
80
81/// Configuration for the perturbed optimizer.
82#[derive(Debug, Clone)]
83pub struct PerturbedOptimizerConfig {
84    /// Number of Monte Carlo samples for expectation estimation.
85    pub n_samples: usize,
86    /// Perturbation standard deviation σ.
87    pub sigma: f64,
88    /// Random seed (for reproducibility).
89    pub seed: u64,
90}
91
92impl Default for PerturbedOptimizerConfig {
93    fn default() -> Self {
94        Self {
95            n_samples: 20,
96            sigma: 1.0,
97            seed: 42,
98        }
99    }
100}
101
102// ─────────────────────────────────────────────────────────────────────────────
103// Perturbed optimizer
104// ─────────────────────────────────────────────────────────────────────────────
105
106/// A differentiable wrapper around any black-box combinatorial optimizer.
107///
108/// Wraps a function `optimizer: θ → y*(θ)` and computes a smooth
109/// approximation `E[y*(θ + σZ)]` via Monte Carlo.
110///
111/// # Type Parameter
112/// * `F` – function type for the combinatorial optimizer, mapping `&[f64]` to `Vec<f64>`.
113pub struct PerturbedOptimizer<F>
114where
115    F: Fn(&[f64]) -> Vec<f64>,
116{
117    optimizer: F,
118    config: PerturbedOptimizerConfig,
119    /// Cached samples from the last forward call (for backward).
120    cached_samples: Option<Vec<Vec<f64>>>,
121    /// Cached outputs from the last forward call.
122    cached_outputs: Option<Vec<Vec<f64>>>,
123    /// Cached noise vectors from the last forward call.
124    cached_noise: Option<Vec<Vec<f64>>>,
125}
126
127impl<F> PerturbedOptimizer<F>
128where
129    F: Fn(&[f64]) -> Vec<f64>,
130{
131    /// Create a new perturbed optimizer with default configuration.
132    pub fn new(optimizer: F) -> Self {
133        Self {
134            optimizer,
135            config: PerturbedOptimizerConfig::default(),
136            cached_samples: None,
137            cached_outputs: None,
138            cached_noise: None,
139        }
140    }
141
142    /// Create a new perturbed optimizer with custom configuration.
143    pub fn with_config(optimizer: F, config: PerturbedOptimizerConfig) -> Self {
144        Self {
145            optimizer,
146            config,
147            cached_samples: None,
148            cached_outputs: None,
149            cached_noise: None,
150        }
151    }
152
153    /// Forward pass: compute `E[y*(θ + σZ)]` via Monte Carlo.
154    ///
155    /// Samples `n_samples` perturbations Z_k ~ N(0, I) and returns the
156    /// sample mean of the optimizer outputs.
157    ///
158    /// # Arguments
159    /// * `theta` – parameter vector (length d).
160    ///
161    /// # Returns
162    /// Expected optimizer output `ŷ(θ)` (length equal to optimizer output).
163    pub fn forward(&mut self, theta: &[f64]) -> OptimizeResult<Vec<f64>> {
164        let d = theta.len();
165        let mut rng = Xorshift64::new(self.config.seed);
166
167        let mut outputs: Vec<Vec<f64>> = Vec::with_capacity(self.config.n_samples);
168        let mut noises: Vec<Vec<f64>> = Vec::with_capacity(self.config.n_samples);
169
170        for _ in 0..self.config.n_samples {
171            let z = rng.normal_vector(d);
172            let theta_perturbed: Vec<f64> = theta
173                .iter()
174                .zip(z.iter())
175                .map(|(&ti, &zi)| ti + self.config.sigma * zi)
176                .collect();
177            let y = (self.optimizer)(&theta_perturbed);
178            outputs.push(y);
179            noises.push(z);
180        }
181
182        // Compute mean output
183        if outputs.is_empty() {
184            return Err(OptimizeError::ComputationError(
185                "No samples generated in PerturbedOptimizer::forward".to_string(),
186            ));
187        }
188        let out_len = outputs[0].len();
189        let mut mean_y = vec![0.0_f64; out_len];
190        for output in &outputs {
191            if output.len() != out_len {
192                return Err(OptimizeError::ComputationError(
193                    "Inconsistent optimizer output lengths".to_string(),
194                ));
195            }
196            for (i, &oi) in output.iter().enumerate() {
197                mean_y[i] += oi;
198            }
199        }
200        let n = self.config.n_samples as f64;
201        for mi in &mut mean_y {
202            *mi /= n;
203        }
204
205        // Cache for backward
206        self.cached_samples = Some(
207            (0..self.config.n_samples)
208                .map(|k| {
209                    theta
210                        .iter()
211                        .zip(noises[k].iter())
212                        .map(|(&ti, &zi)| ti + self.config.sigma * zi)
213                        .collect()
214                })
215                .collect(),
216        );
217        self.cached_outputs = Some(outputs);
218        self.cached_noise = Some(noises);
219
220        Ok(mean_y)
221    }
222
223    /// Gradient estimate via reparameterized covariance:
224    ///
225    ///   grad_theta L ~ (1/sigma) Cov\[y*(theta + sigma*Z), Z\] * dL/dy
226    ///            = `(1/sigma^2*N) Sum_k (y_k - y_mean) * Z_k * dL/dy`
227    ///
228    /// This is an unbiased estimator when y* is the gradient of a linear function,
229    /// and has lower variance than REINFORCE.
230    ///
231    /// # Arguments
232    /// * `theta`  – parameter vector (length d).
233    /// * `dl_dy`  – upstream gradient dL/dŷ (length = optimizer output length).
234    ///
235    /// # Returns
236    /// Gradient estimate ∇_θ L (length d).
237    pub fn gradient(&self, theta: &[f64], dl_dy: &[f64]) -> OptimizeResult<Vec<f64>> {
238        let outputs = self.cached_outputs.as_ref().ok_or_else(|| {
239            OptimizeError::ComputationError(
240                "PerturbedOptimizer::gradient called before forward".to_string(),
241            )
242        })?;
243        let noises = self
244            .cached_noise
245            .as_ref()
246            .ok_or_else(|| OptimizeError::ComputationError("No cached noise".to_string()))?;
247
248        let d = theta.len();
249        let out_len = dl_dy.len();
250        let n_samples = outputs.len();
251
252        if n_samples == 0 {
253            return Err(OptimizeError::ComputationError(
254                "Empty sample cache".to_string(),
255            ));
256        }
257
258        // Compute mean output ȳ
259        let mut mean_y = vec![0.0_f64; out_len];
260        for output in outputs.iter() {
261            for (i, &oi) in output.iter().enumerate().take(out_len) {
262                mean_y[i] += oi;
263            }
264        }
265        for mi in &mut mean_y {
266            *mi /= n_samples as f64;
267        }
268
269        // Reparameterized covariance estimator:
270        // ∇_θ_j L ≈ (1/σ) * (1/N) * Σ_k [(y_k - ȳ) · dL/dy] * Z_k_j
271        let sigma = self.config.sigma;
272        let mut grad = vec![0.0_f64; d];
273
274        for k in 0..n_samples {
275            // Compute (y_k - ȳ) · dL/dy = scalar coefficient for sample k
276            let coeff: f64 = outputs[k]
277                .iter()
278                .zip(mean_y.iter())
279                .zip(dl_dy.iter())
280                .map(|((&yk, &ybar), &dly)| (yk - ybar) * dly)
281                .sum();
282
283            // Gradient contribution: coeff * Z_k / (σ * N)
284            for j in 0..d {
285                let z_kj = if j < noises[k].len() {
286                    noises[k][j]
287                } else {
288                    0.0
289                };
290                grad[j] += coeff * z_kj;
291            }
292        }
293
294        let scale = 1.0 / (sigma * n_samples as f64);
295        for gi in &mut grad {
296            *gi *= scale;
297        }
298
299        Ok(grad)
300    }
301
302    /// REINFORCE (score-function) gradient estimator:
303    ///
304    ///   ∇_θ L ≈ (1/σN) Σ_k L(y_k) Z_k
305    ///
306    /// where L(y_k) = dL/dy · y_k (linear approximation to the loss).
307    ///
308    /// # Arguments
309    /// * `theta`     – parameter vector.
310    /// * `dl_dy`     – upstream gradient (defines the loss as L = dl_dy · y).
311    pub fn reinforce_gradient(&self, theta: &[f64], dl_dy: &[f64]) -> OptimizeResult<Vec<f64>> {
312        let outputs = self.cached_outputs.as_ref().ok_or_else(|| {
313            OptimizeError::ComputationError(
314                "PerturbedOptimizer::reinforce_gradient called before forward".to_string(),
315            )
316        })?;
317        let noises = self
318            .cached_noise
319            .as_ref()
320            .ok_or_else(|| OptimizeError::ComputationError("No cached noise".to_string()))?;
321
322        let d = theta.len();
323        let n_samples = outputs.len();
324        let sigma = self.config.sigma;
325
326        let mut grad = vec![0.0_f64; d];
327        for k in 0..n_samples {
328            // Approximate loss: L_k = dl_dy · y_k
329            let l_k: f64 = outputs[k]
330                .iter()
331                .zip(dl_dy.iter())
332                .map(|(&yk, &dly)| yk * dly)
333                .sum();
334
335            for j in 0..d {
336                let z_kj = if j < noises[k].len() {
337                    noises[k][j]
338                } else {
339                    0.0
340                };
341                grad[j] += l_k * z_kj;
342            }
343        }
344
345        let scale = 1.0 / (sigma * n_samples as f64);
346        for gi in &mut grad {
347            *gi *= scale;
348        }
349
350        Ok(grad)
351    }
352
353    /// Access the cached mean output from the last forward pass.
354    pub fn last_mean_output(&self) -> Option<Vec<f64>> {
355        let outputs = self.cached_outputs.as_ref()?;
356        if outputs.is_empty() {
357            return None;
358        }
359        let out_len = outputs[0].len();
360        let mut mean = vec![0.0_f64; out_len];
361        for output in outputs {
362            for (i, &oi) in output.iter().enumerate().take(out_len) {
363                mean[i] += oi;
364            }
365        }
366        let n = outputs.len() as f64;
367        for mi in &mut mean {
368            *mi /= n;
369        }
370        Some(mean)
371    }
372}
373
374// ─────────────────────────────────────────────────────────────────────────────
375// SparseMap
376// ─────────────────────────────────────────────────────────────────────────────
377
378/// Configuration for the SparseMap structured prediction layer.
379#[derive(Debug, Clone)]
380pub struct SparseMapConfig {
381    /// Maximum number of projected-gradient iterations.
382    pub max_iter: usize,
383    /// Convergence tolerance.
384    pub tol: f64,
385    /// Step size for projected-gradient updates.
386    pub step_size: f64,
387}
388
389impl Default for SparseMapConfig {
390    fn default() -> Self {
391        Self {
392            max_iter: 1000,
393            tol: 1e-8,
394            step_size: 0.1,
395        }
396    }
397}
398
399/// A SparseMap layer: structured prediction with sparse marginals.
400///
401/// Solves the QP:
402///
403///   max  θᵀμ - ½ μᵀ μ   s.t.  μ ∈ M
404///
405/// where M is the marginal polytope (e.g., the simplex for unstructured
406/// classification). The solution is a sparse probability distribution.
407///
408/// The backward pass uses the KKT conditions of the QP.
409#[derive(Debug, Clone)]
410pub struct SparseMap {
411    config: SparseMapConfig,
412    /// Equality constraint matrix A (defines the polytope via Ax = b, x ≥ 0).
413    a_marginal: Vec<Vec<f64>>,
414    /// Equality rhs b.
415    b_marginal: Vec<f64>,
416    /// Last forward result: μ* (sparse distribution).
417    last_mu: Option<Vec<f64>>,
418    /// Last dual: ν* for equality constraints.
419    last_nu: Option<Vec<f64>>,
420    /// Last theta.
421    last_theta: Option<Vec<f64>>,
422}
423
424impl SparseMap {
425    /// Create a new SparseMap layer for a given marginal polytope.
426    ///
427    /// # Arguments
428    /// * `a_marginal` – equality constraints defining the polytope (Ax = b, x ≥ 0).
429    /// * `b_marginal` – equality rhs.
430    pub fn new(a_marginal: Vec<Vec<f64>>, b_marginal: Vec<f64>) -> Self {
431        Self {
432            config: SparseMapConfig::default(),
433            a_marginal,
434            b_marginal,
435            last_mu: None,
436            last_nu: None,
437            last_theta: None,
438        }
439    }
440
441    /// Create a SparseMap for the simplex: Σ μ_i = 1, μ_i ≥ 0.
442    pub fn simplex(n: usize) -> Self {
443        let a = vec![vec![1.0_f64; n]];
444        let b = vec![1.0_f64];
445        Self::new(a, b)
446    }
447
448    /// Create a SparseMap with custom configuration.
449    pub fn with_config(
450        a_marginal: Vec<Vec<f64>>,
451        b_marginal: Vec<f64>,
452        config: SparseMapConfig,
453    ) -> Self {
454        Self {
455            config,
456            a_marginal,
457            b_marginal,
458            last_mu: None,
459            last_nu: None,
460            last_theta: None,
461        }
462    }
463
464    /// Forward pass: solve the QP on the marginal polytope.
465    ///
466    ///   μ* = argmax_{μ ∈ M} θᵀμ - ½ ||μ||²
467    ///       = argmin_{μ ∈ M} ½ ||μ - θ||²
468    ///       = Π_M(θ)   (Euclidean projection onto M)
469    ///
470    /// Uses iterative projected gradient on the Lagrangian.
471    ///
472    /// # Arguments
473    /// * `theta` – score vector (length n).
474    pub fn forward(&mut self, theta: &[f64]) -> OptimizeResult<Vec<f64>> {
475        let n = theta.len();
476        let p = self.b_marginal.len();
477
478        if self.a_marginal.len() != p {
479            return Err(OptimizeError::InvalidInput(format!(
480                "A_marginal rows ({}) != b_marginal length ({})",
481                self.a_marginal.len(),
482                p
483            )));
484        }
485
486        // Solve: min ½ μᵀ μ - θᵀ μ  s.t. A μ = b, μ ≥ 0
487        // via projected gradient descent in the dual:
488        //
489        // Lagrangian: L = ½ μᵀμ - θᵀμ + νᵀ(Aμ - b)
490        // Primal: μ = max(0, θ - Aᵀν)
491        // Dual: max -½ ||max(0, θ - Aᵀν)||² + θᵀ max(0, θ-Aᵀν) - νᵀ b
492
493        let mut nu = vec![0.0_f64; p];
494        let step = self.config.step_size;
495
496        for _ in 0..self.config.max_iter {
497            // Primal: μ(ν) = max(0, θ - Aᵀν)
498            let at_nu: Vec<f64> = (0..n)
499                .map(|j| {
500                    (0..p)
501                        .map(|i| {
502                            let a_ij = if i < self.a_marginal.len() && j < self.a_marginal[i].len()
503                            {
504                                self.a_marginal[i][j]
505                            } else {
506                                0.0
507                            };
508                            nu[i] * a_ij
509                        })
510                        .sum::<f64>()
511                })
512                .collect();
513
514            let mu: Vec<f64> = (0..n).map(|j| (theta[j] - at_nu[j]).max(0.0)).collect();
515
516            // Dual gradient: ∂L/∂ν_i = Σ_j A_{ij} μ_j - b_i = (Aμ)_i - b_i
517            let amu: Vec<f64> = (0..p)
518                .map(|i| {
519                    (0..n)
520                        .map(|j| {
521                            let a_ij = if i < self.a_marginal.len() && j < self.a_marginal[i].len()
522                            {
523                                self.a_marginal[i][j]
524                            } else {
525                                0.0
526                            };
527                            a_ij * mu[j]
528                        })
529                        .sum::<f64>()
530                })
531                .collect();
532
533            let nu_new: Vec<f64> = (0..p)
534                .map(|i| nu[i] + step * (amu[i] - self.b_marginal[i]))
535                .collect();
536
537            // Check convergence
538            let delta: f64 = nu_new
539                .iter()
540                .zip(nu.iter())
541                .map(|(a, b)| (a - b).powi(2))
542                .sum::<f64>()
543                .sqrt();
544
545            nu = nu_new;
546
547            if delta < self.config.tol {
548                break;
549            }
550        }
551
552        // Final primal
553        let at_nu: Vec<f64> = (0..n)
554            .map(|j| {
555                (0..p)
556                    .map(|i| {
557                        let a_ij = if i < self.a_marginal.len() && j < self.a_marginal[i].len() {
558                            self.a_marginal[i][j]
559                        } else {
560                            0.0
561                        };
562                        nu[i] * a_ij
563                    })
564                    .sum::<f64>()
565            })
566            .collect();
567
568        let mu: Vec<f64> = (0..n).map(|j| (theta[j] - at_nu[j]).max(0.0)).collect();
569
570        self.last_mu = Some(mu.clone());
571        self.last_nu = Some(nu);
572        self.last_theta = Some(theta.to_vec());
573
574        Ok(mu)
575    }
576
577    /// Backward pass: compute dL/dθ via KKT sensitivity.
578    ///
579    /// At the optimal μ*, the KKT conditions of the QP are:
580    ///
581    ///   μ* - θ + Aᵀν* + s = 0   (stationarity, s = -min(μ*, 0))
582    ///   Aμ* = b                  (equality)
583    ///   μ* ≥ 0, s ≥ 0, s⊙μ* = 0  (complementarity)
584    ///
585    /// For the active variables (μ*_i > 0), we have s_i = 0, and the
586    /// KKT system reduces to an equality system on the support.
587    ///
588    /// # Arguments
589    /// * `dl_dmu` – upstream gradient dL/dμ (length n).
590    pub fn backward(&self, dl_dmu: &[f64]) -> OptimizeResult<Vec<f64>> {
591        let mu = self.last_mu.as_ref().ok_or_else(|| {
592            OptimizeError::ComputationError("SparseMap::backward called before forward".to_string())
593        })?;
594        let nu = self
595            .last_nu
596            .as_ref()
597            .ok_or_else(|| OptimizeError::ComputationError("No cached nu".to_string()))?;
598        let theta = self
599            .last_theta
600            .as_ref()
601            .ok_or_else(|| OptimizeError::ComputationError("No cached theta".to_string()))?;
602
603        let n = mu.len();
604        let tol = 1e-8_f64;
605
606        // Active support: μ*_i > 0
607        let support: Vec<usize> = (0..n).filter(|&i| mu[i] > tol).collect();
608
609        if support.is_empty() {
610            // All-zero solution: gradient is zero
611            return Ok(vec![0.0_f64; n]);
612        }
613
614        let s = support.len();
615        let p = nu.len();
616
617        // Build restricted system: Q_S = I_s, A_S = A[:, support]
618        let q_s: Vec<Vec<f64>> = (0..s)
619            .map(|i| {
620                let mut row = vec![0.0_f64; s];
621                row[i] = 1.0;
622                row
623            })
624            .collect();
625
626        let a_s: Vec<Vec<f64>> = (0..p)
627            .map(|i| {
628                support
629                    .iter()
630                    .map(|&j| {
631                        if i < self.a_marginal.len() && j < self.a_marginal[i].len() {
632                            self.a_marginal[i][j]
633                        } else {
634                            0.0
635                        }
636                    })
637                    .collect()
638            })
639            .collect();
640
641        // Restricted primal and dual
642        let x_s: Vec<f64> = support
643            .iter()
644            .map(|&j| if j < mu.len() { mu[j] } else { 0.0 })
645            .collect();
646
647        let dl_dx_s: Vec<f64> = support
648            .iter()
649            .map(|&j| if j < dl_dmu.len() { dl_dmu[j] } else { 0.0 })
650            .collect();
651
652        // KKT sensitivity on the restricted system
653        let kkt_grad = kkt_sensitivity(&q_s, &a_s, &x_s, nu, &dl_dx_s)?;
654
655        // Expand gradient back to full n: dL/dθ_j = dx_adj_j for active, 0 for inactive
656        let mut dl_dtheta = vec![0.0_f64; n];
657        for (idx, &j) in support.iter().enumerate() {
658            if idx < kkt_grad.dx_adj.len() {
659                dl_dtheta[j] = kkt_grad.dx_adj[idx];
660            }
661        }
662
663        let _ = theta;
664        Ok(dl_dtheta)
665    }
666
667    /// Project a vector onto the probability simplex.
668    ///
669    /// Solves: argmin_{μ ≥ 0, Σμ = 1} ||μ - v||²
670    ///
671    /// Uses the O(n log n) sorting algorithm.
672    pub fn project_simplex(v: &[f64]) -> Vec<f64> {
673        let n = v.len();
674        if n == 0 {
675            return vec![];
676        }
677
678        let mut u: Vec<f64> = v.to_vec();
679        u.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
680
681        let mut cssv = 0.0_f64;
682        let mut rho = 0_usize;
683        for j in 0..n {
684            cssv += u[j];
685            let tau = (cssv - 1.0) / (j + 1) as f64;
686            if tau < u[j] {
687                rho = j;
688            }
689        }
690
691        let cssv_rho: f64 = u[..=rho].iter().sum();
692        let theta = (cssv_rho - 1.0) / (rho + 1) as f64;
693
694        v.iter().map(|&vi| (vi - theta).max(0.0)).collect()
695    }
696}
697
698// ─────────────────────────────────────────────────────────────────────────────
699// Tests
700// ─────────────────────────────────────────────────────────────────────────────
701
702#[cfg(test)]
703mod tests {
704    use super::*;
705
706    /// Simple linear optimizer: y* = argmax_y θᵀy s.t. y ∈ {0, 1}^n
707    /// (i.e., select the maximum-score element).
708    fn argmax_binary(theta: &[f64]) -> Vec<f64> {
709        if theta.is_empty() {
710            return vec![];
711        }
712        let max_idx = theta
713            .iter()
714            .enumerate()
715            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
716            .map(|(i, _)| i)
717            .unwrap_or(0);
718        let mut y = vec![0.0_f64; theta.len()];
719        y[max_idx] = 1.0;
720        y
721    }
722
723    /// Simple sort optimizer: returns normalized rank vector.
724    fn soft_sort_optimizer(theta: &[f64]) -> Vec<f64> {
725        let n = theta.len();
726        if n == 0 {
727            return vec![];
728        }
729        let mut indexed: Vec<(f64, usize)> = theta
730            .iter()
731            .cloned()
732            .enumerate()
733            .map(|(i, v)| (v, i))
734            .collect();
735        indexed.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
736        let mut rank = vec![0.0_f64; n];
737        for (r, (_, i)) in indexed.iter().enumerate() {
738            rank[*i] = (n - r) as f64 / n as f64;
739        }
740        rank
741    }
742
743    #[test]
744    fn test_perturbed_optimizer_config_default() {
745        let cfg = PerturbedOptimizerConfig::default();
746        assert_eq!(cfg.n_samples, 20);
747        assert!((cfg.sigma - 1.0).abs() < 1e-15);
748    }
749
750    #[test]
751    fn test_perturbed_optimizer_forward_shape() {
752        let mut opt = PerturbedOptimizer::new(argmax_binary);
753        let theta = vec![1.0, 2.0, 3.0_f64];
754
755        let y = opt.forward(&theta).expect("Forward failed");
756        assert_eq!(y.len(), 3, "Output length should match input");
757        // Each y_i in [0, 1] (since outputs are binary)
758        for yi in &y {
759            assert!(*yi >= 0.0 && *yi <= 1.0, "y_i = {} should be in [0, 1]", yi);
760        }
761    }
762
763    #[test]
764    fn test_perturbed_optimizer_forward_distribution_sums_to_one() {
765        // For argmax binary, mean over samples should sum to ~1 when sigma is small
766        let cfg = PerturbedOptimizerConfig {
767            n_samples: 100,
768            sigma: 0.1, // small sigma → less randomness
769            seed: 123,
770        };
771        let mut opt = PerturbedOptimizer::with_config(argmax_binary, cfg);
772        let theta = vec![1.0, 5.0, 2.0_f64]; // θ[1] is largest
773
774        let y = opt.forward(&theta).expect("Forward failed");
775        let sum: f64 = y.iter().sum();
776        assert!(
777            (sum - 1.0).abs() < 0.05,
778            "Sum = {} (expected ~1.0 for binary argmax)",
779            sum
780        );
781    }
782
783    #[test]
784    fn test_perturbed_optimizer_gradient_sign() {
785        // For linear loss L = -y[0], dL/dy = [-1, 0, ..., 0]
786        // The gradient dL/dθ[0] should be negative when θ[0] is large
787        // (increasing θ[0] increases y[0] which increases loss -y[0]... wait, decreases)
788        // Actually for L = sum(-dL_dy * y): increasing θ[0] → y[0] increases → L = -y[0] decreases
789        // So dL/dθ[0] < 0... no wait: dL/dθ[0] = dL/dy * dy/dθ
790        // dL/dy = [-1, 0] (for L = -y[0])
791        // When θ[0] increases, p(y[0]=1) increases, so E[y[0]] increases
792        // dL/dθ[0] = dL/dE[y[0]] * dE[y[0]]/dθ[0] = -1 * positive = negative
793        // But we pass dl_dy = [1, 0, 0] (loss = y[0]), so dL/dθ[0] should be positive.
794
795        let cfg = PerturbedOptimizerConfig {
796            n_samples: 1000,
797            sigma: 1.0,
798            seed: 42,
799        };
800        let mut opt = PerturbedOptimizer::with_config(argmax_binary, cfg);
801        let theta = vec![2.0, 0.0, 0.0_f64];
802
803        let _y = opt.forward(&theta).expect("Forward failed");
804
805        // L = y[0], dL/dy = [1, 0, 0]
806        // We expect dL/dθ[0] > 0 (increasing θ[0] → more likely to pick index 0 → L increases)
807        let grad = opt
808            .gradient(&theta, &[1.0, 0.0, 0.0])
809            .expect("Gradient failed");
810
811        assert_eq!(grad.len(), 3);
812        // The gradient should have the correct sign: dL/dθ[0] > 0
813        // (positive because increasing θ[0] increases E[y[0]])
814        // With enough samples the sign should be correct
815        assert!(
816            grad[0] > -0.5, // Allow some MC variance
817            "grad[0] = {} should be roughly positive",
818            grad[0]
819        );
820    }
821
822    #[test]
823    fn test_perturbed_optimizer_gradient_shape() {
824        let mut opt = PerturbedOptimizer::new(argmax_binary);
825        let theta = vec![1.0, 2.0, 3.0_f64];
826
827        let _y = opt.forward(&theta).expect("Forward failed");
828        let grad = opt
829            .gradient(&theta, &[1.0, 0.0, 0.0])
830            .expect("Gradient failed");
831
832        assert_eq!(grad.len(), 3);
833        for gi in &grad {
834            assert!(gi.is_finite(), "grad not finite");
835        }
836    }
837
838    #[test]
839    fn test_perturbed_optimizer_reinforce_shape() {
840        let mut opt = PerturbedOptimizer::new(soft_sort_optimizer);
841        let theta = vec![1.0, 3.0, 2.0_f64];
842
843        let _y = opt.forward(&theta).expect("Forward failed");
844        let grad = opt
845            .reinforce_gradient(&theta, &[0.0, 1.0, 0.0])
846            .expect("REINFORCE failed");
847
848        assert_eq!(grad.len(), 3);
849        for gi in &grad {
850            assert!(gi.is_finite(), "REINFORCE grad not finite");
851        }
852    }
853
854    #[test]
855    fn test_perturbed_optimizer_no_forward_error() {
856        let opt = PerturbedOptimizer::new(argmax_binary);
857        let result = opt.gradient(&[1.0, 2.0], &[1.0, 0.0]);
858        assert!(result.is_err(), "Should error without forward pass");
859    }
860
861    #[test]
862    fn test_sparsemap_simplex_projection() {
863        // Simple 1D simplex: μ ∈ [0, 1], Σμ = 1
864        let mut sm = SparseMap::simplex(3);
865        let theta = vec![1.0, 2.0, 0.5_f64];
866
867        let mu = sm.forward(&theta).expect("SparseMap forward failed");
868
869        // Check μ ≥ 0
870        for mi in &mu {
871            assert!(*mi >= -1e-6, "μ < 0: {}", mi);
872        }
873
874        // Check Σμ ≈ 1 (simplex constraint)
875        let sum: f64 = mu.iter().sum();
876        assert!(
877            (sum - 1.0).abs() < 0.1,
878            "Σμ = {} (expected ~1.0 for simplex)",
879            sum
880        );
881    }
882
883    #[test]
884    fn test_sparsemap_backward_shape() {
885        let mut sm = SparseMap::simplex(4);
886        let theta = vec![1.0, 3.0, 2.0, 0.5_f64];
887
888        let _mu = sm.forward(&theta).expect("SparseMap forward failed");
889        let dl_dtheta = sm
890            .backward(&[1.0, 0.0, 0.0, 0.0])
891            .expect("SparseMap backward failed");
892
893        assert_eq!(dl_dtheta.len(), 4, "Gradient length mismatch");
894        for gi in &dl_dtheta {
895            assert!(gi.is_finite(), "SparseMap gradient not finite");
896        }
897    }
898
899    #[test]
900    fn test_sparsemap_no_forward_error() {
901        let sm = SparseMap::simplex(3);
902        let result = sm.backward(&[1.0, 0.0, 0.0]);
903        assert!(result.is_err(), "Should error without forward pass");
904    }
905
906    #[test]
907    fn test_project_simplex_properties() {
908        let v = vec![0.5, 1.5, -0.3, 2.0_f64];
909        let p = SparseMap::project_simplex(&v);
910
911        // Σ = 1
912        let sum: f64 = p.iter().sum();
913        assert!(
914            (sum - 1.0).abs() < 1e-10,
915            "Simplex sum = {} (expected 1.0)",
916            sum
917        );
918
919        // All ≥ 0
920        for pi in &p {
921            assert!(*pi >= -1e-12, "Negative simplex component: {}", pi);
922        }
923    }
924
925    #[test]
926    fn test_project_simplex_uniform_input() {
927        // For uniform input [0.5, 0.5], projection = [0.5, 0.5]
928        let v = vec![0.5, 0.5_f64];
929        let p = SparseMap::project_simplex(&v);
930        assert!((p[0] - 0.5).abs() < 1e-10);
931        assert!((p[1] - 0.5).abs() < 1e-10);
932    }
933
934    #[test]
935    fn test_xorshift_reproducible() {
936        let mut rng1 = Xorshift64::new(42);
937        let mut rng2 = Xorshift64::new(42);
938        for _ in 0..100 {
939            assert_eq!(rng1.next_u64(), rng2.next_u64());
940        }
941    }
942
943    #[test]
944    fn test_xorshift_normal_finite() {
945        let mut rng = Xorshift64::new(12345);
946        for _ in 0..100 {
947            let v = rng.normal();
948            assert!(v.is_finite(), "Normal sample not finite: {}", v);
949        }
950    }
951}