1use std::fmt::Debug;
5use std::hash::{Hash, Hasher};
6use std::sync::Arc;
7
8use parking_lot::Mutex;
9use vortex_array::arrays::ConstantArray;
10use vortex_array::compute::{Operator, compare};
11use vortex_array::{Array, ArrayRef, DeserializeMetadata, IntoArray, ProstMetadata};
12use vortex_dtype::DType;
13use vortex_error::{VortexExpect, VortexResult, vortex_bail};
14use vortex_proto::expr as pb;
15use vortex_scalar::{Scalar, ScalarValue};
16
17use crate::display::{DisplayAs, DisplayFormat};
18use crate::traversal::{NodeExt, NodeVisitor, TraversalOrder};
19use crate::{
20 AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, StatsCatalog, VTable, vtable,
21};
22
23vtable!(DynamicComparison);
24
25#[derive(Clone, Debug)]
29pub struct DynamicComparisonExpr {
30 lhs: ExprRef,
31 operator: Operator,
32 rhs: Arc<Rhs>,
33 default: bool,
35}
36
37impl PartialEq for DynamicComparisonExpr {
38 fn eq(&self, other: &Self) -> bool {
39 self.default == other.default
40 && self.operator == other.operator
41 && self.lhs.eq(&other.lhs)
42 && Arc::ptr_eq(&self.rhs.value, &other.rhs.value)
43 && self.rhs.dtype == other.rhs.dtype
44 }
45}
46impl Eq for DynamicComparisonExpr {}
47
48impl Hash for DynamicComparisonExpr {
49 fn hash<H: Hasher>(&self, state: &mut H) {
50 self.default.hash(state);
51 self.operator.hash(state);
52 self.lhs.hash(state);
53 Arc::as_ptr(&self.rhs.value).hash(state);
54 self.rhs.dtype.hash(state);
55 }
56}
57
58struct Rhs {
61 value: Arc<dyn Fn() -> Option<ScalarValue> + Send + Sync>,
63 dtype: DType,
65}
66
67impl Debug for Rhs {
68 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 f.debug_struct("Rhs")
70 .field("value", &"<dyn Fn() -> Option<ScalarValue> + Send + Sync>")
71 .field("dtype", &self.dtype)
72 .finish()
73 }
74}
75
76pub struct DynamicComparisonExprEncoding;
77
78impl VTable for DynamicComparisonVTable {
79 type Expr = DynamicComparisonExpr;
80 type Encoding = DynamicComparisonExprEncoding;
81 type Metadata = ProstMetadata<pb::LiteralOpts>;
82
83 fn id(_encoding: &Self::Encoding) -> ExprId {
84 ExprId::new_ref("dynamic")
85 }
86
87 fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
88 ExprEncodingRef::new_ref(DynamicComparisonExprEncoding.as_ref())
89 }
90
91 fn metadata(_expr: &Self::Expr) -> Option<Self::Metadata> {
92 None
93 }
94
95 fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
96 vec![&expr.lhs]
97 }
98
99 fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
100 Ok(DynamicComparisonExpr {
101 lhs: children[0].clone(),
102 operator: expr.operator,
103 rhs: expr.rhs.clone(),
104 default: expr.default,
105 })
106 }
107
108 fn build(
109 _encoding: &Self::Encoding,
110 _metadata: &<Self::Metadata as DeserializeMetadata>::Output,
111 _children: Vec<ExprRef>,
112 ) -> VortexResult<Self::Expr> {
113 vortex_bail!("DynamicComparison expression does not support building from metadata");
114 }
115
116 fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
117 if let Some(value) = expr.scalar() {
118 let lhs = expr.lhs.evaluate(scope)?;
119 let rhs = ConstantArray::new(value, scope.len());
120 return compare(lhs.as_ref(), rhs.as_ref(), expr.operator);
121 }
122
123 let lhs = expr.return_dtype(scope.dtype())?;
125 Ok(ConstantArray::new(
126 Scalar::new(
127 DType::Bool(lhs.nullability() | expr.rhs.dtype.nullability()),
128 expr.default.into(),
129 ),
130 scope.len(),
131 )
132 .into_array())
133 }
134
135 fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
136 let lhs = expr.lhs.return_dtype(scope)?;
137 if !expr.rhs.dtype.eq_ignore_nullability(&lhs) {
138 vortex_bail!(
139 "Incompatible dtypes for dynamic comparison: expected {} (ignore nullability) but got {}",
140 &expr.rhs.dtype,
141 lhs
142 );
143 }
144 Ok(DType::Bool(
145 lhs.nullability() | expr.rhs.dtype.nullability(),
146 ))
147 }
148}
149
150impl DynamicComparisonExpr {
151 pub fn new(
152 rhs: ExprRef,
153 operator: Operator,
154 rhs_value: impl Fn() -> Option<ScalarValue> + Send + Sync + 'static,
155 rhs_dtype: DType,
156 default: bool,
157 ) -> Self {
158 DynamicComparisonExpr {
159 lhs: rhs,
160 operator,
161 rhs: Arc::new(Rhs {
162 value: Arc::new(rhs_value),
163 dtype: rhs_dtype,
164 }),
165 default,
166 }
167 }
168
169 pub fn scalar(&self) -> Option<Scalar> {
170 (self.rhs.value)().map(|v| Scalar::new(self.rhs.dtype.clone(), v))
171 }
172}
173
174impl DisplayAs for DynamicComparisonExpr {
175 fn fmt_as(&self, df: DisplayFormat, f: &mut std::fmt::Formatter) -> std::fmt::Result {
176 match df {
177 DisplayFormat::Compact => {
178 write!(
179 f,
180 "{} {} dynamic({})",
181 &self.lhs, self.operator, &self.rhs.dtype,
182 )
183 }
184 DisplayFormat::Tree => {
185 write!(f, "DynamicComparison")
186 }
187 }
188 }
189}
190
191impl AnalysisExpr for DynamicComparisonExpr {
192 fn stat_falsification(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
193 match self.operator {
194 Operator::Gt => Some(
195 DynamicComparisonExpr {
196 lhs: self.lhs.max(catalog)?,
197 operator: Operator::Lte,
198 rhs: self.rhs.clone(),
199 default: !self.default,
200 }
201 .into_expr(),
202 ),
203 Operator::Gte => Some(
204 DynamicComparisonExpr {
205 lhs: self.lhs.max(catalog)?,
206 operator: Operator::Lt,
207 rhs: self.rhs.clone(),
208 default: !self.default,
209 }
210 .into_expr(),
211 ),
212 Operator::Lt => Some(
213 DynamicComparisonExpr {
214 lhs: self.lhs.min(catalog)?,
215 operator: Operator::Gte,
216 rhs: self.rhs.clone(),
217 default: !self.default,
218 }
219 .into_expr(),
220 ),
221 Operator::Lte => Some(
222 DynamicComparisonExpr {
223 lhs: self.lhs.min(catalog)?,
224 operator: Operator::Gt,
225 rhs: self.rhs.clone(),
226 default: !self.default,
227 }
228 .into_expr(),
229 ),
230 _ => None,
231 }
232 }
233}
234
235pub struct DynamicExprUpdates {
237 exprs: Box<[DynamicComparisonExpr]>,
238 prev_versions: Mutex<(u64, Vec<Option<Scalar>>)>,
240}
241
242impl DynamicExprUpdates {
243 pub fn new(expr: &ExprRef) -> Option<Self> {
244 #[derive(Default)]
245 struct Visitor(Vec<DynamicComparisonExpr>);
246
247 impl NodeVisitor<'_> for Visitor {
248 type NodeTy = ExprRef;
249
250 fn visit_down(&mut self, node: &'_ Self::NodeTy) -> VortexResult<TraversalOrder> {
251 if let Some(dynamic) = node.as_opt::<DynamicComparisonVTable>() {
252 self.0.push(dynamic.clone());
253 }
254 Ok(TraversalOrder::Continue)
255 }
256 }
257
258 let mut visitor = Visitor::default();
259 expr.accept(&mut visitor).vortex_expect("Infallible");
260
261 if visitor.0.is_empty() {
262 return None;
263 }
264
265 let exprs = visitor.0.into_boxed_slice();
266 let prev_versions = exprs
267 .iter()
268 .map(|expr| (expr.rhs.value)().map(|v| Scalar::new(expr.rhs.dtype.clone(), v)))
269 .collect();
270
271 Some(Self {
272 exprs,
273 prev_versions: Mutex::new((0, prev_versions)),
274 })
275 }
276
277 pub fn version(&self) -> u64 {
278 let mut guard = self.prev_versions.lock();
279
280 let mut updated = false;
281 for (i, expr) in self.exprs.iter().enumerate() {
282 let current = expr.scalar();
283 if current != guard.1[i] {
284 updated = true;
288 guard.1[i] = current;
289 }
290 }
291
292 if updated {
293 guard.0 += 1;
294 }
295
296 guard.0
297 }
298}