vortex_array/scalar_fn/fns/
dynamic.rs1use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::hash::Hasher;
9use std::sync::Arc;
10
11use parking_lot::Mutex;
12use vortex_error::VortexExpect;
13use vortex_error::VortexResult;
14use vortex_error::vortex_bail;
15use vortex_session::registry::CachedId;
16
17use crate::ArrayRef;
18use crate::ExecutionCtx;
19use crate::IntoArray;
20use crate::arrays::ConstantArray;
21use crate::dtype::DType;
22use crate::expr::Expression;
23use crate::expr::traversal::NodeExt;
24use crate::expr::traversal::NodeVisitor;
25use crate::expr::traversal::TraversalOrder;
26use crate::scalar::Scalar;
27use crate::scalar::ScalarValue;
28use crate::scalar_fn::Arity;
29use crate::scalar_fn::ChildName;
30use crate::scalar_fn::ExecutionArgs;
31use crate::scalar_fn::ScalarFnId;
32use crate::scalar_fn::ScalarFnVTable;
33use crate::scalar_fn::ScalarFnVTableExt;
34use crate::scalar_fn::VecExecutionArgs;
35use crate::scalar_fn::fns::binary::Binary;
36use crate::scalar_fn::fns::operators::CompareOperator;
37use crate::scalar_fn::fns::operators::Operator;
38
39#[derive(Clone)]
43pub struct DynamicComparison;
44
45impl ScalarFnVTable for DynamicComparison {
46 type Options = DynamicComparisonExpr;
47
48 fn id(&self) -> ScalarFnId {
49 static ID: CachedId = CachedId::new("vortex.dynamic");
50 *ID
51 }
52
53 fn arity(&self, _options: &Self::Options) -> Arity {
54 Arity::Exact(1)
55 }
56
57 fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
58 match child_idx {
59 0 => ChildName::from("lhs"),
60 _ => unreachable!(),
61 }
62 }
63
64 fn fmt_sql(
65 &self,
66 dynamic: &DynamicComparisonExpr,
67 expr: &Expression,
68 f: &mut Formatter<'_>,
69 ) -> std::fmt::Result {
70 expr.child(0).fmt_sql(f)?;
71 write!(f, " {} dynamic(", dynamic.operator)?;
72 match dynamic.scalar() {
73 None => write!(f, "scalar=<none>")?,
74 Some(scalar) => write!(f, "scalar={scalar}")?,
75 }
76 write!(f, ")")
77 }
78
79 fn return_dtype(
80 &self,
81 dynamic: &DynamicComparisonExpr,
82 arg_dtypes: &[DType],
83 ) -> VortexResult<DType> {
84 let lhs = &arg_dtypes[0];
85 if !dynamic.rhs.dtype.eq_ignore_nullability(lhs) {
86 vortex_bail!(
87 "Incompatible dtypes for dynamic comparison: expected {} (ignore nullability) but got {}",
88 &dynamic.rhs.dtype,
89 lhs
90 );
91 }
92 Ok(DType::Bool(
93 lhs.nullability() | dynamic.rhs.dtype.nullability(),
94 ))
95 }
96
97 fn execute(
98 &self,
99 data: &Self::Options,
100 args: &dyn ExecutionArgs,
101 ctx: &mut ExecutionCtx,
102 ) -> VortexResult<ArrayRef> {
103 if let Some(scalar) = data.rhs.scalar() {
104 let lhs = args.get(0)?;
105 let rhs = ConstantArray::new(scalar, args.row_count()).into_array();
106
107 let delegate_args = VecExecutionArgs::new(vec![lhs, rhs], args.row_count());
108 return Binary
109 .bind(Operator::from(data.operator))
110 .execute(&delegate_args, ctx);
111 }
112 let ret_dtype =
113 DType::Bool(args.get(0)?.dtype().nullability() | data.rhs.dtype.nullability());
114
115 Ok(ConstantArray::new(
116 Scalar::try_new(ret_dtype, Some(data.default.into()))?,
117 args.row_count(),
118 )
119 .into_array())
120 }
121
122 fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
124 false
125 }
126}
127
128#[derive(Clone, Debug)]
129pub struct DynamicComparisonExpr {
130 pub(crate) operator: CompareOperator,
131 pub(crate) rhs: Arc<Rhs>,
132 pub(crate) default: bool,
134}
135
136impl DynamicComparisonExpr {
137 pub fn scalar(&self) -> Option<Scalar> {
138 (self.rhs.value)().map(|v| {
139 Scalar::try_new(self.rhs.dtype.clone(), Some(v))
140 .vortex_expect("`DynamicComparisonExpr` was invalid")
141 })
142 }
143}
144
145impl Display for DynamicComparisonExpr {
146 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
147 write!(
148 f,
149 "{} {}",
150 self.operator,
151 self.scalar()
152 .map_or_else(|| "<none>".to_string(), |v| v.to_string())
153 )
154 }
155}
156
157impl PartialEq for DynamicComparisonExpr {
158 fn eq(&self, other: &Self) -> bool {
159 self.operator == other.operator
160 && Arc::ptr_eq(&self.rhs, &other.rhs)
161 && self.default == other.default
162 }
163}
164impl Eq for DynamicComparisonExpr {}
165
166impl Hash for DynamicComparisonExpr {
167 fn hash<H: Hasher>(&self, state: &mut H) {
168 self.operator.hash(state);
169 Arc::as_ptr(&self.rhs).hash(state);
170 self.default.hash(state);
171 }
172}
173
174pub(crate) struct Rhs {
177 pub(crate) value: Arc<dyn Fn() -> Option<ScalarValue> + Send + Sync>,
179 pub(crate) dtype: DType,
181}
182
183impl Rhs {
184 pub fn scalar(&self) -> Option<Scalar> {
185 (self.value)().map(|v| {
186 Scalar::try_new(self.dtype.clone(), Some(v)).vortex_expect("`Rhs` was invalid")
187 })
188 }
189}
190
191impl Debug for Rhs {
192 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
193 f.debug_struct("Rhs")
194 .field("value", &"<dyn Fn() -> Option<ScalarValue> + Send + Sync>")
195 .field("dtype", &self.dtype)
196 .finish()
197 }
198}
199
200pub struct DynamicExprUpdates {
202 exprs: Box<[DynamicComparisonExpr]>,
203 prev_versions: Mutex<(u64, Vec<Option<Scalar>>)>,
205}
206
207impl DynamicExprUpdates {
208 pub fn new(expr: &Expression) -> Option<Self> {
209 #[derive(Default)]
210 struct Visitor(Vec<DynamicComparisonExpr>);
211
212 impl NodeVisitor<'_> for Visitor {
213 type NodeTy = Expression;
214
215 fn visit_down(&mut self, node: &'_ Self::NodeTy) -> VortexResult<TraversalOrder> {
216 if let Some(dynamic) = node.as_opt::<DynamicComparison>() {
217 self.0.push(dynamic.clone());
218 }
219 Ok(TraversalOrder::Continue)
220 }
221 }
222
223 let mut visitor = Visitor::default();
224 expr.accept(&mut visitor).vortex_expect("Infallible");
225
226 if visitor.0.is_empty() {
227 return None;
228 }
229
230 let exprs = visitor.0.into_boxed_slice();
231 let prev_versions = exprs
232 .iter()
233 .map(|expr| {
234 (expr.rhs.value)().map(|v| {
235 Scalar::try_new(expr.rhs.dtype.clone(), Some(v))
236 .vortex_expect("`DynamicExprUpdates` was invalid")
237 })
238 })
239 .collect();
240
241 Some(Self {
242 exprs,
243 prev_versions: Mutex::new((0, prev_versions)),
244 })
245 }
246
247 pub fn version(&self) -> u64 {
248 let mut guard = self.prev_versions.lock();
249
250 let mut updated = false;
251 for (i, expr) in self.exprs.iter().enumerate() {
252 let current = expr.scalar();
253 if current != guard.1[i] {
254 updated = true;
258 guard.1[i] = current;
259 }
260 }
261
262 if updated {
263 guard.0 += 1;
264 }
265
266 guard.0
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use std::sync::atomic::AtomicI32;
273 use std::sync::atomic::Ordering;
274
275 use vortex_buffer::buffer;
276 use vortex_error::VortexResult;
277
278 use super::*;
279 use crate::IntoArray;
280 use crate::VortexSessionExecute;
281 use crate::array_session;
282 use crate::arrays::BoolArray;
283 use crate::assert_arrays_eq;
284 use crate::dtype::DType;
285 use crate::dtype::Nullability;
286 use crate::dtype::PType;
287 use crate::expr::dynamic;
288 use crate::expr::root;
289 #[test]
290 fn return_dtype_bool() -> VortexResult<()> {
291 let expr = dynamic(
292 CompareOperator::Lt,
293 || Some(5i32.into()),
294 DType::Primitive(PType::I32, Nullability::NonNullable),
295 true,
296 root(),
297 );
298 let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
299 assert_eq!(
300 expr.return_dtype(&input_dtype)?,
301 DType::Bool(Nullability::NonNullable)
302 );
303 Ok(())
304 }
305
306 #[test]
307 fn execute_with_value() -> VortexResult<()> {
308 let mut ctx = array_session().create_execution_ctx();
309 let input = buffer![1i32, 5, 10].into_array();
310 let expr = dynamic(
311 CompareOperator::Lt,
312 || Some(5i32.into()),
313 DType::Primitive(PType::I32, Nullability::NonNullable),
314 true,
315 root(),
316 );
317 let result = input.apply(&expr)?;
318 assert_arrays_eq!(result, BoolArray::from_iter([true, false, false]), &mut ctx);
319 Ok(())
320 }
321
322 #[test]
323 fn execute_without_value_default_true() -> VortexResult<()> {
324 let mut ctx = array_session().create_execution_ctx();
325 let input = buffer![1i32, 5, 10].into_array();
326 let expr = dynamic(
327 CompareOperator::Lt,
328 || None,
329 DType::Primitive(PType::I32, Nullability::NonNullable),
330 true,
331 root(),
332 );
333 let result = input.apply(&expr)?;
334 assert_arrays_eq!(result, BoolArray::from_iter([true, true, true]), &mut ctx);
335 Ok(())
336 }
337
338 #[test]
339 fn execute_without_value_default_false() -> VortexResult<()> {
340 let mut ctx = array_session().create_execution_ctx();
341 let input = buffer![1i32, 5, 10].into_array();
342 let expr = dynamic(
343 CompareOperator::Lt,
344 || None,
345 DType::Primitive(PType::I32, Nullability::NonNullable),
346 false,
347 root(),
348 );
349 let result = input.apply(&expr)?;
350 assert_arrays_eq!(
351 result,
352 BoolArray::from_iter([false, false, false]),
353 &mut ctx
354 );
355 Ok(())
356 }
357
358 #[test]
359 fn execute_value_flips() -> VortexResult<()> {
360 let mut ctx = array_session().create_execution_ctx();
361 let threshold = Arc::new(AtomicI32::new(5));
362 let threshold_clone = Arc::clone(&threshold);
363 let expr = dynamic(
364 CompareOperator::Lt,
365 move || Some(threshold_clone.load(Ordering::SeqCst).into()),
366 DType::Primitive(PType::I32, Nullability::NonNullable),
367 true,
368 root(),
369 );
370 let input = buffer![1i32, 5, 10].into_array();
371
372 let result = input.clone().apply(&expr)?;
373 assert_arrays_eq!(result, BoolArray::from_iter([true, false, false]), &mut ctx);
374
375 threshold.store(10, Ordering::SeqCst);
376 let result = input.apply(&expr)?;
377 assert_arrays_eq!(result, BoolArray::from_iter([true, true, false]), &mut ctx);
378
379 Ok(())
380 }
381}