1use std::{
7 collections::{BTreeMap, BTreeSet},
8 sync::Arc,
9};
10
11use sim_kernel::{
12 Cx, DefaultFactory, EagerPolicy, Expr, MatchScore, Result, ShapeBindings, ShapeMatch, Symbol,
13};
14use sim_shape::{AnyShape, CaptureShape, ExactExprShape, ListShape, Shape};
15
16use crate::model::OccursCheck;
17
18#[derive(Clone, Debug, Default, PartialEq, Eq)]
23pub struct LogicEnv {
24 captures: BTreeMap<Symbol, Expr>,
25 depth: usize,
26}
27
28impl LogicEnv {
29 pub fn new() -> Self {
31 Self::default()
32 }
33
34 pub fn with_depth(depth: usize) -> Self {
36 Self {
37 captures: BTreeMap::new(),
38 depth,
39 }
40 }
41
42 pub fn depth(&self) -> usize {
44 self.depth
45 }
46
47 pub fn set_depth(&mut self, depth: usize) {
49 self.depth = depth;
50 }
51
52 pub fn apply(&self, expr: &Expr) -> Expr {
55 match expr {
56 Expr::Local(var) => match self.captures.get(var) {
57 Some(bound) => self.apply(bound),
58 None => Expr::Local(var.clone()),
59 },
60 Expr::List(items) => Expr::List(items.iter().map(|item| self.apply(item)).collect()),
61 Expr::Vector(items) => {
62 Expr::Vector(items.iter().map(|item| self.apply(item)).collect())
63 }
64 Expr::Map(entries) => Expr::Map(
65 entries
66 .iter()
67 .map(|(key, value)| (self.apply(key), self.apply(value)))
68 .collect(),
69 ),
70 Expr::Set(items) => Expr::Set(items.iter().map(|item| self.apply(item)).collect()),
71 Expr::Call { operator, args } => Expr::Call {
72 operator: Box::new(self.apply(operator)),
73 args: args.iter().map(|arg| self.apply(arg)).collect(),
74 },
75 Expr::Infix {
76 operator,
77 left,
78 right,
79 } => Expr::Infix {
80 operator: operator.clone(),
81 left: Box::new(self.apply(left)),
82 right: Box::new(self.apply(right)),
83 },
84 Expr::Prefix { operator, arg } => Expr::Prefix {
85 operator: operator.clone(),
86 arg: Box::new(self.apply(arg)),
87 },
88 Expr::Postfix { operator, arg } => Expr::Postfix {
89 operator: operator.clone(),
90 arg: Box::new(self.apply(arg)),
91 },
92 Expr::Block(items) => Expr::Block(items.iter().map(|item| self.apply(item)).collect()),
93 Expr::Quote { mode, expr } => Expr::Quote {
94 mode: *mode,
95 expr: Box::new(self.apply(expr)),
96 },
97 Expr::Annotated { expr, annotations } => Expr::Annotated {
98 expr: Box::new(self.apply(expr)),
99 annotations: annotations
100 .iter()
101 .map(|(name, value)| (name.clone(), self.apply(value)))
102 .collect(),
103 },
104 Expr::Extension { tag, payload } => Expr::Extension {
105 tag: tag.clone(),
106 payload: Box::new(self.apply(payload)),
107 },
108 other => other.clone(),
109 }
110 }
111
112 pub fn get(&self, var: &Symbol) -> Option<&Expr> {
114 self.captures.get(var)
115 }
116
117 pub fn bind(&mut self, var: Symbol, value: Expr, occurs_check: OccursCheck) -> Result<()> {
122 if matches!(occurs_check, OccursCheck::Always) && occurs(var.clone(), &value, self) {
123 return Err(sim_kernel::Error::Eval(format!(
124 "occurs check failed for ?{}",
125 var.name
126 )));
127 }
128 self.captures.insert(var, value);
129 Ok(())
130 }
131
132 pub fn unify(&mut self, left: &Expr, right: &Expr, occurs_check: OccursCheck) -> Result<bool> {
137 let left = self.apply(left);
138 let right = self.apply(right);
139 if left.canonical_eq(&right) {
140 return Ok(true);
141 }
142
143 let left_match = self.shape_unify(&left, &right, occurs_check)?;
144 let right_match = self.shape_unify(&right, &left, occurs_check)?;
145 match (left_match, right_match) {
146 (ShapeUnify::Accepted(next), _) | (_, ShapeUnify::Accepted(next)) => {
147 *self = next;
148 Ok(true)
149 }
150 (ShapeUnify::Unsupported, _) | (_, ShapeUnify::Unsupported) => {
151 unify_ground(self, &left, &right, occurs_check)
152 }
153 (ShapeUnify::Rejected, ShapeUnify::Rejected) => Ok(false),
154 }
155 }
156
157 fn shape_unify(
158 &self,
159 pattern: &Expr,
160 subject: &Expr,
161 occurs_check: OccursCheck,
162 ) -> Result<ShapeUnify> {
163 let Some(shape) = shape_from_pattern(pattern) else {
164 return Ok(ShapeUnify::Unsupported);
165 };
166 let mut cx = Cx::new(Arc::new(EagerPolicy), Arc::new(DefaultFactory));
167 let matched = shape.check_expr(&mut cx, subject)?;
168 if !matched.accepted {
169 return Ok(ShapeUnify::Rejected);
170 }
171 let mut next = self.clone();
172 if next.merge_shape_captures(&matched.captures, occurs_check)? {
173 Ok(ShapeUnify::Accepted(next))
174 } else {
175 Ok(ShapeUnify::Rejected)
176 }
177 }
178
179 fn merge_shape_captures(
180 &mut self,
181 captures: &ShapeBindings,
182 occurs_check: OccursCheck,
183 ) -> Result<bool> {
184 for (var, value) in captures.exprs() {
185 if !self.merge_shape_capture(var.clone(), value.clone(), occurs_check)? {
186 return Ok(false);
187 }
188 }
189 Ok(true)
190 }
191
192 fn merge_shape_capture(
193 &mut self,
194 var: Symbol,
195 value: Expr,
196 occurs_check: OccursCheck,
197 ) -> Result<bool> {
198 let value = self.apply(&value);
199 if let Some(bound) = self.captures.get(&var).cloned() {
200 let bound = self.apply(&bound);
201 return self.unify(&bound, &value, occurs_check);
202 }
203 self.bind(var, value, occurs_check)?;
204 Ok(true)
205 }
206
207 pub fn free_vars(&self, expr: &Expr) -> Vec<Symbol> {
209 let mut vars = BTreeSet::new();
210 collect_vars(expr, &mut vars);
211 vars.into_iter().collect()
212 }
213
214 pub fn to_shape_bindings(&self, _cx: &mut Cx) -> Result<ShapeBindings> {
216 let mut bindings = ShapeBindings::new();
217 for (name, expr) in &self.captures {
218 bindings.bind_expr(name.clone(), self.apply(expr));
219 }
220 Ok(bindings)
221 }
222
223 pub fn as_shape_match(&self, cx: &mut Cx) -> Result<ShapeMatch> {
226 Ok(ShapeMatch {
227 accepted: true,
228 captures: self.to_shape_bindings(cx)?,
229 score: MatchScore::exact(100),
230 diagnostics: Vec::new(),
231 })
232 }
233}
234
235enum ShapeUnify {
236 Accepted(LogicEnv),
237 Rejected,
238 Unsupported,
239}
240
241fn shape_from_pattern(pattern: &Expr) -> Option<Arc<dyn Shape>> {
242 match pattern {
243 Expr::Local(var) => Some(Arc::new(CaptureShape::new(var.clone(), Arc::new(AnyShape)))),
244 Expr::List(items) => {
245 let item_shapes = items
246 .iter()
247 .map(shape_from_pattern)
248 .collect::<Option<Vec<_>>>()?;
249 Some(Arc::new(ListShape::new(item_shapes)))
250 }
251 other if !contains_local(other) => Some(Arc::new(ExactExprShape::new(other.clone()))),
252 _ => None,
253 }
254}
255
256fn unify_ground(
257 env: &mut LogicEnv,
258 left: &Expr,
259 right: &Expr,
260 occurs_check: OccursCheck,
261) -> Result<bool> {
262 match (left, right) {
263 (Expr::Nil, Expr::Nil)
264 | (Expr::Bool(_), Expr::Bool(_))
265 | (Expr::Number(_), Expr::Number(_))
266 | (Expr::Symbol(_), Expr::Symbol(_))
267 | (Expr::Local(_), Expr::Local(_))
268 | (Expr::String(_), Expr::String(_))
269 | (Expr::Bytes(_), Expr::Bytes(_)) => Ok(left.canonical_eq(right)),
270 (Expr::List(left_items), Expr::List(right_items))
271 | (Expr::Vector(left_items), Expr::Vector(right_items))
272 | (Expr::Set(left_items), Expr::Set(right_items))
273 | (Expr::Block(left_items), Expr::Block(right_items)) => {
274 unify_slices(env, left_items, right_items, occurs_check)
275 }
276 (Expr::Map(left_entries), Expr::Map(right_entries)) => {
277 if left_entries.len() != right_entries.len() {
278 return Ok(false);
279 }
280 for ((left_key, left_value), (right_key, right_value)) in
281 left_entries.iter().zip(right_entries.iter())
282 {
283 if !env.unify(left_key, right_key, occurs_check)? {
284 return Ok(false);
285 }
286 if !env.unify(left_value, right_value, occurs_check)? {
287 return Ok(false);
288 }
289 }
290 Ok(true)
291 }
292 (
293 Expr::Call {
294 operator: left_op,
295 args: left_args,
296 },
297 Expr::Call {
298 operator: right_op,
299 args: right_args,
300 },
301 ) => {
302 if left_args.len() != right_args.len() || !env.unify(left_op, right_op, occurs_check)? {
303 return Ok(false);
304 }
305 unify_slices(env, left_args, right_args, occurs_check)
306 }
307 (
308 Expr::Quote {
309 mode: left_mode,
310 expr: left_expr,
311 },
312 Expr::Quote {
313 mode: right_mode,
314 expr: right_expr,
315 },
316 ) => {
317 if left_mode != right_mode {
318 return Ok(false);
319 }
320 env.unify(left_expr, right_expr, occurs_check)
321 }
322 (
323 Expr::Annotated {
324 expr: left_expr,
325 annotations: left_annotations,
326 },
327 Expr::Annotated {
328 expr: right_expr,
329 annotations: right_annotations,
330 },
331 ) => {
332 if left_annotations.len() != right_annotations.len()
333 || !env.unify(left_expr, right_expr, occurs_check)?
334 {
335 return Ok(false);
336 }
337 for ((left_name, left_value), (right_name, right_value)) in
338 left_annotations.iter().zip(right_annotations.iter())
339 {
340 if left_name != right_name || !env.unify(left_value, right_value, occurs_check)? {
341 return Ok(false);
342 }
343 }
344 Ok(true)
345 }
346 (
347 Expr::Extension {
348 tag: left_tag,
349 payload: left_payload,
350 },
351 Expr::Extension {
352 tag: right_tag,
353 payload: right_payload,
354 },
355 ) => Ok(left_tag == right_tag && env.unify(left_payload, right_payload, occurs_check)?),
356 (
357 Expr::Infix {
358 operator: left_op,
359 left: left_a,
360 right: left_b,
361 },
362 Expr::Infix {
363 operator: right_op,
364 left: right_a,
365 right: right_b,
366 },
367 ) => Ok(left_op == right_op
368 && env.unify(left_a, right_a, occurs_check)?
369 && env.unify(left_b, right_b, occurs_check)?),
370 (
371 Expr::Prefix {
372 operator: left_op,
373 arg: left_arg,
374 },
375 Expr::Prefix {
376 operator: right_op,
377 arg: right_arg,
378 },
379 )
380 | (
381 Expr::Postfix {
382 operator: left_op,
383 arg: left_arg,
384 },
385 Expr::Postfix {
386 operator: right_op,
387 arg: right_arg,
388 },
389 ) => Ok(left_op == right_op && env.unify(left_arg, right_arg, occurs_check)?),
390 _ => Ok(false),
391 }
392}
393
394fn unify_slices(
395 env: &mut LogicEnv,
396 left: &[Expr],
397 right: &[Expr],
398 occurs_check: OccursCheck,
399) -> Result<bool> {
400 if left.len() != right.len() {
401 return Ok(false);
402 }
403 for (left_item, right_item) in left.iter().zip(right.iter()) {
404 if !env.unify(left_item, right_item, occurs_check)? {
405 return Ok(false);
406 }
407 }
408 Ok(true)
409}
410
411fn occurs(var: Symbol, expr: &Expr, env: &LogicEnv) -> bool {
412 match env.apply(expr) {
413 Expr::Local(candidate) => candidate == var,
414 Expr::List(items) | Expr::Vector(items) | Expr::Set(items) | Expr::Block(items) => {
415 items.iter().any(|item| occurs(var.clone(), item, env))
416 }
417 Expr::Map(entries) => entries
418 .iter()
419 .any(|(key, value)| occurs(var.clone(), key, env) || occurs(var.clone(), value, env)),
420 Expr::Call { operator, args } => {
421 occurs(var.clone(), &operator, env)
422 || args.iter().any(|arg| occurs(var.clone(), arg, env))
423 }
424 Expr::Infix { left, right, .. } => {
425 occurs(var.clone(), &left, env) || occurs(var, &right, env)
426 }
427 Expr::Prefix { arg, .. } | Expr::Postfix { arg, .. } => occurs(var, &arg, env),
428 Expr::Quote { expr, .. } => occurs(var, &expr, env),
429 Expr::Annotated { expr, annotations } => {
430 occurs(var.clone(), &expr, env)
431 || annotations
432 .iter()
433 .any(|(_, value)| occurs(var.clone(), value, env))
434 }
435 Expr::Extension { payload, .. } => occurs(var, &payload, env),
436 _ => false,
437 }
438}
439
440fn contains_local(expr: &Expr) -> bool {
441 match expr {
442 Expr::Local(_) => true,
443 Expr::List(items) | Expr::Vector(items) | Expr::Set(items) | Expr::Block(items) => {
444 items.iter().any(contains_local)
445 }
446 Expr::Map(entries) => entries
447 .iter()
448 .any(|(key, value)| contains_local(key) || contains_local(value)),
449 Expr::Call { operator, args } => {
450 contains_local(operator) || args.iter().any(contains_local)
451 }
452 Expr::Infix { left, right, .. } => contains_local(left) || contains_local(right),
453 Expr::Prefix { arg, .. } | Expr::Postfix { arg, .. } => contains_local(arg),
454 Expr::Quote { expr, .. } => contains_local(expr),
455 Expr::Annotated { expr, annotations } => {
456 contains_local(expr) || annotations.iter().any(|(_, value)| contains_local(value))
457 }
458 Expr::Extension { payload, .. } => contains_local(payload),
459 _ => false,
460 }
461}
462
463fn collect_vars(expr: &Expr, vars: &mut BTreeSet<Symbol>) {
464 match expr {
465 Expr::Local(var) => {
466 vars.insert(var.clone());
467 }
468 Expr::List(items) | Expr::Vector(items) | Expr::Set(items) | Expr::Block(items) => {
469 for item in items {
470 collect_vars(item, vars);
471 }
472 }
473 Expr::Map(entries) => {
474 for (key, value) in entries {
475 collect_vars(key, vars);
476 collect_vars(value, vars);
477 }
478 }
479 Expr::Call { operator, args } => {
480 collect_vars(operator, vars);
481 for arg in args {
482 collect_vars(arg, vars);
483 }
484 }
485 Expr::Infix { left, right, .. } => {
486 collect_vars(left, vars);
487 collect_vars(right, vars);
488 }
489 Expr::Prefix { arg, .. } | Expr::Postfix { arg, .. } => collect_vars(arg, vars),
490 Expr::Quote { expr, .. } => collect_vars(expr, vars),
491 Expr::Annotated { expr, annotations } => {
492 collect_vars(expr, vars);
493 for (_, value) in annotations {
494 collect_vars(value, vars);
495 }
496 }
497 Expr::Extension { payload, .. } => collect_vars(payload, vars),
498 _ => {}
499 }
500}