Skip to main content

tang_expr/
scalar.rs

1//! `Scalar` trait implementation for `ExprId`.
2//!
3//! Every Scalar method decomposes into the 9 RISC primitives by inserting
4//! nodes into the thread-local expression graph.
5
6use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
7
8use tang::Scalar;
9
10use crate::node::ExprId;
11use crate::with_graph;
12
13// --- Operator impls (all delegate to graph ops) ---
14
15impl Add for ExprId {
16    type Output = Self;
17    #[inline]
18    fn add(self, rhs: Self) -> Self {
19        with_graph(|g| g.add(self, rhs))
20    }
21}
22
23impl Sub for ExprId {
24    type Output = Self;
25    #[inline]
26    fn sub(self, rhs: Self) -> Self {
27        // sub(a, b) = add(a, neg(b))
28        let nb = with_graph(|g| g.neg(rhs));
29        with_graph(|g| g.add(self, nb))
30    }
31}
32
33impl Mul for ExprId {
34    type Output = Self;
35    #[inline]
36    fn mul(self, rhs: Self) -> Self {
37        with_graph(|g| g.mul(self, rhs))
38    }
39}
40
41impl Div for ExprId {
42    type Output = Self;
43    #[inline]
44    fn div(self, rhs: Self) -> Self {
45        // div(a, b) = mul(a, recip(b))
46        let rb = with_graph(|g| g.recip(rhs));
47        with_graph(|g| g.mul(self, rb))
48    }
49}
50
51impl Neg for ExprId {
52    type Output = Self;
53    #[inline]
54    fn neg(self) -> Self {
55        with_graph(|g| g.neg(self))
56    }
57}
58
59impl AddAssign for ExprId {
60    #[inline]
61    fn add_assign(&mut self, rhs: Self) {
62        *self = *self + rhs;
63    }
64}
65
66impl SubAssign for ExprId {
67    #[inline]
68    fn sub_assign(&mut self, rhs: Self) {
69        *self = *self - rhs;
70    }
71}
72
73impl MulAssign for ExprId {
74    #[inline]
75    fn mul_assign(&mut self, rhs: Self) {
76        *self = *self * rhs;
77    }
78}
79
80impl DivAssign for ExprId {
81    #[inline]
82    fn div_assign(&mut self, rhs: Self) {
83        *self = *self / rhs;
84    }
85}
86
87// --- Scalar impl ---
88
89impl Scalar for ExprId {
90    const ZERO: Self = ExprId::ZERO;
91    const ONE: Self = ExprId::ONE;
92    const TWO: Self = ExprId::TWO;
93    // These can't be const because they need graph insertion.
94    // We use ZERO as placeholder — the actual values are injected lazily.
95    const HALF: Self = ExprId(u32::MAX - 1);
96    const PI: Self = ExprId(u32::MAX - 2);
97    const TAU: Self = ExprId(u32::MAX - 3);
98    const FRAC_PI_2: Self = ExprId(u32::MAX - 4);
99    const EPSILON: Self = ExprId(u32::MAX - 5);
100    const INFINITY: Self = ExprId(u32::MAX - 6);
101    const NEG_INFINITY: Self = ExprId(u32::MAX - 7);
102
103    #[inline]
104    fn sqrt(self) -> Self {
105        with_graph(|g| g.sqrt(self))
106    }
107
108    #[inline]
109    fn abs(self) -> Self {
110        // abs(x) = sqrt(x * x)
111        let xx = with_graph(|g| g.mul(self, self));
112        with_graph(|g| g.sqrt(xx))
113    }
114
115    #[inline]
116    fn sin(self) -> Self {
117        with_graph(|g| g.sin(self))
118    }
119
120    #[inline]
121    fn cos(self) -> Self {
122        // cos(x) = sin(x + PI/2)
123        let half_pi = Self::from_f64(std::f64::consts::FRAC_PI_2);
124        let shifted = with_graph(|g| g.add(self, half_pi));
125        with_graph(|g| g.sin(shifted))
126    }
127
128    #[inline]
129    fn tan(self) -> Self {
130        // tan(x) = sin(x) * recip(cos(x))
131        let s = self.sin();
132        let c = self.cos();
133        let rc = with_graph(|g| g.recip(c));
134        with_graph(|g| g.mul(s, rc))
135    }
136
137    #[inline]
138    fn asin(self) -> Self {
139        // asin(x) = atan2(x, sqrt(1 - x*x))
140        let one = ExprId::ONE;
141        let xx = with_graph(|g| g.mul(self, self));
142        let diff = with_graph(|g| {
143            let neg_xx = g.neg(xx);
144            g.add(one, neg_xx)
145        });
146        let sq = with_graph(|g| g.sqrt(diff));
147        with_graph(|g| g.atan2(self, sq))
148    }
149
150    #[inline]
151    fn acos(self) -> Self {
152        // acos(x) = atan2(sqrt(1 - x*x), x)
153        let one = ExprId::ONE;
154        let xx = with_graph(|g| g.mul(self, self));
155        let diff = with_graph(|g| {
156            let neg_xx = g.neg(xx);
157            g.add(one, neg_xx)
158        });
159        let sq = with_graph(|g| g.sqrt(diff));
160        with_graph(|g| g.atan2(sq, self))
161    }
162
163    #[inline]
164    fn atan2(self, other: Self) -> Self {
165        with_graph(|g| g.atan2(self, other))
166    }
167
168    #[inline]
169    fn sin_cos(self) -> (Self, Self) {
170        (self.sin(), self.cos())
171    }
172
173    #[inline]
174    fn min(self, other: Self) -> Self {
175        // min(a, b) = 0.5 * (a + b - sqrt((a-b)^2))
176        let half = Self::from_f64(0.5);
177        let sum = self + other;
178        let diff = self - other;
179        let diff_sq = with_graph(|g| g.mul(diff, diff));
180        let abs_diff = with_graph(|g| g.sqrt(diff_sq));
181        let neg_abs = with_graph(|g| g.neg(abs_diff));
182        let inner = with_graph(|g| g.add(sum, neg_abs));
183        with_graph(|g| g.mul(half, inner))
184    }
185
186    #[inline]
187    fn max(self, other: Self) -> Self {
188        // max(a, b) = 0.5 * (a + b + sqrt((a-b)^2))
189        let half = Self::from_f64(0.5);
190        let sum = self + other;
191        let diff = self - other;
192        let diff_sq = with_graph(|g| g.mul(diff, diff));
193        let abs_diff = with_graph(|g| g.sqrt(diff_sq));
194        let inner = with_graph(|g| g.add(sum, abs_diff));
195        with_graph(|g| g.mul(half, inner))
196    }
197
198    #[inline]
199    fn clamp(self, lo: Self, hi: Self) -> Self {
200        self.max(lo).min(hi)
201    }
202
203    #[inline]
204    fn recip(self) -> Self {
205        with_graph(|g| g.recip(self))
206    }
207
208    #[inline]
209    fn powi(self, n: i32) -> Self {
210        match n {
211            0 => ExprId::ONE,
212            1 => self,
213            2 => with_graph(|g| g.mul(self, self)),
214            3 => {
215                let sq = with_graph(|g| g.mul(self, self));
216                with_graph(|g| g.mul(sq, self))
217            }
218            4 => {
219                let sq = with_graph(|g| g.mul(self, self));
220                with_graph(|g| g.mul(sq, sq))
221            }
222            -1 => with_graph(|g| g.recip(self)),
223            -2 => {
224                let sq = with_graph(|g| g.mul(self, self));
225                with_graph(|g| g.recip(sq))
226            }
227            _ => self.powf(Self::from_f64(n as f64)),
228        }
229    }
230
231    #[inline]
232    fn copysign(self, sign: Self) -> Self {
233        // copysign(x, s) = abs(x) * signum(s)
234        let ax = self.abs();
235        let ss = sign.signum();
236        with_graph(|g| g.mul(ax, ss))
237    }
238
239    #[inline]
240    fn signum(self) -> Self {
241        // signum(x) = x * recip(sqrt(x * x))
242        let xx = with_graph(|g| g.mul(self, self));
243        let abs_x = with_graph(|g| g.sqrt(xx));
244        let r = with_graph(|g| g.recip(abs_x));
245        with_graph(|g| g.mul(self, r))
246    }
247
248    #[inline]
249    fn floor(self) -> Self {
250        // Only works for literals
251        with_graph(|g| {
252            if let Some(v) = g.node(self).as_f64() {
253                g.lit(v.floor())
254            } else {
255                panic!("floor() requires a literal expression")
256            }
257        })
258    }
259
260    #[inline]
261    fn ceil(self) -> Self {
262        with_graph(|g| {
263            if let Some(v) = g.node(self).as_f64() {
264                g.lit(v.ceil())
265            } else {
266                panic!("ceil() requires a literal expression")
267            }
268        })
269    }
270
271    #[inline]
272    fn round(self) -> Self {
273        with_graph(|g| {
274            if let Some(v) = g.node(self).as_f64() {
275                g.lit(v.round())
276            } else {
277                panic!("round() requires a literal expression")
278            }
279        })
280    }
281
282    #[inline]
283    fn exp(self) -> Self {
284        // exp(x) = exp2(x * log2(e))
285        let log2_e = Self::from_f64(std::f64::consts::LOG2_E);
286        let scaled = with_graph(|g| g.mul(self, log2_e));
287        with_graph(|g| g.exp2(scaled))
288    }
289
290    #[inline]
291    fn ln(self) -> Self {
292        // ln(x) = log2(x) * ln(2)
293        let ln_2 = Self::from_f64(std::f64::consts::LN_2);
294        let l = with_graph(|g| g.log2(self));
295        with_graph(|g| g.mul(l, ln_2))
296    }
297
298    #[inline]
299    fn powf(self, p: Self) -> Self {
300        // powf(x, p) = exp2(p * log2(x))
301        let l = with_graph(|g| g.log2(self));
302        let pl = with_graph(|g| g.mul(p, l));
303        with_graph(|g| g.exp2(pl))
304    }
305
306    #[inline]
307    fn sinh(self) -> Self {
308        // sinh(x) = 0.5 * (exp(x) - exp(-x))
309        let half = Self::from_f64(0.5);
310        let ex = self.exp();
311        let neg_x = with_graph(|g| g.neg(self));
312        let enx = Scalar::exp(neg_x);
313        let diff = ex - enx;
314        with_graph(|g| g.mul(half, diff))
315    }
316
317    #[inline]
318    fn cosh(self) -> Self {
319        // cosh(x) = 0.5 * (exp(x) + exp(-x))
320        let half = Self::from_f64(0.5);
321        let ex = self.exp();
322        let neg_x = with_graph(|g| g.neg(self));
323        let enx = Scalar::exp(neg_x);
324        let sum = ex + enx;
325        with_graph(|g| g.mul(half, sum))
326    }
327
328    #[inline]
329    fn tanh(self) -> Self {
330        // tanh(x) = sinh(x) / cosh(x)
331        let s = self.sinh();
332        let c = self.cosh();
333        let rc = with_graph(|g| g.recip(c));
334        with_graph(|g| g.mul(s, rc))
335    }
336
337    #[inline]
338    fn acosh(self) -> Self {
339        // acosh(x) = ln(x + sqrt(x*x - 1))
340        let one = ExprId::ONE;
341        let xx = with_graph(|g| g.mul(self, self));
342        let diff = with_graph(|g| {
343            let neg_one = g.neg(one);
344            g.add(xx, neg_one)
345        });
346        let sq = with_graph(|g| g.sqrt(diff));
347        let sum = with_graph(|g| g.add(self, sq));
348        Scalar::ln(sum)
349    }
350
351    #[inline]
352    fn asinh(self) -> Self {
353        // asinh(x) = ln(x + sqrt(x*x + 1))
354        let one = ExprId::ONE;
355        let xx = with_graph(|g| g.mul(self, self));
356        let sum_inner = with_graph(|g| g.add(xx, one));
357        let sq = with_graph(|g| g.sqrt(sum_inner));
358        let sum = with_graph(|g| g.add(self, sq));
359        Scalar::ln(sum)
360    }
361
362    #[inline]
363    fn atanh(self) -> Self {
364        // atanh(x) = 0.5 * ln((1+x) / (1-x))
365        let half = Self::from_f64(0.5);
366        let one = ExprId::ONE;
367        let one_plus = with_graph(|g| g.add(one, self));
368        let neg_x = with_graph(|g| g.neg(self));
369        let one_minus = with_graph(|g| g.add(one, neg_x));
370        let ratio = one_plus / one_minus;
371        let l = Scalar::ln(ratio);
372        with_graph(|g| g.mul(half, l))
373    }
374
375    #[inline]
376    fn from_f64(v: f64) -> Self {
377        with_graph(|g| g.lit(v))
378    }
379
380    #[inline]
381    fn to_f64(self) -> f64 {
382        panic!("cannot evaluate symbolic ExprId to f64 — use ExprGraph::eval() instead")
383    }
384
385    #[inline]
386    fn from_i32(v: i32) -> Self {
387        with_graph(|g| g.lit(v as f64))
388    }
389
390    #[inline]
391    fn select(cond: Self, a: Self, b: Self) -> Self {
392        with_graph(|g| g.select(cond, a, b))
393    }
394}
395
396// --- Display for ExprId (needed by Scalar bound) ---
397// Note: ExprId already has Display in node.rs, just showing "eN".
398// The detailed expression display is in display.rs via ExprGraph::fmt_expr.
399
400#[cfg(test)]
401mod tests {
402    use tang::Scalar;
403
404    use crate::{trace, ExprId};
405
406    #[test]
407    fn basic_arithmetic() {
408        let (g, result) = trace(|| {
409            let x = ExprId::from_f64(3.0);
410            let y = ExprId::from_f64(4.0);
411            x + y
412        });
413        // Should be a single Add node
414        let val = g.eval::<f64>(result, &[]);
415        assert!((val - 7.0).abs() < 1e-10);
416    }
417
418    #[test]
419    fn var_trace() {
420        let (g, result) = trace(|| {
421            let x: ExprId = Scalar::from_f64(0.0); // creates a Lit(0.0) = ZERO
422            x
423        });
424        assert_eq!(result, ExprId::ZERO);
425        assert_eq!(g.len(), 3); // just ZERO, ONE, TWO
426    }
427
428    #[test]
429    fn constants_are_lits() {
430        let (_g, (half, pi)) = trace(|| {
431            let h = ExprId::from_f64(0.5);
432            let p = ExprId::from_f64(std::f64::consts::PI);
433            (h, p)
434        });
435        assert_ne!(half, pi);
436    }
437}