train_station/tensor/reductions/argmax.rs
1//! Argmax reduction operations for tensors
2//!
3//! This module provides argmax operations that find the indices of maximum values
4//! in tensors. These operations are non-differentiable and never require gradients.
5//!
6//! # Operations
7//!
8//! * `argmax()` - Find the index of the maximum value across all elements
9//! * `argmax_dim()` - Find the indices of maximum values along a specific dimension
10//!
11//! # Examples
12//!
13//! ```
14//! use train_station::Tensor;
15//!
16//! let tensor = Tensor::from_slice(&[1.0, 5.0, 3.0, 2.0], vec![4]).unwrap();
17//! let max_idx = tensor.argmax();
18//! assert_eq!(max_idx.get(&[0]), 1.0); // Index 1 has the maximum value 5.0
19//! ```
20
21use crate::tensor::core::Tensor;
22
23impl Tensor {
24 /// Returns the index of the maximum value across all elements in the tensor
25 ///
26 /// This operation finds the flat index (0-based) of the element with the highest value.
27 /// If multiple elements have the same maximum value, the index of the first occurrence
28 /// is returned. The output is a scalar tensor with shape \[1\] containing the index as a float.
29 ///
30 /// This operation is non-differentiable and the output never requires gradients.
31 ///
32 /// # Returns
33 ///
34 /// A tensor with shape \[1\] containing the flat index of the maximum value
35 ///
36 /// # Examples
37 ///
38 /// ```
39 /// use train_station::Tensor;
40 ///
41 /// // 1D tensor
42 /// let tensor = Tensor::from_slice(&[1.0, 5.0, 3.0, 2.0], vec![4]).unwrap();
43 /// let max_idx = tensor.argmax();
44 /// assert_eq!(max_idx.shape().dims, vec![1]);
45 /// assert_eq!(max_idx.get(&[0]), 1.0); // Index 1 has value 5.0
46 /// ```
47 ///
48 /// ```
49 /// use train_station::Tensor;
50 ///
51 /// // 2D tensor
52 /// let tensor = Tensor::from_slice(&[1.0, 3.0, 2.0, 4.0, 0.0, 5.0], vec![2, 3]).unwrap();
53 /// let max_idx = tensor.argmax();
54 /// assert_eq!(max_idx.get(&[0]), 5.0); // Flat index 5 has value 5.0
55 /// ```
56 ///
57 /// ```
58 /// use train_station::Tensor;
59 ///
60 /// // Tied values return first occurrence
61 /// let tensor = Tensor::from_slice(&[3.0, 5.0, 5.0, 2.0], vec![4]).unwrap();
62 /// let max_idx = tensor.argmax();
63 /// assert_eq!(max_idx.get(&[0]), 1.0); // First occurrence of 5.0 at index 1
64 /// ```
65 #[track_caller]
66 pub fn argmax(&self) -> Tensor {
67 let mut out = Tensor::new(vec![1]);
68 if self.size() == 0 {
69 out.fill(0.0);
70 return out;
71 }
72
73 let mut best_val = f32::NEG_INFINITY;
74 let mut best_idx = 0usize;
75
76 if self.is_contiguous() {
77 // Fast path for contiguous tensors
78 unsafe {
79 let src = self.as_ptr();
80 for i in 0..self.size() {
81 let v = *src.add(i);
82 if v > best_val {
83 best_val = v;
84 best_idx = i;
85 }
86 }
87 }
88 } else {
89 // Stride-aware path for non-contiguous tensors
90 let dims = self.shape().dims.clone();
91 for flat_idx in 0..self.size() {
92 // Convert flat index to multi-dimensional coordinates
93 let mut coords = vec![0; dims.len()];
94 let mut tmp = flat_idx;
95 for k in (0..dims.len()).rev() {
96 coords[k] = tmp % dims[k];
97 tmp /= dims[k];
98 }
99
100 // Get value using stride-aware offset
101 let offset = self.shape().offset(&coords);
102 let v = unsafe { *self.as_ptr().add(offset) };
103 if v > best_val {
104 best_val = v;
105 best_idx = flat_idx;
106 }
107 }
108 }
109
110 unsafe {
111 *out.as_mut_ptr() = best_idx as f32;
112 }
113 out
114 }
115
116 /// Returns the indices of maximum values along a specified dimension
117 ///
118 /// This operation finds the indices of maximum values along the specified dimension.
119 /// For each slice along the dimension, it returns the index of the maximum value.
120 /// If multiple elements have the same maximum value, the index of the first occurrence
121 /// is returned.
122 ///
123 /// The output shape depends on the `keepdim` parameter:
124 /// * If `keepdim` is `true`, the reduced dimension is kept with size 1
125 /// * If `keepdim` is `false`, the reduced dimension is removed
126 ///
127 /// This operation is non-differentiable and the output never requires gradients.
128 ///
129 /// # Arguments
130 ///
131 /// * `dim` - The dimension along which to find argmax indices (0-based)
132 /// * `keepdim` - Whether to keep the reduced dimension with size 1
133 ///
134 /// # Returns
135 ///
136 /// A tensor containing the indices of maximum values along the specified dimension
137 ///
138 /// # Panics
139 ///
140 /// Panics if `dim` is out of bounds for the tensor's rank or if the dimension size is 0.
141 ///
142 /// # Examples
143 ///
144 /// ```
145 /// use train_station::Tensor;
146 ///
147 /// // 2D tensor: [[1.0, 3.0, 2.0],
148 /// // [4.0, 0.0, 5.0]]
149 /// let tensor = Tensor::from_slice(&[1.0, 3.0, 2.0, 4.0, 0.0, 5.0], vec![2, 3]).unwrap();
150 ///
151 /// // argmax along columns (dim=1)
152 /// let col_max_idx = tensor.argmax_dim(1, false);
153 /// assert_eq!(col_max_idx.shape().dims, vec![2]);
154 /// assert_eq!(col_max_idx.get(&[0]), 1.0); // Row 0: max at index 1 (value 3.0)
155 /// assert_eq!(col_max_idx.get(&[1]), 2.0); // Row 1: max at index 2 (value 5.0)
156 /// ```
157 ///
158 /// ```
159 /// use train_station::Tensor;
160 ///
161 /// // argmax along rows (dim=0) with keepdim
162 /// let tensor = Tensor::from_slice(&[1.0, 3.0, 2.0, 4.0, 0.0, 5.0], vec![2, 3]).unwrap();
163 /// let row_max_idx = tensor.argmax_dim(0, true);
164 /// assert_eq!(row_max_idx.shape().dims, vec![1, 3]);
165 /// assert_eq!(row_max_idx.get(&[0, 0]), 1.0); // Col 0: max at index 1 (value 4.0)
166 /// assert_eq!(row_max_idx.get(&[0, 1]), 0.0); // Col 1: max at index 0 (value 3.0)
167 /// assert_eq!(row_max_idx.get(&[0, 2]), 1.0); // Col 2: max at index 1 (value 5.0)
168 /// ```
169 ///
170 /// ```
171 /// use train_station::Tensor;
172 ///
173 /// // 1D tensor edge case
174 /// let tensor = Tensor::from_slice(&[5.0, 1.0, 8.0, 3.0], vec![4]).unwrap();
175 /// let max_idx = tensor.argmax_dim(0, false);
176 /// assert_eq!(max_idx.shape().dims, vec![1]); // Special case: becomes [1] not []
177 /// assert_eq!(max_idx.get(&[0]), 2.0); // Index 2 has maximum value 8.0
178 /// ```
179 #[track_caller]
180 pub fn argmax_dim(&self, dim: usize, keepdim: bool) -> Tensor {
181 let rank = self.shape().rank();
182 assert!(
183 dim < rank,
184 "argmax_dim dim {} out of bounds for rank {}",
185 dim,
186 rank
187 );
188
189 let in_dims = self.shape().dims.clone();
190 let reduce_size = in_dims[dim];
191 assert!(reduce_size > 0, "cannot argmax over empty dimension");
192
193 // Build output shape
194 let mut out_dims = in_dims.clone();
195 if keepdim {
196 out_dims[dim] = 1;
197 } else {
198 out_dims.remove(dim);
199 }
200 if out_dims.is_empty() {
201 out_dims.push(1);
202 }
203
204 let mut out = Tensor::zeros(out_dims.clone());
205
206 // Use stride-aware approach to handle non-contiguous tensors correctly
207 let out_size = out.size();
208
209 unsafe {
210 let dst = out.as_mut_ptr();
211
212 // Iterate over all output positions
213 for out_idx in 0..out_size {
214 // Convert flat output index to multi-dimensional coordinates
215 let mut out_coords = vec![0; out_dims.len()];
216 let mut tmp = out_idx;
217 for k in (0..out_dims.len()).rev() {
218 out_coords[k] = tmp % out_dims[k];
219 tmp /= out_dims[k];
220 }
221
222 // Convert output coordinates to input coordinates
223 let mut in_coords = vec![0; rank];
224 if keepdim {
225 // When keepdim=true, output coords map directly to input coords
226 for k in 0..rank {
227 if k == dim {
228 in_coords[k] = 0; // Will be set in the loop below
229 } else {
230 in_coords[k] = out_coords[k];
231 }
232 }
233 } else {
234 // When keepdim=false, we need to insert the missing dimension
235 let mut out_coord_idx = 0;
236 for (k, in_coord) in in_coords.iter_mut().enumerate().take(rank) {
237 if k == dim {
238 *in_coord = 0; // Will be set in the loop below
239 } else {
240 *in_coord = out_coords[out_coord_idx];
241 out_coord_idx += 1;
242 }
243 }
244 }
245
246 // Find the argmax along the specified dimension
247 let mut best_val = f32::NEG_INFINITY;
248 let mut best_j = 0usize;
249
250 for j in 0..reduce_size {
251 in_coords[dim] = j;
252 let in_offset = self.shape().offset(&in_coords);
253 let v = *self.as_ptr().add(in_offset);
254 if v > best_val {
255 best_val = v;
256 best_j = j;
257 }
258 }
259
260 *dst.add(out_idx) = best_j as f32;
261 }
262 }
263
264 out
265 }
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271
272 // ====== LEVEL 1: Basic functionality tests for contiguous tensors ======
273
274 #[test]
275 fn test_argmax_level1_basic_1d() {
276 let x = Tensor::from_slice(&[3.0, -2.0, 5.0, -1.0], vec![4]).unwrap();
277 let idx = x.argmax();
278
279 // Check output shape
280 assert_eq!(idx.shape().dims, vec![1]);
281 assert_eq!(idx.size(), 1);
282
283 // Check result
284 assert_eq!(idx.get(&[0]), 2.0); // index 2 has value 5.0
285 }
286
287 #[test]
288 fn test_argmax_level1_basic_1d_edge_cases() {
289 // Single element
290 let x = Tensor::from_slice(&[42.0], vec![1]).unwrap();
291 let idx = x.argmax();
292 assert_eq!(idx.get(&[0]), 0.0);
293
294 // All same values (should return first occurrence)
295 let x = Tensor::from_slice(&[3.0, 3.0, 3.0], vec![3]).unwrap();
296 let idx = x.argmax();
297 assert_eq!(idx.get(&[0]), 0.0);
298
299 // Negative values
300 let x = Tensor::from_slice(&[-5.0, -2.0, -8.0, -1.0], vec![4]).unwrap();
301 let idx = x.argmax();
302 assert_eq!(idx.get(&[0]), 3.0); // index 3 has value -1.0
303 }
304
305 #[test]
306 fn test_argmax_level1_basic_2d_contiguous() {
307 // Test argmax over all elements for 2D tensor
308 // Data: [[1.0, 3.0, 2.0],
309 // [4.0, 0.0, 5.0]]
310 let x = Tensor::from_slice(&[1.0, 3.0, 2.0, 4.0, 0.0, 5.0], vec![2, 3]).unwrap();
311 let idx = x.argmax();
312
313 assert_eq!(idx.shape().dims, vec![1]);
314 assert_eq!(idx.get(&[0]), 5.0); // flat index 5 has value 5.0
315 }
316
317 #[test]
318 fn test_argmax_level1_dim_2d_basic() {
319 // Test argmax_dim for simple 2D case
320 // Data: [[1.0, 3.0, 2.0],
321 // [4.0, 0.0, 5.0]]
322 let x = Tensor::from_slice(&[1.0, 3.0, 2.0, 4.0, 0.0, 5.0], vec![2, 3]).unwrap();
323
324 // argmax along dim=1 (along columns within each row)
325 let idx1_keepdim = x.argmax_dim(1, true);
326 assert_eq!(idx1_keepdim.shape().dims, vec![2, 1]);
327 assert_eq!(idx1_keepdim.get(&[0, 0]), 1.0); // row 0: max at index 1 (value 3.0)
328 assert_eq!(idx1_keepdim.get(&[1, 0]), 2.0); // row 1: max at index 2 (value 5.0)
329
330 let idx1_no_keepdim = x.argmax_dim(1, false);
331 assert_eq!(idx1_no_keepdim.shape().dims, vec![2]);
332 assert_eq!(idx1_no_keepdim.get(&[0]), 1.0);
333 assert_eq!(idx1_no_keepdim.get(&[1]), 2.0);
334
335 // argmax along dim=0 (along rows within each column)
336 let idx0_keepdim = x.argmax_dim(0, true);
337 assert_eq!(idx0_keepdim.shape().dims, vec![1, 3]);
338 assert_eq!(idx0_keepdim.get(&[0, 0]), 1.0); // col 0: max at index 1 (value 4.0)
339 assert_eq!(idx0_keepdim.get(&[0, 1]), 0.0); // col 1: max at index 0 (value 3.0)
340 assert_eq!(idx0_keepdim.get(&[0, 2]), 1.0); // col 2: max at index 1 (value 5.0)
341
342 let idx0_no_keepdim = x.argmax_dim(0, false);
343 assert_eq!(idx0_no_keepdim.shape().dims, vec![3]);
344 assert_eq!(idx0_no_keepdim.get(&[0]), 1.0);
345 assert_eq!(idx0_no_keepdim.get(&[1]), 0.0);
346 assert_eq!(idx0_no_keepdim.get(&[2]), 1.0);
347 }
348
349 #[test]
350 fn test_argmax_level1_3d_basic() {
351 // Test 3D tensor: shape [2, 2, 2]
352 // Data: [[[1.0, 2.0], [3.0, 4.0]],
353 // [[5.0, 6.0], [7.0, 8.0]]]
354 let x =
355 Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], vec![2, 2, 2]).unwrap();
356
357 // Global argmax
358 let idx = x.argmax();
359 assert_eq!(idx.get(&[0]), 7.0); // flat index 7 has value 8.0
360
361 // argmax along dim=2 (innermost dimension)
362 let idx2 = x.argmax_dim(2, false);
363 assert_eq!(idx2.shape().dims, vec![2, 2]);
364 assert_eq!(idx2.get(&[0, 0]), 1.0); // [1.0, 2.0] -> max at index 1
365 assert_eq!(idx2.get(&[0, 1]), 1.0); // [3.0, 4.0] -> max at index 1
366 assert_eq!(idx2.get(&[1, 0]), 1.0); // [5.0, 6.0] -> max at index 1
367 assert_eq!(idx2.get(&[1, 1]), 1.0); // [7.0, 8.0] -> max at index 1
368 }
369
370 // ====== LEVEL 2: Non-contiguous tensors (views, permuted) ======
371
372 #[test]
373 fn test_argmax_level2_transpose_view() {
374 // Create a 2x3 tensor and transpose it to get a non-contiguous view
375 let x = Tensor::from_slice(&[1.0, 3.0, 2.0, 4.0, 0.0, 5.0], vec![2, 3]).unwrap();
376 // Original: [[1.0, 3.0, 2.0],
377 // [4.0, 0.0, 5.0]]
378
379 let x_t = x.transpose(0, 1);
380 // Transposed: [[1.0, 4.0],
381 // [3.0, 0.0],
382 // [2.0, 5.0]]
383 assert_eq!(x_t.shape().dims, vec![3, 2]);
384 assert!(!x_t.is_contiguous()); // Should be a view
385
386 // Test global argmax on transposed view
387 let idx = x_t.argmax();
388 assert_eq!(idx.get(&[0]), 5.0); // flat index 5 still points to value 5.0
389
390 // Test argmax along dim=0 of transposed tensor
391 let idx0 = x_t.argmax_dim(0, false);
392 assert_eq!(idx0.shape().dims, vec![2]);
393 assert_eq!(idx0.get(&[0]), 1.0); // col 0: [1.0, 3.0, 2.0] -> max 3.0 at index 1
394 assert_eq!(idx0.get(&[1]), 2.0); // col 1: [4.0, 0.0, 5.0] -> max 5.0 at index 2
395
396 // Test argmax along dim=1 of transposed tensor
397 let idx1 = x_t.argmax_dim(1, false);
398 assert_eq!(idx1.shape().dims, vec![3]);
399 assert_eq!(idx1.get(&[0]), 1.0); // row 0: [1.0, 4.0] -> max 4.0 at index 1
400 assert_eq!(idx1.get(&[1]), 0.0); // row 1: [3.0, 0.0] -> max 3.0 at index 0
401 assert_eq!(idx1.get(&[2]), 1.0); // row 2: [2.0, 5.0] -> max 5.0 at index 1
402 }
403
404 #[test]
405 fn test_argmax_level2_slice_view() {
406 // Create a 3x4 tensor and take a slice
407 let data = vec![
408 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
409 ];
410 let x = Tensor::from_slice(&data, vec![3, 4]).unwrap();
411 // [[1, 2, 3, 4],
412 // [5, 6, 7, 8],
413 // [9, 10, 11, 12]]
414
415 // Select middle row (creates a view)
416 let middle_row = x.select(0, 1);
417 // [5, 6, 7, 8]
418 assert_eq!(middle_row.shape().dims, vec![4]);
419
420 let idx = middle_row.argmax();
421 assert_eq!(idx.get(&[0]), 3.0); // index 3 has value 8.0
422
423 // Test argmax_dim on 1D slice (should work the same as global argmax)
424 let idx_dim = middle_row.argmax_dim(0, false);
425 assert_eq!(idx_dim.shape().dims, vec![1]);
426 assert_eq!(idx_dim.get(&[0]), 3.0);
427 }
428
429 #[test]
430 fn test_argmax_level2_permuted_3d() {
431 // Test 3D tensor with permuted dimensions
432 let data = (0..24).map(|i| i as f32).collect::<Vec<_>>();
433 let x = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
434 // Shape [2, 3, 4] with values 0 to 23
435
436 // Permute to [4, 2, 3] (swap dims 0 and 2)
437 let x_perm = x.permute(vec![2, 1, 0]);
438 assert_eq!(x_perm.shape().dims, vec![4, 3, 2]);
439 assert!(!x_perm.is_contiguous());
440
441 // Global argmax should still find the maximum value (23)
442 let idx = x_perm.argmax();
443 assert_eq!(idx.get(&[0]), 23.0); // The max value is still 23
444
445 // Test argmax along each dimension of permuted tensor
446 let idx0 = x_perm.argmax_dim(0, false); // [3, 2]
447 assert_eq!(idx0.shape().dims, vec![3, 2]);
448
449 let idx1 = x_perm.argmax_dim(1, false); // [4, 2]
450 assert_eq!(idx1.shape().dims, vec![4, 2]);
451
452 let idx2 = x_perm.argmax_dim(2, false); // [4, 3]
453 assert_eq!(idx2.shape().dims, vec![4, 3]);
454 }
455
456 #[test]
457 fn test_argmax_level2_nested_views() {
458 // Test nested transformations (transpose then select)
459 let data = vec![
460 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
461 ];
462 let x = Tensor::from_slice(&data, vec![4, 3]).unwrap();
463
464 // First transpose, then select a row
465 let x_t = x.transpose(0, 1); // [3, 4]
466 let row = x_t.select(0, 1); // Select second row: [2, 5, 8, 11]
467 assert_eq!(row.shape().dims, vec![4]);
468
469 let idx = row.argmax();
470 assert_eq!(idx.get(&[0]), 3.0); // index 3 has value 11.0
471 }
472
473 // ====== LEVEL 3: Complex multi-dimensional cases and edge scenarios ======
474
475 #[test]
476 fn test_argmax_level3_4d_tensor() {
477 // Test 4D tensor with various reduction dimensions
478 let data = (0..120).map(|i| i as f32).collect::<Vec<_>>();
479 let x = Tensor::from_slice(&data, vec![2, 3, 4, 5]).unwrap();
480 // Shape [2, 3, 4, 5] with values 0 to 119
481
482 // Global argmax
483 let idx = x.argmax();
484 assert_eq!(idx.get(&[0]), 119.0); // Maximum value 119.0 at flat index 119
485
486 // Test argmax along each dimension
487 let idx0_keepdim = x.argmax_dim(0, true);
488 assert_eq!(idx0_keepdim.shape().dims, vec![1, 3, 4, 5]);
489
490 let idx0_no_keepdim = x.argmax_dim(0, false);
491 assert_eq!(idx0_no_keepdim.shape().dims, vec![3, 4, 5]);
492
493 let idx1_keepdim = x.argmax_dim(1, true);
494 assert_eq!(idx1_keepdim.shape().dims, vec![2, 1, 4, 5]);
495
496 let idx1_no_keepdim = x.argmax_dim(1, false);
497 assert_eq!(idx1_no_keepdim.shape().dims, vec![2, 4, 5]);
498
499 let idx2_keepdim = x.argmax_dim(2, true);
500 assert_eq!(idx2_keepdim.shape().dims, vec![2, 3, 1, 5]);
501
502 let idx2_no_keepdim = x.argmax_dim(2, false);
503 assert_eq!(idx2_no_keepdim.shape().dims, vec![2, 3, 5]);
504
505 let idx3_keepdim = x.argmax_dim(3, true);
506 assert_eq!(idx3_keepdim.shape().dims, vec![2, 3, 4, 1]);
507
508 let idx3_no_keepdim = x.argmax_dim(3, false);
509 assert_eq!(idx3_no_keepdim.shape().dims, vec![2, 3, 4]);
510
511 // Check some specific values for the innermost dimension (dim=3)
512 // For each [i, j, k, :] slice, argmax should be 4 (index of max in size-5 dimension)
513 for i in 0..2 {
514 for j in 0..3 {
515 for k in 0..4 {
516 assert_eq!(idx3_no_keepdim.get(&[i, j, k]), 4.0);
517 assert_eq!(idx3_keepdim.get(&[i, j, k, 0]), 4.0);
518 }
519 }
520 }
521 }
522
523 #[test]
524 fn test_argmax_level3_edge_cases_keepdim() {
525 // Test edge case: 1D tensor with keepdim
526 let x1d = Tensor::from_slice(&[5.0, 1.0, 8.0, 3.0], vec![4]).unwrap();
527
528 let idx_keepdim = x1d.argmax_dim(0, true);
529 assert_eq!(idx_keepdim.shape().dims, vec![1]);
530 assert_eq!(idx_keepdim.get(&[0]), 2.0);
531
532 let idx_no_keepdim = x1d.argmax_dim(0, false);
533 assert_eq!(idx_no_keepdim.shape().dims, vec![1]); // Special case: becomes [1] not []
534 assert_eq!(idx_no_keepdim.get(&[0]), 2.0);
535
536 // Test edge case: dimension of size 1
537 let x_size_1 = Tensor::from_slice(&[42.0], vec![1]).unwrap();
538
539 let idx = x_size_1.argmax_dim(0, true);
540 assert_eq!(idx.shape().dims, vec![1]);
541 assert_eq!(idx.get(&[0]), 0.0);
542
543 let idx = x_size_1.argmax_dim(0, false);
544 assert_eq!(idx.shape().dims, vec![1]);
545 assert_eq!(idx.get(&[0]), 0.0);
546 }
547
548 #[test]
549 fn test_argmax_level3_ties_handling() {
550 // Test that tied values return the first occurrence (PyTorch behavior)
551 let x = Tensor::from_slice(&[3.0, 5.0, 5.0, 2.0, 5.0], vec![5]).unwrap();
552
553 let idx = x.argmax();
554 assert_eq!(idx.get(&[0]), 1.0); // First occurrence of max value 5.0
555
556 // Test with 2D ties
557 let x2d = Tensor::from_slice(&[3.0, 5.0, 5.0, 2.0, 1.0, 5.0], vec![3, 2]).unwrap();
558
559 // argmax along dim=0 (columns)
560 let idx0 = x2d.argmax_dim(0, false);
561 assert_eq!(idx0.shape().dims, vec![2]);
562 assert_eq!(idx0.get(&[0]), 1.0); // col 0: [3, 5, 1] -> first 5 at index 1
563 assert_eq!(idx0.get(&[1]), 0.0); // col 1: [5, 2, 5] -> first 5 at index 0
564
565 // argmax along dim=1 (rows)
566 let idx1 = x2d.argmax_dim(1, false);
567 assert_eq!(idx1.shape().dims, vec![3]);
568 assert_eq!(idx1.get(&[0]), 1.0); // row 0: [3, 5] -> max at index 1
569 assert_eq!(idx1.get(&[1]), 0.0); // row 1: [5, 2] -> max at index 0
570 assert_eq!(idx1.get(&[2]), 1.0); // row 2: [1, 5] -> max at index 1
571 }
572
573 #[test]
574 fn test_argmax_level3_extreme_values() {
575 // Test with extreme floating point values
576 let x = Tensor::from_slice(
577 &[f32::NEG_INFINITY, -1e10, 0.0, 1e10, f32::INFINITY, f32::NAN],
578 vec![6],
579 )
580 .unwrap();
581
582 let idx = x.argmax();
583 // NaN comparison behavior: NaN is not > any value, so INFINITY should win
584 assert_eq!(idx.get(&[0]), 4.0); // f32::INFINITY at index 4
585
586 // Test negative values only
587 let x_neg = Tensor::from_slice(&[-10.0, -5.0, -15.0, -1.0], vec![4]).unwrap();
588 let idx = x_neg.argmax();
589 assert_eq!(idx.get(&[0]), 3.0); // -1.0 is the maximum at index 3
590 }
591
592 #[test]
593 fn test_argmax_level3_large_dimensions() {
594 // Test with one very large dimension
595 let size = 1000;
596 let data: Vec<f32> = (0..size).map(|i| (size - 1 - i) as f32).collect(); // Decreasing values
597 let x = Tensor::from_slice(&data, vec![size]).unwrap();
598
599 let idx = x.argmax();
600 assert_eq!(idx.get(&[0]), 0.0); // First element has max value (size-1)
601
602 // Test with multiple dimensions where one is large
603 let data2: Vec<f32> = (0..(10 * 100)).map(|i| i as f32).collect();
604 let x2 = Tensor::from_slice(&data2, vec![10, 100]).unwrap();
605
606 let idx = x2.argmax();
607 assert_eq!(idx.get(&[0]), 999.0); // Last element has max value
608
609 // Test argmax along the large dimension
610 let idx_dim1 = x2.argmax_dim(1, false);
611 assert_eq!(idx_dim1.shape().dims, vec![10]);
612 // Each row's max should be at index 99 (last column)
613 for i in 0..10 {
614 assert_eq!(idx_dim1.get(&[i]), 99.0);
615 }
616 }
617
618 #[test]
619 fn test_argmax_level3_consistency_with_pytorch_behavior() {
620 // Test specific patterns that should match PyTorch exactly
621
622 // Pattern 1: 3D tensor, reduce middle dimension
623 let x = Tensor::from_slice(
624 &[
625 1.0, 2.0, 3.0, 4.0, // [0, 0, :]
626 5.0, 6.0, 7.0, 8.0, // [0, 1, :]
627 9.0, 8.0, 7.0, 6.0, // [1, 0, :]
628 5.0, 4.0, 3.0, 2.0, // [1, 1, :]
629 ],
630 vec![2, 2, 4],
631 )
632 .unwrap();
633
634 // Reduce along dim=1 (middle dimension)
635 let idx = x.argmax_dim(1, true);
636 assert_eq!(idx.shape().dims, vec![2, 1, 4]);
637
638 // For [0, :, j] where j=0,1,2,3: values are [1,5], [2,6], [3,7], [4,8]
639 // Max indices should be [1,1,1,1] (second slice wins)
640 assert_eq!(idx.get(&[0, 0, 0]), 1.0);
641 assert_eq!(idx.get(&[0, 0, 1]), 1.0);
642 assert_eq!(idx.get(&[0, 0, 2]), 1.0);
643 assert_eq!(idx.get(&[0, 0, 3]), 1.0);
644
645 // For [1, :, j] where j=0,1,2,3: values are [9,5], [8,4], [7,3], [6,2]
646 // Max indices should be [0,0,0,0] (first slice wins)
647 assert_eq!(idx.get(&[1, 0, 0]), 0.0);
648 assert_eq!(idx.get(&[1, 0, 1]), 0.0);
649 assert_eq!(idx.get(&[1, 0, 2]), 0.0);
650 assert_eq!(idx.get(&[1, 0, 3]), 0.0);
651 }
652}