train_station/tensor/transform/
split.rs

1//! Tensor splitting operations
2//!
3//! This module provides tensor splitting functionality that divides a tensor
4//! into multiple smaller tensors along a specified dimension. Splitting is a
5//! fundamental tensor transformation operation used in machine learning for
6//! dividing data into batches, creating multiple outputs from a single tensor,
7//! and implementing complex tensor manipulations.
8//!
9//! # Operations
10//!
11//! * `split()` - Split tensor into chunks of equal size along a dimension
12//! * `split_with_sizes()` - Split tensor into chunks with explicit sizes along a dimension
13//!
14//! # Performance Characteristics
15//!
16//! * **View Operations**: First chunk returns a view when possible (zero-copy)
17//! * **Copy Operations**: Subsequent chunks require data copying for non-zero offsets
18//! * **Memory Efficient**: Minimizes memory allocation through view reuse
19//! * **Gradient Tracking**: Full GradTrack support for automatic differentiation
20//! * **Shape Transformation**: Divides tensor along specified dimension while preserving other dimensions
21//!
22//! # Examples
23//!
24//! ```
25//! use train_station::Tensor;
26//!
27//! // Split tensor into equal-sized chunks
28//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
29//! let parts = tensor.split(1, 1);
30//! assert_eq!(parts.len(), 3);
31//! assert_eq!(parts[0].shape().dims(), vec![2, 1]);
32//! assert_eq!(parts[1].shape().dims(), vec![2, 1]);
33//! assert_eq!(parts[2].shape().dims(), vec![2, 1]);
34//! ```
35//!
36//! ```
37//! use train_station::Tensor;
38//!
39//! // Split tensor with explicit sizes
40//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![1, 5]).unwrap();
41//! let parts = tensor.split_with_sizes(&[2, 3], 1);
42//! assert_eq!(parts.len(), 2);
43//! assert_eq!(parts[0].shape().dims(), vec![1, 2]);
44//! assert_eq!(parts[1].shape().dims(), vec![1, 3]);
45//! ```
46//!
47//! # Gradient Tracking
48//!
49//! The split operations support automatic gradient tracking through
50//! the GradTrack system. When `requires_grad` is enabled, each split
51//! piece registers a gradient function that scatters gradients back
52//! to the original tensor during backward passes.
53
54use crate::gradtrack::{GradEngine, GradFn};
55use crate::tensor::core::Tensor;
56
57impl Tensor {
58    /// Split tensor into chunks of equal size along specified dimension
59    ///
60    /// Divides the tensor into multiple smaller tensors along the specified
61    /// dimension, where each chunk (except possibly the last) has the same size.
62    /// The last chunk may be smaller if the dimension size is not evenly
63    /// divisible by the split size.
64    ///
65    /// This operation returns a vector of tensors, where each tensor is a
66    /// view or copy of a portion of the original tensor. The first chunk
67    /// is returned as a view when possible (zero-copy), while subsequent
68    /// chunks may require data copying for non-zero base offsets.
69    ///
70    /// # Arguments
71    ///
72    /// * `split_size` - Size of each chunk along the specified dimension (must be > 0)
73    /// * `dim` - Dimension along which to split the tensor (must be < tensor rank)
74    ///
75    /// # Returns
76    ///
77    /// A vector of tensors, each representing a chunk of the original tensor.
78    /// The number of chunks depends on the dimension size and split size.
79    ///
80    /// # Panics
81    ///
82    /// * If tensor rank is 0 (scalar tensors cannot be split)
83    /// * If `dim` is out of bounds for the tensor rank
84    /// * If `split_size` is 0
85    ///
86    /// # Examples
87    ///
88    /// ```
89    /// use train_station::Tensor;
90    ///
91    /// // Split 2D tensor into equal chunks along dimension 1
92    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
93    /// let parts = tensor.split(1, 1);
94    /// assert_eq!(parts.len(), 3);
95    /// assert_eq!(parts[0].shape().dims(), vec![2, 1]);
96    /// assert_eq!(parts[1].shape().dims(), vec![2, 1]);
97    /// assert_eq!(parts[2].shape().dims(), vec![2, 1]);
98    /// assert_eq!(parts[0].get(&[0, 0]), 1.0);
99    /// assert_eq!(parts[1].get(&[0, 0]), 2.0);
100    /// assert_eq!(parts[2].get(&[1, 0]), 6.0);
101    /// ```
102    ///
103    /// ```
104    /// use train_station::Tensor;
105    ///
106    /// // Split with uneven division (last chunk smaller)
107    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![1, 5]).unwrap();
108    /// let parts = tensor.split(2, 1);
109    /// assert_eq!(parts.len(), 3);
110    /// assert_eq!(parts[0].shape().dims(), vec![1, 2]);
111    /// assert_eq!(parts[1].shape().dims(), vec![1, 2]);
112    /// assert_eq!(parts[2].shape().dims(), vec![1, 1]); // Last chunk smaller
113    /// ```
114    ///
115    /// ```
116    /// use train_station::Tensor;
117    ///
118    /// // Split with gradient tracking
119    /// let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
120    /// tensor.set_requires_grad(true);
121    ///
122    /// let parts = tensor.split(1, 1);
123    /// assert_eq!(parts.len(), 2);
124    /// assert!(parts[0].requires_grad());
125    /// assert!(parts[1].requires_grad());
126    /// ```
127    ///
128    /// ```
129    /// use train_station::Tensor;
130    ///
131    /// // Split 1D tensor
132    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]).unwrap();
133    /// let parts = tensor.split(2, 0);
134    /// assert_eq!(parts.len(), 3);
135    /// assert_eq!(parts[0].shape().dims(), vec![2]);
136    /// assert_eq!(parts[1].shape().dims(), vec![2]);
137    /// assert_eq!(parts[2].shape().dims(), vec![2]);
138    /// ```
139    ///
140    /// # Performance
141    ///
142    /// - **First Chunk**: O(1) - Returns a view when possible (zero-copy)
143    /// - **Subsequent Chunks**: O(n) - May require data copying for non-zero offsets
144    /// - **Memory Usage**: Minimal allocation for view operations, copying for non-zero offsets
145    /// - **Gradient Tracking**: Each chunk preserves gradient requirements and tracking
146    ///
147    /// # Relationship to Other Operations
148    ///
149    /// This operation is related to other tensor transformations:
150    /// - `split_with_sizes()` - More general version with explicit chunk sizes
151    /// - `cat()` - Inverse operation that concatenates tensors back together
152    /// - `chunk()` - Alternative splitting operation with different semantics
153    ///
154    /// # Memory Layout
155    ///
156    /// The first chunk maintains the same underlying data as a view when
157    /// the base offset is zero. Subsequent chunks may require data copying
158    /// to handle non-zero base offsets, ensuring proper memory layout.
159    #[track_caller]
160    pub fn split(&self, split_size: usize, dim: usize) -> Vec<Tensor> {
161        assert!(self.shape().rank() > 0, "split requires non-zero rank");
162        assert!(
163            dim < self.shape().rank(),
164            "split dim {} out of bounds for rank {}",
165            dim,
166            self.shape().rank()
167        );
168        assert!(split_size > 0, "split_size must be > 0");
169        let dim_size = self.shape().dims()[dim];
170        if dim_size == 0 {
171            return vec![];
172        }
173
174        let mut sizes = Vec::new();
175        let mut remaining = dim_size;
176        while remaining > 0 {
177            let len = remaining.min(split_size);
178            sizes.push(len);
179            remaining -= len;
180        }
181        self.split_with_sizes(&sizes, dim)
182    }
183
184    /// Split tensor into chunks with explicit sizes along specified dimension
185    ///
186    /// Divides the tensor into multiple smaller tensors along the specified
187    /// dimension according to the provided size specifications. Each chunk
188    /// has the exact size specified in the `split_sizes` array, and the sum
189    /// of all sizes must equal the size of the specified dimension.
190    ///
191    /// This operation provides precise control over the size of each resulting
192    /// chunk, unlike `split()` which creates equal-sized chunks. The first
193    /// chunk is returned as a view when possible (zero-copy), while subsequent
194    /// chunks may require data copying for non-zero base offsets.
195    ///
196    /// # Arguments
197    ///
198    /// * `split_sizes` - Array specifying the size of each chunk along the dimension
199    /// * `dim` - Dimension along which to split the tensor (must be < tensor rank)
200    ///
201    /// # Returns
202    ///
203    /// A vector of tensors, each representing a chunk of the original tensor
204    /// with the specified size. The number of chunks equals the length of `split_sizes`.
205    ///
206    /// # Panics
207    ///
208    /// * If tensor rank is 0 (scalar tensors cannot be split)
209    /// * If `dim` is out of bounds for the tensor rank
210    /// * If sum of `split_sizes` does not equal the size of the specified dimension
211    ///
212    /// # Examples
213    ///
214    /// ```
215    /// use train_station::Tensor;
216    ///
217    /// // Split with explicit sizes
218    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![1, 5]).unwrap();
219    /// let parts = tensor.split_with_sizes(&[2, 3], 1);
220    /// assert_eq!(parts.len(), 2);
221    /// assert_eq!(parts[0].shape().dims(), vec![1, 2]);
222    /// assert_eq!(parts[1].shape().dims(), vec![1, 3]);
223    /// assert_eq!(parts[0].get(&[0, 0]), 1.0);
224    /// assert_eq!(parts[0].get(&[0, 1]), 2.0);
225    /// assert_eq!(parts[1].get(&[0, 0]), 3.0);
226    /// ```
227    ///
228    /// ```
229    /// use train_station::Tensor;
230    ///
231    /// // Split 2D tensor with different chunk sizes
232    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
233    /// let parts = tensor.split_with_sizes(&[1, 2], 1);
234    /// assert_eq!(parts.len(), 2);
235    /// assert_eq!(parts[0].shape().dims(), vec![2, 1]);
236    /// assert_eq!(parts[1].shape().dims(), vec![2, 2]);
237    /// ```
238    ///
239    /// ```
240    /// use train_station::Tensor;
241    ///
242    /// // Split with gradient tracking
243    /// let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
244    /// tensor.set_requires_grad(true);
245    ///
246    /// let parts = tensor.split_with_sizes(&[1, 1], 1);
247    /// assert_eq!(parts.len(), 2);
248    /// assert!(parts[0].requires_grad());
249    /// assert!(parts[1].requires_grad());
250    /// ```
251    ///
252    /// ```
253    /// use train_station::Tensor;
254    ///
255    /// // Split 1D tensor with explicit sizes
256    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]).unwrap();
257    /// let parts = tensor.split_with_sizes(&[2, 2, 2], 0);
258    /// assert_eq!(parts.len(), 3);
259    /// assert_eq!(parts[0].shape().dims(), vec![2]);
260    /// assert_eq!(parts[1].shape().dims(), vec![2]);
261    /// assert_eq!(parts[2].shape().dims(), vec![2]);
262    /// ```
263    ///
264    /// # Performance
265    ///
266    /// - **First Chunk**: O(1) - Returns a view when possible (zero-copy)
267    /// - **Subsequent Chunks**: O(n) - May require data copying for non-zero offsets
268    /// - **Memory Usage**: Minimal allocation for view operations, copying for non-zero offsets
269    /// - **Gradient Tracking**: Each chunk preserves gradient requirements and tracking
270    ///
271    /// # Relationship to Other Operations
272    ///
273    /// This operation is related to other tensor transformations:
274    /// - `split()` - Simplified version with equal-sized chunks
275    /// - `cat()` - Inverse operation that concatenates tensors back together
276    /// - `chunk()` - Alternative splitting operation with different semantics
277    ///
278    /// # Memory Layout
279    ///
280    /// The first chunk maintains the same underlying data as a view when
281    /// the base offset is zero. Subsequent chunks may require data copying
282    /// to handle non-zero base offsets, ensuring proper memory layout.
283    /// Zero-sized chunks are handled by creating empty tensors with
284    /// appropriate shapes.
285    #[track_caller]
286    pub fn split_with_sizes(&self, split_sizes: &[usize], dim: usize) -> Vec<Tensor> {
287        assert!(self.shape().rank() > 0, "split requires non-zero rank");
288        assert!(
289            dim < self.shape().rank(),
290            "split dim {} out of bounds for rank {}",
291            dim,
292            self.shape().rank()
293        );
294        let dim_size = self.shape().dims()[dim];
295        let total: usize = split_sizes.iter().sum();
296        assert!(
297            total == dim_size,
298            "sum of split sizes {} must equal size {} of dim {}",
299            total,
300            dim_size,
301            dim
302        );
303
304        let mut outputs = Vec::with_capacity(split_sizes.len());
305        let mut start = 0usize;
306        for &len in split_sizes {
307            if len == 0 {
308                outputs.push(Tensor::zeros(
309                    self.shape()
310                        .dims()
311                        .iter()
312                        .enumerate()
313                        .map(|(i, &d)| if i == dim { 0 } else { d })
314                        .collect(),
315                ));
316                continue;
317            }
318            // Build new dims/strides with updated length along `dim`
319            let mut new_dims = self.shape().dims().to_vec();
320            new_dims[dim] = len;
321            let new_strides = self.strides().to_vec();
322
323            let base_offset = start * self.stride(dim);
324
325            let mut piece: Tensor;
326            if base_offset == 0 {
327                // True view for the first chunk
328                let view_shape = crate::tensor::Shape::as_view(new_dims, new_strides);
329                piece = self.create_view_with_shape(view_shape);
330            } else {
331                // Materialize contiguous copy for non-zero base offset
332                piece = Tensor::new(new_dims.clone());
333                // Use contiguous block copy per outer index to avoid per-element stride math
334                let inner: usize = new_dims[dim + 1..].iter().product();
335                let outer: usize = new_dims[..dim].iter().product();
336                unsafe {
337                    let dst_ptr = piece.as_mut_ptr();
338                    for outer_idx in 0..outer {
339                        let dst_row_base = outer_idx * (new_dims[dim] * inner);
340
341                        for k in 0..new_dims[dim] {
342                            // Calculate linear src pointer by mapping (outer, k, inner) to offset via offset() once per block
343                            // Fallback: compute src offset for the first element of the inner block and copy inner contiguous values
344                            // Build coordinates for the first element in this block
345                            let mut coords = vec![0usize; new_dims.len()];
346                            // Fill outer dims from outer_idx
347                            let mut tmp = outer_idx;
348                            for i in (0..dim).rev() {
349                                let s = new_dims[i];
350                                coords[i] = if s == 0 { 0 } else { tmp % s };
351                                if s != 0 {
352                                    tmp /= s;
353                                }
354                            }
355                            coords[dim] = k;
356                            // inner dims start at 0 here
357                            let mut src_coords = coords.clone();
358                            src_coords[dim] = start + k;
359                            let src_off = self.shape().offset(&src_coords);
360                            let dst_off = dst_row_base + k * inner;
361                            crate::tensor::iterator::collect::optimized_copy(
362                                self.as_ptr().add(src_off),
363                                dst_ptr.add(dst_off),
364                                inner,
365                            );
366                        }
367                    }
368                }
369            }
370
371            // GradTrack: register backward to scatter this piece's grad into original input range
372            if self.requires_grad() {
373                piece.set_requires_grad_internal(true);
374                let grad_fn = GradFn::Split {
375                    dim,
376                    start,
377                    length: len,
378                    input_shape: self.shape().dims().to_vec(),
379                };
380                piece.set_grad_fn(grad_fn.clone());
381                GradEngine::register_operation(piece.id(), vec![self.id()], grad_fn);
382            }
383
384            outputs.push(piece);
385            start += len;
386        }
387
388        outputs
389    }
390}
391
392#[cfg(test)]
393mod tests {
394    use super::*;
395
396    #[test]
397    fn test_split_equal_forward() {
398        let data: Vec<f32> = (0..12).map(|i| i as f32).collect();
399        let x = Tensor::from_slice(&data, vec![2, 6]).unwrap();
400        let parts = x.split(2, 1);
401        assert_eq!(parts.len(), 3);
402        assert_eq!(parts[0].shape().dims(), vec![2, 2]);
403        assert_eq!(parts[1].shape().dims(), vec![2, 2]);
404        assert_eq!(parts[2].shape().dims(), vec![2, 2]);
405        // Check a few values
406        assert_eq!(parts[0].get(&[0, 0]), 0.0);
407        assert_eq!(parts[1].get(&[0, 0]), 2.0);
408        assert_eq!(parts[2].get(&[1, 1]), 11.0);
409    }
410
411    #[test]
412    fn test_split_with_sizes_forward() {
413        let data: Vec<f32> = (0..15).map(|i| (i as f32) * 0.1).collect();
414        let x = Tensor::from_slice(&data, vec![3, 5]).unwrap();
415        let parts = x.split_with_sizes(&[2, 1, 2], 1);
416        assert_eq!(parts.len(), 3);
417        assert_eq!(parts[0].shape().dims(), vec![3, 2]);
418        assert_eq!(parts[1].shape().dims(), vec![3, 1]);
419        assert_eq!(parts[2].shape().dims(), vec![3, 2]);
420        assert_eq!(parts[1].get(&[2, 0]), (2 * 5 + 2) as f32 * 0.1);
421    }
422
423    #[test]
424    fn test_split_gradients_scatter() {
425        let data: Vec<f32> = (0..10).map(|i| (i as f32) * 0.5 - 1.0).collect();
426        let x = Tensor::from_slice(&data, vec![2, 5])
427            .unwrap()
428            .with_requires_grad();
429        let parts = x.split_with_sizes(&[2, 3], 1);
430        // Reconstruct full tensor via concatenation then backward with implicit ones
431        let mut full = Tensor::cat(&parts, 1);
432        full.backward(None);
433        let gx = x.grad_owned().expect("grad missing");
434        // All positions receive 1.0
435        for i in 0..x.size() {
436            assert_eq!(gx.get(&[i / 5, i % 5]), 1.0);
437        }
438    }
439
440    #[test]
441    fn test_split_1d_three_parts_grad() {
442        let data: Vec<f32> = (0..6).map(|i| i as f32).collect();
443        let x = Tensor::from_slice(&data, vec![6])
444            .unwrap()
445            .with_requires_grad();
446        let parts = x.split_with_sizes(&[2, 2, 2], 0);
447        // Concatenate then backward to avoid view/contig mismatches
448        let mut full = Tensor::cat(&parts, 0);
449        full.backward(None);
450        let gx = x.grad_owned().expect("grad missing");
451        for i in 0..6 {
452            assert_eq!(gx.get(&[i]), 1.0);
453        }
454    }
455}