1use std::fmt::{self, Write};
9
10use crate::{EinsumGraph, TLExpr, Term};
11
12pub fn pretty_print_expr(expr: &TLExpr) -> String {
14 let mut buffer = String::new();
15 pretty_print_expr_inner(expr, &mut buffer, 0).unwrap();
16 buffer
17}
18
19fn pretty_print_expr_inner(expr: &TLExpr, buf: &mut String, indent: usize) -> fmt::Result {
20 let spaces = " ".repeat(indent);
21
22 match expr {
23 TLExpr::Pred { name, args } => {
24 write!(buf, "{}{}(", spaces, name)?;
25 for (i, arg) in args.iter().enumerate() {
26 if i > 0 {
27 write!(buf, ", ")?;
28 }
29 write!(buf, "{}", term_to_string(arg))?;
30 }
31 writeln!(buf, ")")?;
32 }
33 TLExpr::And(l, r) => {
34 writeln!(buf, "{}AND(", spaces)?;
35 pretty_print_expr_inner(l, buf, indent + 1)?;
36 writeln!(buf, "{},", spaces)?;
37 pretty_print_expr_inner(r, buf, indent + 1)?;
38 writeln!(buf, "{})", spaces)?;
39 }
40 TLExpr::Or(l, r) => {
41 writeln!(buf, "{}OR(", spaces)?;
42 pretty_print_expr_inner(l, buf, indent + 1)?;
43 writeln!(buf, "{},", spaces)?;
44 pretty_print_expr_inner(r, buf, indent + 1)?;
45 writeln!(buf, "{})", spaces)?;
46 }
47 TLExpr::Not(e) => {
48 writeln!(buf, "{}NOT(", spaces)?;
49 pretty_print_expr_inner(e, buf, indent + 1)?;
50 writeln!(buf, "{})", spaces)?;
51 }
52 TLExpr::Exists { var, domain, body } => {
53 writeln!(buf, "{}∃{}:{}.(", spaces, var, domain)?;
54 pretty_print_expr_inner(body, buf, indent + 1)?;
55 writeln!(buf, "{})", spaces)?;
56 }
57 TLExpr::ForAll { var, domain, body } => {
58 writeln!(buf, "{}∀{}:{}.(", spaces, var, domain)?;
59 pretty_print_expr_inner(body, buf, indent + 1)?;
60 writeln!(buf, "{})", spaces)?;
61 }
62 TLExpr::Aggregate {
63 op,
64 var,
65 domain,
66 body,
67 group_by,
68 } => {
69 write!(buf, "{}AGG_{:?}({}:{}", spaces, op, var, domain)?;
70 if let Some(group_vars) = group_by {
71 write!(buf, " GROUP BY {:?}", group_vars)?;
72 }
73 writeln!(buf, ")(")?;
74 pretty_print_expr_inner(body, buf, indent + 1)?;
75 writeln!(buf, "{})", spaces)?;
76 }
77 TLExpr::Imply(premise, conclusion) => {
78 writeln!(buf, "{}IMPLY(", spaces)?;
79 pretty_print_expr_inner(premise, buf, indent + 1)?;
80 writeln!(buf, "{}⇒", spaces)?;
81 pretty_print_expr_inner(conclusion, buf, indent + 1)?;
82 writeln!(buf, "{})", spaces)?;
83 }
84 TLExpr::Score(e) => {
85 writeln!(buf, "{}SCORE(", spaces)?;
86 pretty_print_expr_inner(e, buf, indent + 1)?;
87 writeln!(buf, "{})", spaces)?;
88 }
89 TLExpr::Add(l, r) => {
90 writeln!(buf, "{}ADD(", spaces)?;
91 pretty_print_expr_inner(l, buf, indent + 1)?;
92 writeln!(buf, "{},", spaces)?;
93 pretty_print_expr_inner(r, buf, indent + 1)?;
94 writeln!(buf, "{})", spaces)?;
95 }
96 TLExpr::Sub(l, r) => {
97 writeln!(buf, "{}SUB(", spaces)?;
98 pretty_print_expr_inner(l, buf, indent + 1)?;
99 writeln!(buf, "{},", spaces)?;
100 pretty_print_expr_inner(r, buf, indent + 1)?;
101 writeln!(buf, "{})", spaces)?;
102 }
103 TLExpr::Mul(l, r) => {
104 writeln!(buf, "{}MUL(", spaces)?;
105 pretty_print_expr_inner(l, buf, indent + 1)?;
106 writeln!(buf, "{},", spaces)?;
107 pretty_print_expr_inner(r, buf, indent + 1)?;
108 writeln!(buf, "{})", spaces)?;
109 }
110 TLExpr::Div(l, r) => {
111 writeln!(buf, "{}DIV(", spaces)?;
112 pretty_print_expr_inner(l, buf, indent + 1)?;
113 writeln!(buf, "{},", spaces)?;
114 pretty_print_expr_inner(r, buf, indent + 1)?;
115 writeln!(buf, "{})", spaces)?;
116 }
117 TLExpr::Pow(l, r) => {
118 writeln!(buf, "{}POW(", spaces)?;
119 pretty_print_expr_inner(l, buf, indent + 1)?;
120 writeln!(buf, "{},", spaces)?;
121 pretty_print_expr_inner(r, buf, indent + 1)?;
122 writeln!(buf, "{})", spaces)?;
123 }
124 TLExpr::Mod(l, r) => {
125 writeln!(buf, "{}MOD(", spaces)?;
126 pretty_print_expr_inner(l, buf, indent + 1)?;
127 writeln!(buf, "{},", spaces)?;
128 pretty_print_expr_inner(r, buf, indent + 1)?;
129 writeln!(buf, "{})", spaces)?;
130 }
131 TLExpr::Min(l, r) => {
132 writeln!(buf, "{}MIN(", spaces)?;
133 pretty_print_expr_inner(l, buf, indent + 1)?;
134 writeln!(buf, "{},", spaces)?;
135 pretty_print_expr_inner(r, buf, indent + 1)?;
136 writeln!(buf, "{})", spaces)?;
137 }
138 TLExpr::Max(l, r) => {
139 writeln!(buf, "{}MAX(", spaces)?;
140 pretty_print_expr_inner(l, buf, indent + 1)?;
141 writeln!(buf, "{},", spaces)?;
142 pretty_print_expr_inner(r, buf, indent + 1)?;
143 writeln!(buf, "{})", spaces)?;
144 }
145 TLExpr::Abs(e) => {
146 writeln!(buf, "{}ABS(", spaces)?;
147 pretty_print_expr_inner(e, buf, indent + 1)?;
148 writeln!(buf, "{})", spaces)?;
149 }
150 TLExpr::Floor(e) => {
151 writeln!(buf, "{}FLOOR(", spaces)?;
152 pretty_print_expr_inner(e, buf, indent + 1)?;
153 writeln!(buf, "{})", spaces)?;
154 }
155 TLExpr::Ceil(e) => {
156 writeln!(buf, "{}CEIL(", spaces)?;
157 pretty_print_expr_inner(e, buf, indent + 1)?;
158 writeln!(buf, "{})", spaces)?;
159 }
160 TLExpr::Round(e) => {
161 writeln!(buf, "{}ROUND(", spaces)?;
162 pretty_print_expr_inner(e, buf, indent + 1)?;
163 writeln!(buf, "{})", spaces)?;
164 }
165 TLExpr::Sqrt(e) => {
166 writeln!(buf, "{}SQRT(", spaces)?;
167 pretty_print_expr_inner(e, buf, indent + 1)?;
168 writeln!(buf, "{})", spaces)?;
169 }
170 TLExpr::Exp(e) => {
171 writeln!(buf, "{}EXP(", spaces)?;
172 pretty_print_expr_inner(e, buf, indent + 1)?;
173 writeln!(buf, "{})", spaces)?;
174 }
175 TLExpr::Log(e) => {
176 writeln!(buf, "{}LOG(", spaces)?;
177 pretty_print_expr_inner(e, buf, indent + 1)?;
178 writeln!(buf, "{})", spaces)?;
179 }
180 TLExpr::Sin(e) => {
181 writeln!(buf, "{}SIN(", spaces)?;
182 pretty_print_expr_inner(e, buf, indent + 1)?;
183 writeln!(buf, "{})", spaces)?;
184 }
185 TLExpr::Cos(e) => {
186 writeln!(buf, "{}COS(", spaces)?;
187 pretty_print_expr_inner(e, buf, indent + 1)?;
188 writeln!(buf, "{})", spaces)?;
189 }
190 TLExpr::Tan(e) => {
191 writeln!(buf, "{}TAN(", spaces)?;
192 pretty_print_expr_inner(e, buf, indent + 1)?;
193 writeln!(buf, "{})", spaces)?;
194 }
195 TLExpr::Box(e) => {
196 writeln!(buf, "{}BOX(", spaces)?;
197 pretty_print_expr_inner(e, buf, indent + 1)?;
198 writeln!(buf, "{})", spaces)?;
199 }
200 TLExpr::Diamond(e) => {
201 writeln!(buf, "{}DIAMOND(", spaces)?;
202 pretty_print_expr_inner(e, buf, indent + 1)?;
203 writeln!(buf, "{})", spaces)?;
204 }
205 TLExpr::Next(e) => {
206 writeln!(buf, "{}NEXT(", spaces)?;
207 pretty_print_expr_inner(e, buf, indent + 1)?;
208 writeln!(buf, "{})", spaces)?;
209 }
210 TLExpr::Eventually(e) => {
211 writeln!(buf, "{}EVENTUALLY(", spaces)?;
212 pretty_print_expr_inner(e, buf, indent + 1)?;
213 writeln!(buf, "{})", spaces)?;
214 }
215 TLExpr::Always(e) => {
216 writeln!(buf, "{}ALWAYS(", spaces)?;
217 pretty_print_expr_inner(e, buf, indent + 1)?;
218 writeln!(buf, "{})", spaces)?;
219 }
220 TLExpr::Until { before, after } => {
221 writeln!(buf, "{}UNTIL(", spaces)?;
222 pretty_print_expr_inner(before, buf, indent + 1)?;
223 writeln!(buf, "{},", spaces)?;
224 pretty_print_expr_inner(after, buf, indent + 1)?;
225 writeln!(buf, "{})", spaces)?;
226 }
227
228 TLExpr::TNorm { kind, left, right } => {
230 writeln!(buf, "{}T-NORM_{:?}(", spaces, kind)?;
231 pretty_print_expr_inner(left, buf, indent + 1)?;
232 writeln!(buf, "{},", spaces)?;
233 pretty_print_expr_inner(right, buf, indent + 1)?;
234 writeln!(buf, "{})", spaces)?;
235 }
236 TLExpr::TCoNorm { kind, left, right } => {
237 writeln!(buf, "{}T-CONORM_{:?}(", spaces, kind)?;
238 pretty_print_expr_inner(left, buf, indent + 1)?;
239 writeln!(buf, "{},", spaces)?;
240 pretty_print_expr_inner(right, buf, indent + 1)?;
241 writeln!(buf, "{})", spaces)?;
242 }
243 TLExpr::FuzzyNot { kind, expr } => {
244 writeln!(buf, "{}FUZZY-NOT_{:?}(", spaces, kind)?;
245 pretty_print_expr_inner(expr, buf, indent + 1)?;
246 writeln!(buf, "{})", spaces)?;
247 }
248 TLExpr::FuzzyImplication {
249 kind,
250 premise,
251 conclusion,
252 } => {
253 writeln!(buf, "{}FUZZY-IMPLY_{:?}(", spaces, kind)?;
254 pretty_print_expr_inner(premise, buf, indent + 1)?;
255 writeln!(buf, "{}⇒", spaces)?;
256 pretty_print_expr_inner(conclusion, buf, indent + 1)?;
257 writeln!(buf, "{})", spaces)?;
258 }
259
260 TLExpr::SoftExists {
262 var,
263 domain,
264 body,
265 temperature,
266 } => {
267 writeln!(
268 buf,
269 "{}SOFT-∃{}:{}[T={}](",
270 spaces, var, domain, temperature
271 )?;
272 pretty_print_expr_inner(body, buf, indent + 1)?;
273 writeln!(buf, "{})", spaces)?;
274 }
275 TLExpr::SoftForAll {
276 var,
277 domain,
278 body,
279 temperature,
280 } => {
281 writeln!(
282 buf,
283 "{}SOFT-∀{}:{}[T={}](",
284 spaces, var, domain, temperature
285 )?;
286 pretty_print_expr_inner(body, buf, indent + 1)?;
287 writeln!(buf, "{})", spaces)?;
288 }
289 TLExpr::WeightedRule { weight, rule } => {
290 writeln!(buf, "{}WEIGHTED[{}](", spaces, weight)?;
291 pretty_print_expr_inner(rule, buf, indent + 1)?;
292 writeln!(buf, "{})", spaces)?;
293 }
294 TLExpr::ProbabilisticChoice { alternatives } => {
295 writeln!(buf, "{}PROB-CHOICE[", spaces)?;
296 for (i, (prob, expr)) in alternatives.iter().enumerate() {
297 if i > 0 {
298 writeln!(buf, "{},", spaces)?;
299 }
300 writeln!(buf, "{} {}: ", spaces, prob)?;
301 pretty_print_expr_inner(expr, buf, indent + 2)?;
302 }
303 writeln!(buf, "{}]", spaces)?;
304 }
305
306 TLExpr::Release { released, releaser } => {
308 writeln!(buf, "{}RELEASE(", spaces)?;
309 pretty_print_expr_inner(released, buf, indent + 1)?;
310 writeln!(buf, "{},", spaces)?;
311 pretty_print_expr_inner(releaser, buf, indent + 1)?;
312 writeln!(buf, "{})", spaces)?;
313 }
314 TLExpr::WeakUntil { before, after } => {
315 writeln!(buf, "{}WEAK-UNTIL(", spaces)?;
316 pretty_print_expr_inner(before, buf, indent + 1)?;
317 writeln!(buf, "{},", spaces)?;
318 pretty_print_expr_inner(after, buf, indent + 1)?;
319 writeln!(buf, "{})", spaces)?;
320 }
321 TLExpr::StrongRelease { released, releaser } => {
322 writeln!(buf, "{}STRONG-RELEASE(", spaces)?;
323 pretty_print_expr_inner(released, buf, indent + 1)?;
324 writeln!(buf, "{},", spaces)?;
325 pretty_print_expr_inner(releaser, buf, indent + 1)?;
326 writeln!(buf, "{})", spaces)?;
327 }
328
329 TLExpr::Eq(l, r) => {
330 writeln!(buf, "{}EQ(", spaces)?;
331 pretty_print_expr_inner(l, buf, indent + 1)?;
332 writeln!(buf, "{},", spaces)?;
333 pretty_print_expr_inner(r, buf, indent + 1)?;
334 writeln!(buf, "{})", spaces)?;
335 }
336 TLExpr::Lt(l, r) => {
337 writeln!(buf, "{}LT(", spaces)?;
338 pretty_print_expr_inner(l, buf, indent + 1)?;
339 writeln!(buf, "{},", spaces)?;
340 pretty_print_expr_inner(r, buf, indent + 1)?;
341 writeln!(buf, "{})", spaces)?;
342 }
343 TLExpr::Gt(l, r) => {
344 writeln!(buf, "{}GT(", spaces)?;
345 pretty_print_expr_inner(l, buf, indent + 1)?;
346 writeln!(buf, "{},", spaces)?;
347 pretty_print_expr_inner(r, buf, indent + 1)?;
348 writeln!(buf, "{})", spaces)?;
349 }
350 TLExpr::Lte(l, r) => {
351 writeln!(buf, "{}LTE(", spaces)?;
352 pretty_print_expr_inner(l, buf, indent + 1)?;
353 writeln!(buf, "{},", spaces)?;
354 pretty_print_expr_inner(r, buf, indent + 1)?;
355 writeln!(buf, "{})", spaces)?;
356 }
357 TLExpr::Gte(l, r) => {
358 writeln!(buf, "{}GTE(", spaces)?;
359 pretty_print_expr_inner(l, buf, indent + 1)?;
360 writeln!(buf, "{},", spaces)?;
361 pretty_print_expr_inner(r, buf, indent + 1)?;
362 writeln!(buf, "{})", spaces)?;
363 }
364 TLExpr::IfThenElse {
365 condition,
366 then_branch,
367 else_branch,
368 } => {
369 writeln!(buf, "{}IF(", spaces)?;
370 pretty_print_expr_inner(condition, buf, indent + 1)?;
371 writeln!(buf, "{}) THEN(", spaces)?;
372 pretty_print_expr_inner(then_branch, buf, indent + 1)?;
373 writeln!(buf, "{}) ELSE(", spaces)?;
374 pretty_print_expr_inner(else_branch, buf, indent + 1)?;
375 writeln!(buf, "{})", spaces)?;
376 }
377 TLExpr::Let { var, value, body } => {
378 writeln!(buf, "{}LET {} =(", spaces, var)?;
379 pretty_print_expr_inner(value, buf, indent + 1)?;
380 writeln!(buf, "{}) IN(", spaces)?;
381 pretty_print_expr_inner(body, buf, indent + 1)?;
382 writeln!(buf, "{})", spaces)?;
383 }
384 TLExpr::Lambda {
386 var,
387 var_type,
388 body,
389 } => {
390 if let Some(ty) = var_type {
391 writeln!(buf, "{}LAMBDA {}:{} ⇒(", spaces, var, ty)?;
392 } else {
393 writeln!(buf, "{}LAMBDA {} ⇒(", spaces, var)?;
394 }
395 pretty_print_expr_inner(body, buf, indent + 1)?;
396 writeln!(buf, "{})", spaces)?;
397 }
398 TLExpr::Apply { function, argument } => {
399 writeln!(buf, "{}APPLY(", spaces)?;
400 pretty_print_expr_inner(function, buf, indent + 1)?;
401 writeln!(buf, "{}TO", spaces)?;
402 pretty_print_expr_inner(argument, buf, indent + 1)?;
403 writeln!(buf, "{})", spaces)?;
404 }
405 TLExpr::SetMembership { element, set } => {
406 writeln!(buf, "{}MEMBER(", spaces)?;
407 pretty_print_expr_inner(element, buf, indent + 1)?;
408 writeln!(buf, "{}IN", spaces)?;
409 pretty_print_expr_inner(set, buf, indent + 1)?;
410 writeln!(buf, "{})", spaces)?;
411 }
412 TLExpr::SetUnion { left, right } => {
413 writeln!(buf, "{}UNION(", spaces)?;
414 pretty_print_expr_inner(left, buf, indent + 1)?;
415 writeln!(buf, "{},", spaces)?;
416 pretty_print_expr_inner(right, buf, indent + 1)?;
417 writeln!(buf, "{})", spaces)?;
418 }
419 TLExpr::SetIntersection { left, right } => {
420 writeln!(buf, "{}INTERSECT(", spaces)?;
421 pretty_print_expr_inner(left, buf, indent + 1)?;
422 writeln!(buf, "{},", spaces)?;
423 pretty_print_expr_inner(right, buf, indent + 1)?;
424 writeln!(buf, "{})", spaces)?;
425 }
426 TLExpr::SetDifference { left, right } => {
427 writeln!(buf, "{}DIFFERENCE(", spaces)?;
428 pretty_print_expr_inner(left, buf, indent + 1)?;
429 writeln!(buf, "{},", spaces)?;
430 pretty_print_expr_inner(right, buf, indent + 1)?;
431 writeln!(buf, "{})", spaces)?;
432 }
433 TLExpr::SetCardinality { set } => {
434 writeln!(buf, "{}CARDINALITY(", spaces)?;
435 pretty_print_expr_inner(set, buf, indent + 1)?;
436 writeln!(buf, "{})", spaces)?;
437 }
438 TLExpr::EmptySet => {
439 writeln!(buf, "{}EMPTY-SET", spaces)?;
440 }
441 TLExpr::SetComprehension {
442 var,
443 domain,
444 condition,
445 } => {
446 writeln!(buf, "{}SET-COMPREHENSION {{ {}:{} | ", spaces, var, domain)?;
447 pretty_print_expr_inner(condition, buf, indent + 1)?;
448 writeln!(buf, "{}}}", spaces)?;
449 }
450 TLExpr::CountingExists {
451 var,
452 domain,
453 body,
454 min_count,
455 } => {
456 writeln!(buf, "{}∃≥{}{}:{}.(", spaces, min_count, var, domain)?;
457 pretty_print_expr_inner(body, buf, indent + 1)?;
458 writeln!(buf, "{})", spaces)?;
459 }
460 TLExpr::CountingForAll {
461 var,
462 domain,
463 body,
464 min_count,
465 } => {
466 writeln!(buf, "{}∀≥{}{}:{}.(", spaces, min_count, var, domain)?;
467 pretty_print_expr_inner(body, buf, indent + 1)?;
468 writeln!(buf, "{})", spaces)?;
469 }
470 TLExpr::ExactCount {
471 var,
472 domain,
473 body,
474 count,
475 } => {
476 writeln!(buf, "{}∃={}{}:{}.(", spaces, count, var, domain)?;
477 pretty_print_expr_inner(body, buf, indent + 1)?;
478 writeln!(buf, "{})", spaces)?;
479 }
480 TLExpr::Majority { var, domain, body } => {
481 writeln!(buf, "{}MAJORITY {}:{}.(", spaces, var, domain)?;
482 pretty_print_expr_inner(body, buf, indent + 1)?;
483 writeln!(buf, "{})", spaces)?;
484 }
485 TLExpr::LeastFixpoint { var, body } => {
486 writeln!(buf, "{}μ{}.(", spaces, var)?;
487 pretty_print_expr_inner(body, buf, indent + 1)?;
488 writeln!(buf, "{})", spaces)?;
489 }
490 TLExpr::GreatestFixpoint { var, body } => {
491 writeln!(buf, "{}ν{}.(", spaces, var)?;
492 pretty_print_expr_inner(body, buf, indent + 1)?;
493 writeln!(buf, "{})", spaces)?;
494 }
495 TLExpr::Nominal { name } => {
496 writeln!(buf, "{}@{}", spaces, name)?;
497 }
498 TLExpr::At { nominal, formula } => {
499 writeln!(buf, "{}AT @{}(", spaces, nominal)?;
500 pretty_print_expr_inner(formula, buf, indent + 1)?;
501 writeln!(buf, "{})", spaces)?;
502 }
503 TLExpr::Somewhere { formula } => {
504 writeln!(buf, "{}SOMEWHERE(", spaces)?;
505 pretty_print_expr_inner(formula, buf, indent + 1)?;
506 writeln!(buf, "{})", spaces)?;
507 }
508 TLExpr::Everywhere { formula } => {
509 writeln!(buf, "{}EVERYWHERE(", spaces)?;
510 pretty_print_expr_inner(formula, buf, indent + 1)?;
511 writeln!(buf, "{})", spaces)?;
512 }
513 TLExpr::AllDifferent { variables } => {
514 writeln!(buf, "{}ALL-DIFFERENT({:?})", spaces, variables)?;
515 }
516 TLExpr::GlobalCardinality {
517 variables,
518 values,
519 min_occurrences,
520 max_occurrences,
521 } => {
522 writeln!(buf, "{}GLOBAL-CARDINALITY(", spaces)?;
523 writeln!(buf, "{} vars: {:?}", spaces, variables)?;
524 writeln!(buf, "{} constraints: [", spaces)?;
525 for (i, val) in values.iter().enumerate() {
526 write!(buf, "{} ", spaces)?;
527 pretty_print_expr_inner(val, buf, 0)?;
528 writeln!(buf, ": [{}, {}]", min_occurrences[i], max_occurrences[i])?;
529 }
530 writeln!(buf, "{} ]", spaces)?;
531 writeln!(buf, "{})", spaces)?;
532 }
533 TLExpr::Abducible { name, cost } => {
534 writeln!(buf, "{}ABDUCIBLE({}, cost={})", spaces, name, cost)?;
535 }
536 TLExpr::Explain { formula } => {
537 writeln!(buf, "{}EXPLAIN(", spaces)?;
538 pretty_print_expr_inner(formula, buf, indent + 1)?;
539 writeln!(buf, "{})", spaces)?;
540 }
541 TLExpr::Constant(value) => {
542 writeln!(buf, "{}{}", spaces, value)?;
543 }
544 }
545
546 Ok(())
547}
548
549fn term_to_string(term: &Term) -> String {
550 match term {
551 Term::Var(name) => format!("?{}", name),
552 Term::Const(name) => name.clone(),
553 Term::Typed {
554 value,
555 type_annotation,
556 } => format!("{}:{}", term_to_string(value), type_annotation.type_name),
557 }
558}
559
560#[derive(Debug, Clone, PartialEq, Eq)]
562pub struct ExprStats {
563 pub node_count: usize,
565 pub max_depth: usize,
567 pub predicate_count: usize,
569 pub quantifier_count: usize,
571 pub logical_op_count: usize,
573 pub arithmetic_op_count: usize,
575 pub comparison_op_count: usize,
577 pub free_var_count: usize,
579}
580
581impl ExprStats {
582 pub fn compute(expr: &TLExpr) -> Self {
584 let mut stats = ExprStats {
585 node_count: 0,
586 max_depth: 0,
587 predicate_count: 0,
588 quantifier_count: 0,
589 logical_op_count: 0,
590 arithmetic_op_count: 0,
591 comparison_op_count: 0,
592 free_var_count: expr.free_vars().len(),
593 };
594
595 stats.max_depth = Self::compute_recursive(expr, &mut stats, 0);
596 stats
597 }
598
599 fn compute_recursive(expr: &TLExpr, stats: &mut ExprStats, depth: usize) -> usize {
600 stats.node_count += 1;
601 let mut max_child_depth = depth;
602
603 match expr {
604 TLExpr::Pred { .. } => {
605 stats.predicate_count += 1;
606 }
607 TLExpr::And(l, r) | TLExpr::Or(l, r) | TLExpr::Imply(l, r) => {
608 stats.logical_op_count += 1;
609 let left_depth = Self::compute_recursive(l, stats, depth + 1);
610 let right_depth = Self::compute_recursive(r, stats, depth + 1);
611 max_child_depth = left_depth.max(right_depth);
612 }
613 TLExpr::Not(e) | TLExpr::Score(e) => {
614 stats.logical_op_count += 1;
615 max_child_depth = Self::compute_recursive(e, stats, depth + 1);
616 }
617 TLExpr::Exists { body, .. } | TLExpr::ForAll { body, .. } => {
618 stats.quantifier_count += 1;
619 max_child_depth = Self::compute_recursive(body, stats, depth + 1);
620 }
621 TLExpr::Aggregate { body, .. } => {
622 stats.quantifier_count += 1; max_child_depth = Self::compute_recursive(body, stats, depth + 1);
624 }
625 TLExpr::Add(l, r)
626 | TLExpr::Sub(l, r)
627 | TLExpr::Mul(l, r)
628 | TLExpr::Div(l, r)
629 | TLExpr::Pow(l, r)
630 | TLExpr::Mod(l, r)
631 | TLExpr::Min(l, r)
632 | TLExpr::Max(l, r) => {
633 stats.arithmetic_op_count += 1;
634 let left_depth = Self::compute_recursive(l, stats, depth + 1);
635 let right_depth = Self::compute_recursive(r, stats, depth + 1);
636 max_child_depth = left_depth.max(right_depth);
637 }
638 TLExpr::Abs(e)
639 | TLExpr::Floor(e)
640 | TLExpr::Ceil(e)
641 | TLExpr::Round(e)
642 | TLExpr::Sqrt(e)
643 | TLExpr::Exp(e)
644 | TLExpr::Log(e)
645 | TLExpr::Sin(e)
646 | TLExpr::Cos(e)
647 | TLExpr::Tan(e)
648 | TLExpr::Box(e)
649 | TLExpr::Diamond(e)
650 | TLExpr::Next(e)
651 | TLExpr::Eventually(e)
652 | TLExpr::Always(e) => {
653 stats.arithmetic_op_count += 1;
654 max_child_depth = Self::compute_recursive(e, stats, depth + 1);
655 }
656 TLExpr::Until { before, after } => {
657 stats.logical_op_count += 1;
658 let depth_before = Self::compute_recursive(before, stats, depth + 1);
659 let depth_after = Self::compute_recursive(after, stats, depth + 1);
660 max_child_depth = depth_before.max(depth_after);
661 }
662 TLExpr::Eq(l, r)
663 | TLExpr::Lt(l, r)
664 | TLExpr::Gt(l, r)
665 | TLExpr::Lte(l, r)
666 | TLExpr::Gte(l, r) => {
667 stats.comparison_op_count += 1;
668 let left_depth = Self::compute_recursive(l, stats, depth + 1);
669 let right_depth = Self::compute_recursive(r, stats, depth + 1);
670 max_child_depth = left_depth.max(right_depth);
671 }
672 TLExpr::IfThenElse {
673 condition,
674 then_branch,
675 else_branch,
676 } => {
677 let cond_depth = Self::compute_recursive(condition, stats, depth + 1);
678 let then_depth = Self::compute_recursive(then_branch, stats, depth + 1);
679 let else_depth = Self::compute_recursive(else_branch, stats, depth + 1);
680 max_child_depth = cond_depth.max(then_depth).max(else_depth);
681 }
682 TLExpr::Let { value, body, .. } => {
683 let value_depth = Self::compute_recursive(value, stats, depth + 1);
684 let body_depth = Self::compute_recursive(body, stats, depth + 1);
685 max_child_depth = value_depth.max(body_depth);
686 }
687
688 TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
690 stats.logical_op_count += 1;
691 let left_depth = Self::compute_recursive(left, stats, depth + 1);
692 let right_depth = Self::compute_recursive(right, stats, depth + 1);
693 max_child_depth = left_depth.max(right_depth);
694 }
695 TLExpr::FuzzyNot { expr, .. } => {
696 stats.logical_op_count += 1;
697 max_child_depth = Self::compute_recursive(expr, stats, depth + 1);
698 }
699 TLExpr::FuzzyImplication {
700 premise,
701 conclusion,
702 ..
703 } => {
704 stats.logical_op_count += 1;
705 let prem_depth = Self::compute_recursive(premise, stats, depth + 1);
706 let conc_depth = Self::compute_recursive(conclusion, stats, depth + 1);
707 max_child_depth = prem_depth.max(conc_depth);
708 }
709
710 TLExpr::SoftExists { body, .. } | TLExpr::SoftForAll { body, .. } => {
712 stats.quantifier_count += 1;
713 max_child_depth = Self::compute_recursive(body, stats, depth + 1);
714 }
715 TLExpr::WeightedRule { rule, .. } => {
716 stats.logical_op_count += 1;
717 max_child_depth = Self::compute_recursive(rule, stats, depth + 1);
718 }
719 TLExpr::ProbabilisticChoice { alternatives } => {
720 stats.logical_op_count += 1;
721 let mut max_alt_depth = depth;
722 for (_, expr) in alternatives {
723 let alt_depth = Self::compute_recursive(expr, stats, depth + 1);
724 max_alt_depth = max_alt_depth.max(alt_depth);
725 }
726 max_child_depth = max_alt_depth;
727 }
728
729 TLExpr::Release { released, releaser }
731 | TLExpr::WeakUntil {
732 before: released,
733 after: releaser,
734 }
735 | TLExpr::StrongRelease { released, releaser } => {
736 stats.logical_op_count += 1;
737 let rel_depth = Self::compute_recursive(released, stats, depth + 1);
738 let reler_depth = Self::compute_recursive(releaser, stats, depth + 1);
739 max_child_depth = rel_depth.max(reler_depth);
740 }
741
742 TLExpr::Lambda { body, .. } => {
744 stats.quantifier_count += 1; max_child_depth = Self::compute_recursive(body, stats, depth + 1);
746 }
747 TLExpr::Apply { function, argument } => {
748 stats.logical_op_count += 1;
749 let func_depth = Self::compute_recursive(function, stats, depth + 1);
750 let arg_depth = Self::compute_recursive(argument, stats, depth + 1);
751 max_child_depth = func_depth.max(arg_depth);
752 }
753 TLExpr::SetMembership { element, set }
754 | TLExpr::SetUnion {
755 left: element,
756 right: set,
757 }
758 | TLExpr::SetIntersection {
759 left: element,
760 right: set,
761 }
762 | TLExpr::SetDifference {
763 left: element,
764 right: set,
765 } => {
766 stats.logical_op_count += 1;
767 let elem_depth = Self::compute_recursive(element, stats, depth + 1);
768 let set_depth = Self::compute_recursive(set, stats, depth + 1);
769 max_child_depth = elem_depth.max(set_depth);
770 }
771 TLExpr::SetCardinality { set } => {
772 stats.arithmetic_op_count += 1;
773 max_child_depth = Self::compute_recursive(set, stats, depth + 1);
774 }
775 TLExpr::EmptySet => {
776 }
778 TLExpr::SetComprehension { condition, .. } => {
779 stats.quantifier_count += 1;
780 max_child_depth = Self::compute_recursive(condition, stats, depth + 1);
781 }
782 TLExpr::CountingExists { body, .. }
783 | TLExpr::CountingForAll { body, .. }
784 | TLExpr::ExactCount { body, .. }
785 | TLExpr::Majority { body, .. } => {
786 stats.quantifier_count += 1;
787 max_child_depth = Self::compute_recursive(body, stats, depth + 1);
788 }
789 TLExpr::LeastFixpoint { body, .. } | TLExpr::GreatestFixpoint { body, .. } => {
790 stats.logical_op_count += 1;
791 max_child_depth = Self::compute_recursive(body, stats, depth + 1);
792 }
793 TLExpr::Nominal { .. } => {
794 }
796 TLExpr::At { formula, .. } => {
797 stats.logical_op_count += 1;
798 max_child_depth = Self::compute_recursive(formula, stats, depth + 1);
799 }
800 TLExpr::Somewhere { formula } | TLExpr::Everywhere { formula } => {
801 stats.logical_op_count += 1;
802 max_child_depth = Self::compute_recursive(formula, stats, depth + 1);
803 }
804 TLExpr::AllDifferent { .. } => {
805 stats.logical_op_count += 1;
806 }
808 TLExpr::GlobalCardinality { values, .. } => {
809 stats.logical_op_count += 1;
810 let mut max_val_depth = depth;
811 for val in values {
812 let val_depth = Self::compute_recursive(val, stats, depth + 1);
813 max_val_depth = max_val_depth.max(val_depth);
814 }
815 max_child_depth = max_val_depth;
816 }
817 TLExpr::Abducible { .. } => {
818 stats.predicate_count += 1;
819 }
821 TLExpr::Explain { formula } => {
822 stats.logical_op_count += 1;
823 max_child_depth = Self::compute_recursive(formula, stats, depth + 1);
824 }
825
826 TLExpr::Constant(_) => {
827 }
829 }
830
831 max_child_depth
832 }
833}
834
835#[derive(Debug, Clone, PartialEq)]
837pub struct GraphStats {
838 pub tensor_count: usize,
840 pub node_count: usize,
842 pub output_count: usize,
844 pub einsum_count: usize,
846 pub elem_unary_count: usize,
848 pub elem_binary_count: usize,
850 pub reduce_count: usize,
852 pub avg_inputs_per_node: f64,
854}
855
856impl GraphStats {
857 pub fn compute(graph: &EinsumGraph) -> Self {
859 let mut stats = GraphStats {
860 tensor_count: graph.tensors.len(),
861 node_count: graph.nodes.len(),
862 output_count: graph.outputs.len(),
863 einsum_count: 0,
864 elem_unary_count: 0,
865 elem_binary_count: 0,
866 reduce_count: 0,
867 avg_inputs_per_node: 0.0,
868 };
869
870 let mut total_inputs = 0;
871
872 for node in &graph.nodes {
873 total_inputs += node.inputs.len();
874
875 match &node.op {
876 crate::graph::OpType::Einsum { .. } => stats.einsum_count += 1,
877 crate::graph::OpType::ElemUnary { .. } => stats.elem_unary_count += 1,
878 crate::graph::OpType::ElemBinary { .. } => stats.elem_binary_count += 1,
879 crate::graph::OpType::Reduce { .. } => stats.reduce_count += 1,
880 }
881 }
882
883 if stats.node_count > 0 {
884 stats.avg_inputs_per_node = total_inputs as f64 / stats.node_count as f64;
885 }
886
887 stats
888 }
889}
890
891pub fn pretty_print_graph(graph: &EinsumGraph) -> String {
893 let mut buffer = String::new();
894 writeln!(buffer, "EinsumGraph {{").unwrap();
895 writeln!(buffer, " Tensors: {} total", graph.tensors.len()).unwrap();
896
897 for (idx, name) in graph.tensors.iter().enumerate() {
898 writeln!(buffer, " t{}: {}", idx, name).unwrap();
899 }
900
901 writeln!(buffer, " Nodes: {} total", graph.nodes.len()).unwrap();
902 for (idx, node) in graph.nodes.iter().enumerate() {
903 write!(buffer, " n{}: ", idx).unwrap();
904 match &node.op {
905 crate::graph::OpType::Einsum { spec } => {
906 write!(buffer, "Einsum(\"{}\")", spec).unwrap()
907 }
908 crate::graph::OpType::ElemUnary { op } => write!(buffer, "ElemUnary({})", op).unwrap(),
909 crate::graph::OpType::ElemBinary { op } => {
910 write!(buffer, "ElemBinary({})", op).unwrap()
911 }
912 crate::graph::OpType::Reduce { op, axes } => {
913 write!(buffer, "Reduce({}, axes={:?})", op, axes).unwrap()
914 }
915 }
916 write!(buffer, " <- [").unwrap();
917 for (i, input) in node.inputs.iter().enumerate() {
918 if i > 0 {
919 write!(buffer, ", ").unwrap();
920 }
921 write!(buffer, "t{}", input).unwrap();
922 }
923 writeln!(buffer, "]").unwrap();
924 }
925
926 writeln!(buffer, " Outputs: {:?}", graph.outputs).unwrap();
927 writeln!(buffer, "}}").unwrap();
928
929 buffer
930}
931
932#[cfg(test)]
933mod tests {
934 use super::*;
935
936 #[test]
937 fn test_expr_stats_simple() {
938 let expr = TLExpr::pred("P", vec![Term::var("x")]);
939 let stats = ExprStats::compute(&expr);
940
941 assert_eq!(stats.node_count, 1);
942 assert_eq!(stats.predicate_count, 1);
943 assert_eq!(stats.quantifier_count, 0);
944 assert_eq!(stats.free_var_count, 1);
945 }
946
947 #[test]
948 fn test_expr_stats_complex() {
949 let p = TLExpr::pred("P", vec![Term::var("x")]);
951 let q = TLExpr::pred("Q", vec![Term::var("x")]);
952 let and_expr = TLExpr::and(p, q);
953 let expr = TLExpr::forall("x", "Domain", and_expr);
954
955 let stats = ExprStats::compute(&expr);
956
957 assert_eq!(stats.node_count, 4); assert_eq!(stats.predicate_count, 2);
959 assert_eq!(stats.quantifier_count, 1);
960 assert_eq!(stats.logical_op_count, 1);
961 assert_eq!(stats.free_var_count, 0); }
963
964 #[test]
965 fn test_expr_stats_arithmetic() {
966 let score = TLExpr::pred("score", vec![Term::var("x")]);
968 let mul = TLExpr::mul(score, TLExpr::constant(2.0));
969 let add = TLExpr::add(mul, TLExpr::constant(1.0));
970
971 let stats = ExprStats::compute(&add);
972
973 assert_eq!(stats.arithmetic_op_count, 2); assert_eq!(stats.predicate_count, 1);
975 }
976
977 #[test]
978 fn test_graph_stats() {
979 let mut graph = EinsumGraph::new();
980 let t0 = graph.add_tensor("input");
981 let t1 = graph.add_tensor("output");
982
983 graph
984 .add_node(crate::EinsumNode {
985 inputs: vec![t0],
986 outputs: vec![t1],
987 op: crate::graph::OpType::Einsum {
988 spec: "i->i".to_string(),
989 },
990 metadata: None,
991 })
992 .unwrap();
993
994 graph.add_output(t1).unwrap();
995
996 let stats = GraphStats::compute(&graph);
997
998 assert_eq!(stats.tensor_count, 2);
999 assert_eq!(stats.node_count, 1);
1000 assert_eq!(stats.output_count, 1);
1001 assert_eq!(stats.einsum_count, 1);
1002 assert_eq!(stats.avg_inputs_per_node, 1.0);
1003 }
1004
1005 #[test]
1006 fn test_pretty_print_expr() {
1007 let expr = TLExpr::pred("Person", vec![Term::var("x")]);
1008 let output = pretty_print_expr(&expr);
1009 assert!(output.contains("Person(?x)"));
1010 }
1011
1012 #[test]
1013 fn test_pretty_print_graph() {
1014 let mut graph = EinsumGraph::new();
1015 let _t0 = graph.add_tensor("input");
1016
1017 let output = pretty_print_graph(&graph);
1018 assert!(output.contains("t0: input"));
1019 assert!(output.contains("Tensors: 1 total"));
1020 }
1021}