train_station/tensor/transform/cat.rs
1//! Tensor concatenation operations
2//!
3//! This module provides tensor concatenation functionality that joins multiple
4//! tensors along a specified dimension. Concatenation is a fundamental tensor
5//! transformation operation used in machine learning for combining data from
6//! multiple sources, building batch operations, and creating complex tensor
7//! structures.
8//!
9//! # Operations
10//!
11//! * `cat()` - Concatenate multiple tensors along a specified dimension
12//!
13//! # Performance Characteristics
14//!
15//! * **SIMD Optimized**: Uses AVX2 instructions for large block copies when available
16//! * **Memory Efficient**: Minimizes temporary allocations by reusing contiguous data
17//! * **Stride Aware**: Handles non-contiguous tensors efficiently with materialization
18//! * **Gradient Tracking**: Full GradTrack support for automatic differentiation
19//!
20//! # Examples
21//!
22//! ```
23//! use train_station::Tensor;
24//!
25//! // Concatenate 1D tensors
26//! let a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
27//! let b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
28//! let result = Tensor::cat(&[a, b], 0);
29//! assert_eq!(result.shape().dims, vec![4]);
30//!
31//! // Concatenate 2D tensors along different dimensions
32//! let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
33//! let y = Tensor::from_slice(&[5.0, 6.0], vec![2, 1]).unwrap();
34//! let result = Tensor::cat(&[x, y], 1);
35//! assert_eq!(result.shape().dims, vec![2, 3]);
36//! ```
37
38use crate::gradtrack::{GradEngine, GradFn};
39use crate::tensor::core::Tensor;
40
41// SIMD optimizations for performance-critical operations
42#[cfg(target_arch = "x86_64")]
43use std::arch::x86_64::*;
44
45impl Tensor {
46 /// Concatenate tensors along a given dimension
47 ///
48 /// Joins multiple tensors along the specified dimension, creating a new tensor
49 /// with the combined data. All input tensors must have the same rank and
50 /// matching dimensions except for the concatenation dimension.
51 ///
52 /// # Arguments
53 ///
54 /// * `tensors` - Slice of tensors to concatenate (must not be empty)
55 /// * `dim` - Dimension along which to concatenate (must be < tensor rank)
56 ///
57 /// # Returns
58 ///
59 /// A new tensor containing the concatenated data with shape where the
60 /// concatenation dimension is the sum of all input tensor sizes along that dimension.
61 ///
62 /// # Panics
63 ///
64 /// * If `tensors` is empty
65 /// * If `dim` is out of bounds for the tensor rank
66 /// * If tensors have different ranks
67 /// * If tensors have mismatched dimensions (except along concatenation dimension)
68 ///
69 /// # Examples
70 ///
71 /// ```
72 /// use train_station::Tensor;
73 ///
74 /// // Concatenate 1D tensors
75 /// let a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
76 /// let b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
77 /// let result = Tensor::cat(&[a, b], 0);
78 /// assert_eq!(result.shape().dims, vec![4]);
79 /// assert_eq!(result.get(&[0]), 1.0);
80 /// assert_eq!(result.get(&[1]), 2.0);
81 /// assert_eq!(result.get(&[2]), 3.0);
82 /// assert_eq!(result.get(&[3]), 4.0);
83 /// ```
84 ///
85 /// ```
86 /// use train_station::Tensor;
87 ///
88 /// // Concatenate 2D tensors along dimension 1
89 /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
90 /// let b = Tensor::from_slice(&[5.0, 6.0], vec![2, 1]).unwrap();
91 /// let result = Tensor::cat(&[a, b], 1);
92 /// assert_eq!(result.shape().dims, vec![2, 3]);
93 /// assert_eq!(result.get(&[0, 0]), 1.0);
94 /// assert_eq!(result.get(&[0, 1]), 2.0);
95 /// assert_eq!(result.get(&[0, 2]), 5.0);
96 /// ```
97 ///
98 /// ```
99 /// use train_station::Tensor;
100 ///
101 /// // Concatenate with gradient tracking
102 /// let mut a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
103 /// let mut b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
104 /// a.set_requires_grad(true);
105 /// b.set_requires_grad(true);
106 ///
107 /// let result = Tensor::cat(&[a, b], 0);
108 /// assert!(result.requires_grad());
109 /// ```
110 pub fn cat(tensors: &[Tensor], dim: usize) -> Tensor {
111 assert!(!tensors.is_empty(), "cat requires at least one tensor");
112
113 let rank = tensors[0].shape().rank();
114 assert!(
115 dim < rank,
116 "concat dim {} out of bounds for rank {}",
117 dim,
118 rank
119 );
120
121 // Validate shapes and compute output dims
122 let base_shape = tensors[0].shape().dims.clone();
123 for t in tensors.iter() {
124 assert_eq!(t.shape().rank(), rank, "All tensors must have same rank");
125 for (i, (&a, &b)) in base_shape.iter().zip(t.shape().dims.iter()).enumerate() {
126 if i != dim {
127 assert_eq!(
128 a, b,
129 "All dims except concat dim must match (dim {}: {} vs {})",
130 i, a, b
131 );
132 }
133 }
134 }
135
136 let mut out_dims = base_shape.clone();
137 let mut concat_len = 0usize;
138 for t in tensors.iter() {
139 concat_len += t.shape().dims[dim];
140 }
141 out_dims[dim] = concat_len;
142
143 let mut output = Tensor::new(out_dims.clone());
144
145 // Calculate block sizes for contiguous copy
146 let inner: usize = out_dims[dim + 1..].iter().product();
147 let outer: usize = out_dims[..dim].iter().product();
148
149 // Prepare source buffers once to avoid per-iteration cloning/copying
150 // Each entry holds a pointer to contiguous data and the length along `dim`
151 struct SourceInfo {
152 base_ptr: *const f32,
153 len_along_dim: usize,
154 }
155
156 let mut temp_contiguous: Vec<Tensor> = Vec::new();
157 let mut sources: Vec<SourceInfo> = Vec::with_capacity(tensors.len());
158 for t in tensors.iter() {
159 let len_d = t.shape().dims[dim];
160 if len_d == 0 {
161 // Skip empty tensors; keep alignment in running count during copy
162 sources.push(SourceInfo {
163 base_ptr: std::ptr::null(),
164 len_along_dim: 0,
165 });
166 continue;
167 }
168 if t.is_contiguous() {
169 let base_ptr = unsafe { t.as_ptr() };
170 sources.push(SourceInfo {
171 base_ptr,
172 len_along_dim: len_d,
173 });
174 } else {
175 // Materialize once and keep it alive in `temp_contiguous`
176 let cont = t.contiguous();
177 let base_ptr = unsafe { cont.as_ptr() };
178 temp_contiguous.push(cont);
179 sources.push(SourceInfo {
180 base_ptr,
181 len_along_dim: len_d,
182 });
183 }
184 }
185
186 unsafe {
187 let dst_ptr = output.as_mut_ptr();
188 for outer_idx in 0..outer {
189 let mut running = 0usize;
190 for src in &sources {
191 let len_d = src.len_along_dim;
192 if len_d == 0 {
193 continue;
194 }
195 let copy_elems = len_d * inner;
196
197 // Source base offset for this outer index
198 let src_base = outer_idx * (len_d * inner);
199 let src_ptr = src.base_ptr.add(src_base);
200
201 // Destination base offset
202 let dst_base = outer_idx * (concat_len * inner) + running * inner;
203 let dst_cur = dst_ptr.add(dst_base);
204
205 optimized_block_copy(src_ptr, dst_cur, copy_elems);
206 running += len_d;
207 }
208 }
209 }
210
211 // GradTrack setup if any input requires_grad
212 let any_requires = tensors.iter().any(|t| t.requires_grad());
213 if any_requires {
214 output.set_requires_grad(true);
215 let mut input_ids = Vec::with_capacity(tensors.len());
216 let mut grad_input_sizes = Vec::new();
217 let mut grad_input_shapes = Vec::new();
218 for t in tensors.iter() {
219 if t.requires_grad() {
220 input_ids.push(t.id());
221 grad_input_sizes.push(t.shape().dims[dim]);
222 grad_input_shapes.push(t.shape().dims.clone());
223 }
224 }
225 let grad_fn = GradFn::Cat {
226 dim,
227 input_sizes: grad_input_sizes,
228 input_shapes: grad_input_shapes,
229 };
230 output.set_grad_fn(grad_fn.clone());
231 GradEngine::register_operation(output.id(), input_ids, grad_fn);
232 }
233
234 output
235 }
236}
237
238/// Optimized block copy with SIMD acceleration for large blocks
239///
240/// Performs efficient memory copying with automatic SIMD optimization when
241/// available. Uses AVX2 instructions for large blocks and falls back to
242/// unrolled scalar operations for smaller blocks or when SIMD is not available.
243///
244/// # Arguments
245///
246/// * `src` - Source pointer to copy from
247/// * `dst` - Destination pointer to copy to
248/// * `count` - Number of f32 elements to copy
249///
250/// # Safety
251///
252/// The caller must ensure:
253/// * `src` points to valid memory with at least `count` f32 elements
254/// * `dst` points to valid writable memory with at least `count` f32 elements
255/// * The source and destination regions do not overlap
256/// * The pointers are properly aligned for the target architecture
257///
258/// # Performance
259///
260/// * **Large blocks (≥64 elements)**: Uses AVX2 SIMD instructions when available
261/// * **Medium blocks (32-63 elements)**: Uses unrolled scalar operations
262/// * **Small blocks (<32 elements)**: Uses standard library copy
263#[inline]
264unsafe fn optimized_block_copy(src: *const f32, dst: *mut f32, count: usize) {
265 if count == 0 {
266 return;
267 }
268
269 // For small blocks, use standard copy
270 if count <= 32 {
271 std::ptr::copy_nonoverlapping(src, dst, count);
272 return;
273 }
274
275 #[cfg(target_arch = "x86_64")]
276 {
277 if is_x86_feature_detected!("avx2") && count >= 64 {
278 simd_block_copy_avx2(src, dst, count);
279 return;
280 }
281 }
282
283 // Fallback to optimized scalar copy with unrolling
284 scalar_block_copy_unrolled(src, dst, count);
285}
286
287/// SIMD-optimized block copy using AVX2 instructions
288///
289/// Performs high-performance memory copying using AVX2 vector instructions.
290/// Processes 32 elements per iteration using 4 AVX2 vectors, with additional
291/// optimizations for remaining elements.
292///
293/// # Arguments
294///
295/// * `src` - Source pointer to copy from
296/// * `dst` - Destination pointer to copy to
297/// * `count` - Number of f32 elements to copy
298///
299/// # Safety
300///
301/// The caller must ensure:
302/// * AVX2 instructions are available on the target CPU
303/// * `src` points to valid memory with at least `count` f32 elements
304/// * `dst` points to valid writable memory with at least `count` f32 elements
305/// * The source and destination regions do not overlap
306/// * Pointers are properly aligned for AVX2 operations
307///
308/// # Performance
309///
310/// * **Main loop**: Processes 32 elements per iteration (4 AVX2 vectors)
311/// * **Remaining blocks**: Processes 8 elements per iteration for partial blocks
312/// * **Final elements**: Uses standard copy for remaining elements
313#[cfg(target_arch = "x86_64")]
314#[inline]
315#[target_feature(enable = "avx2")]
316unsafe fn simd_block_copy_avx2(src: *const f32, dst: *mut f32, count: usize) {
317 let simd_count = count / 32; // Process 32 elements per iteration (4x AVX2 vectors)
318 let mut offset = 0;
319
320 // Unrolled SIMD loop for maximum throughput
321 for _ in 0..simd_count {
322 // Process 4 AVX2 vectors (32 elements) per iteration
323 let vec1 = _mm256_loadu_ps(src.add(offset));
324 let vec2 = _mm256_loadu_ps(src.add(offset + 8));
325 let vec3 = _mm256_loadu_ps(src.add(offset + 16));
326 let vec4 = _mm256_loadu_ps(src.add(offset + 24));
327
328 _mm256_storeu_ps(dst.add(offset), vec1);
329 _mm256_storeu_ps(dst.add(offset + 8), vec2);
330 _mm256_storeu_ps(dst.add(offset + 16), vec3);
331 _mm256_storeu_ps(dst.add(offset + 24), vec4);
332
333 offset += 32;
334 }
335
336 // Handle remaining elements with 8-element SIMD blocks
337 let remaining_full_blocks = (count - offset) / 8;
338 for _ in 0..remaining_full_blocks {
339 let vec = _mm256_loadu_ps(src.add(offset));
340 _mm256_storeu_ps(dst.add(offset), vec);
341 offset += 8;
342 }
343
344 // Handle final elements
345 if offset < count {
346 std::ptr::copy_nonoverlapping(src.add(offset), dst.add(offset), count - offset);
347 }
348}
349
350/// Unrolled scalar block copy for optimal performance
351///
352/// Performs memory copying using unrolled scalar operations for better
353/// instruction-level parallelism and reduced loop overhead. Processes
354/// 8 elements per iteration in the main loop.
355///
356/// # Arguments
357///
358/// * `src` - Source pointer to copy from
359/// * `dst` - Destination pointer to copy to
360/// * `count` - Number of f32 elements to copy
361///
362/// # Safety
363///
364/// The caller must ensure:
365/// * `src` points to valid memory with at least `count` f32 elements
366/// * `dst` points to valid writable memory with at least `count` f32 elements
367/// * The source and destination regions do not overlap
368///
369/// # Performance
370///
371/// * **Main loop**: Processes 8 elements per iteration with manual unrolling
372/// * **Remaining elements**: Uses standard library copy for final elements
373/// * **Optimization**: Reduces loop overhead and improves instruction pipelining
374#[inline]
375unsafe fn scalar_block_copy_unrolled(src: *const f32, dst: *mut f32, count: usize) {
376 let unroll_factor = 8;
377 let unroll_count = count / unroll_factor;
378 let mut offset = 0;
379
380 // Unrolled scalar copy for better performance
381 for _ in 0..unroll_count {
382 *dst.add(offset) = *src.add(offset);
383 *dst.add(offset + 1) = *src.add(offset + 1);
384 *dst.add(offset + 2) = *src.add(offset + 2);
385 *dst.add(offset + 3) = *src.add(offset + 3);
386 *dst.add(offset + 4) = *src.add(offset + 4);
387 *dst.add(offset + 5) = *src.add(offset + 5);
388 *dst.add(offset + 6) = *src.add(offset + 6);
389 *dst.add(offset + 7) = *src.add(offset + 7);
390 offset += unroll_factor;
391 }
392
393 // Handle remaining elements
394 if offset < count {
395 std::ptr::copy_nonoverlapping(src.add(offset), dst.add(offset), count - offset);
396 }
397}
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402
403 #[test]
404 fn test_cat_1d() {
405 let a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
406 let b = Tensor::from_slice(&[3.0], vec![1]).unwrap();
407 let y = Tensor::cat(&[a, b], 0);
408 assert_eq!(y.shape().dims, vec![3]);
409 assert_eq!(y.get(&[0]), 1.0);
410 assert_eq!(y.get(&[2]), 3.0);
411 }
412
413 #[test]
414 fn test_cat_2d_dim1() {
415 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
416 let b = Tensor::from_slice(&[5.0, 6.0], vec![2, 1]).unwrap();
417 let y = Tensor::cat(&[a, b], 1);
418 assert_eq!(y.shape().dims, vec![2, 3]);
419 assert_eq!(y.get(&[0, 2]), 5.0);
420 assert_eq!(y.get(&[1, 2]), 6.0);
421 }
422
423 #[test]
424 #[should_panic]
425 fn test_cat_mismatch() {
426 let a = Tensor::new(vec![2, 2]);
427 let b = Tensor::new(vec![3, 1]);
428 let _ = Tensor::cat(&[a, b], 1);
429 }
430}