Skip to main content

uni_query/query/df_graph/
recursive_cte.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Recursive CTE execution plan for DataFusion.
5//!
6//! Implements `WITH RECURSIVE` by iteratively executing the recursive part
7//! with an updated working table until a fixed point is reached (no new rows).
8//!
9//! # Algorithm
10//!
11//! 1. Execute the anchor (initial) query → working table
12//! 2. Loop:
13//!    a. Bind working table as a parameter under the CTE name
14//!    b. Re-plan and execute the recursive query with updated params
15//!    c. Deduplicate against previously seen rows (cycle detection)
16//!    d. If no new rows, terminate
17//!    e. Accumulate new rows and repeat
18//! 3. Output all accumulated rows as a single-column list
19
20use crate::query::df_graph::GraphExecutionContext;
21use crate::query::df_graph::common::{arrow_err, compute_plan_properties, execute_subplan};
22use crate::query::df_graph::unwind::arrow_to_json_value;
23use crate::query::planner::LogicalPlan;
24use arrow_array::RecordBatch;
25use arrow_array::builder::{Int64Builder, LargeListBuilder};
26use arrow_schema::{DataType, Field, Schema, SchemaRef};
27use datafusion::common::Result as DFResult;
28use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
29use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
30use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
31use datafusion::prelude::SessionContext;
32use futures::Stream;
33use parking_lot::RwLock;
34use std::any::Any;
35use std::collections::{HashMap, HashSet};
36use std::fmt;
37use std::pin::Pin;
38use std::sync::Arc;
39use std::task::{Context, Poll};
40use uni_common::Value;
41use uni_common::core::schema::Schema as UniSchema;
42use uni_store::storage::manager::StorageManager;
43
44/// Maximum number of CTE iterations before forced termination.
45const MAX_ITERATIONS: usize = 1000;
46
47/// Recursive CTE execution plan.
48///
49/// Stores **logical** plans (not physical) and re-plans + executes on each
50/// iteration with updated parameters. The CTE name is injected as a parameter
51/// containing the current working table.
52pub struct RecursiveCTEExec {
53    /// Name of the CTE (e.g., `hierarchy`).
54    cte_name: String,
55
56    /// Logical plan for the anchor query.
57    initial_plan: LogicalPlan,
58
59    /// Logical plan for the recursive query.
60    recursive_plan: LogicalPlan,
61
62    /// Graph execution context shared with sub-planners.
63    graph_ctx: Arc<GraphExecutionContext>,
64
65    /// DataFusion session context.
66    session_ctx: Arc<RwLock<SessionContext>>,
67
68    /// Storage manager for creating sub-planners.
69    storage: Arc<StorageManager>,
70
71    /// Schema for label/edge type lookups.
72    schema_info: Arc<UniSchema>,
73
74    /// Query parameters (will be extended with CTE working table).
75    params: HashMap<String, Value>,
76
77    /// Output schema (single column: the CTE name containing JSON-encoded values).
78    output_schema: SchemaRef,
79
80    /// Cached plan properties.
81    properties: PlanProperties,
82
83    /// Execution metrics.
84    metrics: ExecutionPlanMetricsSet,
85}
86
87impl fmt::Debug for RecursiveCTEExec {
88    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89        f.debug_struct("RecursiveCTEExec")
90            .field("cte_name", &self.cte_name)
91            .finish()
92    }
93}
94
95impl RecursiveCTEExec {
96    /// Create a new recursive CTE execution plan.
97    #[expect(clippy::too_many_arguments)]
98    pub fn new(
99        cte_name: String,
100        initial_plan: LogicalPlan,
101        recursive_plan: LogicalPlan,
102        graph_ctx: Arc<GraphExecutionContext>,
103        session_ctx: Arc<RwLock<SessionContext>>,
104        storage: Arc<StorageManager>,
105        schema_info: Arc<UniSchema>,
106        params: HashMap<String, Value>,
107    ) -> Self {
108        // Output schema: single column named after CTE, containing a LargeList<Int64>.
109        // Each element is a VID (cast to Int64) from the CTE results. The `n IN hierarchy`
110        // expression is rewritten to `CAST(n._vid AS Int64) IN hierarchy` by the expression
111        // translator, so the types match.
112        let inner_field = Arc::new(Field::new("item", DataType::Int64, true));
113        let field = Field::new(&cte_name, DataType::LargeList(inner_field), false);
114        let output_schema = Arc::new(Schema::new(vec![field]));
115        let properties = compute_plan_properties(output_schema.clone());
116
117        Self {
118            cte_name,
119            initial_plan,
120            recursive_plan,
121            graph_ctx,
122            session_ctx,
123            storage,
124            schema_info,
125            params,
126            output_schema,
127            properties,
128            metrics: ExecutionPlanMetricsSet::new(),
129        }
130    }
131}
132
133impl DisplayAs for RecursiveCTEExec {
134    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
135        write!(f, "RecursiveCTEExec: {}", self.cte_name)
136    }
137}
138
139impl ExecutionPlan for RecursiveCTEExec {
140    fn name(&self) -> &str {
141        "RecursiveCTEExec"
142    }
143
144    fn as_any(&self) -> &dyn Any {
145        self
146    }
147
148    fn schema(&self) -> SchemaRef {
149        self.output_schema.clone()
150    }
151
152    fn properties(&self) -> &PlanProperties {
153        &self.properties
154    }
155
156    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
157        // No physical children — sub-plans are re-planned at execution time
158        vec![]
159    }
160
161    fn with_new_children(
162        self: Arc<Self>,
163        children: Vec<Arc<dyn ExecutionPlan>>,
164    ) -> DFResult<Arc<dyn ExecutionPlan>> {
165        if !children.is_empty() {
166            return Err(datafusion::error::DataFusionError::Plan(
167                "RecursiveCTEExec has no children".to_string(),
168            ));
169        }
170        Ok(self)
171    }
172
173    fn execute(
174        &self,
175        partition: usize,
176        _context: Arc<TaskContext>,
177    ) -> DFResult<SendableRecordBatchStream> {
178        let metrics = BaselineMetrics::new(&self.metrics, partition);
179
180        // Clone all fields needed for the async computation
181        let cte_name = self.cte_name.clone();
182        let initial_plan = self.initial_plan.clone();
183        let recursive_plan = self.recursive_plan.clone();
184        let graph_ctx = self.graph_ctx.clone();
185        let session_ctx = self.session_ctx.clone();
186        let storage = self.storage.clone();
187        let schema_info = self.schema_info.clone();
188        let params = self.params.clone();
189        let output_schema = self.output_schema.clone();
190
191        let fut = async move {
192            run_cte_loop(
193                &cte_name,
194                &initial_plan,
195                &recursive_plan,
196                &graph_ctx,
197                &session_ctx,
198                &storage,
199                &schema_info,
200                &params,
201                &output_schema,
202            )
203            .await
204        };
205
206        Ok(Box::pin(RecursiveCTEStream {
207            state: RecursiveCTEStreamState::Running(Box::pin(fut)),
208            schema: self.output_schema.clone(),
209            metrics,
210        }))
211    }
212
213    fn metrics(&self) -> Option<MetricsSet> {
214        Some(self.metrics.clone_inner())
215    }
216}
217
218// ---------------------------------------------------------------------------
219// Free functions for the CTE iteration loop
220// ---------------------------------------------------------------------------
221
222/// Extract values from record batches into a flat list of `Value`.
223///
224/// Each row becomes a single `Value`. If the row has one column, the column
225/// value is used directly. If multiple columns, they are combined into a `Value::Map`.
226fn batches_to_values(batches: &[RecordBatch]) -> Vec<Value> {
227    let mut values = Vec::new();
228    for batch in batches {
229        let num_cols = batch.num_columns();
230        let schema = batch.schema();
231
232        for row_idx in 0..batch.num_rows() {
233            if num_cols == 1 {
234                values.push(arrow_to_json_value(batch.column(0).as_ref(), row_idx));
235            } else {
236                let mut map = Vec::new();
237                for col_idx in 0..num_cols {
238                    let col_name = schema.field(col_idx).name().clone();
239                    let val = arrow_to_json_value(batch.column(col_idx).as_ref(), row_idx);
240                    map.push((col_name, val));
241                }
242                values.push(Value::Map(map.into_iter().collect()));
243            }
244        }
245    }
246    values
247}
248
249/// Create a stable string key for a Value, used for cycle detection.
250fn value_key(val: &Value) -> String {
251    format!("{val:?}")
252}
253
254/// Extract the VID from a CTE result value.
255///
256/// CTE result values can be:
257/// - A Map with a `*._vid` key (from multi-column scan results)
258/// - A raw integer (from single-column VID returns)
259/// - A Map with a `_vid` key
260fn extract_vid(val: &Value) -> Option<u64> {
261    match val {
262        Value::Int(v) => Some(*v as u64),
263        Value::Map(map) => {
264            // Look for any key ending in `._vid` or exactly `_vid`
265            for (k, v) in map {
266                if k == "_vid" || k.ends_with("._vid") {
267                    return v.as_u64();
268                }
269            }
270            None
271        }
272        _ => val.as_u64(),
273    }
274}
275
276/// Run the full recursive CTE iteration loop and produce the output batch.
277#[expect(clippy::too_many_arguments)]
278async fn run_cte_loop(
279    cte_name: &str,
280    initial_plan: &LogicalPlan,
281    recursive_plan: &LogicalPlan,
282    graph_ctx: &Arc<GraphExecutionContext>,
283    session_ctx: &Arc<RwLock<SessionContext>>,
284    storage: &Arc<StorageManager>,
285    schema_info: &Arc<UniSchema>,
286    params: &HashMap<String, Value>,
287    output_schema: &SchemaRef,
288) -> DFResult<RecordBatch> {
289    // 1. Execute anchor
290    let anchor_batches = execute_subplan(
291        initial_plan,
292        params,
293        &HashMap::new(), // No outer values for anchor
294        graph_ctx,
295        session_ctx,
296        storage,
297        schema_info,
298    )
299    .await?;
300    let mut working_values = batches_to_values(&anchor_batches);
301    let mut result_values = working_values.clone();
302
303    // Track seen values for cycle detection
304    let mut seen: HashSet<String> = working_values.iter().map(value_key).collect();
305
306    // 2. Iterate
307    for _iteration in 0..MAX_ITERATIONS {
308        if working_values.is_empty() {
309            break;
310        }
311
312        // Bind working table VIDs to CTE name in params.
313        // Extract VIDs so the expression translator resolves `hierarchy` as List<Int64>,
314        // matching the VID column type used by `n._vid IN hierarchy`.
315        let vid_list = Value::List(
316            working_values
317                .iter()
318                .filter_map(|v| extract_vid(v).map(|vid| Value::Int(vid as i64)))
319                .collect(),
320        );
321        let mut next_params = params.clone();
322        next_params.insert(cte_name.to_string(), vid_list);
323
324        // Execute recursive part
325        let recursive_batches = execute_subplan(
326            recursive_plan,
327            &next_params,
328            &HashMap::new(), // No outer values for recursive part
329            graph_ctx,
330            session_ctx,
331            storage,
332            schema_info,
333        )
334        .await?;
335        let next_values = batches_to_values(&recursive_batches);
336
337        if next_values.is_empty() {
338            break;
339        }
340
341        // Filter out already-seen values (cycle detection)
342        let new_values: Vec<Value> = next_values
343            .into_iter()
344            .filter(|val| {
345                let key = value_key(val);
346                seen.insert(key) // returns false if already present
347            })
348            .collect();
349
350        if new_values.is_empty() {
351            break;
352        }
353
354        result_values.extend(new_values.clone());
355        working_values = new_values;
356    }
357
358    // 3. Build output: single row with a LargeList<Int64> column of VIDs.
359    // Each element is a VID (as Int64) extracted from the CTE results.
360    let mut list_builder = LargeListBuilder::new(Int64Builder::new());
361    for val in &result_values {
362        if let Some(vid) = extract_vid(val) {
363            list_builder.values().append_value(vid as i64);
364        }
365    }
366    list_builder.append(true);
367    let array = Arc::new(list_builder.finish());
368
369    RecordBatch::try_new(output_schema.clone(), vec![array]).map_err(arrow_err)
370}
371
372// ---------------------------------------------------------------------------
373// Stream implementation
374// ---------------------------------------------------------------------------
375
376/// Stream state for the recursive CTE.
377enum RecursiveCTEStreamState {
378    /// The CTE computation is running.
379    Running(Pin<Box<dyn std::future::Future<Output = DFResult<RecordBatch>> + Send>>),
380    /// Computation completed, batch ready to emit.
381    Done,
382}
383
384/// Stream that runs the recursive CTE and emits the result.
385struct RecursiveCTEStream {
386    state: RecursiveCTEStreamState,
387    schema: SchemaRef,
388    metrics: BaselineMetrics,
389}
390
391impl Stream for RecursiveCTEStream {
392    type Item = DFResult<RecordBatch>;
393
394    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
395        match &mut self.state {
396            RecursiveCTEStreamState::Running(fut) => match fut.as_mut().poll(cx) {
397                Poll::Ready(Ok(batch)) => {
398                    self.metrics.record_output(batch.num_rows());
399                    self.state = RecursiveCTEStreamState::Done;
400                    Poll::Ready(Some(Ok(batch)))
401                }
402                Poll::Ready(Err(e)) => {
403                    self.state = RecursiveCTEStreamState::Done;
404                    Poll::Ready(Some(Err(e)))
405                }
406                Poll::Pending => Poll::Pending,
407            },
408            RecursiveCTEStreamState::Done => Poll::Ready(None),
409        }
410    }
411}
412
413impl RecordBatchStream for RecursiveCTEStream {
414    fn schema(&self) -> SchemaRef {
415        self.schema.clone()
416    }
417}