train_station/tensor/reductions/
argmin.rs

1use crate::tensor::core::Tensor;
2
3impl Tensor {
4    /// Returns the index of the minimum value in the tensor
5    ///
6    /// This method finds the flat index of the minimum value across all elements
7    /// in the tensor. The result is a scalar tensor containing the index as a
8    /// floating-point value. This operation is non-differentiable and the output
9    /// never requires gradient tracking.
10    ///
11    /// # Returns
12    ///
13    /// A tensor with shape `[1]` containing the flat index of the minimum value
14    /// as a `f32`. If the input tensor is empty, returns `0.0`.
15    ///
16    /// # Examples
17    ///
18    /// ```
19    /// use train_station::Tensor;
20    ///
21    /// let tensor = Tensor::from_slice(&[3.0, -2.0, 5.0, -1.0], vec![4]).unwrap();
22    /// let min_index = tensor.argmin();
23    /// assert_eq!(min_index.get(&[0]), 1.0); // -2.0 is at index 1
24    /// ```
25    ///
26    /// ```
27    /// use train_station::Tensor;
28    ///
29    /// // Empty tensor case
30    /// let empty_tensor = Tensor::new(vec![0]);
31    /// let min_index = empty_tensor.argmin();
32    /// assert_eq!(min_index.get(&[0]), 0.0);
33    /// ```
34    #[track_caller]
35    pub fn argmin(&self) -> Tensor {
36        let mut out = Tensor::new(vec![1]);
37        if self.size() == 0 {
38            out.fill(0.0);
39            return out;
40        }
41
42        let mut best_val = f32::INFINITY;
43        let mut best_idx = 0usize;
44
45        if self.is_contiguous() {
46            // Fast path for contiguous tensors
47            unsafe {
48                let src = self.as_ptr();
49                for i in 0..self.size() {
50                    let v = *src.add(i);
51                    if v < best_val {
52                        best_val = v;
53                        best_idx = i;
54                    }
55                }
56            }
57        } else {
58            // Stride-aware path for non-contiguous tensors
59            let dims = self.shape().dims.clone();
60            for flat_idx in 0..self.size() {
61                // Convert flat index to multi-dimensional coordinates
62                let mut coords = vec![0; dims.len()];
63                let mut tmp = flat_idx;
64                for k in (0..dims.len()).rev() {
65                    coords[k] = tmp % dims[k];
66                    tmp /= dims[k];
67                }
68
69                // Get value using stride-aware offset
70                let offset = self.shape().offset(&coords);
71                let v = unsafe { *self.as_ptr().add(offset) };
72                if v < best_val {
73                    best_val = v;
74                    best_idx = flat_idx;
75                }
76            }
77        }
78
79        unsafe {
80            *out.as_mut_ptr() = best_idx as f32;
81        }
82        out
83    }
84
85    /// Returns the indices of minimum values along a specified dimension
86    ///
87    /// This method finds the indices of minimum values along the specified dimension.
88    /// The result contains the indices where the minimum values occur in that dimension.
89    /// This operation is non-differentiable and the output never requires gradient tracking.
90    ///
91    /// # Arguments
92    ///
93    /// * `dim` - The dimension along which to find minimum indices (0-based)
94    /// * `keepdim` - Whether to keep the reduced dimension in the output shape
95    ///   - If `true`, the reduced dimension is kept with size 1
96    ///   - If `false`, the reduced dimension is removed from the output shape
97    ///
98    /// # Returns
99    ///
100    /// A tensor containing the indices of minimum values along the specified dimension.
101    /// The output shape depends on `keepdim`:
102    /// - If `keepdim` is `true`, the reduced dimension has size 1
103    /// - If `keepdim` is `false`, the reduced dimension is removed
104    ///
105    /// # Panics
106    ///
107    /// * If `dim` is out of bounds for the tensor's rank
108    /// * If the dimension to reduce has size 0
109    ///
110    /// # Examples
111    ///
112    /// ```
113    /// use train_station::Tensor;
114    ///
115    /// let tensor = Tensor::from_slice(&[3.0, -2.0, 5.0, -1.0, 0.0, -3.0], vec![2, 3]).unwrap();
116    ///
117    /// // Find minimum indices along dimension 1 (columns), keeping the dimension
118    /// let indices = tensor.argmin_dim(1, true);
119    /// assert_eq!(indices.shape().dims, vec![2, 1]);
120    /// assert_eq!(indices.get(&[0, 0]), 1.0); // -2.0 is at index 1 in first row
121    /// assert_eq!(indices.get(&[1, 0]), 2.0); // -3.0 is at index 2 in second row
122    /// ```
123    ///
124    /// ```
125    /// use train_station::Tensor;
126    ///
127    /// let tensor = Tensor::from_slice(&[3.0, -2.0, 5.0, -1.0, 0.0, -3.0], vec![2, 3]).unwrap();
128    ///
129    /// // Find minimum indices along dimension 1 (columns), removing the dimension
130    /// let indices = tensor.argmin_dim(1, false);
131    /// assert_eq!(indices.shape().dims, vec![2]);
132    /// assert_eq!(indices.get(&[0]), 1.0); // -2.0 is at index 1 in first row
133    /// assert_eq!(indices.get(&[1]), 2.0); // -3.0 is at index 2 in second row
134    /// ```
135    ///
136    /// ```
137    /// use train_station::Tensor;
138    ///
139    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
140    ///
141    /// // Find minimum index in a 1D tensor
142    /// let index = tensor.argmin_dim(0, false);
143    /// assert_eq!(index.shape().dims, vec![1]);
144    /// assert_eq!(index.get(&[0]), 0.0); // 1.0 is at index 0
145    /// ```
146    #[track_caller]
147    pub fn argmin_dim(&self, dim: usize, keepdim: bool) -> Tensor {
148        let rank = self.shape().rank();
149        assert!(
150            dim < rank,
151            "argmin_dim dim {} out of bounds for rank {}",
152            dim,
153            rank
154        );
155
156        let in_dims = self.shape().dims.clone();
157        let reduce_size = in_dims[dim];
158        assert!(reduce_size > 0, "cannot argmin over empty dimension");
159
160        // Build output shape
161        let mut out_dims = in_dims.clone();
162        if keepdim {
163            out_dims[dim] = 1;
164        } else {
165            out_dims.remove(dim);
166        }
167        if out_dims.is_empty() {
168            out_dims.push(1);
169        }
170
171        let mut out = Tensor::zeros(out_dims.clone());
172
173        // Use stride-aware approach to handle non-contiguous tensors correctly
174        let out_size = out.size();
175
176        unsafe {
177            let dst = out.as_mut_ptr();
178
179            // Iterate over all output positions
180            for out_idx in 0..out_size {
181                // Convert flat output index to multi-dimensional coordinates
182                let mut out_coords = vec![0; out_dims.len()];
183                let mut tmp = out_idx;
184                for k in (0..out_dims.len()).rev() {
185                    out_coords[k] = tmp % out_dims[k];
186                    tmp /= out_dims[k];
187                }
188
189                // Convert output coordinates to input coordinates
190                let mut in_coords = vec![0; rank];
191                if keepdim {
192                    // When keepdim=true, output coords map directly to input coords
193                    for (k, &out_coord) in out_coords.iter().enumerate() {
194                        if k == dim {
195                            in_coords[k] = 0; // Will be set in the loop below
196                        } else {
197                            in_coords[k] = out_coord;
198                        }
199                    }
200                } else {
201                    // When keepdim=false, we need to insert the missing dimension
202                    let mut out_coord_idx = 0;
203                    for (k, in_coord) in in_coords.iter_mut().enumerate() {
204                        if k == dim {
205                            *in_coord = 0; // Will be set in the loop below
206                        } else {
207                            *in_coord = out_coords[out_coord_idx];
208                            out_coord_idx += 1;
209                        }
210                    }
211                }
212
213                // Find the argmin along the specified dimension
214                let mut best_val = f32::INFINITY;
215                let mut best_j = 0usize;
216
217                for j in 0..reduce_size {
218                    in_coords[dim] = j;
219                    let in_offset = self.shape().offset(&in_coords);
220                    let v = *self.as_ptr().add(in_offset);
221                    if v < best_val {
222                        best_val = v;
223                        best_j = j;
224                    }
225                }
226
227                *dst.add(out_idx) = best_j as f32;
228            }
229        }
230
231        out
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238
239    // Level 1 Tests: Basic functionality with simple contiguous tensors
240    #[test]
241    fn test_argmin_level1_basic_1d() {
242        // Simple 1D case
243        let x = Tensor::from_slice(&[3.0, -2.0, 5.0, -1.0], vec![4]).unwrap();
244        let idx = x.argmin();
245        assert_eq!(idx.get(&[0]), 1.0); // -2.0 is at index 1
246        assert_eq!(idx.shape().dims, vec![1]);
247    }
248
249    #[test]
250    fn test_argmin_level1_basic_1d_edge_cases() {
251        // Single element
252        let x = Tensor::from_slice(&[42.0], vec![1]).unwrap();
253        let idx = x.argmin();
254        assert_eq!(idx.get(&[0]), 0.0);
255
256        // All same values - should return first occurrence
257        let x = Tensor::from_slice(&[5.0, 5.0, 5.0], vec![3]).unwrap();
258        let idx = x.argmin();
259        assert_eq!(idx.get(&[0]), 0.0);
260
261        // Negative values
262        let x = Tensor::from_slice(&[-1.0, -5.0, -2.0], vec![3]).unwrap();
263        let idx = x.argmin();
264        assert_eq!(idx.get(&[0]), 1.0); // -5.0 is at index 1
265    }
266
267    #[test]
268    fn test_argmin_level1_basic_2d_contiguous() {
269        // Simple 2D case - whole tensor argmin
270        let x = Tensor::from_slice(&[3.0, -2.0, 5.0, -1.0, 0.0, -3.0], vec![2, 3]).unwrap();
271        let idx = x.argmin();
272        assert_eq!(idx.get(&[0]), 5.0); // -3.0 is at flat index 5
273        assert_eq!(idx.shape().dims, vec![1]);
274    }
275
276    #[test]
277    fn test_argmin_level1_dim_2d_basic() {
278        // Test argmin_dim with 2D tensor
279        let x = Tensor::from_slice(&[3.0, -2.0, 5.0, -1.0, 0.0, -3.0], vec![2, 3]).unwrap();
280        // Tensor looks like:
281        // [[3.0, -2.0, 5.0],
282        //  [-1.0, 0.0, -3.0]]
283
284        // Along dimension 1 (columns), keepdim=true
285        let idx1 = x.argmin_dim(1, true);
286        assert_eq!(idx1.shape().dims, vec![2, 1]);
287        assert_eq!(idx1.get(&[0, 0]), 1.0); // Row 0: -2.0 is at column index 1
288        assert_eq!(idx1.get(&[1, 0]), 2.0); // Row 1: -3.0 is at column index 2
289
290        // Along dimension 1 (columns), keepdim=false
291        let idx1_no_keep = x.argmin_dim(1, false);
292        assert_eq!(idx1_no_keep.shape().dims, vec![2]);
293        assert_eq!(idx1_no_keep.get(&[0]), 1.0);
294        assert_eq!(idx1_no_keep.get(&[1]), 2.0);
295
296        // Along dimension 0 (rows), keepdim=true
297        let idx0 = x.argmin_dim(0, true);
298        assert_eq!(idx0.shape().dims, vec![1, 3]);
299        assert_eq!(idx0.get(&[0, 0]), 1.0); // Column 0: -1.0 is at row index 1
300        assert_eq!(idx0.get(&[0, 1]), 0.0); // Column 1: -2.0 is at row index 0
301        assert_eq!(idx0.get(&[0, 2]), 1.0); // Column 2: -3.0 is at row index 1
302    }
303
304    #[test]
305    fn test_argmin_level1_3d_basic() {
306        // Test with 3D tensor
307        let data = vec![
308            1.0, -2.0, // [0,0,:] = [1.0, -2.0]
309            3.0, 4.0, // [0,1,:] = [3.0, 4.0]
310            -5.0, 6.0, // [1,0,:] = [-5.0, 6.0]
311            7.0, -8.0, // [1,1,:] = [7.0, -8.0]
312        ];
313        let x = Tensor::from_slice(&data, vec![2, 2, 2]).unwrap();
314
315        // Whole tensor argmin
316        let idx = x.argmin();
317        assert_eq!(idx.get(&[0]), 7.0); // -8.0 is at flat index 7
318
319        // Along dimension 2 (innermost), keepdim=false
320        let idx2 = x.argmin_dim(2, false);
321        assert_eq!(idx2.shape().dims, vec![2, 2]);
322        assert_eq!(idx2.get(&[0, 0]), 1.0); // [1.0, -2.0] -> min at index 1
323        assert_eq!(idx2.get(&[0, 1]), 0.0); // [3.0, 4.0] -> min at index 0
324        assert_eq!(idx2.get(&[1, 0]), 0.0); // [-5.0, 6.0] -> min at index 0
325        assert_eq!(idx2.get(&[1, 1]), 1.0); // [7.0, -8.0] -> min at index 1
326    }
327
328    // Level 2 Tests: Complex shapes, higher dimensions, and edge cases
329    #[test]
330    fn test_argmin_level2_large_tensors() {
331        // Test with larger tensors
332        let data: Vec<f32> = (0..1000).map(|i| (i as f32) * 0.1 - 50.0).collect();
333        // Values from -50.0 to 49.9, minimum at index 0
334        let x = Tensor::from_slice(&data, vec![1000]).unwrap();
335        let idx = x.argmin();
336        assert_eq!(idx.get(&[0]), 0.0);
337
338        // Reshape to 2D
339        let x_2d = Tensor::from_slice(&data, vec![25, 40]).unwrap();
340        let idx_2d = x_2d.argmin();
341        assert_eq!(idx_2d.get(&[0]), 0.0);
342
343        // Test along different dimensions
344        let idx_dim0 = x_2d.argmin_dim(0, false);
345        assert_eq!(idx_dim0.shape().dims, vec![40]);
346        assert_eq!(idx_dim0.get(&[0]), 0.0); // Column 0: minimum at row 0
347
348        let idx_dim1 = x_2d.argmin_dim(1, false);
349        assert_eq!(idx_dim1.shape().dims, vec![25]);
350        assert_eq!(idx_dim1.get(&[0]), 0.0); // Row 0: minimum at column 0
351    }
352
353    #[test]
354    fn test_argmin_level2_4d_tensor() {
355        // Test with 4D tensor [2, 3, 4, 5] = 120 elements
356        let data: Vec<f32> = (0..120).map(|i| 120.0 - i as f32).collect();
357        // Values from 120.0 down to 1.0, minimum at last index
358        let x = Tensor::from_slice(&data, vec![2, 3, 4, 5]).unwrap();
359
360        // Global argmin
361        let idx = x.argmin();
362        assert_eq!(idx.get(&[0]), 119.0); // minimum value 1.0 is at index 119
363
364        // Test argmin along dimension 3 (innermost)
365        let idx3 = x.argmin_dim(3, false);
366        assert_eq!(idx3.shape().dims, vec![2, 3, 4]);
367        // Each slice along dim 3 has values decreasing, so min is always at index 4
368        assert_eq!(idx3.get(&[0, 0, 0]), 4.0);
369        assert_eq!(idx3.get(&[1, 2, 3]), 4.0);
370
371        // Test argmin along dimension 0 (outermost)
372        let idx0 = x.argmin_dim(0, false);
373        assert_eq!(idx0.shape().dims, vec![3, 4, 5]);
374        // For each position, the minimum is in the second batch (index 1)
375        assert_eq!(idx0.get(&[0, 0, 0]), 1.0);
376        assert_eq!(idx0.get(&[2, 3, 4]), 1.0);
377    }
378
379    #[test]
380    fn test_argmin_level2_special_values() {
381        // Test with special floating point values
382        let data = vec![
383            f32::NAN,       // 0
384            f32::INFINITY,  // 1
385            -f32::INFINITY, // 2 <- this should be minimum
386            0.0,            // 3
387            -0.0,           // 4
388            1.0,            // 5
389        ];
390        let x = Tensor::from_slice(&data, vec![6]).unwrap();
391        let idx = x.argmin();
392        assert_eq!(idx.get(&[0]), 2.0); // -infinity at index 2
393
394        // Test with all NaN
395        let nan_data = vec![f32::NAN, f32::NAN, f32::NAN];
396        let x_nan = Tensor::from_slice(&nan_data, vec![3]).unwrap();
397        let idx_nan = x_nan.argmin();
398        // With all NaN, should return first index
399        assert_eq!(idx_nan.get(&[0]), 0.0);
400
401        // Test with mix of normal values and NaN
402        let mixed_data = vec![1.0, f32::NAN, -5.0, f32::NAN, 3.0];
403        let x_mixed = Tensor::from_slice(&mixed_data, vec![5]).unwrap();
404        let idx_mixed = x_mixed.argmin();
405        assert_eq!(idx_mixed.get(&[0]), 2.0); // -5.0 at index 2
406    }
407
408    #[test]
409    fn test_argmin_level2_ties() {
410        // Test behavior with tied minimum values (should return first occurrence)
411        let data = vec![3.0, -2.0, 5.0, -2.0, 0.0, -2.0]; // -2.0 appears at indices 1, 3, 5
412        let x = Tensor::from_slice(&data, vec![6]).unwrap();
413        let idx = x.argmin();
414        assert_eq!(idx.get(&[0]), 1.0); // First occurrence of -2.0
415
416        // Test with 2D tensor and ties
417        let x_2d = Tensor::from_slice(&data, vec![2, 3]).unwrap();
418        // [[3.0, -2.0, 5.0],
419        //  [-2.0, 0.0, -2.0]]
420
421        let idx_dim0 = x_2d.argmin_dim(0, false);
422        assert_eq!(idx_dim0.shape().dims, vec![3]);
423        assert_eq!(idx_dim0.get(&[0]), 1.0); // Column 0: min(-2.0 vs 3.0) -> row 1
424        assert_eq!(idx_dim0.get(&[1]), 0.0); // Column 1: min(-2.0 vs 0.0) -> row 0
425        assert_eq!(idx_dim0.get(&[2]), 1.0); // Column 2: min(5.0 vs -2.0) -> row 1
426
427        let idx_dim1 = x_2d.argmin_dim(1, false);
428        assert_eq!(idx_dim1.shape().dims, vec![2]);
429        assert_eq!(idx_dim1.get(&[0]), 1.0); // Row 0: min of [3.0, -2.0, 5.0] -> col 1
430        assert_eq!(idx_dim1.get(&[1]), 0.0); // Row 1: min of [-2.0, 0.0, -2.0] -> col 0 (first)
431    }
432
433    #[test]
434    fn test_argmin_level2_broadcasting_dims() {
435        // Test with dimensions of size 1 (singleton dimensions)
436        let data = vec![5.0, -3.0, 7.0, 1.0, -8.0, 2.0];
437        let x = Tensor::from_slice(&data, vec![1, 6, 1]).unwrap();
438
439        let idx = x.argmin();
440        assert_eq!(idx.get(&[0]), 4.0); // -8.0 at flat index 4
441
442        // Test argmin along different dimensions
443        let idx_dim0 = x.argmin_dim(0, false);
444        assert_eq!(idx_dim0.shape().dims, vec![6, 1]);
445
446        let idx_dim1 = x.argmin_dim(1, false);
447        assert_eq!(idx_dim1.shape().dims, vec![1, 1]);
448        assert_eq!(idx_dim1.get(&[0, 0]), 4.0); // -8.0 at position 4 along dim 1
449
450        let idx_dim2 = x.argmin_dim(2, false);
451        assert_eq!(idx_dim2.shape().dims, vec![1, 6]);
452    }
453
454    #[test]
455    fn test_argmin_level2_complex_3d() {
456        // Complex 3D case with multiple batch dimensions
457        let data = vec![
458            // Batch 0, Channel 0: [[1, 2], [3, 4]]
459            1.0, 2.0, 3.0, 4.0, // Batch 0, Channel 1: [[5, 6], [7, 8]]
460            5.0, 6.0, 7.0, 8.0, // Batch 0, Channel 2: [[-1, 0], [9, 10]]
461            -1.0, 0.0, 9.0, 10.0, // Batch 1, Channel 0: [[11, 12], [13, 14]]
462            11.0, 12.0, 13.0, 14.0, // Batch 1, Channel 1: [[15, 16], [17, 18]]
463            15.0, 16.0, 17.0, 18.0, // Batch 1, Channel 2: [[19, 20], [21, -5]]
464            19.0, 20.0, 21.0, -5.0,
465        ];
466        let x = Tensor::from_slice(&data, vec![2, 3, 2, 2]).unwrap();
467
468        // Global minimum
469        let idx = x.argmin();
470        assert_eq!(idx.get(&[0]), 23.0); // -5.0 is at flat index 23
471
472        // Argmin along dimension 1 (channels)
473        let idx_dim1 = x.argmin_dim(1, false);
474        assert_eq!(idx_dim1.shape().dims, vec![2, 2, 2]);
475        // At position [0,0,0]: min(1.0, 5.0, -1.0) = -1.0 at channel 2
476        assert_eq!(idx_dim1.get(&[0, 0, 0]), 2.0);
477        // At position [1,1,1]: min(14.0, 18.0, -5.0) = -5.0 at channel 2
478        assert_eq!(idx_dim1.get(&[1, 1, 1]), 2.0);
479    }
480
481    // Level 3 Tests: Non-contiguous tensors, views, and strided memory layouts
482    #[test]
483    fn test_argmin_level3_transpose_view() {
484        // Create a 2x3 tensor and transpose it to get a non-contiguous view
485        let x = Tensor::from_slice(&[1.0, 3.0, 2.0, 4.0, 0.0, -5.0], vec![2, 3]).unwrap();
486        // Original: [[1.0, 3.0, 2.0],
487        //            [4.0, 0.0, -5.0]]
488
489        let x_t = x.transpose(0, 1);
490        // Transposed: [[1.0, 4.0],
491        //              [3.0, 0.0],
492        //              [2.0, -5.0]]
493        assert_eq!(x_t.shape().dims, vec![3, 2]);
494        assert!(!x_t.is_contiguous()); // Should be a view
495
496        // Test global argmin on transposed view
497        let idx = x_t.argmin();
498        assert_eq!(idx.get(&[0]), 5.0); // flat index 5 still points to value -5.0
499
500        // Test argmin along dim=0 of transposed tensor
501        let idx0 = x_t.argmin_dim(0, false);
502        assert_eq!(idx0.shape().dims, vec![2]);
503        assert_eq!(idx0.get(&[0]), 0.0); // col 0: [1.0, 3.0, 2.0] -> min 1.0 at index 0
504        assert_eq!(idx0.get(&[1]), 2.0); // col 1: [4.0, 0.0, -5.0] -> min -5.0 at index 2
505
506        // Test argmin along dim=1 of transposed tensor
507        let idx1 = x_t.argmin_dim(1, false);
508        assert_eq!(idx1.shape().dims, vec![3]);
509        assert_eq!(idx1.get(&[0]), 0.0); // row 0: [1.0, 4.0] -> min 1.0 at index 0
510        assert_eq!(idx1.get(&[1]), 1.0); // row 1: [3.0, 0.0] -> min 0.0 at index 1
511        assert_eq!(idx1.get(&[2]), 1.0); // row 2: [2.0, -5.0] -> min -5.0 at index 1
512    }
513
514    #[test]
515    fn test_argmin_level3_slice_view() {
516        // Create a 3x4 tensor and take a slice
517        let data = vec![
518            1.0, 2.0, 3.0, 4.0, 5.0, -6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
519        ];
520        let x = Tensor::from_slice(&data, vec![3, 4]).unwrap();
521        // [[1, 2, 3, 4],
522        //  [5, -6, 7, 8],
523        //  [9, 10, 11, 12]]
524
525        // Select middle row (creates a view)
526        let middle_row = x.select(0, 1);
527        // [5, -6, 7, 8]
528        assert_eq!(middle_row.shape().dims, vec![4]);
529
530        let idx = middle_row.argmin();
531        assert_eq!(idx.get(&[0]), 1.0); // index 1 has value -6.0
532
533        // Test argmin_dim on 1D slice (should work the same as global argmin)
534        let idx_dim = middle_row.argmin_dim(0, false);
535        assert_eq!(idx_dim.shape().dims, vec![1]);
536        assert_eq!(idx_dim.get(&[0]), 1.0);
537
538        // Test with column slice
539        let second_col = x.select(1, 1);
540        // [2, -6, 10]
541        assert_eq!(second_col.shape().dims, vec![3]);
542        let idx_col = second_col.argmin();
543        assert_eq!(idx_col.get(&[0]), 1.0); // -6.0 at index 1
544    }
545
546    #[test]
547    fn test_argmin_level3_permuted_3d() {
548        // Test 3D tensor with permuted dimensions
549        let data = (0..24).map(|i| 24.0 - i as f32).collect::<Vec<_>>();
550        let x = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
551        // Shape [2, 3, 4] with values 24.0 down to 1.0
552        // Minimum value 1.0 is at the last position
553
554        // Permute to [4, 2, 3] (swap dims 0 and 2)
555        let x_perm = x.permute(vec![2, 1, 0]);
556        assert_eq!(x_perm.shape().dims, vec![4, 3, 2]);
557        assert!(!x_perm.is_contiguous());
558
559        // Global argmin should still find the minimum value (1.0)
560        let idx = x_perm.argmin();
561        assert_eq!(idx.get(&[0]), 23.0); // The min value 1.0 is still at flat index 23
562
563        // Test argmin along each dimension of permuted tensor
564        let idx0 = x_perm.argmin_dim(0, false); // [3, 2]
565        assert_eq!(idx0.shape().dims, vec![3, 2]);
566
567        let idx1 = x_perm.argmin_dim(1, false); // [4, 2]
568        assert_eq!(idx1.shape().dims, vec![4, 2]);
569
570        let idx2 = x_perm.argmin_dim(2, false); // [4, 3]
571        assert_eq!(idx2.shape().dims, vec![4, 3]);
572
573        // Verify some specific values
574        // Since values decrease from 24.0 to 1.0, the permuted tensor should have
575        // minimum values at the later positions in the original ordering
576    }
577
578    #[test]
579    fn test_argmin_level3_nested_views() {
580        // Test nested transformations (transpose then select)
581        let data = vec![
582            1.0, 2.0, -3.0, 4.0, 5.0, 6.0, 7.0, -8.0, 9.0, 10.0, 11.0, 12.0,
583        ];
584        let x = Tensor::from_slice(&data, vec![4, 3]).unwrap();
585
586        // First transpose, then select a row
587        let x_t = x.transpose(0, 1); // [3, 4]
588        let row = x_t.select(0, 1); // Select second row: [2, 5, -8, 11]
589        assert_eq!(row.shape().dims, vec![4]);
590
591        let idx = row.argmin();
592        assert_eq!(idx.get(&[0]), 2.0); // index 2 has value -8.0
593    }
594
595    #[test]
596    fn test_argmin_level3_strided_memory() {
597        // Test with highly strided memory patterns
598        let data: Vec<f32> = (0..60).map(|i| i as f32 - 30.0).collect();
599        let x = Tensor::from_slice(&data, vec![3, 4, 5]).unwrap();
600        // Values from -30.0 to 29.0
601
602        // Create complex views that result in non-contiguous memory
603        let x_perm = x.permute(vec![2, 0, 1]); // [5, 3, 4]
604        assert!(!x_perm.is_contiguous());
605
606        // Test global argmin
607        let idx = x_perm.argmin();
608        assert_eq!(idx.get(&[0]), 0.0); // -30.0 is at index 0
609
610        // Test dimension-wise argmin on permuted tensor
611        let idx0 = x_perm.argmin_dim(0, false);
612        assert_eq!(idx0.shape().dims, vec![3, 4]);
613
614        let idx1 = x_perm.argmin_dim(1, false);
615        assert_eq!(idx1.shape().dims, vec![5, 4]);
616
617        let idx2 = x_perm.argmin_dim(2, false);
618        assert_eq!(idx2.shape().dims, vec![5, 3]);
619    }
620
621    #[test]
622    fn test_argmin_level3_multiple_transformations() {
623        // Test with multiple chained transformations
624        let data = vec![
625            1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
626            17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, -24.0,
627        ];
628        let x = Tensor::from_slice(&data, vec![4, 6]).unwrap();
629
630        // Chain multiple transformations
631        let x_t = x.transpose(0, 1); // [6, 4]
632        let x_subset = x_t.select(0, 5); // Select last row: [6, 12, 18, -24]
633
634        // Note: select might create contiguous tensors in some cases, so we don't assert non-contiguous
635        assert_eq!(x_subset.shape().dims, vec![4]);
636
637        let idx = x_subset.argmin();
638        assert_eq!(idx.get(&[0]), 3.0); // -24.0 at index 3
639
640        // Test on a slice of the transposed tensor
641        let partial_col = x_t.select(1, 2); // Select third column: [15, 16, 17, 18, 19, 20]
642        let idx_partial = partial_col.argmin();
643        assert_eq!(idx_partial.get(&[0]), 0.0); // 15.0 at index 0
644
645        // Test argmin on the non-contiguous transposed tensor
646        assert!(!x_t.is_contiguous());
647        let idx_trans = x_t.argmin();
648        assert_eq!(idx_trans.get(&[0]), 23.0); // -24.0 is still at flat index 23
649    }
650
651    #[test]
652    fn test_argmin_level3_view_consistency() {
653        // Test that argmin results are consistent between original and view
654        let data = vec![
655            5.0, -2.0, 8.0, 1.0, // row 0: min -2.0 at col 1
656            3.0, 9.0, -4.0, 7.0, // row 1: min -4.0 at col 2
657            6.0, 0.0, 2.0, -1.0, // row 2: min -1.0 at col 3
658        ];
659        let x = Tensor::from_slice(&data, vec![3, 4]).unwrap();
660        // Global minimum is -4.0 at flat index 6
661
662        // Test argmin on original tensor
663        let idx_orig = x.argmin();
664        assert_eq!(idx_orig.get(&[0]), 6.0); // -4.0 at index 6
665
666        // Create a view by transposing and test consistency
667        let x_t = x.transpose(0, 1);
668        // Transposed tensor:
669        // [[5.0, 3.0, 6.0],     // col 0 of original -> row 0: min 3.0 at index 1
670        //  [-2.0, 9.0, 0.0],    // col 1 of original -> row 1: min -2.0 at index 0
671        //  [8.0, -4.0, 2.0],    // col 2 of original -> row 2: min -4.0 at index 1
672        //  [1.0, 7.0, -1.0]]    // col 3 of original -> row 3: min -1.0 at index 2
673
674        let idx_view = x_t.argmin();
675        // The minimum value is still -4.0, but its flat index in the view may differ
676        // Let's just check that both find the minimum value correctly
677
678        // Extract actual minimum values to verify they're the same
679        let min_val_orig = unsafe {
680            let flat_idx = idx_orig.get(&[0]) as usize;
681            *x.as_ptr().add(flat_idx)
682        };
683        let min_val_view = unsafe {
684            let flat_idx = idx_view.get(&[0]) as usize;
685            let dims = x_t.shape().dims.clone();
686            let mut coords = vec![0; dims.len()];
687            let mut tmp = flat_idx;
688            for k in (0..dims.len()).rev() {
689                coords[k] = tmp % dims[k];
690                tmp /= dims[k];
691            }
692            let offset = x_t.shape().offset(&coords);
693            *x_t.as_ptr().add(offset)
694        };
695
696        assert_eq!(min_val_orig, -4.0);
697        assert_eq!(min_val_view, -4.0);
698
699        // Test simpler consistency: argmin along specific dimensions
700        let idx_dim0_orig = x.argmin_dim(0, false); // argmin along rows -> [4] (min of each column)
701        let idx_dim1_trans = x_t.argmin_dim(1, false); // argmin along columns -> [4] (min of each row)
702
703        // These should give the same results since we're reducing along corresponding dims
704        assert_eq!(idx_dim0_orig.shape().dims, vec![4]);
705        assert_eq!(idx_dim1_trans.shape().dims, vec![4]);
706
707        // Original columns vs transposed rows should match
708        assert_eq!(idx_dim0_orig.get(&[0]), 1.0); // col 0: min(5,3,6) = 3 at row 1
709        assert_eq!(idx_dim0_orig.get(&[1]), 0.0); // col 1: min(-2,9,0) = -2 at row 0
710        assert_eq!(idx_dim0_orig.get(&[2]), 1.0); // col 2: min(8,-4,2) = -4 at row 1
711        assert_eq!(idx_dim0_orig.get(&[3]), 2.0); // col 3: min(1,7,-1) = -1 at row 2
712
713        assert_eq!(idx_dim1_trans.get(&[0]), 1.0); // corresponds to col 0
714        assert_eq!(idx_dim1_trans.get(&[1]), 0.0); // corresponds to col 1
715        assert_eq!(idx_dim1_trans.get(&[2]), 1.0); // corresponds to col 2
716        assert_eq!(idx_dim1_trans.get(&[3]), 2.0); // corresponds to col 3
717    }
718
719    // Keep the old basic tests for compatibility
720    #[test]
721    fn test_argmin_basic() {
722        let x = Tensor::from_slice(&[3.0, -2.0, 5.0, -1.0], vec![4]).unwrap();
723        let idx = x.argmin();
724        unsafe {
725            assert_eq!(*idx.as_ptr(), 1.0);
726        }
727    }
728
729    #[test]
730    fn test_argmin_dim() {
731        let x = Tensor::from_slice(&[3.0, -2.0, 5.0, -1.0, 0.0, -3.0], vec![2, 3]).unwrap();
732        let idx0 = x.argmin_dim(1, true);
733        assert_eq!(idx0.shape().dims, vec![2, 1]);
734        assert_eq!(idx0.get(&[0, 0]), 1.0);
735        assert_eq!(idx0.get(&[1, 0]), 2.0);
736    }
737}