1use std::{
4 collections::{BTreeMap, BTreeSet},
5 fmt,
6 sync::{Arc, Mutex},
7};
8
9use indexmap::IndexMap;
10use sim_kernel::{Cx, Error, Expr, Result, ShapeMatch, Symbol};
11
12use crate::{
13 LogicConfig, LogicDb,
14 all_solutions::{
15 FindallRequest, bagof_through_sequence, findall_through_sequence, setof_through_sequence,
16 },
17 arith::{eval_compare_through_tower, eval_is_through_tower},
18 clause::{predicate_symbol, rename_clause_apart},
19 env::LogicEnv,
20 error::logic_eval_error,
21 lists::{
22 append_through_sequence, length_through_sequence, member_through_sequence,
23 select_through_sequence,
24 },
25 query::query_all,
26 unify::occurs_check,
27};
28
29pub struct BuiltinCtx<'a> {
31 pub db: &'a LogicDb,
33 pub config: &'a LogicConfig,
35 pub answer_limit: Option<usize>,
37}
38
39pub type BuiltinSolve = dyn for<'a> Fn(&mut Cx, &BuiltinCtx<'a>, &[Expr], &LogicEnv) -> Result<Vec<LogicEnv>>
41 + Send
42 + Sync;
43
44type BuiltinProjection =
45 for<'a> fn(&mut Cx, &BuiltinCtx<'a>, &[Expr], &LogicEnv) -> Result<Vec<LogicEnv>>;
46
47#[derive(Clone)]
49pub struct BuiltinBinding {
50 pub key: Symbol,
52 pub organ: Symbol,
54 pub solve: Arc<BuiltinSolve>,
56}
57
58impl fmt::Debug for BuiltinBinding {
59 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
60 formatter
61 .debug_struct("BuiltinBinding")
62 .field("key", &self.key)
63 .field("organ", &self.organ)
64 .finish_non_exhaustive()
65 }
66}
67
68#[derive(Clone, Default)]
70pub struct BuiltinTable {
71 bindings: IndexMap<Symbol, BuiltinBinding>,
72}
73
74impl fmt::Debug for BuiltinTable {
75 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
76 formatter
77 .debug_struct("BuiltinTable")
78 .field("keys", &self.keys().collect::<Vec<_>>())
79 .finish()
80 }
81}
82
83impl BuiltinTable {
84 pub fn register(&mut self, binding: BuiltinBinding) {
86 self.bindings.insert(binding.key.clone(), binding);
87 }
88
89 pub fn get(&self, key: &Symbol) -> Option<&BuiltinBinding> {
91 self.bindings.get(key)
92 }
93
94 pub fn organ_of(&self, key: &Symbol) -> Option<&Symbol> {
96 self.bindings.get(key).map(|binding| &binding.organ)
97 }
98
99 pub fn keys(&self) -> impl Iterator<Item = &Symbol> {
101 self.bindings.keys()
102 }
103
104 pub fn standard() -> Self {
106 let mut table = Self::default();
107 register_keystones(&mut table);
108 register_constraints(&mut table);
109 register_arithmetic_comparisons(&mut table);
110 register_lists(&mut table);
111 table
112 }
113}
114
115pub fn tabling_memo_binding(predicate: Symbol) -> BuiltinBinding {
123 let memo = Arc::new(Mutex::new(TabledMemo::default()));
124 BuiltinBinding {
125 key: predicate.clone(),
126 organ: Symbol::new("sequence"),
127 solve: Arc::new(move |cx, ctx, args, env| {
128 let tuples = cached_tabled_tuples(cx, ctx, &predicate, args.len(), &memo)?;
129 replay_tabled_tuples(ctx.config, args, env, &tuples)
130 }),
131 }
132}
133
134fn register_keystones(table: &mut BuiltinTable) {
135 table.register(BuiltinBinding {
136 key: Symbol::new("is"),
137 organ: Symbol::qualified("numbers", "arith"),
138 solve: Arc::new(|cx, ctx, args, env| {
139 let [left, right] = args else {
140 return Err(logic_eval_error("is expects two arguments"));
141 };
142 eval_is_through_tower(cx, ctx.config, left, right, env)
143 }),
144 });
145 table.register(BuiltinBinding {
146 key: Symbol::new("findall"),
147 organ: Symbol::new("sequence"),
148 solve: Arc::new(|cx, ctx, args, env| {
149 let [template, goal, output] = args else {
150 return Err(logic_eval_error("findall expects three arguments"));
151 };
152 findall_through_sequence(
153 cx,
154 FindallRequest {
155 db: ctx.db,
156 config: ctx.config,
157 template,
158 goal,
159 output,
160 env,
161 },
162 )
163 }),
164 });
165 register_sequence_binding(table, "bagof", bagof_through_sequence);
166 register_sequence_binding(table, "setof", setof_through_sequence);
167}
168
169fn register_constraints(table: &mut BuiltinTable) {
170 for key in ["#=", "#<", "dif"] {
171 let key = Symbol::new(key);
172 table.register(BuiltinBinding {
173 key: key.clone(),
174 organ: Symbol::new("control"),
175 solve: Arc::new(move |cx, ctx, args, env| {
176 crate::constraints::solve_constraint(cx, ctx.config, &key, args, env)
177 }),
178 });
179 }
180 for key in [
181 "=",
182 "<",
183 "<=",
184 ">",
185 ">=",
186 "plus",
187 "minus",
188 "times",
189 "between",
190 "tool-call",
191 ] {
192 let key = Symbol::new(key);
193 table.register(BuiltinBinding {
194 key: key.clone(),
195 organ: Symbol::qualified("logic", "constraint"),
196 solve: Arc::new(move |cx, ctx, args, env| {
197 crate::constraints::solve_constraint(cx, ctx.config, &key, args, env)
198 }),
199 });
200 }
201}
202
203fn register_arithmetic_comparisons(table: &mut BuiltinTable) {
204 for key in ["=:=", "=\\=", "<", "=<", ">", ">="] {
205 let key = Symbol::new(key);
206 table.register(BuiltinBinding {
207 key: key.clone(),
208 organ: Symbol::qualified("numbers", "arith"),
209 solve: Arc::new(move |cx, _ctx, args, env| {
210 eval_compare_through_tower(cx, &key, args, env)
211 }),
212 });
213 }
214}
215
216fn register_lists(table: &mut BuiltinTable) {
217 register_sequence_binding(table, "member", member_through_sequence);
218 register_sequence_binding(table, "append", append_through_sequence);
219 register_sequence_binding(table, "length", length_through_sequence);
220 register_sequence_binding(table, "select", select_through_sequence);
221}
222
223fn register_sequence_binding(table: &mut BuiltinTable, key: &str, solve: BuiltinProjection) {
224 table.register(BuiltinBinding {
225 key: Symbol::new(key),
226 organ: Symbol::new("sequence"),
227 solve: Arc::new(solve),
228 });
229}
230
231#[derive(Default)]
232struct TabledMemo {
233 by_arity: BTreeMap<usize, Vec<Vec<Expr>>>,
234}
235
236fn cached_tabled_tuples(
237 cx: &mut Cx,
238 ctx: &BuiltinCtx<'_>,
239 predicate: &Symbol,
240 arity: usize,
241 memo: &Arc<Mutex<TabledMemo>>,
242) -> Result<Vec<Vec<Expr>>> {
243 if let Some(cached) = memo
244 .lock()
245 .map_err(|_| Error::PoisonedLock("logic tabling memo"))?
246 .by_arity
247 .get(&arity)
248 .cloned()
249 {
250 return Ok(cached);
251 }
252
253 let computed = compute_tabled_tuples(cx, ctx, predicate, arity)?;
254 let mut guard = memo
255 .lock()
256 .map_err(|_| Error::PoisonedLock("logic tabling memo"))?;
257 Ok(guard
258 .by_arity
259 .entry(arity)
260 .or_insert_with(|| computed.clone())
261 .clone())
262}
263
264fn compute_tabled_tuples(
265 cx: &mut Cx,
266 ctx: &BuiltinCtx<'_>,
267 predicate: &Symbol,
268 arity: usize,
269) -> Result<Vec<Vec<Expr>>> {
270 let mut tuples = Vec::new();
271 let mut seen = BTreeSet::new();
272 let max_rounds = ctx.config.limits.max_depth.max(1);
273 for round in 0..max_rounds {
274 let before = tuples.len();
275 for clause in ctx.db.clauses() {
276 if clause.predicate()? != predicate.clone() || clause.arity()? != arity {
277 continue;
278 }
279 let clause = rename_clause_apart(clause, round + 1);
280 for env in solve_tabled_body(cx, ctx, predicate, &tuples, &clause.body)? {
281 let tuple = tabled_head_tuple(&clause.head, &env)?;
282 if tuple.len() == arity
283 && tuple.iter().all(is_ground)
284 && seen.insert(tuple_key(&tuple))
285 {
286 tuples.push(tuple);
287 }
288 }
289 }
290 if tuples.len() == before {
291 return Ok(tuples);
292 }
293 }
294 Err(logic_eval_error(format!(
295 "tabling memo for {predicate} exceeded fixed-point limit {max_rounds}"
296 )))
297}
298
299fn solve_tabled_body(
300 cx: &mut Cx,
301 ctx: &BuiltinCtx<'_>,
302 predicate: &Symbol,
303 tuples: &[Vec<Expr>],
304 body: &[Expr],
305) -> Result<Vec<LogicEnv>> {
306 let mut envs = vec![LogicEnv::new()];
307 for goal in body {
308 let mut next_envs = Vec::new();
309 for env in envs {
310 let applied = env.apply(goal);
311 if predicate_symbol(&applied)? == predicate.clone() {
312 next_envs.extend(replay_tabled_tuples(
313 ctx.config,
314 goal_args(&applied)?,
315 &env,
316 tuples,
317 )?);
318 } else {
319 next_envs.extend(resolve_non_tabled_goal(cx, ctx, &applied, &env)?);
320 }
321 }
322 envs = next_envs;
323 if envs.is_empty() {
324 break;
325 }
326 }
327 Ok(envs)
328}
329
330fn replay_tabled_tuples(
331 config: &LogicConfig,
332 args: &[Expr],
333 env: &LogicEnv,
334 tuples: &[Vec<Expr>],
335) -> Result<Vec<LogicEnv>> {
336 let mut out = Vec::new();
337 for tuple in tuples.iter().filter(|tuple| tuple.len() == args.len()) {
338 let mut next = env.clone();
339 let mut accepted = true;
340 for (arg, value) in args.iter().zip(tuple) {
341 if !next.unify(arg, value, occurs_check(config))? {
342 accepted = false;
343 break;
344 }
345 }
346 if accepted {
347 out.push(next);
348 }
349 }
350 Ok(out)
351}
352
353fn resolve_non_tabled_goal(
354 cx: &mut Cx,
355 ctx: &BuiltinCtx<'_>,
356 goal: &Expr,
357 env: &LogicEnv,
358) -> Result<Vec<LogicEnv>> {
359 let mut out = Vec::new();
360 for answer in query_all(cx, ctx.db, ctx.config, goal.clone(), ctx.answer_limit)? {
361 if let Some(next) = merge_answer(env.clone(), ctx.config, &answer)? {
362 out.push(next);
363 }
364 }
365 Ok(out)
366}
367
368fn merge_answer(
369 mut env: LogicEnv,
370 config: &LogicConfig,
371 answer: &ShapeMatch,
372) -> Result<Option<LogicEnv>> {
373 for (var, value) in answer.captures.exprs() {
374 if !env.unify(&Expr::Local(var.clone()), value, occurs_check(config))? {
375 return Ok(None);
376 }
377 }
378 Ok(Some(env))
379}
380
381fn tabled_head_tuple(head: &Expr, env: &LogicEnv) -> Result<Vec<Expr>> {
382 Ok(goal_args(head)?.iter().map(|arg| env.apply(arg)).collect())
383}
384
385fn goal_args(goal: &Expr) -> Result<&[Expr]> {
386 match goal {
387 Expr::List(items) => Ok(&items[1..]),
388 Expr::Call { args, .. } => Ok(args),
389 _ => Err(logic_eval_error("tabled goal must be call-shaped")),
390 }
391}
392
393fn is_ground(expr: &Expr) -> bool {
394 match expr {
395 Expr::Local(_) => false,
396 Expr::List(items) | Expr::Vector(items) | Expr::Set(items) | Expr::Block(items) => {
397 items.iter().all(is_ground)
398 }
399 Expr::Map(entries) => entries
400 .iter()
401 .all(|(key, value)| is_ground(key) && is_ground(value)),
402 Expr::Call { operator, args } => is_ground(operator) && args.iter().all(is_ground),
403 Expr::Infix { left, right, .. } => is_ground(left) && is_ground(right),
404 Expr::Prefix { arg, .. } | Expr::Postfix { arg, .. } => is_ground(arg),
405 Expr::Quote { expr, .. } | Expr::Extension { payload: expr, .. } => is_ground(expr),
406 Expr::Annotated { expr, annotations } => {
407 is_ground(expr) && annotations.iter().all(|(_, value)| is_ground(value))
408 }
409 _ => true,
410 }
411}
412
413fn tuple_key(tuple: &[Expr]) -> String {
414 tuple
415 .iter()
416 .map(|expr| format!("{:?}", expr.canonical_key()))
417 .collect::<Vec<_>>()
418 .join("\0")
419}