1use std::fmt::Debug;
2use std::ops::Index;
3use std::ops::IndexMut;
4
5pub trait TensorTrait:
6 PartialEq
7 + Debug
8 + Default
9 + std::ops::Add
10 + std::ops::AddAssign
11 + std::ops::Mul
12 + std::ops::MulAssign
13 + Copy
14 + Clone
15{
16}
17
18impl<T> TensorTrait for T where
19 T: PartialEq
20 + Debug
21 + Default
22 + std::ops::Add
23 + std::ops::AddAssign
24 + std::ops::Mul
25 + std::ops::MulAssign
26 + Copy
27 + Clone
28{
29}
30
31pub trait Tensor:
32 Index<usize> + IndexMut<usize> + PartialEq + Debug + Default + Copy + Clone
33{
34 type Value: TensorTrait;
35
36 const SIZE: usize;
37 const NDIM: usize;
38
39 fn dims() -> Vec<usize>;
40 fn get_dims(&self) -> Vec<usize>;
41
42 }
44
45pub trait CwiseMul<Rhs: Tensor> {
46 type Output: Tensor;
47 fn cwise_mul(self, other: Rhs) -> Self::Output;
48}
49
50pub trait CwiseMulAssign<Rhs: Tensor> {
51 fn cwise_mul_assign(&mut self, other: Rhs);
52}
53
54pub trait TensorTranspose<T, TT>: Tensor<Value = TT>
57where
58 T: Tensor<Value = TT>,
59 TT: TensorTrait,
64{
65 fn transpose(self) -> T;
66}
67
68pub trait Matrix {
69 const ROWS: usize;
70 const COLS: usize;
71}
72
73pub trait Vector {
74 const COLS: usize;
75}
76
77pub trait RowVector {
78 const ROWS: usize;
79}
80
81#[derive(Debug, PartialEq)]
82pub enum TensorError {
83 Size,
84}