Skip to main content

shifty_engine/
infer.rs

1//! SHACL-AF rule inference (Layer 6) — least-fixpoint forward chaining.
2//!
3//! A rule fires on its focus nodes for which all `sh:condition`s hold, producing
4//! triples from its head's node expressions. Per the decided semantics
5//! ([`docs/03-recursion-semantics.md`](../../../docs/03-recursion-semantics.md)),
6//! inference is the **least fixpoint**. Rules run in ascending `sh:order`
7//! groups, and output from a later group may reactivate an earlier group in the
8//! next pass. Predicate-level delta scheduling avoids rerunning rules whose
9//! graph reads cannot observe the newly added triples. Triple rules only
10//! combine existing terms, and SPARQL `CONSTRUCT` results containing fresh
11//! blank nodes are rejected, preserving termination for the supported subset.
12//!
13//! Function node expressions are not executed yet; they are reported as
14//! diagnostics rather than silently skipped.
15
16use crate::frozen::FrozenIndexedDataset;
17use crate::path::{node_of, succ};
18use crate::sparql::SparqlExecutor;
19use crate::validate::{NonStratifiable, ShapeEvaluator, focus_nodes_with, graph_union};
20use oxrdf::{Graph, NamedNode, NamedOrBlankNode, Term, Triple};
21use shifty_algebra::{NodeExpr, Rule, RuleHead, Schema, Selector, ShapeArena};
22use shifty_opt::{RuleDependencies, analyze, rule_dependencies, rule_guard_dependencies};
23use shifty_parse::vocab;
24use std::collections::{BTreeSet, HashMap, HashSet};
25
26/// The result of running inference over a data graph.
27pub struct InferenceOutcome {
28    /// The data graph augmented with all inferred triples.
29    pub graph: Graph,
30    /// The triples that were newly inferred (not already asserted).
31    pub inferred: Vec<Triple>,
32    /// Unsupported rule features encountered (deduplicated).
33    pub diagnostics: Vec<String>,
34}
35
36/// Run rule inference to a least fixpoint, ordered by `sh:order`.
37pub fn infer(data: &Graph, schema: &Schema) -> Result<InferenceOutcome, NonStratifiable> {
38    infer_with_context(data, data, schema)
39}
40
41/// Run inference over split data and shapes graphs.
42///
43/// Rule focus nodes come from `data`, while class hierarchy, paths, conditions,
44/// and SPARQL rule bodies see the RDF union of `data` and `shapes`.
45pub fn infer_graphs(
46    data: &Graph,
47    shapes: &Graph,
48    schema: &Schema,
49) -> Result<InferenceOutcome, NonStratifiable> {
50    let context = graph_union(data, shapes);
51    infer_with_context(data, &context, schema)
52}
53
54/// Run inference with data-scoped focus discovery and a broader execution
55/// context. `context` should contain `data`; newly inferred triples are added to
56/// both the returned data graph and the mutable execution context.
57pub fn infer_with_context(
58    data: &Graph,
59    context: &Graph,
60    schema: &Schema,
61) -> Result<InferenceOutcome, NonStratifiable> {
62    let strat = analyze(&schema.arena);
63    if !strat.stratifiable {
64        let components = strat
65            .strata
66            .iter()
67            .filter(|s| !s.stratifiable)
68            .map(|s| s.shapes.clone())
69            .collect();
70        return Err(NonStratifiable { components });
71    }
72
73    let mut graph = data.clone();
74    let mut context = context.clone();
75    let sparql =
76        SparqlExecutor::new(&context).expect("building an in-memory Oxigraph store should succeed");
77    let mut inferred: Vec<Triple> = Vec::new();
78    let mut diags: BTreeSet<String> = BTreeSet::new();
79
80    let mut rules: Vec<ScheduledRule<'_>> = schema
81        .rules
82        .iter()
83        .enumerate()
84        .filter(|(_, rule)| !rule.deactivated)
85        .map(|(index, rule)| ScheduledRule {
86            index,
87            order: rule.order.unwrap_or(0),
88            dependencies: rule_dependencies(rule, &schema.arena),
89            guard_dependencies: rule_guard_dependencies(rule, &schema.arena),
90            rule,
91        })
92        .collect();
93    rules.sort_by_key(|scheduled| (scheduled.order, scheduled.index));
94    let mut frozen = rules
95        .iter()
96        .any(|scheduled| matches!(scheduled.rule.head, RuleHead::Sparql(_)))
97        .then(|| FrozenIndexedDataset::from_graph(&context));
98
99    // The first pass evaluates every rule. Later passes are semi-naive at rule
100    // granularity: only rules that may read a changed predicate are active.
101    let mut active: HashSet<usize> = (0..rules.len()).collect();
102    // Additions from each pass occupy one contiguous suffix of `inferred`.
103    // `delta_start` avoids cloning RDF terms into separate delta buffers.
104    let mut delta_start = 0;
105    let mut first_pass = true;
106    loop {
107        let mut changed_predicates = HashSet::new();
108        let mut added = false;
109        let mut start = 0;
110        let pass_start = inferred.len();
111        let mut visible_changed: HashSet<NamedNode> = inferred[delta_start..]
112            .iter()
113            .map(|triple| triple.predicate.clone())
114            .collect();
115
116        // Focus node sets are recomputed at most once per selector per pass.
117        // Entries are evicted lazily when a committed triple's predicate matches
118        // the selector's read dependency.
119        let mut focus_cache: HashMap<Selector, Vec<Term>> = HashMap::new();
120        // Predicates of triples committed so far within this pass, used to
121        // invalidate stale cache entries before they are read.
122        let mut pass_changed: HashSet<NamedNode> = HashSet::new();
123
124        while start < rules.len() {
125            let order = rules[start].order;
126            let mut end = start + 1;
127            while end < rules.len() && rules[end].order == order {
128                end += 1;
129            }
130
131            // Tied rules observe the same graph snapshot. Their additions are
132            // visible to subsequent order groups in this pass.
133            // HashSet deduplicates within the batch; fire_rule pre-filters
134            // against the context so only genuinely new triples reach here.
135            let mut candidates: HashSet<Triple> = HashSet::new();
136            for (position, scheduled) in rules[start..end].iter().enumerate() {
137                if !active.contains(&(start + position)) {
138                    continue;
139                }
140                let sel = &scheduled.rule.selector;
141                if selector_stale(sel, &pass_changed) {
142                    focus_cache.remove(sel);
143                }
144                let focus_nodes = focus_cache.entry(sel.clone()).or_insert_with(|| {
145                    focus_nodes_with(&graph, &context, sel, &schema.arena, &sparql)
146                });
147                let mut delta_focus_nodes = Vec::new();
148                let execution_focus_nodes = match &scheduled.rule.head {
149                    RuleHead::Sparql(construct)
150                        if !first_pass
151                            && !focus_nodes.is_empty()
152                            // Differential BGP execution visits the delta once
153                            // per scan. Above this crossover, the existing
154                            // focus-bound batch is the cheaper access path.
155                            && (inferred.len() - delta_start).saturating_mul(2)
156                                < focus_nodes.len()
157                            && !scheduled
158                                .guard_dependencies
159                                .affected_by(&visible_changed) =>
160                    {
161                        match sparql.construct_delta_foci(
162                            &construct.query,
163                            &inferred[delta_start..],
164                            frozen.as_ref(),
165                        ) {
166                            Ok(Some(affected)) => {
167                                delta_focus_nodes.extend(
168                                    focus_nodes
169                                        .iter()
170                                        .filter(|focus| affected.contains(*focus))
171                                        .cloned(),
172                                );
173                                delta_focus_nodes.as_slice()
174                            }
175                            Ok(None) | Err(_) => focus_nodes.as_slice(),
176                        }
177                    }
178                    _ => focus_nodes.as_slice(),
179                };
180                let rule_label = format!("rule[{}]", start + position);
181                let rule_t = std::time::Instant::now();
182                fire_rule(
183                    execution_focus_nodes,
184                    &context,
185                    &schema.arena,
186                    scheduled.rule,
187                    &sparql,
188                    frozen.as_ref(),
189                    &mut candidates,
190                    &mut diags,
191                );
192                crate::profile::record_shape(&rule_label, rule_t.elapsed().as_micros() as u64);
193            }
194            if let Some(frozen) = frozen.as_mut() {
195                frozen.extend_triples(candidates.iter());
196            }
197            for t in candidates {
198                pass_changed.insert(t.predicate.clone());
199                visible_changed.insert(t.predicate.clone());
200                graph.insert(&t);
201                context.insert(&t);
202                if let Err(error) = sparql.insert(&t) {
203                    diags.insert(format!("failed to update SPARQL inference store: {error}"));
204                }
205                changed_predicates.insert(t.predicate.clone());
206                inferred.push(t);
207                added = true;
208            }
209
210            start = end;
211        }
212
213        if !added {
214            break;
215        }
216
217        delta_start = pass_start;
218        first_pass = false;
219        active.clear();
220        for (position, scheduled) in rules.iter().enumerate() {
221            if scheduled.dependencies.affected_by(&changed_predicates) {
222                active.insert(position);
223            }
224        }
225        if active.is_empty() {
226            break;
227        }
228    }
229
230    Ok(InferenceOutcome {
231        graph,
232        inferred,
233        diagnostics: diags.into_iter().collect(),
234    })
235}
236
237struct ScheduledRule<'a> {
238    index: usize,
239    order: i64,
240    dependencies: RuleDependencies,
241    guard_dependencies: RuleDependencies,
242    rule: &'a Rule,
243}
244
245/// Whether a cached focus-node set for `sel` may have become stale given the
246/// predicates committed so far within the current pass.
247fn selector_stale(sel: &Selector, pass_changed: &HashSet<NamedNode>) -> bool {
248    if pass_changed.is_empty() {
249        return false;
250    }
251    match sel {
252        Selector::HasOut(p) | Selector::HasIn(p) => pass_changed.contains(p),
253        Selector::IsConst(_) => false,
254        // HasPath traversal and SPARQL queries can read any predicate.
255        Selector::HasPath(..) | Selector::Sparql(_) => true,
256    }
257}
258
259#[allow(clippy::too_many_arguments)]
260fn fire_rule(
261    focus_nodes: &[Term],
262    context: &Graph,
263    arena: &ShapeArena,
264    rule: &shifty_algebra::Rule,
265    sparql: &SparqlExecutor,
266    frozen: Option<&FrozenIndexedDataset>,
267    out: &mut HashSet<Triple>,
268    diags: &mut BTreeSet<String>,
269) {
270    let mut evaluator = ShapeEvaluator::new(context, arena, sparql);
271    let eligible: Vec<&Term> = focus_nodes
272        .iter()
273        .filter(|v| rule.conditions.iter().all(|c| evaluator.holds(v, *c)))
274        .collect();
275
276    match &rule.head {
277        RuleHead::Triple {
278            subject,
279            predicate,
280            object,
281        } => {
282            for v in eligible {
283                let subjects = eval_node_expr(context, v, subject, &mut evaluator, diags);
284                let predicates = eval_node_expr(context, v, predicate, &mut evaluator, diags);
285                let objects = eval_node_expr(context, v, object, &mut evaluator, diags);
286                for s in &subjects {
287                    let Some(subj) = node_of(s) else { continue };
288                    for p in &predicates {
289                        let Term::NamedNode(pred) = p else { continue };
290                        for o in &objects {
291                            let t = Triple::new(subj.clone(), pred.clone(), o.clone());
292                            if !context.contains(&t) {
293                                out.insert(t);
294                            }
295                        }
296                    }
297                }
298            }
299        }
300        RuleHead::Sparql(construct) => {
301            let eligible: Vec<Term> = eligible.into_iter().cloned().collect();
302            match sparql.construct_many(&construct.query, &eligible, frozen) {
303                Ok(triples) => {
304                    for triple in triples {
305                        if matches!(triple.subject, oxrdf::NamedOrBlankNode::BlankNode(_))
306                            || matches!(triple.object, Term::BlankNode(_))
307                        {
308                            diags.insert(
309                                "sh:SPARQLRule CONSTRUCT blank nodes are not supported because \
310                                 they can prevent fixpoint termination"
311                                    .to_string(),
312                            );
313                        } else {
314                            out.insert(triple);
315                        }
316                    }
317                }
318                Err(error) => {
319                    diags.insert(format!("sh:SPARQLRule evaluation failed: {error}"));
320                }
321            }
322        }
323    }
324}
325
326/// Evaluate a node expression at focus node `v` to its set of result terms.
327fn eval_node_expr(
328    g: &Graph,
329    v: &Term,
330    expr: &NodeExpr,
331    evaluator: &mut ShapeEvaluator<'_>,
332    diags: &mut BTreeSet<String>,
333) -> HashSet<Term> {
334    match expr {
335        NodeExpr::This => once(v.clone()),
336        NodeExpr::Constant(t) => once(t.clone()),
337        NodeExpr::Path(p) => succ(g, v, p),
338        NodeExpr::Filter { input, shape } => eval_node_expr(g, v, input, evaluator, diags)
339            .into_iter()
340            .filter(|x| evaluator.holds(x, *shape))
341            .collect(),
342        NodeExpr::Intersection(es) => {
343            let mut iter = es.iter();
344            match iter.next() {
345                Some(first) => {
346                    let mut acc = eval_node_expr(g, v, first, evaluator, diags);
347                    for e in iter {
348                        let s = eval_node_expr(g, v, e, evaluator, diags);
349                        acc.retain(|x| s.contains(x));
350                    }
351                    acc
352                }
353                None => HashSet::new(),
354            }
355        }
356        NodeExpr::Union(es) => {
357            let mut acc = HashSet::new();
358            for e in es {
359                acc.extend(eval_node_expr(g, v, e, evaluator, diags));
360            }
361            acc
362        }
363        NodeExpr::Function { iri, args } => {
364            // Evaluate arguments before borrowing evaluator for sparql().
365            let arg_values: Vec<HashSet<Term>> = args
366                .iter()
367                .map(|a| eval_node_expr(g, v, a, evaluator, diags))
368                .collect();
369
370            let func = NamedOrBlankNode::NamedNode(iri.clone());
371            let Some(query_text) = g
372                .object_for_subject_predicate(&func, vocab::SH_SELECT)
373                .map(|t| t.into_owned())
374                .and_then(|t| match t {
375                    Term::Literal(l) => Some(l.value().to_string()),
376                    _ => None,
377                })
378            else {
379                diags.insert(format!("function <{}> has no sh:select", iri.as_str()));
380                return HashSet::new();
381            };
382
383            let params = function_params(g, &func);
384            let sparql = evaluator.sparql();
385            let mut results = HashSet::new();
386            for combo in cartesian_product(&arg_values) {
387                if combo.len() != params.len() {
388                    continue;
389                }
390                let bindings: Vec<(String, Term)> = params
391                    .iter()
392                    .zip(combo)
393                    .map(|(name, val)| (name.clone(), val))
394                    .collect();
395                match sparql.call_sparql_function(&query_text, &bindings) {
396                    Ok(terms) => results.extend(terms),
397                    Err(e) => {
398                        diags.insert(format!("function <{}> error: {e}", iri.as_str()));
399                    }
400                }
401            }
402            results
403        }
404    }
405}
406
407fn once(t: Term) -> HashSet<Term> {
408    let mut s = HashSet::with_capacity(1);
409    s.insert(t);
410    s
411}
412
413/// Return the local name of an IRI (the part after the last `#` or `/`).
414fn local_name(iri: &str) -> &str {
415    iri.rsplit(['#', '/']).next().unwrap_or(iri)
416}
417
418/// Resolve a SPARQL function's parameter names from the context graph,
419/// sorted by `sh:order` then by local name of `sh:path` (or `sh:name`).
420fn function_params(g: &Graph, func: &NamedOrBlankNode) -> Vec<String> {
421    let mut params: Vec<(i64, String)> = g
422        .objects_for_subject_predicate(func, vocab::SH_PARAMETER)
423        .filter_map(|param_ref| {
424            let param_node = node_of(&param_ref.into_owned())?;
425            let order = g
426                .object_for_subject_predicate(&param_node, vocab::SH_ORDER)
427                .map(|t| t.into_owned())
428                .and_then(|t| match t {
429                    Term::Literal(l) => l.value().parse::<i64>().ok(),
430                    _ => None,
431                })
432                .unwrap_or(0);
433            let name = g
434                .object_for_subject_predicate(&param_node, vocab::SH_NAME)
435                .map(|t| t.into_owned())
436                .and_then(|t| match t {
437                    Term::Literal(l) => Some(l.value().to_string()),
438                    _ => None,
439                })
440                .or_else(|| {
441                    g.object_for_subject_predicate(&param_node, vocab::SH_PATH)
442                        .map(|t| t.into_owned())
443                        .and_then(|t| match t {
444                            Term::NamedNode(n) => Some(local_name(n.as_str()).to_string()),
445                            _ => None,
446                        })
447                })?;
448            Some((order, name))
449        })
450        .collect();
451    params.sort_by(|a, b| a.0.cmp(&b.0).then(a.1.cmp(&b.1)));
452    params.into_iter().map(|(_, name)| name).collect()
453}
454
455/// Cartesian product of term sets — one arg combo per returned vec.
456fn cartesian_product(sets: &[HashSet<Term>]) -> Vec<Vec<Term>> {
457    sets.iter().fold(vec![vec![]], |acc, set| {
458        acc.into_iter()
459            .flat_map(|combo| {
460                set.iter().map(move |item| {
461                    let mut row = combo.clone();
462                    row.push(item.clone());
463                    row
464                })
465            })
466            .collect()
467    })
468}