Skip to main content

scivex_core/
jit.rs

1//! Expression JIT for element-wise tensor operations.
2//!
3//! Instead of allocating intermediate tensors for each element-wise operation,
4//! this module builds an expression tree and evaluates it in a single fused
5//! pass over the data. This eliminates temporary allocations and improves
6//! cache locality for chains of element-wise ops.
7//!
8//! # Examples
9//!
10//! ```
11//! use scivex_core::Tensor;
12//! use scivex_core::jit::Expr;
13//!
14//! let a = Tensor::from_vec(vec![1.0_f64, 2.0, 3.0], vec![3]).unwrap();
15//! let b = Tensor::from_vec(vec![4.0_f64, 5.0, 6.0], vec![3]).unwrap();
16//! let c = Tensor::from_vec(vec![0.5_f64, 0.5, 0.5], vec![3]).unwrap();
17//!
18//! // (a + b) * c — no intermediate tensor for a + b
19//! let result = Expr::input(&a)
20//!     .add(Expr::input(&b))
21//!     .mul(Expr::input(&c))
22//!     .eval()
23//!     .unwrap();
24//!
25//! assert_eq!(result.as_slice(), &[2.5, 3.5, 4.5]);
26//! ```
27
28use crate::Float;
29use crate::Tensor;
30use crate::error::{CoreError, Result};
31
32/// An expression node in the computation graph.
33///
34/// Each variant represents either a leaf (tensor input or scalar constant)
35/// or a fused element-wise operation. The tree is evaluated in one pass
36/// via [`Expr::eval`], avoiding intermediate tensor allocations.
37pub enum Expr<'a, T: Float> {
38    /// A tensor input (leaf node).
39    Input(&'a Tensor<T>),
40    /// A scalar constant, broadcast to match the output shape.
41    Scalar(T),
42    /// Element-wise addition.
43    Add(Box<Expr<'a, T>>, Box<Expr<'a, T>>),
44    /// Element-wise subtraction.
45    Sub(Box<Expr<'a, T>>, Box<Expr<'a, T>>),
46    /// Element-wise multiplication.
47    Mul(Box<Expr<'a, T>>, Box<Expr<'a, T>>),
48    /// Element-wise division.
49    Div(Box<Expr<'a, T>>, Box<Expr<'a, T>>),
50    /// Unary negation.
51    Neg(Box<Expr<'a, T>>),
52    /// Element-wise square root.
53    Sqrt(Box<Expr<'a, T>>),
54    /// Element-wise exponential.
55    Exp(Box<Expr<'a, T>>),
56    /// Element-wise natural logarithm.
57    Ln(Box<Expr<'a, T>>),
58    /// Element-wise absolute value.
59    Abs(Box<Expr<'a, T>>),
60    /// Element-wise sine.
61    Sin(Box<Expr<'a, T>>),
62    /// Element-wise cosine.
63    Cos(Box<Expr<'a, T>>),
64    /// Element-wise power.
65    Pow(Box<Expr<'a, T>>, Box<Expr<'a, T>>),
66    /// Fused multiply-add: `a * b + c`.
67    Fma(Box<Expr<'a, T>>, Box<Expr<'a, T>>, Box<Expr<'a, T>>),
68    /// Clamp values to `[min, max]`.
69    Clamp(Box<Expr<'a, T>>, T, T),
70}
71
72#[allow(clippy::should_implement_trait)]
73impl<'a, T: Float> Expr<'a, T> {
74    /// Create an input expression referencing a tensor.
75    ///
76    /// # Examples
77    ///
78    /// ```
79    /// use scivex_core::Tensor;
80    /// use scivex_core::jit::Expr;
81    ///
82    /// let t = Tensor::from_vec(vec![1.0_f64, 2.0, 3.0], vec![3]).unwrap();
83    /// let result = Expr::input(&t).eval().unwrap();
84    /// assert_eq!(result.as_slice(), &[1.0, 2.0, 3.0]);
85    /// ```
86    pub fn input(tensor: &'a Tensor<T>) -> Self {
87        Expr::Input(tensor)
88    }
89
90    /// Create a scalar constant expression.
91    pub fn scalar(val: T) -> Self {
92        Expr::Scalar(val)
93    }
94
95    /// Element-wise addition: `self + other`.
96    pub fn add(self, other: Self) -> Self {
97        Expr::Add(Box::new(self), Box::new(other))
98    }
99
100    /// Element-wise subtraction: `self - other`.
101    pub fn sub(self, other: Self) -> Self {
102        Expr::Sub(Box::new(self), Box::new(other))
103    }
104
105    /// Element-wise multiplication: `self * other`.
106    pub fn mul(self, other: Self) -> Self {
107        Expr::Mul(Box::new(self), Box::new(other))
108    }
109
110    /// Element-wise division: `self / other`.
111    pub fn div(self, other: Self) -> Self {
112        Expr::Div(Box::new(self), Box::new(other))
113    }
114
115    /// Unary negation: `-self`.
116    pub fn neg(self) -> Self {
117        Expr::Neg(Box::new(self))
118    }
119
120    /// Element-wise square root.
121    pub fn sqrt(self) -> Self {
122        Expr::Sqrt(Box::new(self))
123    }
124
125    /// Element-wise exponential.
126    pub fn exp(self) -> Self {
127        Expr::Exp(Box::new(self))
128    }
129
130    /// Element-wise natural logarithm.
131    pub fn ln(self) -> Self {
132        Expr::Ln(Box::new(self))
133    }
134
135    /// Element-wise absolute value.
136    pub fn abs(self) -> Self {
137        Expr::Abs(Box::new(self))
138    }
139
140    /// Element-wise sine.
141    pub fn sin(self) -> Self {
142        Expr::Sin(Box::new(self))
143    }
144
145    /// Element-wise cosine.
146    pub fn cos(self) -> Self {
147        Expr::Cos(Box::new(self))
148    }
149
150    /// Element-wise power: `self ^ other`.
151    pub fn pow(self, other: Self) -> Self {
152        Expr::Pow(Box::new(self), Box::new(other))
153    }
154
155    /// Fused multiply-add: `self * b + c`.
156    pub fn fma(self, b: Self, c: Self) -> Self {
157        Expr::Fma(Box::new(self), Box::new(b), Box::new(c))
158    }
159
160    /// Clamp values to `[min, max]`.
161    pub fn clamp(self, min: T, max: T) -> Self {
162        Expr::Clamp(Box::new(self), min, max)
163    }
164
165    /// Evaluate the expression tree, producing a result tensor.
166    ///
167    /// All [`Expr::Input`] tensors referenced anywhere in the tree must have
168    /// the same shape. Scalar nodes are broadcast to match. Returns an error
169    /// if input shapes disagree or if no shape can be determined (i.e., the
170    /// entire expression is purely scalar with no tensor inputs).
171    pub fn eval(&self) -> Result<Tensor<T>> {
172        let shape = collect_shape(self)?;
173        let numel: usize = shape.iter().product();
174        let mut result = Vec::with_capacity(numel);
175        for i in 0..numel {
176            result.push(self.eval_at(i));
177        }
178        Tensor::from_vec(result, shape)
179    }
180
181    /// Evaluate the expression at a single flat index.
182    fn eval_at(&self, idx: usize) -> T {
183        match self {
184            Expr::Input(t) => t.as_slice()[idx],
185            Expr::Scalar(v) => *v,
186            Expr::Add(a, b) => a.eval_at(idx) + b.eval_at(idx),
187            Expr::Sub(a, b) => a.eval_at(idx) - b.eval_at(idx),
188            Expr::Mul(a, b) => a.eval_at(idx) * b.eval_at(idx),
189            Expr::Div(a, b) => a.eval_at(idx) / b.eval_at(idx),
190            Expr::Neg(a) => T::zero() - a.eval_at(idx),
191            Expr::Sqrt(a) => a.eval_at(idx).sqrt(),
192            Expr::Exp(a) => a.eval_at(idx).exp(),
193            Expr::Ln(a) => a.eval_at(idx).ln(),
194            Expr::Abs(a) => a.eval_at(idx).abs(),
195            Expr::Sin(a) => a.eval_at(idx).sin(),
196            Expr::Cos(a) => a.eval_at(idx).cos(),
197            Expr::Pow(a, b) => a.eval_at(idx).powf(b.eval_at(idx)),
198            Expr::Fma(a, b, c) => a.eval_at(idx) * b.eval_at(idx) + c.eval_at(idx),
199            Expr::Clamp(a, min, max) => {
200                let v = a.eval_at(idx);
201                if v < *min {
202                    *min
203                } else if v > *max {
204                    *max
205                } else {
206                    v
207                }
208            }
209        }
210    }
211}
212
213/// Traverse the expression tree, collect all input tensor shapes, and verify
214/// they are identical. Returns the common shape, or an error if shapes differ.
215///
216/// If the expression contains no `Input` nodes (pure scalar), returns a
217/// scalar shape `[1]`.
218fn collect_shape<T: Float>(expr: &Expr<'_, T>) -> Result<Vec<usize>> {
219    let mut shape: Option<Vec<usize>> = None;
220    collect_shape_inner(expr, &mut shape)?;
221    Ok(shape.unwrap_or_else(|| vec![1]))
222}
223
224fn collect_shape_inner<T: Float>(expr: &Expr<'_, T>, shape: &mut Option<Vec<usize>>) -> Result<()> {
225    match expr {
226        Expr::Input(t) => {
227            let s = t.shape();
228            match shape {
229                Some(existing) if existing.as_slice() != s => {
230                    return Err(CoreError::DimensionMismatch {
231                        expected: existing.clone(),
232                        got: s.to_vec(),
233                    });
234                }
235                None => {
236                    *shape = Some(s.to_vec());
237                }
238                _ => {}
239            }
240            Ok(())
241        }
242        Expr::Scalar(_) => Ok(()),
243        Expr::Add(a, b) | Expr::Sub(a, b) | Expr::Mul(a, b) | Expr::Div(a, b) | Expr::Pow(a, b) => {
244            collect_shape_inner(a, shape)?;
245            collect_shape_inner(b, shape)
246        }
247        Expr::Neg(a)
248        | Expr::Sqrt(a)
249        | Expr::Exp(a)
250        | Expr::Ln(a)
251        | Expr::Abs(a)
252        | Expr::Sin(a)
253        | Expr::Cos(a)
254        | Expr::Clamp(a, _, _) => collect_shape_inner(a, shape),
255        Expr::Fma(a, b, c) => {
256            collect_shape_inner(a, shape)?;
257            collect_shape_inner(b, shape)?;
258            collect_shape_inner(c, shape)
259        }
260    }
261}
262
263/// Convenience function: evaluate an expression built from tensors.
264///
265/// Equivalent to calling [`Expr::eval`] directly.
266pub fn eval_expr<T: Float>(expr: &Expr<'_, T>) -> Result<Tensor<T>> {
267    expr.eval()
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273
274    #[test]
275    fn test_expr_basic_arithmetic() {
276        let a = Tensor::from_vec(vec![1.0_f64, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
277        let b = Tensor::from_vec(vec![10.0, 20.0, 30.0, 40.0], vec![2, 2]).unwrap();
278        let c = Tensor::from_vec(vec![2.0, 2.0, 2.0, 2.0], vec![2, 2]).unwrap();
279
280        // (a + b) * c
281        let result = Expr::input(&a)
282            .add(Expr::input(&b))
283            .mul(Expr::input(&c))
284            .eval()
285            .unwrap();
286
287        assert_eq!(result.shape(), &[2, 2]);
288        assert_eq!(result.as_slice(), &[22.0, 44.0, 66.0, 88.0]);
289    }
290
291    #[test]
292    fn test_expr_unary_ops() {
293        let a = Tensor::from_vec(vec![-4.0_f64, -9.0, -16.0], vec![3]).unwrap();
294
295        // sqrt(abs(a))
296        let result = Expr::input(&a).abs().sqrt().eval().unwrap();
297
298        assert_eq!(result.shape(), &[3]);
299        assert_eq!(result.as_slice(), &[2.0, 3.0, 4.0]);
300    }
301
302    #[test]
303    fn test_expr_fma() {
304        let a = Tensor::from_vec(vec![1.0_f64, 2.0, 3.0], vec![3]).unwrap();
305        let b = Tensor::from_vec(vec![4.0, 5.0, 6.0], vec![3]).unwrap();
306        let c = Tensor::from_vec(vec![10.0, 20.0, 30.0], vec![3]).unwrap();
307
308        // a * b + c via fma
309        let result = Expr::input(&a)
310            .fma(Expr::input(&b), Expr::input(&c))
311            .eval()
312            .unwrap();
313
314        // Expected: [1*4+10, 2*5+20, 3*6+30] = [14, 30, 48]
315        assert_eq!(result.as_slice(), &[14.0, 30.0, 48.0]);
316    }
317
318    #[test]
319    fn test_expr_scalar_broadcast() {
320        let a = Tensor::from_vec(vec![1.0_f64, 2.0, 3.0, 4.0], vec![4]).unwrap();
321
322        // a + 2.0
323        let result = Expr::input(&a).add(Expr::scalar(2.0)).eval().unwrap();
324
325        assert_eq!(result.as_slice(), &[3.0, 4.0, 5.0, 6.0]);
326    }
327
328    #[test]
329    fn test_expr_shape_mismatch() {
330        let a = Tensor::from_vec(vec![1.0_f64, 2.0, 3.0], vec![3]).unwrap();
331        let b = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
332
333        let err = Expr::input(&a).add(Expr::input(&b)).eval();
334        assert!(err.is_err());
335
336        match err.unwrap_err() {
337            CoreError::DimensionMismatch { expected, got } => {
338                assert_eq!(expected, vec![3]);
339                assert_eq!(got, vec![4]);
340            }
341            other => panic!("expected DimensionMismatch, got {other:?}"),
342        }
343    }
344
345    #[test]
346    fn test_expr_complex_chain() {
347        // exp(a * 0.5) + cos(b)
348        let a = Tensor::from_vec(vec![0.0_f64, 2.0, 4.0], vec![3]).unwrap();
349        let b = Tensor::from_vec(vec![0.0, core::f64::consts::PI, 0.0], vec![3]).unwrap();
350
351        let result = Expr::input(&a)
352            .mul(Expr::scalar(0.5))
353            .exp()
354            .add(Expr::input(&b).cos())
355            .eval()
356            .unwrap();
357
358        let expected = [
359            (0.0_f64 * 0.5).exp() + 0.0_f64.cos(), // 1.0 + 1.0 = 2.0
360            (2.0_f64 * 0.5).exp() + core::f64::consts::PI.cos(), // e^1 + (-1)
361            (4.0_f64 * 0.5).exp() + 0.0_f64.cos(), // e^2 + 1.0
362        ];
363
364        let result_slice = result.as_slice();
365        for (i, (&got, &exp)) in result_slice.iter().zip(expected.iter()).enumerate() {
366            assert!(
367                (got - exp).abs() < 1e-12,
368                "index {i}: got {got}, expected {exp}"
369            );
370        }
371    }
372}