Skip to main content

tensorlogic_ir/
util.rs

1//! Utility functions for the IR.
2//!
3//! This module provides helper functions for:
4//! - Pretty printing expressions and graphs
5//! - Computing IR statistics
6//! - Formatting and display utilities
7
8use std::fmt::{self, Write};
9
10use crate::{EinsumGraph, TLExpr, Term};
11
12/// Pretty-print a TLExpr to a string.
13pub fn pretty_print_expr(expr: &TLExpr) -> String {
14    let mut buffer = String::new();
15    pretty_print_expr_inner(expr, &mut buffer, 0).expect("writing to String never fails");
16    buffer
17}
18
19fn pretty_print_expr_inner(expr: &TLExpr, buf: &mut String, indent: usize) -> fmt::Result {
20    let spaces = "  ".repeat(indent);
21
22    match expr {
23        TLExpr::Pred { name, args } => {
24            write!(buf, "{}{}(", spaces, name)?;
25            for (i, arg) in args.iter().enumerate() {
26                if i > 0 {
27                    write!(buf, ", ")?;
28                }
29                write!(buf, "{}", term_to_string(arg))?;
30            }
31            writeln!(buf, ")")?;
32        }
33        TLExpr::And(l, r) => {
34            writeln!(buf, "{}AND(", spaces)?;
35            pretty_print_expr_inner(l, buf, indent + 1)?;
36            writeln!(buf, "{},", spaces)?;
37            pretty_print_expr_inner(r, buf, indent + 1)?;
38            writeln!(buf, "{})", spaces)?;
39        }
40        TLExpr::Or(l, r) => {
41            writeln!(buf, "{}OR(", spaces)?;
42            pretty_print_expr_inner(l, buf, indent + 1)?;
43            writeln!(buf, "{},", spaces)?;
44            pretty_print_expr_inner(r, buf, indent + 1)?;
45            writeln!(buf, "{})", spaces)?;
46        }
47        TLExpr::Not(e) => {
48            writeln!(buf, "{}NOT(", spaces)?;
49            pretty_print_expr_inner(e, buf, indent + 1)?;
50            writeln!(buf, "{})", spaces)?;
51        }
52        TLExpr::Exists { var, domain, body } => {
53            writeln!(buf, "{}∃{}:{}.(", spaces, var, domain)?;
54            pretty_print_expr_inner(body, buf, indent + 1)?;
55            writeln!(buf, "{})", spaces)?;
56        }
57        TLExpr::ForAll { var, domain, body } => {
58            writeln!(buf, "{}∀{}:{}.(", spaces, var, domain)?;
59            pretty_print_expr_inner(body, buf, indent + 1)?;
60            writeln!(buf, "{})", spaces)?;
61        }
62        TLExpr::Aggregate {
63            op,
64            var,
65            domain,
66            body,
67            group_by,
68        } => {
69            write!(buf, "{}AGG_{:?}({}:{}", spaces, op, var, domain)?;
70            if let Some(group_vars) = group_by {
71                write!(buf, " GROUP BY {:?}", group_vars)?;
72            }
73            writeln!(buf, ")(")?;
74            pretty_print_expr_inner(body, buf, indent + 1)?;
75            writeln!(buf, "{})", spaces)?;
76        }
77        TLExpr::Imply(premise, conclusion) => {
78            writeln!(buf, "{}IMPLY(", spaces)?;
79            pretty_print_expr_inner(premise, buf, indent + 1)?;
80            writeln!(buf, "{}⇒", spaces)?;
81            pretty_print_expr_inner(conclusion, buf, indent + 1)?;
82            writeln!(buf, "{})", spaces)?;
83        }
84        TLExpr::Score(e) => {
85            writeln!(buf, "{}SCORE(", spaces)?;
86            pretty_print_expr_inner(e, buf, indent + 1)?;
87            writeln!(buf, "{})", spaces)?;
88        }
89        TLExpr::Add(l, r) => {
90            writeln!(buf, "{}ADD(", spaces)?;
91            pretty_print_expr_inner(l, buf, indent + 1)?;
92            writeln!(buf, "{},", spaces)?;
93            pretty_print_expr_inner(r, buf, indent + 1)?;
94            writeln!(buf, "{})", spaces)?;
95        }
96        TLExpr::Sub(l, r) => {
97            writeln!(buf, "{}SUB(", spaces)?;
98            pretty_print_expr_inner(l, buf, indent + 1)?;
99            writeln!(buf, "{},", spaces)?;
100            pretty_print_expr_inner(r, buf, indent + 1)?;
101            writeln!(buf, "{})", spaces)?;
102        }
103        TLExpr::Mul(l, r) => {
104            writeln!(buf, "{}MUL(", spaces)?;
105            pretty_print_expr_inner(l, buf, indent + 1)?;
106            writeln!(buf, "{},", spaces)?;
107            pretty_print_expr_inner(r, buf, indent + 1)?;
108            writeln!(buf, "{})", spaces)?;
109        }
110        TLExpr::Div(l, r) => {
111            writeln!(buf, "{}DIV(", spaces)?;
112            pretty_print_expr_inner(l, buf, indent + 1)?;
113            writeln!(buf, "{},", spaces)?;
114            pretty_print_expr_inner(r, buf, indent + 1)?;
115            writeln!(buf, "{})", spaces)?;
116        }
117        TLExpr::Pow(l, r) => {
118            writeln!(buf, "{}POW(", spaces)?;
119            pretty_print_expr_inner(l, buf, indent + 1)?;
120            writeln!(buf, "{},", spaces)?;
121            pretty_print_expr_inner(r, buf, indent + 1)?;
122            writeln!(buf, "{})", spaces)?;
123        }
124        TLExpr::Mod(l, r) => {
125            writeln!(buf, "{}MOD(", spaces)?;
126            pretty_print_expr_inner(l, buf, indent + 1)?;
127            writeln!(buf, "{},", spaces)?;
128            pretty_print_expr_inner(r, buf, indent + 1)?;
129            writeln!(buf, "{})", spaces)?;
130        }
131        TLExpr::Min(l, r) => {
132            writeln!(buf, "{}MIN(", spaces)?;
133            pretty_print_expr_inner(l, buf, indent + 1)?;
134            writeln!(buf, "{},", spaces)?;
135            pretty_print_expr_inner(r, buf, indent + 1)?;
136            writeln!(buf, "{})", spaces)?;
137        }
138        TLExpr::Max(l, r) => {
139            writeln!(buf, "{}MAX(", spaces)?;
140            pretty_print_expr_inner(l, buf, indent + 1)?;
141            writeln!(buf, "{},", spaces)?;
142            pretty_print_expr_inner(r, buf, indent + 1)?;
143            writeln!(buf, "{})", spaces)?;
144        }
145        TLExpr::Abs(e) => {
146            writeln!(buf, "{}ABS(", spaces)?;
147            pretty_print_expr_inner(e, buf, indent + 1)?;
148            writeln!(buf, "{})", spaces)?;
149        }
150        TLExpr::Floor(e) => {
151            writeln!(buf, "{}FLOOR(", spaces)?;
152            pretty_print_expr_inner(e, buf, indent + 1)?;
153            writeln!(buf, "{})", spaces)?;
154        }
155        TLExpr::Ceil(e) => {
156            writeln!(buf, "{}CEIL(", spaces)?;
157            pretty_print_expr_inner(e, buf, indent + 1)?;
158            writeln!(buf, "{})", spaces)?;
159        }
160        TLExpr::Round(e) => {
161            writeln!(buf, "{}ROUND(", spaces)?;
162            pretty_print_expr_inner(e, buf, indent + 1)?;
163            writeln!(buf, "{})", spaces)?;
164        }
165        TLExpr::Sqrt(e) => {
166            writeln!(buf, "{}SQRT(", spaces)?;
167            pretty_print_expr_inner(e, buf, indent + 1)?;
168            writeln!(buf, "{})", spaces)?;
169        }
170        TLExpr::Exp(e) => {
171            writeln!(buf, "{}EXP(", spaces)?;
172            pretty_print_expr_inner(e, buf, indent + 1)?;
173            writeln!(buf, "{})", spaces)?;
174        }
175        TLExpr::Log(e) => {
176            writeln!(buf, "{}LOG(", spaces)?;
177            pretty_print_expr_inner(e, buf, indent + 1)?;
178            writeln!(buf, "{})", spaces)?;
179        }
180        TLExpr::Sin(e) => {
181            writeln!(buf, "{}SIN(", spaces)?;
182            pretty_print_expr_inner(e, buf, indent + 1)?;
183            writeln!(buf, "{})", spaces)?;
184        }
185        TLExpr::Cos(e) => {
186            writeln!(buf, "{}COS(", spaces)?;
187            pretty_print_expr_inner(e, buf, indent + 1)?;
188            writeln!(buf, "{})", spaces)?;
189        }
190        TLExpr::Tan(e) => {
191            writeln!(buf, "{}TAN(", spaces)?;
192            pretty_print_expr_inner(e, buf, indent + 1)?;
193            writeln!(buf, "{})", spaces)?;
194        }
195        TLExpr::Box(e) => {
196            writeln!(buf, "{}BOX(", spaces)?;
197            pretty_print_expr_inner(e, buf, indent + 1)?;
198            writeln!(buf, "{})", spaces)?;
199        }
200        TLExpr::Diamond(e) => {
201            writeln!(buf, "{}DIAMOND(", spaces)?;
202            pretty_print_expr_inner(e, buf, indent + 1)?;
203            writeln!(buf, "{})", spaces)?;
204        }
205        TLExpr::Next(e) => {
206            writeln!(buf, "{}NEXT(", spaces)?;
207            pretty_print_expr_inner(e, buf, indent + 1)?;
208            writeln!(buf, "{})", spaces)?;
209        }
210        TLExpr::Eventually(e) => {
211            writeln!(buf, "{}EVENTUALLY(", spaces)?;
212            pretty_print_expr_inner(e, buf, indent + 1)?;
213            writeln!(buf, "{})", spaces)?;
214        }
215        TLExpr::Always(e) => {
216            writeln!(buf, "{}ALWAYS(", spaces)?;
217            pretty_print_expr_inner(e, buf, indent + 1)?;
218            writeln!(buf, "{})", spaces)?;
219        }
220        TLExpr::Until { before, after } => {
221            writeln!(buf, "{}UNTIL(", spaces)?;
222            pretty_print_expr_inner(before, buf, indent + 1)?;
223            writeln!(buf, "{},", spaces)?;
224            pretty_print_expr_inner(after, buf, indent + 1)?;
225            writeln!(buf, "{})", spaces)?;
226        }
227
228        // Fuzzy logic operators
229        TLExpr::TNorm { kind, left, right } => {
230            writeln!(buf, "{}T-NORM_{:?}(", spaces, kind)?;
231            pretty_print_expr_inner(left, buf, indent + 1)?;
232            writeln!(buf, "{},", spaces)?;
233            pretty_print_expr_inner(right, buf, indent + 1)?;
234            writeln!(buf, "{})", spaces)?;
235        }
236        TLExpr::TCoNorm { kind, left, right } => {
237            writeln!(buf, "{}T-CONORM_{:?}(", spaces, kind)?;
238            pretty_print_expr_inner(left, buf, indent + 1)?;
239            writeln!(buf, "{},", spaces)?;
240            pretty_print_expr_inner(right, buf, indent + 1)?;
241            writeln!(buf, "{})", spaces)?;
242        }
243        TLExpr::FuzzyNot { kind, expr } => {
244            writeln!(buf, "{}FUZZY-NOT_{:?}(", spaces, kind)?;
245            pretty_print_expr_inner(expr, buf, indent + 1)?;
246            writeln!(buf, "{})", spaces)?;
247        }
248        TLExpr::FuzzyImplication {
249            kind,
250            premise,
251            conclusion,
252        } => {
253            writeln!(buf, "{}FUZZY-IMPLY_{:?}(", spaces, kind)?;
254            pretty_print_expr_inner(premise, buf, indent + 1)?;
255            writeln!(buf, "{}⇒", spaces)?;
256            pretty_print_expr_inner(conclusion, buf, indent + 1)?;
257            writeln!(buf, "{})", spaces)?;
258        }
259
260        // Probabilistic operators
261        TLExpr::SoftExists {
262            var,
263            domain,
264            body,
265            temperature,
266        } => {
267            writeln!(
268                buf,
269                "{}SOFT-∃{}:{}[T={}](",
270                spaces, var, domain, temperature
271            )?;
272            pretty_print_expr_inner(body, buf, indent + 1)?;
273            writeln!(buf, "{})", spaces)?;
274        }
275        TLExpr::SoftForAll {
276            var,
277            domain,
278            body,
279            temperature,
280        } => {
281            writeln!(
282                buf,
283                "{}SOFT-∀{}:{}[T={}](",
284                spaces, var, domain, temperature
285            )?;
286            pretty_print_expr_inner(body, buf, indent + 1)?;
287            writeln!(buf, "{})", spaces)?;
288        }
289        TLExpr::WeightedRule { weight, rule } => {
290            writeln!(buf, "{}WEIGHTED[{}](", spaces, weight)?;
291            pretty_print_expr_inner(rule, buf, indent + 1)?;
292            writeln!(buf, "{})", spaces)?;
293        }
294        TLExpr::ProbabilisticChoice { alternatives } => {
295            writeln!(buf, "{}PROB-CHOICE[", spaces)?;
296            for (i, (prob, expr)) in alternatives.iter().enumerate() {
297                if i > 0 {
298                    writeln!(buf, "{},", spaces)?;
299                }
300                writeln!(buf, "{}  {}: ", spaces, prob)?;
301                pretty_print_expr_inner(expr, buf, indent + 2)?;
302            }
303            writeln!(buf, "{}]", spaces)?;
304        }
305
306        // Extended temporal logic
307        TLExpr::Release { released, releaser } => {
308            writeln!(buf, "{}RELEASE(", spaces)?;
309            pretty_print_expr_inner(released, buf, indent + 1)?;
310            writeln!(buf, "{},", spaces)?;
311            pretty_print_expr_inner(releaser, buf, indent + 1)?;
312            writeln!(buf, "{})", spaces)?;
313        }
314        TLExpr::WeakUntil { before, after } => {
315            writeln!(buf, "{}WEAK-UNTIL(", spaces)?;
316            pretty_print_expr_inner(before, buf, indent + 1)?;
317            writeln!(buf, "{},", spaces)?;
318            pretty_print_expr_inner(after, buf, indent + 1)?;
319            writeln!(buf, "{})", spaces)?;
320        }
321        TLExpr::StrongRelease { released, releaser } => {
322            writeln!(buf, "{}STRONG-RELEASE(", spaces)?;
323            pretty_print_expr_inner(released, buf, indent + 1)?;
324            writeln!(buf, "{},", spaces)?;
325            pretty_print_expr_inner(releaser, buf, indent + 1)?;
326            writeln!(buf, "{})", spaces)?;
327        }
328
329        TLExpr::Eq(l, r) => {
330            writeln!(buf, "{}EQ(", spaces)?;
331            pretty_print_expr_inner(l, buf, indent + 1)?;
332            writeln!(buf, "{},", spaces)?;
333            pretty_print_expr_inner(r, buf, indent + 1)?;
334            writeln!(buf, "{})", spaces)?;
335        }
336        TLExpr::Lt(l, r) => {
337            writeln!(buf, "{}LT(", spaces)?;
338            pretty_print_expr_inner(l, buf, indent + 1)?;
339            writeln!(buf, "{},", spaces)?;
340            pretty_print_expr_inner(r, buf, indent + 1)?;
341            writeln!(buf, "{})", spaces)?;
342        }
343        TLExpr::Gt(l, r) => {
344            writeln!(buf, "{}GT(", spaces)?;
345            pretty_print_expr_inner(l, buf, indent + 1)?;
346            writeln!(buf, "{},", spaces)?;
347            pretty_print_expr_inner(r, buf, indent + 1)?;
348            writeln!(buf, "{})", spaces)?;
349        }
350        TLExpr::Lte(l, r) => {
351            writeln!(buf, "{}LTE(", spaces)?;
352            pretty_print_expr_inner(l, buf, indent + 1)?;
353            writeln!(buf, "{},", spaces)?;
354            pretty_print_expr_inner(r, buf, indent + 1)?;
355            writeln!(buf, "{})", spaces)?;
356        }
357        TLExpr::Gte(l, r) => {
358            writeln!(buf, "{}GTE(", spaces)?;
359            pretty_print_expr_inner(l, buf, indent + 1)?;
360            writeln!(buf, "{},", spaces)?;
361            pretty_print_expr_inner(r, buf, indent + 1)?;
362            writeln!(buf, "{})", spaces)?;
363        }
364        TLExpr::IfThenElse {
365            condition,
366            then_branch,
367            else_branch,
368        } => {
369            writeln!(buf, "{}IF(", spaces)?;
370            pretty_print_expr_inner(condition, buf, indent + 1)?;
371            writeln!(buf, "{}) THEN(", spaces)?;
372            pretty_print_expr_inner(then_branch, buf, indent + 1)?;
373            writeln!(buf, "{}) ELSE(", spaces)?;
374            pretty_print_expr_inner(else_branch, buf, indent + 1)?;
375            writeln!(buf, "{})", spaces)?;
376        }
377        TLExpr::Let { var, value, body } => {
378            writeln!(buf, "{}LET {} =(", spaces, var)?;
379            pretty_print_expr_inner(value, buf, indent + 1)?;
380            writeln!(buf, "{}) IN(", spaces)?;
381            pretty_print_expr_inner(body, buf, indent + 1)?;
382            writeln!(buf, "{})", spaces)?;
383        }
384        // Alpha.3 enhancements
385        TLExpr::Lambda {
386            var,
387            var_type,
388            body,
389        } => {
390            if let Some(ty) = var_type {
391                writeln!(buf, "{}LAMBDA {}:{} ⇒(", spaces, var, ty)?;
392            } else {
393                writeln!(buf, "{}LAMBDA {} ⇒(", spaces, var)?;
394            }
395            pretty_print_expr_inner(body, buf, indent + 1)?;
396            writeln!(buf, "{})", spaces)?;
397        }
398        TLExpr::Apply { function, argument } => {
399            writeln!(buf, "{}APPLY(", spaces)?;
400            pretty_print_expr_inner(function, buf, indent + 1)?;
401            writeln!(buf, "{}TO", spaces)?;
402            pretty_print_expr_inner(argument, buf, indent + 1)?;
403            writeln!(buf, "{})", spaces)?;
404        }
405        TLExpr::SetMembership { element, set } => {
406            writeln!(buf, "{}MEMBER(", spaces)?;
407            pretty_print_expr_inner(element, buf, indent + 1)?;
408            writeln!(buf, "{}IN", spaces)?;
409            pretty_print_expr_inner(set, buf, indent + 1)?;
410            writeln!(buf, "{})", spaces)?;
411        }
412        TLExpr::SetUnion { left, right } => {
413            writeln!(buf, "{}UNION(", spaces)?;
414            pretty_print_expr_inner(left, buf, indent + 1)?;
415            writeln!(buf, "{},", spaces)?;
416            pretty_print_expr_inner(right, buf, indent + 1)?;
417            writeln!(buf, "{})", spaces)?;
418        }
419        TLExpr::SetIntersection { left, right } => {
420            writeln!(buf, "{}INTERSECT(", spaces)?;
421            pretty_print_expr_inner(left, buf, indent + 1)?;
422            writeln!(buf, "{},", spaces)?;
423            pretty_print_expr_inner(right, buf, indent + 1)?;
424            writeln!(buf, "{})", spaces)?;
425        }
426        TLExpr::SetDifference { left, right } => {
427            writeln!(buf, "{}DIFFERENCE(", spaces)?;
428            pretty_print_expr_inner(left, buf, indent + 1)?;
429            writeln!(buf, "{},", spaces)?;
430            pretty_print_expr_inner(right, buf, indent + 1)?;
431            writeln!(buf, "{})", spaces)?;
432        }
433        TLExpr::SetCardinality { set } => {
434            writeln!(buf, "{}CARDINALITY(", spaces)?;
435            pretty_print_expr_inner(set, buf, indent + 1)?;
436            writeln!(buf, "{})", spaces)?;
437        }
438        TLExpr::EmptySet => {
439            writeln!(buf, "{}EMPTY-SET", spaces)?;
440        }
441        TLExpr::SetComprehension {
442            var,
443            domain,
444            condition,
445        } => {
446            writeln!(buf, "{}SET-COMPREHENSION {{ {}:{} | ", spaces, var, domain)?;
447            pretty_print_expr_inner(condition, buf, indent + 1)?;
448            writeln!(buf, "{}}}", spaces)?;
449        }
450        TLExpr::CountingExists {
451            var,
452            domain,
453            body,
454            min_count,
455        } => {
456            writeln!(buf, "{}∃≥{}{}:{}.(", spaces, min_count, var, domain)?;
457            pretty_print_expr_inner(body, buf, indent + 1)?;
458            writeln!(buf, "{})", spaces)?;
459        }
460        TLExpr::CountingForAll {
461            var,
462            domain,
463            body,
464            min_count,
465        } => {
466            writeln!(buf, "{}∀≥{}{}:{}.(", spaces, min_count, var, domain)?;
467            pretty_print_expr_inner(body, buf, indent + 1)?;
468            writeln!(buf, "{})", spaces)?;
469        }
470        TLExpr::ExactCount {
471            var,
472            domain,
473            body,
474            count,
475        } => {
476            writeln!(buf, "{}∃={}{}:{}.(", spaces, count, var, domain)?;
477            pretty_print_expr_inner(body, buf, indent + 1)?;
478            writeln!(buf, "{})", spaces)?;
479        }
480        TLExpr::Majority { var, domain, body } => {
481            writeln!(buf, "{}MAJORITY {}:{}.(", spaces, var, domain)?;
482            pretty_print_expr_inner(body, buf, indent + 1)?;
483            writeln!(buf, "{})", spaces)?;
484        }
485        TLExpr::LeastFixpoint { var, body } => {
486            writeln!(buf, "{}μ{}.(", spaces, var)?;
487            pretty_print_expr_inner(body, buf, indent + 1)?;
488            writeln!(buf, "{})", spaces)?;
489        }
490        TLExpr::GreatestFixpoint { var, body } => {
491            writeln!(buf, "{}ν{}.(", spaces, var)?;
492            pretty_print_expr_inner(body, buf, indent + 1)?;
493            writeln!(buf, "{})", spaces)?;
494        }
495        TLExpr::Nominal { name } => {
496            writeln!(buf, "{}@{}", spaces, name)?;
497        }
498        TLExpr::At { nominal, formula } => {
499            writeln!(buf, "{}AT @{}(", spaces, nominal)?;
500            pretty_print_expr_inner(formula, buf, indent + 1)?;
501            writeln!(buf, "{})", spaces)?;
502        }
503        TLExpr::Somewhere { formula } => {
504            writeln!(buf, "{}SOMEWHERE(", spaces)?;
505            pretty_print_expr_inner(formula, buf, indent + 1)?;
506            writeln!(buf, "{})", spaces)?;
507        }
508        TLExpr::Everywhere { formula } => {
509            writeln!(buf, "{}EVERYWHERE(", spaces)?;
510            pretty_print_expr_inner(formula, buf, indent + 1)?;
511            writeln!(buf, "{})", spaces)?;
512        }
513        TLExpr::AllDifferent { variables } => {
514            writeln!(buf, "{}ALL-DIFFERENT({:?})", spaces, variables)?;
515        }
516        TLExpr::GlobalCardinality {
517            variables,
518            values,
519            min_occurrences,
520            max_occurrences,
521        } => {
522            writeln!(buf, "{}GLOBAL-CARDINALITY(", spaces)?;
523            writeln!(buf, "{}  vars: {:?}", spaces, variables)?;
524            writeln!(buf, "{}  constraints: [", spaces)?;
525            for (i, val) in values.iter().enumerate() {
526                write!(buf, "{}    ", spaces)?;
527                pretty_print_expr_inner(val, buf, 0)?;
528                writeln!(buf, ": [{}, {}]", min_occurrences[i], max_occurrences[i])?;
529            }
530            writeln!(buf, "{}  ]", spaces)?;
531            writeln!(buf, "{})", spaces)?;
532        }
533        TLExpr::Abducible { name, cost } => {
534            writeln!(buf, "{}ABDUCIBLE({}, cost={})", spaces, name, cost)?;
535        }
536        TLExpr::Explain { formula } => {
537            writeln!(buf, "{}EXPLAIN(", spaces)?;
538            pretty_print_expr_inner(formula, buf, indent + 1)?;
539            writeln!(buf, "{})", spaces)?;
540        }
541        TLExpr::Constant(value) => {
542            writeln!(buf, "{}{}", spaces, value)?;
543        }
544        TLExpr::SymbolLiteral(s) => {
545            writeln!(buf, "{}:{}", spaces, s)?;
546        }
547        TLExpr::Match { scrutinee, arms } => {
548            writeln!(buf, "{}MATCH(", spaces)?;
549            pretty_print_expr_inner(scrutinee, buf, indent + 1)?;
550            for (pat, body) in arms {
551                writeln!(buf, "{}  {} =>", spaces, pat)?;
552                pretty_print_expr_inner(body, buf, indent + 2)?;
553            }
554            writeln!(buf, "{})", spaces)?;
555        }
556    }
557
558    Ok(())
559}
560
561fn term_to_string(term: &Term) -> String {
562    match term {
563        Term::Var(name) => format!("?{}", name),
564        Term::Const(name) => name.clone(),
565        Term::Typed {
566            value,
567            type_annotation,
568        } => format!("{}:{}", term_to_string(value), type_annotation.type_name),
569    }
570}
571
572/// Statistics about a TLExpr.
573#[derive(Debug, Clone, PartialEq, Eq)]
574pub struct ExprStats {
575    /// Total number of nodes in the expression tree
576    pub node_count: usize,
577    /// Maximum depth of the expression tree
578    pub max_depth: usize,
579    /// Number of predicates
580    pub predicate_count: usize,
581    /// Number of quantifiers (exists + forall)
582    pub quantifier_count: usize,
583    /// Number of logical operators (and, or, not, imply)
584    pub logical_op_count: usize,
585    /// Number of arithmetic operators
586    pub arithmetic_op_count: usize,
587    /// Number of comparison operators
588    pub comparison_op_count: usize,
589    /// Number of free variables
590    pub free_var_count: usize,
591}
592
593impl ExprStats {
594    /// Compute statistics for an expression.
595    pub fn compute(expr: &TLExpr) -> Self {
596        let mut stats = ExprStats {
597            node_count: 0,
598            max_depth: 0,
599            predicate_count: 0,
600            quantifier_count: 0,
601            logical_op_count: 0,
602            arithmetic_op_count: 0,
603            comparison_op_count: 0,
604            free_var_count: expr.free_vars().len(),
605        };
606
607        stats.max_depth = Self::compute_recursive(expr, &mut stats, 0);
608        stats
609    }
610
611    fn compute_recursive(expr: &TLExpr, stats: &mut ExprStats, depth: usize) -> usize {
612        stats.node_count += 1;
613        let mut max_child_depth = depth;
614
615        match expr {
616            TLExpr::Pred { .. } => {
617                stats.predicate_count += 1;
618            }
619            TLExpr::And(l, r) | TLExpr::Or(l, r) | TLExpr::Imply(l, r) => {
620                stats.logical_op_count += 1;
621                let left_depth = Self::compute_recursive(l, stats, depth + 1);
622                let right_depth = Self::compute_recursive(r, stats, depth + 1);
623                max_child_depth = left_depth.max(right_depth);
624            }
625            TLExpr::Not(e) | TLExpr::Score(e) => {
626                stats.logical_op_count += 1;
627                max_child_depth = Self::compute_recursive(e, stats, depth + 1);
628            }
629            TLExpr::Exists { body, .. } | TLExpr::ForAll { body, .. } => {
630                stats.quantifier_count += 1;
631                max_child_depth = Self::compute_recursive(body, stats, depth + 1);
632            }
633            TLExpr::Aggregate { body, .. } => {
634                stats.quantifier_count += 1; // Aggregates are similar to quantifiers
635                max_child_depth = Self::compute_recursive(body, stats, depth + 1);
636            }
637            TLExpr::Add(l, r)
638            | TLExpr::Sub(l, r)
639            | TLExpr::Mul(l, r)
640            | TLExpr::Div(l, r)
641            | TLExpr::Pow(l, r)
642            | TLExpr::Mod(l, r)
643            | TLExpr::Min(l, r)
644            | TLExpr::Max(l, r) => {
645                stats.arithmetic_op_count += 1;
646                let left_depth = Self::compute_recursive(l, stats, depth + 1);
647                let right_depth = Self::compute_recursive(r, stats, depth + 1);
648                max_child_depth = left_depth.max(right_depth);
649            }
650            TLExpr::Abs(e)
651            | TLExpr::Floor(e)
652            | TLExpr::Ceil(e)
653            | TLExpr::Round(e)
654            | TLExpr::Sqrt(e)
655            | TLExpr::Exp(e)
656            | TLExpr::Log(e)
657            | TLExpr::Sin(e)
658            | TLExpr::Cos(e)
659            | TLExpr::Tan(e)
660            | TLExpr::Box(e)
661            | TLExpr::Diamond(e)
662            | TLExpr::Next(e)
663            | TLExpr::Eventually(e)
664            | TLExpr::Always(e) => {
665                stats.arithmetic_op_count += 1;
666                max_child_depth = Self::compute_recursive(e, stats, depth + 1);
667            }
668            TLExpr::Until { before, after } => {
669                stats.logical_op_count += 1;
670                let depth_before = Self::compute_recursive(before, stats, depth + 1);
671                let depth_after = Self::compute_recursive(after, stats, depth + 1);
672                max_child_depth = depth_before.max(depth_after);
673            }
674            TLExpr::Eq(l, r)
675            | TLExpr::Lt(l, r)
676            | TLExpr::Gt(l, r)
677            | TLExpr::Lte(l, r)
678            | TLExpr::Gte(l, r) => {
679                stats.comparison_op_count += 1;
680                let left_depth = Self::compute_recursive(l, stats, depth + 1);
681                let right_depth = Self::compute_recursive(r, stats, depth + 1);
682                max_child_depth = left_depth.max(right_depth);
683            }
684            TLExpr::IfThenElse {
685                condition,
686                then_branch,
687                else_branch,
688            } => {
689                let cond_depth = Self::compute_recursive(condition, stats, depth + 1);
690                let then_depth = Self::compute_recursive(then_branch, stats, depth + 1);
691                let else_depth = Self::compute_recursive(else_branch, stats, depth + 1);
692                max_child_depth = cond_depth.max(then_depth).max(else_depth);
693            }
694            TLExpr::Let { value, body, .. } => {
695                let value_depth = Self::compute_recursive(value, stats, depth + 1);
696                let body_depth = Self::compute_recursive(body, stats, depth + 1);
697                max_child_depth = value_depth.max(body_depth);
698            }
699
700            // Fuzzy logic operators
701            TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
702                stats.logical_op_count += 1;
703                let left_depth = Self::compute_recursive(left, stats, depth + 1);
704                let right_depth = Self::compute_recursive(right, stats, depth + 1);
705                max_child_depth = left_depth.max(right_depth);
706            }
707            TLExpr::FuzzyNot { expr, .. } => {
708                stats.logical_op_count += 1;
709                max_child_depth = Self::compute_recursive(expr, stats, depth + 1);
710            }
711            TLExpr::FuzzyImplication {
712                premise,
713                conclusion,
714                ..
715            } => {
716                stats.logical_op_count += 1;
717                let prem_depth = Self::compute_recursive(premise, stats, depth + 1);
718                let conc_depth = Self::compute_recursive(conclusion, stats, depth + 1);
719                max_child_depth = prem_depth.max(conc_depth);
720            }
721
722            // Probabilistic operators
723            TLExpr::SoftExists { body, .. } | TLExpr::SoftForAll { body, .. } => {
724                stats.quantifier_count += 1;
725                max_child_depth = Self::compute_recursive(body, stats, depth + 1);
726            }
727            TLExpr::WeightedRule { rule, .. } => {
728                stats.logical_op_count += 1;
729                max_child_depth = Self::compute_recursive(rule, stats, depth + 1);
730            }
731            TLExpr::ProbabilisticChoice { alternatives } => {
732                stats.logical_op_count += 1;
733                let mut max_alt_depth = depth;
734                for (_, expr) in alternatives {
735                    let alt_depth = Self::compute_recursive(expr, stats, depth + 1);
736                    max_alt_depth = max_alt_depth.max(alt_depth);
737                }
738                max_child_depth = max_alt_depth;
739            }
740
741            // Extended temporal logic
742            TLExpr::Release { released, releaser }
743            | TLExpr::WeakUntil {
744                before: released,
745                after: releaser,
746            }
747            | TLExpr::StrongRelease { released, releaser } => {
748                stats.logical_op_count += 1;
749                let rel_depth = Self::compute_recursive(released, stats, depth + 1);
750                let reler_depth = Self::compute_recursive(releaser, stats, depth + 1);
751                max_child_depth = rel_depth.max(reler_depth);
752            }
753
754            // Alpha.3 enhancements
755            TLExpr::Lambda { body, .. } => {
756                stats.quantifier_count += 1; // Lambda binds a variable
757                max_child_depth = Self::compute_recursive(body, stats, depth + 1);
758            }
759            TLExpr::Apply { function, argument } => {
760                stats.logical_op_count += 1;
761                let func_depth = Self::compute_recursive(function, stats, depth + 1);
762                let arg_depth = Self::compute_recursive(argument, stats, depth + 1);
763                max_child_depth = func_depth.max(arg_depth);
764            }
765            TLExpr::SetMembership { element, set }
766            | TLExpr::SetUnion {
767                left: element,
768                right: set,
769            }
770            | TLExpr::SetIntersection {
771                left: element,
772                right: set,
773            }
774            | TLExpr::SetDifference {
775                left: element,
776                right: set,
777            } => {
778                stats.logical_op_count += 1;
779                let elem_depth = Self::compute_recursive(element, stats, depth + 1);
780                let set_depth = Self::compute_recursive(set, stats, depth + 1);
781                max_child_depth = elem_depth.max(set_depth);
782            }
783            TLExpr::SetCardinality { set } => {
784                stats.arithmetic_op_count += 1;
785                max_child_depth = Self::compute_recursive(set, stats, depth + 1);
786            }
787            TLExpr::EmptySet => {
788                // Leaf node
789            }
790            TLExpr::SetComprehension { condition, .. } => {
791                stats.quantifier_count += 1;
792                max_child_depth = Self::compute_recursive(condition, stats, depth + 1);
793            }
794            TLExpr::CountingExists { body, .. }
795            | TLExpr::CountingForAll { body, .. }
796            | TLExpr::ExactCount { body, .. }
797            | TLExpr::Majority { body, .. } => {
798                stats.quantifier_count += 1;
799                max_child_depth = Self::compute_recursive(body, stats, depth + 1);
800            }
801            TLExpr::LeastFixpoint { body, .. } | TLExpr::GreatestFixpoint { body, .. } => {
802                stats.logical_op_count += 1;
803                max_child_depth = Self::compute_recursive(body, stats, depth + 1);
804            }
805            TLExpr::Nominal { .. } => {
806                // Leaf node
807            }
808            TLExpr::At { formula, .. } => {
809                stats.logical_op_count += 1;
810                max_child_depth = Self::compute_recursive(formula, stats, depth + 1);
811            }
812            TLExpr::Somewhere { formula } | TLExpr::Everywhere { formula } => {
813                stats.logical_op_count += 1;
814                max_child_depth = Self::compute_recursive(formula, stats, depth + 1);
815            }
816            TLExpr::AllDifferent { .. } => {
817                stats.logical_op_count += 1;
818                // Leaf node (no subexpressions)
819            }
820            TLExpr::GlobalCardinality { values, .. } => {
821                stats.logical_op_count += 1;
822                let mut max_val_depth = depth;
823                for val in values {
824                    let val_depth = Self::compute_recursive(val, stats, depth + 1);
825                    max_val_depth = max_val_depth.max(val_depth);
826                }
827                max_child_depth = max_val_depth;
828            }
829            TLExpr::Abducible { .. } => {
830                stats.predicate_count += 1;
831                // Leaf node
832            }
833            TLExpr::Explain { formula } => {
834                stats.logical_op_count += 1;
835                max_child_depth = Self::compute_recursive(formula, stats, depth + 1);
836            }
837
838            TLExpr::Constant(_) => {
839                // Leaf node
840            }
841            TLExpr::SymbolLiteral(_) => {
842                // Leaf node
843            }
844            TLExpr::Match { scrutinee, arms } => {
845                stats.logical_op_count += 1;
846                let sd = Self::compute_recursive(scrutinee, stats, depth + 1);
847                if sd > max_child_depth {
848                    max_child_depth = sd;
849                }
850                for (_, body) in arms {
851                    let bd = Self::compute_recursive(body, stats, depth + 1);
852                    if bd > max_child_depth {
853                        max_child_depth = bd;
854                    }
855                }
856            }
857        }
858
859        max_child_depth
860    }
861}
862
863/// Statistics about an EinsumGraph.
864#[derive(Debug, Clone, PartialEq)]
865pub struct GraphStats {
866    /// Number of tensors
867    pub tensor_count: usize,
868    /// Number of nodes
869    pub node_count: usize,
870    /// Number of output tensors
871    pub output_count: usize,
872    /// Number of einsum operations
873    pub einsum_count: usize,
874    /// Number of element-wise unary operations
875    pub elem_unary_count: usize,
876    /// Number of element-wise binary operations
877    pub elem_binary_count: usize,
878    /// Number of reduction operations
879    pub reduce_count: usize,
880    /// Average inputs per node
881    pub avg_inputs_per_node: f64,
882}
883
884impl GraphStats {
885    /// Compute statistics for a graph.
886    pub fn compute(graph: &EinsumGraph) -> Self {
887        let mut stats = GraphStats {
888            tensor_count: graph.tensors.len(),
889            node_count: graph.nodes.len(),
890            output_count: graph.outputs.len(),
891            einsum_count: 0,
892            elem_unary_count: 0,
893            elem_binary_count: 0,
894            reduce_count: 0,
895            avg_inputs_per_node: 0.0,
896        };
897
898        let mut total_inputs = 0;
899
900        for node in &graph.nodes {
901            total_inputs += node.inputs.len();
902
903            match &node.op {
904                crate::graph::OpType::Einsum { .. } => stats.einsum_count += 1,
905                crate::graph::OpType::ElemUnary { .. } => stats.elem_unary_count += 1,
906                crate::graph::OpType::ElemBinary { .. } => stats.elem_binary_count += 1,
907                crate::graph::OpType::Reduce { .. } => stats.reduce_count += 1,
908            }
909        }
910
911        if stats.node_count > 0 {
912            stats.avg_inputs_per_node = total_inputs as f64 / stats.node_count as f64;
913        }
914
915        stats
916    }
917}
918
919/// Pretty-print a graph to a string.
920pub fn pretty_print_graph(graph: &EinsumGraph) -> String {
921    let mut buffer = String::new();
922    writeln!(buffer, "EinsumGraph {{").expect("writing to String buffer never fails");
923    writeln!(buffer, "  Tensors: {} total", graph.tensors.len())
924        .expect("writing to String buffer never fails");
925
926    for (idx, name) in graph.tensors.iter().enumerate() {
927        writeln!(buffer, "    t{}: {}", idx, name).expect("writing to String buffer never fails");
928    }
929
930    writeln!(buffer, "  Nodes: {} total", graph.nodes.len())
931        .expect("writing to String buffer never fails");
932    for (idx, node) in graph.nodes.iter().enumerate() {
933        write!(buffer, "    n{}: ", idx).expect("writing to String buffer never fails");
934        match &node.op {
935            crate::graph::OpType::Einsum { spec } => write!(buffer, "Einsum(\"{}\")", spec)
936                .expect("writing to String buffer never fails"),
937            crate::graph::OpType::ElemUnary { op } => {
938                write!(buffer, "ElemUnary({})", op).expect("writing to String buffer never fails")
939            }
940            crate::graph::OpType::ElemBinary { op } => {
941                write!(buffer, "ElemBinary({})", op).expect("writing to String buffer never fails")
942            }
943            crate::graph::OpType::Reduce { op, axes } => {
944                write!(buffer, "Reduce({}, axes={:?})", op, axes)
945                    .expect("writing to String buffer never fails")
946            }
947        }
948        write!(buffer, " <- [").expect("writing to String buffer never fails");
949        for (i, input) in node.inputs.iter().enumerate() {
950            if i > 0 {
951                write!(buffer, ", ").expect("writing to String buffer never fails");
952            }
953            write!(buffer, "t{}", input).expect("writing to String buffer never fails");
954        }
955        writeln!(buffer, "]").expect("writing to String buffer never fails");
956    }
957
958    writeln!(buffer, "  Outputs: {:?}", graph.outputs)
959        .expect("writing to String buffer never fails");
960    writeln!(buffer, "}}").expect("writing to String buffer never fails");
961
962    buffer
963}
964
965#[cfg(test)]
966mod tests {
967    use super::*;
968
969    #[test]
970    fn test_expr_stats_simple() {
971        let expr = TLExpr::pred("P", vec![Term::var("x")]);
972        let stats = ExprStats::compute(&expr);
973
974        assert_eq!(stats.node_count, 1);
975        assert_eq!(stats.predicate_count, 1);
976        assert_eq!(stats.quantifier_count, 0);
977        assert_eq!(stats.free_var_count, 1);
978    }
979
980    #[test]
981    fn test_expr_stats_complex() {
982        // ∀x. P(x) ∧ Q(x)
983        let p = TLExpr::pred("P", vec![Term::var("x")]);
984        let q = TLExpr::pred("Q", vec![Term::var("x")]);
985        let and_expr = TLExpr::and(p, q);
986        let expr = TLExpr::forall("x", "Domain", and_expr);
987
988        let stats = ExprStats::compute(&expr);
989
990        assert_eq!(stats.node_count, 4); // forall, and, p, q
991        assert_eq!(stats.predicate_count, 2);
992        assert_eq!(stats.quantifier_count, 1);
993        assert_eq!(stats.logical_op_count, 1);
994        assert_eq!(stats.free_var_count, 0); // x is bound
995    }
996
997    #[test]
998    fn test_expr_stats_arithmetic() {
999        // score(x) * 2 + 1
1000        let score = TLExpr::pred("score", vec![Term::var("x")]);
1001        let mul = TLExpr::mul(score, TLExpr::constant(2.0));
1002        let add = TLExpr::add(mul, TLExpr::constant(1.0));
1003
1004        let stats = ExprStats::compute(&add);
1005
1006        assert_eq!(stats.arithmetic_op_count, 2); // mul, add
1007        assert_eq!(stats.predicate_count, 1);
1008    }
1009
1010    #[test]
1011    fn test_graph_stats() {
1012        let mut graph = EinsumGraph::new();
1013        let t0 = graph.add_tensor("input");
1014        let t1 = graph.add_tensor("output");
1015
1016        graph
1017            .add_node(crate::EinsumNode {
1018                inputs: vec![t0],
1019                outputs: vec![t1],
1020                op: crate::graph::OpType::Einsum {
1021                    spec: "i->i".to_string(),
1022                },
1023                metadata: None,
1024            })
1025            .expect("unwrap");
1026
1027        graph.add_output(t1).expect("unwrap");
1028
1029        let stats = GraphStats::compute(&graph);
1030
1031        assert_eq!(stats.tensor_count, 2);
1032        assert_eq!(stats.node_count, 1);
1033        assert_eq!(stats.output_count, 1);
1034        assert_eq!(stats.einsum_count, 1);
1035        assert_eq!(stats.avg_inputs_per_node, 1.0);
1036    }
1037
1038    #[test]
1039    fn test_pretty_print_expr() {
1040        let expr = TLExpr::pred("Person", vec![Term::var("x")]);
1041        let output = pretty_print_expr(&expr);
1042        assert!(output.contains("Person(?x)"));
1043    }
1044
1045    #[test]
1046    fn test_pretty_print_graph() {
1047        let mut graph = EinsumGraph::new();
1048        let _t0 = graph.add_tensor("input");
1049
1050        let output = pretty_print_graph(&graph);
1051        assert!(output.contains("t0: input"));
1052        assert!(output.contains("Tensors: 1 total"));
1053    }
1054}