train_station/tensor/reductions/argmax.rs
1//! Argmax reduction operations for tensors
2//!
3//! This module provides argmax operations that find the indices of maximum values
4//! in tensors. These operations are non-differentiable and never require gradients.
5//!
6//! # Operations
7//!
8//! * `argmax()` - Find the index of the maximum value across all elements
9//! * `argmax_dim()` - Find the indices of maximum values along a specific dimension
10//!
11//! # Examples
12//!
13//! ```
14//! use train_station::Tensor;
15//!
16//! let tensor = Tensor::from_slice(&[1.0, 5.0, 3.0, 2.0], vec![4]).unwrap();
17//! let max_idx = tensor.argmax();
18//! assert_eq!(max_idx.get(&[0]), 1.0); // Index 1 has the maximum value 5.0
19//! ```
20
21use crate::tensor::core::Tensor;
22
23impl Tensor {
24 /// Returns the index of the maximum value across all elements in the tensor
25 ///
26 /// This operation finds the flat index (0-based) of the element with the highest value.
27 /// If multiple elements have the same maximum value, the index of the first occurrence
28 /// is returned. The output is a scalar tensor with shape \[1\] containing the index as a float.
29 ///
30 /// This operation is non-differentiable and the output never requires gradients.
31 ///
32 /// # Returns
33 ///
34 /// A tensor with shape \[1\] containing the flat index of the maximum value
35 ///
36 /// # Examples
37 ///
38 /// ```
39 /// use train_station::Tensor;
40 ///
41 /// // 1D tensor
42 /// let tensor = Tensor::from_slice(&[1.0, 5.0, 3.0, 2.0], vec![4]).unwrap();
43 /// let max_idx = tensor.argmax();
44 /// assert_eq!(max_idx.shape().dims, vec![1]);
45 /// assert_eq!(max_idx.get(&[0]), 1.0); // Index 1 has value 5.0
46 /// ```
47 ///
48 /// ```
49 /// use train_station::Tensor;
50 ///
51 /// // 2D tensor
52 /// let tensor = Tensor::from_slice(&[1.0, 3.0, 2.0, 4.0, 0.0, 5.0], vec![2, 3]).unwrap();
53 /// let max_idx = tensor.argmax();
54 /// assert_eq!(max_idx.get(&[0]), 5.0); // Flat index 5 has value 5.0
55 /// ```
56 ///
57 /// ```
58 /// use train_station::Tensor;
59 ///
60 /// // Tied values return first occurrence
61 /// let tensor = Tensor::from_slice(&[3.0, 5.0, 5.0, 2.0], vec![4]).unwrap();
62 /// let max_idx = tensor.argmax();
63 /// assert_eq!(max_idx.get(&[0]), 1.0); // First occurrence of 5.0 at index 1
64 /// ```
65 pub fn argmax(&self) -> Tensor {
66 let mut out = Tensor::new(vec![1]);
67 if self.size() == 0 {
68 out.fill(0.0);
69 return out;
70 }
71
72 let mut best_val = f32::NEG_INFINITY;
73 let mut best_idx = 0usize;
74
75 if self.is_contiguous() {
76 // Fast path for contiguous tensors
77 unsafe {
78 let src = self.as_ptr();
79 for i in 0..self.size() {
80 let v = *src.add(i);
81 if v > best_val {
82 best_val = v;
83 best_idx = i;
84 }
85 }
86 }
87 } else {
88 // Stride-aware path for non-contiguous tensors
89 let dims = self.shape().dims.clone();
90 for flat_idx in 0..self.size() {
91 // Convert flat index to multi-dimensional coordinates
92 let mut coords = vec![0; dims.len()];
93 let mut tmp = flat_idx;
94 for k in (0..dims.len()).rev() {
95 coords[k] = tmp % dims[k];
96 tmp /= dims[k];
97 }
98
99 // Get value using stride-aware offset
100 let offset = self.shape().offset(&coords);
101 let v = unsafe { *self.as_ptr().add(offset) };
102 if v > best_val {
103 best_val = v;
104 best_idx = flat_idx;
105 }
106 }
107 }
108
109 unsafe {
110 *out.as_mut_ptr() = best_idx as f32;
111 }
112 out
113 }
114
115 /// Returns the indices of maximum values along a specified dimension
116 ///
117 /// This operation finds the indices of maximum values along the specified dimension.
118 /// For each slice along the dimension, it returns the index of the maximum value.
119 /// If multiple elements have the same maximum value, the index of the first occurrence
120 /// is returned.
121 ///
122 /// The output shape depends on the `keepdim` parameter:
123 /// * If `keepdim` is `true`, the reduced dimension is kept with size 1
124 /// * If `keepdim` is `false`, the reduced dimension is removed
125 ///
126 /// This operation is non-differentiable and the output never requires gradients.
127 ///
128 /// # Arguments
129 ///
130 /// * `dim` - The dimension along which to find argmax indices (0-based)
131 /// * `keepdim` - Whether to keep the reduced dimension with size 1
132 ///
133 /// # Returns
134 ///
135 /// A tensor containing the indices of maximum values along the specified dimension
136 ///
137 /// # Panics
138 ///
139 /// Panics if `dim` is out of bounds for the tensor's rank or if the dimension size is 0.
140 ///
141 /// # Examples
142 ///
143 /// ```
144 /// use train_station::Tensor;
145 ///
146 /// // 2D tensor: [[1.0, 3.0, 2.0],
147 /// // [4.0, 0.0, 5.0]]
148 /// let tensor = Tensor::from_slice(&[1.0, 3.0, 2.0, 4.0, 0.0, 5.0], vec![2, 3]).unwrap();
149 ///
150 /// // argmax along columns (dim=1)
151 /// let col_max_idx = tensor.argmax_dim(1, false);
152 /// assert_eq!(col_max_idx.shape().dims, vec![2]);
153 /// assert_eq!(col_max_idx.get(&[0]), 1.0); // Row 0: max at index 1 (value 3.0)
154 /// assert_eq!(col_max_idx.get(&[1]), 2.0); // Row 1: max at index 2 (value 5.0)
155 /// ```
156 ///
157 /// ```
158 /// use train_station::Tensor;
159 ///
160 /// // argmax along rows (dim=0) with keepdim
161 /// let tensor = Tensor::from_slice(&[1.0, 3.0, 2.0, 4.0, 0.0, 5.0], vec![2, 3]).unwrap();
162 /// let row_max_idx = tensor.argmax_dim(0, true);
163 /// assert_eq!(row_max_idx.shape().dims, vec![1, 3]);
164 /// assert_eq!(row_max_idx.get(&[0, 0]), 1.0); // Col 0: max at index 1 (value 4.0)
165 /// assert_eq!(row_max_idx.get(&[0, 1]), 0.0); // Col 1: max at index 0 (value 3.0)
166 /// assert_eq!(row_max_idx.get(&[0, 2]), 1.0); // Col 2: max at index 1 (value 5.0)
167 /// ```
168 ///
169 /// ```
170 /// use train_station::Tensor;
171 ///
172 /// // 1D tensor edge case
173 /// let tensor = Tensor::from_slice(&[5.0, 1.0, 8.0, 3.0], vec![4]).unwrap();
174 /// let max_idx = tensor.argmax_dim(0, false);
175 /// assert_eq!(max_idx.shape().dims, vec![1]); // Special case: becomes [1] not []
176 /// assert_eq!(max_idx.get(&[0]), 2.0); // Index 2 has maximum value 8.0
177 /// ```
178 pub fn argmax_dim(&self, dim: usize, keepdim: bool) -> Tensor {
179 let rank = self.shape().rank();
180 assert!(
181 dim < rank,
182 "argmax_dim dim {} out of bounds for rank {}",
183 dim,
184 rank
185 );
186
187 let in_dims = self.shape().dims.clone();
188 let reduce_size = in_dims[dim];
189 assert!(reduce_size > 0, "cannot argmax over empty dimension");
190
191 // Build output shape
192 let mut out_dims = in_dims.clone();
193 if keepdim {
194 out_dims[dim] = 1;
195 } else {
196 out_dims.remove(dim);
197 }
198 if out_dims.is_empty() {
199 out_dims.push(1);
200 }
201
202 let mut out = Tensor::zeros(out_dims.clone());
203
204 // Use stride-aware approach to handle non-contiguous tensors correctly
205 let out_size = out.size();
206
207 unsafe {
208 let dst = out.as_mut_ptr();
209
210 // Iterate over all output positions
211 for out_idx in 0..out_size {
212 // Convert flat output index to multi-dimensional coordinates
213 let mut out_coords = vec![0; out_dims.len()];
214 let mut tmp = out_idx;
215 for k in (0..out_dims.len()).rev() {
216 out_coords[k] = tmp % out_dims[k];
217 tmp /= out_dims[k];
218 }
219
220 // Convert output coordinates to input coordinates
221 let mut in_coords = vec![0; rank];
222 if keepdim {
223 // When keepdim=true, output coords map directly to input coords
224 for k in 0..rank {
225 if k == dim {
226 in_coords[k] = 0; // Will be set in the loop below
227 } else {
228 in_coords[k] = out_coords[k];
229 }
230 }
231 } else {
232 // When keepdim=false, we need to insert the missing dimension
233 let mut out_coord_idx = 0;
234 for (k, in_coord) in in_coords.iter_mut().enumerate().take(rank) {
235 if k == dim {
236 *in_coord = 0; // Will be set in the loop below
237 } else {
238 *in_coord = out_coords[out_coord_idx];
239 out_coord_idx += 1;
240 }
241 }
242 }
243
244 // Find the argmax along the specified dimension
245 let mut best_val = f32::NEG_INFINITY;
246 let mut best_j = 0usize;
247
248 for j in 0..reduce_size {
249 in_coords[dim] = j;
250 let in_offset = self.shape().offset(&in_coords);
251 let v = *self.as_ptr().add(in_offset);
252 if v > best_val {
253 best_val = v;
254 best_j = j;
255 }
256 }
257
258 *dst.add(out_idx) = best_j as f32;
259 }
260 }
261
262 out
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269
270 // ====== LEVEL 1: Basic functionality tests for contiguous tensors ======
271
272 #[test]
273 fn test_argmax_level1_basic_1d() {
274 let x = Tensor::from_slice(&[3.0, -2.0, 5.0, -1.0], vec![4]).unwrap();
275 let idx = x.argmax();
276
277 // Check output shape
278 assert_eq!(idx.shape().dims, vec![1]);
279 assert_eq!(idx.size(), 1);
280
281 // Check result
282 assert_eq!(idx.get(&[0]), 2.0); // index 2 has value 5.0
283 }
284
285 #[test]
286 fn test_argmax_level1_basic_1d_edge_cases() {
287 // Single element
288 let x = Tensor::from_slice(&[42.0], vec![1]).unwrap();
289 let idx = x.argmax();
290 assert_eq!(idx.get(&[0]), 0.0);
291
292 // All same values (should return first occurrence)
293 let x = Tensor::from_slice(&[3.0, 3.0, 3.0], vec![3]).unwrap();
294 let idx = x.argmax();
295 assert_eq!(idx.get(&[0]), 0.0);
296
297 // Negative values
298 let x = Tensor::from_slice(&[-5.0, -2.0, -8.0, -1.0], vec![4]).unwrap();
299 let idx = x.argmax();
300 assert_eq!(idx.get(&[0]), 3.0); // index 3 has value -1.0
301 }
302
303 #[test]
304 fn test_argmax_level1_basic_2d_contiguous() {
305 // Test argmax over all elements for 2D tensor
306 // Data: [[1.0, 3.0, 2.0],
307 // [4.0, 0.0, 5.0]]
308 let x = Tensor::from_slice(&[1.0, 3.0, 2.0, 4.0, 0.0, 5.0], vec![2, 3]).unwrap();
309 let idx = x.argmax();
310
311 assert_eq!(idx.shape().dims, vec![1]);
312 assert_eq!(idx.get(&[0]), 5.0); // flat index 5 has value 5.0
313 }
314
315 #[test]
316 fn test_argmax_level1_dim_2d_basic() {
317 // Test argmax_dim for simple 2D case
318 // Data: [[1.0, 3.0, 2.0],
319 // [4.0, 0.0, 5.0]]
320 let x = Tensor::from_slice(&[1.0, 3.0, 2.0, 4.0, 0.0, 5.0], vec![2, 3]).unwrap();
321
322 // argmax along dim=1 (along columns within each row)
323 let idx1_keepdim = x.argmax_dim(1, true);
324 assert_eq!(idx1_keepdim.shape().dims, vec![2, 1]);
325 assert_eq!(idx1_keepdim.get(&[0, 0]), 1.0); // row 0: max at index 1 (value 3.0)
326 assert_eq!(idx1_keepdim.get(&[1, 0]), 2.0); // row 1: max at index 2 (value 5.0)
327
328 let idx1_no_keepdim = x.argmax_dim(1, false);
329 assert_eq!(idx1_no_keepdim.shape().dims, vec![2]);
330 assert_eq!(idx1_no_keepdim.get(&[0]), 1.0);
331 assert_eq!(idx1_no_keepdim.get(&[1]), 2.0);
332
333 // argmax along dim=0 (along rows within each column)
334 let idx0_keepdim = x.argmax_dim(0, true);
335 assert_eq!(idx0_keepdim.shape().dims, vec![1, 3]);
336 assert_eq!(idx0_keepdim.get(&[0, 0]), 1.0); // col 0: max at index 1 (value 4.0)
337 assert_eq!(idx0_keepdim.get(&[0, 1]), 0.0); // col 1: max at index 0 (value 3.0)
338 assert_eq!(idx0_keepdim.get(&[0, 2]), 1.0); // col 2: max at index 1 (value 5.0)
339
340 let idx0_no_keepdim = x.argmax_dim(0, false);
341 assert_eq!(idx0_no_keepdim.shape().dims, vec![3]);
342 assert_eq!(idx0_no_keepdim.get(&[0]), 1.0);
343 assert_eq!(idx0_no_keepdim.get(&[1]), 0.0);
344 assert_eq!(idx0_no_keepdim.get(&[2]), 1.0);
345 }
346
347 #[test]
348 fn test_argmax_level1_3d_basic() {
349 // Test 3D tensor: shape [2, 2, 2]
350 // Data: [[[1.0, 2.0], [3.0, 4.0]],
351 // [[5.0, 6.0], [7.0, 8.0]]]
352 let x =
353 Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], vec![2, 2, 2]).unwrap();
354
355 // Global argmax
356 let idx = x.argmax();
357 assert_eq!(idx.get(&[0]), 7.0); // flat index 7 has value 8.0
358
359 // argmax along dim=2 (innermost dimension)
360 let idx2 = x.argmax_dim(2, false);
361 assert_eq!(idx2.shape().dims, vec![2, 2]);
362 assert_eq!(idx2.get(&[0, 0]), 1.0); // [1.0, 2.0] -> max at index 1
363 assert_eq!(idx2.get(&[0, 1]), 1.0); // [3.0, 4.0] -> max at index 1
364 assert_eq!(idx2.get(&[1, 0]), 1.0); // [5.0, 6.0] -> max at index 1
365 assert_eq!(idx2.get(&[1, 1]), 1.0); // [7.0, 8.0] -> max at index 1
366 }
367
368 // ====== LEVEL 2: Non-contiguous tensors (views, permuted) ======
369
370 #[test]
371 fn test_argmax_level2_transpose_view() {
372 // Create a 2x3 tensor and transpose it to get a non-contiguous view
373 let x = Tensor::from_slice(&[1.0, 3.0, 2.0, 4.0, 0.0, 5.0], vec![2, 3]).unwrap();
374 // Original: [[1.0, 3.0, 2.0],
375 // [4.0, 0.0, 5.0]]
376
377 let x_t = x.transpose(0, 1);
378 // Transposed: [[1.0, 4.0],
379 // [3.0, 0.0],
380 // [2.0, 5.0]]
381 assert_eq!(x_t.shape().dims, vec![3, 2]);
382 assert!(!x_t.is_contiguous()); // Should be a view
383
384 // Test global argmax on transposed view
385 let idx = x_t.argmax();
386 assert_eq!(idx.get(&[0]), 5.0); // flat index 5 still points to value 5.0
387
388 // Test argmax along dim=0 of transposed tensor
389 let idx0 = x_t.argmax_dim(0, false);
390 assert_eq!(idx0.shape().dims, vec![2]);
391 assert_eq!(idx0.get(&[0]), 1.0); // col 0: [1.0, 3.0, 2.0] -> max 3.0 at index 1
392 assert_eq!(idx0.get(&[1]), 2.0); // col 1: [4.0, 0.0, 5.0] -> max 5.0 at index 2
393
394 // Test argmax along dim=1 of transposed tensor
395 let idx1 = x_t.argmax_dim(1, false);
396 assert_eq!(idx1.shape().dims, vec![3]);
397 assert_eq!(idx1.get(&[0]), 1.0); // row 0: [1.0, 4.0] -> max 4.0 at index 1
398 assert_eq!(idx1.get(&[1]), 0.0); // row 1: [3.0, 0.0] -> max 3.0 at index 0
399 assert_eq!(idx1.get(&[2]), 1.0); // row 2: [2.0, 5.0] -> max 5.0 at index 1
400 }
401
402 #[test]
403 fn test_argmax_level2_slice_view() {
404 // Create a 3x4 tensor and take a slice
405 let data = vec![
406 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
407 ];
408 let x = Tensor::from_slice(&data, vec![3, 4]).unwrap();
409 // [[1, 2, 3, 4],
410 // [5, 6, 7, 8],
411 // [9, 10, 11, 12]]
412
413 // Select middle row (creates a view)
414 let middle_row = x.select(0, 1);
415 // [5, 6, 7, 8]
416 assert_eq!(middle_row.shape().dims, vec![4]);
417
418 let idx = middle_row.argmax();
419 assert_eq!(idx.get(&[0]), 3.0); // index 3 has value 8.0
420
421 // Test argmax_dim on 1D slice (should work the same as global argmax)
422 let idx_dim = middle_row.argmax_dim(0, false);
423 assert_eq!(idx_dim.shape().dims, vec![1]);
424 assert_eq!(idx_dim.get(&[0]), 3.0);
425 }
426
427 #[test]
428 fn test_argmax_level2_permuted_3d() {
429 // Test 3D tensor with permuted dimensions
430 let data = (0..24).map(|i| i as f32).collect::<Vec<_>>();
431 let x = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
432 // Shape [2, 3, 4] with values 0 to 23
433
434 // Permute to [4, 2, 3] (swap dims 0 and 2)
435 let x_perm = x.permute(vec![2, 1, 0]);
436 assert_eq!(x_perm.shape().dims, vec![4, 3, 2]);
437 assert!(!x_perm.is_contiguous());
438
439 // Global argmax should still find the maximum value (23)
440 let idx = x_perm.argmax();
441 assert_eq!(idx.get(&[0]), 23.0); // The max value is still 23
442
443 // Test argmax along each dimension of permuted tensor
444 let idx0 = x_perm.argmax_dim(0, false); // [3, 2]
445 assert_eq!(idx0.shape().dims, vec![3, 2]);
446
447 let idx1 = x_perm.argmax_dim(1, false); // [4, 2]
448 assert_eq!(idx1.shape().dims, vec![4, 2]);
449
450 let idx2 = x_perm.argmax_dim(2, false); // [4, 3]
451 assert_eq!(idx2.shape().dims, vec![4, 3]);
452 }
453
454 #[test]
455 fn test_argmax_level2_nested_views() {
456 // Test nested transformations (transpose then select)
457 let data = vec![
458 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
459 ];
460 let x = Tensor::from_slice(&data, vec![4, 3]).unwrap();
461
462 // First transpose, then select a row
463 let x_t = x.transpose(0, 1); // [3, 4]
464 let row = x_t.select(0, 1); // Select second row: [2, 5, 8, 11]
465 assert_eq!(row.shape().dims, vec![4]);
466
467 let idx = row.argmax();
468 assert_eq!(idx.get(&[0]), 3.0); // index 3 has value 11.0
469 }
470
471 // ====== LEVEL 3: Complex multi-dimensional cases and edge scenarios ======
472
473 #[test]
474 fn test_argmax_level3_4d_tensor() {
475 // Test 4D tensor with various reduction dimensions
476 let data = (0..120).map(|i| i as f32).collect::<Vec<_>>();
477 let x = Tensor::from_slice(&data, vec![2, 3, 4, 5]).unwrap();
478 // Shape [2, 3, 4, 5] with values 0 to 119
479
480 // Global argmax
481 let idx = x.argmax();
482 assert_eq!(idx.get(&[0]), 119.0); // Maximum value 119.0 at flat index 119
483
484 // Test argmax along each dimension
485 let idx0_keepdim = x.argmax_dim(0, true);
486 assert_eq!(idx0_keepdim.shape().dims, vec![1, 3, 4, 5]);
487
488 let idx0_no_keepdim = x.argmax_dim(0, false);
489 assert_eq!(idx0_no_keepdim.shape().dims, vec![3, 4, 5]);
490
491 let idx1_keepdim = x.argmax_dim(1, true);
492 assert_eq!(idx1_keepdim.shape().dims, vec![2, 1, 4, 5]);
493
494 let idx1_no_keepdim = x.argmax_dim(1, false);
495 assert_eq!(idx1_no_keepdim.shape().dims, vec![2, 4, 5]);
496
497 let idx2_keepdim = x.argmax_dim(2, true);
498 assert_eq!(idx2_keepdim.shape().dims, vec![2, 3, 1, 5]);
499
500 let idx2_no_keepdim = x.argmax_dim(2, false);
501 assert_eq!(idx2_no_keepdim.shape().dims, vec![2, 3, 5]);
502
503 let idx3_keepdim = x.argmax_dim(3, true);
504 assert_eq!(idx3_keepdim.shape().dims, vec![2, 3, 4, 1]);
505
506 let idx3_no_keepdim = x.argmax_dim(3, false);
507 assert_eq!(idx3_no_keepdim.shape().dims, vec![2, 3, 4]);
508
509 // Check some specific values for the innermost dimension (dim=3)
510 // For each [i, j, k, :] slice, argmax should be 4 (index of max in size-5 dimension)
511 for i in 0..2 {
512 for j in 0..3 {
513 for k in 0..4 {
514 assert_eq!(idx3_no_keepdim.get(&[i, j, k]), 4.0);
515 assert_eq!(idx3_keepdim.get(&[i, j, k, 0]), 4.0);
516 }
517 }
518 }
519 }
520
521 #[test]
522 fn test_argmax_level3_edge_cases_keepdim() {
523 // Test edge case: 1D tensor with keepdim
524 let x1d = Tensor::from_slice(&[5.0, 1.0, 8.0, 3.0], vec![4]).unwrap();
525
526 let idx_keepdim = x1d.argmax_dim(0, true);
527 assert_eq!(idx_keepdim.shape().dims, vec![1]);
528 assert_eq!(idx_keepdim.get(&[0]), 2.0);
529
530 let idx_no_keepdim = x1d.argmax_dim(0, false);
531 assert_eq!(idx_no_keepdim.shape().dims, vec![1]); // Special case: becomes [1] not []
532 assert_eq!(idx_no_keepdim.get(&[0]), 2.0);
533
534 // Test edge case: dimension of size 1
535 let x_size_1 = Tensor::from_slice(&[42.0], vec![1]).unwrap();
536
537 let idx = x_size_1.argmax_dim(0, true);
538 assert_eq!(idx.shape().dims, vec![1]);
539 assert_eq!(idx.get(&[0]), 0.0);
540
541 let idx = x_size_1.argmax_dim(0, false);
542 assert_eq!(idx.shape().dims, vec![1]);
543 assert_eq!(idx.get(&[0]), 0.0);
544 }
545
546 #[test]
547 fn test_argmax_level3_ties_handling() {
548 // Test that tied values return the first occurrence (PyTorch behavior)
549 let x = Tensor::from_slice(&[3.0, 5.0, 5.0, 2.0, 5.0], vec![5]).unwrap();
550
551 let idx = x.argmax();
552 assert_eq!(idx.get(&[0]), 1.0); // First occurrence of max value 5.0
553
554 // Test with 2D ties
555 let x2d = Tensor::from_slice(&[3.0, 5.0, 5.0, 2.0, 1.0, 5.0], vec![3, 2]).unwrap();
556
557 // argmax along dim=0 (columns)
558 let idx0 = x2d.argmax_dim(0, false);
559 assert_eq!(idx0.shape().dims, vec![2]);
560 assert_eq!(idx0.get(&[0]), 1.0); // col 0: [3, 5, 1] -> first 5 at index 1
561 assert_eq!(idx0.get(&[1]), 0.0); // col 1: [5, 2, 5] -> first 5 at index 0
562
563 // argmax along dim=1 (rows)
564 let idx1 = x2d.argmax_dim(1, false);
565 assert_eq!(idx1.shape().dims, vec![3]);
566 assert_eq!(idx1.get(&[0]), 1.0); // row 0: [3, 5] -> max at index 1
567 assert_eq!(idx1.get(&[1]), 0.0); // row 1: [5, 2] -> max at index 0
568 assert_eq!(idx1.get(&[2]), 1.0); // row 2: [1, 5] -> max at index 1
569 }
570
571 #[test]
572 fn test_argmax_level3_extreme_values() {
573 // Test with extreme floating point values
574 let x = Tensor::from_slice(
575 &[f32::NEG_INFINITY, -1e10, 0.0, 1e10, f32::INFINITY, f32::NAN],
576 vec![6],
577 )
578 .unwrap();
579
580 let idx = x.argmax();
581 // NaN comparison behavior: NaN is not > any value, so INFINITY should win
582 assert_eq!(idx.get(&[0]), 4.0); // f32::INFINITY at index 4
583
584 // Test negative values only
585 let x_neg = Tensor::from_slice(&[-10.0, -5.0, -15.0, -1.0], vec![4]).unwrap();
586 let idx = x_neg.argmax();
587 assert_eq!(idx.get(&[0]), 3.0); // -1.0 is the maximum at index 3
588 }
589
590 #[test]
591 fn test_argmax_level3_large_dimensions() {
592 // Test with one very large dimension
593 let size = 1000;
594 let data: Vec<f32> = (0..size).map(|i| (size - 1 - i) as f32).collect(); // Decreasing values
595 let x = Tensor::from_slice(&data, vec![size]).unwrap();
596
597 let idx = x.argmax();
598 assert_eq!(idx.get(&[0]), 0.0); // First element has max value (size-1)
599
600 // Test with multiple dimensions where one is large
601 let data2: Vec<f32> = (0..(10 * 100)).map(|i| i as f32).collect();
602 let x2 = Tensor::from_slice(&data2, vec![10, 100]).unwrap();
603
604 let idx = x2.argmax();
605 assert_eq!(idx.get(&[0]), 999.0); // Last element has max value
606
607 // Test argmax along the large dimension
608 let idx_dim1 = x2.argmax_dim(1, false);
609 assert_eq!(idx_dim1.shape().dims, vec![10]);
610 // Each row's max should be at index 99 (last column)
611 for i in 0..10 {
612 assert_eq!(idx_dim1.get(&[i]), 99.0);
613 }
614 }
615
616 #[test]
617 fn test_argmax_level3_consistency_with_pytorch_behavior() {
618 // Test specific patterns that should match PyTorch exactly
619
620 // Pattern 1: 3D tensor, reduce middle dimension
621 let x = Tensor::from_slice(
622 &[
623 1.0, 2.0, 3.0, 4.0, // [0, 0, :]
624 5.0, 6.0, 7.0, 8.0, // [0, 1, :]
625 9.0, 8.0, 7.0, 6.0, // [1, 0, :]
626 5.0, 4.0, 3.0, 2.0, // [1, 1, :]
627 ],
628 vec![2, 2, 4],
629 )
630 .unwrap();
631
632 // Reduce along dim=1 (middle dimension)
633 let idx = x.argmax_dim(1, true);
634 assert_eq!(idx.shape().dims, vec![2, 1, 4]);
635
636 // For [0, :, j] where j=0,1,2,3: values are [1,5], [2,6], [3,7], [4,8]
637 // Max indices should be [1,1,1,1] (second slice wins)
638 assert_eq!(idx.get(&[0, 0, 0]), 1.0);
639 assert_eq!(idx.get(&[0, 0, 1]), 1.0);
640 assert_eq!(idx.get(&[0, 0, 2]), 1.0);
641 assert_eq!(idx.get(&[0, 0, 3]), 1.0);
642
643 // For [1, :, j] where j=0,1,2,3: values are [9,5], [8,4], [7,3], [6,2]
644 // Max indices should be [0,0,0,0] (first slice wins)
645 assert_eq!(idx.get(&[1, 0, 0]), 0.0);
646 assert_eq!(idx.get(&[1, 0, 1]), 0.0);
647 assert_eq!(idx.get(&[1, 0, 2]), 0.0);
648 assert_eq!(idx.get(&[1, 0, 3]), 0.0);
649 }
650}