Skip to main content

tensorlogic_compiler/inline/
mod.rs

1//! Let-Inlining pass for TLExpr trees.
2//!
3//! This module provides a let-inlining optimization pass that substitutes
4//! `Let`-bound variables into their usage sites, reducing the number of
5//! explicit bindings and enabling downstream passes to work with smaller,
6//! simpler trees.
7//!
8//! # Inlining Strategy
9//!
10//! Three independent criteria control whether a binding is inlined:
11//!
12//! 1. **Single-use inlining** (`inline_single_use`): If the bound variable
13//!    appears free exactly once in the body, inlining is always safe — it
14//!    does not duplicate work.
15//!
16//! 2. **Constant inlining** (`inline_constants`): If the bound *value* is a
17//!    `Constant(f64)`, it is cheap to duplicate so we always inline regardless
18//!    of use count.
19//!
20//! 3. **Variable-alias inlining** (`inline_vars`): If the bound value is a
21//!    zero-argument `Pred` (which serves as a variable reference in TLExpr),
22//!    we inline it unconditionally because the binding is just a rename.
23//!
24//! A `max_inline_depth` guard prevents inlining of deeply nested sub-trees
25//! to keep code-size growth bounded.
26//!
27//! # Correctness
28//!
29//! Substitution respects capture: when descending into a binder that re-uses
30//! the same variable name, substitution stops at that binder boundary.
31//!
32//! # Example
33//!
34//! ```rust
35//! use tensorlogic_compiler::inline::{LetInliner, InlineConfig};
36//! use tensorlogic_ir::TLExpr;
37//!
38//! let inliner = LetInliner::with_default();
39//! // Let x = 5.0 in Add(x, x)  →  Add(5.0, 5.0)
40//! let expr = TLExpr::let_binding(
41//!     "x",
42//!     TLExpr::Constant(5.0),
43//!     TLExpr::add(
44//!         TLExpr::pred("x", vec![]),
45//!         TLExpr::pred("x", vec![]),
46//!     ),
47//! );
48//! let (result, stats) = inliner.run(expr);
49//! assert_eq!(stats.constant_inlines, 1);
50//! ```
51
52pub mod config;
53pub mod helpers;
54pub mod substitute;
55pub mod traversal;
56
57pub use config::{InlineConfig, InlineStats};
58pub use traversal::LetInliner;
59
60#[cfg(test)]
61mod tests {
62    use super::*;
63    use tensorlogic_ir::TLExpr;
64
65    // ── Helpers ───────────────────────────────────────────────────────────────
66
67    /// Create a zero-argument predicate (variable reference in let bodies).
68    fn var(name: &str) -> TLExpr {
69        TLExpr::pred(name, vec![])
70    }
71
72    /// Build a deeply nested Add chain of depth `depth` over `Constant(1.0)`.
73    fn deep_add(depth: usize) -> TLExpr {
74        if depth == 0 {
75            return TLExpr::Constant(1.0);
76        }
77        TLExpr::add(deep_add(depth - 1), TLExpr::Constant(1.0))
78    }
79
80    // ─────────────────────────────────────────────────────────────────────────
81    // InlineStats tests
82    // ─────────────────────────────────────────────────────────────────────────
83
84    #[test]
85    fn test_inline_stats_default() {
86        let stats = InlineStats::default();
87        assert_eq!(stats.single_use_inlines, 0);
88        assert_eq!(stats.constant_inlines, 0);
89        assert_eq!(stats.variable_inlines, 0);
90        assert_eq!(stats.total(), 0);
91        assert_eq!(stats.nodes_before, 0);
92        assert_eq!(stats.nodes_after, 0);
93        assert_eq!(stats.passes, 0);
94    }
95
96    #[test]
97    fn test_inline_stats_summary_nonempty() {
98        let stats = InlineStats {
99            single_use_inlines: 2,
100            constant_inlines: 3,
101            variable_inlines: 1,
102            nodes_before: 20,
103            nodes_after: 14,
104            passes: 2,
105        };
106        let summary = stats.summary();
107        assert!(summary.contains("2 passes"));
108        assert!(summary.contains("14/20"));
109        assert!(summary.contains("2 single-use"));
110        assert!(summary.contains("3 constant"));
111        assert!(summary.contains("1 variable-alias"));
112    }
113
114    #[test]
115    fn test_total_inlines() {
116        let stats = InlineStats {
117            single_use_inlines: 4,
118            constant_inlines: 5,
119            variable_inlines: 3,
120            ..Default::default()
121        };
122        assert_eq!(stats.total(), 12);
123    }
124
125    #[test]
126    fn test_reduction_pct() {
127        let stats = InlineStats {
128            nodes_before: 100,
129            nodes_after: 60,
130            ..Default::default()
131        };
132        // (100 - 60) / 100 * 100 = 40%
133        let pct = stats.reduction_pct();
134        assert!((pct - 40.0).abs() < 1e-9, "expected ~40%, got {pct}");
135    }
136
137    // ─────────────────────────────────────────────────────────────────────────
138    // InlineConfig tests
139    // ─────────────────────────────────────────────────────────────────────────
140
141    #[test]
142    fn test_inline_config_default() {
143        let cfg = InlineConfig::default();
144        assert!(cfg.inline_single_use);
145        assert!(cfg.inline_constants);
146        assert!(cfg.inline_vars);
147        assert_eq!(cfg.max_passes, 20);
148        assert_eq!(cfg.max_inline_depth, 10);
149    }
150
151    // ─────────────────────────────────────────────────────────────────────────
152    // LetInliner construction
153    // ─────────────────────────────────────────────────────────────────────────
154
155    #[test]
156    fn test_inliner_with_default() {
157        let inliner = LetInliner::with_default();
158        // Just verify it constructs without panic and default config is sound.
159        assert!(inliner.config.inline_single_use);
160    }
161
162    // ─────────────────────────────────────────────────────────────────────────
163    // count_free_occurrences tests
164    // ─────────────────────────────────────────────────────────────────────────
165
166    #[test]
167    fn test_count_free_occurrences_zero() {
168        // var "z" does not appear in pred("p", [])
169        let expr = var("p");
170        assert_eq!(LetInliner::count_free_occurrences("z", &expr), 0);
171    }
172
173    #[test]
174    fn test_count_free_occurrences_one() {
175        // var "x" appears once in pred("x", [])
176        let expr = var("x");
177        assert_eq!(LetInliner::count_free_occurrences("x", &expr), 1);
178    }
179
180    #[test]
181    fn test_count_free_occurrences_multi() {
182        // var "x" appears twice: Add(x, x)
183        let expr = TLExpr::add(var("x"), var("x"));
184        assert_eq!(LetInliner::count_free_occurrences("x", &expr), 2);
185    }
186
187    // ─────────────────────────────────────────────────────────────────────────
188    // substitute tests
189    // ─────────────────────────────────────────────────────────────────────────
190
191    #[test]
192    fn test_substitute_simple() {
193        // substitute "x" with Constant(7.0) in pred("x", []) => Constant(7.0)
194        let body = var("x");
195        let result = LetInliner::substitute("x", &TLExpr::Constant(7.0), body);
196        assert_eq!(result, TLExpr::Constant(7.0));
197    }
198
199    #[test]
200    fn test_substitute_shadowed() {
201        // substitute "x" with Constant(7.0) in Exists{x, D, pred("x",[])}
202        // The binder "x" shadows the substitution.
203        let inner = TLExpr::exists("x", "D", var("x"));
204        let result = LetInliner::substitute("x", &TLExpr::Constant(7.0), inner.clone());
205        // Because "x" is shadowed, the result should equal the original (no substitution in body).
206        assert_eq!(result, inner);
207    }
208
209    // ─────────────────────────────────────────────────────────────────────────
210    // Inlining behaviour tests
211    // ─────────────────────────────────────────────────────────────────────────
212
213    #[test]
214    fn test_inline_constant_binding() {
215        // Let x = 5.0 in Add(x, x)  => Add(5.0, 5.0)  (constant, always inlined)
216        let inliner = LetInliner::with_default();
217        let expr = TLExpr::let_binding("x", TLExpr::Constant(5.0), TLExpr::add(var("x"), var("x")));
218        let (result, stats) = inliner.run(expr);
219        assert_eq!(stats.constant_inlines, 1);
220        assert_eq!(
221            result,
222            TLExpr::add(TLExpr::Constant(5.0), TLExpr::Constant(5.0))
223        );
224    }
225
226    #[test]
227    fn test_inline_variable_binding() {
228        // Let x = y in pred("p", []) where body uses var("x")
229        // => pred("p", []) with x replaced by y
230        let inliner = LetInliner::with_default();
231        let expr = TLExpr::let_binding("x", var("y"), TLExpr::add(var("x"), TLExpr::Constant(1.0)));
232        let (result, stats) = inliner.run(expr);
233        assert_eq!(stats.variable_inlines, 1);
234        assert_eq!(result, TLExpr::add(var("y"), TLExpr::Constant(1.0)));
235    }
236
237    #[test]
238    fn test_inline_single_use() {
239        // Let x = Add(Constant(3.0), Constant(4.0)) in Sqrt(x)
240        // x used once, so inline it.
241        let inliner = LetInliner::with_default();
242        let binding_val = TLExpr::add(TLExpr::Constant(3.0), TLExpr::Constant(4.0));
243        let expr = TLExpr::let_binding("x", binding_val.clone(), TLExpr::sqrt(var("x")));
244        let (result, stats) = inliner.run(expr);
245        assert_eq!(stats.single_use_inlines, 1);
246        assert_eq!(result, TLExpr::sqrt(binding_val));
247    }
248
249    #[test]
250    fn test_no_inline_multi_use_by_default() {
251        // With inline_single_use=true but binding is neither constant nor var-alias,
252        // and x is used 2 times => should NOT be inlined.
253        let cfg = InlineConfig {
254            inline_single_use: true,
255            inline_constants: false,
256            inline_vars: false,
257            max_passes: 5,
258            max_inline_depth: 10,
259        };
260        let inliner = LetInliner::new(cfg);
261        let binding_val = TLExpr::add(TLExpr::Constant(3.0), TLExpr::Constant(4.0));
262        let expr = TLExpr::let_binding("x", binding_val.clone(), TLExpr::add(var("x"), var("x")));
263        let (_result, stats) = inliner.run(expr);
264        // x used twice, non-simple binding → not inlined
265        assert_eq!(stats.single_use_inlines, 0);
266        assert_eq!(stats.total(), 0);
267    }
268
269    #[test]
270    fn test_inline_depth_limit() {
271        // Binding value is very deep (depth > max_inline_depth) => not inlined
272        let cfg = InlineConfig {
273            inline_single_use: true,
274            inline_constants: true,
275            inline_vars: true,
276            max_passes: 5,
277            max_inline_depth: 3,
278        };
279        let inliner = LetInliner::new(cfg);
280        // deep_add(5) has depth 6 > 3
281        let deep = deep_add(5);
282        let expr = TLExpr::let_binding("x", deep, TLExpr::sqrt(var("x")));
283        let (_result, stats) = inliner.run(expr);
284        assert_eq!(stats.total(), 0, "deep binding should not be inlined");
285    }
286
287    #[test]
288    fn test_shadowing_respected() {
289        // Let x = 5.0 in Let x = 2.0 in pred("x",[])
290        // Outer x is inlined (constant). The inner binding re-introduces x.
291        // Inner x=2.0 should then also be inlined as a constant.
292        let inliner = LetInliner::with_default();
293        let expr = TLExpr::let_binding(
294            "x",
295            TLExpr::Constant(5.0),
296            TLExpr::let_binding("x", TLExpr::Constant(2.0), var("x")),
297        );
298        // After full inlining the result should be Constant(2.0)
299        let (result, stats) = inliner.run(expr);
300        assert_eq!(result, TLExpr::Constant(2.0));
301        // Both constant bindings were inlined.
302        assert!(stats.constant_inlines >= 2);
303    }
304
305    #[test]
306    fn test_run_fixed_point() {
307        // Let a = 1.0 in Let b = a in Add(b, b)
308        // Pass 1: a=1.0 (constant) inlined → Let b = 1.0 in Add(b, b)
309        // Pass 2: b=1.0 (constant) inlined → Add(1.0, 1.0)
310        let inliner = LetInliner::with_default();
311        let expr = TLExpr::let_binding(
312            "a",
313            TLExpr::Constant(1.0),
314            TLExpr::let_binding("b", var("a"), TLExpr::add(var("b"), var("b"))),
315        );
316        let (result, stats) = inliner.run(expr);
317        assert_eq!(
318            result,
319            TLExpr::add(TLExpr::Constant(1.0), TLExpr::Constant(1.0))
320        );
321        assert!(stats.total() >= 2);
322    }
323
324    #[test]
325    fn test_run_preserves_non_let() {
326        // An expression with no Let bindings is returned unchanged.
327        let inliner = LetInliner::with_default();
328        let expr = TLExpr::and(TLExpr::pred("P", vec![]), TLExpr::Constant(1.0));
329        let (result, stats) = inliner.run(expr.clone());
330        assert_eq!(result, expr);
331        assert_eq!(stats.total(), 0);
332    }
333
334    #[test]
335    fn test_inline_disabled() {
336        // With all inlining disabled, nothing should be inlined.
337        let cfg = InlineConfig {
338            inline_single_use: false,
339            inline_constants: false,
340            inline_vars: false,
341            max_passes: 5,
342            max_inline_depth: 10,
343        };
344        let inliner = LetInliner::new(cfg);
345        let expr = TLExpr::let_binding("x", TLExpr::Constant(99.0), var("x"));
346        let (_result, stats) = inliner.run(expr);
347        assert_eq!(stats.total(), 0, "all flags disabled => no inlining");
348    }
349
350    #[test]
351    fn test_reduction_pct_after_inlining() {
352        // Inlining a constant in Let x = C in Add(x, x) should reduce node count.
353        let inliner = LetInliner::with_default();
354        let expr = TLExpr::let_binding("x", TLExpr::Constant(3.0), TLExpr::add(var("x"), var("x")));
355        let (_, stats) = inliner.run(expr);
356        assert!(
357            stats.nodes_after <= stats.nodes_before,
358            "nodes should not grow: before={}, after={}",
359            stats.nodes_before,
360            stats.nodes_after
361        );
362        // Should have removed the Let node and the binding Constant
363        // Before: Let(C(3), Add(x, x)) = 5 nodes
364        // After:  Add(C(3), C(3)) = 3 nodes
365        assert!(stats.reduction_pct() > 0.0, "should have some reduction");
366    }
367}