1use scirs2_core::ndarray::{Array, ArrayD, Axis, IxDyn};
18
19#[derive(Debug, thiserror::Error)]
21pub enum ScoringError {
22 #[error("Shape mismatch: input {input:?}, weights {weights:?}")]
24 ShapeMismatch {
25 input: Vec<usize>,
27 weights: Vec<usize>,
29 },
30 #[error("Axis {axis} out of range for {ndim}D tensor")]
32 AxisOutOfRange {
33 axis: usize,
35 ndim: usize,
37 },
38 #[error("Division by zero in weight normalization")]
40 ZeroWeightSum,
41 #[error("Invalid probability value {value}: must be in [0, 1]")]
43 InvalidProbability {
44 value: f64,
46 },
47 #[error("Empty input tensor")]
49 EmptyInput,
50}
51
52#[derive(Debug, Clone, Copy, PartialEq)]
54pub enum ScoringMode {
55 Standard,
57 LogProbability,
59 LogOdds,
61}
62
63#[derive(Debug, Clone)]
65pub struct ScoringConfig {
66 pub mode: ScoringMode,
68 pub log_floor: f64,
71 pub temperature: f64,
74}
75
76impl Default for ScoringConfig {
77 fn default() -> Self {
78 Self {
79 mode: ScoringMode::Standard,
80 log_floor: f64::MIN_POSITIVE.ln(), temperature: 1.0,
82 }
83 }
84}
85
86impl ScoringConfig {
87 pub fn log_probability() -> Self {
89 Self {
90 mode: ScoringMode::LogProbability,
91 ..Self::default()
92 }
93 }
94
95 pub fn log_odds() -> Self {
97 Self {
98 mode: ScoringMode::LogOdds,
99 ..Self::default()
100 }
101 }
102
103 pub fn with_temperature(mut self, t: f64) -> Self {
105 self.temperature = t;
106 self
107 }
108}
109
110fn log_sum_exp_slice(slice: &[f64], log_floor: f64) -> f64 {
118 if slice.is_empty() {
119 return log_floor;
120 }
121 let max = slice.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
122 if max == f64::NEG_INFINITY {
123 return log_floor;
124 }
125 let sum_exp: f64 = slice.iter().map(|&x| (x - max).exp()).sum();
126 max + sum_exp.ln()
127}
128
129fn log_sum_exp_along_axis(
133 input: &ArrayD<f64>,
134 axis: usize,
135 log_floor: f64,
136) -> Result<ArrayD<f64>, ScoringError> {
137 if axis >= input.ndim() {
138 return Err(ScoringError::AxisOutOfRange {
139 axis,
140 ndim: input.ndim(),
141 });
142 }
143 if input.is_empty() {
144 return Err(ScoringError::EmptyInput);
145 }
146 Ok(input.map_axis(Axis(axis), |lane| {
147 let s: Vec<f64> = lane.iter().cloned().collect();
148 log_sum_exp_slice(&s, log_floor)
149 }))
150}
151
152fn log_product_along_axis(input: &ArrayD<f64>, axis: usize) -> Result<ArrayD<f64>, ScoringError> {
154 if axis >= input.ndim() {
155 return Err(ScoringError::AxisOutOfRange {
156 axis,
157 ndim: input.ndim(),
158 });
159 }
160 if input.is_empty() {
161 return Err(ScoringError::EmptyInput);
162 }
163 Ok(input.map_axis(Axis(axis), |lane| lane.iter().sum::<f64>()))
164}
165
166pub struct LogSpaceAggregator {
175 config: ScoringConfig,
176}
177
178impl LogSpaceAggregator {
179 pub fn new(config: ScoringConfig) -> Self {
181 Self { config }
182 }
183
184 pub fn log_sum_exp(
196 &self,
197 input: &ArrayD<f64>,
198 axis: Option<usize>,
199 ) -> Result<ArrayD<f64>, ScoringError> {
200 if input.is_empty() {
201 return Err(ScoringError::EmptyInput);
202 }
203 match axis {
204 None => {
205 let flat: Vec<f64> = input.iter().cloned().collect();
206 let result = log_sum_exp_slice(&flat, self.config.log_floor);
207 Ok(ArrayD::from_elem(IxDyn(&[]), result))
208 }
209 Some(ax) => log_sum_exp_along_axis(input, ax, self.config.log_floor),
210 }
211 }
212
213 pub fn log_product(
219 &self,
220 input: &ArrayD<f64>,
221 axis: Option<usize>,
222 ) -> Result<ArrayD<f64>, ScoringError> {
223 if input.is_empty() {
224 return Err(ScoringError::EmptyInput);
225 }
226 match axis {
227 None => {
228 let result: f64 = input.iter().sum();
229 let result = result.max(self.config.log_floor);
231 Ok(ArrayD::from_elem(IxDyn(&[]), result))
232 }
233 Some(ax) => {
234 let out = log_product_along_axis(input, ax)?;
235 Ok(out.mapv(|v| v.max(self.config.log_floor)))
236 }
237 }
238 }
239
240 pub fn log_add_exp(
248 &self,
249 a: &ArrayD<f64>,
250 b: &ArrayD<f64>,
251 ) -> Result<ArrayD<f64>, ScoringError> {
252 if a.shape() != b.shape() {
253 return Err(ScoringError::ShapeMismatch {
254 input: a.shape().to_vec(),
255 weights: b.shape().to_vec(),
256 });
257 }
258 let result = a.mapv(|_| 0.0_f64); let result = scirs2_core::ndarray::Zip::from(&result)
261 .and(a)
262 .and(b)
263 .map_collect(|_, &ai, &bi| {
264 let max = ai.max(bi);
265 let min = ai.min(bi);
266 if max == f64::NEG_INFINITY {
267 self.config.log_floor
268 } else {
269 max + (1.0_f64 + (min - max).exp()).ln()
270 }
271 });
272 Ok(result)
273 }
274
275 pub fn to_log_space(&self, probs: &ArrayD<f64>) -> Result<ArrayD<f64>, ScoringError> {
280 for &v in probs.iter() {
282 if !v.is_finite() || !(0.0..=1.0).contains(&v) {
283 return Err(ScoringError::InvalidProbability { value: v });
284 }
285 }
286 let floor = self.config.log_floor;
287 Ok(probs.mapv(|p| if p <= 0.0 { floor } else { p.ln().max(floor) }))
288 }
289
290 pub fn from_log_space(&self, log_probs: &ArrayD<f64>) -> Result<ArrayD<f64>, ScoringError> {
295 Ok(log_probs.mapv(|lp| lp.exp()))
296 }
297}
298
299fn validate_weights_for_axis(
308 input: &ArrayD<f64>,
309 weights: &ArrayD<f64>,
310 axis: Option<usize>,
311) -> Result<(), ScoringError> {
312 match axis {
313 None => {
314 if weights.shape() != input.shape() && weights.len() != input.len() {
317 return Err(ScoringError::ShapeMismatch {
318 input: input.shape().to_vec(),
319 weights: weights.shape().to_vec(),
320 });
321 }
322 }
323 Some(ax) => {
324 if ax >= input.ndim() {
325 return Err(ScoringError::AxisOutOfRange {
326 axis: ax,
327 ndim: input.ndim(),
328 });
329 }
330 let expected_len = input.shape()[ax];
331 let compatible = weights.shape() == input.shape()
333 || (weights.ndim() == 1 && weights.len() == expected_len);
334 if !compatible {
335 return Err(ScoringError::ShapeMismatch {
336 input: input.shape().to_vec(),
337 weights: weights.shape().to_vec(),
338 });
339 }
340 }
341 }
342 Ok(())
343}
344
345pub struct WeightedQuantifier {
350 config: ScoringConfig,
351}
352
353impl WeightedQuantifier {
354 pub fn new(config: ScoringConfig) -> Self {
356 Self { config }
357 }
358
359 pub fn weighted_exists(
370 &self,
371 input: &ArrayD<f64>,
372 weights: &ArrayD<f64>,
373 axis: Option<usize>,
374 ) -> Result<ArrayD<f64>, ScoringError> {
375 if input.is_empty() {
376 return Err(ScoringError::EmptyInput);
377 }
378 validate_weights_for_axis(input, weights, axis)?;
379
380 match self.config.mode {
381 ScoringMode::Standard => self.weighted_exists_standard(input, weights, axis),
382 ScoringMode::LogProbability | ScoringMode::LogOdds => {
383 self.weighted_exists_log(input, weights, axis)
384 }
385 }
386 }
387
388 fn weighted_exists_standard(
389 &self,
390 input: &ArrayD<f64>,
391 weights: &ArrayD<f64>,
392 axis: Option<usize>,
393 ) -> Result<ArrayD<f64>, ScoringError> {
394 let w = broadcast_weights(weights, input, axis)?;
396
397 let weight_sum: f64 = w.iter().sum();
398 if weight_sum == 0.0 {
399 return Err(ScoringError::ZeroWeightSum);
400 }
401
402 match axis {
403 None => {
404 let numerator: f64 = input.iter().zip(w.iter()).map(|(&x, &wi)| wi * x).sum();
405 let result = numerator / weight_sum;
406 Ok(ArrayD::from_elem(IxDyn(&[]), result))
407 }
408 Some(ax) => {
409 let weighted = input * &w;
410 let num = weighted.sum_axis(Axis(ax));
411 let w_sum = w.sum_axis(Axis(ax));
413 let result = scirs2_core::ndarray::Zip::from(&num)
415 .and(&w_sum)
416 .map_collect(|&n, &ws| if ws == 0.0 { 0.0 } else { n / ws });
417 Ok(result)
418 }
419 }
420 }
421
422 fn weighted_exists_log(
423 &self,
424 input: &ArrayD<f64>,
425 weights: &ArrayD<f64>,
426 axis: Option<usize>,
427 ) -> Result<ArrayD<f64>, ScoringError> {
428 let w = broadcast_weights(weights, input, axis)?;
430 let weight_sum: f64 = w.iter().sum();
431 if weight_sum == 0.0 {
432 return Err(ScoringError::ZeroWeightSum);
433 }
434 let log_norm = weight_sum.ln();
435 let floor = self.config.log_floor;
436
437 let log_w_plus_x =
439 scirs2_core::ndarray::Zip::from(&w)
440 .and(input)
441 .map_collect(|&wi, &xi| {
442 if wi <= 0.0 {
443 floor
444 } else {
445 (wi.ln() + xi).max(floor)
446 }
447 });
448
449 let agg = LogSpaceAggregator::new(self.config.clone());
450 let lse = agg.log_sum_exp(&log_w_plus_x, axis)?;
451 Ok(lse.mapv(|v| v - log_norm))
452 }
453
454 pub fn weighted_forall(
465 &self,
466 input: &ArrayD<f64>,
467 weights: &ArrayD<f64>,
468 axis: Option<usize>,
469 ) -> Result<ArrayD<f64>, ScoringError> {
470 if input.is_empty() {
471 return Err(ScoringError::EmptyInput);
472 }
473 validate_weights_for_axis(input, weights, axis)?;
474
475 match self.config.mode {
476 ScoringMode::Standard => self.weighted_forall_standard(input, weights, axis),
477 ScoringMode::LogProbability | ScoringMode::LogOdds => {
478 self.weighted_forall_log(input, weights, axis)
479 }
480 }
481 }
482
483 fn weighted_forall_standard(
484 &self,
485 input: &ArrayD<f64>,
486 weights: &ArrayD<f64>,
487 axis: Option<usize>,
488 ) -> Result<ArrayD<f64>, ScoringError> {
489 let w = broadcast_weights(weights, input, axis)?;
490 let weight_sum: f64 = w.iter().sum();
491 if weight_sum == 0.0 {
492 return Err(ScoringError::ZeroWeightSum);
493 }
494
495 let log_input = input.mapv(|x| {
497 if x <= 0.0 {
498 self.config.log_floor
499 } else {
500 x.ln()
501 }
502 });
503
504 match axis {
505 None => {
506 let log_geo: f64 = log_input
507 .iter()
508 .zip(w.iter())
509 .map(|(&lx, &wi)| lx * wi / weight_sum)
510 .sum();
511 Ok(ArrayD::from_elem(IxDyn(&[]), log_geo.exp()))
512 }
513 Some(ax) => {
514 let w_sum_ax = w.sum_axis(Axis(ax));
515 let weighted_log = &log_input * &w;
516 let num = weighted_log.sum_axis(Axis(ax));
517 let result = scirs2_core::ndarray::Zip::from(&num)
518 .and(&w_sum_ax)
519 .map_collect(|&n, &ws| {
520 if ws == 0.0 {
521 1.0 } else {
523 (n / ws).exp()
524 }
525 });
526 Ok(result)
527 }
528 }
529 }
530
531 fn weighted_forall_log(
532 &self,
533 input: &ArrayD<f64>,
534 weights: &ArrayD<f64>,
535 axis: Option<usize>,
536 ) -> Result<ArrayD<f64>, ScoringError> {
537 let w = broadcast_weights(weights, input, axis)?;
539 let weight_sum: f64 = w.iter().sum();
540 if weight_sum == 0.0 {
541 return Err(ScoringError::ZeroWeightSum);
542 }
543
544 match axis {
545 None => {
546 let result: f64 = input
547 .iter()
548 .zip(w.iter())
549 .map(|(&xi, &wi)| xi * wi / weight_sum)
550 .sum();
551 Ok(ArrayD::from_elem(IxDyn(&[]), result))
552 }
553 Some(ax) => {
554 let w_sum_ax = w.sum_axis(Axis(ax));
555 let weighted = input * &w;
556 let num = weighted.sum_axis(Axis(ax));
557 let result = scirs2_core::ndarray::Zip::from(&num)
558 .and(&w_sum_ax)
559 .map_collect(|&n, &ws| if ws == 0.0 { 0.0 } else { n / ws });
560 Ok(result)
561 }
562 }
563 }
564
565 pub fn weighted_exists_grad(
575 &self,
576 grad: &ArrayD<f64>,
577 input: &ArrayD<f64>,
578 weights: &ArrayD<f64>,
579 axis: Option<usize>,
580 ) -> Result<ArrayD<f64>, ScoringError> {
581 if input.is_empty() {
582 return Err(ScoringError::EmptyInput);
583 }
584 validate_weights_for_axis(input, weights, axis)?;
585
586 let w = broadcast_weights(weights, input, axis)?;
587 let weight_sum: f64 = w.iter().sum();
588 if weight_sum == 0.0 {
589 return Err(ScoringError::ZeroWeightSum);
590 }
591
592 let w_norm = w.mapv(|wi| wi / weight_sum);
594
595 match axis {
596 None => {
597 let g_scalar = grad.iter().next().copied().unwrap_or(0.0);
599 Ok(w_norm.mapv(|wn| wn * g_scalar))
600 }
601 Some(ax) => {
602 let grad_expanded = grad.view().insert_axis(Axis(ax));
604 Ok(&w_norm * &grad_expanded)
605 }
606 }
607 }
608
609 pub fn weighted_forall_grad(
620 &self,
621 grad: &ArrayD<f64>,
622 input: &ArrayD<f64>,
623 weights: &ArrayD<f64>,
624 axis: Option<usize>,
625 ) -> Result<ArrayD<f64>, ScoringError> {
626 if input.is_empty() {
627 return Err(ScoringError::EmptyInput);
628 }
629 validate_weights_for_axis(input, weights, axis)?;
630
631 let w = broadcast_weights(weights, input, axis)?;
632 let weight_sum: f64 = w.iter().sum();
633 if weight_sum == 0.0 {
634 return Err(ScoringError::ZeroWeightSum);
635 }
636
637 match self.config.mode {
638 ScoringMode::Standard => {
639 let log_input = input.mapv(|x| {
643 if x <= 0.0 {
644 self.config.log_floor
645 } else {
646 x.ln()
647 }
648 });
649
650 let forall_out = match axis {
651 None => {
652 let log_geo: f64 = log_input
653 .iter()
654 .zip(w.iter())
655 .map(|(&lx, &wi)| lx * wi / weight_sum)
656 .sum();
657 ArrayD::from_elem(input.raw_dim(), log_geo.exp())
658 }
659 Some(ax) => {
660 let w_sum_ax = w.sum_axis(Axis(ax));
661 let weighted_log = &log_input * &w;
662 let num = weighted_log.sum_axis(Axis(ax));
663 let out_no_ax = scirs2_core::ndarray::Zip::from(&num)
664 .and(&w_sum_ax)
665 .map_collect(|&n, &ws| if ws == 0.0 { 1.0 } else { (n / ws).exp() });
666 out_no_ax
668 .insert_axis(Axis(ax))
669 .broadcast(input.raw_dim())
670 .map_or_else(|| Array::zeros(input.raw_dim()), |v| v.to_owned())
671 }
672 };
673
674 let w_norm = w.mapv(|wi| wi / weight_sum);
676 let scale = scirs2_core::ndarray::Zip::from(&w_norm)
677 .and(&forall_out)
678 .and(input)
679 .map_collect(
680 |&wn, &out_v, &xi| {
681 if xi == 0.0 {
682 0.0
683 } else {
684 wn * out_v / xi
685 }
686 },
687 );
688
689 match axis {
690 None => {
691 let g_scalar = grad.iter().next().copied().unwrap_or(0.0);
692 Ok(scale.mapv(|s| s * g_scalar))
693 }
694 Some(ax) => {
695 let grad_expanded = grad.view().insert_axis(Axis(ax));
696 Ok(&scale * &grad_expanded)
697 }
698 }
699 }
700 ScoringMode::LogProbability | ScoringMode::LogOdds => {
701 let w_norm = w.mapv(|wi| wi / weight_sum);
703 match axis {
704 None => {
705 let g_scalar = grad.iter().next().copied().unwrap_or(0.0);
706 Ok(w_norm.mapv(|wn| wn * g_scalar))
707 }
708 Some(ax) => {
709 let grad_expanded = grad.view().insert_axis(Axis(ax));
710 Ok(&w_norm * &grad_expanded)
711 }
712 }
713 }
714 }
715 }
716}
717
718fn broadcast_weights(
729 weights: &ArrayD<f64>,
730 input: &ArrayD<f64>,
731 axis: Option<usize>,
732) -> Result<ArrayD<f64>, ScoringError> {
733 if weights.shape() == input.shape() {
734 return Ok(weights.clone());
735 }
736
737 match axis {
738 None => {
739 if weights.len() != input.len() {
741 return Err(ScoringError::ShapeMismatch {
742 input: input.shape().to_vec(),
743 weights: weights.shape().to_vec(),
744 });
745 }
746 weights
748 .clone()
749 .into_shape_with_order(input.raw_dim())
750 .map_err(|_| ScoringError::ShapeMismatch {
751 input: input.shape().to_vec(),
752 weights: weights.shape().to_vec(),
753 })
754 }
755 Some(ax) => {
756 if weights.ndim() == 1 && weights.len() == input.shape()[ax] {
757 let mut shape = vec![1usize; input.ndim()];
759 shape[ax] = input.shape()[ax];
760 let reshaped = weights
761 .clone()
762 .into_shape_with_order(IxDyn(&shape))
763 .map_err(|_| ScoringError::ShapeMismatch {
764 input: input.shape().to_vec(),
765 weights: weights.shape().to_vec(),
766 })?;
767 reshaped
768 .broadcast(input.raw_dim())
769 .map(|v| v.to_owned())
770 .ok_or_else(|| ScoringError::ShapeMismatch {
771 input: input.shape().to_vec(),
772 weights: weights.shape().to_vec(),
773 })
774 } else if weights.shape() == input.shape() {
775 Ok(weights.clone())
776 } else {
777 Err(ScoringError::ShapeMismatch {
778 input: input.shape().to_vec(),
779 weights: weights.shape().to_vec(),
780 })
781 }
782 }
783 }
784}
785
786pub fn log_sum_exp(
794 input: &ArrayD<f64>,
795 axis: Option<usize>,
796 config: ScoringConfig,
797) -> Result<ArrayD<f64>, ScoringError> {
798 LogSpaceAggregator::new(config).log_sum_exp(input, axis)
799}
800
801pub fn weighted_soft_exists(
805 input: &ArrayD<f64>,
806 weights: &ArrayD<f64>,
807 axis: Option<usize>,
808 config: ScoringConfig,
809) -> Result<ArrayD<f64>, ScoringError> {
810 WeightedQuantifier::new(config).weighted_exists(input, weights, axis)
811}
812
813pub fn weighted_soft_forall(
817 input: &ArrayD<f64>,
818 weights: &ArrayD<f64>,
819 axis: Option<usize>,
820 config: ScoringConfig,
821) -> Result<ArrayD<f64>, ScoringError> {
822 WeightedQuantifier::new(config).weighted_forall(input, weights, axis)
823}
824
825#[cfg(test)]
830mod tests {
831 use super::*;
832 use scirs2_core::ndarray::Array2;
833
834 const EPS: f64 = 1e-9;
835
836 fn config() -> ScoringConfig {
837 ScoringConfig::default()
838 }
839
840 fn agg() -> LogSpaceAggregator {
841 LogSpaceAggregator::new(config())
842 }
843
844 fn make_1d(data: Vec<f64>) -> ArrayD<f64> {
845 Array::from_vec(data).into_dyn()
846 }
847
848 fn make_2d(data: Vec<Vec<f64>>) -> ArrayD<f64> {
849 let rows = data.len();
850 let cols = data[0].len();
851 let flat: Vec<f64> = data.into_iter().flatten().collect();
852 Array2::from_shape_vec((rows, cols), flat)
853 .expect("valid shape")
854 .into_dyn()
855 }
856
857 #[test]
861 fn test_log_sum_exp_scalar() {
862 let input = make_1d(vec![3.0]);
863 let result = agg().log_sum_exp(&input, None).expect("log_sum_exp scalar");
864 assert!(
866 (result[[]] - 3.0).abs() < EPS,
867 "expected 3.0, got {}",
868 result[[]]
869 );
870 }
871
872 #[test]
876 fn test_log_sum_exp_zeros() {
877 let n = 4usize;
879 let input = make_1d(vec![0.0; n]);
880 let result = agg().log_sum_exp(&input, None).expect("log_sum_exp zeros");
881 let expected = (n as f64).ln();
882 assert!(
883 (result[[]] - expected).abs() < EPS,
884 "expected log({}), got {}",
885 n,
886 result[[]]
887 );
888 }
889
890 #[test]
894 fn test_log_sum_exp_vs_naive() {
895 let vals = vec![1.0, 2.0, 3.0];
896 let input = make_1d(vals.clone());
897 let result = agg().log_sum_exp(&input, None).expect("vs naive");
898 let naive = vals.iter().map(|&x| x.exp()).sum::<f64>().ln();
899 assert!(
900 (result[[]] - naive).abs() < 1e-10,
901 "stable != naive: {} vs {}",
902 result[[]],
903 naive
904 );
905 }
906
907 #[test]
911 fn test_log_sum_exp_numerical_stability() {
912 let input = make_1d(vec![300.0, 299.0, 298.0]);
914 let result = agg()
915 .log_sum_exp(&input, None)
916 .expect("numerical stability");
917 assert!(
918 result[[]].is_finite(),
919 "result should be finite, got {}",
920 result[[]]
921 );
922 let expected = 300.0 + (1.0 + (-1.0_f64).exp() + (-2.0_f64).exp()).ln();
924 assert!((result[[]] - expected).abs() < 1e-10);
925 }
926
927 #[test]
931 fn test_log_sum_exp_axis_0() {
932 let input = make_2d(vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]]);
934 let result = agg().log_sum_exp(&input, Some(0)).expect("axis 0");
935 assert_eq!(result.shape(), &[3]);
936 for col in 0..3 {
937 let a = (col + 1) as f64;
938 let b = (col + 4) as f64;
939 let expected = a.max(b) + (1.0 + (a.min(b) - a.max(b)).exp()).ln();
940 assert!((result[[col]] - expected).abs() < 1e-10);
941 }
942 }
943
944 #[test]
948 fn test_log_sum_exp_axis_1() {
949 let input = make_2d(vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]]);
951 let result = agg().log_sum_exp(&input, Some(1)).expect("axis 1");
952 assert_eq!(result.shape(), &[2]);
953 for row in 0..2 {
954 let vals: Vec<f64> = (1..=3).map(|c| (row * 3 + c) as f64).collect();
955 let expected_v = vals.iter().map(|&v| v.exp()).sum::<f64>().ln();
956 assert!(
957 (result[[row]] - expected_v).abs() < 1e-8,
958 "row {}: {} vs {}",
959 row,
960 result[[row]],
961 expected_v
962 );
963 }
964 }
965
966 #[test]
970 fn test_log_sum_exp_full_reduction() {
971 let input = make_2d(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
972 let result = agg().log_sum_exp(&input, None).expect("full reduction");
973 assert_eq!(result.shape(), &[] as &[usize]);
974 let naive = (1.0_f64.exp() + 2.0_f64.exp() + 3.0_f64.exp() + 4.0_f64.exp()).ln();
975 assert!((result[[]] - naive).abs() < 1e-8);
976 }
977
978 #[test]
982 fn test_log_product_basic() {
983 let input = make_1d(vec![0.5_f64.ln(), 0.25_f64.ln()]);
985 let result = agg().log_product(&input, None).expect("log_product basic");
986 let expected = 0.125_f64.ln();
987 assert!((result[[]] - expected).abs() < 1e-10);
988 }
989
990 #[test]
994 fn test_log_add_exp_symmetry() {
995 let a = make_1d(vec![1.0, 2.0, 3.0]);
996 let b = make_1d(vec![3.0, 1.0, 2.0]);
997 let ab = agg().log_add_exp(&a, &b).expect("log_add_exp ab");
998 let ba = agg().log_add_exp(&b, &a).expect("log_add_exp ba");
999 for i in 0..3 {
1000 assert!(
1001 (ab[[i]] - ba[[i]]).abs() < EPS,
1002 "symmetry violated at {}",
1003 i
1004 );
1005 }
1006 }
1007
1008 #[test]
1012 fn test_to_log_space_range() {
1013 let probs = make_1d(vec![0.0, 0.1, 0.5, 0.9, 1.0]);
1014 let result = agg().to_log_space(&probs).expect("to_log_space");
1015 for &v in result.iter() {
1016 assert!(v <= 0.0, "log-probability must be <= 0, got {}", v);
1017 }
1018 }
1019
1020 #[test]
1024 fn test_from_log_space_roundtrip() {
1025 let probs = make_1d(vec![0.1, 0.5, 0.9]);
1026 let log_p = agg().to_log_space(&probs).expect("to_log_space");
1027 let recovered = agg().from_log_space(&log_p).expect("from_log_space");
1028 for i in 0..3 {
1029 assert!(
1030 (probs[[i]] - recovered[[i]]).abs() < 1e-12,
1031 "roundtrip failed at {}: {} != {}",
1032 i,
1033 probs[[i]],
1034 recovered[[i]]
1035 );
1036 }
1037 }
1038
1039 #[test]
1043 fn test_log_floor_prevents_neg_inf() {
1044 let probs = make_1d(vec![0.0, 0.5, 1.0]); let result = agg().to_log_space(&probs).expect("log_floor");
1046 for &v in result.iter() {
1047 assert!(v.is_finite(), "value should be finite, got {}", v);
1048 }
1049 assert!(result[[0]] <= 0.0, "floor should be <= 0");
1050 }
1051
1052 #[test]
1056 fn test_weighted_exists_uniform_weights() {
1057 let input = make_1d(vec![0.2, 0.4, 0.6, 0.8]);
1059 let weights = make_1d(vec![1.0, 1.0, 1.0, 1.0]);
1060 let q = WeightedQuantifier::new(config());
1061 let result = q
1062 .weighted_exists(&input, &weights, None)
1063 .expect("uniform weights");
1064 let expected = 0.5; assert!(
1066 (result[[]] - expected).abs() < EPS,
1067 "expected {}, got {}",
1068 expected,
1069 result[[]]
1070 );
1071 }
1072
1073 #[test]
1077 fn test_weighted_exists_zero_weight_error() {
1078 let input = make_1d(vec![0.5, 0.5]);
1079 let weights = make_1d(vec![0.0, 0.0]);
1080 let q = WeightedQuantifier::new(config());
1081 let result = q.weighted_exists(&input, &weights, None);
1082 assert!(
1083 matches!(result, Err(ScoringError::ZeroWeightSum)),
1084 "expected ZeroWeightSum error"
1085 );
1086 }
1087
1088 #[test]
1092 fn test_weighted_exists_concentrated_weight() {
1093 let input = make_1d(vec![0.1, 0.3, 0.7, 0.9]);
1095 let weights = make_1d(vec![0.0, 0.0, 1.0, 0.0]);
1096 let q = WeightedQuantifier::new(config());
1097 let result = q
1098 .weighted_exists(&input, &weights, None)
1099 .expect("concentrated weight");
1100 assert!(
1101 (result[[]] - 0.7).abs() < EPS,
1102 "expected 0.7, got {}",
1103 result[[]]
1104 );
1105 }
1106
1107 #[test]
1111 fn test_weighted_forall_uniform() {
1112 let vals = vec![0.5, 0.25, 1.0, 0.5];
1114 let input = make_1d(vals.clone());
1115 let weights = make_1d(vec![1.0; 4]);
1116 let q = WeightedQuantifier::new(config());
1117 let result = q
1118 .weighted_forall(&input, &weights, None)
1119 .expect("forall uniform");
1120 let geo: f64 = vals.iter().product::<f64>().powf(0.25);
1122 assert!(
1123 (result[[]] - geo).abs() < 1e-10,
1124 "expected {}, got {}",
1125 geo,
1126 result[[]]
1127 );
1128 }
1129
1130 #[test]
1134 fn test_weighted_exists_gradient_shape() {
1135 let input = make_2d(vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]]);
1136 let weights = make_2d(vec![vec![1.0, 2.0, 1.0], vec![1.0, 2.0, 1.0]]);
1137 let q = WeightedQuantifier::new(config());
1138 let out = q
1140 .weighted_exists(&input, &weights, Some(1))
1141 .expect("forward");
1142 assert_eq!(out.shape(), &[2]);
1143 let grad = Array::ones(out.raw_dim());
1144 let d_input = q
1145 .weighted_exists_grad(&grad, &input, &weights, Some(1))
1146 .expect("grad");
1147 assert_eq!(
1148 d_input.shape(),
1149 input.shape(),
1150 "gradient should match input shape"
1151 );
1152 }
1153
1154 #[test]
1158 fn test_weighted_exists_gradient_finite() {
1159 let input = make_1d(vec![0.2, 0.5, 0.8]);
1160 let weights = make_1d(vec![1.0, 3.0, 1.0]);
1161 let q = WeightedQuantifier::new(config());
1162 let out = q.weighted_exists(&input, &weights, None).expect("forward");
1163 let grad = Array::ones(out.raw_dim());
1164 let d_input = q
1165 .weighted_exists_grad(&grad, &input, &weights, None)
1166 .expect("grad");
1167 for &v in d_input.iter() {
1168 assert!(v.is_finite(), "gradient must be finite, got {}", v);
1169 }
1170 }
1171
1172 #[test]
1176 fn test_scoring_config_default() {
1177 let cfg = ScoringConfig::default();
1178 assert_eq!(cfg.mode, ScoringMode::Standard);
1179 assert!((cfg.temperature - 1.0).abs() < EPS);
1180 assert!(cfg.log_floor < -100.0, "log_floor should be very negative");
1181 assert!(cfg.log_floor.is_finite(), "log_floor must be finite");
1182 }
1183
1184 #[test]
1188 fn test_scoring_config_builders() {
1189 let lp = ScoringConfig::log_probability();
1190 assert_eq!(lp.mode, ScoringMode::LogProbability);
1191
1192 let lo = ScoringConfig::log_odds();
1193 assert_eq!(lo.mode, ScoringMode::LogOdds);
1194
1195 let with_t = ScoringConfig::default().with_temperature(0.5);
1196 assert!((with_t.temperature - 0.5).abs() < EPS);
1197 }
1198
1199 #[test]
1203 fn test_free_function_log_sum_exp() {
1204 let input = make_1d(vec![0.0, 0.0, 0.0]);
1205 let result = log_sum_exp(&input, None, config()).expect("free fn log_sum_exp");
1206 let expected = (3.0_f64).ln();
1207 assert!((result[[]] - expected).abs() < EPS);
1208 }
1209
1210 #[test]
1214 fn test_log_space_quantifier_mode_via_gradient_ops() {
1215 use crate::gradient_ops::{soft_exists, QuantifierMode};
1216
1217 let input = make_1d(vec![0.0, 0.0, 0.0]);
1219 let scoring_cfg = ScoringConfig::log_probability();
1220 let mode = QuantifierMode::LogSpace(scoring_cfg);
1221 let result = soft_exists(&input, None, mode).expect("log_space quantifier");
1222 let expected = (3.0_f64).ln(); assert!(
1224 (result[[]] - expected).abs() < 1e-10,
1225 "expected log(3)={}, got {}",
1226 expected,
1227 result[[]]
1228 );
1229 }
1230}