1use crate::query::df_graph::common::{
10 ScalarKey, arrow_err, compute_plan_properties, extract_scalar_key,
11};
12use arrow::compute::take;
13use arrow_array::{RecordBatch, UInt32Array};
14use arrow_schema::SchemaRef;
15use datafusion::common::Result as DFResult;
16use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
17use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
18use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
19use futures::{Stream, TryStreamExt};
20use std::any::Any;
21use std::fmt;
22use std::pin::Pin;
23use std::sync::Arc;
24use std::task::{Context, Poll};
25
26#[derive(Debug, Clone)]
28pub struct SortCriterion {
29 pub col_index: usize,
30 pub ascending: bool,
31 pub nulls_first: bool,
32}
33
34#[derive(Debug)]
39pub struct BestByExec {
40 input: Arc<dyn ExecutionPlan>,
41 key_indices: Vec<usize>,
42 sort_criteria: Vec<SortCriterion>,
43 schema: SchemaRef,
44 properties: PlanProperties,
45 metrics: ExecutionPlanMetricsSet,
46 deterministic: bool,
49}
50
51impl BestByExec {
52 pub fn new(
60 input: Arc<dyn ExecutionPlan>,
61 key_indices: Vec<usize>,
62 sort_criteria: Vec<SortCriterion>,
63 deterministic: bool,
64 ) -> Self {
65 let schema = input.schema();
66 let properties = compute_plan_properties(Arc::clone(&schema));
67 Self {
68 input,
69 key_indices,
70 sort_criteria,
71 schema,
72 properties,
73 metrics: ExecutionPlanMetricsSet::new(),
74 deterministic,
75 }
76 }
77}
78
79impl DisplayAs for BestByExec {
80 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81 write!(
82 f,
83 "BestByExec: key_indices={:?}, criteria={:?}",
84 self.key_indices, self.sort_criteria
85 )
86 }
87}
88
89impl ExecutionPlan for BestByExec {
90 fn name(&self) -> &str {
91 "BestByExec"
92 }
93
94 fn as_any(&self) -> &dyn Any {
95 self
96 }
97
98 fn schema(&self) -> SchemaRef {
99 Arc::clone(&self.schema)
100 }
101
102 fn properties(&self) -> &PlanProperties {
103 &self.properties
104 }
105
106 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
107 vec![&self.input]
108 }
109
110 fn with_new_children(
111 self: Arc<Self>,
112 children: Vec<Arc<dyn ExecutionPlan>>,
113 ) -> DFResult<Arc<dyn ExecutionPlan>> {
114 if children.len() != 1 {
115 return Err(datafusion::error::DataFusionError::Plan(
116 "BestByExec requires exactly one child".to_string(),
117 ));
118 }
119 Ok(Arc::new(Self::new(
120 Arc::clone(&children[0]),
121 self.key_indices.clone(),
122 self.sort_criteria.clone(),
123 self.deterministic,
124 )))
125 }
126
127 fn execute(
128 &self,
129 partition: usize,
130 context: Arc<TaskContext>,
131 ) -> DFResult<SendableRecordBatchStream> {
132 let input_stream = self.input.execute(partition, Arc::clone(&context))?;
133 let metrics = BaselineMetrics::new(&self.metrics, partition);
134 let key_indices = self.key_indices.clone();
135 let sort_criteria = self.sort_criteria.clone();
136 let schema = Arc::clone(&self.schema);
137 let input_schema = self.input.schema();
138 let deterministic = self.deterministic;
139
140 let fut = async move {
141 let batches: Vec<RecordBatch> = input_stream.try_collect().await?;
142
143 if batches.is_empty() {
144 return Ok(RecordBatch::new_empty(schema));
145 }
146
147 let batch =
148 arrow::compute::concat_batches(&input_schema, &batches).map_err(arrow_err)?;
149
150 if batch.num_rows() == 0 {
151 return Ok(RecordBatch::new_empty(schema));
152 }
153
154 let num_cols = batch.num_columns();
157 let mut sort_columns = Vec::new();
158
159 for &ki in &key_indices {
161 sort_columns.push(arrow::compute::SortColumn {
162 values: Arc::clone(batch.column(ki)),
163 options: Some(arrow::compute::SortOptions {
164 descending: false,
165 nulls_first: false,
166 }),
167 });
168 }
169
170 for criterion in &sort_criteria {
172 sort_columns.push(arrow::compute::SortColumn {
173 values: Arc::clone(batch.column(criterion.col_index)),
174 options: Some(arrow::compute::SortOptions {
175 descending: !criterion.ascending,
176 nulls_first: criterion.nulls_first,
177 }),
178 });
179 }
180
181 if deterministic {
183 let used_cols: std::collections::HashSet<usize> = key_indices
184 .iter()
185 .copied()
186 .chain(sort_criteria.iter().map(|c| c.col_index))
187 .collect();
188 for col_idx in 0..num_cols {
189 if !used_cols.contains(&col_idx) {
190 sort_columns.push(arrow::compute::SortColumn {
191 values: Arc::clone(batch.column(col_idx)),
192 options: Some(arrow::compute::SortOptions {
193 descending: false,
194 nulls_first: false,
195 }),
196 });
197 }
198 }
199 }
200
201 let sorted_indices =
203 arrow::compute::lexsort_to_indices(&sort_columns, None).map_err(arrow_err)?;
204
205 let sorted_columns: Vec<_> = batch
207 .columns()
208 .iter()
209 .map(|col| take(col.as_ref(), &sorted_indices, None))
210 .collect::<Result<Vec<_>, _>>()?;
211 let sorted_batch =
212 RecordBatch::try_new(Arc::clone(&schema), sorted_columns).map_err(arrow_err)?;
213
214 let mut keep_indices: Vec<u32> = Vec::new();
216 let mut prev_key: Option<Vec<ScalarKey>> = None;
217
218 for row_idx in 0..sorted_batch.num_rows() {
219 let key = extract_scalar_key(&sorted_batch, &key_indices, row_idx);
220 let is_new_group = match &prev_key {
221 None => true,
222 Some(prev) => *prev != key,
223 };
224 if is_new_group {
225 keep_indices.push(row_idx as u32);
226 prev_key = Some(key);
227 }
228 }
229
230 let keep_array = UInt32Array::from(keep_indices);
231 let output_columns: Vec<_> = sorted_batch
232 .columns()
233 .iter()
234 .map(|col| take(col.as_ref(), &keep_array, None))
235 .collect::<Result<Vec<_>, _>>()?;
236
237 RecordBatch::try_new(schema, output_columns).map_err(arrow_err)
238 };
239
240 Ok(Box::pin(BestByStream {
241 state: BestByStreamState::Running(Box::pin(fut)),
242 schema: Arc::clone(&self.schema),
243 metrics,
244 }))
245 }
246
247 fn metrics(&self) -> Option<MetricsSet> {
248 Some(self.metrics.clone_inner())
249 }
250}
251
252enum BestByStreamState {
257 Running(Pin<Box<dyn std::future::Future<Output = DFResult<RecordBatch>> + Send>>),
258 Done,
259}
260
261struct BestByStream {
262 state: BestByStreamState,
263 schema: SchemaRef,
264 metrics: BaselineMetrics,
265}
266
267impl Stream for BestByStream {
268 type Item = DFResult<RecordBatch>;
269
270 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
271 match &mut self.state {
272 BestByStreamState::Running(fut) => match fut.as_mut().poll(cx) {
273 Poll::Ready(Ok(batch)) => {
274 self.metrics.record_output(batch.num_rows());
275 self.state = BestByStreamState::Done;
276 Poll::Ready(Some(Ok(batch)))
277 }
278 Poll::Ready(Err(e)) => {
279 self.state = BestByStreamState::Done;
280 Poll::Ready(Some(Err(e)))
281 }
282 Poll::Pending => Poll::Pending,
283 },
284 BestByStreamState::Done => Poll::Ready(None),
285 }
286 }
287}
288
289impl RecordBatchStream for BestByStream {
290 fn schema(&self) -> SchemaRef {
291 Arc::clone(&self.schema)
292 }
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298 use arrow_array::{Float64Array, Int64Array, StringArray};
299 use arrow_schema::{DataType, Field, Schema};
300 use datafusion::physical_plan::memory::MemoryStream;
301 use datafusion::prelude::SessionContext;
302
303 fn make_test_batch(names: Vec<&str>, scores: Vec<f64>, ages: Vec<i64>) -> RecordBatch {
304 let schema = Arc::new(Schema::new(vec![
305 Field::new("name", DataType::Utf8, true),
306 Field::new("score", DataType::Float64, true),
307 Field::new("age", DataType::Int64, true),
308 ]));
309 RecordBatch::try_new(
310 schema,
311 vec![
312 Arc::new(StringArray::from(
313 names.into_iter().map(Some).collect::<Vec<_>>(),
314 )),
315 Arc::new(Float64Array::from(scores)),
316 Arc::new(Int64Array::from(ages)),
317 ],
318 )
319 .unwrap()
320 }
321
322 fn make_memory_exec(batch: RecordBatch) -> Arc<dyn ExecutionPlan> {
323 let schema = batch.schema();
324 Arc::new(TestMemoryExec {
325 batches: vec![batch],
326 schema: schema.clone(),
327 properties: compute_plan_properties(schema),
328 })
329 }
330
331 #[derive(Debug)]
332 struct TestMemoryExec {
333 batches: Vec<RecordBatch>,
334 schema: SchemaRef,
335 properties: PlanProperties,
336 }
337
338 impl DisplayAs for TestMemoryExec {
339 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
340 write!(f, "TestMemoryExec")
341 }
342 }
343
344 impl ExecutionPlan for TestMemoryExec {
345 fn name(&self) -> &str {
346 "TestMemoryExec"
347 }
348 fn as_any(&self) -> &dyn Any {
349 self
350 }
351 fn schema(&self) -> SchemaRef {
352 Arc::clone(&self.schema)
353 }
354 fn properties(&self) -> &PlanProperties {
355 &self.properties
356 }
357 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
358 vec![]
359 }
360 fn with_new_children(
361 self: Arc<Self>,
362 _children: Vec<Arc<dyn ExecutionPlan>>,
363 ) -> DFResult<Arc<dyn ExecutionPlan>> {
364 Ok(self)
365 }
366 fn execute(
367 &self,
368 _partition: usize,
369 _context: Arc<TaskContext>,
370 ) -> DFResult<SendableRecordBatchStream> {
371 Ok(Box::pin(MemoryStream::try_new(
372 self.batches.clone(),
373 Arc::clone(&self.schema),
374 None,
375 )?))
376 }
377 }
378
379 async fn execute_best_by(
380 input: Arc<dyn ExecutionPlan>,
381 key_indices: Vec<usize>,
382 sort_criteria: Vec<SortCriterion>,
383 ) -> RecordBatch {
384 let exec = BestByExec::new(input, key_indices, sort_criteria, true);
385 let ctx = SessionContext::new();
386 let task_ctx = ctx.task_ctx();
387 let stream = exec.execute(0, task_ctx).unwrap();
388 let batches: Vec<RecordBatch> = datafusion::physical_plan::common::collect(stream)
389 .await
390 .unwrap();
391 if batches.is_empty() {
392 RecordBatch::new_empty(exec.schema())
393 } else {
394 arrow::compute::concat_batches(&exec.schema(), &batches).unwrap()
395 }
396 }
397
398 #[tokio::test]
399 async fn test_best_ascending() {
400 let batch = make_test_batch(vec!["a", "a", "a"], vec![3.0, 1.0, 2.0], vec![20, 30, 25]);
402 let input = make_memory_exec(batch);
403 let result = execute_best_by(
404 input,
405 vec![0], vec![SortCriterion {
407 col_index: 1, ascending: true,
409 nulls_first: false,
410 }],
411 )
412 .await;
413
414 assert_eq!(result.num_rows(), 1);
415 let scores = result
416 .column(1)
417 .as_any()
418 .downcast_ref::<Float64Array>()
419 .unwrap();
420 assert_eq!(scores.value(0), 1.0);
421 }
422
423 #[tokio::test]
424 async fn test_best_descending() {
425 let batch = make_test_batch(vec!["a", "a", "a"], vec![3.0, 1.0, 2.0], vec![20, 30, 25]);
426 let input = make_memory_exec(batch);
427 let result = execute_best_by(
428 input,
429 vec![0],
430 vec![SortCriterion {
431 col_index: 1,
432 ascending: false,
433 nulls_first: false,
434 }],
435 )
436 .await;
437
438 assert_eq!(result.num_rows(), 1);
439 let scores = result
440 .column(1)
441 .as_any()
442 .downcast_ref::<Float64Array>()
443 .unwrap();
444 assert_eq!(scores.value(0), 3.0);
445 }
446
447 #[tokio::test]
448 async fn test_multiple_groups() {
449 let batch = make_test_batch(
450 vec!["a", "a", "b", "b"],
451 vec![3.0, 1.0, 5.0, 2.0],
452 vec![20, 30, 40, 50],
453 );
454 let input = make_memory_exec(batch);
455 let result = execute_best_by(
456 input,
457 vec![0],
458 vec![SortCriterion {
459 col_index: 1,
460 ascending: true,
461 nulls_first: false,
462 }],
463 )
464 .await;
465
466 assert_eq!(result.num_rows(), 2);
467 let names = result
468 .column(0)
469 .as_any()
470 .downcast_ref::<StringArray>()
471 .unwrap();
472 let scores = result
473 .column(1)
474 .as_any()
475 .downcast_ref::<Float64Array>()
476 .unwrap();
477
478 for i in 0..2 {
479 match names.value(i) {
480 "a" => assert_eq!(scores.value(i), 1.0),
481 "b" => assert_eq!(scores.value(i), 2.0),
482 _ => panic!("unexpected name"),
483 }
484 }
485 }
486
487 #[tokio::test]
488 async fn test_multi_column_criteria() {
489 let batch = make_test_batch(vec!["a", "a"], vec![1.0, 1.0], vec![30, 20]);
491 let input = make_memory_exec(batch);
492 let result = execute_best_by(
493 input,
494 vec![0],
495 vec![
496 SortCriterion {
497 col_index: 1,
498 ascending: true,
499 nulls_first: false,
500 },
501 SortCriterion {
502 col_index: 2,
503 ascending: true,
504 nulls_first: false,
505 },
506 ],
507 )
508 .await;
509
510 assert_eq!(result.num_rows(), 1);
511 let ages = result
512 .column(2)
513 .as_any()
514 .downcast_ref::<Int64Array>()
515 .unwrap();
516 assert_eq!(ages.value(0), 20); }
518
519 #[tokio::test]
520 async fn test_empty_input() {
521 let schema = Arc::new(Schema::new(vec![
522 Field::new("name", DataType::Utf8, true),
523 Field::new("score", DataType::Float64, true),
524 ]));
525 let batch = RecordBatch::new_empty(schema.clone());
526 let input = make_memory_exec(batch);
527 let result = execute_best_by(input, vec![0], vec![]).await;
528 assert_eq!(result.num_rows(), 0);
529 }
530
531 #[tokio::test]
532 async fn test_single_row_passthrough() {
533 let batch = make_test_batch(vec!["x"], vec![42.0], vec![10]);
534 let input = make_memory_exec(batch);
535 let result = execute_best_by(
536 input,
537 vec![0],
538 vec![SortCriterion {
539 col_index: 1,
540 ascending: true,
541 nulls_first: false,
542 }],
543 )
544 .await;
545
546 assert_eq!(result.num_rows(), 1);
547 let scores = result
548 .column(1)
549 .as_any()
550 .downcast_ref::<Float64Array>()
551 .unwrap();
552 assert_eq!(scores.value(0), 42.0);
553 }
554}