vortex_array/scalar_fn/
erased.rs1use std::fmt::Debug;
7use std::fmt::Display;
8use std::fmt::Formatter;
9use std::hash::Hash;
10use std::hash::Hasher;
11use std::sync::Arc;
12
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15use vortex_utils::debug_with::DebugWith;
16
17use crate::ArrayRef;
18use crate::ExecutionCtx;
19use crate::dtype::DType;
20use crate::expr::Expression;
21use crate::expr::StatsCatalog;
22use crate::expr::stats::Stat;
23use crate::scalar_fn::EmptyOptions;
24use crate::scalar_fn::ExecutionArgs;
25use crate::scalar_fn::ReduceCtx;
26use crate::scalar_fn::ReduceNode;
27use crate::scalar_fn::ReduceNodeRef;
28use crate::scalar_fn::ScalarFnId;
29use crate::scalar_fn::ScalarFnVTable;
30use crate::scalar_fn::ScalarFnVTableExt;
31use crate::scalar_fn::SimplifyCtx;
32use crate::scalar_fn::fns::is_null::IsNull;
33use crate::scalar_fn::fns::not::Not;
34use crate::scalar_fn::options::ScalarFnOptions;
35use crate::scalar_fn::signature::ScalarFnSignature;
36use crate::scalar_fn::typed::DynScalarFn;
37use crate::scalar_fn::typed::ScalarFnInner;
38
39#[derive(Clone)]
47pub struct ScalarFnRef(pub(super) Arc<dyn DynScalarFn>);
48
49impl ScalarFnRef {
50 pub fn id(&self) -> ScalarFnId {
52 self.0.id()
53 }
54
55 pub fn is<V: ScalarFnVTable>(&self) -> bool {
57 self.0.as_any().is::<ScalarFnInner<V>>()
58 }
59
60 pub fn as_opt<V: ScalarFnVTable>(&self) -> Option<&V::Options> {
62 self.downcast_inner::<V>().map(|inner| &inner.options)
63 }
64
65 pub fn vtable_ref<V: ScalarFnVTable>(&self) -> Option<&V> {
67 self.downcast_inner::<V>().map(|inner| &inner.vtable)
68 }
69
70 fn downcast_inner<V: ScalarFnVTable>(&self) -> Option<&ScalarFnInner<V>> {
72 self.0.as_any().downcast_ref::<ScalarFnInner<V>>()
73 }
74
75 pub fn as_<V: ScalarFnVTable>(&self) -> &V::Options {
81 self.as_opt::<V>()
82 .vortex_expect("Expression options type mismatch")
83 }
84
85 pub fn options(&self) -> ScalarFnOptions<'_> {
87 ScalarFnOptions { inner: &*self.0 }
88 }
89
90 pub fn signature(&self) -> ScalarFnSignature<'_> {
92 ScalarFnSignature { inner: &*self.0 }
93 }
94
95 pub fn return_dtype(&self, arg_types: &[DType]) -> VortexResult<DType> {
97 self.0.return_dtype(arg_types)
98 }
99
100 pub fn validity(&self, expr: &Expression) -> VortexResult<Expression> {
102 Ok(self.0.validity(expr)?.unwrap_or_else(|| {
103 Not.new_expr(
106 EmptyOptions,
107 [IsNull.new_expr(EmptyOptions, [expr.clone()])],
108 )
109 }))
110 }
111
112 pub fn execute(
114 &self,
115 args: &dyn ExecutionArgs,
116 ctx: &mut ExecutionCtx,
117 ) -> VortexResult<ArrayRef> {
118 self.0.execute(args, ctx)
119 }
120
121 pub fn reduce(
123 &self,
124 node: &dyn ReduceNode,
125 ctx: &dyn ReduceCtx,
126 ) -> VortexResult<Option<ReduceNodeRef>> {
127 self.0.reduce(node, ctx)
128 }
129
130 pub(crate) fn fmt_sql(&self, expr: &Expression, f: &mut Formatter<'_>) -> std::fmt::Result {
136 self.0.fmt_sql(expr, f)
137 }
138
139 pub(crate) fn simplify(
141 &self,
142 expr: &Expression,
143 ctx: &dyn SimplifyCtx,
144 ) -> VortexResult<Option<Expression>> {
145 self.0.simplify(expr, ctx)
146 }
147
148 pub(crate) fn simplify_untyped(&self, expr: &Expression) -> VortexResult<Option<Expression>> {
150 self.0.simplify_untyped(expr)
151 }
152
153 pub(crate) fn stat_falsification(
155 &self,
156 expr: &Expression,
157 catalog: &dyn StatsCatalog,
158 ) -> Option<Expression> {
159 self.0.stat_falsification(expr, catalog)
160 }
161
162 pub(crate) fn stat_expression(
164 &self,
165 expr: &Expression,
166 stat: Stat,
167 catalog: &dyn StatsCatalog,
168 ) -> Option<Expression> {
169 self.0.stat_expression(expr, stat, catalog)
170 }
171}
172
173impl Debug for ScalarFnRef {
174 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
175 f.debug_struct("ScalarFnRef")
176 .field("vtable", &self.0.id())
177 .field("options", &DebugWith(|fmt| self.0.options_debug(fmt)))
178 .finish()
179 }
180}
181
182impl Display for ScalarFnRef {
183 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
184 write!(f, "{}(", self.0.id())?;
185 self.0.options_display(f)?;
186 write!(f, ")")
187 }
188}
189
190impl PartialEq for ScalarFnRef {
191 fn eq(&self, other: &Self) -> bool {
192 self.0.id() == other.0.id() && self.0.options_eq(other.0.options_any())
193 }
194}
195impl Eq for ScalarFnRef {}
196
197impl Hash for ScalarFnRef {
198 fn hash<H: Hasher>(&self, state: &mut H) {
199 self.0.id().hash(state);
200 self.0.options_hash(state);
201 }
202}