Skip to main content

uni_query/query/df_graph/
mutation_foreach.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! DataFusion ExecutionPlan for Cypher FOREACH clauses.
5//!
6//! FOREACH executes side-effect mutations (CREATE, SET, REMOVE, DELETE, MERGE)
7//! for each item in a list expression, per input row. The output rows are
8//! passed through unchanged (FOREACH does not modify the caller-visible result).
9
10use super::common::compute_plan_properties;
11use super::mutation_common::{MutationContext, batches_to_rows, rows_to_batches};
12use arrow_array::RecordBatch;
13use arrow_schema::SchemaRef;
14use datafusion::common::Result as DFResult;
15use datafusion::execution::TaskContext;
16use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
17use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
18use datafusion::physical_plan::{
19    DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SendableRecordBatchStream,
20};
21use futures::TryStreamExt;
22use std::any::Any;
23use std::fmt;
24use std::sync::Arc;
25use uni_common::Value;
26use uni_cypher::ast::Expr;
27
28use crate::query::planner::LogicalPlan;
29
30/// DataFusion `ExecutionPlan` for Cypher FOREACH clauses.
31///
32/// FOREACH is a side-effect-only operator: it iterates over a list expression
33/// per input row and executes body plans (mutations) for each item. The input
34/// rows are passed through unchanged to downstream operators.
35///
36/// Implements the "eager barrier" pattern: collects all input batches, then
37/// processes FOREACH side effects, then yields original batches.
38#[derive(Debug)]
39pub struct ForeachExec {
40    /// Child plan producing input rows.
41    input: Arc<dyn ExecutionPlan>,
42
43    /// Iteration variable name (bound per list item).
44    variable: String,
45
46    /// AST expression for the list to iterate over.
47    list_expr: Expr,
48
49    /// Body plans to execute per list item.
50    body: Vec<LogicalPlan>,
51
52    /// Shared mutation context with executor and writer.
53    mutation_ctx: Arc<MutationContext>,
54
55    /// Output schema (same as input — FOREACH is pass-through).
56    schema: SchemaRef,
57
58    /// Plan properties for DataFusion optimizer.
59    properties: Arc<PlanProperties>,
60
61    /// Metrics.
62    metrics: ExecutionPlanMetricsSet,
63}
64
65impl ForeachExec {
66    /// Create a new `ForeachExec`.
67    pub fn new(
68        input: Arc<dyn ExecutionPlan>,
69        variable: String,
70        list_expr: Expr,
71        body: Vec<LogicalPlan>,
72        mutation_ctx: Arc<MutationContext>,
73    ) -> Self {
74        let schema = input.schema();
75        let properties = compute_plan_properties(schema.clone());
76        Self {
77            input,
78            variable,
79            list_expr,
80            body,
81            mutation_ctx,
82            schema,
83            properties,
84            metrics: ExecutionPlanMetricsSet::new(),
85        }
86    }
87}
88
89impl DisplayAs for ForeachExec {
90    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
91        write!(f, "ForeachExec [var={}]", self.variable)
92    }
93}
94
95impl ExecutionPlan for ForeachExec {
96    fn name(&self) -> &str {
97        "ForeachExec"
98    }
99
100    fn as_any(&self) -> &dyn Any {
101        self
102    }
103
104    fn schema(&self) -> SchemaRef {
105        self.schema.clone()
106    }
107
108    fn properties(&self) -> &Arc<PlanProperties> {
109        &self.properties
110    }
111
112    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
113        vec![&self.input]
114    }
115
116    fn with_new_children(
117        self: Arc<Self>,
118        children: Vec<Arc<dyn ExecutionPlan>>,
119    ) -> DFResult<Arc<dyn ExecutionPlan>> {
120        if children.len() != 1 {
121            return Err(datafusion::error::DataFusionError::Plan(
122                "ForeachExec requires exactly one child".to_string(),
123            ));
124        }
125        Ok(Arc::new(ForeachExec::new(
126            children[0].clone(),
127            self.variable.clone(),
128            self.list_expr.clone(),
129            self.body.clone(),
130            self.mutation_ctx.clone(),
131        )))
132    }
133
134    fn execute(
135        &self,
136        partition: usize,
137        context: Arc<TaskContext>,
138    ) -> DFResult<SendableRecordBatchStream> {
139        let input = self.input.clone();
140        let schema = self.schema.clone();
141        let variable = self.variable.clone();
142        let list_expr = self.list_expr.clone();
143        let body = self.body.clone();
144        let mutation_ctx = self.mutation_ctx.clone();
145        let baseline = BaselineMetrics::new(&self.metrics, partition);
146
147        let stream = futures::stream::once(execute_foreach_inner(
148            input,
149            schema.clone(),
150            variable,
151            list_expr,
152            body,
153            mutation_ctx,
154            partition,
155            context,
156            baseline,
157        ))
158        .try_flatten();
159
160        Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
161    }
162
163    fn metrics(&self) -> Option<MetricsSet> {
164        Some(self.metrics.clone_inner())
165    }
166}
167
168/// Inner async function for FOREACH execution.
169#[expect(clippy::too_many_arguments)]
170async fn execute_foreach_inner(
171    input: Arc<dyn ExecutionPlan>,
172    schema: SchemaRef,
173    variable: String,
174    list_expr: Expr,
175    body: Vec<LogicalPlan>,
176    mutation_ctx: Arc<MutationContext>,
177    partition: usize,
178    task_ctx: Arc<TaskContext>,
179    baseline: BaselineMetrics,
180) -> DFResult<futures::stream::Iter<std::vec::IntoIter<DFResult<RecordBatch>>>> {
181    // Time the whole eager-barrier body. Timer records on Drop.
182    let _timer = baseline.elapsed_compute().timer();
183    // 1. Collect all input batches (eager barrier)
184    let input_stream = input.execute(partition, task_ctx)?;
185    let input_batches: Vec<RecordBatch> = input_stream.try_collect().await?;
186
187    let input_row_count: usize = input_batches.iter().map(|b| b.num_rows()).sum();
188    tracing::debug!(
189        variable = variable.as_str(),
190        rows = input_row_count,
191        "Executing FOREACH"
192    );
193
194    let df_err = |msg: &str, e: &dyn std::fmt::Display| {
195        datafusion::error::DataFusionError::Execution(format!("FOREACH: {msg}: {e}"))
196    };
197
198    // 2. Convert to rows for expression evaluation
199    let rows = batches_to_rows(&input_batches)
200        .map_err(|e| df_err("failed to convert batches to rows", &e))?;
201
202    // 3. Execute FOREACH body per row, per list item
203    let exec = &mutation_ctx.executor;
204    let pm = &mutation_ctx.prop_manager;
205    let params = &mutation_ctx.params;
206    let ctx = mutation_ctx.query_ctx.as_ref();
207
208    let writer_lock = &mutation_ctx.writer;
209    let writer: &uni_store::Writer = writer_lock.as_ref();
210
211    for row in &rows {
212        // Evaluate the list expression
213        let list_val = exec
214            .evaluate_expr(&list_expr, row, pm, params, ctx)
215            .await
216            .map_err(|e| df_err("list evaluation failed", &e))?;
217
218        let items = match list_val {
219            Value::List(arr) => arr,
220            Value::Null => continue,
221            _ => {
222                return Err(datafusion::error::DataFusionError::Execution(
223                    "FOREACH requires a list expression".to_string(),
224                ));
225            }
226        };
227
228        // Execute body plans for each item
229        for item in items {
230            let mut scope = row.clone();
231            scope.insert(variable.clone(), item);
232
233            for plan in &body {
234                exec.execute_foreach_body_plan(
235                    plan.clone(),
236                    &mut scope,
237                    writer,
238                    pm,
239                    params,
240                    ctx,
241                    mutation_ctx.tx_l0_override.as_ref(),
242                )
243                .await
244                .map_err(|e| df_err("body execution failed", &e))?;
245            }
246        }
247    }
248
249    tracing::debug!(
250        variable = variable.as_str(),
251        rows = input_row_count,
252        "FOREACH complete"
253    );
254
255    // 4. Pass through original rows (FOREACH is side-effect only)
256    // Reconstruct from rows in case the schema needs normalization
257    let result_batches =
258        rows_to_batches(&rows, &schema).map_err(|e| df_err("failed to reconstruct batches", &e))?;
259    let output_rows: usize = result_batches.iter().map(|b| b.num_rows()).sum();
260    baseline.record_output(output_rows);
261    let results: Vec<DFResult<RecordBatch>> = result_batches.into_iter().map(Ok).collect();
262    Ok(futures::stream::iter(results))
263}