train_station/tensor/transform/
cat.rs

1//! Tensor concatenation operations
2//!
3//! This module provides tensor concatenation functionality that joins multiple
4//! tensors along a specified dimension. Concatenation is a fundamental tensor
5//! transformation operation used in machine learning for combining data from
6//! multiple sources, building batch operations, and creating complex tensor
7//! structures.
8//!
9//! # Operations
10//!
11//! * `cat()` - Concatenate multiple tensors along a specified dimension
12//!
13//! # Performance Characteristics
14//!
15//! * **SIMD Optimized**: Uses AVX2 instructions for large block copies when available
16//! * **Memory Efficient**: Minimizes temporary allocations by reusing contiguous data
17//! * **Stride Aware**: Handles non-contiguous tensors efficiently with materialization
18//! * **Gradient Tracking**: Full GradTrack support for automatic differentiation
19//!
20//! # Examples
21//!
22//! ```
23//! use train_station::Tensor;
24//!
25//! // Concatenate 1D tensors
26//! let a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
27//! let b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
28//! let result = Tensor::cat(&[a, b], 0);
29//! assert_eq!(result.shape().dims(), vec![4]);
30//!
31//! // Concatenate 2D tensors along different dimensions
32//! let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
33//! let y = Tensor::from_slice(&[5.0, 6.0], vec![2, 1]).unwrap();
34//! let result = Tensor::cat(&[x, y], 1);
35//! assert_eq!(result.shape().dims(), vec![2, 3]);
36//! ```
37
38use crate::gradtrack::{GradEngine, GradFn};
39use crate::tensor::core::Tensor;
40use crate::tensor::iterator::collect::optimized_copy;
41
42impl Tensor {
43    /// Concatenate tensors along a given dimension
44    ///
45    /// Joins multiple tensors along the specified dimension, creating a new tensor
46    /// with the combined data. All input tensors must have the same rank and
47    /// matching dimensions except for the concatenation dimension.
48    ///
49    /// # Arguments
50    ///
51    /// * `tensors` - Slice of tensors to concatenate (must not be empty)
52    /// * `dim` - Dimension along which to concatenate (must be < tensor rank)
53    ///
54    /// # Returns
55    ///
56    /// A new tensor containing the concatenated data with shape where the
57    /// concatenation dimension is the sum of all input tensor sizes along that dimension.
58    ///
59    /// # Panics
60    ///
61    /// * If `tensors` is empty
62    /// * If `dim` is out of bounds for the tensor rank
63    /// * If tensors have different ranks
64    /// * If tensors have mismatched dimensions (except along concatenation dimension)
65    ///
66    /// # Examples
67    ///
68    /// ```
69    /// use train_station::Tensor;
70    ///
71    /// // Concatenate 1D tensors
72    /// let a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
73    /// let b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
74    /// let result = Tensor::cat(&[a, b], 0);
75    /// assert_eq!(result.shape().dims(), vec![4]);
76    /// assert_eq!(result.get(&[0]), 1.0);
77    /// assert_eq!(result.get(&[1]), 2.0);
78    /// assert_eq!(result.get(&[2]), 3.0);
79    /// assert_eq!(result.get(&[3]), 4.0);
80    /// ```
81    ///
82    /// ```
83    /// use train_station::Tensor;
84    ///
85    /// // Concatenate 2D tensors along dimension 1
86    /// let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
87    /// let b = Tensor::from_slice(&[5.0, 6.0], vec![2, 1]).unwrap();
88    /// let result = Tensor::cat(&[a, b], 1);
89    /// assert_eq!(result.shape().dims(), vec![2, 3]);
90    /// assert_eq!(result.get(&[0, 0]), 1.0);
91    /// assert_eq!(result.get(&[0, 1]), 2.0);
92    /// assert_eq!(result.get(&[0, 2]), 5.0);
93    /// ```
94    ///
95    /// ```
96    /// use train_station::Tensor;
97    ///
98    /// // Concatenate with gradient tracking
99    /// let mut a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
100    /// let mut b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
101    /// a.set_requires_grad(true);
102    /// b.set_requires_grad(true);
103    ///
104    /// let result = Tensor::cat(&[a, b], 0);
105    /// assert!(result.requires_grad());
106    /// ```
107    #[track_caller]
108    pub fn cat(tensors: &[Tensor], dim: usize) -> Tensor {
109        assert!(!tensors.is_empty(), "cat requires at least one tensor");
110
111        let rank = tensors[0].shape().rank();
112        assert!(
113            dim < rank,
114            "concat dim {} out of bounds for rank {}",
115            dim,
116            rank
117        );
118
119        // Validate shapes and compute output dims
120        let base_shape = tensors[0].shape().dims();
121        for t in tensors.iter() {
122            assert_eq!(t.shape().rank(), rank, "All tensors must have same rank");
123            for (i, (&a, &b)) in base_shape.iter().zip(t.shape().dims().iter()).enumerate() {
124                if i != dim {
125                    assert_eq!(
126                        a, b,
127                        "All dims except concat dim must match (dim {}: {} vs {})",
128                        i, a, b
129                    );
130                }
131            }
132        }
133
134        let mut out_dims = base_shape.to_vec();
135        let mut concat_len = 0usize;
136        for t in tensors.iter() {
137            concat_len += t.shape().dims()[dim];
138        }
139        out_dims[dim] = concat_len;
140
141        let mut output = Tensor::new(out_dims.to_vec());
142
143        // Calculate block sizes for contiguous copy
144        let inner: usize = out_dims[dim + 1..].iter().product();
145        let outer: usize = out_dims[..dim].iter().product();
146
147        // Prepare source buffers once to avoid per-iteration cloning/copying
148        // Each entry holds a pointer to contiguous data and the length along `dim`
149        struct SourceInfo {
150            base_ptr: *const f32,
151            len_along_dim: usize,
152        }
153
154        let mut temp_contiguous: Vec<Tensor> = Vec::new();
155        let mut sources: Vec<SourceInfo> = Vec::with_capacity(tensors.len());
156        for t in tensors.iter() {
157            let len_d = t.shape().dims()[dim];
158            if len_d == 0 {
159                // Skip empty tensors; keep alignment in running count during copy
160                sources.push(SourceInfo {
161                    base_ptr: std::ptr::null(),
162                    len_along_dim: 0,
163                });
164                continue;
165            }
166            if t.is_contiguous() {
167                let base_ptr = unsafe { t.as_ptr() };
168                sources.push(SourceInfo {
169                    base_ptr,
170                    len_along_dim: len_d,
171                });
172            } else {
173                // Materialize once and keep it alive in `temp_contiguous`
174                let cont = t.contiguous();
175                let base_ptr = unsafe { cont.as_ptr() };
176                temp_contiguous.push(cont);
177                sources.push(SourceInfo {
178                    base_ptr,
179                    len_along_dim: len_d,
180                });
181            }
182        }
183
184        unsafe {
185            let dst_ptr = output.as_mut_ptr();
186            for outer_idx in 0..outer {
187                let mut running = 0usize;
188                for src in &sources {
189                    let len_d = src.len_along_dim;
190                    if len_d == 0 {
191                        continue;
192                    }
193                    let copy_elems = len_d * inner;
194
195                    // Source base offset for this outer index
196                    let src_base = outer_idx * (len_d * inner);
197                    let src_ptr = src.base_ptr.add(src_base);
198
199                    // Destination base offset
200                    let dst_base = outer_idx * (concat_len * inner) + running * inner;
201                    let dst_cur = dst_ptr.add(dst_base);
202
203                    optimized_copy(src_ptr, dst_cur, copy_elems);
204                    running += len_d;
205                }
206            }
207        }
208
209        // GradTrack setup if any input requires_grad
210        let any_requires = tensors.iter().any(|t| t.requires_grad());
211        if any_requires {
212            output.set_requires_grad(true);
213            let mut input_ids = Vec::with_capacity(tensors.len());
214            let mut grad_input_sizes = Vec::new();
215            let mut grad_input_shapes = Vec::new();
216            for t in tensors.iter() {
217                if t.requires_grad() {
218                    input_ids.push(t.id());
219                    grad_input_sizes.push(t.shape().dims()[dim]);
220                    grad_input_shapes.push(t.shape().dims().to_vec());
221                }
222            }
223            let grad_fn = GradFn::Cat {
224                dim,
225                input_sizes: grad_input_sizes,
226                input_shapes: grad_input_shapes,
227            };
228            output.set_grad_fn(grad_fn.clone());
229            GradEngine::register_operation(output.id(), input_ids, grad_fn);
230        }
231
232        output
233    }
234}
235
236// Reuse iterator::collect::optimized_copy for all contiguous block copies
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241
242    #[test]
243    fn test_cat_1d() {
244        let a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
245        let b = Tensor::from_slice(&[3.0], vec![1]).unwrap();
246        let y = Tensor::cat(&[a, b], 0);
247        assert_eq!(y.shape().dims(), vec![3]);
248        assert_eq!(y.get(&[0]), 1.0);
249        assert_eq!(y.get(&[2]), 3.0);
250    }
251
252    #[test]
253    fn test_cat_2d_dim1() {
254        let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
255        let b = Tensor::from_slice(&[5.0, 6.0], vec![2, 1]).unwrap();
256        let y = Tensor::cat(&[a, b], 1);
257        assert_eq!(y.shape().dims(), vec![2, 3]);
258        assert_eq!(y.get(&[0, 2]), 5.0);
259        assert_eq!(y.get(&[1, 2]), 6.0);
260    }
261
262    #[test]
263    #[should_panic]
264    fn test_cat_mismatch() {
265        let a = Tensor::new(vec![2, 2]);
266        let b = Tensor::new(vec![3, 1]);
267        let _ = Tensor::cat(&[a, b], 1);
268    }
269}