1use crate::query::df_graph::common::{ScalarKey, compute_plan_properties, extract_scalar_key};
11use arrow::compute::filter as arrow_filter;
12use arrow_array::{BooleanArray, Int64Array, RecordBatch};
13use arrow_schema::{Field, Schema, SchemaRef};
14use datafusion::common::Result as DFResult;
15use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
16use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
17use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
18use futures::{Stream, TryStreamExt};
19use std::any::Any;
20use std::collections::HashMap;
21use std::fmt;
22use std::pin::Pin;
23use std::sync::Arc;
24use std::task::{Context, Poll};
25
26#[derive(Debug)]
32pub struct PriorityExec {
33 input: Arc<dyn ExecutionPlan>,
34 key_indices: Vec<usize>,
35 priority_col_index: usize,
36 schema: SchemaRef,
37 properties: PlanProperties,
38 metrics: ExecutionPlanMetricsSet,
39}
40
41impl PriorityExec {
42 pub fn new(
49 input: Arc<dyn ExecutionPlan>,
50 key_indices: Vec<usize>,
51 priority_col_index: usize,
52 ) -> Self {
53 let input_schema = input.schema();
54 let output_fields: Vec<Arc<Field>> = input_schema
56 .fields()
57 .iter()
58 .enumerate()
59 .filter(|(i, _)| *i != priority_col_index)
60 .map(|(_, f)| Arc::clone(f))
61 .collect();
62 let schema = Arc::new(Schema::new(output_fields));
63 let properties = compute_plan_properties(Arc::clone(&schema));
64
65 Self {
66 input,
67 key_indices,
68 priority_col_index,
69 schema,
70 properties,
71 metrics: ExecutionPlanMetricsSet::new(),
72 }
73 }
74}
75
76impl DisplayAs for PriorityExec {
77 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
78 write!(
79 f,
80 "PriorityExec: key_indices={:?}, priority_col={}",
81 self.key_indices, self.priority_col_index
82 )
83 }
84}
85
86impl ExecutionPlan for PriorityExec {
87 fn name(&self) -> &str {
88 "PriorityExec"
89 }
90
91 fn as_any(&self) -> &dyn Any {
92 self
93 }
94
95 fn schema(&self) -> SchemaRef {
96 Arc::clone(&self.schema)
97 }
98
99 fn properties(&self) -> &PlanProperties {
100 &self.properties
101 }
102
103 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
104 vec![&self.input]
105 }
106
107 fn with_new_children(
108 self: Arc<Self>,
109 children: Vec<Arc<dyn ExecutionPlan>>,
110 ) -> DFResult<Arc<dyn ExecutionPlan>> {
111 if children.len() != 1 {
112 return Err(datafusion::error::DataFusionError::Plan(
113 "PriorityExec requires exactly one child".to_string(),
114 ));
115 }
116 Ok(Arc::new(Self::new(
117 Arc::clone(&children[0]),
118 self.key_indices.clone(),
119 self.priority_col_index,
120 )))
121 }
122
123 fn execute(
124 &self,
125 partition: usize,
126 context: Arc<TaskContext>,
127 ) -> DFResult<SendableRecordBatchStream> {
128 let input_stream = self.input.execute(partition, Arc::clone(&context))?;
129 let metrics = BaselineMetrics::new(&self.metrics, partition);
130 let key_indices = self.key_indices.clone();
131 let priority_col_index = self.priority_col_index;
132 let output_schema = Arc::clone(&self.schema);
133 let input_schema = self.input.schema();
134
135 let fut = async move {
136 let batches: Vec<RecordBatch> = input_stream.try_collect().await?;
138
139 if batches.is_empty() {
140 return Ok(RecordBatch::new_empty(output_schema));
141 }
142
143 let batch = arrow::compute::concat_batches(&input_schema, &batches)
144 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
145
146 if batch.num_rows() == 0 {
147 return Ok(RecordBatch::new_empty(output_schema));
148 }
149
150 let priority_col = batch
152 .column(priority_col_index)
153 .as_any()
154 .downcast_ref::<Int64Array>()
155 .ok_or_else(|| {
156 datafusion::error::DataFusionError::Execution(
157 "__priority column must be Int64".to_string(),
158 )
159 })?;
160
161 let mut group_max: HashMap<Vec<ScalarKey>, i64> = HashMap::new();
163 for row_idx in 0..batch.num_rows() {
164 let key = extract_scalar_key(&batch, &key_indices, row_idx);
165 let prio = priority_col.value(row_idx);
166 let entry = group_max.entry(key).or_insert(i64::MIN);
167 if prio > *entry {
168 *entry = prio;
169 }
170 }
171
172 let keep: Vec<bool> = (0..batch.num_rows())
174 .map(|row_idx| {
175 let key = extract_scalar_key(&batch, &key_indices, row_idx);
176 let prio = priority_col.value(row_idx);
177 group_max
178 .get(&key)
179 .is_some_and(|&max_prio| prio == max_prio)
180 })
181 .collect();
182
183 let filter_mask = BooleanArray::from(keep);
184
185 let mut output_columns = Vec::with_capacity(output_schema.fields().len());
187 for (i, col) in batch.columns().iter().enumerate() {
188 if i == priority_col_index {
189 continue;
190 }
191 let filtered = arrow_filter(col.as_ref(), &filter_mask).map_err(|e| {
192 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
193 })?;
194 output_columns.push(filtered);
195 }
196
197 RecordBatch::try_new(output_schema, output_columns)
198 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))
199 };
200
201 Ok(Box::pin(PriorityStream {
202 state: PriorityStreamState::Running(Box::pin(fut)),
203 schema: Arc::clone(&self.schema),
204 metrics,
205 }))
206 }
207
208 fn metrics(&self) -> Option<MetricsSet> {
209 Some(self.metrics.clone_inner())
210 }
211}
212
213enum PriorityStreamState {
218 Running(Pin<Box<dyn std::future::Future<Output = DFResult<RecordBatch>> + Send>>),
219 Done,
220}
221
222struct PriorityStream {
223 state: PriorityStreamState,
224 schema: SchemaRef,
225 metrics: BaselineMetrics,
226}
227
228impl Stream for PriorityStream {
229 type Item = DFResult<RecordBatch>;
230
231 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
232 match &mut self.state {
233 PriorityStreamState::Running(fut) => match fut.as_mut().poll(cx) {
234 Poll::Ready(Ok(batch)) => {
235 self.metrics.record_output(batch.num_rows());
236 self.state = PriorityStreamState::Done;
237 Poll::Ready(Some(Ok(batch)))
238 }
239 Poll::Ready(Err(e)) => {
240 self.state = PriorityStreamState::Done;
241 Poll::Ready(Some(Err(e)))
242 }
243 Poll::Pending => Poll::Pending,
244 },
245 PriorityStreamState::Done => Poll::Ready(None),
246 }
247 }
248}
249
250impl RecordBatchStream for PriorityStream {
251 fn schema(&self) -> SchemaRef {
252 Arc::clone(&self.schema)
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259 use arrow_array::{Int64Array, StringArray};
260 use arrow_schema::{DataType, Field, Schema};
261 use datafusion::physical_plan::memory::MemoryStream;
262 use datafusion::prelude::SessionContext;
263
264 fn make_test_batch(names: Vec<&str>, values: Vec<i64>, priorities: Vec<i64>) -> RecordBatch {
265 let schema = Arc::new(Schema::new(vec![
266 Field::new("name", DataType::Utf8, true),
267 Field::new("value", DataType::Int64, true),
268 Field::new("__priority", DataType::Int64, false),
269 ]));
270 RecordBatch::try_new(
271 schema,
272 vec![
273 Arc::new(StringArray::from(
274 names.into_iter().map(Some).collect::<Vec<_>>(),
275 )),
276 Arc::new(Int64Array::from(values)),
277 Arc::new(Int64Array::from(priorities)),
278 ],
279 )
280 .unwrap()
281 }
282
283 fn make_memory_exec(batch: RecordBatch) -> Arc<dyn ExecutionPlan> {
284 let schema = batch.schema();
285 Arc::new(TestMemoryExec {
286 batches: vec![batch],
287 schema: schema.clone(),
288 properties: compute_plan_properties(schema),
289 })
290 }
291
292 #[derive(Debug)]
293 struct TestMemoryExec {
294 batches: Vec<RecordBatch>,
295 schema: SchemaRef,
296 properties: PlanProperties,
297 }
298
299 impl DisplayAs for TestMemoryExec {
300 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
301 write!(f, "TestMemoryExec")
302 }
303 }
304
305 impl ExecutionPlan for TestMemoryExec {
306 fn name(&self) -> &str {
307 "TestMemoryExec"
308 }
309 fn as_any(&self) -> &dyn Any {
310 self
311 }
312 fn schema(&self) -> SchemaRef {
313 Arc::clone(&self.schema)
314 }
315 fn properties(&self) -> &PlanProperties {
316 &self.properties
317 }
318 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
319 vec![]
320 }
321 fn with_new_children(
322 self: Arc<Self>,
323 _children: Vec<Arc<dyn ExecutionPlan>>,
324 ) -> DFResult<Arc<dyn ExecutionPlan>> {
325 Ok(self)
326 }
327 fn execute(
328 &self,
329 _partition: usize,
330 _context: Arc<TaskContext>,
331 ) -> DFResult<SendableRecordBatchStream> {
332 Ok(Box::pin(MemoryStream::try_new(
333 self.batches.clone(),
334 Arc::clone(&self.schema),
335 None,
336 )?))
337 }
338 }
339
340 async fn execute_priority(
341 input: Arc<dyn ExecutionPlan>,
342 key_indices: Vec<usize>,
343 priority_col_index: usize,
344 ) -> RecordBatch {
345 let exec = PriorityExec::new(input, key_indices, priority_col_index);
346 let ctx = SessionContext::new();
347 let task_ctx = ctx.task_ctx();
348 let stream = exec.execute(0, task_ctx).unwrap();
349 let batches: Vec<RecordBatch> = datafusion::physical_plan::common::collect(stream)
350 .await
351 .unwrap();
352 if batches.is_empty() {
353 RecordBatch::new_empty(exec.schema())
354 } else {
355 arrow::compute::concat_batches(&exec.schema(), &batches).unwrap()
356 }
357 }
358
359 #[tokio::test]
360 async fn test_single_group_keeps_highest_priority() {
361 let batch = make_test_batch(vec!["a", "a", "a"], vec![10, 20, 30], vec![1, 3, 2]);
362 let input = make_memory_exec(batch);
363 let result = execute_priority(input, vec![0], 2).await;
365
366 assert_eq!(result.num_rows(), 1);
367 let values = result
368 .column(1)
369 .as_any()
370 .downcast_ref::<Int64Array>()
371 .unwrap();
372 assert_eq!(values.value(0), 20); }
374
375 #[tokio::test]
376 async fn test_multiple_groups_independent_priority() {
377 let batch = make_test_batch(
378 vec!["a", "a", "b", "b"],
379 vec![10, 20, 30, 40],
380 vec![1, 2, 3, 1],
381 );
382 let input = make_memory_exec(batch);
383 let result = execute_priority(input, vec![0], 2).await;
384
385 assert_eq!(result.num_rows(), 2);
386 let names = result
389 .column(0)
390 .as_any()
391 .downcast_ref::<StringArray>()
392 .unwrap();
393 let values = result
394 .column(1)
395 .as_any()
396 .downcast_ref::<Int64Array>()
397 .unwrap();
398
399 for i in 0..2 {
401 match names.value(i) {
402 "a" => assert_eq!(values.value(i), 20),
403 "b" => assert_eq!(values.value(i), 30),
404 _ => panic!("unexpected name"),
405 }
406 }
407 }
408
409 #[tokio::test]
410 async fn test_all_same_priority_keeps_all() {
411 let batch = make_test_batch(vec!["a", "a", "a"], vec![10, 20, 30], vec![5, 5, 5]);
412 let input = make_memory_exec(batch);
413 let result = execute_priority(input, vec![0], 2).await;
414
415 assert_eq!(result.num_rows(), 3);
416 }
417
418 #[tokio::test]
419 async fn test_empty_input() {
420 let schema = Arc::new(Schema::new(vec![
421 Field::new("name", DataType::Utf8, true),
422 Field::new("__priority", DataType::Int64, false),
423 ]));
424 let batch = RecordBatch::new_empty(schema.clone());
425 let input = make_memory_exec(batch);
426 let result = execute_priority(input, vec![0], 1).await;
427
428 assert_eq!(result.num_rows(), 0);
429 }
430
431 #[tokio::test]
432 async fn test_single_row_passthrough() {
433 let batch = make_test_batch(vec!["x"], vec![42], vec![1]);
434 let input = make_memory_exec(batch);
435 let result = execute_priority(input, vec![0], 2).await;
436
437 assert_eq!(result.num_rows(), 1);
438 let values = result
439 .column(1)
440 .as_any()
441 .downcast_ref::<Int64Array>()
442 .unwrap();
443 assert_eq!(values.value(0), 42);
444 }
445
446 #[tokio::test]
447 async fn test_output_schema_lacks_priority() {
448 let batch = make_test_batch(vec!["a"], vec![1], vec![1]);
449 let input = make_memory_exec(batch);
450 let exec = PriorityExec::new(input, vec![0], 2);
451
452 let schema = exec.schema();
453 assert_eq!(schema.fields().len(), 2); assert!(schema.column_with_name("__priority").is_none());
455 assert!(schema.column_with_name("name").is_some());
456 assert!(schema.column_with_name("value").is_some());
457 }
458}