uni_query/query/df_graph/
quantifier.rs1use std::any::Any;
14use std::fmt::{self, Display, Formatter};
15use std::hash::Hash;
16use std::sync::Arc;
17
18use datafusion::arrow::array::{Array, BooleanArray, BooleanBuilder, RecordBatch};
19use datafusion::arrow::compute::cast;
20use datafusion::arrow::datatypes::{DataType, Field, Schema};
21use datafusion::common::Result;
22use datafusion::logical_expr::ColumnarValue;
23use datafusion::physical_plan::PhysicalExpr;
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
27pub enum QuantifierType {
28 All,
30 Any,
32 Single,
34 None,
36}
37
38impl Display for QuantifierType {
39 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
40 match self {
41 Self::All => write!(f, "ALL"),
42 Self::Any => write!(f, "ANY"),
43 Self::Single => write!(f, "SINGLE"),
44 Self::None => write!(f, "NONE"),
45 }
46 }
47}
48
49#[derive(Debug)]
56pub struct QuantifierExecExpr {
57 input_list: Arc<dyn PhysicalExpr>,
59 predicate: Arc<dyn PhysicalExpr>,
61 variable_name: String,
63 input_schema: Arc<Schema>,
65 quantifier_type: QuantifierType,
67}
68
69impl Clone for QuantifierExecExpr {
70 fn clone(&self) -> Self {
71 Self {
72 input_list: self.input_list.clone(),
73 predicate: self.predicate.clone(),
74 variable_name: self.variable_name.clone(),
75 input_schema: self.input_schema.clone(),
76 quantifier_type: self.quantifier_type,
77 }
78 }
79}
80
81impl QuantifierExecExpr {
82 pub fn new(
92 input_list: Arc<dyn PhysicalExpr>,
93 predicate: Arc<dyn PhysicalExpr>,
94 variable_name: String,
95 input_schema: Arc<Schema>,
96 quantifier_type: QuantifierType,
97 ) -> Self {
98 Self {
99 input_list,
100 predicate,
101 variable_name,
102 input_schema,
103 quantifier_type,
104 }
105 }
106}
107
108impl Display for QuantifierExecExpr {
109 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
110 write!(
111 f,
112 "{}(var={}, list={})",
113 self.quantifier_type, self.variable_name, self.input_list
114 )
115 }
116}
117
118impl PartialEq for QuantifierExecExpr {
119 fn eq(&self, other: &Self) -> bool {
120 self.variable_name == other.variable_name
121 && self.quantifier_type == other.quantifier_type
122 && Arc::ptr_eq(&self.input_list, &other.input_list)
123 && Arc::ptr_eq(&self.predicate, &other.predicate)
124 }
125}
126
127impl Eq for QuantifierExecExpr {}
128
129impl Hash for QuantifierExecExpr {
130 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
131 self.variable_name.hash(state);
132 self.quantifier_type.hash(state);
133 }
134}
135
136impl PartialEq<dyn Any> for QuantifierExecExpr {
137 fn eq(&self, other: &dyn Any) -> bool {
138 other
139 .downcast_ref::<Self>()
140 .map(|x| self == x)
141 .unwrap_or(false)
142 }
143}
144
145impl PhysicalExpr for QuantifierExecExpr {
146 fn as_any(&self) -> &dyn Any {
147 self
148 }
149
150 fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
151 Ok(DataType::Boolean)
152 }
153
154 fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
155 Ok(true)
157 }
158
159 fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
160 let num_rows = batch.num_rows();
161
162 let list_val = self.input_list.evaluate(batch)?;
164 let list_array = list_val.into_array(num_rows)?;
165
166 let list_array = if let DataType::LargeBinary = list_array.data_type() {
170 crate::query::df_graph::common::cv_array_to_large_list(
171 list_array.as_ref(),
172 &DataType::LargeBinary,
173 )?
174 } else {
175 list_array
176 };
177
178 let list_array = if let DataType::List(field) = list_array.data_type() {
180 let target_type = DataType::LargeList(field.clone());
181 cast(&list_array, &target_type).map_err(|e| {
182 datafusion::error::DataFusionError::Execution(format!("Cast failed: {e}"))
183 })?
184 } else {
185 list_array
186 };
187
188 if let DataType::Null = list_array.data_type() {
190 let mut builder = BooleanBuilder::with_capacity(num_rows);
191 for _ in 0..num_rows {
192 builder.append_null();
193 }
194 return Ok(ColumnarValue::Array(Arc::new(builder.finish())));
195 }
196
197 let large_list = list_array
198 .as_any()
199 .downcast_ref::<datafusion::arrow::array::LargeListArray>()
200 .ok_or_else(|| {
201 datafusion::error::DataFusionError::Execution(format!(
202 "Expected LargeListArray, got {:?}",
203 list_array.data_type()
204 ))
205 })?;
206
207 let values = large_list.values();
208 let offsets = large_list.offsets();
209 let list_nulls = large_list.nulls();
210
211 let num_values = values.len();
213
214 if num_values == 0 {
216 return Ok(ColumnarValue::Array(Arc::new(
217 self.reduce_empty_lists(num_rows, offsets, list_nulls),
218 )));
219 }
220
221 let mut indices_builder =
222 datafusion::arrow::array::UInt32Builder::with_capacity(num_values);
223 for row_idx in 0..num_rows {
224 let start = offsets[row_idx] as usize;
225 let end = offsets[row_idx + 1] as usize;
226 let len = end - start;
227 for _ in 0..len {
228 indices_builder.append_value(row_idx as u32);
229 }
230 }
231 let indices = indices_builder.finish();
232
233 let mut inner_columns = Vec::with_capacity(batch.num_columns() + 1);
234 for col in batch.columns() {
235 let taken = datafusion::arrow::compute::take(col, &indices, None).map_err(|e| {
236 datafusion::error::DataFusionError::Execution(format!("Take failed: {e}"))
237 })?;
238 inner_columns.push(taken);
239 }
240
241 let mut inner_fields = batch.schema().fields().to_vec();
242 let loop_field = Arc::new(Field::new(
243 &self.variable_name,
244 values.data_type().clone(),
245 true,
246 ));
247
248 if let Some(pos) = inner_fields
251 .iter()
252 .position(|f| f.name() == &self.variable_name)
253 {
254 inner_columns[pos] = values.clone();
255 inner_fields[pos] = loop_field;
256 } else {
257 inner_columns.push(values.clone());
258 inner_fields.push(loop_field);
259 }
260
261 let inner_schema = Arc::new(Schema::new(inner_fields));
262 let inner_batch = RecordBatch::try_new(inner_schema, inner_columns)?;
263
264 let pred_val = self.predicate.evaluate(&inner_batch).map_err(|e| {
266 let err_msg = e.to_string();
267 if err_msg.contains("Invalid arithmetic operation") {
268 datafusion::error::DataFusionError::Execution(format!(
269 "SyntaxError: InvalidArgumentType - {}",
270 err_msg
271 ))
272 } else {
273 e
274 }
275 })?;
276 let pred_array = pred_val.into_array(inner_batch.num_rows())?;
277 let pred_array = cast(&pred_array, &DataType::Boolean).map_err(|e| {
278 let err_msg = e.to_string();
279 if err_msg.contains("Invalid arithmetic operation") {
280 datafusion::error::DataFusionError::Execution(format!(
281 "SyntaxError: InvalidArgumentType - {}",
282 err_msg
283 ))
284 } else {
285 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
286 }
287 })?;
288 let pred_bools = pred_array
289 .as_any()
290 .downcast_ref::<BooleanArray>()
291 .ok_or_else(|| {
292 datafusion::error::DataFusionError::Execution(
293 "Quantifier predicate did not produce BooleanArray".to_string(),
294 )
295 })?;
296
297 let result = self.reduce_predicate_results(num_rows, offsets, list_nulls, pred_bools);
298 Ok(ColumnarValue::Array(Arc::new(result)))
299 }
300
301 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
302 vec![&self.input_list]
305 }
306
307 fn with_new_children(
308 self: Arc<Self>,
309 children: Vec<Arc<dyn PhysicalExpr>>,
310 ) -> Result<Arc<dyn PhysicalExpr>> {
311 if children.len() != 1 {
312 return Err(datafusion::error::DataFusionError::Internal(
313 "QuantifierExecExpr requires exactly 1 child (input_list)".to_string(),
314 ));
315 }
316
317 Ok(Arc::new(Self {
318 input_list: children[0].clone(),
319 predicate: self.predicate.clone(),
320 variable_name: self.variable_name.clone(),
321 input_schema: self.input_schema.clone(),
322 quantifier_type: self.quantifier_type,
323 }))
324 }
325
326 fn fmt_sql(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
327 write!(
328 f,
329 "{}({} IN {} WHERE {})",
330 self.quantifier_type, self.variable_name, self.input_list, self.predicate
331 )
332 }
333}
334
335impl QuantifierExecExpr {
336 fn reduce_predicate_results(
341 &self,
342 num_rows: usize,
343 offsets: &datafusion::arrow::buffer::OffsetBuffer<i64>,
344 list_nulls: Option<&datafusion::arrow::buffer::NullBuffer>,
345 pred_bools: &BooleanArray,
346 ) -> BooleanArray {
347 let mut builder = BooleanBuilder::with_capacity(num_rows);
348
349 for row_idx in 0..num_rows {
350 if list_nulls.is_some_and(|n| !n.is_valid(row_idx)) {
352 builder.append_null();
353 continue;
354 }
355
356 let start = offsets[row_idx] as usize;
357 let end = offsets[row_idx + 1] as usize;
358 let len = end - start;
359
360 if len == 0 {
361 match self.quantifier_type {
363 QuantifierType::All | QuantifierType::None => builder.append_value(true),
364 QuantifierType::Any | QuantifierType::Single => builder.append_value(false),
365 }
366 continue;
367 }
368
369 let mut true_count: usize = 0;
370 let mut false_count: usize = 0;
371 let mut null_count: usize = 0;
372
373 for i in start..end {
374 if pred_bools.is_null(i) {
375 null_count += 1;
376 } else if pred_bools.value(i) {
377 true_count += 1;
378 } else {
379 false_count += 1;
380 }
381 }
382
383 match self.quantifier_type {
384 QuantifierType::All => {
385 if false_count > 0 {
386 builder.append_value(false);
387 } else if null_count > 0 {
388 builder.append_null();
389 } else {
390 builder.append_value(true);
391 }
392 }
393 QuantifierType::Any => {
394 if true_count > 0 {
395 builder.append_value(true);
396 } else if null_count > 0 {
397 builder.append_null();
398 } else {
399 builder.append_value(false);
400 }
401 }
402 QuantifierType::Single => {
403 if true_count > 1 {
404 builder.append_value(false);
405 } else if true_count == 1 && null_count == 0 {
406 builder.append_value(true);
407 } else if true_count == 0 && null_count == 0 {
408 builder.append_value(false);
409 } else {
410 builder.append_null();
412 }
413 }
414 QuantifierType::None => {
415 if true_count > 0 {
416 builder.append_value(false);
417 } else if null_count > 0 {
418 builder.append_null();
419 } else {
420 builder.append_value(true);
421 }
422 }
423 }
424 }
425
426 builder.finish()
427 }
428
429 fn reduce_empty_lists(
433 &self,
434 num_rows: usize,
435 offsets: &datafusion::arrow::buffer::OffsetBuffer<i64>,
436 list_nulls: Option<&datafusion::arrow::buffer::NullBuffer>,
437 ) -> BooleanArray {
438 let mut builder = BooleanBuilder::with_capacity(num_rows);
439
440 for row_idx in 0..num_rows {
441 if list_nulls.is_some_and(|n| !n.is_valid(row_idx)) {
442 builder.append_null();
443 continue;
444 }
445
446 let start = offsets[row_idx] as usize;
447 let end = offsets[row_idx + 1] as usize;
448
449 if start == end {
450 match self.quantifier_type {
452 QuantifierType::All | QuantifierType::None => builder.append_value(true),
453 QuantifierType::Any | QuantifierType::Single => builder.append_value(false),
454 }
455 } else {
456 builder.append_null();
458 }
459 }
460
461 builder.finish()
462 }
463}