Skip to main content

tensorlogic_scirs_backend/
scoring.rs

1//! Log-space scoring aggregation and weighted quantifiers.
2//!
3//! This module provides numerically stable log-space operations for probabilistic
4//! inference, along with weighted soft quantifiers (exists/forall) and their
5//! gradients for end-to-end training of logical models.
6//!
7//! ## Log-Space Operations
8//!
9//! All operations in this module are designed to work in log-probability space
10//! to avoid numerical underflow when dealing with very small probabilities.
11//!
12//! ## Weighted Quantifiers
13//!
14//! Weighted versions of soft-exists and soft-forall quantifiers that allow
15//! assigning importance weights to individual elements before aggregation.
16
17use scirs2_core::ndarray::{Array, ArrayD, Axis, IxDyn};
18
19/// Error type for scoring operations.
20#[derive(Debug, thiserror::Error)]
21pub enum ScoringError {
22    /// Input and weights have incompatible shapes.
23    #[error("Shape mismatch: input {input:?}, weights {weights:?}")]
24    ShapeMismatch {
25        /// Shape of the input tensor.
26        input: Vec<usize>,
27        /// Shape of the weights tensor.
28        weights: Vec<usize>,
29    },
30    /// Requested axis is out of bounds for the tensor.
31    #[error("Axis {axis} out of range for {ndim}D tensor")]
32    AxisOutOfRange {
33        /// Requested axis.
34        axis: usize,
35        /// Number of dimensions in the tensor.
36        ndim: usize,
37    },
38    /// All weights sum to zero, cannot normalize.
39    #[error("Division by zero in weight normalization")]
40    ZeroWeightSum,
41    /// A probability value outside [0, 1] was provided.
42    #[error("Invalid probability value {value}: must be in [0, 1]")]
43    InvalidProbability {
44        /// The offending value.
45        value: f64,
46    },
47    /// Reduction attempted on an empty tensor.
48    #[error("Empty input tensor")]
49    EmptyInput,
50}
51
52/// Scoring mode that controls the domain of input values.
53#[derive(Debug, Clone, Copy, PartialEq)]
54pub enum ScoringMode {
55    /// Standard probability space: values in [0, 1].
56    Standard,
57    /// Log-probability space: values in (-∞, 0].
58    LogProbability,
59    /// Log-odds space: values in ℝ (logit scale).
60    LogOdds,
61}
62
63/// Configuration for scoring operations.
64#[derive(Debug, Clone)]
65pub struct ScoringConfig {
66    /// The domain/mode of the scoring values.
67    pub mode: ScoringMode,
68    /// Floor value for log-space to avoid -∞.
69    /// Default: `f64::MIN_POSITIVE.ln() ≈ -708`.
70    pub log_floor: f64,
71    /// Temperature parameter for softmax-style operations.
72    /// Default: 1.0.
73    pub temperature: f64,
74}
75
76impl Default for ScoringConfig {
77    fn default() -> Self {
78        Self {
79            mode: ScoringMode::Standard,
80            log_floor: f64::MIN_POSITIVE.ln(), // ≈ -708
81            temperature: 1.0,
82        }
83    }
84}
85
86impl ScoringConfig {
87    /// Create a log-probability scoring configuration.
88    pub fn log_probability() -> Self {
89        Self {
90            mode: ScoringMode::LogProbability,
91            ..Self::default()
92        }
93    }
94
95    /// Create a log-odds scoring configuration.
96    pub fn log_odds() -> Self {
97        Self {
98            mode: ScoringMode::LogOdds,
99            ..Self::default()
100        }
101    }
102
103    /// Override the temperature parameter (builder pattern).
104    pub fn with_temperature(mut self, t: f64) -> Self {
105        self.temperature = t;
106        self
107    }
108}
109
110// ============================================================================
111// Internal stable helpers
112// ============================================================================
113
114/// Numerically stable log-sum-exp over a flat slice.
115///
116/// Implements: log Σ exp(x_i) via max subtraction.
117fn log_sum_exp_slice(slice: &[f64], log_floor: f64) -> f64 {
118    if slice.is_empty() {
119        return log_floor;
120    }
121    let max = slice.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
122    if max == f64::NEG_INFINITY {
123        return log_floor;
124    }
125    let sum_exp: f64 = slice.iter().map(|&x| (x - max).exp()).sum();
126    max + sum_exp.ln()
127}
128
129/// Compute log-sum-exp along a specific axis, returning the reduced array.
130///
131/// The output has the reduction axis removed.
132fn log_sum_exp_along_axis(
133    input: &ArrayD<f64>,
134    axis: usize,
135    log_floor: f64,
136) -> Result<ArrayD<f64>, ScoringError> {
137    if axis >= input.ndim() {
138        return Err(ScoringError::AxisOutOfRange {
139            axis,
140            ndim: input.ndim(),
141        });
142    }
143    if input.is_empty() {
144        return Err(ScoringError::EmptyInput);
145    }
146    Ok(input.map_axis(Axis(axis), |lane| {
147        let s: Vec<f64> = lane.iter().cloned().collect();
148        log_sum_exp_slice(&s, log_floor)
149    }))
150}
151
152/// Compute log-product (sum) along a specific axis, returning the reduced array.
153fn log_product_along_axis(input: &ArrayD<f64>, axis: usize) -> Result<ArrayD<f64>, ScoringError> {
154    if axis >= input.ndim() {
155        return Err(ScoringError::AxisOutOfRange {
156            axis,
157            ndim: input.ndim(),
158        });
159    }
160    if input.is_empty() {
161        return Err(ScoringError::EmptyInput);
162    }
163    Ok(input.map_axis(Axis(axis), |lane| lane.iter().sum::<f64>()))
164}
165
166// ============================================================================
167// LogSpaceAggregator
168// ============================================================================
169
170/// Numerically stable aggregation operations in log-probability space.
171///
172/// All reduction operations are implemented with maximum-subtraction tricks
173/// to avoid overflow and underflow when working with log-probabilities.
174pub struct LogSpaceAggregator {
175    config: ScoringConfig,
176}
177
178impl LogSpaceAggregator {
179    /// Create a new aggregator with the given configuration.
180    pub fn new(config: ScoringConfig) -> Self {
181        Self { config }
182    }
183
184    /// Compute log-sum-exp: `log Σ_i exp(x_i)`.
185    ///
186    /// Numerically stable via max subtraction:
187    /// `log Σ exp(x_i) = max + log Σ exp(x_i - max)`.
188    ///
189    /// # Arguments
190    /// * `input` - Input tensor (values in any domain).
191    /// * `axis` - `None` → reduce all elements; `Some(k)` → reduce along axis `k`.
192    ///
193    /// # Returns
194    /// Reduced tensor.  For `axis=None` this is a 0-D scalar tensor.
195    pub fn log_sum_exp(
196        &self,
197        input: &ArrayD<f64>,
198        axis: Option<usize>,
199    ) -> Result<ArrayD<f64>, ScoringError> {
200        if input.is_empty() {
201            return Err(ScoringError::EmptyInput);
202        }
203        match axis {
204            None => {
205                let flat: Vec<f64> = input.iter().cloned().collect();
206                let result = log_sum_exp_slice(&flat, self.config.log_floor);
207                Ok(ArrayD::from_elem(IxDyn(&[]), result))
208            }
209            Some(ax) => log_sum_exp_along_axis(input, ax, self.config.log_floor),
210        }
211    }
212
213    /// Compute log-product: `Σ_i x_i` (addition in log space = product in probability space).
214    ///
215    /// # Arguments
216    /// * `input` - Input tensor (values assumed to be in log-probability space).
217    /// * `axis` - `None` → sum all; `Some(k)` → sum along axis `k`.
218    pub fn log_product(
219        &self,
220        input: &ArrayD<f64>,
221        axis: Option<usize>,
222    ) -> Result<ArrayD<f64>, ScoringError> {
223        if input.is_empty() {
224            return Err(ScoringError::EmptyInput);
225        }
226        match axis {
227            None => {
228                let result: f64 = input.iter().sum();
229                // Clamp to log_floor
230                let result = result.max(self.config.log_floor);
231                Ok(ArrayD::from_elem(IxDyn(&[]), result))
232            }
233            Some(ax) => {
234                let out = log_product_along_axis(input, ax)?;
235                Ok(out.mapv(|v| v.max(self.config.log_floor)))
236            }
237        }
238    }
239
240    /// Element-wise binary log-add-exp: `log(exp(a) + exp(b))`.
241    ///
242    /// Numerically stable: `max + log(1 + exp(min - max))`.
243    ///
244    /// # Arguments
245    /// * `a` - First operand (must have same shape as `b`).
246    /// * `b` - Second operand.
247    pub fn log_add_exp(
248        &self,
249        a: &ArrayD<f64>,
250        b: &ArrayD<f64>,
251    ) -> Result<ArrayD<f64>, ScoringError> {
252        if a.shape() != b.shape() {
253            return Err(ScoringError::ShapeMismatch {
254                input: a.shape().to_vec(),
255                weights: b.shape().to_vec(),
256            });
257        }
258        // log_add_exp(a, b) = max(a,b) + log1p(exp(-|a-b|))
259        let result = a.mapv(|_| 0.0_f64); // same shape placeholder
260        let result = scirs2_core::ndarray::Zip::from(&result)
261            .and(a)
262            .and(b)
263            .map_collect(|_, &ai, &bi| {
264                let max = ai.max(bi);
265                let min = ai.min(bi);
266                if max == f64::NEG_INFINITY {
267                    self.config.log_floor
268                } else {
269                    max + (1.0_f64 + (min - max).exp()).ln()
270                }
271            });
272        Ok(result)
273    }
274
275    /// Convert probabilities to log-space, clamping to `log_floor`.
276    ///
277    /// # Arguments
278    /// * `probs` - Input probabilities (must be in [0, 1]).
279    pub fn to_log_space(&self, probs: &ArrayD<f64>) -> Result<ArrayD<f64>, ScoringError> {
280        // Validate that all values are in [0, 1]
281        for &v in probs.iter() {
282            if !v.is_finite() || !(0.0..=1.0).contains(&v) {
283                return Err(ScoringError::InvalidProbability { value: v });
284            }
285        }
286        let floor = self.config.log_floor;
287        Ok(probs.mapv(|p| if p <= 0.0 { floor } else { p.ln().max(floor) }))
288    }
289
290    /// Convert log-probabilities back to probability space via `exp`.
291    ///
292    /// # Arguments
293    /// * `log_probs` - Input log-probabilities (values in (-∞, 0]).
294    pub fn from_log_space(&self, log_probs: &ArrayD<f64>) -> Result<ArrayD<f64>, ScoringError> {
295        Ok(log_probs.mapv(|lp| lp.exp()))
296    }
297}
298
299// ============================================================================
300// WeightedQuantifier
301// ============================================================================
302
303/// Validate that input and weights share compatible shapes for a given axis.
304///
305/// Returns `Ok(weight_sum)` – the total sum of weights – so callers can
306/// check for zero-sum and avoid a second pass.
307fn validate_weights_for_axis(
308    input: &ArrayD<f64>,
309    weights: &ArrayD<f64>,
310    axis: Option<usize>,
311) -> Result<(), ScoringError> {
312    match axis {
313        None => {
314            // Weights must be 1-D with length = total number of elements,
315            // OR have the same shape as input.
316            if weights.shape() != input.shape() && weights.len() != input.len() {
317                return Err(ScoringError::ShapeMismatch {
318                    input: input.shape().to_vec(),
319                    weights: weights.shape().to_vec(),
320                });
321            }
322        }
323        Some(ax) => {
324            if ax >= input.ndim() {
325                return Err(ScoringError::AxisOutOfRange {
326                    axis: ax,
327                    ndim: input.ndim(),
328                });
329            }
330            let expected_len = input.shape()[ax];
331            // Weights should either match input shape exactly, or be 1-D with length = axis_size.
332            let compatible = weights.shape() == input.shape()
333                || (weights.ndim() == 1 && weights.len() == expected_len);
334            if !compatible {
335                return Err(ScoringError::ShapeMismatch {
336                    input: input.shape().to_vec(),
337                    weights: weights.shape().to_vec(),
338                });
339            }
340        }
341    }
342    Ok(())
343}
344
345/// Weighted soft-quantifier operations.
346///
347/// Provides differentiable, weight-aware approximations of logical quantifiers
348/// that can be used in end-to-end training pipelines.
349pub struct WeightedQuantifier {
350    config: ScoringConfig,
351}
352
353impl WeightedQuantifier {
354    /// Create a new quantifier with the given configuration.
355    pub fn new(config: ScoringConfig) -> Self {
356        Self { config }
357    }
358
359    /// Soft-exists: weighted mean (standard) or log-sum-exp with log-weights (log mode).
360    ///
361    /// **Standard mode**: `Σ w_i * x_i / Σ w_i`
362    ///
363    /// **LogProbability / LogOdds mode**: `log-sum-exp(log(w) + x) - log(Σ w_i)`
364    ///
365    /// # Arguments
366    /// * `input`   - Input values.
367    /// * `weights` - Non-negative importance weights (must broadcast with input along `axis`).
368    /// * `axis`    - `None` → over all elements; `Some(k)` → along axis `k`.
369    pub fn weighted_exists(
370        &self,
371        input: &ArrayD<f64>,
372        weights: &ArrayD<f64>,
373        axis: Option<usize>,
374    ) -> Result<ArrayD<f64>, ScoringError> {
375        if input.is_empty() {
376            return Err(ScoringError::EmptyInput);
377        }
378        validate_weights_for_axis(input, weights, axis)?;
379
380        match self.config.mode {
381            ScoringMode::Standard => self.weighted_exists_standard(input, weights, axis),
382            ScoringMode::LogProbability | ScoringMode::LogOdds => {
383                self.weighted_exists_log(input, weights, axis)
384            }
385        }
386    }
387
388    fn weighted_exists_standard(
389        &self,
390        input: &ArrayD<f64>,
391        weights: &ArrayD<f64>,
392        axis: Option<usize>,
393    ) -> Result<ArrayD<f64>, ScoringError> {
394        // Broadcast weights to input shape if needed
395        let w = broadcast_weights(weights, input, axis)?;
396
397        let weight_sum: f64 = w.iter().sum();
398        if weight_sum == 0.0 {
399            return Err(ScoringError::ZeroWeightSum);
400        }
401
402        match axis {
403            None => {
404                let numerator: f64 = input.iter().zip(w.iter()).map(|(&x, &wi)| wi * x).sum();
405                let result = numerator / weight_sum;
406                Ok(ArrayD::from_elem(IxDyn(&[]), result))
407            }
408            Some(ax) => {
409                let weighted = input * &w;
410                let num = weighted.sum_axis(Axis(ax));
411                // Per-slice weight sums
412                let w_sum = w.sum_axis(Axis(ax));
413                // Avoid division by zero per element
414                let result = scirs2_core::ndarray::Zip::from(&num)
415                    .and(&w_sum)
416                    .map_collect(|&n, &ws| if ws == 0.0 { 0.0 } else { n / ws });
417                Ok(result)
418            }
419        }
420    }
421
422    fn weighted_exists_log(
423        &self,
424        input: &ArrayD<f64>,
425        weights: &ArrayD<f64>,
426        axis: Option<usize>,
427    ) -> Result<ArrayD<f64>, ScoringError> {
428        // Broadcast weights to input shape
429        let w = broadcast_weights(weights, input, axis)?;
430        let weight_sum: f64 = w.iter().sum();
431        if weight_sum == 0.0 {
432            return Err(ScoringError::ZeroWeightSum);
433        }
434        let log_norm = weight_sum.ln();
435        let floor = self.config.log_floor;
436
437        // log(w_i) + x_i, then log-sum-exp, minus log(Σ w_i)
438        let log_w_plus_x =
439            scirs2_core::ndarray::Zip::from(&w)
440                .and(input)
441                .map_collect(|&wi, &xi| {
442                    if wi <= 0.0 {
443                        floor
444                    } else {
445                        (wi.ln() + xi).max(floor)
446                    }
447                });
448
449        let agg = LogSpaceAggregator::new(self.config.clone());
450        let lse = agg.log_sum_exp(&log_w_plus_x, axis)?;
451        Ok(lse.mapv(|v| v - log_norm))
452    }
453
454    /// Soft-forall: weighted geometric mean (standard) or log-mean (log mode).
455    ///
456    /// **Standard mode**: `∏ x_i^(w_i/Σw_i)` = geometric weighted mean.
457    ///
458    /// **LogProbability / LogOdds mode**: `Σ (w_i/Σw_i) * x_i` (weighted arithmetic mean in log space).
459    ///
460    /// # Arguments
461    /// * `input`   - Input values.
462    /// * `weights` - Non-negative importance weights.
463    /// * `axis`    - Reduction axis.
464    pub fn weighted_forall(
465        &self,
466        input: &ArrayD<f64>,
467        weights: &ArrayD<f64>,
468        axis: Option<usize>,
469    ) -> Result<ArrayD<f64>, ScoringError> {
470        if input.is_empty() {
471            return Err(ScoringError::EmptyInput);
472        }
473        validate_weights_for_axis(input, weights, axis)?;
474
475        match self.config.mode {
476            ScoringMode::Standard => self.weighted_forall_standard(input, weights, axis),
477            ScoringMode::LogProbability | ScoringMode::LogOdds => {
478                self.weighted_forall_log(input, weights, axis)
479            }
480        }
481    }
482
483    fn weighted_forall_standard(
484        &self,
485        input: &ArrayD<f64>,
486        weights: &ArrayD<f64>,
487        axis: Option<usize>,
488    ) -> Result<ArrayD<f64>, ScoringError> {
489        let w = broadcast_weights(weights, input, axis)?;
490        let weight_sum: f64 = w.iter().sum();
491        if weight_sum == 0.0 {
492            return Err(ScoringError::ZeroWeightSum);
493        }
494
495        // Compute log(x_i) * (w_i / Σw), sum along axis, then exp
496        let log_input = input.mapv(|x| {
497            if x <= 0.0 {
498                self.config.log_floor
499            } else {
500                x.ln()
501            }
502        });
503
504        match axis {
505            None => {
506                let log_geo: f64 = log_input
507                    .iter()
508                    .zip(w.iter())
509                    .map(|(&lx, &wi)| lx * wi / weight_sum)
510                    .sum();
511                Ok(ArrayD::from_elem(IxDyn(&[]), log_geo.exp()))
512            }
513            Some(ax) => {
514                let w_sum_ax = w.sum_axis(Axis(ax));
515                let weighted_log = &log_input * &w;
516                let num = weighted_log.sum_axis(Axis(ax));
517                let result = scirs2_core::ndarray::Zip::from(&num)
518                    .and(&w_sum_ax)
519                    .map_collect(|&n, &ws| {
520                        if ws == 0.0 {
521                            1.0 // neutral element for geometric mean
522                        } else {
523                            (n / ws).exp()
524                        }
525                    });
526                Ok(result)
527            }
528        }
529    }
530
531    fn weighted_forall_log(
532        &self,
533        input: &ArrayD<f64>,
534        weights: &ArrayD<f64>,
535        axis: Option<usize>,
536    ) -> Result<ArrayD<f64>, ScoringError> {
537        // In log space, forall is weighted arithmetic mean of log values.
538        let w = broadcast_weights(weights, input, axis)?;
539        let weight_sum: f64 = w.iter().sum();
540        if weight_sum == 0.0 {
541            return Err(ScoringError::ZeroWeightSum);
542        }
543
544        match axis {
545            None => {
546                let result: f64 = input
547                    .iter()
548                    .zip(w.iter())
549                    .map(|(&xi, &wi)| xi * wi / weight_sum)
550                    .sum();
551                Ok(ArrayD::from_elem(IxDyn(&[]), result))
552            }
553            Some(ax) => {
554                let w_sum_ax = w.sum_axis(Axis(ax));
555                let weighted = input * &w;
556                let num = weighted.sum_axis(Axis(ax));
557                let result = scirs2_core::ndarray::Zip::from(&num)
558                    .and(&w_sum_ax)
559                    .map_collect(|&n, &ws| if ws == 0.0 { 0.0 } else { n / ws });
560                Ok(result)
561            }
562        }
563    }
564
565    /// Gradient of [`WeightedQuantifier::weighted_exists`] with respect to input.
566    ///
567    /// **Standard mode**: `∂/∂x_i (Σ w_j x_j / Σ w_j) = w_i / Σ w_j`
568    ///
569    /// # Arguments
570    /// * `grad`    - Upstream gradient (same shape as the forward output).
571    /// * `input`   - Forward input (used for shape in log mode).
572    /// * `weights` - Same weights used in the forward pass.
573    /// * `axis`    - Same axis used in the forward pass.
574    pub fn weighted_exists_grad(
575        &self,
576        grad: &ArrayD<f64>,
577        input: &ArrayD<f64>,
578        weights: &ArrayD<f64>,
579        axis: Option<usize>,
580    ) -> Result<ArrayD<f64>, ScoringError> {
581        if input.is_empty() {
582            return Err(ScoringError::EmptyInput);
583        }
584        validate_weights_for_axis(input, weights, axis)?;
585
586        let w = broadcast_weights(weights, input, axis)?;
587        let weight_sum: f64 = w.iter().sum();
588        if weight_sum == 0.0 {
589            return Err(ScoringError::ZeroWeightSum);
590        }
591
592        // Normalized weights: w_i / Σ w
593        let w_norm = w.mapv(|wi| wi / weight_sum);
594
595        match axis {
596            None => {
597                // grad is scalar (0-D), broadcast to input shape
598                let g_scalar = grad.iter().next().copied().unwrap_or(0.0);
599                Ok(w_norm.mapv(|wn| wn * g_scalar))
600            }
601            Some(ax) => {
602                // grad has axis `ax` removed; need to reinsert for broadcasting
603                let grad_expanded = grad.view().insert_axis(Axis(ax));
604                Ok(&w_norm * &grad_expanded)
605            }
606        }
607    }
608
609    /// Gradient of [`WeightedQuantifier::weighted_forall`] with respect to input.
610    ///
611    /// **Standard mode**: geometric mean gradient.
612    /// `∂/∂x_i ∏ x_j^(w_j/W) = (w_i/W) * ∏ x_j^(w_j/W) / x_i`
613    ///
614    /// # Arguments
615    /// * `grad`    - Upstream gradient (same shape as the forward output).
616    /// * `input`   - Forward input (needed for ∏ x^w computation).
617    /// * `weights` - Same weights used in the forward pass.
618    /// * `axis`    - Same axis used in the forward pass.
619    pub fn weighted_forall_grad(
620        &self,
621        grad: &ArrayD<f64>,
622        input: &ArrayD<f64>,
623        weights: &ArrayD<f64>,
624        axis: Option<usize>,
625    ) -> Result<ArrayD<f64>, ScoringError> {
626        if input.is_empty() {
627            return Err(ScoringError::EmptyInput);
628        }
629        validate_weights_for_axis(input, weights, axis)?;
630
631        let w = broadcast_weights(weights, input, axis)?;
632        let weight_sum: f64 = w.iter().sum();
633        if weight_sum == 0.0 {
634            return Err(ScoringError::ZeroWeightSum);
635        }
636
637        match self.config.mode {
638            ScoringMode::Standard => {
639                // Geometric mean gradient:
640                // ∂out/∂x_i = (w_i/W) * out / x_i
641                // out = ∏ x_j^(w_j/W)
642                let log_input = input.mapv(|x| {
643                    if x <= 0.0 {
644                        self.config.log_floor
645                    } else {
646                        x.ln()
647                    }
648                });
649
650                let forall_out = match axis {
651                    None => {
652                        let log_geo: f64 = log_input
653                            .iter()
654                            .zip(w.iter())
655                            .map(|(&lx, &wi)| lx * wi / weight_sum)
656                            .sum();
657                        ArrayD::from_elem(input.raw_dim(), log_geo.exp())
658                    }
659                    Some(ax) => {
660                        let w_sum_ax = w.sum_axis(Axis(ax));
661                        let weighted_log = &log_input * &w;
662                        let num = weighted_log.sum_axis(Axis(ax));
663                        let out_no_ax = scirs2_core::ndarray::Zip::from(&num)
664                            .and(&w_sum_ax)
665                            .map_collect(|&n, &ws| if ws == 0.0 { 1.0 } else { (n / ws).exp() });
666                        // broadcast back
667                        out_no_ax
668                            .insert_axis(Axis(ax))
669                            .broadcast(input.raw_dim())
670                            .map_or_else(|| Array::zeros(input.raw_dim()), |v| v.to_owned())
671                    }
672                };
673
674                // ∂out/∂x_i = (w_i/W) * out / x_i
675                let w_norm = w.mapv(|wi| wi / weight_sum);
676                let scale = scirs2_core::ndarray::Zip::from(&w_norm)
677                    .and(&forall_out)
678                    .and(input)
679                    .map_collect(
680                        |&wn, &out_v, &xi| {
681                            if xi == 0.0 {
682                                0.0
683                            } else {
684                                wn * out_v / xi
685                            }
686                        },
687                    );
688
689                match axis {
690                    None => {
691                        let g_scalar = grad.iter().next().copied().unwrap_or(0.0);
692                        Ok(scale.mapv(|s| s * g_scalar))
693                    }
694                    Some(ax) => {
695                        let grad_expanded = grad.view().insert_axis(Axis(ax));
696                        Ok(&scale * &grad_expanded)
697                    }
698                }
699            }
700            ScoringMode::LogProbability | ScoringMode::LogOdds => {
701                // In log mode forall is weighted mean: ∂/∂x_i = w_i/W
702                let w_norm = w.mapv(|wi| wi / weight_sum);
703                match axis {
704                    None => {
705                        let g_scalar = grad.iter().next().copied().unwrap_or(0.0);
706                        Ok(w_norm.mapv(|wn| wn * g_scalar))
707                    }
708                    Some(ax) => {
709                        let grad_expanded = grad.view().insert_axis(Axis(ax));
710                        Ok(&w_norm * &grad_expanded)
711                    }
712                }
713            }
714        }
715    }
716}
717
718// ============================================================================
719// Broadcast helpers
720// ============================================================================
721
722/// Broadcast `weights` to the full shape of `input`.
723///
724/// Supported cases:
725/// - `weights.shape() == input.shape()` → clone
726/// - 1-D weights along `axis` → expand dimensions
727/// - flat weights with same length as input → reshape
728fn broadcast_weights(
729    weights: &ArrayD<f64>,
730    input: &ArrayD<f64>,
731    axis: Option<usize>,
732) -> Result<ArrayD<f64>, ScoringError> {
733    if weights.shape() == input.shape() {
734        return Ok(weights.clone());
735    }
736
737    match axis {
738        None => {
739            // Flat weights: must have same total length as input
740            if weights.len() != input.len() {
741                return Err(ScoringError::ShapeMismatch {
742                    input: input.shape().to_vec(),
743                    weights: weights.shape().to_vec(),
744                });
745            }
746            // Reshape to input shape
747            weights
748                .clone()
749                .into_shape_with_order(input.raw_dim())
750                .map_err(|_| ScoringError::ShapeMismatch {
751                    input: input.shape().to_vec(),
752                    weights: weights.shape().to_vec(),
753                })
754        }
755        Some(ax) => {
756            if weights.ndim() == 1 && weights.len() == input.shape()[ax] {
757                // Build broadcast shape: 1 everywhere except `ax`
758                let mut shape = vec![1usize; input.ndim()];
759                shape[ax] = input.shape()[ax];
760                let reshaped = weights
761                    .clone()
762                    .into_shape_with_order(IxDyn(&shape))
763                    .map_err(|_| ScoringError::ShapeMismatch {
764                        input: input.shape().to_vec(),
765                        weights: weights.shape().to_vec(),
766                    })?;
767                reshaped
768                    .broadcast(input.raw_dim())
769                    .map(|v| v.to_owned())
770                    .ok_or_else(|| ScoringError::ShapeMismatch {
771                        input: input.shape().to_vec(),
772                        weights: weights.shape().to_vec(),
773                    })
774            } else if weights.shape() == input.shape() {
775                Ok(weights.clone())
776            } else {
777                Err(ScoringError::ShapeMismatch {
778                    input: input.shape().to_vec(),
779                    weights: weights.shape().to_vec(),
780                })
781            }
782        }
783    }
784}
785
786// ============================================================================
787// Free functions (gradient_ops.rs style)
788// ============================================================================
789
790/// Compute log-sum-exp with an explicit `ScoringConfig`.
791///
792/// Convenience wrapper around [`LogSpaceAggregator::log_sum_exp`].
793pub fn log_sum_exp(
794    input: &ArrayD<f64>,
795    axis: Option<usize>,
796    config: ScoringConfig,
797) -> Result<ArrayD<f64>, ScoringError> {
798    LogSpaceAggregator::new(config).log_sum_exp(input, axis)
799}
800
801/// Compute weighted soft-exists with an explicit `ScoringConfig`.
802///
803/// Convenience wrapper around [`WeightedQuantifier::weighted_exists`].
804pub fn weighted_soft_exists(
805    input: &ArrayD<f64>,
806    weights: &ArrayD<f64>,
807    axis: Option<usize>,
808    config: ScoringConfig,
809) -> Result<ArrayD<f64>, ScoringError> {
810    WeightedQuantifier::new(config).weighted_exists(input, weights, axis)
811}
812
813/// Compute weighted soft-forall with an explicit `ScoringConfig`.
814///
815/// Convenience wrapper around [`WeightedQuantifier::weighted_forall`].
816pub fn weighted_soft_forall(
817    input: &ArrayD<f64>,
818    weights: &ArrayD<f64>,
819    axis: Option<usize>,
820    config: ScoringConfig,
821) -> Result<ArrayD<f64>, ScoringError> {
822    WeightedQuantifier::new(config).weighted_forall(input, weights, axis)
823}
824
825// ============================================================================
826// Tests
827// ============================================================================
828
829#[cfg(test)]
830mod tests {
831    use super::*;
832    use scirs2_core::ndarray::Array2;
833
834    const EPS: f64 = 1e-9;
835
836    fn config() -> ScoringConfig {
837        ScoringConfig::default()
838    }
839
840    fn agg() -> LogSpaceAggregator {
841        LogSpaceAggregator::new(config())
842    }
843
844    fn make_1d(data: Vec<f64>) -> ArrayD<f64> {
845        Array::from_vec(data).into_dyn()
846    }
847
848    fn make_2d(data: Vec<Vec<f64>>) -> ArrayD<f64> {
849        let rows = data.len();
850        let cols = data[0].len();
851        let flat: Vec<f64> = data.into_iter().flatten().collect();
852        Array2::from_shape_vec((rows, cols), flat)
853            .expect("valid shape")
854            .into_dyn()
855    }
856
857    // -------------------------------------------------------------------------
858    // 1. test_log_sum_exp_scalar
859    // -------------------------------------------------------------------------
860    #[test]
861    fn test_log_sum_exp_scalar() {
862        let input = make_1d(vec![3.0]);
863        let result = agg().log_sum_exp(&input, None).expect("log_sum_exp scalar");
864        // log(exp(3)) == 3
865        assert!(
866            (result[[]] - 3.0).abs() < EPS,
867            "expected 3.0, got {}",
868            result[[]]
869        );
870    }
871
872    // -------------------------------------------------------------------------
873    // 2. test_log_sum_exp_zeros
874    // -------------------------------------------------------------------------
875    #[test]
876    fn test_log_sum_exp_zeros() {
877        // log-sum-exp([0,0,0,0]) = log(4)
878        let n = 4usize;
879        let input = make_1d(vec![0.0; n]);
880        let result = agg().log_sum_exp(&input, None).expect("log_sum_exp zeros");
881        let expected = (n as f64).ln();
882        assert!(
883            (result[[]] - expected).abs() < EPS,
884            "expected log({}), got {}",
885            n,
886            result[[]]
887        );
888    }
889
890    // -------------------------------------------------------------------------
891    // 3. test_log_sum_exp_vs_naive
892    // -------------------------------------------------------------------------
893    #[test]
894    fn test_log_sum_exp_vs_naive() {
895        let vals = vec![1.0, 2.0, 3.0];
896        let input = make_1d(vals.clone());
897        let result = agg().log_sum_exp(&input, None).expect("vs naive");
898        let naive = vals.iter().map(|&x| x.exp()).sum::<f64>().ln();
899        assert!(
900            (result[[]] - naive).abs() < 1e-10,
901            "stable != naive: {} vs {}",
902            result[[]],
903            naive
904        );
905    }
906
907    // -------------------------------------------------------------------------
908    // 4. test_log_sum_exp_numerical_stability
909    // -------------------------------------------------------------------------
910    #[test]
911    fn test_log_sum_exp_numerical_stability() {
912        // Naive exp(300) overflows; stable version should not.
913        let input = make_1d(vec![300.0, 299.0, 298.0]);
914        let result = agg()
915            .log_sum_exp(&input, None)
916            .expect("numerical stability");
917        assert!(
918            result[[]].is_finite(),
919            "result should be finite, got {}",
920            result[[]]
921        );
922        // Should be close to 300 + log(1 + exp(-1) + exp(-2))
923        let expected = 300.0 + (1.0 + (-1.0_f64).exp() + (-2.0_f64).exp()).ln();
924        assert!((result[[]] - expected).abs() < 1e-10);
925    }
926
927    // -------------------------------------------------------------------------
928    // 5. test_log_sum_exp_axis_0
929    // -------------------------------------------------------------------------
930    #[test]
931    fn test_log_sum_exp_axis_0() {
932        // 2x3 matrix, reduce along axis 0 → shape [3]
933        let input = make_2d(vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]]);
934        let result = agg().log_sum_exp(&input, Some(0)).expect("axis 0");
935        assert_eq!(result.shape(), &[3]);
936        for col in 0..3 {
937            let a = (col + 1) as f64;
938            let b = (col + 4) as f64;
939            let expected = a.max(b) + (1.0 + (a.min(b) - a.max(b)).exp()).ln();
940            assert!((result[[col]] - expected).abs() < 1e-10);
941        }
942    }
943
944    // -------------------------------------------------------------------------
945    // 6. test_log_sum_exp_axis_1
946    // -------------------------------------------------------------------------
947    #[test]
948    fn test_log_sum_exp_axis_1() {
949        // 2x3 matrix, reduce along axis 1 → shape [2]
950        let input = make_2d(vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]]);
951        let result = agg().log_sum_exp(&input, Some(1)).expect("axis 1");
952        assert_eq!(result.shape(), &[2]);
953        for row in 0..2 {
954            let vals: Vec<f64> = (1..=3).map(|c| (row * 3 + c) as f64).collect();
955            let expected_v = vals.iter().map(|&v| v.exp()).sum::<f64>().ln();
956            assert!(
957                (result[[row]] - expected_v).abs() < 1e-8,
958                "row {}: {} vs {}",
959                row,
960                result[[row]],
961                expected_v
962            );
963        }
964    }
965
966    // -------------------------------------------------------------------------
967    // 7. test_log_sum_exp_full_reduction
968    // -------------------------------------------------------------------------
969    #[test]
970    fn test_log_sum_exp_full_reduction() {
971        let input = make_2d(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
972        let result = agg().log_sum_exp(&input, None).expect("full reduction");
973        assert_eq!(result.shape(), &[] as &[usize]);
974        let naive = (1.0_f64.exp() + 2.0_f64.exp() + 3.0_f64.exp() + 4.0_f64.exp()).ln();
975        assert!((result[[]] - naive).abs() < 1e-8);
976    }
977
978    // -------------------------------------------------------------------------
979    // 8. test_log_product_basic
980    // -------------------------------------------------------------------------
981    #[test]
982    fn test_log_product_basic() {
983        // log(0.5) + log(0.25) = log(0.125)
984        let input = make_1d(vec![0.5_f64.ln(), 0.25_f64.ln()]);
985        let result = agg().log_product(&input, None).expect("log_product basic");
986        let expected = 0.125_f64.ln();
987        assert!((result[[]] - expected).abs() < 1e-10);
988    }
989
990    // -------------------------------------------------------------------------
991    // 9. test_log_add_exp_symmetry
992    // -------------------------------------------------------------------------
993    #[test]
994    fn test_log_add_exp_symmetry() {
995        let a = make_1d(vec![1.0, 2.0, 3.0]);
996        let b = make_1d(vec![3.0, 1.0, 2.0]);
997        let ab = agg().log_add_exp(&a, &b).expect("log_add_exp ab");
998        let ba = agg().log_add_exp(&b, &a).expect("log_add_exp ba");
999        for i in 0..3 {
1000            assert!(
1001                (ab[[i]] - ba[[i]]).abs() < EPS,
1002                "symmetry violated at {}",
1003                i
1004            );
1005        }
1006    }
1007
1008    // -------------------------------------------------------------------------
1009    // 10. test_to_log_space_range
1010    // -------------------------------------------------------------------------
1011    #[test]
1012    fn test_to_log_space_range() {
1013        let probs = make_1d(vec![0.0, 0.1, 0.5, 0.9, 1.0]);
1014        let result = agg().to_log_space(&probs).expect("to_log_space");
1015        for &v in result.iter() {
1016            assert!(v <= 0.0, "log-probability must be <= 0, got {}", v);
1017        }
1018    }
1019
1020    // -------------------------------------------------------------------------
1021    // 11. test_from_log_space_roundtrip
1022    // -------------------------------------------------------------------------
1023    #[test]
1024    fn test_from_log_space_roundtrip() {
1025        let probs = make_1d(vec![0.1, 0.5, 0.9]);
1026        let log_p = agg().to_log_space(&probs).expect("to_log_space");
1027        let recovered = agg().from_log_space(&log_p).expect("from_log_space");
1028        for i in 0..3 {
1029            assert!(
1030                (probs[[i]] - recovered[[i]]).abs() < 1e-12,
1031                "roundtrip failed at {}: {} != {}",
1032                i,
1033                probs[[i]],
1034                recovered[[i]]
1035            );
1036        }
1037    }
1038
1039    // -------------------------------------------------------------------------
1040    // 12. test_log_floor_prevents_neg_inf
1041    // -------------------------------------------------------------------------
1042    #[test]
1043    fn test_log_floor_prevents_neg_inf() {
1044        let probs = make_1d(vec![0.0, 0.5, 1.0]); // p=0 would give -inf
1045        let result = agg().to_log_space(&probs).expect("log_floor");
1046        for &v in result.iter() {
1047            assert!(v.is_finite(), "value should be finite, got {}", v);
1048        }
1049        assert!(result[[0]] <= 0.0, "floor should be <= 0");
1050    }
1051
1052    // -------------------------------------------------------------------------
1053    // 13. test_weighted_exists_uniform_weights
1054    // -------------------------------------------------------------------------
1055    #[test]
1056    fn test_weighted_exists_uniform_weights() {
1057        // Uniform weights → weighted mean = simple mean
1058        let input = make_1d(vec![0.2, 0.4, 0.6, 0.8]);
1059        let weights = make_1d(vec![1.0, 1.0, 1.0, 1.0]);
1060        let q = WeightedQuantifier::new(config());
1061        let result = q
1062            .weighted_exists(&input, &weights, None)
1063            .expect("uniform weights");
1064        let expected = 0.5; // (0.2+0.4+0.6+0.8)/4
1065        assert!(
1066            (result[[]] - expected).abs() < EPS,
1067            "expected {}, got {}",
1068            expected,
1069            result[[]]
1070        );
1071    }
1072
1073    // -------------------------------------------------------------------------
1074    // 14. test_weighted_exists_zero_weight_error
1075    // -------------------------------------------------------------------------
1076    #[test]
1077    fn test_weighted_exists_zero_weight_error() {
1078        let input = make_1d(vec![0.5, 0.5]);
1079        let weights = make_1d(vec![0.0, 0.0]);
1080        let q = WeightedQuantifier::new(config());
1081        let result = q.weighted_exists(&input, &weights, None);
1082        assert!(
1083            matches!(result, Err(ScoringError::ZeroWeightSum)),
1084            "expected ZeroWeightSum error"
1085        );
1086    }
1087
1088    // -------------------------------------------------------------------------
1089    // 15. test_weighted_exists_concentrated_weight
1090    // -------------------------------------------------------------------------
1091    #[test]
1092    fn test_weighted_exists_concentrated_weight() {
1093        // All weight on the third element → result ≈ x[2]
1094        let input = make_1d(vec![0.1, 0.3, 0.7, 0.9]);
1095        let weights = make_1d(vec![0.0, 0.0, 1.0, 0.0]);
1096        let q = WeightedQuantifier::new(config());
1097        let result = q
1098            .weighted_exists(&input, &weights, None)
1099            .expect("concentrated weight");
1100        assert!(
1101            (result[[]] - 0.7).abs() < EPS,
1102            "expected 0.7, got {}",
1103            result[[]]
1104        );
1105    }
1106
1107    // -------------------------------------------------------------------------
1108    // 16. test_weighted_forall_uniform
1109    // -------------------------------------------------------------------------
1110    #[test]
1111    fn test_weighted_forall_uniform() {
1112        // Uniform weights → geometric mean
1113        let vals = vec![0.5, 0.25, 1.0, 0.5];
1114        let input = make_1d(vals.clone());
1115        let weights = make_1d(vec![1.0; 4]);
1116        let q = WeightedQuantifier::new(config());
1117        let result = q
1118            .weighted_forall(&input, &weights, None)
1119            .expect("forall uniform");
1120        // geometric mean = (0.5 * 0.25 * 1.0 * 0.5)^(1/4)
1121        let geo: f64 = vals.iter().product::<f64>().powf(0.25);
1122        assert!(
1123            (result[[]] - geo).abs() < 1e-10,
1124            "expected {}, got {}",
1125            geo,
1126            result[[]]
1127        );
1128    }
1129
1130    // -------------------------------------------------------------------------
1131    // 17. test_weighted_exists_gradient_shape
1132    // -------------------------------------------------------------------------
1133    #[test]
1134    fn test_weighted_exists_gradient_shape() {
1135        let input = make_2d(vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]]);
1136        let weights = make_2d(vec![vec![1.0, 2.0, 1.0], vec![1.0, 2.0, 1.0]]);
1137        let q = WeightedQuantifier::new(config());
1138        // Forward along axis 1 → shape [2]
1139        let out = q
1140            .weighted_exists(&input, &weights, Some(1))
1141            .expect("forward");
1142        assert_eq!(out.shape(), &[2]);
1143        let grad = Array::ones(out.raw_dim());
1144        let d_input = q
1145            .weighted_exists_grad(&grad, &input, &weights, Some(1))
1146            .expect("grad");
1147        assert_eq!(
1148            d_input.shape(),
1149            input.shape(),
1150            "gradient should match input shape"
1151        );
1152    }
1153
1154    // -------------------------------------------------------------------------
1155    // 18. test_weighted_exists_gradient_finite
1156    // -------------------------------------------------------------------------
1157    #[test]
1158    fn test_weighted_exists_gradient_finite() {
1159        let input = make_1d(vec![0.2, 0.5, 0.8]);
1160        let weights = make_1d(vec![1.0, 3.0, 1.0]);
1161        let q = WeightedQuantifier::new(config());
1162        let out = q.weighted_exists(&input, &weights, None).expect("forward");
1163        let grad = Array::ones(out.raw_dim());
1164        let d_input = q
1165            .weighted_exists_grad(&grad, &input, &weights, None)
1166            .expect("grad");
1167        for &v in d_input.iter() {
1168            assert!(v.is_finite(), "gradient must be finite, got {}", v);
1169        }
1170    }
1171
1172    // -------------------------------------------------------------------------
1173    // 19. test_scoring_config_default
1174    // -------------------------------------------------------------------------
1175    #[test]
1176    fn test_scoring_config_default() {
1177        let cfg = ScoringConfig::default();
1178        assert_eq!(cfg.mode, ScoringMode::Standard);
1179        assert!((cfg.temperature - 1.0).abs() < EPS);
1180        assert!(cfg.log_floor < -100.0, "log_floor should be very negative");
1181        assert!(cfg.log_floor.is_finite(), "log_floor must be finite");
1182    }
1183
1184    // -------------------------------------------------------------------------
1185    // 20. test_scoring_config_builders
1186    // -------------------------------------------------------------------------
1187    #[test]
1188    fn test_scoring_config_builders() {
1189        let lp = ScoringConfig::log_probability();
1190        assert_eq!(lp.mode, ScoringMode::LogProbability);
1191
1192        let lo = ScoringConfig::log_odds();
1193        assert_eq!(lo.mode, ScoringMode::LogOdds);
1194
1195        let with_t = ScoringConfig::default().with_temperature(0.5);
1196        assert!((with_t.temperature - 0.5).abs() < EPS);
1197    }
1198
1199    // -------------------------------------------------------------------------
1200    // 21. test_free_function_log_sum_exp
1201    // -------------------------------------------------------------------------
1202    #[test]
1203    fn test_free_function_log_sum_exp() {
1204        let input = make_1d(vec![0.0, 0.0, 0.0]);
1205        let result = log_sum_exp(&input, None, config()).expect("free fn log_sum_exp");
1206        let expected = (3.0_f64).ln();
1207        assert!((result[[]] - expected).abs() < EPS);
1208    }
1209
1210    // -------------------------------------------------------------------------
1211    // 22. test_log_space_quantifier_mode_via_gradient_ops
1212    // -------------------------------------------------------------------------
1213    #[test]
1214    fn test_log_space_quantifier_mode_via_gradient_ops() {
1215        use crate::gradient_ops::{soft_exists, QuantifierMode};
1216
1217        // LogSpace mode should delegate to LogSpaceAggregator::log_sum_exp
1218        let input = make_1d(vec![0.0, 0.0, 0.0]);
1219        let scoring_cfg = ScoringConfig::log_probability();
1220        let mode = QuantifierMode::LogSpace(scoring_cfg);
1221        let result = soft_exists(&input, None, mode).expect("log_space quantifier");
1222        let expected = (3.0_f64).ln(); // log-sum-exp([0,0,0]) = log(3)
1223        assert!(
1224            (result[[]] - expected).abs() < 1e-10,
1225            "expected log(3)={}, got {}",
1226            expected,
1227            result[[]]
1228        );
1229    }
1230}