Skip to main content

tensorlogic_ir/
util.rs

1//! Utility functions for the IR.
2//!
3//! This module provides helper functions for:
4//! - Pretty printing expressions and graphs
5//! - Computing IR statistics
6//! - Formatting and display utilities
7
8use std::fmt::{self, Write};
9
10use crate::{EinsumGraph, TLExpr, Term};
11
12/// Pretty-print a TLExpr to a string.
13pub fn pretty_print_expr(expr: &TLExpr) -> String {
14    let mut buffer = String::new();
15    pretty_print_expr_inner(expr, &mut buffer, 0).unwrap();
16    buffer
17}
18
19fn pretty_print_expr_inner(expr: &TLExpr, buf: &mut String, indent: usize) -> fmt::Result {
20    let spaces = "  ".repeat(indent);
21
22    match expr {
23        TLExpr::Pred { name, args } => {
24            write!(buf, "{}{}(", spaces, name)?;
25            for (i, arg) in args.iter().enumerate() {
26                if i > 0 {
27                    write!(buf, ", ")?;
28                }
29                write!(buf, "{}", term_to_string(arg))?;
30            }
31            writeln!(buf, ")")?;
32        }
33        TLExpr::And(l, r) => {
34            writeln!(buf, "{}AND(", spaces)?;
35            pretty_print_expr_inner(l, buf, indent + 1)?;
36            writeln!(buf, "{},", spaces)?;
37            pretty_print_expr_inner(r, buf, indent + 1)?;
38            writeln!(buf, "{})", spaces)?;
39        }
40        TLExpr::Or(l, r) => {
41            writeln!(buf, "{}OR(", spaces)?;
42            pretty_print_expr_inner(l, buf, indent + 1)?;
43            writeln!(buf, "{},", spaces)?;
44            pretty_print_expr_inner(r, buf, indent + 1)?;
45            writeln!(buf, "{})", spaces)?;
46        }
47        TLExpr::Not(e) => {
48            writeln!(buf, "{}NOT(", spaces)?;
49            pretty_print_expr_inner(e, buf, indent + 1)?;
50            writeln!(buf, "{})", spaces)?;
51        }
52        TLExpr::Exists { var, domain, body } => {
53            writeln!(buf, "{}∃{}:{}.(", spaces, var, domain)?;
54            pretty_print_expr_inner(body, buf, indent + 1)?;
55            writeln!(buf, "{})", spaces)?;
56        }
57        TLExpr::ForAll { var, domain, body } => {
58            writeln!(buf, "{}∀{}:{}.(", spaces, var, domain)?;
59            pretty_print_expr_inner(body, buf, indent + 1)?;
60            writeln!(buf, "{})", spaces)?;
61        }
62        TLExpr::Aggregate {
63            op,
64            var,
65            domain,
66            body,
67            group_by,
68        } => {
69            write!(buf, "{}AGG_{:?}({}:{}", spaces, op, var, domain)?;
70            if let Some(group_vars) = group_by {
71                write!(buf, " GROUP BY {:?}", group_vars)?;
72            }
73            writeln!(buf, ")(")?;
74            pretty_print_expr_inner(body, buf, indent + 1)?;
75            writeln!(buf, "{})", spaces)?;
76        }
77        TLExpr::Imply(premise, conclusion) => {
78            writeln!(buf, "{}IMPLY(", spaces)?;
79            pretty_print_expr_inner(premise, buf, indent + 1)?;
80            writeln!(buf, "{}⇒", spaces)?;
81            pretty_print_expr_inner(conclusion, buf, indent + 1)?;
82            writeln!(buf, "{})", spaces)?;
83        }
84        TLExpr::Score(e) => {
85            writeln!(buf, "{}SCORE(", spaces)?;
86            pretty_print_expr_inner(e, buf, indent + 1)?;
87            writeln!(buf, "{})", spaces)?;
88        }
89        TLExpr::Add(l, r) => {
90            writeln!(buf, "{}ADD(", spaces)?;
91            pretty_print_expr_inner(l, buf, indent + 1)?;
92            writeln!(buf, "{},", spaces)?;
93            pretty_print_expr_inner(r, buf, indent + 1)?;
94            writeln!(buf, "{})", spaces)?;
95        }
96        TLExpr::Sub(l, r) => {
97            writeln!(buf, "{}SUB(", spaces)?;
98            pretty_print_expr_inner(l, buf, indent + 1)?;
99            writeln!(buf, "{},", spaces)?;
100            pretty_print_expr_inner(r, buf, indent + 1)?;
101            writeln!(buf, "{})", spaces)?;
102        }
103        TLExpr::Mul(l, r) => {
104            writeln!(buf, "{}MUL(", spaces)?;
105            pretty_print_expr_inner(l, buf, indent + 1)?;
106            writeln!(buf, "{},", spaces)?;
107            pretty_print_expr_inner(r, buf, indent + 1)?;
108            writeln!(buf, "{})", spaces)?;
109        }
110        TLExpr::Div(l, r) => {
111            writeln!(buf, "{}DIV(", spaces)?;
112            pretty_print_expr_inner(l, buf, indent + 1)?;
113            writeln!(buf, "{},", spaces)?;
114            pretty_print_expr_inner(r, buf, indent + 1)?;
115            writeln!(buf, "{})", spaces)?;
116        }
117        TLExpr::Pow(l, r) => {
118            writeln!(buf, "{}POW(", spaces)?;
119            pretty_print_expr_inner(l, buf, indent + 1)?;
120            writeln!(buf, "{},", spaces)?;
121            pretty_print_expr_inner(r, buf, indent + 1)?;
122            writeln!(buf, "{})", spaces)?;
123        }
124        TLExpr::Mod(l, r) => {
125            writeln!(buf, "{}MOD(", spaces)?;
126            pretty_print_expr_inner(l, buf, indent + 1)?;
127            writeln!(buf, "{},", spaces)?;
128            pretty_print_expr_inner(r, buf, indent + 1)?;
129            writeln!(buf, "{})", spaces)?;
130        }
131        TLExpr::Min(l, r) => {
132            writeln!(buf, "{}MIN(", spaces)?;
133            pretty_print_expr_inner(l, buf, indent + 1)?;
134            writeln!(buf, "{},", spaces)?;
135            pretty_print_expr_inner(r, buf, indent + 1)?;
136            writeln!(buf, "{})", spaces)?;
137        }
138        TLExpr::Max(l, r) => {
139            writeln!(buf, "{}MAX(", spaces)?;
140            pretty_print_expr_inner(l, buf, indent + 1)?;
141            writeln!(buf, "{},", spaces)?;
142            pretty_print_expr_inner(r, buf, indent + 1)?;
143            writeln!(buf, "{})", spaces)?;
144        }
145        TLExpr::Abs(e) => {
146            writeln!(buf, "{}ABS(", spaces)?;
147            pretty_print_expr_inner(e, buf, indent + 1)?;
148            writeln!(buf, "{})", spaces)?;
149        }
150        TLExpr::Floor(e) => {
151            writeln!(buf, "{}FLOOR(", spaces)?;
152            pretty_print_expr_inner(e, buf, indent + 1)?;
153            writeln!(buf, "{})", spaces)?;
154        }
155        TLExpr::Ceil(e) => {
156            writeln!(buf, "{}CEIL(", spaces)?;
157            pretty_print_expr_inner(e, buf, indent + 1)?;
158            writeln!(buf, "{})", spaces)?;
159        }
160        TLExpr::Round(e) => {
161            writeln!(buf, "{}ROUND(", spaces)?;
162            pretty_print_expr_inner(e, buf, indent + 1)?;
163            writeln!(buf, "{})", spaces)?;
164        }
165        TLExpr::Sqrt(e) => {
166            writeln!(buf, "{}SQRT(", spaces)?;
167            pretty_print_expr_inner(e, buf, indent + 1)?;
168            writeln!(buf, "{})", spaces)?;
169        }
170        TLExpr::Exp(e) => {
171            writeln!(buf, "{}EXP(", spaces)?;
172            pretty_print_expr_inner(e, buf, indent + 1)?;
173            writeln!(buf, "{})", spaces)?;
174        }
175        TLExpr::Log(e) => {
176            writeln!(buf, "{}LOG(", spaces)?;
177            pretty_print_expr_inner(e, buf, indent + 1)?;
178            writeln!(buf, "{})", spaces)?;
179        }
180        TLExpr::Sin(e) => {
181            writeln!(buf, "{}SIN(", spaces)?;
182            pretty_print_expr_inner(e, buf, indent + 1)?;
183            writeln!(buf, "{})", spaces)?;
184        }
185        TLExpr::Cos(e) => {
186            writeln!(buf, "{}COS(", spaces)?;
187            pretty_print_expr_inner(e, buf, indent + 1)?;
188            writeln!(buf, "{})", spaces)?;
189        }
190        TLExpr::Tan(e) => {
191            writeln!(buf, "{}TAN(", spaces)?;
192            pretty_print_expr_inner(e, buf, indent + 1)?;
193            writeln!(buf, "{})", spaces)?;
194        }
195        TLExpr::Box(e) => {
196            writeln!(buf, "{}BOX(", spaces)?;
197            pretty_print_expr_inner(e, buf, indent + 1)?;
198            writeln!(buf, "{})", spaces)?;
199        }
200        TLExpr::Diamond(e) => {
201            writeln!(buf, "{}DIAMOND(", spaces)?;
202            pretty_print_expr_inner(e, buf, indent + 1)?;
203            writeln!(buf, "{})", spaces)?;
204        }
205        TLExpr::Next(e) => {
206            writeln!(buf, "{}NEXT(", spaces)?;
207            pretty_print_expr_inner(e, buf, indent + 1)?;
208            writeln!(buf, "{})", spaces)?;
209        }
210        TLExpr::Eventually(e) => {
211            writeln!(buf, "{}EVENTUALLY(", spaces)?;
212            pretty_print_expr_inner(e, buf, indent + 1)?;
213            writeln!(buf, "{})", spaces)?;
214        }
215        TLExpr::Always(e) => {
216            writeln!(buf, "{}ALWAYS(", spaces)?;
217            pretty_print_expr_inner(e, buf, indent + 1)?;
218            writeln!(buf, "{})", spaces)?;
219        }
220        TLExpr::Until { before, after } => {
221            writeln!(buf, "{}UNTIL(", spaces)?;
222            pretty_print_expr_inner(before, buf, indent + 1)?;
223            writeln!(buf, "{},", spaces)?;
224            pretty_print_expr_inner(after, buf, indent + 1)?;
225            writeln!(buf, "{})", spaces)?;
226        }
227
228        // Fuzzy logic operators
229        TLExpr::TNorm { kind, left, right } => {
230            writeln!(buf, "{}T-NORM_{:?}(", spaces, kind)?;
231            pretty_print_expr_inner(left, buf, indent + 1)?;
232            writeln!(buf, "{},", spaces)?;
233            pretty_print_expr_inner(right, buf, indent + 1)?;
234            writeln!(buf, "{})", spaces)?;
235        }
236        TLExpr::TCoNorm { kind, left, right } => {
237            writeln!(buf, "{}T-CONORM_{:?}(", spaces, kind)?;
238            pretty_print_expr_inner(left, buf, indent + 1)?;
239            writeln!(buf, "{},", spaces)?;
240            pretty_print_expr_inner(right, buf, indent + 1)?;
241            writeln!(buf, "{})", spaces)?;
242        }
243        TLExpr::FuzzyNot { kind, expr } => {
244            writeln!(buf, "{}FUZZY-NOT_{:?}(", spaces, kind)?;
245            pretty_print_expr_inner(expr, buf, indent + 1)?;
246            writeln!(buf, "{})", spaces)?;
247        }
248        TLExpr::FuzzyImplication {
249            kind,
250            premise,
251            conclusion,
252        } => {
253            writeln!(buf, "{}FUZZY-IMPLY_{:?}(", spaces, kind)?;
254            pretty_print_expr_inner(premise, buf, indent + 1)?;
255            writeln!(buf, "{}⇒", spaces)?;
256            pretty_print_expr_inner(conclusion, buf, indent + 1)?;
257            writeln!(buf, "{})", spaces)?;
258        }
259
260        // Probabilistic operators
261        TLExpr::SoftExists {
262            var,
263            domain,
264            body,
265            temperature,
266        } => {
267            writeln!(
268                buf,
269                "{}SOFT-∃{}:{}[T={}](",
270                spaces, var, domain, temperature
271            )?;
272            pretty_print_expr_inner(body, buf, indent + 1)?;
273            writeln!(buf, "{})", spaces)?;
274        }
275        TLExpr::SoftForAll {
276            var,
277            domain,
278            body,
279            temperature,
280        } => {
281            writeln!(
282                buf,
283                "{}SOFT-∀{}:{}[T={}](",
284                spaces, var, domain, temperature
285            )?;
286            pretty_print_expr_inner(body, buf, indent + 1)?;
287            writeln!(buf, "{})", spaces)?;
288        }
289        TLExpr::WeightedRule { weight, rule } => {
290            writeln!(buf, "{}WEIGHTED[{}](", spaces, weight)?;
291            pretty_print_expr_inner(rule, buf, indent + 1)?;
292            writeln!(buf, "{})", spaces)?;
293        }
294        TLExpr::ProbabilisticChoice { alternatives } => {
295            writeln!(buf, "{}PROB-CHOICE[", spaces)?;
296            for (i, (prob, expr)) in alternatives.iter().enumerate() {
297                if i > 0 {
298                    writeln!(buf, "{},", spaces)?;
299                }
300                writeln!(buf, "{}  {}: ", spaces, prob)?;
301                pretty_print_expr_inner(expr, buf, indent + 2)?;
302            }
303            writeln!(buf, "{}]", spaces)?;
304        }
305
306        // Extended temporal logic
307        TLExpr::Release { released, releaser } => {
308            writeln!(buf, "{}RELEASE(", spaces)?;
309            pretty_print_expr_inner(released, buf, indent + 1)?;
310            writeln!(buf, "{},", spaces)?;
311            pretty_print_expr_inner(releaser, buf, indent + 1)?;
312            writeln!(buf, "{})", spaces)?;
313        }
314        TLExpr::WeakUntil { before, after } => {
315            writeln!(buf, "{}WEAK-UNTIL(", spaces)?;
316            pretty_print_expr_inner(before, buf, indent + 1)?;
317            writeln!(buf, "{},", spaces)?;
318            pretty_print_expr_inner(after, buf, indent + 1)?;
319            writeln!(buf, "{})", spaces)?;
320        }
321        TLExpr::StrongRelease { released, releaser } => {
322            writeln!(buf, "{}STRONG-RELEASE(", spaces)?;
323            pretty_print_expr_inner(released, buf, indent + 1)?;
324            writeln!(buf, "{},", spaces)?;
325            pretty_print_expr_inner(releaser, buf, indent + 1)?;
326            writeln!(buf, "{})", spaces)?;
327        }
328
329        TLExpr::Eq(l, r) => {
330            writeln!(buf, "{}EQ(", spaces)?;
331            pretty_print_expr_inner(l, buf, indent + 1)?;
332            writeln!(buf, "{},", spaces)?;
333            pretty_print_expr_inner(r, buf, indent + 1)?;
334            writeln!(buf, "{})", spaces)?;
335        }
336        TLExpr::Lt(l, r) => {
337            writeln!(buf, "{}LT(", spaces)?;
338            pretty_print_expr_inner(l, buf, indent + 1)?;
339            writeln!(buf, "{},", spaces)?;
340            pretty_print_expr_inner(r, buf, indent + 1)?;
341            writeln!(buf, "{})", spaces)?;
342        }
343        TLExpr::Gt(l, r) => {
344            writeln!(buf, "{}GT(", spaces)?;
345            pretty_print_expr_inner(l, buf, indent + 1)?;
346            writeln!(buf, "{},", spaces)?;
347            pretty_print_expr_inner(r, buf, indent + 1)?;
348            writeln!(buf, "{})", spaces)?;
349        }
350        TLExpr::Lte(l, r) => {
351            writeln!(buf, "{}LTE(", spaces)?;
352            pretty_print_expr_inner(l, buf, indent + 1)?;
353            writeln!(buf, "{},", spaces)?;
354            pretty_print_expr_inner(r, buf, indent + 1)?;
355            writeln!(buf, "{})", spaces)?;
356        }
357        TLExpr::Gte(l, r) => {
358            writeln!(buf, "{}GTE(", spaces)?;
359            pretty_print_expr_inner(l, buf, indent + 1)?;
360            writeln!(buf, "{},", spaces)?;
361            pretty_print_expr_inner(r, buf, indent + 1)?;
362            writeln!(buf, "{})", spaces)?;
363        }
364        TLExpr::IfThenElse {
365            condition,
366            then_branch,
367            else_branch,
368        } => {
369            writeln!(buf, "{}IF(", spaces)?;
370            pretty_print_expr_inner(condition, buf, indent + 1)?;
371            writeln!(buf, "{}) THEN(", spaces)?;
372            pretty_print_expr_inner(then_branch, buf, indent + 1)?;
373            writeln!(buf, "{}) ELSE(", spaces)?;
374            pretty_print_expr_inner(else_branch, buf, indent + 1)?;
375            writeln!(buf, "{})", spaces)?;
376        }
377        TLExpr::Let { var, value, body } => {
378            writeln!(buf, "{}LET {} =(", spaces, var)?;
379            pretty_print_expr_inner(value, buf, indent + 1)?;
380            writeln!(buf, "{}) IN(", spaces)?;
381            pretty_print_expr_inner(body, buf, indent + 1)?;
382            writeln!(buf, "{})", spaces)?;
383        }
384        // Alpha.3 enhancements
385        TLExpr::Lambda {
386            var,
387            var_type,
388            body,
389        } => {
390            if let Some(ty) = var_type {
391                writeln!(buf, "{}LAMBDA {}:{} ⇒(", spaces, var, ty)?;
392            } else {
393                writeln!(buf, "{}LAMBDA {} ⇒(", spaces, var)?;
394            }
395            pretty_print_expr_inner(body, buf, indent + 1)?;
396            writeln!(buf, "{})", spaces)?;
397        }
398        TLExpr::Apply { function, argument } => {
399            writeln!(buf, "{}APPLY(", spaces)?;
400            pretty_print_expr_inner(function, buf, indent + 1)?;
401            writeln!(buf, "{}TO", spaces)?;
402            pretty_print_expr_inner(argument, buf, indent + 1)?;
403            writeln!(buf, "{})", spaces)?;
404        }
405        TLExpr::SetMembership { element, set } => {
406            writeln!(buf, "{}MEMBER(", spaces)?;
407            pretty_print_expr_inner(element, buf, indent + 1)?;
408            writeln!(buf, "{}IN", spaces)?;
409            pretty_print_expr_inner(set, buf, indent + 1)?;
410            writeln!(buf, "{})", spaces)?;
411        }
412        TLExpr::SetUnion { left, right } => {
413            writeln!(buf, "{}UNION(", spaces)?;
414            pretty_print_expr_inner(left, buf, indent + 1)?;
415            writeln!(buf, "{},", spaces)?;
416            pretty_print_expr_inner(right, buf, indent + 1)?;
417            writeln!(buf, "{})", spaces)?;
418        }
419        TLExpr::SetIntersection { left, right } => {
420            writeln!(buf, "{}INTERSECT(", spaces)?;
421            pretty_print_expr_inner(left, buf, indent + 1)?;
422            writeln!(buf, "{},", spaces)?;
423            pretty_print_expr_inner(right, buf, indent + 1)?;
424            writeln!(buf, "{})", spaces)?;
425        }
426        TLExpr::SetDifference { left, right } => {
427            writeln!(buf, "{}DIFFERENCE(", spaces)?;
428            pretty_print_expr_inner(left, buf, indent + 1)?;
429            writeln!(buf, "{},", spaces)?;
430            pretty_print_expr_inner(right, buf, indent + 1)?;
431            writeln!(buf, "{})", spaces)?;
432        }
433        TLExpr::SetCardinality { set } => {
434            writeln!(buf, "{}CARDINALITY(", spaces)?;
435            pretty_print_expr_inner(set, buf, indent + 1)?;
436            writeln!(buf, "{})", spaces)?;
437        }
438        TLExpr::EmptySet => {
439            writeln!(buf, "{}EMPTY-SET", spaces)?;
440        }
441        TLExpr::SetComprehension {
442            var,
443            domain,
444            condition,
445        } => {
446            writeln!(buf, "{}SET-COMPREHENSION {{ {}:{} | ", spaces, var, domain)?;
447            pretty_print_expr_inner(condition, buf, indent + 1)?;
448            writeln!(buf, "{}}}", spaces)?;
449        }
450        TLExpr::CountingExists {
451            var,
452            domain,
453            body,
454            min_count,
455        } => {
456            writeln!(buf, "{}∃≥{}{}:{}.(", spaces, min_count, var, domain)?;
457            pretty_print_expr_inner(body, buf, indent + 1)?;
458            writeln!(buf, "{})", spaces)?;
459        }
460        TLExpr::CountingForAll {
461            var,
462            domain,
463            body,
464            min_count,
465        } => {
466            writeln!(buf, "{}∀≥{}{}:{}.(", spaces, min_count, var, domain)?;
467            pretty_print_expr_inner(body, buf, indent + 1)?;
468            writeln!(buf, "{})", spaces)?;
469        }
470        TLExpr::ExactCount {
471            var,
472            domain,
473            body,
474            count,
475        } => {
476            writeln!(buf, "{}∃={}{}:{}.(", spaces, count, var, domain)?;
477            pretty_print_expr_inner(body, buf, indent + 1)?;
478            writeln!(buf, "{})", spaces)?;
479        }
480        TLExpr::Majority { var, domain, body } => {
481            writeln!(buf, "{}MAJORITY {}:{}.(", spaces, var, domain)?;
482            pretty_print_expr_inner(body, buf, indent + 1)?;
483            writeln!(buf, "{})", spaces)?;
484        }
485        TLExpr::LeastFixpoint { var, body } => {
486            writeln!(buf, "{}μ{}.(", spaces, var)?;
487            pretty_print_expr_inner(body, buf, indent + 1)?;
488            writeln!(buf, "{})", spaces)?;
489        }
490        TLExpr::GreatestFixpoint { var, body } => {
491            writeln!(buf, "{}ν{}.(", spaces, var)?;
492            pretty_print_expr_inner(body, buf, indent + 1)?;
493            writeln!(buf, "{})", spaces)?;
494        }
495        TLExpr::Nominal { name } => {
496            writeln!(buf, "{}@{}", spaces, name)?;
497        }
498        TLExpr::At { nominal, formula } => {
499            writeln!(buf, "{}AT @{}(", spaces, nominal)?;
500            pretty_print_expr_inner(formula, buf, indent + 1)?;
501            writeln!(buf, "{})", spaces)?;
502        }
503        TLExpr::Somewhere { formula } => {
504            writeln!(buf, "{}SOMEWHERE(", spaces)?;
505            pretty_print_expr_inner(formula, buf, indent + 1)?;
506            writeln!(buf, "{})", spaces)?;
507        }
508        TLExpr::Everywhere { formula } => {
509            writeln!(buf, "{}EVERYWHERE(", spaces)?;
510            pretty_print_expr_inner(formula, buf, indent + 1)?;
511            writeln!(buf, "{})", spaces)?;
512        }
513        TLExpr::AllDifferent { variables } => {
514            writeln!(buf, "{}ALL-DIFFERENT({:?})", spaces, variables)?;
515        }
516        TLExpr::GlobalCardinality {
517            variables,
518            values,
519            min_occurrences,
520            max_occurrences,
521        } => {
522            writeln!(buf, "{}GLOBAL-CARDINALITY(", spaces)?;
523            writeln!(buf, "{}  vars: {:?}", spaces, variables)?;
524            writeln!(buf, "{}  constraints: [", spaces)?;
525            for (i, val) in values.iter().enumerate() {
526                write!(buf, "{}    ", spaces)?;
527                pretty_print_expr_inner(val, buf, 0)?;
528                writeln!(buf, ": [{}, {}]", min_occurrences[i], max_occurrences[i])?;
529            }
530            writeln!(buf, "{}  ]", spaces)?;
531            writeln!(buf, "{})", spaces)?;
532        }
533        TLExpr::Abducible { name, cost } => {
534            writeln!(buf, "{}ABDUCIBLE({}, cost={})", spaces, name, cost)?;
535        }
536        TLExpr::Explain { formula } => {
537            writeln!(buf, "{}EXPLAIN(", spaces)?;
538            pretty_print_expr_inner(formula, buf, indent + 1)?;
539            writeln!(buf, "{})", spaces)?;
540        }
541        TLExpr::Constant(value) => {
542            writeln!(buf, "{}{}", spaces, value)?;
543        }
544    }
545
546    Ok(())
547}
548
549fn term_to_string(term: &Term) -> String {
550    match term {
551        Term::Var(name) => format!("?{}", name),
552        Term::Const(name) => name.clone(),
553        Term::Typed {
554            value,
555            type_annotation,
556        } => format!("{}:{}", term_to_string(value), type_annotation.type_name),
557    }
558}
559
560/// Statistics about a TLExpr.
561#[derive(Debug, Clone, PartialEq, Eq)]
562pub struct ExprStats {
563    /// Total number of nodes in the expression tree
564    pub node_count: usize,
565    /// Maximum depth of the expression tree
566    pub max_depth: usize,
567    /// Number of predicates
568    pub predicate_count: usize,
569    /// Number of quantifiers (exists + forall)
570    pub quantifier_count: usize,
571    /// Number of logical operators (and, or, not, imply)
572    pub logical_op_count: usize,
573    /// Number of arithmetic operators
574    pub arithmetic_op_count: usize,
575    /// Number of comparison operators
576    pub comparison_op_count: usize,
577    /// Number of free variables
578    pub free_var_count: usize,
579}
580
581impl ExprStats {
582    /// Compute statistics for an expression.
583    pub fn compute(expr: &TLExpr) -> Self {
584        let mut stats = ExprStats {
585            node_count: 0,
586            max_depth: 0,
587            predicate_count: 0,
588            quantifier_count: 0,
589            logical_op_count: 0,
590            arithmetic_op_count: 0,
591            comparison_op_count: 0,
592            free_var_count: expr.free_vars().len(),
593        };
594
595        stats.max_depth = Self::compute_recursive(expr, &mut stats, 0);
596        stats
597    }
598
599    fn compute_recursive(expr: &TLExpr, stats: &mut ExprStats, depth: usize) -> usize {
600        stats.node_count += 1;
601        let mut max_child_depth = depth;
602
603        match expr {
604            TLExpr::Pred { .. } => {
605                stats.predicate_count += 1;
606            }
607            TLExpr::And(l, r) | TLExpr::Or(l, r) | TLExpr::Imply(l, r) => {
608                stats.logical_op_count += 1;
609                let left_depth = Self::compute_recursive(l, stats, depth + 1);
610                let right_depth = Self::compute_recursive(r, stats, depth + 1);
611                max_child_depth = left_depth.max(right_depth);
612            }
613            TLExpr::Not(e) | TLExpr::Score(e) => {
614                stats.logical_op_count += 1;
615                max_child_depth = Self::compute_recursive(e, stats, depth + 1);
616            }
617            TLExpr::Exists { body, .. } | TLExpr::ForAll { body, .. } => {
618                stats.quantifier_count += 1;
619                max_child_depth = Self::compute_recursive(body, stats, depth + 1);
620            }
621            TLExpr::Aggregate { body, .. } => {
622                stats.quantifier_count += 1; // Aggregates are similar to quantifiers
623                max_child_depth = Self::compute_recursive(body, stats, depth + 1);
624            }
625            TLExpr::Add(l, r)
626            | TLExpr::Sub(l, r)
627            | TLExpr::Mul(l, r)
628            | TLExpr::Div(l, r)
629            | TLExpr::Pow(l, r)
630            | TLExpr::Mod(l, r)
631            | TLExpr::Min(l, r)
632            | TLExpr::Max(l, r) => {
633                stats.arithmetic_op_count += 1;
634                let left_depth = Self::compute_recursive(l, stats, depth + 1);
635                let right_depth = Self::compute_recursive(r, stats, depth + 1);
636                max_child_depth = left_depth.max(right_depth);
637            }
638            TLExpr::Abs(e)
639            | TLExpr::Floor(e)
640            | TLExpr::Ceil(e)
641            | TLExpr::Round(e)
642            | TLExpr::Sqrt(e)
643            | TLExpr::Exp(e)
644            | TLExpr::Log(e)
645            | TLExpr::Sin(e)
646            | TLExpr::Cos(e)
647            | TLExpr::Tan(e)
648            | TLExpr::Box(e)
649            | TLExpr::Diamond(e)
650            | TLExpr::Next(e)
651            | TLExpr::Eventually(e)
652            | TLExpr::Always(e) => {
653                stats.arithmetic_op_count += 1;
654                max_child_depth = Self::compute_recursive(e, stats, depth + 1);
655            }
656            TLExpr::Until { before, after } => {
657                stats.logical_op_count += 1;
658                let depth_before = Self::compute_recursive(before, stats, depth + 1);
659                let depth_after = Self::compute_recursive(after, stats, depth + 1);
660                max_child_depth = depth_before.max(depth_after);
661            }
662            TLExpr::Eq(l, r)
663            | TLExpr::Lt(l, r)
664            | TLExpr::Gt(l, r)
665            | TLExpr::Lte(l, r)
666            | TLExpr::Gte(l, r) => {
667                stats.comparison_op_count += 1;
668                let left_depth = Self::compute_recursive(l, stats, depth + 1);
669                let right_depth = Self::compute_recursive(r, stats, depth + 1);
670                max_child_depth = left_depth.max(right_depth);
671            }
672            TLExpr::IfThenElse {
673                condition,
674                then_branch,
675                else_branch,
676            } => {
677                let cond_depth = Self::compute_recursive(condition, stats, depth + 1);
678                let then_depth = Self::compute_recursive(then_branch, stats, depth + 1);
679                let else_depth = Self::compute_recursive(else_branch, stats, depth + 1);
680                max_child_depth = cond_depth.max(then_depth).max(else_depth);
681            }
682            TLExpr::Let { value, body, .. } => {
683                let value_depth = Self::compute_recursive(value, stats, depth + 1);
684                let body_depth = Self::compute_recursive(body, stats, depth + 1);
685                max_child_depth = value_depth.max(body_depth);
686            }
687
688            // Fuzzy logic operators
689            TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
690                stats.logical_op_count += 1;
691                let left_depth = Self::compute_recursive(left, stats, depth + 1);
692                let right_depth = Self::compute_recursive(right, stats, depth + 1);
693                max_child_depth = left_depth.max(right_depth);
694            }
695            TLExpr::FuzzyNot { expr, .. } => {
696                stats.logical_op_count += 1;
697                max_child_depth = Self::compute_recursive(expr, stats, depth + 1);
698            }
699            TLExpr::FuzzyImplication {
700                premise,
701                conclusion,
702                ..
703            } => {
704                stats.logical_op_count += 1;
705                let prem_depth = Self::compute_recursive(premise, stats, depth + 1);
706                let conc_depth = Self::compute_recursive(conclusion, stats, depth + 1);
707                max_child_depth = prem_depth.max(conc_depth);
708            }
709
710            // Probabilistic operators
711            TLExpr::SoftExists { body, .. } | TLExpr::SoftForAll { body, .. } => {
712                stats.quantifier_count += 1;
713                max_child_depth = Self::compute_recursive(body, stats, depth + 1);
714            }
715            TLExpr::WeightedRule { rule, .. } => {
716                stats.logical_op_count += 1;
717                max_child_depth = Self::compute_recursive(rule, stats, depth + 1);
718            }
719            TLExpr::ProbabilisticChoice { alternatives } => {
720                stats.logical_op_count += 1;
721                let mut max_alt_depth = depth;
722                for (_, expr) in alternatives {
723                    let alt_depth = Self::compute_recursive(expr, stats, depth + 1);
724                    max_alt_depth = max_alt_depth.max(alt_depth);
725                }
726                max_child_depth = max_alt_depth;
727            }
728
729            // Extended temporal logic
730            TLExpr::Release { released, releaser }
731            | TLExpr::WeakUntil {
732                before: released,
733                after: releaser,
734            }
735            | TLExpr::StrongRelease { released, releaser } => {
736                stats.logical_op_count += 1;
737                let rel_depth = Self::compute_recursive(released, stats, depth + 1);
738                let reler_depth = Self::compute_recursive(releaser, stats, depth + 1);
739                max_child_depth = rel_depth.max(reler_depth);
740            }
741
742            // Alpha.3 enhancements
743            TLExpr::Lambda { body, .. } => {
744                stats.quantifier_count += 1; // Lambda binds a variable
745                max_child_depth = Self::compute_recursive(body, stats, depth + 1);
746            }
747            TLExpr::Apply { function, argument } => {
748                stats.logical_op_count += 1;
749                let func_depth = Self::compute_recursive(function, stats, depth + 1);
750                let arg_depth = Self::compute_recursive(argument, stats, depth + 1);
751                max_child_depth = func_depth.max(arg_depth);
752            }
753            TLExpr::SetMembership { element, set }
754            | TLExpr::SetUnion {
755                left: element,
756                right: set,
757            }
758            | TLExpr::SetIntersection {
759                left: element,
760                right: set,
761            }
762            | TLExpr::SetDifference {
763                left: element,
764                right: set,
765            } => {
766                stats.logical_op_count += 1;
767                let elem_depth = Self::compute_recursive(element, stats, depth + 1);
768                let set_depth = Self::compute_recursive(set, stats, depth + 1);
769                max_child_depth = elem_depth.max(set_depth);
770            }
771            TLExpr::SetCardinality { set } => {
772                stats.arithmetic_op_count += 1;
773                max_child_depth = Self::compute_recursive(set, stats, depth + 1);
774            }
775            TLExpr::EmptySet => {
776                // Leaf node
777            }
778            TLExpr::SetComprehension { condition, .. } => {
779                stats.quantifier_count += 1;
780                max_child_depth = Self::compute_recursive(condition, stats, depth + 1);
781            }
782            TLExpr::CountingExists { body, .. }
783            | TLExpr::CountingForAll { body, .. }
784            | TLExpr::ExactCount { body, .. }
785            | TLExpr::Majority { body, .. } => {
786                stats.quantifier_count += 1;
787                max_child_depth = Self::compute_recursive(body, stats, depth + 1);
788            }
789            TLExpr::LeastFixpoint { body, .. } | TLExpr::GreatestFixpoint { body, .. } => {
790                stats.logical_op_count += 1;
791                max_child_depth = Self::compute_recursive(body, stats, depth + 1);
792            }
793            TLExpr::Nominal { .. } => {
794                // Leaf node
795            }
796            TLExpr::At { formula, .. } => {
797                stats.logical_op_count += 1;
798                max_child_depth = Self::compute_recursive(formula, stats, depth + 1);
799            }
800            TLExpr::Somewhere { formula } | TLExpr::Everywhere { formula } => {
801                stats.logical_op_count += 1;
802                max_child_depth = Self::compute_recursive(formula, stats, depth + 1);
803            }
804            TLExpr::AllDifferent { .. } => {
805                stats.logical_op_count += 1;
806                // Leaf node (no subexpressions)
807            }
808            TLExpr::GlobalCardinality { values, .. } => {
809                stats.logical_op_count += 1;
810                let mut max_val_depth = depth;
811                for val in values {
812                    let val_depth = Self::compute_recursive(val, stats, depth + 1);
813                    max_val_depth = max_val_depth.max(val_depth);
814                }
815                max_child_depth = max_val_depth;
816            }
817            TLExpr::Abducible { .. } => {
818                stats.predicate_count += 1;
819                // Leaf node
820            }
821            TLExpr::Explain { formula } => {
822                stats.logical_op_count += 1;
823                max_child_depth = Self::compute_recursive(formula, stats, depth + 1);
824            }
825
826            TLExpr::Constant(_) => {
827                // Leaf node
828            }
829        }
830
831        max_child_depth
832    }
833}
834
835/// Statistics about an EinsumGraph.
836#[derive(Debug, Clone, PartialEq)]
837pub struct GraphStats {
838    /// Number of tensors
839    pub tensor_count: usize,
840    /// Number of nodes
841    pub node_count: usize,
842    /// Number of output tensors
843    pub output_count: usize,
844    /// Number of einsum operations
845    pub einsum_count: usize,
846    /// Number of element-wise unary operations
847    pub elem_unary_count: usize,
848    /// Number of element-wise binary operations
849    pub elem_binary_count: usize,
850    /// Number of reduction operations
851    pub reduce_count: usize,
852    /// Average inputs per node
853    pub avg_inputs_per_node: f64,
854}
855
856impl GraphStats {
857    /// Compute statistics for a graph.
858    pub fn compute(graph: &EinsumGraph) -> Self {
859        let mut stats = GraphStats {
860            tensor_count: graph.tensors.len(),
861            node_count: graph.nodes.len(),
862            output_count: graph.outputs.len(),
863            einsum_count: 0,
864            elem_unary_count: 0,
865            elem_binary_count: 0,
866            reduce_count: 0,
867            avg_inputs_per_node: 0.0,
868        };
869
870        let mut total_inputs = 0;
871
872        for node in &graph.nodes {
873            total_inputs += node.inputs.len();
874
875            match &node.op {
876                crate::graph::OpType::Einsum { .. } => stats.einsum_count += 1,
877                crate::graph::OpType::ElemUnary { .. } => stats.elem_unary_count += 1,
878                crate::graph::OpType::ElemBinary { .. } => stats.elem_binary_count += 1,
879                crate::graph::OpType::Reduce { .. } => stats.reduce_count += 1,
880            }
881        }
882
883        if stats.node_count > 0 {
884            stats.avg_inputs_per_node = total_inputs as f64 / stats.node_count as f64;
885        }
886
887        stats
888    }
889}
890
891/// Pretty-print a graph to a string.
892pub fn pretty_print_graph(graph: &EinsumGraph) -> String {
893    let mut buffer = String::new();
894    writeln!(buffer, "EinsumGraph {{").unwrap();
895    writeln!(buffer, "  Tensors: {} total", graph.tensors.len()).unwrap();
896
897    for (idx, name) in graph.tensors.iter().enumerate() {
898        writeln!(buffer, "    t{}: {}", idx, name).unwrap();
899    }
900
901    writeln!(buffer, "  Nodes: {} total", graph.nodes.len()).unwrap();
902    for (idx, node) in graph.nodes.iter().enumerate() {
903        write!(buffer, "    n{}: ", idx).unwrap();
904        match &node.op {
905            crate::graph::OpType::Einsum { spec } => {
906                write!(buffer, "Einsum(\"{}\")", spec).unwrap()
907            }
908            crate::graph::OpType::ElemUnary { op } => write!(buffer, "ElemUnary({})", op).unwrap(),
909            crate::graph::OpType::ElemBinary { op } => {
910                write!(buffer, "ElemBinary({})", op).unwrap()
911            }
912            crate::graph::OpType::Reduce { op, axes } => {
913                write!(buffer, "Reduce({}, axes={:?})", op, axes).unwrap()
914            }
915        }
916        write!(buffer, " <- [").unwrap();
917        for (i, input) in node.inputs.iter().enumerate() {
918            if i > 0 {
919                write!(buffer, ", ").unwrap();
920            }
921            write!(buffer, "t{}", input).unwrap();
922        }
923        writeln!(buffer, "]").unwrap();
924    }
925
926    writeln!(buffer, "  Outputs: {:?}", graph.outputs).unwrap();
927    writeln!(buffer, "}}").unwrap();
928
929    buffer
930}
931
932#[cfg(test)]
933mod tests {
934    use super::*;
935
936    #[test]
937    fn test_expr_stats_simple() {
938        let expr = TLExpr::pred("P", vec![Term::var("x")]);
939        let stats = ExprStats::compute(&expr);
940
941        assert_eq!(stats.node_count, 1);
942        assert_eq!(stats.predicate_count, 1);
943        assert_eq!(stats.quantifier_count, 0);
944        assert_eq!(stats.free_var_count, 1);
945    }
946
947    #[test]
948    fn test_expr_stats_complex() {
949        // ∀x. P(x) ∧ Q(x)
950        let p = TLExpr::pred("P", vec![Term::var("x")]);
951        let q = TLExpr::pred("Q", vec![Term::var("x")]);
952        let and_expr = TLExpr::and(p, q);
953        let expr = TLExpr::forall("x", "Domain", and_expr);
954
955        let stats = ExprStats::compute(&expr);
956
957        assert_eq!(stats.node_count, 4); // forall, and, p, q
958        assert_eq!(stats.predicate_count, 2);
959        assert_eq!(stats.quantifier_count, 1);
960        assert_eq!(stats.logical_op_count, 1);
961        assert_eq!(stats.free_var_count, 0); // x is bound
962    }
963
964    #[test]
965    fn test_expr_stats_arithmetic() {
966        // score(x) * 2 + 1
967        let score = TLExpr::pred("score", vec![Term::var("x")]);
968        let mul = TLExpr::mul(score, TLExpr::constant(2.0));
969        let add = TLExpr::add(mul, TLExpr::constant(1.0));
970
971        let stats = ExprStats::compute(&add);
972
973        assert_eq!(stats.arithmetic_op_count, 2); // mul, add
974        assert_eq!(stats.predicate_count, 1);
975    }
976
977    #[test]
978    fn test_graph_stats() {
979        let mut graph = EinsumGraph::new();
980        let t0 = graph.add_tensor("input");
981        let t1 = graph.add_tensor("output");
982
983        graph
984            .add_node(crate::EinsumNode {
985                inputs: vec![t0],
986                outputs: vec![t1],
987                op: crate::graph::OpType::Einsum {
988                    spec: "i->i".to_string(),
989                },
990                metadata: None,
991            })
992            .unwrap();
993
994        graph.add_output(t1).unwrap();
995
996        let stats = GraphStats::compute(&graph);
997
998        assert_eq!(stats.tensor_count, 2);
999        assert_eq!(stats.node_count, 1);
1000        assert_eq!(stats.output_count, 1);
1001        assert_eq!(stats.einsum_count, 1);
1002        assert_eq!(stats.avg_inputs_per_node, 1.0);
1003    }
1004
1005    #[test]
1006    fn test_pretty_print_expr() {
1007        let expr = TLExpr::pred("Person", vec![Term::var("x")]);
1008        let output = pretty_print_expr(&expr);
1009        assert!(output.contains("Person(?x)"));
1010    }
1011
1012    #[test]
1013    fn test_pretty_print_graph() {
1014        let mut graph = EinsumGraph::new();
1015        let _t0 = graph.add_tensor("input");
1016
1017        let output = pretty_print_graph(&graph);
1018        assert!(output.contains("t0: input"));
1019        assert!(output.contains("Tensors: 1 total"));
1020    }
1021}