1use vortex_array::dtype::PType;
7use vortex_array::dtype::extension::ExtDTypeRef;
8use vortex_array::dtype::extension::Matcher;
9
10use crate::types::fixed_shape_tensor::AnyFixedShapeTensor;
11use crate::types::fixed_shape_tensor::FixedShapeTensorMatcherMetadata;
12use crate::types::vector::AnyVector;
13use crate::types::vector::VectorMatcherMetadata;
14
15pub struct AnyTensor;
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum TensorMatch<'a> {
26 FixedShapeTensor(FixedShapeTensorMatcherMetadata<'a>),
28
29 Vector(VectorMatcherMetadata),
33}
34
35impl TensorMatch<'_> {
36 pub fn element_ptype(self) -> PType {
38 match self {
39 Self::FixedShapeTensor(metadata) => metadata.element_ptype(),
40 Self::Vector(metadata) => metadata.element_ptype(),
41 }
42 }
43
44 pub fn list_size(self) -> u32 {
46 match self {
47 Self::FixedShapeTensor(metadata) => metadata.flat_list_size(),
48 Self::Vector(metadata) => metadata.dimensions(),
49 }
50 }
51}
52
53impl Matcher for AnyTensor {
54 type Match<'a> = TensorMatch<'a>;
55
56 fn try_match<'a>(ext_dtype: &'a ExtDTypeRef) -> Option<Self::Match<'a>> {
57 if let Some(metadata) = ext_dtype.metadata_opt::<AnyFixedShapeTensor>() {
58 return Some(TensorMatch::FixedShapeTensor(metadata));
59 }
60
61 if let Some(metadata) = ext_dtype.metadata_opt::<AnyVector>() {
63 return Some(TensorMatch::Vector(metadata));
64 }
65
66 None
67 }
68}