Skip to main content

vortex_array/scalar_fn/
erased.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Type-erased scalar function ([`ScalarFnRef`]).
5
6use std::any::type_name;
7use std::fmt::Debug;
8use std::fmt::Display;
9use std::fmt::Formatter;
10use std::hash::Hash;
11use std::hash::Hasher;
12use std::sync::Arc;
13
14use vortex_error::VortexExpect;
15use vortex_error::VortexResult;
16use vortex_error::vortex_err;
17use vortex_utils::debug_with::DebugWith;
18
19use crate::ArrayRef;
20use crate::ExecutionCtx;
21use crate::dtype::DType;
22use crate::expr::Expression;
23use crate::expr::StatsCatalog;
24use crate::expr::stats::Stat;
25use crate::scalar_fn::EmptyOptions;
26use crate::scalar_fn::ExecutionArgs;
27use crate::scalar_fn::ReduceCtx;
28use crate::scalar_fn::ReduceNode;
29use crate::scalar_fn::ReduceNodeRef;
30use crate::scalar_fn::ScalarFnId;
31use crate::scalar_fn::ScalarFnVTable;
32use crate::scalar_fn::ScalarFnVTableExt;
33use crate::scalar_fn::SimplifyCtx;
34use crate::scalar_fn::fns::is_null::IsNull;
35use crate::scalar_fn::fns::not::Not;
36use crate::scalar_fn::options::ScalarFnOptions;
37use crate::scalar_fn::signature::ScalarFnSignature;
38use crate::scalar_fn::typed::DynScalarFn;
39use crate::scalar_fn::typed::ScalarFn;
40
41/// A type-erased scalar function, pairing a vtable with bound options behind a trait object.
42///
43/// This stores a [`ScalarFnVTable`] and its options behind an `Arc<dyn DynScalarFn>`, allowing
44/// heterogeneous storage inside [`Expression`] and [`crate::arrays::ScalarFnArray`].
45///
46/// Use [`super::ScalarFn::new()`] to construct, and [`super::ScalarFn::erased()`] to obtain a
47/// [`ScalarFnRef`].
48#[derive(Clone)]
49pub struct ScalarFnRef(pub(super) Arc<dyn DynScalarFn>);
50
51impl ScalarFnRef {
52    /// Returns the ID of this scalar function.
53    pub fn id(&self) -> ScalarFnId {
54        self.0.id()
55    }
56
57    /// Returns whether the scalar function is of the given vtable type.
58    pub fn is<V: ScalarFnVTable>(&self) -> bool {
59        self.0.as_any().is::<ScalarFn<V>>()
60    }
61
62    /// Returns the typed options for this scalar function if it matches the given vtable type.
63    pub fn as_opt<V: ScalarFnVTable>(&self) -> Option<&V::Options> {
64        self.0
65            .as_any()
66            .downcast_ref::<ScalarFn<V>>()
67            .map(|sf| sf.options())
68    }
69
70    /// Returns the typed options for this scalar function if it matches the given vtable type.
71    ///
72    /// # Panics
73    ///
74    /// Panics if the vtable type does not match.
75    pub fn as_<V: ScalarFnVTable>(&self) -> &V::Options {
76        self.as_opt::<V>()
77            .vortex_expect("Expression options type mismatch")
78    }
79
80    /// Downcast to the concrete [`ScalarFn`].
81    ///
82    /// Returns `Err(self)` if the downcast fails.
83    pub fn try_downcast<V: ScalarFnVTable>(self) -> Result<Arc<ScalarFn<V>>, ScalarFnRef> {
84        if self.0.as_any().is::<ScalarFn<V>>() {
85            let ptr = Arc::into_raw(self.0) as *const ScalarFn<V>;
86            Ok(unsafe { Arc::from_raw(ptr) })
87        } else {
88            Err(self)
89        }
90    }
91
92    /// Downcast to the concrete [`ScalarFn`].
93    ///
94    /// # Panics
95    ///
96    /// Panics if the downcast fails.
97    pub fn downcast<V: ScalarFnVTable>(self) -> Arc<ScalarFn<V>> {
98        self.try_downcast::<V>()
99            .map_err(|this| {
100                vortex_err!(
101                    "Failed to downcast ScalarFnRef {} to {}",
102                    this.0.id(),
103                    type_name::<V>(),
104                )
105            })
106            .vortex_expect("Failed to downcast ScalarFnRef")
107    }
108
109    /// Try to downcast into a typed [`ScalarFn`].
110    pub fn downcast_ref<V: ScalarFnVTable>(&self) -> Option<&ScalarFn<V>> {
111        self.0.as_any().downcast_ref::<ScalarFn<V>>()
112    }
113
114    /// The type-erased options for this scalar function.
115    pub fn options(&self) -> ScalarFnOptions<'_> {
116        ScalarFnOptions { inner: &*self.0 }
117    }
118
119    /// Signature information for this scalar function.
120    pub fn signature(&self) -> ScalarFnSignature<'_> {
121        ScalarFnSignature { inner: &*self.0 }
122    }
123
124    /// Compute the return [`DType`] of this expression given the input argument types.
125    pub fn return_dtype(&self, arg_types: &[DType]) -> VortexResult<DType> {
126        self.0.return_dtype(arg_types)
127    }
128
129    /// Coerce the argument types for this scalar function.
130    pub fn coerce_args(&self, arg_types: &[DType]) -> VortexResult<Vec<DType>> {
131        self.0.coerce_args(arg_types)
132    }
133
134    /// Transforms the expression into one representing the validity of this expression.
135    pub fn validity(&self, expr: &Expression) -> VortexResult<Expression> {
136        Ok(self.0.validity(expr)?.unwrap_or_else(|| {
137            // TODO(ngates): make validity a mandatory method on VTable to avoid this fallback.
138            // TODO(ngates): add an IsNotNull expression.
139            Not.new_expr(
140                EmptyOptions,
141                [IsNull.new_expr(EmptyOptions, [expr.clone()])],
142            )
143        }))
144    }
145
146    /// Execute the expression given the input arguments.
147    pub fn execute(
148        &self,
149        args: &dyn ExecutionArgs,
150        ctx: &mut ExecutionCtx,
151    ) -> VortexResult<ArrayRef> {
152        self.0.execute(args, ctx)
153    }
154
155    /// Perform abstract reduction on this scalar function node.
156    pub fn reduce(
157        &self,
158        node: &dyn ReduceNode,
159        ctx: &dyn ReduceCtx,
160    ) -> VortexResult<Option<ReduceNodeRef>> {
161        self.0.reduce(node, ctx)
162    }
163
164    // ------------------------------------------------------------------
165    // Expression-taking methods — used by expr/ module via pub(crate)
166    // ------------------------------------------------------------------
167
168    /// Format this expression in SQL-style format.
169    pub(crate) fn fmt_sql(&self, expr: &Expression, f: &mut Formatter<'_>) -> std::fmt::Result {
170        self.0.fmt_sql(expr, f)
171    }
172
173    /// Simplify the expression using type information.
174    pub(crate) fn simplify(
175        &self,
176        expr: &Expression,
177        ctx: &dyn SimplifyCtx,
178    ) -> VortexResult<Option<Expression>> {
179        self.0.simplify(expr, ctx)
180    }
181
182    /// Simplify the expression without type information.
183    pub(crate) fn simplify_untyped(&self, expr: &Expression) -> VortexResult<Option<Expression>> {
184        self.0.simplify_untyped(expr)
185    }
186
187    /// Compute stat falsification expression.
188    pub(crate) fn stat_falsification(
189        &self,
190        expr: &Expression,
191        catalog: &dyn StatsCatalog,
192    ) -> Option<Expression> {
193        self.0.stat_falsification(expr, catalog)
194    }
195
196    /// Compute stat expression.
197    pub(crate) fn stat_expression(
198        &self,
199        expr: &Expression,
200        stat: Stat,
201        catalog: &dyn StatsCatalog,
202    ) -> Option<Expression> {
203        self.0.stat_expression(expr, stat, catalog)
204    }
205}
206
207impl Debug for ScalarFnRef {
208    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
209        f.debug_struct("ScalarFnRef")
210            .field("vtable", &self.0.id())
211            .field("options", &DebugWith(|fmt| self.0.options_debug(fmt)))
212            .finish()
213    }
214}
215
216impl Display for ScalarFnRef {
217    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
218        write!(f, "{}(", self.0.id())?;
219        self.0.options_display(f)?;
220        write!(f, ")")
221    }
222}
223
224impl PartialEq for ScalarFnRef {
225    fn eq(&self, other: &Self) -> bool {
226        self.0.id() == other.0.id() && self.0.options_eq(other.0.options_any())
227    }
228}
229impl Eq for ScalarFnRef {}
230
231impl Hash for ScalarFnRef {
232    fn hash<H: Hasher>(&self, state: &mut H) {
233        self.0.id().hash(state);
234        self.0.options_hash(state);
235    }
236}