1use std::hash::{Hash, Hasher};
14use std::sync::Arc;
15
16use arcref::ArcRef;
17use vortex_dtype::FieldName;
18use vortex_error::VortexUnwrap;
19use vortex_utils::aliases::hash_set::HashSet;
20
21use crate::expr::traversal::{NodeExt, ReferenceCollector};
22
23pub mod aliases;
24mod analysis;
25#[cfg(feature = "arbitrary")]
26pub mod arbitrary;
27pub mod display;
28mod expression;
29mod exprs;
30mod field;
31pub mod forms;
32pub mod proto;
33pub mod pruning;
34pub mod session;
35pub mod transform;
36pub mod traversal;
37mod view;
38mod vtable;
39
40pub use analysis::*;
41pub use expression::*;
42pub use exprs::*;
43pub use view::*;
44pub use vtable::*;
45
46pub type ExprId = ArcRef<str>;
47
48pub trait VortexExprExt {
49 fn field_references(&self) -> HashSet<FieldName>;
51}
52
53impl VortexExprExt for Expression {
54 fn field_references(&self) -> HashSet<FieldName> {
55 let mut collector = ReferenceCollector::new();
56 self.accept(&mut collector).vortex_unwrap();
58 collector.into_fields()
59 }
60}
61
62pub fn split_conjunction(expr: &Expression) -> Vec<Expression> {
64 let mut conjunctions = vec![];
65 split_inner(expr, &mut conjunctions);
66 conjunctions
67}
68
69fn split_inner(expr: &Expression, exprs: &mut Vec<Expression>) {
70 match expr.as_opt::<Binary>() {
71 Some(bexp) if bexp.operator() == Operator::And => {
72 split_inner(bexp.lhs(), exprs);
73 split_inner(bexp.rhs(), exprs);
74 }
75 Some(_) | None => {
76 exprs.push(expr.clone());
77 }
78 }
79}
80
81#[derive(Clone)]
83pub struct ExactExpr(pub Expression);
84
85impl PartialEq for ExactExpr {
86 fn eq(&self, other: &Self) -> bool {
87 self.0.id() == other.0.id() && Arc::ptr_eq(self.0.data(), other.0.data())
88 }
89}
90impl Eq for ExactExpr {}
91
92impl Hash for ExactExpr {
93 fn hash<H: Hasher>(&self, state: &mut H) {
94 self.0.hash(state);
95 }
96}
97
98#[cfg(feature = "test-harness")]
99pub mod test_harness {
100 use vortex_dtype::{DType, Nullability, PType, StructFields};
101
102 pub fn struct_dtype() -> DType {
103 DType::Struct(
104 StructFields::new(
105 ["a", "col1", "col2", "bool1", "bool2"].into(),
106 vec![
107 DType::Primitive(PType::I32, Nullability::NonNullable),
108 DType::Primitive(PType::U16, Nullability::Nullable),
109 DType::Primitive(PType::U16, Nullability::Nullable),
110 DType::Bool(Nullability::NonNullable),
111 DType::Bool(Nullability::NonNullable),
112 ],
113 ),
114 Nullability::NonNullable,
115 )
116 }
117}
118
119#[cfg(test)]
120mod tests {
121 use vortex_dtype::{DType, FieldNames, Nullability, PType, StructFields};
122 use vortex_scalar::Scalar;
123
124 use super::*;
125 use crate::expr::exprs::binary::{and, eq, gt, gt_eq, lt, lt_eq, not_eq, or};
126 use crate::expr::exprs::get_item::{col, get_item};
127 use crate::expr::exprs::literal::lit;
128 use crate::expr::exprs::not::not;
129 use crate::expr::exprs::root::root;
130 use crate::expr::exprs::select::{select, select_exclude};
131
132 #[test]
133 fn basic_expr_split_test() {
134 let lhs = get_item("col1", root());
135 let rhs = lit(1);
136 let expr = eq(lhs, rhs);
137 let conjunction = split_conjunction(&expr);
138 assert_eq!(conjunction.len(), 1);
139 }
140
141 #[test]
142 fn basic_conjunction_split_test() {
143 let lhs = get_item("col1", root());
144 let rhs = lit(1);
145 let expr = and(lhs, rhs);
146 let conjunction = split_conjunction(&expr);
147 assert_eq!(conjunction.len(), 2, "Conjunction is {conjunction:?}");
148 }
149
150 #[test]
151 fn expr_display() {
152 assert_eq!(col("a").to_string(), "$.a");
153 assert_eq!(root().to_string(), "$");
154
155 let col1: Expression = col("col1");
156 let col2: Expression = col("col2");
157 assert_eq!(
158 and(col1.clone(), col2.clone()).to_string(),
159 "($.col1 and $.col2)"
160 );
161 assert_eq!(
162 or(col1.clone(), col2.clone()).to_string(),
163 "($.col1 or $.col2)"
164 );
165 assert_eq!(
166 eq(col1.clone(), col2.clone()).to_string(),
167 "($.col1 = $.col2)"
168 );
169 assert_eq!(
170 not_eq(col1.clone(), col2.clone()).to_string(),
171 "($.col1 != $.col2)"
172 );
173 assert_eq!(
174 gt(col1.clone(), col2.clone()).to_string(),
175 "($.col1 > $.col2)"
176 );
177 assert_eq!(
178 gt_eq(col1.clone(), col2.clone()).to_string(),
179 "($.col1 >= $.col2)"
180 );
181 assert_eq!(
182 lt(col1.clone(), col2.clone()).to_string(),
183 "($.col1 < $.col2)"
184 );
185 assert_eq!(
186 lt_eq(col1.clone(), col2.clone()).to_string(),
187 "($.col1 <= $.col2)"
188 );
189
190 assert_eq!(
191 or(lt(col1.clone(), col2.clone()), not_eq(col1.clone(), col2),).to_string(),
192 "(($.col1 < $.col2) or ($.col1 != $.col2))"
193 );
194
195 assert_eq!(not(col1).to_string(), "not($.col1)");
196
197 assert_eq!(
198 select(vec![FieldName::from("col1")], root()).to_string(),
199 "${col1}"
200 );
201 assert_eq!(
202 select(
203 vec![FieldName::from("col1"), FieldName::from("col2")],
204 root()
205 )
206 .to_string(),
207 "${col1, col2}"
208 );
209 assert_eq!(
210 select_exclude(
211 vec![FieldName::from("col1"), FieldName::from("col2")],
212 root()
213 )
214 .to_string(),
215 "${~ col1, col2}"
216 );
217
218 assert_eq!(lit(Scalar::from(0u8)).to_string(), "0u8");
219 assert_eq!(lit(Scalar::from(0.0f32)).to_string(), "0f32");
220 assert_eq!(
221 lit(Scalar::from(i64::MAX)).to_string(),
222 "9223372036854775807i64"
223 );
224 assert_eq!(lit(Scalar::from(true)).to_string(), "true");
225 assert_eq!(
226 lit(Scalar::null(DType::Bool(Nullability::Nullable))).to_string(),
227 "null"
228 );
229
230 assert_eq!(
231 lit(Scalar::struct_(
232 DType::Struct(
233 StructFields::new(
234 FieldNames::from(["dog", "cat"]),
235 vec![
236 DType::Primitive(PType::U32, Nullability::NonNullable),
237 DType::Utf8(Nullability::NonNullable)
238 ],
239 ),
240 Nullability::NonNullable
241 ),
242 vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
243 ))
244 .to_string(),
245 "{dog: 32u32, cat: \"rufus\"}"
246 );
247 }
248}