1use crate::frozen::FrozenIndexedDataset;
17use crate::path::{node_of, succ};
18use crate::sparql::SparqlExecutor;
19use crate::validate::{NonStratifiable, ShapeEvaluator, focus_nodes_with, graph_union};
20use oxrdf::{Graph, NamedNode, NamedOrBlankNode, Term, Triple};
21use shifty_algebra::{NodeExpr, Rule, RuleHead, Schema, Selector, ShapeArena};
22use shifty_opt::{RuleDependencies, analyze, rule_dependencies, rule_guard_dependencies};
23use shifty_parse::vocab;
24use std::collections::{BTreeSet, HashMap, HashSet};
25
26pub struct InferenceOutcome {
28 pub graph: Graph,
30 pub inferred: Vec<Triple>,
32 pub diagnostics: Vec<String>,
34}
35
36pub fn infer(data: &Graph, schema: &Schema) -> Result<InferenceOutcome, NonStratifiable> {
38 infer_with_context(data, data, schema)
39}
40
41pub fn infer_graphs(
46 data: &Graph,
47 shapes: &Graph,
48 schema: &Schema,
49) -> Result<InferenceOutcome, NonStratifiable> {
50 let context = graph_union(data, shapes);
51 infer_with_context(data, &context, schema)
52}
53
54pub fn infer_with_context(
58 data: &Graph,
59 context: &Graph,
60 schema: &Schema,
61) -> Result<InferenceOutcome, NonStratifiable> {
62 let strat = analyze(&schema.arena);
63 if !strat.stratifiable {
64 let components = strat
65 .strata
66 .iter()
67 .filter(|s| !s.stratifiable)
68 .map(|s| s.shapes.clone())
69 .collect();
70 return Err(NonStratifiable { components });
71 }
72
73 let mut graph = data.clone();
74 let mut context = context.clone();
75 let sparql =
76 SparqlExecutor::new(&context).expect("building an in-memory Oxigraph store should succeed");
77 let mut inferred: Vec<Triple> = Vec::new();
78 let mut diags: BTreeSet<String> = BTreeSet::new();
79
80 let mut rules: Vec<ScheduledRule<'_>> = schema
81 .rules
82 .iter()
83 .enumerate()
84 .filter(|(_, rule)| !rule.deactivated)
85 .map(|(index, rule)| ScheduledRule {
86 index,
87 order: rule.order.unwrap_or(0),
88 dependencies: rule_dependencies(rule, &schema.arena),
89 guard_dependencies: rule_guard_dependencies(rule, &schema.arena),
90 rule,
91 })
92 .collect();
93 rules.sort_by_key(|scheduled| (scheduled.order, scheduled.index));
94 let mut frozen = rules
95 .iter()
96 .any(|scheduled| matches!(scheduled.rule.head, RuleHead::Sparql(_)))
97 .then(|| FrozenIndexedDataset::from_graph(&context));
98
99 let mut active: HashSet<usize> = (0..rules.len()).collect();
102 let mut delta_start = 0;
105 let mut first_pass = true;
106 loop {
107 let mut changed_predicates = HashSet::new();
108 let mut added = false;
109 let mut start = 0;
110 let pass_start = inferred.len();
111 let mut visible_changed: HashSet<NamedNode> = inferred[delta_start..]
112 .iter()
113 .map(|triple| triple.predicate.clone())
114 .collect();
115
116 let mut focus_cache: HashMap<Selector, Vec<Term>> = HashMap::new();
120 let mut pass_changed: HashSet<NamedNode> = HashSet::new();
123
124 while start < rules.len() {
125 let order = rules[start].order;
126 let mut end = start + 1;
127 while end < rules.len() && rules[end].order == order {
128 end += 1;
129 }
130
131 let mut candidates: HashSet<Triple> = HashSet::new();
136 for (position, scheduled) in rules[start..end].iter().enumerate() {
137 if !active.contains(&(start + position)) {
138 continue;
139 }
140 let sel = &scheduled.rule.selector;
141 if selector_stale(sel, &pass_changed) {
142 focus_cache.remove(sel);
143 }
144 let focus_nodes = focus_cache.entry(sel.clone()).or_insert_with(|| {
145 focus_nodes_with(&graph, &context, sel, &schema.arena, &sparql)
146 });
147 let mut delta_focus_nodes = Vec::new();
148 let execution_focus_nodes = match &scheduled.rule.head {
149 RuleHead::Sparql(construct)
150 if !first_pass
151 && !focus_nodes.is_empty()
152 && (inferred.len() - delta_start).saturating_mul(2)
156 < focus_nodes.len()
157 && !scheduled
158 .guard_dependencies
159 .affected_by(&visible_changed) =>
160 {
161 match sparql.construct_delta_foci(
162 &construct.query,
163 &inferred[delta_start..],
164 frozen.as_ref(),
165 ) {
166 Ok(Some(affected)) => {
167 delta_focus_nodes.extend(
168 focus_nodes
169 .iter()
170 .filter(|focus| affected.contains(*focus))
171 .cloned(),
172 );
173 delta_focus_nodes.as_slice()
174 }
175 Ok(None) | Err(_) => focus_nodes.as_slice(),
176 }
177 }
178 _ => focus_nodes.as_slice(),
179 };
180 let rule_label = format!("rule[{}]", start + position);
181 let rule_t = std::time::Instant::now();
182 fire_rule(
183 execution_focus_nodes,
184 &context,
185 &schema.arena,
186 scheduled.rule,
187 &sparql,
188 frozen.as_ref(),
189 &mut candidates,
190 &mut diags,
191 );
192 crate::profile::record_shape(&rule_label, rule_t.elapsed().as_micros() as u64);
193 }
194 if let Some(frozen) = frozen.as_mut() {
195 frozen.extend_triples(candidates.iter());
196 }
197 for t in candidates {
198 pass_changed.insert(t.predicate.clone());
199 visible_changed.insert(t.predicate.clone());
200 graph.insert(&t);
201 context.insert(&t);
202 if let Err(error) = sparql.insert(&t) {
203 diags.insert(format!("failed to update SPARQL inference store: {error}"));
204 }
205 changed_predicates.insert(t.predicate.clone());
206 inferred.push(t);
207 added = true;
208 }
209
210 start = end;
211 }
212
213 if !added {
214 break;
215 }
216
217 delta_start = pass_start;
218 first_pass = false;
219 active.clear();
220 for (position, scheduled) in rules.iter().enumerate() {
221 if scheduled.dependencies.affected_by(&changed_predicates) {
222 active.insert(position);
223 }
224 }
225 if active.is_empty() {
226 break;
227 }
228 }
229
230 Ok(InferenceOutcome {
231 graph,
232 inferred,
233 diagnostics: diags.into_iter().collect(),
234 })
235}
236
237struct ScheduledRule<'a> {
238 index: usize,
239 order: i64,
240 dependencies: RuleDependencies,
241 guard_dependencies: RuleDependencies,
242 rule: &'a Rule,
243}
244
245fn selector_stale(sel: &Selector, pass_changed: &HashSet<NamedNode>) -> bool {
248 if pass_changed.is_empty() {
249 return false;
250 }
251 match sel {
252 Selector::HasOut(p) | Selector::HasIn(p) => pass_changed.contains(p),
253 Selector::IsConst(_) => false,
254 Selector::HasPath(..) | Selector::Sparql(_) => true,
256 }
257}
258
259#[allow(clippy::too_many_arguments)]
260fn fire_rule(
261 focus_nodes: &[Term],
262 context: &Graph,
263 arena: &ShapeArena,
264 rule: &shifty_algebra::Rule,
265 sparql: &SparqlExecutor,
266 frozen: Option<&FrozenIndexedDataset>,
267 out: &mut HashSet<Triple>,
268 diags: &mut BTreeSet<String>,
269) {
270 let mut evaluator = ShapeEvaluator::new(context, arena, sparql);
271 let eligible: Vec<&Term> = focus_nodes
272 .iter()
273 .filter(|v| rule.conditions.iter().all(|c| evaluator.holds(v, *c)))
274 .collect();
275
276 match &rule.head {
277 RuleHead::Triple {
278 subject,
279 predicate,
280 object,
281 } => {
282 for v in eligible {
283 let subjects = eval_node_expr(context, v, subject, &mut evaluator, diags);
284 let predicates = eval_node_expr(context, v, predicate, &mut evaluator, diags);
285 let objects = eval_node_expr(context, v, object, &mut evaluator, diags);
286 for s in &subjects {
287 let Some(subj) = node_of(s) else { continue };
288 for p in &predicates {
289 let Term::NamedNode(pred) = p else { continue };
290 for o in &objects {
291 let t = Triple::new(subj.clone(), pred.clone(), o.clone());
292 if !context.contains(&t) {
293 out.insert(t);
294 }
295 }
296 }
297 }
298 }
299 }
300 RuleHead::Sparql(construct) => {
301 let eligible: Vec<Term> = eligible.into_iter().cloned().collect();
302 match sparql.construct_many(&construct.query, &eligible, frozen) {
303 Ok(triples) => {
304 for triple in triples {
305 if matches!(triple.subject, oxrdf::NamedOrBlankNode::BlankNode(_))
306 || matches!(triple.object, Term::BlankNode(_))
307 {
308 diags.insert(
309 "sh:SPARQLRule CONSTRUCT blank nodes are not supported because \
310 they can prevent fixpoint termination"
311 .to_string(),
312 );
313 } else {
314 out.insert(triple);
315 }
316 }
317 }
318 Err(error) => {
319 diags.insert(format!("sh:SPARQLRule evaluation failed: {error}"));
320 }
321 }
322 }
323 }
324}
325
326fn eval_node_expr(
328 g: &Graph,
329 v: &Term,
330 expr: &NodeExpr,
331 evaluator: &mut ShapeEvaluator<'_>,
332 diags: &mut BTreeSet<String>,
333) -> HashSet<Term> {
334 match expr {
335 NodeExpr::This => once(v.clone()),
336 NodeExpr::Constant(t) => once(t.clone()),
337 NodeExpr::Path(p) => succ(g, v, p),
338 NodeExpr::Filter { input, shape } => eval_node_expr(g, v, input, evaluator, diags)
339 .into_iter()
340 .filter(|x| evaluator.holds(x, *shape))
341 .collect(),
342 NodeExpr::Intersection(es) => {
343 let mut iter = es.iter();
344 match iter.next() {
345 Some(first) => {
346 let mut acc = eval_node_expr(g, v, first, evaluator, diags);
347 for e in iter {
348 let s = eval_node_expr(g, v, e, evaluator, diags);
349 acc.retain(|x| s.contains(x));
350 }
351 acc
352 }
353 None => HashSet::new(),
354 }
355 }
356 NodeExpr::Union(es) => {
357 let mut acc = HashSet::new();
358 for e in es {
359 acc.extend(eval_node_expr(g, v, e, evaluator, diags));
360 }
361 acc
362 }
363 NodeExpr::Function { iri, args } => {
364 let arg_values: Vec<HashSet<Term>> = args
366 .iter()
367 .map(|a| eval_node_expr(g, v, a, evaluator, diags))
368 .collect();
369
370 let func = NamedOrBlankNode::NamedNode(iri.clone());
371 let Some(query_text) = g
372 .object_for_subject_predicate(&func, vocab::SH_SELECT)
373 .map(|t| t.into_owned())
374 .and_then(|t| match t {
375 Term::Literal(l) => Some(l.value().to_string()),
376 _ => None,
377 })
378 else {
379 diags.insert(format!("function <{}> has no sh:select", iri.as_str()));
380 return HashSet::new();
381 };
382
383 let params = function_params(g, &func);
384 let sparql = evaluator.sparql();
385 let mut results = HashSet::new();
386 for combo in cartesian_product(&arg_values) {
387 if combo.len() != params.len() {
388 continue;
389 }
390 let bindings: Vec<(String, Term)> = params
391 .iter()
392 .zip(combo)
393 .map(|(name, val)| (name.clone(), val))
394 .collect();
395 match sparql.call_sparql_function(&query_text, &bindings) {
396 Ok(terms) => results.extend(terms),
397 Err(e) => {
398 diags.insert(format!("function <{}> error: {e}", iri.as_str()));
399 }
400 }
401 }
402 results
403 }
404 }
405}
406
407fn once(t: Term) -> HashSet<Term> {
408 let mut s = HashSet::with_capacity(1);
409 s.insert(t);
410 s
411}
412
413fn local_name(iri: &str) -> &str {
415 iri.rsplit(['#', '/']).next().unwrap_or(iri)
416}
417
418fn function_params(g: &Graph, func: &NamedOrBlankNode) -> Vec<String> {
421 let mut params: Vec<(i64, String)> = g
422 .objects_for_subject_predicate(func, vocab::SH_PARAMETER)
423 .filter_map(|param_ref| {
424 let param_node = node_of(¶m_ref.into_owned())?;
425 let order = g
426 .object_for_subject_predicate(¶m_node, vocab::SH_ORDER)
427 .map(|t| t.into_owned())
428 .and_then(|t| match t {
429 Term::Literal(l) => l.value().parse::<i64>().ok(),
430 _ => None,
431 })
432 .unwrap_or(0);
433 let name = g
434 .object_for_subject_predicate(¶m_node, vocab::SH_NAME)
435 .map(|t| t.into_owned())
436 .and_then(|t| match t {
437 Term::Literal(l) => Some(l.value().to_string()),
438 _ => None,
439 })
440 .or_else(|| {
441 g.object_for_subject_predicate(¶m_node, vocab::SH_PATH)
442 .map(|t| t.into_owned())
443 .and_then(|t| match t {
444 Term::NamedNode(n) => Some(local_name(n.as_str()).to_string()),
445 _ => None,
446 })
447 })?;
448 Some((order, name))
449 })
450 .collect();
451 params.sort_by(|a, b| a.0.cmp(&b.0).then(a.1.cmp(&b.1)));
452 params.into_iter().map(|(_, name)| name).collect()
453}
454
455fn cartesian_product(sets: &[HashSet<Term>]) -> Vec<Vec<Term>> {
457 sets.iter().fold(vec![vec![]], |acc, set| {
458 acc.into_iter()
459 .flat_map(|combo| {
460 set.iter().map(move |item| {
461 let mut row = combo.clone();
462 row.push(item.clone());
463 row
464 })
465 })
466 .collect()
467 })
468}