train_station/tensor/reductions/max.rs
1use crate::gradtrack::{GradEngine, GradFn};
2use crate::tensor::core::Tensor;
3
4impl Tensor {
5 /// Computes the maximum value over all elements in the tensor
6 ///
7 /// Returns a scalar tensor containing the maximum value. For empty tensors,
8 /// returns negative infinity. This operation supports gradient tracking
9 /// through the GradTrack system.
10 ///
11 /// # Returns
12 ///
13 /// A tensor with shape `[1]` containing the maximum value
14 ///
15 /// # Examples
16 ///
17 /// ```
18 /// use train_station::Tensor;
19 ///
20 /// let tensor = Tensor::from_slice(&[1.0, 5.0, 3.0, 2.0], vec![2, 2]).unwrap();
21 /// let max_val = tensor.max();
22 /// assert_eq!(max_val.get(&[0]), 5.0);
23 /// ```
24 ///
25 /// # GradTrack Support
26 ///
27 /// When `requires_grad` is true, this operation is tracked for automatic
28 /// differentiation. The gradient computation uses the saved input and output
29 /// for efficient backward pass.
30 #[track_caller]
31 pub fn max(&self) -> Tensor {
32 let mut out = Tensor::new(vec![1]);
33 if self.size() == 0 {
34 out.fill(f32::NEG_INFINITY);
35 } else {
36 let mut m = f32::NEG_INFINITY;
37
38 if self.is_contiguous() {
39 // Fast path for contiguous tensors
40 unsafe {
41 let src = self.as_ptr();
42 let size = self.size();
43 m = *src;
44 let mut i = 1usize;
45 while i + 4 <= size {
46 let x0 = *src.add(i);
47 let x1 = *src.add(i + 1);
48 let x2 = *src.add(i + 2);
49 let x3 = *src.add(i + 3);
50 m = m.max(x0).max(x1).max(x2).max(x3);
51 i += 4;
52 }
53 while i < size {
54 m = m.max(*src.add(i));
55 i += 1;
56 }
57 }
58 } else {
59 // Stride-aware path for non-contiguous tensors
60 let dims = self.shape().dims().to_vec();
61 for flat_idx in 0..self.size() {
62 // Convert flat index to multi-dimensional coordinates
63 let mut coords = vec![0; dims.len()];
64 let mut tmp = flat_idx;
65 for k in (0..dims.len()).rev() {
66 coords[k] = tmp % dims[k];
67 tmp /= dims[k];
68 }
69
70 // Get value using stride-aware offset
71 let offset = self.shape().offset(&coords);
72 let value = unsafe { *self.as_ptr().add(offset) };
73 if flat_idx == 0 {
74 m = value;
75 } else {
76 m = m.max(value);
77 }
78 }
79 }
80
81 unsafe {
82 *out.as_mut_ptr() = m;
83 }
84 }
85
86 if self.requires_grad() {
87 let mut result = out.clone();
88 result.set_requires_grad_internal(true);
89 let grad_fn = GradFn::ReduceMax {
90 saved_output: Box::new(out.clone()),
91 saved_input: Box::new(self.clone()),
92 input_shape: self.shape().dims().to_vec(),
93 };
94 result.set_grad_fn(grad_fn.clone());
95 GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
96 return result;
97 }
98
99 out
100 }
101
102 /// Computes the maximum value over specified dimensions
103 ///
104 /// Reduces the tensor along the specified dimensions by computing the maximum
105 /// value in each reduction group. The `keepdim` parameter determines whether
106 /// reduced dimensions are kept with size 1 or removed entirely.
107 ///
108 /// # Arguments
109 ///
110 /// * `dims` - Dimensions to reduce over (must be valid for the tensor's rank)
111 /// * `keepdim` - If true, reduced dimensions are kept with size 1; if false, they are removed
112 ///
113 /// # Returns
114 ///
115 /// A tensor with the specified dimensions reduced
116 ///
117 /// # Examples
118 ///
119 /// ```
120 /// use train_station::Tensor;
121 ///
122 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
123 ///
124 /// // Max over columns (dim 1), keeping dimensions
125 /// let max_cols = tensor.max_dims(&[1], true);
126 /// assert_eq!(max_cols.shape().dims(), vec![2, 1]);
127 /// assert_eq!(max_cols.get(&[0, 0]), 3.0);
128 /// assert_eq!(max_cols.get(&[1, 0]), 6.0);
129 ///
130 /// // Max over rows (dim 0), removing dimensions
131 /// let max_rows = tensor.max_dims(&[0], false);
132 /// assert_eq!(max_rows.shape().dims(), vec![3]);
133 /// assert_eq!(max_rows.get(&[0]), 4.0);
134 /// assert_eq!(max_rows.get(&[1]), 5.0);
135 /// assert_eq!(max_rows.get(&[2]), 6.0);
136 /// ```
137 ///
138 /// # Panics
139 ///
140 /// Panics if:
141 /// * `dims` is empty
142 /// * Any dimension in `dims` is out of bounds for the tensor's rank
143 ///
144 /// # GradTrack Support
145 ///
146 /// When `requires_grad` is true, this operation is tracked for automatic
147 /// differentiation. The gradient computation preserves the original input
148 /// shape and handles broadcasting correctly.
149 #[track_caller]
150 pub fn max_dims(&self, dims: &[usize], keepdim: bool) -> Tensor {
151 assert!(!dims.is_empty(), "max_dims requires at least one dimension");
152 let rank = self.shape().rank();
153 for &d in dims {
154 assert!(
155 d < rank,
156 "max_dims dim {} out of bounds for rank {}",
157 d,
158 rank
159 );
160 }
161
162 let mut out_dims = self.shape().dims().to_vec();
163 let mut reduced: Vec<usize> = dims.to_vec();
164 reduced.sort_unstable();
165 reduced.dedup();
166 for &d in reduced.iter() {
167 out_dims[d] = if keepdim { 1 } else { 0 };
168 }
169 if !keepdim {
170 out_dims.retain(|&s| s != 0);
171 }
172 if out_dims.is_empty() {
173 out_dims.push(1);
174 }
175 let mut out = Tensor::zeros(out_dims.clone());
176
177 let in_shape = self.shape().dims().to_vec();
178 let out_rank = out.shape().rank();
179 let mut in_coords = vec![0usize; rank];
180 unsafe {
181 let dst = out.as_mut_ptr();
182 for i in 0..out.size() {
183 *dst.add(i) = f32::NEG_INFINITY;
184 }
185 for lin in 0..self.size() {
186 let mut tmp = lin;
187 for i in (0..rank).rev() {
188 let s = in_shape[i];
189 in_coords[i] = if s == 0 { 0 } else { tmp % s };
190 if s != 0 {
191 tmp /= s;
192 }
193 }
194
195 // Get input value using stride-aware offset
196 let in_offset = self.shape().offset(&in_coords);
197 let val = *self.as_ptr().add(in_offset);
198
199 let mut out_coords: Vec<usize> = Vec::with_capacity(out_rank);
200 for (i, &c) in in_coords.iter().enumerate().take(rank) {
201 if reduced.contains(&i) {
202 if keepdim {
203 out_coords.push(0);
204 }
205 } else {
206 out_coords.push(c);
207 }
208 }
209 let off = if out_coords.is_empty() {
210 0
211 } else {
212 out.shape().offset(&out_coords)
213 };
214 let cur = *dst.add(off);
215 if val > cur {
216 *dst.add(off) = val;
217 }
218 }
219 }
220
221 if self.requires_grad() {
222 let mut result = out.clone();
223 result.set_requires_grad_internal(true);
224 let grad_fn = GradFn::ReduceMaxDims {
225 dims: reduced,
226 keepdim,
227 input_shape: self.shape().dims().to_vec(),
228 saved_output: Box::new(out.clone()),
229 saved_input: Box::new(self.clone()),
230 };
231 result.set_grad_fn(grad_fn.clone());
232 GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
233 return result;
234 }
235
236 out
237 }
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243
244 #[test]
245 fn test_max_forward_basic() {
246 let mut x = Tensor::zeros(vec![2, 3]);
247 unsafe {
248 for i in 0..6 {
249 *x.as_mut_ptr().add(i) = (i as f32) - 3.0;
250 }
251 }
252 let m = x.max();
253 assert_eq!(m.shape().dims(), vec![1]);
254 unsafe {
255 assert_eq!(*m.as_ptr(), 2.0);
256 }
257 }
258
259 #[test]
260 fn test_max_dims_forward() {
261 let mut x = Tensor::zeros(vec![2, 3]);
262 unsafe {
263 for i in 0..6 {
264 *x.as_mut_ptr().add(i) = (i as f32) - 3.0;
265 }
266 }
267 let m = x.max_dims(&[1], true);
268 assert_eq!(m.shape().dims(), vec![2, 1]);
269 assert_eq!(m.get(&[0, 0]), -1.0);
270 assert_eq!(m.get(&[1, 0]), 2.0);
271 }
272
273 #[test]
274 fn test_max_non_contiguous_transpose() {
275 // Test max on transposed tensor (non-contiguous view)
276 let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
277 // Original: [[1, 2, 3], [4, 5, 6]]
278
279 let x_t = x.transpose(0, 1);
280 // Transposed: [[1, 4], [2, 5], [3, 6]]
281 assert!(!x_t.is_contiguous()); // Should be a view
282
283 let max_orig = x.max();
284 let max_view = x_t.max();
285
286 // Both should give the same result: max(1,2,3,4,5,6) = 6
287 assert_eq!(max_orig.get(&[0]), 6.0);
288 assert_eq!(max_view.get(&[0]), 6.0);
289 }
290
291 #[test]
292 fn test_max_dims_non_contiguous() {
293 // Test max_dims on non-contiguous tensor
294 let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
295 let x_t = x.transpose(0, 1); // [3, 2]
296 assert!(!x_t.is_contiguous());
297
298 // Max along dim 0 of transposed tensor
299 let max_dim0 = x_t.max_dims(&[0], false);
300 assert_eq!(max_dim0.shape().dims(), vec![2]);
301 // Should be [max(1,2,3), max(4,5,6)] = [3, 6]
302 assert_eq!(max_dim0.get(&[0]), 3.0);
303 assert_eq!(max_dim0.get(&[1]), 6.0);
304
305 // Max along dim 1 of transposed tensor
306 let max_dim1 = x_t.max_dims(&[1], false);
307 assert_eq!(max_dim1.shape().dims(), vec![3]);
308 // Should be [max(1,4), max(2,5), max(3,6)] = [4, 5, 6]
309 assert_eq!(max_dim1.get(&[0]), 4.0);
310 assert_eq!(max_dim1.get(&[1]), 5.0);
311 assert_eq!(max_dim1.get(&[2]), 6.0);
312 }
313
314 #[test]
315 fn test_max_permuted_tensor() {
316 // Test with permuted tensor
317 let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
318 let x = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
319
320 // Permute dimensions [2, 3, 4] -> [4, 2, 3]
321 let x_perm = x.permute(vec![2, 1, 0]);
322 assert!(!x_perm.is_contiguous());
323
324 let max_orig = x.max();
325 let max_perm = x_perm.max();
326
327 // Should give same result
328 assert_eq!(max_orig.get(&[0]), max_perm.get(&[0]));
329
330 // Expected max: max(0,1,2,...,23) = 23
331 assert_eq!(max_orig.get(&[0]), 23.0);
332 }
333
334 #[test]
335 fn test_max_with_negative_values() {
336 // Test max with negative values on non-contiguous tensor
337 let x = Tensor::from_slice(&[-5.0, -2.0, -8.0, -1.0, -3.0, -6.0], vec![2, 3]).unwrap();
338 let x_t = x.transpose(0, 1);
339 assert!(!x_t.is_contiguous());
340
341 let max_orig = x.max();
342 let max_view = x_t.max();
343
344 // Both should give the same result: max(-5,-2,-8,-1,-3,-6) = -1
345 assert_eq!(max_orig.get(&[0]), -1.0);
346 assert_eq!(max_view.get(&[0]), -1.0);
347 }
348}