1use std::fmt::{self, Write};
9
10use crate::{EinsumGraph, TLExpr, Term};
11
12pub fn pretty_print_expr(expr: &TLExpr) -> String {
14 let mut buffer = String::new();
15 pretty_print_expr_inner(expr, &mut buffer, 0).unwrap();
16 buffer
17}
18
19fn pretty_print_expr_inner(expr: &TLExpr, buf: &mut String, indent: usize) -> fmt::Result {
20 let spaces = " ".repeat(indent);
21
22 match expr {
23 TLExpr::Pred { name, args } => {
24 write!(buf, "{}{}(", spaces, name)?;
25 for (i, arg) in args.iter().enumerate() {
26 if i > 0 {
27 write!(buf, ", ")?;
28 }
29 write!(buf, "{}", term_to_string(arg))?;
30 }
31 writeln!(buf, ")")?;
32 }
33 TLExpr::And(l, r) => {
34 writeln!(buf, "{}AND(", spaces)?;
35 pretty_print_expr_inner(l, buf, indent + 1)?;
36 writeln!(buf, "{},", spaces)?;
37 pretty_print_expr_inner(r, buf, indent + 1)?;
38 writeln!(buf, "{})", spaces)?;
39 }
40 TLExpr::Or(l, r) => {
41 writeln!(buf, "{}OR(", spaces)?;
42 pretty_print_expr_inner(l, buf, indent + 1)?;
43 writeln!(buf, "{},", spaces)?;
44 pretty_print_expr_inner(r, buf, indent + 1)?;
45 writeln!(buf, "{})", spaces)?;
46 }
47 TLExpr::Not(e) => {
48 writeln!(buf, "{}NOT(", spaces)?;
49 pretty_print_expr_inner(e, buf, indent + 1)?;
50 writeln!(buf, "{})", spaces)?;
51 }
52 TLExpr::Exists { var, domain, body } => {
53 writeln!(buf, "{}∃{}:{}.(", spaces, var, domain)?;
54 pretty_print_expr_inner(body, buf, indent + 1)?;
55 writeln!(buf, "{})", spaces)?;
56 }
57 TLExpr::ForAll { var, domain, body } => {
58 writeln!(buf, "{}∀{}:{}.(", spaces, var, domain)?;
59 pretty_print_expr_inner(body, buf, indent + 1)?;
60 writeln!(buf, "{})", spaces)?;
61 }
62 TLExpr::Aggregate {
63 op,
64 var,
65 domain,
66 body,
67 group_by,
68 } => {
69 write!(buf, "{}AGG_{:?}({}:{}", spaces, op, var, domain)?;
70 if let Some(group_vars) = group_by {
71 write!(buf, " GROUP BY {:?}", group_vars)?;
72 }
73 writeln!(buf, ")(")?;
74 pretty_print_expr_inner(body, buf, indent + 1)?;
75 writeln!(buf, "{})", spaces)?;
76 }
77 TLExpr::Imply(premise, conclusion) => {
78 writeln!(buf, "{}IMPLY(", spaces)?;
79 pretty_print_expr_inner(premise, buf, indent + 1)?;
80 writeln!(buf, "{}⇒", spaces)?;
81 pretty_print_expr_inner(conclusion, buf, indent + 1)?;
82 writeln!(buf, "{})", spaces)?;
83 }
84 TLExpr::Score(e) => {
85 writeln!(buf, "{}SCORE(", spaces)?;
86 pretty_print_expr_inner(e, buf, indent + 1)?;
87 writeln!(buf, "{})", spaces)?;
88 }
89 TLExpr::Add(l, r) => {
90 writeln!(buf, "{}ADD(", spaces)?;
91 pretty_print_expr_inner(l, buf, indent + 1)?;
92 writeln!(buf, "{},", spaces)?;
93 pretty_print_expr_inner(r, buf, indent + 1)?;
94 writeln!(buf, "{})", spaces)?;
95 }
96 TLExpr::Sub(l, r) => {
97 writeln!(buf, "{}SUB(", spaces)?;
98 pretty_print_expr_inner(l, buf, indent + 1)?;
99 writeln!(buf, "{},", spaces)?;
100 pretty_print_expr_inner(r, buf, indent + 1)?;
101 writeln!(buf, "{})", spaces)?;
102 }
103 TLExpr::Mul(l, r) => {
104 writeln!(buf, "{}MUL(", spaces)?;
105 pretty_print_expr_inner(l, buf, indent + 1)?;
106 writeln!(buf, "{},", spaces)?;
107 pretty_print_expr_inner(r, buf, indent + 1)?;
108 writeln!(buf, "{})", spaces)?;
109 }
110 TLExpr::Div(l, r) => {
111 writeln!(buf, "{}DIV(", spaces)?;
112 pretty_print_expr_inner(l, buf, indent + 1)?;
113 writeln!(buf, "{},", spaces)?;
114 pretty_print_expr_inner(r, buf, indent + 1)?;
115 writeln!(buf, "{})", spaces)?;
116 }
117 TLExpr::Pow(l, r) => {
118 writeln!(buf, "{}POW(", spaces)?;
119 pretty_print_expr_inner(l, buf, indent + 1)?;
120 writeln!(buf, "{},", spaces)?;
121 pretty_print_expr_inner(r, buf, indent + 1)?;
122 writeln!(buf, "{})", spaces)?;
123 }
124 TLExpr::Mod(l, r) => {
125 writeln!(buf, "{}MOD(", spaces)?;
126 pretty_print_expr_inner(l, buf, indent + 1)?;
127 writeln!(buf, "{},", spaces)?;
128 pretty_print_expr_inner(r, buf, indent + 1)?;
129 writeln!(buf, "{})", spaces)?;
130 }
131 TLExpr::Min(l, r) => {
132 writeln!(buf, "{}MIN(", spaces)?;
133 pretty_print_expr_inner(l, buf, indent + 1)?;
134 writeln!(buf, "{},", spaces)?;
135 pretty_print_expr_inner(r, buf, indent + 1)?;
136 writeln!(buf, "{})", spaces)?;
137 }
138 TLExpr::Max(l, r) => {
139 writeln!(buf, "{}MAX(", spaces)?;
140 pretty_print_expr_inner(l, buf, indent + 1)?;
141 writeln!(buf, "{},", spaces)?;
142 pretty_print_expr_inner(r, buf, indent + 1)?;
143 writeln!(buf, "{})", spaces)?;
144 }
145 TLExpr::Abs(e) => {
146 writeln!(buf, "{}ABS(", spaces)?;
147 pretty_print_expr_inner(e, buf, indent + 1)?;
148 writeln!(buf, "{})", spaces)?;
149 }
150 TLExpr::Floor(e) => {
151 writeln!(buf, "{}FLOOR(", spaces)?;
152 pretty_print_expr_inner(e, buf, indent + 1)?;
153 writeln!(buf, "{})", spaces)?;
154 }
155 TLExpr::Ceil(e) => {
156 writeln!(buf, "{}CEIL(", spaces)?;
157 pretty_print_expr_inner(e, buf, indent + 1)?;
158 writeln!(buf, "{})", spaces)?;
159 }
160 TLExpr::Round(e) => {
161 writeln!(buf, "{}ROUND(", spaces)?;
162 pretty_print_expr_inner(e, buf, indent + 1)?;
163 writeln!(buf, "{})", spaces)?;
164 }
165 TLExpr::Sqrt(e) => {
166 writeln!(buf, "{}SQRT(", spaces)?;
167 pretty_print_expr_inner(e, buf, indent + 1)?;
168 writeln!(buf, "{})", spaces)?;
169 }
170 TLExpr::Exp(e) => {
171 writeln!(buf, "{}EXP(", spaces)?;
172 pretty_print_expr_inner(e, buf, indent + 1)?;
173 writeln!(buf, "{})", spaces)?;
174 }
175 TLExpr::Log(e) => {
176 writeln!(buf, "{}LOG(", spaces)?;
177 pretty_print_expr_inner(e, buf, indent + 1)?;
178 writeln!(buf, "{})", spaces)?;
179 }
180 TLExpr::Sin(e) => {
181 writeln!(buf, "{}SIN(", spaces)?;
182 pretty_print_expr_inner(e, buf, indent + 1)?;
183 writeln!(buf, "{})", spaces)?;
184 }
185 TLExpr::Cos(e) => {
186 writeln!(buf, "{}COS(", spaces)?;
187 pretty_print_expr_inner(e, buf, indent + 1)?;
188 writeln!(buf, "{})", spaces)?;
189 }
190 TLExpr::Tan(e) => {
191 writeln!(buf, "{}TAN(", spaces)?;
192 pretty_print_expr_inner(e, buf, indent + 1)?;
193 writeln!(buf, "{})", spaces)?;
194 }
195 TLExpr::Box(e) => {
196 writeln!(buf, "{}BOX(", spaces)?;
197 pretty_print_expr_inner(e, buf, indent + 1)?;
198 writeln!(buf, "{})", spaces)?;
199 }
200 TLExpr::Diamond(e) => {
201 writeln!(buf, "{}DIAMOND(", spaces)?;
202 pretty_print_expr_inner(e, buf, indent + 1)?;
203 writeln!(buf, "{})", spaces)?;
204 }
205 TLExpr::Next(e) => {
206 writeln!(buf, "{}NEXT(", spaces)?;
207 pretty_print_expr_inner(e, buf, indent + 1)?;
208 writeln!(buf, "{})", spaces)?;
209 }
210 TLExpr::Eventually(e) => {
211 writeln!(buf, "{}EVENTUALLY(", spaces)?;
212 pretty_print_expr_inner(e, buf, indent + 1)?;
213 writeln!(buf, "{})", spaces)?;
214 }
215 TLExpr::Always(e) => {
216 writeln!(buf, "{}ALWAYS(", spaces)?;
217 pretty_print_expr_inner(e, buf, indent + 1)?;
218 writeln!(buf, "{})", spaces)?;
219 }
220 TLExpr::Until { before, after } => {
221 writeln!(buf, "{}UNTIL(", spaces)?;
222 pretty_print_expr_inner(before, buf, indent + 1)?;
223 writeln!(buf, "{},", spaces)?;
224 pretty_print_expr_inner(after, buf, indent + 1)?;
225 writeln!(buf, "{})", spaces)?;
226 }
227
228 TLExpr::TNorm { kind, left, right } => {
230 writeln!(buf, "{}T-NORM_{:?}(", spaces, kind)?;
231 pretty_print_expr_inner(left, buf, indent + 1)?;
232 writeln!(buf, "{},", spaces)?;
233 pretty_print_expr_inner(right, buf, indent + 1)?;
234 writeln!(buf, "{})", spaces)?;
235 }
236 TLExpr::TCoNorm { kind, left, right } => {
237 writeln!(buf, "{}T-CONORM_{:?}(", spaces, kind)?;
238 pretty_print_expr_inner(left, buf, indent + 1)?;
239 writeln!(buf, "{},", spaces)?;
240 pretty_print_expr_inner(right, buf, indent + 1)?;
241 writeln!(buf, "{})", spaces)?;
242 }
243 TLExpr::FuzzyNot { kind, expr } => {
244 writeln!(buf, "{}FUZZY-NOT_{:?}(", spaces, kind)?;
245 pretty_print_expr_inner(expr, buf, indent + 1)?;
246 writeln!(buf, "{})", spaces)?;
247 }
248 TLExpr::FuzzyImplication {
249 kind,
250 premise,
251 conclusion,
252 } => {
253 writeln!(buf, "{}FUZZY-IMPLY_{:?}(", spaces, kind)?;
254 pretty_print_expr_inner(premise, buf, indent + 1)?;
255 writeln!(buf, "{}⇒", spaces)?;
256 pretty_print_expr_inner(conclusion, buf, indent + 1)?;
257 writeln!(buf, "{})", spaces)?;
258 }
259
260 TLExpr::SoftExists {
262 var,
263 domain,
264 body,
265 temperature,
266 } => {
267 writeln!(
268 buf,
269 "{}SOFT-∃{}:{}[T={}](",
270 spaces, var, domain, temperature
271 )?;
272 pretty_print_expr_inner(body, buf, indent + 1)?;
273 writeln!(buf, "{})", spaces)?;
274 }
275 TLExpr::SoftForAll {
276 var,
277 domain,
278 body,
279 temperature,
280 } => {
281 writeln!(
282 buf,
283 "{}SOFT-∀{}:{}[T={}](",
284 spaces, var, domain, temperature
285 )?;
286 pretty_print_expr_inner(body, buf, indent + 1)?;
287 writeln!(buf, "{})", spaces)?;
288 }
289 TLExpr::WeightedRule { weight, rule } => {
290 writeln!(buf, "{}WEIGHTED[{}](", spaces, weight)?;
291 pretty_print_expr_inner(rule, buf, indent + 1)?;
292 writeln!(buf, "{})", spaces)?;
293 }
294 TLExpr::ProbabilisticChoice { alternatives } => {
295 writeln!(buf, "{}PROB-CHOICE[", spaces)?;
296 for (i, (prob, expr)) in alternatives.iter().enumerate() {
297 if i > 0 {
298 writeln!(buf, "{},", spaces)?;
299 }
300 writeln!(buf, "{} {}: ", spaces, prob)?;
301 pretty_print_expr_inner(expr, buf, indent + 2)?;
302 }
303 writeln!(buf, "{}]", spaces)?;
304 }
305
306 TLExpr::Release { released, releaser } => {
308 writeln!(buf, "{}RELEASE(", spaces)?;
309 pretty_print_expr_inner(released, buf, indent + 1)?;
310 writeln!(buf, "{},", spaces)?;
311 pretty_print_expr_inner(releaser, buf, indent + 1)?;
312 writeln!(buf, "{})", spaces)?;
313 }
314 TLExpr::WeakUntil { before, after } => {
315 writeln!(buf, "{}WEAK-UNTIL(", spaces)?;
316 pretty_print_expr_inner(before, buf, indent + 1)?;
317 writeln!(buf, "{},", spaces)?;
318 pretty_print_expr_inner(after, buf, indent + 1)?;
319 writeln!(buf, "{})", spaces)?;
320 }
321 TLExpr::StrongRelease { released, releaser } => {
322 writeln!(buf, "{}STRONG-RELEASE(", spaces)?;
323 pretty_print_expr_inner(released, buf, indent + 1)?;
324 writeln!(buf, "{},", spaces)?;
325 pretty_print_expr_inner(releaser, buf, indent + 1)?;
326 writeln!(buf, "{})", spaces)?;
327 }
328
329 TLExpr::Eq(l, r) => {
330 writeln!(buf, "{}EQ(", spaces)?;
331 pretty_print_expr_inner(l, buf, indent + 1)?;
332 writeln!(buf, "{},", spaces)?;
333 pretty_print_expr_inner(r, buf, indent + 1)?;
334 writeln!(buf, "{})", spaces)?;
335 }
336 TLExpr::Lt(l, r) => {
337 writeln!(buf, "{}LT(", spaces)?;
338 pretty_print_expr_inner(l, buf, indent + 1)?;
339 writeln!(buf, "{},", spaces)?;
340 pretty_print_expr_inner(r, buf, indent + 1)?;
341 writeln!(buf, "{})", spaces)?;
342 }
343 TLExpr::Gt(l, r) => {
344 writeln!(buf, "{}GT(", spaces)?;
345 pretty_print_expr_inner(l, buf, indent + 1)?;
346 writeln!(buf, "{},", spaces)?;
347 pretty_print_expr_inner(r, buf, indent + 1)?;
348 writeln!(buf, "{})", spaces)?;
349 }
350 TLExpr::Lte(l, r) => {
351 writeln!(buf, "{}LTE(", spaces)?;
352 pretty_print_expr_inner(l, buf, indent + 1)?;
353 writeln!(buf, "{},", spaces)?;
354 pretty_print_expr_inner(r, buf, indent + 1)?;
355 writeln!(buf, "{})", spaces)?;
356 }
357 TLExpr::Gte(l, r) => {
358 writeln!(buf, "{}GTE(", spaces)?;
359 pretty_print_expr_inner(l, buf, indent + 1)?;
360 writeln!(buf, "{},", spaces)?;
361 pretty_print_expr_inner(r, buf, indent + 1)?;
362 writeln!(buf, "{})", spaces)?;
363 }
364 TLExpr::IfThenElse {
365 condition,
366 then_branch,
367 else_branch,
368 } => {
369 writeln!(buf, "{}IF(", spaces)?;
370 pretty_print_expr_inner(condition, buf, indent + 1)?;
371 writeln!(buf, "{}) THEN(", spaces)?;
372 pretty_print_expr_inner(then_branch, buf, indent + 1)?;
373 writeln!(buf, "{}) ELSE(", spaces)?;
374 pretty_print_expr_inner(else_branch, buf, indent + 1)?;
375 writeln!(buf, "{})", spaces)?;
376 }
377 TLExpr::Let { var, value, body } => {
378 writeln!(buf, "{}LET {} =(", spaces, var)?;
379 pretty_print_expr_inner(value, buf, indent + 1)?;
380 writeln!(buf, "{}) IN(", spaces)?;
381 pretty_print_expr_inner(body, buf, indent + 1)?;
382 writeln!(buf, "{})", spaces)?;
383 }
384 TLExpr::Constant(value) => {
385 writeln!(buf, "{}{}", spaces, value)?;
386 }
387 }
388
389 Ok(())
390}
391
392fn term_to_string(term: &Term) -> String {
393 match term {
394 Term::Var(name) => format!("?{}", name),
395 Term::Const(name) => name.clone(),
396 Term::Typed {
397 value,
398 type_annotation,
399 } => format!("{}:{}", term_to_string(value), type_annotation.type_name),
400 }
401}
402
403#[derive(Debug, Clone, PartialEq, Eq)]
405pub struct ExprStats {
406 pub node_count: usize,
408 pub max_depth: usize,
410 pub predicate_count: usize,
412 pub quantifier_count: usize,
414 pub logical_op_count: usize,
416 pub arithmetic_op_count: usize,
418 pub comparison_op_count: usize,
420 pub free_var_count: usize,
422}
423
424impl ExprStats {
425 pub fn compute(expr: &TLExpr) -> Self {
427 let mut stats = ExprStats {
428 node_count: 0,
429 max_depth: 0,
430 predicate_count: 0,
431 quantifier_count: 0,
432 logical_op_count: 0,
433 arithmetic_op_count: 0,
434 comparison_op_count: 0,
435 free_var_count: expr.free_vars().len(),
436 };
437
438 stats.max_depth = Self::compute_recursive(expr, &mut stats, 0);
439 stats
440 }
441
442 fn compute_recursive(expr: &TLExpr, stats: &mut ExprStats, depth: usize) -> usize {
443 stats.node_count += 1;
444 let mut max_child_depth = depth;
445
446 match expr {
447 TLExpr::Pred { .. } => {
448 stats.predicate_count += 1;
449 }
450 TLExpr::And(l, r) | TLExpr::Or(l, r) | TLExpr::Imply(l, r) => {
451 stats.logical_op_count += 1;
452 let left_depth = Self::compute_recursive(l, stats, depth + 1);
453 let right_depth = Self::compute_recursive(r, stats, depth + 1);
454 max_child_depth = left_depth.max(right_depth);
455 }
456 TLExpr::Not(e) | TLExpr::Score(e) => {
457 stats.logical_op_count += 1;
458 max_child_depth = Self::compute_recursive(e, stats, depth + 1);
459 }
460 TLExpr::Exists { body, .. } | TLExpr::ForAll { body, .. } => {
461 stats.quantifier_count += 1;
462 max_child_depth = Self::compute_recursive(body, stats, depth + 1);
463 }
464 TLExpr::Aggregate { body, .. } => {
465 stats.quantifier_count += 1; max_child_depth = Self::compute_recursive(body, stats, depth + 1);
467 }
468 TLExpr::Add(l, r)
469 | TLExpr::Sub(l, r)
470 | TLExpr::Mul(l, r)
471 | TLExpr::Div(l, r)
472 | TLExpr::Pow(l, r)
473 | TLExpr::Mod(l, r)
474 | TLExpr::Min(l, r)
475 | TLExpr::Max(l, r) => {
476 stats.arithmetic_op_count += 1;
477 let left_depth = Self::compute_recursive(l, stats, depth + 1);
478 let right_depth = Self::compute_recursive(r, stats, depth + 1);
479 max_child_depth = left_depth.max(right_depth);
480 }
481 TLExpr::Abs(e)
482 | TLExpr::Floor(e)
483 | TLExpr::Ceil(e)
484 | TLExpr::Round(e)
485 | TLExpr::Sqrt(e)
486 | TLExpr::Exp(e)
487 | TLExpr::Log(e)
488 | TLExpr::Sin(e)
489 | TLExpr::Cos(e)
490 | TLExpr::Tan(e)
491 | TLExpr::Box(e)
492 | TLExpr::Diamond(e)
493 | TLExpr::Next(e)
494 | TLExpr::Eventually(e)
495 | TLExpr::Always(e) => {
496 stats.arithmetic_op_count += 1;
497 max_child_depth = Self::compute_recursive(e, stats, depth + 1);
498 }
499 TLExpr::Until { before, after } => {
500 stats.logical_op_count += 1;
501 let depth_before = Self::compute_recursive(before, stats, depth + 1);
502 let depth_after = Self::compute_recursive(after, stats, depth + 1);
503 max_child_depth = depth_before.max(depth_after);
504 }
505 TLExpr::Eq(l, r)
506 | TLExpr::Lt(l, r)
507 | TLExpr::Gt(l, r)
508 | TLExpr::Lte(l, r)
509 | TLExpr::Gte(l, r) => {
510 stats.comparison_op_count += 1;
511 let left_depth = Self::compute_recursive(l, stats, depth + 1);
512 let right_depth = Self::compute_recursive(r, stats, depth + 1);
513 max_child_depth = left_depth.max(right_depth);
514 }
515 TLExpr::IfThenElse {
516 condition,
517 then_branch,
518 else_branch,
519 } => {
520 let cond_depth = Self::compute_recursive(condition, stats, depth + 1);
521 let then_depth = Self::compute_recursive(then_branch, stats, depth + 1);
522 let else_depth = Self::compute_recursive(else_branch, stats, depth + 1);
523 max_child_depth = cond_depth.max(then_depth).max(else_depth);
524 }
525 TLExpr::Let { value, body, .. } => {
526 let value_depth = Self::compute_recursive(value, stats, depth + 1);
527 let body_depth = Self::compute_recursive(body, stats, depth + 1);
528 max_child_depth = value_depth.max(body_depth);
529 }
530
531 TLExpr::TNorm { left, right, .. } | TLExpr::TCoNorm { left, right, .. } => {
533 stats.logical_op_count += 1;
534 let left_depth = Self::compute_recursive(left, stats, depth + 1);
535 let right_depth = Self::compute_recursive(right, stats, depth + 1);
536 max_child_depth = left_depth.max(right_depth);
537 }
538 TLExpr::FuzzyNot { expr, .. } => {
539 stats.logical_op_count += 1;
540 max_child_depth = Self::compute_recursive(expr, stats, depth + 1);
541 }
542 TLExpr::FuzzyImplication {
543 premise,
544 conclusion,
545 ..
546 } => {
547 stats.logical_op_count += 1;
548 let prem_depth = Self::compute_recursive(premise, stats, depth + 1);
549 let conc_depth = Self::compute_recursive(conclusion, stats, depth + 1);
550 max_child_depth = prem_depth.max(conc_depth);
551 }
552
553 TLExpr::SoftExists { body, .. } | TLExpr::SoftForAll { body, .. } => {
555 stats.quantifier_count += 1;
556 max_child_depth = Self::compute_recursive(body, stats, depth + 1);
557 }
558 TLExpr::WeightedRule { rule, .. } => {
559 stats.logical_op_count += 1;
560 max_child_depth = Self::compute_recursive(rule, stats, depth + 1);
561 }
562 TLExpr::ProbabilisticChoice { alternatives } => {
563 stats.logical_op_count += 1;
564 let mut max_alt_depth = depth;
565 for (_, expr) in alternatives {
566 let alt_depth = Self::compute_recursive(expr, stats, depth + 1);
567 max_alt_depth = max_alt_depth.max(alt_depth);
568 }
569 max_child_depth = max_alt_depth;
570 }
571
572 TLExpr::Release { released, releaser }
574 | TLExpr::WeakUntil {
575 before: released,
576 after: releaser,
577 }
578 | TLExpr::StrongRelease { released, releaser } => {
579 stats.logical_op_count += 1;
580 let rel_depth = Self::compute_recursive(released, stats, depth + 1);
581 let reler_depth = Self::compute_recursive(releaser, stats, depth + 1);
582 max_child_depth = rel_depth.max(reler_depth);
583 }
584
585 TLExpr::Constant(_) => {
586 }
588 }
589
590 max_child_depth
591 }
592}
593
594#[derive(Debug, Clone, PartialEq)]
596pub struct GraphStats {
597 pub tensor_count: usize,
599 pub node_count: usize,
601 pub output_count: usize,
603 pub einsum_count: usize,
605 pub elem_unary_count: usize,
607 pub elem_binary_count: usize,
609 pub reduce_count: usize,
611 pub avg_inputs_per_node: f64,
613}
614
615impl GraphStats {
616 pub fn compute(graph: &EinsumGraph) -> Self {
618 let mut stats = GraphStats {
619 tensor_count: graph.tensors.len(),
620 node_count: graph.nodes.len(),
621 output_count: graph.outputs.len(),
622 einsum_count: 0,
623 elem_unary_count: 0,
624 elem_binary_count: 0,
625 reduce_count: 0,
626 avg_inputs_per_node: 0.0,
627 };
628
629 let mut total_inputs = 0;
630
631 for node in &graph.nodes {
632 total_inputs += node.inputs.len();
633
634 match &node.op {
635 crate::graph::OpType::Einsum { .. } => stats.einsum_count += 1,
636 crate::graph::OpType::ElemUnary { .. } => stats.elem_unary_count += 1,
637 crate::graph::OpType::ElemBinary { .. } => stats.elem_binary_count += 1,
638 crate::graph::OpType::Reduce { .. } => stats.reduce_count += 1,
639 }
640 }
641
642 if stats.node_count > 0 {
643 stats.avg_inputs_per_node = total_inputs as f64 / stats.node_count as f64;
644 }
645
646 stats
647 }
648}
649
650pub fn pretty_print_graph(graph: &EinsumGraph) -> String {
652 let mut buffer = String::new();
653 writeln!(buffer, "EinsumGraph {{").unwrap();
654 writeln!(buffer, " Tensors: {} total", graph.tensors.len()).unwrap();
655
656 for (idx, name) in graph.tensors.iter().enumerate() {
657 writeln!(buffer, " t{}: {}", idx, name).unwrap();
658 }
659
660 writeln!(buffer, " Nodes: {} total", graph.nodes.len()).unwrap();
661 for (idx, node) in graph.nodes.iter().enumerate() {
662 write!(buffer, " n{}: ", idx).unwrap();
663 match &node.op {
664 crate::graph::OpType::Einsum { spec } => {
665 write!(buffer, "Einsum(\"{}\")", spec).unwrap()
666 }
667 crate::graph::OpType::ElemUnary { op } => write!(buffer, "ElemUnary({})", op).unwrap(),
668 crate::graph::OpType::ElemBinary { op } => {
669 write!(buffer, "ElemBinary({})", op).unwrap()
670 }
671 crate::graph::OpType::Reduce { op, axes } => {
672 write!(buffer, "Reduce({}, axes={:?})", op, axes).unwrap()
673 }
674 }
675 write!(buffer, " <- [").unwrap();
676 for (i, input) in node.inputs.iter().enumerate() {
677 if i > 0 {
678 write!(buffer, ", ").unwrap();
679 }
680 write!(buffer, "t{}", input).unwrap();
681 }
682 writeln!(buffer, "]").unwrap();
683 }
684
685 writeln!(buffer, " Outputs: {:?}", graph.outputs).unwrap();
686 writeln!(buffer, "}}").unwrap();
687
688 buffer
689}
690
691#[cfg(test)]
692mod tests {
693 use super::*;
694
695 #[test]
696 fn test_expr_stats_simple() {
697 let expr = TLExpr::pred("P", vec![Term::var("x")]);
698 let stats = ExprStats::compute(&expr);
699
700 assert_eq!(stats.node_count, 1);
701 assert_eq!(stats.predicate_count, 1);
702 assert_eq!(stats.quantifier_count, 0);
703 assert_eq!(stats.free_var_count, 1);
704 }
705
706 #[test]
707 fn test_expr_stats_complex() {
708 let p = TLExpr::pred("P", vec![Term::var("x")]);
710 let q = TLExpr::pred("Q", vec![Term::var("x")]);
711 let and_expr = TLExpr::and(p, q);
712 let expr = TLExpr::forall("x", "Domain", and_expr);
713
714 let stats = ExprStats::compute(&expr);
715
716 assert_eq!(stats.node_count, 4); assert_eq!(stats.predicate_count, 2);
718 assert_eq!(stats.quantifier_count, 1);
719 assert_eq!(stats.logical_op_count, 1);
720 assert_eq!(stats.free_var_count, 0); }
722
723 #[test]
724 fn test_expr_stats_arithmetic() {
725 let score = TLExpr::pred("score", vec![Term::var("x")]);
727 let mul = TLExpr::mul(score, TLExpr::constant(2.0));
728 let add = TLExpr::add(mul, TLExpr::constant(1.0));
729
730 let stats = ExprStats::compute(&add);
731
732 assert_eq!(stats.arithmetic_op_count, 2); assert_eq!(stats.predicate_count, 1);
734 }
735
736 #[test]
737 fn test_graph_stats() {
738 let mut graph = EinsumGraph::new();
739 let t0 = graph.add_tensor("input");
740 let t1 = graph.add_tensor("output");
741
742 graph
743 .add_node(crate::EinsumNode {
744 inputs: vec![t0],
745 outputs: vec![t1],
746 op: crate::graph::OpType::Einsum {
747 spec: "i->i".to_string(),
748 },
749 metadata: None,
750 })
751 .unwrap();
752
753 graph.add_output(t1).unwrap();
754
755 let stats = GraphStats::compute(&graph);
756
757 assert_eq!(stats.tensor_count, 2);
758 assert_eq!(stats.node_count, 1);
759 assert_eq!(stats.output_count, 1);
760 assert_eq!(stats.einsum_count, 1);
761 assert_eq!(stats.avg_inputs_per_node, 1.0);
762 }
763
764 #[test]
765 fn test_pretty_print_expr() {
766 let expr = TLExpr::pred("Person", vec![Term::var("x")]);
767 let output = pretty_print_expr(&expr);
768 assert!(output.contains("Person(?x)"));
769 }
770
771 #[test]
772 fn test_pretty_print_graph() {
773 let mut graph = EinsumGraph::new();
774 let _t0 = graph.add_tensor("input");
775
776 let output = pretty_print_graph(&graph);
777 assert!(output.contains("t0: input"));
778 assert!(output.contains("Tensors: 1 total"));
779 }
780}