train_station/tensor/transform/
stack.rs

1//! Tensor stacking operations
2//!
3//! This module provides tensor stacking functionality that combines multiple
4//! tensors along a new dimension. Stacking is a fundamental tensor transformation
5//! operation used in machine learning for combining multiple feature maps,
6//! creating batch dimensions, and implementing complex tensor manipulations
7//! that require adding new axes to tensor data.
8//!
9//! # Operations
10//!
11//! * `stack()` - Stack multiple tensors along a new dimension
12//!
13//! # Performance Characteristics
14//!
15//! * **SIMD Optimized**: AVX2 acceleration for large block copies
16//! * **Memory Efficient**: Optimized block-wise copying with minimal allocations
17//! * **Contiguous Output**: Always produces a contiguous tensor for optimal performance
18//! * **Gradient Tracking**: Full GradTrack support for automatic differentiation
19//! * **Shape Validation**: Comprehensive error checking for compatible tensor shapes
20//!
21//! # Examples
22//!
23//! ```
24//! use train_station::Tensor;
25//!
26//! // Stack two 1D tensors along dimension 0
27//! let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
28//! let b = Tensor::from_slice(&[4.0, 5.0, 6.0], vec![3]).unwrap();
29//! let stacked = Tensor::stack(&[a, b], 0);
30//! assert_eq!(stacked.shape().dims, vec![2, 3]);
31//! assert_eq!(stacked.get(&[0, 0]), 1.0);
32//! assert_eq!(stacked.get(&[1, 2]), 6.0);
33//! ```
34//!
35//! ```
36//! use train_station::Tensor;
37//!
38//! // Stack multiple 2D tensors along dimension 1
39//! let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
40//! let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
41//! let c = Tensor::from_slice(&[9.0, 10.0, 11.0, 12.0], vec![2, 2]).unwrap();
42//! let stacked = Tensor::stack(&[a, b, c], 1);
43//! assert_eq!(stacked.shape().dims, vec![2, 3, 2]);
44//! ```
45
46use crate::gradtrack::{GradEngine, GradFn};
47use crate::tensor::core::Tensor;
48
49// SIMD optimizations for performance-critical operations
50#[cfg(target_arch = "x86_64")]
51use std::arch::x86_64::*;
52
53impl Tensor {
54    /// Stack a list of tensors along a new dimension
55    ///
56    /// Combines multiple tensors by adding a new dimension at the specified
57    /// position. All input tensors must have identical shapes, and the output
58    /// tensor will have a new dimension of size equal to the number of input
59    /// tensors. This operation is similar to PyTorch's `torch.stack` function.
60    ///
61    /// The stacking operation creates a new axis in the output tensor, unlike
62    /// concatenation which operates along existing dimensions. This makes
63    /// stacking useful for creating batch dimensions, combining feature maps,
64    /// and implementing operations that require adding new tensor axes.
65    ///
66    /// # Arguments
67    ///
68    /// * `tensors` - Array of tensors to stack. All tensors must have identical shapes.
69    /// * `dim` - Index of the new axis in the output shape (0 <= dim <= rank)
70    ///
71    /// # Returns
72    ///
73    /// A new tensor with the stacked data. The output shape is the input shape
74    /// with a new dimension of size `tensors.len()` inserted at position `dim`.
75    ///
76    /// # Panics
77    ///
78    /// * If the tensor array is empty
79    /// * If any tensor has a different shape than the first tensor
80    /// * If `dim` is out of bounds (dim > rank of input tensors)
81    ///
82    /// # Examples
83    ///
84    /// ```
85    /// use train_station::Tensor;
86    ///
87    /// // Stack two 1D tensors along dimension 0
88    /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
89    /// let b = Tensor::from_slice(&[4.0, 5.0, 6.0], vec![3]).unwrap();
90    /// let stacked = Tensor::stack(&[a, b], 0);
91    /// assert_eq!(stacked.shape().dims, vec![2, 3]);
92    /// assert_eq!(stacked.get(&[0, 0]), 1.0);
93    /// assert_eq!(stacked.get(&[1, 2]), 6.0);
94    /// ```
95    ///
96    /// ```
97    /// use train_station::Tensor;
98    ///
99    /// // Stack multiple 2D tensors along dimension 1
100    /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
101    /// let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
102    /// let c = Tensor::from_slice(&[9.0, 10.0, 11.0, 12.0], vec![2, 2]).unwrap();
103    /// let stacked = Tensor::stack(&[a, b, c], 1);
104    /// assert_eq!(stacked.shape().dims, vec![2, 3, 2]);
105    /// assert_eq!(stacked.get(&[0, 0, 0]), 1.0);
106    /// assert_eq!(stacked.get(&[1, 2, 1]), 12.0);
107    /// ```
108    ///
109    /// ```
110    /// use train_station::Tensor;
111    ///
112    /// // Stack with gradient tracking
113    /// let mut a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
114    /// let mut b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
115    /// a.set_requires_grad(true);
116    /// b.set_requires_grad(true);
117    ///
118    /// let stacked = Tensor::stack(&[a, b], 0);
119    /// assert!(stacked.requires_grad());
120    /// assert_eq!(stacked.shape().dims, vec![2, 2]);
121    /// ```
122    ///
123    /// ```
124    /// use train_station::Tensor;
125    ///
126    /// // Stack 3D tensors along the last dimension
127    /// let data1: Vec<f32> = (0..8).map(|i| i as f32).collect();
128    /// let data2: Vec<f32> = (8..16).map(|i| i as f32).collect();
129    /// let a = Tensor::from_slice(&data1, vec![2, 2, 2]).unwrap();
130    /// let b = Tensor::from_slice(&data2, vec![2, 2, 2]).unwrap();
131    /// let stacked = Tensor::stack(&[a, b], 3);
132    /// assert_eq!(stacked.shape().dims, vec![2, 2, 2, 2]);
133    /// assert_eq!(stacked.get(&[0, 0, 0, 0]), 0.0);
134    /// assert_eq!(stacked.get(&[1, 1, 1, 1]), 15.0);
135    /// ```
136    ///
137    /// # Performance
138    ///
139    /// - **Time Complexity**: O(n) where n is the total number of elements
140    /// - **Memory Usage**: Allocates new contiguous tensor for output
141    /// - **SIMD Optimization**: Uses AVX2 acceleration for large block copies
142    /// - **Block-wise Copying**: Optimized copying strategy for better cache performance
143    /// - **Gradient Tracking**: Preserves gradient requirements and tracking
144    ///
145    /// # Relationship to Other Operations
146    ///
147    /// This operation is related to other tensor transformations:
148    /// - `cat()` - Concatenates tensors along existing dimensions
149    /// - `unsqueeze()` - Adds a single dimension of size 1
150    /// - `reshape()` - Changes tensor shape without adding dimensions
151    ///
152    /// # Memory Layout
153    ///
154    /// The output tensor is always contiguous, with elements arranged so that
155    /// the stacked dimension is the fastest-changing index. This ensures optimal
156    /// performance for subsequent operations and maintains compatibility with
157    /// SIMD optimizations.
158    ///
159    /// # Gradient Computation
160    ///
161    /// During backward passes, gradients are split along the stacked dimension
162    /// and distributed back to the original input tensors. This is implemented
163    /// using the same gradient function as concatenation, treating the stack
164    /// operation as concatenation along a new axis.
165    pub fn stack(tensors: &[Tensor], dim: usize) -> Tensor {
166        assert!(!tensors.is_empty(), "stack requires at least one tensor");
167
168        // Validate all shapes identical
169        let base_dims = tensors[0].shape().dims.clone();
170        for t in tensors.iter() {
171            assert_eq!(
172                t.shape().dims,
173                base_dims,
174                "All tensors must have identical shapes for stack"
175            );
176        }
177
178        let rank = base_dims.len();
179        assert!(
180            dim <= rank,
181            "stack dim {} out of bounds for rank {}",
182            dim,
183            rank
184        );
185
186        // Compute output shape by inserting new axis of size = tensors.len()
187        let mut out_dims = Vec::with_capacity(rank + 1);
188        out_dims.extend_from_slice(&base_dims[..dim]);
189        out_dims.push(tensors.len());
190        out_dims.extend_from_slice(&base_dims[dim..]);
191
192        // Materialize into a new contiguous tensor
193        let mut output = Tensor::new(out_dims.clone());
194
195        // Copy block-wise: treat stack dim separately
196        // For output shape [pre..., K=tensors.len(), post...]
197        // inner = product(post...), outer = product(pre...)
198        let inner: usize = base_dims[dim..].iter().product();
199        let outer: usize = base_dims[..dim].iter().product();
200
201        unsafe {
202            let dst_ptr = output.as_mut_ptr();
203            for outer_idx in 0..outer {
204                for (k, t) in tensors.iter().enumerate() {
205                    // Ensure contiguous source
206                    let src = if t.is_contiguous() {
207                        t.clone()
208                    } else {
209                        t.contiguous()
210                    };
211                    // Source offset: within each tensor, block size is inner
212                    let src_base = outer_idx * inner;
213                    let src_ptr = src.as_ptr().add(src_base);
214
215                    // Destination offset computes with inserted axis
216                    // out block along stacked axis of length K, each block is inner
217                    let dst_base = outer_idx * (tensors.len() * inner) + k * inner;
218                    optimized_block_copy(src_ptr, dst_ptr.add(dst_base), inner);
219                }
220            }
221        }
222
223        // GradTrack: stack is like cat with a new axis; gradient splits along that axis
224        let any_requires = tensors.iter().any(|t| t.requires_grad());
225        if any_requires {
226            output.set_requires_grad(true);
227            // For GradFn::Cat, provide sizes along concat dim and input shapes
228            let mut input_ids = Vec::with_capacity(tensors.len());
229            let mut input_sizes = Vec::with_capacity(tensors.len());
230            let mut input_shapes = Vec::with_capacity(tensors.len());
231            for t in tensors.iter() {
232                if t.requires_grad() {
233                    input_ids.push(t.id());
234                }
235                input_sizes.push(1); // each slice along new axis has length 1
236                input_shapes.push(t.shape().dims.clone());
237            }
238            let grad_fn = GradFn::Cat {
239                dim,
240                input_sizes,
241                input_shapes,
242            };
243            output.set_grad_fn(grad_fn.clone());
244            GradEngine::register_operation(output.id(), input_ids, grad_fn);
245        }
246
247        output
248    }
249}
250
251/// Optimized block copy with SIMD acceleration for large blocks
252///
253/// Performs efficient memory copying with automatic SIMD optimization
254/// for large data blocks. This function automatically selects the best
255/// copying strategy based on block size and available CPU features.
256///
257/// # Arguments
258///
259/// * `src` - Source pointer to copy from
260/// * `dst` - Destination pointer to copy to
261/// * `count` - Number of f32 elements to copy
262///
263/// # Safety
264///
265/// The caller must ensure:
266/// * `src` and `dst` are valid pointers to f32 data
267/// * `src` and `dst` do not overlap (non-overlapping memory regions)
268/// * `count` elements are accessible from both pointers
269/// * The memory regions are properly aligned for SIMD operations
270///
271/// # Performance
272///
273/// - **Small blocks (≤32 elements)**: Direct memory copy
274/// - **Large blocks (≥64 elements)**: AVX2 SIMD acceleration when available
275/// - **Medium blocks**: Unrolled scalar copy for optimal performance
276/// - **Memory bandwidth**: Optimized for maximum throughput
277///
278/// # Examples
279///
280/// This function is used internally by the `stack()` operation for
281/// efficient memory copying. It automatically selects the best copying
282/// strategy based on block size and available CPU features.
283#[inline]
284unsafe fn optimized_block_copy(src: *const f32, dst: *mut f32, count: usize) {
285    if count == 0 {
286        return;
287    }
288
289    // For small blocks, use standard copy
290    if count <= 32 {
291        std::ptr::copy_nonoverlapping(src, dst, count);
292        return;
293    }
294
295    #[cfg(target_arch = "x86_64")]
296    {
297        if is_x86_feature_detected!("avx2") && count >= 64 {
298            simd_block_copy_avx2(src, dst, count);
299            return;
300        }
301    }
302
303    // Fallback to optimized scalar copy with unrolling
304    scalar_block_copy_unrolled(src, dst, count);
305}
306
307/// SIMD-optimized block copy using AVX2 instructions
308///
309/// Performs high-performance memory copying using AVX2 vector instructions
310/// for maximum throughput on x86_64 processors. This function processes
311/// 32 elements per iteration using 4 AVX2 vectors in an unrolled loop.
312///
313/// # Arguments
314///
315/// * `src` - Source pointer to copy from
316/// * `dst` - Destination pointer to copy to
317/// * `count` - Number of f32 elements to copy
318///
319/// # Safety
320///
321/// The caller must ensure:
322/// * AVX2 instructions are available on the target CPU
323/// * Pointers are properly aligned for AVX2 operations
324/// * Memory regions do not overlap
325/// * All elements are accessible from both pointers
326///
327/// # Performance
328///
329/// - **Throughput**: 32 elements per iteration (4 AVX2 vectors)
330/// - **Unrolling**: 4x unrolled loop for maximum instruction-level parallelism
331/// - **Fallback**: Handles remaining elements with 8-element blocks and scalar copy
332/// - **Memory bandwidth**: Optimized for maximum memory throughput
333///
334/// # Examples
335///
336/// This function is used internally by `optimized_block_copy()` for
337/// high-performance memory copying on x86_64 processors with AVX2 support.
338#[cfg(target_arch = "x86_64")]
339#[inline]
340#[target_feature(enable = "avx2")]
341unsafe fn simd_block_copy_avx2(src: *const f32, dst: *mut f32, count: usize) {
342    let simd_count = count / 32; // Process 32 elements per iteration (4x AVX2 vectors)
343    let mut offset = 0;
344
345    // Unrolled SIMD loop for maximum throughput
346    for _ in 0..simd_count {
347        // Process 4 AVX2 vectors (32 elements) per iteration
348        let vec1 = _mm256_loadu_ps(src.add(offset));
349        let vec2 = _mm256_loadu_ps(src.add(offset + 8));
350        let vec3 = _mm256_loadu_ps(src.add(offset + 16));
351        let vec4 = _mm256_loadu_ps(src.add(offset + 24));
352
353        _mm256_storeu_ps(dst.add(offset), vec1);
354        _mm256_storeu_ps(dst.add(offset + 8), vec2);
355        _mm256_storeu_ps(dst.add(offset + 16), vec3);
356        _mm256_storeu_ps(dst.add(offset + 24), vec4);
357
358        offset += 32;
359    }
360
361    // Handle remaining elements with 8-element SIMD blocks
362    let remaining_full_blocks = (count - offset) / 8;
363    for _ in 0..remaining_full_blocks {
364        let vec = _mm256_loadu_ps(src.add(offset));
365        _mm256_storeu_ps(dst.add(offset), vec);
366        offset += 8;
367    }
368
369    // Handle final elements
370    if offset < count {
371        std::ptr::copy_nonoverlapping(src.add(offset), dst.add(offset), count - offset);
372    }
373}
374
375/// Optimized scalar block copy with loop unrolling
376///
377/// Performs efficient memory copying using unrolled scalar operations
378/// for cases where SIMD instructions are not available or beneficial.
379/// This function processes 8 elements per iteration in an unrolled loop.
380///
381/// # Arguments
382///
383/// * `src` - Source pointer to copy from
384/// * `dst` - Destination pointer to copy to
385/// * `count` - Number of f32 elements to copy
386///
387/// # Safety
388///
389/// The caller must ensure:
390/// * `src` and `dst` are valid pointers to f32 data
391/// * Memory regions do not overlap
392/// * All elements are accessible from both pointers
393///
394/// # Performance
395///
396/// - **Throughput**: 8 elements per iteration (unrolled loop)
397/// - **Instruction-level parallelism**: Unrolled operations for better CPU utilization
398/// - **Fallback**: Handles remaining elements with standard memory copy
399/// - **Compatibility**: Works on all CPU architectures
400///
401/// # Examples
402///
403/// This function is used internally by `optimized_block_copy()` for
404/// efficient scalar memory copying when SIMD instructions are not available.
405#[inline]
406unsafe fn scalar_block_copy_unrolled(src: *const f32, dst: *mut f32, count: usize) {
407    let unroll_factor = 8;
408    let unroll_count = count / unroll_factor;
409    let mut offset = 0;
410
411    // Unrolled scalar copy for better performance
412    for _ in 0..unroll_count {
413        *dst.add(offset) = *src.add(offset);
414        *dst.add(offset + 1) = *src.add(offset + 1);
415        *dst.add(offset + 2) = *src.add(offset + 2);
416        *dst.add(offset + 3) = *src.add(offset + 3);
417        *dst.add(offset + 4) = *src.add(offset + 4);
418        *dst.add(offset + 5) = *src.add(offset + 5);
419        *dst.add(offset + 6) = *src.add(offset + 6);
420        *dst.add(offset + 7) = *src.add(offset + 7);
421        offset += unroll_factor;
422    }
423
424    // Handle remaining elements
425    if offset < count {
426        std::ptr::copy_nonoverlapping(src.add(offset), dst.add(offset), count - offset);
427    }
428}
429
430#[cfg(test)]
431mod tests {
432    use super::*;
433
434    #[test]
435    fn test_stack_basic() {
436        let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
437        let b = Tensor::from_slice(&[4.0, 5.0, 6.0], vec![3]).unwrap();
438        let y = Tensor::stack(&[a, b], 0);
439        assert_eq!(y.shape().dims, vec![2, 3]);
440        assert_eq!(y.get(&[0, 0]), 1.0);
441        assert_eq!(y.get(&[1, 2]), 6.0);
442    }
443
444    #[test]
445    fn test_stack_multiple_tensors() {
446        let a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
447        let b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
448        let c = Tensor::from_slice(&[5.0, 6.0], vec![2]).unwrap();
449        let stacked = Tensor::stack(&[a, b, c], 0);
450        assert_eq!(stacked.shape().dims, vec![3, 2]);
451        assert_eq!(stacked.get(&[0, 0]), 1.0);
452        assert_eq!(stacked.get(&[1, 1]), 4.0);
453        assert_eq!(stacked.get(&[2, 1]), 6.0);
454    }
455
456    #[test]
457    fn test_stack_2d_tensors() {
458        let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
459        let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
460        let stacked = Tensor::stack(&[a, b], 1);
461        assert_eq!(stacked.shape().dims, vec![2, 2, 2]);
462        assert_eq!(stacked.get(&[0, 0, 0]), 1.0);
463        assert_eq!(stacked.get(&[1, 1, 1]), 8.0);
464    }
465
466    #[test]
467    fn test_stack_with_gradients() {
468        let mut a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
469        let mut b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
470        a.set_requires_grad(true);
471        b.set_requires_grad(true);
472
473        let stacked = Tensor::stack(&[a, b], 0);
474        assert!(stacked.requires_grad());
475        assert_eq!(stacked.shape().dims, vec![2, 2]);
476    }
477
478    #[test]
479    #[should_panic(expected = "stack requires at least one tensor")]
480    fn test_stack_empty() {
481        Tensor::stack(&[], 0);
482    }
483
484    #[test]
485    #[should_panic(expected = "All tensors must have identical shapes")]
486    fn test_stack_different_shapes() {
487        let a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
488        let b = Tensor::from_slice(&[3.0, 4.0, 5.0], vec![3]).unwrap();
489        Tensor::stack(&[a, b], 0);
490    }
491
492    #[test]
493    #[should_panic(expected = "stack dim 2 out of bounds for rank 1")]
494    fn test_stack_dim_out_of_bounds() {
495        let a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
496        let b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
497        Tensor::stack(&[a, b], 2);
498    }
499}