1use super::GenTensor;
2use crate::tensor_trait::reduction::ReduceTensor;
3
4impl<T> GenTensor<T> where T: num_traits::Float {
5 fn _argmax_min(&self, dim: Option<&[usize]>, keep_dim: bool, max: bool) -> Self {
6 if keep_dim {
7 panic!("argmax cannot keep dim");
8 }
9 let dim2aggregate = if let Some(dim_val) = dim {
10 (0..self.size().len()).filter(|x| dim_val.contains(x)).collect()
11 } else {
12 self.size().to_vec()
13 };
14 let dim = dim2aggregate;
15
16 let mut aggregated = false;
18 let ret_dim: Vec<usize> = (0..self.size().len()).map(|x|
19 if dim.contains(&x) {
20 if !aggregated {
21 aggregated = true;
22 dim.len()
23 } else {
24 1
25 }
26 } else {
27 self.size()[x]
28 }
29 ).collect();
30 let mut ret = Self::zeros(&ret_dim);
31 let kept_dim: Vec<usize> = (0..self.size().len()).filter(|x| !dim.contains(x)).collect();
34 let mut index = vec![0; kept_dim.len()]; loop {
37 let mut patch_index: Vec::<(usize, usize)> = Vec::new();
38 let mut output_index: Vec<usize> = Vec::new();
39 let mut kept_dim_step = 0;
40 let mut aggregated = false;
41 for i in 0..self.size().len() {
42 if dim.contains(&i) {
43 patch_index.push((0, self.size()[i]));
44 if !aggregated {
45 output_index.push(0);
46 aggregated = true;
47 }
48 } else {
49 patch_index.push((index[kept_dim_step], index[kept_dim_step]+1));
50 output_index.push(index[kept_dim_step]);
51 kept_dim_step += 1;
52 }
53 }
54 let the_patch = self.get_patch(&patch_index, None);
58 let mut max_value = the_patch.get_data()[0];
59 let mut max_index = 0;
60 for (elem_index, i) in the_patch.get_data().iter().enumerate() {
61 if max {
62 if max_value < *i {
63 max_value = *i;
64 max_index = elem_index;
65 }
66 } else if max_value > *i {
67 max_value = *i;
68 max_index = elem_index;
69 }
70 }
71 let dimpos_elem = the_patch.index2dimpos(max_index);
72 let mut dimpos_elem2 = Vec::new();
73 for (dim_index, v) in dimpos_elem.iter().enumerate() {
74 if dim.contains(&dim_index) {
75 dimpos_elem2.push(*v);
76 }
77 }
78 let dimpos_elem = dimpos_elem2;
79 for (set_index, i) in dimpos_elem.iter().enumerate() {
81 let mut dest_index = output_index.to_vec();
82 dest_index[dim[0]] = set_index;
83 ret.set(&dest_index, T::from(*i).unwrap());
85 }
86
87 for i in 0..index.len() {
88 index[kept_dim.len() -i -1] += 1;
89 if index[kept_dim.len() -i -1] >= self.size()[kept_dim[kept_dim.len() -i -1]] {
90 index[kept_dim.len() -i -1] = 0;
91 } else {
92 break
93 }
94 }
95
96 if index == vec![0; kept_dim.len()] {
97 break
98 }
99 }
100
101 ret
102 }
103}
104
105impl<T> ReduceTensor for GenTensor<T> where T: num_traits::Float {
106
107 fn argmax(&self, dim: Option<&[usize]>, keep_dim: bool) -> Self {
108 self._argmax_min(dim, keep_dim, true)
109 }
110 fn argmin(&self, dim: Option<&[usize]>, keep_dim: bool) -> Self {
111 self._argmax_min(dim, keep_dim, false)
112 }
113 fn dist() {unimplemented!();}
114 fn logsumexp(&self, dim: Option<&[usize]>, keep_dim: bool) -> Self {
115 self._iter_patch(dim, keep_dim,
116 |x| {
117 let mut max = x[0];
118 for i in x {
119 if max < *i {
120 max = *i;
121 }
122 }
123
124 let mut sum = T::zero();
125 for i in x {
126 sum = sum + (*i - max).exp();
127 }
128 max + sum.ln()
129 }
130 )
131 }
132 fn mean(&self, dim: Option<&[usize]>, keep_dim: bool) -> GenTensor<T> {
134 self._iter_patch(dim, keep_dim,
135 |x| {
136 let n = x.len();
137 let mut sum = T::zero();
138 for i in x {
139 sum = sum + *i;
140 }
141 sum / T::from(n).expect("")
142 }
143 )
144 }
145 fn median(){unimplemented!();}
146 fn mode() {unimplemented!();}
147 fn prod(&self, dim: Option<&[usize]>, keep_dim: bool) -> GenTensor<T> {
148 self._iter_patch(dim, keep_dim,
149 |x| {
150 let mut p = T::one();
151 for i in x {
152 p = p * (*i);
153 }
154 p
155 }
156 )
157 }
158 fn std(&self, dim: Option<&[usize]>, keep_dim: bool) -> GenTensor<T> {
159 self._iter_patch(dim, keep_dim,
160 |x| {
161 let n = x.len();
162 let mut sum = T::zero();
163 let mut sum2 = T::zero();
164 for i in x {
165 sum = sum + *i;
166 sum2 = sum2 + *i*(*i);
167 }
168 let sum2 = sum2 / T::from(n).expect("");
169 let sum = sum / T::from(n).expect("");
170 (sum2 - sum*sum).sqrt()
171 }
172 )
173 }
174 fn std_mean() {unimplemented!();}
175 fn sum(&self, dim: Option<&[usize]>, keep_dim: bool) -> GenTensor<T> {
184 self._iter_patch(dim, keep_dim,
185 |x| {
186 let mut sum = T::zero();
187 for i in x {
188 sum = sum + *i;
189 }
190 sum
191 }
192 )
193 }
194 fn unique(){unimplemented!();}
195 fn unique_consecutive() {unimplemented!();}
196 fn var(&self, dim: Option<&[usize]>, keep_dim: bool) -> GenTensor<T> {
197 self._iter_patch(dim, keep_dim,
198 |x| {
199 let n = x.len();
200 let mut sum = T::zero();
201 let mut sum2 = T::zero();
202 for i in x {
203 sum = sum + *i;
204 sum2 = sum2 + *i*(*i);
205 }
206 let sum2 = sum2 / T::from(n).expect("");
207 let sum = sum / T::from(n).expect("");
208 sum2 - sum*sum
209 }
210 )
211 }
212
213 fn var_mean() {unimplemented!();}
214
215 fn max(&self, dim: Option<&[usize]>, keep_dim: bool) -> Self {
216 self._iter_patch(dim, keep_dim,
217 |x| {
218 let mut max = x[0];
219 for i in x {
220 if max < *i {
221 max = *i;
222 }
223 }
224 max
225 }
226 )
227 }
228 fn min(&self, dim: Option<&[usize]>, keep_dim: bool) -> Self {
229 self._iter_patch(dim, keep_dim,
230 |x| {
231 let mut min = x[0];
232 for i in x {
233 if min < *i {
234 min = *i;
235 }
236 }
237 min
238 }
239 )
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use crate::tensor_impl::gen_tensor::GenTensor;
246 use super::*;
247
248 #[test]
249 fn argmax() {
250 let a = GenTensor::<f32>::new_raw(&vec![1., 2., 3., 4., 5., 6., ], &vec![3, 2]);
251 let b = a.argmax(Some(&[0]), false);
252 println!("{:?}", b);
253 assert_eq!(b, GenTensor::<f32>::new_raw(&[2., 2.,], &[1, 2]));
254
255 let b = a.argmax(Some(&[1]), false);
256 println!("{:?}", b);
257 assert_eq!(b, GenTensor::<f32>::new_raw(&[1., 1., 1.,], &[3, 1]));
258 }
259
260 #[test]
261 fn argmin() {
262 let a = GenTensor::<f32>::new_raw(&vec![1., 2., 3., 4., 5., 6., ], &vec![3, 2]);
263 let b = a.argmin(Some(&[0]), false);
264 println!("{:?}", b);
265 assert_eq!(b, GenTensor::<f32>::new_raw(&[0., 0.,], &[1, 2]));
266
267 let b = a.argmin(Some(&[1]), false);
268 println!("{:?}", b);
269 assert_eq!(b, GenTensor::<f32>::new_raw(&[0., 0., 0.,], &[3, 1]));
270 }
271
272 #[test]
273 fn logsumexp() {
274 let a = GenTensor::<f32>::new_raw(&vec![1., 2., 3., 4., 5., 6., ], &vec![3, 2]);
275 let b = a.logsumexp(Some(&[1]), false);
276 assert_eq!(b, GenTensor::<f32>::new_raw(&vec![2.3132617, 4.3132615, 6.3132615], &vec![3]));
277 }
278
279 #[test]
280 fn mean() {
281 let a = GenTensor::<f32>::fill(1., &vec![3, 4, 3]);
282 let b = a.mean(Some(&[1]), false);
283 assert_eq!(*b.size(), vec![3, 3]);
284 assert_eq!(b.numel(), 9);
285 let c = a.mean(Some(&[1]), true);
287 assert_eq!(*c.size(), vec![3, 1, 3]);
288 assert_eq!(c.numel(), 9);
289 }
291
292 #[test]
293 fn var() {
294 let a = GenTensor::<f32>::new_raw(&vec![1., 2., 3., 4., 5., 6., ], &vec![3, 2]);
295 let b = a.var(Some(&[0]), false);
296 assert_eq!(*b.size(), vec![2]);
297 assert_eq!(b.numel(), 2);
298 assert_eq!(b, GenTensor::<f32>::new_raw(&vec![2.666667, 2.666666], &vec![2]));
299 let c = a.var(Some(&[1]), true);
301 assert_eq!(*c.size(), vec![3, 1]);
302 assert_eq!(c.numel(), 3);
303 assert_eq!(c, GenTensor::<f32>::new_raw(&vec![0.25, 0.25, 0.25], &vec![3, 1]));
304 }
306
307 #[test]
308 fn std() {
309 let a = GenTensor::<f32>::new_raw(&vec![1., 2., 3., 4., 5., 6., ], &vec![3, 2]);
310 let b = a.std(Some(&[0]), false);
311 assert_eq!(*b.size(), vec![2]);
312 assert_eq!(b.numel(), 2);
313 assert_eq!(b, GenTensor::<f32>::new_raw(&vec![1.6329932, 1.632993], &vec![2]));
314 let c = a.std(Some(&[1]), true);
316 assert_eq!(*c.size(), vec![3, 1]);
317 assert_eq!(c.numel(), 3);
318 assert_eq!(c, GenTensor::<f32>::new_raw(&vec![0.5, 0.5, 0.5], &vec![3, 1]));
319 }
321}
322