tensor_macros/
traits.rs

1use std::fmt::Debug;
2use std::ops::Index;
3use std::ops::IndexMut;
4
5pub trait TensorTrait:
6    PartialEq
7    + Debug
8    + Default
9    + std::ops::Add
10    + std::ops::AddAssign
11    + std::ops::Mul
12    + std::ops::MulAssign
13    + Copy
14    + Clone
15{
16}
17
18impl<T> TensorTrait for T where
19    T: PartialEq
20        + Debug
21        + Default
22        + std::ops::Add
23        + std::ops::AddAssign
24        + std::ops::Mul
25        + std::ops::MulAssign
26        + Copy
27        + Clone
28{
29}
30
31pub trait Tensor:
32    Index<usize> + IndexMut<usize> + PartialEq + Debug + Default + Copy + Clone
33{
34    type Value: TensorTrait;
35
36    const SIZE: usize;
37    const NDIM: usize;
38
39    fn dims() -> Vec<usize>;
40    fn get_dims(&self) -> Vec<usize>;
41
42    // fn transpose(self) -> TensorTranspose<Self>;
43}
44
45pub trait CwiseMul<Rhs: Tensor> {
46    type Output: Tensor;
47    fn cwise_mul(self, other: Rhs) -> Self::Output;
48}
49
50pub trait CwiseMulAssign<Rhs: Tensor> {
51    fn cwise_mul_assign(&mut self, other: Rhs);
52}
53
54// pub struct TensorTranspose<T: TensorTrait, TT: Tensor<Value = T>>(TT);
55
56pub trait TensorTranspose<T, TT>: Tensor<Value = TT>
57where
58    T: Tensor<Value = TT>,
59    // + std::ops::Add
60    // + std::ops::AddAssign
61    // + std::ops::Mul<TT>
62    // + std::ops::MulAssign<TT>,
63    TT: TensorTrait,
64{
65    fn transpose(self) -> T;
66}
67
68pub trait Matrix {
69    const ROWS: usize;
70    const COLS: usize;
71}
72
73pub trait Vector {
74    const COLS: usize;
75}
76
77pub trait RowVector {
78    const ROWS: usize;
79}
80
81#[derive(Debug, PartialEq)]
82pub enum TensorError {
83    Size,
84}