Skip to main content

vortex_array/scalar_fn/
erased.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Type-erased scalar function ([`ScalarFnRef`]).
5
6use std::fmt::Debug;
7use std::fmt::Display;
8use std::fmt::Formatter;
9use std::hash::Hash;
10use std::hash::Hasher;
11use std::sync::Arc;
12
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15use vortex_utils::debug_with::DebugWith;
16
17use crate::ArrayRef;
18use crate::ExecutionCtx;
19use crate::dtype::DType;
20use crate::expr::Expression;
21use crate::expr::StatsCatalog;
22use crate::expr::stats::Stat;
23use crate::scalar_fn::EmptyOptions;
24use crate::scalar_fn::ExecutionArgs;
25use crate::scalar_fn::ReduceCtx;
26use crate::scalar_fn::ReduceNode;
27use crate::scalar_fn::ReduceNodeRef;
28use crate::scalar_fn::ScalarFnId;
29use crate::scalar_fn::ScalarFnVTable;
30use crate::scalar_fn::ScalarFnVTableExt;
31use crate::scalar_fn::SimplifyCtx;
32use crate::scalar_fn::fns::is_null::IsNull;
33use crate::scalar_fn::fns::not::Not;
34use crate::scalar_fn::options::ScalarFnOptions;
35use crate::scalar_fn::signature::ScalarFnSignature;
36use crate::scalar_fn::typed::DynScalarFn;
37use crate::scalar_fn::typed::ScalarFnInner;
38
39/// A type-erased scalar function, pairing a vtable with bound options behind a trait object.
40///
41/// This stores a [`ScalarFnVTable`] and its options behind an `Arc<dyn DynScalarFn>`, allowing
42/// heterogeneous storage inside [`Expression`] and [`crate::arrays::ScalarFnArray`].
43///
44/// Use [`super::ScalarFn::new()`] to construct, and [`super::ScalarFn::erased()`] to obtain a
45/// [`ScalarFnRef`].
46#[derive(Clone)]
47pub struct ScalarFnRef(pub(super) Arc<dyn DynScalarFn>);
48
49impl ScalarFnRef {
50    /// Returns the ID of this scalar function.
51    pub fn id(&self) -> ScalarFnId {
52        self.0.id()
53    }
54
55    /// Returns whether the scalar function is of the given vtable type.
56    pub fn is<V: ScalarFnVTable>(&self) -> bool {
57        self.0.as_any().is::<ScalarFnInner<V>>()
58    }
59
60    /// Returns the typed options for this scalar function if it matches the given vtable type.
61    pub fn as_opt<V: ScalarFnVTable>(&self) -> Option<&V::Options> {
62        self.downcast_inner::<V>().map(|inner| &inner.options)
63    }
64
65    /// Returns a reference to the typed vtable if it matches the given vtable type.
66    pub fn vtable_ref<V: ScalarFnVTable>(&self) -> Option<&V> {
67        self.downcast_inner::<V>().map(|inner| &inner.vtable)
68    }
69
70    /// Downcast the inner to the concrete `ScalarFnInner<V>`.
71    fn downcast_inner<V: ScalarFnVTable>(&self) -> Option<&ScalarFnInner<V>> {
72        self.0.as_any().downcast_ref::<ScalarFnInner<V>>()
73    }
74
75    /// Returns the typed options for this scalar function if it matches the given vtable type.
76    ///
77    /// # Panics
78    ///
79    /// Panics if the vtable type does not match.
80    pub fn as_<V: ScalarFnVTable>(&self) -> &V::Options {
81        self.as_opt::<V>()
82            .vortex_expect("Expression options type mismatch")
83    }
84
85    /// The type-erased options for this scalar function.
86    pub fn options(&self) -> ScalarFnOptions<'_> {
87        ScalarFnOptions { inner: &*self.0 }
88    }
89
90    /// Signature information for this scalar function.
91    pub fn signature(&self) -> ScalarFnSignature<'_> {
92        ScalarFnSignature { inner: &*self.0 }
93    }
94
95    /// Compute the return [`DType`] of this expression given the input argument types.
96    pub fn return_dtype(&self, arg_types: &[DType]) -> VortexResult<DType> {
97        self.0.return_dtype(arg_types)
98    }
99
100    /// Transforms the expression into one representing the validity of this expression.
101    pub fn validity(&self, expr: &Expression) -> VortexResult<Expression> {
102        Ok(self.0.validity(expr)?.unwrap_or_else(|| {
103            // TODO(ngates): make validity a mandatory method on VTable to avoid this fallback.
104            // TODO(ngates): add an IsNotNull expression.
105            Not.new_expr(
106                EmptyOptions,
107                [IsNull.new_expr(EmptyOptions, [expr.clone()])],
108            )
109        }))
110    }
111
112    /// Execute the expression given the input arguments.
113    pub fn execute(
114        &self,
115        args: &dyn ExecutionArgs,
116        ctx: &mut ExecutionCtx,
117    ) -> VortexResult<ArrayRef> {
118        self.0.execute(args, ctx)
119    }
120
121    /// Perform abstract reduction on this scalar function node.
122    pub fn reduce(
123        &self,
124        node: &dyn ReduceNode,
125        ctx: &dyn ReduceCtx,
126    ) -> VortexResult<Option<ReduceNodeRef>> {
127        self.0.reduce(node, ctx)
128    }
129
130    // ------------------------------------------------------------------
131    // Expression-taking methods — used by expr/ module via pub(crate)
132    // ------------------------------------------------------------------
133
134    /// Format this expression in SQL-style format.
135    pub(crate) fn fmt_sql(&self, expr: &Expression, f: &mut Formatter<'_>) -> std::fmt::Result {
136        self.0.fmt_sql(expr, f)
137    }
138
139    /// Simplify the expression using type information.
140    pub(crate) fn simplify(
141        &self,
142        expr: &Expression,
143        ctx: &dyn SimplifyCtx,
144    ) -> VortexResult<Option<Expression>> {
145        self.0.simplify(expr, ctx)
146    }
147
148    /// Simplify the expression without type information.
149    pub(crate) fn simplify_untyped(&self, expr: &Expression) -> VortexResult<Option<Expression>> {
150        self.0.simplify_untyped(expr)
151    }
152
153    /// Compute stat falsification expression.
154    pub(crate) fn stat_falsification(
155        &self,
156        expr: &Expression,
157        catalog: &dyn StatsCatalog,
158    ) -> Option<Expression> {
159        self.0.stat_falsification(expr, catalog)
160    }
161
162    /// Compute stat expression.
163    pub(crate) fn stat_expression(
164        &self,
165        expr: &Expression,
166        stat: Stat,
167        catalog: &dyn StatsCatalog,
168    ) -> Option<Expression> {
169        self.0.stat_expression(expr, stat, catalog)
170    }
171}
172
173impl Debug for ScalarFnRef {
174    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
175        f.debug_struct("ScalarFnRef")
176            .field("vtable", &self.0.id())
177            .field("options", &DebugWith(|fmt| self.0.options_debug(fmt)))
178            .finish()
179    }
180}
181
182impl Display for ScalarFnRef {
183    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
184        write!(f, "{}(", self.0.id())?;
185        self.0.options_display(f)?;
186        write!(f, ")")
187    }
188}
189
190impl PartialEq for ScalarFnRef {
191    fn eq(&self, other: &Self) -> bool {
192        self.0.id() == other.0.id() && self.0.options_eq(other.0.options_any())
193    }
194}
195impl Eq for ScalarFnRef {}
196
197impl Hash for ScalarFnRef {
198    fn hash<H: Hasher>(&self, state: &mut H) {
199        self.0.id().hash(state);
200        self.0.options_hash(state);
201    }
202}