1use vortex_array::dtype::PType;
7use vortex_array::dtype::extension::ExtDTypeRef;
8use vortex_array::dtype::extension::Matcher;
9
10use crate::fixed_shape::AnyFixedShapeTensor;
11use crate::fixed_shape::FixedShapeTensorMatcherMetadata;
12use crate::vector::AnyVector;
13use crate::vector::VectorMatcherMetadata;
14
15pub struct AnyTensor;
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum TensorMatch<'a> {
26 FixedShapeTensor(FixedShapeTensorMatcherMetadata<'a>),
28
29 Vector(VectorMatcherMetadata),
33}
34
35impl TensorMatch<'_> {
36 pub fn element_ptype(self) -> PType {
38 match self {
39 Self::FixedShapeTensor(metadata) => metadata.element_ptype(),
40 Self::Vector(metadata) => metadata.element_ptype(),
41 }
42 }
43
44 pub fn list_size(self) -> usize {
46 match self {
47 Self::FixedShapeTensor(metadata) => metadata.list_size(),
48 Self::Vector(metadata) => metadata.dimensions() as usize,
49 }
50 }
51}
52
53impl Matcher for AnyTensor {
54 type Match<'a> = TensorMatch<'a>;
55
56 fn try_match<'a>(ext_dtype: &'a ExtDTypeRef) -> Option<Self::Match<'a>> {
57 if let Some(metadata) = ext_dtype.metadata_opt::<AnyFixedShapeTensor>() {
58 return Some(TensorMatch::FixedShapeTensor(metadata));
59 }
60
61 if let Some(metadata) = ext_dtype.metadata_opt::<AnyVector>() {
63 return Some(TensorMatch::Vector(metadata));
64 }
65
66 None
67 }
68}