tensorlogic_ir/expr/
analysis.rs

1//! Expression analysis (free variables, predicate collection).
2
3use std::collections::{HashMap, HashSet};
4
5use crate::term::Term;
6
7use super::TLExpr;
8
9impl TLExpr {
10    /// Collect all free variables in this expression
11    pub fn free_vars(&self) -> HashSet<String> {
12        let mut vars = HashSet::new();
13        self.collect_free_vars(&mut vars, &HashSet::new());
14        vars
15    }
16
17    pub(crate) fn collect_free_vars(&self, vars: &mut HashSet<String>, bound: &HashSet<String>) {
18        match self {
19            TLExpr::Pred { args, .. } => {
20                for arg in args {
21                    if let Term::Var(v) = arg {
22                        if !bound.contains(v) {
23                            vars.insert(v.clone());
24                        }
25                    }
26                }
27            }
28            TLExpr::And(l, r)
29            | TLExpr::Or(l, r)
30            | TLExpr::Imply(l, r)
31            | TLExpr::Add(l, r)
32            | TLExpr::Sub(l, r)
33            | TLExpr::Mul(l, r)
34            | TLExpr::Div(l, r)
35            | TLExpr::Pow(l, r)
36            | TLExpr::Mod(l, r)
37            | TLExpr::Min(l, r)
38            | TLExpr::Max(l, r)
39            | TLExpr::Eq(l, r)
40            | TLExpr::Lt(l, r)
41            | TLExpr::Gt(l, r)
42            | TLExpr::Lte(l, r)
43            | TLExpr::Gte(l, r) => {
44                l.collect_free_vars(vars, bound);
45                r.collect_free_vars(vars, bound);
46            }
47            TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
48                left.collect_free_vars(vars, bound);
49                right.collect_free_vars(vars, bound);
50            }
51            TLExpr::FuzzyImplication {
52                premise,
53                conclusion,
54                ..
55            } => {
56                premise.collect_free_vars(vars, bound);
57                conclusion.collect_free_vars(vars, bound);
58            }
59            TLExpr::Not(e)
60            | TLExpr::Score(e)
61            | TLExpr::Abs(e)
62            | TLExpr::Floor(e)
63            | TLExpr::Ceil(e)
64            | TLExpr::Round(e)
65            | TLExpr::Sqrt(e)
66            | TLExpr::Exp(e)
67            | TLExpr::Log(e)
68            | TLExpr::Sin(e)
69            | TLExpr::Cos(e)
70            | TLExpr::Tan(e)
71            | TLExpr::Box(e)
72            | TLExpr::Diamond(e)
73            | TLExpr::Next(e)
74            | TLExpr::Eventually(e)
75            | TLExpr::Always(e) => {
76                e.collect_free_vars(vars, bound);
77            }
78            TLExpr::FuzzyNot { expr, .. } => {
79                expr.collect_free_vars(vars, bound);
80            }
81            TLExpr::WeightedRule { rule, .. } => {
82                rule.collect_free_vars(vars, bound);
83            }
84            TLExpr::Until { before, after }
85            | TLExpr::Release {
86                released: before,
87                releaser: after,
88            }
89            | TLExpr::WeakUntil { before, after }
90            | TLExpr::StrongRelease {
91                released: before,
92                releaser: after,
93            } => {
94                before.collect_free_vars(vars, bound);
95                after.collect_free_vars(vars, bound);
96            }
97            TLExpr::Exists { var, body, .. }
98            | TLExpr::ForAll { var, body, .. }
99            | TLExpr::SoftExists { var, body, .. }
100            | TLExpr::SoftForAll { var, body, .. } => {
101                let mut new_bound = bound.clone();
102                new_bound.insert(var.clone());
103                body.collect_free_vars(vars, &new_bound);
104            }
105            TLExpr::Aggregate {
106                var,
107                body,
108                group_by,
109                ..
110            } => {
111                let mut new_bound = bound.clone();
112                new_bound.insert(var.clone());
113                body.collect_free_vars(vars, &new_bound);
114                // Group-by variables are free if not already bound
115                if let Some(group_vars) = group_by {
116                    for gv in group_vars {
117                        if !bound.contains(gv) {
118                            vars.insert(gv.clone());
119                        }
120                    }
121                }
122            }
123            TLExpr::IfThenElse {
124                condition,
125                then_branch,
126                else_branch,
127            } => {
128                condition.collect_free_vars(vars, bound);
129                then_branch.collect_free_vars(vars, bound);
130                else_branch.collect_free_vars(vars, bound);
131            }
132            TLExpr::Let { var, value, body } => {
133                // First collect free vars from the value expression
134                value.collect_free_vars(vars, bound);
135                // Then collect from body with the variable bound
136                let mut new_bound = bound.clone();
137                new_bound.insert(var.clone());
138                body.collect_free_vars(vars, &new_bound);
139            }
140            TLExpr::Constant(_) => {
141                // No free variables in constants
142            }
143            TLExpr::ProbabilisticChoice { alternatives } => {
144                for (_, expr) in alternatives {
145                    expr.collect_free_vars(vars, bound);
146                }
147            }
148        }
149    }
150
151    /// Collect all predicates and their arities
152    pub fn all_predicates(&self) -> HashMap<String, usize> {
153        let mut preds = HashMap::new();
154        self.collect_predicates(&mut preds);
155        preds
156    }
157
158    pub(crate) fn collect_predicates(&self, preds: &mut HashMap<String, usize>) {
159        match self {
160            TLExpr::Pred { name, args } => {
161                preds.entry(name.clone()).or_insert(args.len());
162            }
163            TLExpr::And(l, r)
164            | TLExpr::Or(l, r)
165            | TLExpr::Imply(l, r)
166            | TLExpr::Add(l, r)
167            | TLExpr::Sub(l, r)
168            | TLExpr::Mul(l, r)
169            | TLExpr::Div(l, r)
170            | TLExpr::Pow(l, r)
171            | TLExpr::Mod(l, r)
172            | TLExpr::Min(l, r)
173            | TLExpr::Max(l, r)
174            | TLExpr::Eq(l, r)
175            | TLExpr::Lt(l, r)
176            | TLExpr::Gt(l, r)
177            | TLExpr::Lte(l, r)
178            | TLExpr::Gte(l, r) => {
179                l.collect_predicates(preds);
180                r.collect_predicates(preds);
181            }
182            TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
183                left.collect_predicates(preds);
184                right.collect_predicates(preds);
185            }
186            TLExpr::FuzzyImplication {
187                premise,
188                conclusion,
189                ..
190            } => {
191                premise.collect_predicates(preds);
192                conclusion.collect_predicates(preds);
193            }
194            TLExpr::Not(e)
195            | TLExpr::Score(e)
196            | TLExpr::Abs(e)
197            | TLExpr::Floor(e)
198            | TLExpr::Ceil(e)
199            | TLExpr::Round(e)
200            | TLExpr::Sqrt(e)
201            | TLExpr::Exp(e)
202            | TLExpr::Log(e)
203            | TLExpr::Sin(e)
204            | TLExpr::Cos(e)
205            | TLExpr::Tan(e)
206            | TLExpr::Box(e)
207            | TLExpr::Diamond(e)
208            | TLExpr::Next(e)
209            | TLExpr::Eventually(e)
210            | TLExpr::Always(e) => {
211                e.collect_predicates(preds);
212            }
213            TLExpr::FuzzyNot { expr, .. } => {
214                expr.collect_predicates(preds);
215            }
216            TLExpr::WeightedRule { rule, .. } => {
217                rule.collect_predicates(preds);
218            }
219            TLExpr::Until { before, after }
220            | TLExpr::Release {
221                released: before,
222                releaser: after,
223            }
224            | TLExpr::WeakUntil { before, after }
225            | TLExpr::StrongRelease {
226                released: before,
227                releaser: after,
228            } => {
229                before.collect_predicates(preds);
230                after.collect_predicates(preds);
231            }
232            TLExpr::Exists { body, .. }
233            | TLExpr::ForAll { body, .. }
234            | TLExpr::SoftExists { body, .. }
235            | TLExpr::SoftForAll { body, .. } => {
236                body.collect_predicates(preds);
237            }
238            TLExpr::Aggregate { body, .. } => {
239                body.collect_predicates(preds);
240            }
241            TLExpr::IfThenElse {
242                condition,
243                then_branch,
244                else_branch,
245            } => {
246                condition.collect_predicates(preds);
247                then_branch.collect_predicates(preds);
248                else_branch.collect_predicates(preds);
249            }
250            TLExpr::Let { value, body, .. } => {
251                value.collect_predicates(preds);
252                body.collect_predicates(preds);
253            }
254            TLExpr::Constant(_) => {
255                // No predicates in constants
256            }
257            TLExpr::ProbabilisticChoice { alternatives } => {
258                for (_, expr) in alternatives {
259                    expr.collect_predicates(preds);
260                }
261            }
262        }
263    }
264}