1use tensorlogic_ir::TLExpr;
23
24#[derive(Debug, Clone, Default)]
26pub struct DistributivityStats {
27 pub expressions_factored: usize,
29 pub expressions_expanded: usize,
31 pub common_terms_extracted: usize,
33 pub total_processed: usize,
35}
36
37impl DistributivityStats {
38 pub fn total_optimizations(&self) -> usize {
40 self.expressions_factored + self.expressions_expanded + self.common_terms_extracted
41 }
42}
43
44pub fn optimize_distributivity(expr: &TLExpr) -> (TLExpr, DistributivityStats) {
57 let mut stats = DistributivityStats::default();
58 let result = optimize_distributivity_impl(expr, &mut stats);
59 (result, stats)
60}
61
62fn optimize_distributivity_impl(expr: &TLExpr, stats: &mut DistributivityStats) -> TLExpr {
63 stats.total_processed += 1;
64
65 match expr {
66 TLExpr::Add(lhs, rhs) => {
68 let lhs_opt = optimize_distributivity_impl(lhs, stats);
69 let rhs_opt = optimize_distributivity_impl(rhs, stats);
70
71 if let Some(factored) = try_factor_add(&lhs_opt, &rhs_opt) {
73 stats.expressions_factored += 1;
74 return factored;
75 }
76
77 TLExpr::Add(Box::new(lhs_opt), Box::new(rhs_opt))
78 }
79
80 TLExpr::Sub(lhs, rhs) => {
82 let lhs_opt = optimize_distributivity_impl(lhs, stats);
83 let rhs_opt = optimize_distributivity_impl(rhs, stats);
84
85 if let Some(factored) = try_factor_sub(&lhs_opt, &rhs_opt) {
87 stats.expressions_factored += 1;
88 return factored;
89 }
90
91 TLExpr::Sub(Box::new(lhs_opt), Box::new(rhs_opt))
92 }
93
94 TLExpr::Mul(lhs, rhs) => {
96 let lhs_opt = optimize_distributivity_impl(lhs, stats);
97 let rhs_opt = optimize_distributivity_impl(rhs, stats);
98
99 TLExpr::Mul(Box::new(lhs_opt), Box::new(rhs_opt))
102 }
103
104 TLExpr::And(lhs, rhs) => {
106 let lhs_opt = optimize_distributivity_impl(lhs, stats);
107 let rhs_opt = optimize_distributivity_impl(rhs, stats);
108
109 if let Some(factored) = try_factor_and(&lhs_opt, &rhs_opt) {
111 stats.expressions_factored += 1;
112 return factored;
113 }
114
115 TLExpr::And(Box::new(lhs_opt), Box::new(rhs_opt))
116 }
117
118 TLExpr::Or(lhs, rhs) => {
120 let lhs_opt = optimize_distributivity_impl(lhs, stats);
121 let rhs_opt = optimize_distributivity_impl(rhs, stats);
122
123 if let Some(factored) = try_factor_or(&lhs_opt, &rhs_opt) {
125 stats.expressions_factored += 1;
126 return factored;
127 }
128
129 TLExpr::Or(Box::new(lhs_opt), Box::new(rhs_opt))
130 }
131
132 TLExpr::Not(inner) => {
134 let inner_opt = optimize_distributivity_impl(inner, stats);
135 TLExpr::Not(Box::new(inner_opt))
136 }
137
138 TLExpr::Imply(lhs, rhs) => {
139 let lhs_opt = optimize_distributivity_impl(lhs, stats);
140 let rhs_opt = optimize_distributivity_impl(rhs, stats);
141 TLExpr::Imply(Box::new(lhs_opt), Box::new(rhs_opt))
142 }
143
144 TLExpr::Div(lhs, rhs) => {
145 let lhs_opt = optimize_distributivity_impl(lhs, stats);
146 let rhs_opt = optimize_distributivity_impl(rhs, stats);
147 TLExpr::Div(Box::new(lhs_opt), Box::new(rhs_opt))
148 }
149
150 TLExpr::Pow(base, exp) => {
151 let base_opt = optimize_distributivity_impl(base, stats);
152 let exp_opt = optimize_distributivity_impl(exp, stats);
153 TLExpr::Pow(Box::new(base_opt), Box::new(exp_opt))
154 }
155
156 TLExpr::Abs(inner) => {
157 let inner_opt = optimize_distributivity_impl(inner, stats);
158 TLExpr::Abs(Box::new(inner_opt))
159 }
160
161 TLExpr::Sqrt(inner) => {
162 let inner_opt = optimize_distributivity_impl(inner, stats);
163 TLExpr::Sqrt(Box::new(inner_opt))
164 }
165
166 TLExpr::Exp(inner) => {
167 let inner_opt = optimize_distributivity_impl(inner, stats);
168 TLExpr::Exp(Box::new(inner_opt))
169 }
170
171 TLExpr::Log(inner) => {
172 let inner_opt = optimize_distributivity_impl(inner, stats);
173 TLExpr::Log(Box::new(inner_opt))
174 }
175
176 TLExpr::Exists { var, domain, body } => {
177 let body_opt = optimize_distributivity_impl(body, stats);
178 TLExpr::Exists {
179 var: var.clone(),
180 domain: domain.clone(),
181 body: Box::new(body_opt),
182 }
183 }
184
185 TLExpr::ForAll { var, domain, body } => {
186 let body_opt = optimize_distributivity_impl(body, stats);
187 TLExpr::ForAll {
188 var: var.clone(),
189 domain: domain.clone(),
190 body: Box::new(body_opt),
191 }
192 }
193
194 TLExpr::Let { var, value, body } => {
195 let value_opt = optimize_distributivity_impl(value, stats);
196 let body_opt = optimize_distributivity_impl(body, stats);
197 TLExpr::Let {
198 var: var.clone(),
199 value: Box::new(value_opt),
200 body: Box::new(body_opt),
201 }
202 }
203
204 TLExpr::IfThenElse {
205 condition,
206 then_branch,
207 else_branch,
208 } => {
209 let cond_opt = optimize_distributivity_impl(condition, stats);
210 let then_opt = optimize_distributivity_impl(then_branch, stats);
211 let else_opt = optimize_distributivity_impl(else_branch, stats);
212 TLExpr::IfThenElse {
213 condition: Box::new(cond_opt),
214 then_branch: Box::new(then_opt),
215 else_branch: Box::new(else_opt),
216 }
217 }
218
219 TLExpr::Eq(lhs, rhs) => {
221 let lhs_opt = optimize_distributivity_impl(lhs, stats);
222 let rhs_opt = optimize_distributivity_impl(rhs, stats);
223 TLExpr::Eq(Box::new(lhs_opt), Box::new(rhs_opt))
224 }
225
226 TLExpr::Lt(lhs, rhs) => {
227 let lhs_opt = optimize_distributivity_impl(lhs, stats);
228 let rhs_opt = optimize_distributivity_impl(rhs, stats);
229 TLExpr::Lt(Box::new(lhs_opt), Box::new(rhs_opt))
230 }
231
232 TLExpr::Lte(lhs, rhs) => {
233 let lhs_opt = optimize_distributivity_impl(lhs, stats);
234 let rhs_opt = optimize_distributivity_impl(rhs, stats);
235 TLExpr::Lte(Box::new(lhs_opt), Box::new(rhs_opt))
236 }
237
238 TLExpr::Gt(lhs, rhs) => {
239 let lhs_opt = optimize_distributivity_impl(lhs, stats);
240 let rhs_opt = optimize_distributivity_impl(rhs, stats);
241 TLExpr::Gt(Box::new(lhs_opt), Box::new(rhs_opt))
242 }
243
244 TLExpr::Gte(lhs, rhs) => {
245 let lhs_opt = optimize_distributivity_impl(lhs, stats);
246 let rhs_opt = optimize_distributivity_impl(rhs, stats);
247 TLExpr::Gte(Box::new(lhs_opt), Box::new(rhs_opt))
248 }
249
250 TLExpr::Min(lhs, rhs) => {
252 let lhs_opt = optimize_distributivity_impl(lhs, stats);
253 let rhs_opt = optimize_distributivity_impl(rhs, stats);
254 TLExpr::Min(Box::new(lhs_opt), Box::new(rhs_opt))
255 }
256
257 TLExpr::Max(lhs, rhs) => {
258 let lhs_opt = optimize_distributivity_impl(lhs, stats);
259 let rhs_opt = optimize_distributivity_impl(rhs, stats);
260 TLExpr::Max(Box::new(lhs_opt), Box::new(rhs_opt))
261 }
262
263 TLExpr::Box(inner) => {
265 let inner_opt = optimize_distributivity_impl(inner, stats);
266 TLExpr::Box(Box::new(inner_opt))
267 }
268
269 TLExpr::Diamond(inner) => {
270 let inner_opt = optimize_distributivity_impl(inner, stats);
271 TLExpr::Diamond(Box::new(inner_opt))
272 }
273
274 TLExpr::Next(inner) => {
276 let inner_opt = optimize_distributivity_impl(inner, stats);
277 TLExpr::Next(Box::new(inner_opt))
278 }
279
280 TLExpr::Eventually(inner) => {
281 let inner_opt = optimize_distributivity_impl(inner, stats);
282 TLExpr::Eventually(Box::new(inner_opt))
283 }
284
285 TLExpr::Always(inner) => {
286 let inner_opt = optimize_distributivity_impl(inner, stats);
287 TLExpr::Always(Box::new(inner_opt))
288 }
289
290 TLExpr::Until { before, after } => {
291 let before_opt = optimize_distributivity_impl(before, stats);
292 let after_opt = optimize_distributivity_impl(after, stats);
293 TLExpr::Until {
294 before: Box::new(before_opt),
295 after: Box::new(after_opt),
296 }
297 }
298
299 TLExpr::Pred { .. }
301 | TLExpr::Constant(_)
302 | TLExpr::Score(_)
303 | TLExpr::Mod(_, _)
304 | TLExpr::Floor(_)
305 | TLExpr::Ceil(_)
306 | TLExpr::Round(_)
307 | TLExpr::Sin(_)
308 | TLExpr::Cos(_)
309 | TLExpr::Tan(_)
310 | TLExpr::Aggregate { .. }
311 | TLExpr::TNorm { .. }
312 | TLExpr::TCoNorm { .. }
313 | TLExpr::FuzzyNot { .. }
314 | TLExpr::FuzzyImplication { .. }
315 | TLExpr::SoftExists { .. }
316 | TLExpr::SoftForAll { .. }
317 | TLExpr::WeightedRule { .. }
318 | TLExpr::ProbabilisticChoice { .. }
319 | TLExpr::Release { .. }
320 | TLExpr::WeakUntil { .. }
321 | TLExpr::StrongRelease { .. } => expr.clone(),
322
323 _ => expr.clone(),
325 }
326}
327
328fn try_factor_add(lhs: &TLExpr, rhs: &TLExpr) -> Option<TLExpr> {
330 if let (TLExpr::Mul(l1, l2), TLExpr::Mul(r1, r2)) = (lhs, rhs) {
332 if l1 == r1 {
334 return Some(TLExpr::mul(
335 (**l1).clone(),
336 TLExpr::add((**l2).clone(), (**r2).clone()),
337 ));
338 }
339 if l2 == r2 {
341 return Some(TLExpr::mul(
342 TLExpr::add((**l1).clone(), (**r1).clone()),
343 (**l2).clone(),
344 ));
345 }
346 if l1 == r2 {
348 return Some(TLExpr::mul(
349 (**l1).clone(),
350 TLExpr::add((**l2).clone(), (**r1).clone()),
351 ));
352 }
353 if l2 == r1 {
355 return Some(TLExpr::mul(
356 (**l2).clone(),
357 TLExpr::add((**l1).clone(), (**r2).clone()),
358 ));
359 }
360 }
361
362 if let (TLExpr::Mul(l1, l2), TLExpr::Mul(r1, r2)) = (lhs, rhs) {
364 if let (TLExpr::Constant(c1), TLExpr::Constant(c2)) = (l1.as_ref(), r1.as_ref()) {
365 if c1 == c2 {
366 return Some(TLExpr::mul(
367 TLExpr::Constant(*c1),
368 TLExpr::add((**l2).clone(), (**r2).clone()),
369 ));
370 }
371 }
372 }
373
374 None
375}
376
377fn try_factor_sub(lhs: &TLExpr, rhs: &TLExpr) -> Option<TLExpr> {
379 if let (TLExpr::Mul(l1, l2), TLExpr::Mul(r1, r2)) = (lhs, rhs) {
381 if l1 == r1 {
383 return Some(TLExpr::mul(
384 (**l1).clone(),
385 TLExpr::sub((**l2).clone(), (**r2).clone()),
386 ));
387 }
388 if l2 == r2 {
390 return Some(TLExpr::mul(
391 TLExpr::sub((**l1).clone(), (**r1).clone()),
392 (**l2).clone(),
393 ));
394 }
395 }
396
397 None
398}
399
400fn try_factor_and(lhs: &TLExpr, rhs: &TLExpr) -> Option<TLExpr> {
402 if let (TLExpr::Or(l1, l2), TLExpr::Or(r1, r2)) = (lhs, rhs) {
403 if l1 == r1 {
405 return Some(TLExpr::or(
406 (**l1).clone(),
407 TLExpr::and((**l2).clone(), (**r2).clone()),
408 ));
409 }
410 if l1 == r2 {
412 return Some(TLExpr::or(
413 (**l1).clone(),
414 TLExpr::and((**l2).clone(), (**r1).clone()),
415 ));
416 }
417 if l2 == r1 {
419 return Some(TLExpr::or(
420 (**l2).clone(),
421 TLExpr::and((**l1).clone(), (**r2).clone()),
422 ));
423 }
424 if l2 == r2 {
426 return Some(TLExpr::or(
427 (**l2).clone(),
428 TLExpr::and((**l1).clone(), (**r1).clone()),
429 ));
430 }
431 }
432
433 None
434}
435
436fn try_factor_or(lhs: &TLExpr, rhs: &TLExpr) -> Option<TLExpr> {
438 if let (TLExpr::And(l1, l2), TLExpr::And(r1, r2)) = (lhs, rhs) {
439 if l1 == r1 {
441 return Some(TLExpr::and(
442 (**l1).clone(),
443 TLExpr::or((**l2).clone(), (**r2).clone()),
444 ));
445 }
446 if l1 == r2 {
448 return Some(TLExpr::and(
449 (**l1).clone(),
450 TLExpr::or((**l2).clone(), (**r1).clone()),
451 ));
452 }
453 if l2 == r1 {
455 return Some(TLExpr::and(
456 (**l2).clone(),
457 TLExpr::or((**l1).clone(), (**r2).clone()),
458 ));
459 }
460 if l2 == r2 {
462 return Some(TLExpr::and(
463 (**l2).clone(),
464 TLExpr::or((**l1).clone(), (**r1).clone()),
465 ));
466 }
467 }
468
469 None
470}
471
472#[cfg(test)]
473mod tests {
474 use super::*;
475 use tensorlogic_ir::Term;
476
477 #[test]
478 fn test_factor_add_common_left() {
479 let a = TLExpr::pred("a", vec![Term::var("i")]);
481 let b = TLExpr::pred("b", vec![Term::var("j")]);
482 let c = TLExpr::pred("c", vec![Term::var("k")]);
483
484 let expr = TLExpr::add(
485 TLExpr::mul(a.clone(), b.clone()),
486 TLExpr::mul(a.clone(), c.clone()),
487 );
488
489 let (optimized, stats) = optimize_distributivity(&expr);
490 assert_eq!(stats.expressions_factored, 1);
491
492 if let TLExpr::Mul(lhs, rhs) = optimized {
494 assert_eq!(*lhs, a);
495 if let TLExpr::Add(add_lhs, add_rhs) = *rhs {
496 assert_eq!(*add_lhs, b);
497 assert_eq!(*add_rhs, c);
498 } else {
499 panic!("Expected Add on right side of Mul");
500 }
501 } else {
502 panic!("Expected Mul expression");
503 }
504 }
505
506 #[test]
507 fn test_factor_add_common_right() {
508 let a = TLExpr::pred("a", vec![Term::var("i")]);
510 let b = TLExpr::pred("b", vec![Term::var("j")]);
511 let c = TLExpr::pred("c", vec![Term::var("k")]);
512
513 let expr = TLExpr::add(
514 TLExpr::mul(a.clone(), b.clone()),
515 TLExpr::mul(c.clone(), b.clone()),
516 );
517
518 let (optimized, stats) = optimize_distributivity(&expr);
519 assert_eq!(stats.expressions_factored, 1);
520
521 if let TLExpr::Mul(lhs, rhs) = optimized {
523 assert_eq!(*rhs, b);
524 if let TLExpr::Add(add_lhs, add_rhs) = *lhs {
525 assert_eq!(*add_lhs, a);
526 assert_eq!(*add_rhs, c);
527 } else {
528 panic!("Expected Add on left side of Mul");
529 }
530 } else {
531 panic!("Expected Mul expression");
532 }
533 }
534
535 #[test]
536 fn test_factor_sub() {
537 let a = TLExpr::pred("a", vec![Term::var("i")]);
539 let b = TLExpr::pred("b", vec![Term::var("j")]);
540 let c = TLExpr::pred("c", vec![Term::var("k")]);
541
542 let expr = TLExpr::sub(
543 TLExpr::mul(a.clone(), b.clone()),
544 TLExpr::mul(a.clone(), c.clone()),
545 );
546
547 let (optimized, stats) = optimize_distributivity(&expr);
548 assert_eq!(stats.expressions_factored, 1);
549
550 if let TLExpr::Mul(lhs, rhs) = optimized {
552 assert_eq!(*lhs, a);
553 assert!(matches!(*rhs, TLExpr::Sub(_, _)));
554 } else {
555 panic!("Expected Mul expression");
556 }
557 }
558
559 #[test]
560 fn test_factor_and_over_or() {
561 let a = TLExpr::pred("a", vec![Term::var("i")]);
563 let b = TLExpr::pred("b", vec![Term::var("j")]);
564 let c = TLExpr::pred("c", vec![Term::var("k")]);
565
566 let expr = TLExpr::and(
567 TLExpr::or(a.clone(), b.clone()),
568 TLExpr::or(a.clone(), c.clone()),
569 );
570
571 let (optimized, stats) = optimize_distributivity(&expr);
572 assert_eq!(stats.expressions_factored, 1);
573
574 if let TLExpr::Or(lhs, rhs) = optimized {
576 assert_eq!(*lhs, a);
577 if let TLExpr::And(and_lhs, and_rhs) = *rhs {
578 assert_eq!(*and_lhs, b);
579 assert_eq!(*and_rhs, c);
580 } else {
581 panic!("Expected And on right side of Or");
582 }
583 } else {
584 panic!("Expected Or expression");
585 }
586 }
587
588 #[test]
589 fn test_factor_or_over_and() {
590 let a = TLExpr::pred("a", vec![Term::var("i")]);
592 let b = TLExpr::pred("b", vec![Term::var("j")]);
593 let c = TLExpr::pred("c", vec![Term::var("k")]);
594
595 let expr = TLExpr::or(
596 TLExpr::and(a.clone(), b.clone()),
597 TLExpr::and(a.clone(), c.clone()),
598 );
599
600 let (optimized, stats) = optimize_distributivity(&expr);
601 assert_eq!(stats.expressions_factored, 1);
602
603 if let TLExpr::And(lhs, rhs) = optimized {
605 assert_eq!(*lhs, a);
606 if let TLExpr::Or(or_lhs, or_rhs) = *rhs {
607 assert_eq!(*or_lhs, b);
608 assert_eq!(*or_rhs, c);
609 } else {
610 panic!("Expected Or on right side of And");
611 }
612 } else {
613 panic!("Expected And expression");
614 }
615 }
616
617 #[test]
618 fn test_no_factoring_possible() {
619 let a = TLExpr::pred("a", vec![Term::var("i")]);
621 let b = TLExpr::pred("b", vec![Term::var("j")]);
622 let c = TLExpr::pred("c", vec![Term::var("k")]);
623 let d = TLExpr::pred("d", vec![Term::var("l")]);
624
625 let expr = TLExpr::add(TLExpr::mul(a, b), TLExpr::mul(c, d));
626
627 let (optimized, stats) = optimize_distributivity(&expr);
628 assert_eq!(stats.expressions_factored, 0);
629 assert!(matches!(optimized, TLExpr::Add(_, _)));
631 }
632
633 #[test]
634 fn test_nested_factoring() {
635 let a = TLExpr::pred("a", vec![Term::var("i")]);
637 let b = TLExpr::pred("b", vec![Term::var("j")]);
638 let c = TLExpr::pred("c", vec![Term::var("k")]);
639 let d = TLExpr::pred("d", vec![Term::var("l")]);
640
641 let inner = TLExpr::add(
642 TLExpr::mul(a.clone(), b.clone()),
643 TLExpr::mul(a.clone(), c.clone()),
644 );
645 let expr = TLExpr::add(inner, TLExpr::mul(a.clone(), d));
646
647 let (_, stats) = optimize_distributivity(&expr);
648 assert!(stats.expressions_factored >= 1);
650 }
651
652 #[test]
653 fn test_quantifier_body() {
654 let a = TLExpr::pred("a", vec![Term::var("x"), Term::var("i")]);
655 let b = TLExpr::pred("b", vec![Term::var("x"), Term::var("j")]);
656 let c = TLExpr::pred("c", vec![Term::var("x"), Term::var("k")]);
657
658 let body = TLExpr::add(
659 TLExpr::mul(a.clone(), b.clone()),
660 TLExpr::mul(a.clone(), c.clone()),
661 );
662 let expr = TLExpr::exists("x", "D", body);
663
664 let (optimized, stats) = optimize_distributivity(&expr);
665 assert_eq!(stats.expressions_factored, 1);
666
667 if let TLExpr::Exists { body, .. } = optimized {
668 assert!(matches!(*body, TLExpr::Mul(_, _)));
670 } else {
671 panic!("Expected Exists expression");
672 }
673 }
674
675 #[test]
676 fn test_stats_total_optimizations() {
677 let stats = DistributivityStats {
678 expressions_factored: 3,
679 expressions_expanded: 2,
680 common_terms_extracted: 1,
681 total_processed: 100,
682 };
683 assert_eq!(stats.total_optimizations(), 6);
684 }
685}