vortex_expr/
lib.rs

1#![feature(option_zip)]
2
3use std::any::Any;
4use std::fmt::{Debug, Display};
5use std::sync::Arc;
6
7use dyn_hash::DynHash;
8
9mod binary;
10
11mod analysis;
12mod between;
13mod cast;
14mod field;
15pub mod forms;
16mod get_item;
17mod is_null;
18mod let_;
19mod like;
20mod list_contains;
21mod literal;
22mod merge;
23mod not;
24mod operators;
25mod pack;
26pub mod pruning;
27#[cfg(feature = "proto")]
28mod registry;
29mod scope;
30mod select;
31pub mod transform;
32pub mod traversal;
33mod var;
34
35pub use analysis::*;
36pub use between::*;
37pub use binary::*;
38pub use cast::*;
39pub use get_item::*;
40pub use is_null::*;
41pub use let_::*;
42pub use like::*;
43pub use list_contains::*;
44pub use literal::*;
45pub use merge::*;
46pub use not::*;
47pub use operators::*;
48pub use pack::*;
49#[cfg(feature = "proto")]
50pub use registry::deserialize_expr;
51pub use scope::*;
52pub use select::*;
53pub use var::*;
54use vortex_array::{Array, ArrayRef};
55use vortex_dtype::{DType, FieldName, FieldPath};
56use vortex_error::{VortexResult, VortexUnwrap};
57#[cfg(feature = "proto")]
58use vortex_proto::expr;
59#[cfg(feature = "proto")]
60use vortex_proto::expr::{Expr, kind};
61use vortex_utils::aliases::hash_set::HashSet;
62
63use crate::traversal::{Node, ReferenceCollector};
64
65pub type ExprRef = Arc<dyn VortexExpr>;
66
67#[cfg(feature = "proto")]
68pub trait Id {
69    fn id(&self) -> &'static str;
70}
71
72#[cfg(feature = "proto")]
73pub trait ExprDeserialize: Id + Sync {
74    fn deserialize(&self, kind: &kind::Kind, children: Vec<ExprRef>) -> VortexResult<ExprRef>;
75}
76
77#[cfg(feature = "proto")]
78pub trait ExprSerializable {
79    fn id(&self) -> &'static str;
80
81    fn serialize_kind(&self) -> VortexResult<kind::Kind>;
82}
83
84#[cfg(not(feature = "proto"))]
85pub trait ExprSerializable {}
86#[cfg(not(feature = "proto"))]
87impl<T> ExprSerializable for T {}
88/// Represents logical operation on [`ArrayRef`]s
89pub trait VortexExpr:
90    Debug + Send + Sync + DynEq + DynHash + Display + ExprSerializable + AnalysisExpr
91{
92    /// Convert expression reference to reference of [`Any`] type
93    fn as_any(&self) -> &dyn Any;
94
95    /// Compute result of expression on given batch producing a new batch
96    fn evaluate(&self, scope: &Scope) -> VortexResult<ArrayRef> {
97        let result = self.unchecked_evaluate(scope)?;
98        assert_eq!(
99            result.dtype(),
100            &self.return_dtype(&scope.into())?,
101            "Expression {} returned dtype {} but declared return_dtype of {}",
102            self,
103            result.dtype(),
104            self.return_dtype(&scope.into())?,
105        );
106        Ok(result)
107    }
108
109    /// Compute result of expression on given batch producing a new batch
110    ///
111    /// "Unchecked" means that this function lacks a debug assertion that the returned array matches
112    /// the [VortexExpr::return_dtype] method. Use instead the [VortexExpr::evaluate] function which
113    /// includes such an assertion.
114    fn unchecked_evaluate(&self, ctx: &Scope) -> VortexResult<ArrayRef>;
115
116    fn children(&self) -> Vec<&ExprRef>;
117
118    fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef;
119
120    /// Compute the type of the array returned by [VortexExpr::evaluate].
121    fn return_dtype(&self, scope: &ScopeDType) -> VortexResult<DType>;
122}
123
124pub trait VortexExprExt {
125    /// Accumulate all field references from this expression and its children in a set
126    fn references(&self) -> HashSet<FieldName>;
127
128    #[cfg(feature = "proto")]
129    fn serialize(&self) -> VortexResult<Expr>;
130}
131
132impl VortexExprExt for ExprRef {
133    fn references(&self) -> HashSet<FieldName> {
134        let mut collector = ReferenceCollector::new();
135        // The collector is infallible, so we can unwrap the result
136        self.accept(&mut collector).vortex_unwrap();
137        collector.into_fields()
138    }
139
140    #[cfg(feature = "proto")]
141    fn serialize(&self) -> VortexResult<Expr> {
142        let children = self
143            .children()
144            .iter()
145            .map(|e| e.serialize())
146            .collect::<VortexResult<_>>()?;
147
148        Ok(Expr {
149            id: self.id().to_string(),
150            children,
151            kind: Some(expr::Kind {
152                kind: Some(self.serialize_kind()?),
153            }),
154        })
155    }
156}
157
158#[derive(Clone, Debug, Hash, PartialEq, Eq)]
159pub struct AccessPath {
160    field_path: FieldPath,
161    identifier: Identifier,
162}
163
164impl AccessPath {
165    pub fn root_field(path: FieldName) -> Self {
166        Self {
167            field_path: FieldPath::from_name(path),
168            identifier: Identifier::Identity,
169        }
170    }
171
172    pub fn new(path: FieldPath, identifier: Identifier) -> Self {
173        Self {
174            field_path: path,
175            identifier,
176        }
177    }
178
179    pub fn identifier(&self) -> &Identifier {
180        &self.identifier
181    }
182
183    pub fn field_path(&self) -> &FieldPath {
184        &self.field_path
185    }
186}
187
188/// Splits top level and operations into separate expressions
189pub fn split_conjunction(expr: &ExprRef) -> Vec<ExprRef> {
190    let mut conjunctions = vec![];
191    split_inner(expr, &mut conjunctions);
192    conjunctions
193}
194
195fn split_inner(expr: &ExprRef, exprs: &mut Vec<ExprRef>) {
196    match expr.as_any().downcast_ref::<BinaryExpr>() {
197        Some(bexp) if bexp.op() == Operator::And => {
198            split_inner(bexp.lhs(), exprs);
199            split_inner(bexp.rhs(), exprs);
200        }
201        Some(_) | None => {
202            exprs.push(expr.clone());
203        }
204    }
205}
206
207// Adapted from apache/datafusion https://github.com/apache/datafusion/blob/f31ca5b927c040ce03f6a3c8c8dc3d7f4ef5be34/datafusion/physical-expr-common/src/physical_expr.rs#L156
208/// [`VortexExpr`] can't be constrained by [`Eq`] directly because it must remain object
209/// safe. To ease implementation blanket implementation is provided for [`Eq`] types.
210pub trait DynEq {
211    fn dyn_eq(&self, other: &dyn Any) -> bool;
212}
213
214impl<T: Eq + Any> DynEq for T {
215    fn dyn_eq(&self, other: &dyn Any) -> bool {
216        other.downcast_ref::<Self>() == Some(self)
217    }
218}
219
220impl PartialEq for dyn VortexExpr {
221    fn eq(&self, other: &Self) -> bool {
222        self.dyn_eq(other.as_any())
223    }
224}
225
226impl Eq for dyn VortexExpr {}
227
228dyn_hash::hash_trait_object!(VortexExpr);
229
230#[cfg(feature = "test-harness")]
231pub mod test_harness {
232    use std::sync::Arc;
233
234    use vortex_dtype::{DType, Nullability, PType, StructFields};
235
236    pub fn struct_dtype() -> DType {
237        DType::Struct(
238            Arc::new(StructFields::new(
239                [
240                    "a".into(),
241                    "col1".into(),
242                    "col2".into(),
243                    "bool1".into(),
244                    "bool2".into(),
245                ]
246                .into(),
247                vec![
248                    DType::Primitive(PType::I32, Nullability::NonNullable),
249                    DType::Primitive(PType::U16, Nullability::Nullable),
250                    DType::Primitive(PType::U16, Nullability::Nullable),
251                    DType::Bool(Nullability::NonNullable),
252                    DType::Bool(Nullability::NonNullable),
253                ],
254            )),
255            Nullability::NonNullable,
256        )
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use vortex_dtype::{DType, Nullability, PType, StructFields};
263    use vortex_scalar::Scalar;
264
265    use super::*;
266
267    #[test]
268    fn basic_expr_split_test() {
269        let lhs = get_item("col1", root());
270        let rhs = lit(1);
271        let expr = eq(lhs, rhs);
272        let conjunction = split_conjunction(&expr);
273        assert_eq!(conjunction.len(), 1);
274    }
275
276    #[test]
277    fn basic_conjunction_split_test() {
278        let lhs = get_item("col1", root());
279        let rhs = lit(1);
280        let expr = and(lhs, rhs);
281        let conjunction = split_conjunction(&expr);
282        assert_eq!(conjunction.len(), 2, "Conjunction is {conjunction:?}");
283    }
284
285    #[test]
286    fn expr_display() {
287        assert_eq!(col("a").to_string(), "$.a");
288        assert_eq!(root().to_string(), "$");
289
290        let col1: Arc<dyn VortexExpr> = col("col1");
291        let col2: Arc<dyn VortexExpr> = col("col2");
292        assert_eq!(
293            and(col1.clone(), col2.clone()).to_string(),
294            "($.col1 and $.col2)"
295        );
296        assert_eq!(
297            or(col1.clone(), col2.clone()).to_string(),
298            "($.col1 or $.col2)"
299        );
300        assert_eq!(
301            eq(col1.clone(), col2.clone()).to_string(),
302            "($.col1 = $.col2)"
303        );
304        assert_eq!(
305            not_eq(col1.clone(), col2.clone()).to_string(),
306            "($.col1 != $.col2)"
307        );
308        assert_eq!(
309            gt(col1.clone(), col2.clone()).to_string(),
310            "($.col1 > $.col2)"
311        );
312        assert_eq!(
313            gt_eq(col1.clone(), col2.clone()).to_string(),
314            "($.col1 >= $.col2)"
315        );
316        assert_eq!(
317            lt(col1.clone(), col2.clone()).to_string(),
318            "($.col1 < $.col2)"
319        );
320        assert_eq!(
321            lt_eq(col1.clone(), col2.clone()).to_string(),
322            "($.col1 <= $.col2)"
323        );
324
325        assert_eq!(
326            or(
327                lt(col1.clone(), col2.clone()),
328                not_eq(col1.clone(), col2.clone()),
329            )
330            .to_string(),
331            "(($.col1 < $.col2) or ($.col1 != $.col2))"
332        );
333
334        assert_eq!(not(col1.clone()).to_string(), "!$.col1");
335
336        assert_eq!(
337            select(vec![FieldName::from("col1")], root()).to_string(),
338            "${col1}"
339        );
340        assert_eq!(
341            select(
342                vec![FieldName::from("col1"), FieldName::from("col2")],
343                root()
344            )
345            .to_string(),
346            "${col1, col2}"
347        );
348        assert_eq!(
349            select_exclude(
350                vec![FieldName::from("col1"), FieldName::from("col2")],
351                root()
352            )
353            .to_string(),
354            "$~{col1, col2}"
355        );
356
357        assert_eq!(lit(Scalar::from(0u8)).to_string(), "0u8");
358        assert_eq!(lit(Scalar::from(0.0f32)).to_string(), "0f32");
359        assert_eq!(
360            lit(Scalar::from(i64::MAX)).to_string(),
361            "9223372036854775807i64"
362        );
363        assert_eq!(lit(Scalar::from(true)).to_string(), "true");
364        assert_eq!(
365            lit(Scalar::null(DType::Bool(Nullability::Nullable))).to_string(),
366            "null"
367        );
368
369        assert_eq!(
370            lit(Scalar::struct_(
371                DType::Struct(
372                    Arc::new(StructFields::new(
373                        Arc::from([Arc::from("dog"), Arc::from("cat")]),
374                        vec![
375                            DType::Primitive(PType::U32, Nullability::NonNullable),
376                            DType::Utf8(Nullability::NonNullable)
377                        ],
378                    )),
379                    Nullability::NonNullable
380                ),
381                vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
382            ))
383            .to_string(),
384            "{dog: 32u32, cat: \"rufus\"}"
385        );
386    }
387
388    #[cfg(feature = "proto")]
389    mod tests_proto {
390
391        use crate::{VortexExprExt, deserialize_expr, eq, lit, root};
392
393        #[test]
394        fn round_trip_serde() {
395            let expr = eq(root(), lit(1));
396            let res = expr.serialize().unwrap();
397            let final_ = deserialize_expr(&res).unwrap();
398
399            assert_eq!(&expr, &final_);
400        }
401    }
402}