Skip to main content

uni_locy/
semiring.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Probability semiring abstraction for Locy aggregation.
5//!
6//! The [`LocySemiring`] trait lifts MNOR (noisy-OR) and MPROD (product) off
7//! the hard-coded `match FoldAggKind { Nor, Prod }` arms in the runtime and
8//! onto a typed abstraction so future semirings (Scallop-style `TopKProofs`,
9//! gradient lifts) can drop in without re-shaping the planner.
10//!
11//! ### Scope
12//!
13//! The trait is **row-at-a-time**: it composes per-tuple tags via
14//! [`plus`](LocySemiring::plus) (disjunction) and [`times`](LocySemiring::times)
15//! (conjunction). This covers `AddMultProb` (independent noisy-OR / product)
16//! and `MaxMinProb` (Viterbi). It deliberately does **not** cover
17//! [`SemiringKind::BddExact`], which operates over a whole aggregation group's
18//! lineage at once via weighted model counting (see
19//! `crates/uni-query/src/query/df_graph/locy_bdd.rs`). `BddExact` is
20//! dispatched at the fixpoint level outside this trait; C0 will absorb it
21//! once tag-DNFs land.
22//!
23//! See `/home/rohit/work/dragonscale/uni-locy-docs/DEEP_LOCY_IMPLEMENTATION_PLAN.md`
24//! §1.6 (decision D-7) for the design rationale.
25
26use crate::types::SemiringKind;
27
28/// Domain / unsupported-operation error from a semiring.
29///
30/// Callers map this into their own error type — the semiring layer is
31/// deliberately decoupled from DataFusion's error type so `uni-locy` can
32/// remain free of a query-engine dependency.
33#[derive(Debug, Clone, PartialEq)]
34pub enum SemiringError {
35    /// A probability input fell outside `[0, 1]` and `strict_probability_domain`
36    /// was set. `op` is `"MNOR"` or `"MPROD"`.
37    DomainViolation { value: f64, op: &'static str },
38    /// Operation not supported by this semiring (e.g., `negate` on a
39    /// non-Boolean-tagged semiring once C0 lands).
40    NotSupported {
41        op: &'static str,
42        kind: SemiringKind,
43    },
44}
45
46impl std::fmt::Display for SemiringError {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        match self {
49            Self::DomainViolation { value, op } => write!(
50                f,
51                "strict_probability_domain: {op} input {value} is outside [0, 1]"
52            ),
53            Self::NotSupported { op, kind } => {
54                write!(f, "semiring {kind:?} does not support {op}")
55            }
56        }
57    }
58}
59
60impl std::error::Error for SemiringError {}
61
62/// Domain-check a raw scalar before feeding it to `plus` / `times`.
63///
64/// Returns the value unchanged when it lies in `[0, 1]`. Outside that
65/// range it returns [`SemiringError::DomainViolation`] when `strict` is
66/// set, otherwise clamps and emits the pre-refactor tracing literal
67/// (`"<op> input <raw> outside [0,1], clamped to <clamped>"`) so any
68/// string-asserting tests remain stable. Shared by every `f64`-tagged
69/// [`LocySemiring::validate_domain`] impl.
70pub fn validate_probability_domain(
71    raw: f64,
72    op: &'static str,
73    strict: bool,
74) -> Result<f64, SemiringError> {
75    if (0.0..=1.0).contains(&raw) {
76        return Ok(raw);
77    }
78    if strict {
79        return Err(SemiringError::DomainViolation { value: raw, op });
80    }
81    let clamped = raw.clamp(0.0, 1.0);
82    tracing::warn!("{op} input {raw} outside [0,1], clamped to {clamped}");
83    Ok(clamped)
84}
85
86/// Consolidated semiring configuration threaded through planner and
87/// executors. Constructed by [`crate::LocyConfig::resolve`].
88#[derive(Debug, Clone, Copy, PartialEq)]
89pub struct ResolvedSemiringConfig {
90    pub kind: SemiringKind,
91    pub strict_probability_domain: bool,
92    pub probability_epsilon: f64,
93    pub max_bdd_variables: usize,
94}
95
96impl ResolvedSemiringConfig {
97    /// Returns true when the active semiring is the default `AddMultProb`
98    /// path. Phase-3 shared-proof detection and the AddMultProb-specific
99    /// complement code paths gate on this.
100    pub fn is_add_mult_prob(&self) -> bool {
101        matches!(self.kind, SemiringKind::AddMultProb)
102    }
103}
104
105/// Row-at-a-time probability semiring.
106///
107/// Implementors carry their own state (e.g., underflow epsilon for
108/// `AddMultProb`'s log-space switch). All operations are pure and
109/// reentrant so the same semiring instance is safely shared across the
110/// fixpoint loop.
111pub trait LocySemiring: Send + Sync + 'static {
112    /// The per-tuple tag type. For the two Phase-A semirings this is
113    /// `f64`; C0's `TopKProofs` will carry a proof-DNF tag.
114    type Tag: Clone + Send + Sync;
115
116    fn kind(&self) -> SemiringKind;
117
118    /// Whether this semiring composes pointwise via [`plus`](Self::plus) and
119    /// [`times`](Self::times). Returns `true` for `AddMultProb` and
120    /// `MaxMinProb`; `false` for whole-group semirings such as `BddExact`
121    /// which are dispatched outside this trait.
122    fn is_row_at_a_time(&self) -> bool {
123        true
124    }
125
126    /// Identity for [`plus`](Self::plus) — `0.0` for both Phase-A semirings.
127    fn zero_disjunction(&self) -> Self::Tag;
128
129    /// Identity for [`times`](Self::times) — `1.0` for both Phase-A semirings.
130    fn one_conjunction(&self) -> Self::Tag;
131
132    /// Disjunction. MNOR / proof-OR.
133    fn plus(&self, a: &Self::Tag, b: &Self::Tag) -> Self::Tag;
134
135    /// Conjunction. MPROD / proof-AND.
136    fn times(&self, a: &Self::Tag, b: &Self::Tag) -> Self::Tag;
137
138    /// Complement (`1 - p` conventionally). May return
139    /// [`SemiringError::NotSupported`] for semirings whose tags do not
140    /// admit a complement.
141    fn negate(&self, a: &Self::Tag) -> Result<Self::Tag, SemiringError>;
142
143    /// Collapse a tag to a probability in `[0, 1]`.
144    fn weight(&self, a: &Self::Tag) -> f64;
145
146    /// Domain-check a raw scalar before feeding it to `plus` / `times`.
147    /// Returns the clamped value, or `DomainViolation` when `strict` is
148    /// set and `raw` falls outside `[0, 1]`. Emits the same tracing literal
149    /// (`"MNOR input ..."` / `"MPROD input ..."`) used in the pre-refactor
150    /// code so any string-asserting tests remain stable.
151    fn validate_domain(
152        &self,
153        raw: f64,
154        op: &'static str,
155        strict: bool,
156    ) -> Result<f64, SemiringError>;
157}
158
159// ---------------------------------------------------------------------------
160// AddMultProb — Phase 1/2 default. Independence-assumed noisy-OR and product.
161// ---------------------------------------------------------------------------
162
163/// `(plus = noisy-OR, times = product, negate = 1 - p)`.
164///
165/// Stateful: carries `probability_epsilon`, the threshold below which
166/// `times` switches into log-space accumulation to avoid floating-point
167/// underflow (spec §5.3). This keeps the underflow guard inside the
168/// semiring rather than scattering it across executors.
169#[derive(Debug, Clone, Copy)]
170pub struct AddMultProb {
171    pub probability_epsilon: f64,
172}
173
174impl AddMultProb {
175    pub fn new(probability_epsilon: f64) -> Self {
176        Self {
177            probability_epsilon,
178        }
179    }
180}
181
182impl Default for AddMultProb {
183    fn default() -> Self {
184        Self {
185            probability_epsilon: 1e-15,
186        }
187    }
188}
189
190impl LocySemiring for AddMultProb {
191    type Tag = f64;
192
193    fn kind(&self) -> SemiringKind {
194        SemiringKind::AddMultProb
195    }
196
197    fn zero_disjunction(&self) -> f64 {
198        0.0
199    }
200
201    fn one_conjunction(&self) -> f64 {
202        1.0
203    }
204
205    fn plus(&self, a: &f64, b: &f64) -> f64 {
206        1.0 - (1.0 - *a) * (1.0 - *b)
207    }
208
209    fn times(&self, a: &f64, b: &f64) -> f64 {
210        // Log-space switch when the running product drops below epsilon.
211        if *a < self.probability_epsilon || *b < self.probability_epsilon {
212            let la = a.max(self.probability_epsilon).ln();
213            let lb = b.max(self.probability_epsilon).ln();
214            (la + lb).exp()
215        } else {
216            *a * *b
217        }
218    }
219
220    fn negate(&self, a: &f64) -> Result<f64, SemiringError> {
221        Ok(1.0 - *a)
222    }
223
224    fn weight(&self, a: &f64) -> f64 {
225        *a
226    }
227
228    fn validate_domain(
229        &self,
230        raw: f64,
231        op: &'static str,
232        strict: bool,
233    ) -> Result<f64, SemiringError> {
234        validate_probability_domain(raw, op, strict)
235    }
236}
237
238// ---------------------------------------------------------------------------
239// MaxMinProb — Viterbi / fuzzy. Opt-in only; triggers FuzzyNotProbabilistic.
240// ---------------------------------------------------------------------------
241
242/// `(plus = max, times = min, negate = 1 - p)`.
243///
244/// This is **fuzzy logic**, not probability. Any PROB-bearing rule
245/// evaluated under this semiring produces a non-suppressible
246/// `RuntimeWarningCode::FuzzyNotProbabilistic` (rollout decision D-9).
247#[derive(Debug, Clone, Copy, Default)]
248pub struct MaxMinProb;
249
250impl LocySemiring for MaxMinProb {
251    type Tag = f64;
252
253    fn kind(&self) -> SemiringKind {
254        SemiringKind::MaxMinProb
255    }
256
257    fn zero_disjunction(&self) -> f64 {
258        0.0
259    }
260
261    fn one_conjunction(&self) -> f64 {
262        1.0
263    }
264
265    fn plus(&self, a: &f64, b: &f64) -> f64 {
266        a.max(*b)
267    }
268
269    fn times(&self, a: &f64, b: &f64) -> f64 {
270        a.min(*b)
271    }
272
273    fn negate(&self, a: &f64) -> Result<f64, SemiringError> {
274        Ok(1.0 - *a)
275    }
276
277    fn weight(&self, a: &f64) -> f64 {
278        *a
279    }
280
281    fn validate_domain(
282        &self,
283        raw: f64,
284        op: &'static str,
285        strict: bool,
286    ) -> Result<f64, SemiringError> {
287        validate_probability_domain(raw, op, strict)
288    }
289}
290
291// ---------------------------------------------------------------------------
292// SemiringDispatch — runtime-selectable concrete type for executors.
293// ---------------------------------------------------------------------------
294
295/// Concrete enum dispatching to the active row-at-a-time semiring.
296/// Used by `MonotonicAggState` and `FoldExec` instead of a
297/// `Box<dyn LocySemiring>` so that the per-row `plus`/`times` calls stay
298/// inlineable (the match is a small branch; LLVM specializes through).
299///
300/// * `SemiringKind::BddExact` maps to `AddMultProb` at the row level —
301///   the BDD post-correction runs over the same independence-mode
302///   accumulators (see `weighted_model_count` in `locy_bdd.rs`).
303/// * `SemiringKind::TopKProofs { k }` likewise dispatches to
304///   `AddMultProb` at the row level in **Stage 1** (this Phase C C0
305///   slice): the library-layer `TopKProofs<K>` impl in
306///   `crate::top_k_proofs` carries true tag math, but the runtime
307///   hot-path operates on `f64`. Stage 2 wires `TopKTag` flow through
308///   `MonotonicAggState` / `FoldExec` / record-batch encoding.
309#[derive(Debug, Clone, Copy)]
310pub enum SemiringDispatch {
311    AddMultProb(AddMultProb),
312    MaxMinProb(MaxMinProb),
313    TopKProofs { inner: AddMultProb, k: u32 },
314}
315
316impl SemiringDispatch {
317    /// Build a dispatch from the resolved kind and the underflow epsilon.
318    /// `BddExact` collapses to `AddMultProb` row math (post-correction
319    /// runs separately at the fixpoint level); `TopKProofs` likewise
320    /// in Stage 1.
321    pub fn new(kind: SemiringKind, probability_epsilon: f64) -> Self {
322        match kind {
323            SemiringKind::AddMultProb | SemiringKind::BddExact => {
324                Self::AddMultProb(AddMultProb::new(probability_epsilon))
325            }
326            SemiringKind::MaxMinProb => Self::MaxMinProb(MaxMinProb),
327            SemiringKind::TopKProofs { k } => {
328                // Stage 1 of Phase C C0: runtime tag flow is pending,
329                // so this dispatches to AddMultProb row math. Library
330                // users wanting the true `TopKTag` math should call
331                // `crate::top_k_proofs::TopKProofs::<K>` directly.
332                //
333                // The warn fires per executor creation rather than per
334                // row, so cost is negligible; the message helps users
335                // understand why their `TopKProofs` results match
336                // `AddMultProb` byte-for-byte until Stage 2.
337                tracing::warn!(
338                    "TopKProofs(k={k}) runtime tag flow pending Stage 2 — \
339                     falling back to AddMultProb row math; library-layer \
340                     TopKProofs<K> math is available via uni_locy::top_k_proofs"
341                );
342                Self::TopKProofs {
343                    inner: AddMultProb::new(probability_epsilon),
344                    k,
345                }
346            }
347        }
348    }
349
350    pub fn kind(&self) -> SemiringKind {
351        match self {
352            Self::AddMultProb(sr) => sr.kind(),
353            Self::MaxMinProb(sr) => sr.kind(),
354            // Phase C C0 Stage 1 dispatch reports the underlying row
355            // math kind so existing AddMultProb-gated runtime paths
356            // (Phase-3 detector, BDD correction site) continue to fire
357            // correctly. Callers needing the original kind read it
358            // from `ResolvedSemiringConfig.kind` instead.
359            Self::TopKProofs { inner, .. } => inner.kind(),
360        }
361    }
362
363    /// Returns the `k` parameter when the dispatch is `TopKProofs`.
364    /// Stage 2 callers use this to find the K at the row-eval site
365    /// where they materialize tags. `None` for other semirings.
366    pub fn top_k(&self) -> Option<u32> {
367        match self {
368            Self::TopKProofs { k, .. } => Some(*k),
369            _ => None,
370        }
371    }
372
373    pub fn plus(&self, a: f64, b: f64) -> f64 {
374        match self {
375            Self::AddMultProb(sr) => sr.plus(&a, &b),
376            Self::MaxMinProb(sr) => sr.plus(&a, &b),
377            Self::TopKProofs { inner, .. } => inner.plus(&a, &b),
378        }
379    }
380
381    pub fn times(&self, a: f64, b: f64) -> f64 {
382        match self {
383            Self::AddMultProb(sr) => sr.times(&a, &b),
384            Self::MaxMinProb(sr) => sr.times(&a, &b),
385            Self::TopKProofs { inner, .. } => inner.times(&a, &b),
386        }
387    }
388
389    pub fn validate_domain(
390        &self,
391        raw: f64,
392        op: &'static str,
393        strict: bool,
394    ) -> Result<f64, SemiringError> {
395        match self {
396            Self::AddMultProb(sr) => sr.validate_domain(raw, op, strict),
397            Self::MaxMinProb(sr) => sr.validate_domain(raw, op, strict),
398            Self::TopKProofs { inner, .. } => inner.validate_domain(raw, op, strict),
399        }
400    }
401
402    // -----------------------------------------------------------------
403    // Stage 2 tag-flow surface (pending wiring).
404    //
405    // The methods below (`plus_tag` / `times_tag` / `zero_tag` /
406    // `singleton_tag` / `weight_of`) and the [`AggregatorValue`] enum
407    // are the typed-tag interface that Stage 2 will plumb through
408    // `MonotonicAggState` / `FoldExec` / record-batch encoding so the
409    // `TopKProofs` runtime path stops falling back to `AddMultProb` row
410    // math (see `SemiringDispatch::new`). They have no non-test callers
411    // yet — kept here, compiled and unit-tested, so the Stage 2 wiring
412    // lands against an already-validated surface rather than re-deriving
413    // it. Do not delete pending Stage 2.
414    // -----------------------------------------------------------------
415
416    /// Phase C C0 Stage 2: tag-level `plus` that supports both f64
417    /// semirings (AddMultProb / MaxMinProb) and the proof-tag
418    /// semiring (TopKProofs). Returns the merged value plus an
419    /// optional `PruneNotice` callers use to drive
420    /// `RuntimeWarningCode::TopKPruningCrossedDependency` emission.
421    /// Existing `plus(a: f64, b: f64) -> f64` stays unchanged for
422    /// hot paths that don't carry proof tags.
423    pub fn plus_tag(
424        &self,
425        a: &AggregatorValue,
426        b: &AggregatorValue,
427    ) -> (AggregatorValue, Option<crate::top_k_proofs::PruneNotice>) {
428        match (self, a, b) {
429            (Self::AddMultProb(sr), AggregatorValue::F64(x), AggregatorValue::F64(y)) => {
430                (AggregatorValue::F64(sr.plus(x, y)), None)
431            }
432            (Self::MaxMinProb(sr), AggregatorValue::F64(x), AggregatorValue::F64(y)) => {
433                (AggregatorValue::F64(sr.plus(x, y)), None)
434            }
435            (Self::TopKProofs { k, .. }, AggregatorValue::TopK(ta), AggregatorValue::TopK(tb)) => {
436                let (proofs, notice) = merge_top_k_dispatch(ta, tb, *k as usize);
437                (
438                    AggregatorValue::TopK(crate::top_k_proofs::TopKTag { proofs }),
439                    Some(notice),
440                )
441            }
442            // Type mismatch between dispatch arm and value variant —
443            // indicates a callsite bug (constructed the wrong
444            // AggregatorValue for the active semiring).
445            _ => unreachable!(
446                "SemiringDispatch::plus_tag: type mismatch — dispatch {:?} vs ({:?}, {:?})",
447                self.kind(),
448                std::mem::discriminant(a),
449                std::mem::discriminant(b),
450            ),
451        }
452    }
453
454    /// Phase C C0 Stage 2: tag-level `times`. Same contract as
455    /// `plus_tag`.
456    pub fn times_tag(
457        &self,
458        a: &AggregatorValue,
459        b: &AggregatorValue,
460    ) -> (AggregatorValue, Option<crate::top_k_proofs::PruneNotice>) {
461        match (self, a, b) {
462            (Self::AddMultProb(sr), AggregatorValue::F64(x), AggregatorValue::F64(y)) => {
463                (AggregatorValue::F64(sr.times(x, y)), None)
464            }
465            (Self::MaxMinProb(sr), AggregatorValue::F64(x), AggregatorValue::F64(y)) => {
466                (AggregatorValue::F64(sr.times(x, y)), None)
467            }
468            (Self::TopKProofs { k, .. }, AggregatorValue::TopK(ta), AggregatorValue::TopK(tb)) => {
469                // Cartesian product per the library impl; reuse
470                // merge_top_k for dedup + pruning.
471                if ta.proofs.is_empty() || tb.proofs.is_empty() {
472                    return (
473                        AggregatorValue::TopK(crate::top_k_proofs::TopKTag::zero()),
474                        None,
475                    );
476                }
477                let mut cart: Vec<crate::top_k_proofs::Proof> =
478                    Vec::with_capacity(ta.proofs.len() * tb.proofs.len());
479                for pa in &ta.proofs {
480                    for pb in &tb.proofs {
481                        let mut nc = pa.neural_calls.clone();
482                        let existing: std::collections::HashSet<u32> =
483                            pa.neural_calls.iter().map(|c| c.0).collect();
484                        for c in &pb.neural_calls {
485                            if !existing.contains(&c.0) {
486                                nc.push(*c);
487                            }
488                        }
489                        cart.push(crate::top_k_proofs::Proof {
490                            weight: pa.weight * pb.weight,
491                            base_rvs: crate::dependency_dnf::BaseRvSet::union(
492                                &pa.base_rvs,
493                                &pb.base_rvs,
494                            ),
495                            neural_calls: nc,
496                        });
497                    }
498                }
499                let (proofs, notice) = merge_top_k_dispatch_owned(Vec::new(), cart, *k as usize);
500                (
501                    AggregatorValue::TopK(crate::top_k_proofs::TopKTag { proofs }),
502                    Some(notice),
503                )
504            }
505            _ => unreachable!(
506                "SemiringDispatch::times_tag: type mismatch — dispatch {:?} vs ({:?}, {:?})",
507                self.kind(),
508                std::mem::discriminant(a),
509                std::mem::discriminant(b),
510            ),
511        }
512    }
513
514    /// Phase C C0 Stage 2: return the additive-identity value for
515    /// the active semiring. Used by `MonotonicAggState` to
516    /// initialize new accumulator slots.
517    pub fn zero_tag(&self) -> AggregatorValue {
518        match self {
519            Self::AddMultProb(_) | Self::MaxMinProb(_) => AggregatorValue::F64(0.0),
520            Self::TopKProofs { .. } => AggregatorValue::TopK(crate::top_k_proofs::TopKTag::zero()),
521        }
522    }
523
524    /// Phase C C0 Stage 2: lift a row's f64 weight into the
525    /// dispatch's tag type. For `AddMultProb` / `MaxMinProb` this is
526    /// just `F64(w)`; for `TopKProofs` the caller supplies the
527    /// row's `base_rvs` and `neural_calls` so a single-Proof tag is
528    /// materialized.
529    pub fn singleton_tag(
530        &self,
531        weight: f64,
532        base_rvs: crate::dependency_dnf::BaseRvSet,
533        neural_calls: Vec<crate::top_k_proofs::NeuralCallId>,
534    ) -> AggregatorValue {
535        match self {
536            Self::AddMultProb(_) | Self::MaxMinProb(_) => AggregatorValue::F64(weight),
537            Self::TopKProofs { .. } => AggregatorValue::TopK(crate::top_k_proofs::TopKTag {
538                proofs: vec![crate::top_k_proofs::Proof {
539                    weight,
540                    base_rvs,
541                    neural_calls,
542                }],
543            }),
544        }
545    }
546
547    /// Phase C C0 Stage 2: collapse an aggregator value to its
548    /// scalar probability for downstream f64-typed consumers
549    /// (record-batch encoding, BDD post-correction site, etc.).
550    pub fn weight_of(&self, value: &AggregatorValue) -> f64 {
551        match (self, value) {
552            (Self::AddMultProb(_) | Self::MaxMinProb(_), AggregatorValue::F64(v)) => *v,
553            (Self::TopKProofs { .. }, AggregatorValue::TopK(t)) => {
554                // Conservative weight per library impl: noisy-OR
555                // over proof weights under independence-mode.
556                let mut complement = 1.0;
557                for p in &t.proofs {
558                    complement *= 1.0 - p.weight;
559                }
560                (1.0 - complement).clamp(0.0, 1.0)
561            }
562            _ => unreachable!(
563                "SemiringDispatch::weight_of: type mismatch — dispatch {:?} vs {:?}",
564                self.kind(),
565                std::mem::discriminant(value),
566            ),
567        }
568    }
569}
570
571/// Phase C C0 Stage 2: tag-typed accumulator value used by
572/// `MonotonicAggState`. The variant must match the active
573/// `SemiringDispatch` — `F64` for `AddMultProb` / `MaxMinProb`,
574/// `TopK` for `TopKProofs`. Cross-type pairs panic in
575/// `plus_tag` / `times_tag` (callsite bug).
576#[derive(Debug, Clone)]
577pub enum AggregatorValue {
578    F64(f64),
579    TopK(crate::top_k_proofs::TopKTag),
580}
581
582impl AggregatorValue {
583    /// Convenience constructor for f64 callers that want to
584    /// initialize an accumulator from a row weight under
585    /// `SemiringDispatch::AddMultProb` / `MaxMinProb`.
586    pub fn f64(v: f64) -> Self {
587        AggregatorValue::F64(v)
588    }
589}
590
591/// Helper that delegates to [`crate::top_k_proofs::merge_top_k_with`]
592/// over cloned proof lists. The library impl is generic over `K`
593/// (compile-time const); the runtime needs a value-level `k`.
594fn merge_top_k_dispatch(
595    a: &crate::top_k_proofs::TopKTag,
596    b: &crate::top_k_proofs::TopKTag,
597    k: usize,
598) -> (
599    Vec<crate::top_k_proofs::Proof>,
600    crate::top_k_proofs::PruneNotice,
601) {
602    merge_top_k_dispatch_owned(a.proofs.clone(), b.proofs.clone(), k)
603}
604
605/// Phase C C0 Stage 2: runtime-K merge over owned proof lists.
606/// Delegates to [`crate::top_k_proofs::merge_top_k_with`]; exposed
607/// `pub` (re-exported as `merge_top_k_runtime`) so the fixpoint loop
608/// can call it directly without constructing `AggregatorValue`
609/// wrappers when it has owned `Vec<Proof>` already.
610pub fn merge_top_k_dispatch_owned(
611    base: Vec<crate::top_k_proofs::Proof>,
612    additional: Vec<crate::top_k_proofs::Proof>,
613    k: usize,
614) -> (
615    Vec<crate::top_k_proofs::Proof>,
616    crate::top_k_proofs::PruneNotice,
617) {
618    crate::top_k_proofs::merge_top_k_with(base, additional, k)
619}
620
621impl Default for SemiringDispatch {
622    fn default() -> Self {
623        Self::AddMultProb(AddMultProb::default())
624    }
625}
626
627#[cfg(test)]
628mod tests {
629    use super::*;
630
631    #[test]
632    fn add_mult_prob_matches_pre_refactor_noisy_or() {
633        let sr = AddMultProb::default();
634        let mut acc = sr.zero_disjunction();
635        for p in [0.72, 0.54, 0.56, 0.42] {
636            acc = sr.plus(&acc, &p);
637        }
638        // From DEEP_LOCY.md §4.5: MNOR over [0.72, 0.54, 0.56, 0.42] ≈ 0.9671
639        // Hand-computed exact: 1 - 0.28 * 0.46 * 0.44 * 0.58 = 0.96713024
640        assert!((acc - 0.967_130_24).abs() < 1e-9, "got {acc}");
641    }
642
643    #[test]
644    fn add_mult_prob_product_underflow_safe() {
645        let sr = AddMultProb::new(1e-12);
646        // Drive `times` into log-space.
647        let r = sr.times(&1e-20, &1e-20);
648        assert!(r.is_finite());
649        assert!(r >= 0.0);
650    }
651
652    #[test]
653    fn max_min_prob_viterbi() {
654        let sr = MaxMinProb;
655        assert_eq!(sr.plus(&0.3, &0.7), 0.7);
656        assert_eq!(sr.times(&0.3, &0.7), 0.3);
657    }
658
659    #[test]
660    fn strict_domain_violation() {
661        let sr = AddMultProb::default();
662        assert!(matches!(
663            sr.validate_domain(1.5, "MNOR", true),
664            Err(SemiringError::DomainViolation { .. })
665        ));
666        assert_eq!(sr.validate_domain(1.5, "MNOR", false).unwrap(), 1.0);
667    }
668
669    #[test]
670    fn max_min_prob_strict_domain_violation() {
671        let sr = MaxMinProb;
672        assert!(matches!(
673            sr.validate_domain(-0.1, "MPROD", true),
674            Err(SemiringError::DomainViolation { .. })
675        ));
676        assert_eq!(sr.validate_domain(-0.1, "MPROD", false).unwrap(), 0.0);
677        assert_eq!(sr.validate_domain(2.0, "MNOR", false).unwrap(), 1.0);
678    }
679
680    #[test]
681    fn identities_are_correct() {
682        // MNOR over empty set = 0.0 (additive identity).
683        // MPROD over empty set = 1.0 (multiplicative identity).
684        // Both semirings agree on identities — exercised by FoldExec
685        // when a key group has no input rows.
686        let add = AddMultProb::default();
687        assert_eq!(add.zero_disjunction(), 0.0);
688        assert_eq!(add.one_conjunction(), 1.0);
689        let max = MaxMinProb;
690        assert_eq!(max.zero_disjunction(), 0.0);
691        assert_eq!(max.one_conjunction(), 1.0);
692    }
693
694    #[test]
695    fn dispatch_routes_to_correct_impl() {
696        // Same operands, different semirings — verifies SemiringDispatch
697        // doesn't accidentally collapse the two semirings.
698        let add = SemiringDispatch::new(SemiringKind::AddMultProb, 1e-15);
699        let max = SemiringDispatch::new(SemiringKind::MaxMinProb, 1e-15);
700        assert_eq!(add.plus(0.3, 0.5), 1.0 - 0.7 * 0.5); // 0.65
701        assert_eq!(max.plus(0.3, 0.5), 0.5);
702        assert_eq!(add.times(0.3, 0.5), 0.15);
703        assert_eq!(max.times(0.3, 0.5), 0.3);
704
705        // BddExact dispatches to AddMultProb at the row level — the BDD
706        // post-correction runs separately at the fixpoint level. So
707        // `SemiringDispatch::new(BddExact, ε).kind()` is AddMultProb by
708        // design; the original kind is tracked separately on
709        // `ResolvedSemiringConfig`.
710        let bdd = SemiringDispatch::new(SemiringKind::BddExact, 1e-15);
711        assert_eq!(bdd.kind(), SemiringKind::AddMultProb);
712        assert_eq!(bdd.plus(0.3, 0.5), 1.0 - 0.7 * 0.5);
713    }
714}