vortex_expr/exprs/
dynamic.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::hash::{Hash, Hasher};
6use std::sync::Arc;
7
8use parking_lot::Mutex;
9use vortex_array::arrays::ConstantArray;
10use vortex_array::compute::{Operator, compare};
11use vortex_array::{Array, ArrayRef, DeserializeMetadata, IntoArray, ProstMetadata};
12use vortex_dtype::DType;
13use vortex_error::{VortexExpect, VortexResult, vortex_bail};
14use vortex_proto::expr as pb;
15use vortex_scalar::{Scalar, ScalarValue};
16
17use crate::display::{DisplayAs, DisplayFormat};
18use crate::traversal::{NodeExt, NodeVisitor, TraversalOrder};
19use crate::{
20    AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, StatsCatalog, VTable, vtable,
21};
22
23vtable!(DynamicComparison);
24
25/// A dynamic comparison expression can be used to capture a comparison to a value that can change
26/// during the execution of a query, such as when a compute engine pushes down an ORDER BY + LIMIT
27/// operation and is able to progressively tighten the bounds of the filter.
28#[derive(Clone, Debug)]
29pub struct DynamicComparisonExpr {
30    lhs: ExprRef,
31    operator: Operator,
32    rhs: Arc<Rhs>,
33    // Default value for the dynamic comparison.
34    default: bool,
35}
36
37impl PartialEq for DynamicComparisonExpr {
38    fn eq(&self, other: &Self) -> bool {
39        self.default == other.default
40            && self.operator == other.operator
41            && self.lhs.eq(&other.lhs)
42            && Arc::ptr_eq(&self.rhs.value, &other.rhs.value)
43            && self.rhs.dtype == other.rhs.dtype
44    }
45}
46impl Eq for DynamicComparisonExpr {}
47
48impl Hash for DynamicComparisonExpr {
49    fn hash<H: Hasher>(&self, state: &mut H) {
50        self.default.hash(state);
51        self.operator.hash(state);
52        self.lhs.hash(state);
53        Arc::as_ptr(&self.rhs.value).hash(state);
54        self.rhs.dtype.hash(state);
55    }
56}
57
58/// Hash and PartialEq are implemented based on the ptr of the value function, such that the
59/// internal value doesn't impact the hash of an expression tree.
60struct Rhs {
61    // The right-hand side value is a function that returns an `Option<ScalarValue>`.
62    value: Arc<dyn Fn() -> Option<ScalarValue> + Send + Sync>,
63    // The data type of the right-hand side value.
64    dtype: DType,
65}
66
67impl Debug for Rhs {
68    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69        f.debug_struct("Rhs")
70            .field("value", &"<dyn Fn() -> Option<ScalarValue> + Send + Sync>")
71            .field("dtype", &self.dtype)
72            .finish()
73    }
74}
75
76pub struct DynamicComparisonExprEncoding;
77
78impl VTable for DynamicComparisonVTable {
79    type Expr = DynamicComparisonExpr;
80    type Encoding = DynamicComparisonExprEncoding;
81    type Metadata = ProstMetadata<pb::LiteralOpts>;
82
83    fn id(_encoding: &Self::Encoding) -> ExprId {
84        ExprId::new_ref("dynamic")
85    }
86
87    fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
88        ExprEncodingRef::new_ref(DynamicComparisonExprEncoding.as_ref())
89    }
90
91    fn metadata(_expr: &Self::Expr) -> Option<Self::Metadata> {
92        None
93    }
94
95    fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
96        vec![&expr.lhs]
97    }
98
99    fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
100        Ok(DynamicComparisonExpr {
101            lhs: children[0].clone(),
102            operator: expr.operator,
103            rhs: expr.rhs.clone(),
104            default: expr.default,
105        })
106    }
107
108    fn build(
109        _encoding: &Self::Encoding,
110        _metadata: &<Self::Metadata as DeserializeMetadata>::Output,
111        _children: Vec<ExprRef>,
112    ) -> VortexResult<Self::Expr> {
113        vortex_bail!("DynamicComparison expression does not support building from metadata");
114    }
115
116    fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
117        if let Some(value) = expr.scalar() {
118            let lhs = expr.lhs.evaluate(scope)?;
119            let rhs = ConstantArray::new(value, scope.len());
120            return compare(lhs.as_ref(), rhs.as_ref(), expr.operator);
121        }
122
123        // Otherwise, we return the default value.
124        let lhs = expr.return_dtype(scope.dtype())?;
125        Ok(ConstantArray::new(
126            Scalar::new(
127                DType::Bool(lhs.nullability() | expr.rhs.dtype.nullability()),
128                expr.default.into(),
129            ),
130            scope.len(),
131        )
132        .into_array())
133    }
134
135    fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
136        let lhs = expr.lhs.return_dtype(scope)?;
137        if !expr.rhs.dtype.eq_ignore_nullability(&lhs) {
138            vortex_bail!(
139                "Incompatible dtypes for dynamic comparison: expected {} (ignore nullability) but got {}",
140                &expr.rhs.dtype,
141                lhs
142            );
143        }
144        Ok(DType::Bool(
145            lhs.nullability() | expr.rhs.dtype.nullability(),
146        ))
147    }
148}
149
150impl DynamicComparisonExpr {
151    pub fn new(
152        rhs: ExprRef,
153        operator: Operator,
154        rhs_value: impl Fn() -> Option<ScalarValue> + Send + Sync + 'static,
155        rhs_dtype: DType,
156        default: bool,
157    ) -> Self {
158        DynamicComparisonExpr {
159            lhs: rhs,
160            operator,
161            rhs: Arc::new(Rhs {
162                value: Arc::new(rhs_value),
163                dtype: rhs_dtype,
164            }),
165            default,
166        }
167    }
168
169    pub fn scalar(&self) -> Option<Scalar> {
170        (self.rhs.value)().map(|v| Scalar::new(self.rhs.dtype.clone(), v))
171    }
172}
173
174impl DisplayAs for DynamicComparisonExpr {
175    fn fmt_as(&self, df: DisplayFormat, f: &mut std::fmt::Formatter) -> std::fmt::Result {
176        match df {
177            DisplayFormat::Compact => {
178                write!(
179                    f,
180                    "{} {} dynamic({})",
181                    &self.lhs, self.operator, &self.rhs.dtype,
182                )
183            }
184            DisplayFormat::Tree => {
185                write!(f, "DynamicComparison")
186            }
187        }
188    }
189}
190
191impl AnalysisExpr for DynamicComparisonExpr {
192    fn stat_falsification(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
193        match self.operator {
194            Operator::Gt => Some(
195                DynamicComparisonExpr {
196                    lhs: self.lhs.max(catalog)?,
197                    operator: Operator::Lte,
198                    rhs: self.rhs.clone(),
199                    default: !self.default,
200                }
201                .into_expr(),
202            ),
203            Operator::Gte => Some(
204                DynamicComparisonExpr {
205                    lhs: self.lhs.max(catalog)?,
206                    operator: Operator::Lt,
207                    rhs: self.rhs.clone(),
208                    default: !self.default,
209                }
210                .into_expr(),
211            ),
212            Operator::Lt => Some(
213                DynamicComparisonExpr {
214                    lhs: self.lhs.min(catalog)?,
215                    operator: Operator::Gte,
216                    rhs: self.rhs.clone(),
217                    default: !self.default,
218                }
219                .into_expr(),
220            ),
221            Operator::Lte => Some(
222                DynamicComparisonExpr {
223                    lhs: self.lhs.min(catalog)?,
224                    operator: Operator::Gt,
225                    rhs: self.rhs.clone(),
226                    default: !self.default,
227                }
228                .into_expr(),
229            ),
230            _ => None,
231        }
232    }
233}
234
235/// A utility for checking whether any dynamic expressions have been updated.
236pub struct DynamicExprUpdates {
237    exprs: Box<[DynamicComparisonExpr]>,
238    // Track the latest observed versions of each dynamic expression, along with a version counter.
239    prev_versions: Mutex<(u64, Vec<Option<Scalar>>)>,
240}
241
242impl DynamicExprUpdates {
243    pub fn new(expr: &ExprRef) -> Option<Self> {
244        #[derive(Default)]
245        struct Visitor(Vec<DynamicComparisonExpr>);
246
247        impl NodeVisitor<'_> for Visitor {
248            type NodeTy = ExprRef;
249
250            fn visit_down(&mut self, node: &'_ Self::NodeTy) -> VortexResult<TraversalOrder> {
251                if let Some(dynamic) = node.as_opt::<DynamicComparisonVTable>() {
252                    self.0.push(dynamic.clone());
253                }
254                Ok(TraversalOrder::Continue)
255            }
256        }
257
258        let mut visitor = Visitor::default();
259        expr.accept(&mut visitor).vortex_expect("Infallible");
260
261        if visitor.0.is_empty() {
262            return None;
263        }
264
265        let exprs = visitor.0.into_boxed_slice();
266        let prev_versions = exprs
267            .iter()
268            .map(|expr| (expr.rhs.value)().map(|v| Scalar::new(expr.rhs.dtype.clone(), v)))
269            .collect();
270
271        Some(Self {
272            exprs,
273            prev_versions: Mutex::new((0, prev_versions)),
274        })
275    }
276
277    pub fn version(&self) -> u64 {
278        let mut guard = self.prev_versions.lock();
279
280        let mut updated = false;
281        for (i, expr) in self.exprs.iter().enumerate() {
282            let current = expr.scalar();
283            if current != guard.1[i] {
284                // At least one expression has been updated.
285                // We don't bail out early in order to avoid false positives for future calls
286                // to `is_updated`.
287                updated = true;
288                guard.1[i] = current;
289            }
290        }
291
292        if updated {
293            guard.0 += 1;
294        }
295
296        guard.0
297    }
298}