Skip to main content

tensorlogic_compiler/
complexity.rs

1//! Expression complexity analysis for TLExpr trees.
2//!
3//! This module provides comprehensive complexity metrics, threshold-based warnings,
4//! expression comparison, and batch statistics for analyzing the structural complexity
5//! of logical expressions before compilation.
6
7use std::collections::HashSet;
8use std::fmt;
9use tensorlogic_ir::{TLExpr, Term};
10
11// ---------------------------------------------------------------------------
12// ExprComplexity
13// ---------------------------------------------------------------------------
14
15/// Comprehensive complexity metrics for a TLExpr tree.
16#[derive(Debug, Clone, Default)]
17pub struct ExprComplexity {
18    /// Total AST nodes in the expression tree.
19    pub total_nodes: usize,
20    /// Maximum depth of the tree (root = depth 1).
21    pub depth: usize,
22    /// Maximum number of siblings at any single level.
23    pub width: usize,
24    /// Number of distinct variable names across the tree.
25    pub num_variables: usize,
26    /// Number of constant / literal nodes.
27    pub num_constants: usize,
28    /// Number of predicate application nodes.
29    pub num_predicates: usize,
30    /// Number of quantifier nodes (ForAll, Exists, soft, counting, etc.).
31    pub num_quantifiers: usize,
32    /// Maximum nesting depth of quantifiers.
33    pub quantifier_depth: usize,
34    /// Number of Not / FuzzyNot nodes.
35    pub num_negations: usize,
36    /// Number of connective nodes (And, Or, Imply, Iff, TNorm, TCoNorm, etc.).
37    pub num_connectives: usize,
38    /// Number of arithmetic operation nodes.
39    pub num_arithmetic: usize,
40    /// Number of set operation nodes.
41    pub num_set_ops: usize,
42    /// Number of let / lambda / fixpoint binding nodes.
43    pub num_let_bindings: usize,
44    /// Average children per internal (non-leaf) node.
45    pub branching_factor: f64,
46    /// Ratio of leaf nodes to total nodes.
47    pub leaf_ratio: f64,
48}
49
50impl ExprComplexity {
51    /// Compute complexity metrics for a [`TLExpr`].
52    pub fn analyze(expr: &TLExpr) -> Self {
53        let mut ctx = AnalysisContext::default();
54        let mut level_widths: Vec<usize> = Vec::new();
55        Self::visit(expr, 0, 0, &mut ctx, &mut level_widths);
56
57        let max_width = level_widths.iter().copied().max().unwrap_or(1);
58        let internal_nodes = ctx.total_nodes.saturating_sub(ctx.leaf_count);
59        let branching_factor = if internal_nodes > 0 {
60            ctx.total_edges as f64 / internal_nodes as f64
61        } else {
62            0.0
63        };
64        let leaf_ratio = if ctx.total_nodes > 0 {
65            ctx.leaf_count as f64 / ctx.total_nodes as f64
66        } else {
67            0.0
68        };
69
70        Self {
71            total_nodes: ctx.total_nodes,
72            depth: ctx.max_depth,
73            width: max_width,
74            num_variables: ctx.variables.len(),
75            num_constants: ctx.num_constants,
76            num_predicates: ctx.num_predicates,
77            num_quantifiers: ctx.num_quantifiers,
78            quantifier_depth: ctx.max_quantifier_depth,
79            num_negations: ctx.num_negations,
80            num_connectives: ctx.num_connectives,
81            num_arithmetic: ctx.num_arithmetic,
82            num_set_ops: ctx.num_set_ops,
83            num_let_bindings: ctx.num_let_bindings,
84            branching_factor,
85            leaf_ratio,
86        }
87    }
88
89    /// A scalar "complexity score" summarising overall complexity (weighted sum).
90    pub fn score(&self) -> f64 {
91        self.total_nodes as f64 * 1.0
92            + self.depth as f64 * 2.0
93            + self.quantifier_depth as f64 * 5.0
94            + self.num_variables as f64 * 0.5
95            + self.num_quantifiers as f64 * 3.0
96            + self.num_negations as f64 * 1.0
97            + self.num_connectives as f64 * 1.5
98            + self.num_arithmetic as f64 * 1.0
99            + self.num_set_ops as f64 * 2.0
100            + self.num_let_bindings as f64 * 2.0
101    }
102
103    /// Returns `true` when the scalar [`Self::score`] is below `threshold`.
104    pub fn is_simple(&self, threshold: f64) -> bool {
105        self.score() < threshold
106    }
107
108    /// Human-readable one-line summary.
109    pub fn summary(&self) -> String {
110        format!(
111            "nodes={}, depth={}, vars={}, quantifiers={} (depth={}), score={:.1}",
112            self.total_nodes,
113            self.depth,
114            self.num_variables,
115            self.num_quantifiers,
116            self.quantifier_depth,
117            self.score(),
118        )
119    }
120
121    /// Tabular representation of all metrics.
122    pub fn format_table(&self) -> String {
123        let mut buf = String::new();
124        buf.push_str("Metric                | Value\n");
125        buf.push_str("----------------------|------\n");
126        buf.push_str(&format!("Total nodes           | {}\n", self.total_nodes));
127        buf.push_str(&format!("Depth                 | {}\n", self.depth));
128        buf.push_str(&format!("Width                 | {}\n", self.width));
129        buf.push_str(&format!(
130            "Variables (distinct)   | {}\n",
131            self.num_variables
132        ));
133        buf.push_str(&format!("Constants             | {}\n", self.num_constants));
134        buf.push_str(&format!(
135            "Predicates            | {}\n",
136            self.num_predicates
137        ));
138        buf.push_str(&format!(
139            "Quantifiers           | {}\n",
140            self.num_quantifiers
141        ));
142        buf.push_str(&format!(
143            "Quantifier depth      | {}\n",
144            self.quantifier_depth
145        ));
146        buf.push_str(&format!("Negations             | {}\n", self.num_negations));
147        buf.push_str(&format!(
148            "Connectives           | {}\n",
149            self.num_connectives
150        ));
151        buf.push_str(&format!(
152            "Arithmetic ops        | {}\n",
153            self.num_arithmetic
154        ));
155        buf.push_str(&format!("Set ops               | {}\n", self.num_set_ops));
156        buf.push_str(&format!(
157            "Let/Lambda/Fixpoint   | {}\n",
158            self.num_let_bindings
159        ));
160        buf.push_str(&format!(
161            "Branching factor      | {:.3}\n",
162            self.branching_factor
163        ));
164        buf.push_str(&format!("Leaf ratio            | {:.3}\n", self.leaf_ratio));
165        buf.push_str(&format!("Complexity score      | {:.1}\n", self.score()));
166        buf
167    }
168
169    // ------------------------------------------------------------------
170    // Recursive visitor
171    // ------------------------------------------------------------------
172
173    fn visit(
174        expr: &TLExpr,
175        depth: usize,
176        quantifier_depth: usize,
177        ctx: &mut AnalysisContext,
178        level_widths: &mut Vec<usize>,
179    ) {
180        ctx.total_nodes += 1;
181        let current_depth = depth + 1;
182        if current_depth > ctx.max_depth {
183            ctx.max_depth = current_depth;
184        }
185
186        // Ensure level_widths has an entry for this depth.
187        while level_widths.len() <= depth {
188            level_widths.push(0);
189        }
190
191        // Classify and recurse.
192        match expr {
193            // -- Leaves --
194            TLExpr::Constant(_) => {
195                ctx.num_constants += 1;
196                ctx.leaf_count += 1;
197                level_widths[depth] += 1;
198            }
199            TLExpr::EmptySet => {
200                ctx.num_constants += 1;
201                ctx.leaf_count += 1;
202                level_widths[depth] += 1;
203            }
204            TLExpr::Nominal { .. } => {
205                ctx.leaf_count += 1;
206                level_widths[depth] += 1;
207            }
208            TLExpr::Abducible { .. } => {
209                ctx.leaf_count += 1;
210                level_widths[depth] += 1;
211            }
212            TLExpr::AllDifferent { variables } => {
213                // Leaf-like: just collects variable names.
214                for v in variables {
215                    ctx.variables.insert(v.clone());
216                }
217                ctx.leaf_count += 1;
218                level_widths[depth] += 1;
219            }
220
221            // -- Predicates --
222            TLExpr::Pred { args, .. } => {
223                ctx.num_predicates += 1;
224                // Collect variables from term args.
225                for term in args {
226                    Self::collect_term_vars(term, ctx);
227                }
228                ctx.leaf_count += 1; // Pred is a leaf in the *expression* tree.
229                level_widths[depth] += 1;
230            }
231
232            // -- Negation --
233            TLExpr::Not(inner) => {
234                ctx.num_negations += 1;
235                ctx.total_edges += 1;
236                Self::visit(inner, current_depth, quantifier_depth, ctx, level_widths);
237            }
238            TLExpr::FuzzyNot { expr: inner, .. } => {
239                ctx.num_negations += 1;
240                ctx.total_edges += 1;
241                Self::visit(inner, current_depth, quantifier_depth, ctx, level_widths);
242            }
243
244            // -- Connectives (binary logical) --
245            TLExpr::And(l, r) | TLExpr::Or(l, r) | TLExpr::Imply(l, r) => {
246                ctx.num_connectives += 1;
247                ctx.total_edges += 2;
248                Self::visit(l, current_depth, quantifier_depth, ctx, level_widths);
249                Self::visit(r, current_depth, quantifier_depth, ctx, level_widths);
250            }
251            TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
252                ctx.num_connectives += 1;
253                ctx.total_edges += 2;
254                Self::visit(left, current_depth, quantifier_depth, ctx, level_widths);
255                Self::visit(right, current_depth, quantifier_depth, ctx, level_widths);
256            }
257            TLExpr::FuzzyImplication {
258                premise,
259                conclusion,
260                ..
261            } => {
262                ctx.num_connectives += 1;
263                ctx.total_edges += 2;
264                Self::visit(premise, current_depth, quantifier_depth, ctx, level_widths);
265                Self::visit(
266                    conclusion,
267                    current_depth,
268                    quantifier_depth,
269                    ctx,
270                    level_widths,
271                );
272            }
273
274            // -- Quantifiers (standard) --
275            TLExpr::ForAll { var, body, .. } | TLExpr::Exists { var, body, .. } => {
276                ctx.num_quantifiers += 1;
277                ctx.variables.insert(var.clone());
278                let new_qd = quantifier_depth + 1;
279                if new_qd > ctx.max_quantifier_depth {
280                    ctx.max_quantifier_depth = new_qd;
281                }
282                ctx.total_edges += 1;
283                Self::visit(body, current_depth, new_qd, ctx, level_widths);
284            }
285
286            // -- Soft quantifiers --
287            TLExpr::SoftExists { var, body, .. } | TLExpr::SoftForAll { var, body, .. } => {
288                ctx.num_quantifiers += 1;
289                ctx.variables.insert(var.clone());
290                let new_qd = quantifier_depth + 1;
291                if new_qd > ctx.max_quantifier_depth {
292                    ctx.max_quantifier_depth = new_qd;
293                }
294                ctx.total_edges += 1;
295                Self::visit(body, current_depth, new_qd, ctx, level_widths);
296            }
297
298            // -- Counting quantifiers --
299            TLExpr::CountingExists { var, body, .. }
300            | TLExpr::CountingForAll { var, body, .. }
301            | TLExpr::ExactCount { var, body, .. }
302            | TLExpr::Majority { var, body, .. } => {
303                ctx.num_quantifiers += 1;
304                ctx.variables.insert(var.clone());
305                let new_qd = quantifier_depth + 1;
306                if new_qd > ctx.max_quantifier_depth {
307                    ctx.max_quantifier_depth = new_qd;
308                }
309                ctx.total_edges += 1;
310                Self::visit(body, current_depth, new_qd, ctx, level_widths);
311            }
312
313            // -- Aggregate (quantifier-like) --
314            TLExpr::Aggregate { var, body, .. } => {
315                ctx.num_quantifiers += 1;
316                ctx.variables.insert(var.clone());
317                let new_qd = quantifier_depth + 1;
318                if new_qd > ctx.max_quantifier_depth {
319                    ctx.max_quantifier_depth = new_qd;
320                }
321                ctx.total_edges += 1;
322                Self::visit(body, current_depth, new_qd, ctx, level_widths);
323            }
324
325            // -- Arithmetic (binary) --
326            TLExpr::Add(l, r)
327            | TLExpr::Sub(l, r)
328            | TLExpr::Mul(l, r)
329            | TLExpr::Div(l, r)
330            | TLExpr::Pow(l, r)
331            | TLExpr::Mod(l, r)
332            | TLExpr::Min(l, r)
333            | TLExpr::Max(l, r) => {
334                ctx.num_arithmetic += 1;
335                ctx.total_edges += 2;
336                Self::visit(l, current_depth, quantifier_depth, ctx, level_widths);
337                Self::visit(r, current_depth, quantifier_depth, ctx, level_widths);
338            }
339
340            // -- Arithmetic (unary math) --
341            TLExpr::Abs(inner)
342            | TLExpr::Floor(inner)
343            | TLExpr::Ceil(inner)
344            | TLExpr::Round(inner)
345            | TLExpr::Sqrt(inner)
346            | TLExpr::Exp(inner)
347            | TLExpr::Log(inner)
348            | TLExpr::Sin(inner)
349            | TLExpr::Cos(inner)
350            | TLExpr::Tan(inner) => {
351                ctx.num_arithmetic += 1;
352                ctx.total_edges += 1;
353                Self::visit(inner, current_depth, quantifier_depth, ctx, level_widths);
354            }
355
356            // -- Comparison (binary) --
357            TLExpr::Eq(l, r)
358            | TLExpr::Lt(l, r)
359            | TLExpr::Gt(l, r)
360            | TLExpr::Lte(l, r)
361            | TLExpr::Gte(l, r) => {
362                ctx.num_connectives += 1; // comparisons are logical connectives
363                ctx.total_edges += 2;
364                Self::visit(l, current_depth, quantifier_depth, ctx, level_widths);
365                Self::visit(r, current_depth, quantifier_depth, ctx, level_widths);
366            }
367
368            // -- Conditional --
369            TLExpr::IfThenElse {
370                condition,
371                then_branch,
372                else_branch,
373            } => {
374                ctx.total_edges += 3;
375                Self::visit(
376                    condition,
377                    current_depth,
378                    quantifier_depth,
379                    ctx,
380                    level_widths,
381                );
382                Self::visit(
383                    then_branch,
384                    current_depth,
385                    quantifier_depth,
386                    ctx,
387                    level_widths,
388                );
389                Self::visit(
390                    else_branch,
391                    current_depth,
392                    quantifier_depth,
393                    ctx,
394                    level_widths,
395                );
396            }
397
398            // -- Let / Lambda / Fixpoint --
399            TLExpr::Let {
400                var, value, body, ..
401            } => {
402                ctx.num_let_bindings += 1;
403                ctx.variables.insert(var.clone());
404                ctx.total_edges += 2;
405                Self::visit(value, current_depth, quantifier_depth, ctx, level_widths);
406                Self::visit(body, current_depth, quantifier_depth, ctx, level_widths);
407            }
408            TLExpr::Lambda { var, body, .. } => {
409                ctx.num_let_bindings += 1;
410                ctx.variables.insert(var.clone());
411                ctx.total_edges += 1;
412                Self::visit(body, current_depth, quantifier_depth, ctx, level_widths);
413            }
414            TLExpr::LeastFixpoint { var, body, .. }
415            | TLExpr::GreatestFixpoint { var, body, .. } => {
416                ctx.num_let_bindings += 1;
417                ctx.variables.insert(var.clone());
418                ctx.total_edges += 1;
419                Self::visit(body, current_depth, quantifier_depth, ctx, level_widths);
420            }
421
422            // -- Apply (higher-order) --
423            TLExpr::Apply {
424                function, argument, ..
425            } => {
426                ctx.total_edges += 2;
427                Self::visit(function, current_depth, quantifier_depth, ctx, level_widths);
428                Self::visit(argument, current_depth, quantifier_depth, ctx, level_widths);
429            }
430
431            // -- Set operations --
432            TLExpr::SetUnion { left, right }
433            | TLExpr::SetIntersection { left, right }
434            | TLExpr::SetDifference { left, right } => {
435                ctx.num_set_ops += 1;
436                ctx.total_edges += 2;
437                Self::visit(left, current_depth, quantifier_depth, ctx, level_widths);
438                Self::visit(right, current_depth, quantifier_depth, ctx, level_widths);
439            }
440            TLExpr::SetMembership { element, set } => {
441                ctx.num_set_ops += 1;
442                ctx.total_edges += 2;
443                Self::visit(element, current_depth, quantifier_depth, ctx, level_widths);
444                Self::visit(set, current_depth, quantifier_depth, ctx, level_widths);
445            }
446            TLExpr::SetCardinality { set } => {
447                ctx.num_set_ops += 1;
448                ctx.total_edges += 1;
449                Self::visit(set, current_depth, quantifier_depth, ctx, level_widths);
450            }
451            TLExpr::SetComprehension { var, condition, .. } => {
452                ctx.num_set_ops += 1;
453                ctx.variables.insert(var.clone());
454                ctx.total_edges += 1;
455                Self::visit(
456                    condition,
457                    current_depth,
458                    quantifier_depth,
459                    ctx,
460                    level_widths,
461                );
462            }
463
464            // -- Modal logic --
465            TLExpr::Box(inner) | TLExpr::Diamond(inner) => {
466                ctx.num_connectives += 1;
467                ctx.total_edges += 1;
468                Self::visit(inner, current_depth, quantifier_depth, ctx, level_widths);
469            }
470
471            // -- Temporal logic (unary) --
472            TLExpr::Next(inner) | TLExpr::Eventually(inner) | TLExpr::Always(inner) => {
473                ctx.num_connectives += 1;
474                ctx.total_edges += 1;
475                Self::visit(inner, current_depth, quantifier_depth, ctx, level_widths);
476            }
477
478            // -- Temporal logic (binary) --
479            TLExpr::Until { before, after }
480            | TLExpr::Release {
481                released: before,
482                releaser: after,
483            }
484            | TLExpr::WeakUntil { before, after }
485            | TLExpr::StrongRelease {
486                released: before,
487                releaser: after,
488            } => {
489                ctx.num_connectives += 1;
490                ctx.total_edges += 2;
491                Self::visit(before, current_depth, quantifier_depth, ctx, level_widths);
492                Self::visit(after, current_depth, quantifier_depth, ctx, level_widths);
493            }
494
495            // -- Score --
496            TLExpr::Score(inner) => {
497                ctx.total_edges += 1;
498                Self::visit(inner, current_depth, quantifier_depth, ctx, level_widths);
499            }
500
501            // -- Weighted rule --
502            TLExpr::WeightedRule { rule, .. } => {
503                ctx.total_edges += 1;
504                Self::visit(rule, current_depth, quantifier_depth, ctx, level_widths);
505            }
506
507            // -- Probabilistic choice --
508            TLExpr::ProbabilisticChoice { alternatives } => {
509                ctx.total_edges += alternatives.len();
510                for (_prob, alt_expr) in alternatives {
511                    Self::visit(alt_expr, current_depth, quantifier_depth, ctx, level_widths);
512                }
513            }
514
515            // -- Hybrid logic --
516            TLExpr::At { formula, .. } => {
517                ctx.total_edges += 1;
518                Self::visit(formula, current_depth, quantifier_depth, ctx, level_widths);
519            }
520            TLExpr::Somewhere { formula } | TLExpr::Everywhere { formula } => {
521                ctx.num_connectives += 1;
522                ctx.total_edges += 1;
523                Self::visit(formula, current_depth, quantifier_depth, ctx, level_widths);
524            }
525
526            // -- Constraint programming (GlobalCardinality) --
527            TLExpr::GlobalCardinality {
528                variables, values, ..
529            } => {
530                for v in variables {
531                    ctx.variables.insert(v.clone());
532                }
533                ctx.total_edges += values.len();
534                for val_expr in values {
535                    Self::visit(val_expr, current_depth, quantifier_depth, ctx, level_widths);
536                }
537            }
538
539            // -- Explain (abductive reasoning) --
540            TLExpr::Explain { formula } => {
541                ctx.total_edges += 1;
542                Self::visit(formula, current_depth, quantifier_depth, ctx, level_widths);
543            }
544
545            // -- Symbol literal --
546            TLExpr::SymbolLiteral(_) => {
547                ctx.leaf_count += 1;
548                level_widths[depth] += 1;
549            }
550
551            // -- Pattern matching --
552            TLExpr::Match { scrutinee, arms } => {
553                ctx.total_edges += 1 + arms.len();
554                Self::visit(
555                    scrutinee,
556                    current_depth,
557                    quantifier_depth,
558                    ctx,
559                    level_widths,
560                );
561                for (_, body) in arms {
562                    Self::visit(body, current_depth, quantifier_depth, ctx, level_widths);
563                }
564            }
565        }
566    }
567
568    /// Collect variable names from a [`Term`].
569    fn collect_term_vars(term: &Term, ctx: &mut AnalysisContext) {
570        match term {
571            Term::Var(name) => {
572                ctx.variables.insert(name.clone());
573            }
574            Term::Const(_) => {
575                ctx.num_constants += 1;
576            }
577            Term::Typed { value, .. } => {
578                Self::collect_term_vars(value, ctx);
579            }
580        }
581    }
582}
583
584impl fmt::Display for ExprComplexity {
585    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
586        write!(f, "{}", self.summary())
587    }
588}
589
590// ---------------------------------------------------------------------------
591// Internal analysis context
592// ---------------------------------------------------------------------------
593
594#[derive(Debug, Default)]
595struct AnalysisContext {
596    total_nodes: usize,
597    max_depth: usize,
598    max_quantifier_depth: usize,
599    leaf_count: usize,
600    total_edges: usize,
601    variables: HashSet<String>,
602    num_constants: usize,
603    num_predicates: usize,
604    num_quantifiers: usize,
605    num_negations: usize,
606    num_connectives: usize,
607    num_arithmetic: usize,
608    num_set_ops: usize,
609    num_let_bindings: usize,
610}
611
612// ---------------------------------------------------------------------------
613// Thresholds & warnings
614// ---------------------------------------------------------------------------
615
616/// Configurable thresholds for complexity warnings.
617#[derive(Debug, Clone)]
618pub struct ComplexityThresholds {
619    pub max_depth: usize,
620    pub max_nodes: usize,
621    pub max_quantifier_depth: usize,
622    pub max_variables: usize,
623    pub max_branching_factor: f64,
624}
625
626impl Default for ComplexityThresholds {
627    fn default() -> Self {
628        Self {
629            max_depth: 50,
630            max_nodes: 10000,
631            max_quantifier_depth: 10,
632            max_variables: 100,
633            max_branching_factor: 10.0,
634        }
635    }
636}
637
638/// A complexity warning produced by [`check_complexity`].
639#[derive(Debug, Clone)]
640pub struct ComplexityWarning {
641    pub metric: String,
642    pub value: f64,
643    pub threshold: f64,
644    pub severity: WarningSeverity,
645    pub message: String,
646}
647
648impl fmt::Display for ComplexityWarning {
649    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
650        write!(f, "[{:?}] {}", self.severity, self.message)
651    }
652}
653
654/// Severity level for a complexity warning.
655#[derive(Debug, Clone, PartialEq, Eq)]
656pub enum WarningSeverity {
657    Info,
658    Warning,
659    Critical,
660}
661
662/// Check complexity metrics against thresholds and return any warnings.
663pub fn check_complexity(
664    complexity: &ExprComplexity,
665    thresholds: &ComplexityThresholds,
666) -> Vec<ComplexityWarning> {
667    let mut warnings = Vec::new();
668
669    let checks: Vec<(&str, f64, f64)> = vec![
670        (
671            "depth",
672            complexity.depth as f64,
673            thresholds.max_depth as f64,
674        ),
675        (
676            "total_nodes",
677            complexity.total_nodes as f64,
678            thresholds.max_nodes as f64,
679        ),
680        (
681            "quantifier_depth",
682            complexity.quantifier_depth as f64,
683            thresholds.max_quantifier_depth as f64,
684        ),
685        (
686            "num_variables",
687            complexity.num_variables as f64,
688            thresholds.max_variables as f64,
689        ),
690        (
691            "branching_factor",
692            complexity.branching_factor,
693            thresholds.max_branching_factor,
694        ),
695    ];
696
697    for (metric, value, threshold) in checks {
698        if value > threshold {
699            let ratio = value / threshold;
700            let severity = if ratio > 2.0 {
701                WarningSeverity::Critical
702            } else if ratio > 1.5 {
703                WarningSeverity::Warning
704            } else {
705                WarningSeverity::Info
706            };
707            warnings.push(ComplexityWarning {
708                metric: metric.to_string(),
709                value,
710                threshold,
711                severity,
712                message: format!(
713                    "{} ({:.0}) exceeds threshold ({:.0})",
714                    metric, value, threshold
715                ),
716            });
717        }
718    }
719
720    warnings
721}
722
723// ---------------------------------------------------------------------------
724// ComplexityComparison
725// ---------------------------------------------------------------------------
726
727/// Side-by-side comparison of complexity between two expressions.
728#[derive(Debug, Clone)]
729pub struct ComplexityComparison {
730    pub before: ExprComplexity,
731    pub after: ExprComplexity,
732    pub node_delta: i64,
733    pub depth_delta: i64,
734    pub score_delta: f64,
735    /// `true` when the *after* expression is simpler (lower score).
736    pub improved: bool,
737}
738
739impl ComplexityComparison {
740    /// Compare complexity of two expressions (`before` → `after`).
741    pub fn compare(before: &TLExpr, after: &TLExpr) -> Self {
742        let b = ExprComplexity::analyze(before);
743        let a = ExprComplexity::analyze(after);
744        let node_delta = a.total_nodes as i64 - b.total_nodes as i64;
745        let depth_delta = a.depth as i64 - b.depth as i64;
746        let score_delta = a.score() - b.score();
747        let improved = a.score() < b.score();
748        Self {
749            before: b,
750            after: a,
751            node_delta,
752            depth_delta,
753            score_delta,
754            improved,
755        }
756    }
757
758    /// Human-readable summary of the comparison.
759    pub fn summary(&self) -> String {
760        let direction = if self.improved {
761            "improved"
762        } else if self.score_delta.abs() < f64::EPSILON {
763            "unchanged"
764        } else {
765            "regressed"
766        };
767        format!(
768            "Complexity {}: nodes {} -> {} ({:+}), depth {} -> {} ({:+}), score {:.1} -> {:.1} ({:+.1})",
769            direction,
770            self.before.total_nodes,
771            self.after.total_nodes,
772            self.node_delta,
773            self.before.depth,
774            self.after.depth,
775            self.depth_delta,
776            self.before.score(),
777            self.after.score(),
778            self.score_delta,
779        )
780    }
781}
782
783impl fmt::Display for ComplexityComparison {
784    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
785        write!(f, "{}", self.summary())
786    }
787}
788
789// ---------------------------------------------------------------------------
790// BatchComplexityStats
791// ---------------------------------------------------------------------------
792
793/// Aggregate complexity statistics for a batch of expressions.
794#[derive(Debug, Clone)]
795pub struct BatchComplexityStats {
796    pub count: usize,
797    pub avg_nodes: f64,
798    pub avg_depth: f64,
799    pub max_nodes: usize,
800    pub max_depth: usize,
801    pub avg_score: f64,
802    pub above_threshold_count: usize,
803}
804
805impl BatchComplexityStats {
806    /// Compute aggregate stats for a slice of expressions.
807    pub fn from_exprs(exprs: &[TLExpr], thresholds: &ComplexityThresholds) -> Self {
808        if exprs.is_empty() {
809            return Self {
810                count: 0,
811                avg_nodes: 0.0,
812                avg_depth: 0.0,
813                max_nodes: 0,
814                max_depth: 0,
815                avg_score: 0.0,
816                above_threshold_count: 0,
817            };
818        }
819
820        let metrics: Vec<ExprComplexity> = exprs.iter().map(ExprComplexity::analyze).collect();
821        let count = metrics.len();
822        let total_nodes_sum: usize = metrics.iter().map(|m| m.total_nodes).sum();
823        let total_depth_sum: usize = metrics.iter().map(|m| m.depth).sum();
824        let total_score_sum: f64 = metrics.iter().map(|m| m.score()).sum();
825        let max_nodes = metrics.iter().map(|m| m.total_nodes).max().unwrap_or(0);
826        let max_depth = metrics.iter().map(|m| m.depth).max().unwrap_or(0);
827        let above_threshold_count = metrics
828            .iter()
829            .filter(|m| !check_complexity(m, thresholds).is_empty())
830            .count();
831
832        Self {
833            count,
834            avg_nodes: total_nodes_sum as f64 / count as f64,
835            avg_depth: total_depth_sum as f64 / count as f64,
836            max_nodes,
837            max_depth,
838            avg_score: total_score_sum / count as f64,
839            above_threshold_count,
840        }
841    }
842
843    /// Human-readable summary of batch statistics.
844    pub fn summary(&self) -> String {
845        format!(
846            "Batch of {} exprs: avg nodes={:.1}, avg depth={:.1}, max nodes={}, max depth={}, avg score={:.1}, {} above threshold",
847            self.count,
848            self.avg_nodes,
849            self.avg_depth,
850            self.max_nodes,
851            self.max_depth,
852            self.avg_score,
853            self.above_threshold_count,
854        )
855    }
856}
857
858impl fmt::Display for BatchComplexityStats {
859    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
860        write!(f, "{}", self.summary())
861    }
862}
863
864// ---------------------------------------------------------------------------
865// Tests
866// ---------------------------------------------------------------------------
867
868#[cfg(test)]
869mod tests {
870    use super::*;
871    use tensorlogic_ir::{TLExpr, Term};
872
873    fn var_expr(name: &str) -> TLExpr {
874        TLExpr::pred(name, vec![Term::var(name)])
875    }
876
877    #[test]
878    fn test_leaf_node_complexity() {
879        let expr = TLExpr::Constant(1.0);
880        let c = ExprComplexity::analyze(&expr);
881        assert_eq!(c.total_nodes, 1);
882        assert_eq!(c.depth, 1);
883    }
884
885    #[test]
886    fn test_not_depth() {
887        let expr = TLExpr::negate(TLExpr::Constant(1.0));
888        let c = ExprComplexity::analyze(&expr);
889        assert_eq!(c.depth, 2);
890    }
891
892    #[test]
893    fn test_and_node_count() {
894        let expr = TLExpr::and(TLExpr::Constant(1.0), TLExpr::Constant(2.0));
895        let c = ExprComplexity::analyze(&expr);
896        assert_eq!(c.total_nodes, 3);
897    }
898
899    #[test]
900    fn test_quantifier_counted() {
901        let body = var_expr("x");
902        let expr = TLExpr::forall("x", "D", body);
903        let c = ExprComplexity::analyze(&expr);
904        assert_eq!(c.num_quantifiers, 1);
905    }
906
907    #[test]
908    fn test_quantifier_depth_nested() {
909        let inner = TLExpr::exists("y", "D", var_expr("y"));
910        let expr = TLExpr::forall("x", "D", inner);
911        let c = ExprComplexity::analyze(&expr);
912        assert_eq!(c.quantifier_depth, 2);
913    }
914
915    #[test]
916    fn test_negation_counted() {
917        let expr = TLExpr::negate(TLExpr::Constant(1.0));
918        let c = ExprComplexity::analyze(&expr);
919        assert_eq!(c.num_negations, 1);
920    }
921
922    #[test]
923    fn test_connective_counted() {
924        let expr = TLExpr::and(TLExpr::Constant(1.0), TLExpr::Constant(2.0));
925        let c = ExprComplexity::analyze(&expr);
926        assert_eq!(c.num_connectives, 1);
927    }
928
929    #[test]
930    fn test_distinct_variables() {
931        // And(Pred("x", [Var("x")]), Pred("x", [Var("x")]))
932        let expr = TLExpr::and(var_expr("x"), var_expr("x"));
933        let c = ExprComplexity::analyze(&expr);
934        assert_eq!(c.num_variables, 1);
935    }
936
937    #[test]
938    fn test_branching_factor_binary() {
939        // And(Const, Const) => 1 internal node with 2 children => bf = 2.0
940        let expr = TLExpr::and(TLExpr::Constant(1.0), TLExpr::Constant(2.0));
941        let c = ExprComplexity::analyze(&expr);
942        assert!(c.branching_factor >= 1.0);
943        assert!(c.branching_factor <= 2.5);
944    }
945
946    #[test]
947    fn test_leaf_ratio() {
948        // And(Const, Const) = 3 nodes, 2 leaves => 2/3 ≈ 0.667
949        let expr = TLExpr::and(TLExpr::Constant(1.0), TLExpr::Constant(2.0));
950        let c = ExprComplexity::analyze(&expr);
951        let expected = 2.0 / 3.0;
952        assert!((c.leaf_ratio - expected).abs() < 0.01);
953    }
954
955    #[test]
956    fn test_score_increases_with_complexity() {
957        let simple = TLExpr::Constant(1.0);
958        let complex = TLExpr::forall(
959            "x",
960            "D",
961            TLExpr::exists("y", "D", TLExpr::and(var_expr("x"), var_expr("y"))),
962        );
963        let s1 = ExprComplexity::analyze(&simple).score();
964        let s2 = ExprComplexity::analyze(&complex).score();
965        assert!(s2 > s1);
966    }
967
968    #[test]
969    fn test_is_simple_true() {
970        let expr = TLExpr::Constant(42.0);
971        let c = ExprComplexity::analyze(&expr);
972        assert!(c.is_simple(100.0));
973    }
974
975    #[test]
976    fn test_is_simple_false() {
977        // Build a moderately complex expression.
978        let mut expr = var_expr("x");
979        for i in 0..20 {
980            expr = TLExpr::forall(format!("v{}", i), "D", expr);
981        }
982        let c = ExprComplexity::analyze(&expr);
983        assert!(!c.is_simple(10.0));
984    }
985
986    #[test]
987    fn test_summary_nonempty() {
988        let c = ExprComplexity::analyze(&TLExpr::Constant(1.0));
989        let s = c.summary();
990        assert!(!s.is_empty());
991        assert!(s.contains("nodes="));
992    }
993
994    #[test]
995    fn test_format_table_has_header() {
996        let c = ExprComplexity::analyze(&TLExpr::Constant(1.0));
997        let table = c.format_table();
998        assert!(table.contains("Metric"));
999        assert!(table.contains("Value"));
1000        assert!(table.contains("Total nodes"));
1001    }
1002
1003    #[test]
1004    fn test_thresholds_default() {
1005        let t = ComplexityThresholds::default();
1006        assert_eq!(t.max_depth, 50);
1007        assert_eq!(t.max_nodes, 10000);
1008        assert_eq!(t.max_quantifier_depth, 10);
1009        assert_eq!(t.max_variables, 100);
1010        assert!((t.max_branching_factor - 10.0).abs() < f64::EPSILON);
1011    }
1012
1013    #[test]
1014    fn test_check_complexity_no_warnings() {
1015        let c = ExprComplexity::analyze(&TLExpr::Constant(1.0));
1016        let warnings = check_complexity(&c, &ComplexityThresholds::default());
1017        assert!(warnings.is_empty());
1018    }
1019
1020    #[test]
1021    fn test_check_complexity_with_warning() {
1022        // Build an expression deeper than the threshold (max_depth = 3).
1023        let mut expr = TLExpr::Constant(1.0);
1024        for _ in 0..10 {
1025            expr = TLExpr::negate(expr);
1026        }
1027        let c = ExprComplexity::analyze(&expr);
1028        let thresholds = ComplexityThresholds {
1029            max_depth: 3,
1030            ..ComplexityThresholds::default()
1031        };
1032        let warnings = check_complexity(&c, &thresholds);
1033        assert!(!warnings.is_empty());
1034        assert!(warnings.iter().any(|w| w.metric == "depth"));
1035    }
1036
1037    #[test]
1038    fn test_complexity_comparison_improved() {
1039        let complex = TLExpr::and(
1040            TLExpr::forall("x", "D", var_expr("x")),
1041            TLExpr::exists("y", "D", var_expr("y")),
1042        );
1043        let simple = TLExpr::Constant(1.0);
1044        let cmp = ComplexityComparison::compare(&complex, &simple);
1045        assert!(cmp.improved);
1046        assert!(cmp.node_delta < 0);
1047    }
1048
1049    #[test]
1050    fn test_batch_stats_avg() {
1051        let exprs = vec![
1052            TLExpr::Constant(1.0),
1053            TLExpr::Constant(2.0),
1054            TLExpr::and(TLExpr::Constant(3.0), TLExpr::Constant(4.0)),
1055        ];
1056        let stats = BatchComplexityStats::from_exprs(&exprs, &ComplexityThresholds::default());
1057        assert_eq!(stats.count, 3);
1058        // 1 + 1 + 3 = 5, avg = 5/3 ≈ 1.667
1059        let expected_avg = 5.0 / 3.0;
1060        assert!((stats.avg_nodes - expected_avg).abs() < 0.01);
1061    }
1062}