Skip to main content

tang_expr/
eval.rs

1//! Generic evaluation of expression graphs.
2
3use tang::Scalar;
4
5use crate::graph::ExprGraph;
6use crate::node::{ExprId, Node};
7
8impl ExprGraph {
9    /// Evaluate an expression with concrete scalar inputs.
10    ///
11    /// `inputs[n]` provides the value for `Var(n)`. Walks the graph in
12    /// topological order (which is just index order, since children are
13    /// always created before parents).
14    pub fn eval<S: Scalar>(&self, expr: ExprId, inputs: &[S]) -> S {
15        let n = expr.0 as usize + 1;
16        let mut vals: Vec<S> = Vec::with_capacity(n);
17
18        for i in 0..n {
19            let v = match self.node(ExprId(i as u32)) {
20                Node::Var(idx) => inputs[idx as usize],
21                Node::Lit(bits) => S::from_f64(f64::from_bits(bits)),
22                Node::Add(a, b) => vals[a.0 as usize] + vals[b.0 as usize],
23                Node::Mul(a, b) => vals[a.0 as usize] * vals[b.0 as usize],
24                Node::Neg(a) => -vals[a.0 as usize],
25                Node::Recip(a) => vals[a.0 as usize].recip(),
26                Node::Sqrt(a) => vals[a.0 as usize].sqrt(),
27                Node::Sin(a) => vals[a.0 as usize].sin(),
28                Node::Atan2(y, x) => vals[y.0 as usize].atan2(vals[x.0 as usize]),
29                Node::Exp2(a) => {
30                    // exp2(x) = 2^x = exp(x * ln(2))
31                    let x = vals[a.0 as usize];
32                    (x * S::from_f64(std::f64::consts::LN_2)).exp()
33                }
34                Node::Log2(a) => {
35                    // log2(x) = ln(x) / ln(2)
36                    let x = vals[a.0 as usize];
37                    x.ln() * S::from_f64(std::f64::consts::LOG2_E)
38                }
39                Node::Select(c, a, b) => {
40                    S::select(vals[c.0 as usize], vals[a.0 as usize], vals[b.0 as usize])
41                }
42            };
43            vals.push(v);
44        }
45
46        vals[expr.0 as usize]
47    }
48
49    /// Evaluate multiple output expressions, sharing intermediate values.
50    pub fn eval_many<S: Scalar>(&self, exprs: &[ExprId], inputs: &[S]) -> Vec<S> {
51        if exprs.is_empty() {
52            return Vec::new();
53        }
54        let max_id = exprs.iter().map(|e| e.0).max().unwrap() as usize;
55        let n = max_id + 1;
56        let mut vals: Vec<S> = Vec::with_capacity(n);
57
58        for i in 0..n {
59            let v = match self.node(ExprId(i as u32)) {
60                Node::Var(idx) => inputs[idx as usize],
61                Node::Lit(bits) => S::from_f64(f64::from_bits(bits)),
62                Node::Add(a, b) => vals[a.0 as usize] + vals[b.0 as usize],
63                Node::Mul(a, b) => vals[a.0 as usize] * vals[b.0 as usize],
64                Node::Neg(a) => -vals[a.0 as usize],
65                Node::Recip(a) => vals[a.0 as usize].recip(),
66                Node::Sqrt(a) => vals[a.0 as usize].sqrt(),
67                Node::Sin(a) => vals[a.0 as usize].sin(),
68                Node::Atan2(y, x) => vals[y.0 as usize].atan2(vals[x.0 as usize]),
69                Node::Exp2(a) => {
70                    let x = vals[a.0 as usize];
71                    (x * S::from_f64(std::f64::consts::LN_2)).exp()
72                }
73                Node::Log2(a) => {
74                    let x = vals[a.0 as usize];
75                    x.ln() * S::from_f64(std::f64::consts::LOG2_E)
76                }
77                Node::Select(c, a, b) => {
78                    S::select(vals[c.0 as usize], vals[a.0 as usize], vals[b.0 as usize])
79                }
80            };
81            vals.push(v);
82        }
83
84        exprs.iter().map(|e| vals[e.0 as usize]).collect()
85    }
86}
87
88#[cfg(test)]
89mod tests {
90    use crate::graph::ExprGraph;
91
92    #[test]
93    fn eval_add_lits() {
94        let mut g = ExprGraph::new();
95        let a = g.lit(3.0);
96        let b = g.lit(4.0);
97        let sum = g.add(a, b);
98        let result: f64 = g.eval(sum, &[]);
99        assert!((result - 7.0).abs() < 1e-10);
100    }
101
102    #[test]
103    fn eval_with_vars() {
104        let mut g = ExprGraph::new();
105        let x = g.var(0);
106        let y = g.var(1);
107        let sum = g.add(x, y);
108        let prod = g.mul(sum, x);
109        // (x + y) * x at x=3, y=4 = 7 * 3 = 21
110        let result: f64 = g.eval(prod, &[3.0, 4.0]);
111        assert!((result - 21.0).abs() < 1e-10);
112    }
113
114    #[test]
115    fn eval_sqrt() {
116        let mut g = ExprGraph::new();
117        let x = g.var(0);
118        let sq = g.sqrt(x);
119        let result: f64 = g.eval(sq, &[9.0]);
120        assert!((result - 3.0).abs() < 1e-10);
121    }
122
123    #[test]
124    fn eval_sin() {
125        let mut g = ExprGraph::new();
126        let x = g.var(0);
127        let s = g.sin(x);
128        let result: f64 = g.eval(s, &[std::f64::consts::FRAC_PI_2]);
129        assert!((result - 1.0).abs() < 1e-10);
130    }
131
132    #[test]
133    fn eval_select_positive_cond() {
134        let mut g = ExprGraph::new();
135        let cond = g.lit(1.0);
136        let a = g.lit(3.0);
137        let b = g.lit(7.0);
138        let s = g.select(cond, a, b);
139        let result: f64 = g.eval(s, &[]);
140        assert!((result - 3.0).abs() < 1e-10);
141    }
142
143    #[test]
144    fn eval_select_negative_cond() {
145        let mut g = ExprGraph::new();
146        let cond = g.lit(-1.0);
147        let a = g.lit(3.0);
148        let b = g.lit(7.0);
149        let s = g.select(cond, a, b);
150        let result: f64 = g.eval(s, &[]);
151        assert!((result - 7.0).abs() < 1e-10);
152    }
153
154    #[test]
155    fn eval_select_zero_cond() {
156        // cond == 0 should select b (not > 0)
157        let mut g = ExprGraph::new();
158        let cond = g.lit(0.0);
159        let a = g.lit(3.0);
160        let b = g.lit(7.0);
161        let s = g.select(cond, a, b);
162        let result: f64 = g.eval(s, &[]);
163        assert!((result - 7.0).abs() < 1e-10);
164    }
165
166    #[test]
167    fn eval_many_outputs() {
168        let mut g = ExprGraph::new();
169        let x = g.var(0);
170        let y = g.var(1);
171        let sum = g.add(x, y);
172        let prod = g.mul(x, y);
173        let results: Vec<f64> = g.eval_many(&[sum, prod], &[3.0, 4.0]);
174        assert!((results[0] - 7.0).abs() < 1e-10);
175        assert!((results[1] - 12.0).abs() < 1e-10);
176    }
177}