1use std::fmt;
6
7use crate::{
8 expr::{AggregateOp, TLExpr},
9 graph::{EinsumGraph, EinsumNode, OpType},
10 term::Term,
11};
12
13impl fmt::Display for Term {
14 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
15 match self {
16 Term::Var(name) => write!(f, "?{}", name),
17 Term::Const(name) => write!(f, "{}", name),
18 Term::Typed {
19 value,
20 type_annotation,
21 } => write!(f, "{}:{}", value, type_annotation.type_name),
22 }
23 }
24}
25
26impl fmt::Display for AggregateOp {
27 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28 match self {
29 AggregateOp::Count => write!(f, "COUNT"),
30 AggregateOp::Sum => write!(f, "SUM"),
31 AggregateOp::Average => write!(f, "AVG"),
32 AggregateOp::Max => write!(f, "MAX"),
33 AggregateOp::Min => write!(f, "MIN"),
34 AggregateOp::Product => write!(f, "PROD"),
35 AggregateOp::Any => write!(f, "ANY"),
36 AggregateOp::All => write!(f, "ALL"),
37 }
38 }
39}
40
41impl fmt::Display for TLExpr {
42 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43 match self {
44 TLExpr::Pred { name, args } => {
45 write!(f, "{}(", name)?;
46 for (i, arg) in args.iter().enumerate() {
47 if i > 0 {
48 write!(f, ", ")?;
49 }
50 write!(f, "{}", arg)?;
51 }
52 write!(f, ")")
53 }
54 TLExpr::And(l, r) => write!(f, "({} ∧ {})", l, r),
55 TLExpr::Or(l, r) => write!(f, "({} ∨ {})", l, r),
56 TLExpr::Not(e) => write!(f, "¬{}", e),
57 TLExpr::Exists { var, domain, body } => {
58 write!(f, "∃{}:{}. {}", var, domain, body)
59 }
60 TLExpr::ForAll { var, domain, body } => {
61 write!(f, "∀{}:{}. {}", var, domain, body)
62 }
63 TLExpr::Aggregate {
64 op,
65 var,
66 domain,
67 body,
68 group_by,
69 } => {
70 write!(f, "{}({}:{}, ", op, var, domain)?;
71 if let Some(group_vars) = group_by {
72 write!(f, "GROUP BY [")?;
73 for (i, gv) in group_vars.iter().enumerate() {
74 if i > 0 {
75 write!(f, ", ")?;
76 }
77 write!(f, "{}", gv)?;
78 }
79 write!(f, "], ")?;
80 }
81 write!(f, "{})", body)
82 }
83 TLExpr::Imply(premise, conclusion) => write!(f, "({} → {})", premise, conclusion),
84 TLExpr::Score(e) => write!(f, "score({})", e),
85 TLExpr::Add(l, r) => write!(f, "({} + {})", l, r),
86 TLExpr::Sub(l, r) => write!(f, "({} - {})", l, r),
87 TLExpr::Mul(l, r) => write!(f, "({} * {})", l, r),
88 TLExpr::Div(l, r) => write!(f, "({} / {})", l, r),
89 TLExpr::Pow(l, r) => write!(f, "({} ^ {})", l, r),
90 TLExpr::Mod(l, r) => write!(f, "({} % {})", l, r),
91 TLExpr::Min(l, r) => write!(f, "min({}, {})", l, r),
92 TLExpr::Max(l, r) => write!(f, "max({}, {})", l, r),
93 TLExpr::Abs(e) => write!(f, "abs({})", e),
94 TLExpr::Floor(e) => write!(f, "floor({})", e),
95 TLExpr::Ceil(e) => write!(f, "ceil({})", e),
96 TLExpr::Round(e) => write!(f, "round({})", e),
97 TLExpr::Sqrt(e) => write!(f, "sqrt({})", e),
98 TLExpr::Exp(e) => write!(f, "exp({})", e),
99 TLExpr::Log(e) => write!(f, "log({})", e),
100 TLExpr::Sin(e) => write!(f, "sin({})", e),
101 TLExpr::Cos(e) => write!(f, "cos({})", e),
102 TLExpr::Tan(e) => write!(f, "tan({})", e),
103 TLExpr::Eq(l, r) => write!(f, "({} = {})", l, r),
104 TLExpr::Lt(l, r) => write!(f, "({} < {})", l, r),
105 TLExpr::Gt(l, r) => write!(f, "({} > {})", l, r),
106 TLExpr::Lte(l, r) => write!(f, "({} ≤ {})", l, r),
107 TLExpr::Gte(l, r) => write!(f, "({} ≥ {})", l, r),
108 TLExpr::IfThenElse {
109 condition,
110 then_branch,
111 else_branch,
112 } => write!(
113 f,
114 "if {} then {} else {}",
115 condition, then_branch, else_branch
116 ),
117 TLExpr::Let { var, value, body } => {
118 write!(f, "let {} = {} in {}", var, value, body)
119 }
120 TLExpr::Box(e) => write!(f, "□{}", e),
121 TLExpr::Diamond(e) => write!(f, "◇{}", e),
122 TLExpr::Next(e) => write!(f, "X{}", e),
123 TLExpr::Eventually(e) => write!(f, "F{}", e),
124 TLExpr::Always(e) => write!(f, "G{}", e),
125 TLExpr::Until { before, after } => write!(f, "({} U {})", before, after),
126 TLExpr::TNorm { kind, left, right } => {
128 write!(f, "({} ⊤_{:?} {})", left, kind, right)
129 }
130 TLExpr::TCoNorm { kind, left, right } => {
131 write!(f, "({} ⊥_{:?} {})", left, kind, right)
132 }
133 TLExpr::FuzzyNot { kind, expr } => write!(f, "¬_{:?}({})", kind, expr),
134 TLExpr::FuzzyImplication {
135 kind,
136 premise,
137 conclusion,
138 } => write!(f, "({} →_{:?} {})", premise, kind, conclusion),
139 TLExpr::SoftExists {
141 var,
142 domain,
143 body,
144 temperature,
145 } => write!(f, "∃^{{{}}}{}:{}. {}", temperature, var, domain, body),
146 TLExpr::SoftForAll {
147 var,
148 domain,
149 body,
150 temperature,
151 } => write!(f, "∀^{{{}}}{}:{}. {}", temperature, var, domain, body),
152 TLExpr::WeightedRule { weight, rule } => write!(f, "{}::{}", weight, rule),
153 TLExpr::ProbabilisticChoice { alternatives } => {
154 write!(f, "choice[")?;
155 for (i, (prob, expr)) in alternatives.iter().enumerate() {
156 if i > 0 {
157 write!(f, ", ")?;
158 }
159 write!(f, "{}: {}", prob, expr)?;
160 }
161 write!(f, "]")
162 }
163 TLExpr::Release { released, releaser } => write!(f, "({} R {})", released, releaser),
165 TLExpr::WeakUntil { before, after } => write!(f, "({} W {})", before, after),
166 TLExpr::StrongRelease { released, releaser } => {
167 write!(f, "({} M {})", released, releaser)
168 }
169 TLExpr::Constant(value) => write!(f, "{}", value),
170 }
171 }
172}
173
174impl fmt::Display for OpType {
175 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
176 match self {
177 OpType::Einsum { spec } => write!(f, "einsum({})", spec),
178 OpType::ElemUnary { op } => write!(f, "{}(·)", op),
179 OpType::ElemBinary { op } => write!(f, "{}(·, ·)", op),
180 OpType::Reduce { op, axes } => write!(f, "{}(·, axes={:?})", op, axes),
181 }
182 }
183}
184
185impl fmt::Display for EinsumNode {
186 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
187 write!(f, "{} ", self.op)?;
188 write!(f, "inputs={:?}", self.inputs)?;
189 write!(f, " outputs={:?}", self.outputs)
190 }
191}
192
193impl fmt::Display for EinsumGraph {
194 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195 writeln!(f, "EinsumGraph {{")?;
196 writeln!(f, " tensors: {:?}", self.tensors)?;
197 writeln!(f, " nodes: [")?;
198 for (i, node) in self.nodes.iter().enumerate() {
199 writeln!(f, " {}: {}", i, node)?;
200 }
201 writeln!(f, " ]")?;
202 writeln!(f, " outputs: {:?}", self.outputs)?;
203 write!(f, "}}")
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210
211 #[test]
212 fn test_display_term() {
213 let var = Term::var("x");
214 assert_eq!(format!("{}", var), "?x");
215
216 let const_term = Term::constant("alice");
217 assert_eq!(format!("{}", const_term), "alice");
218
219 let typed = Term::typed_var("x", "Int");
220 assert_eq!(format!("{}", typed), "?x:Int");
221 }
222
223 #[test]
224 fn test_display_aggregate_op() {
225 assert_eq!(format!("{}", AggregateOp::Count), "COUNT");
226 assert_eq!(format!("{}", AggregateOp::Sum), "SUM");
227 assert_eq!(format!("{}", AggregateOp::Average), "AVG");
228 }
229
230 #[test]
231 fn test_display_simple_expr() {
232 let pred = TLExpr::pred("Person", vec![Term::var("x")]);
233 assert_eq!(format!("{}", pred), "Person(?x)");
234 }
235
236 #[test]
237 fn test_display_logical_ops() {
238 let p = TLExpr::pred("P", vec![Term::var("x")]);
239 let q = TLExpr::pred("Q", vec![Term::var("y")]);
240
241 let and_expr = TLExpr::and(p.clone(), q.clone());
242 assert_eq!(format!("{}", and_expr), "(P(?x) ∧ Q(?y))");
243
244 let or_expr = TLExpr::or(p.clone(), q);
245 assert_eq!(format!("{}", or_expr), "(P(?x) ∨ Q(?y))");
246
247 let not_expr = TLExpr::negate(p);
248 assert_eq!(format!("{}", not_expr), "¬P(?x)");
249 }
250
251 #[test]
252 fn test_display_quantifiers() {
253 let body = TLExpr::pred("P", vec![Term::var("x")]);
254
255 let exists = TLExpr::exists("x", "Domain", body.clone());
256 assert_eq!(format!("{}", exists), "∃x:Domain. P(?x)");
257
258 let forall = TLExpr::forall("x", "Domain", body);
259 assert_eq!(format!("{}", forall), "∀x:Domain. P(?x)");
260 }
261
262 #[test]
263 fn test_display_aggregate() {
264 let body = TLExpr::pred("Value", vec![Term::var("x")]);
265
266 let sum = TLExpr::sum("x", "Domain", body.clone());
267 assert_eq!(format!("{}", sum), "SUM(x:Domain, Value(?x))");
268
269 let count = TLExpr::count("x", "Domain", body);
270 assert_eq!(format!("{}", count), "COUNT(x:Domain, Value(?x))");
271 }
272
273 #[test]
274 fn test_display_aggregate_with_group_by() {
275 let body = TLExpr::pred("Value", vec![Term::var("x"), Term::var("y")]);
276
277 let agg = TLExpr::aggregate_with_group_by(
278 AggregateOp::Sum,
279 "x",
280 "Domain",
281 body,
282 vec!["y".to_string()],
283 );
284
285 let display = format!("{}", agg);
286 assert!(display.contains("SUM"));
287 assert!(display.contains("GROUP BY"));
288 assert!(display.contains("y"));
289 }
290
291 #[test]
292 fn test_display_arithmetic() {
293 let x = TLExpr::constant(5.0);
294 let y = TLExpr::constant(3.0);
295
296 let add = TLExpr::add(x.clone(), y.clone());
297 assert_eq!(format!("{}", add), "(5 + 3)");
298
299 let mul = TLExpr::mul(x, y);
300 assert_eq!(format!("{}", mul), "(5 * 3)");
301 }
302
303 #[test]
304 fn test_display_comparison() {
305 let x = TLExpr::pred("X", vec![Term::var("i")]);
306 let threshold = TLExpr::constant(0.5);
307
308 let gt = TLExpr::gt(x, threshold);
309 let display = format!("{}", gt);
310 assert!(display.contains(">"));
311 assert!(display.contains("0.5"));
312 }
313
314 #[test]
315 fn test_display_conditional() {
316 let cond = TLExpr::pred("IsAdult", vec![Term::var("x")]);
317 let then_br = TLExpr::constant(1.0);
318 let else_br = TLExpr::constant(0.0);
319
320 let if_expr = TLExpr::if_then_else(cond, then_br, else_br);
321 let display = format!("{}", if_expr);
322 assert!(display.contains("if"));
323 assert!(display.contains("then"));
324 assert!(display.contains("else"));
325 }
326
327 #[test]
328 fn test_display_einsum_node() {
329 let node = EinsumNode::new("ij,jk->ik", vec![0, 1], vec![2]);
330 let display = format!("{}", node);
331 assert!(display.contains("einsum"));
332 assert!(display.contains("ij,jk->ik"));
333 assert!(display.contains("inputs=[0, 1]"));
334 assert!(display.contains("outputs=[2]"));
335 }
336
337 #[test]
338 fn test_display_graph() {
339 let mut graph = EinsumGraph::new();
340 let t0 = graph.add_tensor("input");
341 let t1 = graph.add_tensor("output");
342
343 graph
344 .add_node(EinsumNode::new("i->i", vec![t0], vec![t1]))
345 .unwrap();
346 graph.add_output(t1).unwrap();
347
348 let display = format!("{}", graph);
349 assert!(display.contains("EinsumGraph"));
350 assert!(display.contains("tensors"));
351 assert!(display.contains("input"));
352 assert!(display.contains("output"));
353 }
354}