1use ndarray as nd;
2use num_traits;
3pub trait Gather1D<T: nd::ScalarOperand + num_traits::identities::Zero + Copy> {
4 fn gather(&self, indices: &[usize]) -> nd::Array1<T>;
5}
6impl<T: nd::ScalarOperand + num_traits::identities::Zero + Copy> Gather1D<T> for nd::Array1<T> {
7 fn gather(&self, indices: &[usize]) -> nd::Array1<T> {
8 let mut res = nd::Array1::<T>::zeros(indices.len());
9 for (i_out, &i_in) in indices.iter().enumerate() {
10 res[i_out] = self[i_in];
11 }
12 res
13 }
14}
15pub trait Gather2D<T: nd::ScalarOperand + num_traits::identities::Zero + Copy> {
16 fn gather(&self, indices_rows: &[usize], indices_cols: &[usize]) -> nd::Array2<T>;
17}
18impl<T: nd::ScalarOperand + num_traits::identities::Zero + Copy> Gather2D<T> for nd::Array2<T> {
19 fn gather(&self, indices_rows: &[usize], indices_cols: &[usize]) -> nd::Array2<T> {
20 let mut res = nd::Array2::zeros((indices_rows.len(), indices_cols.len()));
21 for (i_out, &i_in) in indices_rows.iter().enumerate() {
22 for (j_out, &j_in) in indices_cols.iter().enumerate() {
23 res[(i_out, j_out)] = self[(i_in, j_in)];
24 }
25 }
26 res
27 }
28}
29pub trait Gather3D<T: nd::ScalarOperand + num_traits::identities::Zero + Copy> {
30 fn gather(&self, indices_rows: &[usize], indices_cols: &[usize], indices_depth: &[usize]) -> nd::Array3<T>;
31}
32impl<T: nd::ScalarOperand + num_traits::identities::Zero + Copy> Gather3D<T> for nd::Array3<T> {
33 fn gather(&self, indices_rows: &[usize], indices_cols: &[usize], indices_depth: &[usize]) -> nd::Array3<T> {
34 let mut res = nd::Array3::zeros((indices_rows.len(), indices_cols.len(), indices_depth.len()));
35 for (i_out, &i_in) in indices_rows.iter().enumerate() {
36 for (j_out, &j_in) in indices_cols.iter().enumerate() {
37 for (k_out, &k_in) in indices_depth.iter().enumerate() {
38 res[(i_out, j_out, k_out)] = self[(i_in, j_in, k_in)];
39 }
40 }
41 }
42 res
43 }
44}
45pub trait Scatter1D<T: nd::ScalarOperand + num_traits::identities::Zero + Copy> {
46 fn scatter(&self, indices: &[usize], dst: &mut nd::Array1<T>);
47}
48impl<T: nd::ScalarOperand + num_traits::identities::Zero + Copy> Scatter1D<T> for nd::Array1<T> {
49 fn scatter(&self, indices: &[usize], dst: &mut nd::Array1<T>) {
50 for (i_in, &i_out) in indices.iter().enumerate() {
51 dst[i_out] = self[i_in];
52 }
53 }
54}
55pub trait Scatter2D<T: nd::ScalarOperand + num_traits::identities::Zero + Copy> {
56 fn scatter(&self, indices_rows: &[usize], indices_cols: &[usize], dst: &mut nd::Array2<T>);
57}
58impl<T: nd::ScalarOperand + num_traits::identities::Zero + Copy> Scatter2D<T> for nd::Array2<T> {
59 fn scatter(&self, indices_rows: &[usize], indices_cols: &[usize], dst: &mut nd::Array2<T>) {
60 for (i_in, &i_out) in indices_rows.iter().enumerate() {
61 for (j_in, &j_out) in indices_cols.iter().enumerate() {
62 dst[(i_out, j_out)] = self[(i_in, j_in)];
63 }
64 }
65 }
66}