vortex_array/scalar_fn/fns/
dynamic.rs1use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::hash::Hasher;
9use std::sync::Arc;
10
11use parking_lot::Mutex;
12use vortex_error::VortexExpect;
13use vortex_error::VortexResult;
14use vortex_error::vortex_bail;
15use vortex_session::registry::CachedId;
16
17use crate::ArrayRef;
18use crate::ExecutionCtx;
19use crate::IntoArray;
20use crate::arrays::ConstantArray;
21use crate::dtype::DType;
22use crate::expr::Expression;
23use crate::expr::StatsCatalog;
24use crate::expr::traversal::NodeExt;
25use crate::expr::traversal::NodeVisitor;
26use crate::expr::traversal::TraversalOrder;
27use crate::scalar::Scalar;
28use crate::scalar::ScalarValue;
29use crate::scalar_fn::Arity;
30use crate::scalar_fn::ChildName;
31use crate::scalar_fn::ExecutionArgs;
32use crate::scalar_fn::ScalarFnId;
33use crate::scalar_fn::ScalarFnVTable;
34use crate::scalar_fn::ScalarFnVTableExt;
35use crate::scalar_fn::VecExecutionArgs;
36use crate::scalar_fn::fns::binary::Binary;
37use crate::scalar_fn::fns::operators::CompareOperator;
38use crate::scalar_fn::fns::operators::Operator;
39
40#[derive(Clone)]
44pub struct DynamicComparison;
45
46impl ScalarFnVTable for DynamicComparison {
47 type Options = DynamicComparisonExpr;
48
49 fn id(&self) -> ScalarFnId {
50 static ID: CachedId = CachedId::new("vortex.dynamic");
51 *ID
52 }
53
54 fn arity(&self, _options: &Self::Options) -> Arity {
55 Arity::Exact(1)
56 }
57
58 fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
59 match child_idx {
60 0 => ChildName::from("lhs"),
61 _ => unreachable!(),
62 }
63 }
64
65 fn fmt_sql(
66 &self,
67 dynamic: &DynamicComparisonExpr,
68 expr: &Expression,
69 f: &mut Formatter<'_>,
70 ) -> std::fmt::Result {
71 expr.child(0).fmt_sql(f)?;
72 write!(f, " {} dynamic(", dynamic.operator)?;
73 match dynamic.scalar() {
74 None => write!(f, "scalar=<none>")?,
75 Some(scalar) => write!(f, "scalar={scalar}")?,
76 }
77 write!(f, ")")
78 }
79
80 fn return_dtype(
81 &self,
82 dynamic: &DynamicComparisonExpr,
83 arg_dtypes: &[DType],
84 ) -> VortexResult<DType> {
85 let lhs = &arg_dtypes[0];
86 if !dynamic.rhs.dtype.eq_ignore_nullability(lhs) {
87 vortex_bail!(
88 "Incompatible dtypes for dynamic comparison: expected {} (ignore nullability) but got {}",
89 &dynamic.rhs.dtype,
90 lhs
91 );
92 }
93 Ok(DType::Bool(
94 lhs.nullability() | dynamic.rhs.dtype.nullability(),
95 ))
96 }
97
98 fn execute(
99 &self,
100 data: &Self::Options,
101 args: &dyn ExecutionArgs,
102 ctx: &mut ExecutionCtx,
103 ) -> VortexResult<ArrayRef> {
104 if let Some(scalar) = data.rhs.scalar() {
105 let lhs = args.get(0)?;
106 let rhs = ConstantArray::new(scalar, args.row_count()).into_array();
107
108 let delegate_args = VecExecutionArgs::new(vec![lhs, rhs], args.row_count());
109 return Binary
110 .bind(Operator::from(data.operator))
111 .execute(&delegate_args, ctx);
112 }
113 let ret_dtype =
114 DType::Bool(args.get(0)?.dtype().nullability() | data.rhs.dtype.nullability());
115
116 Ok(ConstantArray::new(
117 Scalar::try_new(ret_dtype, Some(data.default.into()))?,
118 args.row_count(),
119 )
120 .into_array())
121 }
122
123 fn stat_falsification(
124 &self,
125 dynamic: &DynamicComparisonExpr,
126 expr: &Expression,
127 catalog: &dyn StatsCatalog,
128 ) -> Option<Expression> {
129 let lhs = expr.child(0);
130 match dynamic.operator {
131 CompareOperator::Eq | CompareOperator::NotEq => None,
132 CompareOperator::Gt => Some(DynamicComparison.new_expr(
133 DynamicComparisonExpr {
134 operator: CompareOperator::Lte,
135 rhs: Arc::clone(&dynamic.rhs),
136 default: !dynamic.default,
137 },
138 vec![lhs.stat_max(catalog)?],
139 )),
140 CompareOperator::Gte => Some(DynamicComparison.new_expr(
141 DynamicComparisonExpr {
142 operator: CompareOperator::Lt,
143 rhs: Arc::clone(&dynamic.rhs),
144 default: !dynamic.default,
145 },
146 vec![lhs.stat_max(catalog)?],
147 )),
148 CompareOperator::Lt => Some(DynamicComparison.new_expr(
149 DynamicComparisonExpr {
150 operator: CompareOperator::Gte,
151 rhs: Arc::clone(&dynamic.rhs),
152 default: !dynamic.default,
153 },
154 vec![lhs.stat_min(catalog)?],
155 )),
156 CompareOperator::Lte => Some(DynamicComparison.new_expr(
157 DynamicComparisonExpr {
158 operator: CompareOperator::Gt,
159 rhs: Arc::clone(&dynamic.rhs),
160 default: !dynamic.default,
161 },
162 vec![lhs.stat_min(catalog)?],
163 )),
164 }
165 }
166
167 fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
169 false
170 }
171}
172
173#[derive(Clone, Debug)]
174pub struct DynamicComparisonExpr {
175 pub(crate) operator: CompareOperator,
176 pub(crate) rhs: Arc<Rhs>,
177 pub(crate) default: bool,
179}
180
181impl DynamicComparisonExpr {
182 pub fn scalar(&self) -> Option<Scalar> {
183 (self.rhs.value)().map(|v| {
184 Scalar::try_new(self.rhs.dtype.clone(), Some(v))
185 .vortex_expect("`DynamicComparisonExpr` was invalid")
186 })
187 }
188}
189
190impl Display for DynamicComparisonExpr {
191 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
192 write!(
193 f,
194 "{} {}",
195 self.operator,
196 self.scalar()
197 .map_or_else(|| "<none>".to_string(), |v| v.to_string())
198 )
199 }
200}
201
202impl PartialEq for DynamicComparisonExpr {
203 fn eq(&self, other: &Self) -> bool {
204 self.operator == other.operator
205 && Arc::ptr_eq(&self.rhs, &other.rhs)
206 && self.default == other.default
207 }
208}
209impl Eq for DynamicComparisonExpr {}
210
211impl Hash for DynamicComparisonExpr {
212 fn hash<H: Hasher>(&self, state: &mut H) {
213 self.operator.hash(state);
214 Arc::as_ptr(&self.rhs).hash(state);
215 self.default.hash(state);
216 }
217}
218
219pub(crate) struct Rhs {
222 pub(crate) value: Arc<dyn Fn() -> Option<ScalarValue> + Send + Sync>,
224 pub(crate) dtype: DType,
226}
227
228impl Rhs {
229 pub fn scalar(&self) -> Option<Scalar> {
230 (self.value)().map(|v| {
231 Scalar::try_new(self.dtype.clone(), Some(v)).vortex_expect("`Rhs` was invalid")
232 })
233 }
234}
235
236impl Debug for Rhs {
237 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
238 f.debug_struct("Rhs")
239 .field("value", &"<dyn Fn() -> Option<ScalarValue> + Send + Sync>")
240 .field("dtype", &self.dtype)
241 .finish()
242 }
243}
244
245pub struct DynamicExprUpdates {
247 exprs: Box<[DynamicComparisonExpr]>,
248 prev_versions: Mutex<(u64, Vec<Option<Scalar>>)>,
250}
251
252impl DynamicExprUpdates {
253 pub fn new(expr: &Expression) -> Option<Self> {
254 #[derive(Default)]
255 struct Visitor(Vec<DynamicComparisonExpr>);
256
257 impl NodeVisitor<'_> for Visitor {
258 type NodeTy = Expression;
259
260 fn visit_down(&mut self, node: &'_ Self::NodeTy) -> VortexResult<TraversalOrder> {
261 if let Some(dynamic) = node.as_opt::<DynamicComparison>() {
262 self.0.push(dynamic.clone());
263 }
264 Ok(TraversalOrder::Continue)
265 }
266 }
267
268 let mut visitor = Visitor::default();
269 expr.accept(&mut visitor).vortex_expect("Infallible");
270
271 if visitor.0.is_empty() {
272 return None;
273 }
274
275 let exprs = visitor.0.into_boxed_slice();
276 let prev_versions = exprs
277 .iter()
278 .map(|expr| {
279 (expr.rhs.value)().map(|v| {
280 Scalar::try_new(expr.rhs.dtype.clone(), Some(v))
281 .vortex_expect("`DynamicExprUpdates` was invalid")
282 })
283 })
284 .collect();
285
286 Some(Self {
287 exprs,
288 prev_versions: Mutex::new((0, prev_versions)),
289 })
290 }
291
292 pub fn version(&self) -> u64 {
293 let mut guard = self.prev_versions.lock();
294
295 let mut updated = false;
296 for (i, expr) in self.exprs.iter().enumerate() {
297 let current = expr.scalar();
298 if current != guard.1[i] {
299 updated = true;
303 guard.1[i] = current;
304 }
305 }
306
307 if updated {
308 guard.0 += 1;
309 }
310
311 guard.0
312 }
313}
314
315#[cfg(test)]
316mod tests {
317 use std::sync::atomic::AtomicI32;
318 use std::sync::atomic::Ordering;
319
320 use vortex_buffer::buffer;
321 use vortex_error::VortexResult;
322
323 use super::*;
324 use crate::IntoArray;
325 use crate::arrays::BoolArray;
326 use crate::assert_arrays_eq;
327 use crate::dtype::DType;
328 use crate::dtype::Nullability;
329 use crate::dtype::PType;
330 use crate::expr::dynamic;
331 use crate::expr::root;
332 #[test]
333 fn return_dtype_bool() -> VortexResult<()> {
334 let expr = dynamic(
335 CompareOperator::Lt,
336 || Some(5i32.into()),
337 DType::Primitive(PType::I32, Nullability::NonNullable),
338 true,
339 root(),
340 );
341 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
342 assert_eq!(
343 expr.return_dtype(&input_dtype)?,
344 DType::Bool(Nullability::NonNullable)
345 );
346 Ok(())
347 }
348
349 #[test]
350 fn execute_with_value() -> VortexResult<()> {
351 let input = buffer![1i32, 5, 10].into_array();
352 let expr = dynamic(
353 CompareOperator::Lt,
354 || Some(5i32.into()),
355 DType::Primitive(PType::I32, Nullability::NonNullable),
356 true,
357 root(),
358 );
359 let result = input.apply(&expr)?;
360 assert_arrays_eq!(result, BoolArray::from_iter([true, false, false]));
361 Ok(())
362 }
363
364 #[test]
365 fn execute_without_value_default_true() -> VortexResult<()> {
366 let input = buffer![1i32, 5, 10].into_array();
367 let expr = dynamic(
368 CompareOperator::Lt,
369 || None,
370 DType::Primitive(PType::I32, Nullability::NonNullable),
371 true,
372 root(),
373 );
374 let result = input.apply(&expr)?;
375 assert_arrays_eq!(result, BoolArray::from_iter([true, true, true]));
376 Ok(())
377 }
378
379 #[test]
380 fn execute_without_value_default_false() -> VortexResult<()> {
381 let input = buffer![1i32, 5, 10].into_array();
382 let expr = dynamic(
383 CompareOperator::Lt,
384 || None,
385 DType::Primitive(PType::I32, Nullability::NonNullable),
386 false,
387 root(),
388 );
389 let result = input.apply(&expr)?;
390 assert_arrays_eq!(result, BoolArray::from_iter([false, false, false]));
391 Ok(())
392 }
393
394 #[test]
395 fn execute_value_flips() -> VortexResult<()> {
396 let threshold = Arc::new(AtomicI32::new(5));
397 let threshold_clone = Arc::clone(&threshold);
398 let expr = dynamic(
399 CompareOperator::Lt,
400 move || Some(threshold_clone.load(Ordering::SeqCst).into()),
401 DType::Primitive(PType::I32, Nullability::NonNullable),
402 true,
403 root(),
404 );
405 let input = buffer![1i32, 5, 10].into_array();
406
407 let result = input.clone().apply(&expr)?;
408 assert_arrays_eq!(result, BoolArray::from_iter([true, false, false]));
409
410 threshold.store(10, Ordering::SeqCst);
411 let result = input.apply(&expr)?;
412 assert_arrays_eq!(result, BoolArray::from_iter([true, true, false]));
413
414 Ok(())
415 }
416}