Skip to main content

uni_query/query/df_graph/
mutation_common.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Common infrastructure for DataFusion mutation operators (CREATE, SET, REMOVE, DELETE).
5//!
6//! Provides:
7//! - [`MutationContext`]: Shared context for mutation operators containing executor, writer, etc.
8//! - [`batches_to_rows`]: Convert RecordBatches to row-based HashMaps (batch→row direction).
9//! - [`rows_to_batches`]: Convert row-based HashMaps back to RecordBatches (row→batch direction).
10//! - [`MutationExec`]: Eager-barrier RecordBatchStream that collects all input, applies
11//!   mutations via Writer, and yields output batches.
12
13use anyhow::Result;
14use arrow_array::RecordBatch;
15use arrow_schema::{DataType, SchemaRef};
16use datafusion::common::Result as DFResult;
17use datafusion::execution::TaskContext;
18use datafusion::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet};
19use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
20use datafusion::physical_plan::{
21    DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SendableRecordBatchStream,
22};
23use futures::TryStreamExt;
24use std::any::Any;
25use std::collections::{HashMap, HashSet};
26use std::fmt;
27use std::sync::Arc;
28use tokio::sync::RwLock;
29use uni_common::core::id::Vid;
30use uni_common::{Path, Value};
31use uni_cypher::ast::{Expr, Pattern, PatternElement, RemoveItem, SetClause, SetItem};
32use uni_store::runtime::property_manager::PropertyManager;
33use uni_store::runtime::writer::Writer;
34use uni_store::storage::arrow_convert;
35
36use super::common::compute_plan_properties;
37use crate::query::executor::core::Executor;
38
39/// Shared context for mutation operators.
40///
41/// Contains all resources needed to execute write operations from within
42/// DataFusion ExecutionPlan operators. The Executor is `Clone` with all
43/// Arc-wrapped fields, so cloning it is cheap.
44#[derive(Clone)]
45pub struct MutationContext {
46    /// The query executor (cheap clone, all Arc fields).
47    pub executor: Executor,
48
49    /// Writer for graph mutations (vertices, edges, properties).
50    pub writer: Arc<RwLock<Writer>>,
51
52    /// Property manager for lazy-loading vertex/edge properties.
53    pub prop_manager: Arc<PropertyManager>,
54
55    /// Query parameters (e.g., `$param` references in Cypher).
56    pub params: HashMap<String, Value>,
57
58    /// Query context for L0 buffer visibility.
59    pub query_ctx: Option<uni_store::QueryContext>,
60
61    /// When set, mutations are routed to this private L0 buffer instead of
62    /// the global L0. Passed explicitly to Writer methods during mutation execution.
63    pub tx_l0_override: Option<Arc<parking_lot::RwLock<uni_store::runtime::l0::L0Buffer>>>,
64}
65
66impl std::fmt::Debug for MutationContext {
67    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68        f.debug_struct("MutationContext")
69            .field("has_writer", &true)
70            .field("has_prop_manager", &true)
71            .field("params_count", &self.params.len())
72            .field("has_query_ctx", &self.query_ctx.is_some())
73            .finish()
74    }
75}
76
77/// The kind of mutation to apply per row.
78#[derive(Debug, Clone)]
79pub enum MutationKind {
80    /// CREATE clause: create nodes/edges per the pattern.
81    Create { pattern: Pattern },
82
83    /// CREATE with multiple patterns (batched CREATE).
84    CreateBatch { patterns: Vec<Pattern> },
85
86    /// SET clause: update properties/labels.
87    Set { items: Vec<SetItem> },
88
89    /// REMOVE clause: remove properties/labels.
90    Remove { items: Vec<RemoveItem> },
91
92    /// DELETE clause: delete nodes/edges.
93    Delete { items: Vec<Expr>, detach: bool },
94
95    /// MERGE clause: match-or-create with optional ON MATCH/ON CREATE actions.
96    Merge {
97        pattern: Pattern,
98        on_match: Option<SetClause>,
99        on_create: Option<SetClause>,
100    },
101}
102
103/// Convert RecordBatches to row-based HashMaps for mutation processing.
104///
105/// Handles special metadata on fields:
106/// - `cv_encoded=true`: Parse string value as JSON to restore original type
107/// - DateTime/Time struct types: Decode to temporal values
108///
109/// NOTE: This does NOT merge system fields (like `n._vid`) into bare variable
110/// maps. The raw column names are preserved so that `rows_to_batches` can
111/// reconstruct the RecordBatch with the same schema. System field merging
112/// happens later in `Executor::record_batches_to_rows()` for user-facing output.
113pub fn batches_to_rows(batches: &[RecordBatch]) -> Result<Vec<HashMap<String, Value>>> {
114    let mut rows = Vec::new();
115
116    for batch in batches {
117        let num_rows = batch.num_rows();
118        let schema = batch.schema();
119
120        for row_idx in 0..num_rows {
121            let mut row = HashMap::new();
122
123            for (col_idx, field) in schema.fields().iter().enumerate() {
124                let column = batch.column(col_idx);
125                // Infer Uni DataType from Arrow type for DateTime/Time struct decoding
126                let data_type = if uni_common::core::schema::is_datetime_struct(field.data_type()) {
127                    Some(&uni_common::DataType::DateTime)
128                } else if uni_common::core::schema::is_time_struct(field.data_type()) {
129                    Some(&uni_common::DataType::Time)
130                } else {
131                    None
132                };
133                let mut value = arrow_convert::arrow_to_value(column.as_ref(), row_idx, data_type);
134
135                // Check if this field contains JSON-encoded values (e.g., from UNWIND)
136                // Parse JSON string to restore the original type
137                if field
138                    .metadata()
139                    .get("cv_encoded")
140                    .is_some_and(|v| v == "true")
141                    && let Value::String(s) = &value
142                    && let Ok(parsed) = serde_json::from_str::<serde_json::Value>(s)
143                {
144                    value = Value::from(parsed);
145                }
146
147                row.insert(field.name().clone(), value);
148            }
149
150            // Also merge system fields into bare variable maps for the write helpers.
151            // The write helpers (execute_set_items_locked, etc.) expect variables
152            // as bare Maps with _vid/_labels inside. We do this AFTER preserving
153            // the raw keys so rows_to_batches can reconstruct the schema.
154            merge_system_fields_for_write(&mut row);
155
156            rows.push(row);
157        }
158    }
159
160    Ok(rows)
161}
162
163/// After mutations, sync `_all_props` within bare variable Maps from their direct property keys.
164///
165/// SET/REMOVE modify direct property keys in the bare Map (e.g., `row["n"]["name"] = "Bob"`)
166/// but the `_all_props` sub-map retains its stale pre-mutation values. The result normalizer
167/// and property UDFs (keys(), properties()) read from `_all_props`, so it must be kept in sync.
168///
169/// This must be called BEFORE `sync_dotted_columns` so that the dotted `n._all_props` column
170/// also gets the updated value.
171fn sync_all_props_in_maps(rows: &mut [HashMap<String, Value>]) {
172    for row in rows {
173        let map_keys: Vec<String> = row
174            .keys()
175            .filter(|k| !k.contains('.') && matches!(row.get(*k), Some(Value::Map(_))))
176            .cloned()
177            .collect();
178
179        for key in map_keys {
180            if let Some(Value::Map(map)) = row.get_mut(&key)
181                && map.contains_key("_all_props")
182            {
183                // Collect non-internal property keys and their values
184                let updates: Vec<(String, Value)> = map
185                    .iter()
186                    .filter(|(k, _)| !k.starts_with('_') && k.as_str() != "ext_id")
187                    .map(|(k, v)| (k.clone(), v.clone()))
188                    .collect();
189
190                if !updates.is_empty()
191                    && let Some(Value::Map(all_props)) = map.get_mut("_all_props")
192                {
193                    for (k, v) in updates {
194                        all_props.insert(k, v);
195                    }
196                }
197            }
198        }
199    }
200}
201
202/// After mutations, sync dotted property columns from bare variable Maps.
203///
204/// SET/REMOVE modify the bare Map (e.g., `row["n"]["name"] = "Bob"`) but the
205/// dotted column (`row["n.name"]`) retains its stale pre-mutation value.
206/// This step overwrites dotted columns from the Map so `rows_to_batches()`
207/// produces correct output. Also handles newly created variables from CREATE/MERGE
208/// by inserting dotted columns that didn't exist in the input.
209fn sync_dotted_columns(rows: &mut [HashMap<String, Value>], schema: &SchemaRef) {
210    for row in rows {
211        for field in schema.fields() {
212            let name = field.name();
213            if let Some(dot_pos) = name.find('.') {
214                let var_name = &name[..dot_pos];
215                let prop_name = &name[dot_pos + 1..];
216                if let Some(Value::Map(map)) = row.get(var_name) {
217                    let val = map.get(prop_name).cloned().unwrap_or(Value::Null);
218                    row.insert(name.clone(), val);
219                }
220            }
221        }
222    }
223}
224
225/// Normalize edge system field names in a map: `_src_vid` -> `_src`, `_dst_vid` -> `_dst`.
226///
227/// The write executor expects `_src`/`_dst` but DataFusion traverse emits `_src_vid`/`_dst_vid`.
228fn normalize_edge_field_names(map: &mut HashMap<String, Value>) {
229    if let Some(val) = map.remove("_src_vid") {
230        map.entry("_src".to_string()).or_insert(val);
231    }
232    if let Some(val) = map.remove("_dst_vid") {
233        map.entry("_dst".to_string()).or_insert(val);
234    }
235}
236
237/// Merge system fields into bare variable maps for write helper consumption.
238///
239/// The write helpers expect variables like `n` to be a Map containing `_vid`, `_labels`, etc.
240/// This merges dotted columns (like `n._vid`, `n._labels`) into the variable Map,
241/// while KEEPING the dotted columns in the row so `rows_to_batches` still works.
242fn merge_system_fields_for_write(row: &mut HashMap<String, Value>) {
243    // Vertex system fields (overwrite into the bare map) and edge system fields
244    // (insert only if absent) that should be copied from dotted columns.
245    const VERTEX_FIELDS: &[&str] = &["_vid", "_labels"];
246    const EDGE_FIELDS: &[&str] = &["_eid", "_type", "_src_vid", "_dst_vid"];
247
248    // Collect all variable names that have dotted columns (var.field).
249    let dotted_vars: HashSet<String> = row
250        .keys()
251        .filter_map(|key| key.find('.').map(|pos| key[..pos].to_string()))
252        .collect();
253
254    // For each variable with dotted columns, ensure a bare Map exists.
255    // If the variable is only represented via dotted columns (e.g., edge from
256    // TraverseMainByType), assemble a Map from those columns.
257    for var in &dotted_vars {
258        if !row.contains_key(var) {
259            let prefix = format!("{var}.");
260            let mut map: HashMap<String, Value> = row
261                .iter()
262                .filter_map(|(k, v)| {
263                    k.strip_prefix(prefix.as_str())
264                        .map(|field| (field.to_string(), v.clone()))
265                })
266                .collect();
267            normalize_edge_field_names(&mut map);
268            if !map.is_empty() {
269                row.insert(var.clone(), Value::Map(map));
270            }
271        }
272    }
273
274    // Merge system fields from dotted columns into bare Maps and normalize edge names.
275    // Single pass: vertex fields overwrite, edge fields insert-if-absent, then normalize.
276    let bare_vars: Vec<String> = row
277        .keys()
278        .filter(|k| !k.contains('.') && matches!(row.get(*k), Some(Value::Map(_))))
279        .cloned()
280        .collect();
281
282    for var in &bare_vars {
283        // Collect dotted values to merge (avoids borrowing row mutably while reading)
284        let vertex_vals: Vec<(&str, Value)> = VERTEX_FIELDS
285            .iter()
286            .filter_map(|&field| {
287                row.get(&format!("{var}.{field}"))
288                    .cloned()
289                    .map(|v| (field, v))
290            })
291            .collect();
292        let edge_vals: Vec<(&str, Value)> = EDGE_FIELDS
293            .iter()
294            .filter_map(|&field| {
295                row.get(&format!("{var}.{field}"))
296                    .cloned()
297                    .map(|v| (field, v))
298            })
299            .collect();
300
301        if let Some(Value::Map(map)) = row.get_mut(var) {
302            for (field, v) in vertex_vals {
303                map.insert(field.to_string(), v);
304            }
305            for (field, v) in edge_vals {
306                map.entry(field.to_string()).or_insert(v);
307            }
308            normalize_edge_field_names(map);
309        }
310    }
311}
312
313/// Convert row-based HashMaps back to RecordBatches.
314///
315/// This is the inverse of `batches_to_rows`. Schema-driven: iterates over the
316/// output schema fields and extracts named values from each row HashMap.
317///
318/// - Entity columns (LargeBinary with `cv_encoded=true`): serialize Map/Node/Edge values
319///   to CypherValue binary encoding.
320/// - Scalar columns: use `arrow_convert::values_to_array()` for type-appropriate conversion.
321pub fn rows_to_batches(
322    rows: &[HashMap<String, Value>],
323    schema: &SchemaRef,
324) -> Result<Vec<RecordBatch>> {
325    if rows.is_empty() {
326        // Handle empty schema case (no fields)
327        let batch = if schema.fields().is_empty() {
328            let options = arrow_array::RecordBatchOptions::new().with_row_count(Some(0));
329            RecordBatch::try_new_with_options(schema.clone(), vec![], &options)?
330        } else {
331            RecordBatch::new_empty(schema.clone())
332        };
333        return Ok(vec![batch]);
334    }
335
336    if schema.fields().is_empty() {
337        // Schema has no fields but there ARE rows. Preserve the row count so that
338        // downstream operators (chained mutations, aggregations) see the correct
339        // number of rows. A RecordBatch with 0 columns can still carry a row count.
340        let options = arrow_array::RecordBatchOptions::new().with_row_count(Some(rows.len()));
341        let batch = RecordBatch::try_new_with_options(schema.clone(), vec![], &options)?;
342        return Ok(vec![batch]);
343    }
344
345    // Build columns from rows using schema
346    let mut columns: Vec<arrow_array::ArrayRef> = Vec::with_capacity(schema.fields().len());
347
348    for field in schema.fields() {
349        let name = field.name();
350        let values: Vec<Value> = rows
351            .iter()
352            .map(|row| row.get(name).cloned().unwrap_or(Value::Null))
353            .collect();
354
355        let array = value_column_to_arrow(&values, field.data_type(), field)?;
356        columns.push(array);
357    }
358
359    let batch = RecordBatch::try_new(schema.clone(), columns)?;
360    Ok(vec![batch])
361}
362
363/// Convert a column of Values to an Arrow array, handling entity-encoded columns.
364fn value_column_to_arrow(
365    values: &[Value],
366    arrow_type: &DataType,
367    field: &arrow_schema::Field,
368) -> Result<arrow_array::ArrayRef> {
369    let is_cv_encoded = field
370        .metadata()
371        .get("cv_encoded")
372        .is_some_and(|v| v == "true");
373
374    if *arrow_type == DataType::LargeBinary || is_cv_encoded {
375        Ok(encode_as_large_binary(values))
376    } else if *arrow_type == DataType::Binary {
377        // Binary columns (e.g., CRDT payloads): encode as Binary, not LargeBinary
378        Ok(encode_as_binary(values))
379    } else {
380        // Use arrow_convert for scalar types, falling back to CypherValue encoding
381        arrow_convert::values_to_array(values, arrow_type)
382            .or_else(|_| Ok(encode_as_large_binary(values)))
383    }
384}
385
386/// Encode values as CypherValue blobs using the given builder type.
387macro_rules! encode_as_cv {
388    ($builder_ty:ty, $values:expr) => {{
389        let values = $values;
390        let mut builder = <$builder_ty>::with_capacity(values.len(), values.len() * 64);
391        for v in values {
392            if v.is_null() {
393                builder.append_null();
394            } else {
395                let bytes = uni_common::cypher_value_codec::encode(v);
396                builder.append_value(&bytes);
397            }
398        }
399        Arc::new(builder.finish()) as arrow_array::ArrayRef
400    }};
401}
402
403/// Encode values as CypherValue Binary blobs.
404fn encode_as_binary(values: &[Value]) -> arrow_array::ArrayRef {
405    encode_as_cv!(arrow_array::builder::BinaryBuilder, values)
406}
407
408/// Encode values as CypherValue LargeBinary blobs.
409fn encode_as_large_binary(values: &[Value]) -> arrow_array::ArrayRef {
410    encode_as_cv!(arrow_array::builder::LargeBinaryBuilder, values)
411}
412
413/// Execute a mutation stream: collect all input batches, apply mutations, yield output.
414///
415/// This is the core logic shared by all mutation operators. It implements the
416/// "eager barrier" pattern:
417/// 1. Pull ALL input batches to completion
418/// 2. Convert to rows
419/// 3. Acquire writer lock once for the entire clause
420/// 4. Apply mutations per row
421/// 5. Convert back to batches
422/// 6. Yield output
423pub fn execute_mutation_stream(
424    input: Arc<dyn ExecutionPlan>,
425    output_schema: SchemaRef,
426    mutation_ctx: Arc<MutationContext>,
427    mutation_kind: MutationKind,
428    partition: usize,
429    task_ctx: Arc<datafusion::execution::TaskContext>,
430) -> DFResult<SendableRecordBatchStream> {
431    if mutation_ctx.query_ctx.is_none() {
432        tracing::warn!(
433            "MutationContext.query_ctx is None — mutations may not see latest L0 buffer state"
434        );
435    }
436
437    let stream = futures::stream::once(execute_mutation_inner(
438        input,
439        output_schema.clone(),
440        mutation_ctx,
441        mutation_kind,
442        partition,
443        task_ctx,
444    ))
445    .try_flatten();
446
447    Ok(Box::pin(RecordBatchStreamAdapter::new(
448        output_schema,
449        stream,
450    )))
451}
452
453/// Inner async function for mutation execution.
454///
455/// Separated from the stream combinator to provide explicit return type
456/// annotation, avoiding type inference issues with multiple From<DataFusionError> impls.
457///
458/// Mutations are applied as storage-level side effects via Writer/L0 buffer.
459/// After mutations, output batches are reconstructed from the modified rows
460/// so downstream operators (RETURN, WITH, subsequent mutations) see the
461/// created/updated variables and properties.
462async fn execute_mutation_inner(
463    input: Arc<dyn ExecutionPlan>,
464    output_schema: SchemaRef,
465    mutation_ctx: Arc<MutationContext>,
466    mutation_kind: MutationKind,
467    partition: usize,
468    task_ctx: Arc<datafusion::execution::TaskContext>,
469) -> DFResult<futures::stream::Iter<std::vec::IntoIter<DFResult<RecordBatch>>>> {
470    let mutation_label = mutation_kind_label(&mutation_kind);
471
472    // 1. Collect all input batches (eager barrier)
473    let input_stream = input.execute(partition, task_ctx)?;
474    let input_batches: Vec<RecordBatch> = input_stream.try_collect().await?;
475
476    let input_row_count: usize = input_batches.iter().map(|b| b.num_rows()).sum();
477    tracing::debug!(
478        mutation = mutation_label,
479        batches = input_batches.len(),
480        rows = input_row_count,
481        "Executing mutation"
482    );
483
484    // 2. Convert to rows for mutation helpers (they operate on HashMap rows)
485    let mut rows = batches_to_rows(&input_batches).map_err(|e| {
486        datafusion::error::DataFusionError::Execution(format!(
487            "Failed to convert batches to rows: {e}"
488        ))
489    })?;
490
491    // 3. Apply mutations.
492    // MERGE manages its own writer lock internally (acquires/releases per-row because
493    // execute_merge_match needs to run a read subplan between lock acquisitions).
494    // All other mutations acquire the writer lock once for the entire clause.
495    if let MutationKind::Merge {
496        ref pattern,
497        ref on_match,
498        ref on_create,
499    } = mutation_kind
500    {
501        let exec = &mutation_ctx.executor;
502        let pm = &mutation_ctx.prop_manager;
503        let params = &mutation_ctx.params;
504        let ctx = mutation_ctx.query_ctx.as_ref();
505
506        let mut result_rows = exec
507            .execute_merge(
508                rows,
509                pattern,
510                on_match.as_ref(),
511                on_create.as_ref(),
512                pm,
513                params,
514                ctx,
515                mutation_ctx.tx_l0_override.as_ref(),
516            )
517            .await
518            .map_err(|e| {
519                datafusion::error::DataFusionError::Execution(format!("MERGE failed: {e}"))
520            })?;
521
522        tracing::debug!(
523            mutation = mutation_label,
524            input_rows = input_row_count,
525            output_rows = result_rows.len(),
526            "MERGE mutation complete"
527        );
528
529        // Reconstruct output batches from modified rows so downstream operators
530        // (RETURN, WITH, subsequent mutations) see the merged/created variables.
531        sync_all_props_in_maps(&mut result_rows);
532        sync_dotted_columns(&mut result_rows, &output_schema);
533        let result_batches = rows_to_batches(&result_rows, &output_schema).map_err(|e| {
534            datafusion::error::DataFusionError::Execution(format!(
535                "Failed to reconstruct MERGE batches: {e}"
536            ))
537        })?;
538        let results: Vec<DFResult<RecordBatch>> = result_batches.into_iter().map(Ok).collect();
539        return Ok(futures::stream::iter(results));
540    }
541
542    let mut writer = mutation_ctx.writer.write().await;
543    let tx_l0 = mutation_ctx.tx_l0_override.as_ref();
544    let result =
545        apply_mutations(&mutation_ctx, &mutation_kind, &mut rows, &mut writer, tx_l0).await;
546    drop(writer);
547    result?;
548
549    tracing::debug!(
550        mutation = mutation_label,
551        rows = input_row_count,
552        "Mutation complete"
553    );
554
555    // 4. Reconstruct output batches from modified rows.
556    // Mutations modify the row HashMaps in place (CREATE adds new variable keys,
557    // SET updates property values). Reconstruct batches so downstream operators
558    // (RETURN, WITH, subsequent mutations) see these modifications.
559    sync_all_props_in_maps(&mut rows);
560    sync_dotted_columns(&mut rows, &output_schema);
561    let result_batches = rows_to_batches(&rows, &output_schema).map_err(|e| {
562        datafusion::error::DataFusionError::Execution(format!("Failed to reconstruct batches: {e}"))
563    })?;
564    let results: Vec<DFResult<RecordBatch>> = result_batches.into_iter().map(Ok).collect();
565    Ok(futures::stream::iter(results))
566}
567
568/// Collects and classifies DELETE targets into nodes and edges.
569///
570/// Handles `Value::Path`, `Value::Node`, `Value::Edge`, map-encoded paths
571/// (from Arrow round-trip), and raw VID values. When `dedup` is true,
572/// uses HashSets to skip duplicates (needed for non-DETACH DELETE to
573/// handle shared nodes across paths).
574struct DeleteCollector {
575    /// Collected node entries: (vid, labels) pairs.
576    node_entries: Vec<(Vid, Option<Vec<String>>)>,
577    /// Collected edge values to delete.
578    edge_vals: Vec<Value>,
579    /// Deduplication sets (only used when dedup=true).
580    seen_vids: HashSet<u64>,
581    seen_eids: HashSet<u64>,
582    dedup: bool,
583}
584
585impl DeleteCollector {
586    fn new(dedup: bool) -> Self {
587        Self {
588            node_entries: Vec::new(),
589            edge_vals: Vec::new(),
590            seen_vids: HashSet::new(),
591            seen_eids: HashSet::new(),
592            dedup,
593        }
594    }
595
596    fn add(&mut self, val: Value) {
597        if val.is_null() {
598            return;
599        }
600
601        // Try to resolve value as a Path (native or map-encoded).
602        let path = match &val {
603            Value::Path(p) => Some(p.clone()),
604            _ => Path::try_from(&val).ok(),
605        };
606
607        if let Some(path) = path {
608            for edge in &path.edges {
609                if !self.dedup || self.seen_eids.insert(edge.eid.as_u64()) {
610                    self.edge_vals.push(Value::Edge(edge.clone()));
611                }
612            }
613            for node in &path.nodes {
614                self.add_node(node.vid, Some(node.labels.clone()));
615            }
616            return;
617        }
618
619        // Not a path -- try as a node (by VID).
620        if let Ok(vid) = Executor::vid_from_value(&val) {
621            let labels = Executor::extract_labels_from_node(&val);
622            self.add_node(vid, labels);
623            return;
624        }
625
626        // Otherwise treat as an edge value.
627        if matches!(&val, Value::Map(_) | Value::Edge(_)) {
628            self.edge_vals.push(val);
629        }
630    }
631
632    fn add_node(&mut self, vid: Vid, labels: Option<Vec<String>>) {
633        if self.dedup && !self.seen_vids.insert(vid.as_u64()) {
634            return;
635        }
636        self.node_entries.push((vid, labels));
637    }
638}
639
640/// Apply mutations to rows using the appropriate executor helper.
641async fn apply_mutations(
642    mutation_ctx: &MutationContext,
643    mutation_kind: &MutationKind,
644    rows: &mut [HashMap<String, Value>],
645    writer: &mut Writer,
646    tx_l0: Option<&Arc<parking_lot::RwLock<uni_store::runtime::l0::L0Buffer>>>,
647) -> DFResult<()> {
648    tracing::trace!(
649        mutation = mutation_kind_label(mutation_kind),
650        rows = rows.len(),
651        "Applying mutations"
652    );
653
654    let exec = &mutation_ctx.executor;
655    let pm = &mutation_ctx.prop_manager;
656    let params = &mutation_ctx.params;
657    let ctx = mutation_ctx.query_ctx.as_ref();
658
659    let df_err = |msg: &str, e: anyhow::Error| {
660        datafusion::error::DataFusionError::Execution(format!("{msg}: {e}"))
661    };
662
663    match mutation_kind {
664        MutationKind::Create { pattern } => {
665            for row in rows.iter_mut() {
666                exec.execute_create_pattern(pattern, row, writer, pm, params, ctx, tx_l0)
667                    .await
668                    .map_err(|e| df_err("CREATE failed", e))?;
669            }
670        }
671        MutationKind::CreateBatch { patterns } => {
672            for row in rows.iter_mut() {
673                for pattern in patterns {
674                    exec.execute_create_pattern(pattern, row, writer, pm, params, ctx, tx_l0)
675                        .await
676                        .map_err(|e| df_err("CREATE failed", e))?;
677                }
678            }
679        }
680        MutationKind::Set { items } => {
681            for row in rows.iter_mut() {
682                exec.execute_set_items_locked(items, row, writer, pm, params, ctx, tx_l0)
683                    .await
684                    .map_err(|e| df_err("SET failed", e))?;
685            }
686        }
687        MutationKind::Remove { items } => {
688            for row in rows.iter_mut() {
689                exec.execute_remove_items_locked(items, row, writer, pm, ctx, tx_l0)
690                    .await
691                    .map_err(|e| df_err("REMOVE failed", e))?;
692            }
693        }
694        MutationKind::Delete { items, detach } => {
695            // Evaluate all DELETE targets and classify into nodes vs edges.
696            let mut collector = DeleteCollector::new(!*detach);
697            for row in rows.iter() {
698                for expr in items {
699                    let val = exec
700                        .evaluate_expr(expr, row, pm, params, ctx)
701                        .await
702                        .map_err(|e| df_err("DELETE eval failed", e))?;
703                    collector.add(val);
704                }
705            }
706
707            // Delete edges before nodes so non-detach DELETE satisfies constraints.
708            for val in &collector.edge_vals {
709                exec.execute_delete_item_locked(val, false, writer, tx_l0)
710                    .await
711                    .map_err(|e| df_err("DELETE edge failed", e))?;
712            }
713
714            if *detach {
715                let (vids, labels): (Vec<Vid>, Vec<Option<Vec<String>>>) =
716                    collector.node_entries.into_iter().unzip();
717                exec.batch_detach_delete_vertices(&vids, labels, writer, tx_l0)
718                    .await
719                    .map_err(|e| df_err("DETACH DELETE failed", e))?;
720            } else {
721                for (vid, labels) in &collector.node_entries {
722                    exec.execute_delete_vertex(*vid, false, labels.clone(), writer, tx_l0)
723                        .await
724                        .map_err(|e| df_err("DELETE node failed", e))?;
725                }
726            }
727        }
728        MutationKind::Merge { .. } => {
729            // MERGE is handled before the writer lock in execute_mutation_inner.
730            // This branch is unreachable but required for exhaustive matching.
731            unreachable!("MERGE mutations are handled before apply_mutations is called");
732        }
733    }
734
735    Ok(())
736}
737
738/// Extract variable names introduced by a CREATE/MERGE pattern.
739///
740/// Walks the pattern tree and collects all node and relationship variable names.
741/// Used to compute extended output schemas for CREATE/MERGE operators.
742pub fn pattern_variable_names(pattern: &Pattern) -> Vec<String> {
743    let mut vars = Vec::new();
744    for path in &pattern.paths {
745        if let Some(ref v) = path.variable {
746            vars.push(v.clone());
747        }
748        for element in &path.elements {
749            match element {
750                PatternElement::Node(n) => {
751                    if let Some(ref v) = n.variable {
752                        vars.push(v.clone());
753                    }
754                }
755                PatternElement::Relationship(r) => {
756                    if let Some(ref v) = r.variable {
757                        vars.push(v.clone());
758                    }
759                }
760                PatternElement::Parenthesized { pattern, .. } => {
761                    // Recurse into parenthesized sub-patterns
762                    let sub = Pattern {
763                        paths: vec![pattern.as_ref().clone()],
764                    };
765                    vars.extend(pattern_variable_names(&sub));
766                }
767            }
768        }
769    }
770    vars
771}
772
773/// Normalize a schema for mutation output.
774///
775/// After mutation processing, entity values (nodes/edges) are stored as
776/// `Value::Map` in row HashMaps. The input schema may have Struct columns
777/// for these entities, but `rows_to_batches()` encodes Map values as
778/// cv_encoded LargeBinary. This function converts Struct and Binary entity
779/// columns to cv_encoded LargeBinary to match the actual output format.
780fn normalize_mutation_schema(schema: &SchemaRef) -> SchemaRef {
781    use arrow_schema::{Field, Schema};
782
783    let needs_normalization = schema
784        .fields()
785        .iter()
786        .any(|f| matches!(f.data_type(), DataType::Struct(_)));
787
788    if !needs_normalization {
789        return schema.clone();
790    }
791
792    let fields: Vec<Arc<Field>> = schema
793        .fields()
794        .iter()
795        .map(|field| {
796            if matches!(field.data_type(), DataType::Struct(_)) {
797                let mut metadata = field.metadata().clone();
798                metadata.insert("cv_encoded".to_string(), "true".to_string());
799                Arc::new(
800                    Field::new(field.name(), DataType::LargeBinary, true).with_metadata(metadata),
801                )
802            } else {
803                field.clone()
804            }
805        })
806        .collect();
807
808    Arc::new(Schema::new(fields))
809}
810
811/// Compute an extended output schema that includes columns for newly created variables.
812///
813/// Extracts variables from CREATE/MERGE patterns and adds:
814/// - Bare cv_encoded LargeBinary column for each variable
815/// - System dotted columns based on element type:
816///   - Node → `{var}._vid` (UInt64), `{var}._labels` (LargeBinary cv_encoded)
817///   - Edge → `{var}._eid` (UInt64), `{var}._type` (LargeBinary cv_encoded)
818///   - Path → bare column only (no system columns)
819///
820/// Property access on mutation variables uses dynamic `index()` UDF extraction,
821/// so property columns are NOT added here.
822///
823/// Also normalizes existing Struct entity columns to cv_encoded LargeBinary,
824/// since after mutation processing, entities are stored as Maps in row HashMaps.
825pub fn extended_schema_for_new_vars(input_schema: &SchemaRef, patterns: &[Pattern]) -> SchemaRef {
826    use arrow_schema::{Field, Schema};
827
828    // First normalize existing columns
829    let normalized = normalize_mutation_schema(input_schema);
830
831    let existing_names: HashSet<&str> = normalized
832        .fields()
833        .iter()
834        .map(|f| f.name().as_str())
835        .collect();
836
837    let mut fields: Vec<Arc<arrow_schema::Field>> = normalized.fields().to_vec();
838    let mut added: HashSet<String> = HashSet::new();
839
840    fn cv_metadata() -> std::collections::HashMap<String, String> {
841        let mut m = std::collections::HashMap::new();
842        m.insert("cv_encoded".to_string(), "true".to_string());
843        m
844    }
845
846    fn add_bare_column(
847        var: &str,
848        fields: &mut Vec<Arc<arrow_schema::Field>>,
849        existing: &HashSet<&str>,
850        added: &mut HashSet<String>,
851    ) -> bool {
852        if existing.contains(var) || added.contains(var) {
853            return false;
854        }
855        added.insert(var.to_string());
856        fields.push(Arc::new(
857            Field::new(var, DataType::LargeBinary, true).with_metadata(cv_metadata()),
858        ));
859        true
860    }
861
862    for pattern in patterns {
863        for path in &pattern.paths {
864            // Path variable (e.g., `p` in `MERGE p = (a)-[r]->(b)`)
865            if let Some(ref var) = path.variable {
866                add_bare_column(var, &mut fields, &existing_names, &mut added);
867            }
868            for element in &path.elements {
869                match element {
870                    PatternElement::Node(n) => {
871                        if let Some(ref var) = n.variable
872                            && add_bare_column(var, &mut fields, &existing_names, &mut added)
873                        {
874                            // Node system columns for id()/labels()
875                            fields.push(Arc::new(Field::new(
876                                format!("{var}._vid"),
877                                DataType::UInt64,
878                                true,
879                            )));
880                            fields.push(Arc::new(
881                                Field::new(format!("{var}._labels"), DataType::LargeBinary, true)
882                                    .with_metadata(cv_metadata()),
883                            ));
884                        }
885                    }
886                    PatternElement::Relationship(r) => {
887                        if let Some(ref var) = r.variable
888                            && add_bare_column(var, &mut fields, &existing_names, &mut added)
889                        {
890                            // Edge system columns for id()/type()
891                            fields.push(Arc::new(Field::new(
892                                format!("{var}._eid"),
893                                DataType::UInt64,
894                                true,
895                            )));
896                            fields.push(Arc::new(
897                                Field::new(format!("{var}._type"), DataType::LargeBinary, true)
898                                    .with_metadata(cv_metadata()),
899                            ));
900                        }
901                    }
902                    PatternElement::Parenthesized { pattern, .. } => {
903                        // Recurse into sub-patterns. Pass current fields as
904                        // input so the recursive call's `existing_names` check
905                        // prevents duplicates for variables already added.
906                        let sub = Pattern {
907                            paths: vec![pattern.as_ref().clone()],
908                        };
909                        let sub_schema = extended_schema_for_new_vars(
910                            &Arc::new(Schema::new(fields.clone())),
911                            &[sub],
912                        );
913                        // Sync `added` from new fields to prevent duplicates
914                        // if a later pattern element reuses a variable.
915                        for field in sub_schema.fields() {
916                            added.insert(field.name().clone());
917                        }
918                        fields = sub_schema.fields().to_vec();
919                    }
920                }
921            }
922        }
923    }
924
925    Arc::new(Schema::new(fields))
926}
927
928/// Human-readable label for a MutationKind (used in tracing spans).
929fn mutation_kind_label(kind: &MutationKind) -> &'static str {
930    match kind {
931        MutationKind::Create { .. } => "CREATE",
932        MutationKind::CreateBatch { .. } => "CREATE_BATCH",
933        MutationKind::Set { .. } => "SET",
934        MutationKind::Remove { .. } => "REMOVE",
935        MutationKind::Delete { .. } => "DELETE",
936        MutationKind::Merge { .. } => "MERGE",
937    }
938}
939
940// ============================================================================
941// Unified MutationExec: single ExecutionPlan for all mutation kinds
942// ============================================================================
943
944/// Unified DataFusion `ExecutionPlan` for all Cypher mutation clauses
945/// (CREATE, SET, REMOVE, DELETE).
946///
947/// Instead of four near-identical ExecutionPlan structs, this single struct
948/// holds a [`MutationKind`] discriminant and delegates to the shared
949/// [`execute_mutation_stream`] implementation. Typed constructors in
950/// `mutation_create`, `mutation_set`, `mutation_remove`, and `mutation_delete`
951/// provide ergonomic construction with the correct kind.
952#[derive(Debug)]
953pub struct MutationExec {
954    /// Child plan producing input rows.
955    input: Arc<dyn ExecutionPlan>,
956
957    /// The kind of mutation to apply.
958    kind: MutationKind,
959
960    /// Display name for EXPLAIN output.
961    display_name: &'static str,
962
963    /// Shared mutation context with executor and writer.
964    mutation_ctx: Arc<MutationContext>,
965
966    /// Output schema (input schema, mutations are side effects).
967    schema: SchemaRef,
968
969    /// Plan properties for DataFusion optimizer.
970    properties: PlanProperties,
971
972    /// Metrics.
973    metrics: ExecutionPlanMetricsSet,
974}
975
976impl MutationExec {
977    /// Create a new `MutationExec` with the given kind.
978    ///
979    /// The output schema is derived from the input schema with Struct entity
980    /// columns normalized to cv_encoded LargeBinary. For mutations that
981    /// introduce new variables (CREATE, MERGE), use [`Self::new_with_schema`] instead.
982    pub fn new(
983        input: Arc<dyn ExecutionPlan>,
984        kind: MutationKind,
985        display_name: &'static str,
986        mutation_ctx: Arc<MutationContext>,
987    ) -> Self {
988        let schema = normalize_mutation_schema(&input.schema());
989        let properties = compute_plan_properties(schema.clone());
990        Self {
991            input,
992            kind,
993            display_name,
994            mutation_ctx,
995            schema,
996            properties,
997            metrics: ExecutionPlanMetricsSet::new(),
998        }
999    }
1000
1001    /// Create a new `MutationExec` with an explicit output schema.
1002    ///
1003    /// Used by CREATE and MERGE operators whose output includes newly created
1004    /// variables not present in the input schema.
1005    pub fn new_with_schema(
1006        input: Arc<dyn ExecutionPlan>,
1007        kind: MutationKind,
1008        display_name: &'static str,
1009        mutation_ctx: Arc<MutationContext>,
1010        output_schema: SchemaRef,
1011    ) -> Self {
1012        let properties = compute_plan_properties(output_schema.clone());
1013        Self {
1014            input,
1015            kind,
1016            display_name,
1017            mutation_ctx,
1018            schema: output_schema,
1019            properties,
1020            metrics: ExecutionPlanMetricsSet::new(),
1021        }
1022    }
1023}
1024
1025impl DisplayAs for MutationExec {
1026    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
1027        if matches!(&self.kind, MutationKind::Delete { detach: true, .. }) {
1028            write!(f, "{} [DETACH]", self.display_name)
1029        } else {
1030            write!(f, "{}", self.display_name)
1031        }
1032    }
1033}
1034
1035impl ExecutionPlan for MutationExec {
1036    fn name(&self) -> &str {
1037        self.display_name
1038    }
1039
1040    fn as_any(&self) -> &dyn Any {
1041        self
1042    }
1043
1044    fn schema(&self) -> SchemaRef {
1045        self.schema.clone()
1046    }
1047
1048    fn properties(&self) -> &PlanProperties {
1049        &self.properties
1050    }
1051
1052    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1053        vec![&self.input]
1054    }
1055
1056    fn with_new_children(
1057        self: Arc<Self>,
1058        children: Vec<Arc<dyn ExecutionPlan>>,
1059    ) -> DFResult<Arc<dyn ExecutionPlan>> {
1060        if children.len() != 1 {
1061            return Err(datafusion::error::DataFusionError::Plan(format!(
1062                "{} requires exactly one child",
1063                self.display_name,
1064            )));
1065        }
1066        Ok(Arc::new(MutationExec::new_with_schema(
1067            children[0].clone(),
1068            self.kind.clone(),
1069            self.display_name,
1070            self.mutation_ctx.clone(),
1071            self.schema.clone(),
1072        )))
1073    }
1074
1075    fn execute(
1076        &self,
1077        partition: usize,
1078        context: Arc<TaskContext>,
1079    ) -> DFResult<SendableRecordBatchStream> {
1080        execute_mutation_stream(
1081            self.input.clone(),
1082            self.schema.clone(),
1083            self.mutation_ctx.clone(),
1084            self.kind.clone(),
1085            partition,
1086            context,
1087        )
1088    }
1089
1090    fn metrics(&self) -> Option<MetricsSet> {
1091        Some(self.metrics.clone_inner())
1092    }
1093}
1094
1095/// Create a new `MutationExec` configured for a CREATE clause.
1096///
1097/// Computes an extended output schema that includes LargeBinary cv_encoded
1098/// columns for any variables introduced by the pattern that are not already
1099/// in the input schema.
1100pub fn new_create_exec(
1101    input: Arc<dyn ExecutionPlan>,
1102    pattern: Pattern,
1103    mutation_ctx: Arc<MutationContext>,
1104) -> MutationExec {
1105    let output_schema =
1106        extended_schema_for_new_vars(&input.schema(), std::slice::from_ref(&pattern));
1107    MutationExec::new_with_schema(
1108        input,
1109        MutationKind::Create { pattern },
1110        "MutationCreateExec",
1111        mutation_ctx,
1112        output_schema,
1113    )
1114}
1115
1116/// Create a new `MutationExec` configured for a MERGE clause.
1117///
1118/// Computes an extended output schema that includes LargeBinary cv_encoded
1119/// columns for any variables introduced by the pattern that are not already
1120/// in the input schema.
1121pub fn new_merge_exec(
1122    input: Arc<dyn ExecutionPlan>,
1123    pattern: Pattern,
1124    on_match: Option<SetClause>,
1125    on_create: Option<SetClause>,
1126    mutation_ctx: Arc<MutationContext>,
1127) -> MutationExec {
1128    let output_schema =
1129        extended_schema_for_new_vars(&input.schema(), std::slice::from_ref(&pattern));
1130    MutationExec::new_with_schema(
1131        input,
1132        MutationKind::Merge {
1133            pattern,
1134            on_match,
1135            on_create,
1136        },
1137        "MutationMergeExec",
1138        mutation_ctx,
1139        output_schema,
1140    )
1141}
1142
1143#[cfg(test)]
1144mod tests {
1145    use super::*;
1146    use arrow_array::{Int64Array, StringArray};
1147    use arrow_schema::{Field, Schema};
1148
1149    #[test]
1150    fn test_batches_to_rows_basic() {
1151        let schema = Arc::new(Schema::new(vec![
1152            Field::new("name", DataType::Utf8, true),
1153            Field::new("age", DataType::Int64, true),
1154        ]));
1155
1156        let batch = RecordBatch::try_new(
1157            schema,
1158            vec![
1159                Arc::new(StringArray::from(vec![Some("Alice"), Some("Bob")])),
1160                Arc::new(Int64Array::from(vec![Some(30), Some(25)])),
1161            ],
1162        )
1163        .unwrap();
1164
1165        let rows = batches_to_rows(&[batch]).unwrap();
1166        assert_eq!(rows.len(), 2);
1167        assert_eq!(rows[0].get("name"), Some(&Value::String("Alice".into())));
1168        assert_eq!(rows[0].get("age"), Some(&Value::Int(30)));
1169        assert_eq!(rows[1].get("name"), Some(&Value::String("Bob".into())));
1170        assert_eq!(rows[1].get("age"), Some(&Value::Int(25)));
1171    }
1172
1173    #[test]
1174    fn test_rows_to_batches_basic() {
1175        let schema = Arc::new(Schema::new(vec![
1176            Field::new("name", DataType::Utf8, true),
1177            Field::new("age", DataType::Int64, true),
1178        ]));
1179
1180        let rows = vec![
1181            {
1182                let mut m = HashMap::new();
1183                m.insert("name".to_string(), Value::String("Alice".into()));
1184                m.insert("age".to_string(), Value::Int(30));
1185                m
1186            },
1187            {
1188                let mut m = HashMap::new();
1189                m.insert("name".to_string(), Value::String("Bob".into()));
1190                m.insert("age".to_string(), Value::Int(25));
1191                m
1192            },
1193        ];
1194
1195        let batches = rows_to_batches(&rows, &schema).unwrap();
1196        assert_eq!(batches.len(), 1);
1197        assert_eq!(batches[0].num_rows(), 2);
1198        assert_eq!(batches[0].schema(), schema);
1199    }
1200
1201    #[test]
1202    fn test_roundtrip_scalar_types() {
1203        let schema = Arc::new(Schema::new(vec![
1204            Field::new("s", DataType::Utf8, true),
1205            Field::new("i", DataType::Int64, true),
1206            Field::new("f", DataType::Float64, true),
1207            Field::new("b", DataType::Boolean, true),
1208        ]));
1209
1210        let batch = RecordBatch::try_new(
1211            schema.clone(),
1212            vec![
1213                Arc::new(StringArray::from(vec![Some("hello")])),
1214                Arc::new(Int64Array::from(vec![Some(42)])),
1215                Arc::new(arrow_array::Float64Array::from(vec![Some(3.125)])),
1216                Arc::new(arrow_array::BooleanArray::from(vec![Some(true)])),
1217            ],
1218        )
1219        .unwrap();
1220
1221        // Roundtrip: batches → rows → batches
1222        let rows = batches_to_rows(&[batch]).unwrap();
1223        let output_batches = rows_to_batches(&rows, &schema).unwrap();
1224
1225        assert_eq!(output_batches.len(), 1);
1226        assert_eq!(output_batches[0].num_rows(), 1);
1227
1228        // Verify roundtrip fidelity
1229        let roundtrip_rows = batches_to_rows(&output_batches).unwrap();
1230        assert_eq!(roundtrip_rows.len(), 1);
1231        assert_eq!(
1232            roundtrip_rows[0].get("s"),
1233            Some(&Value::String("hello".into()))
1234        );
1235        assert_eq!(roundtrip_rows[0].get("i"), Some(&Value::Int(42)));
1236        assert_eq!(roundtrip_rows[0].get("b"), Some(&Value::Bool(true)));
1237        // Float comparison
1238        if let Some(Value::Float(f)) = roundtrip_rows[0].get("f") {
1239            assert!((*f - 3.125).abs() < 1e-10);
1240        } else {
1241            panic!("Expected float value");
1242        }
1243    }
1244
1245    #[test]
1246    fn test_roundtrip_cypher_value_encoded() {
1247        use std::collections::HashMap as StdHashMap;
1248
1249        // Create a schema with a cv_encoded LargeBinary column (entity column)
1250        let mut metadata = StdHashMap::new();
1251        metadata.insert("cv_encoded".to_string(), "true".to_string());
1252        let field = Field::new("n", DataType::LargeBinary, true).with_metadata(metadata);
1253        let schema = Arc::new(Schema::new(vec![field]));
1254
1255        // Create a node-like Map value
1256        let mut node_map = HashMap::new();
1257        node_map.insert("name".to_string(), Value::String("Alice".into()));
1258        node_map.insert("_vid".to_string(), Value::Int(1));
1259        let map_val = Value::Map(node_map);
1260
1261        // Encode to CypherValue bytes
1262        let encoded = uni_common::cypher_value_codec::encode(&map_val);
1263        let batch = RecordBatch::try_new(
1264            schema.clone(),
1265            vec![Arc::new(arrow_array::LargeBinaryArray::from(vec![Some(
1266                encoded.as_slice(),
1267            )]))],
1268        )
1269        .unwrap();
1270
1271        // Roundtrip
1272        let rows = batches_to_rows(&[batch]).unwrap();
1273        assert_eq!(rows.len(), 1);
1274
1275        // The decoded value should be a Map
1276        let val = rows[0].get("n").unwrap();
1277        assert!(matches!(val, Value::Map(_)));
1278
1279        let output_batches = rows_to_batches(&rows, &schema).unwrap();
1280        assert_eq!(output_batches[0].num_rows(), 1);
1281
1282        // Verify we can decode it back
1283        let roundtrip_rows = batches_to_rows(&output_batches).unwrap();
1284        assert_eq!(roundtrip_rows.len(), 1);
1285    }
1286
1287    #[test]
1288    fn test_empty_rows() {
1289        let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, true)]));
1290
1291        let batches = rows_to_batches(&[], &schema).unwrap();
1292        assert_eq!(batches.len(), 1);
1293        assert_eq!(batches[0].num_rows(), 0);
1294    }
1295}