zenu_matrix/dim/
mod.rs

1pub mod dim_dyn;
2pub mod dim_static;
3
4pub use dim_dyn::larger_shape;
5pub use dim_dyn::DimDyn;
6pub(crate) use dim_dyn::{into_dyn, smaller_shape};
7pub use dim_static::{Dim0, Dim1, Dim2, Dim3, Dim4};
8
9use std::{
10    fmt::Debug,
11    ops::{Index, IndexMut},
12};
13
14pub trait DimTrait:
15    Index<usize, Output = usize>
16    + IndexMut<usize>
17    + IntoIterator<Item = usize>
18    + Clone
19    + Copy
20    + Default
21    + PartialEq
22    + Debug
23    + for<'a> From<&'a [usize]>
24    + for<'a> From<&'a Self>
25    + 'static
26{
27    fn len(&self) -> usize;
28    fn is_empty(&self) -> bool;
29    fn is_overflow<D: DimTrait>(&self, index: D) -> bool {
30        assert!(self.len() >= index.len(), "Dimension mismatch");
31
32        index.into_iter().zip(*self).any(|(x, y)| x >= y)
33    }
34    fn num_elm(&self) -> usize {
35        self.into_iter().product()
36    }
37
38    fn slice(&self) -> &[usize];
39
40    fn is_scalar(&self) -> bool {
41        self.len() == 0 || self.num_elm() == 1
42    }
43}
44
45pub trait LessDimTrait: DimTrait {
46    type LessDim: DimTrait;
47
48    fn remove_axis(&self, axis: usize) -> Self::LessDim {
49        let mut default = DimDyn::default();
50        for i in 0..self.len() {
51            if i == axis {
52                continue;
53            }
54            default.push_dim(self[i]);
55        }
56        Self::LessDim::from(default.slice())
57    }
58}
59
60pub trait GreaterDimTrait: DimTrait {
61    type GreaterDim: DimTrait;
62}
63
64#[expect(clippy::missing_panics_doc)]
65pub fn cal_offset<D1: DimTrait, D2: DimTrait>(shape: D1, stride: D2) -> usize {
66    assert!(shape.len() == stride.len(), "Dimension mismatch");
67    shape.into_iter().zip(stride).map(|(x, y)| x * y).sum()
68}
69
70pub fn default_stride<D: DimTrait>(shape: D) -> D {
71    let mut stride = shape;
72    let n = shape.len();
73
74    if n == 0 {
75        return stride;
76    }
77
78    if n == 1 {
79        stride[0] = 1;
80        return stride;
81    }
82
83    // 最後の次元のストライドは常に1
84    stride[n - 1] = 1;
85
86    // 残りの次元に対して、後ろから前へ計算
87    for i in (0..n - 1).rev() {
88        stride[i] = stride[i + 1] * shape[i + 1];
89    }
90
91    stride
92}