smpl_utils/
array.rs

1use ndarray as nd;
2use num_traits;
3pub trait Gather1D<T: nd::ScalarOperand + num_traits::identities::Zero + Copy> {
4    fn gather(&self, indices: &[usize]) -> nd::Array1<T>;
5}
6impl<T: nd::ScalarOperand + num_traits::identities::Zero + Copy> Gather1D<T> for nd::Array1<T> {
7    fn gather(&self, indices: &[usize]) -> nd::Array1<T> {
8        let mut res = nd::Array1::<T>::zeros(indices.len());
9        for (i_out, &i_in) in indices.iter().enumerate() {
10            res[i_out] = self[i_in];
11        }
12        res
13    }
14}
15pub trait Gather2D<T: nd::ScalarOperand + num_traits::identities::Zero + Copy> {
16    fn gather(&self, indices_rows: &[usize], indices_cols: &[usize]) -> nd::Array2<T>;
17}
18impl<T: nd::ScalarOperand + num_traits::identities::Zero + Copy> Gather2D<T> for nd::Array2<T> {
19    fn gather(&self, indices_rows: &[usize], indices_cols: &[usize]) -> nd::Array2<T> {
20        let mut res = nd::Array2::zeros((indices_rows.len(), indices_cols.len()));
21        for (i_out, &i_in) in indices_rows.iter().enumerate() {
22            for (j_out, &j_in) in indices_cols.iter().enumerate() {
23                res[(i_out, j_out)] = self[(i_in, j_in)];
24            }
25        }
26        res
27    }
28}
29pub trait Gather3D<T: nd::ScalarOperand + num_traits::identities::Zero + Copy> {
30    fn gather(&self, indices_rows: &[usize], indices_cols: &[usize], indices_depth: &[usize]) -> nd::Array3<T>;
31}
32impl<T: nd::ScalarOperand + num_traits::identities::Zero + Copy> Gather3D<T> for nd::Array3<T> {
33    fn gather(&self, indices_rows: &[usize], indices_cols: &[usize], indices_depth: &[usize]) -> nd::Array3<T> {
34        let mut res = nd::Array3::zeros((indices_rows.len(), indices_cols.len(), indices_depth.len()));
35        for (i_out, &i_in) in indices_rows.iter().enumerate() {
36            for (j_out, &j_in) in indices_cols.iter().enumerate() {
37                for (k_out, &k_in) in indices_depth.iter().enumerate() {
38                    res[(i_out, j_out, k_out)] = self[(i_in, j_in, k_in)];
39                }
40            }
41        }
42        res
43    }
44}
45pub trait Scatter1D<T: nd::ScalarOperand + num_traits::identities::Zero + Copy> {
46    fn scatter(&self, indices: &[usize], dst: &mut nd::Array1<T>);
47}
48impl<T: nd::ScalarOperand + num_traits::identities::Zero + Copy> Scatter1D<T> for nd::Array1<T> {
49    fn scatter(&self, indices: &[usize], dst: &mut nd::Array1<T>) {
50        for (i_in, &i_out) in indices.iter().enumerate() {
51            dst[i_out] = self[i_in];
52        }
53    }
54}
55pub trait Scatter2D<T: nd::ScalarOperand + num_traits::identities::Zero + Copy> {
56    fn scatter(&self, indices_rows: &[usize], indices_cols: &[usize], dst: &mut nd::Array2<T>);
57}
58impl<T: nd::ScalarOperand + num_traits::identities::Zero + Copy> Scatter2D<T> for nd::Array2<T> {
59    fn scatter(&self, indices_rows: &[usize], indices_cols: &[usize], dst: &mut nd::Array2<T>) {
60        for (i_in, &i_out) in indices_rows.iter().enumerate() {
61            for (j_in, &j_out) in indices_cols.iter().enumerate() {
62                dst[(i_out, j_out)] = self[(i_in, j_in)];
63            }
64        }
65    }
66}