Skip to main content

uni_query/query/df_graph/
mutation_foreach.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! DataFusion ExecutionPlan for Cypher FOREACH clauses.
5//!
6//! FOREACH executes side-effect mutations (CREATE, SET, REMOVE, DELETE, MERGE)
7//! for each item in a list expression, per input row. The output rows are
8//! passed through unchanged (FOREACH does not modify the caller-visible result).
9
10use super::common::compute_plan_properties;
11use super::mutation_common::{MutationContext, batches_to_rows, rows_to_batches};
12use arrow_array::RecordBatch;
13use arrow_schema::SchemaRef;
14use datafusion::common::Result as DFResult;
15use datafusion::execution::TaskContext;
16use datafusion::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet};
17use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
18use datafusion::physical_plan::{
19    DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SendableRecordBatchStream,
20};
21use futures::TryStreamExt;
22use std::any::Any;
23use std::fmt;
24use std::sync::Arc;
25use uni_common::Value;
26use uni_cypher::ast::Expr;
27
28use crate::query::planner::LogicalPlan;
29
30/// DataFusion `ExecutionPlan` for Cypher FOREACH clauses.
31///
32/// FOREACH is a side-effect-only operator: it iterates over a list expression
33/// per input row and executes body plans (mutations) for each item. The input
34/// rows are passed through unchanged to downstream operators.
35///
36/// Implements the "eager barrier" pattern: collects all input batches, then
37/// processes FOREACH side effects, then yields original batches.
38#[derive(Debug)]
39pub struct ForeachExec {
40    /// Child plan producing input rows.
41    input: Arc<dyn ExecutionPlan>,
42
43    /// Iteration variable name (bound per list item).
44    variable: String,
45
46    /// AST expression for the list to iterate over.
47    list_expr: Expr,
48
49    /// Body plans to execute per list item.
50    body: Vec<LogicalPlan>,
51
52    /// Shared mutation context with executor and writer.
53    mutation_ctx: Arc<MutationContext>,
54
55    /// Output schema (same as input — FOREACH is pass-through).
56    schema: SchemaRef,
57
58    /// Plan properties for DataFusion optimizer.
59    properties: PlanProperties,
60
61    /// Metrics.
62    metrics: ExecutionPlanMetricsSet,
63}
64
65impl ForeachExec {
66    /// Create a new `ForeachExec`.
67    pub fn new(
68        input: Arc<dyn ExecutionPlan>,
69        variable: String,
70        list_expr: Expr,
71        body: Vec<LogicalPlan>,
72        mutation_ctx: Arc<MutationContext>,
73    ) -> Self {
74        let schema = input.schema();
75        let properties = compute_plan_properties(schema.clone());
76        Self {
77            input,
78            variable,
79            list_expr,
80            body,
81            mutation_ctx,
82            schema,
83            properties,
84            metrics: ExecutionPlanMetricsSet::new(),
85        }
86    }
87}
88
89impl DisplayAs for ForeachExec {
90    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
91        write!(f, "ForeachExec [var={}]", self.variable)
92    }
93}
94
95impl ExecutionPlan for ForeachExec {
96    fn name(&self) -> &str {
97        "ForeachExec"
98    }
99
100    fn as_any(&self) -> &dyn Any {
101        self
102    }
103
104    fn schema(&self) -> SchemaRef {
105        self.schema.clone()
106    }
107
108    fn properties(&self) -> &PlanProperties {
109        &self.properties
110    }
111
112    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
113        vec![&self.input]
114    }
115
116    fn with_new_children(
117        self: Arc<Self>,
118        children: Vec<Arc<dyn ExecutionPlan>>,
119    ) -> DFResult<Arc<dyn ExecutionPlan>> {
120        if children.len() != 1 {
121            return Err(datafusion::error::DataFusionError::Plan(
122                "ForeachExec requires exactly one child".to_string(),
123            ));
124        }
125        Ok(Arc::new(ForeachExec::new(
126            children[0].clone(),
127            self.variable.clone(),
128            self.list_expr.clone(),
129            self.body.clone(),
130            self.mutation_ctx.clone(),
131        )))
132    }
133
134    fn execute(
135        &self,
136        partition: usize,
137        context: Arc<TaskContext>,
138    ) -> DFResult<SendableRecordBatchStream> {
139        let input = self.input.clone();
140        let schema = self.schema.clone();
141        let variable = self.variable.clone();
142        let list_expr = self.list_expr.clone();
143        let body = self.body.clone();
144        let mutation_ctx = self.mutation_ctx.clone();
145
146        let stream = futures::stream::once(execute_foreach_inner(
147            input,
148            schema.clone(),
149            variable,
150            list_expr,
151            body,
152            mutation_ctx,
153            partition,
154            context,
155        ))
156        .try_flatten();
157
158        Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
159    }
160
161    fn metrics(&self) -> Option<MetricsSet> {
162        Some(self.metrics.clone_inner())
163    }
164}
165
166/// Inner async function for FOREACH execution.
167#[expect(clippy::too_many_arguments)]
168async fn execute_foreach_inner(
169    input: Arc<dyn ExecutionPlan>,
170    schema: SchemaRef,
171    variable: String,
172    list_expr: Expr,
173    body: Vec<LogicalPlan>,
174    mutation_ctx: Arc<MutationContext>,
175    partition: usize,
176    task_ctx: Arc<TaskContext>,
177) -> DFResult<futures::stream::Iter<std::vec::IntoIter<DFResult<RecordBatch>>>> {
178    // 1. Collect all input batches (eager barrier)
179    let input_stream = input.execute(partition, task_ctx)?;
180    let input_batches: Vec<RecordBatch> = input_stream.try_collect().await?;
181
182    let input_row_count: usize = input_batches.iter().map(|b| b.num_rows()).sum();
183    tracing::debug!(
184        variable = variable.as_str(),
185        rows = input_row_count,
186        "Executing FOREACH"
187    );
188
189    let df_err = |msg: &str, e: &dyn std::fmt::Display| {
190        datafusion::error::DataFusionError::Execution(format!("FOREACH: {msg}: {e}"))
191    };
192
193    // 2. Convert to rows for expression evaluation
194    let rows = batches_to_rows(&input_batches)
195        .map_err(|e| df_err("failed to convert batches to rows", &e))?;
196
197    // 3. Execute FOREACH body per row, per list item
198    let exec = &mutation_ctx.executor;
199    let pm = &mutation_ctx.prop_manager;
200    let params = &mutation_ctx.params;
201    let ctx = mutation_ctx.query_ctx.as_ref();
202
203    let writer_lock = &mutation_ctx.writer;
204    let mut writer = writer_lock.write().await;
205
206    for row in &rows {
207        // Evaluate the list expression
208        let list_val = exec
209            .evaluate_expr(&list_expr, row, pm, params, ctx)
210            .await
211            .map_err(|e| df_err("list evaluation failed", &e))?;
212
213        let items = match list_val {
214            Value::List(arr) => arr,
215            Value::Null => continue,
216            _ => {
217                return Err(datafusion::error::DataFusionError::Execution(
218                    "FOREACH requires a list expression".to_string(),
219                ));
220            }
221        };
222
223        // Execute body plans for each item
224        for item in items {
225            let mut scope = row.clone();
226            scope.insert(variable.clone(), item);
227
228            for plan in &body {
229                exec.execute_foreach_body_plan(
230                    plan.clone(),
231                    &mut scope,
232                    &mut writer,
233                    pm,
234                    params,
235                    ctx,
236                    mutation_ctx.tx_l0_override.as_ref(),
237                )
238                .await
239                .map_err(|e| df_err("body execution failed", &e))?;
240            }
241        }
242    }
243
244    drop(writer);
245
246    tracing::debug!(
247        variable = variable.as_str(),
248        rows = input_row_count,
249        "FOREACH complete"
250    );
251
252    // 4. Pass through original rows (FOREACH is side-effect only)
253    // Reconstruct from rows in case the schema needs normalization
254    let result_batches =
255        rows_to_batches(&rows, &schema).map_err(|e| df_err("failed to reconstruct batches", &e))?;
256    let results: Vec<DFResult<RecordBatch>> = result_batches.into_iter().map(Ok).collect();
257    Ok(futures::stream::iter(results))
258}