vortex_array/scalar_fn/fns/
dynamic.rs1use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::hash::Hasher;
9use std::sync::Arc;
10
11use parking_lot::Mutex;
12use vortex_error::VortexExpect;
13use vortex_error::VortexResult;
14use vortex_error::vortex_bail;
15
16use crate::Array;
17use crate::ArrayRef;
18use crate::ExecutionCtx;
19use crate::IntoArray;
20use crate::arrays::ConstantArray;
21use crate::dtype::DType;
22use crate::expr::Expression;
23use crate::expr::StatsCatalog;
24use crate::expr::traversal::NodeExt;
25use crate::expr::traversal::NodeVisitor;
26use crate::expr::traversal::TraversalOrder;
27use crate::scalar::Scalar;
28use crate::scalar::ScalarValue;
29use crate::scalar_fn::Arity;
30use crate::scalar_fn::ChildName;
31use crate::scalar_fn::ExecutionArgs;
32use crate::scalar_fn::ScalarFnId;
33use crate::scalar_fn::ScalarFnVTable;
34use crate::scalar_fn::ScalarFnVTableExt;
35use crate::scalar_fn::VecExecutionArgs;
36use crate::scalar_fn::fns::binary::Binary;
37use crate::scalar_fn::fns::operators::CompareOperator;
38use crate::scalar_fn::fns::operators::Operator;
39
40#[derive(Clone)]
44pub struct DynamicComparison;
45
46impl ScalarFnVTable for DynamicComparison {
47 type Options = DynamicComparisonExpr;
48
49 fn id(&self) -> ScalarFnId {
50 ScalarFnId::new_ref("vortex.dynamic")
51 }
52
53 fn arity(&self, _options: &Self::Options) -> Arity {
54 Arity::Exact(1)
55 }
56
57 fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
58 match child_idx {
59 0 => ChildName::from("lhs"),
60 _ => unreachable!(),
61 }
62 }
63
64 fn fmt_sql(
65 &self,
66 dynamic: &DynamicComparisonExpr,
67 expr: &Expression,
68 f: &mut Formatter<'_>,
69 ) -> std::fmt::Result {
70 expr.child(0).fmt_sql(f)?;
71 write!(f, " {} dynamic(", dynamic)?;
72 match dynamic.scalar() {
73 None => write!(f, "<none>")?,
74 Some(scalar) => write!(f, "{}", scalar)?,
75 }
76 write!(f, ")")
77 }
78
79 fn return_dtype(
80 &self,
81 dynamic: &DynamicComparisonExpr,
82 arg_dtypes: &[DType],
83 ) -> VortexResult<DType> {
84 let lhs = &arg_dtypes[0];
85 if !dynamic.rhs.dtype.eq_ignore_nullability(lhs) {
86 vortex_bail!(
87 "Incompatible dtypes for dynamic comparison: expected {} (ignore nullability) but got {}",
88 &dynamic.rhs.dtype,
89 lhs
90 );
91 }
92 Ok(DType::Bool(
93 lhs.nullability() | dynamic.rhs.dtype.nullability(),
94 ))
95 }
96
97 fn execute(
98 &self,
99 data: &Self::Options,
100 args: &dyn ExecutionArgs,
101 ctx: &mut ExecutionCtx,
102 ) -> VortexResult<ArrayRef> {
103 if let Some(scalar) = data.rhs.scalar() {
104 let lhs = args.get(0)?;
105 let rhs = ConstantArray::new(scalar, args.row_count()).into_array();
106
107 let delegate_args = VecExecutionArgs::new(vec![lhs, rhs], args.row_count());
108 return Binary
109 .bind(Operator::from(data.operator))
110 .execute(&delegate_args, ctx);
111 }
112 let ret_dtype =
113 DType::Bool(args.get(0)?.dtype().nullability() | data.rhs.dtype.nullability());
114
115 Ok(ConstantArray::new(
116 Scalar::try_new(ret_dtype, Some(data.default.into()))?,
117 args.row_count(),
118 )
119 .into_array())
120 }
121
122 fn stat_falsification(
123 &self,
124 dynamic: &DynamicComparisonExpr,
125 expr: &Expression,
126 catalog: &dyn StatsCatalog,
127 ) -> Option<Expression> {
128 let lhs = expr.child(0);
129 match dynamic.operator {
130 CompareOperator::Eq | CompareOperator::NotEq => None,
131 CompareOperator::Gt => Some(DynamicComparison.new_expr(
132 DynamicComparisonExpr {
133 operator: CompareOperator::Lte,
134 rhs: dynamic.rhs.clone(),
135 default: !dynamic.default,
136 },
137 vec![lhs.stat_max(catalog)?],
138 )),
139 CompareOperator::Gte => Some(DynamicComparison.new_expr(
140 DynamicComparisonExpr {
141 operator: CompareOperator::Lt,
142 rhs: dynamic.rhs.clone(),
143 default: !dynamic.default,
144 },
145 vec![lhs.stat_max(catalog)?],
146 )),
147 CompareOperator::Lt => Some(DynamicComparison.new_expr(
148 DynamicComparisonExpr {
149 operator: CompareOperator::Gte,
150 rhs: dynamic.rhs.clone(),
151 default: !dynamic.default,
152 },
153 vec![lhs.stat_min(catalog)?],
154 )),
155 CompareOperator::Lte => Some(DynamicComparison.new_expr(
156 DynamicComparisonExpr {
157 operator: CompareOperator::Gt,
158 rhs: dynamic.rhs.clone(),
159 default: !dynamic.default,
160 },
161 vec![lhs.stat_min(catalog)?],
162 )),
163 }
164 }
165
166 fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
168 false
169 }
170}
171
172#[derive(Clone, Debug)]
173pub struct DynamicComparisonExpr {
174 pub(crate) operator: CompareOperator,
175 pub(crate) rhs: Arc<Rhs>,
176 pub(crate) default: bool,
178}
179
180impl DynamicComparisonExpr {
181 pub fn scalar(&self) -> Option<Scalar> {
182 (self.rhs.value)().map(|v| {
183 Scalar::try_new(self.rhs.dtype.clone(), Some(v))
184 .vortex_expect("`DynamicComparisonExpr` was invalid")
185 })
186 }
187}
188
189impl Display for DynamicComparisonExpr {
190 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
191 write!(
192 f,
193 "{} {}",
194 self.operator,
195 self.scalar()
196 .map_or_else(|| "<none>".to_string(), |v| v.to_string())
197 )
198 }
199}
200
201impl PartialEq for DynamicComparisonExpr {
202 fn eq(&self, other: &Self) -> bool {
203 self.operator == other.operator
204 && Arc::ptr_eq(&self.rhs, &other.rhs)
205 && self.default == other.default
206 }
207}
208impl Eq for DynamicComparisonExpr {}
209
210impl Hash for DynamicComparisonExpr {
211 fn hash<H: Hasher>(&self, state: &mut H) {
212 self.operator.hash(state);
213 Arc::as_ptr(&self.rhs).hash(state);
214 self.default.hash(state);
215 }
216}
217
218pub(crate) struct Rhs {
221 pub(crate) value: Arc<dyn Fn() -> Option<ScalarValue> + Send + Sync>,
223 pub(crate) dtype: DType,
225}
226
227impl Rhs {
228 pub fn scalar(&self) -> Option<Scalar> {
229 (self.value)().map(|v| {
230 Scalar::try_new(self.dtype.clone(), Some(v)).vortex_expect("`Rhs` was invalid")
231 })
232 }
233}
234
235impl Debug for Rhs {
236 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
237 f.debug_struct("Rhs")
238 .field("value", &"<dyn Fn() -> Option<ScalarValue> + Send + Sync>")
239 .field("dtype", &self.dtype)
240 .finish()
241 }
242}
243
244pub struct DynamicExprUpdates {
246 exprs: Box<[DynamicComparisonExpr]>,
247 prev_versions: Mutex<(u64, Vec<Option<Scalar>>)>,
249}
250
251impl DynamicExprUpdates {
252 pub fn new(expr: &Expression) -> Option<Self> {
253 #[derive(Default)]
254 struct Visitor(Vec<DynamicComparisonExpr>);
255
256 impl NodeVisitor<'_> for Visitor {
257 type NodeTy = Expression;
258
259 fn visit_down(&mut self, node: &'_ Self::NodeTy) -> VortexResult<TraversalOrder> {
260 if let Some(dynamic) = node.as_opt::<DynamicComparison>() {
261 self.0.push(dynamic.clone());
262 }
263 Ok(TraversalOrder::Continue)
264 }
265 }
266
267 let mut visitor = Visitor::default();
268 expr.accept(&mut visitor).vortex_expect("Infallible");
269
270 if visitor.0.is_empty() {
271 return None;
272 }
273
274 let exprs = visitor.0.into_boxed_slice();
275 let prev_versions = exprs
276 .iter()
277 .map(|expr| {
278 (expr.rhs.value)().map(|v| {
279 Scalar::try_new(expr.rhs.dtype.clone(), Some(v))
280 .vortex_expect("`DynamicExprUpdates` was invalid")
281 })
282 })
283 .collect();
284
285 Some(Self {
286 exprs,
287 prev_versions: Mutex::new((0, prev_versions)),
288 })
289 }
290
291 pub fn version(&self) -> u64 {
292 let mut guard = self.prev_versions.lock();
293
294 let mut updated = false;
295 for (i, expr) in self.exprs.iter().enumerate() {
296 let current = expr.scalar();
297 if current != guard.1[i] {
298 updated = true;
302 guard.1[i] = current;
303 }
304 }
305
306 if updated {
307 guard.0 += 1;
308 }
309
310 guard.0
311 }
312}
313
314#[cfg(test)]
315mod tests {
316 use std::sync::atomic::AtomicI32;
317 use std::sync::atomic::Ordering;
318
319 use vortex_buffer::buffer;
320 use vortex_error::VortexResult;
321
322 use super::*;
323 use crate::IntoArray;
324 use crate::arrays::BoolArray;
325 use crate::assert_arrays_eq;
326 use crate::dtype::DType;
327 use crate::dtype::Nullability;
328 use crate::dtype::PType;
329 use crate::expr::dynamic;
330 use crate::expr::root;
331 #[test]
332 fn return_dtype_bool() -> VortexResult<()> {
333 let expr = dynamic(
334 CompareOperator::Lt,
335 || Some(5i32.into()),
336 DType::Primitive(PType::I32, Nullability::NonNullable),
337 true,
338 root(),
339 );
340 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
341 assert_eq!(
342 expr.return_dtype(&input_dtype)?,
343 DType::Bool(Nullability::NonNullable)
344 );
345 Ok(())
346 }
347
348 #[test]
349 fn execute_with_value() -> VortexResult<()> {
350 let input = buffer![1i32, 5, 10].into_array();
351 let expr = dynamic(
352 CompareOperator::Lt,
353 || Some(5i32.into()),
354 DType::Primitive(PType::I32, Nullability::NonNullable),
355 true,
356 root(),
357 );
358 let result = input.apply(&expr)?;
359 assert_arrays_eq!(result, BoolArray::from_iter([true, false, false]));
360 Ok(())
361 }
362
363 #[test]
364 fn execute_without_value_default_true() -> VortexResult<()> {
365 let input = buffer![1i32, 5, 10].into_array();
366 let expr = dynamic(
367 CompareOperator::Lt,
368 || None,
369 DType::Primitive(PType::I32, Nullability::NonNullable),
370 true,
371 root(),
372 );
373 let result = input.apply(&expr)?;
374 assert_arrays_eq!(result, BoolArray::from_iter([true, true, true]));
375 Ok(())
376 }
377
378 #[test]
379 fn execute_without_value_default_false() -> VortexResult<()> {
380 let input = buffer![1i32, 5, 10].into_array();
381 let expr = dynamic(
382 CompareOperator::Lt,
383 || None,
384 DType::Primitive(PType::I32, Nullability::NonNullable),
385 false,
386 root(),
387 );
388 let result = input.apply(&expr)?;
389 assert_arrays_eq!(result, BoolArray::from_iter([false, false, false]));
390 Ok(())
391 }
392
393 #[test]
394 fn execute_value_flips() -> VortexResult<()> {
395 let threshold = Arc::new(AtomicI32::new(5));
396 let threshold_clone = threshold.clone();
397 let expr = dynamic(
398 CompareOperator::Lt,
399 move || Some(threshold_clone.load(Ordering::SeqCst).into()),
400 DType::Primitive(PType::I32, Nullability::NonNullable),
401 true,
402 root(),
403 );
404 let input = buffer![1i32, 5, 10].into_array();
405
406 let result = input.apply(&expr)?;
407 assert_arrays_eq!(result, BoolArray::from_iter([true, false, false]));
408
409 threshold.store(10, Ordering::SeqCst);
410 let result = input.apply(&expr)?;
411 assert_arrays_eq!(result, BoolArray::from_iter([true, true, false]));
412
413 Ok(())
414 }
415}