1use crate::Float;
29use crate::Tensor;
30use crate::error::{CoreError, Result};
31
32pub enum Expr<'a, T: Float> {
38 Input(&'a Tensor<T>),
40 Scalar(T),
42 Add(Box<Expr<'a, T>>, Box<Expr<'a, T>>),
44 Sub(Box<Expr<'a, T>>, Box<Expr<'a, T>>),
46 Mul(Box<Expr<'a, T>>, Box<Expr<'a, T>>),
48 Div(Box<Expr<'a, T>>, Box<Expr<'a, T>>),
50 Neg(Box<Expr<'a, T>>),
52 Sqrt(Box<Expr<'a, T>>),
54 Exp(Box<Expr<'a, T>>),
56 Ln(Box<Expr<'a, T>>),
58 Abs(Box<Expr<'a, T>>),
60 Sin(Box<Expr<'a, T>>),
62 Cos(Box<Expr<'a, T>>),
64 Pow(Box<Expr<'a, T>>, Box<Expr<'a, T>>),
66 Fma(Box<Expr<'a, T>>, Box<Expr<'a, T>>, Box<Expr<'a, T>>),
68 Clamp(Box<Expr<'a, T>>, T, T),
70}
71
72#[allow(clippy::should_implement_trait)]
73impl<'a, T: Float> Expr<'a, T> {
74 pub fn input(tensor: &'a Tensor<T>) -> Self {
87 Expr::Input(tensor)
88 }
89
90 pub fn scalar(val: T) -> Self {
92 Expr::Scalar(val)
93 }
94
95 pub fn add(self, other: Self) -> Self {
97 Expr::Add(Box::new(self), Box::new(other))
98 }
99
100 pub fn sub(self, other: Self) -> Self {
102 Expr::Sub(Box::new(self), Box::new(other))
103 }
104
105 pub fn mul(self, other: Self) -> Self {
107 Expr::Mul(Box::new(self), Box::new(other))
108 }
109
110 pub fn div(self, other: Self) -> Self {
112 Expr::Div(Box::new(self), Box::new(other))
113 }
114
115 pub fn neg(self) -> Self {
117 Expr::Neg(Box::new(self))
118 }
119
120 pub fn sqrt(self) -> Self {
122 Expr::Sqrt(Box::new(self))
123 }
124
125 pub fn exp(self) -> Self {
127 Expr::Exp(Box::new(self))
128 }
129
130 pub fn ln(self) -> Self {
132 Expr::Ln(Box::new(self))
133 }
134
135 pub fn abs(self) -> Self {
137 Expr::Abs(Box::new(self))
138 }
139
140 pub fn sin(self) -> Self {
142 Expr::Sin(Box::new(self))
143 }
144
145 pub fn cos(self) -> Self {
147 Expr::Cos(Box::new(self))
148 }
149
150 pub fn pow(self, other: Self) -> Self {
152 Expr::Pow(Box::new(self), Box::new(other))
153 }
154
155 pub fn fma(self, b: Self, c: Self) -> Self {
157 Expr::Fma(Box::new(self), Box::new(b), Box::new(c))
158 }
159
160 pub fn clamp(self, min: T, max: T) -> Self {
162 Expr::Clamp(Box::new(self), min, max)
163 }
164
165 pub fn eval(&self) -> Result<Tensor<T>> {
172 let shape = collect_shape(self)?;
173 let numel: usize = shape.iter().product();
174 let mut result = Vec::with_capacity(numel);
175 for i in 0..numel {
176 result.push(self.eval_at(i));
177 }
178 Tensor::from_vec(result, shape)
179 }
180
181 fn eval_at(&self, idx: usize) -> T {
183 match self {
184 Expr::Input(t) => t.as_slice()[idx],
185 Expr::Scalar(v) => *v,
186 Expr::Add(a, b) => a.eval_at(idx) + b.eval_at(idx),
187 Expr::Sub(a, b) => a.eval_at(idx) - b.eval_at(idx),
188 Expr::Mul(a, b) => a.eval_at(idx) * b.eval_at(idx),
189 Expr::Div(a, b) => a.eval_at(idx) / b.eval_at(idx),
190 Expr::Neg(a) => T::zero() - a.eval_at(idx),
191 Expr::Sqrt(a) => a.eval_at(idx).sqrt(),
192 Expr::Exp(a) => a.eval_at(idx).exp(),
193 Expr::Ln(a) => a.eval_at(idx).ln(),
194 Expr::Abs(a) => a.eval_at(idx).abs(),
195 Expr::Sin(a) => a.eval_at(idx).sin(),
196 Expr::Cos(a) => a.eval_at(idx).cos(),
197 Expr::Pow(a, b) => a.eval_at(idx).powf(b.eval_at(idx)),
198 Expr::Fma(a, b, c) => a.eval_at(idx) * b.eval_at(idx) + c.eval_at(idx),
199 Expr::Clamp(a, min, max) => {
200 let v = a.eval_at(idx);
201 if v < *min {
202 *min
203 } else if v > *max {
204 *max
205 } else {
206 v
207 }
208 }
209 }
210 }
211}
212
213fn collect_shape<T: Float>(expr: &Expr<'_, T>) -> Result<Vec<usize>> {
219 let mut shape: Option<Vec<usize>> = None;
220 collect_shape_inner(expr, &mut shape)?;
221 Ok(shape.unwrap_or_else(|| vec![1]))
222}
223
224fn collect_shape_inner<T: Float>(expr: &Expr<'_, T>, shape: &mut Option<Vec<usize>>) -> Result<()> {
225 match expr {
226 Expr::Input(t) => {
227 let s = t.shape();
228 match shape {
229 Some(existing) if existing.as_slice() != s => {
230 return Err(CoreError::DimensionMismatch {
231 expected: existing.clone(),
232 got: s.to_vec(),
233 });
234 }
235 None => {
236 *shape = Some(s.to_vec());
237 }
238 _ => {}
239 }
240 Ok(())
241 }
242 Expr::Scalar(_) => Ok(()),
243 Expr::Add(a, b) | Expr::Sub(a, b) | Expr::Mul(a, b) | Expr::Div(a, b) | Expr::Pow(a, b) => {
244 collect_shape_inner(a, shape)?;
245 collect_shape_inner(b, shape)
246 }
247 Expr::Neg(a)
248 | Expr::Sqrt(a)
249 | Expr::Exp(a)
250 | Expr::Ln(a)
251 | Expr::Abs(a)
252 | Expr::Sin(a)
253 | Expr::Cos(a)
254 | Expr::Clamp(a, _, _) => collect_shape_inner(a, shape),
255 Expr::Fma(a, b, c) => {
256 collect_shape_inner(a, shape)?;
257 collect_shape_inner(b, shape)?;
258 collect_shape_inner(c, shape)
259 }
260 }
261}
262
263pub fn eval_expr<T: Float>(expr: &Expr<'_, T>) -> Result<Tensor<T>> {
267 expr.eval()
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273
274 #[test]
275 fn test_expr_basic_arithmetic() {
276 let a = Tensor::from_vec(vec![1.0_f64, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
277 let b = Tensor::from_vec(vec![10.0, 20.0, 30.0, 40.0], vec![2, 2]).unwrap();
278 let c = Tensor::from_vec(vec![2.0, 2.0, 2.0, 2.0], vec![2, 2]).unwrap();
279
280 let result = Expr::input(&a)
282 .add(Expr::input(&b))
283 .mul(Expr::input(&c))
284 .eval()
285 .unwrap();
286
287 assert_eq!(result.shape(), &[2, 2]);
288 assert_eq!(result.as_slice(), &[22.0, 44.0, 66.0, 88.0]);
289 }
290
291 #[test]
292 fn test_expr_unary_ops() {
293 let a = Tensor::from_vec(vec![-4.0_f64, -9.0, -16.0], vec![3]).unwrap();
294
295 let result = Expr::input(&a).abs().sqrt().eval().unwrap();
297
298 assert_eq!(result.shape(), &[3]);
299 assert_eq!(result.as_slice(), &[2.0, 3.0, 4.0]);
300 }
301
302 #[test]
303 fn test_expr_fma() {
304 let a = Tensor::from_vec(vec![1.0_f64, 2.0, 3.0], vec![3]).unwrap();
305 let b = Tensor::from_vec(vec![4.0, 5.0, 6.0], vec![3]).unwrap();
306 let c = Tensor::from_vec(vec![10.0, 20.0, 30.0], vec![3]).unwrap();
307
308 let result = Expr::input(&a)
310 .fma(Expr::input(&b), Expr::input(&c))
311 .eval()
312 .unwrap();
313
314 assert_eq!(result.as_slice(), &[14.0, 30.0, 48.0]);
316 }
317
318 #[test]
319 fn test_expr_scalar_broadcast() {
320 let a = Tensor::from_vec(vec![1.0_f64, 2.0, 3.0, 4.0], vec![4]).unwrap();
321
322 let result = Expr::input(&a).add(Expr::scalar(2.0)).eval().unwrap();
324
325 assert_eq!(result.as_slice(), &[3.0, 4.0, 5.0, 6.0]);
326 }
327
328 #[test]
329 fn test_expr_shape_mismatch() {
330 let a = Tensor::from_vec(vec![1.0_f64, 2.0, 3.0], vec![3]).unwrap();
331 let b = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
332
333 let err = Expr::input(&a).add(Expr::input(&b)).eval();
334 assert!(err.is_err());
335
336 match err.unwrap_err() {
337 CoreError::DimensionMismatch { expected, got } => {
338 assert_eq!(expected, vec![3]);
339 assert_eq!(got, vec![4]);
340 }
341 other => panic!("expected DimensionMismatch, got {other:?}"),
342 }
343 }
344
345 #[test]
346 fn test_expr_complex_chain() {
347 let a = Tensor::from_vec(vec![0.0_f64, 2.0, 4.0], vec![3]).unwrap();
349 let b = Tensor::from_vec(vec![0.0, core::f64::consts::PI, 0.0], vec![3]).unwrap();
350
351 let result = Expr::input(&a)
352 .mul(Expr::scalar(0.5))
353 .exp()
354 .add(Expr::input(&b).cos())
355 .eval()
356 .unwrap();
357
358 let expected = [
359 (0.0_f64 * 0.5).exp() + 0.0_f64.cos(), (2.0_f64 * 0.5).exp() + core::f64::consts::PI.cos(), (4.0_f64 * 0.5).exp() + 0.0_f64.cos(), ];
363
364 let result_slice = result.as_slice();
365 for (i, (&got, &exp)) in result_slice.iter().zip(expected.iter()).enumerate() {
366 assert!(
367 (got - exp).abs() < 1e-12,
368 "index {i}: got {got}, expected {exp}"
369 );
370 }
371 }
372}