train_station/tensor/transform/split.rs
1//! Tensor splitting operations
2//!
3//! This module provides tensor splitting functionality that divides a tensor
4//! into multiple smaller tensors along a specified dimension. Splitting is a
5//! fundamental tensor transformation operation used in machine learning for
6//! dividing data into batches, creating multiple outputs from a single tensor,
7//! and implementing complex tensor manipulations.
8//!
9//! # Operations
10//!
11//! * `split()` - Split tensor into chunks of equal size along a dimension
12//! * `split_with_sizes()` - Split tensor into chunks with explicit sizes along a dimension
13//!
14//! # Performance Characteristics
15//!
16//! * **View Operations**: First chunk returns a view when possible (zero-copy)
17//! * **Copy Operations**: Subsequent chunks require data copying for non-zero offsets
18//! * **Memory Efficient**: Minimizes memory allocation through view reuse
19//! * **Gradient Tracking**: Full GradTrack support for automatic differentiation
20//! * **Shape Transformation**: Divides tensor along specified dimension while preserving other dimensions
21//!
22//! # Examples
23//!
24//! ```
25//! use train_station::Tensor;
26//!
27//! // Split tensor into equal-sized chunks
28//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
29//! let parts = tensor.split(1, 1);
30//! assert_eq!(parts.len(), 3);
31//! assert_eq!(parts[0].shape().dims, vec![2, 1]);
32//! assert_eq!(parts[1].shape().dims, vec![2, 1]);
33//! assert_eq!(parts[2].shape().dims, vec![2, 1]);
34//! ```
35//!
36//! ```
37//! use train_station::Tensor;
38//!
39//! // Split tensor with explicit sizes
40//! let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![1, 5]).unwrap();
41//! let parts = tensor.split_with_sizes(&[2, 3], 1);
42//! assert_eq!(parts.len(), 2);
43//! assert_eq!(parts[0].shape().dims, vec![1, 2]);
44//! assert_eq!(parts[1].shape().dims, vec![1, 3]);
45//! ```
46//!
47//! # Gradient Tracking
48//!
49//! The split operations support automatic gradient tracking through
50//! the GradTrack system. When `requires_grad` is enabled, each split
51//! piece registers a gradient function that scatters gradients back
52//! to the original tensor during backward passes.
53
54use crate::gradtrack::{GradEngine, GradFn};
55use crate::tensor::core::Tensor;
56
57impl Tensor {
58 /// Split tensor into chunks of equal size along specified dimension
59 ///
60 /// Divides the tensor into multiple smaller tensors along the specified
61 /// dimension, where each chunk (except possibly the last) has the same size.
62 /// The last chunk may be smaller if the dimension size is not evenly
63 /// divisible by the split size.
64 ///
65 /// This operation returns a vector of tensors, where each tensor is a
66 /// view or copy of a portion of the original tensor. The first chunk
67 /// is returned as a view when possible (zero-copy), while subsequent
68 /// chunks may require data copying for non-zero base offsets.
69 ///
70 /// # Arguments
71 ///
72 /// * `split_size` - Size of each chunk along the specified dimension (must be > 0)
73 /// * `dim` - Dimension along which to split the tensor (must be < tensor rank)
74 ///
75 /// # Returns
76 ///
77 /// A vector of tensors, each representing a chunk of the original tensor.
78 /// The number of chunks depends on the dimension size and split size.
79 ///
80 /// # Panics
81 ///
82 /// * If tensor rank is 0 (scalar tensors cannot be split)
83 /// * If `dim` is out of bounds for the tensor rank
84 /// * If `split_size` is 0
85 ///
86 /// # Examples
87 ///
88 /// ```
89 /// use train_station::Tensor;
90 ///
91 /// // Split 2D tensor into equal chunks along dimension 1
92 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
93 /// let parts = tensor.split(1, 1);
94 /// assert_eq!(parts.len(), 3);
95 /// assert_eq!(parts[0].shape().dims, vec![2, 1]);
96 /// assert_eq!(parts[1].shape().dims, vec![2, 1]);
97 /// assert_eq!(parts[2].shape().dims, vec![2, 1]);
98 /// assert_eq!(parts[0].get(&[0, 0]), 1.0);
99 /// assert_eq!(parts[1].get(&[0, 0]), 2.0);
100 /// assert_eq!(parts[2].get(&[1, 0]), 6.0);
101 /// ```
102 ///
103 /// ```
104 /// use train_station::Tensor;
105 ///
106 /// // Split with uneven division (last chunk smaller)
107 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![1, 5]).unwrap();
108 /// let parts = tensor.split(2, 1);
109 /// assert_eq!(parts.len(), 3);
110 /// assert_eq!(parts[0].shape().dims, vec![1, 2]);
111 /// assert_eq!(parts[1].shape().dims, vec![1, 2]);
112 /// assert_eq!(parts[2].shape().dims, vec![1, 1]); // Last chunk smaller
113 /// ```
114 ///
115 /// ```
116 /// use train_station::Tensor;
117 ///
118 /// // Split with gradient tracking
119 /// let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
120 /// tensor.set_requires_grad(true);
121 ///
122 /// let parts = tensor.split(1, 1);
123 /// assert_eq!(parts.len(), 2);
124 /// assert!(parts[0].requires_grad());
125 /// assert!(parts[1].requires_grad());
126 /// ```
127 ///
128 /// ```
129 /// use train_station::Tensor;
130 ///
131 /// // Split 1D tensor
132 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]).unwrap();
133 /// let parts = tensor.split(2, 0);
134 /// assert_eq!(parts.len(), 3);
135 /// assert_eq!(parts[0].shape().dims, vec![2]);
136 /// assert_eq!(parts[1].shape().dims, vec![2]);
137 /// assert_eq!(parts[2].shape().dims, vec![2]);
138 /// ```
139 ///
140 /// # Performance
141 ///
142 /// - **First Chunk**: O(1) - Returns a view when possible (zero-copy)
143 /// - **Subsequent Chunks**: O(n) - May require data copying for non-zero offsets
144 /// - **Memory Usage**: Minimal allocation for view operations, copying for non-zero offsets
145 /// - **Gradient Tracking**: Each chunk preserves gradient requirements and tracking
146 ///
147 /// # Relationship to Other Operations
148 ///
149 /// This operation is related to other tensor transformations:
150 /// - `split_with_sizes()` - More general version with explicit chunk sizes
151 /// - `cat()` - Inverse operation that concatenates tensors back together
152 /// - `chunk()` - Alternative splitting operation with different semantics
153 ///
154 /// # Memory Layout
155 ///
156 /// The first chunk maintains the same underlying data as a view when
157 /// the base offset is zero. Subsequent chunks may require data copying
158 /// to handle non-zero base offsets, ensuring proper memory layout.
159 pub fn split(&self, split_size: usize, dim: usize) -> Vec<Tensor> {
160 assert!(self.shape().rank() > 0, "split requires non-zero rank");
161 assert!(
162 dim < self.shape().rank(),
163 "split dim {} out of bounds for rank {}",
164 dim,
165 self.shape().rank()
166 );
167 assert!(split_size > 0, "split_size must be > 0");
168 let dim_size = self.shape().dims[dim];
169 if dim_size == 0 {
170 return vec![];
171 }
172
173 let mut sizes = Vec::new();
174 let mut remaining = dim_size;
175 while remaining > 0 {
176 let len = remaining.min(split_size);
177 sizes.push(len);
178 remaining -= len;
179 }
180 self.split_with_sizes(&sizes, dim)
181 }
182
183 /// Split tensor into chunks with explicit sizes along specified dimension
184 ///
185 /// Divides the tensor into multiple smaller tensors along the specified
186 /// dimension according to the provided size specifications. Each chunk
187 /// has the exact size specified in the `split_sizes` array, and the sum
188 /// of all sizes must equal the size of the specified dimension.
189 ///
190 /// This operation provides precise control over the size of each resulting
191 /// chunk, unlike `split()` which creates equal-sized chunks. The first
192 /// chunk is returned as a view when possible (zero-copy), while subsequent
193 /// chunks may require data copying for non-zero base offsets.
194 ///
195 /// # Arguments
196 ///
197 /// * `split_sizes` - Array specifying the size of each chunk along the dimension
198 /// * `dim` - Dimension along which to split the tensor (must be < tensor rank)
199 ///
200 /// # Returns
201 ///
202 /// A vector of tensors, each representing a chunk of the original tensor
203 /// with the specified size. The number of chunks equals the length of `split_sizes`.
204 ///
205 /// # Panics
206 ///
207 /// * If tensor rank is 0 (scalar tensors cannot be split)
208 /// * If `dim` is out of bounds for the tensor rank
209 /// * If sum of `split_sizes` does not equal the size of the specified dimension
210 ///
211 /// # Examples
212 ///
213 /// ```
214 /// use train_station::Tensor;
215 ///
216 /// // Split with explicit sizes
217 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![1, 5]).unwrap();
218 /// let parts = tensor.split_with_sizes(&[2, 3], 1);
219 /// assert_eq!(parts.len(), 2);
220 /// assert_eq!(parts[0].shape().dims, vec![1, 2]);
221 /// assert_eq!(parts[1].shape().dims, vec![1, 3]);
222 /// assert_eq!(parts[0].get(&[0, 0]), 1.0);
223 /// assert_eq!(parts[0].get(&[0, 1]), 2.0);
224 /// assert_eq!(parts[1].get(&[0, 0]), 3.0);
225 /// ```
226 ///
227 /// ```
228 /// use train_station::Tensor;
229 ///
230 /// // Split 2D tensor with different chunk sizes
231 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
232 /// let parts = tensor.split_with_sizes(&[1, 2], 1);
233 /// assert_eq!(parts.len(), 2);
234 /// assert_eq!(parts[0].shape().dims, vec![2, 1]);
235 /// assert_eq!(parts[1].shape().dims, vec![2, 2]);
236 /// ```
237 ///
238 /// ```
239 /// use train_station::Tensor;
240 ///
241 /// // Split with gradient tracking
242 /// let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
243 /// tensor.set_requires_grad(true);
244 ///
245 /// let parts = tensor.split_with_sizes(&[1, 1], 1);
246 /// assert_eq!(parts.len(), 2);
247 /// assert!(parts[0].requires_grad());
248 /// assert!(parts[1].requires_grad());
249 /// ```
250 ///
251 /// ```
252 /// use train_station::Tensor;
253 ///
254 /// // Split 1D tensor with explicit sizes
255 /// let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]).unwrap();
256 /// let parts = tensor.split_with_sizes(&[2, 2, 2], 0);
257 /// assert_eq!(parts.len(), 3);
258 /// assert_eq!(parts[0].shape().dims, vec![2]);
259 /// assert_eq!(parts[1].shape().dims, vec![2]);
260 /// assert_eq!(parts[2].shape().dims, vec![2]);
261 /// ```
262 ///
263 /// # Performance
264 ///
265 /// - **First Chunk**: O(1) - Returns a view when possible (zero-copy)
266 /// - **Subsequent Chunks**: O(n) - May require data copying for non-zero offsets
267 /// - **Memory Usage**: Minimal allocation for view operations, copying for non-zero offsets
268 /// - **Gradient Tracking**: Each chunk preserves gradient requirements and tracking
269 ///
270 /// # Relationship to Other Operations
271 ///
272 /// This operation is related to other tensor transformations:
273 /// - `split()` - Simplified version with equal-sized chunks
274 /// - `cat()` - Inverse operation that concatenates tensors back together
275 /// - `chunk()` - Alternative splitting operation with different semantics
276 ///
277 /// # Memory Layout
278 ///
279 /// The first chunk maintains the same underlying data as a view when
280 /// the base offset is zero. Subsequent chunks may require data copying
281 /// to handle non-zero base offsets, ensuring proper memory layout.
282 /// Zero-sized chunks are handled by creating empty tensors with
283 /// appropriate shapes.
284 pub fn split_with_sizes(&self, split_sizes: &[usize], dim: usize) -> Vec<Tensor> {
285 assert!(self.shape().rank() > 0, "split requires non-zero rank");
286 assert!(
287 dim < self.shape().rank(),
288 "split dim {} out of bounds for rank {}",
289 dim,
290 self.shape().rank()
291 );
292 let dim_size = self.shape().dims[dim];
293 let total: usize = split_sizes.iter().sum();
294 assert!(
295 total == dim_size,
296 "sum of split sizes {} must equal size {} of dim {}",
297 total,
298 dim_size,
299 dim
300 );
301
302 let mut outputs = Vec::with_capacity(split_sizes.len());
303 let mut start = 0usize;
304 for &len in split_sizes {
305 if len == 0 {
306 outputs.push(Tensor::zeros(
307 self.shape()
308 .dims
309 .iter()
310 .enumerate()
311 .map(|(i, &d)| if i == dim { 0 } else { d })
312 .collect(),
313 ));
314 continue;
315 }
316 // Build new dims/strides with updated length along `dim`
317 let mut new_dims = self.shape().dims.clone();
318 new_dims[dim] = len;
319 let new_strides = self.strides().to_vec();
320
321 let base_offset = start * self.stride(dim);
322
323 let mut piece: Tensor;
324 if base_offset == 0 {
325 // True view for the first chunk
326 let view_shape = crate::tensor::Shape::as_view(new_dims, new_strides);
327 piece = self.create_view_with_shape(view_shape);
328 } else {
329 // Materialize contiguous copy for non-zero base offset
330 piece = Tensor::new(new_dims.clone());
331 let rank = new_dims.len();
332 let numel = piece.size();
333 let mut coords = vec![0usize; rank];
334 for lin in 0..numel {
335 let mut tmp = lin;
336 for i in (0..rank).rev() {
337 let s = new_dims[i];
338 coords[i] = if s == 0 { 0 } else { tmp % s };
339 if s != 0 {
340 tmp /= s;
341 }
342 }
343 // Map to source coords
344 let mut src_coords = coords.clone();
345 src_coords[dim] = start + coords[dim];
346 let src_off = self.shape().offset(&src_coords);
347 unsafe {
348 *piece.as_mut_ptr().add(lin) = *self.as_ptr().add(src_off);
349 }
350 }
351 }
352
353 // GradTrack: register backward to scatter this piece's grad into original input range
354 if self.requires_grad() {
355 piece.set_requires_grad_internal(true);
356 let grad_fn = GradFn::Split {
357 dim,
358 start,
359 length: len,
360 input_shape: self.shape().dims.clone(),
361 };
362 piece.set_grad_fn(grad_fn.clone());
363 GradEngine::register_operation(piece.id(), vec![self.id()], grad_fn);
364 }
365
366 outputs.push(piece);
367 start += len;
368 }
369
370 outputs
371 }
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377
378 #[test]
379 fn test_split_equal_forward() {
380 let data: Vec<f32> = (0..12).map(|i| i as f32).collect();
381 let x = Tensor::from_slice(&data, vec![2, 6]).unwrap();
382 let parts = x.split(2, 1);
383 assert_eq!(parts.len(), 3);
384 assert_eq!(parts[0].shape().dims, vec![2, 2]);
385 assert_eq!(parts[1].shape().dims, vec![2, 2]);
386 assert_eq!(parts[2].shape().dims, vec![2, 2]);
387 // Check a few values
388 assert_eq!(parts[0].get(&[0, 0]), 0.0);
389 assert_eq!(parts[1].get(&[0, 0]), 2.0);
390 assert_eq!(parts[2].get(&[1, 1]), 11.0);
391 }
392
393 #[test]
394 fn test_split_with_sizes_forward() {
395 let data: Vec<f32> = (0..15).map(|i| (i as f32) * 0.1).collect();
396 let x = Tensor::from_slice(&data, vec![3, 5]).unwrap();
397 let parts = x.split_with_sizes(&[2, 1, 2], 1);
398 assert_eq!(parts.len(), 3);
399 assert_eq!(parts[0].shape().dims, vec![3, 2]);
400 assert_eq!(parts[1].shape().dims, vec![3, 1]);
401 assert_eq!(parts[2].shape().dims, vec![3, 2]);
402 assert_eq!(parts[1].get(&[2, 0]), (2 * 5 + 2) as f32 * 0.1);
403 }
404
405 #[test]
406 fn test_split_gradients_scatter() {
407 let data: Vec<f32> = (0..10).map(|i| (i as f32) * 0.5 - 1.0).collect();
408 let x = Tensor::from_slice(&data, vec![2, 5])
409 .unwrap()
410 .with_requires_grad();
411 let parts = x.split_with_sizes(&[2, 3], 1);
412 // Reconstruct full tensor via concatenation then backward with implicit ones
413 let mut full = Tensor::cat(&parts, 1);
414 full.backward(None);
415 let gx = x.grad_by_value().expect("grad missing");
416 // All positions receive 1.0
417 for i in 0..x.size() {
418 assert_eq!(gx.get(&[i / 5, i % 5]), 1.0);
419 }
420 }
421
422 #[test]
423 fn test_split_1d_three_parts_grad() {
424 let data: Vec<f32> = (0..6).map(|i| i as f32).collect();
425 let x = Tensor::from_slice(&data, vec![6])
426 .unwrap()
427 .with_requires_grad();
428 let parts = x.split_with_sizes(&[2, 2, 2], 0);
429 // Concatenate then backward to avoid view/contig mismatches
430 let mut full = Tensor::cat(&parts, 0);
431 full.backward(None);
432 let gx = x.grad_by_value().expect("grad missing");
433 for i in 0..6 {
434 assert_eq!(gx.get(&[i]), 1.0);
435 }
436 }
437}