train_station/tensor/reductions/
argmax.rs

1//! Argmax reduction operations for tensors
2//!
3//! This module provides argmax operations that find the indices of maximum values
4//! in tensors. These operations are non-differentiable and never require gradients.
5//!
6//! # Operations
7//!
8//! * `argmax()` - Find the index of the maximum value across all elements
9//! * `argmax_dim()` - Find the indices of maximum values along a specific dimension
10//!
11//! # Examples
12//!
13//! ```
14//! use train_station::Tensor;
15//!
16//! let tensor = Tensor::from_slice(&[1.0, 5.0, 3.0, 2.0], vec![4]).unwrap();
17//! let max_idx = tensor.argmax();
18//! assert_eq!(max_idx.get(&[0]), 1.0); // Index 1 has the maximum value 5.0
19//! ```
20
21use crate::tensor::core::Tensor;
22
23impl Tensor {
24    /// Returns the index of the maximum value across all elements in the tensor
25    ///
26    /// This operation finds the flat index (0-based) of the element with the highest value.
27    /// If multiple elements have the same maximum value, the index of the first occurrence
28    /// is returned. The output is a scalar tensor with shape \[1\] containing the index as a float.
29    ///
30    /// This operation is non-differentiable and the output never requires gradients.
31    ///
32    /// # Returns
33    ///
34    /// A tensor with shape \[1\] containing the flat index of the maximum value
35    ///
36    /// # Examples
37    ///
38    /// ```
39    /// use train_station::Tensor;
40    ///
41    /// // 1D tensor
42    /// let tensor = Tensor::from_slice(&[1.0, 5.0, 3.0, 2.0], vec![4]).unwrap();
43    /// let max_idx = tensor.argmax();
44    /// assert_eq!(max_idx.shape().dims, vec![1]);
45    /// assert_eq!(max_idx.get(&[0]), 1.0); // Index 1 has value 5.0
46    /// ```
47    ///
48    /// ```
49    /// use train_station::Tensor;
50    ///
51    /// // 2D tensor
52    /// let tensor = Tensor::from_slice(&[1.0, 3.0, 2.0, 4.0, 0.0, 5.0], vec![2, 3]).unwrap();
53    /// let max_idx = tensor.argmax();
54    /// assert_eq!(max_idx.get(&[0]), 5.0); // Flat index 5 has value 5.0
55    /// ```
56    ///
57    /// ```
58    /// use train_station::Tensor;
59    ///
60    /// // Tied values return first occurrence
61    /// let tensor = Tensor::from_slice(&[3.0, 5.0, 5.0, 2.0], vec![4]).unwrap();
62    /// let max_idx = tensor.argmax();
63    /// assert_eq!(max_idx.get(&[0]), 1.0); // First occurrence of 5.0 at index 1
64    /// ```
65    pub fn argmax(&self) -> Tensor {
66        let mut out = Tensor::new(vec![1]);
67        if self.size() == 0 {
68            out.fill(0.0);
69            return out;
70        }
71
72        let mut best_val = f32::NEG_INFINITY;
73        let mut best_idx = 0usize;
74
75        if self.is_contiguous() {
76            // Fast path for contiguous tensors
77            unsafe {
78                let src = self.as_ptr();
79                for i in 0..self.size() {
80                    let v = *src.add(i);
81                    if v > best_val {
82                        best_val = v;
83                        best_idx = i;
84                    }
85                }
86            }
87        } else {
88            // Stride-aware path for non-contiguous tensors
89            let dims = self.shape().dims.clone();
90            for flat_idx in 0..self.size() {
91                // Convert flat index to multi-dimensional coordinates
92                let mut coords = vec![0; dims.len()];
93                let mut tmp = flat_idx;
94                for k in (0..dims.len()).rev() {
95                    coords[k] = tmp % dims[k];
96                    tmp /= dims[k];
97                }
98
99                // Get value using stride-aware offset
100                let offset = self.shape().offset(&coords);
101                let v = unsafe { *self.as_ptr().add(offset) };
102                if v > best_val {
103                    best_val = v;
104                    best_idx = flat_idx;
105                }
106            }
107        }
108
109        unsafe {
110            *out.as_mut_ptr() = best_idx as f32;
111        }
112        out
113    }
114
115    /// Returns the indices of maximum values along a specified dimension
116    ///
117    /// This operation finds the indices of maximum values along the specified dimension.
118    /// For each slice along the dimension, it returns the index of the maximum value.
119    /// If multiple elements have the same maximum value, the index of the first occurrence
120    /// is returned.
121    ///
122    /// The output shape depends on the `keepdim` parameter:
123    /// * If `keepdim` is `true`, the reduced dimension is kept with size 1
124    /// * If `keepdim` is `false`, the reduced dimension is removed
125    ///
126    /// This operation is non-differentiable and the output never requires gradients.
127    ///
128    /// # Arguments
129    ///
130    /// * `dim` - The dimension along which to find argmax indices (0-based)
131    /// * `keepdim` - Whether to keep the reduced dimension with size 1
132    ///
133    /// # Returns
134    ///
135    /// A tensor containing the indices of maximum values along the specified dimension
136    ///
137    /// # Panics
138    ///
139    /// Panics if `dim` is out of bounds for the tensor's rank or if the dimension size is 0.
140    ///
141    /// # Examples
142    ///
143    /// ```
144    /// use train_station::Tensor;
145    ///
146    /// // 2D tensor: [[1.0, 3.0, 2.0],
147    /// //             [4.0, 0.0, 5.0]]
148    /// let tensor = Tensor::from_slice(&[1.0, 3.0, 2.0, 4.0, 0.0, 5.0], vec![2, 3]).unwrap();
149    ///
150    /// // argmax along columns (dim=1)
151    /// let col_max_idx = tensor.argmax_dim(1, false);
152    /// assert_eq!(col_max_idx.shape().dims, vec![2]);
153    /// assert_eq!(col_max_idx.get(&[0]), 1.0); // Row 0: max at index 1 (value 3.0)
154    /// assert_eq!(col_max_idx.get(&[1]), 2.0); // Row 1: max at index 2 (value 5.0)
155    /// ```
156    ///
157    /// ```
158    /// use train_station::Tensor;
159    ///
160    /// // argmax along rows (dim=0) with keepdim
161    /// let tensor = Tensor::from_slice(&[1.0, 3.0, 2.0, 4.0, 0.0, 5.0], vec![2, 3]).unwrap();
162    /// let row_max_idx = tensor.argmax_dim(0, true);
163    /// assert_eq!(row_max_idx.shape().dims, vec![1, 3]);
164    /// assert_eq!(row_max_idx.get(&[0, 0]), 1.0); // Col 0: max at index 1 (value 4.0)
165    /// assert_eq!(row_max_idx.get(&[0, 1]), 0.0); // Col 1: max at index 0 (value 3.0)
166    /// assert_eq!(row_max_idx.get(&[0, 2]), 1.0); // Col 2: max at index 1 (value 5.0)
167    /// ```
168    ///
169    /// ```
170    /// use train_station::Tensor;
171    ///
172    /// // 1D tensor edge case
173    /// let tensor = Tensor::from_slice(&[5.0, 1.0, 8.0, 3.0], vec![4]).unwrap();
174    /// let max_idx = tensor.argmax_dim(0, false);
175    /// assert_eq!(max_idx.shape().dims, vec![1]); // Special case: becomes [1] not []
176    /// assert_eq!(max_idx.get(&[0]), 2.0); // Index 2 has maximum value 8.0
177    /// ```
178    pub fn argmax_dim(&self, dim: usize, keepdim: bool) -> Tensor {
179        let rank = self.shape().rank();
180        assert!(
181            dim < rank,
182            "argmax_dim dim {} out of bounds for rank {}",
183            dim,
184            rank
185        );
186
187        let in_dims = self.shape().dims.clone();
188        let reduce_size = in_dims[dim];
189        assert!(reduce_size > 0, "cannot argmax over empty dimension");
190
191        // Build output shape
192        let mut out_dims = in_dims.clone();
193        if keepdim {
194            out_dims[dim] = 1;
195        } else {
196            out_dims.remove(dim);
197        }
198        if out_dims.is_empty() {
199            out_dims.push(1);
200        }
201
202        let mut out = Tensor::zeros(out_dims.clone());
203
204        // Use stride-aware approach to handle non-contiguous tensors correctly
205        let out_size = out.size();
206
207        unsafe {
208            let dst = out.as_mut_ptr();
209
210            // Iterate over all output positions
211            for out_idx in 0..out_size {
212                // Convert flat output index to multi-dimensional coordinates
213                let mut out_coords = vec![0; out_dims.len()];
214                let mut tmp = out_idx;
215                for k in (0..out_dims.len()).rev() {
216                    out_coords[k] = tmp % out_dims[k];
217                    tmp /= out_dims[k];
218                }
219
220                // Convert output coordinates to input coordinates
221                let mut in_coords = vec![0; rank];
222                if keepdim {
223                    // When keepdim=true, output coords map directly to input coords
224                    for k in 0..rank {
225                        if k == dim {
226                            in_coords[k] = 0; // Will be set in the loop below
227                        } else {
228                            in_coords[k] = out_coords[k];
229                        }
230                    }
231                } else {
232                    // When keepdim=false, we need to insert the missing dimension
233                    let mut out_coord_idx = 0;
234                    for (k, in_coord) in in_coords.iter_mut().enumerate().take(rank) {
235                        if k == dim {
236                            *in_coord = 0; // Will be set in the loop below
237                        } else {
238                            *in_coord = out_coords[out_coord_idx];
239                            out_coord_idx += 1;
240                        }
241                    }
242                }
243
244                // Find the argmax along the specified dimension
245                let mut best_val = f32::NEG_INFINITY;
246                let mut best_j = 0usize;
247
248                for j in 0..reduce_size {
249                    in_coords[dim] = j;
250                    let in_offset = self.shape().offset(&in_coords);
251                    let v = *self.as_ptr().add(in_offset);
252                    if v > best_val {
253                        best_val = v;
254                        best_j = j;
255                    }
256                }
257
258                *dst.add(out_idx) = best_j as f32;
259            }
260        }
261
262        out
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269
270    // ====== LEVEL 1: Basic functionality tests for contiguous tensors ======
271
272    #[test]
273    fn test_argmax_level1_basic_1d() {
274        let x = Tensor::from_slice(&[3.0, -2.0, 5.0, -1.0], vec![4]).unwrap();
275        let idx = x.argmax();
276
277        // Check output shape
278        assert_eq!(idx.shape().dims, vec![1]);
279        assert_eq!(idx.size(), 1);
280
281        // Check result
282        assert_eq!(idx.get(&[0]), 2.0); // index 2 has value 5.0
283    }
284
285    #[test]
286    fn test_argmax_level1_basic_1d_edge_cases() {
287        // Single element
288        let x = Tensor::from_slice(&[42.0], vec![1]).unwrap();
289        let idx = x.argmax();
290        assert_eq!(idx.get(&[0]), 0.0);
291
292        // All same values (should return first occurrence)
293        let x = Tensor::from_slice(&[3.0, 3.0, 3.0], vec![3]).unwrap();
294        let idx = x.argmax();
295        assert_eq!(idx.get(&[0]), 0.0);
296
297        // Negative values
298        let x = Tensor::from_slice(&[-5.0, -2.0, -8.0, -1.0], vec![4]).unwrap();
299        let idx = x.argmax();
300        assert_eq!(idx.get(&[0]), 3.0); // index 3 has value -1.0
301    }
302
303    #[test]
304    fn test_argmax_level1_basic_2d_contiguous() {
305        // Test argmax over all elements for 2D tensor
306        // Data: [[1.0, 3.0, 2.0],
307        //        [4.0, 0.0, 5.0]]
308        let x = Tensor::from_slice(&[1.0, 3.0, 2.0, 4.0, 0.0, 5.0], vec![2, 3]).unwrap();
309        let idx = x.argmax();
310
311        assert_eq!(idx.shape().dims, vec![1]);
312        assert_eq!(idx.get(&[0]), 5.0); // flat index 5 has value 5.0
313    }
314
315    #[test]
316    fn test_argmax_level1_dim_2d_basic() {
317        // Test argmax_dim for simple 2D case
318        // Data: [[1.0, 3.0, 2.0],
319        //        [4.0, 0.0, 5.0]]
320        let x = Tensor::from_slice(&[1.0, 3.0, 2.0, 4.0, 0.0, 5.0], vec![2, 3]).unwrap();
321
322        // argmax along dim=1 (along columns within each row)
323        let idx1_keepdim = x.argmax_dim(1, true);
324        assert_eq!(idx1_keepdim.shape().dims, vec![2, 1]);
325        assert_eq!(idx1_keepdim.get(&[0, 0]), 1.0); // row 0: max at index 1 (value 3.0)
326        assert_eq!(idx1_keepdim.get(&[1, 0]), 2.0); // row 1: max at index 2 (value 5.0)
327
328        let idx1_no_keepdim = x.argmax_dim(1, false);
329        assert_eq!(idx1_no_keepdim.shape().dims, vec![2]);
330        assert_eq!(idx1_no_keepdim.get(&[0]), 1.0);
331        assert_eq!(idx1_no_keepdim.get(&[1]), 2.0);
332
333        // argmax along dim=0 (along rows within each column)
334        let idx0_keepdim = x.argmax_dim(0, true);
335        assert_eq!(idx0_keepdim.shape().dims, vec![1, 3]);
336        assert_eq!(idx0_keepdim.get(&[0, 0]), 1.0); // col 0: max at index 1 (value 4.0)
337        assert_eq!(idx0_keepdim.get(&[0, 1]), 0.0); // col 1: max at index 0 (value 3.0)
338        assert_eq!(idx0_keepdim.get(&[0, 2]), 1.0); // col 2: max at index 1 (value 5.0)
339
340        let idx0_no_keepdim = x.argmax_dim(0, false);
341        assert_eq!(idx0_no_keepdim.shape().dims, vec![3]);
342        assert_eq!(idx0_no_keepdim.get(&[0]), 1.0);
343        assert_eq!(idx0_no_keepdim.get(&[1]), 0.0);
344        assert_eq!(idx0_no_keepdim.get(&[2]), 1.0);
345    }
346
347    #[test]
348    fn test_argmax_level1_3d_basic() {
349        // Test 3D tensor: shape [2, 2, 2]
350        // Data: [[[1.0, 2.0], [3.0, 4.0]],
351        //        [[5.0, 6.0], [7.0, 8.0]]]
352        let x =
353            Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], vec![2, 2, 2]).unwrap();
354
355        // Global argmax
356        let idx = x.argmax();
357        assert_eq!(idx.get(&[0]), 7.0); // flat index 7 has value 8.0
358
359        // argmax along dim=2 (innermost dimension)
360        let idx2 = x.argmax_dim(2, false);
361        assert_eq!(idx2.shape().dims, vec![2, 2]);
362        assert_eq!(idx2.get(&[0, 0]), 1.0); // [1.0, 2.0] -> max at index 1
363        assert_eq!(idx2.get(&[0, 1]), 1.0); // [3.0, 4.0] -> max at index 1
364        assert_eq!(idx2.get(&[1, 0]), 1.0); // [5.0, 6.0] -> max at index 1
365        assert_eq!(idx2.get(&[1, 1]), 1.0); // [7.0, 8.0] -> max at index 1
366    }
367
368    // ====== LEVEL 2: Non-contiguous tensors (views, permuted) ======
369
370    #[test]
371    fn test_argmax_level2_transpose_view() {
372        // Create a 2x3 tensor and transpose it to get a non-contiguous view
373        let x = Tensor::from_slice(&[1.0, 3.0, 2.0, 4.0, 0.0, 5.0], vec![2, 3]).unwrap();
374        // Original: [[1.0, 3.0, 2.0],
375        //            [4.0, 0.0, 5.0]]
376
377        let x_t = x.transpose(0, 1);
378        // Transposed: [[1.0, 4.0],
379        //              [3.0, 0.0],
380        //              [2.0, 5.0]]
381        assert_eq!(x_t.shape().dims, vec![3, 2]);
382        assert!(!x_t.is_contiguous()); // Should be a view
383
384        // Test global argmax on transposed view
385        let idx = x_t.argmax();
386        assert_eq!(idx.get(&[0]), 5.0); // flat index 5 still points to value 5.0
387
388        // Test argmax along dim=0 of transposed tensor
389        let idx0 = x_t.argmax_dim(0, false);
390        assert_eq!(idx0.shape().dims, vec![2]);
391        assert_eq!(idx0.get(&[0]), 1.0); // col 0: [1.0, 3.0, 2.0] -> max 3.0 at index 1
392        assert_eq!(idx0.get(&[1]), 2.0); // col 1: [4.0, 0.0, 5.0] -> max 5.0 at index 2
393
394        // Test argmax along dim=1 of transposed tensor
395        let idx1 = x_t.argmax_dim(1, false);
396        assert_eq!(idx1.shape().dims, vec![3]);
397        assert_eq!(idx1.get(&[0]), 1.0); // row 0: [1.0, 4.0] -> max 4.0 at index 1
398        assert_eq!(idx1.get(&[1]), 0.0); // row 1: [3.0, 0.0] -> max 3.0 at index 0
399        assert_eq!(idx1.get(&[2]), 1.0); // row 2: [2.0, 5.0] -> max 5.0 at index 1
400    }
401
402    #[test]
403    fn test_argmax_level2_slice_view() {
404        // Create a 3x4 tensor and take a slice
405        let data = vec![
406            1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
407        ];
408        let x = Tensor::from_slice(&data, vec![3, 4]).unwrap();
409        // [[1, 2, 3, 4],
410        //  [5, 6, 7, 8],
411        //  [9, 10, 11, 12]]
412
413        // Select middle row (creates a view)
414        let middle_row = x.select(0, 1);
415        // [5, 6, 7, 8]
416        assert_eq!(middle_row.shape().dims, vec![4]);
417
418        let idx = middle_row.argmax();
419        assert_eq!(idx.get(&[0]), 3.0); // index 3 has value 8.0
420
421        // Test argmax_dim on 1D slice (should work the same as global argmax)
422        let idx_dim = middle_row.argmax_dim(0, false);
423        assert_eq!(idx_dim.shape().dims, vec![1]);
424        assert_eq!(idx_dim.get(&[0]), 3.0);
425    }
426
427    #[test]
428    fn test_argmax_level2_permuted_3d() {
429        // Test 3D tensor with permuted dimensions
430        let data = (0..24).map(|i| i as f32).collect::<Vec<_>>();
431        let x = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
432        // Shape [2, 3, 4] with values 0 to 23
433
434        // Permute to [4, 2, 3] (swap dims 0 and 2)
435        let x_perm = x.permute(vec![2, 1, 0]);
436        assert_eq!(x_perm.shape().dims, vec![4, 3, 2]);
437        assert!(!x_perm.is_contiguous());
438
439        // Global argmax should still find the maximum value (23)
440        let idx = x_perm.argmax();
441        assert_eq!(idx.get(&[0]), 23.0); // The max value is still 23
442
443        // Test argmax along each dimension of permuted tensor
444        let idx0 = x_perm.argmax_dim(0, false); // [3, 2]
445        assert_eq!(idx0.shape().dims, vec![3, 2]);
446
447        let idx1 = x_perm.argmax_dim(1, false); // [4, 2]
448        assert_eq!(idx1.shape().dims, vec![4, 2]);
449
450        let idx2 = x_perm.argmax_dim(2, false); // [4, 3]
451        assert_eq!(idx2.shape().dims, vec![4, 3]);
452    }
453
454    #[test]
455    fn test_argmax_level2_nested_views() {
456        // Test nested transformations (transpose then select)
457        let data = vec![
458            1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
459        ];
460        let x = Tensor::from_slice(&data, vec![4, 3]).unwrap();
461
462        // First transpose, then select a row
463        let x_t = x.transpose(0, 1); // [3, 4]
464        let row = x_t.select(0, 1); // Select second row: [2, 5, 8, 11]
465        assert_eq!(row.shape().dims, vec![4]);
466
467        let idx = row.argmax();
468        assert_eq!(idx.get(&[0]), 3.0); // index 3 has value 11.0
469    }
470
471    // ====== LEVEL 3: Complex multi-dimensional cases and edge scenarios ======
472
473    #[test]
474    fn test_argmax_level3_4d_tensor() {
475        // Test 4D tensor with various reduction dimensions
476        let data = (0..120).map(|i| i as f32).collect::<Vec<_>>();
477        let x = Tensor::from_slice(&data, vec![2, 3, 4, 5]).unwrap();
478        // Shape [2, 3, 4, 5] with values 0 to 119
479
480        // Global argmax
481        let idx = x.argmax();
482        assert_eq!(idx.get(&[0]), 119.0); // Maximum value 119.0 at flat index 119
483
484        // Test argmax along each dimension
485        let idx0_keepdim = x.argmax_dim(0, true);
486        assert_eq!(idx0_keepdim.shape().dims, vec![1, 3, 4, 5]);
487
488        let idx0_no_keepdim = x.argmax_dim(0, false);
489        assert_eq!(idx0_no_keepdim.shape().dims, vec![3, 4, 5]);
490
491        let idx1_keepdim = x.argmax_dim(1, true);
492        assert_eq!(idx1_keepdim.shape().dims, vec![2, 1, 4, 5]);
493
494        let idx1_no_keepdim = x.argmax_dim(1, false);
495        assert_eq!(idx1_no_keepdim.shape().dims, vec![2, 4, 5]);
496
497        let idx2_keepdim = x.argmax_dim(2, true);
498        assert_eq!(idx2_keepdim.shape().dims, vec![2, 3, 1, 5]);
499
500        let idx2_no_keepdim = x.argmax_dim(2, false);
501        assert_eq!(idx2_no_keepdim.shape().dims, vec![2, 3, 5]);
502
503        let idx3_keepdim = x.argmax_dim(3, true);
504        assert_eq!(idx3_keepdim.shape().dims, vec![2, 3, 4, 1]);
505
506        let idx3_no_keepdim = x.argmax_dim(3, false);
507        assert_eq!(idx3_no_keepdim.shape().dims, vec![2, 3, 4]);
508
509        // Check some specific values for the innermost dimension (dim=3)
510        // For each [i, j, k, :] slice, argmax should be 4 (index of max in size-5 dimension)
511        for i in 0..2 {
512            for j in 0..3 {
513                for k in 0..4 {
514                    assert_eq!(idx3_no_keepdim.get(&[i, j, k]), 4.0);
515                    assert_eq!(idx3_keepdim.get(&[i, j, k, 0]), 4.0);
516                }
517            }
518        }
519    }
520
521    #[test]
522    fn test_argmax_level3_edge_cases_keepdim() {
523        // Test edge case: 1D tensor with keepdim
524        let x1d = Tensor::from_slice(&[5.0, 1.0, 8.0, 3.0], vec![4]).unwrap();
525
526        let idx_keepdim = x1d.argmax_dim(0, true);
527        assert_eq!(idx_keepdim.shape().dims, vec![1]);
528        assert_eq!(idx_keepdim.get(&[0]), 2.0);
529
530        let idx_no_keepdim = x1d.argmax_dim(0, false);
531        assert_eq!(idx_no_keepdim.shape().dims, vec![1]); // Special case: becomes [1] not []
532        assert_eq!(idx_no_keepdim.get(&[0]), 2.0);
533
534        // Test edge case: dimension of size 1
535        let x_size_1 = Tensor::from_slice(&[42.0], vec![1]).unwrap();
536
537        let idx = x_size_1.argmax_dim(0, true);
538        assert_eq!(idx.shape().dims, vec![1]);
539        assert_eq!(idx.get(&[0]), 0.0);
540
541        let idx = x_size_1.argmax_dim(0, false);
542        assert_eq!(idx.shape().dims, vec![1]);
543        assert_eq!(idx.get(&[0]), 0.0);
544    }
545
546    #[test]
547    fn test_argmax_level3_ties_handling() {
548        // Test that tied values return the first occurrence (PyTorch behavior)
549        let x = Tensor::from_slice(&[3.0, 5.0, 5.0, 2.0, 5.0], vec![5]).unwrap();
550
551        let idx = x.argmax();
552        assert_eq!(idx.get(&[0]), 1.0); // First occurrence of max value 5.0
553
554        // Test with 2D ties
555        let x2d = Tensor::from_slice(&[3.0, 5.0, 5.0, 2.0, 1.0, 5.0], vec![3, 2]).unwrap();
556
557        // argmax along dim=0 (columns)
558        let idx0 = x2d.argmax_dim(0, false);
559        assert_eq!(idx0.shape().dims, vec![2]);
560        assert_eq!(idx0.get(&[0]), 1.0); // col 0: [3, 5, 1] -> first 5 at index 1
561        assert_eq!(idx0.get(&[1]), 0.0); // col 1: [5, 2, 5] -> first 5 at index 0
562
563        // argmax along dim=1 (rows)
564        let idx1 = x2d.argmax_dim(1, false);
565        assert_eq!(idx1.shape().dims, vec![3]);
566        assert_eq!(idx1.get(&[0]), 1.0); // row 0: [3, 5] -> max at index 1
567        assert_eq!(idx1.get(&[1]), 0.0); // row 1: [5, 2] -> max at index 0
568        assert_eq!(idx1.get(&[2]), 1.0); // row 2: [1, 5] -> max at index 1
569    }
570
571    #[test]
572    fn test_argmax_level3_extreme_values() {
573        // Test with extreme floating point values
574        let x = Tensor::from_slice(
575            &[f32::NEG_INFINITY, -1e10, 0.0, 1e10, f32::INFINITY, f32::NAN],
576            vec![6],
577        )
578        .unwrap();
579
580        let idx = x.argmax();
581        // NaN comparison behavior: NaN is not > any value, so INFINITY should win
582        assert_eq!(idx.get(&[0]), 4.0); // f32::INFINITY at index 4
583
584        // Test negative values only
585        let x_neg = Tensor::from_slice(&[-10.0, -5.0, -15.0, -1.0], vec![4]).unwrap();
586        let idx = x_neg.argmax();
587        assert_eq!(idx.get(&[0]), 3.0); // -1.0 is the maximum at index 3
588    }
589
590    #[test]
591    fn test_argmax_level3_large_dimensions() {
592        // Test with one very large dimension
593        let size = 1000;
594        let data: Vec<f32> = (0..size).map(|i| (size - 1 - i) as f32).collect(); // Decreasing values
595        let x = Tensor::from_slice(&data, vec![size]).unwrap();
596
597        let idx = x.argmax();
598        assert_eq!(idx.get(&[0]), 0.0); // First element has max value (size-1)
599
600        // Test with multiple dimensions where one is large
601        let data2: Vec<f32> = (0..(10 * 100)).map(|i| i as f32).collect();
602        let x2 = Tensor::from_slice(&data2, vec![10, 100]).unwrap();
603
604        let idx = x2.argmax();
605        assert_eq!(idx.get(&[0]), 999.0); // Last element has max value
606
607        // Test argmax along the large dimension
608        let idx_dim1 = x2.argmax_dim(1, false);
609        assert_eq!(idx_dim1.shape().dims, vec![10]);
610        // Each row's max should be at index 99 (last column)
611        for i in 0..10 {
612            assert_eq!(idx_dim1.get(&[i]), 99.0);
613        }
614    }
615
616    #[test]
617    fn test_argmax_level3_consistency_with_pytorch_behavior() {
618        // Test specific patterns that should match PyTorch exactly
619
620        // Pattern 1: 3D tensor, reduce middle dimension
621        let x = Tensor::from_slice(
622            &[
623                1.0, 2.0, 3.0, 4.0, // [0, 0, :]
624                5.0, 6.0, 7.0, 8.0, // [0, 1, :]
625                9.0, 8.0, 7.0, 6.0, // [1, 0, :]
626                5.0, 4.0, 3.0, 2.0, // [1, 1, :]
627            ],
628            vec![2, 2, 4],
629        )
630        .unwrap();
631
632        // Reduce along dim=1 (middle dimension)
633        let idx = x.argmax_dim(1, true);
634        assert_eq!(idx.shape().dims, vec![2, 1, 4]);
635
636        // For [0, :, j] where j=0,1,2,3: values are [1,5], [2,6], [3,7], [4,8]
637        // Max indices should be [1,1,1,1] (second slice wins)
638        assert_eq!(idx.get(&[0, 0, 0]), 1.0);
639        assert_eq!(idx.get(&[0, 0, 1]), 1.0);
640        assert_eq!(idx.get(&[0, 0, 2]), 1.0);
641        assert_eq!(idx.get(&[0, 0, 3]), 1.0);
642
643        // For [1, :, j] where j=0,1,2,3: values are [9,5], [8,4], [7,3], [6,2]
644        // Max indices should be [0,0,0,0] (first slice wins)
645        assert_eq!(idx.get(&[1, 0, 0]), 0.0);
646        assert_eq!(idx.get(&[1, 0, 1]), 0.0);
647        assert_eq!(idx.get(&[1, 0, 2]), 0.0);
648        assert_eq!(idx.get(&[1, 0, 3]), 0.0);
649    }
650}