1#![allow(dead_code)]
34use std::marker::PhantomData;
35use std::ops::{Add, Div, Mul, Sub};
36
37use torsh_core::{dtype::TensorElement, error::Result};
38
39use crate::Tensor;
40
41pub trait Expression<T: TensorElement> {
43 fn eval_at(&self, index: usize) -> T;
45
46 fn size(&self) -> usize;
48
49 fn eval_vec(&self) -> Vec<T> {
51 (0..self.size()).map(|i| self.eval_at(i)).collect()
52 }
53
54 fn eval_tensor(
56 &self,
57 shape: Vec<usize>,
58 device: torsh_core::device::DeviceType,
59 ) -> Result<Tensor<T>>
60 where
61 T: Copy,
62 {
63 let data = self.eval_vec();
64 Tensor::from_data(data, shape, device)
65 }
66}
67
68pub struct TensorExpr<'a, T: TensorElement> {
70 data: Vec<T>,
71 size: usize,
72 _phantom: PhantomData<&'a T>,
73}
74
75impl<'a, T: TensorElement + Copy> TensorExpr<'a, T> {
76 pub fn new(tensor: &'a Tensor<T>) -> Result<Self> {
78 let data = tensor.to_vec()?;
79 let size = data.len();
80
81 Ok(Self {
82 data,
83 size,
84 _phantom: PhantomData,
85 })
86 }
87}
88
89impl<'a, T: TensorElement> Expression<T> for TensorExpr<'a, T> {
90 fn eval_at(&self, index: usize) -> T {
91 self.data[index]
92 }
93
94 fn size(&self) -> usize {
95 self.size
96 }
97
98 fn eval_vec(&self) -> Vec<T> {
99 self.data.clone()
100 }
101}
102
103pub struct AddScalarExpr<T: TensorElement, E: Expression<T>> {
105 expr: E,
106 scalar: T,
107}
108
109impl<T: TensorElement + Add<Output = T>, E: Expression<T>> Expression<T> for AddScalarExpr<T, E> {
110 fn eval_at(&self, index: usize) -> T {
111 self.expr.eval_at(index) + self.scalar
112 }
113
114 fn size(&self) -> usize {
115 self.expr.size()
116 }
117}
118
119pub struct MulScalarExpr<T: TensorElement, E: Expression<T>> {
121 expr: E,
122 scalar: T,
123}
124
125impl<T: TensorElement + Mul<Output = T>, E: Expression<T>> Expression<T> for MulScalarExpr<T, E> {
126 fn eval_at(&self, index: usize) -> T {
127 self.expr.eval_at(index) * self.scalar
128 }
129
130 fn size(&self) -> usize {
131 self.expr.size()
132 }
133}
134
135pub struct SubScalarExpr<T: TensorElement, E: Expression<T>> {
137 expr: E,
138 scalar: T,
139}
140
141impl<T: TensorElement + Sub<Output = T>, E: Expression<T>> Expression<T> for SubScalarExpr<T, E> {
142 fn eval_at(&self, index: usize) -> T {
143 self.expr.eval_at(index) - self.scalar
144 }
145
146 fn size(&self) -> usize {
147 self.expr.size()
148 }
149}
150
151pub struct DivScalarExpr<T: TensorElement, E: Expression<T>> {
153 expr: E,
154 scalar: T,
155}
156
157impl<T: TensorElement + Div<Output = T>, E: Expression<T>> Expression<T> for DivScalarExpr<T, E> {
158 fn eval_at(&self, index: usize) -> T {
159 self.expr.eval_at(index) / self.scalar
160 }
161
162 fn size(&self) -> usize {
163 self.expr.size()
164 }
165}
166
167pub struct AddExpr<T: TensorElement, E1: Expression<T>, E2: Expression<T>> {
169 left: E1,
170 right: E2,
171 _phantom: PhantomData<T>,
172}
173
174impl<T: TensorElement + Add<Output = T>, E1: Expression<T>, E2: Expression<T>> Expression<T>
175 for AddExpr<T, E1, E2>
176{
177 fn eval_at(&self, index: usize) -> T {
178 self.left.eval_at(index) + self.right.eval_at(index)
179 }
180
181 fn size(&self) -> usize {
182 self.left.size().min(self.right.size())
183 }
184}
185
186pub struct MulExpr<T: TensorElement, E1: Expression<T>, E2: Expression<T>> {
188 left: E1,
189 right: E2,
190 _phantom: PhantomData<T>,
191}
192
193impl<T: TensorElement + Mul<Output = T>, E1: Expression<T>, E2: Expression<T>> Expression<T>
194 for MulExpr<T, E1, E2>
195{
196 fn eval_at(&self, index: usize) -> T {
197 self.left.eval_at(index) * self.right.eval_at(index)
198 }
199
200 fn size(&self) -> usize {
201 self.left.size().min(self.right.size())
202 }
203}
204
205pub struct SubExpr<T: TensorElement, E1: Expression<T>, E2: Expression<T>> {
207 left: E1,
208 right: E2,
209 _phantom: PhantomData<T>,
210}
211
212impl<T: TensorElement + Sub<Output = T>, E1: Expression<T>, E2: Expression<T>> Expression<T>
213 for SubExpr<T, E1, E2>
214{
215 fn eval_at(&self, index: usize) -> T {
216 self.left.eval_at(index) - self.right.eval_at(index)
217 }
218
219 fn size(&self) -> usize {
220 self.left.size().min(self.right.size())
221 }
222}
223
224pub struct DivExpr<T: TensorElement, E1: Expression<T>, E2: Expression<T>> {
226 left: E1,
227 right: E2,
228 _phantom: PhantomData<T>,
229}
230
231impl<T: TensorElement + Div<Output = T>, E1: Expression<T>, E2: Expression<T>> Expression<T>
232 for DivExpr<T, E1, E2>
233{
234 fn eval_at(&self, index: usize) -> T {
235 self.left.eval_at(index) / self.right.eval_at(index)
236 }
237
238 fn size(&self) -> usize {
239 self.left.size().min(self.right.size())
240 }
241}
242
243pub struct NegExpr<T: TensorElement, E: Expression<T>> {
245 expr: E,
246 _phantom: PhantomData<T>,
247}
248
249impl<T: TensorElement + std::ops::Neg<Output = T>, E: Expression<T>> Expression<T>
250 for NegExpr<T, E>
251{
252 fn eval_at(&self, index: usize) -> T {
253 -self.expr.eval_at(index)
254 }
255
256 fn size(&self) -> usize {
257 self.expr.size()
258 }
259}
260
261pub struct ExprBuilder<T: TensorElement, E: Expression<T>> {
263 expr: E,
264 _phantom: PhantomData<T>,
265}
266
267impl<T: TensorElement, E: Expression<T>> ExprBuilder<T, E> {
268 pub fn new(expr: E) -> Self {
270 Self {
271 expr,
272 _phantom: PhantomData,
273 }
274 }
275
276 pub fn add_scalar(self, scalar: T) -> ExprBuilder<T, AddScalarExpr<T, E>>
278 where
279 T: Add<Output = T>,
280 {
281 ExprBuilder::new(AddScalarExpr {
282 expr: self.expr,
283 scalar,
284 })
285 }
286
287 pub fn mul_scalar(self, scalar: T) -> ExprBuilder<T, MulScalarExpr<T, E>>
289 where
290 T: Mul<Output = T>,
291 {
292 ExprBuilder::new(MulScalarExpr {
293 expr: self.expr,
294 scalar,
295 })
296 }
297
298 pub fn sub_scalar(self, scalar: T) -> ExprBuilder<T, SubScalarExpr<T, E>>
300 where
301 T: Sub<Output = T>,
302 {
303 ExprBuilder::new(SubScalarExpr {
304 expr: self.expr,
305 scalar,
306 })
307 }
308
309 pub fn div_scalar(self, scalar: T) -> ExprBuilder<T, DivScalarExpr<T, E>>
311 where
312 T: Div<Output = T>,
313 {
314 ExprBuilder::new(DivScalarExpr {
315 expr: self.expr,
316 scalar,
317 })
318 }
319
320 pub fn add<E2: Expression<T>>(
322 self,
323 other: ExprBuilder<T, E2>,
324 ) -> ExprBuilder<T, AddExpr<T, E, E2>>
325 where
326 T: Add<Output = T>,
327 {
328 ExprBuilder::new(AddExpr {
329 left: self.expr,
330 right: other.expr,
331 _phantom: PhantomData,
332 })
333 }
334
335 pub fn mul<E2: Expression<T>>(
337 self,
338 other: ExprBuilder<T, E2>,
339 ) -> ExprBuilder<T, MulExpr<T, E, E2>>
340 where
341 T: Mul<Output = T>,
342 {
343 ExprBuilder::new(MulExpr {
344 left: self.expr,
345 right: other.expr,
346 _phantom: PhantomData,
347 })
348 }
349
350 pub fn sub<E2: Expression<T>>(
352 self,
353 other: ExprBuilder<T, E2>,
354 ) -> ExprBuilder<T, SubExpr<T, E, E2>>
355 where
356 T: Sub<Output = T>,
357 {
358 ExprBuilder::new(SubExpr {
359 left: self.expr,
360 right: other.expr,
361 _phantom: PhantomData,
362 })
363 }
364
365 pub fn div<E2: Expression<T>>(
367 self,
368 other: ExprBuilder<T, E2>,
369 ) -> ExprBuilder<T, DivExpr<T, E, E2>>
370 where
371 T: Div<Output = T>,
372 {
373 ExprBuilder::new(DivExpr {
374 left: self.expr,
375 right: other.expr,
376 _phantom: PhantomData,
377 })
378 }
379
380 pub fn neg(self) -> ExprBuilder<T, NegExpr<T, E>>
382 where
383 T: std::ops::Neg<Output = T>,
384 {
385 ExprBuilder::new(NegExpr {
386 expr: self.expr,
387 _phantom: PhantomData,
388 })
389 }
390
391 pub fn eval_vec(&self) -> Vec<T> {
393 self.expr.eval_vec()
394 }
395
396 pub fn eval_tensor(
398 &self,
399 shape: Vec<usize>,
400 device: torsh_core::device::DeviceType,
401 ) -> Result<Tensor<T>>
402 where
403 T: Copy,
404 {
405 self.expr.eval_tensor(shape, device)
406 }
407}
408
409pub fn expr<'a, T: TensorElement + Copy>(
411 tensor: &'a Tensor<T>,
412) -> Result<ExprBuilder<T, TensorExpr<'a, T>>> {
413 let tensor_expr = TensorExpr::new(tensor)?;
414 Ok(ExprBuilder::new(tensor_expr))
415}
416
417pub trait TensorExprExt<T: TensorElement> {
419 fn expr(&self) -> Result<ExprBuilder<T, TensorExpr<'_, T>>>
421 where
422 T: Copy;
423}
424
425impl<T: TensorElement + Copy> TensorExprExt<T> for Tensor<T> {
426 fn expr(&self) -> Result<ExprBuilder<T, TensorExpr<'_, T>>> {
427 expr(self)
428 }
429}
430
431#[cfg(test)]
432mod tests {
433 use super::*;
434 use crate::creation::*;
435 use torsh_core::device::DeviceType;
436
437 #[test]
438 fn test_scalar_operations() {
439 let tensor =
440 tensor_1d(&[1.0f32, 2.0, 3.0, 4.0]).expect("tensor_1d creation should succeed");
441
442 let result = tensor
443 .expr()
444 .expect("tensor_1d creation should succeed")
445 .add_scalar(1.0)
446 .mul_scalar(2.0)
447 .eval_vec();
448
449 assert_eq!(result, vec![4.0, 6.0, 8.0, 10.0]);
450 }
451
452 #[test]
453 fn test_element_wise_operations() {
454 let a = tensor_1d(&[1.0f32, 2.0, 3.0, 4.0]).expect("tensor_1d creation should succeed");
455 let b = tensor_1d(&[2.0f32, 2.0, 2.0, 2.0]).expect("tensor_1d creation should succeed");
456
457 let result = a
458 .expr()
459 .expect("expression should exist")
460 .add(b.expr().expect("expression should exist"))
461 .eval_vec();
462
463 assert_eq!(result, vec![3.0, 4.0, 5.0, 6.0]);
464 }
465
466 #[test]
467 fn test_complex_expression() {
468 let a = tensor_1d(&[1.0f32, 2.0, 3.0, 4.0]).expect("tensor_1d creation should succeed");
469 let b = tensor_1d(&[2.0f32, 2.0, 2.0, 2.0]).expect("tensor_1d creation should succeed");
470
471 let result = a
473 .expr()
474 .expect("tensor_1d creation should succeed")
475 .add(b.expr().expect("expression should exist"))
476 .mul_scalar(2.0)
477 .add_scalar(1.0)
478 .eval_vec();
479
480 assert_eq!(result, vec![7.0, 9.0, 11.0, 13.0]);
481 }
482
483 #[test]
484 fn test_negation() {
485 let tensor =
486 tensor_1d(&[1.0f32, 2.0, -3.0, 4.0]).expect("tensor_1d creation should succeed");
487
488 let result = tensor
489 .expr()
490 .expect("expression should exist")
491 .neg()
492 .eval_vec();
493
494 assert_eq!(result, vec![-1.0, -2.0, 3.0, -4.0]);
495 }
496
497 #[test]
498 fn test_eval_tensor() {
499 let tensor =
500 tensor_1d(&[1.0f32, 2.0, 3.0, 4.0]).expect("tensor_1d creation should succeed");
501
502 let result = tensor
503 .expr()
504 .expect("tensor_1d creation should succeed")
505 .mul_scalar(2.0)
506 .eval_tensor(vec![4], DeviceType::Cpu)
507 .expect("tensor_1d creation should succeed");
508
509 let data = result.to_vec().expect("to_vec conversion should succeed");
510 assert_eq!(data, vec![2.0, 4.0, 6.0, 8.0]);
511 }
512
513 #[test]
514 fn test_division() {
515 let a = tensor_1d(&[10.0f32, 20.0, 30.0, 40.0]).expect("tensor_1d creation should succeed");
516 let b = tensor_1d(&[2.0f32, 4.0, 5.0, 8.0]).expect("tensor_1d creation should succeed");
517
518 let result = a
519 .expr()
520 .expect("expression should exist")
521 .div(b.expr().expect("expression should exist"))
522 .eval_vec();
523
524 assert_eq!(result, vec![5.0, 5.0, 6.0, 5.0]);
525 }
526
527 #[test]
528 fn test_subtraction() {
529 let a = tensor_1d(&[10.0f32, 20.0, 30.0, 40.0]).expect("tensor_1d creation should succeed");
530 let b = tensor_1d(&[1.0f32, 2.0, 3.0, 4.0]).expect("tensor_1d creation should succeed");
531
532 let result = a
533 .expr()
534 .expect("expression should exist")
535 .sub(b.expr().expect("expression should exist"))
536 .eval_vec();
537
538 assert_eq!(result, vec![9.0, 18.0, 27.0, 36.0]);
539 }
540
541 #[test]
542 fn test_multiple_operations_chain() {
543 let a = tensor_1d(&[1.0f32, 2.0, 3.0, 4.0]).expect("tensor_1d creation should succeed");
544 let b = tensor_1d(&[2.0f32, 2.0, 2.0, 2.0]).expect("tensor_1d creation should succeed");
545 let c = tensor_1d(&[3.0f32, 3.0, 3.0, 3.0]).expect("tensor_1d creation should succeed");
546
547 let result = a
549 .expr()
550 .expect("tensor_1d creation should succeed")
551 .add(b.expr().expect("expression should exist"))
552 .mul(c.expr().expect("expression should exist"))
553 .div_scalar(2.0)
554 .add_scalar(1.0)
555 .eval_vec();
556
557 assert_eq!(result, vec![5.5, 7.0, 8.5, 10.0]);
558 }
559}