uni_locy/semiring.rs
1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Probability semiring abstraction for Locy aggregation.
5//!
6//! The [`LocySemiring`] trait lifts MNOR (noisy-OR) and MPROD (product) off
7//! the hard-coded `match FoldAggKind { Nor, Prod }` arms in the runtime and
8//! onto a typed abstraction so future semirings (Scallop-style `TopKProofs`,
9//! gradient lifts) can drop in without re-shaping the planner.
10//!
11//! ### Scope
12//!
13//! The trait is **row-at-a-time**: it composes per-tuple tags via
14//! [`plus`](LocySemiring::plus) (disjunction) and [`times`](LocySemiring::times)
15//! (conjunction). This covers `AddMultProb` (independent noisy-OR / product)
16//! and `MaxMinProb` (Viterbi). It deliberately does **not** cover
17//! [`SemiringKind::BddExact`], which operates over a whole aggregation group's
18//! lineage at once via weighted model counting (see
19//! `crates/uni-query/src/query/df_graph/locy_bdd.rs`). `BddExact` is
20//! dispatched at the fixpoint level outside this trait; C0 will absorb it
21//! once tag-DNFs land.
22//!
23//! See `/home/rohit/work/dragonscale/uni-locy-docs/DEEP_LOCY_IMPLEMENTATION_PLAN.md`
24//! §1.6 (decision D-7) for the design rationale.
25
26use crate::types::SemiringKind;
27
28/// Domain / unsupported-operation error from a semiring.
29///
30/// Callers map this into their own error type — the semiring layer is
31/// deliberately decoupled from DataFusion's error type so `uni-locy` can
32/// remain free of a query-engine dependency.
33#[derive(Debug, Clone, PartialEq)]
34pub enum SemiringError {
35 /// A probability input fell outside `[0, 1]` and `strict_probability_domain`
36 /// was set. `op` is `"MNOR"` or `"MPROD"`.
37 DomainViolation { value: f64, op: &'static str },
38 /// Operation not supported by this semiring (e.g., `negate` on a
39 /// non-Boolean-tagged semiring once C0 lands).
40 NotSupported {
41 op: &'static str,
42 kind: SemiringKind,
43 },
44}
45
46impl std::fmt::Display for SemiringError {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 match self {
49 Self::DomainViolation { value, op } => write!(
50 f,
51 "strict_probability_domain: {op} input {value} is outside [0, 1]"
52 ),
53 Self::NotSupported { op, kind } => {
54 write!(f, "semiring {kind:?} does not support {op}")
55 }
56 }
57 }
58}
59
60impl std::error::Error for SemiringError {}
61
62/// Domain-check a raw scalar before feeding it to `plus` / `times`.
63///
64/// Returns the value unchanged when it lies in `[0, 1]`. Outside that
65/// range it returns [`SemiringError::DomainViolation`] when `strict` is
66/// set, otherwise clamps and emits the pre-refactor tracing literal
67/// (`"<op> input <raw> outside [0,1], clamped to <clamped>"`) so any
68/// string-asserting tests remain stable. Shared by every `f64`-tagged
69/// [`LocySemiring::validate_domain`] impl.
70pub fn validate_probability_domain(
71 raw: f64,
72 op: &'static str,
73 strict: bool,
74) -> Result<f64, SemiringError> {
75 if (0.0..=1.0).contains(&raw) {
76 return Ok(raw);
77 }
78 if strict {
79 return Err(SemiringError::DomainViolation { value: raw, op });
80 }
81 let clamped = raw.clamp(0.0, 1.0);
82 tracing::warn!("{op} input {raw} outside [0,1], clamped to {clamped}");
83 Ok(clamped)
84}
85
86/// Consolidated semiring configuration threaded through planner and
87/// executors. Constructed by [`crate::LocyConfig::resolve`].
88#[derive(Debug, Clone, Copy, PartialEq)]
89pub struct ResolvedSemiringConfig {
90 pub kind: SemiringKind,
91 pub strict_probability_domain: bool,
92 pub probability_epsilon: f64,
93 pub max_bdd_variables: usize,
94}
95
96impl ResolvedSemiringConfig {
97 /// Returns true when the active semiring is the default `AddMultProb`
98 /// path. Phase-3 shared-proof detection and the AddMultProb-specific
99 /// complement code paths gate on this.
100 pub fn is_add_mult_prob(&self) -> bool {
101 matches!(self.kind, SemiringKind::AddMultProb)
102 }
103}
104
105/// Row-at-a-time probability semiring.
106///
107/// Implementors carry their own state (e.g., underflow epsilon for
108/// `AddMultProb`'s log-space switch). All operations are pure and
109/// reentrant so the same semiring instance is safely shared across the
110/// fixpoint loop.
111pub trait LocySemiring: Send + Sync + 'static {
112 /// The per-tuple tag type. For the two Phase-A semirings this is
113 /// `f64`; C0's `TopKProofs` will carry a proof-DNF tag.
114 type Tag: Clone + Send + Sync;
115
116 fn kind(&self) -> SemiringKind;
117
118 /// Whether this semiring composes pointwise via [`plus`](Self::plus) and
119 /// [`times`](Self::times). Returns `true` for `AddMultProb` and
120 /// `MaxMinProb`; `false` for whole-group semirings such as `BddExact`
121 /// which are dispatched outside this trait.
122 fn is_row_at_a_time(&self) -> bool {
123 true
124 }
125
126 /// Identity for [`plus`](Self::plus) — `0.0` for both Phase-A semirings.
127 fn zero_disjunction(&self) -> Self::Tag;
128
129 /// Identity for [`times`](Self::times) — `1.0` for both Phase-A semirings.
130 fn one_conjunction(&self) -> Self::Tag;
131
132 /// Disjunction. MNOR / proof-OR.
133 fn plus(&self, a: &Self::Tag, b: &Self::Tag) -> Self::Tag;
134
135 /// Conjunction. MPROD / proof-AND.
136 fn times(&self, a: &Self::Tag, b: &Self::Tag) -> Self::Tag;
137
138 /// Complement (`1 - p` conventionally). May return
139 /// [`SemiringError::NotSupported`] for semirings whose tags do not
140 /// admit a complement.
141 fn negate(&self, a: &Self::Tag) -> Result<Self::Tag, SemiringError>;
142
143 /// Collapse a tag to a probability in `[0, 1]`.
144 fn weight(&self, a: &Self::Tag) -> f64;
145
146 /// Domain-check a raw scalar before feeding it to `plus` / `times`.
147 /// Returns the clamped value, or `DomainViolation` when `strict` is
148 /// set and `raw` falls outside `[0, 1]`. Emits the same tracing literal
149 /// (`"MNOR input ..."` / `"MPROD input ..."`) used in the pre-refactor
150 /// code so any string-asserting tests remain stable.
151 fn validate_domain(
152 &self,
153 raw: f64,
154 op: &'static str,
155 strict: bool,
156 ) -> Result<f64, SemiringError>;
157}
158
159// ---------------------------------------------------------------------------
160// AddMultProb — Phase 1/2 default. Independence-assumed noisy-OR and product.
161// ---------------------------------------------------------------------------
162
163/// `(plus = noisy-OR, times = product, negate = 1 - p)`.
164///
165/// Stateful: carries `probability_epsilon`, the threshold below which
166/// `times` switches into log-space accumulation to avoid floating-point
167/// underflow (spec §5.3). This keeps the underflow guard inside the
168/// semiring rather than scattering it across executors.
169#[derive(Debug, Clone, Copy)]
170pub struct AddMultProb {
171 pub probability_epsilon: f64,
172}
173
174impl AddMultProb {
175 pub fn new(probability_epsilon: f64) -> Self {
176 Self {
177 probability_epsilon,
178 }
179 }
180}
181
182impl Default for AddMultProb {
183 fn default() -> Self {
184 Self {
185 probability_epsilon: 1e-15,
186 }
187 }
188}
189
190impl LocySemiring for AddMultProb {
191 type Tag = f64;
192
193 fn kind(&self) -> SemiringKind {
194 SemiringKind::AddMultProb
195 }
196
197 fn zero_disjunction(&self) -> f64 {
198 0.0
199 }
200
201 fn one_conjunction(&self) -> f64 {
202 1.0
203 }
204
205 fn plus(&self, a: &f64, b: &f64) -> f64 {
206 1.0 - (1.0 - *a) * (1.0 - *b)
207 }
208
209 fn times(&self, a: &f64, b: &f64) -> f64 {
210 // Log-space switch when the running product drops below epsilon.
211 if *a < self.probability_epsilon || *b < self.probability_epsilon {
212 let la = a.max(self.probability_epsilon).ln();
213 let lb = b.max(self.probability_epsilon).ln();
214 (la + lb).exp()
215 } else {
216 *a * *b
217 }
218 }
219
220 fn negate(&self, a: &f64) -> Result<f64, SemiringError> {
221 Ok(1.0 - *a)
222 }
223
224 fn weight(&self, a: &f64) -> f64 {
225 *a
226 }
227
228 fn validate_domain(
229 &self,
230 raw: f64,
231 op: &'static str,
232 strict: bool,
233 ) -> Result<f64, SemiringError> {
234 validate_probability_domain(raw, op, strict)
235 }
236}
237
238// ---------------------------------------------------------------------------
239// MaxMinProb — Viterbi / fuzzy. Opt-in only; triggers FuzzyNotProbabilistic.
240// ---------------------------------------------------------------------------
241
242/// `(plus = max, times = min, negate = 1 - p)`.
243///
244/// This is **fuzzy logic**, not probability. Any PROB-bearing rule
245/// evaluated under this semiring produces a non-suppressible
246/// `RuntimeWarningCode::FuzzyNotProbabilistic` (rollout decision D-9).
247#[derive(Debug, Clone, Copy, Default)]
248pub struct MaxMinProb;
249
250impl LocySemiring for MaxMinProb {
251 type Tag = f64;
252
253 fn kind(&self) -> SemiringKind {
254 SemiringKind::MaxMinProb
255 }
256
257 fn zero_disjunction(&self) -> f64 {
258 0.0
259 }
260
261 fn one_conjunction(&self) -> f64 {
262 1.0
263 }
264
265 fn plus(&self, a: &f64, b: &f64) -> f64 {
266 a.max(*b)
267 }
268
269 fn times(&self, a: &f64, b: &f64) -> f64 {
270 a.min(*b)
271 }
272
273 fn negate(&self, a: &f64) -> Result<f64, SemiringError> {
274 Ok(1.0 - *a)
275 }
276
277 fn weight(&self, a: &f64) -> f64 {
278 *a
279 }
280
281 fn validate_domain(
282 &self,
283 raw: f64,
284 op: &'static str,
285 strict: bool,
286 ) -> Result<f64, SemiringError> {
287 validate_probability_domain(raw, op, strict)
288 }
289}
290
291// ---------------------------------------------------------------------------
292// SemiringDispatch — runtime-selectable concrete type for executors.
293// ---------------------------------------------------------------------------
294
295/// Concrete enum dispatching to the active row-at-a-time semiring.
296/// Used by `MonotonicAggState` and `FoldExec` instead of a
297/// `Box<dyn LocySemiring>` so that the per-row `plus`/`times` calls stay
298/// inlineable (the match is a small branch; LLVM specializes through).
299///
300/// * `SemiringKind::BddExact` maps to `AddMultProb` at the row level —
301/// the BDD post-correction runs over the same independence-mode
302/// accumulators (see `weighted_model_count` in `locy_bdd.rs`).
303/// * `SemiringKind::TopKProofs { k }` likewise dispatches to
304/// `AddMultProb` at the row level in **Stage 1** (this Phase C C0
305/// slice): the library-layer `TopKProofs<K>` impl in
306/// `crate::top_k_proofs` carries true tag math, but the runtime
307/// hot-path operates on `f64`. Stage 2 wires `TopKTag` flow through
308/// `MonotonicAggState` / `FoldExec` / record-batch encoding.
309#[derive(Debug, Clone, Copy)]
310pub enum SemiringDispatch {
311 AddMultProb(AddMultProb),
312 MaxMinProb(MaxMinProb),
313 TopKProofs { inner: AddMultProb, k: u32 },
314}
315
316impl SemiringDispatch {
317 /// Build a dispatch from the resolved kind and the underflow epsilon.
318 /// `BddExact` collapses to `AddMultProb` row math (post-correction
319 /// runs separately at the fixpoint level); `TopKProofs` likewise
320 /// in Stage 1.
321 pub fn new(kind: SemiringKind, probability_epsilon: f64) -> Self {
322 match kind {
323 SemiringKind::AddMultProb | SemiringKind::BddExact => {
324 Self::AddMultProb(AddMultProb::new(probability_epsilon))
325 }
326 SemiringKind::MaxMinProb => Self::MaxMinProb(MaxMinProb),
327 SemiringKind::TopKProofs { k } => {
328 // Stage 1 of Phase C C0: runtime tag flow is pending,
329 // so this dispatches to AddMultProb row math. Library
330 // users wanting the true `TopKTag` math should call
331 // `crate::top_k_proofs::TopKProofs::<K>` directly.
332 //
333 // The warn fires per executor creation rather than per
334 // row, so cost is negligible; the message helps users
335 // understand why their `TopKProofs` results match
336 // `AddMultProb` byte-for-byte until Stage 2.
337 tracing::warn!(
338 "TopKProofs(k={k}) runtime tag flow pending Stage 2 — \
339 falling back to AddMultProb row math; library-layer \
340 TopKProofs<K> math is available via uni_locy::top_k_proofs"
341 );
342 Self::TopKProofs {
343 inner: AddMultProb::new(probability_epsilon),
344 k,
345 }
346 }
347 }
348 }
349
350 pub fn kind(&self) -> SemiringKind {
351 match self {
352 Self::AddMultProb(sr) => sr.kind(),
353 Self::MaxMinProb(sr) => sr.kind(),
354 // Phase C C0 Stage 1 dispatch reports the underlying row
355 // math kind so existing AddMultProb-gated runtime paths
356 // (Phase-3 detector, BDD correction site) continue to fire
357 // correctly. Callers needing the original kind read it
358 // from `ResolvedSemiringConfig.kind` instead.
359 Self::TopKProofs { inner, .. } => inner.kind(),
360 }
361 }
362
363 /// Returns the `k` parameter when the dispatch is `TopKProofs`.
364 /// Stage 2 callers use this to find the K at the row-eval site
365 /// where they materialize tags. `None` for other semirings.
366 pub fn top_k(&self) -> Option<u32> {
367 match self {
368 Self::TopKProofs { k, .. } => Some(*k),
369 _ => None,
370 }
371 }
372
373 pub fn plus(&self, a: f64, b: f64) -> f64 {
374 match self {
375 Self::AddMultProb(sr) => sr.plus(&a, &b),
376 Self::MaxMinProb(sr) => sr.plus(&a, &b),
377 Self::TopKProofs { inner, .. } => inner.plus(&a, &b),
378 }
379 }
380
381 pub fn times(&self, a: f64, b: f64) -> f64 {
382 match self {
383 Self::AddMultProb(sr) => sr.times(&a, &b),
384 Self::MaxMinProb(sr) => sr.times(&a, &b),
385 Self::TopKProofs { inner, .. } => inner.times(&a, &b),
386 }
387 }
388
389 pub fn validate_domain(
390 &self,
391 raw: f64,
392 op: &'static str,
393 strict: bool,
394 ) -> Result<f64, SemiringError> {
395 match self {
396 Self::AddMultProb(sr) => sr.validate_domain(raw, op, strict),
397 Self::MaxMinProb(sr) => sr.validate_domain(raw, op, strict),
398 Self::TopKProofs { inner, .. } => inner.validate_domain(raw, op, strict),
399 }
400 }
401
402 // -----------------------------------------------------------------
403 // Stage 2 tag-flow surface (pending wiring).
404 //
405 // The methods below (`plus_tag` / `times_tag` / `zero_tag` /
406 // `singleton_tag` / `weight_of`) and the [`AggregatorValue`] enum
407 // are the typed-tag interface that Stage 2 will plumb through
408 // `MonotonicAggState` / `FoldExec` / record-batch encoding so the
409 // `TopKProofs` runtime path stops falling back to `AddMultProb` row
410 // math (see `SemiringDispatch::new`). They have no non-test callers
411 // yet — kept here, compiled and unit-tested, so the Stage 2 wiring
412 // lands against an already-validated surface rather than re-deriving
413 // it. Do not delete pending Stage 2.
414 // -----------------------------------------------------------------
415
416 /// Phase C C0 Stage 2: tag-level `plus` that supports both f64
417 /// semirings (AddMultProb / MaxMinProb) and the proof-tag
418 /// semiring (TopKProofs). Returns the merged value plus an
419 /// optional `PruneNotice` callers use to drive
420 /// `RuntimeWarningCode::TopKPruningCrossedDependency` emission.
421 /// Existing `plus(a: f64, b: f64) -> f64` stays unchanged for
422 /// hot paths that don't carry proof tags.
423 pub fn plus_tag(
424 &self,
425 a: &AggregatorValue,
426 b: &AggregatorValue,
427 ) -> (AggregatorValue, Option<crate::top_k_proofs::PruneNotice>) {
428 match (self, a, b) {
429 (Self::AddMultProb(sr), AggregatorValue::F64(x), AggregatorValue::F64(y)) => {
430 (AggregatorValue::F64(sr.plus(x, y)), None)
431 }
432 (Self::MaxMinProb(sr), AggregatorValue::F64(x), AggregatorValue::F64(y)) => {
433 (AggregatorValue::F64(sr.plus(x, y)), None)
434 }
435 (Self::TopKProofs { k, .. }, AggregatorValue::TopK(ta), AggregatorValue::TopK(tb)) => {
436 let (proofs, notice) = merge_top_k_dispatch(ta, tb, *k as usize);
437 (
438 AggregatorValue::TopK(crate::top_k_proofs::TopKTag { proofs }),
439 Some(notice),
440 )
441 }
442 // Type mismatch between dispatch arm and value variant —
443 // indicates a callsite bug (constructed the wrong
444 // AggregatorValue for the active semiring).
445 _ => unreachable!(
446 "SemiringDispatch::plus_tag: type mismatch — dispatch {:?} vs ({:?}, {:?})",
447 self.kind(),
448 std::mem::discriminant(a),
449 std::mem::discriminant(b),
450 ),
451 }
452 }
453
454 /// Phase C C0 Stage 2: tag-level `times`. Same contract as
455 /// `plus_tag`.
456 pub fn times_tag(
457 &self,
458 a: &AggregatorValue,
459 b: &AggregatorValue,
460 ) -> (AggregatorValue, Option<crate::top_k_proofs::PruneNotice>) {
461 match (self, a, b) {
462 (Self::AddMultProb(sr), AggregatorValue::F64(x), AggregatorValue::F64(y)) => {
463 (AggregatorValue::F64(sr.times(x, y)), None)
464 }
465 (Self::MaxMinProb(sr), AggregatorValue::F64(x), AggregatorValue::F64(y)) => {
466 (AggregatorValue::F64(sr.times(x, y)), None)
467 }
468 (Self::TopKProofs { k, .. }, AggregatorValue::TopK(ta), AggregatorValue::TopK(tb)) => {
469 // Cartesian product per the library impl; reuse
470 // merge_top_k for dedup + pruning.
471 if ta.proofs.is_empty() || tb.proofs.is_empty() {
472 return (
473 AggregatorValue::TopK(crate::top_k_proofs::TopKTag::zero()),
474 None,
475 );
476 }
477 let mut cart: Vec<crate::top_k_proofs::Proof> =
478 Vec::with_capacity(ta.proofs.len() * tb.proofs.len());
479 for pa in &ta.proofs {
480 for pb in &tb.proofs {
481 let mut nc = pa.neural_calls.clone();
482 let existing: std::collections::HashSet<u32> =
483 pa.neural_calls.iter().map(|c| c.0).collect();
484 for c in &pb.neural_calls {
485 if !existing.contains(&c.0) {
486 nc.push(*c);
487 }
488 }
489 cart.push(crate::top_k_proofs::Proof {
490 weight: pa.weight * pb.weight,
491 base_rvs: crate::dependency_dnf::BaseRvSet::union(
492 &pa.base_rvs,
493 &pb.base_rvs,
494 ),
495 neural_calls: nc,
496 });
497 }
498 }
499 let (proofs, notice) = merge_top_k_dispatch_owned(Vec::new(), cart, *k as usize);
500 (
501 AggregatorValue::TopK(crate::top_k_proofs::TopKTag { proofs }),
502 Some(notice),
503 )
504 }
505 _ => unreachable!(
506 "SemiringDispatch::times_tag: type mismatch — dispatch {:?} vs ({:?}, {:?})",
507 self.kind(),
508 std::mem::discriminant(a),
509 std::mem::discriminant(b),
510 ),
511 }
512 }
513
514 /// Phase C C0 Stage 2: return the additive-identity value for
515 /// the active semiring. Used by `MonotonicAggState` to
516 /// initialize new accumulator slots.
517 pub fn zero_tag(&self) -> AggregatorValue {
518 match self {
519 Self::AddMultProb(_) | Self::MaxMinProb(_) => AggregatorValue::F64(0.0),
520 Self::TopKProofs { .. } => AggregatorValue::TopK(crate::top_k_proofs::TopKTag::zero()),
521 }
522 }
523
524 /// Phase C C0 Stage 2: lift a row's f64 weight into the
525 /// dispatch's tag type. For `AddMultProb` / `MaxMinProb` this is
526 /// just `F64(w)`; for `TopKProofs` the caller supplies the
527 /// row's `base_rvs` and `neural_calls` so a single-Proof tag is
528 /// materialized.
529 pub fn singleton_tag(
530 &self,
531 weight: f64,
532 base_rvs: crate::dependency_dnf::BaseRvSet,
533 neural_calls: Vec<crate::top_k_proofs::NeuralCallId>,
534 ) -> AggregatorValue {
535 match self {
536 Self::AddMultProb(_) | Self::MaxMinProb(_) => AggregatorValue::F64(weight),
537 Self::TopKProofs { .. } => AggregatorValue::TopK(crate::top_k_proofs::TopKTag {
538 proofs: vec![crate::top_k_proofs::Proof {
539 weight,
540 base_rvs,
541 neural_calls,
542 }],
543 }),
544 }
545 }
546
547 /// Phase C C0 Stage 2: collapse an aggregator value to its
548 /// scalar probability for downstream f64-typed consumers
549 /// (record-batch encoding, BDD post-correction site, etc.).
550 pub fn weight_of(&self, value: &AggregatorValue) -> f64 {
551 match (self, value) {
552 (Self::AddMultProb(_) | Self::MaxMinProb(_), AggregatorValue::F64(v)) => *v,
553 (Self::TopKProofs { .. }, AggregatorValue::TopK(t)) => {
554 // Conservative weight per library impl: noisy-OR
555 // over proof weights under independence-mode.
556 let mut complement = 1.0;
557 for p in &t.proofs {
558 complement *= 1.0 - p.weight;
559 }
560 (1.0 - complement).clamp(0.0, 1.0)
561 }
562 _ => unreachable!(
563 "SemiringDispatch::weight_of: type mismatch — dispatch {:?} vs {:?}",
564 self.kind(),
565 std::mem::discriminant(value),
566 ),
567 }
568 }
569}
570
571/// Phase C C0 Stage 2: tag-typed accumulator value used by
572/// `MonotonicAggState`. The variant must match the active
573/// `SemiringDispatch` — `F64` for `AddMultProb` / `MaxMinProb`,
574/// `TopK` for `TopKProofs`. Cross-type pairs panic in
575/// `plus_tag` / `times_tag` (callsite bug).
576#[derive(Debug, Clone)]
577pub enum AggregatorValue {
578 F64(f64),
579 TopK(crate::top_k_proofs::TopKTag),
580}
581
582impl AggregatorValue {
583 /// Convenience constructor for f64 callers that want to
584 /// initialize an accumulator from a row weight under
585 /// `SemiringDispatch::AddMultProb` / `MaxMinProb`.
586 pub fn f64(v: f64) -> Self {
587 AggregatorValue::F64(v)
588 }
589}
590
591/// Helper that delegates to [`crate::top_k_proofs::merge_top_k_with`]
592/// over cloned proof lists. The library impl is generic over `K`
593/// (compile-time const); the runtime needs a value-level `k`.
594fn merge_top_k_dispatch(
595 a: &crate::top_k_proofs::TopKTag,
596 b: &crate::top_k_proofs::TopKTag,
597 k: usize,
598) -> (
599 Vec<crate::top_k_proofs::Proof>,
600 crate::top_k_proofs::PruneNotice,
601) {
602 merge_top_k_dispatch_owned(a.proofs.clone(), b.proofs.clone(), k)
603}
604
605/// Phase C C0 Stage 2: runtime-K merge over owned proof lists.
606/// Delegates to [`crate::top_k_proofs::merge_top_k_with`]; exposed
607/// `pub` (re-exported as `merge_top_k_runtime`) so the fixpoint loop
608/// can call it directly without constructing `AggregatorValue`
609/// wrappers when it has owned `Vec<Proof>` already.
610pub fn merge_top_k_dispatch_owned(
611 base: Vec<crate::top_k_proofs::Proof>,
612 additional: Vec<crate::top_k_proofs::Proof>,
613 k: usize,
614) -> (
615 Vec<crate::top_k_proofs::Proof>,
616 crate::top_k_proofs::PruneNotice,
617) {
618 crate::top_k_proofs::merge_top_k_with(base, additional, k)
619}
620
621impl Default for SemiringDispatch {
622 fn default() -> Self {
623 Self::AddMultProb(AddMultProb::default())
624 }
625}
626
627#[cfg(test)]
628mod tests {
629 use super::*;
630
631 #[test]
632 fn add_mult_prob_matches_pre_refactor_noisy_or() {
633 let sr = AddMultProb::default();
634 let mut acc = sr.zero_disjunction();
635 for p in [0.72, 0.54, 0.56, 0.42] {
636 acc = sr.plus(&acc, &p);
637 }
638 // From DEEP_LOCY.md §4.5: MNOR over [0.72, 0.54, 0.56, 0.42] ≈ 0.9671
639 // Hand-computed exact: 1 - 0.28 * 0.46 * 0.44 * 0.58 = 0.96713024
640 assert!((acc - 0.967_130_24).abs() < 1e-9, "got {acc}");
641 }
642
643 #[test]
644 fn add_mult_prob_product_underflow_safe() {
645 let sr = AddMultProb::new(1e-12);
646 // Drive `times` into log-space.
647 let r = sr.times(&1e-20, &1e-20);
648 assert!(r.is_finite());
649 assert!(r >= 0.0);
650 }
651
652 #[test]
653 fn max_min_prob_viterbi() {
654 let sr = MaxMinProb;
655 assert_eq!(sr.plus(&0.3, &0.7), 0.7);
656 assert_eq!(sr.times(&0.3, &0.7), 0.3);
657 }
658
659 #[test]
660 fn strict_domain_violation() {
661 let sr = AddMultProb::default();
662 assert!(matches!(
663 sr.validate_domain(1.5, "MNOR", true),
664 Err(SemiringError::DomainViolation { .. })
665 ));
666 assert_eq!(sr.validate_domain(1.5, "MNOR", false).unwrap(), 1.0);
667 }
668
669 #[test]
670 fn max_min_prob_strict_domain_violation() {
671 let sr = MaxMinProb;
672 assert!(matches!(
673 sr.validate_domain(-0.1, "MPROD", true),
674 Err(SemiringError::DomainViolation { .. })
675 ));
676 assert_eq!(sr.validate_domain(-0.1, "MPROD", false).unwrap(), 0.0);
677 assert_eq!(sr.validate_domain(2.0, "MNOR", false).unwrap(), 1.0);
678 }
679
680 #[test]
681 fn identities_are_correct() {
682 // MNOR over empty set = 0.0 (additive identity).
683 // MPROD over empty set = 1.0 (multiplicative identity).
684 // Both semirings agree on identities — exercised by FoldExec
685 // when a key group has no input rows.
686 let add = AddMultProb::default();
687 assert_eq!(add.zero_disjunction(), 0.0);
688 assert_eq!(add.one_conjunction(), 1.0);
689 let max = MaxMinProb;
690 assert_eq!(max.zero_disjunction(), 0.0);
691 assert_eq!(max.one_conjunction(), 1.0);
692 }
693
694 #[test]
695 fn dispatch_routes_to_correct_impl() {
696 // Same operands, different semirings — verifies SemiringDispatch
697 // doesn't accidentally collapse the two semirings.
698 let add = SemiringDispatch::new(SemiringKind::AddMultProb, 1e-15);
699 let max = SemiringDispatch::new(SemiringKind::MaxMinProb, 1e-15);
700 assert_eq!(add.plus(0.3, 0.5), 1.0 - 0.7 * 0.5); // 0.65
701 assert_eq!(max.plus(0.3, 0.5), 0.5);
702 assert_eq!(add.times(0.3, 0.5), 0.15);
703 assert_eq!(max.times(0.3, 0.5), 0.3);
704
705 // BddExact dispatches to AddMultProb at the row level — the BDD
706 // post-correction runs separately at the fixpoint level. So
707 // `SemiringDispatch::new(BddExact, ε).kind()` is AddMultProb by
708 // design; the original kind is tracked separately on
709 // `ResolvedSemiringConfig`.
710 let bdd = SemiringDispatch::new(SemiringKind::BddExact, 1e-15);
711 assert_eq!(bdd.kind(), SemiringKind::AddMultProb);
712 assert_eq!(bdd.plus(0.3, 0.5), 1.0 - 0.7 * 0.5);
713 }
714}