Skip to main content

uni_query/query/df_graph/
locy_best_by.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! BEST BY operator for Locy.
5//!
6//! `BestByExec` selects the "best" row per group of KEY columns, using ordered
7//! criteria (ASC/DESC) to rank rows and keeping only the top-ranked row per group.
8
9use crate::query::df_graph::common::{ScalarKey, compute_plan_properties, extract_scalar_key};
10use arrow::compute::take;
11use arrow_array::{RecordBatch, UInt32Array};
12use arrow_schema::SchemaRef;
13use datafusion::common::Result as DFResult;
14use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
15use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
16use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
17use futures::{Stream, TryStreamExt};
18use std::any::Any;
19use std::fmt;
20use std::pin::Pin;
21use std::sync::Arc;
22use std::task::{Context, Poll};
23
24/// Sort criterion for BEST BY ordering.
25#[derive(Debug, Clone)]
26pub struct SortCriterion {
27    pub col_index: usize,
28    pub ascending: bool,
29    pub nulls_first: bool,
30}
31
32/// DataFusion `ExecutionPlan` that applies BEST BY selection.
33///
34/// For each group of rows sharing the same KEY columns, sorts by the given
35/// criteria and keeps only the first (best) row per group.
36#[derive(Debug)]
37pub struct BestByExec {
38    input: Arc<dyn ExecutionPlan>,
39    key_indices: Vec<usize>,
40    sort_criteria: Vec<SortCriterion>,
41    schema: SchemaRef,
42    properties: PlanProperties,
43    metrics: ExecutionPlanMetricsSet,
44    /// When true, apply a secondary sort on remaining columns for deterministic
45    /// tie-breaking. When false, tied rows are selected non-deterministically.
46    deterministic: bool,
47}
48
49impl BestByExec {
50    /// Create a new `BestByExec`.
51    ///
52    /// # Arguments
53    /// * `input` - Child execution plan
54    /// * `key_indices` - Indices of KEY columns for grouping
55    /// * `sort_criteria` - Ordering criteria for selecting the "best" row
56    /// * `deterministic` - Whether to apply secondary sort for tie-breaking
57    pub fn new(
58        input: Arc<dyn ExecutionPlan>,
59        key_indices: Vec<usize>,
60        sort_criteria: Vec<SortCriterion>,
61        deterministic: bool,
62    ) -> Self {
63        let schema = input.schema();
64        let properties = compute_plan_properties(Arc::clone(&schema));
65        Self {
66            input,
67            key_indices,
68            sort_criteria,
69            schema,
70            properties,
71            metrics: ExecutionPlanMetricsSet::new(),
72            deterministic,
73        }
74    }
75}
76
77impl DisplayAs for BestByExec {
78    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79        write!(
80            f,
81            "BestByExec: key_indices={:?}, criteria={:?}",
82            self.key_indices, self.sort_criteria
83        )
84    }
85}
86
87impl ExecutionPlan for BestByExec {
88    fn name(&self) -> &str {
89        "BestByExec"
90    }
91
92    fn as_any(&self) -> &dyn Any {
93        self
94    }
95
96    fn schema(&self) -> SchemaRef {
97        Arc::clone(&self.schema)
98    }
99
100    fn properties(&self) -> &PlanProperties {
101        &self.properties
102    }
103
104    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
105        vec![&self.input]
106    }
107
108    fn with_new_children(
109        self: Arc<Self>,
110        children: Vec<Arc<dyn ExecutionPlan>>,
111    ) -> DFResult<Arc<dyn ExecutionPlan>> {
112        if children.len() != 1 {
113            return Err(datafusion::error::DataFusionError::Plan(
114                "BestByExec requires exactly one child".to_string(),
115            ));
116        }
117        Ok(Arc::new(Self::new(
118            Arc::clone(&children[0]),
119            self.key_indices.clone(),
120            self.sort_criteria.clone(),
121            self.deterministic,
122        )))
123    }
124
125    fn execute(
126        &self,
127        partition: usize,
128        context: Arc<TaskContext>,
129    ) -> DFResult<SendableRecordBatchStream> {
130        let input_stream = self.input.execute(partition, Arc::clone(&context))?;
131        let metrics = BaselineMetrics::new(&self.metrics, partition);
132        let key_indices = self.key_indices.clone();
133        let sort_criteria = self.sort_criteria.clone();
134        let schema = Arc::clone(&self.schema);
135        let input_schema = self.input.schema();
136        let deterministic = self.deterministic;
137
138        let fut = async move {
139            let batches: Vec<RecordBatch> = input_stream.try_collect().await?;
140
141            if batches.is_empty() {
142                return Ok(RecordBatch::new_empty(schema));
143            }
144
145            let batch = arrow::compute::concat_batches(&input_schema, &batches)
146                .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
147
148            if batch.num_rows() == 0 {
149                return Ok(RecordBatch::new_empty(schema));
150            }
151
152            // Build sort columns: key columns ASC first (for grouping contiguity),
153            // then criteria columns, then remaining columns ASC (deterministic tie-breaking).
154            let num_cols = batch.num_columns();
155            let mut sort_columns = Vec::new();
156
157            // 1. Key columns ASC, nulls last
158            for &ki in &key_indices {
159                sort_columns.push(arrow::compute::SortColumn {
160                    values: Arc::clone(batch.column(ki)),
161                    options: Some(arrow::compute::SortOptions {
162                        descending: false,
163                        nulls_first: false,
164                    }),
165                });
166            }
167
168            // 2. Criteria columns
169            for criterion in &sort_criteria {
170                sort_columns.push(arrow::compute::SortColumn {
171                    values: Arc::clone(batch.column(criterion.col_index)),
172                    options: Some(arrow::compute::SortOptions {
173                        descending: !criterion.ascending,
174                        nulls_first: criterion.nulls_first,
175                    }),
176                });
177            }
178
179            // 3. Remaining columns ASC for deterministic tie-breaking (optional)
180            if deterministic {
181                let used_cols: std::collections::HashSet<usize> = key_indices
182                    .iter()
183                    .copied()
184                    .chain(sort_criteria.iter().map(|c| c.col_index))
185                    .collect();
186                for col_idx in 0..num_cols {
187                    if !used_cols.contains(&col_idx) {
188                        sort_columns.push(arrow::compute::SortColumn {
189                            values: Arc::clone(batch.column(col_idx)),
190                            options: Some(arrow::compute::SortOptions {
191                                descending: false,
192                                nulls_first: false,
193                            }),
194                        });
195                    }
196                }
197            }
198
199            // Sort to get indices
200            let sorted_indices = arrow::compute::lexsort_to_indices(&sort_columns, None)
201                .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
202
203            // Reorder batch by sorted indices
204            let sorted_columns: Vec<_> = batch
205                .columns()
206                .iter()
207                .map(|col| take(col.as_ref(), &sorted_indices, None))
208                .collect::<Result<Vec<_>, _>>()?;
209            let sorted_batch = RecordBatch::try_new(Arc::clone(&schema), sorted_columns)
210                .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
211
212            // Dedup: keep first row per key group (linear scan)
213            let mut keep_indices: Vec<u32> = Vec::new();
214            let mut prev_key: Option<Vec<ScalarKey>> = None;
215
216            for row_idx in 0..sorted_batch.num_rows() {
217                let key = extract_scalar_key(&sorted_batch, &key_indices, row_idx);
218                let is_new_group = match &prev_key {
219                    None => true,
220                    Some(prev) => *prev != key,
221                };
222                if is_new_group {
223                    keep_indices.push(row_idx as u32);
224                    prev_key = Some(key);
225                }
226            }
227
228            let keep_array = UInt32Array::from(keep_indices);
229            let output_columns: Vec<_> = sorted_batch
230                .columns()
231                .iter()
232                .map(|col| take(col.as_ref(), &keep_array, None))
233                .collect::<Result<Vec<_>, _>>()?;
234
235            RecordBatch::try_new(schema, output_columns)
236                .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))
237        };
238
239        Ok(Box::pin(BestByStream {
240            state: BestByStreamState::Running(Box::pin(fut)),
241            schema: Arc::clone(&self.schema),
242            metrics,
243        }))
244    }
245
246    fn metrics(&self) -> Option<MetricsSet> {
247        Some(self.metrics.clone_inner())
248    }
249}
250
251// ---------------------------------------------------------------------------
252// Stream implementation
253// ---------------------------------------------------------------------------
254
255enum BestByStreamState {
256    Running(Pin<Box<dyn std::future::Future<Output = DFResult<RecordBatch>> + Send>>),
257    Done,
258}
259
260struct BestByStream {
261    state: BestByStreamState,
262    schema: SchemaRef,
263    metrics: BaselineMetrics,
264}
265
266impl Stream for BestByStream {
267    type Item = DFResult<RecordBatch>;
268
269    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
270        match &mut self.state {
271            BestByStreamState::Running(fut) => match fut.as_mut().poll(cx) {
272                Poll::Ready(Ok(batch)) => {
273                    self.metrics.record_output(batch.num_rows());
274                    self.state = BestByStreamState::Done;
275                    Poll::Ready(Some(Ok(batch)))
276                }
277                Poll::Ready(Err(e)) => {
278                    self.state = BestByStreamState::Done;
279                    Poll::Ready(Some(Err(e)))
280                }
281                Poll::Pending => Poll::Pending,
282            },
283            BestByStreamState::Done => Poll::Ready(None),
284        }
285    }
286}
287
288impl RecordBatchStream for BestByStream {
289    fn schema(&self) -> SchemaRef {
290        Arc::clone(&self.schema)
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297    use arrow_array::{Float64Array, Int64Array, StringArray};
298    use arrow_schema::{DataType, Field, Schema};
299    use datafusion::physical_plan::memory::MemoryStream;
300    use datafusion::prelude::SessionContext;
301
302    fn make_test_batch(names: Vec<&str>, scores: Vec<f64>, ages: Vec<i64>) -> RecordBatch {
303        let schema = Arc::new(Schema::new(vec![
304            Field::new("name", DataType::Utf8, true),
305            Field::new("score", DataType::Float64, true),
306            Field::new("age", DataType::Int64, true),
307        ]));
308        RecordBatch::try_new(
309            schema,
310            vec![
311                Arc::new(StringArray::from(
312                    names.into_iter().map(Some).collect::<Vec<_>>(),
313                )),
314                Arc::new(Float64Array::from(scores)),
315                Arc::new(Int64Array::from(ages)),
316            ],
317        )
318        .unwrap()
319    }
320
321    fn make_memory_exec(batch: RecordBatch) -> Arc<dyn ExecutionPlan> {
322        let schema = batch.schema();
323        Arc::new(TestMemoryExec {
324            batches: vec![batch],
325            schema: schema.clone(),
326            properties: compute_plan_properties(schema),
327        })
328    }
329
330    #[derive(Debug)]
331    struct TestMemoryExec {
332        batches: Vec<RecordBatch>,
333        schema: SchemaRef,
334        properties: PlanProperties,
335    }
336
337    impl DisplayAs for TestMemoryExec {
338        fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
339            write!(f, "TestMemoryExec")
340        }
341    }
342
343    impl ExecutionPlan for TestMemoryExec {
344        fn name(&self) -> &str {
345            "TestMemoryExec"
346        }
347        fn as_any(&self) -> &dyn Any {
348            self
349        }
350        fn schema(&self) -> SchemaRef {
351            Arc::clone(&self.schema)
352        }
353        fn properties(&self) -> &PlanProperties {
354            &self.properties
355        }
356        fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
357            vec![]
358        }
359        fn with_new_children(
360            self: Arc<Self>,
361            _children: Vec<Arc<dyn ExecutionPlan>>,
362        ) -> DFResult<Arc<dyn ExecutionPlan>> {
363            Ok(self)
364        }
365        fn execute(
366            &self,
367            _partition: usize,
368            _context: Arc<TaskContext>,
369        ) -> DFResult<SendableRecordBatchStream> {
370            Ok(Box::pin(MemoryStream::try_new(
371                self.batches.clone(),
372                Arc::clone(&self.schema),
373                None,
374            )?))
375        }
376    }
377
378    async fn execute_best_by(
379        input: Arc<dyn ExecutionPlan>,
380        key_indices: Vec<usize>,
381        sort_criteria: Vec<SortCriterion>,
382    ) -> RecordBatch {
383        let exec = BestByExec::new(input, key_indices, sort_criteria, true);
384        let ctx = SessionContext::new();
385        let task_ctx = ctx.task_ctx();
386        let stream = exec.execute(0, task_ctx).unwrap();
387        let batches: Vec<RecordBatch> = datafusion::physical_plan::common::collect(stream)
388            .await
389            .unwrap();
390        if batches.is_empty() {
391            RecordBatch::new_empty(exec.schema())
392        } else {
393            arrow::compute::concat_batches(&exec.schema(), &batches).unwrap()
394        }
395    }
396
397    #[tokio::test]
398    async fn test_best_ascending() {
399        // Group "a" with scores 3.0, 1.0, 2.0 → best ascending = 1.0
400        let batch = make_test_batch(vec!["a", "a", "a"], vec![3.0, 1.0, 2.0], vec![20, 30, 25]);
401        let input = make_memory_exec(batch);
402        let result = execute_best_by(
403            input,
404            vec![0], // key: name
405            vec![SortCriterion {
406                col_index: 1, // sort by score
407                ascending: true,
408                nulls_first: false,
409            }],
410        )
411        .await;
412
413        assert_eq!(result.num_rows(), 1);
414        let scores = result
415            .column(1)
416            .as_any()
417            .downcast_ref::<Float64Array>()
418            .unwrap();
419        assert_eq!(scores.value(0), 1.0);
420    }
421
422    #[tokio::test]
423    async fn test_best_descending() {
424        let batch = make_test_batch(vec!["a", "a", "a"], vec![3.0, 1.0, 2.0], vec![20, 30, 25]);
425        let input = make_memory_exec(batch);
426        let result = execute_best_by(
427            input,
428            vec![0],
429            vec![SortCriterion {
430                col_index: 1,
431                ascending: false,
432                nulls_first: false,
433            }],
434        )
435        .await;
436
437        assert_eq!(result.num_rows(), 1);
438        let scores = result
439            .column(1)
440            .as_any()
441            .downcast_ref::<Float64Array>()
442            .unwrap();
443        assert_eq!(scores.value(0), 3.0);
444    }
445
446    #[tokio::test]
447    async fn test_multiple_groups() {
448        let batch = make_test_batch(
449            vec!["a", "a", "b", "b"],
450            vec![3.0, 1.0, 5.0, 2.0],
451            vec![20, 30, 40, 50],
452        );
453        let input = make_memory_exec(batch);
454        let result = execute_best_by(
455            input,
456            vec![0],
457            vec![SortCriterion {
458                col_index: 1,
459                ascending: true,
460                nulls_first: false,
461            }],
462        )
463        .await;
464
465        assert_eq!(result.num_rows(), 2);
466        let names = result
467            .column(0)
468            .as_any()
469            .downcast_ref::<StringArray>()
470            .unwrap();
471        let scores = result
472            .column(1)
473            .as_any()
474            .downcast_ref::<Float64Array>()
475            .unwrap();
476
477        for i in 0..2 {
478            match names.value(i) {
479                "a" => assert_eq!(scores.value(i), 1.0),
480                "b" => assert_eq!(scores.value(i), 2.0),
481                _ => panic!("unexpected name"),
482            }
483        }
484    }
485
486    #[tokio::test]
487    async fn test_multi_column_criteria() {
488        // Two rows with same score, different age → second criterion breaks tie
489        let batch = make_test_batch(vec!["a", "a"], vec![1.0, 1.0], vec![30, 20]);
490        let input = make_memory_exec(batch);
491        let result = execute_best_by(
492            input,
493            vec![0],
494            vec![
495                SortCriterion {
496                    col_index: 1,
497                    ascending: true,
498                    nulls_first: false,
499                },
500                SortCriterion {
501                    col_index: 2,
502                    ascending: true,
503                    nulls_first: false,
504                },
505            ],
506        )
507        .await;
508
509        assert_eq!(result.num_rows(), 1);
510        let ages = result
511            .column(2)
512            .as_any()
513            .downcast_ref::<Int64Array>()
514            .unwrap();
515        assert_eq!(ages.value(0), 20); // younger wins with ascending age
516    }
517
518    #[tokio::test]
519    async fn test_empty_input() {
520        let schema = Arc::new(Schema::new(vec![
521            Field::new("name", DataType::Utf8, true),
522            Field::new("score", DataType::Float64, true),
523        ]));
524        let batch = RecordBatch::new_empty(schema.clone());
525        let input = make_memory_exec(batch);
526        let result = execute_best_by(input, vec![0], vec![]).await;
527        assert_eq!(result.num_rows(), 0);
528    }
529
530    #[tokio::test]
531    async fn test_single_row_passthrough() {
532        let batch = make_test_batch(vec!["x"], vec![42.0], vec![10]);
533        let input = make_memory_exec(batch);
534        let result = execute_best_by(
535            input,
536            vec![0],
537            vec![SortCriterion {
538                col_index: 1,
539                ascending: true,
540                nulls_first: false,
541            }],
542        )
543        .await;
544
545        assert_eq!(result.num_rows(), 1);
546        let scores = result
547            .column(1)
548            .as_any()
549            .downcast_ref::<Float64Array>()
550            .unwrap();
551        assert_eq!(scores.value(0), 42.0);
552    }
553}