1use std::fmt;
6
7use crate::{
8 expr::{AggregateOp, TLExpr},
9 graph::{EinsumGraph, EinsumNode, OpType},
10 term::Term,
11};
12
13impl fmt::Display for Term {
14 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
15 match self {
16 Term::Var(name) => write!(f, "?{}", name),
17 Term::Const(name) => write!(f, "{}", name),
18 Term::Typed {
19 value,
20 type_annotation,
21 } => write!(f, "{}:{}", value, type_annotation.type_name),
22 }
23 }
24}
25
26impl fmt::Display for AggregateOp {
27 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28 match self {
29 AggregateOp::Count => write!(f, "COUNT"),
30 AggregateOp::Sum => write!(f, "SUM"),
31 AggregateOp::Average => write!(f, "AVG"),
32 AggregateOp::Max => write!(f, "MAX"),
33 AggregateOp::Min => write!(f, "MIN"),
34 AggregateOp::Product => write!(f, "PROD"),
35 AggregateOp::Any => write!(f, "ANY"),
36 AggregateOp::All => write!(f, "ALL"),
37 }
38 }
39}
40
41impl fmt::Display for TLExpr {
42 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43 match self {
44 TLExpr::Pred { name, args } => {
45 write!(f, "{}(", name)?;
46 for (i, arg) in args.iter().enumerate() {
47 if i > 0 {
48 write!(f, ", ")?;
49 }
50 write!(f, "{}", arg)?;
51 }
52 write!(f, ")")
53 }
54 TLExpr::And(l, r) => write!(f, "({} ∧ {})", l, r),
55 TLExpr::Or(l, r) => write!(f, "({} ∨ {})", l, r),
56 TLExpr::Not(e) => write!(f, "¬{}", e),
57 TLExpr::Exists { var, domain, body } => {
58 write!(f, "∃{}:{}. {}", var, domain, body)
59 }
60 TLExpr::ForAll { var, domain, body } => {
61 write!(f, "∀{}:{}. {}", var, domain, body)
62 }
63 TLExpr::Aggregate {
64 op,
65 var,
66 domain,
67 body,
68 group_by,
69 } => {
70 write!(f, "{}({}:{}, ", op, var, domain)?;
71 if let Some(group_vars) = group_by {
72 write!(f, "GROUP BY [")?;
73 for (i, gv) in group_vars.iter().enumerate() {
74 if i > 0 {
75 write!(f, ", ")?;
76 }
77 write!(f, "{}", gv)?;
78 }
79 write!(f, "], ")?;
80 }
81 write!(f, "{})", body)
82 }
83 TLExpr::Imply(premise, conclusion) => write!(f, "({} → {})", premise, conclusion),
84 TLExpr::Score(e) => write!(f, "score({})", e),
85 TLExpr::Add(l, r) => write!(f, "({} + {})", l, r),
86 TLExpr::Sub(l, r) => write!(f, "({} - {})", l, r),
87 TLExpr::Mul(l, r) => write!(f, "({} * {})", l, r),
88 TLExpr::Div(l, r) => write!(f, "({} / {})", l, r),
89 TLExpr::Pow(l, r) => write!(f, "({} ^ {})", l, r),
90 TLExpr::Mod(l, r) => write!(f, "({} % {})", l, r),
91 TLExpr::Min(l, r) => write!(f, "min({}, {})", l, r),
92 TLExpr::Max(l, r) => write!(f, "max({}, {})", l, r),
93 TLExpr::Abs(e) => write!(f, "abs({})", e),
94 TLExpr::Floor(e) => write!(f, "floor({})", e),
95 TLExpr::Ceil(e) => write!(f, "ceil({})", e),
96 TLExpr::Round(e) => write!(f, "round({})", e),
97 TLExpr::Sqrt(e) => write!(f, "sqrt({})", e),
98 TLExpr::Exp(e) => write!(f, "exp({})", e),
99 TLExpr::Log(e) => write!(f, "log({})", e),
100 TLExpr::Sin(e) => write!(f, "sin({})", e),
101 TLExpr::Cos(e) => write!(f, "cos({})", e),
102 TLExpr::Tan(e) => write!(f, "tan({})", e),
103 TLExpr::Eq(l, r) => write!(f, "({} = {})", l, r),
104 TLExpr::Lt(l, r) => write!(f, "({} < {})", l, r),
105 TLExpr::Gt(l, r) => write!(f, "({} > {})", l, r),
106 TLExpr::Lte(l, r) => write!(f, "({} ≤ {})", l, r),
107 TLExpr::Gte(l, r) => write!(f, "({} ≥ {})", l, r),
108 TLExpr::IfThenElse {
109 condition,
110 then_branch,
111 else_branch,
112 } => write!(
113 f,
114 "if {} then {} else {}",
115 condition, then_branch, else_branch
116 ),
117 TLExpr::Let { var, value, body } => {
118 write!(f, "let {} = {} in {}", var, value, body)
119 }
120 TLExpr::Box(e) => write!(f, "□{}", e),
121 TLExpr::Diamond(e) => write!(f, "◇{}", e),
122 TLExpr::Next(e) => write!(f, "X{}", e),
123 TLExpr::Eventually(e) => write!(f, "F{}", e),
124 TLExpr::Always(e) => write!(f, "G{}", e),
125 TLExpr::Until { before, after } => write!(f, "({} U {})", before, after),
126 TLExpr::TNorm { kind, left, right } => {
128 write!(f, "({} ⊤_{:?} {})", left, kind, right)
129 }
130 TLExpr::TCoNorm { kind, left, right } => {
131 write!(f, "({} ⊥_{:?} {})", left, kind, right)
132 }
133 TLExpr::FuzzyNot { kind, expr } => write!(f, "¬_{:?}({})", kind, expr),
134 TLExpr::FuzzyImplication {
135 kind,
136 premise,
137 conclusion,
138 } => write!(f, "({} →_{:?} {})", premise, kind, conclusion),
139 TLExpr::SoftExists {
141 var,
142 domain,
143 body,
144 temperature,
145 } => write!(f, "∃^{{{}}}{}:{}. {}", temperature, var, domain, body),
146 TLExpr::SoftForAll {
147 var,
148 domain,
149 body,
150 temperature,
151 } => write!(f, "∀^{{{}}}{}:{}. {}", temperature, var, domain, body),
152 TLExpr::WeightedRule { weight, rule } => write!(f, "{}::{}", weight, rule),
153 TLExpr::ProbabilisticChoice { alternatives } => {
154 write!(f, "choice[")?;
155 for (i, (prob, expr)) in alternatives.iter().enumerate() {
156 if i > 0 {
157 write!(f, ", ")?;
158 }
159 write!(f, "{}: {}", prob, expr)?;
160 }
161 write!(f, "]")
162 }
163 TLExpr::Release { released, releaser } => write!(f, "({} R {})", released, releaser),
165 TLExpr::WeakUntil { before, after } => write!(f, "({} W {})", before, after),
166 TLExpr::StrongRelease { released, releaser } => {
167 write!(f, "({} M {})", released, releaser)
168 }
169 TLExpr::Lambda {
171 var,
172 var_type,
173 body,
174 } => {
175 if let Some(ty) = var_type {
176 write!(f, "λ{}:{}. {}", var, ty, body)
177 } else {
178 write!(f, "λ{}. {}", var, body)
179 }
180 }
181 TLExpr::Apply { function, argument } => write!(f, "({} {})", function, argument),
182 TLExpr::SetMembership { element, set } => write!(f, "({} ∈ {})", element, set),
183 TLExpr::SetUnion { left, right } => write!(f, "({} ∪ {})", left, right),
184 TLExpr::SetIntersection { left, right } => write!(f, "({} ∩ {})", left, right),
185 TLExpr::SetDifference { left, right } => write!(f, "({} \\ {})", left, right),
186 TLExpr::SetCardinality { set } => write!(f, "|{}|", set),
187 TLExpr::EmptySet => write!(f, "∅"),
188 TLExpr::SetComprehension {
189 var,
190 domain,
191 condition,
192 } => write!(f, "{{ {}:{} | {} }}", var, domain, condition),
193 TLExpr::CountingExists {
194 var,
195 domain,
196 body,
197 min_count,
198 } => write!(f, "∃≥{}{}:{}. {}", min_count, var, domain, body),
199 TLExpr::CountingForAll {
200 var,
201 domain,
202 body,
203 min_count,
204 } => write!(f, "∀≥{}{}:{}. {}", min_count, var, domain, body),
205 TLExpr::ExactCount {
206 var,
207 domain,
208 body,
209 count,
210 } => write!(f, "∃={}{}:{}. {}", count, var, domain, body),
211 TLExpr::Majority { var, domain, body } => {
212 write!(f, "Majority {}:{}. {}", var, domain, body)
213 }
214 TLExpr::LeastFixpoint { var, body } => write!(f, "μ{}. {}", var, body),
215 TLExpr::GreatestFixpoint { var, body } => write!(f, "ν{}. {}", var, body),
216 TLExpr::Nominal { name } => write!(f, "@{}", name),
217 TLExpr::At { nominal, formula } => write!(f, "@{} {}", nominal, formula),
218 TLExpr::Somewhere { formula } => write!(f, "E {}", formula),
219 TLExpr::Everywhere { formula } => write!(f, "A {}", formula),
220 TLExpr::AllDifferent { variables } => {
221 write!(f, "alldiff([")?;
222 for (i, var) in variables.iter().enumerate() {
223 if i > 0 {
224 write!(f, ", ")?;
225 }
226 write!(f, "{}", var)?;
227 }
228 write!(f, "])")
229 }
230 TLExpr::GlobalCardinality {
231 variables,
232 values,
233 min_occurrences,
234 max_occurrences,
235 } => {
236 write!(f, "gcc([")?;
237 for (i, var) in variables.iter().enumerate() {
238 if i > 0 {
239 write!(f, ", ")?;
240 }
241 write!(f, "{}", var)?;
242 }
243 write!(f, "], [")?;
244 for (i, val) in values.iter().enumerate() {
245 if i > 0 {
246 write!(f, ", ")?;
247 }
248 write!(f, "{}:[{},{}]", val, min_occurrences[i], max_occurrences[i])?;
249 }
250 write!(f, "])")
251 }
252 TLExpr::Abducible { name, cost } => write!(f, "abd({}:{})", name, cost),
253 TLExpr::Explain { formula } => write!(f, "explain({})", formula),
254 TLExpr::Constant(value) => write!(f, "{}", value),
255 TLExpr::SymbolLiteral(s) => write!(f, ":{s}"),
256 TLExpr::Match { scrutinee, arms } => {
257 write!(f, "(match {scrutinee}")?;
258 for (pat, body) in arms {
259 write!(f, " [{pat} => {body}]")?;
260 }
261 write!(f, ")")
262 }
263 }
264 }
265}
266
267impl fmt::Display for OpType {
268 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
269 match self {
270 OpType::Einsum { spec } => write!(f, "einsum({})", spec),
271 OpType::ElemUnary { op } => write!(f, "{}(·)", op),
272 OpType::ElemBinary { op } => write!(f, "{}(·, ·)", op),
273 OpType::Reduce { op, axes } => write!(f, "{}(·, axes={:?})", op, axes),
274 }
275 }
276}
277
278impl fmt::Display for EinsumNode {
279 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
280 write!(f, "{} ", self.op)?;
281 write!(f, "inputs={:?}", self.inputs)?;
282 write!(f, " outputs={:?}", self.outputs)
283 }
284}
285
286impl fmt::Display for EinsumGraph {
287 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
288 writeln!(f, "EinsumGraph {{")?;
289 writeln!(f, " tensors: {:?}", self.tensors)?;
290 writeln!(f, " nodes: [")?;
291 for (i, node) in self.nodes.iter().enumerate() {
292 writeln!(f, " {}: {}", i, node)?;
293 }
294 writeln!(f, " ]")?;
295 writeln!(f, " outputs: {:?}", self.outputs)?;
296 write!(f, "}}")
297 }
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303
304 #[test]
305 fn test_display_term() {
306 let var = Term::var("x");
307 assert_eq!(format!("{}", var), "?x");
308
309 let const_term = Term::constant("alice");
310 assert_eq!(format!("{}", const_term), "alice");
311
312 let typed = Term::typed_var("x", "Int");
313 assert_eq!(format!("{}", typed), "?x:Int");
314 }
315
316 #[test]
317 fn test_display_aggregate_op() {
318 assert_eq!(format!("{}", AggregateOp::Count), "COUNT");
319 assert_eq!(format!("{}", AggregateOp::Sum), "SUM");
320 assert_eq!(format!("{}", AggregateOp::Average), "AVG");
321 }
322
323 #[test]
324 fn test_display_simple_expr() {
325 let pred = TLExpr::pred("Person", vec![Term::var("x")]);
326 assert_eq!(format!("{}", pred), "Person(?x)");
327 }
328
329 #[test]
330 fn test_display_logical_ops() {
331 let p = TLExpr::pred("P", vec![Term::var("x")]);
332 let q = TLExpr::pred("Q", vec![Term::var("y")]);
333
334 let and_expr = TLExpr::and(p.clone(), q.clone());
335 assert_eq!(format!("{}", and_expr), "(P(?x) ∧ Q(?y))");
336
337 let or_expr = TLExpr::or(p.clone(), q);
338 assert_eq!(format!("{}", or_expr), "(P(?x) ∨ Q(?y))");
339
340 let not_expr = TLExpr::negate(p);
341 assert_eq!(format!("{}", not_expr), "¬P(?x)");
342 }
343
344 #[test]
345 fn test_display_quantifiers() {
346 let body = TLExpr::pred("P", vec![Term::var("x")]);
347
348 let exists = TLExpr::exists("x", "Domain", body.clone());
349 assert_eq!(format!("{}", exists), "∃x:Domain. P(?x)");
350
351 let forall = TLExpr::forall("x", "Domain", body);
352 assert_eq!(format!("{}", forall), "∀x:Domain. P(?x)");
353 }
354
355 #[test]
356 fn test_display_aggregate() {
357 let body = TLExpr::pred("Value", vec![Term::var("x")]);
358
359 let sum = TLExpr::sum("x", "Domain", body.clone());
360 assert_eq!(format!("{}", sum), "SUM(x:Domain, Value(?x))");
361
362 let count = TLExpr::count("x", "Domain", body);
363 assert_eq!(format!("{}", count), "COUNT(x:Domain, Value(?x))");
364 }
365
366 #[test]
367 fn test_display_aggregate_with_group_by() {
368 let body = TLExpr::pred("Value", vec![Term::var("x"), Term::var("y")]);
369
370 let agg = TLExpr::aggregate_with_group_by(
371 AggregateOp::Sum,
372 "x",
373 "Domain",
374 body,
375 vec!["y".to_string()],
376 );
377
378 let display = format!("{}", agg);
379 assert!(display.contains("SUM"));
380 assert!(display.contains("GROUP BY"));
381 assert!(display.contains("y"));
382 }
383
384 #[test]
385 fn test_display_arithmetic() {
386 let x = TLExpr::constant(5.0);
387 let y = TLExpr::constant(3.0);
388
389 let add = TLExpr::add(x.clone(), y.clone());
390 assert_eq!(format!("{}", add), "(5 + 3)");
391
392 let mul = TLExpr::mul(x, y);
393 assert_eq!(format!("{}", mul), "(5 * 3)");
394 }
395
396 #[test]
397 fn test_display_comparison() {
398 let x = TLExpr::pred("X", vec![Term::var("i")]);
399 let threshold = TLExpr::constant(0.5);
400
401 let gt = TLExpr::gt(x, threshold);
402 let display = format!("{}", gt);
403 assert!(display.contains(">"));
404 assert!(display.contains("0.5"));
405 }
406
407 #[test]
408 fn test_display_conditional() {
409 let cond = TLExpr::pred("IsAdult", vec![Term::var("x")]);
410 let then_br = TLExpr::constant(1.0);
411 let else_br = TLExpr::constant(0.0);
412
413 let if_expr = TLExpr::if_then_else(cond, then_br, else_br);
414 let display = format!("{}", if_expr);
415 assert!(display.contains("if"));
416 assert!(display.contains("then"));
417 assert!(display.contains("else"));
418 }
419
420 #[test]
421 fn test_display_einsum_node() {
422 let node = EinsumNode::new("ij,jk->ik", vec![0, 1], vec![2]);
423 let display = format!("{}", node);
424 assert!(display.contains("einsum"));
425 assert!(display.contains("ij,jk->ik"));
426 assert!(display.contains("inputs=[0, 1]"));
427 assert!(display.contains("outputs=[2]"));
428 }
429
430 #[test]
431 fn test_display_graph() {
432 let mut graph = EinsumGraph::new();
433 let t0 = graph.add_tensor("input");
434 let t1 = graph.add_tensor("output");
435
436 graph
437 .add_node(EinsumNode::new("i->i", vec![t0], vec![t1]))
438 .expect("unwrap");
439 graph.add_output(t1).expect("unwrap");
440
441 let display = format!("{}", graph);
442 assert!(display.contains("EinsumGraph"));
443 assert!(display.contains("tensors"));
444 assert!(display.contains("input"));
445 assert!(display.contains("output"));
446 }
447}