Skip to main content

trueno_db/
topk.rs

1//! Top-K selection algorithms
2//!
3//! **Problem**: `ORDER BY ... LIMIT K` is O(N log N). Top-K selection is O(N).
4//!
5//! **Solution**: Min-heap based Top-K selection algorithm
6//!
7//! **Performance Impact** (1M files):
8//! - Full sort: 2.3 seconds
9//! - Top-K selection: 0.08 seconds
10//! - **Speedup**: 28.75x
11//!
12//! Toyota Way Principles:
13//! - **Kaizen**: Algorithmic improvement (O(N log N) → O(N))
14//! - **Muda elimination**: Avoid unnecessary full sort
15//! - **Genchi Genbutsu**: Actual performance measurements guide optimization
16//!
17//! References:
18//! - ../paiml-mcp-agent-toolkit/docs/specifications/trueno-db-integration-review-response.md Issue #2
19
20use crate::Error;
21use arrow::array::{
22    Array, ArrayRef, Float32Array, Float64Array, Int32Array, Int64Array, StringArray,
23};
24use arrow::compute::SortOptions;
25use arrow::record_batch::RecordBatch;
26use std::cmp::Ordering;
27use std::collections::BinaryHeap;
28use std::sync::Arc;
29
30/// Sort order for Top-K selection
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum SortOrder {
33    /// Ascending order (smallest K values)
34    Ascending,
35    /// Descending order (largest K values)
36    Descending,
37}
38
39impl From<SortOrder> for SortOptions {
40    fn from(order: SortOrder) -> Self {
41        Self { descending: matches!(order, SortOrder::Descending), nulls_first: false }
42    }
43}
44
45/// Trait for Top-K selection on record batches
46pub trait TopKSelection {
47    /// Select top K rows by a specific column
48    ///
49    /// # Arguments
50    /// * `column_index` - Index of the column to sort by
51    /// * `k` - Number of rows to select
52    /// * `order` - Sort order (Ascending or Descending)
53    ///
54    /// # Returns
55    /// A new `RecordBatch` containing the top K rows
56    ///
57    /// # Errors
58    /// Returns error if:
59    /// - Column index is out of bounds
60    /// - Column data type is not sortable
61    /// - K is zero
62    ///
63    /// # Examples
64    ///
65    /// ```rust
66    /// use trueno_db::topk::{TopKSelection, SortOrder};
67    /// use arrow::array::{Float64Array, RecordBatch};
68    /// use arrow::datatypes::{DataType, Field, Schema};
69    /// use std::sync::Arc;
70    ///
71    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
72    /// let schema = Arc::new(Schema::new(vec![
73    ///     Field::new("score", DataType::Float64, false),
74    /// ]));
75    /// let batch = RecordBatch::try_new(
76    ///     schema,
77    ///     vec![Arc::new(Float64Array::from(vec![1.0, 5.0, 3.0, 9.0, 2.0]))],
78    /// )?;
79    ///
80    /// // Get top 3 highest scores
81    /// let top3 = batch.top_k(0, 3, SortOrder::Descending)?;
82    /// assert_eq!(top3.num_rows(), 3);
83    /// # Ok(())
84    /// # }
85    /// ```
86    fn top_k(&self, column_index: usize, k: usize, order: SortOrder) -> crate::Result<RecordBatch>;
87}
88
89impl TopKSelection for RecordBatch {
90    fn top_k(&self, column_index: usize, k: usize, order: SortOrder) -> crate::Result<RecordBatch> {
91        // Validate inputs
92        if k == 0 {
93            return Err(Error::InvalidInput("k must be greater than 0".to_string()));
94        }
95
96        if column_index >= self.num_columns() {
97            return Err(Error::InvalidInput(format!(
98                "Column index {} out of bounds (batch has {} columns)",
99                column_index,
100                self.num_columns()
101            )));
102        }
103
104        // If k >= num_rows, just sort and return all rows
105        if k >= self.num_rows() {
106            return sort_all_rows(self, column_index, order);
107        }
108
109        // Use heap-based Top-K selection
110        let column = self.column(column_index);
111        let indices = select_top_k_indices(column, k, order)?;
112
113        // Build result batch from selected indices
114        build_batch_from_indices(self, &indices)
115    }
116}
117
118/// Select top K indices using min-heap algorithm
119///
120/// Time complexity: O(N log K) where N = number of rows, K = selection size
121/// Space complexity: O(K) for the heap
122fn select_top_k_indices(
123    column: &ArrayRef,
124    k: usize,
125    order: SortOrder,
126) -> crate::Result<Vec<usize>> {
127    match column.data_type() {
128        arrow::datatypes::DataType::Int32 => {
129            let array = column.as_any().downcast_ref::<Int32Array>().ok_or_else(|| {
130                Error::Other("Failed to downcast Int32 column to Int32Array".to_string())
131            })?;
132            select_top_k_i32(array, k, order)
133        }
134        arrow::datatypes::DataType::Int64 => {
135            let array = column.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
136                Error::Other("Failed to downcast Int64 column to Int64Array".to_string())
137            })?;
138            select_top_k_i64(array, k, order)
139        }
140        arrow::datatypes::DataType::Float32 => {
141            let array = column.as_any().downcast_ref::<Float32Array>().ok_or_else(|| {
142                Error::Other("Failed to downcast Float32 column to Float32Array".to_string())
143            })?;
144            select_top_k_f32(array, k, order)
145        }
146        arrow::datatypes::DataType::Float64 => {
147            let array = column.as_any().downcast_ref::<Float64Array>().ok_or_else(|| {
148                Error::Other("Failed to downcast Float64 column to Float64Array".to_string())
149            })?;
150            select_top_k_f64(array, k, order)
151        }
152        dt => Err(Error::InvalidInput(format!("Top-K not supported for data type: {dt:?}"))),
153    }
154}
155
156// Heap item for descending order (min-heap: keep smallest at top, so we can find largest K)
157#[derive(Debug)]
158struct MinHeapItem<V> {
159    value: V,
160    index: usize,
161}
162
163impl<V: PartialOrd> PartialEq for MinHeapItem<V> {
164    fn eq(&self, other: &Self) -> bool {
165        self.value.partial_cmp(&other.value) == Some(Ordering::Equal)
166    }
167}
168
169impl<V: PartialOrd> Eq for MinHeapItem<V> {}
170
171impl<V: PartialOrd> Ord for MinHeapItem<V> {
172    fn cmp(&self, other: &Self) -> Ordering {
173        // Reverse comparison for min-heap (smallest at top)
174        other.value.partial_cmp(&self.value).unwrap_or(Ordering::Equal)
175    }
176}
177
178impl<V: PartialOrd> PartialOrd for MinHeapItem<V> {
179    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
180        Some(self.cmp(other))
181    }
182}
183
184// Heap item for ascending order (max-heap: keep largest at top, so we can find smallest K)
185#[derive(Debug)]
186struct MaxHeapItem<V> {
187    value: V,
188    index: usize,
189}
190
191impl<V: PartialOrd> PartialEq for MaxHeapItem<V> {
192    fn eq(&self, other: &Self) -> bool {
193        self.value.partial_cmp(&other.value) == Some(Ordering::Equal)
194    }
195}
196
197impl<V: PartialOrd> Eq for MaxHeapItem<V> {}
198
199impl<V: PartialOrd> Ord for MaxHeapItem<V> {
200    fn cmp(&self, other: &Self) -> Ordering {
201        // Normal comparison for max-heap (largest at top)
202        self.value.partial_cmp(&other.value).unwrap_or(Ordering::Equal)
203    }
204}
205
206impl<V: PartialOrd> PartialOrd for MaxHeapItem<V> {
207    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
208        Some(self.cmp(other))
209    }
210}
211
212/// Top-K selection for `Int32Array`
213#[allow(clippy::unnecessary_wraps)]
214fn select_top_k_i32(array: &Int32Array, k: usize, order: SortOrder) -> crate::Result<Vec<usize>> {
215    match order {
216        SortOrder::Descending => {
217            // Use min-heap to find largest K
218            let mut heap: BinaryHeap<MinHeapItem<i32>> = BinaryHeap::with_capacity(k);
219
220            for index in 0..array.len() {
221                if !array.is_null(index) {
222                    let value = array.value(index);
223                    let item = MinHeapItem { value, index };
224
225                    if heap.len() < k {
226                        heap.push(item);
227                    } else if let Some(top) = heap.peek() {
228                        if value > top.value {
229                            heap.pop();
230                            heap.push(item);
231                        }
232                    }
233                }
234            }
235
236            let mut result: Vec<_> = heap.into_vec();
237            result.sort_by(|a, b| b.value.cmp(&a.value));
238            Ok(result.into_iter().map(|item| item.index).collect())
239        }
240        SortOrder::Ascending => {
241            // Use max-heap to find smallest K
242            let mut heap: BinaryHeap<MaxHeapItem<i32>> = BinaryHeap::with_capacity(k);
243
244            for index in 0..array.len() {
245                if !array.is_null(index) {
246                    let value = array.value(index);
247                    let item = MaxHeapItem { value, index };
248
249                    if heap.len() < k {
250                        heap.push(item);
251                    } else if let Some(top) = heap.peek() {
252                        if value < top.value {
253                            heap.pop();
254                            heap.push(item);
255                        }
256                    }
257                }
258            }
259
260            let mut result: Vec<_> = heap.into_vec();
261            result.sort_by(|a, b| a.value.cmp(&b.value));
262            Ok(result.into_iter().map(|item| item.index).collect())
263        }
264    }
265}
266
267/// Top-K selection for `Int64Array`
268#[allow(clippy::unnecessary_wraps)]
269fn select_top_k_i64(array: &Int64Array, k: usize, order: SortOrder) -> crate::Result<Vec<usize>> {
270    match order {
271        SortOrder::Descending => {
272            let mut heap: BinaryHeap<MinHeapItem<i64>> = BinaryHeap::with_capacity(k);
273            for index in 0..array.len() {
274                if !array.is_null(index) {
275                    let value = array.value(index);
276                    if heap.len() < k {
277                        heap.push(MinHeapItem { value, index });
278                    } else if let Some(top) = heap.peek() {
279                        if value > top.value {
280                            heap.pop();
281                            heap.push(MinHeapItem { value, index });
282                        }
283                    }
284                }
285            }
286            let mut result: Vec<_> = heap.into_vec();
287            result.sort_by(|a, b| b.value.cmp(&a.value));
288            Ok(result.into_iter().map(|item| item.index).collect())
289        }
290        SortOrder::Ascending => {
291            let mut heap: BinaryHeap<MaxHeapItem<i64>> = BinaryHeap::with_capacity(k);
292            for index in 0..array.len() {
293                if !array.is_null(index) {
294                    let value = array.value(index);
295                    if heap.len() < k {
296                        heap.push(MaxHeapItem { value, index });
297                    } else if let Some(top) = heap.peek() {
298                        if value < top.value {
299                            heap.pop();
300                            heap.push(MaxHeapItem { value, index });
301                        }
302                    }
303                }
304            }
305            let mut result: Vec<_> = heap.into_vec();
306            result.sort_by(|a, b| a.value.cmp(&b.value));
307            Ok(result.into_iter().map(|item| item.index).collect())
308        }
309    }
310}
311
312/// Top-K selection for `Float32Array`
313#[allow(clippy::unnecessary_wraps)]
314fn select_top_k_f32(array: &Float32Array, k: usize, order: SortOrder) -> crate::Result<Vec<usize>> {
315    match order {
316        SortOrder::Descending => {
317            let mut heap: BinaryHeap<MinHeapItem<f32>> = BinaryHeap::with_capacity(k);
318            for index in 0..array.len() {
319                if !array.is_null(index) {
320                    let value = array.value(index);
321                    if heap.len() < k {
322                        heap.push(MinHeapItem { value, index });
323                    } else if let Some(top) = heap.peek() {
324                        if value > top.value {
325                            heap.pop();
326                            heap.push(MinHeapItem { value, index });
327                        }
328                    }
329                }
330            }
331            let mut result: Vec<_> = heap.into_vec();
332            result.sort_by(|a, b| b.value.partial_cmp(&a.value).unwrap_or(Ordering::Equal));
333            Ok(result.into_iter().map(|item| item.index).collect())
334        }
335        SortOrder::Ascending => {
336            let mut heap: BinaryHeap<MaxHeapItem<f32>> = BinaryHeap::with_capacity(k);
337            for index in 0..array.len() {
338                if !array.is_null(index) {
339                    let value = array.value(index);
340                    if heap.len() < k {
341                        heap.push(MaxHeapItem { value, index });
342                    } else if let Some(top) = heap.peek() {
343                        if value < top.value {
344                            heap.pop();
345                            heap.push(MaxHeapItem { value, index });
346                        }
347                    }
348                }
349            }
350            let mut result: Vec<_> = heap.into_vec();
351            result.sort_by(|a, b| a.value.partial_cmp(&b.value).unwrap_or(Ordering::Equal));
352            Ok(result.into_iter().map(|item| item.index).collect())
353        }
354    }
355}
356
357/// Top-K selection for `Float64Array`
358#[allow(clippy::unnecessary_wraps)]
359fn select_top_k_f64(array: &Float64Array, k: usize, order: SortOrder) -> crate::Result<Vec<usize>> {
360    match order {
361        SortOrder::Descending => {
362            let mut heap: BinaryHeap<MinHeapItem<f64>> = BinaryHeap::with_capacity(k);
363            for index in 0..array.len() {
364                if !array.is_null(index) {
365                    let value = array.value(index);
366                    if heap.len() < k {
367                        heap.push(MinHeapItem { value, index });
368                    } else if let Some(top) = heap.peek() {
369                        if value > top.value {
370                            heap.pop();
371                            heap.push(MinHeapItem { value, index });
372                        }
373                    }
374                }
375            }
376            let mut result: Vec<_> = heap.into_vec();
377            result.sort_by(|a, b| b.value.partial_cmp(&a.value).unwrap_or(Ordering::Equal));
378            Ok(result.into_iter().map(|item| item.index).collect())
379        }
380        SortOrder::Ascending => {
381            let mut heap: BinaryHeap<MaxHeapItem<f64>> = BinaryHeap::with_capacity(k);
382            for index in 0..array.len() {
383                if !array.is_null(index) {
384                    let value = array.value(index);
385                    if heap.len() < k {
386                        heap.push(MaxHeapItem { value, index });
387                    } else if let Some(top) = heap.peek() {
388                        if value < top.value {
389                            heap.pop();
390                            heap.push(MaxHeapItem { value, index });
391                        }
392                    }
393                }
394            }
395            let mut result: Vec<_> = heap.into_vec();
396            result.sort_by(|a, b| a.value.partial_cmp(&b.value).unwrap_or(Ordering::Equal));
397            Ok(result.into_iter().map(|item| item.index).collect())
398        }
399    }
400}
401
402/// Build a new record batch from selected row indices
403fn build_batch_from_indices(batch: &RecordBatch, indices: &[usize]) -> crate::Result<RecordBatch> {
404    use arrow::datatypes::DataType;
405
406    let mut new_columns: Vec<ArrayRef> = Vec::with_capacity(batch.num_columns());
407
408    for col_idx in 0..batch.num_columns() {
409        let column = batch.column(col_idx);
410
411        let new_array: ArrayRef = match column.data_type() {
412            DataType::Int32 => {
413                let array = column.as_any().downcast_ref::<Int32Array>().ok_or_else(|| {
414                    Error::Other("Failed to downcast Int32 column to Int32Array".to_string())
415                })?;
416                let values: Vec<i32> = indices.iter().map(|&idx| array.value(idx)).collect();
417                Arc::new(Int32Array::from(values))
418            }
419            DataType::Int64 => {
420                let array = column.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
421                    Error::Other("Failed to downcast Int64 column to Int64Array".to_string())
422                })?;
423                let values: Vec<i64> = indices.iter().map(|&idx| array.value(idx)).collect();
424                Arc::new(Int64Array::from(values))
425            }
426            DataType::Float32 => {
427                let array = column.as_any().downcast_ref::<Float32Array>().ok_or_else(|| {
428                    Error::Other("Failed to downcast Float32 column to Float32Array".to_string())
429                })?;
430                let values: Vec<f32> = indices.iter().map(|&idx| array.value(idx)).collect();
431                Arc::new(Float32Array::from(values))
432            }
433            DataType::Float64 => {
434                let array = column.as_any().downcast_ref::<Float64Array>().ok_or_else(|| {
435                    Error::Other("Failed to downcast Float64 column to Float64Array".to_string())
436                })?;
437                let values: Vec<f64> = indices.iter().map(|&idx| array.value(idx)).collect();
438                Arc::new(Float64Array::from(values))
439            }
440            DataType::Utf8 => {
441                let array = column.as_any().downcast_ref::<StringArray>().ok_or_else(|| {
442                    Error::Other("Failed to downcast Utf8 column to StringArray".to_string())
443                })?;
444                let values: Vec<&str> = indices.iter().map(|&idx| array.value(idx)).collect();
445                Arc::new(StringArray::from(values))
446            }
447            dt => {
448                return Err(Error::InvalidInput(format!(
449                    "Top-K not implemented for column data type: {dt:?}"
450                )));
451            }
452        };
453
454        new_columns.push(new_array);
455    }
456
457    RecordBatch::try_new(batch.schema(), new_columns)
458        .map_err(|e| Error::StorageError(format!("Failed to create result batch: {e}")))
459}
460
461/// Fallback: sort all rows when k >= `num_rows`
462fn sort_all_rows(
463    batch: &RecordBatch,
464    column_index: usize,
465    order: SortOrder,
466) -> crate::Result<RecordBatch> {
467    use arrow::compute::sort_to_indices;
468
469    let sort_options = SortOptions::from(order);
470    let indices = sort_to_indices(batch.column(column_index).as_ref(), Some(sort_options), None)
471        .map_err(|e| Error::StorageError(format!("Failed to sort: {e}")))?;
472
473    // Convert indices to usize vec
474    let indices_array =
475        indices.as_any().downcast_ref::<arrow::array::UInt32Array>().ok_or_else(|| {
476            Error::Other(
477                "Failed to downcast sort indices to UInt32Array (expected from sort_to_indices)"
478                    .to_string(),
479            )
480        })?;
481    let indices_vec: Vec<usize> =
482        (0..indices_array.len()).map(|i| indices_array.value(i) as usize).collect();
483
484    build_batch_from_indices(batch, &indices_vec)
485}
486
487#[cfg(test)]
488#[allow(
489    clippy::cast_possible_truncation,
490    clippy::cast_possible_wrap,
491    clippy::cast_precision_loss,
492    clippy::float_cmp,
493    clippy::redundant_closure
494)]
495mod tests {
496    use super::*;
497    use arrow::datatypes::{DataType, Field, Schema};
498    use std::sync::Arc;
499
500    fn create_test_batch(values: Vec<f64>) -> RecordBatch {
501        let schema = Arc::new(Schema::new(vec![
502            Field::new("id", DataType::Int32, false),
503            Field::new("score", DataType::Float64, false),
504        ]));
505
506        let ids: Vec<i32> = (0..values.len() as i32).collect();
507
508        RecordBatch::try_new(
509            schema,
510            vec![Arc::new(Int32Array::from(ids)), Arc::new(Float64Array::from(values))],
511        )
512        .unwrap()
513    }
514
515    #[test]
516    fn test_top_k_descending_basic() {
517        // Test: Get top 3 highest scores
518        let batch = create_test_batch(vec![1.0, 5.0, 3.0, 9.0, 2.0]);
519        let result = batch.top_k(1, 3, SortOrder::Descending).unwrap();
520
521        assert_eq!(result.num_rows(), 3);
522
523        let scores = result.column(1).as_any().downcast_ref::<Float64Array>().unwrap();
524        assert_eq!(scores.value(0), 9.0);
525        assert_eq!(scores.value(1), 5.0);
526        assert_eq!(scores.value(2), 3.0);
527    }
528
529    #[test]
530    fn test_top_k_ascending_basic() {
531        // Test: Get top 3 lowest scores
532        let batch = create_test_batch(vec![1.0, 5.0, 3.0, 9.0, 2.0]);
533        let result = batch.top_k(1, 3, SortOrder::Ascending).unwrap();
534
535        assert_eq!(result.num_rows(), 3);
536
537        let scores = result.column(1).as_any().downcast_ref::<Float64Array>().unwrap();
538        assert_eq!(scores.value(0), 1.0);
539        assert_eq!(scores.value(1), 2.0);
540        assert_eq!(scores.value(2), 3.0);
541    }
542
543    #[test]
544    fn test_top_k_k_equals_length() {
545        // Edge case: k equals number of rows (should return sorted batch)
546        let batch = create_test_batch(vec![3.0, 1.0, 2.0]);
547        let result = batch.top_k(1, 3, SortOrder::Descending).unwrap();
548
549        assert_eq!(result.num_rows(), 3);
550
551        let scores = result.column(1).as_any().downcast_ref::<Float64Array>().unwrap();
552        assert_eq!(scores.value(0), 3.0);
553        assert_eq!(scores.value(1), 2.0);
554        assert_eq!(scores.value(2), 1.0);
555    }
556
557    #[test]
558    fn test_top_k_k_greater_than_length() {
559        // Edge case: k > number of rows (should return all rows sorted)
560        let batch = create_test_batch(vec![3.0, 1.0, 2.0]);
561        let result = batch.top_k(1, 10, SortOrder::Descending).unwrap();
562
563        assert_eq!(result.num_rows(), 3);
564
565        let scores = result.column(1).as_any().downcast_ref::<Float64Array>().unwrap();
566        assert_eq!(scores.value(0), 3.0);
567        assert_eq!(scores.value(1), 2.0);
568        assert_eq!(scores.value(2), 1.0);
569    }
570
571    #[test]
572    fn test_top_k_k_zero_fails() {
573        // Error case: k = 0 should fail
574        let batch = create_test_batch(vec![1.0, 2.0, 3.0]);
575        let result = batch.top_k(1, 0, SortOrder::Descending);
576
577        assert!(result.is_err());
578        assert!(result.unwrap_err().to_string().contains("must be greater than 0"));
579    }
580
581    #[test]
582    fn test_top_k_invalid_column_index() {
583        // Error case: invalid column index
584        let batch = create_test_batch(vec![1.0, 2.0, 3.0]);
585        let result = batch.top_k(99, 2, SortOrder::Descending);
586
587        assert!(result.is_err());
588        assert!(result.unwrap_err().to_string().contains("out of bounds"));
589    }
590
591    #[test]
592    fn test_top_k_preserves_row_integrity() {
593        // Test: Ensure all columns stay aligned (row integrity)
594        let batch = create_test_batch(vec![1.0, 5.0, 3.0]);
595        let result = batch.top_k(1, 2, SortOrder::Descending).unwrap();
596
597        let ids = result.column(0).as_any().downcast_ref::<Int32Array>().unwrap();
598        let scores = result.column(1).as_any().downcast_ref::<Float64Array>().unwrap();
599
600        // Top 2: scores 5.0 (id=1) and 3.0 (id=2)
601        assert_eq!(scores.value(0), 5.0);
602        assert_eq!(ids.value(0), 1);
603
604        assert_eq!(scores.value(1), 3.0);
605        assert_eq!(ids.value(1), 2);
606    }
607
608    #[test]
609    fn test_top_k_large_dataset() {
610        // Performance test: 1M rows (should be O(N) vs O(N log N))
611        let values: Vec<f64> = (0..1_000_000).map(|i| f64::from(i)).collect();
612        let batch = create_test_batch(values);
613
614        let start = std::time::Instant::now();
615        let result = batch.top_k(1, 10, SortOrder::Descending).unwrap();
616        let duration = start.elapsed();
617
618        assert_eq!(result.num_rows(), 10);
619
620        let scores = result.column(1).as_any().downcast_ref::<Float64Array>().unwrap();
621        // Top 10 should be 999999, 999998, ..., 999990
622        for i in 0..10 {
623            assert_eq!(scores.value(i), 999_999.0 - i as f64);
624        }
625
626        // Should complete in < 500ms (debug builds are slower)
627        // Target for release builds: <80ms for 1M rows
628        // This is still much faster than O(N log N) sort
629        assert!(
630            duration.as_millis() < 500,
631            "Top-K took {}ms (expected <500ms)",
632            duration.as_millis()
633        );
634    }
635
636    // Property-based tests
637    #[cfg(test)]
638    mod property_tests {
639        use super::*;
640        use proptest::prelude::*;
641
642        proptest! {
643            /// Property: Top-K always returns exactly K rows (or fewer if input is smaller)
644            #[test]
645            fn prop_top_k_returns_k_rows(
646                values in prop::collection::vec(0.0f64..1000.0, 10..1000),
647                k in 1usize..100
648            ) {
649                let batch = create_test_batch(values.clone());
650                let result = batch.top_k(1, k, SortOrder::Descending).unwrap();
651
652                let expected_rows = k.min(values.len());
653                prop_assert_eq!(result.num_rows(), expected_rows);
654            }
655
656            /// Property: Top-K descending returns values in descending order
657            #[test]
658            fn prop_top_k_descending_is_sorted(
659                values in prop::collection::vec(0.0f64..1000.0, 10..1000),
660                k in 1usize..100
661            ) {
662                let batch = create_test_batch(values);
663                let result = batch.top_k(1, k, SortOrder::Descending).unwrap();
664
665                let scores = result.column(1).as_any().downcast_ref::<Float64Array>().unwrap();
666
667                // Check descending order
668                for i in 0..scores.len().saturating_sub(1) {
669                    prop_assert!(
670                        scores.value(i) >= scores.value(i + 1),
671                        "Not in descending order: {} < {}",
672                        scores.value(i),
673                        scores.value(i + 1)
674                    );
675                }
676            }
677
678            /// Property: Top-K ascending returns values in ascending order
679            #[test]
680            fn prop_top_k_ascending_is_sorted(
681                values in prop::collection::vec(0.0f64..1000.0, 10..1000),
682                k in 1usize..100
683            ) {
684                let batch = create_test_batch(values);
685                let result = batch.top_k(1, k, SortOrder::Ascending).unwrap();
686
687                let scores = result.column(1).as_any().downcast_ref::<Float64Array>().unwrap();
688
689                // Check ascending order
690                for i in 0..scores.len().saturating_sub(1) {
691                    prop_assert!(
692                        scores.value(i) <= scores.value(i + 1),
693                        "Not in ascending order: {} > {}",
694                        scores.value(i),
695                        scores.value(i + 1)
696                    );
697                }
698            }
699        }
700    }
701
702    // Additional tests for all data types
703    #[test]
704    fn test_top_k_int32() {
705        use arrow::array::Int32Array;
706        use arrow::datatypes::{DataType, Field, Schema};
707        use std::sync::Arc;
708
709        let schema = Schema::new(vec![Field::new("value", DataType::Int32, false)]);
710        let values = Int32Array::from(vec![5, 2, 8, 1, 9, 3]);
711        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values)]).unwrap();
712
713        let result = batch.top_k(0, 3, SortOrder::Descending).unwrap();
714        assert_eq!(result.num_rows(), 3);
715
716        let col = result.column(0).as_any().downcast_ref::<Int32Array>().unwrap();
717        assert_eq!(col.value(0), 9);
718        assert_eq!(col.value(1), 8);
719        assert_eq!(col.value(2), 5);
720    }
721
722    #[test]
723    fn test_top_k_int32_ascending() {
724        use arrow::array::Int32Array;
725        use arrow::datatypes::{DataType, Field, Schema};
726        use std::sync::Arc;
727
728        let schema = Schema::new(vec![Field::new("value", DataType::Int32, false)]);
729        let values = Int32Array::from(vec![5, 2, 8, 1, 9, 3]);
730        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values)]).unwrap();
731
732        let result = batch.top_k(0, 3, SortOrder::Ascending).unwrap();
733        assert_eq!(result.num_rows(), 3);
734
735        let col = result.column(0).as_any().downcast_ref::<Int32Array>().unwrap();
736        assert_eq!(col.value(0), 1);
737        assert_eq!(col.value(1), 2);
738        assert_eq!(col.value(2), 3);
739    }
740
741    #[test]
742    fn test_top_k_int64() {
743        use arrow::array::Int64Array;
744        use arrow::datatypes::{DataType, Field, Schema};
745        use std::sync::Arc;
746
747        let schema = Schema::new(vec![Field::new("value", DataType::Int64, false)]);
748        let values = Int64Array::from(vec![100i64, 200, 50, 300, 150]);
749        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values)]).unwrap();
750
751        let result = batch.top_k(0, 2, SortOrder::Ascending).unwrap();
752        assert_eq!(result.num_rows(), 2);
753
754        let col = result.column(0).as_any().downcast_ref::<Int64Array>().unwrap();
755        assert_eq!(col.value(0), 50);
756        assert_eq!(col.value(1), 100);
757    }
758
759    #[test]
760    fn test_top_k_int64_descending() {
761        use arrow::array::Int64Array;
762        use arrow::datatypes::{DataType, Field, Schema};
763        use std::sync::Arc;
764
765        let schema = Schema::new(vec![Field::new("value", DataType::Int64, false)]);
766        let values = Int64Array::from(vec![100i64, 200, 50, 300, 150]);
767        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values)]).unwrap();
768
769        let result = batch.top_k(0, 2, SortOrder::Descending).unwrap();
770        assert_eq!(result.num_rows(), 2);
771
772        let col = result.column(0).as_any().downcast_ref::<Int64Array>().unwrap();
773        assert_eq!(col.value(0), 300);
774        assert_eq!(col.value(1), 200);
775    }
776
777    #[test]
778    fn test_top_k_float32() {
779        use arrow::array::Float32Array;
780        use arrow::datatypes::{DataType, Field, Schema};
781        use std::sync::Arc;
782
783        let schema = Schema::new(vec![Field::new("value", DataType::Float32, false)]);
784        let values = Float32Array::from(vec![1.5f32, 2.7, 0.3, 4.2, 3.1]);
785        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values)]).unwrap();
786
787        let result = batch.top_k(0, 3, SortOrder::Descending).unwrap();
788        assert_eq!(result.num_rows(), 3);
789
790        let col = result.column(0).as_any().downcast_ref::<Float32Array>().unwrap();
791        assert!((col.value(0) - 4.2).abs() < 0.001);
792        assert!((col.value(1) - 3.1).abs() < 0.001);
793        assert!((col.value(2) - 2.7).abs() < 0.001);
794    }
795
796    #[test]
797    fn test_top_k_float32_ascending() {
798        use arrow::array::Float32Array;
799        use arrow::datatypes::{DataType, Field, Schema};
800        use std::sync::Arc;
801
802        let schema = Schema::new(vec![Field::new("value", DataType::Float32, false)]);
803        let values = Float32Array::from(vec![1.5f32, 2.7, 0.3, 4.2, 3.1]);
804        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values)]).unwrap();
805
806        let result = batch.top_k(0, 3, SortOrder::Ascending).unwrap();
807        assert_eq!(result.num_rows(), 3);
808
809        let col = result.column(0).as_any().downcast_ref::<Float32Array>().unwrap();
810        assert!((col.value(0) - 0.3).abs() < 0.001);
811        assert!((col.value(1) - 1.5).abs() < 0.001);
812        assert!((col.value(2) - 2.7).abs() < 0.001);
813    }
814
815    #[test]
816    fn test_top_k_unsupported_type() {
817        use arrow::array::StringArray;
818        use arrow::datatypes::{DataType, Field, Schema};
819        use std::sync::Arc;
820
821        let schema = Schema::new(vec![Field::new("value", DataType::Utf8, false)]);
822        let values = StringArray::from(vec!["a", "b", "c"]);
823        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values)]).unwrap();
824
825        let result = batch.top_k(0, 2, SortOrder::Descending);
826        assert!(result.is_err());
827        assert!(result.unwrap_err().to_string().contains("Top-K not supported for data type"));
828    }
829
830    // ========================================================================
831    // Heap Item Trait Tests (for coverage of MinHeapItem/MaxHeapItem)
832    // ========================================================================
833
834    #[test]
835    fn test_min_heap_item_eq() {
836        let item1 = MinHeapItem { value: 42i32, index: 0 };
837        let item2 = MinHeapItem { value: 42i32, index: 1 };
838        let item3 = MinHeapItem { value: 43i32, index: 2 };
839
840        assert_eq!(item1, item2);
841        assert_ne!(item1, item3);
842    }
843
844    #[test]
845    fn test_min_heap_item_ord() {
846        let item1 = MinHeapItem { value: 10i32, index: 0 };
847        let item2 = MinHeapItem { value: 20i32, index: 1 };
848        let item3 = MinHeapItem { value: 30i32, index: 2 };
849
850        // Min-heap: reverse ordering (smaller values at top)
851        assert!(item3 < item2); // 30 < 20 in min-heap ordering
852        assert!(item2 < item1); // 20 < 10 in min-heap ordering
853    }
854
855    #[test]
856    fn test_min_heap_item_partial_ord() {
857        let item1 = MinHeapItem { value: 5i32, index: 0 };
858        let item2 = MinHeapItem { value: 10i32, index: 1 };
859
860        assert!(item1.partial_cmp(&item2) == Some(Ordering::Greater));
861    }
862
863    #[test]
864    fn test_max_heap_item_eq() {
865        let item1 = MaxHeapItem { value: 42i32, index: 0 };
866        let item2 = MaxHeapItem { value: 42i32, index: 1 };
867        let item3 = MaxHeapItem { value: 43i32, index: 2 };
868
869        assert_eq!(item1, item2);
870        assert_ne!(item1, item3);
871    }
872
873    #[test]
874    fn test_max_heap_item_ord() {
875        let item1 = MaxHeapItem { value: 10i32, index: 0 };
876        let item2 = MaxHeapItem { value: 20i32, index: 1 };
877        let item3 = MaxHeapItem { value: 30i32, index: 2 };
878
879        // Max-heap: normal ordering (larger values at top)
880        assert!(item3 > item2);
881        assert!(item2 > item1);
882    }
883
884    #[test]
885    fn test_max_heap_item_partial_ord() {
886        let item1 = MaxHeapItem { value: 5i32, index: 0 };
887        let item2 = MaxHeapItem { value: 10i32, index: 1 };
888
889        assert!(item1.partial_cmp(&item2) == Some(Ordering::Less));
890    }
891
892    #[test]
893    fn test_heap_item_with_floats() {
894        let item1 = MinHeapItem { value: 1.5f64, index: 0 };
895        let item2 = MinHeapItem { value: 2.5f64, index: 1 };
896
897        assert_ne!(item1, item2);
898        assert!(item2 < item1); // Min-heap: reverse ordering
899    }
900
901    #[test]
902    fn test_heap_item_eq_method_with_floats() {
903        let item1 = MaxHeapItem { value: 3.25f64, index: 0 };
904        let item2 = MaxHeapItem { value: 3.25f64, index: 1 };
905        let item3 = MaxHeapItem { value: 2.75f64, index: 2 };
906
907        assert!(item1.eq(&item2));
908        assert!(!item1.eq(&item3));
909    }
910}