1use anyhow::Result;
2use half::{bf16, f16};
3
4use crate::{global_backend, DType, Dim, OpsTrait, StorageTrait, Tensor, TensorBase, UninitVec};
5
6impl<S: StorageTrait> TensorBase<S> {
7 pub fn argmin_argmax<D: Dim + Clone>(&self, dim: D) -> Result<(Tensor, Tensor)> {
8 self.argmin_argmax_impl(dim, false)
9 }
10
11 pub fn argmin_argmax_keepdim<D: Dim + Clone>(&self, dim: D) -> Result<(Tensor, Tensor)> {
12 self.argmin_argmax_impl(dim, true)
13 }
14
15 pub fn argmin_argmax_impl<D: Dim + Clone>(
16 &self,
17 dim: D,
18 keepdim: bool,
19 ) -> Result<(Tensor, Tensor)> {
20 let dim_index = dim.to_dim(self.rank())?;
21
22 if self.shape()[dim_index] == 0 {
23 anyhow::bail!("Cannot find argmin/argmax of dimension with size 0");
24 }
25
26 if self.is_contiguous() && self.can_reduce_over_last_dims(&[dim_index]) {
28 let backend = global_backend();
29 let shape = self.shape();
30 let reduce_size = shape[dim_index];
31 let output_size = self.numel() / reduce_size;
32 let (new_shape, _) =
33 crate::reduce::reduce_shape_stride(self.shape, &[dim_index], keepdim);
34
35 match self.dtype() {
36 DType::Fp32 => {
37 let data = self.as_slice::<f32>()?;
38
39 let mut out_argmin = UninitVec::<u64>::new(output_size);
40 let mut out_argmax = UninitVec::<u64>::new(output_size);
41
42 let dst_argmin = out_argmin.as_mut_slice();
43 let dst_argmax = out_argmax.as_mut_slice();
44
45 for i in 0..output_size {
46 let start = i * reduce_size;
47 let end = start + reduce_size;
48 let (min_idx, max_idx) = backend.min_max_i_f32(&data[start..end]);
49
50 dst_argmin[i] = min_idx;
51 dst_argmax[i] = max_idx;
52 }
53
54 let out_argmin = unsafe { out_argmin.finalize() };
55 let out_argmax = unsafe { out_argmax.finalize() };
56
57 Ok((
58 Tensor::from_vec(out_argmin, new_shape)?,
59 Tensor::from_vec(out_argmax, new_shape)?,
60 ))
61 }
62 DType::Fp64 => {
63 let data = self.as_slice::<f64>()?;
64
65 let mut out_argmin = UninitVec::<u64>::new(output_size);
66 let mut out_argmax = UninitVec::<u64>::new(output_size);
67
68 let dst_argmin = out_argmin.as_mut_slice();
69 let dst_argmax = out_argmax.as_mut_slice();
70
71 for i in 0..output_size {
72 let start = i * reduce_size;
73 let end = start + reduce_size;
74 let (min_idx, max_idx) = backend.min_max_i_f64(&data[start..end]);
75
76 dst_argmin[i] = min_idx;
77 dst_argmax[i] = max_idx;
78 }
79
80 let out_argmin = unsafe { out_argmin.finalize() };
81 let out_argmax = unsafe { out_argmax.finalize() };
82
83 Ok((
84 Tensor::from_vec(out_argmin, new_shape)?,
85 Tensor::from_vec(out_argmax, new_shape)?,
86 ))
87 }
88 DType::Fp16 => {
89 let data = self.as_slice::<f16>()?;
90
91 let mut out_argmin = UninitVec::<u64>::new(output_size);
92 let mut out_argmax = UninitVec::<u64>::new(output_size);
93
94 let dst_argmin = out_argmin.as_mut_slice();
95 let dst_argmax = out_argmax.as_mut_slice();
96
97 for i in 0..output_size {
98 let start = i * reduce_size;
99 let end = start + reduce_size;
100 let (min_idx, max_idx) = backend.min_max_i_f16(&data[start..end]);
101
102 dst_argmin[i] = min_idx;
103 dst_argmax[i] = max_idx;
104 }
105
106 let out_argmin = unsafe { out_argmin.finalize() };
107 let out_argmax = unsafe { out_argmax.finalize() };
108
109 Ok((
110 Tensor::from_vec(out_argmin, new_shape)?,
111 Tensor::from_vec(out_argmax, new_shape)?,
112 ))
113 }
114 DType::Bf16 => {
115 let data = self.as_slice::<bf16>()?;
116
117 let mut out_argmin = UninitVec::<u64>::new(output_size);
118 let mut out_argmax = UninitVec::<u64>::new(output_size);
119
120 let dst_argmin = out_argmin.as_mut_slice();
121 let dst_argmax = out_argmax.as_mut_slice();
122
123 for i in 0..output_size {
124 let start = i * reduce_size;
125 let end = start + reduce_size;
126 let (min_idx, max_idx) = backend.min_max_i_bf16(&data[start..end]);
127
128 dst_argmin[i] = min_idx;
129 dst_argmax[i] = max_idx;
130 }
131
132 let out_argmin = unsafe { out_argmin.finalize() };
133 let out_argmax = unsafe { out_argmax.finalize() };
134
135 Ok((
136 Tensor::from_vec(out_argmin, new_shape)?,
137 Tensor::from_vec(out_argmax, new_shape)?,
138 ))
139 }
140 DType::Int8 => {
141 let data = self.as_slice::<i8>()?;
142
143 let mut out_argmin = UninitVec::<u64>::new(output_size);
144 let mut out_argmax = UninitVec::<u64>::new(output_size);
145
146 let dst_argmin = out_argmin.as_mut_slice();
147 let dst_argmax = out_argmax.as_mut_slice();
148
149 for i in 0..output_size {
150 let start = i * reduce_size;
151 let end = start + reduce_size;
152 let (min_idx, max_idx) = backend.min_max_i_i8(&data[start..end]);
153
154 dst_argmin[i] = min_idx;
155 dst_argmax[i] = max_idx;
156 }
157
158 let out_argmin = unsafe { out_argmin.finalize() };
159 let out_argmax = unsafe { out_argmax.finalize() };
160
161 Ok((
162 Tensor::from_vec(out_argmin, new_shape)?,
163 Tensor::from_vec(out_argmax, new_shape)?,
164 ))
165 }
166 DType::Int16 => {
167 let data = self.as_slice::<i16>()?;
168
169 let mut out_argmin = UninitVec::<u64>::new(output_size);
170 let mut out_argmax = UninitVec::<u64>::new(output_size);
171
172 let dst_argmin = out_argmin.as_mut_slice();
173 let dst_argmax = out_argmax.as_mut_slice();
174
175 for i in 0..output_size {
176 let start = i * reduce_size;
177 let end = start + reduce_size;
178 let (min_idx, max_idx) = backend.min_max_i_i16(&data[start..end]);
179
180 dst_argmin[i] = min_idx;
181 dst_argmax[i] = max_idx;
182 }
183
184 let out_argmin = unsafe { out_argmin.finalize() };
185 let out_argmax = unsafe { out_argmax.finalize() };
186
187 Ok((
188 Tensor::from_vec(out_argmin, new_shape)?,
189 Tensor::from_vec(out_argmax, new_shape)?,
190 ))
191 }
192 DType::Int32 => {
193 let data = self.as_slice::<i32>()?;
194
195 let mut out_argmin = UninitVec::<u64>::new(output_size);
196 let mut out_argmax = UninitVec::<u64>::new(output_size);
197
198 let dst_argmin = out_argmin.as_mut_slice();
199 let dst_argmax = out_argmax.as_mut_slice();
200
201 for i in 0..output_size {
202 let start = i * reduce_size;
203 let end = start + reduce_size;
204 let (min_idx, max_idx) = backend.min_max_i_i32(&data[start..end]);
205
206 dst_argmin[i] = min_idx;
207 dst_argmax[i] = max_idx;
208 }
209
210 let out_argmin = unsafe { out_argmin.finalize() };
211 let out_argmax = unsafe { out_argmax.finalize() };
212
213 Ok((
214 Tensor::from_vec(out_argmin, new_shape)?,
215 Tensor::from_vec(out_argmax, new_shape)?,
216 ))
217 }
218 DType::Int64 => {
219 let data = self.as_slice::<i64>()?;
220
221 let mut out_argmin = UninitVec::<u64>::new(output_size);
222 let mut out_argmax = UninitVec::<u64>::new(output_size);
223
224 let dst_argmin = out_argmin.as_mut_slice();
225 let dst_argmax = out_argmax.as_mut_slice();
226
227 for i in 0..output_size {
228 let start = i * reduce_size;
229 let end = start + reduce_size;
230 let (min_idx, max_idx) = backend.min_max_i_i64(&data[start..end]);
231
232 dst_argmin[i] = min_idx;
233 dst_argmax[i] = max_idx;
234 }
235
236 let out_argmin = unsafe { out_argmin.finalize() };
237 let out_argmax = unsafe { out_argmax.finalize() };
238
239 Ok((
240 Tensor::from_vec(out_argmin, new_shape)?,
241 Tensor::from_vec(out_argmax, new_shape)?,
242 ))
243 }
244 DType::Uint8 => {
245 let data = self.as_slice::<u8>()?;
246
247 let mut out_argmin = UninitVec::<u64>::new(output_size);
248 let mut out_argmax = UninitVec::<u64>::new(output_size);
249
250 let dst_argmin = out_argmin.as_mut_slice();
251 let dst_argmax = out_argmax.as_mut_slice();
252
253 for i in 0..output_size {
254 let start = i * reduce_size;
255 let end = start + reduce_size;
256 let (min_idx, max_idx) = backend.min_max_i_u8(&data[start..end]);
257
258 dst_argmin[i] = min_idx;
259 dst_argmax[i] = max_idx;
260 }
261
262 let out_argmin = unsafe { out_argmin.finalize() };
263 let out_argmax = unsafe { out_argmax.finalize() };
264
265 Ok((
266 Tensor::from_vec(out_argmin, new_shape)?,
267 Tensor::from_vec(out_argmax, new_shape)?,
268 ))
269 }
270 DType::Uint16 => {
271 let data = self.as_slice::<u16>()?;
272
273 let mut out_argmin = UninitVec::<u64>::new(output_size);
274 let mut out_argmax = UninitVec::<u64>::new(output_size);
275
276 let dst_argmin = out_argmin.as_mut_slice();
277 let dst_argmax = out_argmax.as_mut_slice();
278
279 for i in 0..output_size {
280 let start = i * reduce_size;
281 let end = start + reduce_size;
282 let (min_idx, max_idx) = backend.min_max_i_u16(&data[start..end]);
283
284 dst_argmin[i] = min_idx;
285 dst_argmax[i] = max_idx;
286 }
287
288 let out_argmin = unsafe { out_argmin.finalize() };
289 let out_argmax = unsafe { out_argmax.finalize() };
290
291 Ok((
292 Tensor::from_vec(out_argmin, new_shape)?,
293 Tensor::from_vec(out_argmax, new_shape)?,
294 ))
295 }
296 DType::Uint32 => {
297 let data = self.as_slice::<u32>()?;
298
299 let mut out_argmin = UninitVec::<u64>::new(output_size);
300 let mut out_argmax = UninitVec::<u64>::new(output_size);
301
302 let dst_argmin = out_argmin.as_mut_slice();
303 let dst_argmax = out_argmax.as_mut_slice();
304
305 for i in 0..output_size {
306 let start = i * reduce_size;
307 let end = start + reduce_size;
308 let (min_idx, max_idx) = backend.min_max_i_u32(&data[start..end]);
309
310 dst_argmin[i] = min_idx;
311 dst_argmax[i] = max_idx;
312 }
313
314 let out_argmin = unsafe { out_argmin.finalize() };
315 let out_argmax = unsafe { out_argmax.finalize() };
316
317 Ok((
318 Tensor::from_vec(out_argmin, new_shape)?,
319 Tensor::from_vec(out_argmax, new_shape)?,
320 ))
321 }
322 DType::Uint64 => {
323 let data = self.as_slice::<u64>()?;
324
325 let mut out_argmin = UninitVec::<u64>::new(output_size);
326 let mut out_argmax = UninitVec::<u64>::new(output_size);
327
328 let dst_argmin = out_argmin.as_mut_slice();
329 let dst_argmax = out_argmax.as_mut_slice();
330
331 for i in 0..output_size {
332 let start = i * reduce_size;
333 let end = start + reduce_size;
334 let (min_idx, max_idx) = backend.min_max_i_u64(&data[start..end]);
335
336 dst_argmin[i] = min_idx;
337 dst_argmax[i] = max_idx;
338 }
339
340 let out_argmin = unsafe { out_argmin.finalize() };
341 let out_argmax = unsafe { out_argmax.finalize() };
342
343 Ok((
344 Tensor::from_vec(out_argmin, new_shape)?,
345 Tensor::from_vec(out_argmax, new_shape)?,
346 ))
347 }
348 _ => anyhow::bail!("Argmin/Argmax not supported for dtype {:?}", self.dtype()),
349 }
350 } else {
351 let (new_shape, _) = crate::reduce_shape_stride(self.shape, &[dim_index], keepdim);
352
353 let result_size = new_shape.iter().product();
354 macro_rules! noncontig_argmin_argmax {
355 ($t:ty, $min_init:expr, $max_init:expr) => {{
356 let mut mins = vec![$min_init; result_size];
357 let mut maxs = vec![$max_init; result_size];
358 let mut argmins = vec![0u64; result_size];
359 let mut argmaxs = vec![0u64; result_size];
360 let mut idx_buf = vec![0; new_shape.len()];
361
362 for elem in self.iter() {
363 let i = elem.indices;
364 let ptr = unsafe { elem.as_ptr(self.as_ptr()) };
365 let val = unsafe { *(ptr as *const $t) };
366 let mut current_dim = 0;
367 for k in 0..self.rank() {
368 if k == dim_index {
369 if keepdim {
370 idx_buf[current_dim] = 0;
371 current_dim += 1;
372 }
373 } else {
374 idx_buf[current_dim] = i[k];
375 current_dim += 1;
376 }
377 }
378
379 let mut linear = 0;
380 let mut stride = 1;
381 for j in (0..new_shape.len()).rev() {
382 linear += idx_buf[j] * stride;
383 stride *= new_shape[j];
384 }
385
386 if val < mins[linear] {
387 mins[linear] = val;
388 argmins[linear] = i[dim_index] as u64;
389 }
390 if val > maxs[linear] {
391 maxs[linear] = val;
392 argmaxs[linear] = i[dim_index] as u64;
393 }
394 }
395
396 Ok((
397 Tensor::from_vec(argmins, new_shape)?,
398 Tensor::from_vec(argmaxs, new_shape)?,
399 ))
400 }};
401 }
402 match self.dtype() {
403 DType::Fp32 => noncontig_argmin_argmax!(f32, f32::INFINITY, f32::NEG_INFINITY),
404 DType::Fp64 => noncontig_argmin_argmax!(f64, f64::INFINITY, f64::NEG_INFINITY),
405 DType::Fp16 => noncontig_argmin_argmax!(
406 f16,
407 f16::from_f32(f32::INFINITY),
408 f16::from_f32(f32::NEG_INFINITY)
409 ),
410 DType::Bf16 => noncontig_argmin_argmax!(
411 bf16,
412 bf16::from_f32(f32::INFINITY),
413 bf16::from_f32(f32::NEG_INFINITY)
414 ),
415 DType::Int8 => noncontig_argmin_argmax!(i8, i8::MAX, i8::MIN),
416 DType::Int16 => noncontig_argmin_argmax!(i16, i16::MAX, i16::MIN),
417 DType::Int32 => noncontig_argmin_argmax!(i32, i32::MAX, i32::MIN),
418 DType::Int64 => noncontig_argmin_argmax!(i64, i64::MAX, i64::MIN),
419 DType::Uint8 => noncontig_argmin_argmax!(u8, u8::MAX, u8::MIN),
420 DType::Uint16 => noncontig_argmin_argmax!(u16, u16::MAX, u16::MIN),
421 DType::Uint32 => noncontig_argmin_argmax!(u32, u32::MAX, u32::MIN),
422 DType::Uint64 => noncontig_argmin_argmax!(u64, u64::MAX, u64::MIN),
423 _ => anyhow::bail!("Argmin/Argmax not supported for dtype {:?}", self.dtype()),
424 }
425 }
426 }
427}
428
429#[cfg(test)]
430mod tests {
431 use crate::*;
432 use anyhow::Result;
433
434 #[test]
435 fn test_argmin_argmax_1d_basic() -> Result<()> {
436 let data = vec![3.0f32, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0];
437 let tensor = Tensor::from_vec(data, [7])?;
438
439 let (argmin_result, argmax_result) = tensor.argmin_argmax(0)?;
440
441 assert_eq!(argmin_result.dims(), &[] as &[usize]);
443 assert_eq!(argmax_result.dims(), &[] as &[usize]);
444
445 let argmin_val = argmin_result.as_slice::<u64>()?[0];
446 let argmax_val = argmax_result.as_slice::<u64>()?[0];
447
448 assert_eq!(argmin_val, 1);
450 assert_eq!(argmax_val, 5);
452
453 Ok(())
454 }
455
456 #[test]
457 fn test_argmin_argmax_2d_dim0() -> Result<()> {
458 let data = vec![1.0f32, 5.0, 3.0, 2.0, 8.0, 1.0];
459 let tensor = Tensor::from_vec(data, [2, 3])?;
460
461 let (argmin_result, argmax_result) = tensor.argmin_argmax(0)?;
462
463 assert_eq!(argmin_result.dims(), &[3]);
464 assert_eq!(argmax_result.dims(), &[3]);
465
466 let argmin_vals = argmin_result.as_slice::<u64>()?;
467 let argmax_vals = argmax_result.as_slice::<u64>()?;
468
469 assert_eq!(argmin_vals, &[0, 0, 1]);
471 assert_eq!(argmax_vals, &[1, 1, 0]);
473
474 Ok(())
475 }
476
477 #[test]
478 fn test_argmin_argmax_2d_dim1() -> Result<()> {
479 let data = vec![1.0f32, 5.0, 3.0, 2.0, 8.0, 1.0];
480 let tensor = Tensor::from_vec(data, [2, 3])?;
481
482 let (argmin_result, argmax_result) = tensor.argmin_argmax(1)?;
483
484 assert_eq!(argmin_result.dims(), &[2]);
485 assert_eq!(argmax_result.dims(), &[2]);
486
487 let argmin_vals = argmin_result.as_slice::<u64>()?;
488 let argmax_vals = argmax_result.as_slice::<u64>()?;
489
490 assert_eq!(argmin_vals, &[0, 2]);
492 assert_eq!(argmax_vals, &[1, 1]);
494
495 Ok(())
496 }
497
498 #[test]
499 fn test_argmin_argmax_3d_basic() -> Result<()> {
500 let data: Vec<f32> = (1..=24).map(|x| x as f32).collect();
501 let tensor = Tensor::from_vec(data, [2, 3, 4])?;
502
503 let (argmin_result, argmax_result) = tensor.argmin_argmax(0)?;
505 assert_eq!(argmin_result.dims(), &[3, 4]);
506 assert_eq!(argmax_result.dims(), &[3, 4]);
507
508 let (argmin_result, argmax_result) = tensor.argmin_argmax(1)?;
510 assert_eq!(argmin_result.dims(), &[2, 4]);
511 assert_eq!(argmax_result.dims(), &[2, 4]);
512
513 let (argmin_result, argmax_result) = tensor.argmin_argmax(2)?;
515 assert_eq!(argmin_result.dims(), &[2, 3]);
516 assert_eq!(argmax_result.dims(), &[2, 3]);
517
518 Ok(())
519 }
520
521 #[test]
522 fn test_argmin_argmax_keepdim_1d() -> Result<()> {
523 let data = vec![3.0f32, 1.0, 4.0, 1.0, 5.0];
524 let tensor = Tensor::from_vec(data, [5])?;
525
526 let (argmin_result, argmax_result) = tensor.argmin_argmax_keepdim(0)?;
527
528 assert_eq!(argmin_result.dims(), &[1]);
530 assert_eq!(argmax_result.dims(), &[1]);
531
532 let argmin_val = argmin_result.as_slice::<u64>()?[0];
533 let argmax_val = argmax_result.as_slice::<u64>()?[0];
534
535 assert_eq!(argmin_val, 1);
537 assert_eq!(argmax_val, 4);
539
540 Ok(())
541 }
542
543 #[test]
544 fn test_argmin_argmax_keepdim_2d() -> Result<()> {
545 let data = vec![1.0f32, 5.0, 3.0, 2.0, 8.0, 1.0];
546 let tensor = Tensor::from_vec(data, [2, 3])?;
547
548 let (argmin_result, argmax_result) = tensor.argmin_argmax_keepdim(0)?;
550 assert_eq!(argmin_result.dims(), &[1, 3]);
551 assert_eq!(argmax_result.dims(), &[1, 3]);
552
553 let (argmin_result, argmax_result) = tensor.argmin_argmax_keepdim(1)?;
555 assert_eq!(argmin_result.dims(), &[2, 1]);
556 assert_eq!(argmax_result.dims(), &[2, 1]);
557
558 Ok(())
559 }
560
561 #[test]
562 fn test_argmin_argmax_keepdim_3d() -> Result<()> {
563 let data: Vec<f32> = (1..=24).map(|x| x as f32).collect();
564 let tensor = Tensor::from_vec(data, [2, 3, 4])?;
565
566 let (argmin_result, argmax_result) = tensor.argmin_argmax_keepdim(0)?;
568 assert_eq!(argmin_result.dims(), &[1, 3, 4]);
569 assert_eq!(argmax_result.dims(), &[1, 3, 4]);
570
571 let (argmin_result, argmax_result) = tensor.argmin_argmax_keepdim(1)?;
572 assert_eq!(argmin_result.dims(), &[2, 1, 4]);
573 assert_eq!(argmax_result.dims(), &[2, 1, 4]);
574
575 let (argmin_result, argmax_result) = tensor.argmin_argmax_keepdim(2)?;
576 assert_eq!(argmin_result.dims(), &[2, 3, 1]);
577 assert_eq!(argmax_result.dims(), &[2, 3, 1]);
578
579 Ok(())
580 }
581
582 #[test]
583 fn test_argmin_argmax_non_contiguous_2d() -> Result<()> {
584 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
586 let tensor = Tensor::from_vec(data, [2, 3])?;
587
588 let permuted = tensor.clone().permute([1, 0])?; let (argmin_result, argmax_result) = permuted.argmin_argmax(0)?;
593 assert_eq!(argmin_result.dims(), &[2]);
594 assert_eq!(argmax_result.dims(), &[2]);
595
596 let argmin_vals = argmin_result.as_slice::<u64>()?;
597 let argmax_vals = argmax_result.as_slice::<u64>()?;
598
599 assert_eq!(argmin_vals, &[0, 0]);
602 assert_eq!(argmax_vals, &[2, 2]);
604
605 let (argmin_result, argmax_result) = permuted.argmin_argmax(1)?;
606 assert_eq!(argmin_result.dims(), &[3]);
607 assert_eq!(argmax_result.dims(), &[3]);
608
609 let argmin_vals = argmin_result.as_slice::<u64>()?;
610 let argmax_vals = argmax_result.as_slice::<u64>()?;
611
612 assert_eq!(argmin_vals, &[0, 0, 0]);
614 assert_eq!(argmax_vals, &[1, 1, 1]);
616
617 Ok(())
618 }
619
620 #[test]
621 fn test_argmin_argmax_non_contiguous_3d() -> Result<()> {
622 let data: Vec<f32> = (1..=24).map(|x| x as f32).collect();
624 let tensor = Tensor::from_vec(data, [2, 3, 4])?;
625
626 let permuted = tensor.clone().permute([2, 0, 1])?; let (argmin_result, argmax_result) = permuted.argmin_argmax(0)?;
631 assert_eq!(argmin_result.dims(), &[2, 3]);
632 assert_eq!(argmax_result.dims(), &[2, 3]);
633
634 let (argmin_result, argmax_result) = permuted.argmin_argmax(1)?;
635 assert_eq!(argmin_result.dims(), &[4, 3]);
636 assert_eq!(argmax_result.dims(), &[4, 3]);
637
638 let (argmin_result, argmax_result) = permuted.argmin_argmax(2)?;
639 assert_eq!(argmin_result.dims(), &[4, 2]);
640 assert_eq!(argmax_result.dims(), &[4, 2]);
641
642 Ok(())
643 }
644
645 #[test]
646 fn test_argmin_argmax_keepdim_non_contiguous() -> Result<()> {
647 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
649 let tensor = Tensor::from_vec(data, [2, 2, 2])?;
650
651 let permuted = tensor.clone().permute([2, 1, 0])?; let (argmin_result, argmax_result) = permuted.argmin_argmax_keepdim(0)?;
656 assert_eq!(argmin_result.dims(), &[1, 2, 2]);
657 assert_eq!(argmax_result.dims(), &[1, 2, 2]);
658
659 let (argmin_result, argmax_result) = permuted.argmin_argmax_keepdim(1)?;
660 assert_eq!(argmin_result.dims(), &[2, 1, 2]);
661 assert_eq!(argmax_result.dims(), &[2, 1, 2]);
662
663 let (argmin_result, argmax_result) = permuted.argmin_argmax_keepdim(2)?;
664 assert_eq!(argmin_result.dims(), &[2, 2, 1]);
665 assert_eq!(argmax_result.dims(), &[2, 2, 1]);
666
667 Ok(())
668 }
669
670 #[test]
671 fn test_argmin_argmax_different_data_types() -> Result<()> {
672 let data_i32 = vec![5i32, 1, 9, 3, 7, 2];
676 let tensor_i32 = Tensor::from_vec(data_i32, [2, 3])?;
677 let (argmin_result, argmax_result) = tensor_i32.argmin_argmax(1)?;
678
679 let argmin_vals = argmin_result.as_slice::<u64>()?;
680 let argmax_vals = argmax_result.as_slice::<u64>()?;
681
682 assert_eq!(argmin_vals, &[1, 2]);
684 assert_eq!(argmax_vals, &[2, 1]);
686
687 let data_u32 = vec![10u32, 20, 5, 15];
689 let tensor_u32 = Tensor::from_vec(data_u32, [2, 2])?;
690 let (argmin_result, argmax_result) = tensor_u32.argmin_argmax(0)?;
691
692 let argmin_vals = argmin_result.as_slice::<u64>()?;
693 let argmax_vals = argmax_result.as_slice::<u64>()?;
694
695 assert_eq!(argmin_vals, &[1, 1]);
697 assert_eq!(argmax_vals, &[0, 0]);
699
700 Ok(())
701 }
702
703 #[test]
704 fn test_argmin_argmax_special_values() -> Result<()> {
705 let data_inf = vec![1.0f32, f32::INFINITY, 3.0, f32::NEG_INFINITY];
709 let tensor_inf = Tensor::from_vec(data_inf, [4])?;
710 let (argmin_result, argmax_result) = tensor_inf.argmin_argmax(0)?;
711
712 let argmin_val = argmin_result.as_slice::<u64>()?[0];
713 let argmax_val = argmax_result.as_slice::<u64>()?[0];
714
715 assert_eq!(argmin_val, 3);
717 assert_eq!(argmax_val, 1);
718
719 let data_nan = vec![1.0f32, f32::NAN, 3.0];
721 let tensor_nan = Tensor::from_vec(data_nan, [3])?;
722 let (argmin_result, argmax_result) = tensor_nan.argmin_argmax(0)?;
723
724 let argmin_val = argmin_result.as_slice::<u64>()?[0];
725 let argmax_val = argmax_result.as_slice::<u64>()?[0];
726
727 assert!(argmin_val < 3);
730 assert!(argmax_val < 3);
731
732 Ok(())
733 }
734
735 #[test]
736 fn test_argmin_argmax_edge_cases() -> Result<()> {
737 let single = Tensor::from_vec(vec![42.0f32], [1])?;
741 let (argmin_result, argmax_result) = single.argmin_argmax(0)?;
742
743 let argmin_val = argmin_result.as_slice::<u64>()?[0];
744 let argmax_val = argmax_result.as_slice::<u64>()?[0];
745
746 assert_eq!(argmin_val, 0);
747 assert_eq!(argmax_val, 0);
748
749 let same = Tensor::from_vec(vec![5.0f32, 5.0, 5.0, 5.0], [2, 2])?;
751 let (argmin_result, argmax_result) = same.argmin_argmax(0)?;
752
753 let argmin_vals = argmin_result.as_slice::<u64>()?;
754 let argmax_vals = argmax_result.as_slice::<u64>()?;
755
756 assert_eq!(argmin_vals, &[0, 0]);
758 assert_eq!(argmax_vals, &[0, 0]);
759
760 Ok(())
761 }
762
763 #[test]
764 fn test_argmin_argmax_rectangular_tensors() -> Result<()> {
765 let data_1x5 = vec![5.0f32, 1.0, 9.0, 3.0, 7.0];
769 let tensor_1x5 = Tensor::from_vec(data_1x5, [1, 5])?;
770
771 let (argmin_result, argmax_result) = tensor_1x5.argmin_argmax(0)?;
772 assert_eq!(argmin_result.dims(), &[5]);
773 assert_eq!(argmax_result.dims(), &[5]);
774
775 let (argmin_result, argmax_result) = tensor_1x5.argmin_argmax(1)?;
776 assert_eq!(argmin_result.dims(), &[1]);
777 assert_eq!(argmax_result.dims(), &[1]);
778
779 let argmin_val = argmin_result.as_slice::<u64>()?[0];
780 let argmax_val = argmax_result.as_slice::<u64>()?[0];
781 assert_eq!(argmin_val, 1); assert_eq!(argmax_val, 2); let data_5x1 = vec![10.0f32, 20.0, 5.0, 30.0, 15.0];
786 let tensor_5x1 = Tensor::from_vec(data_5x1, [5, 1])?;
787
788 let (argmin_result, argmax_result) = tensor_5x1.argmin_argmax(0)?;
789 assert_eq!(argmin_result.dims(), &[1]);
790 assert_eq!(argmax_result.dims(), &[1]);
791
792 let argmin_val = argmin_result.as_slice::<u64>()?[0];
793 let argmax_val = argmax_result.as_slice::<u64>()?[0];
794 assert_eq!(argmin_val, 2); assert_eq!(argmax_val, 3); let (argmin_result, argmax_result) = tensor_5x1.argmin_argmax(1)?;
798 assert_eq!(argmin_result.dims(), &[5]);
799 assert_eq!(argmax_result.dims(), &[5]);
800
801 Ok(())
802 }
803
804 #[test]
805 fn test_argmin_argmax_consistency_with_individual_ops() -> Result<()> {
806 let data = vec![3.0f32, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0];
808 let tensor = Tensor::from_vec(data, [2, 4])?;
809
810 let (argmin_result, argmax_result) = tensor.argmin_argmax(0)?;
812 let individual_argmin = tensor.argmin(0)?;
813 let individual_argmax = tensor.argmax(0)?;
814
815 let argmin_vals = argmin_result.as_slice::<u64>()?;
816 let argmax_vals = argmax_result.as_slice::<u64>()?;
817 let individual_argmin_vals = individual_argmin.as_slice::<u64>()?;
818 let individual_argmax_vals = individual_argmax.as_slice::<u64>()?;
819
820 for (i, (&argmin_val, &argmax_val)) in
821 argmin_vals.iter().zip(argmax_vals.iter()).enumerate()
822 {
823 assert_eq!(argmin_val, individual_argmin_vals[i]);
824 assert_eq!(argmax_val, individual_argmax_vals[i]);
825 }
826
827 let (argmin_result, argmax_result) = tensor.argmin_argmax(1)?;
829 let individual_argmin = tensor.argmin(1)?;
830 let individual_argmax = tensor.argmax(1)?;
831
832 let argmin_vals = argmin_result.as_slice::<u64>()?;
833 let argmax_vals = argmax_result.as_slice::<u64>()?;
834 let individual_argmin_vals = individual_argmin.as_slice::<u64>()?;
835 let individual_argmax_vals = individual_argmax.as_slice::<u64>()?;
836
837 for (i, (&argmin_val, &argmax_val)) in
838 argmin_vals.iter().zip(argmax_vals.iter()).enumerate()
839 {
840 assert_eq!(argmin_val, individual_argmin_vals[i]);
841 assert_eq!(argmax_val, individual_argmax_vals[i]);
842 }
843
844 Ok(())
845 }
846
847 #[test]
848 fn test_argmin_argmax_large_tensor() -> Result<()> {
849 let size = 1000;
851 let data: Vec<f32> = (0..size).map(|i| (i % 100) as f32).collect();
852 let tensor = Tensor::from_vec(data, [10, 100])?;
853
854 let (argmin_result, argmax_result) = tensor.argmin_argmax(1)?;
855
856 assert_eq!(argmin_result.dims(), &[10]);
857 assert_eq!(argmax_result.dims(), &[10]);
858
859 let argmin_vals = argmin_result.as_slice::<u64>()?;
860 let argmax_vals = argmax_result.as_slice::<u64>()?;
861
862 for (&argmin_val, &argmax_val) in argmin_vals.iter().zip(argmax_vals.iter()) {
864 assert_eq!(argmin_val, 0);
865 assert_eq!(argmax_val, 99);
866 }
867
868 Ok(())
869 }
870
871 #[test]
872 fn test_argmin_argmax_empty_tensor_error() -> Result<()> {
873 let empty_tensor = Tensor::from_vec(Vec::<f32>::new(), [0])?;
875
876 let result = empty_tensor.argmin_argmax(0);
877 assert!(result.is_err());
878
879 let error_msg = result.unwrap_err().to_string();
880 assert!(error_msg.contains("size 0"));
881
882 Ok(())
883 }
884
885 #[test]
886 fn test_argmin_argmax_invalid_dimension() -> Result<()> {
887 let data = vec![1.0f32, 2.0, 3.0];
889 let tensor = Tensor::from_vec(data, [3])?;
890
891 assert!(tensor.argmin_argmax(1).is_err());
893 assert!(tensor.argmin_argmax(2).is_err());
894
895 Ok(())
896 }
897
898 #[test]
899 fn test_argmin_argmax_first_occurrence() -> Result<()> {
900 let data = vec![3.0f32, 1.0, 4.0, 1.0, 5.0, 5.0, 2.0];
902 let tensor = Tensor::from_vec(data, [7])?;
903
904 let (argmin_result, argmax_result) = tensor.argmin_argmax(0)?;
905
906 let argmin_val = argmin_result.as_slice::<u64>()?[0];
907 let argmax_val = argmax_result.as_slice::<u64>()?[0];
908
909 assert_eq!(argmin_val, 1);
911 assert_eq!(argmax_val, 4);
913
914 Ok(())
915 }
916}