train_station/tensor/reductions/
argmin.rs

1use crate::tensor::core::Tensor;
2
3impl Tensor {
4    /// Returns the index of the minimum value in the tensor
5    ///
6    /// This method finds the flat index of the minimum value across all elements
7    /// in the tensor. The result is a scalar tensor containing the index as a
8    /// floating-point value. This operation is non-differentiable and the output
9    /// never requires gradient tracking.
10    ///
11    /// # Returns
12    ///
13    /// A tensor with shape `[1]` containing the flat index of the minimum value
14    /// as a `f32`. If the input tensor is empty, returns `0.0`.
15    ///
16    /// # Examples
17    ///
18    /// ```
19    /// use train_station::Tensor;
20    ///
21    /// let tensor = Tensor::from_slice(&[3.0, -2.0, 5.0, -1.0], vec![4]).unwrap();
22    /// let min_index = tensor.argmin();
23    /// assert_eq!(min_index.get(&[0]), 1.0); // -2.0 is at index 1
24    /// ```
25    ///
26    /// ```
27    /// use train_station::Tensor;
28    ///
29    /// // Empty tensor case
30    /// let empty_tensor = Tensor::new(vec![0]);
31    /// let min_index = empty_tensor.argmin();
32    /// assert_eq!(min_index.get(&[0]), 0.0);
33    /// ```
34    pub fn argmin(&self) -> Tensor {
35        let mut out = Tensor::new(vec![1]);
36        if self.size() == 0 {
37            out.fill(0.0);
38            return out;
39        }
40
41        let mut best_val = f32::INFINITY;
42        let mut best_idx = 0usize;
43
44        if self.is_contiguous() {
45            // Fast path for contiguous tensors
46            unsafe {
47                let src = self.as_ptr();
48                for i in 0..self.size() {
49                    let v = *src.add(i);
50                    if v < best_val {
51                        best_val = v;
52                        best_idx = i;
53                    }
54                }
55            }
56        } else {
57            // Stride-aware path for non-contiguous tensors
58            let dims = self.shape().dims.clone();
59            for flat_idx in 0..self.size() {
60                // Convert flat index to multi-dimensional coordinates
61                let mut coords = vec![0; dims.len()];
62                let mut tmp = flat_idx;
63                for k in (0..dims.len()).rev() {
64                    coords[k] = tmp % dims[k];
65                    tmp /= dims[k];
66                }
67
68                // Get value using stride-aware offset
69                let offset = self.shape().offset(&coords);
70                let v = unsafe { *self.as_ptr().add(offset) };
71                if v < best_val {
72                    best_val = v;
73                    best_idx = flat_idx;
74                }
75            }
76        }
77
78        unsafe {
79            *out.as_mut_ptr() = best_idx as f32;
80        }
81        out
82    }
83
84    /// Returns the indices of minimum values along a specified dimension
85    ///
86    /// This method finds the indices of minimum values along the specified dimension.
87    /// The result contains the indices where the minimum values occur in that dimension.
88    /// This operation is non-differentiable and the output never requires gradient tracking.
89    ///
90    /// # Arguments
91    ///
92    /// * `dim` - The dimension along which to find minimum indices (0-based)
93    /// * `keepdim` - Whether to keep the reduced dimension in the output shape
94    ///   - If `true`, the reduced dimension is kept with size 1
95    ///   - If `false`, the reduced dimension is removed from the output shape
96    ///
97    /// # Returns
98    ///
99    /// A tensor containing the indices of minimum values along the specified dimension.
100    /// The output shape depends on `keepdim`:
101    /// - If `keepdim` is `true`, the reduced dimension has size 1
102    /// - If `keepdim` is `false`, the reduced dimension is removed
103    ///
104    /// # Panics
105    ///
106    /// * If `dim` is out of bounds for the tensor's rank
107    /// * If the dimension to reduce has size 0
108    ///
109    /// # Examples
110    ///
111    /// ```
112    /// use train_station::Tensor;
113    ///
114    /// let tensor = Tensor::from_slice(&[3.0, -2.0, 5.0, -1.0, 0.0, -3.0], vec![2, 3]).unwrap();
115    ///
116    /// // Find minimum indices along dimension 1 (columns), keeping the dimension
117    /// let indices = tensor.argmin_dim(1, true);
118    /// assert_eq!(indices.shape().dims, vec![2, 1]);
119    /// assert_eq!(indices.get(&[0, 0]), 1.0); // -2.0 is at index 1 in first row
120    /// assert_eq!(indices.get(&[1, 0]), 2.0); // -3.0 is at index 2 in second row
121    /// ```
122    ///
123    /// ```
124    /// use train_station::Tensor;
125    ///
126    /// let tensor = Tensor::from_slice(&[3.0, -2.0, 5.0, -1.0, 0.0, -3.0], vec![2, 3]).unwrap();
127    ///
128    /// // Find minimum indices along dimension 1 (columns), removing the dimension
129    /// let indices = tensor.argmin_dim(1, false);
130    /// assert_eq!(indices.shape().dims, vec![2]);
131    /// assert_eq!(indices.get(&[0]), 1.0); // -2.0 is at index 1 in first row
132    /// assert_eq!(indices.get(&[1]), 2.0); // -3.0 is at index 2 in second row
133    /// ```
134    ///
135    /// ```
136    /// use train_station::Tensor;
137    ///
138    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
139    ///
140    /// // Find minimum index in a 1D tensor
141    /// let index = tensor.argmin_dim(0, false);
142    /// assert_eq!(index.shape().dims, vec![1]);
143    /// assert_eq!(index.get(&[0]), 0.0); // 1.0 is at index 0
144    /// ```
145    pub fn argmin_dim(&self, dim: usize, keepdim: bool) -> Tensor {
146        let rank = self.shape().rank();
147        assert!(
148            dim < rank,
149            "argmin_dim dim {} out of bounds for rank {}",
150            dim,
151            rank
152        );
153
154        let in_dims = self.shape().dims.clone();
155        let reduce_size = in_dims[dim];
156        assert!(reduce_size > 0, "cannot argmin over empty dimension");
157
158        // Build output shape
159        let mut out_dims = in_dims.clone();
160        if keepdim {
161            out_dims[dim] = 1;
162        } else {
163            out_dims.remove(dim);
164        }
165        if out_dims.is_empty() {
166            out_dims.push(1);
167        }
168
169        let mut out = Tensor::zeros(out_dims.clone());
170
171        // Use stride-aware approach to handle non-contiguous tensors correctly
172        let out_size = out.size();
173
174        unsafe {
175            let dst = out.as_mut_ptr();
176
177            // Iterate over all output positions
178            for out_idx in 0..out_size {
179                // Convert flat output index to multi-dimensional coordinates
180                let mut out_coords = vec![0; out_dims.len()];
181                let mut tmp = out_idx;
182                for k in (0..out_dims.len()).rev() {
183                    out_coords[k] = tmp % out_dims[k];
184                    tmp /= out_dims[k];
185                }
186
187                // Convert output coordinates to input coordinates
188                let mut in_coords = vec![0; rank];
189                if keepdim {
190                    // When keepdim=true, output coords map directly to input coords
191                    for (k, &out_coord) in out_coords.iter().enumerate() {
192                        if k == dim {
193                            in_coords[k] = 0; // Will be set in the loop below
194                        } else {
195                            in_coords[k] = out_coord;
196                        }
197                    }
198                } else {
199                    // When keepdim=false, we need to insert the missing dimension
200                    let mut out_coord_idx = 0;
201                    for (k, in_coord) in in_coords.iter_mut().enumerate() {
202                        if k == dim {
203                            *in_coord = 0; // Will be set in the loop below
204                        } else {
205                            *in_coord = out_coords[out_coord_idx];
206                            out_coord_idx += 1;
207                        }
208                    }
209                }
210
211                // Find the argmin along the specified dimension
212                let mut best_val = f32::INFINITY;
213                let mut best_j = 0usize;
214
215                for j in 0..reduce_size {
216                    in_coords[dim] = j;
217                    let in_offset = self.shape().offset(&in_coords);
218                    let v = *self.as_ptr().add(in_offset);
219                    if v < best_val {
220                        best_val = v;
221                        best_j = j;
222                    }
223                }
224
225                *dst.add(out_idx) = best_j as f32;
226            }
227        }
228
229        out
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236
237    // Level 1 Tests: Basic functionality with simple contiguous tensors
238    #[test]
239    fn test_argmin_level1_basic_1d() {
240        // Simple 1D case
241        let x = Tensor::from_slice(&[3.0, -2.0, 5.0, -1.0], vec![4]).unwrap();
242        let idx = x.argmin();
243        assert_eq!(idx.get(&[0]), 1.0); // -2.0 is at index 1
244        assert_eq!(idx.shape().dims, vec![1]);
245    }
246
247    #[test]
248    fn test_argmin_level1_basic_1d_edge_cases() {
249        // Single element
250        let x = Tensor::from_slice(&[42.0], vec![1]).unwrap();
251        let idx = x.argmin();
252        assert_eq!(idx.get(&[0]), 0.0);
253
254        // All same values - should return first occurrence
255        let x = Tensor::from_slice(&[5.0, 5.0, 5.0], vec![3]).unwrap();
256        let idx = x.argmin();
257        assert_eq!(idx.get(&[0]), 0.0);
258
259        // Negative values
260        let x = Tensor::from_slice(&[-1.0, -5.0, -2.0], vec![3]).unwrap();
261        let idx = x.argmin();
262        assert_eq!(idx.get(&[0]), 1.0); // -5.0 is at index 1
263    }
264
265    #[test]
266    fn test_argmin_level1_basic_2d_contiguous() {
267        // Simple 2D case - whole tensor argmin
268        let x = Tensor::from_slice(&[3.0, -2.0, 5.0, -1.0, 0.0, -3.0], vec![2, 3]).unwrap();
269        let idx = x.argmin();
270        assert_eq!(idx.get(&[0]), 5.0); // -3.0 is at flat index 5
271        assert_eq!(idx.shape().dims, vec![1]);
272    }
273
274    #[test]
275    fn test_argmin_level1_dim_2d_basic() {
276        // Test argmin_dim with 2D tensor
277        let x = Tensor::from_slice(&[3.0, -2.0, 5.0, -1.0, 0.0, -3.0], vec![2, 3]).unwrap();
278        // Tensor looks like:
279        // [[3.0, -2.0, 5.0],
280        //  [-1.0, 0.0, -3.0]]
281
282        // Along dimension 1 (columns), keepdim=true
283        let idx1 = x.argmin_dim(1, true);
284        assert_eq!(idx1.shape().dims, vec![2, 1]);
285        assert_eq!(idx1.get(&[0, 0]), 1.0); // Row 0: -2.0 is at column index 1
286        assert_eq!(idx1.get(&[1, 0]), 2.0); // Row 1: -3.0 is at column index 2
287
288        // Along dimension 1 (columns), keepdim=false
289        let idx1_no_keep = x.argmin_dim(1, false);
290        assert_eq!(idx1_no_keep.shape().dims, vec![2]);
291        assert_eq!(idx1_no_keep.get(&[0]), 1.0);
292        assert_eq!(idx1_no_keep.get(&[1]), 2.0);
293
294        // Along dimension 0 (rows), keepdim=true
295        let idx0 = x.argmin_dim(0, true);
296        assert_eq!(idx0.shape().dims, vec![1, 3]);
297        assert_eq!(idx0.get(&[0, 0]), 1.0); // Column 0: -1.0 is at row index 1
298        assert_eq!(idx0.get(&[0, 1]), 0.0); // Column 1: -2.0 is at row index 0
299        assert_eq!(idx0.get(&[0, 2]), 1.0); // Column 2: -3.0 is at row index 1
300    }
301
302    #[test]
303    fn test_argmin_level1_3d_basic() {
304        // Test with 3D tensor
305        let data = vec![
306            1.0, -2.0, // [0,0,:] = [1.0, -2.0]
307            3.0, 4.0, // [0,1,:] = [3.0, 4.0]
308            -5.0, 6.0, // [1,0,:] = [-5.0, 6.0]
309            7.0, -8.0, // [1,1,:] = [7.0, -8.0]
310        ];
311        let x = Tensor::from_slice(&data, vec![2, 2, 2]).unwrap();
312
313        // Whole tensor argmin
314        let idx = x.argmin();
315        assert_eq!(idx.get(&[0]), 7.0); // -8.0 is at flat index 7
316
317        // Along dimension 2 (innermost), keepdim=false
318        let idx2 = x.argmin_dim(2, false);
319        assert_eq!(idx2.shape().dims, vec![2, 2]);
320        assert_eq!(idx2.get(&[0, 0]), 1.0); // [1.0, -2.0] -> min at index 1
321        assert_eq!(idx2.get(&[0, 1]), 0.0); // [3.0, 4.0] -> min at index 0
322        assert_eq!(idx2.get(&[1, 0]), 0.0); // [-5.0, 6.0] -> min at index 0
323        assert_eq!(idx2.get(&[1, 1]), 1.0); // [7.0, -8.0] -> min at index 1
324    }
325
326    // Level 2 Tests: Complex shapes, higher dimensions, and edge cases
327    #[test]
328    fn test_argmin_level2_large_tensors() {
329        // Test with larger tensors
330        let data: Vec<f32> = (0..1000).map(|i| (i as f32) * 0.1 - 50.0).collect();
331        // Values from -50.0 to 49.9, minimum at index 0
332        let x = Tensor::from_slice(&data, vec![1000]).unwrap();
333        let idx = x.argmin();
334        assert_eq!(idx.get(&[0]), 0.0);
335
336        // Reshape to 2D
337        let x_2d = Tensor::from_slice(&data, vec![25, 40]).unwrap();
338        let idx_2d = x_2d.argmin();
339        assert_eq!(idx_2d.get(&[0]), 0.0);
340
341        // Test along different dimensions
342        let idx_dim0 = x_2d.argmin_dim(0, false);
343        assert_eq!(idx_dim0.shape().dims, vec![40]);
344        assert_eq!(idx_dim0.get(&[0]), 0.0); // Column 0: minimum at row 0
345
346        let idx_dim1 = x_2d.argmin_dim(1, false);
347        assert_eq!(idx_dim1.shape().dims, vec![25]);
348        assert_eq!(idx_dim1.get(&[0]), 0.0); // Row 0: minimum at column 0
349    }
350
351    #[test]
352    fn test_argmin_level2_4d_tensor() {
353        // Test with 4D tensor [2, 3, 4, 5] = 120 elements
354        let data: Vec<f32> = (0..120).map(|i| 120.0 - i as f32).collect();
355        // Values from 120.0 down to 1.0, minimum at last index
356        let x = Tensor::from_slice(&data, vec![2, 3, 4, 5]).unwrap();
357
358        // Global argmin
359        let idx = x.argmin();
360        assert_eq!(idx.get(&[0]), 119.0); // minimum value 1.0 is at index 119
361
362        // Test argmin along dimension 3 (innermost)
363        let idx3 = x.argmin_dim(3, false);
364        assert_eq!(idx3.shape().dims, vec![2, 3, 4]);
365        // Each slice along dim 3 has values decreasing, so min is always at index 4
366        assert_eq!(idx3.get(&[0, 0, 0]), 4.0);
367        assert_eq!(idx3.get(&[1, 2, 3]), 4.0);
368
369        // Test argmin along dimension 0 (outermost)
370        let idx0 = x.argmin_dim(0, false);
371        assert_eq!(idx0.shape().dims, vec![3, 4, 5]);
372        // For each position, the minimum is in the second batch (index 1)
373        assert_eq!(idx0.get(&[0, 0, 0]), 1.0);
374        assert_eq!(idx0.get(&[2, 3, 4]), 1.0);
375    }
376
377    #[test]
378    fn test_argmin_level2_special_values() {
379        // Test with special floating point values
380        let data = vec![
381            f32::NAN,       // 0
382            f32::INFINITY,  // 1
383            -f32::INFINITY, // 2 <- this should be minimum
384            0.0,            // 3
385            -0.0,           // 4
386            1.0,            // 5
387        ];
388        let x = Tensor::from_slice(&data, vec![6]).unwrap();
389        let idx = x.argmin();
390        assert_eq!(idx.get(&[0]), 2.0); // -infinity at index 2
391
392        // Test with all NaN
393        let nan_data = vec![f32::NAN, f32::NAN, f32::NAN];
394        let x_nan = Tensor::from_slice(&nan_data, vec![3]).unwrap();
395        let idx_nan = x_nan.argmin();
396        // With all NaN, should return first index
397        assert_eq!(idx_nan.get(&[0]), 0.0);
398
399        // Test with mix of normal values and NaN
400        let mixed_data = vec![1.0, f32::NAN, -5.0, f32::NAN, 3.0];
401        let x_mixed = Tensor::from_slice(&mixed_data, vec![5]).unwrap();
402        let idx_mixed = x_mixed.argmin();
403        assert_eq!(idx_mixed.get(&[0]), 2.0); // -5.0 at index 2
404    }
405
406    #[test]
407    fn test_argmin_level2_ties() {
408        // Test behavior with tied minimum values (should return first occurrence)
409        let data = vec![3.0, -2.0, 5.0, -2.0, 0.0, -2.0]; // -2.0 appears at indices 1, 3, 5
410        let x = Tensor::from_slice(&data, vec![6]).unwrap();
411        let idx = x.argmin();
412        assert_eq!(idx.get(&[0]), 1.0); // First occurrence of -2.0
413
414        // Test with 2D tensor and ties
415        let x_2d = Tensor::from_slice(&data, vec![2, 3]).unwrap();
416        // [[3.0, -2.0, 5.0],
417        //  [-2.0, 0.0, -2.0]]
418
419        let idx_dim0 = x_2d.argmin_dim(0, false);
420        assert_eq!(idx_dim0.shape().dims, vec![3]);
421        assert_eq!(idx_dim0.get(&[0]), 1.0); // Column 0: min(-2.0 vs 3.0) -> row 1
422        assert_eq!(idx_dim0.get(&[1]), 0.0); // Column 1: min(-2.0 vs 0.0) -> row 0
423        assert_eq!(idx_dim0.get(&[2]), 1.0); // Column 2: min(5.0 vs -2.0) -> row 1
424
425        let idx_dim1 = x_2d.argmin_dim(1, false);
426        assert_eq!(idx_dim1.shape().dims, vec![2]);
427        assert_eq!(idx_dim1.get(&[0]), 1.0); // Row 0: min of [3.0, -2.0, 5.0] -> col 1
428        assert_eq!(idx_dim1.get(&[1]), 0.0); // Row 1: min of [-2.0, 0.0, -2.0] -> col 0 (first)
429    }
430
431    #[test]
432    fn test_argmin_level2_broadcasting_dims() {
433        // Test with dimensions of size 1 (singleton dimensions)
434        let data = vec![5.0, -3.0, 7.0, 1.0, -8.0, 2.0];
435        let x = Tensor::from_slice(&data, vec![1, 6, 1]).unwrap();
436
437        let idx = x.argmin();
438        assert_eq!(idx.get(&[0]), 4.0); // -8.0 at flat index 4
439
440        // Test argmin along different dimensions
441        let idx_dim0 = x.argmin_dim(0, false);
442        assert_eq!(idx_dim0.shape().dims, vec![6, 1]);
443
444        let idx_dim1 = x.argmin_dim(1, false);
445        assert_eq!(idx_dim1.shape().dims, vec![1, 1]);
446        assert_eq!(idx_dim1.get(&[0, 0]), 4.0); // -8.0 at position 4 along dim 1
447
448        let idx_dim2 = x.argmin_dim(2, false);
449        assert_eq!(idx_dim2.shape().dims, vec![1, 6]);
450    }
451
452    #[test]
453    fn test_argmin_level2_complex_3d() {
454        // Complex 3D case with multiple batch dimensions
455        let data = vec![
456            // Batch 0, Channel 0: [[1, 2], [3, 4]]
457            1.0, 2.0, 3.0, 4.0, // Batch 0, Channel 1: [[5, 6], [7, 8]]
458            5.0, 6.0, 7.0, 8.0, // Batch 0, Channel 2: [[-1, 0], [9, 10]]
459            -1.0, 0.0, 9.0, 10.0, // Batch 1, Channel 0: [[11, 12], [13, 14]]
460            11.0, 12.0, 13.0, 14.0, // Batch 1, Channel 1: [[15, 16], [17, 18]]
461            15.0, 16.0, 17.0, 18.0, // Batch 1, Channel 2: [[19, 20], [21, -5]]
462            19.0, 20.0, 21.0, -5.0,
463        ];
464        let x = Tensor::from_slice(&data, vec![2, 3, 2, 2]).unwrap();
465
466        // Global minimum
467        let idx = x.argmin();
468        assert_eq!(idx.get(&[0]), 23.0); // -5.0 is at flat index 23
469
470        // Argmin along dimension 1 (channels)
471        let idx_dim1 = x.argmin_dim(1, false);
472        assert_eq!(idx_dim1.shape().dims, vec![2, 2, 2]);
473        // At position [0,0,0]: min(1.0, 5.0, -1.0) = -1.0 at channel 2
474        assert_eq!(idx_dim1.get(&[0, 0, 0]), 2.0);
475        // At position [1,1,1]: min(14.0, 18.0, -5.0) = -5.0 at channel 2
476        assert_eq!(idx_dim1.get(&[1, 1, 1]), 2.0);
477    }
478
479    // Level 3 Tests: Non-contiguous tensors, views, and strided memory layouts
480    #[test]
481    fn test_argmin_level3_transpose_view() {
482        // Create a 2x3 tensor and transpose it to get a non-contiguous view
483        let x = Tensor::from_slice(&[1.0, 3.0, 2.0, 4.0, 0.0, -5.0], vec![2, 3]).unwrap();
484        // Original: [[1.0, 3.0, 2.0],
485        //            [4.0, 0.0, -5.0]]
486
487        let x_t = x.transpose(0, 1);
488        // Transposed: [[1.0, 4.0],
489        //              [3.0, 0.0],
490        //              [2.0, -5.0]]
491        assert_eq!(x_t.shape().dims, vec![3, 2]);
492        assert!(!x_t.is_contiguous()); // Should be a view
493
494        // Test global argmin on transposed view
495        let idx = x_t.argmin();
496        assert_eq!(idx.get(&[0]), 5.0); // flat index 5 still points to value -5.0
497
498        // Test argmin along dim=0 of transposed tensor
499        let idx0 = x_t.argmin_dim(0, false);
500        assert_eq!(idx0.shape().dims, vec![2]);
501        assert_eq!(idx0.get(&[0]), 0.0); // col 0: [1.0, 3.0, 2.0] -> min 1.0 at index 0
502        assert_eq!(idx0.get(&[1]), 2.0); // col 1: [4.0, 0.0, -5.0] -> min -5.0 at index 2
503
504        // Test argmin along dim=1 of transposed tensor
505        let idx1 = x_t.argmin_dim(1, false);
506        assert_eq!(idx1.shape().dims, vec![3]);
507        assert_eq!(idx1.get(&[0]), 0.0); // row 0: [1.0, 4.0] -> min 1.0 at index 0
508        assert_eq!(idx1.get(&[1]), 1.0); // row 1: [3.0, 0.0] -> min 0.0 at index 1
509        assert_eq!(idx1.get(&[2]), 1.0); // row 2: [2.0, -5.0] -> min -5.0 at index 1
510    }
511
512    #[test]
513    fn test_argmin_level3_slice_view() {
514        // Create a 3x4 tensor and take a slice
515        let data = vec![
516            1.0, 2.0, 3.0, 4.0, 5.0, -6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
517        ];
518        let x = Tensor::from_slice(&data, vec![3, 4]).unwrap();
519        // [[1, 2, 3, 4],
520        //  [5, -6, 7, 8],
521        //  [9, 10, 11, 12]]
522
523        // Select middle row (creates a view)
524        let middle_row = x.select(0, 1);
525        // [5, -6, 7, 8]
526        assert_eq!(middle_row.shape().dims, vec![4]);
527
528        let idx = middle_row.argmin();
529        assert_eq!(idx.get(&[0]), 1.0); // index 1 has value -6.0
530
531        // Test argmin_dim on 1D slice (should work the same as global argmin)
532        let idx_dim = middle_row.argmin_dim(0, false);
533        assert_eq!(idx_dim.shape().dims, vec![1]);
534        assert_eq!(idx_dim.get(&[0]), 1.0);
535
536        // Test with column slice
537        let second_col = x.select(1, 1);
538        // [2, -6, 10]
539        assert_eq!(second_col.shape().dims, vec![3]);
540        let idx_col = second_col.argmin();
541        assert_eq!(idx_col.get(&[0]), 1.0); // -6.0 at index 1
542    }
543
544    #[test]
545    fn test_argmin_level3_permuted_3d() {
546        // Test 3D tensor with permuted dimensions
547        let data = (0..24).map(|i| 24.0 - i as f32).collect::<Vec<_>>();
548        let x = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
549        // Shape [2, 3, 4] with values 24.0 down to 1.0
550        // Minimum value 1.0 is at the last position
551
552        // Permute to [4, 2, 3] (swap dims 0 and 2)
553        let x_perm = x.permute(vec![2, 1, 0]);
554        assert_eq!(x_perm.shape().dims, vec![4, 3, 2]);
555        assert!(!x_perm.is_contiguous());
556
557        // Global argmin should still find the minimum value (1.0)
558        let idx = x_perm.argmin();
559        assert_eq!(idx.get(&[0]), 23.0); // The min value 1.0 is still at flat index 23
560
561        // Test argmin along each dimension of permuted tensor
562        let idx0 = x_perm.argmin_dim(0, false); // [3, 2]
563        assert_eq!(idx0.shape().dims, vec![3, 2]);
564
565        let idx1 = x_perm.argmin_dim(1, false); // [4, 2]
566        assert_eq!(idx1.shape().dims, vec![4, 2]);
567
568        let idx2 = x_perm.argmin_dim(2, false); // [4, 3]
569        assert_eq!(idx2.shape().dims, vec![4, 3]);
570
571        // Verify some specific values
572        // Since values decrease from 24.0 to 1.0, the permuted tensor should have
573        // minimum values at the later positions in the original ordering
574    }
575
576    #[test]
577    fn test_argmin_level3_nested_views() {
578        // Test nested transformations (transpose then select)
579        let data = vec![
580            1.0, 2.0, -3.0, 4.0, 5.0, 6.0, 7.0, -8.0, 9.0, 10.0, 11.0, 12.0,
581        ];
582        let x = Tensor::from_slice(&data, vec![4, 3]).unwrap();
583
584        // First transpose, then select a row
585        let x_t = x.transpose(0, 1); // [3, 4]
586        let row = x_t.select(0, 1); // Select second row: [2, 5, -8, 11]
587        assert_eq!(row.shape().dims, vec![4]);
588
589        let idx = row.argmin();
590        assert_eq!(idx.get(&[0]), 2.0); // index 2 has value -8.0
591    }
592
593    #[test]
594    fn test_argmin_level3_strided_memory() {
595        // Test with highly strided memory patterns
596        let data: Vec<f32> = (0..60).map(|i| i as f32 - 30.0).collect();
597        let x = Tensor::from_slice(&data, vec![3, 4, 5]).unwrap();
598        // Values from -30.0 to 29.0
599
600        // Create complex views that result in non-contiguous memory
601        let x_perm = x.permute(vec![2, 0, 1]); // [5, 3, 4]
602        assert!(!x_perm.is_contiguous());
603
604        // Test global argmin
605        let idx = x_perm.argmin();
606        assert_eq!(idx.get(&[0]), 0.0); // -30.0 is at index 0
607
608        // Test dimension-wise argmin on permuted tensor
609        let idx0 = x_perm.argmin_dim(0, false);
610        assert_eq!(idx0.shape().dims, vec![3, 4]);
611
612        let idx1 = x_perm.argmin_dim(1, false);
613        assert_eq!(idx1.shape().dims, vec![5, 4]);
614
615        let idx2 = x_perm.argmin_dim(2, false);
616        assert_eq!(idx2.shape().dims, vec![5, 3]);
617    }
618
619    #[test]
620    fn test_argmin_level3_multiple_transformations() {
621        // Test with multiple chained transformations
622        let data = vec![
623            1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
624            17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, -24.0,
625        ];
626        let x = Tensor::from_slice(&data, vec![4, 6]).unwrap();
627
628        // Chain multiple transformations
629        let x_t = x.transpose(0, 1); // [6, 4]
630        let x_subset = x_t.select(0, 5); // Select last row: [6, 12, 18, -24]
631
632        // Note: select might create contiguous tensors in some cases, so we don't assert non-contiguous
633        assert_eq!(x_subset.shape().dims, vec![4]);
634
635        let idx = x_subset.argmin();
636        assert_eq!(idx.get(&[0]), 3.0); // -24.0 at index 3
637
638        // Test on a slice of the transposed tensor
639        let partial_col = x_t.select(1, 2); // Select third column: [15, 16, 17, 18, 19, 20]
640        let idx_partial = partial_col.argmin();
641        assert_eq!(idx_partial.get(&[0]), 0.0); // 15.0 at index 0
642
643        // Test argmin on the non-contiguous transposed tensor
644        assert!(!x_t.is_contiguous());
645        let idx_trans = x_t.argmin();
646        assert_eq!(idx_trans.get(&[0]), 23.0); // -24.0 is still at flat index 23
647    }
648
649    #[test]
650    fn test_argmin_level3_view_consistency() {
651        // Test that argmin results are consistent between original and view
652        let data = vec![
653            5.0, -2.0, 8.0, 1.0, // row 0: min -2.0 at col 1
654            3.0, 9.0, -4.0, 7.0, // row 1: min -4.0 at col 2
655            6.0, 0.0, 2.0, -1.0, // row 2: min -1.0 at col 3
656        ];
657        let x = Tensor::from_slice(&data, vec![3, 4]).unwrap();
658        // Global minimum is -4.0 at flat index 6
659
660        // Test argmin on original tensor
661        let idx_orig = x.argmin();
662        assert_eq!(idx_orig.get(&[0]), 6.0); // -4.0 at index 6
663
664        // Create a view by transposing and test consistency
665        let x_t = x.transpose(0, 1);
666        // Transposed tensor:
667        // [[5.0, 3.0, 6.0],     // col 0 of original -> row 0: min 3.0 at index 1
668        //  [-2.0, 9.0, 0.0],    // col 1 of original -> row 1: min -2.0 at index 0
669        //  [8.0, -4.0, 2.0],    // col 2 of original -> row 2: min -4.0 at index 1
670        //  [1.0, 7.0, -1.0]]    // col 3 of original -> row 3: min -1.0 at index 2
671
672        let idx_view = x_t.argmin();
673        // The minimum value is still -4.0, but its flat index in the view may differ
674        // Let's just check that both find the minimum value correctly
675
676        // Extract actual minimum values to verify they're the same
677        let min_val_orig = unsafe {
678            let flat_idx = idx_orig.get(&[0]) as usize;
679            *x.as_ptr().add(flat_idx)
680        };
681        let min_val_view = unsafe {
682            let flat_idx = idx_view.get(&[0]) as usize;
683            let dims = x_t.shape().dims.clone();
684            let mut coords = vec![0; dims.len()];
685            let mut tmp = flat_idx;
686            for k in (0..dims.len()).rev() {
687                coords[k] = tmp % dims[k];
688                tmp /= dims[k];
689            }
690            let offset = x_t.shape().offset(&coords);
691            *x_t.as_ptr().add(offset)
692        };
693
694        assert_eq!(min_val_orig, -4.0);
695        assert_eq!(min_val_view, -4.0);
696
697        // Test simpler consistency: argmin along specific dimensions
698        let idx_dim0_orig = x.argmin_dim(0, false); // argmin along rows -> [4] (min of each column)
699        let idx_dim1_trans = x_t.argmin_dim(1, false); // argmin along columns -> [4] (min of each row)
700
701        // These should give the same results since we're reducing along corresponding dims
702        assert_eq!(idx_dim0_orig.shape().dims, vec![4]);
703        assert_eq!(idx_dim1_trans.shape().dims, vec![4]);
704
705        // Original columns vs transposed rows should match
706        assert_eq!(idx_dim0_orig.get(&[0]), 1.0); // col 0: min(5,3,6) = 3 at row 1
707        assert_eq!(idx_dim0_orig.get(&[1]), 0.0); // col 1: min(-2,9,0) = -2 at row 0
708        assert_eq!(idx_dim0_orig.get(&[2]), 1.0); // col 2: min(8,-4,2) = -4 at row 1
709        assert_eq!(idx_dim0_orig.get(&[3]), 2.0); // col 3: min(1,7,-1) = -1 at row 2
710
711        assert_eq!(idx_dim1_trans.get(&[0]), 1.0); // corresponds to col 0
712        assert_eq!(idx_dim1_trans.get(&[1]), 0.0); // corresponds to col 1
713        assert_eq!(idx_dim1_trans.get(&[2]), 1.0); // corresponds to col 2
714        assert_eq!(idx_dim1_trans.get(&[3]), 2.0); // corresponds to col 3
715    }
716
717    // Keep the old basic tests for compatibility
718    #[test]
719    fn test_argmin_basic() {
720        let x = Tensor::from_slice(&[3.0, -2.0, 5.0, -1.0], vec![4]).unwrap();
721        let idx = x.argmin();
722        unsafe {
723            assert_eq!(*idx.as_ptr(), 1.0);
724        }
725    }
726
727    #[test]
728    fn test_argmin_dim() {
729        let x = Tensor::from_slice(&[3.0, -2.0, 5.0, -1.0, 0.0, -3.0], vec![2, 3]).unwrap();
730        let idx0 = x.argmin_dim(1, true);
731        assert_eq!(idx0.shape().dims, vec![2, 1]);
732        assert_eq!(idx0.get(&[0, 0]), 1.0);
733        assert_eq!(idx0.get(&[1, 0]), 2.0);
734    }
735}