train_station/tensor/reductions/min.rs
1use crate::gradtrack::{GradEngine, GradFn};
2use crate::tensor::core::Tensor;
3
4impl Tensor {
5 /// Computes the minimum value over all elements in the tensor
6 ///
7 /// Returns a scalar tensor containing the minimum value. For empty tensors,
8 /// returns positive infinity. This operation supports gradient tracking
9 /// through the GradTrack system.
10 ///
11 /// # Returns
12 ///
13 /// A tensor with shape `[1]` containing the minimum value
14 ///
15 /// # Examples
16 ///
17 /// ```
18 /// use train_station::Tensor;
19 ///
20 /// let tensor = Tensor::from_slice(&[1.0, 5.0, 3.0, 2.0], vec![2, 2]).unwrap();
21 /// let min_val = tensor.min();
22 /// assert_eq!(min_val.get(&[0]), 1.0);
23 /// ```
24 ///
25 /// ```
26 /// use train_station::Tensor;
27 ///
28 /// // Empty tensor case
29 /// let empty_tensor = Tensor::new(vec![0]);
30 /// let min_val = empty_tensor.min();
31 /// assert_eq!(min_val.get(&[0]), f32::INFINITY);
32 /// ```
33 ///
34 /// # GradTrack Support
35 ///
36 /// When `requires_grad` is true, this operation is tracked for automatic
37 /// differentiation. The gradient computation uses the saved input and output
38 /// for efficient backward pass.
39 #[track_caller]
40 pub fn min(&self) -> Tensor {
41 let mut out = Tensor::new(vec![1]);
42 if self.size() == 0 {
43 out.fill(f32::INFINITY);
44 } else {
45 let mut m = f32::INFINITY;
46
47 if self.is_contiguous() {
48 // Fast path for contiguous tensors
49 unsafe {
50 let src = self.as_ptr();
51 let size = self.size();
52 m = *src;
53 let mut i = 1usize;
54 while i + 4 <= size {
55 let x0 = *src.add(i);
56 let x1 = *src.add(i + 1);
57 let x2 = *src.add(i + 2);
58 let x3 = *src.add(i + 3);
59 m = m.min(x0).min(x1).min(x2).min(x3);
60 i += 4;
61 }
62 while i < size {
63 m = m.min(*src.add(i));
64 i += 1;
65 }
66 }
67 } else {
68 // Stride-aware path for non-contiguous tensors
69 let dims = self.shape().dims().to_vec();
70 for flat_idx in 0..self.size() {
71 // Convert flat index to multi-dimensional coordinates
72 let mut coords = vec![0; dims.len()];
73 let mut tmp = flat_idx;
74 for k in (0..dims.len()).rev() {
75 coords[k] = tmp % dims[k];
76 tmp /= dims[k];
77 }
78
79 // Get value using stride-aware offset
80 let offset = self.shape().offset(&coords);
81 let value = unsafe { *self.as_ptr().add(offset) };
82 if flat_idx == 0 {
83 m = value;
84 } else {
85 m = m.min(value);
86 }
87 }
88 }
89
90 unsafe {
91 *out.as_mut_ptr() = m;
92 }
93 }
94
95 if self.requires_grad() {
96 let mut result = out.clone();
97 result.set_requires_grad_internal(true);
98 let grad_fn = GradFn::ReduceMin {
99 saved_output: Box::new(out.clone()),
100 saved_input: Box::new(self.clone()),
101 input_shape: self.shape().dims().to_vec(),
102 };
103 result.set_grad_fn(grad_fn.clone());
104 GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
105 return result;
106 }
107
108 out
109 }
110
111 /// Computes the minimum value over specified dimensions
112 ///
113 /// Reduces the tensor along the specified dimensions by computing the minimum
114 /// value in each reduction group. The `keepdim` parameter determines whether
115 /// reduced dimensions are kept with size 1 or removed entirely.
116 ///
117 /// # Arguments
118 ///
119 /// * `dims` - Dimensions to reduce over (must be valid for the tensor's rank)
120 /// * `keepdim` - If true, reduced dimensions are kept with size 1; if false, they are removed
121 ///
122 /// # Returns
123 ///
124 /// A tensor with the specified dimensions reduced
125 ///
126 /// # Examples
127 ///
128 /// ```
129 /// use train_station::Tensor;
130 ///
131 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
132 ///
133 /// // Min over columns (dim 1), keeping dimensions
134 /// let min_cols = tensor.min_dims(&[1], true);
135 /// assert_eq!(min_cols.shape().dims(), vec![2, 1]);
136 /// assert_eq!(min_cols.get(&[0, 0]), 1.0);
137 /// assert_eq!(min_cols.get(&[1, 0]), 4.0);
138 /// ```
139 ///
140 /// ```
141 /// use train_station::Tensor;
142 ///
143 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
144 ///
145 /// // Min over rows (dim 0), removing dimensions
146 /// let min_rows = tensor.min_dims(&[0], false);
147 /// assert_eq!(min_rows.shape().dims(), vec![3]);
148 /// assert_eq!(min_rows.get(&[0]), 1.0);
149 /// assert_eq!(min_rows.get(&[1]), 2.0);
150 /// assert_eq!(min_rows.get(&[2]), 3.0);
151 /// ```
152 ///
153 /// ```
154 /// use train_station::Tensor;
155 ///
156 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
157 ///
158 /// // Min over multiple dimensions
159 /// let min_all = tensor.min_dims(&[0, 1], false);
160 /// assert_eq!(min_all.shape().dims(), vec![1]);
161 /// assert_eq!(min_all.get(&[0]), 1.0);
162 /// ```
163 ///
164 /// # Panics
165 ///
166 /// Panics if:
167 /// * `dims` is empty
168 /// * Any dimension in `dims` is out of bounds for the tensor's rank
169 ///
170 /// # GradTrack Support
171 ///
172 /// When `requires_grad` is true, this operation is tracked for automatic
173 /// differentiation. The gradient computation preserves the original input
174 /// shape and handles broadcasting correctly.
175 #[track_caller]
176 pub fn min_dims(&self, dims: &[usize], keepdim: bool) -> Tensor {
177 assert!(!dims.is_empty(), "min_dims requires at least one dimension");
178 let rank = self.shape().rank();
179 for &d in dims {
180 assert!(
181 d < rank,
182 "min_dims dim {} out of bounds for rank {}",
183 d,
184 rank
185 );
186 }
187
188 // Build output shape
189 let mut out_dims = self.shape().dims().to_vec();
190 let mut reduced: Vec<usize> = dims.to_vec();
191 reduced.sort_unstable();
192 reduced.dedup();
193 for &d in reduced.iter() {
194 out_dims[d] = if keepdim { 1 } else { 0 };
195 }
196 if !keepdim {
197 out_dims.retain(|&s| s != 0);
198 }
199 if out_dims.is_empty() {
200 out_dims.push(1);
201 }
202 let mut out = Tensor::zeros(out_dims.clone());
203
204 // Compute min along reduced dims
205 let in_shape = self.shape().dims().to_vec();
206 let out_rank = out.shape().rank();
207 let mut in_coords = vec![0usize; rank];
208 unsafe {
209 let dst = out.as_mut_ptr();
210 // Initialize output with +inf
211 for i in 0..out.size() {
212 *dst.add(i) = f32::INFINITY;
213 }
214 for lin in 0..self.size() {
215 let mut tmp = lin;
216 for i in (0..rank).rev() {
217 let s = in_shape[i];
218 in_coords[i] = if s == 0 { 0 } else { tmp % s };
219 if s != 0 {
220 tmp /= s;
221 }
222 }
223
224 // Get input value using stride-aware offset
225 let in_offset = self.shape().offset(&in_coords);
226 let val = *self.as_ptr().add(in_offset);
227
228 let mut out_coords: Vec<usize> = Vec::with_capacity(out_rank);
229 for (i, &c) in in_coords.iter().enumerate().take(rank) {
230 if reduced.contains(&i) {
231 if keepdim {
232 out_coords.push(0);
233 }
234 } else {
235 out_coords.push(c);
236 }
237 }
238 let off = if out_coords.is_empty() {
239 0
240 } else {
241 out.shape().offset(&out_coords)
242 };
243 let cur = *dst.add(off);
244 if val < cur {
245 *dst.add(off) = val;
246 }
247 }
248 }
249
250 if self.requires_grad() {
251 let mut result = out.clone();
252 result.set_requires_grad_internal(true);
253 let grad_fn = GradFn::ReduceMinDims {
254 dims: reduced,
255 keepdim,
256 input_shape: self.shape().dims().to_vec(),
257 saved_output: Box::new(out.clone()),
258 saved_input: Box::new(self.clone()),
259 };
260 result.set_grad_fn(grad_fn.clone());
261 GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
262 return result;
263 }
264
265 out
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272
273 #[test]
274 fn test_min_forward_basic() {
275 let mut x = Tensor::zeros(vec![2, 3]);
276 unsafe {
277 for i in 0..6 {
278 *x.as_mut_ptr().add(i) = (i as f32) - 3.0;
279 }
280 }
281 let m = x.min();
282 assert_eq!(m.shape().dims(), vec![1]);
283 unsafe {
284 assert_eq!(*m.as_ptr(), -3.0);
285 }
286 }
287
288 #[test]
289 fn test_min_dims_forward() {
290 let mut x = Tensor::zeros(vec![2, 3]);
291 unsafe {
292 for i in 0..6 {
293 *x.as_mut_ptr().add(i) = (i as f32) - 3.0;
294 }
295 }
296 let m = x.min_dims(&[1], true);
297 assert_eq!(m.shape().dims(), vec![2, 1]);
298 assert_eq!(m.get(&[0, 0]), -3.0);
299 assert_eq!(m.get(&[1, 0]), 0.0);
300 }
301
302 #[test]
303 fn test_min_non_contiguous_transpose() {
304 // Test min on transposed tensor (non-contiguous view)
305 let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
306 // Original: [[1, 2, 3], [4, 5, 6]]
307
308 let x_t = x.transpose(0, 1);
309 // Transposed: [[1, 4], [2, 5], [3, 6]]
310 assert!(!x_t.is_contiguous()); // Should be a view
311
312 let min_orig = x.min();
313 let min_view = x_t.min();
314
315 // Both should give the same result: min(1,2,3,4,5,6) = 1
316 assert_eq!(min_orig.get(&[0]), 1.0);
317 assert_eq!(min_view.get(&[0]), 1.0);
318 }
319
320 #[test]
321 fn test_min_dims_non_contiguous() {
322 // Test min_dims on non-contiguous tensor
323 let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
324 let x_t = x.transpose(0, 1); // [3, 2]
325 assert!(!x_t.is_contiguous());
326
327 // Min along dim 0 of transposed tensor
328 let min_dim0 = x_t.min_dims(&[0], false);
329 assert_eq!(min_dim0.shape().dims(), vec![2]);
330 // Should be [min(1,2,3), min(4,5,6)] = [1, 4]
331 assert_eq!(min_dim0.get(&[0]), 1.0);
332 assert_eq!(min_dim0.get(&[1]), 4.0);
333
334 // Min along dim 1 of transposed tensor
335 let min_dim1 = x_t.min_dims(&[1], false);
336 assert_eq!(min_dim1.shape().dims(), vec![3]);
337 // Should be [min(1,4), min(2,5), min(3,6)] = [1, 2, 3]
338 assert_eq!(min_dim1.get(&[0]), 1.0);
339 assert_eq!(min_dim1.get(&[1]), 2.0);
340 assert_eq!(min_dim1.get(&[2]), 3.0);
341 }
342
343 #[test]
344 fn test_min_permuted_tensor() {
345 // Test with permuted tensor
346 let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
347 let x = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
348
349 // Permute dimensions [2, 3, 4] -> [4, 2, 3]
350 let x_perm = x.permute(vec![2, 1, 0]);
351 assert!(!x_perm.is_contiguous());
352
353 let min_orig = x.min();
354 let min_perm = x_perm.min();
355
356 // Should give same result
357 assert_eq!(min_orig.get(&[0]), min_perm.get(&[0]));
358
359 // Expected min: min(0,1,2,...,23) = 0
360 assert_eq!(min_orig.get(&[0]), 0.0);
361 }
362
363 #[test]
364 fn test_min_with_negative_values() {
365 // Test min with negative values on non-contiguous tensor
366 let x = Tensor::from_slice(&[-5.0, -2.0, -8.0, -1.0, -3.0, -6.0], vec![2, 3]).unwrap();
367 let x_t = x.transpose(0, 1);
368 assert!(!x_t.is_contiguous());
369
370 let min_orig = x.min();
371 let min_view = x_t.min();
372
373 // Both should give the same result: min(-5,-2,-8,-1,-3,-6) = -8
374 assert_eq!(min_orig.get(&[0]), -8.0);
375 assert_eq!(min_view.get(&[0]), -8.0);
376 }
377}