Skip to main content

tang_expr/
simplify.rs

1//! Pattern-matched simplification rules.
2
3use std::collections::HashMap;
4
5use crate::graph::ExprGraph;
6use crate::node::{ExprId, Node};
7
8impl ExprGraph {
9    /// Simplify an expression by applying rewrite rules to fixpoint.
10    ///
11    /// Bottom-up: simplify children first, then match parent. Iterates
12    /// until no more changes occur.
13    pub fn simplify(&mut self, expr: ExprId) -> ExprId {
14        let mut memo = HashMap::new();
15        self.simplify_inner(expr, &mut memo)
16    }
17
18    fn simplify_inner(&mut self, expr: ExprId, memo: &mut HashMap<ExprId, ExprId>) -> ExprId {
19        if let Some(&cached) = memo.get(&expr) {
20            return cached;
21        }
22
23        // First, simplify children
24        let simplified_children = match self.node(expr) {
25            Node::Var(_) | Node::Lit(_) => expr,
26            Node::Add(a, b) => {
27                let sa = self.simplify_inner(a, memo);
28                let sb = self.simplify_inner(b, memo);
29                self.add(sa, sb)
30            }
31            Node::Mul(a, b) => {
32                let sa = self.simplify_inner(a, memo);
33                let sb = self.simplify_inner(b, memo);
34                self.mul(sa, sb)
35            }
36            Node::Neg(a) => {
37                let sa = self.simplify_inner(a, memo);
38                self.neg(sa)
39            }
40            Node::Recip(a) => {
41                let sa = self.simplify_inner(a, memo);
42                self.recip(sa)
43            }
44            Node::Sqrt(a) => {
45                let sa = self.simplify_inner(a, memo);
46                self.sqrt(sa)
47            }
48            Node::Sin(a) => {
49                let sa = self.simplify_inner(a, memo);
50                self.sin(sa)
51            }
52            Node::Atan2(y, x) => {
53                let sy = self.simplify_inner(y, memo);
54                let sx = self.simplify_inner(x, memo);
55                self.atan2(sy, sx)
56            }
57            Node::Exp2(a) => {
58                let sa = self.simplify_inner(a, memo);
59                self.exp2(sa)
60            }
61            Node::Log2(a) => {
62                let sa = self.simplify_inner(a, memo);
63                self.log2(sa)
64            }
65            Node::Select(c, a, b) => {
66                let sc = self.simplify_inner(c, memo);
67                let sa = self.simplify_inner(a, memo);
68                let sb = self.simplify_inner(b, memo);
69                self.select(sc, sa, sb)
70            }
71        };
72
73        // Now apply rewrite rules on the node with simplified children
74        let result = self.rewrite(simplified_children);
75
76        // If rewrite changed something, simplify again (fixpoint)
77        let final_result = if result != simplified_children {
78            self.simplify_inner(result, memo)
79        } else {
80            result
81        };
82
83        memo.insert(expr, final_result);
84        final_result
85    }
86
87    /// Apply one round of rewrite rules.
88    fn rewrite(&mut self, expr: ExprId) -> ExprId {
89        match self.node(expr) {
90            // --- Identity / Annihilation ---
91
92            // Add(x, ZERO) → x
93            Node::Add(a, b) if b == ExprId::ZERO => a,
94            // Add(ZERO, x) → x
95            Node::Add(a, b) if a == ExprId::ZERO => b,
96
97            // Mul(x, ONE) → x
98            Node::Mul(a, b) if b == ExprId::ONE => a,
99            // Mul(ONE, x) → x
100            Node::Mul(a, b) if a == ExprId::ONE => b,
101            // Mul(x, ZERO) → ZERO
102            Node::Mul(_, b) if b == ExprId::ZERO => ExprId::ZERO,
103            // Mul(ZERO, x) → ZERO
104            Node::Mul(a, _) if a == ExprId::ZERO => ExprId::ZERO,
105
106            // Neg(Neg(x)) → x
107            Node::Neg(a) => match self.node(a) {
108                Node::Neg(inner) => inner,
109                // Neg(ZERO) → ZERO
110                _ if a == ExprId::ZERO => ExprId::ZERO,
111                // Constant folding: Neg(Lit(v)) → Lit(-v)
112                Node::Lit(bits) => {
113                    let v = f64::from_bits(bits);
114                    self.lit(-v)
115                }
116                _ => expr,
117            },
118
119            // Recip(Recip(x)) → x
120            Node::Recip(a) => match self.node(a) {
121                Node::Recip(inner) => inner,
122                // Constant folding: Recip(Lit(v)) → Lit(1/v)
123                Node::Lit(bits) => {
124                    let v = f64::from_bits(bits);
125                    self.lit(1.0 / v)
126                }
127                _ => expr,
128            },
129
130            // --- Cancellation ---
131
132            // Add(x, Neg(x)) → ZERO
133            Node::Add(a, b) => {
134                if let Node::Neg(inner) = self.node(b) {
135                    if inner == a {
136                        return ExprId::ZERO;
137                    }
138                }
139                if let Node::Neg(inner) = self.node(a) {
140                    if inner == b {
141                        return ExprId::ZERO;
142                    }
143                }
144                // Constant folding: Add(Lit(a), Lit(b)) → Lit(a+b)
145                if let (Some(va), Some(vb)) = (self.node(a).as_f64(), self.node(b).as_f64()) {
146                    return self.lit(va + vb);
147                }
148                expr
149            }
150
151            Node::Mul(a, b) => {
152                // Mul(x, Recip(x)) → ONE
153                if let Node::Recip(inner) = self.node(b) {
154                    if inner == a {
155                        return ExprId::ONE;
156                    }
157                }
158                if let Node::Recip(inner) = self.node(a) {
159                    if inner == b {
160                        return ExprId::ONE;
161                    }
162                }
163                // Constant folding: Mul(Lit(a), Lit(b)) → Lit(a*b)
164                if let (Some(va), Some(vb)) = (self.node(a).as_f64(), self.node(b).as_f64()) {
165                    return self.lit(va * vb);
166                }
167                expr
168            }
169
170            // Constant folding for unary ops
171            Node::Sqrt(a) => {
172                if let Some(v) = self.node(a).as_f64() {
173                    self.lit(v.sqrt())
174                } else {
175                    expr
176                }
177            }
178            Node::Sin(a) => {
179                if let Some(v) = self.node(a).as_f64() {
180                    self.lit(v.sin())
181                } else {
182                    expr
183                }
184            }
185            Node::Exp2(a) => {
186                if let Some(v) = self.node(a).as_f64() {
187                    self.lit(v.exp2())
188                } else {
189                    expr
190                }
191            }
192            Node::Log2(a) => {
193                if let Some(v) = self.node(a).as_f64() {
194                    self.lit(v.log2())
195                } else {
196                    expr
197                }
198            }
199
200            // Select constant folding
201            Node::Select(c, a, b) => {
202                if let Some(vc) = self.node(c).as_f64() {
203                    if vc > 0.0 { a } else { b }
204                } else {
205                    expr
206                }
207            }
208
209            _ => expr,
210        }
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use crate::graph::ExprGraph;
217    use crate::node::ExprId;
218
219    #[test]
220    fn simplify_add_zero() {
221        let mut g = ExprGraph::new();
222        let x = g.var(0);
223        let sum = g.add(x, ExprId::ZERO);
224        let s = g.simplify(sum);
225        assert_eq!(s, x);
226
227        let sum2 = g.add(ExprId::ZERO, x);
228        let s2 = g.simplify(sum2);
229        assert_eq!(s2, x);
230    }
231
232    #[test]
233    fn simplify_mul_one() {
234        let mut g = ExprGraph::new();
235        let x = g.var(0);
236        let prod = g.mul(x, ExprId::ONE);
237        let s = g.simplify(prod);
238        assert_eq!(s, x);
239    }
240
241    #[test]
242    fn simplify_mul_zero() {
243        let mut g = ExprGraph::new();
244        let x = g.var(0);
245        let prod = g.mul(x, ExprId::ZERO);
246        let s = g.simplify(prod);
247        assert_eq!(s, ExprId::ZERO);
248    }
249
250    #[test]
251    fn simplify_neg_neg() {
252        let mut g = ExprGraph::new();
253        let x = g.var(0);
254        let nn = g.neg(x);
255        let nnn = g.neg(nn);
256        let s = g.simplify(nnn);
257        assert_eq!(s, x);
258    }
259
260    #[test]
261    fn simplify_recip_recip() {
262        let mut g = ExprGraph::new();
263        let x = g.var(0);
264        let r = g.recip(x);
265        let rr = g.recip(r);
266        let s = g.simplify(rr);
267        assert_eq!(s, x);
268    }
269
270    #[test]
271    fn simplify_cancel_add_neg() {
272        let mut g = ExprGraph::new();
273        let x = g.var(0);
274        let nx = g.neg(x);
275        let sum = g.add(x, nx);
276        let s = g.simplify(sum);
277        assert_eq!(s, ExprId::ZERO);
278    }
279
280    #[test]
281    fn simplify_cancel_mul_recip() {
282        let mut g = ExprGraph::new();
283        let x = g.var(0);
284        let rx = g.recip(x);
285        let prod = g.mul(x, rx);
286        let s = g.simplify(prod);
287        assert_eq!(s, ExprId::ONE);
288    }
289
290    #[test]
291    fn simplify_constant_fold_add() {
292        let mut g = ExprGraph::new();
293        let a = g.lit(3.0);
294        let b = g.lit(4.0);
295        let sum = g.add(a, b);
296        let s = g.simplify(sum);
297        let result: f64 = g.eval(s, &[]);
298        assert!((result - 7.0).abs() < 1e-10);
299    }
300
301    #[test]
302    fn simplify_constant_fold_mul() {
303        let mut g = ExprGraph::new();
304        let a = g.lit(3.0);
305        let b = g.lit(4.0);
306        let prod = g.mul(a, b);
307        let s = g.simplify(prod);
308        let result: f64 = g.eval(s, &[]);
309        assert!((result - 12.0).abs() < 1e-10);
310    }
311
312    #[test]
313    fn simplify_neg_zero() {
314        let mut g = ExprGraph::new();
315        let nz = g.neg(ExprId::ZERO);
316        let s = g.simplify(nz);
317        // Neg(Lit(0)) → Lit(-0) which is 0.0 in bits check
318        // Actually -0.0 has different bits than 0.0, so this creates a new lit.
319        // But functionally it's still zero. Let's just verify the value.
320        let result: f64 = g.eval(s, &[]);
321        assert!(result == 0.0);
322    }
323
324    #[test]
325    fn simplify_derivative() {
326        // d/dx (x*x) = 2x after simplification
327        let mut g = ExprGraph::new();
328        let x = g.var(0);
329        let xx = g.mul(x, x);
330        let d = g.diff(xx, 0);
331
332        // Before simplification, d = Add(Mul(ONE, x), Mul(x, ONE))
333        // After: Add(x, x)
334        let s = g.simplify(d);
335        let result: f64 = g.eval(s, &[5.0]);
336        assert!((result - 10.0).abs() < 1e-10);
337    }
338}