Skip to main content

vortex_array/scalar_fn/fns/
dynamic.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::hash::Hasher;
9use std::sync::Arc;
10
11use parking_lot::Mutex;
12use vortex_error::VortexExpect;
13use vortex_error::VortexResult;
14use vortex_error::vortex_bail;
15use vortex_session::registry::CachedId;
16
17use crate::ArrayRef;
18use crate::ExecutionCtx;
19use crate::IntoArray;
20use crate::arrays::ConstantArray;
21use crate::dtype::DType;
22use crate::expr::Expression;
23use crate::expr::StatsCatalog;
24use crate::expr::traversal::NodeExt;
25use crate::expr::traversal::NodeVisitor;
26use crate::expr::traversal::TraversalOrder;
27use crate::scalar::Scalar;
28use crate::scalar::ScalarValue;
29use crate::scalar_fn::Arity;
30use crate::scalar_fn::ChildName;
31use crate::scalar_fn::ExecutionArgs;
32use crate::scalar_fn::ScalarFnId;
33use crate::scalar_fn::ScalarFnVTable;
34use crate::scalar_fn::ScalarFnVTableExt;
35use crate::scalar_fn::VecExecutionArgs;
36use crate::scalar_fn::fns::binary::Binary;
37use crate::scalar_fn::fns::operators::CompareOperator;
38use crate::scalar_fn::fns::operators::Operator;
39
40/// A dynamic comparison expression can be used to capture a comparison to a value that can change
41/// during the execution of a query, such as when a compute engine pushes down an ORDER BY + LIMIT
42/// operation and is able to progressively tighten the bounds of the filter.
43#[derive(Clone)]
44pub struct DynamicComparison;
45
46impl ScalarFnVTable for DynamicComparison {
47    type Options = DynamicComparisonExpr;
48
49    fn id(&self) -> ScalarFnId {
50        static ID: CachedId = CachedId::new("vortex.dynamic");
51        *ID
52    }
53
54    fn arity(&self, _options: &Self::Options) -> Arity {
55        Arity::Exact(1)
56    }
57
58    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
59        match child_idx {
60            0 => ChildName::from("lhs"),
61            _ => unreachable!(),
62        }
63    }
64
65    fn fmt_sql(
66        &self,
67        dynamic: &DynamicComparisonExpr,
68        expr: &Expression,
69        f: &mut Formatter<'_>,
70    ) -> std::fmt::Result {
71        expr.child(0).fmt_sql(f)?;
72        write!(f, " {} dynamic(", dynamic.operator)?;
73        match dynamic.scalar() {
74            None => write!(f, "scalar=<none>")?,
75            Some(scalar) => write!(f, "scalar={scalar}")?,
76        }
77        write!(f, ")")
78    }
79
80    fn return_dtype(
81        &self,
82        dynamic: &DynamicComparisonExpr,
83        arg_dtypes: &[DType],
84    ) -> VortexResult<DType> {
85        let lhs = &arg_dtypes[0];
86        if !dynamic.rhs.dtype.eq_ignore_nullability(lhs) {
87            vortex_bail!(
88                "Incompatible dtypes for dynamic comparison: expected {} (ignore nullability) but got {}",
89                &dynamic.rhs.dtype,
90                lhs
91            );
92        }
93        Ok(DType::Bool(
94            lhs.nullability() | dynamic.rhs.dtype.nullability(),
95        ))
96    }
97
98    fn execute(
99        &self,
100        data: &Self::Options,
101        args: &dyn ExecutionArgs,
102        ctx: &mut ExecutionCtx,
103    ) -> VortexResult<ArrayRef> {
104        if let Some(scalar) = data.rhs.scalar() {
105            let lhs = args.get(0)?;
106            let rhs = ConstantArray::new(scalar, args.row_count()).into_array();
107
108            let delegate_args = VecExecutionArgs::new(vec![lhs, rhs], args.row_count());
109            return Binary
110                .bind(Operator::from(data.operator))
111                .execute(&delegate_args, ctx);
112        }
113        let ret_dtype =
114            DType::Bool(args.get(0)?.dtype().nullability() | data.rhs.dtype.nullability());
115
116        Ok(ConstantArray::new(
117            Scalar::try_new(ret_dtype, Some(data.default.into()))?,
118            args.row_count(),
119        )
120        .into_array())
121    }
122
123    fn stat_falsification(
124        &self,
125        dynamic: &DynamicComparisonExpr,
126        expr: &Expression,
127        catalog: &dyn StatsCatalog,
128    ) -> Option<Expression> {
129        let lhs = expr.child(0);
130        match dynamic.operator {
131            CompareOperator::Eq | CompareOperator::NotEq => None,
132            CompareOperator::Gt => Some(DynamicComparison.new_expr(
133                DynamicComparisonExpr {
134                    operator: CompareOperator::Lte,
135                    rhs: Arc::clone(&dynamic.rhs),
136                    default: !dynamic.default,
137                },
138                vec![lhs.stat_max(catalog)?],
139            )),
140            CompareOperator::Gte => Some(DynamicComparison.new_expr(
141                DynamicComparisonExpr {
142                    operator: CompareOperator::Lt,
143                    rhs: Arc::clone(&dynamic.rhs),
144                    default: !dynamic.default,
145                },
146                vec![lhs.stat_max(catalog)?],
147            )),
148            CompareOperator::Lt => Some(DynamicComparison.new_expr(
149                DynamicComparisonExpr {
150                    operator: CompareOperator::Gte,
151                    rhs: Arc::clone(&dynamic.rhs),
152                    default: !dynamic.default,
153                },
154                vec![lhs.stat_min(catalog)?],
155            )),
156            CompareOperator::Lte => Some(DynamicComparison.new_expr(
157                DynamicComparisonExpr {
158                    operator: CompareOperator::Gt,
159                    rhs: Arc::clone(&dynamic.rhs),
160                    default: !dynamic.default,
161                },
162                vec![lhs.stat_min(catalog)?],
163            )),
164        }
165    }
166
167    // Defer to the child
168    fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
169        false
170    }
171}
172
173#[derive(Clone, Debug)]
174pub struct DynamicComparisonExpr {
175    pub(crate) operator: CompareOperator,
176    pub(crate) rhs: Arc<Rhs>,
177    // Default value for the dynamic comparison.
178    pub(crate) default: bool,
179}
180
181impl DynamicComparisonExpr {
182    pub fn scalar(&self) -> Option<Scalar> {
183        (self.rhs.value)().map(|v| {
184            Scalar::try_new(self.rhs.dtype.clone(), Some(v))
185                .vortex_expect("`DynamicComparisonExpr` was invalid")
186        })
187    }
188}
189
190impl Display for DynamicComparisonExpr {
191    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
192        write!(
193            f,
194            "{} {}",
195            self.operator,
196            self.scalar()
197                .map_or_else(|| "<none>".to_string(), |v| v.to_string())
198        )
199    }
200}
201
202impl PartialEq for DynamicComparisonExpr {
203    fn eq(&self, other: &Self) -> bool {
204        self.operator == other.operator
205            && Arc::ptr_eq(&self.rhs, &other.rhs)
206            && self.default == other.default
207    }
208}
209impl Eq for DynamicComparisonExpr {}
210
211impl Hash for DynamicComparisonExpr {
212    fn hash<H: Hasher>(&self, state: &mut H) {
213        self.operator.hash(state);
214        Arc::as_ptr(&self.rhs).hash(state);
215        self.default.hash(state);
216    }
217}
218
219/// Hash and PartialEq are implemented based on the ptr of the value function, such that the
220/// internal value doesn't impact the hash of an expression tree.
221pub(crate) struct Rhs {
222    // The right-hand side value is a function that returns an `Option<ScalarValue>`.
223    pub(crate) value: Arc<dyn Fn() -> Option<ScalarValue> + Send + Sync>,
224    // The data type of the right-hand side value.
225    pub(crate) dtype: DType,
226}
227
228impl Rhs {
229    pub fn scalar(&self) -> Option<Scalar> {
230        (self.value)().map(|v| {
231            Scalar::try_new(self.dtype.clone(), Some(v)).vortex_expect("`Rhs` was invalid")
232        })
233    }
234}
235
236impl Debug for Rhs {
237    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
238        f.debug_struct("Rhs")
239            .field("value", &"<dyn Fn() -> Option<ScalarValue> + Send + Sync>")
240            .field("dtype", &self.dtype)
241            .finish()
242    }
243}
244
245/// A utility for checking whether any dynamic expressions have been updated.
246pub struct DynamicExprUpdates {
247    exprs: Box<[DynamicComparisonExpr]>,
248    // Track the latest observed versions of each dynamic expression, along with a version counter.
249    prev_versions: Mutex<(u64, Vec<Option<Scalar>>)>,
250}
251
252impl DynamicExprUpdates {
253    pub fn new(expr: &Expression) -> Option<Self> {
254        #[derive(Default)]
255        struct Visitor(Vec<DynamicComparisonExpr>);
256
257        impl NodeVisitor<'_> for Visitor {
258            type NodeTy = Expression;
259
260            fn visit_down(&mut self, node: &'_ Self::NodeTy) -> VortexResult<TraversalOrder> {
261                if let Some(dynamic) = node.as_opt::<DynamicComparison>() {
262                    self.0.push(dynamic.clone());
263                }
264                Ok(TraversalOrder::Continue)
265            }
266        }
267
268        let mut visitor = Visitor::default();
269        expr.accept(&mut visitor).vortex_expect("Infallible");
270
271        if visitor.0.is_empty() {
272            return None;
273        }
274
275        let exprs = visitor.0.into_boxed_slice();
276        let prev_versions = exprs
277            .iter()
278            .map(|expr| {
279                (expr.rhs.value)().map(|v| {
280                    Scalar::try_new(expr.rhs.dtype.clone(), Some(v))
281                        .vortex_expect("`DynamicExprUpdates` was invalid")
282                })
283            })
284            .collect();
285
286        Some(Self {
287            exprs,
288            prev_versions: Mutex::new((0, prev_versions)),
289        })
290    }
291
292    pub fn version(&self) -> u64 {
293        let mut guard = self.prev_versions.lock();
294
295        let mut updated = false;
296        for (i, expr) in self.exprs.iter().enumerate() {
297            let current = expr.scalar();
298            if current != guard.1[i] {
299                // At least one expression has been updated.
300                // We don't bail out early in order to avoid false positives for future calls
301                // to `is_updated`.
302                updated = true;
303                guard.1[i] = current;
304            }
305        }
306
307        if updated {
308            guard.0 += 1;
309        }
310
311        guard.0
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    use std::sync::atomic::AtomicI32;
318    use std::sync::atomic::Ordering;
319
320    use vortex_buffer::buffer;
321    use vortex_error::VortexResult;
322
323    use super::*;
324    use crate::IntoArray;
325    use crate::arrays::BoolArray;
326    use crate::assert_arrays_eq;
327    use crate::dtype::DType;
328    use crate::dtype::Nullability;
329    use crate::dtype::PType;
330    use crate::expr::dynamic;
331    use crate::expr::root;
332    #[test]
333    fn return_dtype_bool() -> VortexResult<()> {
334        let expr = dynamic(
335            CompareOperator::Lt,
336            || Some(5i32.into()),
337            DType::Primitive(PType::I32, Nullability::NonNullable),
338            true,
339            root(),
340        );
341        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
342        assert_eq!(
343            expr.return_dtype(&input_dtype)?,
344            DType::Bool(Nullability::NonNullable)
345        );
346        Ok(())
347    }
348
349    #[test]
350    fn execute_with_value() -> VortexResult<()> {
351        let input = buffer![1i32, 5, 10].into_array();
352        let expr = dynamic(
353            CompareOperator::Lt,
354            || Some(5i32.into()),
355            DType::Primitive(PType::I32, Nullability::NonNullable),
356            true,
357            root(),
358        );
359        let result = input.apply(&expr)?;
360        assert_arrays_eq!(result, BoolArray::from_iter([true, false, false]));
361        Ok(())
362    }
363
364    #[test]
365    fn execute_without_value_default_true() -> VortexResult<()> {
366        let input = buffer![1i32, 5, 10].into_array();
367        let expr = dynamic(
368            CompareOperator::Lt,
369            || None,
370            DType::Primitive(PType::I32, Nullability::NonNullable),
371            true,
372            root(),
373        );
374        let result = input.apply(&expr)?;
375        assert_arrays_eq!(result, BoolArray::from_iter([true, true, true]));
376        Ok(())
377    }
378
379    #[test]
380    fn execute_without_value_default_false() -> VortexResult<()> {
381        let input = buffer![1i32, 5, 10].into_array();
382        let expr = dynamic(
383            CompareOperator::Lt,
384            || None,
385            DType::Primitive(PType::I32, Nullability::NonNullable),
386            false,
387            root(),
388        );
389        let result = input.apply(&expr)?;
390        assert_arrays_eq!(result, BoolArray::from_iter([false, false, false]));
391        Ok(())
392    }
393
394    #[test]
395    fn execute_value_flips() -> VortexResult<()> {
396        let threshold = Arc::new(AtomicI32::new(5));
397        let threshold_clone = Arc::clone(&threshold);
398        let expr = dynamic(
399            CompareOperator::Lt,
400            move || Some(threshold_clone.load(Ordering::SeqCst).into()),
401            DType::Primitive(PType::I32, Nullability::NonNullable),
402            true,
403            root(),
404        );
405        let input = buffer![1i32, 5, 10].into_array();
406
407        let result = input.clone().apply(&expr)?;
408        assert_arrays_eq!(result, BoolArray::from_iter([true, false, false]));
409
410        threshold.store(10, Ordering::SeqCst);
411        let result = input.apply(&expr)?;
412        assert_arrays_eq!(result, BoolArray::from_iter([true, true, false]));
413
414        Ok(())
415    }
416}