train_station/tensor/reductions/max.rs
1use crate::gradtrack::{GradEngine, GradFn};
2use crate::tensor::core::Tensor;
3
4impl Tensor {
5 /// Computes the maximum value over all elements in the tensor
6 ///
7 /// Returns a scalar tensor containing the maximum value. For empty tensors,
8 /// returns negative infinity. This operation supports gradient tracking
9 /// through the GradTrack system.
10 ///
11 /// # Returns
12 ///
13 /// A tensor with shape `[1]` containing the maximum value
14 ///
15 /// # Examples
16 ///
17 /// ```
18 /// use train_station::Tensor;
19 ///
20 /// let tensor = Tensor::from_slice(&[1.0, 5.0, 3.0, 2.0], vec![2, 2]).unwrap();
21 /// let max_val = tensor.max();
22 /// assert_eq!(max_val.get(&[0]), 5.0);
23 /// ```
24 ///
25 /// # GradTrack Support
26 ///
27 /// When `requires_grad` is true, this operation is tracked for automatic
28 /// differentiation. The gradient computation uses the saved input and output
29 /// for efficient backward pass.
30 pub fn max(&self) -> Tensor {
31 let mut out = Tensor::new(vec![1]);
32 if self.size() == 0 {
33 out.fill(f32::NEG_INFINITY);
34 } else {
35 let mut m = f32::NEG_INFINITY;
36
37 if self.is_contiguous() {
38 // Fast path for contiguous tensors
39 unsafe {
40 let src = self.as_ptr();
41 let size = self.size();
42 m = *src;
43 let mut i = 1usize;
44 while i + 4 <= size {
45 let x0 = *src.add(i);
46 let x1 = *src.add(i + 1);
47 let x2 = *src.add(i + 2);
48 let x3 = *src.add(i + 3);
49 m = m.max(x0).max(x1).max(x2).max(x3);
50 i += 4;
51 }
52 while i < size {
53 m = m.max(*src.add(i));
54 i += 1;
55 }
56 }
57 } else {
58 // Stride-aware path for non-contiguous tensors
59 let dims = self.shape().dims.clone();
60 for flat_idx in 0..self.size() {
61 // Convert flat index to multi-dimensional coordinates
62 let mut coords = vec![0; dims.len()];
63 let mut tmp = flat_idx;
64 for k in (0..dims.len()).rev() {
65 coords[k] = tmp % dims[k];
66 tmp /= dims[k];
67 }
68
69 // Get value using stride-aware offset
70 let offset = self.shape().offset(&coords);
71 let value = unsafe { *self.as_ptr().add(offset) };
72 if flat_idx == 0 {
73 m = value;
74 } else {
75 m = m.max(value);
76 }
77 }
78 }
79
80 unsafe {
81 *out.as_mut_ptr() = m;
82 }
83 }
84
85 if self.requires_grad() {
86 let mut result = out.clone();
87 result.set_requires_grad_internal(true);
88 let grad_fn = GradFn::ReduceMax {
89 saved_output: Box::new(out.clone()),
90 saved_input: Box::new(self.clone()),
91 input_shape: self.shape().dims.clone(),
92 };
93 result.set_grad_fn(grad_fn.clone());
94 GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
95 return result;
96 }
97
98 out
99 }
100
101 /// Computes the maximum value over specified dimensions
102 ///
103 /// Reduces the tensor along the specified dimensions by computing the maximum
104 /// value in each reduction group. The `keepdim` parameter determines whether
105 /// reduced dimensions are kept with size 1 or removed entirely.
106 ///
107 /// # Arguments
108 ///
109 /// * `dims` - Dimensions to reduce over (must be valid for the tensor's rank)
110 /// * `keepdim` - If true, reduced dimensions are kept with size 1; if false, they are removed
111 ///
112 /// # Returns
113 ///
114 /// A tensor with the specified dimensions reduced
115 ///
116 /// # Examples
117 ///
118 /// ```
119 /// use train_station::Tensor;
120 ///
121 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
122 ///
123 /// // Max over columns (dim 1), keeping dimensions
124 /// let max_cols = tensor.max_dims(&[1], true);
125 /// assert_eq!(max_cols.shape().dims, vec![2, 1]);
126 /// assert_eq!(max_cols.get(&[0, 0]), 3.0);
127 /// assert_eq!(max_cols.get(&[1, 0]), 6.0);
128 ///
129 /// // Max over rows (dim 0), removing dimensions
130 /// let max_rows = tensor.max_dims(&[0], false);
131 /// assert_eq!(max_rows.shape().dims, vec![3]);
132 /// assert_eq!(max_rows.get(&[0]), 4.0);
133 /// assert_eq!(max_rows.get(&[1]), 5.0);
134 /// assert_eq!(max_rows.get(&[2]), 6.0);
135 /// ```
136 ///
137 /// # Panics
138 ///
139 /// Panics if:
140 /// * `dims` is empty
141 /// * Any dimension in `dims` is out of bounds for the tensor's rank
142 ///
143 /// # GradTrack Support
144 ///
145 /// When `requires_grad` is true, this operation is tracked for automatic
146 /// differentiation. The gradient computation preserves the original input
147 /// shape and handles broadcasting correctly.
148 pub fn max_dims(&self, dims: &[usize], keepdim: bool) -> Tensor {
149 assert!(!dims.is_empty(), "max_dims requires at least one dimension");
150 let rank = self.shape().rank();
151 for &d in dims {
152 assert!(
153 d < rank,
154 "max_dims dim {} out of bounds for rank {}",
155 d,
156 rank
157 );
158 }
159
160 let mut out_dims = self.shape().dims.clone();
161 let mut reduced: Vec<usize> = dims.to_vec();
162 reduced.sort_unstable();
163 reduced.dedup();
164 for &d in reduced.iter() {
165 out_dims[d] = if keepdim { 1 } else { 0 };
166 }
167 if !keepdim {
168 out_dims.retain(|&s| s != 0);
169 }
170 if out_dims.is_empty() {
171 out_dims.push(1);
172 }
173 let mut out = Tensor::zeros(out_dims.clone());
174
175 let in_shape = self.shape().dims.clone();
176 let out_rank = out.shape().rank();
177 let mut in_coords = vec![0usize; rank];
178 unsafe {
179 let dst = out.as_mut_ptr();
180 for i in 0..out.size() {
181 *dst.add(i) = f32::NEG_INFINITY;
182 }
183 for lin in 0..self.size() {
184 let mut tmp = lin;
185 for i in (0..rank).rev() {
186 let s = in_shape[i];
187 in_coords[i] = if s == 0 { 0 } else { tmp % s };
188 if s != 0 {
189 tmp /= s;
190 }
191 }
192
193 // Get input value using stride-aware offset
194 let in_offset = self.shape().offset(&in_coords);
195 let val = *self.as_ptr().add(in_offset);
196
197 let mut out_coords: Vec<usize> = Vec::with_capacity(out_rank);
198 for (i, &c) in in_coords.iter().enumerate().take(rank) {
199 if reduced.contains(&i) {
200 if keepdim {
201 out_coords.push(0);
202 }
203 } else {
204 out_coords.push(c);
205 }
206 }
207 let off = if out_coords.is_empty() {
208 0
209 } else {
210 out.shape().offset(&out_coords)
211 };
212 let cur = *dst.add(off);
213 if val > cur {
214 *dst.add(off) = val;
215 }
216 }
217 }
218
219 if self.requires_grad() {
220 let mut result = out.clone();
221 result.set_requires_grad_internal(true);
222 let grad_fn = GradFn::ReduceMaxDims {
223 dims: reduced,
224 keepdim,
225 input_shape: self.shape().dims.clone(),
226 saved_output: Box::new(out.clone()),
227 saved_input: Box::new(self.clone()),
228 };
229 result.set_grad_fn(grad_fn.clone());
230 GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
231 return result;
232 }
233
234 out
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241
242 #[test]
243 fn test_max_forward_basic() {
244 let mut x = Tensor::zeros(vec![2, 3]);
245 unsafe {
246 for i in 0..6 {
247 *x.as_mut_ptr().add(i) = (i as f32) - 3.0;
248 }
249 }
250 let m = x.max();
251 assert_eq!(m.shape().dims, vec![1]);
252 unsafe {
253 assert_eq!(*m.as_ptr(), 2.0);
254 }
255 }
256
257 #[test]
258 fn test_max_dims_forward() {
259 let mut x = Tensor::zeros(vec![2, 3]);
260 unsafe {
261 for i in 0..6 {
262 *x.as_mut_ptr().add(i) = (i as f32) - 3.0;
263 }
264 }
265 let m = x.max_dims(&[1], true);
266 assert_eq!(m.shape().dims, vec![2, 1]);
267 assert_eq!(m.get(&[0, 0]), -1.0);
268 assert_eq!(m.get(&[1, 0]), 2.0);
269 }
270
271 #[test]
272 fn test_max_non_contiguous_transpose() {
273 // Test max on transposed tensor (non-contiguous view)
274 let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
275 // Original: [[1, 2, 3], [4, 5, 6]]
276
277 let x_t = x.transpose(0, 1);
278 // Transposed: [[1, 4], [2, 5], [3, 6]]
279 assert!(!x_t.is_contiguous()); // Should be a view
280
281 let max_orig = x.max();
282 let max_view = x_t.max();
283
284 // Both should give the same result: max(1,2,3,4,5,6) = 6
285 assert_eq!(max_orig.get(&[0]), 6.0);
286 assert_eq!(max_view.get(&[0]), 6.0);
287 }
288
289 #[test]
290 fn test_max_dims_non_contiguous() {
291 // Test max_dims on non-contiguous tensor
292 let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
293 let x_t = x.transpose(0, 1); // [3, 2]
294 assert!(!x_t.is_contiguous());
295
296 // Max along dim 0 of transposed tensor
297 let max_dim0 = x_t.max_dims(&[0], false);
298 assert_eq!(max_dim0.shape().dims, vec![2]);
299 // Should be [max(1,2,3), max(4,5,6)] = [3, 6]
300 assert_eq!(max_dim0.get(&[0]), 3.0);
301 assert_eq!(max_dim0.get(&[1]), 6.0);
302
303 // Max along dim 1 of transposed tensor
304 let max_dim1 = x_t.max_dims(&[1], false);
305 assert_eq!(max_dim1.shape().dims, vec![3]);
306 // Should be [max(1,4), max(2,5), max(3,6)] = [4, 5, 6]
307 assert_eq!(max_dim1.get(&[0]), 4.0);
308 assert_eq!(max_dim1.get(&[1]), 5.0);
309 assert_eq!(max_dim1.get(&[2]), 6.0);
310 }
311
312 #[test]
313 fn test_max_permuted_tensor() {
314 // Test with permuted tensor
315 let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
316 let x = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
317
318 // Permute dimensions [2, 3, 4] -> [4, 2, 3]
319 let x_perm = x.permute(vec![2, 1, 0]);
320 assert!(!x_perm.is_contiguous());
321
322 let max_orig = x.max();
323 let max_perm = x_perm.max();
324
325 // Should give same result
326 assert_eq!(max_orig.get(&[0]), max_perm.get(&[0]));
327
328 // Expected max: max(0,1,2,...,23) = 23
329 assert_eq!(max_orig.get(&[0]), 23.0);
330 }
331
332 #[test]
333 fn test_max_with_negative_values() {
334 // Test max with negative values on non-contiguous tensor
335 let x = Tensor::from_slice(&[-5.0, -2.0, -8.0, -1.0, -3.0, -6.0], vec![2, 3]).unwrap();
336 let x_t = x.transpose(0, 1);
337 assert!(!x_t.is_contiguous());
338
339 let max_orig = x.max();
340 let max_view = x_t.max();
341
342 // Both should give the same result: max(-5,-2,-8,-1,-3,-6) = -1
343 assert_eq!(max_orig.get(&[0]), -1.0);
344 assert_eq!(max_view.get(&[0]), -1.0);
345 }
346}