Skip to main content

sim_lib_logic/
env.rs

1//! Substitution environment: the variable bindings built during unification.
2//!
3//! A [`LogicEnv`] maps logic variables to their bound terms and drives the
4//! unifier. It bridges to the kernel `Shape` contracts by projecting its
5//! bindings into a `ShapeMatch`; see the [`README`](https://docs.rs/sim-runtime).
6use std::{
7    collections::{BTreeMap, BTreeSet},
8    sync::Arc,
9};
10
11use sim_kernel::{
12    Cx, DefaultFactory, EagerPolicy, Expr, MatchScore, Result, ShapeBindings, ShapeMatch, Symbol,
13};
14use sim_shape::{AnyShape, CaptureShape, ExactExprShape, ListShape, Shape};
15
16use crate::model::OccursCheck;
17
18/// A unification substitution: bindings from logic variables to terms.
19///
20/// Carries a resolution `depth` so renamed clause variables stay distinct
21/// across recursive calls.
22#[derive(Clone, Debug, Default, PartialEq, Eq)]
23pub struct LogicEnv {
24    captures: BTreeMap<Symbol, Expr>,
25    depth: usize,
26}
27
28impl LogicEnv {
29    /// Creates an empty environment at depth zero.
30    pub fn new() -> Self {
31        Self::default()
32    }
33
34    /// Creates an empty environment recorded at the given resolution depth.
35    pub fn with_depth(depth: usize) -> Self {
36        Self {
37            captures: BTreeMap::new(),
38            depth,
39        }
40    }
41
42    /// Returns the resolution depth recorded for this environment.
43    pub fn depth(&self) -> usize {
44        self.depth
45    }
46
47    /// Sets the recorded resolution depth.
48    pub fn set_depth(&mut self, depth: usize) {
49        self.depth = depth;
50    }
51
52    /// Applies the substitution to `expr`, recursively replacing bound
53    /// variables with their values.
54    pub fn apply(&self, expr: &Expr) -> Expr {
55        match expr {
56            Expr::Local(var) => match self.captures.get(var) {
57                Some(bound) => self.apply(bound),
58                None => Expr::Local(var.clone()),
59            },
60            Expr::List(items) => Expr::List(items.iter().map(|item| self.apply(item)).collect()),
61            Expr::Vector(items) => {
62                Expr::Vector(items.iter().map(|item| self.apply(item)).collect())
63            }
64            Expr::Map(entries) => Expr::Map(
65                entries
66                    .iter()
67                    .map(|(key, value)| (self.apply(key), self.apply(value)))
68                    .collect(),
69            ),
70            Expr::Set(items) => Expr::Set(items.iter().map(|item| self.apply(item)).collect()),
71            Expr::Call { operator, args } => Expr::Call {
72                operator: Box::new(self.apply(operator)),
73                args: args.iter().map(|arg| self.apply(arg)).collect(),
74            },
75            Expr::Infix {
76                operator,
77                left,
78                right,
79            } => Expr::Infix {
80                operator: operator.clone(),
81                left: Box::new(self.apply(left)),
82                right: Box::new(self.apply(right)),
83            },
84            Expr::Prefix { operator, arg } => Expr::Prefix {
85                operator: operator.clone(),
86                arg: Box::new(self.apply(arg)),
87            },
88            Expr::Postfix { operator, arg } => Expr::Postfix {
89                operator: operator.clone(),
90                arg: Box::new(self.apply(arg)),
91            },
92            Expr::Block(items) => Expr::Block(items.iter().map(|item| self.apply(item)).collect()),
93            Expr::Quote { mode, expr } => Expr::Quote {
94                mode: *mode,
95                expr: Box::new(self.apply(expr)),
96            },
97            Expr::Annotated { expr, annotations } => Expr::Annotated {
98                expr: Box::new(self.apply(expr)),
99                annotations: annotations
100                    .iter()
101                    .map(|(name, value)| (name.clone(), self.apply(value)))
102                    .collect(),
103            },
104            Expr::Extension { tag, payload } => Expr::Extension {
105                tag: tag.clone(),
106                payload: Box::new(self.apply(payload)),
107            },
108            other => other.clone(),
109        }
110    }
111
112    /// Returns the term directly bound to `var`, if any.
113    pub fn get(&self, var: &Symbol) -> Option<&Expr> {
114        self.captures.get(var)
115    }
116
117    /// Binds `var` to `value`, honoring the [`OccursCheck`] policy.
118    ///
119    /// Returns an error when an enabled occurs check detects that `var` occurs
120    /// in `value` (which would build a cyclic term).
121    pub fn bind(&mut self, var: Symbol, value: Expr, occurs_check: OccursCheck) -> Result<()> {
122        if matches!(occurs_check, OccursCheck::Always) && occurs(var.clone(), &value, self) {
123            return Err(sim_kernel::Error::Eval(format!(
124                "occurs check failed for ?{}",
125                var.name
126            )));
127        }
128        self.captures.insert(var, value);
129        Ok(())
130    }
131
132    /// Unifies two terms, extending the substitution in place.
133    ///
134    /// Returns `true` when the terms unify and `false` on a structural
135    /// mismatch; errors propagate only from a failed occurs check.
136    pub fn unify(&mut self, left: &Expr, right: &Expr, occurs_check: OccursCheck) -> Result<bool> {
137        let left = self.apply(left);
138        let right = self.apply(right);
139        if left.canonical_eq(&right) {
140            return Ok(true);
141        }
142
143        let left_match = self.shape_unify(&left, &right, occurs_check)?;
144        let right_match = self.shape_unify(&right, &left, occurs_check)?;
145        match (left_match, right_match) {
146            (ShapeUnify::Accepted(next), _) | (_, ShapeUnify::Accepted(next)) => {
147                *self = next;
148                Ok(true)
149            }
150            (ShapeUnify::Unsupported, _) | (_, ShapeUnify::Unsupported) => {
151                unify_ground(self, &left, &right, occurs_check)
152            }
153            (ShapeUnify::Rejected, ShapeUnify::Rejected) => Ok(false),
154        }
155    }
156
157    fn shape_unify(
158        &self,
159        pattern: &Expr,
160        subject: &Expr,
161        occurs_check: OccursCheck,
162    ) -> Result<ShapeUnify> {
163        let Some(shape) = shape_from_pattern(pattern) else {
164            return Ok(ShapeUnify::Unsupported);
165        };
166        let mut cx = Cx::new(Arc::new(EagerPolicy), Arc::new(DefaultFactory));
167        let matched = shape.check_expr(&mut cx, subject)?;
168        if !matched.accepted {
169            return Ok(ShapeUnify::Rejected);
170        }
171        let mut next = self.clone();
172        if next.merge_shape_captures(&matched.captures, occurs_check)? {
173            Ok(ShapeUnify::Accepted(next))
174        } else {
175            Ok(ShapeUnify::Rejected)
176        }
177    }
178
179    fn merge_shape_captures(
180        &mut self,
181        captures: &ShapeBindings,
182        occurs_check: OccursCheck,
183    ) -> Result<bool> {
184        for (var, value) in captures.exprs() {
185            if !self.merge_shape_capture(var.clone(), value.clone(), occurs_check)? {
186                return Ok(false);
187            }
188        }
189        Ok(true)
190    }
191
192    fn merge_shape_capture(
193        &mut self,
194        var: Symbol,
195        value: Expr,
196        occurs_check: OccursCheck,
197    ) -> Result<bool> {
198        let value = self.apply(&value);
199        if let Some(bound) = self.captures.get(&var).cloned() {
200            let bound = self.apply(&bound);
201            return self.unify(&bound, &value, occurs_check);
202        }
203        self.bind(var, value, occurs_check)?;
204        Ok(true)
205    }
206
207    /// Collects the distinct logic variables appearing in `expr`.
208    pub fn free_vars(&self, expr: &Expr) -> Vec<Symbol> {
209        let mut vars = BTreeSet::new();
210        collect_vars(expr, &mut vars);
211        vars.into_iter().collect()
212    }
213
214    /// Projects the current bindings into kernel `ShapeBindings`.
215    pub fn to_shape_bindings(&self, _cx: &mut Cx) -> Result<ShapeBindings> {
216        let mut bindings = ShapeBindings::new();
217        for (name, expr) in &self.captures {
218            bindings.bind_expr(name.clone(), self.apply(expr));
219        }
220        Ok(bindings)
221    }
222
223    /// Builds an accepting kernel `ShapeMatch` whose captures are this
224    /// environment's bindings.
225    pub fn as_shape_match(&self, cx: &mut Cx) -> Result<ShapeMatch> {
226        Ok(ShapeMatch {
227            accepted: true,
228            captures: self.to_shape_bindings(cx)?,
229            score: MatchScore::exact(100),
230            diagnostics: Vec::new(),
231        })
232    }
233}
234
235enum ShapeUnify {
236    Accepted(LogicEnv),
237    Rejected,
238    Unsupported,
239}
240
241fn shape_from_pattern(pattern: &Expr) -> Option<Arc<dyn Shape>> {
242    match pattern {
243        Expr::Local(var) => Some(Arc::new(CaptureShape::new(var.clone(), Arc::new(AnyShape)))),
244        Expr::List(items) => {
245            let item_shapes = items
246                .iter()
247                .map(shape_from_pattern)
248                .collect::<Option<Vec<_>>>()?;
249            Some(Arc::new(ListShape::new(item_shapes)))
250        }
251        other if !contains_local(other) => Some(Arc::new(ExactExprShape::new(other.clone()))),
252        _ => None,
253    }
254}
255
256fn unify_ground(
257    env: &mut LogicEnv,
258    left: &Expr,
259    right: &Expr,
260    occurs_check: OccursCheck,
261) -> Result<bool> {
262    match (left, right) {
263        (Expr::Nil, Expr::Nil)
264        | (Expr::Bool(_), Expr::Bool(_))
265        | (Expr::Number(_), Expr::Number(_))
266        | (Expr::Symbol(_), Expr::Symbol(_))
267        | (Expr::Local(_), Expr::Local(_))
268        | (Expr::String(_), Expr::String(_))
269        | (Expr::Bytes(_), Expr::Bytes(_)) => Ok(left.canonical_eq(right)),
270        (Expr::List(left_items), Expr::List(right_items))
271        | (Expr::Vector(left_items), Expr::Vector(right_items))
272        | (Expr::Set(left_items), Expr::Set(right_items))
273        | (Expr::Block(left_items), Expr::Block(right_items)) => {
274            unify_slices(env, left_items, right_items, occurs_check)
275        }
276        (Expr::Map(left_entries), Expr::Map(right_entries)) => {
277            if left_entries.len() != right_entries.len() {
278                return Ok(false);
279            }
280            for ((left_key, left_value), (right_key, right_value)) in
281                left_entries.iter().zip(right_entries.iter())
282            {
283                if !env.unify(left_key, right_key, occurs_check)? {
284                    return Ok(false);
285                }
286                if !env.unify(left_value, right_value, occurs_check)? {
287                    return Ok(false);
288                }
289            }
290            Ok(true)
291        }
292        (
293            Expr::Call {
294                operator: left_op,
295                args: left_args,
296            },
297            Expr::Call {
298                operator: right_op,
299                args: right_args,
300            },
301        ) => {
302            if left_args.len() != right_args.len() || !env.unify(left_op, right_op, occurs_check)? {
303                return Ok(false);
304            }
305            unify_slices(env, left_args, right_args, occurs_check)
306        }
307        (
308            Expr::Quote {
309                mode: left_mode,
310                expr: left_expr,
311            },
312            Expr::Quote {
313                mode: right_mode,
314                expr: right_expr,
315            },
316        ) => {
317            if left_mode != right_mode {
318                return Ok(false);
319            }
320            env.unify(left_expr, right_expr, occurs_check)
321        }
322        (
323            Expr::Annotated {
324                expr: left_expr,
325                annotations: left_annotations,
326            },
327            Expr::Annotated {
328                expr: right_expr,
329                annotations: right_annotations,
330            },
331        ) => {
332            if left_annotations.len() != right_annotations.len()
333                || !env.unify(left_expr, right_expr, occurs_check)?
334            {
335                return Ok(false);
336            }
337            for ((left_name, left_value), (right_name, right_value)) in
338                left_annotations.iter().zip(right_annotations.iter())
339            {
340                if left_name != right_name || !env.unify(left_value, right_value, occurs_check)? {
341                    return Ok(false);
342                }
343            }
344            Ok(true)
345        }
346        (
347            Expr::Extension {
348                tag: left_tag,
349                payload: left_payload,
350            },
351            Expr::Extension {
352                tag: right_tag,
353                payload: right_payload,
354            },
355        ) => Ok(left_tag == right_tag && env.unify(left_payload, right_payload, occurs_check)?),
356        (
357            Expr::Infix {
358                operator: left_op,
359                left: left_a,
360                right: left_b,
361            },
362            Expr::Infix {
363                operator: right_op,
364                left: right_a,
365                right: right_b,
366            },
367        ) => Ok(left_op == right_op
368            && env.unify(left_a, right_a, occurs_check)?
369            && env.unify(left_b, right_b, occurs_check)?),
370        (
371            Expr::Prefix {
372                operator: left_op,
373                arg: left_arg,
374            },
375            Expr::Prefix {
376                operator: right_op,
377                arg: right_arg,
378            },
379        )
380        | (
381            Expr::Postfix {
382                operator: left_op,
383                arg: left_arg,
384            },
385            Expr::Postfix {
386                operator: right_op,
387                arg: right_arg,
388            },
389        ) => Ok(left_op == right_op && env.unify(left_arg, right_arg, occurs_check)?),
390        _ => Ok(false),
391    }
392}
393
394fn unify_slices(
395    env: &mut LogicEnv,
396    left: &[Expr],
397    right: &[Expr],
398    occurs_check: OccursCheck,
399) -> Result<bool> {
400    if left.len() != right.len() {
401        return Ok(false);
402    }
403    for (left_item, right_item) in left.iter().zip(right.iter()) {
404        if !env.unify(left_item, right_item, occurs_check)? {
405            return Ok(false);
406        }
407    }
408    Ok(true)
409}
410
411fn occurs(var: Symbol, expr: &Expr, env: &LogicEnv) -> bool {
412    match env.apply(expr) {
413        Expr::Local(candidate) => candidate == var,
414        Expr::List(items) | Expr::Vector(items) | Expr::Set(items) | Expr::Block(items) => {
415            items.iter().any(|item| occurs(var.clone(), item, env))
416        }
417        Expr::Map(entries) => entries
418            .iter()
419            .any(|(key, value)| occurs(var.clone(), key, env) || occurs(var.clone(), value, env)),
420        Expr::Call { operator, args } => {
421            occurs(var.clone(), &operator, env)
422                || args.iter().any(|arg| occurs(var.clone(), arg, env))
423        }
424        Expr::Infix { left, right, .. } => {
425            occurs(var.clone(), &left, env) || occurs(var, &right, env)
426        }
427        Expr::Prefix { arg, .. } | Expr::Postfix { arg, .. } => occurs(var, &arg, env),
428        Expr::Quote { expr, .. } => occurs(var, &expr, env),
429        Expr::Annotated { expr, annotations } => {
430            occurs(var.clone(), &expr, env)
431                || annotations
432                    .iter()
433                    .any(|(_, value)| occurs(var.clone(), value, env))
434        }
435        Expr::Extension { payload, .. } => occurs(var, &payload, env),
436        _ => false,
437    }
438}
439
440fn contains_local(expr: &Expr) -> bool {
441    match expr {
442        Expr::Local(_) => true,
443        Expr::List(items) | Expr::Vector(items) | Expr::Set(items) | Expr::Block(items) => {
444            items.iter().any(contains_local)
445        }
446        Expr::Map(entries) => entries
447            .iter()
448            .any(|(key, value)| contains_local(key) || contains_local(value)),
449        Expr::Call { operator, args } => {
450            contains_local(operator) || args.iter().any(contains_local)
451        }
452        Expr::Infix { left, right, .. } => contains_local(left) || contains_local(right),
453        Expr::Prefix { arg, .. } | Expr::Postfix { arg, .. } => contains_local(arg),
454        Expr::Quote { expr, .. } => contains_local(expr),
455        Expr::Annotated { expr, annotations } => {
456            contains_local(expr) || annotations.iter().any(|(_, value)| contains_local(value))
457        }
458        Expr::Extension { payload, .. } => contains_local(payload),
459        _ => false,
460    }
461}
462
463fn collect_vars(expr: &Expr, vars: &mut BTreeSet<Symbol>) {
464    match expr {
465        Expr::Local(var) => {
466            vars.insert(var.clone());
467        }
468        Expr::List(items) | Expr::Vector(items) | Expr::Set(items) | Expr::Block(items) => {
469            for item in items {
470                collect_vars(item, vars);
471            }
472        }
473        Expr::Map(entries) => {
474            for (key, value) in entries {
475                collect_vars(key, vars);
476                collect_vars(value, vars);
477            }
478        }
479        Expr::Call { operator, args } => {
480            collect_vars(operator, vars);
481            for arg in args {
482                collect_vars(arg, vars);
483            }
484        }
485        Expr::Infix { left, right, .. } => {
486            collect_vars(left, vars);
487            collect_vars(right, vars);
488        }
489        Expr::Prefix { arg, .. } | Expr::Postfix { arg, .. } => collect_vars(arg, vars),
490        Expr::Quote { expr, .. } => collect_vars(expr, vars),
491        Expr::Annotated { expr, annotations } => {
492            collect_vars(expr, vars);
493            for (_, value) in annotations {
494                collect_vars(value, vars);
495            }
496        }
497        Expr::Extension { payload, .. } => collect_vars(payload, vars),
498        _ => {}
499    }
500}