tensor_rs/tensor_impl/gen_tensor/
reduction.rs

1use super::GenTensor;
2use crate::tensor_trait::reduction::ReduceTensor;
3
4impl<T> GenTensor<T> where T: num_traits::Float {
5    fn _argmax_min(&self, dim: Option<&[usize]>, keep_dim: bool, max: bool) -> Self {
6        if keep_dim {
7            panic!("argmax cannot keep dim");
8        }
9        let dim2aggregate = if let Some(dim_val) = dim {
10            (0..self.size().len()).filter(|x| dim_val.contains(x)).collect()
11        } else {
12            self.size().to_vec()
13        };
14        let dim = dim2aggregate;
15        
16        // build return tensor dimension.
17        let mut aggregated = false;
18        let ret_dim: Vec<usize> = (0..self.size().len()).map(|x|
19                                                             if dim.contains(&x) {
20                                                                 if !aggregated {
21                                                                     aggregated = true;
22                                                                     dim.len()
23                                                                 } else {
24                                                                     1
25                                                                 }
26                                                             } else {
27                                                                 self.size()[x]
28                                                             }
29        ).collect();
30        let mut ret = Self::zeros(&ret_dim);
31        //println!("{:?}, {:?}, {:?}", ret.size(), self.size(), dim);
32
33        let kept_dim: Vec<usize> = (0..self.size().len()).filter(|x| !dim.contains(x)).collect();
34        let mut index = vec![0; kept_dim.len()]; // index for the loop.
35
36        loop {
37            let mut patch_index: Vec::<(usize, usize)> = Vec::new();
38            let mut output_index: Vec<usize> = Vec::new();
39            let mut kept_dim_step = 0;
40            let mut aggregated = false;
41            for i in 0..self.size().len() {
42                if dim.contains(&i) {
43                    patch_index.push((0, self.size()[i]));
44                    if !aggregated {
45                        output_index.push(0);
46                        aggregated = true;
47                    }
48                } else {
49                    patch_index.push((index[kept_dim_step], index[kept_dim_step]+1));
50                    output_index.push(index[kept_dim_step]);
51                    kept_dim_step += 1;
52                }
53            }
54            //println!("index: {:?}, patch_index: {:?}, output_index: {:?}", index, patch_index, output_index);
55
56            //let value = closure(self.get_patch(&patch_index, None).get_data());
57            let the_patch = self.get_patch(&patch_index, None);
58            let mut max_value = the_patch.get_data()[0];
59            let mut max_index = 0;
60            for (elem_index, i) in the_patch.get_data().iter().enumerate() {
61                if max {
62                    if max_value < *i {
63                        max_value = *i;
64                        max_index = elem_index;
65                    }
66                } else if max_value > *i {
67                    max_value = *i;
68                    max_index = elem_index;
69                }
70            }
71            let dimpos_elem = the_patch.index2dimpos(max_index);
72            let mut dimpos_elem2 = Vec::new();
73            for (dim_index, v) in dimpos_elem.iter().enumerate() {
74                if dim.contains(&dim_index) {
75                    dimpos_elem2.push(*v);
76                } 
77            }
78            let dimpos_elem = dimpos_elem2;
79            //println!("dispos_elem: {:?}", dimpos_elem);
80            for (set_index, i) in dimpos_elem.iter().enumerate() {
81                let mut dest_index = output_index.to_vec();
82                dest_index[dim[0]] = set_index;
83                //println!("dest_index: {:?}", dest_index);
84                ret.set(&dest_index, T::from(*i).unwrap());
85            }
86            
87            for i in 0..index.len() {
88                index[kept_dim.len() -i -1] += 1;
89                if index[kept_dim.len() -i -1] >= self.size()[kept_dim[kept_dim.len() -i -1]] {
90                    index[kept_dim.len() -i -1] = 0;
91                } else {
92                    break
93                }
94            }
95
96            if index == vec![0; kept_dim.len()] {
97                break
98            }
99        }
100        
101        ret
102    }
103}
104
105impl<T> ReduceTensor for GenTensor<T> where T: num_traits::Float {
106
107    fn argmax(&self, dim: Option<&[usize]>, keep_dim: bool) -> Self {
108        self._argmax_min(dim, keep_dim, true)
109    }
110    fn argmin(&self, dim: Option<&[usize]>, keep_dim: bool) -> Self {
111        self._argmax_min(dim, keep_dim, false)
112    }
113    fn dist() {unimplemented!();}
114    fn logsumexp(&self, dim: Option<&[usize]>, keep_dim: bool) -> Self {
115        self._iter_patch(dim, keep_dim,
116                         |x| {
117                             let mut max = x[0];
118                             for i in x {
119                                 if max < *i {
120                                     max = *i;
121                                 }
122                             }
123
124                             let mut sum = T::zero();
125                             for i in x {
126                                 sum = sum + (*i - max).exp();
127                             }
128                             max + sum.ln()
129                         }
130        )
131    }
132    /// Returns the mean value of the tensor along dim row.
133    fn mean(&self, dim: Option<&[usize]>, keep_dim: bool) -> GenTensor<T> {
134        self._iter_patch(dim, keep_dim,
135                         |x| {
136                             let n = x.len();
137                             let mut sum = T::zero();
138                             for i in x {
139                                 sum = sum + *i;
140                             }
141                             sum / T::from(n).expect("")
142                         }
143        )
144    }
145    fn median(){unimplemented!();}
146    fn mode() {unimplemented!();}
147    fn prod(&self, dim: Option<&[usize]>, keep_dim: bool) -> GenTensor<T> {
148        self._iter_patch(dim, keep_dim,
149                         |x| {
150                             let mut p = T::one();
151                             for i in x {
152                                 p = p * (*i);
153                             }
154                             p
155                         }
156        )
157    }
158    fn std(&self, dim: Option<&[usize]>, keep_dim: bool) -> GenTensor<T> {
159        self._iter_patch(dim, keep_dim,
160                         |x| {
161                             let n = x.len();
162                             let mut sum = T::zero();
163                             let mut sum2 = T::zero();
164                             for i in x {
165                                 sum = sum + *i;
166                                 sum2 = sum2 + *i*(*i);
167                             }
168                             let sum2 = sum2 / T::from(n).expect("");
169                             let sum = sum / T::from(n).expect("");
170                             (sum2 - sum*sum).sqrt()
171                         }
172        )
173    }
174    fn std_mean() {unimplemented!();}
175    //fn sum(&self, dim: usize, keepdim: bool) -> Self::TensorType {}
176    /// Returns the sum of all elements.
177    /// ```
178    /// # use crate::tensor_rs::tensor_impl::gen_tensor::*;
179    /// # use crate::tensor_rs::tensor_trait::reduction::ReduceTensor;
180    /// let m1 = GenTensor::<f64>::new_raw(&vec![1.,2.,3.,4.,], &vec![2,2]);
181    /// assert_eq!(m1.sum(None, false).get_scale(), 10.);
182    /// ```
183    fn sum(&self, dim: Option<&[usize]>, keep_dim: bool) -> GenTensor<T> {
184        self._iter_patch(dim, keep_dim,
185                         |x| {
186                             let mut sum = T::zero();
187                             for i in x {
188                                 sum = sum + *i;
189                             }
190                             sum
191                         }
192        )
193    }
194    fn unique(){unimplemented!();}
195    fn unique_consecutive() {unimplemented!();}
196    fn var(&self, dim: Option<&[usize]>, keep_dim: bool) -> GenTensor<T> {
197        self._iter_patch(dim, keep_dim,
198                         |x| {
199                             let n = x.len();
200                             let mut sum = T::zero();
201                             let mut sum2 = T::zero();
202                             for i in x {
203                                 sum = sum + *i;
204                                 sum2 = sum2 + *i*(*i);
205                             }
206                             let sum2 = sum2 / T::from(n).expect("");
207                             let sum = sum / T::from(n).expect("");
208                             sum2 - sum*sum
209                         }
210        )
211    }
212
213    fn var_mean() {unimplemented!();}
214
215    fn max(&self, dim: Option<&[usize]>, keep_dim: bool) -> Self {
216        self._iter_patch(dim, keep_dim,
217                         |x| {
218                             let mut max = x[0];
219                             for i in x {
220                                 if max < *i {
221                                     max = *i;
222                                 }
223                             }
224                             max
225                         }
226        )
227    }
228    fn min(&self, dim: Option<&[usize]>, keep_dim: bool) -> Self {
229        self._iter_patch(dim, keep_dim,
230                         |x| {
231                             let mut min = x[0];
232                             for i in x {
233                                 if min < *i {
234                                     min = *i;
235                                 }
236                             }
237                             min
238                         }
239        )
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use crate::tensor_impl::gen_tensor::GenTensor;
246    use super::*;
247
248    #[test]
249    fn argmax() {
250        let a = GenTensor::<f32>::new_raw(&vec![1., 2., 3., 4., 5., 6., ], &vec![3, 2]);
251        let b = a.argmax(Some(&[0]), false);
252        println!("{:?}", b);
253        assert_eq!(b, GenTensor::<f32>::new_raw(&[2., 2.,], &[1, 2]));
254
255        let b = a.argmax(Some(&[1]), false);
256        println!("{:?}", b);
257        assert_eq!(b, GenTensor::<f32>::new_raw(&[1., 1., 1.,], &[3, 1]));
258    }
259
260    #[test]
261    fn argmin() {
262        let a = GenTensor::<f32>::new_raw(&vec![1., 2., 3., 4., 5., 6., ], &vec![3, 2]);
263        let b = a.argmin(Some(&[0]), false);
264        println!("{:?}", b);
265        assert_eq!(b, GenTensor::<f32>::new_raw(&[0., 0.,], &[1, 2]));
266
267        let b = a.argmin(Some(&[1]), false);
268        println!("{:?}", b);
269        assert_eq!(b, GenTensor::<f32>::new_raw(&[0., 0., 0.,], &[3, 1]));
270    }
271
272    #[test]
273    fn logsumexp() {
274        let a = GenTensor::<f32>::new_raw(&vec![1., 2., 3., 4., 5., 6., ], &vec![3, 2]);
275        let b = a.logsumexp(Some(&[1]), false);
276        assert_eq!(b, GenTensor::<f32>::new_raw(&vec![2.3132617, 4.3132615, 6.3132615], &vec![3]));
277    }
278
279    #[test]
280    fn mean() {
281        let a = GenTensor::<f32>::fill(1., &vec![3, 4, 3]);
282        let b = a.mean(Some(&[1]), false);
283        assert_eq!(*b.size(), vec![3, 3]);
284        assert_eq!(b.numel(), 9);
285        //println!("{}", b);
286        let c = a.mean(Some(&[1]), true);
287        assert_eq!(*c.size(), vec![3, 1, 3]);
288        assert_eq!(c.numel(), 9);
289        //println!("{}", c);
290    }
291
292    #[test]
293    fn var() {
294        let a = GenTensor::<f32>::new_raw(&vec![1., 2., 3., 4., 5., 6., ], &vec![3, 2]);
295        let b = a.var(Some(&[0]), false);
296        assert_eq!(*b.size(), vec![2]);
297        assert_eq!(b.numel(), 2);
298        assert_eq!(b, GenTensor::<f32>::new_raw(&vec![2.666667, 2.666666], &vec![2]));
299        //println!("{}", b);
300        let c = a.var(Some(&[1]), true);
301        assert_eq!(*c.size(), vec![3, 1]);
302        assert_eq!(c.numel(), 3);
303        assert_eq!(c, GenTensor::<f32>::new_raw(&vec![0.25, 0.25, 0.25], &vec![3, 1]));
304        //println!("{}", c);
305    }
306
307    #[test]
308    fn std() {
309        let a = GenTensor::<f32>::new_raw(&vec![1., 2., 3., 4., 5., 6., ], &vec![3, 2]);
310        let b = a.std(Some(&[0]), false);
311        assert_eq!(*b.size(), vec![2]);
312        assert_eq!(b.numel(), 2);
313        assert_eq!(b, GenTensor::<f32>::new_raw(&vec![1.6329932, 1.632993], &vec![2]));
314        //println!("{}", b);
315        let c = a.std(Some(&[1]), true);
316        assert_eq!(*c.size(), vec![3, 1]);
317        assert_eq!(c.numel(), 3);
318        assert_eq!(c, GenTensor::<f32>::new_raw(&vec![0.5, 0.5, 0.5], &vec![3, 1]));
319        //println!("{}", c);
320    }
321}
322