Skip to main content

uni_query/query/df_graph/
mutation_common.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Common infrastructure for DataFusion mutation operators (CREATE, SET, REMOVE, DELETE).
5//!
6//! Provides:
7//! - [`MutationContext`]: Shared context for mutation operators containing executor, writer, etc.
8//! - [`batches_to_rows`]: Convert RecordBatches to row-based HashMaps (batch→row direction).
9//! - [`rows_to_batches`]: Convert row-based HashMaps back to RecordBatches (row→batch direction).
10//! - [`MutationExec`]: Eager-barrier RecordBatchStream that collects all input, applies
11//!   mutations via Writer, and yields output batches.
12
13use anyhow::Result;
14use arrow_array::RecordBatch;
15use arrow_schema::{DataType, SchemaRef};
16use datafusion::common::Result as DFResult;
17use datafusion::execution::TaskContext;
18use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
19use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
20use datafusion::physical_plan::{
21    DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SendableRecordBatchStream,
22};
23use futures::TryStreamExt;
24use std::any::Any;
25use std::collections::{HashMap, HashSet};
26use std::fmt;
27use std::sync::Arc;
28use uni_common::core::id::{Eid, Vid};
29use uni_common::{Path, Properties, Value};
30use uni_cypher::ast::{Expr, Pattern, PatternElement, RemoveItem, SetClause, SetItem};
31use uni_store::QueryContext;
32use uni_store::runtime::property_manager::PropertyManager;
33use uni_store::runtime::writer::Writer;
34use uni_store::storage::arrow_convert;
35
36use super::common::compute_plan_properties;
37use crate::query::executor::core::Executor;
38
39/// Shared context for mutation operators.
40///
41/// Contains all resources needed to execute write operations from within
42/// DataFusion ExecutionPlan operators. The Executor is `Clone` with all
43/// Arc-wrapped fields, so cloning it is cheap.
44#[derive(Clone)]
45pub struct MutationContext {
46    /// The query executor (cheap clone, all Arc fields).
47    pub executor: Executor,
48
49    /// Writer for graph mutations (vertices, edges, properties).
50    pub writer: Arc<Writer>,
51
52    /// Property manager for lazy-loading vertex/edge properties.
53    pub prop_manager: Arc<PropertyManager>,
54
55    /// Query parameters (e.g., `$param` references in Cypher).
56    pub params: HashMap<String, Value>,
57
58    /// Query context for L0 buffer visibility.
59    pub query_ctx: Option<uni_store::QueryContext>,
60
61    /// When set, mutations are routed to this private L0 buffer instead of
62    /// the global L0. Passed explicitly to Writer methods during mutation execution.
63    pub tx_l0_override: Option<Arc<parking_lot::RwLock<uni_store::runtime::l0::L0Buffer>>>,
64}
65
66impl std::fmt::Debug for MutationContext {
67    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68        f.debug_struct("MutationContext")
69            .field("has_writer", &true)
70            .field("has_prop_manager", &true)
71            .field("params_count", &self.params.len())
72            .field("has_query_ctx", &self.query_ctx.is_some())
73            .finish()
74    }
75}
76
77/// The kind of mutation to apply per row.
78#[derive(Debug, Clone)]
79pub enum MutationKind {
80    /// CREATE clause: create nodes/edges per the pattern.
81    Create { pattern: Pattern },
82
83    /// CREATE with multiple patterns (batched CREATE).
84    CreateBatch { patterns: Vec<Pattern> },
85
86    /// SET clause: update properties/labels.
87    Set { items: Vec<SetItem> },
88
89    /// REMOVE clause: remove properties/labels.
90    Remove { items: Vec<RemoveItem> },
91
92    /// DELETE clause: delete nodes/edges.
93    Delete { items: Vec<Expr>, detach: bool },
94
95    /// MERGE clause: match-or-create with optional ON MATCH/ON CREATE actions.
96    Merge {
97        pattern: Pattern,
98        on_match: Option<SetClause>,
99        on_create: Option<SetClause>,
100    },
101}
102
103/// Pre-fetched property maps for SET/REMOVE mutation rows.
104///
105/// Built once at the top of `apply_mutations` for `MutationKind::Set` and
106/// `MutationKind::Remove` (see `prefetch_set_targets`, `prefetch_remove_targets`).
107/// Replaces N per-row `get_all_vertex_props_with_ctx` calls (each ~716 µs of
108/// DataFusion+Lance setup) with a single `_vid IN (...)` scan amortized
109/// across all rows.
110///
111/// Per-row read sites in `execute_set_items_locked`,
112/// `apply_properties_to_entity`, `flush_pending_var`, and
113/// `execute_remove_items_locked` look up the VID/EID here first and fall
114/// back to the per-row call if absent — preserving correctness for newly
115/// created VIDs, schemaless rows, and non-Mutation callers.
116#[derive(Default, Debug)]
117pub(crate) struct Prefetch {
118    pub vertex: HashMap<Vid, Properties>,
119    pub edge: HashMap<Eid, Properties>,
120}
121
122/// Extract `(label, vid)` pairs and `(type_name, eid)` pairs that a SET
123/// clause will touch, batch-fetch them, return a [`Prefetch`] map.
124///
125/// Walks `rows` once to dedupe by VID/EID per label/type. Issues one
126/// `get_batch_vertex_props_for_label` per distinct vertex label. Edge
127/// prefetch is a no-op for now (Phase A.1 only covers vertices; the
128/// fallback path in execute_set_items_locked keeps edge SET correct).
129pub(crate) async fn prefetch_set_targets(
130    items: &[SetItem],
131    rows: &[HashMap<String, Value>],
132    pm: &PropertyManager,
133    ctx: Option<&QueryContext>,
134) -> Result<Prefetch> {
135    // 1. Collect vertex-targeted variable names from SET items.
136    let touched_vars: HashSet<&str> = items
137        .iter()
138        .filter_map(|item| match item {
139            SetItem::Property { expr, .. } => extract_var_from_property_expr(expr),
140            SetItem::Variable { variable, .. } | SetItem::VariablePlus { variable, .. } => {
141                Some(variable.as_str())
142            }
143            SetItem::Labels { .. } => None,
144        })
145        .collect();
146    if touched_vars.is_empty() {
147        return Ok(Prefetch::default());
148    }
149
150    collect_and_fetch_vertex_prefetch(&touched_vars, rows, pm, ctx).await
151}
152
153/// Same shape as [`prefetch_set_targets`] for REMOVE clauses.
154pub(crate) async fn prefetch_remove_targets(
155    items: &[RemoveItem],
156    rows: &[HashMap<String, Value>],
157    pm: &PropertyManager,
158    ctx: Option<&QueryContext>,
159) -> Result<Prefetch> {
160    let touched_vars: HashSet<&str> = items
161        .iter()
162        .filter_map(|item| match item {
163            RemoveItem::Property(expr) => extract_var_from_property_expr(expr),
164            RemoveItem::Labels { .. } => None,
165        })
166        .collect();
167    if touched_vars.is_empty() {
168        return Ok(Prefetch::default());
169    }
170
171    collect_and_fetch_vertex_prefetch(&touched_vars, rows, pm, ctx).await
172}
173
174/// Inspect a property-access expr (e.g. `n.prop` or `n[expr]`) and return
175/// the root variable name if it's a plain `Variable("n").Property("prop")`.
176fn extract_var_from_property_expr(expr: &Expr) -> Option<&str> {
177    if let Expr::Property(inner, _) = expr
178        && let Expr::Variable(name) = inner.as_ref()
179    {
180        return Some(name.as_str());
181    }
182    None
183}
184
185/// Group VIDs/EIDs by their primary label / edge type across `rows`,
186/// issue one batched fetch per group, merge results into a [`Prefetch`].
187///
188/// Schemaless rows / rows without a label or type fall through to the
189/// per-row fallback at the fetch site. Per-row bindings are inspected
190/// once; a variable bound to a vertex in some rows and an edge in
191/// others (rare; e.g., from union) is grouped under each kind it
192/// appears as.
193async fn collect_and_fetch_vertex_prefetch(
194    touched_vars: &HashSet<&str>,
195    rows: &[HashMap<String, Value>],
196    pm: &PropertyManager,
197    ctx: Option<&QueryContext>,
198) -> Result<Prefetch> {
199    let mut by_label: HashMap<String, HashSet<Vid>> = HashMap::new();
200    let mut by_type: HashMap<String, HashSet<Eid>> = HashMap::new();
201
202    for row in rows {
203        for &var in touched_vars {
204            let Some(bound) = row.get(var) else { continue };
205            if let Some((vid, labels)) = vertex_vid_and_labels(bound) {
206                if let Some(label) = labels.first() {
207                    by_label.entry(label.clone()).or_default().insert(vid);
208                }
209            } else if let Some((eid, type_name)) = edge_eid_and_type(bound)
210                && !type_name.is_empty()
211            {
212                by_type.entry(type_name).or_default().insert(eid);
213            }
214        }
215    }
216
217    let mut prefetch = Prefetch::default();
218    for (label, vid_set) in by_label {
219        let vids: Vec<Vid> = vid_set.into_iter().collect();
220        if vids.is_empty() {
221            continue;
222        }
223        if let Ok(label_results) = pm
224            .get_batch_vertex_props_for_label(&vids, &label, ctx)
225            .await
226        {
227            for (vid, props) in label_results {
228                prefetch.vertex.entry(vid).or_insert(props);
229            }
230        }
231        // Batch errors fall through to per-row fallback (correctness preserved).
232    }
233    for (type_name, eid_set) in by_type {
234        let eids: Vec<Eid> = eid_set.into_iter().collect();
235        if eids.is_empty() {
236            continue;
237        }
238        if let Ok(type_results) = pm
239            .get_batch_edge_props_for_type(&eids, &type_name, ctx)
240            .await
241        {
242            for (eid, props) in type_results {
243                prefetch.edge.entry(eid).or_insert(props);
244            }
245        }
246    }
247    Ok(prefetch)
248}
249
250/// Extract (vid, labels) from a bound vertex value (`Value::Node` or a
251/// `Value::Map` that has the `_vid`/`_labels` shape produced by the
252/// planner). Returns `None` for edges, paths, scalars, or untyped maps.
253fn vertex_vid_and_labels(val: &Value) -> Option<(Vid, Vec<String>)> {
254    match val {
255        Value::Node(node) => Some((node.vid, node.labels.clone())),
256        Value::Map(map) => {
257            if map.contains_key("_eid") {
258                return None;
259            }
260            let vid_val = map.get("_vid")?;
261            let vid = match vid_val {
262                Value::Int(i) if *i >= 0 => Vid::from(*i as u64),
263                _ => return None,
264            };
265            let labels = map
266                .get("_labels")
267                .and_then(|v| match v {
268                    Value::List(items) => Some(
269                        items
270                            .iter()
271                            .filter_map(|x| {
272                                if let Value::String(s) = x {
273                                    Some(s.clone())
274                                } else {
275                                    None
276                                }
277                            })
278                            .collect::<Vec<_>>(),
279                    ),
280                    _ => None,
281                })
282                .unwrap_or_default();
283            Some((vid, labels))
284        }
285        _ => None,
286    }
287}
288
289/// Extract (eid, type_name) from a bound edge value. Mirrors
290/// [`vertex_vid_and_labels`] for the edge case. The type name is
291/// resolved from `_type_name` (string) or `_type` (string) on map-encoded
292/// edges, and from `Value::Edge::edge_type` on typed edges.
293fn edge_eid_and_type(val: &Value) -> Option<(Eid, String)> {
294    match val {
295        Value::Edge(edge) => Some((edge.eid, edge.edge_type.clone())),
296        Value::Map(map) => {
297            // Must be edge-shaped: _eid, _src, _dst.
298            let eid_val = map.get("_eid")?;
299            if !map.contains_key("_src") || !map.contains_key("_dst") {
300                return None;
301            }
302            let eid = match eid_val {
303                Value::Int(i) if *i >= 0 => Eid::from(*i as u64),
304                Value::Null => return None,
305                _ => return None,
306            };
307            let type_name = map
308                .get("_type_name")
309                .or_else(|| map.get("_type"))
310                .and_then(|v| match v {
311                    Value::String(s) => Some(s.clone()),
312                    _ => None,
313                })
314                .unwrap_or_default();
315            Some((eid, type_name))
316        }
317        _ => None,
318    }
319}
320
321/// Convert RecordBatches to row-based HashMaps for mutation processing.
322///
323/// Handles special metadata on fields:
324/// - `cv_encoded=true`: Parse string value as JSON to restore original type
325/// - DateTime/Time struct types: Decode to temporal values
326///
327/// NOTE: This does NOT merge system fields (like `n._vid`) into bare variable
328/// maps. The raw column names are preserved so that `rows_to_batches` can
329/// reconstruct the RecordBatch with the same schema. System field merging
330/// happens later in `Executor::record_batches_to_rows()` for user-facing output.
331pub fn batches_to_rows(batches: &[RecordBatch]) -> Result<Vec<HashMap<String, Value>>> {
332    let mut rows = Vec::new();
333
334    for batch in batches {
335        let num_rows = batch.num_rows();
336        let schema = batch.schema();
337
338        for row_idx in 0..num_rows {
339            let mut row = HashMap::new();
340
341            for (col_idx, field) in schema.fields().iter().enumerate() {
342                let column = batch.column(col_idx);
343                // Infer Uni DataType from Arrow type for DateTime/Time struct decoding
344                let data_type = if uni_common::core::schema::is_datetime_struct(field.data_type()) {
345                    Some(&uni_common::DataType::DateTime)
346                } else if uni_common::core::schema::is_time_struct(field.data_type()) {
347                    Some(&uni_common::DataType::Time)
348                } else {
349                    None
350                };
351                let mut value = arrow_convert::arrow_to_value(column.as_ref(), row_idx, data_type);
352
353                // Check if this field contains JSON-encoded values (e.g., from UNWIND)
354                // Parse JSON string to restore the original type
355                if field
356                    .metadata()
357                    .get("cv_encoded")
358                    .is_some_and(|v| v == "true")
359                    && let Value::String(s) = &value
360                    && let Ok(parsed) = serde_json::from_str::<serde_json::Value>(s)
361                {
362                    value = Value::from(parsed);
363                }
364
365                row.insert(field.name().clone(), value);
366            }
367
368            // Also merge system fields into bare variable maps for the write helpers.
369            // The write helpers (execute_set_items_locked, etc.) expect variables
370            // as bare Maps with _vid/_labels inside. We do this AFTER preserving
371            // the raw keys so rows_to_batches can reconstruct the schema.
372            merge_system_fields_for_write(&mut row);
373
374            rows.push(row);
375        }
376    }
377
378    Ok(rows)
379}
380
381/// After mutations, sync `_all_props` within bare variable Maps from their direct property keys.
382///
383/// SET/REMOVE modify direct property keys in the bare Map (e.g., `row["n"]["name"] = "Bob"`)
384/// but the `_all_props` sub-map retains its stale pre-mutation values. The result normalizer
385/// and property UDFs (keys(), properties()) read from `_all_props`, so it must be kept in sync.
386///
387/// This must be called BEFORE `sync_dotted_columns` so that the dotted `n._all_props` column
388/// also gets the updated value.
389fn sync_all_props_in_maps(rows: &mut [HashMap<String, Value>]) {
390    for row in rows {
391        let map_keys: Vec<String> = row
392            .keys()
393            .filter(|k| !k.contains('.') && matches!(row.get(*k), Some(Value::Map(_))))
394            .cloned()
395            .collect();
396
397        for key in map_keys {
398            if let Some(Value::Map(map)) = row.get_mut(&key)
399                && map.contains_key("_all_props")
400            {
401                // Collect non-internal property keys and their values
402                let updates: Vec<(String, Value)> = map
403                    .iter()
404                    .filter(|(k, _)| !k.starts_with('_') && k.as_str() != "ext_id")
405                    .map(|(k, v)| (k.clone(), v.clone()))
406                    .collect();
407
408                if !updates.is_empty()
409                    && let Some(Value::Map(all_props)) = map.get_mut("_all_props")
410                {
411                    for (k, v) in updates {
412                        all_props.insert(k, v);
413                    }
414                }
415            }
416        }
417    }
418}
419
420/// After mutations, sync dotted property columns from bare variable Maps.
421///
422/// SET/REMOVE modify the bare Map (e.g., `row["n"]["name"] = "Bob"`) but the
423/// dotted column (`row["n.name"]`) retains its stale pre-mutation value.
424/// This step overwrites dotted columns from the Map so `rows_to_batches()`
425/// produces correct output. Also handles newly created variables from CREATE/MERGE
426/// by inserting dotted columns that didn't exist in the input.
427fn sync_dotted_columns(rows: &mut [HashMap<String, Value>], schema: &SchemaRef) {
428    for row in rows {
429        for field in schema.fields() {
430            let name = field.name();
431            if let Some(dot_pos) = name.find('.') {
432                let var_name = &name[..dot_pos];
433                let prop_name = &name[dot_pos + 1..];
434                if let Some(Value::Map(map)) = row.get(var_name) {
435                    let val = map.get(prop_name).cloned().unwrap_or(Value::Null);
436                    row.insert(name.clone(), val);
437                }
438            }
439        }
440    }
441}
442
443/// Normalize edge system field names in a map: `_src_vid` -> `_src`, `_dst_vid` -> `_dst`.
444///
445/// The write executor expects `_src`/`_dst` but DataFusion traverse emits `_src_vid`/`_dst_vid`.
446fn normalize_edge_field_names(map: &mut HashMap<String, Value>) {
447    if let Some(val) = map.remove("_src_vid") {
448        map.entry("_src".to_string()).or_insert(val);
449    }
450    if let Some(val) = map.remove("_dst_vid") {
451        map.entry("_dst".to_string()).or_insert(val);
452    }
453}
454
455/// Merge system fields into bare variable maps for write helper consumption.
456///
457/// The write helpers expect variables like `n` to be a Map containing `_vid`, `_labels`, etc.
458/// This merges dotted columns (like `n._vid`, `n._labels`) into the variable Map,
459/// while KEEPING the dotted columns in the row so `rows_to_batches` still works.
460fn merge_system_fields_for_write(row: &mut HashMap<String, Value>) {
461    // Vertex system fields (overwrite into the bare map) and edge system fields
462    // (insert only if absent) that should be copied from dotted columns.
463    const VERTEX_FIELDS: &[&str] = &["_vid", "_labels"];
464    const EDGE_FIELDS: &[&str] = &["_eid", "_type", "_src_vid", "_dst_vid"];
465
466    // Collect all variable names that have dotted columns (var.field).
467    let dotted_vars: HashSet<String> = row
468        .keys()
469        .filter_map(|key| key.find('.').map(|pos| key[..pos].to_string()))
470        .collect();
471
472    // For each variable with dotted columns, ensure a bare Map exists.
473    // If the variable is only represented via dotted columns (e.g., edge from
474    // TraverseMainByType), assemble a Map from those columns.
475    for var in &dotted_vars {
476        if !row.contains_key(var) {
477            let prefix = format!("{var}.");
478            let mut map: HashMap<String, Value> = row
479                .iter()
480                .filter_map(|(k, v)| {
481                    k.strip_prefix(prefix.as_str())
482                        .map(|field| (field.to_string(), v.clone()))
483                })
484                .collect();
485            normalize_edge_field_names(&mut map);
486            if !map.is_empty() {
487                row.insert(var.clone(), Value::Map(map));
488            }
489        }
490    }
491
492    // Merge system fields from dotted columns into bare Maps and normalize edge names.
493    // Single pass: vertex fields overwrite, edge fields insert-if-absent, then normalize.
494    let bare_vars: Vec<String> = row
495        .keys()
496        .filter(|k| !k.contains('.') && matches!(row.get(*k), Some(Value::Map(_))))
497        .cloned()
498        .collect();
499
500    for var in &bare_vars {
501        // Collect dotted values to merge (avoids borrowing row mutably while reading)
502        let vertex_vals: Vec<(&str, Value)> = VERTEX_FIELDS
503            .iter()
504            .filter_map(|&field| {
505                row.get(&format!("{var}.{field}"))
506                    .cloned()
507                    .map(|v| (field, v))
508            })
509            .collect();
510        let edge_vals: Vec<(&str, Value)> = EDGE_FIELDS
511            .iter()
512            .filter_map(|&field| {
513                row.get(&format!("{var}.{field}"))
514                    .cloned()
515                    .map(|v| (field, v))
516            })
517            .collect();
518
519        if let Some(Value::Map(map)) = row.get_mut(var) {
520            for (field, v) in vertex_vals {
521                map.insert(field.to_string(), v);
522            }
523            for (field, v) in edge_vals {
524                map.entry(field.to_string()).or_insert(v);
525            }
526            normalize_edge_field_names(map);
527        }
528    }
529}
530
531/// Convert row-based HashMaps back to RecordBatches.
532///
533/// This is the inverse of `batches_to_rows`. Schema-driven: iterates over the
534/// output schema fields and extracts named values from each row HashMap.
535///
536/// - Entity columns (LargeBinary with `cv_encoded=true`): serialize Map/Node/Edge values
537///   to CypherValue binary encoding.
538/// - Scalar columns: use `arrow_convert::values_to_array()` for type-appropriate conversion.
539pub fn rows_to_batches(
540    rows: &[HashMap<String, Value>],
541    schema: &SchemaRef,
542) -> Result<Vec<RecordBatch>> {
543    if rows.is_empty() {
544        // Handle empty schema case (no fields)
545        let batch = if schema.fields().is_empty() {
546            let options = arrow_array::RecordBatchOptions::new().with_row_count(Some(0));
547            RecordBatch::try_new_with_options(schema.clone(), vec![], &options)?
548        } else {
549            RecordBatch::new_empty(schema.clone())
550        };
551        return Ok(vec![batch]);
552    }
553
554    if schema.fields().is_empty() {
555        // Schema has no fields but there ARE rows. Preserve the row count so that
556        // downstream operators (chained mutations, aggregations) see the correct
557        // number of rows. A RecordBatch with 0 columns can still carry a row count.
558        let options = arrow_array::RecordBatchOptions::new().with_row_count(Some(rows.len()));
559        let batch = RecordBatch::try_new_with_options(schema.clone(), vec![], &options)?;
560        return Ok(vec![batch]);
561    }
562
563    // Build columns from rows using schema
564    let mut columns: Vec<arrow_array::ArrayRef> = Vec::with_capacity(schema.fields().len());
565
566    for field in schema.fields() {
567        let name = field.name();
568        let values: Vec<Value> = rows
569            .iter()
570            .map(|row| row.get(name).cloned().unwrap_or(Value::Null))
571            .collect();
572
573        let array = value_column_to_arrow(&values, field.data_type(), field)?;
574        columns.push(array);
575    }
576
577    let batch = RecordBatch::try_new(schema.clone(), columns)?;
578    Ok(vec![batch])
579}
580
581/// Convert a column of Values to an Arrow array, handling entity-encoded columns.
582fn value_column_to_arrow(
583    values: &[Value],
584    arrow_type: &DataType,
585    field: &arrow_schema::Field,
586) -> Result<arrow_array::ArrayRef> {
587    let is_cv_encoded = field
588        .metadata()
589        .get("cv_encoded")
590        .is_some_and(|v| v == "true");
591
592    if *arrow_type == DataType::LargeBinary || is_cv_encoded {
593        Ok(encode_as_large_binary(values))
594    } else if *arrow_type == DataType::Binary {
595        // Binary columns (e.g., CRDT payloads): encode as Binary, not LargeBinary
596        Ok(encode_as_binary(values))
597    } else {
598        // Use arrow_convert for scalar types, falling back to CypherValue encoding
599        arrow_convert::values_to_array(values, arrow_type)
600            .or_else(|_| Ok(encode_as_large_binary(values)))
601    }
602}
603
604/// Encode values as CypherValue blobs using the given builder type.
605macro_rules! encode_as_cv {
606    ($builder_ty:ty, $values:expr) => {{
607        let values = $values;
608        let mut builder = <$builder_ty>::with_capacity(values.len(), values.len() * 64);
609        for v in values {
610            if v.is_null() {
611                builder.append_null();
612            } else {
613                let bytes = uni_common::cypher_value_codec::encode(v);
614                builder.append_value(&bytes);
615            }
616        }
617        Arc::new(builder.finish()) as arrow_array::ArrayRef
618    }};
619}
620
621/// Encode values as CypherValue Binary blobs.
622fn encode_as_binary(values: &[Value]) -> arrow_array::ArrayRef {
623    encode_as_cv!(arrow_array::builder::BinaryBuilder, values)
624}
625
626/// Encode values as CypherValue LargeBinary blobs.
627fn encode_as_large_binary(values: &[Value]) -> arrow_array::ArrayRef {
628    encode_as_cv!(arrow_array::builder::LargeBinaryBuilder, values)
629}
630
631/// Execute a mutation stream: collect all input batches, apply mutations, yield output.
632///
633/// This is the core logic shared by all mutation operators. It implements the
634/// "eager barrier" pattern:
635/// 1. Pull ALL input batches to completion
636/// 2. Convert to rows
637/// 3. Acquire writer lock once for the entire clause
638/// 4. Apply mutations per row
639/// 5. Convert back to batches
640/// 6. Yield output
641pub fn execute_mutation_stream(
642    input: Arc<dyn ExecutionPlan>,
643    output_schema: SchemaRef,
644    mutation_ctx: Arc<MutationContext>,
645    mutation_kind: MutationKind,
646    partition: usize,
647    task_ctx: Arc<datafusion::execution::TaskContext>,
648    baseline: BaselineMetrics,
649) -> DFResult<SendableRecordBatchStream> {
650    if mutation_ctx.query_ctx.is_none() {
651        tracing::warn!(
652            "MutationContext.query_ctx is None — mutations may not see latest L0 buffer state"
653        );
654    }
655
656    let stream = futures::stream::once(execute_mutation_inner(
657        input,
658        output_schema.clone(),
659        mutation_ctx,
660        mutation_kind,
661        partition,
662        task_ctx,
663        baseline,
664    ))
665    .try_flatten();
666
667    Ok(Box::pin(RecordBatchStreamAdapter::new(
668        output_schema,
669        stream,
670    )))
671}
672
673/// Inner async function for mutation execution.
674///
675/// Separated from the stream combinator to provide explicit return type
676/// annotation, avoiding type inference issues with multiple From<DataFusionError> impls.
677///
678/// Mutations are applied as storage-level side effects via Writer/L0 buffer.
679/// After mutations, output batches are reconstructed from the modified rows
680/// so downstream operators (RETURN, WITH, subsequent mutations) see the
681/// created/updated variables and properties.
682async fn execute_mutation_inner(
683    input: Arc<dyn ExecutionPlan>,
684    output_schema: SchemaRef,
685    mutation_ctx: Arc<MutationContext>,
686    mutation_kind: MutationKind,
687    partition: usize,
688    task_ctx: Arc<datafusion::execution::TaskContext>,
689    baseline: BaselineMetrics,
690) -> DFResult<futures::stream::Iter<std::vec::IntoIter<DFResult<RecordBatch>>>> {
691    // Time the whole eager-barrier body: input collection, mutation writes,
692    // and output batch reconstruction. Timer records on Drop.
693    let _timer = baseline.elapsed_compute().timer();
694    let mutation_label = mutation_kind_label(&mutation_kind);
695
696    // 1. Collect all input batches (eager barrier)
697    let input_stream = input.execute(partition, task_ctx)?;
698    let input_batches: Vec<RecordBatch> = input_stream.try_collect().await?;
699
700    let input_row_count: usize = input_batches.iter().map(|b| b.num_rows()).sum();
701    tracing::debug!(
702        mutation = mutation_label,
703        batches = input_batches.len(),
704        rows = input_row_count,
705        "Executing mutation"
706    );
707
708    // 2. Convert to rows for mutation helpers (they operate on HashMap rows)
709    let mut rows = batches_to_rows(&input_batches).map_err(|e| {
710        datafusion::error::DataFusionError::Execution(format!(
711            "Failed to convert batches to rows: {e}"
712        ))
713    })?;
714
715    // 3. Apply mutations.
716    // MERGE manages its own writer lock internally (acquires/releases per-row because
717    // execute_merge_match needs to run a read subplan between lock acquisitions).
718    // All other mutations acquire the writer lock once for the entire clause.
719    if let MutationKind::Merge {
720        ref pattern,
721        ref on_match,
722        ref on_create,
723    } = mutation_kind
724    {
725        let exec = &mutation_ctx.executor;
726        let pm = &mutation_ctx.prop_manager;
727        let params = &mutation_ctx.params;
728        let ctx = mutation_ctx.query_ctx.as_ref();
729
730        let mut result_rows = exec
731            .execute_merge(
732                rows,
733                pattern,
734                on_match.as_ref(),
735                on_create.as_ref(),
736                pm,
737                params,
738                ctx,
739                mutation_ctx.tx_l0_override.as_ref(),
740            )
741            .await
742            .map_err(|e| {
743                datafusion::error::DataFusionError::Execution(format!("MERGE failed: {e}"))
744            })?;
745
746        tracing::debug!(
747            mutation = mutation_label,
748            input_rows = input_row_count,
749            output_rows = result_rows.len(),
750            "MERGE mutation complete"
751        );
752
753        // Reconstruct output batches from modified rows so downstream operators
754        // (RETURN, WITH, subsequent mutations) see the merged/created variables.
755        sync_all_props_in_maps(&mut result_rows);
756        sync_dotted_columns(&mut result_rows, &output_schema);
757        let result_batches = rows_to_batches(&result_rows, &output_schema).map_err(|e| {
758            datafusion::error::DataFusionError::Execution(format!(
759                "Failed to reconstruct MERGE batches: {e}"
760            ))
761        })?;
762        let output_rows: usize = result_batches.iter().map(|b| b.num_rows()).sum();
763        baseline.record_output(output_rows);
764        let results: Vec<DFResult<RecordBatch>> = result_batches.into_iter().map(Ok).collect();
765        return Ok(futures::stream::iter(results));
766    }
767
768    let tx_l0 = mutation_ctx.tx_l0_override.as_ref();
769    apply_mutations(
770        &mutation_ctx,
771        &mutation_kind,
772        &mut rows,
773        &mutation_ctx.writer,
774        tx_l0,
775    )
776    .await?;
777
778    tracing::debug!(
779        mutation = mutation_label,
780        rows = input_row_count,
781        "Mutation complete"
782    );
783
784    // 4. Reconstruct output batches from modified rows.
785    // Mutations modify the row HashMaps in place (CREATE adds new variable keys,
786    // SET updates property values). Reconstruct batches so downstream operators
787    // (RETURN, WITH, subsequent mutations) see these modifications.
788    sync_all_props_in_maps(&mut rows);
789    sync_dotted_columns(&mut rows, &output_schema);
790    let result_batches = rows_to_batches(&rows, &output_schema).map_err(|e| {
791        datafusion::error::DataFusionError::Execution(format!("Failed to reconstruct batches: {e}"))
792    })?;
793    let output_rows: usize = result_batches.iter().map(|b| b.num_rows()).sum();
794    baseline.record_output(output_rows);
795    let results: Vec<DFResult<RecordBatch>> = result_batches.into_iter().map(Ok).collect();
796    Ok(futures::stream::iter(results))
797}
798
799/// Collects and classifies DELETE targets into nodes and edges.
800///
801/// Handles `Value::Path`, `Value::Node`, `Value::Edge`, map-encoded paths
802/// (from Arrow round-trip), and raw VID values. When `dedup` is true,
803/// uses HashSets to skip duplicates (needed for non-DETACH DELETE to
804/// handle shared nodes across paths).
805struct DeleteCollector {
806    /// Collected node entries: (vid, labels) pairs.
807    node_entries: Vec<(Vid, Option<Vec<String>>)>,
808    /// Collected edge values to delete.
809    edge_vals: Vec<Value>,
810    /// Deduplication sets (only used when dedup=true).
811    seen_vids: HashSet<u64>,
812    seen_eids: HashSet<u64>,
813    dedup: bool,
814}
815
816impl DeleteCollector {
817    fn new(dedup: bool) -> Self {
818        Self {
819            node_entries: Vec::new(),
820            edge_vals: Vec::new(),
821            seen_vids: HashSet::new(),
822            seen_eids: HashSet::new(),
823            dedup,
824        }
825    }
826
827    fn add(&mut self, val: Value) {
828        if val.is_null() {
829            return;
830        }
831
832        // Try to resolve value as a Path (native or map-encoded).
833        let path = match &val {
834            Value::Path(p) => Some(p.clone()),
835            _ => Path::try_from(&val).ok(),
836        };
837
838        if let Some(path) = path {
839            for edge in &path.edges {
840                if !self.dedup || self.seen_eids.insert(edge.eid.as_u64()) {
841                    self.edge_vals.push(Value::Edge(edge.clone()));
842                }
843            }
844            for node in &path.nodes {
845                self.add_node(node.vid, Some(node.labels.clone()));
846            }
847            return;
848        }
849
850        // Not a path -- try as a node (by VID).
851        if let Ok(vid) = Executor::vid_from_value(&val) {
852            let labels = Executor::extract_labels_from_node(&val);
853            self.add_node(vid, labels);
854            return;
855        }
856
857        // Otherwise treat as an edge value.
858        if matches!(&val, Value::Map(_) | Value::Edge(_)) {
859            self.edge_vals.push(val);
860        }
861    }
862
863    fn add_node(&mut self, vid: Vid, labels: Option<Vec<String>>) {
864        if self.dedup && !self.seen_vids.insert(vid.as_u64()) {
865            return;
866        }
867        self.node_entries.push((vid, labels));
868    }
869}
870
871/// Apply mutations to rows using the appropriate executor helper.
872async fn apply_mutations(
873    mutation_ctx: &MutationContext,
874    mutation_kind: &MutationKind,
875    rows: &mut [HashMap<String, Value>],
876    writer: &Writer,
877    tx_l0: Option<&Arc<parking_lot::RwLock<uni_store::runtime::l0::L0Buffer>>>,
878) -> DFResult<()> {
879    tracing::trace!(
880        mutation = mutation_kind_label(mutation_kind),
881        rows = rows.len(),
882        "Applying mutations"
883    );
884
885    let exec = &mutation_ctx.executor;
886    let pm = &mutation_ctx.prop_manager;
887    let params = &mutation_ctx.params;
888    let ctx = mutation_ctx.query_ctx.as_ref();
889
890    let df_err = |msg: &str, e: anyhow::Error| {
891        datafusion::error::DataFusionError::Execution(format!("{msg}: {e}"))
892    };
893
894    match mutation_kind {
895        MutationKind::Create { pattern } => {
896            for row in rows.iter_mut() {
897                exec.execute_create_pattern(pattern, row, writer, pm, params, ctx, tx_l0)
898                    .await
899                    .map_err(|e| df_err("CREATE failed", e))?;
900            }
901        }
902        MutationKind::CreateBatch { patterns } => {
903            for row in rows.iter_mut() {
904                for pattern in patterns {
905                    exec.execute_create_pattern(pattern, row, writer, pm, params, ctx, tx_l0)
906                        .await
907                        .map_err(|e| df_err("CREATE failed", e))?;
908                }
909            }
910        }
911        MutationKind::Set { items } => {
912            let prefetch = prefetch_set_targets(items, rows, pm, ctx)
913                .await
914                .map_err(|e| df_err("SET prefetch failed", e))?;
915            for row in rows.iter_mut() {
916                exec.execute_set_items_locked(
917                    items, row, writer, pm, params, ctx, tx_l0, &prefetch,
918                )
919                .await
920                .map_err(|e| df_err("SET failed", e))?;
921            }
922        }
923        MutationKind::Remove { items } => {
924            let prefetch = prefetch_remove_targets(items, rows, pm, ctx)
925                .await
926                .map_err(|e| df_err("REMOVE prefetch failed", e))?;
927            for row in rows.iter_mut() {
928                exec.execute_remove_items_locked(items, row, writer, pm, ctx, tx_l0, &prefetch)
929                    .await
930                    .map_err(|e| df_err("REMOVE failed", e))?;
931            }
932        }
933        MutationKind::Delete { items, detach } => {
934            // Evaluate all DELETE targets and classify into nodes vs edges.
935            let mut collector = DeleteCollector::new(!*detach);
936            for row in rows.iter() {
937                for expr in items {
938                    let val = exec
939                        .evaluate_expr(expr, row, pm, params, ctx)
940                        .await
941                        .map_err(|e| df_err("DELETE eval failed", e))?;
942                    collector.add(val);
943                }
944            }
945
946            // Delete edges before nodes so non-detach DELETE satisfies constraints.
947            for val in &collector.edge_vals {
948                exec.execute_delete_item_locked(val, false, writer, tx_l0)
949                    .await
950                    .map_err(|e| df_err("DELETE edge failed", e))?;
951            }
952
953            if *detach {
954                let (vids, labels): (Vec<Vid>, Vec<Option<Vec<String>>>) =
955                    collector.node_entries.into_iter().unzip();
956                exec.batch_detach_delete_vertices(&vids, labels, writer, tx_l0)
957                    .await
958                    .map_err(|e| df_err("DETACH DELETE failed", e))?;
959            } else {
960                // Non-detach: one batched edge-dependency check across all
961                // targets (Phase C — collapses N per-VID subgraph loads to
962                // one), then the per-VID writer.delete_vertex calls
963                // (cheap; no scan).
964                let vids: Vec<Vid> = collector.node_entries.iter().map(|(v, _)| *v).collect();
965                exec.batch_check_vertices_have_no_edges(&vids, writer, tx_l0)
966                    .await
967                    .map_err(|e| df_err("DELETE check failed", e))?;
968                for (vid, labels) in &collector.node_entries {
969                    writer
970                        .delete_vertex(*vid, labels.clone(), tx_l0)
971                        .await
972                        .map_err(|e| df_err("DELETE node failed", e))?;
973                }
974            }
975        }
976        MutationKind::Merge { .. } => {
977            // MERGE is handled before the writer lock in execute_mutation_inner.
978            // This branch is unreachable but required for exhaustive matching.
979            unreachable!("MERGE mutations are handled before apply_mutations is called");
980        }
981    }
982
983    Ok(())
984}
985
986/// Extract variable names introduced by a CREATE/MERGE pattern.
987///
988/// Walks the pattern tree and collects all node and relationship variable names.
989/// Used to compute extended output schemas for CREATE/MERGE operators.
990pub fn pattern_variable_names(pattern: &Pattern) -> Vec<String> {
991    let mut vars = Vec::new();
992    for path in &pattern.paths {
993        if let Some(ref v) = path.variable {
994            vars.push(v.clone());
995        }
996        for element in &path.elements {
997            match element {
998                PatternElement::Node(n) => {
999                    if let Some(ref v) = n.variable {
1000                        vars.push(v.clone());
1001                    }
1002                }
1003                PatternElement::Relationship(r) => {
1004                    if let Some(ref v) = r.variable {
1005                        vars.push(v.clone());
1006                    }
1007                }
1008                PatternElement::Parenthesized { pattern, .. } => {
1009                    // Recurse into parenthesized sub-patterns
1010                    let sub = Pattern {
1011                        paths: vec![pattern.as_ref().clone()],
1012                    };
1013                    vars.extend(pattern_variable_names(&sub));
1014                }
1015            }
1016        }
1017    }
1018    vars
1019}
1020
1021/// Normalize a schema for mutation output.
1022///
1023/// After mutation processing, entity values (nodes/edges) are stored as
1024/// `Value::Map` in row HashMaps. The input schema may have Struct columns
1025/// for these entities, but `rows_to_batches()` encodes Map values as
1026/// cv_encoded LargeBinary. This function converts Struct and Binary entity
1027/// columns to cv_encoded LargeBinary to match the actual output format.
1028fn normalize_mutation_schema(schema: &SchemaRef) -> SchemaRef {
1029    use arrow_schema::{Field, Schema};
1030
1031    // Detect any field whose type round-trips through `rows_to_batches` as
1032    // CV-encoded LargeBinary, so the declared schema must match or
1033    // `RecordBatch::try_new` rejects the batch with "column types must
1034    // match schema types".
1035    //   * `Struct(_)` — bare graph entities (Node/Edge).
1036    //   * `List<Struct>` / `LargeList<Struct>` — the VLP-bound edge-list
1037    //     `r` from `MATCH (a)-[r*1..2]->(b)`.
1038    //   * `List<T>` for any T that `arrow_convert::values_to_array` doesn't
1039    //     know how to construct from `Value::List` (today: everything
1040    //     except `List<Utf8>`). RecursiveCTE-side WHERE-IN aggregations
1041    //     emit `List<Int64>` here, for example.
1042    fn needs_norm(dt: &DataType) -> bool {
1043        match dt {
1044            DataType::Struct(_) => true,
1045            DataType::List(inner) | DataType::LargeList(inner) => {
1046                !matches!(inner.data_type(), DataType::Utf8)
1047            }
1048            _ => false,
1049        }
1050    }
1051
1052    if !schema.fields().iter().any(|f| needs_norm(f.data_type())) {
1053        return schema.clone();
1054    }
1055
1056    let fields: Vec<Arc<Field>> = schema
1057        .fields()
1058        .iter()
1059        .map(|field| {
1060            if needs_norm(field.data_type()) {
1061                let mut metadata = field.metadata().clone();
1062                metadata.insert("cv_encoded".to_string(), "true".to_string());
1063                Arc::new(
1064                    Field::new(field.name(), DataType::LargeBinary, true).with_metadata(metadata),
1065                )
1066            } else {
1067                field.clone()
1068            }
1069        })
1070        .collect();
1071
1072    Arc::new(Schema::new(fields))
1073}
1074
1075/// Compute an extended output schema that includes columns for newly created variables.
1076///
1077/// Extracts variables from CREATE/MERGE patterns and adds:
1078/// - Bare cv_encoded LargeBinary column for each variable
1079/// - System dotted columns based on element type:
1080///   - Node → `{var}._vid` (UInt64), `{var}._labels` (LargeBinary cv_encoded)
1081///   - Edge → `{var}._eid` (UInt64), `{var}._type` (LargeBinary cv_encoded)
1082///   - Path → bare column only (no system columns)
1083///
1084/// Property access on mutation variables uses dynamic `index()` UDF extraction,
1085/// so property columns are NOT added here.
1086///
1087/// Also normalizes existing Struct entity columns to cv_encoded LargeBinary,
1088/// since after mutation processing, entities are stored as Maps in row HashMaps.
1089pub fn extended_schema_for_new_vars(input_schema: &SchemaRef, patterns: &[Pattern]) -> SchemaRef {
1090    use arrow_schema::{Field, Schema};
1091
1092    // First normalize existing columns
1093    let normalized = normalize_mutation_schema(input_schema);
1094
1095    let existing_names: HashSet<&str> = normalized
1096        .fields()
1097        .iter()
1098        .map(|f| f.name().as_str())
1099        .collect();
1100
1101    let mut fields: Vec<Arc<arrow_schema::Field>> = normalized.fields().to_vec();
1102    let mut added: HashSet<String> = HashSet::new();
1103
1104    fn cv_metadata() -> std::collections::HashMap<String, String> {
1105        let mut m = std::collections::HashMap::new();
1106        m.insert("cv_encoded".to_string(), "true".to_string());
1107        m
1108    }
1109
1110    fn add_bare_column(
1111        var: &str,
1112        fields: &mut Vec<Arc<arrow_schema::Field>>,
1113        existing: &HashSet<&str>,
1114        added: &mut HashSet<String>,
1115    ) -> bool {
1116        if existing.contains(var) || added.contains(var) {
1117            return false;
1118        }
1119        added.insert(var.to_string());
1120        fields.push(Arc::new(
1121            Field::new(var, DataType::LargeBinary, true).with_metadata(cv_metadata()),
1122        ));
1123        true
1124    }
1125
1126    for pattern in patterns {
1127        for path in &pattern.paths {
1128            // Path variable (e.g., `p` in `MERGE p = (a)-[r]->(b)`)
1129            if let Some(ref var) = path.variable {
1130                add_bare_column(var, &mut fields, &existing_names, &mut added);
1131            }
1132            for element in &path.elements {
1133                match element {
1134                    PatternElement::Node(n) => {
1135                        if let Some(ref var) = n.variable
1136                            && add_bare_column(var, &mut fields, &existing_names, &mut added)
1137                        {
1138                            // Node system columns for id()/labels()
1139                            fields.push(Arc::new(Field::new(
1140                                format!("{var}._vid"),
1141                                DataType::UInt64,
1142                                true,
1143                            )));
1144                            fields.push(Arc::new(
1145                                Field::new(format!("{var}._labels"), DataType::LargeBinary, true)
1146                                    .with_metadata(cv_metadata()),
1147                            ));
1148                        }
1149                    }
1150                    PatternElement::Relationship(r) => {
1151                        if let Some(ref var) = r.variable
1152                            && add_bare_column(var, &mut fields, &existing_names, &mut added)
1153                        {
1154                            // Edge system columns for id()/type()
1155                            fields.push(Arc::new(Field::new(
1156                                format!("{var}._eid"),
1157                                DataType::UInt64,
1158                                true,
1159                            )));
1160                            fields.push(Arc::new(
1161                                Field::new(format!("{var}._type"), DataType::LargeBinary, true)
1162                                    .with_metadata(cv_metadata()),
1163                            ));
1164                        }
1165                    }
1166                    PatternElement::Parenthesized { pattern, .. } => {
1167                        // Recurse into sub-patterns. Pass current fields as
1168                        // input so the recursive call's `existing_names` check
1169                        // prevents duplicates for variables already added.
1170                        let sub = Pattern {
1171                            paths: vec![pattern.as_ref().clone()],
1172                        };
1173                        let sub_schema = extended_schema_for_new_vars(
1174                            &Arc::new(Schema::new(fields.clone())),
1175                            &[sub],
1176                        );
1177                        // Sync `added` from new fields to prevent duplicates
1178                        // if a later pattern element reuses a variable.
1179                        for field in sub_schema.fields() {
1180                            added.insert(field.name().clone());
1181                        }
1182                        fields = sub_schema.fields().to_vec();
1183                    }
1184                }
1185            }
1186        }
1187    }
1188
1189    Arc::new(Schema::new(fields))
1190}
1191
1192/// Human-readable label for a MutationKind (used in tracing spans).
1193fn mutation_kind_label(kind: &MutationKind) -> &'static str {
1194    match kind {
1195        MutationKind::Create { .. } => "CREATE",
1196        MutationKind::CreateBatch { .. } => "CREATE_BATCH",
1197        MutationKind::Set { .. } => "SET",
1198        MutationKind::Remove { .. } => "REMOVE",
1199        MutationKind::Delete { .. } => "DELETE",
1200        MutationKind::Merge { .. } => "MERGE",
1201    }
1202}
1203
1204// ============================================================================
1205// Unified MutationExec: single ExecutionPlan for all mutation kinds
1206// ============================================================================
1207
1208/// Unified DataFusion `ExecutionPlan` for all Cypher mutation clauses
1209/// (CREATE, SET, REMOVE, DELETE).
1210///
1211/// Instead of four near-identical ExecutionPlan structs, this single struct
1212/// holds a [`MutationKind`] discriminant and delegates to the shared
1213/// [`execute_mutation_stream`] implementation. Typed constructors in
1214/// `mutation_create`, `mutation_set`, `mutation_remove`, and `mutation_delete`
1215/// provide ergonomic construction with the correct kind.
1216#[derive(Debug)]
1217pub struct MutationExec {
1218    /// Child plan producing input rows.
1219    input: Arc<dyn ExecutionPlan>,
1220
1221    /// The kind of mutation to apply.
1222    kind: MutationKind,
1223
1224    /// Display name for EXPLAIN output.
1225    display_name: &'static str,
1226
1227    /// Shared mutation context with executor and writer.
1228    mutation_ctx: Arc<MutationContext>,
1229
1230    /// Output schema (input schema, mutations are side effects).
1231    schema: SchemaRef,
1232
1233    /// Plan properties for DataFusion optimizer.
1234    properties: Arc<PlanProperties>,
1235
1236    /// Metrics.
1237    metrics: ExecutionPlanMetricsSet,
1238}
1239
1240impl MutationExec {
1241    /// Create a new `MutationExec` with the given kind.
1242    ///
1243    /// The output schema is derived from the input schema with Struct entity
1244    /// columns normalized to cv_encoded LargeBinary. For mutations that
1245    /// introduce new variables (CREATE, MERGE), use [`Self::new_with_schema`] instead.
1246    pub fn new(
1247        input: Arc<dyn ExecutionPlan>,
1248        kind: MutationKind,
1249        display_name: &'static str,
1250        mutation_ctx: Arc<MutationContext>,
1251    ) -> Self {
1252        let schema = normalize_mutation_schema(&input.schema());
1253        let properties = compute_plan_properties(schema.clone());
1254        Self {
1255            input,
1256            kind,
1257            display_name,
1258            mutation_ctx,
1259            schema,
1260            properties,
1261            metrics: ExecutionPlanMetricsSet::new(),
1262        }
1263    }
1264
1265    /// Create a new `MutationExec` with an explicit output schema.
1266    ///
1267    /// Used by CREATE and MERGE operators whose output includes newly created
1268    /// variables not present in the input schema.
1269    pub fn new_with_schema(
1270        input: Arc<dyn ExecutionPlan>,
1271        kind: MutationKind,
1272        display_name: &'static str,
1273        mutation_ctx: Arc<MutationContext>,
1274        output_schema: SchemaRef,
1275    ) -> Self {
1276        let properties = compute_plan_properties(output_schema.clone());
1277        Self {
1278            input,
1279            kind,
1280            display_name,
1281            mutation_ctx,
1282            schema: output_schema,
1283            properties,
1284            metrics: ExecutionPlanMetricsSet::new(),
1285        }
1286    }
1287}
1288
1289impl DisplayAs for MutationExec {
1290    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
1291        if matches!(&self.kind, MutationKind::Delete { detach: true, .. }) {
1292            write!(f, "{} [DETACH]", self.display_name)
1293        } else {
1294            write!(f, "{}", self.display_name)
1295        }
1296    }
1297}
1298
1299impl ExecutionPlan for MutationExec {
1300    fn name(&self) -> &str {
1301        self.display_name
1302    }
1303
1304    fn as_any(&self) -> &dyn Any {
1305        self
1306    }
1307
1308    fn schema(&self) -> SchemaRef {
1309        self.schema.clone()
1310    }
1311
1312    fn properties(&self) -> &Arc<PlanProperties> {
1313        &self.properties
1314    }
1315
1316    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1317        vec![&self.input]
1318    }
1319
1320    fn with_new_children(
1321        self: Arc<Self>,
1322        children: Vec<Arc<dyn ExecutionPlan>>,
1323    ) -> DFResult<Arc<dyn ExecutionPlan>> {
1324        if children.len() != 1 {
1325            return Err(datafusion::error::DataFusionError::Plan(format!(
1326                "{} requires exactly one child",
1327                self.display_name,
1328            )));
1329        }
1330        Ok(Arc::new(MutationExec::new_with_schema(
1331            children[0].clone(),
1332            self.kind.clone(),
1333            self.display_name,
1334            self.mutation_ctx.clone(),
1335            self.schema.clone(),
1336        )))
1337    }
1338
1339    fn execute(
1340        &self,
1341        partition: usize,
1342        context: Arc<TaskContext>,
1343    ) -> DFResult<SendableRecordBatchStream> {
1344        let baseline = BaselineMetrics::new(&self.metrics, partition);
1345        execute_mutation_stream(
1346            self.input.clone(),
1347            self.schema.clone(),
1348            self.mutation_ctx.clone(),
1349            self.kind.clone(),
1350            partition,
1351            context,
1352            baseline,
1353        )
1354    }
1355
1356    fn metrics(&self) -> Option<MetricsSet> {
1357        Some(self.metrics.clone_inner())
1358    }
1359}
1360
1361/// Create a new `MutationExec` configured for a CREATE clause.
1362///
1363/// Computes an extended output schema that includes LargeBinary cv_encoded
1364/// columns for any variables introduced by the pattern that are not already
1365/// in the input schema.
1366pub fn new_create_exec(
1367    input: Arc<dyn ExecutionPlan>,
1368    pattern: Pattern,
1369    mutation_ctx: Arc<MutationContext>,
1370) -> MutationExec {
1371    let output_schema =
1372        extended_schema_for_new_vars(&input.schema(), std::slice::from_ref(&pattern));
1373    MutationExec::new_with_schema(
1374        input,
1375        MutationKind::Create { pattern },
1376        "MutationCreateExec",
1377        mutation_ctx,
1378        output_schema,
1379    )
1380}
1381
1382/// Create a new `MutationExec` configured for a MERGE clause.
1383///
1384/// Computes an extended output schema that includes LargeBinary cv_encoded
1385/// columns for any variables introduced by the pattern that are not already
1386/// in the input schema.
1387pub fn new_merge_exec(
1388    input: Arc<dyn ExecutionPlan>,
1389    pattern: Pattern,
1390    on_match: Option<SetClause>,
1391    on_create: Option<SetClause>,
1392    mutation_ctx: Arc<MutationContext>,
1393) -> MutationExec {
1394    let output_schema =
1395        extended_schema_for_new_vars(&input.schema(), std::slice::from_ref(&pattern));
1396    MutationExec::new_with_schema(
1397        input,
1398        MutationKind::Merge {
1399            pattern,
1400            on_match,
1401            on_create,
1402        },
1403        "MutationMergeExec",
1404        mutation_ctx,
1405        output_schema,
1406    )
1407}
1408
1409#[cfg(test)]
1410mod tests {
1411    use super::*;
1412    use arrow_array::{Int64Array, StringArray};
1413    use arrow_schema::{Field, Schema};
1414
1415    #[test]
1416    fn test_batches_to_rows_basic() {
1417        let schema = Arc::new(Schema::new(vec![
1418            Field::new("name", DataType::Utf8, true),
1419            Field::new("age", DataType::Int64, true),
1420        ]));
1421
1422        let batch = RecordBatch::try_new(
1423            schema,
1424            vec![
1425                Arc::new(StringArray::from(vec![Some("Alice"), Some("Bob")])),
1426                Arc::new(Int64Array::from(vec![Some(30), Some(25)])),
1427            ],
1428        )
1429        .unwrap();
1430
1431        let rows = batches_to_rows(&[batch]).unwrap();
1432        assert_eq!(rows.len(), 2);
1433        assert_eq!(rows[0].get("name"), Some(&Value::String("Alice".into())));
1434        assert_eq!(rows[0].get("age"), Some(&Value::Int(30)));
1435        assert_eq!(rows[1].get("name"), Some(&Value::String("Bob".into())));
1436        assert_eq!(rows[1].get("age"), Some(&Value::Int(25)));
1437    }
1438
1439    #[test]
1440    fn test_rows_to_batches_basic() {
1441        let schema = Arc::new(Schema::new(vec![
1442            Field::new("name", DataType::Utf8, true),
1443            Field::new("age", DataType::Int64, true),
1444        ]));
1445
1446        let rows = vec![
1447            {
1448                let mut m = HashMap::new();
1449                m.insert("name".to_string(), Value::String("Alice".into()));
1450                m.insert("age".to_string(), Value::Int(30));
1451                m
1452            },
1453            {
1454                let mut m = HashMap::new();
1455                m.insert("name".to_string(), Value::String("Bob".into()));
1456                m.insert("age".to_string(), Value::Int(25));
1457                m
1458            },
1459        ];
1460
1461        let batches = rows_to_batches(&rows, &schema).unwrap();
1462        assert_eq!(batches.len(), 1);
1463        assert_eq!(batches[0].num_rows(), 2);
1464        assert_eq!(batches[0].schema(), schema);
1465    }
1466
1467    #[test]
1468    fn test_roundtrip_scalar_types() {
1469        let schema = Arc::new(Schema::new(vec![
1470            Field::new("s", DataType::Utf8, true),
1471            Field::new("i", DataType::Int64, true),
1472            Field::new("f", DataType::Float64, true),
1473            Field::new("b", DataType::Boolean, true),
1474        ]));
1475
1476        let batch = RecordBatch::try_new(
1477            schema.clone(),
1478            vec![
1479                Arc::new(StringArray::from(vec![Some("hello")])),
1480                Arc::new(Int64Array::from(vec![Some(42)])),
1481                Arc::new(arrow_array::Float64Array::from(vec![Some(3.125)])),
1482                Arc::new(arrow_array::BooleanArray::from(vec![Some(true)])),
1483            ],
1484        )
1485        .unwrap();
1486
1487        // Roundtrip: batches → rows → batches
1488        let rows = batches_to_rows(&[batch]).unwrap();
1489        let output_batches = rows_to_batches(&rows, &schema).unwrap();
1490
1491        assert_eq!(output_batches.len(), 1);
1492        assert_eq!(output_batches[0].num_rows(), 1);
1493
1494        // Verify roundtrip fidelity
1495        let roundtrip_rows = batches_to_rows(&output_batches).unwrap();
1496        assert_eq!(roundtrip_rows.len(), 1);
1497        assert_eq!(
1498            roundtrip_rows[0].get("s"),
1499            Some(&Value::String("hello".into()))
1500        );
1501        assert_eq!(roundtrip_rows[0].get("i"), Some(&Value::Int(42)));
1502        assert_eq!(roundtrip_rows[0].get("b"), Some(&Value::Bool(true)));
1503        // Float comparison
1504        if let Some(Value::Float(f)) = roundtrip_rows[0].get("f") {
1505            assert!((*f - 3.125).abs() < 1e-10);
1506        } else {
1507            panic!("Expected float value");
1508        }
1509    }
1510
1511    #[test]
1512    fn test_roundtrip_cypher_value_encoded() {
1513        use std::collections::HashMap as StdHashMap;
1514
1515        // Create a schema with a cv_encoded LargeBinary column (entity column)
1516        let mut metadata = StdHashMap::new();
1517        metadata.insert("cv_encoded".to_string(), "true".to_string());
1518        let field = Field::new("n", DataType::LargeBinary, true).with_metadata(metadata);
1519        let schema = Arc::new(Schema::new(vec![field]));
1520
1521        // Create a node-like Map value
1522        let mut node_map = HashMap::new();
1523        node_map.insert("name".to_string(), Value::String("Alice".into()));
1524        node_map.insert("_vid".to_string(), Value::Int(1));
1525        let map_val = Value::Map(node_map);
1526
1527        // Encode to CypherValue bytes
1528        let encoded = uni_common::cypher_value_codec::encode(&map_val);
1529        let batch = RecordBatch::try_new(
1530            schema.clone(),
1531            vec![Arc::new(arrow_array::LargeBinaryArray::from(vec![Some(
1532                encoded.as_slice(),
1533            )]))],
1534        )
1535        .unwrap();
1536
1537        // Roundtrip
1538        let rows = batches_to_rows(&[batch]).unwrap();
1539        assert_eq!(rows.len(), 1);
1540
1541        // The decoded value should be a Map
1542        let val = rows[0].get("n").unwrap();
1543        assert!(matches!(val, Value::Map(_)));
1544
1545        let output_batches = rows_to_batches(&rows, &schema).unwrap();
1546        assert_eq!(output_batches[0].num_rows(), 1);
1547
1548        // Verify we can decode it back
1549        let roundtrip_rows = batches_to_rows(&output_batches).unwrap();
1550        assert_eq!(roundtrip_rows.len(), 1);
1551    }
1552
1553    #[test]
1554    fn test_empty_rows() {
1555        let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, true)]));
1556
1557        let batches = rows_to_batches(&[], &schema).unwrap();
1558        assert_eq!(batches.len(), 1);
1559        assert_eq!(batches[0].num_rows(), 0);
1560    }
1561}