1use crate::complex::c64;
4use crate::dimension::{Dimension, InvertDimension, MultiplyDimensions};
5use crate::tensor::element::TensorElement;
6use crate::tensor::Tensor;
7use std::marker::PhantomData;
8use std::ops::{Add, Mul, Neg, Sub, Div};
9use crate::*;
10
11impl<E: TensorElement + Add<Output = E> + Copy, D, const LAYERS: usize, const ROWS: usize, const COLS: usize>
16Add for Tensor<E, D, LAYERS, ROWS, COLS>
17where
18 [(); LAYERS * ROWS * COLS]:,
19{
20 type Output = Self;
21
22 fn add(self, other: Self) -> Self {
23 let data: [E; LAYERS * ROWS * COLS] = self
24 .data
25 .iter()
26 .zip(other.data.iter())
27 .map(|(&a, &b)| a + b)
28 .collect::<Vec<_>>()
29 .try_into()
30 .unwrap();
31
32 Self {
33 data,
34 _phantom: PhantomData,
35 }
36 }
37}
38
39impl<E: TensorElement + Sub<Output = E> + Copy, D, const LAYERS: usize, const ROWS: usize, const COLS: usize>
40Sub for Tensor<E, D, LAYERS, ROWS, COLS>
41where
42 [(); LAYERS * ROWS * COLS]:,
43{
44 type Output = Self;
45
46 fn sub(self, other: Self) -> Self {
47 let data: [E; LAYERS * ROWS * COLS] = self
48 .data
49 .iter()
50 .zip(other.data.iter())
51 .map(|(&a, &b)| a - b)
52 .collect::<Vec<_>>()
53 .try_into()
54 .unwrap();
55
56 Self {
57 data,
58 _phantom: PhantomData,
59 }
60 }
61}
62
63impl<
64 E: TensorElement + Mul<Output = E> + Add<Output = E> + Copy,
65 const LAYERS: usize,
66 const L1: i32,
67 const M1: i32,
68 const T1: i32,
69 const Θ1: i32,
70 const I1: i32,
71 const N1: i32,
72 const J1: i32,
73 const ROWS: usize,
74 const COMMON: usize,
75> Tensor<E, Dimension<L1, M1, T1, Θ1, I1, N1, J1>, LAYERS, ROWS, COMMON>
76where
77 [(); LAYERS * ROWS * COMMON]:,
78{
79 pub fn matmul<
81 const L2: i32,
82 const M2: i32,
83 const T2: i32,
84 const Θ2: i32,
85 const I2: i32,
86 const N2: i32,
87 const J2: i32,
88 const COLS: usize,
89 >(
90 self,
91 other: Tensor<E, Dimension<L2, M2, T2, Θ2, I2, N2, J2>, LAYERS, COMMON, COLS>,
92 ) -> Tensor<
93 E,
94 <Dimension<L1, M1, T1, Θ1, I1, N1, J1> as MultiplyDimensions<
95 Dimension<L2, M2, T2, Θ2, I2, N2, J2>
96 >>::Output,
97 LAYERS,
98 ROWS,
99 COLS,
100 >
101 where
102 Dimension<L1, M1, T1, Θ1, I1, N1, J1>: MultiplyDimensions<Dimension<L2, M2, T2, Θ2, I2, N2, J2>>,
103 [(); LAYERS * COMMON * COLS]:,
104 [(); LAYERS * ROWS * COLS]:,
105 [(); COLS]:,
106 {
107 let mut result = vec![E::zero(); LAYERS * ROWS * COLS];
109
110 for layer in 0..LAYERS {
112 for row in 0..ROWS {
113 for col in 0..COLS {
114 let mut sum = E::zero();
115 for k in 0..COMMON {
117 let index_a = layer * (ROWS * COMMON) + row * COMMON + k;
118 let index_b = layer * (COMMON * COLS) + k * COLS + col;
119 sum = sum + self.data[index_a] * other.data[index_b];
120 }
121 let index_result = layer * (ROWS * COLS) + row * COLS + col;
122 result[index_result] = sum;
123 }
124 }
125 }
126
127 let data: [E; LAYERS * ROWS * COLS] =
129 result.into_iter().collect::<Vec<E>>().try_into().unwrap();
130
131 Tensor {
132 data,
133 _phantom: PhantomData,
134 }
135 }
136}
137
138impl<E: TensorElement + Mul<Output = E> + Copy, D, const LAYERS: usize, const ROWS: usize, const COLS: usize>
139Tensor<E, D, LAYERS, ROWS, COLS>
140where
141 [(); LAYERS * ROWS * COLS]:,
142{
143 pub fn scale<DS>(
145 self,
146 scalar: Tensor<E, DS, 1, 1, 1>,
147 ) -> Tensor<E, <D as MultiplyDimensions<DS>>::Output, LAYERS, ROWS, COLS>
148 where
149 D: MultiplyDimensions<DS>,
150 <D as MultiplyDimensions<DS>>::Output:,
151 {
152 let s = scalar.data[0];
153 let data: [E; LAYERS * ROWS * COLS] = self
154 .data
155 .iter()
156 .map(|&v| v * s)
157 .collect::<Vec<_>>()
158 .try_into()
159 .unwrap();
160
161 Tensor {
162 data,
163 _phantom: PhantomData::<<D as MultiplyDimensions<DS>>::Output>,
164 }
165 }
166}
167
168impl<E, D, DS, const LAYERS: usize, const ROWS: usize, const COLS: usize>
169 Mul<Tensor<E, DS, 1, 1, 1>> for Tensor<E, D, LAYERS, ROWS, COLS>
170where
171 E: TensorElement + Mul<Output = E> + Copy,
172 D: MultiplyDimensions<DS>,
173 [(); LAYERS * ROWS * COLS]:,
174{
175 type Output = Tensor<E, <D as MultiplyDimensions<DS>>::Output, LAYERS, ROWS, COLS>;
176
177 fn mul(self, rhs: Tensor<E, DS, 1, 1, 1>) -> Self::Output {
178 self.scale(rhs)
179 }
180}
181
182
183impl<E, D, DS, const LAYERS: usize, const ROWS: usize, const COLS: usize>
184 Div<Tensor<E, DS, 1, 1, 1>> for Tensor<E, D, LAYERS, ROWS, COLS>
185where
186 E: TensorElement + Div<Output = E> + Copy,
187 DS: InvertDimension,
188 D: MultiplyDimensions<<DS as InvertDimension>::Output>,
189 [(); LAYERS * ROWS * COLS]:,
190{
191 type Output = Tensor<
192 E,
193 <D as MultiplyDimensions<<DS as InvertDimension>::Output>>::Output,
194 LAYERS,
195 ROWS,
196 COLS
197 >;
198
199 fn div(self, rhs: Tensor<E, DS, 1, 1, 1>) -> Self::Output {
200 self.scale(rhs.inv())
201 }
202}
203
204impl<E: TensorElement + Div<Output = E> + Copy + PartialEq, D> Tensor<E, D, 1, 1, 1>
205where
206 [(); 1]:,
207{
208 pub fn inv(self) -> Tensor<E, <D as InvertDimension>::Output, 1, 1, 1>
209 where
210 D: InvertDimension,
211 {
212 let data: [E; 1] = [E::one() / self.data[0]];
213 Tensor {
214 data,
215 _phantom: PhantomData,
216 }
217 }
218}
219
220impl<E: TensorElement + Neg<Output = E> + Copy, D, const LAYERS: usize, const ROWS: usize, const COLS: usize>
221Neg for Tensor<E, D, LAYERS, ROWS, COLS>
222where
223 [(); LAYERS * ROWS * COLS]:,
224{
225 type Output = Self;
226
227 fn neg(self) -> Self {
228 let data: [E; LAYERS * ROWS * COLS] = self
229 .data
230 .iter()
231 .map(|&v| -v)
232 .collect::<Vec<_>>()
233 .try_into()
234 .unwrap();
235
236 Self {
237 data,
238 _phantom: PhantomData,
239 }
240 }
241}
242
243impl<E, D, const LAYERS: usize, const ROWS: usize, const COLS: usize> Tensor<E, D, LAYERS, ROWS, COLS>
244where
245 E: TensorElement + Into<c64> + Copy,
246 [(); LAYERS * ROWS * COLS]:,
247{
248 pub fn to_c64(&self) -> Tensor<c64, D, LAYERS, ROWS, COLS> {
250 let data: [c64; LAYERS * ROWS * COLS] = self
251 .data
252 .iter()
253 .map(|&v| v.into())
254 .collect::<Vec<_>>()
255 .try_into()
256 .unwrap();
257 Tensor {
258 data,
259 _phantom: PhantomData,
260 }
261 }
262}
263
264impl<E: TensorElement,D, const LAYERS: usize, const ROWS: usize, const COLS: usize> Tensor<E,D, LAYERS, ROWS, COLS>
266where
267 [(); LAYERS * ROWS * COLS]:,
268{
269 pub fn conjugate(self) -> Self {
270 let data: [E; LAYERS * ROWS * COLS] = self
271 .data
272 .iter()
273 .map(|&v| v.conjugate())
274 .collect::<Vec<_>>()
275 .try_into()
276 .unwrap();
277
278 Self {
279 data,
280 _phantom: PhantomData,
281 }
282 }
283
284 pub fn conjugate_transpose(self) -> Tensor<E,D, LAYERS, COLS, ROWS>
286 where
287 [(); LAYERS * COLS * ROWS]:,
288 {
289 self.transpose().conjugate()
290 }
291}
292
293impl<E: TensorElement,D, const LAYERS: usize, const ROWS: usize, const COLS: usize> Tensor<E,D, LAYERS, ROWS, COLS>
294where
295 [(); LAYERS * ROWS * COLS]:,
296{
297 pub fn transpose(self) -> Tensor<E,D, LAYERS, COLS, ROWS>
299 where
300 [(); LAYERS * COLS * ROWS]:,
301 {
302 let mut transposed = [E::zero(); LAYERS * COLS * ROWS];
303 for l in 0..LAYERS {
304 for i in 0..ROWS {
305 for j in 0..COLS {
306 let src = l * (ROWS * COLS) + i * COLS + j;
308 let dst = l * (COLS * ROWS) + j * ROWS + i;
309 transposed[dst] = self.data[src];
310 }
311 }
312 }
313 Tensor::<E,D, LAYERS, COLS, ROWS> {
314 data: transposed,
315 _phantom: PhantomData,
316 }
317 }
318}
319
320
321
322
323impl<
325 E: TensorElement + Into<f64> + Copy,
326 const L: i32,
327 const M: i32,
328 const T: i32,
329 const Θ: i32,
330 const I: i32,
331 const N: i32,
332 const J: i32,
333 const ROWS: usize
334>
335Tensor<E, Dimension<L, M, T, Θ, I, N, J>, 1, ROWS, 1>
336where
337 [(); 1 * ROWS * 1]:,
338 [(); 1 * 1 * ROWS]:,
339 [(); ROWS * 1 * 1]:,
340{
341 pub fn norm(
342 self
343 ) -> Tensor<f64, Dimension<L, M, T, Θ, I, N, J>, 1, 1, 1>
344 where
345 [(); { <() as ConstAdd<L, L>>::OUTPUT } as usize]:,
346 [(); { <() as ConstAdd<M, M>>::OUTPUT } as usize]:,
347 [(); { <() as ConstAdd<T, T>>::OUTPUT } as usize]:,
348 [(); { <() as ConstAdd<Θ, Θ>>::OUTPUT } as usize]:,
349 [(); { <() as ConstAdd<I, I>>::OUTPUT } as usize]:,
350 [(); { <() as ConstAdd<N, N>>::OUTPUT } as usize]:,
351 [(); { <() as ConstAdd<J, J>>::OUTPUT } as usize]:,
352 {
353 let ct: Tensor<E, Dimension<L, M, T, Θ, I, N, J>, 1, 1, ROWS> = self.conjugate_transpose();
354 let i: Tensor<E, Dimension<_, _, _, _, _, _, _>, 1, 1, 1> = ct.matmul(self);
355
356 let val: c64 = i.data[0].into();
358 let sqrt_val = f64::from(val.sqrt());
359
360 Tensor {
361 data: [sqrt_val],
362 _phantom: PhantomData,
363 }
364 }
365
366 pub fn dist(
367 self,
368 other: Self,
369 ) -> Tensor<f64, Dimension<L, M, T, Θ, I, N, J>, 1, 1, 1>
370 where
371 [(); 1 * ROWS * 1]:,
372 [(); 1 * 1 * ROWS]:,
373 [(); { <() as ConstAdd<L, L>>::OUTPUT } as usize]:,
374 [(); { <() as ConstAdd<M, M>>::OUTPUT } as usize]:,
375 [(); { <() as ConstAdd<T, T>>::OUTPUT } as usize]:,
376 [(); { <() as ConstAdd<Θ, Θ>>::OUTPUT } as usize]:,
377 [(); { <() as ConstAdd<I, I>>::OUTPUT } as usize]:,
378 [(); { <() as ConstAdd<N, N>>::OUTPUT } as usize]:,
379 [(); { <() as ConstAdd<J, J>>::OUTPUT } as usize]:,
380 {
381 let sub = self - other;
382 sub.norm()
383 }
384
385}
386
387impl<E: TensorElement,D, const LAYERS: usize, const ROWS: usize, const COLS: usize> PartialEq for Tensor<E,D, LAYERS, ROWS, COLS>
389where
390 [(); LAYERS * ROWS * COLS]:,
391{
392 fn eq(&self, other: &Self) -> bool {
393 self.data
394 .iter()
395 .zip(other.data.iter())
396 .all(|(&a, &b)| a == b)
397 }
398}
399
400impl<E:TensorElement,D, const LAYERS: usize, const ROWS: usize, const COLS: usize> Eq for Tensor<E,D, LAYERS, ROWS, COLS>
402where
403 [(); LAYERS * ROWS * COLS]:,
404 c64: Eq,
405{
406}
407
408impl<E: TensorElement,D> PartialOrd for Tensor<E,D, 1, 1, 1>
410where
411 [(); 1]:,
412{
413 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
414 self.data[0].partial_cmp(&other.data[0])
415 }
416}
417
418#[macro_export]
420macro_rules! dot {
421 ($a:expr, $b:expr) => {{
422 let a = $a;
423 let b = $b;
424 let a_t = a.transpose();
425 let result = a_t.matmul(b);
426 result
427 }};
428}
429
430#[macro_export]
431macro_rules! inner_product {
432 ($a:expr, $b:expr) => {{
433 let a = $a;
434 let b = $b;
435 let a_t = a.conjugate_transpose();
436 let result = a_t.matmul(b);
437 result
438 }};
439}
440
441#[macro_export]
442macro_rules! ip {
443 ($x:expr, $y:expr) => {
444 inner_product!($x, $y)
445 };
446}
447
448impl<E: TensorElement,D, const LAYERS: usize, const ROWS: usize, const COLS: usize> Tensor<E,D, LAYERS, ROWS, COLS>
452where
453 [(); LAYERS * ROWS * COLS]:,
454{
455 pub fn size(&self) -> usize {
457 LAYERS * ROWS * COLS
458 }
459
460 pub fn shape(&self) -> (usize, usize, usize) {
461 (LAYERS, ROWS, COLS)
462 }
463
464 pub fn layers(&self) -> usize {
466 LAYERS
467 }
468
469 pub fn rows(&self) -> usize {
471 ROWS
472 }
473
474 pub fn cols(&self) -> usize {
476 COLS
477 }
478
479 pub fn data(&self) -> &[E] {
481 &self.data
482 }
483
484 pub fn reshape<const L: usize, const R: usize, const C: usize>(
485 &self,
486 ) -> Tensor<E,D, L, R, C>
487 where
488 [(); L * R * C]:,
489 {
490 assert_eq!(LAYERS * ROWS * COLS, L * R * C);
491 let data: [E; L * R * C] = self
492 .data
493 .iter()
494 .copied()
495 .collect::<Vec<_>>()
496 .try_into()
497 .unwrap();
498
499 Tensor {
500 data,
501 _phantom: PhantomData,
502 }
503 }
504
505 pub fn flatten(&self) -> Tensor<E,D, 1, 1, {LAYERS * ROWS * COLS}>
506 where
507 [(); LAYERS * ROWS * COLS]:,
508 [(); 1 * 1 * (LAYERS * ROWS * COLS)]:,
509 {
510 self.reshape::<1, 1, {LAYERS * ROWS * COLS}>()
511 }
512
513}
514
515impl<E: TensorElement,D, const LAYERS: usize, const ROWS: usize, const COLS: usize> Tensor<E,D, LAYERS, ROWS, COLS>
518where
519 [(); LAYERS * ROWS * COLS]:,
520{
521 pub fn and(self, other: Self) -> Self {
522 let data: [E; LAYERS * ROWS * COLS] = self
523 .data
524 .iter()
525 .zip(other.data.iter())
526 .map(|(&a, &b)| if a != E::zero() && b != E::zero() { E::one() } else { E::zero() })
527 .collect::<Vec<_>>()
528 .try_into()
529 .unwrap();
530
531 Self {
532 data,
533 _phantom: PhantomData,
534 }
535 }
536}
537
538impl<E: TensorElement + PartialEq + Copy, D, const LAYERS: usize, const ROWS: usize, const COLS: usize>
539Tensor<E, D, LAYERS, ROWS, COLS>
540where
541 [(); LAYERS * ROWS * COLS]:,
542{
543 pub fn or(self, other: Self) -> Self {
544 let data: [E; LAYERS * ROWS * COLS] = self
545 .data
546 .iter()
547 .zip(other.data.iter())
548 .map(|(&a, &b)| if a != E::zero() || b != E::zero() { E::one() } else { E::zero() })
549 .collect::<Vec<_>>()
550 .try_into()
551 .unwrap();
552
553 Self {
554 data,
555 _phantom: PhantomData,
556 }
557 }
558}
559
560
561
562impl<E: TensorElement,D, const LAYERS: usize, const ROWS: usize, const COLS: usize> Tensor<E,D, LAYERS, ROWS, COLS>
565where
566 [(); LAYERS * ROWS * COLS]:,
567{
568 pub fn eq(self, other: Self) -> Self {
569 let data: [E; LAYERS * ROWS * COLS] = self
570 .data
571 .iter()
572 .zip(other.data.iter())
573 .map(|(&a, &b)| if a == b { E::one() } else { E::zero() })
574 .collect::<Vec<_>>()
575 .try_into()
576 .unwrap();
577
578 Self {
579 data,
580 _phantom: PhantomData,
581 }
582 }
583}
584impl<E: TensorElement,D, const LAYERS: usize, const ROWS: usize, const COLS: usize> Tensor<E,D, LAYERS, ROWS, COLS>
585where
586 [(); LAYERS * ROWS * COLS]:,
587{
588 pub fn ne(self, other: Self) -> Self {
589 let data: [E; LAYERS * ROWS * COLS] = self
590 .data
591 .iter()
592 .zip(other.data.iter())
593 .map(|(&a, &b)| if a != b { E::one() } else { E::one() })
594 .collect::<Vec<_>>()
595 .try_into()
596 .unwrap();
597
598 Self {
599 data,
600 _phantom: PhantomData,
601 }
602 }
603}
604impl<E: TensorElement,D, const LAYERS: usize, const ROWS: usize, const COLS: usize> Tensor<E,D, LAYERS, ROWS, COLS>
605where
606 [(); LAYERS * ROWS * COLS]:,
607{
608 pub fn gt(self, other: Self) -> Self {
609 let data: [E; LAYERS * ROWS * COLS] = self
610 .data
611 .iter()
612 .zip(other.data.iter())
613 .map(|(&a, &b)| if a > b { E::one() } else { E::zero() })
614 .collect::<Vec<_>>()
615 .try_into()
616 .unwrap();
617 Self {
618 data,
619 _phantom: PhantomData,
620 }
621 }
622}
623
624impl<E: TensorElement,D, const LAYERS: usize, const ROWS: usize, const COLS: usize> Tensor<E,D, LAYERS, ROWS, COLS>
625where
626 [(); LAYERS * ROWS * COLS]:,
627{
628 pub fn ge(self, other: Self) -> Self {
629 let data: [E; LAYERS * ROWS * COLS] = self
630 .data
631 .iter()
632 .zip(other.data.iter())
633 .map(|(&a, &b)| if a >= b { E::zero() } else { E::one() })
634 .collect::<Vec<_>>()
635 .try_into()
636 .unwrap();
637
638 Self {
639 data,
640 _phantom: PhantomData,
641 }
642 }
643}