Skip to main content

uni_query/query/df_graph/
mutation_common.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Common infrastructure for DataFusion mutation operators (CREATE, SET, REMOVE, DELETE).
5//!
6//! Provides:
7//! - [`MutationContext`]: Shared context for mutation operators containing executor, writer, etc.
8//! - [`batches_to_rows`]: Convert RecordBatches to row-based HashMaps (batch→row direction).
9//! - [`rows_to_batches`]: Convert row-based HashMaps back to RecordBatches (row→batch direction).
10//! - [`MutationExec`]: Eager-barrier RecordBatchStream that collects all input, applies
11//!   mutations via Writer, and yields output batches.
12
13use anyhow::Result;
14use arrow_array::RecordBatch;
15use arrow_schema::{DataType, SchemaRef};
16use datafusion::common::Result as DFResult;
17use datafusion::execution::TaskContext;
18use datafusion::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet};
19use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
20use datafusion::physical_plan::{
21    DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SendableRecordBatchStream,
22};
23use futures::TryStreamExt;
24use std::any::Any;
25use std::collections::{HashMap, HashSet};
26use std::fmt;
27use std::sync::Arc;
28use tokio::sync::RwLock;
29use uni_common::core::id::Vid;
30use uni_common::{Path, Value};
31use uni_cypher::ast::{Expr, Pattern, PatternElement, RemoveItem, SetClause, SetItem};
32use uni_store::runtime::property_manager::PropertyManager;
33use uni_store::runtime::writer::Writer;
34use uni_store::storage::arrow_convert;
35
36use super::common::compute_plan_properties;
37use crate::query::executor::core::Executor;
38
39/// Shared context for mutation operators.
40///
41/// Contains all resources needed to execute write operations from within
42/// DataFusion ExecutionPlan operators. The Executor is `Clone` with all
43/// Arc-wrapped fields, so cloning it is cheap.
44#[derive(Clone)]
45pub struct MutationContext {
46    /// The query executor (cheap clone, all Arc fields).
47    pub executor: Executor,
48
49    /// Writer for graph mutations (vertices, edges, properties).
50    pub writer: Arc<RwLock<Writer>>,
51
52    /// Property manager for lazy-loading vertex/edge properties.
53    pub prop_manager: Arc<PropertyManager>,
54
55    /// Query parameters (e.g., `$param` references in Cypher).
56    pub params: HashMap<String, Value>,
57
58    /// Query context for L0 buffer visibility.
59    pub query_ctx: Option<uni_store::QueryContext>,
60}
61
62impl std::fmt::Debug for MutationContext {
63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64        f.debug_struct("MutationContext")
65            .field("has_writer", &true)
66            .field("has_prop_manager", &true)
67            .field("params_count", &self.params.len())
68            .field("has_query_ctx", &self.query_ctx.is_some())
69            .finish()
70    }
71}
72
73/// The kind of mutation to apply per row.
74#[derive(Debug, Clone)]
75pub enum MutationKind {
76    /// CREATE clause: create nodes/edges per the pattern.
77    Create { pattern: Pattern },
78
79    /// CREATE with multiple patterns (batched CREATE).
80    CreateBatch { patterns: Vec<Pattern> },
81
82    /// SET clause: update properties/labels.
83    Set { items: Vec<SetItem> },
84
85    /// REMOVE clause: remove properties/labels.
86    Remove { items: Vec<RemoveItem> },
87
88    /// DELETE clause: delete nodes/edges.
89    Delete { items: Vec<Expr>, detach: bool },
90
91    /// MERGE clause: match-or-create with optional ON MATCH/ON CREATE actions.
92    Merge {
93        pattern: Pattern,
94        on_match: Option<SetClause>,
95        on_create: Option<SetClause>,
96    },
97}
98
99/// Convert RecordBatches to row-based HashMaps for mutation processing.
100///
101/// Handles special metadata on fields:
102/// - `cv_encoded=true`: Parse string value as JSON to restore original type
103/// - DateTime/Time struct types: Decode to temporal values
104///
105/// NOTE: This does NOT merge system fields (like `n._vid`) into bare variable
106/// maps. The raw column names are preserved so that `rows_to_batches` can
107/// reconstruct the RecordBatch with the same schema. System field merging
108/// happens later in `Executor::record_batches_to_rows()` for user-facing output.
109pub fn batches_to_rows(batches: &[RecordBatch]) -> Result<Vec<HashMap<String, Value>>> {
110    let mut rows = Vec::new();
111
112    for batch in batches {
113        let num_rows = batch.num_rows();
114        let schema = batch.schema();
115
116        for row_idx in 0..num_rows {
117            let mut row = HashMap::new();
118
119            for (col_idx, field) in schema.fields().iter().enumerate() {
120                let column = batch.column(col_idx);
121                // Infer Uni DataType from Arrow type for DateTime/Time struct decoding
122                let data_type = if uni_common::core::schema::is_datetime_struct(field.data_type()) {
123                    Some(&uni_common::DataType::DateTime)
124                } else if uni_common::core::schema::is_time_struct(field.data_type()) {
125                    Some(&uni_common::DataType::Time)
126                } else {
127                    None
128                };
129                let mut value = arrow_convert::arrow_to_value(column.as_ref(), row_idx, data_type);
130
131                // Check if this field contains JSON-encoded values (e.g., from UNWIND)
132                // Parse JSON string to restore the original type
133                if field
134                    .metadata()
135                    .get("cv_encoded")
136                    .is_some_and(|v| v == "true")
137                    && let Value::String(s) = &value
138                    && let Ok(parsed) = serde_json::from_str::<serde_json::Value>(s)
139                {
140                    value = Value::from(parsed);
141                }
142
143                row.insert(field.name().clone(), value);
144            }
145
146            // Also merge system fields into bare variable maps for the write helpers.
147            // The write helpers (execute_set_items_locked, etc.) expect variables
148            // as bare Maps with _vid/_labels inside. We do this AFTER preserving
149            // the raw keys so rows_to_batches can reconstruct the schema.
150            merge_system_fields_for_write(&mut row);
151
152            rows.push(row);
153        }
154    }
155
156    Ok(rows)
157}
158
159/// After mutations, sync `_all_props` within bare variable Maps from their direct property keys.
160///
161/// SET/REMOVE modify direct property keys in the bare Map (e.g., `row["n"]["name"] = "Bob"`)
162/// but the `_all_props` sub-map retains its stale pre-mutation values. The result normalizer
163/// and property UDFs (keys(), properties()) read from `_all_props`, so it must be kept in sync.
164///
165/// This must be called BEFORE `sync_dotted_columns` so that the dotted `n._all_props` column
166/// also gets the updated value.
167fn sync_all_props_in_maps(rows: &mut [HashMap<String, Value>]) {
168    for row in rows {
169        let map_keys: Vec<String> = row
170            .keys()
171            .filter(|k| !k.contains('.') && matches!(row.get(*k), Some(Value::Map(_))))
172            .cloned()
173            .collect();
174
175        for key in map_keys {
176            if let Some(Value::Map(map)) = row.get_mut(&key)
177                && map.contains_key("_all_props")
178            {
179                // Collect non-internal property keys and their values
180                let updates: Vec<(String, Value)> = map
181                    .iter()
182                    .filter(|(k, _)| !k.starts_with('_') && k.as_str() != "ext_id")
183                    .map(|(k, v)| (k.clone(), v.clone()))
184                    .collect();
185
186                if !updates.is_empty()
187                    && let Some(Value::Map(all_props)) = map.get_mut("_all_props")
188                {
189                    for (k, v) in updates {
190                        all_props.insert(k, v);
191                    }
192                }
193            }
194        }
195    }
196}
197
198/// After mutations, sync dotted property columns from bare variable Maps.
199///
200/// SET/REMOVE modify the bare Map (e.g., `row["n"]["name"] = "Bob"`) but the
201/// dotted column (`row["n.name"]`) retains its stale pre-mutation value.
202/// This step overwrites dotted columns from the Map so `rows_to_batches()`
203/// produces correct output. Also handles newly created variables from CREATE/MERGE
204/// by inserting dotted columns that didn't exist in the input.
205fn sync_dotted_columns(rows: &mut [HashMap<String, Value>], schema: &SchemaRef) {
206    for row in rows {
207        for field in schema.fields() {
208            let name = field.name();
209            if let Some(dot_pos) = name.find('.') {
210                let var_name = &name[..dot_pos];
211                let prop_name = &name[dot_pos + 1..];
212                if let Some(Value::Map(map)) = row.get(var_name) {
213                    let val = map.get(prop_name).cloned().unwrap_or(Value::Null);
214                    row.insert(name.clone(), val);
215                }
216            }
217        }
218    }
219}
220
221/// Normalize edge system field names in a map: `_src_vid` -> `_src`, `_dst_vid` -> `_dst`.
222///
223/// The write executor expects `_src`/`_dst` but DataFusion traverse emits `_src_vid`/`_dst_vid`.
224fn normalize_edge_field_names(map: &mut HashMap<String, Value>) {
225    if let Some(val) = map.remove("_src_vid") {
226        map.entry("_src".to_string()).or_insert(val);
227    }
228    if let Some(val) = map.remove("_dst_vid") {
229        map.entry("_dst".to_string()).or_insert(val);
230    }
231}
232
233/// Merge system fields into bare variable maps for write helper consumption.
234///
235/// The write helpers expect variables like `n` to be a Map containing `_vid`, `_labels`, etc.
236/// This merges dotted columns (like `n._vid`, `n._labels`) into the variable Map,
237/// while KEEPING the dotted columns in the row so `rows_to_batches` still works.
238fn merge_system_fields_for_write(row: &mut HashMap<String, Value>) {
239    // Vertex system fields (overwrite into the bare map) and edge system fields
240    // (insert only if absent) that should be copied from dotted columns.
241    const VERTEX_FIELDS: &[&str] = &["_vid", "_labels"];
242    const EDGE_FIELDS: &[&str] = &["_eid", "_type", "_src_vid", "_dst_vid"];
243
244    // Collect all variable names that have dotted columns (var.field).
245    let dotted_vars: HashSet<String> = row
246        .keys()
247        .filter_map(|key| key.find('.').map(|pos| key[..pos].to_string()))
248        .collect();
249
250    // For each variable with dotted columns, ensure a bare Map exists.
251    // If the variable is only represented via dotted columns (e.g., edge from
252    // TraverseMainByType), assemble a Map from those columns.
253    for var in &dotted_vars {
254        if !row.contains_key(var) {
255            let prefix = format!("{var}.");
256            let mut map: HashMap<String, Value> = row
257                .iter()
258                .filter_map(|(k, v)| {
259                    k.strip_prefix(prefix.as_str())
260                        .map(|field| (field.to_string(), v.clone()))
261                })
262                .collect();
263            normalize_edge_field_names(&mut map);
264            if !map.is_empty() {
265                row.insert(var.clone(), Value::Map(map));
266            }
267        }
268    }
269
270    // Merge system fields from dotted columns into bare Maps and normalize edge names.
271    // Single pass: vertex fields overwrite, edge fields insert-if-absent, then normalize.
272    let bare_vars: Vec<String> = row
273        .keys()
274        .filter(|k| !k.contains('.') && matches!(row.get(*k), Some(Value::Map(_))))
275        .cloned()
276        .collect();
277
278    for var in &bare_vars {
279        // Collect dotted values to merge (avoids borrowing row mutably while reading)
280        let vertex_vals: Vec<(&str, Value)> = VERTEX_FIELDS
281            .iter()
282            .filter_map(|&field| {
283                row.get(&format!("{var}.{field}"))
284                    .cloned()
285                    .map(|v| (field, v))
286            })
287            .collect();
288        let edge_vals: Vec<(&str, Value)> = EDGE_FIELDS
289            .iter()
290            .filter_map(|&field| {
291                row.get(&format!("{var}.{field}"))
292                    .cloned()
293                    .map(|v| (field, v))
294            })
295            .collect();
296
297        if let Some(Value::Map(map)) = row.get_mut(var) {
298            for (field, v) in vertex_vals {
299                map.insert(field.to_string(), v);
300            }
301            for (field, v) in edge_vals {
302                map.entry(field.to_string()).or_insert(v);
303            }
304            normalize_edge_field_names(map);
305        }
306    }
307}
308
309/// Convert row-based HashMaps back to RecordBatches.
310///
311/// This is the inverse of `batches_to_rows`. Schema-driven: iterates over the
312/// output schema fields and extracts named values from each row HashMap.
313///
314/// - Entity columns (LargeBinary with `cv_encoded=true`): serialize Map/Node/Edge values
315///   to CypherValue binary encoding.
316/// - Scalar columns: use `arrow_convert::values_to_array()` for type-appropriate conversion.
317pub fn rows_to_batches(
318    rows: &[HashMap<String, Value>],
319    schema: &SchemaRef,
320) -> Result<Vec<RecordBatch>> {
321    if rows.is_empty() {
322        // Handle empty schema case (no fields)
323        let batch = if schema.fields().is_empty() {
324            let options = arrow_array::RecordBatchOptions::new().with_row_count(Some(0));
325            RecordBatch::try_new_with_options(schema.clone(), vec![], &options)?
326        } else {
327            RecordBatch::new_empty(schema.clone())
328        };
329        return Ok(vec![batch]);
330    }
331
332    if schema.fields().is_empty() {
333        // Schema has no fields but there ARE rows. Preserve the row count so that
334        // downstream operators (chained mutations, aggregations) see the correct
335        // number of rows. A RecordBatch with 0 columns can still carry a row count.
336        let options = arrow_array::RecordBatchOptions::new().with_row_count(Some(rows.len()));
337        let batch = RecordBatch::try_new_with_options(schema.clone(), vec![], &options)?;
338        return Ok(vec![batch]);
339    }
340
341    // Build columns from rows using schema
342    let mut columns: Vec<arrow_array::ArrayRef> = Vec::with_capacity(schema.fields().len());
343
344    for field in schema.fields() {
345        let name = field.name();
346        let values: Vec<Value> = rows
347            .iter()
348            .map(|row| row.get(name).cloned().unwrap_or(Value::Null))
349            .collect();
350
351        let array = value_column_to_arrow(&values, field.data_type(), field)?;
352        columns.push(array);
353    }
354
355    let batch = RecordBatch::try_new(schema.clone(), columns)?;
356    Ok(vec![batch])
357}
358
359/// Convert a column of Values to an Arrow array, handling entity-encoded columns.
360fn value_column_to_arrow(
361    values: &[Value],
362    arrow_type: &DataType,
363    field: &arrow_schema::Field,
364) -> Result<arrow_array::ArrayRef> {
365    let is_cv_encoded = field
366        .metadata()
367        .get("cv_encoded")
368        .is_some_and(|v| v == "true");
369
370    if *arrow_type == DataType::LargeBinary || is_cv_encoded {
371        Ok(encode_as_large_binary(values))
372    } else if *arrow_type == DataType::Binary {
373        // Binary columns (e.g., CRDT payloads): encode as Binary, not LargeBinary
374        Ok(encode_as_binary(values))
375    } else {
376        // Use arrow_convert for scalar types, falling back to CypherValue encoding
377        arrow_convert::values_to_array(values, arrow_type)
378            .or_else(|_| Ok(encode_as_large_binary(values)))
379    }
380}
381
382/// Encode values as CypherValue blobs using the given builder type.
383macro_rules! encode_as_cv {
384    ($builder_ty:ty, $values:expr) => {{
385        let values = $values;
386        let mut builder = <$builder_ty>::with_capacity(values.len(), values.len() * 64);
387        for v in values {
388            if v.is_null() {
389                builder.append_null();
390            } else {
391                let bytes = uni_common::cypher_value_codec::encode(v);
392                builder.append_value(&bytes);
393            }
394        }
395        Arc::new(builder.finish()) as arrow_array::ArrayRef
396    }};
397}
398
399/// Encode values as CypherValue Binary blobs.
400fn encode_as_binary(values: &[Value]) -> arrow_array::ArrayRef {
401    encode_as_cv!(arrow_array::builder::BinaryBuilder, values)
402}
403
404/// Encode values as CypherValue LargeBinary blobs.
405fn encode_as_large_binary(values: &[Value]) -> arrow_array::ArrayRef {
406    encode_as_cv!(arrow_array::builder::LargeBinaryBuilder, values)
407}
408
409/// Execute a mutation stream: collect all input batches, apply mutations, yield output.
410///
411/// This is the core logic shared by all mutation operators. It implements the
412/// "eager barrier" pattern:
413/// 1. Pull ALL input batches to completion
414/// 2. Convert to rows
415/// 3. Acquire writer lock once for the entire clause
416/// 4. Apply mutations per row
417/// 5. Convert back to batches
418/// 6. Yield output
419pub fn execute_mutation_stream(
420    input: Arc<dyn ExecutionPlan>,
421    output_schema: SchemaRef,
422    mutation_ctx: Arc<MutationContext>,
423    mutation_kind: MutationKind,
424    partition: usize,
425    task_ctx: Arc<datafusion::execution::TaskContext>,
426) -> DFResult<SendableRecordBatchStream> {
427    if mutation_ctx.query_ctx.is_none() {
428        tracing::warn!(
429            "MutationContext.query_ctx is None — mutations may not see latest L0 buffer state"
430        );
431    }
432
433    let stream = futures::stream::once(execute_mutation_inner(
434        input,
435        output_schema.clone(),
436        mutation_ctx,
437        mutation_kind,
438        partition,
439        task_ctx,
440    ))
441    .try_flatten();
442
443    Ok(Box::pin(RecordBatchStreamAdapter::new(
444        output_schema,
445        stream,
446    )))
447}
448
449/// Inner async function for mutation execution.
450///
451/// Separated from the stream combinator to provide explicit return type
452/// annotation, avoiding type inference issues with multiple From<DataFusionError> impls.
453///
454/// Mutations are applied as storage-level side effects via Writer/L0 buffer.
455/// After mutations, output batches are reconstructed from the modified rows
456/// so downstream operators (RETURN, WITH, subsequent mutations) see the
457/// created/updated variables and properties.
458async fn execute_mutation_inner(
459    input: Arc<dyn ExecutionPlan>,
460    output_schema: SchemaRef,
461    mutation_ctx: Arc<MutationContext>,
462    mutation_kind: MutationKind,
463    partition: usize,
464    task_ctx: Arc<datafusion::execution::TaskContext>,
465) -> DFResult<futures::stream::Iter<std::vec::IntoIter<DFResult<RecordBatch>>>> {
466    let mutation_label = mutation_kind_label(&mutation_kind);
467
468    // 1. Collect all input batches (eager barrier)
469    let input_stream = input.execute(partition, task_ctx)?;
470    let input_batches: Vec<RecordBatch> = input_stream.try_collect().await?;
471
472    let input_row_count: usize = input_batches.iter().map(|b| b.num_rows()).sum();
473    tracing::debug!(
474        mutation = mutation_label,
475        batches = input_batches.len(),
476        rows = input_row_count,
477        "Executing mutation"
478    );
479
480    // 2. Convert to rows for mutation helpers (they operate on HashMap rows)
481    let mut rows = batches_to_rows(&input_batches).map_err(|e| {
482        datafusion::error::DataFusionError::Execution(format!(
483            "Failed to convert batches to rows: {e}"
484        ))
485    })?;
486
487    // 3. Apply mutations.
488    // MERGE manages its own writer lock internally (acquires/releases per-row because
489    // execute_merge_match needs to run a read subplan between lock acquisitions).
490    // All other mutations acquire the writer lock once for the entire clause.
491    if let MutationKind::Merge {
492        ref pattern,
493        ref on_match,
494        ref on_create,
495    } = mutation_kind
496    {
497        let exec = &mutation_ctx.executor;
498        let pm = &mutation_ctx.prop_manager;
499        let params = &mutation_ctx.params;
500        let ctx = mutation_ctx.query_ctx.as_ref();
501
502        let mut result_rows = exec
503            .execute_merge(
504                rows,
505                pattern,
506                on_match.as_ref(),
507                on_create.as_ref(),
508                pm,
509                params,
510                ctx,
511            )
512            .await
513            .map_err(|e| {
514                datafusion::error::DataFusionError::Execution(format!("MERGE failed: {e}"))
515            })?;
516
517        tracing::debug!(
518            mutation = mutation_label,
519            input_rows = input_row_count,
520            output_rows = result_rows.len(),
521            "MERGE mutation complete"
522        );
523
524        // Reconstruct output batches from modified rows so downstream operators
525        // (RETURN, WITH, subsequent mutations) see the merged/created variables.
526        sync_all_props_in_maps(&mut result_rows);
527        sync_dotted_columns(&mut result_rows, &output_schema);
528        let result_batches = rows_to_batches(&result_rows, &output_schema).map_err(|e| {
529            datafusion::error::DataFusionError::Execution(format!(
530                "Failed to reconstruct MERGE batches: {e}"
531            ))
532        })?;
533        let results: Vec<DFResult<RecordBatch>> = result_batches.into_iter().map(Ok).collect();
534        return Ok(futures::stream::iter(results));
535    }
536
537    let mut writer = mutation_ctx.writer.write().await;
538    apply_mutations(&mutation_ctx, &mutation_kind, &mut rows, &mut writer).await?;
539    drop(writer);
540
541    tracing::debug!(
542        mutation = mutation_label,
543        rows = input_row_count,
544        "Mutation complete"
545    );
546
547    // 4. Reconstruct output batches from modified rows.
548    // Mutations modify the row HashMaps in place (CREATE adds new variable keys,
549    // SET updates property values). Reconstruct batches so downstream operators
550    // (RETURN, WITH, subsequent mutations) see these modifications.
551    sync_all_props_in_maps(&mut rows);
552    sync_dotted_columns(&mut rows, &output_schema);
553    let result_batches = rows_to_batches(&rows, &output_schema).map_err(|e| {
554        datafusion::error::DataFusionError::Execution(format!("Failed to reconstruct batches: {e}"))
555    })?;
556    let results: Vec<DFResult<RecordBatch>> = result_batches.into_iter().map(Ok).collect();
557    Ok(futures::stream::iter(results))
558}
559
560/// Collects and classifies DELETE targets into nodes and edges.
561///
562/// Handles `Value::Path`, `Value::Node`, `Value::Edge`, map-encoded paths
563/// (from Arrow round-trip), and raw VID values. When `dedup` is true,
564/// uses HashSets to skip duplicates (needed for non-DETACH DELETE to
565/// handle shared nodes across paths).
566struct DeleteCollector {
567    /// Collected node entries: (vid, labels) pairs.
568    node_entries: Vec<(Vid, Option<Vec<String>>)>,
569    /// Collected edge values to delete.
570    edge_vals: Vec<Value>,
571    /// Deduplication sets (only used when dedup=true).
572    seen_vids: HashSet<u64>,
573    seen_eids: HashSet<u64>,
574    dedup: bool,
575}
576
577impl DeleteCollector {
578    fn new(dedup: bool) -> Self {
579        Self {
580            node_entries: Vec::new(),
581            edge_vals: Vec::new(),
582            seen_vids: HashSet::new(),
583            seen_eids: HashSet::new(),
584            dedup,
585        }
586    }
587
588    fn add(&mut self, val: Value) {
589        if val.is_null() {
590            return;
591        }
592
593        // Try to resolve value as a Path (native or map-encoded).
594        let path = match &val {
595            Value::Path(p) => Some(p.clone()),
596            _ => Path::try_from(&val).ok(),
597        };
598
599        if let Some(path) = path {
600            for edge in &path.edges {
601                if !self.dedup || self.seen_eids.insert(edge.eid.as_u64()) {
602                    self.edge_vals.push(Value::Edge(edge.clone()));
603                }
604            }
605            for node in &path.nodes {
606                self.add_node(node.vid, Some(node.labels.clone()));
607            }
608            return;
609        }
610
611        // Not a path -- try as a node (by VID).
612        if let Ok(vid) = Executor::vid_from_value(&val) {
613            let labels = Executor::extract_labels_from_node(&val);
614            self.add_node(vid, labels);
615            return;
616        }
617
618        // Otherwise treat as an edge value.
619        if matches!(&val, Value::Map(_) | Value::Edge(_)) {
620            self.edge_vals.push(val);
621        }
622    }
623
624    fn add_node(&mut self, vid: Vid, labels: Option<Vec<String>>) {
625        if self.dedup && !self.seen_vids.insert(vid.as_u64()) {
626            return;
627        }
628        self.node_entries.push((vid, labels));
629    }
630}
631
632/// Apply mutations to rows using the appropriate executor helper.
633async fn apply_mutations(
634    mutation_ctx: &MutationContext,
635    mutation_kind: &MutationKind,
636    rows: &mut [HashMap<String, Value>],
637    writer: &mut Writer,
638) -> DFResult<()> {
639    tracing::trace!(
640        mutation = mutation_kind_label(mutation_kind),
641        rows = rows.len(),
642        "Applying mutations"
643    );
644
645    let exec = &mutation_ctx.executor;
646    let pm = &mutation_ctx.prop_manager;
647    let params = &mutation_ctx.params;
648    let ctx = mutation_ctx.query_ctx.as_ref();
649
650    let df_err = |msg: &str, e: anyhow::Error| {
651        datafusion::error::DataFusionError::Execution(format!("{msg}: {e}"))
652    };
653
654    match mutation_kind {
655        MutationKind::Create { pattern } => {
656            for row in rows.iter_mut() {
657                exec.execute_create_pattern(pattern, row, writer, pm, params, ctx)
658                    .await
659                    .map_err(|e| df_err("CREATE failed", e))?;
660            }
661        }
662        MutationKind::CreateBatch { patterns } => {
663            for row in rows.iter_mut() {
664                for pattern in patterns {
665                    exec.execute_create_pattern(pattern, row, writer, pm, params, ctx)
666                        .await
667                        .map_err(|e| df_err("CREATE failed", e))?;
668                }
669            }
670        }
671        MutationKind::Set { items } => {
672            for row in rows.iter_mut() {
673                exec.execute_set_items_locked(items, row, writer, pm, params, ctx)
674                    .await
675                    .map_err(|e| df_err("SET failed", e))?;
676            }
677        }
678        MutationKind::Remove { items } => {
679            for row in rows.iter_mut() {
680                exec.execute_remove_items_locked(items, row, writer, pm, ctx)
681                    .await
682                    .map_err(|e| df_err("REMOVE failed", e))?;
683            }
684        }
685        MutationKind::Delete { items, detach } => {
686            // Evaluate all DELETE targets and classify into nodes vs edges.
687            let mut collector = DeleteCollector::new(!*detach);
688            for row in rows.iter() {
689                for expr in items {
690                    let val = exec
691                        .evaluate_expr(expr, row, pm, params, ctx)
692                        .await
693                        .map_err(|e| df_err("DELETE eval failed", e))?;
694                    collector.add(val);
695                }
696            }
697
698            // Delete edges before nodes so non-detach DELETE satisfies constraints.
699            for val in &collector.edge_vals {
700                exec.execute_delete_item_locked(val, false, writer)
701                    .await
702                    .map_err(|e| df_err("DELETE edge failed", e))?;
703            }
704
705            if *detach {
706                let (vids, labels): (Vec<Vid>, Vec<Option<Vec<String>>>) =
707                    collector.node_entries.into_iter().unzip();
708                exec.batch_detach_delete_vertices(&vids, labels, writer)
709                    .await
710                    .map_err(|e| df_err("DETACH DELETE failed", e))?;
711            } else {
712                for (vid, labels) in &collector.node_entries {
713                    exec.execute_delete_vertex(*vid, false, labels.clone(), writer)
714                        .await
715                        .map_err(|e| df_err("DELETE node failed", e))?;
716                }
717            }
718        }
719        MutationKind::Merge { .. } => {
720            // MERGE is handled before the writer lock in execute_mutation_inner.
721            // This branch is unreachable but required for exhaustive matching.
722            unreachable!("MERGE mutations are handled before apply_mutations is called");
723        }
724    }
725
726    Ok(())
727}
728
729/// Extract variable names introduced by a CREATE/MERGE pattern.
730///
731/// Walks the pattern tree and collects all node and relationship variable names.
732/// Used to compute extended output schemas for CREATE/MERGE operators.
733pub fn pattern_variable_names(pattern: &Pattern) -> Vec<String> {
734    let mut vars = Vec::new();
735    for path in &pattern.paths {
736        if let Some(ref v) = path.variable {
737            vars.push(v.clone());
738        }
739        for element in &path.elements {
740            match element {
741                PatternElement::Node(n) => {
742                    if let Some(ref v) = n.variable {
743                        vars.push(v.clone());
744                    }
745                }
746                PatternElement::Relationship(r) => {
747                    if let Some(ref v) = r.variable {
748                        vars.push(v.clone());
749                    }
750                }
751                PatternElement::Parenthesized { pattern, .. } => {
752                    // Recurse into parenthesized sub-patterns
753                    let sub = Pattern {
754                        paths: vec![pattern.as_ref().clone()],
755                    };
756                    vars.extend(pattern_variable_names(&sub));
757                }
758            }
759        }
760    }
761    vars
762}
763
764/// Normalize a schema for mutation output.
765///
766/// After mutation processing, entity values (nodes/edges) are stored as
767/// `Value::Map` in row HashMaps. The input schema may have Struct columns
768/// for these entities, but `rows_to_batches()` encodes Map values as
769/// cv_encoded LargeBinary. This function converts Struct and Binary entity
770/// columns to cv_encoded LargeBinary to match the actual output format.
771fn normalize_mutation_schema(schema: &SchemaRef) -> SchemaRef {
772    use arrow_schema::{Field, Schema};
773
774    let needs_normalization = schema
775        .fields()
776        .iter()
777        .any(|f| matches!(f.data_type(), DataType::Struct(_)));
778
779    if !needs_normalization {
780        return schema.clone();
781    }
782
783    let fields: Vec<Arc<Field>> = schema
784        .fields()
785        .iter()
786        .map(|field| {
787            if matches!(field.data_type(), DataType::Struct(_)) {
788                let mut metadata = field.metadata().clone();
789                metadata.insert("cv_encoded".to_string(), "true".to_string());
790                Arc::new(
791                    Field::new(field.name(), DataType::LargeBinary, true).with_metadata(metadata),
792                )
793            } else {
794                field.clone()
795            }
796        })
797        .collect();
798
799    Arc::new(Schema::new(fields))
800}
801
802/// Compute an extended output schema that includes columns for newly created variables.
803///
804/// Extracts variables from CREATE/MERGE patterns and adds:
805/// - Bare cv_encoded LargeBinary column for each variable
806/// - System dotted columns based on element type:
807///   - Node → `{var}._vid` (UInt64), `{var}._labels` (LargeBinary cv_encoded)
808///   - Edge → `{var}._eid` (UInt64), `{var}._type` (LargeBinary cv_encoded)
809///   - Path → bare column only (no system columns)
810///
811/// Property access on mutation variables uses dynamic `index()` UDF extraction,
812/// so property columns are NOT added here.
813///
814/// Also normalizes existing Struct entity columns to cv_encoded LargeBinary,
815/// since after mutation processing, entities are stored as Maps in row HashMaps.
816pub fn extended_schema_for_new_vars(input_schema: &SchemaRef, patterns: &[Pattern]) -> SchemaRef {
817    use arrow_schema::{Field, Schema};
818
819    // First normalize existing columns
820    let normalized = normalize_mutation_schema(input_schema);
821
822    let existing_names: HashSet<&str> = normalized
823        .fields()
824        .iter()
825        .map(|f| f.name().as_str())
826        .collect();
827
828    let mut fields: Vec<Arc<arrow_schema::Field>> = normalized.fields().to_vec();
829    let mut added: HashSet<String> = HashSet::new();
830
831    fn cv_metadata() -> std::collections::HashMap<String, String> {
832        let mut m = std::collections::HashMap::new();
833        m.insert("cv_encoded".to_string(), "true".to_string());
834        m
835    }
836
837    fn add_bare_column(
838        var: &str,
839        fields: &mut Vec<Arc<arrow_schema::Field>>,
840        existing: &HashSet<&str>,
841        added: &mut HashSet<String>,
842    ) -> bool {
843        if existing.contains(var) || added.contains(var) {
844            return false;
845        }
846        added.insert(var.to_string());
847        fields.push(Arc::new(
848            Field::new(var, DataType::LargeBinary, true).with_metadata(cv_metadata()),
849        ));
850        true
851    }
852
853    for pattern in patterns {
854        for path in &pattern.paths {
855            // Path variable (e.g., `p` in `MERGE p = (a)-[r]->(b)`)
856            if let Some(ref var) = path.variable {
857                add_bare_column(var, &mut fields, &existing_names, &mut added);
858            }
859            for element in &path.elements {
860                match element {
861                    PatternElement::Node(n) => {
862                        if let Some(ref var) = n.variable
863                            && add_bare_column(var, &mut fields, &existing_names, &mut added)
864                        {
865                            // Node system columns for id()/labels()
866                            fields.push(Arc::new(Field::new(
867                                format!("{var}._vid"),
868                                DataType::UInt64,
869                                true,
870                            )));
871                            fields.push(Arc::new(
872                                Field::new(format!("{var}._labels"), DataType::LargeBinary, true)
873                                    .with_metadata(cv_metadata()),
874                            ));
875                        }
876                    }
877                    PatternElement::Relationship(r) => {
878                        if let Some(ref var) = r.variable
879                            && add_bare_column(var, &mut fields, &existing_names, &mut added)
880                        {
881                            // Edge system columns for id()/type()
882                            fields.push(Arc::new(Field::new(
883                                format!("{var}._eid"),
884                                DataType::UInt64,
885                                true,
886                            )));
887                            fields.push(Arc::new(
888                                Field::new(format!("{var}._type"), DataType::LargeBinary, true)
889                                    .with_metadata(cv_metadata()),
890                            ));
891                        }
892                    }
893                    PatternElement::Parenthesized { pattern, .. } => {
894                        // Recurse into sub-patterns. Pass current fields as
895                        // input so the recursive call's `existing_names` check
896                        // prevents duplicates for variables already added.
897                        let sub = Pattern {
898                            paths: vec![pattern.as_ref().clone()],
899                        };
900                        let sub_schema = extended_schema_for_new_vars(
901                            &Arc::new(Schema::new(fields.clone())),
902                            &[sub],
903                        );
904                        // Sync `added` from new fields to prevent duplicates
905                        // if a later pattern element reuses a variable.
906                        for field in sub_schema.fields() {
907                            added.insert(field.name().clone());
908                        }
909                        fields = sub_schema.fields().to_vec();
910                    }
911                }
912            }
913        }
914    }
915
916    Arc::new(Schema::new(fields))
917}
918
919/// Human-readable label for a MutationKind (used in tracing spans).
920fn mutation_kind_label(kind: &MutationKind) -> &'static str {
921    match kind {
922        MutationKind::Create { .. } => "CREATE",
923        MutationKind::CreateBatch { .. } => "CREATE_BATCH",
924        MutationKind::Set { .. } => "SET",
925        MutationKind::Remove { .. } => "REMOVE",
926        MutationKind::Delete { .. } => "DELETE",
927        MutationKind::Merge { .. } => "MERGE",
928    }
929}
930
931// ============================================================================
932// Unified MutationExec: single ExecutionPlan for all mutation kinds
933// ============================================================================
934
935/// Unified DataFusion `ExecutionPlan` for all Cypher mutation clauses
936/// (CREATE, SET, REMOVE, DELETE).
937///
938/// Instead of four near-identical ExecutionPlan structs, this single struct
939/// holds a [`MutationKind`] discriminant and delegates to the shared
940/// [`execute_mutation_stream`] implementation. Typed constructors in
941/// `mutation_create`, `mutation_set`, `mutation_remove`, and `mutation_delete`
942/// provide ergonomic construction with the correct kind.
943#[derive(Debug)]
944pub struct MutationExec {
945    /// Child plan producing input rows.
946    input: Arc<dyn ExecutionPlan>,
947
948    /// The kind of mutation to apply.
949    kind: MutationKind,
950
951    /// Display name for EXPLAIN output.
952    display_name: &'static str,
953
954    /// Shared mutation context with executor and writer.
955    mutation_ctx: Arc<MutationContext>,
956
957    /// Output schema (input schema, mutations are side effects).
958    schema: SchemaRef,
959
960    /// Plan properties for DataFusion optimizer.
961    properties: PlanProperties,
962
963    /// Metrics.
964    metrics: ExecutionPlanMetricsSet,
965}
966
967impl MutationExec {
968    /// Create a new `MutationExec` with the given kind.
969    ///
970    /// The output schema is derived from the input schema with Struct entity
971    /// columns normalized to cv_encoded LargeBinary. For mutations that
972    /// introduce new variables (CREATE, MERGE), use [`Self::new_with_schema`] instead.
973    pub fn new(
974        input: Arc<dyn ExecutionPlan>,
975        kind: MutationKind,
976        display_name: &'static str,
977        mutation_ctx: Arc<MutationContext>,
978    ) -> Self {
979        let schema = normalize_mutation_schema(&input.schema());
980        let properties = compute_plan_properties(schema.clone());
981        Self {
982            input,
983            kind,
984            display_name,
985            mutation_ctx,
986            schema,
987            properties,
988            metrics: ExecutionPlanMetricsSet::new(),
989        }
990    }
991
992    /// Create a new `MutationExec` with an explicit output schema.
993    ///
994    /// Used by CREATE and MERGE operators whose output includes newly created
995    /// variables not present in the input schema.
996    pub fn new_with_schema(
997        input: Arc<dyn ExecutionPlan>,
998        kind: MutationKind,
999        display_name: &'static str,
1000        mutation_ctx: Arc<MutationContext>,
1001        output_schema: SchemaRef,
1002    ) -> Self {
1003        let properties = compute_plan_properties(output_schema.clone());
1004        Self {
1005            input,
1006            kind,
1007            display_name,
1008            mutation_ctx,
1009            schema: output_schema,
1010            properties,
1011            metrics: ExecutionPlanMetricsSet::new(),
1012        }
1013    }
1014}
1015
1016impl DisplayAs for MutationExec {
1017    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
1018        if matches!(&self.kind, MutationKind::Delete { detach: true, .. }) {
1019            write!(f, "{} [DETACH]", self.display_name)
1020        } else {
1021            write!(f, "{}", self.display_name)
1022        }
1023    }
1024}
1025
1026impl ExecutionPlan for MutationExec {
1027    fn name(&self) -> &str {
1028        self.display_name
1029    }
1030
1031    fn as_any(&self) -> &dyn Any {
1032        self
1033    }
1034
1035    fn schema(&self) -> SchemaRef {
1036        self.schema.clone()
1037    }
1038
1039    fn properties(&self) -> &PlanProperties {
1040        &self.properties
1041    }
1042
1043    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1044        vec![&self.input]
1045    }
1046
1047    fn with_new_children(
1048        self: Arc<Self>,
1049        children: Vec<Arc<dyn ExecutionPlan>>,
1050    ) -> DFResult<Arc<dyn ExecutionPlan>> {
1051        if children.len() != 1 {
1052            return Err(datafusion::error::DataFusionError::Plan(format!(
1053                "{} requires exactly one child",
1054                self.display_name,
1055            )));
1056        }
1057        Ok(Arc::new(MutationExec::new_with_schema(
1058            children[0].clone(),
1059            self.kind.clone(),
1060            self.display_name,
1061            self.mutation_ctx.clone(),
1062            self.schema.clone(),
1063        )))
1064    }
1065
1066    fn execute(
1067        &self,
1068        partition: usize,
1069        context: Arc<TaskContext>,
1070    ) -> DFResult<SendableRecordBatchStream> {
1071        execute_mutation_stream(
1072            self.input.clone(),
1073            self.schema.clone(),
1074            self.mutation_ctx.clone(),
1075            self.kind.clone(),
1076            partition,
1077            context,
1078        )
1079    }
1080
1081    fn metrics(&self) -> Option<MetricsSet> {
1082        Some(self.metrics.clone_inner())
1083    }
1084}
1085
1086/// Create a new `MutationExec` configured for a CREATE clause.
1087///
1088/// Computes an extended output schema that includes LargeBinary cv_encoded
1089/// columns for any variables introduced by the pattern that are not already
1090/// in the input schema.
1091pub fn new_create_exec(
1092    input: Arc<dyn ExecutionPlan>,
1093    pattern: Pattern,
1094    mutation_ctx: Arc<MutationContext>,
1095) -> MutationExec {
1096    let output_schema =
1097        extended_schema_for_new_vars(&input.schema(), std::slice::from_ref(&pattern));
1098    MutationExec::new_with_schema(
1099        input,
1100        MutationKind::Create { pattern },
1101        "MutationCreateExec",
1102        mutation_ctx,
1103        output_schema,
1104    )
1105}
1106
1107/// Create a new `MutationExec` configured for a MERGE clause.
1108///
1109/// Computes an extended output schema that includes LargeBinary cv_encoded
1110/// columns for any variables introduced by the pattern that are not already
1111/// in the input schema.
1112pub fn new_merge_exec(
1113    input: Arc<dyn ExecutionPlan>,
1114    pattern: Pattern,
1115    on_match: Option<SetClause>,
1116    on_create: Option<SetClause>,
1117    mutation_ctx: Arc<MutationContext>,
1118) -> MutationExec {
1119    let output_schema =
1120        extended_schema_for_new_vars(&input.schema(), std::slice::from_ref(&pattern));
1121    MutationExec::new_with_schema(
1122        input,
1123        MutationKind::Merge {
1124            pattern,
1125            on_match,
1126            on_create,
1127        },
1128        "MutationMergeExec",
1129        mutation_ctx,
1130        output_schema,
1131    )
1132}
1133
1134#[cfg(test)]
1135mod tests {
1136    use super::*;
1137    use arrow_array::{Int64Array, StringArray};
1138    use arrow_schema::{Field, Schema};
1139
1140    #[test]
1141    fn test_batches_to_rows_basic() {
1142        let schema = Arc::new(Schema::new(vec![
1143            Field::new("name", DataType::Utf8, true),
1144            Field::new("age", DataType::Int64, true),
1145        ]));
1146
1147        let batch = RecordBatch::try_new(
1148            schema,
1149            vec![
1150                Arc::new(StringArray::from(vec![Some("Alice"), Some("Bob")])),
1151                Arc::new(Int64Array::from(vec![Some(30), Some(25)])),
1152            ],
1153        )
1154        .unwrap();
1155
1156        let rows = batches_to_rows(&[batch]).unwrap();
1157        assert_eq!(rows.len(), 2);
1158        assert_eq!(rows[0].get("name"), Some(&Value::String("Alice".into())));
1159        assert_eq!(rows[0].get("age"), Some(&Value::Int(30)));
1160        assert_eq!(rows[1].get("name"), Some(&Value::String("Bob".into())));
1161        assert_eq!(rows[1].get("age"), Some(&Value::Int(25)));
1162    }
1163
1164    #[test]
1165    fn test_rows_to_batches_basic() {
1166        let schema = Arc::new(Schema::new(vec![
1167            Field::new("name", DataType::Utf8, true),
1168            Field::new("age", DataType::Int64, true),
1169        ]));
1170
1171        let rows = vec![
1172            {
1173                let mut m = HashMap::new();
1174                m.insert("name".to_string(), Value::String("Alice".into()));
1175                m.insert("age".to_string(), Value::Int(30));
1176                m
1177            },
1178            {
1179                let mut m = HashMap::new();
1180                m.insert("name".to_string(), Value::String("Bob".into()));
1181                m.insert("age".to_string(), Value::Int(25));
1182                m
1183            },
1184        ];
1185
1186        let batches = rows_to_batches(&rows, &schema).unwrap();
1187        assert_eq!(batches.len(), 1);
1188        assert_eq!(batches[0].num_rows(), 2);
1189        assert_eq!(batches[0].schema(), schema);
1190    }
1191
1192    #[test]
1193    fn test_roundtrip_scalar_types() {
1194        let schema = Arc::new(Schema::new(vec![
1195            Field::new("s", DataType::Utf8, true),
1196            Field::new("i", DataType::Int64, true),
1197            Field::new("f", DataType::Float64, true),
1198            Field::new("b", DataType::Boolean, true),
1199        ]));
1200
1201        let batch = RecordBatch::try_new(
1202            schema.clone(),
1203            vec![
1204                Arc::new(StringArray::from(vec![Some("hello")])),
1205                Arc::new(Int64Array::from(vec![Some(42)])),
1206                Arc::new(arrow_array::Float64Array::from(vec![Some(3.125)])),
1207                Arc::new(arrow_array::BooleanArray::from(vec![Some(true)])),
1208            ],
1209        )
1210        .unwrap();
1211
1212        // Roundtrip: batches → rows → batches
1213        let rows = batches_to_rows(&[batch]).unwrap();
1214        let output_batches = rows_to_batches(&rows, &schema).unwrap();
1215
1216        assert_eq!(output_batches.len(), 1);
1217        assert_eq!(output_batches[0].num_rows(), 1);
1218
1219        // Verify roundtrip fidelity
1220        let roundtrip_rows = batches_to_rows(&output_batches).unwrap();
1221        assert_eq!(roundtrip_rows.len(), 1);
1222        assert_eq!(
1223            roundtrip_rows[0].get("s"),
1224            Some(&Value::String("hello".into()))
1225        );
1226        assert_eq!(roundtrip_rows[0].get("i"), Some(&Value::Int(42)));
1227        assert_eq!(roundtrip_rows[0].get("b"), Some(&Value::Bool(true)));
1228        // Float comparison
1229        if let Some(Value::Float(f)) = roundtrip_rows[0].get("f") {
1230            assert!((*f - 3.125).abs() < 1e-10);
1231        } else {
1232            panic!("Expected float value");
1233        }
1234    }
1235
1236    #[test]
1237    fn test_roundtrip_cypher_value_encoded() {
1238        use std::collections::HashMap as StdHashMap;
1239
1240        // Create a schema with a cv_encoded LargeBinary column (entity column)
1241        let mut metadata = StdHashMap::new();
1242        metadata.insert("cv_encoded".to_string(), "true".to_string());
1243        let field = Field::new("n", DataType::LargeBinary, true).with_metadata(metadata);
1244        let schema = Arc::new(Schema::new(vec![field]));
1245
1246        // Create a node-like Map value
1247        let mut node_map = HashMap::new();
1248        node_map.insert("name".to_string(), Value::String("Alice".into()));
1249        node_map.insert("_vid".to_string(), Value::Int(1));
1250        let map_val = Value::Map(node_map);
1251
1252        // Encode to CypherValue bytes
1253        let encoded = uni_common::cypher_value_codec::encode(&map_val);
1254        let batch = RecordBatch::try_new(
1255            schema.clone(),
1256            vec![Arc::new(arrow_array::LargeBinaryArray::from(vec![Some(
1257                encoded.as_slice(),
1258            )]))],
1259        )
1260        .unwrap();
1261
1262        // Roundtrip
1263        let rows = batches_to_rows(&[batch]).unwrap();
1264        assert_eq!(rows.len(), 1);
1265
1266        // The decoded value should be a Map
1267        let val = rows[0].get("n").unwrap();
1268        assert!(matches!(val, Value::Map(_)));
1269
1270        let output_batches = rows_to_batches(&rows, &schema).unwrap();
1271        assert_eq!(output_batches[0].num_rows(), 1);
1272
1273        // Verify we can decode it back
1274        let roundtrip_rows = batches_to_rows(&output_batches).unwrap();
1275        assert_eq!(roundtrip_rows.len(), 1);
1276    }
1277
1278    #[test]
1279    fn test_empty_rows() {
1280        let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, true)]));
1281
1282        let batches = rows_to_batches(&[], &schema).unwrap();
1283        assert_eq!(batches.len(), 1);
1284        assert_eq!(batches[0].num_rows(), 0);
1285    }
1286}