tensor_macros/
tensor.rs

1#[macro_export]
2macro_rules! mul {
3    () => (1);
4    ($head:expr) => ($head);
5    ($head:expr, $($tail:expr),+) => ($head * mul!($($tail),*));
6}
7
8#[macro_export]
9macro_rules! sum {
10    () => (0);
11    ($head:expr) => (1);
12    ($head:expr, $($tail:expr),+) => (1 + sum!($($tail),*));
13}
14
15/// Generates a tensor type
16///
17/// Generates a type with the given name and dimensions (space seperated)
18/// There's no upper limit on the amount of dimensions given
19/// Matricies and vectors have special properties assigned to them
20///
21/// # Example
22///
23/// ```rust
24/// #![feature(try_from)]
25///
26/// #[macro_use]
27/// use tensor_macros::*;
28/// use tensor_macros::traits::*;
29///
30/// tensor!(M23: 2 x 3);
31///
32/// assert_eq!(M23::<f64>::dims(), vec!(2, 3));
33///
34/// let m23: M23<f64> = Default::default();
35/// assert_eq!(m23.get_dims(), vec!(2, 3));
36/// ```
37#[macro_export]
38macro_rules! tensor {
39	($name:ident: $dim:literal) => {
40		make_tensor!($name $dim);
41
42		impl<T: tensor_macros::traits::TensorTrait> tensor_macros::traits::Vector for $name<T> {
43			const COLS: usize = $dim;
44		}
45	};
46
47	($name:ident: row $dim:literal) => {
48		make_tensor!($name $dim);
49
50		impl<T: tensor_macros::traits::TensorTrait> tensor_macros::traits::RowVector for $name<T> {
51			const ROWS: usize = $dim;
52		}
53	};
54
55	($name:ident: $dim1:literal x $dim2:literal) => {
56		make_tensor!($name $dim1 x $dim2);
57
58		impl<T: tensor_macros::traits::TensorTrait> tensor_macros::traits::Matrix for $name<T> {
59			const ROWS: usize = $dim1;
60			const COLS: usize = $dim2;
61		}
62	};
63
64	($name:ident: $($dim:literal)x+ ) => (
65		make_tensor!($name $($dim) x *);
66	)
67}
68
69/// Generates a tensor type
70///
71/// Use [`tensor!`] instead, it uses this macro and more
72///
73/// [`tensor!`]: macro.tensor.html
74#[macro_export]
75macro_rules! make_tensor {
76	($name:ident $($dim:literal)x+ ) => {
77
78
79		// #[derive(TensorTranspose)]
80		pub struct $name<T: tensor_macros::traits::TensorTrait> ([T; mul!($($dim),*)]);
81
82		// pub struct concat_idents!($name, _transpose)<T, TT>
83		// 	where T: tensor_macros::traits::Tensor<TT>,
84		// 	TT: tensor_macros::traits::TensorTrait
85		// (T);
86
87		impl<T: tensor_macros::traits::TensorTrait> $name<T> {
88			#[allow(dead_code)]
89			fn new() -> Self {
90				Default::default()
91			}
92		}
93
94		impl<T: tensor_macros::traits::TensorTrait> tensor_macros::traits::Tensor for $name<T> {
95			type Value = T;
96
97			const SIZE: usize = mul!($($dim),*);
98			const NDIM: usize = sum!($($dim),*);
99
100			fn dims() -> Vec<usize> {
101				vec!($($dim),*)
102			}
103
104			fn get_dims(&self) -> Vec<usize> {
105				Self::dims()
106			}
107		}
108
109		impl<T: tensor_macros::traits::TensorTrait> Copy for $name<T> { }
110
111		impl<T: tensor_macros::traits::TensorTrait> Clone for $name<T> {
112			fn clone(&self)	-> Self {
113				let mut data: [T; mul!($($dim),*)];
114
115				unsafe {
116					data = std::mem::uninitialized();
117
118					for (i, elem) in (&mut data[..]).iter_mut().enumerate() {
119						std::ptr::write(elem, self.0[i]);
120				    }
121				}
122
123				$name::<T>(data)
124			}
125		}
126
127		impl<T: tensor_macros::traits::TensorTrait> PartialEq for $name<T> {
128			fn eq(&self, other: &Self) -> bool {
129				for (p, q) in self.0.iter().zip(other.0.iter()) {
130					if p != q {
131						return false;
132					}
133				}
134
135				true
136			}
137		}
138
139		impl<T: tensor_macros::traits::TensorTrait> std::fmt::Debug for $name<T>  {
140			fn fmt(&self, f:  &mut std::fmt::Formatter) -> std::fmt::Result {
141				debug_tensor!(f, self; $($dim),*;);
142
143				Ok(())
144			}
145		}
146
147		impl<T: tensor_macros::traits::TensorTrait> Default for $name<T> {
148			fn default() -> Self {
149				$name::<T>([Default::default(); mul!($($dim),*)])
150			}
151		}
152
153		impl<T: tensor_macros::traits::TensorTrait> std::convert::TryFrom<&[T]> for $name<T> {
154			type Error = tensor_macros::traits::TensorError;
155
156			fn try_from(v: &[T]) -> Result<Self, Self::Error> {
157				if v.len() < mul!($($dim),*) {
158					Err(tensor_macros::traits::TensorError::Size)
159				} else {
160					let mut a: [T; mul!($($dim),*)] = [Default::default(); mul!($($dim),*)];
161					a.copy_from_slice(&v[..mul!($($dim),*)]);
162					Ok($name::<T>(a))
163				}
164			}
165		}
166
167		impl<T: tensor_macros::traits::TensorTrait> std::convert::TryFrom<Vec<T>> for $name<T> {
168			type Error = tensor_macros::traits::TensorError;
169
170			fn try_from(v: Vec<T>) -> Result<Self, Self::Error> {
171				if v.len() < mul!($($dim),*) {
172					Err(tensor_macros::traits::TensorError::Size)
173				} else {
174					let mut a: [T; mul!($($dim),*)] = [Default::default(); mul!($($dim),*)];
175					a.copy_from_slice(&v[..mul!($($dim),*)]);
176					Ok($name::<T>(a))
177				}
178			}
179		}
180
181		impl<T: tensor_macros::traits::TensorTrait> std::convert::From<T> for $name<T> {
182			fn from(t: T) -> Self {
183				$name::<T>([t; mul!($($dim),*)])
184			}
185		}
186
187		impl<T, U, V> std::ops::Add<$name<U>> for $name<T>
188			where Self: tensor_macros::traits::Tensor<Value=T>,
189			T: tensor_macros::traits::TensorTrait + std::ops::Add<U, Output=V>,
190			U: tensor_macros::traits::TensorTrait,
191			V: tensor_macros::traits::TensorTrait,
192		{
193			type Output = $name<V>;
194
195			fn add(self, other: $name<U>) -> Self::Output {
196				let mut data: [V; mul!($($dim),*)];
197
198				unsafe {
199					data = std::mem::uninitialized();
200
201					for (i, elem) in (&mut data[..]).iter_mut().enumerate() {
202						std::ptr::write(elem, self.0[i] + other.0[i]);
203				    }
204				}
205
206				$name::<V>(data)
207		    }
208		}
209
210		impl<T, U> std::ops::AddAssign<$name<U>> for $name<T>
211			where T: tensor_macros::traits::TensorTrait + std::ops::AddAssign<U>,
212			U: tensor_macros::traits::TensorTrait,
213		{
214			fn add_assign(&mut self, other: $name<U>) {
215				for i in 0..mul!($($dim),*) {
216					self[i] += other[i];
217				}
218		    }
219		}
220
221		impl<T, U, V> tensor_macros::traits::CwiseMul<$name<U>> for $name<T>
222			where T: tensor_macros::traits::TensorTrait + std::ops::Mul<U, Output=V>,
223			U: tensor_macros::traits::TensorTrait,
224			V: tensor_macros::traits::TensorTrait,
225		{
226			type Output = $name<V>;
227
228			fn cwise_mul(self, other: $name<U>) -> Self::Output {
229				let mut data: [V; mul!($($dim),*)];
230
231				unsafe {
232					data = std::mem::uninitialized();
233
234					for (i, elem) in (&mut data[..]).iter_mut().enumerate() {
235						std::ptr::write(elem, self.0[i] * other.0[i]);
236				    }
237				}
238
239				$name::<V>(data)
240			}
241		}
242
243		impl<T, U> tensor_macros::traits::CwiseMulAssign<$name<U>> for $name<T>
244			where T: tensor_macros::traits::TensorTrait + std::ops::MulAssign<U>,
245			U: tensor_macros::traits::TensorTrait,
246		{
247			fn cwise_mul_assign(&mut self, other: $name<U>) {
248				for i in 0..mul!($($dim),*) {
249					self[i] *= other[i];
250				}
251			}
252		}
253
254		impl<T, U> std::ops::MulAssign<U> for $name<T>
255			where T: tensor_macros::traits::TensorTrait + std::ops::MulAssign<U>,
256			U: Clone,
257		{
258			fn mul_assign(&mut self, other: U) {
259				for i in 0..mul!($($dim),*) {
260					self[i] *= other.clone();
261				}
262			}
263		}
264
265		make_index_fn!($name; $($dim),*);
266	};
267}