1#![allow(dead_code)]
7
8use tensorlogic_ir::{TLExpr, Term};
9
10pub fn simplify_expression(expr: &TLExpr) -> TLExpr {
21 let expr = apply_constant_folding(expr);
22 let expr = apply_identity_laws(&expr);
23 let expr = apply_double_negation(&expr);
24 let expr = apply_idempotent_laws(&expr);
25 let expr = apply_absorption_laws(&expr);
26
27 apply_de_morgan(&expr)
28}
29
30fn apply_constant_folding(expr: &TLExpr) -> TLExpr {
37 match expr {
38 TLExpr::Add(left, right) => {
40 let left_folded = apply_constant_folding(left);
41 let right_folded = apply_constant_folding(right);
42 if let (TLExpr::Constant(a), TLExpr::Constant(b)) = (&left_folded, &right_folded) {
43 TLExpr::Constant(a + b)
44 } else {
45 TLExpr::Add(Box::new(left_folded), Box::new(right_folded))
46 }
47 }
48 TLExpr::Sub(left, right) => {
49 let left_folded = apply_constant_folding(left);
50 let right_folded = apply_constant_folding(right);
51 if let (TLExpr::Constant(a), TLExpr::Constant(b)) = (&left_folded, &right_folded) {
52 TLExpr::Constant(a - b)
53 } else {
54 TLExpr::Sub(Box::new(left_folded), Box::new(right_folded))
55 }
56 }
57 TLExpr::Mul(left, right) => {
58 let left_folded = apply_constant_folding(left);
59 let right_folded = apply_constant_folding(right);
60 if let (TLExpr::Constant(a), TLExpr::Constant(b)) = (&left_folded, &right_folded) {
61 TLExpr::Constant(a * b)
62 } else {
63 TLExpr::Mul(Box::new(left_folded), Box::new(right_folded))
64 }
65 }
66 TLExpr::Div(left, right) => {
67 let left_folded = apply_constant_folding(left);
68 let right_folded = apply_constant_folding(right);
69 if let (TLExpr::Constant(a), TLExpr::Constant(b)) = (&left_folded, &right_folded) {
70 if *b != 0.0 {
71 TLExpr::Constant(a / b)
72 } else {
73 TLExpr::Div(Box::new(left_folded), Box::new(right_folded))
74 }
75 } else {
76 TLExpr::Div(Box::new(left_folded), Box::new(right_folded))
77 }
78 }
79 TLExpr::Pow(left, right) => {
80 let left_folded = apply_constant_folding(left);
81 let right_folded = apply_constant_folding(right);
82 if let (TLExpr::Constant(a), TLExpr::Constant(b)) = (&left_folded, &right_folded) {
83 TLExpr::Constant(a.powf(*b))
84 } else {
85 TLExpr::Pow(Box::new(left_folded), Box::new(right_folded))
86 }
87 }
88 TLExpr::Min(left, right) => {
89 let left_folded = apply_constant_folding(left);
90 let right_folded = apply_constant_folding(right);
91 if let (TLExpr::Constant(a), TLExpr::Constant(b)) = (&left_folded, &right_folded) {
92 TLExpr::Constant(a.min(*b))
93 } else {
94 TLExpr::Min(Box::new(left_folded), Box::new(right_folded))
95 }
96 }
97 TLExpr::Max(left, right) => {
98 let left_folded = apply_constant_folding(left);
99 let right_folded = apply_constant_folding(right);
100 if let (TLExpr::Constant(a), TLExpr::Constant(b)) = (&left_folded, &right_folded) {
101 TLExpr::Constant(a.max(*b))
102 } else {
103 TLExpr::Max(Box::new(left_folded), Box::new(right_folded))
104 }
105 }
106 TLExpr::Abs(inner) => {
108 let inner_folded = apply_constant_folding(inner);
109 if let TLExpr::Constant(a) = inner_folded {
110 TLExpr::Constant(a.abs())
111 } else {
112 TLExpr::Abs(Box::new(inner_folded))
113 }
114 }
115 TLExpr::Sqrt(inner) => {
116 let inner_folded = apply_constant_folding(inner);
117 if let TLExpr::Constant(a) = inner_folded {
118 TLExpr::Constant(a.sqrt())
119 } else {
120 TLExpr::Sqrt(Box::new(inner_folded))
121 }
122 }
123 TLExpr::Exp(inner) => {
124 let inner_folded = apply_constant_folding(inner);
125 if let TLExpr::Constant(a) = inner_folded {
126 TLExpr::Constant(a.exp())
127 } else {
128 TLExpr::Exp(Box::new(inner_folded))
129 }
130 }
131 TLExpr::Log(inner) => {
132 let inner_folded = apply_constant_folding(inner);
133 if let TLExpr::Constant(a) = inner_folded {
134 TLExpr::Constant(a.ln())
135 } else {
136 TLExpr::Log(Box::new(inner_folded))
137 }
138 }
139 TLExpr::And(left, right) => TLExpr::And(
141 Box::new(apply_constant_folding(left)),
142 Box::new(apply_constant_folding(right)),
143 ),
144 TLExpr::Or(left, right) => TLExpr::Or(
145 Box::new(apply_constant_folding(left)),
146 Box::new(apply_constant_folding(right)),
147 ),
148 TLExpr::Not(inner) => TLExpr::Not(Box::new(apply_constant_folding(inner))),
149 TLExpr::Imply(left, right) => TLExpr::Imply(
150 Box::new(apply_constant_folding(left)),
151 Box::new(apply_constant_folding(right)),
152 ),
153 _ => expr.clone(),
154 }
155}
156
157fn apply_identity_laws(expr: &TLExpr) -> TLExpr {
167 match expr {
168 TLExpr::And(left, right) => {
169 let left_simplified = apply_identity_laws(left);
170 let right_simplified = apply_identity_laws(right);
171
172 if let TLExpr::Constant(c) = &right_simplified {
174 if (*c - 1.0).abs() < 1e-10 {
175 return left_simplified;
176 }
177 if c.abs() < 1e-10 {
179 return TLExpr::Constant(0.0);
180 }
181 }
182 if let TLExpr::Constant(c) = &left_simplified {
184 if (*c - 1.0).abs() < 1e-10 {
185 return right_simplified;
186 }
187 if c.abs() < 1e-10 {
189 return TLExpr::Constant(0.0);
190 }
191 }
192
193 TLExpr::And(Box::new(left_simplified), Box::new(right_simplified))
194 }
195 TLExpr::Or(left, right) => {
196 let left_simplified = apply_identity_laws(left);
197 let right_simplified = apply_identity_laws(right);
198
199 if let TLExpr::Constant(c) = &right_simplified {
201 if c.abs() < 1e-10 {
202 return left_simplified;
203 }
204 if (*c - 1.0).abs() < 1e-10 {
206 return TLExpr::Constant(1.0);
207 }
208 }
209 if let TLExpr::Constant(c) = &left_simplified {
211 if c.abs() < 1e-10 {
212 return right_simplified;
213 }
214 if (*c - 1.0).abs() < 1e-10 {
216 return TLExpr::Constant(1.0);
217 }
218 }
219
220 TLExpr::Or(Box::new(left_simplified), Box::new(right_simplified))
221 }
222 TLExpr::Not(inner) => TLExpr::Not(Box::new(apply_identity_laws(inner))),
223 TLExpr::Imply(left, right) => TLExpr::Imply(
224 Box::new(apply_identity_laws(left)),
225 Box::new(apply_identity_laws(right)),
226 ),
227 TLExpr::Exists { var, domain, body } => TLExpr::Exists {
228 var: var.clone(),
229 domain: domain.clone(),
230 body: Box::new(apply_identity_laws(body)),
231 },
232 TLExpr::ForAll { var, domain, body } => TLExpr::ForAll {
233 var: var.clone(),
234 domain: domain.clone(),
235 body: Box::new(apply_identity_laws(body)),
236 },
237 _ => expr.clone(),
238 }
239}
240
241fn apply_double_negation(expr: &TLExpr) -> TLExpr {
243 match expr {
244 TLExpr::Not(inner) => {
245 if let TLExpr::Not(inner_inner) = &**inner {
246 apply_double_negation(inner_inner)
247 } else {
248 TLExpr::Not(Box::new(apply_double_negation(inner)))
249 }
250 }
251 TLExpr::And(left, right) => TLExpr::And(
252 Box::new(apply_double_negation(left)),
253 Box::new(apply_double_negation(right)),
254 ),
255 TLExpr::Or(left, right) => TLExpr::Or(
256 Box::new(apply_double_negation(left)),
257 Box::new(apply_double_negation(right)),
258 ),
259 TLExpr::Imply(left, right) => TLExpr::Imply(
260 Box::new(apply_double_negation(left)),
261 Box::new(apply_double_negation(right)),
262 ),
263 TLExpr::Exists { var, domain, body } => TLExpr::Exists {
264 var: var.clone(),
265 domain: domain.clone(),
266 body: Box::new(apply_double_negation(body)),
267 },
268 TLExpr::ForAll { var, domain, body } => TLExpr::ForAll {
269 var: var.clone(),
270 domain: domain.clone(),
271 body: Box::new(apply_double_negation(body)),
272 },
273 _ => expr.clone(),
274 }
275}
276
277fn apply_idempotent_laws(expr: &TLExpr) -> TLExpr {
279 match expr {
280 TLExpr::And(left, right) => {
281 let left_simplified = apply_idempotent_laws(left);
282 let right_simplified = apply_idempotent_laws(right);
283
284 if expressions_equal(&left_simplified, &right_simplified) {
286 left_simplified
287 } else {
288 TLExpr::And(Box::new(left_simplified), Box::new(right_simplified))
289 }
290 }
291 TLExpr::Or(left, right) => {
292 let left_simplified = apply_idempotent_laws(left);
293 let right_simplified = apply_idempotent_laws(right);
294
295 if expressions_equal(&left_simplified, &right_simplified) {
297 left_simplified
298 } else {
299 TLExpr::Or(Box::new(left_simplified), Box::new(right_simplified))
300 }
301 }
302 TLExpr::Not(inner) => TLExpr::Not(Box::new(apply_idempotent_laws(inner))),
303 TLExpr::Imply(left, right) => TLExpr::Imply(
304 Box::new(apply_idempotent_laws(left)),
305 Box::new(apply_idempotent_laws(right)),
306 ),
307 TLExpr::Exists { var, domain, body } => TLExpr::Exists {
308 var: var.clone(),
309 domain: domain.clone(),
310 body: Box::new(apply_idempotent_laws(body)),
311 },
312 TLExpr::ForAll { var, domain, body } => TLExpr::ForAll {
313 var: var.clone(),
314 domain: domain.clone(),
315 body: Box::new(apply_idempotent_laws(body)),
316 },
317 _ => expr.clone(),
318 }
319}
320
321fn apply_absorption_laws(expr: &TLExpr) -> TLExpr {
323 match expr {
324 TLExpr::And(left, right) => {
325 let left_simplified = apply_absorption_laws(left);
326 let right_simplified = apply_absorption_laws(right);
327
328 if let TLExpr::Or(or_left, _or_right) = &right_simplified {
330 if expressions_equal(&left_simplified, or_left) {
331 return left_simplified;
332 }
333 }
334 if let TLExpr::Or(or_left, _or_right) = &left_simplified {
336 if expressions_equal(&right_simplified, or_left) {
337 return right_simplified;
338 }
339 }
340
341 TLExpr::And(Box::new(left_simplified), Box::new(right_simplified))
342 }
343 TLExpr::Or(left, right) => {
344 let left_simplified = apply_absorption_laws(left);
345 let right_simplified = apply_absorption_laws(right);
346
347 if let TLExpr::And(and_left, _and_right) = &right_simplified {
349 if expressions_equal(&left_simplified, and_left) {
350 return left_simplified;
351 }
352 }
353 if let TLExpr::And(and_left, _and_right) = &left_simplified {
355 if expressions_equal(&right_simplified, and_left) {
356 return right_simplified;
357 }
358 }
359
360 TLExpr::Or(Box::new(left_simplified), Box::new(right_simplified))
361 }
362 TLExpr::Not(inner) => TLExpr::Not(Box::new(apply_absorption_laws(inner))),
363 TLExpr::Imply(left, right) => TLExpr::Imply(
364 Box::new(apply_absorption_laws(left)),
365 Box::new(apply_absorption_laws(right)),
366 ),
367 TLExpr::Exists { var, domain, body } => TLExpr::Exists {
368 var: var.clone(),
369 domain: domain.clone(),
370 body: Box::new(apply_absorption_laws(body)),
371 },
372 TLExpr::ForAll { var, domain, body } => TLExpr::ForAll {
373 var: var.clone(),
374 domain: domain.clone(),
375 body: Box::new(apply_absorption_laws(body)),
376 },
377 _ => expr.clone(),
378 }
379}
380
381fn apply_de_morgan(expr: &TLExpr) -> TLExpr {
383 match expr {
384 TLExpr::Not(inner) => match &**inner {
385 TLExpr::And(left, right) => {
386 TLExpr::Or(
388 Box::new(apply_de_morgan(&TLExpr::Not(left.clone()))),
389 Box::new(apply_de_morgan(&TLExpr::Not(right.clone()))),
390 )
391 }
392 TLExpr::Or(left, right) => {
393 TLExpr::And(
395 Box::new(apply_de_morgan(&TLExpr::Not(left.clone()))),
396 Box::new(apply_de_morgan(&TLExpr::Not(right.clone()))),
397 )
398 }
399 _ => TLExpr::Not(Box::new(apply_de_morgan(inner))),
400 },
401 TLExpr::And(left, right) => TLExpr::And(
402 Box::new(apply_de_morgan(left)),
403 Box::new(apply_de_morgan(right)),
404 ),
405 TLExpr::Or(left, right) => TLExpr::Or(
406 Box::new(apply_de_morgan(left)),
407 Box::new(apply_de_morgan(right)),
408 ),
409 TLExpr::Imply(left, right) => TLExpr::Imply(
410 Box::new(apply_de_morgan(left)),
411 Box::new(apply_de_morgan(right)),
412 ),
413 TLExpr::Exists { var, domain, body } => TLExpr::Exists {
414 var: var.clone(),
415 domain: domain.clone(),
416 body: Box::new(apply_de_morgan(body)),
417 },
418 TLExpr::ForAll { var, domain, body } => TLExpr::ForAll {
419 var: var.clone(),
420 domain: domain.clone(),
421 body: Box::new(apply_de_morgan(body)),
422 },
423 _ => expr.clone(),
424 }
425}
426
427fn expressions_equal(left: &TLExpr, right: &TLExpr) -> bool {
429 match (left, right) {
430 (TLExpr::Pred { name: n1, args: a1 }, TLExpr::Pred { name: n2, args: a2 }) => {
431 n1 == n2 && terms_equal_vec(a1, a2)
432 }
433 (TLExpr::And(l1, r1), TLExpr::And(l2, r2)) => {
434 expressions_equal(l1, l2) && expressions_equal(r1, r2)
435 }
436 (TLExpr::Or(l1, r1), TLExpr::Or(l2, r2)) => {
437 expressions_equal(l1, l2) && expressions_equal(r1, r2)
438 }
439 (TLExpr::Not(e1), TLExpr::Not(e2)) => expressions_equal(e1, e2),
440 (TLExpr::Imply(l1, r1), TLExpr::Imply(l2, r2)) => {
441 expressions_equal(l1, l2) && expressions_equal(r1, r2)
442 }
443 (
444 TLExpr::Exists {
445 var: v1,
446 domain: d1,
447 body: b1,
448 },
449 TLExpr::Exists {
450 var: v2,
451 domain: d2,
452 body: b2,
453 },
454 ) => v1 == v2 && d1 == d2 && expressions_equal(b1, b2),
455 (
456 TLExpr::ForAll {
457 var: v1,
458 domain: d1,
459 body: b1,
460 },
461 TLExpr::ForAll {
462 var: v2,
463 domain: d2,
464 body: b2,
465 },
466 ) => v1 == v2 && d1 == d2 && expressions_equal(b1, b2),
467 _ => false,
468 }
469}
470
471fn terms_equal_vec(terms1: &[Term], terms2: &[Term]) -> bool {
473 if terms1.len() != terms2.len() {
474 return false;
475 }
476 terms1.iter().zip(terms2.iter()).all(|(t1, t2)| t1 == t2)
477}
478
479#[cfg(test)]
480mod tests {
481 use super::*;
482
483 #[test]
484 fn test_double_negation() {
485 let expr = TLExpr::Not(Box::new(TLExpr::Not(Box::new(TLExpr::Pred {
486 name: "P".to_string(),
487 args: vec![Term::Var("x".to_string())],
488 }))));
489
490 let simplified = apply_double_negation(&expr);
491
492 assert!(matches!(simplified, TLExpr::Pred { .. }));
493 }
494
495 #[test]
496 fn test_idempotent_and() {
497 let pred = TLExpr::Pred {
498 name: "P".to_string(),
499 args: vec![Term::Var("x".to_string())],
500 };
501
502 let expr = TLExpr::And(Box::new(pred.clone()), Box::new(pred));
503
504 let simplified = apply_idempotent_laws(&expr);
505
506 assert!(matches!(simplified, TLExpr::Pred { .. }));
507 }
508
509 #[test]
510 fn test_simplify_complex_expression() {
511 let pred = TLExpr::Pred {
513 name: "P".to_string(),
514 args: vec![Term::Var("x".to_string())],
515 };
516
517 let double_neg = TLExpr::Not(Box::new(TLExpr::Not(Box::new(pred))));
518
519 let simplified = simplify_expression(&double_neg);
520
521 assert!(matches!(simplified, TLExpr::Pred { .. }));
522 }
523
524 #[test]
525 fn test_de_morgan_and() {
526 let p = TLExpr::Pred {
527 name: "P".to_string(),
528 args: vec![],
529 };
530 let q = TLExpr::Pred {
531 name: "Q".to_string(),
532 args: vec![],
533 };
534
535 let expr = TLExpr::Not(Box::new(TLExpr::And(Box::new(p), Box::new(q))));
537
538 let simplified = apply_de_morgan(&expr);
539
540 assert!(matches!(simplified, TLExpr::Or(_, _)));
542 }
543
544 #[test]
545 fn test_constant_folding_add() {
546 let expr = TLExpr::Add(
548 Box::new(TLExpr::Constant(2.0)),
549 Box::new(TLExpr::Constant(3.0)),
550 );
551
552 let simplified = apply_constant_folding(&expr);
553
554 assert!(matches!(simplified, TLExpr::Constant(c) if (c - 5.0).abs() < 1e-10));
555 }
556
557 #[test]
558 fn test_constant_folding_mul() {
559 let expr = TLExpr::Mul(
561 Box::new(TLExpr::Constant(4.0)),
562 Box::new(TLExpr::Constant(5.0)),
563 );
564
565 let simplified = apply_constant_folding(&expr);
566
567 assert!(matches!(simplified, TLExpr::Constant(c) if (c - 20.0).abs() < 1e-10));
568 }
569
570 #[test]
571 fn test_constant_folding_sqrt() {
572 let expr = TLExpr::Sqrt(Box::new(TLExpr::Constant(16.0)));
574
575 let simplified = apply_constant_folding(&expr);
576
577 assert!(matches!(simplified, TLExpr::Constant(c) if (c - 4.0).abs() < 1e-10));
578 }
579
580 #[test]
581 fn test_constant_folding_nested() {
582 let expr = TLExpr::Add(
584 Box::new(TLExpr::Mul(
585 Box::new(TLExpr::Constant(2.0)),
586 Box::new(TLExpr::Constant(3.0)),
587 )),
588 Box::new(TLExpr::Constant(4.0)),
589 );
590
591 let simplified = apply_constant_folding(&expr);
592
593 assert!(matches!(simplified, TLExpr::Constant(c) if (c - 10.0).abs() < 1e-10));
594 }
595
596 #[test]
597 fn test_identity_law_and_true() {
598 let pred = TLExpr::Pred {
599 name: "P".to_string(),
600 args: vec![Term::Var("x".to_string())],
601 };
602
603 let expr = TLExpr::And(Box::new(pred.clone()), Box::new(TLExpr::Constant(1.0)));
605
606 let simplified = apply_identity_laws(&expr);
607
608 assert!(matches!(simplified, TLExpr::Pred { .. }));
609 }
610
611 #[test]
612 fn test_identity_law_and_false() {
613 let pred = TLExpr::Pred {
614 name: "P".to_string(),
615 args: vec![Term::Var("x".to_string())],
616 };
617
618 let expr = TLExpr::And(Box::new(pred), Box::new(TLExpr::Constant(0.0)));
620
621 let simplified = apply_identity_laws(&expr);
622
623 assert!(matches!(simplified, TLExpr::Constant(c) if c.abs() < 1e-10));
624 }
625
626 #[test]
627 fn test_identity_law_or_false() {
628 let pred = TLExpr::Pred {
629 name: "P".to_string(),
630 args: vec![Term::Var("x".to_string())],
631 };
632
633 let expr = TLExpr::Or(Box::new(pred.clone()), Box::new(TLExpr::Constant(0.0)));
635
636 let simplified = apply_identity_laws(&expr);
637
638 assert!(matches!(simplified, TLExpr::Pred { .. }));
639 }
640
641 #[test]
642 fn test_identity_law_or_true() {
643 let pred = TLExpr::Pred {
644 name: "P".to_string(),
645 args: vec![Term::Var("x".to_string())],
646 };
647
648 let expr = TLExpr::Or(Box::new(pred), Box::new(TLExpr::Constant(1.0)));
650
651 let simplified = apply_identity_laws(&expr);
652
653 assert!(matches!(simplified, TLExpr::Constant(c) if (c - 1.0).abs() < 1e-10));
654 }
655
656 #[test]
657 fn test_combined_simplification() {
658 let pred = TLExpr::Pred {
660 name: "P".to_string(),
661 args: vec![Term::Var("x".to_string())],
662 };
663
664 let double_neg = TLExpr::Not(Box::new(TLExpr::Not(Box::new(pred))));
665 let expr = TLExpr::And(Box::new(double_neg), Box::new(TLExpr::Constant(1.0)));
666
667 let simplified = simplify_expression(&expr);
668
669 assert!(matches!(simplified, TLExpr::Pred { .. }));
670 }
671}