1use anyhow::Result;
14use arrow_array::RecordBatch;
15use arrow_schema::{DataType, SchemaRef};
16use datafusion::common::Result as DFResult;
17use datafusion::execution::TaskContext;
18use datafusion::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet};
19use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
20use datafusion::physical_plan::{
21 DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SendableRecordBatchStream,
22};
23use futures::TryStreamExt;
24use std::any::Any;
25use std::collections::{HashMap, HashSet};
26use std::fmt;
27use std::sync::Arc;
28use tokio::sync::RwLock;
29use uni_common::core::id::Vid;
30use uni_common::{Path, Value};
31use uni_cypher::ast::{Expr, Pattern, PatternElement, RemoveItem, SetClause, SetItem};
32use uni_store::runtime::property_manager::PropertyManager;
33use uni_store::runtime::writer::Writer;
34use uni_store::storage::arrow_convert;
35
36use super::common::compute_plan_properties;
37use crate::query::executor::core::Executor;
38
39#[derive(Clone)]
45pub struct MutationContext {
46 pub executor: Executor,
48
49 pub writer: Arc<RwLock<Writer>>,
51
52 pub prop_manager: Arc<PropertyManager>,
54
55 pub params: HashMap<String, Value>,
57
58 pub query_ctx: Option<uni_store::QueryContext>,
60
61 pub tx_l0_override: Option<Arc<parking_lot::RwLock<uni_store::runtime::l0::L0Buffer>>>,
64}
65
66impl std::fmt::Debug for MutationContext {
67 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68 f.debug_struct("MutationContext")
69 .field("has_writer", &true)
70 .field("has_prop_manager", &true)
71 .field("params_count", &self.params.len())
72 .field("has_query_ctx", &self.query_ctx.is_some())
73 .finish()
74 }
75}
76
77#[derive(Debug, Clone)]
79pub enum MutationKind {
80 Create { pattern: Pattern },
82
83 CreateBatch { patterns: Vec<Pattern> },
85
86 Set { items: Vec<SetItem> },
88
89 Remove { items: Vec<RemoveItem> },
91
92 Delete { items: Vec<Expr>, detach: bool },
94
95 Merge {
97 pattern: Pattern,
98 on_match: Option<SetClause>,
99 on_create: Option<SetClause>,
100 },
101}
102
103pub fn batches_to_rows(batches: &[RecordBatch]) -> Result<Vec<HashMap<String, Value>>> {
114 let mut rows = Vec::new();
115
116 for batch in batches {
117 let num_rows = batch.num_rows();
118 let schema = batch.schema();
119
120 for row_idx in 0..num_rows {
121 let mut row = HashMap::new();
122
123 for (col_idx, field) in schema.fields().iter().enumerate() {
124 let column = batch.column(col_idx);
125 let data_type = if uni_common::core::schema::is_datetime_struct(field.data_type()) {
127 Some(&uni_common::DataType::DateTime)
128 } else if uni_common::core::schema::is_time_struct(field.data_type()) {
129 Some(&uni_common::DataType::Time)
130 } else {
131 None
132 };
133 let mut value = arrow_convert::arrow_to_value(column.as_ref(), row_idx, data_type);
134
135 if field
138 .metadata()
139 .get("cv_encoded")
140 .is_some_and(|v| v == "true")
141 && let Value::String(s) = &value
142 && let Ok(parsed) = serde_json::from_str::<serde_json::Value>(s)
143 {
144 value = Value::from(parsed);
145 }
146
147 row.insert(field.name().clone(), value);
148 }
149
150 merge_system_fields_for_write(&mut row);
155
156 rows.push(row);
157 }
158 }
159
160 Ok(rows)
161}
162
163fn sync_all_props_in_maps(rows: &mut [HashMap<String, Value>]) {
172 for row in rows {
173 let map_keys: Vec<String> = row
174 .keys()
175 .filter(|k| !k.contains('.') && matches!(row.get(*k), Some(Value::Map(_))))
176 .cloned()
177 .collect();
178
179 for key in map_keys {
180 if let Some(Value::Map(map)) = row.get_mut(&key)
181 && map.contains_key("_all_props")
182 {
183 let updates: Vec<(String, Value)> = map
185 .iter()
186 .filter(|(k, _)| !k.starts_with('_') && k.as_str() != "ext_id")
187 .map(|(k, v)| (k.clone(), v.clone()))
188 .collect();
189
190 if !updates.is_empty()
191 && let Some(Value::Map(all_props)) = map.get_mut("_all_props")
192 {
193 for (k, v) in updates {
194 all_props.insert(k, v);
195 }
196 }
197 }
198 }
199 }
200}
201
202fn sync_dotted_columns(rows: &mut [HashMap<String, Value>], schema: &SchemaRef) {
210 for row in rows {
211 for field in schema.fields() {
212 let name = field.name();
213 if let Some(dot_pos) = name.find('.') {
214 let var_name = &name[..dot_pos];
215 let prop_name = &name[dot_pos + 1..];
216 if let Some(Value::Map(map)) = row.get(var_name) {
217 let val = map.get(prop_name).cloned().unwrap_or(Value::Null);
218 row.insert(name.clone(), val);
219 }
220 }
221 }
222 }
223}
224
225fn normalize_edge_field_names(map: &mut HashMap<String, Value>) {
229 if let Some(val) = map.remove("_src_vid") {
230 map.entry("_src".to_string()).or_insert(val);
231 }
232 if let Some(val) = map.remove("_dst_vid") {
233 map.entry("_dst".to_string()).or_insert(val);
234 }
235}
236
237fn merge_system_fields_for_write(row: &mut HashMap<String, Value>) {
243 const VERTEX_FIELDS: &[&str] = &["_vid", "_labels"];
246 const EDGE_FIELDS: &[&str] = &["_eid", "_type", "_src_vid", "_dst_vid"];
247
248 let dotted_vars: HashSet<String> = row
250 .keys()
251 .filter_map(|key| key.find('.').map(|pos| key[..pos].to_string()))
252 .collect();
253
254 for var in &dotted_vars {
258 if !row.contains_key(var) {
259 let prefix = format!("{var}.");
260 let mut map: HashMap<String, Value> = row
261 .iter()
262 .filter_map(|(k, v)| {
263 k.strip_prefix(prefix.as_str())
264 .map(|field| (field.to_string(), v.clone()))
265 })
266 .collect();
267 normalize_edge_field_names(&mut map);
268 if !map.is_empty() {
269 row.insert(var.clone(), Value::Map(map));
270 }
271 }
272 }
273
274 let bare_vars: Vec<String> = row
277 .keys()
278 .filter(|k| !k.contains('.') && matches!(row.get(*k), Some(Value::Map(_))))
279 .cloned()
280 .collect();
281
282 for var in &bare_vars {
283 let vertex_vals: Vec<(&str, Value)> = VERTEX_FIELDS
285 .iter()
286 .filter_map(|&field| {
287 row.get(&format!("{var}.{field}"))
288 .cloned()
289 .map(|v| (field, v))
290 })
291 .collect();
292 let edge_vals: Vec<(&str, Value)> = EDGE_FIELDS
293 .iter()
294 .filter_map(|&field| {
295 row.get(&format!("{var}.{field}"))
296 .cloned()
297 .map(|v| (field, v))
298 })
299 .collect();
300
301 if let Some(Value::Map(map)) = row.get_mut(var) {
302 for (field, v) in vertex_vals {
303 map.insert(field.to_string(), v);
304 }
305 for (field, v) in edge_vals {
306 map.entry(field.to_string()).or_insert(v);
307 }
308 normalize_edge_field_names(map);
309 }
310 }
311}
312
313pub fn rows_to_batches(
322 rows: &[HashMap<String, Value>],
323 schema: &SchemaRef,
324) -> Result<Vec<RecordBatch>> {
325 if rows.is_empty() {
326 let batch = if schema.fields().is_empty() {
328 let options = arrow_array::RecordBatchOptions::new().with_row_count(Some(0));
329 RecordBatch::try_new_with_options(schema.clone(), vec![], &options)?
330 } else {
331 RecordBatch::new_empty(schema.clone())
332 };
333 return Ok(vec![batch]);
334 }
335
336 if schema.fields().is_empty() {
337 let options = arrow_array::RecordBatchOptions::new().with_row_count(Some(rows.len()));
341 let batch = RecordBatch::try_new_with_options(schema.clone(), vec![], &options)?;
342 return Ok(vec![batch]);
343 }
344
345 let mut columns: Vec<arrow_array::ArrayRef> = Vec::with_capacity(schema.fields().len());
347
348 for field in schema.fields() {
349 let name = field.name();
350 let values: Vec<Value> = rows
351 .iter()
352 .map(|row| row.get(name).cloned().unwrap_or(Value::Null))
353 .collect();
354
355 let array = value_column_to_arrow(&values, field.data_type(), field)?;
356 columns.push(array);
357 }
358
359 let batch = RecordBatch::try_new(schema.clone(), columns)?;
360 Ok(vec![batch])
361}
362
363fn value_column_to_arrow(
365 values: &[Value],
366 arrow_type: &DataType,
367 field: &arrow_schema::Field,
368) -> Result<arrow_array::ArrayRef> {
369 let is_cv_encoded = field
370 .metadata()
371 .get("cv_encoded")
372 .is_some_and(|v| v == "true");
373
374 if *arrow_type == DataType::LargeBinary || is_cv_encoded {
375 Ok(encode_as_large_binary(values))
376 } else if *arrow_type == DataType::Binary {
377 Ok(encode_as_binary(values))
379 } else {
380 arrow_convert::values_to_array(values, arrow_type)
382 .or_else(|_| Ok(encode_as_large_binary(values)))
383 }
384}
385
386macro_rules! encode_as_cv {
388 ($builder_ty:ty, $values:expr) => {{
389 let values = $values;
390 let mut builder = <$builder_ty>::with_capacity(values.len(), values.len() * 64);
391 for v in values {
392 if v.is_null() {
393 builder.append_null();
394 } else {
395 let bytes = uni_common::cypher_value_codec::encode(v);
396 builder.append_value(&bytes);
397 }
398 }
399 Arc::new(builder.finish()) as arrow_array::ArrayRef
400 }};
401}
402
403fn encode_as_binary(values: &[Value]) -> arrow_array::ArrayRef {
405 encode_as_cv!(arrow_array::builder::BinaryBuilder, values)
406}
407
408fn encode_as_large_binary(values: &[Value]) -> arrow_array::ArrayRef {
410 encode_as_cv!(arrow_array::builder::LargeBinaryBuilder, values)
411}
412
413pub fn execute_mutation_stream(
424 input: Arc<dyn ExecutionPlan>,
425 output_schema: SchemaRef,
426 mutation_ctx: Arc<MutationContext>,
427 mutation_kind: MutationKind,
428 partition: usize,
429 task_ctx: Arc<datafusion::execution::TaskContext>,
430) -> DFResult<SendableRecordBatchStream> {
431 if mutation_ctx.query_ctx.is_none() {
432 tracing::warn!(
433 "MutationContext.query_ctx is None — mutations may not see latest L0 buffer state"
434 );
435 }
436
437 let stream = futures::stream::once(execute_mutation_inner(
438 input,
439 output_schema.clone(),
440 mutation_ctx,
441 mutation_kind,
442 partition,
443 task_ctx,
444 ))
445 .try_flatten();
446
447 Ok(Box::pin(RecordBatchStreamAdapter::new(
448 output_schema,
449 stream,
450 )))
451}
452
453async fn execute_mutation_inner(
463 input: Arc<dyn ExecutionPlan>,
464 output_schema: SchemaRef,
465 mutation_ctx: Arc<MutationContext>,
466 mutation_kind: MutationKind,
467 partition: usize,
468 task_ctx: Arc<datafusion::execution::TaskContext>,
469) -> DFResult<futures::stream::Iter<std::vec::IntoIter<DFResult<RecordBatch>>>> {
470 let mutation_label = mutation_kind_label(&mutation_kind);
471
472 let input_stream = input.execute(partition, task_ctx)?;
474 let input_batches: Vec<RecordBatch> = input_stream.try_collect().await?;
475
476 let input_row_count: usize = input_batches.iter().map(|b| b.num_rows()).sum();
477 tracing::debug!(
478 mutation = mutation_label,
479 batches = input_batches.len(),
480 rows = input_row_count,
481 "Executing mutation"
482 );
483
484 let mut rows = batches_to_rows(&input_batches).map_err(|e| {
486 datafusion::error::DataFusionError::Execution(format!(
487 "Failed to convert batches to rows: {e}"
488 ))
489 })?;
490
491 if let MutationKind::Merge {
496 ref pattern,
497 ref on_match,
498 ref on_create,
499 } = mutation_kind
500 {
501 let exec = &mutation_ctx.executor;
502 let pm = &mutation_ctx.prop_manager;
503 let params = &mutation_ctx.params;
504 let ctx = mutation_ctx.query_ctx.as_ref();
505
506 let mut result_rows = exec
507 .execute_merge(
508 rows,
509 pattern,
510 on_match.as_ref(),
511 on_create.as_ref(),
512 pm,
513 params,
514 ctx,
515 mutation_ctx.tx_l0_override.as_ref(),
516 )
517 .await
518 .map_err(|e| {
519 datafusion::error::DataFusionError::Execution(format!("MERGE failed: {e}"))
520 })?;
521
522 tracing::debug!(
523 mutation = mutation_label,
524 input_rows = input_row_count,
525 output_rows = result_rows.len(),
526 "MERGE mutation complete"
527 );
528
529 sync_all_props_in_maps(&mut result_rows);
532 sync_dotted_columns(&mut result_rows, &output_schema);
533 let result_batches = rows_to_batches(&result_rows, &output_schema).map_err(|e| {
534 datafusion::error::DataFusionError::Execution(format!(
535 "Failed to reconstruct MERGE batches: {e}"
536 ))
537 })?;
538 let results: Vec<DFResult<RecordBatch>> = result_batches.into_iter().map(Ok).collect();
539 return Ok(futures::stream::iter(results));
540 }
541
542 let mut writer = mutation_ctx.writer.write().await;
543 let tx_l0 = mutation_ctx.tx_l0_override.as_ref();
544 let result =
545 apply_mutations(&mutation_ctx, &mutation_kind, &mut rows, &mut writer, tx_l0).await;
546 drop(writer);
547 result?;
548
549 tracing::debug!(
550 mutation = mutation_label,
551 rows = input_row_count,
552 "Mutation complete"
553 );
554
555 sync_all_props_in_maps(&mut rows);
560 sync_dotted_columns(&mut rows, &output_schema);
561 let result_batches = rows_to_batches(&rows, &output_schema).map_err(|e| {
562 datafusion::error::DataFusionError::Execution(format!("Failed to reconstruct batches: {e}"))
563 })?;
564 let results: Vec<DFResult<RecordBatch>> = result_batches.into_iter().map(Ok).collect();
565 Ok(futures::stream::iter(results))
566}
567
568struct DeleteCollector {
575 node_entries: Vec<(Vid, Option<Vec<String>>)>,
577 edge_vals: Vec<Value>,
579 seen_vids: HashSet<u64>,
581 seen_eids: HashSet<u64>,
582 dedup: bool,
583}
584
585impl DeleteCollector {
586 fn new(dedup: bool) -> Self {
587 Self {
588 node_entries: Vec::new(),
589 edge_vals: Vec::new(),
590 seen_vids: HashSet::new(),
591 seen_eids: HashSet::new(),
592 dedup,
593 }
594 }
595
596 fn add(&mut self, val: Value) {
597 if val.is_null() {
598 return;
599 }
600
601 let path = match &val {
603 Value::Path(p) => Some(p.clone()),
604 _ => Path::try_from(&val).ok(),
605 };
606
607 if let Some(path) = path {
608 for edge in &path.edges {
609 if !self.dedup || self.seen_eids.insert(edge.eid.as_u64()) {
610 self.edge_vals.push(Value::Edge(edge.clone()));
611 }
612 }
613 for node in &path.nodes {
614 self.add_node(node.vid, Some(node.labels.clone()));
615 }
616 return;
617 }
618
619 if let Ok(vid) = Executor::vid_from_value(&val) {
621 let labels = Executor::extract_labels_from_node(&val);
622 self.add_node(vid, labels);
623 return;
624 }
625
626 if matches!(&val, Value::Map(_) | Value::Edge(_)) {
628 self.edge_vals.push(val);
629 }
630 }
631
632 fn add_node(&mut self, vid: Vid, labels: Option<Vec<String>>) {
633 if self.dedup && !self.seen_vids.insert(vid.as_u64()) {
634 return;
635 }
636 self.node_entries.push((vid, labels));
637 }
638}
639
640async fn apply_mutations(
642 mutation_ctx: &MutationContext,
643 mutation_kind: &MutationKind,
644 rows: &mut [HashMap<String, Value>],
645 writer: &mut Writer,
646 tx_l0: Option<&Arc<parking_lot::RwLock<uni_store::runtime::l0::L0Buffer>>>,
647) -> DFResult<()> {
648 tracing::trace!(
649 mutation = mutation_kind_label(mutation_kind),
650 rows = rows.len(),
651 "Applying mutations"
652 );
653
654 let exec = &mutation_ctx.executor;
655 let pm = &mutation_ctx.prop_manager;
656 let params = &mutation_ctx.params;
657 let ctx = mutation_ctx.query_ctx.as_ref();
658
659 let df_err = |msg: &str, e: anyhow::Error| {
660 datafusion::error::DataFusionError::Execution(format!("{msg}: {e}"))
661 };
662
663 match mutation_kind {
664 MutationKind::Create { pattern } => {
665 for row in rows.iter_mut() {
666 exec.execute_create_pattern(pattern, row, writer, pm, params, ctx, tx_l0)
667 .await
668 .map_err(|e| df_err("CREATE failed", e))?;
669 }
670 }
671 MutationKind::CreateBatch { patterns } => {
672 for row in rows.iter_mut() {
673 for pattern in patterns {
674 exec.execute_create_pattern(pattern, row, writer, pm, params, ctx, tx_l0)
675 .await
676 .map_err(|e| df_err("CREATE failed", e))?;
677 }
678 }
679 }
680 MutationKind::Set { items } => {
681 for row in rows.iter_mut() {
682 exec.execute_set_items_locked(items, row, writer, pm, params, ctx, tx_l0)
683 .await
684 .map_err(|e| df_err("SET failed", e))?;
685 }
686 }
687 MutationKind::Remove { items } => {
688 for row in rows.iter_mut() {
689 exec.execute_remove_items_locked(items, row, writer, pm, ctx, tx_l0)
690 .await
691 .map_err(|e| df_err("REMOVE failed", e))?;
692 }
693 }
694 MutationKind::Delete { items, detach } => {
695 let mut collector = DeleteCollector::new(!*detach);
697 for row in rows.iter() {
698 for expr in items {
699 let val = exec
700 .evaluate_expr(expr, row, pm, params, ctx)
701 .await
702 .map_err(|e| df_err("DELETE eval failed", e))?;
703 collector.add(val);
704 }
705 }
706
707 for val in &collector.edge_vals {
709 exec.execute_delete_item_locked(val, false, writer, tx_l0)
710 .await
711 .map_err(|e| df_err("DELETE edge failed", e))?;
712 }
713
714 if *detach {
715 let (vids, labels): (Vec<Vid>, Vec<Option<Vec<String>>>) =
716 collector.node_entries.into_iter().unzip();
717 exec.batch_detach_delete_vertices(&vids, labels, writer, tx_l0)
718 .await
719 .map_err(|e| df_err("DETACH DELETE failed", e))?;
720 } else {
721 for (vid, labels) in &collector.node_entries {
722 exec.execute_delete_vertex(*vid, false, labels.clone(), writer, tx_l0)
723 .await
724 .map_err(|e| df_err("DELETE node failed", e))?;
725 }
726 }
727 }
728 MutationKind::Merge { .. } => {
729 unreachable!("MERGE mutations are handled before apply_mutations is called");
732 }
733 }
734
735 Ok(())
736}
737
738pub fn pattern_variable_names(pattern: &Pattern) -> Vec<String> {
743 let mut vars = Vec::new();
744 for path in &pattern.paths {
745 if let Some(ref v) = path.variable {
746 vars.push(v.clone());
747 }
748 for element in &path.elements {
749 match element {
750 PatternElement::Node(n) => {
751 if let Some(ref v) = n.variable {
752 vars.push(v.clone());
753 }
754 }
755 PatternElement::Relationship(r) => {
756 if let Some(ref v) = r.variable {
757 vars.push(v.clone());
758 }
759 }
760 PatternElement::Parenthesized { pattern, .. } => {
761 let sub = Pattern {
763 paths: vec![pattern.as_ref().clone()],
764 };
765 vars.extend(pattern_variable_names(&sub));
766 }
767 }
768 }
769 }
770 vars
771}
772
773fn normalize_mutation_schema(schema: &SchemaRef) -> SchemaRef {
781 use arrow_schema::{Field, Schema};
782
783 let needs_normalization = schema
784 .fields()
785 .iter()
786 .any(|f| matches!(f.data_type(), DataType::Struct(_)));
787
788 if !needs_normalization {
789 return schema.clone();
790 }
791
792 let fields: Vec<Arc<Field>> = schema
793 .fields()
794 .iter()
795 .map(|field| {
796 if matches!(field.data_type(), DataType::Struct(_)) {
797 let mut metadata = field.metadata().clone();
798 metadata.insert("cv_encoded".to_string(), "true".to_string());
799 Arc::new(
800 Field::new(field.name(), DataType::LargeBinary, true).with_metadata(metadata),
801 )
802 } else {
803 field.clone()
804 }
805 })
806 .collect();
807
808 Arc::new(Schema::new(fields))
809}
810
811pub fn extended_schema_for_new_vars(input_schema: &SchemaRef, patterns: &[Pattern]) -> SchemaRef {
826 use arrow_schema::{Field, Schema};
827
828 let normalized = normalize_mutation_schema(input_schema);
830
831 let existing_names: HashSet<&str> = normalized
832 .fields()
833 .iter()
834 .map(|f| f.name().as_str())
835 .collect();
836
837 let mut fields: Vec<Arc<arrow_schema::Field>> = normalized.fields().to_vec();
838 let mut added: HashSet<String> = HashSet::new();
839
840 fn cv_metadata() -> std::collections::HashMap<String, String> {
841 let mut m = std::collections::HashMap::new();
842 m.insert("cv_encoded".to_string(), "true".to_string());
843 m
844 }
845
846 fn add_bare_column(
847 var: &str,
848 fields: &mut Vec<Arc<arrow_schema::Field>>,
849 existing: &HashSet<&str>,
850 added: &mut HashSet<String>,
851 ) -> bool {
852 if existing.contains(var) || added.contains(var) {
853 return false;
854 }
855 added.insert(var.to_string());
856 fields.push(Arc::new(
857 Field::new(var, DataType::LargeBinary, true).with_metadata(cv_metadata()),
858 ));
859 true
860 }
861
862 for pattern in patterns {
863 for path in &pattern.paths {
864 if let Some(ref var) = path.variable {
866 add_bare_column(var, &mut fields, &existing_names, &mut added);
867 }
868 for element in &path.elements {
869 match element {
870 PatternElement::Node(n) => {
871 if let Some(ref var) = n.variable
872 && add_bare_column(var, &mut fields, &existing_names, &mut added)
873 {
874 fields.push(Arc::new(Field::new(
876 format!("{var}._vid"),
877 DataType::UInt64,
878 true,
879 )));
880 fields.push(Arc::new(
881 Field::new(format!("{var}._labels"), DataType::LargeBinary, true)
882 .with_metadata(cv_metadata()),
883 ));
884 }
885 }
886 PatternElement::Relationship(r) => {
887 if let Some(ref var) = r.variable
888 && add_bare_column(var, &mut fields, &existing_names, &mut added)
889 {
890 fields.push(Arc::new(Field::new(
892 format!("{var}._eid"),
893 DataType::UInt64,
894 true,
895 )));
896 fields.push(Arc::new(
897 Field::new(format!("{var}._type"), DataType::LargeBinary, true)
898 .with_metadata(cv_metadata()),
899 ));
900 }
901 }
902 PatternElement::Parenthesized { pattern, .. } => {
903 let sub = Pattern {
907 paths: vec![pattern.as_ref().clone()],
908 };
909 let sub_schema = extended_schema_for_new_vars(
910 &Arc::new(Schema::new(fields.clone())),
911 &[sub],
912 );
913 for field in sub_schema.fields() {
916 added.insert(field.name().clone());
917 }
918 fields = sub_schema.fields().to_vec();
919 }
920 }
921 }
922 }
923 }
924
925 Arc::new(Schema::new(fields))
926}
927
928fn mutation_kind_label(kind: &MutationKind) -> &'static str {
930 match kind {
931 MutationKind::Create { .. } => "CREATE",
932 MutationKind::CreateBatch { .. } => "CREATE_BATCH",
933 MutationKind::Set { .. } => "SET",
934 MutationKind::Remove { .. } => "REMOVE",
935 MutationKind::Delete { .. } => "DELETE",
936 MutationKind::Merge { .. } => "MERGE",
937 }
938}
939
940#[derive(Debug)]
953pub struct MutationExec {
954 input: Arc<dyn ExecutionPlan>,
956
957 kind: MutationKind,
959
960 display_name: &'static str,
962
963 mutation_ctx: Arc<MutationContext>,
965
966 schema: SchemaRef,
968
969 properties: PlanProperties,
971
972 metrics: ExecutionPlanMetricsSet,
974}
975
976impl MutationExec {
977 pub fn new(
983 input: Arc<dyn ExecutionPlan>,
984 kind: MutationKind,
985 display_name: &'static str,
986 mutation_ctx: Arc<MutationContext>,
987 ) -> Self {
988 let schema = normalize_mutation_schema(&input.schema());
989 let properties = compute_plan_properties(schema.clone());
990 Self {
991 input,
992 kind,
993 display_name,
994 mutation_ctx,
995 schema,
996 properties,
997 metrics: ExecutionPlanMetricsSet::new(),
998 }
999 }
1000
1001 pub fn new_with_schema(
1006 input: Arc<dyn ExecutionPlan>,
1007 kind: MutationKind,
1008 display_name: &'static str,
1009 mutation_ctx: Arc<MutationContext>,
1010 output_schema: SchemaRef,
1011 ) -> Self {
1012 let properties = compute_plan_properties(output_schema.clone());
1013 Self {
1014 input,
1015 kind,
1016 display_name,
1017 mutation_ctx,
1018 schema: output_schema,
1019 properties,
1020 metrics: ExecutionPlanMetricsSet::new(),
1021 }
1022 }
1023}
1024
1025impl DisplayAs for MutationExec {
1026 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
1027 if matches!(&self.kind, MutationKind::Delete { detach: true, .. }) {
1028 write!(f, "{} [DETACH]", self.display_name)
1029 } else {
1030 write!(f, "{}", self.display_name)
1031 }
1032 }
1033}
1034
1035impl ExecutionPlan for MutationExec {
1036 fn name(&self) -> &str {
1037 self.display_name
1038 }
1039
1040 fn as_any(&self) -> &dyn Any {
1041 self
1042 }
1043
1044 fn schema(&self) -> SchemaRef {
1045 self.schema.clone()
1046 }
1047
1048 fn properties(&self) -> &PlanProperties {
1049 &self.properties
1050 }
1051
1052 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1053 vec![&self.input]
1054 }
1055
1056 fn with_new_children(
1057 self: Arc<Self>,
1058 children: Vec<Arc<dyn ExecutionPlan>>,
1059 ) -> DFResult<Arc<dyn ExecutionPlan>> {
1060 if children.len() != 1 {
1061 return Err(datafusion::error::DataFusionError::Plan(format!(
1062 "{} requires exactly one child",
1063 self.display_name,
1064 )));
1065 }
1066 Ok(Arc::new(MutationExec::new_with_schema(
1067 children[0].clone(),
1068 self.kind.clone(),
1069 self.display_name,
1070 self.mutation_ctx.clone(),
1071 self.schema.clone(),
1072 )))
1073 }
1074
1075 fn execute(
1076 &self,
1077 partition: usize,
1078 context: Arc<TaskContext>,
1079 ) -> DFResult<SendableRecordBatchStream> {
1080 execute_mutation_stream(
1081 self.input.clone(),
1082 self.schema.clone(),
1083 self.mutation_ctx.clone(),
1084 self.kind.clone(),
1085 partition,
1086 context,
1087 )
1088 }
1089
1090 fn metrics(&self) -> Option<MetricsSet> {
1091 Some(self.metrics.clone_inner())
1092 }
1093}
1094
1095pub fn new_create_exec(
1101 input: Arc<dyn ExecutionPlan>,
1102 pattern: Pattern,
1103 mutation_ctx: Arc<MutationContext>,
1104) -> MutationExec {
1105 let output_schema =
1106 extended_schema_for_new_vars(&input.schema(), std::slice::from_ref(&pattern));
1107 MutationExec::new_with_schema(
1108 input,
1109 MutationKind::Create { pattern },
1110 "MutationCreateExec",
1111 mutation_ctx,
1112 output_schema,
1113 )
1114}
1115
1116pub fn new_merge_exec(
1122 input: Arc<dyn ExecutionPlan>,
1123 pattern: Pattern,
1124 on_match: Option<SetClause>,
1125 on_create: Option<SetClause>,
1126 mutation_ctx: Arc<MutationContext>,
1127) -> MutationExec {
1128 let output_schema =
1129 extended_schema_for_new_vars(&input.schema(), std::slice::from_ref(&pattern));
1130 MutationExec::new_with_schema(
1131 input,
1132 MutationKind::Merge {
1133 pattern,
1134 on_match,
1135 on_create,
1136 },
1137 "MutationMergeExec",
1138 mutation_ctx,
1139 output_schema,
1140 )
1141}
1142
1143#[cfg(test)]
1144mod tests {
1145 use super::*;
1146 use arrow_array::{Int64Array, StringArray};
1147 use arrow_schema::{Field, Schema};
1148
1149 #[test]
1150 fn test_batches_to_rows_basic() {
1151 let schema = Arc::new(Schema::new(vec![
1152 Field::new("name", DataType::Utf8, true),
1153 Field::new("age", DataType::Int64, true),
1154 ]));
1155
1156 let batch = RecordBatch::try_new(
1157 schema,
1158 vec![
1159 Arc::new(StringArray::from(vec![Some("Alice"), Some("Bob")])),
1160 Arc::new(Int64Array::from(vec![Some(30), Some(25)])),
1161 ],
1162 )
1163 .unwrap();
1164
1165 let rows = batches_to_rows(&[batch]).unwrap();
1166 assert_eq!(rows.len(), 2);
1167 assert_eq!(rows[0].get("name"), Some(&Value::String("Alice".into())));
1168 assert_eq!(rows[0].get("age"), Some(&Value::Int(30)));
1169 assert_eq!(rows[1].get("name"), Some(&Value::String("Bob".into())));
1170 assert_eq!(rows[1].get("age"), Some(&Value::Int(25)));
1171 }
1172
1173 #[test]
1174 fn test_rows_to_batches_basic() {
1175 let schema = Arc::new(Schema::new(vec![
1176 Field::new("name", DataType::Utf8, true),
1177 Field::new("age", DataType::Int64, true),
1178 ]));
1179
1180 let rows = vec![
1181 {
1182 let mut m = HashMap::new();
1183 m.insert("name".to_string(), Value::String("Alice".into()));
1184 m.insert("age".to_string(), Value::Int(30));
1185 m
1186 },
1187 {
1188 let mut m = HashMap::new();
1189 m.insert("name".to_string(), Value::String("Bob".into()));
1190 m.insert("age".to_string(), Value::Int(25));
1191 m
1192 },
1193 ];
1194
1195 let batches = rows_to_batches(&rows, &schema).unwrap();
1196 assert_eq!(batches.len(), 1);
1197 assert_eq!(batches[0].num_rows(), 2);
1198 assert_eq!(batches[0].schema(), schema);
1199 }
1200
1201 #[test]
1202 fn test_roundtrip_scalar_types() {
1203 let schema = Arc::new(Schema::new(vec![
1204 Field::new("s", DataType::Utf8, true),
1205 Field::new("i", DataType::Int64, true),
1206 Field::new("f", DataType::Float64, true),
1207 Field::new("b", DataType::Boolean, true),
1208 ]));
1209
1210 let batch = RecordBatch::try_new(
1211 schema.clone(),
1212 vec![
1213 Arc::new(StringArray::from(vec![Some("hello")])),
1214 Arc::new(Int64Array::from(vec![Some(42)])),
1215 Arc::new(arrow_array::Float64Array::from(vec![Some(3.125)])),
1216 Arc::new(arrow_array::BooleanArray::from(vec![Some(true)])),
1217 ],
1218 )
1219 .unwrap();
1220
1221 let rows = batches_to_rows(&[batch]).unwrap();
1223 let output_batches = rows_to_batches(&rows, &schema).unwrap();
1224
1225 assert_eq!(output_batches.len(), 1);
1226 assert_eq!(output_batches[0].num_rows(), 1);
1227
1228 let roundtrip_rows = batches_to_rows(&output_batches).unwrap();
1230 assert_eq!(roundtrip_rows.len(), 1);
1231 assert_eq!(
1232 roundtrip_rows[0].get("s"),
1233 Some(&Value::String("hello".into()))
1234 );
1235 assert_eq!(roundtrip_rows[0].get("i"), Some(&Value::Int(42)));
1236 assert_eq!(roundtrip_rows[0].get("b"), Some(&Value::Bool(true)));
1237 if let Some(Value::Float(f)) = roundtrip_rows[0].get("f") {
1239 assert!((*f - 3.125).abs() < 1e-10);
1240 } else {
1241 panic!("Expected float value");
1242 }
1243 }
1244
1245 #[test]
1246 fn test_roundtrip_cypher_value_encoded() {
1247 use std::collections::HashMap as StdHashMap;
1248
1249 let mut metadata = StdHashMap::new();
1251 metadata.insert("cv_encoded".to_string(), "true".to_string());
1252 let field = Field::new("n", DataType::LargeBinary, true).with_metadata(metadata);
1253 let schema = Arc::new(Schema::new(vec![field]));
1254
1255 let mut node_map = HashMap::new();
1257 node_map.insert("name".to_string(), Value::String("Alice".into()));
1258 node_map.insert("_vid".to_string(), Value::Int(1));
1259 let map_val = Value::Map(node_map);
1260
1261 let encoded = uni_common::cypher_value_codec::encode(&map_val);
1263 let batch = RecordBatch::try_new(
1264 schema.clone(),
1265 vec![Arc::new(arrow_array::LargeBinaryArray::from(vec![Some(
1266 encoded.as_slice(),
1267 )]))],
1268 )
1269 .unwrap();
1270
1271 let rows = batches_to_rows(&[batch]).unwrap();
1273 assert_eq!(rows.len(), 1);
1274
1275 let val = rows[0].get("n").unwrap();
1277 assert!(matches!(val, Value::Map(_)));
1278
1279 let output_batches = rows_to_batches(&rows, &schema).unwrap();
1280 assert_eq!(output_batches[0].num_rows(), 1);
1281
1282 let roundtrip_rows = batches_to_rows(&output_batches).unwrap();
1284 assert_eq!(roundtrip_rows.len(), 1);
1285 }
1286
1287 #[test]
1288 fn test_empty_rows() {
1289 let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, true)]));
1290
1291 let batches = rows_to_batches(&[], &schema).unwrap();
1292 assert_eq!(batches.len(), 1);
1293 assert_eq!(batches[0].num_rows(), 0);
1294 }
1295}