1use scirs2_core::num_traits::{Float, FromPrimitive};
32use std::fmt::Debug;
33
34use crate::error::{OptimizeError, OptimizeResult};
35
36#[derive(Debug, Clone)]
42#[non_exhaustive]
43pub enum StructureType {
44 Simplex,
46 Knapsack {
48 weights: Vec<f64>,
50 capacity: f64,
52 },
53 Permutation {
56 dim: usize,
58 },
59}
60
61impl Default for StructureType {
62 fn default() -> Self {
63 StructureType::Simplex
64 }
65}
66
67#[derive(Debug, Clone)]
69pub struct SparsemapConfig {
70 pub max_iter: usize,
72 pub tol: f64,
74 pub structure_type: StructureType,
76 pub step_size: f64,
78}
79
80impl Default for SparsemapConfig {
81 fn default() -> Self {
82 Self {
83 max_iter: 1000,
84 tol: 1e-6,
85 structure_type: StructureType::default(),
86 step_size: 0.1,
87 }
88 }
89}
90
91#[derive(Debug, Clone)]
93pub struct SparsemapResult<F> {
94 pub solution: Vec<F>,
96 pub support: Vec<usize>,
98 pub dual: Vec<F>,
101 pub n_iters: usize,
103}
104
105fn project_simplex<F>(v: &[F]) -> Vec<F>
109where
110 F: Float + FromPrimitive + Debug + Clone,
111{
112 let n = v.len();
113 let mut u: Vec<F> = v.to_vec();
114 u.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
116
117 let mut cssv = F::zero();
118 let mut rho = 0usize;
119 for (j, &uj) in u.iter().enumerate() {
120 cssv = cssv + uj;
121 let j_f = F::from_usize(j + 1).unwrap_or(F::one());
122 let one = F::one();
123 if uj - (cssv - one) / j_f > F::zero() {
124 rho = j;
125 }
126 }
127
128 let rho_f = F::from_usize(rho + 1).unwrap_or(F::one());
129 let one = F::one();
130 let mut cssv2 = F::zero();
132 for uj in u.iter().take(rho + 1) {
133 cssv2 = cssv2 + *uj;
134 }
135 let theta = (cssv2 - one) / rho_f;
136
137 v.iter()
138 .map(|&vi| {
139 let diff = vi - theta;
140 if diff > F::zero() {
141 diff
142 } else {
143 F::zero()
144 }
145 })
146 .collect()
147}
148
149fn project_knapsack<F>(v: &[F], weights: &[f64], capacity: f64) -> Vec<F>
154where
155 F: Float + FromPrimitive + Debug + Clone,
156{
157 let n = v.len();
158 let mut mu: Vec<F> = v
160 .iter()
161 .map(|&vi| {
162 if vi < F::zero() {
163 F::zero()
164 } else if vi > F::one() {
165 F::one()
166 } else {
167 vi
168 }
169 })
170 .collect();
171
172 let total_weight: f64 = (0..n)
174 .map(|i| weights.get(i).copied().unwrap_or(1.0) * mu[i].to_f64().unwrap_or(0.0))
175 .sum();
176
177 if total_weight <= capacity + 1e-12 {
178 return mu;
179 }
180
181 let mut lo = 0.0_f64;
184 let mut hi = 1e8_f64;
185
186 for _ in 0..200 {
187 let mid = (lo + hi) / 2.0;
188 let w_total: f64 = (0..n)
189 .map(|i| {
190 let wi = weights.get(i).copied().unwrap_or(1.0);
191 let vi = v[i].to_f64().unwrap_or(0.0);
192 let mu_i = (vi / (1.0 + mid * wi)).clamp(0.0, 1.0);
193 wi * mu_i
194 })
195 .sum();
196 if w_total > capacity {
197 lo = mid;
198 } else {
199 hi = mid;
200 }
201 }
202
203 let lambda = (lo + hi) / 2.0;
204 mu = (0..n)
205 .map(|i| {
206 let wi = weights.get(i).copied().unwrap_or(1.0);
207 let vi = v[i].to_f64().unwrap_or(0.0);
208 let val = (vi / (1.0 + lambda * wi)).clamp(0.0, 1.0);
209 F::from_f64(val).unwrap_or(F::zero())
210 })
211 .collect();
212 mu
213}
214
215pub fn sparsemap<F>(scores: &[F], config: &SparsemapConfig) -> OptimizeResult<SparsemapResult<F>>
230where
231 F: Float + FromPrimitive + Debug + Clone,
232{
233 if scores.is_empty() {
234 return Err(OptimizeError::InvalidInput(
235 "scores vector must be non-empty".into(),
236 ));
237 }
238
239 let n = scores.len();
240 let tol_f = F::from_f64(config.tol).unwrap_or(F::epsilon());
241
242 let solution: Vec<F>;
243 let n_iters: usize;
244 let dual: Vec<F>;
245
246 match &config.structure_type {
247 StructureType::Simplex => {
248 solution = project_simplex(scores);
250 n_iters = 1;
251 let support_sum: F =
254 solution.iter().fold(
255 F::zero(),
256 |acc, &x| {
257 if x > F::zero() {
258 acc + x
259 } else {
260 acc
261 }
262 },
263 );
264 let support_count = solution.iter().filter(|&&x| x > F::zero()).count();
265 let count_f = F::from_usize(support_count).unwrap_or(F::one());
266 let lambda = if count_f > F::zero() {
267 (support_sum - F::one()) / count_f
268 } else {
269 F::zero()
270 };
271 dual = vec![lambda];
272 }
273
274 StructureType::Knapsack { weights, capacity } => {
275 let mut mu: Vec<F> = vec![F::zero(); n];
277 let step = F::from_f64(config.step_size).unwrap_or(F::epsilon());
278
279 let mut iter = 0usize;
280 let mut prev_obj = F::neg_infinity();
281
282 loop {
283 let grad: Vec<F> = mu.iter().zip(scores.iter()).map(|(&m, &s)| m - s).collect();
285
286 let mu_new: Vec<F> = mu
288 .iter()
289 .zip(grad.iter())
290 .map(|(&m, &g)| m - step * g)
291 .collect();
292
293 let mu_proj = project_knapsack(&mu_new, weights, *capacity);
295
296 let obj = mu_proj
298 .iter()
299 .zip(scores.iter())
300 .fold(F::zero(), |acc, (&m, &s)| {
301 let diff = m - s;
302 acc + diff * diff
303 });
304 let half = F::from_f64(0.5).unwrap_or(F::one());
305 let obj = obj * half;
306
307 let diff = (obj - prev_obj).abs();
308 mu = mu_proj;
309 prev_obj = obj;
310 iter += 1;
311
312 if iter >= config.max_iter || diff < tol_f {
313 break;
314 }
315 }
316
317 solution = mu;
318 n_iters = iter;
319 let total_w: f64 = (0..n)
321 .map(|i| {
322 weights.get(i).copied().unwrap_or(1.0) * solution[i].to_f64().unwrap_or(0.0)
323 })
324 .sum();
325 let slack = *capacity - total_w;
326 let lambda_val = if slack.abs() < 1e-8 { -1.0 } else { 0.0 };
327 dual = vec![F::from_f64(lambda_val).unwrap_or(F::zero())];
328 }
329
330 StructureType::Permutation { dim } => {
331 let d = *dim;
334 if scores.len() != d * d {
335 return Err(OptimizeError::InvalidInput(format!(
336 "Permutation structure requires d²={} scores but got {}",
337 d * d,
338 scores.len()
339 )));
340 }
341
342 let inv_d = F::from_f64(1.0 / d as f64).unwrap_or(F::one());
344 let mut mu: Vec<F> = vec![inv_d; d * d];
345 let step = F::from_f64(config.step_size).unwrap_or(F::epsilon());
346 let mut iter = 0usize;
347
348 loop {
349 let mu_step: Vec<F> = mu
351 .iter()
352 .zip(scores.iter())
353 .map(|(&m, &s)| m - step * (m - s))
354 .collect();
355
356 let mut m_sink = mu_step;
358 for _ in 0..50 {
359 for row in 0..d {
361 let row_sum: F = (0..d)
362 .map(|col| m_sink[row * d + col])
363 .fold(F::zero(), |a, b| a + b);
364 if row_sum > F::zero() {
365 for col in 0..d {
366 m_sink[row * d + col] = m_sink[row * d + col] / row_sum;
367 }
368 }
369 }
370 for col in 0..d {
372 let col_sum: F = (0..d)
373 .map(|row| m_sink[row * d + col])
374 .fold(F::zero(), |a, b| a + b);
375 if col_sum > F::zero() {
376 for row in 0..d {
377 m_sink[row * d + col] = m_sink[row * d + col] / col_sum;
378 }
379 }
380 }
381 }
382
383 let change: F = mu
385 .iter()
386 .zip(m_sink.iter())
387 .map(|(&a, &b)| {
388 let d = a - b;
389 d * d
390 })
391 .fold(F::zero(), |a, b| a + b);
392
393 mu = m_sink;
394 iter += 1;
395
396 if iter >= config.max_iter || change < tol_f * tol_f {
397 break;
398 }
399 }
400
401 solution = mu;
402 n_iters = iter;
403 dual = vec![F::zero(); 2 * d]; }
405 }
406
407 let support: Vec<usize> = solution
409 .iter()
410 .enumerate()
411 .filter_map(|(i, &v)| {
412 if v > F::from_f64(1e-9).unwrap_or(F::zero()) {
413 Some(i)
414 } else {
415 None
416 }
417 })
418 .collect();
419
420 Ok(SparsemapResult {
421 solution,
422 support,
423 dual,
424 n_iters,
425 })
426}
427
428pub fn sparsemap_gradient<F>(result: &SparsemapResult<F>, upstream_grad: &[F]) -> Vec<F>
450where
451 F: Float + FromPrimitive + Debug + Clone,
452{
453 let n = result.solution.len();
454 if upstream_grad.len() != n {
455 return vec![F::zero(); n];
457 }
458
459 let s = &result.support;
460 if s.is_empty() {
461 return vec![F::zero(); n];
462 }
463
464 let s_size = F::from_usize(s.len()).unwrap_or(F::one());
466 let mean_s: F = s
467 .iter()
468 .map(|&i| upstream_grad[i])
469 .fold(F::zero(), |a, b| a + b)
470 / s_size;
471
472 let mut grad = vec![F::zero(); n];
474 for &i in s {
475 grad[i] = upstream_grad[i] - mean_s;
476 }
477 grad
478}
479
480#[derive(Debug, Clone)]
486pub struct PerturbedOptimizerConfig {
487 pub n_samples: usize,
489 pub epsilon: f64,
491 pub seed: u64,
493}
494
495impl Default for PerturbedOptimizerConfig {
496 fn default() -> Self {
497 Self {
498 n_samples: 100,
499 epsilon: 0.1,
500 seed: 42,
501 }
502 }
503}
504
505#[derive(Debug, Clone)]
511pub struct PerturbedOptimizer {
512 config: PerturbedOptimizerConfig,
513}
514
515impl PerturbedOptimizer {
516 pub fn new(config: PerturbedOptimizerConfig) -> Self {
518 Self { config }
519 }
520
521 pub fn forward<F>(&self, scores: &[F]) -> OptimizeResult<Vec<F>>
525 where
526 F: Float + FromPrimitive + Debug + Clone,
527 {
528 if scores.is_empty() {
529 return Err(OptimizeError::InvalidInput(
530 "scores must be non-empty".into(),
531 ));
532 }
533 let n = scores.len();
534 let mut counts = vec![0usize; n];
535 let eps = self.config.epsilon;
536
537 let mut rng_state = self.config.seed;
539 let n_samples = self.config.n_samples;
540
541 for _ in 0..n_samples {
542 let mut best_idx = 0usize;
544 let mut best_val = F::neg_infinity();
545
546 for i in 0..n {
547 let z = sample_standard_normal(&mut rng_state);
548 let perturbed = scores[i] + F::from_f64(eps * z).unwrap_or(F::zero());
549 if perturbed > best_val {
550 best_val = perturbed;
551 best_idx = i;
552 }
553 }
554 counts[best_idx] += 1;
555 }
556
557 let n_samples_f = F::from_usize(n_samples).unwrap_or(F::one());
558 let probs: Vec<F> = counts
559 .iter()
560 .map(|&c| F::from_usize(c).unwrap_or(F::zero()) / n_samples_f)
561 .collect();
562
563 Ok(probs)
564 }
565
566 pub fn backward<F>(&self, scores: &[F], upstream: &[F]) -> OptimizeResult<Vec<F>>
581 where
582 F: Float + FromPrimitive + Debug + Clone,
583 {
584 if scores.len() != upstream.len() {
585 return Err(OptimizeError::InvalidInput(
586 "scores and upstream must have the same length".into(),
587 ));
588 }
589 let n = scores.len();
590 let eps = self.config.epsilon;
591 let eps_sq = eps * eps;
592 let n_samples = self.config.n_samples;
593
594 let mut grad = vec![F::zero(); n];
595 let mut rng_state = self.config.seed;
596
597 for _ in 0..n_samples {
598 let noise: Vec<f64> = (0..n)
600 .map(|_| sample_standard_normal(&mut rng_state))
601 .collect();
602
603 let mut best_idx = 0usize;
604 let mut best_val = F::neg_infinity();
605 for i in 0..n {
606 let perturbed = scores[i] + F::from_f64(eps * noise[i]).unwrap_or(F::zero());
607 if perturbed > best_val {
608 best_val = perturbed;
609 best_idx = i;
610 }
611 }
612
613 let dot = upstream[best_idx];
616
617 for i in 0..n {
619 let zi = F::from_f64(noise[i]).unwrap_or(F::zero());
620 let eps_sq_f = F::from_f64(eps_sq).unwrap_or(F::one());
621 grad[i] = grad[i] + dot * zi / eps_sq_f;
622 }
623 }
624
625 let n_f = F::from_usize(n_samples).unwrap_or(F::one());
626 for g in &mut grad {
627 *g = *g / n_f;
628 }
629
630 Ok(grad)
631 }
632}
633
634fn splitmix64(state: &mut u64) -> u64 {
636 *state = state.wrapping_add(0x9e3779b97f4a7c15);
637 let mut z = *state;
638 z = (z ^ (z >> 30)).wrapping_mul(0xbf58476d1ce4e5b9);
639 z = (z ^ (z >> 27)).wrapping_mul(0x94d049bb133111eb);
640 z ^ (z >> 31)
641}
642
643fn sample_standard_normal(state: &mut u64) -> f64 {
645 let u1_raw = splitmix64(state);
646 let u2_raw = splitmix64(state);
647 let u1 = (u1_raw as f64 + 0.5) / (u64::MAX as f64 + 1.0);
649 let u2 = (u2_raw as f64 + 0.5) / (u64::MAX as f64 + 1.0);
650 let two_pi = 2.0 * std::f64::consts::PI;
651 (-2.0 * u1.ln()).sqrt() * (two_pi * u2).cos()
652}
653
654pub fn soft_sort<F>(x: &[F], temperature: F) -> OptimizeResult<Vec<F>>
677where
678 F: Float + FromPrimitive + Debug + Clone,
679{
680 if x.is_empty() {
681 return Err(OptimizeError::InvalidInput(
682 "input vector must be non-empty".into(),
683 ));
684 }
685
686 let n = x.len();
687 let mut sorted_x: Vec<F> = x.to_vec();
689 sorted_x.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
690
691 if temperature == F::zero() {
695 return Ok(sorted_x);
696 }
697
698 let mean_val =
699 sorted_x.iter().fold(F::zero(), |a, b| a + *b) / F::from_usize(n).unwrap_or(F::one());
700
701 let t_clamped = if temperature > F::one() {
705 F::one()
706 } else {
707 temperature
708 };
709 let one_minus_t = F::one() - t_clamped;
710
711 let mixed: Vec<F> = sorted_x
712 .iter()
713 .map(|&v| one_minus_t * v + t_clamped * mean_val)
714 .collect();
715
716 let result = pool_adjacent_violators(&mixed);
720
721 Ok(result)
722}
723
724fn pool_adjacent_violators<F>(s: &[F]) -> Vec<F>
727where
728 F: Float + FromPrimitive + Debug + Clone,
729{
730 let n = s.len();
731 let mut blocks: Vec<(F, usize)> = s.iter().map(|&v| (v, 1)).collect();
733
734 let mut changed = true;
735 while changed {
736 changed = false;
737 let mut i = 0usize;
738 let mut new_blocks: Vec<(F, usize)> = Vec::with_capacity(blocks.len());
739 while i < blocks.len() {
740 let mut sum = blocks[i].0;
741 let mut cnt = blocks[i].1;
742 while i + 1 < blocks.len() {
744 let next_mean =
745 blocks[i + 1].0 / F::from_usize(blocks[i + 1].1).unwrap_or(F::one());
746 let cur_mean = sum / F::from_usize(cnt).unwrap_or(F::one());
747 if cur_mean > next_mean {
748 sum = sum + blocks[i + 1].0;
749 cnt += blocks[i + 1].1;
750 i += 1;
751 changed = true;
752 } else {
753 break;
754 }
755 }
756 new_blocks.push((sum, cnt));
757 i += 1;
758 }
759 blocks = new_blocks;
760 }
761
762 let mut result = Vec::with_capacity(n);
764 for (sum, cnt) in blocks {
765 let mean = sum / F::from_usize(cnt).unwrap_or(F::one());
766 for _ in 0..cnt {
767 result.push(mean);
768 }
769 }
770 result
771}
772
773pub fn soft_rank<F>(x: &[F], temperature: F) -> OptimizeResult<Vec<F>>
784where
785 F: Float + FromPrimitive + Debug + Clone,
786{
787 if x.is_empty() {
788 return Err(OptimizeError::InvalidInput(
789 "input vector must be non-empty".into(),
790 ));
791 }
792 let n = x.len();
793 let one = F::one();
794 let n_f = F::from_usize(n).unwrap_or(one);
795
796 if temperature == F::zero() {
797 let ranks: Vec<F> = (0..n)
799 .map(|i| {
800 let rank = x.iter().filter(|&&v| v < x[i]).count();
801 F::from_usize(rank + 1).unwrap_or(one)
802 })
803 .collect();
804 return Ok(ranks);
805 }
806
807 let two = F::from_f64(2.0).unwrap_or(one);
810
811 let ranks: Vec<F> = (0..n)
812 .map(|i| {
813 let mut soft_rank_i = one; for j in 0..n {
815 if i == j {
816 continue;
817 }
818 let diff = (x[i] - x[j]) / temperature;
819 let diff_clamped = if diff < F::from_f64(-50.0).unwrap_or(-one) {
821 F::from_f64(-50.0).unwrap_or(-one)
822 } else if diff > F::from_f64(50.0).unwrap_or(one) {
823 F::from_f64(50.0).unwrap_or(one)
824 } else {
825 diff
826 };
827 let sigmoid_val = one / (one + (-diff_clamped).exp());
828 soft_rank_i = soft_rank_i + sigmoid_val;
829 }
830 let mid = (n_f + one) / two;
832 let t = if temperature > F::from_f64(10.0).unwrap_or(one) {
833 one
834 } else {
835 temperature / F::from_f64(10.0).unwrap_or(one)
836 };
837 (one - t) * soft_rank_i + t * mid
838 })
839 .collect();
840
841 Ok(ranks)
842}
843
844pub fn diff_topk<F>(scores: &[F], k: usize, temperature: F) -> OptimizeResult<Vec<F>>
872where
873 F: Float + FromPrimitive + Debug + Clone,
874{
875 let n = scores.len();
876 if n == 0 {
877 return Err(OptimizeError::InvalidInput(
878 "scores must be non-empty".into(),
879 ));
880 }
881 if k == 0 || k > n {
882 return Err(OptimizeError::InvalidInput(format!(
883 "k must be in [1, {}] but got {}",
884 n, k
885 )));
886 }
887
888 let k_f = F::from_usize(k).unwrap_or(F::one());
889
890 if temperature == F::zero() {
891 let mut indexed: Vec<(usize, F)> = scores.iter().copied().enumerate().collect();
893 indexed.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
894 let mut result = vec![F::zero(); n];
895 for (idx, _) in indexed.iter().take(k) {
896 result[*idx] = F::one();
897 }
898 return Ok(result);
899 }
900
901 let max_score = scores
903 .iter()
904 .copied()
905 .fold(F::neg_infinity(), |a, b| if b > a { b } else { a });
906
907 let exp_scores: Vec<F> = scores
908 .iter()
909 .map(|&s| {
910 let scaled = (s - max_score) / temperature;
911 let clamped = if scaled < F::from_f64(-80.0).unwrap_or(-F::one()) {
913 F::from_f64(-80.0).unwrap_or(-F::one())
914 } else {
915 scaled
916 };
917 clamped.exp()
918 })
919 .collect();
920
921 let sum_exp: F = exp_scores.iter().fold(F::zero(), |a, b| a + *b);
922 if sum_exp == F::zero() {
923 let uniform = k_f / F::from_usize(n).unwrap_or(F::one());
925 return Ok(vec![uniform; n]);
926 }
927
928 let result: Vec<F> = exp_scores.iter().map(|&e| k_f * e / sum_exp).collect();
929
930 Ok(result)
931}
932
933#[cfg(test)]
938mod tests {
939 use super::*;
940
941 const EPS: f64 = 1e-5;
942
943 #[test]
946 fn test_sparsemap_config_defaults() {
947 let cfg = SparsemapConfig::default();
948 assert_eq!(cfg.max_iter, 1000);
949 assert!((cfg.tol - 1e-6).abs() < 1e-12);
950 assert!(matches!(cfg.structure_type, StructureType::Simplex));
951 }
952
953 #[test]
956 fn test_sparsemap_simplex_sums_to_one() {
957 let scores = vec![1.0_f64, 2.0, 0.5, -0.3, 1.8];
958 let cfg = SparsemapConfig::default();
959 let res = sparsemap(&scores, &cfg).unwrap();
960 let sum: f64 = res.solution.iter().sum();
961 assert!((sum - 1.0).abs() < EPS, "sum = {}", sum);
962 }
963
964 #[test]
965 fn test_sparsemap_simplex_sparse_support() {
966 let scores = vec![10.0_f64, 0.1, 0.1, 0.1, 0.1];
968 let cfg = SparsemapConfig::default();
969 let res = sparsemap(&scores, &cfg).unwrap();
970 let n_nonzero = res.solution.iter().filter(|&&v| v > 1e-9).count();
972 assert!(
973 n_nonzero <= scores.len(),
974 "non-zero count {} should be <= n",
975 n_nonzero
976 );
977 assert!(!res.support.is_empty());
978 }
979
980 #[test]
981 fn test_sparsemap_simplex_nonneg() {
982 let scores = vec![-1.0_f64, -0.5, 0.3, 2.0, -3.0];
983 let cfg = SparsemapConfig::default();
984 let res = sparsemap(&scores, &cfg).unwrap();
985 for &v in &res.solution {
986 assert!(v >= -1e-10, "negative value {}", v);
987 }
988 }
989
990 #[test]
991 fn test_sparsemap_gradient_shape_matches_input() {
992 let scores = vec![1.0_f64, 2.0, 0.5];
993 let cfg = SparsemapConfig::default();
994 let res = sparsemap(&scores, &cfg).unwrap();
995 let upstream = vec![1.0_f64, 0.0, -1.0];
996 let grad = sparsemap_gradient(&res, &upstream);
997 assert_eq!(grad.len(), scores.len());
998 }
999
1000 #[test]
1001 fn test_sparsemap_gradient_zeros_outside_support() {
1002 let scores = vec![5.0_f64, -5.0, -5.0];
1003 let cfg = SparsemapConfig::default();
1004 let res = sparsemap(&scores, &cfg).unwrap();
1005 let upstream = vec![1.0_f64, 1.0, 1.0];
1006 let grad = sparsemap_gradient(&res, &upstream);
1007 for (i, &g) in grad.iter().enumerate() {
1009 if !res.support.contains(&i) {
1010 assert!(g.abs() < EPS, "index {} outside support has grad {}", i, g);
1011 }
1012 }
1013 }
1014
1015 #[test]
1016 fn test_sparsemap_knapsack_feasibility() {
1017 let weights = vec![1.0_f64, 2.0, 3.0];
1018 let capacity = 3.0_f64;
1019 let cfg = SparsemapConfig {
1020 structure_type: StructureType::Knapsack {
1021 weights: weights.clone(),
1022 capacity,
1023 },
1024 max_iter: 500,
1025 ..SparsemapConfig::default()
1026 };
1027 let scores = vec![3.0_f64, 2.0, 1.0];
1028 let res = sparsemap(&scores, &cfg).unwrap();
1029 for &v in &res.solution {
1031 assert!(v >= -EPS && v <= 1.0 + EPS, "value {} out of [0,1]", v);
1032 }
1033 let used: f64 = weights
1035 .iter()
1036 .zip(res.solution.iter())
1037 .map(|(&w, &v)| w * v)
1038 .sum();
1039 assert!(used <= capacity + EPS, "capacity exceeded: {}", used);
1040 }
1041
1042 #[test]
1045 fn test_perturbed_optimizer_config_defaults() {
1046 let cfg = PerturbedOptimizerConfig::default();
1047 assert_eq!(cfg.n_samples, 100);
1048 assert!((cfg.epsilon - 0.1).abs() < 1e-12);
1049 assert_eq!(cfg.seed, 42);
1050 }
1051
1052 #[test]
1055 fn test_perturbed_optimizer_output_sums_to_one() {
1056 let cfg = PerturbedOptimizerConfig {
1057 n_samples: 200,
1058 ..Default::default()
1059 };
1060 let opt = PerturbedOptimizer::new(cfg);
1061 let scores = vec![1.0_f64, 2.0, 0.5, 3.0];
1062 let probs = opt.forward(&scores).unwrap();
1063 let sum: f64 = probs.iter().sum();
1064 assert!((sum - 1.0).abs() < 0.01, "sum = {}", sum);
1065 }
1066
1067 #[test]
1068 fn test_perturbed_optimizer_n_samples_1_deterministic() {
1069 let cfg = PerturbedOptimizerConfig {
1070 n_samples: 1,
1071 seed: 7,
1072 ..Default::default()
1073 };
1074 let opt = PerturbedOptimizer::new(cfg.clone());
1075 let scores = vec![1.0_f64, 2.0, 0.5];
1076 let p1 = opt.forward(&scores).unwrap();
1077 let opt2 = PerturbedOptimizer::new(cfg);
1078 let p2 = opt2.forward(&scores).unwrap();
1079 for (a, b) in p1.iter().zip(p2.iter()) {
1080 assert_eq!(a, b, "results differ between identical seeds");
1081 }
1082 }
1083
1084 #[test]
1087 fn test_soft_sort_nondecreasing() {
1088 let x = vec![3.0_f64, 1.0, 4.0, 1.5, 9.0, 2.6];
1089 let sorted = soft_sort(&x, 0.0_f64).unwrap();
1090 for w in sorted.windows(2) {
1091 assert!(w[0] <= w[1] + 1e-10, "not sorted: {} > {}", w[0], w[1]);
1092 }
1093 }
1094
1095 #[test]
1096 fn test_soft_sort_nonzero_temp_nondecreasing() {
1097 let x = vec![5.0_f64, 1.0, 3.0, 2.0];
1098 let sorted = soft_sort(&x, 0.5_f64).unwrap();
1099 for w in sorted.windows(2) {
1100 assert!(
1101 w[0] <= w[1] + 1e-9,
1102 "soft_sort not sorted: {} > {}",
1103 w[0],
1104 w[1]
1105 );
1106 }
1107 }
1108
1109 #[test]
1112 fn test_soft_rank_high_temp_input_3_1_2() {
1113 let x = vec![3.0_f64, 1.0, 2.0];
1116 let ranks = soft_rank(&x, 0.0_f64).unwrap();
1117 assert_eq!(ranks[0] as usize, 3, "rank of largest should be 3");
1118 assert_eq!(ranks[1] as usize, 1, "rank of smallest should be 1");
1119 assert_eq!(ranks[2] as usize, 2, "rank of middle should be 2");
1120 }
1121
1122 #[test]
1125 fn test_diff_topk_sums_to_k() {
1126 let scores = vec![1.0_f64, 5.0, 2.0, 4.0, 3.0];
1127 let k = 3;
1128 let p = diff_topk(&scores, k, 0.5_f64).unwrap();
1129 let sum: f64 = p.iter().sum();
1130 assert!(
1131 (sum - k as f64).abs() < 1e-6,
1132 "sum = {} but expected k={}",
1133 sum,
1134 k
1135 );
1136 }
1137
1138 #[test]
1139 fn test_diff_topk_zero_temp_hard_topk() {
1140 let scores = vec![1.0_f64, 5.0, 2.0, 4.0, 3.0];
1141 let k = 2;
1142 let p = diff_topk(&scores, k, 0.0_f64).unwrap();
1143 let sum: f64 = p.iter().sum();
1145 assert!((sum - k as f64).abs() < 1e-9);
1146 assert!((p[1] - 1.0).abs() < 1e-9, "index 1 should be selected");
1147 assert!((p[3] - 1.0).abs() < 1e-9, "index 3 should be selected");
1148 }
1149
1150 #[test]
1151 fn test_diff_topk_all_values_nonneg() {
1152 let scores = vec![0.1_f64, 2.3, -1.0, 5.0, 0.7];
1154 let k = 2usize;
1155 let p = diff_topk(&scores, k, 1.0_f64).unwrap();
1156 for &v in &p {
1157 assert!(v >= -1e-9, "value {} is negative", v);
1158 assert!(v <= k as f64 + 1e-9, "value {} exceeds k={}", v, k);
1159 }
1160 let sum: f64 = p.iter().sum();
1161 assert!(
1162 (sum - k as f64).abs() < 1e-6,
1163 "sum = {} expected k={}",
1164 sum,
1165 k
1166 );
1167 }
1168}