train_station/tensor/reductions/
max.rs

1use crate::gradtrack::{GradEngine, GradFn};
2use crate::tensor::core::Tensor;
3
4impl Tensor {
5    /// Computes the maximum value over all elements in the tensor
6    ///
7    /// Returns a scalar tensor containing the maximum value. For empty tensors,
8    /// returns negative infinity. This operation supports gradient tracking
9    /// through the GradTrack system.
10    ///
11    /// # Returns
12    ///
13    /// A tensor with shape `[1]` containing the maximum value
14    ///
15    /// # Examples
16    ///
17    /// ```
18    /// use train_station::Tensor;
19    ///
20    /// let tensor = Tensor::from_slice(&[1.0, 5.0, 3.0, 2.0], vec![2, 2]).unwrap();
21    /// let max_val = tensor.max();
22    /// assert_eq!(max_val.get(&[0]), 5.0);
23    /// ```
24    ///
25    /// # GradTrack Support
26    ///
27    /// When `requires_grad` is true, this operation is tracked for automatic
28    /// differentiation. The gradient computation uses the saved input and output
29    /// for efficient backward pass.
30    pub fn max(&self) -> Tensor {
31        let mut out = Tensor::new(vec![1]);
32        if self.size() == 0 {
33            out.fill(f32::NEG_INFINITY);
34        } else {
35            let mut m = f32::NEG_INFINITY;
36
37            if self.is_contiguous() {
38                // Fast path for contiguous tensors
39                unsafe {
40                    let src = self.as_ptr();
41                    let size = self.size();
42                    m = *src;
43                    let mut i = 1usize;
44                    while i + 4 <= size {
45                        let x0 = *src.add(i);
46                        let x1 = *src.add(i + 1);
47                        let x2 = *src.add(i + 2);
48                        let x3 = *src.add(i + 3);
49                        m = m.max(x0).max(x1).max(x2).max(x3);
50                        i += 4;
51                    }
52                    while i < size {
53                        m = m.max(*src.add(i));
54                        i += 1;
55                    }
56                }
57            } else {
58                // Stride-aware path for non-contiguous tensors
59                let dims = self.shape().dims.clone();
60                for flat_idx in 0..self.size() {
61                    // Convert flat index to multi-dimensional coordinates
62                    let mut coords = vec![0; dims.len()];
63                    let mut tmp = flat_idx;
64                    for k in (0..dims.len()).rev() {
65                        coords[k] = tmp % dims[k];
66                        tmp /= dims[k];
67                    }
68
69                    // Get value using stride-aware offset
70                    let offset = self.shape().offset(&coords);
71                    let value = unsafe { *self.as_ptr().add(offset) };
72                    if flat_idx == 0 {
73                        m = value;
74                    } else {
75                        m = m.max(value);
76                    }
77                }
78            }
79
80            unsafe {
81                *out.as_mut_ptr() = m;
82            }
83        }
84
85        if self.requires_grad() {
86            let mut result = out.clone();
87            result.set_requires_grad_internal(true);
88            let grad_fn = GradFn::ReduceMax {
89                saved_output: Box::new(out.clone()),
90                saved_input: Box::new(self.clone()),
91                input_shape: self.shape().dims.clone(),
92            };
93            result.set_grad_fn(grad_fn.clone());
94            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
95            return result;
96        }
97
98        out
99    }
100
101    /// Computes the maximum value over specified dimensions
102    ///
103    /// Reduces the tensor along the specified dimensions by computing the maximum
104    /// value in each reduction group. The `keepdim` parameter determines whether
105    /// reduced dimensions are kept with size 1 or removed entirely.
106    ///
107    /// # Arguments
108    ///
109    /// * `dims` - Dimensions to reduce over (must be valid for the tensor's rank)
110    /// * `keepdim` - If true, reduced dimensions are kept with size 1; if false, they are removed
111    ///
112    /// # Returns
113    ///
114    /// A tensor with the specified dimensions reduced
115    ///
116    /// # Examples
117    ///
118    /// ```
119    /// use train_station::Tensor;
120    ///
121    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
122    ///
123    /// // Max over columns (dim 1), keeping dimensions
124    /// let max_cols = tensor.max_dims(&[1], true);
125    /// assert_eq!(max_cols.shape().dims, vec![2, 1]);
126    /// assert_eq!(max_cols.get(&[0, 0]), 3.0);
127    /// assert_eq!(max_cols.get(&[1, 0]), 6.0);
128    ///
129    /// // Max over rows (dim 0), removing dimensions
130    /// let max_rows = tensor.max_dims(&[0], false);
131    /// assert_eq!(max_rows.shape().dims, vec![3]);
132    /// assert_eq!(max_rows.get(&[0]), 4.0);
133    /// assert_eq!(max_rows.get(&[1]), 5.0);
134    /// assert_eq!(max_rows.get(&[2]), 6.0);
135    /// ```
136    ///
137    /// # Panics
138    ///
139    /// Panics if:
140    /// * `dims` is empty
141    /// * Any dimension in `dims` is out of bounds for the tensor's rank
142    ///
143    /// # GradTrack Support
144    ///
145    /// When `requires_grad` is true, this operation is tracked for automatic
146    /// differentiation. The gradient computation preserves the original input
147    /// shape and handles broadcasting correctly.
148    pub fn max_dims(&self, dims: &[usize], keepdim: bool) -> Tensor {
149        assert!(!dims.is_empty(), "max_dims requires at least one dimension");
150        let rank = self.shape().rank();
151        for &d in dims {
152            assert!(
153                d < rank,
154                "max_dims dim {} out of bounds for rank {}",
155                d,
156                rank
157            );
158        }
159
160        let mut out_dims = self.shape().dims.clone();
161        let mut reduced: Vec<usize> = dims.to_vec();
162        reduced.sort_unstable();
163        reduced.dedup();
164        for &d in reduced.iter() {
165            out_dims[d] = if keepdim { 1 } else { 0 };
166        }
167        if !keepdim {
168            out_dims.retain(|&s| s != 0);
169        }
170        if out_dims.is_empty() {
171            out_dims.push(1);
172        }
173        let mut out = Tensor::zeros(out_dims.clone());
174
175        let in_shape = self.shape().dims.clone();
176        let out_rank = out.shape().rank();
177        let mut in_coords = vec![0usize; rank];
178        unsafe {
179            let dst = out.as_mut_ptr();
180            for i in 0..out.size() {
181                *dst.add(i) = f32::NEG_INFINITY;
182            }
183            for lin in 0..self.size() {
184                let mut tmp = lin;
185                for i in (0..rank).rev() {
186                    let s = in_shape[i];
187                    in_coords[i] = if s == 0 { 0 } else { tmp % s };
188                    if s != 0 {
189                        tmp /= s;
190                    }
191                }
192
193                // Get input value using stride-aware offset
194                let in_offset = self.shape().offset(&in_coords);
195                let val = *self.as_ptr().add(in_offset);
196
197                let mut out_coords: Vec<usize> = Vec::with_capacity(out_rank);
198                for (i, &c) in in_coords.iter().enumerate().take(rank) {
199                    if reduced.contains(&i) {
200                        if keepdim {
201                            out_coords.push(0);
202                        }
203                    } else {
204                        out_coords.push(c);
205                    }
206                }
207                let off = if out_coords.is_empty() {
208                    0
209                } else {
210                    out.shape().offset(&out_coords)
211                };
212                let cur = *dst.add(off);
213                if val > cur {
214                    *dst.add(off) = val;
215                }
216            }
217        }
218
219        if self.requires_grad() {
220            let mut result = out.clone();
221            result.set_requires_grad_internal(true);
222            let grad_fn = GradFn::ReduceMaxDims {
223                dims: reduced,
224                keepdim,
225                input_shape: self.shape().dims.clone(),
226                saved_output: Box::new(out.clone()),
227                saved_input: Box::new(self.clone()),
228            };
229            result.set_grad_fn(grad_fn.clone());
230            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
231            return result;
232        }
233
234        out
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241
242    #[test]
243    fn test_max_forward_basic() {
244        let mut x = Tensor::zeros(vec![2, 3]);
245        unsafe {
246            for i in 0..6 {
247                *x.as_mut_ptr().add(i) = (i as f32) - 3.0;
248            }
249        }
250        let m = x.max();
251        assert_eq!(m.shape().dims, vec![1]);
252        unsafe {
253            assert_eq!(*m.as_ptr(), 2.0);
254        }
255    }
256
257    #[test]
258    fn test_max_dims_forward() {
259        let mut x = Tensor::zeros(vec![2, 3]);
260        unsafe {
261            for i in 0..6 {
262                *x.as_mut_ptr().add(i) = (i as f32) - 3.0;
263            }
264        }
265        let m = x.max_dims(&[1], true);
266        assert_eq!(m.shape().dims, vec![2, 1]);
267        assert_eq!(m.get(&[0, 0]), -1.0);
268        assert_eq!(m.get(&[1, 0]), 2.0);
269    }
270
271    #[test]
272    fn test_max_non_contiguous_transpose() {
273        // Test max on transposed tensor (non-contiguous view)
274        let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
275        // Original: [[1, 2, 3], [4, 5, 6]]
276
277        let x_t = x.transpose(0, 1);
278        // Transposed: [[1, 4], [2, 5], [3, 6]]
279        assert!(!x_t.is_contiguous()); // Should be a view
280
281        let max_orig = x.max();
282        let max_view = x_t.max();
283
284        // Both should give the same result: max(1,2,3,4,5,6) = 6
285        assert_eq!(max_orig.get(&[0]), 6.0);
286        assert_eq!(max_view.get(&[0]), 6.0);
287    }
288
289    #[test]
290    fn test_max_dims_non_contiguous() {
291        // Test max_dims on non-contiguous tensor
292        let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
293        let x_t = x.transpose(0, 1); // [3, 2]
294        assert!(!x_t.is_contiguous());
295
296        // Max along dim 0 of transposed tensor
297        let max_dim0 = x_t.max_dims(&[0], false);
298        assert_eq!(max_dim0.shape().dims, vec![2]);
299        // Should be [max(1,2,3), max(4,5,6)] = [3, 6]
300        assert_eq!(max_dim0.get(&[0]), 3.0);
301        assert_eq!(max_dim0.get(&[1]), 6.0);
302
303        // Max along dim 1 of transposed tensor
304        let max_dim1 = x_t.max_dims(&[1], false);
305        assert_eq!(max_dim1.shape().dims, vec![3]);
306        // Should be [max(1,4), max(2,5), max(3,6)] = [4, 5, 6]
307        assert_eq!(max_dim1.get(&[0]), 4.0);
308        assert_eq!(max_dim1.get(&[1]), 5.0);
309        assert_eq!(max_dim1.get(&[2]), 6.0);
310    }
311
312    #[test]
313    fn test_max_permuted_tensor() {
314        // Test with permuted tensor
315        let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
316        let x = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
317
318        // Permute dimensions [2, 3, 4] -> [4, 2, 3]
319        let x_perm = x.permute(vec![2, 1, 0]);
320        assert!(!x_perm.is_contiguous());
321
322        let max_orig = x.max();
323        let max_perm = x_perm.max();
324
325        // Should give same result
326        assert_eq!(max_orig.get(&[0]), max_perm.get(&[0]));
327
328        // Expected max: max(0,1,2,...,23) = 23
329        assert_eq!(max_orig.get(&[0]), 23.0);
330    }
331
332    #[test]
333    fn test_max_with_negative_values() {
334        // Test max with negative values on non-contiguous tensor
335        let x = Tensor::from_slice(&[-5.0, -2.0, -8.0, -1.0, -3.0, -6.0], vec![2, 3]).unwrap();
336        let x_t = x.transpose(0, 1);
337        assert!(!x_t.is_contiguous());
338
339        let max_orig = x.max();
340        let max_view = x_t.max();
341
342        // Both should give the same result: max(-5,-2,-8,-1,-3,-6) = -1
343        assert_eq!(max_orig.get(&[0]), -1.0);
344        assert_eq!(max_view.get(&[0]), -1.0);
345    }
346}