train_station/tensor/reductions/
sum.rs

1//! Sum reduction operations for tensors
2//!
3//! This module provides sum reduction operations that compute the sum of tensor elements.
4//! These operations support both global summation and dimension-wise summation with
5//! automatic gradient tracking when enabled.
6//!
7//! # Operations
8//!
9//! * `sum()` - Sum all elements into a scalar tensor
10//! * `sum_dims()` - Sum elements along specified dimensions
11//!
12//! # Examples
13//!
14//! ```
15//! use train_station::Tensor;
16//!
17//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
18//! let total = tensor.sum();
19//! assert_eq!(total.get(&[0]), 10.0); // 1 + 2 + 3 + 4 = 10
20//! ```
21
22use crate::gradtrack::{GradEngine, GradFn};
23use crate::tensor::core::Tensor;
24
25impl Tensor {
26    /// Returns the sum of all elements in the tensor
27    ///
28    /// This operation computes the sum of all elements across all dimensions,
29    /// reducing the tensor to a scalar value. The output is a tensor with shape \[1\]
30    /// containing the sum as a float.
31    ///
32    /// When `requires_grad` is enabled, this operation supports automatic gradient
33    /// tracking through the GradTrack system.
34    ///
35    /// # Returns
36    ///
37    /// A tensor with shape \[1\] containing the sum of all elements
38    ///
39    /// # Examples
40    ///
41    /// ```
42    /// use train_station::Tensor;
43    ///
44    /// // Basic sum calculation
45    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
46    /// let total = tensor.sum();
47    /// assert_eq!(total.shape().dims, vec![1]);
48    /// assert_eq!(total.get(&[0]), 10.0); // 1 + 2 + 3 + 4 = 10
49    /// ```
50    ///
51    /// ```
52    /// use train_station::Tensor;
53    ///
54    /// // Sum with gradient tracking
55    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3])
56    ///     .unwrap()
57    ///     .with_requires_grad();
58    /// let mut total = tensor.sum();
59    /// total.backward(None);
60    /// let grad = tensor.grad_by_value().expect("gradient should exist");
61    /// // Gradient should be [1.0, 1.0, 1.0] for each element
62    /// assert_eq!(grad.get(&[0]), 1.0);
63    /// assert_eq!(grad.get(&[1]), 1.0);
64    /// assert_eq!(grad.get(&[2]), 1.0);
65    /// ```
66    ///
67    /// ```
68    /// use train_station::Tensor;
69    ///
70    /// // Sum of empty tensor
71    /// let tensor = Tensor::new(vec![0]);
72    /// let total = tensor.sum();
73    /// assert_eq!(total.get(&[0]), 0.0); // Sum of empty tensor is 0
74    /// ```
75    ///
76    /// # Performance
77    ///
78    /// Uses optimized contiguous tensor path with 4x loop unrolling for better
79    /// performance. Non-contiguous tensors use stride-aware iteration.
80    pub fn sum(&self) -> Tensor {
81        let mut out = Tensor::new(vec![1]);
82        if self.size() == 0 {
83            out.fill(0.0);
84        } else {
85            let mut acc0 = 0.0f32;
86
87            if self.is_contiguous() {
88                // Fast path for contiguous tensors
89                unsafe {
90                    let src = self.as_ptr();
91                    let size = self.size();
92                    let mut i = 0usize;
93                    // Unrolled loop for better throughput
94                    while i + 4 <= size {
95                        let x0 = *src.add(i);
96                        let x1 = *src.add(i + 1);
97                        let x2 = *src.add(i + 2);
98                        let x3 = *src.add(i + 3);
99                        acc0 += x0 + x1 + x2 + x3;
100                        i += 4;
101                    }
102                    while i < size {
103                        acc0 += *src.add(i);
104                        i += 1;
105                    }
106                }
107            } else {
108                // Stride-aware path for non-contiguous tensors
109                let dims = self.shape().dims.clone();
110                for flat_idx in 0..self.size() {
111                    // Convert flat index to multi-dimensional coordinates
112                    let mut coords = vec![0; dims.len()];
113                    let mut tmp = flat_idx;
114                    for k in (0..dims.len()).rev() {
115                        coords[k] = tmp % dims[k];
116                        tmp /= dims[k];
117                    }
118
119                    // Get value using stride-aware offset
120                    let offset = self.shape().offset(&coords);
121                    let value = unsafe { *self.as_ptr().add(offset) };
122                    acc0 += value;
123                }
124            }
125
126            unsafe {
127                *out.as_mut_ptr() = acc0;
128            }
129        }
130
131        if self.requires_grad() {
132            out.set_requires_grad_internal(true);
133            let grad_fn = GradFn::ReduceSum {
134                input_shape: self.shape().dims.clone(),
135            };
136            out.set_grad_fn(grad_fn.clone());
137            GradEngine::register_operation(out.id(), vec![self.id()], grad_fn);
138        }
139
140        out
141    }
142
143    /// Returns the sum of elements along specified dimensions
144    ///
145    /// This operation computes the sum of elements along the specified dimensions,
146    /// reducing the tensor while optionally preserving the reduced dimensions as
147    /// size-1 dimensions.
148    ///
149    /// The output shape depends on the `keepdim` parameter:
150    /// * If `keepdim` is `true`, the reduced dimensions are kept with size 1
151    /// * If `keepdim` is `false`, the reduced dimensions are removed
152    ///
153    /// When `requires_grad` is enabled, this operation supports automatic gradient
154    /// tracking through the GradTrack system.
155    ///
156    /// # Arguments
157    ///
158    /// * `dims` - Vector of dimension indices to sum over (must be valid for tensor rank)
159    /// * `keepdim` - Whether to keep reduced dimensions as size-1 dimensions
160    ///
161    /// # Returns
162    ///
163    /// A tensor with sum computed over the specified dimensions
164    ///
165    /// # Panics
166    ///
167    /// * If `dims` is empty
168    /// * If any dimension index is out of bounds for the tensor rank
169    ///
170    /// # Examples
171    ///
172    /// ```
173    /// use train_station::Tensor;
174    ///
175    /// // Sum along rows (dimension 0) with keepdim=false
176    /// let matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
177    /// let row_sums = matrix.sum_dims(&[0], false);
178    /// assert_eq!(row_sums.shape().dims, vec![2]);
179    /// assert_eq!(row_sums.get(&[0]), 4.0); // 1 + 3 = 4
180    /// assert_eq!(row_sums.get(&[1]), 6.0); // 2 + 4 = 6
181    /// ```
182    ///
183    /// ```
184    /// use train_station::Tensor;
185    ///
186    /// // Sum along columns (dimension 1) with keepdim=true
187    /// let matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
188    /// let col_sums = matrix.sum_dims(&[1], true);
189    /// assert_eq!(col_sums.shape().dims, vec![2, 1]);
190    /// assert_eq!(col_sums.get(&[0, 0]), 3.0); // 1 + 2 = 3
191    /// assert_eq!(col_sums.get(&[1, 0]), 7.0); // 3 + 4 = 7
192    /// ```
193    ///
194    /// ```
195    /// use train_station::Tensor;
196    ///
197    /// // Sum over multiple dimensions
198    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
199    /// let total = tensor.sum_dims(&[0, 1], false);
200    /// assert_eq!(total.shape().dims, vec![1]);
201    /// assert_eq!(total.get(&[0]), 10.0); // 1 + 2 + 3 + 4 = 10
202    /// ```
203    ///
204    /// ```
205    /// use train_station::Tensor;
206    ///
207    /// // Sum with gradient tracking
208    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2])
209    ///     .unwrap()
210    ///     .with_requires_grad();
211    /// let mut row_sums = tensor.sum_dims(&[0], false);
212    /// row_sums.backward(None);
213    /// let grad = tensor.grad_by_value().expect("gradient should exist");
214    /// // Gradient should be [1.0, 1.0, 1.0, 1.0] for each element
215    /// assert_eq!(grad.get(&[0, 0]), 1.0);
216    /// assert_eq!(grad.get(&[0, 1]), 1.0);
217    /// assert_eq!(grad.get(&[1, 0]), 1.0);
218    /// assert_eq!(grad.get(&[1, 1]), 1.0);
219    /// ```
220    ///
221    /// # Performance
222    ///
223    /// Uses efficient coordinate-based iteration that works correctly with
224    /// both contiguous and non-contiguous tensor layouts.
225    pub fn sum_dims(&self, dims: &[usize], keepdim: bool) -> Tensor {
226        assert!(!dims.is_empty(), "sum_dims requires at least one dimension");
227        let rank = self.shape().rank();
228        for &d in dims {
229            assert!(
230                d < rank,
231                "sum_dims dim {} out of bounds for rank {}",
232                d,
233                rank
234            );
235        }
236
237        // Build output shape
238        let mut out_dims = self.shape().dims.clone();
239        let mut reduced: Vec<usize> = dims.to_vec();
240        reduced.sort_unstable();
241        reduced.dedup();
242        for &d in reduced.iter() {
243            out_dims[d] = if keepdim { 1 } else { 0 };
244        }
245        if !keepdim {
246            out_dims.retain(|&s| s != 0);
247        }
248        if out_dims.is_empty() {
249            out_dims.push(1);
250        }
251        let mut out = Tensor::zeros(out_dims.clone());
252
253        // Accumulate along reduced dims
254        let in_shape = self.shape().dims.clone();
255        let out_rank = out.shape().rank();
256        let mut in_coords = vec![0usize; rank];
257        unsafe {
258            let dst = out.as_mut_ptr();
259            // Iterate over all input elements, map to output index
260            for lin in 0..self.size() {
261                let mut tmp = lin;
262                for i in (0..rank).rev() {
263                    let s = in_shape[i];
264                    in_coords[i] = if s == 0 { 0 } else { tmp % s };
265                    if s != 0 {
266                        tmp /= s;
267                    }
268                }
269
270                // Get input value using stride-aware offset
271                let in_offset = self.shape().offset(&in_coords);
272                let value = *self.as_ptr().add(in_offset);
273
274                // build output coords
275                let mut out_coords: Vec<usize> = Vec::with_capacity(out_rank);
276                for (i, &c) in in_coords.iter().enumerate().take(rank) {
277                    if reduced.contains(&i) {
278                        if keepdim {
279                            out_coords.push(0);
280                        }
281                    } else {
282                        out_coords.push(c);
283                    }
284                }
285                let off = if out_coords.is_empty() {
286                    0
287                } else {
288                    out.shape().offset(&out_coords)
289                };
290                *dst.add(off) += value;
291            }
292        }
293
294        if self.requires_grad() {
295            out.set_requires_grad_internal(true);
296            let grad_fn = GradFn::ReduceSumDims {
297                dims: reduced,
298                input_shape: self.shape().dims.clone(),
299                keepdim,
300            };
301            out.set_grad_fn(grad_fn.clone());
302            GradEngine::register_operation(out.id(), vec![self.id()], grad_fn);
303        }
304
305        out
306    }
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312
313    #[test]
314    fn test_sum_forward_basic() {
315        let mut x = Tensor::zeros(vec![2, 3]);
316        unsafe {
317            for i in 0..6 {
318                *x.as_mut_ptr().add(i) = (i as f32) * 0.5;
319            }
320        }
321        let s = x.sum();
322        assert_eq!(s.shape().dims, vec![1]);
323        unsafe {
324            assert!((*s.as_ptr() - 7.5).abs() < 1e-6);
325        }
326    }
327
328    #[test]
329    fn test_sum_autograd_all_ones_grad() {
330        let mut x = Tensor::zeros(vec![2, 2]).with_requires_grad();
331        unsafe {
332            for i in 0..4 {
333                *x.as_mut_ptr().add(i) = i as f32;
334            }
335        }
336        let mut s = x.sum();
337        s.backward(None);
338        let gx = x.grad_by_value().expect("grad missing");
339        for i in 0..4 {
340            unsafe {
341                assert_eq!(*gx.as_ptr().add(i), 1.0);
342            }
343        }
344    }
345
346    #[test]
347    fn test_sum_chain_autograd() {
348        let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2])
349            .unwrap()
350            .with_requires_grad();
351        let y = x.mul_scalar(2.0).add_scalar(1.0);
352        let mut s = y.sum();
353        s.backward(None);
354        let gx = x.grad_by_value().expect("grad missing");
355        // d/dx of sum(2x+1) = 2 for each element
356        for i in 0..4 {
357            unsafe {
358                assert_eq!(*gx.as_ptr().add(i), 2.0);
359            }
360        }
361    }
362
363    #[test]
364    fn test_sum_non_contiguous_transpose() {
365        // Test sum on transposed tensor (non-contiguous view)
366        let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
367        // Original: [[1, 2, 3], [4, 5, 6]]
368
369        let x_t = x.transpose(0, 1);
370        // Transposed: [[1, 4], [2, 5], [3, 6]]
371        assert!(!x_t.is_contiguous()); // Should be a view
372
373        let sum_orig = x.sum();
374        let sum_view = x_t.sum();
375
376        // Both should give the same result: 1+2+3+4+5+6 = 21
377        assert_eq!(sum_orig.get(&[0]), 21.0);
378        assert_eq!(sum_view.get(&[0]), 21.0);
379    }
380
381    #[test]
382    fn test_sum_dims_non_contiguous() {
383        // Test sum_dims on non-contiguous tensor
384        let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
385        let x_t = x.transpose(0, 1); // [3, 2]
386        assert!(!x_t.is_contiguous());
387
388        // Sum along dim 0 of transposed tensor
389        let sum_dim0 = x_t.sum_dims(&[0], false);
390        assert_eq!(sum_dim0.shape().dims, vec![2]);
391        // Should be [1+2+3, 4+5+6] = [6, 15]
392        assert_eq!(sum_dim0.get(&[0]), 6.0);
393        assert_eq!(sum_dim0.get(&[1]), 15.0);
394
395        // Sum along dim 1 of transposed tensor
396        let sum_dim1 = x_t.sum_dims(&[1], false);
397        assert_eq!(sum_dim1.shape().dims, vec![3]);
398        // Should be [1+4, 2+5, 3+6] = [5, 7, 9]
399        assert_eq!(sum_dim1.get(&[0]), 5.0);
400        assert_eq!(sum_dim1.get(&[1]), 7.0);
401        assert_eq!(sum_dim1.get(&[2]), 9.0);
402    }
403
404    #[test]
405    fn test_sum_permuted_tensor() {
406        // Test with permuted tensor
407        let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
408        let x = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
409
410        // Permute dimensions [2, 3, 4] -> [4, 2, 3]
411        let x_perm = x.permute(vec![2, 1, 0]);
412        assert!(!x_perm.is_contiguous());
413
414        let sum_orig = x.sum();
415        let sum_perm = x_perm.sum();
416
417        // Should give same result
418        assert_eq!(sum_orig.get(&[0]), sum_perm.get(&[0]));
419
420        // Expected sum: 0+1+2+...+23 = 23*24/2 = 276
421        assert_eq!(sum_orig.get(&[0]), 276.0);
422    }
423}