vortex_array/expr/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Vortex's expression language.
5//!
6//! All expressions are serializable, and own their own wire format.
7//!
8//! The implementation takes inspiration from [Postgres] and [Apache Datafusion].
9//!
10//! [Postgres]: https://www.postgresql.org/docs/current/sql-expressions.html
11//! [Apache Datafusion]: https://github.com/apache/datafusion/tree/5fac581efbaffd0e6a9edf931182517524526afd/datafusion/expr
12
13use std::hash::Hash;
14use std::hash::Hasher;
15use std::sync::Arc;
16
17use arcref::ArcRef;
18use vortex_dtype::FieldName;
19use vortex_error::VortexExpect;
20use vortex_utils::aliases::hash_set::HashSet;
21
22use crate::expr::traversal::NodeExt;
23use crate::expr::traversal::ReferenceCollector;
24
25pub mod aliases;
26pub mod analysis;
27#[cfg(feature = "arbitrary")]
28pub mod arbitrary;
29pub mod display;
30mod expression;
31mod exprs;
32mod field;
33pub mod forms;
34mod optimize;
35mod options;
36pub mod proto;
37pub mod pruning;
38mod scalar_fn;
39pub mod session;
40mod signature;
41pub mod stats;
42pub mod transform;
43pub mod traversal;
44mod vtable;
45
46pub use analysis::*;
47pub use expression::*;
48pub use exprs::*;
49pub use pruning::StatsCatalog;
50pub use scalar_fn::*;
51pub use vtable::*;
52
53pub type ExprId = ArcRef<str>;
54
55pub trait VortexExprExt {
56    /// Accumulate all field references from this expression and its children in a set
57    fn field_references(&self) -> HashSet<FieldName>;
58}
59
60impl VortexExprExt for Expression {
61    fn field_references(&self) -> HashSet<FieldName> {
62        let mut collector = ReferenceCollector::new();
63        // The collector is infallible, so we can unwrap the result
64        self.accept(&mut collector)
65            .vortex_expect("reference collector should never fail");
66        collector.into_fields()
67    }
68}
69
70/// Splits top level and operations into separate expressions.
71pub fn split_conjunction(expr: &Expression) -> Vec<Expression> {
72    let mut conjunctions = vec![];
73    split_inner(expr, &mut conjunctions);
74    conjunctions
75}
76
77fn split_inner(expr: &Expression, exprs: &mut Vec<Expression>) {
78    match expr.as_opt::<Binary>() {
79        Some(operator) if *operator == Operator::And => {
80            split_inner(expr.child(0), exprs);
81            split_inner(expr.child(1), exprs);
82        }
83        Some(_) | None => {
84            exprs.push(expr.clone());
85        }
86    }
87}
88
89/// An expression wrapper that performs pointer equality on child expressions.
90#[derive(Clone)]
91pub struct ExactExpr(pub Expression);
92impl PartialEq for ExactExpr {
93    fn eq(&self, other: &Self) -> bool {
94        self.0.scalar_fn() == other.0.scalar_fn()
95            && Arc::ptr_eq(self.0.children(), other.0.children())
96    }
97}
98impl Eq for ExactExpr {}
99
100impl Hash for ExactExpr {
101    fn hash<H: Hasher>(&self, state: &mut H) {
102        self.0.hash(state);
103    }
104}
105
106#[cfg(feature = "test-harness")]
107pub mod test_harness {
108    use vortex_dtype::DType;
109    use vortex_dtype::Nullability;
110    use vortex_dtype::PType;
111    use vortex_dtype::StructFields;
112
113    pub fn struct_dtype() -> DType {
114        DType::Struct(
115            StructFields::new(
116                ["a", "col1", "col2", "bool1", "bool2"].into(),
117                vec![
118                    DType::Primitive(PType::I32, Nullability::NonNullable),
119                    DType::Primitive(PType::U16, Nullability::Nullable),
120                    DType::Primitive(PType::U16, Nullability::Nullable),
121                    DType::Bool(Nullability::NonNullable),
122                    DType::Bool(Nullability::NonNullable),
123                ],
124            ),
125            Nullability::NonNullable,
126        )
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use vortex_dtype::DType;
133    use vortex_dtype::FieldNames;
134    use vortex_dtype::Nullability;
135    use vortex_dtype::PType;
136    use vortex_dtype::StructFields;
137    use vortex_scalar::Scalar;
138
139    use super::*;
140    use crate::expr::exprs::binary::and;
141    use crate::expr::exprs::binary::eq;
142    use crate::expr::exprs::binary::gt;
143    use crate::expr::exprs::binary::gt_eq;
144    use crate::expr::exprs::binary::lt;
145    use crate::expr::exprs::binary::lt_eq;
146    use crate::expr::exprs::binary::not_eq;
147    use crate::expr::exprs::binary::or;
148    use crate::expr::exprs::get_item::col;
149    use crate::expr::exprs::get_item::get_item;
150    use crate::expr::exprs::literal::lit;
151    use crate::expr::exprs::not::not;
152    use crate::expr::exprs::root::root;
153    use crate::expr::exprs::select::select;
154    use crate::expr::exprs::select::select_exclude;
155
156    #[test]
157    fn basic_expr_split_test() {
158        let lhs = get_item("col1", root());
159        let rhs = lit(1);
160        let expr = eq(lhs, rhs);
161        let conjunction = split_conjunction(&expr);
162        assert_eq!(conjunction.len(), 1);
163    }
164
165    #[test]
166    fn basic_conjunction_split_test() {
167        let lhs = get_item("col1", root());
168        let rhs = lit(1);
169        let expr = and(lhs, rhs);
170        let conjunction = split_conjunction(&expr);
171        assert_eq!(conjunction.len(), 2, "Conjunction is {conjunction:?}");
172    }
173
174    #[test]
175    fn expr_display() {
176        assert_eq!(col("a").to_string(), "$.a");
177        assert_eq!(root().to_string(), "$");
178
179        let col1: Expression = col("col1");
180        let col2: Expression = col("col2");
181        assert_eq!(
182            and(col1.clone(), col2.clone()).to_string(),
183            "($.col1 and $.col2)"
184        );
185        assert_eq!(
186            or(col1.clone(), col2.clone()).to_string(),
187            "($.col1 or $.col2)"
188        );
189        assert_eq!(
190            eq(col1.clone(), col2.clone()).to_string(),
191            "($.col1 = $.col2)"
192        );
193        assert_eq!(
194            not_eq(col1.clone(), col2.clone()).to_string(),
195            "($.col1 != $.col2)"
196        );
197        assert_eq!(
198            gt(col1.clone(), col2.clone()).to_string(),
199            "($.col1 > $.col2)"
200        );
201        assert_eq!(
202            gt_eq(col1.clone(), col2.clone()).to_string(),
203            "($.col1 >= $.col2)"
204        );
205        assert_eq!(
206            lt(col1.clone(), col2.clone()).to_string(),
207            "($.col1 < $.col2)"
208        );
209        assert_eq!(
210            lt_eq(col1.clone(), col2.clone()).to_string(),
211            "($.col1 <= $.col2)"
212        );
213
214        assert_eq!(
215            or(lt(col1.clone(), col2.clone()), not_eq(col1.clone(), col2),).to_string(),
216            "(($.col1 < $.col2) or ($.col1 != $.col2))"
217        );
218
219        assert_eq!(not(col1).to_string(), "not($.col1)");
220
221        assert_eq!(
222            select(vec![FieldName::from("col1")], root()).to_string(),
223            "${col1}"
224        );
225        assert_eq!(
226            select(
227                vec![FieldName::from("col1"), FieldName::from("col2")],
228                root()
229            )
230            .to_string(),
231            "${col1, col2}"
232        );
233        assert_eq!(
234            select_exclude(
235                vec![FieldName::from("col1"), FieldName::from("col2")],
236                root()
237            )
238            .to_string(),
239            "${~ col1, col2}"
240        );
241
242        assert_eq!(lit(Scalar::from(0u8)).to_string(), "0u8");
243        assert_eq!(lit(Scalar::from(0.0f32)).to_string(), "0f32");
244        assert_eq!(
245            lit(Scalar::from(i64::MAX)).to_string(),
246            "9223372036854775807i64"
247        );
248        assert_eq!(lit(Scalar::from(true)).to_string(), "true");
249        assert_eq!(
250            lit(Scalar::null(DType::Bool(Nullability::Nullable))).to_string(),
251            "null"
252        );
253
254        assert_eq!(
255            lit(Scalar::struct_(
256                DType::Struct(
257                    StructFields::new(
258                        FieldNames::from(["dog", "cat"]),
259                        vec![
260                            DType::Primitive(PType::U32, Nullability::NonNullable),
261                            DType::Utf8(Nullability::NonNullable)
262                        ],
263                    ),
264                    Nullability::NonNullable
265                ),
266                vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
267            ))
268            .to_string(),
269            "{dog: 32u32, cat: \"rufus\"}"
270        );
271    }
272}