vortex_expr/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Vortex's expression language.
5//!
6//! All expressions are serializable, and own their own wire format.
7//!
8//! The implementation takes inspiration from [Postgres] and [Apache Datafusion].
9//!
10//! [Postgres]: https://www.postgresql.org/docs/current/sql-expressions.html
11//! [Apache Datafusion]: https://github.com/apache/datafusion/tree/5fac581efbaffd0e6a9edf931182517524526afd/datafusion/expr
12
13use std::any::Any;
14use std::fmt::{Debug, Display, Formatter};
15use std::hash::{Hash, Hasher};
16use std::sync::Arc;
17
18use dyn_hash::DynHash;
19pub use exprs::*;
20pub mod aliases;
21mod analysis;
22#[cfg(feature = "arbitrary")]
23pub mod arbitrary;
24pub mod dyn_traits;
25mod encoding;
26mod exprs;
27mod field;
28pub mod forms;
29pub mod proto;
30pub mod pruning;
31mod registry;
32mod scope;
33mod scope_vars;
34pub mod transform;
35pub mod traversal;
36mod vtable;
37
38pub use analysis::*;
39pub use between::*;
40pub use binary::*;
41pub use cast::*;
42pub use encoding::*;
43pub use get_item::*;
44pub use is_null::*;
45pub use like::*;
46pub use list_contains::*;
47pub use literal::*;
48pub use merge::*;
49pub use not::*;
50pub use operators::*;
51pub use pack::*;
52pub use registry::*;
53pub use root::*;
54pub use scope::*;
55pub use scope_vars::*;
56pub use select::*;
57use vortex_array::{Array, ArrayRef, SerializeMetadata};
58use vortex_dtype::{DType, FieldName, FieldPath};
59use vortex_error::{VortexExpect, VortexResult, VortexUnwrap, vortex_bail};
60use vortex_utils::aliases::hash_set::HashSet;
61pub use vtable::*;
62
63use crate::dyn_traits::DynEq;
64use crate::traversal::{NodeExt, ReferenceCollector};
65
66pub trait IntoExpr {
67    /// Convert this type into an expression reference.
68    fn into_expr(self) -> ExprRef;
69}
70
71pub type ExprRef = Arc<dyn VortexExpr>;
72
73/// Represents logical operation on [`ArrayRef`]s
74pub trait VortexExpr:
75    'static + Send + Sync + Debug + Display + DynEq + DynHash + private::Sealed + AnalysisExpr
76{
77    /// Convert expression reference to reference of [`Any`] type
78    fn as_any(&self) -> &dyn Any;
79
80    /// Convert the expression to an [`ExprRef`].
81    fn to_expr(&self) -> ExprRef;
82
83    /// Return the encoding of the expression.
84    fn encoding(&self) -> ExprEncodingRef;
85
86    /// Serialize the metadata of this expression into a bytes vector.
87    ///
88    /// Returns `None` if the expression does not support serialization.
89    fn metadata(&self) -> Option<Vec<u8>> {
90        None
91    }
92
93    /// Compute result of expression on given batch producing a new batch
94    ///
95    /// "Unchecked" means that this function lacks a debug assertion that the returned array matches
96    /// the [VortexExpr::return_dtype] method. Use instead the
97    /// [`VortexExpr::evaluate`](./trait.VortexExpr.html#method.evaluate).
98    /// function which includes such an assertion.
99    fn unchecked_evaluate(&self, ctx: &Scope) -> VortexResult<ArrayRef>;
100
101    /// Returns the children of this expression.
102    fn children(&self) -> Vec<&ExprRef>;
103
104    /// Returns a new instance of this expression with the children replaced.
105    fn with_children(self: Arc<Self>, children: Vec<ExprRef>) -> VortexResult<ExprRef>;
106
107    /// Compute the type of the array returned by
108    /// [`VortexExpr::evaluate`](./trait.VortexExpr.html#method.evaluate).
109    fn return_dtype(&self, scope: &DType) -> VortexResult<DType>;
110}
111
112dyn_hash::hash_trait_object!(VortexExpr);
113
114impl PartialEq for dyn VortexExpr {
115    fn eq(&self, other: &Self) -> bool {
116        self.dyn_eq(other.as_any())
117    }
118}
119
120impl Eq for dyn VortexExpr {}
121
122impl dyn VortexExpr + '_ {
123    pub fn id(&self) -> ExprId {
124        self.encoding().id()
125    }
126
127    pub fn is<V: VTable>(&self) -> bool {
128        self.as_opt::<V>().is_some()
129    }
130
131    pub fn as_<V: VTable>(&self) -> &V::Expr {
132        self.as_opt::<V>()
133            .vortex_expect("Expr is not of the expected type")
134    }
135
136    pub fn as_opt<V: VTable>(&self) -> Option<&V::Expr> {
137        VortexExpr::as_any(self)
138            .downcast_ref::<ExprAdapter<V>>()
139            .map(|e| &e.0)
140    }
141
142    /// Compute result of expression on given batch producing a new batch
143    pub fn evaluate(&self, scope: &Scope) -> VortexResult<ArrayRef> {
144        let result = self.unchecked_evaluate(scope)?;
145        assert_eq!(
146            result.dtype(),
147            &self.return_dtype(scope.dtype())?,
148            "Expression {} returned dtype {} but declared return_dtype of {}",
149            self,
150            result.dtype(),
151            self.return_dtype(scope.dtype())?,
152        );
153        Ok(result)
154    }
155}
156
157pub trait VortexExprExt {
158    /// Accumulate all field references from this expression and its children in a set
159    fn field_references(&self) -> HashSet<FieldName>;
160}
161
162impl VortexExprExt for ExprRef {
163    fn field_references(&self) -> HashSet<FieldName> {
164        let mut collector = ReferenceCollector::new();
165        // The collector is infallible, so we can unwrap the result
166        self.accept(&mut collector).vortex_unwrap();
167        collector.into_fields()
168    }
169}
170
171#[derive(Clone)]
172#[repr(transparent)]
173pub struct ExprAdapter<V: VTable>(V::Expr);
174
175impl<V: VTable> VortexExpr for ExprAdapter<V> {
176    fn as_any(&self) -> &dyn Any {
177        self
178    }
179
180    fn to_expr(&self) -> ExprRef {
181        Arc::new(ExprAdapter::<V>(self.0.clone()))
182    }
183
184    fn encoding(&self) -> ExprEncodingRef {
185        V::encoding(&self.0)
186    }
187
188    fn metadata(&self) -> Option<Vec<u8>> {
189        V::metadata(&self.0).map(|m| m.serialize())
190    }
191
192    fn unchecked_evaluate(&self, ctx: &Scope) -> VortexResult<ArrayRef> {
193        V::evaluate(&self.0, ctx)
194    }
195
196    fn children(&self) -> Vec<&ExprRef> {
197        V::children(&self.0)
198    }
199
200    fn with_children(self: Arc<Self>, children: Vec<ExprRef>) -> VortexResult<ExprRef> {
201        if self.children().len() != children.len() {
202            vortex_bail!(
203                "Expected {} children, got {}",
204                self.children().len(),
205                children.len()
206            );
207        }
208        Ok(V::with_children(&self.0, children)?.to_expr())
209    }
210
211    fn return_dtype(&self, scope: &DType) -> VortexResult<DType> {
212        V::return_dtype(&self.0, scope)
213    }
214}
215
216impl<V: VTable> Debug for ExprAdapter<V> {
217    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
218        Debug::fmt(&self.0, f)
219    }
220}
221
222impl<V: VTable> Display for ExprAdapter<V> {
223    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
224        Display::fmt(&self.0, f)
225    }
226}
227
228impl<V: VTable> PartialEq for ExprAdapter<V> {
229    fn eq(&self, other: &Self) -> bool {
230        PartialEq::eq(&self.0, &other.0)
231    }
232}
233
234impl<V: VTable> Eq for ExprAdapter<V> {}
235
236impl<V: VTable> Hash for ExprAdapter<V> {
237    fn hash<H: Hasher>(&self, state: &mut H) {
238        Hash::hash(&self.0, state);
239    }
240}
241
242impl<V: VTable> AnalysisExpr for ExprAdapter<V> {
243    fn stat_falsification(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
244        <V::Expr as AnalysisExpr>::stat_falsification(&self.0, catalog)
245    }
246
247    fn max(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
248        <V::Expr as AnalysisExpr>::max(&self.0, catalog)
249    }
250
251    fn min(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
252        <V::Expr as AnalysisExpr>::min(&self.0, catalog)
253    }
254
255    fn nan_count(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
256        <V::Expr as AnalysisExpr>::nan_count(&self.0, catalog)
257    }
258
259    fn field_path(&self) -> Option<FieldPath> {
260        <V::Expr as AnalysisExpr>::field_path(&self.0)
261    }
262}
263
264mod private {
265    use super::*;
266
267    pub trait Sealed {}
268
269    impl<V: VTable> Sealed for ExprAdapter<V> {}
270}
271
272/// Splits top level and operations into separate expressions.
273pub fn split_conjunction(expr: &ExprRef) -> Vec<ExprRef> {
274    let mut conjunctions = vec![];
275    split_inner(expr, &mut conjunctions);
276    conjunctions
277}
278
279fn split_inner(expr: &ExprRef, exprs: &mut Vec<ExprRef>) {
280    match expr.as_opt::<BinaryVTable>() {
281        Some(bexp) if bexp.op() == Operator::And => {
282            split_inner(bexp.lhs(), exprs);
283            split_inner(bexp.rhs(), exprs);
284        }
285        Some(_) | None => {
286            exprs.push(expr.clone());
287        }
288    }
289}
290
291/// An expression wrapper that performs pointer equality.
292#[derive(Clone)]
293pub struct ExactExpr(pub ExprRef);
294
295impl PartialEq for ExactExpr {
296    fn eq(&self, other: &Self) -> bool {
297        Arc::ptr_eq(&self.0, &other.0)
298    }
299}
300
301impl Eq for ExactExpr {}
302
303impl Hash for ExactExpr {
304    fn hash<H: Hasher>(&self, state: &mut H) {
305        Arc::as_ptr(&self.0).hash(state)
306    }
307}
308
309#[cfg(feature = "test-harness")]
310pub mod test_harness {
311
312    use vortex_dtype::{DType, Nullability, PType, StructFields};
313
314    pub fn struct_dtype() -> DType {
315        DType::Struct(
316            StructFields::new(
317                ["a", "col1", "col2", "bool1", "bool2"].into(),
318                vec![
319                    DType::Primitive(PType::I32, Nullability::NonNullable),
320                    DType::Primitive(PType::U16, Nullability::Nullable),
321                    DType::Primitive(PType::U16, Nullability::Nullable),
322                    DType::Bool(Nullability::NonNullable),
323                    DType::Bool(Nullability::NonNullable),
324                ],
325            ),
326            Nullability::NonNullable,
327        )
328    }
329}
330
331#[cfg(test)]
332mod tests {
333    use vortex_dtype::{DType, FieldNames, Nullability, PType, StructFields};
334    use vortex_scalar::Scalar;
335
336    use super::*;
337
338    #[test]
339    fn basic_expr_split_test() {
340        let lhs = get_item("col1", root());
341        let rhs = lit(1);
342        let expr = eq(lhs, rhs);
343        let conjunction = split_conjunction(&expr);
344        assert_eq!(conjunction.len(), 1);
345    }
346
347    #[test]
348    fn basic_conjunction_split_test() {
349        let lhs = get_item("col1", root());
350        let rhs = lit(1);
351        let expr = and(lhs, rhs);
352        let conjunction = split_conjunction(&expr);
353        assert_eq!(conjunction.len(), 2, "Conjunction is {conjunction:?}");
354    }
355
356    #[test]
357    fn expr_display() {
358        assert_eq!(col("a").to_string(), "$.a");
359        assert_eq!(root().to_string(), "$");
360
361        let col1: Arc<dyn VortexExpr> = col("col1");
362        let col2: Arc<dyn VortexExpr> = col("col2");
363        assert_eq!(
364            and(col1.clone(), col2.clone()).to_string(),
365            "($.col1 and $.col2)"
366        );
367        assert_eq!(
368            or(col1.clone(), col2.clone()).to_string(),
369            "($.col1 or $.col2)"
370        );
371        assert_eq!(
372            eq(col1.clone(), col2.clone()).to_string(),
373            "($.col1 = $.col2)"
374        );
375        assert_eq!(
376            not_eq(col1.clone(), col2.clone()).to_string(),
377            "($.col1 != $.col2)"
378        );
379        assert_eq!(
380            gt(col1.clone(), col2.clone()).to_string(),
381            "($.col1 > $.col2)"
382        );
383        assert_eq!(
384            gt_eq(col1.clone(), col2.clone()).to_string(),
385            "($.col1 >= $.col2)"
386        );
387        assert_eq!(
388            lt(col1.clone(), col2.clone()).to_string(),
389            "($.col1 < $.col2)"
390        );
391        assert_eq!(
392            lt_eq(col1.clone(), col2.clone()).to_string(),
393            "($.col1 <= $.col2)"
394        );
395
396        assert_eq!(
397            or(
398                lt(col1.clone(), col2.clone()),
399                not_eq(col1.clone(), col2.clone()),
400            )
401            .to_string(),
402            "(($.col1 < $.col2) or ($.col1 != $.col2))"
403        );
404
405        assert_eq!(not(col1.clone()).to_string(), "(!$.col1)");
406
407        assert_eq!(
408            select(vec![FieldName::from("col1")], root()).to_string(),
409            "${col1}"
410        );
411        assert_eq!(
412            select(
413                vec![FieldName::from("col1"), FieldName::from("col2")],
414                root()
415            )
416            .to_string(),
417            "${col1, col2}"
418        );
419        assert_eq!(
420            select_exclude(
421                vec![FieldName::from("col1"), FieldName::from("col2")],
422                root()
423            )
424            .to_string(),
425            "$~{col1, col2}"
426        );
427
428        assert_eq!(lit(Scalar::from(0u8)).to_string(), "0u8");
429        assert_eq!(lit(Scalar::from(0.0f32)).to_string(), "0f32");
430        assert_eq!(
431            lit(Scalar::from(i64::MAX)).to_string(),
432            "9223372036854775807i64"
433        );
434        assert_eq!(lit(Scalar::from(true)).to_string(), "true");
435        assert_eq!(
436            lit(Scalar::null(DType::Bool(Nullability::Nullable))).to_string(),
437            "null"
438        );
439
440        assert_eq!(
441            lit(Scalar::struct_(
442                DType::Struct(
443                    StructFields::new(
444                        FieldNames::from(["dog", "cat"]),
445                        vec![
446                            DType::Primitive(PType::U32, Nullability::NonNullable),
447                            DType::Utf8(Nullability::NonNullable)
448                        ],
449                    ),
450                    Nullability::NonNullable
451                ),
452                vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
453            ))
454            .to_string(),
455            "{dog: 32u32, cat: \"rufus\"}"
456        );
457    }
458}