train_station/tensor/reductions/min.rs
1use crate::gradtrack::{GradEngine, GradFn};
2use crate::tensor::core::Tensor;
3
4impl Tensor {
5 /// Computes the minimum value over all elements in the tensor
6 ///
7 /// Returns a scalar tensor containing the minimum value. For empty tensors,
8 /// returns positive infinity. This operation supports gradient tracking
9 /// through the GradTrack system.
10 ///
11 /// # Returns
12 ///
13 /// A tensor with shape `[1]` containing the minimum value
14 ///
15 /// # Examples
16 ///
17 /// ```
18 /// use train_station::Tensor;
19 ///
20 /// let tensor = Tensor::from_slice(&[1.0, 5.0, 3.0, 2.0], vec![2, 2]).unwrap();
21 /// let min_val = tensor.min();
22 /// assert_eq!(min_val.get(&[0]), 1.0);
23 /// ```
24 ///
25 /// ```
26 /// use train_station::Tensor;
27 ///
28 /// // Empty tensor case
29 /// let empty_tensor = Tensor::new(vec![0]);
30 /// let min_val = empty_tensor.min();
31 /// assert_eq!(min_val.get(&[0]), f32::INFINITY);
32 /// ```
33 ///
34 /// # GradTrack Support
35 ///
36 /// When `requires_grad` is true, this operation is tracked for automatic
37 /// differentiation. The gradient computation uses the saved input and output
38 /// for efficient backward pass.
39 pub fn min(&self) -> Tensor {
40 let mut out = Tensor::new(vec![1]);
41 if self.size() == 0 {
42 out.fill(f32::INFINITY);
43 } else {
44 let mut m = f32::INFINITY;
45
46 if self.is_contiguous() {
47 // Fast path for contiguous tensors
48 unsafe {
49 let src = self.as_ptr();
50 let size = self.size();
51 m = *src;
52 let mut i = 1usize;
53 while i + 4 <= size {
54 let x0 = *src.add(i);
55 let x1 = *src.add(i + 1);
56 let x2 = *src.add(i + 2);
57 let x3 = *src.add(i + 3);
58 m = m.min(x0).min(x1).min(x2).min(x3);
59 i += 4;
60 }
61 while i < size {
62 m = m.min(*src.add(i));
63 i += 1;
64 }
65 }
66 } else {
67 // Stride-aware path for non-contiguous tensors
68 let dims = self.shape().dims.clone();
69 for flat_idx in 0..self.size() {
70 // Convert flat index to multi-dimensional coordinates
71 let mut coords = vec![0; dims.len()];
72 let mut tmp = flat_idx;
73 for k in (0..dims.len()).rev() {
74 coords[k] = tmp % dims[k];
75 tmp /= dims[k];
76 }
77
78 // Get value using stride-aware offset
79 let offset = self.shape().offset(&coords);
80 let value = unsafe { *self.as_ptr().add(offset) };
81 if flat_idx == 0 {
82 m = value;
83 } else {
84 m = m.min(value);
85 }
86 }
87 }
88
89 unsafe {
90 *out.as_mut_ptr() = m;
91 }
92 }
93
94 if self.requires_grad() {
95 let mut result = out.clone();
96 result.set_requires_grad_internal(true);
97 let grad_fn = GradFn::ReduceMin {
98 saved_output: Box::new(out.clone()),
99 saved_input: Box::new(self.clone()),
100 input_shape: self.shape().dims.clone(),
101 };
102 result.set_grad_fn(grad_fn.clone());
103 GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
104 return result;
105 }
106
107 out
108 }
109
110 /// Computes the minimum value over specified dimensions
111 ///
112 /// Reduces the tensor along the specified dimensions by computing the minimum
113 /// value in each reduction group. The `keepdim` parameter determines whether
114 /// reduced dimensions are kept with size 1 or removed entirely.
115 ///
116 /// # Arguments
117 ///
118 /// * `dims` - Dimensions to reduce over (must be valid for the tensor's rank)
119 /// * `keepdim` - If true, reduced dimensions are kept with size 1; if false, they are removed
120 ///
121 /// # Returns
122 ///
123 /// A tensor with the specified dimensions reduced
124 ///
125 /// # Examples
126 ///
127 /// ```
128 /// use train_station::Tensor;
129 ///
130 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
131 ///
132 /// // Min over columns (dim 1), keeping dimensions
133 /// let min_cols = tensor.min_dims(&[1], true);
134 /// assert_eq!(min_cols.shape().dims, vec![2, 1]);
135 /// assert_eq!(min_cols.get(&[0, 0]), 1.0);
136 /// assert_eq!(min_cols.get(&[1, 0]), 4.0);
137 /// ```
138 ///
139 /// ```
140 /// use train_station::Tensor;
141 ///
142 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
143 ///
144 /// // Min over rows (dim 0), removing dimensions
145 /// let min_rows = tensor.min_dims(&[0], false);
146 /// assert_eq!(min_rows.shape().dims, vec![3]);
147 /// assert_eq!(min_rows.get(&[0]), 1.0);
148 /// assert_eq!(min_rows.get(&[1]), 2.0);
149 /// assert_eq!(min_rows.get(&[2]), 3.0);
150 /// ```
151 ///
152 /// ```
153 /// use train_station::Tensor;
154 ///
155 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
156 ///
157 /// // Min over multiple dimensions
158 /// let min_all = tensor.min_dims(&[0, 1], false);
159 /// assert_eq!(min_all.shape().dims, vec![1]);
160 /// assert_eq!(min_all.get(&[0]), 1.0);
161 /// ```
162 ///
163 /// # Panics
164 ///
165 /// Panics if:
166 /// * `dims` is empty
167 /// * Any dimension in `dims` is out of bounds for the tensor's rank
168 ///
169 /// # GradTrack Support
170 ///
171 /// When `requires_grad` is true, this operation is tracked for automatic
172 /// differentiation. The gradient computation preserves the original input
173 /// shape and handles broadcasting correctly.
174 pub fn min_dims(&self, dims: &[usize], keepdim: bool) -> Tensor {
175 assert!(!dims.is_empty(), "min_dims requires at least one dimension");
176 let rank = self.shape().rank();
177 for &d in dims {
178 assert!(
179 d < rank,
180 "min_dims dim {} out of bounds for rank {}",
181 d,
182 rank
183 );
184 }
185
186 // Build output shape
187 let mut out_dims = self.shape().dims.clone();
188 let mut reduced: Vec<usize> = dims.to_vec();
189 reduced.sort_unstable();
190 reduced.dedup();
191 for &d in reduced.iter() {
192 out_dims[d] = if keepdim { 1 } else { 0 };
193 }
194 if !keepdim {
195 out_dims.retain(|&s| s != 0);
196 }
197 if out_dims.is_empty() {
198 out_dims.push(1);
199 }
200 let mut out = Tensor::zeros(out_dims.clone());
201
202 // Compute min along reduced dims
203 let in_shape = self.shape().dims.clone();
204 let out_rank = out.shape().rank();
205 let mut in_coords = vec![0usize; rank];
206 unsafe {
207 let dst = out.as_mut_ptr();
208 // Initialize output with +inf
209 for i in 0..out.size() {
210 *dst.add(i) = f32::INFINITY;
211 }
212 for lin in 0..self.size() {
213 let mut tmp = lin;
214 for i in (0..rank).rev() {
215 let s = in_shape[i];
216 in_coords[i] = if s == 0 { 0 } else { tmp % s };
217 if s != 0 {
218 tmp /= s;
219 }
220 }
221
222 // Get input value using stride-aware offset
223 let in_offset = self.shape().offset(&in_coords);
224 let val = *self.as_ptr().add(in_offset);
225
226 let mut out_coords: Vec<usize> = Vec::with_capacity(out_rank);
227 for (i, &c) in in_coords.iter().enumerate().take(rank) {
228 if reduced.contains(&i) {
229 if keepdim {
230 out_coords.push(0);
231 }
232 } else {
233 out_coords.push(c);
234 }
235 }
236 let off = if out_coords.is_empty() {
237 0
238 } else {
239 out.shape().offset(&out_coords)
240 };
241 let cur = *dst.add(off);
242 if val < cur {
243 *dst.add(off) = val;
244 }
245 }
246 }
247
248 if self.requires_grad() {
249 let mut result = out.clone();
250 result.set_requires_grad_internal(true);
251 let grad_fn = GradFn::ReduceMinDims {
252 dims: reduced,
253 keepdim,
254 input_shape: self.shape().dims.clone(),
255 saved_output: Box::new(out.clone()),
256 saved_input: Box::new(self.clone()),
257 };
258 result.set_grad_fn(grad_fn.clone());
259 GradEngine::register_operation(result.id(), vec![self.id()], grad_fn);
260 return result;
261 }
262
263 out
264 }
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270
271 #[test]
272 fn test_min_forward_basic() {
273 let mut x = Tensor::zeros(vec![2, 3]);
274 unsafe {
275 for i in 0..6 {
276 *x.as_mut_ptr().add(i) = (i as f32) - 3.0;
277 }
278 }
279 let m = x.min();
280 assert_eq!(m.shape().dims, vec![1]);
281 unsafe {
282 assert_eq!(*m.as_ptr(), -3.0);
283 }
284 }
285
286 #[test]
287 fn test_min_dims_forward() {
288 let mut x = Tensor::zeros(vec![2, 3]);
289 unsafe {
290 for i in 0..6 {
291 *x.as_mut_ptr().add(i) = (i as f32) - 3.0;
292 }
293 }
294 let m = x.min_dims(&[1], true);
295 assert_eq!(m.shape().dims, vec![2, 1]);
296 assert_eq!(m.get(&[0, 0]), -3.0);
297 assert_eq!(m.get(&[1, 0]), 0.0);
298 }
299
300 #[test]
301 fn test_min_non_contiguous_transpose() {
302 // Test min on transposed tensor (non-contiguous view)
303 let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
304 // Original: [[1, 2, 3], [4, 5, 6]]
305
306 let x_t = x.transpose(0, 1);
307 // Transposed: [[1, 4], [2, 5], [3, 6]]
308 assert!(!x_t.is_contiguous()); // Should be a view
309
310 let min_orig = x.min();
311 let min_view = x_t.min();
312
313 // Both should give the same result: min(1,2,3,4,5,6) = 1
314 assert_eq!(min_orig.get(&[0]), 1.0);
315 assert_eq!(min_view.get(&[0]), 1.0);
316 }
317
318 #[test]
319 fn test_min_dims_non_contiguous() {
320 // Test min_dims on non-contiguous tensor
321 let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
322 let x_t = x.transpose(0, 1); // [3, 2]
323 assert!(!x_t.is_contiguous());
324
325 // Min along dim 0 of transposed tensor
326 let min_dim0 = x_t.min_dims(&[0], false);
327 assert_eq!(min_dim0.shape().dims, vec![2]);
328 // Should be [min(1,2,3), min(4,5,6)] = [1, 4]
329 assert_eq!(min_dim0.get(&[0]), 1.0);
330 assert_eq!(min_dim0.get(&[1]), 4.0);
331
332 // Min along dim 1 of transposed tensor
333 let min_dim1 = x_t.min_dims(&[1], false);
334 assert_eq!(min_dim1.shape().dims, vec![3]);
335 // Should be [min(1,4), min(2,5), min(3,6)] = [1, 2, 3]
336 assert_eq!(min_dim1.get(&[0]), 1.0);
337 assert_eq!(min_dim1.get(&[1]), 2.0);
338 assert_eq!(min_dim1.get(&[2]), 3.0);
339 }
340
341 #[test]
342 fn test_min_permuted_tensor() {
343 // Test with permuted tensor
344 let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
345 let x = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
346
347 // Permute dimensions [2, 3, 4] -> [4, 2, 3]
348 let x_perm = x.permute(vec![2, 1, 0]);
349 assert!(!x_perm.is_contiguous());
350
351 let min_orig = x.min();
352 let min_perm = x_perm.min();
353
354 // Should give same result
355 assert_eq!(min_orig.get(&[0]), min_perm.get(&[0]));
356
357 // Expected min: min(0,1,2,...,23) = 0
358 assert_eq!(min_orig.get(&[0]), 0.0);
359 }
360
361 #[test]
362 fn test_min_with_negative_values() {
363 // Test min with negative values on non-contiguous tensor
364 let x = Tensor::from_slice(&[-5.0, -2.0, -8.0, -1.0, -3.0, -6.0], vec![2, 3]).unwrap();
365 let x_t = x.transpose(0, 1);
366 assert!(!x_t.is_contiguous());
367
368 let min_orig = x.min();
369 let min_view = x_t.min();
370
371 // Both should give the same result: min(-5,-2,-8,-1,-3,-6) = -8
372 assert_eq!(min_orig.get(&[0]), -8.0);
373 assert_eq!(min_view.get(&[0]), -8.0);
374 }
375}