1use std::hash::Hash;
14use std::hash::Hasher;
15use std::sync::Arc;
16
17use arcref::ArcRef;
18use vortex_dtype::FieldName;
19use vortex_error::VortexExpect;
20use vortex_utils::aliases::hash_set::HashSet;
21
22use crate::expr::traversal::NodeExt;
23use crate::expr::traversal::ReferenceCollector;
24
25pub mod aliases;
26pub mod analysis;
27#[cfg(feature = "arbitrary")]
28pub mod arbitrary;
29pub mod display;
30mod expression;
31mod exprs;
32mod field;
33pub mod forms;
34mod optimize;
35mod options;
36pub mod proto;
37pub mod pruning;
38mod scalar_fn;
39pub mod session;
40mod signature;
41pub mod stats;
42pub mod transform;
43pub mod traversal;
44mod vtable;
45
46pub use analysis::*;
47pub use expression::*;
48pub use exprs::*;
49pub use pruning::StatsCatalog;
50pub use scalar_fn::*;
51pub use vtable::*;
52
53pub type ExprId = ArcRef<str>;
54
55pub trait VortexExprExt {
56 fn field_references(&self) -> HashSet<FieldName>;
58}
59
60impl VortexExprExt for Expression {
61 fn field_references(&self) -> HashSet<FieldName> {
62 let mut collector = ReferenceCollector::new();
63 self.accept(&mut collector)
65 .vortex_expect("reference collector should never fail");
66 collector.into_fields()
67 }
68}
69
70pub fn split_conjunction(expr: &Expression) -> Vec<Expression> {
72 let mut conjunctions = vec![];
73 split_inner(expr, &mut conjunctions);
74 conjunctions
75}
76
77fn split_inner(expr: &Expression, exprs: &mut Vec<Expression>) {
78 match expr.as_opt::<Binary>() {
79 Some(operator) if *operator == Operator::And => {
80 split_inner(expr.child(0), exprs);
81 split_inner(expr.child(1), exprs);
82 }
83 Some(_) | None => {
84 exprs.push(expr.clone());
85 }
86 }
87}
88
89#[derive(Clone)]
91pub struct ExactExpr(pub Expression);
92impl PartialEq for ExactExpr {
93 fn eq(&self, other: &Self) -> bool {
94 self.0.scalar_fn() == other.0.scalar_fn()
95 && Arc::ptr_eq(self.0.children(), other.0.children())
96 }
97}
98impl Eq for ExactExpr {}
99
100impl Hash for ExactExpr {
101 fn hash<H: Hasher>(&self, state: &mut H) {
102 self.0.hash(state);
103 }
104}
105
106#[cfg(feature = "test-harness")]
107pub mod test_harness {
108 use vortex_dtype::DType;
109 use vortex_dtype::Nullability;
110 use vortex_dtype::PType;
111 use vortex_dtype::StructFields;
112
113 pub fn struct_dtype() -> DType {
114 DType::Struct(
115 StructFields::new(
116 ["a", "col1", "col2", "bool1", "bool2"].into(),
117 vec![
118 DType::Primitive(PType::I32, Nullability::NonNullable),
119 DType::Primitive(PType::U16, Nullability::Nullable),
120 DType::Primitive(PType::U16, Nullability::Nullable),
121 DType::Bool(Nullability::NonNullable),
122 DType::Bool(Nullability::NonNullable),
123 ],
124 ),
125 Nullability::NonNullable,
126 )
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use vortex_dtype::DType;
133 use vortex_dtype::FieldNames;
134 use vortex_dtype::Nullability;
135 use vortex_dtype::PType;
136 use vortex_dtype::StructFields;
137 use vortex_scalar::Scalar;
138
139 use super::*;
140 use crate::expr::exprs::binary::and;
141 use crate::expr::exprs::binary::eq;
142 use crate::expr::exprs::binary::gt;
143 use crate::expr::exprs::binary::gt_eq;
144 use crate::expr::exprs::binary::lt;
145 use crate::expr::exprs::binary::lt_eq;
146 use crate::expr::exprs::binary::not_eq;
147 use crate::expr::exprs::binary::or;
148 use crate::expr::exprs::get_item::col;
149 use crate::expr::exprs::get_item::get_item;
150 use crate::expr::exprs::literal::lit;
151 use crate::expr::exprs::not::not;
152 use crate::expr::exprs::root::root;
153 use crate::expr::exprs::select::select;
154 use crate::expr::exprs::select::select_exclude;
155
156 #[test]
157 fn basic_expr_split_test() {
158 let lhs = get_item("col1", root());
159 let rhs = lit(1);
160 let expr = eq(lhs, rhs);
161 let conjunction = split_conjunction(&expr);
162 assert_eq!(conjunction.len(), 1);
163 }
164
165 #[test]
166 fn basic_conjunction_split_test() {
167 let lhs = get_item("col1", root());
168 let rhs = lit(1);
169 let expr = and(lhs, rhs);
170 let conjunction = split_conjunction(&expr);
171 assert_eq!(conjunction.len(), 2, "Conjunction is {conjunction:?}");
172 }
173
174 #[test]
175 fn expr_display() {
176 assert_eq!(col("a").to_string(), "$.a");
177 assert_eq!(root().to_string(), "$");
178
179 let col1: Expression = col("col1");
180 let col2: Expression = col("col2");
181 assert_eq!(
182 and(col1.clone(), col2.clone()).to_string(),
183 "($.col1 and $.col2)"
184 );
185 assert_eq!(
186 or(col1.clone(), col2.clone()).to_string(),
187 "($.col1 or $.col2)"
188 );
189 assert_eq!(
190 eq(col1.clone(), col2.clone()).to_string(),
191 "($.col1 = $.col2)"
192 );
193 assert_eq!(
194 not_eq(col1.clone(), col2.clone()).to_string(),
195 "($.col1 != $.col2)"
196 );
197 assert_eq!(
198 gt(col1.clone(), col2.clone()).to_string(),
199 "($.col1 > $.col2)"
200 );
201 assert_eq!(
202 gt_eq(col1.clone(), col2.clone()).to_string(),
203 "($.col1 >= $.col2)"
204 );
205 assert_eq!(
206 lt(col1.clone(), col2.clone()).to_string(),
207 "($.col1 < $.col2)"
208 );
209 assert_eq!(
210 lt_eq(col1.clone(), col2.clone()).to_string(),
211 "($.col1 <= $.col2)"
212 );
213
214 assert_eq!(
215 or(lt(col1.clone(), col2.clone()), not_eq(col1.clone(), col2),).to_string(),
216 "(($.col1 < $.col2) or ($.col1 != $.col2))"
217 );
218
219 assert_eq!(not(col1).to_string(), "not($.col1)");
220
221 assert_eq!(
222 select(vec![FieldName::from("col1")], root()).to_string(),
223 "${col1}"
224 );
225 assert_eq!(
226 select(
227 vec![FieldName::from("col1"), FieldName::from("col2")],
228 root()
229 )
230 .to_string(),
231 "${col1, col2}"
232 );
233 assert_eq!(
234 select_exclude(
235 vec![FieldName::from("col1"), FieldName::from("col2")],
236 root()
237 )
238 .to_string(),
239 "${~ col1, col2}"
240 );
241
242 assert_eq!(lit(Scalar::from(0u8)).to_string(), "0u8");
243 assert_eq!(lit(Scalar::from(0.0f32)).to_string(), "0f32");
244 assert_eq!(
245 lit(Scalar::from(i64::MAX)).to_string(),
246 "9223372036854775807i64"
247 );
248 assert_eq!(lit(Scalar::from(true)).to_string(), "true");
249 assert_eq!(
250 lit(Scalar::null(DType::Bool(Nullability::Nullable))).to_string(),
251 "null"
252 );
253
254 assert_eq!(
255 lit(Scalar::struct_(
256 DType::Struct(
257 StructFields::new(
258 FieldNames::from(["dog", "cat"]),
259 vec![
260 DType::Primitive(PType::U32, Nullability::NonNullable),
261 DType::Utf8(Nullability::NonNullable)
262 ],
263 ),
264 Nullability::NonNullable
265 ),
266 vec![Scalar::from(32_u32), Scalar::from("rufus".to_string())]
267 ))
268 .to_string(),
269 "{dog: 32u32, cat: \"rufus\"}"
270 );
271 }
272}