zenu_matrix/index/
index_dyn_impl.rs1use crate::{dim::DimTrait, shape_stride::ShapeStride};
2
3use super::IndexAxisTrait;
4
5#[derive(Clone, Copy, Debug)]
6pub struct Index {
7 axis: usize,
8 index: usize,
9}
10
11impl Index {
12 #[must_use]
13 pub fn new(axis: usize, index: usize) -> Self {
14 Self { axis, index }
15 }
16
17 #[must_use]
18 pub fn axis(&self) -> usize {
19 self.axis
20 }
21
22 #[must_use]
23 pub fn index(&self) -> usize {
24 self.index
25 }
26}
27
28impl IndexAxisTrait for Index {
29 fn get_shape_stride<Din: DimTrait, Dout: DimTrait>(
30 &self,
31 shape: Din,
32 stride: Din,
33 ) -> ShapeStride<Dout> {
34 let mut shape_v = Vec::new();
35 let mut stride_v = Vec::new();
36 for i in 0..shape.len() {
37 if i == self.axis {
38 continue;
39 }
40 shape_v.push(shape[i]);
41 stride_v.push(stride[i]);
42 }
43
44 let new_shape = Dout::from(&shape_v as &[usize]);
45 let new_stride = Dout::from(&stride_v as &[usize]);
46 ShapeStride::new(new_shape, new_stride)
47 }
48 fn offset<Din: DimTrait>(&self, stride: Din) -> usize {
49 stride[self.axis] * self.index
50 }
51}