uni_query/query/df_graph/
comprehension.rs1use std::any::Any;
5use std::fmt::{self, Display, Formatter};
6use std::hash::Hash;
7use std::sync::Arc;
8
9use datafusion::arrow::array::{Array, BooleanArray, RecordBatch, UInt32Array};
10use datafusion::arrow::buffer::{OffsetBuffer, ScalarBuffer};
11use datafusion::arrow::compute::{cast, filter, filter_record_batch, take};
12use datafusion::arrow::datatypes::{DataType, Field, Schema};
13use datafusion::common::Result;
14use datafusion::logical_expr::ColumnarValue;
15use datafusion::physical_plan::PhysicalExpr;
16
17#[derive(Debug)]
19pub struct ListComprehensionExecExpr {
20 input_list: Arc<dyn PhysicalExpr>,
22 map_expr: Arc<dyn PhysicalExpr>,
24 predicate: Option<Arc<dyn PhysicalExpr>>,
26 variable_name: String,
28 input_schema: Arc<Schema>,
30 output_item_type: DataType,
32 needs_vid_extraction: bool,
35}
36
37impl Clone for ListComprehensionExecExpr {
38 fn clone(&self) -> Self {
39 Self {
40 input_list: self.input_list.clone(),
41 map_expr: self.map_expr.clone(),
42 predicate: self.predicate.clone(),
43 variable_name: self.variable_name.clone(),
44 input_schema: self.input_schema.clone(),
45 output_item_type: self.output_item_type.clone(),
46 needs_vid_extraction: self.needs_vid_extraction,
47 }
48 }
49}
50
51impl ListComprehensionExecExpr {
52 pub fn new(
53 input_list: Arc<dyn PhysicalExpr>,
54 map_expr: Arc<dyn PhysicalExpr>,
55 predicate: Option<Arc<dyn PhysicalExpr>>,
56 variable_name: String,
57 input_schema: Arc<Schema>,
58 output_item_type: DataType,
59 needs_vid_extraction: bool,
60 ) -> Self {
61 Self {
62 input_list,
63 map_expr,
64 predicate,
65 variable_name,
66 input_schema,
67 output_item_type,
68 needs_vid_extraction,
69 }
70 }
71}
72
73impl Display for ListComprehensionExecExpr {
74 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
75 write!(
76 f,
77 "ListComprehension(var={}, list={})",
78 self.variable_name, self.input_list
79 )
80 }
81}
82
83impl PartialEq for ListComprehensionExecExpr {
84 fn eq(&self, other: &Self) -> bool {
85 self.variable_name == other.variable_name
86 && self.output_item_type == other.output_item_type
87 && Arc::ptr_eq(&self.input_list, &other.input_list)
88 && Arc::ptr_eq(&self.map_expr, &other.map_expr)
89 && match (&self.predicate, &other.predicate) {
90 (Some(a), Some(b)) => Arc::ptr_eq(a, b),
91 (None, None) => true,
92 _ => false,
93 }
94 }
95}
96
97impl Eq for ListComprehensionExecExpr {}
98
99impl Hash for ListComprehensionExecExpr {
100 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
101 self.variable_name.hash(state);
102 self.output_item_type.hash(state);
103 }
104}
105
106impl PartialEq<dyn Any> for ListComprehensionExecExpr {
107 fn eq(&self, other: &dyn Any) -> bool {
108 other
109 .downcast_ref::<Self>()
110 .map(|x| self == x)
111 .unwrap_or(false)
112 }
113}
114
115impl PhysicalExpr for ListComprehensionExecExpr {
116 fn as_any(&self) -> &dyn Any {
117 self
118 }
119
120 fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
121 Ok(DataType::LargeBinary)
127 }
128
129 fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
130 Ok(true)
131 }
132
133 fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
134 let list_val = self.input_list.evaluate(batch)?;
136 let list_array = list_val.into_array(batch.num_rows())?;
137
138 let list_array = if let DataType::LargeBinary = list_array.data_type() {
140 crate::query::df_graph::common::cv_array_to_large_list(
141 list_array.as_ref(),
142 &DataType::LargeBinary,
143 )?
144 } else {
145 list_array
146 };
147
148 let list_array = if let DataType::List(field) = list_array.data_type() {
150 let target_type = DataType::LargeList(field.clone());
151 cast(&list_array, &target_type).map_err(|e| {
152 datafusion::error::DataFusionError::Execution(format!("Cast failed: {}", e))
153 })?
154 } else {
155 list_array
156 };
157
158 let large_list = list_array
159 .as_any()
160 .downcast_ref::<datafusion::arrow::array::LargeListArray>()
161 .ok_or_else(|| {
162 datafusion::error::DataFusionError::Execution(format!(
163 "Expected LargeListArray, got {:?}",
164 list_array.data_type()
165 ))
166 })?;
167
168 let values = large_list.values();
169 let offsets = large_list.offsets();
170 let nulls = large_list.nulls();
171
172 let num_rows = batch.num_rows();
174 let num_values = values.len();
175 let mut indices_builder =
176 datafusion::arrow::array::UInt32Builder::with_capacity(num_values);
177 for row_idx in 0..num_rows {
178 let start = offsets[row_idx] as usize;
179 let end = offsets[row_idx + 1] as usize;
180 let len = end - start;
181 for _ in 0..len {
182 indices_builder.append_value(row_idx as u32);
183 }
184 }
185 let indices = indices_builder.finish();
186
187 let mut inner_columns = Vec::with_capacity(batch.num_columns() + 1);
188 for col in batch.columns() {
189 let taken = take(col, &indices, None).map_err(|e| {
190 datafusion::error::DataFusionError::Execution(format!("Take failed: {}", e))
191 })?;
192 inner_columns.push(taken);
193 }
194
195 let mut inner_fields = batch.schema().fields().to_vec();
196 let loop_field = Arc::new(Field::new(
197 &self.variable_name,
198 values.data_type().clone(),
199 true,
200 ));
201
202 if let Some(pos) = inner_fields
205 .iter()
206 .position(|f| f.name() == &self.variable_name)
207 {
208 inner_columns[pos] = values.clone();
209 inner_fields[pos] = loop_field;
210 } else {
211 inner_columns.push(values.clone());
212 inner_fields.push(loop_field);
213 }
214
215 if self.needs_vid_extraction {
218 let vid_field_name = format!("{}._vid", self.variable_name);
219 if !inner_fields.iter().any(|f| f.name() == &vid_field_name) {
220 let vid_field = Arc::new(Field::new(&vid_field_name, DataType::UInt64, true));
221 let loop_var_idx = inner_fields
223 .iter()
224 .position(|f| f.name() == &self.variable_name);
225 if let Some(idx) = loop_var_idx {
226 let vid_array = super::common::extract_vids_from_cypher_value_column(
227 inner_columns[idx].as_ref(),
228 )?;
229 inner_fields.push(vid_field);
230 inner_columns.push(vid_array);
231 }
232 }
233 }
234
235 let inner_schema = Arc::new(Schema::new(inner_fields));
236
237 let inner_batch = RecordBatch::try_new(inner_schema, inner_columns)?;
238
239 let (filtered_batch, filtered_indices) = if let Some(pred) = &self.predicate {
241 let mask = pred
242 .evaluate(&inner_batch)?
243 .into_array(inner_batch.num_rows())?;
244 let mask = cast(&mask, &DataType::Boolean)?;
245 let boolean_mask = mask.as_any().downcast_ref::<BooleanArray>().unwrap();
246
247 let filtered_batch = filter_record_batch(&inner_batch, boolean_mask)?;
248
249 let indices_array: Arc<dyn Array> = Arc::new(indices.clone());
250 let filtered_indices = filter(&indices_array, boolean_mask)?;
251 let filtered_indices = filtered_indices
252 .as_any()
253 .downcast_ref::<UInt32Array>()
254 .unwrap()
255 .clone();
256
257 (filtered_batch, filtered_indices)
258 } else {
259 (inner_batch, indices.clone())
260 };
261
262 let mapped_val = self.map_expr.evaluate(&filtered_batch)?;
264 let mapped_array = mapped_val.into_array(filtered_batch.num_rows())?;
265
266 let new_offsets = if self.predicate.is_some() {
268 let num_rows = batch.num_rows();
269 let mut new_offsets = Vec::with_capacity(num_rows + 1);
270 new_offsets.push(0);
271
272 let indices_slice = filtered_indices.values();
273 let mut pos = 0;
274 let mut current_len = 0;
275
276 for row_idx in 0..num_rows {
277 let mut count = 0;
278 while pos < indices_slice.len() && indices_slice[pos] as usize == row_idx {
279 count += 1;
280 pos += 1;
281 }
282 current_len += count;
283 new_offsets.push(current_len);
284 }
285 OffsetBuffer::new(ScalarBuffer::from(new_offsets))
286 } else {
287 offsets.clone()
288 };
289
290 let new_field = Arc::new(Field::new("item", mapped_array.data_type().clone(), true));
291 let new_list = datafusion::arrow::array::LargeListArray::new(
292 new_field,
293 new_offsets,
294 mapped_array,
295 nulls.cloned(),
296 );
297
298 let cypher_value_array =
302 crate::query::df_graph::common::typed_large_list_to_cv_array(&new_list)?;
303 Ok(ColumnarValue::Array(cypher_value_array))
304 }
305
306 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
307 vec![&self.input_list]
311 }
312
313 fn with_new_children(
314 self: Arc<Self>,
315 children: Vec<Arc<dyn PhysicalExpr>>,
316 ) -> Result<Arc<dyn PhysicalExpr>> {
317 if children.len() != 1 {
318 return Err(datafusion::error::DataFusionError::Internal(
319 "ListComprehension requires exactly 1 child (input_list)".to_string(),
320 ));
321 }
322
323 Ok(Arc::new(Self {
324 input_list: children[0].clone(),
325 map_expr: self.map_expr.clone(),
326 predicate: self.predicate.clone(),
327 variable_name: self.variable_name.clone(),
328 input_schema: self.input_schema.clone(),
329 output_item_type: self.output_item_type.clone(),
330 needs_vid_extraction: self.needs_vid_extraction,
331 }))
332 }
333
334 fn fmt_sql(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
335 if let Some(pred) = &self.predicate {
336 write!(
337 f,
338 "[{} IN {} WHERE {} | {}]",
339 self.variable_name, self.input_list, pred, self.map_expr
340 )
341 } else {
342 write!(
343 f,
344 "[{} IN {} | {}]",
345 self.variable_name, self.input_list, self.map_expr
346 )
347 }
348 }
349}