Skip to main content

vortex_array/scalar_fn/fns/
dynamic.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::hash::Hasher;
9use std::sync::Arc;
10
11use parking_lot::Mutex;
12use vortex_error::VortexExpect;
13use vortex_error::VortexResult;
14use vortex_error::vortex_bail;
15
16use crate::Array;
17use crate::ArrayRef;
18use crate::IntoArray;
19use crate::arrays::ConstantArray;
20use crate::dtype::DType;
21use crate::expr::Expression;
22use crate::expr::StatsCatalog;
23use crate::expr::traversal::NodeExt;
24use crate::expr::traversal::NodeVisitor;
25use crate::expr::traversal::TraversalOrder;
26use crate::scalar::Scalar;
27use crate::scalar::ScalarValue;
28use crate::scalar_fn::Arity;
29use crate::scalar_fn::ChildName;
30use crate::scalar_fn::ExecutionArgs;
31use crate::scalar_fn::ScalarFnId;
32use crate::scalar_fn::ScalarFnVTable;
33use crate::scalar_fn::ScalarFnVTableExt;
34use crate::scalar_fn::fns::binary::Binary;
35use crate::scalar_fn::fns::operators::CompareOperator;
36use crate::scalar_fn::fns::operators::Operator;
37
38/// A dynamic comparison expression can be used to capture a comparison to a value that can change
39/// during the execution of a query, such as when a compute engine pushes down an ORDER BY + LIMIT
40/// operation and is able to progressively tighten the bounds of the filter.
41#[derive(Clone)]
42pub struct DynamicComparison;
43
44impl ScalarFnVTable for DynamicComparison {
45    type Options = DynamicComparisonExpr;
46
47    fn id(&self) -> ScalarFnId {
48        ScalarFnId::new_ref("vortex.dynamic")
49    }
50
51    fn arity(&self, _options: &Self::Options) -> Arity {
52        Arity::Exact(1)
53    }
54
55    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
56        match child_idx {
57            0 => ChildName::from("lhs"),
58            _ => unreachable!(),
59        }
60    }
61
62    fn fmt_sql(
63        &self,
64        dynamic: &DynamicComparisonExpr,
65        expr: &Expression,
66        f: &mut Formatter<'_>,
67    ) -> std::fmt::Result {
68        expr.child(0).fmt_sql(f)?;
69        write!(f, " {} dynamic(", dynamic)?;
70        match dynamic.scalar() {
71            None => write!(f, "<none>")?,
72            Some(scalar) => write!(f, "{}", scalar)?,
73        }
74        write!(f, ")")
75    }
76
77    fn return_dtype(
78        &self,
79        dynamic: &DynamicComparisonExpr,
80        arg_dtypes: &[DType],
81    ) -> VortexResult<DType> {
82        let lhs = &arg_dtypes[0];
83        if !dynamic.rhs.dtype.eq_ignore_nullability(lhs) {
84            vortex_bail!(
85                "Incompatible dtypes for dynamic comparison: expected {} (ignore nullability) but got {}",
86                &dynamic.rhs.dtype,
87                lhs
88            );
89        }
90        Ok(DType::Bool(
91            lhs.nullability() | dynamic.rhs.dtype.nullability(),
92        ))
93    }
94
95    fn execute(&self, data: &Self::Options, args: ExecutionArgs) -> VortexResult<ArrayRef> {
96        if let Some(scalar) = data.rhs.scalar() {
97            let [lhs]: [ArrayRef; _] = args
98                .inputs
99                .try_into()
100                .map_err(|_| vortex_error::vortex_err!("Wrong arg count for DynamicComparison"))?;
101            let rhs = ConstantArray::new(scalar, args.row_count).into_array();
102
103            return Binary
104                .bind(Operator::from(data.operator))
105                .execute(ExecutionArgs {
106                    inputs: vec![lhs, rhs],
107                    row_count: args.row_count,
108                    ctx: args.ctx,
109                });
110        }
111        let ret_dtype =
112            DType::Bool(args.inputs[0].dtype().nullability() | data.rhs.dtype.nullability());
113
114        Ok(ConstantArray::new(
115            Scalar::try_new(ret_dtype, Some(data.default.into()))?,
116            args.row_count,
117        )
118        .into_array())
119    }
120
121    fn stat_falsification(
122        &self,
123        dynamic: &DynamicComparisonExpr,
124        expr: &Expression,
125        catalog: &dyn StatsCatalog,
126    ) -> Option<Expression> {
127        let lhs = expr.child(0);
128        match dynamic.operator {
129            CompareOperator::Eq | CompareOperator::NotEq => None,
130            CompareOperator::Gt => Some(DynamicComparison.new_expr(
131                DynamicComparisonExpr {
132                    operator: CompareOperator::Lte,
133                    rhs: dynamic.rhs.clone(),
134                    default: !dynamic.default,
135                },
136                vec![lhs.stat_max(catalog)?],
137            )),
138            CompareOperator::Gte => Some(DynamicComparison.new_expr(
139                DynamicComparisonExpr {
140                    operator: CompareOperator::Lt,
141                    rhs: dynamic.rhs.clone(),
142                    default: !dynamic.default,
143                },
144                vec![lhs.stat_max(catalog)?],
145            )),
146            CompareOperator::Lt => Some(DynamicComparison.new_expr(
147                DynamicComparisonExpr {
148                    operator: CompareOperator::Gte,
149                    rhs: dynamic.rhs.clone(),
150                    default: !dynamic.default,
151                },
152                vec![lhs.stat_min(catalog)?],
153            )),
154            CompareOperator::Lte => Some(DynamicComparison.new_expr(
155                DynamicComparisonExpr {
156                    operator: CompareOperator::Gt,
157                    rhs: dynamic.rhs.clone(),
158                    default: !dynamic.default,
159                },
160                vec![lhs.stat_min(catalog)?],
161            )),
162        }
163    }
164
165    // Defer to the child
166    fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
167        false
168    }
169}
170
171#[derive(Clone, Debug)]
172pub struct DynamicComparisonExpr {
173    pub(crate) operator: CompareOperator,
174    pub(crate) rhs: Arc<Rhs>,
175    // Default value for the dynamic comparison.
176    pub(crate) default: bool,
177}
178
179impl DynamicComparisonExpr {
180    pub fn scalar(&self) -> Option<Scalar> {
181        (self.rhs.value)().map(|v| {
182            Scalar::try_new(self.rhs.dtype.clone(), Some(v))
183                .vortex_expect("`DynamicComparisonExpr` was invalid")
184        })
185    }
186}
187
188impl Display for DynamicComparisonExpr {
189    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
190        write!(
191            f,
192            "{} {}",
193            self.operator,
194            self.scalar()
195                .map_or_else(|| "<none>".to_string(), |v| v.to_string())
196        )
197    }
198}
199
200impl PartialEq for DynamicComparisonExpr {
201    fn eq(&self, other: &Self) -> bool {
202        self.operator == other.operator
203            && Arc::ptr_eq(&self.rhs, &other.rhs)
204            && self.default == other.default
205    }
206}
207impl Eq for DynamicComparisonExpr {}
208
209impl Hash for DynamicComparisonExpr {
210    fn hash<H: Hasher>(&self, state: &mut H) {
211        self.operator.hash(state);
212        Arc::as_ptr(&self.rhs).hash(state);
213        self.default.hash(state);
214    }
215}
216
217/// Hash and PartialEq are implemented based on the ptr of the value function, such that the
218/// internal value doesn't impact the hash of an expression tree.
219pub(crate) struct Rhs {
220    // The right-hand side value is a function that returns an `Option<ScalarValue>`.
221    pub(crate) value: Arc<dyn Fn() -> Option<ScalarValue> + Send + Sync>,
222    // The data type of the right-hand side value.
223    pub(crate) dtype: DType,
224}
225
226impl Rhs {
227    pub fn scalar(&self) -> Option<Scalar> {
228        (self.value)().map(|v| {
229            Scalar::try_new(self.dtype.clone(), Some(v)).vortex_expect("`Rhs` was invalid")
230        })
231    }
232}
233
234impl Debug for Rhs {
235    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
236        f.debug_struct("Rhs")
237            .field("value", &"<dyn Fn() -> Option<ScalarValue> + Send + Sync>")
238            .field("dtype", &self.dtype)
239            .finish()
240    }
241}
242
243/// A utility for checking whether any dynamic expressions have been updated.
244pub struct DynamicExprUpdates {
245    exprs: Box<[DynamicComparisonExpr]>,
246    // Track the latest observed versions of each dynamic expression, along with a version counter.
247    prev_versions: Mutex<(u64, Vec<Option<Scalar>>)>,
248}
249
250impl DynamicExprUpdates {
251    pub fn new(expr: &Expression) -> Option<Self> {
252        #[derive(Default)]
253        struct Visitor(Vec<DynamicComparisonExpr>);
254
255        impl NodeVisitor<'_> for Visitor {
256            type NodeTy = Expression;
257
258            fn visit_down(&mut self, node: &'_ Self::NodeTy) -> VortexResult<TraversalOrder> {
259                if let Some(dynamic) = node.as_opt::<DynamicComparison>() {
260                    self.0.push(dynamic.clone());
261                }
262                Ok(TraversalOrder::Continue)
263            }
264        }
265
266        let mut visitor = Visitor::default();
267        expr.accept(&mut visitor).vortex_expect("Infallible");
268
269        if visitor.0.is_empty() {
270            return None;
271        }
272
273        let exprs = visitor.0.into_boxed_slice();
274        let prev_versions = exprs
275            .iter()
276            .map(|expr| {
277                (expr.rhs.value)().map(|v| {
278                    Scalar::try_new(expr.rhs.dtype.clone(), Some(v))
279                        .vortex_expect("`DynamicExprUpdates` was invalid")
280                })
281            })
282            .collect();
283
284        Some(Self {
285            exprs,
286            prev_versions: Mutex::new((0, prev_versions)),
287        })
288    }
289
290    pub fn version(&self) -> u64 {
291        let mut guard = self.prev_versions.lock();
292
293        let mut updated = false;
294        for (i, expr) in self.exprs.iter().enumerate() {
295            let current = expr.scalar();
296            if current != guard.1[i] {
297                // At least one expression has been updated.
298                // We don't bail out early in order to avoid false positives for future calls
299                // to `is_updated`.
300                updated = true;
301                guard.1[i] = current;
302            }
303        }
304
305        if updated {
306            guard.0 += 1;
307        }
308
309        guard.0
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use std::sync::atomic::AtomicI32;
316    use std::sync::atomic::Ordering;
317
318    use vortex_buffer::buffer;
319    use vortex_error::VortexResult;
320
321    use super::*;
322    use crate::IntoArray;
323    use crate::arrays::BoolArray;
324    use crate::assert_arrays_eq;
325    use crate::dtype::DType;
326    use crate::dtype::Nullability;
327    use crate::dtype::PType;
328    use crate::expr::dynamic;
329    use crate::expr::root;
330    #[test]
331    fn return_dtype_bool() -> VortexResult<()> {
332        let expr = dynamic(
333            CompareOperator::Lt,
334            || Some(5i32.into()),
335            DType::Primitive(PType::I32, Nullability::NonNullable),
336            true,
337            root(),
338        );
339        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
340        assert_eq!(
341            expr.return_dtype(&input_dtype)?,
342            DType::Bool(Nullability::NonNullable)
343        );
344        Ok(())
345    }
346
347    #[test]
348    fn execute_with_value() -> VortexResult<()> {
349        let input = buffer![1i32, 5, 10].into_array();
350        let expr = dynamic(
351            CompareOperator::Lt,
352            || Some(5i32.into()),
353            DType::Primitive(PType::I32, Nullability::NonNullable),
354            true,
355            root(),
356        );
357        let result = input.apply(&expr)?;
358        assert_arrays_eq!(result, BoolArray::from_iter([true, false, false]));
359        Ok(())
360    }
361
362    #[test]
363    fn execute_without_value_default_true() -> VortexResult<()> {
364        let input = buffer![1i32, 5, 10].into_array();
365        let expr = dynamic(
366            CompareOperator::Lt,
367            || None,
368            DType::Primitive(PType::I32, Nullability::NonNullable),
369            true,
370            root(),
371        );
372        let result = input.apply(&expr)?;
373        assert_arrays_eq!(result, BoolArray::from_iter([true, true, true]));
374        Ok(())
375    }
376
377    #[test]
378    fn execute_without_value_default_false() -> VortexResult<()> {
379        let input = buffer![1i32, 5, 10].into_array();
380        let expr = dynamic(
381            CompareOperator::Lt,
382            || None,
383            DType::Primitive(PType::I32, Nullability::NonNullable),
384            false,
385            root(),
386        );
387        let result = input.apply(&expr)?;
388        assert_arrays_eq!(result, BoolArray::from_iter([false, false, false]));
389        Ok(())
390    }
391
392    #[test]
393    fn execute_value_flips() -> VortexResult<()> {
394        let threshold = Arc::new(AtomicI32::new(5));
395        let threshold_clone = threshold.clone();
396        let expr = dynamic(
397            CompareOperator::Lt,
398            move || Some(threshold_clone.load(Ordering::SeqCst).into()),
399            DType::Primitive(PType::I32, Nullability::NonNullable),
400            true,
401            root(),
402        );
403        let input = buffer![1i32, 5, 10].into_array();
404
405        let result = input.apply(&expr)?;
406        assert_arrays_eq!(result, BoolArray::from_iter([true, false, false]));
407
408        threshold.store(10, Ordering::SeqCst);
409        let result = input.apply(&expr)?;
410        assert_arrays_eq!(result, BoolArray::from_iter([true, true, false]));
411
412        Ok(())
413    }
414}