tensorlogic_compiler/inline/
mod.rs1pub mod config;
53pub mod helpers;
54pub mod substitute;
55pub mod traversal;
56
57pub use config::{InlineConfig, InlineStats};
58pub use traversal::LetInliner;
59
60#[cfg(test)]
61mod tests {
62 use super::*;
63 use tensorlogic_ir::TLExpr;
64
65 fn var(name: &str) -> TLExpr {
69 TLExpr::pred(name, vec![])
70 }
71
72 fn deep_add(depth: usize) -> TLExpr {
74 if depth == 0 {
75 return TLExpr::Constant(1.0);
76 }
77 TLExpr::add(deep_add(depth - 1), TLExpr::Constant(1.0))
78 }
79
80 #[test]
85 fn test_inline_stats_default() {
86 let stats = InlineStats::default();
87 assert_eq!(stats.single_use_inlines, 0);
88 assert_eq!(stats.constant_inlines, 0);
89 assert_eq!(stats.variable_inlines, 0);
90 assert_eq!(stats.total(), 0);
91 assert_eq!(stats.nodes_before, 0);
92 assert_eq!(stats.nodes_after, 0);
93 assert_eq!(stats.passes, 0);
94 }
95
96 #[test]
97 fn test_inline_stats_summary_nonempty() {
98 let stats = InlineStats {
99 single_use_inlines: 2,
100 constant_inlines: 3,
101 variable_inlines: 1,
102 nodes_before: 20,
103 nodes_after: 14,
104 passes: 2,
105 };
106 let summary = stats.summary();
107 assert!(summary.contains("2 passes"));
108 assert!(summary.contains("14/20"));
109 assert!(summary.contains("2 single-use"));
110 assert!(summary.contains("3 constant"));
111 assert!(summary.contains("1 variable-alias"));
112 }
113
114 #[test]
115 fn test_total_inlines() {
116 let stats = InlineStats {
117 single_use_inlines: 4,
118 constant_inlines: 5,
119 variable_inlines: 3,
120 ..Default::default()
121 };
122 assert_eq!(stats.total(), 12);
123 }
124
125 #[test]
126 fn test_reduction_pct() {
127 let stats = InlineStats {
128 nodes_before: 100,
129 nodes_after: 60,
130 ..Default::default()
131 };
132 let pct = stats.reduction_pct();
134 assert!((pct - 40.0).abs() < 1e-9, "expected ~40%, got {pct}");
135 }
136
137 #[test]
142 fn test_inline_config_default() {
143 let cfg = InlineConfig::default();
144 assert!(cfg.inline_single_use);
145 assert!(cfg.inline_constants);
146 assert!(cfg.inline_vars);
147 assert_eq!(cfg.max_passes, 20);
148 assert_eq!(cfg.max_inline_depth, 10);
149 }
150
151 #[test]
156 fn test_inliner_with_default() {
157 let inliner = LetInliner::with_default();
158 assert!(inliner.config.inline_single_use);
160 }
161
162 #[test]
167 fn test_count_free_occurrences_zero() {
168 let expr = var("p");
170 assert_eq!(LetInliner::count_free_occurrences("z", &expr), 0);
171 }
172
173 #[test]
174 fn test_count_free_occurrences_one() {
175 let expr = var("x");
177 assert_eq!(LetInliner::count_free_occurrences("x", &expr), 1);
178 }
179
180 #[test]
181 fn test_count_free_occurrences_multi() {
182 let expr = TLExpr::add(var("x"), var("x"));
184 assert_eq!(LetInliner::count_free_occurrences("x", &expr), 2);
185 }
186
187 #[test]
192 fn test_substitute_simple() {
193 let body = var("x");
195 let result = LetInliner::substitute("x", &TLExpr::Constant(7.0), body);
196 assert_eq!(result, TLExpr::Constant(7.0));
197 }
198
199 #[test]
200 fn test_substitute_shadowed() {
201 let inner = TLExpr::exists("x", "D", var("x"));
204 let result = LetInliner::substitute("x", &TLExpr::Constant(7.0), inner.clone());
205 assert_eq!(result, inner);
207 }
208
209 #[test]
214 fn test_inline_constant_binding() {
215 let inliner = LetInliner::with_default();
217 let expr = TLExpr::let_binding("x", TLExpr::Constant(5.0), TLExpr::add(var("x"), var("x")));
218 let (result, stats) = inliner.run(expr);
219 assert_eq!(stats.constant_inlines, 1);
220 assert_eq!(
221 result,
222 TLExpr::add(TLExpr::Constant(5.0), TLExpr::Constant(5.0))
223 );
224 }
225
226 #[test]
227 fn test_inline_variable_binding() {
228 let inliner = LetInliner::with_default();
231 let expr = TLExpr::let_binding("x", var("y"), TLExpr::add(var("x"), TLExpr::Constant(1.0)));
232 let (result, stats) = inliner.run(expr);
233 assert_eq!(stats.variable_inlines, 1);
234 assert_eq!(result, TLExpr::add(var("y"), TLExpr::Constant(1.0)));
235 }
236
237 #[test]
238 fn test_inline_single_use() {
239 let inliner = LetInliner::with_default();
242 let binding_val = TLExpr::add(TLExpr::Constant(3.0), TLExpr::Constant(4.0));
243 let expr = TLExpr::let_binding("x", binding_val.clone(), TLExpr::sqrt(var("x")));
244 let (result, stats) = inliner.run(expr);
245 assert_eq!(stats.single_use_inlines, 1);
246 assert_eq!(result, TLExpr::sqrt(binding_val));
247 }
248
249 #[test]
250 fn test_no_inline_multi_use_by_default() {
251 let cfg = InlineConfig {
254 inline_single_use: true,
255 inline_constants: false,
256 inline_vars: false,
257 max_passes: 5,
258 max_inline_depth: 10,
259 };
260 let inliner = LetInliner::new(cfg);
261 let binding_val = TLExpr::add(TLExpr::Constant(3.0), TLExpr::Constant(4.0));
262 let expr = TLExpr::let_binding("x", binding_val.clone(), TLExpr::add(var("x"), var("x")));
263 let (_result, stats) = inliner.run(expr);
264 assert_eq!(stats.single_use_inlines, 0);
266 assert_eq!(stats.total(), 0);
267 }
268
269 #[test]
270 fn test_inline_depth_limit() {
271 let cfg = InlineConfig {
273 inline_single_use: true,
274 inline_constants: true,
275 inline_vars: true,
276 max_passes: 5,
277 max_inline_depth: 3,
278 };
279 let inliner = LetInliner::new(cfg);
280 let deep = deep_add(5);
282 let expr = TLExpr::let_binding("x", deep, TLExpr::sqrt(var("x")));
283 let (_result, stats) = inliner.run(expr);
284 assert_eq!(stats.total(), 0, "deep binding should not be inlined");
285 }
286
287 #[test]
288 fn test_shadowing_respected() {
289 let inliner = LetInliner::with_default();
293 let expr = TLExpr::let_binding(
294 "x",
295 TLExpr::Constant(5.0),
296 TLExpr::let_binding("x", TLExpr::Constant(2.0), var("x")),
297 );
298 let (result, stats) = inliner.run(expr);
300 assert_eq!(result, TLExpr::Constant(2.0));
301 assert!(stats.constant_inlines >= 2);
303 }
304
305 #[test]
306 fn test_run_fixed_point() {
307 let inliner = LetInliner::with_default();
311 let expr = TLExpr::let_binding(
312 "a",
313 TLExpr::Constant(1.0),
314 TLExpr::let_binding("b", var("a"), TLExpr::add(var("b"), var("b"))),
315 );
316 let (result, stats) = inliner.run(expr);
317 assert_eq!(
318 result,
319 TLExpr::add(TLExpr::Constant(1.0), TLExpr::Constant(1.0))
320 );
321 assert!(stats.total() >= 2);
322 }
323
324 #[test]
325 fn test_run_preserves_non_let() {
326 let inliner = LetInliner::with_default();
328 let expr = TLExpr::and(TLExpr::pred("P", vec![]), TLExpr::Constant(1.0));
329 let (result, stats) = inliner.run(expr.clone());
330 assert_eq!(result, expr);
331 assert_eq!(stats.total(), 0);
332 }
333
334 #[test]
335 fn test_inline_disabled() {
336 let cfg = InlineConfig {
338 inline_single_use: false,
339 inline_constants: false,
340 inline_vars: false,
341 max_passes: 5,
342 max_inline_depth: 10,
343 };
344 let inliner = LetInliner::new(cfg);
345 let expr = TLExpr::let_binding("x", TLExpr::Constant(99.0), var("x"));
346 let (_result, stats) = inliner.run(expr);
347 assert_eq!(stats.total(), 0, "all flags disabled => no inlining");
348 }
349
350 #[test]
351 fn test_reduction_pct_after_inlining() {
352 let inliner = LetInliner::with_default();
354 let expr = TLExpr::let_binding("x", TLExpr::Constant(3.0), TLExpr::add(var("x"), var("x")));
355 let (_, stats) = inliner.run(expr);
356 assert!(
357 stats.nodes_after <= stats.nodes_before,
358 "nodes should not grow: before={}, after={}",
359 stats.nodes_before,
360 stats.nodes_after
361 );
362 assert!(stats.reduction_pct() > 0.0, "should have some reduction");
366 }
367}