vortex_array/expr/exprs/
dynamic.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::{Debug, Display, Formatter};
5use std::hash::{Hash, Hasher};
6use std::sync::Arc;
7
8use parking_lot::Mutex;
9use vortex_dtype::DType;
10use vortex_error::{VortexExpect, VortexResult, vortex_bail};
11use vortex_scalar::{Scalar, ScalarValue};
12
13use crate::arrays::ConstantArray;
14use crate::compute::{Operator, compare};
15use crate::expr::traversal::{NodeExt, NodeVisitor, TraversalOrder};
16use crate::expr::{ChildName, ExprId, Expression, ExpressionView, StatsCatalog, VTable, VTableExt};
17use crate::{Array, ArrayRef, IntoArray};
18
19/// A dynamic comparison expression can be used to capture a comparison to a value that can change
20/// during the execution of a query, such as when a compute engine pushes down an ORDER BY + LIMIT
21/// operation and is able to progressively tighten the bounds of the filter.
22pub struct DynamicComparison;
23
24impl VTable for DynamicComparison {
25    type Instance = DynamicComparisonExpr;
26
27    fn id(&self) -> ExprId {
28        ExprId::new_ref("vortex.dynamic")
29    }
30
31    fn validate(&self, expr: &ExpressionView<Self>) -> VortexResult<()> {
32        if expr.children().len() != 1 {
33            vortex_bail!(
34                "DynamicComparison expression requires exactly one child, got {}",
35                expr.children().len()
36            );
37        }
38        Ok(())
39    }
40
41    fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName {
42        match child_idx {
43            0 => ChildName::from("lhs"),
44            _ => unreachable!(),
45        }
46    }
47
48    fn fmt_sql(&self, expr: &ExpressionView<Self>, f: &mut Formatter<'_>) -> std::fmt::Result {
49        expr.lhs().fmt_sql(f)?;
50        write!(f, " {} dynamic(", expr.data())?;
51        match expr.scalar() {
52            None => write!(f, "<none>")?,
53            Some(scalar) => write!(f, "{}", scalar)?,
54        }
55        write!(f, ")")
56    }
57
58    fn return_dtype(&self, expr: &ExpressionView<Self>, scope: &DType) -> VortexResult<DType> {
59        let lhs = expr.lhs().return_dtype(scope)?;
60        if !expr.data().rhs.dtype.eq_ignore_nullability(&lhs) {
61            vortex_bail!(
62                "Incompatible dtypes for dynamic comparison: expected {} (ignore nullability) but got {}",
63                &expr.data().rhs.dtype,
64                lhs
65            );
66        }
67        Ok(DType::Bool(
68            lhs.nullability() | expr.data().rhs.dtype.nullability(),
69        ))
70    }
71
72    fn evaluate(&self, expr: &ExpressionView<Self>, scope: &ArrayRef) -> VortexResult<ArrayRef> {
73        if let Some(value) = expr.scalar() {
74            let lhs = expr.lhs().evaluate(scope)?;
75            let rhs = ConstantArray::new(value, scope.len());
76            return compare(lhs.as_ref(), rhs.as_ref(), expr.data().operator);
77        }
78
79        // Otherwise, we return the default value.
80        let lhs = expr.return_dtype(scope.dtype())?;
81        Ok(ConstantArray::new(
82            Scalar::new(
83                DType::Bool(lhs.nullability() | expr.data().rhs.dtype.nullability()),
84                expr.data().default.into(),
85            ),
86            scope.len(),
87        )
88        .into_array())
89    }
90
91    fn stat_falsification(
92        &self,
93        expr: &ExpressionView<DynamicComparison>,
94        catalog: &mut dyn StatsCatalog,
95    ) -> Option<Expression> {
96        match expr.data().operator {
97            Operator::Gt => Some(DynamicComparison.new_expr(
98                DynamicComparisonExpr {
99                    operator: Operator::Lte,
100                    rhs: expr.data().rhs.clone(),
101                    default: !expr.data().default,
102                },
103                vec![expr.lhs().stat_max(catalog)?],
104            )),
105            Operator::Gte => Some(DynamicComparison.new_expr(
106                DynamicComparisonExpr {
107                    operator: Operator::Lt,
108                    rhs: expr.data().rhs.clone(),
109                    default: !expr.data().default,
110                },
111                vec![expr.lhs().stat_max(catalog)?],
112            )),
113            Operator::Lt => Some(DynamicComparison.new_expr(
114                DynamicComparisonExpr {
115                    operator: Operator::Gte,
116                    rhs: expr.data().rhs.clone(),
117                    default: !expr.data().default,
118                },
119                vec![expr.lhs().stat_min(catalog)?],
120            )),
121            Operator::Lte => Some(DynamicComparison.new_expr(
122                DynamicComparisonExpr {
123                    operator: Operator::Gt,
124                    rhs: expr.data().rhs.clone(),
125                    default: !expr.data().default,
126                },
127                vec![expr.lhs().stat_min(catalog)?],
128            )),
129            _ => None,
130        }
131    }
132}
133
134pub fn dynamic(
135    operator: Operator,
136    rhs_value: impl Fn() -> Option<ScalarValue> + Send + Sync + 'static,
137    rhs_dtype: DType,
138    default: bool,
139    lhs: Expression,
140) -> Expression {
141    DynamicComparison.new_expr(
142        DynamicComparisonExpr {
143            operator,
144            rhs: Arc::new(Rhs {
145                value: Arc::new(rhs_value),
146                dtype: rhs_dtype,
147            }),
148            default,
149        },
150        [lhs],
151    )
152}
153
154#[derive(Clone, Debug)]
155pub struct DynamicComparisonExpr {
156    operator: Operator,
157    rhs: Arc<Rhs>,
158    // Default value for the dynamic comparison.
159    default: bool,
160}
161
162impl DynamicComparisonExpr {
163    pub fn scalar(&self) -> Option<Scalar> {
164        (self.rhs.value)().map(|v| Scalar::new(self.rhs.dtype.clone(), v))
165    }
166}
167
168impl Display for DynamicComparisonExpr {
169    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
170        write!(
171            f,
172            "{} {}",
173            self.operator,
174            self.scalar()
175                .map_or("<none>".to_string(), |v| v.to_string())
176        )
177    }
178}
179
180impl PartialEq for DynamicComparisonExpr {
181    fn eq(&self, other: &Self) -> bool {
182        self.operator == other.operator
183            && Arc::ptr_eq(&self.rhs, &other.rhs)
184            && self.default == other.default
185    }
186}
187impl Eq for DynamicComparisonExpr {}
188
189impl Hash for DynamicComparisonExpr {
190    fn hash<H: Hasher>(&self, state: &mut H) {
191        self.operator.hash(state);
192        Arc::as_ptr(&self.rhs).hash(state);
193        self.default.hash(state);
194    }
195}
196
197/// Hash and PartialEq are implemented based on the ptr of the value function, such that the
198/// internal value doesn't impact the hash of an expression tree.
199struct Rhs {
200    // The right-hand side value is a function that returns an `Option<ScalarValue>`.
201    value: Arc<dyn Fn() -> Option<ScalarValue> + Send + Sync>,
202    // The data type of the right-hand side value.
203    dtype: DType,
204}
205
206impl Debug for Rhs {
207    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
208        f.debug_struct("Rhs")
209            .field("value", &"<dyn Fn() -> Option<ScalarValue> + Send + Sync>")
210            .field("dtype", &self.dtype)
211            .finish()
212    }
213}
214
215impl ExpressionView<'_, DynamicComparison> {
216    pub fn lhs(&self) -> &Expression {
217        &self.children()[0]
218    }
219
220    pub fn scalar(&self) -> Option<Scalar> {
221        (self.data().rhs.value)().map(|v| Scalar::new(self.data().rhs.dtype.clone(), v))
222    }
223}
224
225/// A utility for checking whether any dynamic expressions have been updated.
226pub struct DynamicExprUpdates {
227    exprs: Box<[DynamicComparisonExpr]>,
228    // Track the latest observed versions of each dynamic expression, along with a version counter.
229    prev_versions: Mutex<(u64, Vec<Option<Scalar>>)>,
230}
231
232impl DynamicExprUpdates {
233    pub fn new(expr: &Expression) -> Option<Self> {
234        #[derive(Default)]
235        struct Visitor(Vec<DynamicComparisonExpr>);
236
237        impl NodeVisitor<'_> for Visitor {
238            type NodeTy = Expression;
239
240            fn visit_down(&mut self, node: &'_ Self::NodeTy) -> VortexResult<TraversalOrder> {
241                if let Some(dynamic) = node.as_opt::<DynamicComparison>() {
242                    self.0.push(dynamic.data().clone());
243                }
244                Ok(TraversalOrder::Continue)
245            }
246        }
247
248        let mut visitor = Visitor::default();
249        expr.accept(&mut visitor).vortex_expect("Infallible");
250
251        if visitor.0.is_empty() {
252            return None;
253        }
254
255        let exprs = visitor.0.into_boxed_slice();
256        let prev_versions = exprs
257            .iter()
258            .map(|expr| (expr.rhs.value)().map(|v| Scalar::new(expr.rhs.dtype.clone(), v)))
259            .collect();
260
261        Some(Self {
262            exprs,
263            prev_versions: Mutex::new((0, prev_versions)),
264        })
265    }
266
267    pub fn version(&self) -> u64 {
268        let mut guard = self.prev_versions.lock();
269
270        let mut updated = false;
271        for (i, expr) in self.exprs.iter().enumerate() {
272            let current = expr.scalar();
273            if current != guard.1[i] {
274                // At least one expression has been updated.
275                // We don't bail out early in order to avoid false positives for future calls
276                // to `is_updated`.
277                updated = true;
278                guard.1[i] = current;
279            }
280        }
281
282        if updated {
283            guard.0 += 1;
284        }
285
286        guard.0
287    }
288}