Skip to main content

tensorlogic_compiler/inline/
substitute.rs

1use tensorlogic_ir::TLExpr;
2
3/// Substitute all free occurrences of `var` with `replacement` in `body`.
4///
5/// This is capture-avoiding: when a binder that re-introduces `var` is
6/// encountered, substitution stops for the sub-tree guarded by that binder.
7///
8/// Variable occurrences are zero-argument predicates with `name == var`
9/// and `Term::Var(var)` inside predicate arguments.
10pub fn substitute(var: &str, replacement: &TLExpr, body: TLExpr) -> TLExpr {
11    subst(var, replacement, body)
12}
13
14pub(crate) fn subst(var: &str, repl: &TLExpr, expr: TLExpr) -> TLExpr {
15    match expr {
16        // Zero-arg predicate used as a variable reference.
17        TLExpr::Pred { ref name, ref args } if args.is_empty() && name == var => repl.clone(),
18
19        // Predicate with arguments: substitute in Term::Var occurrences.
20        TLExpr::Pred { name, args } => {
21            let new_args = args
22                .into_iter()
23                .map(|t| match &t {
24                    tensorlogic_ir::Term::Var(v) if v == var => {
25                        // We can only substitute if replacement is a zero-arg
26                        // Pred (variable) or Constant; otherwise keep the Term.
27                        match repl {
28                            TLExpr::Pred { name: rn, args: ra } if ra.is_empty() => {
29                                tensorlogic_ir::Term::Var(rn.clone())
30                            }
31                            _ => t,
32                        }
33                    }
34                    _ => t,
35                })
36                .collect();
37            TLExpr::Pred {
38                name,
39                args: new_args,
40            }
41        }
42
43        // ── Binary nodes ─────────────────────────────────────────────────
44        TLExpr::And(l, r) => TLExpr::And(
45            Box::new(subst(var, repl, *l)),
46            Box::new(subst(var, repl, *r)),
47        ),
48        TLExpr::Or(l, r) => TLExpr::Or(
49            Box::new(subst(var, repl, *l)),
50            Box::new(subst(var, repl, *r)),
51        ),
52        TLExpr::Imply(l, r) => TLExpr::Imply(
53            Box::new(subst(var, repl, *l)),
54            Box::new(subst(var, repl, *r)),
55        ),
56        TLExpr::Add(l, r) => TLExpr::Add(
57            Box::new(subst(var, repl, *l)),
58            Box::new(subst(var, repl, *r)),
59        ),
60        TLExpr::Sub(l, r) => TLExpr::Sub(
61            Box::new(subst(var, repl, *l)),
62            Box::new(subst(var, repl, *r)),
63        ),
64        TLExpr::Mul(l, r) => TLExpr::Mul(
65            Box::new(subst(var, repl, *l)),
66            Box::new(subst(var, repl, *r)),
67        ),
68        TLExpr::Div(l, r) => TLExpr::Div(
69            Box::new(subst(var, repl, *l)),
70            Box::new(subst(var, repl, *r)),
71        ),
72        TLExpr::Pow(l, r) => TLExpr::Pow(
73            Box::new(subst(var, repl, *l)),
74            Box::new(subst(var, repl, *r)),
75        ),
76        TLExpr::Mod(l, r) => TLExpr::Mod(
77            Box::new(subst(var, repl, *l)),
78            Box::new(subst(var, repl, *r)),
79        ),
80        TLExpr::Min(l, r) => TLExpr::Min(
81            Box::new(subst(var, repl, *l)),
82            Box::new(subst(var, repl, *r)),
83        ),
84        TLExpr::Max(l, r) => TLExpr::Max(
85            Box::new(subst(var, repl, *l)),
86            Box::new(subst(var, repl, *r)),
87        ),
88        TLExpr::Eq(l, r) => TLExpr::Eq(
89            Box::new(subst(var, repl, *l)),
90            Box::new(subst(var, repl, *r)),
91        ),
92        TLExpr::Lt(l, r) => TLExpr::Lt(
93            Box::new(subst(var, repl, *l)),
94            Box::new(subst(var, repl, *r)),
95        ),
96        TLExpr::Gt(l, r) => TLExpr::Gt(
97            Box::new(subst(var, repl, *l)),
98            Box::new(subst(var, repl, *r)),
99        ),
100        TLExpr::Lte(l, r) => TLExpr::Lte(
101            Box::new(subst(var, repl, *l)),
102            Box::new(subst(var, repl, *r)),
103        ),
104        TLExpr::Gte(l, r) => TLExpr::Gte(
105            Box::new(subst(var, repl, *l)),
106            Box::new(subst(var, repl, *r)),
107        ),
108
109        // ── Unary nodes ──────────────────────────────────────────────────
110        TLExpr::Not(e) => TLExpr::Not(Box::new(subst(var, repl, *e))),
111        TLExpr::Score(e) => TLExpr::Score(Box::new(subst(var, repl, *e))),
112        TLExpr::Abs(e) => TLExpr::Abs(Box::new(subst(var, repl, *e))),
113        TLExpr::Floor(e) => TLExpr::Floor(Box::new(subst(var, repl, *e))),
114        TLExpr::Ceil(e) => TLExpr::Ceil(Box::new(subst(var, repl, *e))),
115        TLExpr::Round(e) => TLExpr::Round(Box::new(subst(var, repl, *e))),
116        TLExpr::Sqrt(e) => TLExpr::Sqrt(Box::new(subst(var, repl, *e))),
117        TLExpr::Exp(e) => TLExpr::Exp(Box::new(subst(var, repl, *e))),
118        TLExpr::Log(e) => TLExpr::Log(Box::new(subst(var, repl, *e))),
119        TLExpr::Sin(e) => TLExpr::Sin(Box::new(subst(var, repl, *e))),
120        TLExpr::Cos(e) => TLExpr::Cos(Box::new(subst(var, repl, *e))),
121        TLExpr::Tan(e) => TLExpr::Tan(Box::new(subst(var, repl, *e))),
122        TLExpr::Box(e) => TLExpr::Box(Box::new(subst(var, repl, *e))),
123        TLExpr::Diamond(e) => TLExpr::Diamond(Box::new(subst(var, repl, *e))),
124        TLExpr::Next(e) => TLExpr::Next(Box::new(subst(var, repl, *e))),
125        TLExpr::Eventually(e) => TLExpr::Eventually(Box::new(subst(var, repl, *e))),
126        TLExpr::Always(e) => TLExpr::Always(Box::new(subst(var, repl, *e))),
127
128        TLExpr::FuzzyNot { kind, expr } => TLExpr::FuzzyNot {
129            kind,
130            expr: Box::new(subst(var, repl, *expr)),
131        },
132        TLExpr::WeightedRule { weight, rule } => TLExpr::WeightedRule {
133            weight,
134            rule: Box::new(subst(var, repl, *rule)),
135        },
136
137        // ── Temporal / logical binary ─────────────────────────────────────
138        TLExpr::Until { before, after } => TLExpr::Until {
139            before: Box::new(subst(var, repl, *before)),
140            after: Box::new(subst(var, repl, *after)),
141        },
142        TLExpr::Release { released, releaser } => TLExpr::Release {
143            released: Box::new(subst(var, repl, *released)),
144            releaser: Box::new(subst(var, repl, *releaser)),
145        },
146        TLExpr::WeakUntil { before, after } => TLExpr::WeakUntil {
147            before: Box::new(subst(var, repl, *before)),
148            after: Box::new(subst(var, repl, *after)),
149        },
150        TLExpr::StrongRelease { released, releaser } => TLExpr::StrongRelease {
151            released: Box::new(subst(var, repl, *released)),
152            releaser: Box::new(subst(var, repl, *releaser)),
153        },
154
155        TLExpr::TNorm { kind, left, right } => TLExpr::TNorm {
156            kind,
157            left: Box::new(subst(var, repl, *left)),
158            right: Box::new(subst(var, repl, *right)),
159        },
160        TLExpr::TCoNorm { kind, left, right } => TLExpr::TCoNorm {
161            kind,
162            left: Box::new(subst(var, repl, *left)),
163            right: Box::new(subst(var, repl, *right)),
164        },
165        TLExpr::FuzzyImplication {
166            kind,
167            premise,
168            conclusion,
169        } => TLExpr::FuzzyImplication {
170            kind,
171            premise: Box::new(subst(var, repl, *premise)),
172            conclusion: Box::new(subst(var, repl, *conclusion)),
173        },
174
175        TLExpr::ProbabilisticChoice { alternatives } => TLExpr::ProbabilisticChoice {
176            alternatives: alternatives
177                .into_iter()
178                .map(|(p, e)| (p, subst(var, repl, e)))
179                .collect(),
180        },
181
182        // ── IfThenElse ────────────────────────────────────────────────────
183        TLExpr::IfThenElse {
184            condition,
185            then_branch,
186            else_branch,
187        } => TLExpr::IfThenElse {
188            condition: Box::new(subst(var, repl, *condition)),
189            then_branch: Box::new(subst(var, repl, *then_branch)),
190            else_branch: Box::new(subst(var, repl, *else_branch)),
191        },
192
193        // ── Binders — capture-avoiding ────────────────────────────────────
194        TLExpr::Exists {
195            var: binder,
196            domain,
197            body,
198        } => {
199            if binder == var {
200                TLExpr::Exists {
201                    var: binder,
202                    domain,
203                    body,
204                }
205            } else {
206                TLExpr::Exists {
207                    var: binder,
208                    domain,
209                    body: Box::new(subst(var, repl, *body)),
210                }
211            }
212        }
213        TLExpr::ForAll {
214            var: binder,
215            domain,
216            body,
217        } => {
218            if binder == var {
219                TLExpr::ForAll {
220                    var: binder,
221                    domain,
222                    body,
223                }
224            } else {
225                TLExpr::ForAll {
226                    var: binder,
227                    domain,
228                    body: Box::new(subst(var, repl, *body)),
229                }
230            }
231        }
232        TLExpr::SoftExists {
233            var: binder,
234            domain,
235            body,
236            temperature,
237        } => {
238            if binder == var {
239                TLExpr::SoftExists {
240                    var: binder,
241                    domain,
242                    body,
243                    temperature,
244                }
245            } else {
246                TLExpr::SoftExists {
247                    var: binder,
248                    domain,
249                    body: Box::new(subst(var, repl, *body)),
250                    temperature,
251                }
252            }
253        }
254        TLExpr::SoftForAll {
255            var: binder,
256            domain,
257            body,
258            temperature,
259        } => {
260            if binder == var {
261                TLExpr::SoftForAll {
262                    var: binder,
263                    domain,
264                    body,
265                    temperature,
266                }
267            } else {
268                TLExpr::SoftForAll {
269                    var: binder,
270                    domain,
271                    body: Box::new(subst(var, repl, *body)),
272                    temperature,
273                }
274            }
275        }
276        TLExpr::Aggregate {
277            op,
278            var: binder,
279            domain,
280            body,
281            group_by,
282        } => {
283            if binder == var {
284                TLExpr::Aggregate {
285                    op,
286                    var: binder,
287                    domain,
288                    body,
289                    group_by,
290                }
291            } else {
292                TLExpr::Aggregate {
293                    op,
294                    var: binder,
295                    domain,
296                    body: Box::new(subst(var, repl, *body)),
297                    group_by,
298                }
299            }
300        }
301        // Let: substitute in value unconditionally (outer scope), substitute
302        // in body only if binder != var (capture avoidance).
303        TLExpr::Let {
304            var: binder,
305            value,
306            body,
307        } => {
308            let new_value = subst(var, repl, *value);
309            if binder == var {
310                TLExpr::Let {
311                    var: binder,
312                    value: Box::new(new_value),
313                    body,
314                }
315            } else {
316                TLExpr::Let {
317                    var: binder,
318                    value: Box::new(new_value),
319                    body: Box::new(subst(var, repl, *body)),
320                }
321            }
322        }
323        TLExpr::Lambda {
324            var: binder,
325            var_type,
326            body,
327        } => {
328            if binder == var {
329                TLExpr::Lambda {
330                    var: binder,
331                    var_type,
332                    body,
333                }
334            } else {
335                TLExpr::Lambda {
336                    var: binder,
337                    var_type,
338                    body: Box::new(subst(var, repl, *body)),
339                }
340            }
341        }
342        TLExpr::CountingExists {
343            var: binder,
344            domain,
345            body,
346            min_count,
347        } => {
348            if binder == var {
349                TLExpr::CountingExists {
350                    var: binder,
351                    domain,
352                    body,
353                    min_count,
354                }
355            } else {
356                TLExpr::CountingExists {
357                    var: binder,
358                    domain,
359                    body: Box::new(subst(var, repl, *body)),
360                    min_count,
361                }
362            }
363        }
364        TLExpr::CountingForAll {
365            var: binder,
366            domain,
367            body,
368            min_count,
369        } => {
370            if binder == var {
371                TLExpr::CountingForAll {
372                    var: binder,
373                    domain,
374                    body,
375                    min_count,
376                }
377            } else {
378                TLExpr::CountingForAll {
379                    var: binder,
380                    domain,
381                    body: Box::new(subst(var, repl, *body)),
382                    min_count,
383                }
384            }
385        }
386        TLExpr::ExactCount {
387            var: binder,
388            domain,
389            body,
390            count,
391        } => {
392            if binder == var {
393                TLExpr::ExactCount {
394                    var: binder,
395                    domain,
396                    body,
397                    count,
398                }
399            } else {
400                TLExpr::ExactCount {
401                    var: binder,
402                    domain,
403                    body: Box::new(subst(var, repl, *body)),
404                    count,
405                }
406            }
407        }
408        TLExpr::Majority {
409            var: binder,
410            domain,
411            body,
412        } => {
413            if binder == var {
414                TLExpr::Majority {
415                    var: binder,
416                    domain,
417                    body,
418                }
419            } else {
420                TLExpr::Majority {
421                    var: binder,
422                    domain,
423                    body: Box::new(subst(var, repl, *body)),
424                }
425            }
426        }
427        TLExpr::LeastFixpoint { var: binder, body } => {
428            if binder == var {
429                TLExpr::LeastFixpoint { var: binder, body }
430            } else {
431                TLExpr::LeastFixpoint {
432                    var: binder,
433                    body: Box::new(subst(var, repl, *body)),
434                }
435            }
436        }
437        TLExpr::GreatestFixpoint { var: binder, body } => {
438            if binder == var {
439                TLExpr::GreatestFixpoint { var: binder, body }
440            } else {
441                TLExpr::GreatestFixpoint {
442                    var: binder,
443                    body: Box::new(subst(var, repl, *body)),
444                }
445            }
446        }
447        TLExpr::SetComprehension {
448            var: binder,
449            domain,
450            condition,
451        } => {
452            if binder == var {
453                TLExpr::SetComprehension {
454                    var: binder,
455                    domain,
456                    condition,
457                }
458            } else {
459                TLExpr::SetComprehension {
460                    var: binder,
461                    domain,
462                    condition: Box::new(subst(var, repl, *condition)),
463                }
464            }
465        }
466
467        // ── Set operations ────────────────────────────────────────────────
468        TLExpr::Apply { function, argument } => TLExpr::Apply {
469            function: Box::new(subst(var, repl, *function)),
470            argument: Box::new(subst(var, repl, *argument)),
471        },
472        TLExpr::SetMembership { element, set } => TLExpr::SetMembership {
473            element: Box::new(subst(var, repl, *element)),
474            set: Box::new(subst(var, repl, *set)),
475        },
476        TLExpr::SetUnion { left, right } => TLExpr::SetUnion {
477            left: Box::new(subst(var, repl, *left)),
478            right: Box::new(subst(var, repl, *right)),
479        },
480        TLExpr::SetIntersection { left, right } => TLExpr::SetIntersection {
481            left: Box::new(subst(var, repl, *left)),
482            right: Box::new(subst(var, repl, *right)),
483        },
484        TLExpr::SetDifference { left, right } => TLExpr::SetDifference {
485            left: Box::new(subst(var, repl, *left)),
486            right: Box::new(subst(var, repl, *right)),
487        },
488        TLExpr::SetCardinality { set } => TLExpr::SetCardinality {
489            set: Box::new(subst(var, repl, *set)),
490        },
491
492        TLExpr::At { nominal, formula } => TLExpr::At {
493            nominal,
494            formula: Box::new(subst(var, repl, *formula)),
495        },
496        TLExpr::Somewhere { formula } => TLExpr::Somewhere {
497            formula: Box::new(subst(var, repl, *formula)),
498        },
499        TLExpr::Everywhere { formula } => TLExpr::Everywhere {
500            formula: Box::new(subst(var, repl, *formula)),
501        },
502        TLExpr::Explain { formula } => TLExpr::Explain {
503            formula: Box::new(subst(var, repl, *formula)),
504        },
505
506        TLExpr::GlobalCardinality {
507            variables,
508            values,
509            min_occurrences,
510            max_occurrences,
511        } => TLExpr::GlobalCardinality {
512            variables,
513            values: values.into_iter().map(|e| subst(var, repl, e)).collect(),
514            min_occurrences,
515            max_occurrences,
516        },
517
518        // ── Leaves ───────────────────────────────────────────────────────
519        leaf @ (TLExpr::Constant(_)
520        | TLExpr::EmptySet
521        | TLExpr::AllDifferent { .. }
522        | TLExpr::Nominal { .. }
523        | TLExpr::Abducible { .. }
524        | TLExpr::SymbolLiteral(_)) => leaf,
525
526        TLExpr::Match { scrutinee, arms } => TLExpr::Match {
527            scrutinee: Box::new(subst(var, repl, *scrutinee)),
528            arms: arms
529                .into_iter()
530                .map(|(pat, body)| (pat, Box::new(subst(var, repl, *body))))
531                .collect(),
532        },
533    }
534}