train_station/tensor/reductions/
argmax.rs

1//! Argmax reduction operations for tensors
2//!
3//! This module provides argmax operations that find the indices of maximum values
4//! in tensors. These operations are non-differentiable and never require gradients.
5//!
6//! # Operations
7//!
8//! * `argmax()` - Find the index of the maximum value across all elements
9//! * `argmax_dim()` - Find the indices of maximum values along a specific dimension
10//!
11//! # Examples
12//!
13//! ```
14//! use train_station::Tensor;
15//!
16//! let tensor = Tensor::from_slice(&[1.0, 5.0, 3.0, 2.0], vec![4]).unwrap();
17//! let max_idx = tensor.argmax();
18//! assert_eq!(max_idx.get(&[0]), 1.0); // Index 1 has the maximum value 5.0
19//! ```
20
21use crate::tensor::core::Tensor;
22
23impl Tensor {
24    /// Returns the index of the maximum value across all elements in the tensor
25    ///
26    /// This operation finds the flat index (0-based) of the element with the highest value.
27    /// If multiple elements have the same maximum value, the index of the first occurrence
28    /// is returned. The output is a scalar tensor with shape \[1\] containing the index as a float.
29    ///
30    /// This operation is non-differentiable and the output never requires gradients.
31    ///
32    /// # Returns
33    ///
34    /// A tensor with shape \[1\] containing the flat index of the maximum value
35    ///
36    /// # Examples
37    ///
38    /// ```
39    /// use train_station::Tensor;
40    ///
41    /// // 1D tensor
42    /// let tensor = Tensor::from_slice(&[1.0, 5.0, 3.0, 2.0], vec![4]).unwrap();
43    /// let max_idx = tensor.argmax();
44    /// assert_eq!(max_idx.shape().dims(), vec![1]);
45    /// assert_eq!(max_idx.get(&[0]), 1.0); // Index 1 has value 5.0
46    /// ```
47    ///
48    /// ```
49    /// use train_station::Tensor;
50    ///
51    /// // 2D tensor
52    /// let tensor = Tensor::from_slice(&[1.0, 3.0, 2.0, 4.0, 0.0, 5.0], vec![2, 3]).unwrap();
53    /// let max_idx = tensor.argmax();
54    /// assert_eq!(max_idx.get(&[0]), 5.0); // Flat index 5 has value 5.0
55    /// ```
56    ///
57    /// ```
58    /// use train_station::Tensor;
59    ///
60    /// // Tied values return first occurrence
61    /// let tensor = Tensor::from_slice(&[3.0, 5.0, 5.0, 2.0], vec![4]).unwrap();
62    /// let max_idx = tensor.argmax();
63    /// assert_eq!(max_idx.get(&[0]), 1.0); // First occurrence of 5.0 at index 1
64    /// ```
65    #[track_caller]
66    pub fn argmax(&self) -> Tensor {
67        let mut out = Tensor::new(vec![1]);
68        if self.size() == 0 {
69            out.fill(0.0);
70            return out;
71        }
72
73        let mut best_val = f32::NEG_INFINITY;
74        let mut best_idx = 0usize;
75
76        if self.is_contiguous() {
77            // Fast path for contiguous tensors
78            unsafe {
79                let src = self.as_ptr();
80                for i in 0..self.size() {
81                    let v = *src.add(i);
82                    if v > best_val {
83                        best_val = v;
84                        best_idx = i;
85                    }
86                }
87            }
88        } else {
89            // Stride-aware path for non-contiguous tensors
90            let dims = self.shape().dims().to_vec();
91            for flat_idx in 0..self.size() {
92                // Convert flat index to multi-dimensional coordinates
93                let mut coords = vec![0; dims.len()];
94                let mut tmp = flat_idx;
95                for k in (0..dims.len()).rev() {
96                    coords[k] = tmp % dims[k];
97                    tmp /= dims[k];
98                }
99
100                // Get value using stride-aware offset
101                let offset = self.shape().offset(&coords);
102                let v = unsafe { *self.as_ptr().add(offset) };
103                if v > best_val {
104                    best_val = v;
105                    best_idx = flat_idx;
106                }
107            }
108        }
109
110        unsafe {
111            *out.as_mut_ptr() = best_idx as f32;
112        }
113        out
114    }
115
116    /// Returns the indices of maximum values along a specified dimension
117    ///
118    /// This operation finds the indices of maximum values along the specified dimension.
119    /// For each slice along the dimension, it returns the index of the maximum value.
120    /// If multiple elements have the same maximum value, the index of the first occurrence
121    /// is returned.
122    ///
123    /// The output shape depends on the `keepdim` parameter:
124    /// * If `keepdim` is `true`, the reduced dimension is kept with size 1
125    /// * If `keepdim` is `false`, the reduced dimension is removed
126    ///
127    /// This operation is non-differentiable and the output never requires gradients.
128    ///
129    /// # Arguments
130    ///
131    /// * `dim` - The dimension along which to find argmax indices (0-based)
132    /// * `keepdim` - Whether to keep the reduced dimension with size 1
133    ///
134    /// # Returns
135    ///
136    /// A tensor containing the indices of maximum values along the specified dimension
137    ///
138    /// # Panics
139    ///
140    /// Panics if `dim` is out of bounds for the tensor's rank or if the dimension size is 0.
141    ///
142    /// # Examples
143    ///
144    /// ```
145    /// use train_station::Tensor;
146    ///
147    /// // 2D tensor: [[1.0, 3.0, 2.0],
148    /// //             [4.0, 0.0, 5.0]]
149    /// let tensor = Tensor::from_slice(&[1.0, 3.0, 2.0, 4.0, 0.0, 5.0], vec![2, 3]).unwrap();
150    ///
151    /// // argmax along columns (dim=1)
152    /// let col_max_idx = tensor.argmax_dim(1, false);
153    /// assert_eq!(col_max_idx.shape().dims(), vec![2]);
154    /// assert_eq!(col_max_idx.get(&[0]), 1.0); // Row 0: max at index 1 (value 3.0)
155    /// assert_eq!(col_max_idx.get(&[1]), 2.0); // Row 1: max at index 2 (value 5.0)
156    /// ```
157    ///
158    /// ```
159    /// use train_station::Tensor;
160    ///
161    /// // argmax along rows (dim=0) with keepdim
162    /// let tensor = Tensor::from_slice(&[1.0, 3.0, 2.0, 4.0, 0.0, 5.0], vec![2, 3]).unwrap();
163    /// let row_max_idx = tensor.argmax_dim(0, true);
164    /// assert_eq!(row_max_idx.shape().dims(), vec![1, 3]);
165    /// assert_eq!(row_max_idx.get(&[0, 0]), 1.0); // Col 0: max at index 1 (value 4.0)
166    /// assert_eq!(row_max_idx.get(&[0, 1]), 0.0); // Col 1: max at index 0 (value 3.0)
167    /// assert_eq!(row_max_idx.get(&[0, 2]), 1.0); // Col 2: max at index 1 (value 5.0)
168    /// ```
169    ///
170    /// ```
171    /// use train_station::Tensor;
172    ///
173    /// // 1D tensor edge case
174    /// let tensor = Tensor::from_slice(&[5.0, 1.0, 8.0, 3.0], vec![4]).unwrap();
175    /// let max_idx = tensor.argmax_dim(0, false);
176    /// assert_eq!(max_idx.shape().dims(), vec![1]); // Special case: becomes [1] not []
177    /// assert_eq!(max_idx.get(&[0]), 2.0); // Index 2 has maximum value 8.0
178    /// ```
179    #[track_caller]
180    pub fn argmax_dim(&self, dim: usize, keepdim: bool) -> Tensor {
181        let rank = self.shape().rank();
182        assert!(
183            dim < rank,
184            "argmax_dim dim {} out of bounds for rank {}",
185            dim,
186            rank
187        );
188
189        let in_dims = self.shape().dims().to_vec();
190        let reduce_size = in_dims[dim];
191        assert!(reduce_size > 0, "cannot argmax over empty dimension");
192
193        // Build output shape
194        let mut out_dims = in_dims.clone();
195        if keepdim {
196            out_dims[dim] = 1;
197        } else {
198            out_dims.remove(dim);
199        }
200        if out_dims.is_empty() {
201            out_dims.push(1);
202        }
203
204        let mut out = Tensor::zeros(out_dims.clone());
205
206        // Use stride-aware approach to handle non-contiguous tensors correctly
207        let out_size = out.size();
208
209        unsafe {
210            let dst = out.as_mut_ptr();
211
212            // Iterate over all output positions
213            for out_idx in 0..out_size {
214                // Convert flat output index to multi-dimensional coordinates
215                let mut out_coords = vec![0; out_dims.len()];
216                let mut tmp = out_idx;
217                for k in (0..out_dims.len()).rev() {
218                    out_coords[k] = tmp % out_dims[k];
219                    tmp /= out_dims[k];
220                }
221
222                // Convert output coordinates to input coordinates
223                let mut in_coords = vec![0; rank];
224                if keepdim {
225                    // When keepdim=true, output coords map directly to input coords
226                    for k in 0..rank {
227                        if k == dim {
228                            in_coords[k] = 0; // Will be set in the loop below
229                        } else {
230                            in_coords[k] = out_coords[k];
231                        }
232                    }
233                } else {
234                    // When keepdim=false, we need to insert the missing dimension
235                    let mut out_coord_idx = 0;
236                    for (k, in_coord) in in_coords.iter_mut().enumerate().take(rank) {
237                        if k == dim {
238                            *in_coord = 0; // Will be set in the loop below
239                        } else {
240                            *in_coord = out_coords[out_coord_idx];
241                            out_coord_idx += 1;
242                        }
243                    }
244                }
245
246                // Find the argmax along the specified dimension
247                let mut best_val = f32::NEG_INFINITY;
248                let mut best_j = 0usize;
249
250                for j in 0..reduce_size {
251                    in_coords[dim] = j;
252                    let in_offset = self.shape().offset(&in_coords);
253                    let v = *self.as_ptr().add(in_offset);
254                    if v > best_val {
255                        best_val = v;
256                        best_j = j;
257                    }
258                }
259
260                *dst.add(out_idx) = best_j as f32;
261            }
262        }
263
264        out
265    }
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271
272    // ====== LEVEL 1: Basic functionality tests for contiguous tensors ======
273
274    #[test]
275    fn test_argmax_level1_basic_1d() {
276        let x = Tensor::from_slice(&[3.0, -2.0, 5.0, -1.0], vec![4]).unwrap();
277        let idx = x.argmax();
278
279        // Check output shape
280        assert_eq!(idx.shape().dims(), vec![1]);
281        assert_eq!(idx.size(), 1);
282
283        // Check result
284        assert_eq!(idx.get(&[0]), 2.0); // index 2 has value 5.0
285    }
286
287    #[test]
288    fn test_argmax_level1_basic_1d_edge_cases() {
289        // Single element
290        let x = Tensor::from_slice(&[42.0], vec![1]).unwrap();
291        let idx = x.argmax();
292        assert_eq!(idx.get(&[0]), 0.0);
293
294        // All same values (should return first occurrence)
295        let x = Tensor::from_slice(&[3.0, 3.0, 3.0], vec![3]).unwrap();
296        let idx = x.argmax();
297        assert_eq!(idx.get(&[0]), 0.0);
298
299        // Negative values
300        let x = Tensor::from_slice(&[-5.0, -2.0, -8.0, -1.0], vec![4]).unwrap();
301        let idx = x.argmax();
302        assert_eq!(idx.get(&[0]), 3.0); // index 3 has value -1.0
303    }
304
305    #[test]
306    fn test_argmax_level1_basic_2d_contiguous() {
307        // Test argmax over all elements for 2D tensor
308        // Data: [[1.0, 3.0, 2.0],
309        //        [4.0, 0.0, 5.0]]
310        let x = Tensor::from_slice(&[1.0, 3.0, 2.0, 4.0, 0.0, 5.0], vec![2, 3]).unwrap();
311        let idx = x.argmax();
312
313        assert_eq!(idx.shape().dims(), vec![1]);
314        assert_eq!(idx.get(&[0]), 5.0); // flat index 5 has value 5.0
315    }
316
317    #[test]
318    fn test_argmax_level1_dim_2d_basic() {
319        // Test argmax_dim for simple 2D case
320        // Data: [[1.0, 3.0, 2.0],
321        //        [4.0, 0.0, 5.0]]
322        let x = Tensor::from_slice(&[1.0, 3.0, 2.0, 4.0, 0.0, 5.0], vec![2, 3]).unwrap();
323
324        // argmax along dim=1 (along columns within each row)
325        let idx1_keepdim = x.argmax_dim(1, true);
326        assert_eq!(idx1_keepdim.shape().dims(), vec![2, 1]);
327        assert_eq!(idx1_keepdim.get(&[0, 0]), 1.0); // row 0: max at index 1 (value 3.0)
328        assert_eq!(idx1_keepdim.get(&[1, 0]), 2.0); // row 1: max at index 2 (value 5.0)
329
330        let idx1_no_keepdim = x.argmax_dim(1, false);
331        assert_eq!(idx1_no_keepdim.shape().dims(), vec![2]);
332        assert_eq!(idx1_no_keepdim.get(&[0]), 1.0);
333        assert_eq!(idx1_no_keepdim.get(&[1]), 2.0);
334
335        // argmax along dim=0 (along rows within each column)
336        let idx0_keepdim = x.argmax_dim(0, true);
337        assert_eq!(idx0_keepdim.shape().dims(), vec![1, 3]);
338        assert_eq!(idx0_keepdim.get(&[0, 0]), 1.0); // col 0: max at index 1 (value 4.0)
339        assert_eq!(idx0_keepdim.get(&[0, 1]), 0.0); // col 1: max at index 0 (value 3.0)
340        assert_eq!(idx0_keepdim.get(&[0, 2]), 1.0); // col 2: max at index 1 (value 5.0)
341
342        let idx0_no_keepdim = x.argmax_dim(0, false);
343        assert_eq!(idx0_no_keepdim.shape().dims(), vec![3]);
344        assert_eq!(idx0_no_keepdim.get(&[0]), 1.0);
345        assert_eq!(idx0_no_keepdim.get(&[1]), 0.0);
346        assert_eq!(idx0_no_keepdim.get(&[2]), 1.0);
347    }
348
349    #[test]
350    fn test_argmax_level1_3d_basic() {
351        // Test 3D tensor: shape [2, 2, 2]
352        // Data: [[[1.0, 2.0], [3.0, 4.0]],
353        //        [[5.0, 6.0], [7.0, 8.0]]]
354        let x =
355            Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], vec![2, 2, 2]).unwrap();
356
357        // Global argmax
358        let idx = x.argmax();
359        assert_eq!(idx.get(&[0]), 7.0); // flat index 7 has value 8.0
360
361        // argmax along dim=2 (innermost dimension)
362        let idx2 = x.argmax_dim(2, false);
363        assert_eq!(idx2.shape().dims(), vec![2, 2]);
364        assert_eq!(idx2.get(&[0, 0]), 1.0); // [1.0, 2.0] -> max at index 1
365        assert_eq!(idx2.get(&[0, 1]), 1.0); // [3.0, 4.0] -> max at index 1
366        assert_eq!(idx2.get(&[1, 0]), 1.0); // [5.0, 6.0] -> max at index 1
367        assert_eq!(idx2.get(&[1, 1]), 1.0); // [7.0, 8.0] -> max at index 1
368    }
369
370    // ====== LEVEL 2: Non-contiguous tensors (views, permuted) ======
371
372    #[test]
373    fn test_argmax_level2_transpose_view() {
374        // Create a 2x3 tensor and transpose it to get a non-contiguous view
375        let x = Tensor::from_slice(&[1.0, 3.0, 2.0, 4.0, 0.0, 5.0], vec![2, 3]).unwrap();
376        // Original: [[1.0, 3.0, 2.0],
377        //            [4.0, 0.0, 5.0]]
378
379        let x_t = x.transpose(0, 1);
380        // Transposed: [[1.0, 4.0],
381        //              [3.0, 0.0],
382        //              [2.0, 5.0]]
383        assert_eq!(x_t.shape().dims(), vec![3, 2]);
384        assert!(!x_t.is_contiguous()); // Should be a view
385
386        // Test global argmax on transposed view
387        let idx = x_t.argmax();
388        assert_eq!(idx.get(&[0]), 5.0); // flat index 5 still points to value 5.0
389
390        // Test argmax along dim=0 of transposed tensor
391        let idx0 = x_t.argmax_dim(0, false);
392        assert_eq!(idx0.shape().dims(), vec![2]);
393        assert_eq!(idx0.get(&[0]), 1.0); // col 0: [1.0, 3.0, 2.0] -> max 3.0 at index 1
394        assert_eq!(idx0.get(&[1]), 2.0); // col 1: [4.0, 0.0, 5.0] -> max 5.0 at index 2
395
396        // Test argmax along dim=1 of transposed tensor
397        let idx1 = x_t.argmax_dim(1, false);
398        assert_eq!(idx1.shape().dims(), vec![3]);
399        assert_eq!(idx1.get(&[0]), 1.0); // row 0: [1.0, 4.0] -> max 4.0 at index 1
400        assert_eq!(idx1.get(&[1]), 0.0); // row 1: [3.0, 0.0] -> max 3.0 at index 0
401        assert_eq!(idx1.get(&[2]), 1.0); // row 2: [2.0, 5.0] -> max 5.0 at index 1
402    }
403
404    #[test]
405    fn test_argmax_level2_slice_view() {
406        // Create a 3x4 tensor and take a slice
407        let data = vec![
408            1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
409        ];
410        let x = Tensor::from_slice(&data, vec![3, 4]).unwrap();
411        // [[1, 2, 3, 4],
412        //  [5, 6, 7, 8],
413        //  [9, 10, 11, 12]]
414
415        // Select middle row (creates a view)
416        let middle_row = x.select(0, 1);
417        // [5, 6, 7, 8]
418        assert_eq!(middle_row.shape().dims(), vec![4]);
419
420        let idx = middle_row.argmax();
421        assert_eq!(idx.get(&[0]), 3.0); // index 3 has value 8.0
422
423        // Test argmax_dim on 1D slice (should work the same as global argmax)
424        let idx_dim = middle_row.argmax_dim(0, false);
425        assert_eq!(idx_dim.shape().dims(), vec![1]);
426        assert_eq!(idx_dim.get(&[0]), 3.0);
427    }
428
429    #[test]
430    fn test_argmax_level2_permuted_3d() {
431        // Test 3D tensor with permuted dimensions
432        let data = (0..24).map(|i| i as f32).collect::<Vec<_>>();
433        let x = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
434        // Shape [2, 3, 4] with values 0 to 23
435
436        // Permute to [4, 2, 3] (swap dims 0 and 2)
437        let x_perm = x.permute(vec![2, 1, 0]);
438        assert_eq!(x_perm.shape().dims(), vec![4, 3, 2]);
439        assert!(!x_perm.is_contiguous());
440
441        // Global argmax should still find the maximum value (23)
442        let idx = x_perm.argmax();
443        assert_eq!(idx.get(&[0]), 23.0); // The max value is still 23
444
445        // Test argmax along each dimension of permuted tensor
446        let idx0 = x_perm.argmax_dim(0, false); // [3, 2]
447        assert_eq!(idx0.shape().dims(), vec![3, 2]);
448
449        let idx1 = x_perm.argmax_dim(1, false); // [4, 2]
450        assert_eq!(idx1.shape().dims(), vec![4, 2]);
451
452        let idx2 = x_perm.argmax_dim(2, false); // [4, 3]
453        assert_eq!(idx2.shape().dims(), vec![4, 3]);
454    }
455
456    #[test]
457    fn test_argmax_level2_nested_views() {
458        // Test nested transformations (transpose then select)
459        let data = vec![
460            1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
461        ];
462        let x = Tensor::from_slice(&data, vec![4, 3]).unwrap();
463
464        // First transpose, then select a row
465        let x_t = x.transpose(0, 1); // [3, 4]
466        let row = x_t.select(0, 1); // Select second row: [2, 5, 8, 11]
467        assert_eq!(row.shape().dims(), vec![4]);
468
469        let idx = row.argmax();
470        assert_eq!(idx.get(&[0]), 3.0); // index 3 has value 11.0
471    }
472
473    // ====== LEVEL 3: Complex multi-dimensional cases and edge scenarios ======
474
475    #[test]
476    fn test_argmax_level3_4d_tensor() {
477        // Test 4D tensor with various reduction dimensions
478        let data = (0..120).map(|i| i as f32).collect::<Vec<_>>();
479        let x = Tensor::from_slice(&data, vec![2, 3, 4, 5]).unwrap();
480        // Shape [2, 3, 4, 5] with values 0 to 119
481
482        // Global argmax
483        let idx = x.argmax();
484        assert_eq!(idx.get(&[0]), 119.0); // Maximum value 119.0 at flat index 119
485
486        // Test argmax along each dimension
487        let idx0_keepdim = x.argmax_dim(0, true);
488        assert_eq!(idx0_keepdim.shape().dims(), vec![1, 3, 4, 5]);
489
490        let idx0_no_keepdim = x.argmax_dim(0, false);
491        assert_eq!(idx0_no_keepdim.shape().dims(), vec![3, 4, 5]);
492
493        let idx1_keepdim = x.argmax_dim(1, true);
494        assert_eq!(idx1_keepdim.shape().dims(), vec![2, 1, 4, 5]);
495
496        let idx1_no_keepdim = x.argmax_dim(1, false);
497        assert_eq!(idx1_no_keepdim.shape().dims(), vec![2, 4, 5]);
498
499        let idx2_keepdim = x.argmax_dim(2, true);
500        assert_eq!(idx2_keepdim.shape().dims(), vec![2, 3, 1, 5]);
501
502        let idx2_no_keepdim = x.argmax_dim(2, false);
503        assert_eq!(idx2_no_keepdim.shape().dims(), vec![2, 3, 5]);
504
505        let idx3_keepdim = x.argmax_dim(3, true);
506        assert_eq!(idx3_keepdim.shape().dims(), vec![2, 3, 4, 1]);
507
508        let idx3_no_keepdim = x.argmax_dim(3, false);
509        assert_eq!(idx3_no_keepdim.shape().dims(), vec![2, 3, 4]);
510
511        // Check some specific values for the innermost dimension (dim=3)
512        // For each [i, j, k, :] slice, argmax should be 4 (index of max in size-5 dimension)
513        for i in 0..2 {
514            for j in 0..3 {
515                for k in 0..4 {
516                    assert_eq!(idx3_no_keepdim.get(&[i, j, k]), 4.0);
517                    assert_eq!(idx3_keepdim.get(&[i, j, k, 0]), 4.0);
518                }
519            }
520        }
521    }
522
523    #[test]
524    fn test_argmax_level3_edge_cases_keepdim() {
525        // Test edge case: 1D tensor with keepdim
526        let x1d = Tensor::from_slice(&[5.0, 1.0, 8.0, 3.0], vec![4]).unwrap();
527
528        let idx_keepdim = x1d.argmax_dim(0, true);
529        assert_eq!(idx_keepdim.shape().dims(), vec![1]);
530        assert_eq!(idx_keepdim.get(&[0]), 2.0);
531
532        let idx_no_keepdim = x1d.argmax_dim(0, false);
533        assert_eq!(idx_no_keepdim.shape().dims(), vec![1]); // Special case: becomes [1] not []
534        assert_eq!(idx_no_keepdim.get(&[0]), 2.0);
535
536        // Test edge case: dimension of size 1
537        let x_size_1 = Tensor::from_slice(&[42.0], vec![1]).unwrap();
538
539        let idx = x_size_1.argmax_dim(0, true);
540        assert_eq!(idx.shape().dims(), vec![1]);
541        assert_eq!(idx.get(&[0]), 0.0);
542
543        let idx = x_size_1.argmax_dim(0, false);
544        assert_eq!(idx.shape().dims(), vec![1]);
545        assert_eq!(idx.get(&[0]), 0.0);
546    }
547
548    #[test]
549    fn test_argmax_level3_ties_handling() {
550        // Test that tied values return the first occurrence (PyTorch behavior)
551        let x = Tensor::from_slice(&[3.0, 5.0, 5.0, 2.0, 5.0], vec![5]).unwrap();
552
553        let idx = x.argmax();
554        assert_eq!(idx.get(&[0]), 1.0); // First occurrence of max value 5.0
555
556        // Test with 2D ties
557        let x2d = Tensor::from_slice(&[3.0, 5.0, 5.0, 2.0, 1.0, 5.0], vec![3, 2]).unwrap();
558
559        // argmax along dim=0 (columns)
560        let idx0 = x2d.argmax_dim(0, false);
561        assert_eq!(idx0.shape().dims(), vec![2]);
562        assert_eq!(idx0.get(&[0]), 1.0); // col 0: [3, 5, 1] -> first 5 at index 1
563        assert_eq!(idx0.get(&[1]), 0.0); // col 1: [5, 2, 5] -> first 5 at index 0
564
565        // argmax along dim=1 (rows)
566        let idx1 = x2d.argmax_dim(1, false);
567        assert_eq!(idx1.shape().dims(), vec![3]);
568        assert_eq!(idx1.get(&[0]), 1.0); // row 0: [3, 5] -> max at index 1
569        assert_eq!(idx1.get(&[1]), 0.0); // row 1: [5, 2] -> max at index 0
570        assert_eq!(idx1.get(&[2]), 1.0); // row 2: [1, 5] -> max at index 1
571    }
572
573    #[test]
574    fn test_argmax_level3_extreme_values() {
575        // Test with extreme floating point values
576        let x = Tensor::from_slice(
577            &[f32::NEG_INFINITY, -1e10, 0.0, 1e10, f32::INFINITY, f32::NAN],
578            vec![6],
579        )
580        .unwrap();
581
582        let idx = x.argmax();
583        // NaN comparison behavior: NaN is not > any value, so INFINITY should win
584        assert_eq!(idx.get(&[0]), 4.0); // f32::INFINITY at index 4
585
586        // Test negative values only
587        let x_neg = Tensor::from_slice(&[-10.0, -5.0, -15.0, -1.0], vec![4]).unwrap();
588        let idx = x_neg.argmax();
589        assert_eq!(idx.get(&[0]), 3.0); // -1.0 is the maximum at index 3
590    }
591
592    #[test]
593    fn test_argmax_level3_large_dimensions() {
594        // Test with one very large dimension
595        let size = 1000;
596        let data: Vec<f32> = (0..size).map(|i| (size - 1 - i) as f32).collect(); // Decreasing values
597        let x = Tensor::from_slice(&data, vec![size]).unwrap();
598
599        let idx = x.argmax();
600        assert_eq!(idx.get(&[0]), 0.0); // First element has max value (size-1)
601
602        // Test with multiple dimensions where one is large
603        let data2: Vec<f32> = (0..(10 * 100)).map(|i| i as f32).collect();
604        let x2 = Tensor::from_slice(&data2, vec![10, 100]).unwrap();
605
606        let idx = x2.argmax();
607        assert_eq!(idx.get(&[0]), 999.0); // Last element has max value
608
609        // Test argmax along the large dimension
610        let idx_dim1 = x2.argmax_dim(1, false);
611        assert_eq!(idx_dim1.shape().dims(), vec![10]);
612        // Each row's max should be at index 99 (last column)
613        for i in 0..10 {
614            assert_eq!(idx_dim1.get(&[i]), 99.0);
615        }
616    }
617
618    #[test]
619    fn test_argmax_level3_consistency_with_pytorch_behavior() {
620        // Test specific patterns that should match PyTorch exactly
621
622        // Pattern 1: 3D tensor, reduce middle dimension
623        let x = Tensor::from_slice(
624            &[
625                1.0, 2.0, 3.0, 4.0, // [0, 0, :]
626                5.0, 6.0, 7.0, 8.0, // [0, 1, :]
627                9.0, 8.0, 7.0, 6.0, // [1, 0, :]
628                5.0, 4.0, 3.0, 2.0, // [1, 1, :]
629            ],
630            vec![2, 2, 4],
631        )
632        .unwrap();
633
634        // Reduce along dim=1 (middle dimension)
635        let idx = x.argmax_dim(1, true);
636        assert_eq!(idx.shape().dims(), vec![2, 1, 4]);
637
638        // For [0, :, j] where j=0,1,2,3: values are [1,5], [2,6], [3,7], [4,8]
639        // Max indices should be [1,1,1,1] (second slice wins)
640        assert_eq!(idx.get(&[0, 0, 0]), 1.0);
641        assert_eq!(idx.get(&[0, 0, 1]), 1.0);
642        assert_eq!(idx.get(&[0, 0, 2]), 1.0);
643        assert_eq!(idx.get(&[0, 0, 3]), 1.0);
644
645        // For [1, :, j] where j=0,1,2,3: values are [9,5], [8,4], [7,3], [6,2]
646        // Max indices should be [0,0,0,0] (first slice wins)
647        assert_eq!(idx.get(&[1, 0, 0]), 0.0);
648        assert_eq!(idx.get(&[1, 0, 1]), 0.0);
649        assert_eq!(idx.get(&[1, 0, 2]), 0.0);
650        assert_eq!(idx.get(&[1, 0, 3]), 0.0);
651    }
652}