vortex_expr/
lib.rs

1use std::any::Any;
2use std::fmt::{Debug, Display};
3use std::hash::Hash;
4use std::sync::Arc;
5
6use dyn_hash::DynHash;
7pub use exprs::*;
8mod analysis;
9#[cfg(feature = "arbitrary")]
10pub mod arbitrary;
11mod exprs;
12mod field;
13pub mod forms;
14pub mod pruning;
15#[cfg(feature = "proto")]
16mod registry;
17mod scope;
18pub mod transform;
19pub mod traversal;
20
21pub use analysis::*;
22pub use between::*;
23pub use binary::*;
24pub use cast::*;
25pub use get_item::*;
26pub use is_null::*;
27pub use like::*;
28pub use list_contains::*;
29pub use literal::*;
30pub use merge::*;
31pub use not::*;
32pub use operators::*;
33pub use pack::*;
34#[cfg(feature = "proto")]
35pub use registry::deserialize_expr;
36pub use scope::*;
37pub use select::*;
38pub use var::*;
39use vortex_array::{Array, ArrayRef};
40use vortex_dtype::{DType, FieldName, FieldPath};
41use vortex_error::{VortexResult, VortexUnwrap};
42#[cfg(feature = "proto")]
43use vortex_proto::expr;
44#[cfg(feature = "proto")]
45use vortex_proto::expr::{Expr, kind};
46use vortex_utils::aliases::hash_set::HashSet;
47
48use crate::traversal::{Node, ReferenceCollector, VarsCollector};
49
50pub type ExprRef = Arc<dyn VortexExpr>;
51
52#[cfg(feature = "proto")]
53pub trait Id {
54    fn id(&self) -> &'static str;
55}
56
57#[cfg(feature = "proto")]
58pub trait ExprDeserialize: Id + Sync {
59    fn deserialize(&self, kind: &kind::Kind, children: Vec<ExprRef>) -> VortexResult<ExprRef>;
60}
61
62#[cfg(feature = "proto")]
63pub trait ExprSerializable {
64    fn id(&self) -> &'static str;
65
66    fn serialize_kind(&self) -> VortexResult<kind::Kind>;
67}
68
69#[cfg(not(feature = "proto"))]
70pub trait ExprSerializable {}
71#[cfg(not(feature = "proto"))]
72impl<T> ExprSerializable for T {}
73/// Represents logical operation on [`ArrayRef`]s
74pub trait VortexExpr:
75    Debug + Send + Sync + DynEq + DynHash + Display + ExprSerializable + AnalysisExpr
76{
77    /// Convert expression reference to reference of [`Any`] type
78    fn as_any(&self) -> &dyn Any;
79
80    /// Compute result of expression on given batch producing a new batch
81    fn evaluate(&self, scope: &Scope) -> VortexResult<ArrayRef> {
82        let result = self.unchecked_evaluate(scope)?;
83        assert_eq!(
84            result.dtype(),
85            &self.return_dtype(&scope.into())?,
86            "Expression {} returned dtype {} but declared return_dtype of {}",
87            self,
88            result.dtype(),
89            self.return_dtype(&scope.into())?,
90        );
91        Ok(result)
92    }
93
94    /// Compute result of expression on given batch producing a new batch
95    ///
96    /// "Unchecked" means that this function lacks a debug assertion that the returned array matches
97    /// the [VortexExpr::return_dtype] method. Use instead the [VortexExpr::evaluate] function which
98    /// includes such an assertion.
99    fn unchecked_evaluate(&self, ctx: &Scope) -> VortexResult<ArrayRef>;
100
101    fn children(&self) -> Vec<&ExprRef>;
102
103    fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef;
104
105    /// Compute the type of the array returned by [VortexExpr::evaluate].
106    fn return_dtype(&self, scope: &ScopeDType) -> VortexResult<DType>;
107}
108
109pub trait VortexExprExt {
110    /// Accumulate all field references from this expression and its children in a set
111    fn field_references(&self) -> HashSet<FieldName>;
112
113    fn vars(&self) -> HashSet<Identifier>;
114
115    #[cfg(feature = "proto")]
116    fn serialize(&self) -> VortexResult<Expr>;
117}
118
119impl VortexExprExt for ExprRef {
120    fn field_references(&self) -> HashSet<FieldName> {
121        let mut collector = ReferenceCollector::new();
122        // The collector is infallible, so we can unwrap the result
123        self.accept(&mut collector).vortex_unwrap();
124        collector.into_fields()
125    }
126
127    fn vars(&self) -> HashSet<Identifier> {
128        let mut collector = VarsCollector::new();
129        // The collector is infallible, so we can unwrap the result
130        self.accept(&mut collector).vortex_unwrap();
131        collector.into_vars()
132    }
133
134    #[cfg(feature = "proto")]
135    fn serialize(&self) -> VortexResult<Expr> {
136        let children = self
137            .children()
138            .iter()
139            .map(|e| e.serialize())
140            .collect::<VortexResult<_>>()?;
141
142        Ok(Expr {
143            id: self.id().to_string(),
144            children,
145            kind: Some(expr::Kind {
146                kind: Some(self.serialize_kind()?),
147            }),
148        })
149    }
150}
151
152#[derive(Clone, Debug, Hash, PartialEq, Eq)]
153pub struct AccessPath {
154    field_path: FieldPath,
155    identifier: Identifier,
156}
157
158impl AccessPath {
159    pub fn root_field(path: FieldName) -> Self {
160        Self {
161            field_path: FieldPath::from_name(path),
162            identifier: Identifier::Identity,
163        }
164    }
165
166    pub fn new(path: FieldPath, identifier: Identifier) -> Self {
167        Self {
168            field_path: path,
169            identifier,
170        }
171    }
172
173    pub fn identifier(&self) -> &Identifier {
174        &self.identifier
175    }
176
177    pub fn field_path(&self) -> &FieldPath {
178        &self.field_path
179    }
180}
181
182/// Splits top level and operations into separate expressions
183pub fn split_conjunction(expr: &ExprRef) -> Vec<ExprRef> {
184    let mut conjunctions = vec![];
185    split_inner(expr, &mut conjunctions);
186    conjunctions
187}
188
189fn split_inner(expr: &ExprRef, exprs: &mut Vec<ExprRef>) {
190    match expr.as_any().downcast_ref::<BinaryExpr>() {
191        Some(bexp) if bexp.op() == Operator::And => {
192            split_inner(bexp.lhs(), exprs);
193            split_inner(bexp.rhs(), exprs);
194        }
195        Some(_) | None => {
196            exprs.push(expr.clone());
197        }
198    }
199}
200
201// Adapted from apache/datafusion https://github.com/apache/datafusion/blob/f31ca5b927c040ce03f6a3c8c8dc3d7f4ef5be34/datafusion/physical-expr-common/src/physical_expr.rs#L156
202/// [`VortexExpr`] can't be constrained by [`Eq`] directly because it must remain object
203/// safe. To ease implementation blanket implementation is provided for [`Eq`] types.
204pub trait DynEq {
205    fn dyn_eq(&self, other: &dyn Any) -> bool;
206}
207
208impl<T: Eq + Any> DynEq for T {
209    fn dyn_eq(&self, other: &dyn Any) -> bool {
210        other.downcast_ref::<Self>() == Some(self)
211    }
212}
213
214impl PartialEq for dyn VortexExpr {
215    fn eq(&self, other: &Self) -> bool {
216        self.dyn_eq(other.as_any())
217    }
218}
219
220impl Eq for dyn VortexExpr {}
221
222dyn_hash::hash_trait_object!(VortexExpr);
223
224/// An expression wrapper that performs pointer equality.
225#[derive(Clone)]
226pub struct ExactExpr(pub ExprRef);
227
228impl PartialEq for ExactExpr {
229    fn eq(&self, other: &Self) -> bool {
230        Arc::ptr_eq(&self.0, &other.0)
231    }
232}
233
234impl Eq for ExactExpr {}
235
236impl Hash for ExactExpr {
237    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
238        Arc::as_ptr(&self.0).hash(state)
239    }
240}
241
242#[cfg(feature = "test-harness")]
243pub mod test_harness {
244    use std::sync::Arc;
245
246    use vortex_dtype::{DType, Nullability, PType, StructFields};
247
248    pub fn struct_dtype() -> DType {
249        DType::Struct(
250            Arc::new(StructFields::new(
251                [
252                    "a".into(),
253                    "col1".into(),
254                    "col2".into(),
255                    "bool1".into(),
256                    "bool2".into(),
257                ]
258                .into(),
259                vec![
260                    DType::Primitive(PType::I32, Nullability::NonNullable),
261                    DType::Primitive(PType::U16, Nullability::Nullable),
262                    DType::Primitive(PType::U16, Nullability::Nullable),
263                    DType::Bool(Nullability::NonNullable),
264                    DType::Bool(Nullability::NonNullable),
265                ],
266            )),
267            Nullability::NonNullable,
268        )
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use vortex_dtype::{DType, Nullability, PType, StructFields};
275    use vortex_scalar::Scalar;
276
277    use super::*;
278
279    #[test]
280    fn basic_expr_split_test() {
281        let lhs = get_item("col1", root());
282        let rhs = lit(1);
283        let expr = eq(lhs, rhs);
284        let conjunction = split_conjunction(&expr);
285        assert_eq!(conjunction.len(), 1);
286    }
287
288    #[test]
289    fn basic_conjunction_split_test() {
290        let lhs = get_item("col1", root());
291        let rhs = lit(1);
292        let expr = and(lhs, rhs);
293        let conjunction = split_conjunction(&expr);
294        assert_eq!(conjunction.len(), 2, "Conjunction is {conjunction:?}");
295    }
296
297    #[test]
298    fn expr_display() {
299        assert_eq!(col("a").to_string(), "$.a");
300        assert_eq!(root().to_string(), "$");
301
302        let col1: Arc<dyn VortexExpr> = col("col1");
303        let col2: Arc<dyn VortexExpr> = col("col2");
304        assert_eq!(
305            and(col1.clone(), col2.clone()).to_string(),
306            "($.col1 and $.col2)"
307        );
308        assert_eq!(
309            or(col1.clone(), col2.clone()).to_string(),
310            "($.col1 or $.col2)"
311        );
312        assert_eq!(
313            eq(col1.clone(), col2.clone()).to_string(),
314            "($.col1 = $.col2)"
315        );
316        assert_eq!(
317            not_eq(col1.clone(), col2.clone()).to_string(),
318            "($.col1 != $.col2)"
319        );
320        assert_eq!(
321            gt(col1.clone(), col2.clone()).to_string(),
322            "($.col1 > $.col2)"
323        );
324        assert_eq!(
325            gt_eq(col1.clone(), col2.clone()).to_string(),
326            "($.col1 >= $.col2)"
327        );
328        assert_eq!(
329            lt(col1.clone(), col2.clone()).to_string(),
330            "($.col1 < $.col2)"
331        );
332        assert_eq!(
333            lt_eq(col1.clone(), col2.clone()).to_string(),
334            "($.col1 <= $.col2)"
335        );
336
337        assert_eq!(
338            or(
339                lt(col1.clone(), col2.clone()),
340                not_eq(col1.clone(), col2.clone()),
341            )
342            .to_string(),
343            "(($.col1 < $.col2) or ($.col1 != $.col2))"
344        );
345
346        assert_eq!(not(col1.clone()).to_string(), "!$.col1");
347
348        assert_eq!(
349            select(vec![FieldName::from("col1")], root()).to_string(),
350            "${col1}"
351        );
352        assert_eq!(
353            select(
354                vec![FieldName::from("col1"), FieldName::from("col2")],
355                root()
356            )
357            .to_string(),
358            "${col1, col2}"
359        );
360        assert_eq!(
361            select_exclude(
362                vec![FieldName::from("col1"), FieldName::from("col2")],
363                root()
364            )
365            .to_string(),
366            "$~{col1, col2}"
367        );
368
369        assert_eq!(lit(Scalar::from(0u8)).to_string(), "0u8");
370        assert_eq!(lit(Scalar::from(0.0f32)).to_string(), "0f32");
371        assert_eq!(
372            lit(Scalar::from(i64::MAX)).to_string(),
373            "9223372036854775807i64"
374        );
375        assert_eq!(lit(Scalar::from(true)).to_string(), "true");
376        assert_eq!(
377            lit(Scalar::null(DType::Bool(Nullability::Nullable))).to_string(),
378            "null"
379        );
380
381        assert_eq!(
382            lit(Scalar::struct_(
383                DType::Struct(
384                    Arc::new(StructFields::new(
385                        Arc::from([Arc::from("dog"), Arc::from("cat")]),
386                        vec![
387                            DType::Primitive(PType::U32, Nullability::NonNullable),
388                            DType::Utf8(Nullability::NonNullable)
389                        ],
390                    )),
391                    Nullability::NonNullable
392                ),
393                vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
394            ))
395            .to_string(),
396            "{dog: 32u32, cat: \"rufus\"}"
397        );
398    }
399
400    #[cfg(feature = "proto")]
401    mod tests_proto {
402        use crate::{VortexExprExt, deserialize_expr, eq, lit, root};
403
404        #[test]
405        fn round_trip_serde() {
406            let expr = eq(root(), lit(1));
407            let res = expr.serialize().unwrap();
408            let final_ = deserialize_expr(&res).unwrap();
409
410            assert_eq!(&expr, &final_);
411        }
412    }
413}