1use crate::query::df_graph::common::{
11 ScalarKey, arrow_err, compute_plan_properties, extract_scalar_key,
12};
13use arrow::compute::filter as arrow_filter;
14use arrow_array::{BooleanArray, Int64Array, RecordBatch};
15use arrow_schema::{Field, Schema, SchemaRef};
16use datafusion::common::Result as DFResult;
17use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
18use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
19use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
20use futures::{Stream, TryStreamExt};
21use std::any::Any;
22use std::collections::HashMap;
23use std::fmt;
24use std::pin::Pin;
25use std::sync::Arc;
26use std::task::{Context, Poll};
27
28#[derive(Debug)]
34pub struct PriorityExec {
35 input: Arc<dyn ExecutionPlan>,
36 key_indices: Vec<usize>,
37 priority_col_index: usize,
38 schema: SchemaRef,
39 properties: PlanProperties,
40 metrics: ExecutionPlanMetricsSet,
41}
42
43impl PriorityExec {
44 pub fn new(
51 input: Arc<dyn ExecutionPlan>,
52 key_indices: Vec<usize>,
53 priority_col_index: usize,
54 ) -> Self {
55 let input_schema = input.schema();
56 let output_fields: Vec<Arc<Field>> = input_schema
58 .fields()
59 .iter()
60 .enumerate()
61 .filter(|(i, _)| *i != priority_col_index)
62 .map(|(_, f)| Arc::clone(f))
63 .collect();
64 let schema = Arc::new(Schema::new(output_fields));
65 let properties = compute_plan_properties(Arc::clone(&schema));
66
67 Self {
68 input,
69 key_indices,
70 priority_col_index,
71 schema,
72 properties,
73 metrics: ExecutionPlanMetricsSet::new(),
74 }
75 }
76}
77
78impl DisplayAs for PriorityExec {
79 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80 write!(
81 f,
82 "PriorityExec: key_indices={:?}, priority_col={}",
83 self.key_indices, self.priority_col_index
84 )
85 }
86}
87
88impl ExecutionPlan for PriorityExec {
89 fn name(&self) -> &str {
90 "PriorityExec"
91 }
92
93 fn as_any(&self) -> &dyn Any {
94 self
95 }
96
97 fn schema(&self) -> SchemaRef {
98 Arc::clone(&self.schema)
99 }
100
101 fn properties(&self) -> &PlanProperties {
102 &self.properties
103 }
104
105 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
106 vec![&self.input]
107 }
108
109 fn with_new_children(
110 self: Arc<Self>,
111 children: Vec<Arc<dyn ExecutionPlan>>,
112 ) -> DFResult<Arc<dyn ExecutionPlan>> {
113 if children.len() != 1 {
114 return Err(datafusion::error::DataFusionError::Plan(
115 "PriorityExec requires exactly one child".to_string(),
116 ));
117 }
118 Ok(Arc::new(Self::new(
119 Arc::clone(&children[0]),
120 self.key_indices.clone(),
121 self.priority_col_index,
122 )))
123 }
124
125 fn execute(
126 &self,
127 partition: usize,
128 context: Arc<TaskContext>,
129 ) -> DFResult<SendableRecordBatchStream> {
130 let input_stream = self.input.execute(partition, Arc::clone(&context))?;
131 let metrics = BaselineMetrics::new(&self.metrics, partition);
132 let key_indices = self.key_indices.clone();
133 let priority_col_index = self.priority_col_index;
134 let output_schema = Arc::clone(&self.schema);
135 let input_schema = self.input.schema();
136
137 let fut = async move {
138 let batches: Vec<RecordBatch> = input_stream.try_collect().await?;
140
141 if batches.is_empty() {
142 return Ok(RecordBatch::new_empty(output_schema));
143 }
144
145 let batch =
146 arrow::compute::concat_batches(&input_schema, &batches).map_err(arrow_err)?;
147
148 if batch.num_rows() == 0 {
149 return Ok(RecordBatch::new_empty(output_schema));
150 }
151
152 let priority_col = batch
154 .column(priority_col_index)
155 .as_any()
156 .downcast_ref::<Int64Array>()
157 .ok_or_else(|| {
158 datafusion::error::DataFusionError::Execution(
159 "__priority column must be Int64".to_string(),
160 )
161 })?;
162
163 let mut group_max: HashMap<Vec<ScalarKey>, i64> = HashMap::new();
165 for row_idx in 0..batch.num_rows() {
166 let key = extract_scalar_key(&batch, &key_indices, row_idx);
167 let prio = priority_col.value(row_idx);
168 let entry = group_max.entry(key).or_insert(i64::MIN);
169 if prio > *entry {
170 *entry = prio;
171 }
172 }
173
174 let keep: Vec<bool> = (0..batch.num_rows())
176 .map(|row_idx| {
177 let key = extract_scalar_key(&batch, &key_indices, row_idx);
178 let prio = priority_col.value(row_idx);
179 group_max
180 .get(&key)
181 .is_some_and(|&max_prio| prio == max_prio)
182 })
183 .collect();
184
185 let filter_mask = BooleanArray::from(keep);
186
187 let mut output_columns = Vec::with_capacity(output_schema.fields().len());
189 for (i, col) in batch.columns().iter().enumerate() {
190 if i == priority_col_index {
191 continue;
192 }
193 let filtered = arrow_filter(col.as_ref(), &filter_mask).map_err(|e| {
194 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
195 })?;
196 output_columns.push(filtered);
197 }
198
199 RecordBatch::try_new(output_schema, output_columns).map_err(arrow_err)
200 };
201
202 Ok(Box::pin(PriorityStream {
203 state: PriorityStreamState::Running(Box::pin(fut)),
204 schema: Arc::clone(&self.schema),
205 metrics,
206 }))
207 }
208
209 fn metrics(&self) -> Option<MetricsSet> {
210 Some(self.metrics.clone_inner())
211 }
212}
213
214enum PriorityStreamState {
219 Running(Pin<Box<dyn std::future::Future<Output = DFResult<RecordBatch>> + Send>>),
220 Done,
221}
222
223struct PriorityStream {
224 state: PriorityStreamState,
225 schema: SchemaRef,
226 metrics: BaselineMetrics,
227}
228
229impl Stream for PriorityStream {
230 type Item = DFResult<RecordBatch>;
231
232 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
233 match &mut self.state {
234 PriorityStreamState::Running(fut) => match fut.as_mut().poll(cx) {
235 Poll::Ready(Ok(batch)) => {
236 self.metrics.record_output(batch.num_rows());
237 self.state = PriorityStreamState::Done;
238 Poll::Ready(Some(Ok(batch)))
239 }
240 Poll::Ready(Err(e)) => {
241 self.state = PriorityStreamState::Done;
242 Poll::Ready(Some(Err(e)))
243 }
244 Poll::Pending => Poll::Pending,
245 },
246 PriorityStreamState::Done => Poll::Ready(None),
247 }
248 }
249}
250
251impl RecordBatchStream for PriorityStream {
252 fn schema(&self) -> SchemaRef {
253 Arc::clone(&self.schema)
254 }
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260 use arrow_array::{Int64Array, StringArray};
261 use arrow_schema::{DataType, Field, Schema};
262 use datafusion::physical_plan::memory::MemoryStream;
263 use datafusion::prelude::SessionContext;
264
265 fn make_test_batch(names: Vec<&str>, values: Vec<i64>, priorities: Vec<i64>) -> RecordBatch {
266 let schema = Arc::new(Schema::new(vec![
267 Field::new("name", DataType::Utf8, true),
268 Field::new("value", DataType::Int64, true),
269 Field::new("__priority", DataType::Int64, false),
270 ]));
271 RecordBatch::try_new(
272 schema,
273 vec![
274 Arc::new(StringArray::from(
275 names.into_iter().map(Some).collect::<Vec<_>>(),
276 )),
277 Arc::new(Int64Array::from(values)),
278 Arc::new(Int64Array::from(priorities)),
279 ],
280 )
281 .unwrap()
282 }
283
284 fn make_memory_exec(batch: RecordBatch) -> Arc<dyn ExecutionPlan> {
285 let schema = batch.schema();
286 Arc::new(TestMemoryExec {
287 batches: vec![batch],
288 schema: schema.clone(),
289 properties: compute_plan_properties(schema),
290 })
291 }
292
293 #[derive(Debug)]
294 struct TestMemoryExec {
295 batches: Vec<RecordBatch>,
296 schema: SchemaRef,
297 properties: PlanProperties,
298 }
299
300 impl DisplayAs for TestMemoryExec {
301 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
302 write!(f, "TestMemoryExec")
303 }
304 }
305
306 impl ExecutionPlan for TestMemoryExec {
307 fn name(&self) -> &str {
308 "TestMemoryExec"
309 }
310 fn as_any(&self) -> &dyn Any {
311 self
312 }
313 fn schema(&self) -> SchemaRef {
314 Arc::clone(&self.schema)
315 }
316 fn properties(&self) -> &PlanProperties {
317 &self.properties
318 }
319 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
320 vec![]
321 }
322 fn with_new_children(
323 self: Arc<Self>,
324 _children: Vec<Arc<dyn ExecutionPlan>>,
325 ) -> DFResult<Arc<dyn ExecutionPlan>> {
326 Ok(self)
327 }
328 fn execute(
329 &self,
330 _partition: usize,
331 _context: Arc<TaskContext>,
332 ) -> DFResult<SendableRecordBatchStream> {
333 Ok(Box::pin(MemoryStream::try_new(
334 self.batches.clone(),
335 Arc::clone(&self.schema),
336 None,
337 )?))
338 }
339 }
340
341 async fn execute_priority(
342 input: Arc<dyn ExecutionPlan>,
343 key_indices: Vec<usize>,
344 priority_col_index: usize,
345 ) -> RecordBatch {
346 let exec = PriorityExec::new(input, key_indices, priority_col_index);
347 let ctx = SessionContext::new();
348 let task_ctx = ctx.task_ctx();
349 let stream = exec.execute(0, task_ctx).unwrap();
350 let batches: Vec<RecordBatch> = datafusion::physical_plan::common::collect(stream)
351 .await
352 .unwrap();
353 if batches.is_empty() {
354 RecordBatch::new_empty(exec.schema())
355 } else {
356 arrow::compute::concat_batches(&exec.schema(), &batches).unwrap()
357 }
358 }
359
360 #[tokio::test]
361 async fn test_single_group_keeps_highest_priority() {
362 let batch = make_test_batch(vec!["a", "a", "a"], vec![10, 20, 30], vec![1, 3, 2]);
363 let input = make_memory_exec(batch);
364 let result = execute_priority(input, vec![0], 2).await;
366
367 assert_eq!(result.num_rows(), 1);
368 let values = result
369 .column(1)
370 .as_any()
371 .downcast_ref::<Int64Array>()
372 .unwrap();
373 assert_eq!(values.value(0), 20); }
375
376 #[tokio::test]
377 async fn test_multiple_groups_independent_priority() {
378 let batch = make_test_batch(
379 vec!["a", "a", "b", "b"],
380 vec![10, 20, 30, 40],
381 vec![1, 2, 3, 1],
382 );
383 let input = make_memory_exec(batch);
384 let result = execute_priority(input, vec![0], 2).await;
385
386 assert_eq!(result.num_rows(), 2);
387 let names = result
390 .column(0)
391 .as_any()
392 .downcast_ref::<StringArray>()
393 .unwrap();
394 let values = result
395 .column(1)
396 .as_any()
397 .downcast_ref::<Int64Array>()
398 .unwrap();
399
400 for i in 0..2 {
402 match names.value(i) {
403 "a" => assert_eq!(values.value(i), 20),
404 "b" => assert_eq!(values.value(i), 30),
405 _ => panic!("unexpected name"),
406 }
407 }
408 }
409
410 #[tokio::test]
411 async fn test_all_same_priority_keeps_all() {
412 let batch = make_test_batch(vec!["a", "a", "a"], vec![10, 20, 30], vec![5, 5, 5]);
413 let input = make_memory_exec(batch);
414 let result = execute_priority(input, vec![0], 2).await;
415
416 assert_eq!(result.num_rows(), 3);
417 }
418
419 #[tokio::test]
420 async fn test_empty_input() {
421 let schema = Arc::new(Schema::new(vec![
422 Field::new("name", DataType::Utf8, true),
423 Field::new("__priority", DataType::Int64, false),
424 ]));
425 let batch = RecordBatch::new_empty(schema.clone());
426 let input = make_memory_exec(batch);
427 let result = execute_priority(input, vec![0], 1).await;
428
429 assert_eq!(result.num_rows(), 0);
430 }
431
432 #[tokio::test]
433 async fn test_single_row_passthrough() {
434 let batch = make_test_batch(vec!["x"], vec![42], vec![1]);
435 let input = make_memory_exec(batch);
436 let result = execute_priority(input, vec![0], 2).await;
437
438 assert_eq!(result.num_rows(), 1);
439 let values = result
440 .column(1)
441 .as_any()
442 .downcast_ref::<Int64Array>()
443 .unwrap();
444 assert_eq!(values.value(0), 42);
445 }
446
447 #[tokio::test]
448 async fn test_output_schema_lacks_priority() {
449 let batch = make_test_batch(vec!["a"], vec![1], vec![1]);
450 let input = make_memory_exec(batch);
451 let exec = PriorityExec::new(input, vec![0], 2);
452
453 let schema = exec.schema();
454 assert_eq!(schema.fields().len(), 2); assert!(schema.column_with_name("__priority").is_none());
456 assert!(schema.column_with_name("name").is_some());
457 assert!(schema.column_with_name("value").is_some());
458 }
459}