Skip to main content

tensorlogic_compiler/passes/
cse.rs

1//! Common Subexpression Elimination (CSE) for TLExpr.
2
3use std::collections::HashMap;
4
5use tensorlogic_ir::TLExpr;
6
7/// CSE result containing optimized expression and statistics
8#[derive(Debug, Clone)]
9pub struct CseResult {
10    pub optimized_expr: TLExpr,
11    pub eliminated_count: usize,
12}
13
14/// Perform common subexpression elimination on a TLExpr
15pub fn eliminate_common_subexpressions(expr: &TLExpr) -> CseResult {
16    let mut cache: HashMap<String, TLExpr> = HashMap::new();
17    let mut eliminated_count = 0;
18
19    let optimized = cse_recursive(expr, &mut cache, &mut eliminated_count);
20
21    CseResult {
22        optimized_expr: optimized,
23        eliminated_count,
24    }
25}
26
27fn cse_recursive(
28    expr: &TLExpr,
29    cache: &mut HashMap<String, TLExpr>,
30    eliminated_count: &mut usize,
31) -> TLExpr {
32    // Compute a hash/key for this expression
33    let key = expr_to_key(expr);
34
35    // Check if we've seen this exact subexpression before
36    if let Some(cached) = cache.get(&key) {
37        *eliminated_count += 1;
38        return cached.clone();
39    }
40
41    // Recursively process subexpressions
42    let result = match expr {
43        TLExpr::Pred { .. } => {
44            // Predicates are atomic, just cache them
45            expr.clone()
46        }
47        TLExpr::And(left, right) => {
48            let left_opt = cse_recursive(left, cache, eliminated_count);
49            let right_opt = cse_recursive(right, cache, eliminated_count);
50            TLExpr::and(left_opt, right_opt)
51        }
52        TLExpr::Or(left, right) => {
53            let left_opt = cse_recursive(left, cache, eliminated_count);
54            let right_opt = cse_recursive(right, cache, eliminated_count);
55            TLExpr::or(left_opt, right_opt)
56        }
57        TLExpr::Imply(premise, conclusion) => {
58            let premise_opt = cse_recursive(premise, cache, eliminated_count);
59            let conclusion_opt = cse_recursive(conclusion, cache, eliminated_count);
60            TLExpr::imply(premise_opt, conclusion_opt)
61        }
62        TLExpr::Not(inner) => {
63            let inner_opt = cse_recursive(inner, cache, eliminated_count);
64            TLExpr::negate(inner_opt)
65        }
66        TLExpr::Exists { var, domain, body } => {
67            let body_opt = cse_recursive(body, cache, eliminated_count);
68            TLExpr::exists(var, domain, body_opt)
69        }
70        TLExpr::ForAll { var, domain, body } => {
71            let body_opt = cse_recursive(body, cache, eliminated_count);
72            TLExpr::forall(var, domain, body_opt)
73        }
74        TLExpr::Aggregate {
75            op,
76            var,
77            domain,
78            body,
79            group_by,
80        } => {
81            let body_opt = cse_recursive(body, cache, eliminated_count);
82            TLExpr::aggregate_with_group_by(
83                op.clone(),
84                var,
85                domain,
86                body_opt,
87                group_by.clone().unwrap_or_default(),
88            )
89        }
90        TLExpr::Score(inner) => {
91            let inner_opt = cse_recursive(inner, cache, eliminated_count);
92            TLExpr::score(inner_opt)
93        }
94        // Arithmetic operations
95        TLExpr::Add(left, right) => {
96            let left_opt = cse_recursive(left, cache, eliminated_count);
97            let right_opt = cse_recursive(right, cache, eliminated_count);
98            TLExpr::add(left_opt, right_opt)
99        }
100        TLExpr::Sub(left, right) => {
101            let left_opt = cse_recursive(left, cache, eliminated_count);
102            let right_opt = cse_recursive(right, cache, eliminated_count);
103            TLExpr::sub(left_opt, right_opt)
104        }
105        TLExpr::Mul(left, right) => {
106            let left_opt = cse_recursive(left, cache, eliminated_count);
107            let right_opt = cse_recursive(right, cache, eliminated_count);
108            TLExpr::mul(left_opt, right_opt)
109        }
110        TLExpr::Div(left, right) => {
111            let left_opt = cse_recursive(left, cache, eliminated_count);
112            let right_opt = cse_recursive(right, cache, eliminated_count);
113            TLExpr::div(left_opt, right_opt)
114        }
115        // Comparison operations
116        TLExpr::Eq(left, right) => {
117            let left_opt = cse_recursive(left, cache, eliminated_count);
118            let right_opt = cse_recursive(right, cache, eliminated_count);
119            TLExpr::eq(left_opt, right_opt)
120        }
121        TLExpr::Lt(left, right) => {
122            let left_opt = cse_recursive(left, cache, eliminated_count);
123            let right_opt = cse_recursive(right, cache, eliminated_count);
124            TLExpr::lt(left_opt, right_opt)
125        }
126        TLExpr::Gt(left, right) => {
127            let left_opt = cse_recursive(left, cache, eliminated_count);
128            let right_opt = cse_recursive(right, cache, eliminated_count);
129            TLExpr::gt(left_opt, right_opt)
130        }
131        TLExpr::Lte(left, right) => {
132            let left_opt = cse_recursive(left, cache, eliminated_count);
133            let right_opt = cse_recursive(right, cache, eliminated_count);
134            TLExpr::lte(left_opt, right_opt)
135        }
136        TLExpr::Gte(left, right) => {
137            let left_opt = cse_recursive(left, cache, eliminated_count);
138            let right_opt = cse_recursive(right, cache, eliminated_count);
139            TLExpr::gte(left_opt, right_opt)
140        }
141        TLExpr::Pow(left, right) => {
142            let left_opt = cse_recursive(left, cache, eliminated_count);
143            let right_opt = cse_recursive(right, cache, eliminated_count);
144            TLExpr::pow(left_opt, right_opt)
145        }
146        TLExpr::Mod(left, right) => {
147            let left_opt = cse_recursive(left, cache, eliminated_count);
148            let right_opt = cse_recursive(right, cache, eliminated_count);
149            TLExpr::modulo(left_opt, right_opt)
150        }
151        TLExpr::Min(left, right) => {
152            let left_opt = cse_recursive(left, cache, eliminated_count);
153            let right_opt = cse_recursive(right, cache, eliminated_count);
154            TLExpr::min(left_opt, right_opt)
155        }
156        TLExpr::Max(left, right) => {
157            let left_opt = cse_recursive(left, cache, eliminated_count);
158            let right_opt = cse_recursive(right, cache, eliminated_count);
159            TLExpr::max(left_opt, right_opt)
160        }
161        // Unary math operations
162        TLExpr::Abs(inner) => {
163            let inner_opt = cse_recursive(inner, cache, eliminated_count);
164            TLExpr::abs(inner_opt)
165        }
166        TLExpr::Floor(inner) => {
167            let inner_opt = cse_recursive(inner, cache, eliminated_count);
168            TLExpr::floor(inner_opt)
169        }
170        TLExpr::Ceil(inner) => {
171            let inner_opt = cse_recursive(inner, cache, eliminated_count);
172            TLExpr::ceil(inner_opt)
173        }
174        TLExpr::Round(inner) => {
175            let inner_opt = cse_recursive(inner, cache, eliminated_count);
176            TLExpr::round(inner_opt)
177        }
178        TLExpr::Sqrt(inner) => {
179            let inner_opt = cse_recursive(inner, cache, eliminated_count);
180            TLExpr::sqrt(inner_opt)
181        }
182        TLExpr::Exp(inner) => {
183            let inner_opt = cse_recursive(inner, cache, eliminated_count);
184            TLExpr::exp(inner_opt)
185        }
186        TLExpr::Log(inner) => {
187            let inner_opt = cse_recursive(inner, cache, eliminated_count);
188            TLExpr::log(inner_opt)
189        }
190        TLExpr::Sin(inner) => {
191            let inner_opt = cse_recursive(inner, cache, eliminated_count);
192            TLExpr::sin(inner_opt)
193        }
194        TLExpr::Cos(inner) => {
195            let inner_opt = cse_recursive(inner, cache, eliminated_count);
196            TLExpr::cos(inner_opt)
197        }
198        TLExpr::Tan(inner) => {
199            let inner_opt = cse_recursive(inner, cache, eliminated_count);
200            TLExpr::tan(inner_opt)
201        }
202        // Let binding
203        TLExpr::Let { var, value, body } => {
204            let value_opt = cse_recursive(value, cache, eliminated_count);
205            let body_opt = cse_recursive(body, cache, eliminated_count);
206            TLExpr::let_binding(var, value_opt, body_opt)
207        }
208        // Conditional
209        TLExpr::IfThenElse {
210            condition,
211            then_branch,
212            else_branch,
213        } => {
214            let cond_opt = cse_recursive(condition, cache, eliminated_count);
215            let then_opt = cse_recursive(then_branch, cache, eliminated_count);
216            let else_opt = cse_recursive(else_branch, cache, eliminated_count);
217            TLExpr::if_then_else(cond_opt, then_opt, else_opt)
218        }
219        // Constant
220        TLExpr::Constant(_) => {
221            // Constants are atomic, just cache them
222            expr.clone()
223        }
224
225        // Modal/temporal logic operators - not yet implemented, pass through with recursion
226        TLExpr::Box(inner) => {
227            let inner_opt = cse_recursive(inner, cache, eliminated_count);
228            TLExpr::Box(Box::new(inner_opt))
229        }
230        TLExpr::Diamond(inner) => {
231            let inner_opt = cse_recursive(inner, cache, eliminated_count);
232            TLExpr::Diamond(Box::new(inner_opt))
233        }
234        TLExpr::Next(inner) => {
235            let inner_opt = cse_recursive(inner, cache, eliminated_count);
236            TLExpr::Next(Box::new(inner_opt))
237        }
238        TLExpr::Eventually(inner) => {
239            let inner_opt = cse_recursive(inner, cache, eliminated_count);
240            TLExpr::Eventually(Box::new(inner_opt))
241        }
242        TLExpr::Always(inner) => {
243            let inner_opt = cse_recursive(inner, cache, eliminated_count);
244            TLExpr::Always(Box::new(inner_opt))
245        }
246        TLExpr::Until { before, after } => {
247            let before_opt = cse_recursive(before, cache, eliminated_count);
248            let after_opt = cse_recursive(after, cache, eliminated_count);
249            TLExpr::Until {
250                before: Box::new(before_opt),
251                after: Box::new(after_opt),
252            }
253        }
254        // Fuzzy logic operators
255        TLExpr::TNorm { kind, left, right } => {
256            let left_opt = cse_recursive(left, cache, eliminated_count);
257            let right_opt = cse_recursive(right, cache, eliminated_count);
258            TLExpr::TNorm {
259                kind: *kind,
260                left: Box::new(left_opt),
261                right: Box::new(right_opt),
262            }
263        }
264        TLExpr::TCoNorm { kind, left, right } => {
265            let left_opt = cse_recursive(left, cache, eliminated_count);
266            let right_opt = cse_recursive(right, cache, eliminated_count);
267            TLExpr::TCoNorm {
268                kind: *kind,
269                left: Box::new(left_opt),
270                right: Box::new(right_opt),
271            }
272        }
273        TLExpr::FuzzyNot { kind, expr: inner } => {
274            let inner_opt = cse_recursive(inner, cache, eliminated_count);
275            TLExpr::FuzzyNot {
276                kind: *kind,
277                expr: Box::new(inner_opt),
278            }
279        }
280        TLExpr::FuzzyImplication {
281            kind,
282            premise,
283            conclusion,
284        } => {
285            let premise_opt = cse_recursive(premise, cache, eliminated_count);
286            let conclusion_opt = cse_recursive(conclusion, cache, eliminated_count);
287            TLExpr::FuzzyImplication {
288                kind: *kind,
289                premise: Box::new(premise_opt),
290                conclusion: Box::new(conclusion_opt),
291            }
292        }
293        // Soft quantifiers
294        TLExpr::SoftExists {
295            var,
296            domain,
297            body,
298            temperature,
299        } => {
300            let body_opt = cse_recursive(body, cache, eliminated_count);
301            TLExpr::SoftExists {
302                var: var.clone(),
303                domain: domain.clone(),
304                body: Box::new(body_opt),
305                temperature: *temperature,
306            }
307        }
308        TLExpr::SoftForAll {
309            var,
310            domain,
311            body,
312            temperature,
313        } => {
314            let body_opt = cse_recursive(body, cache, eliminated_count);
315            TLExpr::SoftForAll {
316                var: var.clone(),
317                domain: domain.clone(),
318                body: Box::new(body_opt),
319                temperature: *temperature,
320            }
321        }
322        // Weighted/probabilistic operators
323        TLExpr::WeightedRule { weight, rule } => {
324            let rule_opt = cse_recursive(rule, cache, eliminated_count);
325            TLExpr::WeightedRule {
326                weight: *weight,
327                rule: Box::new(rule_opt),
328            }
329        }
330        TLExpr::ProbabilisticChoice { alternatives } => {
331            let alts_opt: Vec<(f64, TLExpr)> = alternatives
332                .iter()
333                .map(|(prob, expr)| (*prob, cse_recursive(expr, cache, eliminated_count)))
334                .collect();
335            TLExpr::ProbabilisticChoice {
336                alternatives: alts_opt,
337            }
338        }
339        // Extended temporal operators
340        TLExpr::Release { released, releaser } => {
341            let released_opt = cse_recursive(released, cache, eliminated_count);
342            let releaser_opt = cse_recursive(releaser, cache, eliminated_count);
343            TLExpr::Release {
344                released: Box::new(released_opt),
345                releaser: Box::new(releaser_opt),
346            }
347        }
348        TLExpr::WeakUntil { before, after } => {
349            let before_opt = cse_recursive(before, cache, eliminated_count);
350            let after_opt = cse_recursive(after, cache, eliminated_count);
351            TLExpr::WeakUntil {
352                before: Box::new(before_opt),
353                after: Box::new(after_opt),
354            }
355        }
356        TLExpr::StrongRelease { released, releaser } => {
357            let released_opt = cse_recursive(released, cache, eliminated_count);
358            let releaser_opt = cse_recursive(releaser, cache, eliminated_count);
359            TLExpr::StrongRelease {
360                released: Box::new(released_opt),
361                releaser: Box::new(releaser_opt),
362            }
363        }
364        // All other expression types (enhancements)
365        _ => expr.clone(),
366    };
367
368    // Cache this result
369    cache.insert(key, result.clone());
370    result
371}
372
373/// Convert an expression to a hashable key
374fn expr_to_key(expr: &TLExpr) -> String {
375    // Use debug format as a simple hash
376    // In production, you'd want a proper hash function
377    format!("{:?}", expr)
378}
379
380#[cfg(test)]
381mod tests {
382    use super::*;
383    use tensorlogic_ir::Term;
384
385    #[test]
386    fn test_cse_no_duplicates() {
387        let expr = TLExpr::and(
388            TLExpr::pred("p", vec![Term::var("x")]),
389            TLExpr::pred("q", vec![Term::var("y")]),
390        );
391
392        let result = eliminate_common_subexpressions(&expr);
393        assert_eq!(result.eliminated_count, 0);
394    }
395
396    #[test]
397    fn test_cse_duplicate_predicates() {
398        // p(x) ∧ p(x) - should detect duplicate
399        let p_x = TLExpr::pred("p", vec![Term::var("x")]);
400        let expr = TLExpr::and(p_x.clone(), p_x);
401
402        let result = eliminate_common_subexpressions(&expr);
403        // Should eliminate at least one duplicate
404        assert!(result.eliminated_count > 0);
405    }
406
407    #[test]
408    fn test_cse_nested_duplicates() {
409        // (p(x) ∧ q(y)) ∧ (p(x) ∧ q(y)) - duplicate AND subexpressions
410        let p_x = TLExpr::pred("p", vec![Term::var("x")]);
411        let q_y = TLExpr::pred("q", vec![Term::var("y")]);
412        let sub = TLExpr::and(p_x, q_y);
413        let expr = TLExpr::and(sub.clone(), sub);
414
415        let result = eliminate_common_subexpressions(&expr);
416        assert!(result.eliminated_count > 0);
417    }
418
419    #[test]
420    fn test_cse_with_quantifiers() {
421        // ∃x. p(x) ∧ ∃x. p(x) - duplicate existentials
422        let p_x = TLExpr::pred("p", vec![Term::var("x")]);
423        let exists = TLExpr::exists("x", "Domain", p_x);
424        let expr = TLExpr::and(exists.clone(), exists);
425
426        let result = eliminate_common_subexpressions(&expr);
427        assert!(result.eliminated_count > 0);
428    }
429
430    #[test]
431    fn test_cse_preserves_semantics() {
432        // Verify that CSE doesn't change the structure inappropriately
433        let p_x = TLExpr::pred("p", vec![Term::var("x")]);
434        let q_y = TLExpr::pred("q", vec![Term::var("y")]);
435        let expr = TLExpr::and(p_x.clone(), q_y.clone());
436
437        let result = eliminate_common_subexpressions(&expr);
438
439        // Should still be an AND of two predicates
440        match result.optimized_expr {
441            TLExpr::And(left, right) => {
442                assert!(matches!(*left, TLExpr::Pred { .. }));
443                assert!(matches!(*right, TLExpr::Pred { .. }));
444            }
445            _ => panic!("Expected And expression"),
446        }
447    }
448
449    #[test]
450    fn test_cse_complex_expression() {
451        // p(x) ∧ (q(y) ∨ p(x)) - p(x) appears twice
452        let p_x = TLExpr::pred("p", vec![Term::var("x")]);
453        let q_y = TLExpr::pred("q", vec![Term::var("y")]);
454        let or_expr = TLExpr::or(q_y, p_x.clone());
455        let expr = TLExpr::and(p_x, or_expr);
456
457        let result = eliminate_common_subexpressions(&expr);
458        assert!(result.eliminated_count > 0);
459    }
460}