train_station/tensor/reductions/
min.rs

1use crate::gradtrack::{GradEngine, GradFn};
2use crate::tensor::core::Tensor;
3
4impl Tensor {
5    /// Computes the minimum value over all elements in the tensor
6    ///
7    /// Returns a scalar tensor containing the minimum value. For empty tensors,
8    /// returns positive infinity. This operation supports gradient tracking
9    /// through the GradTrack system.
10    ///
11    /// # Returns
12    ///
13    /// A tensor with shape `[1]` containing the minimum value
14    ///
15    /// # Examples
16    ///
17    /// ```
18    /// use train_station::Tensor;
19    ///
20    /// let tensor = Tensor::from_slice(&[1.0, 5.0, 3.0, 2.0], vec![2, 2]).unwrap();
21    /// let min_val = tensor.min();
22    /// assert_eq!(min_val.get(&[0]), 1.0);
23    /// ```
24    ///
25    /// ```
26    /// use train_station::Tensor;
27    ///
28    /// // Empty tensor case
29    /// let empty_tensor = Tensor::new(vec![0]);
30    /// let min_val = empty_tensor.min();
31    /// assert_eq!(min_val.get(&[0]), f32::INFINITY);
32    /// ```
33    ///
34    /// # GradTrack Support
35    ///
36    /// When `requires_grad` is true, this operation is tracked for automatic
37    /// differentiation. The gradient computation uses the saved input and output
38    /// for efficient backward pass.
39    #[track_caller]
40    pub fn min(&self) -> Tensor {
41        let mut out = Tensor::new(vec![1]);
42        if self.size() == 0 {
43            out.fill(f32::INFINITY);
44        } else {
45            let mut m = f32::INFINITY;
46
47            if self.is_contiguous() {
48                // Fast path for contiguous tensors
49                unsafe {
50                    let src = self.as_ptr();
51                    let size = self.size();
52                    m = *src;
53                    let mut i = 1usize;
54                    while i + 4 <= size {
55                        let x0 = *src.add(i);
56                        let x1 = *src.add(i + 1);
57                        let x2 = *src.add(i + 2);
58                        let x3 = *src.add(i + 3);
59                        m = m.min(x0).min(x1).min(x2).min(x3);
60                        i += 4;
61                    }
62                    while i < size {
63                        m = m.min(*src.add(i));
64                        i += 1;
65                    }
66                }
67            } else {
68                // Stride-aware path for non-contiguous tensors
69                let dims = self.shape().dims().to_vec();
70                for flat_idx in 0..self.size() {
71                    // Convert flat index to multi-dimensional coordinates
72                    let mut coords = vec![0; dims.len()];
73                    let mut tmp = flat_idx;
74                    for k in (0..dims.len()).rev() {
75                        coords[k] = tmp % dims[k];
76                        tmp /= dims[k];
77                    }
78
79                    // Get value using stride-aware offset
80                    let offset = self.shape().offset(&coords);
81                    let value = unsafe { *self.as_ptr().add(offset) };
82                    if flat_idx == 0 {
83                        m = value;
84                    } else {
85                        m = m.min(value);
86                    }
87                }
88            }
89
90            unsafe {
91                *out.as_mut_ptr() = m;
92            }
93        }
94
95        if self.requires_grad() {
96            let mut result = out.clone();
97            result.set_requires_grad_internal(true);
98            let grad_fn = GradFn::ReduceMin {
99                saved_output: Box::new(out.clone()),
100                saved_input: Box::new(self.clone()),
101                input_shape: self.shape().dims().to_vec(),
102            };
103            result.set_grad_fn(grad_fn.clone());
104            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
105            return result;
106        }
107
108        out
109    }
110
111    /// Computes the minimum value over specified dimensions
112    ///
113    /// Reduces the tensor along the specified dimensions by computing the minimum
114    /// value in each reduction group. The `keepdim` parameter determines whether
115    /// reduced dimensions are kept with size 1 or removed entirely.
116    ///
117    /// # Arguments
118    ///
119    /// * `dims` - Dimensions to reduce over (must be valid for the tensor's rank)
120    /// * `keepdim` - If true, reduced dimensions are kept with size 1; if false, they are removed
121    ///
122    /// # Returns
123    ///
124    /// A tensor with the specified dimensions reduced
125    ///
126    /// # Examples
127    ///
128    /// ```
129    /// use train_station::Tensor;
130    ///
131    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
132    ///
133    /// // Min over columns (dim 1), keeping dimensions
134    /// let min_cols = tensor.min_dims(&[1], true);
135    /// assert_eq!(min_cols.shape().dims(), vec![2, 1]);
136    /// assert_eq!(min_cols.get(&[0, 0]), 1.0);
137    /// assert_eq!(min_cols.get(&[1, 0]), 4.0);
138    /// ```
139    ///
140    /// ```
141    /// use train_station::Tensor;
142    ///
143    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
144    ///
145    /// // Min over rows (dim 0), removing dimensions
146    /// let min_rows = tensor.min_dims(&[0], false);
147    /// assert_eq!(min_rows.shape().dims(), vec![3]);
148    /// assert_eq!(min_rows.get(&[0]), 1.0);
149    /// assert_eq!(min_rows.get(&[1]), 2.0);
150    /// assert_eq!(min_rows.get(&[2]), 3.0);
151    /// ```
152    ///
153    /// ```
154    /// use train_station::Tensor;
155    ///
156    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
157    ///
158    /// // Min over multiple dimensions
159    /// let min_all = tensor.min_dims(&[0, 1], false);
160    /// assert_eq!(min_all.shape().dims(), vec![1]);
161    /// assert_eq!(min_all.get(&[0]), 1.0);
162    /// ```
163    ///
164    /// # Panics
165    ///
166    /// Panics if:
167    /// * `dims` is empty
168    /// * Any dimension in `dims` is out of bounds for the tensor's rank
169    ///
170    /// # GradTrack Support
171    ///
172    /// When `requires_grad` is true, this operation is tracked for automatic
173    /// differentiation. The gradient computation preserves the original input
174    /// shape and handles broadcasting correctly.
175    #[track_caller]
176    pub fn min_dims(&self, dims: &[usize], keepdim: bool) -> Tensor {
177        assert!(!dims.is_empty(), "min_dims requires at least one dimension");
178        let rank = self.shape().rank();
179        for &d in dims {
180            assert!(
181                d < rank,
182                "min_dims dim {} out of bounds for rank {}",
183                d,
184                rank
185            );
186        }
187
188        // Build output shape
189        let mut out_dims = self.shape().dims().to_vec();
190        let mut reduced: Vec<usize> = dims.to_vec();
191        reduced.sort_unstable();
192        reduced.dedup();
193        for &d in reduced.iter() {
194            out_dims[d] = if keepdim { 1 } else { 0 };
195        }
196        if !keepdim {
197            out_dims.retain(|&s| s != 0);
198        }
199        if out_dims.is_empty() {
200            out_dims.push(1);
201        }
202        let mut out = Tensor::zeros(out_dims.clone());
203
204        // Compute min along reduced dims
205        let in_shape = self.shape().dims().to_vec();
206        let out_rank = out.shape().rank();
207        let mut in_coords = vec![0usize; rank];
208        unsafe {
209            let dst = out.as_mut_ptr();
210            // Initialize output with +inf
211            for i in 0..out.size() {
212                *dst.add(i) = f32::INFINITY;
213            }
214            for lin in 0..self.size() {
215                let mut tmp = lin;
216                for i in (0..rank).rev() {
217                    let s = in_shape[i];
218                    in_coords[i] = if s == 0 { 0 } else { tmp % s };
219                    if s != 0 {
220                        tmp /= s;
221                    }
222                }
223
224                // Get input value using stride-aware offset
225                let in_offset = self.shape().offset(&in_coords);
226                let val = *self.as_ptr().add(in_offset);
227
228                let mut out_coords: Vec<usize> = Vec::with_capacity(out_rank);
229                for (i, &c) in in_coords.iter().enumerate().take(rank) {
230                    if reduced.contains(&i) {
231                        if keepdim {
232                            out_coords.push(0);
233                        }
234                    } else {
235                        out_coords.push(c);
236                    }
237                }
238                let off = if out_coords.is_empty() {
239                    0
240                } else {
241                    out.shape().offset(&out_coords)
242                };
243                let cur = *dst.add(off);
244                if val < cur {
245                    *dst.add(off) = val;
246                }
247            }
248        }
249
250        if self.requires_grad() {
251            let mut result = out.clone();
252            result.set_requires_grad_internal(true);
253            let grad_fn = GradFn::ReduceMinDims {
254                dims: reduced,
255                keepdim,
256                input_shape: self.shape().dims().to_vec(),
257                saved_output: Box::new(out.clone()),
258                saved_input: Box::new(self.clone()),
259            };
260            result.set_grad_fn(grad_fn.clone());
261            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
262            return result;
263        }
264
265        out
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272
273    #[test]
274    fn test_min_forward_basic() {
275        let mut x = Tensor::zeros(vec![2, 3]);
276        unsafe {
277            for i in 0..6 {
278                *x.as_mut_ptr().add(i) = (i as f32) - 3.0;
279            }
280        }
281        let m = x.min();
282        assert_eq!(m.shape().dims(), vec![1]);
283        unsafe {
284            assert_eq!(*m.as_ptr(), -3.0);
285        }
286    }
287
288    #[test]
289    fn test_min_dims_forward() {
290        let mut x = Tensor::zeros(vec![2, 3]);
291        unsafe {
292            for i in 0..6 {
293                *x.as_mut_ptr().add(i) = (i as f32) - 3.0;
294            }
295        }
296        let m = x.min_dims(&[1], true);
297        assert_eq!(m.shape().dims(), vec![2, 1]);
298        assert_eq!(m.get(&[0, 0]), -3.0);
299        assert_eq!(m.get(&[1, 0]), 0.0);
300    }
301
302    #[test]
303    fn test_min_non_contiguous_transpose() {
304        // Test min on transposed tensor (non-contiguous view)
305        let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
306        // Original: [[1, 2, 3], [4, 5, 6]]
307
308        let x_t = x.transpose(0, 1);
309        // Transposed: [[1, 4], [2, 5], [3, 6]]
310        assert!(!x_t.is_contiguous()); // Should be a view
311
312        let min_orig = x.min();
313        let min_view = x_t.min();
314
315        // Both should give the same result: min(1,2,3,4,5,6) = 1
316        assert_eq!(min_orig.get(&[0]), 1.0);
317        assert_eq!(min_view.get(&[0]), 1.0);
318    }
319
320    #[test]
321    fn test_min_dims_non_contiguous() {
322        // Test min_dims on non-contiguous tensor
323        let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
324        let x_t = x.transpose(0, 1); // [3, 2]
325        assert!(!x_t.is_contiguous());
326
327        // Min along dim 0 of transposed tensor
328        let min_dim0 = x_t.min_dims(&[0], false);
329        assert_eq!(min_dim0.shape().dims(), vec![2]);
330        // Should be [min(1,2,3), min(4,5,6)] = [1, 4]
331        assert_eq!(min_dim0.get(&[0]), 1.0);
332        assert_eq!(min_dim0.get(&[1]), 4.0);
333
334        // Min along dim 1 of transposed tensor
335        let min_dim1 = x_t.min_dims(&[1], false);
336        assert_eq!(min_dim1.shape().dims(), vec![3]);
337        // Should be [min(1,4), min(2,5), min(3,6)] = [1, 2, 3]
338        assert_eq!(min_dim1.get(&[0]), 1.0);
339        assert_eq!(min_dim1.get(&[1]), 2.0);
340        assert_eq!(min_dim1.get(&[2]), 3.0);
341    }
342
343    #[test]
344    fn test_min_permuted_tensor() {
345        // Test with permuted tensor
346        let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
347        let x = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
348
349        // Permute dimensions [2, 3, 4] -> [4, 2, 3]
350        let x_perm = x.permute(vec![2, 1, 0]);
351        assert!(!x_perm.is_contiguous());
352
353        let min_orig = x.min();
354        let min_perm = x_perm.min();
355
356        // Should give same result
357        assert_eq!(min_orig.get(&[0]), min_perm.get(&[0]));
358
359        // Expected min: min(0,1,2,...,23) = 0
360        assert_eq!(min_orig.get(&[0]), 0.0);
361    }
362
363    #[test]
364    fn test_min_with_negative_values() {
365        // Test min with negative values on non-contiguous tensor
366        let x = Tensor::from_slice(&[-5.0, -2.0, -8.0, -1.0, -3.0, -6.0], vec![2, 3]).unwrap();
367        let x_t = x.transpose(0, 1);
368        assert!(!x_t.is_contiguous());
369
370        let min_orig = x.min();
371        let min_view = x_t.min();
372
373        // Both should give the same result: min(-5,-2,-8,-1,-3,-6) = -8
374        assert_eq!(min_orig.get(&[0]), -8.0);
375        assert_eq!(min_view.get(&[0]), -8.0);
376    }
377}