1use crate::Error;
21use arrow::array::{
22 Array, ArrayRef, Float32Array, Float64Array, Int32Array, Int64Array, StringArray,
23};
24use arrow::compute::SortOptions;
25use arrow::record_batch::RecordBatch;
26use std::cmp::Ordering;
27use std::collections::BinaryHeap;
28use std::sync::Arc;
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum SortOrder {
33 Ascending,
35 Descending,
37}
38
39impl From<SortOrder> for SortOptions {
40 fn from(order: SortOrder) -> Self {
41 Self { descending: matches!(order, SortOrder::Descending), nulls_first: false }
42 }
43}
44
45pub trait TopKSelection {
47 fn top_k(&self, column_index: usize, k: usize, order: SortOrder) -> crate::Result<RecordBatch>;
87}
88
89impl TopKSelection for RecordBatch {
90 fn top_k(&self, column_index: usize, k: usize, order: SortOrder) -> crate::Result<RecordBatch> {
91 if k == 0 {
93 return Err(Error::InvalidInput("k must be greater than 0".to_string()));
94 }
95
96 if column_index >= self.num_columns() {
97 return Err(Error::InvalidInput(format!(
98 "Column index {} out of bounds (batch has {} columns)",
99 column_index,
100 self.num_columns()
101 )));
102 }
103
104 if k >= self.num_rows() {
106 return sort_all_rows(self, column_index, order);
107 }
108
109 let column = self.column(column_index);
111 let indices = select_top_k_indices(column, k, order)?;
112
113 build_batch_from_indices(self, &indices)
115 }
116}
117
118fn select_top_k_indices(
123 column: &ArrayRef,
124 k: usize,
125 order: SortOrder,
126) -> crate::Result<Vec<usize>> {
127 match column.data_type() {
128 arrow::datatypes::DataType::Int32 => {
129 let array = column.as_any().downcast_ref::<Int32Array>().ok_or_else(|| {
130 Error::Other("Failed to downcast Int32 column to Int32Array".to_string())
131 })?;
132 select_top_k_i32(array, k, order)
133 }
134 arrow::datatypes::DataType::Int64 => {
135 let array = column.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
136 Error::Other("Failed to downcast Int64 column to Int64Array".to_string())
137 })?;
138 select_top_k_i64(array, k, order)
139 }
140 arrow::datatypes::DataType::Float32 => {
141 let array = column.as_any().downcast_ref::<Float32Array>().ok_or_else(|| {
142 Error::Other("Failed to downcast Float32 column to Float32Array".to_string())
143 })?;
144 select_top_k_f32(array, k, order)
145 }
146 arrow::datatypes::DataType::Float64 => {
147 let array = column.as_any().downcast_ref::<Float64Array>().ok_or_else(|| {
148 Error::Other("Failed to downcast Float64 column to Float64Array".to_string())
149 })?;
150 select_top_k_f64(array, k, order)
151 }
152 dt => Err(Error::InvalidInput(format!("Top-K not supported for data type: {dt:?}"))),
153 }
154}
155
156#[derive(Debug)]
158struct MinHeapItem<V> {
159 value: V,
160 index: usize,
161}
162
163impl<V: PartialOrd> PartialEq for MinHeapItem<V> {
164 fn eq(&self, other: &Self) -> bool {
165 self.value.partial_cmp(&other.value) == Some(Ordering::Equal)
166 }
167}
168
169impl<V: PartialOrd> Eq for MinHeapItem<V> {}
170
171impl<V: PartialOrd> Ord for MinHeapItem<V> {
172 fn cmp(&self, other: &Self) -> Ordering {
173 other.value.partial_cmp(&self.value).unwrap_or(Ordering::Equal)
175 }
176}
177
178impl<V: PartialOrd> PartialOrd for MinHeapItem<V> {
179 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
180 Some(self.cmp(other))
181 }
182}
183
184#[derive(Debug)]
186struct MaxHeapItem<V> {
187 value: V,
188 index: usize,
189}
190
191impl<V: PartialOrd> PartialEq for MaxHeapItem<V> {
192 fn eq(&self, other: &Self) -> bool {
193 self.value.partial_cmp(&other.value) == Some(Ordering::Equal)
194 }
195}
196
197impl<V: PartialOrd> Eq for MaxHeapItem<V> {}
198
199impl<V: PartialOrd> Ord for MaxHeapItem<V> {
200 fn cmp(&self, other: &Self) -> Ordering {
201 self.value.partial_cmp(&other.value).unwrap_or(Ordering::Equal)
203 }
204}
205
206impl<V: PartialOrd> PartialOrd for MaxHeapItem<V> {
207 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
208 Some(self.cmp(other))
209 }
210}
211
212#[allow(clippy::unnecessary_wraps)]
214fn select_top_k_i32(array: &Int32Array, k: usize, order: SortOrder) -> crate::Result<Vec<usize>> {
215 match order {
216 SortOrder::Descending => {
217 let mut heap: BinaryHeap<MinHeapItem<i32>> = BinaryHeap::with_capacity(k);
219
220 for index in 0..array.len() {
221 if !array.is_null(index) {
222 let value = array.value(index);
223 let item = MinHeapItem { value, index };
224
225 if heap.len() < k {
226 heap.push(item);
227 } else if let Some(top) = heap.peek() {
228 if value > top.value {
229 heap.pop();
230 heap.push(item);
231 }
232 }
233 }
234 }
235
236 let mut result: Vec<_> = heap.into_vec();
237 result.sort_by(|a, b| b.value.cmp(&a.value));
238 Ok(result.into_iter().map(|item| item.index).collect())
239 }
240 SortOrder::Ascending => {
241 let mut heap: BinaryHeap<MaxHeapItem<i32>> = BinaryHeap::with_capacity(k);
243
244 for index in 0..array.len() {
245 if !array.is_null(index) {
246 let value = array.value(index);
247 let item = MaxHeapItem { value, index };
248
249 if heap.len() < k {
250 heap.push(item);
251 } else if let Some(top) = heap.peek() {
252 if value < top.value {
253 heap.pop();
254 heap.push(item);
255 }
256 }
257 }
258 }
259
260 let mut result: Vec<_> = heap.into_vec();
261 result.sort_by(|a, b| a.value.cmp(&b.value));
262 Ok(result.into_iter().map(|item| item.index).collect())
263 }
264 }
265}
266
267#[allow(clippy::unnecessary_wraps)]
269fn select_top_k_i64(array: &Int64Array, k: usize, order: SortOrder) -> crate::Result<Vec<usize>> {
270 match order {
271 SortOrder::Descending => {
272 let mut heap: BinaryHeap<MinHeapItem<i64>> = BinaryHeap::with_capacity(k);
273 for index in 0..array.len() {
274 if !array.is_null(index) {
275 let value = array.value(index);
276 if heap.len() < k {
277 heap.push(MinHeapItem { value, index });
278 } else if let Some(top) = heap.peek() {
279 if value > top.value {
280 heap.pop();
281 heap.push(MinHeapItem { value, index });
282 }
283 }
284 }
285 }
286 let mut result: Vec<_> = heap.into_vec();
287 result.sort_by(|a, b| b.value.cmp(&a.value));
288 Ok(result.into_iter().map(|item| item.index).collect())
289 }
290 SortOrder::Ascending => {
291 let mut heap: BinaryHeap<MaxHeapItem<i64>> = BinaryHeap::with_capacity(k);
292 for index in 0..array.len() {
293 if !array.is_null(index) {
294 let value = array.value(index);
295 if heap.len() < k {
296 heap.push(MaxHeapItem { value, index });
297 } else if let Some(top) = heap.peek() {
298 if value < top.value {
299 heap.pop();
300 heap.push(MaxHeapItem { value, index });
301 }
302 }
303 }
304 }
305 let mut result: Vec<_> = heap.into_vec();
306 result.sort_by(|a, b| a.value.cmp(&b.value));
307 Ok(result.into_iter().map(|item| item.index).collect())
308 }
309 }
310}
311
312#[allow(clippy::unnecessary_wraps)]
314fn select_top_k_f32(array: &Float32Array, k: usize, order: SortOrder) -> crate::Result<Vec<usize>> {
315 match order {
316 SortOrder::Descending => {
317 let mut heap: BinaryHeap<MinHeapItem<f32>> = BinaryHeap::with_capacity(k);
318 for index in 0..array.len() {
319 if !array.is_null(index) {
320 let value = array.value(index);
321 if heap.len() < k {
322 heap.push(MinHeapItem { value, index });
323 } else if let Some(top) = heap.peek() {
324 if value > top.value {
325 heap.pop();
326 heap.push(MinHeapItem { value, index });
327 }
328 }
329 }
330 }
331 let mut result: Vec<_> = heap.into_vec();
332 result.sort_by(|a, b| b.value.partial_cmp(&a.value).unwrap_or(Ordering::Equal));
333 Ok(result.into_iter().map(|item| item.index).collect())
334 }
335 SortOrder::Ascending => {
336 let mut heap: BinaryHeap<MaxHeapItem<f32>> = BinaryHeap::with_capacity(k);
337 for index in 0..array.len() {
338 if !array.is_null(index) {
339 let value = array.value(index);
340 if heap.len() < k {
341 heap.push(MaxHeapItem { value, index });
342 } else if let Some(top) = heap.peek() {
343 if value < top.value {
344 heap.pop();
345 heap.push(MaxHeapItem { value, index });
346 }
347 }
348 }
349 }
350 let mut result: Vec<_> = heap.into_vec();
351 result.sort_by(|a, b| a.value.partial_cmp(&b.value).unwrap_or(Ordering::Equal));
352 Ok(result.into_iter().map(|item| item.index).collect())
353 }
354 }
355}
356
357#[allow(clippy::unnecessary_wraps)]
359fn select_top_k_f64(array: &Float64Array, k: usize, order: SortOrder) -> crate::Result<Vec<usize>> {
360 match order {
361 SortOrder::Descending => {
362 let mut heap: BinaryHeap<MinHeapItem<f64>> = BinaryHeap::with_capacity(k);
363 for index in 0..array.len() {
364 if !array.is_null(index) {
365 let value = array.value(index);
366 if heap.len() < k {
367 heap.push(MinHeapItem { value, index });
368 } else if let Some(top) = heap.peek() {
369 if value > top.value {
370 heap.pop();
371 heap.push(MinHeapItem { value, index });
372 }
373 }
374 }
375 }
376 let mut result: Vec<_> = heap.into_vec();
377 result.sort_by(|a, b| b.value.partial_cmp(&a.value).unwrap_or(Ordering::Equal));
378 Ok(result.into_iter().map(|item| item.index).collect())
379 }
380 SortOrder::Ascending => {
381 let mut heap: BinaryHeap<MaxHeapItem<f64>> = BinaryHeap::with_capacity(k);
382 for index in 0..array.len() {
383 if !array.is_null(index) {
384 let value = array.value(index);
385 if heap.len() < k {
386 heap.push(MaxHeapItem { value, index });
387 } else if let Some(top) = heap.peek() {
388 if value < top.value {
389 heap.pop();
390 heap.push(MaxHeapItem { value, index });
391 }
392 }
393 }
394 }
395 let mut result: Vec<_> = heap.into_vec();
396 result.sort_by(|a, b| a.value.partial_cmp(&b.value).unwrap_or(Ordering::Equal));
397 Ok(result.into_iter().map(|item| item.index).collect())
398 }
399 }
400}
401
402fn build_batch_from_indices(batch: &RecordBatch, indices: &[usize]) -> crate::Result<RecordBatch> {
404 use arrow::datatypes::DataType;
405
406 let mut new_columns: Vec<ArrayRef> = Vec::with_capacity(batch.num_columns());
407
408 for col_idx in 0..batch.num_columns() {
409 let column = batch.column(col_idx);
410
411 let new_array: ArrayRef = match column.data_type() {
412 DataType::Int32 => {
413 let array = column.as_any().downcast_ref::<Int32Array>().ok_or_else(|| {
414 Error::Other("Failed to downcast Int32 column to Int32Array".to_string())
415 })?;
416 let values: Vec<i32> = indices.iter().map(|&idx| array.value(idx)).collect();
417 Arc::new(Int32Array::from(values))
418 }
419 DataType::Int64 => {
420 let array = column.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
421 Error::Other("Failed to downcast Int64 column to Int64Array".to_string())
422 })?;
423 let values: Vec<i64> = indices.iter().map(|&idx| array.value(idx)).collect();
424 Arc::new(Int64Array::from(values))
425 }
426 DataType::Float32 => {
427 let array = column.as_any().downcast_ref::<Float32Array>().ok_or_else(|| {
428 Error::Other("Failed to downcast Float32 column to Float32Array".to_string())
429 })?;
430 let values: Vec<f32> = indices.iter().map(|&idx| array.value(idx)).collect();
431 Arc::new(Float32Array::from(values))
432 }
433 DataType::Float64 => {
434 let array = column.as_any().downcast_ref::<Float64Array>().ok_or_else(|| {
435 Error::Other("Failed to downcast Float64 column to Float64Array".to_string())
436 })?;
437 let values: Vec<f64> = indices.iter().map(|&idx| array.value(idx)).collect();
438 Arc::new(Float64Array::from(values))
439 }
440 DataType::Utf8 => {
441 let array = column.as_any().downcast_ref::<StringArray>().ok_or_else(|| {
442 Error::Other("Failed to downcast Utf8 column to StringArray".to_string())
443 })?;
444 let values: Vec<&str> = indices.iter().map(|&idx| array.value(idx)).collect();
445 Arc::new(StringArray::from(values))
446 }
447 dt => {
448 return Err(Error::InvalidInput(format!(
449 "Top-K not implemented for column data type: {dt:?}"
450 )));
451 }
452 };
453
454 new_columns.push(new_array);
455 }
456
457 RecordBatch::try_new(batch.schema(), new_columns)
458 .map_err(|e| Error::StorageError(format!("Failed to create result batch: {e}")))
459}
460
461fn sort_all_rows(
463 batch: &RecordBatch,
464 column_index: usize,
465 order: SortOrder,
466) -> crate::Result<RecordBatch> {
467 use arrow::compute::sort_to_indices;
468
469 let sort_options = SortOptions::from(order);
470 let indices = sort_to_indices(batch.column(column_index).as_ref(), Some(sort_options), None)
471 .map_err(|e| Error::StorageError(format!("Failed to sort: {e}")))?;
472
473 let indices_array =
475 indices.as_any().downcast_ref::<arrow::array::UInt32Array>().ok_or_else(|| {
476 Error::Other(
477 "Failed to downcast sort indices to UInt32Array (expected from sort_to_indices)"
478 .to_string(),
479 )
480 })?;
481 let indices_vec: Vec<usize> =
482 (0..indices_array.len()).map(|i| indices_array.value(i) as usize).collect();
483
484 build_batch_from_indices(batch, &indices_vec)
485}
486
487#[cfg(test)]
488#[allow(
489 clippy::cast_possible_truncation,
490 clippy::cast_possible_wrap,
491 clippy::cast_precision_loss,
492 clippy::float_cmp,
493 clippy::redundant_closure
494)]
495mod tests {
496 use super::*;
497 use arrow::datatypes::{DataType, Field, Schema};
498 use std::sync::Arc;
499
500 fn create_test_batch(values: Vec<f64>) -> RecordBatch {
501 let schema = Arc::new(Schema::new(vec![
502 Field::new("id", DataType::Int32, false),
503 Field::new("score", DataType::Float64, false),
504 ]));
505
506 let ids: Vec<i32> = (0..values.len() as i32).collect();
507
508 RecordBatch::try_new(
509 schema,
510 vec![Arc::new(Int32Array::from(ids)), Arc::new(Float64Array::from(values))],
511 )
512 .unwrap()
513 }
514
515 #[test]
516 fn test_top_k_descending_basic() {
517 let batch = create_test_batch(vec![1.0, 5.0, 3.0, 9.0, 2.0]);
519 let result = batch.top_k(1, 3, SortOrder::Descending).unwrap();
520
521 assert_eq!(result.num_rows(), 3);
522
523 let scores = result.column(1).as_any().downcast_ref::<Float64Array>().unwrap();
524 assert_eq!(scores.value(0), 9.0);
525 assert_eq!(scores.value(1), 5.0);
526 assert_eq!(scores.value(2), 3.0);
527 }
528
529 #[test]
530 fn test_top_k_ascending_basic() {
531 let batch = create_test_batch(vec![1.0, 5.0, 3.0, 9.0, 2.0]);
533 let result = batch.top_k(1, 3, SortOrder::Ascending).unwrap();
534
535 assert_eq!(result.num_rows(), 3);
536
537 let scores = result.column(1).as_any().downcast_ref::<Float64Array>().unwrap();
538 assert_eq!(scores.value(0), 1.0);
539 assert_eq!(scores.value(1), 2.0);
540 assert_eq!(scores.value(2), 3.0);
541 }
542
543 #[test]
544 fn test_top_k_k_equals_length() {
545 let batch = create_test_batch(vec![3.0, 1.0, 2.0]);
547 let result = batch.top_k(1, 3, SortOrder::Descending).unwrap();
548
549 assert_eq!(result.num_rows(), 3);
550
551 let scores = result.column(1).as_any().downcast_ref::<Float64Array>().unwrap();
552 assert_eq!(scores.value(0), 3.0);
553 assert_eq!(scores.value(1), 2.0);
554 assert_eq!(scores.value(2), 1.0);
555 }
556
557 #[test]
558 fn test_top_k_k_greater_than_length() {
559 let batch = create_test_batch(vec![3.0, 1.0, 2.0]);
561 let result = batch.top_k(1, 10, SortOrder::Descending).unwrap();
562
563 assert_eq!(result.num_rows(), 3);
564
565 let scores = result.column(1).as_any().downcast_ref::<Float64Array>().unwrap();
566 assert_eq!(scores.value(0), 3.0);
567 assert_eq!(scores.value(1), 2.0);
568 assert_eq!(scores.value(2), 1.0);
569 }
570
571 #[test]
572 fn test_top_k_k_zero_fails() {
573 let batch = create_test_batch(vec![1.0, 2.0, 3.0]);
575 let result = batch.top_k(1, 0, SortOrder::Descending);
576
577 assert!(result.is_err());
578 assert!(result.unwrap_err().to_string().contains("must be greater than 0"));
579 }
580
581 #[test]
582 fn test_top_k_invalid_column_index() {
583 let batch = create_test_batch(vec![1.0, 2.0, 3.0]);
585 let result = batch.top_k(99, 2, SortOrder::Descending);
586
587 assert!(result.is_err());
588 assert!(result.unwrap_err().to_string().contains("out of bounds"));
589 }
590
591 #[test]
592 fn test_top_k_preserves_row_integrity() {
593 let batch = create_test_batch(vec![1.0, 5.0, 3.0]);
595 let result = batch.top_k(1, 2, SortOrder::Descending).unwrap();
596
597 let ids = result.column(0).as_any().downcast_ref::<Int32Array>().unwrap();
598 let scores = result.column(1).as_any().downcast_ref::<Float64Array>().unwrap();
599
600 assert_eq!(scores.value(0), 5.0);
602 assert_eq!(ids.value(0), 1);
603
604 assert_eq!(scores.value(1), 3.0);
605 assert_eq!(ids.value(1), 2);
606 }
607
608 #[test]
609 fn test_top_k_large_dataset() {
610 let values: Vec<f64> = (0..1_000_000).map(|i| f64::from(i)).collect();
612 let batch = create_test_batch(values);
613
614 let start = std::time::Instant::now();
615 let result = batch.top_k(1, 10, SortOrder::Descending).unwrap();
616 let duration = start.elapsed();
617
618 assert_eq!(result.num_rows(), 10);
619
620 let scores = result.column(1).as_any().downcast_ref::<Float64Array>().unwrap();
621 for i in 0..10 {
623 assert_eq!(scores.value(i), 999_999.0 - i as f64);
624 }
625
626 assert!(
630 duration.as_millis() < 500,
631 "Top-K took {}ms (expected <500ms)",
632 duration.as_millis()
633 );
634 }
635
636 #[cfg(test)]
638 mod property_tests {
639 use super::*;
640 use proptest::prelude::*;
641
642 proptest! {
643 #[test]
645 fn prop_top_k_returns_k_rows(
646 values in prop::collection::vec(0.0f64..1000.0, 10..1000),
647 k in 1usize..100
648 ) {
649 let batch = create_test_batch(values.clone());
650 let result = batch.top_k(1, k, SortOrder::Descending).unwrap();
651
652 let expected_rows = k.min(values.len());
653 prop_assert_eq!(result.num_rows(), expected_rows);
654 }
655
656 #[test]
658 fn prop_top_k_descending_is_sorted(
659 values in prop::collection::vec(0.0f64..1000.0, 10..1000),
660 k in 1usize..100
661 ) {
662 let batch = create_test_batch(values);
663 let result = batch.top_k(1, k, SortOrder::Descending).unwrap();
664
665 let scores = result.column(1).as_any().downcast_ref::<Float64Array>().unwrap();
666
667 for i in 0..scores.len().saturating_sub(1) {
669 prop_assert!(
670 scores.value(i) >= scores.value(i + 1),
671 "Not in descending order: {} < {}",
672 scores.value(i),
673 scores.value(i + 1)
674 );
675 }
676 }
677
678 #[test]
680 fn prop_top_k_ascending_is_sorted(
681 values in prop::collection::vec(0.0f64..1000.0, 10..1000),
682 k in 1usize..100
683 ) {
684 let batch = create_test_batch(values);
685 let result = batch.top_k(1, k, SortOrder::Ascending).unwrap();
686
687 let scores = result.column(1).as_any().downcast_ref::<Float64Array>().unwrap();
688
689 for i in 0..scores.len().saturating_sub(1) {
691 prop_assert!(
692 scores.value(i) <= scores.value(i + 1),
693 "Not in ascending order: {} > {}",
694 scores.value(i),
695 scores.value(i + 1)
696 );
697 }
698 }
699 }
700 }
701
702 #[test]
704 fn test_top_k_int32() {
705 use arrow::array::Int32Array;
706 use arrow::datatypes::{DataType, Field, Schema};
707 use std::sync::Arc;
708
709 let schema = Schema::new(vec![Field::new("value", DataType::Int32, false)]);
710 let values = Int32Array::from(vec![5, 2, 8, 1, 9, 3]);
711 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values)]).unwrap();
712
713 let result = batch.top_k(0, 3, SortOrder::Descending).unwrap();
714 assert_eq!(result.num_rows(), 3);
715
716 let col = result.column(0).as_any().downcast_ref::<Int32Array>().unwrap();
717 assert_eq!(col.value(0), 9);
718 assert_eq!(col.value(1), 8);
719 assert_eq!(col.value(2), 5);
720 }
721
722 #[test]
723 fn test_top_k_int32_ascending() {
724 use arrow::array::Int32Array;
725 use arrow::datatypes::{DataType, Field, Schema};
726 use std::sync::Arc;
727
728 let schema = Schema::new(vec![Field::new("value", DataType::Int32, false)]);
729 let values = Int32Array::from(vec![5, 2, 8, 1, 9, 3]);
730 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values)]).unwrap();
731
732 let result = batch.top_k(0, 3, SortOrder::Ascending).unwrap();
733 assert_eq!(result.num_rows(), 3);
734
735 let col = result.column(0).as_any().downcast_ref::<Int32Array>().unwrap();
736 assert_eq!(col.value(0), 1);
737 assert_eq!(col.value(1), 2);
738 assert_eq!(col.value(2), 3);
739 }
740
741 #[test]
742 fn test_top_k_int64() {
743 use arrow::array::Int64Array;
744 use arrow::datatypes::{DataType, Field, Schema};
745 use std::sync::Arc;
746
747 let schema = Schema::new(vec![Field::new("value", DataType::Int64, false)]);
748 let values = Int64Array::from(vec![100i64, 200, 50, 300, 150]);
749 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values)]).unwrap();
750
751 let result = batch.top_k(0, 2, SortOrder::Ascending).unwrap();
752 assert_eq!(result.num_rows(), 2);
753
754 let col = result.column(0).as_any().downcast_ref::<Int64Array>().unwrap();
755 assert_eq!(col.value(0), 50);
756 assert_eq!(col.value(1), 100);
757 }
758
759 #[test]
760 fn test_top_k_int64_descending() {
761 use arrow::array::Int64Array;
762 use arrow::datatypes::{DataType, Field, Schema};
763 use std::sync::Arc;
764
765 let schema = Schema::new(vec![Field::new("value", DataType::Int64, false)]);
766 let values = Int64Array::from(vec![100i64, 200, 50, 300, 150]);
767 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values)]).unwrap();
768
769 let result = batch.top_k(0, 2, SortOrder::Descending).unwrap();
770 assert_eq!(result.num_rows(), 2);
771
772 let col = result.column(0).as_any().downcast_ref::<Int64Array>().unwrap();
773 assert_eq!(col.value(0), 300);
774 assert_eq!(col.value(1), 200);
775 }
776
777 #[test]
778 fn test_top_k_float32() {
779 use arrow::array::Float32Array;
780 use arrow::datatypes::{DataType, Field, Schema};
781 use std::sync::Arc;
782
783 let schema = Schema::new(vec![Field::new("value", DataType::Float32, false)]);
784 let values = Float32Array::from(vec![1.5f32, 2.7, 0.3, 4.2, 3.1]);
785 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values)]).unwrap();
786
787 let result = batch.top_k(0, 3, SortOrder::Descending).unwrap();
788 assert_eq!(result.num_rows(), 3);
789
790 let col = result.column(0).as_any().downcast_ref::<Float32Array>().unwrap();
791 assert!((col.value(0) - 4.2).abs() < 0.001);
792 assert!((col.value(1) - 3.1).abs() < 0.001);
793 assert!((col.value(2) - 2.7).abs() < 0.001);
794 }
795
796 #[test]
797 fn test_top_k_float32_ascending() {
798 use arrow::array::Float32Array;
799 use arrow::datatypes::{DataType, Field, Schema};
800 use std::sync::Arc;
801
802 let schema = Schema::new(vec![Field::new("value", DataType::Float32, false)]);
803 let values = Float32Array::from(vec![1.5f32, 2.7, 0.3, 4.2, 3.1]);
804 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values)]).unwrap();
805
806 let result = batch.top_k(0, 3, SortOrder::Ascending).unwrap();
807 assert_eq!(result.num_rows(), 3);
808
809 let col = result.column(0).as_any().downcast_ref::<Float32Array>().unwrap();
810 assert!((col.value(0) - 0.3).abs() < 0.001);
811 assert!((col.value(1) - 1.5).abs() < 0.001);
812 assert!((col.value(2) - 2.7).abs() < 0.001);
813 }
814
815 #[test]
816 fn test_top_k_unsupported_type() {
817 use arrow::array::StringArray;
818 use arrow::datatypes::{DataType, Field, Schema};
819 use std::sync::Arc;
820
821 let schema = Schema::new(vec![Field::new("value", DataType::Utf8, false)]);
822 let values = StringArray::from(vec!["a", "b", "c"]);
823 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values)]).unwrap();
824
825 let result = batch.top_k(0, 2, SortOrder::Descending);
826 assert!(result.is_err());
827 assert!(result.unwrap_err().to_string().contains("Top-K not supported for data type"));
828 }
829
830 #[test]
835 fn test_min_heap_item_eq() {
836 let item1 = MinHeapItem { value: 42i32, index: 0 };
837 let item2 = MinHeapItem { value: 42i32, index: 1 };
838 let item3 = MinHeapItem { value: 43i32, index: 2 };
839
840 assert_eq!(item1, item2);
841 assert_ne!(item1, item3);
842 }
843
844 #[test]
845 fn test_min_heap_item_ord() {
846 let item1 = MinHeapItem { value: 10i32, index: 0 };
847 let item2 = MinHeapItem { value: 20i32, index: 1 };
848 let item3 = MinHeapItem { value: 30i32, index: 2 };
849
850 assert!(item3 < item2); assert!(item2 < item1); }
854
855 #[test]
856 fn test_min_heap_item_partial_ord() {
857 let item1 = MinHeapItem { value: 5i32, index: 0 };
858 let item2 = MinHeapItem { value: 10i32, index: 1 };
859
860 assert!(item1.partial_cmp(&item2) == Some(Ordering::Greater));
861 }
862
863 #[test]
864 fn test_max_heap_item_eq() {
865 let item1 = MaxHeapItem { value: 42i32, index: 0 };
866 let item2 = MaxHeapItem { value: 42i32, index: 1 };
867 let item3 = MaxHeapItem { value: 43i32, index: 2 };
868
869 assert_eq!(item1, item2);
870 assert_ne!(item1, item3);
871 }
872
873 #[test]
874 fn test_max_heap_item_ord() {
875 let item1 = MaxHeapItem { value: 10i32, index: 0 };
876 let item2 = MaxHeapItem { value: 20i32, index: 1 };
877 let item3 = MaxHeapItem { value: 30i32, index: 2 };
878
879 assert!(item3 > item2);
881 assert!(item2 > item1);
882 }
883
884 #[test]
885 fn test_max_heap_item_partial_ord() {
886 let item1 = MaxHeapItem { value: 5i32, index: 0 };
887 let item2 = MaxHeapItem { value: 10i32, index: 1 };
888
889 assert!(item1.partial_cmp(&item2) == Some(Ordering::Less));
890 }
891
892 #[test]
893 fn test_heap_item_with_floats() {
894 let item1 = MinHeapItem { value: 1.5f64, index: 0 };
895 let item2 = MinHeapItem { value: 2.5f64, index: 1 };
896
897 assert_ne!(item1, item2);
898 assert!(item2 < item1); }
900
901 #[test]
902 fn test_heap_item_eq_method_with_floats() {
903 let item1 = MaxHeapItem { value: 3.25f64, index: 0 };
904 let item2 = MaxHeapItem { value: 3.25f64, index: 1 };
905 let item3 = MaxHeapItem { value: 2.75f64, index: 2 };
906
907 assert!(item1.eq(&item2));
908 assert!(!item1.eq(&item3));
909 }
910}