train_station/tensor/transform/cat.rs
1//! Tensor concatenation operations
2//!
3//! This module provides tensor concatenation functionality that joins multiple
4//! tensors along a specified dimension. Concatenation is a fundamental tensor
5//! transformation operation used in machine learning for combining data from
6//! multiple sources, building batch operations, and creating complex tensor
7//! structures.
8//!
9//! # Operations
10//!
11//! * `cat()` - Concatenate multiple tensors along a specified dimension
12//!
13//! # Performance Characteristics
14//!
15//! * **SIMD Optimized**: Uses AVX2 instructions for large block copies when available
16//! * **Memory Efficient**: Minimizes temporary allocations by reusing contiguous data
17//! * **Stride Aware**: Handles non-contiguous tensors efficiently with materialization
18//! * **Gradient Tracking**: Full GradTrack support for automatic differentiation
19//!
20//! # Examples
21//!
22//! ```
23//! use train_station::Tensor;
24//!
25//! // Concatenate 1D tensors
26//! let a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
27//! let b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
28//! let result = Tensor::cat(&[a, b], 0);
29//! assert_eq!(result.shape().dims, vec![4]);
30//!
31//! // Concatenate 2D tensors along different dimensions
32//! let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
33//! let y = Tensor::from_slice(&[5.0, 6.0], vec![2, 1]).unwrap();
34//! let result = Tensor::cat(&[x, y], 1);
35//! assert_eq!(result.shape().dims, vec![2, 3]);
36//! ```
37
38use crate::gradtrack::{GradEngine, GradFn};
39use crate::tensor::core::Tensor;
40
41// SIMD optimizations for performance-critical operations
42#[cfg(target_arch = "x86_64")]
43use std::arch::x86_64::*;
44
45impl Tensor {
46 /// Concatenate tensors along a given dimension
47 ///
48 /// Joins multiple tensors along the specified dimension, creating a new tensor
49 /// with the combined data. All input tensors must have the same rank and
50 /// matching dimensions except for the concatenation dimension.
51 ///
52 /// # Arguments
53 ///
54 /// * `tensors` - Slice of tensors to concatenate (must not be empty)
55 /// * `dim` - Dimension along which to concatenate (must be < tensor rank)
56 ///
57 /// # Returns
58 ///
59 /// A new tensor containing the concatenated data with shape where the
60 /// concatenation dimension is the sum of all input tensor sizes along that dimension.
61 ///
62 /// # Panics
63 ///
64 /// * If `tensors` is empty
65 /// * If `dim` is out of bounds for the tensor rank
66 /// * If tensors have different ranks
67 /// * If tensors have mismatched dimensions (except along concatenation dimension)
68 ///
69 /// # Examples
70 ///
71 /// ```
72 /// use train_station::Tensor;
73 ///
74 /// // Concatenate 1D tensors
75 /// let a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
76 /// let b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
77 /// let result = Tensor::cat(&[a, b], 0);
78 /// assert_eq!(result.shape().dims, vec![4]);
79 /// assert_eq!(result.get(&[0]), 1.0);
80 /// assert_eq!(result.get(&[1]), 2.0);
81 /// assert_eq!(result.get(&[2]), 3.0);
82 /// assert_eq!(result.get(&[3]), 4.0);
83 /// ```
84 ///
85 /// ```
86 /// use train_station::Tensor;
87 ///
88 /// // Concatenate 2D tensors along dimension 1
89 /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
90 /// let b = Tensor::from_slice(&[5.0, 6.0], vec![2, 1]).unwrap();
91 /// let result = Tensor::cat(&[a, b], 1);
92 /// assert_eq!(result.shape().dims, vec![2, 3]);
93 /// assert_eq!(result.get(&[0, 0]), 1.0);
94 /// assert_eq!(result.get(&[0, 1]), 2.0);
95 /// assert_eq!(result.get(&[0, 2]), 5.0);
96 /// ```
97 ///
98 /// ```
99 /// use train_station::Tensor;
100 ///
101 /// // Concatenate with gradient tracking
102 /// let mut a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
103 /// let mut b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
104 /// a.set_requires_grad(true);
105 /// b.set_requires_grad(true);
106 ///
107 /// let result = Tensor::cat(&[a, b], 0);
108 /// assert!(result.requires_grad());
109 /// ```
110 #[track_caller]
111 pub fn cat(tensors: &[Tensor], dim: usize) -> Tensor {
112 assert!(!tensors.is_empty(), "cat requires at least one tensor");
113
114 let rank = tensors[0].shape().rank();
115 assert!(
116 dim < rank,
117 "concat dim {} out of bounds for rank {}",
118 dim,
119 rank
120 );
121
122 // Validate shapes and compute output dims
123 let base_shape = tensors[0].shape().dims.clone();
124 for t in tensors.iter() {
125 assert_eq!(t.shape().rank(), rank, "All tensors must have same rank");
126 for (i, (&a, &b)) in base_shape.iter().zip(t.shape().dims.iter()).enumerate() {
127 if i != dim {
128 assert_eq!(
129 a, b,
130 "All dims except concat dim must match (dim {}: {} vs {})",
131 i, a, b
132 );
133 }
134 }
135 }
136
137 let mut out_dims = base_shape.clone();
138 let mut concat_len = 0usize;
139 for t in tensors.iter() {
140 concat_len += t.shape().dims[dim];
141 }
142 out_dims[dim] = concat_len;
143
144 let mut output = Tensor::new(out_dims.clone());
145
146 // Calculate block sizes for contiguous copy
147 let inner: usize = out_dims[dim + 1..].iter().product();
148 let outer: usize = out_dims[..dim].iter().product();
149
150 // Prepare source buffers once to avoid per-iteration cloning/copying
151 // Each entry holds a pointer to contiguous data and the length along `dim`
152 struct SourceInfo {
153 base_ptr: *const f32,
154 len_along_dim: usize,
155 }
156
157 let mut temp_contiguous: Vec<Tensor> = Vec::new();
158 let mut sources: Vec<SourceInfo> = Vec::with_capacity(tensors.len());
159 for t in tensors.iter() {
160 let len_d = t.shape().dims[dim];
161 if len_d == 0 {
162 // Skip empty tensors; keep alignment in running count during copy
163 sources.push(SourceInfo {
164 base_ptr: std::ptr::null(),
165 len_along_dim: 0,
166 });
167 continue;
168 }
169 if t.is_contiguous() {
170 let base_ptr = unsafe { t.as_ptr() };
171 sources.push(SourceInfo {
172 base_ptr,
173 len_along_dim: len_d,
174 });
175 } else {
176 // Materialize once and keep it alive in `temp_contiguous`
177 let cont = t.contiguous();
178 let base_ptr = unsafe { cont.as_ptr() };
179 temp_contiguous.push(cont);
180 sources.push(SourceInfo {
181 base_ptr,
182 len_along_dim: len_d,
183 });
184 }
185 }
186
187 unsafe {
188 let dst_ptr = output.as_mut_ptr();
189 for outer_idx in 0..outer {
190 let mut running = 0usize;
191 for src in &sources {
192 let len_d = src.len_along_dim;
193 if len_d == 0 {
194 continue;
195 }
196 let copy_elems = len_d * inner;
197
198 // Source base offset for this outer index
199 let src_base = outer_idx * (len_d * inner);
200 let src_ptr = src.base_ptr.add(src_base);
201
202 // Destination base offset
203 let dst_base = outer_idx * (concat_len * inner) + running * inner;
204 let dst_cur = dst_ptr.add(dst_base);
205
206 optimized_block_copy(src_ptr, dst_cur, copy_elems);
207 running += len_d;
208 }
209 }
210 }
211
212 // GradTrack setup if any input requires_grad
213 let any_requires = tensors.iter().any(|t| t.requires_grad());
214 if any_requires {
215 output.set_requires_grad(true);
216 let mut input_ids = Vec::with_capacity(tensors.len());
217 let mut grad_input_sizes = Vec::new();
218 let mut grad_input_shapes = Vec::new();
219 for t in tensors.iter() {
220 if t.requires_grad() {
221 input_ids.push(t.id());
222 grad_input_sizes.push(t.shape().dims[dim]);
223 grad_input_shapes.push(t.shape().dims.clone());
224 }
225 }
226 let grad_fn = GradFn::Cat {
227 dim,
228 input_sizes: grad_input_sizes,
229 input_shapes: grad_input_shapes,
230 };
231 output.set_grad_fn(grad_fn.clone());
232 GradEngine::register_operation(output.id(), input_ids, grad_fn);
233 }
234
235 output
236 }
237}
238
239/// Optimized block copy with SIMD acceleration for large blocks
240///
241/// Performs efficient memory copying with automatic SIMD optimization when
242/// available. Uses AVX2 instructions for large blocks and falls back to
243/// unrolled scalar operations for smaller blocks or when SIMD is not available.
244///
245/// # Arguments
246///
247/// * `src` - Source pointer to copy from
248/// * `dst` - Destination pointer to copy to
249/// * `count` - Number of f32 elements to copy
250///
251/// # Safety
252///
253/// The caller must ensure:
254/// * `src` points to valid memory with at least `count` f32 elements
255/// * `dst` points to valid writable memory with at least `count` f32 elements
256/// * The source and destination regions do not overlap
257/// * The pointers are properly aligned for the target architecture
258///
259/// # Performance
260///
261/// * **Large blocks (≥64 elements)**: Uses AVX2 SIMD instructions when available
262/// * **Medium blocks (32-63 elements)**: Uses unrolled scalar operations
263/// * **Small blocks (<32 elements)**: Uses standard library copy
264#[inline]
265unsafe fn optimized_block_copy(src: *const f32, dst: *mut f32, count: usize) {
266 if count == 0 {
267 return;
268 }
269
270 // For small blocks, use standard copy
271 if count <= 32 {
272 std::ptr::copy_nonoverlapping(src, dst, count);
273 return;
274 }
275
276 #[cfg(target_arch = "x86_64")]
277 {
278 if is_x86_feature_detected!("avx2") && count >= 64 {
279 simd_block_copy_avx2(src, dst, count);
280 return;
281 }
282 }
283
284 // Fallback to optimized scalar copy with unrolling
285 scalar_block_copy_unrolled(src, dst, count);
286}
287
288/// SIMD-optimized block copy using AVX2 instructions
289///
290/// Performs high-performance memory copying using AVX2 vector instructions.
291/// Processes 32 elements per iteration using 4 AVX2 vectors, with additional
292/// optimizations for remaining elements.
293///
294/// # Arguments
295///
296/// * `src` - Source pointer to copy from
297/// * `dst` - Destination pointer to copy to
298/// * `count` - Number of f32 elements to copy
299///
300/// # Safety
301///
302/// The caller must ensure:
303/// * AVX2 instructions are available on the target CPU
304/// * `src` points to valid memory with at least `count` f32 elements
305/// * `dst` points to valid writable memory with at least `count` f32 elements
306/// * The source and destination regions do not overlap
307/// * Pointers are properly aligned for AVX2 operations
308///
309/// # Performance
310///
311/// * **Main loop**: Processes 32 elements per iteration (4 AVX2 vectors)
312/// * **Remaining blocks**: Processes 8 elements per iteration for partial blocks
313/// * **Final elements**: Uses standard copy for remaining elements
314#[cfg(target_arch = "x86_64")]
315#[inline]
316#[target_feature(enable = "avx2")]
317unsafe fn simd_block_copy_avx2(src: *const f32, dst: *mut f32, count: usize) {
318 let simd_count = count / 32; // Process 32 elements per iteration (4x AVX2 vectors)
319 let mut offset = 0;
320
321 // Unrolled SIMD loop for maximum throughput
322 for _ in 0..simd_count {
323 // Process 4 AVX2 vectors (32 elements) per iteration
324 let vec1 = _mm256_loadu_ps(src.add(offset));
325 let vec2 = _mm256_loadu_ps(src.add(offset + 8));
326 let vec3 = _mm256_loadu_ps(src.add(offset + 16));
327 let vec4 = _mm256_loadu_ps(src.add(offset + 24));
328
329 _mm256_storeu_ps(dst.add(offset), vec1);
330 _mm256_storeu_ps(dst.add(offset + 8), vec2);
331 _mm256_storeu_ps(dst.add(offset + 16), vec3);
332 _mm256_storeu_ps(dst.add(offset + 24), vec4);
333
334 offset += 32;
335 }
336
337 // Handle remaining elements with 8-element SIMD blocks
338 let remaining_full_blocks = (count - offset) / 8;
339 for _ in 0..remaining_full_blocks {
340 let vec = _mm256_loadu_ps(src.add(offset));
341 _mm256_storeu_ps(dst.add(offset), vec);
342 offset += 8;
343 }
344
345 // Handle final elements
346 if offset < count {
347 std::ptr::copy_nonoverlapping(src.add(offset), dst.add(offset), count - offset);
348 }
349}
350
351/// Unrolled scalar block copy for optimal performance
352///
353/// Performs memory copying using unrolled scalar operations for better
354/// instruction-level parallelism and reduced loop overhead. Processes
355/// 8 elements per iteration in the main loop.
356///
357/// # Arguments
358///
359/// * `src` - Source pointer to copy from
360/// * `dst` - Destination pointer to copy to
361/// * `count` - Number of f32 elements to copy
362///
363/// # Safety
364///
365/// The caller must ensure:
366/// * `src` points to valid memory with at least `count` f32 elements
367/// * `dst` points to valid writable memory with at least `count` f32 elements
368/// * The source and destination regions do not overlap
369///
370/// # Performance
371///
372/// * **Main loop**: Processes 8 elements per iteration with manual unrolling
373/// * **Remaining elements**: Uses standard library copy for final elements
374/// * **Optimization**: Reduces loop overhead and improves instruction pipelining
375#[inline]
376unsafe fn scalar_block_copy_unrolled(src: *const f32, dst: *mut f32, count: usize) {
377 let unroll_factor = 8;
378 let unroll_count = count / unroll_factor;
379 let mut offset = 0;
380
381 // Unrolled scalar copy for better performance
382 for _ in 0..unroll_count {
383 *dst.add(offset) = *src.add(offset);
384 *dst.add(offset + 1) = *src.add(offset + 1);
385 *dst.add(offset + 2) = *src.add(offset + 2);
386 *dst.add(offset + 3) = *src.add(offset + 3);
387 *dst.add(offset + 4) = *src.add(offset + 4);
388 *dst.add(offset + 5) = *src.add(offset + 5);
389 *dst.add(offset + 6) = *src.add(offset + 6);
390 *dst.add(offset + 7) = *src.add(offset + 7);
391 offset += unroll_factor;
392 }
393
394 // Handle remaining elements
395 if offset < count {
396 std::ptr::copy_nonoverlapping(src.add(offset), dst.add(offset), count - offset);
397 }
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403
404 #[test]
405 fn test_cat_1d() {
406 let a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
407 let b = Tensor::from_slice(&[3.0], vec![1]).unwrap();
408 let y = Tensor::cat(&[a, b], 0);
409 assert_eq!(y.shape().dims, vec![3]);
410 assert_eq!(y.get(&[0]), 1.0);
411 assert_eq!(y.get(&[2]), 3.0);
412 }
413
414 #[test]
415 fn test_cat_2d_dim1() {
416 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
417 let b = Tensor::from_slice(&[5.0, 6.0], vec![2, 1]).unwrap();
418 let y = Tensor::cat(&[a, b], 1);
419 assert_eq!(y.shape().dims, vec![2, 3]);
420 assert_eq!(y.get(&[0, 2]), 5.0);
421 assert_eq!(y.get(&[1, 2]), 6.0);
422 }
423
424 #[test]
425 #[should_panic]
426 fn test_cat_mismatch() {
427 let a = Tensor::new(vec![2, 2]);
428 let b = Tensor::new(vec![3, 1]);
429 let _ = Tensor::cat(&[a, b], 1);
430 }
431}