1use anyhow::Result;
14use arrow_array::RecordBatch;
15use arrow_schema::{DataType, SchemaRef};
16use datafusion::common::Result as DFResult;
17use datafusion::execution::TaskContext;
18use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
19use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
20use datafusion::physical_plan::{
21 DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SendableRecordBatchStream,
22};
23use futures::TryStreamExt;
24use std::any::Any;
25use std::collections::{HashMap, HashSet};
26use std::fmt;
27use std::sync::Arc;
28use uni_common::core::id::{Eid, Vid};
29use uni_common::{Path, Properties, Value};
30use uni_cypher::ast::{Expr, Pattern, PatternElement, RemoveItem, SetClause, SetItem};
31use uni_store::QueryContext;
32use uni_store::runtime::property_manager::PropertyManager;
33use uni_store::runtime::writer::Writer;
34use uni_store::storage::arrow_convert;
35
36use super::common::compute_plan_properties;
37use crate::query::executor::core::Executor;
38
39#[derive(Clone)]
45pub struct MutationContext {
46 pub executor: Executor,
48
49 pub writer: Arc<Writer>,
51
52 pub prop_manager: Arc<PropertyManager>,
54
55 pub params: HashMap<String, Value>,
57
58 pub query_ctx: Option<uni_store::QueryContext>,
60
61 pub tx_l0_override: Option<Arc<parking_lot::RwLock<uni_store::runtime::l0::L0Buffer>>>,
64}
65
66impl std::fmt::Debug for MutationContext {
67 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68 f.debug_struct("MutationContext")
69 .field("has_writer", &true)
70 .field("has_prop_manager", &true)
71 .field("params_count", &self.params.len())
72 .field("has_query_ctx", &self.query_ctx.is_some())
73 .finish()
74 }
75}
76
77#[derive(Debug, Clone)]
79pub enum MutationKind {
80 Create { pattern: Pattern },
82
83 CreateBatch { patterns: Vec<Pattern> },
85
86 Set { items: Vec<SetItem> },
88
89 Remove { items: Vec<RemoveItem> },
91
92 Delete { items: Vec<Expr>, detach: bool },
94
95 Merge {
97 pattern: Pattern,
98 on_match: Option<SetClause>,
99 on_create: Option<SetClause>,
100 },
101}
102
103#[derive(Default, Debug)]
117pub(crate) struct Prefetch {
118 pub vertex: HashMap<Vid, Properties>,
119 pub edge: HashMap<Eid, Properties>,
120}
121
122pub(crate) async fn prefetch_set_targets(
130 items: &[SetItem],
131 rows: &[HashMap<String, Value>],
132 pm: &PropertyManager,
133 ctx: Option<&QueryContext>,
134) -> Result<Prefetch> {
135 let touched_vars: HashSet<&str> = items
137 .iter()
138 .filter_map(|item| match item {
139 SetItem::Property { expr, .. } => extract_var_from_property_expr(expr),
140 SetItem::Variable { variable, .. } | SetItem::VariablePlus { variable, .. } => {
141 Some(variable.as_str())
142 }
143 SetItem::Labels { .. } => None,
144 })
145 .collect();
146 if touched_vars.is_empty() {
147 return Ok(Prefetch::default());
148 }
149
150 collect_and_fetch_vertex_prefetch(&touched_vars, rows, pm, ctx).await
151}
152
153pub(crate) async fn prefetch_remove_targets(
155 items: &[RemoveItem],
156 rows: &[HashMap<String, Value>],
157 pm: &PropertyManager,
158 ctx: Option<&QueryContext>,
159) -> Result<Prefetch> {
160 let touched_vars: HashSet<&str> = items
161 .iter()
162 .filter_map(|item| match item {
163 RemoveItem::Property(expr) => extract_var_from_property_expr(expr),
164 RemoveItem::Labels { .. } => None,
165 })
166 .collect();
167 if touched_vars.is_empty() {
168 return Ok(Prefetch::default());
169 }
170
171 collect_and_fetch_vertex_prefetch(&touched_vars, rows, pm, ctx).await
172}
173
174fn extract_var_from_property_expr(expr: &Expr) -> Option<&str> {
177 if let Expr::Property(inner, _) = expr
178 && let Expr::Variable(name) = inner.as_ref()
179 {
180 return Some(name.as_str());
181 }
182 None
183}
184
185async fn collect_and_fetch_vertex_prefetch(
194 touched_vars: &HashSet<&str>,
195 rows: &[HashMap<String, Value>],
196 pm: &PropertyManager,
197 ctx: Option<&QueryContext>,
198) -> Result<Prefetch> {
199 let mut by_label: HashMap<String, HashSet<Vid>> = HashMap::new();
200 let mut by_type: HashMap<String, HashSet<Eid>> = HashMap::new();
201
202 for row in rows {
203 for &var in touched_vars {
204 let Some(bound) = row.get(var) else { continue };
205 if let Some((vid, labels)) = vertex_vid_and_labels(bound) {
206 if let Some(label) = labels.first() {
207 by_label.entry(label.clone()).or_default().insert(vid);
208 }
209 } else if let Some((eid, type_name)) = edge_eid_and_type(bound)
210 && !type_name.is_empty()
211 {
212 by_type.entry(type_name).or_default().insert(eid);
213 }
214 }
215 }
216
217 let mut prefetch = Prefetch::default();
218 for (label, vid_set) in by_label {
219 let vids: Vec<Vid> = vid_set.into_iter().collect();
220 if vids.is_empty() {
221 continue;
222 }
223 if let Ok(label_results) = pm
224 .get_batch_vertex_props_for_label(&vids, &label, ctx)
225 .await
226 {
227 for (vid, props) in label_results {
228 prefetch.vertex.entry(vid).or_insert(props);
229 }
230 }
231 }
233 for (type_name, eid_set) in by_type {
234 let eids: Vec<Eid> = eid_set.into_iter().collect();
235 if eids.is_empty() {
236 continue;
237 }
238 if let Ok(type_results) = pm
239 .get_batch_edge_props_for_type(&eids, &type_name, ctx)
240 .await
241 {
242 for (eid, props) in type_results {
243 prefetch.edge.entry(eid).or_insert(props);
244 }
245 }
246 }
247 Ok(prefetch)
248}
249
250fn vertex_vid_and_labels(val: &Value) -> Option<(Vid, Vec<String>)> {
254 match val {
255 Value::Node(node) => Some((node.vid, node.labels.clone())),
256 Value::Map(map) => {
257 if map.contains_key("_eid") {
258 return None;
259 }
260 let vid_val = map.get("_vid")?;
261 let vid = match vid_val {
262 Value::Int(i) if *i >= 0 => Vid::from(*i as u64),
263 _ => return None,
264 };
265 let labels = map
266 .get("_labels")
267 .and_then(|v| match v {
268 Value::List(items) => Some(
269 items
270 .iter()
271 .filter_map(|x| {
272 if let Value::String(s) = x {
273 Some(s.clone())
274 } else {
275 None
276 }
277 })
278 .collect::<Vec<_>>(),
279 ),
280 _ => None,
281 })
282 .unwrap_or_default();
283 Some((vid, labels))
284 }
285 _ => None,
286 }
287}
288
289fn edge_eid_and_type(val: &Value) -> Option<(Eid, String)> {
294 match val {
295 Value::Edge(edge) => Some((edge.eid, edge.edge_type.clone())),
296 Value::Map(map) => {
297 let eid_val = map.get("_eid")?;
299 if !map.contains_key("_src") || !map.contains_key("_dst") {
300 return None;
301 }
302 let eid = match eid_val {
303 Value::Int(i) if *i >= 0 => Eid::from(*i as u64),
304 Value::Null => return None,
305 _ => return None,
306 };
307 let type_name = map
308 .get("_type_name")
309 .or_else(|| map.get("_type"))
310 .and_then(|v| match v {
311 Value::String(s) => Some(s.clone()),
312 _ => None,
313 })
314 .unwrap_or_default();
315 Some((eid, type_name))
316 }
317 _ => None,
318 }
319}
320
321pub fn batches_to_rows(batches: &[RecordBatch]) -> Result<Vec<HashMap<String, Value>>> {
332 let mut rows = Vec::new();
333
334 for batch in batches {
335 let num_rows = batch.num_rows();
336 let schema = batch.schema();
337
338 for row_idx in 0..num_rows {
339 let mut row = HashMap::new();
340
341 for (col_idx, field) in schema.fields().iter().enumerate() {
342 let column = batch.column(col_idx);
343 let data_type = if uni_common::core::schema::is_datetime_struct(field.data_type()) {
345 Some(&uni_common::DataType::DateTime)
346 } else if uni_common::core::schema::is_time_struct(field.data_type()) {
347 Some(&uni_common::DataType::Time)
348 } else {
349 None
350 };
351 let mut value = arrow_convert::arrow_to_value(column.as_ref(), row_idx, data_type);
352
353 if field
356 .metadata()
357 .get("cv_encoded")
358 .is_some_and(|v| v == "true")
359 && let Value::String(s) = &value
360 && let Ok(parsed) = serde_json::from_str::<serde_json::Value>(s)
361 {
362 value = Value::from(parsed);
363 }
364
365 row.insert(field.name().clone(), value);
366 }
367
368 merge_system_fields_for_write(&mut row);
373
374 rows.push(row);
375 }
376 }
377
378 Ok(rows)
379}
380
381fn sync_all_props_in_maps(rows: &mut [HashMap<String, Value>]) {
390 for row in rows {
391 let map_keys: Vec<String> = row
392 .keys()
393 .filter(|k| !k.contains('.') && matches!(row.get(*k), Some(Value::Map(_))))
394 .cloned()
395 .collect();
396
397 for key in map_keys {
398 if let Some(Value::Map(map)) = row.get_mut(&key)
399 && map.contains_key("_all_props")
400 {
401 let updates: Vec<(String, Value)> = map
403 .iter()
404 .filter(|(k, _)| !k.starts_with('_') && k.as_str() != "ext_id")
405 .map(|(k, v)| (k.clone(), v.clone()))
406 .collect();
407
408 if !updates.is_empty()
409 && let Some(Value::Map(all_props)) = map.get_mut("_all_props")
410 {
411 for (k, v) in updates {
412 all_props.insert(k, v);
413 }
414 }
415 }
416 }
417 }
418}
419
420fn sync_dotted_columns(rows: &mut [HashMap<String, Value>], schema: &SchemaRef) {
428 for row in rows {
429 for field in schema.fields() {
430 let name = field.name();
431 if let Some(dot_pos) = name.find('.') {
432 let var_name = &name[..dot_pos];
433 let prop_name = &name[dot_pos + 1..];
434 if let Some(Value::Map(map)) = row.get(var_name) {
435 let val = map.get(prop_name).cloned().unwrap_or(Value::Null);
436 row.insert(name.clone(), val);
437 }
438 }
439 }
440 }
441}
442
443fn normalize_edge_field_names(map: &mut HashMap<String, Value>) {
447 if let Some(val) = map.remove("_src_vid") {
448 map.entry("_src".to_string()).or_insert(val);
449 }
450 if let Some(val) = map.remove("_dst_vid") {
451 map.entry("_dst".to_string()).or_insert(val);
452 }
453}
454
455fn merge_system_fields_for_write(row: &mut HashMap<String, Value>) {
461 const VERTEX_FIELDS: &[&str] = &["_vid", "_labels"];
464 const EDGE_FIELDS: &[&str] = &["_eid", "_type", "_src_vid", "_dst_vid"];
465
466 let dotted_vars: HashSet<String> = row
468 .keys()
469 .filter_map(|key| key.find('.').map(|pos| key[..pos].to_string()))
470 .collect();
471
472 for var in &dotted_vars {
476 if !row.contains_key(var) {
477 let prefix = format!("{var}.");
478 let mut map: HashMap<String, Value> = row
479 .iter()
480 .filter_map(|(k, v)| {
481 k.strip_prefix(prefix.as_str())
482 .map(|field| (field.to_string(), v.clone()))
483 })
484 .collect();
485 normalize_edge_field_names(&mut map);
486 if !map.is_empty() {
487 row.insert(var.clone(), Value::Map(map));
488 }
489 }
490 }
491
492 let bare_vars: Vec<String> = row
495 .keys()
496 .filter(|k| !k.contains('.') && matches!(row.get(*k), Some(Value::Map(_))))
497 .cloned()
498 .collect();
499
500 for var in &bare_vars {
501 let vertex_vals: Vec<(&str, Value)> = VERTEX_FIELDS
503 .iter()
504 .filter_map(|&field| {
505 row.get(&format!("{var}.{field}"))
506 .cloned()
507 .map(|v| (field, v))
508 })
509 .collect();
510 let edge_vals: Vec<(&str, Value)> = EDGE_FIELDS
511 .iter()
512 .filter_map(|&field| {
513 row.get(&format!("{var}.{field}"))
514 .cloned()
515 .map(|v| (field, v))
516 })
517 .collect();
518
519 if let Some(Value::Map(map)) = row.get_mut(var) {
520 for (field, v) in vertex_vals {
521 map.insert(field.to_string(), v);
522 }
523 for (field, v) in edge_vals {
524 map.entry(field.to_string()).or_insert(v);
525 }
526 normalize_edge_field_names(map);
527 }
528 }
529}
530
531pub fn rows_to_batches(
540 rows: &[HashMap<String, Value>],
541 schema: &SchemaRef,
542) -> Result<Vec<RecordBatch>> {
543 if rows.is_empty() {
544 let batch = if schema.fields().is_empty() {
546 let options = arrow_array::RecordBatchOptions::new().with_row_count(Some(0));
547 RecordBatch::try_new_with_options(schema.clone(), vec![], &options)?
548 } else {
549 RecordBatch::new_empty(schema.clone())
550 };
551 return Ok(vec![batch]);
552 }
553
554 if schema.fields().is_empty() {
555 let options = arrow_array::RecordBatchOptions::new().with_row_count(Some(rows.len()));
559 let batch = RecordBatch::try_new_with_options(schema.clone(), vec![], &options)?;
560 return Ok(vec![batch]);
561 }
562
563 let mut columns: Vec<arrow_array::ArrayRef> = Vec::with_capacity(schema.fields().len());
565
566 for field in schema.fields() {
567 let name = field.name();
568 let values: Vec<Value> = rows
569 .iter()
570 .map(|row| row.get(name).cloned().unwrap_or(Value::Null))
571 .collect();
572
573 let array = value_column_to_arrow(&values, field.data_type(), field)?;
574 columns.push(array);
575 }
576
577 let batch = RecordBatch::try_new(schema.clone(), columns)?;
578 Ok(vec![batch])
579}
580
581fn value_column_to_arrow(
583 values: &[Value],
584 arrow_type: &DataType,
585 field: &arrow_schema::Field,
586) -> Result<arrow_array::ArrayRef> {
587 let is_cv_encoded = field
588 .metadata()
589 .get("cv_encoded")
590 .is_some_and(|v| v == "true");
591
592 if *arrow_type == DataType::LargeBinary || is_cv_encoded {
593 Ok(encode_as_large_binary(values))
594 } else if *arrow_type == DataType::Binary {
595 Ok(encode_as_binary(values))
597 } else {
598 arrow_convert::values_to_array(values, arrow_type)
600 .or_else(|_| Ok(encode_as_large_binary(values)))
601 }
602}
603
604macro_rules! encode_as_cv {
606 ($builder_ty:ty, $values:expr) => {{
607 let values = $values;
608 let mut builder = <$builder_ty>::with_capacity(values.len(), values.len() * 64);
609 for v in values {
610 if v.is_null() {
611 builder.append_null();
612 } else {
613 let bytes = uni_common::cypher_value_codec::encode(v);
614 builder.append_value(&bytes);
615 }
616 }
617 Arc::new(builder.finish()) as arrow_array::ArrayRef
618 }};
619}
620
621fn encode_as_binary(values: &[Value]) -> arrow_array::ArrayRef {
623 encode_as_cv!(arrow_array::builder::BinaryBuilder, values)
624}
625
626fn encode_as_large_binary(values: &[Value]) -> arrow_array::ArrayRef {
628 encode_as_cv!(arrow_array::builder::LargeBinaryBuilder, values)
629}
630
631pub fn execute_mutation_stream(
642 input: Arc<dyn ExecutionPlan>,
643 output_schema: SchemaRef,
644 mutation_ctx: Arc<MutationContext>,
645 mutation_kind: MutationKind,
646 partition: usize,
647 task_ctx: Arc<datafusion::execution::TaskContext>,
648 baseline: BaselineMetrics,
649) -> DFResult<SendableRecordBatchStream> {
650 if mutation_ctx.query_ctx.is_none() {
651 tracing::warn!(
652 "MutationContext.query_ctx is None — mutations may not see latest L0 buffer state"
653 );
654 }
655
656 let stream = futures::stream::once(execute_mutation_inner(
657 input,
658 output_schema.clone(),
659 mutation_ctx,
660 mutation_kind,
661 partition,
662 task_ctx,
663 baseline,
664 ))
665 .try_flatten();
666
667 Ok(Box::pin(RecordBatchStreamAdapter::new(
668 output_schema,
669 stream,
670 )))
671}
672
673async fn execute_mutation_inner(
683 input: Arc<dyn ExecutionPlan>,
684 output_schema: SchemaRef,
685 mutation_ctx: Arc<MutationContext>,
686 mutation_kind: MutationKind,
687 partition: usize,
688 task_ctx: Arc<datafusion::execution::TaskContext>,
689 baseline: BaselineMetrics,
690) -> DFResult<futures::stream::Iter<std::vec::IntoIter<DFResult<RecordBatch>>>> {
691 let _timer = baseline.elapsed_compute().timer();
694 let mutation_label = mutation_kind_label(&mutation_kind);
695
696 let input_stream = input.execute(partition, task_ctx)?;
698 let input_batches: Vec<RecordBatch> = input_stream.try_collect().await?;
699
700 let input_row_count: usize = input_batches.iter().map(|b| b.num_rows()).sum();
701 tracing::debug!(
702 mutation = mutation_label,
703 batches = input_batches.len(),
704 rows = input_row_count,
705 "Executing mutation"
706 );
707
708 let mut rows = batches_to_rows(&input_batches).map_err(|e| {
710 datafusion::error::DataFusionError::Execution(format!(
711 "Failed to convert batches to rows: {e}"
712 ))
713 })?;
714
715 if let MutationKind::Merge {
720 ref pattern,
721 ref on_match,
722 ref on_create,
723 } = mutation_kind
724 {
725 let exec = &mutation_ctx.executor;
726 let pm = &mutation_ctx.prop_manager;
727 let params = &mutation_ctx.params;
728 let ctx = mutation_ctx.query_ctx.as_ref();
729
730 let mut result_rows = exec
731 .execute_merge(
732 rows,
733 pattern,
734 on_match.as_ref(),
735 on_create.as_ref(),
736 pm,
737 params,
738 ctx,
739 mutation_ctx.tx_l0_override.as_ref(),
740 )
741 .await
742 .map_err(|e| {
743 datafusion::error::DataFusionError::Execution(format!("MERGE failed: {e}"))
744 })?;
745
746 tracing::debug!(
747 mutation = mutation_label,
748 input_rows = input_row_count,
749 output_rows = result_rows.len(),
750 "MERGE mutation complete"
751 );
752
753 sync_all_props_in_maps(&mut result_rows);
756 sync_dotted_columns(&mut result_rows, &output_schema);
757 let result_batches = rows_to_batches(&result_rows, &output_schema).map_err(|e| {
758 datafusion::error::DataFusionError::Execution(format!(
759 "Failed to reconstruct MERGE batches: {e}"
760 ))
761 })?;
762 let output_rows: usize = result_batches.iter().map(|b| b.num_rows()).sum();
763 baseline.record_output(output_rows);
764 let results: Vec<DFResult<RecordBatch>> = result_batches.into_iter().map(Ok).collect();
765 return Ok(futures::stream::iter(results));
766 }
767
768 let tx_l0 = mutation_ctx.tx_l0_override.as_ref();
769 apply_mutations(
770 &mutation_ctx,
771 &mutation_kind,
772 &mut rows,
773 &mutation_ctx.writer,
774 tx_l0,
775 )
776 .await?;
777
778 tracing::debug!(
779 mutation = mutation_label,
780 rows = input_row_count,
781 "Mutation complete"
782 );
783
784 sync_all_props_in_maps(&mut rows);
789 sync_dotted_columns(&mut rows, &output_schema);
790 let result_batches = rows_to_batches(&rows, &output_schema).map_err(|e| {
791 datafusion::error::DataFusionError::Execution(format!("Failed to reconstruct batches: {e}"))
792 })?;
793 let output_rows: usize = result_batches.iter().map(|b| b.num_rows()).sum();
794 baseline.record_output(output_rows);
795 let results: Vec<DFResult<RecordBatch>> = result_batches.into_iter().map(Ok).collect();
796 Ok(futures::stream::iter(results))
797}
798
799struct DeleteCollector {
806 node_entries: Vec<(Vid, Option<Vec<String>>)>,
808 edge_vals: Vec<Value>,
810 seen_vids: HashSet<u64>,
812 seen_eids: HashSet<u64>,
813 dedup: bool,
814}
815
816impl DeleteCollector {
817 fn new(dedup: bool) -> Self {
818 Self {
819 node_entries: Vec::new(),
820 edge_vals: Vec::new(),
821 seen_vids: HashSet::new(),
822 seen_eids: HashSet::new(),
823 dedup,
824 }
825 }
826
827 fn add(&mut self, val: Value) {
828 if val.is_null() {
829 return;
830 }
831
832 let path = match &val {
834 Value::Path(p) => Some(p.clone()),
835 _ => Path::try_from(&val).ok(),
836 };
837
838 if let Some(path) = path {
839 for edge in &path.edges {
840 if !self.dedup || self.seen_eids.insert(edge.eid.as_u64()) {
841 self.edge_vals.push(Value::Edge(edge.clone()));
842 }
843 }
844 for node in &path.nodes {
845 self.add_node(node.vid, Some(node.labels.clone()));
846 }
847 return;
848 }
849
850 if let Ok(vid) = Executor::vid_from_value(&val) {
852 let labels = Executor::extract_labels_from_node(&val);
853 self.add_node(vid, labels);
854 return;
855 }
856
857 if matches!(&val, Value::Map(_) | Value::Edge(_)) {
859 self.edge_vals.push(val);
860 }
861 }
862
863 fn add_node(&mut self, vid: Vid, labels: Option<Vec<String>>) {
864 if self.dedup && !self.seen_vids.insert(vid.as_u64()) {
865 return;
866 }
867 self.node_entries.push((vid, labels));
868 }
869}
870
871async fn apply_mutations(
873 mutation_ctx: &MutationContext,
874 mutation_kind: &MutationKind,
875 rows: &mut [HashMap<String, Value>],
876 writer: &Writer,
877 tx_l0: Option<&Arc<parking_lot::RwLock<uni_store::runtime::l0::L0Buffer>>>,
878) -> DFResult<()> {
879 tracing::trace!(
880 mutation = mutation_kind_label(mutation_kind),
881 rows = rows.len(),
882 "Applying mutations"
883 );
884
885 let exec = &mutation_ctx.executor;
886 let pm = &mutation_ctx.prop_manager;
887 let params = &mutation_ctx.params;
888 let ctx = mutation_ctx.query_ctx.as_ref();
889
890 let df_err = |msg: &str, e: anyhow::Error| {
891 datafusion::error::DataFusionError::Execution(format!("{msg}: {e}"))
892 };
893
894 match mutation_kind {
895 MutationKind::Create { pattern } => {
896 for row in rows.iter_mut() {
897 exec.execute_create_pattern(pattern, row, writer, pm, params, ctx, tx_l0, None)
898 .await
899 .map_err(|e| df_err("CREATE failed", e))?;
900 }
901 }
902 MutationKind::CreateBatch { patterns } => {
903 for row in rows.iter_mut() {
904 for pattern in patterns {
905 exec.execute_create_pattern(pattern, row, writer, pm, params, ctx, tx_l0, None)
906 .await
907 .map_err(|e| df_err("CREATE failed", e))?;
908 }
909 }
910 }
911 MutationKind::Set { items } => {
912 let prefetch = prefetch_set_targets(items, rows, pm, ctx)
913 .await
914 .map_err(|e| df_err("SET prefetch failed", e))?;
915 for row in rows.iter_mut() {
916 exec.execute_set_items_locked(
917 items, row, writer, pm, params, ctx, tx_l0, &prefetch,
918 )
919 .await
920 .map_err(|e| df_err("SET failed", e))?;
921 }
922 }
923 MutationKind::Remove { items } => {
924 let prefetch = prefetch_remove_targets(items, rows, pm, ctx)
925 .await
926 .map_err(|e| df_err("REMOVE prefetch failed", e))?;
927 for row in rows.iter_mut() {
928 exec.execute_remove_items_locked(items, row, writer, pm, ctx, tx_l0, &prefetch)
929 .await
930 .map_err(|e| df_err("REMOVE failed", e))?;
931 }
932 }
933 MutationKind::Delete { items, detach } => {
934 let mut collector = DeleteCollector::new(!*detach);
936 for row in rows.iter() {
937 for expr in items {
938 let val = exec
939 .evaluate_expr(expr, row, pm, params, ctx)
940 .await
941 .map_err(|e| df_err("DELETE eval failed", e))?;
942 collector.add(val);
943 }
944 }
945
946 for val in &collector.edge_vals {
948 exec.execute_delete_item_locked(val, false, writer, tx_l0)
949 .await
950 .map_err(|e| df_err("DELETE edge failed", e))?;
951 }
952
953 if *detach {
954 let (vids, labels): (Vec<Vid>, Vec<Option<Vec<String>>>) =
955 collector.node_entries.into_iter().unzip();
956 exec.batch_detach_delete_vertices(&vids, labels, writer, tx_l0)
957 .await
958 .map_err(|e| df_err("DETACH DELETE failed", e))?;
959 } else {
960 let vids: Vec<Vid> = collector.node_entries.iter().map(|(v, _)| *v).collect();
965 exec.batch_check_vertices_have_no_edges(&vids, writer, tx_l0)
966 .await
967 .map_err(|e| df_err("DELETE check failed", e))?;
968 for (vid, labels) in &collector.node_entries {
969 writer
970 .delete_vertex(*vid, labels.clone(), tx_l0)
971 .await
972 .map_err(|e| df_err("DELETE node failed", e))?;
973 }
974 }
975 }
976 MutationKind::Merge { .. } => {
977 unreachable!("MERGE mutations are handled before apply_mutations is called");
980 }
981 }
982
983 Ok(())
984}
985
986pub fn pattern_variable_names(pattern: &Pattern) -> Vec<String> {
991 let mut vars = Vec::new();
992 for path in &pattern.paths {
993 if let Some(ref v) = path.variable {
994 vars.push(v.clone());
995 }
996 for element in &path.elements {
997 match element {
998 PatternElement::Node(n) => {
999 if let Some(ref v) = n.variable {
1000 vars.push(v.clone());
1001 }
1002 }
1003 PatternElement::Relationship(r) => {
1004 if let Some(ref v) = r.variable {
1005 vars.push(v.clone());
1006 }
1007 }
1008 PatternElement::Parenthesized { pattern, .. } => {
1009 let sub = Pattern {
1011 paths: vec![pattern.as_ref().clone()],
1012 };
1013 vars.extend(pattern_variable_names(&sub));
1014 }
1015 }
1016 }
1017 }
1018 vars
1019}
1020
1021fn normalize_mutation_schema(schema: &SchemaRef) -> SchemaRef {
1029 use arrow_schema::{Field, Schema};
1030
1031 fn needs_norm(dt: &DataType) -> bool {
1043 match dt {
1044 DataType::Struct(_) => true,
1045 DataType::List(inner) | DataType::LargeList(inner) => {
1046 !matches!(inner.data_type(), DataType::Utf8)
1047 }
1048 _ => false,
1049 }
1050 }
1051
1052 if !schema.fields().iter().any(|f| needs_norm(f.data_type())) {
1053 return schema.clone();
1054 }
1055
1056 let fields: Vec<Arc<Field>> = schema
1057 .fields()
1058 .iter()
1059 .map(|field| {
1060 if needs_norm(field.data_type()) {
1061 let mut metadata = field.metadata().clone();
1062 metadata.insert("cv_encoded".to_string(), "true".to_string());
1063 Arc::new(
1064 Field::new(field.name(), DataType::LargeBinary, true).with_metadata(metadata),
1065 )
1066 } else {
1067 field.clone()
1068 }
1069 })
1070 .collect();
1071
1072 Arc::new(Schema::new(fields))
1073}
1074
1075pub fn extended_schema_for_new_vars(input_schema: &SchemaRef, patterns: &[Pattern]) -> SchemaRef {
1090 use arrow_schema::{Field, Schema};
1091
1092 let normalized = normalize_mutation_schema(input_schema);
1094
1095 let existing_names: HashSet<&str> = normalized
1096 .fields()
1097 .iter()
1098 .map(|f| f.name().as_str())
1099 .collect();
1100
1101 let mut fields: Vec<Arc<arrow_schema::Field>> = normalized.fields().to_vec();
1102 let mut added: HashSet<String> = HashSet::new();
1103
1104 fn cv_metadata() -> std::collections::HashMap<String, String> {
1105 let mut m = std::collections::HashMap::new();
1106 m.insert("cv_encoded".to_string(), "true".to_string());
1107 m
1108 }
1109
1110 fn add_bare_column(
1111 var: &str,
1112 fields: &mut Vec<Arc<arrow_schema::Field>>,
1113 existing: &HashSet<&str>,
1114 added: &mut HashSet<String>,
1115 ) -> bool {
1116 if existing.contains(var) || added.contains(var) {
1117 return false;
1118 }
1119 added.insert(var.to_string());
1120 fields.push(Arc::new(
1121 Field::new(var, DataType::LargeBinary, true).with_metadata(cv_metadata()),
1122 ));
1123 true
1124 }
1125
1126 for pattern in patterns {
1127 for path in &pattern.paths {
1128 if let Some(ref var) = path.variable {
1130 add_bare_column(var, &mut fields, &existing_names, &mut added);
1131 }
1132 for element in &path.elements {
1133 match element {
1134 PatternElement::Node(n) => {
1135 if let Some(ref var) = n.variable
1136 && add_bare_column(var, &mut fields, &existing_names, &mut added)
1137 {
1138 fields.push(Arc::new(Field::new(
1140 format!("{var}._vid"),
1141 DataType::UInt64,
1142 true,
1143 )));
1144 fields.push(Arc::new(
1145 Field::new(format!("{var}._labels"), DataType::LargeBinary, true)
1146 .with_metadata(cv_metadata()),
1147 ));
1148 }
1149 }
1150 PatternElement::Relationship(r) => {
1151 if let Some(ref var) = r.variable
1152 && add_bare_column(var, &mut fields, &existing_names, &mut added)
1153 {
1154 fields.push(Arc::new(Field::new(
1156 format!("{var}._eid"),
1157 DataType::UInt64,
1158 true,
1159 )));
1160 fields.push(Arc::new(
1161 Field::new(format!("{var}._type"), DataType::LargeBinary, true)
1162 .with_metadata(cv_metadata()),
1163 ));
1164 }
1165 }
1166 PatternElement::Parenthesized { pattern, .. } => {
1167 let sub = Pattern {
1171 paths: vec![pattern.as_ref().clone()],
1172 };
1173 let sub_schema = extended_schema_for_new_vars(
1174 &Arc::new(Schema::new(fields.clone())),
1175 &[sub],
1176 );
1177 for field in sub_schema.fields() {
1180 added.insert(field.name().clone());
1181 }
1182 fields = sub_schema.fields().to_vec();
1183 }
1184 }
1185 }
1186 }
1187 }
1188
1189 Arc::new(Schema::new(fields))
1190}
1191
1192fn mutation_kind_label(kind: &MutationKind) -> &'static str {
1194 match kind {
1195 MutationKind::Create { .. } => "CREATE",
1196 MutationKind::CreateBatch { .. } => "CREATE_BATCH",
1197 MutationKind::Set { .. } => "SET",
1198 MutationKind::Remove { .. } => "REMOVE",
1199 MutationKind::Delete { .. } => "DELETE",
1200 MutationKind::Merge { .. } => "MERGE",
1201 }
1202}
1203
1204#[derive(Debug)]
1217pub struct MutationExec {
1218 input: Arc<dyn ExecutionPlan>,
1220
1221 kind: MutationKind,
1223
1224 display_name: &'static str,
1226
1227 mutation_ctx: Arc<MutationContext>,
1229
1230 schema: SchemaRef,
1232
1233 properties: Arc<PlanProperties>,
1235
1236 metrics: ExecutionPlanMetricsSet,
1238}
1239
1240impl MutationExec {
1241 pub fn new(
1247 input: Arc<dyn ExecutionPlan>,
1248 kind: MutationKind,
1249 display_name: &'static str,
1250 mutation_ctx: Arc<MutationContext>,
1251 ) -> Self {
1252 let schema = normalize_mutation_schema(&input.schema());
1253 let properties = compute_plan_properties(schema.clone());
1254 Self {
1255 input,
1256 kind,
1257 display_name,
1258 mutation_ctx,
1259 schema,
1260 properties,
1261 metrics: ExecutionPlanMetricsSet::new(),
1262 }
1263 }
1264
1265 pub fn new_with_schema(
1270 input: Arc<dyn ExecutionPlan>,
1271 kind: MutationKind,
1272 display_name: &'static str,
1273 mutation_ctx: Arc<MutationContext>,
1274 output_schema: SchemaRef,
1275 ) -> Self {
1276 let properties = compute_plan_properties(output_schema.clone());
1277 Self {
1278 input,
1279 kind,
1280 display_name,
1281 mutation_ctx,
1282 schema: output_schema,
1283 properties,
1284 metrics: ExecutionPlanMetricsSet::new(),
1285 }
1286 }
1287}
1288
1289impl DisplayAs for MutationExec {
1290 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
1291 if matches!(&self.kind, MutationKind::Delete { detach: true, .. }) {
1292 write!(f, "{} [DETACH]", self.display_name)
1293 } else {
1294 write!(f, "{}", self.display_name)
1295 }
1296 }
1297}
1298
1299impl ExecutionPlan for MutationExec {
1300 fn name(&self) -> &str {
1301 self.display_name
1302 }
1303
1304 fn as_any(&self) -> &dyn Any {
1305 self
1306 }
1307
1308 fn schema(&self) -> SchemaRef {
1309 self.schema.clone()
1310 }
1311
1312 fn properties(&self) -> &Arc<PlanProperties> {
1313 &self.properties
1314 }
1315
1316 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1317 vec![&self.input]
1318 }
1319
1320 fn with_new_children(
1321 self: Arc<Self>,
1322 children: Vec<Arc<dyn ExecutionPlan>>,
1323 ) -> DFResult<Arc<dyn ExecutionPlan>> {
1324 if children.len() != 1 {
1325 return Err(datafusion::error::DataFusionError::Plan(format!(
1326 "{} requires exactly one child",
1327 self.display_name,
1328 )));
1329 }
1330 Ok(Arc::new(MutationExec::new_with_schema(
1331 children[0].clone(),
1332 self.kind.clone(),
1333 self.display_name,
1334 self.mutation_ctx.clone(),
1335 self.schema.clone(),
1336 )))
1337 }
1338
1339 fn execute(
1340 &self,
1341 partition: usize,
1342 context: Arc<TaskContext>,
1343 ) -> DFResult<SendableRecordBatchStream> {
1344 let baseline = BaselineMetrics::new(&self.metrics, partition);
1345 execute_mutation_stream(
1346 self.input.clone(),
1347 self.schema.clone(),
1348 self.mutation_ctx.clone(),
1349 self.kind.clone(),
1350 partition,
1351 context,
1352 baseline,
1353 )
1354 }
1355
1356 fn metrics(&self) -> Option<MetricsSet> {
1357 Some(self.metrics.clone_inner())
1358 }
1359}
1360
1361pub fn new_create_exec(
1367 input: Arc<dyn ExecutionPlan>,
1368 pattern: Pattern,
1369 mutation_ctx: Arc<MutationContext>,
1370) -> MutationExec {
1371 let output_schema =
1372 extended_schema_for_new_vars(&input.schema(), std::slice::from_ref(&pattern));
1373 MutationExec::new_with_schema(
1374 input,
1375 MutationKind::Create { pattern },
1376 "MutationCreateExec",
1377 mutation_ctx,
1378 output_schema,
1379 )
1380}
1381
1382pub fn new_merge_exec(
1388 input: Arc<dyn ExecutionPlan>,
1389 pattern: Pattern,
1390 on_match: Option<SetClause>,
1391 on_create: Option<SetClause>,
1392 mutation_ctx: Arc<MutationContext>,
1393) -> MutationExec {
1394 let output_schema =
1395 extended_schema_for_new_vars(&input.schema(), std::slice::from_ref(&pattern));
1396 MutationExec::new_with_schema(
1397 input,
1398 MutationKind::Merge {
1399 pattern,
1400 on_match,
1401 on_create,
1402 },
1403 "MutationMergeExec",
1404 mutation_ctx,
1405 output_schema,
1406 )
1407}
1408
1409#[cfg(test)]
1410mod tests {
1411 use super::*;
1412 use arrow_array::{Int64Array, StringArray};
1413 use arrow_schema::{Field, Schema};
1414
1415 #[test]
1416 fn test_batches_to_rows_basic() {
1417 let schema = Arc::new(Schema::new(vec![
1418 Field::new("name", DataType::Utf8, true),
1419 Field::new("age", DataType::Int64, true),
1420 ]));
1421
1422 let batch = RecordBatch::try_new(
1423 schema,
1424 vec![
1425 Arc::new(StringArray::from(vec![Some("Alice"), Some("Bob")])),
1426 Arc::new(Int64Array::from(vec![Some(30), Some(25)])),
1427 ],
1428 )
1429 .unwrap();
1430
1431 let rows = batches_to_rows(&[batch]).unwrap();
1432 assert_eq!(rows.len(), 2);
1433 assert_eq!(rows[0].get("name"), Some(&Value::String("Alice".into())));
1434 assert_eq!(rows[0].get("age"), Some(&Value::Int(30)));
1435 assert_eq!(rows[1].get("name"), Some(&Value::String("Bob".into())));
1436 assert_eq!(rows[1].get("age"), Some(&Value::Int(25)));
1437 }
1438
1439 #[test]
1440 fn test_rows_to_batches_basic() {
1441 let schema = Arc::new(Schema::new(vec![
1442 Field::new("name", DataType::Utf8, true),
1443 Field::new("age", DataType::Int64, true),
1444 ]));
1445
1446 let rows = vec![
1447 {
1448 let mut m = HashMap::new();
1449 m.insert("name".to_string(), Value::String("Alice".into()));
1450 m.insert("age".to_string(), Value::Int(30));
1451 m
1452 },
1453 {
1454 let mut m = HashMap::new();
1455 m.insert("name".to_string(), Value::String("Bob".into()));
1456 m.insert("age".to_string(), Value::Int(25));
1457 m
1458 },
1459 ];
1460
1461 let batches = rows_to_batches(&rows, &schema).unwrap();
1462 assert_eq!(batches.len(), 1);
1463 assert_eq!(batches[0].num_rows(), 2);
1464 assert_eq!(batches[0].schema(), schema);
1465 }
1466
1467 #[test]
1468 fn test_roundtrip_scalar_types() {
1469 let schema = Arc::new(Schema::new(vec![
1470 Field::new("s", DataType::Utf8, true),
1471 Field::new("i", DataType::Int64, true),
1472 Field::new("f", DataType::Float64, true),
1473 Field::new("b", DataType::Boolean, true),
1474 ]));
1475
1476 let batch = RecordBatch::try_new(
1477 schema.clone(),
1478 vec![
1479 Arc::new(StringArray::from(vec![Some("hello")])),
1480 Arc::new(Int64Array::from(vec![Some(42)])),
1481 Arc::new(arrow_array::Float64Array::from(vec![Some(3.125)])),
1482 Arc::new(arrow_array::BooleanArray::from(vec![Some(true)])),
1483 ],
1484 )
1485 .unwrap();
1486
1487 let rows = batches_to_rows(&[batch]).unwrap();
1489 let output_batches = rows_to_batches(&rows, &schema).unwrap();
1490
1491 assert_eq!(output_batches.len(), 1);
1492 assert_eq!(output_batches[0].num_rows(), 1);
1493
1494 let roundtrip_rows = batches_to_rows(&output_batches).unwrap();
1496 assert_eq!(roundtrip_rows.len(), 1);
1497 assert_eq!(
1498 roundtrip_rows[0].get("s"),
1499 Some(&Value::String("hello".into()))
1500 );
1501 assert_eq!(roundtrip_rows[0].get("i"), Some(&Value::Int(42)));
1502 assert_eq!(roundtrip_rows[0].get("b"), Some(&Value::Bool(true)));
1503 if let Some(Value::Float(f)) = roundtrip_rows[0].get("f") {
1505 assert!((*f - 3.125).abs() < 1e-10);
1506 } else {
1507 panic!("Expected float value");
1508 }
1509 }
1510
1511 #[test]
1512 fn test_roundtrip_cypher_value_encoded() {
1513 use std::collections::HashMap as StdHashMap;
1514
1515 let mut metadata = StdHashMap::new();
1517 metadata.insert("cv_encoded".to_string(), "true".to_string());
1518 let field = Field::new("n", DataType::LargeBinary, true).with_metadata(metadata);
1519 let schema = Arc::new(Schema::new(vec![field]));
1520
1521 let mut node_map = HashMap::new();
1523 node_map.insert("name".to_string(), Value::String("Alice".into()));
1524 node_map.insert("_vid".to_string(), Value::Int(1));
1525 let map_val = Value::Map(node_map);
1526
1527 let encoded = uni_common::cypher_value_codec::encode(&map_val);
1529 let batch = RecordBatch::try_new(
1530 schema.clone(),
1531 vec![Arc::new(arrow_array::LargeBinaryArray::from(vec![Some(
1532 encoded.as_slice(),
1533 )]))],
1534 )
1535 .unwrap();
1536
1537 let rows = batches_to_rows(&[batch]).unwrap();
1539 assert_eq!(rows.len(), 1);
1540
1541 let val = rows[0].get("n").unwrap();
1543 assert!(matches!(val, Value::Map(_)));
1544
1545 let output_batches = rows_to_batches(&rows, &schema).unwrap();
1546 assert_eq!(output_batches[0].num_rows(), 1);
1547
1548 let roundtrip_rows = batches_to_rows(&output_batches).unwrap();
1550 assert_eq!(roundtrip_rows.len(), 1);
1551 }
1552
1553 #[test]
1554 fn test_empty_rows() {
1555 let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Int64, true)]));
1556
1557 let batches = rows_to_batches(&[], &schema).unwrap();
1558 assert_eq!(batches.len(), 1);
1559 assert_eq!(batches[0].num_rows(), 0);
1560 }
1561}