Skip to main content

tensorlogic_ir/expr/optimization/
algebraic.rs

1//! Algebraic simplification rules for logical and arithmetic expressions.
2//!
3//! This module implements algebraic identities and simplification rules that
4//! transform expressions into simpler equivalent forms without changing semantics.
5
6use crate::expr::TLExpr;
7
8pub fn algebraic_simplify(expr: &TLExpr) -> TLExpr {
9    match expr {
10        // Addition identities
11        TLExpr::Add(l, r) => {
12            let left = algebraic_simplify(l);
13            let right = algebraic_simplify(r);
14
15            // x + 0 = x
16            if let TLExpr::Constant(0.0) = right {
17                return left;
18            }
19            // 0 + x = x
20            if let TLExpr::Constant(0.0) = left {
21                return right;
22            }
23
24            TLExpr::Add(Box::new(left), Box::new(right))
25        }
26
27        // Subtraction identities
28        TLExpr::Sub(l, r) => {
29            let left = algebraic_simplify(l);
30            let right = algebraic_simplify(r);
31
32            // x - 0 = x
33            if let TLExpr::Constant(0.0) = right {
34                return left;
35            }
36            // x - x = 0 (simplified form comparison)
37            if left == right {
38                return TLExpr::Constant(0.0);
39            }
40
41            TLExpr::Sub(Box::new(left), Box::new(right))
42        }
43
44        // Multiplication identities
45        TLExpr::Mul(l, r) => {
46            let left = algebraic_simplify(l);
47            let right = algebraic_simplify(r);
48
49            // x * 0 = 0
50            if let TLExpr::Constant(0.0) = right {
51                return TLExpr::Constant(0.0);
52            }
53            if let TLExpr::Constant(0.0) = left {
54                return TLExpr::Constant(0.0);
55            }
56
57            // x * 1 = x
58            if let TLExpr::Constant(1.0) = right {
59                return left;
60            }
61            // 1 * x = x
62            if let TLExpr::Constant(1.0) = left {
63                return right;
64            }
65
66            TLExpr::Mul(Box::new(left), Box::new(right))
67        }
68
69        // Division identities
70        TLExpr::Div(l, r) => {
71            let left = algebraic_simplify(l);
72            let right = algebraic_simplify(r);
73
74            // x / 1 = x
75            if let TLExpr::Constant(1.0) = right {
76                return left;
77            }
78
79            // 0 / x = 0 (assuming x != 0)
80            if let TLExpr::Constant(0.0) = left {
81                if let TLExpr::Constant(rv) = right {
82                    if rv != 0.0 {
83                        return TLExpr::Constant(0.0);
84                    }
85                }
86            }
87
88            // x / x = 1 (assuming x != 0)
89            // Only apply for constants to avoid division by zero issues
90            if left == right {
91                if let TLExpr::Constant(v) = left {
92                    if v != 0.0 {
93                        return TLExpr::Constant(1.0);
94                    }
95                }
96            }
97
98            TLExpr::Div(Box::new(left), Box::new(right))
99        }
100
101        // Power identities
102        TLExpr::Pow(l, r) => {
103            let left = algebraic_simplify(l);
104            let right = algebraic_simplify(r);
105
106            // x ^ 0 = 1
107            if let TLExpr::Constant(0.0) = right {
108                return TLExpr::Constant(1.0);
109            }
110            // x ^ 1 = x
111            if let TLExpr::Constant(1.0) = right {
112                return left;
113            }
114            // 0 ^ x = 0 (for x > 0)
115            if let TLExpr::Constant(0.0) = left {
116                if let TLExpr::Constant(rv) = right {
117                    if rv > 0.0 {
118                        return TLExpr::Constant(0.0);
119                    }
120                }
121            }
122            // 1 ^ x = 1
123            if let TLExpr::Constant(1.0) = left {
124                return TLExpr::Constant(1.0);
125            }
126
127            TLExpr::Pow(Box::new(left), Box::new(right))
128        }
129
130        // Double negation: NOT(NOT(x)) = x
131        TLExpr::Not(e) => {
132            let inner = algebraic_simplify(e);
133            if let TLExpr::Not(inner_inner) = &inner {
134                return *inner_inner.clone();
135            }
136            TLExpr::Not(Box::new(inner))
137        }
138
139        // Recursively simplify other operations
140        TLExpr::Mod(l, r) => {
141            let left = algebraic_simplify(l);
142            let right = algebraic_simplify(r);
143            TLExpr::Mod(Box::new(left), Box::new(right))
144        }
145        TLExpr::Min(l, r) => {
146            let left = algebraic_simplify(l);
147            let right = algebraic_simplify(r);
148            TLExpr::Min(Box::new(left), Box::new(right))
149        }
150        TLExpr::Max(l, r) => {
151            let left = algebraic_simplify(l);
152            let right = algebraic_simplify(r);
153            TLExpr::Max(Box::new(left), Box::new(right))
154        }
155        TLExpr::Abs(e) => TLExpr::Abs(Box::new(algebraic_simplify(e))),
156        TLExpr::Floor(e) => TLExpr::Floor(Box::new(algebraic_simplify(e))),
157        TLExpr::Ceil(e) => TLExpr::Ceil(Box::new(algebraic_simplify(e))),
158        TLExpr::Round(e) => TLExpr::Round(Box::new(algebraic_simplify(e))),
159        TLExpr::Sqrt(e) => TLExpr::Sqrt(Box::new(algebraic_simplify(e))),
160        // Modal logic simplifications
161        TLExpr::Box(e) => {
162            let inner = algebraic_simplify(e);
163
164            // □(TRUE) = TRUE, □(FALSE) = FALSE
165            if let TLExpr::Constant(v) = inner {
166                return TLExpr::Constant(v);
167            }
168
169            TLExpr::Box(Box::new(inner))
170        }
171        TLExpr::Diamond(e) => {
172            let inner = algebraic_simplify(e);
173
174            // ◇(TRUE) = TRUE, ◇(FALSE) = FALSE
175            if let TLExpr::Constant(v) = inner {
176                return TLExpr::Constant(v);
177            }
178
179            TLExpr::Diamond(Box::new(inner))
180        }
181
182        // Temporal logic simplifications
183        TLExpr::Next(e) => {
184            let inner = algebraic_simplify(e);
185
186            // X(TRUE) = TRUE, X(FALSE) = FALSE
187            if let TLExpr::Constant(v) = inner {
188                return TLExpr::Constant(v);
189            }
190
191            TLExpr::Next(Box::new(inner))
192        }
193        TLExpr::Eventually(e) => {
194            let inner = algebraic_simplify(e);
195
196            // F(TRUE) = TRUE, F(FALSE) = FALSE
197            if let TLExpr::Constant(v) = inner {
198                return TLExpr::Constant(v);
199            }
200
201            // Idempotence: F(F(P)) = F(P)
202            if let TLExpr::Eventually(inner_inner) = &inner {
203                return TLExpr::Eventually(inner_inner.clone());
204            }
205
206            TLExpr::Eventually(Box::new(inner))
207        }
208        TLExpr::Always(e) => {
209            let inner = algebraic_simplify(e);
210
211            // G(TRUE) = TRUE, G(FALSE) = FALSE
212            if let TLExpr::Constant(v) = inner {
213                return TLExpr::Constant(v);
214            }
215
216            // Idempotence: G(G(P)) = G(P)
217            if let TLExpr::Always(inner_inner) = &inner {
218                return TLExpr::Always(inner_inner.clone());
219            }
220
221            TLExpr::Always(Box::new(inner))
222        }
223        TLExpr::Until { before, after } => {
224            let before_simplified = algebraic_simplify(before);
225            let after_simplified = algebraic_simplify(after);
226
227            // P U TRUE = TRUE (after becomes immediately true)
228            if let TLExpr::Constant(1.0) = after_simplified {
229                return TLExpr::Constant(1.0);
230            }
231
232            // FALSE U P = F(P) (before is never true, so we just wait for after)
233            if let TLExpr::Constant(0.0) = before_simplified {
234                return TLExpr::Eventually(Box::new(after_simplified));
235            }
236
237            TLExpr::Until {
238                before: Box::new(before_simplified),
239                after: Box::new(after_simplified),
240            }
241        }
242
243        // Fuzzy logic operators - pass through with recursive simplification
244        TLExpr::TNorm { kind, left, right } => TLExpr::TNorm {
245            kind: *kind,
246            left: Box::new(algebraic_simplify(left)),
247            right: Box::new(algebraic_simplify(right)),
248        },
249        TLExpr::TCoNorm { kind, left, right } => TLExpr::TCoNorm {
250            kind: *kind,
251            left: Box::new(algebraic_simplify(left)),
252            right: Box::new(algebraic_simplify(right)),
253        },
254        TLExpr::FuzzyNot { kind, expr } => TLExpr::FuzzyNot {
255            kind: *kind,
256            expr: Box::new(algebraic_simplify(expr)),
257        },
258        TLExpr::FuzzyImplication {
259            kind,
260            premise,
261            conclusion,
262        } => TLExpr::FuzzyImplication {
263            kind: *kind,
264            premise: Box::new(algebraic_simplify(premise)),
265            conclusion: Box::new(algebraic_simplify(conclusion)),
266        },
267
268        // Probabilistic operators - pass through
269        TLExpr::SoftExists {
270            var,
271            domain,
272            body,
273            temperature,
274        } => TLExpr::SoftExists {
275            var: var.clone(),
276            domain: domain.clone(),
277            body: Box::new(algebraic_simplify(body)),
278            temperature: *temperature,
279        },
280        TLExpr::SoftForAll {
281            var,
282            domain,
283            body,
284            temperature,
285        } => TLExpr::SoftForAll {
286            var: var.clone(),
287            domain: domain.clone(),
288            body: Box::new(algebraic_simplify(body)),
289            temperature: *temperature,
290        },
291        TLExpr::WeightedRule { weight, rule } => TLExpr::WeightedRule {
292            weight: *weight,
293            rule: Box::new(algebraic_simplify(rule)),
294        },
295        TLExpr::ProbabilisticChoice { alternatives } => TLExpr::ProbabilisticChoice {
296            alternatives: alternatives
297                .iter()
298                .map(|(p, e)| (*p, algebraic_simplify(e)))
299                .collect(),
300        },
301
302        // Extended temporal logic - pass through
303        TLExpr::Release { released, releaser } => TLExpr::Release {
304            released: Box::new(algebraic_simplify(released)),
305            releaser: Box::new(algebraic_simplify(releaser)),
306        },
307        TLExpr::WeakUntil { before, after } => TLExpr::WeakUntil {
308            before: Box::new(algebraic_simplify(before)),
309            after: Box::new(algebraic_simplify(after)),
310        },
311        TLExpr::StrongRelease { released, releaser } => TLExpr::StrongRelease {
312            released: Box::new(algebraic_simplify(released)),
313            releaser: Box::new(algebraic_simplify(releaser)),
314        },
315
316        TLExpr::Exp(e) => TLExpr::Exp(Box::new(algebraic_simplify(e))),
317        TLExpr::Log(e) => TLExpr::Log(Box::new(algebraic_simplify(e))),
318        TLExpr::Sin(e) => TLExpr::Sin(Box::new(algebraic_simplify(e))),
319        TLExpr::Cos(e) => TLExpr::Cos(Box::new(algebraic_simplify(e))),
320        TLExpr::Tan(e) => TLExpr::Tan(Box::new(algebraic_simplify(e))),
321        // EQ simplifications
322        TLExpr::Eq(l, r) => {
323            let left = algebraic_simplify(l);
324            let right = algebraic_simplify(r);
325
326            // x = x → TRUE
327            if left == right {
328                return TLExpr::Constant(1.0);
329            }
330
331            TLExpr::Eq(Box::new(left), Box::new(right))
332        }
333
334        // LT simplifications
335        TLExpr::Lt(l, r) => {
336            let left = algebraic_simplify(l);
337            let right = algebraic_simplify(r);
338
339            // x < x → FALSE
340            if left == right {
341                return TLExpr::Constant(0.0);
342            }
343
344            TLExpr::Lt(Box::new(left), Box::new(right))
345        }
346
347        // GT simplifications
348        TLExpr::Gt(l, r) => {
349            let left = algebraic_simplify(l);
350            let right = algebraic_simplify(r);
351
352            // x > x → FALSE
353            if left == right {
354                return TLExpr::Constant(0.0);
355            }
356
357            TLExpr::Gt(Box::new(left), Box::new(right))
358        }
359
360        // LTE simplifications
361        TLExpr::Lte(l, r) => {
362            let left = algebraic_simplify(l);
363            let right = algebraic_simplify(r);
364
365            // x <= x → TRUE
366            if left == right {
367                return TLExpr::Constant(1.0);
368            }
369
370            TLExpr::Lte(Box::new(left), Box::new(right))
371        }
372
373        // GTE simplifications
374        TLExpr::Gte(l, r) => {
375            let left = algebraic_simplify(l);
376            let right = algebraic_simplify(r);
377
378            // x >= x → TRUE
379            if left == right {
380                return TLExpr::Constant(1.0);
381            }
382
383            TLExpr::Gte(Box::new(left), Box::new(right))
384        }
385        // AND logical laws
386        TLExpr::And(l, r) => {
387            let left = algebraic_simplify(l);
388            let right = algebraic_simplify(r);
389
390            // Idempotence: A ∧ A = A
391            if left == right {
392                return left;
393            }
394
395            // Identity: A ∧ TRUE = A, TRUE ∧ A = A
396            if let TLExpr::Constant(1.0) = right {
397                return left;
398            }
399            if let TLExpr::Constant(1.0) = left {
400                return right;
401            }
402
403            // Annihilation: A ∧ FALSE = FALSE, FALSE ∧ A = FALSE
404            if let TLExpr::Constant(0.0) = right {
405                return TLExpr::Constant(0.0);
406            }
407            if let TLExpr::Constant(0.0) = left {
408                return TLExpr::Constant(0.0);
409            }
410
411            // Complement: A ∧ ¬A = FALSE
412            if let TLExpr::Not(inner) = &right {
413                if **inner == left {
414                    return TLExpr::Constant(0.0);
415                }
416            }
417            if let TLExpr::Not(inner) = &left {
418                if **inner == right {
419                    return TLExpr::Constant(0.0);
420                }
421            }
422
423            // Absorption: A ∧ (A ∨ B) = A
424            if let TLExpr::Or(or_left, _or_right) = &right {
425                if **or_left == left {
426                    return left;
427                }
428            }
429            if let TLExpr::Or(or_left, _or_right) = &left {
430                if **or_left == right {
431                    return right;
432                }
433            }
434
435            TLExpr::And(Box::new(left), Box::new(right))
436        }
437
438        // OR logical laws
439        TLExpr::Or(l, r) => {
440            let left = algebraic_simplify(l);
441            let right = algebraic_simplify(r);
442
443            // Idempotence: A ∨ A = A
444            if left == right {
445                return left;
446            }
447
448            // Annihilation: A ∨ TRUE = TRUE, TRUE ∨ A = TRUE
449            if let TLExpr::Constant(1.0) = right {
450                return TLExpr::Constant(1.0);
451            }
452            if let TLExpr::Constant(1.0) = left {
453                return TLExpr::Constant(1.0);
454            }
455
456            // Identity: A ∨ FALSE = A, FALSE ∨ A = A
457            if let TLExpr::Constant(0.0) = right {
458                return left;
459            }
460            if let TLExpr::Constant(0.0) = left {
461                return right;
462            }
463
464            // Complement: A ∨ ¬A = TRUE
465            if let TLExpr::Not(inner) = &right {
466                if **inner == left {
467                    return TLExpr::Constant(1.0);
468                }
469            }
470            if let TLExpr::Not(inner) = &left {
471                if **inner == right {
472                    return TLExpr::Constant(1.0);
473                }
474            }
475
476            // Absorption: A ∨ (A ∧ B) = A
477            if let TLExpr::And(and_left, _and_right) = &right {
478                if **and_left == left {
479                    return left;
480                }
481            }
482            if let TLExpr::And(and_left, _and_right) = &left {
483                if **and_left == right {
484                    return right;
485                }
486            }
487
488            TLExpr::Or(Box::new(left), Box::new(right))
489        }
490
491        // IMPLY simplifications
492        TLExpr::Imply(l, r) => {
493            let left = algebraic_simplify(l);
494            let right = algebraic_simplify(r);
495
496            // TRUE → P = P
497            if let TLExpr::Constant(1.0) = left {
498                return right;
499            }
500
501            // FALSE → P = TRUE
502            if let TLExpr::Constant(0.0) = left {
503                return TLExpr::Constant(1.0);
504            }
505
506            // P → TRUE = TRUE
507            if let TLExpr::Constant(1.0) = right {
508                return TLExpr::Constant(1.0);
509            }
510
511            // P → FALSE = ¬P
512            if let TLExpr::Constant(0.0) = right {
513                return TLExpr::negate(left);
514            }
515
516            // P → P = TRUE
517            if left == right {
518                return TLExpr::Constant(1.0);
519            }
520
521            TLExpr::Imply(Box::new(left), Box::new(right))
522        }
523        TLExpr::Score(e) => TLExpr::Score(Box::new(algebraic_simplify(e))),
524        TLExpr::Exists { var, domain, body } => TLExpr::Exists {
525            var: var.clone(),
526            domain: domain.clone(),
527            body: Box::new(algebraic_simplify(body)),
528        },
529        TLExpr::ForAll { var, domain, body } => TLExpr::ForAll {
530            var: var.clone(),
531            domain: domain.clone(),
532            body: Box::new(algebraic_simplify(body)),
533        },
534        TLExpr::Aggregate {
535            op,
536            var,
537            domain,
538            body,
539            group_by,
540        } => TLExpr::Aggregate {
541            op: op.clone(),
542            var: var.clone(),
543            domain: domain.clone(),
544            body: Box::new(algebraic_simplify(body)),
545            group_by: group_by.clone(),
546        },
547        TLExpr::IfThenElse {
548            condition,
549            then_branch,
550            else_branch,
551        } => TLExpr::IfThenElse {
552            condition: Box::new(algebraic_simplify(condition)),
553            then_branch: Box::new(algebraic_simplify(then_branch)),
554            else_branch: Box::new(algebraic_simplify(else_branch)),
555        },
556        TLExpr::Let { var, value, body } => TLExpr::Let {
557            var: var.clone(),
558            value: Box::new(algebraic_simplify(value)),
559            body: Box::new(algebraic_simplify(body)),
560        },
561
562        // Alpha.3 enhancements
563        TLExpr::Lambda {
564            var,
565            var_type,
566            body,
567        } => TLExpr::lambda(var.clone(), var_type.clone(), algebraic_simplify(body)),
568        TLExpr::Apply { function, argument } => {
569            TLExpr::apply(algebraic_simplify(function), algebraic_simplify(argument))
570        }
571        TLExpr::SetMembership { element, set } => {
572            TLExpr::set_membership(algebraic_simplify(element), algebraic_simplify(set))
573        }
574        TLExpr::SetUnion { left, right } => {
575            TLExpr::set_union(algebraic_simplify(left), algebraic_simplify(right))
576        }
577        TLExpr::SetIntersection { left, right } => {
578            TLExpr::set_intersection(algebraic_simplify(left), algebraic_simplify(right))
579        }
580        TLExpr::SetDifference { left, right } => {
581            TLExpr::set_difference(algebraic_simplify(left), algebraic_simplify(right))
582        }
583        TLExpr::SetCardinality { set } => TLExpr::set_cardinality(algebraic_simplify(set)),
584        TLExpr::EmptySet => expr.clone(),
585        TLExpr::SetComprehension {
586            var,
587            domain,
588            condition,
589        } => TLExpr::set_comprehension(var.clone(), domain.clone(), algebraic_simplify(condition)),
590        TLExpr::CountingExists {
591            var,
592            domain,
593            body,
594            min_count,
595        } => TLExpr::counting_exists(
596            var.clone(),
597            domain.clone(),
598            algebraic_simplify(body),
599            *min_count,
600        ),
601        TLExpr::CountingForAll {
602            var,
603            domain,
604            body,
605            min_count,
606        } => TLExpr::counting_forall(
607            var.clone(),
608            domain.clone(),
609            algebraic_simplify(body),
610            *min_count,
611        ),
612        TLExpr::ExactCount {
613            var,
614            domain,
615            body,
616            count,
617        } => TLExpr::exact_count(
618            var.clone(),
619            domain.clone(),
620            algebraic_simplify(body),
621            *count,
622        ),
623        TLExpr::Majority { var, domain, body } => {
624            TLExpr::majority(var.clone(), domain.clone(), algebraic_simplify(body))
625        }
626        TLExpr::LeastFixpoint { var, body } => {
627            TLExpr::least_fixpoint(var.clone(), algebraic_simplify(body))
628        }
629        TLExpr::GreatestFixpoint { var, body } => {
630            TLExpr::greatest_fixpoint(var.clone(), algebraic_simplify(body))
631        }
632        TLExpr::Nominal { .. } => expr.clone(),
633        TLExpr::At { nominal, formula } => TLExpr::at(nominal.clone(), algebraic_simplify(formula)),
634        TLExpr::Somewhere { formula } => TLExpr::somewhere(algebraic_simplify(formula)),
635        TLExpr::Everywhere { formula } => TLExpr::everywhere(algebraic_simplify(formula)),
636        TLExpr::AllDifferent { .. } => expr.clone(),
637        TLExpr::GlobalCardinality {
638            variables,
639            values,
640            min_occurrences,
641            max_occurrences,
642        } => TLExpr::global_cardinality(
643            variables.clone(),
644            values.iter().map(algebraic_simplify).collect(),
645            min_occurrences.clone(),
646            max_occurrences.clone(),
647        ),
648        TLExpr::Abducible { .. } => expr.clone(),
649        TLExpr::Explain { formula } => TLExpr::explain(algebraic_simplify(formula)),
650
651        // Leaves
652        TLExpr::Pred { .. } | TLExpr::Constant(_) => expr.clone(),
653    }
654}