1use std::fmt;
6
7use crate::{
8 expr::{AggregateOp, TLExpr},
9 graph::{EinsumGraph, EinsumNode, OpType},
10 term::Term,
11};
12
13impl fmt::Display for Term {
14 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
15 match self {
16 Term::Var(name) => write!(f, "?{}", name),
17 Term::Const(name) => write!(f, "{}", name),
18 Term::Typed {
19 value,
20 type_annotation,
21 } => write!(f, "{}:{}", value, type_annotation.type_name),
22 }
23 }
24}
25
26impl fmt::Display for AggregateOp {
27 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28 match self {
29 AggregateOp::Count => write!(f, "COUNT"),
30 AggregateOp::Sum => write!(f, "SUM"),
31 AggregateOp::Average => write!(f, "AVG"),
32 AggregateOp::Max => write!(f, "MAX"),
33 AggregateOp::Min => write!(f, "MIN"),
34 AggregateOp::Product => write!(f, "PROD"),
35 AggregateOp::Any => write!(f, "ANY"),
36 AggregateOp::All => write!(f, "ALL"),
37 }
38 }
39}
40
41impl fmt::Display for TLExpr {
42 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43 match self {
44 TLExpr::Pred { name, args } => {
45 write!(f, "{}(", name)?;
46 for (i, arg) in args.iter().enumerate() {
47 if i > 0 {
48 write!(f, ", ")?;
49 }
50 write!(f, "{}", arg)?;
51 }
52 write!(f, ")")
53 }
54 TLExpr::And(l, r) => write!(f, "({} ∧ {})", l, r),
55 TLExpr::Or(l, r) => write!(f, "({} ∨ {})", l, r),
56 TLExpr::Not(e) => write!(f, "¬{}", e),
57 TLExpr::Exists { var, domain, body } => {
58 write!(f, "∃{}:{}. {}", var, domain, body)
59 }
60 TLExpr::ForAll { var, domain, body } => {
61 write!(f, "∀{}:{}. {}", var, domain, body)
62 }
63 TLExpr::Aggregate {
64 op,
65 var,
66 domain,
67 body,
68 group_by,
69 } => {
70 write!(f, "{}({}:{}, ", op, var, domain)?;
71 if let Some(group_vars) = group_by {
72 write!(f, "GROUP BY [")?;
73 for (i, gv) in group_vars.iter().enumerate() {
74 if i > 0 {
75 write!(f, ", ")?;
76 }
77 write!(f, "{}", gv)?;
78 }
79 write!(f, "], ")?;
80 }
81 write!(f, "{})", body)
82 }
83 TLExpr::Imply(premise, conclusion) => write!(f, "({} → {})", premise, conclusion),
84 TLExpr::Score(e) => write!(f, "score({})", e),
85 TLExpr::Add(l, r) => write!(f, "({} + {})", l, r),
86 TLExpr::Sub(l, r) => write!(f, "({} - {})", l, r),
87 TLExpr::Mul(l, r) => write!(f, "({} * {})", l, r),
88 TLExpr::Div(l, r) => write!(f, "({} / {})", l, r),
89 TLExpr::Pow(l, r) => write!(f, "({} ^ {})", l, r),
90 TLExpr::Mod(l, r) => write!(f, "({} % {})", l, r),
91 TLExpr::Min(l, r) => write!(f, "min({}, {})", l, r),
92 TLExpr::Max(l, r) => write!(f, "max({}, {})", l, r),
93 TLExpr::Abs(e) => write!(f, "abs({})", e),
94 TLExpr::Floor(e) => write!(f, "floor({})", e),
95 TLExpr::Ceil(e) => write!(f, "ceil({})", e),
96 TLExpr::Round(e) => write!(f, "round({})", e),
97 TLExpr::Sqrt(e) => write!(f, "sqrt({})", e),
98 TLExpr::Exp(e) => write!(f, "exp({})", e),
99 TLExpr::Log(e) => write!(f, "log({})", e),
100 TLExpr::Sin(e) => write!(f, "sin({})", e),
101 TLExpr::Cos(e) => write!(f, "cos({})", e),
102 TLExpr::Tan(e) => write!(f, "tan({})", e),
103 TLExpr::Eq(l, r) => write!(f, "({} = {})", l, r),
104 TLExpr::Lt(l, r) => write!(f, "({} < {})", l, r),
105 TLExpr::Gt(l, r) => write!(f, "({} > {})", l, r),
106 TLExpr::Lte(l, r) => write!(f, "({} ≤ {})", l, r),
107 TLExpr::Gte(l, r) => write!(f, "({} ≥ {})", l, r),
108 TLExpr::IfThenElse {
109 condition,
110 then_branch,
111 else_branch,
112 } => write!(
113 f,
114 "if {} then {} else {}",
115 condition, then_branch, else_branch
116 ),
117 TLExpr::Let { var, value, body } => {
118 write!(f, "let {} = {} in {}", var, value, body)
119 }
120 TLExpr::Box(e) => write!(f, "□{}", e),
121 TLExpr::Diamond(e) => write!(f, "◇{}", e),
122 TLExpr::Next(e) => write!(f, "X{}", e),
123 TLExpr::Eventually(e) => write!(f, "F{}", e),
124 TLExpr::Always(e) => write!(f, "G{}", e),
125 TLExpr::Until { before, after } => write!(f, "({} U {})", before, after),
126 TLExpr::TNorm { kind, left, right } => {
128 write!(f, "({} ⊤_{:?} {})", left, kind, right)
129 }
130 TLExpr::TCoNorm { kind, left, right } => {
131 write!(f, "({} ⊥_{:?} {})", left, kind, right)
132 }
133 TLExpr::FuzzyNot { kind, expr } => write!(f, "¬_{:?}({})", kind, expr),
134 TLExpr::FuzzyImplication {
135 kind,
136 premise,
137 conclusion,
138 } => write!(f, "({} →_{:?} {})", premise, kind, conclusion),
139 TLExpr::SoftExists {
141 var,
142 domain,
143 body,
144 temperature,
145 } => write!(f, "∃^{{{}}}{}:{}. {}", temperature, var, domain, body),
146 TLExpr::SoftForAll {
147 var,
148 domain,
149 body,
150 temperature,
151 } => write!(f, "∀^{{{}}}{}:{}. {}", temperature, var, domain, body),
152 TLExpr::WeightedRule { weight, rule } => write!(f, "{}::{}", weight, rule),
153 TLExpr::ProbabilisticChoice { alternatives } => {
154 write!(f, "choice[")?;
155 for (i, (prob, expr)) in alternatives.iter().enumerate() {
156 if i > 0 {
157 write!(f, ", ")?;
158 }
159 write!(f, "{}: {}", prob, expr)?;
160 }
161 write!(f, "]")
162 }
163 TLExpr::Release { released, releaser } => write!(f, "({} R {})", released, releaser),
165 TLExpr::WeakUntil { before, after } => write!(f, "({} W {})", before, after),
166 TLExpr::StrongRelease { released, releaser } => {
167 write!(f, "({} M {})", released, releaser)
168 }
169 TLExpr::Lambda {
171 var,
172 var_type,
173 body,
174 } => {
175 if let Some(ty) = var_type {
176 write!(f, "λ{}:{}. {}", var, ty, body)
177 } else {
178 write!(f, "λ{}. {}", var, body)
179 }
180 }
181 TLExpr::Apply { function, argument } => write!(f, "({} {})", function, argument),
182 TLExpr::SetMembership { element, set } => write!(f, "({} ∈ {})", element, set),
183 TLExpr::SetUnion { left, right } => write!(f, "({} ∪ {})", left, right),
184 TLExpr::SetIntersection { left, right } => write!(f, "({} ∩ {})", left, right),
185 TLExpr::SetDifference { left, right } => write!(f, "({} \\ {})", left, right),
186 TLExpr::SetCardinality { set } => write!(f, "|{}|", set),
187 TLExpr::EmptySet => write!(f, "∅"),
188 TLExpr::SetComprehension {
189 var,
190 domain,
191 condition,
192 } => write!(f, "{{ {}:{} | {} }}", var, domain, condition),
193 TLExpr::CountingExists {
194 var,
195 domain,
196 body,
197 min_count,
198 } => write!(f, "∃≥{}{}:{}. {}", min_count, var, domain, body),
199 TLExpr::CountingForAll {
200 var,
201 domain,
202 body,
203 min_count,
204 } => write!(f, "∀≥{}{}:{}. {}", min_count, var, domain, body),
205 TLExpr::ExactCount {
206 var,
207 domain,
208 body,
209 count,
210 } => write!(f, "∃={}{}:{}. {}", count, var, domain, body),
211 TLExpr::Majority { var, domain, body } => {
212 write!(f, "Majority {}:{}. {}", var, domain, body)
213 }
214 TLExpr::LeastFixpoint { var, body } => write!(f, "μ{}. {}", var, body),
215 TLExpr::GreatestFixpoint { var, body } => write!(f, "ν{}. {}", var, body),
216 TLExpr::Nominal { name } => write!(f, "@{}", name),
217 TLExpr::At { nominal, formula } => write!(f, "@{} {}", nominal, formula),
218 TLExpr::Somewhere { formula } => write!(f, "E {}", formula),
219 TLExpr::Everywhere { formula } => write!(f, "A {}", formula),
220 TLExpr::AllDifferent { variables } => {
221 write!(f, "alldiff([")?;
222 for (i, var) in variables.iter().enumerate() {
223 if i > 0 {
224 write!(f, ", ")?;
225 }
226 write!(f, "{}", var)?;
227 }
228 write!(f, "])")
229 }
230 TLExpr::GlobalCardinality {
231 variables,
232 values,
233 min_occurrences,
234 max_occurrences,
235 } => {
236 write!(f, "gcc([")?;
237 for (i, var) in variables.iter().enumerate() {
238 if i > 0 {
239 write!(f, ", ")?;
240 }
241 write!(f, "{}", var)?;
242 }
243 write!(f, "], [")?;
244 for (i, val) in values.iter().enumerate() {
245 if i > 0 {
246 write!(f, ", ")?;
247 }
248 write!(f, "{}:[{},{}]", val, min_occurrences[i], max_occurrences[i])?;
249 }
250 write!(f, "])")
251 }
252 TLExpr::Abducible { name, cost } => write!(f, "abd({}:{})", name, cost),
253 TLExpr::Explain { formula } => write!(f, "explain({})", formula),
254 TLExpr::Constant(value) => write!(f, "{}", value),
255 }
256 }
257}
258
259impl fmt::Display for OpType {
260 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
261 match self {
262 OpType::Einsum { spec } => write!(f, "einsum({})", spec),
263 OpType::ElemUnary { op } => write!(f, "{}(·)", op),
264 OpType::ElemBinary { op } => write!(f, "{}(·, ·)", op),
265 OpType::Reduce { op, axes } => write!(f, "{}(·, axes={:?})", op, axes),
266 }
267 }
268}
269
270impl fmt::Display for EinsumNode {
271 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
272 write!(f, "{} ", self.op)?;
273 write!(f, "inputs={:?}", self.inputs)?;
274 write!(f, " outputs={:?}", self.outputs)
275 }
276}
277
278impl fmt::Display for EinsumGraph {
279 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
280 writeln!(f, "EinsumGraph {{")?;
281 writeln!(f, " tensors: {:?}", self.tensors)?;
282 writeln!(f, " nodes: [")?;
283 for (i, node) in self.nodes.iter().enumerate() {
284 writeln!(f, " {}: {}", i, node)?;
285 }
286 writeln!(f, " ]")?;
287 writeln!(f, " outputs: {:?}", self.outputs)?;
288 write!(f, "}}")
289 }
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295
296 #[test]
297 fn test_display_term() {
298 let var = Term::var("x");
299 assert_eq!(format!("{}", var), "?x");
300
301 let const_term = Term::constant("alice");
302 assert_eq!(format!("{}", const_term), "alice");
303
304 let typed = Term::typed_var("x", "Int");
305 assert_eq!(format!("{}", typed), "?x:Int");
306 }
307
308 #[test]
309 fn test_display_aggregate_op() {
310 assert_eq!(format!("{}", AggregateOp::Count), "COUNT");
311 assert_eq!(format!("{}", AggregateOp::Sum), "SUM");
312 assert_eq!(format!("{}", AggregateOp::Average), "AVG");
313 }
314
315 #[test]
316 fn test_display_simple_expr() {
317 let pred = TLExpr::pred("Person", vec![Term::var("x")]);
318 assert_eq!(format!("{}", pred), "Person(?x)");
319 }
320
321 #[test]
322 fn test_display_logical_ops() {
323 let p = TLExpr::pred("P", vec![Term::var("x")]);
324 let q = TLExpr::pred("Q", vec![Term::var("y")]);
325
326 let and_expr = TLExpr::and(p.clone(), q.clone());
327 assert_eq!(format!("{}", and_expr), "(P(?x) ∧ Q(?y))");
328
329 let or_expr = TLExpr::or(p.clone(), q);
330 assert_eq!(format!("{}", or_expr), "(P(?x) ∨ Q(?y))");
331
332 let not_expr = TLExpr::negate(p);
333 assert_eq!(format!("{}", not_expr), "¬P(?x)");
334 }
335
336 #[test]
337 fn test_display_quantifiers() {
338 let body = TLExpr::pred("P", vec![Term::var("x")]);
339
340 let exists = TLExpr::exists("x", "Domain", body.clone());
341 assert_eq!(format!("{}", exists), "∃x:Domain. P(?x)");
342
343 let forall = TLExpr::forall("x", "Domain", body);
344 assert_eq!(format!("{}", forall), "∀x:Domain. P(?x)");
345 }
346
347 #[test]
348 fn test_display_aggregate() {
349 let body = TLExpr::pred("Value", vec![Term::var("x")]);
350
351 let sum = TLExpr::sum("x", "Domain", body.clone());
352 assert_eq!(format!("{}", sum), "SUM(x:Domain, Value(?x))");
353
354 let count = TLExpr::count("x", "Domain", body);
355 assert_eq!(format!("{}", count), "COUNT(x:Domain, Value(?x))");
356 }
357
358 #[test]
359 fn test_display_aggregate_with_group_by() {
360 let body = TLExpr::pred("Value", vec![Term::var("x"), Term::var("y")]);
361
362 let agg = TLExpr::aggregate_with_group_by(
363 AggregateOp::Sum,
364 "x",
365 "Domain",
366 body,
367 vec!["y".to_string()],
368 );
369
370 let display = format!("{}", agg);
371 assert!(display.contains("SUM"));
372 assert!(display.contains("GROUP BY"));
373 assert!(display.contains("y"));
374 }
375
376 #[test]
377 fn test_display_arithmetic() {
378 let x = TLExpr::constant(5.0);
379 let y = TLExpr::constant(3.0);
380
381 let add = TLExpr::add(x.clone(), y.clone());
382 assert_eq!(format!("{}", add), "(5 + 3)");
383
384 let mul = TLExpr::mul(x, y);
385 assert_eq!(format!("{}", mul), "(5 * 3)");
386 }
387
388 #[test]
389 fn test_display_comparison() {
390 let x = TLExpr::pred("X", vec![Term::var("i")]);
391 let threshold = TLExpr::constant(0.5);
392
393 let gt = TLExpr::gt(x, threshold);
394 let display = format!("{}", gt);
395 assert!(display.contains(">"));
396 assert!(display.contains("0.5"));
397 }
398
399 #[test]
400 fn test_display_conditional() {
401 let cond = TLExpr::pred("IsAdult", vec![Term::var("x")]);
402 let then_br = TLExpr::constant(1.0);
403 let else_br = TLExpr::constant(0.0);
404
405 let if_expr = TLExpr::if_then_else(cond, then_br, else_br);
406 let display = format!("{}", if_expr);
407 assert!(display.contains("if"));
408 assert!(display.contains("then"));
409 assert!(display.contains("else"));
410 }
411
412 #[test]
413 fn test_display_einsum_node() {
414 let node = EinsumNode::new("ij,jk->ik", vec![0, 1], vec![2]);
415 let display = format!("{}", node);
416 assert!(display.contains("einsum"));
417 assert!(display.contains("ij,jk->ik"));
418 assert!(display.contains("inputs=[0, 1]"));
419 assert!(display.contains("outputs=[2]"));
420 }
421
422 #[test]
423 fn test_display_graph() {
424 let mut graph = EinsumGraph::new();
425 let t0 = graph.add_tensor("input");
426 let t1 = graph.add_tensor("output");
427
428 graph
429 .add_node(EinsumNode::new("i->i", vec![t0], vec![t1]))
430 .unwrap();
431 graph.add_output(t1).unwrap();
432
433 let display = format!("{}", graph);
434 assert!(display.contains("EinsumGraph"));
435 assert!(display.contains("tensors"));
436 assert!(display.contains("input"));
437 assert!(display.contains("output"));
438 }
439}