Skip to main content

vortex_array/scalar_fn/fns/
dynamic.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::hash::Hasher;
9use std::sync::Arc;
10
11use parking_lot::Mutex;
12use vortex_error::VortexExpect;
13use vortex_error::VortexResult;
14use vortex_error::vortex_bail;
15
16use crate::Array;
17use crate::ArrayRef;
18use crate::ExecutionCtx;
19use crate::IntoArray;
20use crate::arrays::ConstantArray;
21use crate::dtype::DType;
22use crate::expr::Expression;
23use crate::expr::StatsCatalog;
24use crate::expr::traversal::NodeExt;
25use crate::expr::traversal::NodeVisitor;
26use crate::expr::traversal::TraversalOrder;
27use crate::scalar::Scalar;
28use crate::scalar::ScalarValue;
29use crate::scalar_fn::Arity;
30use crate::scalar_fn::ChildName;
31use crate::scalar_fn::ExecutionArgs;
32use crate::scalar_fn::ScalarFnId;
33use crate::scalar_fn::ScalarFnVTable;
34use crate::scalar_fn::ScalarFnVTableExt;
35use crate::scalar_fn::VecExecutionArgs;
36use crate::scalar_fn::fns::binary::Binary;
37use crate::scalar_fn::fns::operators::CompareOperator;
38use crate::scalar_fn::fns::operators::Operator;
39
40/// A dynamic comparison expression can be used to capture a comparison to a value that can change
41/// during the execution of a query, such as when a compute engine pushes down an ORDER BY + LIMIT
42/// operation and is able to progressively tighten the bounds of the filter.
43#[derive(Clone)]
44pub struct DynamicComparison;
45
46impl ScalarFnVTable for DynamicComparison {
47    type Options = DynamicComparisonExpr;
48
49    fn id(&self) -> ScalarFnId {
50        ScalarFnId::new_ref("vortex.dynamic")
51    }
52
53    fn arity(&self, _options: &Self::Options) -> Arity {
54        Arity::Exact(1)
55    }
56
57    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
58        match child_idx {
59            0 => ChildName::from("lhs"),
60            _ => unreachable!(),
61        }
62    }
63
64    fn fmt_sql(
65        &self,
66        dynamic: &DynamicComparisonExpr,
67        expr: &Expression,
68        f: &mut Formatter<'_>,
69    ) -> std::fmt::Result {
70        expr.child(0).fmt_sql(f)?;
71        write!(f, " {} dynamic(", dynamic)?;
72        match dynamic.scalar() {
73            None => write!(f, "<none>")?,
74            Some(scalar) => write!(f, "{}", scalar)?,
75        }
76        write!(f, ")")
77    }
78
79    fn return_dtype(
80        &self,
81        dynamic: &DynamicComparisonExpr,
82        arg_dtypes: &[DType],
83    ) -> VortexResult<DType> {
84        let lhs = &arg_dtypes[0];
85        if !dynamic.rhs.dtype.eq_ignore_nullability(lhs) {
86            vortex_bail!(
87                "Incompatible dtypes for dynamic comparison: expected {} (ignore nullability) but got {}",
88                &dynamic.rhs.dtype,
89                lhs
90            );
91        }
92        Ok(DType::Bool(
93            lhs.nullability() | dynamic.rhs.dtype.nullability(),
94        ))
95    }
96
97    fn execute(
98        &self,
99        data: &Self::Options,
100        args: &dyn ExecutionArgs,
101        ctx: &mut ExecutionCtx,
102    ) -> VortexResult<ArrayRef> {
103        if let Some(scalar) = data.rhs.scalar() {
104            let lhs = args.get(0)?;
105            let rhs = ConstantArray::new(scalar, args.row_count()).into_array();
106
107            let delegate_args = VecExecutionArgs::new(vec![lhs, rhs], args.row_count());
108            return Binary
109                .bind(Operator::from(data.operator))
110                .execute(&delegate_args, ctx);
111        }
112        let ret_dtype =
113            DType::Bool(args.get(0)?.dtype().nullability() | data.rhs.dtype.nullability());
114
115        Ok(ConstantArray::new(
116            Scalar::try_new(ret_dtype, Some(data.default.into()))?,
117            args.row_count(),
118        )
119        .into_array())
120    }
121
122    fn stat_falsification(
123        &self,
124        dynamic: &DynamicComparisonExpr,
125        expr: &Expression,
126        catalog: &dyn StatsCatalog,
127    ) -> Option<Expression> {
128        let lhs = expr.child(0);
129        match dynamic.operator {
130            CompareOperator::Eq | CompareOperator::NotEq => None,
131            CompareOperator::Gt => Some(DynamicComparison.new_expr(
132                DynamicComparisonExpr {
133                    operator: CompareOperator::Lte,
134                    rhs: dynamic.rhs.clone(),
135                    default: !dynamic.default,
136                },
137                vec![lhs.stat_max(catalog)?],
138            )),
139            CompareOperator::Gte => Some(DynamicComparison.new_expr(
140                DynamicComparisonExpr {
141                    operator: CompareOperator::Lt,
142                    rhs: dynamic.rhs.clone(),
143                    default: !dynamic.default,
144                },
145                vec![lhs.stat_max(catalog)?],
146            )),
147            CompareOperator::Lt => Some(DynamicComparison.new_expr(
148                DynamicComparisonExpr {
149                    operator: CompareOperator::Gte,
150                    rhs: dynamic.rhs.clone(),
151                    default: !dynamic.default,
152                },
153                vec![lhs.stat_min(catalog)?],
154            )),
155            CompareOperator::Lte => Some(DynamicComparison.new_expr(
156                DynamicComparisonExpr {
157                    operator: CompareOperator::Gt,
158                    rhs: dynamic.rhs.clone(),
159                    default: !dynamic.default,
160                },
161                vec![lhs.stat_min(catalog)?],
162            )),
163        }
164    }
165
166    // Defer to the child
167    fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
168        false
169    }
170}
171
172#[derive(Clone, Debug)]
173pub struct DynamicComparisonExpr {
174    pub(crate) operator: CompareOperator,
175    pub(crate) rhs: Arc<Rhs>,
176    // Default value for the dynamic comparison.
177    pub(crate) default: bool,
178}
179
180impl DynamicComparisonExpr {
181    pub fn scalar(&self) -> Option<Scalar> {
182        (self.rhs.value)().map(|v| {
183            Scalar::try_new(self.rhs.dtype.clone(), Some(v))
184                .vortex_expect("`DynamicComparisonExpr` was invalid")
185        })
186    }
187}
188
189impl Display for DynamicComparisonExpr {
190    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
191        write!(
192            f,
193            "{} {}",
194            self.operator,
195            self.scalar()
196                .map_or_else(|| "<none>".to_string(), |v| v.to_string())
197        )
198    }
199}
200
201impl PartialEq for DynamicComparisonExpr {
202    fn eq(&self, other: &Self) -> bool {
203        self.operator == other.operator
204            && Arc::ptr_eq(&self.rhs, &other.rhs)
205            && self.default == other.default
206    }
207}
208impl Eq for DynamicComparisonExpr {}
209
210impl Hash for DynamicComparisonExpr {
211    fn hash<H: Hasher>(&self, state: &mut H) {
212        self.operator.hash(state);
213        Arc::as_ptr(&self.rhs).hash(state);
214        self.default.hash(state);
215    }
216}
217
218/// Hash and PartialEq are implemented based on the ptr of the value function, such that the
219/// internal value doesn't impact the hash of an expression tree.
220pub(crate) struct Rhs {
221    // The right-hand side value is a function that returns an `Option<ScalarValue>`.
222    pub(crate) value: Arc<dyn Fn() -> Option<ScalarValue> + Send + Sync>,
223    // The data type of the right-hand side value.
224    pub(crate) dtype: DType,
225}
226
227impl Rhs {
228    pub fn scalar(&self) -> Option<Scalar> {
229        (self.value)().map(|v| {
230            Scalar::try_new(self.dtype.clone(), Some(v)).vortex_expect("`Rhs` was invalid")
231        })
232    }
233}
234
235impl Debug for Rhs {
236    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
237        f.debug_struct("Rhs")
238            .field("value", &"<dyn Fn() -> Option<ScalarValue> + Send + Sync>")
239            .field("dtype", &self.dtype)
240            .finish()
241    }
242}
243
244/// A utility for checking whether any dynamic expressions have been updated.
245pub struct DynamicExprUpdates {
246    exprs: Box<[DynamicComparisonExpr]>,
247    // Track the latest observed versions of each dynamic expression, along with a version counter.
248    prev_versions: Mutex<(u64, Vec<Option<Scalar>>)>,
249}
250
251impl DynamicExprUpdates {
252    pub fn new(expr: &Expression) -> Option<Self> {
253        #[derive(Default)]
254        struct Visitor(Vec<DynamicComparisonExpr>);
255
256        impl NodeVisitor<'_> for Visitor {
257            type NodeTy = Expression;
258
259            fn visit_down(&mut self, node: &'_ Self::NodeTy) -> VortexResult<TraversalOrder> {
260                if let Some(dynamic) = node.as_opt::<DynamicComparison>() {
261                    self.0.push(dynamic.clone());
262                }
263                Ok(TraversalOrder::Continue)
264            }
265        }
266
267        let mut visitor = Visitor::default();
268        expr.accept(&mut visitor).vortex_expect("Infallible");
269
270        if visitor.0.is_empty() {
271            return None;
272        }
273
274        let exprs = visitor.0.into_boxed_slice();
275        let prev_versions = exprs
276            .iter()
277            .map(|expr| {
278                (expr.rhs.value)().map(|v| {
279                    Scalar::try_new(expr.rhs.dtype.clone(), Some(v))
280                        .vortex_expect("`DynamicExprUpdates` was invalid")
281                })
282            })
283            .collect();
284
285        Some(Self {
286            exprs,
287            prev_versions: Mutex::new((0, prev_versions)),
288        })
289    }
290
291    pub fn version(&self) -> u64 {
292        let mut guard = self.prev_versions.lock();
293
294        let mut updated = false;
295        for (i, expr) in self.exprs.iter().enumerate() {
296            let current = expr.scalar();
297            if current != guard.1[i] {
298                // At least one expression has been updated.
299                // We don't bail out early in order to avoid false positives for future calls
300                // to `is_updated`.
301                updated = true;
302                guard.1[i] = current;
303            }
304        }
305
306        if updated {
307            guard.0 += 1;
308        }
309
310        guard.0
311    }
312}
313
314#[cfg(test)]
315mod tests {
316    use std::sync::atomic::AtomicI32;
317    use std::sync::atomic::Ordering;
318
319    use vortex_buffer::buffer;
320    use vortex_error::VortexResult;
321
322    use super::*;
323    use crate::IntoArray;
324    use crate::arrays::BoolArray;
325    use crate::assert_arrays_eq;
326    use crate::dtype::DType;
327    use crate::dtype::Nullability;
328    use crate::dtype::PType;
329    use crate::expr::dynamic;
330    use crate::expr::root;
331    #[test]
332    fn return_dtype_bool() -> VortexResult<()> {
333        let expr = dynamic(
334            CompareOperator::Lt,
335            || Some(5i32.into()),
336            DType::Primitive(PType::I32, Nullability::NonNullable),
337            true,
338            root(),
339        );
340        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
341        assert_eq!(
342            expr.return_dtype(&input_dtype)?,
343            DType::Bool(Nullability::NonNullable)
344        );
345        Ok(())
346    }
347
348    #[test]
349    fn execute_with_value() -> VortexResult<()> {
350        let input = buffer![1i32, 5, 10].into_array();
351        let expr = dynamic(
352            CompareOperator::Lt,
353            || Some(5i32.into()),
354            DType::Primitive(PType::I32, Nullability::NonNullable),
355            true,
356            root(),
357        );
358        let result = input.apply(&expr)?;
359        assert_arrays_eq!(result, BoolArray::from_iter([true, false, false]));
360        Ok(())
361    }
362
363    #[test]
364    fn execute_without_value_default_true() -> VortexResult<()> {
365        let input = buffer![1i32, 5, 10].into_array();
366        let expr = dynamic(
367            CompareOperator::Lt,
368            || None,
369            DType::Primitive(PType::I32, Nullability::NonNullable),
370            true,
371            root(),
372        );
373        let result = input.apply(&expr)?;
374        assert_arrays_eq!(result, BoolArray::from_iter([true, true, true]));
375        Ok(())
376    }
377
378    #[test]
379    fn execute_without_value_default_false() -> VortexResult<()> {
380        let input = buffer![1i32, 5, 10].into_array();
381        let expr = dynamic(
382            CompareOperator::Lt,
383            || None,
384            DType::Primitive(PType::I32, Nullability::NonNullable),
385            false,
386            root(),
387        );
388        let result = input.apply(&expr)?;
389        assert_arrays_eq!(result, BoolArray::from_iter([false, false, false]));
390        Ok(())
391    }
392
393    #[test]
394    fn execute_value_flips() -> VortexResult<()> {
395        let threshold = Arc::new(AtomicI32::new(5));
396        let threshold_clone = threshold.clone();
397        let expr = dynamic(
398            CompareOperator::Lt,
399            move || Some(threshold_clone.load(Ordering::SeqCst).into()),
400            DType::Primitive(PType::I32, Nullability::NonNullable),
401            true,
402            root(),
403        );
404        let input = buffer![1i32, 5, 10].into_array();
405
406        let result = input.apply(&expr)?;
407        assert_arrays_eq!(result, BoolArray::from_iter([true, false, false]));
408
409        threshold.store(10, Ordering::SeqCst);
410        let result = input.apply(&expr)?;
411        assert_arrays_eq!(result, BoolArray::from_iter([true, true, false]));
412
413        Ok(())
414    }
415}