Skip to main content

vortex_array/expr/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Vortex's expression language.
5//!
6//! All expressions are serializable, and own their own wire format.
7//!
8//! The implementation takes inspiration from [Postgres] and [Apache Datafusion].
9//!
10//! [Postgres]: https://www.postgresql.org/docs/current/sql-expressions.html
11//! [Apache Datafusion]: https://github.com/apache/datafusion/tree/5fac581efbaffd0e6a9edf931182517524526afd/datafusion/expr
12
13use std::hash::Hash;
14use std::hash::Hasher;
15use std::sync::Arc;
16
17use vortex_error::VortexExpect;
18use vortex_utils::aliases::hash_set::HashSet;
19
20use crate::dtype::FieldName;
21use crate::expr::traversal::NodeExt;
22use crate::expr::traversal::ReferenceCollector;
23use crate::scalar_fn::fns::binary::Binary;
24use crate::scalar_fn::fns::operators::Operator;
25
26pub mod aliases;
27pub mod analysis;
28#[cfg(feature = "arbitrary")]
29pub mod arbitrary;
30pub mod display;
31pub(crate) mod expression;
32mod exprs;
33pub(crate) mod field;
34pub mod forms;
35mod optimize;
36pub mod proto;
37pub mod pruning;
38pub mod stats;
39pub mod transform;
40pub mod traversal;
41
42pub use analysis::*;
43pub use expression::*;
44pub use exprs::*;
45pub use pruning::StatsCatalog;
46
47pub trait VortexExprExt {
48    /// Accumulate all field references from this expression and its children in a set
49    fn field_references(&self) -> HashSet<FieldName>;
50}
51
52impl VortexExprExt for Expression {
53    fn field_references(&self) -> HashSet<FieldName> {
54        let mut collector = ReferenceCollector::new();
55        // The collector is infallible, so we can unwrap the result
56        self.accept(&mut collector)
57            .vortex_expect("reference collector should never fail");
58        collector.into_fields()
59    }
60}
61
62/// Splits top level and operations into separate expressions.
63pub fn split_conjunction(expr: &Expression) -> Vec<Expression> {
64    let mut conjunctions = vec![];
65    split_inner(expr, &mut conjunctions);
66    conjunctions
67}
68
69fn split_inner(expr: &Expression, exprs: &mut Vec<Expression>) {
70    match expr.as_opt::<Binary>() {
71        Some(operator) if *operator == Operator::And => {
72            split_inner(expr.child(0), exprs);
73            split_inner(expr.child(1), exprs);
74        }
75        Some(_) | None => {
76            exprs.push(expr.clone());
77        }
78    }
79}
80
81/// An expression wrapper that performs pointer equality on child expressions.
82#[derive(Clone)]
83pub struct ExactExpr(pub Expression);
84impl PartialEq for ExactExpr {
85    fn eq(&self, other: &Self) -> bool {
86        self.0.scalar_fn() == other.0.scalar_fn()
87            && Arc::ptr_eq(self.0.children(), other.0.children())
88    }
89}
90impl Eq for ExactExpr {}
91
92impl Hash for ExactExpr {
93    fn hash<H: Hasher>(&self, state: &mut H) {
94        self.0.hash(state);
95    }
96}
97
98#[cfg(feature = "_test-harness")]
99pub mod test_harness {
100    use crate::dtype::DType;
101    use crate::dtype::Nullability;
102    use crate::dtype::PType;
103    use crate::dtype::StructFields;
104
105    pub fn struct_dtype() -> DType {
106        DType::Struct(
107            StructFields::new(
108                ["a", "col1", "col2", "bool1", "bool2"].into(),
109                vec![
110                    DType::Primitive(PType::I32, Nullability::NonNullable),
111                    DType::Primitive(PType::U16, Nullability::Nullable),
112                    DType::Primitive(PType::U16, Nullability::Nullable),
113                    DType::Bool(Nullability::NonNullable),
114                    DType::Bool(Nullability::NonNullable),
115                ],
116            ),
117            Nullability::NonNullable,
118        )
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125    use crate::dtype::DType;
126    use crate::dtype::FieldNames;
127    use crate::dtype::Nullability;
128    use crate::dtype::PType;
129    use crate::dtype::StructFields;
130    use crate::expr::and;
131    use crate::expr::col;
132    use crate::expr::eq;
133    use crate::expr::get_item;
134    use crate::expr::gt;
135    use crate::expr::gt_eq;
136    use crate::expr::lit;
137    use crate::expr::lt;
138    use crate::expr::lt_eq;
139    use crate::expr::not;
140    use crate::expr::not_eq;
141    use crate::expr::or;
142    use crate::expr::root;
143    use crate::expr::select;
144    use crate::expr::select_exclude;
145    use crate::scalar::Scalar;
146
147    #[test]
148    fn basic_expr_split_test() {
149        let lhs = get_item("col1", root());
150        let rhs = lit(1);
151        let expr = eq(lhs, rhs);
152        let conjunction = split_conjunction(&expr);
153        assert_eq!(conjunction.len(), 1);
154    }
155
156    #[test]
157    fn basic_conjunction_split_test() {
158        let lhs = get_item("col1", root());
159        let rhs = lit(1);
160        let expr = and(lhs, rhs);
161        let conjunction = split_conjunction(&expr);
162        assert_eq!(conjunction.len(), 2, "Conjunction is {conjunction:?}");
163    }
164
165    #[test]
166    fn expr_display() {
167        assert_eq!(col("a").to_string(), "$.a");
168        assert_eq!(root().to_string(), "$");
169
170        let col1: Expression = col("col1");
171        let col2: Expression = col("col2");
172        assert_eq!(
173            and(col1.clone(), col2.clone()).to_string(),
174            "($.col1 and $.col2)"
175        );
176        assert_eq!(
177            or(col1.clone(), col2.clone()).to_string(),
178            "($.col1 or $.col2)"
179        );
180        assert_eq!(
181            eq(col1.clone(), col2.clone()).to_string(),
182            "($.col1 = $.col2)"
183        );
184        assert_eq!(
185            not_eq(col1.clone(), col2.clone()).to_string(),
186            "($.col1 != $.col2)"
187        );
188        assert_eq!(
189            gt(col1.clone(), col2.clone()).to_string(),
190            "($.col1 > $.col2)"
191        );
192        assert_eq!(
193            gt_eq(col1.clone(), col2.clone()).to_string(),
194            "($.col1 >= $.col2)"
195        );
196        assert_eq!(
197            lt(col1.clone(), col2.clone()).to_string(),
198            "($.col1 < $.col2)"
199        );
200        assert_eq!(
201            lt_eq(col1.clone(), col2.clone()).to_string(),
202            "($.col1 <= $.col2)"
203        );
204
205        assert_eq!(
206            or(lt(col1.clone(), col2.clone()), not_eq(col1.clone(), col2),).to_string(),
207            "(($.col1 < $.col2) or ($.col1 != $.col2))"
208        );
209
210        assert_eq!(not(col1).to_string(), "not($.col1)");
211
212        assert_eq!(
213            select(vec![FieldName::from("col1")], root()).to_string(),
214            "${col1}"
215        );
216        assert_eq!(
217            select(
218                vec![FieldName::from("col1"), FieldName::from("col2")],
219                root()
220            )
221            .to_string(),
222            "${col1, col2}"
223        );
224        assert_eq!(
225            select_exclude(
226                vec![FieldName::from("col1"), FieldName::from("col2")],
227                root()
228            )
229            .to_string(),
230            "${~ col1, col2}"
231        );
232
233        assert_eq!(lit(Scalar::from(0u8)).to_string(), "0u8");
234        assert_eq!(lit(Scalar::from(0.0f32)).to_string(), "0f32");
235        assert_eq!(
236            lit(Scalar::from(i64::MAX)).to_string(),
237            "9223372036854775807i64"
238        );
239        assert_eq!(lit(Scalar::from(true)).to_string(), "true");
240        assert_eq!(
241            lit(Scalar::null(DType::Bool(Nullability::Nullable))).to_string(),
242            "null"
243        );
244
245        assert_eq!(
246            lit(Scalar::struct_(
247                DType::Struct(
248                    StructFields::new(
249                        FieldNames::from(["dog", "cat"]),
250                        vec![
251                            DType::Primitive(PType::U32, Nullability::NonNullable),
252                            DType::Utf8(Nullability::NonNullable)
253                        ],
254                    ),
255                    Nullability::NonNullable
256                ),
257                vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
258            ))
259            .to_string(),
260            "{dog: 32u32, cat: \"rufus\"}"
261        );
262    }
263}