vortex_array/arrays/scalar_fn/vtable/
mod.rs1mod array;
5mod operations;
6mod validity;
7mod visitor;
8
9use std::fmt::Display;
10use std::fmt::Formatter;
11use std::hash::Hash;
12use std::hash::Hasher;
13use std::marker::PhantomData;
14use std::ops::Deref;
15
16use itertools::Itertools;
17use vortex_error::VortexExpect;
18use vortex_error::VortexResult;
19use vortex_error::vortex_bail;
20use vortex_error::vortex_ensure;
21use vortex_session::VortexSession;
22
23use crate::Array;
24use crate::ArrayRef;
25use crate::IntoArray;
26use crate::arrays::scalar_fn::array::ScalarFnArray;
27use crate::arrays::scalar_fn::metadata::ScalarFnMetadata;
28use crate::arrays::scalar_fn::rules::PARENT_RULES;
29use crate::arrays::scalar_fn::rules::RULES;
30use crate::buffer::BufferHandle;
31use crate::dtype::DType;
32use crate::executor::ExecutionCtx;
33use crate::expr::Expression;
34use crate::matcher::Matcher;
35use crate::scalar_fn;
36use crate::scalar_fn::Arity;
37use crate::scalar_fn::ChildName;
38use crate::scalar_fn::ExecutionArgs;
39use crate::scalar_fn::ScalarFnId;
40use crate::scalar_fn::ScalarFnVTableExt;
41use crate::serde::ArrayChildren;
42use crate::vtable;
43use crate::vtable::ArrayId;
44use crate::vtable::VTable;
45
46vtable!(ScalarFn);
47
48#[derive(Clone, Debug)]
49pub struct ScalarFnVTable;
50
51impl VTable for ScalarFnVTable {
52 type Array = ScalarFnArray;
53 type Metadata = ScalarFnMetadata;
54 type ArrayVTable = Self;
55 type OperationsVTable = Self;
56 type ValidityVTable = Self;
57 type VisitorVTable = Self;
58
59 fn id(array: &Self::Array) -> ArrayId {
60 array.scalar_fn.id()
61 }
62
63 fn metadata(array: &Self::Array) -> VortexResult<Self::Metadata> {
64 let child_dtypes = array.children().iter().map(|c| c.dtype().clone()).collect();
65 Ok(ScalarFnMetadata {
66 scalar_fn: array.scalar_fn.clone(),
67 child_dtypes,
68 })
69 }
70
71 fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
72 Ok(None)
74 }
75
76 fn deserialize(
77 _bytes: &[u8],
78 _dtype: &DType,
79 _len: usize,
80 _buffers: &[BufferHandle],
81 _session: &VortexSession,
82 ) -> VortexResult<Self::Metadata> {
83 vortex_bail!("Deserialization of ScalarFnVTable metadata is not supported");
84 }
85
86 fn build(
87 dtype: &DType,
88 len: usize,
89 metadata: &ScalarFnMetadata,
90 _buffers: &[BufferHandle],
91 children: &dyn ArrayChildren,
92 ) -> VortexResult<Self::Array> {
93 let children: Vec<_> = metadata
94 .child_dtypes
95 .iter()
96 .enumerate()
97 .map(|(idx, child_dtype)| children.get(idx, child_dtype, len))
98 .try_collect()?;
99
100 #[cfg(debug_assertions)]
101 {
102 let child_dtypes: Vec<_> = children.iter().map(|c| c.dtype().clone()).collect();
103 vortex_error::vortex_ensure!(
104 &metadata.scalar_fn.return_dtype(&child_dtypes)? == dtype,
105 "Return dtype mismatch when building ScalarFnArray"
106 );
107 }
108
109 Ok(ScalarFnArray {
110 scalar_fn: metadata.scalar_fn.clone(),
112 dtype: dtype.clone(),
113 len,
114 children,
115 stats: Default::default(),
116 })
117 }
118
119 fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
120 vortex_ensure!(
121 children.len() == array.children.len(),
122 "ScalarFnArray expects {} children, got {}",
123 array.children.len(),
124 children.len()
125 );
126 array.children = children;
127 Ok(())
128 }
129
130 fn execute(array: &Self::Array, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
131 ctx.log(format_args!("scalar_fn({}): executing", array.scalar_fn));
132 let args = ExecutionArgs {
133 inputs: array.children.clone(),
134 row_count: array.len,
135 ctx,
136 };
137 array.scalar_fn.execute(args)
138 }
139
140 fn reduce(array: &Self::Array) -> VortexResult<Option<ArrayRef>> {
141 RULES.evaluate(array)
142 }
143
144 fn reduce_parent(
145 array: &Self::Array,
146 parent: &ArrayRef,
147 child_idx: usize,
148 ) -> VortexResult<Option<ArrayRef>> {
149 PARENT_RULES.evaluate(array, parent, child_idx)
150 }
151}
152
153pub trait ScalarFnArrayExt: scalar_fn::ScalarFnVTable {
155 fn try_new_array(
156 &self,
157 len: usize,
158 options: Self::Options,
159 children: impl Into<Vec<ArrayRef>>,
160 ) -> VortexResult<ArrayRef> {
161 let scalar_fn = scalar_fn::ScalarFn::new(self.clone(), options).erased();
162
163 let children = children.into();
164 vortex_ensure!(
165 children.iter().all(|c| c.len() == len),
166 "All child arrays must have the same length as the scalar function array"
167 );
168
169 let child_dtypes = children.iter().map(|c| c.dtype().clone()).collect_vec();
170 let dtype = scalar_fn.return_dtype(&child_dtypes)?;
171
172 Ok(ScalarFnArray {
173 scalar_fn,
174 dtype,
175 len,
176 children,
177 stats: Default::default(),
178 }
179 .into_array())
180 }
181}
182impl<V: scalar_fn::ScalarFnVTable> ScalarFnArrayExt for V {}
183
184#[derive(Debug)]
186pub struct AnyScalarFn;
187impl Matcher for AnyScalarFn {
188 type Match<'a> = &'a ScalarFnArray;
189
190 fn try_match(array: &dyn Array) -> Option<Self::Match<'_>> {
191 array.as_opt::<ScalarFnVTable>()
192 }
193}
194
195#[derive(Debug, Default)]
197pub struct ExactScalarFn<F: scalar_fn::ScalarFnVTable>(PhantomData<F>);
198
199impl<F: scalar_fn::ScalarFnVTable> Matcher for ExactScalarFn<F> {
200 type Match<'a> = ScalarFnArrayView<'a, F>;
201
202 fn matches(array: &dyn Array) -> bool {
203 if let Some(scalar_fn_array) = array.as_opt::<ScalarFnVTable>() {
204 scalar_fn_array.scalar_fn().is::<F>()
205 } else {
206 false
207 }
208 }
209
210 fn try_match(array: &dyn Array) -> Option<Self::Match<'_>> {
211 let scalar_fn_array = array.as_opt::<ScalarFnVTable>()?;
212 let scalar_fn_vtable = scalar_fn_array
213 .scalar_fn
214 .vtable_ref::<F>()
215 .vortex_expect("ScalarFn VTable type mismatch in ExactScalarFn matcher");
216 let scalar_fn_options = scalar_fn_array
217 .scalar_fn
218 .as_opt::<F>()
219 .vortex_expect("ScalarFn options type mismatch in ExactScalarFn matcher");
220 Some(ScalarFnArrayView {
221 array,
222 vtable: scalar_fn_vtable,
223 options: scalar_fn_options,
224 })
225 }
226}
227
228pub struct ScalarFnArrayView<'a, F: scalar_fn::ScalarFnVTable> {
229 array: &'a dyn Array,
230 pub vtable: &'a F,
231 pub options: &'a F::Options,
232}
233
234impl<F: scalar_fn::ScalarFnVTable> Deref for ScalarFnArrayView<'_, F> {
235 type Target = dyn Array;
236
237 fn deref(&self) -> &Self::Target {
238 self.array
239 }
240}
241
242#[derive(Clone)]
244struct ArrayExpr;
245
246#[derive(Clone, Debug)]
247struct FakeEq<T>(T);
248
249impl<T> PartialEq<Self> for FakeEq<T> {
250 fn eq(&self, _other: &Self) -> bool {
251 false
252 }
253}
254
255impl<T> Eq for FakeEq<T> {}
256
257impl<T> Hash for FakeEq<T> {
258 fn hash<H: Hasher>(&self, _state: &mut H) {}
259}
260
261impl Display for FakeEq<ArrayRef> {
262 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
263 write!(f, "{}", self.0.encoding_id())
264 }
265}
266
267impl scalar_fn::ScalarFnVTable for ArrayExpr {
268 type Options = FakeEq<ArrayRef>;
269
270 fn id(&self) -> ScalarFnId {
271 ScalarFnId::from("vortex.array")
272 }
273
274 fn arity(&self, _options: &Self::Options) -> Arity {
275 Arity::Exact(0)
276 }
277
278 fn child_name(&self, _options: &Self::Options, _child_idx: usize) -> ChildName {
279 todo!()
280 }
281
282 fn fmt_sql(
283 &self,
284 options: &Self::Options,
285 _expr: &Expression,
286 f: &mut Formatter<'_>,
287 ) -> std::fmt::Result {
288 write!(f, "{}", options.0.encoding_id())
289 }
290
291 fn return_dtype(&self, options: &Self::Options, _arg_dtypes: &[DType]) -> VortexResult<DType> {
292 Ok(options.0.dtype().clone())
293 }
294
295 fn execute(&self, options: &Self::Options, args: ExecutionArgs) -> VortexResult<ArrayRef> {
296 crate::Executable::execute(options.0.clone(), args.ctx)
297 }
298
299 fn validity(
300 &self,
301 options: &Self::Options,
302 _expression: &Expression,
303 ) -> VortexResult<Option<Expression>> {
304 let validity_array = options.0.validity()?.to_array(options.0.len());
305 Ok(Some(ArrayExpr.new_expr(FakeEq(validity_array), [])))
306 }
307}