Skip to main content

vortex_array/scalar_fn/fns/
dynamic.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::hash::Hasher;
9use std::sync::Arc;
10
11use parking_lot::Mutex;
12use vortex_error::VortexExpect;
13use vortex_error::VortexResult;
14use vortex_error::vortex_bail;
15use vortex_session::registry::CachedId;
16
17use crate::ArrayRef;
18use crate::ExecutionCtx;
19use crate::IntoArray;
20use crate::arrays::ConstantArray;
21use crate::dtype::DType;
22use crate::expr::Expression;
23use crate::expr::traversal::NodeExt;
24use crate::expr::traversal::NodeVisitor;
25use crate::expr::traversal::TraversalOrder;
26use crate::scalar::Scalar;
27use crate::scalar::ScalarValue;
28use crate::scalar_fn::Arity;
29use crate::scalar_fn::ChildName;
30use crate::scalar_fn::ExecutionArgs;
31use crate::scalar_fn::ScalarFnId;
32use crate::scalar_fn::ScalarFnVTable;
33use crate::scalar_fn::ScalarFnVTableExt;
34use crate::scalar_fn::VecExecutionArgs;
35use crate::scalar_fn::fns::binary::Binary;
36use crate::scalar_fn::fns::operators::CompareOperator;
37use crate::scalar_fn::fns::operators::Operator;
38
39/// A dynamic comparison expression can be used to capture a comparison to a value that can change
40/// during the execution of a query, such as when a compute engine pushes down an ORDER BY + LIMIT
41/// operation and is able to progressively tighten the bounds of the filter.
42#[derive(Clone)]
43pub struct DynamicComparison;
44
45impl ScalarFnVTable for DynamicComparison {
46    type Options = DynamicComparisonExpr;
47
48    fn id(&self) -> ScalarFnId {
49        static ID: CachedId = CachedId::new("vortex.dynamic");
50        *ID
51    }
52
53    fn arity(&self, _options: &Self::Options) -> Arity {
54        Arity::Exact(1)
55    }
56
57    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
58        match child_idx {
59            0 => ChildName::from("lhs"),
60            _ => unreachable!(),
61        }
62    }
63
64    fn fmt_sql(
65        &self,
66        dynamic: &DynamicComparisonExpr,
67        expr: &Expression,
68        f: &mut Formatter<'_>,
69    ) -> std::fmt::Result {
70        expr.child(0).fmt_sql(f)?;
71        write!(f, " {} dynamic(", dynamic.operator)?;
72        match dynamic.scalar() {
73            None => write!(f, "scalar=<none>")?,
74            Some(scalar) => write!(f, "scalar={scalar}")?,
75        }
76        write!(f, ")")
77    }
78
79    fn return_dtype(
80        &self,
81        dynamic: &DynamicComparisonExpr,
82        arg_dtypes: &[DType],
83    ) -> VortexResult<DType> {
84        let lhs = &arg_dtypes[0];
85        if !dynamic.rhs.dtype.eq_ignore_nullability(lhs) {
86            vortex_bail!(
87                "Incompatible dtypes for dynamic comparison: expected {} (ignore nullability) but got {}",
88                &dynamic.rhs.dtype,
89                lhs
90            );
91        }
92        Ok(DType::Bool(
93            lhs.nullability() | dynamic.rhs.dtype.nullability(),
94        ))
95    }
96
97    fn execute(
98        &self,
99        data: &Self::Options,
100        args: &dyn ExecutionArgs,
101        ctx: &mut ExecutionCtx,
102    ) -> VortexResult<ArrayRef> {
103        if let Some(scalar) = data.rhs.scalar() {
104            let lhs = args.get(0)?;
105            let rhs = ConstantArray::new(scalar, args.row_count()).into_array();
106
107            let delegate_args = VecExecutionArgs::new(vec![lhs, rhs], args.row_count());
108            return Binary
109                .bind(Operator::from(data.operator))
110                .execute(&delegate_args, ctx);
111        }
112        let ret_dtype =
113            DType::Bool(args.get(0)?.dtype().nullability() | data.rhs.dtype.nullability());
114
115        Ok(ConstantArray::new(
116            Scalar::try_new(ret_dtype, Some(data.default.into()))?,
117            args.row_count(),
118        )
119        .into_array())
120    }
121
122    // Defer to the child
123    fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
124        false
125    }
126}
127
128#[derive(Clone, Debug)]
129pub struct DynamicComparisonExpr {
130    pub(crate) operator: CompareOperator,
131    pub(crate) rhs: Arc<Rhs>,
132    // Default value for the dynamic comparison.
133    pub(crate) default: bool,
134}
135
136impl DynamicComparisonExpr {
137    pub fn scalar(&self) -> Option<Scalar> {
138        (self.rhs.value)().map(|v| {
139            Scalar::try_new(self.rhs.dtype.clone(), Some(v))
140                .vortex_expect("`DynamicComparisonExpr` was invalid")
141        })
142    }
143}
144
145impl Display for DynamicComparisonExpr {
146    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
147        write!(
148            f,
149            "{} {}",
150            self.operator,
151            self.scalar()
152                .map_or_else(|| "<none>".to_string(), |v| v.to_string())
153        )
154    }
155}
156
157impl PartialEq for DynamicComparisonExpr {
158    fn eq(&self, other: &Self) -> bool {
159        self.operator == other.operator
160            && Arc::ptr_eq(&self.rhs, &other.rhs)
161            && self.default == other.default
162    }
163}
164impl Eq for DynamicComparisonExpr {}
165
166impl Hash for DynamicComparisonExpr {
167    fn hash<H: Hasher>(&self, state: &mut H) {
168        self.operator.hash(state);
169        Arc::as_ptr(&self.rhs).hash(state);
170        self.default.hash(state);
171    }
172}
173
174/// Hash and PartialEq are implemented based on the ptr of the value function, such that the
175/// internal value doesn't impact the hash of an expression tree.
176pub(crate) struct Rhs {
177    // The right-hand side value is a function that returns an `Option<ScalarValue>`.
178    pub(crate) value: Arc<dyn Fn() -> Option<ScalarValue> + Send + Sync>,
179    // The data type of the right-hand side value.
180    pub(crate) dtype: DType,
181}
182
183impl Rhs {
184    pub fn scalar(&self) -> Option<Scalar> {
185        (self.value)().map(|v| {
186            Scalar::try_new(self.dtype.clone(), Some(v)).vortex_expect("`Rhs` was invalid")
187        })
188    }
189}
190
191impl Debug for Rhs {
192    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
193        f.debug_struct("Rhs")
194            .field("value", &"<dyn Fn() -> Option<ScalarValue> + Send + Sync>")
195            .field("dtype", &self.dtype)
196            .finish()
197    }
198}
199
200/// A utility for checking whether any dynamic expressions have been updated.
201pub struct DynamicExprUpdates {
202    exprs: Box<[DynamicComparisonExpr]>,
203    // Track the latest observed versions of each dynamic expression, along with a version counter.
204    prev_versions: Mutex<(u64, Vec<Option<Scalar>>)>,
205}
206
207impl DynamicExprUpdates {
208    pub fn new(expr: &Expression) -> Option<Self> {
209        #[derive(Default)]
210        struct Visitor(Vec<DynamicComparisonExpr>);
211
212        impl NodeVisitor<'_> for Visitor {
213            type NodeTy = Expression;
214
215            fn visit_down(&mut self, node: &'_ Self::NodeTy) -> VortexResult<TraversalOrder> {
216                if let Some(dynamic) = node.as_opt::<DynamicComparison>() {
217                    self.0.push(dynamic.clone());
218                }
219                Ok(TraversalOrder::Continue)
220            }
221        }
222
223        let mut visitor = Visitor::default();
224        expr.accept(&mut visitor).vortex_expect("Infallible");
225
226        if visitor.0.is_empty() {
227            return None;
228        }
229
230        let exprs = visitor.0.into_boxed_slice();
231        let prev_versions = exprs
232            .iter()
233            .map(|expr| {
234                (expr.rhs.value)().map(|v| {
235                    Scalar::try_new(expr.rhs.dtype.clone(), Some(v))
236                        .vortex_expect("`DynamicExprUpdates` was invalid")
237                })
238            })
239            .collect();
240
241        Some(Self {
242            exprs,
243            prev_versions: Mutex::new((0, prev_versions)),
244        })
245    }
246
247    pub fn version(&self) -> u64 {
248        let mut guard = self.prev_versions.lock();
249
250        let mut updated = false;
251        for (i, expr) in self.exprs.iter().enumerate() {
252            let current = expr.scalar();
253            if current != guard.1[i] {
254                // At least one expression has been updated.
255                // We don't bail out early in order to avoid false positives for future calls
256                // to `is_updated`.
257                updated = true;
258                guard.1[i] = current;
259            }
260        }
261
262        if updated {
263            guard.0 += 1;
264        }
265
266        guard.0
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use std::sync::atomic::AtomicI32;
273    use std::sync::atomic::Ordering;
274
275    use vortex_buffer::buffer;
276    use vortex_error::VortexResult;
277
278    use super::*;
279    use crate::IntoArray;
280    use crate::VortexSessionExecute;
281    use crate::array_session;
282    use crate::arrays::BoolArray;
283    use crate::assert_arrays_eq;
284    use crate::dtype::DType;
285    use crate::dtype::Nullability;
286    use crate::dtype::PType;
287    use crate::expr::dynamic;
288    use crate::expr::root;
289    #[test]
290    fn return_dtype_bool() -> VortexResult<()> {
291        let expr = dynamic(
292            CompareOperator::Lt,
293            || Some(5i32.into()),
294            DType::Primitive(PType::I32, Nullability::NonNullable),
295            true,
296            root(),
297        );
298        let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
299        assert_eq!(
300            expr.return_dtype(&input_dtype)?,
301            DType::Bool(Nullability::NonNullable)
302        );
303        Ok(())
304    }
305
306    #[test]
307    fn execute_with_value() -> VortexResult<()> {
308        let mut ctx = array_session().create_execution_ctx();
309        let input = buffer![1i32, 5, 10].into_array();
310        let expr = dynamic(
311            CompareOperator::Lt,
312            || Some(5i32.into()),
313            DType::Primitive(PType::I32, Nullability::NonNullable),
314            true,
315            root(),
316        );
317        let result = input.apply(&expr)?;
318        assert_arrays_eq!(result, BoolArray::from_iter([true, false, false]), &mut ctx);
319        Ok(())
320    }
321
322    #[test]
323    fn execute_without_value_default_true() -> VortexResult<()> {
324        let mut ctx = array_session().create_execution_ctx();
325        let input = buffer![1i32, 5, 10].into_array();
326        let expr = dynamic(
327            CompareOperator::Lt,
328            || None,
329            DType::Primitive(PType::I32, Nullability::NonNullable),
330            true,
331            root(),
332        );
333        let result = input.apply(&expr)?;
334        assert_arrays_eq!(result, BoolArray::from_iter([true, true, true]), &mut ctx);
335        Ok(())
336    }
337
338    #[test]
339    fn execute_without_value_default_false() -> VortexResult<()> {
340        let mut ctx = array_session().create_execution_ctx();
341        let input = buffer![1i32, 5, 10].into_array();
342        let expr = dynamic(
343            CompareOperator::Lt,
344            || None,
345            DType::Primitive(PType::I32, Nullability::NonNullable),
346            false,
347            root(),
348        );
349        let result = input.apply(&expr)?;
350        assert_arrays_eq!(
351            result,
352            BoolArray::from_iter([false, false, false]),
353            &mut ctx
354        );
355        Ok(())
356    }
357
358    #[test]
359    fn execute_value_flips() -> VortexResult<()> {
360        let mut ctx = array_session().create_execution_ctx();
361        let threshold = Arc::new(AtomicI32::new(5));
362        let threshold_clone = Arc::clone(&threshold);
363        let expr = dynamic(
364            CompareOperator::Lt,
365            move || Some(threshold_clone.load(Ordering::SeqCst).into()),
366            DType::Primitive(PType::I32, Nullability::NonNullable),
367            true,
368            root(),
369        );
370        let input = buffer![1i32, 5, 10].into_array();
371
372        let result = input.clone().apply(&expr)?;
373        assert_arrays_eq!(result, BoolArray::from_iter([true, false, false]), &mut ctx);
374
375        threshold.store(10, Ordering::SeqCst);
376        let result = input.apply(&expr)?;
377        assert_arrays_eq!(result, BoolArray::from_iter([true, true, false]), &mut ctx);
378
379        Ok(())
380    }
381}