Skip to main content

tensorlogic_adapters/
incremental_query.rs

1//! Incremental query evaluation using semi-naive Datalog evaluation.
2//!
3//! Implements the semi-naive bottom-up fixpoint algorithm for Datalog: given a set of
4//! rules and base facts (EDB), computes derived facts (IDB) incrementally by only
5//! re-evaluating rules against Δ (newly derived tuples) at each iteration.
6//!
7//! # Example
8//!
9//! ```rust
10//! use tensorlogic_adapters::{Atom, Edb, Fact, FactArg, IncrementalEvaluator, Rule, Term};
11//!
12//! // Build EDB: parent(alice, bob), parent(bob, carol)
13//! let mut edb = Edb::new();
14//! edb.add_fact(Fact::sym("parent", &["alice", "bob"]));
15//! edb.add_fact(Fact::sym("parent", &["bob", "carol"]));
16//!
17//! // Rule: ancestor(X, Y) :- parent(X, Y).
18//! let rule1 = Rule::new(
19//!     Atom::new("ancestor", vec![Term::var("X"), Term::var("Y")]),
20//!     vec![Atom::new("parent", vec![Term::var("X"), Term::var("Y")])],
21//! );
22//!
23//! // Rule: ancestor(X, Z) :- parent(X, Y), ancestor(Y, Z).
24//! let rule2 = Rule::new(
25//!     Atom::new("ancestor", vec![Term::var("X"), Term::var("Z")]),
26//!     vec![
27//!         Atom::new("parent", vec![Term::var("X"), Term::var("Y")]),
28//!         Atom::new("ancestor", vec![Term::var("Y"), Term::var("Z")]),
29//!     ],
30//! );
31//!
32//! let mut evaluator = IncrementalEvaluator::new(vec![rule1, rule2], edb).unwrap();
33//! let derived = evaluator.query("ancestor");
34//! assert_eq!(derived.len(), 3); // alice->bob, bob->carol, alice->carol
35//! ```
36
37use std::collections::{HashMap, HashSet};
38use std::fmt;
39
40// ─────────────────────────────────────────────────────────────────────────────
41// FactArg
42// ─────────────────────────────────────────────────────────────────────────────
43
44/// A fact argument value — either a symbol or an integer.
45#[derive(Debug, Clone, PartialEq, Eq, Hash)]
46pub enum FactArg {
47    /// A symbolic / string constant.
48    Symbol(String),
49    /// An integer constant.
50    Integer(i64),
51}
52
53impl fmt::Display for FactArg {
54    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55        match self {
56            FactArg::Symbol(s) => write!(f, "{s}"),
57            FactArg::Integer(n) => write!(f, "{n}"),
58        }
59    }
60}
61
62// ─────────────────────────────────────────────────────────────────────────────
63// Fact
64// ─────────────────────────────────────────────────────────────────────────────
65
66/// A ground fact: a predicate name together with its argument values.
67#[derive(Debug, Clone, PartialEq, Eq, Hash)]
68pub struct Fact {
69    /// The predicate name (e.g. `"parent"`, `"ancestor"`).
70    pub predicate: String,
71    /// The argument values (ground terms).
72    pub args: Vec<FactArg>,
73}
74
75impl Fact {
76    /// Create a new fact.
77    pub fn new(predicate: impl Into<String>, args: Vec<FactArg>) -> Self {
78        Self {
79            predicate: predicate.into(),
80            args,
81        }
82    }
83
84    /// Convenience constructor: all arguments are `Symbol` values.
85    pub fn sym(predicate: impl Into<String>, args: &[&str]) -> Self {
86        Self {
87            predicate: predicate.into(),
88            args: args
89                .iter()
90                .map(|s| FactArg::Symbol(s.to_string()))
91                .collect(),
92        }
93    }
94
95    /// The number of arguments (arity).
96    pub fn arity(&self) -> usize {
97        self.args.len()
98    }
99}
100
101impl fmt::Display for Fact {
102    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
103        write!(f, "{}(", self.predicate)?;
104        for (i, a) in self.args.iter().enumerate() {
105            if i > 0 {
106                write!(f, ", ")?;
107            }
108            write!(f, "{a}")?;
109        }
110        write!(f, ")")
111    }
112}
113
114// ─────────────────────────────────────────────────────────────────────────────
115// Term
116// ─────────────────────────────────────────────────────────────────────────────
117
118/// A term in a Datalog rule body or head atom.
119#[derive(Debug, Clone, PartialEq)]
120pub enum Term {
121    /// A logical variable (e.g. `"X"`, `"Y"`).
122    Variable(String),
123    /// A ground constant.
124    Constant(FactArg),
125}
126
127impl Term {
128    /// Create a variable term.
129    pub fn var(name: impl Into<String>) -> Self {
130        Term::Variable(name.into())
131    }
132
133    /// Create a symbol constant term.
134    pub fn sym(s: impl Into<String>) -> Self {
135        Term::Constant(FactArg::Symbol(s.into()))
136    }
137
138    /// Create an integer constant term.
139    pub fn int(n: i64) -> Self {
140        Term::Constant(FactArg::Integer(n))
141    }
142}
143
144// ─────────────────────────────────────────────────────────────────────────────
145// Atom
146// ─────────────────────────────────────────────────────────────────────────────
147
148/// A Datalog atom: a predicate applied to a list of terms.
149#[derive(Debug, Clone, PartialEq)]
150pub struct Atom {
151    /// The predicate name.
152    pub predicate: String,
153    /// The argument terms (may contain variables or constants).
154    pub terms: Vec<Term>,
155}
156
157impl Atom {
158    /// Create a new atom.
159    pub fn new(predicate: impl Into<String>, terms: Vec<Term>) -> Self {
160        Self {
161            predicate: predicate.into(),
162            terms,
163        }
164    }
165}
166
167// ─────────────────────────────────────────────────────────────────────────────
168// Rule
169// ─────────────────────────────────────────────────────────────────────────────
170
171/// A Datalog rule: `head :- body[0], body[1], …`.
172///
173/// If `body` is empty, the rule is a fact (unconditional assertion).
174#[derive(Debug, Clone)]
175pub struct Rule {
176    /// The head atom (conclusion).
177    pub head: Atom,
178    /// The body atoms (premises), conjoined.
179    pub body: Vec<Atom>,
180}
181
182impl Rule {
183    /// Create a new rule.
184    pub fn new(head: Atom, body: Vec<Atom>) -> Self {
185        Self { head, body }
186    }
187
188    /// Returns `true` when the body is empty (the rule is an unconditional fact).
189    pub fn is_fact(&self) -> bool {
190        self.body.is_empty()
191    }
192}
193
194// ─────────────────────────────────────────────────────────────────────────────
195// Relation
196// ─────────────────────────────────────────────────────────────────────────────
197
198/// A set of facts sharing the same predicate.
199#[derive(Debug, Clone, Default)]
200pub struct Relation {
201    facts: HashSet<Fact>,
202}
203
204impl Relation {
205    /// Create an empty relation.
206    pub fn new() -> Self {
207        Self::default()
208    }
209
210    /// Insert a fact.  Returns `true` if the fact was not already present.
211    pub fn insert(&mut self, fact: Fact) -> bool {
212        self.facts.insert(fact)
213    }
214
215    /// Check whether the relation contains the given fact.
216    pub fn contains(&self, fact: &Fact) -> bool {
217        self.facts.contains(fact)
218    }
219
220    /// Number of facts in this relation.
221    pub fn len(&self) -> usize {
222        self.facts.len()
223    }
224
225    /// Returns `true` if the relation has no facts.
226    pub fn is_empty(&self) -> bool {
227        self.facts.is_empty()
228    }
229
230    /// Iterate over all facts.
231    pub fn iter(&self) -> impl Iterator<Item = &Fact> {
232        self.facts.iter()
233    }
234
235    /// Return a cloned `Vec` of all facts.
236    pub fn facts(&self) -> Vec<Fact> {
237        self.facts.iter().cloned().collect()
238    }
239
240    /// Compute the union of two relations (facts from both).
241    pub fn union(&self, other: &Relation) -> Relation {
242        let mut result = self.clone();
243        for f in other.facts.iter() {
244            result.facts.insert(f.clone());
245        }
246        result
247    }
248
249    /// Compute the set-difference `self − other`.
250    pub fn difference(&self, other: &Relation) -> Relation {
251        Relation {
252            facts: self
253                .facts
254                .iter()
255                .filter(|f| !other.facts.contains(*f))
256                .cloned()
257                .collect(),
258        }
259    }
260}
261
262// ─────────────────────────────────────────────────────────────────────────────
263// Edb — Extensional Database (base facts)
264// ─────────────────────────────────────────────────────────────────────────────
265
266/// The extensional database: the set of base (input) facts.
267#[derive(Debug, Clone, Default)]
268pub struct Edb {
269    relations: HashMap<String, Relation>,
270}
271
272impl Edb {
273    /// Create an empty EDB.
274    pub fn new() -> Self {
275        Self::default()
276    }
277
278    /// Insert a base fact.
279    pub fn add_fact(&mut self, fact: Fact) {
280        self.relations
281            .entry(fact.predicate.clone())
282            .or_default()
283            .insert(fact);
284    }
285
286    /// Retrieve the relation for a predicate, if any.
287    pub fn get_relation(&self, predicate: &str) -> Option<&Relation> {
288        self.relations.get(predicate)
289    }
290
291    /// List the names of all predicates in the EDB.
292    pub fn relation_names(&self) -> Vec<String> {
293        self.relations.keys().cloned().collect()
294    }
295
296    /// Total number of base facts across all predicates.
297    pub fn total_facts(&self) -> usize {
298        self.relations.values().map(|r| r.len()).sum()
299    }
300}
301
302// ─────────────────────────────────────────────────────────────────────────────
303// Idb — Intensional Database (derived facts)
304// ─────────────────────────────────────────────────────────────────────────────
305
306/// The intensional database: the set of derived (output) facts.
307#[derive(Debug, Clone, Default)]
308pub struct Idb {
309    relations: HashMap<String, Relation>,
310}
311
312impl Idb {
313    /// Create an empty IDB.
314    pub fn new() -> Self {
315        Self::default()
316    }
317
318    /// Retrieve the relation for a predicate, if any.
319    pub fn get_relation(&self, predicate: &str) -> Option<&Relation> {
320        self.relations.get(predicate)
321    }
322
323    /// Insert a derived fact.  Returns `true` if the fact was new.
324    pub fn insert(&mut self, predicate: &str, fact: Fact) -> bool {
325        self.relations
326            .entry(predicate.to_owned())
327            .or_default()
328            .insert(fact)
329    }
330
331    /// Total number of derived facts.
332    pub fn total_facts(&self) -> usize {
333        self.relations.values().map(|r| r.len()).sum()
334    }
335
336    /// Return every derived fact as a flat `Vec`.
337    pub fn all_facts(&self) -> Vec<Fact> {
338        self.relations
339            .values()
340            .flat_map(|r| r.facts.iter().cloned())
341            .collect()
342    }
343}
344
345// ─────────────────────────────────────────────────────────────────────────────
346// EvalStats
347// ─────────────────────────────────────────────────────────────────────────────
348
349/// Statistics gathered during semi-naive evaluation.
350#[derive(Debug, Default, Clone)]
351pub struct EvalStats {
352    /// Total number of fixpoint iterations performed.
353    pub iterations: usize,
354    /// Total number of new facts derived across all iterations.
355    pub total_new_facts: usize,
356    /// How many new facts were derived in each individual iteration.
357    pub facts_per_iteration: Vec<usize>,
358}
359
360// ─────────────────────────────────────────────────────────────────────────────
361// QueryError
362// ─────────────────────────────────────────────────────────────────────────────
363
364/// Errors that can occur during query evaluation.
365#[derive(Debug)]
366pub enum QueryError {
367    /// A rule body atom references a predicate that is neither EDB nor IDB.
368    UnknownPredicate(String),
369    /// The arity of a queried fact does not match the schema expectation.
370    ArityMismatch {
371        predicate: String,
372        expected: usize,
373        got: usize,
374    },
375    /// An internal evaluation error.
376    EvaluationError(String),
377}
378
379impl fmt::Display for QueryError {
380    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
381        match self {
382            QueryError::UnknownPredicate(p) => write!(f, "unknown predicate: {p}"),
383            QueryError::ArityMismatch {
384                predicate,
385                expected,
386                got,
387            } => write!(
388                f,
389                "arity mismatch for predicate {predicate}: expected {expected}, got {got}"
390            ),
391            QueryError::EvaluationError(msg) => write!(f, "evaluation error: {msg}"),
392        }
393    }
394}
395
396impl std::error::Error for QueryError {}
397
398// ─────────────────────────────────────────────────────────────────────────────
399// Helpers
400// ─────────────────────────────────────────────────────────────────────────────
401
402/// Attempt to unify `term` against the concrete `arg`, extending `bindings`.
403///
404/// * If `term` is a `Variable` that is already bound, the binding must equal `arg`.
405/// * If `term` is a `Variable` that is unbound, it is bound to `arg`.
406/// * If `term` is a `Constant`, it must equal `arg`.
407///
408/// Returns `true` on success, `false` on failure (conflict).
409fn unify_term(term: &Term, arg: &FactArg, bindings: &mut HashMap<String, FactArg>) -> bool {
410    match term {
411        Term::Variable(name) => {
412            if let Some(existing) = bindings.get(name) {
413                existing == arg
414            } else {
415                bindings.insert(name.clone(), arg.clone());
416                true
417            }
418        }
419        Term::Constant(c) => c == arg,
420    }
421}
422
423/// Substitute a complete set of `bindings` into the `head` atom to produce a
424/// ground `Fact`.  Returns `None` if any variable in the head is unbound.
425fn ground_head(head: &Atom, bindings: &HashMap<String, FactArg>) -> Option<Fact> {
426    let mut args = Vec::with_capacity(head.terms.len());
427    for term in &head.terms {
428        let arg = match term {
429            Term::Variable(name) => bindings.get(name)?.clone(),
430            Term::Constant(c) => c.clone(),
431        };
432        args.push(arg);
433    }
434    Some(Fact::new(head.predicate.clone(), args))
435}
436
437/// Recursively extend `current_bindings` over the remaining `atoms`,
438/// looking facts up in `all_facts`.
439///
440/// For the *semi-naive* optimisation at least one atom in the conjunction must
441/// be resolved against `delta` (the set of new facts from the previous
442/// iteration) rather than the full relation.  The boolean `used_delta`
443/// tracks whether a delta relation has already been used in the current
444/// conjunction path.  `delta` contains only the relations that changed in the
445/// last iteration.
446fn eval_body_atoms<'a>(
447    atoms: &'a [Atom],
448    current_bindings: HashMap<String, FactArg>,
449    all_facts: &'a HashMap<String, Relation>,
450    delta: &'a HashMap<String, Relation>,
451    used_delta: bool,
452) -> Vec<HashMap<String, FactArg>> {
453    if atoms.is_empty() {
454        // Require that at least one delta atom was used in order to avoid
455        // re-deriving already-known facts (semi-naive condition).
456        if used_delta {
457            return vec![current_bindings];
458        } else {
459            return vec![];
460        }
461    }
462
463    let (head_atom, rest) = atoms.split_first().expect("atoms is non-empty");
464
465    let predicate = &head_atom.predicate;
466    let mut results: Vec<HashMap<String, FactArg>> = Vec::new();
467
468    // Determine which fact-sets to scan.
469    // Semi-naive: we try the atom against the delta relation (if present) and
470    // separately against the full relation.  The `used_delta` flag ensures the
471    // overall conjunction touches at least one delta tuple.
472
473    let full_rel = all_facts.get(predicate.as_str());
474    let delta_rel = delta.get(predicate.as_str());
475
476    // Helper: iterate over a relation and collect extended bindings.
477    let try_relation = |rel: &Relation,
478                        bindings: &HashMap<String, FactArg>,
479                        is_delta: bool|
480     -> Vec<HashMap<String, FactArg>> {
481        let mut out = Vec::new();
482        for fact in rel.iter() {
483            if fact.terms_len() != head_atom.terms.len() {
484                continue;
485            }
486            let mut b = bindings.clone();
487            let mut ok = true;
488            for (term, arg) in head_atom.terms.iter().zip(fact.args.iter()) {
489                if !unify_term(term, arg, &mut b) {
490                    ok = false;
491                    break;
492                }
493            }
494            if ok {
495                let mut sub = eval_body_atoms(rest, b, all_facts, delta, used_delta || is_delta);
496                out.append(&mut sub);
497            }
498        }
499        out
500    };
501
502    // Strategy:
503    // 1. Use the delta relation for this atom (marks used_delta = true).
504    if let Some(dr) = delta_rel {
505        let mut sub = try_relation(dr, &current_bindings, true);
506        results.append(&mut sub);
507    }
508
509    // 2. Use the full relation for this atom but only if we haven't used delta
510    //    yet or there are remaining atoms that can provide the delta touch.
511    //    In practice we always scan the full relation; the used_delta guard at
512    //    the base case prevents duplicates being emitted.
513    if let Some(fr) = full_rel {
514        // When there is a delta for this predicate and used_delta is already
515        // true (from an earlier atom), we can scan the full relation freely.
516        // When used_delta is false we still scan the full relation because a
517        // later atom may hit the delta.  Duplicates are avoided by the
518        // base-case guard.
519        let mut sub = try_relation(fr, &current_bindings, false);
520        results.append(&mut sub);
521    }
522
523    results
524}
525
526// Extension trait to get the argument count from a Fact without exposing
527// implementation details.
528trait FactExt {
529    fn terms_len(&self) -> usize;
530}
531impl FactExt for Fact {
532    fn terms_len(&self) -> usize {
533        self.args.len()
534    }
535}
536
537// ─────────────────────────────────────────────────────────────────────────────
538// SemiNaiveEvaluator
539// ─────────────────────────────────────────────────────────────────────────────
540
541/// Semi-naive Datalog evaluator.
542///
543/// Computes the least fixpoint of a set of Datalog rules over a base EDB by
544/// iterating only over the *delta* (newly derived facts) at each round.
545pub struct SemiNaiveEvaluator {
546    rules: Vec<Rule>,
547    edb: Edb,
548    idb: Idb,
549    stats: EvalStats,
550}
551
552impl SemiNaiveEvaluator {
553    /// Create a new evaluator with the given rules and EDB.
554    pub fn new(rules: Vec<Rule>, edb: Edb) -> Self {
555        Self {
556            rules,
557            edb,
558            idb: Idb::new(),
559            stats: EvalStats::default(),
560        }
561    }
562
563    // ── Internal helpers ──────────────────────────────────────────────────────
564
565    /// Build a unified view of all known facts: EDB ∪ IDB, keyed by predicate.
566    fn all_facts_snapshot(&self) -> HashMap<String, Relation> {
567        let mut map: HashMap<String, Relation> = HashMap::new();
568
569        for (pred, rel) in &self.edb.relations {
570            map.entry(pred.clone())
571                .or_default()
572                .facts
573                .extend(rel.facts.iter().cloned());
574        }
575        for (pred, rel) in &self.idb.relations {
576            map.entry(pred.clone())
577                .or_default()
578                .facts
579                .extend(rel.facts.iter().cloned());
580        }
581        map
582    }
583
584    /// Apply a single rule using `delta` as the "new facts" layer, returning
585    /// any newly derivable head facts that are not yet in the IDB.
586    fn apply_rule(&self, rule: &Rule, delta: &HashMap<String, Relation>) -> Vec<Fact> {
587        // Rules with an empty body are fact rules — handled separately during
588        // initialisation; skip them here.
589        if rule.is_fact() {
590            return vec![];
591        }
592
593        let all_facts = self.all_facts_snapshot();
594        let bindings: HashMap<String, FactArg> = HashMap::new();
595
596        let binding_sets = eval_body_atoms(&rule.body, bindings, &all_facts, delta, false);
597
598        let mut new_facts: Vec<Fact> = Vec::new();
599        for b in binding_sets {
600            if let Some(fact) = ground_head(&rule.head, &b) {
601                // Only emit facts that are not yet in the IDB.
602                let already_known = self
603                    .idb
604                    .get_relation(&fact.predicate)
605                    .map(|r| r.contains(&fact))
606                    .unwrap_or(false);
607                if !already_known {
608                    new_facts.push(fact);
609                }
610            }
611        }
612        // Deduplicate within the batch.
613        new_facts.sort_unstable_by(|a, b| format!("{a}").cmp(&format!("{b}")));
614        new_facts.dedup();
615        new_facts
616    }
617
618    // ── Public API ────────────────────────────────────────────────────────────
619
620    /// Run semi-naive evaluation to fixpoint and return the final IDB.
621    ///
622    /// Algorithm:
623    /// 1. Seed Δ with all EDB facts plus any IDB fact-rules.
624    /// 2. Repeat until Δ is empty:
625    ///    a. For each rule, derive new facts using Δ.
626    ///    b. New facts not already in IDB become the next Δ.
627    ///    c. Add all new facts to IDB.
628    pub fn evaluate(&mut self) -> Result<&Idb, QueryError> {
629        // ── Step 0: Bootstrap IDB with EDB fact-rules (body-less rules). ──────
630        for rule in &self.rules {
631            if rule.is_fact() {
632                if let Some(fact) = ground_head(&rule.head, &HashMap::new()) {
633                    self.idb.insert(&fact.predicate.clone(), fact);
634                }
635            }
636        }
637
638        // ── Step 1: Seed delta with EDB facts + initial IDB. ─────────────────
639        let mut delta: HashMap<String, Relation> = HashMap::new();
640
641        for (pred, rel) in &self.edb.relations {
642            delta
643                .entry(pred.clone())
644                .or_default()
645                .facts
646                .extend(rel.facts.iter().cloned());
647        }
648        for (pred, rel) in &self.idb.relations {
649            delta
650                .entry(pred.clone())
651                .or_default()
652                .facts
653                .extend(rel.facts.iter().cloned());
654        }
655
656        // ── Step 2: Fixpoint loop. ────────────────────────────────────────────
657        loop {
658            if delta.values().all(|r| r.is_empty()) {
659                break;
660            }
661
662            let mut new_delta: HashMap<String, Relation> = HashMap::new();
663            let mut iteration_count = 0usize;
664
665            for rule in &self.rules {
666                if rule.is_fact() {
667                    continue;
668                }
669                let derived = self.apply_rule(rule, &delta);
670                for fact in derived {
671                    let pred = fact.predicate.clone();
672                    let is_new = self.idb.insert(&pred, fact.clone());
673                    if is_new {
674                        new_delta.entry(pred).or_default().insert(fact);
675                        iteration_count += 1;
676                    }
677                }
678            }
679
680            self.stats.iterations += 1;
681            self.stats.facts_per_iteration.push(iteration_count);
682            self.stats.total_new_facts += iteration_count;
683
684            delta = new_delta;
685        }
686
687        Ok(&self.idb)
688    }
689
690    /// Return evaluation statistics.
691    pub fn stats(&self) -> &EvalStats {
692        &self.stats
693    }
694
695    /// Return a reference to the current IDB.
696    pub fn idb(&self) -> &Idb {
697        &self.idb
698    }
699
700    /// Return a reference to the EDB.
701    pub fn edb(&self) -> &Edb {
702        &self.edb
703    }
704}
705
706// ─────────────────────────────────────────────────────────────────────────────
707// IncrementalEvaluator
708// ─────────────────────────────────────────────────────────────────────────────
709
710/// Wraps [`SemiNaiveEvaluator`] to support incremental updates: add new base
711/// facts without discarding previously derived knowledge.
712pub struct IncrementalEvaluator {
713    evaluator: SemiNaiveEvaluator,
714}
715
716impl IncrementalEvaluator {
717    /// Create and immediately evaluate the initial EDB.
718    pub fn new(rules: Vec<Rule>, initial_edb: Edb) -> Result<Self, QueryError> {
719        let mut evaluator = SemiNaiveEvaluator::new(rules, initial_edb);
720        evaluator.evaluate()?;
721        Ok(Self { evaluator })
722    }
723
724    /// Add new base facts to the EDB and propagate their consequences into the IDB.
725    ///
726    /// Only the newly added facts seed the next Δ, so previously computed
727    /// derived facts are never recomputed from scratch.
728    pub fn add_facts(&mut self, new_facts: Vec<Fact>) -> Result<EvalStats, QueryError> {
729        // Inject into EDB.
730        for fact in &new_facts {
731            self.evaluator.edb.add_fact(fact.clone());
732        }
733
734        // Seed delta only with the new facts.
735        let mut delta: HashMap<String, Relation> = HashMap::new();
736        for fact in new_facts {
737            delta
738                .entry(fact.predicate.clone())
739                .or_default()
740                .insert(fact);
741        }
742
743        let mut local_stats = EvalStats::default();
744
745        // Iterate until no new IDB facts emerge.
746        loop {
747            if delta.values().all(|r| r.is_empty()) {
748                break;
749            }
750
751            let mut new_delta: HashMap<String, Relation> = HashMap::new();
752            let mut iteration_count = 0usize;
753
754            for rule in &self.evaluator.rules {
755                if rule.is_fact() {
756                    continue;
757                }
758                let derived = self.evaluator.apply_rule(rule, &delta);
759                for fact in derived {
760                    let pred = fact.predicate.clone();
761                    let is_new = self.evaluator.idb.insert(&pred, fact.clone());
762                    if is_new {
763                        new_delta.entry(pred).or_default().insert(fact);
764                        iteration_count += 1;
765                    }
766                }
767            }
768
769            local_stats.iterations += 1;
770            local_stats.facts_per_iteration.push(iteration_count);
771            local_stats.total_new_facts += iteration_count;
772
773            // Mirror into the evaluator's global stats.
774            self.evaluator.stats.iterations += 1;
775            self.evaluator.stats.total_new_facts += iteration_count;
776            self.evaluator
777                .stats
778                .facts_per_iteration
779                .push(iteration_count);
780
781            delta = new_delta;
782        }
783
784        Ok(local_stats)
785    }
786
787    /// Query the current IDB for all facts of the given predicate.
788    pub fn query(&self, predicate: &str) -> Vec<Fact> {
789        self.evaluator
790            .idb
791            .get_relation(predicate)
792            .map(|r| r.facts())
793            .unwrap_or_default()
794    }
795
796    /// Total number of derived facts currently in the IDB.
797    pub fn total_derived_facts(&self) -> usize {
798        self.evaluator.idb.total_facts()
799    }
800}
801
802// ─────────────────────────────────────────────────────────────────────────────
803// Tests
804// ─────────────────────────────────────────────────────────────────────────────
805
806#[cfg(test)]
807mod tests {
808    use super::*;
809
810    // ── Utility helpers ───────────────────────────────────────────────────────
811
812    fn make_parent_edb() -> Edb {
813        let mut edb = Edb::new();
814        edb.add_fact(Fact::sym("parent", &["alice", "bob"]));
815        edb.add_fact(Fact::sym("parent", &["bob", "carol"]));
816        edb
817    }
818
819    fn ancestor_rules() -> Vec<Rule> {
820        vec![
821            Rule::new(
822                Atom::new("ancestor", vec![Term::var("X"), Term::var("Y")]),
823                vec![Atom::new("parent", vec![Term::var("X"), Term::var("Y")])],
824            ),
825            Rule::new(
826                Atom::new("ancestor", vec![Term::var("X"), Term::var("Z")]),
827                vec![
828                    Atom::new("parent", vec![Term::var("X"), Term::var("Y")]),
829                    Atom::new("ancestor", vec![Term::var("Y"), Term::var("Z")]),
830                ],
831            ),
832        ]
833    }
834
835    // ── Test 1: Empty EDB + no rules → empty IDB ─────────────────────────────
836
837    #[test]
838    fn test_empty_edb_no_rules() {
839        let mut eval = SemiNaiveEvaluator::new(vec![], Edb::new());
840        let idb = eval.evaluate().expect("evaluation should succeed");
841        assert_eq!(idb.total_facts(), 0, "empty IDB expected");
842    }
843
844    // ── Test 2: Fact rules (body-less) insert into IDB directly ──────────────
845
846    #[test]
847    fn test_fact_rules_insert_directly() {
848        let rule = Rule::new(
849            Atom::new("foo", vec![Term::sym("bar")]),
850            vec![], // body-less
851        );
852        let mut eval = SemiNaiveEvaluator::new(vec![rule], Edb::new());
853        let idb = eval.evaluate().expect("evaluation should succeed");
854        let facts = idb.get_relation("foo").expect("relation foo should exist");
855        assert_eq!(facts.len(), 1);
856        assert!(facts.contains(&Fact::sym("foo", &["bar"])));
857    }
858
859    // ── Test 3: Simple chain: direct ancestor ─────────────────────────────────
860
861    #[test]
862    fn test_simple_ancestor_chain() {
863        let rule = Rule::new(
864            Atom::new("ancestor", vec![Term::var("X"), Term::var("Y")]),
865            vec![Atom::new("parent", vec![Term::var("X"), Term::var("Y")])],
866        );
867        let mut eval = SemiNaiveEvaluator::new(vec![rule], make_parent_edb());
868        let idb = eval.evaluate().expect("evaluation should succeed");
869        let derived = idb.get_relation("ancestor").expect("ancestor relation");
870        // Should contain alice->bob and bob->carol.
871        assert!(derived.contains(&Fact::sym("ancestor", &["alice", "bob"])));
872        assert!(derived.contains(&Fact::sym("ancestor", &["bob", "carol"])));
873    }
874
875    // ── Test 4: Recursive rule → transitive closure ───────────────────────────
876
877    #[test]
878    fn test_recursive_transitive_closure() {
879        let mut eval = SemiNaiveEvaluator::new(ancestor_rules(), make_parent_edb());
880        let idb = eval.evaluate().expect("evaluation should succeed");
881        let derived = idb.get_relation("ancestor").expect("ancestor relation");
882        assert!(derived.contains(&Fact::sym("ancestor", &["alice", "carol"])));
883        assert_eq!(derived.len(), 3);
884    }
885
886    // ── Test 5: Fixpoint stops when no new facts ──────────────────────────────
887
888    #[test]
889    fn test_fixpoint_terminates() {
890        let mut eval = SemiNaiveEvaluator::new(ancestor_rules(), make_parent_edb());
891        eval.evaluate().expect("evaluation should succeed");
892        // Run again — should produce no new facts.
893        let idb_after = eval.idb().total_facts();
894        assert_eq!(idb_after, 3);
895    }
896
897    // ── Test 6: Stats.iterations > 1 for recursive rules ─────────────────────
898
899    #[test]
900    fn test_eval_stats_iterations() {
901        let mut eval = SemiNaiveEvaluator::new(ancestor_rules(), make_parent_edb());
902        eval.evaluate().expect("evaluation should succeed");
903        // At least 2 iterations are needed to derive the transitive fact.
904        assert!(
905            eval.stats().iterations >= 2,
906            "expected >=2 iterations, got {}",
907            eval.stats().iterations
908        );
909    }
910
911    // ── Test 7: EvalStats.total_new_facts counts correctly ───────────────────
912
913    #[test]
914    fn test_eval_stats_total_new_facts() {
915        let mut eval = SemiNaiveEvaluator::new(ancestor_rules(), make_parent_edb());
916        eval.evaluate().expect("evaluation should succeed");
917        assert_eq!(eval.stats().total_new_facts, 3);
918    }
919
920    // ── Test 8: Relation.union ────────────────────────────────────────────────
921
922    #[test]
923    fn test_relation_union() {
924        let mut r1 = Relation::new();
925        r1.insert(Fact::sym("foo", &["a"]));
926
927        let mut r2 = Relation::new();
928        r2.insert(Fact::sym("foo", &["b"]));
929        r2.insert(Fact::sym("foo", &["a"])); // duplicate
930
931        let u = r1.union(&r2);
932        assert_eq!(u.len(), 2);
933    }
934
935    // ── Test 9: Relation.difference ──────────────────────────────────────────
936
937    #[test]
938    fn test_relation_difference() {
939        let mut r1 = Relation::new();
940        r1.insert(Fact::sym("foo", &["a"]));
941        r1.insert(Fact::sym("foo", &["b"]));
942
943        let mut r2 = Relation::new();
944        r2.insert(Fact::sym("foo", &["a"]));
945
946        let diff = r1.difference(&r2);
947        assert_eq!(diff.len(), 1);
948        assert!(diff.contains(&Fact::sym("foo", &["b"])));
949    }
950
951    // ── Test 10: Edb.total_facts ──────────────────────────────────────────────
952
953    #[test]
954    fn test_edb_total_facts() {
955        let edb = make_parent_edb();
956        assert_eq!(edb.total_facts(), 2);
957    }
958
959    // ── Test 11: Idb.all_facts ────────────────────────────────────────────────
960
961    #[test]
962    fn test_idb_all_facts() {
963        let mut eval = SemiNaiveEvaluator::new(ancestor_rules(), make_parent_edb());
964        eval.evaluate().expect("evaluation should succeed");
965        let all = eval.idb().all_facts();
966        assert_eq!(all.len(), 3);
967    }
968
969    // ── Test 12: Rule with two body atoms — join ──────────────────────────────
970
971    #[test]
972    fn test_two_body_atom_join() {
973        // sibling(X, Z) :- parent(Y, X), parent(Y, Z).
974        let mut edb = Edb::new();
975        edb.add_fact(Fact::sym("parent", &["alice", "bob"]));
976        edb.add_fact(Fact::sym("parent", &["alice", "carol"]));
977
978        let rule = Rule::new(
979            Atom::new("sibling", vec![Term::var("X"), Term::var("Z")]),
980            vec![
981                Atom::new("parent", vec![Term::var("Y"), Term::var("X")]),
982                Atom::new("parent", vec![Term::var("Y"), Term::var("Z")]),
983            ],
984        );
985
986        let mut eval = SemiNaiveEvaluator::new(vec![rule], edb);
987        let idb = eval.evaluate().expect("evaluation should succeed");
988        let siblings = idb.get_relation("sibling").expect("sibling relation");
989        // bob-bob, bob-carol, carol-bob, carol-carol
990        assert_eq!(siblings.len(), 4);
991    }
992
993    // ── Test 13: Constant in rule body filters ────────────────────────────────
994
995    #[test]
996    fn test_constant_in_body_filters() {
997        // known_alice(Y) :- parent("alice", Y).
998        let rule = Rule::new(
999            Atom::new("known_alice", vec![Term::var("Y")]),
1000            vec![Atom::new(
1001                "parent",
1002                vec![Term::sym("alice"), Term::var("Y")],
1003            )],
1004        );
1005
1006        let mut eval = SemiNaiveEvaluator::new(vec![rule], make_parent_edb());
1007        let idb = eval.evaluate().expect("evaluation should succeed");
1008        let rel = idb.get_relation("known_alice").expect("known_alice");
1009        assert_eq!(rel.len(), 1);
1010        assert!(rel.contains(&Fact::new(
1011            "known_alice",
1012            vec![FactArg::Symbol("bob".to_owned())]
1013        )));
1014    }
1015
1016    // ── Test 14: Variable reuse (equality check) ──────────────────────────────
1017
1018    #[test]
1019    fn test_variable_reuse_equality() {
1020        // self_parent(X) :- parent(X, X).
1021        let mut edb = Edb::new();
1022        edb.add_fact(Fact::sym("parent", &["alice", "bob"]));
1023        edb.add_fact(Fact::sym("parent", &["self", "self"]));
1024
1025        let rule = Rule::new(
1026            Atom::new("self_parent", vec![Term::var("X")]),
1027            vec![Atom::new("parent", vec![Term::var("X"), Term::var("X")])],
1028        );
1029
1030        let mut eval = SemiNaiveEvaluator::new(vec![rule], edb);
1031        let idb = eval.evaluate().expect("evaluation should succeed");
1032        let rel = idb.get_relation("self_parent").expect("self_parent");
1033        assert_eq!(rel.len(), 1);
1034        assert!(rel.contains(&Fact::new(
1035            "self_parent",
1036            vec![FactArg::Symbol("self".to_owned())]
1037        )));
1038    }
1039
1040    // ── Test 15: IncrementalEvaluator.add_facts propagates ───────────────────
1041
1042    #[test]
1043    fn test_incremental_add_facts() {
1044        let edb = make_parent_edb(); // alice->bob, bob->carol
1045        let mut inc =
1046            IncrementalEvaluator::new(ancestor_rules(), edb).expect("init should succeed");
1047
1048        // Initially 3 derived facts.
1049        assert_eq!(inc.total_derived_facts(), 3);
1050
1051        // Add carol->dave.
1052        inc.add_facts(vec![Fact::sym("parent", &["carol", "dave"])])
1053            .expect("add_facts should succeed");
1054
1055        let ancestors = inc.query("ancestor");
1056        // alice->bob, alice->carol, alice->dave, bob->carol, bob->dave, carol->dave
1057        assert_eq!(ancestors.len(), 6, "expected 6 ancestor pairs");
1058    }
1059
1060    // ── Test 16: IncrementalEvaluator.query ──────────────────────────────────
1061
1062    #[test]
1063    fn test_incremental_query() {
1064        let edb = make_parent_edb();
1065        let inc = IncrementalEvaluator::new(ancestor_rules(), edb).expect("init should succeed");
1066
1067        let ancestors = inc.query("ancestor");
1068        assert!(!ancestors.is_empty());
1069
1070        // Non-existent predicate returns empty.
1071        let none = inc.query("no_such_predicate");
1072        assert!(none.is_empty());
1073    }
1074
1075    // ── Test 17: Semi-naive avoids re-deriving known facts ───────────────────
1076
1077    #[test]
1078    fn test_semi_naive_no_redundant_recomputation() {
1079        let edb = make_parent_edb();
1080        let mut inc =
1081            IncrementalEvaluator::new(ancestor_rules(), edb).expect("init should succeed");
1082
1083        let before = inc.total_derived_facts();
1084
1085        // Adding a fact that produces no new derived facts should leave the
1086        // IDB size unchanged (alice already has bob as ancestor).
1087        // (We add a duplicate base fact.)
1088        let stats = inc
1089            .add_facts(vec![Fact::sym("parent", &["alice", "bob"])])
1090            .expect("add_facts should succeed");
1091
1092        let after = inc.total_derived_facts();
1093        assert_eq!(before, after, "no new derived facts expected");
1094        assert_eq!(stats.total_new_facts, 0);
1095    }
1096
1097    // ── Test 18: Fact.sym convenience constructor ─────────────────────────────
1098
1099    #[test]
1100    fn test_fact_sym_constructor() {
1101        let f = Fact::sym("edge", &["a", "b"]);
1102        assert_eq!(f.predicate, "edge");
1103        assert_eq!(f.arity(), 2);
1104        assert_eq!(f.args[0], FactArg::Symbol("a".to_owned()));
1105        assert_eq!(f.args[1], FactArg::Symbol("b".to_owned()));
1106    }
1107
1108    // ── Test 19: Term constructors ────────────────────────────────────────────
1109
1110    #[test]
1111    fn test_term_constructors() {
1112        let v = Term::var("X");
1113        let s = Term::sym("hello");
1114        let n = Term::int(42);
1115
1116        assert!(matches!(v, Term::Variable(ref x) if x == "X"));
1117        assert!(matches!(s, Term::Constant(FactArg::Symbol(ref x)) if x == "hello"));
1118        assert!(matches!(n, Term::Constant(FactArg::Integer(42))));
1119    }
1120
1121    // ── Test 20: QueryError for rule body with unknown predicate ─────────────
1122    //
1123    // Our evaluator does not fail hard on unknown predicates (it simply finds
1124    // no facts for that predicate), but we provide a mechanism to detect it
1125    // after evaluation by checking EDB + IDB coverage.  Here we test that an
1126    // empty result is produced when the body atom predicate is absent from
1127    // both EDB and IDB.
1128    #[test]
1129    fn test_unknown_predicate_in_rule_body() {
1130        // foo(X) :- no_such_pred(X).
1131        let rule = Rule::new(
1132            Atom::new("foo", vec![Term::var("X")]),
1133            vec![Atom::new("no_such_pred", vec![Term::var("X")])],
1134        );
1135        let mut eval = SemiNaiveEvaluator::new(vec![rule], Edb::new());
1136        let idb = eval.evaluate().expect("evaluation should not hard-fail");
1137        // No facts derived because no_such_pred is empty.
1138        assert_eq!(idb.total_facts(), 0);
1139
1140        // Verify QueryError can be constructed and displayed.
1141        let err = QueryError::UnknownPredicate("no_such_pred".to_owned());
1142        assert!(err.to_string().contains("no_such_pred"));
1143    }
1144
1145    // ── Test 21: 5-node chain has 10 ancestor pairs ───────────────────────────
1146
1147    #[test]
1148    fn test_five_node_chain() {
1149        let nodes = ["a", "b", "c", "d", "e"];
1150        let mut edb = Edb::new();
1151        for i in 0..nodes.len() - 1 {
1152            edb.add_fact(Fact::sym("parent", &[nodes[i], nodes[i + 1]]));
1153        }
1154
1155        let mut eval = SemiNaiveEvaluator::new(ancestor_rules(), edb);
1156        let idb = eval.evaluate().expect("evaluation should succeed");
1157        let derived = idb.get_relation("ancestor").expect("ancestor relation");
1158        // For a 5-node chain a→b→c→d→e, the transitive closure has
1159        // C(5,2) = 10 pairs.
1160        assert_eq!(derived.len(), 10);
1161    }
1162
1163    // ── Test 22: Multiple rules deriving same head accumulate ─────────────────
1164
1165    #[test]
1166    fn test_multiple_rules_same_head() {
1167        let mut edb = Edb::new();
1168        edb.add_fact(Fact::sym("edge_a", &["x", "y"]));
1169        edb.add_fact(Fact::sym("edge_b", &["y", "z"]));
1170
1171        let rule1 = Rule::new(
1172            Atom::new("reachable", vec![Term::var("X"), Term::var("Y")]),
1173            vec![Atom::new("edge_a", vec![Term::var("X"), Term::var("Y")])],
1174        );
1175        let rule2 = Rule::new(
1176            Atom::new("reachable", vec![Term::var("X"), Term::var("Y")]),
1177            vec![Atom::new("edge_b", vec![Term::var("X"), Term::var("Y")])],
1178        );
1179
1180        let mut eval = SemiNaiveEvaluator::new(vec![rule1, rule2], edb);
1181        let idb = eval.evaluate().expect("evaluation should succeed");
1182        let rel = idb.get_relation("reachable").expect("reachable relation");
1183        assert_eq!(rel.len(), 2);
1184        assert!(rel.contains(&Fact::sym("reachable", &["x", "y"])));
1185        assert!(rel.contains(&Fact::sym("reachable", &["y", "z"])));
1186    }
1187
1188    // ── Bonus: Integer fact args work ─────────────────────────────────────────
1189
1190    #[test]
1191    fn test_integer_fact_args() {
1192        let mut edb = Edb::new();
1193        edb.add_fact(Fact::new(
1194            "score",
1195            vec![FactArg::Symbol("alice".to_owned()), FactArg::Integer(99)],
1196        ));
1197
1198        // high_scorer(X) :- score(X, 99).
1199        let rule = Rule::new(
1200            Atom::new("high_scorer", vec![Term::var("X")]),
1201            vec![Atom::new("score", vec![Term::var("X"), Term::int(99)])],
1202        );
1203
1204        let mut eval = SemiNaiveEvaluator::new(vec![rule], edb);
1205        let idb = eval.evaluate().expect("evaluation should succeed");
1206        let rel = idb.get_relation("high_scorer").expect("high_scorer");
1207        assert_eq!(rel.len(), 1);
1208    }
1209}