Skip to main content

torsh_functional/manipulation/
splitting.rs

1//! Tensor splitting operations
2//!
3//! This module provides comprehensive tensor splitting functionality for dividing tensors
4//! into multiple sub-tensors along specified dimensions. These operations are essential for
5//! data partitioning, parallel processing, and implementing neural network architectures
6//! that require tensor decomposition.
7
8use torsh_core::Result as TorshResult;
9use torsh_tensor::Tensor;
10
11/// Split tensor into multiple sub-tensors
12///
13/// ## Mathematical Background
14///
15/// Tensor splitting partitions a tensor A ∈ ℝ^(d₁×d₂×...×dₙ) along dimension k into
16/// multiple sub-tensors {B₁, B₂, ..., Bₘ} such that:
17///
18/// ```text
19/// concatenate([B₁, B₂, ..., Bₘ], dim=k) = A
20/// ```text
21///
22/// ## Splitting Modes
23///
24/// ### Fixed Size Splitting
25/// For split size s, creates ⌈dₖ/s⌉ sub-tensors where:
26/// - First ⌊dₖ/s⌋ tensors have size s along dimension k
27/// - Last tensor has size dₖ mod s (if non-zero)
28///
29/// ### Section-Based Splitting
30/// For n sections, creates n sub-tensors with sizes:
31/// - Base size: b = ⌊dₖ/n⌋
32/// - First r = dₖ mod n tensors: size b+1
33/// - Remaining n-r tensors: size b
34///
35/// ### Index-Based Splitting
36/// For indices [i₁, i₂, ..., iₘ], creates m+1 sub-tensors:
37/// - A[:i₁], A[i₁:i₂], ..., A[iₘ:]
38///
39/// ## Parameters
40/// * `tensor` - Input tensor to split
41/// * `split_size_or_sections` - Splitting specification (size, sections, or indices)
42/// * `dim` - Dimension along which to split (negative indexing supported)
43///
44/// ## Returns
45/// * Vector of sub-tensors resulting from the split operation
46///
47/// ## Applications
48/// - **Data batching**: Split large datasets into smaller batches
49/// - **Parallel processing**: Distribute tensor chunks across workers
50/// - **Memory management**: Process large tensors in smaller chunks
51/// - **Model parallelism**: Split layers across multiple devices
52///
53/// ## Examples
54/// ```rust
55/// # use torsh_functional::manipulation::split;
56/// # use torsh_functional::manipulation::SplitArg;
57/// # use torsh_tensor::creation::ones;
58/// let tensor = ones(&[12, 4])?;
59///
60/// // Split into chunks of size 3
61/// let chunks = split(&tensor, SplitArg::Size(3), 0)?; // 4 chunks of [3,4]
62///
63/// // Split into 4 equal sections
64/// let sections = split(&tensor, SplitArg::Sections(4), 0)?; // 4 chunks of [3,4]
65///
66/// // Split at specific indices
67/// let splits = split(&tensor, SplitArg::Indices(vec![3, 8]), 0)?; // [3,4], [5,4], [4,4]
68/// # Ok::<(), Box<dyn std::error::Error>>(())
69/// ```text
70pub fn split(
71    tensor: &Tensor,
72    split_size_or_sections: SplitArg,
73    dim: isize,
74) -> TorshResult<Vec<Tensor>> {
75    let shape = tensor.shape();
76    let ndim = shape.ndim() as isize;
77
78    // Normalize dimension
79    let dim = if dim < 0 { ndim + dim } else { dim } as usize;
80
81    if dim >= shape.ndim() {
82        return Err(torsh_core::TorshError::invalid_argument_with_context(
83            &format!(
84                "Dimension {} out of range for tensor with {} dimensions",
85                dim,
86                shape.ndim()
87            ),
88            "split",
89        ));
90    }
91
92    match split_size_or_sections {
93        SplitArg::Size(size) => {
94            // Split into chunks of given size
95            let dim_size = shape.dims()[dim];
96            let num_splits = dim_size.div_ceil(size);
97
98            let mut splits = Vec::new();
99            for i in 0..num_splits {
100                let start = i * size;
101                let end = ((i + 1) * size).min(dim_size);
102
103                // Create slice for this split
104                let split = tensor.slice(dim as usize, start, end)?.to_tensor()?;
105                splits.push(split);
106            }
107
108            Ok(splits)
109        }
110        SplitArg::Sections(sections) => {
111            // Split into given number of sections
112            let dim_size = shape.dims()[dim];
113            let base_size = dim_size / sections;
114            let remainder = dim_size % sections;
115
116            let mut splits = Vec::new();
117            let mut offset = 0;
118
119            for i in 0..sections {
120                let size = if i < remainder {
121                    base_size + 1
122                } else {
123                    base_size
124                };
125
126                // Create slice for this split
127                let split = tensor
128                    .slice(dim as usize, offset, offset + size)?
129                    .to_tensor()?;
130                splits.push(split);
131
132                offset += size;
133            }
134
135            Ok(splits)
136        }
137        SplitArg::Indices(indices) => {
138            // Split at the specified indices
139            let mut splits = Vec::new();
140            let mut start = 0;
141
142            for &index in &indices {
143                let split = tensor.slice(dim as usize, start, index)?.to_tensor()?;
144                splits.push(split);
145                start = index;
146            }
147
148            // Add final split from last index to end
149            let dim_size = shape.dims()[dim];
150            if start < dim_size {
151                let split = tensor.slice(dim as usize, start, dim_size)?.to_tensor()?;
152                splits.push(split);
153            }
154
155            Ok(splits)
156        }
157    }
158}
159
160/// Split argument specification for tensor splitting operations
161///
162/// ## Variants
163///
164/// ### Size
165/// Split into chunks of specified size along the dimension.
166/// The last chunk may be smaller if the dimension size is not evenly divisible.
167///
168/// ### Sections
169/// Split into the specified number of approximately equal sections.
170/// If not evenly divisible, earlier sections will be one element larger.
171///
172/// ### Indices
173/// Split at the specified indices, creating len(indices)+1 sub-tensors.
174/// Indices must be sorted and within bounds of the dimension.
175#[derive(Debug, Clone)]
176pub enum SplitArg {
177    /// Split into chunks of fixed size
178    Size(usize),
179    /// Split into specified number of sections
180    Sections(usize),
181    /// Split at specified indices
182    Indices(Vec<usize>),
183}
184
185/// Split tensor into approximately equal chunks
186///
187/// ## Mathematical Background
188///
189/// Chunks a tensor into approximately equal pieces along the specified dimension.
190/// For tensor with dimension size d and n chunks:
191///
192/// ```text
193/// chunk_size = ⌈d/n⌉
194/// num_full_chunks = n - (n * chunk_size - d)
195/// ```text
196///
197/// The first `num_full_chunks` will have size `chunk_size`, and remaining chunks
198/// will have size `chunk_size - 1`.
199///
200/// ## Parameters
201/// * `tensor` - Input tensor to chunk
202/// * `chunks` - Number of chunks to create
203/// * `dim` - Dimension along which to chunk (negative indexing supported)
204///
205/// ## Returns
206/// * Vector of approximately equal-sized tensor chunks
207///
208/// ## Examples
209/// ```rust
210/// # use torsh_functional::manipulation::chunk;
211/// # use torsh_tensor::creation::ones;
212/// let tensor = ones(&[10, 3])?; // 10 rows, 3 columns
213///
214/// // Split into 3 chunks along first dimension
215/// let chunks = chunk(&tensor, 3, 0)?; // [4,3], [3,3], [3,3]
216///
217/// // Split into 4 chunks along second dimension
218/// let chunks = chunk(&tensor, 4, 1)?; // [10,1], [10,1], [10,1], [10,0] (empty)
219/// # Ok::<(), Box<dyn std::error::Error>>(())
220/// ```text
221pub fn chunk(tensor: &Tensor, chunks: usize, dim: isize) -> TorshResult<Vec<Tensor>> {
222    let shape = tensor.shape();
223    let ndim = shape.ndim() as isize;
224
225    // Normalize dimension
226    let dim = if dim < 0 { ndim + dim } else { dim } as usize;
227
228    if dim >= shape.ndim() {
229        return Err(torsh_core::TorshError::invalid_argument_with_context(
230            &format!(
231                "Dimension {} out of range for tensor with {} dimensions",
232                dim,
233                shape.ndim()
234            ),
235            "chunk",
236        ));
237    }
238
239    split(tensor, SplitArg::Sections(chunks), dim as isize)
240}
241
242/// Split tensor into sections at specified indices
243///
244/// ## Mathematical Background
245///
246/// Performs tensor splitting at explicitly specified indices along a dimension.
247/// For tensor A with dimension size d and indices [i₁, i₂, ..., iₘ]:
248///
249/// ```text
250/// Result = [A[..., :i₁, ...], A[..., i₁:i₂, ...], ..., A[..., iₘ:, ...]]
251/// ```text
252///
253/// Where the ellipsis represents all other dimensions.
254///
255/// ## Index Validation
256/// - All indices must be within bounds: 0 ≤ iⱼ ≤ d
257/// - Indices should be in ascending order for meaningful results
258/// - Empty sections are allowed (consecutive identical indices)
259///
260/// ## Parameters
261/// * `tensor` - Input tensor to split
262/// * `indices_or_sections` - Either number of sections or explicit indices
263/// * `dim` - Dimension along which to split (negative indexing supported)
264///
265/// ## Returns
266/// * Vector of tensor sections
267///
268/// ## Applications
269/// - **Sequence processing**: Split variable-length sequences at boundaries
270/// - **Data preprocessing**: Extract regions of interest from images/signals
271/// - **Batch processing**: Create non-uniform batches based on data characteristics
272///
273/// ## Examples
274/// ```rust
275/// # use torsh_functional::manipulation::{tensor_split, TensorSplitArg};
276/// # use torsh_tensor::creation::ones;
277/// let tensor = ones(&[8, 4])?;
278///
279/// // Split into 3 sections
280/// let sections = tensor_split(&tensor, TensorSplitArg::Sections(3), 0)?;
281/// // Results: [3,4], [3,4], [2,4]
282///
283/// // Split at specific indices
284/// let splits = tensor_split(&tensor, TensorSplitArg::Indices(vec![2, 5]), 0)?;
285/// // Results: [2,4], [3,4], [3,4]
286/// # Ok::<(), Box<dyn std::error::Error>>(())
287/// ```text
288pub fn tensor_split(
289    tensor: &Tensor,
290    indices_or_sections: TensorSplitArg,
291    dim: isize,
292) -> TorshResult<Vec<Tensor>> {
293    let shape = tensor.shape();
294    let ndim = shape.ndim() as isize;
295
296    // Normalize dimension
297    let dim = if dim < 0 { ndim + dim } else { dim } as usize;
298
299    if dim >= shape.ndim() {
300        return Err(torsh_core::TorshError::invalid_argument_with_context(
301            &format!(
302                "Dimension {} out of range for tensor with {} dimensions",
303                dim,
304                shape.ndim()
305            ),
306            "tensor_split",
307        ));
308    }
309
310    match indices_or_sections {
311        TensorSplitArg::Sections(sections) => {
312            split(tensor, SplitArg::Sections(sections), dim as isize)
313        }
314        TensorSplitArg::Indices(indices) => {
315            let dim_size = shape.dims()[dim];
316            let mut splits = Vec::new();
317            let mut prev_idx = 0;
318
319            for &idx in &indices {
320                if idx > dim_size {
321                    return Err(torsh_core::TorshError::invalid_argument_with_context(
322                        &format!(
323                            "Split index {} out of range for dimension size {}",
324                            idx, dim_size
325                        ),
326                        "tensor_split",
327                    ));
328                }
329
330                if idx > prev_idx {
331                    let split = tensor.slice(dim as usize, prev_idx, idx)?.to_tensor()?;
332                    splits.push(split);
333                }
334                prev_idx = idx;
335            }
336
337            // Add final split if needed
338            if prev_idx < dim_size {
339                let split = tensor
340                    .slice(dim as usize, prev_idx, dim_size)?
341                    .to_tensor()?;
342                splits.push(split);
343            }
344
345            Ok(splits)
346        }
347    }
348}
349
350/// Split argument specification for tensor_split operations
351///
352/// ## Variants
353///
354/// ### Sections
355/// Split into the specified number of approximately equal sections.
356///
357/// ### Indices
358/// Split at the specified indices. The indices define the boundaries
359/// where the tensor should be divided.
360#[derive(Debug, Clone)]
361pub enum TensorSplitArg {
362    /// Split into specified number of sections
363    Sections(usize),
364    /// Split at specified indices
365    Indices(Vec<usize>),
366}
367
368/// Split tensor horizontally (along second dimension)
369///
370/// ## Mathematical Background
371///
372/// Horizontal splitting divides a tensor along its second dimension (columns for 2D matrices).
373/// For tensor A ∈ ℝ^(m×n×...), hsplit creates sub-tensors along dimension 1:
374///
375/// ```text
376/// A = [A₁ | A₂ | ... | Aₖ]  (column-wise concatenation)
377/// ```text
378///
379/// ## Requirements
380/// - Input tensor must have at least 2 dimensions
381/// - Equivalent to `tensor_split(tensor, indices_or_sections, 1)`
382///
383/// ## Parameters
384/// * `tensor` - Input tensor (≥2D required)
385/// * `indices_or_sections` - Split specification (sections or indices)
386///
387/// ## Returns
388/// * Vector of horizontally split tensors
389///
390/// ## Applications
391/// - **Image processing**: Split images into vertical strips
392/// - **Feature extraction**: Separate different feature groups in matrices
393/// - **Data analysis**: Split datasets by column groups
394///
395/// ## Examples
396/// ```rust
397/// # use torsh_functional::manipulation::{hsplit, TensorSplitArg};
398/// # use torsh_tensor::creation::ones;
399/// let image = ones(&[100, 200, 3])?; // Height × Width × Channels
400///
401/// // Split into 4 vertical strips
402/// let strips = hsplit(&image, TensorSplitArg::Sections(4))?;
403/// // Each strip: [100, 50, 3]
404///
405/// // Split at specific column positions
406/// let splits = hsplit(&image, TensorSplitArg::Indices(vec![50, 150]))?;
407/// // Results: [100,50,3], [100,100,3], [100,50,3]
408/// # Ok::<(), Box<dyn std::error::Error>>(())
409/// ```text
410pub fn hsplit(tensor: &Tensor, indices_or_sections: TensorSplitArg) -> TorshResult<Vec<Tensor>> {
411    let shape = tensor.shape();
412    if shape.ndim() < 2 {
413        return Err(torsh_core::TorshError::invalid_argument_with_context(
414            "Input tensor must have at least 2 dimensions for hsplit",
415            "hsplit",
416        ));
417    }
418
419    tensor_split(tensor, indices_or_sections, 1)
420}
421
422/// Split tensor vertically (along first dimension)
423///
424/// ## Mathematical Background
425///
426/// Vertical splitting divides a tensor along its first dimension (rows for 2D matrices).
427/// For tensor A ∈ ℝ^(m×n×...), vsplit creates sub-tensors along dimension 0:
428///
429/// ```text
430/// A = [A₁; A₂; ...; Aₖ]  (row-wise concatenation)
431/// ```text
432///
433/// ## Requirements
434/// - Input tensor must have at least 2 dimensions
435/// - Equivalent to `tensor_split(tensor, indices_or_sections, 0)`
436///
437/// ## Parameters
438/// * `tensor` - Input tensor (≥2D required)
439/// * `indices_or_sections` - Split specification (sections or indices)
440///
441/// ## Returns
442/// * Vector of vertically split tensors
443///
444/// ## Applications
445/// - **Image processing**: Split images into horizontal strips
446/// - **Batch processing**: Divide batches into smaller sub-batches
447/// - **Time series**: Split sequences into temporal segments
448///
449/// ## Examples
450/// ```rust
451/// # use torsh_functional::manipulation::{vsplit, TensorSplitArg};
452/// # use torsh_tensor::creation::ones;
453/// let batch = ones(&[64, 784])?; // Batch size × Features
454///
455/// // Split into 4 mini-batches
456/// let mini_batches = vsplit(&batch, TensorSplitArg::Sections(4))?;
457/// // Each mini-batch: [16, 784]
458///
459/// // Split at specific row positions
460/// let splits = vsplit(&batch, TensorSplitArg::Indices(vec![16, 48]))?;
461/// // Results: [16,784], [32,784], [16,784]
462/// # Ok::<(), Box<dyn std::error::Error>>(())
463/// ```text
464pub fn vsplit(tensor: &Tensor, indices_or_sections: TensorSplitArg) -> TorshResult<Vec<Tensor>> {
465    let shape = tensor.shape();
466    if shape.ndim() < 2 {
467        return Err(torsh_core::TorshError::invalid_argument_with_context(
468            "Input tensor must have at least 2 dimensions for vsplit",
469            "vsplit",
470        ));
471    }
472
473    tensor_split(tensor, indices_or_sections, 0)
474}
475
476/// Split tensor along depth dimension (third dimension)
477///
478/// ## Mathematical Background
479///
480/// Depth splitting divides a tensor along its third dimension (depth for 3D tensors).
481/// For tensor A ∈ ℝ^(m×n×d×...), dsplit creates sub-tensors along dimension 2:
482///
483/// ```text
484/// A[:,:,k₁:k₂,:] for each split k₁:k₂
485/// ```text
486///
487/// ## Requirements
488/// - Input tensor must have at least 3 dimensions
489/// - Equivalent to `tensor_split(tensor, indices_or_sections, 2)`
490///
491/// ## Parameters
492/// * `tensor` - Input tensor (≥3D required)
493/// * `indices_or_sections` - Split specification (sections or indices)
494///
495/// ## Returns
496/// * Vector of depth-wise split tensors
497///
498/// ## Applications
499/// - **3D data processing**: Split volumetric data along depth
500/// - **Video analysis**: Split video frames into temporal chunks
501/// - **Multi-channel data**: Separate different channels or modalities
502/// - **Neural networks**: Split feature maps along channel dimension
503///
504/// ## Examples
505/// ```rust
506/// # use torsh_functional::manipulation::{dsplit, TensorSplitArg};
507/// # use torsh_tensor::creation::ones;
508/// let volume = ones(&[64, 64, 32])?; // Height × Width × Depth
509///
510/// // Split into 4 depth sections
511/// let sections = dsplit(&volume, TensorSplitArg::Sections(4))?;
512/// // Each section: [64, 64, 8]
513///
514/// // Split at specific depth positions
515/// let splits = dsplit(&volume, TensorSplitArg::Indices(vec![8, 24]))?;
516/// // Results: [64,64,8], [64,64,16], [64,64,8]
517/// # Ok::<(), Box<dyn std::error::Error>>(())
518/// ```text
519pub fn dsplit(tensor: &Tensor, indices_or_sections: TensorSplitArg) -> TorshResult<Vec<Tensor>> {
520    let shape = tensor.shape();
521    if shape.ndim() < 3 {
522        return Err(torsh_core::TorshError::invalid_argument_with_context(
523            "Input tensor must have at least 3 dimensions for dsplit",
524            "dsplit",
525        ));
526    }
527
528    tensor_split(tensor, indices_or_sections, 2)
529}
530
531#[cfg(test)]
532mod tests {
533    use super::*;
534    use crate::random_ops::randn;
535
536    #[test]
537    fn test_split() -> TorshResult<()> {
538        // Test equal splits
539        let tensor = randn(&[6, 4], None, None, None)?;
540        let result = split(&tensor, SplitArg::Sections(3), 0)?;
541        assert_eq!(result.len(), 3);
542        for chunk in &result {
543            assert_eq!(chunk.shape().dims(), &[2, 4]);
544        }
545
546        // Test split with indices
547        let result = split(&tensor, SplitArg::Indices(vec![2, 4]), 0)?;
548        assert_eq!(result.len(), 3);
549        assert_eq!(result[0].shape().dims(), &[2, 4]);
550        assert_eq!(result[1].shape().dims(), &[2, 4]);
551        assert_eq!(result[2].shape().dims(), &[2, 4]);
552
553        Ok(())
554    }
555
556    #[test]
557    fn test_chunk() -> TorshResult<()> {
558        // Test chunking along dimension 0
559        let tensor = randn(&[8, 3], None, None, None)?;
560        let result = chunk(&tensor, 3, 0)?;
561
562        // Should create 3 chunks of sizes [3, 3, 2]
563        assert_eq!(result.len(), 3);
564        assert_eq!(result[0].shape().dims(), &[3, 3]);
565        assert_eq!(result[1].shape().dims(), &[3, 3]);
566        assert_eq!(result[2].shape().dims(), &[2, 3]);
567
568        // Test chunking along dimension 1
569        let result = chunk(&tensor, 2, 1)?;
570        assert_eq!(result.len(), 2);
571        assert_eq!(result[0].shape().dims(), &[8, 2]); // ceil(3/2) = 2
572        assert_eq!(result[1].shape().dims(), &[8, 1]); // remaining 1
573
574        Ok(())
575    }
576
577    #[test]
578    fn test_tensor_split() -> TorshResult<()> {
579        // Test with sections
580        let tensor = randn(&[6, 4], None, None, None)?;
581        let result = tensor_split(&tensor, TensorSplitArg::Sections(3), 0)?;
582        assert_eq!(result.len(), 3);
583        for chunk in &result {
584            assert_eq!(chunk.shape().dims(), &[2, 4]);
585        }
586
587        // Test with indices
588        let result = tensor_split(&tensor, TensorSplitArg::Indices(vec![2, 4]), 0)?;
589        assert_eq!(result.len(), 3);
590        assert_eq!(result[0].shape().dims(), &[2, 4]);
591        assert_eq!(result[1].shape().dims(), &[2, 4]);
592        assert_eq!(result[2].shape().dims(), &[2, 4]);
593
594        Ok(())
595    }
596
597    #[test]
598    fn test_hsplit() -> TorshResult<()> {
599        // Test horizontal split with sections
600        let tensor = randn(&[4, 6], None, None, None)?;
601        let result = hsplit(&tensor, TensorSplitArg::Sections(3))?;
602        assert_eq!(result.len(), 3);
603        for chunk in &result {
604            assert_eq!(chunk.shape().dims(), &[4, 2]);
605        }
606
607        // Test horizontal split with indices
608        let result = hsplit(&tensor, TensorSplitArg::Indices(vec![2, 4]))?;
609        assert_eq!(result.len(), 3);
610        assert_eq!(result[0].shape().dims(), &[4, 2]);
611        assert_eq!(result[1].shape().dims(), &[4, 2]);
612        assert_eq!(result[2].shape().dims(), &[4, 2]);
613
614        Ok(())
615    }
616
617    #[test]
618    fn test_vsplit() -> TorshResult<()> {
619        // Test vertical split with sections
620        let tensor = randn(&[6, 4], None, None, None)?;
621        let result = vsplit(&tensor, TensorSplitArg::Sections(3))?;
622        assert_eq!(result.len(), 3);
623        for chunk in &result {
624            assert_eq!(chunk.shape().dims(), &[2, 4]);
625        }
626
627        // Test vertical split with indices
628        let result = vsplit(&tensor, TensorSplitArg::Indices(vec![2, 4]))?;
629        assert_eq!(result.len(), 3);
630        assert_eq!(result[0].shape().dims(), &[2, 4]);
631        assert_eq!(result[1].shape().dims(), &[2, 4]);
632        assert_eq!(result[2].shape().dims(), &[2, 4]);
633
634        Ok(())
635    }
636
637    #[test]
638    fn test_dsplit() -> TorshResult<()> {
639        // Test depth split with 3D tensor
640        let tensor = randn(&[2, 3, 6], None, None, None)?;
641        let result = dsplit(&tensor, TensorSplitArg::Sections(3))?;
642        assert_eq!(result.len(), 3);
643        for chunk in &result {
644            assert_eq!(chunk.shape().dims(), &[2, 3, 2]);
645        }
646
647        // Test depth split with indices
648        let result = dsplit(&tensor, TensorSplitArg::Indices(vec![2, 4]))?;
649        assert_eq!(result.len(), 3);
650        assert_eq!(result[0].shape().dims(), &[2, 3, 2]);
651        assert_eq!(result[1].shape().dims(), &[2, 3, 2]);
652        assert_eq!(result[2].shape().dims(), &[2, 3, 2]);
653
654        Ok(())
655    }
656
657    #[test]
658    #[should_panic(expected = "Input tensor must have at least 2 dimensions for hsplit")]
659    fn test_hsplit_invalid_dimensions() {
660        let tensor = randn(&[5], None, None, None).unwrap(); // 1D tensor
661        hsplit(&tensor, TensorSplitArg::Sections(2)).unwrap();
662    }
663
664    #[test]
665    #[should_panic(expected = "Input tensor must have at least 2 dimensions for vsplit")]
666    fn test_vsplit_invalid_dimensions() {
667        let tensor = randn(&[5], None, None, None).unwrap(); // 1D tensor
668        vsplit(&tensor, TensorSplitArg::Sections(2)).unwrap();
669    }
670
671    #[test]
672    #[should_panic(expected = "Input tensor must have at least 3 dimensions for dsplit")]
673    fn test_dsplit_invalid_dimensions() {
674        let tensor = randn(&[3, 4], None, None, None).unwrap(); // 2D tensor
675        dsplit(&tensor, TensorSplitArg::Sections(2)).unwrap();
676    }
677
678    #[test]
679    fn test_split_size_mode() -> TorshResult<()> {
680        let tensor = randn(&[10, 4], None, None, None)?;
681        let result = split(&tensor, SplitArg::Size(3), 0)?;
682
683        // Should create 4 chunks: [3,4], [3,4], [3,4], [1,4]
684        assert_eq!(result.len(), 4);
685        assert_eq!(result[0].shape().dims(), &[3, 4]);
686        assert_eq!(result[1].shape().dims(), &[3, 4]);
687        assert_eq!(result[2].shape().dims(), &[3, 4]);
688        assert_eq!(result[3].shape().dims(), &[1, 4]);
689
690        Ok(())
691    }
692
693    #[test]
694    fn test_negative_dimension_indexing() -> TorshResult<()> {
695        let tensor = randn(&[4, 6, 8], None, None, None)?;
696
697        // Test negative dimension indexing
698        let result1 = split(&tensor, SplitArg::Sections(2), -1)?; // Last dimension
699        let result2 = split(&tensor, SplitArg::Sections(2), 2)?; // Explicit last dimension
700
701        assert_eq!(result1.len(), result2.len());
702        assert_eq!(result1[0].shape().dims(), result2[0].shape().dims());
703
704        Ok(())
705    }
706}