1use std::any::Any;
5use std::fmt;
6use std::fmt::Debug;
7use std::fmt::Display;
8use std::fmt::Formatter;
9use std::hash::Hash;
10use std::hash::Hasher;
11use std::sync::Arc;
12
13use arcref::ArcRef;
14use vortex_dtype::DType;
15use vortex_error::VortexExpect;
16use vortex_error::VortexResult;
17use vortex_error::vortex_bail;
18use vortex_session::VortexSession;
19
20use crate::ArrayRef;
21use crate::ExecutionCtx;
22use crate::expr::ExprId;
23use crate::expr::StatsCatalog;
24use crate::expr::expression::Expression;
25use crate::expr::scalar_fn::ScalarFn;
26use crate::expr::stats::Stat;
27
28pub trait VTable: 'static + Sized + Send + Sync {
40 type Options: 'static + Send + Sync + Clone + Debug + Display + PartialEq + Eq + Hash;
42
43 fn id(&self) -> ExprId;
45
46 fn serialize(&self, options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
51 _ = options;
52 Ok(None)
53 }
54
55 fn deserialize(
57 &self,
58 _metadata: &[u8],
59 _session: &VortexSession,
60 ) -> VortexResult<Self::Options> {
61 vortex_bail!("Expression {} is not deserializable", self.id());
62 }
63
64 fn arity(&self, options: &Self::Options) -> Arity;
66
67 fn child_name(&self, options: &Self::Options, child_idx: usize) -> ChildName;
69
70 fn fmt_sql(
75 &self,
76 options: &Self::Options,
77 expr: &Expression,
78 f: &mut Formatter<'_>,
79 ) -> fmt::Result;
80
81 fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType>;
83
84 fn execute(&self, options: &Self::Options, args: ExecutionArgs) -> VortexResult<ArrayRef>;
96
97 fn reduce(
104 &self,
105 options: &Self::Options,
106 node: &dyn ReduceNode,
107 ctx: &dyn ReduceCtx,
108 ) -> VortexResult<Option<ReduceNodeRef>> {
109 _ = options;
110 _ = node;
111 _ = ctx;
112 Ok(None)
113 }
114
115 fn simplify(
117 &self,
118 options: &Self::Options,
119 expr: &Expression,
120 ctx: &dyn SimplifyCtx,
121 ) -> VortexResult<Option<Expression>> {
122 _ = options;
123 _ = expr;
124 _ = ctx;
125 Ok(None)
126 }
127
128 fn simplify_untyped(
130 &self,
131 options: &Self::Options,
132 expr: &Expression,
133 ) -> VortexResult<Option<Expression>> {
134 _ = options;
135 _ = expr;
136 Ok(None)
137 }
138
139 fn stat_falsification(
141 &self,
142 options: &Self::Options,
143 expr: &Expression,
144 catalog: &dyn StatsCatalog,
145 ) -> Option<Expression> {
146 _ = options;
147 _ = expr;
148 _ = catalog;
149 None
150 }
151
152 fn stat_expression(
154 &self,
155 options: &Self::Options,
156 expr: &Expression,
157 stat: Stat,
158 catalog: &dyn StatsCatalog,
159 ) -> Option<Expression> {
160 _ = options;
161 _ = expr;
162 _ = expr;
163 _ = stat;
164 _ = catalog;
165 None
166 }
167
168 fn validity(
175 &self,
176 options: &Self::Options,
177 expression: &Expression,
178 ) -> VortexResult<Option<Expression>> {
179 _ = (options, expression);
180 Ok(None)
181 }
182
183 fn is_null_sensitive(&self, options: &Self::Options) -> bool {
199 _ = options;
200 true
201 }
202
203 fn is_fallible(&self, options: &Self::Options) -> bool {
211 _ = options;
212 true
213 }
214}
215
216pub trait ReduceCtx {
218 fn new_node(
220 &self,
221 scalar_fn: ScalarFn,
222 children: &[ReduceNodeRef],
223 ) -> VortexResult<ReduceNodeRef>;
224}
225
226pub type ReduceNodeRef = Arc<dyn ReduceNode>;
227
228pub trait ReduceNode {
230 fn as_any(&self) -> &dyn Any;
232
233 fn node_dtype(&self) -> VortexResult<DType>;
235
236 fn scalar_fn(&self) -> Option<&ScalarFn>;
238
239 fn child(&self, idx: usize) -> ReduceNodeRef;
241
242 fn child_count(&self) -> usize;
244}
245
246#[derive(Clone, Copy, Debug, PartialEq, Eq)]
248pub enum Arity {
249 Exact(usize),
250 Variadic { min: usize, max: Option<usize> },
251}
252
253impl Display for Arity {
254 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
255 match self {
256 Arity::Exact(n) => write!(f, "{}", n),
257 Arity::Variadic { min, max } => match max {
258 Some(max) if min == max => write!(f, "{}", min),
259 Some(max) => write!(f, "{}..{}", min, max),
260 None => write!(f, "{}+", min),
261 },
262 }
263 }
264}
265
266impl Arity {
267 pub fn matches(&self, arg_count: usize) -> bool {
269 match self {
270 Arity::Exact(m) => *m == arg_count,
271 Arity::Variadic { min, max } => {
272 if arg_count < *min {
273 return false;
274 }
275 if let Some(max) = max
276 && arg_count > *max
277 {
278 return false;
279 }
280 true
281 }
282 }
283 }
284}
285
286pub trait SimplifyCtx {
290 fn return_dtype(&self, expr: &Expression) -> VortexResult<DType>;
292}
293
294pub struct ExecutionArgs<'a> {
296 pub inputs: Vec<ArrayRef>,
298 pub row_count: usize,
300 pub ctx: &'a mut ExecutionCtx,
302}
303
304#[derive(Clone, Debug, PartialEq, Eq, Hash)]
305pub struct EmptyOptions;
306impl Display for EmptyOptions {
307 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
308 write!(f, "")
309 }
310}
311
312pub trait VTableExt: VTable {
314 fn bind(&'static self, options: Self::Options) -> ScalarFn {
316 ScalarFn::new_static(self, options)
317 }
318
319 fn new_expr(
321 &'static self,
322 options: Self::Options,
323 children: impl IntoIterator<Item = Expression>,
324 ) -> Expression {
325 Self::try_new_expr(self, options, children).vortex_expect("Failed to create expression")
326 }
327
328 fn try_new_expr(
330 &'static self,
331 options: Self::Options,
332 children: impl IntoIterator<Item = Expression>,
333 ) -> VortexResult<Expression> {
334 Expression::try_new(self.bind(options), children)
335 }
336}
337impl<V: VTable> VTableExt for V {}
338
339pub type ChildName = ArcRef<str>;
341
342pub struct NotSupported;
344
345pub trait DynExprVTable: 'static + Send + Sync + private::Sealed {
350 fn as_any(&self) -> &dyn Any;
351
352 fn id(&self) -> ExprId;
353 fn fmt_sql(&self, expression: &Expression, f: &mut Formatter<'_>) -> fmt::Result;
354
355 fn options_serialize(&self, options: &dyn Any) -> VortexResult<Option<Vec<u8>>>;
356 fn options_deserialize(
357 &self,
358 metadata: &[u8],
359 session: &VortexSession,
360 ) -> VortexResult<Box<dyn Any + Send + Sync>>;
361 fn options_clone(&self, options: &dyn Any) -> Box<dyn Any + Send + Sync>;
362 fn options_eq(&self, a: &dyn Any, b: &dyn Any) -> bool;
363 fn options_hash(&self, options: &dyn Any, hasher: &mut dyn Hasher);
364 fn options_display(&self, options: &dyn Any, fmt: &mut Formatter<'_>) -> fmt::Result;
365 fn options_debug(&self, options: &dyn Any, fmt: &mut Formatter<'_>) -> fmt::Result;
366
367 fn return_dtype(&self, options: &dyn Any, arg_types: &[DType]) -> VortexResult<DType>;
368 fn simplify(
369 &self,
370 expression: &Expression,
371 ctx: &dyn SimplifyCtx,
372 ) -> VortexResult<Option<Expression>>;
373 fn simplify_untyped(&self, expression: &Expression) -> VortexResult<Option<Expression>>;
374 fn validity(&self, expression: &Expression) -> VortexResult<Option<Expression>>;
375 fn execute(&self, options: &dyn Any, args: ExecutionArgs) -> VortexResult<ArrayRef>;
376 fn reduce(
377 &self,
378 options: &dyn Any,
379 node: &dyn ReduceNode,
380 ctx: &dyn ReduceCtx,
381 ) -> VortexResult<Option<ReduceNodeRef>>;
382
383 fn arity(&self, options: &dyn Any) -> Arity;
384 fn child_name(&self, options: &dyn Any, child_idx: usize) -> ChildName;
385 fn stat_falsification(
386 &self,
387 expression: &Expression,
388 catalog: &dyn StatsCatalog,
389 ) -> Option<Expression>;
390 fn stat_expression(
391 &self,
392 expression: &Expression,
393 stat: Stat,
394 catalog: &dyn StatsCatalog,
395 ) -> Option<Expression>;
396 fn is_null_sensitive(&self, options: &dyn Any) -> bool;
397 fn is_fallible(&self, options: &dyn Any) -> bool;
398}
399
400#[repr(transparent)]
401pub struct VTableAdapter<V>(V);
402
403impl<V: VTable> DynExprVTable for VTableAdapter<V> {
404 #[inline(always)]
405 fn as_any(&self) -> &dyn Any {
406 &self.0
407 }
408
409 #[inline(always)]
410 fn id(&self) -> ExprId {
411 V::id(&self.0)
412 }
413
414 fn fmt_sql(&self, expression: &Expression, f: &mut Formatter<'_>) -> fmt::Result {
415 V::fmt_sql(
416 &self.0,
417 downcast::<V>(expression.options().as_any()),
418 expression,
419 f,
420 )
421 }
422
423 fn options_serialize(&self, options: &dyn Any) -> VortexResult<Option<Vec<u8>>> {
424 V::serialize(&self.0, downcast::<V>(options))
425 }
426
427 fn options_deserialize(
428 &self,
429 bytes: &[u8],
430 session: &VortexSession,
431 ) -> VortexResult<Box<dyn Any + Send + Sync>> {
432 Ok(Box::new(V::deserialize(&self.0, bytes, session)?))
433 }
434
435 fn options_clone(&self, options: &dyn Any) -> Box<dyn Any + Send + Sync> {
436 let options = options
437 .downcast_ref::<V::Options>()
438 .vortex_expect("Failed to downcast expression options to expected type");
439 Box::new(options.clone())
440 }
441
442 fn options_eq(&self, a: &dyn Any, b: &dyn Any) -> bool {
443 downcast::<V>(a) == downcast::<V>(b)
444 }
445
446 fn options_hash(&self, options: &dyn Any, mut hasher: &mut dyn Hasher) {
447 downcast::<V>(options).hash(&mut hasher);
448 }
449
450 fn options_display(&self, options: &dyn Any, fmt: &mut Formatter<'_>) -> fmt::Result {
451 Display::fmt(downcast::<V>(options), fmt)
452 }
453
454 fn options_debug(&self, options: &dyn Any, fmt: &mut Formatter<'_>) -> fmt::Result {
455 Debug::fmt(downcast::<V>(options), fmt)
456 }
457
458 fn return_dtype(&self, options: &dyn Any, arg_dtypes: &[DType]) -> VortexResult<DType> {
459 V::return_dtype(&self.0, downcast::<V>(options), arg_dtypes)
460 }
461
462 fn simplify(
463 &self,
464 expression: &Expression,
465 ctx: &dyn SimplifyCtx,
466 ) -> VortexResult<Option<Expression>> {
467 V::simplify(
468 &self.0,
469 downcast::<V>(expression.options().as_any()),
470 expression,
471 ctx,
472 )
473 }
474
475 fn simplify_untyped(&self, expression: &Expression) -> VortexResult<Option<Expression>> {
476 V::simplify_untyped(
477 &self.0,
478 downcast::<V>(expression.options().as_any()),
479 expression,
480 )
481 }
482
483 fn validity(&self, expression: &Expression) -> VortexResult<Option<Expression>> {
484 V::validity(
485 &self.0,
486 downcast::<V>(expression.options().as_any()),
487 expression,
488 )
489 }
490
491 fn execute(&self, options: &dyn Any, args: ExecutionArgs) -> VortexResult<ArrayRef> {
492 let options = downcast::<V>(options);
493
494 let expected_row_count = args.row_count;
495 #[cfg(debug_assertions)]
496 let expected_dtype = {
497 let args_dtypes: Vec<DType> = args
498 .inputs
499 .iter()
500 .map(|array| array.dtype().clone())
501 .collect();
502 V::return_dtype(&self.0, options, &args_dtypes)
503 }?;
504
505 let result = V::execute(&self.0, options, args)?;
506
507 assert_eq!(
508 result.len(),
509 expected_row_count,
510 "Expression execution {} returned vector of length {}, but expected {}",
511 self.0.id(),
512 result.len(),
513 expected_row_count,
514 );
515
516 #[cfg(debug_assertions)]
518 {
519 vortex_error::vortex_ensure!(
520 result.dtype() == &expected_dtype,
521 "Expression execution {} returned vector of invalid dtype. Expected {}, got {}",
522 self.0.id(),
523 expected_dtype,
524 result.dtype(),
525 );
526 }
527
528 Ok(result)
529 }
530
531 fn reduce(
532 &self,
533 options: &dyn Any,
534 node: &dyn ReduceNode,
535 ctx: &dyn ReduceCtx,
536 ) -> VortexResult<Option<ReduceNodeRef>> {
537 V::reduce(&self.0, downcast::<V>(options), node, ctx)
538 }
539
540 fn arity(&self, options: &dyn Any) -> Arity {
541 V::arity(&self.0, downcast::<V>(options))
542 }
543
544 fn child_name(&self, options: &dyn Any, child_idx: usize) -> ChildName {
545 V::child_name(&self.0, downcast::<V>(options), child_idx)
546 }
547
548 fn stat_falsification(
549 &self,
550 expression: &Expression,
551 catalog: &dyn StatsCatalog,
552 ) -> Option<Expression> {
553 V::stat_falsification(
554 &self.0,
555 downcast::<V>(expression.options().as_any()),
556 expression,
557 catalog,
558 )
559 }
560
561 fn stat_expression(
562 &self,
563 expression: &Expression,
564 stat: Stat,
565 catalog: &dyn StatsCatalog,
566 ) -> Option<Expression> {
567 V::stat_expression(
568 &self.0,
569 downcast::<V>(expression.options().as_any()),
570 expression,
571 stat,
572 catalog,
573 )
574 }
575
576 fn is_null_sensitive(&self, options: &dyn Any) -> bool {
577 V::is_null_sensitive(&self.0, downcast::<V>(options))
578 }
579
580 fn is_fallible(&self, options: &dyn Any) -> bool {
581 V::is_fallible(&self.0, downcast::<V>(options))
582 }
583}
584
585fn downcast<V: VTable>(options: &dyn Any) -> &V::Options {
586 options
587 .downcast_ref::<V::Options>()
588 .vortex_expect("Invalid options type for expression")
589}
590
591mod private {
592 use crate::expr::VTable;
593 use crate::expr::VTableAdapter;
594
595 pub trait Sealed {}
596 impl<V: VTable> Sealed for VTableAdapter<V> {}
597}
598
599#[derive(Clone)]
601pub struct ExprVTable(ArcRef<dyn DynExprVTable>);
602
603impl ExprVTable {
604 pub(crate) fn as_dyn(&self) -> &dyn DynExprVTable {
607 self.0.as_ref()
608 }
609
610 pub fn as_any(&self) -> &dyn Any {
612 self.0.as_any()
613 }
614
615 pub fn new<V: VTable>(vtable: V) -> Self {
617 Self(ArcRef::new_arc(Arc::new(VTableAdapter(vtable))))
618 }
619
620 pub const fn new_static<V: VTable>(vtable: &'static V) -> Self {
622 let adapted: &'static VTableAdapter<V> =
624 unsafe { &*(vtable as *const V as *const VTableAdapter<V>) };
625 Self(ArcRef::new_ref(adapted as &'static dyn DynExprVTable))
626 }
627
628 pub fn id(&self) -> ExprId {
630 self.0.id()
631 }
632
633 pub fn is<V: VTable>(&self) -> bool {
635 self.0.as_any().is::<V>()
636 }
637
638 pub fn deserialize(&self, metadata: &[u8], session: &VortexSession) -> VortexResult<ScalarFn> {
640 Ok(unsafe {
641 ScalarFn::new_unchecked(
642 self.clone(),
643 self.as_dyn().options_deserialize(metadata, session)?,
644 )
645 })
646 }
647}
648
649impl PartialEq for ExprVTable {
650 fn eq(&self, other: &Self) -> bool {
651 self.0.id() == other.0.id()
652 }
653}
654impl Eq for ExprVTable {}
655
656impl Hash for ExprVTable {
657 fn hash<H: Hasher>(&self, state: &mut H) {
658 self.0.id().hash(state);
659 }
660}
661
662impl Display for ExprVTable {
663 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
664 write!(f, "{}", self.as_dyn().id())
665 }
666}
667
668impl Debug for ExprVTable {
669 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
670 write!(f, "{}", self.as_dyn().id())
671 }
672}
673
674#[cfg(test)]
675mod tests {
676 use rstest::rstest;
677
678 use super::*;
679 use crate::LEGACY_SESSION;
680 use crate::expr::exprs::between::between;
681 use crate::expr::exprs::binary::and;
682 use crate::expr::exprs::binary::checked_add;
683 use crate::expr::exprs::binary::eq;
684 use crate::expr::exprs::binary::gt;
685 use crate::expr::exprs::binary::gt_eq;
686 use crate::expr::exprs::binary::lt;
687 use crate::expr::exprs::binary::lt_eq;
688 use crate::expr::exprs::binary::not_eq;
689 use crate::expr::exprs::binary::or;
690 use crate::expr::exprs::cast::cast;
691 use crate::expr::exprs::fill_null::fill_null;
692 use crate::expr::exprs::get_item::col;
693 use crate::expr::exprs::get_item::get_item;
694 use crate::expr::exprs::is_null::is_null;
695 use crate::expr::exprs::list_contains::list_contains;
696 use crate::expr::exprs::literal::lit;
697 use crate::expr::exprs::merge::merge;
698 use crate::expr::exprs::not::not;
699 use crate::expr::exprs::pack::pack;
700 use crate::expr::exprs::root::root;
701 use crate::expr::exprs::select::select;
702 use crate::expr::exprs::select::select_exclude;
703 use crate::expr::proto::ExprSerializeProtoExt;
704
705 #[rstest]
706 #[case(root())]
708 #[case(select(["hello", "world"], root()))]
709 #[case(select_exclude(["world", "hello"], root()))]
710 #[case(lit(42i32))]
712 #[case(lit(std::f64::consts::PI))]
713 #[case(lit(true))]
714 #[case(lit("hello"))]
715 #[case(col("column_name"))]
717 #[case(get_item("field", root()))]
718 #[case(eq(col("a"), lit(10)))]
720 #[case(not_eq(col("a"), lit(10)))]
721 #[case(gt(col("a"), lit(10)))]
722 #[case(gt_eq(col("a"), lit(10)))]
723 #[case(lt(col("a"), lit(10)))]
724 #[case(lt_eq(col("a"), lit(10)))]
725 #[case(and(col("a"), col("b")))]
727 #[case(or(col("a"), col("b")))]
728 #[case(not(col("a")))]
729 #[case(checked_add(col("a"), lit(5)))]
731 #[case(is_null(col("nullable_col")))]
733 #[case(fill_null(col("a"), lit(0)))]
735 #[case(cast(
737 col("a"),
738 DType::Primitive(vortex_dtype::PType::I64, vortex_dtype::Nullability::NonNullable)
739 ))]
740 #[case(between(
742 col("a"),
743 lit(10),
744 lit(20),
745 crate::expr::BetweenOptions{ lower_strict: crate::expr::StrictComparison::NonStrict, upper_strict: crate::expr::StrictComparison::NonStrict }
746 ))]
747 #[case(list_contains(col("list_col"), lit("item")))]
749 #[case(pack([("field1", col("a")), ("field2", col("b"))], vortex_dtype::Nullability::NonNullable
751 ))]
752 #[case(merge([col("struct1"), col("struct2")]))]
754 #[case(and(gt(col("a"), lit(0)), lt(col("a"), lit(100))))]
756 #[case(or(is_null(col("a")), eq(col("a"), lit(0))))]
757 #[case(not(and(eq(col("status"), lit("active")), gt(col("age"), lit(18)))))]
758 fn text_expr_serde_round_trip(#[case] expr: Expression) -> VortexResult<()> {
759 let serialized_pb = expr.serialize_proto()?;
760 let deserialized_expr = Expression::from_proto(&serialized_pb, &LEGACY_SESSION)?;
761
762 assert_eq!(&expr, &deserialized_expr);
763
764 Ok(())
765 }
766}