1use anyhow::Result;
14use arrow_array::RecordBatch;
15use arrow_schema::{DataType, SchemaRef};
16use datafusion::common::Result as DFResult;
17use datafusion::execution::TaskContext;
18use datafusion::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet};
19use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
20use datafusion::physical_plan::{
21 DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SendableRecordBatchStream,
22};
23use futures::TryStreamExt;
24use std::any::Any;
25use std::collections::{HashMap, HashSet};
26use std::fmt;
27use std::sync::Arc;
28use tokio::sync::RwLock;
29use uni_common::core::id::Vid;
30use uni_common::{Path, Value};
31use uni_cypher::ast::{Expr, Pattern, PatternElement, RemoveItem, SetClause, SetItem};
32use uni_store::runtime::property_manager::PropertyManager;
33use uni_store::runtime::writer::Writer;
34use uni_store::storage::arrow_convert;
35
36use super::common::compute_plan_properties;
37use crate::query::executor::core::Executor;
38
39#[derive(Clone)]
45pub struct MutationContext {
46 pub executor: Executor,
48
49 pub writer: Arc<RwLock<Writer>>,
51
52 pub prop_manager: Arc<PropertyManager>,
54
55 pub params: HashMap<String, Value>,
57
58 pub query_ctx: Option<uni_store::QueryContext>,
60}
61
62impl std::fmt::Debug for MutationContext {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 f.debug_struct("MutationContext")
65 .field("has_writer", &true)
66 .field("has_prop_manager", &true)
67 .field("params_count", &self.params.len())
68 .field("has_query_ctx", &self.query_ctx.is_some())
69 .finish()
70 }
71}
72
73#[derive(Debug, Clone)]
75pub enum MutationKind {
76 Create { pattern: Pattern },
78
79 CreateBatch { patterns: Vec<Pattern> },
81
82 Set { items: Vec<SetItem> },
84
85 Remove { items: Vec<RemoveItem> },
87
88 Delete { items: Vec<Expr>, detach: bool },
90
91 Merge {
93 pattern: Pattern,
94 on_match: Option<SetClause>,
95 on_create: Option<SetClause>,
96 },
97}
98
99pub fn batches_to_rows(batches: &[RecordBatch]) -> Result<Vec<HashMap<String, Value>>> {
110 let mut rows = Vec::new();
111
112 for batch in batches {
113 let num_rows = batch.num_rows();
114 let schema = batch.schema();
115
116 for row_idx in 0..num_rows {
117 let mut row = HashMap::new();
118
119 for (col_idx, field) in schema.fields().iter().enumerate() {
120 let column = batch.column(col_idx);
121 let data_type = if uni_common::core::schema::is_datetime_struct(field.data_type()) {
123 Some(&uni_common::DataType::DateTime)
124 } else if uni_common::core::schema::is_time_struct(field.data_type()) {
125 Some(&uni_common::DataType::Time)
126 } else {
127 None
128 };
129 let mut value = arrow_convert::arrow_to_value(column.as_ref(), row_idx, data_type);
130
131 if field
134 .metadata()
135 .get("cv_encoded")
136 .is_some_and(|v| v == "true")
137 && let Value::String(s) = &value
138 && let Ok(parsed) = serde_json::from_str::<serde_json::Value>(s)
139 {
140 value = Value::from(parsed);
141 }
142
143 row.insert(field.name().clone(), value);
144 }
145
146 merge_system_fields_for_write(&mut row);
151
152 rows.push(row);
153 }
154 }
155
156 Ok(rows)
157}
158
159fn sync_all_props_in_maps(rows: &mut [HashMap<String, Value>]) {
168 for row in rows {
169 let map_keys: Vec<String> = row
170 .keys()
171 .filter(|k| !k.contains('.') && matches!(row.get(*k), Some(Value::Map(_))))
172 .cloned()
173 .collect();
174
175 for key in map_keys {
176 if let Some(Value::Map(map)) = row.get_mut(&key)
177 && map.contains_key("_all_props")
178 {
179 let updates: Vec<(String, Value)> = map
181 .iter()
182 .filter(|(k, _)| !k.starts_with('_') && k.as_str() != "ext_id")
183 .map(|(k, v)| (k.clone(), v.clone()))
184 .collect();
185
186 if !updates.is_empty()
187 && let Some(Value::Map(all_props)) = map.get_mut("_all_props")
188 {
189 for (k, v) in updates {
190 all_props.insert(k, v);
191 }
192 }
193 }
194 }
195 }
196}
197
198fn sync_dotted_columns(rows: &mut [HashMap<String, Value>], schema: &SchemaRef) {
206 for row in rows {
207 for field in schema.fields() {
208 let name = field.name();
209 if let Some(dot_pos) = name.find('.') {
210 let var_name = &name[..dot_pos];
211 let prop_name = &name[dot_pos + 1..];
212 if let Some(Value::Map(map)) = row.get(var_name) {
213 let val = map.get(prop_name).cloned().unwrap_or(Value::Null);
214 row.insert(name.clone(), val);
215 }
216 }
217 }
218 }
219}
220
221fn normalize_edge_field_names(map: &mut HashMap<String, Value>) {
225 if let Some(val) = map.remove("_src_vid") {
226 map.entry("_src".to_string()).or_insert(val);
227 }
228 if let Some(val) = map.remove("_dst_vid") {
229 map.entry("_dst".to_string()).or_insert(val);
230 }
231}
232
233fn merge_system_fields_for_write(row: &mut HashMap<String, Value>) {
239 const VERTEX_FIELDS: &[&str] = &["_vid", "_labels"];
242 const EDGE_FIELDS: &[&str] = &["_eid", "_type", "_src_vid", "_dst_vid"];
243
244 let dotted_vars: HashSet<String> = row
246 .keys()
247 .filter_map(|key| key.find('.').map(|pos| key[..pos].to_string()))
248 .collect();
249
250 for var in &dotted_vars {
254 if !row.contains_key(var) {
255 let prefix = format!("{var}.");
256 let mut map: HashMap<String, Value> = row
257 .iter()
258 .filter_map(|(k, v)| {
259 k.strip_prefix(prefix.as_str())
260 .map(|field| (field.to_string(), v.clone()))
261 })
262 .collect();
263 normalize_edge_field_names(&mut map);
264 if !map.is_empty() {
265 row.insert(var.clone(), Value::Map(map));
266 }
267 }
268 }
269
270 let bare_vars: Vec<String> = row
273 .keys()
274 .filter(|k| !k.contains('.') && matches!(row.get(*k), Some(Value::Map(_))))
275 .cloned()
276 .collect();
277
278 for var in &bare_vars {
279 let vertex_vals: Vec<(&str, Value)> = VERTEX_FIELDS
281 .iter()
282 .filter_map(|&field| {
283 row.get(&format!("{var}.{field}"))
284 .cloned()
285 .map(|v| (field, v))
286 })
287 .collect();
288 let edge_vals: Vec<(&str, Value)> = EDGE_FIELDS
289 .iter()
290 .filter_map(|&field| {
291 row.get(&format!("{var}.{field}"))
292 .cloned()
293 .map(|v| (field, v))
294 })
295 .collect();
296
297 if let Some(Value::Map(map)) = row.get_mut(var) {
298 for (field, v) in vertex_vals {
299 map.insert(field.to_string(), v);
300 }
301 for (field, v) in edge_vals {
302 map.entry(field.to_string()).or_insert(v);
303 }
304 normalize_edge_field_names(map);
305 }
306 }
307}
308
309pub fn rows_to_batches(
318 rows: &[HashMap<String, Value>],
319 schema: &SchemaRef,
320) -> Result<Vec<RecordBatch>> {
321 if rows.is_empty() {
322 let batch = if schema.fields().is_empty() {
324 let options = arrow_array::RecordBatchOptions::new().with_row_count(Some(0));
325 RecordBatch::try_new_with_options(schema.clone(), vec![], &options)?
326 } else {
327 RecordBatch::new_empty(schema.clone())
328 };
329 return Ok(vec![batch]);
330 }
331
332 if schema.fields().is_empty() {
333 let options = arrow_array::RecordBatchOptions::new().with_row_count(Some(rows.len()));
337 let batch = RecordBatch::try_new_with_options(schema.clone(), vec![], &options)?;
338 return Ok(vec![batch]);
339 }
340
341 let mut columns: Vec<arrow_array::ArrayRef> = Vec::with_capacity(schema.fields().len());
343
344 for field in schema.fields() {
345 let name = field.name();
346 let values: Vec<Value> = rows
347 .iter()
348 .map(|row| row.get(name).cloned().unwrap_or(Value::Null))
349 .collect();
350
351 let array = value_column_to_arrow(&values, field.data_type(), field)?;
352 columns.push(array);
353 }
354
355 let batch = RecordBatch::try_new(schema.clone(), columns)?;
356 Ok(vec![batch])
357}
358
359fn value_column_to_arrow(
361 values: &[Value],
362 arrow_type: &DataType,
363 field: &arrow_schema::Field,
364) -> Result<arrow_array::ArrayRef> {
365 let is_cv_encoded = field
366 .metadata()
367 .get("cv_encoded")
368 .is_some_and(|v| v == "true");
369
370 if *arrow_type == DataType::LargeBinary || is_cv_encoded {
371 Ok(encode_as_large_binary(values))
372 } else if *arrow_type == DataType::Binary {
373 Ok(encode_as_binary(values))
375 } else {
376 arrow_convert::values_to_array(values, arrow_type)
378 .or_else(|_| Ok(encode_as_large_binary(values)))
379 }
380}
381
382macro_rules! encode_as_cv {
384 ($builder_ty:ty, $values:expr) => {{
385 let values = $values;
386 let mut builder = <$builder_ty>::with_capacity(values.len(), values.len() * 64);
387 for v in values {
388 if v.is_null() {
389 builder.append_null();
390 } else {
391 let bytes = uni_common::cypher_value_codec::encode(v);
392 builder.append_value(&bytes);
393 }
394 }
395 Arc::new(builder.finish()) as arrow_array::ArrayRef
396 }};
397}
398
399fn encode_as_binary(values: &[Value]) -> arrow_array::ArrayRef {
401 encode_as_cv!(arrow_array::builder::BinaryBuilder, values)
402}
403
404fn encode_as_large_binary(values: &[Value]) -> arrow_array::ArrayRef {
406 encode_as_cv!(arrow_array::builder::LargeBinaryBuilder, values)
407}
408
409pub fn execute_mutation_stream(
420 input: Arc<dyn ExecutionPlan>,
421 output_schema: SchemaRef,
422 mutation_ctx: Arc<MutationContext>,
423 mutation_kind: MutationKind,
424 partition: usize,
425 task_ctx: Arc<datafusion::execution::TaskContext>,
426) -> DFResult<SendableRecordBatchStream> {
427 if mutation_ctx.query_ctx.is_none() {
428 tracing::warn!(
429 "MutationContext.query_ctx is None — mutations may not see latest L0 buffer state"
430 );
431 }
432
433 let stream = futures::stream::once(execute_mutation_inner(
434 input,
435 output_schema.clone(),
436 mutation_ctx,
437 mutation_kind,
438 partition,
439 task_ctx,
440 ))
441 .try_flatten();
442
443 Ok(Box::pin(RecordBatchStreamAdapter::new(
444 output_schema,
445 stream,
446 )))
447}
448
449async fn execute_mutation_inner(
459 input: Arc<dyn ExecutionPlan>,
460 output_schema: SchemaRef,
461 mutation_ctx: Arc<MutationContext>,
462 mutation_kind: MutationKind,
463 partition: usize,
464 task_ctx: Arc<datafusion::execution::TaskContext>,
465) -> DFResult<futures::stream::Iter<std::vec::IntoIter<DFResult<RecordBatch>>>> {
466 let mutation_label = mutation_kind_label(&mutation_kind);
467
468 let input_stream = input.execute(partition, task_ctx)?;
470 let input_batches: Vec<RecordBatch> = input_stream.try_collect().await?;
471
472 let input_row_count: usize = input_batches.iter().map(|b| b.num_rows()).sum();
473 tracing::debug!(
474 mutation = mutation_label,
475 batches = input_batches.len(),
476 rows = input_row_count,
477 "Executing mutation"
478 );
479
480 let mut rows = batches_to_rows(&input_batches).map_err(|e| {
482 datafusion::error::DataFusionError::Execution(format!(
483 "Failed to convert batches to rows: {e}"
484 ))
485 })?;
486
487 if let MutationKind::Merge {
492 ref pattern,
493 ref on_match,
494 ref on_create,
495 } = mutation_kind
496 {
497 let exec = &mutation_ctx.executor;
498 let pm = &mutation_ctx.prop_manager;
499 let params = &mutation_ctx.params;
500 let ctx = mutation_ctx.query_ctx.as_ref();
501
502 let mut result_rows = exec
503 .execute_merge(
504 rows,
505 pattern,
506 on_match.as_ref(),
507 on_create.as_ref(),
508 pm,
509 params,
510 ctx,
511 )
512 .await
513 .map_err(|e| {
514 datafusion::error::DataFusionError::Execution(format!("MERGE failed: {e}"))
515 })?;
516
517 tracing::debug!(
518 mutation = mutation_label,
519 input_rows = input_row_count,
520 output_rows = result_rows.len(),
521 "MERGE mutation complete"
522 );
523
524 sync_all_props_in_maps(&mut result_rows);
527 sync_dotted_columns(&mut result_rows, &output_schema);
528 let result_batches = rows_to_batches(&result_rows, &output_schema).map_err(|e| {
529 datafusion::error::DataFusionError::Execution(format!(
530 "Failed to reconstruct MERGE batches: {e}"
531 ))
532 })?;
533 let results: Vec<DFResult<RecordBatch>> = result_batches.into_iter().map(Ok).collect();
534 return Ok(futures::stream::iter(results));
535 }
536
537 let mut writer = mutation_ctx.writer.write().await;
538 apply_mutations(&mutation_ctx, &mutation_kind, &mut rows, &mut writer).await?;
539 drop(writer);
540
541 tracing::debug!(
542 mutation = mutation_label,
543 rows = input_row_count,
544 "Mutation complete"
545 );
546
547 sync_all_props_in_maps(&mut rows);
552 sync_dotted_columns(&mut rows, &output_schema);
553 let result_batches = rows_to_batches(&rows, &output_schema).map_err(|e| {
554 datafusion::error::DataFusionError::Execution(format!("Failed to reconstruct batches: {e}"))
555 })?;
556 let results: Vec<DFResult<RecordBatch>> = result_batches.into_iter().map(Ok).collect();
557 Ok(futures::stream::iter(results))
558}
559
560struct DeleteCollector {
567 node_entries: Vec<(Vid, Option<Vec<String>>)>,
569 edge_vals: Vec<Value>,
571 seen_vids: HashSet<u64>,
573 seen_eids: HashSet<u64>,
574 dedup: bool,
575}
576
577impl DeleteCollector {
578 fn new(dedup: bool) -> Self {
579 Self {
580 node_entries: Vec::new(),
581 edge_vals: Vec::new(),
582 seen_vids: HashSet::new(),
583 seen_eids: HashSet::new(),
584 dedup,
585 }
586 }
587
588 fn add(&mut self, val: Value) {
589 if val.is_null() {
590 return;
591 }
592
593 let path = match &val {
595 Value::Path(p) => Some(p.clone()),
596 _ => Path::try_from(&val).ok(),
597 };
598
599 if let Some(path) = path {
600 for edge in &path.edges {
601 if !self.dedup || self.seen_eids.insert(edge.eid.as_u64()) {
602 self.edge_vals.push(Value::Edge(edge.clone()));
603 }
604 }
605 for node in &path.nodes {
606 self.add_node(node.vid, Some(node.labels.clone()));
607 }
608 return;
609 }
610
611 if let Ok(vid) = Executor::vid_from_value(&val) {
613 let labels = Executor::extract_labels_from_node(&val);
614 self.add_node(vid, labels);
615 return;
616 }
617
618 if matches!(&val, Value::Map(_) | Value::Edge(_)) {
620 self.edge_vals.push(val);
621 }
622 }
623
624 fn add_node(&mut self, vid: Vid, labels: Option<Vec<String>>) {
625 if self.dedup && !self.seen_vids.insert(vid.as_u64()) {
626 return;
627 }
628 self.node_entries.push((vid, labels));
629 }
630}
631
632async fn apply_mutations(
634 mutation_ctx: &MutationContext,
635 mutation_kind: &MutationKind,
636 rows: &mut [HashMap<String, Value>],
637 writer: &mut Writer,
638) -> DFResult<()> {
639 tracing::trace!(
640 mutation = mutation_kind_label(mutation_kind),
641 rows = rows.len(),
642 "Applying mutations"
643 );
644
645 let exec = &mutation_ctx.executor;
646 let pm = &mutation_ctx.prop_manager;
647 let params = &mutation_ctx.params;
648 let ctx = mutation_ctx.query_ctx.as_ref();
649
650 let df_err = |msg: &str, e: anyhow::Error| {
651 datafusion::error::DataFusionError::Execution(format!("{msg}: {e}"))
652 };
653
654 match mutation_kind {
655 MutationKind::Create { pattern } => {
656 for row in rows.iter_mut() {
657 exec.execute_create_pattern(pattern, row, writer, pm, params, ctx)
658 .await
659 .map_err(|e| df_err("CREATE failed", e))?;
660 }
661 }
662 MutationKind::CreateBatch { patterns } => {
663 for row in rows.iter_mut() {
664 for pattern in patterns {
665 exec.execute_create_pattern(pattern, row, writer, pm, params, ctx)
666 .await
667 .map_err(|e| df_err("CREATE failed", e))?;
668 }
669 }
670 }
671 MutationKind::Set { items } => {
672 for row in rows.iter_mut() {
673 exec.execute_set_items_locked(items, row, writer, pm, params, ctx)
674 .await
675 .map_err(|e| df_err("SET failed", e))?;
676 }
677 }
678 MutationKind::Remove { items } => {
679 for row in rows.iter_mut() {
680 exec.execute_remove_items_locked(items, row, writer, pm, ctx)
681 .await
682 .map_err(|e| df_err("REMOVE failed", e))?;
683 }
684 }
685 MutationKind::Delete { items, detach } => {
686 let mut collector = DeleteCollector::new(!*detach);
688 for row in rows.iter() {
689 for expr in items {
690 let val = exec
691 .evaluate_expr(expr, row, pm, params, ctx)
692 .await
693 .map_err(|e| df_err("DELETE eval failed", e))?;
694 collector.add(val);
695 }
696 }
697
698 for val in &collector.edge_vals {
700 exec.execute_delete_item_locked(val, false, writer)
701 .await
702 .map_err(|e| df_err("DELETE edge failed", e))?;
703 }
704
705 if *detach {
706 let (vids, labels): (Vec<Vid>, Vec<Option<Vec<String>>>) =
707 collector.node_entries.into_iter().unzip();
708 exec.batch_detach_delete_vertices(&vids, labels, writer)
709 .await
710 .map_err(|e| df_err("DETACH DELETE failed", e))?;
711 } else {
712 for (vid, labels) in &collector.node_entries {
713 exec.execute_delete_vertex(*vid, false, labels.clone(), writer)
714 .await
715 .map_err(|e| df_err("DELETE node failed", e))?;
716 }
717 }
718 }
719 MutationKind::Merge { .. } => {
720 unreachable!("MERGE mutations are handled before apply_mutations is called");
723 }
724 }
725
726 Ok(())
727}
728
729pub fn pattern_variable_names(pattern: &Pattern) -> Vec<String> {
734 let mut vars = Vec::new();
735 for path in &pattern.paths {
736 if let Some(ref v) = path.variable {
737 vars.push(v.clone());
738 }
739 for element in &path.elements {
740 match element {
741 PatternElement::Node(n) => {
742 if let Some(ref v) = n.variable {
743 vars.push(v.clone());
744 }
745 }
746 PatternElement::Relationship(r) => {
747 if let Some(ref v) = r.variable {
748 vars.push(v.clone());
749 }
750 }
751 PatternElement::Parenthesized { pattern, .. } => {
752 let sub = Pattern {
754 paths: vec![pattern.as_ref().clone()],
755 };
756 vars.extend(pattern_variable_names(&sub));
757 }
758 }
759 }
760 }
761 vars
762}
763
764fn normalize_mutation_schema(schema: &SchemaRef) -> SchemaRef {
772 use arrow_schema::{Field, Schema};
773
774 let needs_normalization = schema
775 .fields()
776 .iter()
777 .any(|f| matches!(f.data_type(), DataType::Struct(_)));
778
779 if !needs_normalization {
780 return schema.clone();
781 }
782
783 let fields: Vec<Arc<Field>> = schema
784 .fields()
785 .iter()
786 .map(|field| {
787 if matches!(field.data_type(), DataType::Struct(_)) {
788 let mut metadata = field.metadata().clone();
789 metadata.insert("cv_encoded".to_string(), "true".to_string());
790 Arc::new(
791 Field::new(field.name(), DataType::LargeBinary, true).with_metadata(metadata),
792 )
793 } else {
794 field.clone()
795 }
796 })
797 .collect();
798
799 Arc::new(Schema::new(fields))
800}
801
802pub fn extended_schema_for_new_vars(input_schema: &SchemaRef, patterns: &[Pattern]) -> SchemaRef {
817 use arrow_schema::{Field, Schema};
818
819 let normalized = normalize_mutation_schema(input_schema);
821
822 let existing_names: HashSet<&str> = normalized
823 .fields()
824 .iter()
825 .map(|f| f.name().as_str())
826 .collect();
827
828 let mut fields: Vec<Arc<arrow_schema::Field>> = normalized.fields().to_vec();
829 let mut added: HashSet<String> = HashSet::new();
830
831 fn cv_metadata() -> std::collections::HashMap<String, String> {
832 let mut m = std::collections::HashMap::new();
833 m.insert("cv_encoded".to_string(), "true".to_string());
834 m
835 }
836
837 fn add_bare_column(
838 var: &str,
839 fields: &mut Vec<Arc<arrow_schema::Field>>,
840 existing: &HashSet<&str>,
841 added: &mut HashSet<String>,
842 ) -> bool {
843 if existing.contains(var) || added.contains(var) {
844 return false;
845 }
846 added.insert(var.to_string());
847 fields.push(Arc::new(
848 Field::new(var, DataType::LargeBinary, true).with_metadata(cv_metadata()),
849 ));
850 true
851 }
852
853 for pattern in patterns {
854 for path in &pattern.paths {
855 if let Some(ref var) = path.variable {
857 add_bare_column(var, &mut fields, &existing_names, &mut added);
858 }
859 for element in &path.elements {
860 match element {
861 PatternElement::Node(n) => {
862 if let Some(ref var) = n.variable
863 && add_bare_column(var, &mut fields, &existing_names, &mut added)
864 {
865 fields.push(Arc::new(Field::new(
867 format!("{var}._vid"),
868 DataType::UInt64,
869 true,
870 )));
871 fields.push(Arc::new(
872 Field::new(format!("{var}._labels"), DataType::LargeBinary, true)
873 .with_metadata(cv_metadata()),
874 ));
875 }
876 }
877 PatternElement::Relationship(r) => {
878 if let Some(ref var) = r.variable
879 && add_bare_column(var, &mut fields, &existing_names, &mut added)
880 {
881 fields.push(Arc::new(Field::new(
883 format!("{var}._eid"),
884 DataType::UInt64,
885 true,
886 )));
887 fields.push(Arc::new(
888 Field::new(format!("{var}._type"), DataType::LargeBinary, true)
889 .with_metadata(cv_metadata()),
890 ));
891 }
892 }
893 PatternElement::Parenthesized { pattern, .. } => {
894 let sub = Pattern {
898 paths: vec![pattern.as_ref().clone()],
899 };
900 let sub_schema = extended_schema_for_new_vars(
901 &Arc::new(Schema::new(fields.clone())),
902 &[sub],
903 );
904 for field in sub_schema.fields() {
907 added.insert(field.name().clone());
908 }
909 fields = sub_schema.fields().to_vec();
910 }
911 }
912 }
913 }
914 }
915
916 Arc::new(Schema::new(fields))
917}
918
919fn mutation_kind_label(kind: &MutationKind) -> &'static str {
921 match kind {
922 MutationKind::Create { .. } => "CREATE",
923 MutationKind::CreateBatch { .. } => "CREATE_BATCH",
924 MutationKind::Set { .. } => "SET",
925 MutationKind::Remove { .. } => "REMOVE",
926 MutationKind::Delete { .. } => "DELETE",
927 MutationKind::Merge { .. } => "MERGE",
928 }
929}
930
931#[derive(Debug)]
944pub struct MutationExec {
945 input: Arc<dyn ExecutionPlan>,
947
948 kind: MutationKind,
950
951 display_name: &'static str,
953
954 mutation_ctx: Arc<MutationContext>,
956
957 schema: SchemaRef,
959
960 properties: PlanProperties,
962
963 metrics: ExecutionPlanMetricsSet,
965}
966
967impl MutationExec {
968 pub fn new(
974 input: Arc<dyn ExecutionPlan>,
975 kind: MutationKind,
976 display_name: &'static str,
977 mutation_ctx: Arc<MutationContext>,
978 ) -> Self {
979 let schema = normalize_mutation_schema(&input.schema());
980 let properties = compute_plan_properties(schema.clone());
981 Self {
982 input,
983 kind,
984 display_name,
985 mutation_ctx,
986 schema,
987 properties,
988 metrics: ExecutionPlanMetricsSet::new(),
989 }
990 }
991
992 pub fn new_with_schema(
997 input: Arc<dyn ExecutionPlan>,
998 kind: MutationKind,
999 display_name: &'static str,
1000 mutation_ctx: Arc<MutationContext>,
1001 output_schema: SchemaRef,
1002 ) -> Self {
1003 let properties = compute_plan_properties(output_schema.clone());
1004 Self {
1005 input,
1006 kind,
1007 display_name,
1008 mutation_ctx,
1009 schema: output_schema,
1010 properties,
1011 metrics: ExecutionPlanMetricsSet::new(),
1012 }
1013 }
1014}
1015
1016impl DisplayAs for MutationExec {
1017 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
1018 if matches!(&self.kind, MutationKind::Delete { detach: true, .. }) {
1019 write!(f, "{} [DETACH]", self.display_name)
1020 } else {
1021 write!(f, "{}", self.display_name)
1022 }
1023 }
1024}
1025
1026impl ExecutionPlan for MutationExec {
1027 fn name(&self) -> &str {
1028 self.display_name
1029 }
1030
1031 fn as_any(&self) -> &dyn Any {
1032 self
1033 }
1034
1035 fn schema(&self) -> SchemaRef {
1036 self.schema.clone()
1037 }
1038
1039 fn properties(&self) -> &PlanProperties {
1040 &self.properties
1041 }
1042
1043 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1044 vec![&self.input]
1045 }
1046
1047 fn with_new_children(
1048 self: Arc<Self>,
1049 children: Vec<Arc<dyn ExecutionPlan>>,
1050 ) -> DFResult<Arc<dyn ExecutionPlan>> {
1051 if children.len() != 1 {
1052 return Err(datafusion::error::DataFusionError::Plan(format!(
1053 "{} requires exactly one child",
1054 self.display_name,
1055 )));
1056 }
1057 Ok(Arc::new(MutationExec::new_with_schema(
1058 children[0].clone(),
1059 self.kind.clone(),
1060 self.display_name,
1061 self.mutation_ctx.clone(),
1062 self.schema.clone(),
1063 )))
1064 }
1065
1066 fn execute(
1067 &self,
1068 partition: usize,
1069 context: Arc<TaskContext>,
1070 ) -> DFResult<SendableRecordBatchStream> {
1071 execute_mutation_stream(
1072 self.input.clone(),
1073 self.schema.clone(),
1074 self.mutation_ctx.clone(),
1075 self.kind.clone(),
1076 partition,
1077 context,
1078 )
1079 }
1080
1081 fn metrics(&self) -> Option<MetricsSet> {
1082 Some(self.metrics.clone_inner())
1083 }
1084}
1085
1086pub fn new_create_exec(
1092 input: Arc<dyn ExecutionPlan>,
1093 pattern: Pattern,
1094 mutation_ctx: Arc<MutationContext>,
1095) -> MutationExec {
1096 let output_schema =
1097 extended_schema_for_new_vars(&input.schema(), std::slice::from_ref(&pattern));
1098 MutationExec::new_with_schema(
1099 input,
1100 MutationKind::Create { pattern },
1101 "MutationCreateExec",
1102 mutation_ctx,
1103 output_schema,
1104 )
1105}
1106
1107pub fn new_merge_exec(
1113 input: Arc<dyn ExecutionPlan>,
1114 pattern: Pattern,
1115 on_match: Option<SetClause>,
1116 on_create: Option<SetClause>,
1117 mutation_ctx: Arc<MutationContext>,
1118) -> MutationExec {
1119 let output_schema =
1120 extended_schema_for_new_vars(&input.schema(), std::slice::from_ref(&pattern));
1121 MutationExec::new_with_schema(
1122 input,
1123 MutationKind::Merge {
1124 pattern,
1125 on_match,
1126 on_create,
1127 },
1128 "MutationMergeExec",
1129 mutation_ctx,
1130 output_schema,
1131 )
1132}
1133
1134#[cfg(test)]
1135mod tests {
1136 use super::*;
1137 use arrow_array::{Int64Array, StringArray};
1138 use arrow_schema::{Field, Schema};
1139
1140 #[test]
1141 fn test_batches_to_rows_basic() {
1142 let schema = Arc::new(Schema::new(vec![
1143 Field::new("name", DataType::Utf8, true),
1144 Field::new("age", DataType::Int64, true),
1145 ]));
1146
1147 let batch = RecordBatch::try_new(
1148 schema,
1149 vec![
1150 Arc::new(StringArray::from(vec![Some("Alice"), Some("Bob")])),
1151 Arc::new(Int64Array::from(vec![Some(30), Some(25)])),
1152 ],
1153 )
1154 .unwrap();
1155
1156 let rows = batches_to_rows(&[batch]).unwrap();
1157 assert_eq!(rows.len(), 2);
1158 assert_eq!(rows[0].get("name"), Some(&Value::String("Alice".into())));
1159 assert_eq!(rows[0].get("age"), Some(&Value::Int(30)));
1160 assert_eq!(rows[1].get("name"), Some(&Value::String("Bob".into())));
1161 assert_eq!(rows[1].get("age"), Some(&Value::Int(25)));
1162 }
1163
1164 #[test]
1165 fn test_rows_to_batches_basic() {
1166 let schema = Arc::new(Schema::new(vec![
1167 Field::new("name", DataType::Utf8, true),
1168 Field::new("age", DataType::Int64, true),
1169 ]));
1170
1171 let rows = vec![
1172 {
1173 let mut m = HashMap::new();
1174 m.insert("name".to_string(), Value::String("Alice".into()));
1175 m.insert("age".to_string(), Value::Int(30));
1176 m
1177 },
1178 {
1179 let mut m = HashMap::new();
1180 m.insert("name".to_string(), Value::String("Bob".into()));
1181 m.insert("age".to_string(), Value::Int(25));
1182 m
1183 },
1184 ];
1185
1186 let batches = rows_to_batches(&rows, &schema).unwrap();
1187 assert_eq!(batches.len(), 1);
1188 assert_eq!(batches[0].num_rows(), 2);
1189 assert_eq!(batches[0].schema(), schema);
1190 }
1191
1192 #[test]
1193 fn test_roundtrip_scalar_types() {
1194 let schema = Arc::new(Schema::new(vec![
1195 Field::new("s", DataType::Utf8, true),
1196 Field::new("i", DataType::Int64, true),
1197 Field::new("f", DataType::Float64, true),
1198 Field::new("b", DataType::Boolean, true),
1199 ]));
1200
1201 let batch = RecordBatch::try_new(
1202 schema.clone(),
1203 vec![
1204 Arc::new(StringArray::from(vec![Some("hello")])),
1205 Arc::new(Int64Array::from(vec![Some(42)])),
1206 Arc::new(arrow_array::Float64Array::from(vec![Some(3.125)])),
1207 Arc::new(arrow_array::BooleanArray::from(vec![Some(true)])),
1208 ],
1209 )
1210 .unwrap();
1211
1212 let rows = batches_to_rows(&[batch]).unwrap();
1214 let output_batches = rows_to_batches(&rows, &schema).unwrap();
1215
1216 assert_eq!(output_batches.len(), 1);
1217 assert_eq!(output_batches[0].num_rows(), 1);
1218
1219 let roundtrip_rows = batches_to_rows(&output_batches).unwrap();
1221 assert_eq!(roundtrip_rows.len(), 1);
1222 assert_eq!(
1223 roundtrip_rows[0].get("s"),
1224 Some(&Value::String("hello".into()))
1225 );
1226 assert_eq!(roundtrip_rows[0].get("i"), Some(&Value::Int(42)));
1227 assert_eq!(roundtrip_rows[0].get("b"), Some(&Value::Bool(true)));
1228 if let Some(Value::Float(f)) = roundtrip_rows[0].get("f") {
1230 assert!((*f - 3.125).abs() < 1e-10);
1231 } else {
1232 panic!("Expected float value");
1233 }
1234 }
1235
1236 #[test]
1237 fn test_roundtrip_cypher_value_encoded() {
1238 use std::collections::HashMap as StdHashMap;
1239
1240 let mut metadata = StdHashMap::new();
1242 metadata.insert("cv_encoded".to_string(), "true".to_string());
1243 let field = Field::new("n", DataType::LargeBinary, true).with_metadata(metadata);
1244 let schema = Arc::new(Schema::new(vec![field]));
1245
1246 let mut node_map = HashMap::new();
1248 node_map.insert("name".to_string(), Value::String("Alice".into()));
1249 node_map.insert("_vid".to_string(), Value::Int(1));
1250 let map_val = Value::Map(node_map);
1251
1252 let encoded = uni_common::cypher_value_codec::encode(&map_val);
1254 let batch = RecordBatch::try_new(
1255 schema.clone(),
1256 vec![Arc::new(arrow_array::LargeBinaryArray::from(vec![Some(
1257 encoded.as_slice(),
1258 )]))],
1259 )
1260 .unwrap();
1261
1262 let rows = batches_to_rows(&[batch]).unwrap();
1264 assert_eq!(rows.len(), 1);
1265
1266 let val = rows[0].get("n").unwrap();
1268 assert!(matches!(val, Value::Map(_)));
1269
1270 let output_batches = rows_to_batches(&rows, &schema).unwrap();
1271 assert_eq!(output_batches[0].num_rows(), 1);
1272
1273 let roundtrip_rows = batches_to_rows(&output_batches).unwrap();
1275 assert_eq!(roundtrip_rows.len(), 1);
1276 }
1277
1278 #[test]
1279 fn test_empty_rows() {
1280 let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, true)]));
1281
1282 let batches = rows_to_batches(&[], &schema).unwrap();
1283 assert_eq!(batches.len(), 1);
1284 assert_eq!(batches[0].num_rows(), 0);
1285 }
1286}