Skip to main content

torsh_tensor/
expression_templates.rs

1//! Expression Templates for Compile-Time Tensor Operation Optimization
2//!
3//! This module provides expression templates that enable compile-time fusion of tensor operations.
4//! Expression templates defer actual computation until evaluation, allowing the compiler to optimize
5//! operation chains and eliminate intermediate allocations.
6//!
7//! # Features
8//!
9//! - **Compile-time optimization**: Operations are fused at compile time
10//! - **Zero-cost abstractions**: No runtime overhead compared to hand-written fused loops
11//! - **Lazy evaluation**: Computation is deferred until explicitly requested
12//! - **Type-safe**: Full type checking at compile time
13//! - **Intermediate elimination**: No temporary tensors created for operation chains
14//!
15//! # Example
16//!
17//! ```rust
18//! use torsh_tensor::{Tensor, expression_templates::*};
19//!
20//! // Without expression templates:
21//! // let temp1 = a.add(&b)?;  // Allocates temporary
22//! // let temp2 = temp1.mul(&c)?;  // Allocates another temporary
23//! // let result = temp2.add_scalar(1.0)?;
24//!
25//! // With expression templates:
26//! // let expr = expr_add(expr_tensor(&a), expr_tensor(&b))
27//! //     .mul(expr_tensor(&c))
28//! //     .add_scalar(1.0);
29//! // let result = expr.eval()?;  // Single allocation, fused computation
30//! ```
31
32// Framework infrastructure - components designed for future use
33#![allow(dead_code)]
34use std::marker::PhantomData;
35use std::ops::{Add, Div, Mul, Sub};
36
37use torsh_core::{dtype::TensorElement, error::Result};
38
39use crate::Tensor;
40
41/// Trait representing an expression that can be evaluated to produce a value
42pub trait Expression<T: TensorElement> {
43    /// Evaluate the expression at a specific index
44    fn eval_at(&self, index: usize) -> T;
45
46    /// Get the size (number of elements) of the expression
47    fn size(&self) -> usize;
48
49    /// Evaluate the entire expression into a Vec
50    fn eval_vec(&self) -> Vec<T> {
51        (0..self.size()).map(|i| self.eval_at(i)).collect()
52    }
53
54    /// Evaluate the expression into a tensor
55    fn eval_tensor(
56        &self,
57        shape: Vec<usize>,
58        device: torsh_core::device::DeviceType,
59    ) -> Result<Tensor<T>>
60    where
61        T: Copy,
62    {
63        let data = self.eval_vec();
64        Tensor::from_data(data, shape, device)
65    }
66}
67
68/// Expression representing a tensor reference
69pub struct TensorExpr<'a, T: TensorElement> {
70    data: Vec<T>,
71    size: usize,
72    _phantom: PhantomData<&'a T>,
73}
74
75impl<'a, T: TensorElement + Copy> TensorExpr<'a, T> {
76    /// Create a new tensor expression from a tensor
77    pub fn new(tensor: &'a Tensor<T>) -> Result<Self> {
78        let data = tensor.to_vec()?;
79        let size = data.len();
80
81        Ok(Self {
82            data,
83            size,
84            _phantom: PhantomData,
85        })
86    }
87}
88
89impl<'a, T: TensorElement> Expression<T> for TensorExpr<'a, T> {
90    fn eval_at(&self, index: usize) -> T {
91        self.data[index]
92    }
93
94    fn size(&self) -> usize {
95        self.size
96    }
97
98    fn eval_vec(&self) -> Vec<T> {
99        self.data.clone()
100    }
101}
102
103/// Expression representing scalar addition
104pub struct AddScalarExpr<T: TensorElement, E: Expression<T>> {
105    expr: E,
106    scalar: T,
107}
108
109impl<T: TensorElement + Add<Output = T>, E: Expression<T>> Expression<T> for AddScalarExpr<T, E> {
110    fn eval_at(&self, index: usize) -> T {
111        self.expr.eval_at(index) + self.scalar
112    }
113
114    fn size(&self) -> usize {
115        self.expr.size()
116    }
117}
118
119/// Expression representing scalar multiplication
120pub struct MulScalarExpr<T: TensorElement, E: Expression<T>> {
121    expr: E,
122    scalar: T,
123}
124
125impl<T: TensorElement + Mul<Output = T>, E: Expression<T>> Expression<T> for MulScalarExpr<T, E> {
126    fn eval_at(&self, index: usize) -> T {
127        self.expr.eval_at(index) * self.scalar
128    }
129
130    fn size(&self) -> usize {
131        self.expr.size()
132    }
133}
134
135/// Expression representing scalar subtraction
136pub struct SubScalarExpr<T: TensorElement, E: Expression<T>> {
137    expr: E,
138    scalar: T,
139}
140
141impl<T: TensorElement + Sub<Output = T>, E: Expression<T>> Expression<T> for SubScalarExpr<T, E> {
142    fn eval_at(&self, index: usize) -> T {
143        self.expr.eval_at(index) - self.scalar
144    }
145
146    fn size(&self) -> usize {
147        self.expr.size()
148    }
149}
150
151/// Expression representing scalar division
152pub struct DivScalarExpr<T: TensorElement, E: Expression<T>> {
153    expr: E,
154    scalar: T,
155}
156
157impl<T: TensorElement + Div<Output = T>, E: Expression<T>> Expression<T> for DivScalarExpr<T, E> {
158    fn eval_at(&self, index: usize) -> T {
159        self.expr.eval_at(index) / self.scalar
160    }
161
162    fn size(&self) -> usize {
163        self.expr.size()
164    }
165}
166
167/// Expression representing element-wise addition
168pub struct AddExpr<T: TensorElement, E1: Expression<T>, E2: Expression<T>> {
169    left: E1,
170    right: E2,
171    _phantom: PhantomData<T>,
172}
173
174impl<T: TensorElement + Add<Output = T>, E1: Expression<T>, E2: Expression<T>> Expression<T>
175    for AddExpr<T, E1, E2>
176{
177    fn eval_at(&self, index: usize) -> T {
178        self.left.eval_at(index) + self.right.eval_at(index)
179    }
180
181    fn size(&self) -> usize {
182        self.left.size().min(self.right.size())
183    }
184}
185
186/// Expression representing element-wise multiplication
187pub struct MulExpr<T: TensorElement, E1: Expression<T>, E2: Expression<T>> {
188    left: E1,
189    right: E2,
190    _phantom: PhantomData<T>,
191}
192
193impl<T: TensorElement + Mul<Output = T>, E1: Expression<T>, E2: Expression<T>> Expression<T>
194    for MulExpr<T, E1, E2>
195{
196    fn eval_at(&self, index: usize) -> T {
197        self.left.eval_at(index) * self.right.eval_at(index)
198    }
199
200    fn size(&self) -> usize {
201        self.left.size().min(self.right.size())
202    }
203}
204
205/// Expression representing element-wise subtraction
206pub struct SubExpr<T: TensorElement, E1: Expression<T>, E2: Expression<T>> {
207    left: E1,
208    right: E2,
209    _phantom: PhantomData<T>,
210}
211
212impl<T: TensorElement + Sub<Output = T>, E1: Expression<T>, E2: Expression<T>> Expression<T>
213    for SubExpr<T, E1, E2>
214{
215    fn eval_at(&self, index: usize) -> T {
216        self.left.eval_at(index) - self.right.eval_at(index)
217    }
218
219    fn size(&self) -> usize {
220        self.left.size().min(self.right.size())
221    }
222}
223
224/// Expression representing element-wise division
225pub struct DivExpr<T: TensorElement, E1: Expression<T>, E2: Expression<T>> {
226    left: E1,
227    right: E2,
228    _phantom: PhantomData<T>,
229}
230
231impl<T: TensorElement + Div<Output = T>, E1: Expression<T>, E2: Expression<T>> Expression<T>
232    for DivExpr<T, E1, E2>
233{
234    fn eval_at(&self, index: usize) -> T {
235        self.left.eval_at(index) / self.right.eval_at(index)
236    }
237
238    fn size(&self) -> usize {
239        self.left.size().min(self.right.size())
240    }
241}
242
243/// Expression representing negation
244pub struct NegExpr<T: TensorElement, E: Expression<T>> {
245    expr: E,
246    _phantom: PhantomData<T>,
247}
248
249impl<T: TensorElement + std::ops::Neg<Output = T>, E: Expression<T>> Expression<T>
250    for NegExpr<T, E>
251{
252    fn eval_at(&self, index: usize) -> T {
253        -self.expr.eval_at(index)
254    }
255
256    fn size(&self) -> usize {
257        self.expr.size()
258    }
259}
260
261/// Expression builder for creating fused operation chains
262pub struct ExprBuilder<T: TensorElement, E: Expression<T>> {
263    expr: E,
264    _phantom: PhantomData<T>,
265}
266
267impl<T: TensorElement, E: Expression<T>> ExprBuilder<T, E> {
268    /// Create a new expression builder
269    pub fn new(expr: E) -> Self {
270        Self {
271            expr,
272            _phantom: PhantomData,
273        }
274    }
275
276    /// Add a scalar to the expression
277    pub fn add_scalar(self, scalar: T) -> ExprBuilder<T, AddScalarExpr<T, E>>
278    where
279        T: Add<Output = T>,
280    {
281        ExprBuilder::new(AddScalarExpr {
282            expr: self.expr,
283            scalar,
284        })
285    }
286
287    /// Multiply the expression by a scalar
288    pub fn mul_scalar(self, scalar: T) -> ExprBuilder<T, MulScalarExpr<T, E>>
289    where
290        T: Mul<Output = T>,
291    {
292        ExprBuilder::new(MulScalarExpr {
293            expr: self.expr,
294            scalar,
295        })
296    }
297
298    /// Subtract a scalar from the expression
299    pub fn sub_scalar(self, scalar: T) -> ExprBuilder<T, SubScalarExpr<T, E>>
300    where
301        T: Sub<Output = T>,
302    {
303        ExprBuilder::new(SubScalarExpr {
304            expr: self.expr,
305            scalar,
306        })
307    }
308
309    /// Divide the expression by a scalar
310    pub fn div_scalar(self, scalar: T) -> ExprBuilder<T, DivScalarExpr<T, E>>
311    where
312        T: Div<Output = T>,
313    {
314        ExprBuilder::new(DivScalarExpr {
315            expr: self.expr,
316            scalar,
317        })
318    }
319
320    /// Add another expression element-wise
321    pub fn add<E2: Expression<T>>(
322        self,
323        other: ExprBuilder<T, E2>,
324    ) -> ExprBuilder<T, AddExpr<T, E, E2>>
325    where
326        T: Add<Output = T>,
327    {
328        ExprBuilder::new(AddExpr {
329            left: self.expr,
330            right: other.expr,
331            _phantom: PhantomData,
332        })
333    }
334
335    /// Multiply another expression element-wise
336    pub fn mul<E2: Expression<T>>(
337        self,
338        other: ExprBuilder<T, E2>,
339    ) -> ExprBuilder<T, MulExpr<T, E, E2>>
340    where
341        T: Mul<Output = T>,
342    {
343        ExprBuilder::new(MulExpr {
344            left: self.expr,
345            right: other.expr,
346            _phantom: PhantomData,
347        })
348    }
349
350    /// Subtract another expression element-wise
351    pub fn sub<E2: Expression<T>>(
352        self,
353        other: ExprBuilder<T, E2>,
354    ) -> ExprBuilder<T, SubExpr<T, E, E2>>
355    where
356        T: Sub<Output = T>,
357    {
358        ExprBuilder::new(SubExpr {
359            left: self.expr,
360            right: other.expr,
361            _phantom: PhantomData,
362        })
363    }
364
365    /// Divide by another expression element-wise
366    pub fn div<E2: Expression<T>>(
367        self,
368        other: ExprBuilder<T, E2>,
369    ) -> ExprBuilder<T, DivExpr<T, E, E2>>
370    where
371        T: Div<Output = T>,
372    {
373        ExprBuilder::new(DivExpr {
374            left: self.expr,
375            right: other.expr,
376            _phantom: PhantomData,
377        })
378    }
379
380    /// Negate the expression
381    pub fn neg(self) -> ExprBuilder<T, NegExpr<T, E>>
382    where
383        T: std::ops::Neg<Output = T>,
384    {
385        ExprBuilder::new(NegExpr {
386            expr: self.expr,
387            _phantom: PhantomData,
388        })
389    }
390
391    /// Evaluate the expression into a Vec
392    pub fn eval_vec(&self) -> Vec<T> {
393        self.expr.eval_vec()
394    }
395
396    /// Evaluate the expression into a Tensor
397    pub fn eval_tensor(
398        &self,
399        shape: Vec<usize>,
400        device: torsh_core::device::DeviceType,
401    ) -> Result<Tensor<T>>
402    where
403        T: Copy,
404    {
405        self.expr.eval_tensor(shape, device)
406    }
407}
408
409/// Create an expression from a tensor reference
410pub fn expr<'a, T: TensorElement + Copy>(
411    tensor: &'a Tensor<T>,
412) -> Result<ExprBuilder<T, TensorExpr<'a, T>>> {
413    let tensor_expr = TensorExpr::new(tensor)?;
414    Ok(ExprBuilder::new(tensor_expr))
415}
416
417/// Trait for tensors that support expression templates
418pub trait TensorExprExt<T: TensorElement> {
419    /// Convert the tensor to an expression builder
420    fn expr(&self) -> Result<ExprBuilder<T, TensorExpr<'_, T>>>
421    where
422        T: Copy;
423}
424
425impl<T: TensorElement + Copy> TensorExprExt<T> for Tensor<T> {
426    fn expr(&self) -> Result<ExprBuilder<T, TensorExpr<'_, T>>> {
427        expr(self)
428    }
429}
430
431#[cfg(test)]
432mod tests {
433    use super::*;
434    use crate::creation::*;
435    use torsh_core::device::DeviceType;
436
437    #[test]
438    fn test_scalar_operations() {
439        let tensor =
440            tensor_1d(&[1.0f32, 2.0, 3.0, 4.0]).expect("tensor_1d creation should succeed");
441
442        let result = tensor
443            .expr()
444            .expect("tensor_1d creation should succeed")
445            .add_scalar(1.0)
446            .mul_scalar(2.0)
447            .eval_vec();
448
449        assert_eq!(result, vec![4.0, 6.0, 8.0, 10.0]);
450    }
451
452    #[test]
453    fn test_element_wise_operations() {
454        let a = tensor_1d(&[1.0f32, 2.0, 3.0, 4.0]).expect("tensor_1d creation should succeed");
455        let b = tensor_1d(&[2.0f32, 2.0, 2.0, 2.0]).expect("tensor_1d creation should succeed");
456
457        let result = a
458            .expr()
459            .expect("expression should exist")
460            .add(b.expr().expect("expression should exist"))
461            .eval_vec();
462
463        assert_eq!(result, vec![3.0, 4.0, 5.0, 6.0]);
464    }
465
466    #[test]
467    fn test_complex_expression() {
468        let a = tensor_1d(&[1.0f32, 2.0, 3.0, 4.0]).expect("tensor_1d creation should succeed");
469        let b = tensor_1d(&[2.0f32, 2.0, 2.0, 2.0]).expect("tensor_1d creation should succeed");
470
471        // (a + b) * 2 + 1
472        let result = a
473            .expr()
474            .expect("tensor_1d creation should succeed")
475            .add(b.expr().expect("expression should exist"))
476            .mul_scalar(2.0)
477            .add_scalar(1.0)
478            .eval_vec();
479
480        assert_eq!(result, vec![7.0, 9.0, 11.0, 13.0]);
481    }
482
483    #[test]
484    fn test_negation() {
485        let tensor =
486            tensor_1d(&[1.0f32, 2.0, -3.0, 4.0]).expect("tensor_1d creation should succeed");
487
488        let result = tensor
489            .expr()
490            .expect("expression should exist")
491            .neg()
492            .eval_vec();
493
494        assert_eq!(result, vec![-1.0, -2.0, 3.0, -4.0]);
495    }
496
497    #[test]
498    fn test_eval_tensor() {
499        let tensor =
500            tensor_1d(&[1.0f32, 2.0, 3.0, 4.0]).expect("tensor_1d creation should succeed");
501
502        let result = tensor
503            .expr()
504            .expect("tensor_1d creation should succeed")
505            .mul_scalar(2.0)
506            .eval_tensor(vec![4], DeviceType::Cpu)
507            .expect("tensor_1d creation should succeed");
508
509        let data = result.to_vec().expect("to_vec conversion should succeed");
510        assert_eq!(data, vec![2.0, 4.0, 6.0, 8.0]);
511    }
512
513    #[test]
514    fn test_division() {
515        let a = tensor_1d(&[10.0f32, 20.0, 30.0, 40.0]).expect("tensor_1d creation should succeed");
516        let b = tensor_1d(&[2.0f32, 4.0, 5.0, 8.0]).expect("tensor_1d creation should succeed");
517
518        let result = a
519            .expr()
520            .expect("expression should exist")
521            .div(b.expr().expect("expression should exist"))
522            .eval_vec();
523
524        assert_eq!(result, vec![5.0, 5.0, 6.0, 5.0]);
525    }
526
527    #[test]
528    fn test_subtraction() {
529        let a = tensor_1d(&[10.0f32, 20.0, 30.0, 40.0]).expect("tensor_1d creation should succeed");
530        let b = tensor_1d(&[1.0f32, 2.0, 3.0, 4.0]).expect("tensor_1d creation should succeed");
531
532        let result = a
533            .expr()
534            .expect("expression should exist")
535            .sub(b.expr().expect("expression should exist"))
536            .eval_vec();
537
538        assert_eq!(result, vec![9.0, 18.0, 27.0, 36.0]);
539    }
540
541    #[test]
542    fn test_multiple_operations_chain() {
543        let a = tensor_1d(&[1.0f32, 2.0, 3.0, 4.0]).expect("tensor_1d creation should succeed");
544        let b = tensor_1d(&[2.0f32, 2.0, 2.0, 2.0]).expect("tensor_1d creation should succeed");
545        let c = tensor_1d(&[3.0f32, 3.0, 3.0, 3.0]).expect("tensor_1d creation should succeed");
546
547        // ((a + b) * c) / 2 + 1
548        let result = a
549            .expr()
550            .expect("tensor_1d creation should succeed")
551            .add(b.expr().expect("expression should exist"))
552            .mul(c.expr().expect("expression should exist"))
553            .div_scalar(2.0)
554            .add_scalar(1.0)
555            .eval_vec();
556
557        assert_eq!(result, vec![5.5, 7.0, 8.5, 10.0]);
558    }
559}