torsh_functional/manipulation/splitting.rs
1//! Tensor splitting operations
2//!
3//! This module provides comprehensive tensor splitting functionality for dividing tensors
4//! into multiple sub-tensors along specified dimensions. These operations are essential for
5//! data partitioning, parallel processing, and implementing neural network architectures
6//! that require tensor decomposition.
7
8use torsh_core::Result as TorshResult;
9use torsh_tensor::Tensor;
10
11/// Split tensor into multiple sub-tensors
12///
13/// ## Mathematical Background
14///
15/// Tensor splitting partitions a tensor A ∈ ℝ^(d₁×d₂×...×dₙ) along dimension k into
16/// multiple sub-tensors {B₁, B₂, ..., Bₘ} such that:
17///
18/// ```text
19/// concatenate([B₁, B₂, ..., Bₘ], dim=k) = A
20/// ```text
21///
22/// ## Splitting Modes
23///
24/// ### Fixed Size Splitting
25/// For split size s, creates ⌈dₖ/s⌉ sub-tensors where:
26/// - First ⌊dₖ/s⌋ tensors have size s along dimension k
27/// - Last tensor has size dₖ mod s (if non-zero)
28///
29/// ### Section-Based Splitting
30/// For n sections, creates n sub-tensors with sizes:
31/// - Base size: b = ⌊dₖ/n⌋
32/// - First r = dₖ mod n tensors: size b+1
33/// - Remaining n-r tensors: size b
34///
35/// ### Index-Based Splitting
36/// For indices [i₁, i₂, ..., iₘ], creates m+1 sub-tensors:
37/// - A[:i₁], A[i₁:i₂], ..., A[iₘ:]
38///
39/// ## Parameters
40/// * `tensor` - Input tensor to split
41/// * `split_size_or_sections` - Splitting specification (size, sections, or indices)
42/// * `dim` - Dimension along which to split (negative indexing supported)
43///
44/// ## Returns
45/// * Vector of sub-tensors resulting from the split operation
46///
47/// ## Applications
48/// - **Data batching**: Split large datasets into smaller batches
49/// - **Parallel processing**: Distribute tensor chunks across workers
50/// - **Memory management**: Process large tensors in smaller chunks
51/// - **Model parallelism**: Split layers across multiple devices
52///
53/// ## Examples
54/// ```rust
55/// # use torsh_functional::manipulation::split;
56/// # use torsh_functional::manipulation::SplitArg;
57/// # use torsh_tensor::creation::ones;
58/// let tensor = ones(&[12, 4])?;
59///
60/// // Split into chunks of size 3
61/// let chunks = split(&tensor, SplitArg::Size(3), 0)?; // 4 chunks of [3,4]
62///
63/// // Split into 4 equal sections
64/// let sections = split(&tensor, SplitArg::Sections(4), 0)?; // 4 chunks of [3,4]
65///
66/// // Split at specific indices
67/// let splits = split(&tensor, SplitArg::Indices(vec![3, 8]), 0)?; // [3,4], [5,4], [4,4]
68/// # Ok::<(), Box<dyn std::error::Error>>(())
69/// ```text
70pub fn split(
71 tensor: &Tensor,
72 split_size_or_sections: SplitArg,
73 dim: isize,
74) -> TorshResult<Vec<Tensor>> {
75 let shape = tensor.shape();
76 let ndim = shape.ndim() as isize;
77
78 // Normalize dimension
79 let dim = if dim < 0 { ndim + dim } else { dim } as usize;
80
81 if dim >= shape.ndim() {
82 return Err(torsh_core::TorshError::invalid_argument_with_context(
83 &format!(
84 "Dimension {} out of range for tensor with {} dimensions",
85 dim,
86 shape.ndim()
87 ),
88 "split",
89 ));
90 }
91
92 match split_size_or_sections {
93 SplitArg::Size(size) => {
94 // Split into chunks of given size
95 let dim_size = shape.dims()[dim];
96 let num_splits = dim_size.div_ceil(size);
97
98 let mut splits = Vec::new();
99 for i in 0..num_splits {
100 let start = i * size;
101 let end = ((i + 1) * size).min(dim_size);
102
103 // Create slice for this split
104 let split = tensor.slice(dim as usize, start, end)?.to_tensor()?;
105 splits.push(split);
106 }
107
108 Ok(splits)
109 }
110 SplitArg::Sections(sections) => {
111 // Split into given number of sections
112 let dim_size = shape.dims()[dim];
113 let base_size = dim_size / sections;
114 let remainder = dim_size % sections;
115
116 let mut splits = Vec::new();
117 let mut offset = 0;
118
119 for i in 0..sections {
120 let size = if i < remainder {
121 base_size + 1
122 } else {
123 base_size
124 };
125
126 // Create slice for this split
127 let split = tensor
128 .slice(dim as usize, offset, offset + size)?
129 .to_tensor()?;
130 splits.push(split);
131
132 offset += size;
133 }
134
135 Ok(splits)
136 }
137 SplitArg::Indices(indices) => {
138 // Split at the specified indices
139 let mut splits = Vec::new();
140 let mut start = 0;
141
142 for &index in &indices {
143 let split = tensor.slice(dim as usize, start, index)?.to_tensor()?;
144 splits.push(split);
145 start = index;
146 }
147
148 // Add final split from last index to end
149 let dim_size = shape.dims()[dim];
150 if start < dim_size {
151 let split = tensor.slice(dim as usize, start, dim_size)?.to_tensor()?;
152 splits.push(split);
153 }
154
155 Ok(splits)
156 }
157 }
158}
159
160/// Split argument specification for tensor splitting operations
161///
162/// ## Variants
163///
164/// ### Size
165/// Split into chunks of specified size along the dimension.
166/// The last chunk may be smaller if the dimension size is not evenly divisible.
167///
168/// ### Sections
169/// Split into the specified number of approximately equal sections.
170/// If not evenly divisible, earlier sections will be one element larger.
171///
172/// ### Indices
173/// Split at the specified indices, creating len(indices)+1 sub-tensors.
174/// Indices must be sorted and within bounds of the dimension.
175#[derive(Debug, Clone)]
176pub enum SplitArg {
177 /// Split into chunks of fixed size
178 Size(usize),
179 /// Split into specified number of sections
180 Sections(usize),
181 /// Split at specified indices
182 Indices(Vec<usize>),
183}
184
185/// Split tensor into approximately equal chunks
186///
187/// ## Mathematical Background
188///
189/// Chunks a tensor into approximately equal pieces along the specified dimension.
190/// For tensor with dimension size d and n chunks:
191///
192/// ```text
193/// chunk_size = ⌈d/n⌉
194/// num_full_chunks = n - (n * chunk_size - d)
195/// ```text
196///
197/// The first `num_full_chunks` will have size `chunk_size`, and remaining chunks
198/// will have size `chunk_size - 1`.
199///
200/// ## Parameters
201/// * `tensor` - Input tensor to chunk
202/// * `chunks` - Number of chunks to create
203/// * `dim` - Dimension along which to chunk (negative indexing supported)
204///
205/// ## Returns
206/// * Vector of approximately equal-sized tensor chunks
207///
208/// ## Examples
209/// ```rust
210/// # use torsh_functional::manipulation::chunk;
211/// # use torsh_tensor::creation::ones;
212/// let tensor = ones(&[10, 3])?; // 10 rows, 3 columns
213///
214/// // Split into 3 chunks along first dimension
215/// let chunks = chunk(&tensor, 3, 0)?; // [4,3], [3,3], [3,3]
216///
217/// // Split into 4 chunks along second dimension
218/// let chunks = chunk(&tensor, 4, 1)?; // [10,1], [10,1], [10,1], [10,0] (empty)
219/// # Ok::<(), Box<dyn std::error::Error>>(())
220/// ```text
221pub fn chunk(tensor: &Tensor, chunks: usize, dim: isize) -> TorshResult<Vec<Tensor>> {
222 let shape = tensor.shape();
223 let ndim = shape.ndim() as isize;
224
225 // Normalize dimension
226 let dim = if dim < 0 { ndim + dim } else { dim } as usize;
227
228 if dim >= shape.ndim() {
229 return Err(torsh_core::TorshError::invalid_argument_with_context(
230 &format!(
231 "Dimension {} out of range for tensor with {} dimensions",
232 dim,
233 shape.ndim()
234 ),
235 "chunk",
236 ));
237 }
238
239 split(tensor, SplitArg::Sections(chunks), dim as isize)
240}
241
242/// Split tensor into sections at specified indices
243///
244/// ## Mathematical Background
245///
246/// Performs tensor splitting at explicitly specified indices along a dimension.
247/// For tensor A with dimension size d and indices [i₁, i₂, ..., iₘ]:
248///
249/// ```text
250/// Result = [A[..., :i₁, ...], A[..., i₁:i₂, ...], ..., A[..., iₘ:, ...]]
251/// ```text
252///
253/// Where the ellipsis represents all other dimensions.
254///
255/// ## Index Validation
256/// - All indices must be within bounds: 0 ≤ iⱼ ≤ d
257/// - Indices should be in ascending order for meaningful results
258/// - Empty sections are allowed (consecutive identical indices)
259///
260/// ## Parameters
261/// * `tensor` - Input tensor to split
262/// * `indices_or_sections` - Either number of sections or explicit indices
263/// * `dim` - Dimension along which to split (negative indexing supported)
264///
265/// ## Returns
266/// * Vector of tensor sections
267///
268/// ## Applications
269/// - **Sequence processing**: Split variable-length sequences at boundaries
270/// - **Data preprocessing**: Extract regions of interest from images/signals
271/// - **Batch processing**: Create non-uniform batches based on data characteristics
272///
273/// ## Examples
274/// ```rust
275/// # use torsh_functional::manipulation::{tensor_split, TensorSplitArg};
276/// # use torsh_tensor::creation::ones;
277/// let tensor = ones(&[8, 4])?;
278///
279/// // Split into 3 sections
280/// let sections = tensor_split(&tensor, TensorSplitArg::Sections(3), 0)?;
281/// // Results: [3,4], [3,4], [2,4]
282///
283/// // Split at specific indices
284/// let splits = tensor_split(&tensor, TensorSplitArg::Indices(vec![2, 5]), 0)?;
285/// // Results: [2,4], [3,4], [3,4]
286/// # Ok::<(), Box<dyn std::error::Error>>(())
287/// ```text
288pub fn tensor_split(
289 tensor: &Tensor,
290 indices_or_sections: TensorSplitArg,
291 dim: isize,
292) -> TorshResult<Vec<Tensor>> {
293 let shape = tensor.shape();
294 let ndim = shape.ndim() as isize;
295
296 // Normalize dimension
297 let dim = if dim < 0 { ndim + dim } else { dim } as usize;
298
299 if dim >= shape.ndim() {
300 return Err(torsh_core::TorshError::invalid_argument_with_context(
301 &format!(
302 "Dimension {} out of range for tensor with {} dimensions",
303 dim,
304 shape.ndim()
305 ),
306 "tensor_split",
307 ));
308 }
309
310 match indices_or_sections {
311 TensorSplitArg::Sections(sections) => {
312 split(tensor, SplitArg::Sections(sections), dim as isize)
313 }
314 TensorSplitArg::Indices(indices) => {
315 let dim_size = shape.dims()[dim];
316 let mut splits = Vec::new();
317 let mut prev_idx = 0;
318
319 for &idx in &indices {
320 if idx > dim_size {
321 return Err(torsh_core::TorshError::invalid_argument_with_context(
322 &format!(
323 "Split index {} out of range for dimension size {}",
324 idx, dim_size
325 ),
326 "tensor_split",
327 ));
328 }
329
330 if idx > prev_idx {
331 let split = tensor.slice(dim as usize, prev_idx, idx)?.to_tensor()?;
332 splits.push(split);
333 }
334 prev_idx = idx;
335 }
336
337 // Add final split if needed
338 if prev_idx < dim_size {
339 let split = tensor
340 .slice(dim as usize, prev_idx, dim_size)?
341 .to_tensor()?;
342 splits.push(split);
343 }
344
345 Ok(splits)
346 }
347 }
348}
349
350/// Split argument specification for tensor_split operations
351///
352/// ## Variants
353///
354/// ### Sections
355/// Split into the specified number of approximately equal sections.
356///
357/// ### Indices
358/// Split at the specified indices. The indices define the boundaries
359/// where the tensor should be divided.
360#[derive(Debug, Clone)]
361pub enum TensorSplitArg {
362 /// Split into specified number of sections
363 Sections(usize),
364 /// Split at specified indices
365 Indices(Vec<usize>),
366}
367
368/// Split tensor horizontally (along second dimension)
369///
370/// ## Mathematical Background
371///
372/// Horizontal splitting divides a tensor along its second dimension (columns for 2D matrices).
373/// For tensor A ∈ ℝ^(m×n×...), hsplit creates sub-tensors along dimension 1:
374///
375/// ```text
376/// A = [A₁ | A₂ | ... | Aₖ] (column-wise concatenation)
377/// ```text
378///
379/// ## Requirements
380/// - Input tensor must have at least 2 dimensions
381/// - Equivalent to `tensor_split(tensor, indices_or_sections, 1)`
382///
383/// ## Parameters
384/// * `tensor` - Input tensor (≥2D required)
385/// * `indices_or_sections` - Split specification (sections or indices)
386///
387/// ## Returns
388/// * Vector of horizontally split tensors
389///
390/// ## Applications
391/// - **Image processing**: Split images into vertical strips
392/// - **Feature extraction**: Separate different feature groups in matrices
393/// - **Data analysis**: Split datasets by column groups
394///
395/// ## Examples
396/// ```rust
397/// # use torsh_functional::manipulation::{hsplit, TensorSplitArg};
398/// # use torsh_tensor::creation::ones;
399/// let image = ones(&[100, 200, 3])?; // Height × Width × Channels
400///
401/// // Split into 4 vertical strips
402/// let strips = hsplit(&image, TensorSplitArg::Sections(4))?;
403/// // Each strip: [100, 50, 3]
404///
405/// // Split at specific column positions
406/// let splits = hsplit(&image, TensorSplitArg::Indices(vec![50, 150]))?;
407/// // Results: [100,50,3], [100,100,3], [100,50,3]
408/// # Ok::<(), Box<dyn std::error::Error>>(())
409/// ```text
410pub fn hsplit(tensor: &Tensor, indices_or_sections: TensorSplitArg) -> TorshResult<Vec<Tensor>> {
411 let shape = tensor.shape();
412 if shape.ndim() < 2 {
413 return Err(torsh_core::TorshError::invalid_argument_with_context(
414 "Input tensor must have at least 2 dimensions for hsplit",
415 "hsplit",
416 ));
417 }
418
419 tensor_split(tensor, indices_or_sections, 1)
420}
421
422/// Split tensor vertically (along first dimension)
423///
424/// ## Mathematical Background
425///
426/// Vertical splitting divides a tensor along its first dimension (rows for 2D matrices).
427/// For tensor A ∈ ℝ^(m×n×...), vsplit creates sub-tensors along dimension 0:
428///
429/// ```text
430/// A = [A₁; A₂; ...; Aₖ] (row-wise concatenation)
431/// ```text
432///
433/// ## Requirements
434/// - Input tensor must have at least 2 dimensions
435/// - Equivalent to `tensor_split(tensor, indices_or_sections, 0)`
436///
437/// ## Parameters
438/// * `tensor` - Input tensor (≥2D required)
439/// * `indices_or_sections` - Split specification (sections or indices)
440///
441/// ## Returns
442/// * Vector of vertically split tensors
443///
444/// ## Applications
445/// - **Image processing**: Split images into horizontal strips
446/// - **Batch processing**: Divide batches into smaller sub-batches
447/// - **Time series**: Split sequences into temporal segments
448///
449/// ## Examples
450/// ```rust
451/// # use torsh_functional::manipulation::{vsplit, TensorSplitArg};
452/// # use torsh_tensor::creation::ones;
453/// let batch = ones(&[64, 784])?; // Batch size × Features
454///
455/// // Split into 4 mini-batches
456/// let mini_batches = vsplit(&batch, TensorSplitArg::Sections(4))?;
457/// // Each mini-batch: [16, 784]
458///
459/// // Split at specific row positions
460/// let splits = vsplit(&batch, TensorSplitArg::Indices(vec![16, 48]))?;
461/// // Results: [16,784], [32,784], [16,784]
462/// # Ok::<(), Box<dyn std::error::Error>>(())
463/// ```text
464pub fn vsplit(tensor: &Tensor, indices_or_sections: TensorSplitArg) -> TorshResult<Vec<Tensor>> {
465 let shape = tensor.shape();
466 if shape.ndim() < 2 {
467 return Err(torsh_core::TorshError::invalid_argument_with_context(
468 "Input tensor must have at least 2 dimensions for vsplit",
469 "vsplit",
470 ));
471 }
472
473 tensor_split(tensor, indices_or_sections, 0)
474}
475
476/// Split tensor along depth dimension (third dimension)
477///
478/// ## Mathematical Background
479///
480/// Depth splitting divides a tensor along its third dimension (depth for 3D tensors).
481/// For tensor A ∈ ℝ^(m×n×d×...), dsplit creates sub-tensors along dimension 2:
482///
483/// ```text
484/// A[:,:,k₁:k₂,:] for each split k₁:k₂
485/// ```text
486///
487/// ## Requirements
488/// - Input tensor must have at least 3 dimensions
489/// - Equivalent to `tensor_split(tensor, indices_or_sections, 2)`
490///
491/// ## Parameters
492/// * `tensor` - Input tensor (≥3D required)
493/// * `indices_or_sections` - Split specification (sections or indices)
494///
495/// ## Returns
496/// * Vector of depth-wise split tensors
497///
498/// ## Applications
499/// - **3D data processing**: Split volumetric data along depth
500/// - **Video analysis**: Split video frames into temporal chunks
501/// - **Multi-channel data**: Separate different channels or modalities
502/// - **Neural networks**: Split feature maps along channel dimension
503///
504/// ## Examples
505/// ```rust
506/// # use torsh_functional::manipulation::{dsplit, TensorSplitArg};
507/// # use torsh_tensor::creation::ones;
508/// let volume = ones(&[64, 64, 32])?; // Height × Width × Depth
509///
510/// // Split into 4 depth sections
511/// let sections = dsplit(&volume, TensorSplitArg::Sections(4))?;
512/// // Each section: [64, 64, 8]
513///
514/// // Split at specific depth positions
515/// let splits = dsplit(&volume, TensorSplitArg::Indices(vec![8, 24]))?;
516/// // Results: [64,64,8], [64,64,16], [64,64,8]
517/// # Ok::<(), Box<dyn std::error::Error>>(())
518/// ```text
519pub fn dsplit(tensor: &Tensor, indices_or_sections: TensorSplitArg) -> TorshResult<Vec<Tensor>> {
520 let shape = tensor.shape();
521 if shape.ndim() < 3 {
522 return Err(torsh_core::TorshError::invalid_argument_with_context(
523 "Input tensor must have at least 3 dimensions for dsplit",
524 "dsplit",
525 ));
526 }
527
528 tensor_split(tensor, indices_or_sections, 2)
529}
530
531#[cfg(test)]
532mod tests {
533 use super::*;
534 use crate::random_ops::randn;
535
536 #[test]
537 fn test_split() -> TorshResult<()> {
538 // Test equal splits
539 let tensor = randn(&[6, 4], None, None, None)?;
540 let result = split(&tensor, SplitArg::Sections(3), 0)?;
541 assert_eq!(result.len(), 3);
542 for chunk in &result {
543 assert_eq!(chunk.shape().dims(), &[2, 4]);
544 }
545
546 // Test split with indices
547 let result = split(&tensor, SplitArg::Indices(vec![2, 4]), 0)?;
548 assert_eq!(result.len(), 3);
549 assert_eq!(result[0].shape().dims(), &[2, 4]);
550 assert_eq!(result[1].shape().dims(), &[2, 4]);
551 assert_eq!(result[2].shape().dims(), &[2, 4]);
552
553 Ok(())
554 }
555
556 #[test]
557 fn test_chunk() -> TorshResult<()> {
558 // Test chunking along dimension 0
559 let tensor = randn(&[8, 3], None, None, None)?;
560 let result = chunk(&tensor, 3, 0)?;
561
562 // Should create 3 chunks of sizes [3, 3, 2]
563 assert_eq!(result.len(), 3);
564 assert_eq!(result[0].shape().dims(), &[3, 3]);
565 assert_eq!(result[1].shape().dims(), &[3, 3]);
566 assert_eq!(result[2].shape().dims(), &[2, 3]);
567
568 // Test chunking along dimension 1
569 let result = chunk(&tensor, 2, 1)?;
570 assert_eq!(result.len(), 2);
571 assert_eq!(result[0].shape().dims(), &[8, 2]); // ceil(3/2) = 2
572 assert_eq!(result[1].shape().dims(), &[8, 1]); // remaining 1
573
574 Ok(())
575 }
576
577 #[test]
578 fn test_tensor_split() -> TorshResult<()> {
579 // Test with sections
580 let tensor = randn(&[6, 4], None, None, None)?;
581 let result = tensor_split(&tensor, TensorSplitArg::Sections(3), 0)?;
582 assert_eq!(result.len(), 3);
583 for chunk in &result {
584 assert_eq!(chunk.shape().dims(), &[2, 4]);
585 }
586
587 // Test with indices
588 let result = tensor_split(&tensor, TensorSplitArg::Indices(vec![2, 4]), 0)?;
589 assert_eq!(result.len(), 3);
590 assert_eq!(result[0].shape().dims(), &[2, 4]);
591 assert_eq!(result[1].shape().dims(), &[2, 4]);
592 assert_eq!(result[2].shape().dims(), &[2, 4]);
593
594 Ok(())
595 }
596
597 #[test]
598 fn test_hsplit() -> TorshResult<()> {
599 // Test horizontal split with sections
600 let tensor = randn(&[4, 6], None, None, None)?;
601 let result = hsplit(&tensor, TensorSplitArg::Sections(3))?;
602 assert_eq!(result.len(), 3);
603 for chunk in &result {
604 assert_eq!(chunk.shape().dims(), &[4, 2]);
605 }
606
607 // Test horizontal split with indices
608 let result = hsplit(&tensor, TensorSplitArg::Indices(vec![2, 4]))?;
609 assert_eq!(result.len(), 3);
610 assert_eq!(result[0].shape().dims(), &[4, 2]);
611 assert_eq!(result[1].shape().dims(), &[4, 2]);
612 assert_eq!(result[2].shape().dims(), &[4, 2]);
613
614 Ok(())
615 }
616
617 #[test]
618 fn test_vsplit() -> TorshResult<()> {
619 // Test vertical split with sections
620 let tensor = randn(&[6, 4], None, None, None)?;
621 let result = vsplit(&tensor, TensorSplitArg::Sections(3))?;
622 assert_eq!(result.len(), 3);
623 for chunk in &result {
624 assert_eq!(chunk.shape().dims(), &[2, 4]);
625 }
626
627 // Test vertical split with indices
628 let result = vsplit(&tensor, TensorSplitArg::Indices(vec![2, 4]))?;
629 assert_eq!(result.len(), 3);
630 assert_eq!(result[0].shape().dims(), &[2, 4]);
631 assert_eq!(result[1].shape().dims(), &[2, 4]);
632 assert_eq!(result[2].shape().dims(), &[2, 4]);
633
634 Ok(())
635 }
636
637 #[test]
638 fn test_dsplit() -> TorshResult<()> {
639 // Test depth split with 3D tensor
640 let tensor = randn(&[2, 3, 6], None, None, None)?;
641 let result = dsplit(&tensor, TensorSplitArg::Sections(3))?;
642 assert_eq!(result.len(), 3);
643 for chunk in &result {
644 assert_eq!(chunk.shape().dims(), &[2, 3, 2]);
645 }
646
647 // Test depth split with indices
648 let result = dsplit(&tensor, TensorSplitArg::Indices(vec![2, 4]))?;
649 assert_eq!(result.len(), 3);
650 assert_eq!(result[0].shape().dims(), &[2, 3, 2]);
651 assert_eq!(result[1].shape().dims(), &[2, 3, 2]);
652 assert_eq!(result[2].shape().dims(), &[2, 3, 2]);
653
654 Ok(())
655 }
656
657 #[test]
658 #[should_panic(expected = "Input tensor must have at least 2 dimensions for hsplit")]
659 fn test_hsplit_invalid_dimensions() {
660 let tensor = randn(&[5], None, None, None).unwrap(); // 1D tensor
661 hsplit(&tensor, TensorSplitArg::Sections(2)).unwrap();
662 }
663
664 #[test]
665 #[should_panic(expected = "Input tensor must have at least 2 dimensions for vsplit")]
666 fn test_vsplit_invalid_dimensions() {
667 let tensor = randn(&[5], None, None, None).unwrap(); // 1D tensor
668 vsplit(&tensor, TensorSplitArg::Sections(2)).unwrap();
669 }
670
671 #[test]
672 #[should_panic(expected = "Input tensor must have at least 3 dimensions for dsplit")]
673 fn test_dsplit_invalid_dimensions() {
674 let tensor = randn(&[3, 4], None, None, None).unwrap(); // 2D tensor
675 dsplit(&tensor, TensorSplitArg::Sections(2)).unwrap();
676 }
677
678 #[test]
679 fn test_split_size_mode() -> TorshResult<()> {
680 let tensor = randn(&[10, 4], None, None, None)?;
681 let result = split(&tensor, SplitArg::Size(3), 0)?;
682
683 // Should create 4 chunks: [3,4], [3,4], [3,4], [1,4]
684 assert_eq!(result.len(), 4);
685 assert_eq!(result[0].shape().dims(), &[3, 4]);
686 assert_eq!(result[1].shape().dims(), &[3, 4]);
687 assert_eq!(result[2].shape().dims(), &[3, 4]);
688 assert_eq!(result[3].shape().dims(), &[1, 4]);
689
690 Ok(())
691 }
692
693 #[test]
694 fn test_negative_dimension_indexing() -> TorshResult<()> {
695 let tensor = randn(&[4, 6, 8], None, None, None)?;
696
697 // Test negative dimension indexing
698 let result1 = split(&tensor, SplitArg::Sections(2), -1)?; // Last dimension
699 let result2 = split(&tensor, SplitArg::Sections(2), 2)?; // Explicit last dimension
700
701 assert_eq!(result1.len(), result2.len());
702 assert_eq!(result1[0].shape().dims(), result2[0].shape().dims());
703
704 Ok(())
705 }
706}