1use tensorlogic_ir::TLExpr;
2
3use super::config::{InlineConfig, InlineStats};
4use super::helpers::{
5 count_free_occurrences, count_nodes, expr_depth, is_constant_binding, is_var_binding,
6};
7use super::substitute::substitute;
8
9pub struct LetInliner {
15 pub(super) config: InlineConfig,
16}
17
18impl Default for LetInliner {
19 fn default() -> Self {
20 Self::with_default()
21 }
22}
23
24impl LetInliner {
25 pub fn new(config: InlineConfig) -> Self {
27 Self { config }
28 }
29
30 pub fn with_default() -> Self {
32 Self::new(InlineConfig::default())
33 }
34
35 pub fn run(&self, expr: TLExpr) -> (TLExpr, InlineStats) {
43 let mut stats = InlineStats {
44 nodes_before: count_nodes(&expr),
45 ..Default::default()
46 };
47
48 let mut current = expr;
49 let max = self.config.max_passes.max(1);
50
51 for _ in 0..max {
52 let (next, changed) = self.run_pass(current, &mut stats);
53 stats.passes += 1;
54 current = next;
55 if !changed {
56 break;
57 }
58 }
59
60 stats.nodes_after = count_nodes(¤t);
61 (current, stats)
62 }
63
64 fn run_pass(&self, expr: TLExpr, stats: &mut InlineStats) -> (TLExpr, bool) {
72 self.inline_expr(expr, stats)
73 }
74
75 fn inline_expr(&self, expr: TLExpr, stats: &mut InlineStats) -> (TLExpr, bool) {
85 match expr {
86 TLExpr::Let { var, value, body } => {
88 let (new_value, cv) = self.inline_expr(*value, stats);
90 let (new_body, cb) = self.inline_expr(*body, stats);
91 let child_changed = cv || cb;
92
93 let depth_ok = expr_depth(&new_value) <= self.config.max_inline_depth;
95
96 if depth_ok {
97 if self.config.inline_constants && is_constant_binding(&new_value) {
99 stats.constant_inlines += 1;
100 let inlined = substitute(&var, &new_value, new_body);
101 let (final_expr, _) = self.inline_expr(inlined, stats);
103 return (final_expr, true);
104 }
105
106 if self.config.inline_vars && is_var_binding(&new_value) {
108 stats.variable_inlines += 1;
109 let inlined = substitute(&var, &new_value, new_body);
110 let (final_expr, _) = self.inline_expr(inlined, stats);
111 return (final_expr, true);
112 }
113
114 if self.config.inline_single_use && count_free_occurrences(&var, &new_body) == 1
116 {
117 stats.single_use_inlines += 1;
118 let inlined = substitute(&var, &new_value, new_body);
119 let (final_expr, _) = self.inline_expr(inlined, stats);
120 return (final_expr, true);
121 }
122 }
123
124 (
126 TLExpr::Let {
127 var,
128 value: Box::new(new_value),
129 body: Box::new(new_body),
130 },
131 child_changed,
132 )
133 }
134
135 TLExpr::And(l, r) => {
137 let (nl, cl) = self.inline_expr(*l, stats);
138 let (nr, cr) = self.inline_expr(*r, stats);
139 (TLExpr::And(Box::new(nl), Box::new(nr)), cl || cr)
140 }
141 TLExpr::Or(l, r) => {
142 let (nl, cl) = self.inline_expr(*l, stats);
143 let (nr, cr) = self.inline_expr(*r, stats);
144 (TLExpr::Or(Box::new(nl), Box::new(nr)), cl || cr)
145 }
146 TLExpr::Not(e) => {
147 let (ne, changed) = self.inline_expr(*e, stats);
148 (TLExpr::Not(Box::new(ne)), changed)
149 }
150 TLExpr::Imply(l, r) => {
151 let (nl, cl) = self.inline_expr(*l, stats);
152 let (nr, cr) = self.inline_expr(*r, stats);
153 (TLExpr::Imply(Box::new(nl), Box::new(nr)), cl || cr)
154 }
155
156 TLExpr::Add(l, r) => self.map_binary(TLExpr::Add, *l, *r, stats),
158 TLExpr::Sub(l, r) => self.map_binary(TLExpr::Sub, *l, *r, stats),
159 TLExpr::Mul(l, r) => self.map_binary(TLExpr::Mul, *l, *r, stats),
160 TLExpr::Div(l, r) => self.map_binary(TLExpr::Div, *l, *r, stats),
161 TLExpr::Pow(l, r) => self.map_binary(TLExpr::Pow, *l, *r, stats),
162 TLExpr::Mod(l, r) => self.map_binary(TLExpr::Mod, *l, *r, stats),
163 TLExpr::Min(l, r) => self.map_binary(TLExpr::Min, *l, *r, stats),
164 TLExpr::Max(l, r) => self.map_binary(TLExpr::Max, *l, *r, stats),
165
166 TLExpr::Eq(l, r) => self.map_binary(TLExpr::Eq, *l, *r, stats),
168 TLExpr::Lt(l, r) => self.map_binary(TLExpr::Lt, *l, *r, stats),
169 TLExpr::Gt(l, r) => self.map_binary(TLExpr::Gt, *l, *r, stats),
170 TLExpr::Lte(l, r) => self.map_binary(TLExpr::Lte, *l, *r, stats),
171 TLExpr::Gte(l, r) => self.map_binary(TLExpr::Gte, *l, *r, stats),
172
173 TLExpr::Abs(e) => self.map_unary(TLExpr::Abs, *e, stats),
175 TLExpr::Floor(e) => self.map_unary(TLExpr::Floor, *e, stats),
176 TLExpr::Ceil(e) => self.map_unary(TLExpr::Ceil, *e, stats),
177 TLExpr::Round(e) => self.map_unary(TLExpr::Round, *e, stats),
178 TLExpr::Sqrt(e) => self.map_unary(TLExpr::Sqrt, *e, stats),
179 TLExpr::Exp(e) => self.map_unary(TLExpr::Exp, *e, stats),
180 TLExpr::Log(e) => self.map_unary(TLExpr::Log, *e, stats),
181 TLExpr::Sin(e) => self.map_unary(TLExpr::Sin, *e, stats),
182 TLExpr::Cos(e) => self.map_unary(TLExpr::Cos, *e, stats),
183 TLExpr::Tan(e) => self.map_unary(TLExpr::Tan, *e, stats),
184 TLExpr::Score(e) => self.map_unary(TLExpr::Score, *e, stats),
185
186 TLExpr::Box(e) => self.map_unary(TLExpr::Box, *e, stats),
188 TLExpr::Diamond(e) => self.map_unary(TLExpr::Diamond, *e, stats),
189 TLExpr::Next(e) => self.map_unary(TLExpr::Next, *e, stats),
190 TLExpr::Eventually(e) => self.map_unary(TLExpr::Eventually, *e, stats),
191 TLExpr::Always(e) => self.map_unary(TLExpr::Always, *e, stats),
192
193 TLExpr::Until { before, after } => {
195 let (nb, cb) = self.inline_expr(*before, stats);
196 let (na, ca) = self.inline_expr(*after, stats);
197 (
198 TLExpr::Until {
199 before: Box::new(nb),
200 after: Box::new(na),
201 },
202 cb || ca,
203 )
204 }
205 TLExpr::Release { released, releaser } => {
206 let (nr, cr) = self.inline_expr(*released, stats);
207 let (ne, ce) = self.inline_expr(*releaser, stats);
208 (
209 TLExpr::Release {
210 released: Box::new(nr),
211 releaser: Box::new(ne),
212 },
213 cr || ce,
214 )
215 }
216 TLExpr::WeakUntil { before, after } => {
217 let (nb, cb) = self.inline_expr(*before, stats);
218 let (na, ca) = self.inline_expr(*after, stats);
219 (
220 TLExpr::WeakUntil {
221 before: Box::new(nb),
222 after: Box::new(na),
223 },
224 cb || ca,
225 )
226 }
227 TLExpr::StrongRelease { released, releaser } => {
228 let (nr, cr) = self.inline_expr(*released, stats);
229 let (ne, ce) = self.inline_expr(*releaser, stats);
230 (
231 TLExpr::StrongRelease {
232 released: Box::new(nr),
233 releaser: Box::new(ne),
234 },
235 cr || ce,
236 )
237 }
238
239 TLExpr::TNorm { kind, left, right } => {
241 let (nl, cl) = self.inline_expr(*left, stats);
242 let (nr, cr) = self.inline_expr(*right, stats);
243 (
244 TLExpr::TNorm {
245 kind,
246 left: Box::new(nl),
247 right: Box::new(nr),
248 },
249 cl || cr,
250 )
251 }
252 TLExpr::TCoNorm { kind, left, right } => {
253 let (nl, cl) = self.inline_expr(*left, stats);
254 let (nr, cr) = self.inline_expr(*right, stats);
255 (
256 TLExpr::TCoNorm {
257 kind,
258 left: Box::new(nl),
259 right: Box::new(nr),
260 },
261 cl || cr,
262 )
263 }
264 TLExpr::FuzzyNot { kind, expr } => {
265 let (ne, changed) = self.inline_expr(*expr, stats);
266 (
267 TLExpr::FuzzyNot {
268 kind,
269 expr: Box::new(ne),
270 },
271 changed,
272 )
273 }
274 TLExpr::FuzzyImplication {
275 kind,
276 premise,
277 conclusion,
278 } => {
279 let (np, cp) = self.inline_expr(*premise, stats);
280 let (nc, cc) = self.inline_expr(*conclusion, stats);
281 (
282 TLExpr::FuzzyImplication {
283 kind,
284 premise: Box::new(np),
285 conclusion: Box::new(nc),
286 },
287 cp || cc,
288 )
289 }
290
291 TLExpr::WeightedRule { weight, rule } => {
293 let (nr, changed) = self.inline_expr(*rule, stats);
294 (
295 TLExpr::WeightedRule {
296 weight,
297 rule: Box::new(nr),
298 },
299 changed,
300 )
301 }
302 TLExpr::ProbabilisticChoice { alternatives } => {
303 let mut any_changed = false;
304 let new_alts: Vec<(f64, TLExpr)> = alternatives
305 .into_iter()
306 .map(|(prob, e)| {
307 let (ne, changed) = self.inline_expr(e, stats);
308 any_changed = any_changed || changed;
309 (prob, ne)
310 })
311 .collect();
312 (
313 TLExpr::ProbabilisticChoice {
314 alternatives: new_alts,
315 },
316 any_changed,
317 )
318 }
319
320 TLExpr::IfThenElse {
322 condition,
323 then_branch,
324 else_branch,
325 } => {
326 let (nc, cc) = self.inline_expr(*condition, stats);
327 let (nt, ct) = self.inline_expr(*then_branch, stats);
328 let (ne, ce) = self.inline_expr(*else_branch, stats);
329 (
330 TLExpr::IfThenElse {
331 condition: Box::new(nc),
332 then_branch: Box::new(nt),
333 else_branch: Box::new(ne),
334 },
335 cc || ct || ce,
336 )
337 }
338
339 TLExpr::Exists { var, domain, body } => {
341 let (new_body, changed) = self.inline_expr(*body, stats);
342 (
343 TLExpr::Exists {
344 var,
345 domain,
346 body: Box::new(new_body),
347 },
348 changed,
349 )
350 }
351 TLExpr::ForAll { var, domain, body } => {
352 let (new_body, changed) = self.inline_expr(*body, stats);
353 (
354 TLExpr::ForAll {
355 var,
356 domain,
357 body: Box::new(new_body),
358 },
359 changed,
360 )
361 }
362 TLExpr::SoftExists {
363 var,
364 domain,
365 body,
366 temperature,
367 } => {
368 let (new_body, changed) = self.inline_expr(*body, stats);
369 (
370 TLExpr::SoftExists {
371 var,
372 domain,
373 body: Box::new(new_body),
374 temperature,
375 },
376 changed,
377 )
378 }
379 TLExpr::SoftForAll {
380 var,
381 domain,
382 body,
383 temperature,
384 } => {
385 let (new_body, changed) = self.inline_expr(*body, stats);
386 (
387 TLExpr::SoftForAll {
388 var,
389 domain,
390 body: Box::new(new_body),
391 temperature,
392 },
393 changed,
394 )
395 }
396
397 TLExpr::Aggregate {
399 op,
400 var,
401 domain,
402 body,
403 group_by,
404 } => {
405 let (new_body, changed) = self.inline_expr(*body, stats);
406 (
407 TLExpr::Aggregate {
408 op,
409 var,
410 domain,
411 body: Box::new(new_body),
412 group_by,
413 },
414 changed,
415 )
416 }
417
418 TLExpr::Lambda {
420 var,
421 var_type,
422 body,
423 } => {
424 let (new_body, changed) = self.inline_expr(*body, stats);
425 (
426 TLExpr::Lambda {
427 var,
428 var_type,
429 body: Box::new(new_body),
430 },
431 changed,
432 )
433 }
434 TLExpr::Apply { function, argument } => {
435 let (nf, cf) = self.inline_expr(*function, stats);
436 let (na, ca) = self.inline_expr(*argument, stats);
437 (
438 TLExpr::Apply {
439 function: Box::new(nf),
440 argument: Box::new(na),
441 },
442 cf || ca,
443 )
444 }
445
446 TLExpr::SetMembership { element, set } => {
448 let (ne, ce) = self.inline_expr(*element, stats);
449 let (ns, cs) = self.inline_expr(*set, stats);
450 (
451 TLExpr::SetMembership {
452 element: Box::new(ne),
453 set: Box::new(ns),
454 },
455 ce || cs,
456 )
457 }
458 TLExpr::SetUnion { left, right } => {
459 let (nl, cl) = self.inline_expr(*left, stats);
460 let (nr, cr) = self.inline_expr(*right, stats);
461 (
462 TLExpr::SetUnion {
463 left: Box::new(nl),
464 right: Box::new(nr),
465 },
466 cl || cr,
467 )
468 }
469 TLExpr::SetIntersection { left, right } => {
470 let (nl, cl) = self.inline_expr(*left, stats);
471 let (nr, cr) = self.inline_expr(*right, stats);
472 (
473 TLExpr::SetIntersection {
474 left: Box::new(nl),
475 right: Box::new(nr),
476 },
477 cl || cr,
478 )
479 }
480 TLExpr::SetDifference { left, right } => {
481 let (nl, cl) = self.inline_expr(*left, stats);
482 let (nr, cr) = self.inline_expr(*right, stats);
483 (
484 TLExpr::SetDifference {
485 left: Box::new(nl),
486 right: Box::new(nr),
487 },
488 cl || cr,
489 )
490 }
491 TLExpr::SetCardinality { set } => {
492 let (ns, changed) = self.inline_expr(*set, stats);
493 (TLExpr::SetCardinality { set: Box::new(ns) }, changed)
494 }
495 TLExpr::SetComprehension {
496 var,
497 domain,
498 condition,
499 } => {
500 let (nc, changed) = self.inline_expr(*condition, stats);
501 (
502 TLExpr::SetComprehension {
503 var,
504 domain,
505 condition: Box::new(nc),
506 },
507 changed,
508 )
509 }
510
511 TLExpr::CountingExists {
513 var,
514 domain,
515 body,
516 min_count,
517 } => {
518 let (new_body, changed) = self.inline_expr(*body, stats);
519 (
520 TLExpr::CountingExists {
521 var,
522 domain,
523 body: Box::new(new_body),
524 min_count,
525 },
526 changed,
527 )
528 }
529 TLExpr::CountingForAll {
530 var,
531 domain,
532 body,
533 min_count,
534 } => {
535 let (new_body, changed) = self.inline_expr(*body, stats);
536 (
537 TLExpr::CountingForAll {
538 var,
539 domain,
540 body: Box::new(new_body),
541 min_count,
542 },
543 changed,
544 )
545 }
546 TLExpr::ExactCount {
547 var,
548 domain,
549 body,
550 count,
551 } => {
552 let (new_body, changed) = self.inline_expr(*body, stats);
553 (
554 TLExpr::ExactCount {
555 var,
556 domain,
557 body: Box::new(new_body),
558 count,
559 },
560 changed,
561 )
562 }
563 TLExpr::Majority { var, domain, body } => {
564 let (new_body, changed) = self.inline_expr(*body, stats);
565 (
566 TLExpr::Majority {
567 var,
568 domain,
569 body: Box::new(new_body),
570 },
571 changed,
572 )
573 }
574
575 TLExpr::LeastFixpoint { var, body } => {
577 let (new_body, changed) = self.inline_expr(*body, stats);
578 (
579 TLExpr::LeastFixpoint {
580 var,
581 body: Box::new(new_body),
582 },
583 changed,
584 )
585 }
586 TLExpr::GreatestFixpoint { var, body } => {
587 let (new_body, changed) = self.inline_expr(*body, stats);
588 (
589 TLExpr::GreatestFixpoint {
590 var,
591 body: Box::new(new_body),
592 },
593 changed,
594 )
595 }
596
597 TLExpr::At { nominal, formula } => {
599 let (nf, changed) = self.inline_expr(*formula, stats);
600 (
601 TLExpr::At {
602 nominal,
603 formula: Box::new(nf),
604 },
605 changed,
606 )
607 }
608 TLExpr::Somewhere { formula } => {
609 let (nf, changed) = self.inline_expr(*formula, stats);
610 (
611 TLExpr::Somewhere {
612 formula: Box::new(nf),
613 },
614 changed,
615 )
616 }
617 TLExpr::Everywhere { formula } => {
618 let (nf, changed) = self.inline_expr(*formula, stats);
619 (
620 TLExpr::Everywhere {
621 formula: Box::new(nf),
622 },
623 changed,
624 )
625 }
626
627 TLExpr::Explain { formula } => {
629 let (nf, changed) = self.inline_expr(*formula, stats);
630 (
631 TLExpr::Explain {
632 formula: Box::new(nf),
633 },
634 changed,
635 )
636 }
637
638 leaf @ (TLExpr::Pred { .. }
640 | TLExpr::Constant(_)
641 | TLExpr::EmptySet
642 | TLExpr::AllDifferent { .. }
643 | TLExpr::GlobalCardinality { .. }
644 | TLExpr::Nominal { .. }
645 | TLExpr::Abducible { .. }
646 | TLExpr::SymbolLiteral(_)) => (leaf, false),
647
648 TLExpr::Match { scrutinee, arms } => {
649 let (new_scrutinee, sc) = self.inline_expr(*scrutinee, stats);
650 let mut any_changed = sc;
651 let new_arms = arms
652 .into_iter()
653 .map(|(pat, body)| {
654 let (new_body, bc) = self.inline_expr(*body, stats);
655 if bc {
656 any_changed = true;
657 }
658 (pat, Box::new(new_body))
659 })
660 .collect();
661 (
662 TLExpr::Match {
663 scrutinee: Box::new(new_scrutinee),
664 arms: new_arms,
665 },
666 any_changed,
667 )
668 }
669 }
670 }
671
672 #[inline]
677 fn map_binary(
678 &self,
679 ctor: fn(Box<TLExpr>, Box<TLExpr>) -> TLExpr,
680 l: TLExpr,
681 r: TLExpr,
682 stats: &mut InlineStats,
683 ) -> (TLExpr, bool) {
684 let (nl, cl) = self.inline_expr(l, stats);
685 let (nr, cr) = self.inline_expr(r, stats);
686 (ctor(Box::new(nl), Box::new(nr)), cl || cr)
687 }
688
689 #[inline]
690 fn map_unary(
691 &self,
692 ctor: fn(Box<TLExpr>) -> TLExpr,
693 e: TLExpr,
694 stats: &mut InlineStats,
695 ) -> (TLExpr, bool) {
696 let (ne, changed) = self.inline_expr(e, stats);
697 (ctor(Box::new(ne)), changed)
698 }
699
700 pub fn count_free_occurrences(var: &str, expr: &TLExpr) -> usize {
706 count_free_occurrences(var, expr)
707 }
708
709 pub fn substitute(var: &str, replacement: &TLExpr, body: TLExpr) -> TLExpr {
711 substitute(var, replacement, body)
712 }
713
714 pub fn is_constant_binding(expr: &TLExpr) -> bool {
716 is_constant_binding(expr)
717 }
718
719 pub fn is_var_binding(expr: &TLExpr) -> bool {
721 is_var_binding(expr)
722 }
723
724 pub fn is_simple_binding(expr: &TLExpr) -> bool {
727 super::helpers::is_simple_binding(expr)
728 }
729
730 pub fn expr_depth(expr: &TLExpr) -> usize {
732 expr_depth(expr)
733 }
734}