1use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign};
2
3use crate::{
4 device::Device,
5 dim::{larger_shape, DimDyn, DimTrait},
6 matrix::{Matrix, Owned, Ref, Repr},
7 num::Num,
8};
9
10macro_rules! call_on_self {
11 ($self:ident, $F:ident, $($args:expr),*) => {
12 $self.$F($($args),*)
13 };
14}
15
16macro_rules! impl_arithmetic_ops {
17 (
18 $trait:ident,
19 $trait_method:ident,
20 $assign_trait:ident,
21 $assign_trait_method:ident,
22 $scalr:ident,
23 $scalar_assign:ident,
24 $array:ident,
25 $array_assign:ident
26 ) => {
27 impl<T: Num, R: Repr<Item = T>, D: Device> $trait<T> for Matrix<R, DimDyn, D> {
29 type Output = Matrix<Owned<T>, DimDyn, D>;
30
31 fn $trait_method(self, rhs: T) -> Self::Output {
32 let s = self.to_ref().into_dyn_dim();
33 let mut owned = Matrix::alloc_like(&self).into_dyn_dim();
34 {
35 let mut ref_mut = owned.to_ref_mut();
36 call_on_self!(ref_mut, $scalr, &s, rhs);
37 }
38 owned
39 }
40 }
41
42 impl<T: Num, R: Repr<Item = T>, D: Device> $trait<T> for &Matrix<R, DimDyn, D> {
43 type Output = Matrix<Owned<T>, DimDyn, D>;
44
45 fn $trait_method(self, rhs: T) -> Self::Output {
46 let mut owned = Matrix::alloc_like(self);
47 let s = self.to_ref().into_dyn_dim();
48 {
49 let mut ref_mut = owned.to_ref_mut();
50 call_on_self!(ref_mut, $scalr, &s, rhs);
51 }
52 owned
53 }
54 }
55
56 impl<
58 T: Num,
59 RS: Repr<Item = T>,
60 SS: DimTrait,
61 RO: Repr<Item = T>,
62 SO: DimTrait,
63 D: Device,
64 > $trait<Matrix<RO, SO, D>> for Matrix<RS, SS, D>
65 {
66 type Output = Matrix<Owned<T>, DimDyn, D>;
67
68 fn $trait_method(self, rhs: Matrix<RO, SO, D>) -> Self::Output {
69 let larger = if self.shape().len() == rhs.shape().len() {
70 DimDyn::from(larger_shape(self.shape(), rhs.shape()))
71 } else if self.shape().len() > rhs.shape().len() {
72 DimDyn::from(self.shape().slice())
73 } else {
74 DimDyn::from(rhs.shape().slice())
75 };
76 let mut owned: Matrix<Owned<T>, DimDyn, D> = Matrix::alloc(larger.slice());
77 {
78 let mut ref_mut = owned.to_ref_mut();
79 let s = self.to_ref().into_dyn_dim();
80 let rhs = rhs.to_ref().into_dyn_dim();
81 call_on_self!(ref_mut, $array, &s, &rhs);
82 }
83 owned
84 }
85 }
86
87 impl<
88 T: Num,
89 RS: Repr<Item = T>,
90 SS: DimTrait,
91 RO: Repr<Item = T>,
92 SO: DimTrait,
93 D: Device,
94 > $trait<&Matrix<RO, SO, D>> for Matrix<RS, SS, D>
95 {
96 type Output = Matrix<Owned<T>, DimDyn, D>;
97
98 fn $trait_method(self, rhs: &Matrix<RO, SO, D>) -> Self::Output {
99 let larger = if self.shape().len() == rhs.shape().len() {
100 DimDyn::from(larger_shape(self.shape(), rhs.shape()))
101 } else if self.shape().len() > rhs.shape().len() {
102 DimDyn::from(self.shape().slice())
103 } else {
104 DimDyn::from(rhs.shape().slice())
105 };
106 let mut owned: Matrix<Owned<T>, DimDyn, D> = Matrix::alloc(larger.slice());
107 {
108 let mut ref_mut = owned.to_ref_mut();
109 let s = self.to_ref().into_dyn_dim();
110 let rhs = rhs.to_ref().into_dyn_dim();
111 call_on_self!(ref_mut, $array, &s, &rhs);
112 }
113 owned
114 }
115 }
116 impl<
117 T: Num,
118 RS: Repr<Item = T>,
119 SS: DimTrait,
120 RO: Repr<Item = T>,
121 SO: DimTrait,
122 D: Device,
123 > $trait<Matrix<RO, SO, D>> for &Matrix<RS, SS, D>
124 {
125 type Output = Matrix<Owned<T>, DimDyn, D>;
126
127 fn $trait_method(self, rhs: Matrix<RO, SO, D>) -> Self::Output {
128 let larger = if self.shape().len() == rhs.shape().len() {
129 DimDyn::from(larger_shape(self.shape(), rhs.shape()))
130 } else if self.shape().len() > rhs.shape().len() {
131 DimDyn::from(self.shape().slice())
132 } else {
133 DimDyn::from(rhs.shape().slice())
134 };
135 let mut owned: Matrix<Owned<T>, DimDyn, D> = Matrix::alloc(larger.slice());
136 {
137 let mut ref_mut = owned.to_ref_mut();
138 let s = self.to_ref().into_dyn_dim();
139 let rhs = rhs.to_ref().into_dyn_dim();
140 call_on_self!(ref_mut, $array, &s, &rhs);
141 }
142 owned
143 }
144 }
145
146 impl<
147 T: Num,
148 RS: Repr<Item = T>,
149 SS: DimTrait,
150 RO: Repr<Item = T>,
151 SO: DimTrait,
152 D: Device,
153 > $trait<&Matrix<RO, SO, D>> for &Matrix<RS, SS, D>
154 {
155 type Output = Matrix<Owned<T>, DimDyn, D>;
156
157 fn $trait_method(self, rhs: &Matrix<RO, SO, D>) -> Self::Output {
158 let larger = if self.shape().len() == rhs.shape().len() {
159 DimDyn::from(larger_shape(self.shape(), rhs.shape()))
160 } else if self.shape().len() > rhs.shape().len() {
161 DimDyn::from(self.shape().slice())
162 } else {
163 DimDyn::from(rhs.shape().slice())
164 };
165 let mut owned: Matrix<Owned<T>, DimDyn, D> = Matrix::alloc(larger.slice());
166 {
167 let mut ref_mut = owned.to_ref_mut();
168 let s = self.to_ref().into_dyn_dim();
169 let rhs = rhs.to_ref().into_dyn_dim();
170 call_on_self!(ref_mut, $array, &s, &rhs);
171 }
172 owned
173 }
174 }
175
176 impl<T: Num, D: Device> $assign_trait<T> for Matrix<Ref<&mut T>, DimDyn, D> {
179 fn $assign_trait_method(&mut self, rhs: T) {
180 call_on_self!(self, $scalar_assign, rhs);
181 }
182 }
183
184 impl<T: Num, S: DimTrait, D: Device> $assign_trait<T> for Matrix<Owned<T>, S, D> {
186 fn $assign_trait_method(&mut self, rhs: T) {
187 let mut ref_mut = self.to_ref_mut().into_dyn_dim();
188 call_on_self!(ref_mut, $scalar_assign, rhs);
189 }
190 }
191
192 impl<T: Num, SS: DimTrait, RO: Repr<Item = T>, SO: DimTrait, D: Device>
194 $assign_trait<Matrix<RO, SO, D>> for Matrix<Owned<T>, SS, D>
195 {
196 fn $assign_trait_method(&mut self, rhs: Matrix<RO, SO, D>) {
197 let mut ref_mut = self.to_ref_mut().into_dyn_dim();
198 let rhs = rhs.to_ref().into_dyn_dim();
199 call_on_self!(ref_mut, $array_assign, &rhs);
200 }
201 }
202
203 impl<T: Num, R: Repr<Item = T>, SO: DimTrait, D: Device> $assign_trait<Matrix<R, SO, D>>
206 for Matrix<Ref<&mut T>, DimDyn, D>
207 {
208 fn $assign_trait_method(&mut self, rhs: Matrix<R, SO, D>) {
209 let rhs = rhs.to_ref().into_dyn_dim();
210 call_on_self!(self, $array_assign, &rhs);
211 }
212 }
213 };
214}
215impl_arithmetic_ops!(
216 Add,
217 add,
218 AddAssign,
219 add_assign,
220 add_scalar,
221 add_scalar_assign,
222 add_array,
223 add_assign
224);
225impl_arithmetic_ops!(
226 Sub,
227 sub,
228 SubAssign,
229 sub_assign,
230 sub_scalar,
231 sub_scalar_assign,
232 sub_array,
233 sub_assign
234);
235impl_arithmetic_ops!(
236 Mul,
237 mul,
238 MulAssign,
239 mul_assign,
240 mul_scalar,
241 mul_scalar_assign,
242 mul_array,
243 mul_assign
244);
245impl_arithmetic_ops!(
246 Div,
247 div,
248 DivAssign,
249 div_assign,
250 div_scalar,
251 div_scalar_assign,
252 div_array,
253 div_assign
254);