Skip to main content

vortex_tensor/
matcher.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Matcher for tensor-like extension types.
5
6use vortex_array::dtype::PType;
7use vortex_array::dtype::extension::ExtDTypeRef;
8use vortex_array::dtype::extension::Matcher;
9
10use crate::types::fixed_shape_tensor::AnyFixedShapeTensor;
11use crate::types::fixed_shape_tensor::FixedShapeTensorMatcherMetadata;
12use crate::types::vector::AnyVector;
13use crate::types::vector::VectorMatcherMetadata;
14
15/// Matcher for any tensor-like extension type.
16///
17/// Currently the different kinds of tensors that are available are:
18///
19/// - `FixedShapeTensor`
20/// - `Vector`
21pub struct AnyTensor;
22
23/// The matched variant of a tensor-like extension type.
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum TensorMatch<'a> {
26    /// A [`FixedShapeTensor`](crate::fixed_shape_tensor::FixedShapeTensor) extension type.
27    FixedShapeTensor(FixedShapeTensorMatcherMetadata<'a>),
28
29    /// A [`Vector`](crate::vector::Vector) extension type.
30    ///
31    /// Note that we store an owned type here wrapping (copyable) data from the dtype.
32    Vector(VectorMatcherMetadata),
33}
34
35impl TensorMatch<'_> {
36    /// Returns the tensor element type for this tensor-like dtype.
37    pub fn element_ptype(self) -> PType {
38        match self {
39            Self::FixedShapeTensor(metadata) => metadata.element_ptype(),
40            Self::Vector(metadata) => metadata.element_ptype(),
41        }
42    }
43
44    /// Returns the flattened element count for each logical tensor row.
45    pub fn list_size(self) -> u32 {
46        match self {
47            Self::FixedShapeTensor(metadata) => metadata.flat_list_size(),
48            Self::Vector(metadata) => metadata.dimensions(),
49        }
50    }
51}
52
53impl Matcher for AnyTensor {
54    type Match<'a> = TensorMatch<'a>;
55
56    fn try_match<'a>(ext_dtype: &'a ExtDTypeRef) -> Option<Self::Match<'a>> {
57        if let Some(metadata) = ext_dtype.metadata_opt::<AnyFixedShapeTensor>() {
58            return Some(TensorMatch::FixedShapeTensor(metadata));
59        }
60
61        // Special logic for vectors to get convenience metadata (instead of `EmptyMetadata`).
62        if let Some(metadata) = ext_dtype.metadata_opt::<AnyVector>() {
63            return Some(TensorMatch::Vector(metadata));
64        }
65
66        None
67    }
68}