1use std::fmt::{Debug, Display};
5use std::hash::{Hash, Hasher};
6use std::sync::Arc;
7
8use parking_lot::Mutex;
9use vortex_array::arrays::ConstantArray;
10use vortex_array::compute::{Operator, compare};
11use vortex_array::{Array, ArrayRef, DeserializeMetadata, IntoArray, ProstMetadata};
12use vortex_dtype::DType;
13use vortex_error::{VortexExpect, VortexResult, vortex_bail};
14use vortex_proto::expr as pb;
15use vortex_scalar::{Scalar, ScalarValue};
16
17use crate::traversal::{Node, NodeVisitor, TraversalOrder};
18use crate::{
19 AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, StatsCatalog, VTable, vtable,
20};
21
22vtable!(DynamicComparison);
23
24#[derive(Clone, Debug)]
28pub struct DynamicComparisonExpr {
29 lhs: ExprRef,
30 operator: Operator,
31 rhs: Arc<Rhs>,
32 default: bool,
34}
35
36impl PartialEq for DynamicComparisonExpr {
37 fn eq(&self, other: &Self) -> bool {
38 self.default == other.default
39 && self.operator == other.operator
40 && self.lhs.eq(&other.lhs)
41 && Arc::ptr_eq(&self.rhs.value, &other.rhs.value)
42 && self.rhs.dtype == other.rhs.dtype
43 }
44}
45impl Eq for DynamicComparisonExpr {}
46
47impl Hash for DynamicComparisonExpr {
48 fn hash<H: Hasher>(&self, state: &mut H) {
49 self.default.hash(state);
50 self.operator.hash(state);
51 self.lhs.hash(state);
52 Arc::as_ptr(&self.rhs.value).hash(state);
53 self.rhs.dtype.hash(state);
54 }
55}
56
57struct Rhs {
60 value: Arc<dyn Fn() -> Option<ScalarValue> + Send + Sync>,
62 dtype: DType,
64}
65
66impl Debug for Rhs {
67 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68 f.debug_struct("Rhs")
69 .field("value", &"<dyn Fn() -> Option<ScalarValue> + Send + Sync>")
70 .field("dtype", &self.dtype)
71 .finish()
72 }
73}
74
75pub struct DynamicComparisonExprEncoding;
76
77impl VTable for DynamicComparisonVTable {
78 type Expr = DynamicComparisonExpr;
79 type Encoding = DynamicComparisonExprEncoding;
80 type Metadata = ProstMetadata<pb::LiteralOpts>;
81
82 fn id(_encoding: &Self::Encoding) -> ExprId {
83 ExprId::new_ref("dynamic")
84 }
85
86 fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
87 ExprEncodingRef::new_ref(DynamicComparisonExprEncoding.as_ref())
88 }
89
90 fn metadata(_expr: &Self::Expr) -> Option<Self::Metadata> {
91 None
92 }
93
94 fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
95 vec![&expr.lhs]
96 }
97
98 fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
99 Ok(DynamicComparisonExpr {
100 lhs: children[0].clone(),
101 operator: expr.operator,
102 rhs: expr.rhs.clone(),
103 default: expr.default,
104 })
105 }
106
107 fn build(
108 _encoding: &Self::Encoding,
109 _metadata: &<Self::Metadata as DeserializeMetadata>::Output,
110 _children: Vec<ExprRef>,
111 ) -> VortexResult<Self::Expr> {
112 vortex_bail!("DynamicComparison expression does not support building from metadata");
113 }
114
115 fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
116 if let Some(value) = expr.scalar() {
117 let lhs = expr.lhs.evaluate(scope)?;
118 let rhs = ConstantArray::new(value, scope.len());
119 return compare(lhs.as_ref(), rhs.as_ref(), expr.operator);
120 }
121
122 let lhs = expr.return_dtype(scope.dtype())?;
124 Ok(ConstantArray::new(
125 Scalar::new(
126 DType::Bool(lhs.nullability() | expr.rhs.dtype.nullability()),
127 expr.default.into(),
128 ),
129 scope.len(),
130 )
131 .into_array())
132 }
133
134 fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
135 let lhs = expr.lhs.return_dtype(scope)?;
136 if !expr.rhs.dtype.eq_ignore_nullability(&lhs) {
137 vortex_bail!(
138 "Incompatible dtypes for dynamic comparison: expected {} (ignore nullability) but got {}",
139 &expr.rhs.dtype,
140 lhs
141 );
142 }
143 Ok(DType::Bool(
144 lhs.nullability() | expr.rhs.dtype.nullability(),
145 ))
146 }
147}
148
149impl DynamicComparisonExpr {
150 pub fn new(
151 rhs: ExprRef,
152 operator: Operator,
153 rhs_value: impl Fn() -> Option<ScalarValue> + Send + Sync + 'static,
154 rhs_dtype: DType,
155 default: bool,
156 ) -> Self {
157 DynamicComparisonExpr {
158 lhs: rhs,
159 operator,
160 rhs: Arc::new(Rhs {
161 value: Arc::new(rhs_value),
162 dtype: rhs_dtype,
163 }),
164 default,
165 }
166 }
167
168 pub fn scalar(&self) -> Option<Scalar> {
169 (self.rhs.value)().map(|v| Scalar::new(self.rhs.dtype.clone(), v))
170 }
171}
172
173impl Display for DynamicComparisonExpr {
174 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
175 write!(
176 f,
177 "{} {} dynamic({})",
178 &self.lhs, self.operator, &self.rhs.dtype,
179 )
180 }
181}
182
183impl AnalysisExpr for DynamicComparisonExpr {
184 fn stat_falsification(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
185 match self.operator {
186 Operator::Gt => Some(
187 DynamicComparisonExpr {
188 lhs: self.lhs.max(catalog)?,
189 operator: Operator::Lte,
190 rhs: self.rhs.clone(),
191 default: !self.default,
192 }
193 .into_expr(),
194 ),
195 Operator::Gte => Some(
196 DynamicComparisonExpr {
197 lhs: self.lhs.max(catalog)?,
198 operator: Operator::Lt,
199 rhs: self.rhs.clone(),
200 default: !self.default,
201 }
202 .into_expr(),
203 ),
204 Operator::Lt => Some(
205 DynamicComparisonExpr {
206 lhs: self.lhs.min(catalog)?,
207 operator: Operator::Gte,
208 rhs: self.rhs.clone(),
209 default: !self.default,
210 }
211 .into_expr(),
212 ),
213 Operator::Lte => Some(
214 DynamicComparisonExpr {
215 lhs: self.lhs.min(catalog)?,
216 operator: Operator::Gt,
217 rhs: self.rhs.clone(),
218 default: !self.default,
219 }
220 .into_expr(),
221 ),
222 _ => None,
223 }
224 }
225}
226
227pub struct DynamicExprUpdates {
229 exprs: Box<[DynamicComparisonExpr]>,
230 prev_versions: Mutex<(u64, Vec<Option<Scalar>>)>,
232}
233
234impl DynamicExprUpdates {
235 pub fn new(expr: &ExprRef) -> Option<Self> {
236 #[derive(Default)]
237 struct Visitor(Vec<DynamicComparisonExpr>);
238
239 impl NodeVisitor<'_> for Visitor {
240 type NodeTy = ExprRef;
241
242 fn visit_down(&mut self, node: &'_ Self::NodeTy) -> VortexResult<TraversalOrder> {
243 if let Some(dynamic) = node.as_opt::<DynamicComparisonVTable>() {
244 self.0.push(dynamic.clone());
245 }
246 Ok(TraversalOrder::Continue)
247 }
248 }
249
250 let mut visitor = Visitor::default();
251 expr.accept(&mut visitor).vortex_expect("Infallible");
252
253 if visitor.0.is_empty() {
254 return None;
255 }
256
257 let exprs = visitor.0.into_boxed_slice();
258 let prev_versions = exprs
259 .iter()
260 .map(|expr| (expr.rhs.value)().map(|v| Scalar::new(expr.rhs.dtype.clone(), v)))
261 .collect();
262
263 Some(Self {
264 exprs,
265 prev_versions: Mutex::new((0, prev_versions)),
266 })
267 }
268
269 pub fn version(&self) -> u64 {
270 let mut guard = self.prev_versions.lock();
271
272 let mut updated = false;
273 for (i, expr) in self.exprs.iter().enumerate() {
274 let current = expr.scalar();
275 if current != guard.1[i] {
276 updated = true;
280 guard.1[i] = current;
281 }
282 }
283
284 if updated {
285 guard.0 += 1;
286 }
287
288 guard.0
289 }
290}