zenu_matrix/
impl_ops.rs

1use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign};
2
3use crate::{
4    device::Device,
5    dim::{larger_shape, DimDyn, DimTrait},
6    matrix::{Matrix, Owned, Ref, Repr},
7    num::Num,
8};
9
10macro_rules! call_on_self {
11    ($self:ident, $F:ident, $($args:expr),*) => {
12        $self.$F($($args),*)
13    };
14}
15
16macro_rules! impl_arithmetic_ops {
17    (
18        $trait:ident,
19        $trait_method:ident,
20        $assign_trait:ident,
21        $assign_trait_method:ident,
22        $scalr:ident,
23        $scalar_assign:ident,
24        $array:ident,
25        $array_assign:ident
26    ) => {
27        // Add<T> for Matrix<R, S, D>
28        impl<T: Num, R: Repr<Item = T>, D: Device> $trait<T> for Matrix<R, DimDyn, D> {
29            type Output = Matrix<Owned<T>, DimDyn, D>;
30
31            fn $trait_method(self, rhs: T) -> Self::Output {
32                let s = self.to_ref().into_dyn_dim();
33                let mut owned = Matrix::alloc_like(&self).into_dyn_dim();
34                {
35                    let mut ref_mut = owned.to_ref_mut();
36                    call_on_self!(ref_mut, $scalr, &s, rhs);
37                }
38                owned
39            }
40        }
41
42        impl<T: Num, R: Repr<Item = T>, D: Device> $trait<T> for &Matrix<R, DimDyn, D> {
43            type Output = Matrix<Owned<T>, DimDyn, D>;
44
45            fn $trait_method(self, rhs: T) -> Self::Output {
46                let mut owned = Matrix::alloc_like(self);
47                let s = self.to_ref().into_dyn_dim();
48                {
49                    let mut ref_mut = owned.to_ref_mut();
50                    call_on_self!(ref_mut, $scalr, &s, rhs);
51                }
52                owned
53            }
54        }
55
56        // Add<Matrix<RO, SO, D>> for Matrix<RS, SS, D>
57        impl<
58                T: Num,
59                RS: Repr<Item = T>,
60                SS: DimTrait,
61                RO: Repr<Item = T>,
62                SO: DimTrait,
63                D: Device,
64            > $trait<Matrix<RO, SO, D>> for Matrix<RS, SS, D>
65        {
66            type Output = Matrix<Owned<T>, DimDyn, D>;
67
68            fn $trait_method(self, rhs: Matrix<RO, SO, D>) -> Self::Output {
69                let larger = if self.shape().len() == rhs.shape().len() {
70                    DimDyn::from(larger_shape(self.shape(), rhs.shape()))
71                } else if self.shape().len() > rhs.shape().len() {
72                    DimDyn::from(self.shape().slice())
73                } else {
74                    DimDyn::from(rhs.shape().slice())
75                };
76                let mut owned: Matrix<Owned<T>, DimDyn, D> = Matrix::alloc(larger.slice());
77                {
78                    let mut ref_mut = owned.to_ref_mut();
79                    let s = self.to_ref().into_dyn_dim();
80                    let rhs = rhs.to_ref().into_dyn_dim();
81                    call_on_self!(ref_mut, $array, &s, &rhs);
82                }
83                owned
84            }
85        }
86
87        impl<
88                T: Num,
89                RS: Repr<Item = T>,
90                SS: DimTrait,
91                RO: Repr<Item = T>,
92                SO: DimTrait,
93                D: Device,
94            > $trait<&Matrix<RO, SO, D>> for Matrix<RS, SS, D>
95        {
96            type Output = Matrix<Owned<T>, DimDyn, D>;
97
98            fn $trait_method(self, rhs: &Matrix<RO, SO, D>) -> Self::Output {
99                let larger = if self.shape().len() == rhs.shape().len() {
100                    DimDyn::from(larger_shape(self.shape(), rhs.shape()))
101                } else if self.shape().len() > rhs.shape().len() {
102                    DimDyn::from(self.shape().slice())
103                } else {
104                    DimDyn::from(rhs.shape().slice())
105                };
106                let mut owned: Matrix<Owned<T>, DimDyn, D> = Matrix::alloc(larger.slice());
107                {
108                    let mut ref_mut = owned.to_ref_mut();
109                    let s = self.to_ref().into_dyn_dim();
110                    let rhs = rhs.to_ref().into_dyn_dim();
111                    call_on_self!(ref_mut, $array, &s, &rhs);
112                }
113                owned
114            }
115        }
116        impl<
117                T: Num,
118                RS: Repr<Item = T>,
119                SS: DimTrait,
120                RO: Repr<Item = T>,
121                SO: DimTrait,
122                D: Device,
123            > $trait<Matrix<RO, SO, D>> for &Matrix<RS, SS, D>
124        {
125            type Output = Matrix<Owned<T>, DimDyn, D>;
126
127            fn $trait_method(self, rhs: Matrix<RO, SO, D>) -> Self::Output {
128                let larger = if self.shape().len() == rhs.shape().len() {
129                    DimDyn::from(larger_shape(self.shape(), rhs.shape()))
130                } else if self.shape().len() > rhs.shape().len() {
131                    DimDyn::from(self.shape().slice())
132                } else {
133                    DimDyn::from(rhs.shape().slice())
134                };
135                let mut owned: Matrix<Owned<T>, DimDyn, D> = Matrix::alloc(larger.slice());
136                {
137                    let mut ref_mut = owned.to_ref_mut();
138                    let s = self.to_ref().into_dyn_dim();
139                    let rhs = rhs.to_ref().into_dyn_dim();
140                    call_on_self!(ref_mut, $array, &s, &rhs);
141                }
142                owned
143            }
144        }
145
146        impl<
147                T: Num,
148                RS: Repr<Item = T>,
149                SS: DimTrait,
150                RO: Repr<Item = T>,
151                SO: DimTrait,
152                D: Device,
153            > $trait<&Matrix<RO, SO, D>> for &Matrix<RS, SS, D>
154        {
155            type Output = Matrix<Owned<T>, DimDyn, D>;
156
157            fn $trait_method(self, rhs: &Matrix<RO, SO, D>) -> Self::Output {
158                let larger = if self.shape().len() == rhs.shape().len() {
159                    DimDyn::from(larger_shape(self.shape(), rhs.shape()))
160                } else if self.shape().len() > rhs.shape().len() {
161                    DimDyn::from(self.shape().slice())
162                } else {
163                    DimDyn::from(rhs.shape().slice())
164                };
165                let mut owned: Matrix<Owned<T>, DimDyn, D> = Matrix::alloc(larger.slice());
166                {
167                    let mut ref_mut = owned.to_ref_mut();
168                    let s = self.to_ref().into_dyn_dim();
169                    let rhs = rhs.to_ref().into_dyn_dim();
170                    call_on_self!(ref_mut, $array, &s, &rhs);
171                }
172                owned
173            }
174        }
175
176        // AddAssign<T> for Matrix<Ref<&mut T>, S, D>
177        // impl<T: Num, S: DimTrait, D: Device> $assign_trait<T> for Matrix<Ref<&mut T>, S, D> {
178        impl<T: Num, D: Device> $assign_trait<T> for Matrix<Ref<&mut T>, DimDyn, D> {
179            fn $assign_trait_method(&mut self, rhs: T) {
180                call_on_self!(self, $scalar_assign, rhs);
181            }
182        }
183
184        // AddAssign<T> for Matrix<Owned<T>, S, D>
185        impl<T: Num, S: DimTrait, D: Device> $assign_trait<T> for Matrix<Owned<T>, S, D> {
186            fn $assign_trait_method(&mut self, rhs: T) {
187                let mut ref_mut = self.to_ref_mut().into_dyn_dim();
188                call_on_self!(ref_mut, $scalar_assign, rhs);
189            }
190        }
191
192        // AddAssign<Matrix<RO, SO, D>> for Matrix<Owned<T>, SS, D>
193        impl<T: Num, SS: DimTrait, RO: Repr<Item = T>, SO: DimTrait, D: Device>
194            $assign_trait<Matrix<RO, SO, D>> for Matrix<Owned<T>, SS, D>
195        {
196            fn $assign_trait_method(&mut self, rhs: Matrix<RO, SO, D>) {
197                let mut ref_mut = self.to_ref_mut().into_dyn_dim();
198                let rhs = rhs.to_ref().into_dyn_dim();
199                call_on_self!(ref_mut, $array_assign, &rhs);
200            }
201        }
202
203        // AddAssign<Matrix<R, SO, D>> for Matrix<Ref<&mut T>, SS, D>
204        // impl<T: Num, R: Repr<Item = T>, SO: DimTrait, SS: DimTrait, D: Device>
205        impl<T: Num, R: Repr<Item = T>, SO: DimTrait, D: Device> $assign_trait<Matrix<R, SO, D>>
206            for Matrix<Ref<&mut T>, DimDyn, D>
207        {
208            fn $assign_trait_method(&mut self, rhs: Matrix<R, SO, D>) {
209                let rhs = rhs.to_ref().into_dyn_dim();
210                call_on_self!(self, $array_assign, &rhs);
211            }
212        }
213    };
214}
215impl_arithmetic_ops!(
216    Add,
217    add,
218    AddAssign,
219    add_assign,
220    add_scalar,
221    add_scalar_assign,
222    add_array,
223    add_assign
224);
225impl_arithmetic_ops!(
226    Sub,
227    sub,
228    SubAssign,
229    sub_assign,
230    sub_scalar,
231    sub_scalar_assign,
232    sub_array,
233    sub_assign
234);
235impl_arithmetic_ops!(
236    Mul,
237    mul,
238    MulAssign,
239    mul_assign,
240    mul_scalar,
241    mul_scalar_assign,
242    mul_array,
243    mul_assign
244);
245impl_arithmetic_ops!(
246    Div,
247    div,
248    DivAssign,
249    div_assign,
250    div_scalar,
251    div_scalar_assign,
252    div_array,
253    div_assign
254);