Skip to main content

scivex_core/tensor/
reshape.rs

1//! Shape manipulation: reshape, transpose, flatten, squeeze, unsqueeze,
2//! concatenate, and stack.
3
4use crate::Scalar;
5use crate::error::{CoreError, Result};
6
7use super::{Tensor, compute_strides};
8
9impl<T: Scalar> Tensor<T> {
10    /// Reshape the tensor to a new shape without copying data.
11    ///
12    /// The total number of elements must remain the same.
13    ///
14    /// # Examples
15    ///
16    /// ```
17    /// # use scivex_core::Tensor;
18    /// let t = Tensor::from_vec(vec![1, 2, 3, 4, 5, 6], vec![6]).unwrap();
19    /// let t = t.reshape(vec![2, 3]).unwrap();
20    /// assert_eq!(t.shape(), &[2, 3]);
21    /// ```
22    pub fn reshape(mut self, new_shape: Vec<usize>) -> Result<Self> {
23        let new_numel: usize = new_shape.iter().product();
24        if new_numel != self.numel() {
25            return Err(CoreError::InvalidShape {
26                shape: new_shape,
27                reason: "new shape has different number of elements",
28            });
29        }
30        self.strides = compute_strides(&new_shape);
31        self.shape = new_shape;
32        Ok(self)
33    }
34
35    /// Return a reshaped view without consuming the tensor (copies data).
36    ///
37    /// # Examples
38    ///
39    /// ```
40    /// # use scivex_core::Tensor;
41    /// let t = Tensor::from_vec(vec![1, 2, 3, 4], vec![4]).unwrap();
42    /// let r = t.reshaped(vec![2, 2]).unwrap();
43    /// assert_eq!(r.shape(), &[2, 2]);
44    /// assert_eq!(t.shape(), &[4]); // original unchanged
45    /// ```
46    pub fn reshaped(&self, new_shape: Vec<usize>) -> Result<Self> {
47        self.clone().reshape(new_shape)
48    }
49
50    /// Flatten the tensor into a 1-D tensor (consumes self, no copy).
51    ///
52    /// # Examples
53    ///
54    /// ```
55    /// # use scivex_core::Tensor;
56    /// let t = Tensor::from_vec(vec![1, 2, 3, 4], vec![2, 2]).unwrap();
57    /// let flat = t.flatten();
58    /// assert_eq!(flat.shape(), &[4]);
59    /// ```
60    pub fn flatten(self) -> Self {
61        let n = self.numel();
62        Tensor {
63            data: self.data,
64            shape: vec![n],
65            strides: vec![1],
66        }
67    }
68
69    /// Return a flattened copy of the tensor.
70    ///
71    /// # Examples
72    ///
73    /// ```
74    /// # use scivex_core::Tensor;
75    /// let t = Tensor::from_vec(vec![1, 2, 3, 4], vec![2, 2]).unwrap();
76    /// let flat = t.flattened();
77    /// assert_eq!(flat.shape(), &[4]);
78    /// assert_eq!(t.shape(), &[2, 2]); // original unchanged
79    /// ```
80    pub fn flattened(&self) -> Self {
81        let n = self.numel();
82        Tensor {
83            data: self.data.clone(),
84            shape: vec![n],
85            strides: vec![1],
86        }
87    }
88
89    /// Transpose a 2-D tensor (matrix). Returns a new tensor with copied data.
90    ///
91    /// For higher-rank tensors, use [`permute`](Self::permute).
92    ///
93    /// # Examples
94    ///
95    /// ```
96    /// # use scivex_core::Tensor;
97    /// let t = Tensor::from_vec(vec![1, 2, 3, 4, 5, 6], vec![2, 3]).unwrap();
98    /// let tt = t.transpose().unwrap();
99    /// assert_eq!(tt.shape(), &[3, 2]);
100    /// ```
101    pub fn transpose(&self) -> Result<Self> {
102        if self.ndim() != 2 {
103            return Err(CoreError::InvalidArgument {
104                reason: "transpose() requires a 2-D tensor; use permute() for higher ranks",
105            });
106        }
107        let (rows, cols) = (self.shape[0], self.shape[1]);
108        let mut data = vec![T::zero(); self.numel()];
109
110        for r in 0..rows {
111            for c in 0..cols {
112                data[c * rows + r] = self.data[r * cols + c];
113            }
114        }
115
116        Tensor::from_vec(data, vec![cols, rows])
117    }
118
119    /// Permute the dimensions of the tensor according to the given axes.
120    ///
121    /// `axes` must be a permutation of `0..ndim`.
122    ///
123    /// # Examples
124    ///
125    /// ```
126    /// # use scivex_core::Tensor;
127    /// let t = Tensor::<i32>::arange(24).reshape(vec![2, 3, 4]).unwrap();
128    /// let p = t.permute(&[2, 0, 1]).unwrap();
129    /// assert_eq!(p.shape(), &[4, 2, 3]);
130    /// ```
131    pub fn permute(&self, axes: &[usize]) -> Result<Self> {
132        if axes.len() != self.ndim() {
133            return Err(CoreError::InvalidArgument {
134                reason: "axes length must match tensor rank",
135            });
136        }
137
138        // Validate it's a valid permutation
139        let mut seen = vec![false; self.ndim()];
140        for &a in axes {
141            if a >= self.ndim() {
142                return Err(CoreError::AxisOutOfBounds {
143                    axis: a,
144                    ndim: self.ndim(),
145                });
146            }
147            if seen[a] {
148                return Err(CoreError::InvalidArgument {
149                    reason: "duplicate axis in permutation",
150                });
151            }
152            seen[a] = true;
153        }
154
155        let new_shape: Vec<usize> = axes.iter().map(|&a| self.shape[a]).collect();
156        let new_strides = compute_strides(&new_shape);
157        let new_numel: usize = new_shape.iter().product();
158        let mut data = vec![T::zero(); new_numel];
159
160        // Iterate over every element in the output
161        let mut out_index = vec![0usize; self.ndim()];
162        for item in &mut data {
163            // Map output index back to input index
164            let mut flat_in = 0;
165            for (out_ax, &in_ax) in axes.iter().enumerate() {
166                flat_in += out_index[out_ax] * self.strides[in_ax];
167            }
168            *item = self.data[flat_in];
169
170            // Increment the output index (odometer style)
171            for d in (0..self.ndim()).rev() {
172                out_index[d] += 1;
173                if out_index[d] < new_shape[d] {
174                    break;
175                }
176                out_index[d] = 0;
177            }
178        }
179
180        Ok(Tensor {
181            data,
182            shape: new_shape,
183            strides: new_strides,
184        })
185    }
186
187    /// Insert a dimension of size 1 at the given axis.
188    ///
189    /// # Examples
190    ///
191    /// ```
192    /// # use scivex_core::Tensor;
193    /// let t = Tensor::from_vec(vec![1, 2, 3], vec![3]).unwrap();
194    /// let t = t.unsqueeze(0).unwrap();
195    /// assert_eq!(t.shape(), &[1, 3]);
196    /// ```
197    pub fn unsqueeze(mut self, axis: usize) -> Result<Self> {
198        if axis > self.ndim() {
199            return Err(CoreError::AxisOutOfBounds {
200                axis,
201                ndim: self.ndim(),
202            });
203        }
204        self.shape.insert(axis, 1);
205        self.strides = compute_strides(&self.shape);
206        Ok(self)
207    }
208
209    /// Remove all dimensions of size 1.
210    ///
211    /// # Examples
212    ///
213    /// ```
214    /// # use scivex_core::Tensor;
215    /// let t = Tensor::from_vec(vec![1, 2, 3], vec![1, 3, 1]).unwrap();
216    /// let t = t.squeeze();
217    /// assert_eq!(t.shape(), &[3]);
218    /// ```
219    pub fn squeeze(mut self) -> Self {
220        self.shape.retain(|&d| d != 1);
221        if self.shape.is_empty() && self.numel() == 1 {
222            self.shape = vec![];
223        }
224        self.strides = compute_strides(&self.shape);
225        self
226    }
227
228    /// Concatenate a list of tensors along the given axis.
229    ///
230    /// All tensors must have the same shape except along the concatenation axis.
231    ///
232    /// # Examples
233    ///
234    /// ```
235    /// # use scivex_core::Tensor;
236    /// let a = Tensor::from_vec(vec![1, 2, 3], vec![1, 3]).unwrap();
237    /// let b = Tensor::from_vec(vec![4, 5, 6], vec![1, 3]).unwrap();
238    /// let c = Tensor::concat(&[&a, &b], 0).unwrap();
239    /// assert_eq!(c.shape(), &[2, 3]);
240    /// ```
241    pub fn concat(tensors: &[&Tensor<T>], axis: usize) -> Result<Self> {
242        if tensors.is_empty() {
243            return Err(CoreError::InvalidArgument {
244                reason: "cannot concatenate zero tensors",
245            });
246        }
247
248        let ndim = tensors[0].ndim();
249        if axis >= ndim {
250            return Err(CoreError::AxisOutOfBounds { axis, ndim });
251        }
252
253        // Validate shapes match on all axes except `axis`
254        for t in &tensors[1..] {
255            if t.ndim() != ndim {
256                return Err(CoreError::DimensionMismatch {
257                    expected: tensors[0].shape.clone(),
258                    got: t.shape.clone(),
259                });
260            }
261            for (d, (&a, &b)) in tensors[0].shape.iter().zip(t.shape.iter()).enumerate() {
262                if d != axis && a != b {
263                    return Err(CoreError::DimensionMismatch {
264                        expected: tensors[0].shape.clone(),
265                        got: t.shape.clone(),
266                    });
267                }
268            }
269        }
270
271        let mut new_shape = tensors[0].shape.clone();
272        new_shape[axis] = tensors.iter().map(|t| t.shape[axis]).sum();
273
274        let outer: usize = new_shape[..axis].iter().product();
275        let inner: usize = new_shape[axis + 1..].iter().product();
276        let total: usize = new_shape.iter().product();
277
278        let mut data = Vec::with_capacity(total);
279
280        for o in 0..outer {
281            for t in tensors {
282                let axis_len = t.shape[axis];
283                let src_start = o * axis_len * inner;
284                let src_end = src_start + axis_len * inner;
285                data.extend_from_slice(&t.data[src_start..src_end]);
286            }
287        }
288
289        Tensor::from_vec(data, new_shape)
290    }
291
292    /// Stack tensors along a new axis inserted at position `axis`.
293    ///
294    /// All tensors must have identical shapes.
295    ///
296    /// # Examples
297    ///
298    /// ```
299    /// # use scivex_core::Tensor;
300    /// let a = Tensor::from_vec(vec![1, 2, 3], vec![3]).unwrap();
301    /// let b = Tensor::from_vec(vec![4, 5, 6], vec![3]).unwrap();
302    /// let c = Tensor::stack(&[&a, &b], 0).unwrap();
303    /// assert_eq!(c.shape(), &[2, 3]);
304    /// ```
305    pub fn stack(tensors: &[&Tensor<T>], axis: usize) -> Result<Self> {
306        if tensors.is_empty() {
307            return Err(CoreError::InvalidArgument {
308                reason: "cannot stack zero tensors",
309            });
310        }
311
312        let base_shape = &tensors[0].shape;
313        if axis > base_shape.len() {
314            return Err(CoreError::AxisOutOfBounds {
315                axis,
316                ndim: base_shape.len() + 1,
317            });
318        }
319
320        for t in &tensors[1..] {
321            if t.shape != *base_shape {
322                return Err(CoreError::DimensionMismatch {
323                    expected: base_shape.clone(),
324                    got: t.shape.clone(),
325                });
326            }
327        }
328
329        // Unsqueeze each tensor along the new axis, then concat
330        let expanded: Vec<Tensor<T>> = tensors
331            .iter()
332            // SAFETY: axis is validated above and is within bounds for all tensors.
333            .map(|t| {
334                (*t).clone()
335                    .unsqueeze(axis)
336                    .expect("axis is valid for all tensors since shapes were validated above")
337            })
338            .collect();
339        let refs: Vec<&Tensor<T>> = expanded.iter().collect();
340        Tensor::concat(&refs, axis)
341    }
342}
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347
348    #[test]
349    fn test_reshape() {
350        let t = Tensor::from_vec(vec![1, 2, 3, 4, 5, 6], vec![6]).unwrap();
351        let t = t.reshape(vec![2, 3]).unwrap();
352        assert_eq!(t.shape(), &[2, 3]);
353        assert_eq!(t.strides(), &[3, 1]);
354        assert_eq!(*t.get(&[1, 0]).unwrap(), 4);
355    }
356
357    #[test]
358    fn test_reshape_invalid() {
359        let t = Tensor::from_vec(vec![1, 2, 3, 4], vec![4]).unwrap();
360        assert!(t.reshape(vec![3, 2]).is_err());
361    }
362
363    #[test]
364    fn test_flatten() {
365        let t = Tensor::from_vec(vec![1, 2, 3, 4, 5, 6], vec![2, 3]).unwrap();
366        let flat = t.flatten();
367        assert_eq!(flat.shape(), &[6]);
368        assert_eq!(flat.as_slice(), &[1, 2, 3, 4, 5, 6]);
369    }
370
371    #[test]
372    fn test_transpose() {
373        // [[1, 2, 3],
374        //  [4, 5, 6]]
375        let t = Tensor::from_vec(vec![1, 2, 3, 4, 5, 6], vec![2, 3]).unwrap();
376        let tt = t.transpose().unwrap();
377        assert_eq!(tt.shape(), &[3, 2]);
378        assert_eq!(*tt.get(&[0, 0]).unwrap(), 1);
379        assert_eq!(*tt.get(&[0, 1]).unwrap(), 4);
380        assert_eq!(*tt.get(&[2, 0]).unwrap(), 3);
381        assert_eq!(*tt.get(&[2, 1]).unwrap(), 6);
382    }
383
384    #[test]
385    fn test_transpose_not_2d() {
386        let t = Tensor::from_vec(vec![1, 2, 3], vec![3]).unwrap();
387        assert!(t.transpose().is_err());
388    }
389
390    #[test]
391    fn test_permute() {
392        // Shape [2, 3, 4] -> permute [2, 0, 1] -> shape [4, 2, 3]
393        let t = Tensor::<i32>::arange(24).reshape(vec![2, 3, 4]).unwrap();
394        let p = t.permute(&[2, 0, 1]).unwrap();
395        assert_eq!(p.shape(), &[4, 2, 3]);
396        // Element at [0, 0, 0] in original is at [0, 0, 0] in permuted
397        assert_eq!(*p.get(&[0, 0, 0]).unwrap(), 0);
398        // Element at [1, 2, 3] in original -> permuted[3, 1, 2]
399        assert_eq!(*p.get(&[3, 1, 2]).unwrap(), *t.get(&[1, 2, 3]).unwrap());
400    }
401
402    #[test]
403    fn test_unsqueeze_squeeze() {
404        let t = Tensor::from_vec(vec![1, 2, 3], vec![3]).unwrap();
405        let t = t.unsqueeze(0).unwrap();
406        assert_eq!(t.shape(), &[1, 3]);
407        let t = t.squeeze();
408        assert_eq!(t.shape(), &[3]);
409    }
410
411    #[test]
412    fn test_concat() {
413        let a = Tensor::from_vec(vec![1, 2, 3], vec![1, 3]).unwrap();
414        let b = Tensor::from_vec(vec![4, 5, 6], vec![1, 3]).unwrap();
415        let c = Tensor::concat(&[&a, &b], 0).unwrap();
416        assert_eq!(c.shape(), &[2, 3]);
417        assert_eq!(c.as_slice(), &[1, 2, 3, 4, 5, 6]);
418    }
419
420    #[test]
421    fn test_concat_axis1() {
422        let a = Tensor::from_vec(vec![1, 2, 3, 4], vec![2, 2]).unwrap();
423        let b = Tensor::from_vec(vec![5, 6, 7, 8], vec![2, 2]).unwrap();
424        let c = Tensor::concat(&[&a, &b], 1).unwrap();
425        assert_eq!(c.shape(), &[2, 4]);
426        assert_eq!(c.as_slice(), &[1, 2, 5, 6, 3, 4, 7, 8]);
427    }
428
429    #[test]
430    fn test_stack() {
431        let a = Tensor::from_vec(vec![1, 2, 3], vec![3]).unwrap();
432        let b = Tensor::from_vec(vec![4, 5, 6], vec![3]).unwrap();
433        let c = Tensor::stack(&[&a, &b], 0).unwrap();
434        assert_eq!(c.shape(), &[2, 3]);
435        assert_eq!(c.as_slice(), &[1, 2, 3, 4, 5, 6]);
436    }
437
438    #[test]
439    fn test_stack_axis1() {
440        let a = Tensor::from_vec(vec![1, 2, 3], vec![3]).unwrap();
441        let b = Tensor::from_vec(vec![4, 5, 6], vec![3]).unwrap();
442        let c = Tensor::stack(&[&a, &b], 1).unwrap();
443        assert_eq!(c.shape(), &[3, 2]);
444        assert_eq!(c.as_slice(), &[1, 4, 2, 5, 3, 6]);
445    }
446
447    #[test]
448    fn test_concat_shape_mismatch() {
449        let a = Tensor::from_vec(vec![1, 2, 3], vec![1, 3]).unwrap();
450        let b = Tensor::from_vec(vec![4, 5], vec![1, 2]).unwrap();
451        assert!(Tensor::concat(&[&a, &b], 0).is_err());
452    }
453}