Skip to main content

vortex_array/scalar_fn/
erased.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Type-erased scalar function ([`ScalarFnRef`]).
5
6use std::fmt::Debug;
7use std::fmt::Display;
8use std::fmt::Formatter;
9use std::hash::Hash;
10use std::hash::Hasher;
11use std::sync::Arc;
12
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15use vortex_utils::debug_with::DebugWith;
16
17use crate::ArrayRef;
18use crate::dtype::DType;
19use crate::expr::Expression;
20use crate::expr::StatsCatalog;
21use crate::expr::stats::Stat;
22use crate::scalar_fn::EmptyOptions;
23use crate::scalar_fn::ExecutionArgs;
24use crate::scalar_fn::ReduceCtx;
25use crate::scalar_fn::ReduceNode;
26use crate::scalar_fn::ReduceNodeRef;
27use crate::scalar_fn::ScalarFnId;
28use crate::scalar_fn::ScalarFnVTable;
29use crate::scalar_fn::ScalarFnVTableExt;
30use crate::scalar_fn::SimplifyCtx;
31use crate::scalar_fn::fns::is_null::IsNull;
32use crate::scalar_fn::fns::not::Not;
33use crate::scalar_fn::options::ScalarFnOptions;
34use crate::scalar_fn::signature::ScalarFnSignature;
35use crate::scalar_fn::typed::DynScalarFn;
36use crate::scalar_fn::typed::ScalarFnInner;
37
38/// A type-erased scalar function, pairing a vtable with bound options behind a trait object.
39///
40/// This stores a [`ScalarFnVTable`] and its options behind an `Arc<dyn DynScalarFn>`, allowing
41/// heterogeneous storage inside [`Expression`] and [`crate::arrays::ScalarFnArray`].
42///
43/// Use [`super::ScalarFn::new()`] to construct, and [`super::ScalarFn::erased()`] to obtain a
44/// [`ScalarFnRef`].
45#[derive(Clone)]
46pub struct ScalarFnRef(pub(crate) Arc<dyn DynScalarFn>);
47
48impl ScalarFnRef {
49    /// Returns the ID of this scalar function.
50    pub fn id(&self) -> ScalarFnId {
51        self.0.id()
52    }
53
54    /// Returns whether the scalar function is of the given vtable type.
55    pub fn is<V: ScalarFnVTable>(&self) -> bool {
56        self.0.as_any().is::<ScalarFnInner<V>>()
57    }
58
59    /// Returns the typed options for this scalar function if it matches the given vtable type.
60    pub fn as_opt<V: ScalarFnVTable>(&self) -> Option<&V::Options> {
61        self.downcast_inner::<V>().map(|inner| &inner.options)
62    }
63
64    /// Returns a reference to the typed vtable if it matches the given vtable type.
65    pub fn vtable_ref<V: ScalarFnVTable>(&self) -> Option<&V> {
66        self.downcast_inner::<V>().map(|inner| &inner.vtable)
67    }
68
69    /// Downcast the inner to the concrete `ScalarFnInner<V>`.
70    fn downcast_inner<V: ScalarFnVTable>(&self) -> Option<&ScalarFnInner<V>> {
71        self.0.as_any().downcast_ref::<ScalarFnInner<V>>()
72    }
73
74    /// Returns the typed options for this scalar function if it matches the given vtable type.
75    ///
76    /// # Panics
77    ///
78    /// Panics if the vtable type does not match.
79    pub fn as_<V: ScalarFnVTable>(&self) -> &V::Options {
80        self.as_opt::<V>()
81            .vortex_expect("Expression options type mismatch")
82    }
83
84    /// The type-erased options for this scalar function.
85    pub fn options(&self) -> ScalarFnOptions<'_> {
86        ScalarFnOptions { inner: &*self.0 }
87    }
88
89    /// Signature information for this scalar function.
90    pub fn signature(&self) -> ScalarFnSignature<'_> {
91        ScalarFnSignature { inner: &*self.0 }
92    }
93
94    /// Compute the return [`DType`] of this expression given the input argument types.
95    pub fn return_dtype(&self, arg_types: &[DType]) -> VortexResult<DType> {
96        self.0.return_dtype(arg_types)
97    }
98
99    /// Transforms the expression into one representing the validity of this expression.
100    pub fn validity(&self, expr: &Expression) -> VortexResult<Expression> {
101        Ok(self.0.validity(expr)?.unwrap_or_else(|| {
102            // TODO(ngates): make validity a mandatory method on VTable to avoid this fallback.
103            // TODO(ngates): add an IsNotNull expression.
104            Not.new_expr(
105                EmptyOptions,
106                [IsNull.new_expr(EmptyOptions, [expr.clone()])],
107            )
108        }))
109    }
110
111    /// Execute the expression given the input arguments.
112    pub fn execute(&self, ctx: ExecutionArgs) -> VortexResult<ArrayRef> {
113        self.0.execute(ctx)
114    }
115
116    /// Perform abstract reduction on this scalar function node.
117    pub fn reduce(
118        &self,
119        node: &dyn ReduceNode,
120        ctx: &dyn ReduceCtx,
121    ) -> VortexResult<Option<ReduceNodeRef>> {
122        self.0.reduce(node, ctx)
123    }
124
125    // ------------------------------------------------------------------
126    // Expression-taking methods — used by expr/ module via pub(crate)
127    // ------------------------------------------------------------------
128
129    /// Format this expression in SQL-style format.
130    pub(crate) fn fmt_sql(&self, expr: &Expression, f: &mut Formatter<'_>) -> std::fmt::Result {
131        self.0.fmt_sql(expr, f)
132    }
133
134    /// Simplify the expression using type information.
135    pub(crate) fn simplify(
136        &self,
137        expr: &Expression,
138        ctx: &dyn SimplifyCtx,
139    ) -> VortexResult<Option<Expression>> {
140        self.0.simplify(expr, ctx)
141    }
142
143    /// Simplify the expression without type information.
144    pub(crate) fn simplify_untyped(&self, expr: &Expression) -> VortexResult<Option<Expression>> {
145        self.0.simplify_untyped(expr)
146    }
147
148    /// Compute stat falsification expression.
149    pub(crate) fn stat_falsification(
150        &self,
151        expr: &Expression,
152        catalog: &dyn StatsCatalog,
153    ) -> Option<Expression> {
154        self.0.stat_falsification(expr, catalog)
155    }
156
157    /// Compute stat expression.
158    pub(crate) fn stat_expression(
159        &self,
160        expr: &Expression,
161        stat: Stat,
162        catalog: &dyn StatsCatalog,
163    ) -> Option<Expression> {
164        self.0.stat_expression(expr, stat, catalog)
165    }
166}
167
168impl Debug for ScalarFnRef {
169    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
170        f.debug_struct("ScalarFnRef")
171            .field("vtable", &self.0.id())
172            .field("options", &DebugWith(|fmt| self.0.options_debug(fmt)))
173            .finish()
174    }
175}
176
177impl Display for ScalarFnRef {
178    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
179        write!(f, "{}(", self.0.id())?;
180        self.0.options_display(f)?;
181        write!(f, ")")
182    }
183}
184
185impl PartialEq for ScalarFnRef {
186    fn eq(&self, other: &Self) -> bool {
187        self.0.id() == other.0.id() && self.0.options_eq(other.0.options_any())
188    }
189}
190impl Eq for ScalarFnRef {}
191
192impl Hash for ScalarFnRef {
193    fn hash<H: Hasher>(&self, state: &mut H) {
194        self.0.id().hash(state);
195        self.0.options_hash(state);
196    }
197}