1#[macro_export]
2macro_rules! mul {
3 () => (1);
4 ($head:expr) => ($head);
5 ($head:expr, $($tail:expr),+) => ($head * mul!($($tail),*));
6}
7
8#[macro_export]
9macro_rules! sum {
10 () => (0);
11 ($head:expr) => (1);
12 ($head:expr, $($tail:expr),+) => (1 + sum!($($tail),*));
13}
14
15#[macro_export]
38macro_rules! tensor {
39 ($name:ident: $dim:literal) => {
40 make_tensor!($name $dim);
41
42 impl<T: tensor_macros::traits::TensorTrait> tensor_macros::traits::Vector for $name<T> {
43 const COLS: usize = $dim;
44 }
45 };
46
47 ($name:ident: row $dim:literal) => {
48 make_tensor!($name $dim);
49
50 impl<T: tensor_macros::traits::TensorTrait> tensor_macros::traits::RowVector for $name<T> {
51 const ROWS: usize = $dim;
52 }
53 };
54
55 ($name:ident: $dim1:literal x $dim2:literal) => {
56 make_tensor!($name $dim1 x $dim2);
57
58 impl<T: tensor_macros::traits::TensorTrait> tensor_macros::traits::Matrix for $name<T> {
59 const ROWS: usize = $dim1;
60 const COLS: usize = $dim2;
61 }
62 };
63
64 ($name:ident: $($dim:literal)x+ ) => (
65 make_tensor!($name $($dim) x *);
66 )
67}
68
69#[macro_export]
75macro_rules! make_tensor {
76 ($name:ident $($dim:literal)x+ ) => {
77
78
79 pub struct $name<T: tensor_macros::traits::TensorTrait> ([T; mul!($($dim),*)]);
81
82 impl<T: tensor_macros::traits::TensorTrait> $name<T> {
88 #[allow(dead_code)]
89 fn new() -> Self {
90 Default::default()
91 }
92 }
93
94 impl<T: tensor_macros::traits::TensorTrait> tensor_macros::traits::Tensor for $name<T> {
95 type Value = T;
96
97 const SIZE: usize = mul!($($dim),*);
98 const NDIM: usize = sum!($($dim),*);
99
100 fn dims() -> Vec<usize> {
101 vec!($($dim),*)
102 }
103
104 fn get_dims(&self) -> Vec<usize> {
105 Self::dims()
106 }
107 }
108
109 impl<T: tensor_macros::traits::TensorTrait> Copy for $name<T> { }
110
111 impl<T: tensor_macros::traits::TensorTrait> Clone for $name<T> {
112 fn clone(&self) -> Self {
113 let mut data: [T; mul!($($dim),*)];
114
115 unsafe {
116 data = std::mem::uninitialized();
117
118 for (i, elem) in (&mut data[..]).iter_mut().enumerate() {
119 std::ptr::write(elem, self.0[i]);
120 }
121 }
122
123 $name::<T>(data)
124 }
125 }
126
127 impl<T: tensor_macros::traits::TensorTrait> PartialEq for $name<T> {
128 fn eq(&self, other: &Self) -> bool {
129 for (p, q) in self.0.iter().zip(other.0.iter()) {
130 if p != q {
131 return false;
132 }
133 }
134
135 true
136 }
137 }
138
139 impl<T: tensor_macros::traits::TensorTrait> std::fmt::Debug for $name<T> {
140 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
141 debug_tensor!(f, self; $($dim),*;);
142
143 Ok(())
144 }
145 }
146
147 impl<T: tensor_macros::traits::TensorTrait> Default for $name<T> {
148 fn default() -> Self {
149 $name::<T>([Default::default(); mul!($($dim),*)])
150 }
151 }
152
153 impl<T: tensor_macros::traits::TensorTrait> std::convert::TryFrom<&[T]> for $name<T> {
154 type Error = tensor_macros::traits::TensorError;
155
156 fn try_from(v: &[T]) -> Result<Self, Self::Error> {
157 if v.len() < mul!($($dim),*) {
158 Err(tensor_macros::traits::TensorError::Size)
159 } else {
160 let mut a: [T; mul!($($dim),*)] = [Default::default(); mul!($($dim),*)];
161 a.copy_from_slice(&v[..mul!($($dim),*)]);
162 Ok($name::<T>(a))
163 }
164 }
165 }
166
167 impl<T: tensor_macros::traits::TensorTrait> std::convert::TryFrom<Vec<T>> for $name<T> {
168 type Error = tensor_macros::traits::TensorError;
169
170 fn try_from(v: Vec<T>) -> Result<Self, Self::Error> {
171 if v.len() < mul!($($dim),*) {
172 Err(tensor_macros::traits::TensorError::Size)
173 } else {
174 let mut a: [T; mul!($($dim),*)] = [Default::default(); mul!($($dim),*)];
175 a.copy_from_slice(&v[..mul!($($dim),*)]);
176 Ok($name::<T>(a))
177 }
178 }
179 }
180
181 impl<T: tensor_macros::traits::TensorTrait> std::convert::From<T> for $name<T> {
182 fn from(t: T) -> Self {
183 $name::<T>([t; mul!($($dim),*)])
184 }
185 }
186
187 impl<T, U, V> std::ops::Add<$name<U>> for $name<T>
188 where Self: tensor_macros::traits::Tensor<Value=T>,
189 T: tensor_macros::traits::TensorTrait + std::ops::Add<U, Output=V>,
190 U: tensor_macros::traits::TensorTrait,
191 V: tensor_macros::traits::TensorTrait,
192 {
193 type Output = $name<V>;
194
195 fn add(self, other: $name<U>) -> Self::Output {
196 let mut data: [V; mul!($($dim),*)];
197
198 unsafe {
199 data = std::mem::uninitialized();
200
201 for (i, elem) in (&mut data[..]).iter_mut().enumerate() {
202 std::ptr::write(elem, self.0[i] + other.0[i]);
203 }
204 }
205
206 $name::<V>(data)
207 }
208 }
209
210 impl<T, U> std::ops::AddAssign<$name<U>> for $name<T>
211 where T: tensor_macros::traits::TensorTrait + std::ops::AddAssign<U>,
212 U: tensor_macros::traits::TensorTrait,
213 {
214 fn add_assign(&mut self, other: $name<U>) {
215 for i in 0..mul!($($dim),*) {
216 self[i] += other[i];
217 }
218 }
219 }
220
221 impl<T, U, V> tensor_macros::traits::CwiseMul<$name<U>> for $name<T>
222 where T: tensor_macros::traits::TensorTrait + std::ops::Mul<U, Output=V>,
223 U: tensor_macros::traits::TensorTrait,
224 V: tensor_macros::traits::TensorTrait,
225 {
226 type Output = $name<V>;
227
228 fn cwise_mul(self, other: $name<U>) -> Self::Output {
229 let mut data: [V; mul!($($dim),*)];
230
231 unsafe {
232 data = std::mem::uninitialized();
233
234 for (i, elem) in (&mut data[..]).iter_mut().enumerate() {
235 std::ptr::write(elem, self.0[i] * other.0[i]);
236 }
237 }
238
239 $name::<V>(data)
240 }
241 }
242
243 impl<T, U> tensor_macros::traits::CwiseMulAssign<$name<U>> for $name<T>
244 where T: tensor_macros::traits::TensorTrait + std::ops::MulAssign<U>,
245 U: tensor_macros::traits::TensorTrait,
246 {
247 fn cwise_mul_assign(&mut self, other: $name<U>) {
248 for i in 0..mul!($($dim),*) {
249 self[i] *= other[i];
250 }
251 }
252 }
253
254 impl<T, U> std::ops::MulAssign<U> for $name<T>
255 where T: tensor_macros::traits::TensorTrait + std::ops::MulAssign<U>,
256 U: Clone,
257 {
258 fn mul_assign(&mut self, other: U) {
259 for i in 0..mul!($($dim),*) {
260 self[i] *= other.clone();
261 }
262 }
263 }
264
265 make_index_fn!($name; $($dim),*);
266 };
267}