1pub mod dim_dyn;
2pub mod dim_static;
3
4pub use dim_dyn::larger_shape;
5pub use dim_dyn::DimDyn;
6pub(crate) use dim_dyn::{into_dyn, smaller_shape};
7pub use dim_static::{Dim0, Dim1, Dim2, Dim3, Dim4};
8
9use std::{
10 fmt::Debug,
11 ops::{Index, IndexMut},
12};
13
14pub trait DimTrait:
15 Index<usize, Output = usize>
16 + IndexMut<usize>
17 + IntoIterator<Item = usize>
18 + Clone
19 + Copy
20 + Default
21 + PartialEq
22 + Debug
23 + for<'a> From<&'a [usize]>
24 + for<'a> From<&'a Self>
25 + 'static
26{
27 fn len(&self) -> usize;
28 fn is_empty(&self) -> bool;
29 fn is_overflow<D: DimTrait>(&self, index: D) -> bool {
30 assert!(self.len() >= index.len(), "Dimension mismatch");
31
32 index.into_iter().zip(*self).any(|(x, y)| x >= y)
33 }
34 fn num_elm(&self) -> usize {
35 self.into_iter().product()
36 }
37
38 fn slice(&self) -> &[usize];
39
40 fn is_scalar(&self) -> bool {
41 self.len() == 0 || self.num_elm() == 1
42 }
43}
44
45pub trait LessDimTrait: DimTrait {
46 type LessDim: DimTrait;
47
48 fn remove_axis(&self, axis: usize) -> Self::LessDim {
49 let mut default = DimDyn::default();
50 for i in 0..self.len() {
51 if i == axis {
52 continue;
53 }
54 default.push_dim(self[i]);
55 }
56 Self::LessDim::from(default.slice())
57 }
58}
59
60pub trait GreaterDimTrait: DimTrait {
61 type GreaterDim: DimTrait;
62}
63
64#[expect(clippy::missing_panics_doc)]
65pub fn cal_offset<D1: DimTrait, D2: DimTrait>(shape: D1, stride: D2) -> usize {
66 assert!(shape.len() == stride.len(), "Dimension mismatch");
67 shape.into_iter().zip(stride).map(|(x, y)| x * y).sum()
68}
69
70pub fn default_stride<D: DimTrait>(shape: D) -> D {
71 let mut stride = shape;
72 let n = shape.len();
73
74 if n == 0 {
75 return stride;
76 }
77
78 if n == 1 {
79 stride[0] = 1;
80 return stride;
81 }
82
83 stride[n - 1] = 1;
85
86 for i in (0..n - 1).rev() {
88 stride[i] = stride[i + 1] * shape[i + 1];
89 }
90
91 stride
92}