vortex_tensor/vector/
matcher.rs1use vortex_array::dtype::DType;
5use vortex_array::dtype::PType;
6use vortex_array::dtype::extension::ExtDTypeRef;
7use vortex_array::dtype::extension::Matcher;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10use vortex_error::vortex_ensure;
11use vortex_error::vortex_panic;
12
13use crate::vector::Vector;
14
15pub struct AnyVector;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
28pub struct VectorMatcherMetadata {
29 element_ptype: PType,
33
34 dimensions: u32,
36}
37
38impl Matcher for AnyVector {
39 type Match<'a> = VectorMatcherMetadata;
40
41 fn try_match<'a>(ext_dtype: &'a ExtDTypeRef) -> Option<Self::Match<'a>> {
42 if !ext_dtype.is::<Vector>() {
43 return None;
44 }
45
46 let DType::FixedSizeList(element_dtype, list_size, _) = ext_dtype.storage_dtype() else {
47 vortex_panic!("`Vector` type somehow did not have a `FixedSizeList` storage type")
48 };
49
50 let dimensions = *list_size;
51
52 assert!(element_dtype.is_float(), "element dtype must be primitive");
53 assert!(
54 !element_dtype.is_nullable(),
55 "element dtype must be non-nullable"
56 );
57 let element_ptype = element_dtype.as_ptype();
58
59 let vector_metadata = VectorMatcherMetadata::try_new(element_ptype, dimensions)
60 .vortex_expect("`Vector` type somehow did not have float elements");
61
62 Some(vector_metadata)
63 }
64}
65
66impl VectorMatcherMetadata {
67 pub fn try_new(element_ptype: PType, dimensions: u32) -> VortexResult<Self> {
73 vortex_ensure!(element_ptype.is_float());
74
75 Ok(Self {
76 element_ptype,
77 dimensions,
78 })
79 }
80
81 pub fn element_ptype(&self) -> PType {
83 self.element_ptype
84 }
85
86 pub fn dimensions(&self) -> u32 {
88 self.dimensions
89 }
90}
91
92#[cfg(test)]
93mod tests {
94 use std::sync::Arc;
95
96 use vortex_array::dtype::DType;
97 use vortex_array::dtype::Nullability;
98 use vortex_array::dtype::PType;
99 use vortex_array::dtype::extension::ExtDType;
100 use vortex_array::extension::EmptyMetadata;
101 use vortex_error::VortexResult;
102
103 use super::*;
104 use crate::fixed_shape::FixedShapeTensor;
105 use crate::fixed_shape::FixedShapeTensorMetadata;
106
107 fn vector_storage_dtype(element_ptype: PType, dimensions: u32) -> DType {
108 DType::FixedSizeList(
109 Arc::new(DType::Primitive(element_ptype, Nullability::NonNullable)),
110 dimensions,
111 Nullability::NonNullable,
112 )
113 }
114
115 #[test]
116 fn matches_vector_dtype_metadata() -> VortexResult<()> {
117 let ext_dtype =
118 ExtDType::<Vector>::try_new(EmptyMetadata, vector_storage_dtype(PType::F32, 256))?
119 .erased();
120
121 let metadata = ext_dtype.metadata::<AnyVector>();
122 assert_eq!(metadata.element_ptype(), PType::F32);
123 assert_eq!(metadata.dimensions(), 256);
124 Ok(())
125 }
126
127 #[test]
128 fn does_not_match_fixed_shape_tensor() -> VortexResult<()> {
129 let ext_dtype = ExtDType::<FixedShapeTensor>::try_new(
130 FixedShapeTensorMetadata::new(vec![16, 16]),
131 vector_storage_dtype(PType::F32, 256),
132 )?
133 .erased();
134
135 assert!(ext_dtype.metadata_opt::<AnyVector>().is_none());
136 Ok(())
137 }
138}