train_station/tensor/transform/cat.rs
1//! Tensor concatenation operations
2//!
3//! This module provides tensor concatenation functionality that joins multiple
4//! tensors along a specified dimension. Concatenation is a fundamental tensor
5//! transformation operation used in machine learning for combining data from
6//! multiple sources, building batch operations, and creating complex tensor
7//! structures.
8//!
9//! # Operations
10//!
11//! * `cat()` - Concatenate multiple tensors along a specified dimension
12//!
13//! # Performance Characteristics
14//!
15//! * **SIMD Optimized**: Uses AVX2 instructions for large block copies when available
16//! * **Memory Efficient**: Minimizes temporary allocations by reusing contiguous data
17//! * **Stride Aware**: Handles non-contiguous tensors efficiently with materialization
18//! * **Gradient Tracking**: Full GradTrack support for automatic differentiation
19//!
20//! # Examples
21//!
22//! ```
23//! use train_station::Tensor;
24//!
25//! // Concatenate 1D tensors
26//! let a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
27//! let b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
28//! let result = Tensor::cat(&[a, b], 0);
29//! assert_eq!(result.shape().dims(), vec![4]);
30//!
31//! // Concatenate 2D tensors along different dimensions
32//! let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
33//! let y = Tensor::from_slice(&[5.0, 6.0], vec![2, 1]).unwrap();
34//! let result = Tensor::cat(&[x, y], 1);
35//! assert_eq!(result.shape().dims(), vec![2, 3]);
36//! ```
37
38use crate::gradtrack::{GradEngine, GradFn};
39use crate::tensor::core::Tensor;
40use crate::tensor::iterator::collect::optimized_copy;
41
42impl Tensor {
43 /// Concatenate tensors along a given dimension
44 ///
45 /// Joins multiple tensors along the specified dimension, creating a new tensor
46 /// with the combined data. All input tensors must have the same rank and
47 /// matching dimensions except for the concatenation dimension.
48 ///
49 /// # Arguments
50 ///
51 /// * `tensors` - Slice of tensors to concatenate (must not be empty)
52 /// * `dim` - Dimension along which to concatenate (must be < tensor rank)
53 ///
54 /// # Returns
55 ///
56 /// A new tensor containing the concatenated data with shape where the
57 /// concatenation dimension is the sum of all input tensor sizes along that dimension.
58 ///
59 /// # Panics
60 ///
61 /// * If `tensors` is empty
62 /// * If `dim` is out of bounds for the tensor rank
63 /// * If tensors have different ranks
64 /// * If tensors have mismatched dimensions (except along concatenation dimension)
65 ///
66 /// # Examples
67 ///
68 /// ```
69 /// use train_station::Tensor;
70 ///
71 /// // Concatenate 1D tensors
72 /// let a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
73 /// let b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
74 /// let result = Tensor::cat(&[a, b], 0);
75 /// assert_eq!(result.shape().dims(), vec![4]);
76 /// assert_eq!(result.get(&[0]), 1.0);
77 /// assert_eq!(result.get(&[1]), 2.0);
78 /// assert_eq!(result.get(&[2]), 3.0);
79 /// assert_eq!(result.get(&[3]), 4.0);
80 /// ```
81 ///
82 /// ```
83 /// use train_station::Tensor;
84 ///
85 /// // Concatenate 2D tensors along dimension 1
86 /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
87 /// let b = Tensor::from_slice(&[5.0, 6.0], vec![2, 1]).unwrap();
88 /// let result = Tensor::cat(&[a, b], 1);
89 /// assert_eq!(result.shape().dims(), vec![2, 3]);
90 /// assert_eq!(result.get(&[0, 0]), 1.0);
91 /// assert_eq!(result.get(&[0, 1]), 2.0);
92 /// assert_eq!(result.get(&[0, 2]), 5.0);
93 /// ```
94 ///
95 /// ```
96 /// use train_station::Tensor;
97 ///
98 /// // Concatenate with gradient tracking
99 /// let mut a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
100 /// let mut b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
101 /// a.set_requires_grad(true);
102 /// b.set_requires_grad(true);
103 ///
104 /// let result = Tensor::cat(&[a, b], 0);
105 /// assert!(result.requires_grad());
106 /// ```
107 #[track_caller]
108 pub fn cat(tensors: &[Tensor], dim: usize) -> Tensor {
109 assert!(!tensors.is_empty(), "cat requires at least one tensor");
110
111 let rank = tensors[0].shape().rank();
112 assert!(
113 dim < rank,
114 "concat dim {} out of bounds for rank {}",
115 dim,
116 rank
117 );
118
119 // Validate shapes and compute output dims
120 let base_shape = tensors[0].shape().dims();
121 for t in tensors.iter() {
122 assert_eq!(t.shape().rank(), rank, "All tensors must have same rank");
123 for (i, (&a, &b)) in base_shape.iter().zip(t.shape().dims().iter()).enumerate() {
124 if i != dim {
125 assert_eq!(
126 a, b,
127 "All dims except concat dim must match (dim {}: {} vs {})",
128 i, a, b
129 );
130 }
131 }
132 }
133
134 let mut out_dims = base_shape.to_vec();
135 let mut concat_len = 0usize;
136 for t in tensors.iter() {
137 concat_len += t.shape().dims()[dim];
138 }
139 out_dims[dim] = concat_len;
140
141 let mut output = Tensor::new(out_dims.to_vec());
142
143 // Calculate block sizes for contiguous copy
144 let inner: usize = out_dims[dim + 1..].iter().product();
145 let outer: usize = out_dims[..dim].iter().product();
146
147 // Prepare source buffers once to avoid per-iteration cloning/copying
148 // Each entry holds a pointer to contiguous data and the length along `dim`
149 struct SourceInfo {
150 base_ptr: *const f32,
151 len_along_dim: usize,
152 }
153
154 let mut temp_contiguous: Vec<Tensor> = Vec::new();
155 let mut sources: Vec<SourceInfo> = Vec::with_capacity(tensors.len());
156 for t in tensors.iter() {
157 let len_d = t.shape().dims()[dim];
158 if len_d == 0 {
159 // Skip empty tensors; keep alignment in running count during copy
160 sources.push(SourceInfo {
161 base_ptr: std::ptr::null(),
162 len_along_dim: 0,
163 });
164 continue;
165 }
166 if t.is_contiguous() {
167 let base_ptr = unsafe { t.as_ptr() };
168 sources.push(SourceInfo {
169 base_ptr,
170 len_along_dim: len_d,
171 });
172 } else {
173 // Materialize once and keep it alive in `temp_contiguous`
174 let cont = t.contiguous();
175 let base_ptr = unsafe { cont.as_ptr() };
176 temp_contiguous.push(cont);
177 sources.push(SourceInfo {
178 base_ptr,
179 len_along_dim: len_d,
180 });
181 }
182 }
183
184 unsafe {
185 let dst_ptr = output.as_mut_ptr();
186 for outer_idx in 0..outer {
187 let mut running = 0usize;
188 for src in &sources {
189 let len_d = src.len_along_dim;
190 if len_d == 0 {
191 continue;
192 }
193 let copy_elems = len_d * inner;
194
195 // Source base offset for this outer index
196 let src_base = outer_idx * (len_d * inner);
197 let src_ptr = src.base_ptr.add(src_base);
198
199 // Destination base offset
200 let dst_base = outer_idx * (concat_len * inner) + running * inner;
201 let dst_cur = dst_ptr.add(dst_base);
202
203 optimized_copy(src_ptr, dst_cur, copy_elems);
204 running += len_d;
205 }
206 }
207 }
208
209 // GradTrack setup if any input requires_grad
210 let any_requires = tensors.iter().any(|t| t.requires_grad());
211 if any_requires {
212 output.set_requires_grad(true);
213 let mut input_ids = Vec::with_capacity(tensors.len());
214 let mut grad_input_sizes = Vec::new();
215 let mut grad_input_shapes = Vec::new();
216 for t in tensors.iter() {
217 if t.requires_grad() {
218 input_ids.push(t.id());
219 grad_input_sizes.push(t.shape().dims()[dim]);
220 grad_input_shapes.push(t.shape().dims().to_vec());
221 }
222 }
223 let grad_fn = GradFn::Cat {
224 dim,
225 input_sizes: grad_input_sizes,
226 input_shapes: grad_input_shapes,
227 };
228 output.set_grad_fn(grad_fn.clone());
229 GradEngine::register_operation(output.id(), input_ids, grad_fn);
230 }
231
232 output
233 }
234}
235
236// Reuse iterator::collect::optimized_copy for all contiguous block copies
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241
242 #[test]
243 fn test_cat_1d() {
244 let a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
245 let b = Tensor::from_slice(&[3.0], vec![1]).unwrap();
246 let y = Tensor::cat(&[a, b], 0);
247 assert_eq!(y.shape().dims(), vec![3]);
248 assert_eq!(y.get(&[0]), 1.0);
249 assert_eq!(y.get(&[2]), 3.0);
250 }
251
252 #[test]
253 fn test_cat_2d_dim1() {
254 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
255 let b = Tensor::from_slice(&[5.0, 6.0], vec![2, 1]).unwrap();
256 let y = Tensor::cat(&[a, b], 1);
257 assert_eq!(y.shape().dims(), vec![2, 3]);
258 assert_eq!(y.get(&[0, 2]), 5.0);
259 assert_eq!(y.get(&[1, 2]), 6.0);
260 }
261
262 #[test]
263 #[should_panic]
264 fn test_cat_mismatch() {
265 let a = Tensor::new(vec![2, 2]);
266 let b = Tensor::new(vec![3, 1]);
267 let _ = Tensor::cat(&[a, b], 1);
268 }
269}