1#[macro_export]
2macro_rules! make_transpose_index_fn {
3 ($name:ident; $dim:literal) => {
4 impl<T: tensor_macros::traits::TensorTrait> std::ops::Index<usize> for $name<T> {
5 type Output = T;
6
7 fn index(&self, i: usize) -> &Self::Output {
8 &self.0[i]
9 }
10 }
11
12 impl<T: tensor_macros::traits::TensorTrait> std::ops::IndexMut<usize> for $name<T> {
13 fn index_mut(&mut self, i: usize) -> &mut T {
14 &mut self.0[i]
15 }
16 }
17 };
18
19 ($name:ident; $($dims:literal),+) => {
20 impl<T: tensor_macros::traits::TensorTrait> std::ops::Index<usize> for $name<T> {
21 type Output = T;
22
23 fn index(&self, i: usize) -> &Self::Output {
24 &self.0[i]
25 }
26 }
27
28 impl<T: tensor_macros::traits::TensorTrait> std::ops::IndexMut<usize> for $name<T> {
29 fn index_mut(&mut self, i: usize) -> &mut T {
30 &mut self.0[i]
31 }
32 }
33
34 make_transpose_index_fn!($name; $($dims),*;;;);
35 };
36
37 ($name:ident; $dim:literal $(,$dims:literal)*; $($i:ident),*; $($t:ty),*; $($dims_bk:literal),*) => {
38 make_transpose_index_fn!($name; $($dims),*; $($i,)* i; $($t,)* usize; $($dims_bk,)* $dim);
39 };
40 ($name:ident; ; $($i:ident),*; $($t:ty),*; $($dims:literal),*) => {
41 impl<T: tensor_macros::traits::TensorTrait> std::ops::Index<( $($t),* )> for $name<T> {
42 type Output = T;
43
44 fn index(&self, ( $($i),* ): ( $($t),* )) -> &Self::Output {
45 &self.0[
46 make_transpose_index_val!($($dims),*; $($i),*;)
47 ]
48 }
49 }
50
51 impl<T: tensor_macros::traits::TensorTrait> std::ops::IndexMut<( $($t),* )> for $name<T> {
52 fn index_mut(&mut self, ( $($i),* ): ( $($t),* )) -> &mut T {
53 &mut self.0[
54 make_transpose_index_val!($($dims),*; $($i),*;)
55 ]
56 }
57 }
58 };
59}
60
61#[macro_export]
62macro_rules! make_transpose_index_val {
63 ($($dims:literal),*; $i:expr $(,$is:expr)* ; $($js:expr),*) => (
64 make_transpose_index_val!($($dims),*; $($is),*; $i $(,$js)*)
65 );
66 ($($dims:literal),*; ; $($js:expr),*) => (
67 make_transpose_index_val!(~$($dims),*; $($js),*)
68 );
69 (~$dim:literal $(,$dims:literal)*; $i:expr $(,$is:expr)* ) => (
70 $i * mul!($($dims),*) + make_transpose_index_val!(~$($dims),*; $($is),*)
71 );
72 (~;) => (0)
73}
74
75#[macro_export]
76macro_rules! transpose {
98 ($from:ident: $($dim:literal)x+ => $to:ident) => {
99 transpose!(~ $from, $to; $($dim),*;;);
100 };
101 (~ $from:ident, $to:ident; $d:literal $(,$dims:literal)*; $($fd:literal),*; $($td:literal),*) => {
102 transpose!(~ $from, $to; $($dims),*; $($fd,)* $d; $d $(,$td)*);
103 };
104 (~ $from:ident, $to:ident;; $($fd:literal),*; $($td:literal),*) => {
105 transpose!($from: $($fd)x* => $to: $($td)x*);
106 };
107 ($from:ident: $($from_dim:literal)x+ => $to:ident: $($to_dim:literal)x+) => {
108 pub struct $to<T: tensor_macros::traits::TensorTrait>($from<T>);
110
111 impl<T: tensor_macros::traits::TensorTrait> $to<T> {
112 #[allow(dead_code)]
113 fn new() -> Self {
114 Default::default()
115 }
116 }
117
118 impl<T> tensor_macros::traits::TensorTranspose<$to<T>, T> for $from<T>
119 where
120 T: tensor_macros::traits::TensorTrait,
121 {
122 fn transpose(self) -> $to<T> {
123 $to(self)
124 }
125 }
126
127 impl<T> tensor_macros::traits::TensorTranspose<$from<T>, T> for $to<T>
128 where
129 T: tensor_macros::traits::TensorTrait,
130 {
131 fn transpose(self) -> $from<T> {
132 self.0
133 }
134 }
135
136 impl<T> tensor_macros::traits::Tensor for $to<T>
137 where T: tensor_macros::traits::TensorTrait,
138 {
139 type Value = T;
140
141 const SIZE: usize = <$from<T> as tensor_macros::traits::Tensor>::SIZE;
142 const NDIM: usize = <$from<T> as tensor_macros::traits::Tensor>::NDIM;
143
144 fn dims() -> Vec<usize> {
145 vec!($($to_dim),*)
146 }
147
148 fn get_dims(&self) -> Vec<usize> {
149 Self::dims()
150 }
151 }
152
153 impl<T: tensor_macros::traits::TensorTrait> Copy for $to<T> { }
154
155 impl<T: tensor_macros::traits::TensorTrait> Clone for $to<T> {
156 fn clone(&self) -> Self {
157 $to(self.0)
158 }
159 }
160
161 impl<T: tensor_macros::traits::TensorTrait> PartialEq for $to<T> {
162 fn eq(&self, other: &Self) -> bool {
163 self.0 == other.0
164 }
165 }
166
167 impl<T: tensor_macros::traits::TensorTrait> Default for $to<T> {
168 fn default() -> Self {
169 $to($from::default())
170 }
171 }
172
173 impl<T: tensor_macros::traits::TensorTrait> std::convert::TryFrom<&[T]> for $to<T> {
174 type Error = tensor_macros::traits::TensorError;
175
176 fn try_from(v: &[T]) -> Result<Self, Self::Error> {
177 Ok($to($from::try_from(v)?))
178 }
179 }
180
181 impl<T: tensor_macros::traits::TensorTrait> std::convert::TryFrom<Vec<T>> for $to<T> {
182 type Error = tensor_macros::traits::TensorError;
183
184 fn try_from(v: Vec<T>) -> Result<Self, Self::Error> {
185 Ok($to($from::try_from(v)?))
186 }
187 }
188
189 impl<T: tensor_macros::traits::TensorTrait> std::convert::From<T> for $to<T> {
190 fn from(t: T) -> Self {
191 $to($from::from(t))
192 }
193 }
194
195 make_transpose_index_fn!($to; $($from_dim),*);
196
197 impl<T: tensor_macros::traits::TensorTrait> std::fmt::Debug for $to<T> {
198 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
199 debug_tensor!(f, self; $($to_dim),*;);
200
201 Ok(())
202 }
203 }
204
205 impl<T, U, V> std::ops::Add<$to<U>> for $to<T>
206 where Self: tensor_macros::traits::Tensor<Value=T>,
207 T: tensor_macros::traits::TensorTrait + std::ops::Add<U, Output=V>,
208 U: tensor_macros::traits::TensorTrait,
209 V: tensor_macros::traits::TensorTrait,
210 {
211 type Output = $to<V>;
212
213 fn add(self, other: $to<U>) -> Self::Output {
214 $to(self.0 + other.0)
215 }
216 }
217
218 impl<T, U> std::ops::AddAssign<$to<U>> for $to<T>
219 where T: tensor_macros::traits::TensorTrait + std::ops::AddAssign<U>,
220 U: tensor_macros::traits::TensorTrait,
221 {
222 fn add_assign(&mut self, other: $to<U>) {
223 for i in 0..mul!($($to_dim),*) {
224 self[i] += other[i];
225 }
226 }
227 }
228
229 impl<T, U, V> tensor_macros::traits::CwiseMul<$to<U>> for $to<T>
230 where T: tensor_macros::traits::TensorTrait + std::ops::Mul<U, Output=V>,
231 U: tensor_macros::traits::TensorTrait,
232 V: tensor_macros::traits::TensorTrait,
233 {
234 type Output = $to<V>;
235
236 fn cwise_mul(self, other: $to<U>) -> Self::Output {
237 $to(self.0.cwise_mul(other.0))
238 }
239 }
240
241 impl<T, U> tensor_macros::traits::CwiseMulAssign<$to<U>> for $to<T>
242 where T: tensor_macros::traits::TensorTrait + std::ops::MulAssign<U>,
243 U: tensor_macros::traits::TensorTrait,
244 {
245 fn cwise_mul_assign(&mut self, other: $to<U>) {
246 for i in 0..mul!($($to_dim),*) {
247 self[i] *= other[i];
248 }
249 }
250 }
251
252 impl<T, U> std::ops::MulAssign<U> for $to<T>
253 where T: tensor_macros::traits::TensorTrait + std::ops::MulAssign<U>,
254 U: Clone,
255 {
256 fn mul_assign(&mut self, other: U) {
257 for i in 0..mul!($($to_dim),*) {
258 self[i] *= other.clone();
259 }
260 }
261 }
262 };
263}