Skip to main content

tang_expr/
display.rs

1//! Pretty-printing for expressions.
2
3use crate::graph::ExprGraph;
4use crate::node::{ExprId, Node};
5
6impl ExprGraph {
7    /// Format an expression as a human-readable string.
8    pub fn fmt_expr(&self, expr: ExprId) -> String {
9        match self.node(expr) {
10            Node::Var(n) => format!("x{n}"),
11            Node::Lit(bits) => {
12                let v = f64::from_bits(bits);
13                if v == 0.0 {
14                    "0".to_string()
15                } else if v == 1.0 {
16                    "1".to_string()
17                } else if v == 2.0 {
18                    "2".to_string()
19                } else if v == -1.0 {
20                    "-1".to_string()
21                } else {
22                    format!("{v}")
23                }
24            }
25            Node::Add(a, b) => {
26                format!("({} + {})", self.fmt_expr(a), self.fmt_expr(b))
27            }
28            Node::Mul(a, b) => {
29                format!("({} * {})", self.fmt_expr(a), self.fmt_expr(b))
30            }
31            Node::Neg(a) => format!("(-{})", self.fmt_expr(a)),
32            Node::Recip(a) => format!("(1 / {})", self.fmt_expr(a)),
33            Node::Sqrt(a) => format!("sqrt({})", self.fmt_expr(a)),
34            Node::Sin(a) => format!("sin({})", self.fmt_expr(a)),
35            Node::Atan2(y, x) => {
36                format!("atan2({}, {})", self.fmt_expr(y), self.fmt_expr(x))
37            }
38            Node::Exp2(a) => format!("exp2({})", self.fmt_expr(a)),
39            Node::Log2(a) => format!("log2({})", self.fmt_expr(a)),
40            Node::Select(c, a, b) => {
41                format!(
42                    "select({}, {}, {})",
43                    self.fmt_expr(c),
44                    self.fmt_expr(a),
45                    self.fmt_expr(b)
46                )
47            }
48        }
49    }
50}
51
52#[cfg(test)]
53mod tests {
54    use crate::graph::ExprGraph;
55    use crate::node::ExprId;
56
57    #[test]
58    fn display_simple() {
59        let mut g = ExprGraph::new();
60        let x = g.var(0);
61        let y = g.var(1);
62        let sum = g.add(x, y);
63        assert_eq!(g.fmt_expr(sum), "(x0 + x1)");
64
65        let prod = g.mul(x, y);
66        assert_eq!(g.fmt_expr(prod), "(x0 * x1)");
67
68        let s = g.sin(x);
69        assert_eq!(g.fmt_expr(s), "sin(x0)");
70    }
71
72    #[test]
73    fn display_constants() {
74        let g = ExprGraph::new();
75        assert_eq!(g.fmt_expr(ExprId::ZERO), "0");
76        assert_eq!(g.fmt_expr(ExprId::ONE), "1");
77        assert_eq!(g.fmt_expr(ExprId::TWO), "2");
78    }
79}