Skip to main content

uni_query/query/df_graph/
locy_best_by.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! BEST BY operator for Locy.
5//!
6//! `BestByExec` selects the "best" row per group of KEY columns, using ordered
7//! criteria (ASC/DESC) to rank rows and keeping only the top-ranked row per group.
8
9use crate::query::df_graph::common::{
10    ScalarKey, arrow_err, compute_plan_properties, extract_scalar_key,
11};
12use arrow::compute::take;
13use arrow_array::{RecordBatch, UInt32Array};
14use arrow_schema::SchemaRef;
15use datafusion::common::Result as DFResult;
16use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
17use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
18use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
19use futures::{Stream, TryStreamExt};
20use std::any::Any;
21use std::fmt;
22use std::pin::Pin;
23use std::sync::Arc;
24use std::task::{Context, Poll};
25
26/// Sort criterion for BEST BY ordering.
27#[derive(Debug, Clone)]
28pub struct SortCriterion {
29    pub col_index: usize,
30    pub ascending: bool,
31    pub nulls_first: bool,
32}
33
34/// DataFusion `ExecutionPlan` that applies BEST BY selection.
35///
36/// For each group of rows sharing the same KEY columns, sorts by the given
37/// criteria and keeps only the first (best) row per group.
38#[derive(Debug)]
39pub struct BestByExec {
40    input: Arc<dyn ExecutionPlan>,
41    key_indices: Vec<usize>,
42    sort_criteria: Vec<SortCriterion>,
43    schema: SchemaRef,
44    properties: PlanProperties,
45    metrics: ExecutionPlanMetricsSet,
46    /// When true, apply a secondary sort on remaining columns for deterministic
47    /// tie-breaking. When false, tied rows are selected non-deterministically.
48    deterministic: bool,
49}
50
51impl BestByExec {
52    /// Create a new `BestByExec`.
53    ///
54    /// # Arguments
55    /// * `input` - Child execution plan
56    /// * `key_indices` - Indices of KEY columns for grouping
57    /// * `sort_criteria` - Ordering criteria for selecting the "best" row
58    /// * `deterministic` - Whether to apply secondary sort for tie-breaking
59    pub fn new(
60        input: Arc<dyn ExecutionPlan>,
61        key_indices: Vec<usize>,
62        sort_criteria: Vec<SortCriterion>,
63        deterministic: bool,
64    ) -> Self {
65        let schema = input.schema();
66        let properties = compute_plan_properties(Arc::clone(&schema));
67        Self {
68            input,
69            key_indices,
70            sort_criteria,
71            schema,
72            properties,
73            metrics: ExecutionPlanMetricsSet::new(),
74            deterministic,
75        }
76    }
77}
78
79impl DisplayAs for BestByExec {
80    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81        write!(
82            f,
83            "BestByExec: key_indices={:?}, criteria={:?}",
84            self.key_indices, self.sort_criteria
85        )
86    }
87}
88
89impl ExecutionPlan for BestByExec {
90    fn name(&self) -> &str {
91        "BestByExec"
92    }
93
94    fn as_any(&self) -> &dyn Any {
95        self
96    }
97
98    fn schema(&self) -> SchemaRef {
99        Arc::clone(&self.schema)
100    }
101
102    fn properties(&self) -> &PlanProperties {
103        &self.properties
104    }
105
106    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
107        vec![&self.input]
108    }
109
110    fn with_new_children(
111        self: Arc<Self>,
112        children: Vec<Arc<dyn ExecutionPlan>>,
113    ) -> DFResult<Arc<dyn ExecutionPlan>> {
114        if children.len() != 1 {
115            return Err(datafusion::error::DataFusionError::Plan(
116                "BestByExec requires exactly one child".to_string(),
117            ));
118        }
119        Ok(Arc::new(Self::new(
120            Arc::clone(&children[0]),
121            self.key_indices.clone(),
122            self.sort_criteria.clone(),
123            self.deterministic,
124        )))
125    }
126
127    fn execute(
128        &self,
129        partition: usize,
130        context: Arc<TaskContext>,
131    ) -> DFResult<SendableRecordBatchStream> {
132        let input_stream = self.input.execute(partition, Arc::clone(&context))?;
133        let metrics = BaselineMetrics::new(&self.metrics, partition);
134        let key_indices = self.key_indices.clone();
135        let sort_criteria = self.sort_criteria.clone();
136        let schema = Arc::clone(&self.schema);
137        let input_schema = self.input.schema();
138        let deterministic = self.deterministic;
139
140        let fut = async move {
141            let batches: Vec<RecordBatch> = input_stream.try_collect().await?;
142
143            if batches.is_empty() {
144                return Ok(RecordBatch::new_empty(schema));
145            }
146
147            let batch =
148                arrow::compute::concat_batches(&input_schema, &batches).map_err(arrow_err)?;
149
150            if batch.num_rows() == 0 {
151                return Ok(RecordBatch::new_empty(schema));
152            }
153
154            // Build sort columns: key columns ASC first (for grouping contiguity),
155            // then criteria columns, then remaining columns ASC (deterministic tie-breaking).
156            let num_cols = batch.num_columns();
157            let mut sort_columns = Vec::new();
158
159            // 1. Key columns ASC, nulls last
160            for &ki in &key_indices {
161                sort_columns.push(arrow::compute::SortColumn {
162                    values: Arc::clone(batch.column(ki)),
163                    options: Some(arrow::compute::SortOptions {
164                        descending: false,
165                        nulls_first: false,
166                    }),
167                });
168            }
169
170            // 2. Criteria columns
171            for criterion in &sort_criteria {
172                sort_columns.push(arrow::compute::SortColumn {
173                    values: Arc::clone(batch.column(criterion.col_index)),
174                    options: Some(arrow::compute::SortOptions {
175                        descending: !criterion.ascending,
176                        nulls_first: criterion.nulls_first,
177                    }),
178                });
179            }
180
181            // 3. Remaining columns ASC for deterministic tie-breaking (optional)
182            if deterministic {
183                let used_cols: std::collections::HashSet<usize> = key_indices
184                    .iter()
185                    .copied()
186                    .chain(sort_criteria.iter().map(|c| c.col_index))
187                    .collect();
188                for col_idx in 0..num_cols {
189                    if !used_cols.contains(&col_idx) {
190                        sort_columns.push(arrow::compute::SortColumn {
191                            values: Arc::clone(batch.column(col_idx)),
192                            options: Some(arrow::compute::SortOptions {
193                                descending: false,
194                                nulls_first: false,
195                            }),
196                        });
197                    }
198                }
199            }
200
201            // Sort to get indices
202            let sorted_indices =
203                arrow::compute::lexsort_to_indices(&sort_columns, None).map_err(arrow_err)?;
204
205            // Reorder batch by sorted indices
206            let sorted_columns: Vec<_> = batch
207                .columns()
208                .iter()
209                .map(|col| take(col.as_ref(), &sorted_indices, None))
210                .collect::<Result<Vec<_>, _>>()?;
211            let sorted_batch =
212                RecordBatch::try_new(Arc::clone(&schema), sorted_columns).map_err(arrow_err)?;
213
214            // Dedup: keep first row per key group (linear scan)
215            let mut keep_indices: Vec<u32> = Vec::new();
216            let mut prev_key: Option<Vec<ScalarKey>> = None;
217
218            for row_idx in 0..sorted_batch.num_rows() {
219                let key = extract_scalar_key(&sorted_batch, &key_indices, row_idx);
220                let is_new_group = match &prev_key {
221                    None => true,
222                    Some(prev) => *prev != key,
223                };
224                if is_new_group {
225                    keep_indices.push(row_idx as u32);
226                    prev_key = Some(key);
227                }
228            }
229
230            let keep_array = UInt32Array::from(keep_indices);
231            let output_columns: Vec<_> = sorted_batch
232                .columns()
233                .iter()
234                .map(|col| take(col.as_ref(), &keep_array, None))
235                .collect::<Result<Vec<_>, _>>()?;
236
237            RecordBatch::try_new(schema, output_columns).map_err(arrow_err)
238        };
239
240        Ok(Box::pin(BestByStream {
241            state: BestByStreamState::Running(Box::pin(fut)),
242            schema: Arc::clone(&self.schema),
243            metrics,
244        }))
245    }
246
247    fn metrics(&self) -> Option<MetricsSet> {
248        Some(self.metrics.clone_inner())
249    }
250}
251
252// ---------------------------------------------------------------------------
253// Stream implementation
254// ---------------------------------------------------------------------------
255
256enum BestByStreamState {
257    Running(Pin<Box<dyn std::future::Future<Output = DFResult<RecordBatch>> + Send>>),
258    Done,
259}
260
261struct BestByStream {
262    state: BestByStreamState,
263    schema: SchemaRef,
264    metrics: BaselineMetrics,
265}
266
267impl Stream for BestByStream {
268    type Item = DFResult<RecordBatch>;
269
270    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
271        match &mut self.state {
272            BestByStreamState::Running(fut) => match fut.as_mut().poll(cx) {
273                Poll::Ready(Ok(batch)) => {
274                    self.metrics.record_output(batch.num_rows());
275                    self.state = BestByStreamState::Done;
276                    Poll::Ready(Some(Ok(batch)))
277                }
278                Poll::Ready(Err(e)) => {
279                    self.state = BestByStreamState::Done;
280                    Poll::Ready(Some(Err(e)))
281                }
282                Poll::Pending => Poll::Pending,
283            },
284            BestByStreamState::Done => Poll::Ready(None),
285        }
286    }
287}
288
289impl RecordBatchStream for BestByStream {
290    fn schema(&self) -> SchemaRef {
291        Arc::clone(&self.schema)
292    }
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298    use arrow_array::{Float64Array, Int64Array, StringArray};
299    use arrow_schema::{DataType, Field, Schema};
300    use datafusion::physical_plan::memory::MemoryStream;
301    use datafusion::prelude::SessionContext;
302
303    fn make_test_batch(names: Vec<&str>, scores: Vec<f64>, ages: Vec<i64>) -> RecordBatch {
304        let schema = Arc::new(Schema::new(vec![
305            Field::new("name", DataType::Utf8, true),
306            Field::new("score", DataType::Float64, true),
307            Field::new("age", DataType::Int64, true),
308        ]));
309        RecordBatch::try_new(
310            schema,
311            vec![
312                Arc::new(StringArray::from(
313                    names.into_iter().map(Some).collect::<Vec<_>>(),
314                )),
315                Arc::new(Float64Array::from(scores)),
316                Arc::new(Int64Array::from(ages)),
317            ],
318        )
319        .unwrap()
320    }
321
322    fn make_memory_exec(batch: RecordBatch) -> Arc<dyn ExecutionPlan> {
323        let schema = batch.schema();
324        Arc::new(TestMemoryExec {
325            batches: vec![batch],
326            schema: schema.clone(),
327            properties: compute_plan_properties(schema),
328        })
329    }
330
331    #[derive(Debug)]
332    struct TestMemoryExec {
333        batches: Vec<RecordBatch>,
334        schema: SchemaRef,
335        properties: PlanProperties,
336    }
337
338    impl DisplayAs for TestMemoryExec {
339        fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
340            write!(f, "TestMemoryExec")
341        }
342    }
343
344    impl ExecutionPlan for TestMemoryExec {
345        fn name(&self) -> &str {
346            "TestMemoryExec"
347        }
348        fn as_any(&self) -> &dyn Any {
349            self
350        }
351        fn schema(&self) -> SchemaRef {
352            Arc::clone(&self.schema)
353        }
354        fn properties(&self) -> &PlanProperties {
355            &self.properties
356        }
357        fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
358            vec![]
359        }
360        fn with_new_children(
361            self: Arc<Self>,
362            _children: Vec<Arc<dyn ExecutionPlan>>,
363        ) -> DFResult<Arc<dyn ExecutionPlan>> {
364            Ok(self)
365        }
366        fn execute(
367            &self,
368            _partition: usize,
369            _context: Arc<TaskContext>,
370        ) -> DFResult<SendableRecordBatchStream> {
371            Ok(Box::pin(MemoryStream::try_new(
372                self.batches.clone(),
373                Arc::clone(&self.schema),
374                None,
375            )?))
376        }
377    }
378
379    async fn execute_best_by(
380        input: Arc<dyn ExecutionPlan>,
381        key_indices: Vec<usize>,
382        sort_criteria: Vec<SortCriterion>,
383    ) -> RecordBatch {
384        let exec = BestByExec::new(input, key_indices, sort_criteria, true);
385        let ctx = SessionContext::new();
386        let task_ctx = ctx.task_ctx();
387        let stream = exec.execute(0, task_ctx).unwrap();
388        let batches: Vec<RecordBatch> = datafusion::physical_plan::common::collect(stream)
389            .await
390            .unwrap();
391        if batches.is_empty() {
392            RecordBatch::new_empty(exec.schema())
393        } else {
394            arrow::compute::concat_batches(&exec.schema(), &batches).unwrap()
395        }
396    }
397
398    #[tokio::test]
399    async fn test_best_ascending() {
400        // Group "a" with scores 3.0, 1.0, 2.0 → best ascending = 1.0
401        let batch = make_test_batch(vec!["a", "a", "a"], vec![3.0, 1.0, 2.0], vec![20, 30, 25]);
402        let input = make_memory_exec(batch);
403        let result = execute_best_by(
404            input,
405            vec![0], // key: name
406            vec![SortCriterion {
407                col_index: 1, // sort by score
408                ascending: true,
409                nulls_first: false,
410            }],
411        )
412        .await;
413
414        assert_eq!(result.num_rows(), 1);
415        let scores = result
416            .column(1)
417            .as_any()
418            .downcast_ref::<Float64Array>()
419            .unwrap();
420        assert_eq!(scores.value(0), 1.0);
421    }
422
423    #[tokio::test]
424    async fn test_best_descending() {
425        let batch = make_test_batch(vec!["a", "a", "a"], vec![3.0, 1.0, 2.0], vec![20, 30, 25]);
426        let input = make_memory_exec(batch);
427        let result = execute_best_by(
428            input,
429            vec![0],
430            vec![SortCriterion {
431                col_index: 1,
432                ascending: false,
433                nulls_first: false,
434            }],
435        )
436        .await;
437
438        assert_eq!(result.num_rows(), 1);
439        let scores = result
440            .column(1)
441            .as_any()
442            .downcast_ref::<Float64Array>()
443            .unwrap();
444        assert_eq!(scores.value(0), 3.0);
445    }
446
447    #[tokio::test]
448    async fn test_multiple_groups() {
449        let batch = make_test_batch(
450            vec!["a", "a", "b", "b"],
451            vec![3.0, 1.0, 5.0, 2.0],
452            vec![20, 30, 40, 50],
453        );
454        let input = make_memory_exec(batch);
455        let result = execute_best_by(
456            input,
457            vec![0],
458            vec![SortCriterion {
459                col_index: 1,
460                ascending: true,
461                nulls_first: false,
462            }],
463        )
464        .await;
465
466        assert_eq!(result.num_rows(), 2);
467        let names = result
468            .column(0)
469            .as_any()
470            .downcast_ref::<StringArray>()
471            .unwrap();
472        let scores = result
473            .column(1)
474            .as_any()
475            .downcast_ref::<Float64Array>()
476            .unwrap();
477
478        for i in 0..2 {
479            match names.value(i) {
480                "a" => assert_eq!(scores.value(i), 1.0),
481                "b" => assert_eq!(scores.value(i), 2.0),
482                _ => panic!("unexpected name"),
483            }
484        }
485    }
486
487    #[tokio::test]
488    async fn test_multi_column_criteria() {
489        // Two rows with same score, different age → second criterion breaks tie
490        let batch = make_test_batch(vec!["a", "a"], vec![1.0, 1.0], vec![30, 20]);
491        let input = make_memory_exec(batch);
492        let result = execute_best_by(
493            input,
494            vec![0],
495            vec![
496                SortCriterion {
497                    col_index: 1,
498                    ascending: true,
499                    nulls_first: false,
500                },
501                SortCriterion {
502                    col_index: 2,
503                    ascending: true,
504                    nulls_first: false,
505                },
506            ],
507        )
508        .await;
509
510        assert_eq!(result.num_rows(), 1);
511        let ages = result
512            .column(2)
513            .as_any()
514            .downcast_ref::<Int64Array>()
515            .unwrap();
516        assert_eq!(ages.value(0), 20); // younger wins with ascending age
517    }
518
519    #[tokio::test]
520    async fn test_empty_input() {
521        let schema = Arc::new(Schema::new(vec![
522            Field::new("name", DataType::Utf8, true),
523            Field::new("score", DataType::Float64, true),
524        ]));
525        let batch = RecordBatch::new_empty(schema.clone());
526        let input = make_memory_exec(batch);
527        let result = execute_best_by(input, vec![0], vec![]).await;
528        assert_eq!(result.num_rows(), 0);
529    }
530
531    #[tokio::test]
532    async fn test_single_row_passthrough() {
533        let batch = make_test_batch(vec!["x"], vec![42.0], vec![10]);
534        let input = make_memory_exec(batch);
535        let result = execute_best_by(
536            input,
537            vec![0],
538            vec![SortCriterion {
539                col_index: 1,
540                ascending: true,
541                nulls_first: false,
542            }],
543        )
544        .await;
545
546        assert_eq!(result.num_rows(), 1);
547        let scores = result
548            .column(1)
549            .as_any()
550            .downcast_ref::<Float64Array>()
551            .unwrap();
552        assert_eq!(scores.value(0), 42.0);
553    }
554}