vortex_array/expr/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Vortex's expression language.
5//!
6//! All expressions are serializable, and own their own wire format.
7//!
8//! The implementation takes inspiration from [Postgres] and [Apache Datafusion].
9//!
10//! [Postgres]: https://www.postgresql.org/docs/current/sql-expressions.html
11//! [Apache Datafusion]: https://github.com/apache/datafusion/tree/5fac581efbaffd0e6a9edf931182517524526afd/datafusion/expr
12
13use std::hash::{Hash, Hasher};
14use std::sync::Arc;
15
16use arcref::ArcRef;
17use vortex_dtype::FieldName;
18use vortex_error::VortexUnwrap;
19use vortex_utils::aliases::hash_set::HashSet;
20
21use crate::expr::traversal::{NodeExt, ReferenceCollector};
22
23pub mod aliases;
24mod analysis;
25#[cfg(feature = "arbitrary")]
26pub mod arbitrary;
27pub mod display;
28mod expression;
29mod exprs;
30mod field;
31pub mod forms;
32pub mod proto;
33pub mod pruning;
34pub mod session;
35pub mod transform;
36pub mod traversal;
37mod view;
38mod vtable;
39
40pub use analysis::*;
41pub use expression::*;
42pub use exprs::*;
43pub use view::*;
44pub use vtable::*;
45
46pub type ExprId = ArcRef<str>;
47
48pub trait VortexExprExt {
49    /// Accumulate all field references from this expression and its children in a set
50    fn field_references(&self) -> HashSet<FieldName>;
51}
52
53impl VortexExprExt for Expression {
54    fn field_references(&self) -> HashSet<FieldName> {
55        let mut collector = ReferenceCollector::new();
56        // The collector is infallible, so we can unwrap the result
57        self.accept(&mut collector).vortex_unwrap();
58        collector.into_fields()
59    }
60}
61
62/// Splits top level and operations into separate expressions.
63pub fn split_conjunction(expr: &Expression) -> Vec<Expression> {
64    let mut conjunctions = vec![];
65    split_inner(expr, &mut conjunctions);
66    conjunctions
67}
68
69fn split_inner(expr: &Expression, exprs: &mut Vec<Expression>) {
70    match expr.as_opt::<Binary>() {
71        Some(bexp) if bexp.operator() == Operator::And => {
72            split_inner(bexp.lhs(), exprs);
73            split_inner(bexp.rhs(), exprs);
74        }
75        Some(_) | None => {
76            exprs.push(expr.clone());
77        }
78    }
79}
80
81/// An expression wrapper that performs pointer equality.
82#[derive(Clone)]
83pub struct ExactExpr(pub Expression);
84
85impl PartialEq for ExactExpr {
86    fn eq(&self, other: &Self) -> bool {
87        self.0.id() == other.0.id() && Arc::ptr_eq(self.0.data(), other.0.data())
88    }
89}
90impl Eq for ExactExpr {}
91
92impl Hash for ExactExpr {
93    fn hash<H: Hasher>(&self, state: &mut H) {
94        self.0.hash(state);
95    }
96}
97
98#[cfg(feature = "test-harness")]
99pub mod test_harness {
100    use vortex_dtype::{DType, Nullability, PType, StructFields};
101
102    pub fn struct_dtype() -> DType {
103        DType::Struct(
104            StructFields::new(
105                ["a", "col1", "col2", "bool1", "bool2"].into(),
106                vec![
107                    DType::Primitive(PType::I32, Nullability::NonNullable),
108                    DType::Primitive(PType::U16, Nullability::Nullable),
109                    DType::Primitive(PType::U16, Nullability::Nullable),
110                    DType::Bool(Nullability::NonNullable),
111                    DType::Bool(Nullability::NonNullable),
112                ],
113            ),
114            Nullability::NonNullable,
115        )
116    }
117}
118
119#[cfg(test)]
120mod tests {
121    use vortex_dtype::{DType, FieldNames, Nullability, PType, StructFields};
122    use vortex_scalar::Scalar;
123
124    use super::*;
125    use crate::expr::exprs::binary::{and, eq, gt, gt_eq, lt, lt_eq, not_eq, or};
126    use crate::expr::exprs::get_item::{col, get_item};
127    use crate::expr::exprs::literal::lit;
128    use crate::expr::exprs::not::not;
129    use crate::expr::exprs::root::root;
130    use crate::expr::exprs::select::{select, select_exclude};
131
132    #[test]
133    fn basic_expr_split_test() {
134        let lhs = get_item("col1", root());
135        let rhs = lit(1);
136        let expr = eq(lhs, rhs);
137        let conjunction = split_conjunction(&expr);
138        assert_eq!(conjunction.len(), 1);
139    }
140
141    #[test]
142    fn basic_conjunction_split_test() {
143        let lhs = get_item("col1", root());
144        let rhs = lit(1);
145        let expr = and(lhs, rhs);
146        let conjunction = split_conjunction(&expr);
147        assert_eq!(conjunction.len(), 2, "Conjunction is {conjunction:?}");
148    }
149
150    #[test]
151    fn expr_display() {
152        assert_eq!(col("a").to_string(), "$.a");
153        assert_eq!(root().to_string(), "$");
154
155        let col1: Expression = col("col1");
156        let col2: Expression = col("col2");
157        assert_eq!(
158            and(col1.clone(), col2.clone()).to_string(),
159            "($.col1 and $.col2)"
160        );
161        assert_eq!(
162            or(col1.clone(), col2.clone()).to_string(),
163            "($.col1 or $.col2)"
164        );
165        assert_eq!(
166            eq(col1.clone(), col2.clone()).to_string(),
167            "($.col1 = $.col2)"
168        );
169        assert_eq!(
170            not_eq(col1.clone(), col2.clone()).to_string(),
171            "($.col1 != $.col2)"
172        );
173        assert_eq!(
174            gt(col1.clone(), col2.clone()).to_string(),
175            "($.col1 > $.col2)"
176        );
177        assert_eq!(
178            gt_eq(col1.clone(), col2.clone()).to_string(),
179            "($.col1 >= $.col2)"
180        );
181        assert_eq!(
182            lt(col1.clone(), col2.clone()).to_string(),
183            "($.col1 < $.col2)"
184        );
185        assert_eq!(
186            lt_eq(col1.clone(), col2.clone()).to_string(),
187            "($.col1 <= $.col2)"
188        );
189
190        assert_eq!(
191            or(lt(col1.clone(), col2.clone()), not_eq(col1.clone(), col2),).to_string(),
192            "(($.col1 < $.col2) or ($.col1 != $.col2))"
193        );
194
195        assert_eq!(not(col1).to_string(), "not($.col1)");
196
197        assert_eq!(
198            select(vec![FieldName::from("col1")], root()).to_string(),
199            "${col1}"
200        );
201        assert_eq!(
202            select(
203                vec![FieldName::from("col1"), FieldName::from("col2")],
204                root()
205            )
206            .to_string(),
207            "${col1, col2}"
208        );
209        assert_eq!(
210            select_exclude(
211                vec![FieldName::from("col1"), FieldName::from("col2")],
212                root()
213            )
214            .to_string(),
215            "${~ col1, col2}"
216        );
217
218        assert_eq!(lit(Scalar::from(0u8)).to_string(), "0u8");
219        assert_eq!(lit(Scalar::from(0.0f32)).to_string(), "0f32");
220        assert_eq!(
221            lit(Scalar::from(i64::MAX)).to_string(),
222            "9223372036854775807i64"
223        );
224        assert_eq!(lit(Scalar::from(true)).to_string(), "true");
225        assert_eq!(
226            lit(Scalar::null(DType::Bool(Nullability::Nullable))).to_string(),
227            "null"
228        );
229
230        assert_eq!(
231            lit(Scalar::struct_(
232                DType::Struct(
233                    StructFields::new(
234                        FieldNames::from(["dog", "cat"]),
235                        vec![
236                            DType::Primitive(PType::U32, Nullability::NonNullable),
237                            DType::Utf8(Nullability::NonNullable)
238                        ],
239                    ),
240                    Nullability::NonNullable
241                ),
242                vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
243            ))
244            .to_string(),
245            "{dog: 32u32, cat: \"rufus\"}"
246        );
247    }
248}