Skip to main content

uni_query/query/df_graph/
comprehension.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4use std::any::Any;
5use std::fmt::{self, Display, Formatter};
6use std::hash::Hash;
7use std::sync::Arc;
8
9use datafusion::arrow::array::{Array, BooleanArray, RecordBatch, UInt32Array};
10use datafusion::arrow::buffer::{OffsetBuffer, ScalarBuffer};
11use datafusion::arrow::compute::{cast, filter, filter_record_batch, take};
12use datafusion::arrow::datatypes::{DataType, Field, Schema};
13use datafusion::common::Result;
14use datafusion::logical_expr::ColumnarValue;
15use datafusion::physical_plan::PhysicalExpr;
16
17/// Physical expression for Cypher List Comprehension: `[x IN list WHERE pred | expr]`
18#[derive(Debug)]
19pub struct ListComprehensionExecExpr {
20    /// Expression producing the input list
21    input_list: Arc<dyn PhysicalExpr>,
22    /// Expression to map each element (projection)
23    map_expr: Arc<dyn PhysicalExpr>,
24    /// Optional filter predicate
25    predicate: Option<Arc<dyn PhysicalExpr>>,
26    /// Name of the loop variable (e.g., "x")
27    variable_name: String,
28    /// Schema of the input batch (outer scope)
29    input_schema: Arc<Schema>,
30    /// Data type of the items in the output list
31    output_item_type: DataType,
32    /// Whether to extract VIDs from CypherValue-encoded loop variable
33    /// for nested pattern comprehension anchor binding
34    needs_vid_extraction: bool,
35}
36
37impl Clone for ListComprehensionExecExpr {
38    fn clone(&self) -> Self {
39        Self {
40            input_list: self.input_list.clone(),
41            map_expr: self.map_expr.clone(),
42            predicate: self.predicate.clone(),
43            variable_name: self.variable_name.clone(),
44            input_schema: self.input_schema.clone(),
45            output_item_type: self.output_item_type.clone(),
46            needs_vid_extraction: self.needs_vid_extraction,
47        }
48    }
49}
50
51impl ListComprehensionExecExpr {
52    pub fn new(
53        input_list: Arc<dyn PhysicalExpr>,
54        map_expr: Arc<dyn PhysicalExpr>,
55        predicate: Option<Arc<dyn PhysicalExpr>>,
56        variable_name: String,
57        input_schema: Arc<Schema>,
58        output_item_type: DataType,
59        needs_vid_extraction: bool,
60    ) -> Self {
61        Self {
62            input_list,
63            map_expr,
64            predicate,
65            variable_name,
66            input_schema,
67            output_item_type,
68            needs_vid_extraction,
69        }
70    }
71}
72
73impl Display for ListComprehensionExecExpr {
74    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
75        write!(
76            f,
77            "ListComprehension(var={}, list={})",
78            self.variable_name, self.input_list
79        )
80    }
81}
82
83impl PartialEq for ListComprehensionExecExpr {
84    fn eq(&self, other: &Self) -> bool {
85        self.variable_name == other.variable_name
86            && self.output_item_type == other.output_item_type
87            && Arc::ptr_eq(&self.input_list, &other.input_list)
88            && Arc::ptr_eq(&self.map_expr, &other.map_expr)
89            && match (&self.predicate, &other.predicate) {
90                (Some(a), Some(b)) => Arc::ptr_eq(a, b),
91                (None, None) => true,
92                _ => false,
93            }
94    }
95}
96
97impl Eq for ListComprehensionExecExpr {}
98
99impl Hash for ListComprehensionExecExpr {
100    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
101        self.variable_name.hash(state);
102        self.output_item_type.hash(state);
103    }
104}
105
106impl PartialEq<dyn Any> for ListComprehensionExecExpr {
107    fn eq(&self, other: &dyn Any) -> bool {
108        other
109            .downcast_ref::<Self>()
110            .map(|x| self == x)
111            .unwrap_or(false)
112    }
113}
114
115impl PhysicalExpr for ListComprehensionExecExpr {
116    fn as_any(&self) -> &dyn Any {
117        self
118    }
119
120    fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
121        // Always return LargeBinary (CypherValue encoding).
122        // This is consistent with ALL other list-producing operations (reverse(),
123        // tail(), list_concat(), etc.) which always return LargeBinary. Returning
124        // LargeList<T> for typed inputs would cause type mismatches in CASE/coalesce
125        // branches when mixed with other list ops that return LargeBinary.
126        Ok(DataType::LargeBinary)
127    }
128
129    fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
130        Ok(true)
131    }
132
133    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
134        // 1. Evaluate input list
135        let list_val = self.input_list.evaluate(batch)?;
136        let list_array = list_val.into_array(batch.num_rows())?;
137
138        // 2. Decode CypherValue-encoded arrays (LargeBinary → LargeList<LargeBinary>)
139        let list_array = if let DataType::LargeBinary = list_array.data_type() {
140            crate::query::df_graph::common::cv_array_to_large_list(
141                list_array.as_ref(),
142                &DataType::LargeBinary,
143            )?
144        } else {
145            list_array
146        };
147
148        // Normalize to LargeListArray
149        let list_array = if let DataType::List(field) = list_array.data_type() {
150            let target_type = DataType::LargeList(field.clone());
151            cast(&list_array, &target_type).map_err(|e| {
152                datafusion::error::DataFusionError::Execution(format!("Cast failed: {}", e))
153            })?
154        } else {
155            list_array
156        };
157
158        let large_list = list_array
159            .as_any()
160            .downcast_ref::<datafusion::arrow::array::LargeListArray>()
161            .ok_or_else(|| {
162                datafusion::error::DataFusionError::Execution(format!(
163                    "Expected LargeListArray, got {:?}",
164                    list_array.data_type()
165                ))
166            })?;
167
168        let values = large_list.values();
169        let offsets = large_list.offsets();
170        let nulls = large_list.nulls();
171
172        // 3. Prepare inner batch
173        let num_rows = batch.num_rows();
174        let num_values = values.len();
175        let mut indices_builder =
176            datafusion::arrow::array::UInt32Builder::with_capacity(num_values);
177        for row_idx in 0..num_rows {
178            let start = offsets[row_idx] as usize;
179            let end = offsets[row_idx + 1] as usize;
180            let len = end - start;
181            for _ in 0..len {
182                indices_builder.append_value(row_idx as u32);
183            }
184        }
185        let indices = indices_builder.finish();
186
187        let mut inner_columns = Vec::with_capacity(batch.num_columns() + 1);
188        for col in batch.columns() {
189            let taken = take(col, &indices, None).map_err(|e| {
190                datafusion::error::DataFusionError::Execution(format!("Take failed: {}", e))
191            })?;
192            inner_columns.push(taken);
193        }
194
195        let mut inner_fields = batch.schema().fields().to_vec();
196        let loop_field = Arc::new(Field::new(
197            &self.variable_name,
198            values.data_type().clone(),
199            true,
200        ));
201
202        // Replace existing column if loop variable shadows an outer column,
203        // otherwise append at the end — matching compile_list_comprehension's schema construction.
204        if let Some(pos) = inner_fields
205            .iter()
206            .position(|f| f.name() == &self.variable_name)
207        {
208            inner_columns[pos] = values.clone();
209            inner_fields[pos] = loop_field;
210        } else {
211            inner_columns.push(values.clone());
212            inner_fields.push(loop_field);
213        }
214
215        // Materialize VID column from CypherValue-encoded loop variable for nested
216        // pattern comprehension anchor binding
217        if self.needs_vid_extraction {
218            let vid_field_name = format!("{}._vid", self.variable_name);
219            if !inner_fields.iter().any(|f| f.name() == &vid_field_name) {
220                let vid_field = Arc::new(Field::new(&vid_field_name, DataType::UInt64, true));
221                // Find the loop variable column
222                let loop_var_idx = inner_fields
223                    .iter()
224                    .position(|f| f.name() == &self.variable_name);
225                if let Some(idx) = loop_var_idx {
226                    let vid_array = super::common::extract_vids_from_cypher_value_column(
227                        inner_columns[idx].as_ref(),
228                    )?;
229                    inner_fields.push(vid_field);
230                    inner_columns.push(vid_array);
231                }
232            }
233        }
234
235        let inner_schema = Arc::new(Schema::new(inner_fields));
236
237        let inner_batch = RecordBatch::try_new(inner_schema, inner_columns)?;
238
239        // 4. Filter (Predicate)
240        let (filtered_batch, filtered_indices) = if let Some(pred) = &self.predicate {
241            let mask = pred
242                .evaluate(&inner_batch)?
243                .into_array(inner_batch.num_rows())?;
244            let mask = cast(&mask, &DataType::Boolean)?;
245            let boolean_mask = mask.as_any().downcast_ref::<BooleanArray>().unwrap();
246
247            let filtered_batch = filter_record_batch(&inner_batch, boolean_mask)?;
248
249            let indices_array: Arc<dyn Array> = Arc::new(indices.clone());
250            let filtered_indices = filter(&indices_array, boolean_mask)?;
251            let filtered_indices = filtered_indices
252                .as_any()
253                .downcast_ref::<UInt32Array>()
254                .unwrap()
255                .clone();
256
257            (filtered_batch, filtered_indices)
258        } else {
259            (inner_batch, indices.clone())
260        };
261
262        // 5. Evaluate Map Expression
263        let mapped_val = self.map_expr.evaluate(&filtered_batch)?;
264        let mapped_array = mapped_val.into_array(filtered_batch.num_rows())?;
265
266        // 6. Reconstruct ListArray
267        let new_offsets = if self.predicate.is_some() {
268            let num_rows = batch.num_rows();
269            let mut new_offsets = Vec::with_capacity(num_rows + 1);
270            new_offsets.push(0);
271
272            let indices_slice = filtered_indices.values();
273            let mut pos = 0;
274            let mut current_len = 0;
275
276            for row_idx in 0..num_rows {
277                let mut count = 0;
278                while pos < indices_slice.len() && indices_slice[pos] as usize == row_idx {
279                    count += 1;
280                    pos += 1;
281                }
282                current_len += count;
283                new_offsets.push(current_len);
284            }
285            OffsetBuffer::new(ScalarBuffer::from(new_offsets))
286        } else {
287            offsets.clone()
288        };
289
290        let new_field = Arc::new(Field::new("item", mapped_array.data_type().clone(), true));
291        let new_list = datafusion::arrow::array::LargeListArray::new(
292            new_field,
293            new_offsets,
294            mapped_array,
295            nulls.cloned(),
296        );
297
298        // Always encode the result as LargeBinary (CypherValue), consistent with
299        // data_type(). typed_large_list_to_cv_array handles all element types
300        // (Int64, Float64, Utf8, Boolean, Struct, LargeBinary/nested CypherValue).
301        let cypher_value_array =
302            crate::query::df_graph::common::typed_large_list_to_cv_array(&new_list)?;
303        Ok(ColumnarValue::Array(cypher_value_array))
304    }
305
306    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
307        // Only expose input_list as a child. The map_expr and predicate are compiled
308        // against an inner schema (with the loop variable) and should not be exposed
309        // to DataFusion's expression tree traversal (e.g., equivalence analysis).
310        vec![&self.input_list]
311    }
312
313    fn with_new_children(
314        self: Arc<Self>,
315        children: Vec<Arc<dyn PhysicalExpr>>,
316    ) -> Result<Arc<dyn PhysicalExpr>> {
317        if children.len() != 1 {
318            return Err(datafusion::error::DataFusionError::Internal(
319                "ListComprehension requires exactly 1 child (input_list)".to_string(),
320            ));
321        }
322
323        Ok(Arc::new(Self {
324            input_list: children[0].clone(),
325            map_expr: self.map_expr.clone(),
326            predicate: self.predicate.clone(),
327            variable_name: self.variable_name.clone(),
328            input_schema: self.input_schema.clone(),
329            output_item_type: self.output_item_type.clone(),
330            needs_vid_extraction: self.needs_vid_extraction,
331        }))
332    }
333
334    fn fmt_sql(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
335        if let Some(pred) = &self.predicate {
336            write!(
337                f,
338                "[{} IN {} WHERE {} | {}]",
339                self.variable_name, self.input_list, pred, self.map_expr
340            )
341        } else {
342            write!(
343                f,
344                "[{} IN {} | {}]",
345                self.variable_name, self.input_list, self.map_expr
346            )
347        }
348    }
349}