1use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::hash::Hasher;
9use std::sync::Arc;
10
11use parking_lot::Mutex;
12use vortex_dtype::DType;
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15use vortex_error::vortex_bail;
16use vortex_scalar::Scalar;
17use vortex_scalar::ScalarValue;
18
19use crate::Array;
20use crate::ArrayRef;
21use crate::IntoArray;
22use crate::arrays::ConstantArray;
23use crate::compute::Operator;
24use crate::compute::compare;
25use crate::expr::ChildName;
26use crate::expr::ExprId;
27use crate::expr::Expression;
28use crate::expr::ExpressionView;
29use crate::expr::StatsCatalog;
30use crate::expr::VTable;
31use crate::expr::VTableExt;
32use crate::expr::traversal::NodeExt;
33use crate::expr::traversal::NodeVisitor;
34use crate::expr::traversal::TraversalOrder;
35
36pub struct DynamicComparison;
40
41impl VTable for DynamicComparison {
42 type Instance = DynamicComparisonExpr;
43
44 fn id(&self) -> ExprId {
45 ExprId::new_ref("vortex.dynamic")
46 }
47
48 fn validate(&self, expr: &ExpressionView<Self>) -> VortexResult<()> {
49 if expr.children().len() != 1 {
50 vortex_bail!(
51 "DynamicComparison expression requires exactly one child, got {}",
52 expr.children().len()
53 );
54 }
55 Ok(())
56 }
57
58 fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName {
59 match child_idx {
60 0 => ChildName::from("lhs"),
61 _ => unreachable!(),
62 }
63 }
64
65 fn fmt_sql(&self, expr: &ExpressionView<Self>, f: &mut Formatter<'_>) -> std::fmt::Result {
66 expr.lhs().fmt_sql(f)?;
67 write!(f, " {} dynamic(", expr.data())?;
68 match expr.scalar() {
69 None => write!(f, "<none>")?,
70 Some(scalar) => write!(f, "{}", scalar)?,
71 }
72 write!(f, ")")
73 }
74
75 fn return_dtype(&self, expr: &ExpressionView<Self>, scope: &DType) -> VortexResult<DType> {
76 let lhs = expr.lhs().return_dtype(scope)?;
77 if !expr.data().rhs.dtype.eq_ignore_nullability(&lhs) {
78 vortex_bail!(
79 "Incompatible dtypes for dynamic comparison: expected {} (ignore nullability) but got {}",
80 &expr.data().rhs.dtype,
81 lhs
82 );
83 }
84 Ok(DType::Bool(
85 lhs.nullability() | expr.data().rhs.dtype.nullability(),
86 ))
87 }
88
89 fn evaluate(&self, expr: &ExpressionView<Self>, scope: &ArrayRef) -> VortexResult<ArrayRef> {
90 if let Some(value) = expr.scalar() {
91 let lhs = expr.lhs().evaluate(scope)?;
92 let rhs = ConstantArray::new(value, scope.len());
93 return compare(lhs.as_ref(), rhs.as_ref(), expr.data().operator);
94 }
95
96 let lhs = expr.return_dtype(scope.dtype())?;
98 Ok(ConstantArray::new(
99 Scalar::new(
100 DType::Bool(lhs.nullability() | expr.data().rhs.dtype.nullability()),
101 expr.data().default.into(),
102 ),
103 scope.len(),
104 )
105 .into_array())
106 }
107
108 fn stat_falsification(
109 &self,
110 expr: &ExpressionView<DynamicComparison>,
111 catalog: &dyn StatsCatalog,
112 ) -> Option<Expression> {
113 match expr.data().operator {
114 Operator::Gt => Some(DynamicComparison.new_expr(
115 DynamicComparisonExpr {
116 operator: Operator::Lte,
117 rhs: expr.data().rhs.clone(),
118 default: !expr.data().default,
119 },
120 vec![expr.lhs().stat_max(catalog)?],
121 )),
122 Operator::Gte => Some(DynamicComparison.new_expr(
123 DynamicComparisonExpr {
124 operator: Operator::Lt,
125 rhs: expr.data().rhs.clone(),
126 default: !expr.data().default,
127 },
128 vec![expr.lhs().stat_max(catalog)?],
129 )),
130 Operator::Lt => Some(DynamicComparison.new_expr(
131 DynamicComparisonExpr {
132 operator: Operator::Gte,
133 rhs: expr.data().rhs.clone(),
134 default: !expr.data().default,
135 },
136 vec![expr.lhs().stat_min(catalog)?],
137 )),
138 Operator::Lte => Some(DynamicComparison.new_expr(
139 DynamicComparisonExpr {
140 operator: Operator::Gt,
141 rhs: expr.data().rhs.clone(),
142 default: !expr.data().default,
143 },
144 vec![expr.lhs().stat_min(catalog)?],
145 )),
146 _ => None,
147 }
148 }
149
150 fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool {
152 false
153 }
154}
155
156pub fn dynamic(
157 operator: Operator,
158 rhs_value: impl Fn() -> Option<ScalarValue> + Send + Sync + 'static,
159 rhs_dtype: DType,
160 default: bool,
161 lhs: Expression,
162) -> Expression {
163 DynamicComparison.new_expr(
164 DynamicComparisonExpr {
165 operator,
166 rhs: Arc::new(Rhs {
167 value: Arc::new(rhs_value),
168 dtype: rhs_dtype,
169 }),
170 default,
171 },
172 [lhs],
173 )
174}
175
176#[derive(Clone, Debug)]
177pub struct DynamicComparisonExpr {
178 operator: Operator,
179 rhs: Arc<Rhs>,
180 default: bool,
182}
183
184impl DynamicComparisonExpr {
185 pub fn scalar(&self) -> Option<Scalar> {
186 (self.rhs.value)().map(|v| Scalar::new(self.rhs.dtype.clone(), v))
187 }
188}
189
190impl Display for DynamicComparisonExpr {
191 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
192 write!(
193 f,
194 "{} {}",
195 self.operator,
196 self.scalar()
197 .map_or("<none>".to_string(), |v| v.to_string())
198 )
199 }
200}
201
202impl PartialEq for DynamicComparisonExpr {
203 fn eq(&self, other: &Self) -> bool {
204 self.operator == other.operator
205 && Arc::ptr_eq(&self.rhs, &other.rhs)
206 && self.default == other.default
207 }
208}
209impl Eq for DynamicComparisonExpr {}
210
211impl Hash for DynamicComparisonExpr {
212 fn hash<H: Hasher>(&self, state: &mut H) {
213 self.operator.hash(state);
214 Arc::as_ptr(&self.rhs).hash(state);
215 self.default.hash(state);
216 }
217}
218
219struct Rhs {
222 value: Arc<dyn Fn() -> Option<ScalarValue> + Send + Sync>,
224 dtype: DType,
226}
227
228impl Debug for Rhs {
229 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
230 f.debug_struct("Rhs")
231 .field("value", &"<dyn Fn() -> Option<ScalarValue> + Send + Sync>")
232 .field("dtype", &self.dtype)
233 .finish()
234 }
235}
236
237impl ExpressionView<'_, DynamicComparison> {
238 pub fn lhs(&self) -> &Expression {
239 &self.children()[0]
240 }
241
242 pub fn scalar(&self) -> Option<Scalar> {
243 (self.data().rhs.value)().map(|v| Scalar::new(self.data().rhs.dtype.clone(), v))
244 }
245}
246
247pub struct DynamicExprUpdates {
249 exprs: Box<[DynamicComparisonExpr]>,
250 prev_versions: Mutex<(u64, Vec<Option<Scalar>>)>,
252}
253
254impl DynamicExprUpdates {
255 pub fn new(expr: &Expression) -> Option<Self> {
256 #[derive(Default)]
257 struct Visitor(Vec<DynamicComparisonExpr>);
258
259 impl NodeVisitor<'_> for Visitor {
260 type NodeTy = Expression;
261
262 fn visit_down(&mut self, node: &'_ Self::NodeTy) -> VortexResult<TraversalOrder> {
263 if let Some(dynamic) = node.as_opt::<DynamicComparison>() {
264 self.0.push(dynamic.data().clone());
265 }
266 Ok(TraversalOrder::Continue)
267 }
268 }
269
270 let mut visitor = Visitor::default();
271 expr.accept(&mut visitor).vortex_expect("Infallible");
272
273 if visitor.0.is_empty() {
274 return None;
275 }
276
277 let exprs = visitor.0.into_boxed_slice();
278 let prev_versions = exprs
279 .iter()
280 .map(|expr| (expr.rhs.value)().map(|v| Scalar::new(expr.rhs.dtype.clone(), v)))
281 .collect();
282
283 Some(Self {
284 exprs,
285 prev_versions: Mutex::new((0, prev_versions)),
286 })
287 }
288
289 pub fn version(&self) -> u64 {
290 let mut guard = self.prev_versions.lock();
291
292 let mut updated = false;
293 for (i, expr) in self.exprs.iter().enumerate() {
294 let current = expr.scalar();
295 if current != guard.1[i] {
296 updated = true;
300 guard.1[i] = current;
301 }
302 }
303
304 if updated {
305 guard.0 += 1;
306 }
307
308 guard.0
309 }
310}