train_station/tensor/reductions/
max.rs

1use crate::gradtrack::{GradEngine, GradFn};
2use crate::tensor::core::Tensor;
3
4impl Tensor {
5    /// Computes the maximum value over all elements in the tensor
6    ///
7    /// Returns a scalar tensor containing the maximum value. For empty tensors,
8    /// returns negative infinity. This operation supports gradient tracking
9    /// through the GradTrack system.
10    ///
11    /// # Returns
12    ///
13    /// A tensor with shape `[1]` containing the maximum value
14    ///
15    /// # Examples
16    ///
17    /// ```
18    /// use train_station::Tensor;
19    ///
20    /// let tensor = Tensor::from_slice(&[1.0, 5.0, 3.0, 2.0], vec![2, 2]).unwrap();
21    /// let max_val = tensor.max();
22    /// assert_eq!(max_val.get(&[0]), 5.0);
23    /// ```
24    ///
25    /// # GradTrack Support
26    ///
27    /// When `requires_grad` is true, this operation is tracked for automatic
28    /// differentiation. The gradient computation uses the saved input and output
29    /// for efficient backward pass.
30    #[track_caller]
31    pub fn max(&self) -> Tensor {
32        let mut out = Tensor::new(vec![1]);
33        if self.size() == 0 {
34            out.fill(f32::NEG_INFINITY);
35        } else {
36            let mut m = f32::NEG_INFINITY;
37
38            if self.is_contiguous() {
39                // Fast path for contiguous tensors
40                unsafe {
41                    let src = self.as_ptr();
42                    let size = self.size();
43                    m = *src;
44                    let mut i = 1usize;
45                    while i + 4 <= size {
46                        let x0 = *src.add(i);
47                        let x1 = *src.add(i + 1);
48                        let x2 = *src.add(i + 2);
49                        let x3 = *src.add(i + 3);
50                        m = m.max(x0).max(x1).max(x2).max(x3);
51                        i += 4;
52                    }
53                    while i < size {
54                        m = m.max(*src.add(i));
55                        i += 1;
56                    }
57                }
58            } else {
59                // Stride-aware path for non-contiguous tensors
60                let dims = self.shape().dims.clone();
61                for flat_idx in 0..self.size() {
62                    // Convert flat index to multi-dimensional coordinates
63                    let mut coords = vec![0; dims.len()];
64                    let mut tmp = flat_idx;
65                    for k in (0..dims.len()).rev() {
66                        coords[k] = tmp % dims[k];
67                        tmp /= dims[k];
68                    }
69
70                    // Get value using stride-aware offset
71                    let offset = self.shape().offset(&coords);
72                    let value = unsafe { *self.as_ptr().add(offset) };
73                    if flat_idx == 0 {
74                        m = value;
75                    } else {
76                        m = m.max(value);
77                    }
78                }
79            }
80
81            unsafe {
82                *out.as_mut_ptr() = m;
83            }
84        }
85
86        if self.requires_grad() {
87            let mut result = out.clone();
88            result.set_requires_grad_internal(true);
89            let grad_fn = GradFn::ReduceMax {
90                saved_output: Box::new(out.clone()),
91                saved_input: Box::new(self.clone()),
92                input_shape: self.shape().dims.clone(),
93            };
94            result.set_grad_fn(grad_fn.clone());
95            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
96            return result;
97        }
98
99        out
100    }
101
102    /// Computes the maximum value over specified dimensions
103    ///
104    /// Reduces the tensor along the specified dimensions by computing the maximum
105    /// value in each reduction group. The `keepdim` parameter determines whether
106    /// reduced dimensions are kept with size 1 or removed entirely.
107    ///
108    /// # Arguments
109    ///
110    /// * `dims` - Dimensions to reduce over (must be valid for the tensor's rank)
111    /// * `keepdim` - If true, reduced dimensions are kept with size 1; if false, they are removed
112    ///
113    /// # Returns
114    ///
115    /// A tensor with the specified dimensions reduced
116    ///
117    /// # Examples
118    ///
119    /// ```
120    /// use train_station::Tensor;
121    ///
122    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
123    ///
124    /// // Max over columns (dim 1), keeping dimensions
125    /// let max_cols = tensor.max_dims(&[1], true);
126    /// assert_eq!(max_cols.shape().dims, vec![2, 1]);
127    /// assert_eq!(max_cols.get(&[0, 0]), 3.0);
128    /// assert_eq!(max_cols.get(&[1, 0]), 6.0);
129    ///
130    /// // Max over rows (dim 0), removing dimensions
131    /// let max_rows = tensor.max_dims(&[0], false);
132    /// assert_eq!(max_rows.shape().dims, vec![3]);
133    /// assert_eq!(max_rows.get(&[0]), 4.0);
134    /// assert_eq!(max_rows.get(&[1]), 5.0);
135    /// assert_eq!(max_rows.get(&[2]), 6.0);
136    /// ```
137    ///
138    /// # Panics
139    ///
140    /// Panics if:
141    /// * `dims` is empty
142    /// * Any dimension in `dims` is out of bounds for the tensor's rank
143    ///
144    /// # GradTrack Support
145    ///
146    /// When `requires_grad` is true, this operation is tracked for automatic
147    /// differentiation. The gradient computation preserves the original input
148    /// shape and handles broadcasting correctly.
149    #[track_caller]
150    pub fn max_dims(&self, dims: &[usize], keepdim: bool) -> Tensor {
151        assert!(!dims.is_empty(), "max_dims requires at least one dimension");
152        let rank = self.shape().rank();
153        for &d in dims {
154            assert!(
155                d < rank,
156                "max_dims dim {} out of bounds for rank {}",
157                d,
158                rank
159            );
160        }
161
162        let mut out_dims = self.shape().dims.clone();
163        let mut reduced: Vec<usize> = dims.to_vec();
164        reduced.sort_unstable();
165        reduced.dedup();
166        for &d in reduced.iter() {
167            out_dims[d] = if keepdim { 1 } else { 0 };
168        }
169        if !keepdim {
170            out_dims.retain(|&s| s != 0);
171        }
172        if out_dims.is_empty() {
173            out_dims.push(1);
174        }
175        let mut out = Tensor::zeros(out_dims.clone());
176
177        let in_shape = self.shape().dims.clone();
178        let out_rank = out.shape().rank();
179        let mut in_coords = vec![0usize; rank];
180        unsafe {
181            let dst = out.as_mut_ptr();
182            for i in 0..out.size() {
183                *dst.add(i) = f32::NEG_INFINITY;
184            }
185            for lin in 0..self.size() {
186                let mut tmp = lin;
187                for i in (0..rank).rev() {
188                    let s = in_shape[i];
189                    in_coords[i] = if s == 0 { 0 } else { tmp % s };
190                    if s != 0 {
191                        tmp /= s;
192                    }
193                }
194
195                // Get input value using stride-aware offset
196                let in_offset = self.shape().offset(&in_coords);
197                let val = *self.as_ptr().add(in_offset);
198
199                let mut out_coords: Vec<usize> = Vec::with_capacity(out_rank);
200                for (i, &c) in in_coords.iter().enumerate().take(rank) {
201                    if reduced.contains(&i) {
202                        if keepdim {
203                            out_coords.push(0);
204                        }
205                    } else {
206                        out_coords.push(c);
207                    }
208                }
209                let off = if out_coords.is_empty() {
210                    0
211                } else {
212                    out.shape().offset(&out_coords)
213                };
214                let cur = *dst.add(off);
215                if val > cur {
216                    *dst.add(off) = val;
217                }
218            }
219        }
220
221        if self.requires_grad() {
222            let mut result = out.clone();
223            result.set_requires_grad_internal(true);
224            let grad_fn = GradFn::ReduceMaxDims {
225                dims: reduced,
226                keepdim,
227                input_shape: self.shape().dims.clone(),
228                saved_output: Box::new(out.clone()),
229                saved_input: Box::new(self.clone()),
230            };
231            result.set_grad_fn(grad_fn.clone());
232            GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
233            return result;
234        }
235
236        out
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243
244    #[test]
245    fn test_max_forward_basic() {
246        let mut x = Tensor::zeros(vec![2, 3]);
247        unsafe {
248            for i in 0..6 {
249                *x.as_mut_ptr().add(i) = (i as f32) - 3.0;
250            }
251        }
252        let m = x.max();
253        assert_eq!(m.shape().dims, vec![1]);
254        unsafe {
255            assert_eq!(*m.as_ptr(), 2.0);
256        }
257    }
258
259    #[test]
260    fn test_max_dims_forward() {
261        let mut x = Tensor::zeros(vec![2, 3]);
262        unsafe {
263            for i in 0..6 {
264                *x.as_mut_ptr().add(i) = (i as f32) - 3.0;
265            }
266        }
267        let m = x.max_dims(&[1], true);
268        assert_eq!(m.shape().dims, vec![2, 1]);
269        assert_eq!(m.get(&[0, 0]), -1.0);
270        assert_eq!(m.get(&[1, 0]), 2.0);
271    }
272
273    #[test]
274    fn test_max_non_contiguous_transpose() {
275        // Test max on transposed tensor (non-contiguous view)
276        let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
277        // Original: [[1, 2, 3], [4, 5, 6]]
278
279        let x_t = x.transpose(0, 1);
280        // Transposed: [[1, 4], [2, 5], [3, 6]]
281        assert!(!x_t.is_contiguous()); // Should be a view
282
283        let max_orig = x.max();
284        let max_view = x_t.max();
285
286        // Both should give the same result: max(1,2,3,4,5,6) = 6
287        assert_eq!(max_orig.get(&[0]), 6.0);
288        assert_eq!(max_view.get(&[0]), 6.0);
289    }
290
291    #[test]
292    fn test_max_dims_non_contiguous() {
293        // Test max_dims on non-contiguous tensor
294        let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
295        let x_t = x.transpose(0, 1); // [3, 2]
296        assert!(!x_t.is_contiguous());
297
298        // Max along dim 0 of transposed tensor
299        let max_dim0 = x_t.max_dims(&[0], false);
300        assert_eq!(max_dim0.shape().dims, vec![2]);
301        // Should be [max(1,2,3), max(4,5,6)] = [3, 6]
302        assert_eq!(max_dim0.get(&[0]), 3.0);
303        assert_eq!(max_dim0.get(&[1]), 6.0);
304
305        // Max along dim 1 of transposed tensor
306        let max_dim1 = x_t.max_dims(&[1], false);
307        assert_eq!(max_dim1.shape().dims, vec![3]);
308        // Should be [max(1,4), max(2,5), max(3,6)] = [4, 5, 6]
309        assert_eq!(max_dim1.get(&[0]), 4.0);
310        assert_eq!(max_dim1.get(&[1]), 5.0);
311        assert_eq!(max_dim1.get(&[2]), 6.0);
312    }
313
314    #[test]
315    fn test_max_permuted_tensor() {
316        // Test with permuted tensor
317        let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
318        let x = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
319
320        // Permute dimensions [2, 3, 4] -> [4, 2, 3]
321        let x_perm = x.permute(vec![2, 1, 0]);
322        assert!(!x_perm.is_contiguous());
323
324        let max_orig = x.max();
325        let max_perm = x_perm.max();
326
327        // Should give same result
328        assert_eq!(max_orig.get(&[0]), max_perm.get(&[0]));
329
330        // Expected max: max(0,1,2,...,23) = 23
331        assert_eq!(max_orig.get(&[0]), 23.0);
332    }
333
334    #[test]
335    fn test_max_with_negative_values() {
336        // Test max with negative values on non-contiguous tensor
337        let x = Tensor::from_slice(&[-5.0, -2.0, -8.0, -1.0, -3.0, -6.0], vec![2, 3]).unwrap();
338        let x_t = x.transpose(0, 1);
339        assert!(!x_t.is_contiguous());
340
341        let max_orig = x.max();
342        let max_view = x_t.max();
343
344        // Both should give the same result: max(-5,-2,-8,-1,-3,-6) = -1
345        assert_eq!(max_orig.get(&[0]), -1.0);
346        assert_eq!(max_view.get(&[0]), -1.0);
347    }
348}