Skip to main content

tensorlogic_ir/
display.rs

1//! Display trait implementations for IR types.
2//!
3//! Provides human-readable string representations for debugging and error messages.
4
5use std::fmt;
6
7use crate::{
8    expr::{AggregateOp, TLExpr},
9    graph::{EinsumGraph, EinsumNode, OpType},
10    term::Term,
11};
12
13impl fmt::Display for Term {
14    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
15        match self {
16            Term::Var(name) => write!(f, "?{}", name),
17            Term::Const(name) => write!(f, "{}", name),
18            Term::Typed {
19                value,
20                type_annotation,
21            } => write!(f, "{}:{}", value, type_annotation.type_name),
22        }
23    }
24}
25
26impl fmt::Display for AggregateOp {
27    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28        match self {
29            AggregateOp::Count => write!(f, "COUNT"),
30            AggregateOp::Sum => write!(f, "SUM"),
31            AggregateOp::Average => write!(f, "AVG"),
32            AggregateOp::Max => write!(f, "MAX"),
33            AggregateOp::Min => write!(f, "MIN"),
34            AggregateOp::Product => write!(f, "PROD"),
35            AggregateOp::Any => write!(f, "ANY"),
36            AggregateOp::All => write!(f, "ALL"),
37        }
38    }
39}
40
41impl fmt::Display for TLExpr {
42    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43        match self {
44            TLExpr::Pred { name, args } => {
45                write!(f, "{}(", name)?;
46                for (i, arg) in args.iter().enumerate() {
47                    if i > 0 {
48                        write!(f, ", ")?;
49                    }
50                    write!(f, "{}", arg)?;
51                }
52                write!(f, ")")
53            }
54            TLExpr::And(l, r) => write!(f, "({} ∧ {})", l, r),
55            TLExpr::Or(l, r) => write!(f, "({} ∨ {})", l, r),
56            TLExpr::Not(e) => write!(f, "¬{}", e),
57            TLExpr::Exists { var, domain, body } => {
58                write!(f, "∃{}:{}. {}", var, domain, body)
59            }
60            TLExpr::ForAll { var, domain, body } => {
61                write!(f, "∀{}:{}. {}", var, domain, body)
62            }
63            TLExpr::Aggregate {
64                op,
65                var,
66                domain,
67                body,
68                group_by,
69            } => {
70                write!(f, "{}({}:{}, ", op, var, domain)?;
71                if let Some(group_vars) = group_by {
72                    write!(f, "GROUP BY [")?;
73                    for (i, gv) in group_vars.iter().enumerate() {
74                        if i > 0 {
75                            write!(f, ", ")?;
76                        }
77                        write!(f, "{}", gv)?;
78                    }
79                    write!(f, "], ")?;
80                }
81                write!(f, "{})", body)
82            }
83            TLExpr::Imply(premise, conclusion) => write!(f, "({} → {})", premise, conclusion),
84            TLExpr::Score(e) => write!(f, "score({})", e),
85            TLExpr::Add(l, r) => write!(f, "({} + {})", l, r),
86            TLExpr::Sub(l, r) => write!(f, "({} - {})", l, r),
87            TLExpr::Mul(l, r) => write!(f, "({} * {})", l, r),
88            TLExpr::Div(l, r) => write!(f, "({} / {})", l, r),
89            TLExpr::Pow(l, r) => write!(f, "({} ^ {})", l, r),
90            TLExpr::Mod(l, r) => write!(f, "({} % {})", l, r),
91            TLExpr::Min(l, r) => write!(f, "min({}, {})", l, r),
92            TLExpr::Max(l, r) => write!(f, "max({}, {})", l, r),
93            TLExpr::Abs(e) => write!(f, "abs({})", e),
94            TLExpr::Floor(e) => write!(f, "floor({})", e),
95            TLExpr::Ceil(e) => write!(f, "ceil({})", e),
96            TLExpr::Round(e) => write!(f, "round({})", e),
97            TLExpr::Sqrt(e) => write!(f, "sqrt({})", e),
98            TLExpr::Exp(e) => write!(f, "exp({})", e),
99            TLExpr::Log(e) => write!(f, "log({})", e),
100            TLExpr::Sin(e) => write!(f, "sin({})", e),
101            TLExpr::Cos(e) => write!(f, "cos({})", e),
102            TLExpr::Tan(e) => write!(f, "tan({})", e),
103            TLExpr::Eq(l, r) => write!(f, "({} = {})", l, r),
104            TLExpr::Lt(l, r) => write!(f, "({} < {})", l, r),
105            TLExpr::Gt(l, r) => write!(f, "({} > {})", l, r),
106            TLExpr::Lte(l, r) => write!(f, "({} ≤ {})", l, r),
107            TLExpr::Gte(l, r) => write!(f, "({} ≥ {})", l, r),
108            TLExpr::IfThenElse {
109                condition,
110                then_branch,
111                else_branch,
112            } => write!(
113                f,
114                "if {} then {} else {}",
115                condition, then_branch, else_branch
116            ),
117            TLExpr::Let { var, value, body } => {
118                write!(f, "let {} = {} in {}", var, value, body)
119            }
120            TLExpr::Box(e) => write!(f, "□{}", e),
121            TLExpr::Diamond(e) => write!(f, "◇{}", e),
122            TLExpr::Next(e) => write!(f, "X{}", e),
123            TLExpr::Eventually(e) => write!(f, "F{}", e),
124            TLExpr::Always(e) => write!(f, "G{}", e),
125            TLExpr::Until { before, after } => write!(f, "({} U {})", before, after),
126            // Fuzzy logic operators
127            TLExpr::TNorm { kind, left, right } => {
128                write!(f, "({} ⊤_{:?} {})", left, kind, right)
129            }
130            TLExpr::TCoNorm { kind, left, right } => {
131                write!(f, "({} ⊥_{:?} {})", left, kind, right)
132            }
133            TLExpr::FuzzyNot { kind, expr } => write!(f, "¬_{:?}({})", kind, expr),
134            TLExpr::FuzzyImplication {
135                kind,
136                premise,
137                conclusion,
138            } => write!(f, "({} →_{:?} {})", premise, kind, conclusion),
139            // Probabilistic operators
140            TLExpr::SoftExists {
141                var,
142                domain,
143                body,
144                temperature,
145            } => write!(f, "∃^{{{}}}{}:{}. {}", temperature, var, domain, body),
146            TLExpr::SoftForAll {
147                var,
148                domain,
149                body,
150                temperature,
151            } => write!(f, "∀^{{{}}}{}:{}. {}", temperature, var, domain, body),
152            TLExpr::WeightedRule { weight, rule } => write!(f, "{}::{}", weight, rule),
153            TLExpr::ProbabilisticChoice { alternatives } => {
154                write!(f, "choice[")?;
155                for (i, (prob, expr)) in alternatives.iter().enumerate() {
156                    if i > 0 {
157                        write!(f, ", ")?;
158                    }
159                    write!(f, "{}: {}", prob, expr)?;
160                }
161                write!(f, "]")
162            }
163            // Extended temporal logic
164            TLExpr::Release { released, releaser } => write!(f, "({} R {})", released, releaser),
165            TLExpr::WeakUntil { before, after } => write!(f, "({} W {})", before, after),
166            TLExpr::StrongRelease { released, releaser } => {
167                write!(f, "({} M {})", released, releaser)
168            }
169            // Alpha.3 enhancements
170            TLExpr::Lambda {
171                var,
172                var_type,
173                body,
174            } => {
175                if let Some(ty) = var_type {
176                    write!(f, "λ{}:{}. {}", var, ty, body)
177                } else {
178                    write!(f, "λ{}. {}", var, body)
179                }
180            }
181            TLExpr::Apply { function, argument } => write!(f, "({} {})", function, argument),
182            TLExpr::SetMembership { element, set } => write!(f, "({} ∈ {})", element, set),
183            TLExpr::SetUnion { left, right } => write!(f, "({} ∪ {})", left, right),
184            TLExpr::SetIntersection { left, right } => write!(f, "({} ∩ {})", left, right),
185            TLExpr::SetDifference { left, right } => write!(f, "({} \\ {})", left, right),
186            TLExpr::SetCardinality { set } => write!(f, "|{}|", set),
187            TLExpr::EmptySet => write!(f, "∅"),
188            TLExpr::SetComprehension {
189                var,
190                domain,
191                condition,
192            } => write!(f, "{{ {}:{} | {} }}", var, domain, condition),
193            TLExpr::CountingExists {
194                var,
195                domain,
196                body,
197                min_count,
198            } => write!(f, "∃≥{}{}:{}. {}", min_count, var, domain, body),
199            TLExpr::CountingForAll {
200                var,
201                domain,
202                body,
203                min_count,
204            } => write!(f, "∀≥{}{}:{}. {}", min_count, var, domain, body),
205            TLExpr::ExactCount {
206                var,
207                domain,
208                body,
209                count,
210            } => write!(f, "∃={}{}:{}. {}", count, var, domain, body),
211            TLExpr::Majority { var, domain, body } => {
212                write!(f, "Majority {}:{}. {}", var, domain, body)
213            }
214            TLExpr::LeastFixpoint { var, body } => write!(f, "μ{}. {}", var, body),
215            TLExpr::GreatestFixpoint { var, body } => write!(f, "ν{}. {}", var, body),
216            TLExpr::Nominal { name } => write!(f, "@{}", name),
217            TLExpr::At { nominal, formula } => write!(f, "@{} {}", nominal, formula),
218            TLExpr::Somewhere { formula } => write!(f, "E {}", formula),
219            TLExpr::Everywhere { formula } => write!(f, "A {}", formula),
220            TLExpr::AllDifferent { variables } => {
221                write!(f, "alldiff([")?;
222                for (i, var) in variables.iter().enumerate() {
223                    if i > 0 {
224                        write!(f, ", ")?;
225                    }
226                    write!(f, "{}", var)?;
227                }
228                write!(f, "])")
229            }
230            TLExpr::GlobalCardinality {
231                variables,
232                values,
233                min_occurrences,
234                max_occurrences,
235            } => {
236                write!(f, "gcc([")?;
237                for (i, var) in variables.iter().enumerate() {
238                    if i > 0 {
239                        write!(f, ", ")?;
240                    }
241                    write!(f, "{}", var)?;
242                }
243                write!(f, "], [")?;
244                for (i, val) in values.iter().enumerate() {
245                    if i > 0 {
246                        write!(f, ", ")?;
247                    }
248                    write!(f, "{}:[{},{}]", val, min_occurrences[i], max_occurrences[i])?;
249                }
250                write!(f, "])")
251            }
252            TLExpr::Abducible { name, cost } => write!(f, "abd({}:{})", name, cost),
253            TLExpr::Explain { formula } => write!(f, "explain({})", formula),
254            TLExpr::Constant(value) => write!(f, "{}", value),
255        }
256    }
257}
258
259impl fmt::Display for OpType {
260    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
261        match self {
262            OpType::Einsum { spec } => write!(f, "einsum({})", spec),
263            OpType::ElemUnary { op } => write!(f, "{}(·)", op),
264            OpType::ElemBinary { op } => write!(f, "{}(·, ·)", op),
265            OpType::Reduce { op, axes } => write!(f, "{}(·, axes={:?})", op, axes),
266        }
267    }
268}
269
270impl fmt::Display for EinsumNode {
271    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
272        write!(f, "{} ", self.op)?;
273        write!(f, "inputs={:?}", self.inputs)?;
274        write!(f, " outputs={:?}", self.outputs)
275    }
276}
277
278impl fmt::Display for EinsumGraph {
279    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
280        writeln!(f, "EinsumGraph {{")?;
281        writeln!(f, "  tensors: {:?}", self.tensors)?;
282        writeln!(f, "  nodes: [")?;
283        for (i, node) in self.nodes.iter().enumerate() {
284            writeln!(f, "    {}: {}", i, node)?;
285        }
286        writeln!(f, "  ]")?;
287        writeln!(f, "  outputs: {:?}", self.outputs)?;
288        write!(f, "}}")
289    }
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295
296    #[test]
297    fn test_display_term() {
298        let var = Term::var("x");
299        assert_eq!(format!("{}", var), "?x");
300
301        let const_term = Term::constant("alice");
302        assert_eq!(format!("{}", const_term), "alice");
303
304        let typed = Term::typed_var("x", "Int");
305        assert_eq!(format!("{}", typed), "?x:Int");
306    }
307
308    #[test]
309    fn test_display_aggregate_op() {
310        assert_eq!(format!("{}", AggregateOp::Count), "COUNT");
311        assert_eq!(format!("{}", AggregateOp::Sum), "SUM");
312        assert_eq!(format!("{}", AggregateOp::Average), "AVG");
313    }
314
315    #[test]
316    fn test_display_simple_expr() {
317        let pred = TLExpr::pred("Person", vec![Term::var("x")]);
318        assert_eq!(format!("{}", pred), "Person(?x)");
319    }
320
321    #[test]
322    fn test_display_logical_ops() {
323        let p = TLExpr::pred("P", vec![Term::var("x")]);
324        let q = TLExpr::pred("Q", vec![Term::var("y")]);
325
326        let and_expr = TLExpr::and(p.clone(), q.clone());
327        assert_eq!(format!("{}", and_expr), "(P(?x) ∧ Q(?y))");
328
329        let or_expr = TLExpr::or(p.clone(), q);
330        assert_eq!(format!("{}", or_expr), "(P(?x) ∨ Q(?y))");
331
332        let not_expr = TLExpr::negate(p);
333        assert_eq!(format!("{}", not_expr), "¬P(?x)");
334    }
335
336    #[test]
337    fn test_display_quantifiers() {
338        let body = TLExpr::pred("P", vec![Term::var("x")]);
339
340        let exists = TLExpr::exists("x", "Domain", body.clone());
341        assert_eq!(format!("{}", exists), "∃x:Domain. P(?x)");
342
343        let forall = TLExpr::forall("x", "Domain", body);
344        assert_eq!(format!("{}", forall), "∀x:Domain. P(?x)");
345    }
346
347    #[test]
348    fn test_display_aggregate() {
349        let body = TLExpr::pred("Value", vec![Term::var("x")]);
350
351        let sum = TLExpr::sum("x", "Domain", body.clone());
352        assert_eq!(format!("{}", sum), "SUM(x:Domain, Value(?x))");
353
354        let count = TLExpr::count("x", "Domain", body);
355        assert_eq!(format!("{}", count), "COUNT(x:Domain, Value(?x))");
356    }
357
358    #[test]
359    fn test_display_aggregate_with_group_by() {
360        let body = TLExpr::pred("Value", vec![Term::var("x"), Term::var("y")]);
361
362        let agg = TLExpr::aggregate_with_group_by(
363            AggregateOp::Sum,
364            "x",
365            "Domain",
366            body,
367            vec!["y".to_string()],
368        );
369
370        let display = format!("{}", agg);
371        assert!(display.contains("SUM"));
372        assert!(display.contains("GROUP BY"));
373        assert!(display.contains("y"));
374    }
375
376    #[test]
377    fn test_display_arithmetic() {
378        let x = TLExpr::constant(5.0);
379        let y = TLExpr::constant(3.0);
380
381        let add = TLExpr::add(x.clone(), y.clone());
382        assert_eq!(format!("{}", add), "(5 + 3)");
383
384        let mul = TLExpr::mul(x, y);
385        assert_eq!(format!("{}", mul), "(5 * 3)");
386    }
387
388    #[test]
389    fn test_display_comparison() {
390        let x = TLExpr::pred("X", vec![Term::var("i")]);
391        let threshold = TLExpr::constant(0.5);
392
393        let gt = TLExpr::gt(x, threshold);
394        let display = format!("{}", gt);
395        assert!(display.contains(">"));
396        assert!(display.contains("0.5"));
397    }
398
399    #[test]
400    fn test_display_conditional() {
401        let cond = TLExpr::pred("IsAdult", vec![Term::var("x")]);
402        let then_br = TLExpr::constant(1.0);
403        let else_br = TLExpr::constant(0.0);
404
405        let if_expr = TLExpr::if_then_else(cond, then_br, else_br);
406        let display = format!("{}", if_expr);
407        assert!(display.contains("if"));
408        assert!(display.contains("then"));
409        assert!(display.contains("else"));
410    }
411
412    #[test]
413    fn test_display_einsum_node() {
414        let node = EinsumNode::new("ij,jk->ik", vec![0, 1], vec![2]);
415        let display = format!("{}", node);
416        assert!(display.contains("einsum"));
417        assert!(display.contains("ij,jk->ik"));
418        assert!(display.contains("inputs=[0, 1]"));
419        assert!(display.contains("outputs=[2]"));
420    }
421
422    #[test]
423    fn test_display_graph() {
424        let mut graph = EinsumGraph::new();
425        let t0 = graph.add_tensor("input");
426        let t1 = graph.add_tensor("output");
427
428        graph
429            .add_node(EinsumNode::new("i->i", vec![t0], vec![t1]))
430            .unwrap();
431        graph.add_output(t1).unwrap();
432
433        let display = format!("{}", graph);
434        assert!(display.contains("EinsumGraph"));
435        assert!(display.contains("tensors"));
436        assert!(display.contains("input"));
437        assert!(display.contains("output"));
438    }
439}