1use crate::{Tensor, TensorElement};
12use torsh_core::error::{Result, TorshError};
13
14impl<T: TensorElement + Copy + Default> Tensor<T> {
15 pub fn stack(tensors: &[Self], dim: isize) -> Result<Self> {
31 if tensors.is_empty() {
32 return Err(TorshError::InvalidArgument(
33 "stack requires at least one tensor".to_string(),
34 ));
35 }
36
37 let first_shape = tensors[0].shape().to_vec();
39 for tensor in tensors.iter().skip(1) {
40 if tensor.shape().dims() != first_shape.as_slice() {
41 return Err(TorshError::ShapeMismatch {
42 expected: first_shape.clone(),
43 got: tensor.shape().to_vec(),
44 });
45 }
46 }
47
48 let ndim = first_shape.len();
49 let dim = if dim < 0 {
50 ((ndim + 1) as isize + dim) as usize
51 } else {
52 dim as usize
53 };
54
55 if dim > ndim {
56 return Err(TorshError::InvalidArgument(format!(
57 "Dimension {} out of range for stacking {}-D tensors",
58 dim, ndim
59 )));
60 }
61
62 let mut output_shape = first_shape.to_vec();
64 output_shape.insert(dim, tensors.len());
65
66 let elem_count: usize = first_shape.iter().product();
68 let mut result_data = Vec::with_capacity(elem_count * tensors.len());
69
70 let outer_size: usize = first_shape[..dim].iter().product();
72 let inner_size: usize = first_shape[dim..].iter().product();
73
74 for outer in 0..outer_size {
75 for tensor in tensors {
76 let data = tensor.to_vec()?;
77 for inner in 0..inner_size {
78 let idx = outer * inner_size + inner;
79 result_data.push(data[idx]);
80 }
81 }
82 }
83
84 let device = tensors[0].device.clone();
85 Self::from_data(result_data, output_shape, device)
86 }
87
88 pub fn chunk(&self, chunks: usize, dim: isize) -> Result<Vec<Self>> {
93 if chunks == 0 {
94 return Err(TorshError::InvalidArgument(
95 "chunks must be greater than 0".to_string(),
96 ));
97 }
98
99 let ndim = self.ndim();
100 let dim = if dim < 0 {
101 (ndim as isize + dim) as usize
102 } else {
103 dim as usize
104 };
105
106 if dim >= ndim {
107 return Err(TorshError::InvalidArgument(format!(
108 "Dimension {} out of range for {}-D tensor",
109 dim, ndim
110 )));
111 }
112
113 let dim_size = self.shape().dims()[dim];
114 let chunk_size = (dim_size + chunks - 1) / chunks; let mut result = Vec::new();
117 let mut start = 0;
118
119 while start < dim_size {
120 let end = (start + chunk_size).min(dim_size);
121 let slice_tensor = self.narrow(dim as i32, start as i64, end - start)?;
122 result.push(slice_tensor);
123 start = end;
124 }
125
126 Ok(result)
127 }
128
129 pub fn split(&self, split_size: usize, dim: isize) -> Result<Vec<Self>> {
134 if split_size == 0 {
135 return Err(TorshError::InvalidArgument(
136 "split_size must be greater than 0".to_string(),
137 ));
138 }
139
140 let ndim = self.ndim();
141 let dim = if dim < 0 {
142 (ndim as isize + dim) as usize
143 } else {
144 dim as usize
145 };
146
147 if dim >= ndim {
148 return Err(TorshError::InvalidArgument(format!(
149 "Dimension {} out of range for {}-D tensor",
150 dim, ndim
151 )));
152 }
153
154 let dim_size = self.shape().dims()[dim];
155 let mut result = Vec::new();
156 let mut start = 0;
157
158 while start < dim_size {
159 let size = split_size.min(dim_size - start);
160 let slice_tensor = self.narrow(dim as i32, start as i64, size)?;
161 result.push(slice_tensor);
162 start += split_size;
163 }
164
165 Ok(result)
166 }
167
168 pub fn flip(&self, dims: &[isize]) -> Result<Self> {
173 if dims.is_empty() {
174 return Ok(self.clone());
175 }
176
177 let ndim = self.ndim();
178
179 let mut norm_dims = Vec::new();
181 for &dim in dims {
182 let d = if dim < 0 {
183 (ndim as isize + dim) as usize
184 } else {
185 dim as usize
186 };
187
188 if d >= ndim {
189 return Err(TorshError::InvalidArgument(format!(
190 "Dimension {} out of range for {}-D tensor",
191 dim, ndim
192 )));
193 }
194 norm_dims.push(d);
195 }
196
197 let data = self.to_vec()?;
199 let shape = self.shape().to_vec();
200 let mut result_data = vec![T::default(); data.len()];
201
202 let mut strides = vec![1; ndim];
204 for i in (0..ndim - 1).rev() {
205 strides[i] = strides[i + 1] * shape[i + 1];
206 }
207
208 for i in 0..data.len() {
210 let mut indices = vec![0; ndim];
211 let mut remainder = i;
212
213 for d in 0..ndim {
214 indices[d] = remainder / strides[d];
215 remainder %= strides[d];
216 }
217
218 for &flip_dim in &norm_dims {
220 indices[flip_dim] = shape[flip_dim] - 1 - indices[flip_dim];
221 }
222
223 let mut flipped_idx = 0;
225 for d in 0..ndim {
226 flipped_idx += indices[d] * strides[d];
227 }
228
229 result_data[flipped_idx] = data[i];
230 }
231
232 Self::from_data(result_data, shape.to_vec(), self.device)
233 }
234
235 pub fn fliplr(&self) -> Result<Self> {
240 if self.ndim() < 2 {
241 return Err(TorshError::InvalidArgument(
242 "fliplr requires at least 2 dimensions".to_string(),
243 ));
244 }
245 self.flip(&[-1])
246 }
247
248 pub fn flipud(&self) -> Result<Self> {
253 if self.ndim() < 1 {
254 return Err(TorshError::InvalidArgument(
255 "flipud requires at least 1 dimension".to_string(),
256 ));
257 }
258 self.flip(&[0])
259 }
260
261 pub fn roll(&self, shifts: &[isize], dims: &[isize]) -> Result<Self> {
266 if shifts.len() != dims.len() {
267 return Err(TorshError::InvalidArgument(
268 "shifts and dims must have the same length".to_string(),
269 ));
270 }
271
272 if dims.is_empty() {
273 let data = self.to_vec()?;
275 let shift = if shifts.is_empty() { 0 } else { shifts[0] };
276 let n = data.len();
277 let shift = ((shift % n as isize) + n as isize) as usize % n;
278
279 let mut result_data = vec![T::default(); n];
280 for (i, &val) in data.iter().enumerate() {
281 result_data[(i + shift) % n] = val;
282 }
283
284 return Self::from_data(result_data, self.shape().dims().to_vec(), self.device);
285 }
286
287 let ndim = self.ndim();
288
289 let mut norm_dims = Vec::new();
291 for &dim in dims {
292 let d = if dim < 0 {
293 (ndim as isize + dim) as usize
294 } else {
295 dim as usize
296 };
297
298 if d >= ndim {
299 return Err(TorshError::InvalidArgument(format!(
300 "Dimension {} out of range for {}-D tensor",
301 dim, ndim
302 )));
303 }
304 norm_dims.push(d);
305 }
306
307 let data = self.to_vec()?;
308 let shape = self.shape().to_vec();
309 let mut result_data = vec![T::default(); data.len()];
310
311 let mut strides = vec![1; ndim];
313 for i in (0..ndim - 1).rev() {
314 strides[i] = strides[i + 1] * shape[i + 1];
315 }
316
317 for i in 0..data.len() {
319 let mut indices = vec![0; ndim];
320 let mut remainder = i;
321
322 for d in 0..ndim {
323 indices[d] = remainder / strides[d];
324 remainder %= strides[d];
325 }
326
327 for (dim_idx, &roll_dim) in norm_dims.iter().enumerate() {
329 let shift = shifts[dim_idx];
330 let dim_size = shape[roll_dim] as isize;
331 let rolled =
332 ((indices[roll_dim] as isize + shift) % dim_size + dim_size) % dim_size;
333 indices[roll_dim] = rolled as usize;
334 }
335
336 let mut rolled_idx = 0;
338 for d in 0..ndim {
339 rolled_idx += indices[d] * strides[d];
340 }
341
342 result_data[rolled_idx] = data[i];
343 }
344
345 Self::from_data(result_data, shape.to_vec(), self.device)
346 }
347
348 pub fn rot90(&self, k: isize, dims: &[isize]) -> Result<Self> {
353 if dims.len() != 2 {
354 return Err(TorshError::InvalidArgument(
355 "dims must contain exactly 2 dimensions".to_string(),
356 ));
357 }
358
359 let ndim = self.ndim();
360 if ndim < 2 {
361 return Err(TorshError::InvalidArgument(
362 "rot90 requires at least 2 dimensions".to_string(),
363 ));
364 }
365
366 let dim0 = if dims[0] < 0 {
368 (ndim as isize + dims[0]) as usize
369 } else {
370 dims[0] as usize
371 };
372
373 let dim1 = if dims[1] < 0 {
374 (ndim as isize + dims[1]) as usize
375 } else {
376 dims[1] as usize
377 };
378
379 if dim0 >= ndim || dim1 >= ndim {
380 return Err(TorshError::InvalidArgument("dims out of range".to_string()));
381 }
382
383 if dim0 == dim1 {
384 return Err(TorshError::InvalidArgument(
385 "dims must be different".to_string(),
386 ));
387 }
388
389 let k = ((k % 4) + 4) % 4;
391
392 let mut result = self.clone();
393 for _ in 0..k {
394 result = result.transpose_view(dim0, dim1)?;
396 result = result.flip(&[dim1 as isize])?;
397 }
398
399 Ok(result)
400 }
401
402 pub fn tile(&self, repeats: &[usize]) -> Result<Self> {
407 if repeats.is_empty() {
408 return Ok(self.clone());
409 }
410
411 let shape = self.shape().to_vec();
412 let ndim = shape.len();
413
414 let mut new_shape = shape.to_vec();
416 if repeats.len() > ndim {
417 let diff = repeats.len() - ndim;
418 for _ in 0..diff {
419 new_shape.insert(0, 1);
420 }
421 }
422
423 let mut output_shape = new_shape.clone();
425 let repeat_offset = if repeats.len() < output_shape.len() {
426 output_shape.len() - repeats.len()
427 } else {
428 0
429 };
430
431 for (i, &rep) in repeats.iter().enumerate() {
432 let idx = repeat_offset + i;
433 if idx < output_shape.len() {
434 output_shape[idx] *= rep;
435 }
436 }
437
438 self.repeat(repeats)
440 }
441
442 pub fn repeat_interleave(&self, repeats: usize, dim: Option<isize>) -> Result<Self> {
457 if repeats == 0 {
458 return Err(TorshError::InvalidArgument(
459 "repeats must be positive".to_string(),
460 ));
461 }
462
463 match dim {
464 None => {
465 let data = self.to_vec()?;
467 let mut result_data = Vec::with_capacity(data.len() * repeats);
468
469 for &val in data.iter() {
470 for _ in 0..repeats {
471 result_data.push(val);
472 }
473 }
474
475 Self::from_data(result_data, vec![data.len() * repeats], self.device)
476 }
477 Some(d) => {
478 let ndim = self.ndim();
479 let dim = if d < 0 {
480 (ndim as isize + d) as usize
481 } else {
482 d as usize
483 };
484
485 if dim >= ndim {
486 return Err(TorshError::InvalidArgument(format!(
487 "Dimension {} out of range for {}-D tensor",
488 d, ndim
489 )));
490 }
491
492 let shape = self.shape().to_vec();
493 let data = self.to_vec()?;
494
495 let mut output_shape = shape.clone();
497 output_shape[dim] *= repeats;
498
499 let dim_size = shape[dim];
501 let outer_size: usize = shape[..dim].iter().product();
502 let inner_size: usize = shape[dim + 1..].iter().product();
503
504 let mut result_data = Vec::with_capacity(data.len() * repeats);
505
506 for outer in 0..outer_size {
507 for d in 0..dim_size {
508 for _ in 0..repeats {
509 for inner in 0..inner_size {
510 let idx = outer * dim_size * inner_size + d * inner_size + inner;
511 result_data.push(data[idx]);
512 }
513 }
514 }
515 }
516
517 Self::from_data(result_data, output_shape, self.device)
518 }
519 }
520 }
521
522 pub fn unflatten(&self, dim: isize, sizes: &[usize]) -> Result<Self> {
537 if sizes.is_empty() {
538 return Err(TorshError::InvalidArgument(
539 "sizes cannot be empty".to_string(),
540 ));
541 }
542
543 let shape = self.shape().to_vec();
544 let ndim = shape.len();
545
546 let dim = if dim < 0 {
548 (ndim as isize + dim) as usize
549 } else {
550 dim as usize
551 };
552
553 if dim >= ndim {
554 return Err(TorshError::InvalidArgument(format!(
555 "Dimension {} out of range for {}-D tensor",
556 dim, ndim
557 )));
558 }
559
560 let sizes_product: usize = sizes.iter().product();
562 if sizes_product != shape[dim] {
563 return Err(TorshError::InvalidArgument(format!(
564 "sizes product {} does not match dimension size {}",
565 sizes_product, shape[dim]
566 )));
567 }
568
569 let mut new_shape = Vec::new();
571 new_shape.extend_from_slice(&shape[..dim]);
572 new_shape.extend_from_slice(sizes);
573 new_shape.extend_from_slice(&shape[dim + 1..]);
574
575 let data = self.to_vec()?;
577 Self::from_data(data, new_shape, self.device)
578 }
579
580 pub fn take_along_dim(&self, indices: &Tensor<i64>, dim: Option<isize>) -> Result<Self> {
596 match dim {
597 None => {
598 let data = self.to_vec()?;
600 let idx_data = indices.to_vec()?;
601
602 let mut result = Vec::with_capacity(idx_data.len());
603
604 for &idx in idx_data.iter() {
605 if idx < 0 || idx as usize >= data.len() {
606 return Err(TorshError::InvalidArgument(format!(
607 "Index {} out of range for tensor with {} elements",
608 idx,
609 data.len()
610 )));
611 }
612 result.push(data[idx as usize]);
613 }
614
615 Self::from_data(result, indices.shape().to_vec(), self.device)
616 }
617 Some(d) => {
618 let ndim = self.ndim();
619 let dim = if d < 0 {
620 (ndim as isize + d) as usize
621 } else {
622 d as usize
623 };
624
625 if dim >= ndim {
626 return Err(TorshError::InvalidArgument(format!(
627 "Dimension {} out of range for {}-D tensor",
628 d, ndim
629 )));
630 }
631
632 let self_shape = self.shape().to_vec();
633 let indices_shape = indices.shape().to_vec();
634
635 if self_shape.len() != indices_shape.len() {
637 return Err(TorshError::ShapeMismatch {
638 expected: self_shape.clone(),
639 got: indices_shape.clone(),
640 });
641 }
642
643 for (i, (&s, &idx_s)) in self_shape.iter().zip(indices_shape.iter()).enumerate() {
644 if i != dim && s != idx_s {
645 return Err(TorshError::ShapeMismatch {
646 expected: self_shape.clone(),
647 got: indices_shape.clone(),
648 });
649 }
650 }
651
652 let data = self.to_vec()?;
653 let idx_data = indices.to_vec()?;
654
655 let dim_size = self_shape[dim];
656 let outer_size: usize = self_shape[..dim].iter().product();
657 let inner_size: usize = self_shape[dim + 1..].iter().product();
658
659 let indices_dim_size = indices_shape[dim];
660 let mut result = Vec::with_capacity(idx_data.len());
661
662 for outer in 0..outer_size {
663 for d in 0..indices_dim_size {
664 for inner in 0..inner_size {
665 let idx_flat =
666 outer * indices_dim_size * inner_size + d * inner_size + inner;
667 let gather_idx = idx_data[idx_flat];
668
669 if gather_idx < 0 || gather_idx as usize >= dim_size {
670 return Err(TorshError::InvalidArgument(format!(
671 "Index {} out of range for dimension size {}",
672 gather_idx, dim_size
673 )));
674 }
675
676 let src_idx = outer * dim_size * inner_size
677 + (gather_idx as usize) * inner_size
678 + inner;
679
680 result.push(data[src_idx]);
681 }
682 }
683 }
684
685 Self::from_data(result, indices_shape, self.device)
686 }
687 }
688 }
689}
690
691#[cfg(test)]
692mod tests {
693 use super::*;
694 use torsh_core::device::DeviceType;
695
696 #[test]
698 fn test_stack_1d() {
699 let a = Tensor::from_data(vec![1.0f32, 2.0], vec![2], DeviceType::Cpu)
700 .expect("failed to create tensor a");
701 let b = Tensor::from_data(vec![3.0f32, 4.0], vec![2], DeviceType::Cpu)
702 .expect("failed to create tensor b");
703
704 let result = Tensor::stack(&[a, b], 0).expect("stack should succeed for 1d tensors");
705
706 assert_eq!(result.shape().dims(), &[2, 2]);
707 let data = result.data().expect("failed to get stacked tensor data");
708 assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0]);
709 }
710
711 #[test]
712 fn test_stack_negative_dim() {
713 let a = Tensor::from_data(vec![1.0f32, 2.0], vec![2], DeviceType::Cpu)
714 .expect("failed to create tensor a");
715 let b = Tensor::from_data(vec![3.0f32, 4.0], vec![2], DeviceType::Cpu)
716 .expect("failed to create tensor b");
717
718 let result = Tensor::stack(&[a, b], -1).expect("stack should succeed with negative dim");
719 assert_eq!(result.shape().dims(), &[2, 2]);
720 }
721
722 #[test]
724 fn test_chunk_even() {
725 let tensor = Tensor::from_data(
726 vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0],
727 vec![6],
728 DeviceType::Cpu,
729 )
730 .expect("failed to create tensor for chunk_even");
731
732 let chunks = tensor.chunk(3, 0).expect("chunk into 3 should succeed");
733 assert_eq!(chunks.len(), 3);
734 assert_eq!(chunks[0].shape().dims(), &[2]);
735 assert_eq!(
736 chunks[0].data().expect("failed to get chunk 0 data"),
737 vec![1.0, 2.0]
738 );
739 assert_eq!(
740 chunks[1].data().expect("failed to get chunk 1 data"),
741 vec![3.0, 4.0]
742 );
743 assert_eq!(
744 chunks[2].data().expect("failed to get chunk 2 data"),
745 vec![5.0, 6.0]
746 );
747 }
748
749 #[test]
750 fn test_chunk_uneven() {
751 let tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0, 4.0, 5.0], vec![5], DeviceType::Cpu)
752 .expect("failed to create tensor for chunk_uneven");
753
754 let chunks = tensor.chunk(2, 0).expect("uneven chunk should succeed");
755 assert_eq!(chunks.len(), 2);
756 assert_eq!(chunks[0].shape().dims(), &[3]);
757 assert_eq!(chunks[1].shape().dims(), &[2]);
758 }
759
760 #[test]
762 fn test_split_even() {
763 let tensor = Tensor::from_data(
764 vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0],
765 vec![6],
766 DeviceType::Cpu,
767 )
768 .expect("failed to create tensor for split_even");
769
770 let splits = tensor.split(2, 0).expect("split by 2 should succeed");
771 assert_eq!(splits.len(), 3);
772 assert_eq!(
773 splits[0].data().expect("failed to get split 0 data"),
774 vec![1.0, 2.0]
775 );
776 assert_eq!(
777 splits[1].data().expect("failed to get split 1 data"),
778 vec![3.0, 4.0]
779 );
780 }
781
782 #[test]
783 fn test_split_uneven() {
784 let tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0, 4.0, 5.0], vec![5], DeviceType::Cpu)
785 .expect("failed to create tensor for split_uneven");
786
787 let splits = tensor.split(2, 0).expect("uneven split should succeed");
788 assert_eq!(splits.len(), 3);
789 assert_eq!(splits[0].shape().dims(), &[2]);
790 assert_eq!(splits[1].shape().dims(), &[2]);
791 assert_eq!(splits[2].shape().dims(), &[1]); }
793
794 #[test]
796 fn test_flip_1d() {
797 let tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0, 4.0], vec![4], DeviceType::Cpu)
798 .expect("failed to create 1d tensor for flip");
799
800 let result = tensor.flip(&[0]).expect("flip dim 0 should succeed");
801 assert_eq!(
802 result.data().expect("failed to get flipped data"),
803 vec![4.0, 3.0, 2.0, 1.0]
804 );
805 }
806
807 #[test]
808 fn test_flip_2d() {
809 let tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0, 4.0], vec![2, 2], DeviceType::Cpu)
810 .expect("failed to create 2d tensor for flip");
811
812 let result = tensor.flip(&[0]).expect("flip 2d dim 0 should succeed");
813 assert_eq!(
814 result.data().expect("failed to get 2d flipped data"),
815 vec![3.0, 4.0, 1.0, 2.0]
816 );
817 }
818
819 #[test]
820 fn test_fliplr() {
821 let tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0, 4.0], vec![2, 2], DeviceType::Cpu)
822 .expect("failed to create tensor for fliplr");
823
824 let result = tensor.fliplr().expect("fliplr should succeed");
825 assert_eq!(
826 result.data().expect("failed to get fliplr data"),
827 vec![2.0, 1.0, 4.0, 3.0]
828 );
829 }
830
831 #[test]
832 fn test_flipud() {
833 let tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0, 4.0], vec![2, 2], DeviceType::Cpu)
834 .expect("failed to create tensor for flipud");
835
836 let result = tensor.flipud().expect("flipud should succeed");
837 assert_eq!(
838 result.data().expect("failed to get flipud data"),
839 vec![3.0, 4.0, 1.0, 2.0]
840 );
841 }
842
843 #[test]
845 fn test_roll_1d() {
846 let tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0, 4.0], vec![4], DeviceType::Cpu)
847 .expect("failed to create tensor for roll");
848
849 let result = tensor.roll(&[1], &[0]).expect("roll by 1 should succeed");
850 assert_eq!(
851 result.data().expect("failed to get rolled data"),
852 vec![4.0, 1.0, 2.0, 3.0]
853 );
854 }
855
856 #[test]
857 fn test_roll_negative() {
858 let tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0, 4.0], vec![4], DeviceType::Cpu)
859 .expect("failed to create tensor for negative roll");
860
861 let result = tensor
862 .roll(&[-1], &[0])
863 .expect("negative roll should succeed");
864 assert_eq!(
865 result.data().expect("failed to get negatively rolled data"),
866 vec![2.0, 3.0, 4.0, 1.0]
867 );
868 }
869
870 #[test]
872 fn test_rot90_once() {
873 let tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0, 4.0], vec![2, 2], DeviceType::Cpu)
874 .expect("failed to create tensor for rot90");
875
876 let result = tensor.rot90(1, &[0, 1]).expect("rot90 once should succeed");
877 assert_eq!(result.shape().dims(), &[2, 2]);
878 }
880
881 #[test]
882 fn test_rot90_twice() {
883 let tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0, 4.0], vec![2, 2], DeviceType::Cpu)
884 .expect("failed to create tensor for rot90 twice");
885
886 let result = tensor
887 .rot90(2, &[0, 1])
888 .expect("rot90 twice should succeed");
889 assert_eq!(result.shape().dims(), &[2, 2]);
890 assert_eq!(
892 result.data().expect("failed to get rot90 twice data"),
893 vec![4.0, 3.0, 2.0, 1.0]
894 );
895 }
896
897 #[test]
899 fn test_tile_1d() {
900 let tensor = Tensor::from_data(vec![1.0f32, 2.0], vec![2], DeviceType::Cpu)
901 .expect("failed to create tensor for tile 1d");
902
903 let result = tensor.tile(&[2]).expect("tile 1d should succeed");
904 assert_eq!(result.shape().dims(), &[4]);
905 assert_eq!(
906 result.data().expect("failed to get tiled 1d data"),
907 vec![1.0, 2.0, 1.0, 2.0]
908 );
909 }
910
911 #[test]
912 fn test_tile_2d() {
913 let tensor = Tensor::from_data(vec![1.0f32, 2.0], vec![1, 2], DeviceType::Cpu)
914 .expect("failed to create tensor for tile 2d");
915
916 let result = tensor.tile(&[2, 1]).expect("tile 2d should succeed");
917 assert_eq!(result.shape().dims(), &[2, 2]);
918 }
919
920 #[test]
922 fn test_repeat_interleave_flatten() {
923 let tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0], vec![3], DeviceType::Cpu)
924 .expect("failed to create tensor for repeat_interleave");
925
926 let result = tensor
927 .repeat_interleave(2, None)
928 .expect("repeat_interleave flatten should succeed");
929 assert_eq!(result.shape().dims(), &[6]);
930 assert_eq!(
931 result.data().expect("failed to get repeat_interleave data"),
932 vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0]
933 );
934 }
935
936 #[test]
937 fn test_repeat_interleave_dim() {
938 let tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0, 4.0], vec![2, 2], DeviceType::Cpu)
939 .expect("failed to create tensor for repeat_interleave dim");
940
941 let result = tensor
942 .repeat_interleave(2, Some(0))
943 .expect("repeat_interleave along dim 0 should succeed");
944 assert_eq!(result.shape().dims(), &[4, 2]);
945 }
946
947 #[test]
949 fn test_unflatten_basic() {
950 let tensor = Tensor::from_data(
951 vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0],
952 vec![6],
953 DeviceType::Cpu,
954 )
955 .expect("failed to create tensor for unflatten");
956
957 let result = tensor
958 .unflatten(0, &[2, 3])
959 .expect("unflatten to [2,3] should succeed");
960 assert_eq!(result.shape().dims(), &[2, 3]);
961 assert_eq!(
962 result.data().expect("failed to get unflattened data"),
963 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
964 );
965 }
966
967 #[test]
968 fn test_unflatten_multiple_dims() {
969 let tensor = Tensor::from_data(vec![1.0f32; 24], vec![24], DeviceType::Cpu)
970 .expect("failed to create tensor for unflatten multiple dims");
971
972 let result = tensor
973 .unflatten(0, &[2, 3, 4])
974 .expect("unflatten to [2,3,4] should succeed");
975 assert_eq!(result.shape().dims(), &[2, 3, 4]);
976 }
977
978 #[test]
980 fn test_take_along_dim_flatten() {
981 let tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0, 4.0], vec![4], DeviceType::Cpu)
982 .expect("failed to create tensor for take_along_dim");
983
984 let indices = Tensor::from_data(vec![0i64, 2, 1], vec![3], DeviceType::Cpu)
985 .expect("failed to create indices tensor");
986
987 let result = tensor
988 .take_along_dim(&indices, None)
989 .expect("take_along_dim flatten should succeed");
990 assert_eq!(result.shape().dims(), &[3]);
991 assert_eq!(
992 result
993 .data()
994 .expect("failed to get take_along_dim flatten data"),
995 vec![1.0, 3.0, 2.0]
996 );
997 }
998
999 #[test]
1000 fn test_take_along_dim_2d() {
1001 let tensor = Tensor::from_data(
1002 vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0],
1003 vec![2, 3],
1004 DeviceType::Cpu,
1005 )
1006 .expect("failed to create 2d tensor for take_along_dim");
1007
1008 let indices = Tensor::from_data(vec![0i64, 2, 1, 1, 0, 2], vec![2, 3], DeviceType::Cpu)
1009 .expect("failed to create 2d indices tensor");
1010
1011 let result = tensor
1012 .take_along_dim(&indices, Some(1))
1013 .expect("take_along_dim 2d should succeed");
1014 assert_eq!(result.shape().dims(), &[2, 3]);
1015 assert_eq!(
1018 result.data().expect("failed to get take_along_dim 2d data"),
1019 vec![1.0, 3.0, 2.0, 5.0, 4.0, 6.0]
1020 );
1021 }
1022}