Skip to main content

vortex_tensor/fixed_shape/
metadata.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt;
5
6use itertools::Either;
7use vortex::error::VortexExpect;
8use vortex::error::VortexResult;
9use vortex::error::vortex_ensure;
10use vortex::error::vortex_ensure_eq;
11
12/// Metadata for a `FixedShapeTensor` extension type.
13#[derive(Debug, Clone, PartialEq, Eq, Hash)]
14pub struct FixedShapeTensorMetadata {
15    /// The logical shape of the tensor.
16    ///
17    /// `logical_shape[i]` is the size of the `i`-th logical dimension. When a `permutation` is
18    /// present, the physical shape (i.e., the row-major memory layout) is derived as
19    /// `physical_shape[permutation[i]] = logical_shape[i]`.
20    ///
21    /// May be empty (0D scalar tensor) or contain dimensions of size 0 (degenerate tensor).
22    logical_shape: Vec<usize>,
23
24    /// Optional names for each logical dimension. Each name corresponds to an entry in
25    /// `logical_shape`.
26    ///
27    /// If names exist, there must be an equal number of names to logical dimensions.
28    dim_names: Option<Vec<String>>,
29
30    /// The permutation of the tensor's dimensions. `permutation[i]` is the physical dimension
31    /// index that logical dimension `i` maps to.
32    ///
33    /// If this is `None`, then the logical and physical layouts are identical, equivalent to
34    /// the identity permutation `[0, 1, ..., N-1]`.
35    permutation: Option<Vec<usize>>,
36}
37
38impl FixedShapeTensorMetadata {
39    /// Creates a new [`FixedShapeTensorMetadata`] with the given logical `shape`.
40    ///
41    /// Use [`with_dim_names`](Self::with_dim_names) and
42    /// [`with_permutation`](Self::with_permutation) to further configure the metadata.
43    pub fn new(shape: Vec<usize>) -> Self {
44        Self {
45            logical_shape: shape,
46            dim_names: None,
47            permutation: None,
48        }
49    }
50
51    /// Sets the dimension names for this tensor. An empty vec is normalized to `None` since a
52    /// 0-dimensional tensor has no dimensions to name.
53    ///
54    /// The number of names must match the number of logical dimensions.
55    pub fn with_dim_names(mut self, names: Vec<String>) -> VortexResult<Self> {
56        if !names.is_empty() {
57            vortex_ensure_eq!(
58                names.len(),
59                self.logical_shape.len(),
60                "dim_names length ({}) must match logical_shape length ({})",
61                names.len(),
62                self.logical_shape.len()
63            );
64            self.dim_names = Some(names);
65        }
66
67        Ok(self)
68    }
69
70    /// Sets the permutation for this tensor. An empty vec is normalized to `None` since a
71    /// 0-dimensional tensor has no dimensions to permute.
72    ///
73    /// The permutation must be a valid permutation of `[0, 1, ..., N-1]` where `N` is the
74    /// number of logical dimensions.
75    pub fn with_permutation(mut self, permutation: Vec<usize>) -> VortexResult<Self> {
76        if !permutation.is_empty() {
77            vortex_ensure_eq!(
78                permutation.len(),
79                self.logical_shape.len(),
80                "permutation length ({}) must match logical_shape length ({})",
81                permutation.len(),
82                self.logical_shape.len()
83            );
84
85            // Verify this is actually a permutation of [0..N).
86            let mut seen = vec![false; permutation.len()];
87            for &p in &permutation {
88                vortex_ensure!(
89                    p < permutation.len(),
90                    "permutation index {p} is out of range for {} dimensions",
91                    permutation.len()
92                );
93                vortex_ensure!(!seen[p], "permutation contains duplicate index {p}");
94                seen[p] = true;
95            }
96
97            self.permutation = Some(permutation);
98        }
99
100        Ok(self)
101    }
102
103    /// Returns the number of dimensions (rank) of the tensor.
104    pub fn ndim(&self) -> usize {
105        self.logical_shape.len()
106    }
107
108    /// Returns the logical dimensions of the tensor as a slice.
109    pub fn logical_shape(&self) -> &[usize] {
110        &self.logical_shape
111    }
112
113    /// Returns the dimension names, if set.
114    pub fn dim_names(&self) -> Option<&[String]> {
115        self.dim_names.as_deref()
116    }
117
118    /// Returns the permutation, if set.
119    pub fn permutation(&self) -> Option<&[usize]> {
120        self.permutation.as_deref()
121    }
122
123    /// Returns an iterator over the physical shape of the tensor.
124    ///
125    /// The physical shape describes the row-major memory layout. It is derived from the logical
126    /// shape by placing each logical dimension's size at its physical position:
127    /// `physical_shape[permutation[i]] = logical_shape[i]`.
128    ///
129    /// When no permutation is present, the physical shape is identical to the logical shape.
130    pub fn physical_shape(&self) -> impl Iterator<Item = usize> + '_ {
131        let ndim = self.logical_shape.len();
132        let permutation = self.permutation.as_deref();
133
134        match permutation {
135            None => Either::Left(self.logical_shape.iter().copied()),
136            Some(perm) => Either::Right(
137                (0..ndim).map(move |p| self.logical_shape[Self::inverse_perm(perm, p)]),
138            ),
139        }
140    }
141
142    /// Returns an iterator over the strides for each logical dimension of the tensor.
143    ///
144    /// The stride for a logical dimension is the number of elements to skip in the flat backing
145    /// array in order to move one step along that logical dimension.
146    ///
147    /// When a permutation is present, the physical memory is laid out in row-major order over the
148    /// physical dimensions (the logical dimensions reordered by the permutation), so the strides
149    /// are derived from that physical layout.
150    pub fn strides(&self) -> impl Iterator<Item = usize> + '_ {
151        let ndim = self.logical_shape.len();
152        let permutation = self.permutation.as_deref();
153
154        match permutation {
155            None => Either::Left(
156                (0..ndim).map(|i| self.logical_shape[i + 1..].iter().product::<usize>()),
157            ),
158            Some(permutation) => {
159                Either::Right((0..ndim).map(|i| self.permuted_stride(i, permutation)))
160            }
161        }
162    }
163
164    /// Computes the stride for logical dimension `i` given a `permutation`.
165    ///
166    /// The stride is the product of `logical_shape[j]` for all logical dimensions `j` whose
167    /// physical position (`perm[j]`) comes after the physical position of dimension `i`
168    /// (`perm[i]`).
169    fn permuted_stride(&self, i: usize, perm: &[usize]) -> usize {
170        let phys = perm[i];
171
172        // Each call scans the full permutation, making `strides()` O(ndim^2) overall. Tensor rank
173        // is typically small, so avoiding a Vec allocation is a net win.
174        perm.iter()
175            .enumerate()
176            .filter(|&(_, &p)| p > phys)
177            .map(|(l, _)| self.logical_shape[l])
178            .product::<usize>()
179    }
180
181    /// Returns the logical dimension index that maps to physical position `p`. This is the
182    /// inverse of the permutation: if `perm[i] == p`, returns `i`.
183    ///
184    /// Each call is a linear scan of `perm`, making callers that invoke this for every physical
185    /// position O(ndim^2) overall. Tensor rank is typically small (2–5), so avoiding a Vec
186    /// allocation for the full inverse permutation is a net win.
187    fn inverse_perm(perm: &[usize], p: usize) -> usize {
188        perm.iter()
189            .position(|&pi| pi == p)
190            .vortex_expect("permutation must contain every physical position exactly once")
191    }
192}
193
194impl fmt::Display for FixedShapeTensorMetadata {
195    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
196        write!(f, "Tensor(")?;
197
198        match &self.dim_names {
199            Some(names) => {
200                for (i, (dim, name)) in self.logical_shape.iter().zip(names.iter()).enumerate() {
201                    if i > 0 {
202                        write!(f, ", ")?;
203                    }
204                    write!(f, "{name}: {dim}")?;
205                }
206            }
207            None => {
208                for (i, dim) in self.logical_shape.iter().enumerate() {
209                    if i > 0 {
210                        write!(f, ", ")?;
211                    }
212                    write!(f, "{dim}")?;
213                }
214            }
215        }
216
217        if let Some(perm) = &self.permutation {
218            for (i, p) in perm.iter().enumerate() {
219                if i > 0 {
220                    write!(f, ", ")?;
221                }
222                write!(f, "{p}")?;
223            }
224            write!(f, "]")?;
225        }
226
227        write!(f, ")")
228    }
229}
230
231#[cfg(test)]
232mod tests {
233    use rstest::rstest;
234
235    use super::*;
236
237    /// Reference implementation that computes permuted strides in an explicit, step-by-step way.
238    ///
239    /// 1. Build the physical shape: `physical_shape[perm[i]] = logical_shape[i]`.
240    /// 2. Compute row-major strides over the physical shape.
241    /// 3. Map back to logical: `logical_stride[i] = physical_strides[perm[i]]`.
242    fn slow_strides(shape: &[usize], perm: &[usize]) -> Vec<usize> {
243        let ndim = shape.len();
244
245        // Derive the physical shape from the logical shape and the permutation.
246        let mut physical_shape = vec![0usize; ndim];
247        for l in 0..ndim {
248            physical_shape[perm[l]] = shape[l];
249        }
250
251        // Compute row-major strides over the physical shape.
252        let mut physical_strides = vec![1usize; ndim];
253        for i in (0..ndim.saturating_sub(1)).rev() {
254            physical_strides[i] = physical_strides[i + 1] * physical_shape[i + 1];
255        }
256
257        // Map physical strides back to logical dimension order.
258        (0..ndim).map(|l| physical_strides[perm[l]]).collect()
259    }
260
261    // -- Row-major strides (no permutation) --
262
263    #[rstest]
264    #[case::scalar_0d(vec![],        vec![])]
265    #[case::vector_1d(vec![5],       vec![1])]
266    #[case::matrix_2d(vec![3, 4],    vec![4, 1])]
267    #[case::tensor_3d(vec![2, 3, 4], vec![12, 4, 1])]
268    #[case::zero_dim( vec![3, 0, 4], vec![0, 4, 1])]
269    fn strides_row_major(#[case] shape: Vec<usize>, #[case] expected: Vec<usize>) {
270        let m = FixedShapeTensorMetadata::new(shape);
271        assert_eq!(m.strides().collect::<Vec<_>>(), expected);
272    }
273
274    // -- Permuted strides --
275    //
276    // Each case is checked against the expected value and cross-validated against the
277    // `slow_strides` reference implementation.
278
279    #[rstest]
280    // 2D transpose: physical shape = [4, 3].
281    #[case::transpose_2d(vec![3, 4],    vec![1, 0],    vec![1, 3])]
282    // 3D: physical shape = [3, 4, 2].
283    #[case::perm_3d_201( vec![2, 3, 4], vec![2, 0, 1], vec![1, 8, 2])]
284    // 3D with zero-sized dimension: physical shape = [4, 3, 0].
285    #[case::zero_dim(    vec![3, 0, 4], vec![1, 2, 0], vec![0, 1, 0])]
286    fn strides_permuted(
287        #[case] shape: Vec<usize>,
288        #[case] perm: Vec<usize>,
289        #[case] expected: Vec<usize>,
290    ) -> VortexResult<()> {
291        let m = FixedShapeTensorMetadata::new(shape.clone()).with_permutation(perm.clone())?;
292        let actual: Vec<usize> = m.strides().collect();
293        assert_eq!(actual, expected);
294        assert_eq!(actual, slow_strides(&shape, &perm));
295        Ok(())
296    }
297
298    #[test]
299    fn strides_identity_permutation_matches_row_major() -> VortexResult<()> {
300        let row_major = FixedShapeTensorMetadata::new(vec![2, 3, 4]);
301        let identity =
302            FixedShapeTensorMetadata::new(vec![2, 3, 4]).with_permutation(vec![0, 1, 2])?;
303        assert_eq!(
304            row_major.strides().collect::<Vec<_>>(),
305            identity.strides().collect::<Vec<_>>(),
306        );
307        Ok(())
308    }
309
310    /// Cross-validates the fast `permuted_stride` against the reference `slow_strides` across a
311    /// broader set of shapes and permutations.
312    #[rstest]
313    #[case::perm_3d_120(vec![2, 3, 4],    vec![1, 2, 0])]
314    #[case::perm_3d_021(vec![2, 3, 4],    vec![0, 2, 1])]
315    #[case::identity_3d(vec![2, 3, 4],    vec![0, 1, 2])]
316    #[case::zero_lead(  vec![0, 3, 4],    vec![2, 0, 1])]
317    #[case::rev_4d(     vec![2, 3, 4, 5], vec![3, 2, 1, 0])]
318    #[case::swap_4d(    vec![2, 3, 4, 5], vec![1, 0, 3, 2])]
319    #[case::half_4d(    vec![2, 3, 4, 5], vec![2, 3, 0, 1])]
320    fn strides_match_slow_reference(
321        #[case] shape: Vec<usize>,
322        #[case] perm: Vec<usize>,
323    ) -> VortexResult<()> {
324        let m = FixedShapeTensorMetadata::new(shape.clone()).with_permutation(perm.clone())?;
325        assert_eq!(m.strides().collect::<Vec<_>>(), slow_strides(&shape, &perm));
326        Ok(())
327    }
328
329    // -- Physical shape --
330
331    #[test]
332    fn physical_shape_no_permutation() {
333        let m = FixedShapeTensorMetadata::new(vec![2, 3, 4]);
334        assert_eq!(m.physical_shape().collect::<Vec<_>>(), vec![2, 3, 4]);
335    }
336
337    #[rstest]
338    // Logical [3, 4] with perm [1, 0] → physical [4, 3].
339    #[case::transpose_2d(vec![3, 4],    vec![1, 0],    vec![4, 3])]
340    // Logical [2, 3, 4] with perm [2, 0, 1] → physical [3, 4, 2].
341    #[case::perm_3d(     vec![2, 3, 4], vec![2, 0, 1], vec![3, 4, 2])]
342    // Identity: physical = logical.
343    #[case::identity(    vec![2, 3, 4], vec![0, 1, 2], vec![2, 3, 4])]
344    // Logical [3, 0, 4] with perm [1, 2, 0] → physical [4, 3, 0].
345    #[case::zero_dim(    vec![3, 0, 4], vec![1, 2, 0], vec![4, 3, 0])]
346    fn physical_shape_permuted(
347        #[case] shape: Vec<usize>,
348        #[case] perm: Vec<usize>,
349        #[case] expected: Vec<usize>,
350    ) -> VortexResult<()> {
351        let m = FixedShapeTensorMetadata::new(shape).with_permutation(perm)?;
352        assert_eq!(m.physical_shape().collect::<Vec<_>>(), expected);
353        Ok(())
354    }
355
356    #[test]
357    fn dim_names_wrong_length() {
358        let result = FixedShapeTensorMetadata::new(vec![2, 3]).with_dim_names(vec!["x".into()]);
359        assert!(result.is_err());
360    }
361
362    #[test]
363    fn permutation_wrong_length() {
364        let result = FixedShapeTensorMetadata::new(vec![2, 3]).with_permutation(vec![0]);
365        assert!(result.is_err());
366    }
367
368    #[test]
369    fn permutation_out_of_range() {
370        let result = FixedShapeTensorMetadata::new(vec![2, 3]).with_permutation(vec![0, 5]);
371        assert!(result.is_err());
372    }
373
374    #[test]
375    fn permutation_duplicate_index() {
376        let result = FixedShapeTensorMetadata::new(vec![2, 3]).with_permutation(vec![0, 0]);
377        assert!(result.is_err());
378    }
379}