vortex_array/expr/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Vortex's expression language.
5//!
6//! All expressions are serializable, and own their own wire format.
7//!
8//! The implementation takes inspiration from [Postgres] and [Apache Datafusion].
9//!
10//! [Postgres]: https://www.postgresql.org/docs/current/sql-expressions.html
11//! [Apache Datafusion]: https://github.com/apache/datafusion/tree/5fac581efbaffd0e6a9edf931182517524526afd/datafusion/expr
12
13use std::hash::Hash;
14use std::hash::Hasher;
15use std::sync::Arc;
16
17use arcref::ArcRef;
18use vortex_dtype::FieldName;
19use vortex_error::VortexUnwrap;
20use vortex_utils::aliases::hash_set::HashSet;
21
22use crate::expr::traversal::NodeExt;
23use crate::expr::traversal::ReferenceCollector;
24
25pub mod aliases;
26pub mod analysis;
27#[cfg(feature = "arbitrary")]
28pub mod arbitrary;
29pub mod display;
30mod expression;
31mod exprs;
32mod field;
33pub mod forms;
34pub mod functions;
35pub mod proto;
36pub mod pruning;
37pub mod session;
38pub mod stats;
39pub mod transform;
40pub mod traversal;
41mod view;
42mod vtable;
43
44pub use analysis::*;
45pub use expression::*;
46pub use exprs::*;
47pub use pruning::StatsCatalog;
48pub use view::*;
49pub use vtable::*;
50
51pub type ExprId = ArcRef<str>;
52
53pub trait VortexExprExt {
54    /// Accumulate all field references from this expression and its children in a set
55    fn field_references(&self) -> HashSet<FieldName>;
56}
57
58impl VortexExprExt for Expression {
59    fn field_references(&self) -> HashSet<FieldName> {
60        let mut collector = ReferenceCollector::new();
61        // The collector is infallible, so we can unwrap the result
62        self.accept(&mut collector).vortex_unwrap();
63        collector.into_fields()
64    }
65}
66
67/// Splits top level and operations into separate expressions.
68pub fn split_conjunction(expr: &Expression) -> Vec<Expression> {
69    let mut conjunctions = vec![];
70    split_inner(expr, &mut conjunctions);
71    conjunctions
72}
73
74fn split_inner(expr: &Expression, exprs: &mut Vec<Expression>) {
75    match expr.as_opt::<Binary>() {
76        Some(bexp) if bexp.operator() == Operator::And => {
77            split_inner(bexp.lhs(), exprs);
78            split_inner(bexp.rhs(), exprs);
79        }
80        Some(_) | None => {
81            exprs.push(expr.clone());
82        }
83    }
84}
85
86/// An expression wrapper that performs pointer equality.
87#[derive(Clone)]
88pub struct ExactExpr(pub Expression);
89
90impl PartialEq for ExactExpr {
91    fn eq(&self, other: &Self) -> bool {
92        self.0.id() == other.0.id() && Arc::ptr_eq(self.0.data(), other.0.data())
93    }
94}
95impl Eq for ExactExpr {}
96
97impl Hash for ExactExpr {
98    fn hash<H: Hasher>(&self, state: &mut H) {
99        self.0.hash(state);
100    }
101}
102
103#[cfg(feature = "test-harness")]
104pub mod test_harness {
105    use vortex_dtype::DType;
106    use vortex_dtype::Nullability;
107    use vortex_dtype::PType;
108    use vortex_dtype::StructFields;
109
110    pub fn struct_dtype() -> DType {
111        DType::Struct(
112            StructFields::new(
113                ["a", "col1", "col2", "bool1", "bool2"].into(),
114                vec![
115                    DType::Primitive(PType::I32, Nullability::NonNullable),
116                    DType::Primitive(PType::U16, Nullability::Nullable),
117                    DType::Primitive(PType::U16, Nullability::Nullable),
118                    DType::Bool(Nullability::NonNullable),
119                    DType::Bool(Nullability::NonNullable),
120                ],
121            ),
122            Nullability::NonNullable,
123        )
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use vortex_dtype::DType;
130    use vortex_dtype::FieldNames;
131    use vortex_dtype::Nullability;
132    use vortex_dtype::PType;
133    use vortex_dtype::StructFields;
134    use vortex_scalar::Scalar;
135
136    use super::*;
137    use crate::expr::exprs::binary::and;
138    use crate::expr::exprs::binary::eq;
139    use crate::expr::exprs::binary::gt;
140    use crate::expr::exprs::binary::gt_eq;
141    use crate::expr::exprs::binary::lt;
142    use crate::expr::exprs::binary::lt_eq;
143    use crate::expr::exprs::binary::not_eq;
144    use crate::expr::exprs::binary::or;
145    use crate::expr::exprs::get_item::col;
146    use crate::expr::exprs::get_item::get_item;
147    use crate::expr::exprs::literal::lit;
148    use crate::expr::exprs::not::not;
149    use crate::expr::exprs::root::root;
150    use crate::expr::exprs::select::select;
151    use crate::expr::exprs::select::select_exclude;
152
153    #[test]
154    fn basic_expr_split_test() {
155        let lhs = get_item("col1", root());
156        let rhs = lit(1);
157        let expr = eq(lhs, rhs);
158        let conjunction = split_conjunction(&expr);
159        assert_eq!(conjunction.len(), 1);
160    }
161
162    #[test]
163    fn basic_conjunction_split_test() {
164        let lhs = get_item("col1", root());
165        let rhs = lit(1);
166        let expr = and(lhs, rhs);
167        let conjunction = split_conjunction(&expr);
168        assert_eq!(conjunction.len(), 2, "Conjunction is {conjunction:?}");
169    }
170
171    #[test]
172    fn expr_display() {
173        assert_eq!(col("a").to_string(), "$.a");
174        assert_eq!(root().to_string(), "$");
175
176        let col1: Expression = col("col1");
177        let col2: Expression = col("col2");
178        assert_eq!(
179            and(col1.clone(), col2.clone()).to_string(),
180            "($.col1 and $.col2)"
181        );
182        assert_eq!(
183            or(col1.clone(), col2.clone()).to_string(),
184            "($.col1 or $.col2)"
185        );
186        assert_eq!(
187            eq(col1.clone(), col2.clone()).to_string(),
188            "($.col1 = $.col2)"
189        );
190        assert_eq!(
191            not_eq(col1.clone(), col2.clone()).to_string(),
192            "($.col1 != $.col2)"
193        );
194        assert_eq!(
195            gt(col1.clone(), col2.clone()).to_string(),
196            "($.col1 > $.col2)"
197        );
198        assert_eq!(
199            gt_eq(col1.clone(), col2.clone()).to_string(),
200            "($.col1 >= $.col2)"
201        );
202        assert_eq!(
203            lt(col1.clone(), col2.clone()).to_string(),
204            "($.col1 < $.col2)"
205        );
206        assert_eq!(
207            lt_eq(col1.clone(), col2.clone()).to_string(),
208            "($.col1 <= $.col2)"
209        );
210
211        assert_eq!(
212            or(lt(col1.clone(), col2.clone()), not_eq(col1.clone(), col2),).to_string(),
213            "(($.col1 < $.col2) or ($.col1 != $.col2))"
214        );
215
216        assert_eq!(not(col1).to_string(), "not($.col1)");
217
218        assert_eq!(
219            select(vec![FieldName::from("col1")], root()).to_string(),
220            "${col1}"
221        );
222        assert_eq!(
223            select(
224                vec![FieldName::from("col1"), FieldName::from("col2")],
225                root()
226            )
227            .to_string(),
228            "${col1, col2}"
229        );
230        assert_eq!(
231            select_exclude(
232                vec![FieldName::from("col1"), FieldName::from("col2")],
233                root()
234            )
235            .to_string(),
236            "${~ col1, col2}"
237        );
238
239        assert_eq!(lit(Scalar::from(0u8)).to_string(), "0u8");
240        assert_eq!(lit(Scalar::from(0.0f32)).to_string(), "0f32");
241        assert_eq!(
242            lit(Scalar::from(i64::MAX)).to_string(),
243            "9223372036854775807i64"
244        );
245        assert_eq!(lit(Scalar::from(true)).to_string(), "true");
246        assert_eq!(
247            lit(Scalar::null(DType::Bool(Nullability::Nullable))).to_string(),
248            "null"
249        );
250
251        assert_eq!(
252            lit(Scalar::struct_(
253                DType::Struct(
254                    StructFields::new(
255                        FieldNames::from(["dog", "cat"]),
256                        vec![
257                            DType::Primitive(PType::U32, Nullability::NonNullable),
258                            DType::Utf8(Nullability::NonNullable)
259                        ],
260                    ),
261                    Nullability::NonNullable
262                ),
263                vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
264            ))
265            .to_string(),
266            "{dog: 32u32, cat: \"rufus\"}"
267        );
268    }
269}