1use std::collections::HashSet;
40
41use tensorlogic_ir::TLExpr;
42
43use super::complexity::{analyze_complexity, CostWeights};
44use crate::CompilerContext;
45
46#[derive(Debug, Clone, PartialEq)]
48pub struct CostBasedStats {
49 pub alternatives_explored: usize,
51 pub original_cost: f64,
53 pub optimized_cost: f64,
55 pub rewrites_applied: usize,
57 pub time_us: u64,
59}
60
61impl CostBasedStats {
62 pub fn cost_reduction_percent(&self) -> f64 {
64 if self.original_cost == 0.0 {
65 0.0
66 } else {
67 ((self.original_cost - self.optimized_cost) / self.original_cost) * 100.0
68 }
69 }
70
71 pub fn is_beneficial(&self) -> bool {
73 self.optimized_cost < self.original_cost
74 }
75
76 pub fn cost_ratio(&self) -> f64 {
78 if self.original_cost == 0.0 {
79 1.0
80 } else {
81 self.optimized_cost / self.original_cost
82 }
83 }
84}
85
86#[derive(Debug, Clone, PartialEq, Eq, Hash)]
88pub enum RewriteRule {
89 DistributeAndOverOr,
91 DistributeOrOverAnd,
93 FactorCommonAnd,
95 FactorCommonOr,
97 PushExistsInward,
99 PushForallInward,
101 PullExistsOutward,
103 PullForallOutward,
105 MergeNestedExists,
107 MergeNestedForall,
109 ReorderConjunctions,
111 ReorderDisjunctions,
113}
114
115#[derive(Debug, Clone)]
117struct Alternative {
118 expr: TLExpr,
119 cost: f64,
120 rules_applied: Vec<RewriteRule>,
121}
122
123pub struct CostBasedOptimizer<'a> {
125 _context: &'a CompilerContext,
126 cost_weights: CostWeights,
127 max_alternatives: usize,
128 explored: HashSet<String>,
129}
130
131impl<'a> CostBasedOptimizer<'a> {
132 pub fn new(context: &'a CompilerContext) -> Self {
134 Self {
135 _context: context,
136 cost_weights: CostWeights::default(),
137 max_alternatives: 100,
138 explored: HashSet::new(),
139 }
140 }
141
142 pub fn with_cost_weights(mut self, weights: CostWeights) -> Self {
144 self.cost_weights = weights;
145 self
146 }
147
148 pub fn with_max_alternatives(mut self, max: usize) -> Self {
150 self.max_alternatives = max;
151 self
152 }
153
154 pub fn optimize(&mut self, expr: &TLExpr) -> (TLExpr, CostBasedStats) {
156 let start = std::time::Instant::now();
157
158 let original_cost = self.estimate_cost(expr);
159 let mut alternatives = vec![Alternative {
160 expr: expr.clone(),
161 cost: original_cost,
162 rules_applied: Vec::new(),
163 }];
164
165 self.explored.clear();
166 self.explored.insert(expr_hash(expr));
167
168 let mut iteration = 0;
170 while iteration < self.max_alternatives && iteration < alternatives.len() {
171 let current = &alternatives[iteration].clone();
172 let new_alts = self.generate_alternatives(¤t.expr, ¤t.rules_applied);
173
174 for alt in new_alts {
175 let hash = expr_hash(&alt.expr);
176 if !self.explored.contains(&hash) {
177 self.explored.insert(hash);
178 alternatives.push(alt);
179
180 if alternatives.len() >= self.max_alternatives {
181 break;
182 }
183 }
184 }
185
186 iteration += 1;
187 }
188
189 let best = alternatives
191 .iter()
192 .min_by(|a, b| {
193 a.cost
194 .partial_cmp(&b.cost)
195 .unwrap_or(std::cmp::Ordering::Equal)
196 })
197 .unwrap();
198
199 let time_us = start.elapsed().as_micros() as u64;
200
201 let stats = CostBasedStats {
202 alternatives_explored: alternatives.len(),
203 original_cost,
204 optimized_cost: best.cost,
205 rewrites_applied: best.rules_applied.len(),
206 time_us,
207 };
208
209 (best.expr.clone(), stats)
210 }
211
212 fn estimate_cost(&self, expr: &TLExpr) -> f64 {
214 let complexity = analyze_complexity(expr);
215 complexity.total_cost_with_weights(&self.cost_weights)
216 }
217
218 fn generate_alternatives(&self, expr: &TLExpr, applied: &[RewriteRule]) -> Vec<Alternative> {
220 let mut alternatives = Vec::new();
221
222 for rule in self.available_rules() {
224 if let Some(rewritten) = self.apply_rule(expr, &rule) {
225 let cost = self.estimate_cost(&rewritten);
226 let mut new_applied = applied.to_vec();
227 new_applied.push(rule);
228
229 alternatives.push(Alternative {
230 expr: rewritten,
231 cost,
232 rules_applied: new_applied,
233 });
234 }
235 }
236
237 alternatives
238 }
239
240 fn available_rules(&self) -> Vec<RewriteRule> {
242 vec![
243 RewriteRule::DistributeAndOverOr,
244 RewriteRule::DistributeOrOverAnd,
245 RewriteRule::FactorCommonAnd,
246 RewriteRule::FactorCommonOr,
247 RewriteRule::PushExistsInward,
248 RewriteRule::PushForallInward,
249 RewriteRule::PullExistsOutward,
250 RewriteRule::PullForallOutward,
251 RewriteRule::MergeNestedExists,
252 RewriteRule::MergeNestedForall,
253 RewriteRule::ReorderConjunctions,
254 RewriteRule::ReorderDisjunctions,
255 ]
256 }
257
258 fn apply_rule(&self, expr: &TLExpr, rule: &RewriteRule) -> Option<TLExpr> {
260 match rule {
261 RewriteRule::DistributeAndOverOr => self.distribute_and_over_or(expr),
262 RewriteRule::DistributeOrOverAnd => self.distribute_or_over_and(expr),
263 RewriteRule::FactorCommonAnd => self.factor_common_and(expr),
264 RewriteRule::FactorCommonOr => self.factor_common_or(expr),
265 RewriteRule::PushExistsInward => self.push_exists_inward(expr),
266 RewriteRule::PushForallInward => self.push_forall_inward(expr),
267 RewriteRule::PullExistsOutward => self.pull_exists_outward(expr),
268 RewriteRule::PullForallOutward => self.pull_forall_outward(expr),
269 RewriteRule::MergeNestedExists => self.merge_nested_exists(expr),
270 RewriteRule::MergeNestedForall => self.merge_nested_forall(expr),
271 RewriteRule::ReorderConjunctions => self.reorder_conjunctions(expr),
272 RewriteRule::ReorderDisjunctions => self.reorder_disjunctions(expr),
273 }
274 }
275
276 fn distribute_and_over_or(&self, expr: &TLExpr) -> Option<TLExpr> {
278 match expr {
279 TLExpr::And(a, b) => {
280 if let TLExpr::Or(b1, b2) = b.as_ref() {
281 Some(TLExpr::or(
282 TLExpr::and(a.as_ref().clone(), b1.as_ref().clone()),
283 TLExpr::and(a.as_ref().clone(), b2.as_ref().clone()),
284 ))
285 } else if let TLExpr::Or(a1, a2) = a.as_ref() {
286 Some(TLExpr::or(
287 TLExpr::and(a1.as_ref().clone(), b.as_ref().clone()),
288 TLExpr::and(a2.as_ref().clone(), b.as_ref().clone()),
289 ))
290 } else {
291 None
292 }
293 }
294 _ => None,
295 }
296 }
297
298 fn distribute_or_over_and(&self, expr: &TLExpr) -> Option<TLExpr> {
300 match expr {
301 TLExpr::Or(a, b) => {
302 if let TLExpr::And(b1, b2) = b.as_ref() {
303 Some(TLExpr::and(
304 TLExpr::or(a.as_ref().clone(), b1.as_ref().clone()),
305 TLExpr::or(a.as_ref().clone(), b2.as_ref().clone()),
306 ))
307 } else if let TLExpr::And(a1, a2) = a.as_ref() {
308 Some(TLExpr::and(
309 TLExpr::or(a1.as_ref().clone(), b.as_ref().clone()),
310 TLExpr::or(a2.as_ref().clone(), b.as_ref().clone()),
311 ))
312 } else {
313 None
314 }
315 }
316 _ => None,
317 }
318 }
319
320 fn factor_common_and(&self, expr: &TLExpr) -> Option<TLExpr> {
322 match expr {
323 TLExpr::Or(left, right) => {
324 if let (TLExpr::And(a1, b1), TLExpr::And(a2, b2)) = (left.as_ref(), right.as_ref())
325 {
326 if a1 == a2 {
327 return Some(TLExpr::and(
328 a1.as_ref().clone(),
329 TLExpr::or(b1.as_ref().clone(), b2.as_ref().clone()),
330 ));
331 }
332 if b1 == b2 {
333 return Some(TLExpr::and(
334 b1.as_ref().clone(),
335 TLExpr::or(a1.as_ref().clone(), a2.as_ref().clone()),
336 ));
337 }
338 }
339 None
340 }
341 _ => None,
342 }
343 }
344
345 fn factor_common_or(&self, expr: &TLExpr) -> Option<TLExpr> {
347 match expr {
348 TLExpr::And(left, right) => {
349 if let (TLExpr::Or(a1, b1), TLExpr::Or(a2, b2)) = (left.as_ref(), right.as_ref()) {
350 if a1 == a2 {
351 return Some(TLExpr::or(
352 a1.as_ref().clone(),
353 TLExpr::and(b1.as_ref().clone(), b2.as_ref().clone()),
354 ));
355 }
356 if b1 == b2 {
357 return Some(TLExpr::or(
358 b1.as_ref().clone(),
359 TLExpr::and(a1.as_ref().clone(), a2.as_ref().clone()),
360 ));
361 }
362 }
363 None
364 }
365 _ => None,
366 }
367 }
368
369 fn push_exists_inward(&self, expr: &TLExpr) -> Option<TLExpr> {
371 match expr {
372 TLExpr::Exists { var, domain, body } => {
373 if let TLExpr::And(p, q) = body.as_ref() {
374 let q_vars = q.free_vars();
375 if !q_vars.contains(var.as_str()) {
376 return Some(TLExpr::and(
377 TLExpr::exists(var, domain, p.as_ref().clone()),
378 q.as_ref().clone(),
379 ));
380 }
381
382 let p_vars = p.free_vars();
383 if !p_vars.contains(var.as_str()) {
384 return Some(TLExpr::and(
385 p.as_ref().clone(),
386 TLExpr::exists(var, domain, q.as_ref().clone()),
387 ));
388 }
389 }
390 None
391 }
392 _ => None,
393 }
394 }
395
396 fn push_forall_inward(&self, expr: &TLExpr) -> Option<TLExpr> {
398 match expr {
399 TLExpr::ForAll { var, domain, body } => {
400 if let TLExpr::And(p, q) = body.as_ref() {
401 let q_vars = q.free_vars();
402 if !q_vars.contains(var.as_str()) {
403 return Some(TLExpr::and(
404 TLExpr::forall(var, domain, p.as_ref().clone()),
405 q.as_ref().clone(),
406 ));
407 }
408
409 let p_vars = p.free_vars();
410 if !p_vars.contains(var.as_str()) {
411 return Some(TLExpr::and(
412 p.as_ref().clone(),
413 TLExpr::forall(var, domain, q.as_ref().clone()),
414 ));
415 }
416 }
417 None
418 }
419 _ => None,
420 }
421 }
422
423 fn pull_exists_outward(&self, expr: &TLExpr) -> Option<TLExpr> {
425 match expr {
426 TLExpr::And(left, right) => {
427 if let TLExpr::Exists { var, domain, body } = left.as_ref() {
428 let right_vars = right.free_vars();
429 if !right_vars.contains(var.as_str()) {
430 return Some(TLExpr::exists(
431 var,
432 domain,
433 TLExpr::and(body.as_ref().clone(), right.as_ref().clone()),
434 ));
435 }
436 }
437
438 if let TLExpr::Exists { var, domain, body } = right.as_ref() {
439 let left_vars = left.free_vars();
440 if !left_vars.contains(var.as_str()) {
441 return Some(TLExpr::exists(
442 var,
443 domain,
444 TLExpr::and(left.as_ref().clone(), body.as_ref().clone()),
445 ));
446 }
447 }
448 None
449 }
450 _ => None,
451 }
452 }
453
454 fn pull_forall_outward(&self, expr: &TLExpr) -> Option<TLExpr> {
456 match expr {
457 TLExpr::And(left, right) => {
458 if let TLExpr::ForAll { var, domain, body } = left.as_ref() {
459 let right_vars = right.free_vars();
460 if !right_vars.contains(var.as_str()) {
461 return Some(TLExpr::forall(
462 var,
463 domain,
464 TLExpr::and(body.as_ref().clone(), right.as_ref().clone()),
465 ));
466 }
467 }
468
469 if let TLExpr::ForAll { var, domain, body } = right.as_ref() {
470 let left_vars = left.free_vars();
471 if !left_vars.contains(var.as_str()) {
472 return Some(TLExpr::forall(
473 var,
474 domain,
475 TLExpr::and(left.as_ref().clone(), body.as_ref().clone()),
476 ));
477 }
478 }
479 None
480 }
481 _ => None,
482 }
483 }
484
485 fn merge_nested_exists(&self, _expr: &TLExpr) -> Option<TLExpr> {
487 None
490 }
491
492 fn merge_nested_forall(&self, _expr: &TLExpr) -> Option<TLExpr> {
494 None
497 }
498
499 fn reorder_conjunctions(&self, expr: &TLExpr) -> Option<TLExpr> {
501 match expr {
502 TLExpr::And(a, b) => {
503 let cost_a = self.estimate_cost(a);
504 let cost_b = self.estimate_cost(b);
505
506 if cost_b < cost_a {
508 Some(TLExpr::and(b.as_ref().clone(), a.as_ref().clone()))
509 } else {
510 None
511 }
512 }
513 _ => None,
514 }
515 }
516
517 fn reorder_disjunctions(&self, expr: &TLExpr) -> Option<TLExpr> {
519 match expr {
520 TLExpr::Or(a, b) => {
521 let cost_a = self.estimate_cost(a);
522 let cost_b = self.estimate_cost(b);
523
524 if cost_b < cost_a {
526 Some(TLExpr::or(b.as_ref().clone(), a.as_ref().clone()))
527 } else {
528 None
529 }
530 }
531 _ => None,
532 }
533 }
534}
535
536fn expr_hash(expr: &TLExpr) -> String {
538 format!("{:?}", expr)
539}
540
541pub fn optimize_by_cost(expr: &TLExpr, context: &CompilerContext) -> (TLExpr, CostBasedStats) {
565 let mut optimizer = CostBasedOptimizer::new(context);
566 optimizer.optimize(expr)
567}
568
569pub fn optimize_by_cost_with_config(
571 expr: &TLExpr,
572 context: &CompilerContext,
573 weights: CostWeights,
574 max_alternatives: usize,
575) -> (TLExpr, CostBasedStats) {
576 let mut optimizer = CostBasedOptimizer::new(context)
577 .with_cost_weights(weights)
578 .with_max_alternatives(max_alternatives);
579 optimizer.optimize(expr)
580}
581
582#[cfg(test)]
583mod tests {
584 use super::*;
585 use tensorlogic_ir::Term;
586
587 fn test_context() -> CompilerContext {
588 let mut ctx = CompilerContext::new();
589 ctx.add_domain("Person", 100);
590 ctx.add_domain("City", 50);
591 ctx
592 }
593
594 #[test]
595 fn test_distribute_and_over_or() {
596 let ctx = test_context();
597 let expr = TLExpr::and(
598 TLExpr::pred("p", vec![Term::var("x")]),
599 TLExpr::or(
600 TLExpr::pred("q", vec![Term::var("x")]),
601 TLExpr::pred("r", vec![Term::var("x")]),
602 ),
603 );
604
605 let weights = CostWeights::default();
607 let (_optimized, stats) = optimize_by_cost_with_config(&expr, &ctx, weights, 10);
608 assert!(stats.alternatives_explored > 1);
609 assert!(stats.alternatives_explored < 50);
611 }
612
613 #[test]
614 fn test_factor_common_and() {
615 let ctx = test_context();
616 let p = TLExpr::pred("p", vec![Term::var("x")]);
617 let q = TLExpr::pred("q", vec![Term::var("x")]);
618 let r = TLExpr::pred("r", vec![Term::var("x")]);
619
620 let expr = TLExpr::or(
622 TLExpr::and(p.clone(), q.clone()),
623 TLExpr::and(p.clone(), r.clone()),
624 );
625
626 let weights = CostWeights::default();
628 let (_optimized, stats) = optimize_by_cost_with_config(&expr, &ctx, weights, 10);
629 assert!(stats.alternatives_explored > 1);
630 assert!(stats.alternatives_explored < 50);
632 }
633
634 #[test]
635 fn test_push_exists_inward() {
636 let ctx = test_context();
637 let expr = TLExpr::exists(
639 "x",
640 "Person",
641 TLExpr::and(
642 TLExpr::pred("p", vec![Term::var("x")]),
643 TLExpr::pred("q", vec![Term::var("y")]),
644 ),
645 );
646
647 let (_optimized, stats) = optimize_by_cost(&expr, &ctx);
648 assert!(stats.alternatives_explored > 0);
649 }
650
651 #[test]
652 fn test_reorder_conjunctions() {
653 let ctx = test_context();
654 let expensive = TLExpr::exists(
655 "y",
656 "City",
657 TLExpr::pred("expensive", vec![Term::var("x"), Term::var("y")]),
658 );
659 let cheap = TLExpr::pred("cheap", vec![Term::var("x")]);
660
661 let expr = TLExpr::and(expensive, cheap);
663
664 let (_optimized, stats) = optimize_by_cost(&expr, &ctx);
665 assert!(stats.alternatives_explored > 1);
666 }
667
668 #[test]
669 fn test_cost_reduction_calculation() {
670 let stats = CostBasedStats {
671 alternatives_explored: 5,
672 original_cost: 100.0,
673 optimized_cost: 75.0,
674 rewrites_applied: 2,
675 time_us: 1000,
676 };
677
678 assert_eq!(stats.cost_reduction_percent(), 25.0);
679 assert!(stats.is_beneficial());
680 assert_eq!(stats.cost_ratio(), 0.75);
681 }
682
683 #[test]
684 fn test_no_improvement() {
685 let stats = CostBasedStats {
686 alternatives_explored: 3,
687 original_cost: 50.0,
688 optimized_cost: 50.0,
689 rewrites_applied: 0,
690 time_us: 500,
691 };
692
693 assert_eq!(stats.cost_reduction_percent(), 0.0);
694 assert!(!stats.is_beneficial());
695 assert_eq!(stats.cost_ratio(), 1.0);
696 }
697
698 #[test]
699 fn test_simple_expression_no_rewrites() {
700 let ctx = test_context();
701 let expr = TLExpr::pred("p", vec![Term::var("x")]);
702
703 let (optimized, stats) = optimize_by_cost(&expr, &ctx);
704 assert_eq!(optimized, expr);
705 assert_eq!(stats.rewrites_applied, 0);
706 }
707
708 #[test]
709 fn test_custom_cost_weights() {
710 let ctx = test_context();
711 let expr = TLExpr::and(
712 TLExpr::pred("p", vec![Term::var("x")]),
713 TLExpr::pred("q", vec![Term::var("x")]),
714 );
715
716 let weights = CostWeights {
717 reduction: 10.0, cmp: 5.0, ..Default::default()
720 };
721
722 let (_optimized, stats) = optimize_by_cost_with_config(&expr, &ctx, weights, 50);
723 assert!(stats.alternatives_explored > 0);
724 }
725
726 #[test]
727 fn test_max_alternatives_limit() {
728 let ctx = test_context();
729 let expr = TLExpr::and(
730 TLExpr::or(
731 TLExpr::pred("p", vec![Term::var("x")]),
732 TLExpr::pred("q", vec![Term::var("x")]),
733 ),
734 TLExpr::or(
735 TLExpr::pred("r", vec![Term::var("x")]),
736 TLExpr::pred("s", vec![Term::var("x")]),
737 ),
738 );
739
740 let weights = CostWeights::default();
741 let (_optimized, stats) = optimize_by_cost_with_config(&expr, &ctx, weights, 5);
742 assert!(stats.alternatives_explored < 25);
744 }
745
746 #[test]
747 fn test_complex_quantifier_expression() {
748 let ctx = test_context();
749 let expr = TLExpr::exists(
751 "x",
752 "Person",
753 TLExpr::exists(
754 "y",
755 "City",
756 TLExpr::and(
757 TLExpr::pred("p", vec![Term::var("x")]),
758 TLExpr::pred("q", vec![Term::var("y")]),
759 ),
760 ),
761 );
762
763 let (_optimized, stats) = optimize_by_cost(&expr, &ctx);
764 assert!(stats.alternatives_explored > 0);
765 }
766}