Skip to main content

vortex_array/expr/exprs/
dynamic.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::hash::Hasher;
9use std::sync::Arc;
10
11use parking_lot::Mutex;
12use vortex_dtype::DType;
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15use vortex_error::vortex_bail;
16
17use crate::Array;
18use crate::ArrayRef;
19use crate::IntoArray;
20use crate::arrays::ConstantArray;
21use crate::compute::Operator;
22use crate::expr::Arity;
23use crate::expr::Binary;
24use crate::expr::ChildName;
25use crate::expr::ExecutionArgs;
26use crate::expr::ExprId;
27use crate::expr::Expression;
28use crate::expr::StatsCatalog;
29use crate::expr::VTable;
30use crate::expr::VTableExt;
31use crate::expr::traversal::NodeExt;
32use crate::expr::traversal::NodeVisitor;
33use crate::expr::traversal::TraversalOrder;
34use crate::scalar::Scalar;
35use crate::scalar::ScalarValue;
36
37/// A dynamic comparison expression can be used to capture a comparison to a value that can change
38/// during the execution of a query, such as when a compute engine pushes down an ORDER BY + LIMIT
39/// operation and is able to progressively tighten the bounds of the filter.
40pub struct DynamicComparison;
41
42impl VTable for DynamicComparison {
43    type Options = DynamicComparisonExpr;
44
45    fn id(&self) -> ExprId {
46        ExprId::new_ref("vortex.dynamic")
47    }
48
49    fn arity(&self, _options: &Self::Options) -> Arity {
50        Arity::Exact(1)
51    }
52
53    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
54        match child_idx {
55            0 => ChildName::from("lhs"),
56            _ => unreachable!(),
57        }
58    }
59
60    fn fmt_sql(
61        &self,
62        dynamic: &DynamicComparisonExpr,
63        expr: &Expression,
64        f: &mut Formatter<'_>,
65    ) -> std::fmt::Result {
66        expr.child(0).fmt_sql(f)?;
67        write!(f, " {} dynamic(", dynamic)?;
68        match dynamic.scalar() {
69            None => write!(f, "<none>")?,
70            Some(scalar) => write!(f, "{}", scalar)?,
71        }
72        write!(f, ")")
73    }
74
75    fn return_dtype(
76        &self,
77        dynamic: &DynamicComparisonExpr,
78        arg_dtypes: &[DType],
79    ) -> VortexResult<DType> {
80        let lhs = &arg_dtypes[0];
81        if !dynamic.rhs.dtype.eq_ignore_nullability(lhs) {
82            vortex_bail!(
83                "Incompatible dtypes for dynamic comparison: expected {} (ignore nullability) but got {}",
84                &dynamic.rhs.dtype,
85                lhs
86            );
87        }
88        Ok(DType::Bool(
89            lhs.nullability() | dynamic.rhs.dtype.nullability(),
90        ))
91    }
92
93    fn execute(&self, data: &Self::Options, args: ExecutionArgs) -> VortexResult<ArrayRef> {
94        if let Some(scalar) = data.rhs.scalar() {
95            let [lhs]: [ArrayRef; _] = args
96                .inputs
97                .try_into()
98                .map_err(|_| vortex_error::vortex_err!("Wrong arg count for DynamicComparison"))?;
99            let rhs = ConstantArray::new(scalar, args.row_count).into_array();
100
101            return Binary.bind(data.operator.into()).execute(ExecutionArgs {
102                inputs: vec![lhs, rhs],
103                row_count: args.row_count,
104                ctx: args.ctx,
105            });
106        }
107        let ret_dtype =
108            DType::Bool(args.inputs[0].dtype().nullability() | data.rhs.dtype.nullability());
109
110        Ok(ConstantArray::new(
111            Scalar::try_new(ret_dtype, Some(data.default.into()))?,
112            args.row_count,
113        )
114        .into_array())
115    }
116
117    fn stat_falsification(
118        &self,
119        dynamic: &DynamicComparisonExpr,
120        expr: &Expression,
121        catalog: &dyn StatsCatalog,
122    ) -> Option<Expression> {
123        let lhs = expr.child(0);
124        match dynamic.operator {
125            Operator::Gt => Some(DynamicComparison.new_expr(
126                DynamicComparisonExpr {
127                    operator: Operator::Lte,
128                    rhs: dynamic.rhs.clone(),
129                    default: !dynamic.default,
130                },
131                vec![lhs.stat_max(catalog)?],
132            )),
133            Operator::Gte => Some(DynamicComparison.new_expr(
134                DynamicComparisonExpr {
135                    operator: Operator::Lt,
136                    rhs: dynamic.rhs.clone(),
137                    default: !dynamic.default,
138                },
139                vec![lhs.stat_max(catalog)?],
140            )),
141            Operator::Lt => Some(DynamicComparison.new_expr(
142                DynamicComparisonExpr {
143                    operator: Operator::Gte,
144                    rhs: dynamic.rhs.clone(),
145                    default: !dynamic.default,
146                },
147                vec![lhs.stat_min(catalog)?],
148            )),
149            Operator::Lte => Some(DynamicComparison.new_expr(
150                DynamicComparisonExpr {
151                    operator: Operator::Gt,
152                    rhs: dynamic.rhs.clone(),
153                    default: !dynamic.default,
154                },
155                vec![lhs.stat_min(catalog)?],
156            )),
157            _ => None,
158        }
159    }
160
161    // Defer to the child
162    fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
163        false
164    }
165}
166
167pub fn dynamic(
168    operator: Operator,
169    rhs_value: impl Fn() -> Option<ScalarValue> + Send + Sync + 'static,
170    rhs_dtype: DType,
171    default: bool,
172    lhs: Expression,
173) -> Expression {
174    DynamicComparison.new_expr(
175        DynamicComparisonExpr {
176            operator,
177            rhs: Arc::new(Rhs {
178                value: Arc::new(rhs_value),
179                dtype: rhs_dtype,
180            }),
181            default,
182        },
183        [lhs],
184    )
185}
186
187#[derive(Clone, Debug)]
188pub struct DynamicComparisonExpr {
189    operator: Operator,
190    rhs: Arc<Rhs>,
191    // Default value for the dynamic comparison.
192    default: bool,
193}
194
195impl DynamicComparisonExpr {
196    pub fn scalar(&self) -> Option<Scalar> {
197        (self.rhs.value)().map(|v| {
198            Scalar::try_new(self.rhs.dtype.clone(), Some(v))
199                .vortex_expect("`DynamicComparisonExpr` was invalid")
200        })
201    }
202}
203
204impl Display for DynamicComparisonExpr {
205    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
206        write!(
207            f,
208            "{} {}",
209            self.operator,
210            self.scalar()
211                .map_or_else(|| "<none>".to_string(), |v| v.to_string())
212        )
213    }
214}
215
216impl PartialEq for DynamicComparisonExpr {
217    fn eq(&self, other: &Self) -> bool {
218        self.operator == other.operator
219            && Arc::ptr_eq(&self.rhs, &other.rhs)
220            && self.default == other.default
221    }
222}
223impl Eq for DynamicComparisonExpr {}
224
225impl Hash for DynamicComparisonExpr {
226    fn hash<H: Hasher>(&self, state: &mut H) {
227        self.operator.hash(state);
228        Arc::as_ptr(&self.rhs).hash(state);
229        self.default.hash(state);
230    }
231}
232
233/// Hash and PartialEq are implemented based on the ptr of the value function, such that the
234/// internal value doesn't impact the hash of an expression tree.
235struct Rhs {
236    // The right-hand side value is a function that returns an `Option<ScalarValue>`.
237    value: Arc<dyn Fn() -> Option<ScalarValue> + Send + Sync>,
238    // The data type of the right-hand side value.
239    dtype: DType,
240}
241
242impl Rhs {
243    pub fn scalar(&self) -> Option<Scalar> {
244        (self.value)().map(|v| {
245            Scalar::try_new(self.dtype.clone(), Some(v)).vortex_expect("`Rhs` was invalid")
246        })
247    }
248}
249
250impl Debug for Rhs {
251    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
252        f.debug_struct("Rhs")
253            .field("value", &"<dyn Fn() -> Option<ScalarValue> + Send + Sync>")
254            .field("dtype", &self.dtype)
255            .finish()
256    }
257}
258
259/// A utility for checking whether any dynamic expressions have been updated.
260pub struct DynamicExprUpdates {
261    exprs: Box<[DynamicComparisonExpr]>,
262    // Track the latest observed versions of each dynamic expression, along with a version counter.
263    prev_versions: Mutex<(u64, Vec<Option<Scalar>>)>,
264}
265
266impl DynamicExprUpdates {
267    pub fn new(expr: &Expression) -> Option<Self> {
268        #[derive(Default)]
269        struct Visitor(Vec<DynamicComparisonExpr>);
270
271        impl NodeVisitor<'_> for Visitor {
272            type NodeTy = Expression;
273
274            fn visit_down(&mut self, node: &'_ Self::NodeTy) -> VortexResult<TraversalOrder> {
275                if let Some(dynamic) = node.as_opt::<DynamicComparison>() {
276                    self.0.push(dynamic.clone());
277                }
278                Ok(TraversalOrder::Continue)
279            }
280        }
281
282        let mut visitor = Visitor::default();
283        expr.accept(&mut visitor).vortex_expect("Infallible");
284
285        if visitor.0.is_empty() {
286            return None;
287        }
288
289        let exprs = visitor.0.into_boxed_slice();
290        let prev_versions = exprs
291            .iter()
292            .map(|expr| {
293                (expr.rhs.value)().map(|v| {
294                    Scalar::try_new(expr.rhs.dtype.clone(), Some(v))
295                        .vortex_expect("`DynamicExprUpdates` was invalid")
296                })
297            })
298            .collect();
299
300        Some(Self {
301            exprs,
302            prev_versions: Mutex::new((0, prev_versions)),
303        })
304    }
305
306    pub fn version(&self) -> u64 {
307        let mut guard = self.prev_versions.lock();
308
309        let mut updated = false;
310        for (i, expr) in self.exprs.iter().enumerate() {
311            let current = expr.scalar();
312            if current != guard.1[i] {
313                // At least one expression has been updated.
314                // We don't bail out early in order to avoid false positives for future calls
315                // to `is_updated`.
316                updated = true;
317                guard.1[i] = current;
318            }
319        }
320
321        if updated {
322            guard.0 += 1;
323        }
324
325        guard.0
326    }
327}
328
329#[cfg(test)]
330mod tests {
331    use std::sync::atomic::AtomicI32;
332    use std::sync::atomic::Ordering;
333
334    use vortex_buffer::buffer;
335    use vortex_dtype::DType;
336    use vortex_dtype::Nullability;
337    use vortex_dtype::PType;
338    use vortex_error::VortexResult;
339
340    use super::*;
341    use crate::IntoArray;
342    use crate::arrays::BoolArray;
343    use crate::assert_arrays_eq;
344    use crate::expr::exprs::root::root;
345
346    #[test]
347    fn return_dtype_bool() -> VortexResult<()> {
348        let expr = dynamic(
349            Operator::Lt,
350            || Some(5i32.into()),
351            DType::Primitive(PType::I32, Nullability::NonNullable),
352            true,
353            root(),
354        );
355        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
356        assert_eq!(
357            expr.return_dtype(&input_dtype)?,
358            DType::Bool(Nullability::NonNullable)
359        );
360        Ok(())
361    }
362
363    #[test]
364    fn execute_with_value() -> VortexResult<()> {
365        let input = buffer![1i32, 5, 10].into_array();
366        let expr = dynamic(
367            Operator::Lt,
368            || Some(5i32.into()),
369            DType::Primitive(PType::I32, Nullability::NonNullable),
370            true,
371            root(),
372        );
373        let result = input.apply(&expr)?;
374        assert_arrays_eq!(result, BoolArray::from_iter([true, false, false]));
375        Ok(())
376    }
377
378    #[test]
379    fn execute_without_value_default_true() -> VortexResult<()> {
380        let input = buffer![1i32, 5, 10].into_array();
381        let expr = dynamic(
382            Operator::Lt,
383            || None,
384            DType::Primitive(PType::I32, Nullability::NonNullable),
385            true,
386            root(),
387        );
388        let result = input.apply(&expr)?;
389        assert_arrays_eq!(result, BoolArray::from_iter([true, true, true]));
390        Ok(())
391    }
392
393    #[test]
394    fn execute_without_value_default_false() -> VortexResult<()> {
395        let input = buffer![1i32, 5, 10].into_array();
396        let expr = dynamic(
397            Operator::Lt,
398            || None,
399            DType::Primitive(PType::I32, Nullability::NonNullable),
400            false,
401            root(),
402        );
403        let result = input.apply(&expr)?;
404        assert_arrays_eq!(result, BoolArray::from_iter([false, false, false]));
405        Ok(())
406    }
407
408    #[test]
409    fn execute_value_flips() -> VortexResult<()> {
410        let threshold = Arc::new(AtomicI32::new(5));
411        let threshold_clone = threshold.clone();
412        let expr = dynamic(
413            Operator::Lt,
414            move || Some(threshold_clone.load(Ordering::SeqCst).into()),
415            DType::Primitive(PType::I32, Nullability::NonNullable),
416            true,
417            root(),
418        );
419        let input = buffer![1i32, 5, 10].into_array();
420
421        let result = input.apply(&expr)?;
422        assert_arrays_eq!(result, BoolArray::from_iter([true, false, false]));
423
424        threshold.store(10, Ordering::SeqCst);
425        let result = input.apply(&expr)?;
426        assert_arrays_eq!(result, BoolArray::from_iter([true, true, false]));
427
428        Ok(())
429    }
430}