1use tensorlogic_ir::TLExpr;
26
27#[derive(Debug, Clone)]
29pub struct CostWeights {
30 pub add_sub: f64,
32 pub mul: f64,
34 pub div: f64,
36 pub pow: f64,
38 pub exp: f64,
40 pub log: f64,
42 pub sqrt: f64,
44 pub abs: f64,
46 pub cmp: f64,
48 pub reduction: f64,
50}
51
52impl Default for CostWeights {
53 fn default() -> Self {
54 Self {
55 add_sub: 1.0,
56 mul: 2.0,
57 div: 4.0,
58 pow: 8.0,
59 exp: 10.0,
60 log: 10.0,
61 sqrt: 4.0,
62 abs: 1.0,
63 cmp: 1.0,
64 reduction: 5.0,
65 }
66 }
67}
68
69impl CostWeights {
70 pub fn gpu_optimized() -> Self {
72 Self {
73 add_sub: 1.0,
74 mul: 1.0,
75 div: 2.0,
76 pow: 4.0,
77 exp: 3.0,
78 log: 3.0,
79 sqrt: 2.0,
80 abs: 1.0,
81 cmp: 1.0,
82 reduction: 10.0, }
84 }
85
86 pub fn simd_optimized() -> Self {
88 Self {
89 add_sub: 1.0,
90 mul: 1.0,
91 div: 3.0,
92 pow: 6.0,
93 exp: 8.0,
94 log: 8.0,
95 sqrt: 3.0,
96 abs: 1.0,
97 cmp: 1.0,
98 reduction: 3.0,
99 }
100 }
101}
102
103#[derive(Debug, Clone, Default)]
105pub struct ExpressionComplexity {
106 pub additions: usize,
108 pub subtractions: usize,
110 pub multiplications: usize,
112 pub divisions: usize,
114 pub powers: usize,
116 pub exponentials: usize,
118 pub logarithms: usize,
120 pub square_roots: usize,
122 pub absolute_values: usize,
124 pub negations: usize,
126 pub comparisons: usize,
128 pub logical_ands: usize,
130 pub logical_ors: usize,
132 pub logical_nots: usize,
134 pub existential_quantifiers: usize,
136 pub universal_quantifiers: usize,
138 pub conditionals: usize,
140 pub predicates: usize,
142 pub constants: usize,
144 pub variables: usize,
146 pub min_operations: usize,
148 pub max_operations: usize,
150 pub max_depth: usize,
152 pub unique_variables: usize,
154 pub unique_predicates: usize,
156}
157
158impl ExpressionComplexity {
159 pub fn arithmetic_operations(&self) -> usize {
161 self.additions
162 + self.subtractions
163 + self.multiplications
164 + self.divisions
165 + self.powers
166 + self.exponentials
167 + self.logarithms
168 + self.square_roots
169 + self.absolute_values
170 + self.negations
171 }
172
173 pub fn logical_operations(&self) -> usize {
175 self.logical_ands + self.logical_ors + self.logical_nots
176 }
177
178 pub fn total_operations(&self) -> usize {
180 self.arithmetic_operations()
181 + self.logical_operations()
182 + self.comparisons
183 + self.conditionals
184 + self.min_operations
185 + self.max_operations
186 }
187
188 pub fn total_cost(&self) -> f64 {
190 self.total_cost_with_weights(&CostWeights::default())
191 }
192
193 pub fn total_cost_with_weights(&self, weights: &CostWeights) -> f64 {
195 let mut cost = 0.0;
196 cost += (self.additions + self.subtractions) as f64 * weights.add_sub;
197 cost += self.multiplications as f64 * weights.mul;
198 cost += self.divisions as f64 * weights.div;
199 cost += self.powers as f64 * weights.pow;
200 cost += self.exponentials as f64 * weights.exp;
201 cost += self.logarithms as f64 * weights.log;
202 cost += self.square_roots as f64 * weights.sqrt;
203 cost += self.absolute_values as f64 * weights.abs;
204 cost += self.comparisons as f64 * weights.cmp;
205 cost +=
206 (self.existential_quantifiers + self.universal_quantifiers) as f64 * weights.reduction;
207 cost += self.min_operations as f64 * weights.cmp;
208 cost += self.max_operations as f64 * weights.cmp;
209 cost
210 }
211
212 pub fn leaf_count(&self) -> usize {
214 self.constants + self.variables + self.predicates
215 }
216
217 pub fn cse_potential(&self) -> bool {
219 self.total_operations() > 5 && self.max_depth > 3
221 }
222
223 pub fn strength_reduction_potential(&self) -> bool {
225 self.powers > 0 || self.divisions > 2 || self.exponentials + self.logarithms > 0
226 }
227
228 pub fn complexity_level(&self) -> &'static str {
230 let total = self.total_operations();
231 if total <= 3 {
232 "trivial"
233 } else if total <= 10 {
234 "simple"
235 } else if total <= 30 {
236 "moderate"
237 } else if total <= 100 {
238 "complex"
239 } else {
240 "very_complex"
241 }
242 }
243}
244
245impl std::fmt::Display for ExpressionComplexity {
246 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
247 writeln!(f, "Expression Complexity Analysis:")?;
248 writeln!(f, " Total operations: {}", self.total_operations())?;
249 writeln!(
250 f,
251 " Arithmetic operations: {}",
252 self.arithmetic_operations()
253 )?;
254 writeln!(f, " Logical operations: {}", self.logical_operations())?;
255 writeln!(f, " Maximum depth: {}", self.max_depth)?;
256 writeln!(f, " Estimated cost: {:.2}", self.total_cost())?;
257 writeln!(f, " Complexity level: {}", self.complexity_level())?;
258 Ok(())
259 }
260}
261
262pub fn analyze_complexity(expr: &TLExpr) -> ExpressionComplexity {
272 let mut complexity = ExpressionComplexity::default();
273 let mut var_names = std::collections::HashSet::new();
274 let mut pred_names = std::collections::HashSet::new();
275
276 analyze_complexity_impl(expr, &mut complexity, 0, &mut var_names, &mut pred_names);
277
278 complexity.unique_variables = var_names.len();
279 complexity.unique_predicates = pred_names.len();
280
281 complexity
282}
283
284fn analyze_complexity_impl(
285 expr: &TLExpr,
286 complexity: &mut ExpressionComplexity,
287 depth: usize,
288 var_names: &mut std::collections::HashSet<String>,
289 pred_names: &mut std::collections::HashSet<String>,
290) {
291 complexity.max_depth = complexity.max_depth.max(depth);
292
293 match expr {
294 TLExpr::Add(lhs, rhs) => {
295 complexity.additions += 1;
296 analyze_complexity_impl(lhs, complexity, depth + 1, var_names, pred_names);
297 analyze_complexity_impl(rhs, complexity, depth + 1, var_names, pred_names);
298 }
299
300 TLExpr::Sub(lhs, rhs) => {
301 complexity.subtractions += 1;
302 analyze_complexity_impl(lhs, complexity, depth + 1, var_names, pred_names);
303 analyze_complexity_impl(rhs, complexity, depth + 1, var_names, pred_names);
304 }
305
306 TLExpr::Mul(lhs, rhs) => {
307 complexity.multiplications += 1;
308 analyze_complexity_impl(lhs, complexity, depth + 1, var_names, pred_names);
309 analyze_complexity_impl(rhs, complexity, depth + 1, var_names, pred_names);
310 }
311
312 TLExpr::Div(lhs, rhs) => {
313 complexity.divisions += 1;
314 analyze_complexity_impl(lhs, complexity, depth + 1, var_names, pred_names);
315 analyze_complexity_impl(rhs, complexity, depth + 1, var_names, pred_names);
316 }
317
318 TLExpr::Pow(base, exp) => {
319 complexity.powers += 1;
320 analyze_complexity_impl(base, complexity, depth + 1, var_names, pred_names);
321 analyze_complexity_impl(exp, complexity, depth + 1, var_names, pred_names);
322 }
323
324 TLExpr::Exp(inner) => {
325 complexity.exponentials += 1;
326 analyze_complexity_impl(inner, complexity, depth + 1, var_names, pred_names);
327 }
328
329 TLExpr::Log(inner) => {
330 complexity.logarithms += 1;
331 analyze_complexity_impl(inner, complexity, depth + 1, var_names, pred_names);
332 }
333
334 TLExpr::Sqrt(inner) => {
335 complexity.square_roots += 1;
336 analyze_complexity_impl(inner, complexity, depth + 1, var_names, pred_names);
337 }
338
339 TLExpr::Abs(inner) => {
340 complexity.absolute_values += 1;
341 analyze_complexity_impl(inner, complexity, depth + 1, var_names, pred_names);
342 }
343
344 TLExpr::And(lhs, rhs) => {
345 complexity.logical_ands += 1;
346 analyze_complexity_impl(lhs, complexity, depth + 1, var_names, pred_names);
347 analyze_complexity_impl(rhs, complexity, depth + 1, var_names, pred_names);
348 }
349
350 TLExpr::Or(lhs, rhs) => {
351 complexity.logical_ors += 1;
352 analyze_complexity_impl(lhs, complexity, depth + 1, var_names, pred_names);
353 analyze_complexity_impl(rhs, complexity, depth + 1, var_names, pred_names);
354 }
355
356 TLExpr::Not(inner) => {
357 complexity.logical_nots += 1;
358 analyze_complexity_impl(inner, complexity, depth + 1, var_names, pred_names);
359 }
360
361 TLExpr::Imply(lhs, rhs) => {
362 complexity.logical_nots += 1;
364 complexity.logical_ors += 1;
365 analyze_complexity_impl(lhs, complexity, depth + 1, var_names, pred_names);
366 analyze_complexity_impl(rhs, complexity, depth + 1, var_names, pred_names);
367 }
368
369 TLExpr::Eq(lhs, rhs)
370 | TLExpr::Lt(lhs, rhs)
371 | TLExpr::Lte(lhs, rhs)
372 | TLExpr::Gt(lhs, rhs)
373 | TLExpr::Gte(lhs, rhs) => {
374 complexity.comparisons += 1;
375 analyze_complexity_impl(lhs, complexity, depth + 1, var_names, pred_names);
376 analyze_complexity_impl(rhs, complexity, depth + 1, var_names, pred_names);
377 }
378
379 TLExpr::Min(lhs, rhs) => {
380 complexity.min_operations += 1;
381 analyze_complexity_impl(lhs, complexity, depth + 1, var_names, pred_names);
382 analyze_complexity_impl(rhs, complexity, depth + 1, var_names, pred_names);
383 }
384
385 TLExpr::Max(lhs, rhs) => {
386 complexity.max_operations += 1;
387 analyze_complexity_impl(lhs, complexity, depth + 1, var_names, pred_names);
388 analyze_complexity_impl(rhs, complexity, depth + 1, var_names, pred_names);
389 }
390
391 TLExpr::Exists { var, body, .. } => {
392 complexity.existential_quantifiers += 1;
393 var_names.insert(var.clone());
394 analyze_complexity_impl(body, complexity, depth + 1, var_names, pred_names);
395 }
396
397 TLExpr::ForAll { var, body, .. } => {
398 complexity.universal_quantifiers += 1;
399 var_names.insert(var.clone());
400 analyze_complexity_impl(body, complexity, depth + 1, var_names, pred_names);
401 }
402
403 TLExpr::Let {
404 var, value, body, ..
405 } => {
406 var_names.insert(var.clone());
407 analyze_complexity_impl(value, complexity, depth + 1, var_names, pred_names);
408 analyze_complexity_impl(body, complexity, depth + 1, var_names, pred_names);
409 }
410
411 TLExpr::IfThenElse {
412 condition,
413 then_branch,
414 else_branch,
415 } => {
416 complexity.conditionals += 1;
417 analyze_complexity_impl(condition, complexity, depth + 1, var_names, pred_names);
418 analyze_complexity_impl(then_branch, complexity, depth + 1, var_names, pred_names);
419 analyze_complexity_impl(else_branch, complexity, depth + 1, var_names, pred_names);
420 }
421
422 TLExpr::Pred { name, args } => {
423 complexity.predicates += 1;
424 pred_names.insert(name.clone());
425 for arg in args {
426 if let tensorlogic_ir::Term::Var(v) = arg {
427 var_names.insert(v.clone());
428 }
429 }
430 }
431
432 TLExpr::Constant(_) => {
433 complexity.constants += 1;
434 }
435
436 TLExpr::Box(inner) | TLExpr::Diamond(inner) => {
438 complexity.universal_quantifiers += 1; analyze_complexity_impl(inner, complexity, depth + 1, var_names, pred_names);
440 }
441
442 TLExpr::Next(inner) | TLExpr::Eventually(inner) | TLExpr::Always(inner) => {
444 complexity.existential_quantifiers += 1; analyze_complexity_impl(inner, complexity, depth + 1, var_names, pred_names);
446 }
447
448 TLExpr::Until { before, after } => {
449 complexity.existential_quantifiers += 1;
450 analyze_complexity_impl(before, complexity, depth + 1, var_names, pred_names);
451 analyze_complexity_impl(after, complexity, depth + 1, var_names, pred_names);
452 }
453
454 TLExpr::Score(inner)
456 | TLExpr::Floor(inner)
457 | TLExpr::Ceil(inner)
458 | TLExpr::Round(inner)
459 | TLExpr::Sin(inner)
460 | TLExpr::Cos(inner)
461 | TLExpr::Tan(inner)
462 | TLExpr::FuzzyNot { expr: inner, .. } => {
463 analyze_complexity_impl(inner, complexity, depth + 1, var_names, pred_names);
464 }
465
466 TLExpr::Mod(lhs, rhs)
467 | TLExpr::TNorm {
468 left: lhs,
469 right: rhs,
470 ..
471 }
472 | TLExpr::TCoNorm {
473 left: lhs,
474 right: rhs,
475 ..
476 }
477 | TLExpr::FuzzyImplication {
478 premise: lhs,
479 conclusion: rhs,
480 ..
481 }
482 | TLExpr::Release {
483 released: lhs,
484 releaser: rhs,
485 }
486 | TLExpr::WeakUntil {
487 before: lhs,
488 after: rhs,
489 }
490 | TLExpr::StrongRelease {
491 released: lhs,
492 releaser: rhs,
493 } => {
494 analyze_complexity_impl(lhs, complexity, depth + 1, var_names, pred_names);
495 analyze_complexity_impl(rhs, complexity, depth + 1, var_names, pred_names);
496 }
497
498 TLExpr::Aggregate { body, .. }
499 | TLExpr::SoftExists { body, .. }
500 | TLExpr::SoftForAll { body, .. }
501 | TLExpr::WeightedRule { rule: body, .. } => {
502 complexity.existential_quantifiers += 1;
503 analyze_complexity_impl(body, complexity, depth + 1, var_names, pred_names);
504 }
505
506 TLExpr::ProbabilisticChoice { alternatives } => {
507 for (_, expr) in alternatives {
508 analyze_complexity_impl(expr, complexity, depth + 1, var_names, pred_names);
509 }
510 }
511
512 _ => {}
514 }
515}
516
517pub fn compare_complexity(expr1: &TLExpr, expr2: &TLExpr) -> std::cmp::Ordering {
521 let c1 = analyze_complexity(expr1);
522 let c2 = analyze_complexity(expr2);
523
524 let cost1 = c1.total_cost();
525 let cost2 = c2.total_cost();
526
527 cost1
528 .partial_cmp(&cost2)
529 .unwrap_or(std::cmp::Ordering::Equal)
530}
531
532#[cfg(test)]
533mod tests {
534 use super::*;
535 use tensorlogic_ir::Term;
536
537 #[test]
538 fn test_simple_addition() {
539 let expr = TLExpr::add(TLExpr::Constant(1.0), TLExpr::Constant(2.0));
540 let complexity = analyze_complexity(&expr);
541
542 assert_eq!(complexity.additions, 1);
543 assert_eq!(complexity.constants, 2);
544 assert_eq!(complexity.total_operations(), 1);
545 }
546
547 #[test]
548 fn test_nested_operations() {
549 let x = TLExpr::pred("x", vec![Term::var("i")]);
550 let expr = TLExpr::mul(
551 TLExpr::add(x.clone(), TLExpr::Constant(1.0)),
552 TLExpr::sub(x, TLExpr::Constant(2.0)),
553 );
554 let complexity = analyze_complexity(&expr);
555
556 assert_eq!(complexity.additions, 1);
557 assert_eq!(complexity.subtractions, 1);
558 assert_eq!(complexity.multiplications, 1);
559 assert_eq!(complexity.predicates, 2);
560 assert_eq!(complexity.constants, 2);
561 }
562
563 #[test]
564 fn test_logical_operations() {
565 let a = TLExpr::pred("a", vec![Term::var("x")]);
566 let b = TLExpr::pred("b", vec![Term::var("y")]);
567 let expr = TLExpr::and(a, TLExpr::negate(b));
568 let complexity = analyze_complexity(&expr);
569
570 assert_eq!(complexity.logical_ands, 1);
571 assert_eq!(complexity.logical_nots, 1);
572 assert_eq!(complexity.predicates, 2);
573 }
574
575 #[test]
576 fn test_quantifiers() {
577 let pred = TLExpr::pred("p", vec![Term::var("x"), Term::var("y")]);
578 let expr = TLExpr::exists("x", "D1", TLExpr::forall("y", "D2", pred));
579 let complexity = analyze_complexity(&expr);
580
581 assert_eq!(complexity.existential_quantifiers, 1);
582 assert_eq!(complexity.universal_quantifiers, 1);
583 assert_eq!(complexity.predicates, 1);
584 assert_eq!(complexity.unique_variables, 2);
585 }
586
587 #[test]
588 fn test_depth_calculation() {
589 let x = TLExpr::pred("x", vec![Term::var("i")]);
591 let expr = TLExpr::add(TLExpr::mul(x, TLExpr::Constant(2.0)), TLExpr::Constant(3.0));
592 let complexity = analyze_complexity(&expr);
593
594 assert_eq!(complexity.max_depth, 2);
595 }
596
597 #[test]
598 fn test_cost_calculation() {
599 let x = TLExpr::pred("x", vec![Term::var("i")]);
600 let expr = TLExpr::add(TLExpr::mul(x, TLExpr::Constant(2.0)), TLExpr::Constant(3.0));
602 let complexity = analyze_complexity(&expr);
603
604 let cost = complexity.total_cost();
605 assert!(cost > 0.0);
606 assert_eq!(cost, 3.0);
608 }
609
610 #[test]
611 fn test_gpu_weights() {
612 let weights = CostWeights::gpu_optimized();
613 assert!(weights.reduction > weights.mul);
614 }
615
616 #[test]
617 fn test_complexity_level() {
618 let simple = TLExpr::add(TLExpr::Constant(1.0), TLExpr::Constant(2.0));
619 let complex = {
620 let mut expr = TLExpr::pred("x", vec![Term::var("i")]);
621 for _ in 0..20 {
622 expr = TLExpr::add(expr, TLExpr::Constant(1.0));
623 }
624 expr
625 };
626
627 let simple_c = analyze_complexity(&simple);
628 let complex_c = analyze_complexity(&complex);
629
630 assert_eq!(simple_c.complexity_level(), "trivial");
631 assert!(
632 complex_c.complexity_level() == "moderate" || complex_c.complexity_level() == "complex"
633 );
634 }
635
636 #[test]
637 fn test_cse_potential() {
638 let simple = TLExpr::add(TLExpr::Constant(1.0), TLExpr::Constant(2.0));
640 let simple_c = analyze_complexity(&simple);
641 assert!(!simple_c.cse_potential());
642
643 let x = TLExpr::pred("x", vec![Term::var("i")]);
645 let complex = TLExpr::mul(
646 TLExpr::exp(TLExpr::add(
647 TLExpr::mul(x.clone(), TLExpr::Constant(2.0)),
648 TLExpr::Constant(1.0),
649 )),
650 TLExpr::log(TLExpr::sub(
651 TLExpr::div(x, TLExpr::Constant(3.0)),
652 TLExpr::Constant(4.0),
653 )),
654 );
655 let complex_c = analyze_complexity(&complex);
656 assert!(complex_c.cse_potential());
657 }
658
659 #[test]
660 fn test_strength_reduction_potential() {
661 let x = TLExpr::pred("x", vec![Term::var("i")]);
663 let expr = TLExpr::pow(x.clone(), TLExpr::Constant(2.0));
664 let c = analyze_complexity(&expr);
665 assert!(c.strength_reduction_potential());
666
667 let simple = TLExpr::add(x, TLExpr::Constant(1.0));
669 let simple_c = analyze_complexity(&simple);
670 assert!(!simple_c.strength_reduction_potential());
671 }
672
673 #[test]
674 fn test_compare_complexity() {
675 let simple = TLExpr::add(TLExpr::Constant(1.0), TLExpr::Constant(2.0));
676 let x = TLExpr::pred("x", vec![Term::var("i")]);
677 let complex = TLExpr::mul(
678 TLExpr::add(x.clone(), TLExpr::Constant(1.0)),
679 TLExpr::sub(x, TLExpr::Constant(2.0)),
680 );
681
682 let ordering = compare_complexity(&simple, &complex);
683 assert_eq!(ordering, std::cmp::Ordering::Less);
684 }
685
686 #[test]
687 fn test_display() {
688 let expr = TLExpr::add(TLExpr::Constant(1.0), TLExpr::Constant(2.0));
689 let complexity = analyze_complexity(&expr);
690 let display = format!("{}", complexity);
691
692 assert!(display.contains("Expression Complexity Analysis"));
693 assert!(display.contains("Total operations:"));
694 }
695
696 #[test]
697 fn test_arithmetic_vs_logical() {
698 let arith = TLExpr::mul(
699 TLExpr::add(TLExpr::Constant(1.0), TLExpr::Constant(2.0)),
700 TLExpr::Constant(3.0),
701 );
702 let logic = TLExpr::and(
703 TLExpr::or(TLExpr::pred("a", vec![]), TLExpr::pred("b", vec![])),
704 TLExpr::pred("c", vec![]),
705 );
706
707 let arith_c = analyze_complexity(&arith);
708 let logic_c = analyze_complexity(&logic);
709
710 assert!(arith_c.arithmetic_operations() > 0);
711 assert_eq!(arith_c.logical_operations(), 0);
712 assert_eq!(logic_c.arithmetic_operations(), 0);
713 assert!(logic_c.logical_operations() > 0);
714 }
715
716 #[test]
717 fn test_unique_variables() {
718 let expr = TLExpr::exists(
719 "x",
720 "D",
721 TLExpr::forall(
722 "y",
723 "D",
724 TLExpr::pred("p", vec![Term::var("x"), Term::var("y"), Term::var("z")]),
725 ),
726 );
727 let c = analyze_complexity(&expr);
728
729 assert_eq!(c.unique_variables, 3); }
731
732 #[test]
733 fn test_unique_predicates() {
734 let expr = TLExpr::and(
735 TLExpr::pred("foo", vec![Term::var("x")]),
736 TLExpr::or(
737 TLExpr::pred("bar", vec![Term::var("y")]),
738 TLExpr::pred("foo", vec![Term::var("z")]), ),
740 );
741 let c = analyze_complexity(&expr);
742
743 assert_eq!(c.unique_predicates, 2); assert_eq!(c.predicates, 3); }
746}