Skip to main content

tensorlogic_compiler/inline/
helpers.rs

1use tensorlogic_ir::TLExpr;
2
3/// Count how many times `var` appears free in `expr`.
4///
5/// A variable occurrence is:
6/// - A `Pred { name, args: [] }` node whose `name == var` (zero-arg
7///   predicate used as a variable reference in let bodies), OR
8/// - A `Term::Var(v)` inside `Pred` args where `v == var`.
9///
10/// The count respects capture: once a binder with the same name is
11/// entered, occurrences inside that scope are not counted as free.
12pub fn count_free_occurrences(var: &str, expr: &TLExpr) -> usize {
13    count_free_in(var, expr)
14}
15
16pub(crate) fn count_free_in(var: &str, expr: &TLExpr) -> usize {
17    match expr {
18        // A zero-argument predicate serves as a variable reference.
19        TLExpr::Pred { name, args } => {
20            if args.is_empty() && name == var {
21                1
22            } else {
23                // Count Term::Var occurrences in the argument list.
24                args.iter()
25                    .filter(|t| matches!(t, tensorlogic_ir::Term::Var(v) if v == var))
26                    .count()
27            }
28        }
29
30        // ── Binary nodes ─────────────────────────────────────────────────
31        TLExpr::And(l, r)
32        | TLExpr::Or(l, r)
33        | TLExpr::Imply(l, r)
34        | TLExpr::Add(l, r)
35        | TLExpr::Sub(l, r)
36        | TLExpr::Mul(l, r)
37        | TLExpr::Div(l, r)
38        | TLExpr::Pow(l, r)
39        | TLExpr::Mod(l, r)
40        | TLExpr::Min(l, r)
41        | TLExpr::Max(l, r)
42        | TLExpr::Eq(l, r)
43        | TLExpr::Lt(l, r)
44        | TLExpr::Gt(l, r)
45        | TLExpr::Lte(l, r)
46        | TLExpr::Gte(l, r) => count_free_in(var, l) + count_free_in(var, r),
47
48        TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
49            count_free_in(var, left) + count_free_in(var, right)
50        }
51        TLExpr::FuzzyImplication {
52            premise,
53            conclusion,
54            ..
55        } => count_free_in(var, premise) + count_free_in(var, conclusion),
56
57        // ── Unary nodes ──────────────────────────────────────────────────
58        TLExpr::Not(e)
59        | TLExpr::Score(e)
60        | TLExpr::Abs(e)
61        | TLExpr::Floor(e)
62        | TLExpr::Ceil(e)
63        | TLExpr::Round(e)
64        | TLExpr::Sqrt(e)
65        | TLExpr::Exp(e)
66        | TLExpr::Log(e)
67        | TLExpr::Sin(e)
68        | TLExpr::Cos(e)
69        | TLExpr::Tan(e)
70        | TLExpr::Box(e)
71        | TLExpr::Diamond(e)
72        | TLExpr::Next(e)
73        | TLExpr::Eventually(e)
74        | TLExpr::Always(e) => count_free_in(var, e),
75
76        TLExpr::FuzzyNot { expr, .. } => count_free_in(var, expr),
77        TLExpr::WeightedRule { rule, .. } => count_free_in(var, rule),
78
79        TLExpr::Until { before, after }
80        | TLExpr::Release {
81            released: before,
82            releaser: after,
83        }
84        | TLExpr::WeakUntil { before, after }
85        | TLExpr::StrongRelease {
86            released: before,
87            releaser: after,
88        } => count_free_in(var, before) + count_free_in(var, after),
89
90        TLExpr::IfThenElse {
91            condition,
92            then_branch,
93            else_branch,
94        } => {
95            count_free_in(var, condition)
96                + count_free_in(var, then_branch)
97                + count_free_in(var, else_branch)
98        }
99
100        TLExpr::Apply { function, argument } => {
101            count_free_in(var, function) + count_free_in(var, argument)
102        }
103
104        // ── Binders — shadow the variable in the body ────────────────────
105        TLExpr::Exists {
106            var: binder, body, ..
107        }
108        | TLExpr::ForAll {
109            var: binder, body, ..
110        }
111        | TLExpr::SoftExists {
112            var: binder, body, ..
113        }
114        | TLExpr::SoftForAll {
115            var: binder, body, ..
116        }
117        | TLExpr::CountingExists {
118            var: binder, body, ..
119        }
120        | TLExpr::CountingForAll {
121            var: binder, body, ..
122        }
123        | TLExpr::ExactCount {
124            var: binder, body, ..
125        }
126        | TLExpr::Majority {
127            var: binder, body, ..
128        }
129        | TLExpr::LeastFixpoint { var: binder, body }
130        | TLExpr::GreatestFixpoint { var: binder, body } => {
131            if binder == var {
132                0
133            } else {
134                count_free_in(var, body)
135            }
136        }
137
138        TLExpr::Lambda {
139            var: binder, body, ..
140        } => {
141            if binder == var {
142                0
143            } else {
144                count_free_in(var, body)
145            }
146        }
147
148        TLExpr::Aggregate {
149            var: binder,
150            body,
151            group_by,
152            ..
153        } => {
154            let in_body = if binder == var {
155                0
156            } else {
157                count_free_in(var, body)
158            };
159            let in_group = group_by
160                .as_ref()
161                .map(|gs| gs.iter().filter(|g| g.as_str() == var).count())
162                .unwrap_or(0);
163            in_body + in_group
164        }
165
166        // For Let: value is in scope of outer env, body is in scope of
167        // the binding; if binder == var, occurrences in body are shadowed.
168        TLExpr::Let {
169            var: binder,
170            value,
171            body,
172        } => {
173            let in_value = count_free_in(var, value);
174            let in_body = if binder == var {
175                0
176            } else {
177                count_free_in(var, body)
178            };
179            in_value + in_body
180        }
181
182        TLExpr::SetComprehension {
183            var: binder,
184            condition,
185            ..
186        } => {
187            if binder == var {
188                0
189            } else {
190                count_free_in(var, condition)
191            }
192        }
193
194        TLExpr::SetMembership { element, set }
195        | TLExpr::SetUnion {
196            left: element,
197            right: set,
198        }
199        | TLExpr::SetIntersection {
200            left: element,
201            right: set,
202        }
203        | TLExpr::SetDifference {
204            left: element,
205            right: set,
206        } => count_free_in(var, element) + count_free_in(var, set),
207
208        TLExpr::SetCardinality { set } => count_free_in(var, set),
209
210        TLExpr::At { formula, .. } => count_free_in(var, formula),
211        TLExpr::Somewhere { formula } | TLExpr::Everywhere { formula } => {
212            count_free_in(var, formula)
213        }
214        TLExpr::Explain { formula } => count_free_in(var, formula),
215
216        TLExpr::ProbabilisticChoice { alternatives } => alternatives
217            .iter()
218            .map(|(_, e)| count_free_in(var, e))
219            .sum(),
220
221        TLExpr::AllDifferent { variables } => {
222            variables.iter().filter(|v| v.as_str() == var).count()
223        }
224        TLExpr::GlobalCardinality {
225            variables, values, ..
226        } => {
227            let in_vars = variables.iter().filter(|v| v.as_str() == var).count();
228            let in_vals: usize = values.iter().map(|e| count_free_in(var, e)).sum();
229            in_vars + in_vals
230        }
231
232        // ── Leaves with no variable occurrences ──────────────────────────
233        TLExpr::Constant(_)
234        | TLExpr::EmptySet
235        | TLExpr::Nominal { .. }
236        | TLExpr::Abducible { .. }
237        | TLExpr::SymbolLiteral(_) => 0,
238
239        TLExpr::Match { scrutinee, arms } => {
240            count_free_in(var, scrutinee)
241                + arms
242                    .iter()
243                    .map(|(_, b)| count_free_in(var, b))
244                    .sum::<usize>()
245        }
246    }
247}
248
249/// Returns `true` if `expr` is a constant literal (`Constant(_)`).
250pub fn is_constant_binding(expr: &TLExpr) -> bool {
251    matches!(expr, TLExpr::Constant(_))
252}
253
254/// Returns `true` if `expr` is a zero-argument predicate (variable alias).
255pub fn is_var_binding(expr: &TLExpr) -> bool {
256    matches!(expr, TLExpr::Pred { args, .. } if args.is_empty())
257}
258
259/// Returns `true` if `expr` is a "simple" binding worth inlining regardless
260/// of use count: either a constant or a variable alias.
261pub fn is_simple_binding(expr: &TLExpr) -> bool {
262    is_constant_binding(expr) || is_var_binding(expr)
263}
264
265/// Compute the depth (height) of an expression tree.
266///
267/// Leaf nodes have depth 1; each internal node adds 1 to the maximum
268/// depth of its children.
269pub fn expr_depth(expr: &TLExpr) -> usize {
270    match expr {
271        // ── Binary nodes ─────────────────────────────────────────────────
272        TLExpr::And(l, r)
273        | TLExpr::Or(l, r)
274        | TLExpr::Imply(l, r)
275        | TLExpr::Add(l, r)
276        | TLExpr::Sub(l, r)
277        | TLExpr::Mul(l, r)
278        | TLExpr::Div(l, r)
279        | TLExpr::Pow(l, r)
280        | TLExpr::Mod(l, r)
281        | TLExpr::Min(l, r)
282        | TLExpr::Max(l, r)
283        | TLExpr::Eq(l, r)
284        | TLExpr::Lt(l, r)
285        | TLExpr::Gt(l, r)
286        | TLExpr::Lte(l, r)
287        | TLExpr::Gte(l, r) => 1 + expr_depth(l).max(expr_depth(r)),
288
289        TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
290            1 + expr_depth(left).max(expr_depth(right))
291        }
292        TLExpr::FuzzyImplication {
293            premise,
294            conclusion,
295            ..
296        } => 1 + expr_depth(premise).max(expr_depth(conclusion)),
297
298        // ── Unary nodes ──────────────────────────────────────────────────
299        TLExpr::Not(e)
300        | TLExpr::Score(e)
301        | TLExpr::Abs(e)
302        | TLExpr::Floor(e)
303        | TLExpr::Ceil(e)
304        | TLExpr::Round(e)
305        | TLExpr::Sqrt(e)
306        | TLExpr::Exp(e)
307        | TLExpr::Log(e)
308        | TLExpr::Sin(e)
309        | TLExpr::Cos(e)
310        | TLExpr::Tan(e)
311        | TLExpr::Box(e)
312        | TLExpr::Diamond(e)
313        | TLExpr::Next(e)
314        | TLExpr::Eventually(e)
315        | TLExpr::Always(e) => 1 + expr_depth(e),
316
317        TLExpr::FuzzyNot { expr, .. } => 1 + expr_depth(expr),
318        TLExpr::WeightedRule { rule, .. } => 1 + expr_depth(rule),
319
320        TLExpr::Until { before, after }
321        | TLExpr::Release {
322            released: before,
323            releaser: after,
324        }
325        | TLExpr::WeakUntil { before, after }
326        | TLExpr::StrongRelease {
327            released: before,
328            releaser: after,
329        } => 1 + expr_depth(before).max(expr_depth(after)),
330
331        TLExpr::IfThenElse {
332            condition,
333            then_branch,
334            else_branch,
335        } => {
336            1 + expr_depth(condition)
337                .max(expr_depth(then_branch))
338                .max(expr_depth(else_branch))
339        }
340
341        TLExpr::Exists { body, .. }
342        | TLExpr::ForAll { body, .. }
343        | TLExpr::SoftExists { body, .. }
344        | TLExpr::SoftForAll { body, .. }
345        | TLExpr::Aggregate { body, .. }
346        | TLExpr::Lambda { body, .. }
347        | TLExpr::SetComprehension {
348            condition: body, ..
349        }
350        | TLExpr::CountingExists { body, .. }
351        | TLExpr::CountingForAll { body, .. }
352        | TLExpr::ExactCount { body, .. }
353        | TLExpr::Majority { body, .. }
354        | TLExpr::LeastFixpoint { body, .. }
355        | TLExpr::GreatestFixpoint { body, .. } => 1 + expr_depth(body),
356
357        TLExpr::Let { value, body, .. } => 1 + expr_depth(value).max(expr_depth(body)),
358
359        TLExpr::Apply { function, argument } => 1 + expr_depth(function).max(expr_depth(argument)),
360
361        TLExpr::SetMembership { element, set }
362        | TLExpr::SetUnion {
363            left: element,
364            right: set,
365        }
366        | TLExpr::SetIntersection {
367            left: element,
368            right: set,
369        }
370        | TLExpr::SetDifference {
371            left: element,
372            right: set,
373        } => 1 + expr_depth(element).max(expr_depth(set)),
374
375        TLExpr::SetCardinality { set } => 1 + expr_depth(set),
376
377        TLExpr::At { formula, .. } => 1 + expr_depth(formula),
378        TLExpr::Somewhere { formula }
379        | TLExpr::Everywhere { formula }
380        | TLExpr::Explain { formula } => 1 + expr_depth(formula),
381
382        TLExpr::ProbabilisticChoice { alternatives } => {
383            let max_depth = alternatives
384                .iter()
385                .map(|(_, e)| expr_depth(e))
386                .max()
387                .unwrap_or(0);
388            1 + max_depth
389        }
390
391        // ── Leaves ───────────────────────────────────────────────────────
392        TLExpr::Pred { .. }
393        | TLExpr::Constant(_)
394        | TLExpr::EmptySet
395        | TLExpr::AllDifferent { .. }
396        | TLExpr::GlobalCardinality { .. }
397        | TLExpr::Nominal { .. }
398        | TLExpr::Abducible { .. }
399        | TLExpr::SymbolLiteral(_) => 1,
400
401        TLExpr::Match { scrutinee, arms } => {
402            1 + expr_depth(scrutinee) + arms.iter().map(|(_, b)| expr_depth(b)).max().unwrap_or(0)
403        }
404    }
405}
406
407/// Count total nodes in an expression tree.
408pub fn count_nodes(expr: &TLExpr) -> u64 {
409    match expr {
410        TLExpr::And(l, r)
411        | TLExpr::Or(l, r)
412        | TLExpr::Imply(l, r)
413        | TLExpr::Add(l, r)
414        | TLExpr::Sub(l, r)
415        | TLExpr::Mul(l, r)
416        | TLExpr::Div(l, r)
417        | TLExpr::Pow(l, r)
418        | TLExpr::Mod(l, r)
419        | TLExpr::Min(l, r)
420        | TLExpr::Max(l, r)
421        | TLExpr::Eq(l, r)
422        | TLExpr::Lt(l, r)
423        | TLExpr::Gt(l, r)
424        | TLExpr::Lte(l, r)
425        | TLExpr::Gte(l, r) => 1 + count_nodes(l) + count_nodes(r),
426
427        TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
428            1 + count_nodes(left) + count_nodes(right)
429        }
430        TLExpr::FuzzyImplication {
431            premise,
432            conclusion,
433            ..
434        } => 1 + count_nodes(premise) + count_nodes(conclusion),
435
436        TLExpr::Not(e)
437        | TLExpr::Score(e)
438        | TLExpr::Abs(e)
439        | TLExpr::Floor(e)
440        | TLExpr::Ceil(e)
441        | TLExpr::Round(e)
442        | TLExpr::Sqrt(e)
443        | TLExpr::Exp(e)
444        | TLExpr::Log(e)
445        | TLExpr::Sin(e)
446        | TLExpr::Cos(e)
447        | TLExpr::Tan(e)
448        | TLExpr::Box(e)
449        | TLExpr::Diamond(e)
450        | TLExpr::Next(e)
451        | TLExpr::Eventually(e)
452        | TLExpr::Always(e) => 1 + count_nodes(e),
453
454        TLExpr::FuzzyNot { expr, .. } => 1 + count_nodes(expr),
455        TLExpr::WeightedRule { rule, .. } => 1 + count_nodes(rule),
456
457        TLExpr::Until { before, after }
458        | TLExpr::Release {
459            released: before,
460            releaser: after,
461        }
462        | TLExpr::WeakUntil { before, after }
463        | TLExpr::StrongRelease {
464            released: before,
465            releaser: after,
466        } => 1 + count_nodes(before) + count_nodes(after),
467
468        TLExpr::IfThenElse {
469            condition,
470            then_branch,
471            else_branch,
472        } => 1 + count_nodes(condition) + count_nodes(then_branch) + count_nodes(else_branch),
473
474        TLExpr::Exists { body, .. }
475        | TLExpr::ForAll { body, .. }
476        | TLExpr::SoftExists { body, .. }
477        | TLExpr::SoftForAll { body, .. }
478        | TLExpr::Aggregate { body, .. }
479        | TLExpr::Lambda { body, .. }
480        | TLExpr::SetComprehension {
481            condition: body, ..
482        }
483        | TLExpr::CountingExists { body, .. }
484        | TLExpr::CountingForAll { body, .. }
485        | TLExpr::ExactCount { body, .. }
486        | TLExpr::Majority { body, .. }
487        | TLExpr::LeastFixpoint { body, .. }
488        | TLExpr::GreatestFixpoint { body, .. } => 1 + count_nodes(body),
489
490        TLExpr::Let { value, body, .. } => 1 + count_nodes(value) + count_nodes(body),
491
492        TLExpr::Apply { function, argument } => 1 + count_nodes(function) + count_nodes(argument),
493
494        TLExpr::SetMembership { element, set }
495        | TLExpr::SetUnion {
496            left: element,
497            right: set,
498        }
499        | TLExpr::SetIntersection {
500            left: element,
501            right: set,
502        }
503        | TLExpr::SetDifference {
504            left: element,
505            right: set,
506        } => 1 + count_nodes(element) + count_nodes(set),
507
508        TLExpr::SetCardinality { set } => 1 + count_nodes(set),
509
510        TLExpr::At { formula, .. } => 1 + count_nodes(formula),
511        TLExpr::Somewhere { formula }
512        | TLExpr::Everywhere { formula }
513        | TLExpr::Explain { formula } => 1 + count_nodes(formula),
514
515        TLExpr::ProbabilisticChoice { alternatives } => {
516            1 + alternatives
517                .iter()
518                .map(|(_, e)| count_nodes(e))
519                .sum::<u64>()
520        }
521
522        TLExpr::Pred { .. }
523        | TLExpr::Constant(_)
524        | TLExpr::EmptySet
525        | TLExpr::AllDifferent { .. }
526        | TLExpr::GlobalCardinality { .. }
527        | TLExpr::Nominal { .. }
528        | TLExpr::Abducible { .. }
529        | TLExpr::SymbolLiteral(_) => 1,
530
531        TLExpr::Match { scrutinee, arms } => {
532            1 + count_nodes(scrutinee) + arms.iter().map(|(_, b)| count_nodes(b)).sum::<u64>()
533        }
534    }
535}