Skip to main content

tensorlogic_ir/expr/optimization/
propagation.rs

1//! Constant propagation for Let bindings.
2//!
3//! This module implements constant propagation, which substitutes variables
4//! bound in Let expressions with their values throughout the expression tree.
5
6use super::substitution::substitute;
7use crate::expr::TLExpr;
8
9pub fn propagate_constants(expr: &TLExpr) -> TLExpr {
10    match expr {
11        // If the Let binding value is a constant, substitute it into the body
12        TLExpr::Let { var, value, body } => {
13            let optimized_value = propagate_constants(value);
14            let optimized_body = propagate_constants(body);
15
16            // If the value is constant, substitute it
17            if matches!(optimized_value, TLExpr::Constant(_)) {
18                substitute(&optimized_body, var, &optimized_value)
19            } else {
20                TLExpr::Let {
21                    var: var.clone(),
22                    value: Box::new(optimized_value),
23                    body: Box::new(optimized_body),
24                }
25            }
26        }
27
28        // Recursively propagate in other expressions
29        TLExpr::And(l, r) => TLExpr::And(
30            Box::new(propagate_constants(l)),
31            Box::new(propagate_constants(r)),
32        ),
33        TLExpr::Or(l, r) => TLExpr::Or(
34            Box::new(propagate_constants(l)),
35            Box::new(propagate_constants(r)),
36        ),
37        TLExpr::Imply(l, r) => TLExpr::Imply(
38            Box::new(propagate_constants(l)),
39            Box::new(propagate_constants(r)),
40        ),
41        TLExpr::Add(l, r) => TLExpr::Add(
42            Box::new(propagate_constants(l)),
43            Box::new(propagate_constants(r)),
44        ),
45        TLExpr::Sub(l, r) => TLExpr::Sub(
46            Box::new(propagate_constants(l)),
47            Box::new(propagate_constants(r)),
48        ),
49        TLExpr::Mul(l, r) => TLExpr::Mul(
50            Box::new(propagate_constants(l)),
51            Box::new(propagate_constants(r)),
52        ),
53        TLExpr::Div(l, r) => TLExpr::Div(
54            Box::new(propagate_constants(l)),
55            Box::new(propagate_constants(r)),
56        ),
57        TLExpr::Pow(l, r) => TLExpr::Pow(
58            Box::new(propagate_constants(l)),
59            Box::new(propagate_constants(r)),
60        ),
61        TLExpr::Mod(l, r) => TLExpr::Mod(
62            Box::new(propagate_constants(l)),
63            Box::new(propagate_constants(r)),
64        ),
65        TLExpr::Min(l, r) => TLExpr::Min(
66            Box::new(propagate_constants(l)),
67            Box::new(propagate_constants(r)),
68        ),
69        TLExpr::Max(l, r) => TLExpr::Max(
70            Box::new(propagate_constants(l)),
71            Box::new(propagate_constants(r)),
72        ),
73        TLExpr::Eq(l, r) => TLExpr::Eq(
74            Box::new(propagate_constants(l)),
75            Box::new(propagate_constants(r)),
76        ),
77        TLExpr::Lt(l, r) => TLExpr::Lt(
78            Box::new(propagate_constants(l)),
79            Box::new(propagate_constants(r)),
80        ),
81        TLExpr::Gt(l, r) => TLExpr::Gt(
82            Box::new(propagate_constants(l)),
83            Box::new(propagate_constants(r)),
84        ),
85        TLExpr::Lte(l, r) => TLExpr::Lte(
86            Box::new(propagate_constants(l)),
87            Box::new(propagate_constants(r)),
88        ),
89        TLExpr::Gte(l, r) => TLExpr::Gte(
90            Box::new(propagate_constants(l)),
91            Box::new(propagate_constants(r)),
92        ),
93        TLExpr::Not(e) => TLExpr::Not(Box::new(propagate_constants(e))),
94        TLExpr::Score(e) => TLExpr::Score(Box::new(propagate_constants(e))),
95        TLExpr::Abs(e) => TLExpr::Abs(Box::new(propagate_constants(e))),
96        TLExpr::Floor(e) => TLExpr::Floor(Box::new(propagate_constants(e))),
97        TLExpr::Ceil(e) => TLExpr::Ceil(Box::new(propagate_constants(e))),
98        TLExpr::Round(e) => TLExpr::Round(Box::new(propagate_constants(e))),
99        TLExpr::Sqrt(e) => TLExpr::Sqrt(Box::new(propagate_constants(e))),
100        TLExpr::Exp(e) => TLExpr::Exp(Box::new(propagate_constants(e))),
101        TLExpr::Log(e) => TLExpr::Log(Box::new(propagate_constants(e))),
102        TLExpr::Sin(e) => TLExpr::Sin(Box::new(propagate_constants(e))),
103        TLExpr::Cos(e) => TLExpr::Cos(Box::new(propagate_constants(e))),
104        TLExpr::Tan(e) => TLExpr::Tan(Box::new(propagate_constants(e))),
105        TLExpr::Box(e) => TLExpr::Box(Box::new(propagate_constants(e))),
106        TLExpr::Diamond(e) => TLExpr::Diamond(Box::new(propagate_constants(e))),
107        TLExpr::Next(e) => TLExpr::Next(Box::new(propagate_constants(e))),
108        TLExpr::Eventually(e) => TLExpr::Eventually(Box::new(propagate_constants(e))),
109        TLExpr::Always(e) => TLExpr::Always(Box::new(propagate_constants(e))),
110        TLExpr::Until { before, after } => TLExpr::Until {
111            before: Box::new(propagate_constants(before)),
112            after: Box::new(propagate_constants(after)),
113        },
114
115        // Fuzzy logic operators
116        TLExpr::TNorm { kind, left, right } => TLExpr::TNorm {
117            kind: *kind,
118            left: Box::new(propagate_constants(left)),
119            right: Box::new(propagate_constants(right)),
120        },
121        TLExpr::TCoNorm { kind, left, right } => TLExpr::TCoNorm {
122            kind: *kind,
123            left: Box::new(propagate_constants(left)),
124            right: Box::new(propagate_constants(right)),
125        },
126        TLExpr::FuzzyNot { kind, expr } => TLExpr::FuzzyNot {
127            kind: *kind,
128            expr: Box::new(propagate_constants(expr)),
129        },
130        TLExpr::FuzzyImplication {
131            kind,
132            premise,
133            conclusion,
134        } => TLExpr::FuzzyImplication {
135            kind: *kind,
136            premise: Box::new(propagate_constants(premise)),
137            conclusion: Box::new(propagate_constants(conclusion)),
138        },
139
140        // Probabilistic operators
141        TLExpr::SoftExists {
142            var,
143            domain,
144            body,
145            temperature,
146        } => TLExpr::SoftExists {
147            var: var.clone(),
148            domain: domain.clone(),
149            body: Box::new(propagate_constants(body)),
150            temperature: *temperature,
151        },
152        TLExpr::SoftForAll {
153            var,
154            domain,
155            body,
156            temperature,
157        } => TLExpr::SoftForAll {
158            var: var.clone(),
159            domain: domain.clone(),
160            body: Box::new(propagate_constants(body)),
161            temperature: *temperature,
162        },
163        TLExpr::WeightedRule { weight, rule } => TLExpr::WeightedRule {
164            weight: *weight,
165            rule: Box::new(propagate_constants(rule)),
166        },
167        TLExpr::ProbabilisticChoice { alternatives } => TLExpr::ProbabilisticChoice {
168            alternatives: alternatives
169                .iter()
170                .map(|(p, e)| (*p, propagate_constants(e)))
171                .collect(),
172        },
173
174        // Extended temporal logic
175        TLExpr::Release { released, releaser } => TLExpr::Release {
176            released: Box::new(propagate_constants(released)),
177            releaser: Box::new(propagate_constants(releaser)),
178        },
179        TLExpr::WeakUntil { before, after } => TLExpr::WeakUntil {
180            before: Box::new(propagate_constants(before)),
181            after: Box::new(propagate_constants(after)),
182        },
183        TLExpr::StrongRelease { released, releaser } => TLExpr::StrongRelease {
184            released: Box::new(propagate_constants(released)),
185            releaser: Box::new(propagate_constants(releaser)),
186        },
187
188        TLExpr::Exists { var, domain, body } => TLExpr::Exists {
189            var: var.clone(),
190            domain: domain.clone(),
191            body: Box::new(propagate_constants(body)),
192        },
193        TLExpr::ForAll { var, domain, body } => TLExpr::ForAll {
194            var: var.clone(),
195            domain: domain.clone(),
196            body: Box::new(propagate_constants(body)),
197        },
198        TLExpr::Aggregate {
199            op,
200            var,
201            domain,
202            body,
203            group_by,
204        } => TLExpr::Aggregate {
205            op: op.clone(),
206            var: var.clone(),
207            domain: domain.clone(),
208            body: Box::new(propagate_constants(body)),
209            group_by: group_by.clone(),
210        },
211        TLExpr::IfThenElse {
212            condition,
213            then_branch,
214            else_branch,
215        } => TLExpr::IfThenElse {
216            condition: Box::new(propagate_constants(condition)),
217            then_branch: Box::new(propagate_constants(then_branch)),
218            else_branch: Box::new(propagate_constants(else_branch)),
219        },
220
221        // Alpha.3 enhancements
222        TLExpr::Lambda {
223            var,
224            var_type,
225            body,
226        } => TLExpr::lambda(var.clone(), var_type.clone(), propagate_constants(body)),
227        TLExpr::Apply { function, argument } => {
228            TLExpr::apply(propagate_constants(function), propagate_constants(argument))
229        }
230        TLExpr::SetMembership { element, set } => {
231            TLExpr::set_membership(propagate_constants(element), propagate_constants(set))
232        }
233        TLExpr::SetUnion { left, right } => {
234            TLExpr::set_union(propagate_constants(left), propagate_constants(right))
235        }
236        TLExpr::SetIntersection { left, right } => {
237            TLExpr::set_intersection(propagate_constants(left), propagate_constants(right))
238        }
239        TLExpr::SetDifference { left, right } => {
240            TLExpr::set_difference(propagate_constants(left), propagate_constants(right))
241        }
242        TLExpr::SetCardinality { set } => TLExpr::set_cardinality(propagate_constants(set)),
243        TLExpr::EmptySet => expr.clone(),
244        TLExpr::SetComprehension {
245            var,
246            domain,
247            condition,
248        } => TLExpr::set_comprehension(var.clone(), domain.clone(), propagate_constants(condition)),
249        TLExpr::CountingExists {
250            var,
251            domain,
252            body,
253            min_count,
254        } => TLExpr::counting_exists(
255            var.clone(),
256            domain.clone(),
257            propagate_constants(body),
258            *min_count,
259        ),
260        TLExpr::CountingForAll {
261            var,
262            domain,
263            body,
264            min_count,
265        } => TLExpr::counting_forall(
266            var.clone(),
267            domain.clone(),
268            propagate_constants(body),
269            *min_count,
270        ),
271        TLExpr::ExactCount {
272            var,
273            domain,
274            body,
275            count,
276        } => TLExpr::exact_count(
277            var.clone(),
278            domain.clone(),
279            propagate_constants(body),
280            *count,
281        ),
282        TLExpr::Majority { var, domain, body } => {
283            TLExpr::majority(var.clone(), domain.clone(), propagate_constants(body))
284        }
285        TLExpr::LeastFixpoint { var, body } => {
286            TLExpr::least_fixpoint(var.clone(), propagate_constants(body))
287        }
288        TLExpr::GreatestFixpoint { var, body } => {
289            TLExpr::greatest_fixpoint(var.clone(), propagate_constants(body))
290        }
291        TLExpr::Nominal { .. } => expr.clone(),
292        TLExpr::At { nominal, formula } => {
293            TLExpr::at(nominal.clone(), propagate_constants(formula))
294        }
295        TLExpr::Somewhere { formula } => TLExpr::somewhere(propagate_constants(formula)),
296        TLExpr::Everywhere { formula } => TLExpr::everywhere(propagate_constants(formula)),
297        TLExpr::AllDifferent { .. } => expr.clone(),
298        TLExpr::GlobalCardinality {
299            variables,
300            values,
301            min_occurrences,
302            max_occurrences,
303        } => TLExpr::global_cardinality(
304            variables.clone(),
305            values.iter().map(propagate_constants).collect(),
306            min_occurrences.clone(),
307            max_occurrences.clone(),
308        ),
309        TLExpr::Abducible { .. } => expr.clone(),
310        TLExpr::Explain { formula } => TLExpr::explain(propagate_constants(formula)),
311
312        TLExpr::Pred { .. } | TLExpr::Constant(_) => expr.clone(),
313    }
314}