Skip to main content

tensorlogic_compiler/optimize/
algebraic.rs

1//! Algebraic simplification optimization pass.
2//!
3//! This module implements algebraic simplifications based on mathematical identities
4//! and properties, such as x + 0 = x, x * 1 = x, x * 0 = 0, etc.
5
6use tensorlogic_ir::TLExpr;
7
8/// Statistics from algebraic simplification.
9#[derive(Debug, Default, Clone)]
10pub struct AlgebraicSimplificationStats {
11    /// Number of identity operations eliminated (e.g., x + 0, x * 1)
12    pub identities_eliminated: usize,
13    /// Number of annihilation operations eliminated (e.g., x * 0)
14    pub annihilations_applied: usize,
15    /// Number of idempotent operations simplified (e.g., min(x, x) = x)
16    pub idempotent_simplified: usize,
17    /// Total expressions processed
18    pub total_processed: usize,
19}
20
21/// Simplify an expression using algebraic identities.
22///
23/// This pass applies mathematical identities to simplify expressions:
24/// - Identity: x + 0 = x, x * 1 = x, x - 0 = x, x / 1 = x
25/// - Annihilation: x * 0 = 0, 0 / x = 0
26/// - Idempotent: min(x, x) = x, max(x, x) = x
27/// - Power identities: x^0 = 1, x^1 = x, 1^x = 1
28///
29/// # Examples
30///
31/// ```
32/// use tensorlogic_compiler::optimize::simplify_algebraic;
33/// use tensorlogic_ir::{TLExpr, Term};
34///
35/// // x + 0 => x
36/// let x = TLExpr::pred("x", vec![Term::var("i")]);
37/// let expr = TLExpr::Add(Box::new(x.clone()), Box::new(TLExpr::Constant(0.0)));
38///
39/// let (simplified, stats) = simplify_algebraic(&expr);
40/// assert!(matches!(simplified, TLExpr::Pred { .. }));
41/// assert_eq!(stats.identities_eliminated, 1);
42/// ```
43pub fn simplify_algebraic(expr: &TLExpr) -> (TLExpr, AlgebraicSimplificationStats) {
44    let mut stats = AlgebraicSimplificationStats::default();
45    let result = simplify_algebraic_impl(expr, &mut stats);
46    (result, stats)
47}
48
49fn simplify_algebraic_impl(expr: &TLExpr, stats: &mut AlgebraicSimplificationStats) -> TLExpr {
50    stats.total_processed += 1;
51
52    match expr {
53        // Addition: x + 0 = x, 0 + x = x
54        TLExpr::Add(left, right) => {
55            let left_simp = simplify_algebraic_impl(left, stats);
56            let right_simp = simplify_algebraic_impl(right, stats);
57
58            if is_zero(&right_simp) {
59                stats.identities_eliminated += 1;
60                left_simp
61            } else if is_zero(&left_simp) {
62                stats.identities_eliminated += 1;
63                right_simp
64            } else {
65                TLExpr::Add(Box::new(left_simp), Box::new(right_simp))
66            }
67        }
68
69        // Subtraction: x - 0 = x
70        TLExpr::Sub(left, right) => {
71            let left_simp = simplify_algebraic_impl(left, stats);
72            let right_simp = simplify_algebraic_impl(right, stats);
73
74            if is_zero(&right_simp) {
75                stats.identities_eliminated += 1;
76                left_simp
77            } else {
78                TLExpr::Sub(Box::new(left_simp), Box::new(right_simp))
79            }
80        }
81
82        // Multiplication: x * 1 = x, 1 * x = x, x * 0 = 0, 0 * x = 0
83        TLExpr::Mul(left, right) => {
84            let left_simp = simplify_algebraic_impl(left, stats);
85            let right_simp = simplify_algebraic_impl(right, stats);
86
87            if is_zero(&left_simp) || is_zero(&right_simp) {
88                stats.annihilations_applied += 1;
89                TLExpr::Constant(0.0)
90            } else if is_one(&right_simp) {
91                stats.identities_eliminated += 1;
92                left_simp
93            } else if is_one(&left_simp) {
94                stats.identities_eliminated += 1;
95                right_simp
96            } else {
97                TLExpr::Mul(Box::new(left_simp), Box::new(right_simp))
98            }
99        }
100
101        // Division: x / 1 = x, 0 / x = 0
102        TLExpr::Div(left, right) => {
103            let left_simp = simplify_algebraic_impl(left, stats);
104            let right_simp = simplify_algebraic_impl(right, stats);
105
106            if is_one(&right_simp) {
107                stats.identities_eliminated += 1;
108                left_simp
109            } else if is_zero(&left_simp) {
110                stats.annihilations_applied += 1;
111                TLExpr::Constant(0.0)
112            } else {
113                TLExpr::Div(Box::new(left_simp), Box::new(right_simp))
114            }
115        }
116
117        // Power: x^0 = 1, x^1 = x, 1^x = 1, 0^x = 0 (for x > 0)
118        TLExpr::Pow(base, exponent) => {
119            let base_simp = simplify_algebraic_impl(base, stats);
120            let exp_simp = simplify_algebraic_impl(exponent, stats);
121
122            if is_zero(&exp_simp) {
123                stats.identities_eliminated += 1;
124                TLExpr::Constant(1.0)
125            } else if is_one(&exp_simp) {
126                stats.identities_eliminated += 1;
127                base_simp
128            } else if is_one(&base_simp) {
129                stats.annihilations_applied += 1;
130                TLExpr::Constant(1.0)
131            } else if is_zero(&base_simp) {
132                stats.annihilations_applied += 1;
133                TLExpr::Constant(0.0)
134            } else {
135                TLExpr::Pow(Box::new(base_simp), Box::new(exp_simp))
136            }
137        }
138
139        // Min/Max: min(x, x) = x, max(x, x) = x (idempotent)
140        TLExpr::Min(left, right) => {
141            let left_simp = simplify_algebraic_impl(left, stats);
142            let right_simp = simplify_algebraic_impl(right, stats);
143
144            if expressions_equal(&left_simp, &right_simp) {
145                stats.idempotent_simplified += 1;
146                left_simp
147            } else {
148                TLExpr::Min(Box::new(left_simp), Box::new(right_simp))
149            }
150        }
151        TLExpr::Max(left, right) => {
152            let left_simp = simplify_algebraic_impl(left, stats);
153            let right_simp = simplify_algebraic_impl(right, stats);
154
155            if expressions_equal(&left_simp, &right_simp) {
156                stats.idempotent_simplified += 1;
157                left_simp
158            } else {
159                TLExpr::Max(Box::new(left_simp), Box::new(right_simp))
160            }
161        }
162
163        // Unary operations: abs(abs(x)) = abs(x) is handled by simplification
164        TLExpr::Abs(inner) => {
165            let inner_simp = simplify_algebraic_impl(inner, stats);
166            // abs(abs(x)) = abs(x)
167            if matches!(&inner_simp, TLExpr::Abs(_)) {
168                stats.idempotent_simplified += 1;
169                inner_simp
170            } else {
171                TLExpr::Abs(Box::new(inner_simp))
172            }
173        }
174
175        // Other unary operations
176        TLExpr::Floor(inner) => {
177            let inner_simp = simplify_algebraic_impl(inner, stats);
178            TLExpr::Floor(Box::new(inner_simp))
179        }
180        TLExpr::Ceil(inner) => {
181            let inner_simp = simplify_algebraic_impl(inner, stats);
182            TLExpr::Ceil(Box::new(inner_simp))
183        }
184        TLExpr::Round(inner) => {
185            let inner_simp = simplify_algebraic_impl(inner, stats);
186            TLExpr::Round(Box::new(inner_simp))
187        }
188        TLExpr::Sqrt(inner) => {
189            let inner_simp = simplify_algebraic_impl(inner, stats);
190            TLExpr::Sqrt(Box::new(inner_simp))
191        }
192        TLExpr::Exp(inner) => {
193            let inner_simp = simplify_algebraic_impl(inner, stats);
194            TLExpr::Exp(Box::new(inner_simp))
195        }
196        TLExpr::Log(inner) => {
197            let inner_simp = simplify_algebraic_impl(inner, stats);
198            TLExpr::Log(Box::new(inner_simp))
199        }
200        TLExpr::Sin(inner) => {
201            let inner_simp = simplify_algebraic_impl(inner, stats);
202            TLExpr::Sin(Box::new(inner_simp))
203        }
204        TLExpr::Cos(inner) => {
205            let inner_simp = simplify_algebraic_impl(inner, stats);
206            TLExpr::Cos(Box::new(inner_simp))
207        }
208        TLExpr::Tan(inner) => {
209            let inner_simp = simplify_algebraic_impl(inner, stats);
210            TLExpr::Tan(Box::new(inner_simp))
211        }
212
213        // Modulo
214        TLExpr::Mod(left, right) => {
215            let left_simp = simplify_algebraic_impl(left, stats);
216            let right_simp = simplify_algebraic_impl(right, stats);
217            TLExpr::Mod(Box::new(left_simp), Box::new(right_simp))
218        }
219
220        // Logical operations
221        TLExpr::And(left, right) => {
222            let left_simp = simplify_algebraic_impl(left, stats);
223            let right_simp = simplify_algebraic_impl(right, stats);
224            TLExpr::And(Box::new(left_simp), Box::new(right_simp))
225        }
226        TLExpr::Or(left, right) => {
227            let left_simp = simplify_algebraic_impl(left, stats);
228            let right_simp = simplify_algebraic_impl(right, stats);
229            TLExpr::Or(Box::new(left_simp), Box::new(right_simp))
230        }
231        TLExpr::Not(inner) => {
232            let inner_simp = simplify_algebraic_impl(inner, stats);
233            TLExpr::Not(Box::new(inner_simp))
234        }
235        TLExpr::Imply(left, right) => {
236            let left_simp = simplify_algebraic_impl(left, stats);
237            let right_simp = simplify_algebraic_impl(right, stats);
238            TLExpr::Imply(Box::new(left_simp), Box::new(right_simp))
239        }
240
241        // Comparison operations
242        TLExpr::Eq(left, right) => {
243            let left_simp = simplify_algebraic_impl(left, stats);
244            let right_simp = simplify_algebraic_impl(right, stats);
245            TLExpr::Eq(Box::new(left_simp), Box::new(right_simp))
246        }
247        TLExpr::Lt(left, right) => {
248            let left_simp = simplify_algebraic_impl(left, stats);
249            let right_simp = simplify_algebraic_impl(right, stats);
250            TLExpr::Lt(Box::new(left_simp), Box::new(right_simp))
251        }
252        TLExpr::Gt(left, right) => {
253            let left_simp = simplify_algebraic_impl(left, stats);
254            let right_simp = simplify_algebraic_impl(right, stats);
255            TLExpr::Gt(Box::new(left_simp), Box::new(right_simp))
256        }
257        TLExpr::Lte(left, right) => {
258            let left_simp = simplify_algebraic_impl(left, stats);
259            let right_simp = simplify_algebraic_impl(right, stats);
260            TLExpr::Lte(Box::new(left_simp), Box::new(right_simp))
261        }
262        TLExpr::Gte(left, right) => {
263            let left_simp = simplify_algebraic_impl(left, stats);
264            let right_simp = simplify_algebraic_impl(right, stats);
265            TLExpr::Gte(Box::new(left_simp), Box::new(right_simp))
266        }
267
268        // Quantifiers and other constructs
269        TLExpr::Exists { var, domain, body } => {
270            let body_simp = simplify_algebraic_impl(body, stats);
271            TLExpr::Exists {
272                var: var.clone(),
273                domain: domain.clone(),
274                body: Box::new(body_simp),
275            }
276        }
277        TLExpr::ForAll { var, domain, body } => {
278            let body_simp = simplify_algebraic_impl(body, stats);
279            TLExpr::ForAll {
280                var: var.clone(),
281                domain: domain.clone(),
282                body: Box::new(body_simp),
283            }
284        }
285        TLExpr::Aggregate {
286            op,
287            var,
288            domain,
289            body,
290            group_by,
291        } => {
292            let body_simp = simplify_algebraic_impl(body, stats);
293            TLExpr::Aggregate {
294                op: op.clone(),
295                var: var.clone(),
296                domain: domain.clone(),
297                body: Box::new(body_simp),
298                group_by: group_by.clone(),
299            }
300        }
301        TLExpr::IfThenElse {
302            condition,
303            then_branch,
304            else_branch,
305        } => {
306            let cond_simp = simplify_algebraic_impl(condition, stats);
307            let then_simp = simplify_algebraic_impl(then_branch, stats);
308            let else_simp = simplify_algebraic_impl(else_branch, stats);
309            TLExpr::IfThenElse {
310                condition: Box::new(cond_simp),
311                then_branch: Box::new(then_simp),
312                else_branch: Box::new(else_simp),
313            }
314        }
315        TLExpr::Let { var, value, body } => {
316            let value_simp = simplify_algebraic_impl(value, stats);
317            let body_simp = simplify_algebraic_impl(body, stats);
318            TLExpr::Let {
319                var: var.clone(),
320                value: Box::new(value_simp),
321                body: Box::new(body_simp),
322            }
323        }
324
325        // Fuzzy logic operators
326        TLExpr::TNorm { kind, left, right } => {
327            let left_simp = simplify_algebraic_impl(left, stats);
328            let right_simp = simplify_algebraic_impl(right, stats);
329            TLExpr::TNorm {
330                kind: *kind,
331                left: Box::new(left_simp),
332                right: Box::new(right_simp),
333            }
334        }
335        TLExpr::TCoNorm { kind, left, right } => {
336            let left_simp = simplify_algebraic_impl(left, stats);
337            let right_simp = simplify_algebraic_impl(right, stats);
338            TLExpr::TCoNorm {
339                kind: *kind,
340                left: Box::new(left_simp),
341                right: Box::new(right_simp),
342            }
343        }
344        TLExpr::FuzzyNot { kind, expr: inner } => {
345            let inner_simp = simplify_algebraic_impl(inner, stats);
346            TLExpr::FuzzyNot {
347                kind: *kind,
348                expr: Box::new(inner_simp),
349            }
350        }
351        TLExpr::FuzzyImplication {
352            kind,
353            premise,
354            conclusion,
355        } => {
356            let premise_simp = simplify_algebraic_impl(premise, stats);
357            let conclusion_simp = simplify_algebraic_impl(conclusion, stats);
358            TLExpr::FuzzyImplication {
359                kind: *kind,
360                premise: Box::new(premise_simp),
361                conclusion: Box::new(conclusion_simp),
362            }
363        }
364        TLExpr::SoftExists {
365            var,
366            domain,
367            body,
368            temperature,
369        } => {
370            let body_simp = simplify_algebraic_impl(body, stats);
371            TLExpr::SoftExists {
372                var: var.clone(),
373                domain: domain.clone(),
374                body: Box::new(body_simp),
375                temperature: *temperature,
376            }
377        }
378        TLExpr::SoftForAll {
379            var,
380            domain,
381            body,
382            temperature,
383        } => {
384            let body_simp = simplify_algebraic_impl(body, stats);
385            TLExpr::SoftForAll {
386                var: var.clone(),
387                domain: domain.clone(),
388                body: Box::new(body_simp),
389                temperature: *temperature,
390            }
391        }
392        TLExpr::WeightedRule { weight, rule } => {
393            let rule_simp = simplify_algebraic_impl(rule, stats);
394            TLExpr::WeightedRule {
395                weight: *weight,
396                rule: Box::new(rule_simp),
397            }
398        }
399        TLExpr::ProbabilisticChoice { alternatives } => {
400            let alts_simp: Vec<_> = alternatives
401                .iter()
402                .map(|(w, e)| (*w, simplify_algebraic_impl(e, stats)))
403                .collect();
404            TLExpr::ProbabilisticChoice {
405                alternatives: alts_simp,
406            }
407        }
408
409        // Modal/temporal logic operators - not yet implemented, pass through with recursion
410        TLExpr::Box(inner) => TLExpr::Box(Box::new(simplify_algebraic_impl(inner, stats))),
411        TLExpr::Diamond(inner) => TLExpr::Diamond(Box::new(simplify_algebraic_impl(inner, stats))),
412        TLExpr::Next(inner) => TLExpr::Next(Box::new(simplify_algebraic_impl(inner, stats))),
413        TLExpr::Eventually(inner) => {
414            TLExpr::Eventually(Box::new(simplify_algebraic_impl(inner, stats)))
415        }
416        TLExpr::Always(inner) => TLExpr::Always(Box::new(simplify_algebraic_impl(inner, stats))),
417        TLExpr::Until { before, after } => TLExpr::Until {
418            before: Box::new(simplify_algebraic_impl(before, stats)),
419            after: Box::new(simplify_algebraic_impl(after, stats)),
420        },
421        TLExpr::Release { released, releaser } => TLExpr::Release {
422            released: Box::new(simplify_algebraic_impl(released, stats)),
423            releaser: Box::new(simplify_algebraic_impl(releaser, stats)),
424        },
425        TLExpr::WeakUntil { before, after } => TLExpr::WeakUntil {
426            before: Box::new(simplify_algebraic_impl(before, stats)),
427            after: Box::new(simplify_algebraic_impl(after, stats)),
428        },
429        TLExpr::StrongRelease { released, releaser } => TLExpr::StrongRelease {
430            released: Box::new(simplify_algebraic_impl(released, stats)),
431            releaser: Box::new(simplify_algebraic_impl(releaser, stats)),
432        },
433
434        // Base cases
435        TLExpr::Pred { .. } | TLExpr::Constant(_) | TLExpr::Score(_) => expr.clone(),
436        // All other expression types (enhancements) - no algebraic simplification
437        _ => expr.clone(),
438    }
439}
440
441/// Check if an expression is constant zero
442fn is_zero(expr: &TLExpr) -> bool {
443    matches!(expr, TLExpr::Constant(x) if x.abs() < f64::EPSILON)
444}
445
446/// Check if an expression is constant one
447fn is_one(expr: &TLExpr) -> bool {
448    matches!(expr, TLExpr::Constant(x) if (x - 1.0).abs() < f64::EPSILON)
449}
450
451/// Check if two expressions are structurally equal (for idempotent simplification)
452fn expressions_equal(a: &TLExpr, b: &TLExpr) -> bool {
453    match (a, b) {
454        (TLExpr::Constant(x), TLExpr::Constant(y)) => (x - y).abs() < f64::EPSILON,
455        (TLExpr::Pred { name: n1, args: a1 }, TLExpr::Pred { name: n2, args: a2 }) => {
456            n1 == n2 && a1 == a2
457        }
458        _ => false, // Conservative: only check simple cases
459    }
460}
461
462#[cfg(test)]
463mod tests {
464    use super::*;
465    use tensorlogic_ir::Term;
466
467    #[test]
468    fn test_addition_identity() {
469        // x + 0 = x
470        let x = TLExpr::pred("x", vec![Term::var("i")]);
471        let expr = TLExpr::Add(Box::new(x.clone()), Box::new(TLExpr::Constant(0.0)));
472
473        let (result, stats) = simplify_algebraic(&expr);
474        assert!(matches!(result, TLExpr::Pred { .. }));
475        assert_eq!(stats.identities_eliminated, 1);
476    }
477
478    #[test]
479    fn test_multiplication_identity() {
480        // x * 1 = x
481        let x = TLExpr::pred("x", vec![Term::var("i")]);
482        let expr = TLExpr::Mul(Box::new(x.clone()), Box::new(TLExpr::Constant(1.0)));
483
484        let (result, stats) = simplify_algebraic(&expr);
485        assert!(matches!(result, TLExpr::Pred { .. }));
486        assert_eq!(stats.identities_eliminated, 1);
487    }
488
489    #[test]
490    fn test_multiplication_annihilation() {
491        // x * 0 = 0
492        let x = TLExpr::pred("x", vec![Term::var("i")]);
493        let expr = TLExpr::Mul(Box::new(x), Box::new(TLExpr::Constant(0.0)));
494
495        let (result, stats) = simplify_algebraic(&expr);
496        assert!(matches!(result, TLExpr::Constant(0.0)));
497        assert_eq!(stats.annihilations_applied, 1);
498    }
499
500    #[test]
501    fn test_power_identities() {
502        let x = TLExpr::pred("x", vec![Term::var("i")]);
503
504        // x^0 = 1
505        let expr1 = TLExpr::Pow(Box::new(x.clone()), Box::new(TLExpr::Constant(0.0)));
506        let (result1, stats1) = simplify_algebraic(&expr1);
507        assert!(matches!(result1, TLExpr::Constant(1.0)));
508        assert_eq!(stats1.identities_eliminated, 1);
509
510        // x^1 = x
511        let expr2 = TLExpr::Pow(Box::new(x), Box::new(TLExpr::Constant(1.0)));
512        let (result2, stats2) = simplify_algebraic(&expr2);
513        assert!(matches!(result2, TLExpr::Pred { .. }));
514        assert_eq!(stats2.identities_eliminated, 1);
515    }
516
517    #[test]
518    fn test_idempotent_min_max() {
519        let x = TLExpr::pred("x", vec![Term::var("i")]);
520
521        // min(x, x) = x
522        let expr1 = TLExpr::Min(Box::new(x.clone()), Box::new(x.clone()));
523        let (result1, stats1) = simplify_algebraic(&expr1);
524        assert!(matches!(result1, TLExpr::Pred { .. }));
525        assert_eq!(stats1.idempotent_simplified, 1);
526
527        // max(x, x) = x
528        let expr2 = TLExpr::Max(Box::new(x.clone()), Box::new(x));
529        let (result2, stats2) = simplify_algebraic(&expr2);
530        assert!(matches!(result2, TLExpr::Pred { .. }));
531        assert_eq!(stats2.idempotent_simplified, 1);
532    }
533
534    #[test]
535    fn test_nested_simplification() {
536        // (x + 0) * 1 = x
537        let x = TLExpr::pred("x", vec![Term::var("i")]);
538        let add = TLExpr::Add(Box::new(x), Box::new(TLExpr::Constant(0.0)));
539        let expr = TLExpr::Mul(Box::new(add), Box::new(TLExpr::Constant(1.0)));
540
541        let (result, stats) = simplify_algebraic(&expr);
542        assert!(matches!(result, TLExpr::Pred { .. }));
543        assert_eq!(stats.identities_eliminated, 2);
544    }
545}