train_station/tensor/reductions/
sum.rs

1//! Sum reduction operations for tensors
2//!
3//! This module provides sum reduction operations that compute the sum of tensor elements.
4//! These operations support both global summation and dimension-wise summation with
5//! automatic gradient tracking when enabled.
6//!
7//! # Operations
8//!
9//! * `sum()` - Sum all elements into a scalar tensor
10//! * `sum_dims()` - Sum elements along specified dimensions
11//!
12//! # Examples
13//!
14//! ```
15//! use train_station::Tensor;
16//!
17//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
18//! let total = tensor.sum();
19//! assert_eq!(total.get(&[0]), 10.0); // 1 + 2 + 3 + 4 = 10
20//! ```
21
22use crate::gradtrack::{GradEngine, GradFn};
23use crate::tensor::core::Tensor;
24
25impl Tensor {
26    /// Returns the sum of all elements in the tensor
27    ///
28    /// This operation computes the sum of all elements across all dimensions,
29    /// reducing the tensor to a scalar value. The output is a tensor with shape \[1\]
30    /// containing the sum as a float.
31    ///
32    /// When `requires_grad` is enabled, this operation supports automatic gradient
33    /// tracking through the GradTrack system.
34    ///
35    /// # Returns
36    ///
37    /// A tensor with shape \[1\] containing the sum of all elements
38    ///
39    /// # Examples
40    ///
41    /// ```
42    /// use train_station::Tensor;
43    ///
44    /// // Basic sum calculation
45    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
46    /// let total = tensor.sum();
47    /// assert_eq!(total.shape().dims(), vec![1]);
48    /// assert_eq!(total.get(&[0]), 10.0); // 1 + 2 + 3 + 4 = 10
49    /// ```
50    ///
51    /// ```
52    /// use train_station::Tensor;
53    ///
54    /// // Sum with gradient tracking
55    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3])
56    ///     .unwrap()
57    ///     .with_requires_grad();
58    /// let mut total = tensor.sum();
59    /// total.backward(None);
60    /// let grad = tensor.grad_owned().expect("gradient should exist");
61    /// // Gradient should be [1.0, 1.0, 1.0] for each element
62    /// assert_eq!(grad.get(&[0]), 1.0);
63    /// assert_eq!(grad.get(&[1]), 1.0);
64    /// assert_eq!(grad.get(&[2]), 1.0);
65    /// ```
66    ///
67    /// ```
68    /// use train_station::Tensor;
69    ///
70    /// // Sum of empty tensor
71    /// let tensor = Tensor::new(vec![0]);
72    /// let total = tensor.sum();
73    /// assert_eq!(total.get(&[0]), 0.0); // Sum of empty tensor is 0
74    /// ```
75    ///
76    /// # Performance
77    ///
78    /// Uses optimized contiguous tensor path with 4x loop unrolling for better
79    /// performance. Non-contiguous tensors use stride-aware iteration.
80    #[track_caller]
81    pub fn sum(&self) -> Tensor {
82        let mut out = Tensor::new(vec![1]);
83        if self.size() == 0 {
84            out.fill(0.0);
85        } else {
86            let mut acc0 = 0.0f32;
87
88            if self.is_contiguous() {
89                // Fast path for contiguous tensors
90                unsafe {
91                    let src = self.as_ptr();
92                    let size = self.size();
93                    let mut i = 0usize;
94                    // Unrolled loop for better throughput
95                    while i + 4 <= size {
96                        let x0 = *src.add(i);
97                        let x1 = *src.add(i + 1);
98                        let x2 = *src.add(i + 2);
99                        let x3 = *src.add(i + 3);
100                        acc0 += x0 + x1 + x2 + x3;
101                        i += 4;
102                    }
103                    while i < size {
104                        acc0 += *src.add(i);
105                        i += 1;
106                    }
107                }
108            } else {
109                // Stride-aware path for non-contiguous tensors
110                let dims = self.shape().dims().to_vec();
111                for flat_idx in 0..self.size() {
112                    // Convert flat index to multi-dimensional coordinates
113                    let mut coords = vec![0; dims.len()];
114                    let mut tmp = flat_idx;
115                    for k in (0..dims.len()).rev() {
116                        coords[k] = tmp % dims[k];
117                        tmp /= dims[k];
118                    }
119
120                    // Get value using stride-aware offset
121                    let offset = self.shape().offset(&coords);
122                    let value = unsafe { *self.as_ptr().add(offset) };
123                    acc0 += value;
124                }
125            }
126
127            unsafe {
128                *out.as_mut_ptr() = acc0;
129            }
130        }
131
132        if self.requires_grad() {
133            out.set_requires_grad_internal(true);
134            let grad_fn = GradFn::ReduceSum {
135                input_shape: self.shape().dims().to_vec(),
136            };
137            out.set_grad_fn(grad_fn.clone());
138            GradEngine::register_operation(out.id(), vec![self.id()], grad_fn);
139        }
140
141        out
142    }
143
144    /// Returns the sum of elements along specified dimensions
145    ///
146    /// This operation computes the sum of elements along the specified dimensions,
147    /// reducing the tensor while optionally preserving the reduced dimensions as
148    /// size-1 dimensions.
149    ///
150    /// The output shape depends on the `keepdim` parameter:
151    /// * If `keepdim` is `true`, the reduced dimensions are kept with size 1
152    /// * If `keepdim` is `false`, the reduced dimensions are removed
153    ///
154    /// When `requires_grad` is enabled, this operation supports automatic gradient
155    /// tracking through the GradTrack system.
156    ///
157    /// # Arguments
158    ///
159    /// * `dims` - Vector of dimension indices to sum over (must be valid for tensor rank)
160    /// * `keepdim` - Whether to keep reduced dimensions as size-1 dimensions
161    ///
162    /// # Returns
163    ///
164    /// A tensor with sum computed over the specified dimensions
165    ///
166    /// # Panics
167    ///
168    /// * If `dims` is empty
169    /// * If any dimension index is out of bounds for the tensor rank
170    ///
171    /// # Examples
172    ///
173    /// ```
174    /// use train_station::Tensor;
175    ///
176    /// // Sum along rows (dimension 0) with keepdim=false
177    /// let matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
178    /// let row_sums = matrix.sum_dims(&[0], false);
179    /// assert_eq!(row_sums.shape().dims(), vec![2]);
180    /// assert_eq!(row_sums.get(&[0]), 4.0); // 1 + 3 = 4
181    /// assert_eq!(row_sums.get(&[1]), 6.0); // 2 + 4 = 6
182    /// ```
183    ///
184    /// ```
185    /// use train_station::Tensor;
186    ///
187    /// // Sum along columns (dimension 1) with keepdim=true
188    /// let matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
189    /// let col_sums = matrix.sum_dims(&[1], true);
190    /// assert_eq!(col_sums.shape().dims(), vec![2, 1]);
191    /// assert_eq!(col_sums.get(&[0, 0]), 3.0); // 1 + 2 = 3
192    /// assert_eq!(col_sums.get(&[1, 0]), 7.0); // 3 + 4 = 7
193    /// ```
194    ///
195    /// ```
196    /// use train_station::Tensor;
197    ///
198    /// // Sum over multiple dimensions
199    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
200    /// let total = tensor.sum_dims(&[0, 1], false);
201    /// assert_eq!(total.shape().dims(), vec![1]);
202    /// assert_eq!(total.get(&[0]), 10.0); // 1 + 2 + 3 + 4 = 10
203    /// ```
204    ///
205    /// ```
206    /// use train_station::Tensor;
207    ///
208    /// // Sum with gradient tracking
209    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2])
210    ///     .unwrap()
211    ///     .with_requires_grad();
212    /// let mut row_sums = tensor.sum_dims(&[0], false);
213    /// row_sums.backward(None);
214    /// let grad = tensor.grad_owned().expect("gradient should exist");
215    /// // Gradient should be [1.0, 1.0, 1.0, 1.0] for each element
216    /// assert_eq!(grad.get(&[0, 0]), 1.0);
217    /// assert_eq!(grad.get(&[0, 1]), 1.0);
218    /// assert_eq!(grad.get(&[1, 0]), 1.0);
219    /// assert_eq!(grad.get(&[1, 1]), 1.0);
220    /// ```
221    ///
222    /// # Performance
223    ///
224    /// Uses efficient coordinate-based iteration that works correctly with
225    /// both contiguous and non-contiguous tensor layouts.
226    #[track_caller]
227    pub fn sum_dims(&self, dims: &[usize], keepdim: bool) -> Tensor {
228        assert!(!dims.is_empty(), "sum_dims requires at least one dimension");
229        let rank = self.shape().rank();
230        for &d in dims {
231            assert!(
232                d < rank,
233                "sum_dims dim {} out of bounds for rank {}",
234                d,
235                rank
236            );
237        }
238
239        // Build output shape
240        let mut out_dims = self.shape().dims().to_vec();
241        let mut reduced: Vec<usize> = dims.to_vec();
242        reduced.sort_unstable();
243        reduced.dedup();
244        for &d in reduced.iter() {
245            out_dims[d] = if keepdim { 1 } else { 0 };
246        }
247        if !keepdim {
248            out_dims.retain(|&s| s != 0);
249        }
250        if out_dims.is_empty() {
251            out_dims.push(1);
252        }
253        let mut out = Tensor::zeros(out_dims.clone());
254
255        // Accumulate along reduced dims
256        let in_shape = self.shape().dims().to_vec();
257        let out_rank = out.shape().rank();
258        let mut in_coords = vec![0usize; rank];
259        unsafe {
260            let dst = out.as_mut_ptr();
261            // Iterate over all input elements, map to output index
262            for lin in 0..self.size() {
263                let mut tmp = lin;
264                for i in (0..rank).rev() {
265                    let s = in_shape[i];
266                    in_coords[i] = if s == 0 { 0 } else { tmp % s };
267                    if s != 0 {
268                        tmp /= s;
269                    }
270                }
271
272                // Get input value using stride-aware offset
273                let in_offset = self.shape().offset(&in_coords);
274                let value = *self.as_ptr().add(in_offset);
275
276                // build output coords
277                let mut out_coords: Vec<usize> = Vec::with_capacity(out_rank);
278                for (i, &c) in in_coords.iter().enumerate().take(rank) {
279                    if reduced.contains(&i) {
280                        if keepdim {
281                            out_coords.push(0);
282                        }
283                    } else {
284                        out_coords.push(c);
285                    }
286                }
287                let off = if out_coords.is_empty() {
288                    0
289                } else {
290                    out.shape().offset(&out_coords)
291                };
292                *dst.add(off) += value;
293            }
294        }
295
296        if self.requires_grad() {
297            out.set_requires_grad_internal(true);
298            let grad_fn = GradFn::ReduceSumDims {
299                dims: reduced,
300                input_shape: self.shape().dims().to_vec(),
301                keepdim,
302            };
303            out.set_grad_fn(grad_fn.clone());
304            GradEngine::register_operation(out.id(), vec![self.id()], grad_fn);
305        }
306
307        out
308    }
309}
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314
315    #[test]
316    fn test_sum_forward_basic() {
317        let mut x = Tensor::zeros(vec![2, 3]);
318        unsafe {
319            for i in 0..6 {
320                *x.as_mut_ptr().add(i) = (i as f32) * 0.5;
321            }
322        }
323        let s = x.sum();
324        assert_eq!(s.shape().dims(), vec![1]);
325        unsafe {
326            assert!((*s.as_ptr() - 7.5).abs() < 1e-6);
327        }
328    }
329
330    #[test]
331    fn test_sum_autograd_all_ones_grad() {
332        let mut x = Tensor::zeros(vec![2, 2]).with_requires_grad();
333        unsafe {
334            for i in 0..4 {
335                *x.as_mut_ptr().add(i) = i as f32;
336            }
337        }
338        let mut s = x.sum();
339        s.backward(None);
340        let gx = x.grad_owned().expect("grad missing");
341        for i in 0..4 {
342            unsafe {
343                assert_eq!(*gx.as_ptr().add(i), 1.0);
344            }
345        }
346    }
347
348    #[test]
349    fn test_sum_chain_autograd() {
350        let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2])
351            .unwrap()
352            .with_requires_grad();
353        let y = x.mul_scalar(2.0).add_scalar(1.0);
354        let mut s = y.sum();
355        s.backward(None);
356        let gx = x.grad_owned().expect("grad missing");
357        // d/dx of sum(2x+1) = 2 for each element
358        for i in 0..4 {
359            unsafe {
360                assert_eq!(*gx.as_ptr().add(i), 2.0);
361            }
362        }
363    }
364
365    #[test]
366    fn test_sum_non_contiguous_transpose() {
367        // Test sum on transposed tensor (non-contiguous view)
368        let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
369        // Original: [[1, 2, 3], [4, 5, 6]]
370
371        let x_t = x.transpose(0, 1);
372        // Transposed: [[1, 4], [2, 5], [3, 6]]
373        assert!(!x_t.is_contiguous()); // Should be a view
374
375        let sum_orig = x.sum();
376        let sum_view = x_t.sum();
377
378        // Both should give the same result: 1+2+3+4+5+6 = 21
379        assert_eq!(sum_orig.get(&[0]), 21.0);
380        assert_eq!(sum_view.get(&[0]), 21.0);
381    }
382
383    #[test]
384    fn test_sum_dims_non_contiguous() {
385        // Test sum_dims on non-contiguous tensor
386        let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
387        let x_t = x.transpose(0, 1); // [3, 2]
388        assert!(!x_t.is_contiguous());
389
390        // Sum along dim 0 of transposed tensor
391        let sum_dim0 = x_t.sum_dims(&[0], false);
392        assert_eq!(sum_dim0.shape().dims(), vec![2]);
393        // Should be [1+2+3, 4+5+6] = [6, 15]
394        assert_eq!(sum_dim0.get(&[0]), 6.0);
395        assert_eq!(sum_dim0.get(&[1]), 15.0);
396
397        // Sum along dim 1 of transposed tensor
398        let sum_dim1 = x_t.sum_dims(&[1], false);
399        assert_eq!(sum_dim1.shape().dims(), vec![3]);
400        // Should be [1+4, 2+5, 3+6] = [5, 7, 9]
401        assert_eq!(sum_dim1.get(&[0]), 5.0);
402        assert_eq!(sum_dim1.get(&[1]), 7.0);
403        assert_eq!(sum_dim1.get(&[2]), 9.0);
404    }
405
406    #[test]
407    fn test_sum_permuted_tensor() {
408        // Test with permuted tensor
409        let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
410        let x = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
411
412        // Permute dimensions [2, 3, 4] -> [4, 2, 3]
413        let x_perm = x.permute(vec![2, 1, 0]);
414        assert!(!x_perm.is_contiguous());
415
416        let sum_orig = x.sum();
417        let sum_perm = x_perm.sum();
418
419        // Should give same result
420        assert_eq!(sum_orig.get(&[0]), sum_perm.get(&[0]));
421
422        // Expected sum: 0+1+2+...+23 = 23*24/2 = 276
423        assert_eq!(sum_orig.get(&[0]), 276.0);
424    }
425}