1use tensorlogic_ir::TLExpr;
2
3pub fn substitute(var: &str, replacement: &TLExpr, body: TLExpr) -> TLExpr {
11 subst(var, replacement, body)
12}
13
14pub(crate) fn subst(var: &str, repl: &TLExpr, expr: TLExpr) -> TLExpr {
15 match expr {
16 TLExpr::Pred { ref name, ref args } if args.is_empty() && name == var => repl.clone(),
18
19 TLExpr::Pred { name, args } => {
21 let new_args = args
22 .into_iter()
23 .map(|t| match &t {
24 tensorlogic_ir::Term::Var(v) if v == var => {
25 match repl {
28 TLExpr::Pred { name: rn, args: ra } if ra.is_empty() => {
29 tensorlogic_ir::Term::Var(rn.clone())
30 }
31 _ => t,
32 }
33 }
34 _ => t,
35 })
36 .collect();
37 TLExpr::Pred {
38 name,
39 args: new_args,
40 }
41 }
42
43 TLExpr::And(l, r) => TLExpr::And(
45 Box::new(subst(var, repl, *l)),
46 Box::new(subst(var, repl, *r)),
47 ),
48 TLExpr::Or(l, r) => TLExpr::Or(
49 Box::new(subst(var, repl, *l)),
50 Box::new(subst(var, repl, *r)),
51 ),
52 TLExpr::Imply(l, r) => TLExpr::Imply(
53 Box::new(subst(var, repl, *l)),
54 Box::new(subst(var, repl, *r)),
55 ),
56 TLExpr::Add(l, r) => TLExpr::Add(
57 Box::new(subst(var, repl, *l)),
58 Box::new(subst(var, repl, *r)),
59 ),
60 TLExpr::Sub(l, r) => TLExpr::Sub(
61 Box::new(subst(var, repl, *l)),
62 Box::new(subst(var, repl, *r)),
63 ),
64 TLExpr::Mul(l, r) => TLExpr::Mul(
65 Box::new(subst(var, repl, *l)),
66 Box::new(subst(var, repl, *r)),
67 ),
68 TLExpr::Div(l, r) => TLExpr::Div(
69 Box::new(subst(var, repl, *l)),
70 Box::new(subst(var, repl, *r)),
71 ),
72 TLExpr::Pow(l, r) => TLExpr::Pow(
73 Box::new(subst(var, repl, *l)),
74 Box::new(subst(var, repl, *r)),
75 ),
76 TLExpr::Mod(l, r) => TLExpr::Mod(
77 Box::new(subst(var, repl, *l)),
78 Box::new(subst(var, repl, *r)),
79 ),
80 TLExpr::Min(l, r) => TLExpr::Min(
81 Box::new(subst(var, repl, *l)),
82 Box::new(subst(var, repl, *r)),
83 ),
84 TLExpr::Max(l, r) => TLExpr::Max(
85 Box::new(subst(var, repl, *l)),
86 Box::new(subst(var, repl, *r)),
87 ),
88 TLExpr::Eq(l, r) => TLExpr::Eq(
89 Box::new(subst(var, repl, *l)),
90 Box::new(subst(var, repl, *r)),
91 ),
92 TLExpr::Lt(l, r) => TLExpr::Lt(
93 Box::new(subst(var, repl, *l)),
94 Box::new(subst(var, repl, *r)),
95 ),
96 TLExpr::Gt(l, r) => TLExpr::Gt(
97 Box::new(subst(var, repl, *l)),
98 Box::new(subst(var, repl, *r)),
99 ),
100 TLExpr::Lte(l, r) => TLExpr::Lte(
101 Box::new(subst(var, repl, *l)),
102 Box::new(subst(var, repl, *r)),
103 ),
104 TLExpr::Gte(l, r) => TLExpr::Gte(
105 Box::new(subst(var, repl, *l)),
106 Box::new(subst(var, repl, *r)),
107 ),
108
109 TLExpr::Not(e) => TLExpr::Not(Box::new(subst(var, repl, *e))),
111 TLExpr::Score(e) => TLExpr::Score(Box::new(subst(var, repl, *e))),
112 TLExpr::Abs(e) => TLExpr::Abs(Box::new(subst(var, repl, *e))),
113 TLExpr::Floor(e) => TLExpr::Floor(Box::new(subst(var, repl, *e))),
114 TLExpr::Ceil(e) => TLExpr::Ceil(Box::new(subst(var, repl, *e))),
115 TLExpr::Round(e) => TLExpr::Round(Box::new(subst(var, repl, *e))),
116 TLExpr::Sqrt(e) => TLExpr::Sqrt(Box::new(subst(var, repl, *e))),
117 TLExpr::Exp(e) => TLExpr::Exp(Box::new(subst(var, repl, *e))),
118 TLExpr::Log(e) => TLExpr::Log(Box::new(subst(var, repl, *e))),
119 TLExpr::Sin(e) => TLExpr::Sin(Box::new(subst(var, repl, *e))),
120 TLExpr::Cos(e) => TLExpr::Cos(Box::new(subst(var, repl, *e))),
121 TLExpr::Tan(e) => TLExpr::Tan(Box::new(subst(var, repl, *e))),
122 TLExpr::Box(e) => TLExpr::Box(Box::new(subst(var, repl, *e))),
123 TLExpr::Diamond(e) => TLExpr::Diamond(Box::new(subst(var, repl, *e))),
124 TLExpr::Next(e) => TLExpr::Next(Box::new(subst(var, repl, *e))),
125 TLExpr::Eventually(e) => TLExpr::Eventually(Box::new(subst(var, repl, *e))),
126 TLExpr::Always(e) => TLExpr::Always(Box::new(subst(var, repl, *e))),
127
128 TLExpr::FuzzyNot { kind, expr } => TLExpr::FuzzyNot {
129 kind,
130 expr: Box::new(subst(var, repl, *expr)),
131 },
132 TLExpr::WeightedRule { weight, rule } => TLExpr::WeightedRule {
133 weight,
134 rule: Box::new(subst(var, repl, *rule)),
135 },
136
137 TLExpr::Until { before, after } => TLExpr::Until {
139 before: Box::new(subst(var, repl, *before)),
140 after: Box::new(subst(var, repl, *after)),
141 },
142 TLExpr::Release { released, releaser } => TLExpr::Release {
143 released: Box::new(subst(var, repl, *released)),
144 releaser: Box::new(subst(var, repl, *releaser)),
145 },
146 TLExpr::WeakUntil { before, after } => TLExpr::WeakUntil {
147 before: Box::new(subst(var, repl, *before)),
148 after: Box::new(subst(var, repl, *after)),
149 },
150 TLExpr::StrongRelease { released, releaser } => TLExpr::StrongRelease {
151 released: Box::new(subst(var, repl, *released)),
152 releaser: Box::new(subst(var, repl, *releaser)),
153 },
154
155 TLExpr::TNorm { kind, left, right } => TLExpr::TNorm {
156 kind,
157 left: Box::new(subst(var, repl, *left)),
158 right: Box::new(subst(var, repl, *right)),
159 },
160 TLExpr::TCoNorm { kind, left, right } => TLExpr::TCoNorm {
161 kind,
162 left: Box::new(subst(var, repl, *left)),
163 right: Box::new(subst(var, repl, *right)),
164 },
165 TLExpr::FuzzyImplication {
166 kind,
167 premise,
168 conclusion,
169 } => TLExpr::FuzzyImplication {
170 kind,
171 premise: Box::new(subst(var, repl, *premise)),
172 conclusion: Box::new(subst(var, repl, *conclusion)),
173 },
174
175 TLExpr::ProbabilisticChoice { alternatives } => TLExpr::ProbabilisticChoice {
176 alternatives: alternatives
177 .into_iter()
178 .map(|(p, e)| (p, subst(var, repl, e)))
179 .collect(),
180 },
181
182 TLExpr::IfThenElse {
184 condition,
185 then_branch,
186 else_branch,
187 } => TLExpr::IfThenElse {
188 condition: Box::new(subst(var, repl, *condition)),
189 then_branch: Box::new(subst(var, repl, *then_branch)),
190 else_branch: Box::new(subst(var, repl, *else_branch)),
191 },
192
193 TLExpr::Exists {
195 var: binder,
196 domain,
197 body,
198 } => {
199 if binder == var {
200 TLExpr::Exists {
201 var: binder,
202 domain,
203 body,
204 }
205 } else {
206 TLExpr::Exists {
207 var: binder,
208 domain,
209 body: Box::new(subst(var, repl, *body)),
210 }
211 }
212 }
213 TLExpr::ForAll {
214 var: binder,
215 domain,
216 body,
217 } => {
218 if binder == var {
219 TLExpr::ForAll {
220 var: binder,
221 domain,
222 body,
223 }
224 } else {
225 TLExpr::ForAll {
226 var: binder,
227 domain,
228 body: Box::new(subst(var, repl, *body)),
229 }
230 }
231 }
232 TLExpr::SoftExists {
233 var: binder,
234 domain,
235 body,
236 temperature,
237 } => {
238 if binder == var {
239 TLExpr::SoftExists {
240 var: binder,
241 domain,
242 body,
243 temperature,
244 }
245 } else {
246 TLExpr::SoftExists {
247 var: binder,
248 domain,
249 body: Box::new(subst(var, repl, *body)),
250 temperature,
251 }
252 }
253 }
254 TLExpr::SoftForAll {
255 var: binder,
256 domain,
257 body,
258 temperature,
259 } => {
260 if binder == var {
261 TLExpr::SoftForAll {
262 var: binder,
263 domain,
264 body,
265 temperature,
266 }
267 } else {
268 TLExpr::SoftForAll {
269 var: binder,
270 domain,
271 body: Box::new(subst(var, repl, *body)),
272 temperature,
273 }
274 }
275 }
276 TLExpr::Aggregate {
277 op,
278 var: binder,
279 domain,
280 body,
281 group_by,
282 } => {
283 if binder == var {
284 TLExpr::Aggregate {
285 op,
286 var: binder,
287 domain,
288 body,
289 group_by,
290 }
291 } else {
292 TLExpr::Aggregate {
293 op,
294 var: binder,
295 domain,
296 body: Box::new(subst(var, repl, *body)),
297 group_by,
298 }
299 }
300 }
301 TLExpr::Let {
304 var: binder,
305 value,
306 body,
307 } => {
308 let new_value = subst(var, repl, *value);
309 if binder == var {
310 TLExpr::Let {
311 var: binder,
312 value: Box::new(new_value),
313 body,
314 }
315 } else {
316 TLExpr::Let {
317 var: binder,
318 value: Box::new(new_value),
319 body: Box::new(subst(var, repl, *body)),
320 }
321 }
322 }
323 TLExpr::Lambda {
324 var: binder,
325 var_type,
326 body,
327 } => {
328 if binder == var {
329 TLExpr::Lambda {
330 var: binder,
331 var_type,
332 body,
333 }
334 } else {
335 TLExpr::Lambda {
336 var: binder,
337 var_type,
338 body: Box::new(subst(var, repl, *body)),
339 }
340 }
341 }
342 TLExpr::CountingExists {
343 var: binder,
344 domain,
345 body,
346 min_count,
347 } => {
348 if binder == var {
349 TLExpr::CountingExists {
350 var: binder,
351 domain,
352 body,
353 min_count,
354 }
355 } else {
356 TLExpr::CountingExists {
357 var: binder,
358 domain,
359 body: Box::new(subst(var, repl, *body)),
360 min_count,
361 }
362 }
363 }
364 TLExpr::CountingForAll {
365 var: binder,
366 domain,
367 body,
368 min_count,
369 } => {
370 if binder == var {
371 TLExpr::CountingForAll {
372 var: binder,
373 domain,
374 body,
375 min_count,
376 }
377 } else {
378 TLExpr::CountingForAll {
379 var: binder,
380 domain,
381 body: Box::new(subst(var, repl, *body)),
382 min_count,
383 }
384 }
385 }
386 TLExpr::ExactCount {
387 var: binder,
388 domain,
389 body,
390 count,
391 } => {
392 if binder == var {
393 TLExpr::ExactCount {
394 var: binder,
395 domain,
396 body,
397 count,
398 }
399 } else {
400 TLExpr::ExactCount {
401 var: binder,
402 domain,
403 body: Box::new(subst(var, repl, *body)),
404 count,
405 }
406 }
407 }
408 TLExpr::Majority {
409 var: binder,
410 domain,
411 body,
412 } => {
413 if binder == var {
414 TLExpr::Majority {
415 var: binder,
416 domain,
417 body,
418 }
419 } else {
420 TLExpr::Majority {
421 var: binder,
422 domain,
423 body: Box::new(subst(var, repl, *body)),
424 }
425 }
426 }
427 TLExpr::LeastFixpoint { var: binder, body } => {
428 if binder == var {
429 TLExpr::LeastFixpoint { var: binder, body }
430 } else {
431 TLExpr::LeastFixpoint {
432 var: binder,
433 body: Box::new(subst(var, repl, *body)),
434 }
435 }
436 }
437 TLExpr::GreatestFixpoint { var: binder, body } => {
438 if binder == var {
439 TLExpr::GreatestFixpoint { var: binder, body }
440 } else {
441 TLExpr::GreatestFixpoint {
442 var: binder,
443 body: Box::new(subst(var, repl, *body)),
444 }
445 }
446 }
447 TLExpr::SetComprehension {
448 var: binder,
449 domain,
450 condition,
451 } => {
452 if binder == var {
453 TLExpr::SetComprehension {
454 var: binder,
455 domain,
456 condition,
457 }
458 } else {
459 TLExpr::SetComprehension {
460 var: binder,
461 domain,
462 condition: Box::new(subst(var, repl, *condition)),
463 }
464 }
465 }
466
467 TLExpr::Apply { function, argument } => TLExpr::Apply {
469 function: Box::new(subst(var, repl, *function)),
470 argument: Box::new(subst(var, repl, *argument)),
471 },
472 TLExpr::SetMembership { element, set } => TLExpr::SetMembership {
473 element: Box::new(subst(var, repl, *element)),
474 set: Box::new(subst(var, repl, *set)),
475 },
476 TLExpr::SetUnion { left, right } => TLExpr::SetUnion {
477 left: Box::new(subst(var, repl, *left)),
478 right: Box::new(subst(var, repl, *right)),
479 },
480 TLExpr::SetIntersection { left, right } => TLExpr::SetIntersection {
481 left: Box::new(subst(var, repl, *left)),
482 right: Box::new(subst(var, repl, *right)),
483 },
484 TLExpr::SetDifference { left, right } => TLExpr::SetDifference {
485 left: Box::new(subst(var, repl, *left)),
486 right: Box::new(subst(var, repl, *right)),
487 },
488 TLExpr::SetCardinality { set } => TLExpr::SetCardinality {
489 set: Box::new(subst(var, repl, *set)),
490 },
491
492 TLExpr::At { nominal, formula } => TLExpr::At {
493 nominal,
494 formula: Box::new(subst(var, repl, *formula)),
495 },
496 TLExpr::Somewhere { formula } => TLExpr::Somewhere {
497 formula: Box::new(subst(var, repl, *formula)),
498 },
499 TLExpr::Everywhere { formula } => TLExpr::Everywhere {
500 formula: Box::new(subst(var, repl, *formula)),
501 },
502 TLExpr::Explain { formula } => TLExpr::Explain {
503 formula: Box::new(subst(var, repl, *formula)),
504 },
505
506 TLExpr::GlobalCardinality {
507 variables,
508 values,
509 min_occurrences,
510 max_occurrences,
511 } => TLExpr::GlobalCardinality {
512 variables,
513 values: values.into_iter().map(|e| subst(var, repl, e)).collect(),
514 min_occurrences,
515 max_occurrences,
516 },
517
518 leaf @ (TLExpr::Constant(_)
520 | TLExpr::EmptySet
521 | TLExpr::AllDifferent { .. }
522 | TLExpr::Nominal { .. }
523 | TLExpr::Abducible { .. }
524 | TLExpr::SymbolLiteral(_)) => leaf,
525
526 TLExpr::Match { scrutinee, arms } => TLExpr::Match {
527 scrutinee: Box::new(subst(var, repl, *scrutinee)),
528 arms: arms
529 .into_iter()
530 .map(|(pat, body)| (pat, Box::new(subst(var, repl, *body))))
531 .collect(),
532 },
533 }
534}