vortex_array/scalar_fn/fns/
dynamic.rs1use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::hash::Hasher;
9use std::sync::Arc;
10
11use parking_lot::Mutex;
12use vortex_error::VortexExpect;
13use vortex_error::VortexResult;
14use vortex_error::vortex_bail;
15
16use crate::Array;
17use crate::ArrayRef;
18use crate::IntoArray;
19use crate::arrays::ConstantArray;
20use crate::dtype::DType;
21use crate::expr::Expression;
22use crate::expr::StatsCatalog;
23use crate::expr::traversal::NodeExt;
24use crate::expr::traversal::NodeVisitor;
25use crate::expr::traversal::TraversalOrder;
26use crate::scalar::Scalar;
27use crate::scalar::ScalarValue;
28use crate::scalar_fn::Arity;
29use crate::scalar_fn::ChildName;
30use crate::scalar_fn::ExecutionArgs;
31use crate::scalar_fn::ScalarFnId;
32use crate::scalar_fn::ScalarFnVTable;
33use crate::scalar_fn::ScalarFnVTableExt;
34use crate::scalar_fn::fns::binary::Binary;
35use crate::scalar_fn::fns::operators::CompareOperator;
36use crate::scalar_fn::fns::operators::Operator;
37
38#[derive(Clone)]
42pub struct DynamicComparison;
43
44impl ScalarFnVTable for DynamicComparison {
45 type Options = DynamicComparisonExpr;
46
47 fn id(&self) -> ScalarFnId {
48 ScalarFnId::new_ref("vortex.dynamic")
49 }
50
51 fn arity(&self, _options: &Self::Options) -> Arity {
52 Arity::Exact(1)
53 }
54
55 fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
56 match child_idx {
57 0 => ChildName::from("lhs"),
58 _ => unreachable!(),
59 }
60 }
61
62 fn fmt_sql(
63 &self,
64 dynamic: &DynamicComparisonExpr,
65 expr: &Expression,
66 f: &mut Formatter<'_>,
67 ) -> std::fmt::Result {
68 expr.child(0).fmt_sql(f)?;
69 write!(f, " {} dynamic(", dynamic)?;
70 match dynamic.scalar() {
71 None => write!(f, "<none>")?,
72 Some(scalar) => write!(f, "{}", scalar)?,
73 }
74 write!(f, ")")
75 }
76
77 fn return_dtype(
78 &self,
79 dynamic: &DynamicComparisonExpr,
80 arg_dtypes: &[DType],
81 ) -> VortexResult<DType> {
82 let lhs = &arg_dtypes[0];
83 if !dynamic.rhs.dtype.eq_ignore_nullability(lhs) {
84 vortex_bail!(
85 "Incompatible dtypes for dynamic comparison: expected {} (ignore nullability) but got {}",
86 &dynamic.rhs.dtype,
87 lhs
88 );
89 }
90 Ok(DType::Bool(
91 lhs.nullability() | dynamic.rhs.dtype.nullability(),
92 ))
93 }
94
95 fn execute(&self, data: &Self::Options, args: ExecutionArgs) -> VortexResult<ArrayRef> {
96 if let Some(scalar) = data.rhs.scalar() {
97 let [lhs]: [ArrayRef; _] = args
98 .inputs
99 .try_into()
100 .map_err(|_| vortex_error::vortex_err!("Wrong arg count for DynamicComparison"))?;
101 let rhs = ConstantArray::new(scalar, args.row_count).into_array();
102
103 return Binary
104 .bind(Operator::from(data.operator))
105 .execute(ExecutionArgs {
106 inputs: vec![lhs, rhs],
107 row_count: args.row_count,
108 ctx: args.ctx,
109 });
110 }
111 let ret_dtype =
112 DType::Bool(args.inputs[0].dtype().nullability() | data.rhs.dtype.nullability());
113
114 Ok(ConstantArray::new(
115 Scalar::try_new(ret_dtype, Some(data.default.into()))?,
116 args.row_count,
117 )
118 .into_array())
119 }
120
121 fn stat_falsification(
122 &self,
123 dynamic: &DynamicComparisonExpr,
124 expr: &Expression,
125 catalog: &dyn StatsCatalog,
126 ) -> Option<Expression> {
127 let lhs = expr.child(0);
128 match dynamic.operator {
129 CompareOperator::Eq | CompareOperator::NotEq => None,
130 CompareOperator::Gt => Some(DynamicComparison.new_expr(
131 DynamicComparisonExpr {
132 operator: CompareOperator::Lte,
133 rhs: dynamic.rhs.clone(),
134 default: !dynamic.default,
135 },
136 vec![lhs.stat_max(catalog)?],
137 )),
138 CompareOperator::Gte => Some(DynamicComparison.new_expr(
139 DynamicComparisonExpr {
140 operator: CompareOperator::Lt,
141 rhs: dynamic.rhs.clone(),
142 default: !dynamic.default,
143 },
144 vec![lhs.stat_max(catalog)?],
145 )),
146 CompareOperator::Lt => Some(DynamicComparison.new_expr(
147 DynamicComparisonExpr {
148 operator: CompareOperator::Gte,
149 rhs: dynamic.rhs.clone(),
150 default: !dynamic.default,
151 },
152 vec![lhs.stat_min(catalog)?],
153 )),
154 CompareOperator::Lte => Some(DynamicComparison.new_expr(
155 DynamicComparisonExpr {
156 operator: CompareOperator::Gt,
157 rhs: dynamic.rhs.clone(),
158 default: !dynamic.default,
159 },
160 vec![lhs.stat_min(catalog)?],
161 )),
162 }
163 }
164
165 fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
167 false
168 }
169}
170
171#[derive(Clone, Debug)]
172pub struct DynamicComparisonExpr {
173 pub(crate) operator: CompareOperator,
174 pub(crate) rhs: Arc<Rhs>,
175 pub(crate) default: bool,
177}
178
179impl DynamicComparisonExpr {
180 pub fn scalar(&self) -> Option<Scalar> {
181 (self.rhs.value)().map(|v| {
182 Scalar::try_new(self.rhs.dtype.clone(), Some(v))
183 .vortex_expect("`DynamicComparisonExpr` was invalid")
184 })
185 }
186}
187
188impl Display for DynamicComparisonExpr {
189 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
190 write!(
191 f,
192 "{} {}",
193 self.operator,
194 self.scalar()
195 .map_or_else(|| "<none>".to_string(), |v| v.to_string())
196 )
197 }
198}
199
200impl PartialEq for DynamicComparisonExpr {
201 fn eq(&self, other: &Self) -> bool {
202 self.operator == other.operator
203 && Arc::ptr_eq(&self.rhs, &other.rhs)
204 && self.default == other.default
205 }
206}
207impl Eq for DynamicComparisonExpr {}
208
209impl Hash for DynamicComparisonExpr {
210 fn hash<H: Hasher>(&self, state: &mut H) {
211 self.operator.hash(state);
212 Arc::as_ptr(&self.rhs).hash(state);
213 self.default.hash(state);
214 }
215}
216
217pub(crate) struct Rhs {
220 pub(crate) value: Arc<dyn Fn() -> Option<ScalarValue> + Send + Sync>,
222 pub(crate) dtype: DType,
224}
225
226impl Rhs {
227 pub fn scalar(&self) -> Option<Scalar> {
228 (self.value)().map(|v| {
229 Scalar::try_new(self.dtype.clone(), Some(v)).vortex_expect("`Rhs` was invalid")
230 })
231 }
232}
233
234impl Debug for Rhs {
235 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
236 f.debug_struct("Rhs")
237 .field("value", &"<dyn Fn() -> Option<ScalarValue> + Send + Sync>")
238 .field("dtype", &self.dtype)
239 .finish()
240 }
241}
242
243pub struct DynamicExprUpdates {
245 exprs: Box<[DynamicComparisonExpr]>,
246 prev_versions: Mutex<(u64, Vec<Option<Scalar>>)>,
248}
249
250impl DynamicExprUpdates {
251 pub fn new(expr: &Expression) -> Option<Self> {
252 #[derive(Default)]
253 struct Visitor(Vec<DynamicComparisonExpr>);
254
255 impl NodeVisitor<'_> for Visitor {
256 type NodeTy = Expression;
257
258 fn visit_down(&mut self, node: &'_ Self::NodeTy) -> VortexResult<TraversalOrder> {
259 if let Some(dynamic) = node.as_opt::<DynamicComparison>() {
260 self.0.push(dynamic.clone());
261 }
262 Ok(TraversalOrder::Continue)
263 }
264 }
265
266 let mut visitor = Visitor::default();
267 expr.accept(&mut visitor).vortex_expect("Infallible");
268
269 if visitor.0.is_empty() {
270 return None;
271 }
272
273 let exprs = visitor.0.into_boxed_slice();
274 let prev_versions = exprs
275 .iter()
276 .map(|expr| {
277 (expr.rhs.value)().map(|v| {
278 Scalar::try_new(expr.rhs.dtype.clone(), Some(v))
279 .vortex_expect("`DynamicExprUpdates` was invalid")
280 })
281 })
282 .collect();
283
284 Some(Self {
285 exprs,
286 prev_versions: Mutex::new((0, prev_versions)),
287 })
288 }
289
290 pub fn version(&self) -> u64 {
291 let mut guard = self.prev_versions.lock();
292
293 let mut updated = false;
294 for (i, expr) in self.exprs.iter().enumerate() {
295 let current = expr.scalar();
296 if current != guard.1[i] {
297 updated = true;
301 guard.1[i] = current;
302 }
303 }
304
305 if updated {
306 guard.0 += 1;
307 }
308
309 guard.0
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use std::sync::atomic::AtomicI32;
316 use std::sync::atomic::Ordering;
317
318 use vortex_buffer::buffer;
319 use vortex_error::VortexResult;
320
321 use super::*;
322 use crate::IntoArray;
323 use crate::arrays::BoolArray;
324 use crate::assert_arrays_eq;
325 use crate::dtype::DType;
326 use crate::dtype::Nullability;
327 use crate::dtype::PType;
328 use crate::expr::dynamic;
329 use crate::expr::root;
330 #[test]
331 fn return_dtype_bool() -> VortexResult<()> {
332 let expr = dynamic(
333 CompareOperator::Lt,
334 || Some(5i32.into()),
335 DType::Primitive(PType::I32, Nullability::NonNullable),
336 true,
337 root(),
338 );
339 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
340 assert_eq!(
341 expr.return_dtype(&input_dtype)?,
342 DType::Bool(Nullability::NonNullable)
343 );
344 Ok(())
345 }
346
347 #[test]
348 fn execute_with_value() -> VortexResult<()> {
349 let input = buffer![1i32, 5, 10].into_array();
350 let expr = dynamic(
351 CompareOperator::Lt,
352 || Some(5i32.into()),
353 DType::Primitive(PType::I32, Nullability::NonNullable),
354 true,
355 root(),
356 );
357 let result = input.apply(&expr)?;
358 assert_arrays_eq!(result, BoolArray::from_iter([true, false, false]));
359 Ok(())
360 }
361
362 #[test]
363 fn execute_without_value_default_true() -> VortexResult<()> {
364 let input = buffer![1i32, 5, 10].into_array();
365 let expr = dynamic(
366 CompareOperator::Lt,
367 || None,
368 DType::Primitive(PType::I32, Nullability::NonNullable),
369 true,
370 root(),
371 );
372 let result = input.apply(&expr)?;
373 assert_arrays_eq!(result, BoolArray::from_iter([true, true, true]));
374 Ok(())
375 }
376
377 #[test]
378 fn execute_without_value_default_false() -> VortexResult<()> {
379 let input = buffer![1i32, 5, 10].into_array();
380 let expr = dynamic(
381 CompareOperator::Lt,
382 || None,
383 DType::Primitive(PType::I32, Nullability::NonNullable),
384 false,
385 root(),
386 );
387 let result = input.apply(&expr)?;
388 assert_arrays_eq!(result, BoolArray::from_iter([false, false, false]));
389 Ok(())
390 }
391
392 #[test]
393 fn execute_value_flips() -> VortexResult<()> {
394 let threshold = Arc::new(AtomicI32::new(5));
395 let threshold_clone = threshold.clone();
396 let expr = dynamic(
397 CompareOperator::Lt,
398 move || Some(threshold_clone.load(Ordering::SeqCst).into()),
399 DType::Primitive(PType::I32, Nullability::NonNullable),
400 true,
401 root(),
402 );
403 let input = buffer![1i32, 5, 10].into_array();
404
405 let result = input.apply(&expr)?;
406 assert_arrays_eq!(result, BoolArray::from_iter([true, false, false]));
407
408 threshold.store(10, Ordering::SeqCst);
409 let result = input.apply(&expr)?;
410 assert_arrays_eq!(result, BoolArray::from_iter([true, true, false]));
411
412 Ok(())
413 }
414}