1use super::*;
2
3macro_rules! impl_binary_op {
11 ($trait:ident, $method:ident, $try_method:ident, $error_msg:expr) => {
12 impl std::ops::$trait for &Tensor {
13 type Output = Tensor;
14
15 #[track_caller]
16 fn $method(self, other: &Tensor) -> Tensor {
17 self.$try_method(other).expect($error_msg)
18 }
19 }
20
21 impl std::ops::$trait for Tensor {
22 type Output = Tensor;
23
24 #[track_caller]
25 fn $method(self, other: Tensor) -> Tensor {
26 (&self).$method(&other)
27 }
28 }
29
30 impl std::ops::$trait<Tensor> for &Tensor {
31 type Output = Tensor;
32
33 #[track_caller]
34 fn $method(self, other: Tensor) -> Tensor {
35 self.$method(&other)
36 }
37 }
38
39 impl std::ops::$trait<&Tensor> for Tensor {
40 type Output = Tensor;
41
42 #[track_caller]
43 fn $method(self, other: &Tensor) -> Tensor {
44 (&self).$method(other)
45 }
46 }
47 };
48}
49
50impl_binary_op!(Add, add, try_add, "Addition failed");
52impl_binary_op!(Sub, sub, try_sub, "Subtraction failed");
53impl_binary_op!(Mul, mul, try_mul, "Multiplication failed");
54impl_binary_op!(Div, div, try_div, "Division failed");
55
56impl std::ops::Neg for &Tensor {
58 type Output = Tensor;
59
60 #[track_caller]
61 fn neg(self) -> Tensor {
62 self.try_neg().expect("Negation failed")
63 }
64}
65
66impl std::ops::Neg for Tensor {
67 type Output = Tensor;
68
69 #[track_caller]
70 fn neg(self) -> Tensor {
71 (&self).neg()
72 }
73}