1use tang::Scalar;
4
5use crate::graph::ExprGraph;
6use crate::node::{ExprId, Node};
7
8impl ExprGraph {
9 pub fn eval<S: Scalar>(&self, expr: ExprId, inputs: &[S]) -> S {
15 let n = expr.0 as usize + 1;
16 let mut vals: Vec<S> = Vec::with_capacity(n);
17
18 for i in 0..n {
19 let v = match self.node(ExprId(i as u32)) {
20 Node::Var(idx) => inputs[idx as usize],
21 Node::Lit(bits) => S::from_f64(f64::from_bits(bits)),
22 Node::Add(a, b) => vals[a.0 as usize] + vals[b.0 as usize],
23 Node::Mul(a, b) => vals[a.0 as usize] * vals[b.0 as usize],
24 Node::Neg(a) => -vals[a.0 as usize],
25 Node::Recip(a) => vals[a.0 as usize].recip(),
26 Node::Sqrt(a) => vals[a.0 as usize].sqrt(),
27 Node::Sin(a) => vals[a.0 as usize].sin(),
28 Node::Atan2(y, x) => vals[y.0 as usize].atan2(vals[x.0 as usize]),
29 Node::Exp2(a) => {
30 let x = vals[a.0 as usize];
32 (x * S::from_f64(std::f64::consts::LN_2)).exp()
33 }
34 Node::Log2(a) => {
35 let x = vals[a.0 as usize];
37 x.ln() * S::from_f64(std::f64::consts::LOG2_E)
38 }
39 Node::Select(c, a, b) => {
40 S::select(vals[c.0 as usize], vals[a.0 as usize], vals[b.0 as usize])
41 }
42 };
43 vals.push(v);
44 }
45
46 vals[expr.0 as usize]
47 }
48
49 pub fn eval_many<S: Scalar>(&self, exprs: &[ExprId], inputs: &[S]) -> Vec<S> {
51 if exprs.is_empty() {
52 return Vec::new();
53 }
54 let max_id = exprs.iter().map(|e| e.0).max().unwrap() as usize;
55 let n = max_id + 1;
56 let mut vals: Vec<S> = Vec::with_capacity(n);
57
58 for i in 0..n {
59 let v = match self.node(ExprId(i as u32)) {
60 Node::Var(idx) => inputs[idx as usize],
61 Node::Lit(bits) => S::from_f64(f64::from_bits(bits)),
62 Node::Add(a, b) => vals[a.0 as usize] + vals[b.0 as usize],
63 Node::Mul(a, b) => vals[a.0 as usize] * vals[b.0 as usize],
64 Node::Neg(a) => -vals[a.0 as usize],
65 Node::Recip(a) => vals[a.0 as usize].recip(),
66 Node::Sqrt(a) => vals[a.0 as usize].sqrt(),
67 Node::Sin(a) => vals[a.0 as usize].sin(),
68 Node::Atan2(y, x) => vals[y.0 as usize].atan2(vals[x.0 as usize]),
69 Node::Exp2(a) => {
70 let x = vals[a.0 as usize];
71 (x * S::from_f64(std::f64::consts::LN_2)).exp()
72 }
73 Node::Log2(a) => {
74 let x = vals[a.0 as usize];
75 x.ln() * S::from_f64(std::f64::consts::LOG2_E)
76 }
77 Node::Select(c, a, b) => {
78 S::select(vals[c.0 as usize], vals[a.0 as usize], vals[b.0 as usize])
79 }
80 };
81 vals.push(v);
82 }
83
84 exprs.iter().map(|e| vals[e.0 as usize]).collect()
85 }
86}
87
88#[cfg(test)]
89mod tests {
90 use crate::graph::ExprGraph;
91
92 #[test]
93 fn eval_add_lits() {
94 let mut g = ExprGraph::new();
95 let a = g.lit(3.0);
96 let b = g.lit(4.0);
97 let sum = g.add(a, b);
98 let result: f64 = g.eval(sum, &[]);
99 assert!((result - 7.0).abs() < 1e-10);
100 }
101
102 #[test]
103 fn eval_with_vars() {
104 let mut g = ExprGraph::new();
105 let x = g.var(0);
106 let y = g.var(1);
107 let sum = g.add(x, y);
108 let prod = g.mul(sum, x);
109 let result: f64 = g.eval(prod, &[3.0, 4.0]);
111 assert!((result - 21.0).abs() < 1e-10);
112 }
113
114 #[test]
115 fn eval_sqrt() {
116 let mut g = ExprGraph::new();
117 let x = g.var(0);
118 let sq = g.sqrt(x);
119 let result: f64 = g.eval(sq, &[9.0]);
120 assert!((result - 3.0).abs() < 1e-10);
121 }
122
123 #[test]
124 fn eval_sin() {
125 let mut g = ExprGraph::new();
126 let x = g.var(0);
127 let s = g.sin(x);
128 let result: f64 = g.eval(s, &[std::f64::consts::FRAC_PI_2]);
129 assert!((result - 1.0).abs() < 1e-10);
130 }
131
132 #[test]
133 fn eval_select_positive_cond() {
134 let mut g = ExprGraph::new();
135 let cond = g.lit(1.0);
136 let a = g.lit(3.0);
137 let b = g.lit(7.0);
138 let s = g.select(cond, a, b);
139 let result: f64 = g.eval(s, &[]);
140 assert!((result - 3.0).abs() < 1e-10);
141 }
142
143 #[test]
144 fn eval_select_negative_cond() {
145 let mut g = ExprGraph::new();
146 let cond = g.lit(-1.0);
147 let a = g.lit(3.0);
148 let b = g.lit(7.0);
149 let s = g.select(cond, a, b);
150 let result: f64 = g.eval(s, &[]);
151 assert!((result - 7.0).abs() < 1e-10);
152 }
153
154 #[test]
155 fn eval_select_zero_cond() {
156 let mut g = ExprGraph::new();
158 let cond = g.lit(0.0);
159 let a = g.lit(3.0);
160 let b = g.lit(7.0);
161 let s = g.select(cond, a, b);
162 let result: f64 = g.eval(s, &[]);
163 assert!((result - 7.0).abs() < 1e-10);
164 }
165
166 #[test]
167 fn eval_many_outputs() {
168 let mut g = ExprGraph::new();
169 let x = g.var(0);
170 let y = g.var(1);
171 let sum = g.add(x, y);
172 let prod = g.mul(x, y);
173 let results: Vec<f64> = g.eval_many(&[sum, prod], &[3.0, 4.0]);
174 assert!((results[0] - 7.0).abs() < 1e-10);
175 assert!((results[1] - 12.0).abs() < 1e-10);
176 }
177}