vortex_expr/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Vortex's expression language.
5//!
6//! All expressions are serializable, and own their own wire format.
7//!
8//! The implementation takes inspiration from [Postgres] and [Apache Datafusion].
9//!
10//! [Postgres]: https://www.postgresql.org/docs/current/sql-expressions.html
11//! [Apache Datafusion]: https://github.com/apache/datafusion/tree/5fac581efbaffd0e6a9edf931182517524526afd/datafusion/expr
12
13use std::any::Any;
14use std::fmt::{Debug, Display, Formatter};
15use std::hash::{Hash, Hasher};
16use std::sync::Arc;
17
18use dyn_hash::DynHash;
19pub use exprs::*;
20pub mod aliases;
21mod analysis;
22#[cfg(feature = "arbitrary")]
23pub mod arbitrary;
24pub mod dyn_traits;
25mod encoding;
26mod exprs;
27mod field;
28pub mod forms;
29mod operator;
30pub mod proto;
31pub mod pruning;
32mod registry;
33mod scope;
34mod scope_vars;
35pub mod transform;
36pub mod traversal;
37mod vtable;
38
39pub use analysis::*;
40pub use between::*;
41pub use binary::*;
42pub use cast::*;
43pub use encoding::*;
44pub use get_item::*;
45pub use is_null::*;
46pub use like::*;
47pub use list_contains::*;
48pub use literal::*;
49pub use merge::*;
50pub use not::*;
51pub use operator::*;
52pub use operators::*;
53pub use pack::*;
54pub use registry::*;
55pub use root::*;
56pub use scope::*;
57pub use scope_vars::*;
58pub use select::*;
59use vortex_array::{Array, ArrayRef, SerializeMetadata};
60use vortex_dtype::{DType, FieldName, FieldPath};
61use vortex_error::{VortexExpect, VortexResult, VortexUnwrap, vortex_bail};
62use vortex_utils::aliases::hash_set::HashSet;
63pub use vtable::*;
64
65use crate::dyn_traits::DynEq;
66use crate::traversal::{NodeExt, ReferenceCollector};
67
68pub trait IntoExpr {
69    /// Convert this type into an expression reference.
70    fn into_expr(self) -> ExprRef;
71}
72
73pub type ExprRef = Arc<dyn VortexExpr>;
74
75/// Represents logical operation on [`ArrayRef`]s
76pub trait VortexExpr:
77    'static + Send + Sync + Debug + Display + DynEq + DynHash + private::Sealed + AnalysisExpr
78{
79    /// Convert expression reference to reference of [`Any`] type
80    fn as_any(&self) -> &dyn Any;
81
82    /// Convert the expression to an [`ExprRef`].
83    fn to_expr(&self) -> ExprRef;
84
85    /// Return the encoding of the expression.
86    fn encoding(&self) -> ExprEncodingRef;
87
88    /// Serialize the metadata of this expression into a bytes vector.
89    ///
90    /// Returns `None` if the expression does not support serialization.
91    fn metadata(&self) -> Option<Vec<u8>> {
92        None
93    }
94
95    /// Compute result of expression on given batch producing a new batch
96    ///
97    /// "Unchecked" means that this function lacks a debug assertion that the returned array matches
98    /// the [VortexExpr::return_dtype] method. Use instead the
99    /// [`VortexExpr::evaluate`](./trait.VortexExpr.html#method.evaluate).
100    /// function which includes such an assertion.
101    fn unchecked_evaluate(&self, ctx: &Scope) -> VortexResult<ArrayRef>;
102
103    /// Returns the children of this expression.
104    fn children(&self) -> Vec<&ExprRef>;
105
106    /// Returns a new instance of this expression with the children replaced.
107    fn with_children(self: Arc<Self>, children: Vec<ExprRef>) -> VortexResult<ExprRef>;
108
109    /// Compute the type of the array returned by
110    /// [`VortexExpr::evaluate`](./trait.VortexExpr.html#method.evaluate).
111    fn return_dtype(&self, scope: &DType) -> VortexResult<DType>;
112}
113
114dyn_hash::hash_trait_object!(VortexExpr);
115
116impl PartialEq for dyn VortexExpr {
117    fn eq(&self, other: &Self) -> bool {
118        self.dyn_eq(other.as_any())
119    }
120}
121
122impl Eq for dyn VortexExpr {}
123
124impl dyn VortexExpr + '_ {
125    pub fn id(&self) -> ExprId {
126        self.encoding().id()
127    }
128
129    pub fn is<V: VTable>(&self) -> bool {
130        self.as_opt::<V>().is_some()
131    }
132
133    pub fn as_<V: VTable>(&self) -> &V::Expr {
134        self.as_opt::<V>()
135            .vortex_expect("Expr is not of the expected type")
136    }
137
138    pub fn as_opt<V: VTable>(&self) -> Option<&V::Expr> {
139        VortexExpr::as_any(self)
140            .downcast_ref::<ExprAdapter<V>>()
141            .map(|e| &e.0)
142    }
143
144    /// Compute result of expression on given batch producing a new batch
145    pub fn evaluate(&self, scope: &Scope) -> VortexResult<ArrayRef> {
146        let result = self.unchecked_evaluate(scope)?;
147        assert_eq!(
148            result.dtype(),
149            &self.return_dtype(scope.dtype())?,
150            "Expression {} returned dtype {} but declared return_dtype of {}",
151            self,
152            result.dtype(),
153            self.return_dtype(scope.dtype())?,
154        );
155        Ok(result)
156    }
157}
158
159pub trait VortexExprExt {
160    /// Accumulate all field references from this expression and its children in a set
161    fn field_references(&self) -> HashSet<FieldName>;
162}
163
164impl VortexExprExt for ExprRef {
165    fn field_references(&self) -> HashSet<FieldName> {
166        let mut collector = ReferenceCollector::new();
167        // The collector is infallible, so we can unwrap the result
168        self.accept(&mut collector).vortex_unwrap();
169        collector.into_fields()
170    }
171}
172
173#[derive(Clone)]
174#[repr(transparent)]
175pub struct ExprAdapter<V: VTable>(V::Expr);
176
177impl<V: VTable> VortexExpr for ExprAdapter<V> {
178    fn as_any(&self) -> &dyn Any {
179        self
180    }
181
182    fn to_expr(&self) -> ExprRef {
183        Arc::new(ExprAdapter::<V>(self.0.clone()))
184    }
185
186    fn encoding(&self) -> ExprEncodingRef {
187        V::encoding(&self.0)
188    }
189
190    fn metadata(&self) -> Option<Vec<u8>> {
191        V::metadata(&self.0).map(|m| m.serialize())
192    }
193
194    fn unchecked_evaluate(&self, ctx: &Scope) -> VortexResult<ArrayRef> {
195        V::evaluate(&self.0, ctx)
196    }
197
198    fn children(&self) -> Vec<&ExprRef> {
199        V::children(&self.0)
200    }
201
202    fn with_children(self: Arc<Self>, children: Vec<ExprRef>) -> VortexResult<ExprRef> {
203        if self.children().len() != children.len() {
204            vortex_bail!(
205                "Expected {} children, got {}",
206                self.children().len(),
207                children.len()
208            );
209        }
210        Ok(V::with_children(&self.0, children)?.to_expr())
211    }
212
213    fn return_dtype(&self, scope: &DType) -> VortexResult<DType> {
214        V::return_dtype(&self.0, scope)
215    }
216}
217
218impl<V: VTable> Debug for ExprAdapter<V> {
219    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
220        Debug::fmt(&self.0, f)
221    }
222}
223
224impl<V: VTable> Display for ExprAdapter<V> {
225    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
226        Display::fmt(&self.0, f)
227    }
228}
229
230impl<V: VTable> PartialEq for ExprAdapter<V> {
231    fn eq(&self, other: &Self) -> bool {
232        PartialEq::eq(&self.0, &other.0)
233    }
234}
235
236impl<V: VTable> Eq for ExprAdapter<V> {}
237
238impl<V: VTable> Hash for ExprAdapter<V> {
239    fn hash<H: Hasher>(&self, state: &mut H) {
240        Hash::hash(&self.0, state);
241    }
242}
243
244impl<V: VTable> AnalysisExpr for ExprAdapter<V> {
245    fn stat_falsification(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
246        <V::Expr as AnalysisExpr>::stat_falsification(&self.0, catalog)
247    }
248
249    fn max(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
250        <V::Expr as AnalysisExpr>::max(&self.0, catalog)
251    }
252
253    fn min(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
254        <V::Expr as AnalysisExpr>::min(&self.0, catalog)
255    }
256
257    fn nan_count(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
258        <V::Expr as AnalysisExpr>::nan_count(&self.0, catalog)
259    }
260
261    fn field_path(&self) -> Option<FieldPath> {
262        <V::Expr as AnalysisExpr>::field_path(&self.0)
263    }
264}
265
266mod private {
267    use super::*;
268
269    pub trait Sealed {}
270
271    impl<V: VTable> Sealed for ExprAdapter<V> {}
272}
273
274/// Splits top level and operations into separate expressions.
275pub fn split_conjunction(expr: &ExprRef) -> Vec<ExprRef> {
276    let mut conjunctions = vec![];
277    split_inner(expr, &mut conjunctions);
278    conjunctions
279}
280
281fn split_inner(expr: &ExprRef, exprs: &mut Vec<ExprRef>) {
282    match expr.as_opt::<BinaryVTable>() {
283        Some(bexp) if bexp.op() == Operator::And => {
284            split_inner(bexp.lhs(), exprs);
285            split_inner(bexp.rhs(), exprs);
286        }
287        Some(_) | None => {
288            exprs.push(expr.clone());
289        }
290    }
291}
292
293/// An expression wrapper that performs pointer equality.
294#[derive(Clone)]
295pub struct ExactExpr(pub ExprRef);
296
297impl PartialEq for ExactExpr {
298    fn eq(&self, other: &Self) -> bool {
299        Arc::ptr_eq(&self.0, &other.0)
300    }
301}
302
303impl Eq for ExactExpr {}
304
305impl Hash for ExactExpr {
306    fn hash<H: Hasher>(&self, state: &mut H) {
307        Arc::as_ptr(&self.0).hash(state)
308    }
309}
310
311#[cfg(feature = "test-harness")]
312pub mod test_harness {
313
314    use vortex_dtype::{DType, Nullability, PType, StructFields};
315
316    pub fn struct_dtype() -> DType {
317        DType::Struct(
318            StructFields::new(
319                ["a", "col1", "col2", "bool1", "bool2"].into(),
320                vec![
321                    DType::Primitive(PType::I32, Nullability::NonNullable),
322                    DType::Primitive(PType::U16, Nullability::Nullable),
323                    DType::Primitive(PType::U16, Nullability::Nullable),
324                    DType::Bool(Nullability::NonNullable),
325                    DType::Bool(Nullability::NonNullable),
326                ],
327            ),
328            Nullability::NonNullable,
329        )
330    }
331}
332
333#[cfg(test)]
334mod tests {
335    use vortex_dtype::{DType, FieldNames, Nullability, PType, StructFields};
336    use vortex_scalar::Scalar;
337
338    use super::*;
339
340    #[test]
341    fn basic_expr_split_test() {
342        let lhs = get_item("col1", root());
343        let rhs = lit(1);
344        let expr = eq(lhs, rhs);
345        let conjunction = split_conjunction(&expr);
346        assert_eq!(conjunction.len(), 1);
347    }
348
349    #[test]
350    fn basic_conjunction_split_test() {
351        let lhs = get_item("col1", root());
352        let rhs = lit(1);
353        let expr = and(lhs, rhs);
354        let conjunction = split_conjunction(&expr);
355        assert_eq!(conjunction.len(), 2, "Conjunction is {conjunction:?}");
356    }
357
358    #[test]
359    fn expr_display() {
360        assert_eq!(col("a").to_string(), "$.a");
361        assert_eq!(root().to_string(), "$");
362
363        let col1: Arc<dyn VortexExpr> = col("col1");
364        let col2: Arc<dyn VortexExpr> = col("col2");
365        assert_eq!(
366            and(col1.clone(), col2.clone()).to_string(),
367            "($.col1 and $.col2)"
368        );
369        assert_eq!(
370            or(col1.clone(), col2.clone()).to_string(),
371            "($.col1 or $.col2)"
372        );
373        assert_eq!(
374            eq(col1.clone(), col2.clone()).to_string(),
375            "($.col1 = $.col2)"
376        );
377        assert_eq!(
378            not_eq(col1.clone(), col2.clone()).to_string(),
379            "($.col1 != $.col2)"
380        );
381        assert_eq!(
382            gt(col1.clone(), col2.clone()).to_string(),
383            "($.col1 > $.col2)"
384        );
385        assert_eq!(
386            gt_eq(col1.clone(), col2.clone()).to_string(),
387            "($.col1 >= $.col2)"
388        );
389        assert_eq!(
390            lt(col1.clone(), col2.clone()).to_string(),
391            "($.col1 < $.col2)"
392        );
393        assert_eq!(
394            lt_eq(col1.clone(), col2.clone()).to_string(),
395            "($.col1 <= $.col2)"
396        );
397
398        assert_eq!(
399            or(
400                lt(col1.clone(), col2.clone()),
401                not_eq(col1.clone(), col2.clone()),
402            )
403            .to_string(),
404            "(($.col1 < $.col2) or ($.col1 != $.col2))"
405        );
406
407        assert_eq!(not(col1.clone()).to_string(), "(!$.col1)");
408
409        assert_eq!(
410            select(vec![FieldName::from("col1")], root()).to_string(),
411            "${col1}"
412        );
413        assert_eq!(
414            select(
415                vec![FieldName::from("col1"), FieldName::from("col2")],
416                root()
417            )
418            .to_string(),
419            "${col1, col2}"
420        );
421        assert_eq!(
422            select_exclude(
423                vec![FieldName::from("col1"), FieldName::from("col2")],
424                root()
425            )
426            .to_string(),
427            "$~{col1, col2}"
428        );
429
430        assert_eq!(lit(Scalar::from(0u8)).to_string(), "0u8");
431        assert_eq!(lit(Scalar::from(0.0f32)).to_string(), "0f32");
432        assert_eq!(
433            lit(Scalar::from(i64::MAX)).to_string(),
434            "9223372036854775807i64"
435        );
436        assert_eq!(lit(Scalar::from(true)).to_string(), "true");
437        assert_eq!(
438            lit(Scalar::null(DType::Bool(Nullability::Nullable))).to_string(),
439            "null"
440        );
441
442        assert_eq!(
443            lit(Scalar::struct_(
444                DType::Struct(
445                    StructFields::new(
446                        FieldNames::from(["dog", "cat"]),
447                        vec![
448                            DType::Primitive(PType::U32, Nullability::NonNullable),
449                            DType::Utf8(Nullability::NonNullable)
450                        ],
451                    ),
452                    Nullability::NonNullable
453                ),
454                vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
455            ))
456            .to_string(),
457            "{dog: 32u32, cat: \"rufus\"}"
458        );
459    }
460}