1use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
7
8use tang::Scalar;
9
10use crate::node::ExprId;
11use crate::with_graph;
12
13impl Add for ExprId {
16 type Output = Self;
17 #[inline]
18 fn add(self, rhs: Self) -> Self {
19 with_graph(|g| g.add(self, rhs))
20 }
21}
22
23impl Sub for ExprId {
24 type Output = Self;
25 #[inline]
26 fn sub(self, rhs: Self) -> Self {
27 let nb = with_graph(|g| g.neg(rhs));
29 with_graph(|g| g.add(self, nb))
30 }
31}
32
33impl Mul for ExprId {
34 type Output = Self;
35 #[inline]
36 fn mul(self, rhs: Self) -> Self {
37 with_graph(|g| g.mul(self, rhs))
38 }
39}
40
41impl Div for ExprId {
42 type Output = Self;
43 #[inline]
44 fn div(self, rhs: Self) -> Self {
45 let rb = with_graph(|g| g.recip(rhs));
47 with_graph(|g| g.mul(self, rb))
48 }
49}
50
51impl Neg for ExprId {
52 type Output = Self;
53 #[inline]
54 fn neg(self) -> Self {
55 with_graph(|g| g.neg(self))
56 }
57}
58
59impl AddAssign for ExprId {
60 #[inline]
61 fn add_assign(&mut self, rhs: Self) {
62 *self = *self + rhs;
63 }
64}
65
66impl SubAssign for ExprId {
67 #[inline]
68 fn sub_assign(&mut self, rhs: Self) {
69 *self = *self - rhs;
70 }
71}
72
73impl MulAssign for ExprId {
74 #[inline]
75 fn mul_assign(&mut self, rhs: Self) {
76 *self = *self * rhs;
77 }
78}
79
80impl DivAssign for ExprId {
81 #[inline]
82 fn div_assign(&mut self, rhs: Self) {
83 *self = *self / rhs;
84 }
85}
86
87impl Scalar for ExprId {
90 const ZERO: Self = ExprId::ZERO;
91 const ONE: Self = ExprId::ONE;
92 const TWO: Self = ExprId::TWO;
93 const HALF: Self = ExprId(u32::MAX - 1);
96 const PI: Self = ExprId(u32::MAX - 2);
97 const TAU: Self = ExprId(u32::MAX - 3);
98 const FRAC_PI_2: Self = ExprId(u32::MAX - 4);
99 const EPSILON: Self = ExprId(u32::MAX - 5);
100 const INFINITY: Self = ExprId(u32::MAX - 6);
101 const NEG_INFINITY: Self = ExprId(u32::MAX - 7);
102
103 #[inline]
104 fn sqrt(self) -> Self {
105 with_graph(|g| g.sqrt(self))
106 }
107
108 #[inline]
109 fn abs(self) -> Self {
110 let xx = with_graph(|g| g.mul(self, self));
112 with_graph(|g| g.sqrt(xx))
113 }
114
115 #[inline]
116 fn sin(self) -> Self {
117 with_graph(|g| g.sin(self))
118 }
119
120 #[inline]
121 fn cos(self) -> Self {
122 let half_pi = Self::from_f64(std::f64::consts::FRAC_PI_2);
124 let shifted = with_graph(|g| g.add(self, half_pi));
125 with_graph(|g| g.sin(shifted))
126 }
127
128 #[inline]
129 fn tan(self) -> Self {
130 let s = self.sin();
132 let c = self.cos();
133 let rc = with_graph(|g| g.recip(c));
134 with_graph(|g| g.mul(s, rc))
135 }
136
137 #[inline]
138 fn asin(self) -> Self {
139 let one = ExprId::ONE;
141 let xx = with_graph(|g| g.mul(self, self));
142 let diff = with_graph(|g| {
143 let neg_xx = g.neg(xx);
144 g.add(one, neg_xx)
145 });
146 let sq = with_graph(|g| g.sqrt(diff));
147 with_graph(|g| g.atan2(self, sq))
148 }
149
150 #[inline]
151 fn acos(self) -> Self {
152 let one = ExprId::ONE;
154 let xx = with_graph(|g| g.mul(self, self));
155 let diff = with_graph(|g| {
156 let neg_xx = g.neg(xx);
157 g.add(one, neg_xx)
158 });
159 let sq = with_graph(|g| g.sqrt(diff));
160 with_graph(|g| g.atan2(sq, self))
161 }
162
163 #[inline]
164 fn atan2(self, other: Self) -> Self {
165 with_graph(|g| g.atan2(self, other))
166 }
167
168 #[inline]
169 fn sin_cos(self) -> (Self, Self) {
170 (self.sin(), self.cos())
171 }
172
173 #[inline]
174 fn min(self, other: Self) -> Self {
175 let half = Self::from_f64(0.5);
177 let sum = self + other;
178 let diff = self - other;
179 let diff_sq = with_graph(|g| g.mul(diff, diff));
180 let abs_diff = with_graph(|g| g.sqrt(diff_sq));
181 let neg_abs = with_graph(|g| g.neg(abs_diff));
182 let inner = with_graph(|g| g.add(sum, neg_abs));
183 with_graph(|g| g.mul(half, inner))
184 }
185
186 #[inline]
187 fn max(self, other: Self) -> Self {
188 let half = Self::from_f64(0.5);
190 let sum = self + other;
191 let diff = self - other;
192 let diff_sq = with_graph(|g| g.mul(diff, diff));
193 let abs_diff = with_graph(|g| g.sqrt(diff_sq));
194 let inner = with_graph(|g| g.add(sum, abs_diff));
195 with_graph(|g| g.mul(half, inner))
196 }
197
198 #[inline]
199 fn clamp(self, lo: Self, hi: Self) -> Self {
200 self.max(lo).min(hi)
201 }
202
203 #[inline]
204 fn recip(self) -> Self {
205 with_graph(|g| g.recip(self))
206 }
207
208 #[inline]
209 fn powi(self, n: i32) -> Self {
210 match n {
211 0 => ExprId::ONE,
212 1 => self,
213 2 => with_graph(|g| g.mul(self, self)),
214 3 => {
215 let sq = with_graph(|g| g.mul(self, self));
216 with_graph(|g| g.mul(sq, self))
217 }
218 4 => {
219 let sq = with_graph(|g| g.mul(self, self));
220 with_graph(|g| g.mul(sq, sq))
221 }
222 -1 => with_graph(|g| g.recip(self)),
223 -2 => {
224 let sq = with_graph(|g| g.mul(self, self));
225 with_graph(|g| g.recip(sq))
226 }
227 _ => self.powf(Self::from_f64(n as f64)),
228 }
229 }
230
231 #[inline]
232 fn copysign(self, sign: Self) -> Self {
233 let ax = self.abs();
235 let ss = sign.signum();
236 with_graph(|g| g.mul(ax, ss))
237 }
238
239 #[inline]
240 fn signum(self) -> Self {
241 let xx = with_graph(|g| g.mul(self, self));
243 let abs_x = with_graph(|g| g.sqrt(xx));
244 let r = with_graph(|g| g.recip(abs_x));
245 with_graph(|g| g.mul(self, r))
246 }
247
248 #[inline]
249 fn floor(self) -> Self {
250 with_graph(|g| {
252 if let Some(v) = g.node(self).as_f64() {
253 g.lit(v.floor())
254 } else {
255 panic!("floor() requires a literal expression")
256 }
257 })
258 }
259
260 #[inline]
261 fn ceil(self) -> Self {
262 with_graph(|g| {
263 if let Some(v) = g.node(self).as_f64() {
264 g.lit(v.ceil())
265 } else {
266 panic!("ceil() requires a literal expression")
267 }
268 })
269 }
270
271 #[inline]
272 fn round(self) -> Self {
273 with_graph(|g| {
274 if let Some(v) = g.node(self).as_f64() {
275 g.lit(v.round())
276 } else {
277 panic!("round() requires a literal expression")
278 }
279 })
280 }
281
282 #[inline]
283 fn exp(self) -> Self {
284 let log2_e = Self::from_f64(std::f64::consts::LOG2_E);
286 let scaled = with_graph(|g| g.mul(self, log2_e));
287 with_graph(|g| g.exp2(scaled))
288 }
289
290 #[inline]
291 fn ln(self) -> Self {
292 let ln_2 = Self::from_f64(std::f64::consts::LN_2);
294 let l = with_graph(|g| g.log2(self));
295 with_graph(|g| g.mul(l, ln_2))
296 }
297
298 #[inline]
299 fn powf(self, p: Self) -> Self {
300 let l = with_graph(|g| g.log2(self));
302 let pl = with_graph(|g| g.mul(p, l));
303 with_graph(|g| g.exp2(pl))
304 }
305
306 #[inline]
307 fn sinh(self) -> Self {
308 let half = Self::from_f64(0.5);
310 let ex = self.exp();
311 let neg_x = with_graph(|g| g.neg(self));
312 let enx = Scalar::exp(neg_x);
313 let diff = ex - enx;
314 with_graph(|g| g.mul(half, diff))
315 }
316
317 #[inline]
318 fn cosh(self) -> Self {
319 let half = Self::from_f64(0.5);
321 let ex = self.exp();
322 let neg_x = with_graph(|g| g.neg(self));
323 let enx = Scalar::exp(neg_x);
324 let sum = ex + enx;
325 with_graph(|g| g.mul(half, sum))
326 }
327
328 #[inline]
329 fn tanh(self) -> Self {
330 let s = self.sinh();
332 let c = self.cosh();
333 let rc = with_graph(|g| g.recip(c));
334 with_graph(|g| g.mul(s, rc))
335 }
336
337 #[inline]
338 fn acosh(self) -> Self {
339 let one = ExprId::ONE;
341 let xx = with_graph(|g| g.mul(self, self));
342 let diff = with_graph(|g| {
343 let neg_one = g.neg(one);
344 g.add(xx, neg_one)
345 });
346 let sq = with_graph(|g| g.sqrt(diff));
347 let sum = with_graph(|g| g.add(self, sq));
348 Scalar::ln(sum)
349 }
350
351 #[inline]
352 fn asinh(self) -> Self {
353 let one = ExprId::ONE;
355 let xx = with_graph(|g| g.mul(self, self));
356 let sum_inner = with_graph(|g| g.add(xx, one));
357 let sq = with_graph(|g| g.sqrt(sum_inner));
358 let sum = with_graph(|g| g.add(self, sq));
359 Scalar::ln(sum)
360 }
361
362 #[inline]
363 fn atanh(self) -> Self {
364 let half = Self::from_f64(0.5);
366 let one = ExprId::ONE;
367 let one_plus = with_graph(|g| g.add(one, self));
368 let neg_x = with_graph(|g| g.neg(self));
369 let one_minus = with_graph(|g| g.add(one, neg_x));
370 let ratio = one_plus / one_minus;
371 let l = Scalar::ln(ratio);
372 with_graph(|g| g.mul(half, l))
373 }
374
375 #[inline]
376 fn from_f64(v: f64) -> Self {
377 with_graph(|g| g.lit(v))
378 }
379
380 #[inline]
381 fn to_f64(self) -> f64 {
382 panic!("cannot evaluate symbolic ExprId to f64 — use ExprGraph::eval() instead")
383 }
384
385 #[inline]
386 fn from_i32(v: i32) -> Self {
387 with_graph(|g| g.lit(v as f64))
388 }
389
390 #[inline]
391 fn select(cond: Self, a: Self, b: Self) -> Self {
392 with_graph(|g| g.select(cond, a, b))
393 }
394}
395
396#[cfg(test)]
401mod tests {
402 use tang::Scalar;
403
404 use crate::{trace, ExprId};
405
406 #[test]
407 fn basic_arithmetic() {
408 let (g, result) = trace(|| {
409 let x = ExprId::from_f64(3.0);
410 let y = ExprId::from_f64(4.0);
411 x + y
412 });
413 let val = g.eval::<f64>(result, &[]);
415 assert!((val - 7.0).abs() < 1e-10);
416 }
417
418 #[test]
419 fn var_trace() {
420 let (g, result) = trace(|| {
421 let x: ExprId = Scalar::from_f64(0.0); x
423 });
424 assert_eq!(result, ExprId::ZERO);
425 assert_eq!(g.len(), 3); }
427
428 #[test]
429 fn constants_are_lits() {
430 let (_g, (half, pi)) = trace(|| {
431 let h = ExprId::from_f64(0.5);
432 let p = ExprId::from_f64(std::f64::consts::PI);
433 (h, p)
434 });
435 assert_ne!(half, pi);
436 }
437}