Skip to main content

tensorlogic_compiler/inline/
traversal.rs

1use tensorlogic_ir::TLExpr;
2
3use super::config::{InlineConfig, InlineStats};
4use super::helpers::{
5    count_free_occurrences, count_nodes, expr_depth, is_constant_binding, is_var_binding,
6};
7use super::substitute::substitute;
8
9/// The let-inlining pass.
10///
11/// Iterates to a fixed point (or until `config.max_passes` is reached),
12/// replacing eligible `Let` bindings with direct substitution of the bound
13/// value into the body.
14pub struct LetInliner {
15    pub(super) config: InlineConfig,
16}
17
18impl Default for LetInliner {
19    fn default() -> Self {
20        Self::with_default()
21    }
22}
23
24impl LetInliner {
25    /// Construct a new inliner with the given configuration.
26    pub fn new(config: InlineConfig) -> Self {
27        Self { config }
28    }
29
30    /// Construct a new inliner with the default configuration.
31    pub fn with_default() -> Self {
32        Self::new(InlineConfig::default())
33    }
34
35    // ─────────────────────────────────────────────────────────────────────
36    // Public entry-point
37    // ─────────────────────────────────────────────────────────────────────
38
39    /// Run inlining to a fixed point (or until `config.max_passes` is reached).
40    ///
41    /// Returns the rewritten expression and collected [`InlineStats`].
42    pub fn run(&self, expr: TLExpr) -> (TLExpr, InlineStats) {
43        let mut stats = InlineStats {
44            nodes_before: count_nodes(&expr),
45            ..Default::default()
46        };
47
48        let mut current = expr;
49        let max = self.config.max_passes.max(1);
50
51        for _ in 0..max {
52            let (next, changed) = self.run_pass(current, &mut stats);
53            stats.passes += 1;
54            current = next;
55            if !changed {
56                break;
57            }
58        }
59
60        stats.nodes_after = count_nodes(&current);
61        (current, stats)
62    }
63
64    // ─────────────────────────────────────────────────────────────────────
65    // Single pass
66    // ─────────────────────────────────────────────────────────────────────
67
68    /// Execute a single bottom-up inlining pass.
69    ///
70    /// Returns `(new_expr, did_change)`.
71    fn run_pass(&self, expr: TLExpr, stats: &mut InlineStats) -> (TLExpr, bool) {
72        self.inline_expr(expr, stats)
73    }
74
75    // ─────────────────────────────────────────────────────────────────────
76    // Core recursive transformation
77    // ─────────────────────────────────────────────────────────────────────
78
79    /// Recursively inline let-bindings in `expr`.
80    ///
81    /// The traversal is bottom-up: children are processed first so that
82    /// nested bindings are simplified before the enclosing binder is
83    /// considered.
84    fn inline_expr(&self, expr: TLExpr, stats: &mut InlineStats) -> (TLExpr, bool) {
85        match expr {
86            // ── The key case: Let bindings ───────────────────────────────────
87            TLExpr::Let { var, value, body } => {
88                // First recurse into the value and body.
89                let (new_value, cv) = self.inline_expr(*value, stats);
90                let (new_body, cb) = self.inline_expr(*body, stats);
91                let child_changed = cv || cb;
92
93                // Decide whether to inline this binding.
94                let depth_ok = expr_depth(&new_value) <= self.config.max_inline_depth;
95
96                if depth_ok {
97                    // Case 1: constant — always inline if flag set
98                    if self.config.inline_constants && is_constant_binding(&new_value) {
99                        stats.constant_inlines += 1;
100                        let inlined = substitute(&var, &new_value, new_body);
101                        // Recurse once more: the substituted body may expose new opportunities.
102                        let (final_expr, _) = self.inline_expr(inlined, stats);
103                        return (final_expr, true);
104                    }
105
106                    // Case 2: simple variable alias — always inline if flag set
107                    if self.config.inline_vars && is_var_binding(&new_value) {
108                        stats.variable_inlines += 1;
109                        let inlined = substitute(&var, &new_value, new_body);
110                        let (final_expr, _) = self.inline_expr(inlined, stats);
111                        return (final_expr, true);
112                    }
113
114                    // Case 3: single-use — inline if flag set
115                    if self.config.inline_single_use && count_free_occurrences(&var, &new_body) == 1
116                    {
117                        stats.single_use_inlines += 1;
118                        let inlined = substitute(&var, &new_value, new_body);
119                        let (final_expr, _) = self.inline_expr(inlined, stats);
120                        return (final_expr, true);
121                    }
122                }
123
124                // Not inlined — keep the Let node with updated children.
125                (
126                    TLExpr::Let {
127                        var,
128                        value: Box::new(new_value),
129                        body: Box::new(new_body),
130                    },
131                    child_changed,
132                )
133            }
134
135            // ── Boolean connectives ──────────────────────────────────────────
136            TLExpr::And(l, r) => {
137                let (nl, cl) = self.inline_expr(*l, stats);
138                let (nr, cr) = self.inline_expr(*r, stats);
139                (TLExpr::And(Box::new(nl), Box::new(nr)), cl || cr)
140            }
141            TLExpr::Or(l, r) => {
142                let (nl, cl) = self.inline_expr(*l, stats);
143                let (nr, cr) = self.inline_expr(*r, stats);
144                (TLExpr::Or(Box::new(nl), Box::new(nr)), cl || cr)
145            }
146            TLExpr::Not(e) => {
147                let (ne, changed) = self.inline_expr(*e, stats);
148                (TLExpr::Not(Box::new(ne)), changed)
149            }
150            TLExpr::Imply(l, r) => {
151                let (nl, cl) = self.inline_expr(*l, stats);
152                let (nr, cr) = self.inline_expr(*r, stats);
153                (TLExpr::Imply(Box::new(nl), Box::new(nr)), cl || cr)
154            }
155
156            // ── Arithmetic binary ops ────────────────────────────────────────
157            TLExpr::Add(l, r) => self.map_binary(TLExpr::Add, *l, *r, stats),
158            TLExpr::Sub(l, r) => self.map_binary(TLExpr::Sub, *l, *r, stats),
159            TLExpr::Mul(l, r) => self.map_binary(TLExpr::Mul, *l, *r, stats),
160            TLExpr::Div(l, r) => self.map_binary(TLExpr::Div, *l, *r, stats),
161            TLExpr::Pow(l, r) => self.map_binary(TLExpr::Pow, *l, *r, stats),
162            TLExpr::Mod(l, r) => self.map_binary(TLExpr::Mod, *l, *r, stats),
163            TLExpr::Min(l, r) => self.map_binary(TLExpr::Min, *l, *r, stats),
164            TLExpr::Max(l, r) => self.map_binary(TLExpr::Max, *l, *r, stats),
165
166            // ── Comparison binary ops ────────────────────────────────────────
167            TLExpr::Eq(l, r) => self.map_binary(TLExpr::Eq, *l, *r, stats),
168            TLExpr::Lt(l, r) => self.map_binary(TLExpr::Lt, *l, *r, stats),
169            TLExpr::Gt(l, r) => self.map_binary(TLExpr::Gt, *l, *r, stats),
170            TLExpr::Lte(l, r) => self.map_binary(TLExpr::Lte, *l, *r, stats),
171            TLExpr::Gte(l, r) => self.map_binary(TLExpr::Gte, *l, *r, stats),
172
173            // ── Unary math ops ───────────────────────────────────────────────
174            TLExpr::Abs(e) => self.map_unary(TLExpr::Abs, *e, stats),
175            TLExpr::Floor(e) => self.map_unary(TLExpr::Floor, *e, stats),
176            TLExpr::Ceil(e) => self.map_unary(TLExpr::Ceil, *e, stats),
177            TLExpr::Round(e) => self.map_unary(TLExpr::Round, *e, stats),
178            TLExpr::Sqrt(e) => self.map_unary(TLExpr::Sqrt, *e, stats),
179            TLExpr::Exp(e) => self.map_unary(TLExpr::Exp, *e, stats),
180            TLExpr::Log(e) => self.map_unary(TLExpr::Log, *e, stats),
181            TLExpr::Sin(e) => self.map_unary(TLExpr::Sin, *e, stats),
182            TLExpr::Cos(e) => self.map_unary(TLExpr::Cos, *e, stats),
183            TLExpr::Tan(e) => self.map_unary(TLExpr::Tan, *e, stats),
184            TLExpr::Score(e) => self.map_unary(TLExpr::Score, *e, stats),
185
186            // ── Modal / temporal unary ───────────────────────────────────────
187            TLExpr::Box(e) => self.map_unary(TLExpr::Box, *e, stats),
188            TLExpr::Diamond(e) => self.map_unary(TLExpr::Diamond, *e, stats),
189            TLExpr::Next(e) => self.map_unary(TLExpr::Next, *e, stats),
190            TLExpr::Eventually(e) => self.map_unary(TLExpr::Eventually, *e, stats),
191            TLExpr::Always(e) => self.map_unary(TLExpr::Always, *e, stats),
192
193            // ── Temporal binary ──────────────────────────────────────────────
194            TLExpr::Until { before, after } => {
195                let (nb, cb) = self.inline_expr(*before, stats);
196                let (na, ca) = self.inline_expr(*after, stats);
197                (
198                    TLExpr::Until {
199                        before: Box::new(nb),
200                        after: Box::new(na),
201                    },
202                    cb || ca,
203                )
204            }
205            TLExpr::Release { released, releaser } => {
206                let (nr, cr) = self.inline_expr(*released, stats);
207                let (ne, ce) = self.inline_expr(*releaser, stats);
208                (
209                    TLExpr::Release {
210                        released: Box::new(nr),
211                        releaser: Box::new(ne),
212                    },
213                    cr || ce,
214                )
215            }
216            TLExpr::WeakUntil { before, after } => {
217                let (nb, cb) = self.inline_expr(*before, stats);
218                let (na, ca) = self.inline_expr(*after, stats);
219                (
220                    TLExpr::WeakUntil {
221                        before: Box::new(nb),
222                        after: Box::new(na),
223                    },
224                    cb || ca,
225                )
226            }
227            TLExpr::StrongRelease { released, releaser } => {
228                let (nr, cr) = self.inline_expr(*released, stats);
229                let (ne, ce) = self.inline_expr(*releaser, stats);
230                (
231                    TLExpr::StrongRelease {
232                        released: Box::new(nr),
233                        releaser: Box::new(ne),
234                    },
235                    cr || ce,
236                )
237            }
238
239            // ── Fuzzy operators ──────────────────────────────────────────────
240            TLExpr::TNorm { kind, left, right } => {
241                let (nl, cl) = self.inline_expr(*left, stats);
242                let (nr, cr) = self.inline_expr(*right, stats);
243                (
244                    TLExpr::TNorm {
245                        kind,
246                        left: Box::new(nl),
247                        right: Box::new(nr),
248                    },
249                    cl || cr,
250                )
251            }
252            TLExpr::TCoNorm { kind, left, right } => {
253                let (nl, cl) = self.inline_expr(*left, stats);
254                let (nr, cr) = self.inline_expr(*right, stats);
255                (
256                    TLExpr::TCoNorm {
257                        kind,
258                        left: Box::new(nl),
259                        right: Box::new(nr),
260                    },
261                    cl || cr,
262                )
263            }
264            TLExpr::FuzzyNot { kind, expr } => {
265                let (ne, changed) = self.inline_expr(*expr, stats);
266                (
267                    TLExpr::FuzzyNot {
268                        kind,
269                        expr: Box::new(ne),
270                    },
271                    changed,
272                )
273            }
274            TLExpr::FuzzyImplication {
275                kind,
276                premise,
277                conclusion,
278            } => {
279                let (np, cp) = self.inline_expr(*premise, stats);
280                let (nc, cc) = self.inline_expr(*conclusion, stats);
281                (
282                    TLExpr::FuzzyImplication {
283                        kind,
284                        premise: Box::new(np),
285                        conclusion: Box::new(nc),
286                    },
287                    cp || cc,
288                )
289            }
290
291            // ── Weighted / probabilistic ─────────────────────────────────────
292            TLExpr::WeightedRule { weight, rule } => {
293                let (nr, changed) = self.inline_expr(*rule, stats);
294                (
295                    TLExpr::WeightedRule {
296                        weight,
297                        rule: Box::new(nr),
298                    },
299                    changed,
300                )
301            }
302            TLExpr::ProbabilisticChoice { alternatives } => {
303                let mut any_changed = false;
304                let new_alts: Vec<(f64, TLExpr)> = alternatives
305                    .into_iter()
306                    .map(|(prob, e)| {
307                        let (ne, changed) = self.inline_expr(e, stats);
308                        any_changed = any_changed || changed;
309                        (prob, ne)
310                    })
311                    .collect();
312                (
313                    TLExpr::ProbabilisticChoice {
314                        alternatives: new_alts,
315                    },
316                    any_changed,
317                )
318            }
319
320            // ── IfThenElse ───────────────────────────────────────────────────
321            TLExpr::IfThenElse {
322                condition,
323                then_branch,
324                else_branch,
325            } => {
326                let (nc, cc) = self.inline_expr(*condition, stats);
327                let (nt, ct) = self.inline_expr(*then_branch, stats);
328                let (ne, ce) = self.inline_expr(*else_branch, stats);
329                (
330                    TLExpr::IfThenElse {
331                        condition: Box::new(nc),
332                        then_branch: Box::new(nt),
333                        else_branch: Box::new(ne),
334                    },
335                    cc || ct || ce,
336                )
337            }
338
339            // ── Quantifiers ──────────────────────────────────────────────────
340            TLExpr::Exists { var, domain, body } => {
341                let (new_body, changed) = self.inline_expr(*body, stats);
342                (
343                    TLExpr::Exists {
344                        var,
345                        domain,
346                        body: Box::new(new_body),
347                    },
348                    changed,
349                )
350            }
351            TLExpr::ForAll { var, domain, body } => {
352                let (new_body, changed) = self.inline_expr(*body, stats);
353                (
354                    TLExpr::ForAll {
355                        var,
356                        domain,
357                        body: Box::new(new_body),
358                    },
359                    changed,
360                )
361            }
362            TLExpr::SoftExists {
363                var,
364                domain,
365                body,
366                temperature,
367            } => {
368                let (new_body, changed) = self.inline_expr(*body, stats);
369                (
370                    TLExpr::SoftExists {
371                        var,
372                        domain,
373                        body: Box::new(new_body),
374                        temperature,
375                    },
376                    changed,
377                )
378            }
379            TLExpr::SoftForAll {
380                var,
381                domain,
382                body,
383                temperature,
384            } => {
385                let (new_body, changed) = self.inline_expr(*body, stats);
386                (
387                    TLExpr::SoftForAll {
388                        var,
389                        domain,
390                        body: Box::new(new_body),
391                        temperature,
392                    },
393                    changed,
394                )
395            }
396
397            // ── Aggregation ──────────────────────────────────────────────────
398            TLExpr::Aggregate {
399                op,
400                var,
401                domain,
402                body,
403                group_by,
404            } => {
405                let (new_body, changed) = self.inline_expr(*body, stats);
406                (
407                    TLExpr::Aggregate {
408                        op,
409                        var,
410                        domain,
411                        body: Box::new(new_body),
412                        group_by,
413                    },
414                    changed,
415                )
416            }
417
418            // ── Higher-order ─────────────────────────────────────────────────
419            TLExpr::Lambda {
420                var,
421                var_type,
422                body,
423            } => {
424                let (new_body, changed) = self.inline_expr(*body, stats);
425                (
426                    TLExpr::Lambda {
427                        var,
428                        var_type,
429                        body: Box::new(new_body),
430                    },
431                    changed,
432                )
433            }
434            TLExpr::Apply { function, argument } => {
435                let (nf, cf) = self.inline_expr(*function, stats);
436                let (na, ca) = self.inline_expr(*argument, stats);
437                (
438                    TLExpr::Apply {
439                        function: Box::new(nf),
440                        argument: Box::new(na),
441                    },
442                    cf || ca,
443                )
444            }
445
446            // ── Set theory ───────────────────────────────────────────────────
447            TLExpr::SetMembership { element, set } => {
448                let (ne, ce) = self.inline_expr(*element, stats);
449                let (ns, cs) = self.inline_expr(*set, stats);
450                (
451                    TLExpr::SetMembership {
452                        element: Box::new(ne),
453                        set: Box::new(ns),
454                    },
455                    ce || cs,
456                )
457            }
458            TLExpr::SetUnion { left, right } => {
459                let (nl, cl) = self.inline_expr(*left, stats);
460                let (nr, cr) = self.inline_expr(*right, stats);
461                (
462                    TLExpr::SetUnion {
463                        left: Box::new(nl),
464                        right: Box::new(nr),
465                    },
466                    cl || cr,
467                )
468            }
469            TLExpr::SetIntersection { left, right } => {
470                let (nl, cl) = self.inline_expr(*left, stats);
471                let (nr, cr) = self.inline_expr(*right, stats);
472                (
473                    TLExpr::SetIntersection {
474                        left: Box::new(nl),
475                        right: Box::new(nr),
476                    },
477                    cl || cr,
478                )
479            }
480            TLExpr::SetDifference { left, right } => {
481                let (nl, cl) = self.inline_expr(*left, stats);
482                let (nr, cr) = self.inline_expr(*right, stats);
483                (
484                    TLExpr::SetDifference {
485                        left: Box::new(nl),
486                        right: Box::new(nr),
487                    },
488                    cl || cr,
489                )
490            }
491            TLExpr::SetCardinality { set } => {
492                let (ns, changed) = self.inline_expr(*set, stats);
493                (TLExpr::SetCardinality { set: Box::new(ns) }, changed)
494            }
495            TLExpr::SetComprehension {
496                var,
497                domain,
498                condition,
499            } => {
500                let (nc, changed) = self.inline_expr(*condition, stats);
501                (
502                    TLExpr::SetComprehension {
503                        var,
504                        domain,
505                        condition: Box::new(nc),
506                    },
507                    changed,
508                )
509            }
510
511            // ── Counting quantifiers ─────────────────────────────────────────
512            TLExpr::CountingExists {
513                var,
514                domain,
515                body,
516                min_count,
517            } => {
518                let (new_body, changed) = self.inline_expr(*body, stats);
519                (
520                    TLExpr::CountingExists {
521                        var,
522                        domain,
523                        body: Box::new(new_body),
524                        min_count,
525                    },
526                    changed,
527                )
528            }
529            TLExpr::CountingForAll {
530                var,
531                domain,
532                body,
533                min_count,
534            } => {
535                let (new_body, changed) = self.inline_expr(*body, stats);
536                (
537                    TLExpr::CountingForAll {
538                        var,
539                        domain,
540                        body: Box::new(new_body),
541                        min_count,
542                    },
543                    changed,
544                )
545            }
546            TLExpr::ExactCount {
547                var,
548                domain,
549                body,
550                count,
551            } => {
552                let (new_body, changed) = self.inline_expr(*body, stats);
553                (
554                    TLExpr::ExactCount {
555                        var,
556                        domain,
557                        body: Box::new(new_body),
558                        count,
559                    },
560                    changed,
561                )
562            }
563            TLExpr::Majority { var, domain, body } => {
564                let (new_body, changed) = self.inline_expr(*body, stats);
565                (
566                    TLExpr::Majority {
567                        var,
568                        domain,
569                        body: Box::new(new_body),
570                    },
571                    changed,
572                )
573            }
574
575            // ── Fixed-point operators ────────────────────────────────────────
576            TLExpr::LeastFixpoint { var, body } => {
577                let (new_body, changed) = self.inline_expr(*body, stats);
578                (
579                    TLExpr::LeastFixpoint {
580                        var,
581                        body: Box::new(new_body),
582                    },
583                    changed,
584                )
585            }
586            TLExpr::GreatestFixpoint { var, body } => {
587                let (new_body, changed) = self.inline_expr(*body, stats);
588                (
589                    TLExpr::GreatestFixpoint {
590                        var,
591                        body: Box::new(new_body),
592                    },
593                    changed,
594                )
595            }
596
597            // ── Hybrid logic ─────────────────────────────────────────────────
598            TLExpr::At { nominal, formula } => {
599                let (nf, changed) = self.inline_expr(*formula, stats);
600                (
601                    TLExpr::At {
602                        nominal,
603                        formula: Box::new(nf),
604                    },
605                    changed,
606                )
607            }
608            TLExpr::Somewhere { formula } => {
609                let (nf, changed) = self.inline_expr(*formula, stats);
610                (
611                    TLExpr::Somewhere {
612                        formula: Box::new(nf),
613                    },
614                    changed,
615                )
616            }
617            TLExpr::Everywhere { formula } => {
618                let (nf, changed) = self.inline_expr(*formula, stats);
619                (
620                    TLExpr::Everywhere {
621                        formula: Box::new(nf),
622                    },
623                    changed,
624                )
625            }
626
627            // ── Abductive ────────────────────────────────────────────────────
628            TLExpr::Explain { formula } => {
629                let (nf, changed) = self.inline_expr(*formula, stats);
630                (
631                    TLExpr::Explain {
632                        formula: Box::new(nf),
633                    },
634                    changed,
635                )
636            }
637
638            // ── Leaves ───────────────────────────────────────────────────────
639            leaf @ (TLExpr::Pred { .. }
640            | TLExpr::Constant(_)
641            | TLExpr::EmptySet
642            | TLExpr::AllDifferent { .. }
643            | TLExpr::GlobalCardinality { .. }
644            | TLExpr::Nominal { .. }
645            | TLExpr::Abducible { .. }
646            | TLExpr::SymbolLiteral(_)) => (leaf, false),
647
648            TLExpr::Match { scrutinee, arms } => {
649                let (new_scrutinee, sc) = self.inline_expr(*scrutinee, stats);
650                let mut any_changed = sc;
651                let new_arms = arms
652                    .into_iter()
653                    .map(|(pat, body)| {
654                        let (new_body, bc) = self.inline_expr(*body, stats);
655                        if bc {
656                            any_changed = true;
657                        }
658                        (pat, Box::new(new_body))
659                    })
660                    .collect();
661                (
662                    TLExpr::Match {
663                        scrutinee: Box::new(new_scrutinee),
664                        arms: new_arms,
665                    },
666                    any_changed,
667                )
668            }
669        }
670    }
671
672    // ─────────────────────────────────────────────────────────────────────
673    // Helper: binary / unary mapping
674    // ─────────────────────────────────────────────────────────────────────
675
676    #[inline]
677    fn map_binary(
678        &self,
679        ctor: fn(Box<TLExpr>, Box<TLExpr>) -> TLExpr,
680        l: TLExpr,
681        r: TLExpr,
682        stats: &mut InlineStats,
683    ) -> (TLExpr, bool) {
684        let (nl, cl) = self.inline_expr(l, stats);
685        let (nr, cr) = self.inline_expr(r, stats);
686        (ctor(Box::new(nl), Box::new(nr)), cl || cr)
687    }
688
689    #[inline]
690    fn map_unary(
691        &self,
692        ctor: fn(Box<TLExpr>) -> TLExpr,
693        e: TLExpr,
694        stats: &mut InlineStats,
695    ) -> (TLExpr, bool) {
696        let (ne, changed) = self.inline_expr(e, stats);
697        (ctor(Box::new(ne)), changed)
698    }
699
700    // ─────────────────────────────────────────────────────────────────────
701    // Public helper forwarding functions (keeping backward-compat API)
702    // ─────────────────────────────────────────────────────────────────────
703
704    /// Count how many times `var` appears free in `expr`.
705    pub fn count_free_occurrences(var: &str, expr: &TLExpr) -> usize {
706        count_free_occurrences(var, expr)
707    }
708
709    /// Substitute all free occurrences of `var` with `replacement` in `body`.
710    pub fn substitute(var: &str, replacement: &TLExpr, body: TLExpr) -> TLExpr {
711        substitute(var, replacement, body)
712    }
713
714    /// Returns `true` if `expr` is a constant literal (`Constant(_)`).
715    pub fn is_constant_binding(expr: &TLExpr) -> bool {
716        is_constant_binding(expr)
717    }
718
719    /// Returns `true` if `expr` is a zero-argument predicate (variable alias).
720    pub fn is_var_binding(expr: &TLExpr) -> bool {
721        is_var_binding(expr)
722    }
723
724    /// Returns `true` if `expr` is a "simple" binding worth inlining regardless
725    /// of use count: either a constant or a variable alias.
726    pub fn is_simple_binding(expr: &TLExpr) -> bool {
727        super::helpers::is_simple_binding(expr)
728    }
729
730    /// Compute the depth (height) of an expression tree.
731    pub fn expr_depth(expr: &TLExpr) -> usize {
732        expr_depth(expr)
733    }
734}