1use torsh_core::{
16 dtype::TensorElement,
17 error::{Result, TorshError},
18};
19
20use crate::Tensor;
21
22pub trait TensorManipulationExt<T: TensorElement> {
24 fn squeeze_all(&self) -> Result<Tensor<T>>;
26
27 fn squeeze_dims(&self, dims: &[i32]) -> Result<Tensor<T>>;
29
30 fn unsqueeze_dims(&self, dims: &[i32]) -> Result<Tensor<T>>;
32
33 fn add_batch_dim(&self) -> Result<Tensor<T>>;
35
36 fn remove_batch_dim(&self) -> Result<Tensor<T>>;
38
39 fn atleast_nd(&self, n: usize) -> Result<Tensor<T>>;
41
42 fn to_channel_last(&self) -> Result<Tensor<T>>;
44
45 fn to_channel_first(&self) -> Result<Tensor<T>>;
47
48 fn swap_dims(&self, dim0: i32, dim1: i32) -> Result<Tensor<T>>;
50
51 fn move_dim(&self, src: i32, dst: i32) -> Result<Tensor<T>>;
53
54 fn expand_to(&self, target_shape: &[usize]) -> Result<Tensor<T>>;
56
57 fn repeat_along(&self, dim: i32, repeats: usize) -> Result<Tensor<T>>;
59}
60
61impl<T: TensorElement + Copy> TensorManipulationExt<T> for Tensor<T> {
62 fn squeeze_all(&self) -> Result<Tensor<T>> {
63 let shape_binding = self.shape();
64 let shape = shape_binding.dims();
65 let new_shape: Vec<usize> = shape.iter().filter(|&&s| s != 1).copied().collect();
66
67 if new_shape.is_empty() {
68 self.reshape(&[1])
70 } else {
71 let new_shape_i32: Vec<i32> = new_shape.iter().map(|&s| s as i32).collect();
72 self.reshape(&new_shape_i32)
73 }
74 }
75
76 fn squeeze_dims(&self, dims: &[i32]) -> Result<Tensor<T>> {
77 let shape_binding = self.shape();
78 let shape = shape_binding.dims();
79 let ndim = shape.len() as i32;
80
81 let normalized_dims: Result<Vec<usize>> = dims
83 .iter()
84 .map(|&d| {
85 let normalized = if d < 0 { ndim + d } else { d };
86 if normalized < 0 || normalized >= ndim {
87 Err(TorshError::InvalidArgument(format!(
88 "Dimension {} out of range for tensor with {} dimensions",
89 d, ndim
90 )))
91 } else {
92 Ok(normalized as usize)
93 }
94 })
95 .collect();
96
97 let normalized_dims = normalized_dims?;
98
99 for &dim in &normalized_dims {
101 if shape[dim] != 1 {
102 return Err(TorshError::InvalidArgument(format!(
103 "Cannot squeeze dimension {} of size {}",
104 dim, shape[dim]
105 )));
106 }
107 }
108
109 let new_shape: Vec<usize> = shape
111 .iter()
112 .enumerate()
113 .filter(|(i, _)| !normalized_dims.contains(i))
114 .map(|(_, &s)| s)
115 .collect();
116
117 if new_shape.is_empty() {
118 self.reshape(&[1])
119 } else {
120 let new_shape_i32: Vec<i32> = new_shape.iter().map(|&s| s as i32).collect();
121 self.reshape(&new_shape_i32)
122 }
123 }
124
125 fn unsqueeze_dims(&self, dims: &[i32]) -> Result<Tensor<T>> {
126 let mut result = self.clone();
127
128 let mut sorted_dims: Vec<i32> = dims.to_vec();
130 sorted_dims.sort_unstable();
131
132 for (i, &dim) in sorted_dims.iter().enumerate() {
134 let adjusted_dim = dim + i as i32;
136 result = result.unsqueeze(adjusted_dim)?;
137 }
138
139 Ok(result)
140 }
141
142 fn add_batch_dim(&self) -> Result<Tensor<T>> {
143 self.unsqueeze(0)
144 }
145
146 fn remove_batch_dim(&self) -> Result<Tensor<T>> {
147 let shape_binding = self.shape();
148 let shape = shape_binding.dims();
149 if shape.is_empty() {
150 return Err(TorshError::InvalidArgument(
151 "Cannot remove batch dim from scalar tensor".to_string(),
152 ));
153 }
154
155 if shape[0] != 1 {
156 return Err(TorshError::InvalidArgument(format!(
157 "Batch dimension has size {}, expected 1",
158 shape[0]
159 )));
160 }
161
162 self.squeeze(0)
163 }
164
165 fn atleast_nd(&self, n: usize) -> Result<Tensor<T>> {
166 let shape_binding = self.shape();
167 let shape = shape_binding.dims();
168 let current_ndim = shape.len();
169
170 if current_ndim >= n {
171 return Ok(self.clone());
172 }
173
174 let mut new_shape = shape.to_vec();
175 for _ in current_ndim..n {
176 new_shape.push(1);
177 }
178
179 let new_shape_i32: Vec<i32> = new_shape.iter().map(|&s| s as i32).collect();
180 self.reshape(&new_shape_i32)
181 }
182
183 fn to_channel_last(&self) -> Result<Tensor<T>> {
184 let shape_binding = self.shape();
185 let shape = shape_binding.dims();
186
187 match shape.len() {
188 4 => {
189 self.permute(&[0, 2, 3, 1])
191 }
192 3 => {
193 self.permute(&[1, 2, 0])
195 }
196 _ => Err(TorshError::InvalidArgument(
197 "to_channel_last expects 3D or 4D tensor".to_string(),
198 )),
199 }
200 }
201
202 fn to_channel_first(&self) -> Result<Tensor<T>> {
203 let shape_binding = self.shape();
204 let shape = shape_binding.dims();
205
206 match shape.len() {
207 4 => {
208 self.permute(&[0, 3, 1, 2])
210 }
211 3 => {
212 self.permute(&[2, 0, 1])
214 }
215 _ => Err(TorshError::InvalidArgument(
216 "to_channel_first expects 3D or 4D tensor".to_string(),
217 )),
218 }
219 }
220
221 fn swap_dims(&self, dim0: i32, dim1: i32) -> Result<Tensor<T>> {
222 self.transpose(dim0, dim1)
223 }
224
225 fn move_dim(&self, src: i32, dst: i32) -> Result<Tensor<T>> {
226 let ndim = self.shape().dims().len() as i32;
227
228 let src = if src < 0 { ndim + src } else { src };
230 let dst = if dst < 0 { ndim + dst } else { dst };
231
232 if src < 0 || src >= ndim || dst < 0 || dst >= ndim {
233 return Err(TorshError::InvalidArgument(
234 "Dimension out of range".to_string(),
235 ));
236 }
237
238 if src == dst {
239 return Ok(self.clone());
240 }
241
242 let mut perm: Vec<i32> = (0..ndim).collect();
244 let src_dim = perm.remove(src as usize);
245
246 perm.insert(dst as usize, src_dim);
247
248 self.permute(&perm)
249 }
250
251 fn expand_to(&self, target_shape: &[usize]) -> Result<Tensor<T>> {
252 let shape_binding = self.shape();
253 let current_shape = shape_binding.dims();
254
255 if current_shape.len() > target_shape.len() {
256 return Err(TorshError::InvalidArgument(
257 "Cannot expand to shape with fewer dimensions".to_string(),
258 ));
259 }
260
261 for (i, ¤t_size) in current_shape.iter().rev().enumerate() {
263 let target_idx = target_shape.len() - 1 - i;
264 let target_size = target_shape[target_idx];
265
266 if current_size != 1 && current_size != target_size {
267 return Err(TorshError::InvalidArgument(format!(
268 "Cannot expand dimension {} from {} to {}",
269 target_idx, current_size, target_size
270 )));
271 }
272 }
273
274 self.expand(target_shape)
275 }
276
277 fn repeat_along(&self, dim: i32, repeats: usize) -> Result<Tensor<T>> {
278 let unsqueezed = self.unsqueeze(dim)?;
280 let shape_binding = unsqueezed.shape();
281 let shape = shape_binding.dims();
282
283 let mut repeat_shape = vec![1; shape.len()];
284 let normalized_dim = if dim < 0 {
285 (shape.len() as i32 + dim) as usize
286 } else {
287 dim as usize
288 };
289
290 repeat_shape[normalized_dim] = repeats;
291
292 unsqueezed.repeat(&repeat_shape)
293 }
294}
295
296pub mod shape_utils {
298 use super::*;
299
300 pub fn numel(shape: &[usize]) -> usize {
302 shape.iter().product()
303 }
304
305 pub fn are_broadcastable(shape1: &[usize], shape2: &[usize]) -> bool {
307 let len1 = shape1.len();
308 let len2 = shape2.len();
309 let max_len = len1.max(len2);
310
311 for i in 0..max_len {
312 let dim1 = if i < len1 { shape1[len1 - 1 - i] } else { 1 };
313
314 let dim2 = if i < len2 { shape2[len2 - 1 - i] } else { 1 };
315
316 if dim1 != 1 && dim2 != 1 && dim1 != dim2 {
317 return false;
318 }
319 }
320
321 true
322 }
323
324 pub fn broadcast_shape(shape1: &[usize], shape2: &[usize]) -> Option<Vec<usize>> {
326 if !are_broadcastable(shape1, shape2) {
327 return None;
328 }
329
330 let len1 = shape1.len();
331 let len2 = shape2.len();
332 let max_len = len1.max(len2);
333
334 let mut result = Vec::with_capacity(max_len);
335
336 for i in 0..max_len {
337 let dim1 = if i < len1 { shape1[len1 - 1 - i] } else { 1 };
338
339 let dim2 = if i < len2 { shape2[len2 - 1 - i] } else { 1 };
340
341 result.push(dim1.max(dim2));
342 }
343
344 result.reverse();
345 Some(result)
346 }
347
348 pub fn infer_shape(shape: &[i32], total_elements: usize) -> Result<Vec<usize>> {
350 let mut result = Vec::new();
351 let mut unknown_idx = None;
352 let mut known_product = 1usize;
353
354 for (i, &dim) in shape.iter().enumerate() {
355 if dim == -1 {
356 if unknown_idx.is_some() {
357 return Err(TorshError::InvalidArgument(
358 "Only one dimension can be inferred".to_string(),
359 ));
360 }
361 unknown_idx = Some(i);
362 result.push(0); } else if dim < 0 {
364 return Err(TorshError::InvalidArgument(format!(
365 "Invalid dimension size: {}",
366 dim
367 )));
368 } else {
369 result.push(dim as usize);
370 known_product *= dim as usize;
371 }
372 }
373
374 if let Some(idx) = unknown_idx {
375 if known_product == 0 {
376 return Err(TorshError::InvalidArgument(
377 "Cannot infer dimension with zero-sized dimensions".to_string(),
378 ));
379 }
380
381 if total_elements % known_product != 0 {
382 return Err(TorshError::InvalidArgument(
383 "Cannot infer dimension: size is not divisible".to_string(),
384 ));
385 }
386
387 result[idx] = total_elements / known_product;
388 }
389
390 Ok(result)
391 }
392}
393
394#[cfg(test)]
395mod tests {
396 use super::*;
397 use crate::creation::*;
398
399 #[test]
400 fn test_squeeze_all() {
401 let tensor = zeros::<f32>(&[1, 3, 1, 4, 1]).expect("zeros creation should succeed");
402 let squeezed = tensor.squeeze_all().expect("squeeze_all should succeed");
403
404 assert_eq!(squeezed.shape().dims(), &[3, 4]);
405 }
406
407 #[test]
408 fn test_squeeze_dims() {
409 let tensor = zeros::<f32>(&[1, 3, 1, 4]).expect("zeros creation should succeed");
410 let squeezed = tensor
411 .squeeze_dims(&[0, 2])
412 .expect("squeeze_dims should succeed");
413
414 assert_eq!(squeezed.shape().dims(), &[3, 4]);
415 }
416
417 #[test]
418 fn test_unsqueeze_dims() {
419 let tensor = zeros::<f32>(&[3, 4]).expect("zeros creation should succeed");
420 let unsqueezed = tensor
423 .unsqueeze_dims(&[0, 2])
424 .expect("unsqueeze_dims should succeed");
425
426 assert_eq!(unsqueezed.shape().dims(), &[1, 3, 4, 1]);
427 }
428
429 #[test]
430 fn test_add_remove_batch_dim() {
431 let tensor = zeros::<f32>(&[3, 4]).expect("zeros creation should succeed");
432 let with_batch = tensor
433 .add_batch_dim()
434 .expect("add_batch_dim should succeed");
435
436 assert_eq!(with_batch.shape().dims(), &[1, 3, 4]);
437
438 let without_batch = with_batch
439 .remove_batch_dim()
440 .expect("remove_batch_dim should succeed");
441 assert_eq!(without_batch.shape().dims(), &[3, 4]);
442 }
443
444 #[test]
445 fn test_atleast_nd() {
446 let tensor = zeros::<f32>(&[3, 4]).expect("zeros creation should succeed");
447 let expanded = tensor.atleast_nd(4).expect("atleast_nd should succeed");
448
449 assert_eq!(expanded.shape().dims(), &[3, 4, 1, 1]);
450 }
451
452 #[test]
453 fn test_channel_conversions() {
454 let tensor = zeros::<f32>(&[2, 3, 4, 5]).expect("zeros creation should succeed"); let channel_last = tensor
457 .to_channel_last()
458 .expect("channel conversion should succeed");
459 assert_eq!(channel_last.shape().dims(), &[2, 4, 5, 3]); let channel_first = channel_last
462 .to_channel_first()
463 .expect("channel conversion should succeed");
464 assert_eq!(channel_first.shape().dims(), &[2, 3, 4, 5]); }
466
467 #[test]
468 fn test_move_dim() {
469 let tensor = zeros::<f32>(&[2, 3, 4, 5]).expect("zeros creation should succeed");
470 let moved = tensor.move_dim(1, 3).expect("move_dim should succeed");
471
472 assert_eq!(moved.shape().dims(), &[2, 4, 5, 3]);
474 }
475
476 #[test]
477 fn test_shape_utils_broadcastable() {
478 use shape_utils::*;
479
480 assert!(are_broadcastable(&[3, 1, 4], &[1, 5, 4]));
481 assert!(are_broadcastable(&[3, 4], &[3, 4]));
482 assert!(are_broadcastable(&[1], &[3, 4]));
483
484 assert!(!are_broadcastable(&[3, 4], &[2, 4]));
485 }
486
487 #[test]
488 fn test_shape_utils_broadcast_shape() {
489 use shape_utils::*;
490
491 let result = broadcast_shape(&[3, 1, 4], &[1, 5, 4]);
492 assert_eq!(result, Some(vec![3, 5, 4]));
493
494 let result = broadcast_shape(&[3, 4], &[2, 4]);
495 assert_eq!(result, None);
496 }
497
498 #[test]
499 fn test_shape_utils_infer_shape() {
500 use shape_utils::*;
501
502 let inferred = infer_shape(&[2, -1, 3], 24).expect("shape inference should succeed");
503 assert_eq!(inferred, vec![2, 4, 3]);
504
505 let inferred = infer_shape(&[3, 4], 12).expect("shape inference should succeed");
506 assert_eq!(inferred, vec![3, 4]);
507 }
508
509 #[test]
510 fn test_squeeze_dims_invalid() {
511 let tensor = zeros::<f32>(&[3, 4]).expect("zeros creation should succeed");
512 let result = tensor.squeeze_dims(&[0]); assert!(result.is_err());
515 }
516
517 #[test]
518 fn test_remove_batch_dim_invalid() {
519 let tensor = zeros::<f32>(&[3, 4]).expect("zeros creation should succeed");
520 let result = tensor.remove_batch_dim(); assert!(result.is_err());
523 }
524}