1use crate::error::TlBackendResult;
19use crate::Scirs2Tensor;
20use scirs2_core::ndarray::{Array, ArrayD, Axis};
21use scirs2_core::random::arrays::OptimizedArrayRandom;
22use scirs2_core::random::prelude::*;
23
24#[derive(Debug, Clone, Copy)]
31pub struct SteConfig {
32 pub threshold: f64,
34 pub clip_gradients: bool,
36}
37
38impl Default for SteConfig {
39 fn default() -> Self {
40 Self {
41 threshold: 0.5,
42 clip_gradients: false,
43 }
44 }
45}
46
47#[derive(Debug, Clone, Copy)]
52pub struct GumbelSoftmaxConfig {
53 pub temperature: f64,
56 pub hard: bool,
59 pub seed: Option<u64>,
61}
62
63impl Default for GumbelSoftmaxConfig {
64 fn default() -> Self {
65 Self {
66 temperature: 1.0,
67 hard: false,
68 seed: None,
69 }
70 }
71}
72
73#[derive(Debug, Clone, Copy)]
79pub enum QuantifierMode {
80 Hard,
82 Smooth { temperature: f64 },
84 Probabilistic,
86}
87
88impl Default for QuantifierMode {
89 fn default() -> Self {
90 Self::Smooth { temperature: 1.0 }
91 }
92}
93
94pub fn ste_threshold(input: &Scirs2Tensor, config: SteConfig) -> TlBackendResult<Scirs2Tensor> {
106 let output = input.mapv(|x| if x >= config.threshold { 1.0 } else { 0.0 });
108 Ok(output)
109}
110
111pub fn ste_threshold_backward(
121 grad_output: &Scirs2Tensor,
122 _input: &Scirs2Tensor,
123 config: SteConfig,
124) -> TlBackendResult<Scirs2Tensor> {
125 if config.clip_gradients {
126 Ok(grad_output.mapv(|g| g.clamp(-1.0, 1.0)))
128 } else {
129 Ok(grad_output.clone())
131 }
132}
133
134pub fn gumbel_softmax(
148 logits: &Scirs2Tensor,
149 config: GumbelSoftmaxConfig,
150) -> TlBackendResult<Scirs2Tensor> {
151 let gumbel_noise = sample_gumbel(logits.shape(), config.seed)?;
153
154 let noisy_logits = logits + &gumbel_noise;
156
157 let soft_samples = softmax_temperature(&noisy_logits, config.temperature)?;
159
160 if config.hard {
161 let hard_samples = argmax_to_onehot(&soft_samples)?;
163 Ok(hard_samples)
164 } else {
165 Ok(soft_samples)
166 }
167}
168
169pub fn gumbel_softmax_backward(
182 grad_output: &Scirs2Tensor,
183 soft_samples: &Scirs2Tensor,
184 config: GumbelSoftmaxConfig,
185) -> TlBackendResult<Scirs2Tensor> {
186 let last_axis = soft_samples.ndim() - 1;
191 let dot_product = (soft_samples * grad_output)
192 .sum_axis(Axis(last_axis))
193 .insert_axis(Axis(last_axis));
194
195 let grad_logits = soft_samples * &(grad_output - &dot_product);
197
198 Ok(grad_logits.mapv(|g| g / config.temperature))
200}
201
202pub fn soft_exists(
214 input: &Scirs2Tensor,
215 axis: Option<usize>,
216 mode: QuantifierMode,
217) -> TlBackendResult<Scirs2Tensor> {
218 match mode {
219 QuantifierMode::Hard => {
220 if let Some(ax) = axis {
222 Ok(input.map_axis(Axis(ax), |slice| {
223 slice.iter().fold(0.0_f64, |a, &b| a.max(b))
224 }))
225 } else {
226 let max_val = input.iter().fold(0.0_f64, |a, &b| a.max(b));
227 Ok(Array::from_elem(vec![], max_val))
228 }
229 }
230 QuantifierMode::Smooth { temperature } => {
231 smooth_max(input, axis, temperature)
233 }
234 QuantifierMode::Probabilistic => {
235 probabilistic_exists(input, axis)
238 }
239 }
240}
241
242pub fn soft_exists_backward(
254 grad_output: &Scirs2Tensor,
255 input: &Scirs2Tensor,
256 _output: &Scirs2Tensor,
257 axis: Option<usize>,
258 mode: QuantifierMode,
259) -> TlBackendResult<Scirs2Tensor> {
260 match mode {
261 QuantifierMode::Hard => {
262 argmax_gradient(grad_output, input, axis)
265 }
266 QuantifierMode::Smooth { temperature } => {
267 smooth_max_gradient(grad_output, input, temperature, axis)
269 }
270 QuantifierMode::Probabilistic => {
271 probabilistic_exists_gradient(grad_output, input, axis)
273 }
274 }
275}
276
277pub fn soft_forall(
289 input: &Scirs2Tensor,
290 axis: Option<usize>,
291 mode: QuantifierMode,
292) -> TlBackendResult<Scirs2Tensor> {
293 match mode {
296 QuantifierMode::Hard => {
297 if let Some(ax) = axis {
299 Ok(input.map_axis(Axis(ax), |slice| {
300 slice.iter().fold(1.0_f64, |a, &b| a.min(b))
301 }))
302 } else {
303 let min_val = input.iter().fold(1.0_f64, |a, &b| a.min(b));
304 Ok(Array::from_elem(vec![], min_val))
305 }
306 }
307 QuantifierMode::Smooth { temperature } => {
308 smooth_min(input, axis, temperature)
310 }
311 QuantifierMode::Probabilistic => {
312 probabilistic_forall(input, axis)
314 }
315 }
316}
317
318pub fn soft_forall_backward(
330 grad_output: &Scirs2Tensor,
331 input: &Scirs2Tensor,
332 output: &Scirs2Tensor,
333 axis: Option<usize>,
334 mode: QuantifierMode,
335) -> TlBackendResult<Scirs2Tensor> {
336 match mode {
337 QuantifierMode::Hard => {
338 argmin_gradient(grad_output, input, axis)
340 }
341 QuantifierMode::Smooth { temperature } => {
342 smooth_min_gradient(grad_output, input, temperature, axis)
344 }
345 QuantifierMode::Probabilistic => {
346 probabilistic_forall_gradient(grad_output, input, output, axis)
348 }
349 }
350}
351
352fn sample_gumbel(shape: &[usize], seed: Option<u64>) -> TlBackendResult<Scirs2Tensor> {
358 use scirs2_core::ndarray::IxDyn;
359
360 let uniform_dist = Uniform::new(1e-10, 1.0 - 1e-10).unwrap(); let dyn_shape = IxDyn(shape);
362
363 let gumbel = if let Some(s) = seed {
364 let mut rng = seeded_rng(s);
365 ArrayD::random_bulk(dyn_shape, uniform_dist, &mut rng)
366 } else {
367 let mut rng = thread_rng();
368 ArrayD::random_bulk(dyn_shape, uniform_dist, &mut rng)
369 };
370
371 let gumbel = gumbel.mapv(|u: f64| -(-u.ln()).ln());
373 Ok(gumbel)
374}
375
376fn softmax_temperature(logits: &Scirs2Tensor, temperature: f64) -> TlBackendResult<Scirs2Tensor> {
378 let scaled = logits.mapv(|x| x / temperature);
380
381 let last_axis = scaled.ndim() - 1;
383
384 let max_vals = scaled
386 .map_axis(Axis(last_axis), |slice| {
387 slice.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b))
388 })
389 .insert_axis(Axis(last_axis));
390
391 let exp_vals = (&scaled - &max_vals).mapv(|x| x.exp());
392 let sum_exp = exp_vals
393 .sum_axis(Axis(last_axis))
394 .insert_axis(Axis(last_axis));
395
396 Ok(&exp_vals / &sum_exp)
397}
398
399fn argmax_to_onehot(soft_samples: &Scirs2Tensor) -> TlBackendResult<Scirs2Tensor> {
401 let last_axis = soft_samples.ndim() - 1;
402 let mut onehot = ArrayD::zeros(soft_samples.raw_dim());
403
404 let n_classes = soft_samples.len_of(Axis(last_axis));
406
407 for i in 0..soft_samples.len() / n_classes {
409 let mut flat_idx = i;
411 let mut indices = vec![0; soft_samples.ndim()];
412
413 for dim in (0..last_axis).rev() {
414 let size = soft_samples.len_of(Axis(dim));
415 indices[dim] = flat_idx % size;
416 flat_idx /= size;
417 }
418
419 let mut max_val = f64::NEG_INFINITY;
421 let mut max_idx = 0;
422
423 for j in 0..n_classes {
424 indices[last_axis] = j;
425 let val = soft_samples[&indices[..]];
426 if val > max_val {
427 max_val = val;
428 max_idx = j;
429 }
430 }
431
432 indices[last_axis] = max_idx;
434 onehot[&indices[..]] = 1.0;
435 }
436
437 Ok(onehot)
438}
439
440fn smooth_max(
442 input: &Scirs2Tensor,
443 axis: Option<usize>,
444 temperature: f64,
445) -> TlBackendResult<Scirs2Tensor> {
446 let scaled = input.mapv(|x| x / temperature);
447
448 if let Some(ax) = axis {
449 let max_vals = scaled.map_axis(Axis(ax), |slice| {
451 slice.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b))
452 });
453
454 let max_vals_broadcast = max_vals.clone().insert_axis(Axis(ax));
456 let exp_vals = (&scaled - &max_vals_broadcast).mapv(|x| x.exp());
457 let sum_exp = exp_vals.sum_axis(Axis(ax));
458 let log_sum_exp = &max_vals + &sum_exp.mapv(|x| x.ln());
459
460 Ok(log_sum_exp.mapv(|x| x * temperature))
461 } else {
462 let max_val = scaled.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
463 let exp_vals = scaled.mapv(|x| (x - max_val).exp());
464 let sum_exp: f64 = exp_vals.iter().sum();
465 let result = temperature * (max_val + sum_exp.ln());
466 Ok(Array::from_elem(vec![], result))
467 }
468}
469
470fn smooth_max_gradient(
472 grad_output: &Scirs2Tensor,
473 input: &Scirs2Tensor,
474 temperature: f64,
475 axis: Option<usize>,
476) -> TlBackendResult<Scirs2Tensor> {
477 let scaled = input.mapv(|x| x / temperature);
479
480 if let Some(ax) = axis {
481 let max_vals = scaled
482 .map_axis(Axis(ax), |slice| {
483 slice.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b))
484 })
485 .insert_axis(Axis(ax));
486
487 let exp_vals = (&scaled - &max_vals).mapv(|x| x.exp());
488 let sum_exp = exp_vals.sum_axis(Axis(ax)).insert_axis(Axis(ax));
489 let weights = &exp_vals / &sum_exp;
490
491 Ok(&weights * grad_output)
493 } else {
494 let max_val = scaled.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
495 let exp_vals = scaled.mapv(|x| (x - max_val).exp());
496 let sum_exp: f64 = exp_vals.iter().sum();
497 let weights = exp_vals.mapv(|x| x / sum_exp);
498
499 let grad_scalar = grad_output.iter().next().unwrap_or(&0.0);
500 Ok(weights.mapv(|w| w * grad_scalar))
501 }
502}
503
504fn smooth_min(
506 input: &Scirs2Tensor,
507 axis: Option<usize>,
508 temperature: f64,
509) -> TlBackendResult<Scirs2Tensor> {
510 let negated = input.mapv(|x| -x);
512 let result = smooth_max(&negated, axis, temperature)?;
513 Ok(result.mapv(|x| -x))
514}
515
516fn smooth_min_gradient(
518 grad_output: &Scirs2Tensor,
519 input: &Scirs2Tensor,
520 temperature: f64,
521 axis: Option<usize>,
522) -> TlBackendResult<Scirs2Tensor> {
523 let negated = input.mapv(|x| -x);
525 let grad = smooth_max_gradient(grad_output, &negated, temperature, axis)?;
526 Ok(grad.mapv(|g| -g))
527}
528
529fn probabilistic_exists(
531 input: &Scirs2Tensor,
532 axis: Option<usize>,
533) -> TlBackendResult<Scirs2Tensor> {
534 let one_minus_input = input.mapv(|x| 1.0 - x);
535
536 if let Some(ax) = axis {
537 let product = one_minus_input.map_axis(Axis(ax), |slice| slice.iter().product::<f64>());
538 Ok(product.mapv(|p| 1.0 - p))
539 } else {
540 let product: f64 = one_minus_input.iter().product();
541 Ok(Array::from_elem(vec![], 1.0 - product))
542 }
543}
544
545fn probabilistic_exists_gradient(
547 grad_output: &Scirs2Tensor,
548 input: &Scirs2Tensor,
549 axis: Option<usize>,
550) -> TlBackendResult<Scirs2Tensor> {
551 let one_minus_input = input.mapv(|x| 1.0 - x);
553
554 if let Some(ax) = axis {
555 let mut grad = ArrayD::zeros(input.raw_dim());
557
558 for i in 0..input.len_of(Axis(ax)) {
559 let mut slice = input.index_axis(Axis(ax), i).to_owned();
560 let product: f64 = one_minus_input.iter().product();
562 let elem_val = 1.0 - input.index_axis(Axis(ax), i).iter().next().unwrap_or(&0.0);
563 let grad_elem = if elem_val.abs() > 1e-10 {
564 product / elem_val
565 } else {
566 0.0
567 };
568
569 slice.fill(grad_elem);
570 grad.index_axis_mut(Axis(ax), i).assign(&slice);
571 }
572
573 Ok(&grad * grad_output)
574 } else {
575 let product: f64 = one_minus_input.iter().product();
576 let grad = input.mapv(|x| {
577 let denom = 1.0 - x;
578 if denom.abs() > 1e-10 {
579 product / denom
580 } else {
581 0.0
582 }
583 });
584
585 let grad_scalar = grad_output.iter().next().unwrap_or(&0.0);
586 Ok(grad.mapv(|g| g * grad_scalar))
587 }
588}
589
590fn probabilistic_forall(
592 input: &Scirs2Tensor,
593 axis: Option<usize>,
594) -> TlBackendResult<Scirs2Tensor> {
595 if let Some(ax) = axis {
596 Ok(input.map_axis(Axis(ax), |slice| slice.iter().product::<f64>()))
597 } else {
598 let product: f64 = input.iter().product();
599 Ok(Array::from_elem(vec![], product))
600 }
601}
602
603fn probabilistic_forall_gradient(
605 grad_output: &Scirs2Tensor,
606 input: &Scirs2Tensor,
607 output: &Scirs2Tensor,
608 axis: Option<usize>,
609) -> TlBackendResult<Scirs2Tensor> {
610 if let Some(_ax) = axis {
613 let grad = output / input;
615 Ok(&grad * grad_output)
616 } else {
617 let output_val = output.iter().next().unwrap_or(&0.0);
618 let grad = input.mapv(|x| if x.abs() > 1e-10 { output_val / x } else { 0.0 });
619
620 let grad_scalar = grad_output.iter().next().unwrap_or(&0.0);
621 Ok(grad.mapv(|g| g * grad_scalar))
622 }
623}
624
625fn argmax_gradient(
627 grad_output: &Scirs2Tensor,
628 input: &Scirs2Tensor,
629 axis: Option<usize>,
630) -> TlBackendResult<Scirs2Tensor> {
631 let mut grad = ArrayD::zeros(input.raw_dim());
632
633 if let Some(ax) = axis {
634 for i in 0..input.len_of(Axis(ax)) {
635 let slice = input.index_axis(Axis(ax), i);
636 let max_idx = slice
637 .iter()
638 .enumerate()
639 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
640 .map(|(idx, _)| idx)
641 .unwrap_or(0);
642
643 grad.index_axis_mut(Axis(ax), i)[max_idx] = *grad_output
644 .index_axis(Axis(ax), i)
645 .iter()
646 .next()
647 .unwrap_or(&0.0);
648 }
649 } else {
650 let max_idx = input
651 .iter()
652 .enumerate()
653 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
654 .map(|(idx, _)| idx)
655 .unwrap_or(0);
656
657 grad.as_slice_mut().unwrap()[max_idx] = *grad_output.iter().next().unwrap_or(&0.0);
658 }
659
660 Ok(grad)
661}
662
663fn argmin_gradient(
665 grad_output: &Scirs2Tensor,
666 input: &Scirs2Tensor,
667 axis: Option<usize>,
668) -> TlBackendResult<Scirs2Tensor> {
669 let mut grad = ArrayD::zeros(input.raw_dim());
670
671 if let Some(ax) = axis {
672 for i in 0..input.len_of(Axis(ax)) {
673 let slice = input.index_axis(Axis(ax), i);
674 let min_idx = slice
675 .iter()
676 .enumerate()
677 .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
678 .map(|(idx, _)| idx)
679 .unwrap_or(0);
680
681 grad.index_axis_mut(Axis(ax), i)[min_idx] = *grad_output
682 .index_axis(Axis(ax), i)
683 .iter()
684 .next()
685 .unwrap_or(&0.0);
686 }
687 } else {
688 let min_idx = input
689 .iter()
690 .enumerate()
691 .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
692 .map(|(idx, _)| idx)
693 .unwrap_or(0);
694
695 grad.as_slice_mut().unwrap()[min_idx] = *grad_output.iter().next().unwrap_or(&0.0);
696 }
697
698 Ok(grad)
699}
700
701#[cfg(test)]
702mod tests {
703 use super::*;
704 use scirs2_core::ndarray::array;
705
706 #[test]
707 fn test_ste_threshold_forward() {
708 let input = array![[0.2, 0.6], [0.4, 0.8]].into_dyn();
709 let config = SteConfig::default();
710
711 let output = ste_threshold(&input, config).unwrap();
712 let expected = array![[0.0, 1.0], [0.0, 1.0]].into_dyn();
713
714 assert_eq!(output, expected);
715 }
716
717 #[test]
718 fn test_ste_threshold_backward() {
719 let grad_output = array![[1.0, 2.0], [3.0, 4.0]].into_dyn();
720 let input = array![[0.2, 0.6], [0.4, 0.8]].into_dyn();
721 let config = SteConfig::default();
722
723 let grad_input = ste_threshold_backward(&grad_output, &input, config).unwrap();
724
725 assert_eq!(grad_input, grad_output);
727 }
728
729 #[test]
730 fn test_ste_gradient_clipping() {
731 let grad_output = array![[5.0, -3.0], [0.5, -10.0]].into_dyn();
732 let input = array![[0.2, 0.6], [0.4, 0.8]].into_dyn();
733 let config = SteConfig {
734 threshold: 0.5,
735 clip_gradients: true,
736 };
737
738 let grad_input = ste_threshold_backward(&grad_output, &input, config).unwrap();
739 let expected = array![[1.0, -1.0], [0.5, -1.0]].into_dyn();
740
741 assert_eq!(grad_input, expected);
742 }
743
744 #[test]
745 fn test_gumbel_softmax_deterministic() {
746 let logits = array![[1.0, 2.0, 3.0]].into_dyn();
747 let config = GumbelSoftmaxConfig {
748 temperature: 1.0,
749 hard: false,
750 seed: Some(42),
751 };
752
753 let samples = gumbel_softmax(&logits, config).unwrap();
754
755 assert_eq!(samples.shape(), &[1, 3]);
757 let sum: f64 = samples.iter().sum();
758 assert!((sum - 1.0).abs() < 1e-6);
759
760 for &val in samples.iter() {
762 assert!((0.0..=1.0).contains(&val));
763 }
764 }
765
766 #[test]
767 fn test_gumbel_softmax_hard_mode() {
768 let logits = array![[1.0, 5.0, 2.0]].into_dyn();
769 let config = GumbelSoftmaxConfig {
770 temperature: 0.1,
771 hard: true,
772 seed: Some(123),
773 };
774
775 let samples = gumbel_softmax(&logits, config).unwrap();
776
777 let sum: f64 = samples.iter().sum();
779 assert!((sum - 1.0).abs() < 1e-6);
780
781 let max_val = samples.iter().fold(0.0_f64, |a, &b| a.max(b));
783 assert!(max_val >= 0.9);
784 }
785
786 #[test]
787 fn test_soft_exists_smooth() {
788 let input = array![[0.1, 0.3], [0.2, 0.9]].into_dyn();
789 let mode = QuantifierMode::Smooth { temperature: 1.0 };
790
791 let output = soft_exists(&input, Some(1), mode).unwrap();
792
793 assert_eq!(output.shape(), &[2]);
797 assert!(
798 output[0] >= 0.85 && output[0] <= 0.95,
799 "output[0] = {} not in [0.85, 0.95]",
800 output[0]
801 );
802 assert!(
803 output[1] >= 1.25 && output[1] <= 1.35,
804 "output[1] = {} not in [1.25, 1.35]",
805 output[1]
806 );
807 }
808
809 #[test]
810 fn test_soft_exists_probabilistic() {
811 let input = array![[0.5, 0.5]].into_dyn();
812 let mode = QuantifierMode::Probabilistic;
813
814 let output = soft_exists(&input, Some(1), mode).unwrap();
815
816 assert!((output[0] - 0.75).abs() < 1e-6);
818 }
819
820 #[test]
821 fn test_soft_forall_probabilistic() {
822 let input = array![[0.5, 0.5]].into_dyn();
823 let mode = QuantifierMode::Probabilistic;
824
825 let output = soft_forall(&input, Some(1), mode).unwrap();
826
827 assert!((output[0] - 0.25).abs() < 1e-6);
829 }
830
831 #[test]
832 fn test_probabilistic_forall_gradient() {
833 let input = array![[0.5, 0.8]].into_dyn();
834 let output = array![0.4].into_dyn(); let grad_output = array![1.0].into_dyn();
836
837 let grad_input =
838 probabilistic_forall_gradient(&grad_output, &input, &output, Some(1)).unwrap();
839
840 assert!((grad_input[[0, 0]] - 0.8).abs() < 1e-6);
843 assert!((grad_input[[0, 1]] - 0.5).abs() < 1e-6);
844 }
845
846 #[test]
847 fn test_smooth_max_vs_hard_max() {
848 let input = array![[1.0, 2.0, 3.0]].into_dyn();
849
850 let hard = soft_exists(&input, Some(1), QuantifierMode::Hard).unwrap();
852 assert!((hard[0] - 3.0).abs() < 1e-6);
853
854 let smooth = soft_exists(
856 &input,
857 Some(1),
858 QuantifierMode::Smooth { temperature: 0.01 },
859 )
860 .unwrap();
861 assert!((smooth[0] - 3.0).abs() < 0.1); }
863
864 #[test]
865 fn test_gumbel_noise_properties() {
866 let shape = &[1000];
868 let noise = sample_gumbel(shape, Some(42)).unwrap();
869
870 let mean: f64 = noise.iter().sum::<f64>() / noise.len() as f64;
872 assert!((mean - 0.5772).abs() < 0.1); for &val in noise.iter() {
876 assert!(val.is_finite());
877 }
878 }
879}