vortex_expr/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::fmt::{Debug, Display, Formatter};
6use std::hash::{Hash, Hasher};
7use std::sync::Arc;
8
9use dyn_eq::DynEq;
10use dyn_hash::DynHash;
11pub use exprs::*;
12pub mod aliases;
13mod analysis;
14#[cfg(feature = "arbitrary")]
15pub mod arbitrary;
16mod encoding;
17mod exprs;
18mod field;
19pub mod forms;
20pub mod proto;
21pub mod pruning;
22mod registry;
23mod scope;
24mod scope_vars;
25pub mod transform;
26pub mod traversal;
27mod vtable;
28
29pub use analysis::*;
30pub use between::*;
31pub use binary::*;
32pub use cast::*;
33pub use encoding::*;
34pub use get_item::*;
35pub use is_null::*;
36pub use like::*;
37pub use list_contains::*;
38pub use literal::*;
39pub use merge::*;
40pub use not::*;
41pub use operators::*;
42pub use pack::*;
43pub use registry::*;
44pub use root::*;
45pub use scope::*;
46pub use select::*;
47use vortex_array::{Array, ArrayRef, SerializeMetadata};
48use vortex_dtype::{DType, FieldName, FieldPath};
49use vortex_error::{VortexExpect, VortexResult, VortexUnwrap, vortex_bail};
50use vortex_utils::aliases::hash_set::HashSet;
51pub use vtable::*;
52
53use crate::traversal::{Node, ReferenceCollector};
54
55pub trait IntoExpr {
56    /// Convert this type into an expression reference.
57    fn into_expr(self) -> ExprRef;
58}
59
60pub type ExprRef = Arc<dyn VortexExpr>;
61
62/// Represents logical operation on [`ArrayRef`]s
63pub trait VortexExpr:
64    'static + Send + Sync + Debug + Display + DynEq + DynHash + private::Sealed + AnalysisExpr
65{
66    /// Convert expression reference to reference of [`Any`] type
67    fn as_any(&self) -> &dyn Any;
68
69    /// Convert the expression to an [`ExprRef`].
70    fn to_expr(&self) -> ExprRef;
71
72    /// Return the encoding of the expression.
73    fn encoding(&self) -> ExprEncodingRef;
74
75    /// Serialize the metadata of this expression into a bytes vector.
76    ///
77    /// Returns `None` if the expression does not support serialization.
78    fn metadata(&self) -> Option<Vec<u8>> {
79        None
80    }
81
82    /// Compute result of expression on given batch producing a new batch
83    ///
84    /// "Unchecked" means that this function lacks a debug assertion that the returned array matches
85    /// the [VortexExpr::return_dtype] method. Use instead the
86    /// [`VortexExpr::evaluate`](./trait.VortexExpr.html#method.evaluate).
87    /// function which includes such an assertion.
88    fn unchecked_evaluate(&self, ctx: &Scope) -> VortexResult<ArrayRef>;
89
90    /// Returns the children of this expression.
91    fn children(&self) -> Vec<&ExprRef>;
92
93    /// Returns a new instance of this expression with the children replaced.
94    fn with_children(self: Arc<Self>, children: Vec<ExprRef>) -> VortexResult<ExprRef>;
95
96    /// Compute the type of the array returned by
97    /// [`VortexExpr::evaluate`](./trait.VortexExpr.html#method.evaluate).
98    fn return_dtype(&self, scope: &DType) -> VortexResult<DType>;
99}
100
101dyn_eq::eq_trait_object!(VortexExpr);
102dyn_hash::hash_trait_object!(VortexExpr);
103
104impl dyn VortexExpr + '_ {
105    pub fn id(&self) -> ExprId {
106        self.encoding().id()
107    }
108
109    pub fn is<V: VTable>(&self) -> bool {
110        self.as_opt::<V>().is_some()
111    }
112
113    pub fn as_<V: VTable>(&self) -> &V::Expr {
114        self.as_opt::<V>()
115            .vortex_expect("Expr is not of the expected type")
116    }
117
118    pub fn as_opt<V: VTable>(&self) -> Option<&V::Expr> {
119        VortexExpr::as_any(self)
120            .downcast_ref::<ExprAdapter<V>>()
121            .map(|e| &e.0)
122    }
123
124    /// Compute result of expression on given batch producing a new batch
125    pub fn evaluate(&self, scope: &Scope) -> VortexResult<ArrayRef> {
126        let result = self.unchecked_evaluate(scope)?;
127        assert_eq!(
128            result.dtype(),
129            &self.return_dtype(scope.dtype())?,
130            "Expression {} returned dtype {} but declared return_dtype of {}",
131            self,
132            result.dtype(),
133            self.return_dtype(scope.dtype())?,
134        );
135        Ok(result)
136    }
137}
138
139pub trait VortexExprExt {
140    /// Accumulate all field references from this expression and its children in a set
141    fn field_references(&self) -> HashSet<FieldName>;
142}
143
144impl VortexExprExt for ExprRef {
145    fn field_references(&self) -> HashSet<FieldName> {
146        let mut collector = ReferenceCollector::new();
147        // The collector is infallible, so we can unwrap the result
148        self.accept(&mut collector).vortex_unwrap();
149        collector.into_fields()
150    }
151}
152
153#[derive(Clone)]
154#[repr(transparent)]
155pub struct ExprAdapter<V: VTable>(V::Expr);
156
157impl<V: VTable> VortexExpr for ExprAdapter<V> {
158    fn as_any(&self) -> &dyn Any {
159        self
160    }
161
162    fn to_expr(&self) -> ExprRef {
163        Arc::new(ExprAdapter::<V>(self.0.clone()))
164    }
165
166    fn encoding(&self) -> ExprEncodingRef {
167        V::encoding(&self.0)
168    }
169
170    fn metadata(&self) -> Option<Vec<u8>> {
171        V::metadata(&self.0).map(|m| m.serialize())
172    }
173
174    fn unchecked_evaluate(&self, ctx: &Scope) -> VortexResult<ArrayRef> {
175        V::evaluate(&self.0, ctx)
176    }
177
178    fn children(&self) -> Vec<&ExprRef> {
179        V::children(&self.0)
180    }
181
182    fn with_children(self: Arc<Self>, children: Vec<ExprRef>) -> VortexResult<ExprRef> {
183        if self.children().len() != children.len() {
184            vortex_bail!(
185                "Expected {} children, got {}",
186                self.children().len(),
187                children.len()
188            );
189        }
190        Ok(V::with_children(&self.0, children)?.to_expr())
191    }
192
193    fn return_dtype(&self, scope: &DType) -> VortexResult<DType> {
194        V::return_dtype(&self.0, scope)
195    }
196}
197
198impl<V: VTable> Debug for ExprAdapter<V> {
199    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
200        Debug::fmt(&self.0, f)
201    }
202}
203
204impl<V: VTable> Display for ExprAdapter<V> {
205    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
206        Display::fmt(&self.0, f)
207    }
208}
209
210impl<V: VTable> PartialEq for ExprAdapter<V> {
211    fn eq(&self, other: &Self) -> bool {
212        PartialEq::eq(&self.0, &other.0)
213    }
214}
215
216impl<V: VTable> Eq for ExprAdapter<V> {}
217
218impl<V: VTable> Hash for ExprAdapter<V> {
219    fn hash<H: Hasher>(&self, state: &mut H) {
220        Hash::hash(&self.0, state);
221    }
222}
223
224impl<V: VTable> AnalysisExpr for ExprAdapter<V> {
225    fn stat_falsification(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
226        <V::Expr as AnalysisExpr>::stat_falsification(&self.0, catalog)
227    }
228
229    fn max(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
230        <V::Expr as AnalysisExpr>::max(&self.0, catalog)
231    }
232
233    fn min(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
234        <V::Expr as AnalysisExpr>::min(&self.0, catalog)
235    }
236
237    fn nan_count(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
238        <V::Expr as AnalysisExpr>::nan_count(&self.0, catalog)
239    }
240
241    fn field_path(&self) -> Option<FieldPath> {
242        <V::Expr as AnalysisExpr>::field_path(&self.0)
243    }
244}
245
246mod private {
247    use super::*;
248
249    pub trait Sealed {}
250
251    impl<V: VTable> Sealed for ExprAdapter<V> {}
252}
253
254/// Splits top level and operations into separate expressions
255pub fn split_conjunction(expr: &ExprRef) -> Vec<ExprRef> {
256    let mut conjunctions = vec![];
257    split_inner(expr, &mut conjunctions);
258    conjunctions
259}
260
261fn split_inner(expr: &ExprRef, exprs: &mut Vec<ExprRef>) {
262    match expr.as_opt::<BinaryVTable>() {
263        Some(bexp) if bexp.op() == Operator::And => {
264            split_inner(bexp.lhs(), exprs);
265            split_inner(bexp.rhs(), exprs);
266        }
267        Some(_) | None => {
268            exprs.push(expr.clone());
269        }
270    }
271}
272
273/// An expression wrapper that performs pointer equality.
274#[derive(Clone)]
275pub struct ExactExpr(pub ExprRef);
276
277impl PartialEq for ExactExpr {
278    fn eq(&self, other: &Self) -> bool {
279        Arc::ptr_eq(&self.0, &other.0)
280    }
281}
282
283impl Eq for ExactExpr {}
284
285impl Hash for ExactExpr {
286    fn hash<H: Hasher>(&self, state: &mut H) {
287        Arc::as_ptr(&self.0).hash(state)
288    }
289}
290
291#[cfg(feature = "test-harness")]
292pub mod test_harness {
293
294    use vortex_dtype::{DType, Nullability, PType, StructFields};
295
296    pub fn struct_dtype() -> DType {
297        DType::Struct(
298            StructFields::new(
299                ["a", "col1", "col2", "bool1", "bool2"].into(),
300                vec![
301                    DType::Primitive(PType::I32, Nullability::NonNullable),
302                    DType::Primitive(PType::U16, Nullability::Nullable),
303                    DType::Primitive(PType::U16, Nullability::Nullable),
304                    DType::Bool(Nullability::NonNullable),
305                    DType::Bool(Nullability::NonNullable),
306                ],
307            ),
308            Nullability::NonNullable,
309        )
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use vortex_dtype::{DType, FieldNames, Nullability, PType, StructFields};
316    use vortex_scalar::Scalar;
317
318    use super::*;
319
320    #[test]
321    fn basic_expr_split_test() {
322        let lhs = get_item("col1", root());
323        let rhs = lit(1);
324        let expr = eq(lhs, rhs);
325        let conjunction = split_conjunction(&expr);
326        assert_eq!(conjunction.len(), 1);
327    }
328
329    #[test]
330    fn basic_conjunction_split_test() {
331        let lhs = get_item("col1", root());
332        let rhs = lit(1);
333        let expr = and(lhs, rhs);
334        let conjunction = split_conjunction(&expr);
335        assert_eq!(conjunction.len(), 2, "Conjunction is {conjunction:?}");
336    }
337
338    #[test]
339    fn expr_display() {
340        assert_eq!(col("a").to_string(), "$.a");
341        assert_eq!(root().to_string(), "$");
342
343        let col1: Arc<dyn VortexExpr> = col("col1");
344        let col2: Arc<dyn VortexExpr> = col("col2");
345        assert_eq!(
346            and(col1.clone(), col2.clone()).to_string(),
347            "($.col1 and $.col2)"
348        );
349        assert_eq!(
350            or(col1.clone(), col2.clone()).to_string(),
351            "($.col1 or $.col2)"
352        );
353        assert_eq!(
354            eq(col1.clone(), col2.clone()).to_string(),
355            "($.col1 = $.col2)"
356        );
357        assert_eq!(
358            not_eq(col1.clone(), col2.clone()).to_string(),
359            "($.col1 != $.col2)"
360        );
361        assert_eq!(
362            gt(col1.clone(), col2.clone()).to_string(),
363            "($.col1 > $.col2)"
364        );
365        assert_eq!(
366            gt_eq(col1.clone(), col2.clone()).to_string(),
367            "($.col1 >= $.col2)"
368        );
369        assert_eq!(
370            lt(col1.clone(), col2.clone()).to_string(),
371            "($.col1 < $.col2)"
372        );
373        assert_eq!(
374            lt_eq(col1.clone(), col2.clone()).to_string(),
375            "($.col1 <= $.col2)"
376        );
377
378        assert_eq!(
379            or(
380                lt(col1.clone(), col2.clone()),
381                not_eq(col1.clone(), col2.clone()),
382            )
383            .to_string(),
384            "(($.col1 < $.col2) or ($.col1 != $.col2))"
385        );
386
387        assert_eq!(not(col1.clone()).to_string(), "(!$.col1)");
388
389        assert_eq!(
390            select(vec![FieldName::from("col1")], root()).to_string(),
391            "${col1}"
392        );
393        assert_eq!(
394            select(
395                vec![FieldName::from("col1"), FieldName::from("col2")],
396                root()
397            )
398            .to_string(),
399            "${col1, col2}"
400        );
401        assert_eq!(
402            select_exclude(
403                vec![FieldName::from("col1"), FieldName::from("col2")],
404                root()
405            )
406            .to_string(),
407            "$~{col1, col2}"
408        );
409
410        assert_eq!(lit(Scalar::from(0u8)).to_string(), "0u8");
411        assert_eq!(lit(Scalar::from(0.0f32)).to_string(), "0f32");
412        assert_eq!(
413            lit(Scalar::from(i64::MAX)).to_string(),
414            "9223372036854775807i64"
415        );
416        assert_eq!(lit(Scalar::from(true)).to_string(), "true");
417        assert_eq!(
418            lit(Scalar::null(DType::Bool(Nullability::Nullable))).to_string(),
419            "null"
420        );
421
422        assert_eq!(
423            lit(Scalar::struct_(
424                DType::Struct(
425                    StructFields::new(
426                        FieldNames::from(["dog", "cat"]),
427                        vec![
428                            DType::Primitive(PType::U32, Nullability::NonNullable),
429                            DType::Utf8(Nullability::NonNullable)
430                        ],
431                    ),
432                    Nullability::NonNullable
433                ),
434                vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
435            ))
436            .to_string(),
437            "{dog: 32u32, cat: \"rufus\"}"
438        );
439    }
440}