train_station/tensor/reductions/argmin.rs
1use crate::tensor::core::Tensor;
2
3impl Tensor {
4 /// Returns the index of the minimum value in the tensor
5 ///
6 /// This method finds the flat index of the minimum value across all elements
7 /// in the tensor. The result is a scalar tensor containing the index as a
8 /// floating-point value. This operation is non-differentiable and the output
9 /// never requires gradient tracking.
10 ///
11 /// # Returns
12 ///
13 /// A tensor with shape `[1]` containing the flat index of the minimum value
14 /// as a `f32`. If the input tensor is empty, returns `0.0`.
15 ///
16 /// # Examples
17 ///
18 /// ```
19 /// use train_station::Tensor;
20 ///
21 /// let tensor = Tensor::from_slice(&[3.0, -2.0, 5.0, -1.0], vec![4]).unwrap();
22 /// let min_index = tensor.argmin();
23 /// assert_eq!(min_index.get(&[0]), 1.0); // -2.0 is at index 1
24 /// ```
25 ///
26 /// ```
27 /// use train_station::Tensor;
28 ///
29 /// // Empty tensor case
30 /// let empty_tensor = Tensor::new(vec![0]);
31 /// let min_index = empty_tensor.argmin();
32 /// assert_eq!(min_index.get(&[0]), 0.0);
33 /// ```
34 #[track_caller]
35 pub fn argmin(&self) -> Tensor {
36 let mut out = Tensor::new(vec![1]);
37 if self.size() == 0 {
38 out.fill(0.0);
39 return out;
40 }
41
42 let mut best_val = f32::INFINITY;
43 let mut best_idx = 0usize;
44
45 if self.is_contiguous() {
46 // Fast path for contiguous tensors
47 unsafe {
48 let src = self.as_ptr();
49 for i in 0..self.size() {
50 let v = *src.add(i);
51 if v < best_val {
52 best_val = v;
53 best_idx = i;
54 }
55 }
56 }
57 } else {
58 // Stride-aware path for non-contiguous tensors
59 let dims = self.shape().dims.clone();
60 for flat_idx in 0..self.size() {
61 // Convert flat index to multi-dimensional coordinates
62 let mut coords = vec![0; dims.len()];
63 let mut tmp = flat_idx;
64 for k in (0..dims.len()).rev() {
65 coords[k] = tmp % dims[k];
66 tmp /= dims[k];
67 }
68
69 // Get value using stride-aware offset
70 let offset = self.shape().offset(&coords);
71 let v = unsafe { *self.as_ptr().add(offset) };
72 if v < best_val {
73 best_val = v;
74 best_idx = flat_idx;
75 }
76 }
77 }
78
79 unsafe {
80 *out.as_mut_ptr() = best_idx as f32;
81 }
82 out
83 }
84
85 /// Returns the indices of minimum values along a specified dimension
86 ///
87 /// This method finds the indices of minimum values along the specified dimension.
88 /// The result contains the indices where the minimum values occur in that dimension.
89 /// This operation is non-differentiable and the output never requires gradient tracking.
90 ///
91 /// # Arguments
92 ///
93 /// * `dim` - The dimension along which to find minimum indices (0-based)
94 /// * `keepdim` - Whether to keep the reduced dimension in the output shape
95 /// - If `true`, the reduced dimension is kept with size 1
96 /// - If `false`, the reduced dimension is removed from the output shape
97 ///
98 /// # Returns
99 ///
100 /// A tensor containing the indices of minimum values along the specified dimension.
101 /// The output shape depends on `keepdim`:
102 /// - If `keepdim` is `true`, the reduced dimension has size 1
103 /// - If `keepdim` is `false`, the reduced dimension is removed
104 ///
105 /// # Panics
106 ///
107 /// * If `dim` is out of bounds for the tensor's rank
108 /// * If the dimension to reduce has size 0
109 ///
110 /// # Examples
111 ///
112 /// ```
113 /// use train_station::Tensor;
114 ///
115 /// let tensor = Tensor::from_slice(&[3.0, -2.0, 5.0, -1.0, 0.0, -3.0], vec![2, 3]).unwrap();
116 ///
117 /// // Find minimum indices along dimension 1 (columns), keeping the dimension
118 /// let indices = tensor.argmin_dim(1, true);
119 /// assert_eq!(indices.shape().dims, vec![2, 1]);
120 /// assert_eq!(indices.get(&[0, 0]), 1.0); // -2.0 is at index 1 in first row
121 /// assert_eq!(indices.get(&[1, 0]), 2.0); // -3.0 is at index 2 in second row
122 /// ```
123 ///
124 /// ```
125 /// use train_station::Tensor;
126 ///
127 /// let tensor = Tensor::from_slice(&[3.0, -2.0, 5.0, -1.0, 0.0, -3.0], vec![2, 3]).unwrap();
128 ///
129 /// // Find minimum indices along dimension 1 (columns), removing the dimension
130 /// let indices = tensor.argmin_dim(1, false);
131 /// assert_eq!(indices.shape().dims, vec![2]);
132 /// assert_eq!(indices.get(&[0]), 1.0); // -2.0 is at index 1 in first row
133 /// assert_eq!(indices.get(&[1]), 2.0); // -3.0 is at index 2 in second row
134 /// ```
135 ///
136 /// ```
137 /// use train_station::Tensor;
138 ///
139 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
140 ///
141 /// // Find minimum index in a 1D tensor
142 /// let index = tensor.argmin_dim(0, false);
143 /// assert_eq!(index.shape().dims, vec![1]);
144 /// assert_eq!(index.get(&[0]), 0.0); // 1.0 is at index 0
145 /// ```
146 #[track_caller]
147 pub fn argmin_dim(&self, dim: usize, keepdim: bool) -> Tensor {
148 let rank = self.shape().rank();
149 assert!(
150 dim < rank,
151 "argmin_dim dim {} out of bounds for rank {}",
152 dim,
153 rank
154 );
155
156 let in_dims = self.shape().dims.clone();
157 let reduce_size = in_dims[dim];
158 assert!(reduce_size > 0, "cannot argmin over empty dimension");
159
160 // Build output shape
161 let mut out_dims = in_dims.clone();
162 if keepdim {
163 out_dims[dim] = 1;
164 } else {
165 out_dims.remove(dim);
166 }
167 if out_dims.is_empty() {
168 out_dims.push(1);
169 }
170
171 let mut out = Tensor::zeros(out_dims.clone());
172
173 // Use stride-aware approach to handle non-contiguous tensors correctly
174 let out_size = out.size();
175
176 unsafe {
177 let dst = out.as_mut_ptr();
178
179 // Iterate over all output positions
180 for out_idx in 0..out_size {
181 // Convert flat output index to multi-dimensional coordinates
182 let mut out_coords = vec![0; out_dims.len()];
183 let mut tmp = out_idx;
184 for k in (0..out_dims.len()).rev() {
185 out_coords[k] = tmp % out_dims[k];
186 tmp /= out_dims[k];
187 }
188
189 // Convert output coordinates to input coordinates
190 let mut in_coords = vec![0; rank];
191 if keepdim {
192 // When keepdim=true, output coords map directly to input coords
193 for (k, &out_coord) in out_coords.iter().enumerate() {
194 if k == dim {
195 in_coords[k] = 0; // Will be set in the loop below
196 } else {
197 in_coords[k] = out_coord;
198 }
199 }
200 } else {
201 // When keepdim=false, we need to insert the missing dimension
202 let mut out_coord_idx = 0;
203 for (k, in_coord) in in_coords.iter_mut().enumerate() {
204 if k == dim {
205 *in_coord = 0; // Will be set in the loop below
206 } else {
207 *in_coord = out_coords[out_coord_idx];
208 out_coord_idx += 1;
209 }
210 }
211 }
212
213 // Find the argmin along the specified dimension
214 let mut best_val = f32::INFINITY;
215 let mut best_j = 0usize;
216
217 for j in 0..reduce_size {
218 in_coords[dim] = j;
219 let in_offset = self.shape().offset(&in_coords);
220 let v = *self.as_ptr().add(in_offset);
221 if v < best_val {
222 best_val = v;
223 best_j = j;
224 }
225 }
226
227 *dst.add(out_idx) = best_j as f32;
228 }
229 }
230
231 out
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238
239 // Level 1 Tests: Basic functionality with simple contiguous tensors
240 #[test]
241 fn test_argmin_level1_basic_1d() {
242 // Simple 1D case
243 let x = Tensor::from_slice(&[3.0, -2.0, 5.0, -1.0], vec![4]).unwrap();
244 let idx = x.argmin();
245 assert_eq!(idx.get(&[0]), 1.0); // -2.0 is at index 1
246 assert_eq!(idx.shape().dims, vec![1]);
247 }
248
249 #[test]
250 fn test_argmin_level1_basic_1d_edge_cases() {
251 // Single element
252 let x = Tensor::from_slice(&[42.0], vec![1]).unwrap();
253 let idx = x.argmin();
254 assert_eq!(idx.get(&[0]), 0.0);
255
256 // All same values - should return first occurrence
257 let x = Tensor::from_slice(&[5.0, 5.0, 5.0], vec![3]).unwrap();
258 let idx = x.argmin();
259 assert_eq!(idx.get(&[0]), 0.0);
260
261 // Negative values
262 let x = Tensor::from_slice(&[-1.0, -5.0, -2.0], vec![3]).unwrap();
263 let idx = x.argmin();
264 assert_eq!(idx.get(&[0]), 1.0); // -5.0 is at index 1
265 }
266
267 #[test]
268 fn test_argmin_level1_basic_2d_contiguous() {
269 // Simple 2D case - whole tensor argmin
270 let x = Tensor::from_slice(&[3.0, -2.0, 5.0, -1.0, 0.0, -3.0], vec![2, 3]).unwrap();
271 let idx = x.argmin();
272 assert_eq!(idx.get(&[0]), 5.0); // -3.0 is at flat index 5
273 assert_eq!(idx.shape().dims, vec![1]);
274 }
275
276 #[test]
277 fn test_argmin_level1_dim_2d_basic() {
278 // Test argmin_dim with 2D tensor
279 let x = Tensor::from_slice(&[3.0, -2.0, 5.0, -1.0, 0.0, -3.0], vec![2, 3]).unwrap();
280 // Tensor looks like:
281 // [[3.0, -2.0, 5.0],
282 // [-1.0, 0.0, -3.0]]
283
284 // Along dimension 1 (columns), keepdim=true
285 let idx1 = x.argmin_dim(1, true);
286 assert_eq!(idx1.shape().dims, vec![2, 1]);
287 assert_eq!(idx1.get(&[0, 0]), 1.0); // Row 0: -2.0 is at column index 1
288 assert_eq!(idx1.get(&[1, 0]), 2.0); // Row 1: -3.0 is at column index 2
289
290 // Along dimension 1 (columns), keepdim=false
291 let idx1_no_keep = x.argmin_dim(1, false);
292 assert_eq!(idx1_no_keep.shape().dims, vec![2]);
293 assert_eq!(idx1_no_keep.get(&[0]), 1.0);
294 assert_eq!(idx1_no_keep.get(&[1]), 2.0);
295
296 // Along dimension 0 (rows), keepdim=true
297 let idx0 = x.argmin_dim(0, true);
298 assert_eq!(idx0.shape().dims, vec![1, 3]);
299 assert_eq!(idx0.get(&[0, 0]), 1.0); // Column 0: -1.0 is at row index 1
300 assert_eq!(idx0.get(&[0, 1]), 0.0); // Column 1: -2.0 is at row index 0
301 assert_eq!(idx0.get(&[0, 2]), 1.0); // Column 2: -3.0 is at row index 1
302 }
303
304 #[test]
305 fn test_argmin_level1_3d_basic() {
306 // Test with 3D tensor
307 let data = vec![
308 1.0, -2.0, // [0,0,:] = [1.0, -2.0]
309 3.0, 4.0, // [0,1,:] = [3.0, 4.0]
310 -5.0, 6.0, // [1,0,:] = [-5.0, 6.0]
311 7.0, -8.0, // [1,1,:] = [7.0, -8.0]
312 ];
313 let x = Tensor::from_slice(&data, vec![2, 2, 2]).unwrap();
314
315 // Whole tensor argmin
316 let idx = x.argmin();
317 assert_eq!(idx.get(&[0]), 7.0); // -8.0 is at flat index 7
318
319 // Along dimension 2 (innermost), keepdim=false
320 let idx2 = x.argmin_dim(2, false);
321 assert_eq!(idx2.shape().dims, vec![2, 2]);
322 assert_eq!(idx2.get(&[0, 0]), 1.0); // [1.0, -2.0] -> min at index 1
323 assert_eq!(idx2.get(&[0, 1]), 0.0); // [3.0, 4.0] -> min at index 0
324 assert_eq!(idx2.get(&[1, 0]), 0.0); // [-5.0, 6.0] -> min at index 0
325 assert_eq!(idx2.get(&[1, 1]), 1.0); // [7.0, -8.0] -> min at index 1
326 }
327
328 // Level 2 Tests: Complex shapes, higher dimensions, and edge cases
329 #[test]
330 fn test_argmin_level2_large_tensors() {
331 // Test with larger tensors
332 let data: Vec<f32> = (0..1000).map(|i| (i as f32) * 0.1 - 50.0).collect();
333 // Values from -50.0 to 49.9, minimum at index 0
334 let x = Tensor::from_slice(&data, vec![1000]).unwrap();
335 let idx = x.argmin();
336 assert_eq!(idx.get(&[0]), 0.0);
337
338 // Reshape to 2D
339 let x_2d = Tensor::from_slice(&data, vec![25, 40]).unwrap();
340 let idx_2d = x_2d.argmin();
341 assert_eq!(idx_2d.get(&[0]), 0.0);
342
343 // Test along different dimensions
344 let idx_dim0 = x_2d.argmin_dim(0, false);
345 assert_eq!(idx_dim0.shape().dims, vec![40]);
346 assert_eq!(idx_dim0.get(&[0]), 0.0); // Column 0: minimum at row 0
347
348 let idx_dim1 = x_2d.argmin_dim(1, false);
349 assert_eq!(idx_dim1.shape().dims, vec![25]);
350 assert_eq!(idx_dim1.get(&[0]), 0.0); // Row 0: minimum at column 0
351 }
352
353 #[test]
354 fn test_argmin_level2_4d_tensor() {
355 // Test with 4D tensor [2, 3, 4, 5] = 120 elements
356 let data: Vec<f32> = (0..120).map(|i| 120.0 - i as f32).collect();
357 // Values from 120.0 down to 1.0, minimum at last index
358 let x = Tensor::from_slice(&data, vec![2, 3, 4, 5]).unwrap();
359
360 // Global argmin
361 let idx = x.argmin();
362 assert_eq!(idx.get(&[0]), 119.0); // minimum value 1.0 is at index 119
363
364 // Test argmin along dimension 3 (innermost)
365 let idx3 = x.argmin_dim(3, false);
366 assert_eq!(idx3.shape().dims, vec![2, 3, 4]);
367 // Each slice along dim 3 has values decreasing, so min is always at index 4
368 assert_eq!(idx3.get(&[0, 0, 0]), 4.0);
369 assert_eq!(idx3.get(&[1, 2, 3]), 4.0);
370
371 // Test argmin along dimension 0 (outermost)
372 let idx0 = x.argmin_dim(0, false);
373 assert_eq!(idx0.shape().dims, vec![3, 4, 5]);
374 // For each position, the minimum is in the second batch (index 1)
375 assert_eq!(idx0.get(&[0, 0, 0]), 1.0);
376 assert_eq!(idx0.get(&[2, 3, 4]), 1.0);
377 }
378
379 #[test]
380 fn test_argmin_level2_special_values() {
381 // Test with special floating point values
382 let data = vec![
383 f32::NAN, // 0
384 f32::INFINITY, // 1
385 -f32::INFINITY, // 2 <- this should be minimum
386 0.0, // 3
387 -0.0, // 4
388 1.0, // 5
389 ];
390 let x = Tensor::from_slice(&data, vec![6]).unwrap();
391 let idx = x.argmin();
392 assert_eq!(idx.get(&[0]), 2.0); // -infinity at index 2
393
394 // Test with all NaN
395 let nan_data = vec![f32::NAN, f32::NAN, f32::NAN];
396 let x_nan = Tensor::from_slice(&nan_data, vec![3]).unwrap();
397 let idx_nan = x_nan.argmin();
398 // With all NaN, should return first index
399 assert_eq!(idx_nan.get(&[0]), 0.0);
400
401 // Test with mix of normal values and NaN
402 let mixed_data = vec![1.0, f32::NAN, -5.0, f32::NAN, 3.0];
403 let x_mixed = Tensor::from_slice(&mixed_data, vec![5]).unwrap();
404 let idx_mixed = x_mixed.argmin();
405 assert_eq!(idx_mixed.get(&[0]), 2.0); // -5.0 at index 2
406 }
407
408 #[test]
409 fn test_argmin_level2_ties() {
410 // Test behavior with tied minimum values (should return first occurrence)
411 let data = vec![3.0, -2.0, 5.0, -2.0, 0.0, -2.0]; // -2.0 appears at indices 1, 3, 5
412 let x = Tensor::from_slice(&data, vec![6]).unwrap();
413 let idx = x.argmin();
414 assert_eq!(idx.get(&[0]), 1.0); // First occurrence of -2.0
415
416 // Test with 2D tensor and ties
417 let x_2d = Tensor::from_slice(&data, vec![2, 3]).unwrap();
418 // [[3.0, -2.0, 5.0],
419 // [-2.0, 0.0, -2.0]]
420
421 let idx_dim0 = x_2d.argmin_dim(0, false);
422 assert_eq!(idx_dim0.shape().dims, vec![3]);
423 assert_eq!(idx_dim0.get(&[0]), 1.0); // Column 0: min(-2.0 vs 3.0) -> row 1
424 assert_eq!(idx_dim0.get(&[1]), 0.0); // Column 1: min(-2.0 vs 0.0) -> row 0
425 assert_eq!(idx_dim0.get(&[2]), 1.0); // Column 2: min(5.0 vs -2.0) -> row 1
426
427 let idx_dim1 = x_2d.argmin_dim(1, false);
428 assert_eq!(idx_dim1.shape().dims, vec![2]);
429 assert_eq!(idx_dim1.get(&[0]), 1.0); // Row 0: min of [3.0, -2.0, 5.0] -> col 1
430 assert_eq!(idx_dim1.get(&[1]), 0.0); // Row 1: min of [-2.0, 0.0, -2.0] -> col 0 (first)
431 }
432
433 #[test]
434 fn test_argmin_level2_broadcasting_dims() {
435 // Test with dimensions of size 1 (singleton dimensions)
436 let data = vec![5.0, -3.0, 7.0, 1.0, -8.0, 2.0];
437 let x = Tensor::from_slice(&data, vec![1, 6, 1]).unwrap();
438
439 let idx = x.argmin();
440 assert_eq!(idx.get(&[0]), 4.0); // -8.0 at flat index 4
441
442 // Test argmin along different dimensions
443 let idx_dim0 = x.argmin_dim(0, false);
444 assert_eq!(idx_dim0.shape().dims, vec![6, 1]);
445
446 let idx_dim1 = x.argmin_dim(1, false);
447 assert_eq!(idx_dim1.shape().dims, vec![1, 1]);
448 assert_eq!(idx_dim1.get(&[0, 0]), 4.0); // -8.0 at position 4 along dim 1
449
450 let idx_dim2 = x.argmin_dim(2, false);
451 assert_eq!(idx_dim2.shape().dims, vec![1, 6]);
452 }
453
454 #[test]
455 fn test_argmin_level2_complex_3d() {
456 // Complex 3D case with multiple batch dimensions
457 let data = vec![
458 // Batch 0, Channel 0: [[1, 2], [3, 4]]
459 1.0, 2.0, 3.0, 4.0, // Batch 0, Channel 1: [[5, 6], [7, 8]]
460 5.0, 6.0, 7.0, 8.0, // Batch 0, Channel 2: [[-1, 0], [9, 10]]
461 -1.0, 0.0, 9.0, 10.0, // Batch 1, Channel 0: [[11, 12], [13, 14]]
462 11.0, 12.0, 13.0, 14.0, // Batch 1, Channel 1: [[15, 16], [17, 18]]
463 15.0, 16.0, 17.0, 18.0, // Batch 1, Channel 2: [[19, 20], [21, -5]]
464 19.0, 20.0, 21.0, -5.0,
465 ];
466 let x = Tensor::from_slice(&data, vec![2, 3, 2, 2]).unwrap();
467
468 // Global minimum
469 let idx = x.argmin();
470 assert_eq!(idx.get(&[0]), 23.0); // -5.0 is at flat index 23
471
472 // Argmin along dimension 1 (channels)
473 let idx_dim1 = x.argmin_dim(1, false);
474 assert_eq!(idx_dim1.shape().dims, vec![2, 2, 2]);
475 // At position [0,0,0]: min(1.0, 5.0, -1.0) = -1.0 at channel 2
476 assert_eq!(idx_dim1.get(&[0, 0, 0]), 2.0);
477 // At position [1,1,1]: min(14.0, 18.0, -5.0) = -5.0 at channel 2
478 assert_eq!(idx_dim1.get(&[1, 1, 1]), 2.0);
479 }
480
481 // Level 3 Tests: Non-contiguous tensors, views, and strided memory layouts
482 #[test]
483 fn test_argmin_level3_transpose_view() {
484 // Create a 2x3 tensor and transpose it to get a non-contiguous view
485 let x = Tensor::from_slice(&[1.0, 3.0, 2.0, 4.0, 0.0, -5.0], vec![2, 3]).unwrap();
486 // Original: [[1.0, 3.0, 2.0],
487 // [4.0, 0.0, -5.0]]
488
489 let x_t = x.transpose(0, 1);
490 // Transposed: [[1.0, 4.0],
491 // [3.0, 0.0],
492 // [2.0, -5.0]]
493 assert_eq!(x_t.shape().dims, vec![3, 2]);
494 assert!(!x_t.is_contiguous()); // Should be a view
495
496 // Test global argmin on transposed view
497 let idx = x_t.argmin();
498 assert_eq!(idx.get(&[0]), 5.0); // flat index 5 still points to value -5.0
499
500 // Test argmin along dim=0 of transposed tensor
501 let idx0 = x_t.argmin_dim(0, false);
502 assert_eq!(idx0.shape().dims, vec![2]);
503 assert_eq!(idx0.get(&[0]), 0.0); // col 0: [1.0, 3.0, 2.0] -> min 1.0 at index 0
504 assert_eq!(idx0.get(&[1]), 2.0); // col 1: [4.0, 0.0, -5.0] -> min -5.0 at index 2
505
506 // Test argmin along dim=1 of transposed tensor
507 let idx1 = x_t.argmin_dim(1, false);
508 assert_eq!(idx1.shape().dims, vec![3]);
509 assert_eq!(idx1.get(&[0]), 0.0); // row 0: [1.0, 4.0] -> min 1.0 at index 0
510 assert_eq!(idx1.get(&[1]), 1.0); // row 1: [3.0, 0.0] -> min 0.0 at index 1
511 assert_eq!(idx1.get(&[2]), 1.0); // row 2: [2.0, -5.0] -> min -5.0 at index 1
512 }
513
514 #[test]
515 fn test_argmin_level3_slice_view() {
516 // Create a 3x4 tensor and take a slice
517 let data = vec![
518 1.0, 2.0, 3.0, 4.0, 5.0, -6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
519 ];
520 let x = Tensor::from_slice(&data, vec![3, 4]).unwrap();
521 // [[1, 2, 3, 4],
522 // [5, -6, 7, 8],
523 // [9, 10, 11, 12]]
524
525 // Select middle row (creates a view)
526 let middle_row = x.select(0, 1);
527 // [5, -6, 7, 8]
528 assert_eq!(middle_row.shape().dims, vec![4]);
529
530 let idx = middle_row.argmin();
531 assert_eq!(idx.get(&[0]), 1.0); // index 1 has value -6.0
532
533 // Test argmin_dim on 1D slice (should work the same as global argmin)
534 let idx_dim = middle_row.argmin_dim(0, false);
535 assert_eq!(idx_dim.shape().dims, vec![1]);
536 assert_eq!(idx_dim.get(&[0]), 1.0);
537
538 // Test with column slice
539 let second_col = x.select(1, 1);
540 // [2, -6, 10]
541 assert_eq!(second_col.shape().dims, vec![3]);
542 let idx_col = second_col.argmin();
543 assert_eq!(idx_col.get(&[0]), 1.0); // -6.0 at index 1
544 }
545
546 #[test]
547 fn test_argmin_level3_permuted_3d() {
548 // Test 3D tensor with permuted dimensions
549 let data = (0..24).map(|i| 24.0 - i as f32).collect::<Vec<_>>();
550 let x = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
551 // Shape [2, 3, 4] with values 24.0 down to 1.0
552 // Minimum value 1.0 is at the last position
553
554 // Permute to [4, 2, 3] (swap dims 0 and 2)
555 let x_perm = x.permute(vec![2, 1, 0]);
556 assert_eq!(x_perm.shape().dims, vec![4, 3, 2]);
557 assert!(!x_perm.is_contiguous());
558
559 // Global argmin should still find the minimum value (1.0)
560 let idx = x_perm.argmin();
561 assert_eq!(idx.get(&[0]), 23.0); // The min value 1.0 is still at flat index 23
562
563 // Test argmin along each dimension of permuted tensor
564 let idx0 = x_perm.argmin_dim(0, false); // [3, 2]
565 assert_eq!(idx0.shape().dims, vec![3, 2]);
566
567 let idx1 = x_perm.argmin_dim(1, false); // [4, 2]
568 assert_eq!(idx1.shape().dims, vec![4, 2]);
569
570 let idx2 = x_perm.argmin_dim(2, false); // [4, 3]
571 assert_eq!(idx2.shape().dims, vec![4, 3]);
572
573 // Verify some specific values
574 // Since values decrease from 24.0 to 1.0, the permuted tensor should have
575 // minimum values at the later positions in the original ordering
576 }
577
578 #[test]
579 fn test_argmin_level3_nested_views() {
580 // Test nested transformations (transpose then select)
581 let data = vec![
582 1.0, 2.0, -3.0, 4.0, 5.0, 6.0, 7.0, -8.0, 9.0, 10.0, 11.0, 12.0,
583 ];
584 let x = Tensor::from_slice(&data, vec![4, 3]).unwrap();
585
586 // First transpose, then select a row
587 let x_t = x.transpose(0, 1); // [3, 4]
588 let row = x_t.select(0, 1); // Select second row: [2, 5, -8, 11]
589 assert_eq!(row.shape().dims, vec![4]);
590
591 let idx = row.argmin();
592 assert_eq!(idx.get(&[0]), 2.0); // index 2 has value -8.0
593 }
594
595 #[test]
596 fn test_argmin_level3_strided_memory() {
597 // Test with highly strided memory patterns
598 let data: Vec<f32> = (0..60).map(|i| i as f32 - 30.0).collect();
599 let x = Tensor::from_slice(&data, vec![3, 4, 5]).unwrap();
600 // Values from -30.0 to 29.0
601
602 // Create complex views that result in non-contiguous memory
603 let x_perm = x.permute(vec![2, 0, 1]); // [5, 3, 4]
604 assert!(!x_perm.is_contiguous());
605
606 // Test global argmin
607 let idx = x_perm.argmin();
608 assert_eq!(idx.get(&[0]), 0.0); // -30.0 is at index 0
609
610 // Test dimension-wise argmin on permuted tensor
611 let idx0 = x_perm.argmin_dim(0, false);
612 assert_eq!(idx0.shape().dims, vec![3, 4]);
613
614 let idx1 = x_perm.argmin_dim(1, false);
615 assert_eq!(idx1.shape().dims, vec![5, 4]);
616
617 let idx2 = x_perm.argmin_dim(2, false);
618 assert_eq!(idx2.shape().dims, vec![5, 3]);
619 }
620
621 #[test]
622 fn test_argmin_level3_multiple_transformations() {
623 // Test with multiple chained transformations
624 let data = vec![
625 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
626 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, -24.0,
627 ];
628 let x = Tensor::from_slice(&data, vec![4, 6]).unwrap();
629
630 // Chain multiple transformations
631 let x_t = x.transpose(0, 1); // [6, 4]
632 let x_subset = x_t.select(0, 5); // Select last row: [6, 12, 18, -24]
633
634 // Note: select might create contiguous tensors in some cases, so we don't assert non-contiguous
635 assert_eq!(x_subset.shape().dims, vec![4]);
636
637 let idx = x_subset.argmin();
638 assert_eq!(idx.get(&[0]), 3.0); // -24.0 at index 3
639
640 // Test on a slice of the transposed tensor
641 let partial_col = x_t.select(1, 2); // Select third column: [15, 16, 17, 18, 19, 20]
642 let idx_partial = partial_col.argmin();
643 assert_eq!(idx_partial.get(&[0]), 0.0); // 15.0 at index 0
644
645 // Test argmin on the non-contiguous transposed tensor
646 assert!(!x_t.is_contiguous());
647 let idx_trans = x_t.argmin();
648 assert_eq!(idx_trans.get(&[0]), 23.0); // -24.0 is still at flat index 23
649 }
650
651 #[test]
652 fn test_argmin_level3_view_consistency() {
653 // Test that argmin results are consistent between original and view
654 let data = vec![
655 5.0, -2.0, 8.0, 1.0, // row 0: min -2.0 at col 1
656 3.0, 9.0, -4.0, 7.0, // row 1: min -4.0 at col 2
657 6.0, 0.0, 2.0, -1.0, // row 2: min -1.0 at col 3
658 ];
659 let x = Tensor::from_slice(&data, vec![3, 4]).unwrap();
660 // Global minimum is -4.0 at flat index 6
661
662 // Test argmin on original tensor
663 let idx_orig = x.argmin();
664 assert_eq!(idx_orig.get(&[0]), 6.0); // -4.0 at index 6
665
666 // Create a view by transposing and test consistency
667 let x_t = x.transpose(0, 1);
668 // Transposed tensor:
669 // [[5.0, 3.0, 6.0], // col 0 of original -> row 0: min 3.0 at index 1
670 // [-2.0, 9.0, 0.0], // col 1 of original -> row 1: min -2.0 at index 0
671 // [8.0, -4.0, 2.0], // col 2 of original -> row 2: min -4.0 at index 1
672 // [1.0, 7.0, -1.0]] // col 3 of original -> row 3: min -1.0 at index 2
673
674 let idx_view = x_t.argmin();
675 // The minimum value is still -4.0, but its flat index in the view may differ
676 // Let's just check that both find the minimum value correctly
677
678 // Extract actual minimum values to verify they're the same
679 let min_val_orig = unsafe {
680 let flat_idx = idx_orig.get(&[0]) as usize;
681 *x.as_ptr().add(flat_idx)
682 };
683 let min_val_view = unsafe {
684 let flat_idx = idx_view.get(&[0]) as usize;
685 let dims = x_t.shape().dims.clone();
686 let mut coords = vec![0; dims.len()];
687 let mut tmp = flat_idx;
688 for k in (0..dims.len()).rev() {
689 coords[k] = tmp % dims[k];
690 tmp /= dims[k];
691 }
692 let offset = x_t.shape().offset(&coords);
693 *x_t.as_ptr().add(offset)
694 };
695
696 assert_eq!(min_val_orig, -4.0);
697 assert_eq!(min_val_view, -4.0);
698
699 // Test simpler consistency: argmin along specific dimensions
700 let idx_dim0_orig = x.argmin_dim(0, false); // argmin along rows -> [4] (min of each column)
701 let idx_dim1_trans = x_t.argmin_dim(1, false); // argmin along columns -> [4] (min of each row)
702
703 // These should give the same results since we're reducing along corresponding dims
704 assert_eq!(idx_dim0_orig.shape().dims, vec![4]);
705 assert_eq!(idx_dim1_trans.shape().dims, vec![4]);
706
707 // Original columns vs transposed rows should match
708 assert_eq!(idx_dim0_orig.get(&[0]), 1.0); // col 0: min(5,3,6) = 3 at row 1
709 assert_eq!(idx_dim0_orig.get(&[1]), 0.0); // col 1: min(-2,9,0) = -2 at row 0
710 assert_eq!(idx_dim0_orig.get(&[2]), 1.0); // col 2: min(8,-4,2) = -4 at row 1
711 assert_eq!(idx_dim0_orig.get(&[3]), 2.0); // col 3: min(1,7,-1) = -1 at row 2
712
713 assert_eq!(idx_dim1_trans.get(&[0]), 1.0); // corresponds to col 0
714 assert_eq!(idx_dim1_trans.get(&[1]), 0.0); // corresponds to col 1
715 assert_eq!(idx_dim1_trans.get(&[2]), 1.0); // corresponds to col 2
716 assert_eq!(idx_dim1_trans.get(&[3]), 2.0); // corresponds to col 3
717 }
718
719 // Keep the old basic tests for compatibility
720 #[test]
721 fn test_argmin_basic() {
722 let x = Tensor::from_slice(&[3.0, -2.0, 5.0, -1.0], vec![4]).unwrap();
723 let idx = x.argmin();
724 unsafe {
725 assert_eq!(*idx.as_ptr(), 1.0);
726 }
727 }
728
729 #[test]
730 fn test_argmin_dim() {
731 let x = Tensor::from_slice(&[3.0, -2.0, 5.0, -1.0, 0.0, -3.0], vec![2, 3]).unwrap();
732 let idx0 = x.argmin_dim(1, true);
733 assert_eq!(idx0.shape().dims, vec![2, 1]);
734 assert_eq!(idx0.get(&[0, 0]), 1.0);
735 assert_eq!(idx0.get(&[1, 0]), 2.0);
736 }
737}