train_station/tensor/transform/
stack.rs

1//! Tensor stacking operations
2//!
3//! This module provides tensor stacking functionality that combines multiple
4//! tensors along a new dimension. Stacking is a fundamental tensor transformation
5//! operation used in machine learning for combining multiple feature maps,
6//! creating batch dimensions, and implementing complex tensor manipulations
7//! that require adding new axes to tensor data.
8//!
9//! # Operations
10//!
11//! * `stack()` - Stack multiple tensors along a new dimension
12//!
13//! # Performance Characteristics
14//!
15//! * **SIMD Optimized**: AVX2 acceleration for large block copies
16//! * **Memory Efficient**: Optimized block-wise copying with minimal allocations
17//! * **Contiguous Output**: Always produces a contiguous tensor for optimal performance
18//! * **Gradient Tracking**: Full GradTrack support for automatic differentiation
19//! * **Shape Validation**: Comprehensive error checking for compatible tensor shapes
20//!
21//! # Examples
22//!
23//! ```
24//! use train_station::Tensor;
25//!
26//! // Stack two 1D tensors along dimension 0
27//! let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
28//! let b = Tensor::from_slice(&[4.0, 5.0, 6.0], vec![3]).unwrap();
29//! let stacked = Tensor::stack(&[a, b], 0);
30//! assert_eq!(stacked.shape().dims, vec![2, 3]);
31//! assert_eq!(stacked.get(&[0, 0]), 1.0);
32//! assert_eq!(stacked.get(&[1, 2]), 6.0);
33//! ```
34//!
35//! ```
36//! use train_station::Tensor;
37//!
38//! // Stack multiple 2D tensors along dimension 1
39//! let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
40//! let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
41//! let c = Tensor::from_slice(&[9.0, 10.0, 11.0, 12.0], vec![2, 2]).unwrap();
42//! let stacked = Tensor::stack(&[a, b, c], 1);
43//! assert_eq!(stacked.shape().dims, vec![2, 3, 2]);
44//! ```
45
46use crate::gradtrack::{GradEngine, GradFn};
47use crate::tensor::core::Tensor;
48
49// SIMD optimizations for performance-critical operations
50#[cfg(target_arch = "x86_64")]
51use std::arch::x86_64::*;
52
53impl Tensor {
54    /// Stack a list of tensors along a new dimension
55    ///
56    /// Combines multiple tensors by adding a new dimension at the specified
57    /// position. All input tensors must have identical shapes, and the output
58    /// tensor will have a new dimension of size equal to the number of input
59    /// tensors. This operation is similar to PyTorch's `torch.stack` function.
60    ///
61    /// The stacking operation creates a new axis in the output tensor, unlike
62    /// concatenation which operates along existing dimensions. This makes
63    /// stacking useful for creating batch dimensions, combining feature maps,
64    /// and implementing operations that require adding new tensor axes.
65    ///
66    /// # Arguments
67    ///
68    /// * `tensors` - Array of tensors to stack. All tensors must have identical shapes.
69    /// * `dim` - Index of the new axis in the output shape (0 <= dim <= rank)
70    ///
71    /// # Returns
72    ///
73    /// A new tensor with the stacked data. The output shape is the input shape
74    /// with a new dimension of size `tensors.len()` inserted at position `dim`.
75    ///
76    /// # Panics
77    ///
78    /// * If the tensor array is empty
79    /// * If any tensor has a different shape than the first tensor
80    /// * If `dim` is out of bounds (dim > rank of input tensors)
81    ///
82    /// # Examples
83    ///
84    /// ```
85    /// use train_station::Tensor;
86    ///
87    /// // Stack two 1D tensors along dimension 0
88    /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
89    /// let b = Tensor::from_slice(&[4.0, 5.0, 6.0], vec![3]).unwrap();
90    /// let stacked = Tensor::stack(&[a, b], 0);
91    /// assert_eq!(stacked.shape().dims, vec![2, 3]);
92    /// assert_eq!(stacked.get(&[0, 0]), 1.0);
93    /// assert_eq!(stacked.get(&[1, 2]), 6.0);
94    /// ```
95    ///
96    /// ```
97    /// use train_station::Tensor;
98    ///
99    /// // Stack multiple 2D tensors along dimension 1
100    /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
101    /// let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
102    /// let c = Tensor::from_slice(&[9.0, 10.0, 11.0, 12.0], vec![2, 2]).unwrap();
103    /// let stacked = Tensor::stack(&[a, b, c], 1);
104    /// assert_eq!(stacked.shape().dims, vec![2, 3, 2]);
105    /// assert_eq!(stacked.get(&[0, 0, 0]), 1.0);
106    /// assert_eq!(stacked.get(&[1, 2, 1]), 12.0);
107    /// ```
108    ///
109    /// ```
110    /// use train_station::Tensor;
111    ///
112    /// // Stack with gradient tracking
113    /// let mut a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
114    /// let mut b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
115    /// a.set_requires_grad(true);
116    /// b.set_requires_grad(true);
117    ///
118    /// let stacked = Tensor::stack(&[a, b], 0);
119    /// assert!(stacked.requires_grad());
120    /// assert_eq!(stacked.shape().dims, vec![2, 2]);
121    /// ```
122    ///
123    /// ```
124    /// use train_station::Tensor;
125    ///
126    /// // Stack 3D tensors along the last dimension
127    /// let data1: Vec<f32> = (0..8).map(|i| i as f32).collect();
128    /// let data2: Vec<f32> = (8..16).map(|i| i as f32).collect();
129    /// let a = Tensor::from_slice(&data1, vec![2, 2, 2]).unwrap();
130    /// let b = Tensor::from_slice(&data2, vec![2, 2, 2]).unwrap();
131    /// let stacked = Tensor::stack(&[a, b], 3);
132    /// assert_eq!(stacked.shape().dims, vec![2, 2, 2, 2]);
133    /// assert_eq!(stacked.get(&[0, 0, 0, 0]), 0.0);
134    /// assert_eq!(stacked.get(&[1, 1, 1, 1]), 15.0);
135    /// ```
136    ///
137    /// # Performance
138    ///
139    /// - **Time Complexity**: O(n) where n is the total number of elements
140    /// - **Memory Usage**: Allocates new contiguous tensor for output
141    /// - **SIMD Optimization**: Uses AVX2 acceleration for large block copies
142    /// - **Block-wise Copying**: Optimized copying strategy for better cache performance
143    /// - **Gradient Tracking**: Preserves gradient requirements and tracking
144    ///
145    /// # Relationship to Other Operations
146    ///
147    /// This operation is related to other tensor transformations:
148    /// - `cat()` - Concatenates tensors along existing dimensions
149    /// - `unsqueeze()` - Adds a single dimension of size 1
150    /// - `reshape()` - Changes tensor shape without adding dimensions
151    ///
152    /// # Memory Layout
153    ///
154    /// The output tensor is always contiguous, with elements arranged so that
155    /// the stacked dimension is the fastest-changing index. This ensures optimal
156    /// performance for subsequent operations and maintains compatibility with
157    /// SIMD optimizations.
158    ///
159    /// # Gradient Computation
160    ///
161    /// During backward passes, gradients are split along the stacked dimension
162    /// and distributed back to the original input tensors. This is implemented
163    /// using the same gradient function as concatenation, treating the stack
164    /// operation as concatenation along a new axis.
165    #[track_caller]
166    pub fn stack(tensors: &[Tensor], dim: usize) -> Tensor {
167        assert!(!tensors.is_empty(), "stack requires at least one tensor");
168
169        // Validate all shapes identical
170        let base_dims = tensors[0].shape().dims.clone();
171        for t in tensors.iter() {
172            assert_eq!(
173                t.shape().dims,
174                base_dims,
175                "All tensors must have identical shapes for stack"
176            );
177        }
178
179        let rank = base_dims.len();
180        assert!(
181            dim <= rank,
182            "stack dim {} out of bounds for rank {}",
183            dim,
184            rank
185        );
186
187        // Compute output shape by inserting new axis of size = tensors.len()
188        let mut out_dims = Vec::with_capacity(rank + 1);
189        out_dims.extend_from_slice(&base_dims[..dim]);
190        out_dims.push(tensors.len());
191        out_dims.extend_from_slice(&base_dims[dim..]);
192
193        // Materialize into a new contiguous tensor
194        let mut output = Tensor::new(out_dims.clone());
195
196        // Copy block-wise: treat stack dim separately
197        // For output shape [pre..., K=tensors.len(), post...]
198        // inner = product(post...), outer = product(pre...)
199        let inner: usize = base_dims[dim..].iter().product();
200        let outer: usize = base_dims[..dim].iter().product();
201
202        unsafe {
203            let dst_ptr = output.as_mut_ptr();
204            for outer_idx in 0..outer {
205                for (k, t) in tensors.iter().enumerate() {
206                    // Ensure contiguous source
207                    let src = if t.is_contiguous() {
208                        t.clone()
209                    } else {
210                        t.contiguous()
211                    };
212                    // Source offset: within each tensor, block size is inner
213                    let src_base = outer_idx * inner;
214                    let src_ptr = src.as_ptr().add(src_base);
215
216                    // Destination offset computes with inserted axis
217                    // out block along stacked axis of length K, each block is inner
218                    let dst_base = outer_idx * (tensors.len() * inner) + k * inner;
219                    optimized_block_copy(src_ptr, dst_ptr.add(dst_base), inner);
220                }
221            }
222        }
223
224        // GradTrack: stack is like cat with a new axis; gradient splits along that axis
225        let any_requires = tensors.iter().any(|t| t.requires_grad());
226        if any_requires {
227            output.set_requires_grad(true);
228            // For GradFn::Cat, provide sizes along concat dim and input shapes
229            let mut input_ids = Vec::with_capacity(tensors.len());
230            let mut input_sizes = Vec::with_capacity(tensors.len());
231            let mut input_shapes = Vec::with_capacity(tensors.len());
232            for t in tensors.iter() {
233                if t.requires_grad() {
234                    input_ids.push(t.id());
235                }
236                input_sizes.push(1); // each slice along new axis has length 1
237                input_shapes.push(t.shape().dims.clone());
238            }
239            let grad_fn = GradFn::Cat {
240                dim,
241                input_sizes,
242                input_shapes,
243            };
244            output.set_grad_fn(grad_fn.clone());
245            GradEngine::register_operation(output.id(), input_ids, grad_fn);
246        }
247
248        output
249    }
250}
251
252/// Optimized block copy with SIMD acceleration for large blocks
253///
254/// Performs efficient memory copying with automatic SIMD optimization
255/// for large data blocks. This function automatically selects the best
256/// copying strategy based on block size and available CPU features.
257///
258/// # Arguments
259///
260/// * `src` - Source pointer to copy from
261/// * `dst` - Destination pointer to copy to
262/// * `count` - Number of f32 elements to copy
263///
264/// # Safety
265///
266/// The caller must ensure:
267/// * `src` and `dst` are valid pointers to f32 data
268/// * `src` and `dst` do not overlap (non-overlapping memory regions)
269/// * `count` elements are accessible from both pointers
270/// * The memory regions are properly aligned for SIMD operations
271///
272/// # Performance
273///
274/// - **Small blocks (≤32 elements)**: Direct memory copy
275/// - **Large blocks (≥64 elements)**: AVX2 SIMD acceleration when available
276/// - **Medium blocks**: Unrolled scalar copy for optimal performance
277/// - **Memory bandwidth**: Optimized for maximum throughput
278///
279/// # Examples
280///
281/// This function is used internally by the `stack()` operation for
282/// efficient memory copying. It automatically selects the best copying
283/// strategy based on block size and available CPU features.
284#[inline]
285unsafe fn optimized_block_copy(src: *const f32, dst: *mut f32, count: usize) {
286    if count == 0 {
287        return;
288    }
289
290    // For small blocks, use standard copy
291    if count <= 32 {
292        std::ptr::copy_nonoverlapping(src, dst, count);
293        return;
294    }
295
296    #[cfg(target_arch = "x86_64")]
297    {
298        if is_x86_feature_detected!("avx2") && count >= 64 {
299            simd_block_copy_avx2(src, dst, count);
300            return;
301        }
302    }
303
304    // Fallback to optimized scalar copy with unrolling
305    scalar_block_copy_unrolled(src, dst, count);
306}
307
308/// SIMD-optimized block copy using AVX2 instructions
309///
310/// Performs high-performance memory copying using AVX2 vector instructions
311/// for maximum throughput on x86_64 processors. This function processes
312/// 32 elements per iteration using 4 AVX2 vectors in an unrolled loop.
313///
314/// # Arguments
315///
316/// * `src` - Source pointer to copy from
317/// * `dst` - Destination pointer to copy to
318/// * `count` - Number of f32 elements to copy
319///
320/// # Safety
321///
322/// The caller must ensure:
323/// * AVX2 instructions are available on the target CPU
324/// * Pointers are properly aligned for AVX2 operations
325/// * Memory regions do not overlap
326/// * All elements are accessible from both pointers
327///
328/// # Performance
329///
330/// - **Throughput**: 32 elements per iteration (4 AVX2 vectors)
331/// - **Unrolling**: 4x unrolled loop for maximum instruction-level parallelism
332/// - **Fallback**: Handles remaining elements with 8-element blocks and scalar copy
333/// - **Memory bandwidth**: Optimized for maximum memory throughput
334///
335/// # Examples
336///
337/// This function is used internally by `optimized_block_copy()` for
338/// high-performance memory copying on x86_64 processors with AVX2 support.
339#[cfg(target_arch = "x86_64")]
340#[inline]
341#[target_feature(enable = "avx2")]
342unsafe fn simd_block_copy_avx2(src: *const f32, dst: *mut f32, count: usize) {
343    let simd_count = count / 32; // Process 32 elements per iteration (4x AVX2 vectors)
344    let mut offset = 0;
345
346    // Unrolled SIMD loop for maximum throughput
347    for _ in 0..simd_count {
348        // Process 4 AVX2 vectors (32 elements) per iteration
349        let vec1 = _mm256_loadu_ps(src.add(offset));
350        let vec2 = _mm256_loadu_ps(src.add(offset + 8));
351        let vec3 = _mm256_loadu_ps(src.add(offset + 16));
352        let vec4 = _mm256_loadu_ps(src.add(offset + 24));
353
354        _mm256_storeu_ps(dst.add(offset), vec1);
355        _mm256_storeu_ps(dst.add(offset + 8), vec2);
356        _mm256_storeu_ps(dst.add(offset + 16), vec3);
357        _mm256_storeu_ps(dst.add(offset + 24), vec4);
358
359        offset += 32;
360    }
361
362    // Handle remaining elements with 8-element SIMD blocks
363    let remaining_full_blocks = (count - offset) / 8;
364    for _ in 0..remaining_full_blocks {
365        let vec = _mm256_loadu_ps(src.add(offset));
366        _mm256_storeu_ps(dst.add(offset), vec);
367        offset += 8;
368    }
369
370    // Handle final elements
371    if offset < count {
372        std::ptr::copy_nonoverlapping(src.add(offset), dst.add(offset), count - offset);
373    }
374}
375
376/// Optimized scalar block copy with loop unrolling
377///
378/// Performs efficient memory copying using unrolled scalar operations
379/// for cases where SIMD instructions are not available or beneficial.
380/// This function processes 8 elements per iteration in an unrolled loop.
381///
382/// # Arguments
383///
384/// * `src` - Source pointer to copy from
385/// * `dst` - Destination pointer to copy to
386/// * `count` - Number of f32 elements to copy
387///
388/// # Safety
389///
390/// The caller must ensure:
391/// * `src` and `dst` are valid pointers to f32 data
392/// * Memory regions do not overlap
393/// * All elements are accessible from both pointers
394///
395/// # Performance
396///
397/// - **Throughput**: 8 elements per iteration (unrolled loop)
398/// - **Instruction-level parallelism**: Unrolled operations for better CPU utilization
399/// - **Fallback**: Handles remaining elements with standard memory copy
400/// - **Compatibility**: Works on all CPU architectures
401///
402/// # Examples
403///
404/// This function is used internally by `optimized_block_copy()` for
405/// efficient scalar memory copying when SIMD instructions are not available.
406#[inline]
407unsafe fn scalar_block_copy_unrolled(src: *const f32, dst: *mut f32, count: usize) {
408    let unroll_factor = 8;
409    let unroll_count = count / unroll_factor;
410    let mut offset = 0;
411
412    // Unrolled scalar copy for better performance
413    for _ in 0..unroll_count {
414        *dst.add(offset) = *src.add(offset);
415        *dst.add(offset + 1) = *src.add(offset + 1);
416        *dst.add(offset + 2) = *src.add(offset + 2);
417        *dst.add(offset + 3) = *src.add(offset + 3);
418        *dst.add(offset + 4) = *src.add(offset + 4);
419        *dst.add(offset + 5) = *src.add(offset + 5);
420        *dst.add(offset + 6) = *src.add(offset + 6);
421        *dst.add(offset + 7) = *src.add(offset + 7);
422        offset += unroll_factor;
423    }
424
425    // Handle remaining elements
426    if offset < count {
427        std::ptr::copy_nonoverlapping(src.add(offset), dst.add(offset), count - offset);
428    }
429}
430
431#[cfg(test)]
432mod tests {
433    use super::*;
434
435    #[test]
436    fn test_stack_basic() {
437        let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
438        let b = Tensor::from_slice(&[4.0, 5.0, 6.0], vec![3]).unwrap();
439        let y = Tensor::stack(&[a, b], 0);
440        assert_eq!(y.shape().dims, vec![2, 3]);
441        assert_eq!(y.get(&[0, 0]), 1.0);
442        assert_eq!(y.get(&[1, 2]), 6.0);
443    }
444
445    #[test]
446    fn test_stack_multiple_tensors() {
447        let a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
448        let b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
449        let c = Tensor::from_slice(&[5.0, 6.0], vec![2]).unwrap();
450        let stacked = Tensor::stack(&[a, b, c], 0);
451        assert_eq!(stacked.shape().dims, vec![3, 2]);
452        assert_eq!(stacked.get(&[0, 0]), 1.0);
453        assert_eq!(stacked.get(&[1, 1]), 4.0);
454        assert_eq!(stacked.get(&[2, 1]), 6.0);
455    }
456
457    #[test]
458    fn test_stack_2d_tensors() {
459        let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
460        let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
461        let stacked = Tensor::stack(&[a, b], 1);
462        assert_eq!(stacked.shape().dims, vec![2, 2, 2]);
463        assert_eq!(stacked.get(&[0, 0, 0]), 1.0);
464        assert_eq!(stacked.get(&[1, 1, 1]), 8.0);
465    }
466
467    #[test]
468    fn test_stack_with_gradients() {
469        let mut a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
470        let mut b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
471        a.set_requires_grad(true);
472        b.set_requires_grad(true);
473
474        let stacked = Tensor::stack(&[a, b], 0);
475        assert!(stacked.requires_grad());
476        assert_eq!(stacked.shape().dims, vec![2, 2]);
477    }
478
479    #[test]
480    #[should_panic(expected = "stack requires at least one tensor")]
481    fn test_stack_empty() {
482        Tensor::stack(&[], 0);
483    }
484
485    #[test]
486    #[should_panic(expected = "All tensors must have identical shapes")]
487    fn test_stack_different_shapes() {
488        let a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
489        let b = Tensor::from_slice(&[3.0, 4.0, 5.0], vec![3]).unwrap();
490        Tensor::stack(&[a, b], 0);
491    }
492
493    #[test]
494    #[should_panic(expected = "stack dim 2 out of bounds for rank 1")]
495    fn test_stack_dim_out_of_bounds() {
496        let a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
497        let b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
498        Tensor::stack(&[a, b], 2);
499    }
500}