1use crate::query::df_graph::common::{ScalarKey, compute_plan_properties, extract_scalar_key};
10use arrow::compute::take;
11use arrow_array::{RecordBatch, UInt32Array};
12use arrow_schema::SchemaRef;
13use datafusion::common::Result as DFResult;
14use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
15use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
16use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
17use futures::{Stream, TryStreamExt};
18use std::any::Any;
19use std::fmt;
20use std::pin::Pin;
21use std::sync::Arc;
22use std::task::{Context, Poll};
23
24#[derive(Debug, Clone)]
26pub struct SortCriterion {
27 pub col_index: usize,
28 pub ascending: bool,
29 pub nulls_first: bool,
30}
31
32#[derive(Debug)]
37pub struct BestByExec {
38 input: Arc<dyn ExecutionPlan>,
39 key_indices: Vec<usize>,
40 sort_criteria: Vec<SortCriterion>,
41 schema: SchemaRef,
42 properties: PlanProperties,
43 metrics: ExecutionPlanMetricsSet,
44 deterministic: bool,
47}
48
49impl BestByExec {
50 pub fn new(
58 input: Arc<dyn ExecutionPlan>,
59 key_indices: Vec<usize>,
60 sort_criteria: Vec<SortCriterion>,
61 deterministic: bool,
62 ) -> Self {
63 let schema = input.schema();
64 let properties = compute_plan_properties(Arc::clone(&schema));
65 Self {
66 input,
67 key_indices,
68 sort_criteria,
69 schema,
70 properties,
71 metrics: ExecutionPlanMetricsSet::new(),
72 deterministic,
73 }
74 }
75}
76
77impl DisplayAs for BestByExec {
78 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79 write!(
80 f,
81 "BestByExec: key_indices={:?}, criteria={:?}",
82 self.key_indices, self.sort_criteria
83 )
84 }
85}
86
87impl ExecutionPlan for BestByExec {
88 fn name(&self) -> &str {
89 "BestByExec"
90 }
91
92 fn as_any(&self) -> &dyn Any {
93 self
94 }
95
96 fn schema(&self) -> SchemaRef {
97 Arc::clone(&self.schema)
98 }
99
100 fn properties(&self) -> &PlanProperties {
101 &self.properties
102 }
103
104 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
105 vec![&self.input]
106 }
107
108 fn with_new_children(
109 self: Arc<Self>,
110 children: Vec<Arc<dyn ExecutionPlan>>,
111 ) -> DFResult<Arc<dyn ExecutionPlan>> {
112 if children.len() != 1 {
113 return Err(datafusion::error::DataFusionError::Plan(
114 "BestByExec requires exactly one child".to_string(),
115 ));
116 }
117 Ok(Arc::new(Self::new(
118 Arc::clone(&children[0]),
119 self.key_indices.clone(),
120 self.sort_criteria.clone(),
121 self.deterministic,
122 )))
123 }
124
125 fn execute(
126 &self,
127 partition: usize,
128 context: Arc<TaskContext>,
129 ) -> DFResult<SendableRecordBatchStream> {
130 let input_stream = self.input.execute(partition, Arc::clone(&context))?;
131 let metrics = BaselineMetrics::new(&self.metrics, partition);
132 let key_indices = self.key_indices.clone();
133 let sort_criteria = self.sort_criteria.clone();
134 let schema = Arc::clone(&self.schema);
135 let input_schema = self.input.schema();
136 let deterministic = self.deterministic;
137
138 let fut = async move {
139 let batches: Vec<RecordBatch> = input_stream.try_collect().await?;
140
141 if batches.is_empty() {
142 return Ok(RecordBatch::new_empty(schema));
143 }
144
145 let batch = arrow::compute::concat_batches(&input_schema, &batches)
146 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
147
148 if batch.num_rows() == 0 {
149 return Ok(RecordBatch::new_empty(schema));
150 }
151
152 let num_cols = batch.num_columns();
155 let mut sort_columns = Vec::new();
156
157 for &ki in &key_indices {
159 sort_columns.push(arrow::compute::SortColumn {
160 values: Arc::clone(batch.column(ki)),
161 options: Some(arrow::compute::SortOptions {
162 descending: false,
163 nulls_first: false,
164 }),
165 });
166 }
167
168 for criterion in &sort_criteria {
170 sort_columns.push(arrow::compute::SortColumn {
171 values: Arc::clone(batch.column(criterion.col_index)),
172 options: Some(arrow::compute::SortOptions {
173 descending: !criterion.ascending,
174 nulls_first: criterion.nulls_first,
175 }),
176 });
177 }
178
179 if deterministic {
181 let used_cols: std::collections::HashSet<usize> = key_indices
182 .iter()
183 .copied()
184 .chain(sort_criteria.iter().map(|c| c.col_index))
185 .collect();
186 for col_idx in 0..num_cols {
187 if !used_cols.contains(&col_idx) {
188 sort_columns.push(arrow::compute::SortColumn {
189 values: Arc::clone(batch.column(col_idx)),
190 options: Some(arrow::compute::SortOptions {
191 descending: false,
192 nulls_first: false,
193 }),
194 });
195 }
196 }
197 }
198
199 let sorted_indices = arrow::compute::lexsort_to_indices(&sort_columns, None)
201 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
202
203 let sorted_columns: Vec<_> = batch
205 .columns()
206 .iter()
207 .map(|col| take(col.as_ref(), &sorted_indices, None))
208 .collect::<Result<Vec<_>, _>>()?;
209 let sorted_batch = RecordBatch::try_new(Arc::clone(&schema), sorted_columns)
210 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
211
212 let mut keep_indices: Vec<u32> = Vec::new();
214 let mut prev_key: Option<Vec<ScalarKey>> = None;
215
216 for row_idx in 0..sorted_batch.num_rows() {
217 let key = extract_scalar_key(&sorted_batch, &key_indices, row_idx);
218 let is_new_group = match &prev_key {
219 None => true,
220 Some(prev) => *prev != key,
221 };
222 if is_new_group {
223 keep_indices.push(row_idx as u32);
224 prev_key = Some(key);
225 }
226 }
227
228 let keep_array = UInt32Array::from(keep_indices);
229 let output_columns: Vec<_> = sorted_batch
230 .columns()
231 .iter()
232 .map(|col| take(col.as_ref(), &keep_array, None))
233 .collect::<Result<Vec<_>, _>>()?;
234
235 RecordBatch::try_new(schema, output_columns)
236 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))
237 };
238
239 Ok(Box::pin(BestByStream {
240 state: BestByStreamState::Running(Box::pin(fut)),
241 schema: Arc::clone(&self.schema),
242 metrics,
243 }))
244 }
245
246 fn metrics(&self) -> Option<MetricsSet> {
247 Some(self.metrics.clone_inner())
248 }
249}
250
251enum BestByStreamState {
256 Running(Pin<Box<dyn std::future::Future<Output = DFResult<RecordBatch>> + Send>>),
257 Done,
258}
259
260struct BestByStream {
261 state: BestByStreamState,
262 schema: SchemaRef,
263 metrics: BaselineMetrics,
264}
265
266impl Stream for BestByStream {
267 type Item = DFResult<RecordBatch>;
268
269 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
270 match &mut self.state {
271 BestByStreamState::Running(fut) => match fut.as_mut().poll(cx) {
272 Poll::Ready(Ok(batch)) => {
273 self.metrics.record_output(batch.num_rows());
274 self.state = BestByStreamState::Done;
275 Poll::Ready(Some(Ok(batch)))
276 }
277 Poll::Ready(Err(e)) => {
278 self.state = BestByStreamState::Done;
279 Poll::Ready(Some(Err(e)))
280 }
281 Poll::Pending => Poll::Pending,
282 },
283 BestByStreamState::Done => Poll::Ready(None),
284 }
285 }
286}
287
288impl RecordBatchStream for BestByStream {
289 fn schema(&self) -> SchemaRef {
290 Arc::clone(&self.schema)
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297 use arrow_array::{Float64Array, Int64Array, StringArray};
298 use arrow_schema::{DataType, Field, Schema};
299 use datafusion::physical_plan::memory::MemoryStream;
300 use datafusion::prelude::SessionContext;
301
302 fn make_test_batch(names: Vec<&str>, scores: Vec<f64>, ages: Vec<i64>) -> RecordBatch {
303 let schema = Arc::new(Schema::new(vec![
304 Field::new("name", DataType::Utf8, true),
305 Field::new("score", DataType::Float64, true),
306 Field::new("age", DataType::Int64, true),
307 ]));
308 RecordBatch::try_new(
309 schema,
310 vec![
311 Arc::new(StringArray::from(
312 names.into_iter().map(Some).collect::<Vec<_>>(),
313 )),
314 Arc::new(Float64Array::from(scores)),
315 Arc::new(Int64Array::from(ages)),
316 ],
317 )
318 .unwrap()
319 }
320
321 fn make_memory_exec(batch: RecordBatch) -> Arc<dyn ExecutionPlan> {
322 let schema = batch.schema();
323 Arc::new(TestMemoryExec {
324 batches: vec![batch],
325 schema: schema.clone(),
326 properties: compute_plan_properties(schema),
327 })
328 }
329
330 #[derive(Debug)]
331 struct TestMemoryExec {
332 batches: Vec<RecordBatch>,
333 schema: SchemaRef,
334 properties: PlanProperties,
335 }
336
337 impl DisplayAs for TestMemoryExec {
338 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
339 write!(f, "TestMemoryExec")
340 }
341 }
342
343 impl ExecutionPlan for TestMemoryExec {
344 fn name(&self) -> &str {
345 "TestMemoryExec"
346 }
347 fn as_any(&self) -> &dyn Any {
348 self
349 }
350 fn schema(&self) -> SchemaRef {
351 Arc::clone(&self.schema)
352 }
353 fn properties(&self) -> &PlanProperties {
354 &self.properties
355 }
356 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
357 vec![]
358 }
359 fn with_new_children(
360 self: Arc<Self>,
361 _children: Vec<Arc<dyn ExecutionPlan>>,
362 ) -> DFResult<Arc<dyn ExecutionPlan>> {
363 Ok(self)
364 }
365 fn execute(
366 &self,
367 _partition: usize,
368 _context: Arc<TaskContext>,
369 ) -> DFResult<SendableRecordBatchStream> {
370 Ok(Box::pin(MemoryStream::try_new(
371 self.batches.clone(),
372 Arc::clone(&self.schema),
373 None,
374 )?))
375 }
376 }
377
378 async fn execute_best_by(
379 input: Arc<dyn ExecutionPlan>,
380 key_indices: Vec<usize>,
381 sort_criteria: Vec<SortCriterion>,
382 ) -> RecordBatch {
383 let exec = BestByExec::new(input, key_indices, sort_criteria, true);
384 let ctx = SessionContext::new();
385 let task_ctx = ctx.task_ctx();
386 let stream = exec.execute(0, task_ctx).unwrap();
387 let batches: Vec<RecordBatch> = datafusion::physical_plan::common::collect(stream)
388 .await
389 .unwrap();
390 if batches.is_empty() {
391 RecordBatch::new_empty(exec.schema())
392 } else {
393 arrow::compute::concat_batches(&exec.schema(), &batches).unwrap()
394 }
395 }
396
397 #[tokio::test]
398 async fn test_best_ascending() {
399 let batch = make_test_batch(vec!["a", "a", "a"], vec![3.0, 1.0, 2.0], vec![20, 30, 25]);
401 let input = make_memory_exec(batch);
402 let result = execute_best_by(
403 input,
404 vec![0], vec![SortCriterion {
406 col_index: 1, ascending: true,
408 nulls_first: false,
409 }],
410 )
411 .await;
412
413 assert_eq!(result.num_rows(), 1);
414 let scores = result
415 .column(1)
416 .as_any()
417 .downcast_ref::<Float64Array>()
418 .unwrap();
419 assert_eq!(scores.value(0), 1.0);
420 }
421
422 #[tokio::test]
423 async fn test_best_descending() {
424 let batch = make_test_batch(vec!["a", "a", "a"], vec![3.0, 1.0, 2.0], vec![20, 30, 25]);
425 let input = make_memory_exec(batch);
426 let result = execute_best_by(
427 input,
428 vec![0],
429 vec![SortCriterion {
430 col_index: 1,
431 ascending: false,
432 nulls_first: false,
433 }],
434 )
435 .await;
436
437 assert_eq!(result.num_rows(), 1);
438 let scores = result
439 .column(1)
440 .as_any()
441 .downcast_ref::<Float64Array>()
442 .unwrap();
443 assert_eq!(scores.value(0), 3.0);
444 }
445
446 #[tokio::test]
447 async fn test_multiple_groups() {
448 let batch = make_test_batch(
449 vec!["a", "a", "b", "b"],
450 vec![3.0, 1.0, 5.0, 2.0],
451 vec![20, 30, 40, 50],
452 );
453 let input = make_memory_exec(batch);
454 let result = execute_best_by(
455 input,
456 vec![0],
457 vec![SortCriterion {
458 col_index: 1,
459 ascending: true,
460 nulls_first: false,
461 }],
462 )
463 .await;
464
465 assert_eq!(result.num_rows(), 2);
466 let names = result
467 .column(0)
468 .as_any()
469 .downcast_ref::<StringArray>()
470 .unwrap();
471 let scores = result
472 .column(1)
473 .as_any()
474 .downcast_ref::<Float64Array>()
475 .unwrap();
476
477 for i in 0..2 {
478 match names.value(i) {
479 "a" => assert_eq!(scores.value(i), 1.0),
480 "b" => assert_eq!(scores.value(i), 2.0),
481 _ => panic!("unexpected name"),
482 }
483 }
484 }
485
486 #[tokio::test]
487 async fn test_multi_column_criteria() {
488 let batch = make_test_batch(vec!["a", "a"], vec![1.0, 1.0], vec![30, 20]);
490 let input = make_memory_exec(batch);
491 let result = execute_best_by(
492 input,
493 vec![0],
494 vec![
495 SortCriterion {
496 col_index: 1,
497 ascending: true,
498 nulls_first: false,
499 },
500 SortCriterion {
501 col_index: 2,
502 ascending: true,
503 nulls_first: false,
504 },
505 ],
506 )
507 .await;
508
509 assert_eq!(result.num_rows(), 1);
510 let ages = result
511 .column(2)
512 .as_any()
513 .downcast_ref::<Int64Array>()
514 .unwrap();
515 assert_eq!(ages.value(0), 20); }
517
518 #[tokio::test]
519 async fn test_empty_input() {
520 let schema = Arc::new(Schema::new(vec![
521 Field::new("name", DataType::Utf8, true),
522 Field::new("score", DataType::Float64, true),
523 ]));
524 let batch = RecordBatch::new_empty(schema.clone());
525 let input = make_memory_exec(batch);
526 let result = execute_best_by(input, vec![0], vec![]).await;
527 assert_eq!(result.num_rows(), 0);
528 }
529
530 #[tokio::test]
531 async fn test_single_row_passthrough() {
532 let batch = make_test_batch(vec!["x"], vec![42.0], vec![10]);
533 let input = make_memory_exec(batch);
534 let result = execute_best_by(
535 input,
536 vec![0],
537 vec![SortCriterion {
538 col_index: 1,
539 ascending: true,
540 nulls_first: false,
541 }],
542 )
543 .await;
544
545 assert_eq!(result.num_rows(), 1);
546 let scores = result
547 .column(1)
548 .as_any()
549 .downcast_ref::<Float64Array>()
550 .unwrap();
551 assert_eq!(scores.value(0), 42.0);
552 }
553}