slsl/reduce/
argmin_argmax.rs

1use anyhow::Result;
2use half::{bf16, f16};
3
4use crate::{global_backend, DType, Dim, OpsTrait, StorageTrait, Tensor, TensorBase, UninitVec};
5
6impl<S: StorageTrait> TensorBase<S> {
7    pub fn argmin_argmax<D: Dim + Clone>(&self, dim: D) -> Result<(Tensor, Tensor)> {
8        self.argmin_argmax_impl(dim, false)
9    }
10
11    pub fn argmin_argmax_keepdim<D: Dim + Clone>(&self, dim: D) -> Result<(Tensor, Tensor)> {
12        self.argmin_argmax_impl(dim, true)
13    }
14
15    pub fn argmin_argmax_impl<D: Dim + Clone>(
16        &self,
17        dim: D,
18        keepdim: bool,
19    ) -> Result<(Tensor, Tensor)> {
20        let dim_index = dim.to_dim(self.rank())?;
21
22        if self.shape()[dim_index] == 0 {
23            anyhow::bail!("Cannot find argmin/argmax of dimension with size 0");
24        }
25
26        // Use dimension-agnostic optimization when possible
27        if self.is_contiguous() && self.can_reduce_over_last_dims(&[dim_index]) {
28            let backend = global_backend();
29            let shape = self.shape();
30            let reduce_size = shape[dim_index];
31            let output_size = self.numel() / reduce_size;
32            let (new_shape, _) =
33                crate::reduce::reduce_shape_stride(self.shape, &[dim_index], keepdim);
34
35            match self.dtype() {
36                DType::Fp32 => {
37                    let data = self.as_slice::<f32>()?;
38
39                    let mut out_argmin = UninitVec::<u64>::new(output_size);
40                    let mut out_argmax = UninitVec::<u64>::new(output_size);
41
42                    let dst_argmin = out_argmin.as_mut_slice();
43                    let dst_argmax = out_argmax.as_mut_slice();
44
45                    for i in 0..output_size {
46                        let start = i * reduce_size;
47                        let end = start + reduce_size;
48                        let (min_idx, max_idx) = backend.min_max_i_f32(&data[start..end]);
49
50                        dst_argmin[i] = min_idx;
51                        dst_argmax[i] = max_idx;
52                    }
53
54                    let out_argmin = unsafe { out_argmin.finalize() };
55                    let out_argmax = unsafe { out_argmax.finalize() };
56
57                    Ok((
58                        Tensor::from_vec(out_argmin, new_shape)?,
59                        Tensor::from_vec(out_argmax, new_shape)?,
60                    ))
61                }
62                DType::Fp64 => {
63                    let data = self.as_slice::<f64>()?;
64
65                    let mut out_argmin = UninitVec::<u64>::new(output_size);
66                    let mut out_argmax = UninitVec::<u64>::new(output_size);
67
68                    let dst_argmin = out_argmin.as_mut_slice();
69                    let dst_argmax = out_argmax.as_mut_slice();
70
71                    for i in 0..output_size {
72                        let start = i * reduce_size;
73                        let end = start + reduce_size;
74                        let (min_idx, max_idx) = backend.min_max_i_f64(&data[start..end]);
75
76                        dst_argmin[i] = min_idx;
77                        dst_argmax[i] = max_idx;
78                    }
79
80                    let out_argmin = unsafe { out_argmin.finalize() };
81                    let out_argmax = unsafe { out_argmax.finalize() };
82
83                    Ok((
84                        Tensor::from_vec(out_argmin, new_shape)?,
85                        Tensor::from_vec(out_argmax, new_shape)?,
86                    ))
87                }
88                DType::Fp16 => {
89                    let data = self.as_slice::<f16>()?;
90
91                    let mut out_argmin = UninitVec::<u64>::new(output_size);
92                    let mut out_argmax = UninitVec::<u64>::new(output_size);
93
94                    let dst_argmin = out_argmin.as_mut_slice();
95                    let dst_argmax = out_argmax.as_mut_slice();
96
97                    for i in 0..output_size {
98                        let start = i * reduce_size;
99                        let end = start + reduce_size;
100                        let (min_idx, max_idx) = backend.min_max_i_f16(&data[start..end]);
101
102                        dst_argmin[i] = min_idx;
103                        dst_argmax[i] = max_idx;
104                    }
105
106                    let out_argmin = unsafe { out_argmin.finalize() };
107                    let out_argmax = unsafe { out_argmax.finalize() };
108
109                    Ok((
110                        Tensor::from_vec(out_argmin, new_shape)?,
111                        Tensor::from_vec(out_argmax, new_shape)?,
112                    ))
113                }
114                DType::Bf16 => {
115                    let data = self.as_slice::<bf16>()?;
116
117                    let mut out_argmin = UninitVec::<u64>::new(output_size);
118                    let mut out_argmax = UninitVec::<u64>::new(output_size);
119
120                    let dst_argmin = out_argmin.as_mut_slice();
121                    let dst_argmax = out_argmax.as_mut_slice();
122
123                    for i in 0..output_size {
124                        let start = i * reduce_size;
125                        let end = start + reduce_size;
126                        let (min_idx, max_idx) = backend.min_max_i_bf16(&data[start..end]);
127
128                        dst_argmin[i] = min_idx;
129                        dst_argmax[i] = max_idx;
130                    }
131
132                    let out_argmin = unsafe { out_argmin.finalize() };
133                    let out_argmax = unsafe { out_argmax.finalize() };
134
135                    Ok((
136                        Tensor::from_vec(out_argmin, new_shape)?,
137                        Tensor::from_vec(out_argmax, new_shape)?,
138                    ))
139                }
140                DType::Int8 => {
141                    let data = self.as_slice::<i8>()?;
142
143                    let mut out_argmin = UninitVec::<u64>::new(output_size);
144                    let mut out_argmax = UninitVec::<u64>::new(output_size);
145
146                    let dst_argmin = out_argmin.as_mut_slice();
147                    let dst_argmax = out_argmax.as_mut_slice();
148
149                    for i in 0..output_size {
150                        let start = i * reduce_size;
151                        let end = start + reduce_size;
152                        let (min_idx, max_idx) = backend.min_max_i_i8(&data[start..end]);
153
154                        dst_argmin[i] = min_idx;
155                        dst_argmax[i] = max_idx;
156                    }
157
158                    let out_argmin = unsafe { out_argmin.finalize() };
159                    let out_argmax = unsafe { out_argmax.finalize() };
160
161                    Ok((
162                        Tensor::from_vec(out_argmin, new_shape)?,
163                        Tensor::from_vec(out_argmax, new_shape)?,
164                    ))
165                }
166                DType::Int16 => {
167                    let data = self.as_slice::<i16>()?;
168
169                    let mut out_argmin = UninitVec::<u64>::new(output_size);
170                    let mut out_argmax = UninitVec::<u64>::new(output_size);
171
172                    let dst_argmin = out_argmin.as_mut_slice();
173                    let dst_argmax = out_argmax.as_mut_slice();
174
175                    for i in 0..output_size {
176                        let start = i * reduce_size;
177                        let end = start + reduce_size;
178                        let (min_idx, max_idx) = backend.min_max_i_i16(&data[start..end]);
179
180                        dst_argmin[i] = min_idx;
181                        dst_argmax[i] = max_idx;
182                    }
183
184                    let out_argmin = unsafe { out_argmin.finalize() };
185                    let out_argmax = unsafe { out_argmax.finalize() };
186
187                    Ok((
188                        Tensor::from_vec(out_argmin, new_shape)?,
189                        Tensor::from_vec(out_argmax, new_shape)?,
190                    ))
191                }
192                DType::Int32 => {
193                    let data = self.as_slice::<i32>()?;
194
195                    let mut out_argmin = UninitVec::<u64>::new(output_size);
196                    let mut out_argmax = UninitVec::<u64>::new(output_size);
197
198                    let dst_argmin = out_argmin.as_mut_slice();
199                    let dst_argmax = out_argmax.as_mut_slice();
200
201                    for i in 0..output_size {
202                        let start = i * reduce_size;
203                        let end = start + reduce_size;
204                        let (min_idx, max_idx) = backend.min_max_i_i32(&data[start..end]);
205
206                        dst_argmin[i] = min_idx;
207                        dst_argmax[i] = max_idx;
208                    }
209
210                    let out_argmin = unsafe { out_argmin.finalize() };
211                    let out_argmax = unsafe { out_argmax.finalize() };
212
213                    Ok((
214                        Tensor::from_vec(out_argmin, new_shape)?,
215                        Tensor::from_vec(out_argmax, new_shape)?,
216                    ))
217                }
218                DType::Int64 => {
219                    let data = self.as_slice::<i64>()?;
220
221                    let mut out_argmin = UninitVec::<u64>::new(output_size);
222                    let mut out_argmax = UninitVec::<u64>::new(output_size);
223
224                    let dst_argmin = out_argmin.as_mut_slice();
225                    let dst_argmax = out_argmax.as_mut_slice();
226
227                    for i in 0..output_size {
228                        let start = i * reduce_size;
229                        let end = start + reduce_size;
230                        let (min_idx, max_idx) = backend.min_max_i_i64(&data[start..end]);
231
232                        dst_argmin[i] = min_idx;
233                        dst_argmax[i] = max_idx;
234                    }
235
236                    let out_argmin = unsafe { out_argmin.finalize() };
237                    let out_argmax = unsafe { out_argmax.finalize() };
238
239                    Ok((
240                        Tensor::from_vec(out_argmin, new_shape)?,
241                        Tensor::from_vec(out_argmax, new_shape)?,
242                    ))
243                }
244                DType::Uint8 => {
245                    let data = self.as_slice::<u8>()?;
246
247                    let mut out_argmin = UninitVec::<u64>::new(output_size);
248                    let mut out_argmax = UninitVec::<u64>::new(output_size);
249
250                    let dst_argmin = out_argmin.as_mut_slice();
251                    let dst_argmax = out_argmax.as_mut_slice();
252
253                    for i in 0..output_size {
254                        let start = i * reduce_size;
255                        let end = start + reduce_size;
256                        let (min_idx, max_idx) = backend.min_max_i_u8(&data[start..end]);
257
258                        dst_argmin[i] = min_idx;
259                        dst_argmax[i] = max_idx;
260                    }
261
262                    let out_argmin = unsafe { out_argmin.finalize() };
263                    let out_argmax = unsafe { out_argmax.finalize() };
264
265                    Ok((
266                        Tensor::from_vec(out_argmin, new_shape)?,
267                        Tensor::from_vec(out_argmax, new_shape)?,
268                    ))
269                }
270                DType::Uint16 => {
271                    let data = self.as_slice::<u16>()?;
272
273                    let mut out_argmin = UninitVec::<u64>::new(output_size);
274                    let mut out_argmax = UninitVec::<u64>::new(output_size);
275
276                    let dst_argmin = out_argmin.as_mut_slice();
277                    let dst_argmax = out_argmax.as_mut_slice();
278
279                    for i in 0..output_size {
280                        let start = i * reduce_size;
281                        let end = start + reduce_size;
282                        let (min_idx, max_idx) = backend.min_max_i_u16(&data[start..end]);
283
284                        dst_argmin[i] = min_idx;
285                        dst_argmax[i] = max_idx;
286                    }
287
288                    let out_argmin = unsafe { out_argmin.finalize() };
289                    let out_argmax = unsafe { out_argmax.finalize() };
290
291                    Ok((
292                        Tensor::from_vec(out_argmin, new_shape)?,
293                        Tensor::from_vec(out_argmax, new_shape)?,
294                    ))
295                }
296                DType::Uint32 => {
297                    let data = self.as_slice::<u32>()?;
298
299                    let mut out_argmin = UninitVec::<u64>::new(output_size);
300                    let mut out_argmax = UninitVec::<u64>::new(output_size);
301
302                    let dst_argmin = out_argmin.as_mut_slice();
303                    let dst_argmax = out_argmax.as_mut_slice();
304
305                    for i in 0..output_size {
306                        let start = i * reduce_size;
307                        let end = start + reduce_size;
308                        let (min_idx, max_idx) = backend.min_max_i_u32(&data[start..end]);
309
310                        dst_argmin[i] = min_idx;
311                        dst_argmax[i] = max_idx;
312                    }
313
314                    let out_argmin = unsafe { out_argmin.finalize() };
315                    let out_argmax = unsafe { out_argmax.finalize() };
316
317                    Ok((
318                        Tensor::from_vec(out_argmin, new_shape)?,
319                        Tensor::from_vec(out_argmax, new_shape)?,
320                    ))
321                }
322                DType::Uint64 => {
323                    let data = self.as_slice::<u64>()?;
324
325                    let mut out_argmin = UninitVec::<u64>::new(output_size);
326                    let mut out_argmax = UninitVec::<u64>::new(output_size);
327
328                    let dst_argmin = out_argmin.as_mut_slice();
329                    let dst_argmax = out_argmax.as_mut_slice();
330
331                    for i in 0..output_size {
332                        let start = i * reduce_size;
333                        let end = start + reduce_size;
334                        let (min_idx, max_idx) = backend.min_max_i_u64(&data[start..end]);
335
336                        dst_argmin[i] = min_idx;
337                        dst_argmax[i] = max_idx;
338                    }
339
340                    let out_argmin = unsafe { out_argmin.finalize() };
341                    let out_argmax = unsafe { out_argmax.finalize() };
342
343                    Ok((
344                        Tensor::from_vec(out_argmin, new_shape)?,
345                        Tensor::from_vec(out_argmax, new_shape)?,
346                    ))
347                }
348                _ => anyhow::bail!("Argmin/Argmax not supported for dtype {:?}", self.dtype()),
349            }
350        } else {
351            let (new_shape, _) = crate::reduce_shape_stride(self.shape, &[dim_index], keepdim);
352
353            let result_size = new_shape.iter().product();
354            macro_rules! noncontig_argmin_argmax {
355                ($t:ty, $min_init:expr, $max_init:expr) => {{
356                    let mut mins = vec![$min_init; result_size];
357                    let mut maxs = vec![$max_init; result_size];
358                    let mut argmins = vec![0u64; result_size];
359                    let mut argmaxs = vec![0u64; result_size];
360                    let mut idx_buf = vec![0; new_shape.len()];
361
362                    for elem in self.iter() {
363                        let i = elem.indices;
364                        let ptr = unsafe { elem.as_ptr(self.as_ptr()) };
365                        let val = unsafe { *(ptr as *const $t) };
366                        let mut current_dim = 0;
367                        for k in 0..self.rank() {
368                            if k == dim_index {
369                                if keepdim {
370                                    idx_buf[current_dim] = 0;
371                                    current_dim += 1;
372                                }
373                            } else {
374                                idx_buf[current_dim] = i[k];
375                                current_dim += 1;
376                            }
377                        }
378
379                        let mut linear = 0;
380                        let mut stride = 1;
381                        for j in (0..new_shape.len()).rev() {
382                            linear += idx_buf[j] * stride;
383                            stride *= new_shape[j];
384                        }
385
386                        if val < mins[linear] {
387                            mins[linear] = val;
388                            argmins[linear] = i[dim_index] as u64;
389                        }
390                        if val > maxs[linear] {
391                            maxs[linear] = val;
392                            argmaxs[linear] = i[dim_index] as u64;
393                        }
394                    }
395
396                    Ok((
397                        Tensor::from_vec(argmins, new_shape)?,
398                        Tensor::from_vec(argmaxs, new_shape)?,
399                    ))
400                }};
401            }
402            match self.dtype() {
403                DType::Fp32 => noncontig_argmin_argmax!(f32, f32::INFINITY, f32::NEG_INFINITY),
404                DType::Fp64 => noncontig_argmin_argmax!(f64, f64::INFINITY, f64::NEG_INFINITY),
405                DType::Fp16 => noncontig_argmin_argmax!(
406                    f16,
407                    f16::from_f32(f32::INFINITY),
408                    f16::from_f32(f32::NEG_INFINITY)
409                ),
410                DType::Bf16 => noncontig_argmin_argmax!(
411                    bf16,
412                    bf16::from_f32(f32::INFINITY),
413                    bf16::from_f32(f32::NEG_INFINITY)
414                ),
415                DType::Int8 => noncontig_argmin_argmax!(i8, i8::MAX, i8::MIN),
416                DType::Int16 => noncontig_argmin_argmax!(i16, i16::MAX, i16::MIN),
417                DType::Int32 => noncontig_argmin_argmax!(i32, i32::MAX, i32::MIN),
418                DType::Int64 => noncontig_argmin_argmax!(i64, i64::MAX, i64::MIN),
419                DType::Uint8 => noncontig_argmin_argmax!(u8, u8::MAX, u8::MIN),
420                DType::Uint16 => noncontig_argmin_argmax!(u16, u16::MAX, u16::MIN),
421                DType::Uint32 => noncontig_argmin_argmax!(u32, u32::MAX, u32::MIN),
422                DType::Uint64 => noncontig_argmin_argmax!(u64, u64::MAX, u64::MIN),
423                _ => anyhow::bail!("Argmin/Argmax not supported for dtype {:?}", self.dtype()),
424            }
425        }
426    }
427}
428
429#[cfg(test)]
430mod tests {
431    use crate::*;
432    use anyhow::Result;
433
434    #[test]
435    fn test_argmin_argmax_1d_basic() -> Result<()> {
436        let data = vec![3.0f32, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0];
437        let tensor = Tensor::from_vec(data, [7])?;
438
439        let (argmin_result, argmax_result) = tensor.argmin_argmax(0)?;
440
441        // Should return scalars (empty shape)
442        assert_eq!(argmin_result.dims(), &[] as &[usize]);
443        assert_eq!(argmax_result.dims(), &[] as &[usize]);
444
445        let argmin_val = argmin_result.as_slice::<u64>()?[0];
446        let argmax_val = argmax_result.as_slice::<u64>()?[0];
447
448        // First occurrence of minimum value 1.0 is at index 1
449        assert_eq!(argmin_val, 1);
450        // Maximum value 9.0 is at index 5
451        assert_eq!(argmax_val, 5);
452
453        Ok(())
454    }
455
456    #[test]
457    fn test_argmin_argmax_2d_dim0() -> Result<()> {
458        let data = vec![1.0f32, 5.0, 3.0, 2.0, 8.0, 1.0];
459        let tensor = Tensor::from_vec(data, [2, 3])?;
460
461        let (argmin_result, argmax_result) = tensor.argmin_argmax(0)?;
462
463        assert_eq!(argmin_result.dims(), &[3]);
464        assert_eq!(argmax_result.dims(), &[3]);
465
466        let argmin_vals = argmin_result.as_slice::<u64>()?;
467        let argmax_vals = argmax_result.as_slice::<u64>()?;
468
469        // Argmin along dim 0: [argmin(1,2), argmin(5,8), argmin(3,1)] = [0, 0, 1]
470        assert_eq!(argmin_vals, &[0, 0, 1]);
471        // Argmax along dim 0: [argmax(1,2), argmax(5,8), argmax(3,1)] = [1, 1, 0]
472        assert_eq!(argmax_vals, &[1, 1, 0]);
473
474        Ok(())
475    }
476
477    #[test]
478    fn test_argmin_argmax_2d_dim1() -> Result<()> {
479        let data = vec![1.0f32, 5.0, 3.0, 2.0, 8.0, 1.0];
480        let tensor = Tensor::from_vec(data, [2, 3])?;
481
482        let (argmin_result, argmax_result) = tensor.argmin_argmax(1)?;
483
484        assert_eq!(argmin_result.dims(), &[2]);
485        assert_eq!(argmax_result.dims(), &[2]);
486
487        let argmin_vals = argmin_result.as_slice::<u64>()?;
488        let argmax_vals = argmax_result.as_slice::<u64>()?;
489
490        // Argmin along dim 1: [argmin(1,5,3), argmin(2,8,1)] = [0, 2]
491        assert_eq!(argmin_vals, &[0, 2]);
492        // Argmax along dim 1: [argmax(1,5,3), argmax(2,8,1)] = [1, 1]
493        assert_eq!(argmax_vals, &[1, 1]);
494
495        Ok(())
496    }
497
498    #[test]
499    fn test_argmin_argmax_3d_basic() -> Result<()> {
500        let data: Vec<f32> = (1..=24).map(|x| x as f32).collect();
501        let tensor = Tensor::from_vec(data, [2, 3, 4])?;
502
503        // Test along dimension 0
504        let (argmin_result, argmax_result) = tensor.argmin_argmax(0)?;
505        assert_eq!(argmin_result.dims(), &[3, 4]);
506        assert_eq!(argmax_result.dims(), &[3, 4]);
507
508        // Test along dimension 1
509        let (argmin_result, argmax_result) = tensor.argmin_argmax(1)?;
510        assert_eq!(argmin_result.dims(), &[2, 4]);
511        assert_eq!(argmax_result.dims(), &[2, 4]);
512
513        // Test along dimension 2
514        let (argmin_result, argmax_result) = tensor.argmin_argmax(2)?;
515        assert_eq!(argmin_result.dims(), &[2, 3]);
516        assert_eq!(argmax_result.dims(), &[2, 3]);
517
518        Ok(())
519    }
520
521    #[test]
522    fn test_argmin_argmax_keepdim_1d() -> Result<()> {
523        let data = vec![3.0f32, 1.0, 4.0, 1.0, 5.0];
524        let tensor = Tensor::from_vec(data, [5])?;
525
526        let (argmin_result, argmax_result) = tensor.argmin_argmax_keepdim(0)?;
527
528        // Should keep dimension as [1]
529        assert_eq!(argmin_result.dims(), &[1]);
530        assert_eq!(argmax_result.dims(), &[1]);
531
532        let argmin_val = argmin_result.as_slice::<u64>()?[0];
533        let argmax_val = argmax_result.as_slice::<u64>()?[0];
534
535        // First occurrence of minimum value 1.0 is at index 1
536        assert_eq!(argmin_val, 1);
537        // Maximum value 5.0 is at index 4
538        assert_eq!(argmax_val, 4);
539
540        Ok(())
541    }
542
543    #[test]
544    fn test_argmin_argmax_keepdim_2d() -> Result<()> {
545        let data = vec![1.0f32, 5.0, 3.0, 2.0, 8.0, 1.0];
546        let tensor = Tensor::from_vec(data, [2, 3])?;
547
548        // Test keepdim along dimension 0
549        let (argmin_result, argmax_result) = tensor.argmin_argmax_keepdim(0)?;
550        assert_eq!(argmin_result.dims(), &[1, 3]);
551        assert_eq!(argmax_result.dims(), &[1, 3]);
552
553        // Test keepdim along dimension 1
554        let (argmin_result, argmax_result) = tensor.argmin_argmax_keepdim(1)?;
555        assert_eq!(argmin_result.dims(), &[2, 1]);
556        assert_eq!(argmax_result.dims(), &[2, 1]);
557
558        Ok(())
559    }
560
561    #[test]
562    fn test_argmin_argmax_keepdim_3d() -> Result<()> {
563        let data: Vec<f32> = (1..=24).map(|x| x as f32).collect();
564        let tensor = Tensor::from_vec(data, [2, 3, 4])?;
565
566        // Test keepdim along different dimensions
567        let (argmin_result, argmax_result) = tensor.argmin_argmax_keepdim(0)?;
568        assert_eq!(argmin_result.dims(), &[1, 3, 4]);
569        assert_eq!(argmax_result.dims(), &[1, 3, 4]);
570
571        let (argmin_result, argmax_result) = tensor.argmin_argmax_keepdim(1)?;
572        assert_eq!(argmin_result.dims(), &[2, 1, 4]);
573        assert_eq!(argmax_result.dims(), &[2, 1, 4]);
574
575        let (argmin_result, argmax_result) = tensor.argmin_argmax_keepdim(2)?;
576        assert_eq!(argmin_result.dims(), &[2, 3, 1]);
577        assert_eq!(argmax_result.dims(), &[2, 3, 1]);
578
579        Ok(())
580    }
581
582    #[test]
583    fn test_argmin_argmax_non_contiguous_2d() -> Result<()> {
584        // Test argmin_argmax with non-contiguous tensor using permute
585        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
586        let tensor = Tensor::from_vec(data, [2, 3])?;
587
588        // Create non-contiguous tensor by permuting dimensions
589        let permuted = tensor.clone().permute([1, 0])?; // [3, 2]
590
591        // Test argmin_argmax along different dimensions
592        let (argmin_result, argmax_result) = permuted.argmin_argmax(0)?;
593        assert_eq!(argmin_result.dims(), &[2]);
594        assert_eq!(argmax_result.dims(), &[2]);
595
596        let argmin_vals = argmin_result.as_slice::<u64>()?;
597        let argmax_vals = argmax_result.as_slice::<u64>()?;
598
599        // After permute: [[1,4], [2,5], [3,6]]
600        // Argmin along dim 0: [argmin(1,2,3), argmin(4,5,6)] = [0, 0]
601        assert_eq!(argmin_vals, &[0, 0]);
602        // Argmax along dim 0: [argmax(1,2,3), argmax(4,5,6)] = [2, 2]
603        assert_eq!(argmax_vals, &[2, 2]);
604
605        let (argmin_result, argmax_result) = permuted.argmin_argmax(1)?;
606        assert_eq!(argmin_result.dims(), &[3]);
607        assert_eq!(argmax_result.dims(), &[3]);
608
609        let argmin_vals = argmin_result.as_slice::<u64>()?;
610        let argmax_vals = argmax_result.as_slice::<u64>()?;
611
612        // Argmin along dim 1: [argmin(1,4), argmin(2,5), argmin(3,6)] = [0, 0, 0]
613        assert_eq!(argmin_vals, &[0, 0, 0]);
614        // Argmax along dim 1: [argmax(1,4), argmax(2,5), argmax(3,6)] = [1, 1, 1]
615        assert_eq!(argmax_vals, &[1, 1, 1]);
616
617        Ok(())
618    }
619
620    #[test]
621    fn test_argmin_argmax_non_contiguous_3d() -> Result<()> {
622        // Test argmin_argmax with 3D non-contiguous tensor
623        let data: Vec<f32> = (1..=24).map(|x| x as f32).collect();
624        let tensor = Tensor::from_vec(data, [2, 3, 4])?;
625
626        // Create non-contiguous tensor by permuting dimensions
627        let permuted = tensor.clone().permute([2, 0, 1])?; // [4, 2, 3]
628
629        // Test argmin_argmax along different dimensions
630        let (argmin_result, argmax_result) = permuted.argmin_argmax(0)?;
631        assert_eq!(argmin_result.dims(), &[2, 3]);
632        assert_eq!(argmax_result.dims(), &[2, 3]);
633
634        let (argmin_result, argmax_result) = permuted.argmin_argmax(1)?;
635        assert_eq!(argmin_result.dims(), &[4, 3]);
636        assert_eq!(argmax_result.dims(), &[4, 3]);
637
638        let (argmin_result, argmax_result) = permuted.argmin_argmax(2)?;
639        assert_eq!(argmin_result.dims(), &[4, 2]);
640        assert_eq!(argmax_result.dims(), &[4, 2]);
641
642        Ok(())
643    }
644
645    #[test]
646    fn test_argmin_argmax_keepdim_non_contiguous() -> Result<()> {
647        // Test argmin_argmax_keepdim with non-contiguous tensor
648        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
649        let tensor = Tensor::from_vec(data, [2, 2, 2])?;
650
651        // Create non-contiguous tensor
652        let permuted = tensor.clone().permute([2, 1, 0])?; // [2, 2, 2]
653
654        // Test argmin_argmax_keepdim along different dimensions
655        let (argmin_result, argmax_result) = permuted.argmin_argmax_keepdim(0)?;
656        assert_eq!(argmin_result.dims(), &[1, 2, 2]);
657        assert_eq!(argmax_result.dims(), &[1, 2, 2]);
658
659        let (argmin_result, argmax_result) = permuted.argmin_argmax_keepdim(1)?;
660        assert_eq!(argmin_result.dims(), &[2, 1, 2]);
661        assert_eq!(argmax_result.dims(), &[2, 1, 2]);
662
663        let (argmin_result, argmax_result) = permuted.argmin_argmax_keepdim(2)?;
664        assert_eq!(argmin_result.dims(), &[2, 2, 1]);
665        assert_eq!(argmax_result.dims(), &[2, 2, 1]);
666
667        Ok(())
668    }
669
670    #[test]
671    fn test_argmin_argmax_different_data_types() -> Result<()> {
672        // Test argmin_argmax with different data types
673
674        // Test with i32
675        let data_i32 = vec![5i32, 1, 9, 3, 7, 2];
676        let tensor_i32 = Tensor::from_vec(data_i32, [2, 3])?;
677        let (argmin_result, argmax_result) = tensor_i32.argmin_argmax(1)?;
678
679        let argmin_vals = argmin_result.as_slice::<u64>()?;
680        let argmax_vals = argmax_result.as_slice::<u64>()?;
681
682        // Argmin along dim 1: [argmin(5,1,9), argmin(3,7,2)] = [1, 2]
683        assert_eq!(argmin_vals, &[1, 2]);
684        // Argmax along dim 1: [argmax(5,1,9), argmax(3,7,2)] = [2, 1]
685        assert_eq!(argmax_vals, &[2, 1]);
686
687        // Test with u32
688        let data_u32 = vec![10u32, 20, 5, 15];
689        let tensor_u32 = Tensor::from_vec(data_u32, [2, 2])?;
690        let (argmin_result, argmax_result) = tensor_u32.argmin_argmax(0)?;
691
692        let argmin_vals = argmin_result.as_slice::<u64>()?;
693        let argmax_vals = argmax_result.as_slice::<u64>()?;
694
695        // Argmin along dim 0: [argmin(10,5), argmin(20,15)] = [1, 1]
696        assert_eq!(argmin_vals, &[1, 1]);
697        // Argmax along dim 0: [argmax(10,5), argmax(20,15)] = [0, 0]
698        assert_eq!(argmax_vals, &[0, 0]);
699
700        Ok(())
701    }
702
703    #[test]
704    fn test_argmin_argmax_special_values() -> Result<()> {
705        // Test argmin_argmax with special floating point values
706
707        // Test with infinity
708        let data_inf = vec![1.0f32, f32::INFINITY, 3.0, f32::NEG_INFINITY];
709        let tensor_inf = Tensor::from_vec(data_inf, [4])?;
710        let (argmin_result, argmax_result) = tensor_inf.argmin_argmax(0)?;
711
712        let argmin_val = argmin_result.as_slice::<u64>()?[0];
713        let argmax_val = argmax_result.as_slice::<u64>()?[0];
714
715        // NEG_INFINITY is at index 3, INFINITY is at index 1
716        assert_eq!(argmin_val, 3);
717        assert_eq!(argmax_val, 1);
718
719        // Test with NaN
720        let data_nan = vec![1.0f32, f32::NAN, 3.0];
721        let tensor_nan = Tensor::from_vec(data_nan, [3])?;
722        let (argmin_result, argmax_result) = tensor_nan.argmin_argmax(0)?;
723
724        let argmin_val = argmin_result.as_slice::<u64>()?[0];
725        let argmax_val = argmax_result.as_slice::<u64>()?[0];
726
727        // NaN behavior: typically returns the index of NaN or first non-NaN value
728        // The exact behavior depends on the backend implementation
729        assert!(argmin_val < 3);
730        assert!(argmax_val < 3);
731
732        Ok(())
733    }
734
735    #[test]
736    fn test_argmin_argmax_edge_cases() -> Result<()> {
737        // Test various edge cases
738
739        // Single element tensor
740        let single = Tensor::from_vec(vec![42.0f32], [1])?;
741        let (argmin_result, argmax_result) = single.argmin_argmax(0)?;
742
743        let argmin_val = argmin_result.as_slice::<u64>()?[0];
744        let argmax_val = argmax_result.as_slice::<u64>()?[0];
745
746        assert_eq!(argmin_val, 0);
747        assert_eq!(argmax_val, 0);
748
749        // Tensor with all same values
750        let same = Tensor::from_vec(vec![5.0f32, 5.0, 5.0, 5.0], [2, 2])?;
751        let (argmin_result, argmax_result) = same.argmin_argmax(0)?;
752
753        let argmin_vals = argmin_result.as_slice::<u64>()?;
754        let argmax_vals = argmax_result.as_slice::<u64>()?;
755
756        // Should return first occurrence (index 0) for both min and max
757        assert_eq!(argmin_vals, &[0, 0]);
758        assert_eq!(argmax_vals, &[0, 0]);
759
760        Ok(())
761    }
762
763    #[test]
764    fn test_argmin_argmax_rectangular_tensors() -> Result<()> {
765        // Test argmin_argmax with rectangular (non-square) tensors
766
767        // 1x5 tensor
768        let data_1x5 = vec![5.0f32, 1.0, 9.0, 3.0, 7.0];
769        let tensor_1x5 = Tensor::from_vec(data_1x5, [1, 5])?;
770
771        let (argmin_result, argmax_result) = tensor_1x5.argmin_argmax(0)?;
772        assert_eq!(argmin_result.dims(), &[5]);
773        assert_eq!(argmax_result.dims(), &[5]);
774
775        let (argmin_result, argmax_result) = tensor_1x5.argmin_argmax(1)?;
776        assert_eq!(argmin_result.dims(), &[1]);
777        assert_eq!(argmax_result.dims(), &[1]);
778
779        let argmin_val = argmin_result.as_slice::<u64>()?[0];
780        let argmax_val = argmax_result.as_slice::<u64>()?[0];
781        assert_eq!(argmin_val, 1); // min value 1.0 at index 1
782        assert_eq!(argmax_val, 2); // max value 9.0 at index 2
783
784        // 5x1 tensor
785        let data_5x1 = vec![10.0f32, 20.0, 5.0, 30.0, 15.0];
786        let tensor_5x1 = Tensor::from_vec(data_5x1, [5, 1])?;
787
788        let (argmin_result, argmax_result) = tensor_5x1.argmin_argmax(0)?;
789        assert_eq!(argmin_result.dims(), &[1]);
790        assert_eq!(argmax_result.dims(), &[1]);
791
792        let argmin_val = argmin_result.as_slice::<u64>()?[0];
793        let argmax_val = argmax_result.as_slice::<u64>()?[0];
794        assert_eq!(argmin_val, 2); // min value 5.0 at index 2
795        assert_eq!(argmax_val, 3); // max value 30.0 at index 3
796
797        let (argmin_result, argmax_result) = tensor_5x1.argmin_argmax(1)?;
798        assert_eq!(argmin_result.dims(), &[5]);
799        assert_eq!(argmax_result.dims(), &[5]);
800
801        Ok(())
802    }
803
804    #[test]
805    fn test_argmin_argmax_consistency_with_individual_ops() -> Result<()> {
806        // Test that argmin_argmax results are consistent with individual argmin and argmax operations
807        let data = vec![3.0f32, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0];
808        let tensor = Tensor::from_vec(data, [2, 4])?;
809
810        // Test along dimension 0
811        let (argmin_result, argmax_result) = tensor.argmin_argmax(0)?;
812        let individual_argmin = tensor.argmin(0)?;
813        let individual_argmax = tensor.argmax(0)?;
814
815        let argmin_vals = argmin_result.as_slice::<u64>()?;
816        let argmax_vals = argmax_result.as_slice::<u64>()?;
817        let individual_argmin_vals = individual_argmin.as_slice::<u64>()?;
818        let individual_argmax_vals = individual_argmax.as_slice::<u64>()?;
819
820        for (i, (&argmin_val, &argmax_val)) in
821            argmin_vals.iter().zip(argmax_vals.iter()).enumerate()
822        {
823            assert_eq!(argmin_val, individual_argmin_vals[i]);
824            assert_eq!(argmax_val, individual_argmax_vals[i]);
825        }
826
827        // Test along dimension 1
828        let (argmin_result, argmax_result) = tensor.argmin_argmax(1)?;
829        let individual_argmin = tensor.argmin(1)?;
830        let individual_argmax = tensor.argmax(1)?;
831
832        let argmin_vals = argmin_result.as_slice::<u64>()?;
833        let argmax_vals = argmax_result.as_slice::<u64>()?;
834        let individual_argmin_vals = individual_argmin.as_slice::<u64>()?;
835        let individual_argmax_vals = individual_argmax.as_slice::<u64>()?;
836
837        for (i, (&argmin_val, &argmax_val)) in
838            argmin_vals.iter().zip(argmax_vals.iter()).enumerate()
839        {
840            assert_eq!(argmin_val, individual_argmin_vals[i]);
841            assert_eq!(argmax_val, individual_argmax_vals[i]);
842        }
843
844        Ok(())
845    }
846
847    #[test]
848    fn test_argmin_argmax_large_tensor() -> Result<()> {
849        // Test argmin_argmax with a larger tensor
850        let size = 1000;
851        let data: Vec<f32> = (0..size).map(|i| (i % 100) as f32).collect();
852        let tensor = Tensor::from_vec(data, [10, 100])?;
853
854        let (argmin_result, argmax_result) = tensor.argmin_argmax(1)?;
855
856        assert_eq!(argmin_result.dims(), &[10]);
857        assert_eq!(argmax_result.dims(), &[10]);
858
859        let argmin_vals = argmin_result.as_slice::<u64>()?;
860        let argmax_vals = argmax_result.as_slice::<u64>()?;
861
862        // Each row should have argmin=0 and argmax=99
863        for (&argmin_val, &argmax_val) in argmin_vals.iter().zip(argmax_vals.iter()) {
864            assert_eq!(argmin_val, 0);
865            assert_eq!(argmax_val, 99);
866        }
867
868        Ok(())
869    }
870
871    #[test]
872    fn test_argmin_argmax_empty_tensor_error() -> Result<()> {
873        // Test that argmin_argmax fails gracefully with empty tensor
874        let empty_tensor = Tensor::from_vec(Vec::<f32>::new(), [0])?;
875
876        let result = empty_tensor.argmin_argmax(0);
877        assert!(result.is_err());
878
879        let error_msg = result.unwrap_err().to_string();
880        assert!(error_msg.contains("size 0"));
881
882        Ok(())
883    }
884
885    #[test]
886    fn test_argmin_argmax_invalid_dimension() -> Result<()> {
887        // Test argmin_argmax with invalid dimension
888        let data = vec![1.0f32, 2.0, 3.0];
889        let tensor = Tensor::from_vec(data, [3])?;
890
891        // Should fail for dimension >= rank
892        assert!(tensor.argmin_argmax(1).is_err());
893        assert!(tensor.argmin_argmax(2).is_err());
894
895        Ok(())
896    }
897
898    #[test]
899    fn test_argmin_argmax_first_occurrence() -> Result<()> {
900        // Test that argmin_argmax returns the first occurrence of min/max values
901        let data = vec![3.0f32, 1.0, 4.0, 1.0, 5.0, 5.0, 2.0];
902        let tensor = Tensor::from_vec(data, [7])?;
903
904        let (argmin_result, argmax_result) = tensor.argmin_argmax(0)?;
905
906        let argmin_val = argmin_result.as_slice::<u64>()?[0];
907        let argmax_val = argmax_result.as_slice::<u64>()?[0];
908
909        // First occurrence of minimum value 1.0 is at index 1
910        assert_eq!(argmin_val, 1);
911        // First occurrence of maximum value 5.0 is at index 4
912        assert_eq!(argmax_val, 4);
913
914        Ok(())
915    }
916}