Skip to main content

vortex_tensor/fixed_shape/
matcher.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::dtype::DType;
5use vortex_array::dtype::PType;
6use vortex_array::dtype::extension::ExtDTypeRef;
7use vortex_array::dtype::extension::Matcher;
8use vortex_error::VortexExpect;
9use vortex_error::vortex_panic;
10
11use crate::fixed_shape::FixedShapeTensor;
12use crate::fixed_shape::FixedShapeTensorMetadata;
13
14pub struct AnyFixedShapeTensor;
15
16/// Convenience metadata for fixed-shape tensors.
17///
18/// Fixed-shape tensors already store their logical metadata directly, but callers also often need
19/// the flattened storage list size and element primitive type from the storage dtype.
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub struct FixedShapeTensorMatcherMetadata<'a> {
22    /// The logical fixed-shape tensor metadata stored on the extension dtype.
23    metadata: &'a FixedShapeTensorMetadata,
24
25    /// The primitive element type of the tensor storage.
26    ///
27    /// Fixed-shape tensors currently require non-nullable primitive elements.
28    element_ptype: PType,
29
30    /// The flattened element count for each tensor row in storage order.
31    ///
32    /// This matches the `FixedSizeList` list size in the storage dtype, which is the product of
33    /// the logical shape dimensions.
34    flat_list_size: usize,
35}
36
37impl Matcher for AnyFixedShapeTensor {
38    type Match<'a> = FixedShapeTensorMatcherMetadata<'a>;
39
40    fn try_match<'a>(ext_dtype: &'a ExtDTypeRef) -> Option<Self::Match<'a>> {
41        if !ext_dtype.is::<FixedShapeTensor>() {
42            return None;
43        }
44
45        let metadata = ext_dtype
46            .metadata_opt::<FixedShapeTensor>()
47            .vortex_expect("`FixedShapeTensor` type somehow did not have metadata");
48
49        let DType::FixedSizeList(element_dtype, list_size, _) = ext_dtype.storage_dtype() else {
50            vortex_panic!(
51                "`FixedShapeTensor` type somehow did not have a `FixedSizeList` storage type"
52            )
53        };
54
55        assert!(
56            element_dtype.is_primitive(),
57            "element dtype must be primitive"
58        );
59        assert!(
60            !element_dtype.is_nullable(),
61            "element dtype must be non-nullable"
62        );
63
64        Some(FixedShapeTensorMatcherMetadata {
65            metadata,
66            element_ptype: element_dtype.as_ptype(),
67            flat_list_size: *list_size as usize,
68        })
69    }
70}
71
72impl FixedShapeTensorMatcherMetadata<'_> {
73    /// Returns the underlying fixed-shape tensor metadata.
74    pub fn metadata(&self) -> &FixedShapeTensorMetadata {
75        self.metadata
76    }
77
78    /// Returns the tensor element type.
79    pub fn element_ptype(&self) -> PType {
80        self.element_ptype
81    }
82
83    /// Returns the flattened element count for each tensor row.
84    pub fn list_size(&self) -> usize {
85        self.flat_list_size
86    }
87}
88
89#[cfg(test)]
90mod tests {
91    use std::sync::Arc;
92
93    use vortex_array::dtype::DType;
94    use vortex_array::dtype::Nullability;
95    use vortex_array::dtype::PType;
96    use vortex_array::dtype::extension::ExtDType;
97    use vortex_array::extension::EmptyMetadata;
98    use vortex_error::VortexResult;
99
100    use super::*;
101    use crate::vector::Vector;
102
103    fn tensor_storage_dtype(element_ptype: PType, list_size: u32) -> DType {
104        DType::FixedSizeList(
105            Arc::new(DType::Primitive(element_ptype, Nullability::NonNullable)),
106            list_size,
107            Nullability::NonNullable,
108        )
109    }
110
111    #[test]
112    fn matches_fixed_shape_tensor_dtype_metadata() -> VortexResult<()> {
113        let ext_dtype = ExtDType::<FixedShapeTensor>::try_new(
114            FixedShapeTensorMetadata::new(vec![2, 3, 4]),
115            tensor_storage_dtype(PType::F32, 24),
116        )?
117        .erased();
118
119        let metadata = ext_dtype.metadata::<AnyFixedShapeTensor>();
120        assert_eq!(metadata.element_ptype(), PType::F32);
121        assert_eq!(metadata.list_size(), 24);
122        assert_eq!(metadata.metadata().logical_shape(), &[2, 3, 4]);
123        Ok(())
124    }
125
126    #[test]
127    fn does_not_match_vector() -> VortexResult<()> {
128        let ext_dtype =
129            ExtDType::<Vector>::try_new(EmptyMetadata, tensor_storage_dtype(PType::F32, 24))?
130                .erased();
131
132        assert!(ext_dtype.metadata_opt::<AnyFixedShapeTensor>().is_none());
133        Ok(())
134    }
135}