xlog_logic/hypergraph/inference.rs
1//! Transitive type inference across SCC predicates.
2//!
3//! Closes the PR 5 policy gap: where a join-key vertex was anchored
4//! only through SCC-recursive atoms, the typed gate previously left
5//! it untyped under "unknown ≠ unsupported." This module propagates
6//! types through the rule graph — body atoms type variables, head
7//! atoms back-propagate to head-predicate columns, iterate to
8//! fixpoint — so the typed gate has full type information when it
9//! consults [`super::analyze_typed`].
10//!
11//! ## Where inference is engaged
12//!
13//! Only the **group-aware** typed entry points engage inference:
14//!
15//! * [`super::evaluate_scc_fixpoint_typed`] runs inference once at
16//! entry, then types each rule's body using the inferred schemas
17//! plus `base_relations`.
18//! * [`super::evaluate_fixpoint_typed`] treats `target_predicate`
19//! as a single-element rule group and runs the same inference.
20//!
21//! The single-rule entry points retain the base-only typing policy
22//! because they have no SCC structure to propagate over:
23//!
24//! * [`super::evaluate_rule_typed`] takes one rule.
25//! * [`super::plan_rule`] / [`super::plan_rules`] plan per-rule.
26//!
27//! Callers that want SCC-aware planning should drive
28//! [`super::evaluate_scc_fixpoint_typed`] directly or build their
29//! own inference pass via [`infer_scc_predicate_schemas`].
30//!
31//! ## Conflict layering
32//!
33//! Inference detects only **back-propagation conflicts**: e.g.,
34//! predicate `p`'s column 0 is `U32` from rule A's head and
35//! `Symbol` from rule B's head → [`InferenceError::ConflictingPredicateColumnType`].
36//! Within-rule body conflicts (variable `X` typed `U32` in one body
37//! atom and `Symbol` in another) stay in the existing
38//! [`super::typed`] flow and surface as
39//! [`super::RefEvalError::ConflictingVariableType`]. Each conflict
40//! type is detected at exactly one layer.
41//!
42//! ## Cyclic-only predicates
43//!
44//! When an SCC has no base anchor anywhere (e.g., `a(X) :- b(X),
45//! b(X) :- a(X)` with no rule referencing `base_relations`), every
46//! column converges to `None`. The typed gate must NOT reject such
47//! rules: the policy narrows from "unknown ≠ unsupported" to
48//! "unknowable-after-inference ≠ unsupported." Locked by
49//! `cyclic_only_predicate_still_passes_typed_gate_locked_policy`.
50//!
51//! ## Strict-correctness behavior change
52//!
53//! Fixtures whose base-relation schemas disagreed but whose actual
54//! rows happened to agree at runtime were previously silent (the
55//! typed gate types each body atom independently). They now surface
56//! as [`InferenceError::ConflictingPredicateColumnType`] when
57//! back-propagating to a head predicate. That is a strict
58//! correctness win, not a regression — fixtures with internally
59//! contradictory schemas are now caught before evaluation rather
60//! than silently corrupting downstream comparisons.
61
62use super::reference::RefRelationStore;
63use crate::ast::{BodyLiteral, Rule, Term};
64use std::collections::BTreeMap;
65use xlog_core::ScalarType;
66
67/// Errors surfaced by [`infer_scc_predicate_schemas`].
68#[derive(Debug, Clone, PartialEq)]
69pub enum InferenceError {
70 /// Two rules contributing to the same head predicate disagree
71 /// on the type of the same column. The first rule that types
72 /// the column wins `first_*`; the rule that disagrees wins
73 /// `second_*`.
74 ConflictingPredicateColumnType {
75 /// Head predicate name where the conflict was detected.
76 predicate: String,
77 /// 0-based column index where types disagree.
78 column: usize,
79 /// Rule index (within the predicate's rule group) that
80 /// first typed the column.
81 first_rule_index: usize,
82 /// Type derived from the first rule's body for the head
83 /// variable at this column.
84 first_type: ScalarType,
85 /// Rule index (within the predicate's rule group) whose
86 /// derivation conflicts.
87 second_rule_index: usize,
88 /// Type derived from the conflicting rule's body for the
89 /// head variable at this column.
90 second_type: ScalarType,
91 },
92}
93
94/// Per-predicate inferred schema. `Vec` length equals the head
95/// arity; each element is `Some(t)` if inference established the
96/// column's type, or `None` if the column remains unknowable
97/// (e.g., cyclic-only predicate, or a head term whose body atoms
98/// don't type the corresponding variable).
99pub type InferredSchemas = BTreeMap<String, Vec<Option<ScalarType>>>;
100
101/// Infer per-predicate schemas for a rule group via constraint
102/// propagation through the rule graph.
103///
104/// Algorithm:
105///
106/// 1. Determine head arity per predicate from the first rule with
107/// a non-empty head. (Predicates whose every rule has an empty
108/// head are treated as 0-arity; in practice this is rare.)
109/// 2. Initialize each predicate's schema as `vec![None; arity]`.
110/// 3. Iterate: for each rule, compute a per-rule variable-to-type
111/// map by walking body atoms (typing vars from
112/// `base_relations` schemas first, then from currently-inferred
113/// SCC predicate schemas where columns are `Some`). Then
114/// back-propagate: for each `Term::Variable` in the head at
115/// column `i`, if the variable has a derived type, propose it
116/// as the type for `head_predicate.schema[i]`. Conflict if a
117/// column has been previously typed differently.
118/// 4. Stop when no schema column changes between iterations.
119///
120/// Within-rule body conflicts are NOT detected here; they are
121/// caught by the existing [`super::typed`] gate during its own
122/// per-rule type-derivation walk. See module docs for the
123/// conflict-layering split.
124pub fn infer_scc_predicate_schemas(
125 rules: &BTreeMap<String, Vec<Rule>>,
126 base_relations: &RefRelationStore,
127) -> Result<InferredSchemas, InferenceError> {
128 // Step 1+2: arity + initial schemas.
129 let mut schemas: InferredSchemas = BTreeMap::new();
130 for (predicate, group) in rules.iter() {
131 let arity = group
132 .iter()
133 .find(|r| !r.head.terms.is_empty())
134 .map(|r| r.head.terms.len())
135 .unwrap_or(0);
136 schemas.insert(predicate.clone(), vec![None; arity]);
137 }
138 // Track the rule index that first typed each column so the
139 // conflict report can name both contributors.
140 let mut origins: BTreeMap<(String, usize), usize> = BTreeMap::new();
141 // Inference is monotonic: every iteration that changes
142 // anything replaces a `None` with a `Some(_)`. The total
143 // number of column slots across all SCC predicates is the
144 // strict upper bound on iterations that produce change. We
145 // add 1 to allow for the final no-change iteration that
146 // detects convergence.
147 let total_columns: usize = schemas.values().map(|s| s.len()).sum();
148 let max_iterations = total_columns + 1;
149 let mut converged = false;
150 for _ in 0..max_iterations {
151 let mut changed = false;
152 for (predicate, group) in rules.iter() {
153 for (rule_index, rule) in group.iter().enumerate() {
154 let var_types = derive_rule_var_types(rule, base_relations, &schemas);
155 // Back-propagate from head terms to head-predicate
156 // columns.
157 for (col, term) in rule.head.terms.iter().enumerate() {
158 let name = match term {
159 Term::Variable(n) => n,
160 // Head constants / aggregates / wildcards do
161 // not constrain a column type via inference.
162 // Their type would be locked by the value
163 // itself at evaluation time.
164 _ => continue,
165 };
166 let Some(&derived) = var_types.get(name) else {
167 continue;
168 };
169 let schema = schemas
170 .get_mut(predicate)
171 .expect("predicate in initialized schemas");
172 if col >= schema.len() {
173 // Head arity drift across rules — let the
174 // structural SCC fixpoint surface this as
175 // HeadArityMismatch. Inference doesn't
176 // pre-empt; just skip this column.
177 continue;
178 }
179 match schema[col] {
180 None => {
181 schema[col] = Some(derived);
182 origins.insert((predicate.clone(), col), rule_index);
183 changed = true;
184 }
185 Some(existing) if existing == derived => {
186 // Agreement — silent.
187 }
188 Some(existing) => {
189 let first_rule_index =
190 origins.get(&(predicate.clone(), col)).copied().unwrap_or(0);
191 return Err(InferenceError::ConflictingPredicateColumnType {
192 predicate: predicate.clone(),
193 column: col,
194 first_rule_index,
195 first_type: existing,
196 second_rule_index: rule_index,
197 second_type: derived,
198 });
199 }
200 }
201 }
202 }
203 }
204 if !changed {
205 converged = true;
206 break;
207 }
208 }
209 // Monotonic invariant: every iteration that changed something
210 // replaced a None with a Some(_). The bound `total_columns + 1`
211 // strictly exceeds the number of such iterations possible, so
212 // failing to converge here indicates a future code change has
213 // broken the monotonicity guarantee — a programmer error, not
214 // a data error.
215 debug_assert!(
216 converged,
217 "type inference failed to converge within {max_iterations} iterations \
218 (monotonicity invariant violated)"
219 );
220 Ok(schemas)
221}
222
223/// Derive the per-variable type map for a single rule, consulting
224/// both `base_relations` and currently-inferred SCC schemas.
225///
226/// Body conflicts (a variable typed two different ways across
227/// body atoms within this rule) are NOT surfaced here — that is
228/// the responsibility of [`super::typed::derive_vertex_types`],
229/// which the typed gate calls before evaluation. This helper is
230/// a *forward* propagation pass that prefers the first type seen
231/// (in source order) and silently skips later disagreements; the
232/// typed gate later catches the disagreement on the same rule
233/// using its own walk.
234fn derive_rule_var_types(
235 rule: &Rule,
236 base_relations: &RefRelationStore,
237 inferred: &InferredSchemas,
238) -> BTreeMap<String, ScalarType> {
239 let mut var_types: BTreeMap<String, ScalarType> = BTreeMap::new();
240 for literal in &rule.body {
241 let body_atom = match literal {
242 BodyLiteral::Positive(a) => a,
243 _ => continue,
244 };
245 let schema_opt: Option<&[Option<ScalarType>]> =
246 if let Some(rel) = base_relations.get(&body_atom.predicate) {
247 // Build a transient "all-Some" view of the base schema.
248 // We don't actually need to allocate — handle directly.
249 let limit = body_atom.terms.len().min(rel.schema.len());
250 for (pos, term) in body_atom.terms[..limit].iter().enumerate() {
251 if let Term::Variable(name) = term {
252 var_types.entry(name.clone()).or_insert(rel.schema[pos]);
253 }
254 }
255 None
256 } else {
257 inferred.get(&body_atom.predicate).map(|v| v.as_slice())
258 };
259 if let Some(schema) = schema_opt {
260 let limit = body_atom.terms.len().min(schema.len());
261 for (pos, term) in body_atom.terms[..limit].iter().enumerate() {
262 if let Term::Variable(name) = term {
263 if let Some(ty) = schema[pos] {
264 var_types.entry(name.clone()).or_insert(ty);
265 }
266 }
267 }
268 }
269 }
270 var_types
271}
272
273/// Build the typed-gate input map for a single rule using
274/// inferred SCC schemas alongside base relations.
275///
276/// Mirrors [`super::typed::derive_vertex_types`]'s contract — same
277/// conflict surface ([`super::RefEvalError::ConflictingVariableType`])
278/// — but consults `inferred_schemas` whenever a body atom's
279/// predicate is not in `base_relations`. Inferred columns marked
280/// `None` are treated identically to "predicate absent": they
281/// don't type the variable at that position.
282///
283/// Used by [`super::evaluate_scc_fixpoint_typed`] and
284/// [`super::evaluate_fixpoint_typed`] inside their per-rule typed
285/// gate to give [`super::analyze_typed`] full type information.
286pub(super) fn derive_vertex_types_with_inference(
287 rule: &Rule,
288 base_relations: &RefRelationStore,
289 inferred_schemas: &InferredSchemas,
290) -> Result<BTreeMap<String, ScalarType>, super::RefEvalError> {
291 /// First-recorded site for a variable; used to populate the
292 /// `ConflictingVariableType` report when a second body atom
293 /// types the variable differently.
294 struct FirstSite {
295 predicate: String,
296 position: usize,
297 ty: ScalarType,
298 }
299 let mut sites: BTreeMap<String, FirstSite> = BTreeMap::new();
300 for literal in &rule.body {
301 let body_atom = match literal {
302 BodyLiteral::Positive(a) => a,
303 _ => continue,
304 };
305 // Type each position. Base relation wins if both are
306 // present (cannot happen — `base_relations` and
307 // `inferred_schemas` keys are disjoint by construction in
308 // the typed evaluators).
309 let position_types: Vec<Option<ScalarType>> =
310 if let Some(rel) = base_relations.get(&body_atom.predicate) {
311 let limit = body_atom.terms.len().min(rel.schema.len());
312 let mut v: Vec<Option<ScalarType>> = vec![None; body_atom.terms.len()];
313 for (pos_idx, slot) in v.iter_mut().enumerate().take(limit) {
314 *slot = Some(rel.schema[pos_idx]);
315 }
316 v
317 } else if let Some(schema) = inferred_schemas.get(&body_atom.predicate) {
318 let limit = body_atom.terms.len().min(schema.len());
319 let mut v: Vec<Option<ScalarType>> = vec![None; body_atom.terms.len()];
320 for (pos_idx, slot) in v.iter_mut().enumerate().take(limit) {
321 *slot = schema[pos_idx];
322 }
323 v
324 } else {
325 continue; // predicate unknown, no type info
326 };
327 for (position, term) in body_atom.terms.iter().enumerate() {
328 let var_name = match term {
329 Term::Variable(name) => name.clone(),
330 _ => continue,
331 };
332 let Some(ty) = position_types[position] else {
333 continue;
334 };
335 match sites.get(&var_name) {
336 None => {
337 sites.insert(
338 var_name,
339 FirstSite {
340 predicate: body_atom.predicate.clone(),
341 position,
342 ty,
343 },
344 );
345 }
346 Some(prior) if prior.ty == ty => {
347 // Agreeing repeat — silent.
348 }
349 Some(prior) => {
350 return Err(super::RefEvalError::ConflictingVariableType {
351 var: var_name,
352 first_predicate: prior.predicate.clone(),
353 first_position: prior.position,
354 first_type: prior.ty,
355 second_predicate: body_atom.predicate.clone(),
356 second_position: position,
357 second_type: ty,
358 });
359 }
360 }
361 }
362 }
363 Ok(sites
364 .into_iter()
365 .map(|(name, site)| (name, site.ty))
366 .collect())
367}