tch_plus/tensor/
ops.rs

1//! Implement various ops traits for tensors
2use super::Tensor;
3use crate::Scalar;
4use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
5
6fn id<T>(v: T) -> T {
7    v
8}
9
10fn neg(t: Tensor) -> Tensor {
11    t.neg()
12}
13
14fn inv(t: Tensor) -> Tensor {
15    t.pow_tensor_scalar(-1)
16}
17
18macro_rules! impl_op {
19    ($trait:ident, $func:ident, $op:ident) => {
20        impl $trait<Tensor> for Tensor {
21            type Output = Tensor;
22
23            fn $func(self, rhs: Tensor) -> Self::Output {
24                self.$op(&rhs)
25            }
26        }
27
28        impl $trait<&Tensor> for Tensor {
29            type Output = Tensor;
30
31            fn $func(self, rhs: &Tensor) -> Self::Output {
32                self.$op(rhs)
33            }
34        }
35
36        impl<'a> $trait<&Tensor> for &'a Tensor {
37            type Output = Tensor;
38
39            fn $func(self, rhs: &Tensor) -> Self::Output {
40                self.$op(rhs)
41            }
42        }
43
44        impl $trait<Tensor> for &Tensor {
45            type Output = Tensor;
46
47            fn $func(self, rhs: Tensor) -> Self::Output {
48                self.$op(&rhs)
49            }
50        }
51    };
52}
53
54impl<S> Add<S> for &Tensor
55where
56    S: Into<Scalar>,
57{
58    type Output = Tensor;
59
60    fn add(self, rhs: S) -> Self::Output {
61        self.g_add_scalar(rhs)
62    }
63}
64
65impl<S> Add<S> for Tensor
66where
67    S: Into<Scalar>,
68{
69    type Output = Tensor;
70
71    fn add(self, rhs: S) -> Self::Output {
72        (&self).add(rhs)
73    }
74}
75
76impl<S> Sub<S> for &Tensor
77where
78    S: Into<Scalar>,
79{
80    type Output = Tensor;
81
82    fn sub(self, rhs: S) -> Self::Output {
83        self.g_sub_scalar(rhs)
84    }
85}
86
87impl<S> Sub<S> for Tensor
88where
89    S: Into<Scalar>,
90{
91    type Output = Tensor;
92
93    fn sub(self, rhs: S) -> Self::Output {
94        (&self).sub(rhs)
95    }
96}
97
98impl<S> Mul<S> for &Tensor
99where
100    S: Into<Scalar>,
101{
102    type Output = Tensor;
103
104    fn mul(self, rhs: S) -> Self::Output {
105        self.g_mul_scalar(rhs)
106    }
107}
108
109impl<S> Mul<S> for Tensor
110where
111    S: Into<Scalar>,
112{
113    type Output = Tensor;
114
115    fn mul(self, rhs: S) -> Self::Output {
116        (&self).mul(rhs)
117    }
118}
119
120impl<S> Div<S> for &Tensor
121where
122    S: Into<Scalar>,
123{
124    type Output = Tensor;
125
126    fn div(self, rhs: S) -> Self::Output {
127        self.g_div_scalar(rhs)
128    }
129}
130
131impl<S> Div<S> for Tensor
132where
133    S: Into<Scalar>,
134{
135    type Output = Tensor;
136
137    fn div(self, rhs: S) -> Self::Output {
138        (&self).div(rhs)
139    }
140}
141
142macro_rules! impl_op_basic {
143    /* rev such that rev(op(b, a)) = op(a, b) */
144    ($trait:ident, $func:ident, $op:ident, $rev:ident) => {
145        impl $trait<Tensor> for i32 {
146            type Output = Tensor;
147
148            fn $func(self, rhs: Tensor) -> Self::Output {
149                self.$func(&rhs)
150            }
151        }
152
153        impl $trait<Tensor> for i64 {
154            type Output = Tensor;
155
156            fn $func(self, rhs: Tensor) -> Self::Output {
157                self.$func(&rhs)
158            }
159        }
160
161        impl $trait<Tensor> for f32 {
162            type Output = Tensor;
163
164            fn $func(self, rhs: Tensor) -> Self::Output {
165                self.$func(&rhs)
166            }
167        }
168
169        impl $trait<Tensor> for f64 {
170            type Output = Tensor;
171
172            fn $func(self, rhs: Tensor) -> Self::Output {
173                self.$func(&rhs)
174            }
175        }
176
177        impl $trait<&Tensor> for i32 {
178            type Output = Tensor;
179
180            fn $func(self, rhs: &Tensor) -> Self::Output {
181                $rev(rhs.$op(self as i64))
182            }
183        }
184
185        impl $trait<&Tensor> for i64 {
186            type Output = Tensor;
187
188            fn $func(self, rhs: &Tensor) -> Self::Output {
189                $rev(rhs.$op(self))
190            }
191        }
192
193        impl $trait<&Tensor> for f32 {
194            type Output = Tensor;
195
196            fn $func(self, rhs: &Tensor) -> Self::Output {
197                $rev(rhs.$op(self as f64))
198            }
199        }
200
201        impl $trait<&Tensor> for f64 {
202            type Output = Tensor;
203
204            fn $func(self, rhs: &Tensor) -> Self::Output {
205                $rev(rhs.$op(self))
206            }
207        }
208    };
209}
210
211macro_rules! impl_op_assign {
212    ($trait:ident, $func:ident, $op:ident) => {
213        impl $trait<Tensor> for Tensor {
214            fn $func(&mut self, rhs: Tensor) {
215                let _ = self.$op(&rhs);
216            }
217        }
218
219        impl $trait<&Tensor> for Tensor {
220            fn $func(&mut self, rhs: &Tensor) {
221                let _ = self.$op(rhs);
222            }
223        }
224    };
225}
226
227macro_rules! impl_op_assign_basic {
228    ($trait:ident, $func:ident, $op:ident) => {
229        impl $trait<i32> for Tensor {
230            fn $func(&mut self, rhs: i32) {
231                let _ = self.$op(rhs as i64);
232            }
233        }
234
235        impl $trait<i64> for Tensor {
236            fn $func(&mut self, rhs: i64) {
237                let _ = self.$op(rhs);
238            }
239        }
240
241        impl $trait<f32> for Tensor {
242            fn $func(&mut self, rhs: f32) {
243                let _ = self.$op(rhs as f64);
244            }
245        }
246
247        impl $trait<f64> for Tensor {
248            fn $func(&mut self, rhs: f64) {
249                let _ = self.$op(rhs);
250            }
251        }
252    };
253}
254
255impl_op!(Add, add, g_add);
256impl_op_basic!(Add, add, g_add_scalar, id);
257impl_op_assign!(AddAssign, add_assign, g_add_);
258impl_op_assign_basic!(AddAssign, add_assign, g_add_scalar_);
259
260impl_op!(Mul, mul, g_mul);
261impl_op_basic!(Mul, mul, g_mul_scalar, id);
262impl_op_assign!(MulAssign, mul_assign, g_mul_);
263impl_op_assign_basic!(MulAssign, mul_assign, g_mul_scalar_);
264
265impl_op!(Div, div, g_div);
266impl_op_basic!(Div, div, g_div_scalar, inv);
267impl_op_assign!(DivAssign, div_assign, g_div_);
268impl_op_assign_basic!(DivAssign, div_assign, g_div_scalar_);
269
270impl_op!(Sub, sub, g_sub);
271impl_op_basic!(Sub, sub, g_sub_scalar, neg);
272impl_op_assign!(SubAssign, sub_assign, g_sub_);
273impl_op_assign_basic!(SubAssign, sub_assign, g_sub_scalar_);
274
275impl Neg for Tensor {
276    type Output = Tensor;
277
278    fn neg(self) -> Tensor {
279        self.f_neg().unwrap()
280    }
281}
282
283impl Neg for &Tensor {
284    type Output = Tensor;
285
286    fn neg(self) -> Tensor {
287        self.f_neg().unwrap()
288    }
289}
290
291impl PartialEq for Tensor {
292    fn eq(&self, other: &Tensor) -> bool {
293        if self.size() != other.size() {
294            return false;
295        }
296        match self.f_eq_tensor(other) {
297            Err(_) => false,
298            Ok(v) => match v.f_all() {
299                Err(_) => false,
300                Ok(v) => match i64::try_from(v) {
301                    Err(_) => false,
302                    Ok(v) => v > 0,
303                },
304            },
305        }
306    }
307}