Skip to main content

uni_query/query/df_graph/
apply.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Apply (correlated subquery) execution plan for DataFusion.
5//!
6//! Implements `CALL { ... }` subqueries by executing the subquery once per
7//! input row, injecting the input row's columns as parameters, and cross-joining
8//! the results.
9//!
10//! # Semantics
11//!
12//! For each row from the input plan:
13//! 1. Optionally filter via `input_filter`
14//! 2. Inject the input row's columns as parameters
15//! 3. Re-plan and execute the subquery with those parameters
16//! 4. Cross-join: merge each subquery result row with the input row
17//!
18//! If input produces zero rows (after filtering), execute the subquery once
19//! with the base parameters (standalone CALL support).
20
21use crate::query::df_graph::GraphExecutionContext;
22use crate::query::df_graph::common::{
23    collect_all_partitions, compute_plan_properties, execute_subplan, extract_row_params,
24};
25use crate::query::planner::LogicalPlan;
26use arrow_array::builder::{
27    BooleanBuilder, Float64Builder, Int32Builder, Int64Builder, StringBuilder, UInt64Builder,
28};
29use arrow_array::{ArrayRef, RecordBatch};
30use arrow_schema::{DataType, SchemaRef};
31use datafusion::common::Result as DFResult;
32use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
33use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
34use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
35use datafusion::prelude::SessionContext;
36use futures::Stream;
37use parking_lot::RwLock;
38use std::any::Any;
39use std::collections::HashMap;
40use std::collections::hash_map::DefaultHasher;
41use std::fmt;
42use std::hash::{Hash, Hasher};
43use std::pin::Pin;
44use std::sync::Arc;
45use std::task::{Context, Poll};
46use uni_common::Value;
47use uni_common::core::schema::Schema as UniSchema;
48use uni_cypher::ast::{Expr, UnaryOp};
49use uni_store::storage::manager::StorageManager;
50
51/// Apply (correlated subquery) execution plan.
52///
53/// The input is pre-planned as a physical plan (executed directly).
54/// The subquery is stored as a **logical** plan and re-planned per row at runtime
55/// with correlated parameters injected.
56/// Handles both `SubqueryCall` (no input_filter) and `Apply` (with input_filter).
57pub struct GraphApplyExec {
58    /// Physical plan for the driving input (e.g., MATCH scan).
59    /// Pre-planned at construction time to preserve property context.
60    input_exec: Arc<dyn ExecutionPlan>,
61
62    /// Logical plan for the correlated subquery (re-planned per row).
63    subquery_plan: LogicalPlan,
64
65    /// Optional pre-filter applied to input rows before subquery execution.
66    input_filter: Option<Expr>,
67
68    /// Graph execution context shared with sub-planners.
69    graph_ctx: Arc<GraphExecutionContext>,
70
71    /// DataFusion session context.
72    session_ctx: Arc<RwLock<SessionContext>>,
73
74    /// Storage manager for creating sub-planners.
75    storage: Arc<StorageManager>,
76
77    /// Schema for label/edge type lookups.
78    schema_info: Arc<UniSchema>,
79
80    /// Query parameters.
81    params: HashMap<String, Value>,
82
83    /// Output schema (merged: input columns + subquery columns).
84    output_schema: SchemaRef,
85
86    /// Cached plan properties.
87    properties: PlanProperties,
88
89    /// Execution metrics.
90    metrics: ExecutionPlanMetricsSet,
91}
92
93impl fmt::Debug for GraphApplyExec {
94    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
95        f.debug_struct("GraphApplyExec")
96            .field("has_input_filter", &self.input_filter.is_some())
97            .finish()
98    }
99}
100
101impl GraphApplyExec {
102    /// Create a new Apply execution plan.
103    #[expect(clippy::too_many_arguments)]
104    pub fn new(
105        input_exec: Arc<dyn ExecutionPlan>,
106        subquery_plan: LogicalPlan,
107        input_filter: Option<Expr>,
108        graph_ctx: Arc<GraphExecutionContext>,
109        session_ctx: Arc<RwLock<SessionContext>>,
110        storage: Arc<StorageManager>,
111        schema_info: Arc<UniSchema>,
112        params: HashMap<String, Value>,
113        output_schema: SchemaRef,
114    ) -> Self {
115        let properties = compute_plan_properties(output_schema.clone());
116
117        Self {
118            input_exec,
119            subquery_plan,
120            input_filter,
121            graph_ctx,
122            session_ctx,
123            storage,
124            schema_info,
125            params,
126            output_schema,
127            properties,
128            metrics: ExecutionPlanMetricsSet::new(),
129        }
130    }
131}
132
133impl DisplayAs for GraphApplyExec {
134    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
135        write!(
136            f,
137            "GraphApplyExec: filter={}",
138            if self.input_filter.is_some() {
139                "yes"
140            } else {
141                "none"
142            }
143        )
144    }
145}
146
147impl ExecutionPlan for GraphApplyExec {
148    fn name(&self) -> &str {
149        "GraphApplyExec"
150    }
151
152    fn as_any(&self) -> &dyn Any {
153        self
154    }
155
156    fn schema(&self) -> SchemaRef {
157        self.output_schema.clone()
158    }
159
160    fn properties(&self) -> &PlanProperties {
161        &self.properties
162    }
163
164    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
165        // No physical children — sub-plans are re-planned at execution time
166        vec![]
167    }
168
169    fn with_new_children(
170        self: Arc<Self>,
171        children: Vec<Arc<dyn ExecutionPlan>>,
172    ) -> DFResult<Arc<dyn ExecutionPlan>> {
173        if !children.is_empty() {
174            return Err(datafusion::error::DataFusionError::Plan(
175                "GraphApplyExec has no children".to_string(),
176            ));
177        }
178        Ok(self)
179    }
180
181    fn execute(
182        &self,
183        partition: usize,
184        _context: Arc<TaskContext>,
185    ) -> DFResult<SendableRecordBatchStream> {
186        let metrics = BaselineMetrics::new(&self.metrics, partition);
187
188        let input_exec = self.input_exec.clone();
189        let subquery_plan = self.subquery_plan.clone();
190        let input_filter = self.input_filter.clone();
191        let graph_ctx = self.graph_ctx.clone();
192        let session_ctx = self.session_ctx.clone();
193        let storage = self.storage.clone();
194        let schema_info = self.schema_info.clone();
195        let params = self.params.clone();
196        let output_schema = self.output_schema.clone();
197
198        let fut = async move {
199            run_apply(
200                input_exec,
201                &subquery_plan,
202                input_filter.as_ref(),
203                &graph_ctx,
204                &session_ctx,
205                &storage,
206                &schema_info,
207                &params,
208                &output_schema,
209            )
210            .await
211        };
212
213        Ok(Box::pin(ApplyStream {
214            state: ApplyStreamState::Running(Box::pin(fut)),
215            schema: self.output_schema.clone(),
216            metrics,
217        }))
218    }
219
220    fn metrics(&self) -> Option<MetricsSet> {
221        Some(self.metrics.clone_inner())
222    }
223}
224
225// ---------------------------------------------------------------------------
226// Core apply logic
227// ---------------------------------------------------------------------------
228
229/// Convert record batches into row-oriented `HashMap<String, Value>` representation.
230fn batches_to_row_maps(batches: &[RecordBatch]) -> Vec<HashMap<String, Value>> {
231    batches
232        .iter()
233        .flat_map(|batch| {
234            (0..batch.num_rows()).map(move |row_idx| extract_row_params(batch, row_idx))
235        })
236        .collect()
237}
238
239/// Evaluate a Cypher filter expression against a row.
240///
241/// Supports simple binary comparisons and boolean operations needed for
242/// input_filter pushdown (e.g., `p.age > 30`, `p.status = 'active'`).
243fn evaluate_filter(filter: &Expr, row: &HashMap<String, Value>) -> bool {
244    match filter {
245        Expr::BinaryOp { left, op, right } => {
246            use uni_cypher::ast::BinaryOp;
247            match op {
248                BinaryOp::And => evaluate_filter(left, row) && evaluate_filter(right, row),
249                BinaryOp::Or => evaluate_filter(left, row) || evaluate_filter(right, row),
250                _ => {
251                    let left_val = resolve_expr_value(left, row);
252                    let right_val = resolve_expr_value(right, row);
253                    evaluate_comparison(op, &left_val, &right_val)
254                }
255            }
256        }
257        Expr::UnaryOp {
258            op: UnaryOp::Not,
259            expr,
260        } => !evaluate_filter(expr, row),
261        _ => {
262            // Treat any other expression as a truth test on its resolved value
263            let val = resolve_expr_value(filter, row);
264            val.as_bool().unwrap_or(false)
265        }
266    }
267}
268
269/// Resolve a simple expression to a Value using the row context.
270fn resolve_expr_value(expr: &Expr, row: &HashMap<String, Value>) -> Value {
271    match expr {
272        Expr::Literal(lit) => lit.to_value(),
273        Expr::Variable(name) => row.get(name).cloned().unwrap_or(Value::Null),
274        Expr::Property(base_expr, key) => {
275            if let Expr::Variable(var) = base_expr.as_ref() {
276                // Look up "var.key" in the row map
277                let col_name = format!("{}.{}", var, key);
278                row.get(&col_name).cloned().unwrap_or(Value::Null)
279            } else {
280                Value::Null
281            }
282        }
283        _ => Value::Null,
284    }
285}
286
287/// Compare two Values for ordering.
288fn compare_values(a: &Value, b: &Value) -> Option<std::cmp::Ordering> {
289    match (a, b) {
290        (Value::Int(a), Value::Int(b)) => Some(a.cmp(b)),
291        (Value::Float(a), Value::Float(b)) => a.partial_cmp(b),
292        (Value::Int(a), Value::Float(b)) => (*a as f64).partial_cmp(b),
293        (Value::Float(a), Value::Int(b)) => a.partial_cmp(&(*b as f64)),
294        (Value::String(a), Value::String(b)) => Some(a.cmp(b)),
295        _ => None,
296    }
297}
298
299/// Evaluate a binary comparison operator on two Values.
300///
301/// Handles equality (`Eq`, `NotEq`) directly and delegates ordering
302/// comparisons (`Lt`, `LtEq`, `Gt`, `GtEq`) to [`compare_values`].
303fn evaluate_comparison(op: &uni_cypher::ast::BinaryOp, left: &Value, right: &Value) -> bool {
304    use std::cmp::Ordering;
305    use uni_cypher::ast::BinaryOp;
306
307    match op {
308        BinaryOp::Eq => left == right,
309        BinaryOp::NotEq => left != right,
310        BinaryOp::Lt => compare_values(left, right) == Some(Ordering::Less),
311        BinaryOp::LtEq => matches!(
312            compare_values(left, right),
313            Some(Ordering::Less | Ordering::Equal)
314        ),
315        BinaryOp::Gt => compare_values(left, right) == Some(Ordering::Greater),
316        BinaryOp::GtEq => matches!(
317            compare_values(left, right),
318            Some(Ordering::Greater | Ordering::Equal)
319        ),
320        _ => false,
321    }
322}
323
324/// Build a typed column from row maps using a builder and value extractor.
325///
326/// For each row, looks up `col_name`, applies `extract` to get an `Option<T>`,
327/// and appends the value or null to the builder.
328fn build_column<B, T>(
329    rows: &[HashMap<String, Value>],
330    col_name: &str,
331    mut builder: B,
332    extract: impl Fn(&Value) -> Option<T>,
333) -> ArrayRef
334where
335    B: arrow_array::builder::ArrayBuilder,
336    B: PrimitiveAppend<T>,
337{
338    for row in rows {
339        match row.get(col_name).and_then(&extract) {
340            Some(v) => builder.append_typed_value(v),
341            None => builder.append_typed_null(),
342        }
343    }
344    Arc::new(builder.finish_to_array())
345}
346
347/// Trait to abstract over typed append for primitive Arrow builders.
348///
349/// This avoids repeating the same get-value/convert/append-or-null pattern
350/// for each numeric/boolean type in `rows_to_batch`.
351trait PrimitiveAppend<T> {
352    fn append_typed_value(&mut self, val: T);
353    fn append_typed_null(&mut self);
354    fn finish_to_array(self) -> ArrayRef;
355}
356
357macro_rules! impl_primitive_append {
358    ($builder:ty, $native:ty, $array:ty) => {
359        impl PrimitiveAppend<$native> for $builder {
360            fn append_typed_value(&mut self, val: $native) {
361                self.append_value(val);
362            }
363            fn append_typed_null(&mut self) {
364                self.append_null();
365            }
366            fn finish_to_array(mut self) -> ArrayRef {
367                Arc::new(self.finish()) as ArrayRef
368            }
369        }
370    };
371}
372
373impl_primitive_append!(UInt64Builder, u64, arrow_array::UInt64Array);
374impl_primitive_append!(Int64Builder, i64, arrow_array::Int64Array);
375impl_primitive_append!(Int32Builder, i32, arrow_array::Int32Array);
376impl_primitive_append!(Float64Builder, f64, arrow_array::Float64Array);
377impl_primitive_append!(BooleanBuilder, bool, arrow_array::BooleanArray);
378
379/// Build a RecordBatch from merged row maps using the output schema.
380fn rows_to_batch(rows: &[HashMap<String, Value>], schema: &SchemaRef) -> DFResult<RecordBatch> {
381    if rows.is_empty() {
382        return Ok(RecordBatch::new_empty(schema.clone()));
383    }
384
385    let num_rows = rows.len();
386    let mut columns: Vec<ArrayRef> = Vec::with_capacity(schema.fields().len());
387
388    for field in schema.fields() {
389        let col_name = field.name();
390        let col = match field.data_type() {
391            DataType::UInt64 => build_column(
392                rows,
393                col_name,
394                UInt64Builder::with_capacity(num_rows),
395                |v| v.as_u64().or_else(|| v.as_i64().map(|i| i as u64)),
396            ),
397            DataType::Int64 => build_column(
398                rows,
399                col_name,
400                Int64Builder::with_capacity(num_rows),
401                Value::as_i64,
402            ),
403            DataType::Int32 => {
404                build_column(rows, col_name, Int32Builder::with_capacity(num_rows), |v| {
405                    v.as_i64().map(|i| i as i32)
406                })
407            }
408            DataType::Float64 => build_column(
409                rows,
410                col_name,
411                Float64Builder::with_capacity(num_rows),
412                Value::as_f64,
413            ),
414            DataType::Boolean => build_column(
415                rows,
416                col_name,
417                BooleanBuilder::with_capacity(num_rows),
418                Value::as_bool,
419            ),
420            DataType::LargeBinary => {
421                let mut builder = arrow_array::builder::LargeBinaryBuilder::with_capacity(
422                    num_rows,
423                    num_rows * 64,
424                );
425                for row in rows {
426                    match row.get(col_name) {
427                        Some(val) if !val.is_null() => {
428                            let cv_bytes = uni_common::cypher_value_codec::encode(val);
429                            builder.append_value(&cv_bytes);
430                        }
431                        _ => builder.append_null(),
432                    }
433                }
434                Arc::new(builder.finish()) as ArrayRef
435            }
436            DataType::List(inner_field) if inner_field.data_type() == &DataType::Utf8 => {
437                let mut builder = arrow_array::builder::ListBuilder::new(StringBuilder::new());
438                for row in rows {
439                    match row.get(col_name) {
440                        Some(Value::List(items)) => {
441                            for item in items {
442                                match item {
443                                    Value::String(s) => builder.values().append_value(s),
444                                    Value::Null => builder.values().append_null(),
445                                    other => builder.values().append_value(format!("{}", other)),
446                                }
447                            }
448                            builder.append(true);
449                        }
450                        _ => builder.append_null(),
451                    }
452                }
453                Arc::new(builder.finish()) as ArrayRef
454            }
455            DataType::Null => Arc::new(arrow_array::NullArray::new(num_rows)) as ArrayRef,
456            // Default: Utf8 for everything else
457            _ => {
458                let mut builder = StringBuilder::with_capacity(num_rows, num_rows * 32);
459                for row in rows {
460                    match row.get(col_name) {
461                        Some(Value::Null) | None => builder.append_null(),
462                        Some(Value::String(s)) => builder.append_value(s),
463                        Some(other) => builder.append_value(format!("{}", other)),
464                    }
465                }
466                Arc::new(builder.finish()) as ArrayRef
467            }
468        };
469        columns.push(col);
470    }
471
472    RecordBatch::try_new(schema.clone(), columns)
473        .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))
474}
475
476/// Slice a single row from a RecordBatch, preserving Arrow types.
477fn slice_row(batch: &RecordBatch, row_idx: usize) -> Vec<ArrayRef> {
478    batch
479        .columns()
480        .iter()
481        .map(|col| col.slice(row_idx, 1))
482        .collect()
483}
484
485/// Check if a logical plan is or contains a ProcedureCall node.
486/// This helps distinguish procedure calls (CALL...YIELD) from regular subqueries (CALL { ... }).
487fn is_procedure_call(plan: &LogicalPlan) -> bool {
488    match plan {
489        LogicalPlan::ProcedureCall { .. } => true,
490        LogicalPlan::Project { input, .. }
491        | LogicalPlan::Filter { input, .. }
492        | LogicalPlan::Sort { input, .. }
493        | LogicalPlan::Limit { input, .. }
494        | LogicalPlan::Distinct { input } => is_procedure_call(input),
495        _ => false,
496    }
497}
498
499/// Compute a hash for row parameters to enable deduplication.
500///
501/// Sorts entries by key for deterministic hashing regardless of iteration order.
502fn hash_row_params(params: &HashMap<String, Value>) -> u64 {
503    let mut hasher = DefaultHasher::new();
504    let mut entries: Vec<_> = params.iter().collect();
505    entries.sort_unstable_by_key(|(k, _)| *k);
506    for (key, val) in entries {
507        key.hash(&mut hasher);
508        format!("{:?}", val).hash(&mut hasher);
509    }
510    hasher.finish()
511}
512
513/// Check if batching is eligible for this apply operation.
514/// Returns true if:
515/// - There are 2+ filtered entries (single row → existing path)
516/// - At least one `._vid` correlation key exists
517fn is_batch_eligible(filtered_entries: &[(&RecordBatch, usize, HashMap<String, Value>)]) -> bool {
518    if filtered_entries.len() < 2 {
519        return false;
520    }
521
522    // Check if at least one correlation key (._vid) exists
523    filtered_entries
524        .iter()
525        .any(|(_, _, row_params)| row_params.keys().any(|k| k.ends_with("._vid")))
526}
527
528/// Run the apply operation: execute input, filter, correlate subquery, merge results.
529///
530/// Uses Arrow-native row slicing for input columns to preserve complex types
531/// (Struct, List, etc.), and only converts to Value for parameter injection.
532#[expect(clippy::too_many_arguments)]
533async fn run_apply(
534    input_exec: Arc<dyn ExecutionPlan>,
535    subquery_plan: &LogicalPlan,
536    input_filter: Option<&Expr>,
537    graph_ctx: &Arc<GraphExecutionContext>,
538    session_ctx: &Arc<RwLock<SessionContext>>,
539    storage: &Arc<StorageManager>,
540    schema_info: &Arc<UniSchema>,
541    params: &HashMap<String, Value>,
542    output_schema: &SchemaRef,
543) -> DFResult<RecordBatch> {
544    let apply_start = std::time::Instant::now();
545    let is_proc_call = is_procedure_call(subquery_plan);
546    tracing::debug!("run_apply: is_procedure_call={}", is_proc_call);
547
548    // 1. Execute pre-planned input physical plan directly
549    let task_ctx = session_ctx.read().task_ctx();
550    let input_batches = collect_all_partitions(&input_exec, task_ctx).await?;
551
552    // 2. Collect (batch_ref, row_idx) for rows that pass the input filter,
553    //    along with their Value-based params for subquery injection.
554    let mut filtered_entries: Vec<(&RecordBatch, usize, HashMap<String, Value>)> = Vec::new();
555    for batch in &input_batches {
556        for row_idx in 0..batch.num_rows() {
557            let row_params = extract_row_params(batch, row_idx);
558            if let Some(filter) = input_filter
559                && !evaluate_filter(filter, &row_params)
560            {
561                continue;
562            }
563            filtered_entries.push((batch, row_idx, row_params));
564        }
565    }
566
567    tracing::debug!(
568        "run_apply: filtered_entries count = {}",
569        filtered_entries.len()
570    );
571
572    // 3. Handle empty input: execute subquery once with base params
573    if filtered_entries.is_empty() {
574        let sub_batches = execute_subplan(
575            subquery_plan,
576            params,
577            &HashMap::new(), // No outer values for empty input case
578            graph_ctx,
579            session_ctx,
580            storage,
581            schema_info,
582        )
583        .await?;
584        let sub_rows = batches_to_row_maps(&sub_batches);
585        return rows_to_batch(&sub_rows, output_schema);
586    }
587
588    // 4. Check if we can batch the subplan execution
589    // IMPORTANT: Only batch when NOT a procedure call AND has input_filter.
590    // - Procedure calls use outer_values (not params), incompatible with batching
591    // - No input_filter indicates CALL subquery (e.g., MATCH (p) CALL { MATCH (p) })
592    //   which requires per-row correlation, not batching
593    // - Target pattern: procedure call → Apply with filter → MATCH traversal
594    let has_filter = input_filter.is_some();
595
596    if is_batch_eligible(&filtered_entries) && !is_proc_call && has_filter {
597        tracing::debug!("run_apply: batching eligible, attempting batch execution");
598
599        // Collect unique VID values and build batched params
600        let mut vid_values: HashMap<String, Vec<Value>> = HashMap::new();
601        for (_, _, row_params) in &filtered_entries {
602            for (key, value) in row_params {
603                if key.ends_with("._vid") {
604                    vid_values
605                        .entry(key.clone())
606                        .or_default()
607                        .push(value.clone());
608                }
609            }
610        }
611
612        // Build batched params: VID keys become Value::List
613        let mut batched_params = params.clone();
614        for (key, values) in &vid_values {
615            batched_params.insert(key.clone(), Value::List(values.clone()));
616        }
617
618        // Add carry-through parameters from first row (for literals in projections)
619        // These won't affect the WHERE filter but ensure planning succeeds
620        if let Some((_, _, first_row_params)) = filtered_entries.first() {
621            for (key, value) in first_row_params {
622                if !key.ends_with("._vid") {
623                    batched_params
624                        .entry(key.clone())
625                        .or_insert_with(|| value.clone());
626                }
627            }
628        }
629
630        // Execute subquery ONCE with batched VID params
631        let subplan_start = std::time::Instant::now();
632        let sub_batches = execute_subplan(
633            subquery_plan,
634            &batched_params,
635            &HashMap::new(),
636            graph_ctx,
637            session_ctx,
638            storage,
639            schema_info,
640        )
641        .await?;
642        let subplan_elapsed = subplan_start.elapsed();
643        tracing::debug!(
644            "run_apply: batch execute_subplan took {:?}",
645            subplan_elapsed
646        );
647
648        // Build hash index: VID → Vec<subquery result rows>
649        let sub_rows = batches_to_row_maps(&sub_batches);
650        let mut sub_index: HashMap<i64, Vec<&HashMap<String, Value>>> = HashMap::new();
651
652        // Find the VID key (should be the same for all rows)
653        let vid_key = vid_values.keys().next().expect("at least one VID key");
654
655        for sub_row in &sub_rows {
656            if let Some(Value::Int(vid)) = sub_row.get(vid_key) {
657                sub_index.entry(*vid).or_default().push(sub_row);
658            }
659        }
660
661        // Hash-join: for each input row, look up by VID, emit input+subquery columns
662        let input_schema = input_batches[0].schema();
663        let num_input_cols = input_schema.fields().len();
664        let num_output_cols = output_schema.fields().len();
665        let mut column_arrays: Vec<Vec<ArrayRef>> = vec![Vec::new(); num_output_cols];
666
667        for (batch, row_idx, row_params) in &filtered_entries {
668            // Extract VID from row params
669            let input_vid = if let Some(Value::Int(vid)) = row_params.get(vid_key) {
670                *vid
671            } else {
672                continue; // Skip if VID is not present
673            };
674
675            // Look up matching subquery rows by VID
676            if let Some(matching_sub_rows) = sub_index.get(&input_vid) {
677                let input_row_arrays = slice_row(batch, *row_idx);
678
679                for sub_row in matching_sub_rows {
680                    append_cross_join_row(
681                        &mut column_arrays,
682                        &input_row_arrays,
683                        sub_row,
684                        output_schema,
685                        num_input_cols,
686                    )?;
687                }
688            }
689            // else: inner join — skip input row (no subquery matches)
690        }
691
692        let result = concat_column_arrays(&column_arrays, output_schema);
693
694        let apply_elapsed = apply_start.elapsed();
695        tracing::debug!(
696            "run_apply: completed (batched) in {:?}, 1 subplan execution",
697            apply_elapsed
698        );
699
700        return result;
701    }
702
703    // 5. Fallback: For each input row, execute subquery and collect output column arrays.
704    //    Used when batching is not eligible (single row, no VID keys, or procedure call).
705    //    Each output row is: input columns (sliced) + subquery columns (sliced).
706    let input_schema = input_batches[0].schema();
707    let num_input_cols = input_schema.fields().len();
708    let num_output_cols = output_schema.fields().len();
709    // Accumulate per-column arrays for all output rows
710    let mut column_arrays: Vec<Vec<ArrayRef>> = vec![Vec::new(); num_output_cols];
711
712    let mut total_subplan_time = std::time::Duration::ZERO;
713    let mut subplan_executions = 0;
714
715    // Cache to deduplicate subplan executions for identical row parameters
716    let mut subplan_cache: HashMap<u64, Vec<HashMap<String, Value>>> = HashMap::new();
717    let mut cache_hits = 0;
718
719    for (batch, row_idx, row_params) in &filtered_entries {
720        // For procedure calls (CALL...YIELD), pass row_params as outer_values to avoid
721        // shadowing user parameters. For regular subqueries (CALL { ... }), merge them
722        // into parameters for backward compatibility with correlated variables.
723        let (sub_params, sub_outer_values) = if is_procedure_call(subquery_plan) {
724            // Procedure call: keep params separate from outer values
725            (params.clone(), row_params.clone())
726        } else {
727            // Regular subquery: merge outer values into params (old behavior)
728            let mut merged = params.clone();
729            merged.extend(row_params.clone());
730            (merged, HashMap::new())
731        };
732
733        // Check cache for identical row params
734        let params_hash = hash_row_params(row_params);
735        let sub_rows = if let Some(cached_rows) = subplan_cache.get(&params_hash) {
736            // Cache hit: reuse previous results
737            cache_hits += 1;
738            tracing::debug!(
739                "run_apply: cache hit for params hash {}, skipping execute_subplan",
740                params_hash
741            );
742            cached_rows.clone()
743        } else {
744            // Cache miss: execute subplan
745            let subplan_start = std::time::Instant::now();
746            let sub_batches = execute_subplan(
747                subquery_plan,
748                &sub_params,
749                &sub_outer_values,
750                graph_ctx,
751                session_ctx,
752                storage,
753                schema_info,
754            )
755            .await?;
756            let subplan_elapsed = subplan_start.elapsed();
757            total_subplan_time += subplan_elapsed;
758            subplan_executions += 1;
759
760            tracing::debug!(
761                "run_apply: execute_subplan #{} took {:?}",
762                subplan_executions,
763                subplan_elapsed
764            );
765
766            let rows = batches_to_row_maps(&sub_batches);
767            subplan_cache.insert(params_hash, rows.clone());
768            rows
769        };
770
771        let input_row_arrays = slice_row(batch, *row_idx);
772
773        if sub_rows.is_empty() {
774            // No subquery results — skip this input row (inner join semantics)
775            continue;
776        }
777
778        for sub_row in &sub_rows {
779            append_cross_join_row(
780                &mut column_arrays,
781                &input_row_arrays,
782                sub_row,
783                output_schema,
784                num_input_cols,
785            )?;
786        }
787    }
788
789    // 5. Concatenate all accumulated arrays per column
790    let result = concat_column_arrays(&column_arrays, output_schema);
791
792    let apply_elapsed = apply_start.elapsed();
793    tracing::debug!(
794        "run_apply: completed in {:?}, {} subplan executions, {} cache hits, {:?} total subplan time",
795        apply_elapsed,
796        subplan_executions,
797        cache_hits,
798        total_subplan_time
799    );
800
801    result
802}
803
804/// Build a single-row Arrow array from a builder and optional value.
805fn single_row_array<B, T>(mut builder: B, val: Option<T>) -> ArrayRef
806where
807    B: PrimitiveAppend<T>,
808{
809    match val {
810        Some(v) => builder.append_typed_value(v),
811        None => builder.append_typed_null(),
812    }
813    builder.finish_to_array()
814}
815
816/// Convert a single Value to a single-row Arrow array of the given type.
817fn value_to_single_row_array(val: &Value, data_type: &DataType) -> DFResult<ArrayRef> {
818    Ok(match data_type {
819        DataType::UInt64 => single_row_array(
820            UInt64Builder::with_capacity(1),
821            val.as_u64().or_else(|| val.as_i64().map(|v| v as u64)),
822        ),
823        DataType::Int64 => single_row_array(Int64Builder::with_capacity(1), val.as_i64()),
824        DataType::Int32 => single_row_array(
825            Int32Builder::with_capacity(1),
826            val.as_i64().map(|v| v as i32),
827        ),
828        DataType::Float64 => single_row_array(Float64Builder::with_capacity(1), val.as_f64()),
829        DataType::Boolean => single_row_array(BooleanBuilder::with_capacity(1), val.as_bool()),
830        DataType::Null => Arc::new(arrow_array::NullArray::new(1)) as ArrayRef,
831        _ => {
832            let mut b = StringBuilder::with_capacity(1, 64);
833            match val {
834                Value::Null => b.append_null(),
835                Value::String(s) => b.append_value(s),
836                other => b.append_value(format!("{}", other)),
837            }
838            Arc::new(b.finish()) as ArrayRef
839        }
840    })
841}
842
843/// Append one cross-joined row (input + subquery) to the per-column accumulator.
844///
845/// For input columns, uses the Arrow-native sliced arrays to preserve complex types.
846/// For subquery columns, converts `Value` to single-row Arrow arrays.
847fn append_cross_join_row(
848    column_arrays: &mut [Vec<ArrayRef>],
849    input_row_arrays: &[ArrayRef],
850    sub_row: &HashMap<String, Value>,
851    output_schema: &SchemaRef,
852    num_input_cols: usize,
853) -> DFResult<()> {
854    // Add input columns (Arrow-native, preserves types)
855    for (col_idx, arr) in input_row_arrays.iter().enumerate() {
856        column_arrays[col_idx].push(arr.clone());
857    }
858
859    // Add subquery columns using Value -> Arrow conversion
860    let num_output_cols = output_schema.fields().len();
861    for (col_arr, field) in column_arrays[num_input_cols..num_output_cols]
862        .iter_mut()
863        .zip(output_schema.fields()[num_input_cols..num_output_cols].iter())
864    {
865        let col_name = field.name();
866        let val = sub_row.get(col_name).cloned().unwrap_or(Value::Null);
867        let arr = value_to_single_row_array(&val, field.data_type())?;
868        col_arr.push(arr);
869    }
870    Ok(())
871}
872
873/// Concatenate per-column array accumulators into a single `RecordBatch`.
874///
875/// Returns an empty batch if no rows were accumulated.
876fn concat_column_arrays(
877    column_arrays: &[Vec<ArrayRef>],
878    output_schema: &SchemaRef,
879) -> DFResult<RecordBatch> {
880    if column_arrays[0].is_empty() {
881        return Ok(RecordBatch::new_empty(output_schema.clone()));
882    }
883
884    let mut final_columns: Vec<ArrayRef> = Vec::with_capacity(column_arrays.len());
885    for arrays in column_arrays {
886        let refs: Vec<&dyn arrow_array::Array> = arrays.iter().map(|a| a.as_ref()).collect();
887        let concatenated = arrow::compute::concat(&refs)
888            .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
889        final_columns.push(concatenated);
890    }
891
892    RecordBatch::try_new(output_schema.clone(), final_columns)
893        .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))
894}
895
896// ---------------------------------------------------------------------------
897// Stream implementation
898// ---------------------------------------------------------------------------
899
900/// Stream state for the apply operation.
901enum ApplyStreamState {
902    /// The apply computation is running.
903    Running(Pin<Box<dyn std::future::Future<Output = DFResult<RecordBatch>> + Send>>),
904    /// Computation completed.
905    Done,
906}
907
908/// Stream that runs the apply operation and emits the result.
909struct ApplyStream {
910    state: ApplyStreamState,
911    schema: SchemaRef,
912    metrics: BaselineMetrics,
913}
914
915impl Stream for ApplyStream {
916    type Item = DFResult<RecordBatch>;
917
918    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
919        match &mut self.state {
920            ApplyStreamState::Running(fut) => match fut.as_mut().poll(cx) {
921                Poll::Ready(Ok(batch)) => {
922                    self.metrics.record_output(batch.num_rows());
923                    self.state = ApplyStreamState::Done;
924                    Poll::Ready(Some(Ok(batch)))
925                }
926                Poll::Ready(Err(e)) => {
927                    self.state = ApplyStreamState::Done;
928                    Poll::Ready(Some(Err(e)))
929                }
930                Poll::Pending => Poll::Pending,
931            },
932            ApplyStreamState::Done => Poll::Ready(None),
933        }
934    }
935}
936
937impl RecordBatchStream for ApplyStream {
938    fn schema(&self) -> SchemaRef {
939        self.schema.clone()
940    }
941}