1use super::Tensor;
3use crate::Scalar;
4use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
5
6fn id<T>(v: T) -> T {
7 v
8}
9
10fn neg(t: Tensor) -> Tensor {
11 t.neg()
12}
13
14fn inv(t: Tensor) -> Tensor {
15 t.pow_tensor_scalar(-1)
16}
17
18macro_rules! impl_op {
19 ($trait:ident, $func:ident, $op:ident) => {
20 impl $trait<Tensor> for Tensor {
21 type Output = Tensor;
22
23 fn $func(self, rhs: Tensor) -> Self::Output {
24 self.$op(&rhs)
25 }
26 }
27
28 impl $trait<&Tensor> for Tensor {
29 type Output = Tensor;
30
31 fn $func(self, rhs: &Tensor) -> Self::Output {
32 self.$op(rhs)
33 }
34 }
35
36 impl<'a> $trait<&Tensor> for &'a Tensor {
37 type Output = Tensor;
38
39 fn $func(self, rhs: &Tensor) -> Self::Output {
40 self.$op(rhs)
41 }
42 }
43
44 impl $trait<Tensor> for &Tensor {
45 type Output = Tensor;
46
47 fn $func(self, rhs: Tensor) -> Self::Output {
48 self.$op(&rhs)
49 }
50 }
51 };
52}
53
54impl<S> Add<S> for &Tensor
55where
56 S: Into<Scalar>,
57{
58 type Output = Tensor;
59
60 fn add(self, rhs: S) -> Self::Output {
61 self.g_add_scalar(rhs)
62 }
63}
64
65impl<S> Add<S> for Tensor
66where
67 S: Into<Scalar>,
68{
69 type Output = Tensor;
70
71 fn add(self, rhs: S) -> Self::Output {
72 (&self).add(rhs)
73 }
74}
75
76impl<S> Sub<S> for &Tensor
77where
78 S: Into<Scalar>,
79{
80 type Output = Tensor;
81
82 fn sub(self, rhs: S) -> Self::Output {
83 self.g_sub_scalar(rhs)
84 }
85}
86
87impl<S> Sub<S> for Tensor
88where
89 S: Into<Scalar>,
90{
91 type Output = Tensor;
92
93 fn sub(self, rhs: S) -> Self::Output {
94 (&self).sub(rhs)
95 }
96}
97
98impl<S> Mul<S> for &Tensor
99where
100 S: Into<Scalar>,
101{
102 type Output = Tensor;
103
104 fn mul(self, rhs: S) -> Self::Output {
105 self.g_mul_scalar(rhs)
106 }
107}
108
109impl<S> Mul<S> for Tensor
110where
111 S: Into<Scalar>,
112{
113 type Output = Tensor;
114
115 fn mul(self, rhs: S) -> Self::Output {
116 (&self).mul(rhs)
117 }
118}
119
120impl<S> Div<S> for &Tensor
121where
122 S: Into<Scalar>,
123{
124 type Output = Tensor;
125
126 fn div(self, rhs: S) -> Self::Output {
127 self.g_div_scalar(rhs)
128 }
129}
130
131impl<S> Div<S> for Tensor
132where
133 S: Into<Scalar>,
134{
135 type Output = Tensor;
136
137 fn div(self, rhs: S) -> Self::Output {
138 (&self).div(rhs)
139 }
140}
141
142macro_rules! impl_op_basic {
143 ($trait:ident, $func:ident, $op:ident, $rev:ident) => {
145 impl $trait<Tensor> for i32 {
146 type Output = Tensor;
147
148 fn $func(self, rhs: Tensor) -> Self::Output {
149 self.$func(&rhs)
150 }
151 }
152
153 impl $trait<Tensor> for i64 {
154 type Output = Tensor;
155
156 fn $func(self, rhs: Tensor) -> Self::Output {
157 self.$func(&rhs)
158 }
159 }
160
161 impl $trait<Tensor> for f32 {
162 type Output = Tensor;
163
164 fn $func(self, rhs: Tensor) -> Self::Output {
165 self.$func(&rhs)
166 }
167 }
168
169 impl $trait<Tensor> for f64 {
170 type Output = Tensor;
171
172 fn $func(self, rhs: Tensor) -> Self::Output {
173 self.$func(&rhs)
174 }
175 }
176
177 impl $trait<&Tensor> for i32 {
178 type Output = Tensor;
179
180 fn $func(self, rhs: &Tensor) -> Self::Output {
181 $rev(rhs.$op(self as i64))
182 }
183 }
184
185 impl $trait<&Tensor> for i64 {
186 type Output = Tensor;
187
188 fn $func(self, rhs: &Tensor) -> Self::Output {
189 $rev(rhs.$op(self))
190 }
191 }
192
193 impl $trait<&Tensor> for f32 {
194 type Output = Tensor;
195
196 fn $func(self, rhs: &Tensor) -> Self::Output {
197 $rev(rhs.$op(self as f64))
198 }
199 }
200
201 impl $trait<&Tensor> for f64 {
202 type Output = Tensor;
203
204 fn $func(self, rhs: &Tensor) -> Self::Output {
205 $rev(rhs.$op(self))
206 }
207 }
208 };
209}
210
211macro_rules! impl_op_assign {
212 ($trait:ident, $func:ident, $op:ident) => {
213 impl $trait<Tensor> for Tensor {
214 fn $func(&mut self, rhs: Tensor) {
215 let _ = self.$op(&rhs);
216 }
217 }
218
219 impl $trait<&Tensor> for Tensor {
220 fn $func(&mut self, rhs: &Tensor) {
221 let _ = self.$op(rhs);
222 }
223 }
224 };
225}
226
227macro_rules! impl_op_assign_basic {
228 ($trait:ident, $func:ident, $op:ident) => {
229 impl $trait<i32> for Tensor {
230 fn $func(&mut self, rhs: i32) {
231 let _ = self.$op(rhs as i64);
232 }
233 }
234
235 impl $trait<i64> for Tensor {
236 fn $func(&mut self, rhs: i64) {
237 let _ = self.$op(rhs);
238 }
239 }
240
241 impl $trait<f32> for Tensor {
242 fn $func(&mut self, rhs: f32) {
243 let _ = self.$op(rhs as f64);
244 }
245 }
246
247 impl $trait<f64> for Tensor {
248 fn $func(&mut self, rhs: f64) {
249 let _ = self.$op(rhs);
250 }
251 }
252 };
253}
254
255impl_op!(Add, add, g_add);
256impl_op_basic!(Add, add, g_add_scalar, id);
257impl_op_assign!(AddAssign, add_assign, g_add_);
258impl_op_assign_basic!(AddAssign, add_assign, g_add_scalar_);
259
260impl_op!(Mul, mul, g_mul);
261impl_op_basic!(Mul, mul, g_mul_scalar, id);
262impl_op_assign!(MulAssign, mul_assign, g_mul_);
263impl_op_assign_basic!(MulAssign, mul_assign, g_mul_scalar_);
264
265impl_op!(Div, div, g_div);
266impl_op_basic!(Div, div, g_div_scalar, inv);
267impl_op_assign!(DivAssign, div_assign, g_div_);
268impl_op_assign_basic!(DivAssign, div_assign, g_div_scalar_);
269
270impl_op!(Sub, sub, g_sub);
271impl_op_basic!(Sub, sub, g_sub_scalar, neg);
272impl_op_assign!(SubAssign, sub_assign, g_sub_);
273impl_op_assign_basic!(SubAssign, sub_assign, g_sub_scalar_);
274
275impl Neg for Tensor {
276 type Output = Tensor;
277
278 fn neg(self) -> Tensor {
279 self.f_neg().unwrap()
280 }
281}
282
283impl Neg for &Tensor {
284 type Output = Tensor;
285
286 fn neg(self) -> Tensor {
287 self.f_neg().unwrap()
288 }
289}
290
291impl PartialEq for Tensor {
292 fn eq(&self, other: &Tensor) -> bool {
293 if self.size() != other.size() {
294 return false;
295 }
296 match self.f_eq_tensor(other) {
297 Err(_) => false,
298 Ok(v) => match v.f_all() {
299 Err(_) => false,
300 Ok(v) => match i64::try_from(v) {
301 Err(_) => false,
302 Ok(v) => v > 0,
303 },
304 },
305 }
306 }
307}