Skip to main content

vortex_tensor/vector/
matcher.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::dtype::DType;
5use vortex_array::dtype::PType;
6use vortex_array::dtype::extension::ExtDTypeRef;
7use vortex_array::dtype::extension::Matcher;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10use vortex_error::vortex_ensure;
11use vortex_error::vortex_panic;
12
13use crate::vector::Vector;
14
15pub struct AnyVector;
16
17/// Convenience metadata for vectors.
18///
19/// Unlike `FixedShapeTensor`, the [`Vector`] type has `EmptyMetadata` as its metadata because all
20/// of the important information is already stored in the dtype.
21///
22/// However, it is quite inconvenient to repeatedly unwrap the dtype to get the element type of the
23/// vector and the number of dimensions.
24///
25/// Thus, we allow the matcher to return this metadata so that we can access this information more
26/// easily.
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
28pub struct VectorMatcherMetadata {
29    /// The element type of the vectors. Note that vector elements are _always_ non-nullable.
30    ///
31    /// This MUST be a floating point type (f16, f32, f64).
32    element_ptype: PType,
33
34    /// The number of dimensions of the vector. This is always fixed.
35    dimensions: u32,
36}
37
38impl Matcher for AnyVector {
39    type Match<'a> = VectorMatcherMetadata;
40
41    fn try_match<'a>(ext_dtype: &'a ExtDTypeRef) -> Option<Self::Match<'a>> {
42        if !ext_dtype.is::<Vector>() {
43            return None;
44        }
45
46        let DType::FixedSizeList(element_dtype, list_size, _) = ext_dtype.storage_dtype() else {
47            vortex_panic!("`Vector` type somehow did not have a `FixedSizeList` storage type")
48        };
49
50        let dimensions = *list_size;
51
52        assert!(element_dtype.is_float(), "element dtype must be primitive");
53        assert!(
54            !element_dtype.is_nullable(),
55            "element dtype must be non-nullable"
56        );
57        let element_ptype = element_dtype.as_ptype();
58
59        let vector_metadata = VectorMatcherMetadata::try_new(element_ptype, dimensions)
60            .vortex_expect("`Vector` type somehow did not have float elements");
61
62        Some(vector_metadata)
63    }
64}
65
66impl VectorMatcherMetadata {
67    /// Tries to create a new `VectorMatcherMetadata`.
68    ///
69    /// # Errors
70    ///
71    /// Returns an error if the element type is not a float.
72    pub fn try_new(element_ptype: PType, dimensions: u32) -> VortexResult<Self> {
73        vortex_ensure!(element_ptype.is_float());
74
75        Ok(Self {
76            element_ptype,
77            dimensions,
78        })
79    }
80
81    /// Returns the element type of the vectors.
82    pub fn element_ptype(&self) -> PType {
83        self.element_ptype
84    }
85
86    /// Returns the number of dimensions of the vector.
87    pub fn dimensions(&self) -> u32 {
88        self.dimensions
89    }
90}
91
92#[cfg(test)]
93mod tests {
94    use std::sync::Arc;
95
96    use vortex_array::dtype::DType;
97    use vortex_array::dtype::Nullability;
98    use vortex_array::dtype::PType;
99    use vortex_array::dtype::extension::ExtDType;
100    use vortex_array::extension::EmptyMetadata;
101    use vortex_error::VortexResult;
102
103    use super::*;
104    use crate::fixed_shape::FixedShapeTensor;
105    use crate::fixed_shape::FixedShapeTensorMetadata;
106
107    fn vector_storage_dtype(element_ptype: PType, dimensions: u32) -> DType {
108        DType::FixedSizeList(
109            Arc::new(DType::Primitive(element_ptype, Nullability::NonNullable)),
110            dimensions,
111            Nullability::NonNullable,
112        )
113    }
114
115    #[test]
116    fn matches_vector_dtype_metadata() -> VortexResult<()> {
117        let ext_dtype =
118            ExtDType::<Vector>::try_new(EmptyMetadata, vector_storage_dtype(PType::F32, 256))?
119                .erased();
120
121        let metadata = ext_dtype.metadata::<AnyVector>();
122        assert_eq!(metadata.element_ptype(), PType::F32);
123        assert_eq!(metadata.dimensions(), 256);
124        Ok(())
125    }
126
127    #[test]
128    fn does_not_match_fixed_shape_tensor() -> VortexResult<()> {
129        let ext_dtype = ExtDType::<FixedShapeTensor>::try_new(
130            FixedShapeTensorMetadata::new(vec![16, 16]),
131            vector_storage_dtype(PType::F32, 256),
132        )?
133        .erased();
134
135        assert!(ext_dtype.metadata_opt::<AnyVector>().is_none());
136        Ok(())
137    }
138}