1use std::fmt::{self, Write};
9
10use crate::{EinsumGraph, TLExpr, Term};
11
12pub fn pretty_print_expr(expr: &TLExpr) -> String {
14 let mut buffer = String::new();
15 pretty_print_expr_inner(expr, &mut buffer, 0).expect("writing to String never fails");
16 buffer
17}
18
19fn pretty_print_expr_inner(expr: &TLExpr, buf: &mut String, indent: usize) -> fmt::Result {
20 let spaces = " ".repeat(indent);
21
22 match expr {
23 TLExpr::Pred { name, args } => {
24 write!(buf, "{}{}(", spaces, name)?;
25 for (i, arg) in args.iter().enumerate() {
26 if i > 0 {
27 write!(buf, ", ")?;
28 }
29 write!(buf, "{}", term_to_string(arg))?;
30 }
31 writeln!(buf, ")")?;
32 }
33 TLExpr::And(l, r) => {
34 writeln!(buf, "{}AND(", spaces)?;
35 pretty_print_expr_inner(l, buf, indent + 1)?;
36 writeln!(buf, "{},", spaces)?;
37 pretty_print_expr_inner(r, buf, indent + 1)?;
38 writeln!(buf, "{})", spaces)?;
39 }
40 TLExpr::Or(l, r) => {
41 writeln!(buf, "{}OR(", spaces)?;
42 pretty_print_expr_inner(l, buf, indent + 1)?;
43 writeln!(buf, "{},", spaces)?;
44 pretty_print_expr_inner(r, buf, indent + 1)?;
45 writeln!(buf, "{})", spaces)?;
46 }
47 TLExpr::Not(e) => {
48 writeln!(buf, "{}NOT(", spaces)?;
49 pretty_print_expr_inner(e, buf, indent + 1)?;
50 writeln!(buf, "{})", spaces)?;
51 }
52 TLExpr::Exists { var, domain, body } => {
53 writeln!(buf, "{}∃{}:{}.(", spaces, var, domain)?;
54 pretty_print_expr_inner(body, buf, indent + 1)?;
55 writeln!(buf, "{})", spaces)?;
56 }
57 TLExpr::ForAll { var, domain, body } => {
58 writeln!(buf, "{}∀{}:{}.(", spaces, var, domain)?;
59 pretty_print_expr_inner(body, buf, indent + 1)?;
60 writeln!(buf, "{})", spaces)?;
61 }
62 TLExpr::Aggregate {
63 op,
64 var,
65 domain,
66 body,
67 group_by,
68 } => {
69 write!(buf, "{}AGG_{:?}({}:{}", spaces, op, var, domain)?;
70 if let Some(group_vars) = group_by {
71 write!(buf, " GROUP BY {:?}", group_vars)?;
72 }
73 writeln!(buf, ")(")?;
74 pretty_print_expr_inner(body, buf, indent + 1)?;
75 writeln!(buf, "{})", spaces)?;
76 }
77 TLExpr::Imply(premise, conclusion) => {
78 writeln!(buf, "{}IMPLY(", spaces)?;
79 pretty_print_expr_inner(premise, buf, indent + 1)?;
80 writeln!(buf, "{}⇒", spaces)?;
81 pretty_print_expr_inner(conclusion, buf, indent + 1)?;
82 writeln!(buf, "{})", spaces)?;
83 }
84 TLExpr::Score(e) => {
85 writeln!(buf, "{}SCORE(", spaces)?;
86 pretty_print_expr_inner(e, buf, indent + 1)?;
87 writeln!(buf, "{})", spaces)?;
88 }
89 TLExpr::Add(l, r) => {
90 writeln!(buf, "{}ADD(", spaces)?;
91 pretty_print_expr_inner(l, buf, indent + 1)?;
92 writeln!(buf, "{},", spaces)?;
93 pretty_print_expr_inner(r, buf, indent + 1)?;
94 writeln!(buf, "{})", spaces)?;
95 }
96 TLExpr::Sub(l, r) => {
97 writeln!(buf, "{}SUB(", spaces)?;
98 pretty_print_expr_inner(l, buf, indent + 1)?;
99 writeln!(buf, "{},", spaces)?;
100 pretty_print_expr_inner(r, buf, indent + 1)?;
101 writeln!(buf, "{})", spaces)?;
102 }
103 TLExpr::Mul(l, r) => {
104 writeln!(buf, "{}MUL(", spaces)?;
105 pretty_print_expr_inner(l, buf, indent + 1)?;
106 writeln!(buf, "{},", spaces)?;
107 pretty_print_expr_inner(r, buf, indent + 1)?;
108 writeln!(buf, "{})", spaces)?;
109 }
110 TLExpr::Div(l, r) => {
111 writeln!(buf, "{}DIV(", spaces)?;
112 pretty_print_expr_inner(l, buf, indent + 1)?;
113 writeln!(buf, "{},", spaces)?;
114 pretty_print_expr_inner(r, buf, indent + 1)?;
115 writeln!(buf, "{})", spaces)?;
116 }
117 TLExpr::Pow(l, r) => {
118 writeln!(buf, "{}POW(", spaces)?;
119 pretty_print_expr_inner(l, buf, indent + 1)?;
120 writeln!(buf, "{},", spaces)?;
121 pretty_print_expr_inner(r, buf, indent + 1)?;
122 writeln!(buf, "{})", spaces)?;
123 }
124 TLExpr::Mod(l, r) => {
125 writeln!(buf, "{}MOD(", spaces)?;
126 pretty_print_expr_inner(l, buf, indent + 1)?;
127 writeln!(buf, "{},", spaces)?;
128 pretty_print_expr_inner(r, buf, indent + 1)?;
129 writeln!(buf, "{})", spaces)?;
130 }
131 TLExpr::Min(l, r) => {
132 writeln!(buf, "{}MIN(", spaces)?;
133 pretty_print_expr_inner(l, buf, indent + 1)?;
134 writeln!(buf, "{},", spaces)?;
135 pretty_print_expr_inner(r, buf, indent + 1)?;
136 writeln!(buf, "{})", spaces)?;
137 }
138 TLExpr::Max(l, r) => {
139 writeln!(buf, "{}MAX(", spaces)?;
140 pretty_print_expr_inner(l, buf, indent + 1)?;
141 writeln!(buf, "{},", spaces)?;
142 pretty_print_expr_inner(r, buf, indent + 1)?;
143 writeln!(buf, "{})", spaces)?;
144 }
145 TLExpr::Abs(e) => {
146 writeln!(buf, "{}ABS(", spaces)?;
147 pretty_print_expr_inner(e, buf, indent + 1)?;
148 writeln!(buf, "{})", spaces)?;
149 }
150 TLExpr::Floor(e) => {
151 writeln!(buf, "{}FLOOR(", spaces)?;
152 pretty_print_expr_inner(e, buf, indent + 1)?;
153 writeln!(buf, "{})", spaces)?;
154 }
155 TLExpr::Ceil(e) => {
156 writeln!(buf, "{}CEIL(", spaces)?;
157 pretty_print_expr_inner(e, buf, indent + 1)?;
158 writeln!(buf, "{})", spaces)?;
159 }
160 TLExpr::Round(e) => {
161 writeln!(buf, "{}ROUND(", spaces)?;
162 pretty_print_expr_inner(e, buf, indent + 1)?;
163 writeln!(buf, "{})", spaces)?;
164 }
165 TLExpr::Sqrt(e) => {
166 writeln!(buf, "{}SQRT(", spaces)?;
167 pretty_print_expr_inner(e, buf, indent + 1)?;
168 writeln!(buf, "{})", spaces)?;
169 }
170 TLExpr::Exp(e) => {
171 writeln!(buf, "{}EXP(", spaces)?;
172 pretty_print_expr_inner(e, buf, indent + 1)?;
173 writeln!(buf, "{})", spaces)?;
174 }
175 TLExpr::Log(e) => {
176 writeln!(buf, "{}LOG(", spaces)?;
177 pretty_print_expr_inner(e, buf, indent + 1)?;
178 writeln!(buf, "{})", spaces)?;
179 }
180 TLExpr::Sin(e) => {
181 writeln!(buf, "{}SIN(", spaces)?;
182 pretty_print_expr_inner(e, buf, indent + 1)?;
183 writeln!(buf, "{})", spaces)?;
184 }
185 TLExpr::Cos(e) => {
186 writeln!(buf, "{}COS(", spaces)?;
187 pretty_print_expr_inner(e, buf, indent + 1)?;
188 writeln!(buf, "{})", spaces)?;
189 }
190 TLExpr::Tan(e) => {
191 writeln!(buf, "{}TAN(", spaces)?;
192 pretty_print_expr_inner(e, buf, indent + 1)?;
193 writeln!(buf, "{})", spaces)?;
194 }
195 TLExpr::Box(e) => {
196 writeln!(buf, "{}BOX(", spaces)?;
197 pretty_print_expr_inner(e, buf, indent + 1)?;
198 writeln!(buf, "{})", spaces)?;
199 }
200 TLExpr::Diamond(e) => {
201 writeln!(buf, "{}DIAMOND(", spaces)?;
202 pretty_print_expr_inner(e, buf, indent + 1)?;
203 writeln!(buf, "{})", spaces)?;
204 }
205 TLExpr::Next(e) => {
206 writeln!(buf, "{}NEXT(", spaces)?;
207 pretty_print_expr_inner(e, buf, indent + 1)?;
208 writeln!(buf, "{})", spaces)?;
209 }
210 TLExpr::Eventually(e) => {
211 writeln!(buf, "{}EVENTUALLY(", spaces)?;
212 pretty_print_expr_inner(e, buf, indent + 1)?;
213 writeln!(buf, "{})", spaces)?;
214 }
215 TLExpr::Always(e) => {
216 writeln!(buf, "{}ALWAYS(", spaces)?;
217 pretty_print_expr_inner(e, buf, indent + 1)?;
218 writeln!(buf, "{})", spaces)?;
219 }
220 TLExpr::Until { before, after } => {
221 writeln!(buf, "{}UNTIL(", spaces)?;
222 pretty_print_expr_inner(before, buf, indent + 1)?;
223 writeln!(buf, "{},", spaces)?;
224 pretty_print_expr_inner(after, buf, indent + 1)?;
225 writeln!(buf, "{})", spaces)?;
226 }
227
228 TLExpr::TNorm { kind, left, right } => {
230 writeln!(buf, "{}T-NORM_{:?}(", spaces, kind)?;
231 pretty_print_expr_inner(left, buf, indent + 1)?;
232 writeln!(buf, "{},", spaces)?;
233 pretty_print_expr_inner(right, buf, indent + 1)?;
234 writeln!(buf, "{})", spaces)?;
235 }
236 TLExpr::TCoNorm { kind, left, right } => {
237 writeln!(buf, "{}T-CONORM_{:?}(", spaces, kind)?;
238 pretty_print_expr_inner(left, buf, indent + 1)?;
239 writeln!(buf, "{},", spaces)?;
240 pretty_print_expr_inner(right, buf, indent + 1)?;
241 writeln!(buf, "{})", spaces)?;
242 }
243 TLExpr::FuzzyNot { kind, expr } => {
244 writeln!(buf, "{}FUZZY-NOT_{:?}(", spaces, kind)?;
245 pretty_print_expr_inner(expr, buf, indent + 1)?;
246 writeln!(buf, "{})", spaces)?;
247 }
248 TLExpr::FuzzyImplication {
249 kind,
250 premise,
251 conclusion,
252 } => {
253 writeln!(buf, "{}FUZZY-IMPLY_{:?}(", spaces, kind)?;
254 pretty_print_expr_inner(premise, buf, indent + 1)?;
255 writeln!(buf, "{}⇒", spaces)?;
256 pretty_print_expr_inner(conclusion, buf, indent + 1)?;
257 writeln!(buf, "{})", spaces)?;
258 }
259
260 TLExpr::SoftExists {
262 var,
263 domain,
264 body,
265 temperature,
266 } => {
267 writeln!(
268 buf,
269 "{}SOFT-∃{}:{}[T={}](",
270 spaces, var, domain, temperature
271 )?;
272 pretty_print_expr_inner(body, buf, indent + 1)?;
273 writeln!(buf, "{})", spaces)?;
274 }
275 TLExpr::SoftForAll {
276 var,
277 domain,
278 body,
279 temperature,
280 } => {
281 writeln!(
282 buf,
283 "{}SOFT-∀{}:{}[T={}](",
284 spaces, var, domain, temperature
285 )?;
286 pretty_print_expr_inner(body, buf, indent + 1)?;
287 writeln!(buf, "{})", spaces)?;
288 }
289 TLExpr::WeightedRule { weight, rule } => {
290 writeln!(buf, "{}WEIGHTED[{}](", spaces, weight)?;
291 pretty_print_expr_inner(rule, buf, indent + 1)?;
292 writeln!(buf, "{})", spaces)?;
293 }
294 TLExpr::ProbabilisticChoice { alternatives } => {
295 writeln!(buf, "{}PROB-CHOICE[", spaces)?;
296 for (i, (prob, expr)) in alternatives.iter().enumerate() {
297 if i > 0 {
298 writeln!(buf, "{},", spaces)?;
299 }
300 writeln!(buf, "{} {}: ", spaces, prob)?;
301 pretty_print_expr_inner(expr, buf, indent + 2)?;
302 }
303 writeln!(buf, "{}]", spaces)?;
304 }
305
306 TLExpr::Release { released, releaser } => {
308 writeln!(buf, "{}RELEASE(", spaces)?;
309 pretty_print_expr_inner(released, buf, indent + 1)?;
310 writeln!(buf, "{},", spaces)?;
311 pretty_print_expr_inner(releaser, buf, indent + 1)?;
312 writeln!(buf, "{})", spaces)?;
313 }
314 TLExpr::WeakUntil { before, after } => {
315 writeln!(buf, "{}WEAK-UNTIL(", spaces)?;
316 pretty_print_expr_inner(before, buf, indent + 1)?;
317 writeln!(buf, "{},", spaces)?;
318 pretty_print_expr_inner(after, buf, indent + 1)?;
319 writeln!(buf, "{})", spaces)?;
320 }
321 TLExpr::StrongRelease { released, releaser } => {
322 writeln!(buf, "{}STRONG-RELEASE(", spaces)?;
323 pretty_print_expr_inner(released, buf, indent + 1)?;
324 writeln!(buf, "{},", spaces)?;
325 pretty_print_expr_inner(releaser, buf, indent + 1)?;
326 writeln!(buf, "{})", spaces)?;
327 }
328
329 TLExpr::Eq(l, r) => {
330 writeln!(buf, "{}EQ(", spaces)?;
331 pretty_print_expr_inner(l, buf, indent + 1)?;
332 writeln!(buf, "{},", spaces)?;
333 pretty_print_expr_inner(r, buf, indent + 1)?;
334 writeln!(buf, "{})", spaces)?;
335 }
336 TLExpr::Lt(l, r) => {
337 writeln!(buf, "{}LT(", spaces)?;
338 pretty_print_expr_inner(l, buf, indent + 1)?;
339 writeln!(buf, "{},", spaces)?;
340 pretty_print_expr_inner(r, buf, indent + 1)?;
341 writeln!(buf, "{})", spaces)?;
342 }
343 TLExpr::Gt(l, r) => {
344 writeln!(buf, "{}GT(", spaces)?;
345 pretty_print_expr_inner(l, buf, indent + 1)?;
346 writeln!(buf, "{},", spaces)?;
347 pretty_print_expr_inner(r, buf, indent + 1)?;
348 writeln!(buf, "{})", spaces)?;
349 }
350 TLExpr::Lte(l, r) => {
351 writeln!(buf, "{}LTE(", spaces)?;
352 pretty_print_expr_inner(l, buf, indent + 1)?;
353 writeln!(buf, "{},", spaces)?;
354 pretty_print_expr_inner(r, buf, indent + 1)?;
355 writeln!(buf, "{})", spaces)?;
356 }
357 TLExpr::Gte(l, r) => {
358 writeln!(buf, "{}GTE(", spaces)?;
359 pretty_print_expr_inner(l, buf, indent + 1)?;
360 writeln!(buf, "{},", spaces)?;
361 pretty_print_expr_inner(r, buf, indent + 1)?;
362 writeln!(buf, "{})", spaces)?;
363 }
364 TLExpr::IfThenElse {
365 condition,
366 then_branch,
367 else_branch,
368 } => {
369 writeln!(buf, "{}IF(", spaces)?;
370 pretty_print_expr_inner(condition, buf, indent + 1)?;
371 writeln!(buf, "{}) THEN(", spaces)?;
372 pretty_print_expr_inner(then_branch, buf, indent + 1)?;
373 writeln!(buf, "{}) ELSE(", spaces)?;
374 pretty_print_expr_inner(else_branch, buf, indent + 1)?;
375 writeln!(buf, "{})", spaces)?;
376 }
377 TLExpr::Let { var, value, body } => {
378 writeln!(buf, "{}LET {} =(", spaces, var)?;
379 pretty_print_expr_inner(value, buf, indent + 1)?;
380 writeln!(buf, "{}) IN(", spaces)?;
381 pretty_print_expr_inner(body, buf, indent + 1)?;
382 writeln!(buf, "{})", spaces)?;
383 }
384 TLExpr::Lambda {
386 var,
387 var_type,
388 body,
389 } => {
390 if let Some(ty) = var_type {
391 writeln!(buf, "{}LAMBDA {}:{} ⇒(", spaces, var, ty)?;
392 } else {
393 writeln!(buf, "{}LAMBDA {} ⇒(", spaces, var)?;
394 }
395 pretty_print_expr_inner(body, buf, indent + 1)?;
396 writeln!(buf, "{})", spaces)?;
397 }
398 TLExpr::Apply { function, argument } => {
399 writeln!(buf, "{}APPLY(", spaces)?;
400 pretty_print_expr_inner(function, buf, indent + 1)?;
401 writeln!(buf, "{}TO", spaces)?;
402 pretty_print_expr_inner(argument, buf, indent + 1)?;
403 writeln!(buf, "{})", spaces)?;
404 }
405 TLExpr::SetMembership { element, set } => {
406 writeln!(buf, "{}MEMBER(", spaces)?;
407 pretty_print_expr_inner(element, buf, indent + 1)?;
408 writeln!(buf, "{}IN", spaces)?;
409 pretty_print_expr_inner(set, buf, indent + 1)?;
410 writeln!(buf, "{})", spaces)?;
411 }
412 TLExpr::SetUnion { left, right } => {
413 writeln!(buf, "{}UNION(", spaces)?;
414 pretty_print_expr_inner(left, buf, indent + 1)?;
415 writeln!(buf, "{},", spaces)?;
416 pretty_print_expr_inner(right, buf, indent + 1)?;
417 writeln!(buf, "{})", spaces)?;
418 }
419 TLExpr::SetIntersection { left, right } => {
420 writeln!(buf, "{}INTERSECT(", spaces)?;
421 pretty_print_expr_inner(left, buf, indent + 1)?;
422 writeln!(buf, "{},", spaces)?;
423 pretty_print_expr_inner(right, buf, indent + 1)?;
424 writeln!(buf, "{})", spaces)?;
425 }
426 TLExpr::SetDifference { left, right } => {
427 writeln!(buf, "{}DIFFERENCE(", spaces)?;
428 pretty_print_expr_inner(left, buf, indent + 1)?;
429 writeln!(buf, "{},", spaces)?;
430 pretty_print_expr_inner(right, buf, indent + 1)?;
431 writeln!(buf, "{})", spaces)?;
432 }
433 TLExpr::SetCardinality { set } => {
434 writeln!(buf, "{}CARDINALITY(", spaces)?;
435 pretty_print_expr_inner(set, buf, indent + 1)?;
436 writeln!(buf, "{})", spaces)?;
437 }
438 TLExpr::EmptySet => {
439 writeln!(buf, "{}EMPTY-SET", spaces)?;
440 }
441 TLExpr::SetComprehension {
442 var,
443 domain,
444 condition,
445 } => {
446 writeln!(buf, "{}SET-COMPREHENSION {{ {}:{} | ", spaces, var, domain)?;
447 pretty_print_expr_inner(condition, buf, indent + 1)?;
448 writeln!(buf, "{}}}", spaces)?;
449 }
450 TLExpr::CountingExists {
451 var,
452 domain,
453 body,
454 min_count,
455 } => {
456 writeln!(buf, "{}∃≥{}{}:{}.(", spaces, min_count, var, domain)?;
457 pretty_print_expr_inner(body, buf, indent + 1)?;
458 writeln!(buf, "{})", spaces)?;
459 }
460 TLExpr::CountingForAll {
461 var,
462 domain,
463 body,
464 min_count,
465 } => {
466 writeln!(buf, "{}∀≥{}{}:{}.(", spaces, min_count, var, domain)?;
467 pretty_print_expr_inner(body, buf, indent + 1)?;
468 writeln!(buf, "{})", spaces)?;
469 }
470 TLExpr::ExactCount {
471 var,
472 domain,
473 body,
474 count,
475 } => {
476 writeln!(buf, "{}∃={}{}:{}.(", spaces, count, var, domain)?;
477 pretty_print_expr_inner(body, buf, indent + 1)?;
478 writeln!(buf, "{})", spaces)?;
479 }
480 TLExpr::Majority { var, domain, body } => {
481 writeln!(buf, "{}MAJORITY {}:{}.(", spaces, var, domain)?;
482 pretty_print_expr_inner(body, buf, indent + 1)?;
483 writeln!(buf, "{})", spaces)?;
484 }
485 TLExpr::LeastFixpoint { var, body } => {
486 writeln!(buf, "{}μ{}.(", spaces, var)?;
487 pretty_print_expr_inner(body, buf, indent + 1)?;
488 writeln!(buf, "{})", spaces)?;
489 }
490 TLExpr::GreatestFixpoint { var, body } => {
491 writeln!(buf, "{}ν{}.(", spaces, var)?;
492 pretty_print_expr_inner(body, buf, indent + 1)?;
493 writeln!(buf, "{})", spaces)?;
494 }
495 TLExpr::Nominal { name } => {
496 writeln!(buf, "{}@{}", spaces, name)?;
497 }
498 TLExpr::At { nominal, formula } => {
499 writeln!(buf, "{}AT @{}(", spaces, nominal)?;
500 pretty_print_expr_inner(formula, buf, indent + 1)?;
501 writeln!(buf, "{})", spaces)?;
502 }
503 TLExpr::Somewhere { formula } => {
504 writeln!(buf, "{}SOMEWHERE(", spaces)?;
505 pretty_print_expr_inner(formula, buf, indent + 1)?;
506 writeln!(buf, "{})", spaces)?;
507 }
508 TLExpr::Everywhere { formula } => {
509 writeln!(buf, "{}EVERYWHERE(", spaces)?;
510 pretty_print_expr_inner(formula, buf, indent + 1)?;
511 writeln!(buf, "{})", spaces)?;
512 }
513 TLExpr::AllDifferent { variables } => {
514 writeln!(buf, "{}ALL-DIFFERENT({:?})", spaces, variables)?;
515 }
516 TLExpr::GlobalCardinality {
517 variables,
518 values,
519 min_occurrences,
520 max_occurrences,
521 } => {
522 writeln!(buf, "{}GLOBAL-CARDINALITY(", spaces)?;
523 writeln!(buf, "{} vars: {:?}", spaces, variables)?;
524 writeln!(buf, "{} constraints: [", spaces)?;
525 for (i, val) in values.iter().enumerate() {
526 write!(buf, "{} ", spaces)?;
527 pretty_print_expr_inner(val, buf, 0)?;
528 writeln!(buf, ": [{}, {}]", min_occurrences[i], max_occurrences[i])?;
529 }
530 writeln!(buf, "{} ]", spaces)?;
531 writeln!(buf, "{})", spaces)?;
532 }
533 TLExpr::Abducible { name, cost } => {
534 writeln!(buf, "{}ABDUCIBLE({}, cost={})", spaces, name, cost)?;
535 }
536 TLExpr::Explain { formula } => {
537 writeln!(buf, "{}EXPLAIN(", spaces)?;
538 pretty_print_expr_inner(formula, buf, indent + 1)?;
539 writeln!(buf, "{})", spaces)?;
540 }
541 TLExpr::Constant(value) => {
542 writeln!(buf, "{}{}", spaces, value)?;
543 }
544 TLExpr::SymbolLiteral(s) => {
545 writeln!(buf, "{}:{}", spaces, s)?;
546 }
547 TLExpr::Match { scrutinee, arms } => {
548 writeln!(buf, "{}MATCH(", spaces)?;
549 pretty_print_expr_inner(scrutinee, buf, indent + 1)?;
550 for (pat, body) in arms {
551 writeln!(buf, "{} {} =>", spaces, pat)?;
552 pretty_print_expr_inner(body, buf, indent + 2)?;
553 }
554 writeln!(buf, "{})", spaces)?;
555 }
556 }
557
558 Ok(())
559}
560
561fn term_to_string(term: &Term) -> String {
562 match term {
563 Term::Var(name) => format!("?{}", name),
564 Term::Const(name) => name.clone(),
565 Term::Typed {
566 value,
567 type_annotation,
568 } => format!("{}:{}", term_to_string(value), type_annotation.type_name),
569 }
570}
571
572#[derive(Debug, Clone, PartialEq, Eq)]
574pub struct ExprStats {
575 pub node_count: usize,
577 pub max_depth: usize,
579 pub predicate_count: usize,
581 pub quantifier_count: usize,
583 pub logical_op_count: usize,
585 pub arithmetic_op_count: usize,
587 pub comparison_op_count: usize,
589 pub free_var_count: usize,
591}
592
593impl ExprStats {
594 pub fn compute(expr: &TLExpr) -> Self {
596 let mut stats = ExprStats {
597 node_count: 0,
598 max_depth: 0,
599 predicate_count: 0,
600 quantifier_count: 0,
601 logical_op_count: 0,
602 arithmetic_op_count: 0,
603 comparison_op_count: 0,
604 free_var_count: expr.free_vars().len(),
605 };
606
607 stats.max_depth = Self::compute_recursive(expr, &mut stats, 0);
608 stats
609 }
610
611 fn compute_recursive(expr: &TLExpr, stats: &mut ExprStats, depth: usize) -> usize {
612 stats.node_count += 1;
613 let mut max_child_depth = depth;
614
615 match expr {
616 TLExpr::Pred { .. } => {
617 stats.predicate_count += 1;
618 }
619 TLExpr::And(l, r) | TLExpr::Or(l, r) | TLExpr::Imply(l, r) => {
620 stats.logical_op_count += 1;
621 let left_depth = Self::compute_recursive(l, stats, depth + 1);
622 let right_depth = Self::compute_recursive(r, stats, depth + 1);
623 max_child_depth = left_depth.max(right_depth);
624 }
625 TLExpr::Not(e) | TLExpr::Score(e) => {
626 stats.logical_op_count += 1;
627 max_child_depth = Self::compute_recursive(e, stats, depth + 1);
628 }
629 TLExpr::Exists { body, .. } | TLExpr::ForAll { body, .. } => {
630 stats.quantifier_count += 1;
631 max_child_depth = Self::compute_recursive(body, stats, depth + 1);
632 }
633 TLExpr::Aggregate { body, .. } => {
634 stats.quantifier_count += 1; max_child_depth = Self::compute_recursive(body, stats, depth + 1);
636 }
637 TLExpr::Add(l, r)
638 | TLExpr::Sub(l, r)
639 | TLExpr::Mul(l, r)
640 | TLExpr::Div(l, r)
641 | TLExpr::Pow(l, r)
642 | TLExpr::Mod(l, r)
643 | TLExpr::Min(l, r)
644 | TLExpr::Max(l, r) => {
645 stats.arithmetic_op_count += 1;
646 let left_depth = Self::compute_recursive(l, stats, depth + 1);
647 let right_depth = Self::compute_recursive(r, stats, depth + 1);
648 max_child_depth = left_depth.max(right_depth);
649 }
650 TLExpr::Abs(e)
651 | TLExpr::Floor(e)
652 | TLExpr::Ceil(e)
653 | TLExpr::Round(e)
654 | TLExpr::Sqrt(e)
655 | TLExpr::Exp(e)
656 | TLExpr::Log(e)
657 | TLExpr::Sin(e)
658 | TLExpr::Cos(e)
659 | TLExpr::Tan(e)
660 | TLExpr::Box(e)
661 | TLExpr::Diamond(e)
662 | TLExpr::Next(e)
663 | TLExpr::Eventually(e)
664 | TLExpr::Always(e) => {
665 stats.arithmetic_op_count += 1;
666 max_child_depth = Self::compute_recursive(e, stats, depth + 1);
667 }
668 TLExpr::Until { before, after } => {
669 stats.logical_op_count += 1;
670 let depth_before = Self::compute_recursive(before, stats, depth + 1);
671 let depth_after = Self::compute_recursive(after, stats, depth + 1);
672 max_child_depth = depth_before.max(depth_after);
673 }
674 TLExpr::Eq(l, r)
675 | TLExpr::Lt(l, r)
676 | TLExpr::Gt(l, r)
677 | TLExpr::Lte(l, r)
678 | TLExpr::Gte(l, r) => {
679 stats.comparison_op_count += 1;
680 let left_depth = Self::compute_recursive(l, stats, depth + 1);
681 let right_depth = Self::compute_recursive(r, stats, depth + 1);
682 max_child_depth = left_depth.max(right_depth);
683 }
684 TLExpr::IfThenElse {
685 condition,
686 then_branch,
687 else_branch,
688 } => {
689 let cond_depth = Self::compute_recursive(condition, stats, depth + 1);
690 let then_depth = Self::compute_recursive(then_branch, stats, depth + 1);
691 let else_depth = Self::compute_recursive(else_branch, stats, depth + 1);
692 max_child_depth = cond_depth.max(then_depth).max(else_depth);
693 }
694 TLExpr::Let { value, body, .. } => {
695 let value_depth = Self::compute_recursive(value, stats, depth + 1);
696 let body_depth = Self::compute_recursive(body, stats, depth + 1);
697 max_child_depth = value_depth.max(body_depth);
698 }
699
700 TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
702 stats.logical_op_count += 1;
703 let left_depth = Self::compute_recursive(left, stats, depth + 1);
704 let right_depth = Self::compute_recursive(right, stats, depth + 1);
705 max_child_depth = left_depth.max(right_depth);
706 }
707 TLExpr::FuzzyNot { expr, .. } => {
708 stats.logical_op_count += 1;
709 max_child_depth = Self::compute_recursive(expr, stats, depth + 1);
710 }
711 TLExpr::FuzzyImplication {
712 premise,
713 conclusion,
714 ..
715 } => {
716 stats.logical_op_count += 1;
717 let prem_depth = Self::compute_recursive(premise, stats, depth + 1);
718 let conc_depth = Self::compute_recursive(conclusion, stats, depth + 1);
719 max_child_depth = prem_depth.max(conc_depth);
720 }
721
722 TLExpr::SoftExists { body, .. } | TLExpr::SoftForAll { body, .. } => {
724 stats.quantifier_count += 1;
725 max_child_depth = Self::compute_recursive(body, stats, depth + 1);
726 }
727 TLExpr::WeightedRule { rule, .. } => {
728 stats.logical_op_count += 1;
729 max_child_depth = Self::compute_recursive(rule, stats, depth + 1);
730 }
731 TLExpr::ProbabilisticChoice { alternatives } => {
732 stats.logical_op_count += 1;
733 let mut max_alt_depth = depth;
734 for (_, expr) in alternatives {
735 let alt_depth = Self::compute_recursive(expr, stats, depth + 1);
736 max_alt_depth = max_alt_depth.max(alt_depth);
737 }
738 max_child_depth = max_alt_depth;
739 }
740
741 TLExpr::Release { released, releaser }
743 | TLExpr::WeakUntil {
744 before: released,
745 after: releaser,
746 }
747 | TLExpr::StrongRelease { released, releaser } => {
748 stats.logical_op_count += 1;
749 let rel_depth = Self::compute_recursive(released, stats, depth + 1);
750 let reler_depth = Self::compute_recursive(releaser, stats, depth + 1);
751 max_child_depth = rel_depth.max(reler_depth);
752 }
753
754 TLExpr::Lambda { body, .. } => {
756 stats.quantifier_count += 1; max_child_depth = Self::compute_recursive(body, stats, depth + 1);
758 }
759 TLExpr::Apply { function, argument } => {
760 stats.logical_op_count += 1;
761 let func_depth = Self::compute_recursive(function, stats, depth + 1);
762 let arg_depth = Self::compute_recursive(argument, stats, depth + 1);
763 max_child_depth = func_depth.max(arg_depth);
764 }
765 TLExpr::SetMembership { element, set }
766 | TLExpr::SetUnion {
767 left: element,
768 right: set,
769 }
770 | TLExpr::SetIntersection {
771 left: element,
772 right: set,
773 }
774 | TLExpr::SetDifference {
775 left: element,
776 right: set,
777 } => {
778 stats.logical_op_count += 1;
779 let elem_depth = Self::compute_recursive(element, stats, depth + 1);
780 let set_depth = Self::compute_recursive(set, stats, depth + 1);
781 max_child_depth = elem_depth.max(set_depth);
782 }
783 TLExpr::SetCardinality { set } => {
784 stats.arithmetic_op_count += 1;
785 max_child_depth = Self::compute_recursive(set, stats, depth + 1);
786 }
787 TLExpr::EmptySet => {
788 }
790 TLExpr::SetComprehension { condition, .. } => {
791 stats.quantifier_count += 1;
792 max_child_depth = Self::compute_recursive(condition, stats, depth + 1);
793 }
794 TLExpr::CountingExists { body, .. }
795 | TLExpr::CountingForAll { body, .. }
796 | TLExpr::ExactCount { body, .. }
797 | TLExpr::Majority { body, .. } => {
798 stats.quantifier_count += 1;
799 max_child_depth = Self::compute_recursive(body, stats, depth + 1);
800 }
801 TLExpr::LeastFixpoint { body, .. } | TLExpr::GreatestFixpoint { body, .. } => {
802 stats.logical_op_count += 1;
803 max_child_depth = Self::compute_recursive(body, stats, depth + 1);
804 }
805 TLExpr::Nominal { .. } => {
806 }
808 TLExpr::At { formula, .. } => {
809 stats.logical_op_count += 1;
810 max_child_depth = Self::compute_recursive(formula, stats, depth + 1);
811 }
812 TLExpr::Somewhere { formula } | TLExpr::Everywhere { formula } => {
813 stats.logical_op_count += 1;
814 max_child_depth = Self::compute_recursive(formula, stats, depth + 1);
815 }
816 TLExpr::AllDifferent { .. } => {
817 stats.logical_op_count += 1;
818 }
820 TLExpr::GlobalCardinality { values, .. } => {
821 stats.logical_op_count += 1;
822 let mut max_val_depth = depth;
823 for val in values {
824 let val_depth = Self::compute_recursive(val, stats, depth + 1);
825 max_val_depth = max_val_depth.max(val_depth);
826 }
827 max_child_depth = max_val_depth;
828 }
829 TLExpr::Abducible { .. } => {
830 stats.predicate_count += 1;
831 }
833 TLExpr::Explain { formula } => {
834 stats.logical_op_count += 1;
835 max_child_depth = Self::compute_recursive(formula, stats, depth + 1);
836 }
837
838 TLExpr::Constant(_) => {
839 }
841 TLExpr::SymbolLiteral(_) => {
842 }
844 TLExpr::Match { scrutinee, arms } => {
845 stats.logical_op_count += 1;
846 let sd = Self::compute_recursive(scrutinee, stats, depth + 1);
847 if sd > max_child_depth {
848 max_child_depth = sd;
849 }
850 for (_, body) in arms {
851 let bd = Self::compute_recursive(body, stats, depth + 1);
852 if bd > max_child_depth {
853 max_child_depth = bd;
854 }
855 }
856 }
857 }
858
859 max_child_depth
860 }
861}
862
863#[derive(Debug, Clone, PartialEq)]
865pub struct GraphStats {
866 pub tensor_count: usize,
868 pub node_count: usize,
870 pub output_count: usize,
872 pub einsum_count: usize,
874 pub elem_unary_count: usize,
876 pub elem_binary_count: usize,
878 pub reduce_count: usize,
880 pub avg_inputs_per_node: f64,
882}
883
884impl GraphStats {
885 pub fn compute(graph: &EinsumGraph) -> Self {
887 let mut stats = GraphStats {
888 tensor_count: graph.tensors.len(),
889 node_count: graph.nodes.len(),
890 output_count: graph.outputs.len(),
891 einsum_count: 0,
892 elem_unary_count: 0,
893 elem_binary_count: 0,
894 reduce_count: 0,
895 avg_inputs_per_node: 0.0,
896 };
897
898 let mut total_inputs = 0;
899
900 for node in &graph.nodes {
901 total_inputs += node.inputs.len();
902
903 match &node.op {
904 crate::graph::OpType::Einsum { .. } => stats.einsum_count += 1,
905 crate::graph::OpType::ElemUnary { .. } => stats.elem_unary_count += 1,
906 crate::graph::OpType::ElemBinary { .. } => stats.elem_binary_count += 1,
907 crate::graph::OpType::Reduce { .. } => stats.reduce_count += 1,
908 }
909 }
910
911 if stats.node_count > 0 {
912 stats.avg_inputs_per_node = total_inputs as f64 / stats.node_count as f64;
913 }
914
915 stats
916 }
917}
918
919pub fn pretty_print_graph(graph: &EinsumGraph) -> String {
921 let mut buffer = String::new();
922 writeln!(buffer, "EinsumGraph {{").expect("writing to String buffer never fails");
923 writeln!(buffer, " Tensors: {} total", graph.tensors.len())
924 .expect("writing to String buffer never fails");
925
926 for (idx, name) in graph.tensors.iter().enumerate() {
927 writeln!(buffer, " t{}: {}", idx, name).expect("writing to String buffer never fails");
928 }
929
930 writeln!(buffer, " Nodes: {} total", graph.nodes.len())
931 .expect("writing to String buffer never fails");
932 for (idx, node) in graph.nodes.iter().enumerate() {
933 write!(buffer, " n{}: ", idx).expect("writing to String buffer never fails");
934 match &node.op {
935 crate::graph::OpType::Einsum { spec } => write!(buffer, "Einsum(\"{}\")", spec)
936 .expect("writing to String buffer never fails"),
937 crate::graph::OpType::ElemUnary { op } => {
938 write!(buffer, "ElemUnary({})", op).expect("writing to String buffer never fails")
939 }
940 crate::graph::OpType::ElemBinary { op } => {
941 write!(buffer, "ElemBinary({})", op).expect("writing to String buffer never fails")
942 }
943 crate::graph::OpType::Reduce { op, axes } => {
944 write!(buffer, "Reduce({}, axes={:?})", op, axes)
945 .expect("writing to String buffer never fails")
946 }
947 }
948 write!(buffer, " <- [").expect("writing to String buffer never fails");
949 for (i, input) in node.inputs.iter().enumerate() {
950 if i > 0 {
951 write!(buffer, ", ").expect("writing to String buffer never fails");
952 }
953 write!(buffer, "t{}", input).expect("writing to String buffer never fails");
954 }
955 writeln!(buffer, "]").expect("writing to String buffer never fails");
956 }
957
958 writeln!(buffer, " Outputs: {:?}", graph.outputs)
959 .expect("writing to String buffer never fails");
960 writeln!(buffer, "}}").expect("writing to String buffer never fails");
961
962 buffer
963}
964
965#[cfg(test)]
966mod tests {
967 use super::*;
968
969 #[test]
970 fn test_expr_stats_simple() {
971 let expr = TLExpr::pred("P", vec![Term::var("x")]);
972 let stats = ExprStats::compute(&expr);
973
974 assert_eq!(stats.node_count, 1);
975 assert_eq!(stats.predicate_count, 1);
976 assert_eq!(stats.quantifier_count, 0);
977 assert_eq!(stats.free_var_count, 1);
978 }
979
980 #[test]
981 fn test_expr_stats_complex() {
982 let p = TLExpr::pred("P", vec![Term::var("x")]);
984 let q = TLExpr::pred("Q", vec![Term::var("x")]);
985 let and_expr = TLExpr::and(p, q);
986 let expr = TLExpr::forall("x", "Domain", and_expr);
987
988 let stats = ExprStats::compute(&expr);
989
990 assert_eq!(stats.node_count, 4); assert_eq!(stats.predicate_count, 2);
992 assert_eq!(stats.quantifier_count, 1);
993 assert_eq!(stats.logical_op_count, 1);
994 assert_eq!(stats.free_var_count, 0); }
996
997 #[test]
998 fn test_expr_stats_arithmetic() {
999 let score = TLExpr::pred("score", vec![Term::var("x")]);
1001 let mul = TLExpr::mul(score, TLExpr::constant(2.0));
1002 let add = TLExpr::add(mul, TLExpr::constant(1.0));
1003
1004 let stats = ExprStats::compute(&add);
1005
1006 assert_eq!(stats.arithmetic_op_count, 2); assert_eq!(stats.predicate_count, 1);
1008 }
1009
1010 #[test]
1011 fn test_graph_stats() {
1012 let mut graph = EinsumGraph::new();
1013 let t0 = graph.add_tensor("input");
1014 let t1 = graph.add_tensor("output");
1015
1016 graph
1017 .add_node(crate::EinsumNode {
1018 inputs: vec![t0],
1019 outputs: vec![t1],
1020 op: crate::graph::OpType::Einsum {
1021 spec: "i->i".to_string(),
1022 },
1023 metadata: None,
1024 })
1025 .expect("unwrap");
1026
1027 graph.add_output(t1).expect("unwrap");
1028
1029 let stats = GraphStats::compute(&graph);
1030
1031 assert_eq!(stats.tensor_count, 2);
1032 assert_eq!(stats.node_count, 1);
1033 assert_eq!(stats.output_count, 1);
1034 assert_eq!(stats.einsum_count, 1);
1035 assert_eq!(stats.avg_inputs_per_node, 1.0);
1036 }
1037
1038 #[test]
1039 fn test_pretty_print_expr() {
1040 let expr = TLExpr::pred("Person", vec![Term::var("x")]);
1041 let output = pretty_print_expr(&expr);
1042 assert!(output.contains("Person(?x)"));
1043 }
1044
1045 #[test]
1046 fn test_pretty_print_graph() {
1047 let mut graph = EinsumGraph::new();
1048 let _t0 = graph.add_tensor("input");
1049
1050 let output = pretty_print_graph(&graph);
1051 assert!(output.contains("t0: input"));
1052 assert!(output.contains("Tensors: 1 total"));
1053 }
1054}