train_station/tensor/transform/
split.rs

1//! Tensor splitting operations
2//!
3//! This module provides tensor splitting functionality that divides a tensor
4//! into multiple smaller tensors along a specified dimension. Splitting is a
5//! fundamental tensor transformation operation used in machine learning for
6//! dividing data into batches, creating multiple outputs from a single tensor,
7//! and implementing complex tensor manipulations.
8//!
9//! # Operations
10//!
11//! * `split()` - Split tensor into chunks of equal size along a dimension
12//! * `split_with_sizes()` - Split tensor into chunks with explicit sizes along a dimension
13//!
14//! # Performance Characteristics
15//!
16//! * **View Operations**: First chunk returns a view when possible (zero-copy)
17//! * **Copy Operations**: Subsequent chunks require data copying for non-zero offsets
18//! * **Memory Efficient**: Minimizes memory allocation through view reuse
19//! * **Gradient Tracking**: Full GradTrack support for automatic differentiation
20//! * **Shape Transformation**: Divides tensor along specified dimension while preserving other dimensions
21//!
22//! # Examples
23//!
24//! ```
25//! use train_station::Tensor;
26//!
27//! // Split tensor into equal-sized chunks
28//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
29//! let parts = tensor.split(1, 1);
30//! assert_eq!(parts.len(), 3);
31//! assert_eq!(parts[0].shape().dims, vec![2, 1]);
32//! assert_eq!(parts[1].shape().dims, vec![2, 1]);
33//! assert_eq!(parts[2].shape().dims, vec![2, 1]);
34//! ```
35//!
36//! ```
37//! use train_station::Tensor;
38//!
39//! // Split tensor with explicit sizes
40//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![1, 5]).unwrap();
41//! let parts = tensor.split_with_sizes(&[2, 3], 1);
42//! assert_eq!(parts.len(), 2);
43//! assert_eq!(parts[0].shape().dims, vec![1, 2]);
44//! assert_eq!(parts[1].shape().dims, vec![1, 3]);
45//! ```
46//!
47//! # Gradient Tracking
48//!
49//! The split operations support automatic gradient tracking through
50//! the GradTrack system. When `requires_grad` is enabled, each split
51//! piece registers a gradient function that scatters gradients back
52//! to the original tensor during backward passes.
53
54use crate::gradtrack::{GradEngine, GradFn};
55use crate::tensor::core::Tensor;
56
57impl Tensor {
58    /// Split tensor into chunks of equal size along specified dimension
59    ///
60    /// Divides the tensor into multiple smaller tensors along the specified
61    /// dimension, where each chunk (except possibly the last) has the same size.
62    /// The last chunk may be smaller if the dimension size is not evenly
63    /// divisible by the split size.
64    ///
65    /// This operation returns a vector of tensors, where each tensor is a
66    /// view or copy of a portion of the original tensor. The first chunk
67    /// is returned as a view when possible (zero-copy), while subsequent
68    /// chunks may require data copying for non-zero base offsets.
69    ///
70    /// # Arguments
71    ///
72    /// * `split_size` - Size of each chunk along the specified dimension (must be > 0)
73    /// * `dim` - Dimension along which to split the tensor (must be < tensor rank)
74    ///
75    /// # Returns
76    ///
77    /// A vector of tensors, each representing a chunk of the original tensor.
78    /// The number of chunks depends on the dimension size and split size.
79    ///
80    /// # Panics
81    ///
82    /// * If tensor rank is 0 (scalar tensors cannot be split)
83    /// * If `dim` is out of bounds for the tensor rank
84    /// * If `split_size` is 0
85    ///
86    /// # Examples
87    ///
88    /// ```
89    /// use train_station::Tensor;
90    ///
91    /// // Split 2D tensor into equal chunks along dimension 1
92    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
93    /// let parts = tensor.split(1, 1);
94    /// assert_eq!(parts.len(), 3);
95    /// assert_eq!(parts[0].shape().dims, vec![2, 1]);
96    /// assert_eq!(parts[1].shape().dims, vec![2, 1]);
97    /// assert_eq!(parts[2].shape().dims, vec![2, 1]);
98    /// assert_eq!(parts[0].get(&[0, 0]), 1.0);
99    /// assert_eq!(parts[1].get(&[0, 0]), 2.0);
100    /// assert_eq!(parts[2].get(&[1, 0]), 6.0);
101    /// ```
102    ///
103    /// ```
104    /// use train_station::Tensor;
105    ///
106    /// // Split with uneven division (last chunk smaller)
107    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![1, 5]).unwrap();
108    /// let parts = tensor.split(2, 1);
109    /// assert_eq!(parts.len(), 3);
110    /// assert_eq!(parts[0].shape().dims, vec![1, 2]);
111    /// assert_eq!(parts[1].shape().dims, vec![1, 2]);
112    /// assert_eq!(parts[2].shape().dims, vec![1, 1]); // Last chunk smaller
113    /// ```
114    ///
115    /// ```
116    /// use train_station::Tensor;
117    ///
118    /// // Split with gradient tracking
119    /// let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
120    /// tensor.set_requires_grad(true);
121    ///
122    /// let parts = tensor.split(1, 1);
123    /// assert_eq!(parts.len(), 2);
124    /// assert!(parts[0].requires_grad());
125    /// assert!(parts[1].requires_grad());
126    /// ```
127    ///
128    /// ```
129    /// use train_station::Tensor;
130    ///
131    /// // Split 1D tensor
132    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]).unwrap();
133    /// let parts = tensor.split(2, 0);
134    /// assert_eq!(parts.len(), 3);
135    /// assert_eq!(parts[0].shape().dims, vec![2]);
136    /// assert_eq!(parts[1].shape().dims, vec![2]);
137    /// assert_eq!(parts[2].shape().dims, vec![2]);
138    /// ```
139    ///
140    /// # Performance
141    ///
142    /// - **First Chunk**: O(1) - Returns a view when possible (zero-copy)
143    /// - **Subsequent Chunks**: O(n) - May require data copying for non-zero offsets
144    /// - **Memory Usage**: Minimal allocation for view operations, copying for non-zero offsets
145    /// - **Gradient Tracking**: Each chunk preserves gradient requirements and tracking
146    ///
147    /// # Relationship to Other Operations
148    ///
149    /// This operation is related to other tensor transformations:
150    /// - `split_with_sizes()` - More general version with explicit chunk sizes
151    /// - `cat()` - Inverse operation that concatenates tensors back together
152    /// - `chunk()` - Alternative splitting operation with different semantics
153    ///
154    /// # Memory Layout
155    ///
156    /// The first chunk maintains the same underlying data as a view when
157    /// the base offset is zero. Subsequent chunks may require data copying
158    /// to handle non-zero base offsets, ensuring proper memory layout.
159    #[track_caller]
160    pub fn split(&self, split_size: usize, dim: usize) -> Vec<Tensor> {
161        assert!(self.shape().rank() > 0, "split requires non-zero rank");
162        assert!(
163            dim < self.shape().rank(),
164            "split dim {} out of bounds for rank {}",
165            dim,
166            self.shape().rank()
167        );
168        assert!(split_size > 0, "split_size must be > 0");
169        let dim_size = self.shape().dims[dim];
170        if dim_size == 0 {
171            return vec![];
172        }
173
174        let mut sizes = Vec::new();
175        let mut remaining = dim_size;
176        while remaining > 0 {
177            let len = remaining.min(split_size);
178            sizes.push(len);
179            remaining -= len;
180        }
181        self.split_with_sizes(&sizes, dim)
182    }
183
184    /// Split tensor into chunks with explicit sizes along specified dimension
185    ///
186    /// Divides the tensor into multiple smaller tensors along the specified
187    /// dimension according to the provided size specifications. Each chunk
188    /// has the exact size specified in the `split_sizes` array, and the sum
189    /// of all sizes must equal the size of the specified dimension.
190    ///
191    /// This operation provides precise control over the size of each resulting
192    /// chunk, unlike `split()` which creates equal-sized chunks. The first
193    /// chunk is returned as a view when possible (zero-copy), while subsequent
194    /// chunks may require data copying for non-zero base offsets.
195    ///
196    /// # Arguments
197    ///
198    /// * `split_sizes` - Array specifying the size of each chunk along the dimension
199    /// * `dim` - Dimension along which to split the tensor (must be < tensor rank)
200    ///
201    /// # Returns
202    ///
203    /// A vector of tensors, each representing a chunk of the original tensor
204    /// with the specified size. The number of chunks equals the length of `split_sizes`.
205    ///
206    /// # Panics
207    ///
208    /// * If tensor rank is 0 (scalar tensors cannot be split)
209    /// * If `dim` is out of bounds for the tensor rank
210    /// * If sum of `split_sizes` does not equal the size of the specified dimension
211    ///
212    /// # Examples
213    ///
214    /// ```
215    /// use train_station::Tensor;
216    ///
217    /// // Split with explicit sizes
218    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![1, 5]).unwrap();
219    /// let parts = tensor.split_with_sizes(&[2, 3], 1);
220    /// assert_eq!(parts.len(), 2);
221    /// assert_eq!(parts[0].shape().dims, vec![1, 2]);
222    /// assert_eq!(parts[1].shape().dims, vec![1, 3]);
223    /// assert_eq!(parts[0].get(&[0, 0]), 1.0);
224    /// assert_eq!(parts[0].get(&[0, 1]), 2.0);
225    /// assert_eq!(parts[1].get(&[0, 0]), 3.0);
226    /// ```
227    ///
228    /// ```
229    /// use train_station::Tensor;
230    ///
231    /// // Split 2D tensor with different chunk sizes
232    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
233    /// let parts = tensor.split_with_sizes(&[1, 2], 1);
234    /// assert_eq!(parts.len(), 2);
235    /// assert_eq!(parts[0].shape().dims, vec![2, 1]);
236    /// assert_eq!(parts[1].shape().dims, vec![2, 2]);
237    /// ```
238    ///
239    /// ```
240    /// use train_station::Tensor;
241    ///
242    /// // Split with gradient tracking
243    /// let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
244    /// tensor.set_requires_grad(true);
245    ///
246    /// let parts = tensor.split_with_sizes(&[1, 1], 1);
247    /// assert_eq!(parts.len(), 2);
248    /// assert!(parts[0].requires_grad());
249    /// assert!(parts[1].requires_grad());
250    /// ```
251    ///
252    /// ```
253    /// use train_station::Tensor;
254    ///
255    /// // Split 1D tensor with explicit sizes
256    /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]).unwrap();
257    /// let parts = tensor.split_with_sizes(&[2, 2, 2], 0);
258    /// assert_eq!(parts.len(), 3);
259    /// assert_eq!(parts[0].shape().dims, vec![2]);
260    /// assert_eq!(parts[1].shape().dims, vec![2]);
261    /// assert_eq!(parts[2].shape().dims, vec![2]);
262    /// ```
263    ///
264    /// # Performance
265    ///
266    /// - **First Chunk**: O(1) - Returns a view when possible (zero-copy)
267    /// - **Subsequent Chunks**: O(n) - May require data copying for non-zero offsets
268    /// - **Memory Usage**: Minimal allocation for view operations, copying for non-zero offsets
269    /// - **Gradient Tracking**: Each chunk preserves gradient requirements and tracking
270    ///
271    /// # Relationship to Other Operations
272    ///
273    /// This operation is related to other tensor transformations:
274    /// - `split()` - Simplified version with equal-sized chunks
275    /// - `cat()` - Inverse operation that concatenates tensors back together
276    /// - `chunk()` - Alternative splitting operation with different semantics
277    ///
278    /// # Memory Layout
279    ///
280    /// The first chunk maintains the same underlying data as a view when
281    /// the base offset is zero. Subsequent chunks may require data copying
282    /// to handle non-zero base offsets, ensuring proper memory layout.
283    /// Zero-sized chunks are handled by creating empty tensors with
284    /// appropriate shapes.
285    #[track_caller]
286    pub fn split_with_sizes(&self, split_sizes: &[usize], dim: usize) -> Vec<Tensor> {
287        assert!(self.shape().rank() > 0, "split requires non-zero rank");
288        assert!(
289            dim < self.shape().rank(),
290            "split dim {} out of bounds for rank {}",
291            dim,
292            self.shape().rank()
293        );
294        let dim_size = self.shape().dims[dim];
295        let total: usize = split_sizes.iter().sum();
296        assert!(
297            total == dim_size,
298            "sum of split sizes {} must equal size {} of dim {}",
299            total,
300            dim_size,
301            dim
302        );
303
304        let mut outputs = Vec::with_capacity(split_sizes.len());
305        let mut start = 0usize;
306        for &len in split_sizes {
307            if len == 0 {
308                outputs.push(Tensor::zeros(
309                    self.shape()
310                        .dims
311                        .iter()
312                        .enumerate()
313                        .map(|(i, &d)| if i == dim { 0 } else { d })
314                        .collect(),
315                ));
316                continue;
317            }
318            // Build new dims/strides with updated length along `dim`
319            let mut new_dims = self.shape().dims.clone();
320            new_dims[dim] = len;
321            let new_strides = self.strides().to_vec();
322
323            let base_offset = start * self.stride(dim);
324
325            let mut piece: Tensor;
326            if base_offset == 0 {
327                // True view for the first chunk
328                let view_shape = crate::tensor::Shape::as_view(new_dims, new_strides);
329                piece = self.create_view_with_shape(view_shape);
330            } else {
331                // Materialize contiguous copy for non-zero base offset
332                piece = Tensor::new(new_dims.clone());
333                let rank = new_dims.len();
334                let numel = piece.size();
335                let mut coords = vec![0usize; rank];
336                for lin in 0..numel {
337                    let mut tmp = lin;
338                    for i in (0..rank).rev() {
339                        let s = new_dims[i];
340                        coords[i] = if s == 0 { 0 } else { tmp % s };
341                        if s != 0 {
342                            tmp /= s;
343                        }
344                    }
345                    // Map to source coords
346                    let mut src_coords = coords.clone();
347                    src_coords[dim] = start + coords[dim];
348                    let src_off = self.shape().offset(&src_coords);
349                    unsafe {
350                        *piece.as_mut_ptr().add(lin) = *self.as_ptr().add(src_off);
351                    }
352                }
353            }
354
355            // GradTrack: register backward to scatter this piece's grad into original input range
356            if self.requires_grad() {
357                piece.set_requires_grad_internal(true);
358                let grad_fn = GradFn::Split {
359                    dim,
360                    start,
361                    length: len,
362                    input_shape: self.shape().dims.clone(),
363                };
364                piece.set_grad_fn(grad_fn.clone());
365                GradEngine::register_operation(piece.id(), vec![self.id()], grad_fn);
366            }
367
368            outputs.push(piece);
369            start += len;
370        }
371
372        outputs
373    }
374}
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379
380    #[test]
381    fn test_split_equal_forward() {
382        let data: Vec<f32> = (0..12).map(|i| i as f32).collect();
383        let x = Tensor::from_slice(&data, vec![2, 6]).unwrap();
384        let parts = x.split(2, 1);
385        assert_eq!(parts.len(), 3);
386        assert_eq!(parts[0].shape().dims, vec![2, 2]);
387        assert_eq!(parts[1].shape().dims, vec![2, 2]);
388        assert_eq!(parts[2].shape().dims, vec![2, 2]);
389        // Check a few values
390        assert_eq!(parts[0].get(&[0, 0]), 0.0);
391        assert_eq!(parts[1].get(&[0, 0]), 2.0);
392        assert_eq!(parts[2].get(&[1, 1]), 11.0);
393    }
394
395    #[test]
396    fn test_split_with_sizes_forward() {
397        let data: Vec<f32> = (0..15).map(|i| (i as f32) * 0.1).collect();
398        let x = Tensor::from_slice(&data, vec![3, 5]).unwrap();
399        let parts = x.split_with_sizes(&[2, 1, 2], 1);
400        assert_eq!(parts.len(), 3);
401        assert_eq!(parts[0].shape().dims, vec![3, 2]);
402        assert_eq!(parts[1].shape().dims, vec![3, 1]);
403        assert_eq!(parts[2].shape().dims, vec![3, 2]);
404        assert_eq!(parts[1].get(&[2, 0]), (2 * 5 + 2) as f32 * 0.1);
405    }
406
407    #[test]
408    fn test_split_gradients_scatter() {
409        let data: Vec<f32> = (0..10).map(|i| (i as f32) * 0.5 - 1.0).collect();
410        let x = Tensor::from_slice(&data, vec![2, 5])
411            .unwrap()
412            .with_requires_grad();
413        let parts = x.split_with_sizes(&[2, 3], 1);
414        // Reconstruct full tensor via concatenation then backward with implicit ones
415        let mut full = Tensor::cat(&parts, 1);
416        full.backward(None);
417        let gx = x.grad_by_value().expect("grad missing");
418        // All positions receive 1.0
419        for i in 0..x.size() {
420            assert_eq!(gx.get(&[i / 5, i % 5]), 1.0);
421        }
422    }
423
424    #[test]
425    fn test_split_1d_three_parts_grad() {
426        let data: Vec<f32> = (0..6).map(|i| i as f32).collect();
427        let x = Tensor::from_slice(&data, vec![6])
428            .unwrap()
429            .with_requires_grad();
430        let parts = x.split_with_sizes(&[2, 2, 2], 0);
431        // Concatenate then backward to avoid view/contig mismatches
432        let mut full = Tensor::cat(&parts, 0);
433        full.backward(None);
434        let gx = x.grad_by_value().expect("grad missing");
435        for i in 0..6 {
436            assert_eq!(gx.get(&[i]), 1.0);
437        }
438    }
439}