train_station/tensor/reductions/sum.rs
1//! Sum reduction operations for tensors
2//!
3//! This module provides sum reduction operations that compute the sum of tensor elements.
4//! These operations support both global summation and dimension-wise summation with
5//! automatic gradient tracking when enabled.
6//!
7//! # Operations
8//!
9//! * `sum()` - Sum all elements into a scalar tensor
10//! * `sum_dims()` - Sum elements along specified dimensions
11//!
12//! # Examples
13//!
14//! ```
15//! use train_station::Tensor;
16//!
17//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
18//! let total = tensor.sum();
19//! assert_eq!(total.get(&[0]), 10.0); // 1 + 2 + 3 + 4 = 10
20//! ```
21
22use crate::gradtrack::{GradEngine, GradFn};
23use crate::tensor::core::Tensor;
24
25impl Tensor {
26 /// Returns the sum of all elements in the tensor
27 ///
28 /// This operation computes the sum of all elements across all dimensions,
29 /// reducing the tensor to a scalar value. The output is a tensor with shape \[1\]
30 /// containing the sum as a float.
31 ///
32 /// When `requires_grad` is enabled, this operation supports automatic gradient
33 /// tracking through the GradTrack system.
34 ///
35 /// # Returns
36 ///
37 /// A tensor with shape \[1\] containing the sum of all elements
38 ///
39 /// # Examples
40 ///
41 /// ```
42 /// use train_station::Tensor;
43 ///
44 /// // Basic sum calculation
45 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
46 /// let total = tensor.sum();
47 /// assert_eq!(total.shape().dims, vec![1]);
48 /// assert_eq!(total.get(&[0]), 10.0); // 1 + 2 + 3 + 4 = 10
49 /// ```
50 ///
51 /// ```
52 /// use train_station::Tensor;
53 ///
54 /// // Sum with gradient tracking
55 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3])
56 /// .unwrap()
57 /// .with_requires_grad();
58 /// let mut total = tensor.sum();
59 /// total.backward(None);
60 /// let grad = tensor.grad_by_value().expect("gradient should exist");
61 /// // Gradient should be [1.0, 1.0, 1.0] for each element
62 /// assert_eq!(grad.get(&[0]), 1.0);
63 /// assert_eq!(grad.get(&[1]), 1.0);
64 /// assert_eq!(grad.get(&[2]), 1.0);
65 /// ```
66 ///
67 /// ```
68 /// use train_station::Tensor;
69 ///
70 /// // Sum of empty tensor
71 /// let tensor = Tensor::new(vec![0]);
72 /// let total = tensor.sum();
73 /// assert_eq!(total.get(&[0]), 0.0); // Sum of empty tensor is 0
74 /// ```
75 ///
76 /// # Performance
77 ///
78 /// Uses optimized contiguous tensor path with 4x loop unrolling for better
79 /// performance. Non-contiguous tensors use stride-aware iteration.
80 #[track_caller]
81 pub fn sum(&self) -> Tensor {
82 let mut out = Tensor::new(vec![1]);
83 if self.size() == 0 {
84 out.fill(0.0);
85 } else {
86 let mut acc0 = 0.0f32;
87
88 if self.is_contiguous() {
89 // Fast path for contiguous tensors
90 unsafe {
91 let src = self.as_ptr();
92 let size = self.size();
93 let mut i = 0usize;
94 // Unrolled loop for better throughput
95 while i + 4 <= size {
96 let x0 = *src.add(i);
97 let x1 = *src.add(i + 1);
98 let x2 = *src.add(i + 2);
99 let x3 = *src.add(i + 3);
100 acc0 += x0 + x1 + x2 + x3;
101 i += 4;
102 }
103 while i < size {
104 acc0 += *src.add(i);
105 i += 1;
106 }
107 }
108 } else {
109 // Stride-aware path for non-contiguous tensors
110 let dims = self.shape().dims.clone();
111 for flat_idx in 0..self.size() {
112 // Convert flat index to multi-dimensional coordinates
113 let mut coords = vec![0; dims.len()];
114 let mut tmp = flat_idx;
115 for k in (0..dims.len()).rev() {
116 coords[k] = tmp % dims[k];
117 tmp /= dims[k];
118 }
119
120 // Get value using stride-aware offset
121 let offset = self.shape().offset(&coords);
122 let value = unsafe { *self.as_ptr().add(offset) };
123 acc0 += value;
124 }
125 }
126
127 unsafe {
128 *out.as_mut_ptr() = acc0;
129 }
130 }
131
132 if self.requires_grad() {
133 out.set_requires_grad_internal(true);
134 let grad_fn = GradFn::ReduceSum {
135 input_shape: self.shape().dims.clone(),
136 };
137 out.set_grad_fn(grad_fn.clone());
138 GradEngine::register_operation(out.id(), vec![self.id()], grad_fn);
139 }
140
141 out
142 }
143
144 /// Returns the sum of elements along specified dimensions
145 ///
146 /// This operation computes the sum of elements along the specified dimensions,
147 /// reducing the tensor while optionally preserving the reduced dimensions as
148 /// size-1 dimensions.
149 ///
150 /// The output shape depends on the `keepdim` parameter:
151 /// * If `keepdim` is `true`, the reduced dimensions are kept with size 1
152 /// * If `keepdim` is `false`, the reduced dimensions are removed
153 ///
154 /// When `requires_grad` is enabled, this operation supports automatic gradient
155 /// tracking through the GradTrack system.
156 ///
157 /// # Arguments
158 ///
159 /// * `dims` - Vector of dimension indices to sum over (must be valid for tensor rank)
160 /// * `keepdim` - Whether to keep reduced dimensions as size-1 dimensions
161 ///
162 /// # Returns
163 ///
164 /// A tensor with sum computed over the specified dimensions
165 ///
166 /// # Panics
167 ///
168 /// * If `dims` is empty
169 /// * If any dimension index is out of bounds for the tensor rank
170 ///
171 /// # Examples
172 ///
173 /// ```
174 /// use train_station::Tensor;
175 ///
176 /// // Sum along rows (dimension 0) with keepdim=false
177 /// let matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
178 /// let row_sums = matrix.sum_dims(&[0], false);
179 /// assert_eq!(row_sums.shape().dims, vec![2]);
180 /// assert_eq!(row_sums.get(&[0]), 4.0); // 1 + 3 = 4
181 /// assert_eq!(row_sums.get(&[1]), 6.0); // 2 + 4 = 6
182 /// ```
183 ///
184 /// ```
185 /// use train_station::Tensor;
186 ///
187 /// // Sum along columns (dimension 1) with keepdim=true
188 /// let matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
189 /// let col_sums = matrix.sum_dims(&[1], true);
190 /// assert_eq!(col_sums.shape().dims, vec![2, 1]);
191 /// assert_eq!(col_sums.get(&[0, 0]), 3.0); // 1 + 2 = 3
192 /// assert_eq!(col_sums.get(&[1, 0]), 7.0); // 3 + 4 = 7
193 /// ```
194 ///
195 /// ```
196 /// use train_station::Tensor;
197 ///
198 /// // Sum over multiple dimensions
199 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
200 /// let total = tensor.sum_dims(&[0, 1], false);
201 /// assert_eq!(total.shape().dims, vec![1]);
202 /// assert_eq!(total.get(&[0]), 10.0); // 1 + 2 + 3 + 4 = 10
203 /// ```
204 ///
205 /// ```
206 /// use train_station::Tensor;
207 ///
208 /// // Sum with gradient tracking
209 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2])
210 /// .unwrap()
211 /// .with_requires_grad();
212 /// let mut row_sums = tensor.sum_dims(&[0], false);
213 /// row_sums.backward(None);
214 /// let grad = tensor.grad_by_value().expect("gradient should exist");
215 /// // Gradient should be [1.0, 1.0, 1.0, 1.0] for each element
216 /// assert_eq!(grad.get(&[0, 0]), 1.0);
217 /// assert_eq!(grad.get(&[0, 1]), 1.0);
218 /// assert_eq!(grad.get(&[1, 0]), 1.0);
219 /// assert_eq!(grad.get(&[1, 1]), 1.0);
220 /// ```
221 ///
222 /// # Performance
223 ///
224 /// Uses efficient coordinate-based iteration that works correctly with
225 /// both contiguous and non-contiguous tensor layouts.
226 #[track_caller]
227 pub fn sum_dims(&self, dims: &[usize], keepdim: bool) -> Tensor {
228 assert!(!dims.is_empty(), "sum_dims requires at least one dimension");
229 let rank = self.shape().rank();
230 for &d in dims {
231 assert!(
232 d < rank,
233 "sum_dims dim {} out of bounds for rank {}",
234 d,
235 rank
236 );
237 }
238
239 // Build output shape
240 let mut out_dims = self.shape().dims.clone();
241 let mut reduced: Vec<usize> = dims.to_vec();
242 reduced.sort_unstable();
243 reduced.dedup();
244 for &d in reduced.iter() {
245 out_dims[d] = if keepdim { 1 } else { 0 };
246 }
247 if !keepdim {
248 out_dims.retain(|&s| s != 0);
249 }
250 if out_dims.is_empty() {
251 out_dims.push(1);
252 }
253 let mut out = Tensor::zeros(out_dims.clone());
254
255 // Accumulate along reduced dims
256 let in_shape = self.shape().dims.clone();
257 let out_rank = out.shape().rank();
258 let mut in_coords = vec![0usize; rank];
259 unsafe {
260 let dst = out.as_mut_ptr();
261 // Iterate over all input elements, map to output index
262 for lin in 0..self.size() {
263 let mut tmp = lin;
264 for i in (0..rank).rev() {
265 let s = in_shape[i];
266 in_coords[i] = if s == 0 { 0 } else { tmp % s };
267 if s != 0 {
268 tmp /= s;
269 }
270 }
271
272 // Get input value using stride-aware offset
273 let in_offset = self.shape().offset(&in_coords);
274 let value = *self.as_ptr().add(in_offset);
275
276 // build output coords
277 let mut out_coords: Vec<usize> = Vec::with_capacity(out_rank);
278 for (i, &c) in in_coords.iter().enumerate().take(rank) {
279 if reduced.contains(&i) {
280 if keepdim {
281 out_coords.push(0);
282 }
283 } else {
284 out_coords.push(c);
285 }
286 }
287 let off = if out_coords.is_empty() {
288 0
289 } else {
290 out.shape().offset(&out_coords)
291 };
292 *dst.add(off) += value;
293 }
294 }
295
296 if self.requires_grad() {
297 out.set_requires_grad_internal(true);
298 let grad_fn = GradFn::ReduceSumDims {
299 dims: reduced,
300 input_shape: self.shape().dims.clone(),
301 keepdim,
302 };
303 out.set_grad_fn(grad_fn.clone());
304 GradEngine::register_operation(out.id(), vec![self.id()], grad_fn);
305 }
306
307 out
308 }
309}
310
311#[cfg(test)]
312mod tests {
313 use super::*;
314
315 #[test]
316 fn test_sum_forward_basic() {
317 let mut x = Tensor::zeros(vec![2, 3]);
318 unsafe {
319 for i in 0..6 {
320 *x.as_mut_ptr().add(i) = (i as f32) * 0.5;
321 }
322 }
323 let s = x.sum();
324 assert_eq!(s.shape().dims, vec![1]);
325 unsafe {
326 assert!((*s.as_ptr() - 7.5).abs() < 1e-6);
327 }
328 }
329
330 #[test]
331 fn test_sum_autograd_all_ones_grad() {
332 let mut x = Tensor::zeros(vec![2, 2]).with_requires_grad();
333 unsafe {
334 for i in 0..4 {
335 *x.as_mut_ptr().add(i) = i as f32;
336 }
337 }
338 let mut s = x.sum();
339 s.backward(None);
340 let gx = x.grad_by_value().expect("grad missing");
341 for i in 0..4 {
342 unsafe {
343 assert_eq!(*gx.as_ptr().add(i), 1.0);
344 }
345 }
346 }
347
348 #[test]
349 fn test_sum_chain_autograd() {
350 let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2])
351 .unwrap()
352 .with_requires_grad();
353 let y = x.mul_scalar(2.0).add_scalar(1.0);
354 let mut s = y.sum();
355 s.backward(None);
356 let gx = x.grad_by_value().expect("grad missing");
357 // d/dx of sum(2x+1) = 2 for each element
358 for i in 0..4 {
359 unsafe {
360 assert_eq!(*gx.as_ptr().add(i), 2.0);
361 }
362 }
363 }
364
365 #[test]
366 fn test_sum_non_contiguous_transpose() {
367 // Test sum on transposed tensor (non-contiguous view)
368 let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
369 // Original: [[1, 2, 3], [4, 5, 6]]
370
371 let x_t = x.transpose(0, 1);
372 // Transposed: [[1, 4], [2, 5], [3, 6]]
373 assert!(!x_t.is_contiguous()); // Should be a view
374
375 let sum_orig = x.sum();
376 let sum_view = x_t.sum();
377
378 // Both should give the same result: 1+2+3+4+5+6 = 21
379 assert_eq!(sum_orig.get(&[0]), 21.0);
380 assert_eq!(sum_view.get(&[0]), 21.0);
381 }
382
383 #[test]
384 fn test_sum_dims_non_contiguous() {
385 // Test sum_dims on non-contiguous tensor
386 let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
387 let x_t = x.transpose(0, 1); // [3, 2]
388 assert!(!x_t.is_contiguous());
389
390 // Sum along dim 0 of transposed tensor
391 let sum_dim0 = x_t.sum_dims(&[0], false);
392 assert_eq!(sum_dim0.shape().dims, vec![2]);
393 // Should be [1+2+3, 4+5+6] = [6, 15]
394 assert_eq!(sum_dim0.get(&[0]), 6.0);
395 assert_eq!(sum_dim0.get(&[1]), 15.0);
396
397 // Sum along dim 1 of transposed tensor
398 let sum_dim1 = x_t.sum_dims(&[1], false);
399 assert_eq!(sum_dim1.shape().dims, vec![3]);
400 // Should be [1+4, 2+5, 3+6] = [5, 7, 9]
401 assert_eq!(sum_dim1.get(&[0]), 5.0);
402 assert_eq!(sum_dim1.get(&[1]), 7.0);
403 assert_eq!(sum_dim1.get(&[2]), 9.0);
404 }
405
406 #[test]
407 fn test_sum_permuted_tensor() {
408 // Test with permuted tensor
409 let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
410 let x = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
411
412 // Permute dimensions [2, 3, 4] -> [4, 2, 3]
413 let x_perm = x.permute(vec![2, 1, 0]);
414 assert!(!x_perm.is_contiguous());
415
416 let sum_orig = x.sum();
417 let sum_perm = x_perm.sum();
418
419 // Should give same result
420 assert_eq!(sum_orig.get(&[0]), sum_perm.get(&[0]));
421
422 // Expected sum: 0+1+2+...+23 = 23*24/2 = 276
423 assert_eq!(sum_orig.get(&[0]), 276.0);
424 }
425}