vortex_array/
matchers.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::marker::PhantomData;
5
6use crate::ArrayRef;
7use crate::vtable::ArrayId;
8use crate::vtable::VTable;
9
10/// Trait for matching array types in optimizer rules
11pub trait Matcher: Send + Sync + 'static {
12    type View<'a>;
13
14    /// Return the key for this matcher
15    fn key(&self) -> MatchKey;
16
17    /// Try to match the given array to this matcher type
18    fn try_match<'a>(&self, array: &'a ArrayRef) -> Option<Self::View<'a>>;
19}
20
21#[derive(Clone, Debug, PartialEq, Eq, Hash)]
22pub enum MatchKey {
23    Any,
24    Array(ArrayId),
25}
26
27/// Matches any array type (wildcard matcher)
28#[derive(Debug)]
29pub struct AnyArray;
30
31impl Matcher for AnyArray {
32    type View<'a> = &'a ArrayRef;
33
34    fn key(&self) -> MatchKey {
35        MatchKey::Any
36    }
37
38    fn try_match<'a>(&self, array: &'a ArrayRef) -> Option<Self::View<'a>> {
39        Some(array)
40    }
41}
42
43/// Matches a specific Array by its encoding ID.
44#[derive(Debug)]
45pub struct Exact<V: VTable> {
46    id: ArrayId,
47    _phantom: PhantomData<V>,
48}
49
50impl<V: VTable> Matcher for Exact<V> {
51    type View<'a> = &'a V::Array;
52
53    fn key(&self) -> MatchKey {
54        MatchKey::Array(self.id.clone())
55    }
56
57    fn try_match<'a>(&self, parent: &'a ArrayRef) -> Option<Self::View<'a>> {
58        parent.as_opt::<V>()
59    }
60}
61
62impl<V: VTable> Exact<V> {
63    /// Create a new Exact matcher for the given ArrayId.
64    ///
65    /// # Safety
66    ///
67    /// The optimizer will attempt to downcast the array to the type V when matching.
68    /// If an array with the given ID does not match type V, the rule will silently not be applied.
69    pub unsafe fn new_unchecked(id: ArrayId) -> Self {
70        Self {
71            id,
72            _phantom: PhantomData,
73        }
74    }
75}
76
77impl<V: VTable> From<&'static V> for Exact<V> {
78    fn from(vtable: &'static V) -> Self {
79        Self {
80            id: vtable.id(),
81            _phantom: PhantomData,
82        }
83    }
84}