vortex_array/expr/exprs/
dynamic.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::hash::Hasher;
9use std::sync::Arc;
10
11use parking_lot::Mutex;
12use vortex_dtype::DType;
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15use vortex_error::vortex_bail;
16use vortex_scalar::Scalar;
17use vortex_scalar::ScalarValue;
18use vortex_vector::Datum;
19use vortex_vector::Scalar as VectorScalar;
20use vortex_vector::bool::BoolScalar;
21
22use crate::Array;
23use crate::ArrayRef;
24use crate::IntoArray;
25use crate::arrays::ConstantArray;
26use crate::compute::Operator;
27use crate::compute::compare;
28use crate::expr::Arity;
29use crate::expr::Binary;
30use crate::expr::ChildName;
31use crate::expr::ExecutionArgs;
32use crate::expr::ExprId;
33use crate::expr::Expression;
34use crate::expr::StatsCatalog;
35use crate::expr::VTable;
36use crate::expr::VTableExt;
37use crate::expr::traversal::NodeExt;
38use crate::expr::traversal::NodeVisitor;
39use crate::expr::traversal::TraversalOrder;
40
41/// A dynamic comparison expression can be used to capture a comparison to a value that can change
42/// during the execution of a query, such as when a compute engine pushes down an ORDER BY + LIMIT
43/// operation and is able to progressively tighten the bounds of the filter.
44pub struct DynamicComparison;
45
46impl VTable for DynamicComparison {
47    type Options = DynamicComparisonExpr;
48
49    fn id(&self) -> ExprId {
50        ExprId::new_ref("vortex.dynamic")
51    }
52
53    fn arity(&self, _options: &Self::Options) -> Arity {
54        Arity::Exact(1)
55    }
56
57    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
58        match child_idx {
59            0 => ChildName::from("lhs"),
60            _ => unreachable!(),
61        }
62    }
63
64    fn fmt_sql(
65        &self,
66        dynamic: &DynamicComparisonExpr,
67        expr: &Expression,
68        f: &mut Formatter<'_>,
69    ) -> std::fmt::Result {
70        expr.child(0).fmt_sql(f)?;
71        write!(f, " {} dynamic(", dynamic)?;
72        match dynamic.scalar() {
73            None => write!(f, "<none>")?,
74            Some(scalar) => write!(f, "{}", scalar)?,
75        }
76        write!(f, ")")
77    }
78
79    fn return_dtype(
80        &self,
81        dynamic: &DynamicComparisonExpr,
82        arg_dtypes: &[DType],
83    ) -> VortexResult<DType> {
84        let lhs = &arg_dtypes[0];
85        if !dynamic.rhs.dtype.eq_ignore_nullability(lhs) {
86            vortex_bail!(
87                "Incompatible dtypes for dynamic comparison: expected {} (ignore nullability) but got {}",
88                &dynamic.rhs.dtype,
89                lhs
90            );
91        }
92        Ok(DType::Bool(
93            lhs.nullability() | dynamic.rhs.dtype.nullability(),
94        ))
95    }
96
97    fn evaluate(
98        &self,
99        dynamic: &DynamicComparisonExpr,
100        expr: &Expression,
101        scope: &ArrayRef,
102    ) -> VortexResult<ArrayRef> {
103        if let Some(value) = dynamic.rhs.scalar() {
104            let lhs = expr.child(0).evaluate(scope)?;
105            let rhs = ConstantArray::new(value, scope.len());
106            return compare(lhs.as_ref(), rhs.as_ref(), dynamic.operator);
107        }
108
109        // Otherwise, we return the default value.
110        let lhs = expr.return_dtype(scope.dtype())?;
111        Ok(ConstantArray::new(
112            Scalar::new(
113                DType::Bool(lhs.nullability() | dynamic.rhs.dtype.nullability()),
114                dynamic.default.into(),
115            ),
116            scope.len(),
117        )
118        .into_array())
119    }
120
121    fn execute(&self, data: &Self::Options, args: ExecutionArgs) -> VortexResult<Datum> {
122        if let Some(scalar) = data.rhs.scalar() {
123            let [lhs]: [Datum; _] = args
124                .datums
125                .try_into()
126                .map_err(|_| vortex_error::vortex_err!("Wrong arg count for DynamicComparison"))?;
127            let rhs_vector_scalar = scalar.to_vector_scalar();
128            let rhs = Datum::Scalar(rhs_vector_scalar);
129
130            return Binary.bind(data.operator.into()).execute(ExecutionArgs {
131                datums: vec![lhs, rhs],
132                dtypes: args.dtypes,
133                row_count: args.row_count,
134                return_dtype: args.return_dtype,
135            });
136        }
137
138        Ok(Datum::Scalar(VectorScalar::Bool(BoolScalar::new(Some(
139            data.default,
140        )))))
141    }
142
143    fn stat_falsification(
144        &self,
145        dynamic: &DynamicComparisonExpr,
146        expr: &Expression,
147        catalog: &dyn StatsCatalog,
148    ) -> Option<Expression> {
149        let lhs = expr.child(0);
150        match dynamic.operator {
151            Operator::Gt => Some(DynamicComparison.new_expr(
152                DynamicComparisonExpr {
153                    operator: Operator::Lte,
154                    rhs: dynamic.rhs.clone(),
155                    default: !dynamic.default,
156                },
157                vec![lhs.stat_max(catalog)?],
158            )),
159            Operator::Gte => Some(DynamicComparison.new_expr(
160                DynamicComparisonExpr {
161                    operator: Operator::Lt,
162                    rhs: dynamic.rhs.clone(),
163                    default: !dynamic.default,
164                },
165                vec![lhs.stat_max(catalog)?],
166            )),
167            Operator::Lt => Some(DynamicComparison.new_expr(
168                DynamicComparisonExpr {
169                    operator: Operator::Gte,
170                    rhs: dynamic.rhs.clone(),
171                    default: !dynamic.default,
172                },
173                vec![lhs.stat_min(catalog)?],
174            )),
175            Operator::Lte => Some(DynamicComparison.new_expr(
176                DynamicComparisonExpr {
177                    operator: Operator::Gt,
178                    rhs: dynamic.rhs.clone(),
179                    default: !dynamic.default,
180                },
181                vec![lhs.stat_min(catalog)?],
182            )),
183            _ => None,
184        }
185    }
186
187    // Defer to the child
188    fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
189        false
190    }
191}
192
193pub fn dynamic(
194    operator: Operator,
195    rhs_value: impl Fn() -> Option<ScalarValue> + Send + Sync + 'static,
196    rhs_dtype: DType,
197    default: bool,
198    lhs: Expression,
199) -> Expression {
200    DynamicComparison.new_expr(
201        DynamicComparisonExpr {
202            operator,
203            rhs: Arc::new(Rhs {
204                value: Arc::new(rhs_value),
205                dtype: rhs_dtype,
206            }),
207            default,
208        },
209        [lhs],
210    )
211}
212
213#[derive(Clone, Debug)]
214pub struct DynamicComparisonExpr {
215    operator: Operator,
216    rhs: Arc<Rhs>,
217    // Default value for the dynamic comparison.
218    default: bool,
219}
220
221impl DynamicComparisonExpr {
222    pub fn scalar(&self) -> Option<Scalar> {
223        (self.rhs.value)().map(|v| Scalar::new(self.rhs.dtype.clone(), v))
224    }
225}
226
227impl Display for DynamicComparisonExpr {
228    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
229        write!(
230            f,
231            "{} {}",
232            self.operator,
233            self.scalar()
234                .map_or_else(|| "<none>".to_string(), |v| v.to_string())
235        )
236    }
237}
238
239impl PartialEq for DynamicComparisonExpr {
240    fn eq(&self, other: &Self) -> bool {
241        self.operator == other.operator
242            && Arc::ptr_eq(&self.rhs, &other.rhs)
243            && self.default == other.default
244    }
245}
246impl Eq for DynamicComparisonExpr {}
247
248impl Hash for DynamicComparisonExpr {
249    fn hash<H: Hasher>(&self, state: &mut H) {
250        self.operator.hash(state);
251        Arc::as_ptr(&self.rhs).hash(state);
252        self.default.hash(state);
253    }
254}
255
256/// Hash and PartialEq are implemented based on the ptr of the value function, such that the
257/// internal value doesn't impact the hash of an expression tree.
258struct Rhs {
259    // The right-hand side value is a function that returns an `Option<ScalarValue>`.
260    value: Arc<dyn Fn() -> Option<ScalarValue> + Send + Sync>,
261    // The data type of the right-hand side value.
262    dtype: DType,
263}
264
265impl Rhs {
266    pub fn scalar(&self) -> Option<Scalar> {
267        (self.value)().map(|v| Scalar::new(self.dtype.clone(), v))
268    }
269}
270
271impl Debug for Rhs {
272    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
273        f.debug_struct("Rhs")
274            .field("value", &"<dyn Fn() -> Option<ScalarValue> + Send + Sync>")
275            .field("dtype", &self.dtype)
276            .finish()
277    }
278}
279
280/// A utility for checking whether any dynamic expressions have been updated.
281pub struct DynamicExprUpdates {
282    exprs: Box<[DynamicComparisonExpr]>,
283    // Track the latest observed versions of each dynamic expression, along with a version counter.
284    prev_versions: Mutex<(u64, Vec<Option<Scalar>>)>,
285}
286
287impl DynamicExprUpdates {
288    pub fn new(expr: &Expression) -> Option<Self> {
289        #[derive(Default)]
290        struct Visitor(Vec<DynamicComparisonExpr>);
291
292        impl NodeVisitor<'_> for Visitor {
293            type NodeTy = Expression;
294
295            fn visit_down(&mut self, node: &'_ Self::NodeTy) -> VortexResult<TraversalOrder> {
296                if let Some(dynamic) = node.as_opt::<DynamicComparison>() {
297                    self.0.push(dynamic.clone());
298                }
299                Ok(TraversalOrder::Continue)
300            }
301        }
302
303        let mut visitor = Visitor::default();
304        expr.accept(&mut visitor).vortex_expect("Infallible");
305
306        if visitor.0.is_empty() {
307            return None;
308        }
309
310        let exprs = visitor.0.into_boxed_slice();
311        let prev_versions = exprs
312            .iter()
313            .map(|expr| (expr.rhs.value)().map(|v| Scalar::new(expr.rhs.dtype.clone(), v)))
314            .collect();
315
316        Some(Self {
317            exprs,
318            prev_versions: Mutex::new((0, prev_versions)),
319        })
320    }
321
322    pub fn version(&self) -> u64 {
323        let mut guard = self.prev_versions.lock();
324
325        let mut updated = false;
326        for (i, expr) in self.exprs.iter().enumerate() {
327            let current = expr.scalar();
328            if current != guard.1[i] {
329                // At least one expression has been updated.
330                // We don't bail out early in order to avoid false positives for future calls
331                // to `is_updated`.
332                updated = true;
333                guard.1[i] = current;
334            }
335        }
336
337        if updated {
338            guard.0 += 1;
339        }
340
341        guard.0
342    }
343}