Skip to main content

tensorlogic_compiler/dead_code/
node_count.rs

1//! Whole-tree node counting plus small recursion helpers.
2
3use tensorlogic_ir::TLExpr;
4
5use super::types::{DceStats, DeadCodeEliminator};
6
7impl DeadCodeEliminator {
8    /// Count the total number of nodes in an expression tree.
9    ///
10    /// Every node (leaf or internal) counts as 1.
11    pub fn count_nodes(expr: &TLExpr) -> u64 {
12        match expr {
13            TLExpr::And(l, r)
14            | TLExpr::Or(l, r)
15            | TLExpr::Imply(l, r)
16            | TLExpr::Add(l, r)
17            | TLExpr::Sub(l, r)
18            | TLExpr::Mul(l, r)
19            | TLExpr::Div(l, r)
20            | TLExpr::Pow(l, r)
21            | TLExpr::Mod(l, r)
22            | TLExpr::Min(l, r)
23            | TLExpr::Max(l, r)
24            | TLExpr::Eq(l, r)
25            | TLExpr::Lt(l, r)
26            | TLExpr::Gt(l, r)
27            | TLExpr::Lte(l, r)
28            | TLExpr::Gte(l, r) => 1 + Self::count_nodes(l) + Self::count_nodes(r),
29
30            TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
31                1 + Self::count_nodes(left) + Self::count_nodes(right)
32            }
33            TLExpr::FuzzyImplication {
34                premise,
35                conclusion,
36                ..
37            } => 1 + Self::count_nodes(premise) + Self::count_nodes(conclusion),
38
39            TLExpr::Not(e)
40            | TLExpr::Score(e)
41            | TLExpr::Abs(e)
42            | TLExpr::Floor(e)
43            | TLExpr::Ceil(e)
44            | TLExpr::Round(e)
45            | TLExpr::Sqrt(e)
46            | TLExpr::Exp(e)
47            | TLExpr::Log(e)
48            | TLExpr::Sin(e)
49            | TLExpr::Cos(e)
50            | TLExpr::Tan(e)
51            | TLExpr::Box(e)
52            | TLExpr::Diamond(e)
53            | TLExpr::Next(e)
54            | TLExpr::Eventually(e)
55            | TLExpr::Always(e) => 1 + Self::count_nodes(e),
56
57            TLExpr::FuzzyNot { expr, .. } => 1 + Self::count_nodes(expr),
58            TLExpr::WeightedRule { rule, .. } => 1 + Self::count_nodes(rule),
59
60            TLExpr::Until { before, after }
61            | TLExpr::Release {
62                released: before,
63                releaser: after,
64            }
65            | TLExpr::WeakUntil { before, after }
66            | TLExpr::StrongRelease {
67                released: before,
68                releaser: after,
69            } => 1 + Self::count_nodes(before) + Self::count_nodes(after),
70
71            TLExpr::IfThenElse {
72                condition,
73                then_branch,
74                else_branch,
75            } => {
76                1 + Self::count_nodes(condition)
77                    + Self::count_nodes(then_branch)
78                    + Self::count_nodes(else_branch)
79            }
80
81            TLExpr::Exists { body, .. }
82            | TLExpr::ForAll { body, .. }
83            | TLExpr::SoftExists { body, .. }
84            | TLExpr::SoftForAll { body, .. }
85            | TLExpr::Aggregate { body, .. }
86            | TLExpr::Lambda { body, .. }
87            | TLExpr::SetComprehension {
88                condition: body, ..
89            }
90            | TLExpr::CountingExists { body, .. }
91            | TLExpr::CountingForAll { body, .. }
92            | TLExpr::ExactCount { body, .. }
93            | TLExpr::Majority { body, .. }
94            | TLExpr::LeastFixpoint { body, .. }
95            | TLExpr::GreatestFixpoint { body, .. } => 1 + Self::count_nodes(body),
96
97            TLExpr::Let { value, body, .. } => {
98                1 + Self::count_nodes(value) + Self::count_nodes(body)
99            }
100
101            TLExpr::Apply { function, argument } => {
102                1 + Self::count_nodes(function) + Self::count_nodes(argument)
103            }
104
105            TLExpr::SetMembership { element, set }
106            | TLExpr::SetUnion {
107                left: element,
108                right: set,
109            }
110            | TLExpr::SetIntersection {
111                left: element,
112                right: set,
113            }
114            | TLExpr::SetDifference {
115                left: element,
116                right: set,
117            } => 1 + Self::count_nodes(element) + Self::count_nodes(set),
118
119            TLExpr::SetCardinality { set } => 1 + Self::count_nodes(set),
120
121            TLExpr::At { formula, .. }
122            | TLExpr::Somewhere { formula }
123            | TLExpr::Everywhere { formula }
124            | TLExpr::Explain { formula } => 1 + Self::count_nodes(formula),
125
126            TLExpr::ProbabilisticChoice { alternatives } => {
127                1 + alternatives
128                    .iter()
129                    .map(|(_, e)| Self::count_nodes(e))
130                    .sum::<u64>()
131            }
132
133            TLExpr::GlobalCardinality { values, .. } => {
134                1 + values.iter().map(Self::count_nodes).sum::<u64>()
135            }
136
137            TLExpr::Pred { .. }
138            | TLExpr::Constant(_)
139            | TLExpr::EmptySet
140            | TLExpr::AllDifferent { .. }
141            | TLExpr::Nominal { .. }
142            | TLExpr::Abducible { .. }
143            | TLExpr::SymbolLiteral(_) => 1,
144
145            TLExpr::Match { scrutinee, arms } => {
146                1 + Self::count_nodes(scrutinee)
147                    + arms.iter().map(|(_, b)| Self::count_nodes(b)).sum::<u64>()
148            }
149        }
150    }
151
152    /// Recurse into the single child of a unary constructor and reconstruct.
153    pub(super) fn map_unary<F>(
154        &self,
155        ctor: F,
156        child: TLExpr,
157        stats: &mut DceStats,
158    ) -> (TLExpr, bool)
159    where
160        F: Fn(Box<TLExpr>) -> TLExpr,
161    {
162        let (new_child, changed) = self.eliminate(child, stats);
163        (ctor(Box::new(new_child)), changed)
164    }
165
166    /// Recurse into both children of a binary constructor and reconstruct.
167    pub(super) fn map_binary<F>(
168        &self,
169        ctor: F,
170        left: TLExpr,
171        right: TLExpr,
172        stats: &mut DceStats,
173    ) -> (TLExpr, bool)
174    where
175        F: Fn(Box<TLExpr>, Box<TLExpr>) -> TLExpr,
176    {
177        let (nl, cl) = self.eliminate(left, stats);
178        let (nr, cr) = self.eliminate(right, stats);
179        (ctor(Box::new(nl), Box::new(nr)), cl || cr)
180    }
181}