train_station/tensor/transform/stack.rs
1//! Tensor stacking operations
2//!
3//! This module provides tensor stacking functionality that combines multiple
4//! tensors along a new dimension. Stacking is a fundamental tensor transformation
5//! operation used in machine learning for combining multiple feature maps,
6//! creating batch dimensions, and implementing complex tensor manipulations
7//! that require adding new axes to tensor data.
8//!
9//! # Operations
10//!
11//! * `stack()` - Stack multiple tensors along a new dimension
12//!
13//! # Performance Characteristics
14//!
15//! * **SIMD Optimized**: AVX2 acceleration for large block copies
16//! * **Memory Efficient**: Optimized block-wise copying with minimal allocations
17//! * **Contiguous Output**: Always produces a contiguous tensor for optimal performance
18//! * **Gradient Tracking**: Full GradTrack support for automatic differentiation
19//! * **Shape Validation**: Comprehensive error checking for compatible tensor shapes
20//!
21//! # Examples
22//!
23//! ```
24//! use train_station::Tensor;
25//!
26//! // Stack two 1D tensors along dimension 0
27//! let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
28//! let b = Tensor::from_slice(&[4.0, 5.0, 6.0], vec![3]).unwrap();
29//! let stacked = Tensor::stack(&[a, b], 0);
30//! assert_eq!(stacked.shape().dims, vec![2, 3]);
31//! assert_eq!(stacked.get(&[0, 0]), 1.0);
32//! assert_eq!(stacked.get(&[1, 2]), 6.0);
33//! ```
34//!
35//! ```
36//! use train_station::Tensor;
37//!
38//! // Stack multiple 2D tensors along dimension 1
39//! let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
40//! let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
41//! let c = Tensor::from_slice(&[9.0, 10.0, 11.0, 12.0], vec![2, 2]).unwrap();
42//! let stacked = Tensor::stack(&[a, b, c], 1);
43//! assert_eq!(stacked.shape().dims, vec![2, 3, 2]);
44//! ```
45
46use crate::gradtrack::{GradEngine, GradFn};
47use crate::tensor::core::Tensor;
48
49// SIMD optimizations for performance-critical operations
50#[cfg(target_arch = "x86_64")]
51use std::arch::x86_64::*;
52
53impl Tensor {
54 /// Stack a list of tensors along a new dimension
55 ///
56 /// Combines multiple tensors by adding a new dimension at the specified
57 /// position. All input tensors must have identical shapes, and the output
58 /// tensor will have a new dimension of size equal to the number of input
59 /// tensors. This operation is similar to PyTorch's `torch.stack` function.
60 ///
61 /// The stacking operation creates a new axis in the output tensor, unlike
62 /// concatenation which operates along existing dimensions. This makes
63 /// stacking useful for creating batch dimensions, combining feature maps,
64 /// and implementing operations that require adding new tensor axes.
65 ///
66 /// # Arguments
67 ///
68 /// * `tensors` - Array of tensors to stack. All tensors must have identical shapes.
69 /// * `dim` - Index of the new axis in the output shape (0 <= dim <= rank)
70 ///
71 /// # Returns
72 ///
73 /// A new tensor with the stacked data. The output shape is the input shape
74 /// with a new dimension of size `tensors.len()` inserted at position `dim`.
75 ///
76 /// # Panics
77 ///
78 /// * If the tensor array is empty
79 /// * If any tensor has a different shape than the first tensor
80 /// * If `dim` is out of bounds (dim > rank of input tensors)
81 ///
82 /// # Examples
83 ///
84 /// ```
85 /// use train_station::Tensor;
86 ///
87 /// // Stack two 1D tensors along dimension 0
88 /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
89 /// let b = Tensor::from_slice(&[4.0, 5.0, 6.0], vec![3]).unwrap();
90 /// let stacked = Tensor::stack(&[a, b], 0);
91 /// assert_eq!(stacked.shape().dims, vec![2, 3]);
92 /// assert_eq!(stacked.get(&[0, 0]), 1.0);
93 /// assert_eq!(stacked.get(&[1, 2]), 6.0);
94 /// ```
95 ///
96 /// ```
97 /// use train_station::Tensor;
98 ///
99 /// // Stack multiple 2D tensors along dimension 1
100 /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
101 /// let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
102 /// let c = Tensor::from_slice(&[9.0, 10.0, 11.0, 12.0], vec![2, 2]).unwrap();
103 /// let stacked = Tensor::stack(&[a, b, c], 1);
104 /// assert_eq!(stacked.shape().dims, vec![2, 3, 2]);
105 /// assert_eq!(stacked.get(&[0, 0, 0]), 1.0);
106 /// assert_eq!(stacked.get(&[1, 2, 1]), 12.0);
107 /// ```
108 ///
109 /// ```
110 /// use train_station::Tensor;
111 ///
112 /// // Stack with gradient tracking
113 /// let mut a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
114 /// let mut b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
115 /// a.set_requires_grad(true);
116 /// b.set_requires_grad(true);
117 ///
118 /// let stacked = Tensor::stack(&[a, b], 0);
119 /// assert!(stacked.requires_grad());
120 /// assert_eq!(stacked.shape().dims, vec![2, 2]);
121 /// ```
122 ///
123 /// ```
124 /// use train_station::Tensor;
125 ///
126 /// // Stack 3D tensors along the last dimension
127 /// let data1: Vec<f32> = (0..8).map(|i| i as f32).collect();
128 /// let data2: Vec<f32> = (8..16).map(|i| i as f32).collect();
129 /// let a = Tensor::from_slice(&data1, vec![2, 2, 2]).unwrap();
130 /// let b = Tensor::from_slice(&data2, vec![2, 2, 2]).unwrap();
131 /// let stacked = Tensor::stack(&[a, b], 3);
132 /// assert_eq!(stacked.shape().dims, vec![2, 2, 2, 2]);
133 /// assert_eq!(stacked.get(&[0, 0, 0, 0]), 0.0);
134 /// assert_eq!(stacked.get(&[1, 1, 1, 1]), 15.0);
135 /// ```
136 ///
137 /// # Performance
138 ///
139 /// - **Time Complexity**: O(n) where n is the total number of elements
140 /// - **Memory Usage**: Allocates new contiguous tensor for output
141 /// - **SIMD Optimization**: Uses AVX2 acceleration for large block copies
142 /// - **Block-wise Copying**: Optimized copying strategy for better cache performance
143 /// - **Gradient Tracking**: Preserves gradient requirements and tracking
144 ///
145 /// # Relationship to Other Operations
146 ///
147 /// This operation is related to other tensor transformations:
148 /// - `cat()` - Concatenates tensors along existing dimensions
149 /// - `unsqueeze()` - Adds a single dimension of size 1
150 /// - `reshape()` - Changes tensor shape without adding dimensions
151 ///
152 /// # Memory Layout
153 ///
154 /// The output tensor is always contiguous, with elements arranged so that
155 /// the stacked dimension is the fastest-changing index. This ensures optimal
156 /// performance for subsequent operations and maintains compatibility with
157 /// SIMD optimizations.
158 ///
159 /// # Gradient Computation
160 ///
161 /// During backward passes, gradients are split along the stacked dimension
162 /// and distributed back to the original input tensors. This is implemented
163 /// using the same gradient function as concatenation, treating the stack
164 /// operation as concatenation along a new axis.
165 #[track_caller]
166 pub fn stack(tensors: &[Tensor], dim: usize) -> Tensor {
167 assert!(!tensors.is_empty(), "stack requires at least one tensor");
168
169 // Validate all shapes identical
170 let base_dims = tensors[0].shape().dims.clone();
171 for t in tensors.iter() {
172 assert_eq!(
173 t.shape().dims,
174 base_dims,
175 "All tensors must have identical shapes for stack"
176 );
177 }
178
179 let rank = base_dims.len();
180 assert!(
181 dim <= rank,
182 "stack dim {} out of bounds for rank {}",
183 dim,
184 rank
185 );
186
187 // Compute output shape by inserting new axis of size = tensors.len()
188 let mut out_dims = Vec::with_capacity(rank + 1);
189 out_dims.extend_from_slice(&base_dims[..dim]);
190 out_dims.push(tensors.len());
191 out_dims.extend_from_slice(&base_dims[dim..]);
192
193 // Materialize into a new contiguous tensor
194 let mut output = Tensor::new(out_dims.clone());
195
196 // Copy block-wise: treat stack dim separately
197 // For output shape [pre..., K=tensors.len(), post...]
198 // inner = product(post...), outer = product(pre...)
199 let inner: usize = base_dims[dim..].iter().product();
200 let outer: usize = base_dims[..dim].iter().product();
201
202 unsafe {
203 let dst_ptr = output.as_mut_ptr();
204 for outer_idx in 0..outer {
205 for (k, t) in tensors.iter().enumerate() {
206 // Ensure contiguous source
207 let src = if t.is_contiguous() {
208 t.clone()
209 } else {
210 t.contiguous()
211 };
212 // Source offset: within each tensor, block size is inner
213 let src_base = outer_idx * inner;
214 let src_ptr = src.as_ptr().add(src_base);
215
216 // Destination offset computes with inserted axis
217 // out block along stacked axis of length K, each block is inner
218 let dst_base = outer_idx * (tensors.len() * inner) + k * inner;
219 optimized_block_copy(src_ptr, dst_ptr.add(dst_base), inner);
220 }
221 }
222 }
223
224 // GradTrack: stack is like cat with a new axis; gradient splits along that axis
225 let any_requires = tensors.iter().any(|t| t.requires_grad());
226 if any_requires {
227 output.set_requires_grad(true);
228 // For GradFn::Cat, provide sizes along concat dim and input shapes
229 let mut input_ids = Vec::with_capacity(tensors.len());
230 let mut input_sizes = Vec::with_capacity(tensors.len());
231 let mut input_shapes = Vec::with_capacity(tensors.len());
232 for t in tensors.iter() {
233 if t.requires_grad() {
234 input_ids.push(t.id());
235 }
236 input_sizes.push(1); // each slice along new axis has length 1
237 input_shapes.push(t.shape().dims.clone());
238 }
239 let grad_fn = GradFn::Cat {
240 dim,
241 input_sizes,
242 input_shapes,
243 };
244 output.set_grad_fn(grad_fn.clone());
245 GradEngine::register_operation(output.id(), input_ids, grad_fn);
246 }
247
248 output
249 }
250}
251
252/// Optimized block copy with SIMD acceleration for large blocks
253///
254/// Performs efficient memory copying with automatic SIMD optimization
255/// for large data blocks. This function automatically selects the best
256/// copying strategy based on block size and available CPU features.
257///
258/// # Arguments
259///
260/// * `src` - Source pointer to copy from
261/// * `dst` - Destination pointer to copy to
262/// * `count` - Number of f32 elements to copy
263///
264/// # Safety
265///
266/// The caller must ensure:
267/// * `src` and `dst` are valid pointers to f32 data
268/// * `src` and `dst` do not overlap (non-overlapping memory regions)
269/// * `count` elements are accessible from both pointers
270/// * The memory regions are properly aligned for SIMD operations
271///
272/// # Performance
273///
274/// - **Small blocks (≤32 elements)**: Direct memory copy
275/// - **Large blocks (≥64 elements)**: AVX2 SIMD acceleration when available
276/// - **Medium blocks**: Unrolled scalar copy for optimal performance
277/// - **Memory bandwidth**: Optimized for maximum throughput
278///
279/// # Examples
280///
281/// This function is used internally by the `stack()` operation for
282/// efficient memory copying. It automatically selects the best copying
283/// strategy based on block size and available CPU features.
284#[inline]
285unsafe fn optimized_block_copy(src: *const f32, dst: *mut f32, count: usize) {
286 if count == 0 {
287 return;
288 }
289
290 // For small blocks, use standard copy
291 if count <= 32 {
292 std::ptr::copy_nonoverlapping(src, dst, count);
293 return;
294 }
295
296 #[cfg(target_arch = "x86_64")]
297 {
298 if is_x86_feature_detected!("avx2") && count >= 64 {
299 simd_block_copy_avx2(src, dst, count);
300 return;
301 }
302 }
303
304 // Fallback to optimized scalar copy with unrolling
305 scalar_block_copy_unrolled(src, dst, count);
306}
307
308/// SIMD-optimized block copy using AVX2 instructions
309///
310/// Performs high-performance memory copying using AVX2 vector instructions
311/// for maximum throughput on x86_64 processors. This function processes
312/// 32 elements per iteration using 4 AVX2 vectors in an unrolled loop.
313///
314/// # Arguments
315///
316/// * `src` - Source pointer to copy from
317/// * `dst` - Destination pointer to copy to
318/// * `count` - Number of f32 elements to copy
319///
320/// # Safety
321///
322/// The caller must ensure:
323/// * AVX2 instructions are available on the target CPU
324/// * Pointers are properly aligned for AVX2 operations
325/// * Memory regions do not overlap
326/// * All elements are accessible from both pointers
327///
328/// # Performance
329///
330/// - **Throughput**: 32 elements per iteration (4 AVX2 vectors)
331/// - **Unrolling**: 4x unrolled loop for maximum instruction-level parallelism
332/// - **Fallback**: Handles remaining elements with 8-element blocks and scalar copy
333/// - **Memory bandwidth**: Optimized for maximum memory throughput
334///
335/// # Examples
336///
337/// This function is used internally by `optimized_block_copy()` for
338/// high-performance memory copying on x86_64 processors with AVX2 support.
339#[cfg(target_arch = "x86_64")]
340#[inline]
341#[target_feature(enable = "avx2")]
342unsafe fn simd_block_copy_avx2(src: *const f32, dst: *mut f32, count: usize) {
343 let simd_count = count / 32; // Process 32 elements per iteration (4x AVX2 vectors)
344 let mut offset = 0;
345
346 // Unrolled SIMD loop for maximum throughput
347 for _ in 0..simd_count {
348 // Process 4 AVX2 vectors (32 elements) per iteration
349 let vec1 = _mm256_loadu_ps(src.add(offset));
350 let vec2 = _mm256_loadu_ps(src.add(offset + 8));
351 let vec3 = _mm256_loadu_ps(src.add(offset + 16));
352 let vec4 = _mm256_loadu_ps(src.add(offset + 24));
353
354 _mm256_storeu_ps(dst.add(offset), vec1);
355 _mm256_storeu_ps(dst.add(offset + 8), vec2);
356 _mm256_storeu_ps(dst.add(offset + 16), vec3);
357 _mm256_storeu_ps(dst.add(offset + 24), vec4);
358
359 offset += 32;
360 }
361
362 // Handle remaining elements with 8-element SIMD blocks
363 let remaining_full_blocks = (count - offset) / 8;
364 for _ in 0..remaining_full_blocks {
365 let vec = _mm256_loadu_ps(src.add(offset));
366 _mm256_storeu_ps(dst.add(offset), vec);
367 offset += 8;
368 }
369
370 // Handle final elements
371 if offset < count {
372 std::ptr::copy_nonoverlapping(src.add(offset), dst.add(offset), count - offset);
373 }
374}
375
376/// Optimized scalar block copy with loop unrolling
377///
378/// Performs efficient memory copying using unrolled scalar operations
379/// for cases where SIMD instructions are not available or beneficial.
380/// This function processes 8 elements per iteration in an unrolled loop.
381///
382/// # Arguments
383///
384/// * `src` - Source pointer to copy from
385/// * `dst` - Destination pointer to copy to
386/// * `count` - Number of f32 elements to copy
387///
388/// # Safety
389///
390/// The caller must ensure:
391/// * `src` and `dst` are valid pointers to f32 data
392/// * Memory regions do not overlap
393/// * All elements are accessible from both pointers
394///
395/// # Performance
396///
397/// - **Throughput**: 8 elements per iteration (unrolled loop)
398/// - **Instruction-level parallelism**: Unrolled operations for better CPU utilization
399/// - **Fallback**: Handles remaining elements with standard memory copy
400/// - **Compatibility**: Works on all CPU architectures
401///
402/// # Examples
403///
404/// This function is used internally by `optimized_block_copy()` for
405/// efficient scalar memory copying when SIMD instructions are not available.
406#[inline]
407unsafe fn scalar_block_copy_unrolled(src: *const f32, dst: *mut f32, count: usize) {
408 let unroll_factor = 8;
409 let unroll_count = count / unroll_factor;
410 let mut offset = 0;
411
412 // Unrolled scalar copy for better performance
413 for _ in 0..unroll_count {
414 *dst.add(offset) = *src.add(offset);
415 *dst.add(offset + 1) = *src.add(offset + 1);
416 *dst.add(offset + 2) = *src.add(offset + 2);
417 *dst.add(offset + 3) = *src.add(offset + 3);
418 *dst.add(offset + 4) = *src.add(offset + 4);
419 *dst.add(offset + 5) = *src.add(offset + 5);
420 *dst.add(offset + 6) = *src.add(offset + 6);
421 *dst.add(offset + 7) = *src.add(offset + 7);
422 offset += unroll_factor;
423 }
424
425 // Handle remaining elements
426 if offset < count {
427 std::ptr::copy_nonoverlapping(src.add(offset), dst.add(offset), count - offset);
428 }
429}
430
431#[cfg(test)]
432mod tests {
433 use super::*;
434
435 #[test]
436 fn test_stack_basic() {
437 let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
438 let b = Tensor::from_slice(&[4.0, 5.0, 6.0], vec![3]).unwrap();
439 let y = Tensor::stack(&[a, b], 0);
440 assert_eq!(y.shape().dims, vec![2, 3]);
441 assert_eq!(y.get(&[0, 0]), 1.0);
442 assert_eq!(y.get(&[1, 2]), 6.0);
443 }
444
445 #[test]
446 fn test_stack_multiple_tensors() {
447 let a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
448 let b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
449 let c = Tensor::from_slice(&[5.0, 6.0], vec![2]).unwrap();
450 let stacked = Tensor::stack(&[a, b, c], 0);
451 assert_eq!(stacked.shape().dims, vec![3, 2]);
452 assert_eq!(stacked.get(&[0, 0]), 1.0);
453 assert_eq!(stacked.get(&[1, 1]), 4.0);
454 assert_eq!(stacked.get(&[2, 1]), 6.0);
455 }
456
457 #[test]
458 fn test_stack_2d_tensors() {
459 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
460 let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
461 let stacked = Tensor::stack(&[a, b], 1);
462 assert_eq!(stacked.shape().dims, vec![2, 2, 2]);
463 assert_eq!(stacked.get(&[0, 0, 0]), 1.0);
464 assert_eq!(stacked.get(&[1, 1, 1]), 8.0);
465 }
466
467 #[test]
468 fn test_stack_with_gradients() {
469 let mut a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
470 let mut b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
471 a.set_requires_grad(true);
472 b.set_requires_grad(true);
473
474 let stacked = Tensor::stack(&[a, b], 0);
475 assert!(stacked.requires_grad());
476 assert_eq!(stacked.shape().dims, vec![2, 2]);
477 }
478
479 #[test]
480 #[should_panic(expected = "stack requires at least one tensor")]
481 fn test_stack_empty() {
482 Tensor::stack(&[], 0);
483 }
484
485 #[test]
486 #[should_panic(expected = "All tensors must have identical shapes")]
487 fn test_stack_different_shapes() {
488 let a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
489 let b = Tensor::from_slice(&[3.0, 4.0, 5.0], vec![3]).unwrap();
490 Tensor::stack(&[a, b], 0);
491 }
492
493 #[test]
494 #[should_panic(expected = "stack dim 2 out of bounds for rank 1")]
495 fn test_stack_dim_out_of_bounds() {
496 let a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
497 let b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
498 Tensor::stack(&[a, b], 2);
499 }
500}