vortex_array/stream/
take_rows.rs

1use std::pin::Pin;
2use std::task::{Context, Poll};
3
4use futures_util::{Stream, ready};
5use pin_project::pin_project;
6use vortex_dtype::match_each_integer_ptype;
7use vortex_error::{VortexResult, vortex_bail};
8use vortex_scalar::Scalar;
9
10use crate::compute::{SearchSortedSide, search_sorted_usize, slice, sub_scalar, take};
11use crate::stats::Stat;
12use crate::stream::ArrayStream;
13use crate::variants::PrimitiveArrayTrait;
14use crate::{Array, ArrayRef, ToCanonical};
15
16#[pin_project]
17pub struct TakeRows<R: ArrayStream> {
18    #[pin]
19    reader: R,
20    indices: ArrayRef,
21    row_offset: usize,
22}
23
24impl<R: ArrayStream> TakeRows<R> {
25    pub fn try_new(reader: R, indices: ArrayRef) -> VortexResult<Self> {
26        if !indices.is_empty() {
27            if !indices.statistics().compute_is_sorted().unwrap_or(false) {
28                vortex_bail!("Indices must be sorted to take from IPC stream")
29            }
30
31            if indices
32                .statistics()
33                .compute_null_count()
34                .map(|nc| nc > 0)
35                .unwrap_or(true)
36            {
37                vortex_bail!("Indices must not contain nulls")
38            }
39
40            if !indices.dtype().is_int() {
41                vortex_bail!("Indices must be integers")
42            }
43
44            if indices.dtype().is_signed_int()
45                && indices
46                    .statistics()
47                    .compute_as::<i64>(Stat::Min)
48                    .map(|min| min < 0)
49                    .unwrap_or(true)
50            {
51                vortex_bail!("Indices must be positive")
52            }
53        }
54
55        Ok(Self {
56            reader,
57            indices,
58            row_offset: 0,
59        })
60    }
61}
62
63impl<R: ArrayStream> Stream for TakeRows<R> {
64    type Item = VortexResult<ArrayRef>;
65
66    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
67        let mut this = self.project();
68
69        if this.indices.is_empty() {
70            return Poll::Ready(None);
71        }
72
73        while let Some(batch) = ready!(this.reader.as_mut().poll_next(cx)?) {
74            let curr_offset = *this.row_offset;
75            let left =
76                search_sorted_usize(this.indices, curr_offset, SearchSortedSide::Left)?.to_index();
77            let right = search_sorted_usize(
78                this.indices,
79                curr_offset + batch.len(),
80                SearchSortedSide::Left,
81            )?
82            .to_index();
83
84            *this.row_offset += batch.len();
85
86            if left == right {
87                continue;
88            }
89
90            // TODO(ngates): this is probably too heavy to run on the event loop. We should spawn
91            //  onto a worker pool.
92            let indices_for_batch = slice(this.indices, left, right)?.to_primitive()?;
93            let shifted_arr = match_each_integer_ptype!(indices_for_batch.ptype(), |$T| {
94                sub_scalar(&indices_for_batch.to_array(), Scalar::from(curr_offset as $T))?
95            });
96            return Poll::Ready(take(&batch, &shifted_arr).map(Some).transpose());
97        }
98
99        Poll::Ready(None)
100    }
101}