1use std::any::Any;
5use std::fmt;
6use std::fmt::{Debug, Display, Formatter};
7use std::hash::{Hash, Hasher};
8use std::sync::Arc;
9
10use arcref::ArcRef;
11use vortex_dtype::{DType, FieldPath};
12use vortex_error::{VortexExpect, VortexResult, vortex_err};
13
14use crate::ArrayRef;
15use crate::expr::expression::Expression;
16use crate::expr::{ExprId, ExpressionView, StatsCatalog};
17
18pub trait VTable: 'static + Sized + Send + Sync {
31 type Instance: 'static + Send + Sync + Debug + PartialEq + Eq + Hash;
33
34 fn id(&self) -> ExprId;
36
37 fn serialize(&self, _instance: &Self::Instance) -> VortexResult<Option<Vec<u8>>> {
42 Ok(None)
43 }
44
45 fn deserialize(&self, _metadata: &[u8]) -> VortexResult<Option<Self::Instance>> {
49 Ok(None)
50 }
51
52 fn validate(&self, expr: &ExpressionView<Self>) -> VortexResult<()>;
54
55 fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName;
57
58 fn fmt_sql(&self, expr: &ExpressionView<Self>, f: &mut Formatter<'_>) -> fmt::Result;
63
64 #[allow(clippy::use_debug)]
68 fn fmt_data(&self, instance: &Self::Instance, f: &mut Formatter<'_>) -> fmt::Result {
69 write!(f, "{:?}", instance)
70 }
71
72 fn return_dtype(&self, expr: &ExpressionView<Self>, scope: &DType) -> VortexResult<DType>;
74
75 fn evaluate(&self, expr: &ExpressionView<Self>, scope: &ArrayRef) -> VortexResult<ArrayRef>;
77
78 fn stat_falsification(
80 &self,
81 _expr: &ExpressionView<Self>,
82 _catalog: &mut dyn StatsCatalog,
83 ) -> Option<Expression> {
84 None
85 }
86
87 fn stat_max(
89 &self,
90 _expr: &ExpressionView<Self>,
91 _catalog: &mut dyn StatsCatalog,
92 ) -> Option<Expression> {
93 None
94 }
95
96 fn stat_min(
98 &self,
99 _expr: &ExpressionView<Self>,
100 _catalog: &mut dyn StatsCatalog,
101 ) -> Option<Expression> {
102 None
103 }
104
105 fn stat_nan_count(
107 &self,
108 _expr: &ExpressionView<Self>,
109 _catalog: &mut dyn StatsCatalog,
110 ) -> Option<Expression> {
111 None
112 }
113
114 fn stat_field_path(&self, _expr: &ExpressionView<Self>) -> Option<FieldPath> {
116 None
117 }
118}
119
120pub trait VTableExt: VTable {
122 fn new_expr(
123 &'static self,
124 instance: Self::Instance,
125 children: impl Into<Arc<[Expression]>>,
126 ) -> Expression {
127 Self::try_new_expr(self, instance, children)
128 .vortex_expect("Failed to create expression instance")
129 }
130
131 fn try_new_expr(
132 &'static self,
133 instance: Self::Instance,
134 children: impl Into<Arc<[Expression]>>,
135 ) -> VortexResult<Expression> {
136 Expression::try_new(
137 ExprVTable::from_static(self),
138 Arc::new(instance),
139 children.into(),
140 )
141 }
142}
143impl<V: VTable> VTableExt for V {}
144
145pub type ChildName = ArcRef<str>;
147
148pub struct NotSupported;
150
151pub trait DynExprVTable: 'static + Send + Sync + private::Sealed {
156 fn as_any(&self) -> &dyn Any;
157 fn id(&self) -> ExprId;
158 fn serialize(&self, instance: &dyn Any) -> VortexResult<Option<Vec<u8>>>;
159 fn deserialize(&self, metadata: &[u8]) -> VortexResult<Option<Arc<dyn Any + Send + Sync>>>;
160 fn child_name(&self, instance: &dyn Any, child_idx: usize) -> ChildName;
161 fn validate(&self, expression: &Expression) -> VortexResult<()>;
162 fn fmt_sql(&self, expression: &Expression, f: &mut Formatter<'_>) -> fmt::Result;
163 fn fmt_data(&self, instance: &dyn Any, f: &mut Formatter<'_>) -> fmt::Result;
164 fn return_dtype(&self, expression: &Expression, scope: &DType) -> VortexResult<DType>;
165 fn evaluate(&self, expression: &Expression, scope: &ArrayRef) -> VortexResult<ArrayRef>;
166
167 fn stat_falsification(
168 &self,
169 expression: &Expression,
170 catalog: &mut dyn StatsCatalog,
171 ) -> Option<Expression>;
172 fn stat_max(
173 &self,
174 expression: &Expression,
175 catalog: &mut dyn StatsCatalog,
176 ) -> Option<Expression>;
177 fn stat_min(
178 &self,
179 expression: &Expression,
180 catalog: &mut dyn StatsCatalog,
181 ) -> Option<Expression>;
182 fn stat_nan_count(
183 &self,
184 expression: &Expression,
185 catalog: &mut dyn StatsCatalog,
186 ) -> Option<Expression>;
187 fn stat_field_path(&self, expression: &Expression) -> Option<FieldPath>;
188
189 fn dyn_eq(&self, instance: &dyn Any, other: &dyn Any) -> bool;
190 fn dyn_hash(&self, instance: &dyn Any, state: &mut dyn Hasher);
191}
192
193#[repr(transparent)]
194pub struct VTableAdapter<V>(V);
195
196impl<V: VTable> DynExprVTable for VTableAdapter<V> {
197 #[inline(always)]
198 fn as_any(&self) -> &dyn Any {
199 self
200 }
201
202 #[inline(always)]
203 fn id(&self) -> ExprId {
204 V::id(&self.0)
205 }
206
207 fn serialize(&self, instance: &dyn Any) -> VortexResult<Option<Vec<u8>>> {
208 let instance = instance
209 .downcast_ref::<V::Instance>()
210 .vortex_expect("Failed to downcast expression instance to expected type");
211 V::serialize(&self.0, instance)
212 }
213
214 fn deserialize(&self, metadata: &[u8]) -> VortexResult<Option<Arc<dyn Any + Send + Sync>>> {
215 Ok(V::deserialize(&self.0, metadata)?
216 .map(|data| Arc::new(data) as Arc<dyn Any + Send + Sync>))
217 }
218
219 fn child_name(&self, instance: &dyn Any, child_idx: usize) -> ChildName {
220 let instance = instance
221 .downcast_ref::<V::Instance>()
222 .vortex_expect("Failed to downcast expression instance to expected type");
223 V::child_name(&self.0, instance, child_idx)
224 }
225
226 fn validate(&self, expression: &Expression) -> VortexResult<()> {
227 let expr = ExpressionView::new(expression);
228 V::validate(&self.0, &expr)
229 }
230
231 fn fmt_sql(&self, expression: &Expression, f: &mut Formatter<'_>) -> fmt::Result {
232 let expr = ExpressionView::new(expression);
233 V::fmt_sql(&self.0, &expr, f)
234 }
235
236 fn fmt_data(&self, instance: &dyn Any, f: &mut Formatter<'_>) -> fmt::Result {
237 let instance = instance
238 .downcast_ref::<V::Instance>()
239 .vortex_expect("Failed to downcast expression instance to expected type");
240 V::fmt_data(&self.0, instance, f)
241 }
242
243 fn return_dtype(&self, expression: &Expression, scope: &DType) -> VortexResult<DType> {
244 let expr = ExpressionView::new(expression);
245 V::return_dtype(&self.0, &expr, scope)
246 }
247
248 fn evaluate(&self, expression: &Expression, scope: &ArrayRef) -> VortexResult<ArrayRef> {
249 let expr = ExpressionView::new(expression);
250 V::evaluate(&self.0, &expr, scope)
251 }
252
253 fn stat_falsification(
254 &self,
255 expression: &Expression,
256 catalog: &mut dyn StatsCatalog,
257 ) -> Option<Expression> {
258 let expr = ExpressionView::new(expression);
259 V::stat_falsification(&self.0, &expr, catalog)
260 }
261
262 fn stat_max(
263 &self,
264 expression: &Expression,
265 catalog: &mut dyn StatsCatalog,
266 ) -> Option<Expression> {
267 let expr = ExpressionView::new(expression);
268 V::stat_max(&self.0, &expr, catalog)
269 }
270
271 fn stat_min(
272 &self,
273 expression: &Expression,
274 catalog: &mut dyn StatsCatalog,
275 ) -> Option<Expression> {
276 let expr = ExpressionView::new(expression);
277 V::stat_min(&self.0, &expr, catalog)
278 }
279
280 fn stat_nan_count(
281 &self,
282 expression: &Expression,
283 catalog: &mut dyn StatsCatalog,
284 ) -> Option<Expression> {
285 let expr = ExpressionView::new(expression);
286 V::stat_nan_count(&self.0, &expr, catalog)
287 }
288
289 fn stat_field_path(&self, expression: &Expression) -> Option<FieldPath> {
290 let expr = ExpressionView::new(expression);
291 V::stat_field_path(&self.0, &expr)
292 }
293
294 fn dyn_eq(&self, instance: &dyn Any, other: &dyn Any) -> bool {
295 let this_instance = instance
296 .downcast_ref::<V::Instance>()
297 .vortex_expect("Failed to downcast expression instance to expected type");
298 let other_instance = other
299 .downcast_ref::<V::Instance>()
300 .vortex_expect("Failed to downcast expression instance to expected type");
301 this_instance == other_instance
302 }
303
304 fn dyn_hash(&self, instance: &dyn Any, mut state: &mut dyn Hasher) {
305 let this_instance = instance
306 .downcast_ref::<V::Instance>()
307 .vortex_expect("Failed to downcast expression instance to expected type");
308 this_instance.hash(&mut state);
309 }
310}
311
312mod private {
313 use crate::expr::{VTable, VTableAdapter};
314
315 pub trait Sealed {}
316 impl<V: VTable> Sealed for VTableAdapter<V> {}
317}
318
319#[derive(Clone)]
321pub struct ExprVTable(ArcRef<dyn DynExprVTable>);
322
323impl ExprVTable {
324 pub(crate) fn as_dyn(&self) -> &dyn DynExprVTable {
327 self.0.as_ref()
328 }
329
330 pub const fn from_static<V: VTable>(vtable: &'static V) -> Self {
332 let adapted: &'static VTableAdapter<V> =
334 unsafe { &*(vtable as *const V as *const VTableAdapter<V>) };
335 Self(ArcRef::new_ref(adapted as &'static dyn DynExprVTable))
336 }
337
338 pub fn id(&self) -> ExprId {
340 self.0.id()
341 }
342
343 pub fn is<V: VTable>(&self) -> bool {
345 self.0.as_any().is::<VTableAdapter<V>>()
346 }
347
348 pub fn as_opt<V: VTable>(&self) -> Option<&V> {
350 self.0
351 .as_any()
352 .downcast_ref::<VTableAdapter<V>>()
353 .map(|adapter| &adapter.0)
354 }
355
356 pub fn deserialize(
358 &self,
359 metadata: &[u8],
360 children: Arc<[Expression]>,
361 ) -> VortexResult<Expression> {
362 let instance_data = self.as_dyn().deserialize(metadata)?.ok_or_else(|| {
363 vortex_err!(
364 "Expression vtable {} is not deserializable",
365 self.as_dyn().id()
366 )
367 })?;
368 Expression::try_new(self.clone(), instance_data, children)
369 }
370}
371
372impl PartialEq for ExprVTable {
373 fn eq(&self, other: &Self) -> bool {
374 self.0.id() == other.0.id()
375 }
376}
377impl Eq for ExprVTable {}
378
379impl Display for ExprVTable {
380 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
381 write!(f, "{}", self.as_dyn().id())
382 }
383}
384
385impl Debug for ExprVTable {
386 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
387 write!(f, "{}", self.as_dyn().id())
388 }
389}
390
391#[cfg(test)]
392mod tests {
393 use rstest::{fixture, rstest};
394
395 use super::*;
396 use crate::expr::exprs::between::between;
397 use crate::expr::exprs::binary::{and, checked_add, eq, gt, gt_eq, lt, lt_eq, not_eq, or};
398 use crate::expr::exprs::cast::cast;
399 use crate::expr::exprs::get_item::{col, get_item};
400 use crate::expr::exprs::is_null::is_null;
401 use crate::expr::exprs::list_contains::list_contains;
402 use crate::expr::exprs::literal::lit;
403 use crate::expr::exprs::merge::merge;
404 use crate::expr::exprs::not::not;
405 use crate::expr::exprs::pack::pack;
406 use crate::expr::exprs::root::root;
407 use crate::expr::exprs::select::{select, select_exclude};
408 use crate::expr::proto::{ExprSerializeProtoExt, deserialize_expr_proto};
409 use crate::expr::session::{ExprRegistry, ExprSession};
410
411 #[fixture]
412 #[once]
413 fn registry() -> ExprRegistry {
414 ExprSession::default().registry().clone()
415 }
416
417 #[rstest]
418 #[case(root())]
420 #[case(select(["hello", "world"], root()))]
421 #[case(select_exclude(["world", "hello"], root()))]
422 #[case(lit(42i32))]
424 #[case(lit(std::f64::consts::PI))]
425 #[case(lit(true))]
426 #[case(lit("hello"))]
427 #[case(col("column_name"))]
429 #[case(get_item("field", root()))]
430 #[case(eq(col("a"), lit(10)))]
432 #[case(not_eq(col("a"), lit(10)))]
433 #[case(gt(col("a"), lit(10)))]
434 #[case(gt_eq(col("a"), lit(10)))]
435 #[case(lt(col("a"), lit(10)))]
436 #[case(lt_eq(col("a"), lit(10)))]
437 #[case(and(col("a"), col("b")))]
439 #[case(or(col("a"), col("b")))]
440 #[case(not(col("a")))]
441 #[case(checked_add(col("a"), lit(5)))]
443 #[case(is_null(col("nullable_col")))]
445 #[case(cast(
447 col("a"),
448 DType::Primitive(vortex_dtype::PType::I64, vortex_dtype::Nullability::NonNullable)
449 ))]
450 #[case(between(col("a"), lit(10), lit(20), crate::compute::BetweenOptions { lower_strict: crate::compute::StrictComparison::NonStrict, upper_strict: crate::compute::StrictComparison::NonStrict }))]
452 #[case(list_contains(col("list_col"), lit("item")))]
454 #[case(pack([("field1", col("a")), ("field2", col("b"))], vortex_dtype::Nullability::NonNullable))]
456 #[case(merge([col("struct1"), col("struct2")]))]
458 #[case(and(gt(col("a"), lit(0)), lt(col("a"), lit(100))))]
460 #[case(or(is_null(col("a")), eq(col("a"), lit(0))))]
461 #[case(not(and(eq(col("status"), lit("active")), gt(col("age"), lit(18)))))]
462 fn text_expr_serde_round_trip(
463 registry: &ExprRegistry,
464 #[case] expr: Expression,
465 ) -> VortexResult<()> {
466 let serialized_pb = (&expr).serialize_proto()?;
467 let deserialized_expr = deserialize_expr_proto(&serialized_pb, registry)?;
468
469 assert_eq!(&expr, &deserialized_expr);
470
471 Ok(())
472 }
473}