train_station/tensor/transform/
cat.rs

1//! Tensor concatenation operations
2//!
3//! This module provides tensor concatenation functionality that joins multiple
4//! tensors along a specified dimension. Concatenation is a fundamental tensor
5//! transformation operation used in machine learning for combining data from
6//! multiple sources, building batch operations, and creating complex tensor
7//! structures.
8//!
9//! # Operations
10//!
11//! * `cat()` - Concatenate multiple tensors along a specified dimension
12//!
13//! # Performance Characteristics
14//!
15//! * **SIMD Optimized**: Uses AVX2 instructions for large block copies when available
16//! * **Memory Efficient**: Minimizes temporary allocations by reusing contiguous data
17//! * **Stride Aware**: Handles non-contiguous tensors efficiently with materialization
18//! * **Gradient Tracking**: Full GradTrack support for automatic differentiation
19//!
20//! # Examples
21//!
22//! ```
23//! use train_station::Tensor;
24//!
25//! // Concatenate 1D tensors
26//! let a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
27//! let b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
28//! let result = Tensor::cat(&[a, b], 0);
29//! assert_eq!(result.shape().dims, vec![4]);
30//!
31//! // Concatenate 2D tensors along different dimensions
32//! let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
33//! let y = Tensor::from_slice(&[5.0, 6.0], vec![2, 1]).unwrap();
34//! let result = Tensor::cat(&[x, y], 1);
35//! assert_eq!(result.shape().dims, vec![2, 3]);
36//! ```
37
38use crate::gradtrack::{GradEngine, GradFn};
39use crate::tensor::core::Tensor;
40
41// SIMD optimizations for performance-critical operations
42#[cfg(target_arch = "x86_64")]
43use std::arch::x86_64::*;
44
45impl Tensor {
46    /// Concatenate tensors along a given dimension
47    ///
48    /// Joins multiple tensors along the specified dimension, creating a new tensor
49    /// with the combined data. All input tensors must have the same rank and
50    /// matching dimensions except for the concatenation dimension.
51    ///
52    /// # Arguments
53    ///
54    /// * `tensors` - Slice of tensors to concatenate (must not be empty)
55    /// * `dim` - Dimension along which to concatenate (must be < tensor rank)
56    ///
57    /// # Returns
58    ///
59    /// A new tensor containing the concatenated data with shape where the
60    /// concatenation dimension is the sum of all input tensor sizes along that dimension.
61    ///
62    /// # Panics
63    ///
64    /// * If `tensors` is empty
65    /// * If `dim` is out of bounds for the tensor rank
66    /// * If tensors have different ranks
67    /// * If tensors have mismatched dimensions (except along concatenation dimension)
68    ///
69    /// # Examples
70    ///
71    /// ```
72    /// use train_station::Tensor;
73    ///
74    /// // Concatenate 1D tensors
75    /// let a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
76    /// let b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
77    /// let result = Tensor::cat(&[a, b], 0);
78    /// assert_eq!(result.shape().dims, vec![4]);
79    /// assert_eq!(result.get(&[0]), 1.0);
80    /// assert_eq!(result.get(&[1]), 2.0);
81    /// assert_eq!(result.get(&[2]), 3.0);
82    /// assert_eq!(result.get(&[3]), 4.0);
83    /// ```
84    ///
85    /// ```
86    /// use train_station::Tensor;
87    ///
88    /// // Concatenate 2D tensors along dimension 1
89    /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
90    /// let b = Tensor::from_slice(&[5.0, 6.0], vec![2, 1]).unwrap();
91    /// let result = Tensor::cat(&[a, b], 1);
92    /// assert_eq!(result.shape().dims, vec![2, 3]);
93    /// assert_eq!(result.get(&[0, 0]), 1.0);
94    /// assert_eq!(result.get(&[0, 1]), 2.0);
95    /// assert_eq!(result.get(&[0, 2]), 5.0);
96    /// ```
97    ///
98    /// ```
99    /// use train_station::Tensor;
100    ///
101    /// // Concatenate with gradient tracking
102    /// let mut a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
103    /// let mut b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
104    /// a.set_requires_grad(true);
105    /// b.set_requires_grad(true);
106    ///
107    /// let result = Tensor::cat(&[a, b], 0);
108    /// assert!(result.requires_grad());
109    /// ```
110    pub fn cat(tensors: &[Tensor], dim: usize) -> Tensor {
111        assert!(!tensors.is_empty(), "cat requires at least one tensor");
112
113        let rank = tensors[0].shape().rank();
114        assert!(
115            dim < rank,
116            "concat dim {} out of bounds for rank {}",
117            dim,
118            rank
119        );
120
121        // Validate shapes and compute output dims
122        let base_shape = tensors[0].shape().dims.clone();
123        for t in tensors.iter() {
124            assert_eq!(t.shape().rank(), rank, "All tensors must have same rank");
125            for (i, (&a, &b)) in base_shape.iter().zip(t.shape().dims.iter()).enumerate() {
126                if i != dim {
127                    assert_eq!(
128                        a, b,
129                        "All dims except concat dim must match (dim {}: {} vs {})",
130                        i, a, b
131                    );
132                }
133            }
134        }
135
136        let mut out_dims = base_shape.clone();
137        let mut concat_len = 0usize;
138        for t in tensors.iter() {
139            concat_len += t.shape().dims[dim];
140        }
141        out_dims[dim] = concat_len;
142
143        let mut output = Tensor::new(out_dims.clone());
144
145        // Calculate block sizes for contiguous copy
146        let inner: usize = out_dims[dim + 1..].iter().product();
147        let outer: usize = out_dims[..dim].iter().product();
148
149        // Prepare source buffers once to avoid per-iteration cloning/copying
150        // Each entry holds a pointer to contiguous data and the length along `dim`
151        struct SourceInfo {
152            base_ptr: *const f32,
153            len_along_dim: usize,
154        }
155
156        let mut temp_contiguous: Vec<Tensor> = Vec::new();
157        let mut sources: Vec<SourceInfo> = Vec::with_capacity(tensors.len());
158        for t in tensors.iter() {
159            let len_d = t.shape().dims[dim];
160            if len_d == 0 {
161                // Skip empty tensors; keep alignment in running count during copy
162                sources.push(SourceInfo {
163                    base_ptr: std::ptr::null(),
164                    len_along_dim: 0,
165                });
166                continue;
167            }
168            if t.is_contiguous() {
169                let base_ptr = unsafe { t.as_ptr() };
170                sources.push(SourceInfo {
171                    base_ptr,
172                    len_along_dim: len_d,
173                });
174            } else {
175                // Materialize once and keep it alive in `temp_contiguous`
176                let cont = t.contiguous();
177                let base_ptr = unsafe { cont.as_ptr() };
178                temp_contiguous.push(cont);
179                sources.push(SourceInfo {
180                    base_ptr,
181                    len_along_dim: len_d,
182                });
183            }
184        }
185
186        unsafe {
187            let dst_ptr = output.as_mut_ptr();
188            for outer_idx in 0..outer {
189                let mut running = 0usize;
190                for src in &sources {
191                    let len_d = src.len_along_dim;
192                    if len_d == 0 {
193                        continue;
194                    }
195                    let copy_elems = len_d * inner;
196
197                    // Source base offset for this outer index
198                    let src_base = outer_idx * (len_d * inner);
199                    let src_ptr = src.base_ptr.add(src_base);
200
201                    // Destination base offset
202                    let dst_base = outer_idx * (concat_len * inner) + running * inner;
203                    let dst_cur = dst_ptr.add(dst_base);
204
205                    optimized_block_copy(src_ptr, dst_cur, copy_elems);
206                    running += len_d;
207                }
208            }
209        }
210
211        // GradTrack setup if any input requires_grad
212        let any_requires = tensors.iter().any(|t| t.requires_grad());
213        if any_requires {
214            output.set_requires_grad(true);
215            let mut input_ids = Vec::with_capacity(tensors.len());
216            let mut grad_input_sizes = Vec::new();
217            let mut grad_input_shapes = Vec::new();
218            for t in tensors.iter() {
219                if t.requires_grad() {
220                    input_ids.push(t.id());
221                    grad_input_sizes.push(t.shape().dims[dim]);
222                    grad_input_shapes.push(t.shape().dims.clone());
223                }
224            }
225            let grad_fn = GradFn::Cat {
226                dim,
227                input_sizes: grad_input_sizes,
228                input_shapes: grad_input_shapes,
229            };
230            output.set_grad_fn(grad_fn.clone());
231            GradEngine::register_operation(output.id(), input_ids, grad_fn);
232        }
233
234        output
235    }
236}
237
238/// Optimized block copy with SIMD acceleration for large blocks
239///
240/// Performs efficient memory copying with automatic SIMD optimization when
241/// available. Uses AVX2 instructions for large blocks and falls back to
242/// unrolled scalar operations for smaller blocks or when SIMD is not available.
243///
244/// # Arguments
245///
246/// * `src` - Source pointer to copy from
247/// * `dst` - Destination pointer to copy to
248/// * `count` - Number of f32 elements to copy
249///
250/// # Safety
251///
252/// The caller must ensure:
253/// * `src` points to valid memory with at least `count` f32 elements
254/// * `dst` points to valid writable memory with at least `count` f32 elements
255/// * The source and destination regions do not overlap
256/// * The pointers are properly aligned for the target architecture
257///
258/// # Performance
259///
260/// * **Large blocks (≥64 elements)**: Uses AVX2 SIMD instructions when available
261/// * **Medium blocks (32-63 elements)**: Uses unrolled scalar operations
262/// * **Small blocks (<32 elements)**: Uses standard library copy
263#[inline]
264unsafe fn optimized_block_copy(src: *const f32, dst: *mut f32, count: usize) {
265    if count == 0 {
266        return;
267    }
268
269    // For small blocks, use standard copy
270    if count <= 32 {
271        std::ptr::copy_nonoverlapping(src, dst, count);
272        return;
273    }
274
275    #[cfg(target_arch = "x86_64")]
276    {
277        if is_x86_feature_detected!("avx2") && count >= 64 {
278            simd_block_copy_avx2(src, dst, count);
279            return;
280        }
281    }
282
283    // Fallback to optimized scalar copy with unrolling
284    scalar_block_copy_unrolled(src, dst, count);
285}
286
287/// SIMD-optimized block copy using AVX2 instructions
288///
289/// Performs high-performance memory copying using AVX2 vector instructions.
290/// Processes 32 elements per iteration using 4 AVX2 vectors, with additional
291/// optimizations for remaining elements.
292///
293/// # Arguments
294///
295/// * `src` - Source pointer to copy from
296/// * `dst` - Destination pointer to copy to
297/// * `count` - Number of f32 elements to copy
298///
299/// # Safety
300///
301/// The caller must ensure:
302/// * AVX2 instructions are available on the target CPU
303/// * `src` points to valid memory with at least `count` f32 elements
304/// * `dst` points to valid writable memory with at least `count` f32 elements
305/// * The source and destination regions do not overlap
306/// * Pointers are properly aligned for AVX2 operations
307///
308/// # Performance
309///
310/// * **Main loop**: Processes 32 elements per iteration (4 AVX2 vectors)
311/// * **Remaining blocks**: Processes 8 elements per iteration for partial blocks
312/// * **Final elements**: Uses standard copy for remaining elements
313#[cfg(target_arch = "x86_64")]
314#[inline]
315#[target_feature(enable = "avx2")]
316unsafe fn simd_block_copy_avx2(src: *const f32, dst: *mut f32, count: usize) {
317    let simd_count = count / 32; // Process 32 elements per iteration (4x AVX2 vectors)
318    let mut offset = 0;
319
320    // Unrolled SIMD loop for maximum throughput
321    for _ in 0..simd_count {
322        // Process 4 AVX2 vectors (32 elements) per iteration
323        let vec1 = _mm256_loadu_ps(src.add(offset));
324        let vec2 = _mm256_loadu_ps(src.add(offset + 8));
325        let vec3 = _mm256_loadu_ps(src.add(offset + 16));
326        let vec4 = _mm256_loadu_ps(src.add(offset + 24));
327
328        _mm256_storeu_ps(dst.add(offset), vec1);
329        _mm256_storeu_ps(dst.add(offset + 8), vec2);
330        _mm256_storeu_ps(dst.add(offset + 16), vec3);
331        _mm256_storeu_ps(dst.add(offset + 24), vec4);
332
333        offset += 32;
334    }
335
336    // Handle remaining elements with 8-element SIMD blocks
337    let remaining_full_blocks = (count - offset) / 8;
338    for _ in 0..remaining_full_blocks {
339        let vec = _mm256_loadu_ps(src.add(offset));
340        _mm256_storeu_ps(dst.add(offset), vec);
341        offset += 8;
342    }
343
344    // Handle final elements
345    if offset < count {
346        std::ptr::copy_nonoverlapping(src.add(offset), dst.add(offset), count - offset);
347    }
348}
349
350/// Unrolled scalar block copy for optimal performance
351///
352/// Performs memory copying using unrolled scalar operations for better
353/// instruction-level parallelism and reduced loop overhead. Processes
354/// 8 elements per iteration in the main loop.
355///
356/// # Arguments
357///
358/// * `src` - Source pointer to copy from
359/// * `dst` - Destination pointer to copy to
360/// * `count` - Number of f32 elements to copy
361///
362/// # Safety
363///
364/// The caller must ensure:
365/// * `src` points to valid memory with at least `count` f32 elements
366/// * `dst` points to valid writable memory with at least `count` f32 elements
367/// * The source and destination regions do not overlap
368///
369/// # Performance
370///
371/// * **Main loop**: Processes 8 elements per iteration with manual unrolling
372/// * **Remaining elements**: Uses standard library copy for final elements
373/// * **Optimization**: Reduces loop overhead and improves instruction pipelining
374#[inline]
375unsafe fn scalar_block_copy_unrolled(src: *const f32, dst: *mut f32, count: usize) {
376    let unroll_factor = 8;
377    let unroll_count = count / unroll_factor;
378    let mut offset = 0;
379
380    // Unrolled scalar copy for better performance
381    for _ in 0..unroll_count {
382        *dst.add(offset) = *src.add(offset);
383        *dst.add(offset + 1) = *src.add(offset + 1);
384        *dst.add(offset + 2) = *src.add(offset + 2);
385        *dst.add(offset + 3) = *src.add(offset + 3);
386        *dst.add(offset + 4) = *src.add(offset + 4);
387        *dst.add(offset + 5) = *src.add(offset + 5);
388        *dst.add(offset + 6) = *src.add(offset + 6);
389        *dst.add(offset + 7) = *src.add(offset + 7);
390        offset += unroll_factor;
391    }
392
393    // Handle remaining elements
394    if offset < count {
395        std::ptr::copy_nonoverlapping(src.add(offset), dst.add(offset), count - offset);
396    }
397}
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402
403    #[test]
404    fn test_cat_1d() {
405        let a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
406        let b = Tensor::from_slice(&[3.0], vec![1]).unwrap();
407        let y = Tensor::cat(&[a, b], 0);
408        assert_eq!(y.shape().dims, vec![3]);
409        assert_eq!(y.get(&[0]), 1.0);
410        assert_eq!(y.get(&[2]), 3.0);
411    }
412
413    #[test]
414    fn test_cat_2d_dim1() {
415        let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
416        let b = Tensor::from_slice(&[5.0, 6.0], vec![2, 1]).unwrap();
417        let y = Tensor::cat(&[a, b], 1);
418        assert_eq!(y.shape().dims, vec![2, 3]);
419        assert_eq!(y.get(&[0, 2]), 5.0);
420        assert_eq!(y.get(&[1, 2]), 6.0);
421    }
422
423    #[test]
424    #[should_panic]
425    fn test_cat_mismatch() {
426        let a = Tensor::new(vec![2, 2]);
427        let b = Tensor::new(vec![3, 1]);
428        let _ = Tensor::cat(&[a, b], 1);
429    }
430}