1use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::hash::Hasher;
9use std::sync::Arc;
10
11use parking_lot::Mutex;
12use vortex_dtype::DType;
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15use vortex_error::vortex_bail;
16
17use crate::Array;
18use crate::ArrayRef;
19use crate::IntoArray;
20use crate::arrays::ConstantArray;
21use crate::compute::Operator;
22use crate::expr::Arity;
23use crate::expr::Binary;
24use crate::expr::ChildName;
25use crate::expr::ExecutionArgs;
26use crate::expr::ExprId;
27use crate::expr::Expression;
28use crate::expr::StatsCatalog;
29use crate::expr::VTable;
30use crate::expr::VTableExt;
31use crate::expr::traversal::NodeExt;
32use crate::expr::traversal::NodeVisitor;
33use crate::expr::traversal::TraversalOrder;
34use crate::scalar::Scalar;
35use crate::scalar::ScalarValue;
36
37pub struct DynamicComparison;
41
42impl VTable for DynamicComparison {
43 type Options = DynamicComparisonExpr;
44
45 fn id(&self) -> ExprId {
46 ExprId::new_ref("vortex.dynamic")
47 }
48
49 fn arity(&self, _options: &Self::Options) -> Arity {
50 Arity::Exact(1)
51 }
52
53 fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
54 match child_idx {
55 0 => ChildName::from("lhs"),
56 _ => unreachable!(),
57 }
58 }
59
60 fn fmt_sql(
61 &self,
62 dynamic: &DynamicComparisonExpr,
63 expr: &Expression,
64 f: &mut Formatter<'_>,
65 ) -> std::fmt::Result {
66 expr.child(0).fmt_sql(f)?;
67 write!(f, " {} dynamic(", dynamic)?;
68 match dynamic.scalar() {
69 None => write!(f, "<none>")?,
70 Some(scalar) => write!(f, "{}", scalar)?,
71 }
72 write!(f, ")")
73 }
74
75 fn return_dtype(
76 &self,
77 dynamic: &DynamicComparisonExpr,
78 arg_dtypes: &[DType],
79 ) -> VortexResult<DType> {
80 let lhs = &arg_dtypes[0];
81 if !dynamic.rhs.dtype.eq_ignore_nullability(lhs) {
82 vortex_bail!(
83 "Incompatible dtypes for dynamic comparison: expected {} (ignore nullability) but got {}",
84 &dynamic.rhs.dtype,
85 lhs
86 );
87 }
88 Ok(DType::Bool(
89 lhs.nullability() | dynamic.rhs.dtype.nullability(),
90 ))
91 }
92
93 fn execute(&self, data: &Self::Options, args: ExecutionArgs) -> VortexResult<ArrayRef> {
94 if let Some(scalar) = data.rhs.scalar() {
95 let [lhs]: [ArrayRef; _] = args
96 .inputs
97 .try_into()
98 .map_err(|_| vortex_error::vortex_err!("Wrong arg count for DynamicComparison"))?;
99 let rhs = ConstantArray::new(scalar, args.row_count).into_array();
100
101 return Binary.bind(data.operator.into()).execute(ExecutionArgs {
102 inputs: vec![lhs, rhs],
103 row_count: args.row_count,
104 ctx: args.ctx,
105 });
106 }
107 let ret_dtype =
108 DType::Bool(args.inputs[0].dtype().nullability() | data.rhs.dtype.nullability());
109
110 Ok(ConstantArray::new(
111 Scalar::try_new(ret_dtype, Some(data.default.into()))?,
112 args.row_count,
113 )
114 .into_array())
115 }
116
117 fn stat_falsification(
118 &self,
119 dynamic: &DynamicComparisonExpr,
120 expr: &Expression,
121 catalog: &dyn StatsCatalog,
122 ) -> Option<Expression> {
123 let lhs = expr.child(0);
124 match dynamic.operator {
125 Operator::Gt => Some(DynamicComparison.new_expr(
126 DynamicComparisonExpr {
127 operator: Operator::Lte,
128 rhs: dynamic.rhs.clone(),
129 default: !dynamic.default,
130 },
131 vec![lhs.stat_max(catalog)?],
132 )),
133 Operator::Gte => Some(DynamicComparison.new_expr(
134 DynamicComparisonExpr {
135 operator: Operator::Lt,
136 rhs: dynamic.rhs.clone(),
137 default: !dynamic.default,
138 },
139 vec![lhs.stat_max(catalog)?],
140 )),
141 Operator::Lt => Some(DynamicComparison.new_expr(
142 DynamicComparisonExpr {
143 operator: Operator::Gte,
144 rhs: dynamic.rhs.clone(),
145 default: !dynamic.default,
146 },
147 vec![lhs.stat_min(catalog)?],
148 )),
149 Operator::Lte => Some(DynamicComparison.new_expr(
150 DynamicComparisonExpr {
151 operator: Operator::Gt,
152 rhs: dynamic.rhs.clone(),
153 default: !dynamic.default,
154 },
155 vec![lhs.stat_min(catalog)?],
156 )),
157 _ => None,
158 }
159 }
160
161 fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
163 false
164 }
165}
166
167pub fn dynamic(
168 operator: Operator,
169 rhs_value: impl Fn() -> Option<ScalarValue> + Send + Sync + 'static,
170 rhs_dtype: DType,
171 default: bool,
172 lhs: Expression,
173) -> Expression {
174 DynamicComparison.new_expr(
175 DynamicComparisonExpr {
176 operator,
177 rhs: Arc::new(Rhs {
178 value: Arc::new(rhs_value),
179 dtype: rhs_dtype,
180 }),
181 default,
182 },
183 [lhs],
184 )
185}
186
187#[derive(Clone, Debug)]
188pub struct DynamicComparisonExpr {
189 operator: Operator,
190 rhs: Arc<Rhs>,
191 default: bool,
193}
194
195impl DynamicComparisonExpr {
196 pub fn scalar(&self) -> Option<Scalar> {
197 (self.rhs.value)().map(|v| {
198 Scalar::try_new(self.rhs.dtype.clone(), Some(v))
199 .vortex_expect("`DynamicComparisonExpr` was invalid")
200 })
201 }
202}
203
204impl Display for DynamicComparisonExpr {
205 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
206 write!(
207 f,
208 "{} {}",
209 self.operator,
210 self.scalar()
211 .map_or_else(|| "<none>".to_string(), |v| v.to_string())
212 )
213 }
214}
215
216impl PartialEq for DynamicComparisonExpr {
217 fn eq(&self, other: &Self) -> bool {
218 self.operator == other.operator
219 && Arc::ptr_eq(&self.rhs, &other.rhs)
220 && self.default == other.default
221 }
222}
223impl Eq for DynamicComparisonExpr {}
224
225impl Hash for DynamicComparisonExpr {
226 fn hash<H: Hasher>(&self, state: &mut H) {
227 self.operator.hash(state);
228 Arc::as_ptr(&self.rhs).hash(state);
229 self.default.hash(state);
230 }
231}
232
233struct Rhs {
236 value: Arc<dyn Fn() -> Option<ScalarValue> + Send + Sync>,
238 dtype: DType,
240}
241
242impl Rhs {
243 pub fn scalar(&self) -> Option<Scalar> {
244 (self.value)().map(|v| {
245 Scalar::try_new(self.dtype.clone(), Some(v)).vortex_expect("`Rhs` was invalid")
246 })
247 }
248}
249
250impl Debug for Rhs {
251 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
252 f.debug_struct("Rhs")
253 .field("value", &"<dyn Fn() -> Option<ScalarValue> + Send + Sync>")
254 .field("dtype", &self.dtype)
255 .finish()
256 }
257}
258
259pub struct DynamicExprUpdates {
261 exprs: Box<[DynamicComparisonExpr]>,
262 prev_versions: Mutex<(u64, Vec<Option<Scalar>>)>,
264}
265
266impl DynamicExprUpdates {
267 pub fn new(expr: &Expression) -> Option<Self> {
268 #[derive(Default)]
269 struct Visitor(Vec<DynamicComparisonExpr>);
270
271 impl NodeVisitor<'_> for Visitor {
272 type NodeTy = Expression;
273
274 fn visit_down(&mut self, node: &'_ Self::NodeTy) -> VortexResult<TraversalOrder> {
275 if let Some(dynamic) = node.as_opt::<DynamicComparison>() {
276 self.0.push(dynamic.clone());
277 }
278 Ok(TraversalOrder::Continue)
279 }
280 }
281
282 let mut visitor = Visitor::default();
283 expr.accept(&mut visitor).vortex_expect("Infallible");
284
285 if visitor.0.is_empty() {
286 return None;
287 }
288
289 let exprs = visitor.0.into_boxed_slice();
290 let prev_versions = exprs
291 .iter()
292 .map(|expr| {
293 (expr.rhs.value)().map(|v| {
294 Scalar::try_new(expr.rhs.dtype.clone(), Some(v))
295 .vortex_expect("`DynamicExprUpdates` was invalid")
296 })
297 })
298 .collect();
299
300 Some(Self {
301 exprs,
302 prev_versions: Mutex::new((0, prev_versions)),
303 })
304 }
305
306 pub fn version(&self) -> u64 {
307 let mut guard = self.prev_versions.lock();
308
309 let mut updated = false;
310 for (i, expr) in self.exprs.iter().enumerate() {
311 let current = expr.scalar();
312 if current != guard.1[i] {
313 updated = true;
317 guard.1[i] = current;
318 }
319 }
320
321 if updated {
322 guard.0 += 1;
323 }
324
325 guard.0
326 }
327}
328
329#[cfg(test)]
330mod tests {
331 use std::sync::atomic::AtomicI32;
332 use std::sync::atomic::Ordering;
333
334 use vortex_buffer::buffer;
335 use vortex_dtype::DType;
336 use vortex_dtype::Nullability;
337 use vortex_dtype::PType;
338 use vortex_error::VortexResult;
339
340 use super::*;
341 use crate::IntoArray;
342 use crate::arrays::BoolArray;
343 use crate::assert_arrays_eq;
344 use crate::expr::exprs::root::root;
345
346 #[test]
347 fn return_dtype_bool() -> VortexResult<()> {
348 let expr = dynamic(
349 Operator::Lt,
350 || Some(5i32.into()),
351 DType::Primitive(PType::I32, Nullability::NonNullable),
352 true,
353 root(),
354 );
355 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
356 assert_eq!(
357 expr.return_dtype(&input_dtype)?,
358 DType::Bool(Nullability::NonNullable)
359 );
360 Ok(())
361 }
362
363 #[test]
364 fn execute_with_value() -> VortexResult<()> {
365 let input = buffer![1i32, 5, 10].into_array();
366 let expr = dynamic(
367 Operator::Lt,
368 || Some(5i32.into()),
369 DType::Primitive(PType::I32, Nullability::NonNullable),
370 true,
371 root(),
372 );
373 let result = input.apply(&expr)?;
374 assert_arrays_eq!(result, BoolArray::from_iter([true, false, false]));
375 Ok(())
376 }
377
378 #[test]
379 fn execute_without_value_default_true() -> VortexResult<()> {
380 let input = buffer![1i32, 5, 10].into_array();
381 let expr = dynamic(
382 Operator::Lt,
383 || None,
384 DType::Primitive(PType::I32, Nullability::NonNullable),
385 true,
386 root(),
387 );
388 let result = input.apply(&expr)?;
389 assert_arrays_eq!(result, BoolArray::from_iter([true, true, true]));
390 Ok(())
391 }
392
393 #[test]
394 fn execute_without_value_default_false() -> VortexResult<()> {
395 let input = buffer![1i32, 5, 10].into_array();
396 let expr = dynamic(
397 Operator::Lt,
398 || None,
399 DType::Primitive(PType::I32, Nullability::NonNullable),
400 false,
401 root(),
402 );
403 let result = input.apply(&expr)?;
404 assert_arrays_eq!(result, BoolArray::from_iter([false, false, false]));
405 Ok(())
406 }
407
408 #[test]
409 fn execute_value_flips() -> VortexResult<()> {
410 let threshold = Arc::new(AtomicI32::new(5));
411 let threshold_clone = threshold.clone();
412 let expr = dynamic(
413 Operator::Lt,
414 move || Some(threshold_clone.load(Ordering::SeqCst).into()),
415 DType::Primitive(PType::I32, Nullability::NonNullable),
416 true,
417 root(),
418 );
419 let input = buffer![1i32, 5, 10].into_array();
420
421 let result = input.apply(&expr)?;
422 assert_arrays_eq!(result, BoolArray::from_iter([true, false, false]));
423
424 threshold.store(10, Ordering::SeqCst);
425 let result = input.apply(&expr)?;
426 assert_arrays_eq!(result, BoolArray::from_iter([true, true, false]));
427
428 Ok(())
429 }
430}