1use std::hash::Hash;
14use std::hash::Hasher;
15use std::sync::Arc;
16
17use vortex_error::VortexExpect;
18use vortex_utils::aliases::hash_set::HashSet;
19
20use crate::dtype::FieldName;
21use crate::expr::traversal::NodeExt;
22use crate::expr::traversal::ReferenceCollector;
23use crate::scalar_fn::fns::binary::Binary;
24use crate::scalar_fn::fns::operators::Operator;
25
26pub mod aliases;
27pub mod analysis;
28#[cfg(feature = "arbitrary")]
29pub mod arbitrary;
30pub mod display;
31pub(crate) mod expression;
32mod exprs;
33pub(crate) mod field;
34pub mod forms;
35mod optimize;
36pub mod proto;
37pub mod pruning;
38pub mod stats;
39pub mod transform;
40pub mod traversal;
41
42pub use analysis::*;
43pub use expression::*;
44pub use exprs::*;
45pub use pruning::StatsCatalog;
46
47pub trait VortexExprExt {
48 fn field_references(&self) -> HashSet<FieldName>;
50}
51
52impl VortexExprExt for Expression {
53 fn field_references(&self) -> HashSet<FieldName> {
54 let mut collector = ReferenceCollector::new();
55 self.accept(&mut collector)
57 .vortex_expect("reference collector should never fail");
58 collector.into_fields()
59 }
60}
61
62pub fn split_conjunction(expr: &Expression) -> Vec<Expression> {
64 let mut conjunctions = vec![];
65 split_inner(expr, &mut conjunctions);
66 conjunctions
67}
68
69fn split_inner(expr: &Expression, exprs: &mut Vec<Expression>) {
70 match expr.as_opt::<Binary>() {
71 Some(operator) if *operator == Operator::And => {
72 split_inner(expr.child(0), exprs);
73 split_inner(expr.child(1), exprs);
74 }
75 Some(_) | None => {
76 exprs.push(expr.clone());
77 }
78 }
79}
80
81#[derive(Clone)]
83pub struct ExactExpr(pub Expression);
84impl PartialEq for ExactExpr {
85 fn eq(&self, other: &Self) -> bool {
86 self.0.scalar_fn() == other.0.scalar_fn()
87 && Arc::ptr_eq(self.0.children(), other.0.children())
88 }
89}
90impl Eq for ExactExpr {}
91
92impl Hash for ExactExpr {
93 fn hash<H: Hasher>(&self, state: &mut H) {
94 self.0.hash(state);
95 }
96}
97
98#[cfg(feature = "_test-harness")]
99pub mod test_harness {
100 use crate::dtype::DType;
101 use crate::dtype::Nullability;
102 use crate::dtype::PType;
103 use crate::dtype::StructFields;
104
105 pub fn struct_dtype() -> DType {
106 DType::Struct(
107 StructFields::new(
108 ["a", "col1", "col2", "bool1", "bool2"].into(),
109 vec![
110 DType::Primitive(PType::I32, Nullability::NonNullable),
111 DType::Primitive(PType::U16, Nullability::Nullable),
112 DType::Primitive(PType::U16, Nullability::Nullable),
113 DType::Bool(Nullability::NonNullable),
114 DType::Bool(Nullability::NonNullable),
115 ],
116 ),
117 Nullability::NonNullable,
118 )
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125 use crate::dtype::DType;
126 use crate::dtype::FieldNames;
127 use crate::dtype::Nullability;
128 use crate::dtype::PType;
129 use crate::dtype::StructFields;
130 use crate::expr::and;
131 use crate::expr::col;
132 use crate::expr::eq;
133 use crate::expr::get_item;
134 use crate::expr::gt;
135 use crate::expr::gt_eq;
136 use crate::expr::lit;
137 use crate::expr::lt;
138 use crate::expr::lt_eq;
139 use crate::expr::not;
140 use crate::expr::not_eq;
141 use crate::expr::or;
142 use crate::expr::root;
143 use crate::expr::select;
144 use crate::expr::select_exclude;
145 use crate::scalar::Scalar;
146
147 #[test]
148 fn basic_expr_split_test() {
149 let lhs = get_item("col1", root());
150 let rhs = lit(1);
151 let expr = eq(lhs, rhs);
152 let conjunction = split_conjunction(&expr);
153 assert_eq!(conjunction.len(), 1);
154 }
155
156 #[test]
157 fn basic_conjunction_split_test() {
158 let lhs = get_item("col1", root());
159 let rhs = lit(1);
160 let expr = and(lhs, rhs);
161 let conjunction = split_conjunction(&expr);
162 assert_eq!(conjunction.len(), 2, "Conjunction is {conjunction:?}");
163 }
164
165 #[test]
166 fn expr_display() {
167 assert_eq!(col("a").to_string(), "$.a");
168 assert_eq!(root().to_string(), "$");
169
170 let col1: Expression = col("col1");
171 let col2: Expression = col("col2");
172 assert_eq!(
173 and(col1.clone(), col2.clone()).to_string(),
174 "($.col1 and $.col2)"
175 );
176 assert_eq!(
177 or(col1.clone(), col2.clone()).to_string(),
178 "($.col1 or $.col2)"
179 );
180 assert_eq!(
181 eq(col1.clone(), col2.clone()).to_string(),
182 "($.col1 = $.col2)"
183 );
184 assert_eq!(
185 not_eq(col1.clone(), col2.clone()).to_string(),
186 "($.col1 != $.col2)"
187 );
188 assert_eq!(
189 gt(col1.clone(), col2.clone()).to_string(),
190 "($.col1 > $.col2)"
191 );
192 assert_eq!(
193 gt_eq(col1.clone(), col2.clone()).to_string(),
194 "($.col1 >= $.col2)"
195 );
196 assert_eq!(
197 lt(col1.clone(), col2.clone()).to_string(),
198 "($.col1 < $.col2)"
199 );
200 assert_eq!(
201 lt_eq(col1.clone(), col2.clone()).to_string(),
202 "($.col1 <= $.col2)"
203 );
204
205 assert_eq!(
206 or(lt(col1.clone(), col2.clone()), not_eq(col1.clone(), col2),).to_string(),
207 "(($.col1 < $.col2) or ($.col1 != $.col2))"
208 );
209
210 assert_eq!(not(col1).to_string(), "not($.col1)");
211
212 assert_eq!(
213 select(vec![FieldName::from("col1")], root()).to_string(),
214 "${col1}"
215 );
216 assert_eq!(
217 select(
218 vec![FieldName::from("col1"), FieldName::from("col2")],
219 root()
220 )
221 .to_string(),
222 "${col1, col2}"
223 );
224 assert_eq!(
225 select_exclude(
226 vec![FieldName::from("col1"), FieldName::from("col2")],
227 root()
228 )
229 .to_string(),
230 "${~ col1, col2}"
231 );
232
233 assert_eq!(lit(Scalar::from(0u8)).to_string(), "0u8");
234 assert_eq!(lit(Scalar::from(0.0f32)).to_string(), "0f32");
235 assert_eq!(
236 lit(Scalar::from(i64::MAX)).to_string(),
237 "9223372036854775807i64"
238 );
239 assert_eq!(lit(Scalar::from(true)).to_string(), "true");
240 assert_eq!(
241 lit(Scalar::null(DType::Bool(Nullability::Nullable))).to_string(),
242 "null"
243 );
244
245 assert_eq!(
246 lit(Scalar::struct_(
247 DType::Struct(
248 StructFields::new(
249 FieldNames::from(["dog", "cat"]),
250 vec![
251 DType::Primitive(PType::U32, Nullability::NonNullable),
252 DType::Utf8(Nullability::NonNullable)
253 ],
254 ),
255 Nullability::NonNullable
256 ),
257 vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
258 ))
259 .to_string(),
260 "{dog: 32u32, cat: \"rufus\"}"
261 );
262 }
263}