1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
pub mod dim_dyn;
pub mod dim_static;

pub(crate) use dim_dyn::into_dyn;
pub use dim_dyn::DimDyn;
pub use dim_static::{Dim0, Dim1, Dim2, Dim3, Dim4};

use std::{
    fmt::Debug,
    ops::{Index, IndexMut},
};

pub trait DimTrait:
    Index<usize, Output = usize>
    + IndexMut<usize>
    + IntoIterator<Item = usize>
    + Clone
    + Copy
    + Default
    + PartialEq
    + Debug
    + for<'a> From<&'a [usize]>
    + 'static
{
    fn len(&self) -> usize;
    fn is_empty(&self) -> bool;
    fn is_overflow<D: DimTrait>(&self, index: D) -> bool {
        if self.len() < index.len() {
            panic!("Dimension mismatch");
        }

        index.into_iter().zip(*self).any(|(x, y)| x >= y)
    }
    fn num_elm(&self) -> usize {
        self.into_iter().product()
    }

    fn slice(&self) -> &[usize];
}

pub trait LessDimTrait: DimTrait {
    type LessDim: DimTrait;
}

pub trait GreaterDimTrait: DimTrait {
    type GreaterDim: DimTrait;
}

pub fn cal_offset<D1: DimTrait, D2: DimTrait>(shape: D1, stride: D2) -> usize {
    if shape.len() != stride.len() {
        panic!("Dimension mismatch");
    }
    shape.into_iter().zip(stride).map(|(x, y)| x * y).sum()
}

pub fn default_stride<D: DimTrait>(shape: D) -> D {
    let mut stride = shape;
    let n = shape.len();

    if n == 0 {
        return stride;
    }

    if n == 1 {
        stride[0] = 1;
        return stride;
    }

    // 最後の次元のストライドは常に1
    stride[n - 1] = 1;

    // 残りの次元に対して、後ろから前へ計算
    for i in (0..n - 1).rev() {
        stride[i] = stride[i + 1] * shape[i + 1];
    }

    stride
}