Skip to main content

scirs2_optimize/differentiable_optimization/
combinatorial.rs

1//! Differentiable Combinatorial Optimization
2//!
3//! This module implements differentiable relaxations of combinatorial
4//! optimization problems, enabling gradients to flow through discrete
5//! decision layers in end-to-end learning pipelines.
6//!
7//! # Algorithms
8//!
9//! - **SparseMAP** (Niculae & Blondel, 2017): Sparse structured prediction
10//!   via QP over the marginal polytope. Yields sparse probability distributions
11//!   over combinatorial structures with exact gradients via the active-set
12//!   theorem.
13//!
14//! - **Perturbed Optimizers** (Berthet et al., 2020): Sample-based
15//!   differentiable argmax using additive Gaussian noise, enabling unbiased
16//!   gradient estimates through any black-box combinatorial solver.
17//!
18//! - **Differentiable Sorting** (Cuturi & Doucet, 2017): Regularised isotonic
19//!   regression for soft sorting and ranking.
20//!
21//! - **Differentiable Top-K**: Entropy-regularised LP relaxation of the hard
22//!   top-k selector.
23//!
24//! # References
25//! - Niculae & Blondel (2017). "A regularized framework for sparse and
26//!   structured neural attention." NeurIPS.
27//! - Berthet et al. (2020). "Learning with Differentiable Perturbed Optimizers."
28//!   NeurIPS.
29//! - Blondel et al. (2020). "Fast Differentiable Sorting and Ranking." ICML.
30
31use scirs2_core::num_traits::{Float, FromPrimitive};
32use std::fmt::Debug;
33
34use crate::error::{OptimizeError, OptimizeResult};
35
36// ─────────────────────────────────────────────────────────────────────────────
37// SparseMAP
38// ─────────────────────────────────────────────────────────────────────────────
39
40/// Type of combinatorial structure defining the polytope for SparseMAP.
41#[derive(Debug, Clone)]
42#[non_exhaustive]
43pub enum StructureType {
44    /// Standard probability simplex: Σμ_i = 1, μ_i ≥ 0.
45    Simplex,
46    /// Knapsack polytope: Σw_i · μ_i ≤ capacity, 0 ≤ μ_i ≤ 1.
47    Knapsack {
48        /// Weights for each item.
49        weights: Vec<f64>,
50        /// Knapsack capacity.
51        capacity: f64,
52    },
53    /// Birkhoff polytope (doubly-stochastic matrices): permutation marginals.
54    /// `dim` is the side length (number of items to rank).
55    Permutation {
56        /// Number of items.
57        dim: usize,
58    },
59}
60
61impl Default for StructureType {
62    fn default() -> Self {
63        StructureType::Simplex
64    }
65}
66
67/// Configuration for the SparseMAP solver.
68#[derive(Debug, Clone)]
69pub struct SparsemapConfig {
70    /// Maximum number of active-set / projected-gradient iterations.
71    pub max_iter: usize,
72    /// Convergence tolerance (dual gap or gradient norm).
73    pub tol: f64,
74    /// Combinatorial structure type.
75    pub structure_type: StructureType,
76    /// Step size for projected-gradient updates.
77    pub step_size: f64,
78}
79
80impl Default for SparsemapConfig {
81    fn default() -> Self {
82        Self {
83            max_iter: 1000,
84            tol: 1e-6,
85            structure_type: StructureType::default(),
86            step_size: 0.1,
87        }
88    }
89}
90
91/// Result of SparseMAP.
92#[derive(Debug, Clone)]
93pub struct SparsemapResult<F> {
94    /// Sparse probability distribution over combinatorial structures.
95    pub solution: Vec<F>,
96    /// Indices of atoms with non-zero weight (the active support).
97    pub support: Vec<usize>,
98    /// Dual variables at optimality (Lagrange multipliers for equality / active
99    /// inequality constraints).
100    pub dual: Vec<F>,
101    /// Number of iterations performed.
102    pub n_iters: usize,
103}
104
105/// Project a vector onto the probability simplex Δ^{n-1} = {μ | Σμ_i=1, μ≥0}.
106///
107/// Uses Duchi et al. (2008) O(n log n) algorithm.
108fn project_simplex<F>(v: &[F]) -> Vec<F>
109where
110    F: Float + FromPrimitive + Debug + Clone,
111{
112    let n = v.len();
113    let mut u: Vec<F> = v.to_vec();
114    // Sort descending
115    u.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
116
117    let mut cssv = F::zero();
118    let mut rho = 0usize;
119    for (j, &uj) in u.iter().enumerate() {
120        cssv = cssv + uj;
121        let j_f = F::from_usize(j + 1).unwrap_or(F::one());
122        let one = F::one();
123        if uj - (cssv - one) / j_f > F::zero() {
124            rho = j;
125        }
126    }
127
128    let rho_f = F::from_usize(rho + 1).unwrap_or(F::one());
129    let one = F::one();
130    // Recompute cssv up to rho
131    let mut cssv2 = F::zero();
132    for uj in u.iter().take(rho + 1) {
133        cssv2 = cssv2 + *uj;
134    }
135    let theta = (cssv2 - one) / rho_f;
136
137    v.iter()
138        .map(|&vi| {
139            let diff = vi - theta;
140            if diff > F::zero() {
141                diff
142            } else {
143                F::zero()
144            }
145        })
146        .collect()
147}
148
149/// Project a vector onto the knapsack polytope with binary variables and
150/// capacity constraint: {μ | Σw_i·μ_i ≤ cap, 0 ≤ μ_i ≤ 1}.
151///
152/// Uses greedy fractional knapsack rounding followed by projected gradient.
153fn project_knapsack<F>(v: &[F], weights: &[f64], capacity: f64) -> Vec<F>
154where
155    F: Float + FromPrimitive + Debug + Clone,
156{
157    let n = v.len();
158    // Clip to [0,1] first
159    let mut mu: Vec<F> = v
160        .iter()
161        .map(|&vi| {
162            if vi < F::zero() {
163                F::zero()
164            } else if vi > F::one() {
165                F::one()
166            } else {
167                vi
168            }
169        })
170        .collect();
171
172    // Check if capacity is satisfied; if so, return clipped value
173    let total_weight: f64 = (0..n)
174        .map(|i| weights.get(i).copied().unwrap_or(1.0) * mu[i].to_f64().unwrap_or(0.0))
175        .sum();
176
177    if total_weight <= capacity + 1e-12 {
178        return mu;
179    }
180
181    // Binary search on the Lagrange multiplier λ for the capacity constraint:
182    //   μ_i(λ) = clip(v_i / (1 + λ·w_i), 0, 1)   => capacity constraint
183    let mut lo = 0.0_f64;
184    let mut hi = 1e8_f64;
185
186    for _ in 0..200 {
187        let mid = (lo + hi) / 2.0;
188        let w_total: f64 = (0..n)
189            .map(|i| {
190                let wi = weights.get(i).copied().unwrap_or(1.0);
191                let vi = v[i].to_f64().unwrap_or(0.0);
192                let mu_i = (vi / (1.0 + mid * wi)).clamp(0.0, 1.0);
193                wi * mu_i
194            })
195            .sum();
196        if w_total > capacity {
197            lo = mid;
198        } else {
199            hi = mid;
200        }
201    }
202
203    let lambda = (lo + hi) / 2.0;
204    mu = (0..n)
205        .map(|i| {
206            let wi = weights.get(i).copied().unwrap_or(1.0);
207            let vi = v[i].to_f64().unwrap_or(0.0);
208            let val = (vi / (1.0 + lambda * wi)).clamp(0.0, 1.0);
209            F::from_f64(val).unwrap_or(F::zero())
210        })
211        .collect();
212    mu
213}
214
215/// Solve SparseMAP via projected gradient descent on the regularised QP.
216///
217/// For `StructureType::Simplex` this is equivalent to computing the Euclidean
218/// projection of `scores` onto the probability simplex, which has a closed-form
219/// O(n log n) solution.  For other structures, it falls back to iterative
220/// projected gradient.
221///
222/// # Arguments
223/// * `scores` – score vector θ ∈ ℝ^d.
224/// * `config` – solver configuration.
225///
226/// # Returns
227/// [`SparsemapResult`] containing the sparse distribution, active support,
228/// dual variables, and iteration count.
229pub fn sparsemap<F>(scores: &[F], config: &SparsemapConfig) -> OptimizeResult<SparsemapResult<F>>
230where
231    F: Float + FromPrimitive + Debug + Clone,
232{
233    if scores.is_empty() {
234        return Err(OptimizeError::InvalidInput(
235            "scores vector must be non-empty".into(),
236        ));
237    }
238
239    let n = scores.len();
240    let tol_f = F::from_f64(config.tol).unwrap_or(F::epsilon());
241
242    let solution: Vec<F>;
243    let n_iters: usize;
244    let dual: Vec<F>;
245
246    match &config.structure_type {
247        StructureType::Simplex => {
248            // Closed-form: Euclidean projection onto probability simplex.
249            solution = project_simplex(scores);
250            n_iters = 1;
251            // The dual variable λ for the equality constraint Σμ=1 equals
252            // the threshold used in the projection: λ = (Σ_{i∈S} θ_i - 1)/|S|
253            let support_sum: F =
254                solution.iter().fold(
255                    F::zero(),
256                    |acc, &x| {
257                        if x > F::zero() {
258                            acc + x
259                        } else {
260                            acc
261                        }
262                    },
263                );
264            let support_count = solution.iter().filter(|&&x| x > F::zero()).count();
265            let count_f = F::from_usize(support_count).unwrap_or(F::one());
266            let lambda = if count_f > F::zero() {
267                (support_sum - F::one()) / count_f
268            } else {
269                F::zero()
270            };
271            dual = vec![lambda];
272        }
273
274        StructureType::Knapsack { weights, capacity } => {
275            // Projected gradient descent on knapsack polytope.
276            let mut mu: Vec<F> = vec![F::zero(); n];
277            let step = F::from_f64(config.step_size).unwrap_or(F::epsilon());
278
279            let mut iter = 0usize;
280            let mut prev_obj = F::neg_infinity();
281
282            loop {
283                // Gradient of ½||μ-θ||² w.r.t. μ is (μ - θ)
284                let grad: Vec<F> = mu.iter().zip(scores.iter()).map(|(&m, &s)| m - s).collect();
285
286                // Gradient step: μ ← μ - step * grad = μ - step*(μ-θ)
287                let mu_new: Vec<F> = mu
288                    .iter()
289                    .zip(grad.iter())
290                    .map(|(&m, &g)| m - step * g)
291                    .collect();
292
293                // Project onto knapsack polytope
294                let mu_proj = project_knapsack(&mu_new, weights, *capacity);
295
296                // Compute objective
297                let obj = mu_proj
298                    .iter()
299                    .zip(scores.iter())
300                    .fold(F::zero(), |acc, (&m, &s)| {
301                        let diff = m - s;
302                        acc + diff * diff
303                    });
304                let half = F::from_f64(0.5).unwrap_or(F::one());
305                let obj = obj * half;
306
307                let diff = (obj - prev_obj).abs();
308                mu = mu_proj;
309                prev_obj = obj;
310                iter += 1;
311
312                if iter >= config.max_iter || diff < tol_f {
313                    break;
314                }
315            }
316
317            solution = mu;
318            n_iters = iter;
319            // Dual: reduced costs for capacity constraint
320            let total_w: f64 = (0..n)
321                .map(|i| {
322                    weights.get(i).copied().unwrap_or(1.0) * solution[i].to_f64().unwrap_or(0.0)
323                })
324                .sum();
325            let slack = *capacity - total_w;
326            let lambda_val = if slack.abs() < 1e-8 { -1.0 } else { 0.0 };
327            dual = vec![F::from_f64(lambda_val).unwrap_or(F::zero())];
328        }
329
330        StructureType::Permutation { dim } => {
331            // For permutations, solve QP over Birkhoff polytope via Sinkhorn
332            // projection (alternating row/column normalisation).
333            let d = *dim;
334            if scores.len() != d * d {
335                return Err(OptimizeError::InvalidInput(format!(
336                    "Permutation structure requires d²={} scores but got {}",
337                    d * d,
338                    scores.len()
339                )));
340            }
341
342            // Initialise as uniform doubly-stochastic matrix
343            let inv_d = F::from_f64(1.0 / d as f64).unwrap_or(F::one());
344            let mut mu: Vec<F> = vec![inv_d; d * d];
345            let step = F::from_f64(config.step_size).unwrap_or(F::epsilon());
346            let mut iter = 0usize;
347
348            loop {
349                // Gradient step
350                let mu_step: Vec<F> = mu
351                    .iter()
352                    .zip(scores.iter())
353                    .map(|(&m, &s)| m - step * (m - s))
354                    .collect();
355
356                // Sinkhorn projection: alternate row / column normalisation
357                let mut m_sink = mu_step;
358                for _ in 0..50 {
359                    // Row normalisation
360                    for row in 0..d {
361                        let row_sum: F = (0..d)
362                            .map(|col| m_sink[row * d + col])
363                            .fold(F::zero(), |a, b| a + b);
364                        if row_sum > F::zero() {
365                            for col in 0..d {
366                                m_sink[row * d + col] = m_sink[row * d + col] / row_sum;
367                            }
368                        }
369                    }
370                    // Column normalisation
371                    for col in 0..d {
372                        let col_sum: F = (0..d)
373                            .map(|row| m_sink[row * d + col])
374                            .fold(F::zero(), |a, b| a + b);
375                        if col_sum > F::zero() {
376                            for row in 0..d {
377                                m_sink[row * d + col] = m_sink[row * d + col] / col_sum;
378                            }
379                        }
380                    }
381                }
382
383                // Check convergence
384                let change: F = mu
385                    .iter()
386                    .zip(m_sink.iter())
387                    .map(|(&a, &b)| {
388                        let d = a - b;
389                        d * d
390                    })
391                    .fold(F::zero(), |a, b| a + b);
392
393                mu = m_sink;
394                iter += 1;
395
396                if iter >= config.max_iter || change < tol_f * tol_f {
397                    break;
398                }
399            }
400
401            solution = mu;
402            n_iters = iter;
403            dual = vec![F::zero(); 2 * d]; // row + column duals
404        }
405    }
406
407    // Extract active support
408    let support: Vec<usize> = solution
409        .iter()
410        .enumerate()
411        .filter_map(|(i, &v)| {
412            if v > F::from_f64(1e-9).unwrap_or(F::zero()) {
413                Some(i)
414            } else {
415                None
416            }
417        })
418        .collect();
419
420    Ok(SparsemapResult {
421        solution,
422        support,
423        dual,
424        n_iters,
425    })
426}
427
428/// Compute gradient of a loss through SparseMAP via the active-set theorem.
429///
430/// For SparseMAP, the Jacobian of the optimal solution μ*(θ) w.r.t. θ is:
431/// ```text
432/// dμ*/dθ = Π_S  (projection onto tangent space of active support S)
433/// ```
434/// Concretely, for the simplex case, only active coordinates (support) can
435/// receive gradient.  The backward pass is:
436/// ```text
437/// dL/dθ = Π_S (upstream_grad)
438///       = upstream_grad[support] - mean(upstream_grad[support]) · 1_S
439/// ```
440/// This is the projection of `upstream_grad` onto the tangent space of the
441/// simplex face defined by the active support.
442///
443/// # Arguments
444/// * `result` – the forward-pass [`SparsemapResult`].
445/// * `upstream_grad` – gradient of the scalar loss w.r.t. `solution` (∂L/∂μ).
446///
447/// # Returns
448/// Gradient ∂L/∂θ of the same length as the score input.
449pub fn sparsemap_gradient<F>(result: &SparsemapResult<F>, upstream_grad: &[F]) -> Vec<F>
450where
451    F: Float + FromPrimitive + Debug + Clone,
452{
453    let n = result.solution.len();
454    if upstream_grad.len() != n {
455        // Length mismatch — return zero gradient rather than panic
456        return vec![F::zero(); n];
457    }
458
459    let s = &result.support;
460    if s.is_empty() {
461        return vec![F::zero(); n];
462    }
463
464    // Restrict upstream gradient to active support
465    let s_size = F::from_usize(s.len()).unwrap_or(F::one());
466    let mean_s: F = s
467        .iter()
468        .map(|&i| upstream_grad[i])
469        .fold(F::zero(), |a, b| a + b)
470        / s_size;
471
472    // Projected gradient: g_i - mean(g_S)  for i ∈ S, else 0
473    let mut grad = vec![F::zero(); n];
474    for &i in s {
475        grad[i] = upstream_grad[i] - mean_s;
476    }
477    grad
478}
479
480// ─────────────────────────────────────────────────────────────────────────────
481// Perturbed Optimizers
482// ─────────────────────────────────────────────────────────────────────────────
483
484/// Configuration for the Perturbed Optimizer.
485#[derive(Debug, Clone)]
486pub struct PerturbedOptimizerConfig {
487    /// Number of Monte Carlo samples.
488    pub n_samples: usize,
489    /// Perturbation magnitude ε.
490    pub epsilon: f64,
491    /// RNG seed for reproducibility.
492    pub seed: u64,
493}
494
495impl Default for PerturbedOptimizerConfig {
496    fn default() -> Self {
497        Self {
498            n_samples: 100,
499            epsilon: 0.1,
500            seed: 42,
501        }
502    }
503}
504
505/// Differentiable argmax via additive Gaussian perturbations.
506///
507/// Estimates `E_Z[argmax(θ + ε·Z)]` where `Z ~ N(0, I)`.  The forward pass
508/// returns soft assignment probabilities; the backward pass returns an
509/// unbiased gradient estimate via the score-function estimator.
510#[derive(Debug, Clone)]
511pub struct PerturbedOptimizer {
512    config: PerturbedOptimizerConfig,
513}
514
515impl PerturbedOptimizer {
516    /// Create a new perturbed optimizer with the given configuration.
517    pub fn new(config: PerturbedOptimizerConfig) -> Self {
518        Self { config }
519    }
520
521    /// Forward pass: estimate E[argmax(θ + εZ)] via Monte Carlo.
522    ///
523    /// Returns a probability vector of length `scores.len()`.
524    pub fn forward<F>(&self, scores: &[F]) -> OptimizeResult<Vec<F>>
525    where
526        F: Float + FromPrimitive + Debug + Clone,
527    {
528        if scores.is_empty() {
529            return Err(OptimizeError::InvalidInput(
530                "scores must be non-empty".into(),
531            ));
532        }
533        let n = scores.len();
534        let mut counts = vec![0usize; n];
535        let eps = self.config.epsilon;
536
537        // Deterministic PRNG (xoshiro-style via splitmix64)
538        let mut rng_state = self.config.seed;
539        let n_samples = self.config.n_samples;
540
541        for _ in 0..n_samples {
542            // Sample argmax of perturbed scores
543            let mut best_idx = 0usize;
544            let mut best_val = F::neg_infinity();
545
546            for i in 0..n {
547                let z = sample_standard_normal(&mut rng_state);
548                let perturbed = scores[i] + F::from_f64(eps * z).unwrap_or(F::zero());
549                if perturbed > best_val {
550                    best_val = perturbed;
551                    best_idx = i;
552                }
553            }
554            counts[best_idx] += 1;
555        }
556
557        let n_samples_f = F::from_usize(n_samples).unwrap_or(F::one());
558        let probs: Vec<F> = counts
559            .iter()
560            .map(|&c| F::from_usize(c).unwrap_or(F::zero()) / n_samples_f)
561            .collect();
562
563        Ok(probs)
564    }
565
566    /// Backward pass: gradient estimate via score-function / log-derivative
567    /// trick.
568    ///
569    /// # Formula
570    /// ```text
571    /// dL/dθ ≈ (1 / (ε² · n)) Σ_i [<argmax(θ+εZ_i), upstream>] · Z_i
572    /// ```
573    ///
574    /// # Arguments
575    /// * `scores` – original (unperturbed) scores.
576    /// * `upstream` – upstream gradient ∂L/∂p (same shape as `forward` output).
577    ///
578    /// # Returns
579    /// Gradient ∂L/∂θ of the same length as `scores`.
580    pub fn backward<F>(&self, scores: &[F], upstream: &[F]) -> OptimizeResult<Vec<F>>
581    where
582        F: Float + FromPrimitive + Debug + Clone,
583    {
584        if scores.len() != upstream.len() {
585            return Err(OptimizeError::InvalidInput(
586                "scores and upstream must have the same length".into(),
587            ));
588        }
589        let n = scores.len();
590        let eps = self.config.epsilon;
591        let eps_sq = eps * eps;
592        let n_samples = self.config.n_samples;
593
594        let mut grad = vec![F::zero(); n];
595        let mut rng_state = self.config.seed;
596
597        for _ in 0..n_samples {
598            // Sample perturbed noise vector and find argmax
599            let noise: Vec<f64> = (0..n)
600                .map(|_| sample_standard_normal(&mut rng_state))
601                .collect();
602
603            let mut best_idx = 0usize;
604            let mut best_val = F::neg_infinity();
605            for i in 0..n {
606                let perturbed = scores[i] + F::from_f64(eps * noise[i]).unwrap_or(F::zero());
607                if perturbed > best_val {
608                    best_val = perturbed;
609                    best_idx = i;
610                }
611            }
612
613            // argmax is a one-hot vector e_{best_idx}
614            // <e_{best_idx}, upstream> = upstream[best_idx]
615            let dot = upstream[best_idx];
616
617            // Accumulate: grad += dot * Z / ε²
618            for i in 0..n {
619                let zi = F::from_f64(noise[i]).unwrap_or(F::zero());
620                let eps_sq_f = F::from_f64(eps_sq).unwrap_or(F::one());
621                grad[i] = grad[i] + dot * zi / eps_sq_f;
622            }
623        }
624
625        let n_f = F::from_usize(n_samples).unwrap_or(F::one());
626        for g in &mut grad {
627            *g = *g / n_f;
628        }
629
630        Ok(grad)
631    }
632}
633
634/// Splitmix64 PRNG step, returns a value in [0, 2^64).
635fn splitmix64(state: &mut u64) -> u64 {
636    *state = state.wrapping_add(0x9e3779b97f4a7c15);
637    let mut z = *state;
638    z = (z ^ (z >> 30)).wrapping_mul(0xbf58476d1ce4e5b9);
639    z = (z ^ (z >> 27)).wrapping_mul(0x94d049bb133111eb);
640    z ^ (z >> 31)
641}
642
643/// Box-Muller transform to generate a standard normal sample.
644fn sample_standard_normal(state: &mut u64) -> f64 {
645    let u1_raw = splitmix64(state);
646    let u2_raw = splitmix64(state);
647    // Map to (0, 1]
648    let u1 = (u1_raw as f64 + 0.5) / (u64::MAX as f64 + 1.0);
649    let u2 = (u2_raw as f64 + 0.5) / (u64::MAX as f64 + 1.0);
650    let two_pi = 2.0 * std::f64::consts::PI;
651    (-2.0 * u1.ln()).sqrt() * (two_pi * u2).cos()
652}
653
654// ─────────────────────────────────────────────────────────────────────────────
655// Differentiable Sorting and Ranking
656// ─────────────────────────────────────────────────────────────────────────────
657
658/// Compute the soft sort of a vector via regularised isotonic regression.
659///
660/// Returns a non-decreasing sequence of the same length as `x`.  At
661/// `temperature → 0` this recovers the exact sorted sequence; at high
662/// temperature the output approaches the element-wise mean.
663///
664/// Internally uses the Pool Adjacent Violators (PAV) algorithm to solve the
665/// isotonic regression:
666/// ```text
667/// min_{ŝ non-decreasing} Σ (ŝ_i - s_i)²  +  temperature · regularisation
668/// ```
669///
670/// # Arguments
671/// * `x` – input vector.
672/// * `temperature` – controls softness (≥ 0; typical values 0.01–1.0).
673///
674/// # Returns
675/// Non-decreasing vector of the same length as `x`.
676pub fn soft_sort<F>(x: &[F], temperature: F) -> OptimizeResult<Vec<F>>
677where
678    F: Float + FromPrimitive + Debug + Clone,
679{
680    if x.is_empty() {
681        return Err(OptimizeError::InvalidInput(
682            "input vector must be non-empty".into(),
683        ));
684    }
685
686    let n = x.len();
687    // Sort indices to get the sorted values
688    let mut sorted_x: Vec<F> = x.to_vec();
689    sorted_x.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
690
691    // Apply temperature-based regularisation:
692    // For positive temperature, we compute the regularised soft sort by blending
693    // sorted values with the mean (which acts as the max-entropy limit).
694    if temperature == F::zero() {
695        return Ok(sorted_x);
696    }
697
698    let mean_val =
699        sorted_x.iter().fold(F::zero(), |a, b| a + *b) / F::from_usize(n).unwrap_or(F::one());
700
701    // Regularised isotonic regression via PAV on scores shifted toward mean
702    // The regularisation adds a quadratic penalty pulling toward mean.
703    // We implement PAV on the mixture: (1-t)·sorted + t·mean
704    let t_clamped = if temperature > F::one() {
705        F::one()
706    } else {
707        temperature
708    };
709    let one_minus_t = F::one() - t_clamped;
710
711    let mixed: Vec<F> = sorted_x
712        .iter()
713        .map(|&v| one_minus_t * v + t_clamped * mean_val)
714        .collect();
715
716    // PAV to ensure non-decreasingness (already guaranteed by initial sort,
717    // but the blending preserves it so PAV is a no-op here — kept for
718    // generality)
719    let result = pool_adjacent_violators(&mixed);
720
721    Ok(result)
722}
723
724/// Pool Adjacent Violators (PAV) algorithm for isotonic regression.
725/// Solves `min Σ(ŝ_i - s_i)² s.t. ŝ_1 ≤ ŝ_2 ≤ … ≤ ŝ_n`.
726fn pool_adjacent_violators<F>(s: &[F]) -> Vec<F>
727where
728    F: Float + FromPrimitive + Debug + Clone,
729{
730    let n = s.len();
731    // Each block stores (sum, count)
732    let mut blocks: Vec<(F, usize)> = s.iter().map(|&v| (v, 1)).collect();
733
734    let mut changed = true;
735    while changed {
736        changed = false;
737        let mut i = 0usize;
738        let mut new_blocks: Vec<(F, usize)> = Vec::with_capacity(blocks.len());
739        while i < blocks.len() {
740            let mut sum = blocks[i].0;
741            let mut cnt = blocks[i].1;
742            // Merge with next block if violates monotonicity
743            while i + 1 < blocks.len() {
744                let next_mean =
745                    blocks[i + 1].0 / F::from_usize(blocks[i + 1].1).unwrap_or(F::one());
746                let cur_mean = sum / F::from_usize(cnt).unwrap_or(F::one());
747                if cur_mean > next_mean {
748                    sum = sum + blocks[i + 1].0;
749                    cnt += blocks[i + 1].1;
750                    i += 1;
751                    changed = true;
752                } else {
753                    break;
754                }
755            }
756            new_blocks.push((sum, cnt));
757            i += 1;
758        }
759        blocks = new_blocks;
760    }
761
762    // Expand blocks back to length-n vector
763    let mut result = Vec::with_capacity(n);
764    for (sum, cnt) in blocks {
765        let mean = sum / F::from_usize(cnt).unwrap_or(F::one());
766        for _ in 0..cnt {
767            result.push(mean);
768        }
769    }
770    result
771}
772
773/// Compute soft ranks of elements in `x`.
774///
775/// Returns a vector of the same length as `x`, where each entry is the
776/// (1-indexed) soft rank of the corresponding element.  At `temperature → 0`
777/// this recovers the exact ranks (with ties broken by index).  At high
778/// temperature all ranks are pulled toward `(n+1)/2`.
779///
780/// # Arguments
781/// * `x` – input vector.
782/// * `temperature` – smoothing parameter (≥ 0).
783pub fn soft_rank<F>(x: &[F], temperature: F) -> OptimizeResult<Vec<F>>
784where
785    F: Float + FromPrimitive + Debug + Clone,
786{
787    if x.is_empty() {
788        return Err(OptimizeError::InvalidInput(
789            "input vector must be non-empty".into(),
790        ));
791    }
792    let n = x.len();
793    let one = F::one();
794    let n_f = F::from_usize(n).unwrap_or(one);
795
796    if temperature == F::zero() {
797        // Hard rank: rank each element by counting how many elements it exceeds
798        let ranks: Vec<F> = (0..n)
799            .map(|i| {
800                let rank = x.iter().filter(|&&v| v < x[i]).count();
801                F::from_usize(rank + 1).unwrap_or(one)
802            })
803            .collect();
804        return Ok(ranks);
805    }
806
807    // Soft rank via pairwise comparison with sigmoid smoothing:
808    // rank_i ≈ 1 + Σ_{j≠i} σ((x_i - x_j) / temperature)
809    let two = F::from_f64(2.0).unwrap_or(one);
810
811    let ranks: Vec<F> = (0..n)
812        .map(|i| {
813            let mut soft_rank_i = one; // starts at 1
814            for j in 0..n {
815                if i == j {
816                    continue;
817                }
818                let diff = (x[i] - x[j]) / temperature;
819                // σ(diff) = 1/(1+exp(-diff)), clipped for numerical safety
820                let diff_clamped = if diff < F::from_f64(-50.0).unwrap_or(-one) {
821                    F::from_f64(-50.0).unwrap_or(-one)
822                } else if diff > F::from_f64(50.0).unwrap_or(one) {
823                    F::from_f64(50.0).unwrap_or(one)
824                } else {
825                    diff
826                };
827                let sigmoid_val = one / (one + (-diff_clamped).exp());
828                soft_rank_i = soft_rank_i + sigmoid_val;
829            }
830            // Blend with mid-rank at high temperature to ensure sum = n(n+1)/2
831            let mid = (n_f + one) / two;
832            let t = if temperature > F::from_f64(10.0).unwrap_or(one) {
833                one
834            } else {
835                temperature / F::from_f64(10.0).unwrap_or(one)
836            };
837            (one - t) * soft_rank_i + t * mid
838        })
839        .collect();
840
841    Ok(ranks)
842}
843
844// ─────────────────────────────────────────────────────────────────────────────
845// Differentiable Top-K
846// ─────────────────────────────────────────────────────────────────────────────
847
848/// Differentiable top-k selector via entropy-regularised LP.
849///
850/// Solves the relaxed problem:
851/// ```text
852/// max_{p ∈ Δ^n, Σp_i = k}  <scores, p>  -  temperature · H(p)
853/// ```
854/// where H(p) = -Σ p_i log p_i is the entropy regulariser.  The solution
855/// has the closed form:
856/// ```text
857/// p_i = k · softmax(scores / temperature)_i
858/// ```
859/// normalised so that Σp_i = k.
860///
861/// At `temperature → 0`, `p` approaches the hard top-k indicator vector.
862///
863/// # Arguments
864/// * `scores` – input scores (arbitrary real values).
865/// * `k` – number of elements to select (1 ≤ k ≤ n).
866/// * `temperature` – regularisation strength (> 0 for differentiable;
867///   use a small value like 0.01 for near-hard top-k).
868///
869/// # Returns
870/// Soft indicator vector p in `[0,1]`^n with Sum(p_i) ~ k.
871pub fn diff_topk<F>(scores: &[F], k: usize, temperature: F) -> OptimizeResult<Vec<F>>
872where
873    F: Float + FromPrimitive + Debug + Clone,
874{
875    let n = scores.len();
876    if n == 0 {
877        return Err(OptimizeError::InvalidInput(
878            "scores must be non-empty".into(),
879        ));
880    }
881    if k == 0 || k > n {
882        return Err(OptimizeError::InvalidInput(format!(
883            "k must be in [1, {}] but got {}",
884            n, k
885        )));
886    }
887
888    let k_f = F::from_usize(k).unwrap_or(F::one());
889
890    if temperature == F::zero() {
891        // Hard top-k: indicator vector
892        let mut indexed: Vec<(usize, F)> = scores.iter().copied().enumerate().collect();
893        indexed.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
894        let mut result = vec![F::zero(); n];
895        for (idx, _) in indexed.iter().take(k) {
896            result[*idx] = F::one();
897        }
898        return Ok(result);
899    }
900
901    // Numerically stable softmax: subtract max before exponentiating
902    let max_score = scores
903        .iter()
904        .copied()
905        .fold(F::neg_infinity(), |a, b| if b > a { b } else { a });
906
907    let exp_scores: Vec<F> = scores
908        .iter()
909        .map(|&s| {
910            let scaled = (s - max_score) / temperature;
911            // Clamp to avoid underflow
912            let clamped = if scaled < F::from_f64(-80.0).unwrap_or(-F::one()) {
913                F::from_f64(-80.0).unwrap_or(-F::one())
914            } else {
915                scaled
916            };
917            clamped.exp()
918        })
919        .collect();
920
921    let sum_exp: F = exp_scores.iter().fold(F::zero(), |a, b| a + *b);
922    if sum_exp == F::zero() {
923        // All scores are -inf relative to max → uniform
924        let uniform = k_f / F::from_usize(n).unwrap_or(F::one());
925        return Ok(vec![uniform; n]);
926    }
927
928    let result: Vec<F> = exp_scores.iter().map(|&e| k_f * e / sum_exp).collect();
929
930    Ok(result)
931}
932
933// ─────────────────────────────────────────────────────────────────────────────
934// Tests
935// ─────────────────────────────────────────────────────────────────────────────
936
937#[cfg(test)]
938mod tests {
939    use super::*;
940
941    const EPS: f64 = 1e-5;
942
943    // ── SparsemapConfig defaults ─────────────────────────────────────────────
944
945    #[test]
946    fn test_sparsemap_config_defaults() {
947        let cfg = SparsemapConfig::default();
948        assert_eq!(cfg.max_iter, 1000);
949        assert!((cfg.tol - 1e-6).abs() < 1e-12);
950        assert!(matches!(cfg.structure_type, StructureType::Simplex));
951    }
952
953    // ── SparseMAP on simplex ─────────────────────────────────────────────────
954
955    #[test]
956    fn test_sparsemap_simplex_sums_to_one() {
957        let scores = vec![1.0_f64, 2.0, 0.5, -0.3, 1.8];
958        let cfg = SparsemapConfig::default();
959        let res = sparsemap(&scores, &cfg).unwrap();
960        let sum: f64 = res.solution.iter().sum();
961        assert!((sum - 1.0).abs() < EPS, "sum = {}", sum);
962    }
963
964    #[test]
965    fn test_sparsemap_simplex_sparse_support() {
966        // Scores with clear winner should produce sparse solution
967        let scores = vec![10.0_f64, 0.1, 0.1, 0.1, 0.1];
968        let cfg = SparsemapConfig::default();
969        let res = sparsemap(&scores, &cfg).unwrap();
970        // The highest score dominates; at least one zero
971        let n_nonzero = res.solution.iter().filter(|&&v| v > 1e-9).count();
972        assert!(
973            n_nonzero <= scores.len(),
974            "non-zero count {} should be <= n",
975            n_nonzero
976        );
977        assert!(!res.support.is_empty());
978    }
979
980    #[test]
981    fn test_sparsemap_simplex_nonneg() {
982        let scores = vec![-1.0_f64, -0.5, 0.3, 2.0, -3.0];
983        let cfg = SparsemapConfig::default();
984        let res = sparsemap(&scores, &cfg).unwrap();
985        for &v in &res.solution {
986            assert!(v >= -1e-10, "negative value {}", v);
987        }
988    }
989
990    #[test]
991    fn test_sparsemap_gradient_shape_matches_input() {
992        let scores = vec![1.0_f64, 2.0, 0.5];
993        let cfg = SparsemapConfig::default();
994        let res = sparsemap(&scores, &cfg).unwrap();
995        let upstream = vec![1.0_f64, 0.0, -1.0];
996        let grad = sparsemap_gradient(&res, &upstream);
997        assert_eq!(grad.len(), scores.len());
998    }
999
1000    #[test]
1001    fn test_sparsemap_gradient_zeros_outside_support() {
1002        let scores = vec![5.0_f64, -5.0, -5.0];
1003        let cfg = SparsemapConfig::default();
1004        let res = sparsemap(&scores, &cfg).unwrap();
1005        let upstream = vec![1.0_f64, 1.0, 1.0];
1006        let grad = sparsemap_gradient(&res, &upstream);
1007        // Indices not in support should receive zero gradient
1008        for (i, &g) in grad.iter().enumerate() {
1009            if !res.support.contains(&i) {
1010                assert!(g.abs() < EPS, "index {} outside support has grad {}", i, g);
1011            }
1012        }
1013    }
1014
1015    #[test]
1016    fn test_sparsemap_knapsack_feasibility() {
1017        let weights = vec![1.0_f64, 2.0, 3.0];
1018        let capacity = 3.0_f64;
1019        let cfg = SparsemapConfig {
1020            structure_type: StructureType::Knapsack {
1021                weights: weights.clone(),
1022                capacity,
1023            },
1024            max_iter: 500,
1025            ..SparsemapConfig::default()
1026        };
1027        let scores = vec![3.0_f64, 2.0, 1.0];
1028        let res = sparsemap(&scores, &cfg).unwrap();
1029        // All values in [0,1]
1030        for &v in &res.solution {
1031            assert!(v >= -EPS && v <= 1.0 + EPS, "value {} out of [0,1]", v);
1032        }
1033        // Weighted sum ≤ capacity
1034        let used: f64 = weights
1035            .iter()
1036            .zip(res.solution.iter())
1037            .map(|(&w, &v)| w * v)
1038            .sum();
1039        assert!(used <= capacity + EPS, "capacity exceeded: {}", used);
1040    }
1041
1042    // ── PerturbedOptimizerConfig defaults ───────────────────────────────────
1043
1044    #[test]
1045    fn test_perturbed_optimizer_config_defaults() {
1046        let cfg = PerturbedOptimizerConfig::default();
1047        assert_eq!(cfg.n_samples, 100);
1048        assert!((cfg.epsilon - 0.1).abs() < 1e-12);
1049        assert_eq!(cfg.seed, 42);
1050    }
1051
1052    // ── PerturbedOptimizer forward ───────────────────────────────────────────
1053
1054    #[test]
1055    fn test_perturbed_optimizer_output_sums_to_one() {
1056        let cfg = PerturbedOptimizerConfig {
1057            n_samples: 200,
1058            ..Default::default()
1059        };
1060        let opt = PerturbedOptimizer::new(cfg);
1061        let scores = vec![1.0_f64, 2.0, 0.5, 3.0];
1062        let probs = opt.forward(&scores).unwrap();
1063        let sum: f64 = probs.iter().sum();
1064        assert!((sum - 1.0).abs() < 0.01, "sum = {}", sum);
1065    }
1066
1067    #[test]
1068    fn test_perturbed_optimizer_n_samples_1_deterministic() {
1069        let cfg = PerturbedOptimizerConfig {
1070            n_samples: 1,
1071            seed: 7,
1072            ..Default::default()
1073        };
1074        let opt = PerturbedOptimizer::new(cfg.clone());
1075        let scores = vec![1.0_f64, 2.0, 0.5];
1076        let p1 = opt.forward(&scores).unwrap();
1077        let opt2 = PerturbedOptimizer::new(cfg);
1078        let p2 = opt2.forward(&scores).unwrap();
1079        for (a, b) in p1.iter().zip(p2.iter()) {
1080            assert_eq!(a, b, "results differ between identical seeds");
1081        }
1082    }
1083
1084    // ── soft_sort ────────────────────────────────────────────────────────────
1085
1086    #[test]
1087    fn test_soft_sort_nondecreasing() {
1088        let x = vec![3.0_f64, 1.0, 4.0, 1.5, 9.0, 2.6];
1089        let sorted = soft_sort(&x, 0.0_f64).unwrap();
1090        for w in sorted.windows(2) {
1091            assert!(w[0] <= w[1] + 1e-10, "not sorted: {} > {}", w[0], w[1]);
1092        }
1093    }
1094
1095    #[test]
1096    fn test_soft_sort_nonzero_temp_nondecreasing() {
1097        let x = vec![5.0_f64, 1.0, 3.0, 2.0];
1098        let sorted = soft_sort(&x, 0.5_f64).unwrap();
1099        for w in sorted.windows(2) {
1100            assert!(
1101                w[0] <= w[1] + 1e-9,
1102                "soft_sort not sorted: {} > {}",
1103                w[0],
1104                w[1]
1105            );
1106        }
1107    }
1108
1109    // ── soft_rank ────────────────────────────────────────────────────────────
1110
1111    #[test]
1112    fn test_soft_rank_high_temp_input_3_1_2() {
1113        // At high temperature, hard ranks of [3,1,2] should be [3,1,2]
1114        // (largest element gets rank 3)
1115        let x = vec![3.0_f64, 1.0, 2.0];
1116        let ranks = soft_rank(&x, 0.0_f64).unwrap();
1117        assert_eq!(ranks[0] as usize, 3, "rank of largest should be 3");
1118        assert_eq!(ranks[1] as usize, 1, "rank of smallest should be 1");
1119        assert_eq!(ranks[2] as usize, 2, "rank of middle should be 2");
1120    }
1121
1122    // ── diff_topk ────────────────────────────────────────────────────────────
1123
1124    #[test]
1125    fn test_diff_topk_sums_to_k() {
1126        let scores = vec![1.0_f64, 5.0, 2.0, 4.0, 3.0];
1127        let k = 3;
1128        let p = diff_topk(&scores, k, 0.5_f64).unwrap();
1129        let sum: f64 = p.iter().sum();
1130        assert!(
1131            (sum - k as f64).abs() < 1e-6,
1132            "sum = {} but expected k={}",
1133            sum,
1134            k
1135        );
1136    }
1137
1138    #[test]
1139    fn test_diff_topk_zero_temp_hard_topk() {
1140        let scores = vec![1.0_f64, 5.0, 2.0, 4.0, 3.0];
1141        let k = 2;
1142        let p = diff_topk(&scores, k, 0.0_f64).unwrap();
1143        // Should select indices 1 (score 5.0) and 3 (score 4.0)
1144        let sum: f64 = p.iter().sum();
1145        assert!((sum - k as f64).abs() < 1e-9);
1146        assert!((p[1] - 1.0).abs() < 1e-9, "index 1 should be selected");
1147        assert!((p[3] - 1.0).abs() < 1e-9, "index 3 should be selected");
1148    }
1149
1150    #[test]
1151    fn test_diff_topk_all_values_nonneg() {
1152        // diff_topk returns p_i in [0, k]; each value is non-negative and sum = k
1153        let scores = vec![0.1_f64, 2.3, -1.0, 5.0, 0.7];
1154        let k = 2usize;
1155        let p = diff_topk(&scores, k, 1.0_f64).unwrap();
1156        for &v in &p {
1157            assert!(v >= -1e-9, "value {} is negative", v);
1158            assert!(v <= k as f64 + 1e-9, "value {} exceeds k={}", v, k);
1159        }
1160        let sum: f64 = p.iter().sum();
1161        assert!(
1162            (sum - k as f64).abs() < 1e-6,
1163            "sum = {} expected k={}",
1164            sum,
1165            k
1166        );
1167    }
1168}