1use crate::tensor::scalar::Scalar;
2use crate::units::Unit;
3use std::marker::PhantomData;
4use std::ops::*;
5
6use crate::complex::c64;
7use crate::tensor::element::*;
8
9use crate::dimension::Dimension;
10
11#[derive(Copy, Clone)]
12pub struct Tensor<E: TensorElement, D, const LAYERS: usize, const ROWS: usize, const COLS: usize>
13where
14 [(); LAYERS * ROWS * COLS]:,
15{
16 pub data: [E; LAYERS * ROWS * COLS],
17 pub _phantom: PhantomData<D>,
18}
19
20impl<E: TensorElement, D, const LAYERS: usize, const ROWS: usize, const COLS: usize> Tensor<E, D, LAYERS, ROWS, COLS>
21where
22 [(); LAYERS * ROWS * COLS]:,
23{
24 pub fn new<U: Unit<Dimension = D>>(values: [E; LAYERS * ROWS * COLS]) -> Self {
25 let data: [E; LAYERS * ROWS * COLS] = values
26 .iter()
27 .map(|&v| E::from(U::to_base(v.into())))
28 .collect::<Vec<_>>()
29 .try_into()
30 .unwrap();
31
32 Self {
33 data,
34 _phantom: PhantomData,
35 }
36 }
37
38 pub fn zero() -> Self {
39 let data: [E; LAYERS * ROWS * COLS] = [E::zero(); LAYERS * ROWS * COLS];
40
41 Tensor {
42 data,
43 _phantom: PhantomData,
44 }
45 }
46
47 pub fn random<U: Unit<Dimension = D>>(min: E, max: E) -> Self {
48 let base_min: E = E::from(U::to_base(min.into()));
49 let base_max: E = E::from(U::to_base(max.into()));
50 let data: [E; LAYERS * ROWS * COLS] = (0..LAYERS * ROWS * COLS)
51 .map(|_| {
52 E::from(U::from_base(((base_max - base_min) + base_min).weak_mul(rand::random::<f64>()).into()))
53 })
54 .collect::<Vec<_>>()
55 .try_into()
56 .unwrap();
57
58 Tensor {
59 data,
60 _phantom: PhantomData,
61 }
62 }
63
64 pub fn get<S: Unit<Dimension = D>>(&self) -> [E; LAYERS * ROWS * COLS] {
65 self.data
66 .iter()
67 .map(|&v| E::from(S::from_base(v.into())))
68 .collect::<Vec<_>>()
69 .try_into()
70 .unwrap()
71 }
72
73 pub fn get_at(&self, layer: usize, row: usize, col: usize) -> Scalar<E,D> {
74 assert!(layer < LAYERS && row < ROWS && col < COLS);
75 let idx = layer * (ROWS * COLS) + row * COLS + col;
76 Scalar::<E,D> {
77 data: [self.data[idx]],
78 _phantom: PhantomData,
79 }
80 }
81
82 pub fn set_at(&mut self, layer: usize, row: usize, col: usize, value: Scalar<E,D>) {
83 assert!(layer < LAYERS && row < ROWS && col < COLS);
84 let idx = layer * (ROWS * COLS) + row * COLS + col;
85 self.data[idx] = value.data[0];
86 }
87}
88
89impl<E: TensorElement, D, const LAYERS: usize, const ROWS: usize, const COLS: usize> Tensor<E, D, LAYERS, ROWS, COLS>
90where
91 [(); LAYERS * ROWS * COLS]:,
92{
93 pub fn dtype(&self) -> &'static str {
94 std::any::type_name::<E>()
95 }
96
97 pub fn cast<T: TensorElement>(&self) -> Tensor<T, D, LAYERS, ROWS, COLS>
98 where
99 T: TensorElement,
100 {
101 let data: [T; LAYERS * ROWS * COLS] = self
102 .data
103 .iter()
104 .map(|&v| (T::from(v.into() as c64)))
105 .collect::<Vec<_>>()
106 .try_into()
107 .unwrap();
108
109 Tensor {
110 data,
111 _phantom: PhantomData,
112 }
113 }
114}
115
116pub type Vector<E: TensorElement, D, const N: usize> = Tensor<E, D, 1, N, 1>;
118
119pub type Matrix<E: TensorElement,D, const N: usize, const M: usize> = Tensor<E,D, 1, N, M>;
121
122pub type Vec2<E: TensorElement,D> = Vector<E,D, 2>;
126pub type Vec3<E: TensorElement,D> = Vector<E,D, 3>;
128pub type Vec4<E: TensorElement,D> = Vector<E,D, 4>;
130
131pub type Mat2<E: TensorElement,D> = Matrix<E,D, 2, 2>;
133pub type Mat3<E: TensorElement,D> = Matrix<E,D, 3, 3>;
135pub type Mat4<E: TensorElement,D> = Matrix<E,D, 4, 4>;
137
138impl<E: TensorElement,D> Vec2<E,D> {
139 pub fn raw_tuple(&self) -> (E, E)
141 where
142 E: TensorElement,
143 {
144 (self.data[0], self.data[1])
145 }
146
147 pub fn raw_tuple_as<T: From<E>>(&self) -> (T, T)
149 where
150 E: TensorElement,
151 {
152 (T::from(self.data[0]), T::from(self.data[1]))
153 }
154}
155
156impl<E: TensorElement,D, const LAYERS: usize, const ROWS: usize, const COLS: usize> Tensor<E, D, LAYERS, ROWS, COLS>
157where
158 [(); LAYERS * ROWS * COLS]:,
159 E: TensorElement,
160{
161 pub fn raw_vec(&self) -> Vec<E> {
163 self.data.to_vec()
164 }
165
166 pub fn raw_vec_as<T: From<E>>(&self) -> Vec<T> {
168 self.data.iter().map(|&x| T::from(x)).collect()
169 }
170}
171
172impl<E: TensorElement,D> Vec2<E,D>
174where
175 E: TensorElement,
176{
177 pub fn x(&self) -> Scalar<E,D> {
178 Scalar::<E,D> {
179 data: [self.data[0]],
180 _phantom: PhantomData,
181 }
182 }
183
184 pub fn y(&self) -> Scalar<E,D> {
185 Scalar::<E,D> {
186 data: [self.data[1]],
187 _phantom: PhantomData,
188 }
189 }
190}
191
192impl<E: TensorElement,D> Vec3<E,D>
194where
195 E: TensorElement,
196{
197 pub fn x(&self) -> Scalar<E,D> {
198 Scalar::<E,D> {
199 data: [self.data[0]],
200 _phantom: PhantomData,
201 }
202 }
203
204 pub fn y(&self) -> Scalar<E,D> {
205 Scalar::<E,D> {
206 data: [self.data[1]],
207 _phantom: PhantomData,
208 }
209 }
210
211 pub fn z(&self) -> Scalar<E,D> {
212 Scalar::<E,D> {
213 data: [self.data[2]],
214 _phantom: PhantomData,
215 }
216 }
217}
218
219impl<E: TensorElement,D, const LAYERS: usize, const ROWS: usize, const COLS: usize> AddAssign for Tensor<E, D, LAYERS, ROWS, COLS>
221where
222 [(); LAYERS * ROWS * COLS]:,
223 E: TensorElement + AddAssign,
224{
225 fn add_assign(&mut self, other: Self) {
226 for i in 0..LAYERS {
227 for j in 0..ROWS {
228 for k in 0..COLS {
229 let idx = i * (ROWS * COLS) + j * COLS + k;
230 self.data[idx] += other.data[idx];
231 }
232 }
233 }
234 }
235}
236
237impl<const L: i32, const M: i32, const T: i32, const Θ: i32, const I: i32, const N: i32, const J: i32>
240 std::fmt::Display for Dimension<L, M, T, Θ, I, N, J>
241{
242 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
243 let mut parts = Vec::new();
245 if L != 0 {
246 parts.push(format!("L^{}", L));
247 }
248 if M != 0 {
249 parts.push(format!("M^{}", M));
250 }
251 if T != 0 {
252 parts.push(format!("T^{}", T));
253 }
254 if Θ != 0 {
255 parts.push(format!("Θ^{}", Θ));
256 }
257 if I != 0 {
258 parts.push(format!("I^{}", I));
259 }
260 if N != 0 {
261 parts.push(format!("N^{}", N));
262 }
263 if J != 0 {
264 parts.push(format!("J^{}", J));
265 }
266 if parts.is_empty() {
267 write!(f, "Dimensionless")
268 } else {
269 write!(f, "{}", parts.join(" * "))
270 }
271 }
272}
273
274impl<E: TensorElement,D: std::fmt::Display + Default, const LAYERS: usize, const ROWS: usize, const COLS: usize> std::fmt::Display
275 for Tensor<E,D, LAYERS, ROWS, COLS>
276where
277 [(); LAYERS * ROWS * COLS]:,
278{
279 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
280 writeln!(f, "Tensor [{}x{}x{}]: {}", LAYERS, ROWS, COLS, D::default())?;
281 for l in 0..LAYERS {
283 writeln!(f, "-- Layer {} --", l)?;
284 for i in 0..ROWS {
285 write!(f, "(")?;
286 for j in 0..COLS {
287 let idx = l * (ROWS * COLS) + i * COLS + j;
288 write!(f, " {} ", self.data[idx])?;
289 }
290 writeln!(f, ")")?;
291 }
292 }
293 Ok(())
294 }
295}
296
297impl<E: TensorElement,D: std::fmt::Debug + Default, const LAYERS: usize, const ROWS: usize, const COLS: usize> std::fmt::Debug
298 for Tensor<E,D, LAYERS, ROWS, COLS>
299where
300 [(); LAYERS * ROWS * COLS]:,
301{
302 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
303 f.debug_struct("Tensor")
304 .field("dimension", &D::default())
305 .field("shape", &format!("{}x{}x{}", LAYERS, ROWS, COLS))
306 .field("data", &self.data)
307 .finish()
308 }
309}
310