Skip to main content

uni_query/query/df_graph/
recursive_cte.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Recursive CTE execution plan for DataFusion.
5//!
6//! Implements `WITH RECURSIVE` by iteratively executing the recursive part
7//! with an updated working table until a fixed point is reached (no new rows).
8//!
9//! # Algorithm
10//!
11//! 1. Execute the anchor (initial) query → working table
12//! 2. Loop:
13//!    a. Bind working table as a parameter under the CTE name
14//!    b. Re-plan and execute the recursive query with updated params
15//!    c. Deduplicate against previously seen rows (cycle detection)
16//!    d. If no new rows, terminate
17//!    e. Accumulate new rows and repeat
18//! 3. Output all accumulated rows as a single-column list
19
20use crate::query::df_graph::common::{arrow_err, compute_plan_properties, execute_subplan};
21use crate::query::df_graph::unwind::arrow_to_json_value;
22use crate::query::df_graph::{GraphExecutionContext, MutationContext};
23use crate::query::planner::LogicalPlan;
24use arrow_array::RecordBatch;
25use arrow_array::builder::{Int64Builder, LargeListBuilder};
26use arrow_schema::{DataType, Field, Schema, SchemaRef};
27use datafusion::common::Result as DFResult;
28use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
29use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
30use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
31use datafusion::prelude::SessionContext;
32use futures::Stream;
33use parking_lot::RwLock;
34use std::any::Any;
35use std::collections::{HashMap, HashSet};
36use std::fmt;
37use std::pin::Pin;
38use std::sync::Arc;
39use std::task::{Context, Poll};
40use uni_common::Value;
41use uni_common::core::schema::Schema as UniSchema;
42use uni_store::storage::manager::StorageManager;
43
44/// Maximum number of CTE iterations before forced termination.
45const MAX_ITERATIONS: usize = 1000;
46
47/// Recursive CTE execution plan.
48///
49/// Stores **logical** plans (not physical) and re-plans + executes on each
50/// iteration with updated parameters. The CTE name is injected as a parameter
51/// containing the current working table.
52pub struct RecursiveCTEExec {
53    /// Name of the CTE (e.g., `hierarchy`).
54    cte_name: String,
55
56    /// Logical plan for the anchor query.
57    initial_plan: LogicalPlan,
58
59    /// Logical plan for the recursive query.
60    recursive_plan: LogicalPlan,
61
62    /// Graph execution context shared with sub-planners.
63    graph_ctx: Arc<GraphExecutionContext>,
64
65    /// DataFusion session context.
66    session_ctx: Arc<RwLock<SessionContext>>,
67
68    /// Storage manager for creating sub-planners.
69    storage: Arc<StorageManager>,
70
71    /// Schema for label/edge type lookups.
72    schema_info: Arc<UniSchema>,
73
74    /// Query parameters (will be extended with CTE working table).
75    params: HashMap<String, Value>,
76
77    /// Output schema (single column: the CTE name containing JSON-encoded values).
78    output_schema: SchemaRef,
79
80    /// Cached plan properties.
81    properties: Arc<PlanProperties>,
82
83    /// Outer mutation context, threaded into each iteration's sub-planner
84    /// so that writes inside the recursive body (rare but supported by the
85    /// general API) route through the same transaction's L0 buffer.
86    mutation_ctx: Option<Arc<MutationContext>>,
87
88    /// Execution metrics.
89    metrics: ExecutionPlanMetricsSet,
90}
91
92impl fmt::Debug for RecursiveCTEExec {
93    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94        f.debug_struct("RecursiveCTEExec")
95            .field("cte_name", &self.cte_name)
96            .finish()
97    }
98}
99
100impl RecursiveCTEExec {
101    /// Create a new recursive CTE execution plan.
102    #[expect(clippy::too_many_arguments)]
103    pub fn new(
104        cte_name: String,
105        initial_plan: LogicalPlan,
106        recursive_plan: LogicalPlan,
107        graph_ctx: Arc<GraphExecutionContext>,
108        session_ctx: Arc<RwLock<SessionContext>>,
109        storage: Arc<StorageManager>,
110        schema_info: Arc<UniSchema>,
111        params: HashMap<String, Value>,
112        mutation_ctx: Option<Arc<MutationContext>>,
113    ) -> Self {
114        // Output schema: single column named after CTE, containing a LargeList<Int64>.
115        // Each element is a VID (cast to Int64) from the CTE results. The `n IN hierarchy`
116        // expression is rewritten to `CAST(n._vid AS Int64) IN hierarchy` by the expression
117        // translator, so the types match.
118        let inner_field = Arc::new(Field::new("item", DataType::Int64, true));
119        let field = Field::new(&cte_name, DataType::LargeList(inner_field), false);
120        let output_schema = Arc::new(Schema::new(vec![field]));
121        let properties = compute_plan_properties(output_schema.clone());
122
123        Self {
124            cte_name,
125            initial_plan,
126            recursive_plan,
127            graph_ctx,
128            session_ctx,
129            storage,
130            schema_info,
131            params,
132            output_schema,
133            properties,
134            mutation_ctx,
135            metrics: ExecutionPlanMetricsSet::new(),
136        }
137    }
138}
139
140impl DisplayAs for RecursiveCTEExec {
141    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
142        write!(f, "RecursiveCTEExec: {}", self.cte_name)
143    }
144}
145
146impl ExecutionPlan for RecursiveCTEExec {
147    fn name(&self) -> &str {
148        "RecursiveCTEExec"
149    }
150
151    fn as_any(&self) -> &dyn Any {
152        self
153    }
154
155    fn schema(&self) -> SchemaRef {
156        self.output_schema.clone()
157    }
158
159    fn properties(&self) -> &Arc<PlanProperties> {
160        &self.properties
161    }
162
163    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
164        // No physical children — sub-plans are re-planned at execution time
165        vec![]
166    }
167
168    fn with_new_children(
169        self: Arc<Self>,
170        children: Vec<Arc<dyn ExecutionPlan>>,
171    ) -> DFResult<Arc<dyn ExecutionPlan>> {
172        if !children.is_empty() {
173            return Err(datafusion::error::DataFusionError::Plan(
174                "RecursiveCTEExec has no children".to_string(),
175            ));
176        }
177        Ok(self)
178    }
179
180    fn execute(
181        &self,
182        partition: usize,
183        _context: Arc<TaskContext>,
184    ) -> DFResult<SendableRecordBatchStream> {
185        let metrics = BaselineMetrics::new(&self.metrics, partition);
186
187        // Clone all fields needed for the async computation
188        let cte_name = self.cte_name.clone();
189        let initial_plan = self.initial_plan.clone();
190        let recursive_plan = self.recursive_plan.clone();
191        let graph_ctx = self.graph_ctx.clone();
192        let session_ctx = self.session_ctx.clone();
193        let storage = self.storage.clone();
194        let schema_info = self.schema_info.clone();
195        let params = self.params.clone();
196        let output_schema = self.output_schema.clone();
197        let mutation_ctx = self.mutation_ctx.clone();
198
199        let fut = async move {
200            run_cte_loop(
201                &cte_name,
202                &initial_plan,
203                &recursive_plan,
204                &graph_ctx,
205                &session_ctx,
206                &storage,
207                &schema_info,
208                &params,
209                &output_schema,
210                mutation_ctx.as_ref(),
211            )
212            .await
213        };
214
215        Ok(Box::pin(RecursiveCTEStream {
216            state: RecursiveCTEStreamState::Running(Box::pin(fut)),
217            schema: self.output_schema.clone(),
218            metrics,
219        }))
220    }
221
222    fn metrics(&self) -> Option<MetricsSet> {
223        Some(self.metrics.clone_inner())
224    }
225}
226
227// ---------------------------------------------------------------------------
228// Free functions for the CTE iteration loop
229// ---------------------------------------------------------------------------
230
231/// Extract values from record batches into a flat list of `Value`.
232///
233/// Each row becomes a single `Value`. If the row has one column, the column
234/// value is used directly. If multiple columns, they are combined into a `Value::Map`.
235fn batches_to_values(batches: &[RecordBatch]) -> Vec<Value> {
236    let mut values = Vec::new();
237    for batch in batches {
238        let num_cols = batch.num_columns();
239        let schema = batch.schema();
240
241        for row_idx in 0..batch.num_rows() {
242            if num_cols == 1 {
243                values.push(arrow_to_json_value(batch.column(0).as_ref(), row_idx));
244            } else {
245                let mut map = Vec::new();
246                for col_idx in 0..num_cols {
247                    let col_name = schema.field(col_idx).name().clone();
248                    let val = arrow_to_json_value(batch.column(col_idx).as_ref(), row_idx);
249                    map.push((col_name, val));
250                }
251                values.push(Value::Map(map.into_iter().collect()));
252            }
253        }
254    }
255    values
256}
257
258/// Create a stable string key for a Value, used for cycle detection.
259fn value_key(val: &Value) -> String {
260    format!("{val:?}")
261}
262
263/// Extract the VID from a CTE result value.
264///
265/// CTE result values can be:
266/// - A Map with a `*._vid` key (from multi-column scan results)
267/// - A raw integer (from single-column VID returns)
268/// - A Map with a `_vid` key
269fn extract_vid(val: &Value) -> Option<u64> {
270    match val {
271        Value::Int(v) => Some(*v as u64),
272        Value::Map(map) => {
273            // Look for any key ending in `._vid` or exactly `_vid`
274            for (k, v) in map {
275                if k == "_vid" || k.ends_with("._vid") {
276                    return v.as_u64();
277                }
278            }
279            None
280        }
281        _ => val.as_u64(),
282    }
283}
284
285/// Run the full recursive CTE iteration loop and produce the output batch.
286#[expect(clippy::too_many_arguments)]
287async fn run_cte_loop(
288    cte_name: &str,
289    initial_plan: &LogicalPlan,
290    recursive_plan: &LogicalPlan,
291    graph_ctx: &Arc<GraphExecutionContext>,
292    session_ctx: &Arc<RwLock<SessionContext>>,
293    storage: &Arc<StorageManager>,
294    schema_info: &Arc<UniSchema>,
295    params: &HashMap<String, Value>,
296    output_schema: &SchemaRef,
297    mutation_ctx: Option<&Arc<MutationContext>>,
298) -> DFResult<RecordBatch> {
299    // 1. Execute anchor
300    let anchor_batches = execute_subplan(
301        initial_plan,
302        params,
303        &HashMap::new(), // No outer values for anchor
304        graph_ctx,
305        session_ctx,
306        storage,
307        schema_info,
308        mutation_ctx,
309    )
310    .await?;
311    let mut working_values = batches_to_values(&anchor_batches);
312    let mut result_values = working_values.clone();
313
314    // Track seen values for cycle detection
315    let mut seen: HashSet<String> = working_values.iter().map(value_key).collect();
316
317    // 2. Iterate
318    for _iteration in 0..MAX_ITERATIONS {
319        if working_values.is_empty() {
320            break;
321        }
322
323        // Bind working table VIDs to CTE name in params.
324        // Extract VIDs so the expression translator resolves `hierarchy` as List<Int64>,
325        // matching the VID column type used by `n._vid IN hierarchy`.
326        let vid_list = Value::List(
327            working_values
328                .iter()
329                .filter_map(|v| extract_vid(v).map(|vid| Value::Int(vid as i64)))
330                .collect(),
331        );
332        let mut next_params = params.clone();
333        next_params.insert(cte_name.to_string(), vid_list);
334
335        // Execute recursive part
336        let recursive_batches = execute_subplan(
337            recursive_plan,
338            &next_params,
339            &HashMap::new(), // No outer values for recursive part
340            graph_ctx,
341            session_ctx,
342            storage,
343            schema_info,
344            mutation_ctx,
345        )
346        .await?;
347        let next_values = batches_to_values(&recursive_batches);
348
349        if next_values.is_empty() {
350            break;
351        }
352
353        // Filter out already-seen values (cycle detection)
354        let new_values: Vec<Value> = next_values
355            .into_iter()
356            .filter(|val| {
357                let key = value_key(val);
358                seen.insert(key) // returns false if already present
359            })
360            .collect();
361
362        if new_values.is_empty() {
363            break;
364        }
365
366        result_values.extend(new_values.clone());
367        working_values = new_values;
368    }
369
370    // 3. Build output: single row with a LargeList<Int64> column of VIDs.
371    // Each element is a VID (as Int64) extracted from the CTE results.
372    let mut list_builder = LargeListBuilder::new(Int64Builder::new());
373    for val in &result_values {
374        if let Some(vid) = extract_vid(val) {
375            list_builder.values().append_value(vid as i64);
376        }
377    }
378    list_builder.append(true);
379    let array = Arc::new(list_builder.finish());
380
381    RecordBatch::try_new(output_schema.clone(), vec![array]).map_err(arrow_err)
382}
383
384// ---------------------------------------------------------------------------
385// Stream implementation
386// ---------------------------------------------------------------------------
387
388/// Stream state for the recursive CTE.
389enum RecursiveCTEStreamState {
390    /// The CTE computation is running.
391    Running(Pin<Box<dyn std::future::Future<Output = DFResult<RecordBatch>> + Send>>),
392    /// Computation completed, batch ready to emit.
393    Done,
394}
395
396/// Stream that runs the recursive CTE and emits the result.
397struct RecursiveCTEStream {
398    state: RecursiveCTEStreamState,
399    schema: SchemaRef,
400    metrics: BaselineMetrics,
401}
402
403impl Stream for RecursiveCTEStream {
404    type Item = DFResult<RecordBatch>;
405
406    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
407        let metrics = self.metrics.clone();
408        let _timer = metrics.elapsed_compute().timer();
409        match &mut self.state {
410            RecursiveCTEStreamState::Running(fut) => match fut.as_mut().poll(cx) {
411                Poll::Ready(Ok(batch)) => {
412                    self.metrics.record_output(batch.num_rows());
413                    self.state = RecursiveCTEStreamState::Done;
414                    Poll::Ready(Some(Ok(batch)))
415                }
416                Poll::Ready(Err(e)) => {
417                    self.state = RecursiveCTEStreamState::Done;
418                    Poll::Ready(Some(Err(e)))
419                }
420                Poll::Pending => Poll::Pending,
421            },
422            RecursiveCTEStreamState::Done => Poll::Ready(None),
423        }
424    }
425}
426
427impl RecordBatchStream for RecursiveCTEStream {
428    fn schema(&self) -> SchemaRef {
429        self.schema.clone()
430    }
431}