uni_query/query/df_graph/
read_set_exec.rs1use std::any::Any;
14use std::pin::Pin;
15use std::sync::Arc;
16use std::task::{Context, Poll};
17
18use arrow::datatypes::SchemaRef;
19use arrow_array::RecordBatch;
20use datafusion::common::Result as DFResult;
21use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
22use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
23use futures::{Stream, StreamExt};
24
25use crate::query::df_graph::GraphExecutionContext;
26
27#[derive(Debug)]
34pub struct ReadSetRecordingExec {
35 input: Arc<dyn ExecutionPlan>,
36 graph_ctx: Arc<GraphExecutionContext>,
37 vertex_cols: Vec<usize>,
39 edge_cols: Vec<usize>,
41}
42
43impl ReadSetRecordingExec {
44 pub fn new(
49 input: Arc<dyn ExecutionPlan>,
50 graph_ctx: Arc<GraphExecutionContext>,
51 variable: &str,
52 ) -> Self {
53 let vid_name = format!("{variable}._vid");
54 let eid_name = format!("{variable}._eid");
55 let mut vertex_cols = Vec::new();
56 let mut edge_cols = Vec::new();
57 for (i, field) in input.schema().fields().iter().enumerate() {
58 if field.name() == &vid_name {
59 vertex_cols.push(i);
60 } else if field.name() == &eid_name {
61 edge_cols.push(i);
62 }
63 }
64 Self {
65 input,
66 graph_ctx,
67 vertex_cols,
68 edge_cols,
69 }
70 }
71}
72
73impl DisplayAs for ReadSetRecordingExec {
74 fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
75 write!(f, "ReadSetRecordingExec")
76 }
77}
78
79impl ExecutionPlan for ReadSetRecordingExec {
80 fn name(&self) -> &str {
81 "ReadSetRecordingExec"
82 }
83
84 fn as_any(&self) -> &dyn Any {
85 self
86 }
87
88 fn properties(&self) -> &Arc<PlanProperties> {
89 self.input.properties()
90 }
91
92 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
93 vec![&self.input]
94 }
95
96 fn with_new_children(
97 self: Arc<Self>,
98 children: Vec<Arc<dyn ExecutionPlan>>,
99 ) -> DFResult<Arc<dyn ExecutionPlan>> {
100 let input = children.into_iter().next().ok_or_else(|| {
101 datafusion::error::DataFusionError::Internal(
102 "ReadSetRecordingExec requires exactly one child".to_string(),
103 )
104 })?;
105 Ok(Arc::new(ReadSetRecordingExec {
106 input,
107 graph_ctx: self.graph_ctx.clone(),
108 vertex_cols: self.vertex_cols.clone(),
109 edge_cols: self.edge_cols.clone(),
110 }))
111 }
112
113 fn execute(
114 &self,
115 partition: usize,
116 context: Arc<TaskContext>,
117 ) -> DFResult<SendableRecordBatchStream> {
118 let inner = self.input.execute(partition, context)?;
119 Ok(Box::pin(ReadSetRecordingStream {
120 schema: self.input.schema(),
121 inner,
122 graph_ctx: self.graph_ctx.clone(),
123 vertex_cols: self.vertex_cols.clone(),
124 edge_cols: self.edge_cols.clone(),
125 }))
126 }
127}
128
129struct ReadSetRecordingStream {
131 schema: SchemaRef,
132 inner: SendableRecordBatchStream,
133 graph_ctx: Arc<GraphExecutionContext>,
134 vertex_cols: Vec<usize>,
135 edge_cols: Vec<usize>,
136}
137
138impl Stream for ReadSetRecordingStream {
139 type Item = DFResult<RecordBatch>;
140
141 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
142 match self.inner.poll_next_unpin(cx) {
143 Poll::Ready(Some(Ok(batch))) => {
144 self.graph_ctx
145 .record_batch_ids(&batch, &self.vertex_cols, &self.edge_cols);
146 Poll::Ready(Some(Ok(batch)))
147 }
148 other => other,
149 }
150 }
151}
152
153impl RecordBatchStream for ReadSetRecordingStream {
154 fn schema(&self) -> SchemaRef {
155 self.schema.clone()
156 }
157}