Skip to main content

tensorlogic_compiler/optimize/
dead_code.rs

1//! Dead code elimination optimization pass.
2//!
3//! This module provides optimizations that remove unreachable or redundant
4//! code from TLExpr expressions. Examples include:
5//!
6//! - `if true then A else B` → `A`
7//! - `if false then A else B` → `B`
8//! - `AND(false, x)` → `false` (short-circuit)
9//! - `OR(true, x)` → `true` (short-circuit)
10//! - `EXISTS x. constant` → `constant` (if x is not free in constant)
11//! - Remove unused subexpressions that don't affect the result
12//!
13//! # Examples
14//!
15//! ```
16//! use tensorlogic_compiler::optimize::eliminate_dead_code;
17//! use tensorlogic_ir::{TLExpr, Term};
18//!
19//! // if true then A else B → A
20//! let a = TLExpr::pred("a", vec![Term::var("i")]);
21//! let b = TLExpr::pred("b", vec![Term::var("i")]);
22//! let expr = TLExpr::IfThenElse {
23//!     condition: Box::new(TLExpr::Constant(1.0)),
24//!     then_branch: Box::new(a.clone()),
25//!     else_branch: Box::new(b),
26//! };
27//! let (optimized, stats) = eliminate_dead_code(&expr);
28//! assert!(stats.branches_eliminated > 0);
29//! ```
30
31use std::collections::HashSet;
32use tensorlogic_ir::TLExpr;
33
34/// Statistics from dead code elimination.
35#[derive(Debug, Clone, Default)]
36pub struct DeadCodeStats {
37    /// Number of conditional branches eliminated
38    pub branches_eliminated: usize,
39    /// Number of short-circuit evaluations applied
40    pub short_circuits: usize,
41    /// Number of unused quantifiers removed
42    pub unused_quantifiers_removed: usize,
43    /// Number of identity expressions simplified
44    pub identity_simplifications: usize,
45    /// Total expressions processed
46    pub total_processed: usize,
47}
48
49impl DeadCodeStats {
50    /// Get total number of optimizations applied.
51    pub fn total_optimizations(&self) -> usize {
52        self.branches_eliminated
53            + self.short_circuits
54            + self.unused_quantifiers_removed
55            + self.identity_simplifications
56    }
57}
58
59/// Apply dead code elimination to an expression.
60///
61/// This pass removes unreachable code and simplifies expressions
62/// that have known outcomes.
63///
64/// # Arguments
65///
66/// * `expr` - The expression to optimize
67///
68/// # Returns
69///
70/// A tuple of (optimized expression, statistics)
71pub fn eliminate_dead_code(expr: &TLExpr) -> (TLExpr, DeadCodeStats) {
72    let mut stats = DeadCodeStats::default();
73    let result = eliminate_dead_code_impl(expr, &mut stats);
74    (result, stats)
75}
76
77fn eliminate_dead_code_impl(expr: &TLExpr, stats: &mut DeadCodeStats) -> TLExpr {
78    stats.total_processed += 1;
79
80    match expr {
81        // Conditional elimination
82        TLExpr::IfThenElse {
83            condition,
84            then_branch,
85            else_branch,
86        } => {
87            let cond_opt = eliminate_dead_code_impl(condition, stats);
88            let then_opt = eliminate_dead_code_impl(then_branch, stats);
89            let else_opt = eliminate_dead_code_impl(else_branch, stats);
90
91            // Check for constant conditions
92            if let TLExpr::Constant(c) = &cond_opt {
93                stats.branches_eliminated += 1;
94                // Non-zero is truthy
95                return if *c != 0.0 { then_opt } else { else_opt };
96            }
97
98            TLExpr::IfThenElse {
99                condition: Box::new(cond_opt),
100                then_branch: Box::new(then_opt),
101                else_branch: Box::new(else_opt),
102            }
103        }
104
105        // AND short-circuit: AND(false, x) → false, AND(x, false) → false
106        TLExpr::And(lhs, rhs) => {
107            let lhs_opt = eliminate_dead_code_impl(lhs, stats);
108            let rhs_opt = eliminate_dead_code_impl(rhs, stats);
109
110            // Check for constant false
111            if let TLExpr::Constant(c) = &lhs_opt {
112                if *c == 0.0 {
113                    stats.short_circuits += 1;
114                    return TLExpr::Constant(0.0);
115                }
116            }
117            if let TLExpr::Constant(c) = &rhs_opt {
118                if *c == 0.0 {
119                    stats.short_circuits += 1;
120                    return TLExpr::Constant(0.0);
121                }
122            }
123
124            // AND(true, x) → x
125            if let TLExpr::Constant(c) = &lhs_opt {
126                if *c != 0.0 {
127                    stats.identity_simplifications += 1;
128                    return rhs_opt;
129                }
130            }
131            // AND(x, true) → x
132            if let TLExpr::Constant(c) = &rhs_opt {
133                if *c != 0.0 {
134                    stats.identity_simplifications += 1;
135                    return lhs_opt;
136                }
137            }
138
139            TLExpr::And(Box::new(lhs_opt), Box::new(rhs_opt))
140        }
141
142        // OR short-circuit: OR(true, x) → true, OR(x, true) → true
143        TLExpr::Or(lhs, rhs) => {
144            let lhs_opt = eliminate_dead_code_impl(lhs, stats);
145            let rhs_opt = eliminate_dead_code_impl(rhs, stats);
146
147            // Check for constant true
148            if let TLExpr::Constant(c) = &lhs_opt {
149                if *c != 0.0 {
150                    stats.short_circuits += 1;
151                    return TLExpr::Constant(1.0);
152                }
153            }
154            if let TLExpr::Constant(c) = &rhs_opt {
155                if *c != 0.0 {
156                    stats.short_circuits += 1;
157                    return TLExpr::Constant(1.0);
158                }
159            }
160
161            // OR(false, x) → x
162            if let TLExpr::Constant(c) = &lhs_opt {
163                if *c == 0.0 {
164                    stats.identity_simplifications += 1;
165                    return rhs_opt;
166                }
167            }
168            // OR(x, false) → x
169            if let TLExpr::Constant(c) = &rhs_opt {
170                if *c == 0.0 {
171                    stats.identity_simplifications += 1;
172                    return lhs_opt;
173                }
174            }
175
176            TLExpr::Or(Box::new(lhs_opt), Box::new(rhs_opt))
177        }
178
179        // Imply with constant conditions
180        TLExpr::Imply(lhs, rhs) => {
181            let lhs_opt = eliminate_dead_code_impl(lhs, stats);
182            let rhs_opt = eliminate_dead_code_impl(rhs, stats);
183
184            // false → x is always true
185            if let TLExpr::Constant(c) = &lhs_opt {
186                if *c == 0.0 {
187                    stats.short_circuits += 1;
188                    return TLExpr::Constant(1.0);
189                }
190            }
191            // x → true is always true
192            if let TLExpr::Constant(c) = &rhs_opt {
193                if *c != 0.0 {
194                    stats.short_circuits += 1;
195                    return TLExpr::Constant(1.0);
196                }
197            }
198            // true → x is just x
199            if let TLExpr::Constant(c) = &lhs_opt {
200                if *c != 0.0 {
201                    stats.identity_simplifications += 1;
202                    return rhs_opt;
203                }
204            }
205
206            TLExpr::Imply(Box::new(lhs_opt), Box::new(rhs_opt))
207        }
208
209        // EXISTS with unused variable
210        TLExpr::Exists { var, domain, body } => {
211            let body_opt = eliminate_dead_code_impl(body, stats);
212
213            // If the variable is not free in the body, remove the quantifier
214            let free_vars = collect_free_vars(&body_opt);
215            if !free_vars.contains(var.as_str()) {
216                stats.unused_quantifiers_removed += 1;
217                return body_opt;
218            }
219
220            TLExpr::Exists {
221                var: var.clone(),
222                domain: domain.clone(),
223                body: Box::new(body_opt),
224            }
225        }
226
227        // FORALL with unused variable
228        TLExpr::ForAll { var, domain, body } => {
229            let body_opt = eliminate_dead_code_impl(body, stats);
230
231            // If the variable is not free in the body, remove the quantifier
232            let free_vars = collect_free_vars(&body_opt);
233            if !free_vars.contains(var.as_str()) {
234                stats.unused_quantifiers_removed += 1;
235                return body_opt;
236            }
237
238            TLExpr::ForAll {
239                var: var.clone(),
240                domain: domain.clone(),
241                body: Box::new(body_opt),
242            }
243        }
244
245        // Multiplication by zero
246        TLExpr::Mul(lhs, rhs) => {
247            let lhs_opt = eliminate_dead_code_impl(lhs, stats);
248            let rhs_opt = eliminate_dead_code_impl(rhs, stats);
249
250            // 0 * x = 0, x * 0 = 0
251            if matches!(&lhs_opt, TLExpr::Constant(c) if *c == 0.0) {
252                stats.short_circuits += 1;
253                return TLExpr::Constant(0.0);
254            }
255            if matches!(&rhs_opt, TLExpr::Constant(c) if *c == 0.0) {
256                stats.short_circuits += 1;
257                return TLExpr::Constant(0.0);
258            }
259
260            TLExpr::Mul(Box::new(lhs_opt), Box::new(rhs_opt))
261        }
262
263        // NOT with constant
264        TLExpr::Not(inner) => {
265            let inner_opt = eliminate_dead_code_impl(inner, stats);
266
267            // NOT(true) → false, NOT(false) → true
268            if let TLExpr::Constant(c) = &inner_opt {
269                stats.identity_simplifications += 1;
270                return TLExpr::Constant(if *c == 0.0 { 1.0 } else { 0.0 });
271            }
272
273            TLExpr::Not(Box::new(inner_opt))
274        }
275
276        // Min/Max with same operand
277        TLExpr::Min(lhs, rhs) => {
278            let lhs_opt = eliminate_dead_code_impl(lhs, stats);
279            let rhs_opt = eliminate_dead_code_impl(rhs, stats);
280
281            // min(x, x) = x
282            if exprs_equal(&lhs_opt, &rhs_opt) {
283                stats.identity_simplifications += 1;
284                return lhs_opt;
285            }
286
287            TLExpr::Min(Box::new(lhs_opt), Box::new(rhs_opt))
288        }
289
290        TLExpr::Max(lhs, rhs) => {
291            let lhs_opt = eliminate_dead_code_impl(lhs, stats);
292            let rhs_opt = eliminate_dead_code_impl(rhs, stats);
293
294            // max(x, x) = x
295            if exprs_equal(&lhs_opt, &rhs_opt) {
296                stats.identity_simplifications += 1;
297                return lhs_opt;
298            }
299
300            TLExpr::Max(Box::new(lhs_opt), Box::new(rhs_opt))
301        }
302
303        // Recursive cases - binary operations
304        TLExpr::Add(lhs, rhs) => TLExpr::Add(
305            Box::new(eliminate_dead_code_impl(lhs, stats)),
306            Box::new(eliminate_dead_code_impl(rhs, stats)),
307        ),
308        TLExpr::Sub(lhs, rhs) => TLExpr::Sub(
309            Box::new(eliminate_dead_code_impl(lhs, stats)),
310            Box::new(eliminate_dead_code_impl(rhs, stats)),
311        ),
312        TLExpr::Div(lhs, rhs) => TLExpr::Div(
313            Box::new(eliminate_dead_code_impl(lhs, stats)),
314            Box::new(eliminate_dead_code_impl(rhs, stats)),
315        ),
316        TLExpr::Pow(base, exp) => TLExpr::Pow(
317            Box::new(eliminate_dead_code_impl(base, stats)),
318            Box::new(eliminate_dead_code_impl(exp, stats)),
319        ),
320        TLExpr::Mod(lhs, rhs) => TLExpr::Mod(
321            Box::new(eliminate_dead_code_impl(lhs, stats)),
322            Box::new(eliminate_dead_code_impl(rhs, stats)),
323        ),
324
325        // Comparison operations
326        TLExpr::Eq(lhs, rhs) => TLExpr::Eq(
327            Box::new(eliminate_dead_code_impl(lhs, stats)),
328            Box::new(eliminate_dead_code_impl(rhs, stats)),
329        ),
330        TLExpr::Lt(lhs, rhs) => TLExpr::Lt(
331            Box::new(eliminate_dead_code_impl(lhs, stats)),
332            Box::new(eliminate_dead_code_impl(rhs, stats)),
333        ),
334        TLExpr::Lte(lhs, rhs) => TLExpr::Lte(
335            Box::new(eliminate_dead_code_impl(lhs, stats)),
336            Box::new(eliminate_dead_code_impl(rhs, stats)),
337        ),
338        TLExpr::Gt(lhs, rhs) => TLExpr::Gt(
339            Box::new(eliminate_dead_code_impl(lhs, stats)),
340            Box::new(eliminate_dead_code_impl(rhs, stats)),
341        ),
342        TLExpr::Gte(lhs, rhs) => TLExpr::Gte(
343            Box::new(eliminate_dead_code_impl(lhs, stats)),
344            Box::new(eliminate_dead_code_impl(rhs, stats)),
345        ),
346
347        // Unary operations
348        TLExpr::Exp(inner) => TLExpr::Exp(Box::new(eliminate_dead_code_impl(inner, stats))),
349        TLExpr::Log(inner) => TLExpr::Log(Box::new(eliminate_dead_code_impl(inner, stats))),
350        TLExpr::Sqrt(inner) => TLExpr::Sqrt(Box::new(eliminate_dead_code_impl(inner, stats))),
351        TLExpr::Abs(inner) => TLExpr::Abs(Box::new(eliminate_dead_code_impl(inner, stats))),
352        TLExpr::Sin(inner) => TLExpr::Sin(Box::new(eliminate_dead_code_impl(inner, stats))),
353        TLExpr::Cos(inner) => TLExpr::Cos(Box::new(eliminate_dead_code_impl(inner, stats))),
354        TLExpr::Tan(inner) => TLExpr::Tan(Box::new(eliminate_dead_code_impl(inner, stats))),
355        TLExpr::Floor(inner) => TLExpr::Floor(Box::new(eliminate_dead_code_impl(inner, stats))),
356        TLExpr::Ceil(inner) => TLExpr::Ceil(Box::new(eliminate_dead_code_impl(inner, stats))),
357        TLExpr::Round(inner) => TLExpr::Round(Box::new(eliminate_dead_code_impl(inner, stats))),
358        TLExpr::Score(inner) => TLExpr::Score(Box::new(eliminate_dead_code_impl(inner, stats))),
359
360        // Modal operators
361        TLExpr::Box(inner) => TLExpr::Box(Box::new(eliminate_dead_code_impl(inner, stats))),
362        TLExpr::Diamond(inner) => TLExpr::Diamond(Box::new(eliminate_dead_code_impl(inner, stats))),
363
364        // Temporal operators
365        TLExpr::Next(inner) => TLExpr::Next(Box::new(eliminate_dead_code_impl(inner, stats))),
366        TLExpr::Eventually(inner) => {
367            TLExpr::Eventually(Box::new(eliminate_dead_code_impl(inner, stats)))
368        }
369        TLExpr::Always(inner) => TLExpr::Always(Box::new(eliminate_dead_code_impl(inner, stats))),
370        TLExpr::Until { before, after } => TLExpr::Until {
371            before: Box::new(eliminate_dead_code_impl(before, stats)),
372            after: Box::new(eliminate_dead_code_impl(after, stats)),
373        },
374        TLExpr::Release { released, releaser } => TLExpr::Release {
375            released: Box::new(eliminate_dead_code_impl(released, stats)),
376            releaser: Box::new(eliminate_dead_code_impl(releaser, stats)),
377        },
378        TLExpr::WeakUntil { before, after } => TLExpr::WeakUntil {
379            before: Box::new(eliminate_dead_code_impl(before, stats)),
380            after: Box::new(eliminate_dead_code_impl(after, stats)),
381        },
382        TLExpr::StrongRelease { released, releaser } => TLExpr::StrongRelease {
383            released: Box::new(eliminate_dead_code_impl(released, stats)),
384            releaser: Box::new(eliminate_dead_code_impl(releaser, stats)),
385        },
386
387        // Fuzzy operators
388        TLExpr::TNorm { kind, left, right } => TLExpr::TNorm {
389            kind: *kind,
390            left: Box::new(eliminate_dead_code_impl(left, stats)),
391            right: Box::new(eliminate_dead_code_impl(right, stats)),
392        },
393        TLExpr::TCoNorm { kind, left, right } => TLExpr::TCoNorm {
394            kind: *kind,
395            left: Box::new(eliminate_dead_code_impl(left, stats)),
396            right: Box::new(eliminate_dead_code_impl(right, stats)),
397        },
398        TLExpr::FuzzyNot { kind, expr } => TLExpr::FuzzyNot {
399            kind: *kind,
400            expr: Box::new(eliminate_dead_code_impl(expr, stats)),
401        },
402        TLExpr::FuzzyImplication {
403            kind,
404            premise,
405            conclusion,
406        } => TLExpr::FuzzyImplication {
407            kind: *kind,
408            premise: Box::new(eliminate_dead_code_impl(premise, stats)),
409            conclusion: Box::new(eliminate_dead_code_impl(conclusion, stats)),
410        },
411
412        // Probabilistic operators
413        TLExpr::SoftExists {
414            var,
415            domain,
416            body,
417            temperature,
418        } => TLExpr::SoftExists {
419            var: var.clone(),
420            domain: domain.clone(),
421            body: Box::new(eliminate_dead_code_impl(body, stats)),
422            temperature: *temperature,
423        },
424        TLExpr::SoftForAll {
425            var,
426            domain,
427            body,
428            temperature,
429        } => TLExpr::SoftForAll {
430            var: var.clone(),
431            domain: domain.clone(),
432            body: Box::new(eliminate_dead_code_impl(body, stats)),
433            temperature: *temperature,
434        },
435        TLExpr::WeightedRule { weight, rule } => TLExpr::WeightedRule {
436            weight: *weight,
437            rule: Box::new(eliminate_dead_code_impl(rule, stats)),
438        },
439        TLExpr::ProbabilisticChoice { alternatives } => TLExpr::ProbabilisticChoice {
440            alternatives: alternatives
441                .iter()
442                .map(|(prob, e)| (*prob, eliminate_dead_code_impl(e, stats)))
443                .collect(),
444        },
445
446        // Leaf nodes - no recursion needed
447        TLExpr::Pred { .. } | TLExpr::Constant(_) => expr.clone(),
448
449        // Aggregate operations
450        TLExpr::Aggregate {
451            op,
452            var,
453            domain,
454            body,
455            group_by,
456        } => TLExpr::Aggregate {
457            op: op.clone(),
458            var: var.clone(),
459            domain: domain.clone(),
460            body: Box::new(eliminate_dead_code_impl(body, stats)),
461            group_by: group_by.clone(),
462        },
463
464        // Let binding
465        TLExpr::Let { var, value, body } => TLExpr::Let {
466            var: var.clone(),
467            value: Box::new(eliminate_dead_code_impl(value, stats)),
468            body: Box::new(eliminate_dead_code_impl(body, stats)),
469        },
470
471        // All other expression types (enhancements)
472        _ => expr.clone(),
473    }
474}
475
476/// Collect free variables in an expression.
477fn collect_free_vars(expr: &TLExpr) -> HashSet<String> {
478    let mut vars = HashSet::new();
479    collect_free_vars_impl(expr, &mut vars, &HashSet::new());
480    vars
481}
482
483fn collect_free_vars_impl(
484    expr: &TLExpr,
485    free_vars: &mut HashSet<String>,
486    bound_vars: &HashSet<String>,
487) {
488    match expr {
489        TLExpr::Pred { args, .. } => {
490            for arg in args {
491                if let tensorlogic_ir::Term::Var(v) = arg {
492                    if !bound_vars.contains(v) {
493                        free_vars.insert(v.clone());
494                    }
495                }
496            }
497        }
498
499        TLExpr::Exists { var, body, .. }
500        | TLExpr::ForAll { var, body, .. }
501        | TLExpr::SoftExists { var, body, .. }
502        | TLExpr::SoftForAll { var, body, .. } => {
503            let mut new_bound = bound_vars.clone();
504            new_bound.insert(var.clone());
505            collect_free_vars_impl(body, free_vars, &new_bound);
506        }
507
508        TLExpr::Aggregate { var, body, .. } => {
509            let mut new_bound = bound_vars.clone();
510            new_bound.insert(var.clone());
511            collect_free_vars_impl(body, free_vars, &new_bound);
512        }
513
514        TLExpr::Let { var, value, body } => {
515            collect_free_vars_impl(value, free_vars, bound_vars);
516            let mut new_bound = bound_vars.clone();
517            new_bound.insert(var.clone());
518            collect_free_vars_impl(body, free_vars, &new_bound);
519        }
520
521        // Binary operations
522        TLExpr::And(lhs, rhs)
523        | TLExpr::Or(lhs, rhs)
524        | TLExpr::Imply(lhs, rhs)
525        | TLExpr::Add(lhs, rhs)
526        | TLExpr::Sub(lhs, rhs)
527        | TLExpr::Mul(lhs, rhs)
528        | TLExpr::Div(lhs, rhs)
529        | TLExpr::Pow(lhs, rhs)
530        | TLExpr::Mod(lhs, rhs)
531        | TLExpr::Min(lhs, rhs)
532        | TLExpr::Max(lhs, rhs)
533        | TLExpr::Eq(lhs, rhs)
534        | TLExpr::Lt(lhs, rhs)
535        | TLExpr::Lte(lhs, rhs)
536        | TLExpr::Gt(lhs, rhs)
537        | TLExpr::Gte(lhs, rhs) => {
538            collect_free_vars_impl(lhs, free_vars, bound_vars);
539            collect_free_vars_impl(rhs, free_vars, bound_vars);
540        }
541
542        TLExpr::Until { before, after }
543        | TLExpr::WeakUntil { before, after }
544        | TLExpr::Release {
545            released: before,
546            releaser: after,
547        }
548        | TLExpr::StrongRelease {
549            released: before,
550            releaser: after,
551        } => {
552            collect_free_vars_impl(before, free_vars, bound_vars);
553            collect_free_vars_impl(after, free_vars, bound_vars);
554        }
555
556        TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
557            collect_free_vars_impl(left, free_vars, bound_vars);
558            collect_free_vars_impl(right, free_vars, bound_vars);
559        }
560
561        TLExpr::FuzzyImplication {
562            premise,
563            conclusion,
564            ..
565        } => {
566            collect_free_vars_impl(premise, free_vars, bound_vars);
567            collect_free_vars_impl(conclusion, free_vars, bound_vars);
568        }
569
570        // Unary operations
571        TLExpr::Not(inner)
572        | TLExpr::Exp(inner)
573        | TLExpr::Log(inner)
574        | TLExpr::Sqrt(inner)
575        | TLExpr::Abs(inner)
576        | TLExpr::Sin(inner)
577        | TLExpr::Cos(inner)
578        | TLExpr::Tan(inner)
579        | TLExpr::Floor(inner)
580        | TLExpr::Ceil(inner)
581        | TLExpr::Round(inner)
582        | TLExpr::Score(inner)
583        | TLExpr::Box(inner)
584        | TLExpr::Diamond(inner)
585        | TLExpr::Next(inner)
586        | TLExpr::Eventually(inner)
587        | TLExpr::Always(inner) => {
588            collect_free_vars_impl(inner, free_vars, bound_vars);
589        }
590
591        TLExpr::FuzzyNot { expr, .. } => {
592            collect_free_vars_impl(expr, free_vars, bound_vars);
593        }
594
595        TLExpr::WeightedRule { rule, .. } => {
596            collect_free_vars_impl(rule, free_vars, bound_vars);
597        }
598
599        TLExpr::ProbabilisticChoice { alternatives } => {
600            for (_, e) in alternatives {
601                collect_free_vars_impl(e, free_vars, bound_vars);
602            }
603        }
604
605        TLExpr::IfThenElse {
606            condition,
607            then_branch,
608            else_branch,
609        } => {
610            collect_free_vars_impl(condition, free_vars, bound_vars);
611            collect_free_vars_impl(then_branch, free_vars, bound_vars);
612            collect_free_vars_impl(else_branch, free_vars, bound_vars);
613        }
614
615        // Leaves
616        TLExpr::Constant(_) => {}
617
618        // All other expression types (enhancements)
619        _ => {}
620    }
621}
622
623/// Check if two expressions are structurally equal.
624fn exprs_equal(a: &TLExpr, b: &TLExpr) -> bool {
625    match (a, b) {
626        (TLExpr::Constant(c1), TLExpr::Constant(c2)) => (c1 - c2).abs() < 1e-10,
627        (TLExpr::Pred { name: n1, args: a1 }, TLExpr::Pred { name: n2, args: a2 }) => {
628            n1 == n2 && a1 == a2
629        }
630        (TLExpr::Add(l1, r1), TLExpr::Add(l2, r2))
631        | (TLExpr::Sub(l1, r1), TLExpr::Sub(l2, r2))
632        | (TLExpr::Mul(l1, r1), TLExpr::Mul(l2, r2))
633        | (TLExpr::Div(l1, r1), TLExpr::Div(l2, r2))
634        | (TLExpr::And(l1, r1), TLExpr::And(l2, r2))
635        | (TLExpr::Or(l1, r1), TLExpr::Or(l2, r2)) => exprs_equal(l1, l2) && exprs_equal(r1, r2),
636        (TLExpr::Not(e1), TLExpr::Not(e2))
637        | (TLExpr::Exp(e1), TLExpr::Exp(e2))
638        | (TLExpr::Log(e1), TLExpr::Log(e2))
639        | (TLExpr::Sqrt(e1), TLExpr::Sqrt(e2))
640        | (TLExpr::Abs(e1), TLExpr::Abs(e2)) => exprs_equal(e1, e2),
641        _ => false,
642    }
643}
644
645#[cfg(test)]
646mod tests {
647    use super::*;
648    use tensorlogic_ir::Term;
649
650    #[test]
651    fn test_if_true_elimination() {
652        let a = TLExpr::pred("a", vec![Term::var("i")]);
653        let b = TLExpr::pred("b", vec![Term::var("i")]);
654        let expr = TLExpr::IfThenElse {
655            condition: Box::new(TLExpr::Constant(1.0)),
656            then_branch: Box::new(a.clone()),
657            else_branch: Box::new(b),
658        };
659        let (optimized, stats) = eliminate_dead_code(&expr);
660        assert_eq!(stats.branches_eliminated, 1);
661        assert!(matches!(optimized, TLExpr::Pred { name, .. } if name == "a"));
662    }
663
664    #[test]
665    fn test_if_false_elimination() {
666        let a = TLExpr::pred("a", vec![Term::var("i")]);
667        let b = TLExpr::pred("b", vec![Term::var("i")]);
668        let expr = TLExpr::IfThenElse {
669            condition: Box::new(TLExpr::Constant(0.0)),
670            then_branch: Box::new(a),
671            else_branch: Box::new(b.clone()),
672        };
673        let (optimized, stats) = eliminate_dead_code(&expr);
674        assert_eq!(stats.branches_eliminated, 1);
675        assert!(matches!(optimized, TLExpr::Pred { name, .. } if name == "b"));
676    }
677
678    #[test]
679    fn test_and_short_circuit_false() {
680        let x = TLExpr::pred("x", vec![Term::var("i")]);
681        let expr = TLExpr::and(TLExpr::Constant(0.0), x);
682        let (optimized, stats) = eliminate_dead_code(&expr);
683        assert_eq!(stats.short_circuits, 1);
684        assert!(matches!(optimized, TLExpr::Constant(c) if c == 0.0));
685    }
686
687    #[test]
688    fn test_or_short_circuit_true() {
689        let x = TLExpr::pred("x", vec![Term::var("i")]);
690        let expr = TLExpr::or(TLExpr::Constant(1.0), x);
691        let (optimized, stats) = eliminate_dead_code(&expr);
692        assert_eq!(stats.short_circuits, 1);
693        assert!(matches!(optimized, TLExpr::Constant(c) if c == 1.0));
694    }
695
696    #[test]
697    fn test_unused_exists_quantifier() {
698        let const_expr = TLExpr::Constant(5.0);
699        let expr = TLExpr::Exists {
700            var: "x".to_string(),
701            domain: "D".to_string(),
702            body: Box::new(const_expr),
703        };
704        let (optimized, stats) = eliminate_dead_code(&expr);
705        assert_eq!(stats.unused_quantifiers_removed, 1);
706        assert!(matches!(optimized, TLExpr::Constant(c) if c == 5.0));
707    }
708
709    #[test]
710    fn test_used_exists_quantifier() {
711        let p_x = TLExpr::pred("p", vec![Term::var("x")]);
712        let expr = TLExpr::Exists {
713            var: "x".to_string(),
714            domain: "D".to_string(),
715            body: Box::new(p_x),
716        };
717        let (optimized, stats) = eliminate_dead_code(&expr);
718        assert_eq!(stats.unused_quantifiers_removed, 0);
719        assert!(matches!(optimized, TLExpr::Exists { .. }));
720    }
721
722    #[test]
723    fn test_mul_by_zero() {
724        let x = TLExpr::pred("x", vec![Term::var("i")]);
725        let expr = TLExpr::mul(x, TLExpr::Constant(0.0));
726        let (optimized, stats) = eliminate_dead_code(&expr);
727        assert_eq!(stats.short_circuits, 1);
728        assert!(matches!(optimized, TLExpr::Constant(c) if c == 0.0));
729    }
730
731    #[test]
732    fn test_min_same_operands() {
733        let x = TLExpr::pred("x", vec![Term::var("i")]);
734        let expr = TLExpr::Min(Box::new(x.clone()), Box::new(x));
735        let (optimized, stats) = eliminate_dead_code(&expr);
736        assert_eq!(stats.identity_simplifications, 1);
737        assert!(matches!(optimized, TLExpr::Pred { .. }));
738    }
739
740    #[test]
741    fn test_not_constant() {
742        let expr = TLExpr::Not(Box::new(TLExpr::Constant(1.0)));
743        let (optimized, stats) = eliminate_dead_code(&expr);
744        assert_eq!(stats.identity_simplifications, 1);
745        assert!(matches!(optimized, TLExpr::Constant(c) if c == 0.0));
746    }
747
748    #[test]
749    fn test_imply_false_antecedent() {
750        let x = TLExpr::pred("x", vec![Term::var("i")]);
751        let expr = TLExpr::Imply(Box::new(TLExpr::Constant(0.0)), Box::new(x));
752        let (optimized, stats) = eliminate_dead_code(&expr);
753        assert_eq!(stats.short_circuits, 1);
754        assert!(matches!(optimized, TLExpr::Constant(c) if c == 1.0));
755    }
756}