Skip to main content

uni_query/query/df_graph/
locy_priority.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! PRIORITY operator for Locy.
5//!
6//! `PriorityExec` applies priority semantics: among rows sharing KEY columns,
7//! only those from the highest-priority clause are kept. The `__priority` column
8//! is consumed and removed from the output schema.
9
10use crate::query::df_graph::common::{
11    ScalarKey, arrow_err, compute_plan_properties, extract_scalar_key,
12};
13use arrow::compute::filter as arrow_filter;
14use arrow_array::{BooleanArray, Int64Array, RecordBatch};
15use arrow_schema::{Field, Schema, SchemaRef};
16use datafusion::common::Result as DFResult;
17use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
18use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
19use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
20use futures::{Stream, TryStreamExt};
21use std::any::Any;
22use std::collections::HashMap;
23use std::fmt;
24use std::pin::Pin;
25use std::sync::Arc;
26use std::task::{Context, Poll};
27
28/// DataFusion `ExecutionPlan` that applies PRIORITY filtering.
29///
30/// For each group of rows sharing the same KEY columns, keeps only those rows
31/// whose `__priority` value equals the maximum priority in that group. The
32/// `__priority` column is stripped from the output.
33#[derive(Debug)]
34pub struct PriorityExec {
35    input: Arc<dyn ExecutionPlan>,
36    key_indices: Vec<usize>,
37    priority_col_index: usize,
38    schema: SchemaRef,
39    properties: Arc<PlanProperties>,
40    metrics: ExecutionPlanMetricsSet,
41}
42
43impl PriorityExec {
44    /// Create a new `PriorityExec`.
45    ///
46    /// # Arguments
47    /// * `input` - Child execution plan (must contain `__priority` column)
48    /// * `key_indices` - Indices of KEY columns for grouping
49    /// * `priority_col_index` - Index of the `__priority` Int64 column
50    pub fn new(
51        input: Arc<dyn ExecutionPlan>,
52        key_indices: Vec<usize>,
53        priority_col_index: usize,
54    ) -> Self {
55        let input_schema = input.schema();
56        // Build output schema: all columns except __priority
57        let output_fields: Vec<Arc<Field>> = input_schema
58            .fields()
59            .iter()
60            .enumerate()
61            .filter(|(i, _)| *i != priority_col_index)
62            .map(|(_, f)| Arc::clone(f))
63            .collect();
64        let schema = Arc::new(Schema::new(output_fields));
65        let properties = compute_plan_properties(Arc::clone(&schema));
66
67        Self {
68            input,
69            key_indices,
70            priority_col_index,
71            schema,
72            properties,
73            metrics: ExecutionPlanMetricsSet::new(),
74        }
75    }
76}
77
78impl DisplayAs for PriorityExec {
79    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80        write!(
81            f,
82            "PriorityExec: key_indices={:?}, priority_col={}",
83            self.key_indices, self.priority_col_index
84        )
85    }
86}
87
88impl ExecutionPlan for PriorityExec {
89    fn name(&self) -> &str {
90        "PriorityExec"
91    }
92
93    fn as_any(&self) -> &dyn Any {
94        self
95    }
96
97    fn schema(&self) -> SchemaRef {
98        Arc::clone(&self.schema)
99    }
100
101    fn properties(&self) -> &Arc<PlanProperties> {
102        &self.properties
103    }
104
105    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
106        vec![&self.input]
107    }
108
109    fn with_new_children(
110        self: Arc<Self>,
111        children: Vec<Arc<dyn ExecutionPlan>>,
112    ) -> DFResult<Arc<dyn ExecutionPlan>> {
113        if children.len() != 1 {
114            return Err(datafusion::error::DataFusionError::Plan(
115                "PriorityExec requires exactly one child".to_string(),
116            ));
117        }
118        Ok(Arc::new(Self::new(
119            Arc::clone(&children[0]),
120            self.key_indices.clone(),
121            self.priority_col_index,
122        )))
123    }
124
125    fn execute(
126        &self,
127        partition: usize,
128        context: Arc<TaskContext>,
129    ) -> DFResult<SendableRecordBatchStream> {
130        let input_stream = self.input.execute(partition, Arc::clone(&context))?;
131        let metrics = BaselineMetrics::new(&self.metrics, partition);
132        let key_indices = self.key_indices.clone();
133        let priority_col_index = self.priority_col_index;
134        let output_schema = Arc::clone(&self.schema);
135        let input_schema = self.input.schema();
136
137        let fut = async move {
138            // Collect all input batches
139            let batches: Vec<RecordBatch> = input_stream.try_collect().await?;
140
141            if batches.is_empty() {
142                return Ok(RecordBatch::new_empty(output_schema));
143            }
144
145            let batch =
146                arrow::compute::concat_batches(&input_schema, &batches).map_err(arrow_err)?;
147
148            if batch.num_rows() == 0 {
149                return Ok(RecordBatch::new_empty(output_schema));
150            }
151
152            // Extract priority column
153            let priority_col = batch
154                .column(priority_col_index)
155                .as_any()
156                .downcast_ref::<Int64Array>()
157                .ok_or_else(|| {
158                    datafusion::error::DataFusionError::Execution(
159                        "__priority column must be Int64".to_string(),
160                    )
161                })?;
162
163            // Group by key columns → find max priority per group
164            let mut group_max: HashMap<Vec<ScalarKey>, i64> = HashMap::new();
165            for row_idx in 0..batch.num_rows() {
166                let key = extract_scalar_key(&batch, &key_indices, row_idx);
167                let prio = priority_col.value(row_idx);
168                let entry = group_max.entry(key).or_insert(i64::MIN);
169                if prio > *entry {
170                    *entry = prio;
171                }
172            }
173
174            // Build filter mask: keep rows where priority == max for their group
175            let keep: Vec<bool> = (0..batch.num_rows())
176                .map(|row_idx| {
177                    let key = extract_scalar_key(&batch, &key_indices, row_idx);
178                    let prio = priority_col.value(row_idx);
179                    group_max
180                        .get(&key)
181                        .is_some_and(|&max_prio| prio == max_prio)
182                })
183                .collect();
184
185            let filter_mask = BooleanArray::from(keep);
186
187            // Filter columns, skipping __priority
188            let mut output_columns = Vec::with_capacity(output_schema.fields().len());
189            for (i, col) in batch.columns().iter().enumerate() {
190                if i == priority_col_index {
191                    continue;
192                }
193                let filtered = arrow_filter(col.as_ref(), &filter_mask).map_err(|e| {
194                    datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
195                })?;
196                output_columns.push(filtered);
197            }
198
199            RecordBatch::try_new(output_schema, output_columns).map_err(arrow_err)
200        };
201
202        Ok(Box::pin(PriorityStream {
203            state: PriorityStreamState::Running(Box::pin(fut)),
204            schema: Arc::clone(&self.schema),
205            metrics,
206        }))
207    }
208
209    fn metrics(&self) -> Option<MetricsSet> {
210        Some(self.metrics.clone_inner())
211    }
212}
213
214// ---------------------------------------------------------------------------
215// Stream implementation
216// ---------------------------------------------------------------------------
217
218enum PriorityStreamState {
219    Running(Pin<Box<dyn std::future::Future<Output = DFResult<RecordBatch>> + Send>>),
220    Done,
221}
222
223struct PriorityStream {
224    state: PriorityStreamState,
225    schema: SchemaRef,
226    metrics: BaselineMetrics,
227}
228
229impl Stream for PriorityStream {
230    type Item = DFResult<RecordBatch>;
231
232    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
233        let metrics = self.metrics.clone();
234        let _timer = metrics.elapsed_compute().timer();
235        match &mut self.state {
236            PriorityStreamState::Running(fut) => match fut.as_mut().poll(cx) {
237                Poll::Ready(Ok(batch)) => {
238                    self.metrics.record_output(batch.num_rows());
239                    self.state = PriorityStreamState::Done;
240                    Poll::Ready(Some(Ok(batch)))
241                }
242                Poll::Ready(Err(e)) => {
243                    self.state = PriorityStreamState::Done;
244                    Poll::Ready(Some(Err(e)))
245                }
246                Poll::Pending => Poll::Pending,
247            },
248            PriorityStreamState::Done => Poll::Ready(None),
249        }
250    }
251}
252
253impl RecordBatchStream for PriorityStream {
254    fn schema(&self) -> SchemaRef {
255        Arc::clone(&self.schema)
256    }
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262    use arrow_array::{Int64Array, StringArray};
263    use arrow_schema::{DataType, Field, Schema};
264    use datafusion::physical_plan::memory::MemoryStream;
265    use datafusion::prelude::SessionContext;
266
267    fn make_test_batch(names: Vec<&str>, values: Vec<i64>, priorities: Vec<i64>) -> RecordBatch {
268        let schema = Arc::new(Schema::new(vec![
269            Field::new("name", DataType::Utf8, true),
270            Field::new("value", DataType::Int64, true),
271            Field::new("__priority", DataType::Int64, false),
272        ]));
273        RecordBatch::try_new(
274            schema,
275            vec![
276                Arc::new(StringArray::from(
277                    names.into_iter().map(Some).collect::<Vec<_>>(),
278                )),
279                Arc::new(Int64Array::from(values)),
280                Arc::new(Int64Array::from(priorities)),
281            ],
282        )
283        .unwrap()
284    }
285
286    fn make_memory_exec(batch: RecordBatch) -> Arc<dyn ExecutionPlan> {
287        let schema = batch.schema();
288        Arc::new(TestMemoryExec {
289            batches: vec![batch],
290            schema: schema.clone(),
291            properties: compute_plan_properties(schema),
292        })
293    }
294
295    #[derive(Debug)]
296    struct TestMemoryExec {
297        batches: Vec<RecordBatch>,
298        schema: SchemaRef,
299        properties: Arc<PlanProperties>,
300    }
301
302    impl DisplayAs for TestMemoryExec {
303        fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
304            write!(f, "TestMemoryExec")
305        }
306    }
307
308    impl ExecutionPlan for TestMemoryExec {
309        fn name(&self) -> &str {
310            "TestMemoryExec"
311        }
312        fn as_any(&self) -> &dyn Any {
313            self
314        }
315        fn schema(&self) -> SchemaRef {
316            Arc::clone(&self.schema)
317        }
318        fn properties(&self) -> &Arc<PlanProperties> {
319            &self.properties
320        }
321        fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
322            vec![]
323        }
324        fn with_new_children(
325            self: Arc<Self>,
326            _children: Vec<Arc<dyn ExecutionPlan>>,
327        ) -> DFResult<Arc<dyn ExecutionPlan>> {
328            Ok(self)
329        }
330        fn execute(
331            &self,
332            _partition: usize,
333            _context: Arc<TaskContext>,
334        ) -> DFResult<SendableRecordBatchStream> {
335            Ok(Box::pin(MemoryStream::try_new(
336                self.batches.clone(),
337                Arc::clone(&self.schema),
338                None,
339            )?))
340        }
341    }
342
343    async fn execute_priority(
344        input: Arc<dyn ExecutionPlan>,
345        key_indices: Vec<usize>,
346        priority_col_index: usize,
347    ) -> RecordBatch {
348        let exec = PriorityExec::new(input, key_indices, priority_col_index);
349        let ctx = SessionContext::new();
350        let task_ctx = ctx.task_ctx();
351        let stream = exec.execute(0, task_ctx).unwrap();
352        let batches: Vec<RecordBatch> = datafusion::physical_plan::common::collect(stream)
353            .await
354            .unwrap();
355        if batches.is_empty() {
356            RecordBatch::new_empty(exec.schema())
357        } else {
358            arrow::compute::concat_batches(&exec.schema(), &batches).unwrap()
359        }
360    }
361
362    #[tokio::test]
363    async fn test_single_group_keeps_highest_priority() {
364        let batch = make_test_batch(vec!["a", "a", "a"], vec![10, 20, 30], vec![1, 3, 2]);
365        let input = make_memory_exec(batch);
366        // key_indices=[0] (name), priority_col=2
367        let result = execute_priority(input, vec![0], 2).await;
368
369        assert_eq!(result.num_rows(), 1);
370        let values = result
371            .column(1)
372            .as_any()
373            .downcast_ref::<Int64Array>()
374            .unwrap();
375        assert_eq!(values.value(0), 20); // priority 3 was highest
376    }
377
378    #[tokio::test]
379    async fn test_multiple_groups_independent_priority() {
380        let batch = make_test_batch(
381            vec!["a", "a", "b", "b"],
382            vec![10, 20, 30, 40],
383            vec![1, 2, 3, 1],
384        );
385        let input = make_memory_exec(batch);
386        let result = execute_priority(input, vec![0], 2).await;
387
388        assert_eq!(result.num_rows(), 2);
389        // Group "a": priority 2 wins → value 20
390        // Group "b": priority 3 wins → value 30
391        let names = result
392            .column(0)
393            .as_any()
394            .downcast_ref::<StringArray>()
395            .unwrap();
396        let values = result
397            .column(1)
398            .as_any()
399            .downcast_ref::<Int64Array>()
400            .unwrap();
401
402        // Find each group's result
403        for i in 0..2 {
404            match names.value(i) {
405                "a" => assert_eq!(values.value(i), 20),
406                "b" => assert_eq!(values.value(i), 30),
407                _ => panic!("unexpected name"),
408            }
409        }
410    }
411
412    #[tokio::test]
413    async fn test_all_same_priority_keeps_all() {
414        let batch = make_test_batch(vec!["a", "a", "a"], vec![10, 20, 30], vec![5, 5, 5]);
415        let input = make_memory_exec(batch);
416        let result = execute_priority(input, vec![0], 2).await;
417
418        assert_eq!(result.num_rows(), 3);
419    }
420
421    #[tokio::test]
422    async fn test_empty_input() {
423        let schema = Arc::new(Schema::new(vec![
424            Field::new("name", DataType::Utf8, true),
425            Field::new("__priority", DataType::Int64, false),
426        ]));
427        let batch = RecordBatch::new_empty(schema.clone());
428        let input = make_memory_exec(batch);
429        let result = execute_priority(input, vec![0], 1).await;
430
431        assert_eq!(result.num_rows(), 0);
432    }
433
434    #[tokio::test]
435    async fn test_single_row_passthrough() {
436        let batch = make_test_batch(vec!["x"], vec![42], vec![1]);
437        let input = make_memory_exec(batch);
438        let result = execute_priority(input, vec![0], 2).await;
439
440        assert_eq!(result.num_rows(), 1);
441        let values = result
442            .column(1)
443            .as_any()
444            .downcast_ref::<Int64Array>()
445            .unwrap();
446        assert_eq!(values.value(0), 42);
447    }
448
449    #[tokio::test]
450    async fn test_output_schema_lacks_priority() {
451        let batch = make_test_batch(vec!["a"], vec![1], vec![1]);
452        let input = make_memory_exec(batch);
453        let exec = PriorityExec::new(input, vec![0], 2);
454
455        let schema = exec.schema();
456        assert_eq!(schema.fields().len(), 2); // name + value, no __priority
457        assert!(schema.column_with_name("__priority").is_none());
458        assert!(schema.column_with_name("name").is_some());
459        assert!(schema.column_with_name("value").is_some());
460    }
461}