tensorlogic_ir/
util.rs

1//! Utility functions for the IR.
2//!
3//! This module provides helper functions for:
4//! - Pretty printing expressions and graphs
5//! - Computing IR statistics
6//! - Formatting and display utilities
7
8use std::fmt::{self, Write};
9
10use crate::{EinsumGraph, TLExpr, Term};
11
12/// Pretty-print a TLExpr to a string.
13pub fn pretty_print_expr(expr: &TLExpr) -> String {
14    let mut buffer = String::new();
15    pretty_print_expr_inner(expr, &mut buffer, 0).unwrap();
16    buffer
17}
18
19fn pretty_print_expr_inner(expr: &TLExpr, buf: &mut String, indent: usize) -> fmt::Result {
20    let spaces = "  ".repeat(indent);
21
22    match expr {
23        TLExpr::Pred { name, args } => {
24            write!(buf, "{}{}(", spaces, name)?;
25            for (i, arg) in args.iter().enumerate() {
26                if i > 0 {
27                    write!(buf, ", ")?;
28                }
29                write!(buf, "{}", term_to_string(arg))?;
30            }
31            writeln!(buf, ")")?;
32        }
33        TLExpr::And(l, r) => {
34            writeln!(buf, "{}AND(", spaces)?;
35            pretty_print_expr_inner(l, buf, indent + 1)?;
36            writeln!(buf, "{},", spaces)?;
37            pretty_print_expr_inner(r, buf, indent + 1)?;
38            writeln!(buf, "{})", spaces)?;
39        }
40        TLExpr::Or(l, r) => {
41            writeln!(buf, "{}OR(", spaces)?;
42            pretty_print_expr_inner(l, buf, indent + 1)?;
43            writeln!(buf, "{},", spaces)?;
44            pretty_print_expr_inner(r, buf, indent + 1)?;
45            writeln!(buf, "{})", spaces)?;
46        }
47        TLExpr::Not(e) => {
48            writeln!(buf, "{}NOT(", spaces)?;
49            pretty_print_expr_inner(e, buf, indent + 1)?;
50            writeln!(buf, "{})", spaces)?;
51        }
52        TLExpr::Exists { var, domain, body } => {
53            writeln!(buf, "{}∃{}:{}.(", spaces, var, domain)?;
54            pretty_print_expr_inner(body, buf, indent + 1)?;
55            writeln!(buf, "{})", spaces)?;
56        }
57        TLExpr::ForAll { var, domain, body } => {
58            writeln!(buf, "{}∀{}:{}.(", spaces, var, domain)?;
59            pretty_print_expr_inner(body, buf, indent + 1)?;
60            writeln!(buf, "{})", spaces)?;
61        }
62        TLExpr::Aggregate {
63            op,
64            var,
65            domain,
66            body,
67            group_by,
68        } => {
69            write!(buf, "{}AGG_{:?}({}:{}", spaces, op, var, domain)?;
70            if let Some(group_vars) = group_by {
71                write!(buf, " GROUP BY {:?}", group_vars)?;
72            }
73            writeln!(buf, ")(")?;
74            pretty_print_expr_inner(body, buf, indent + 1)?;
75            writeln!(buf, "{})", spaces)?;
76        }
77        TLExpr::Imply(premise, conclusion) => {
78            writeln!(buf, "{}IMPLY(", spaces)?;
79            pretty_print_expr_inner(premise, buf, indent + 1)?;
80            writeln!(buf, "{}⇒", spaces)?;
81            pretty_print_expr_inner(conclusion, buf, indent + 1)?;
82            writeln!(buf, "{})", spaces)?;
83        }
84        TLExpr::Score(e) => {
85            writeln!(buf, "{}SCORE(", spaces)?;
86            pretty_print_expr_inner(e, buf, indent + 1)?;
87            writeln!(buf, "{})", spaces)?;
88        }
89        TLExpr::Add(l, r) => {
90            writeln!(buf, "{}ADD(", spaces)?;
91            pretty_print_expr_inner(l, buf, indent + 1)?;
92            writeln!(buf, "{},", spaces)?;
93            pretty_print_expr_inner(r, buf, indent + 1)?;
94            writeln!(buf, "{})", spaces)?;
95        }
96        TLExpr::Sub(l, r) => {
97            writeln!(buf, "{}SUB(", spaces)?;
98            pretty_print_expr_inner(l, buf, indent + 1)?;
99            writeln!(buf, "{},", spaces)?;
100            pretty_print_expr_inner(r, buf, indent + 1)?;
101            writeln!(buf, "{})", spaces)?;
102        }
103        TLExpr::Mul(l, r) => {
104            writeln!(buf, "{}MUL(", spaces)?;
105            pretty_print_expr_inner(l, buf, indent + 1)?;
106            writeln!(buf, "{},", spaces)?;
107            pretty_print_expr_inner(r, buf, indent + 1)?;
108            writeln!(buf, "{})", spaces)?;
109        }
110        TLExpr::Div(l, r) => {
111            writeln!(buf, "{}DIV(", spaces)?;
112            pretty_print_expr_inner(l, buf, indent + 1)?;
113            writeln!(buf, "{},", spaces)?;
114            pretty_print_expr_inner(r, buf, indent + 1)?;
115            writeln!(buf, "{})", spaces)?;
116        }
117        TLExpr::Pow(l, r) => {
118            writeln!(buf, "{}POW(", spaces)?;
119            pretty_print_expr_inner(l, buf, indent + 1)?;
120            writeln!(buf, "{},", spaces)?;
121            pretty_print_expr_inner(r, buf, indent + 1)?;
122            writeln!(buf, "{})", spaces)?;
123        }
124        TLExpr::Mod(l, r) => {
125            writeln!(buf, "{}MOD(", spaces)?;
126            pretty_print_expr_inner(l, buf, indent + 1)?;
127            writeln!(buf, "{},", spaces)?;
128            pretty_print_expr_inner(r, buf, indent + 1)?;
129            writeln!(buf, "{})", spaces)?;
130        }
131        TLExpr::Min(l, r) => {
132            writeln!(buf, "{}MIN(", spaces)?;
133            pretty_print_expr_inner(l, buf, indent + 1)?;
134            writeln!(buf, "{},", spaces)?;
135            pretty_print_expr_inner(r, buf, indent + 1)?;
136            writeln!(buf, "{})", spaces)?;
137        }
138        TLExpr::Max(l, r) => {
139            writeln!(buf, "{}MAX(", spaces)?;
140            pretty_print_expr_inner(l, buf, indent + 1)?;
141            writeln!(buf, "{},", spaces)?;
142            pretty_print_expr_inner(r, buf, indent + 1)?;
143            writeln!(buf, "{})", spaces)?;
144        }
145        TLExpr::Abs(e) => {
146            writeln!(buf, "{}ABS(", spaces)?;
147            pretty_print_expr_inner(e, buf, indent + 1)?;
148            writeln!(buf, "{})", spaces)?;
149        }
150        TLExpr::Floor(e) => {
151            writeln!(buf, "{}FLOOR(", spaces)?;
152            pretty_print_expr_inner(e, buf, indent + 1)?;
153            writeln!(buf, "{})", spaces)?;
154        }
155        TLExpr::Ceil(e) => {
156            writeln!(buf, "{}CEIL(", spaces)?;
157            pretty_print_expr_inner(e, buf, indent + 1)?;
158            writeln!(buf, "{})", spaces)?;
159        }
160        TLExpr::Round(e) => {
161            writeln!(buf, "{}ROUND(", spaces)?;
162            pretty_print_expr_inner(e, buf, indent + 1)?;
163            writeln!(buf, "{})", spaces)?;
164        }
165        TLExpr::Sqrt(e) => {
166            writeln!(buf, "{}SQRT(", spaces)?;
167            pretty_print_expr_inner(e, buf, indent + 1)?;
168            writeln!(buf, "{})", spaces)?;
169        }
170        TLExpr::Exp(e) => {
171            writeln!(buf, "{}EXP(", spaces)?;
172            pretty_print_expr_inner(e, buf, indent + 1)?;
173            writeln!(buf, "{})", spaces)?;
174        }
175        TLExpr::Log(e) => {
176            writeln!(buf, "{}LOG(", spaces)?;
177            pretty_print_expr_inner(e, buf, indent + 1)?;
178            writeln!(buf, "{})", spaces)?;
179        }
180        TLExpr::Sin(e) => {
181            writeln!(buf, "{}SIN(", spaces)?;
182            pretty_print_expr_inner(e, buf, indent + 1)?;
183            writeln!(buf, "{})", spaces)?;
184        }
185        TLExpr::Cos(e) => {
186            writeln!(buf, "{}COS(", spaces)?;
187            pretty_print_expr_inner(e, buf, indent + 1)?;
188            writeln!(buf, "{})", spaces)?;
189        }
190        TLExpr::Tan(e) => {
191            writeln!(buf, "{}TAN(", spaces)?;
192            pretty_print_expr_inner(e, buf, indent + 1)?;
193            writeln!(buf, "{})", spaces)?;
194        }
195        TLExpr::Box(e) => {
196            writeln!(buf, "{}BOX(", spaces)?;
197            pretty_print_expr_inner(e, buf, indent + 1)?;
198            writeln!(buf, "{})", spaces)?;
199        }
200        TLExpr::Diamond(e) => {
201            writeln!(buf, "{}DIAMOND(", spaces)?;
202            pretty_print_expr_inner(e, buf, indent + 1)?;
203            writeln!(buf, "{})", spaces)?;
204        }
205        TLExpr::Next(e) => {
206            writeln!(buf, "{}NEXT(", spaces)?;
207            pretty_print_expr_inner(e, buf, indent + 1)?;
208            writeln!(buf, "{})", spaces)?;
209        }
210        TLExpr::Eventually(e) => {
211            writeln!(buf, "{}EVENTUALLY(", spaces)?;
212            pretty_print_expr_inner(e, buf, indent + 1)?;
213            writeln!(buf, "{})", spaces)?;
214        }
215        TLExpr::Always(e) => {
216            writeln!(buf, "{}ALWAYS(", spaces)?;
217            pretty_print_expr_inner(e, buf, indent + 1)?;
218            writeln!(buf, "{})", spaces)?;
219        }
220        TLExpr::Until { before, after } => {
221            writeln!(buf, "{}UNTIL(", spaces)?;
222            pretty_print_expr_inner(before, buf, indent + 1)?;
223            writeln!(buf, "{},", spaces)?;
224            pretty_print_expr_inner(after, buf, indent + 1)?;
225            writeln!(buf, "{})", spaces)?;
226        }
227
228        // Fuzzy logic operators
229        TLExpr::TNorm { kind, left, right } => {
230            writeln!(buf, "{}T-NORM_{:?}(", spaces, kind)?;
231            pretty_print_expr_inner(left, buf, indent + 1)?;
232            writeln!(buf, "{},", spaces)?;
233            pretty_print_expr_inner(right, buf, indent + 1)?;
234            writeln!(buf, "{})", spaces)?;
235        }
236        TLExpr::TCoNorm { kind, left, right } => {
237            writeln!(buf, "{}T-CONORM_{:?}(", spaces, kind)?;
238            pretty_print_expr_inner(left, buf, indent + 1)?;
239            writeln!(buf, "{},", spaces)?;
240            pretty_print_expr_inner(right, buf, indent + 1)?;
241            writeln!(buf, "{})", spaces)?;
242        }
243        TLExpr::FuzzyNot { kind, expr } => {
244            writeln!(buf, "{}FUZZY-NOT_{:?}(", spaces, kind)?;
245            pretty_print_expr_inner(expr, buf, indent + 1)?;
246            writeln!(buf, "{})", spaces)?;
247        }
248        TLExpr::FuzzyImplication {
249            kind,
250            premise,
251            conclusion,
252        } => {
253            writeln!(buf, "{}FUZZY-IMPLY_{:?}(", spaces, kind)?;
254            pretty_print_expr_inner(premise, buf, indent + 1)?;
255            writeln!(buf, "{}⇒", spaces)?;
256            pretty_print_expr_inner(conclusion, buf, indent + 1)?;
257            writeln!(buf, "{})", spaces)?;
258        }
259
260        // Probabilistic operators
261        TLExpr::SoftExists {
262            var,
263            domain,
264            body,
265            temperature,
266        } => {
267            writeln!(
268                buf,
269                "{}SOFT-∃{}:{}[T={}](",
270                spaces, var, domain, temperature
271            )?;
272            pretty_print_expr_inner(body, buf, indent + 1)?;
273            writeln!(buf, "{})", spaces)?;
274        }
275        TLExpr::SoftForAll {
276            var,
277            domain,
278            body,
279            temperature,
280        } => {
281            writeln!(
282                buf,
283                "{}SOFT-∀{}:{}[T={}](",
284                spaces, var, domain, temperature
285            )?;
286            pretty_print_expr_inner(body, buf, indent + 1)?;
287            writeln!(buf, "{})", spaces)?;
288        }
289        TLExpr::WeightedRule { weight, rule } => {
290            writeln!(buf, "{}WEIGHTED[{}](", spaces, weight)?;
291            pretty_print_expr_inner(rule, buf, indent + 1)?;
292            writeln!(buf, "{})", spaces)?;
293        }
294        TLExpr::ProbabilisticChoice { alternatives } => {
295            writeln!(buf, "{}PROB-CHOICE[", spaces)?;
296            for (i, (prob, expr)) in alternatives.iter().enumerate() {
297                if i > 0 {
298                    writeln!(buf, "{},", spaces)?;
299                }
300                writeln!(buf, "{}  {}: ", spaces, prob)?;
301                pretty_print_expr_inner(expr, buf, indent + 2)?;
302            }
303            writeln!(buf, "{}]", spaces)?;
304        }
305
306        // Extended temporal logic
307        TLExpr::Release { released, releaser } => {
308            writeln!(buf, "{}RELEASE(", spaces)?;
309            pretty_print_expr_inner(released, buf, indent + 1)?;
310            writeln!(buf, "{},", spaces)?;
311            pretty_print_expr_inner(releaser, buf, indent + 1)?;
312            writeln!(buf, "{})", spaces)?;
313        }
314        TLExpr::WeakUntil { before, after } => {
315            writeln!(buf, "{}WEAK-UNTIL(", spaces)?;
316            pretty_print_expr_inner(before, buf, indent + 1)?;
317            writeln!(buf, "{},", spaces)?;
318            pretty_print_expr_inner(after, buf, indent + 1)?;
319            writeln!(buf, "{})", spaces)?;
320        }
321        TLExpr::StrongRelease { released, releaser } => {
322            writeln!(buf, "{}STRONG-RELEASE(", spaces)?;
323            pretty_print_expr_inner(released, buf, indent + 1)?;
324            writeln!(buf, "{},", spaces)?;
325            pretty_print_expr_inner(releaser, buf, indent + 1)?;
326            writeln!(buf, "{})", spaces)?;
327        }
328
329        TLExpr::Eq(l, r) => {
330            writeln!(buf, "{}EQ(", spaces)?;
331            pretty_print_expr_inner(l, buf, indent + 1)?;
332            writeln!(buf, "{},", spaces)?;
333            pretty_print_expr_inner(r, buf, indent + 1)?;
334            writeln!(buf, "{})", spaces)?;
335        }
336        TLExpr::Lt(l, r) => {
337            writeln!(buf, "{}LT(", spaces)?;
338            pretty_print_expr_inner(l, buf, indent + 1)?;
339            writeln!(buf, "{},", spaces)?;
340            pretty_print_expr_inner(r, buf, indent + 1)?;
341            writeln!(buf, "{})", spaces)?;
342        }
343        TLExpr::Gt(l, r) => {
344            writeln!(buf, "{}GT(", spaces)?;
345            pretty_print_expr_inner(l, buf, indent + 1)?;
346            writeln!(buf, "{},", spaces)?;
347            pretty_print_expr_inner(r, buf, indent + 1)?;
348            writeln!(buf, "{})", spaces)?;
349        }
350        TLExpr::Lte(l, r) => {
351            writeln!(buf, "{}LTE(", spaces)?;
352            pretty_print_expr_inner(l, buf, indent + 1)?;
353            writeln!(buf, "{},", spaces)?;
354            pretty_print_expr_inner(r, buf, indent + 1)?;
355            writeln!(buf, "{})", spaces)?;
356        }
357        TLExpr::Gte(l, r) => {
358            writeln!(buf, "{}GTE(", spaces)?;
359            pretty_print_expr_inner(l, buf, indent + 1)?;
360            writeln!(buf, "{},", spaces)?;
361            pretty_print_expr_inner(r, buf, indent + 1)?;
362            writeln!(buf, "{})", spaces)?;
363        }
364        TLExpr::IfThenElse {
365            condition,
366            then_branch,
367            else_branch,
368        } => {
369            writeln!(buf, "{}IF(", spaces)?;
370            pretty_print_expr_inner(condition, buf, indent + 1)?;
371            writeln!(buf, "{}) THEN(", spaces)?;
372            pretty_print_expr_inner(then_branch, buf, indent + 1)?;
373            writeln!(buf, "{}) ELSE(", spaces)?;
374            pretty_print_expr_inner(else_branch, buf, indent + 1)?;
375            writeln!(buf, "{})", spaces)?;
376        }
377        TLExpr::Let { var, value, body } => {
378            writeln!(buf, "{}LET {} =(", spaces, var)?;
379            pretty_print_expr_inner(value, buf, indent + 1)?;
380            writeln!(buf, "{}) IN(", spaces)?;
381            pretty_print_expr_inner(body, buf, indent + 1)?;
382            writeln!(buf, "{})", spaces)?;
383        }
384        TLExpr::Constant(value) => {
385            writeln!(buf, "{}{}", spaces, value)?;
386        }
387    }
388
389    Ok(())
390}
391
392fn term_to_string(term: &Term) -> String {
393    match term {
394        Term::Var(name) => format!("?{}", name),
395        Term::Const(name) => name.clone(),
396        Term::Typed {
397            value,
398            type_annotation,
399        } => format!("{}:{}", term_to_string(value), type_annotation.type_name),
400    }
401}
402
403/// Statistics about a TLExpr.
404#[derive(Debug, Clone, PartialEq, Eq)]
405pub struct ExprStats {
406    /// Total number of nodes in the expression tree
407    pub node_count: usize,
408    /// Maximum depth of the expression tree
409    pub max_depth: usize,
410    /// Number of predicates
411    pub predicate_count: usize,
412    /// Number of quantifiers (exists + forall)
413    pub quantifier_count: usize,
414    /// Number of logical operators (and, or, not, imply)
415    pub logical_op_count: usize,
416    /// Number of arithmetic operators
417    pub arithmetic_op_count: usize,
418    /// Number of comparison operators
419    pub comparison_op_count: usize,
420    /// Number of free variables
421    pub free_var_count: usize,
422}
423
424impl ExprStats {
425    /// Compute statistics for an expression.
426    pub fn compute(expr: &TLExpr) -> Self {
427        let mut stats = ExprStats {
428            node_count: 0,
429            max_depth: 0,
430            predicate_count: 0,
431            quantifier_count: 0,
432            logical_op_count: 0,
433            arithmetic_op_count: 0,
434            comparison_op_count: 0,
435            free_var_count: expr.free_vars().len(),
436        };
437
438        stats.max_depth = Self::compute_recursive(expr, &mut stats, 0);
439        stats
440    }
441
442    fn compute_recursive(expr: &TLExpr, stats: &mut ExprStats, depth: usize) -> usize {
443        stats.node_count += 1;
444        let mut max_child_depth = depth;
445
446        match expr {
447            TLExpr::Pred { .. } => {
448                stats.predicate_count += 1;
449            }
450            TLExpr::And(l, r) | TLExpr::Or(l, r) | TLExpr::Imply(l, r) => {
451                stats.logical_op_count += 1;
452                let left_depth = Self::compute_recursive(l, stats, depth + 1);
453                let right_depth = Self::compute_recursive(r, stats, depth + 1);
454                max_child_depth = left_depth.max(right_depth);
455            }
456            TLExpr::Not(e) | TLExpr::Score(e) => {
457                stats.logical_op_count += 1;
458                max_child_depth = Self::compute_recursive(e, stats, depth + 1);
459            }
460            TLExpr::Exists { body, .. } | TLExpr::ForAll { body, .. } => {
461                stats.quantifier_count += 1;
462                max_child_depth = Self::compute_recursive(body, stats, depth + 1);
463            }
464            TLExpr::Aggregate { body, .. } => {
465                stats.quantifier_count += 1; // Aggregates are similar to quantifiers
466                max_child_depth = Self::compute_recursive(body, stats, depth + 1);
467            }
468            TLExpr::Add(l, r)
469            | TLExpr::Sub(l, r)
470            | TLExpr::Mul(l, r)
471            | TLExpr::Div(l, r)
472            | TLExpr::Pow(l, r)
473            | TLExpr::Mod(l, r)
474            | TLExpr::Min(l, r)
475            | TLExpr::Max(l, r) => {
476                stats.arithmetic_op_count += 1;
477                let left_depth = Self::compute_recursive(l, stats, depth + 1);
478                let right_depth = Self::compute_recursive(r, stats, depth + 1);
479                max_child_depth = left_depth.max(right_depth);
480            }
481            TLExpr::Abs(e)
482            | TLExpr::Floor(e)
483            | TLExpr::Ceil(e)
484            | TLExpr::Round(e)
485            | TLExpr::Sqrt(e)
486            | TLExpr::Exp(e)
487            | TLExpr::Log(e)
488            | TLExpr::Sin(e)
489            | TLExpr::Cos(e)
490            | TLExpr::Tan(e)
491            | TLExpr::Box(e)
492            | TLExpr::Diamond(e)
493            | TLExpr::Next(e)
494            | TLExpr::Eventually(e)
495            | TLExpr::Always(e) => {
496                stats.arithmetic_op_count += 1;
497                max_child_depth = Self::compute_recursive(e, stats, depth + 1);
498            }
499            TLExpr::Until { before, after } => {
500                stats.logical_op_count += 1;
501                let depth_before = Self::compute_recursive(before, stats, depth + 1);
502                let depth_after = Self::compute_recursive(after, stats, depth + 1);
503                max_child_depth = depth_before.max(depth_after);
504            }
505            TLExpr::Eq(l, r)
506            | TLExpr::Lt(l, r)
507            | TLExpr::Gt(l, r)
508            | TLExpr::Lte(l, r)
509            | TLExpr::Gte(l, r) => {
510                stats.comparison_op_count += 1;
511                let left_depth = Self::compute_recursive(l, stats, depth + 1);
512                let right_depth = Self::compute_recursive(r, stats, depth + 1);
513                max_child_depth = left_depth.max(right_depth);
514            }
515            TLExpr::IfThenElse {
516                condition,
517                then_branch,
518                else_branch,
519            } => {
520                let cond_depth = Self::compute_recursive(condition, stats, depth + 1);
521                let then_depth = Self::compute_recursive(then_branch, stats, depth + 1);
522                let else_depth = Self::compute_recursive(else_branch, stats, depth + 1);
523                max_child_depth = cond_depth.max(then_depth).max(else_depth);
524            }
525            TLExpr::Let { value, body, .. } => {
526                let value_depth = Self::compute_recursive(value, stats, depth + 1);
527                let body_depth = Self::compute_recursive(body, stats, depth + 1);
528                max_child_depth = value_depth.max(body_depth);
529            }
530
531            // Fuzzy logic operators
532            TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
533                stats.logical_op_count += 1;
534                let left_depth = Self::compute_recursive(left, stats, depth + 1);
535                let right_depth = Self::compute_recursive(right, stats, depth + 1);
536                max_child_depth = left_depth.max(right_depth);
537            }
538            TLExpr::FuzzyNot { expr, .. } => {
539                stats.logical_op_count += 1;
540                max_child_depth = Self::compute_recursive(expr, stats, depth + 1);
541            }
542            TLExpr::FuzzyImplication {
543                premise,
544                conclusion,
545                ..
546            } => {
547                stats.logical_op_count += 1;
548                let prem_depth = Self::compute_recursive(premise, stats, depth + 1);
549                let conc_depth = Self::compute_recursive(conclusion, stats, depth + 1);
550                max_child_depth = prem_depth.max(conc_depth);
551            }
552
553            // Probabilistic operators
554            TLExpr::SoftExists { body, .. } | TLExpr::SoftForAll { body, .. } => {
555                stats.quantifier_count += 1;
556                max_child_depth = Self::compute_recursive(body, stats, depth + 1);
557            }
558            TLExpr::WeightedRule { rule, .. } => {
559                stats.logical_op_count += 1;
560                max_child_depth = Self::compute_recursive(rule, stats, depth + 1);
561            }
562            TLExpr::ProbabilisticChoice { alternatives } => {
563                stats.logical_op_count += 1;
564                let mut max_alt_depth = depth;
565                for (_, expr) in alternatives {
566                    let alt_depth = Self::compute_recursive(expr, stats, depth + 1);
567                    max_alt_depth = max_alt_depth.max(alt_depth);
568                }
569                max_child_depth = max_alt_depth;
570            }
571
572            // Extended temporal logic
573            TLExpr::Release { released, releaser }
574            | TLExpr::WeakUntil {
575                before: released,
576                after: releaser,
577            }
578            | TLExpr::StrongRelease { released, releaser } => {
579                stats.logical_op_count += 1;
580                let rel_depth = Self::compute_recursive(released, stats, depth + 1);
581                let reler_depth = Self::compute_recursive(releaser, stats, depth + 1);
582                max_child_depth = rel_depth.max(reler_depth);
583            }
584
585            TLExpr::Constant(_) => {
586                // Leaf node
587            }
588        }
589
590        max_child_depth
591    }
592}
593
594/// Statistics about an EinsumGraph.
595#[derive(Debug, Clone, PartialEq)]
596pub struct GraphStats {
597    /// Number of tensors
598    pub tensor_count: usize,
599    /// Number of nodes
600    pub node_count: usize,
601    /// Number of output tensors
602    pub output_count: usize,
603    /// Number of einsum operations
604    pub einsum_count: usize,
605    /// Number of element-wise unary operations
606    pub elem_unary_count: usize,
607    /// Number of element-wise binary operations
608    pub elem_binary_count: usize,
609    /// Number of reduction operations
610    pub reduce_count: usize,
611    /// Average inputs per node
612    pub avg_inputs_per_node: f64,
613}
614
615impl GraphStats {
616    /// Compute statistics for a graph.
617    pub fn compute(graph: &EinsumGraph) -> Self {
618        let mut stats = GraphStats {
619            tensor_count: graph.tensors.len(),
620            node_count: graph.nodes.len(),
621            output_count: graph.outputs.len(),
622            einsum_count: 0,
623            elem_unary_count: 0,
624            elem_binary_count: 0,
625            reduce_count: 0,
626            avg_inputs_per_node: 0.0,
627        };
628
629        let mut total_inputs = 0;
630
631        for node in &graph.nodes {
632            total_inputs += node.inputs.len();
633
634            match &node.op {
635                crate::graph::OpType::Einsum { .. } => stats.einsum_count += 1,
636                crate::graph::OpType::ElemUnary { .. } => stats.elem_unary_count += 1,
637                crate::graph::OpType::ElemBinary { .. } => stats.elem_binary_count += 1,
638                crate::graph::OpType::Reduce { .. } => stats.reduce_count += 1,
639            }
640        }
641
642        if stats.node_count > 0 {
643            stats.avg_inputs_per_node = total_inputs as f64 / stats.node_count as f64;
644        }
645
646        stats
647    }
648}
649
650/// Pretty-print a graph to a string.
651pub fn pretty_print_graph(graph: &EinsumGraph) -> String {
652    let mut buffer = String::new();
653    writeln!(buffer, "EinsumGraph {{").unwrap();
654    writeln!(buffer, "  Tensors: {} total", graph.tensors.len()).unwrap();
655
656    for (idx, name) in graph.tensors.iter().enumerate() {
657        writeln!(buffer, "    t{}: {}", idx, name).unwrap();
658    }
659
660    writeln!(buffer, "  Nodes: {} total", graph.nodes.len()).unwrap();
661    for (idx, node) in graph.nodes.iter().enumerate() {
662        write!(buffer, "    n{}: ", idx).unwrap();
663        match &node.op {
664            crate::graph::OpType::Einsum { spec } => {
665                write!(buffer, "Einsum(\"{}\")", spec).unwrap()
666            }
667            crate::graph::OpType::ElemUnary { op } => write!(buffer, "ElemUnary({})", op).unwrap(),
668            crate::graph::OpType::ElemBinary { op } => {
669                write!(buffer, "ElemBinary({})", op).unwrap()
670            }
671            crate::graph::OpType::Reduce { op, axes } => {
672                write!(buffer, "Reduce({}, axes={:?})", op, axes).unwrap()
673            }
674        }
675        write!(buffer, " <- [").unwrap();
676        for (i, input) in node.inputs.iter().enumerate() {
677            if i > 0 {
678                write!(buffer, ", ").unwrap();
679            }
680            write!(buffer, "t{}", input).unwrap();
681        }
682        writeln!(buffer, "]").unwrap();
683    }
684
685    writeln!(buffer, "  Outputs: {:?}", graph.outputs).unwrap();
686    writeln!(buffer, "}}").unwrap();
687
688    buffer
689}
690
691#[cfg(test)]
692mod tests {
693    use super::*;
694
695    #[test]
696    fn test_expr_stats_simple() {
697        let expr = TLExpr::pred("P", vec![Term::var("x")]);
698        let stats = ExprStats::compute(&expr);
699
700        assert_eq!(stats.node_count, 1);
701        assert_eq!(stats.predicate_count, 1);
702        assert_eq!(stats.quantifier_count, 0);
703        assert_eq!(stats.free_var_count, 1);
704    }
705
706    #[test]
707    fn test_expr_stats_complex() {
708        // ∀x. P(x) ∧ Q(x)
709        let p = TLExpr::pred("P", vec![Term::var("x")]);
710        let q = TLExpr::pred("Q", vec![Term::var("x")]);
711        let and_expr = TLExpr::and(p, q);
712        let expr = TLExpr::forall("x", "Domain", and_expr);
713
714        let stats = ExprStats::compute(&expr);
715
716        assert_eq!(stats.node_count, 4); // forall, and, p, q
717        assert_eq!(stats.predicate_count, 2);
718        assert_eq!(stats.quantifier_count, 1);
719        assert_eq!(stats.logical_op_count, 1);
720        assert_eq!(stats.free_var_count, 0); // x is bound
721    }
722
723    #[test]
724    fn test_expr_stats_arithmetic() {
725        // score(x) * 2 + 1
726        let score = TLExpr::pred("score", vec![Term::var("x")]);
727        let mul = TLExpr::mul(score, TLExpr::constant(2.0));
728        let add = TLExpr::add(mul, TLExpr::constant(1.0));
729
730        let stats = ExprStats::compute(&add);
731
732        assert_eq!(stats.arithmetic_op_count, 2); // mul, add
733        assert_eq!(stats.predicate_count, 1);
734    }
735
736    #[test]
737    fn test_graph_stats() {
738        let mut graph = EinsumGraph::new();
739        let t0 = graph.add_tensor("input");
740        let t1 = graph.add_tensor("output");
741
742        graph
743            .add_node(crate::EinsumNode {
744                inputs: vec![t0],
745                outputs: vec![t1],
746                op: crate::graph::OpType::Einsum {
747                    spec: "i->i".to_string(),
748                },
749                metadata: None,
750            })
751            .unwrap();
752
753        graph.add_output(t1).unwrap();
754
755        let stats = GraphStats::compute(&graph);
756
757        assert_eq!(stats.tensor_count, 2);
758        assert_eq!(stats.node_count, 1);
759        assert_eq!(stats.output_count, 1);
760        assert_eq!(stats.einsum_count, 1);
761        assert_eq!(stats.avg_inputs_per_node, 1.0);
762    }
763
764    #[test]
765    fn test_pretty_print_expr() {
766        let expr = TLExpr::pred("Person", vec![Term::var("x")]);
767        let output = pretty_print_expr(&expr);
768        assert!(output.contains("Person(?x)"));
769    }
770
771    #[test]
772    fn test_pretty_print_graph() {
773        let mut graph = EinsumGraph::new();
774        let _t0 = graph.add_tensor("input");
775
776        let output = pretty_print_graph(&graph);
777        assert!(output.contains("t0: input"));
778        assert!(output.contains("Tensors: 1 total"));
779    }
780}