train_station/tensor/reductions/
min.rs

1use crate::gradtrack::{GradEngine, GradFn};
2use crate::tensor::core::Tensor;
3
4impl Tensor {
5    /// Computes the minimum value over all elements in the tensor
6    ///
7    /// Returns a scalar tensor containing the minimum value. For empty tensors,
8    /// returns positive infinity. This operation supports gradient tracking
9    /// through the GradTrack system.
10    ///
11    /// # Returns
12    ///
13    /// A tensor with shape `[1]` containing the minimum value
14    ///
15    /// # Examples
16    ///
17    /// ```
18    /// use train_station::Tensor;
19    ///
20    /// let tensor = Tensor::from_slice(&[1.0, 5.0, 3.0, 2.0], vec![2, 2]).unwrap();
21    /// let min_val = tensor.min();
22    /// assert_eq!(min_val.get(&[0]), 1.0);
23    /// ```
24    ///
25    /// ```
26    /// use train_station::Tensor;
27    ///
28    /// // Empty tensor case
29    /// let empty_tensor = Tensor::new(vec![0]);
30    /// let min_val = empty_tensor.min();
31    /// assert_eq!(min_val.get(&[0]), f32::INFINITY);
32    /// ```
33    ///
34    /// # GradTrack Support
35    ///
36    /// When `requires_grad` is true, this operation is tracked for automatic
37    /// differentiation. The gradient computation uses the saved input and output
38    /// for efficient backward pass.
39    pub fn min(&self) -> Tensor {
40        let mut out = Tensor::new(vec![1]);
41        if self.size() == 0 {
42            out.fill(f32::INFINITY);
43        } else {
44            let mut m = f32::INFINITY;
45
46            if self.is_contiguous() {
47                // Fast path for contiguous tensors
48                unsafe {
49                    let src = self.as_ptr();
50                    let size = self.size();
51                    m = *src;
52                    let mut i = 1usize;
53                    while i + 4 <= size {
54                        let x0 = *src.add(i);
55                        let x1 = *src.add(i + 1);
56                        let x2 = *src.add(i + 2);
57                        let x3 = *src.add(i + 3);
58                        m = m.min(x0).min(x1).min(x2).min(x3);
59                        i += 4;
60                    }
61                    while i < size {
62                        m = m.min(*src.add(i));
63                        i += 1;
64                    }
65                }
66            } else {
67                // Stride-aware path for non-contiguous tensors
68                let dims = self.shape().dims.clone();
69                for flat_idx in 0..self.size() {
70                    // Convert flat index to multi-dimensional coordinates
71                    let mut coords = vec![0; dims.len()];
72                    let mut tmp = flat_idx;
73                    for k in (0..dims.len()).rev() {
74                        coords[k] = tmp % dims[k];
75                        tmp /= dims[k];
76                    }
77
78                    // Get value using stride-aware offset
79                    let offset = self.shape().offset(&coords);
80                    let value = unsafe { *self.as_ptr().add(offset) };
81                    if flat_idx == 0 {
82                        m = value;
83                    } else {
84                        m = m.min(value);
85                    }
86                }
87            }
88
89            unsafe {
90                *out.as_mut_ptr() = m;
91            }
92        }
93
94        if self.requires_grad() {
95            let mut result = out.clone();
96            result.set_requires_grad_internal(true);
97            let grad_fn = GradFn::ReduceMin {
98                saved_output: Box::new(out.clone()),
99                saved_input: Box::new(self.clone()),
100                input_shape: self.shape().dims.clone(),
101            };
102            result.set_grad_fn(grad_fn.clone());
103            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
104            return result;
105        }
106
107        out
108    }
109
110    /// Computes the minimum value over specified dimensions
111    ///
112    /// Reduces the tensor along the specified dimensions by computing the minimum
113    /// value in each reduction group. The `keepdim` parameter determines whether
114    /// reduced dimensions are kept with size 1 or removed entirely.
115    ///
116    /// # Arguments
117    ///
118    /// * `dims` - Dimensions to reduce over (must be valid for the tensor's rank)
119    /// * `keepdim` - If true, reduced dimensions are kept with size 1; if false, they are removed
120    ///
121    /// # Returns
122    ///
123    /// A tensor with the specified dimensions reduced
124    ///
125    /// # Examples
126    ///
127    /// ```
128    /// use train_station::Tensor;
129    ///
130    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
131    ///
132    /// // Min over columns (dim 1), keeping dimensions
133    /// let min_cols = tensor.min_dims(&[1], true);
134    /// assert_eq!(min_cols.shape().dims, vec![2, 1]);
135    /// assert_eq!(min_cols.get(&[0, 0]), 1.0);
136    /// assert_eq!(min_cols.get(&[1, 0]), 4.0);
137    /// ```
138    ///
139    /// ```
140    /// use train_station::Tensor;
141    ///
142    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
143    ///
144    /// // Min over rows (dim 0), removing dimensions
145    /// let min_rows = tensor.min_dims(&[0], false);
146    /// assert_eq!(min_rows.shape().dims, vec![3]);
147    /// assert_eq!(min_rows.get(&[0]), 1.0);
148    /// assert_eq!(min_rows.get(&[1]), 2.0);
149    /// assert_eq!(min_rows.get(&[2]), 3.0);
150    /// ```
151    ///
152    /// ```
153    /// use train_station::Tensor;
154    ///
155    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
156    ///
157    /// // Min over multiple dimensions
158    /// let min_all = tensor.min_dims(&[0, 1], false);
159    /// assert_eq!(min_all.shape().dims, vec![1]);
160    /// assert_eq!(min_all.get(&[0]), 1.0);
161    /// ```
162    ///
163    /// # Panics
164    ///
165    /// Panics if:
166    /// * `dims` is empty
167    /// * Any dimension in `dims` is out of bounds for the tensor's rank
168    ///
169    /// # GradTrack Support
170    ///
171    /// When `requires_grad` is true, this operation is tracked for automatic
172    /// differentiation. The gradient computation preserves the original input
173    /// shape and handles broadcasting correctly.
174    pub fn min_dims(&self, dims: &[usize], keepdim: bool) -> Tensor {
175        assert!(!dims.is_empty(), "min_dims requires at least one dimension");
176        let rank = self.shape().rank();
177        for &d in dims {
178            assert!(
179                d < rank,
180                "min_dims dim {} out of bounds for rank {}",
181                d,
182                rank
183            );
184        }
185
186        // Build output shape
187        let mut out_dims = self.shape().dims.clone();
188        let mut reduced: Vec<usize> = dims.to_vec();
189        reduced.sort_unstable();
190        reduced.dedup();
191        for &d in reduced.iter() {
192            out_dims[d] = if keepdim { 1 } else { 0 };
193        }
194        if !keepdim {
195            out_dims.retain(|&s| s != 0);
196        }
197        if out_dims.is_empty() {
198            out_dims.push(1);
199        }
200        let mut out = Tensor::zeros(out_dims.clone());
201
202        // Compute min along reduced dims
203        let in_shape = self.shape().dims.clone();
204        let out_rank = out.shape().rank();
205        let mut in_coords = vec![0usize; rank];
206        unsafe {
207            let dst = out.as_mut_ptr();
208            // Initialize output with +inf
209            for i in 0..out.size() {
210                *dst.add(i) = f32::INFINITY;
211            }
212            for lin in 0..self.size() {
213                let mut tmp = lin;
214                for i in (0..rank).rev() {
215                    let s = in_shape[i];
216                    in_coords[i] = if s == 0 { 0 } else { tmp % s };
217                    if s != 0 {
218                        tmp /= s;
219                    }
220                }
221
222                // Get input value using stride-aware offset
223                let in_offset = self.shape().offset(&in_coords);
224                let val = *self.as_ptr().add(in_offset);
225
226                let mut out_coords: Vec<usize> = Vec::with_capacity(out_rank);
227                for (i, &c) in in_coords.iter().enumerate().take(rank) {
228                    if reduced.contains(&i) {
229                        if keepdim {
230                            out_coords.push(0);
231                        }
232                    } else {
233                        out_coords.push(c);
234                    }
235                }
236                let off = if out_coords.is_empty() {
237                    0
238                } else {
239                    out.shape().offset(&out_coords)
240                };
241                let cur = *dst.add(off);
242                if val < cur {
243                    *dst.add(off) = val;
244                }
245            }
246        }
247
248        if self.requires_grad() {
249            let mut result = out.clone();
250            result.set_requires_grad_internal(true);
251            let grad_fn = GradFn::ReduceMinDims {
252                dims: reduced,
253                keepdim,
254                input_shape: self.shape().dims.clone(),
255                saved_output: Box::new(out.clone()),
256                saved_input: Box::new(self.clone()),
257            };
258            result.set_grad_fn(grad_fn.clone());
259            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
260            return result;
261        }
262
263        out
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270
271    #[test]
272    fn test_min_forward_basic() {
273        let mut x = Tensor::zeros(vec![2, 3]);
274        unsafe {
275            for i in 0..6 {
276                *x.as_mut_ptr().add(i) = (i as f32) - 3.0;
277            }
278        }
279        let m = x.min();
280        assert_eq!(m.shape().dims, vec![1]);
281        unsafe {
282            assert_eq!(*m.as_ptr(), -3.0);
283        }
284    }
285
286    #[test]
287    fn test_min_dims_forward() {
288        let mut x = Tensor::zeros(vec![2, 3]);
289        unsafe {
290            for i in 0..6 {
291                *x.as_mut_ptr().add(i) = (i as f32) - 3.0;
292            }
293        }
294        let m = x.min_dims(&[1], true);
295        assert_eq!(m.shape().dims, vec![2, 1]);
296        assert_eq!(m.get(&[0, 0]), -3.0);
297        assert_eq!(m.get(&[1, 0]), 0.0);
298    }
299
300    #[test]
301    fn test_min_non_contiguous_transpose() {
302        // Test min on transposed tensor (non-contiguous view)
303        let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
304        // Original: [[1, 2, 3], [4, 5, 6]]
305
306        let x_t = x.transpose(0, 1);
307        // Transposed: [[1, 4], [2, 5], [3, 6]]
308        assert!(!x_t.is_contiguous()); // Should be a view
309
310        let min_orig = x.min();
311        let min_view = x_t.min();
312
313        // Both should give the same result: min(1,2,3,4,5,6) = 1
314        assert_eq!(min_orig.get(&[0]), 1.0);
315        assert_eq!(min_view.get(&[0]), 1.0);
316    }
317
318    #[test]
319    fn test_min_dims_non_contiguous() {
320        // Test min_dims on non-contiguous tensor
321        let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
322        let x_t = x.transpose(0, 1); // [3, 2]
323        assert!(!x_t.is_contiguous());
324
325        // Min along dim 0 of transposed tensor
326        let min_dim0 = x_t.min_dims(&[0], false);
327        assert_eq!(min_dim0.shape().dims, vec![2]);
328        // Should be [min(1,2,3), min(4,5,6)] = [1, 4]
329        assert_eq!(min_dim0.get(&[0]), 1.0);
330        assert_eq!(min_dim0.get(&[1]), 4.0);
331
332        // Min along dim 1 of transposed tensor
333        let min_dim1 = x_t.min_dims(&[1], false);
334        assert_eq!(min_dim1.shape().dims, vec![3]);
335        // Should be [min(1,4), min(2,5), min(3,6)] = [1, 2, 3]
336        assert_eq!(min_dim1.get(&[0]), 1.0);
337        assert_eq!(min_dim1.get(&[1]), 2.0);
338        assert_eq!(min_dim1.get(&[2]), 3.0);
339    }
340
341    #[test]
342    fn test_min_permuted_tensor() {
343        // Test with permuted tensor
344        let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
345        let x = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
346
347        // Permute dimensions [2, 3, 4] -> [4, 2, 3]
348        let x_perm = x.permute(vec![2, 1, 0]);
349        assert!(!x_perm.is_contiguous());
350
351        let min_orig = x.min();
352        let min_perm = x_perm.min();
353
354        // Should give same result
355        assert_eq!(min_orig.get(&[0]), min_perm.get(&[0]));
356
357        // Expected min: min(0,1,2,...,23) = 0
358        assert_eq!(min_orig.get(&[0]), 0.0);
359    }
360
361    #[test]
362    fn test_min_with_negative_values() {
363        // Test min with negative values on non-contiguous tensor
364        let x = Tensor::from_slice(&[-5.0, -2.0, -8.0, -1.0, -3.0, -6.0], vec![2, 3]).unwrap();
365        let x_t = x.transpose(0, 1);
366        assert!(!x_t.is_contiguous());
367
368        let min_orig = x.min();
369        let min_view = x_t.min();
370
371        // Both should give the same result: min(-5,-2,-8,-1,-3,-6) = -8
372        assert_eq!(min_orig.get(&[0]), -8.0);
373        assert_eq!(min_view.get(&[0]), -8.0);
374    }
375}