vortex_array/scalar_fn/
erased.rs1use std::any::type_name;
7use std::fmt::Debug;
8use std::fmt::Display;
9use std::fmt::Formatter;
10use std::hash::Hash;
11use std::hash::Hasher;
12use std::sync::Arc;
13
14use vortex_error::VortexExpect;
15use vortex_error::VortexResult;
16use vortex_error::vortex_err;
17use vortex_utils::debug_with::DebugWith;
18
19use crate::ArrayRef;
20use crate::ExecutionCtx;
21use crate::dtype::DType;
22use crate::expr::Expression;
23use crate::expr::StatsCatalog;
24use crate::expr::stats::Stat;
25use crate::scalar_fn::EmptyOptions;
26use crate::scalar_fn::ExecutionArgs;
27use crate::scalar_fn::ReduceCtx;
28use crate::scalar_fn::ReduceNode;
29use crate::scalar_fn::ReduceNodeRef;
30use crate::scalar_fn::ScalarFnId;
31use crate::scalar_fn::ScalarFnVTable;
32use crate::scalar_fn::ScalarFnVTableExt;
33use crate::scalar_fn::SimplifyCtx;
34use crate::scalar_fn::fns::is_null::IsNull;
35use crate::scalar_fn::fns::not::Not;
36use crate::scalar_fn::options::ScalarFnOptions;
37use crate::scalar_fn::signature::ScalarFnSignature;
38use crate::scalar_fn::typed::DynScalarFn;
39use crate::scalar_fn::typed::ScalarFn;
40
41#[derive(Clone)]
49pub struct ScalarFnRef(pub(super) Arc<dyn DynScalarFn>);
50
51impl ScalarFnRef {
52 pub fn id(&self) -> ScalarFnId {
54 self.0.id()
55 }
56
57 pub fn is<V: ScalarFnVTable>(&self) -> bool {
59 self.0.as_any().is::<ScalarFn<V>>()
60 }
61
62 pub fn as_opt<V: ScalarFnVTable>(&self) -> Option<&V::Options> {
64 self.0
65 .as_any()
66 .downcast_ref::<ScalarFn<V>>()
67 .map(|sf| sf.options())
68 }
69
70 pub fn as_<V: ScalarFnVTable>(&self) -> &V::Options {
76 self.as_opt::<V>()
77 .vortex_expect("Expression options type mismatch")
78 }
79
80 pub fn try_downcast<V: ScalarFnVTable>(self) -> Result<Arc<ScalarFn<V>>, ScalarFnRef> {
84 if self.0.as_any().is::<ScalarFn<V>>() {
85 let ptr = Arc::into_raw(self.0) as *const ScalarFn<V>;
86 Ok(unsafe { Arc::from_raw(ptr) })
87 } else {
88 Err(self)
89 }
90 }
91
92 pub fn downcast<V: ScalarFnVTable>(self) -> Arc<ScalarFn<V>> {
98 self.try_downcast::<V>()
99 .map_err(|this| {
100 vortex_err!(
101 "Failed to downcast ScalarFnRef {} to {}",
102 this.0.id(),
103 type_name::<V>(),
104 )
105 })
106 .vortex_expect("Failed to downcast ScalarFnRef")
107 }
108
109 pub fn downcast_ref<V: ScalarFnVTable>(&self) -> Option<&ScalarFn<V>> {
111 self.0.as_any().downcast_ref::<ScalarFn<V>>()
112 }
113
114 pub fn options(&self) -> ScalarFnOptions<'_> {
116 ScalarFnOptions { inner: &*self.0 }
117 }
118
119 pub fn signature(&self) -> ScalarFnSignature<'_> {
121 ScalarFnSignature { inner: &*self.0 }
122 }
123
124 pub fn return_dtype(&self, arg_types: &[DType]) -> VortexResult<DType> {
126 self.0.return_dtype(arg_types)
127 }
128
129 pub fn validity(&self, expr: &Expression) -> VortexResult<Expression> {
131 Ok(self.0.validity(expr)?.unwrap_or_else(|| {
132 Not.new_expr(
135 EmptyOptions,
136 [IsNull.new_expr(EmptyOptions, [expr.clone()])],
137 )
138 }))
139 }
140
141 pub fn execute(
143 &self,
144 args: &dyn ExecutionArgs,
145 ctx: &mut ExecutionCtx,
146 ) -> VortexResult<ArrayRef> {
147 self.0.execute(args, ctx)
148 }
149
150 pub fn reduce(
152 &self,
153 node: &dyn ReduceNode,
154 ctx: &dyn ReduceCtx,
155 ) -> VortexResult<Option<ReduceNodeRef>> {
156 self.0.reduce(node, ctx)
157 }
158
159 pub(crate) fn fmt_sql(&self, expr: &Expression, f: &mut Formatter<'_>) -> std::fmt::Result {
165 self.0.fmt_sql(expr, f)
166 }
167
168 pub(crate) fn simplify(
170 &self,
171 expr: &Expression,
172 ctx: &dyn SimplifyCtx,
173 ) -> VortexResult<Option<Expression>> {
174 self.0.simplify(expr, ctx)
175 }
176
177 pub(crate) fn simplify_untyped(&self, expr: &Expression) -> VortexResult<Option<Expression>> {
179 self.0.simplify_untyped(expr)
180 }
181
182 pub(crate) fn stat_falsification(
184 &self,
185 expr: &Expression,
186 catalog: &dyn StatsCatalog,
187 ) -> Option<Expression> {
188 self.0.stat_falsification(expr, catalog)
189 }
190
191 pub(crate) fn stat_expression(
193 &self,
194 expr: &Expression,
195 stat: Stat,
196 catalog: &dyn StatsCatalog,
197 ) -> Option<Expression> {
198 self.0.stat_expression(expr, stat, catalog)
199 }
200}
201
202impl Debug for ScalarFnRef {
203 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
204 f.debug_struct("ScalarFnRef")
205 .field("vtable", &self.0.id())
206 .field("options", &DebugWith(|fmt| self.0.options_debug(fmt)))
207 .finish()
208 }
209}
210
211impl Display for ScalarFnRef {
212 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
213 write!(f, "{}(", self.0.id())?;
214 self.0.options_display(f)?;
215 write!(f, ")")
216 }
217}
218
219impl PartialEq for ScalarFnRef {
220 fn eq(&self, other: &Self) -> bool {
221 self.0.id() == other.0.id() && self.0.options_eq(other.0.options_any())
222 }
223}
224impl Eq for ScalarFnRef {}
225
226impl Hash for ScalarFnRef {
227 fn hash<H: Hasher>(&self, state: &mut H) {
228 self.0.id().hash(state);
229 self.0.options_hash(state);
230 }
231}