train_station/tensor/transform/
split.rs

1//! Tensor splitting operations
2//!
3//! This module provides tensor splitting functionality that divides a tensor
4//! into multiple smaller tensors along a specified dimension. Splitting is a
5//! fundamental tensor transformation operation used in machine learning for
6//! dividing data into batches, creating multiple outputs from a single tensor,
7//! and implementing complex tensor manipulations.
8//!
9//! # Operations
10//!
11//! * `split()` - Split tensor into chunks of equal size along a dimension
12//! * `split_with_sizes()` - Split tensor into chunks with explicit sizes along a dimension
13//!
14//! # Performance Characteristics
15//!
16//! * **View Operations**: First chunk returns a view when possible (zero-copy)
17//! * **Copy Operations**: Subsequent chunks require data copying for non-zero offsets
18//! * **Memory Efficient**: Minimizes memory allocation through view reuse
19//! * **Gradient Tracking**: Full GradTrack support for automatic differentiation
20//! * **Shape Transformation**: Divides tensor along specified dimension while preserving other dimensions
21//!
22//! # Examples
23//!
24//! ```
25//! use train_station::Tensor;
26//!
27//! // Split tensor into equal-sized chunks
28//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
29//! let parts = tensor.split(1, 1);
30//! assert_eq!(parts.len(), 3);
31//! assert_eq!(parts[0].shape().dims, vec![2, 1]);
32//! assert_eq!(parts[1].shape().dims, vec![2, 1]);
33//! assert_eq!(parts[2].shape().dims, vec![2, 1]);
34//! ```
35//!
36//! ```
37//! use train_station::Tensor;
38//!
39//! // Split tensor with explicit sizes
40//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![1, 5]).unwrap();
41//! let parts = tensor.split_with_sizes(&[2, 3], 1);
42//! assert_eq!(parts.len(), 2);
43//! assert_eq!(parts[0].shape().dims, vec![1, 2]);
44//! assert_eq!(parts[1].shape().dims, vec![1, 3]);
45//! ```
46//!
47//! # Gradient Tracking
48//!
49//! The split operations support automatic gradient tracking through
50//! the GradTrack system. When `requires_grad` is enabled, each split
51//! piece registers a gradient function that scatters gradients back
52//! to the original tensor during backward passes.
53
54use crate::gradtrack::{GradEngine, GradFn};
55use crate::tensor::core::Tensor;
56
57impl Tensor {
58    /// Split tensor into chunks of equal size along specified dimension
59    ///
60    /// Divides the tensor into multiple smaller tensors along the specified
61    /// dimension, where each chunk (except possibly the last) has the same size.
62    /// The last chunk may be smaller if the dimension size is not evenly
63    /// divisible by the split size.
64    ///
65    /// This operation returns a vector of tensors, where each tensor is a
66    /// view or copy of a portion of the original tensor. The first chunk
67    /// is returned as a view when possible (zero-copy), while subsequent
68    /// chunks may require data copying for non-zero base offsets.
69    ///
70    /// # Arguments
71    ///
72    /// * `split_size` - Size of each chunk along the specified dimension (must be > 0)
73    /// * `dim` - Dimension along which to split the tensor (must be < tensor rank)
74    ///
75    /// # Returns
76    ///
77    /// A vector of tensors, each representing a chunk of the original tensor.
78    /// The number of chunks depends on the dimension size and split size.
79    ///
80    /// # Panics
81    ///
82    /// * If tensor rank is 0 (scalar tensors cannot be split)
83    /// * If `dim` is out of bounds for the tensor rank
84    /// * If `split_size` is 0
85    ///
86    /// # Examples
87    ///
88    /// ```
89    /// use train_station::Tensor;
90    ///
91    /// // Split 2D tensor into equal chunks along dimension 1
92    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
93    /// let parts = tensor.split(1, 1);
94    /// assert_eq!(parts.len(), 3);
95    /// assert_eq!(parts[0].shape().dims, vec![2, 1]);
96    /// assert_eq!(parts[1].shape().dims, vec![2, 1]);
97    /// assert_eq!(parts[2].shape().dims, vec![2, 1]);
98    /// assert_eq!(parts[0].get(&[0, 0]), 1.0);
99    /// assert_eq!(parts[1].get(&[0, 0]), 2.0);
100    /// assert_eq!(parts[2].get(&[1, 0]), 6.0);
101    /// ```
102    ///
103    /// ```
104    /// use train_station::Tensor;
105    ///
106    /// // Split with uneven division (last chunk smaller)
107    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![1, 5]).unwrap();
108    /// let parts = tensor.split(2, 1);
109    /// assert_eq!(parts.len(), 3);
110    /// assert_eq!(parts[0].shape().dims, vec![1, 2]);
111    /// assert_eq!(parts[1].shape().dims, vec![1, 2]);
112    /// assert_eq!(parts[2].shape().dims, vec![1, 1]); // Last chunk smaller
113    /// ```
114    ///
115    /// ```
116    /// use train_station::Tensor;
117    ///
118    /// // Split with gradient tracking
119    /// let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
120    /// tensor.set_requires_grad(true);
121    ///
122    /// let parts = tensor.split(1, 1);
123    /// assert_eq!(parts.len(), 2);
124    /// assert!(parts[0].requires_grad());
125    /// assert!(parts[1].requires_grad());
126    /// ```
127    ///
128    /// ```
129    /// use train_station::Tensor;
130    ///
131    /// // Split 1D tensor
132    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]).unwrap();
133    /// let parts = tensor.split(2, 0);
134    /// assert_eq!(parts.len(), 3);
135    /// assert_eq!(parts[0].shape().dims, vec![2]);
136    /// assert_eq!(parts[1].shape().dims, vec![2]);
137    /// assert_eq!(parts[2].shape().dims, vec![2]);
138    /// ```
139    ///
140    /// # Performance
141    ///
142    /// - **First Chunk**: O(1) - Returns a view when possible (zero-copy)
143    /// - **Subsequent Chunks**: O(n) - May require data copying for non-zero offsets
144    /// - **Memory Usage**: Minimal allocation for view operations, copying for non-zero offsets
145    /// - **Gradient Tracking**: Each chunk preserves gradient requirements and tracking
146    ///
147    /// # Relationship to Other Operations
148    ///
149    /// This operation is related to other tensor transformations:
150    /// - `split_with_sizes()` - More general version with explicit chunk sizes
151    /// - `cat()` - Inverse operation that concatenates tensors back together
152    /// - `chunk()` - Alternative splitting operation with different semantics
153    ///
154    /// # Memory Layout
155    ///
156    /// The first chunk maintains the same underlying data as a view when
157    /// the base offset is zero. Subsequent chunks may require data copying
158    /// to handle non-zero base offsets, ensuring proper memory layout.
159    pub fn split(&self, split_size: usize, dim: usize) -> Vec<Tensor> {
160        assert!(self.shape().rank() > 0, "split requires non-zero rank");
161        assert!(
162            dim < self.shape().rank(),
163            "split dim {} out of bounds for rank {}",
164            dim,
165            self.shape().rank()
166        );
167        assert!(split_size > 0, "split_size must be > 0");
168        let dim_size = self.shape().dims[dim];
169        if dim_size == 0 {
170            return vec![];
171        }
172
173        let mut sizes = Vec::new();
174        let mut remaining = dim_size;
175        while remaining > 0 {
176            let len = remaining.min(split_size);
177            sizes.push(len);
178            remaining -= len;
179        }
180        self.split_with_sizes(&sizes, dim)
181    }
182
183    /// Split tensor into chunks with explicit sizes along specified dimension
184    ///
185    /// Divides the tensor into multiple smaller tensors along the specified
186    /// dimension according to the provided size specifications. Each chunk
187    /// has the exact size specified in the `split_sizes` array, and the sum
188    /// of all sizes must equal the size of the specified dimension.
189    ///
190    /// This operation provides precise control over the size of each resulting
191    /// chunk, unlike `split()` which creates equal-sized chunks. The first
192    /// chunk is returned as a view when possible (zero-copy), while subsequent
193    /// chunks may require data copying for non-zero base offsets.
194    ///
195    /// # Arguments
196    ///
197    /// * `split_sizes` - Array specifying the size of each chunk along the dimension
198    /// * `dim` - Dimension along which to split the tensor (must be < tensor rank)
199    ///
200    /// # Returns
201    ///
202    /// A vector of tensors, each representing a chunk of the original tensor
203    /// with the specified size. The number of chunks equals the length of `split_sizes`.
204    ///
205    /// # Panics
206    ///
207    /// * If tensor rank is 0 (scalar tensors cannot be split)
208    /// * If `dim` is out of bounds for the tensor rank
209    /// * If sum of `split_sizes` does not equal the size of the specified dimension
210    ///
211    /// # Examples
212    ///
213    /// ```
214    /// use train_station::Tensor;
215    ///
216    /// // Split with explicit sizes
217    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![1, 5]).unwrap();
218    /// let parts = tensor.split_with_sizes(&[2, 3], 1);
219    /// assert_eq!(parts.len(), 2);
220    /// assert_eq!(parts[0].shape().dims, vec![1, 2]);
221    /// assert_eq!(parts[1].shape().dims, vec![1, 3]);
222    /// assert_eq!(parts[0].get(&[0, 0]), 1.0);
223    /// assert_eq!(parts[0].get(&[0, 1]), 2.0);
224    /// assert_eq!(parts[1].get(&[0, 0]), 3.0);
225    /// ```
226    ///
227    /// ```
228    /// use train_station::Tensor;
229    ///
230    /// // Split 2D tensor with different chunk sizes
231    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
232    /// let parts = tensor.split_with_sizes(&[1, 2], 1);
233    /// assert_eq!(parts.len(), 2);
234    /// assert_eq!(parts[0].shape().dims, vec![2, 1]);
235    /// assert_eq!(parts[1].shape().dims, vec![2, 2]);
236    /// ```
237    ///
238    /// ```
239    /// use train_station::Tensor;
240    ///
241    /// // Split with gradient tracking
242    /// let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
243    /// tensor.set_requires_grad(true);
244    ///
245    /// let parts = tensor.split_with_sizes(&[1, 1], 1);
246    /// assert_eq!(parts.len(), 2);
247    /// assert!(parts[0].requires_grad());
248    /// assert!(parts[1].requires_grad());
249    /// ```
250    ///
251    /// ```
252    /// use train_station::Tensor;
253    ///
254    /// // Split 1D tensor with explicit sizes
255    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]).unwrap();
256    /// let parts = tensor.split_with_sizes(&[2, 2, 2], 0);
257    /// assert_eq!(parts.len(), 3);
258    /// assert_eq!(parts[0].shape().dims, vec![2]);
259    /// assert_eq!(parts[1].shape().dims, vec![2]);
260    /// assert_eq!(parts[2].shape().dims, vec![2]);
261    /// ```
262    ///
263    /// # Performance
264    ///
265    /// - **First Chunk**: O(1) - Returns a view when possible (zero-copy)
266    /// - **Subsequent Chunks**: O(n) - May require data copying for non-zero offsets
267    /// - **Memory Usage**: Minimal allocation for view operations, copying for non-zero offsets
268    /// - **Gradient Tracking**: Each chunk preserves gradient requirements and tracking
269    ///
270    /// # Relationship to Other Operations
271    ///
272    /// This operation is related to other tensor transformations:
273    /// - `split()` - Simplified version with equal-sized chunks
274    /// - `cat()` - Inverse operation that concatenates tensors back together
275    /// - `chunk()` - Alternative splitting operation with different semantics
276    ///
277    /// # Memory Layout
278    ///
279    /// The first chunk maintains the same underlying data as a view when
280    /// the base offset is zero. Subsequent chunks may require data copying
281    /// to handle non-zero base offsets, ensuring proper memory layout.
282    /// Zero-sized chunks are handled by creating empty tensors with
283    /// appropriate shapes.
284    pub fn split_with_sizes(&self, split_sizes: &[usize], dim: usize) -> Vec<Tensor> {
285        assert!(self.shape().rank() > 0, "split requires non-zero rank");
286        assert!(
287            dim < self.shape().rank(),
288            "split dim {} out of bounds for rank {}",
289            dim,
290            self.shape().rank()
291        );
292        let dim_size = self.shape().dims[dim];
293        let total: usize = split_sizes.iter().sum();
294        assert!(
295            total == dim_size,
296            "sum of split sizes {} must equal size {} of dim {}",
297            total,
298            dim_size,
299            dim
300        );
301
302        let mut outputs = Vec::with_capacity(split_sizes.len());
303        let mut start = 0usize;
304        for &len in split_sizes {
305            if len == 0 {
306                outputs.push(Tensor::zeros(
307                    self.shape()
308                        .dims
309                        .iter()
310                        .enumerate()
311                        .map(|(i, &d)| if i == dim { 0 } else { d })
312                        .collect(),
313                ));
314                continue;
315            }
316            // Build new dims/strides with updated length along `dim`
317            let mut new_dims = self.shape().dims.clone();
318            new_dims[dim] = len;
319            let new_strides = self.strides().to_vec();
320
321            let base_offset = start * self.stride(dim);
322
323            let mut piece: Tensor;
324            if base_offset == 0 {
325                // True view for the first chunk
326                let view_shape = crate::tensor::Shape::as_view(new_dims, new_strides);
327                piece = self.create_view_with_shape(view_shape);
328            } else {
329                // Materialize contiguous copy for non-zero base offset
330                piece = Tensor::new(new_dims.clone());
331                let rank = new_dims.len();
332                let numel = piece.size();
333                let mut coords = vec![0usize; rank];
334                for lin in 0..numel {
335                    let mut tmp = lin;
336                    for i in (0..rank).rev() {
337                        let s = new_dims[i];
338                        coords[i] = if s == 0 { 0 } else { tmp % s };
339                        if s != 0 {
340                            tmp /= s;
341                        }
342                    }
343                    // Map to source coords
344                    let mut src_coords = coords.clone();
345                    src_coords[dim] = start + coords[dim];
346                    let src_off = self.shape().offset(&src_coords);
347                    unsafe {
348                        *piece.as_mut_ptr().add(lin) = *self.as_ptr().add(src_off);
349                    }
350                }
351            }
352
353            // GradTrack: register backward to scatter this piece's grad into original input range
354            if self.requires_grad() {
355                piece.set_requires_grad_internal(true);
356                let grad_fn = GradFn::Split {
357                    dim,
358                    start,
359                    length: len,
360                    input_shape: self.shape().dims.clone(),
361                };
362                piece.set_grad_fn(grad_fn.clone());
363                GradEngine::register_operation(piece.id(), vec![self.id()], grad_fn);
364            }
365
366            outputs.push(piece);
367            start += len;
368        }
369
370        outputs
371    }
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377
378    #[test]
379    fn test_split_equal_forward() {
380        let data: Vec<f32> = (0..12).map(|i| i as f32).collect();
381        let x = Tensor::from_slice(&data, vec![2, 6]).unwrap();
382        let parts = x.split(2, 1);
383        assert_eq!(parts.len(), 3);
384        assert_eq!(parts[0].shape().dims, vec![2, 2]);
385        assert_eq!(parts[1].shape().dims, vec![2, 2]);
386        assert_eq!(parts[2].shape().dims, vec![2, 2]);
387        // Check a few values
388        assert_eq!(parts[0].get(&[0, 0]), 0.0);
389        assert_eq!(parts[1].get(&[0, 0]), 2.0);
390        assert_eq!(parts[2].get(&[1, 1]), 11.0);
391    }
392
393    #[test]
394    fn test_split_with_sizes_forward() {
395        let data: Vec<f32> = (0..15).map(|i| (i as f32) * 0.1).collect();
396        let x = Tensor::from_slice(&data, vec![3, 5]).unwrap();
397        let parts = x.split_with_sizes(&[2, 1, 2], 1);
398        assert_eq!(parts.len(), 3);
399        assert_eq!(parts[0].shape().dims, vec![3, 2]);
400        assert_eq!(parts[1].shape().dims, vec![3, 1]);
401        assert_eq!(parts[2].shape().dims, vec![3, 2]);
402        assert_eq!(parts[1].get(&[2, 0]), (2 * 5 + 2) as f32 * 0.1);
403    }
404
405    #[test]
406    fn test_split_gradients_scatter() {
407        let data: Vec<f32> = (0..10).map(|i| (i as f32) * 0.5 - 1.0).collect();
408        let x = Tensor::from_slice(&data, vec![2, 5])
409            .unwrap()
410            .with_requires_grad();
411        let parts = x.split_with_sizes(&[2, 3], 1);
412        // Reconstruct full tensor via concatenation then backward with implicit ones
413        let mut full = Tensor::cat(&parts, 1);
414        full.backward(None);
415        let gx = x.grad_by_value().expect("grad missing");
416        // All positions receive 1.0
417        for i in 0..x.size() {
418            assert_eq!(gx.get(&[i / 5, i % 5]), 1.0);
419        }
420    }
421
422    #[test]
423    fn test_split_1d_three_parts_grad() {
424        let data: Vec<f32> = (0..6).map(|i| i as f32).collect();
425        let x = Tensor::from_slice(&data, vec![6])
426            .unwrap()
427            .with_requires_grad();
428        let parts = x.split_with_sizes(&[2, 2, 2], 0);
429        // Concatenate then backward to avoid view/contig mismatches
430        let mut full = Tensor::cat(&parts, 0);
431        full.backward(None);
432        let gx = x.grad_by_value().expect("grad missing");
433        for i in 0..6 {
434            assert_eq!(gx.get(&[i]), 1.0);
435        }
436    }
437}