train_station/tensor/transform/split.rs
1//! Tensor splitting operations
2//!
3//! This module provides tensor splitting functionality that divides a tensor
4//! into multiple smaller tensors along a specified dimension. Splitting is a
5//! fundamental tensor transformation operation used in machine learning for
6//! dividing data into batches, creating multiple outputs from a single tensor,
7//! and implementing complex tensor manipulations.
8//!
9//! # Operations
10//!
11//! * `split()` - Split tensor into chunks of equal size along a dimension
12//! * `split_with_sizes()` - Split tensor into chunks with explicit sizes along a dimension
13//!
14//! # Performance Characteristics
15//!
16//! * **View Operations**: First chunk returns a view when possible (zero-copy)
17//! * **Copy Operations**: Subsequent chunks require data copying for non-zero offsets
18//! * **Memory Efficient**: Minimizes memory allocation through view reuse
19//! * **Gradient Tracking**: Full GradTrack support for automatic differentiation
20//! * **Shape Transformation**: Divides tensor along specified dimension while preserving other dimensions
21//!
22//! # Examples
23//!
24//! ```
25//! use train_station::Tensor;
26//!
27//! // Split tensor into equal-sized chunks
28//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
29//! let parts = tensor.split(1, 1);
30//! assert_eq!(parts.len(), 3);
31//! assert_eq!(parts[0].shape().dims, vec![2, 1]);
32//! assert_eq!(parts[1].shape().dims, vec![2, 1]);
33//! assert_eq!(parts[2].shape().dims, vec![2, 1]);
34//! ```
35//!
36//! ```
37//! use train_station::Tensor;
38//!
39//! // Split tensor with explicit sizes
40//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![1, 5]).unwrap();
41//! let parts = tensor.split_with_sizes(&[2, 3], 1);
42//! assert_eq!(parts.len(), 2);
43//! assert_eq!(parts[0].shape().dims, vec![1, 2]);
44//! assert_eq!(parts[1].shape().dims, vec![1, 3]);
45//! ```
46//!
47//! # Gradient Tracking
48//!
49//! The split operations support automatic gradient tracking through
50//! the GradTrack system. When `requires_grad` is enabled, each split
51//! piece registers a gradient function that scatters gradients back
52//! to the original tensor during backward passes.
53
54use crate::gradtrack::{GradEngine, GradFn};
55use crate::tensor::core::Tensor;
56
57impl Tensor {
58 /// Split tensor into chunks of equal size along specified dimension
59 ///
60 /// Divides the tensor into multiple smaller tensors along the specified
61 /// dimension, where each chunk (except possibly the last) has the same size.
62 /// The last chunk may be smaller if the dimension size is not evenly
63 /// divisible by the split size.
64 ///
65 /// This operation returns a vector of tensors, where each tensor is a
66 /// view or copy of a portion of the original tensor. The first chunk
67 /// is returned as a view when possible (zero-copy), while subsequent
68 /// chunks may require data copying for non-zero base offsets.
69 ///
70 /// # Arguments
71 ///
72 /// * `split_size` - Size of each chunk along the specified dimension (must be > 0)
73 /// * `dim` - Dimension along which to split the tensor (must be < tensor rank)
74 ///
75 /// # Returns
76 ///
77 /// A vector of tensors, each representing a chunk of the original tensor.
78 /// The number of chunks depends on the dimension size and split size.
79 ///
80 /// # Panics
81 ///
82 /// * If tensor rank is 0 (scalar tensors cannot be split)
83 /// * If `dim` is out of bounds for the tensor rank
84 /// * If `split_size` is 0
85 ///
86 /// # Examples
87 ///
88 /// ```
89 /// use train_station::Tensor;
90 ///
91 /// // Split 2D tensor into equal chunks along dimension 1
92 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
93 /// let parts = tensor.split(1, 1);
94 /// assert_eq!(parts.len(), 3);
95 /// assert_eq!(parts[0].shape().dims, vec![2, 1]);
96 /// assert_eq!(parts[1].shape().dims, vec![2, 1]);
97 /// assert_eq!(parts[2].shape().dims, vec![2, 1]);
98 /// assert_eq!(parts[0].get(&[0, 0]), 1.0);
99 /// assert_eq!(parts[1].get(&[0, 0]), 2.0);
100 /// assert_eq!(parts[2].get(&[1, 0]), 6.0);
101 /// ```
102 ///
103 /// ```
104 /// use train_station::Tensor;
105 ///
106 /// // Split with uneven division (last chunk smaller)
107 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![1, 5]).unwrap();
108 /// let parts = tensor.split(2, 1);
109 /// assert_eq!(parts.len(), 3);
110 /// assert_eq!(parts[0].shape().dims, vec![1, 2]);
111 /// assert_eq!(parts[1].shape().dims, vec![1, 2]);
112 /// assert_eq!(parts[2].shape().dims, vec![1, 1]); // Last chunk smaller
113 /// ```
114 ///
115 /// ```
116 /// use train_station::Tensor;
117 ///
118 /// // Split with gradient tracking
119 /// let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
120 /// tensor.set_requires_grad(true);
121 ///
122 /// let parts = tensor.split(1, 1);
123 /// assert_eq!(parts.len(), 2);
124 /// assert!(parts[0].requires_grad());
125 /// assert!(parts[1].requires_grad());
126 /// ```
127 ///
128 /// ```
129 /// use train_station::Tensor;
130 ///
131 /// // Split 1D tensor
132 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]).unwrap();
133 /// let parts = tensor.split(2, 0);
134 /// assert_eq!(parts.len(), 3);
135 /// assert_eq!(parts[0].shape().dims, vec![2]);
136 /// assert_eq!(parts[1].shape().dims, vec![2]);
137 /// assert_eq!(parts[2].shape().dims, vec![2]);
138 /// ```
139 ///
140 /// # Performance
141 ///
142 /// - **First Chunk**: O(1) - Returns a view when possible (zero-copy)
143 /// - **Subsequent Chunks**: O(n) - May require data copying for non-zero offsets
144 /// - **Memory Usage**: Minimal allocation for view operations, copying for non-zero offsets
145 /// - **Gradient Tracking**: Each chunk preserves gradient requirements and tracking
146 ///
147 /// # Relationship to Other Operations
148 ///
149 /// This operation is related to other tensor transformations:
150 /// - `split_with_sizes()` - More general version with explicit chunk sizes
151 /// - `cat()` - Inverse operation that concatenates tensors back together
152 /// - `chunk()` - Alternative splitting operation with different semantics
153 ///
154 /// # Memory Layout
155 ///
156 /// The first chunk maintains the same underlying data as a view when
157 /// the base offset is zero. Subsequent chunks may require data copying
158 /// to handle non-zero base offsets, ensuring proper memory layout.
159 #[track_caller]
160 pub fn split(&self, split_size: usize, dim: usize) -> Vec<Tensor> {
161 assert!(self.shape().rank() > 0, "split requires non-zero rank");
162 assert!(
163 dim < self.shape().rank(),
164 "split dim {} out of bounds for rank {}",
165 dim,
166 self.shape().rank()
167 );
168 assert!(split_size > 0, "split_size must be > 0");
169 let dim_size = self.shape().dims[dim];
170 if dim_size == 0 {
171 return vec![];
172 }
173
174 let mut sizes = Vec::new();
175 let mut remaining = dim_size;
176 while remaining > 0 {
177 let len = remaining.min(split_size);
178 sizes.push(len);
179 remaining -= len;
180 }
181 self.split_with_sizes(&sizes, dim)
182 }
183
184 /// Split tensor into chunks with explicit sizes along specified dimension
185 ///
186 /// Divides the tensor into multiple smaller tensors along the specified
187 /// dimension according to the provided size specifications. Each chunk
188 /// has the exact size specified in the `split_sizes` array, and the sum
189 /// of all sizes must equal the size of the specified dimension.
190 ///
191 /// This operation provides precise control over the size of each resulting
192 /// chunk, unlike `split()` which creates equal-sized chunks. The first
193 /// chunk is returned as a view when possible (zero-copy), while subsequent
194 /// chunks may require data copying for non-zero base offsets.
195 ///
196 /// # Arguments
197 ///
198 /// * `split_sizes` - Array specifying the size of each chunk along the dimension
199 /// * `dim` - Dimension along which to split the tensor (must be < tensor rank)
200 ///
201 /// # Returns
202 ///
203 /// A vector of tensors, each representing a chunk of the original tensor
204 /// with the specified size. The number of chunks equals the length of `split_sizes`.
205 ///
206 /// # Panics
207 ///
208 /// * If tensor rank is 0 (scalar tensors cannot be split)
209 /// * If `dim` is out of bounds for the tensor rank
210 /// * If sum of `split_sizes` does not equal the size of the specified dimension
211 ///
212 /// # Examples
213 ///
214 /// ```
215 /// use train_station::Tensor;
216 ///
217 /// // Split with explicit sizes
218 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![1, 5]).unwrap();
219 /// let parts = tensor.split_with_sizes(&[2, 3], 1);
220 /// assert_eq!(parts.len(), 2);
221 /// assert_eq!(parts[0].shape().dims, vec![1, 2]);
222 /// assert_eq!(parts[1].shape().dims, vec![1, 3]);
223 /// assert_eq!(parts[0].get(&[0, 0]), 1.0);
224 /// assert_eq!(parts[0].get(&[0, 1]), 2.0);
225 /// assert_eq!(parts[1].get(&[0, 0]), 3.0);
226 /// ```
227 ///
228 /// ```
229 /// use train_station::Tensor;
230 ///
231 /// // Split 2D tensor with different chunk sizes
232 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
233 /// let parts = tensor.split_with_sizes(&[1, 2], 1);
234 /// assert_eq!(parts.len(), 2);
235 /// assert_eq!(parts[0].shape().dims, vec![2, 1]);
236 /// assert_eq!(parts[1].shape().dims, vec![2, 2]);
237 /// ```
238 ///
239 /// ```
240 /// use train_station::Tensor;
241 ///
242 /// // Split with gradient tracking
243 /// let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
244 /// tensor.set_requires_grad(true);
245 ///
246 /// let parts = tensor.split_with_sizes(&[1, 1], 1);
247 /// assert_eq!(parts.len(), 2);
248 /// assert!(parts[0].requires_grad());
249 /// assert!(parts[1].requires_grad());
250 /// ```
251 ///
252 /// ```
253 /// use train_station::Tensor;
254 ///
255 /// // Split 1D tensor with explicit sizes
256 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]).unwrap();
257 /// let parts = tensor.split_with_sizes(&[2, 2, 2], 0);
258 /// assert_eq!(parts.len(), 3);
259 /// assert_eq!(parts[0].shape().dims, vec![2]);
260 /// assert_eq!(parts[1].shape().dims, vec![2]);
261 /// assert_eq!(parts[2].shape().dims, vec![2]);
262 /// ```
263 ///
264 /// # Performance
265 ///
266 /// - **First Chunk**: O(1) - Returns a view when possible (zero-copy)
267 /// - **Subsequent Chunks**: O(n) - May require data copying for non-zero offsets
268 /// - **Memory Usage**: Minimal allocation for view operations, copying for non-zero offsets
269 /// - **Gradient Tracking**: Each chunk preserves gradient requirements and tracking
270 ///
271 /// # Relationship to Other Operations
272 ///
273 /// This operation is related to other tensor transformations:
274 /// - `split()` - Simplified version with equal-sized chunks
275 /// - `cat()` - Inverse operation that concatenates tensors back together
276 /// - `chunk()` - Alternative splitting operation with different semantics
277 ///
278 /// # Memory Layout
279 ///
280 /// The first chunk maintains the same underlying data as a view when
281 /// the base offset is zero. Subsequent chunks may require data copying
282 /// to handle non-zero base offsets, ensuring proper memory layout.
283 /// Zero-sized chunks are handled by creating empty tensors with
284 /// appropriate shapes.
285 #[track_caller]
286 pub fn split_with_sizes(&self, split_sizes: &[usize], dim: usize) -> Vec<Tensor> {
287 assert!(self.shape().rank() > 0, "split requires non-zero rank");
288 assert!(
289 dim < self.shape().rank(),
290 "split dim {} out of bounds for rank {}",
291 dim,
292 self.shape().rank()
293 );
294 let dim_size = self.shape().dims[dim];
295 let total: usize = split_sizes.iter().sum();
296 assert!(
297 total == dim_size,
298 "sum of split sizes {} must equal size {} of dim {}",
299 total,
300 dim_size,
301 dim
302 );
303
304 let mut outputs = Vec::with_capacity(split_sizes.len());
305 let mut start = 0usize;
306 for &len in split_sizes {
307 if len == 0 {
308 outputs.push(Tensor::zeros(
309 self.shape()
310 .dims
311 .iter()
312 .enumerate()
313 .map(|(i, &d)| if i == dim { 0 } else { d })
314 .collect(),
315 ));
316 continue;
317 }
318 // Build new dims/strides with updated length along `dim`
319 let mut new_dims = self.shape().dims.clone();
320 new_dims[dim] = len;
321 let new_strides = self.strides().to_vec();
322
323 let base_offset = start * self.stride(dim);
324
325 let mut piece: Tensor;
326 if base_offset == 0 {
327 // True view for the first chunk
328 let view_shape = crate::tensor::Shape::as_view(new_dims, new_strides);
329 piece = self.create_view_with_shape(view_shape);
330 } else {
331 // Materialize contiguous copy for non-zero base offset
332 piece = Tensor::new(new_dims.clone());
333 let rank = new_dims.len();
334 let numel = piece.size();
335 let mut coords = vec![0usize; rank];
336 for lin in 0..numel {
337 let mut tmp = lin;
338 for i in (0..rank).rev() {
339 let s = new_dims[i];
340 coords[i] = if s == 0 { 0 } else { tmp % s };
341 if s != 0 {
342 tmp /= s;
343 }
344 }
345 // Map to source coords
346 let mut src_coords = coords.clone();
347 src_coords[dim] = start + coords[dim];
348 let src_off = self.shape().offset(&src_coords);
349 unsafe {
350 *piece.as_mut_ptr().add(lin) = *self.as_ptr().add(src_off);
351 }
352 }
353 }
354
355 // GradTrack: register backward to scatter this piece's grad into original input range
356 if self.requires_grad() {
357 piece.set_requires_grad_internal(true);
358 let grad_fn = GradFn::Split {
359 dim,
360 start,
361 length: len,
362 input_shape: self.shape().dims.clone(),
363 };
364 piece.set_grad_fn(grad_fn.clone());
365 GradEngine::register_operation(piece.id(), vec![self.id()], grad_fn);
366 }
367
368 outputs.push(piece);
369 start += len;
370 }
371
372 outputs
373 }
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379
380 #[test]
381 fn test_split_equal_forward() {
382 let data: Vec<f32> = (0..12).map(|i| i as f32).collect();
383 let x = Tensor::from_slice(&data, vec![2, 6]).unwrap();
384 let parts = x.split(2, 1);
385 assert_eq!(parts.len(), 3);
386 assert_eq!(parts[0].shape().dims, vec![2, 2]);
387 assert_eq!(parts[1].shape().dims, vec![2, 2]);
388 assert_eq!(parts[2].shape().dims, vec![2, 2]);
389 // Check a few values
390 assert_eq!(parts[0].get(&[0, 0]), 0.0);
391 assert_eq!(parts[1].get(&[0, 0]), 2.0);
392 assert_eq!(parts[2].get(&[1, 1]), 11.0);
393 }
394
395 #[test]
396 fn test_split_with_sizes_forward() {
397 let data: Vec<f32> = (0..15).map(|i| (i as f32) * 0.1).collect();
398 let x = Tensor::from_slice(&data, vec![3, 5]).unwrap();
399 let parts = x.split_with_sizes(&[2, 1, 2], 1);
400 assert_eq!(parts.len(), 3);
401 assert_eq!(parts[0].shape().dims, vec![3, 2]);
402 assert_eq!(parts[1].shape().dims, vec![3, 1]);
403 assert_eq!(parts[2].shape().dims, vec![3, 2]);
404 assert_eq!(parts[1].get(&[2, 0]), (2 * 5 + 2) as f32 * 0.1);
405 }
406
407 #[test]
408 fn test_split_gradients_scatter() {
409 let data: Vec<f32> = (0..10).map(|i| (i as f32) * 0.5 - 1.0).collect();
410 let x = Tensor::from_slice(&data, vec![2, 5])
411 .unwrap()
412 .with_requires_grad();
413 let parts = x.split_with_sizes(&[2, 3], 1);
414 // Reconstruct full tensor via concatenation then backward with implicit ones
415 let mut full = Tensor::cat(&parts, 1);
416 full.backward(None);
417 let gx = x.grad_by_value().expect("grad missing");
418 // All positions receive 1.0
419 for i in 0..x.size() {
420 assert_eq!(gx.get(&[i / 5, i % 5]), 1.0);
421 }
422 }
423
424 #[test]
425 fn test_split_1d_three_parts_grad() {
426 let data: Vec<f32> = (0..6).map(|i| i as f32).collect();
427 let x = Tensor::from_slice(&data, vec![6])
428 .unwrap()
429 .with_requires_grad();
430 let parts = x.split_with_sizes(&[2, 2, 2], 0);
431 // Concatenate then backward to avoid view/contig mismatches
432 let mut full = Tensor::cat(&parts, 0);
433 full.backward(None);
434 let gx = x.grad_by_value().expect("grad missing");
435 for i in 0..6 {
436 assert_eq!(gx.get(&[i]), 1.0);
437 }
438 }
439}