1use crate::error::{OptimizeError, OptimizeResult};
29
30use super::kkt_sensitivity::kkt_sensitivity;
31
32struct Xorshift64 {
38 state: u64,
39}
40
41impl Xorshift64 {
42 fn new(seed: u64) -> Self {
43 Self {
44 state: if seed == 0 { 1 } else { seed },
45 }
46 }
47
48 fn next_u64(&mut self) -> u64 {
49 let mut x = self.state;
50 x ^= x << 13;
51 x ^= x >> 7;
52 x ^= x << 17;
53 self.state = x;
54 x
55 }
56
57 fn uniform(&mut self) -> f64 {
59 (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
60 }
61
62 fn normal(&mut self) -> f64 {
64 let u1 = self.uniform().max(1e-15); let u2 = self.uniform();
66 let r = (-2.0 * u1.ln()).sqrt();
67 let theta = 2.0 * std::f64::consts::PI * u2;
68 r * theta.cos()
69 }
70
71 fn normal_vector(&mut self, n: usize) -> Vec<f64> {
73 (0..n).map(|_| self.normal()).collect()
74 }
75}
76
77#[derive(Debug, Clone)]
83pub struct PerturbedOptimizerConfig {
84 pub n_samples: usize,
86 pub sigma: f64,
88 pub seed: u64,
90}
91
92impl Default for PerturbedOptimizerConfig {
93 fn default() -> Self {
94 Self {
95 n_samples: 20,
96 sigma: 1.0,
97 seed: 42,
98 }
99 }
100}
101
102pub struct PerturbedOptimizer<F>
114where
115 F: Fn(&[f64]) -> Vec<f64>,
116{
117 optimizer: F,
118 config: PerturbedOptimizerConfig,
119 cached_samples: Option<Vec<Vec<f64>>>,
121 cached_outputs: Option<Vec<Vec<f64>>>,
123 cached_noise: Option<Vec<Vec<f64>>>,
125}
126
127impl<F> PerturbedOptimizer<F>
128where
129 F: Fn(&[f64]) -> Vec<f64>,
130{
131 pub fn new(optimizer: F) -> Self {
133 Self {
134 optimizer,
135 config: PerturbedOptimizerConfig::default(),
136 cached_samples: None,
137 cached_outputs: None,
138 cached_noise: None,
139 }
140 }
141
142 pub fn with_config(optimizer: F, config: PerturbedOptimizerConfig) -> Self {
144 Self {
145 optimizer,
146 config,
147 cached_samples: None,
148 cached_outputs: None,
149 cached_noise: None,
150 }
151 }
152
153 pub fn forward(&mut self, theta: &[f64]) -> OptimizeResult<Vec<f64>> {
164 let d = theta.len();
165 let mut rng = Xorshift64::new(self.config.seed);
166
167 let mut outputs: Vec<Vec<f64>> = Vec::with_capacity(self.config.n_samples);
168 let mut noises: Vec<Vec<f64>> = Vec::with_capacity(self.config.n_samples);
169
170 for _ in 0..self.config.n_samples {
171 let z = rng.normal_vector(d);
172 let theta_perturbed: Vec<f64> = theta
173 .iter()
174 .zip(z.iter())
175 .map(|(&ti, &zi)| ti + self.config.sigma * zi)
176 .collect();
177 let y = (self.optimizer)(&theta_perturbed);
178 outputs.push(y);
179 noises.push(z);
180 }
181
182 if outputs.is_empty() {
184 return Err(OptimizeError::ComputationError(
185 "No samples generated in PerturbedOptimizer::forward".to_string(),
186 ));
187 }
188 let out_len = outputs[0].len();
189 let mut mean_y = vec![0.0_f64; out_len];
190 for output in &outputs {
191 if output.len() != out_len {
192 return Err(OptimizeError::ComputationError(
193 "Inconsistent optimizer output lengths".to_string(),
194 ));
195 }
196 for (i, &oi) in output.iter().enumerate() {
197 mean_y[i] += oi;
198 }
199 }
200 let n = self.config.n_samples as f64;
201 for mi in &mut mean_y {
202 *mi /= n;
203 }
204
205 self.cached_samples = Some(
207 (0..self.config.n_samples)
208 .map(|k| {
209 theta
210 .iter()
211 .zip(noises[k].iter())
212 .map(|(&ti, &zi)| ti + self.config.sigma * zi)
213 .collect()
214 })
215 .collect(),
216 );
217 self.cached_outputs = Some(outputs);
218 self.cached_noise = Some(noises);
219
220 Ok(mean_y)
221 }
222
223 pub fn gradient(&self, theta: &[f64], dl_dy: &[f64]) -> OptimizeResult<Vec<f64>> {
238 let outputs = self.cached_outputs.as_ref().ok_or_else(|| {
239 OptimizeError::ComputationError(
240 "PerturbedOptimizer::gradient called before forward".to_string(),
241 )
242 })?;
243 let noises = self
244 .cached_noise
245 .as_ref()
246 .ok_or_else(|| OptimizeError::ComputationError("No cached noise".to_string()))?;
247
248 let d = theta.len();
249 let out_len = dl_dy.len();
250 let n_samples = outputs.len();
251
252 if n_samples == 0 {
253 return Err(OptimizeError::ComputationError(
254 "Empty sample cache".to_string(),
255 ));
256 }
257
258 let mut mean_y = vec![0.0_f64; out_len];
260 for output in outputs.iter() {
261 for (i, &oi) in output.iter().enumerate().take(out_len) {
262 mean_y[i] += oi;
263 }
264 }
265 for mi in &mut mean_y {
266 *mi /= n_samples as f64;
267 }
268
269 let sigma = self.config.sigma;
272 let mut grad = vec![0.0_f64; d];
273
274 for k in 0..n_samples {
275 let coeff: f64 = outputs[k]
277 .iter()
278 .zip(mean_y.iter())
279 .zip(dl_dy.iter())
280 .map(|((&yk, &ybar), &dly)| (yk - ybar) * dly)
281 .sum();
282
283 for j in 0..d {
285 let z_kj = if j < noises[k].len() {
286 noises[k][j]
287 } else {
288 0.0
289 };
290 grad[j] += coeff * z_kj;
291 }
292 }
293
294 let scale = 1.0 / (sigma * n_samples as f64);
295 for gi in &mut grad {
296 *gi *= scale;
297 }
298
299 Ok(grad)
300 }
301
302 pub fn reinforce_gradient(&self, theta: &[f64], dl_dy: &[f64]) -> OptimizeResult<Vec<f64>> {
312 let outputs = self.cached_outputs.as_ref().ok_or_else(|| {
313 OptimizeError::ComputationError(
314 "PerturbedOptimizer::reinforce_gradient called before forward".to_string(),
315 )
316 })?;
317 let noises = self
318 .cached_noise
319 .as_ref()
320 .ok_or_else(|| OptimizeError::ComputationError("No cached noise".to_string()))?;
321
322 let d = theta.len();
323 let n_samples = outputs.len();
324 let sigma = self.config.sigma;
325
326 let mut grad = vec![0.0_f64; d];
327 for k in 0..n_samples {
328 let l_k: f64 = outputs[k]
330 .iter()
331 .zip(dl_dy.iter())
332 .map(|(&yk, &dly)| yk * dly)
333 .sum();
334
335 for j in 0..d {
336 let z_kj = if j < noises[k].len() {
337 noises[k][j]
338 } else {
339 0.0
340 };
341 grad[j] += l_k * z_kj;
342 }
343 }
344
345 let scale = 1.0 / (sigma * n_samples as f64);
346 for gi in &mut grad {
347 *gi *= scale;
348 }
349
350 Ok(grad)
351 }
352
353 pub fn last_mean_output(&self) -> Option<Vec<f64>> {
355 let outputs = self.cached_outputs.as_ref()?;
356 if outputs.is_empty() {
357 return None;
358 }
359 let out_len = outputs[0].len();
360 let mut mean = vec![0.0_f64; out_len];
361 for output in outputs {
362 for (i, &oi) in output.iter().enumerate().take(out_len) {
363 mean[i] += oi;
364 }
365 }
366 let n = outputs.len() as f64;
367 for mi in &mut mean {
368 *mi /= n;
369 }
370 Some(mean)
371 }
372}
373
374#[derive(Debug, Clone)]
380pub struct SparseMapConfig {
381 pub max_iter: usize,
383 pub tol: f64,
385 pub step_size: f64,
387}
388
389impl Default for SparseMapConfig {
390 fn default() -> Self {
391 Self {
392 max_iter: 1000,
393 tol: 1e-8,
394 step_size: 0.1,
395 }
396 }
397}
398
399#[derive(Debug, Clone)]
410pub struct SparseMap {
411 config: SparseMapConfig,
412 a_marginal: Vec<Vec<f64>>,
414 b_marginal: Vec<f64>,
416 last_mu: Option<Vec<f64>>,
418 last_nu: Option<Vec<f64>>,
420 last_theta: Option<Vec<f64>>,
422}
423
424impl SparseMap {
425 pub fn new(a_marginal: Vec<Vec<f64>>, b_marginal: Vec<f64>) -> Self {
431 Self {
432 config: SparseMapConfig::default(),
433 a_marginal,
434 b_marginal,
435 last_mu: None,
436 last_nu: None,
437 last_theta: None,
438 }
439 }
440
441 pub fn simplex(n: usize) -> Self {
443 let a = vec![vec![1.0_f64; n]];
444 let b = vec![1.0_f64];
445 Self::new(a, b)
446 }
447
448 pub fn with_config(
450 a_marginal: Vec<Vec<f64>>,
451 b_marginal: Vec<f64>,
452 config: SparseMapConfig,
453 ) -> Self {
454 Self {
455 config,
456 a_marginal,
457 b_marginal,
458 last_mu: None,
459 last_nu: None,
460 last_theta: None,
461 }
462 }
463
464 pub fn forward(&mut self, theta: &[f64]) -> OptimizeResult<Vec<f64>> {
475 let n = theta.len();
476 let p = self.b_marginal.len();
477
478 if self.a_marginal.len() != p {
479 return Err(OptimizeError::InvalidInput(format!(
480 "A_marginal rows ({}) != b_marginal length ({})",
481 self.a_marginal.len(),
482 p
483 )));
484 }
485
486 let mut nu = vec![0.0_f64; p];
494 let step = self.config.step_size;
495
496 for _ in 0..self.config.max_iter {
497 let at_nu: Vec<f64> = (0..n)
499 .map(|j| {
500 (0..p)
501 .map(|i| {
502 let a_ij = if i < self.a_marginal.len() && j < self.a_marginal[i].len()
503 {
504 self.a_marginal[i][j]
505 } else {
506 0.0
507 };
508 nu[i] * a_ij
509 })
510 .sum::<f64>()
511 })
512 .collect();
513
514 let mu: Vec<f64> = (0..n).map(|j| (theta[j] - at_nu[j]).max(0.0)).collect();
515
516 let amu: Vec<f64> = (0..p)
518 .map(|i| {
519 (0..n)
520 .map(|j| {
521 let a_ij = if i < self.a_marginal.len() && j < self.a_marginal[i].len()
522 {
523 self.a_marginal[i][j]
524 } else {
525 0.0
526 };
527 a_ij * mu[j]
528 })
529 .sum::<f64>()
530 })
531 .collect();
532
533 let nu_new: Vec<f64> = (0..p)
534 .map(|i| nu[i] + step * (amu[i] - self.b_marginal[i]))
535 .collect();
536
537 let delta: f64 = nu_new
539 .iter()
540 .zip(nu.iter())
541 .map(|(a, b)| (a - b).powi(2))
542 .sum::<f64>()
543 .sqrt();
544
545 nu = nu_new;
546
547 if delta < self.config.tol {
548 break;
549 }
550 }
551
552 let at_nu: Vec<f64> = (0..n)
554 .map(|j| {
555 (0..p)
556 .map(|i| {
557 let a_ij = if i < self.a_marginal.len() && j < self.a_marginal[i].len() {
558 self.a_marginal[i][j]
559 } else {
560 0.0
561 };
562 nu[i] * a_ij
563 })
564 .sum::<f64>()
565 })
566 .collect();
567
568 let mu: Vec<f64> = (0..n).map(|j| (theta[j] - at_nu[j]).max(0.0)).collect();
569
570 self.last_mu = Some(mu.clone());
571 self.last_nu = Some(nu);
572 self.last_theta = Some(theta.to_vec());
573
574 Ok(mu)
575 }
576
577 pub fn backward(&self, dl_dmu: &[f64]) -> OptimizeResult<Vec<f64>> {
591 let mu = self.last_mu.as_ref().ok_or_else(|| {
592 OptimizeError::ComputationError("SparseMap::backward called before forward".to_string())
593 })?;
594 let nu = self
595 .last_nu
596 .as_ref()
597 .ok_or_else(|| OptimizeError::ComputationError("No cached nu".to_string()))?;
598 let theta = self
599 .last_theta
600 .as_ref()
601 .ok_or_else(|| OptimizeError::ComputationError("No cached theta".to_string()))?;
602
603 let n = mu.len();
604 let tol = 1e-8_f64;
605
606 let support: Vec<usize> = (0..n).filter(|&i| mu[i] > tol).collect();
608
609 if support.is_empty() {
610 return Ok(vec![0.0_f64; n]);
612 }
613
614 let s = support.len();
615 let p = nu.len();
616
617 let q_s: Vec<Vec<f64>> = (0..s)
619 .map(|i| {
620 let mut row = vec![0.0_f64; s];
621 row[i] = 1.0;
622 row
623 })
624 .collect();
625
626 let a_s: Vec<Vec<f64>> = (0..p)
627 .map(|i| {
628 support
629 .iter()
630 .map(|&j| {
631 if i < self.a_marginal.len() && j < self.a_marginal[i].len() {
632 self.a_marginal[i][j]
633 } else {
634 0.0
635 }
636 })
637 .collect()
638 })
639 .collect();
640
641 let x_s: Vec<f64> = support
643 .iter()
644 .map(|&j| if j < mu.len() { mu[j] } else { 0.0 })
645 .collect();
646
647 let dl_dx_s: Vec<f64> = support
648 .iter()
649 .map(|&j| if j < dl_dmu.len() { dl_dmu[j] } else { 0.0 })
650 .collect();
651
652 let kkt_grad = kkt_sensitivity(&q_s, &a_s, &x_s, nu, &dl_dx_s)?;
654
655 let mut dl_dtheta = vec![0.0_f64; n];
657 for (idx, &j) in support.iter().enumerate() {
658 if idx < kkt_grad.dx_adj.len() {
659 dl_dtheta[j] = kkt_grad.dx_adj[idx];
660 }
661 }
662
663 let _ = theta;
664 Ok(dl_dtheta)
665 }
666
667 pub fn project_simplex(v: &[f64]) -> Vec<f64> {
673 let n = v.len();
674 if n == 0 {
675 return vec![];
676 }
677
678 let mut u: Vec<f64> = v.to_vec();
679 u.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
680
681 let mut cssv = 0.0_f64;
682 let mut rho = 0_usize;
683 for j in 0..n {
684 cssv += u[j];
685 let tau = (cssv - 1.0) / (j + 1) as f64;
686 if tau < u[j] {
687 rho = j;
688 }
689 }
690
691 let cssv_rho: f64 = u[..=rho].iter().sum();
692 let theta = (cssv_rho - 1.0) / (rho + 1) as f64;
693
694 v.iter().map(|&vi| (vi - theta).max(0.0)).collect()
695 }
696}
697
698#[cfg(test)]
703mod tests {
704 use super::*;
705
706 fn argmax_binary(theta: &[f64]) -> Vec<f64> {
709 if theta.is_empty() {
710 return vec![];
711 }
712 let max_idx = theta
713 .iter()
714 .enumerate()
715 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
716 .map(|(i, _)| i)
717 .unwrap_or(0);
718 let mut y = vec![0.0_f64; theta.len()];
719 y[max_idx] = 1.0;
720 y
721 }
722
723 fn soft_sort_optimizer(theta: &[f64]) -> Vec<f64> {
725 let n = theta.len();
726 if n == 0 {
727 return vec![];
728 }
729 let mut indexed: Vec<(f64, usize)> = theta
730 .iter()
731 .cloned()
732 .enumerate()
733 .map(|(i, v)| (v, i))
734 .collect();
735 indexed.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
736 let mut rank = vec![0.0_f64; n];
737 for (r, (_, i)) in indexed.iter().enumerate() {
738 rank[*i] = (n - r) as f64 / n as f64;
739 }
740 rank
741 }
742
743 #[test]
744 fn test_perturbed_optimizer_config_default() {
745 let cfg = PerturbedOptimizerConfig::default();
746 assert_eq!(cfg.n_samples, 20);
747 assert!((cfg.sigma - 1.0).abs() < 1e-15);
748 }
749
750 #[test]
751 fn test_perturbed_optimizer_forward_shape() {
752 let mut opt = PerturbedOptimizer::new(argmax_binary);
753 let theta = vec![1.0, 2.0, 3.0_f64];
754
755 let y = opt.forward(&theta).expect("Forward failed");
756 assert_eq!(y.len(), 3, "Output length should match input");
757 for yi in &y {
759 assert!(*yi >= 0.0 && *yi <= 1.0, "y_i = {} should be in [0, 1]", yi);
760 }
761 }
762
763 #[test]
764 fn test_perturbed_optimizer_forward_distribution_sums_to_one() {
765 let cfg = PerturbedOptimizerConfig {
767 n_samples: 100,
768 sigma: 0.1, seed: 123,
770 };
771 let mut opt = PerturbedOptimizer::with_config(argmax_binary, cfg);
772 let theta = vec![1.0, 5.0, 2.0_f64]; let y = opt.forward(&theta).expect("Forward failed");
775 let sum: f64 = y.iter().sum();
776 assert!(
777 (sum - 1.0).abs() < 0.05,
778 "Sum = {} (expected ~1.0 for binary argmax)",
779 sum
780 );
781 }
782
783 #[test]
784 fn test_perturbed_optimizer_gradient_sign() {
785 let cfg = PerturbedOptimizerConfig {
796 n_samples: 1000,
797 sigma: 1.0,
798 seed: 42,
799 };
800 let mut opt = PerturbedOptimizer::with_config(argmax_binary, cfg);
801 let theta = vec![2.0, 0.0, 0.0_f64];
802
803 let _y = opt.forward(&theta).expect("Forward failed");
804
805 let grad = opt
808 .gradient(&theta, &[1.0, 0.0, 0.0])
809 .expect("Gradient failed");
810
811 assert_eq!(grad.len(), 3);
812 assert!(
816 grad[0] > -0.5, "grad[0] = {} should be roughly positive",
818 grad[0]
819 );
820 }
821
822 #[test]
823 fn test_perturbed_optimizer_gradient_shape() {
824 let mut opt = PerturbedOptimizer::new(argmax_binary);
825 let theta = vec![1.0, 2.0, 3.0_f64];
826
827 let _y = opt.forward(&theta).expect("Forward failed");
828 let grad = opt
829 .gradient(&theta, &[1.0, 0.0, 0.0])
830 .expect("Gradient failed");
831
832 assert_eq!(grad.len(), 3);
833 for gi in &grad {
834 assert!(gi.is_finite(), "grad not finite");
835 }
836 }
837
838 #[test]
839 fn test_perturbed_optimizer_reinforce_shape() {
840 let mut opt = PerturbedOptimizer::new(soft_sort_optimizer);
841 let theta = vec![1.0, 3.0, 2.0_f64];
842
843 let _y = opt.forward(&theta).expect("Forward failed");
844 let grad = opt
845 .reinforce_gradient(&theta, &[0.0, 1.0, 0.0])
846 .expect("REINFORCE failed");
847
848 assert_eq!(grad.len(), 3);
849 for gi in &grad {
850 assert!(gi.is_finite(), "REINFORCE grad not finite");
851 }
852 }
853
854 #[test]
855 fn test_perturbed_optimizer_no_forward_error() {
856 let opt = PerturbedOptimizer::new(argmax_binary);
857 let result = opt.gradient(&[1.0, 2.0], &[1.0, 0.0]);
858 assert!(result.is_err(), "Should error without forward pass");
859 }
860
861 #[test]
862 fn test_sparsemap_simplex_projection() {
863 let mut sm = SparseMap::simplex(3);
865 let theta = vec![1.0, 2.0, 0.5_f64];
866
867 let mu = sm.forward(&theta).expect("SparseMap forward failed");
868
869 for mi in &mu {
871 assert!(*mi >= -1e-6, "μ < 0: {}", mi);
872 }
873
874 let sum: f64 = mu.iter().sum();
876 assert!(
877 (sum - 1.0).abs() < 0.1,
878 "Σμ = {} (expected ~1.0 for simplex)",
879 sum
880 );
881 }
882
883 #[test]
884 fn test_sparsemap_backward_shape() {
885 let mut sm = SparseMap::simplex(4);
886 let theta = vec![1.0, 3.0, 2.0, 0.5_f64];
887
888 let _mu = sm.forward(&theta).expect("SparseMap forward failed");
889 let dl_dtheta = sm
890 .backward(&[1.0, 0.0, 0.0, 0.0])
891 .expect("SparseMap backward failed");
892
893 assert_eq!(dl_dtheta.len(), 4, "Gradient length mismatch");
894 for gi in &dl_dtheta {
895 assert!(gi.is_finite(), "SparseMap gradient not finite");
896 }
897 }
898
899 #[test]
900 fn test_sparsemap_no_forward_error() {
901 let sm = SparseMap::simplex(3);
902 let result = sm.backward(&[1.0, 0.0, 0.0]);
903 assert!(result.is_err(), "Should error without forward pass");
904 }
905
906 #[test]
907 fn test_project_simplex_properties() {
908 let v = vec![0.5, 1.5, -0.3, 2.0_f64];
909 let p = SparseMap::project_simplex(&v);
910
911 let sum: f64 = p.iter().sum();
913 assert!(
914 (sum - 1.0).abs() < 1e-10,
915 "Simplex sum = {} (expected 1.0)",
916 sum
917 );
918
919 for pi in &p {
921 assert!(*pi >= -1e-12, "Negative simplex component: {}", pi);
922 }
923 }
924
925 #[test]
926 fn test_project_simplex_uniform_input() {
927 let v = vec![0.5, 0.5_f64];
929 let p = SparseMap::project_simplex(&v);
930 assert!((p[0] - 0.5).abs() < 1e-10);
931 assert!((p[1] - 0.5).abs() < 1e-10);
932 }
933
934 #[test]
935 fn test_xorshift_reproducible() {
936 let mut rng1 = Xorshift64::new(42);
937 let mut rng2 = Xorshift64::new(42);
938 for _ in 0..100 {
939 assert_eq!(rng1.next_u64(), rng2.next_u64());
940 }
941 }
942
943 #[test]
944 fn test_xorshift_normal_finite() {
945 let mut rng = Xorshift64::new(12345);
946 for _ in 0..100 {
947 let v = rng.normal();
948 assert!(v.is_finite(), "Normal sample not finite: {}", v);
949 }
950 }
951}