Skip to main content

uni_query/query/df_graph/
quantifier.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Vectorized quantifier expression for Cypher `ALL/ANY/SINGLE/NONE(x IN list WHERE pred)`.
5//!
6//! Implements three-valued null semantics required by the OpenCypher TCK:
7//! - `ALL`: false if any false; null if any null (no false); true otherwise. Empty → true.
8//! - `ANY`: true if any true; null if any null (no true); false otherwise. Empty → false.
9//! - `SINGLE`: false if >1 true; null if nulls present with ≤1 true; true if exactly 1 true
10//!   and no nulls. Empty → false.
11//! - `NONE`: false if any true; null if any null (no true); true otherwise. Empty → true.
12
13use std::any::Any;
14use std::fmt::{self, Display, Formatter};
15use std::hash::Hash;
16use std::sync::Arc;
17
18use datafusion::arrow::array::{Array, BooleanArray, BooleanBuilder, RecordBatch};
19use datafusion::arrow::compute::cast;
20use datafusion::arrow::datatypes::{DataType, Field, Schema};
21use datafusion::common::Result;
22use datafusion::logical_expr::ColumnarValue;
23use datafusion::physical_plan::PhysicalExpr;
24
25/// Quantifier type for boolean reduction over list elements.
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
27pub enum QuantifierType {
28    /// `ALL(x IN list WHERE pred)` — true iff every element satisfies pred.
29    All,
30    /// `ANY(x IN list WHERE pred)` — true iff at least one element satisfies pred.
31    Any,
32    /// `SINGLE(x IN list WHERE pred)` — true iff exactly one element satisfies pred.
33    Single,
34    /// `NONE(x IN list WHERE pred)` — true iff no element satisfies pred.
35    None,
36}
37
38impl Display for QuantifierType {
39    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
40        match self {
41            Self::All => write!(f, "ALL"),
42            Self::Any => write!(f, "ANY"),
43            Self::Single => write!(f, "SINGLE"),
44            Self::None => write!(f, "NONE"),
45        }
46    }
47}
48
49/// Physical expression evaluating `ALL/ANY/SINGLE/NONE(x IN list WHERE pred)`.
50///
51/// Steps 1–4 mirror [`super::comprehension::ListComprehensionExecExpr`]: evaluate input
52/// list, CypherValue-decode, normalize to `LargeList`, flatten with row indices, build inner
53/// batch. Step 5 evaluates the predicate on the inner batch and performs boolean reduction
54/// with three-valued null logic per parent row.
55#[derive(Debug)]
56pub struct QuantifierExecExpr {
57    /// Expression producing the input list.
58    input_list: Arc<dyn PhysicalExpr>,
59    /// Predicate evaluated for each element.
60    predicate: Arc<dyn PhysicalExpr>,
61    /// Name of the loop variable (e.g., `"x"`).
62    variable_name: String,
63    /// Schema of the outer input batch.
64    input_schema: Arc<Schema>,
65    /// Which quantifier to apply.
66    quantifier_type: QuantifierType,
67}
68
69impl Clone for QuantifierExecExpr {
70    fn clone(&self) -> Self {
71        Self {
72            input_list: self.input_list.clone(),
73            predicate: self.predicate.clone(),
74            variable_name: self.variable_name.clone(),
75            input_schema: self.input_schema.clone(),
76            quantifier_type: self.quantifier_type,
77        }
78    }
79}
80
81impl QuantifierExecExpr {
82    /// Create a new quantifier expression.
83    ///
84    /// # Arguments
85    ///
86    /// * `input_list` — expression producing the list to iterate
87    /// * `predicate` — expression evaluated per element (compiled against inner schema)
88    /// * `variable_name` — loop variable name bound to each element
89    /// * `input_schema` — schema of the outer batch
90    /// * `quantifier_type` — `All`, `Any`, `Single`, or `None`
91    pub fn new(
92        input_list: Arc<dyn PhysicalExpr>,
93        predicate: Arc<dyn PhysicalExpr>,
94        variable_name: String,
95        input_schema: Arc<Schema>,
96        quantifier_type: QuantifierType,
97    ) -> Self {
98        Self {
99            input_list,
100            predicate,
101            variable_name,
102            input_schema,
103            quantifier_type,
104        }
105    }
106}
107
108impl Display for QuantifierExecExpr {
109    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
110        write!(
111            f,
112            "{}(var={}, list={})",
113            self.quantifier_type, self.variable_name, self.input_list
114        )
115    }
116}
117
118impl PartialEq for QuantifierExecExpr {
119    fn eq(&self, other: &Self) -> bool {
120        self.variable_name == other.variable_name
121            && self.quantifier_type == other.quantifier_type
122            && Arc::ptr_eq(&self.input_list, &other.input_list)
123            && Arc::ptr_eq(&self.predicate, &other.predicate)
124    }
125}
126
127impl Eq for QuantifierExecExpr {}
128
129impl Hash for QuantifierExecExpr {
130    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
131        self.variable_name.hash(state);
132        self.quantifier_type.hash(state);
133    }
134}
135
136impl PartialEq<dyn Any> for QuantifierExecExpr {
137    fn eq(&self, other: &dyn Any) -> bool {
138        other
139            .downcast_ref::<Self>()
140            .map(|x| self == x)
141            .unwrap_or(false)
142    }
143}
144
145impl PhysicalExpr for QuantifierExecExpr {
146    fn as_any(&self) -> &dyn Any {
147        self
148    }
149
150    fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
151        Ok(DataType::Boolean)
152    }
153
154    fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
155        // Three-valued logic can produce null results.
156        Ok(true)
157    }
158
159    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
160        let num_rows = batch.num_rows();
161
162        // --- Step 1: Evaluate input list ---
163        let list_val = self.input_list.evaluate(batch)?;
164        let list_array = list_val.into_array(num_rows)?;
165
166        // --- Step 2: CypherValue decode (LargeBinary → LargeList<LargeBinary>) ---
167        // Keep elements as CypherValue (LargeBinary) to match the compile-time schema.
168        // The compiled predicate handles LargeBinary via CypherValue comparison/arithmetic UDFs.
169        let list_array = if let DataType::LargeBinary = list_array.data_type() {
170            crate::query::df_graph::common::cv_array_to_large_list(
171                list_array.as_ref(),
172                &DataType::LargeBinary,
173            )?
174        } else {
175            list_array
176        };
177
178        // --- Step 3: Normalize List → LargeList ---
179        let list_array = if let DataType::List(field) = list_array.data_type() {
180            let target_type = DataType::LargeList(field.clone());
181            cast(&list_array, &target_type).map_err(|e| {
182                datafusion::error::DataFusionError::Execution(format!("Cast failed: {e}"))
183            })?
184        } else {
185            list_array
186        };
187
188        // Handle Null type: all rows produce null result
189        if let DataType::Null = list_array.data_type() {
190            let mut builder = BooleanBuilder::with_capacity(num_rows);
191            for _ in 0..num_rows {
192                builder.append_null();
193            }
194            return Ok(ColumnarValue::Array(Arc::new(builder.finish())));
195        }
196
197        let large_list = list_array
198            .as_any()
199            .downcast_ref::<datafusion::arrow::array::LargeListArray>()
200            .ok_or_else(|| {
201                datafusion::error::DataFusionError::Execution(format!(
202                    "Expected LargeListArray, got {:?}",
203                    list_array.data_type()
204                ))
205            })?;
206
207        let values = large_list.values();
208        let offsets = large_list.offsets();
209        let list_nulls = large_list.nulls();
210
211        // --- Step 4: Flatten — build inner batch ---
212        let num_values = values.len();
213
214        // If there are no values at all, short-circuit with empty-list semantics
215        if num_values == 0 {
216            return Ok(ColumnarValue::Array(Arc::new(
217                self.reduce_empty_lists(num_rows, offsets, list_nulls),
218            )));
219        }
220
221        let mut indices_builder =
222            datafusion::arrow::array::UInt32Builder::with_capacity(num_values);
223        for row_idx in 0..num_rows {
224            let start = offsets[row_idx] as usize;
225            let end = offsets[row_idx + 1] as usize;
226            let len = end - start;
227            for _ in 0..len {
228                indices_builder.append_value(row_idx as u32);
229            }
230        }
231        let indices = indices_builder.finish();
232
233        let mut inner_columns = Vec::with_capacity(batch.num_columns() + 1);
234        for col in batch.columns() {
235            let taken = datafusion::arrow::compute::take(col, &indices, None).map_err(|e| {
236                datafusion::error::DataFusionError::Execution(format!("Take failed: {e}"))
237            })?;
238            inner_columns.push(taken);
239        }
240
241        let mut inner_fields = batch.schema().fields().to_vec();
242        let loop_field = Arc::new(Field::new(
243            &self.variable_name,
244            values.data_type().clone(),
245            true,
246        ));
247
248        // Replace existing column if loop variable shadows an outer column,
249        // otherwise append at the end — matching compile_quantifier's schema construction.
250        if let Some(pos) = inner_fields
251            .iter()
252            .position(|f| f.name() == &self.variable_name)
253        {
254            inner_columns[pos] = values.clone();
255            inner_fields[pos] = loop_field;
256        } else {
257            inner_columns.push(values.clone());
258            inner_fields.push(loop_field);
259        }
260
261        let inner_schema = Arc::new(Schema::new(inner_fields));
262        let inner_batch = RecordBatch::try_new(inner_schema, inner_columns)?;
263
264        // --- Step 5: Evaluate predicate and reduce ---
265        let pred_val = self.predicate.evaluate(&inner_batch).map_err(|e| {
266            let err_msg = e.to_string();
267            if err_msg.contains("Invalid arithmetic operation") {
268                datafusion::error::DataFusionError::Execution(format!(
269                    "SyntaxError: InvalidArgumentType - {}",
270                    err_msg
271                ))
272            } else {
273                e
274            }
275        })?;
276        let pred_array = pred_val.into_array(inner_batch.num_rows())?;
277        let pred_array = cast(&pred_array, &DataType::Boolean).map_err(|e| {
278            let err_msg = e.to_string();
279            if err_msg.contains("Invalid arithmetic operation") {
280                datafusion::error::DataFusionError::Execution(format!(
281                    "SyntaxError: InvalidArgumentType - {}",
282                    err_msg
283                ))
284            } else {
285                datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
286            }
287        })?;
288        let pred_bools = pred_array
289            .as_any()
290            .downcast_ref::<BooleanArray>()
291            .ok_or_else(|| {
292                datafusion::error::DataFusionError::Execution(
293                    "Quantifier predicate did not produce BooleanArray".to_string(),
294                )
295            })?;
296
297        let result = self.reduce_predicate_results(num_rows, offsets, list_nulls, pred_bools);
298        Ok(ColumnarValue::Array(Arc::new(result)))
299    }
300
301    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
302        // Only expose input_list. The predicate is compiled against the inner schema
303        // (with the loop variable) and must not be exposed to DF tree traversal.
304        vec![&self.input_list]
305    }
306
307    fn with_new_children(
308        self: Arc<Self>,
309        children: Vec<Arc<dyn PhysicalExpr>>,
310    ) -> Result<Arc<dyn PhysicalExpr>> {
311        if children.len() != 1 {
312            return Err(datafusion::error::DataFusionError::Internal(
313                "QuantifierExecExpr requires exactly 1 child (input_list)".to_string(),
314            ));
315        }
316
317        Ok(Arc::new(Self {
318            input_list: children[0].clone(),
319            predicate: self.predicate.clone(),
320            variable_name: self.variable_name.clone(),
321            input_schema: self.input_schema.clone(),
322            quantifier_type: self.quantifier_type,
323        }))
324    }
325
326    fn fmt_sql(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
327        write!(
328            f,
329            "{}({} IN {} WHERE {})",
330            self.quantifier_type, self.variable_name, self.input_list, self.predicate
331        )
332    }
333}
334
335impl QuantifierExecExpr {
336    /// Reduce predicate results per parent row using three-valued null logic.
337    ///
338    /// For each parent row, slices the predicate boolean array using offsets and
339    /// counts true/false/null, then applies the quantifier semantics.
340    fn reduce_predicate_results(
341        &self,
342        num_rows: usize,
343        offsets: &datafusion::arrow::buffer::OffsetBuffer<i64>,
344        list_nulls: Option<&datafusion::arrow::buffer::NullBuffer>,
345        pred_bools: &BooleanArray,
346    ) -> BooleanArray {
347        let mut builder = BooleanBuilder::with_capacity(num_rows);
348
349        for row_idx in 0..num_rows {
350            // If the list itself is null, result is null
351            if list_nulls.is_some_and(|n| !n.is_valid(row_idx)) {
352                builder.append_null();
353                continue;
354            }
355
356            let start = offsets[row_idx] as usize;
357            let end = offsets[row_idx + 1] as usize;
358            let len = end - start;
359
360            if len == 0 {
361                // Empty list semantics
362                match self.quantifier_type {
363                    QuantifierType::All | QuantifierType::None => builder.append_value(true),
364                    QuantifierType::Any | QuantifierType::Single => builder.append_value(false),
365                }
366                continue;
367            }
368
369            let mut true_count: usize = 0;
370            let mut false_count: usize = 0;
371            let mut null_count: usize = 0;
372
373            for i in start..end {
374                if pred_bools.is_null(i) {
375                    null_count += 1;
376                } else if pred_bools.value(i) {
377                    true_count += 1;
378                } else {
379                    false_count += 1;
380                }
381            }
382
383            match self.quantifier_type {
384                QuantifierType::All => {
385                    if false_count > 0 {
386                        builder.append_value(false);
387                    } else if null_count > 0 {
388                        builder.append_null();
389                    } else {
390                        builder.append_value(true);
391                    }
392                }
393                QuantifierType::Any => {
394                    if true_count > 0 {
395                        builder.append_value(true);
396                    } else if null_count > 0 {
397                        builder.append_null();
398                    } else {
399                        builder.append_value(false);
400                    }
401                }
402                QuantifierType::Single => {
403                    if true_count > 1 {
404                        builder.append_value(false);
405                    } else if true_count == 1 && null_count == 0 {
406                        builder.append_value(true);
407                    } else if true_count == 0 && null_count == 0 {
408                        builder.append_value(false);
409                    } else {
410                        // true_count <= 1 with nulls present — indeterminate
411                        builder.append_null();
412                    }
413                }
414                QuantifierType::None => {
415                    if true_count > 0 {
416                        builder.append_value(false);
417                    } else if null_count > 0 {
418                        builder.append_null();
419                    } else {
420                        builder.append_value(true);
421                    }
422                }
423            }
424        }
425
426        builder.finish()
427    }
428
429    /// Produce results for the degenerate case where every list is empty (or null).
430    ///
431    /// This avoids building an inner batch when there are zero flattened values.
432    fn reduce_empty_lists(
433        &self,
434        num_rows: usize,
435        offsets: &datafusion::arrow::buffer::OffsetBuffer<i64>,
436        list_nulls: Option<&datafusion::arrow::buffer::NullBuffer>,
437    ) -> BooleanArray {
438        let mut builder = BooleanBuilder::with_capacity(num_rows);
439
440        for row_idx in 0..num_rows {
441            if list_nulls.is_some_and(|n| !n.is_valid(row_idx)) {
442                builder.append_null();
443                continue;
444            }
445
446            let start = offsets[row_idx] as usize;
447            let end = offsets[row_idx + 1] as usize;
448
449            if start == end {
450                // Empty list
451                match self.quantifier_type {
452                    QuantifierType::All | QuantifierType::None => builder.append_value(true),
453                    QuantifierType::Any | QuantifierType::Single => builder.append_value(false),
454                }
455            } else {
456                // Should not reach here since num_values == 0, but handle defensively
457                builder.append_null();
458            }
459        }
460
461        builder.finish()
462    }
463}