train_station/tensor/transform/stack.rs
1//! Tensor stacking operations
2//!
3//! This module provides tensor stacking functionality that combines multiple
4//! tensors along a new dimension. Stacking is a fundamental tensor transformation
5//! operation used in machine learning for combining multiple feature maps,
6//! creating batch dimensions, and implementing complex tensor manipulations
7//! that require adding new axes to tensor data.
8//!
9//! # Operations
10//!
11//! * `stack()` - Stack multiple tensors along a new dimension
12//!
13//! # Performance Characteristics
14//!
15//! * **SIMD Optimized**: AVX2 acceleration for large block copies
16//! * **Memory Efficient**: Optimized block-wise copying with minimal allocations
17//! * **Contiguous Output**: Always produces a contiguous tensor for optimal performance
18//! * **Gradient Tracking**: Full GradTrack support for automatic differentiation
19//! * **Shape Validation**: Comprehensive error checking for compatible tensor shapes
20//!
21//! # Examples
22//!
23//! ```
24//! use train_station::Tensor;
25//!
26//! // Stack two 1D tensors along dimension 0
27//! let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
28//! let b = Tensor::from_slice(&[4.0, 5.0, 6.0], vec![3]).unwrap();
29//! let stacked = Tensor::stack(&[a, b], 0);
30//! assert_eq!(stacked.shape().dims, vec![2, 3]);
31//! assert_eq!(stacked.get(&[0, 0]), 1.0);
32//! assert_eq!(stacked.get(&[1, 2]), 6.0);
33//! ```
34//!
35//! ```
36//! use train_station::Tensor;
37//!
38//! // Stack multiple 2D tensors along dimension 1
39//! let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
40//! let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
41//! let c = Tensor::from_slice(&[9.0, 10.0, 11.0, 12.0], vec![2, 2]).unwrap();
42//! let stacked = Tensor::stack(&[a, b, c], 1);
43//! assert_eq!(stacked.shape().dims, vec![2, 3, 2]);
44//! ```
45
46use crate::gradtrack::{GradEngine, GradFn};
47use crate::tensor::core::Tensor;
48
49// SIMD optimizations for performance-critical operations
50#[cfg(target_arch = "x86_64")]
51use std::arch::x86_64::*;
52
53impl Tensor {
54 /// Stack a list of tensors along a new dimension
55 ///
56 /// Combines multiple tensors by adding a new dimension at the specified
57 /// position. All input tensors must have identical shapes, and the output
58 /// tensor will have a new dimension of size equal to the number of input
59 /// tensors. This operation is similar to PyTorch's `torch.stack` function.
60 ///
61 /// The stacking operation creates a new axis in the output tensor, unlike
62 /// concatenation which operates along existing dimensions. This makes
63 /// stacking useful for creating batch dimensions, combining feature maps,
64 /// and implementing operations that require adding new tensor axes.
65 ///
66 /// # Arguments
67 ///
68 /// * `tensors` - Array of tensors to stack. All tensors must have identical shapes.
69 /// * `dim` - Index of the new axis in the output shape (0 <= dim <= rank)
70 ///
71 /// # Returns
72 ///
73 /// A new tensor with the stacked data. The output shape is the input shape
74 /// with a new dimension of size `tensors.len()` inserted at position `dim`.
75 ///
76 /// # Panics
77 ///
78 /// * If the tensor array is empty
79 /// * If any tensor has a different shape than the first tensor
80 /// * If `dim` is out of bounds (dim > rank of input tensors)
81 ///
82 /// # Examples
83 ///
84 /// ```
85 /// use train_station::Tensor;
86 ///
87 /// // Stack two 1D tensors along dimension 0
88 /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
89 /// let b = Tensor::from_slice(&[4.0, 5.0, 6.0], vec![3]).unwrap();
90 /// let stacked = Tensor::stack(&[a, b], 0);
91 /// assert_eq!(stacked.shape().dims, vec![2, 3]);
92 /// assert_eq!(stacked.get(&[0, 0]), 1.0);
93 /// assert_eq!(stacked.get(&[1, 2]), 6.0);
94 /// ```
95 ///
96 /// ```
97 /// use train_station::Tensor;
98 ///
99 /// // Stack multiple 2D tensors along dimension 1
100 /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
101 /// let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
102 /// let c = Tensor::from_slice(&[9.0, 10.0, 11.0, 12.0], vec![2, 2]).unwrap();
103 /// let stacked = Tensor::stack(&[a, b, c], 1);
104 /// assert_eq!(stacked.shape().dims, vec![2, 3, 2]);
105 /// assert_eq!(stacked.get(&[0, 0, 0]), 1.0);
106 /// assert_eq!(stacked.get(&[1, 2, 1]), 12.0);
107 /// ```
108 ///
109 /// ```
110 /// use train_station::Tensor;
111 ///
112 /// // Stack with gradient tracking
113 /// let mut a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
114 /// let mut b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
115 /// a.set_requires_grad(true);
116 /// b.set_requires_grad(true);
117 ///
118 /// let stacked = Tensor::stack(&[a, b], 0);
119 /// assert!(stacked.requires_grad());
120 /// assert_eq!(stacked.shape().dims, vec![2, 2]);
121 /// ```
122 ///
123 /// ```
124 /// use train_station::Tensor;
125 ///
126 /// // Stack 3D tensors along the last dimension
127 /// let data1: Vec<f32> = (0..8).map(|i| i as f32).collect();
128 /// let data2: Vec<f32> = (8..16).map(|i| i as f32).collect();
129 /// let a = Tensor::from_slice(&data1, vec![2, 2, 2]).unwrap();
130 /// let b = Tensor::from_slice(&data2, vec![2, 2, 2]).unwrap();
131 /// let stacked = Tensor::stack(&[a, b], 3);
132 /// assert_eq!(stacked.shape().dims, vec![2, 2, 2, 2]);
133 /// assert_eq!(stacked.get(&[0, 0, 0, 0]), 0.0);
134 /// assert_eq!(stacked.get(&[1, 1, 1, 1]), 15.0);
135 /// ```
136 ///
137 /// # Performance
138 ///
139 /// - **Time Complexity**: O(n) where n is the total number of elements
140 /// - **Memory Usage**: Allocates new contiguous tensor for output
141 /// - **SIMD Optimization**: Uses AVX2 acceleration for large block copies
142 /// - **Block-wise Copying**: Optimized copying strategy for better cache performance
143 /// - **Gradient Tracking**: Preserves gradient requirements and tracking
144 ///
145 /// # Relationship to Other Operations
146 ///
147 /// This operation is related to other tensor transformations:
148 /// - `cat()` - Concatenates tensors along existing dimensions
149 /// - `unsqueeze()` - Adds a single dimension of size 1
150 /// - `reshape()` - Changes tensor shape without adding dimensions
151 ///
152 /// # Memory Layout
153 ///
154 /// The output tensor is always contiguous, with elements arranged so that
155 /// the stacked dimension is the fastest-changing index. This ensures optimal
156 /// performance for subsequent operations and maintains compatibility with
157 /// SIMD optimizations.
158 ///
159 /// # Gradient Computation
160 ///
161 /// During backward passes, gradients are split along the stacked dimension
162 /// and distributed back to the original input tensors. This is implemented
163 /// using the same gradient function as concatenation, treating the stack
164 /// operation as concatenation along a new axis.
165 pub fn stack(tensors: &[Tensor], dim: usize) -> Tensor {
166 assert!(!tensors.is_empty(), "stack requires at least one tensor");
167
168 // Validate all shapes identical
169 let base_dims = tensors[0].shape().dims.clone();
170 for t in tensors.iter() {
171 assert_eq!(
172 t.shape().dims,
173 base_dims,
174 "All tensors must have identical shapes for stack"
175 );
176 }
177
178 let rank = base_dims.len();
179 assert!(
180 dim <= rank,
181 "stack dim {} out of bounds for rank {}",
182 dim,
183 rank
184 );
185
186 // Compute output shape by inserting new axis of size = tensors.len()
187 let mut out_dims = Vec::with_capacity(rank + 1);
188 out_dims.extend_from_slice(&base_dims[..dim]);
189 out_dims.push(tensors.len());
190 out_dims.extend_from_slice(&base_dims[dim..]);
191
192 // Materialize into a new contiguous tensor
193 let mut output = Tensor::new(out_dims.clone());
194
195 // Copy block-wise: treat stack dim separately
196 // For output shape [pre..., K=tensors.len(), post...]
197 // inner = product(post...), outer = product(pre...)
198 let inner: usize = base_dims[dim..].iter().product();
199 let outer: usize = base_dims[..dim].iter().product();
200
201 unsafe {
202 let dst_ptr = output.as_mut_ptr();
203 for outer_idx in 0..outer {
204 for (k, t) in tensors.iter().enumerate() {
205 // Ensure contiguous source
206 let src = if t.is_contiguous() {
207 t.clone()
208 } else {
209 t.contiguous()
210 };
211 // Source offset: within each tensor, block size is inner
212 let src_base = outer_idx * inner;
213 let src_ptr = src.as_ptr().add(src_base);
214
215 // Destination offset computes with inserted axis
216 // out block along stacked axis of length K, each block is inner
217 let dst_base = outer_idx * (tensors.len() * inner) + k * inner;
218 optimized_block_copy(src_ptr, dst_ptr.add(dst_base), inner);
219 }
220 }
221 }
222
223 // GradTrack: stack is like cat with a new axis; gradient splits along that axis
224 let any_requires = tensors.iter().any(|t| t.requires_grad());
225 if any_requires {
226 output.set_requires_grad(true);
227 // For GradFn::Cat, provide sizes along concat dim and input shapes
228 let mut input_ids = Vec::with_capacity(tensors.len());
229 let mut input_sizes = Vec::with_capacity(tensors.len());
230 let mut input_shapes = Vec::with_capacity(tensors.len());
231 for t in tensors.iter() {
232 if t.requires_grad() {
233 input_ids.push(t.id());
234 }
235 input_sizes.push(1); // each slice along new axis has length 1
236 input_shapes.push(t.shape().dims.clone());
237 }
238 let grad_fn = GradFn::Cat {
239 dim,
240 input_sizes,
241 input_shapes,
242 };
243 output.set_grad_fn(grad_fn.clone());
244 GradEngine::register_operation(output.id(), input_ids, grad_fn);
245 }
246
247 output
248 }
249}
250
251/// Optimized block copy with SIMD acceleration for large blocks
252///
253/// Performs efficient memory copying with automatic SIMD optimization
254/// for large data blocks. This function automatically selects the best
255/// copying strategy based on block size and available CPU features.
256///
257/// # Arguments
258///
259/// * `src` - Source pointer to copy from
260/// * `dst` - Destination pointer to copy to
261/// * `count` - Number of f32 elements to copy
262///
263/// # Safety
264///
265/// The caller must ensure:
266/// * `src` and `dst` are valid pointers to f32 data
267/// * `src` and `dst` do not overlap (non-overlapping memory regions)
268/// * `count` elements are accessible from both pointers
269/// * The memory regions are properly aligned for SIMD operations
270///
271/// # Performance
272///
273/// - **Small blocks (≤32 elements)**: Direct memory copy
274/// - **Large blocks (≥64 elements)**: AVX2 SIMD acceleration when available
275/// - **Medium blocks**: Unrolled scalar copy for optimal performance
276/// - **Memory bandwidth**: Optimized for maximum throughput
277///
278/// # Examples
279///
280/// This function is used internally by the `stack()` operation for
281/// efficient memory copying. It automatically selects the best copying
282/// strategy based on block size and available CPU features.
283#[inline]
284unsafe fn optimized_block_copy(src: *const f32, dst: *mut f32, count: usize) {
285 if count == 0 {
286 return;
287 }
288
289 // For small blocks, use standard copy
290 if count <= 32 {
291 std::ptr::copy_nonoverlapping(src, dst, count);
292 return;
293 }
294
295 #[cfg(target_arch = "x86_64")]
296 {
297 if is_x86_feature_detected!("avx2") && count >= 64 {
298 simd_block_copy_avx2(src, dst, count);
299 return;
300 }
301 }
302
303 // Fallback to optimized scalar copy with unrolling
304 scalar_block_copy_unrolled(src, dst, count);
305}
306
307/// SIMD-optimized block copy using AVX2 instructions
308///
309/// Performs high-performance memory copying using AVX2 vector instructions
310/// for maximum throughput on x86_64 processors. This function processes
311/// 32 elements per iteration using 4 AVX2 vectors in an unrolled loop.
312///
313/// # Arguments
314///
315/// * `src` - Source pointer to copy from
316/// * `dst` - Destination pointer to copy to
317/// * `count` - Number of f32 elements to copy
318///
319/// # Safety
320///
321/// The caller must ensure:
322/// * AVX2 instructions are available on the target CPU
323/// * Pointers are properly aligned for AVX2 operations
324/// * Memory regions do not overlap
325/// * All elements are accessible from both pointers
326///
327/// # Performance
328///
329/// - **Throughput**: 32 elements per iteration (4 AVX2 vectors)
330/// - **Unrolling**: 4x unrolled loop for maximum instruction-level parallelism
331/// - **Fallback**: Handles remaining elements with 8-element blocks and scalar copy
332/// - **Memory bandwidth**: Optimized for maximum memory throughput
333///
334/// # Examples
335///
336/// This function is used internally by `optimized_block_copy()` for
337/// high-performance memory copying on x86_64 processors with AVX2 support.
338#[cfg(target_arch = "x86_64")]
339#[inline]
340#[target_feature(enable = "avx2")]
341unsafe fn simd_block_copy_avx2(src: *const f32, dst: *mut f32, count: usize) {
342 let simd_count = count / 32; // Process 32 elements per iteration (4x AVX2 vectors)
343 let mut offset = 0;
344
345 // Unrolled SIMD loop for maximum throughput
346 for _ in 0..simd_count {
347 // Process 4 AVX2 vectors (32 elements) per iteration
348 let vec1 = _mm256_loadu_ps(src.add(offset));
349 let vec2 = _mm256_loadu_ps(src.add(offset + 8));
350 let vec3 = _mm256_loadu_ps(src.add(offset + 16));
351 let vec4 = _mm256_loadu_ps(src.add(offset + 24));
352
353 _mm256_storeu_ps(dst.add(offset), vec1);
354 _mm256_storeu_ps(dst.add(offset + 8), vec2);
355 _mm256_storeu_ps(dst.add(offset + 16), vec3);
356 _mm256_storeu_ps(dst.add(offset + 24), vec4);
357
358 offset += 32;
359 }
360
361 // Handle remaining elements with 8-element SIMD blocks
362 let remaining_full_blocks = (count - offset) / 8;
363 for _ in 0..remaining_full_blocks {
364 let vec = _mm256_loadu_ps(src.add(offset));
365 _mm256_storeu_ps(dst.add(offset), vec);
366 offset += 8;
367 }
368
369 // Handle final elements
370 if offset < count {
371 std::ptr::copy_nonoverlapping(src.add(offset), dst.add(offset), count - offset);
372 }
373}
374
375/// Optimized scalar block copy with loop unrolling
376///
377/// Performs efficient memory copying using unrolled scalar operations
378/// for cases where SIMD instructions are not available or beneficial.
379/// This function processes 8 elements per iteration in an unrolled loop.
380///
381/// # Arguments
382///
383/// * `src` - Source pointer to copy from
384/// * `dst` - Destination pointer to copy to
385/// * `count` - Number of f32 elements to copy
386///
387/// # Safety
388///
389/// The caller must ensure:
390/// * `src` and `dst` are valid pointers to f32 data
391/// * Memory regions do not overlap
392/// * All elements are accessible from both pointers
393///
394/// # Performance
395///
396/// - **Throughput**: 8 elements per iteration (unrolled loop)
397/// - **Instruction-level parallelism**: Unrolled operations for better CPU utilization
398/// - **Fallback**: Handles remaining elements with standard memory copy
399/// - **Compatibility**: Works on all CPU architectures
400///
401/// # Examples
402///
403/// This function is used internally by `optimized_block_copy()` for
404/// efficient scalar memory copying when SIMD instructions are not available.
405#[inline]
406unsafe fn scalar_block_copy_unrolled(src: *const f32, dst: *mut f32, count: usize) {
407 let unroll_factor = 8;
408 let unroll_count = count / unroll_factor;
409 let mut offset = 0;
410
411 // Unrolled scalar copy for better performance
412 for _ in 0..unroll_count {
413 *dst.add(offset) = *src.add(offset);
414 *dst.add(offset + 1) = *src.add(offset + 1);
415 *dst.add(offset + 2) = *src.add(offset + 2);
416 *dst.add(offset + 3) = *src.add(offset + 3);
417 *dst.add(offset + 4) = *src.add(offset + 4);
418 *dst.add(offset + 5) = *src.add(offset + 5);
419 *dst.add(offset + 6) = *src.add(offset + 6);
420 *dst.add(offset + 7) = *src.add(offset + 7);
421 offset += unroll_factor;
422 }
423
424 // Handle remaining elements
425 if offset < count {
426 std::ptr::copy_nonoverlapping(src.add(offset), dst.add(offset), count - offset);
427 }
428}
429
430#[cfg(test)]
431mod tests {
432 use super::*;
433
434 #[test]
435 fn test_stack_basic() {
436 let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
437 let b = Tensor::from_slice(&[4.0, 5.0, 6.0], vec![3]).unwrap();
438 let y = Tensor::stack(&[a, b], 0);
439 assert_eq!(y.shape().dims, vec![2, 3]);
440 assert_eq!(y.get(&[0, 0]), 1.0);
441 assert_eq!(y.get(&[1, 2]), 6.0);
442 }
443
444 #[test]
445 fn test_stack_multiple_tensors() {
446 let a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
447 let b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
448 let c = Tensor::from_slice(&[5.0, 6.0], vec![2]).unwrap();
449 let stacked = Tensor::stack(&[a, b, c], 0);
450 assert_eq!(stacked.shape().dims, vec![3, 2]);
451 assert_eq!(stacked.get(&[0, 0]), 1.0);
452 assert_eq!(stacked.get(&[1, 1]), 4.0);
453 assert_eq!(stacked.get(&[2, 1]), 6.0);
454 }
455
456 #[test]
457 fn test_stack_2d_tensors() {
458 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
459 let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
460 let stacked = Tensor::stack(&[a, b], 1);
461 assert_eq!(stacked.shape().dims, vec![2, 2, 2]);
462 assert_eq!(stacked.get(&[0, 0, 0]), 1.0);
463 assert_eq!(stacked.get(&[1, 1, 1]), 8.0);
464 }
465
466 #[test]
467 fn test_stack_with_gradients() {
468 let mut a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
469 let mut b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
470 a.set_requires_grad(true);
471 b.set_requires_grad(true);
472
473 let stacked = Tensor::stack(&[a, b], 0);
474 assert!(stacked.requires_grad());
475 assert_eq!(stacked.shape().dims, vec![2, 2]);
476 }
477
478 #[test]
479 #[should_panic(expected = "stack requires at least one tensor")]
480 fn test_stack_empty() {
481 Tensor::stack(&[], 0);
482 }
483
484 #[test]
485 #[should_panic(expected = "All tensors must have identical shapes")]
486 fn test_stack_different_shapes() {
487 let a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
488 let b = Tensor::from_slice(&[3.0, 4.0, 5.0], vec![3]).unwrap();
489 Tensor::stack(&[a, b], 0);
490 }
491
492 #[test]
493 #[should_panic(expected = "stack dim 2 out of bounds for rank 1")]
494 fn test_stack_dim_out_of_bounds() {
495 let a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
496 let b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
497 Tensor::stack(&[a, b], 2);
498 }
499}