1use std::fmt::{Debug, Display, Formatter};
5use std::hash::{Hash, Hasher};
6use std::sync::Arc;
7
8use parking_lot::Mutex;
9use vortex_dtype::DType;
10use vortex_error::{VortexExpect, VortexResult, vortex_bail};
11use vortex_scalar::{Scalar, ScalarValue};
12
13use crate::arrays::ConstantArray;
14use crate::compute::{Operator, compare};
15use crate::expr::traversal::{NodeExt, NodeVisitor, TraversalOrder};
16use crate::expr::{ChildName, ExprId, Expression, ExpressionView, StatsCatalog, VTable, VTableExt};
17use crate::{Array, ArrayRef, IntoArray};
18
19pub struct DynamicComparison;
23
24impl VTable for DynamicComparison {
25 type Instance = DynamicComparisonExpr;
26
27 fn id(&self) -> ExprId {
28 ExprId::new_ref("vortex.dynamic")
29 }
30
31 fn validate(&self, expr: &ExpressionView<Self>) -> VortexResult<()> {
32 if expr.children().len() != 1 {
33 vortex_bail!(
34 "DynamicComparison expression requires exactly one child, got {}",
35 expr.children().len()
36 );
37 }
38 Ok(())
39 }
40
41 fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName {
42 match child_idx {
43 0 => ChildName::from("lhs"),
44 _ => unreachable!(),
45 }
46 }
47
48 fn fmt_sql(&self, expr: &ExpressionView<Self>, f: &mut Formatter<'_>) -> std::fmt::Result {
49 expr.lhs().fmt_sql(f)?;
50 write!(f, " {} dynamic(", expr.data())?;
51 match expr.scalar() {
52 None => write!(f, "<none>")?,
53 Some(scalar) => write!(f, "{}", scalar)?,
54 }
55 write!(f, ")")
56 }
57
58 fn return_dtype(&self, expr: &ExpressionView<Self>, scope: &DType) -> VortexResult<DType> {
59 let lhs = expr.lhs().return_dtype(scope)?;
60 if !expr.data().rhs.dtype.eq_ignore_nullability(&lhs) {
61 vortex_bail!(
62 "Incompatible dtypes for dynamic comparison: expected {} (ignore nullability) but got {}",
63 &expr.data().rhs.dtype,
64 lhs
65 );
66 }
67 Ok(DType::Bool(
68 lhs.nullability() | expr.data().rhs.dtype.nullability(),
69 ))
70 }
71
72 fn evaluate(&self, expr: &ExpressionView<Self>, scope: &ArrayRef) -> VortexResult<ArrayRef> {
73 if let Some(value) = expr.scalar() {
74 let lhs = expr.lhs().evaluate(scope)?;
75 let rhs = ConstantArray::new(value, scope.len());
76 return compare(lhs.as_ref(), rhs.as_ref(), expr.data().operator);
77 }
78
79 let lhs = expr.return_dtype(scope.dtype())?;
81 Ok(ConstantArray::new(
82 Scalar::new(
83 DType::Bool(lhs.nullability() | expr.data().rhs.dtype.nullability()),
84 expr.data().default.into(),
85 ),
86 scope.len(),
87 )
88 .into_array())
89 }
90
91 fn stat_falsification(
92 &self,
93 expr: &ExpressionView<DynamicComparison>,
94 catalog: &mut dyn StatsCatalog,
95 ) -> Option<Expression> {
96 match expr.data().operator {
97 Operator::Gt => Some(DynamicComparison.new_expr(
98 DynamicComparisonExpr {
99 operator: Operator::Lte,
100 rhs: expr.data().rhs.clone(),
101 default: !expr.data().default,
102 },
103 vec![expr.lhs().stat_max(catalog)?],
104 )),
105 Operator::Gte => Some(DynamicComparison.new_expr(
106 DynamicComparisonExpr {
107 operator: Operator::Lt,
108 rhs: expr.data().rhs.clone(),
109 default: !expr.data().default,
110 },
111 vec![expr.lhs().stat_max(catalog)?],
112 )),
113 Operator::Lt => Some(DynamicComparison.new_expr(
114 DynamicComparisonExpr {
115 operator: Operator::Gte,
116 rhs: expr.data().rhs.clone(),
117 default: !expr.data().default,
118 },
119 vec![expr.lhs().stat_min(catalog)?],
120 )),
121 Operator::Lte => Some(DynamicComparison.new_expr(
122 DynamicComparisonExpr {
123 operator: Operator::Gt,
124 rhs: expr.data().rhs.clone(),
125 default: !expr.data().default,
126 },
127 vec![expr.lhs().stat_min(catalog)?],
128 )),
129 _ => None,
130 }
131 }
132}
133
134pub fn dynamic(
135 operator: Operator,
136 rhs_value: impl Fn() -> Option<ScalarValue> + Send + Sync + 'static,
137 rhs_dtype: DType,
138 default: bool,
139 lhs: Expression,
140) -> Expression {
141 DynamicComparison.new_expr(
142 DynamicComparisonExpr {
143 operator,
144 rhs: Arc::new(Rhs {
145 value: Arc::new(rhs_value),
146 dtype: rhs_dtype,
147 }),
148 default,
149 },
150 [lhs],
151 )
152}
153
154#[derive(Clone, Debug)]
155pub struct DynamicComparisonExpr {
156 operator: Operator,
157 rhs: Arc<Rhs>,
158 default: bool,
160}
161
162impl DynamicComparisonExpr {
163 pub fn scalar(&self) -> Option<Scalar> {
164 (self.rhs.value)().map(|v| Scalar::new(self.rhs.dtype.clone(), v))
165 }
166}
167
168impl Display for DynamicComparisonExpr {
169 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
170 write!(
171 f,
172 "{} {}",
173 self.operator,
174 self.scalar()
175 .map_or("<none>".to_string(), |v| v.to_string())
176 )
177 }
178}
179
180impl PartialEq for DynamicComparisonExpr {
181 fn eq(&self, other: &Self) -> bool {
182 self.operator == other.operator
183 && Arc::ptr_eq(&self.rhs, &other.rhs)
184 && self.default == other.default
185 }
186}
187impl Eq for DynamicComparisonExpr {}
188
189impl Hash for DynamicComparisonExpr {
190 fn hash<H: Hasher>(&self, state: &mut H) {
191 self.operator.hash(state);
192 Arc::as_ptr(&self.rhs).hash(state);
193 self.default.hash(state);
194 }
195}
196
197struct Rhs {
200 value: Arc<dyn Fn() -> Option<ScalarValue> + Send + Sync>,
202 dtype: DType,
204}
205
206impl Debug for Rhs {
207 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
208 f.debug_struct("Rhs")
209 .field("value", &"<dyn Fn() -> Option<ScalarValue> + Send + Sync>")
210 .field("dtype", &self.dtype)
211 .finish()
212 }
213}
214
215impl ExpressionView<'_, DynamicComparison> {
216 pub fn lhs(&self) -> &Expression {
217 &self.children()[0]
218 }
219
220 pub fn scalar(&self) -> Option<Scalar> {
221 (self.data().rhs.value)().map(|v| Scalar::new(self.data().rhs.dtype.clone(), v))
222 }
223}
224
225pub struct DynamicExprUpdates {
227 exprs: Box<[DynamicComparisonExpr]>,
228 prev_versions: Mutex<(u64, Vec<Option<Scalar>>)>,
230}
231
232impl DynamicExprUpdates {
233 pub fn new(expr: &Expression) -> Option<Self> {
234 #[derive(Default)]
235 struct Visitor(Vec<DynamicComparisonExpr>);
236
237 impl NodeVisitor<'_> for Visitor {
238 type NodeTy = Expression;
239
240 fn visit_down(&mut self, node: &'_ Self::NodeTy) -> VortexResult<TraversalOrder> {
241 if let Some(dynamic) = node.as_opt::<DynamicComparison>() {
242 self.0.push(dynamic.data().clone());
243 }
244 Ok(TraversalOrder::Continue)
245 }
246 }
247
248 let mut visitor = Visitor::default();
249 expr.accept(&mut visitor).vortex_expect("Infallible");
250
251 if visitor.0.is_empty() {
252 return None;
253 }
254
255 let exprs = visitor.0.into_boxed_slice();
256 let prev_versions = exprs
257 .iter()
258 .map(|expr| (expr.rhs.value)().map(|v| Scalar::new(expr.rhs.dtype.clone(), v)))
259 .collect();
260
261 Some(Self {
262 exprs,
263 prev_versions: Mutex::new((0, prev_versions)),
264 })
265 }
266
267 pub fn version(&self) -> u64 {
268 let mut guard = self.prev_versions.lock();
269
270 let mut updated = false;
271 for (i, expr) in self.exprs.iter().enumerate() {
272 let current = expr.scalar();
273 if current != guard.1[i] {
274 updated = true;
278 guard.1[i] = current;
279 }
280 }
281
282 if updated {
283 guard.0 += 1;
284 }
285
286 guard.0
287 }
288}