vortex_array/arrays/scalar_fn/vtable/
mod.rs1mod operations;
4mod validity;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::hash::Hasher;
9use std::marker::PhantomData;
10use std::ops::Deref;
11
12use itertools::Itertools;
13use vortex_error::VortexResult;
14use vortex_error::vortex_bail;
15use vortex_error::vortex_ensure;
16use vortex_error::vortex_panic;
17use vortex_session::VortexSession;
18
19use crate::ArrayEq;
20use crate::ArrayHash;
21use crate::ArrayRef;
22use crate::IntoArray;
23use crate::Precision;
24use crate::array::Array;
25use crate::array::ArrayId;
26use crate::array::ArrayParts;
27use crate::array::ArrayView;
28use crate::array::VTable;
29use crate::arrays::scalar_fn::array::ScalarFnArrayExt;
30use crate::arrays::scalar_fn::array::ScalarFnData;
31use crate::arrays::scalar_fn::rules::PARENT_RULES;
32use crate::arrays::scalar_fn::rules::RULES;
33use crate::buffer::BufferHandle;
34use crate::dtype::DType;
35use crate::executor::ExecutionCtx;
36use crate::executor::ExecutionResult;
37use crate::expr::Expression;
38use crate::matcher::Matcher;
39use crate::scalar_fn;
40use crate::scalar_fn::Arity;
41use crate::scalar_fn::ChildName;
42use crate::scalar_fn::ExecutionArgs;
43use crate::scalar_fn::ScalarFnId;
44use crate::scalar_fn::ScalarFnVTableExt;
45use crate::scalar_fn::VecExecutionArgs;
46use crate::serde::ArrayChildren;
47
48pub type ScalarFnArray = Array<ScalarFnVTable>;
50
51#[derive(Clone, Debug)]
52pub struct ScalarFnVTable {
53 pub(super) id: ScalarFnId,
54}
55
56impl ArrayHash for ScalarFnData {
57 fn array_hash<H: Hasher>(&self, state: &mut H, _precision: Precision) {
58 self.scalar_fn().hash(state);
59 }
60}
61
62impl ArrayEq for ScalarFnData {
63 fn array_eq(&self, other: &Self, _precision: Precision) -> bool {
64 self.scalar_fn() == other.scalar_fn()
65 }
66}
67
68impl VTable for ScalarFnVTable {
69 type ArrayData = ScalarFnData;
70 type OperationsVTable = Self;
71 type ValidityVTable = Self;
72
73 fn id(&self) -> ArrayId {
74 self.id
75 }
76
77 fn validate(
78 &self,
79 data: &ScalarFnData,
80 dtype: &DType,
81 len: usize,
82 slots: &[Option<ArrayRef>],
83 ) -> VortexResult<()> {
84 vortex_ensure!(
85 data.scalar_fn.id() == self.id,
86 "ScalarFnArray data scalar_fn does not match vtable"
87 );
88 vortex_ensure!(
89 slots.iter().flatten().all(|c| c.len() == len),
90 "All child arrays must have the same length as the scalar function array"
91 );
92
93 let child_dtypes = slots
94 .iter()
95 .flatten()
96 .map(|c| c.dtype().clone())
97 .collect_vec();
98 vortex_ensure!(
99 data.scalar_fn.return_dtype(&child_dtypes)? == *dtype,
100 "ScalarFnArray dtype does not match scalar function return dtype"
101 );
102 Ok(())
103 }
104
105 fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
106 0
107 }
108
109 fn buffer(_array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
110 vortex_panic!("ScalarFnArray buffer index {idx} out of bounds")
111 }
112
113 fn buffer_name(_array: ArrayView<'_, Self>, _idx: usize) -> Option<String> {
114 None
115 }
116
117 fn serialize(
118 _array: ArrayView<'_, Self>,
119 _session: &VortexSession,
120 ) -> VortexResult<Option<Vec<u8>>> {
121 Ok(None)
123 }
124
125 fn deserialize(
126 &self,
127 _dtype: &DType,
128 _len: usize,
129 _metadata: &[u8],
130 _buffers: &[BufferHandle],
131 _children: &dyn ArrayChildren,
132 _session: &VortexSession,
133 ) -> VortexResult<ArrayParts<Self>> {
134 vortex_bail!("Deserialization of ScalarFnVTable metadata is not supported");
135 }
136
137 fn slot_name(array: ArrayView<'_, Self>, idx: usize) -> String {
138 array
139 .scalar_fn()
140 .signature()
141 .child_name(idx)
142 .as_ref()
143 .to_string()
144 }
145
146 fn execute(array: Array<Self>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
147 ctx.log(format_args!("scalar_fn({}): executing", array.scalar_fn()));
148 let args = VecExecutionArgs::new(array.children(), array.len());
149 array
150 .scalar_fn()
151 .execute(&args, ctx)
152 .map(ExecutionResult::done)
153 }
154
155 fn reduce(array: ArrayView<'_, Self>) -> VortexResult<Option<ArrayRef>> {
156 RULES.evaluate(array)
157 }
158
159 fn reduce_parent(
160 array: ArrayView<'_, Self>,
161 parent: &ArrayRef,
162 child_idx: usize,
163 ) -> VortexResult<Option<ArrayRef>> {
164 PARENT_RULES.evaluate(array, parent, child_idx)
165 }
166}
167
168pub trait ScalarFnFactoryExt: scalar_fn::ScalarFnVTable {
170 fn try_new_array(
171 &self,
172 len: usize,
173 options: Self::Options,
174 children: impl Into<Vec<ArrayRef>>,
175 ) -> VortexResult<ArrayRef> {
176 let scalar_fn = scalar_fn::ScalarFn::new(self.clone(), options).erased();
177
178 let children = children.into();
179 vortex_ensure!(
180 children.iter().all(|c| c.len() == len),
181 "All child arrays must have the same length as the scalar function array"
182 );
183
184 let child_dtypes = children.iter().map(|c| c.dtype().clone()).collect_vec();
185 let dtype = scalar_fn.return_dtype(&child_dtypes)?;
186
187 let data = ScalarFnData {
188 scalar_fn: scalar_fn.clone(),
189 };
190 let vtable = ScalarFnVTable { id: scalar_fn.id() };
191 Ok(unsafe {
192 Array::from_parts_unchecked(
193 ArrayParts::new(vtable, dtype, len, data)
194 .with_slots(children.into_iter().map(Some).collect()),
195 )
196 }
197 .into_array())
198 }
199}
200impl<V: scalar_fn::ScalarFnVTable> ScalarFnFactoryExt for V {}
201
202#[derive(Debug)]
204pub struct AnyScalarFn;
205impl Matcher for AnyScalarFn {
206 type Match<'a> = ArrayView<'a, ScalarFnVTable>;
207
208 fn matches(array: &ArrayRef) -> bool {
209 array.is::<ScalarFnVTable>()
210 }
211
212 fn try_match(array: &ArrayRef) -> Option<Self::Match<'_>> {
213 array.as_opt::<ScalarFnVTable>()
214 }
215}
216
217#[derive(Debug, Default)]
219pub struct ExactScalarFn<F: scalar_fn::ScalarFnVTable>(PhantomData<F>);
220
221impl<F: scalar_fn::ScalarFnVTable> Matcher for ExactScalarFn<F> {
222 type Match<'a> = ScalarFnArrayView<'a, F>;
223
224 fn matches(array: &ArrayRef) -> bool {
225 if let Some(scalar_fn_array) = array.as_opt::<ScalarFnVTable>() {
226 scalar_fn_array.data().scalar_fn().is::<F>()
227 } else {
228 false
229 }
230 }
231
232 fn try_match(array: &ArrayRef) -> Option<Self::Match<'_>> {
233 let scalar_fn_array = array.as_opt::<ScalarFnVTable>()?;
234 let scalar_fn_data = scalar_fn_array.data();
235 let scalar_fn = scalar_fn_data.scalar_fn().downcast_ref::<F>()?;
236 Some(ScalarFnArrayView {
237 array,
238 vtable: scalar_fn.vtable(),
239 options: scalar_fn.options(),
240 })
241 }
242}
243
244pub struct ScalarFnArrayView<'a, F: scalar_fn::ScalarFnVTable> {
245 array: &'a ArrayRef,
246 pub vtable: &'a F,
247 pub options: &'a F::Options,
248}
249
250impl<F: scalar_fn::ScalarFnVTable> Deref for ScalarFnArrayView<'_, F> {
251 type Target = ArrayRef;
252
253 fn deref(&self) -> &Self::Target {
254 self.array
255 }
256}
257
258#[derive(Clone)]
260struct ArrayExpr;
261
262#[derive(Clone, Debug)]
263struct FakeEq<T>(T);
264
265impl<T> PartialEq<Self> for FakeEq<T> {
266 fn eq(&self, _other: &Self) -> bool {
267 false
268 }
269}
270
271impl<T> Eq for FakeEq<T> {}
272
273impl<T> Hash for FakeEq<T> {
274 fn hash<H: Hasher>(&self, _state: &mut H) {}
275}
276
277impl Display for FakeEq<ArrayRef> {
278 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
279 write!(f, "{}", self.0.encoding_id())
280 }
281}
282
283impl scalar_fn::ScalarFnVTable for ArrayExpr {
284 type Options = FakeEq<ArrayRef>;
285
286 fn id(&self) -> ScalarFnId {
287 ScalarFnId::new("vortex.array")
288 }
289
290 fn arity(&self, _options: &Self::Options) -> Arity {
291 Arity::Exact(0)
292 }
293
294 fn child_name(&self, _options: &Self::Options, _child_idx: usize) -> ChildName {
295 todo!()
296 }
297
298 fn fmt_sql(
299 &self,
300 options: &Self::Options,
301 _expr: &Expression,
302 f: &mut Formatter<'_>,
303 ) -> std::fmt::Result {
304 write!(f, "{}", options.0.encoding_id())
305 }
306
307 fn return_dtype(&self, options: &Self::Options, _arg_dtypes: &[DType]) -> VortexResult<DType> {
308 Ok(options.0.dtype().clone())
309 }
310
311 fn execute(
312 &self,
313 options: &Self::Options,
314 _args: &dyn ExecutionArgs,
315 ctx: &mut ExecutionCtx,
316 ) -> VortexResult<ArrayRef> {
317 crate::Executable::execute(options.0.clone(), ctx)
318 }
319
320 fn validity(
321 &self,
322 options: &Self::Options,
323 _expression: &Expression,
324 ) -> VortexResult<Option<Expression>> {
325 let validity_array = options.0.validity()?.to_array(options.0.len());
326 Ok(Some(ArrayExpr.new_expr(FakeEq(validity_array), [])))
327 }
328}