vortex_array/stream/
take_rows.rs1use std::pin::Pin;
2use std::task::{Context, Poll};
3
4use futures_util::{Stream, ready};
5use pin_project::pin_project;
6use vortex_dtype::match_each_integer_ptype;
7use vortex_error::{VortexResult, vortex_bail};
8use vortex_scalar::Scalar;
9
10use crate::compute::{SearchSortedSide, search_sorted_usize, slice, sub_scalar, take};
11use crate::stats::Stat;
12use crate::stream::ArrayStream;
13use crate::variants::PrimitiveArrayTrait;
14use crate::{Array, ArrayRef, ToCanonical};
15
16#[pin_project]
17pub struct TakeRows<R: ArrayStream> {
18 #[pin]
19 reader: R,
20 indices: ArrayRef,
21 row_offset: usize,
22}
23
24impl<R: ArrayStream> TakeRows<R> {
25 pub fn try_new(reader: R, indices: ArrayRef) -> VortexResult<Self> {
26 if !indices.is_empty() {
27 if !indices.statistics().compute_is_sorted().unwrap_or(false) {
28 vortex_bail!("Indices must be sorted to take from IPC stream")
29 }
30
31 if indices
32 .statistics()
33 .compute_null_count()
34 .map(|nc| nc > 0)
35 .unwrap_or(true)
36 {
37 vortex_bail!("Indices must not contain nulls")
38 }
39
40 if !indices.dtype().is_int() {
41 vortex_bail!("Indices must be integers")
42 }
43
44 if indices.dtype().is_signed_int()
45 && indices
46 .statistics()
47 .compute_as::<i64>(Stat::Min)
48 .map(|min| min < 0)
49 .unwrap_or(true)
50 {
51 vortex_bail!("Indices must be positive")
52 }
53 }
54
55 Ok(Self {
56 reader,
57 indices,
58 row_offset: 0,
59 })
60 }
61}
62
63impl<R: ArrayStream> Stream for TakeRows<R> {
64 type Item = VortexResult<ArrayRef>;
65
66 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
67 let mut this = self.project();
68
69 if this.indices.is_empty() {
70 return Poll::Ready(None);
71 }
72
73 while let Some(batch) = ready!(this.reader.as_mut().poll_next(cx)?) {
74 let curr_offset = *this.row_offset;
75 let left =
76 search_sorted_usize(this.indices, curr_offset, SearchSortedSide::Left)?.to_index();
77 let right = search_sorted_usize(
78 this.indices,
79 curr_offset + batch.len(),
80 SearchSortedSide::Left,
81 )?
82 .to_index();
83
84 *this.row_offset += batch.len();
85
86 if left == right {
87 continue;
88 }
89
90 let indices_for_batch = slice(this.indices, left, right)?.to_primitive()?;
93 let shifted_arr = match_each_integer_ptype!(indices_for_batch.ptype(), |$T| {
94 sub_scalar(&indices_for_batch.to_array(), Scalar::from(curr_offset as $T))?
95 });
96 return Poll::Ready(take(&batch, &shifted_arr).map(Some).transpose());
97 }
98
99 Poll::Ready(None)
100 }
101}