uni_query/query/df_graph/
recursive_cte.rs1use crate::query::df_graph::GraphExecutionContext;
21use crate::query::df_graph::common::{arrow_err, compute_plan_properties, execute_subplan};
22use crate::query::df_graph::unwind::arrow_to_json_value;
23use crate::query::planner::LogicalPlan;
24use arrow_array::RecordBatch;
25use arrow_array::builder::{Int64Builder, LargeListBuilder};
26use arrow_schema::{DataType, Field, Schema, SchemaRef};
27use datafusion::common::Result as DFResult;
28use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
29use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
30use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
31use datafusion::prelude::SessionContext;
32use futures::Stream;
33use parking_lot::RwLock;
34use std::any::Any;
35use std::collections::{HashMap, HashSet};
36use std::fmt;
37use std::pin::Pin;
38use std::sync::Arc;
39use std::task::{Context, Poll};
40use uni_common::Value;
41use uni_common::core::schema::Schema as UniSchema;
42use uni_store::storage::manager::StorageManager;
43
44const MAX_ITERATIONS: usize = 1000;
46
47pub struct RecursiveCTEExec {
53 cte_name: String,
55
56 initial_plan: LogicalPlan,
58
59 recursive_plan: LogicalPlan,
61
62 graph_ctx: Arc<GraphExecutionContext>,
64
65 session_ctx: Arc<RwLock<SessionContext>>,
67
68 storage: Arc<StorageManager>,
70
71 schema_info: Arc<UniSchema>,
73
74 params: HashMap<String, Value>,
76
77 output_schema: SchemaRef,
79
80 properties: PlanProperties,
82
83 metrics: ExecutionPlanMetricsSet,
85}
86
87impl fmt::Debug for RecursiveCTEExec {
88 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89 f.debug_struct("RecursiveCTEExec")
90 .field("cte_name", &self.cte_name)
91 .finish()
92 }
93}
94
95impl RecursiveCTEExec {
96 #[expect(clippy::too_many_arguments)]
98 pub fn new(
99 cte_name: String,
100 initial_plan: LogicalPlan,
101 recursive_plan: LogicalPlan,
102 graph_ctx: Arc<GraphExecutionContext>,
103 session_ctx: Arc<RwLock<SessionContext>>,
104 storage: Arc<StorageManager>,
105 schema_info: Arc<UniSchema>,
106 params: HashMap<String, Value>,
107 ) -> Self {
108 let inner_field = Arc::new(Field::new("item", DataType::Int64, true));
113 let field = Field::new(&cte_name, DataType::LargeList(inner_field), false);
114 let output_schema = Arc::new(Schema::new(vec![field]));
115 let properties = compute_plan_properties(output_schema.clone());
116
117 Self {
118 cte_name,
119 initial_plan,
120 recursive_plan,
121 graph_ctx,
122 session_ctx,
123 storage,
124 schema_info,
125 params,
126 output_schema,
127 properties,
128 metrics: ExecutionPlanMetricsSet::new(),
129 }
130 }
131}
132
133impl DisplayAs for RecursiveCTEExec {
134 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
135 write!(f, "RecursiveCTEExec: {}", self.cte_name)
136 }
137}
138
139impl ExecutionPlan for RecursiveCTEExec {
140 fn name(&self) -> &str {
141 "RecursiveCTEExec"
142 }
143
144 fn as_any(&self) -> &dyn Any {
145 self
146 }
147
148 fn schema(&self) -> SchemaRef {
149 self.output_schema.clone()
150 }
151
152 fn properties(&self) -> &PlanProperties {
153 &self.properties
154 }
155
156 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
157 vec![]
159 }
160
161 fn with_new_children(
162 self: Arc<Self>,
163 children: Vec<Arc<dyn ExecutionPlan>>,
164 ) -> DFResult<Arc<dyn ExecutionPlan>> {
165 if !children.is_empty() {
166 return Err(datafusion::error::DataFusionError::Plan(
167 "RecursiveCTEExec has no children".to_string(),
168 ));
169 }
170 Ok(self)
171 }
172
173 fn execute(
174 &self,
175 partition: usize,
176 _context: Arc<TaskContext>,
177 ) -> DFResult<SendableRecordBatchStream> {
178 let metrics = BaselineMetrics::new(&self.metrics, partition);
179
180 let cte_name = self.cte_name.clone();
182 let initial_plan = self.initial_plan.clone();
183 let recursive_plan = self.recursive_plan.clone();
184 let graph_ctx = self.graph_ctx.clone();
185 let session_ctx = self.session_ctx.clone();
186 let storage = self.storage.clone();
187 let schema_info = self.schema_info.clone();
188 let params = self.params.clone();
189 let output_schema = self.output_schema.clone();
190
191 let fut = async move {
192 run_cte_loop(
193 &cte_name,
194 &initial_plan,
195 &recursive_plan,
196 &graph_ctx,
197 &session_ctx,
198 &storage,
199 &schema_info,
200 ¶ms,
201 &output_schema,
202 )
203 .await
204 };
205
206 Ok(Box::pin(RecursiveCTEStream {
207 state: RecursiveCTEStreamState::Running(Box::pin(fut)),
208 schema: self.output_schema.clone(),
209 metrics,
210 }))
211 }
212
213 fn metrics(&self) -> Option<MetricsSet> {
214 Some(self.metrics.clone_inner())
215 }
216}
217
218fn batches_to_values(batches: &[RecordBatch]) -> Vec<Value> {
227 let mut values = Vec::new();
228 for batch in batches {
229 let num_cols = batch.num_columns();
230 let schema = batch.schema();
231
232 for row_idx in 0..batch.num_rows() {
233 if num_cols == 1 {
234 values.push(arrow_to_json_value(batch.column(0).as_ref(), row_idx));
235 } else {
236 let mut map = Vec::new();
237 for col_idx in 0..num_cols {
238 let col_name = schema.field(col_idx).name().clone();
239 let val = arrow_to_json_value(batch.column(col_idx).as_ref(), row_idx);
240 map.push((col_name, val));
241 }
242 values.push(Value::Map(map.into_iter().collect()));
243 }
244 }
245 }
246 values
247}
248
249fn value_key(val: &Value) -> String {
251 format!("{val:?}")
252}
253
254fn extract_vid(val: &Value) -> Option<u64> {
261 match val {
262 Value::Int(v) => Some(*v as u64),
263 Value::Map(map) => {
264 for (k, v) in map {
266 if k == "_vid" || k.ends_with("._vid") {
267 return v.as_u64();
268 }
269 }
270 None
271 }
272 _ => val.as_u64(),
273 }
274}
275
276#[expect(clippy::too_many_arguments)]
278async fn run_cte_loop(
279 cte_name: &str,
280 initial_plan: &LogicalPlan,
281 recursive_plan: &LogicalPlan,
282 graph_ctx: &Arc<GraphExecutionContext>,
283 session_ctx: &Arc<RwLock<SessionContext>>,
284 storage: &Arc<StorageManager>,
285 schema_info: &Arc<UniSchema>,
286 params: &HashMap<String, Value>,
287 output_schema: &SchemaRef,
288) -> DFResult<RecordBatch> {
289 let anchor_batches = execute_subplan(
291 initial_plan,
292 params,
293 &HashMap::new(), graph_ctx,
295 session_ctx,
296 storage,
297 schema_info,
298 )
299 .await?;
300 let mut working_values = batches_to_values(&anchor_batches);
301 let mut result_values = working_values.clone();
302
303 let mut seen: HashSet<String> = working_values.iter().map(value_key).collect();
305
306 for _iteration in 0..MAX_ITERATIONS {
308 if working_values.is_empty() {
309 break;
310 }
311
312 let vid_list = Value::List(
316 working_values
317 .iter()
318 .filter_map(|v| extract_vid(v).map(|vid| Value::Int(vid as i64)))
319 .collect(),
320 );
321 let mut next_params = params.clone();
322 next_params.insert(cte_name.to_string(), vid_list);
323
324 let recursive_batches = execute_subplan(
326 recursive_plan,
327 &next_params,
328 &HashMap::new(), graph_ctx,
330 session_ctx,
331 storage,
332 schema_info,
333 )
334 .await?;
335 let next_values = batches_to_values(&recursive_batches);
336
337 if next_values.is_empty() {
338 break;
339 }
340
341 let new_values: Vec<Value> = next_values
343 .into_iter()
344 .filter(|val| {
345 let key = value_key(val);
346 seen.insert(key) })
348 .collect();
349
350 if new_values.is_empty() {
351 break;
352 }
353
354 result_values.extend(new_values.clone());
355 working_values = new_values;
356 }
357
358 let mut list_builder = LargeListBuilder::new(Int64Builder::new());
361 for val in &result_values {
362 if let Some(vid) = extract_vid(val) {
363 list_builder.values().append_value(vid as i64);
364 }
365 }
366 list_builder.append(true);
367 let array = Arc::new(list_builder.finish());
368
369 RecordBatch::try_new(output_schema.clone(), vec![array]).map_err(arrow_err)
370}
371
372enum RecursiveCTEStreamState {
378 Running(Pin<Box<dyn std::future::Future<Output = DFResult<RecordBatch>> + Send>>),
380 Done,
382}
383
384struct RecursiveCTEStream {
386 state: RecursiveCTEStreamState,
387 schema: SchemaRef,
388 metrics: BaselineMetrics,
389}
390
391impl Stream for RecursiveCTEStream {
392 type Item = DFResult<RecordBatch>;
393
394 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
395 match &mut self.state {
396 RecursiveCTEStreamState::Running(fut) => match fut.as_mut().poll(cx) {
397 Poll::Ready(Ok(batch)) => {
398 self.metrics.record_output(batch.num_rows());
399 self.state = RecursiveCTEStreamState::Done;
400 Poll::Ready(Some(Ok(batch)))
401 }
402 Poll::Ready(Err(e)) => {
403 self.state = RecursiveCTEStreamState::Done;
404 Poll::Ready(Some(Err(e)))
405 }
406 Poll::Pending => Poll::Pending,
407 },
408 RecursiveCTEStreamState::Done => Poll::Ready(None),
409 }
410 }
411}
412
413impl RecordBatchStream for RecursiveCTEStream {
414 fn schema(&self) -> SchemaRef {
415 self.schema.clone()
416 }
417}