Skip to main content

sim_lib_logic/
builtins.rs

1//! Browseable logic builtin bindings.
2
3use std::{
4    collections::{BTreeMap, BTreeSet},
5    fmt,
6    sync::{Arc, Mutex},
7};
8
9use indexmap::IndexMap;
10use sim_kernel::{Cx, Error, Expr, Result, ShapeMatch, Symbol};
11
12use crate::{
13    LogicConfig, LogicDb,
14    all_solutions::{
15        FindallRequest, bagof_through_sequence, findall_through_sequence, setof_through_sequence,
16    },
17    arith::{eval_compare_through_tower, eval_is_through_tower},
18    clause::{predicate_symbol, rename_clause_apart},
19    env::LogicEnv,
20    error::logic_eval_error,
21    lists::{
22        append_through_sequence, length_through_sequence, member_through_sequence,
23        select_through_sequence,
24    },
25    query::query_all,
26    unify::occurs_check,
27};
28
29/// Context handed to every builtin projection.
30pub struct BuiltinCtx<'a> {
31    /// Active clause database for child queries.
32    pub db: &'a LogicDb,
33    /// Active query limits and search configuration.
34    pub config: &'a LogicConfig,
35    /// Effective answer cap for the current query stream.
36    pub answer_limit: Option<usize>,
37}
38
39/// Projection function used by a builtin binding.
40pub type BuiltinSolve = dyn for<'a> Fn(&mut Cx, &BuiltinCtx<'a>, &[Expr], &LogicEnv) -> Result<Vec<LogicEnv>>
41    + Send
42    + Sync;
43
44type BuiltinProjection =
45    for<'a> fn(&mut Cx, &BuiltinCtx<'a>, &[Expr], &LogicEnv) -> Result<Vec<LogicEnv>>;
46
47/// Data record describing one builtin goal.
48#[derive(Clone)]
49pub struct BuiltinBinding {
50    /// Goal functor handled by this binding.
51    pub key: Symbol,
52    /// Organ that resolves the builtin, exposed as browseable metadata.
53    pub organ: Symbol,
54    /// Thin projection from a goal's arguments to continuation environments.
55    pub solve: Arc<BuiltinSolve>,
56}
57
58impl fmt::Debug for BuiltinBinding {
59    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
60        formatter
61            .debug_struct("BuiltinBinding")
62            .field("key", &self.key)
63            .field("organ", &self.organ)
64            .finish_non_exhaustive()
65    }
66}
67
68/// Table of builtin goal bindings.
69#[derive(Clone, Default)]
70pub struct BuiltinTable {
71    bindings: IndexMap<Symbol, BuiltinBinding>,
72}
73
74impl fmt::Debug for BuiltinTable {
75    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
76        formatter
77            .debug_struct("BuiltinTable")
78            .field("keys", &self.keys().collect::<Vec<_>>())
79            .finish()
80    }
81}
82
83impl BuiltinTable {
84    /// Registers or replaces a builtin binding.
85    pub fn register(&mut self, binding: BuiltinBinding) {
86        self.bindings.insert(binding.key.clone(), binding);
87    }
88
89    /// Returns the binding for `key`, when one is registered.
90    pub fn get(&self, key: &Symbol) -> Option<&BuiltinBinding> {
91        self.bindings.get(key)
92    }
93
94    /// Returns the organ metadata for `key`, when one is registered.
95    pub fn organ_of(&self, key: &Symbol) -> Option<&Symbol> {
96        self.bindings.get(key).map(|binding| &binding.organ)
97    }
98
99    /// Returns the registered builtin keys in insertion order.
100    pub fn keys(&self) -> impl Iterator<Item = &Symbol> {
101        self.bindings.keys()
102    }
103
104    /// Returns the standard builtin table.
105    pub fn standard() -> Self {
106        let mut table = Self::default();
107        register_keystones(&mut table);
108        register_constraints(&mut table);
109        register_arithmetic_comparisons(&mut table);
110        register_lists(&mut table);
111        table
112    }
113}
114
115/// Builds a sequence-organ tabling memo binding for `predicate`.
116///
117/// The binding is ordinary [`BuiltinTable`] data: registering it under a
118/// predicate key makes that predicate tabled without changing the resolver. It
119/// computes a finite fixed point for the predicate's current clauses, caches the
120/// ground answer tuples by arity, and replays matching tuples into the active
121/// environment on every call.
122pub fn tabling_memo_binding(predicate: Symbol) -> BuiltinBinding {
123    let memo = Arc::new(Mutex::new(TabledMemo::default()));
124    BuiltinBinding {
125        key: predicate.clone(),
126        organ: Symbol::new("sequence"),
127        solve: Arc::new(move |cx, ctx, args, env| {
128            let tuples = cached_tabled_tuples(cx, ctx, &predicate, args.len(), &memo)?;
129            replay_tabled_tuples(ctx.config, args, env, &tuples)
130        }),
131    }
132}
133
134fn register_keystones(table: &mut BuiltinTable) {
135    table.register(BuiltinBinding {
136        key: Symbol::new("is"),
137        organ: Symbol::qualified("numbers", "arith"),
138        solve: Arc::new(|cx, ctx, args, env| {
139            let [left, right] = args else {
140                return Err(logic_eval_error("is expects two arguments"));
141            };
142            eval_is_through_tower(cx, ctx.config, left, right, env)
143        }),
144    });
145    table.register(BuiltinBinding {
146        key: Symbol::new("findall"),
147        organ: Symbol::new("sequence"),
148        solve: Arc::new(|cx, ctx, args, env| {
149            let [template, goal, output] = args else {
150                return Err(logic_eval_error("findall expects three arguments"));
151            };
152            findall_through_sequence(
153                cx,
154                FindallRequest {
155                    db: ctx.db,
156                    config: ctx.config,
157                    template,
158                    goal,
159                    output,
160                    env,
161                },
162            )
163        }),
164    });
165    register_sequence_binding(table, "bagof", bagof_through_sequence);
166    register_sequence_binding(table, "setof", setof_through_sequence);
167}
168
169fn register_constraints(table: &mut BuiltinTable) {
170    for key in ["#=", "#<", "dif"] {
171        let key = Symbol::new(key);
172        table.register(BuiltinBinding {
173            key: key.clone(),
174            organ: Symbol::new("control"),
175            solve: Arc::new(move |cx, ctx, args, env| {
176                crate::constraints::solve_constraint(cx, ctx.config, &key, args, env)
177            }),
178        });
179    }
180    for key in [
181        "=",
182        "<",
183        "<=",
184        ">",
185        ">=",
186        "plus",
187        "minus",
188        "times",
189        "between",
190        "tool-call",
191    ] {
192        let key = Symbol::new(key);
193        table.register(BuiltinBinding {
194            key: key.clone(),
195            organ: Symbol::qualified("logic", "constraint"),
196            solve: Arc::new(move |cx, ctx, args, env| {
197                crate::constraints::solve_constraint(cx, ctx.config, &key, args, env)
198            }),
199        });
200    }
201}
202
203fn register_arithmetic_comparisons(table: &mut BuiltinTable) {
204    for key in ["=:=", "=\\=", "<", "=<", ">", ">="] {
205        let key = Symbol::new(key);
206        table.register(BuiltinBinding {
207            key: key.clone(),
208            organ: Symbol::qualified("numbers", "arith"),
209            solve: Arc::new(move |cx, _ctx, args, env| {
210                eval_compare_through_tower(cx, &key, args, env)
211            }),
212        });
213    }
214}
215
216fn register_lists(table: &mut BuiltinTable) {
217    register_sequence_binding(table, "member", member_through_sequence);
218    register_sequence_binding(table, "append", append_through_sequence);
219    register_sequence_binding(table, "length", length_through_sequence);
220    register_sequence_binding(table, "select", select_through_sequence);
221}
222
223fn register_sequence_binding(table: &mut BuiltinTable, key: &str, solve: BuiltinProjection) {
224    table.register(BuiltinBinding {
225        key: Symbol::new(key),
226        organ: Symbol::new("sequence"),
227        solve: Arc::new(solve),
228    });
229}
230
231#[derive(Default)]
232struct TabledMemo {
233    by_arity: BTreeMap<usize, Vec<Vec<Expr>>>,
234}
235
236fn cached_tabled_tuples(
237    cx: &mut Cx,
238    ctx: &BuiltinCtx<'_>,
239    predicate: &Symbol,
240    arity: usize,
241    memo: &Arc<Mutex<TabledMemo>>,
242) -> Result<Vec<Vec<Expr>>> {
243    if let Some(cached) = memo
244        .lock()
245        .map_err(|_| Error::PoisonedLock("logic tabling memo"))?
246        .by_arity
247        .get(&arity)
248        .cloned()
249    {
250        return Ok(cached);
251    }
252
253    let computed = compute_tabled_tuples(cx, ctx, predicate, arity)?;
254    let mut guard = memo
255        .lock()
256        .map_err(|_| Error::PoisonedLock("logic tabling memo"))?;
257    Ok(guard
258        .by_arity
259        .entry(arity)
260        .or_insert_with(|| computed.clone())
261        .clone())
262}
263
264fn compute_tabled_tuples(
265    cx: &mut Cx,
266    ctx: &BuiltinCtx<'_>,
267    predicate: &Symbol,
268    arity: usize,
269) -> Result<Vec<Vec<Expr>>> {
270    let mut tuples = Vec::new();
271    let mut seen = BTreeSet::new();
272    let max_rounds = ctx.config.limits.max_depth.max(1);
273    for round in 0..max_rounds {
274        let before = tuples.len();
275        for clause in ctx.db.clauses() {
276            if clause.predicate()? != predicate.clone() || clause.arity()? != arity {
277                continue;
278            }
279            let clause = rename_clause_apart(clause, round + 1);
280            for env in solve_tabled_body(cx, ctx, predicate, &tuples, &clause.body)? {
281                let tuple = tabled_head_tuple(&clause.head, &env)?;
282                if tuple.len() == arity
283                    && tuple.iter().all(is_ground)
284                    && seen.insert(tuple_key(&tuple))
285                {
286                    tuples.push(tuple);
287                }
288            }
289        }
290        if tuples.len() == before {
291            return Ok(tuples);
292        }
293    }
294    Err(logic_eval_error(format!(
295        "tabling memo for {predicate} exceeded fixed-point limit {max_rounds}"
296    )))
297}
298
299fn solve_tabled_body(
300    cx: &mut Cx,
301    ctx: &BuiltinCtx<'_>,
302    predicate: &Symbol,
303    tuples: &[Vec<Expr>],
304    body: &[Expr],
305) -> Result<Vec<LogicEnv>> {
306    let mut envs = vec![LogicEnv::new()];
307    for goal in body {
308        let mut next_envs = Vec::new();
309        for env in envs {
310            let applied = env.apply(goal);
311            if predicate_symbol(&applied)? == predicate.clone() {
312                next_envs.extend(replay_tabled_tuples(
313                    ctx.config,
314                    goal_args(&applied)?,
315                    &env,
316                    tuples,
317                )?);
318            } else {
319                next_envs.extend(resolve_non_tabled_goal(cx, ctx, &applied, &env)?);
320            }
321        }
322        envs = next_envs;
323        if envs.is_empty() {
324            break;
325        }
326    }
327    Ok(envs)
328}
329
330fn replay_tabled_tuples(
331    config: &LogicConfig,
332    args: &[Expr],
333    env: &LogicEnv,
334    tuples: &[Vec<Expr>],
335) -> Result<Vec<LogicEnv>> {
336    let mut out = Vec::new();
337    for tuple in tuples.iter().filter(|tuple| tuple.len() == args.len()) {
338        let mut next = env.clone();
339        let mut accepted = true;
340        for (arg, value) in args.iter().zip(tuple) {
341            if !next.unify(arg, value, occurs_check(config))? {
342                accepted = false;
343                break;
344            }
345        }
346        if accepted {
347            out.push(next);
348        }
349    }
350    Ok(out)
351}
352
353fn resolve_non_tabled_goal(
354    cx: &mut Cx,
355    ctx: &BuiltinCtx<'_>,
356    goal: &Expr,
357    env: &LogicEnv,
358) -> Result<Vec<LogicEnv>> {
359    let mut out = Vec::new();
360    for answer in query_all(cx, ctx.db, ctx.config, goal.clone(), ctx.answer_limit)? {
361        if let Some(next) = merge_answer(env.clone(), ctx.config, &answer)? {
362            out.push(next);
363        }
364    }
365    Ok(out)
366}
367
368fn merge_answer(
369    mut env: LogicEnv,
370    config: &LogicConfig,
371    answer: &ShapeMatch,
372) -> Result<Option<LogicEnv>> {
373    for (var, value) in answer.captures.exprs() {
374        if !env.unify(&Expr::Local(var.clone()), value, occurs_check(config))? {
375            return Ok(None);
376        }
377    }
378    Ok(Some(env))
379}
380
381fn tabled_head_tuple(head: &Expr, env: &LogicEnv) -> Result<Vec<Expr>> {
382    Ok(goal_args(head)?.iter().map(|arg| env.apply(arg)).collect())
383}
384
385fn goal_args(goal: &Expr) -> Result<&[Expr]> {
386    match goal {
387        Expr::List(items) => Ok(&items[1..]),
388        Expr::Call { args, .. } => Ok(args),
389        _ => Err(logic_eval_error("tabled goal must be call-shaped")),
390    }
391}
392
393fn is_ground(expr: &Expr) -> bool {
394    match expr {
395        Expr::Local(_) => false,
396        Expr::List(items) | Expr::Vector(items) | Expr::Set(items) | Expr::Block(items) => {
397            items.iter().all(is_ground)
398        }
399        Expr::Map(entries) => entries
400            .iter()
401            .all(|(key, value)| is_ground(key) && is_ground(value)),
402        Expr::Call { operator, args } => is_ground(operator) && args.iter().all(is_ground),
403        Expr::Infix { left, right, .. } => is_ground(left) && is_ground(right),
404        Expr::Prefix { arg, .. } | Expr::Postfix { arg, .. } => is_ground(arg),
405        Expr::Quote { expr, .. } | Expr::Extension { payload: expr, .. } => is_ground(expr),
406        Expr::Annotated { expr, annotations } => {
407            is_ground(expr) && annotations.iter().all(|(_, value)| is_ground(value))
408        }
409        _ => true,
410    }
411}
412
413fn tuple_key(tuple: &[Expr]) -> String {
414    tuple
415        .iter()
416        .map(|expr| format!("{:?}", expr.canonical_key()))
417        .collect::<Vec<_>>()
418        .join("\0")
419}