1use std::any::Any;
5use std::fmt;
6use std::fmt::Debug;
7use std::fmt::Display;
8use std::fmt::Formatter;
9use std::hash::Hash;
10use std::hash::Hasher;
11use std::sync::Arc;
12
13use arcref::ArcRef;
14use vortex_dtype::DType;
15use vortex_error::VortexExpect;
16use vortex_error::VortexResult;
17use vortex_error::vortex_bail;
18use vortex_vector::Datum;
19use vortex_vector::VectorOps;
20
21use crate::ArrayRef;
22use crate::expr::ExprId;
23use crate::expr::StatsCatalog;
24use crate::expr::expression::Expression;
25use crate::expr::scalar_fn::ScalarFn;
26use crate::expr::stats::Stat;
27
28pub trait VTable: 'static + Sized + Send + Sync {
40 type Options: 'static + Send + Sync + Clone + Debug + Display + PartialEq + Eq + Hash;
42
43 fn id(&self) -> ExprId;
45
46 fn serialize(&self, options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
51 _ = options;
52 Ok(None)
53 }
54
55 fn deserialize(&self, _metadata: &[u8]) -> VortexResult<Self::Options> {
57 vortex_bail!("Expression {} is not deserializable", self.id());
58 }
59
60 fn arity(&self, options: &Self::Options) -> Arity;
62
63 fn child_name(&self, options: &Self::Options, child_idx: usize) -> ChildName;
65
66 fn fmt_sql(
71 &self,
72 options: &Self::Options,
73 expr: &Expression,
74 f: &mut Formatter<'_>,
75 ) -> fmt::Result;
76
77 fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType>;
79
80 fn evaluate(
84 &self,
85 options: &Self::Options,
86 expr: &Expression,
87 scope: &ArrayRef,
88 ) -> VortexResult<ArrayRef> {
89 _ = options;
90 _ = expr;
91 _ = scope;
92 vortex_bail!("Expression {} does not support evaluation", self.id());
93 }
94
95 fn execute(&self, options: &Self::Options, args: ExecutionArgs) -> VortexResult<Datum> {
99 _ = options;
100 drop(args);
101 vortex_bail!("Expression {} does not support execution", self.id());
102 }
103
104 fn reduce(
111 &self,
112 options: &Self::Options,
113 node: &dyn ReduceNode,
114 ctx: &dyn ReduceCtx,
115 ) -> VortexResult<Option<ReduceNodeRef>> {
116 _ = options;
117 _ = node;
118 _ = ctx;
119 Ok(None)
120 }
121
122 fn simplify(
124 &self,
125 options: &Self::Options,
126 expr: &Expression,
127 ctx: &dyn SimplifyCtx,
128 ) -> VortexResult<Option<Expression>> {
129 _ = options;
130 _ = expr;
131 _ = ctx;
132 Ok(None)
133 }
134
135 fn simplify_untyped(
137 &self,
138 options: &Self::Options,
139 expr: &Expression,
140 ) -> VortexResult<Option<Expression>> {
141 _ = options;
142 _ = expr;
143 Ok(None)
144 }
145
146 fn stat_falsification(
148 &self,
149 options: &Self::Options,
150 expr: &Expression,
151 catalog: &dyn StatsCatalog,
152 ) -> Option<Expression> {
153 _ = options;
154 _ = expr;
155 _ = catalog;
156 None
157 }
158
159 fn stat_expression(
161 &self,
162 options: &Self::Options,
163 expr: &Expression,
164 stat: Stat,
165 catalog: &dyn StatsCatalog,
166 ) -> Option<Expression> {
167 _ = options;
168 _ = expr;
169 _ = expr;
170 _ = stat;
171 _ = catalog;
172 None
173 }
174
175 fn is_null_sensitive(&self, options: &Self::Options) -> bool {
191 _ = options;
192 true
193 }
194
195 fn is_fallible(&self, options: &Self::Options) -> bool {
203 _ = options;
204 true
205 }
206}
207
208pub trait ReduceCtx {
210 fn new_node(
212 &self,
213 scalar_fn: ScalarFn,
214 children: &[ReduceNodeRef],
215 ) -> VortexResult<ReduceNodeRef>;
216}
217
218pub type ReduceNodeRef = Arc<dyn ReduceNode>;
219
220pub trait ReduceNode {
222 fn as_any(&self) -> &dyn Any;
224
225 fn node_dtype(&self) -> VortexResult<DType>;
227
228 fn scalar_fn(&self) -> Option<&ScalarFn>;
230
231 fn child(&self, idx: usize) -> ReduceNodeRef;
233
234 fn child_count(&self) -> usize;
236
237 fn children(&self) -> Vec<ReduceNodeRef> {
239 (0..self.child_count()).map(|i| self.child(i)).collect()
240 }
241}
242
243#[derive(Clone, Copy, Debug, PartialEq, Eq)]
245pub enum Arity {
246 Exact(usize),
247 Variadic { min: usize, max: Option<usize> },
248}
249
250impl Display for Arity {
251 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
252 match self {
253 Arity::Exact(n) => write!(f, "{}", n),
254 Arity::Variadic { min, max } => match max {
255 Some(max) if min == max => write!(f, "{}", min),
256 Some(max) => write!(f, "{}..{}", min, max),
257 None => write!(f, "{}+", min),
258 },
259 }
260 }
261}
262
263impl Arity {
264 pub fn matches(&self, arg_count: usize) -> bool {
266 match self {
267 Arity::Exact(m) => *m == arg_count,
268 Arity::Variadic { min, max } => {
269 if arg_count < *min {
270 return false;
271 }
272 if let Some(max) = max
273 && arg_count > *max
274 {
275 return false;
276 }
277 true
278 }
279 }
280 }
281}
282
283pub trait SimplifyCtx {
287 fn return_dtype(&self, expr: &Expression) -> VortexResult<DType>;
289}
290
291pub struct ExecutionArgs {
293 pub datums: Vec<Datum>,
295 pub dtypes: Vec<DType>,
297 pub row_count: usize,
299 pub return_dtype: DType,
301}
302
303#[derive(Clone, Debug, PartialEq, Eq, Hash)]
304pub struct EmptyOptions;
305impl Display for EmptyOptions {
306 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
307 write!(f, "")
308 }
309}
310
311pub trait VTableExt: VTable {
313 fn bind(&'static self, options: Self::Options) -> ScalarFn {
315 ScalarFn::new_static(self, options)
316 }
317
318 fn new_expr(
320 &'static self,
321 options: Self::Options,
322 children: impl Into<Arc<[Expression]>>,
323 ) -> Expression {
324 Self::try_new_expr(self, options, children).vortex_expect("Failed to create expression")
325 }
326
327 fn try_new_expr(
329 &'static self,
330 options: Self::Options,
331 children: impl Into<Arc<[Expression]>>,
332 ) -> VortexResult<Expression> {
333 Expression::try_new(self.bind(options), children.into())
334 }
335}
336impl<V: VTable> VTableExt for V {}
337
338pub type ChildName = ArcRef<str>;
340
341pub struct NotSupported;
343
344pub trait DynExprVTable: 'static + Send + Sync + private::Sealed {
349 fn as_any(&self) -> &dyn Any;
350
351 fn id(&self) -> ExprId;
352 fn fmt_sql(&self, expression: &Expression, f: &mut Formatter<'_>) -> fmt::Result;
353
354 fn options_serialize(&self, options: &dyn Any) -> VortexResult<Option<Vec<u8>>>;
355 fn options_deserialize(&self, metadata: &[u8]) -> VortexResult<Box<dyn Any + Send + Sync>>;
356 fn options_clone(&self, options: &dyn Any) -> Box<dyn Any + Send + Sync>;
357 fn options_eq(&self, a: &dyn Any, b: &dyn Any) -> bool;
358 fn options_hash(&self, options: &dyn Any, hasher: &mut dyn Hasher);
359 fn options_display(&self, options: &dyn Any, fmt: &mut Formatter<'_>) -> fmt::Result;
360 fn options_debug(&self, options: &dyn Any, fmt: &mut Formatter<'_>) -> fmt::Result;
361
362 fn return_dtype(&self, options: &dyn Any, arg_types: &[DType]) -> VortexResult<DType>;
363 fn simplify(
364 &self,
365 expression: &Expression,
366 ctx: &dyn SimplifyCtx,
367 ) -> VortexResult<Option<Expression>>;
368 fn simplify_untyped(&self, expression: &Expression) -> VortexResult<Option<Expression>>;
369 fn execute(&self, options: &dyn Any, args: ExecutionArgs) -> VortexResult<Datum>;
370 fn evaluate(&self, expression: &Expression, scope: &ArrayRef) -> VortexResult<ArrayRef>;
371 fn reduce(
372 &self,
373 options: &dyn Any,
374 node: &dyn ReduceNode,
375 ctx: &dyn ReduceCtx,
376 ) -> VortexResult<Option<ReduceNodeRef>>;
377
378 fn arity(&self, options: &dyn Any) -> Arity;
379 fn child_name(&self, options: &dyn Any, child_idx: usize) -> ChildName;
380 fn stat_falsification(
381 &self,
382 expression: &Expression,
383 catalog: &dyn StatsCatalog,
384 ) -> Option<Expression>;
385 fn stat_expression(
386 &self,
387 expression: &Expression,
388 stat: Stat,
389 catalog: &dyn StatsCatalog,
390 ) -> Option<Expression>;
391 fn is_null_sensitive(&self, options: &dyn Any) -> bool;
392 fn is_fallible(&self, options: &dyn Any) -> bool;
393}
394
395#[repr(transparent)]
396pub struct VTableAdapter<V>(V);
397
398impl<V: VTable> DynExprVTable for VTableAdapter<V> {
399 #[inline(always)]
400 fn as_any(&self) -> &dyn Any {
401 &self.0
402 }
403
404 #[inline(always)]
405 fn id(&self) -> ExprId {
406 V::id(&self.0)
407 }
408
409 fn fmt_sql(&self, expression: &Expression, f: &mut Formatter<'_>) -> fmt::Result {
410 V::fmt_sql(
411 &self.0,
412 downcast::<V>(expression.options().as_any()),
413 expression,
414 f,
415 )
416 }
417
418 fn options_serialize(&self, options: &dyn Any) -> VortexResult<Option<Vec<u8>>> {
419 V::serialize(&self.0, downcast::<V>(options))
420 }
421
422 fn options_deserialize(&self, bytes: &[u8]) -> VortexResult<Box<dyn Any + Send + Sync>> {
423 Ok(Box::new(V::deserialize(&self.0, bytes)?))
424 }
425
426 fn options_clone(&self, options: &dyn Any) -> Box<dyn Any + Send + Sync> {
427 let options = options
428 .downcast_ref::<V::Options>()
429 .vortex_expect("Failed to downcast expression options to expected type");
430 Box::new(options.clone())
431 }
432
433 fn options_eq(&self, a: &dyn Any, b: &dyn Any) -> bool {
434 downcast::<V>(a) == downcast::<V>(b)
435 }
436
437 fn options_hash(&self, options: &dyn Any, mut hasher: &mut dyn Hasher) {
438 downcast::<V>(options).hash(&mut hasher);
439 }
440
441 fn options_display(&self, options: &dyn Any, fmt: &mut Formatter<'_>) -> fmt::Result {
442 Display::fmt(downcast::<V>(options), fmt)
443 }
444
445 fn options_debug(&self, options: &dyn Any, fmt: &mut Formatter<'_>) -> fmt::Result {
446 Debug::fmt(downcast::<V>(options), fmt)
447 }
448
449 fn return_dtype(&self, options: &dyn Any, arg_dtypes: &[DType]) -> VortexResult<DType> {
450 V::return_dtype(&self.0, downcast::<V>(options), arg_dtypes)
451 }
452
453 fn simplify(
454 &self,
455 expression: &Expression,
456 ctx: &dyn SimplifyCtx,
457 ) -> VortexResult<Option<Expression>> {
458 V::simplify(
459 &self.0,
460 downcast::<V>(expression.options().as_any()),
461 expression,
462 ctx,
463 )
464 }
465
466 fn simplify_untyped(&self, expression: &Expression) -> VortexResult<Option<Expression>> {
467 V::simplify_untyped(
468 &self.0,
469 downcast::<V>(expression.options().as_any()),
470 expression,
471 )
472 }
473
474 fn execute(&self, options: &dyn Any, args: ExecutionArgs) -> VortexResult<Datum> {
475 let options = downcast::<V>(options);
476
477 let expected_row_count = args.row_count;
478 #[cfg(debug_assertions)]
479 let expected_dtype = args.return_dtype.clone();
480
481 let result = V::execute(&self.0, options, args)?;
482
483 if let Datum::Vector(v) = &result {
484 assert_eq!(
485 v.len(),
486 expected_row_count,
487 "Expression execution {} returned vector of length {}, but expected {}",
488 self.0.id(),
489 v.len(),
490 expected_row_count,
491 );
492 }
493
494 #[cfg(debug_assertions)]
496 {
497 use vortex_vector::datum_matches_dtype;
498
499 if !datum_matches_dtype(&result, &expected_dtype) {
500 vortex_bail!(
501 "Expression execution returned datum of invalid dtype. Expected {}, got {:?}",
502 expected_dtype,
503 result
504 );
505 }
506 }
507
508 Ok(result)
509 }
510
511 fn evaluate(&self, expression: &Expression, scope: &ArrayRef) -> VortexResult<ArrayRef> {
512 V::evaluate(
513 &self.0,
514 downcast::<V>(expression.options().as_any()),
515 expression,
516 scope,
517 )
518 }
519
520 fn reduce(
521 &self,
522 options: &dyn Any,
523 node: &dyn ReduceNode,
524 ctx: &dyn ReduceCtx,
525 ) -> VortexResult<Option<ReduceNodeRef>> {
526 V::reduce(&self.0, downcast::<V>(options), node, ctx)
527 }
528
529 fn arity(&self, options: &dyn Any) -> Arity {
530 V::arity(&self.0, downcast::<V>(options))
531 }
532
533 fn child_name(&self, options: &dyn Any, child_idx: usize) -> ChildName {
534 V::child_name(&self.0, downcast::<V>(options), child_idx)
535 }
536
537 fn stat_falsification(
538 &self,
539 expression: &Expression,
540 catalog: &dyn StatsCatalog,
541 ) -> Option<Expression> {
542 V::stat_falsification(
543 &self.0,
544 downcast::<V>(expression.options().as_any()),
545 expression,
546 catalog,
547 )
548 }
549
550 fn stat_expression(
551 &self,
552 expression: &Expression,
553 stat: Stat,
554 catalog: &dyn StatsCatalog,
555 ) -> Option<Expression> {
556 V::stat_expression(
557 &self.0,
558 downcast::<V>(expression.options().as_any()),
559 expression,
560 stat,
561 catalog,
562 )
563 }
564
565 fn is_null_sensitive(&self, options: &dyn Any) -> bool {
566 V::is_null_sensitive(&self.0, downcast::<V>(options))
567 }
568
569 fn is_fallible(&self, options: &dyn Any) -> bool {
570 V::is_fallible(&self.0, downcast::<V>(options))
571 }
572}
573
574fn downcast<V: VTable>(options: &dyn Any) -> &V::Options {
575 options
576 .downcast_ref::<V::Options>()
577 .vortex_expect("Invalid options type for expression")
578}
579
580mod private {
581 use crate::expr::VTable;
582 use crate::expr::VTableAdapter;
583
584 pub trait Sealed {}
585 impl<V: VTable> Sealed for VTableAdapter<V> {}
586}
587
588#[derive(Clone)]
590pub struct ExprVTable(ArcRef<dyn DynExprVTable>);
591
592impl ExprVTable {
593 pub(crate) fn as_dyn(&self) -> &dyn DynExprVTable {
596 self.0.as_ref()
597 }
598
599 pub fn as_any(&self) -> &dyn Any {
601 self.0.as_any()
602 }
603
604 pub fn new<V: VTable>(vtable: V) -> Self {
606 Self(ArcRef::new_arc(Arc::new(VTableAdapter(vtable))))
607 }
608
609 pub const fn new_static<V: VTable>(vtable: &'static V) -> Self {
611 let adapted: &'static VTableAdapter<V> =
613 unsafe { &*(vtable as *const V as *const VTableAdapter<V>) };
614 Self(ArcRef::new_ref(adapted as &'static dyn DynExprVTable))
615 }
616
617 pub fn id(&self) -> ExprId {
619 self.0.id()
620 }
621
622 pub fn is<V: VTable>(&self) -> bool {
624 self.0.as_any().is::<V>()
625 }
626
627 pub fn deserialize(&self, metadata: &[u8]) -> VortexResult<ScalarFn> {
629 Ok(unsafe {
630 ScalarFn::new_unchecked(self.clone(), self.as_dyn().options_deserialize(metadata)?)
631 })
632 }
633}
634
635impl PartialEq for ExprVTable {
636 fn eq(&self, other: &Self) -> bool {
637 self.0.id() == other.0.id()
638 }
639}
640impl Eq for ExprVTable {}
641
642impl Hash for ExprVTable {
643 fn hash<H: Hasher>(&self, state: &mut H) {
644 self.0.id().hash(state);
645 }
646}
647
648impl Display for ExprVTable {
649 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
650 write!(f, "{}", self.as_dyn().id())
651 }
652}
653
654impl Debug for ExprVTable {
655 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
656 write!(f, "{}", self.as_dyn().id())
657 }
658}
659
660#[cfg(test)]
661mod tests {
662 use rstest::fixture;
663 use rstest::rstest;
664
665 use super::*;
666 use crate::expr::exprs::between::between;
667 use crate::expr::exprs::binary::and;
668 use crate::expr::exprs::binary::checked_add;
669 use crate::expr::exprs::binary::eq;
670 use crate::expr::exprs::binary::gt;
671 use crate::expr::exprs::binary::gt_eq;
672 use crate::expr::exprs::binary::lt;
673 use crate::expr::exprs::binary::lt_eq;
674 use crate::expr::exprs::binary::not_eq;
675 use crate::expr::exprs::binary::or;
676 use crate::expr::exprs::cast::cast;
677 use crate::expr::exprs::get_item::col;
678 use crate::expr::exprs::get_item::get_item;
679 use crate::expr::exprs::is_null::is_null;
680 use crate::expr::exprs::list_contains::list_contains;
681 use crate::expr::exprs::literal::lit;
682 use crate::expr::exprs::merge::merge;
683 use crate::expr::exprs::not::not;
684 use crate::expr::exprs::pack::pack;
685 use crate::expr::exprs::root::root;
686 use crate::expr::exprs::select::select;
687 use crate::expr::exprs::select::select_exclude;
688 use crate::expr::proto::ExprSerializeProtoExt;
689 use crate::expr::proto::deserialize_expr_proto;
690 use crate::expr::session::ExprRegistry;
691 use crate::expr::session::ExprSession;
692
693 #[fixture]
694 #[once]
695 fn registry() -> ExprRegistry {
696 ExprSession::default().registry().clone()
697 }
698
699 #[rstest]
700 #[case(root())]
702 #[case(select(["hello", "world"], root()))]
703 #[case(select_exclude(["world", "hello"], root()))]
704 #[case(lit(42i32))]
706 #[case(lit(std::f64::consts::PI))]
707 #[case(lit(true))]
708 #[case(lit("hello"))]
709 #[case(col("column_name"))]
711 #[case(get_item("field", root()))]
712 #[case(eq(col("a"), lit(10)))]
714 #[case(not_eq(col("a"), lit(10)))]
715 #[case(gt(col("a"), lit(10)))]
716 #[case(gt_eq(col("a"), lit(10)))]
717 #[case(lt(col("a"), lit(10)))]
718 #[case(lt_eq(col("a"), lit(10)))]
719 #[case(and(col("a"), col("b")))]
721 #[case(or(col("a"), col("b")))]
722 #[case(not(col("a")))]
723 #[case(checked_add(col("a"), lit(5)))]
725 #[case(is_null(col("nullable_col")))]
727 #[case(cast(
729 col("a"),
730 DType::Primitive(vortex_dtype::PType::I64, vortex_dtype::Nullability::NonNullable)
731 ))]
732 #[case(between(col("a"), lit(10), lit(20), crate::compute::BetweenOptions { lower_strict: crate::compute::StrictComparison::NonStrict, upper_strict: crate::compute::StrictComparison::NonStrict }))]
734 #[case(list_contains(col("list_col"), lit("item")))]
736 #[case(pack([("field1", col("a")), ("field2", col("b"))], vortex_dtype::Nullability::NonNullable))]
738 #[case(merge([col("struct1"), col("struct2")]))]
740 #[case(and(gt(col("a"), lit(0)), lt(col("a"), lit(100))))]
742 #[case(or(is_null(col("a")), eq(col("a"), lit(0))))]
743 #[case(not(and(eq(col("status"), lit("active")), gt(col("age"), lit(18)))))]
744 fn text_expr_serde_round_trip(
745 registry: &ExprRegistry,
746 #[case] expr: Expression,
747 ) -> VortexResult<()> {
748 let serialized_pb = expr.serialize_proto()?;
749 let deserialized_expr = deserialize_expr_proto(&serialized_pb, registry)?;
750
751 assert_eq!(&expr, &deserialized_expr);
752
753 Ok(())
754 }
755}