Skip to main content

torsh_tensor/
manipulation.rs

1//! Tensor manipulation operations
2//!
3//! This module provides PyTorch-compatible tensor manipulation operations including:
4//! - Stacking: stack
5//! - Splitting: chunk, split
6//! - Flipping: flip, fliplr, flipud
7//! - Rolling: roll, rot90
8//! - Tiling: tile, repeat_interleave
9//! - Utilities: unflatten, take_along_dim
10
11use crate::{Tensor, TensorElement};
12use torsh_core::error::{Result, TorshError};
13
14impl<T: TensorElement + Copy + Default> Tensor<T> {
15    /// Stack tensors along a new dimension
16    ///
17    /// # PyTorch Compatibility
18    /// Equivalent to `torch.stack(tensors, dim)`
19    ///
20    /// # Arguments
21    /// * `tensors` - Sequence of tensors to stack
22    /// * `dim` - Dimension along which to stack
23    ///
24    /// # Examples
25    /// ```ignore
26    /// let a = Tensor::from_data(vec![1.0, 2.0], vec![2], DeviceType::Cpu)?;
27    /// let b = Tensor::from_data(vec![3.0, 4.0], vec![2], DeviceType::Cpu)?;
28    /// let result = Tensor::stack(&[a, b], 0)?; // shape: [2, 2]
29    /// ```
30    pub fn stack(tensors: &[Self], dim: isize) -> Result<Self> {
31        if tensors.is_empty() {
32            return Err(TorshError::InvalidArgument(
33                "stack requires at least one tensor".to_string(),
34            ));
35        }
36
37        // Verify all tensors have the same shape
38        let first_shape = tensors[0].shape().to_vec();
39        for tensor in tensors.iter().skip(1) {
40            if tensor.shape().dims() != first_shape.as_slice() {
41                return Err(TorshError::ShapeMismatch {
42                    expected: first_shape.clone(),
43                    got: tensor.shape().to_vec(),
44                });
45            }
46        }
47
48        let ndim = first_shape.len();
49        let dim = if dim < 0 {
50            ((ndim + 1) as isize + dim) as usize
51        } else {
52            dim as usize
53        };
54
55        if dim > ndim {
56            return Err(TorshError::InvalidArgument(format!(
57                "Dimension {} out of range for stacking {}-D tensors",
58                dim, ndim
59            )));
60        }
61
62        // Calculate output shape: insert new dimension at position `dim`
63        let mut output_shape = first_shape.to_vec();
64        output_shape.insert(dim, tensors.len());
65
66        // Stack tensors by interleaving data
67        let elem_count: usize = first_shape.iter().product();
68        let mut result_data = Vec::with_capacity(elem_count * tensors.len());
69
70        // Calculate strides for proper data layout
71        let outer_size: usize = first_shape[..dim].iter().product();
72        let inner_size: usize = first_shape[dim..].iter().product();
73
74        for outer in 0..outer_size {
75            for tensor in tensors {
76                let data = tensor.to_vec()?;
77                for inner in 0..inner_size {
78                    let idx = outer * inner_size + inner;
79                    result_data.push(data[idx]);
80                }
81            }
82        }
83
84        let device = tensors[0].device.clone();
85        Self::from_data(result_data, output_shape, device)
86    }
87
88    /// Split tensor into chunks
89    ///
90    /// # PyTorch Compatibility
91    /// Equivalent to `torch.chunk(tensor, chunks, dim)`
92    pub fn chunk(&self, chunks: usize, dim: isize) -> Result<Vec<Self>> {
93        if chunks == 0 {
94            return Err(TorshError::InvalidArgument(
95                "chunks must be greater than 0".to_string(),
96            ));
97        }
98
99        let ndim = self.ndim();
100        let dim = if dim < 0 {
101            (ndim as isize + dim) as usize
102        } else {
103            dim as usize
104        };
105
106        if dim >= ndim {
107            return Err(TorshError::InvalidArgument(format!(
108                "Dimension {} out of range for {}-D tensor",
109                dim, ndim
110            )));
111        }
112
113        let dim_size = self.shape().dims()[dim];
114        let chunk_size = (dim_size + chunks - 1) / chunks; // Ceiling division
115
116        let mut result = Vec::new();
117        let mut start = 0;
118
119        while start < dim_size {
120            let end = (start + chunk_size).min(dim_size);
121            let slice_tensor = self.narrow(dim as i32, start as i64, end - start)?;
122            result.push(slice_tensor);
123            start = end;
124        }
125
126        Ok(result)
127    }
128
129    /// Split tensor into parts of given size
130    ///
131    /// # PyTorch Compatibility
132    /// Equivalent to `torch.split(tensor, split_size, dim)`
133    pub fn split(&self, split_size: usize, dim: isize) -> Result<Vec<Self>> {
134        if split_size == 0 {
135            return Err(TorshError::InvalidArgument(
136                "split_size must be greater than 0".to_string(),
137            ));
138        }
139
140        let ndim = self.ndim();
141        let dim = if dim < 0 {
142            (ndim as isize + dim) as usize
143        } else {
144            dim as usize
145        };
146
147        if dim >= ndim {
148            return Err(TorshError::InvalidArgument(format!(
149                "Dimension {} out of range for {}-D tensor",
150                dim, ndim
151            )));
152        }
153
154        let dim_size = self.shape().dims()[dim];
155        let mut result = Vec::new();
156        let mut start = 0;
157
158        while start < dim_size {
159            let size = split_size.min(dim_size - start);
160            let slice_tensor = self.narrow(dim as i32, start as i64, size)?;
161            result.push(slice_tensor);
162            start += split_size;
163        }
164
165        Ok(result)
166    }
167
168    /// Flip tensor along given dimensions
169    ///
170    /// # PyTorch Compatibility
171    /// Equivalent to `torch.flip(tensor, dims)`
172    pub fn flip(&self, dims: &[isize]) -> Result<Self> {
173        if dims.is_empty() {
174            return Ok(self.clone());
175        }
176
177        let ndim = self.ndim();
178
179        // Normalize dimensions
180        let mut norm_dims = Vec::new();
181        for &dim in dims {
182            let d = if dim < 0 {
183                (ndim as isize + dim) as usize
184            } else {
185                dim as usize
186            };
187
188            if d >= ndim {
189                return Err(TorshError::InvalidArgument(format!(
190                    "Dimension {} out of range for {}-D tensor",
191                    dim, ndim
192                )));
193            }
194            norm_dims.push(d);
195        }
196
197        // Flip data
198        let data = self.to_vec()?;
199        let shape = self.shape().to_vec();
200        let mut result_data = vec![T::default(); data.len()];
201
202        // Calculate strides
203        let mut strides = vec![1; ndim];
204        for i in (0..ndim - 1).rev() {
205            strides[i] = strides[i + 1] * shape[i + 1];
206        }
207
208        // Copy data with flipped indices
209        for i in 0..data.len() {
210            let mut indices = vec![0; ndim];
211            let mut remainder = i;
212
213            for d in 0..ndim {
214                indices[d] = remainder / strides[d];
215                remainder %= strides[d];
216            }
217
218            // Flip specified dimensions
219            for &flip_dim in &norm_dims {
220                indices[flip_dim] = shape[flip_dim] - 1 - indices[flip_dim];
221            }
222
223            // Calculate flipped index
224            let mut flipped_idx = 0;
225            for d in 0..ndim {
226                flipped_idx += indices[d] * strides[d];
227            }
228
229            result_data[flipped_idx] = data[i];
230        }
231
232        Self::from_data(result_data, shape.to_vec(), self.device)
233    }
234
235    /// Flip tensor left-right (last dimension)
236    ///
237    /// # PyTorch Compatibility
238    /// Equivalent to `torch.fliplr(tensor)`
239    pub fn fliplr(&self) -> Result<Self> {
240        if self.ndim() < 2 {
241            return Err(TorshError::InvalidArgument(
242                "fliplr requires at least 2 dimensions".to_string(),
243            ));
244        }
245        self.flip(&[-1])
246    }
247
248    /// Flip tensor up-down (first dimension)
249    ///
250    /// # PyTorch Compatibility
251    /// Equivalent to `torch.flipud(tensor)`
252    pub fn flipud(&self) -> Result<Self> {
253        if self.ndim() < 1 {
254            return Err(TorshError::InvalidArgument(
255                "flipud requires at least 1 dimension".to_string(),
256            ));
257        }
258        self.flip(&[0])
259    }
260
261    /// Roll tensor elements along given dimensions
262    ///
263    /// # PyTorch Compatibility
264    /// Equivalent to `torch.roll(tensor, shifts, dims)`
265    pub fn roll(&self, shifts: &[isize], dims: &[isize]) -> Result<Self> {
266        if shifts.len() != dims.len() {
267            return Err(TorshError::InvalidArgument(
268                "shifts and dims must have the same length".to_string(),
269            ));
270        }
271
272        if dims.is_empty() {
273            // Roll flattened tensor
274            let data = self.to_vec()?;
275            let shift = if shifts.is_empty() { 0 } else { shifts[0] };
276            let n = data.len();
277            let shift = ((shift % n as isize) + n as isize) as usize % n;
278
279            let mut result_data = vec![T::default(); n];
280            for (i, &val) in data.iter().enumerate() {
281                result_data[(i + shift) % n] = val;
282            }
283
284            return Self::from_data(result_data, self.shape().dims().to_vec(), self.device);
285        }
286
287        let ndim = self.ndim();
288
289        // Normalize dimensions
290        let mut norm_dims = Vec::new();
291        for &dim in dims {
292            let d = if dim < 0 {
293                (ndim as isize + dim) as usize
294            } else {
295                dim as usize
296            };
297
298            if d >= ndim {
299                return Err(TorshError::InvalidArgument(format!(
300                    "Dimension {} out of range for {}-D tensor",
301                    dim, ndim
302                )));
303            }
304            norm_dims.push(d);
305        }
306
307        let data = self.to_vec()?;
308        let shape = self.shape().to_vec();
309        let mut result_data = vec![T::default(); data.len()];
310
311        // Calculate strides
312        let mut strides = vec![1; ndim];
313        for i in (0..ndim - 1).rev() {
314            strides[i] = strides[i + 1] * shape[i + 1];
315        }
316
317        // Copy data with rolled indices
318        for i in 0..data.len() {
319            let mut indices = vec![0; ndim];
320            let mut remainder = i;
321
322            for d in 0..ndim {
323                indices[d] = remainder / strides[d];
324                remainder %= strides[d];
325            }
326
327            // Roll specified dimensions
328            for (dim_idx, &roll_dim) in norm_dims.iter().enumerate() {
329                let shift = shifts[dim_idx];
330                let dim_size = shape[roll_dim] as isize;
331                let rolled =
332                    ((indices[roll_dim] as isize + shift) % dim_size + dim_size) % dim_size;
333                indices[roll_dim] = rolled as usize;
334            }
335
336            // Calculate rolled index
337            let mut rolled_idx = 0;
338            for d in 0..ndim {
339                rolled_idx += indices[d] * strides[d];
340            }
341
342            result_data[rolled_idx] = data[i];
343        }
344
345        Self::from_data(result_data, shape.to_vec(), self.device)
346    }
347
348    /// Rotate tensor 90 degrees
349    ///
350    /// # PyTorch Compatibility
351    /// Equivalent to `torch.rot90(tensor, k, dims)`
352    pub fn rot90(&self, k: isize, dims: &[isize]) -> Result<Self> {
353        if dims.len() != 2 {
354            return Err(TorshError::InvalidArgument(
355                "dims must contain exactly 2 dimensions".to_string(),
356            ));
357        }
358
359        let ndim = self.ndim();
360        if ndim < 2 {
361            return Err(TorshError::InvalidArgument(
362                "rot90 requires at least 2 dimensions".to_string(),
363            ));
364        }
365
366        // Normalize dimensions
367        let dim0 = if dims[0] < 0 {
368            (ndim as isize + dims[0]) as usize
369        } else {
370            dims[0] as usize
371        };
372
373        let dim1 = if dims[1] < 0 {
374            (ndim as isize + dims[1]) as usize
375        } else {
376            dims[1] as usize
377        };
378
379        if dim0 >= ndim || dim1 >= ndim {
380            return Err(TorshError::InvalidArgument("dims out of range".to_string()));
381        }
382
383        if dim0 == dim1 {
384            return Err(TorshError::InvalidArgument(
385                "dims must be different".to_string(),
386            ));
387        }
388
389        // Normalize k to [0, 4)
390        let k = ((k % 4) + 4) % 4;
391
392        let mut result = self.clone();
393        for _ in 0..k {
394            // Transpose and flip
395            result = result.transpose_view(dim0, dim1)?;
396            result = result.flip(&[dim1 as isize])?;
397        }
398
399        Ok(result)
400    }
401
402    /// Tile tensor by repeating
403    ///
404    /// # PyTorch Compatibility
405    /// Equivalent to `torch.tile(tensor, repeats)`
406    pub fn tile(&self, repeats: &[usize]) -> Result<Self> {
407        if repeats.is_empty() {
408            return Ok(self.clone());
409        }
410
411        let shape = self.shape().to_vec();
412        let ndim = shape.len();
413
414        // Extend shape if needed
415        let mut new_shape = shape.to_vec();
416        if repeats.len() > ndim {
417            let diff = repeats.len() - ndim;
418            for _ in 0..diff {
419                new_shape.insert(0, 1);
420            }
421        }
422
423        // Calculate output shape
424        let mut output_shape = new_shape.clone();
425        let repeat_offset = if repeats.len() < output_shape.len() {
426            output_shape.len() - repeats.len()
427        } else {
428            0
429        };
430
431        for (i, &rep) in repeats.iter().enumerate() {
432            let idx = repeat_offset + i;
433            if idx < output_shape.len() {
434                output_shape[idx] *= rep;
435            }
436        }
437
438        // Tile using existing repeat method from data_ops
439        self.repeat(repeats)
440    }
441
442    /// Repeat elements of a tensor along a dimension
443    ///
444    /// # PyTorch Compatibility
445    /// Equivalent to `torch.repeat_interleave(tensor, repeats, dim)`
446    ///
447    /// # Arguments
448    /// * `repeats` - Number of times to repeat each element
449    /// * `dim` - Dimension along which to repeat (None = flatten first)
450    ///
451    /// # Examples
452    /// ```ignore
453    /// let x = Tensor::from_data(vec![1.0, 2.0, 3.0], vec![3], DeviceType::Cpu)?;
454    /// let y = x.repeat_interleave(2, None)?; // [1.0, 1.0, 2.0, 2.0, 3.0, 3.0]
455    /// ```
456    pub fn repeat_interleave(&self, repeats: usize, dim: Option<isize>) -> Result<Self> {
457        if repeats == 0 {
458            return Err(TorshError::InvalidArgument(
459                "repeats must be positive".to_string(),
460            ));
461        }
462
463        match dim {
464            None => {
465                // Flatten and repeat each element
466                let data = self.to_vec()?;
467                let mut result_data = Vec::with_capacity(data.len() * repeats);
468
469                for &val in data.iter() {
470                    for _ in 0..repeats {
471                        result_data.push(val);
472                    }
473                }
474
475                Self::from_data(result_data, vec![data.len() * repeats], self.device)
476            }
477            Some(d) => {
478                let ndim = self.ndim();
479                let dim = if d < 0 {
480                    (ndim as isize + d) as usize
481                } else {
482                    d as usize
483                };
484
485                if dim >= ndim {
486                    return Err(TorshError::InvalidArgument(format!(
487                        "Dimension {} out of range for {}-D tensor",
488                        d, ndim
489                    )));
490                }
491
492                let shape = self.shape().to_vec();
493                let data = self.to_vec()?;
494
495                // Calculate output shape
496                let mut output_shape = shape.clone();
497                output_shape[dim] *= repeats;
498
499                // Repeat along the specified dimension
500                let dim_size = shape[dim];
501                let outer_size: usize = shape[..dim].iter().product();
502                let inner_size: usize = shape[dim + 1..].iter().product();
503
504                let mut result_data = Vec::with_capacity(data.len() * repeats);
505
506                for outer in 0..outer_size {
507                    for d in 0..dim_size {
508                        for _ in 0..repeats {
509                            for inner in 0..inner_size {
510                                let idx = outer * dim_size * inner_size + d * inner_size + inner;
511                                result_data.push(data[idx]);
512                            }
513                        }
514                    }
515                }
516
517                Self::from_data(result_data, output_shape, self.device)
518            }
519        }
520    }
521
522    /// Unflatten a dimension into multiple dimensions
523    ///
524    /// # PyTorch Compatibility
525    /// Equivalent to `torch.unflatten(tensor, dim, sizes)`
526    ///
527    /// # Arguments
528    /// * `dim` - Dimension to unflatten
529    /// * `sizes` - Target sizes for the unflattened dimensions
530    ///
531    /// # Examples
532    /// ```ignore
533    /// let x = Tensor::from_data(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6], DeviceType::Cpu)?;
534    /// let y = x.unflatten(0, &[2, 3])?; // Shape becomes [2, 3]
535    /// ```
536    pub fn unflatten(&self, dim: isize, sizes: &[usize]) -> Result<Self> {
537        if sizes.is_empty() {
538            return Err(TorshError::InvalidArgument(
539                "sizes cannot be empty".to_string(),
540            ));
541        }
542
543        let shape = self.shape().to_vec();
544        let ndim = shape.len();
545
546        // Normalize dimension
547        let dim = if dim < 0 {
548            (ndim as isize + dim) as usize
549        } else {
550            dim as usize
551        };
552
553        if dim >= ndim {
554            return Err(TorshError::InvalidArgument(format!(
555                "Dimension {} out of range for {}-D tensor",
556                dim, ndim
557            )));
558        }
559
560        // Verify that the product of sizes equals the dimension size
561        let sizes_product: usize = sizes.iter().product();
562        if sizes_product != shape[dim] {
563            return Err(TorshError::InvalidArgument(format!(
564                "sizes product {} does not match dimension size {}",
565                sizes_product, shape[dim]
566            )));
567        }
568
569        // Build new shape
570        let mut new_shape = Vec::new();
571        new_shape.extend_from_slice(&shape[..dim]);
572        new_shape.extend_from_slice(sizes);
573        new_shape.extend_from_slice(&shape[dim + 1..]);
574
575        // Reshape (data layout doesn't change, only shape interpretation)
576        let data = self.to_vec()?;
577        Self::from_data(data, new_shape, self.device)
578    }
579
580    /// Gather values along a dimension using indices
581    ///
582    /// # PyTorch Compatibility
583    /// Equivalent to `torch.take_along_dim(tensor, indices, dim)`
584    ///
585    /// # Arguments
586    /// * `indices` - Indices to gather
587    /// * `dim` - Dimension along which to gather (None = flatten first)
588    ///
589    /// # Examples
590    /// ```ignore
591    /// let x = Tensor::from_data(vec![1.0, 2.0, 3.0, 4.0], vec![4], DeviceType::Cpu)?;
592    /// let indices = Tensor::from_data(vec![0i64, 2], vec![2], DeviceType::Cpu)?;
593    /// let y = x.take_along_dim(&indices, None)?; // [1.0, 3.0]
594    /// ```
595    pub fn take_along_dim(&self, indices: &Tensor<i64>, dim: Option<isize>) -> Result<Self> {
596        match dim {
597            None => {
598                // Flatten both tensors and use simple indexing
599                let data = self.to_vec()?;
600                let idx_data = indices.to_vec()?;
601
602                let mut result = Vec::with_capacity(idx_data.len());
603
604                for &idx in idx_data.iter() {
605                    if idx < 0 || idx as usize >= data.len() {
606                        return Err(TorshError::InvalidArgument(format!(
607                            "Index {} out of range for tensor with {} elements",
608                            idx,
609                            data.len()
610                        )));
611                    }
612                    result.push(data[idx as usize]);
613                }
614
615                Self::from_data(result, indices.shape().to_vec(), self.device)
616            }
617            Some(d) => {
618                let ndim = self.ndim();
619                let dim = if d < 0 {
620                    (ndim as isize + d) as usize
621                } else {
622                    d as usize
623                };
624
625                if dim >= ndim {
626                    return Err(TorshError::InvalidArgument(format!(
627                        "Dimension {} out of range for {}-D tensor",
628                        d, ndim
629                    )));
630                }
631
632                let self_shape = self.shape().to_vec();
633                let indices_shape = indices.shape().to_vec();
634
635                // Verify shapes match except at the gather dimension
636                if self_shape.len() != indices_shape.len() {
637                    return Err(TorshError::ShapeMismatch {
638                        expected: self_shape.clone(),
639                        got: indices_shape.clone(),
640                    });
641                }
642
643                for (i, (&s, &idx_s)) in self_shape.iter().zip(indices_shape.iter()).enumerate() {
644                    if i != dim && s != idx_s {
645                        return Err(TorshError::ShapeMismatch {
646                            expected: self_shape.clone(),
647                            got: indices_shape.clone(),
648                        });
649                    }
650                }
651
652                let data = self.to_vec()?;
653                let idx_data = indices.to_vec()?;
654
655                let dim_size = self_shape[dim];
656                let outer_size: usize = self_shape[..dim].iter().product();
657                let inner_size: usize = self_shape[dim + 1..].iter().product();
658
659                let indices_dim_size = indices_shape[dim];
660                let mut result = Vec::with_capacity(idx_data.len());
661
662                for outer in 0..outer_size {
663                    for d in 0..indices_dim_size {
664                        for inner in 0..inner_size {
665                            let idx_flat =
666                                outer * indices_dim_size * inner_size + d * inner_size + inner;
667                            let gather_idx = idx_data[idx_flat];
668
669                            if gather_idx < 0 || gather_idx as usize >= dim_size {
670                                return Err(TorshError::InvalidArgument(format!(
671                                    "Index {} out of range for dimension size {}",
672                                    gather_idx, dim_size
673                                )));
674                            }
675
676                            let src_idx = outer * dim_size * inner_size
677                                + (gather_idx as usize) * inner_size
678                                + inner;
679
680                            result.push(data[src_idx]);
681                        }
682                    }
683                }
684
685                Self::from_data(result, indices_shape, self.device)
686            }
687        }
688    }
689}
690
691#[cfg(test)]
692mod tests {
693    use super::*;
694    use torsh_core::device::DeviceType;
695
696    // Stack tests
697    #[test]
698    fn test_stack_1d() {
699        let a = Tensor::from_data(vec![1.0f32, 2.0], vec![2], DeviceType::Cpu)
700            .expect("failed to create tensor a");
701        let b = Tensor::from_data(vec![3.0f32, 4.0], vec![2], DeviceType::Cpu)
702            .expect("failed to create tensor b");
703
704        let result = Tensor::stack(&[a, b], 0).expect("stack should succeed for 1d tensors");
705
706        assert_eq!(result.shape().dims(), &[2, 2]);
707        let data = result.data().expect("failed to get stacked tensor data");
708        assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0]);
709    }
710
711    #[test]
712    fn test_stack_negative_dim() {
713        let a = Tensor::from_data(vec![1.0f32, 2.0], vec![2], DeviceType::Cpu)
714            .expect("failed to create tensor a");
715        let b = Tensor::from_data(vec![3.0f32, 4.0], vec![2], DeviceType::Cpu)
716            .expect("failed to create tensor b");
717
718        let result = Tensor::stack(&[a, b], -1).expect("stack should succeed with negative dim");
719        assert_eq!(result.shape().dims(), &[2, 2]);
720    }
721
722    // Chunk tests
723    #[test]
724    fn test_chunk_even() {
725        let tensor = Tensor::from_data(
726            vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0],
727            vec![6],
728            DeviceType::Cpu,
729        )
730        .expect("failed to create tensor for chunk_even");
731
732        let chunks = tensor.chunk(3, 0).expect("chunk into 3 should succeed");
733        assert_eq!(chunks.len(), 3);
734        assert_eq!(chunks[0].shape().dims(), &[2]);
735        assert_eq!(
736            chunks[0].data().expect("failed to get chunk 0 data"),
737            vec![1.0, 2.0]
738        );
739        assert_eq!(
740            chunks[1].data().expect("failed to get chunk 1 data"),
741            vec![3.0, 4.0]
742        );
743        assert_eq!(
744            chunks[2].data().expect("failed to get chunk 2 data"),
745            vec![5.0, 6.0]
746        );
747    }
748
749    #[test]
750    fn test_chunk_uneven() {
751        let tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0, 4.0, 5.0], vec![5], DeviceType::Cpu)
752            .expect("failed to create tensor for chunk_uneven");
753
754        let chunks = tensor.chunk(2, 0).expect("uneven chunk should succeed");
755        assert_eq!(chunks.len(), 2);
756        assert_eq!(chunks[0].shape().dims(), &[3]);
757        assert_eq!(chunks[1].shape().dims(), &[2]);
758    }
759
760    // Split tests
761    #[test]
762    fn test_split_even() {
763        let tensor = Tensor::from_data(
764            vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0],
765            vec![6],
766            DeviceType::Cpu,
767        )
768        .expect("failed to create tensor for split_even");
769
770        let splits = tensor.split(2, 0).expect("split by 2 should succeed");
771        assert_eq!(splits.len(), 3);
772        assert_eq!(
773            splits[0].data().expect("failed to get split 0 data"),
774            vec![1.0, 2.0]
775        );
776        assert_eq!(
777            splits[1].data().expect("failed to get split 1 data"),
778            vec![3.0, 4.0]
779        );
780    }
781
782    #[test]
783    fn test_split_uneven() {
784        let tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0, 4.0, 5.0], vec![5], DeviceType::Cpu)
785            .expect("failed to create tensor for split_uneven");
786
787        let splits = tensor.split(2, 0).expect("uneven split should succeed");
788        assert_eq!(splits.len(), 3);
789        assert_eq!(splits[0].shape().dims(), &[2]);
790        assert_eq!(splits[1].shape().dims(), &[2]);
791        assert_eq!(splits[2].shape().dims(), &[1]); // Last one smaller
792    }
793
794    // Flip tests
795    #[test]
796    fn test_flip_1d() {
797        let tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0, 4.0], vec![4], DeviceType::Cpu)
798            .expect("failed to create 1d tensor for flip");
799
800        let result = tensor.flip(&[0]).expect("flip dim 0 should succeed");
801        assert_eq!(
802            result.data().expect("failed to get flipped data"),
803            vec![4.0, 3.0, 2.0, 1.0]
804        );
805    }
806
807    #[test]
808    fn test_flip_2d() {
809        let tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0, 4.0], vec![2, 2], DeviceType::Cpu)
810            .expect("failed to create 2d tensor for flip");
811
812        let result = tensor.flip(&[0]).expect("flip 2d dim 0 should succeed");
813        assert_eq!(
814            result.data().expect("failed to get 2d flipped data"),
815            vec![3.0, 4.0, 1.0, 2.0]
816        );
817    }
818
819    #[test]
820    fn test_fliplr() {
821        let tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0, 4.0], vec![2, 2], DeviceType::Cpu)
822            .expect("failed to create tensor for fliplr");
823
824        let result = tensor.fliplr().expect("fliplr should succeed");
825        assert_eq!(
826            result.data().expect("failed to get fliplr data"),
827            vec![2.0, 1.0, 4.0, 3.0]
828        );
829    }
830
831    #[test]
832    fn test_flipud() {
833        let tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0, 4.0], vec![2, 2], DeviceType::Cpu)
834            .expect("failed to create tensor for flipud");
835
836        let result = tensor.flipud().expect("flipud should succeed");
837        assert_eq!(
838            result.data().expect("failed to get flipud data"),
839            vec![3.0, 4.0, 1.0, 2.0]
840        );
841    }
842
843    // Roll tests
844    #[test]
845    fn test_roll_1d() {
846        let tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0, 4.0], vec![4], DeviceType::Cpu)
847            .expect("failed to create tensor for roll");
848
849        let result = tensor.roll(&[1], &[0]).expect("roll by 1 should succeed");
850        assert_eq!(
851            result.data().expect("failed to get rolled data"),
852            vec![4.0, 1.0, 2.0, 3.0]
853        );
854    }
855
856    #[test]
857    fn test_roll_negative() {
858        let tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0, 4.0], vec![4], DeviceType::Cpu)
859            .expect("failed to create tensor for negative roll");
860
861        let result = tensor
862            .roll(&[-1], &[0])
863            .expect("negative roll should succeed");
864        assert_eq!(
865            result.data().expect("failed to get negatively rolled data"),
866            vec![2.0, 3.0, 4.0, 1.0]
867        );
868    }
869
870    // Rot90 tests
871    #[test]
872    fn test_rot90_once() {
873        let tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0, 4.0], vec![2, 2], DeviceType::Cpu)
874            .expect("failed to create tensor for rot90");
875
876        let result = tensor.rot90(1, &[0, 1]).expect("rot90 once should succeed");
877        assert_eq!(result.shape().dims(), &[2, 2]);
878        // After 90° rotation: [[1,2],[3,4]] -> [[2,4],[1,3]]
879    }
880
881    #[test]
882    fn test_rot90_twice() {
883        let tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0, 4.0], vec![2, 2], DeviceType::Cpu)
884            .expect("failed to create tensor for rot90 twice");
885
886        let result = tensor
887            .rot90(2, &[0, 1])
888            .expect("rot90 twice should succeed");
889        assert_eq!(result.shape().dims(), &[2, 2]);
890        // After 180° rotation: [[1,2],[3,4]] -> [[4,3],[2,1]]
891        assert_eq!(
892            result.data().expect("failed to get rot90 twice data"),
893            vec![4.0, 3.0, 2.0, 1.0]
894        );
895    }
896
897    // Tile tests
898    #[test]
899    fn test_tile_1d() {
900        let tensor = Tensor::from_data(vec![1.0f32, 2.0], vec![2], DeviceType::Cpu)
901            .expect("failed to create tensor for tile 1d");
902
903        let result = tensor.tile(&[2]).expect("tile 1d should succeed");
904        assert_eq!(result.shape().dims(), &[4]);
905        assert_eq!(
906            result.data().expect("failed to get tiled 1d data"),
907            vec![1.0, 2.0, 1.0, 2.0]
908        );
909    }
910
911    #[test]
912    fn test_tile_2d() {
913        let tensor = Tensor::from_data(vec![1.0f32, 2.0], vec![1, 2], DeviceType::Cpu)
914            .expect("failed to create tensor for tile 2d");
915
916        let result = tensor.tile(&[2, 1]).expect("tile 2d should succeed");
917        assert_eq!(result.shape().dims(), &[2, 2]);
918    }
919
920    // Repeat interleave tests
921    #[test]
922    fn test_repeat_interleave_flatten() {
923        let tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0], vec![3], DeviceType::Cpu)
924            .expect("failed to create tensor for repeat_interleave");
925
926        let result = tensor
927            .repeat_interleave(2, None)
928            .expect("repeat_interleave flatten should succeed");
929        assert_eq!(result.shape().dims(), &[6]);
930        assert_eq!(
931            result.data().expect("failed to get repeat_interleave data"),
932            vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0]
933        );
934    }
935
936    #[test]
937    fn test_repeat_interleave_dim() {
938        let tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0, 4.0], vec![2, 2], DeviceType::Cpu)
939            .expect("failed to create tensor for repeat_interleave dim");
940
941        let result = tensor
942            .repeat_interleave(2, Some(0))
943            .expect("repeat_interleave along dim 0 should succeed");
944        assert_eq!(result.shape().dims(), &[4, 2]);
945    }
946
947    // Unflatten tests
948    #[test]
949    fn test_unflatten_basic() {
950        let tensor = Tensor::from_data(
951            vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0],
952            vec![6],
953            DeviceType::Cpu,
954        )
955        .expect("failed to create tensor for unflatten");
956
957        let result = tensor
958            .unflatten(0, &[2, 3])
959            .expect("unflatten to [2,3] should succeed");
960        assert_eq!(result.shape().dims(), &[2, 3]);
961        assert_eq!(
962            result.data().expect("failed to get unflattened data"),
963            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
964        );
965    }
966
967    #[test]
968    fn test_unflatten_multiple_dims() {
969        let tensor = Tensor::from_data(vec![1.0f32; 24], vec![24], DeviceType::Cpu)
970            .expect("failed to create tensor for unflatten multiple dims");
971
972        let result = tensor
973            .unflatten(0, &[2, 3, 4])
974            .expect("unflatten to [2,3,4] should succeed");
975        assert_eq!(result.shape().dims(), &[2, 3, 4]);
976    }
977
978    // Take along dim tests
979    #[test]
980    fn test_take_along_dim_flatten() {
981        let tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0, 4.0], vec![4], DeviceType::Cpu)
982            .expect("failed to create tensor for take_along_dim");
983
984        let indices = Tensor::from_data(vec![0i64, 2, 1], vec![3], DeviceType::Cpu)
985            .expect("failed to create indices tensor");
986
987        let result = tensor
988            .take_along_dim(&indices, None)
989            .expect("take_along_dim flatten should succeed");
990        assert_eq!(result.shape().dims(), &[3]);
991        assert_eq!(
992            result
993                .data()
994                .expect("failed to get take_along_dim flatten data"),
995            vec![1.0, 3.0, 2.0]
996        );
997    }
998
999    #[test]
1000    fn test_take_along_dim_2d() {
1001        let tensor = Tensor::from_data(
1002            vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0],
1003            vec![2, 3],
1004            DeviceType::Cpu,
1005        )
1006        .expect("failed to create 2d tensor for take_along_dim");
1007
1008        let indices = Tensor::from_data(vec![0i64, 2, 1, 1, 0, 2], vec![2, 3], DeviceType::Cpu)
1009            .expect("failed to create 2d indices tensor");
1010
1011        let result = tensor
1012            .take_along_dim(&indices, Some(1))
1013            .expect("take_along_dim 2d should succeed");
1014        assert_eq!(result.shape().dims(), &[2, 3]);
1015        // Row 0: [1.0, 2.0, 3.0] indexed by [0, 2, 1] = [1.0, 3.0, 2.0]
1016        // Row 1: [4.0, 5.0, 6.0] indexed by [1, 0, 2] = [5.0, 4.0, 6.0]
1017        assert_eq!(
1018            result.data().expect("failed to get take_along_dim 2d data"),
1019            vec![1.0, 3.0, 2.0, 5.0, 4.0, 6.0]
1020        );
1021    }
1022}