Skip to main content

uni_query/query/df_graph/
locy_priority.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! PRIORITY operator for Locy.
5//!
6//! `PriorityExec` applies priority semantics: among rows sharing KEY columns,
7//! only those from the highest-priority clause are kept. The `__priority` column
8//! is consumed and removed from the output schema.
9
10use crate::query::df_graph::common::{
11    ScalarKey, arrow_err, compute_plan_properties, extract_scalar_key,
12};
13use arrow::compute::filter as arrow_filter;
14use arrow_array::{BooleanArray, Int64Array, RecordBatch};
15use arrow_schema::{Field, Schema, SchemaRef};
16use datafusion::common::Result as DFResult;
17use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
18use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
19use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
20use futures::{Stream, TryStreamExt};
21use std::any::Any;
22use std::collections::HashMap;
23use std::fmt;
24use std::pin::Pin;
25use std::sync::Arc;
26use std::task::{Context, Poll};
27
28/// DataFusion `ExecutionPlan` that applies PRIORITY filtering.
29///
30/// For each group of rows sharing the same KEY columns, keeps only those rows
31/// whose `__priority` value equals the maximum priority in that group. The
32/// `__priority` column is stripped from the output.
33#[derive(Debug)]
34pub struct PriorityExec {
35    input: Arc<dyn ExecutionPlan>,
36    key_indices: Vec<usize>,
37    priority_col_index: usize,
38    schema: SchemaRef,
39    properties: PlanProperties,
40    metrics: ExecutionPlanMetricsSet,
41}
42
43impl PriorityExec {
44    /// Create a new `PriorityExec`.
45    ///
46    /// # Arguments
47    /// * `input` - Child execution plan (must contain `__priority` column)
48    /// * `key_indices` - Indices of KEY columns for grouping
49    /// * `priority_col_index` - Index of the `__priority` Int64 column
50    pub fn new(
51        input: Arc<dyn ExecutionPlan>,
52        key_indices: Vec<usize>,
53        priority_col_index: usize,
54    ) -> Self {
55        let input_schema = input.schema();
56        // Build output schema: all columns except __priority
57        let output_fields: Vec<Arc<Field>> = input_schema
58            .fields()
59            .iter()
60            .enumerate()
61            .filter(|(i, _)| *i != priority_col_index)
62            .map(|(_, f)| Arc::clone(f))
63            .collect();
64        let schema = Arc::new(Schema::new(output_fields));
65        let properties = compute_plan_properties(Arc::clone(&schema));
66
67        Self {
68            input,
69            key_indices,
70            priority_col_index,
71            schema,
72            properties,
73            metrics: ExecutionPlanMetricsSet::new(),
74        }
75    }
76}
77
78impl DisplayAs for PriorityExec {
79    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80        write!(
81            f,
82            "PriorityExec: key_indices={:?}, priority_col={}",
83            self.key_indices, self.priority_col_index
84        )
85    }
86}
87
88impl ExecutionPlan for PriorityExec {
89    fn name(&self) -> &str {
90        "PriorityExec"
91    }
92
93    fn as_any(&self) -> &dyn Any {
94        self
95    }
96
97    fn schema(&self) -> SchemaRef {
98        Arc::clone(&self.schema)
99    }
100
101    fn properties(&self) -> &PlanProperties {
102        &self.properties
103    }
104
105    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
106        vec![&self.input]
107    }
108
109    fn with_new_children(
110        self: Arc<Self>,
111        children: Vec<Arc<dyn ExecutionPlan>>,
112    ) -> DFResult<Arc<dyn ExecutionPlan>> {
113        if children.len() != 1 {
114            return Err(datafusion::error::DataFusionError::Plan(
115                "PriorityExec requires exactly one child".to_string(),
116            ));
117        }
118        Ok(Arc::new(Self::new(
119            Arc::clone(&children[0]),
120            self.key_indices.clone(),
121            self.priority_col_index,
122        )))
123    }
124
125    fn execute(
126        &self,
127        partition: usize,
128        context: Arc<TaskContext>,
129    ) -> DFResult<SendableRecordBatchStream> {
130        let input_stream = self.input.execute(partition, Arc::clone(&context))?;
131        let metrics = BaselineMetrics::new(&self.metrics, partition);
132        let key_indices = self.key_indices.clone();
133        let priority_col_index = self.priority_col_index;
134        let output_schema = Arc::clone(&self.schema);
135        let input_schema = self.input.schema();
136
137        let fut = async move {
138            // Collect all input batches
139            let batches: Vec<RecordBatch> = input_stream.try_collect().await?;
140
141            if batches.is_empty() {
142                return Ok(RecordBatch::new_empty(output_schema));
143            }
144
145            let batch =
146                arrow::compute::concat_batches(&input_schema, &batches).map_err(arrow_err)?;
147
148            if batch.num_rows() == 0 {
149                return Ok(RecordBatch::new_empty(output_schema));
150            }
151
152            // Extract priority column
153            let priority_col = batch
154                .column(priority_col_index)
155                .as_any()
156                .downcast_ref::<Int64Array>()
157                .ok_or_else(|| {
158                    datafusion::error::DataFusionError::Execution(
159                        "__priority column must be Int64".to_string(),
160                    )
161                })?;
162
163            // Group by key columns → find max priority per group
164            let mut group_max: HashMap<Vec<ScalarKey>, i64> = HashMap::new();
165            for row_idx in 0..batch.num_rows() {
166                let key = extract_scalar_key(&batch, &key_indices, row_idx);
167                let prio = priority_col.value(row_idx);
168                let entry = group_max.entry(key).or_insert(i64::MIN);
169                if prio > *entry {
170                    *entry = prio;
171                }
172            }
173
174            // Build filter mask: keep rows where priority == max for their group
175            let keep: Vec<bool> = (0..batch.num_rows())
176                .map(|row_idx| {
177                    let key = extract_scalar_key(&batch, &key_indices, row_idx);
178                    let prio = priority_col.value(row_idx);
179                    group_max
180                        .get(&key)
181                        .is_some_and(|&max_prio| prio == max_prio)
182                })
183                .collect();
184
185            let filter_mask = BooleanArray::from(keep);
186
187            // Filter columns, skipping __priority
188            let mut output_columns = Vec::with_capacity(output_schema.fields().len());
189            for (i, col) in batch.columns().iter().enumerate() {
190                if i == priority_col_index {
191                    continue;
192                }
193                let filtered = arrow_filter(col.as_ref(), &filter_mask).map_err(|e| {
194                    datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
195                })?;
196                output_columns.push(filtered);
197            }
198
199            RecordBatch::try_new(output_schema, output_columns).map_err(arrow_err)
200        };
201
202        Ok(Box::pin(PriorityStream {
203            state: PriorityStreamState::Running(Box::pin(fut)),
204            schema: Arc::clone(&self.schema),
205            metrics,
206        }))
207    }
208
209    fn metrics(&self) -> Option<MetricsSet> {
210        Some(self.metrics.clone_inner())
211    }
212}
213
214// ---------------------------------------------------------------------------
215// Stream implementation
216// ---------------------------------------------------------------------------
217
218enum PriorityStreamState {
219    Running(Pin<Box<dyn std::future::Future<Output = DFResult<RecordBatch>> + Send>>),
220    Done,
221}
222
223struct PriorityStream {
224    state: PriorityStreamState,
225    schema: SchemaRef,
226    metrics: BaselineMetrics,
227}
228
229impl Stream for PriorityStream {
230    type Item = DFResult<RecordBatch>;
231
232    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
233        match &mut self.state {
234            PriorityStreamState::Running(fut) => match fut.as_mut().poll(cx) {
235                Poll::Ready(Ok(batch)) => {
236                    self.metrics.record_output(batch.num_rows());
237                    self.state = PriorityStreamState::Done;
238                    Poll::Ready(Some(Ok(batch)))
239                }
240                Poll::Ready(Err(e)) => {
241                    self.state = PriorityStreamState::Done;
242                    Poll::Ready(Some(Err(e)))
243                }
244                Poll::Pending => Poll::Pending,
245            },
246            PriorityStreamState::Done => Poll::Ready(None),
247        }
248    }
249}
250
251impl RecordBatchStream for PriorityStream {
252    fn schema(&self) -> SchemaRef {
253        Arc::clone(&self.schema)
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260    use arrow_array::{Int64Array, StringArray};
261    use arrow_schema::{DataType, Field, Schema};
262    use datafusion::physical_plan::memory::MemoryStream;
263    use datafusion::prelude::SessionContext;
264
265    fn make_test_batch(names: Vec<&str>, values: Vec<i64>, priorities: Vec<i64>) -> RecordBatch {
266        let schema = Arc::new(Schema::new(vec![
267            Field::new("name", DataType::Utf8, true),
268            Field::new("value", DataType::Int64, true),
269            Field::new("__priority", DataType::Int64, false),
270        ]));
271        RecordBatch::try_new(
272            schema,
273            vec![
274                Arc::new(StringArray::from(
275                    names.into_iter().map(Some).collect::<Vec<_>>(),
276                )),
277                Arc::new(Int64Array::from(values)),
278                Arc::new(Int64Array::from(priorities)),
279            ],
280        )
281        .unwrap()
282    }
283
284    fn make_memory_exec(batch: RecordBatch) -> Arc<dyn ExecutionPlan> {
285        let schema = batch.schema();
286        Arc::new(TestMemoryExec {
287            batches: vec![batch],
288            schema: schema.clone(),
289            properties: compute_plan_properties(schema),
290        })
291    }
292
293    #[derive(Debug)]
294    struct TestMemoryExec {
295        batches: Vec<RecordBatch>,
296        schema: SchemaRef,
297        properties: PlanProperties,
298    }
299
300    impl DisplayAs for TestMemoryExec {
301        fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
302            write!(f, "TestMemoryExec")
303        }
304    }
305
306    impl ExecutionPlan for TestMemoryExec {
307        fn name(&self) -> &str {
308            "TestMemoryExec"
309        }
310        fn as_any(&self) -> &dyn Any {
311            self
312        }
313        fn schema(&self) -> SchemaRef {
314            Arc::clone(&self.schema)
315        }
316        fn properties(&self) -> &PlanProperties {
317            &self.properties
318        }
319        fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
320            vec![]
321        }
322        fn with_new_children(
323            self: Arc<Self>,
324            _children: Vec<Arc<dyn ExecutionPlan>>,
325        ) -> DFResult<Arc<dyn ExecutionPlan>> {
326            Ok(self)
327        }
328        fn execute(
329            &self,
330            _partition: usize,
331            _context: Arc<TaskContext>,
332        ) -> DFResult<SendableRecordBatchStream> {
333            Ok(Box::pin(MemoryStream::try_new(
334                self.batches.clone(),
335                Arc::clone(&self.schema),
336                None,
337            )?))
338        }
339    }
340
341    async fn execute_priority(
342        input: Arc<dyn ExecutionPlan>,
343        key_indices: Vec<usize>,
344        priority_col_index: usize,
345    ) -> RecordBatch {
346        let exec = PriorityExec::new(input, key_indices, priority_col_index);
347        let ctx = SessionContext::new();
348        let task_ctx = ctx.task_ctx();
349        let stream = exec.execute(0, task_ctx).unwrap();
350        let batches: Vec<RecordBatch> = datafusion::physical_plan::common::collect(stream)
351            .await
352            .unwrap();
353        if batches.is_empty() {
354            RecordBatch::new_empty(exec.schema())
355        } else {
356            arrow::compute::concat_batches(&exec.schema(), &batches).unwrap()
357        }
358    }
359
360    #[tokio::test]
361    async fn test_single_group_keeps_highest_priority() {
362        let batch = make_test_batch(vec!["a", "a", "a"], vec![10, 20, 30], vec![1, 3, 2]);
363        let input = make_memory_exec(batch);
364        // key_indices=[0] (name), priority_col=2
365        let result = execute_priority(input, vec![0], 2).await;
366
367        assert_eq!(result.num_rows(), 1);
368        let values = result
369            .column(1)
370            .as_any()
371            .downcast_ref::<Int64Array>()
372            .unwrap();
373        assert_eq!(values.value(0), 20); // priority 3 was highest
374    }
375
376    #[tokio::test]
377    async fn test_multiple_groups_independent_priority() {
378        let batch = make_test_batch(
379            vec!["a", "a", "b", "b"],
380            vec![10, 20, 30, 40],
381            vec![1, 2, 3, 1],
382        );
383        let input = make_memory_exec(batch);
384        let result = execute_priority(input, vec![0], 2).await;
385
386        assert_eq!(result.num_rows(), 2);
387        // Group "a": priority 2 wins → value 20
388        // Group "b": priority 3 wins → value 30
389        let names = result
390            .column(0)
391            .as_any()
392            .downcast_ref::<StringArray>()
393            .unwrap();
394        let values = result
395            .column(1)
396            .as_any()
397            .downcast_ref::<Int64Array>()
398            .unwrap();
399
400        // Find each group's result
401        for i in 0..2 {
402            match names.value(i) {
403                "a" => assert_eq!(values.value(i), 20),
404                "b" => assert_eq!(values.value(i), 30),
405                _ => panic!("unexpected name"),
406            }
407        }
408    }
409
410    #[tokio::test]
411    async fn test_all_same_priority_keeps_all() {
412        let batch = make_test_batch(vec!["a", "a", "a"], vec![10, 20, 30], vec![5, 5, 5]);
413        let input = make_memory_exec(batch);
414        let result = execute_priority(input, vec![0], 2).await;
415
416        assert_eq!(result.num_rows(), 3);
417    }
418
419    #[tokio::test]
420    async fn test_empty_input() {
421        let schema = Arc::new(Schema::new(vec![
422            Field::new("name", DataType::Utf8, true),
423            Field::new("__priority", DataType::Int64, false),
424        ]));
425        let batch = RecordBatch::new_empty(schema.clone());
426        let input = make_memory_exec(batch);
427        let result = execute_priority(input, vec![0], 1).await;
428
429        assert_eq!(result.num_rows(), 0);
430    }
431
432    #[tokio::test]
433    async fn test_single_row_passthrough() {
434        let batch = make_test_batch(vec!["x"], vec![42], vec![1]);
435        let input = make_memory_exec(batch);
436        let result = execute_priority(input, vec![0], 2).await;
437
438        assert_eq!(result.num_rows(), 1);
439        let values = result
440            .column(1)
441            .as_any()
442            .downcast_ref::<Int64Array>()
443            .unwrap();
444        assert_eq!(values.value(0), 42);
445    }
446
447    #[tokio::test]
448    async fn test_output_schema_lacks_priority() {
449        let batch = make_test_batch(vec!["a"], vec![1], vec![1]);
450        let input = make_memory_exec(batch);
451        let exec = PriorityExec::new(input, vec![0], 2);
452
453        let schema = exec.schema();
454        assert_eq!(schema.fields().len(), 2); // name + value, no __priority
455        assert!(schema.column_with_name("__priority").is_none());
456        assert!(schema.column_with_name("name").is_some());
457        assert!(schema.column_with_name("value").is_some());
458    }
459}