train_station/tensor/transform/stack.rs
1//! Tensor stacking operations
2//!
3//! This module provides tensor stacking functionality that combines multiple
4//! tensors along a new dimension. Stacking is a fundamental tensor transformation
5//! operation used in machine learning for combining multiple feature maps,
6//! creating batch dimensions, and implementing complex tensor manipulations
7//! that require adding new axes to tensor data.
8//!
9//! # Operations
10//!
11//! * `stack()` - Stack multiple tensors along a new dimension
12//!
13//! # Performance Characteristics
14//!
15//! * **SIMD Optimized**: AVX2 acceleration for large block copies
16//! * **Memory Efficient**: Optimized block-wise copying with minimal allocations
17//! * **Contiguous Output**: Always produces a contiguous tensor for optimal performance
18//! * **Gradient Tracking**: Full GradTrack support for automatic differentiation
19//! * **Shape Validation**: Comprehensive error checking for compatible tensor shapes
20//!
21//! # Examples
22//!
23//! ```
24//! use train_station::Tensor;
25//!
26//! // Stack two 1D tensors along dimension 0
27//! let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
28//! let b = Tensor::from_slice(&[4.0, 5.0, 6.0], vec![3]).unwrap();
29//! let stacked = Tensor::stack(&[a, b], 0);
30//! assert_eq!(stacked.shape().dims(), vec![2, 3]);
31//! assert_eq!(stacked.get(&[0, 0]), 1.0);
32//! assert_eq!(stacked.get(&[1, 2]), 6.0);
33//! ```
34//!
35//! ```
36//! use train_station::Tensor;
37//!
38//! // Stack multiple 2D tensors along dimension 1
39//! let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
40//! let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
41//! let c = Tensor::from_slice(&[9.0, 10.0, 11.0, 12.0], vec![2, 2]).unwrap();
42//! let stacked = Tensor::stack(&[a, b, c], 1);
43//! assert_eq!(stacked.shape().dims(), vec![2, 3, 2]);
44//! ```
45
46use crate::gradtrack::{GradEngine, GradFn};
47use crate::tensor::core::Tensor;
48use crate::tensor::iterator::collect::optimized_copy;
49
50impl Tensor {
51 /// Stack a list of tensors along a new dimension
52 ///
53 /// Combines multiple tensors by adding a new dimension at the specified
54 /// position. All input tensors must have identical shapes, and the output
55 /// tensor will have a new dimension of size equal to the number of input
56 /// tensors. This operation is similar to PyTorch's `torch.stack` function.
57 ///
58 /// The stacking operation creates a new axis in the output tensor, unlike
59 /// concatenation which operates along existing dimensions. This makes
60 /// stacking useful for creating batch dimensions, combining feature maps,
61 /// and implementing operations that require adding new tensor axes.
62 ///
63 /// # Arguments
64 ///
65 /// * `tensors` - Array of tensors to stack. All tensors must have identical shapes.
66 /// * `dim` - Index of the new axis in the output shape (0 <= dim <= rank)
67 ///
68 /// # Returns
69 ///
70 /// A new tensor with the stacked data. The output shape is the input shape
71 /// with a new dimension of size `tensors.len()` inserted at position `dim`.
72 ///
73 /// # Panics
74 ///
75 /// * If the tensor array is empty
76 /// * If any tensor has a different shape than the first tensor
77 /// * If `dim` is out of bounds (dim > rank of input tensors)
78 ///
79 /// # Examples
80 ///
81 /// ```
82 /// use train_station::Tensor;
83 ///
84 /// // Stack two 1D tensors along dimension 0
85 /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
86 /// let b = Tensor::from_slice(&[4.0, 5.0, 6.0], vec![3]).unwrap();
87 /// let stacked = Tensor::stack(&[a, b], 0);
88 /// assert_eq!(stacked.shape().dims(), vec![2, 3]);
89 /// assert_eq!(stacked.get(&[0, 0]), 1.0);
90 /// assert_eq!(stacked.get(&[1, 2]), 6.0);
91 /// ```
92 ///
93 /// ```
94 /// use train_station::Tensor;
95 ///
96 /// // Stack multiple 2D tensors along dimension 1
97 /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
98 /// let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
99 /// let c = Tensor::from_slice(&[9.0, 10.0, 11.0, 12.0], vec![2, 2]).unwrap();
100 /// let stacked = Tensor::stack(&[a, b, c], 1);
101 /// assert_eq!(stacked.shape().dims(), vec![2, 3, 2]);
102 /// assert_eq!(stacked.get(&[0, 0, 0]), 1.0);
103 /// assert_eq!(stacked.get(&[1, 2, 1]), 12.0);
104 /// ```
105 ///
106 /// ```
107 /// use train_station::Tensor;
108 ///
109 /// // Stack with gradient tracking
110 /// let mut a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
111 /// let mut b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
112 /// a.set_requires_grad(true);
113 /// b.set_requires_grad(true);
114 ///
115 /// let stacked = Tensor::stack(&[a, b], 0);
116 /// assert!(stacked.requires_grad());
117 /// assert_eq!(stacked.shape().dims(), vec![2, 2]);
118 /// ```
119 ///
120 /// ```
121 /// use train_station::Tensor;
122 ///
123 /// // Stack 3D tensors along the last dimension
124 /// let data1: Vec<f32> = (0..8).map(|i| i as f32).collect();
125 /// let data2: Vec<f32> = (8..16).map(|i| i as f32).collect();
126 /// let a = Tensor::from_slice(&data1, vec![2, 2, 2]).unwrap();
127 /// let b = Tensor::from_slice(&data2, vec![2, 2, 2]).unwrap();
128 /// let stacked = Tensor::stack(&[a, b], 3);
129 /// assert_eq!(stacked.shape().dims(), vec![2, 2, 2, 2]);
130 /// assert_eq!(stacked.get(&[0, 0, 0, 0]), 0.0);
131 /// assert_eq!(stacked.get(&[1, 1, 1, 1]), 15.0);
132 /// ```
133 ///
134 /// # Performance
135 ///
136 /// - **Time Complexity**: O(n) where n is the total number of elements
137 /// - **Memory Usage**: Allocates new contiguous tensor for output
138 /// - **SIMD Optimization**: Uses AVX2 acceleration for large block copies
139 /// - **Block-wise Copying**: Optimized copying strategy for better cache performance
140 /// - **Gradient Tracking**: Preserves gradient requirements and tracking
141 ///
142 /// # Relationship to Other Operations
143 ///
144 /// This operation is related to other tensor transformations:
145 /// - `cat()` - Concatenates tensors along existing dimensions
146 /// - `unsqueeze()` - Adds a single dimension of size 1
147 /// - `reshape()` - Changes tensor shape without adding dimensions
148 ///
149 /// # Memory Layout
150 ///
151 /// The output tensor is always contiguous, with elements arranged so that
152 /// the stacked dimension is the fastest-changing index. This ensures optimal
153 /// performance for subsequent operations and maintains compatibility with
154 /// SIMD optimizations.
155 ///
156 /// # Gradient Computation
157 ///
158 /// During backward passes, gradients are split along the stacked dimension
159 /// and distributed back to the original input tensors. This is implemented
160 /// using the same gradient function as concatenation, treating the stack
161 /// operation as concatenation along a new axis.
162 #[track_caller]
163 pub fn stack(tensors: &[Tensor], dim: usize) -> Tensor {
164 assert!(!tensors.is_empty(), "stack requires at least one tensor");
165
166 // Validate all shapes identical
167 let base_dims = tensors[0].shape().dims();
168 for t in tensors.iter() {
169 assert_eq!(
170 t.shape().dims(),
171 base_dims,
172 "All tensors must have identical shapes for stack"
173 );
174 }
175
176 let rank = base_dims.len();
177 assert!(
178 dim <= rank,
179 "stack dim {} out of bounds for rank {}",
180 dim,
181 rank
182 );
183
184 // Compute output shape by inserting new axis of size = tensors.len()
185 let mut out_dims = Vec::with_capacity(rank + 1);
186 out_dims.extend_from_slice(&base_dims[..dim]);
187 out_dims.push(tensors.len());
188 out_dims.extend_from_slice(&base_dims[dim..]);
189
190 // Materialize into a new contiguous tensor
191 let mut output = Tensor::new(out_dims.clone());
192
193 // Copy block-wise: treat stack dim separately
194 // For output shape [pre..., K=tensors.len(), post...]
195 // inner = product(post...), outer = product(pre...)
196 let inner: usize = base_dims[dim..].iter().product();
197 let outer: usize = base_dims[..dim].iter().product();
198
199 unsafe {
200 let dst_ptr = output.as_mut_ptr();
201 for outer_idx in 0..outer {
202 for (k, t) in tensors.iter().enumerate() {
203 // Ensure contiguous source
204 let src = if t.is_contiguous() {
205 t.clone()
206 } else {
207 t.contiguous()
208 };
209 // Source offset: within each tensor, block size is inner
210 let src_base = outer_idx * inner;
211 let src_ptr = src.as_ptr().add(src_base);
212
213 // Destination offset computes with inserted axis
214 // out block along stacked axis of length K, each block is inner
215 let dst_base = outer_idx * (tensors.len() * inner) + k * inner;
216 optimized_copy(src_ptr, dst_ptr.add(dst_base), inner);
217 }
218 }
219 }
220
221 // GradTrack: stack is like cat with a new axis; gradient splits along that axis
222 let any_requires = tensors.iter().any(|t| t.requires_grad());
223 if any_requires {
224 output.set_requires_grad(true);
225 // For GradFn::Cat, provide sizes along concat dim and input shapes
226 let mut input_ids = Vec::with_capacity(tensors.len());
227 let mut input_sizes = Vec::with_capacity(tensors.len());
228 let mut input_shapes = Vec::with_capacity(tensors.len());
229 for t in tensors.iter() {
230 if t.requires_grad() {
231 input_ids.push(t.id());
232 }
233 input_sizes.push(1); // each slice along new axis has length 1
234 input_shapes.push(t.shape().dims().to_vec());
235 }
236 let grad_fn = GradFn::Cat {
237 dim,
238 input_sizes,
239 input_shapes,
240 };
241 output.set_grad_fn(grad_fn.clone());
242 GradEngine::register_operation(output.id(), input_ids, grad_fn);
243 }
244
245 output
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252
253 #[test]
254 fn test_stack_basic() {
255 let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
256 let b = Tensor::from_slice(&[4.0, 5.0, 6.0], vec![3]).unwrap();
257 let y = Tensor::stack(&[a, b], 0);
258 assert_eq!(y.shape().dims(), vec![2, 3]);
259 assert_eq!(y.get(&[0, 0]), 1.0);
260 assert_eq!(y.get(&[1, 2]), 6.0);
261 }
262
263 #[test]
264 fn test_stack_multiple_tensors() {
265 let a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
266 let b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
267 let c = Tensor::from_slice(&[5.0, 6.0], vec![2]).unwrap();
268 let stacked = Tensor::stack(&[a, b, c], 0);
269 assert_eq!(stacked.shape().dims(), vec![3, 2]);
270 assert_eq!(stacked.get(&[0, 0]), 1.0);
271 assert_eq!(stacked.get(&[1, 1]), 4.0);
272 assert_eq!(stacked.get(&[2, 1]), 6.0);
273 }
274
275 #[test]
276 fn test_stack_2d_tensors() {
277 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
278 let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
279 let stacked = Tensor::stack(&[a, b], 1);
280 assert_eq!(stacked.shape().dims(), vec![2, 2, 2]);
281 assert_eq!(stacked.get(&[0, 0, 0]), 1.0);
282 assert_eq!(stacked.get(&[1, 1, 1]), 8.0);
283 }
284
285 #[test]
286 fn test_stack_with_gradients() {
287 let mut a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
288 let mut b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
289 a.set_requires_grad(true);
290 b.set_requires_grad(true);
291
292 let stacked = Tensor::stack(&[a, b], 0);
293 assert!(stacked.requires_grad());
294 assert_eq!(stacked.shape().dims(), vec![2, 2]);
295 }
296
297 #[test]
298 #[should_panic(expected = "stack requires at least one tensor")]
299 fn test_stack_empty() {
300 Tensor::stack(&[], 0);
301 }
302
303 #[test]
304 #[should_panic(expected = "All tensors must have identical shapes")]
305 fn test_stack_different_shapes() {
306 let a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
307 let b = Tensor::from_slice(&[3.0, 4.0, 5.0], vec![3]).unwrap();
308 Tensor::stack(&[a, b], 0);
309 }
310
311 #[test]
312 #[should_panic(expected = "stack dim 2 out of bounds for rank 1")]
313 fn test_stack_dim_out_of_bounds() {
314 let a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
315 let b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
316 Tensor::stack(&[a, b], 2);
317 }
318}