1use std::collections::HashSet;
4
5use crate::graph::ExprGraph;
6use crate::node::{ExprId, Node};
7
8pub type CompiledExpr = Box<dyn Fn(&[f64]) -> f64>;
10
11pub type CompiledMany = Box<dyn Fn(&[f64], &mut [f64])>;
13
14impl ExprGraph {
15 pub fn compile(&self, expr: ExprId) -> CompiledExpr {
20 let live = self.live_set(&[expr]);
21 let nodes = self.collect_eval_order(&live, expr.0 as usize + 1);
22 let out_idx = expr.0 as usize;
23
24 Box::new(move |inputs: &[f64]| {
25 let mut vals = vec![0.0f64; out_idx + 1];
26 for &(i, ref node) in &nodes {
27 vals[i] = eval_node(node, &vals, inputs);
28 }
29 vals[out_idx]
30 })
31 }
32
33 pub fn compile_many(&self, exprs: &[ExprId]) -> CompiledMany {
37 if exprs.is_empty() {
38 return Box::new(|_, _| {});
39 }
40
41 let live = self.live_set(exprs);
42 let max_id = exprs.iter().map(|e| e.0).max().unwrap() as usize;
43 let nodes = self.collect_eval_order(&live, max_id + 1);
44 let out_indices: Vec<usize> = exprs.iter().map(|e| e.0 as usize).collect();
45
46 Box::new(move |inputs: &[f64], outputs: &mut [f64]| {
47 let mut vals = vec![0.0f64; max_id + 1];
48 for &(i, ref node) in &nodes {
49 vals[i] = eval_node(node, &vals, inputs);
50 }
51 for (k, &idx) in out_indices.iter().enumerate() {
52 outputs[k] = vals[idx];
53 }
54 })
55 }
56
57 pub fn live_set(&self, outputs: &[ExprId]) -> HashSet<usize> {
59 let mut live = HashSet::new();
60 let mut stack: Vec<usize> = outputs.iter().map(|e| e.0 as usize).collect();
61 while let Some(i) = stack.pop() {
62 if !live.insert(i) {
63 continue;
64 }
65 match self.node(ExprId(i as u32)) {
66 Node::Var(_) | Node::Lit(_) => {}
67 Node::Add(a, b) | Node::Mul(a, b) | Node::Atan2(a, b) => {
68 stack.push(a.0 as usize);
69 stack.push(b.0 as usize);
70 }
71 Node::Neg(a)
72 | Node::Recip(a)
73 | Node::Sqrt(a)
74 | Node::Sin(a)
75 | Node::Exp2(a)
76 | Node::Log2(a) => {
77 stack.push(a.0 as usize);
78 }
79 Node::Select(c, a, b) => {
80 stack.push(c.0 as usize);
81 stack.push(a.0 as usize);
82 stack.push(b.0 as usize);
83 }
84 }
85 }
86 live
87 }
88
89 fn collect_eval_order(&self, live: &HashSet<usize>, count: usize) -> Vec<(usize, Node)> {
91 (0..count)
92 .filter(|i| live.contains(i))
93 .map(|i| (i, self.node(ExprId(i as u32))))
94 .collect()
95 }
96}
97
98#[inline]
99fn eval_node(node: &Node, vals: &[f64], inputs: &[f64]) -> f64 {
100 match *node {
101 Node::Var(idx) => inputs[idx as usize],
102 Node::Lit(bits) => f64::from_bits(bits),
103 Node::Add(a, b) => vals[a.0 as usize] + vals[b.0 as usize],
104 Node::Mul(a, b) => vals[a.0 as usize] * vals[b.0 as usize],
105 Node::Neg(a) => -vals[a.0 as usize],
106 Node::Recip(a) => 1.0 / vals[a.0 as usize],
107 Node::Sqrt(a) => vals[a.0 as usize].sqrt(),
108 Node::Sin(a) => vals[a.0 as usize].sin(),
109 Node::Atan2(y, x) => vals[y.0 as usize].atan2(vals[x.0 as usize]),
110 Node::Exp2(a) => vals[a.0 as usize].exp2(),
111 Node::Log2(a) => vals[a.0 as usize].log2(),
112 Node::Select(c, a, b) => {
113 if vals[c.0 as usize] > 0.0 {
114 vals[a.0 as usize]
115 } else {
116 vals[b.0 as usize]
117 }
118 }
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use crate::graph::ExprGraph;
125
126 #[test]
127 fn compile_add_lits() {
128 let mut g = ExprGraph::new();
129 let a = g.lit(3.0);
130 let b = g.lit(4.0);
131 let sum = g.add(a, b);
132 let f = g.compile(sum);
133 assert!((f(&[]) - 7.0).abs() < 1e-10);
134 }
135
136 #[test]
137 fn compile_with_vars() {
138 let mut g = ExprGraph::new();
139 let x = g.var(0);
140 let y = g.var(1);
141 let sum = g.add(x, y);
142 let prod = g.mul(sum, x);
143 let f = g.compile(prod);
144 assert!((f(&[3.0, 4.0]) - 21.0).abs() < 1e-10);
146 }
147
148 #[test]
149 fn compile_sin() {
150 let mut g = ExprGraph::new();
151 let x = g.var(0);
152 let s = g.sin(x);
153 let f = g.compile(s);
154 assert!((f(&[std::f64::consts::FRAC_PI_2]) - 1.0).abs() < 1e-10);
155 }
156
157 #[test]
158 fn compile_many_outputs() {
159 let mut g = ExprGraph::new();
160 let x = g.var(0);
161 let y = g.var(1);
162 let sum = g.add(x, y);
163 let prod = g.mul(x, y);
164 let f = g.compile_many(&[sum, prod]);
165 let mut out = [0.0; 2];
166 f(&[3.0, 4.0], &mut out);
167 assert!((out[0] - 7.0).abs() < 1e-10);
168 assert!((out[1] - 12.0).abs() < 1e-10);
169 }
170
171 #[test]
172 fn compile_dead_code_elimination() {
173 let mut g = ExprGraph::new();
174 let x = g.var(0);
175 let _dead = g.sin(x); let result = g.mul(x, x);
177 let f = g.compile(result);
178 assert!((f(&[5.0]) - 25.0).abs() < 1e-10);
179 }
180
181 #[test]
182 fn compile_matches_eval() {
183 let mut g = ExprGraph::new();
184 let x = g.var(0);
185 let y = g.var(1);
186 let xx = g.mul(x, x);
187 let yy = g.mul(y, y);
188 let sum = g.add(xx, yy);
189 let dist = g.sqrt(sum);
190
191 let inputs = [3.0, 4.0];
192 let eval_result: f64 = g.eval(dist, &inputs);
193 let f = g.compile(dist);
194 let compile_result = f(&inputs);
195 assert!((eval_result - compile_result).abs() < 1e-10);
196 assert!((compile_result - 5.0).abs() < 1e-10);
197 }
198}