train_station/tensor/reductions/argmin.rs
1use crate::tensor::core::Tensor;
2
3impl Tensor {
4 /// Returns the index of the minimum value in the tensor
5 ///
6 /// This method finds the flat index of the minimum value across all elements
7 /// in the tensor. The result is a scalar tensor containing the index as a
8 /// floating-point value. This operation is non-differentiable and the output
9 /// never requires gradient tracking.
10 ///
11 /// # Returns
12 ///
13 /// A tensor with shape `[1]` containing the flat index of the minimum value
14 /// as a `f32`. If the input tensor is empty, returns `0.0`.
15 ///
16 /// # Examples
17 ///
18 /// ```
19 /// use train_station::Tensor;
20 ///
21 /// let tensor = Tensor::from_slice(&[3.0, -2.0, 5.0, -1.0], vec![4]).unwrap();
22 /// let min_index = tensor.argmin();
23 /// assert_eq!(min_index.get(&[0]), 1.0); // -2.0 is at index 1
24 /// ```
25 ///
26 /// ```
27 /// use train_station::Tensor;
28 ///
29 /// // Empty tensor case
30 /// let empty_tensor = Tensor::new(vec![0]);
31 /// let min_index = empty_tensor.argmin();
32 /// assert_eq!(min_index.get(&[0]), 0.0);
33 /// ```
34 pub fn argmin(&self) -> Tensor {
35 let mut out = Tensor::new(vec![1]);
36 if self.size() == 0 {
37 out.fill(0.0);
38 return out;
39 }
40
41 let mut best_val = f32::INFINITY;
42 let mut best_idx = 0usize;
43
44 if self.is_contiguous() {
45 // Fast path for contiguous tensors
46 unsafe {
47 let src = self.as_ptr();
48 for i in 0..self.size() {
49 let v = *src.add(i);
50 if v < best_val {
51 best_val = v;
52 best_idx = i;
53 }
54 }
55 }
56 } else {
57 // Stride-aware path for non-contiguous tensors
58 let dims = self.shape().dims.clone();
59 for flat_idx in 0..self.size() {
60 // Convert flat index to multi-dimensional coordinates
61 let mut coords = vec![0; dims.len()];
62 let mut tmp = flat_idx;
63 for k in (0..dims.len()).rev() {
64 coords[k] = tmp % dims[k];
65 tmp /= dims[k];
66 }
67
68 // Get value using stride-aware offset
69 let offset = self.shape().offset(&coords);
70 let v = unsafe { *self.as_ptr().add(offset) };
71 if v < best_val {
72 best_val = v;
73 best_idx = flat_idx;
74 }
75 }
76 }
77
78 unsafe {
79 *out.as_mut_ptr() = best_idx as f32;
80 }
81 out
82 }
83
84 /// Returns the indices of minimum values along a specified dimension
85 ///
86 /// This method finds the indices of minimum values along the specified dimension.
87 /// The result contains the indices where the minimum values occur in that dimension.
88 /// This operation is non-differentiable and the output never requires gradient tracking.
89 ///
90 /// # Arguments
91 ///
92 /// * `dim` - The dimension along which to find minimum indices (0-based)
93 /// * `keepdim` - Whether to keep the reduced dimension in the output shape
94 /// - If `true`, the reduced dimension is kept with size 1
95 /// - If `false`, the reduced dimension is removed from the output shape
96 ///
97 /// # Returns
98 ///
99 /// A tensor containing the indices of minimum values along the specified dimension.
100 /// The output shape depends on `keepdim`:
101 /// - If `keepdim` is `true`, the reduced dimension has size 1
102 /// - If `keepdim` is `false`, the reduced dimension is removed
103 ///
104 /// # Panics
105 ///
106 /// * If `dim` is out of bounds for the tensor's rank
107 /// * If the dimension to reduce has size 0
108 ///
109 /// # Examples
110 ///
111 /// ```
112 /// use train_station::Tensor;
113 ///
114 /// let tensor = Tensor::from_slice(&[3.0, -2.0, 5.0, -1.0, 0.0, -3.0], vec![2, 3]).unwrap();
115 ///
116 /// // Find minimum indices along dimension 1 (columns), keeping the dimension
117 /// let indices = tensor.argmin_dim(1, true);
118 /// assert_eq!(indices.shape().dims, vec![2, 1]);
119 /// assert_eq!(indices.get(&[0, 0]), 1.0); // -2.0 is at index 1 in first row
120 /// assert_eq!(indices.get(&[1, 0]), 2.0); // -3.0 is at index 2 in second row
121 /// ```
122 ///
123 /// ```
124 /// use train_station::Tensor;
125 ///
126 /// let tensor = Tensor::from_slice(&[3.0, -2.0, 5.0, -1.0, 0.0, -3.0], vec![2, 3]).unwrap();
127 ///
128 /// // Find minimum indices along dimension 1 (columns), removing the dimension
129 /// let indices = tensor.argmin_dim(1, false);
130 /// assert_eq!(indices.shape().dims, vec![2]);
131 /// assert_eq!(indices.get(&[0]), 1.0); // -2.0 is at index 1 in first row
132 /// assert_eq!(indices.get(&[1]), 2.0); // -3.0 is at index 2 in second row
133 /// ```
134 ///
135 /// ```
136 /// use train_station::Tensor;
137 ///
138 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
139 ///
140 /// // Find minimum index in a 1D tensor
141 /// let index = tensor.argmin_dim(0, false);
142 /// assert_eq!(index.shape().dims, vec![1]);
143 /// assert_eq!(index.get(&[0]), 0.0); // 1.0 is at index 0
144 /// ```
145 pub fn argmin_dim(&self, dim: usize, keepdim: bool) -> Tensor {
146 let rank = self.shape().rank();
147 assert!(
148 dim < rank,
149 "argmin_dim dim {} out of bounds for rank {}",
150 dim,
151 rank
152 );
153
154 let in_dims = self.shape().dims.clone();
155 let reduce_size = in_dims[dim];
156 assert!(reduce_size > 0, "cannot argmin over empty dimension");
157
158 // Build output shape
159 let mut out_dims = in_dims.clone();
160 if keepdim {
161 out_dims[dim] = 1;
162 } else {
163 out_dims.remove(dim);
164 }
165 if out_dims.is_empty() {
166 out_dims.push(1);
167 }
168
169 let mut out = Tensor::zeros(out_dims.clone());
170
171 // Use stride-aware approach to handle non-contiguous tensors correctly
172 let out_size = out.size();
173
174 unsafe {
175 let dst = out.as_mut_ptr();
176
177 // Iterate over all output positions
178 for out_idx in 0..out_size {
179 // Convert flat output index to multi-dimensional coordinates
180 let mut out_coords = vec![0; out_dims.len()];
181 let mut tmp = out_idx;
182 for k in (0..out_dims.len()).rev() {
183 out_coords[k] = tmp % out_dims[k];
184 tmp /= out_dims[k];
185 }
186
187 // Convert output coordinates to input coordinates
188 let mut in_coords = vec![0; rank];
189 if keepdim {
190 // When keepdim=true, output coords map directly to input coords
191 for (k, &out_coord) in out_coords.iter().enumerate() {
192 if k == dim {
193 in_coords[k] = 0; // Will be set in the loop below
194 } else {
195 in_coords[k] = out_coord;
196 }
197 }
198 } else {
199 // When keepdim=false, we need to insert the missing dimension
200 let mut out_coord_idx = 0;
201 for (k, in_coord) in in_coords.iter_mut().enumerate() {
202 if k == dim {
203 *in_coord = 0; // Will be set in the loop below
204 } else {
205 *in_coord = out_coords[out_coord_idx];
206 out_coord_idx += 1;
207 }
208 }
209 }
210
211 // Find the argmin along the specified dimension
212 let mut best_val = f32::INFINITY;
213 let mut best_j = 0usize;
214
215 for j in 0..reduce_size {
216 in_coords[dim] = j;
217 let in_offset = self.shape().offset(&in_coords);
218 let v = *self.as_ptr().add(in_offset);
219 if v < best_val {
220 best_val = v;
221 best_j = j;
222 }
223 }
224
225 *dst.add(out_idx) = best_j as f32;
226 }
227 }
228
229 out
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236
237 // Level 1 Tests: Basic functionality with simple contiguous tensors
238 #[test]
239 fn test_argmin_level1_basic_1d() {
240 // Simple 1D case
241 let x = Tensor::from_slice(&[3.0, -2.0, 5.0, -1.0], vec![4]).unwrap();
242 let idx = x.argmin();
243 assert_eq!(idx.get(&[0]), 1.0); // -2.0 is at index 1
244 assert_eq!(idx.shape().dims, vec![1]);
245 }
246
247 #[test]
248 fn test_argmin_level1_basic_1d_edge_cases() {
249 // Single element
250 let x = Tensor::from_slice(&[42.0], vec![1]).unwrap();
251 let idx = x.argmin();
252 assert_eq!(idx.get(&[0]), 0.0);
253
254 // All same values - should return first occurrence
255 let x = Tensor::from_slice(&[5.0, 5.0, 5.0], vec![3]).unwrap();
256 let idx = x.argmin();
257 assert_eq!(idx.get(&[0]), 0.0);
258
259 // Negative values
260 let x = Tensor::from_slice(&[-1.0, -5.0, -2.0], vec![3]).unwrap();
261 let idx = x.argmin();
262 assert_eq!(idx.get(&[0]), 1.0); // -5.0 is at index 1
263 }
264
265 #[test]
266 fn test_argmin_level1_basic_2d_contiguous() {
267 // Simple 2D case - whole tensor argmin
268 let x = Tensor::from_slice(&[3.0, -2.0, 5.0, -1.0, 0.0, -3.0], vec![2, 3]).unwrap();
269 let idx = x.argmin();
270 assert_eq!(idx.get(&[0]), 5.0); // -3.0 is at flat index 5
271 assert_eq!(idx.shape().dims, vec![1]);
272 }
273
274 #[test]
275 fn test_argmin_level1_dim_2d_basic() {
276 // Test argmin_dim with 2D tensor
277 let x = Tensor::from_slice(&[3.0, -2.0, 5.0, -1.0, 0.0, -3.0], vec![2, 3]).unwrap();
278 // Tensor looks like:
279 // [[3.0, -2.0, 5.0],
280 // [-1.0, 0.0, -3.0]]
281
282 // Along dimension 1 (columns), keepdim=true
283 let idx1 = x.argmin_dim(1, true);
284 assert_eq!(idx1.shape().dims, vec![2, 1]);
285 assert_eq!(idx1.get(&[0, 0]), 1.0); // Row 0: -2.0 is at column index 1
286 assert_eq!(idx1.get(&[1, 0]), 2.0); // Row 1: -3.0 is at column index 2
287
288 // Along dimension 1 (columns), keepdim=false
289 let idx1_no_keep = x.argmin_dim(1, false);
290 assert_eq!(idx1_no_keep.shape().dims, vec![2]);
291 assert_eq!(idx1_no_keep.get(&[0]), 1.0);
292 assert_eq!(idx1_no_keep.get(&[1]), 2.0);
293
294 // Along dimension 0 (rows), keepdim=true
295 let idx0 = x.argmin_dim(0, true);
296 assert_eq!(idx0.shape().dims, vec![1, 3]);
297 assert_eq!(idx0.get(&[0, 0]), 1.0); // Column 0: -1.0 is at row index 1
298 assert_eq!(idx0.get(&[0, 1]), 0.0); // Column 1: -2.0 is at row index 0
299 assert_eq!(idx0.get(&[0, 2]), 1.0); // Column 2: -3.0 is at row index 1
300 }
301
302 #[test]
303 fn test_argmin_level1_3d_basic() {
304 // Test with 3D tensor
305 let data = vec![
306 1.0, -2.0, // [0,0,:] = [1.0, -2.0]
307 3.0, 4.0, // [0,1,:] = [3.0, 4.0]
308 -5.0, 6.0, // [1,0,:] = [-5.0, 6.0]
309 7.0, -8.0, // [1,1,:] = [7.0, -8.0]
310 ];
311 let x = Tensor::from_slice(&data, vec![2, 2, 2]).unwrap();
312
313 // Whole tensor argmin
314 let idx = x.argmin();
315 assert_eq!(idx.get(&[0]), 7.0); // -8.0 is at flat index 7
316
317 // Along dimension 2 (innermost), keepdim=false
318 let idx2 = x.argmin_dim(2, false);
319 assert_eq!(idx2.shape().dims, vec![2, 2]);
320 assert_eq!(idx2.get(&[0, 0]), 1.0); // [1.0, -2.0] -> min at index 1
321 assert_eq!(idx2.get(&[0, 1]), 0.0); // [3.0, 4.0] -> min at index 0
322 assert_eq!(idx2.get(&[1, 0]), 0.0); // [-5.0, 6.0] -> min at index 0
323 assert_eq!(idx2.get(&[1, 1]), 1.0); // [7.0, -8.0] -> min at index 1
324 }
325
326 // Level 2 Tests: Complex shapes, higher dimensions, and edge cases
327 #[test]
328 fn test_argmin_level2_large_tensors() {
329 // Test with larger tensors
330 let data: Vec<f32> = (0..1000).map(|i| (i as f32) * 0.1 - 50.0).collect();
331 // Values from -50.0 to 49.9, minimum at index 0
332 let x = Tensor::from_slice(&data, vec![1000]).unwrap();
333 let idx = x.argmin();
334 assert_eq!(idx.get(&[0]), 0.0);
335
336 // Reshape to 2D
337 let x_2d = Tensor::from_slice(&data, vec![25, 40]).unwrap();
338 let idx_2d = x_2d.argmin();
339 assert_eq!(idx_2d.get(&[0]), 0.0);
340
341 // Test along different dimensions
342 let idx_dim0 = x_2d.argmin_dim(0, false);
343 assert_eq!(idx_dim0.shape().dims, vec![40]);
344 assert_eq!(idx_dim0.get(&[0]), 0.0); // Column 0: minimum at row 0
345
346 let idx_dim1 = x_2d.argmin_dim(1, false);
347 assert_eq!(idx_dim1.shape().dims, vec![25]);
348 assert_eq!(idx_dim1.get(&[0]), 0.0); // Row 0: minimum at column 0
349 }
350
351 #[test]
352 fn test_argmin_level2_4d_tensor() {
353 // Test with 4D tensor [2, 3, 4, 5] = 120 elements
354 let data: Vec<f32> = (0..120).map(|i| 120.0 - i as f32).collect();
355 // Values from 120.0 down to 1.0, minimum at last index
356 let x = Tensor::from_slice(&data, vec![2, 3, 4, 5]).unwrap();
357
358 // Global argmin
359 let idx = x.argmin();
360 assert_eq!(idx.get(&[0]), 119.0); // minimum value 1.0 is at index 119
361
362 // Test argmin along dimension 3 (innermost)
363 let idx3 = x.argmin_dim(3, false);
364 assert_eq!(idx3.shape().dims, vec![2, 3, 4]);
365 // Each slice along dim 3 has values decreasing, so min is always at index 4
366 assert_eq!(idx3.get(&[0, 0, 0]), 4.0);
367 assert_eq!(idx3.get(&[1, 2, 3]), 4.0);
368
369 // Test argmin along dimension 0 (outermost)
370 let idx0 = x.argmin_dim(0, false);
371 assert_eq!(idx0.shape().dims, vec![3, 4, 5]);
372 // For each position, the minimum is in the second batch (index 1)
373 assert_eq!(idx0.get(&[0, 0, 0]), 1.0);
374 assert_eq!(idx0.get(&[2, 3, 4]), 1.0);
375 }
376
377 #[test]
378 fn test_argmin_level2_special_values() {
379 // Test with special floating point values
380 let data = vec![
381 f32::NAN, // 0
382 f32::INFINITY, // 1
383 -f32::INFINITY, // 2 <- this should be minimum
384 0.0, // 3
385 -0.0, // 4
386 1.0, // 5
387 ];
388 let x = Tensor::from_slice(&data, vec![6]).unwrap();
389 let idx = x.argmin();
390 assert_eq!(idx.get(&[0]), 2.0); // -infinity at index 2
391
392 // Test with all NaN
393 let nan_data = vec![f32::NAN, f32::NAN, f32::NAN];
394 let x_nan = Tensor::from_slice(&nan_data, vec![3]).unwrap();
395 let idx_nan = x_nan.argmin();
396 // With all NaN, should return first index
397 assert_eq!(idx_nan.get(&[0]), 0.0);
398
399 // Test with mix of normal values and NaN
400 let mixed_data = vec![1.0, f32::NAN, -5.0, f32::NAN, 3.0];
401 let x_mixed = Tensor::from_slice(&mixed_data, vec![5]).unwrap();
402 let idx_mixed = x_mixed.argmin();
403 assert_eq!(idx_mixed.get(&[0]), 2.0); // -5.0 at index 2
404 }
405
406 #[test]
407 fn test_argmin_level2_ties() {
408 // Test behavior with tied minimum values (should return first occurrence)
409 let data = vec![3.0, -2.0, 5.0, -2.0, 0.0, -2.0]; // -2.0 appears at indices 1, 3, 5
410 let x = Tensor::from_slice(&data, vec![6]).unwrap();
411 let idx = x.argmin();
412 assert_eq!(idx.get(&[0]), 1.0); // First occurrence of -2.0
413
414 // Test with 2D tensor and ties
415 let x_2d = Tensor::from_slice(&data, vec![2, 3]).unwrap();
416 // [[3.0, -2.0, 5.0],
417 // [-2.0, 0.0, -2.0]]
418
419 let idx_dim0 = x_2d.argmin_dim(0, false);
420 assert_eq!(idx_dim0.shape().dims, vec![3]);
421 assert_eq!(idx_dim0.get(&[0]), 1.0); // Column 0: min(-2.0 vs 3.0) -> row 1
422 assert_eq!(idx_dim0.get(&[1]), 0.0); // Column 1: min(-2.0 vs 0.0) -> row 0
423 assert_eq!(idx_dim0.get(&[2]), 1.0); // Column 2: min(5.0 vs -2.0) -> row 1
424
425 let idx_dim1 = x_2d.argmin_dim(1, false);
426 assert_eq!(idx_dim1.shape().dims, vec![2]);
427 assert_eq!(idx_dim1.get(&[0]), 1.0); // Row 0: min of [3.0, -2.0, 5.0] -> col 1
428 assert_eq!(idx_dim1.get(&[1]), 0.0); // Row 1: min of [-2.0, 0.0, -2.0] -> col 0 (first)
429 }
430
431 #[test]
432 fn test_argmin_level2_broadcasting_dims() {
433 // Test with dimensions of size 1 (singleton dimensions)
434 let data = vec![5.0, -3.0, 7.0, 1.0, -8.0, 2.0];
435 let x = Tensor::from_slice(&data, vec![1, 6, 1]).unwrap();
436
437 let idx = x.argmin();
438 assert_eq!(idx.get(&[0]), 4.0); // -8.0 at flat index 4
439
440 // Test argmin along different dimensions
441 let idx_dim0 = x.argmin_dim(0, false);
442 assert_eq!(idx_dim0.shape().dims, vec![6, 1]);
443
444 let idx_dim1 = x.argmin_dim(1, false);
445 assert_eq!(idx_dim1.shape().dims, vec![1, 1]);
446 assert_eq!(idx_dim1.get(&[0, 0]), 4.0); // -8.0 at position 4 along dim 1
447
448 let idx_dim2 = x.argmin_dim(2, false);
449 assert_eq!(idx_dim2.shape().dims, vec![1, 6]);
450 }
451
452 #[test]
453 fn test_argmin_level2_complex_3d() {
454 // Complex 3D case with multiple batch dimensions
455 let data = vec![
456 // Batch 0, Channel 0: [[1, 2], [3, 4]]
457 1.0, 2.0, 3.0, 4.0, // Batch 0, Channel 1: [[5, 6], [7, 8]]
458 5.0, 6.0, 7.0, 8.0, // Batch 0, Channel 2: [[-1, 0], [9, 10]]
459 -1.0, 0.0, 9.0, 10.0, // Batch 1, Channel 0: [[11, 12], [13, 14]]
460 11.0, 12.0, 13.0, 14.0, // Batch 1, Channel 1: [[15, 16], [17, 18]]
461 15.0, 16.0, 17.0, 18.0, // Batch 1, Channel 2: [[19, 20], [21, -5]]
462 19.0, 20.0, 21.0, -5.0,
463 ];
464 let x = Tensor::from_slice(&data, vec![2, 3, 2, 2]).unwrap();
465
466 // Global minimum
467 let idx = x.argmin();
468 assert_eq!(idx.get(&[0]), 23.0); // -5.0 is at flat index 23
469
470 // Argmin along dimension 1 (channels)
471 let idx_dim1 = x.argmin_dim(1, false);
472 assert_eq!(idx_dim1.shape().dims, vec![2, 2, 2]);
473 // At position [0,0,0]: min(1.0, 5.0, -1.0) = -1.0 at channel 2
474 assert_eq!(idx_dim1.get(&[0, 0, 0]), 2.0);
475 // At position [1,1,1]: min(14.0, 18.0, -5.0) = -5.0 at channel 2
476 assert_eq!(idx_dim1.get(&[1, 1, 1]), 2.0);
477 }
478
479 // Level 3 Tests: Non-contiguous tensors, views, and strided memory layouts
480 #[test]
481 fn test_argmin_level3_transpose_view() {
482 // Create a 2x3 tensor and transpose it to get a non-contiguous view
483 let x = Tensor::from_slice(&[1.0, 3.0, 2.0, 4.0, 0.0, -5.0], vec![2, 3]).unwrap();
484 // Original: [[1.0, 3.0, 2.0],
485 // [4.0, 0.0, -5.0]]
486
487 let x_t = x.transpose(0, 1);
488 // Transposed: [[1.0, 4.0],
489 // [3.0, 0.0],
490 // [2.0, -5.0]]
491 assert_eq!(x_t.shape().dims, vec![3, 2]);
492 assert!(!x_t.is_contiguous()); // Should be a view
493
494 // Test global argmin on transposed view
495 let idx = x_t.argmin();
496 assert_eq!(idx.get(&[0]), 5.0); // flat index 5 still points to value -5.0
497
498 // Test argmin along dim=0 of transposed tensor
499 let idx0 = x_t.argmin_dim(0, false);
500 assert_eq!(idx0.shape().dims, vec![2]);
501 assert_eq!(idx0.get(&[0]), 0.0); // col 0: [1.0, 3.0, 2.0] -> min 1.0 at index 0
502 assert_eq!(idx0.get(&[1]), 2.0); // col 1: [4.0, 0.0, -5.0] -> min -5.0 at index 2
503
504 // Test argmin along dim=1 of transposed tensor
505 let idx1 = x_t.argmin_dim(1, false);
506 assert_eq!(idx1.shape().dims, vec![3]);
507 assert_eq!(idx1.get(&[0]), 0.0); // row 0: [1.0, 4.0] -> min 1.0 at index 0
508 assert_eq!(idx1.get(&[1]), 1.0); // row 1: [3.0, 0.0] -> min 0.0 at index 1
509 assert_eq!(idx1.get(&[2]), 1.0); // row 2: [2.0, -5.0] -> min -5.0 at index 1
510 }
511
512 #[test]
513 fn test_argmin_level3_slice_view() {
514 // Create a 3x4 tensor and take a slice
515 let data = vec![
516 1.0, 2.0, 3.0, 4.0, 5.0, -6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
517 ];
518 let x = Tensor::from_slice(&data, vec![3, 4]).unwrap();
519 // [[1, 2, 3, 4],
520 // [5, -6, 7, 8],
521 // [9, 10, 11, 12]]
522
523 // Select middle row (creates a view)
524 let middle_row = x.select(0, 1);
525 // [5, -6, 7, 8]
526 assert_eq!(middle_row.shape().dims, vec![4]);
527
528 let idx = middle_row.argmin();
529 assert_eq!(idx.get(&[0]), 1.0); // index 1 has value -6.0
530
531 // Test argmin_dim on 1D slice (should work the same as global argmin)
532 let idx_dim = middle_row.argmin_dim(0, false);
533 assert_eq!(idx_dim.shape().dims, vec![1]);
534 assert_eq!(idx_dim.get(&[0]), 1.0);
535
536 // Test with column slice
537 let second_col = x.select(1, 1);
538 // [2, -6, 10]
539 assert_eq!(second_col.shape().dims, vec![3]);
540 let idx_col = second_col.argmin();
541 assert_eq!(idx_col.get(&[0]), 1.0); // -6.0 at index 1
542 }
543
544 #[test]
545 fn test_argmin_level3_permuted_3d() {
546 // Test 3D tensor with permuted dimensions
547 let data = (0..24).map(|i| 24.0 - i as f32).collect::<Vec<_>>();
548 let x = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
549 // Shape [2, 3, 4] with values 24.0 down to 1.0
550 // Minimum value 1.0 is at the last position
551
552 // Permute to [4, 2, 3] (swap dims 0 and 2)
553 let x_perm = x.permute(vec![2, 1, 0]);
554 assert_eq!(x_perm.shape().dims, vec![4, 3, 2]);
555 assert!(!x_perm.is_contiguous());
556
557 // Global argmin should still find the minimum value (1.0)
558 let idx = x_perm.argmin();
559 assert_eq!(idx.get(&[0]), 23.0); // The min value 1.0 is still at flat index 23
560
561 // Test argmin along each dimension of permuted tensor
562 let idx0 = x_perm.argmin_dim(0, false); // [3, 2]
563 assert_eq!(idx0.shape().dims, vec![3, 2]);
564
565 let idx1 = x_perm.argmin_dim(1, false); // [4, 2]
566 assert_eq!(idx1.shape().dims, vec![4, 2]);
567
568 let idx2 = x_perm.argmin_dim(2, false); // [4, 3]
569 assert_eq!(idx2.shape().dims, vec![4, 3]);
570
571 // Verify some specific values
572 // Since values decrease from 24.0 to 1.0, the permuted tensor should have
573 // minimum values at the later positions in the original ordering
574 }
575
576 #[test]
577 fn test_argmin_level3_nested_views() {
578 // Test nested transformations (transpose then select)
579 let data = vec![
580 1.0, 2.0, -3.0, 4.0, 5.0, 6.0, 7.0, -8.0, 9.0, 10.0, 11.0, 12.0,
581 ];
582 let x = Tensor::from_slice(&data, vec![4, 3]).unwrap();
583
584 // First transpose, then select a row
585 let x_t = x.transpose(0, 1); // [3, 4]
586 let row = x_t.select(0, 1); // Select second row: [2, 5, -8, 11]
587 assert_eq!(row.shape().dims, vec![4]);
588
589 let idx = row.argmin();
590 assert_eq!(idx.get(&[0]), 2.0); // index 2 has value -8.0
591 }
592
593 #[test]
594 fn test_argmin_level3_strided_memory() {
595 // Test with highly strided memory patterns
596 let data: Vec<f32> = (0..60).map(|i| i as f32 - 30.0).collect();
597 let x = Tensor::from_slice(&data, vec![3, 4, 5]).unwrap();
598 // Values from -30.0 to 29.0
599
600 // Create complex views that result in non-contiguous memory
601 let x_perm = x.permute(vec![2, 0, 1]); // [5, 3, 4]
602 assert!(!x_perm.is_contiguous());
603
604 // Test global argmin
605 let idx = x_perm.argmin();
606 assert_eq!(idx.get(&[0]), 0.0); // -30.0 is at index 0
607
608 // Test dimension-wise argmin on permuted tensor
609 let idx0 = x_perm.argmin_dim(0, false);
610 assert_eq!(idx0.shape().dims, vec![3, 4]);
611
612 let idx1 = x_perm.argmin_dim(1, false);
613 assert_eq!(idx1.shape().dims, vec![5, 4]);
614
615 let idx2 = x_perm.argmin_dim(2, false);
616 assert_eq!(idx2.shape().dims, vec![5, 3]);
617 }
618
619 #[test]
620 fn test_argmin_level3_multiple_transformations() {
621 // Test with multiple chained transformations
622 let data = vec![
623 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
624 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, -24.0,
625 ];
626 let x = Tensor::from_slice(&data, vec![4, 6]).unwrap();
627
628 // Chain multiple transformations
629 let x_t = x.transpose(0, 1); // [6, 4]
630 let x_subset = x_t.select(0, 5); // Select last row: [6, 12, 18, -24]
631
632 // Note: select might create contiguous tensors in some cases, so we don't assert non-contiguous
633 assert_eq!(x_subset.shape().dims, vec![4]);
634
635 let idx = x_subset.argmin();
636 assert_eq!(idx.get(&[0]), 3.0); // -24.0 at index 3
637
638 // Test on a slice of the transposed tensor
639 let partial_col = x_t.select(1, 2); // Select third column: [15, 16, 17, 18, 19, 20]
640 let idx_partial = partial_col.argmin();
641 assert_eq!(idx_partial.get(&[0]), 0.0); // 15.0 at index 0
642
643 // Test argmin on the non-contiguous transposed tensor
644 assert!(!x_t.is_contiguous());
645 let idx_trans = x_t.argmin();
646 assert_eq!(idx_trans.get(&[0]), 23.0); // -24.0 is still at flat index 23
647 }
648
649 #[test]
650 fn test_argmin_level3_view_consistency() {
651 // Test that argmin results are consistent between original and view
652 let data = vec![
653 5.0, -2.0, 8.0, 1.0, // row 0: min -2.0 at col 1
654 3.0, 9.0, -4.0, 7.0, // row 1: min -4.0 at col 2
655 6.0, 0.0, 2.0, -1.0, // row 2: min -1.0 at col 3
656 ];
657 let x = Tensor::from_slice(&data, vec![3, 4]).unwrap();
658 // Global minimum is -4.0 at flat index 6
659
660 // Test argmin on original tensor
661 let idx_orig = x.argmin();
662 assert_eq!(idx_orig.get(&[0]), 6.0); // -4.0 at index 6
663
664 // Create a view by transposing and test consistency
665 let x_t = x.transpose(0, 1);
666 // Transposed tensor:
667 // [[5.0, 3.0, 6.0], // col 0 of original -> row 0: min 3.0 at index 1
668 // [-2.0, 9.0, 0.0], // col 1 of original -> row 1: min -2.0 at index 0
669 // [8.0, -4.0, 2.0], // col 2 of original -> row 2: min -4.0 at index 1
670 // [1.0, 7.0, -1.0]] // col 3 of original -> row 3: min -1.0 at index 2
671
672 let idx_view = x_t.argmin();
673 // The minimum value is still -4.0, but its flat index in the view may differ
674 // Let's just check that both find the minimum value correctly
675
676 // Extract actual minimum values to verify they're the same
677 let min_val_orig = unsafe {
678 let flat_idx = idx_orig.get(&[0]) as usize;
679 *x.as_ptr().add(flat_idx)
680 };
681 let min_val_view = unsafe {
682 let flat_idx = idx_view.get(&[0]) as usize;
683 let dims = x_t.shape().dims.clone();
684 let mut coords = vec![0; dims.len()];
685 let mut tmp = flat_idx;
686 for k in (0..dims.len()).rev() {
687 coords[k] = tmp % dims[k];
688 tmp /= dims[k];
689 }
690 let offset = x_t.shape().offset(&coords);
691 *x_t.as_ptr().add(offset)
692 };
693
694 assert_eq!(min_val_orig, -4.0);
695 assert_eq!(min_val_view, -4.0);
696
697 // Test simpler consistency: argmin along specific dimensions
698 let idx_dim0_orig = x.argmin_dim(0, false); // argmin along rows -> [4] (min of each column)
699 let idx_dim1_trans = x_t.argmin_dim(1, false); // argmin along columns -> [4] (min of each row)
700
701 // These should give the same results since we're reducing along corresponding dims
702 assert_eq!(idx_dim0_orig.shape().dims, vec![4]);
703 assert_eq!(idx_dim1_trans.shape().dims, vec![4]);
704
705 // Original columns vs transposed rows should match
706 assert_eq!(idx_dim0_orig.get(&[0]), 1.0); // col 0: min(5,3,6) = 3 at row 1
707 assert_eq!(idx_dim0_orig.get(&[1]), 0.0); // col 1: min(-2,9,0) = -2 at row 0
708 assert_eq!(idx_dim0_orig.get(&[2]), 1.0); // col 2: min(8,-4,2) = -4 at row 1
709 assert_eq!(idx_dim0_orig.get(&[3]), 2.0); // col 3: min(1,7,-1) = -1 at row 2
710
711 assert_eq!(idx_dim1_trans.get(&[0]), 1.0); // corresponds to col 0
712 assert_eq!(idx_dim1_trans.get(&[1]), 0.0); // corresponds to col 1
713 assert_eq!(idx_dim1_trans.get(&[2]), 1.0); // corresponds to col 2
714 assert_eq!(idx_dim1_trans.get(&[3]), 2.0); // corresponds to col 3
715 }
716
717 // Keep the old basic tests for compatibility
718 #[test]
719 fn test_argmin_basic() {
720 let x = Tensor::from_slice(&[3.0, -2.0, 5.0, -1.0], vec![4]).unwrap();
721 let idx = x.argmin();
722 unsafe {
723 assert_eq!(*idx.as_ptr(), 1.0);
724 }
725 }
726
727 #[test]
728 fn test_argmin_dim() {
729 let x = Tensor::from_slice(&[3.0, -2.0, 5.0, -1.0, 0.0, -3.0], vec![2, 3]).unwrap();
730 let idx0 = x.argmin_dim(1, true);
731 assert_eq!(idx0.shape().dims, vec![2, 1]);
732 assert_eq!(idx0.get(&[0, 0]), 1.0);
733 assert_eq!(idx0.get(&[1, 0]), 2.0);
734 }
735}