vortex_tensor/fixed_shape/
matcher.rs1use vortex_array::dtype::DType;
5use vortex_array::dtype::PType;
6use vortex_array::dtype::extension::ExtDTypeRef;
7use vortex_array::dtype::extension::Matcher;
8use vortex_error::VortexExpect;
9use vortex_error::vortex_panic;
10
11use crate::fixed_shape::FixedShapeTensor;
12use crate::fixed_shape::FixedShapeTensorMetadata;
13
14pub struct AnyFixedShapeTensor;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub struct FixedShapeTensorMatcherMetadata<'a> {
22 metadata: &'a FixedShapeTensorMetadata,
24
25 element_ptype: PType,
29
30 flat_list_size: usize,
35}
36
37impl Matcher for AnyFixedShapeTensor {
38 type Match<'a> = FixedShapeTensorMatcherMetadata<'a>;
39
40 fn try_match<'a>(ext_dtype: &'a ExtDTypeRef) -> Option<Self::Match<'a>> {
41 if !ext_dtype.is::<FixedShapeTensor>() {
42 return None;
43 }
44
45 let metadata = ext_dtype
46 .metadata_opt::<FixedShapeTensor>()
47 .vortex_expect("`FixedShapeTensor` type somehow did not have metadata");
48
49 let DType::FixedSizeList(element_dtype, list_size, _) = ext_dtype.storage_dtype() else {
50 vortex_panic!(
51 "`FixedShapeTensor` type somehow did not have a `FixedSizeList` storage type"
52 )
53 };
54
55 assert!(
56 element_dtype.is_primitive(),
57 "element dtype must be primitive"
58 );
59 assert!(
60 !element_dtype.is_nullable(),
61 "element dtype must be non-nullable"
62 );
63
64 Some(FixedShapeTensorMatcherMetadata {
65 metadata,
66 element_ptype: element_dtype.as_ptype(),
67 flat_list_size: *list_size as usize,
68 })
69 }
70}
71
72impl FixedShapeTensorMatcherMetadata<'_> {
73 pub fn metadata(&self) -> &FixedShapeTensorMetadata {
75 self.metadata
76 }
77
78 pub fn element_ptype(&self) -> PType {
80 self.element_ptype
81 }
82
83 pub fn list_size(&self) -> usize {
85 self.flat_list_size
86 }
87}
88
89#[cfg(test)]
90mod tests {
91 use std::sync::Arc;
92
93 use vortex_array::dtype::DType;
94 use vortex_array::dtype::Nullability;
95 use vortex_array::dtype::PType;
96 use vortex_array::dtype::extension::ExtDType;
97 use vortex_array::extension::EmptyMetadata;
98 use vortex_error::VortexResult;
99
100 use super::*;
101 use crate::vector::Vector;
102
103 fn tensor_storage_dtype(element_ptype: PType, list_size: u32) -> DType {
104 DType::FixedSizeList(
105 Arc::new(DType::Primitive(element_ptype, Nullability::NonNullable)),
106 list_size,
107 Nullability::NonNullable,
108 )
109 }
110
111 #[test]
112 fn matches_fixed_shape_tensor_dtype_metadata() -> VortexResult<()> {
113 let ext_dtype = ExtDType::<FixedShapeTensor>::try_new(
114 FixedShapeTensorMetadata::new(vec![2, 3, 4]),
115 tensor_storage_dtype(PType::F32, 24),
116 )?
117 .erased();
118
119 let metadata = ext_dtype.metadata::<AnyFixedShapeTensor>();
120 assert_eq!(metadata.element_ptype(), PType::F32);
121 assert_eq!(metadata.list_size(), 24);
122 assert_eq!(metadata.metadata().logical_shape(), &[2, 3, 4]);
123 Ok(())
124 }
125
126 #[test]
127 fn does_not_match_vector() -> VortexResult<()> {
128 let ext_dtype =
129 ExtDType::<Vector>::try_new(EmptyMetadata, tensor_storage_dtype(PType::F32, 24))?
130 .erased();
131
132 assert!(ext_dtype.metadata_opt::<AnyFixedShapeTensor>().is_none());
133 Ok(())
134 }
135}