train_station/tensor/reductions/sum.rs
1//! Sum reduction operations for tensors
2//!
3//! This module provides sum reduction operations that compute the sum of tensor elements.
4//! These operations support both global summation and dimension-wise summation with
5//! automatic gradient tracking when enabled.
6//!
7//! # Operations
8//!
9//! * `sum()` - Sum all elements into a scalar tensor
10//! * `sum_dims()` - Sum elements along specified dimensions
11//!
12//! # Examples
13//!
14//! ```
15//! use train_station::Tensor;
16//!
17//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
18//! let total = tensor.sum();
19//! assert_eq!(total.get(&[0]), 10.0); // 1 + 2 + 3 + 4 = 10
20//! ```
21
22use crate::gradtrack::{GradEngine, GradFn};
23use crate::tensor::core::Tensor;
24
25impl Tensor {
26 /// Returns the sum of all elements in the tensor
27 ///
28 /// This operation computes the sum of all elements across all dimensions,
29 /// reducing the tensor to a scalar value. The output is a tensor with shape \[1\]
30 /// containing the sum as a float.
31 ///
32 /// When `requires_grad` is enabled, this operation supports automatic gradient
33 /// tracking through the GradTrack system.
34 ///
35 /// # Returns
36 ///
37 /// A tensor with shape \[1\] containing the sum of all elements
38 ///
39 /// # Examples
40 ///
41 /// ```
42 /// use train_station::Tensor;
43 ///
44 /// // Basic sum calculation
45 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
46 /// let total = tensor.sum();
47 /// assert_eq!(total.shape().dims, vec![1]);
48 /// assert_eq!(total.get(&[0]), 10.0); // 1 + 2 + 3 + 4 = 10
49 /// ```
50 ///
51 /// ```
52 /// use train_station::Tensor;
53 ///
54 /// // Sum with gradient tracking
55 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3])
56 /// .unwrap()
57 /// .with_requires_grad();
58 /// let mut total = tensor.sum();
59 /// total.backward(None);
60 /// let grad = tensor.grad_by_value().expect("gradient should exist");
61 /// // Gradient should be [1.0, 1.0, 1.0] for each element
62 /// assert_eq!(grad.get(&[0]), 1.0);
63 /// assert_eq!(grad.get(&[1]), 1.0);
64 /// assert_eq!(grad.get(&[2]), 1.0);
65 /// ```
66 ///
67 /// ```
68 /// use train_station::Tensor;
69 ///
70 /// // Sum of empty tensor
71 /// let tensor = Tensor::new(vec![0]);
72 /// let total = tensor.sum();
73 /// assert_eq!(total.get(&[0]), 0.0); // Sum of empty tensor is 0
74 /// ```
75 ///
76 /// # Performance
77 ///
78 /// Uses optimized contiguous tensor path with 4x loop unrolling for better
79 /// performance. Non-contiguous tensors use stride-aware iteration.
80 pub fn sum(&self) -> Tensor {
81 let mut out = Tensor::new(vec![1]);
82 if self.size() == 0 {
83 out.fill(0.0);
84 } else {
85 let mut acc0 = 0.0f32;
86
87 if self.is_contiguous() {
88 // Fast path for contiguous tensors
89 unsafe {
90 let src = self.as_ptr();
91 let size = self.size();
92 let mut i = 0usize;
93 // Unrolled loop for better throughput
94 while i + 4 <= size {
95 let x0 = *src.add(i);
96 let x1 = *src.add(i + 1);
97 let x2 = *src.add(i + 2);
98 let x3 = *src.add(i + 3);
99 acc0 += x0 + x1 + x2 + x3;
100 i += 4;
101 }
102 while i < size {
103 acc0 += *src.add(i);
104 i += 1;
105 }
106 }
107 } else {
108 // Stride-aware path for non-contiguous tensors
109 let dims = self.shape().dims.clone();
110 for flat_idx in 0..self.size() {
111 // Convert flat index to multi-dimensional coordinates
112 let mut coords = vec![0; dims.len()];
113 let mut tmp = flat_idx;
114 for k in (0..dims.len()).rev() {
115 coords[k] = tmp % dims[k];
116 tmp /= dims[k];
117 }
118
119 // Get value using stride-aware offset
120 let offset = self.shape().offset(&coords);
121 let value = unsafe { *self.as_ptr().add(offset) };
122 acc0 += value;
123 }
124 }
125
126 unsafe {
127 *out.as_mut_ptr() = acc0;
128 }
129 }
130
131 if self.requires_grad() {
132 out.set_requires_grad_internal(true);
133 let grad_fn = GradFn::ReduceSum {
134 input_shape: self.shape().dims.clone(),
135 };
136 out.set_grad_fn(grad_fn.clone());
137 GradEngine::register_operation(out.id(), vec![self.id()], grad_fn);
138 }
139
140 out
141 }
142
143 /// Returns the sum of elements along specified dimensions
144 ///
145 /// This operation computes the sum of elements along the specified dimensions,
146 /// reducing the tensor while optionally preserving the reduced dimensions as
147 /// size-1 dimensions.
148 ///
149 /// The output shape depends on the `keepdim` parameter:
150 /// * If `keepdim` is `true`, the reduced dimensions are kept with size 1
151 /// * If `keepdim` is `false`, the reduced dimensions are removed
152 ///
153 /// When `requires_grad` is enabled, this operation supports automatic gradient
154 /// tracking through the GradTrack system.
155 ///
156 /// # Arguments
157 ///
158 /// * `dims` - Vector of dimension indices to sum over (must be valid for tensor rank)
159 /// * `keepdim` - Whether to keep reduced dimensions as size-1 dimensions
160 ///
161 /// # Returns
162 ///
163 /// A tensor with sum computed over the specified dimensions
164 ///
165 /// # Panics
166 ///
167 /// * If `dims` is empty
168 /// * If any dimension index is out of bounds for the tensor rank
169 ///
170 /// # Examples
171 ///
172 /// ```
173 /// use train_station::Tensor;
174 ///
175 /// // Sum along rows (dimension 0) with keepdim=false
176 /// let matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
177 /// let row_sums = matrix.sum_dims(&[0], false);
178 /// assert_eq!(row_sums.shape().dims, vec![2]);
179 /// assert_eq!(row_sums.get(&[0]), 4.0); // 1 + 3 = 4
180 /// assert_eq!(row_sums.get(&[1]), 6.0); // 2 + 4 = 6
181 /// ```
182 ///
183 /// ```
184 /// use train_station::Tensor;
185 ///
186 /// // Sum along columns (dimension 1) with keepdim=true
187 /// let matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
188 /// let col_sums = matrix.sum_dims(&[1], true);
189 /// assert_eq!(col_sums.shape().dims, vec![2, 1]);
190 /// assert_eq!(col_sums.get(&[0, 0]), 3.0); // 1 + 2 = 3
191 /// assert_eq!(col_sums.get(&[1, 0]), 7.0); // 3 + 4 = 7
192 /// ```
193 ///
194 /// ```
195 /// use train_station::Tensor;
196 ///
197 /// // Sum over multiple dimensions
198 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
199 /// let total = tensor.sum_dims(&[0, 1], false);
200 /// assert_eq!(total.shape().dims, vec![1]);
201 /// assert_eq!(total.get(&[0]), 10.0); // 1 + 2 + 3 + 4 = 10
202 /// ```
203 ///
204 /// ```
205 /// use train_station::Tensor;
206 ///
207 /// // Sum with gradient tracking
208 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2])
209 /// .unwrap()
210 /// .with_requires_grad();
211 /// let mut row_sums = tensor.sum_dims(&[0], false);
212 /// row_sums.backward(None);
213 /// let grad = tensor.grad_by_value().expect("gradient should exist");
214 /// // Gradient should be [1.0, 1.0, 1.0, 1.0] for each element
215 /// assert_eq!(grad.get(&[0, 0]), 1.0);
216 /// assert_eq!(grad.get(&[0, 1]), 1.0);
217 /// assert_eq!(grad.get(&[1, 0]), 1.0);
218 /// assert_eq!(grad.get(&[1, 1]), 1.0);
219 /// ```
220 ///
221 /// # Performance
222 ///
223 /// Uses efficient coordinate-based iteration that works correctly with
224 /// both contiguous and non-contiguous tensor layouts.
225 pub fn sum_dims(&self, dims: &[usize], keepdim: bool) -> Tensor {
226 assert!(!dims.is_empty(), "sum_dims requires at least one dimension");
227 let rank = self.shape().rank();
228 for &d in dims {
229 assert!(
230 d < rank,
231 "sum_dims dim {} out of bounds for rank {}",
232 d,
233 rank
234 );
235 }
236
237 // Build output shape
238 let mut out_dims = self.shape().dims.clone();
239 let mut reduced: Vec<usize> = dims.to_vec();
240 reduced.sort_unstable();
241 reduced.dedup();
242 for &d in reduced.iter() {
243 out_dims[d] = if keepdim { 1 } else { 0 };
244 }
245 if !keepdim {
246 out_dims.retain(|&s| s != 0);
247 }
248 if out_dims.is_empty() {
249 out_dims.push(1);
250 }
251 let mut out = Tensor::zeros(out_dims.clone());
252
253 // Accumulate along reduced dims
254 let in_shape = self.shape().dims.clone();
255 let out_rank = out.shape().rank();
256 let mut in_coords = vec![0usize; rank];
257 unsafe {
258 let dst = out.as_mut_ptr();
259 // Iterate over all input elements, map to output index
260 for lin in 0..self.size() {
261 let mut tmp = lin;
262 for i in (0..rank).rev() {
263 let s = in_shape[i];
264 in_coords[i] = if s == 0 { 0 } else { tmp % s };
265 if s != 0 {
266 tmp /= s;
267 }
268 }
269
270 // Get input value using stride-aware offset
271 let in_offset = self.shape().offset(&in_coords);
272 let value = *self.as_ptr().add(in_offset);
273
274 // build output coords
275 let mut out_coords: Vec<usize> = Vec::with_capacity(out_rank);
276 for (i, &c) in in_coords.iter().enumerate().take(rank) {
277 if reduced.contains(&i) {
278 if keepdim {
279 out_coords.push(0);
280 }
281 } else {
282 out_coords.push(c);
283 }
284 }
285 let off = if out_coords.is_empty() {
286 0
287 } else {
288 out.shape().offset(&out_coords)
289 };
290 *dst.add(off) += value;
291 }
292 }
293
294 if self.requires_grad() {
295 out.set_requires_grad_internal(true);
296 let grad_fn = GradFn::ReduceSumDims {
297 dims: reduced,
298 input_shape: self.shape().dims.clone(),
299 keepdim,
300 };
301 out.set_grad_fn(grad_fn.clone());
302 GradEngine::register_operation(out.id(), vec![self.id()], grad_fn);
303 }
304
305 out
306 }
307}
308
309#[cfg(test)]
310mod tests {
311 use super::*;
312
313 #[test]
314 fn test_sum_forward_basic() {
315 let mut x = Tensor::zeros(vec![2, 3]);
316 unsafe {
317 for i in 0..6 {
318 *x.as_mut_ptr().add(i) = (i as f32) * 0.5;
319 }
320 }
321 let s = x.sum();
322 assert_eq!(s.shape().dims, vec![1]);
323 unsafe {
324 assert!((*s.as_ptr() - 7.5).abs() < 1e-6);
325 }
326 }
327
328 #[test]
329 fn test_sum_autograd_all_ones_grad() {
330 let mut x = Tensor::zeros(vec![2, 2]).with_requires_grad();
331 unsafe {
332 for i in 0..4 {
333 *x.as_mut_ptr().add(i) = i as f32;
334 }
335 }
336 let mut s = x.sum();
337 s.backward(None);
338 let gx = x.grad_by_value().expect("grad missing");
339 for i in 0..4 {
340 unsafe {
341 assert_eq!(*gx.as_ptr().add(i), 1.0);
342 }
343 }
344 }
345
346 #[test]
347 fn test_sum_chain_autograd() {
348 let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2])
349 .unwrap()
350 .with_requires_grad();
351 let y = x.mul_scalar(2.0).add_scalar(1.0);
352 let mut s = y.sum();
353 s.backward(None);
354 let gx = x.grad_by_value().expect("grad missing");
355 // d/dx of sum(2x+1) = 2 for each element
356 for i in 0..4 {
357 unsafe {
358 assert_eq!(*gx.as_ptr().add(i), 2.0);
359 }
360 }
361 }
362
363 #[test]
364 fn test_sum_non_contiguous_transpose() {
365 // Test sum on transposed tensor (non-contiguous view)
366 let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
367 // Original: [[1, 2, 3], [4, 5, 6]]
368
369 let x_t = x.transpose(0, 1);
370 // Transposed: [[1, 4], [2, 5], [3, 6]]
371 assert!(!x_t.is_contiguous()); // Should be a view
372
373 let sum_orig = x.sum();
374 let sum_view = x_t.sum();
375
376 // Both should give the same result: 1+2+3+4+5+6 = 21
377 assert_eq!(sum_orig.get(&[0]), 21.0);
378 assert_eq!(sum_view.get(&[0]), 21.0);
379 }
380
381 #[test]
382 fn test_sum_dims_non_contiguous() {
383 // Test sum_dims on non-contiguous tensor
384 let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
385 let x_t = x.transpose(0, 1); // [3, 2]
386 assert!(!x_t.is_contiguous());
387
388 // Sum along dim 0 of transposed tensor
389 let sum_dim0 = x_t.sum_dims(&[0], false);
390 assert_eq!(sum_dim0.shape().dims, vec![2]);
391 // Should be [1+2+3, 4+5+6] = [6, 15]
392 assert_eq!(sum_dim0.get(&[0]), 6.0);
393 assert_eq!(sum_dim0.get(&[1]), 15.0);
394
395 // Sum along dim 1 of transposed tensor
396 let sum_dim1 = x_t.sum_dims(&[1], false);
397 assert_eq!(sum_dim1.shape().dims, vec![3]);
398 // Should be [1+4, 2+5, 3+6] = [5, 7, 9]
399 assert_eq!(sum_dim1.get(&[0]), 5.0);
400 assert_eq!(sum_dim1.get(&[1]), 7.0);
401 assert_eq!(sum_dim1.get(&[2]), 9.0);
402 }
403
404 #[test]
405 fn test_sum_permuted_tensor() {
406 // Test with permuted tensor
407 let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
408 let x = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
409
410 // Permute dimensions [2, 3, 4] -> [4, 2, 3]
411 let x_perm = x.permute(vec![2, 1, 0]);
412 assert!(!x_perm.is_contiguous());
413
414 let sum_orig = x.sum();
415 let sum_perm = x_perm.sum();
416
417 // Should give same result
418 assert_eq!(sum_orig.get(&[0]), sum_perm.get(&[0]));
419
420 // Expected sum: 0+1+2+...+23 = 23*24/2 = 276
421 assert_eq!(sum_orig.get(&[0]), 276.0);
422 }
423}