1use std::hash::Hash;
43use std::hash::Hasher;
44use std::sync::Arc;
45
46use vortex_error::VortexExpect;
47use vortex_utils::aliases::hash_set::HashSet;
48
49use crate::dtype::FieldName;
50use crate::expr::traversal::NodeExt;
51use crate::expr::traversal::ReferenceCollector;
52use crate::scalar_fn::fns::binary::Binary;
53use crate::scalar_fn::fns::operators::Operator;
54
55pub mod aliases;
56pub mod analysis;
57#[cfg(feature = "arbitrary")]
58pub mod arbitrary;
59pub mod display;
60pub(crate) mod expression;
61mod exprs;
62pub(crate) mod field;
63pub mod forms;
64mod optimize;
65pub mod proto;
66pub mod stats;
67pub mod transform;
68pub mod traversal;
69
70pub use analysis::*;
71pub use expression::*;
72pub use exprs::*;
73
74pub trait VortexExprExt {
75 fn field_references(&self) -> HashSet<FieldName>;
77}
78
79impl VortexExprExt for Expression {
80 fn field_references(&self) -> HashSet<FieldName> {
81 let mut collector = ReferenceCollector::new();
82 self.accept(&mut collector)
84 .vortex_expect("reference collector should never fail");
85 collector.into_fields()
86 }
87}
88
89pub fn split_conjunction(expr: &Expression) -> Vec<Expression> {
91 let mut conjunctions = vec![];
92 split_inner(expr, &mut conjunctions);
93 conjunctions
94}
95
96fn split_inner(expr: &Expression, exprs: &mut Vec<Expression>) {
97 match expr.as_opt::<Binary>() {
98 Some(operator) if *operator == Operator::And => {
99 split_inner(expr.child(0), exprs);
100 split_inner(expr.child(1), exprs);
101 }
102 Some(_) | None => {
103 exprs.push(expr.clone());
104 }
105 }
106}
107
108#[derive(Clone)]
110pub struct ExactExpr(pub Expression);
111impl PartialEq for ExactExpr {
112 fn eq(&self, other: &Self) -> bool {
113 self.0.scalar_fn() == other.0.scalar_fn()
114 && Arc::ptr_eq(self.0.children(), other.0.children())
115 }
116}
117impl Eq for ExactExpr {}
118
119impl Hash for ExactExpr {
120 fn hash<H: Hasher>(&self, state: &mut H) {
121 self.0.hash(state);
122 }
123}
124
125#[cfg(feature = "_test-harness")]
126pub mod test_harness {
127 use crate::dtype::DType;
128 use crate::dtype::Nullability;
129 use crate::dtype::PType;
130 use crate::dtype::StructFields;
131
132 pub fn struct_dtype() -> DType {
133 DType::Struct(
134 StructFields::new(
135 ["a", "col1", "col2", "bool1", "bool2"].into(),
136 vec![
137 DType::Primitive(PType::I32, Nullability::NonNullable),
138 DType::Primitive(PType::U16, Nullability::Nullable),
139 DType::Primitive(PType::U16, Nullability::Nullable),
140 DType::Bool(Nullability::NonNullable),
141 DType::Bool(Nullability::NonNullable),
142 ],
143 ),
144 Nullability::NonNullable,
145 )
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use super::*;
152 use crate::dtype::DType;
153 use crate::dtype::FieldNames;
154 use crate::dtype::Nullability;
155 use crate::dtype::PType;
156 use crate::dtype::StructFields;
157 use crate::expr::and;
158 use crate::expr::col;
159 use crate::expr::eq;
160 use crate::expr::get_item;
161 use crate::expr::gt;
162 use crate::expr::gt_eq;
163 use crate::expr::lit;
164 use crate::expr::lt;
165 use crate::expr::lt_eq;
166 use crate::expr::not;
167 use crate::expr::not_eq;
168 use crate::expr::or;
169 use crate::expr::root;
170 use crate::expr::select;
171 use crate::expr::select_exclude;
172 use crate::scalar::Scalar;
173
174 #[test]
175 fn basic_expr_split_test() {
176 let lhs = get_item("col1", root());
177 let rhs = lit(1);
178 let expr = eq(lhs, rhs);
179 let conjunction = split_conjunction(&expr);
180 assert_eq!(conjunction.len(), 1);
181 }
182
183 #[test]
184 fn basic_conjunction_split_test() {
185 let lhs = get_item("col1", root());
186 let rhs = lit(1);
187 let expr = and(lhs, rhs);
188 let conjunction = split_conjunction(&expr);
189 assert_eq!(conjunction.len(), 2, "Conjunction is {conjunction:?}");
190 }
191
192 #[test]
193 fn expr_display() {
194 assert_eq!(col("a").to_string(), "$.a");
195 assert_eq!(root().to_string(), "$");
196
197 let col1: Expression = col("col1");
198 let col2: Expression = col("col2");
199 assert_eq!(
200 and(col1.clone(), col2.clone()).to_string(),
201 "($.col1 and $.col2)"
202 );
203 assert_eq!(
204 or(col1.clone(), col2.clone()).to_string(),
205 "($.col1 or $.col2)"
206 );
207 assert_eq!(
208 eq(col1.clone(), col2.clone()).to_string(),
209 "($.col1 = $.col2)"
210 );
211 assert_eq!(
212 not_eq(col1.clone(), col2.clone()).to_string(),
213 "($.col1 != $.col2)"
214 );
215 assert_eq!(
216 gt(col1.clone(), col2.clone()).to_string(),
217 "($.col1 > $.col2)"
218 );
219 assert_eq!(
220 gt_eq(col1.clone(), col2.clone()).to_string(),
221 "($.col1 >= $.col2)"
222 );
223 assert_eq!(
224 lt(col1.clone(), col2.clone()).to_string(),
225 "($.col1 < $.col2)"
226 );
227 assert_eq!(
228 lt_eq(col1.clone(), col2.clone()).to_string(),
229 "($.col1 <= $.col2)"
230 );
231
232 assert_eq!(
233 or(lt(col1.clone(), col2.clone()), not_eq(col1.clone(), col2),).to_string(),
234 "(($.col1 < $.col2) or ($.col1 != $.col2))"
235 );
236
237 assert_eq!(not(col1).to_string(), "vortex.not($.col1)");
238
239 assert_eq!(
240 select(vec![FieldName::from("col1")], root()).to_string(),
241 "${col1}"
242 );
243 assert_eq!(
244 select(
245 vec![FieldName::from("col1"), FieldName::from("col2")],
246 root()
247 )
248 .to_string(),
249 "${col1, col2}"
250 );
251 assert_eq!(
252 select_exclude(
253 vec![FieldName::from("col1"), FieldName::from("col2")],
254 root()
255 )
256 .to_string(),
257 "${~ col1, col2}"
258 );
259
260 assert_eq!(lit(Scalar::from(0u8)).to_string(), "0u8");
261 assert_eq!(lit(Scalar::from(0.0f32)).to_string(), "0f32");
262 assert_eq!(
263 lit(Scalar::from(i64::MAX)).to_string(),
264 "9223372036854775807i64"
265 );
266 assert_eq!(lit(Scalar::from(true)).to_string(), "true");
267 assert_eq!(
268 lit(Scalar::null(DType::Bool(Nullability::Nullable))).to_string(),
269 "null"
270 );
271
272 assert_eq!(
273 lit(Scalar::struct_(
274 DType::Struct(
275 StructFields::new(
276 FieldNames::from(["dog", "cat"]),
277 vec![
278 DType::Primitive(PType::U32, Nullability::NonNullable),
279 DType::Utf8(Nullability::NonNullable)
280 ],
281 ),
282 Nullability::NonNullable
283 ),
284 vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
285 ))
286 .to_string(),
287 "{dog: 32u32, cat: \"rufus\"}"
288 );
289 }
290}