Skip to main content

tensorlogic_ir/expr/
analysis.rs

1//! Expression analysis (free variables, predicate collection).
2
3use std::collections::{HashMap, HashSet};
4
5use crate::term::Term;
6
7use super::TLExpr;
8
9impl TLExpr {
10    /// Collect all free variables in this expression
11    pub fn free_vars(&self) -> HashSet<String> {
12        let mut vars = HashSet::new();
13        self.collect_free_vars(&mut vars, &HashSet::new());
14        vars
15    }
16
17    pub(crate) fn collect_free_vars(&self, vars: &mut HashSet<String>, bound: &HashSet<String>) {
18        match self {
19            TLExpr::Pred { args, .. } => {
20                for arg in args {
21                    if let Term::Var(v) = arg {
22                        if !bound.contains(v) {
23                            vars.insert(v.clone());
24                        }
25                    }
26                }
27            }
28            TLExpr::And(l, r)
29            | TLExpr::Or(l, r)
30            | TLExpr::Imply(l, r)
31            | TLExpr::Add(l, r)
32            | TLExpr::Sub(l, r)
33            | TLExpr::Mul(l, r)
34            | TLExpr::Div(l, r)
35            | TLExpr::Pow(l, r)
36            | TLExpr::Mod(l, r)
37            | TLExpr::Min(l, r)
38            | TLExpr::Max(l, r)
39            | TLExpr::Eq(l, r)
40            | TLExpr::Lt(l, r)
41            | TLExpr::Gt(l, r)
42            | TLExpr::Lte(l, r)
43            | TLExpr::Gte(l, r) => {
44                l.collect_free_vars(vars, bound);
45                r.collect_free_vars(vars, bound);
46            }
47            TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
48                left.collect_free_vars(vars, bound);
49                right.collect_free_vars(vars, bound);
50            }
51            TLExpr::FuzzyImplication {
52                premise,
53                conclusion,
54                ..
55            } => {
56                premise.collect_free_vars(vars, bound);
57                conclusion.collect_free_vars(vars, bound);
58            }
59            TLExpr::Not(e)
60            | TLExpr::Score(e)
61            | TLExpr::Abs(e)
62            | TLExpr::Floor(e)
63            | TLExpr::Ceil(e)
64            | TLExpr::Round(e)
65            | TLExpr::Sqrt(e)
66            | TLExpr::Exp(e)
67            | TLExpr::Log(e)
68            | TLExpr::Sin(e)
69            | TLExpr::Cos(e)
70            | TLExpr::Tan(e)
71            | TLExpr::Box(e)
72            | TLExpr::Diamond(e)
73            | TLExpr::Next(e)
74            | TLExpr::Eventually(e)
75            | TLExpr::Always(e) => {
76                e.collect_free_vars(vars, bound);
77            }
78            TLExpr::FuzzyNot { expr, .. } => {
79                expr.collect_free_vars(vars, bound);
80            }
81            TLExpr::WeightedRule { rule, .. } => {
82                rule.collect_free_vars(vars, bound);
83            }
84            TLExpr::Until { before, after }
85            | TLExpr::Release {
86                released: before,
87                releaser: after,
88            }
89            | TLExpr::WeakUntil { before, after }
90            | TLExpr::StrongRelease {
91                released: before,
92                releaser: after,
93            } => {
94                before.collect_free_vars(vars, bound);
95                after.collect_free_vars(vars, bound);
96            }
97            TLExpr::Exists { var, body, .. }
98            | TLExpr::ForAll { var, body, .. }
99            | TLExpr::SoftExists { var, body, .. }
100            | TLExpr::SoftForAll { var, body, .. } => {
101                let mut new_bound = bound.clone();
102                new_bound.insert(var.clone());
103                body.collect_free_vars(vars, &new_bound);
104            }
105            TLExpr::Aggregate {
106                var,
107                body,
108                group_by,
109                ..
110            } => {
111                let mut new_bound = bound.clone();
112                new_bound.insert(var.clone());
113                body.collect_free_vars(vars, &new_bound);
114                // Group-by variables are free if not already bound
115                if let Some(group_vars) = group_by {
116                    for gv in group_vars {
117                        if !bound.contains(gv) {
118                            vars.insert(gv.clone());
119                        }
120                    }
121                }
122            }
123            TLExpr::IfThenElse {
124                condition,
125                then_branch,
126                else_branch,
127            } => {
128                condition.collect_free_vars(vars, bound);
129                then_branch.collect_free_vars(vars, bound);
130                else_branch.collect_free_vars(vars, bound);
131            }
132            TLExpr::Let { var, value, body } => {
133                // First collect free vars from the value expression
134                value.collect_free_vars(vars, bound);
135                // Then collect from body with the variable bound
136                let mut new_bound = bound.clone();
137                new_bound.insert(var.clone());
138                body.collect_free_vars(vars, &new_bound);
139            }
140            TLExpr::Constant(_) => {
141                // No free variables in constants
142            }
143            TLExpr::ProbabilisticChoice { alternatives } => {
144                for (_, expr) in alternatives {
145                    expr.collect_free_vars(vars, bound);
146                }
147            }
148            // Beta.1 enhancements
149            TLExpr::Lambda { var, body, .. } => {
150                // Lambda binds the variable
151                let mut new_bound = bound.clone();
152                new_bound.insert(var.clone());
153                body.collect_free_vars(vars, &new_bound);
154            }
155            TLExpr::Apply { function, argument } => {
156                function.collect_free_vars(vars, bound);
157                argument.collect_free_vars(vars, bound);
158            }
159            TLExpr::SetMembership { element, set }
160            | TLExpr::SetUnion {
161                left: element,
162                right: set,
163            }
164            | TLExpr::SetIntersection {
165                left: element,
166                right: set,
167            }
168            | TLExpr::SetDifference {
169                left: element,
170                right: set,
171            } => {
172                element.collect_free_vars(vars, bound);
173                set.collect_free_vars(vars, bound);
174            }
175            TLExpr::SetCardinality { set } => {
176                set.collect_free_vars(vars, bound);
177            }
178            TLExpr::EmptySet => {
179                // No free variables
180            }
181            TLExpr::SetComprehension { var, condition, .. } => {
182                // Set comprehension binds the variable
183                let mut new_bound = bound.clone();
184                new_bound.insert(var.clone());
185                condition.collect_free_vars(vars, &new_bound);
186            }
187            TLExpr::CountingExists { var, body, .. }
188            | TLExpr::CountingForAll { var, body, .. }
189            | TLExpr::ExactCount { var, body, .. }
190            | TLExpr::Majority { var, body, .. } => {
191                // Counting quantifiers bind the variable
192                let mut new_bound = bound.clone();
193                new_bound.insert(var.clone());
194                body.collect_free_vars(vars, &new_bound);
195            }
196            TLExpr::LeastFixpoint { var, body } | TLExpr::GreatestFixpoint { var, body } => {
197                // Fixed-point operators bind the variable
198                let mut new_bound = bound.clone();
199                new_bound.insert(var.clone());
200                body.collect_free_vars(vars, &new_bound);
201            }
202            TLExpr::Nominal { .. } => {
203                // No free variables
204            }
205            TLExpr::At { formula, .. } => {
206                formula.collect_free_vars(vars, bound);
207            }
208            TLExpr::Somewhere { formula } | TLExpr::Everywhere { formula } => {
209                formula.collect_free_vars(vars, bound);
210            }
211            TLExpr::AllDifferent { variables } => {
212                // Variables in constraint are free if not bound
213                for v in variables {
214                    if !bound.contains(v) {
215                        vars.insert(v.clone());
216                    }
217                }
218            }
219            TLExpr::GlobalCardinality {
220                variables, values, ..
221            } => {
222                // Variables are free if not bound
223                for v in variables {
224                    if !bound.contains(v) {
225                        vars.insert(v.clone());
226                    }
227                }
228                // Collect from value expressions
229                for val in values {
230                    val.collect_free_vars(vars, bound);
231                }
232            }
233            TLExpr::Abducible { .. } => {
234                // No free variables (it's a literal)
235            }
236            TLExpr::Explain { formula } => {
237                formula.collect_free_vars(vars, bound);
238            }
239            TLExpr::SymbolLiteral(_) => {
240                // No free variables — a symbol literal is a constant
241            }
242            TLExpr::Match { scrutinee, arms } => {
243                scrutinee.collect_free_vars(vars, bound);
244                for (_, body) in arms {
245                    body.collect_free_vars(vars, bound);
246                }
247            }
248        }
249    }
250
251    /// Collect all predicates and their arities
252    pub fn all_predicates(&self) -> HashMap<String, usize> {
253        let mut preds = HashMap::new();
254        self.collect_predicates(&mut preds);
255        preds
256    }
257
258    pub(crate) fn collect_predicates(&self, preds: &mut HashMap<String, usize>) {
259        match self {
260            TLExpr::Pred { name, args } => {
261                preds.entry(name.clone()).or_insert(args.len());
262            }
263            TLExpr::And(l, r)
264            | TLExpr::Or(l, r)
265            | TLExpr::Imply(l, r)
266            | TLExpr::Add(l, r)
267            | TLExpr::Sub(l, r)
268            | TLExpr::Mul(l, r)
269            | TLExpr::Div(l, r)
270            | TLExpr::Pow(l, r)
271            | TLExpr::Mod(l, r)
272            | TLExpr::Min(l, r)
273            | TLExpr::Max(l, r)
274            | TLExpr::Eq(l, r)
275            | TLExpr::Lt(l, r)
276            | TLExpr::Gt(l, r)
277            | TLExpr::Lte(l, r)
278            | TLExpr::Gte(l, r) => {
279                l.collect_predicates(preds);
280                r.collect_predicates(preds);
281            }
282            TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
283                left.collect_predicates(preds);
284                right.collect_predicates(preds);
285            }
286            TLExpr::FuzzyImplication {
287                premise,
288                conclusion,
289                ..
290            } => {
291                premise.collect_predicates(preds);
292                conclusion.collect_predicates(preds);
293            }
294            TLExpr::Not(e)
295            | TLExpr::Score(e)
296            | TLExpr::Abs(e)
297            | TLExpr::Floor(e)
298            | TLExpr::Ceil(e)
299            | TLExpr::Round(e)
300            | TLExpr::Sqrt(e)
301            | TLExpr::Exp(e)
302            | TLExpr::Log(e)
303            | TLExpr::Sin(e)
304            | TLExpr::Cos(e)
305            | TLExpr::Tan(e)
306            | TLExpr::Box(e)
307            | TLExpr::Diamond(e)
308            | TLExpr::Next(e)
309            | TLExpr::Eventually(e)
310            | TLExpr::Always(e) => {
311                e.collect_predicates(preds);
312            }
313            TLExpr::FuzzyNot { expr, .. } => {
314                expr.collect_predicates(preds);
315            }
316            TLExpr::WeightedRule { rule, .. } => {
317                rule.collect_predicates(preds);
318            }
319            TLExpr::Until { before, after }
320            | TLExpr::Release {
321                released: before,
322                releaser: after,
323            }
324            | TLExpr::WeakUntil { before, after }
325            | TLExpr::StrongRelease {
326                released: before,
327                releaser: after,
328            } => {
329                before.collect_predicates(preds);
330                after.collect_predicates(preds);
331            }
332            TLExpr::Exists { body, .. }
333            | TLExpr::ForAll { body, .. }
334            | TLExpr::SoftExists { body, .. }
335            | TLExpr::SoftForAll { body, .. } => {
336                body.collect_predicates(preds);
337            }
338            TLExpr::Aggregate { body, .. } => {
339                body.collect_predicates(preds);
340            }
341            TLExpr::IfThenElse {
342                condition,
343                then_branch,
344                else_branch,
345            } => {
346                condition.collect_predicates(preds);
347                then_branch.collect_predicates(preds);
348                else_branch.collect_predicates(preds);
349            }
350            TLExpr::Let { value, body, .. } => {
351                value.collect_predicates(preds);
352                body.collect_predicates(preds);
353            }
354            TLExpr::Constant(_) => {
355                // No predicates in constants
356            }
357            TLExpr::ProbabilisticChoice { alternatives } => {
358                for (_, expr) in alternatives {
359                    expr.collect_predicates(preds);
360                }
361            }
362            // Beta.1 enhancements
363            TLExpr::Lambda { body, .. } => {
364                body.collect_predicates(preds);
365            }
366            TLExpr::Apply { function, argument } => {
367                function.collect_predicates(preds);
368                argument.collect_predicates(preds);
369            }
370            TLExpr::SetMembership { element, set }
371            | TLExpr::SetUnion {
372                left: element,
373                right: set,
374            }
375            | TLExpr::SetIntersection {
376                left: element,
377                right: set,
378            }
379            | TLExpr::SetDifference {
380                left: element,
381                right: set,
382            } => {
383                element.collect_predicates(preds);
384                set.collect_predicates(preds);
385            }
386            TLExpr::SetCardinality { set } => {
387                set.collect_predicates(preds);
388            }
389            TLExpr::EmptySet => {
390                // No predicates
391            }
392            TLExpr::SetComprehension { condition, .. } => {
393                condition.collect_predicates(preds);
394            }
395            TLExpr::CountingExists { body, .. }
396            | TLExpr::CountingForAll { body, .. }
397            | TLExpr::ExactCount { body, .. }
398            | TLExpr::Majority { body, .. } => {
399                body.collect_predicates(preds);
400            }
401            TLExpr::LeastFixpoint { body, .. } | TLExpr::GreatestFixpoint { body, .. } => {
402                body.collect_predicates(preds);
403            }
404            TLExpr::Nominal { .. } => {
405                // No predicates
406            }
407            TLExpr::At { formula, .. } => {
408                formula.collect_predicates(preds);
409            }
410            TLExpr::Somewhere { formula } | TLExpr::Everywhere { formula } => {
411                formula.collect_predicates(preds);
412            }
413            TLExpr::AllDifferent { .. } => {
414                // No predicates (constraint on variables)
415            }
416            TLExpr::GlobalCardinality { values, .. } => {
417                // Collect from value expressions
418                for val in values {
419                    val.collect_predicates(preds);
420                }
421            }
422            TLExpr::Abducible { .. } => {
423                // No predicates
424            }
425            TLExpr::Explain { formula } => {
426                formula.collect_predicates(preds);
427            }
428            TLExpr::SymbolLiteral(_) => {
429                // No predicates — a symbol literal is a constant
430            }
431            TLExpr::Match { scrutinee, arms } => {
432                scrutinee.collect_predicates(preds);
433                for (_, body) in arms {
434                    body.collect_predicates(preds);
435                }
436            }
437        }
438    }
439}