vortex_array/scalar_fn/
erased.rs1use std::any::type_name;
7use std::fmt::Debug;
8use std::fmt::Display;
9use std::fmt::Formatter;
10use std::hash::Hash;
11use std::hash::Hasher;
12use std::sync::Arc;
13
14use vortex_error::VortexExpect;
15use vortex_error::VortexResult;
16use vortex_error::vortex_err;
17use vortex_utils::debug_with::DebugWith;
18
19use crate::ArrayRef;
20use crate::ExecutionCtx;
21use crate::dtype::DType;
22use crate::expr::Expression;
23use crate::expr::StatsCatalog;
24use crate::expr::stats::Stat;
25use crate::scalar_fn::EmptyOptions;
26use crate::scalar_fn::ExecutionArgs;
27use crate::scalar_fn::ReduceCtx;
28use crate::scalar_fn::ReduceNode;
29use crate::scalar_fn::ReduceNodeRef;
30use crate::scalar_fn::ScalarFnId;
31use crate::scalar_fn::ScalarFnVTable;
32use crate::scalar_fn::ScalarFnVTableExt;
33use crate::scalar_fn::SimplifyCtx;
34use crate::scalar_fn::fns::is_not_null::IsNotNull;
35use crate::scalar_fn::options::ScalarFnOptions;
36use crate::scalar_fn::signature::ScalarFnSignature;
37use crate::scalar_fn::typed::DynScalarFn;
38use crate::scalar_fn::typed::TypedScalarFnInstance;
39
40#[derive(Clone)]
48pub struct ScalarFnRef(pub(super) Arc<dyn DynScalarFn>);
49
50impl ScalarFnRef {
51 pub fn id(&self) -> ScalarFnId {
53 self.0.id()
54 }
55
56 pub fn is<V: ScalarFnVTable>(&self) -> bool {
58 self.0.as_any().is::<TypedScalarFnInstance<V>>()
59 }
60
61 pub fn as_opt<V: ScalarFnVTable>(&self) -> Option<&V::Options> {
63 self.0
64 .as_any()
65 .downcast_ref::<TypedScalarFnInstance<V>>()
66 .map(|sf| sf.options())
67 }
68
69 pub fn as_<V: ScalarFnVTable>(&self) -> &V::Options {
75 self.as_opt::<V>()
76 .vortex_expect("Expression options type mismatch")
77 }
78
79 pub fn try_downcast<V: ScalarFnVTable>(
83 self,
84 ) -> Result<Arc<TypedScalarFnInstance<V>>, ScalarFnRef> {
85 if self.0.as_any().is::<TypedScalarFnInstance<V>>() {
86 let ptr = Arc::into_raw(self.0) as *const TypedScalarFnInstance<V>;
87 Ok(unsafe { Arc::from_raw(ptr) })
88 } else {
89 Err(self)
90 }
91 }
92
93 pub fn downcast<V: ScalarFnVTable>(self) -> Arc<TypedScalarFnInstance<V>> {
99 self.try_downcast::<V>()
100 .map_err(|this| {
101 vortex_err!(
102 "Failed to downcast ScalarFnRef {} to {}",
103 this.0.id(),
104 type_name::<V>(),
105 )
106 })
107 .vortex_expect("Failed to downcast ScalarFnRef")
108 }
109
110 pub fn downcast_ref<V: ScalarFnVTable>(&self) -> Option<&TypedScalarFnInstance<V>> {
112 self.0.as_any().downcast_ref::<TypedScalarFnInstance<V>>()
113 }
114
115 pub fn options(&self) -> ScalarFnOptions<'_> {
117 ScalarFnOptions { inner: &*self.0 }
118 }
119
120 pub fn signature(&self) -> ScalarFnSignature<'_> {
122 ScalarFnSignature { inner: &*self.0 }
123 }
124
125 pub fn return_dtype(&self, arg_types: &[DType]) -> VortexResult<DType> {
127 self.0.return_dtype(arg_types)
128 }
129
130 pub fn coerce_args(&self, arg_types: &[DType]) -> VortexResult<Vec<DType>> {
132 self.0.coerce_args(arg_types)
133 }
134
135 pub fn validity(&self, expr: &Expression) -> VortexResult<Expression> {
137 Ok(self.0.validity(expr)?.unwrap_or_else(|| {
138 IsNotNull.new_expr(EmptyOptions, [expr.clone()])
140 }))
141 }
142
143 pub fn execute(
145 &self,
146 args: &dyn ExecutionArgs,
147 ctx: &mut ExecutionCtx,
148 ) -> VortexResult<ArrayRef> {
149 self.0.execute(args, ctx)
150 }
151
152 pub fn reduce(
154 &self,
155 node: &dyn ReduceNode,
156 ctx: &dyn ReduceCtx,
157 ) -> VortexResult<Option<ReduceNodeRef>> {
158 self.0.reduce(node, ctx)
159 }
160
161 pub(crate) fn fmt_sql(&self, expr: &Expression, f: &mut Formatter<'_>) -> std::fmt::Result {
167 self.0.fmt_sql(expr, f)
168 }
169
170 pub(crate) fn simplify(
172 &self,
173 expr: &Expression,
174 ctx: &dyn SimplifyCtx,
175 ) -> VortexResult<Option<Expression>> {
176 self.0.simplify(expr, ctx)
177 }
178
179 pub(crate) fn simplify_untyped(&self, expr: &Expression) -> VortexResult<Option<Expression>> {
181 self.0.simplify_untyped(expr)
182 }
183
184 pub(crate) fn stat_falsification(
186 &self,
187 expr: &Expression,
188 catalog: &dyn StatsCatalog,
189 ) -> Option<Expression> {
190 self.0.stat_falsification(expr, catalog)
191 }
192
193 pub(crate) fn stat_expression(
195 &self,
196 expr: &Expression,
197 stat: Stat,
198 catalog: &dyn StatsCatalog,
199 ) -> Option<Expression> {
200 self.0.stat_expression(expr, stat, catalog)
201 }
202}
203
204impl Debug for ScalarFnRef {
205 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
206 f.debug_struct("ScalarFnRef")
207 .field("vtable", &self.0.id())
208 .field("options", &DebugWith(|fmt| self.0.options_debug(fmt)))
209 .finish()
210 }
211}
212
213impl Display for ScalarFnRef {
214 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
215 write!(f, "{}(", self.0.id())?;
216 self.0.options_display(f)?;
217 write!(f, ")")
218 }
219}
220
221impl PartialEq for ScalarFnRef {
222 fn eq(&self, other: &Self) -> bool {
223 self.0.id() == other.0.id() && self.0.options_eq(other.0.options_any())
224 }
225}
226impl Eq for ScalarFnRef {}
227
228impl Hash for ScalarFnRef {
229 fn hash<H: Hasher>(&self, state: &mut H) {
230 self.0.id().hash(state);
231 self.0.options_hash(state);
232 }
233}