1use std::sync::{Arc, RwLock};
16use torsh_core::{
17 dtype::TensorElement,
18 error::{Result, TorshError},
19 shape::Shape,
20};
21
22use crate::core_ops::{Operation, Tensor};
23
24impl<T: TensorElement + Copy> Tensor<T> {
25 pub fn size(&self, dim: i32) -> Result<usize> {
27 self.shape().size(dim)
28 }
29
30 pub fn view(&self, shape: &[i32]) -> Result<Self> {
77 let infer_count = shape.iter().filter(|&&x| x == -1).count();
79 if infer_count > 1 {
80 return Err(TorshError::InvalidShape(
81 "Only one dimension can be inferred (only one -1 allowed)".to_string(),
82 ));
83 }
84
85 let new_shape: Result<Vec<usize>> = shape
86 .iter()
87 .map(|&d| {
88 if d == -1 {
89 let known_dims: Result<Vec<usize>> = shape
91 .iter()
92 .filter(|&&x| x != -1)
93 .map(|&x| {
94 if x < 0 {
95 Err(TorshError::InvalidShape(format!(
96 "Invalid dimension size: {x} (negative dimensions not allowed except -1)"
97 )))
98 } else {
99 Ok(x as usize)
100 }
101 })
102 .collect();
103
104 let known_dims = known_dims?;
105
106 let known_product = known_dims.iter().try_fold(1usize, |acc, &dim| {
108 acc.checked_mul(dim).ok_or_else(|| {
109 TorshError::InvalidShape(
110 "Shape dimensions too large (would overflow)".to_string()
111 )
112 })
113 })?;
114
115 if known_product == 0 {
116 return Err(TorshError::InvalidShape(
117 "Cannot infer dimension with zero-sized dimensions".to_string(),
118 ));
119 }
120
121 let total = self.numel();
122 if total % known_product != 0 {
123 return Err(TorshError::InvalidShape(
124 "Cannot infer dimension: size is not divisible".to_string(),
125 ));
126 }
127
128 Ok(total / known_product)
129 } else if d < 0 {
130 Err(TorshError::InvalidShape(format!(
131 "Invalid dimension size: {d}"
132 )))
133 } else {
134 Ok(d as usize)
135 }
136 })
137 .collect();
138
139 let new_shape = new_shape?;
140
141 let new_numel = new_shape.iter().try_fold(1usize, |acc, &dim| {
143 acc.checked_mul(dim).ok_or_else(|| {
144 TorshError::InvalidShape(
145 "Reshaped tensor would be too large (would overflow)".to_string(),
146 )
147 })
148 })?;
149
150 if new_numel != self.numel() {
151 return Err(TorshError::InvalidShape(format!(
152 "Shape {:?} is invalid for tensor of size {}",
153 new_shape,
154 self.numel()
155 )));
156 }
157
158 let data = self.to_vec()?;
160 Self::from_data(data, new_shape, self.device)
161 }
162
163 pub fn view_as(&self, shape: &[usize]) -> Result<Self> {
166 let new_numel = shape.iter().product::<usize>();
168 if new_numel != self.numel() {
169 return Err(TorshError::InvalidShape(format!(
170 "Shape {:?} is invalid for tensor of size {}",
171 shape,
172 self.numel()
173 )));
174 }
175
176 if !self.is_contiguous() {
179 return Err(TorshError::InvalidShape(
180 "Cannot create efficient view of non-contiguous tensor".to_string(),
181 ));
182 }
183
184 Ok(Self {
186 storage: self.storage.clone(),
187 shape: Shape::new(shape.to_vec()),
188 device: self.device,
189 requires_grad: self.requires_grad,
190 grad: Arc::new(RwLock::new(None)), operation: Operation::Leaf, strides: None, storage_offset: self.storage_offset,
194 base_tensor: if self.is_view() {
195 self.base_tensor.clone()
197 } else {
198 Some(Arc::downgrade(&Arc::new(self.clone())))
200 },
201 })
202 }
203
204 pub fn slice_tensor(&self, dim: usize, start: usize, end: usize) -> Result<Self> {
206 if dim >= self.ndim() {
207 return Err(TorshError::InvalidArgument(format!(
208 "Dimension {} out of range for tensor with {} dimensions",
209 dim,
210 self.ndim()
211 )));
212 }
213
214 let shape = self.shape.dims();
215 if start >= shape[dim] || end > shape[dim] || start >= end {
216 return Err(TorshError::InvalidArgument(format!(
217 "Invalid slice range [{}:{}] for dimension {} of size {}",
218 start, end, dim, shape[dim]
219 )));
220 }
221
222 let mut new_shape = shape.to_vec();
224 new_shape[dim] = end - start;
225
226 let current_strides = self.strides();
228 let offset_adjustment = start * current_strides[dim];
229
230 Ok(Self {
231 storage: self.storage.clone(),
232 shape: Shape::new(new_shape),
233 device: self.device,
234 requires_grad: self.requires_grad,
235 grad: Arc::new(RwLock::new(None)),
236 operation: Operation::Leaf,
237 strides: Some(current_strides),
238 storage_offset: self.storage_offset + offset_adjustment,
239 base_tensor: if self.is_view() {
240 self.base_tensor.clone()
241 } else {
242 Some(Arc::downgrade(&Arc::new(self.clone())))
243 },
244 })
245 }
246
247 pub fn transpose_view(&self, dim0: usize, dim1: usize) -> Result<Self> {
249 if dim0 >= self.ndim() || dim1 >= self.ndim() {
250 return Err(TorshError::InvalidArgument(format!(
251 "Dimensions {} and {} out of range for tensor with {} dimensions",
252 dim0,
253 dim1,
254 self.ndim()
255 )));
256 }
257
258 if dim0 == dim1 {
259 return Ok(self.clone());
260 }
261
262 let mut new_shape = self.shape.dims().to_vec();
264 let mut new_strides = self.strides();
265
266 new_shape.swap(dim0, dim1);
268 new_strides.swap(dim0, dim1);
269
270 Ok(Self {
271 storage: self.storage.clone(),
272 shape: Shape::new(new_shape),
273 device: self.device,
274 requires_grad: self.requires_grad,
275 grad: Arc::new(RwLock::new(None)),
276 operation: Operation::Leaf,
277 strides: Some(new_strides),
278 storage_offset: self.storage_offset,
279 base_tensor: if self.is_view() {
280 self.base_tensor.clone()
281 } else {
282 Some(Arc::downgrade(&Arc::new(self.clone())))
283 },
284 })
285 }
286
287 pub fn squeeze_tensor(&self, dim: usize) -> Result<Self> {
289 if dim >= self.ndim() {
290 return Err(TorshError::InvalidArgument(format!(
291 "Dimension {} out of range for tensor with {} dimensions",
292 dim,
293 self.ndim()
294 )));
295 }
296
297 let shape = self.shape.dims();
298 if shape[dim] != 1 {
299 return Err(TorshError::InvalidArgument(format!(
300 "Cannot squeeze dimension {} of size {}",
301 dim, shape[dim]
302 )));
303 }
304
305 let mut new_shape = shape.to_vec();
307 new_shape.remove(dim);
308
309 let mut new_strides = self.strides();
310 new_strides.remove(dim);
311
312 Ok(Self {
313 storage: self.storage.clone(),
314 shape: Shape::new(new_shape),
315 device: self.device,
316 requires_grad: self.requires_grad,
317 grad: Arc::new(RwLock::new(None)),
318 operation: Operation::Leaf,
319 strides: Some(new_strides),
320 storage_offset: self.storage_offset,
321 base_tensor: if self.is_view() {
322 self.base_tensor.clone()
323 } else {
324 Some(Arc::downgrade(&Arc::new(self.clone())))
325 },
326 })
327 }
328
329 pub fn unsqueeze_tensor(&self, dim: usize) -> Result<Self> {
331 if dim > self.ndim() {
332 return Err(TorshError::InvalidArgument(format!(
333 "Dimension {} out of range for insertion in tensor with {} dimensions",
334 dim,
335 self.ndim()
336 )));
337 }
338
339 let mut new_shape = self.shape.dims().to_vec();
341 new_shape.insert(dim, 1);
342
343 let mut new_strides = self.strides();
344 let new_stride = if dim == new_shape.len() - 1 {
346 1 } else {
348 new_strides[dim] };
350 new_strides.insert(dim, new_stride);
351
352 Ok(Self {
353 storage: self.storage.clone(),
354 shape: Shape::new(new_shape),
355 device: self.device,
356 requires_grad: self.requires_grad,
357 grad: Arc::new(RwLock::new(None)),
358 operation: Operation::Leaf,
359 strides: Some(new_strides),
360 storage_offset: self.storage_offset,
361 base_tensor: if self.is_view() {
362 self.base_tensor.clone()
363 } else {
364 Some(Arc::downgrade(&Arc::new(self.clone())))
365 },
366 })
367 }
368
369 pub fn transpose(&self, dim0: i32, dim1: i32) -> Result<Self> {
414 let ndim = self.ndim();
415 let dim0 = if dim0 < 0 {
416 (ndim as i32 + dim0) as usize
417 } else {
418 dim0 as usize
419 };
420 let dim1 = if dim1 < 0 {
421 (ndim as i32 + dim1) as usize
422 } else {
423 dim1 as usize
424 };
425
426 if dim0 >= ndim || dim1 >= ndim {
427 return Err(TorshError::InvalidArgument(format!(
428 "Dimensions {} and {} out of range for tensor with {} dimensions",
429 dim0, dim1, ndim
430 )));
431 }
432
433 if ndim == 2 && dim0 != dim1 {
434 self.transpose_2d()
435 } else {
436 self.transpose_view(dim0, dim1)
437 }
438 }
439
440 fn transpose_2d(&self) -> Result<Self> {
442 let shape = self.shape.dims();
443 if shape.len() != 2 {
444 return Err(TorshError::InvalidArgument(
445 "transpose_2d only works with 2D tensors".to_string(),
446 ));
447 }
448
449 let (rows, cols) = (shape[0], shape[1]);
450 let data = self.to_vec()?;
451 let mut transposed_data = Vec::with_capacity(data.len());
452
453 for col in 0..cols {
454 for row in 0..rows {
455 transposed_data.push(data[row * cols + col]);
456 }
457 }
458
459 Self::from_data(transposed_data, vec![cols, rows], self.device)
460 }
461
462 pub fn permute(&self, dims: &[i32]) -> Result<Self> {
464 let ndim = self.ndim();
465
466 if dims.len() != ndim {
467 return Err(TorshError::InvalidArgument(format!(
468 "Number of dimensions in permutation ({}) doesn't match tensor dimensions ({})",
469 dims.len(),
470 ndim
471 )));
472 }
473
474 let perm_dims: Result<Vec<usize>> = dims
476 .iter()
477 .map(|&d| {
478 let dim = if d < 0 { ndim as i32 + d } else { d } as usize;
479 if dim >= ndim {
480 Err(TorshError::InvalidArgument(format!(
481 "Dimension {} out of range for tensor with {} dimensions",
482 d, ndim
483 )))
484 } else {
485 Ok(dim)
486 }
487 })
488 .collect();
489
490 let perm_dims = perm_dims?;
491
492 let mut sorted_dims = perm_dims.clone();
494 sorted_dims.sort_unstable();
495 for i in 0..ndim {
496 if sorted_dims[i] != i {
497 return Err(TorshError::InvalidArgument(
498 "Permutation must contain each dimension exactly once".to_string(),
499 ));
500 }
501 }
502
503 let old_shape = self.shape.dims();
505 let old_strides = self.strides();
506
507 let new_shape: Vec<usize> = perm_dims.iter().map(|&i| old_shape[i]).collect();
508 let new_strides: Vec<usize> = perm_dims.iter().map(|&i| old_strides[i]).collect();
509
510 Ok(Self {
511 storage: self.storage.clone(),
512 shape: Shape::new(new_shape),
513 device: self.device,
514 requires_grad: self.requires_grad,
515 grad: Arc::new(RwLock::new(None)),
516 operation: Operation::Leaf,
517 strides: Some(new_strides),
518 storage_offset: self.storage_offset,
519 base_tensor: if self.is_view() {
520 self.base_tensor.clone()
521 } else {
522 Some(Arc::downgrade(&Arc::new(self.clone())))
523 },
524 })
525 }
526
527 pub fn squeeze(&self, dim: i32) -> Result<Self> {
568 let ndim = self.ndim();
569 let dim = if dim < 0 {
570 (ndim as i32 + dim) as usize
571 } else {
572 dim as usize
573 };
574
575 self.squeeze_tensor(dim)
576 }
577
578 pub fn squeeze_all(&self) -> Result<Self> {
580 let shape = self.shape.dims();
581 let new_shape: Vec<usize> = shape.iter().copied().filter(|&s| s != 1).collect();
582
583 if new_shape.is_empty() {
584 let data = self.to_vec()?;
586 Self::from_data(data, vec![], self.device)
587 } else {
588 let data = self.to_vec()?;
589 Self::from_data(data, new_shape, self.device)
590 }
591 }
592
593 pub fn unsqueeze(&self, dim: i32) -> Result<Self> {
633 let ndim = self.ndim();
634 let dim = if dim < 0 {
635 (ndim as i32 + dim + 1) as usize
636 } else {
637 dim as usize
638 };
639
640 self.unsqueeze_tensor(dim)
641 }
642
643 pub fn reshape(&self, shape: &[i32]) -> Result<Self> {
675 self.view(shape)
676 }
677
678 pub fn is_contiguous(&self) -> bool {
680 let default_strides = self.compute_default_strides();
682 let current_strides = self.strides();
683
684 current_strides == default_strides
685 }
686
687 pub fn contiguous(&self) -> Result<Self> {
689 if self.is_contiguous() {
690 Ok(self.clone())
691 } else {
692 let data = self.to_vec()?;
694 Self::from_data(data, self.shape.dims().to_vec(), self.device)
695 }
696 }
697
698 pub fn expand(&self, shape: &[usize]) -> Result<Self> {
700 let old_shape = self.shape.dims();
701
702 if shape.len() < old_shape.len() {
704 return Err(TorshError::InvalidShape(
705 "Cannot expand to smaller number of dimensions".to_string(),
706 ));
707 }
708
709 let offset = shape.len() - old_shape.len();
711 for (i, &old_dim) in old_shape.iter().enumerate() {
712 let new_dim = shape[offset + i];
713 if old_dim != 1 && old_dim != new_dim {
714 return Err(TorshError::InvalidShape(format!(
715 "Cannot expand dimension {} from {} to {}",
716 i, old_dim, new_dim
717 )));
718 }
719 }
720
721 let source_data = self.to_vec()?;
724 let target_numel = shape.iter().product();
725 let mut result_data = Vec::with_capacity(target_numel);
726
727 self.expand_data_recursive(&source_data, &mut result_data, shape, old_shape, 0, 0)?;
728
729 Self::from_data(result_data, shape.to_vec(), self.device)
730 }
731
732 fn expand_data_recursive(
734 &self,
735 source: &[T],
736 dest: &mut Vec<T>,
737 target_shape: &[usize],
738 source_shape: &[usize],
739 target_dim: usize,
740 source_offset: usize,
741 ) -> Result<()> {
742 if target_dim == target_shape.len() {
743 dest.push(source[source_offset]);
745 return Ok(());
746 }
747
748 let target_size = target_shape[target_dim];
749 let source_dim_idx = target_dim + source_shape.len() - target_shape.len();
750
751 if source_dim_idx < source_shape.len() {
752 let source_size = source_shape[source_dim_idx];
753 let stride = if source_dim_idx + 1 < source_shape.len() {
754 source_shape[source_dim_idx + 1..].iter().product()
755 } else {
756 1
757 };
758
759 if source_size == 1 {
760 for _ in 0..target_size {
762 self.expand_data_recursive(
763 source,
764 dest,
765 target_shape,
766 source_shape,
767 target_dim + 1,
768 source_offset,
769 )?;
770 }
771 } else {
772 for i in 0..target_size {
774 self.expand_data_recursive(
775 source,
776 dest,
777 target_shape,
778 source_shape,
779 target_dim + 1,
780 source_offset + i * stride,
781 )?;
782 }
783 }
784 } else {
785 for _ in 0..target_size {
787 self.expand_data_recursive(
788 source,
789 dest,
790 target_shape,
791 source_shape,
792 target_dim + 1,
793 source_offset,
794 )?;
795 }
796 }
797
798 Ok(())
799 }
800
801 pub fn movedim(&self, source: &[isize], destination: &[isize]) -> Result<Self> {
816 if source.len() != destination.len() {
817 return Err(TorshError::InvalidArgument(
818 "source and destination must have the same length".to_string(),
819 ));
820 }
821
822 let ndim = self.ndim();
823
824 let norm_source: Result<Vec<usize>> = source
826 .iter()
827 .map(|&d| {
828 let dim = if d < 0 {
829 (ndim as isize + d) as usize
830 } else {
831 d as usize
832 };
833 if dim >= ndim {
834 Err(TorshError::InvalidArgument(format!(
835 "Dimension {} out of range for {}-D tensor",
836 d, ndim
837 )))
838 } else {
839 Ok(dim)
840 }
841 })
842 .collect();
843 let norm_source = norm_source?;
844
845 let norm_dest: Result<Vec<usize>> = destination
846 .iter()
847 .map(|&d| {
848 let dim = if d < 0 {
849 (ndim as isize + d) as usize
850 } else {
851 d as usize
852 };
853 if dim >= ndim {
854 Err(TorshError::InvalidArgument(format!(
855 "Dimension {} out of range for {}-D tensor",
856 d, ndim
857 )))
858 } else {
859 Ok(dim)
860 }
861 })
862 .collect();
863 let norm_dest = norm_dest?;
864
865 for i in 0..norm_source.len() {
867 for j in i + 1..norm_source.len() {
868 if norm_source[i] == norm_source[j] {
869 return Err(TorshError::InvalidArgument(
870 "repeated dim in source".to_string(),
871 ));
872 }
873 }
874 }
875
876 for i in 0..norm_dest.len() {
878 for j in i + 1..norm_dest.len() {
879 if norm_dest[i] == norm_dest[j] {
880 return Err(TorshError::InvalidArgument(
881 "repeated dim in destination".to_string(),
882 ));
883 }
884 }
885 }
886
887 let mut result_perm = vec![0; ndim];
889 let mut used = vec![false; ndim];
890
891 for (&src, &dst) in norm_source.iter().zip(norm_dest.iter()) {
893 result_perm[dst] = src;
894 used[dst] = true;
895 }
896
897 let remaining_dims: Vec<usize> = (0..ndim).filter(|d| !norm_source.contains(d)).collect();
899
900 let mut remaining_idx = 0;
901 for i in 0..ndim {
902 if !used[i] {
903 result_perm[i] = remaining_dims[remaining_idx];
904 remaining_idx += 1;
905 }
906 }
907
908 let perm_i32: Vec<i32> = result_perm.iter().map(|&d| d as i32).collect();
910 self.permute(&perm_i32)
911 }
912
913 pub fn moveaxis(&self, source: &[isize], destination: &[isize]) -> Result<Self> {
922 self.movedim(source, destination)
923 }
924
925 pub fn swapaxes(&self, axis0: isize, axis1: isize) -> Result<Self> {
940 let ndim = self.ndim();
941
942 let dim0 = if axis0 < 0 {
944 (ndim as isize + axis0) as usize
945 } else {
946 axis0 as usize
947 };
948 let dim1 = if axis1 < 0 {
949 (ndim as isize + axis1) as usize
950 } else {
951 axis1 as usize
952 };
953
954 if dim0 >= ndim {
955 return Err(TorshError::InvalidArgument(format!(
956 "Dimension {} out of range for {}-D tensor",
957 axis0, ndim
958 )));
959 }
960 if dim1 >= ndim {
961 return Err(TorshError::InvalidArgument(format!(
962 "Dimension {} out of range for {}-D tensor",
963 axis1, ndim
964 )));
965 }
966
967 let mut perm: Vec<i32> = (0..ndim as i32).collect();
969 perm.swap(dim0, dim1);
970
971 self.permute(&perm)
972 }
973
974 pub fn swapdims(&self, dim0: isize, dim1: isize) -> Result<Self> {
979 self.swapaxes(dim0, dim1)
980 }
981
982 pub fn broadcast_to(&self, shape: &[usize]) -> Result<Self> {
996 self.expand(shape)
998 }
999
1000 pub fn expand_as(&self, other: &Self) -> Result<Self> {
1015 self.broadcast_to(other.shape().dims())
1016 }
1017}
1018
1019#[cfg(test)]
1020mod tests {
1021 use super::*;
1022 use torsh_core::device::DeviceType;
1023
1024 #[test]
1025 fn test_tensor_view() {
1026 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
1027 let tensor = Tensor::from_data(data, vec![2, 3], DeviceType::Cpu)
1028 .expect("tensor creation should succeed");
1029
1030 let reshaped = tensor.view(&[3, 2]).expect("view should succeed");
1031 assert_eq!(reshaped.shape().dims(), &[3, 2]);
1032 assert_eq!(reshaped.numel(), 6);
1033 }
1034
1035 #[test]
1036 fn test_tensor_view_with_inference() {
1037 let data = vec![1.0f32; 24];
1038 let tensor = Tensor::from_data(data, vec![2, 3, 4], DeviceType::Cpu)
1039 .expect("tensor creation should succeed");
1040
1041 let reshaped = tensor.view(&[6, -1]).expect("view should succeed");
1042 assert_eq!(reshaped.shape().dims(), &[6, 4]);
1043 }
1044
1045 #[test]
1046 fn test_tensor_slice() {
1047 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
1048 let tensor = Tensor::from_data(data, vec![2, 3], DeviceType::Cpu)
1049 .expect("tensor creation should succeed");
1050
1051 let slice = tensor.slice_tensor(1, 1, 3).expect("slice should succeed");
1052 assert_eq!(slice.shape().dims(), &[2, 2]);
1053 }
1054
1055 #[test]
1056 fn test_tensor_transpose() {
1057 let data = vec![1.0f32, 2.0, 3.0, 4.0];
1058 let tensor = Tensor::from_data(data, vec![2, 2], DeviceType::Cpu)
1059 .expect("tensor creation should succeed");
1060
1061 let transposed = tensor.transpose(0, 1).expect("transpose should succeed");
1062 assert_eq!(transposed.shape().dims(), &[2, 2]);
1063 assert_eq!(
1064 transposed.get(&[0, 1]).expect("data access should succeed"),
1065 3.0
1066 );
1067 assert_eq!(
1068 transposed.get(&[1, 0]).expect("data access should succeed"),
1069 2.0
1070 );
1071 }
1072
1073 #[test]
1074 fn test_tensor_squeeze_unsqueeze() {
1075 let data = vec![1.0f32, 2.0, 3.0];
1076 let tensor = Tensor::from_data(data, vec![1, 3], DeviceType::Cpu)
1077 .expect("tensor creation should succeed");
1078
1079 let squeezed = tensor.squeeze(0).expect("squeeze should succeed");
1080 assert_eq!(squeezed.shape().dims(), &[3]);
1081
1082 let unsqueezed = squeezed.unsqueeze(0).expect("unsqueeze should succeed");
1083 assert_eq!(unsqueezed.shape().dims(), &[1, 3]);
1084 }
1085
1086 #[test]
1087 fn test_tensor_permute() {
1088 let data = vec![1.0f32; 24];
1089 let tensor = Tensor::from_data(data, vec![2, 3, 4], DeviceType::Cpu)
1090 .expect("tensor creation should succeed");
1091
1092 let permuted = tensor.permute(&[2, 0, 1]).expect("permute should succeed");
1093 assert_eq!(permuted.shape().dims(), &[4, 2, 3]);
1094 }
1095
1096 #[test]
1097 fn test_is_contiguous() {
1098 let data = vec![1.0f32, 2.0, 3.0, 4.0];
1099 let tensor = Tensor::from_data(data, vec![2, 2], DeviceType::Cpu)
1100 .expect("tensor creation should succeed");
1101 assert!(tensor.is_contiguous());
1102
1103 let transposed = tensor
1104 .transpose_view(0, 1)
1105 .expect("transpose view should succeed");
1106 assert!(!transposed.is_contiguous());
1107
1108 let contiguous = transposed.contiguous().expect("contiguous should succeed");
1109 assert!(contiguous.is_contiguous());
1110 }
1111
1112 #[test]
1113 fn test_expand() {
1114 let data = vec![1.0f32, 2.0];
1115 let tensor = Tensor::from_data(data, vec![1, 2], DeviceType::Cpu)
1116 .expect("tensor creation should succeed");
1117
1118 let expanded = tensor.expand(&[3, 2]).expect("expand should succeed");
1119 assert_eq!(expanded.shape().dims(), &[3, 2]);
1120 assert_eq!(expanded.numel(), 6);
1121 }
1122
1123 #[test]
1124 fn test_view_error_handling() {
1125 let data = vec![1.0f32, 2.0, 3.0];
1126 let tensor = Tensor::from_data(data, vec![3], DeviceType::Cpu)
1127 .expect("tensor creation should succeed");
1128
1129 assert!(tensor.view(&[2, 2]).is_err());
1131
1132 assert!(tensor.view(&[-1, -1]).is_err());
1134 }
1135
1136 #[test]
1137 fn test_movedim_single() {
1138 let tensor = Tensor::from_data(vec![1.0f32; 24], vec![2, 3, 4], DeviceType::Cpu)
1139 .expect("tensor creation should succeed");
1140
1141 let result = tensor.movedim(&[0], &[2]).expect("movedim should succeed");
1143 assert_eq!(result.shape().dims(), &[3, 4, 2]);
1144 }
1145
1146 #[test]
1147 fn test_movedim_multiple() {
1148 let tensor = Tensor::from_data(vec![1.0f32; 24], vec![2, 3, 4], DeviceType::Cpu)
1149 .expect("tensor creation should succeed");
1150
1151 let result = tensor
1153 .movedim(&[0, 1], &[2, 0])
1154 .expect("movedim should succeed");
1155 assert_eq!(result.shape().dims(), &[3, 4, 2]);
1156 }
1157
1158 #[test]
1159 fn test_movedim_negative_indices() {
1160 let tensor = Tensor::from_data(vec![1.0f32; 24], vec![2, 3, 4], DeviceType::Cpu)
1161 .expect("tensor creation should succeed");
1162
1163 let result = tensor.movedim(&[-1], &[0]).expect("movedim should succeed");
1165 assert_eq!(result.shape().dims(), &[4, 2, 3]);
1166 }
1167
1168 #[test]
1169 fn test_moveaxis_alias() {
1170 let tensor = Tensor::from_data(vec![1.0f32; 24], vec![2, 3, 4], DeviceType::Cpu)
1171 .expect("tensor creation should succeed");
1172
1173 let result1 = tensor.movedim(&[0], &[2]).expect("movedim should succeed");
1174 let result2 = tensor
1175 .moveaxis(&[0], &[2])
1176 .expect("moveaxis should succeed");
1177 assert_eq!(result1.shape().dims(), result2.shape().dims());
1178 }
1179
1180 #[test]
1181 fn test_swapaxes_simple() {
1182 let tensor = Tensor::from_data(
1183 vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0],
1184 vec![2, 3],
1185 DeviceType::Cpu,
1186 )
1187 .expect("tensor creation should succeed");
1188
1189 let result = tensor.swapaxes(0, 1).expect("swapaxes should succeed");
1191 assert_eq!(result.shape().dims(), &[3, 2]);
1192 }
1193
1194 #[test]
1195 fn test_swapaxes_3d() {
1196 let tensor = Tensor::from_data(vec![1.0f32; 24], vec![2, 3, 4], DeviceType::Cpu)
1197 .expect("tensor creation should succeed");
1198
1199 let result = tensor.swapaxes(0, 2).expect("swapaxes should succeed");
1201 assert_eq!(result.shape().dims(), &[4, 3, 2]);
1202 }
1203
1204 #[test]
1205 fn test_swapaxes_negative_indices() {
1206 let tensor = Tensor::from_data(vec![1.0f32; 24], vec![2, 3, 4], DeviceType::Cpu)
1207 .expect("tensor creation should succeed");
1208
1209 let result = tensor.swapaxes(-1, -2).expect("swapaxes should succeed");
1211 assert_eq!(result.shape().dims(), &[2, 4, 3]);
1212 }
1213
1214 #[test]
1215 fn test_swapdims_alias() {
1216 let tensor = Tensor::from_data(vec![1.0f32; 24], vec![2, 3, 4], DeviceType::Cpu)
1217 .expect("tensor creation should succeed");
1218
1219 let result1 = tensor.swapaxes(0, 2).expect("swapaxes should succeed");
1220 let result2 = tensor.swapdims(0, 2).expect("swapdims should succeed");
1221 assert_eq!(result1.shape().dims(), result2.shape().dims());
1222 }
1223
1224 #[test]
1225 fn test_broadcast_to_same_shape() {
1226 let tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0, 4.0], vec![2, 2], DeviceType::Cpu)
1227 .expect("tensor creation should succeed");
1228
1229 let result = tensor
1230 .broadcast_to(&[2, 2])
1231 .expect("broadcast_to should succeed");
1232 assert_eq!(result.shape().dims(), &[2, 2]);
1233 }
1234
1235 #[test]
1236 fn test_broadcast_to_expand_dim() {
1237 let tensor = Tensor::from_data(vec![1.0f32, 2.0], vec![1, 2], DeviceType::Cpu)
1238 .expect("tensor creation should succeed");
1239
1240 let result = tensor
1242 .broadcast_to(&[3, 2])
1243 .expect("broadcast_to should succeed");
1244 assert_eq!(result.shape().dims(), &[3, 2]);
1245 }
1246
1247 #[test]
1248 fn test_expand_as_basic() {
1249 let tensor = Tensor::from_data(vec![1.0f32, 2.0], vec![1, 2], DeviceType::Cpu)
1250 .expect("tensor creation should succeed");
1251
1252 let target = Tensor::from_data(vec![0.0f32; 6], vec![3, 2], DeviceType::Cpu)
1253 .expect("tensor creation should succeed");
1254
1255 let result = tensor.expand_as(&target).expect("expand_as should succeed");
1256 assert_eq!(result.shape().dims(), target.shape().dims());
1257 assert_eq!(result.shape().dims(), &[3, 2]);
1258 }
1259}