Skip to main content

vortex_array/scalar_fn/
erased.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Type-erased scalar function ([`ScalarFnRef`]).
5
6use std::any::type_name;
7use std::fmt::Debug;
8use std::fmt::Display;
9use std::fmt::Formatter;
10use std::hash::Hash;
11use std::hash::Hasher;
12use std::sync::Arc;
13
14use vortex_error::VortexExpect;
15use vortex_error::VortexResult;
16use vortex_error::vortex_err;
17use vortex_utils::debug_with::DebugWith;
18
19use crate::ArrayRef;
20use crate::ExecutionCtx;
21use crate::dtype::DType;
22use crate::expr::Expression;
23use crate::expr::StatsCatalog;
24use crate::expr::stats::Stat;
25use crate::scalar_fn::EmptyOptions;
26use crate::scalar_fn::ExecutionArgs;
27use crate::scalar_fn::ReduceCtx;
28use crate::scalar_fn::ReduceNode;
29use crate::scalar_fn::ReduceNodeRef;
30use crate::scalar_fn::ScalarFnId;
31use crate::scalar_fn::ScalarFnVTable;
32use crate::scalar_fn::ScalarFnVTableExt;
33use crate::scalar_fn::SimplifyCtx;
34use crate::scalar_fn::fns::is_not_null::IsNotNull;
35use crate::scalar_fn::options::ScalarFnOptions;
36use crate::scalar_fn::signature::ScalarFnSignature;
37use crate::scalar_fn::typed::DynScalarFn;
38use crate::scalar_fn::typed::TypedScalarFnInstance;
39
40/// A type-erased scalar function, pairing a vtable with bound options behind a trait object.
41///
42/// This stores a [`ScalarFnVTable`] and its options behind an `Arc<dyn DynScalarFn>`, allowing
43/// heterogeneous storage inside [`Expression`] and [`crate::arrays::ScalarFnArray`].
44///
45/// Use [`super::TypedScalarFnInstance::new()`] to construct, and [`super::TypedScalarFnInstance::erased()`] to
46/// obtain a [`ScalarFnRef`].
47#[derive(Clone)]
48pub struct ScalarFnRef(pub(super) Arc<dyn DynScalarFn>);
49
50impl ScalarFnRef {
51    /// Returns the ID of this scalar function.
52    pub fn id(&self) -> ScalarFnId {
53        self.0.id()
54    }
55
56    /// Returns whether the scalar function is of the given vtable type.
57    pub fn is<V: ScalarFnVTable>(&self) -> bool {
58        self.0.as_any().is::<TypedScalarFnInstance<V>>()
59    }
60
61    /// Returns the typed options for this scalar function if it matches the given vtable type.
62    pub fn as_opt<V: ScalarFnVTable>(&self) -> Option<&V::Options> {
63        self.0
64            .as_any()
65            .downcast_ref::<TypedScalarFnInstance<V>>()
66            .map(|sf| sf.options())
67    }
68
69    /// Returns the typed options for this scalar function if it matches the given vtable type.
70    ///
71    /// # Panics
72    ///
73    /// Panics if the vtable type does not match.
74    pub fn as_<V: ScalarFnVTable>(&self) -> &V::Options {
75        self.as_opt::<V>()
76            .vortex_expect("Expression options type mismatch")
77    }
78
79    /// Downcast to the concrete [`TypedScalarFnInstance`].
80    ///
81    /// Returns `Err(self)` if the downcast fails.
82    pub fn try_downcast<V: ScalarFnVTable>(
83        self,
84    ) -> Result<Arc<TypedScalarFnInstance<V>>, ScalarFnRef> {
85        if self.0.as_any().is::<TypedScalarFnInstance<V>>() {
86            let ptr = Arc::into_raw(self.0) as *const TypedScalarFnInstance<V>;
87            Ok(unsafe { Arc::from_raw(ptr) })
88        } else {
89            Err(self)
90        }
91    }
92
93    /// Downcast to the concrete [`TypedScalarFnInstance`].
94    ///
95    /// # Panics
96    ///
97    /// Panics if the downcast fails.
98    pub fn downcast<V: ScalarFnVTable>(self) -> Arc<TypedScalarFnInstance<V>> {
99        self.try_downcast::<V>()
100            .map_err(|this| {
101                vortex_err!(
102                    "Failed to downcast ScalarFnRef {} to {}",
103                    this.0.id(),
104                    type_name::<V>(),
105                )
106            })
107            .vortex_expect("Failed to downcast ScalarFnRef")
108    }
109
110    /// Try to downcast into a typed [`TypedScalarFnInstance`].
111    pub fn downcast_ref<V: ScalarFnVTable>(&self) -> Option<&TypedScalarFnInstance<V>> {
112        self.0.as_any().downcast_ref::<TypedScalarFnInstance<V>>()
113    }
114
115    /// The type-erased options for this scalar function.
116    pub fn options(&self) -> ScalarFnOptions<'_> {
117        ScalarFnOptions { inner: &*self.0 }
118    }
119
120    /// Signature information for this scalar function.
121    pub fn signature(&self) -> ScalarFnSignature<'_> {
122        ScalarFnSignature { inner: &*self.0 }
123    }
124
125    /// Compute the return [`DType`] of this expression given the input argument types.
126    pub fn return_dtype(&self, arg_types: &[DType]) -> VortexResult<DType> {
127        self.0.return_dtype(arg_types)
128    }
129
130    /// Coerce the argument types for this scalar function.
131    pub fn coerce_args(&self, arg_types: &[DType]) -> VortexResult<Vec<DType>> {
132        self.0.coerce_args(arg_types)
133    }
134
135    /// Transforms the expression into one representing the validity of this expression.
136    pub fn validity(&self, expr: &Expression) -> VortexResult<Expression> {
137        Ok(self.0.validity(expr)?.unwrap_or_else(|| {
138            // TODO(ngates): make validity a mandatory method on VTable to avoid this fallback.
139            IsNotNull.new_expr(EmptyOptions, [expr.clone()])
140        }))
141    }
142
143    /// Execute the expression given the input arguments.
144    pub fn execute(
145        &self,
146        args: &dyn ExecutionArgs,
147        ctx: &mut ExecutionCtx,
148    ) -> VortexResult<ArrayRef> {
149        self.0.execute(args, ctx)
150    }
151
152    /// Perform abstract reduction on this scalar function node.
153    pub fn reduce(
154        &self,
155        node: &dyn ReduceNode,
156        ctx: &dyn ReduceCtx,
157    ) -> VortexResult<Option<ReduceNodeRef>> {
158        self.0.reduce(node, ctx)
159    }
160
161    // ------------------------------------------------------------------
162    // Expression-taking methods — used by expr/ module via pub(crate)
163    // ------------------------------------------------------------------
164
165    /// Format this expression in SQL-style format.
166    pub(crate) fn fmt_sql(&self, expr: &Expression, f: &mut Formatter<'_>) -> std::fmt::Result {
167        self.0.fmt_sql(expr, f)
168    }
169
170    /// Simplify the expression using type information.
171    pub(crate) fn simplify(
172        &self,
173        expr: &Expression,
174        ctx: &dyn SimplifyCtx,
175    ) -> VortexResult<Option<Expression>> {
176        self.0.simplify(expr, ctx)
177    }
178
179    /// Simplify the expression without type information.
180    pub(crate) fn simplify_untyped(&self, expr: &Expression) -> VortexResult<Option<Expression>> {
181        self.0.simplify_untyped(expr)
182    }
183
184    /// Compute stat falsification expression.
185    pub(crate) fn stat_falsification(
186        &self,
187        expr: &Expression,
188        catalog: &dyn StatsCatalog,
189    ) -> Option<Expression> {
190        self.0.stat_falsification(expr, catalog)
191    }
192
193    /// Compute stat expression.
194    pub(crate) fn stat_expression(
195        &self,
196        expr: &Expression,
197        stat: Stat,
198        catalog: &dyn StatsCatalog,
199    ) -> Option<Expression> {
200        self.0.stat_expression(expr, stat, catalog)
201    }
202}
203
204impl Debug for ScalarFnRef {
205    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
206        f.debug_struct("ScalarFnRef")
207            .field("vtable", &self.0.id())
208            .field("options", &DebugWith(|fmt| self.0.options_debug(fmt)))
209            .finish()
210    }
211}
212
213impl Display for ScalarFnRef {
214    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
215        write!(f, "{}(", self.0.id())?;
216        self.0.options_display(f)?;
217        write!(f, ")")
218    }
219}
220
221impl PartialEq for ScalarFnRef {
222    fn eq(&self, other: &Self) -> bool {
223        self.0.id() == other.0.id() && self.0.options_eq(other.0.options_any())
224    }
225}
226impl Eq for ScalarFnRef {}
227
228impl Hash for ScalarFnRef {
229    fn hash<H: Hasher>(&self, state: &mut H) {
230        self.0.id().hash(state);
231        self.0.options_hash(state);
232    }
233}