vortex_array/expr/exprs/
dynamic.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::hash::Hasher;
9use std::sync::Arc;
10
11use parking_lot::Mutex;
12use vortex_dtype::DType;
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15use vortex_error::vortex_bail;
16use vortex_scalar::Scalar;
17use vortex_scalar::ScalarValue;
18
19use crate::Array;
20use crate::ArrayRef;
21use crate::IntoArray;
22use crate::arrays::ConstantArray;
23use crate::compute::Operator;
24use crate::compute::compare;
25use crate::expr::ChildName;
26use crate::expr::ExprId;
27use crate::expr::Expression;
28use crate::expr::ExpressionView;
29use crate::expr::StatsCatalog;
30use crate::expr::VTable;
31use crate::expr::VTableExt;
32use crate::expr::traversal::NodeExt;
33use crate::expr::traversal::NodeVisitor;
34use crate::expr::traversal::TraversalOrder;
35
36/// A dynamic comparison expression can be used to capture a comparison to a value that can change
37/// during the execution of a query, such as when a compute engine pushes down an ORDER BY + LIMIT
38/// operation and is able to progressively tighten the bounds of the filter.
39pub struct DynamicComparison;
40
41impl VTable for DynamicComparison {
42    type Instance = DynamicComparisonExpr;
43
44    fn id(&self) -> ExprId {
45        ExprId::new_ref("vortex.dynamic")
46    }
47
48    fn validate(&self, expr: &ExpressionView<Self>) -> VortexResult<()> {
49        if expr.children().len() != 1 {
50            vortex_bail!(
51                "DynamicComparison expression requires exactly one child, got {}",
52                expr.children().len()
53            );
54        }
55        Ok(())
56    }
57
58    fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName {
59        match child_idx {
60            0 => ChildName::from("lhs"),
61            _ => unreachable!(),
62        }
63    }
64
65    fn fmt_sql(&self, expr: &ExpressionView<Self>, f: &mut Formatter<'_>) -> std::fmt::Result {
66        expr.lhs().fmt_sql(f)?;
67        write!(f, " {} dynamic(", expr.data())?;
68        match expr.scalar() {
69            None => write!(f, "<none>")?,
70            Some(scalar) => write!(f, "{}", scalar)?,
71        }
72        write!(f, ")")
73    }
74
75    fn return_dtype(&self, expr: &ExpressionView<Self>, scope: &DType) -> VortexResult<DType> {
76        let lhs = expr.lhs().return_dtype(scope)?;
77        if !expr.data().rhs.dtype.eq_ignore_nullability(&lhs) {
78            vortex_bail!(
79                "Incompatible dtypes for dynamic comparison: expected {} (ignore nullability) but got {}",
80                &expr.data().rhs.dtype,
81                lhs
82            );
83        }
84        Ok(DType::Bool(
85            lhs.nullability() | expr.data().rhs.dtype.nullability(),
86        ))
87    }
88
89    fn evaluate(&self, expr: &ExpressionView<Self>, scope: &ArrayRef) -> VortexResult<ArrayRef> {
90        if let Some(value) = expr.scalar() {
91            let lhs = expr.lhs().evaluate(scope)?;
92            let rhs = ConstantArray::new(value, scope.len());
93            return compare(lhs.as_ref(), rhs.as_ref(), expr.data().operator);
94        }
95
96        // Otherwise, we return the default value.
97        let lhs = expr.return_dtype(scope.dtype())?;
98        Ok(ConstantArray::new(
99            Scalar::new(
100                DType::Bool(lhs.nullability() | expr.data().rhs.dtype.nullability()),
101                expr.data().default.into(),
102            ),
103            scope.len(),
104        )
105        .into_array())
106    }
107
108    fn stat_falsification(
109        &self,
110        expr: &ExpressionView<DynamicComparison>,
111        catalog: &dyn StatsCatalog,
112    ) -> Option<Expression> {
113        match expr.data().operator {
114            Operator::Gt => Some(DynamicComparison.new_expr(
115                DynamicComparisonExpr {
116                    operator: Operator::Lte,
117                    rhs: expr.data().rhs.clone(),
118                    default: !expr.data().default,
119                },
120                vec![expr.lhs().stat_max(catalog)?],
121            )),
122            Operator::Gte => Some(DynamicComparison.new_expr(
123                DynamicComparisonExpr {
124                    operator: Operator::Lt,
125                    rhs: expr.data().rhs.clone(),
126                    default: !expr.data().default,
127                },
128                vec![expr.lhs().stat_max(catalog)?],
129            )),
130            Operator::Lt => Some(DynamicComparison.new_expr(
131                DynamicComparisonExpr {
132                    operator: Operator::Gte,
133                    rhs: expr.data().rhs.clone(),
134                    default: !expr.data().default,
135                },
136                vec![expr.lhs().stat_min(catalog)?],
137            )),
138            Operator::Lte => Some(DynamicComparison.new_expr(
139                DynamicComparisonExpr {
140                    operator: Operator::Gt,
141                    rhs: expr.data().rhs.clone(),
142                    default: !expr.data().default,
143                },
144                vec![expr.lhs().stat_min(catalog)?],
145            )),
146            _ => None,
147        }
148    }
149
150    // Defer to the child
151    fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool {
152        false
153    }
154}
155
156pub fn dynamic(
157    operator: Operator,
158    rhs_value: impl Fn() -> Option<ScalarValue> + Send + Sync + 'static,
159    rhs_dtype: DType,
160    default: bool,
161    lhs: Expression,
162) -> Expression {
163    DynamicComparison.new_expr(
164        DynamicComparisonExpr {
165            operator,
166            rhs: Arc::new(Rhs {
167                value: Arc::new(rhs_value),
168                dtype: rhs_dtype,
169            }),
170            default,
171        },
172        [lhs],
173    )
174}
175
176#[derive(Clone, Debug)]
177pub struct DynamicComparisonExpr {
178    operator: Operator,
179    rhs: Arc<Rhs>,
180    // Default value for the dynamic comparison.
181    default: bool,
182}
183
184impl DynamicComparisonExpr {
185    pub fn scalar(&self) -> Option<Scalar> {
186        (self.rhs.value)().map(|v| Scalar::new(self.rhs.dtype.clone(), v))
187    }
188}
189
190impl Display for DynamicComparisonExpr {
191    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
192        write!(
193            f,
194            "{} {}",
195            self.operator,
196            self.scalar()
197                .map_or("<none>".to_string(), |v| v.to_string())
198        )
199    }
200}
201
202impl PartialEq for DynamicComparisonExpr {
203    fn eq(&self, other: &Self) -> bool {
204        self.operator == other.operator
205            && Arc::ptr_eq(&self.rhs, &other.rhs)
206            && self.default == other.default
207    }
208}
209impl Eq for DynamicComparisonExpr {}
210
211impl Hash for DynamicComparisonExpr {
212    fn hash<H: Hasher>(&self, state: &mut H) {
213        self.operator.hash(state);
214        Arc::as_ptr(&self.rhs).hash(state);
215        self.default.hash(state);
216    }
217}
218
219/// Hash and PartialEq are implemented based on the ptr of the value function, such that the
220/// internal value doesn't impact the hash of an expression tree.
221struct Rhs {
222    // The right-hand side value is a function that returns an `Option<ScalarValue>`.
223    value: Arc<dyn Fn() -> Option<ScalarValue> + Send + Sync>,
224    // The data type of the right-hand side value.
225    dtype: DType,
226}
227
228impl Debug for Rhs {
229    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
230        f.debug_struct("Rhs")
231            .field("value", &"<dyn Fn() -> Option<ScalarValue> + Send + Sync>")
232            .field("dtype", &self.dtype)
233            .finish()
234    }
235}
236
237impl ExpressionView<'_, DynamicComparison> {
238    pub fn lhs(&self) -> &Expression {
239        &self.children()[0]
240    }
241
242    pub fn scalar(&self) -> Option<Scalar> {
243        (self.data().rhs.value)().map(|v| Scalar::new(self.data().rhs.dtype.clone(), v))
244    }
245}
246
247/// A utility for checking whether any dynamic expressions have been updated.
248pub struct DynamicExprUpdates {
249    exprs: Box<[DynamicComparisonExpr]>,
250    // Track the latest observed versions of each dynamic expression, along with a version counter.
251    prev_versions: Mutex<(u64, Vec<Option<Scalar>>)>,
252}
253
254impl DynamicExprUpdates {
255    pub fn new(expr: &Expression) -> Option<Self> {
256        #[derive(Default)]
257        struct Visitor(Vec<DynamicComparisonExpr>);
258
259        impl NodeVisitor<'_> for Visitor {
260            type NodeTy = Expression;
261
262            fn visit_down(&mut self, node: &'_ Self::NodeTy) -> VortexResult<TraversalOrder> {
263                if let Some(dynamic) = node.as_opt::<DynamicComparison>() {
264                    self.0.push(dynamic.data().clone());
265                }
266                Ok(TraversalOrder::Continue)
267            }
268        }
269
270        let mut visitor = Visitor::default();
271        expr.accept(&mut visitor).vortex_expect("Infallible");
272
273        if visitor.0.is_empty() {
274            return None;
275        }
276
277        let exprs = visitor.0.into_boxed_slice();
278        let prev_versions = exprs
279            .iter()
280            .map(|expr| (expr.rhs.value)().map(|v| Scalar::new(expr.rhs.dtype.clone(), v)))
281            .collect();
282
283        Some(Self {
284            exprs,
285            prev_versions: Mutex::new((0, prev_versions)),
286        })
287    }
288
289    pub fn version(&self) -> u64 {
290        let mut guard = self.prev_versions.lock();
291
292        let mut updated = false;
293        for (i, expr) in self.exprs.iter().enumerate() {
294            let current = expr.scalar();
295            if current != guard.1[i] {
296                // At least one expression has been updated.
297                // We don't bail out early in order to avoid false positives for future calls
298                // to `is_updated`.
299                updated = true;
300                guard.1[i] = current;
301            }
302        }
303
304        if updated {
305            guard.0 += 1;
306        }
307
308        guard.0
309    }
310}