train_station/tensor/transform/
stack.rs

1//! Tensor stacking operations
2//!
3//! This module provides tensor stacking functionality that combines multiple
4//! tensors along a new dimension. Stacking is a fundamental tensor transformation
5//! operation used in machine learning for combining multiple feature maps,
6//! creating batch dimensions, and implementing complex tensor manipulations
7//! that require adding new axes to tensor data.
8//!
9//! # Operations
10//!
11//! * `stack()` - Stack multiple tensors along a new dimension
12//!
13//! # Performance Characteristics
14//!
15//! * **SIMD Optimized**: AVX2 acceleration for large block copies
16//! * **Memory Efficient**: Optimized block-wise copying with minimal allocations
17//! * **Contiguous Output**: Always produces a contiguous tensor for optimal performance
18//! * **Gradient Tracking**: Full GradTrack support for automatic differentiation
19//! * **Shape Validation**: Comprehensive error checking for compatible tensor shapes
20//!
21//! # Examples
22//!
23//! ```
24//! use train_station::Tensor;
25//!
26//! // Stack two 1D tensors along dimension 0
27//! let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
28//! let b = Tensor::from_slice(&[4.0, 5.0, 6.0], vec![3]).unwrap();
29//! let stacked = Tensor::stack(&[a, b], 0);
30//! assert_eq!(stacked.shape().dims(), vec![2, 3]);
31//! assert_eq!(stacked.get(&[0, 0]), 1.0);
32//! assert_eq!(stacked.get(&[1, 2]), 6.0);
33//! ```
34//!
35//! ```
36//! use train_station::Tensor;
37//!
38//! // Stack multiple 2D tensors along dimension 1
39//! let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
40//! let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
41//! let c = Tensor::from_slice(&[9.0, 10.0, 11.0, 12.0], vec![2, 2]).unwrap();
42//! let stacked = Tensor::stack(&[a, b, c], 1);
43//! assert_eq!(stacked.shape().dims(), vec![2, 3, 2]);
44//! ```
45
46use crate::gradtrack::{GradEngine, GradFn};
47use crate::tensor::core::Tensor;
48use crate::tensor::iterator::collect::optimized_copy;
49
50impl Tensor {
51    /// Stack a list of tensors along a new dimension
52    ///
53    /// Combines multiple tensors by adding a new dimension at the specified
54    /// position. All input tensors must have identical shapes, and the output
55    /// tensor will have a new dimension of size equal to the number of input
56    /// tensors. This operation is similar to PyTorch's `torch.stack` function.
57    ///
58    /// The stacking operation creates a new axis in the output tensor, unlike
59    /// concatenation which operates along existing dimensions. This makes
60    /// stacking useful for creating batch dimensions, combining feature maps,
61    /// and implementing operations that require adding new tensor axes.
62    ///
63    /// # Arguments
64    ///
65    /// * `tensors` - Array of tensors to stack. All tensors must have identical shapes.
66    /// * `dim` - Index of the new axis in the output shape (0 <= dim <= rank)
67    ///
68    /// # Returns
69    ///
70    /// A new tensor with the stacked data. The output shape is the input shape
71    /// with a new dimension of size `tensors.len()` inserted at position `dim`.
72    ///
73    /// # Panics
74    ///
75    /// * If the tensor array is empty
76    /// * If any tensor has a different shape than the first tensor
77    /// * If `dim` is out of bounds (dim > rank of input tensors)
78    ///
79    /// # Examples
80    ///
81    /// ```
82    /// use train_station::Tensor;
83    ///
84    /// // Stack two 1D tensors along dimension 0
85    /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
86    /// let b = Tensor::from_slice(&[4.0, 5.0, 6.0], vec![3]).unwrap();
87    /// let stacked = Tensor::stack(&[a, b], 0);
88    /// assert_eq!(stacked.shape().dims(), vec![2, 3]);
89    /// assert_eq!(stacked.get(&[0, 0]), 1.0);
90    /// assert_eq!(stacked.get(&[1, 2]), 6.0);
91    /// ```
92    ///
93    /// ```
94    /// use train_station::Tensor;
95    ///
96    /// // Stack multiple 2D tensors along dimension 1
97    /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
98    /// let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
99    /// let c = Tensor::from_slice(&[9.0, 10.0, 11.0, 12.0], vec![2, 2]).unwrap();
100    /// let stacked = Tensor::stack(&[a, b, c], 1);
101    /// assert_eq!(stacked.shape().dims(), vec![2, 3, 2]);
102    /// assert_eq!(stacked.get(&[0, 0, 0]), 1.0);
103    /// assert_eq!(stacked.get(&[1, 2, 1]), 12.0);
104    /// ```
105    ///
106    /// ```
107    /// use train_station::Tensor;
108    ///
109    /// // Stack with gradient tracking
110    /// let mut a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
111    /// let mut b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
112    /// a.set_requires_grad(true);
113    /// b.set_requires_grad(true);
114    ///
115    /// let stacked = Tensor::stack(&[a, b], 0);
116    /// assert!(stacked.requires_grad());
117    /// assert_eq!(stacked.shape().dims(), vec![2, 2]);
118    /// ```
119    ///
120    /// ```
121    /// use train_station::Tensor;
122    ///
123    /// // Stack 3D tensors along the last dimension
124    /// let data1: Vec<f32> = (0..8).map(|i| i as f32).collect();
125    /// let data2: Vec<f32> = (8..16).map(|i| i as f32).collect();
126    /// let a = Tensor::from_slice(&data1, vec![2, 2, 2]).unwrap();
127    /// let b = Tensor::from_slice(&data2, vec![2, 2, 2]).unwrap();
128    /// let stacked = Tensor::stack(&[a, b], 3);
129    /// assert_eq!(stacked.shape().dims(), vec![2, 2, 2, 2]);
130    /// assert_eq!(stacked.get(&[0, 0, 0, 0]), 0.0);
131    /// assert_eq!(stacked.get(&[1, 1, 1, 1]), 15.0);
132    /// ```
133    ///
134    /// # Performance
135    ///
136    /// - **Time Complexity**: O(n) where n is the total number of elements
137    /// - **Memory Usage**: Allocates new contiguous tensor for output
138    /// - **SIMD Optimization**: Uses AVX2 acceleration for large block copies
139    /// - **Block-wise Copying**: Optimized copying strategy for better cache performance
140    /// - **Gradient Tracking**: Preserves gradient requirements and tracking
141    ///
142    /// # Relationship to Other Operations
143    ///
144    /// This operation is related to other tensor transformations:
145    /// - `cat()` - Concatenates tensors along existing dimensions
146    /// - `unsqueeze()` - Adds a single dimension of size 1
147    /// - `reshape()` - Changes tensor shape without adding dimensions
148    ///
149    /// # Memory Layout
150    ///
151    /// The output tensor is always contiguous, with elements arranged so that
152    /// the stacked dimension is the fastest-changing index. This ensures optimal
153    /// performance for subsequent operations and maintains compatibility with
154    /// SIMD optimizations.
155    ///
156    /// # Gradient Computation
157    ///
158    /// During backward passes, gradients are split along the stacked dimension
159    /// and distributed back to the original input tensors. This is implemented
160    /// using the same gradient function as concatenation, treating the stack
161    /// operation as concatenation along a new axis.
162    #[track_caller]
163    pub fn stack(tensors: &[Tensor], dim: usize) -> Tensor {
164        assert!(!tensors.is_empty(), "stack requires at least one tensor");
165
166        // Validate all shapes identical
167        let base_dims = tensors[0].shape().dims();
168        for t in tensors.iter() {
169            assert_eq!(
170                t.shape().dims(),
171                base_dims,
172                "All tensors must have identical shapes for stack"
173            );
174        }
175
176        let rank = base_dims.len();
177        assert!(
178            dim <= rank,
179            "stack dim {} out of bounds for rank {}",
180            dim,
181            rank
182        );
183
184        // Compute output shape by inserting new axis of size = tensors.len()
185        let mut out_dims = Vec::with_capacity(rank + 1);
186        out_dims.extend_from_slice(&base_dims[..dim]);
187        out_dims.push(tensors.len());
188        out_dims.extend_from_slice(&base_dims[dim..]);
189
190        // Materialize into a new contiguous tensor
191        let mut output = Tensor::new(out_dims.clone());
192
193        // Copy block-wise: treat stack dim separately
194        // For output shape [pre..., K=tensors.len(), post...]
195        // inner = product(post...), outer = product(pre...)
196        let inner: usize = base_dims[dim..].iter().product();
197        let outer: usize = base_dims[..dim].iter().product();
198
199        unsafe {
200            let dst_ptr = output.as_mut_ptr();
201            for outer_idx in 0..outer {
202                for (k, t) in tensors.iter().enumerate() {
203                    // Ensure contiguous source
204                    let src = if t.is_contiguous() {
205                        t.clone()
206                    } else {
207                        t.contiguous()
208                    };
209                    // Source offset: within each tensor, block size is inner
210                    let src_base = outer_idx * inner;
211                    let src_ptr = src.as_ptr().add(src_base);
212
213                    // Destination offset computes with inserted axis
214                    // out block along stacked axis of length K, each block is inner
215                    let dst_base = outer_idx * (tensors.len() * inner) + k * inner;
216                    optimized_copy(src_ptr, dst_ptr.add(dst_base), inner);
217                }
218            }
219        }
220
221        // GradTrack: stack is like cat with a new axis; gradient splits along that axis
222        let any_requires = tensors.iter().any(|t| t.requires_grad());
223        if any_requires {
224            output.set_requires_grad(true);
225            // For GradFn::Cat, provide sizes along concat dim and input shapes
226            let mut input_ids = Vec::with_capacity(tensors.len());
227            let mut input_sizes = Vec::with_capacity(tensors.len());
228            let mut input_shapes = Vec::with_capacity(tensors.len());
229            for t in tensors.iter() {
230                if t.requires_grad() {
231                    input_ids.push(t.id());
232                }
233                input_sizes.push(1); // each slice along new axis has length 1
234                input_shapes.push(t.shape().dims().to_vec());
235            }
236            let grad_fn = GradFn::Cat {
237                dim,
238                input_sizes,
239                input_shapes,
240            };
241            output.set_grad_fn(grad_fn.clone());
242            GradEngine::register_operation(output.id(), input_ids, grad_fn);
243        }
244
245        output
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252
253    #[test]
254    fn test_stack_basic() {
255        let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
256        let b = Tensor::from_slice(&[4.0, 5.0, 6.0], vec![3]).unwrap();
257        let y = Tensor::stack(&[a, b], 0);
258        assert_eq!(y.shape().dims(), vec![2, 3]);
259        assert_eq!(y.get(&[0, 0]), 1.0);
260        assert_eq!(y.get(&[1, 2]), 6.0);
261    }
262
263    #[test]
264    fn test_stack_multiple_tensors() {
265        let a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
266        let b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
267        let c = Tensor::from_slice(&[5.0, 6.0], vec![2]).unwrap();
268        let stacked = Tensor::stack(&[a, b, c], 0);
269        assert_eq!(stacked.shape().dims(), vec![3, 2]);
270        assert_eq!(stacked.get(&[0, 0]), 1.0);
271        assert_eq!(stacked.get(&[1, 1]), 4.0);
272        assert_eq!(stacked.get(&[2, 1]), 6.0);
273    }
274
275    #[test]
276    fn test_stack_2d_tensors() {
277        let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
278        let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
279        let stacked = Tensor::stack(&[a, b], 1);
280        assert_eq!(stacked.shape().dims(), vec![2, 2, 2]);
281        assert_eq!(stacked.get(&[0, 0, 0]), 1.0);
282        assert_eq!(stacked.get(&[1, 1, 1]), 8.0);
283    }
284
285    #[test]
286    fn test_stack_with_gradients() {
287        let mut a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
288        let mut b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
289        a.set_requires_grad(true);
290        b.set_requires_grad(true);
291
292        let stacked = Tensor::stack(&[a, b], 0);
293        assert!(stacked.requires_grad());
294        assert_eq!(stacked.shape().dims(), vec![2, 2]);
295    }
296
297    #[test]
298    #[should_panic(expected = "stack requires at least one tensor")]
299    fn test_stack_empty() {
300        Tensor::stack(&[], 0);
301    }
302
303    #[test]
304    #[should_panic(expected = "All tensors must have identical shapes")]
305    fn test_stack_different_shapes() {
306        let a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
307        let b = Tensor::from_slice(&[3.0, 4.0, 5.0], vec![3]).unwrap();
308        Tensor::stack(&[a, b], 0);
309    }
310
311    #[test]
312    #[should_panic(expected = "stack dim 2 out of bounds for rank 1")]
313    fn test_stack_dim_out_of_bounds() {
314        let a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
315        let b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
316        Tensor::stack(&[a, b], 2);
317    }
318}