tensorlogic_ir/
display.rs

1//! Display trait implementations for IR types.
2//!
3//! Provides human-readable string representations for debugging and error messages.
4
5use std::fmt;
6
7use crate::{
8    expr::{AggregateOp, TLExpr},
9    graph::{EinsumGraph, EinsumNode, OpType},
10    term::Term,
11};
12
13impl fmt::Display for Term {
14    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
15        match self {
16            Term::Var(name) => write!(f, "?{}", name),
17            Term::Const(name) => write!(f, "{}", name),
18            Term::Typed {
19                value,
20                type_annotation,
21            } => write!(f, "{}:{}", value, type_annotation.type_name),
22        }
23    }
24}
25
26impl fmt::Display for AggregateOp {
27    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28        match self {
29            AggregateOp::Count => write!(f, "COUNT"),
30            AggregateOp::Sum => write!(f, "SUM"),
31            AggregateOp::Average => write!(f, "AVG"),
32            AggregateOp::Max => write!(f, "MAX"),
33            AggregateOp::Min => write!(f, "MIN"),
34            AggregateOp::Product => write!(f, "PROD"),
35            AggregateOp::Any => write!(f, "ANY"),
36            AggregateOp::All => write!(f, "ALL"),
37        }
38    }
39}
40
41impl fmt::Display for TLExpr {
42    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43        match self {
44            TLExpr::Pred { name, args } => {
45                write!(f, "{}(", name)?;
46                for (i, arg) in args.iter().enumerate() {
47                    if i > 0 {
48                        write!(f, ", ")?;
49                    }
50                    write!(f, "{}", arg)?;
51                }
52                write!(f, ")")
53            }
54            TLExpr::And(l, r) => write!(f, "({} ∧ {})", l, r),
55            TLExpr::Or(l, r) => write!(f, "({} ∨ {})", l, r),
56            TLExpr::Not(e) => write!(f, "¬{}", e),
57            TLExpr::Exists { var, domain, body } => {
58                write!(f, "∃{}:{}. {}", var, domain, body)
59            }
60            TLExpr::ForAll { var, domain, body } => {
61                write!(f, "∀{}:{}. {}", var, domain, body)
62            }
63            TLExpr::Aggregate {
64                op,
65                var,
66                domain,
67                body,
68                group_by,
69            } => {
70                write!(f, "{}({}:{}, ", op, var, domain)?;
71                if let Some(group_vars) = group_by {
72                    write!(f, "GROUP BY [")?;
73                    for (i, gv) in group_vars.iter().enumerate() {
74                        if i > 0 {
75                            write!(f, ", ")?;
76                        }
77                        write!(f, "{}", gv)?;
78                    }
79                    write!(f, "], ")?;
80                }
81                write!(f, "{})", body)
82            }
83            TLExpr::Imply(premise, conclusion) => write!(f, "({} → {})", premise, conclusion),
84            TLExpr::Score(e) => write!(f, "score({})", e),
85            TLExpr::Add(l, r) => write!(f, "({} + {})", l, r),
86            TLExpr::Sub(l, r) => write!(f, "({} - {})", l, r),
87            TLExpr::Mul(l, r) => write!(f, "({} * {})", l, r),
88            TLExpr::Div(l, r) => write!(f, "({} / {})", l, r),
89            TLExpr::Pow(l, r) => write!(f, "({} ^ {})", l, r),
90            TLExpr::Mod(l, r) => write!(f, "({} % {})", l, r),
91            TLExpr::Min(l, r) => write!(f, "min({}, {})", l, r),
92            TLExpr::Max(l, r) => write!(f, "max({}, {})", l, r),
93            TLExpr::Abs(e) => write!(f, "abs({})", e),
94            TLExpr::Floor(e) => write!(f, "floor({})", e),
95            TLExpr::Ceil(e) => write!(f, "ceil({})", e),
96            TLExpr::Round(e) => write!(f, "round({})", e),
97            TLExpr::Sqrt(e) => write!(f, "sqrt({})", e),
98            TLExpr::Exp(e) => write!(f, "exp({})", e),
99            TLExpr::Log(e) => write!(f, "log({})", e),
100            TLExpr::Sin(e) => write!(f, "sin({})", e),
101            TLExpr::Cos(e) => write!(f, "cos({})", e),
102            TLExpr::Tan(e) => write!(f, "tan({})", e),
103            TLExpr::Eq(l, r) => write!(f, "({} = {})", l, r),
104            TLExpr::Lt(l, r) => write!(f, "({} < {})", l, r),
105            TLExpr::Gt(l, r) => write!(f, "({} > {})", l, r),
106            TLExpr::Lte(l, r) => write!(f, "({} ≤ {})", l, r),
107            TLExpr::Gte(l, r) => write!(f, "({} ≥ {})", l, r),
108            TLExpr::IfThenElse {
109                condition,
110                then_branch,
111                else_branch,
112            } => write!(
113                f,
114                "if {} then {} else {}",
115                condition, then_branch, else_branch
116            ),
117            TLExpr::Let { var, value, body } => {
118                write!(f, "let {} = {} in {}", var, value, body)
119            }
120            TLExpr::Box(e) => write!(f, "□{}", e),
121            TLExpr::Diamond(e) => write!(f, "◇{}", e),
122            TLExpr::Next(e) => write!(f, "X{}", e),
123            TLExpr::Eventually(e) => write!(f, "F{}", e),
124            TLExpr::Always(e) => write!(f, "G{}", e),
125            TLExpr::Until { before, after } => write!(f, "({} U {})", before, after),
126            // Fuzzy logic operators
127            TLExpr::TNorm { kind, left, right } => {
128                write!(f, "({} ⊤_{:?} {})", left, kind, right)
129            }
130            TLExpr::TCoNorm { kind, left, right } => {
131                write!(f, "({} ⊥_{:?} {})", left, kind, right)
132            }
133            TLExpr::FuzzyNot { kind, expr } => write!(f, "¬_{:?}({})", kind, expr),
134            TLExpr::FuzzyImplication {
135                kind,
136                premise,
137                conclusion,
138            } => write!(f, "({} →_{:?} {})", premise, kind, conclusion),
139            // Probabilistic operators
140            TLExpr::SoftExists {
141                var,
142                domain,
143                body,
144                temperature,
145            } => write!(f, "∃^{{{}}}{}:{}. {}", temperature, var, domain, body),
146            TLExpr::SoftForAll {
147                var,
148                domain,
149                body,
150                temperature,
151            } => write!(f, "∀^{{{}}}{}:{}. {}", temperature, var, domain, body),
152            TLExpr::WeightedRule { weight, rule } => write!(f, "{}::{}", weight, rule),
153            TLExpr::ProbabilisticChoice { alternatives } => {
154                write!(f, "choice[")?;
155                for (i, (prob, expr)) in alternatives.iter().enumerate() {
156                    if i > 0 {
157                        write!(f, ", ")?;
158                    }
159                    write!(f, "{}: {}", prob, expr)?;
160                }
161                write!(f, "]")
162            }
163            // Extended temporal logic
164            TLExpr::Release { released, releaser } => write!(f, "({} R {})", released, releaser),
165            TLExpr::WeakUntil { before, after } => write!(f, "({} W {})", before, after),
166            TLExpr::StrongRelease { released, releaser } => {
167                write!(f, "({} M {})", released, releaser)
168            }
169            TLExpr::Constant(value) => write!(f, "{}", value),
170        }
171    }
172}
173
174impl fmt::Display for OpType {
175    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
176        match self {
177            OpType::Einsum { spec } => write!(f, "einsum({})", spec),
178            OpType::ElemUnary { op } => write!(f, "{}(·)", op),
179            OpType::ElemBinary { op } => write!(f, "{}(·, ·)", op),
180            OpType::Reduce { op, axes } => write!(f, "{}(·, axes={:?})", op, axes),
181        }
182    }
183}
184
185impl fmt::Display for EinsumNode {
186    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
187        write!(f, "{} ", self.op)?;
188        write!(f, "inputs={:?}", self.inputs)?;
189        write!(f, " outputs={:?}", self.outputs)
190    }
191}
192
193impl fmt::Display for EinsumGraph {
194    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195        writeln!(f, "EinsumGraph {{")?;
196        writeln!(f, "  tensors: {:?}", self.tensors)?;
197        writeln!(f, "  nodes: [")?;
198        for (i, node) in self.nodes.iter().enumerate() {
199            writeln!(f, "    {}: {}", i, node)?;
200        }
201        writeln!(f, "  ]")?;
202        writeln!(f, "  outputs: {:?}", self.outputs)?;
203        write!(f, "}}")
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210
211    #[test]
212    fn test_display_term() {
213        let var = Term::var("x");
214        assert_eq!(format!("{}", var), "?x");
215
216        let const_term = Term::constant("alice");
217        assert_eq!(format!("{}", const_term), "alice");
218
219        let typed = Term::typed_var("x", "Int");
220        assert_eq!(format!("{}", typed), "?x:Int");
221    }
222
223    #[test]
224    fn test_display_aggregate_op() {
225        assert_eq!(format!("{}", AggregateOp::Count), "COUNT");
226        assert_eq!(format!("{}", AggregateOp::Sum), "SUM");
227        assert_eq!(format!("{}", AggregateOp::Average), "AVG");
228    }
229
230    #[test]
231    fn test_display_simple_expr() {
232        let pred = TLExpr::pred("Person", vec![Term::var("x")]);
233        assert_eq!(format!("{}", pred), "Person(?x)");
234    }
235
236    #[test]
237    fn test_display_logical_ops() {
238        let p = TLExpr::pred("P", vec![Term::var("x")]);
239        let q = TLExpr::pred("Q", vec![Term::var("y")]);
240
241        let and_expr = TLExpr::and(p.clone(), q.clone());
242        assert_eq!(format!("{}", and_expr), "(P(?x) ∧ Q(?y))");
243
244        let or_expr = TLExpr::or(p.clone(), q);
245        assert_eq!(format!("{}", or_expr), "(P(?x) ∨ Q(?y))");
246
247        let not_expr = TLExpr::negate(p);
248        assert_eq!(format!("{}", not_expr), "¬P(?x)");
249    }
250
251    #[test]
252    fn test_display_quantifiers() {
253        let body = TLExpr::pred("P", vec![Term::var("x")]);
254
255        let exists = TLExpr::exists("x", "Domain", body.clone());
256        assert_eq!(format!("{}", exists), "∃x:Domain. P(?x)");
257
258        let forall = TLExpr::forall("x", "Domain", body);
259        assert_eq!(format!("{}", forall), "∀x:Domain. P(?x)");
260    }
261
262    #[test]
263    fn test_display_aggregate() {
264        let body = TLExpr::pred("Value", vec![Term::var("x")]);
265
266        let sum = TLExpr::sum("x", "Domain", body.clone());
267        assert_eq!(format!("{}", sum), "SUM(x:Domain, Value(?x))");
268
269        let count = TLExpr::count("x", "Domain", body);
270        assert_eq!(format!("{}", count), "COUNT(x:Domain, Value(?x))");
271    }
272
273    #[test]
274    fn test_display_aggregate_with_group_by() {
275        let body = TLExpr::pred("Value", vec![Term::var("x"), Term::var("y")]);
276
277        let agg = TLExpr::aggregate_with_group_by(
278            AggregateOp::Sum,
279            "x",
280            "Domain",
281            body,
282            vec!["y".to_string()],
283        );
284
285        let display = format!("{}", agg);
286        assert!(display.contains("SUM"));
287        assert!(display.contains("GROUP BY"));
288        assert!(display.contains("y"));
289    }
290
291    #[test]
292    fn test_display_arithmetic() {
293        let x = TLExpr::constant(5.0);
294        let y = TLExpr::constant(3.0);
295
296        let add = TLExpr::add(x.clone(), y.clone());
297        assert_eq!(format!("{}", add), "(5 + 3)");
298
299        let mul = TLExpr::mul(x, y);
300        assert_eq!(format!("{}", mul), "(5 * 3)");
301    }
302
303    #[test]
304    fn test_display_comparison() {
305        let x = TLExpr::pred("X", vec![Term::var("i")]);
306        let threshold = TLExpr::constant(0.5);
307
308        let gt = TLExpr::gt(x, threshold);
309        let display = format!("{}", gt);
310        assert!(display.contains(">"));
311        assert!(display.contains("0.5"));
312    }
313
314    #[test]
315    fn test_display_conditional() {
316        let cond = TLExpr::pred("IsAdult", vec![Term::var("x")]);
317        let then_br = TLExpr::constant(1.0);
318        let else_br = TLExpr::constant(0.0);
319
320        let if_expr = TLExpr::if_then_else(cond, then_br, else_br);
321        let display = format!("{}", if_expr);
322        assert!(display.contains("if"));
323        assert!(display.contains("then"));
324        assert!(display.contains("else"));
325    }
326
327    #[test]
328    fn test_display_einsum_node() {
329        let node = EinsumNode::new("ij,jk->ik", vec![0, 1], vec![2]);
330        let display = format!("{}", node);
331        assert!(display.contains("einsum"));
332        assert!(display.contains("ij,jk->ik"));
333        assert!(display.contains("inputs=[0, 1]"));
334        assert!(display.contains("outputs=[2]"));
335    }
336
337    #[test]
338    fn test_display_graph() {
339        let mut graph = EinsumGraph::new();
340        let t0 = graph.add_tensor("input");
341        let t1 = graph.add_tensor("output");
342
343        graph
344            .add_node(EinsumNode::new("i->i", vec![t0], vec![t1]))
345            .unwrap();
346        graph.add_output(t1).unwrap();
347
348        let display = format!("{}", graph);
349        assert!(display.contains("EinsumGraph"));
350        assert!(display.contains("tensors"));
351        assert!(display.contains("input"));
352        assert!(display.contains("output"));
353    }
354}