Skip to main content

uni_query/query/df_graph/
locy_priority.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! PRIORITY operator for Locy.
5//!
6//! `PriorityExec` applies priority semantics: among rows sharing KEY columns,
7//! only those from the highest-priority clause are kept. The `__priority` column
8//! is consumed and removed from the output schema.
9
10use crate::query::df_graph::common::{ScalarKey, compute_plan_properties, extract_scalar_key};
11use arrow::compute::filter as arrow_filter;
12use arrow_array::{BooleanArray, Int64Array, RecordBatch};
13use arrow_schema::{Field, Schema, SchemaRef};
14use datafusion::common::Result as DFResult;
15use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
16use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
17use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
18use futures::{Stream, TryStreamExt};
19use std::any::Any;
20use std::collections::HashMap;
21use std::fmt;
22use std::pin::Pin;
23use std::sync::Arc;
24use std::task::{Context, Poll};
25
26/// DataFusion `ExecutionPlan` that applies PRIORITY filtering.
27///
28/// For each group of rows sharing the same KEY columns, keeps only those rows
29/// whose `__priority` value equals the maximum priority in that group. The
30/// `__priority` column is stripped from the output.
31#[derive(Debug)]
32pub struct PriorityExec {
33    input: Arc<dyn ExecutionPlan>,
34    key_indices: Vec<usize>,
35    priority_col_index: usize,
36    schema: SchemaRef,
37    properties: PlanProperties,
38    metrics: ExecutionPlanMetricsSet,
39}
40
41impl PriorityExec {
42    /// Create a new `PriorityExec`.
43    ///
44    /// # Arguments
45    /// * `input` - Child execution plan (must contain `__priority` column)
46    /// * `key_indices` - Indices of KEY columns for grouping
47    /// * `priority_col_index` - Index of the `__priority` Int64 column
48    pub fn new(
49        input: Arc<dyn ExecutionPlan>,
50        key_indices: Vec<usize>,
51        priority_col_index: usize,
52    ) -> Self {
53        let input_schema = input.schema();
54        // Build output schema: all columns except __priority
55        let output_fields: Vec<Arc<Field>> = input_schema
56            .fields()
57            .iter()
58            .enumerate()
59            .filter(|(i, _)| *i != priority_col_index)
60            .map(|(_, f)| Arc::clone(f))
61            .collect();
62        let schema = Arc::new(Schema::new(output_fields));
63        let properties = compute_plan_properties(Arc::clone(&schema));
64
65        Self {
66            input,
67            key_indices,
68            priority_col_index,
69            schema,
70            properties,
71            metrics: ExecutionPlanMetricsSet::new(),
72        }
73    }
74}
75
76impl DisplayAs for PriorityExec {
77    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
78        write!(
79            f,
80            "PriorityExec: key_indices={:?}, priority_col={}",
81            self.key_indices, self.priority_col_index
82        )
83    }
84}
85
86impl ExecutionPlan for PriorityExec {
87    fn name(&self) -> &str {
88        "PriorityExec"
89    }
90
91    fn as_any(&self) -> &dyn Any {
92        self
93    }
94
95    fn schema(&self) -> SchemaRef {
96        Arc::clone(&self.schema)
97    }
98
99    fn properties(&self) -> &PlanProperties {
100        &self.properties
101    }
102
103    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
104        vec![&self.input]
105    }
106
107    fn with_new_children(
108        self: Arc<Self>,
109        children: Vec<Arc<dyn ExecutionPlan>>,
110    ) -> DFResult<Arc<dyn ExecutionPlan>> {
111        if children.len() != 1 {
112            return Err(datafusion::error::DataFusionError::Plan(
113                "PriorityExec requires exactly one child".to_string(),
114            ));
115        }
116        Ok(Arc::new(Self::new(
117            Arc::clone(&children[0]),
118            self.key_indices.clone(),
119            self.priority_col_index,
120        )))
121    }
122
123    fn execute(
124        &self,
125        partition: usize,
126        context: Arc<TaskContext>,
127    ) -> DFResult<SendableRecordBatchStream> {
128        let input_stream = self.input.execute(partition, Arc::clone(&context))?;
129        let metrics = BaselineMetrics::new(&self.metrics, partition);
130        let key_indices = self.key_indices.clone();
131        let priority_col_index = self.priority_col_index;
132        let output_schema = Arc::clone(&self.schema);
133        let input_schema = self.input.schema();
134
135        let fut = async move {
136            // Collect all input batches
137            let batches: Vec<RecordBatch> = input_stream.try_collect().await?;
138
139            if batches.is_empty() {
140                return Ok(RecordBatch::new_empty(output_schema));
141            }
142
143            let batch = arrow::compute::concat_batches(&input_schema, &batches)
144                .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
145
146            if batch.num_rows() == 0 {
147                return Ok(RecordBatch::new_empty(output_schema));
148            }
149
150            // Extract priority column
151            let priority_col = batch
152                .column(priority_col_index)
153                .as_any()
154                .downcast_ref::<Int64Array>()
155                .ok_or_else(|| {
156                    datafusion::error::DataFusionError::Execution(
157                        "__priority column must be Int64".to_string(),
158                    )
159                })?;
160
161            // Group by key columns → find max priority per group
162            let mut group_max: HashMap<Vec<ScalarKey>, i64> = HashMap::new();
163            for row_idx in 0..batch.num_rows() {
164                let key = extract_scalar_key(&batch, &key_indices, row_idx);
165                let prio = priority_col.value(row_idx);
166                let entry = group_max.entry(key).or_insert(i64::MIN);
167                if prio > *entry {
168                    *entry = prio;
169                }
170            }
171
172            // Build filter mask: keep rows where priority == max for their group
173            let keep: Vec<bool> = (0..batch.num_rows())
174                .map(|row_idx| {
175                    let key = extract_scalar_key(&batch, &key_indices, row_idx);
176                    let prio = priority_col.value(row_idx);
177                    group_max
178                        .get(&key)
179                        .is_some_and(|&max_prio| prio == max_prio)
180                })
181                .collect();
182
183            let filter_mask = BooleanArray::from(keep);
184
185            // Filter columns, skipping __priority
186            let mut output_columns = Vec::with_capacity(output_schema.fields().len());
187            for (i, col) in batch.columns().iter().enumerate() {
188                if i == priority_col_index {
189                    continue;
190                }
191                let filtered = arrow_filter(col.as_ref(), &filter_mask).map_err(|e| {
192                    datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
193                })?;
194                output_columns.push(filtered);
195            }
196
197            RecordBatch::try_new(output_schema, output_columns)
198                .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))
199        };
200
201        Ok(Box::pin(PriorityStream {
202            state: PriorityStreamState::Running(Box::pin(fut)),
203            schema: Arc::clone(&self.schema),
204            metrics,
205        }))
206    }
207
208    fn metrics(&self) -> Option<MetricsSet> {
209        Some(self.metrics.clone_inner())
210    }
211}
212
213// ---------------------------------------------------------------------------
214// Stream implementation
215// ---------------------------------------------------------------------------
216
217enum PriorityStreamState {
218    Running(Pin<Box<dyn std::future::Future<Output = DFResult<RecordBatch>> + Send>>),
219    Done,
220}
221
222struct PriorityStream {
223    state: PriorityStreamState,
224    schema: SchemaRef,
225    metrics: BaselineMetrics,
226}
227
228impl Stream for PriorityStream {
229    type Item = DFResult<RecordBatch>;
230
231    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
232        match &mut self.state {
233            PriorityStreamState::Running(fut) => match fut.as_mut().poll(cx) {
234                Poll::Ready(Ok(batch)) => {
235                    self.metrics.record_output(batch.num_rows());
236                    self.state = PriorityStreamState::Done;
237                    Poll::Ready(Some(Ok(batch)))
238                }
239                Poll::Ready(Err(e)) => {
240                    self.state = PriorityStreamState::Done;
241                    Poll::Ready(Some(Err(e)))
242                }
243                Poll::Pending => Poll::Pending,
244            },
245            PriorityStreamState::Done => Poll::Ready(None),
246        }
247    }
248}
249
250impl RecordBatchStream for PriorityStream {
251    fn schema(&self) -> SchemaRef {
252        Arc::clone(&self.schema)
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259    use arrow_array::{Int64Array, StringArray};
260    use arrow_schema::{DataType, Field, Schema};
261    use datafusion::physical_plan::memory::MemoryStream;
262    use datafusion::prelude::SessionContext;
263
264    fn make_test_batch(names: Vec<&str>, values: Vec<i64>, priorities: Vec<i64>) -> RecordBatch {
265        let schema = Arc::new(Schema::new(vec![
266            Field::new("name", DataType::Utf8, true),
267            Field::new("value", DataType::Int64, true),
268            Field::new("__priority", DataType::Int64, false),
269        ]));
270        RecordBatch::try_new(
271            schema,
272            vec![
273                Arc::new(StringArray::from(
274                    names.into_iter().map(Some).collect::<Vec<_>>(),
275                )),
276                Arc::new(Int64Array::from(values)),
277                Arc::new(Int64Array::from(priorities)),
278            ],
279        )
280        .unwrap()
281    }
282
283    fn make_memory_exec(batch: RecordBatch) -> Arc<dyn ExecutionPlan> {
284        let schema = batch.schema();
285        Arc::new(TestMemoryExec {
286            batches: vec![batch],
287            schema: schema.clone(),
288            properties: compute_plan_properties(schema),
289        })
290    }
291
292    #[derive(Debug)]
293    struct TestMemoryExec {
294        batches: Vec<RecordBatch>,
295        schema: SchemaRef,
296        properties: PlanProperties,
297    }
298
299    impl DisplayAs for TestMemoryExec {
300        fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
301            write!(f, "TestMemoryExec")
302        }
303    }
304
305    impl ExecutionPlan for TestMemoryExec {
306        fn name(&self) -> &str {
307            "TestMemoryExec"
308        }
309        fn as_any(&self) -> &dyn Any {
310            self
311        }
312        fn schema(&self) -> SchemaRef {
313            Arc::clone(&self.schema)
314        }
315        fn properties(&self) -> &PlanProperties {
316            &self.properties
317        }
318        fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
319            vec![]
320        }
321        fn with_new_children(
322            self: Arc<Self>,
323            _children: Vec<Arc<dyn ExecutionPlan>>,
324        ) -> DFResult<Arc<dyn ExecutionPlan>> {
325            Ok(self)
326        }
327        fn execute(
328            &self,
329            _partition: usize,
330            _context: Arc<TaskContext>,
331        ) -> DFResult<SendableRecordBatchStream> {
332            Ok(Box::pin(MemoryStream::try_new(
333                self.batches.clone(),
334                Arc::clone(&self.schema),
335                None,
336            )?))
337        }
338    }
339
340    async fn execute_priority(
341        input: Arc<dyn ExecutionPlan>,
342        key_indices: Vec<usize>,
343        priority_col_index: usize,
344    ) -> RecordBatch {
345        let exec = PriorityExec::new(input, key_indices, priority_col_index);
346        let ctx = SessionContext::new();
347        let task_ctx = ctx.task_ctx();
348        let stream = exec.execute(0, task_ctx).unwrap();
349        let batches: Vec<RecordBatch> = datafusion::physical_plan::common::collect(stream)
350            .await
351            .unwrap();
352        if batches.is_empty() {
353            RecordBatch::new_empty(exec.schema())
354        } else {
355            arrow::compute::concat_batches(&exec.schema(), &batches).unwrap()
356        }
357    }
358
359    #[tokio::test]
360    async fn test_single_group_keeps_highest_priority() {
361        let batch = make_test_batch(vec!["a", "a", "a"], vec![10, 20, 30], vec![1, 3, 2]);
362        let input = make_memory_exec(batch);
363        // key_indices=[0] (name), priority_col=2
364        let result = execute_priority(input, vec![0], 2).await;
365
366        assert_eq!(result.num_rows(), 1);
367        let values = result
368            .column(1)
369            .as_any()
370            .downcast_ref::<Int64Array>()
371            .unwrap();
372        assert_eq!(values.value(0), 20); // priority 3 was highest
373    }
374
375    #[tokio::test]
376    async fn test_multiple_groups_independent_priority() {
377        let batch = make_test_batch(
378            vec!["a", "a", "b", "b"],
379            vec![10, 20, 30, 40],
380            vec![1, 2, 3, 1],
381        );
382        let input = make_memory_exec(batch);
383        let result = execute_priority(input, vec![0], 2).await;
384
385        assert_eq!(result.num_rows(), 2);
386        // Group "a": priority 2 wins → value 20
387        // Group "b": priority 3 wins → value 30
388        let names = result
389            .column(0)
390            .as_any()
391            .downcast_ref::<StringArray>()
392            .unwrap();
393        let values = result
394            .column(1)
395            .as_any()
396            .downcast_ref::<Int64Array>()
397            .unwrap();
398
399        // Find each group's result
400        for i in 0..2 {
401            match names.value(i) {
402                "a" => assert_eq!(values.value(i), 20),
403                "b" => assert_eq!(values.value(i), 30),
404                _ => panic!("unexpected name"),
405            }
406        }
407    }
408
409    #[tokio::test]
410    async fn test_all_same_priority_keeps_all() {
411        let batch = make_test_batch(vec!["a", "a", "a"], vec![10, 20, 30], vec![5, 5, 5]);
412        let input = make_memory_exec(batch);
413        let result = execute_priority(input, vec![0], 2).await;
414
415        assert_eq!(result.num_rows(), 3);
416    }
417
418    #[tokio::test]
419    async fn test_empty_input() {
420        let schema = Arc::new(Schema::new(vec![
421            Field::new("name", DataType::Utf8, true),
422            Field::new("__priority", DataType::Int64, false),
423        ]));
424        let batch = RecordBatch::new_empty(schema.clone());
425        let input = make_memory_exec(batch);
426        let result = execute_priority(input, vec![0], 1).await;
427
428        assert_eq!(result.num_rows(), 0);
429    }
430
431    #[tokio::test]
432    async fn test_single_row_passthrough() {
433        let batch = make_test_batch(vec!["x"], vec![42], vec![1]);
434        let input = make_memory_exec(batch);
435        let result = execute_priority(input, vec![0], 2).await;
436
437        assert_eq!(result.num_rows(), 1);
438        let values = result
439            .column(1)
440            .as_any()
441            .downcast_ref::<Int64Array>()
442            .unwrap();
443        assert_eq!(values.value(0), 42);
444    }
445
446    #[tokio::test]
447    async fn test_output_schema_lacks_priority() {
448        let batch = make_test_batch(vec!["a"], vec![1], vec![1]);
449        let input = make_memory_exec(batch);
450        let exec = PriorityExec::new(input, vec![0], 2);
451
452        let schema = exec.schema();
453        assert_eq!(schema.fields().len(), 2); // name + value, no __priority
454        assert!(schema.column_with_name("__priority").is_none());
455        assert!(schema.column_with_name("name").is_some());
456        assert!(schema.column_with_name("value").is_some());
457    }
458}