zenu_matrix/index/
index_dyn_impl.rs

1use crate::{dim::DimTrait, shape_stride::ShapeStride};
2
3use super::IndexAxisTrait;
4
5#[derive(Clone, Copy, Debug)]
6pub struct Index {
7    axis: usize,
8    index: usize,
9}
10
11impl Index {
12    #[must_use]
13    pub fn new(axis: usize, index: usize) -> Self {
14        Self { axis, index }
15    }
16
17    #[must_use]
18    pub fn axis(&self) -> usize {
19        self.axis
20    }
21
22    #[must_use]
23    pub fn index(&self) -> usize {
24        self.index
25    }
26}
27
28impl IndexAxisTrait for Index {
29    fn get_shape_stride<Din: DimTrait, Dout: DimTrait>(
30        &self,
31        shape: Din,
32        stride: Din,
33    ) -> ShapeStride<Dout> {
34        let mut shape_v = Vec::new();
35        let mut stride_v = Vec::new();
36        for i in 0..shape.len() {
37            if i == self.axis {
38                continue;
39            }
40            shape_v.push(shape[i]);
41            stride_v.push(stride[i]);
42        }
43
44        let new_shape = Dout::from(&shape_v as &[usize]);
45        let new_stride = Dout::from(&stride_v as &[usize]);
46        ShapeStride::new(new_shape, new_stride)
47    }
48    fn offset<Din: DimTrait>(&self, stride: Din) -> usize {
49        stride[self.axis] * self.index
50    }
51}