vortex_expr/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::fmt::{Debug, Display, Formatter};
6use std::hash::{Hash, Hasher};
7use std::sync::Arc;
8
9use dyn_hash::DynHash;
10pub use exprs::*;
11pub mod aliases;
12mod analysis;
13#[cfg(feature = "arbitrary")]
14pub mod arbitrary;
15pub mod dyn_traits;
16mod encoding;
17mod exprs;
18mod field;
19pub mod forms;
20pub mod proto;
21pub mod pruning;
22mod registry;
23mod scope;
24mod scope_vars;
25pub mod transform;
26pub mod traversal;
27mod vtable;
28
29pub use analysis::*;
30pub use between::*;
31pub use binary::*;
32pub use cast::*;
33pub use encoding::*;
34pub use get_item::*;
35pub use is_null::*;
36pub use like::*;
37pub use list_contains::*;
38pub use literal::*;
39pub use merge::*;
40pub use not::*;
41pub use operators::*;
42pub use pack::*;
43pub use registry::*;
44pub use root::*;
45pub use scope::*;
46pub use select::*;
47use vortex_array::{Array, ArrayRef, SerializeMetadata};
48use vortex_dtype::{DType, FieldName, FieldPath};
49use vortex_error::{VortexExpect, VortexResult, VortexUnwrap, vortex_bail};
50use vortex_utils::aliases::hash_set::HashSet;
51pub use vtable::*;
52
53use crate::dyn_traits::DynEq;
54use crate::traversal::{Node, ReferenceCollector};
55
56pub trait IntoExpr {
57    /// Convert this type into an expression reference.
58    fn into_expr(self) -> ExprRef;
59}
60
61pub type ExprRef = Arc<dyn VortexExpr>;
62
63/// Represents logical operation on [`ArrayRef`]s
64pub trait VortexExpr:
65    'static + Send + Sync + Debug + Display + DynEq + DynHash + private::Sealed + AnalysisExpr
66{
67    /// Convert expression reference to reference of [`Any`] type
68    fn as_any(&self) -> &dyn Any;
69
70    /// Convert the expression to an [`ExprRef`].
71    fn to_expr(&self) -> ExprRef;
72
73    /// Return the encoding of the expression.
74    fn encoding(&self) -> ExprEncodingRef;
75
76    /// Serialize the metadata of this expression into a bytes vector.
77    ///
78    /// Returns `None` if the expression does not support serialization.
79    fn metadata(&self) -> Option<Vec<u8>> {
80        None
81    }
82
83    /// Compute result of expression on given batch producing a new batch
84    ///
85    /// "Unchecked" means that this function lacks a debug assertion that the returned array matches
86    /// the [VortexExpr::return_dtype] method. Use instead the
87    /// [`VortexExpr::evaluate`](./trait.VortexExpr.html#method.evaluate).
88    /// function which includes such an assertion.
89    fn unchecked_evaluate(&self, ctx: &Scope) -> VortexResult<ArrayRef>;
90
91    /// Returns the children of this expression.
92    fn children(&self) -> Vec<&ExprRef>;
93
94    /// Returns a new instance of this expression with the children replaced.
95    fn with_children(self: Arc<Self>, children: Vec<ExprRef>) -> VortexResult<ExprRef>;
96
97    /// Compute the type of the array returned by
98    /// [`VortexExpr::evaluate`](./trait.VortexExpr.html#method.evaluate).
99    fn return_dtype(&self, scope: &DType) -> VortexResult<DType>;
100}
101
102dyn_hash::hash_trait_object!(VortexExpr);
103
104impl PartialEq for dyn VortexExpr {
105    fn eq(&self, other: &Self) -> bool {
106        self.dyn_eq(other.as_any())
107    }
108}
109
110impl Eq for dyn VortexExpr {}
111
112impl dyn VortexExpr + '_ {
113    pub fn id(&self) -> ExprId {
114        self.encoding().id()
115    }
116
117    pub fn is<V: VTable>(&self) -> bool {
118        self.as_opt::<V>().is_some()
119    }
120
121    pub fn as_<V: VTable>(&self) -> &V::Expr {
122        self.as_opt::<V>()
123            .vortex_expect("Expr is not of the expected type")
124    }
125
126    pub fn as_opt<V: VTable>(&self) -> Option<&V::Expr> {
127        VortexExpr::as_any(self)
128            .downcast_ref::<ExprAdapter<V>>()
129            .map(|e| &e.0)
130    }
131
132    /// Compute result of expression on given batch producing a new batch
133    pub fn evaluate(&self, scope: &Scope) -> VortexResult<ArrayRef> {
134        let result = self.unchecked_evaluate(scope)?;
135        assert_eq!(
136            result.dtype(),
137            &self.return_dtype(scope.dtype())?,
138            "Expression {} returned dtype {} but declared return_dtype of {}",
139            self,
140            result.dtype(),
141            self.return_dtype(scope.dtype())?,
142        );
143        Ok(result)
144    }
145}
146
147pub trait VortexExprExt {
148    /// Accumulate all field references from this expression and its children in a set
149    fn field_references(&self) -> HashSet<FieldName>;
150}
151
152impl VortexExprExt for ExprRef {
153    fn field_references(&self) -> HashSet<FieldName> {
154        let mut collector = ReferenceCollector::new();
155        // The collector is infallible, so we can unwrap the result
156        self.accept(&mut collector).vortex_unwrap();
157        collector.into_fields()
158    }
159}
160
161#[derive(Clone)]
162#[repr(transparent)]
163pub struct ExprAdapter<V: VTable>(V::Expr);
164
165impl<V: VTable> VortexExpr for ExprAdapter<V> {
166    fn as_any(&self) -> &dyn Any {
167        self
168    }
169
170    fn to_expr(&self) -> ExprRef {
171        Arc::new(ExprAdapter::<V>(self.0.clone()))
172    }
173
174    fn encoding(&self) -> ExprEncodingRef {
175        V::encoding(&self.0)
176    }
177
178    fn metadata(&self) -> Option<Vec<u8>> {
179        V::metadata(&self.0).map(|m| m.serialize())
180    }
181
182    fn unchecked_evaluate(&self, ctx: &Scope) -> VortexResult<ArrayRef> {
183        V::evaluate(&self.0, ctx)
184    }
185
186    fn children(&self) -> Vec<&ExprRef> {
187        V::children(&self.0)
188    }
189
190    fn with_children(self: Arc<Self>, children: Vec<ExprRef>) -> VortexResult<ExprRef> {
191        if self.children().len() != children.len() {
192            vortex_bail!(
193                "Expected {} children, got {}",
194                self.children().len(),
195                children.len()
196            );
197        }
198        Ok(V::with_children(&self.0, children)?.to_expr())
199    }
200
201    fn return_dtype(&self, scope: &DType) -> VortexResult<DType> {
202        V::return_dtype(&self.0, scope)
203    }
204}
205
206impl<V: VTable> Debug for ExprAdapter<V> {
207    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
208        Debug::fmt(&self.0, f)
209    }
210}
211
212impl<V: VTable> Display for ExprAdapter<V> {
213    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
214        Display::fmt(&self.0, f)
215    }
216}
217
218impl<V: VTable> PartialEq for ExprAdapter<V> {
219    fn eq(&self, other: &Self) -> bool {
220        PartialEq::eq(&self.0, &other.0)
221    }
222}
223
224impl<V: VTable> Eq for ExprAdapter<V> {}
225
226impl<V: VTable> Hash for ExprAdapter<V> {
227    fn hash<H: Hasher>(&self, state: &mut H) {
228        Hash::hash(&self.0, state);
229    }
230}
231
232impl<V: VTable> AnalysisExpr for ExprAdapter<V> {
233    fn stat_falsification(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
234        <V::Expr as AnalysisExpr>::stat_falsification(&self.0, catalog)
235    }
236
237    fn max(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
238        <V::Expr as AnalysisExpr>::max(&self.0, catalog)
239    }
240
241    fn min(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
242        <V::Expr as AnalysisExpr>::min(&self.0, catalog)
243    }
244
245    fn nan_count(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
246        <V::Expr as AnalysisExpr>::nan_count(&self.0, catalog)
247    }
248
249    fn field_path(&self) -> Option<FieldPath> {
250        <V::Expr as AnalysisExpr>::field_path(&self.0)
251    }
252}
253
254mod private {
255    use super::*;
256
257    pub trait Sealed {}
258
259    impl<V: VTable> Sealed for ExprAdapter<V> {}
260}
261
262/// Splits top level and operations into separate expressions
263pub fn split_conjunction(expr: &ExprRef) -> Vec<ExprRef> {
264    let mut conjunctions = vec![];
265    split_inner(expr, &mut conjunctions);
266    conjunctions
267}
268
269fn split_inner(expr: &ExprRef, exprs: &mut Vec<ExprRef>) {
270    match expr.as_opt::<BinaryVTable>() {
271        Some(bexp) if bexp.op() == Operator::And => {
272            split_inner(bexp.lhs(), exprs);
273            split_inner(bexp.rhs(), exprs);
274        }
275        Some(_) | None => {
276            exprs.push(expr.clone());
277        }
278    }
279}
280
281/// An expression wrapper that performs pointer equality.
282#[derive(Clone)]
283pub struct ExactExpr(pub ExprRef);
284
285impl PartialEq for ExactExpr {
286    fn eq(&self, other: &Self) -> bool {
287        Arc::ptr_eq(&self.0, &other.0)
288    }
289}
290
291impl Eq for ExactExpr {}
292
293impl Hash for ExactExpr {
294    fn hash<H: Hasher>(&self, state: &mut H) {
295        Arc::as_ptr(&self.0).hash(state)
296    }
297}
298
299#[cfg(feature = "test-harness")]
300pub mod test_harness {
301
302    use vortex_dtype::{DType, Nullability, PType, StructFields};
303
304    pub fn struct_dtype() -> DType {
305        DType::Struct(
306            StructFields::new(
307                ["a", "col1", "col2", "bool1", "bool2"].into(),
308                vec![
309                    DType::Primitive(PType::I32, Nullability::NonNullable),
310                    DType::Primitive(PType::U16, Nullability::Nullable),
311                    DType::Primitive(PType::U16, Nullability::Nullable),
312                    DType::Bool(Nullability::NonNullable),
313                    DType::Bool(Nullability::NonNullable),
314                ],
315            ),
316            Nullability::NonNullable,
317        )
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use vortex_dtype::{DType, FieldNames, Nullability, PType, StructFields};
324    use vortex_scalar::Scalar;
325
326    use super::*;
327
328    #[test]
329    fn basic_expr_split_test() {
330        let lhs = get_item("col1", root());
331        let rhs = lit(1);
332        let expr = eq(lhs, rhs);
333        let conjunction = split_conjunction(&expr);
334        assert_eq!(conjunction.len(), 1);
335    }
336
337    #[test]
338    fn basic_conjunction_split_test() {
339        let lhs = get_item("col1", root());
340        let rhs = lit(1);
341        let expr = and(lhs, rhs);
342        let conjunction = split_conjunction(&expr);
343        assert_eq!(conjunction.len(), 2, "Conjunction is {conjunction:?}");
344    }
345
346    #[test]
347    fn expr_display() {
348        assert_eq!(col("a").to_string(), "$.a");
349        assert_eq!(root().to_string(), "$");
350
351        let col1: Arc<dyn VortexExpr> = col("col1");
352        let col2: Arc<dyn VortexExpr> = col("col2");
353        assert_eq!(
354            and(col1.clone(), col2.clone()).to_string(),
355            "($.col1 and $.col2)"
356        );
357        assert_eq!(
358            or(col1.clone(), col2.clone()).to_string(),
359            "($.col1 or $.col2)"
360        );
361        assert_eq!(
362            eq(col1.clone(), col2.clone()).to_string(),
363            "($.col1 = $.col2)"
364        );
365        assert_eq!(
366            not_eq(col1.clone(), col2.clone()).to_string(),
367            "($.col1 != $.col2)"
368        );
369        assert_eq!(
370            gt(col1.clone(), col2.clone()).to_string(),
371            "($.col1 > $.col2)"
372        );
373        assert_eq!(
374            gt_eq(col1.clone(), col2.clone()).to_string(),
375            "($.col1 >= $.col2)"
376        );
377        assert_eq!(
378            lt(col1.clone(), col2.clone()).to_string(),
379            "($.col1 < $.col2)"
380        );
381        assert_eq!(
382            lt_eq(col1.clone(), col2.clone()).to_string(),
383            "($.col1 <= $.col2)"
384        );
385
386        assert_eq!(
387            or(
388                lt(col1.clone(), col2.clone()),
389                not_eq(col1.clone(), col2.clone()),
390            )
391            .to_string(),
392            "(($.col1 < $.col2) or ($.col1 != $.col2))"
393        );
394
395        assert_eq!(not(col1.clone()).to_string(), "(!$.col1)");
396
397        assert_eq!(
398            select(vec![FieldName::from("col1")], root()).to_string(),
399            "${col1}"
400        );
401        assert_eq!(
402            select(
403                vec![FieldName::from("col1"), FieldName::from("col2")],
404                root()
405            )
406            .to_string(),
407            "${col1, col2}"
408        );
409        assert_eq!(
410            select_exclude(
411                vec![FieldName::from("col1"), FieldName::from("col2")],
412                root()
413            )
414            .to_string(),
415            "$~{col1, col2}"
416        );
417
418        assert_eq!(lit(Scalar::from(0u8)).to_string(), "0u8");
419        assert_eq!(lit(Scalar::from(0.0f32)).to_string(), "0f32");
420        assert_eq!(
421            lit(Scalar::from(i64::MAX)).to_string(),
422            "9223372036854775807i64"
423        );
424        assert_eq!(lit(Scalar::from(true)).to_string(), "true");
425        assert_eq!(
426            lit(Scalar::null(DType::Bool(Nullability::Nullable))).to_string(),
427            "null"
428        );
429
430        assert_eq!(
431            lit(Scalar::struct_(
432                DType::Struct(
433                    StructFields::new(
434                        FieldNames::from(["dog", "cat"]),
435                        vec![
436                            DType::Primitive(PType::U32, Nullability::NonNullable),
437                            DType::Utf8(Nullability::NonNullable)
438                        ],
439                    ),
440                    Nullability::NonNullable
441                ),
442                vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
443            ))
444            .to_string(),
445            "{dog: 32u32, cat: \"rufus\"}"
446        );
447    }
448}