uni_query/query/df_graph/
mutation_foreach.rs1use super::common::compute_plan_properties;
11use super::mutation_common::{MutationContext, batches_to_rows, rows_to_batches};
12use arrow_array::RecordBatch;
13use arrow_schema::SchemaRef;
14use datafusion::common::Result as DFResult;
15use datafusion::execution::TaskContext;
16use datafusion::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet};
17use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
18use datafusion::physical_plan::{
19 DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SendableRecordBatchStream,
20};
21use futures::TryStreamExt;
22use std::any::Any;
23use std::fmt;
24use std::sync::Arc;
25use uni_common::Value;
26use uni_cypher::ast::Expr;
27
28use crate::query::planner::LogicalPlan;
29
30#[derive(Debug)]
39pub struct ForeachExec {
40 input: Arc<dyn ExecutionPlan>,
42
43 variable: String,
45
46 list_expr: Expr,
48
49 body: Vec<LogicalPlan>,
51
52 mutation_ctx: Arc<MutationContext>,
54
55 schema: SchemaRef,
57
58 properties: PlanProperties,
60
61 metrics: ExecutionPlanMetricsSet,
63}
64
65impl ForeachExec {
66 pub fn new(
68 input: Arc<dyn ExecutionPlan>,
69 variable: String,
70 list_expr: Expr,
71 body: Vec<LogicalPlan>,
72 mutation_ctx: Arc<MutationContext>,
73 ) -> Self {
74 let schema = input.schema();
75 let properties = compute_plan_properties(schema.clone());
76 Self {
77 input,
78 variable,
79 list_expr,
80 body,
81 mutation_ctx,
82 schema,
83 properties,
84 metrics: ExecutionPlanMetricsSet::new(),
85 }
86 }
87}
88
89impl DisplayAs for ForeachExec {
90 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
91 write!(f, "ForeachExec [var={}]", self.variable)
92 }
93}
94
95impl ExecutionPlan for ForeachExec {
96 fn name(&self) -> &str {
97 "ForeachExec"
98 }
99
100 fn as_any(&self) -> &dyn Any {
101 self
102 }
103
104 fn schema(&self) -> SchemaRef {
105 self.schema.clone()
106 }
107
108 fn properties(&self) -> &PlanProperties {
109 &self.properties
110 }
111
112 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
113 vec![&self.input]
114 }
115
116 fn with_new_children(
117 self: Arc<Self>,
118 children: Vec<Arc<dyn ExecutionPlan>>,
119 ) -> DFResult<Arc<dyn ExecutionPlan>> {
120 if children.len() != 1 {
121 return Err(datafusion::error::DataFusionError::Plan(
122 "ForeachExec requires exactly one child".to_string(),
123 ));
124 }
125 Ok(Arc::new(ForeachExec::new(
126 children[0].clone(),
127 self.variable.clone(),
128 self.list_expr.clone(),
129 self.body.clone(),
130 self.mutation_ctx.clone(),
131 )))
132 }
133
134 fn execute(
135 &self,
136 partition: usize,
137 context: Arc<TaskContext>,
138 ) -> DFResult<SendableRecordBatchStream> {
139 let input = self.input.clone();
140 let schema = self.schema.clone();
141 let variable = self.variable.clone();
142 let list_expr = self.list_expr.clone();
143 let body = self.body.clone();
144 let mutation_ctx = self.mutation_ctx.clone();
145
146 let stream = futures::stream::once(execute_foreach_inner(
147 input,
148 schema.clone(),
149 variable,
150 list_expr,
151 body,
152 mutation_ctx,
153 partition,
154 context,
155 ))
156 .try_flatten();
157
158 Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
159 }
160
161 fn metrics(&self) -> Option<MetricsSet> {
162 Some(self.metrics.clone_inner())
163 }
164}
165
166#[expect(clippy::too_many_arguments)]
168async fn execute_foreach_inner(
169 input: Arc<dyn ExecutionPlan>,
170 schema: SchemaRef,
171 variable: String,
172 list_expr: Expr,
173 body: Vec<LogicalPlan>,
174 mutation_ctx: Arc<MutationContext>,
175 partition: usize,
176 task_ctx: Arc<TaskContext>,
177) -> DFResult<futures::stream::Iter<std::vec::IntoIter<DFResult<RecordBatch>>>> {
178 let input_stream = input.execute(partition, task_ctx)?;
180 let input_batches: Vec<RecordBatch> = input_stream.try_collect().await?;
181
182 let input_row_count: usize = input_batches.iter().map(|b| b.num_rows()).sum();
183 tracing::debug!(
184 variable = variable.as_str(),
185 rows = input_row_count,
186 "Executing FOREACH"
187 );
188
189 let df_err = |msg: &str, e: &dyn std::fmt::Display| {
190 datafusion::error::DataFusionError::Execution(format!("FOREACH: {msg}: {e}"))
191 };
192
193 let rows = batches_to_rows(&input_batches)
195 .map_err(|e| df_err("failed to convert batches to rows", &e))?;
196
197 let exec = &mutation_ctx.executor;
199 let pm = &mutation_ctx.prop_manager;
200 let params = &mutation_ctx.params;
201 let ctx = mutation_ctx.query_ctx.as_ref();
202
203 let writer_lock = &mutation_ctx.writer;
204 let mut writer = writer_lock.write().await;
205
206 for row in &rows {
207 let list_val = exec
209 .evaluate_expr(&list_expr, row, pm, params, ctx)
210 .await
211 .map_err(|e| df_err("list evaluation failed", &e))?;
212
213 let items = match list_val {
214 Value::List(arr) => arr,
215 Value::Null => continue,
216 _ => {
217 return Err(datafusion::error::DataFusionError::Execution(
218 "FOREACH requires a list expression".to_string(),
219 ));
220 }
221 };
222
223 for item in items {
225 let mut scope = row.clone();
226 scope.insert(variable.clone(), item);
227
228 for plan in &body {
229 exec.execute_foreach_body_plan(
230 plan.clone(),
231 &mut scope,
232 &mut writer,
233 pm,
234 params,
235 ctx,
236 mutation_ctx.tx_l0_override.as_ref(),
237 )
238 .await
239 .map_err(|e| df_err("body execution failed", &e))?;
240 }
241 }
242 }
243
244 drop(writer);
245
246 tracing::debug!(
247 variable = variable.as_str(),
248 rows = input_row_count,
249 "FOREACH complete"
250 );
251
252 let result_batches =
255 rows_to_batches(&rows, &schema).map_err(|e| df_err("failed to reconstruct batches", &e))?;
256 let results: Vec<DFResult<RecordBatch>> = result_batches.into_iter().map(Ok).collect();
257 Ok(futures::stream::iter(results))
258}