Skip to main content

tensorlogic_ir/expr/
analysis.rs

1//! Expression analysis (free variables, predicate collection).
2
3use std::collections::{HashMap, HashSet};
4
5use crate::term::Term;
6
7use super::TLExpr;
8
9impl TLExpr {
10    /// Collect all free variables in this expression
11    pub fn free_vars(&self) -> HashSet<String> {
12        let mut vars = HashSet::new();
13        self.collect_free_vars(&mut vars, &HashSet::new());
14        vars
15    }
16
17    pub(crate) fn collect_free_vars(&self, vars: &mut HashSet<String>, bound: &HashSet<String>) {
18        match self {
19            TLExpr::Pred { args, .. } => {
20                for arg in args {
21                    if let Term::Var(v) = arg {
22                        if !bound.contains(v) {
23                            vars.insert(v.clone());
24                        }
25                    }
26                }
27            }
28            TLExpr::And(l, r)
29            | TLExpr::Or(l, r)
30            | TLExpr::Imply(l, r)
31            | TLExpr::Add(l, r)
32            | TLExpr::Sub(l, r)
33            | TLExpr::Mul(l, r)
34            | TLExpr::Div(l, r)
35            | TLExpr::Pow(l, r)
36            | TLExpr::Mod(l, r)
37            | TLExpr::Min(l, r)
38            | TLExpr::Max(l, r)
39            | TLExpr::Eq(l, r)
40            | TLExpr::Lt(l, r)
41            | TLExpr::Gt(l, r)
42            | TLExpr::Lte(l, r)
43            | TLExpr::Gte(l, r) => {
44                l.collect_free_vars(vars, bound);
45                r.collect_free_vars(vars, bound);
46            }
47            TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
48                left.collect_free_vars(vars, bound);
49                right.collect_free_vars(vars, bound);
50            }
51            TLExpr::FuzzyImplication {
52                premise,
53                conclusion,
54                ..
55            } => {
56                premise.collect_free_vars(vars, bound);
57                conclusion.collect_free_vars(vars, bound);
58            }
59            TLExpr::Not(e)
60            | TLExpr::Score(e)
61            | TLExpr::Abs(e)
62            | TLExpr::Floor(e)
63            | TLExpr::Ceil(e)
64            | TLExpr::Round(e)
65            | TLExpr::Sqrt(e)
66            | TLExpr::Exp(e)
67            | TLExpr::Log(e)
68            | TLExpr::Sin(e)
69            | TLExpr::Cos(e)
70            | TLExpr::Tan(e)
71            | TLExpr::Box(e)
72            | TLExpr::Diamond(e)
73            | TLExpr::Next(e)
74            | TLExpr::Eventually(e)
75            | TLExpr::Always(e) => {
76                e.collect_free_vars(vars, bound);
77            }
78            TLExpr::FuzzyNot { expr, .. } => {
79                expr.collect_free_vars(vars, bound);
80            }
81            TLExpr::WeightedRule { rule, .. } => {
82                rule.collect_free_vars(vars, bound);
83            }
84            TLExpr::Until { before, after }
85            | TLExpr::Release {
86                released: before,
87                releaser: after,
88            }
89            | TLExpr::WeakUntil { before, after }
90            | TLExpr::StrongRelease {
91                released: before,
92                releaser: after,
93            } => {
94                before.collect_free_vars(vars, bound);
95                after.collect_free_vars(vars, bound);
96            }
97            TLExpr::Exists { var, body, .. }
98            | TLExpr::ForAll { var, body, .. }
99            | TLExpr::SoftExists { var, body, .. }
100            | TLExpr::SoftForAll { var, body, .. } => {
101                let mut new_bound = bound.clone();
102                new_bound.insert(var.clone());
103                body.collect_free_vars(vars, &new_bound);
104            }
105            TLExpr::Aggregate {
106                var,
107                body,
108                group_by,
109                ..
110            } => {
111                let mut new_bound = bound.clone();
112                new_bound.insert(var.clone());
113                body.collect_free_vars(vars, &new_bound);
114                // Group-by variables are free if not already bound
115                if let Some(group_vars) = group_by {
116                    for gv in group_vars {
117                        if !bound.contains(gv) {
118                            vars.insert(gv.clone());
119                        }
120                    }
121                }
122            }
123            TLExpr::IfThenElse {
124                condition,
125                then_branch,
126                else_branch,
127            } => {
128                condition.collect_free_vars(vars, bound);
129                then_branch.collect_free_vars(vars, bound);
130                else_branch.collect_free_vars(vars, bound);
131            }
132            TLExpr::Let { var, value, body } => {
133                // First collect free vars from the value expression
134                value.collect_free_vars(vars, bound);
135                // Then collect from body with the variable bound
136                let mut new_bound = bound.clone();
137                new_bound.insert(var.clone());
138                body.collect_free_vars(vars, &new_bound);
139            }
140            TLExpr::Constant(_) => {
141                // No free variables in constants
142            }
143            TLExpr::ProbabilisticChoice { alternatives } => {
144                for (_, expr) in alternatives {
145                    expr.collect_free_vars(vars, bound);
146                }
147            }
148            // Beta.1 enhancements
149            TLExpr::Lambda { var, body, .. } => {
150                // Lambda binds the variable
151                let mut new_bound = bound.clone();
152                new_bound.insert(var.clone());
153                body.collect_free_vars(vars, &new_bound);
154            }
155            TLExpr::Apply { function, argument } => {
156                function.collect_free_vars(vars, bound);
157                argument.collect_free_vars(vars, bound);
158            }
159            TLExpr::SetMembership { element, set }
160            | TLExpr::SetUnion {
161                left: element,
162                right: set,
163            }
164            | TLExpr::SetIntersection {
165                left: element,
166                right: set,
167            }
168            | TLExpr::SetDifference {
169                left: element,
170                right: set,
171            } => {
172                element.collect_free_vars(vars, bound);
173                set.collect_free_vars(vars, bound);
174            }
175            TLExpr::SetCardinality { set } => {
176                set.collect_free_vars(vars, bound);
177            }
178            TLExpr::EmptySet => {
179                // No free variables
180            }
181            TLExpr::SetComprehension { var, condition, .. } => {
182                // Set comprehension binds the variable
183                let mut new_bound = bound.clone();
184                new_bound.insert(var.clone());
185                condition.collect_free_vars(vars, &new_bound);
186            }
187            TLExpr::CountingExists { var, body, .. }
188            | TLExpr::CountingForAll { var, body, .. }
189            | TLExpr::ExactCount { var, body, .. }
190            | TLExpr::Majority { var, body, .. } => {
191                // Counting quantifiers bind the variable
192                let mut new_bound = bound.clone();
193                new_bound.insert(var.clone());
194                body.collect_free_vars(vars, &new_bound);
195            }
196            TLExpr::LeastFixpoint { var, body } | TLExpr::GreatestFixpoint { var, body } => {
197                // Fixed-point operators bind the variable
198                let mut new_bound = bound.clone();
199                new_bound.insert(var.clone());
200                body.collect_free_vars(vars, &new_bound);
201            }
202            TLExpr::Nominal { .. } => {
203                // No free variables
204            }
205            TLExpr::At { formula, .. } => {
206                formula.collect_free_vars(vars, bound);
207            }
208            TLExpr::Somewhere { formula } | TLExpr::Everywhere { formula } => {
209                formula.collect_free_vars(vars, bound);
210            }
211            TLExpr::AllDifferent { variables } => {
212                // Variables in constraint are free if not bound
213                for v in variables {
214                    if !bound.contains(v) {
215                        vars.insert(v.clone());
216                    }
217                }
218            }
219            TLExpr::GlobalCardinality {
220                variables, values, ..
221            } => {
222                // Variables are free if not bound
223                for v in variables {
224                    if !bound.contains(v) {
225                        vars.insert(v.clone());
226                    }
227                }
228                // Collect from value expressions
229                for val in values {
230                    val.collect_free_vars(vars, bound);
231                }
232            }
233            TLExpr::Abducible { .. } => {
234                // No free variables (it's a literal)
235            }
236            TLExpr::Explain { formula } => {
237                formula.collect_free_vars(vars, bound);
238            }
239        }
240    }
241
242    /// Collect all predicates and their arities
243    pub fn all_predicates(&self) -> HashMap<String, usize> {
244        let mut preds = HashMap::new();
245        self.collect_predicates(&mut preds);
246        preds
247    }
248
249    pub(crate) fn collect_predicates(&self, preds: &mut HashMap<String, usize>) {
250        match self {
251            TLExpr::Pred { name, args } => {
252                preds.entry(name.clone()).or_insert(args.len());
253            }
254            TLExpr::And(l, r)
255            | TLExpr::Or(l, r)
256            | TLExpr::Imply(l, r)
257            | TLExpr::Add(l, r)
258            | TLExpr::Sub(l, r)
259            | TLExpr::Mul(l, r)
260            | TLExpr::Div(l, r)
261            | TLExpr::Pow(l, r)
262            | TLExpr::Mod(l, r)
263            | TLExpr::Min(l, r)
264            | TLExpr::Max(l, r)
265            | TLExpr::Eq(l, r)
266            | TLExpr::Lt(l, r)
267            | TLExpr::Gt(l, r)
268            | TLExpr::Lte(l, r)
269            | TLExpr::Gte(l, r) => {
270                l.collect_predicates(preds);
271                r.collect_predicates(preds);
272            }
273            TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
274                left.collect_predicates(preds);
275                right.collect_predicates(preds);
276            }
277            TLExpr::FuzzyImplication {
278                premise,
279                conclusion,
280                ..
281            } => {
282                premise.collect_predicates(preds);
283                conclusion.collect_predicates(preds);
284            }
285            TLExpr::Not(e)
286            | TLExpr::Score(e)
287            | TLExpr::Abs(e)
288            | TLExpr::Floor(e)
289            | TLExpr::Ceil(e)
290            | TLExpr::Round(e)
291            | TLExpr::Sqrt(e)
292            | TLExpr::Exp(e)
293            | TLExpr::Log(e)
294            | TLExpr::Sin(e)
295            | TLExpr::Cos(e)
296            | TLExpr::Tan(e)
297            | TLExpr::Box(e)
298            | TLExpr::Diamond(e)
299            | TLExpr::Next(e)
300            | TLExpr::Eventually(e)
301            | TLExpr::Always(e) => {
302                e.collect_predicates(preds);
303            }
304            TLExpr::FuzzyNot { expr, .. } => {
305                expr.collect_predicates(preds);
306            }
307            TLExpr::WeightedRule { rule, .. } => {
308                rule.collect_predicates(preds);
309            }
310            TLExpr::Until { before, after }
311            | TLExpr::Release {
312                released: before,
313                releaser: after,
314            }
315            | TLExpr::WeakUntil { before, after }
316            | TLExpr::StrongRelease {
317                released: before,
318                releaser: after,
319            } => {
320                before.collect_predicates(preds);
321                after.collect_predicates(preds);
322            }
323            TLExpr::Exists { body, .. }
324            | TLExpr::ForAll { body, .. }
325            | TLExpr::SoftExists { body, .. }
326            | TLExpr::SoftForAll { body, .. } => {
327                body.collect_predicates(preds);
328            }
329            TLExpr::Aggregate { body, .. } => {
330                body.collect_predicates(preds);
331            }
332            TLExpr::IfThenElse {
333                condition,
334                then_branch,
335                else_branch,
336            } => {
337                condition.collect_predicates(preds);
338                then_branch.collect_predicates(preds);
339                else_branch.collect_predicates(preds);
340            }
341            TLExpr::Let { value, body, .. } => {
342                value.collect_predicates(preds);
343                body.collect_predicates(preds);
344            }
345            TLExpr::Constant(_) => {
346                // No predicates in constants
347            }
348            TLExpr::ProbabilisticChoice { alternatives } => {
349                for (_, expr) in alternatives {
350                    expr.collect_predicates(preds);
351                }
352            }
353            // Beta.1 enhancements
354            TLExpr::Lambda { body, .. } => {
355                body.collect_predicates(preds);
356            }
357            TLExpr::Apply { function, argument } => {
358                function.collect_predicates(preds);
359                argument.collect_predicates(preds);
360            }
361            TLExpr::SetMembership { element, set }
362            | TLExpr::SetUnion {
363                left: element,
364                right: set,
365            }
366            | TLExpr::SetIntersection {
367                left: element,
368                right: set,
369            }
370            | TLExpr::SetDifference {
371                left: element,
372                right: set,
373            } => {
374                element.collect_predicates(preds);
375                set.collect_predicates(preds);
376            }
377            TLExpr::SetCardinality { set } => {
378                set.collect_predicates(preds);
379            }
380            TLExpr::EmptySet => {
381                // No predicates
382            }
383            TLExpr::SetComprehension { condition, .. } => {
384                condition.collect_predicates(preds);
385            }
386            TLExpr::CountingExists { body, .. }
387            | TLExpr::CountingForAll { body, .. }
388            | TLExpr::ExactCount { body, .. }
389            | TLExpr::Majority { body, .. } => {
390                body.collect_predicates(preds);
391            }
392            TLExpr::LeastFixpoint { body, .. } | TLExpr::GreatestFixpoint { body, .. } => {
393                body.collect_predicates(preds);
394            }
395            TLExpr::Nominal { .. } => {
396                // No predicates
397            }
398            TLExpr::At { formula, .. } => {
399                formula.collect_predicates(preds);
400            }
401            TLExpr::Somewhere { formula } | TLExpr::Everywhere { formula } => {
402                formula.collect_predicates(preds);
403            }
404            TLExpr::AllDifferent { .. } => {
405                // No predicates (constraint on variables)
406            }
407            TLExpr::GlobalCardinality { values, .. } => {
408                // Collect from value expressions
409                for val in values {
410                    val.collect_predicates(preds);
411                }
412            }
413            TLExpr::Abducible { .. } => {
414                // No predicates
415            }
416            TLExpr::Explain { formula } => {
417                formula.collect_predicates(preds);
418            }
419        }
420    }
421}