vortex_layout/layouts/row_idx/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod expr;
5
6use std::collections::BTreeSet;
7use std::fmt::{Display, Formatter};
8use std::ops::{BitAnd, Range};
9use std::sync::Arc;
10
11use Nullability::NonNullable;
12pub use expr::*;
13use futures::FutureExt;
14use futures::future::BoxFuture;
15use vortex_array::compute::filter;
16use vortex_array::stats::Precision;
17use vortex_array::{ArrayRef, IntoArray, MaskFuture};
18use vortex_dtype::{DType, FieldMask, FieldName, Nullability, PType};
19use vortex_error::{VortexExpect, VortexResult};
20use vortex_expr::transform::{PartitionedExpr, partition, replace};
21use vortex_expr::{ExactExpr, ExprRef, Scope, is_root, root};
22use vortex_mask::Mask;
23use vortex_scalar::PValue;
24use vortex_sequence::SequenceArray;
25use vortex_utils::aliases::dash_map::DashMap;
26
27use crate::layouts::partitioned::PartitionedExprEval;
28use crate::{ArrayFuture, LayoutReader};
29
30pub struct RowIdxLayoutReader {
31    name: Arc<str>,
32    row_offset: u64,
33    child: Arc<dyn LayoutReader>,
34
35    partition_cache: DashMap<ExactExpr, Partitioning>,
36}
37
38impl RowIdxLayoutReader {
39    pub fn new(row_offset: u64, child: Arc<dyn LayoutReader>) -> Self {
40        Self {
41            name: child.name().clone(),
42            row_offset,
43            child,
44            partition_cache: DashMap::with_hasher(Default::default()),
45        }
46    }
47
48    fn partition_expr(&self, expr: &ExprRef) -> Partitioning {
49        self.partition_cache
50            .entry(ExactExpr(expr.clone()))
51            .or_insert_with(|| {
52                // Partition the expression into row idx and child expressions.
53                let mut partitioned = partition(expr.clone(), self.dtype(), |expr| {
54                    if expr.is::<RowIdxVTable>() {
55                        vec![Partition::RowIdx]
56                    } else if is_root(expr) {
57                        vec![Partition::Child]
58                    } else {
59                        vec![]
60                    }
61                })
62                .vortex_expect("We should not fail to partition expression over struct fields");
63
64                // If there's only a single partition, we can directly return the expression.
65                if partitioned.partitions.len() == 1 {
66                    return match &partitioned.partition_annotations[0] {
67                        Partition::RowIdx => {
68                            Partitioning::RowIdx(replace(expr.clone(), &row_idx(), root()))
69                        }
70                        Partition::Child => Partitioning::Child(expr.clone()),
71                    };
72                }
73
74                // Replace the row_idx expression with the root expression in the row_idx partition.
75                partitioned.partitions = partitioned
76                    .partitions
77                    .into_iter()
78                    .map(|p| replace(p, &row_idx(), root()))
79                    .collect();
80
81                Partitioning::Partitioned(Arc::new(partitioned))
82            })
83            .clone()
84    }
85}
86
87#[derive(Clone)]
88enum Partitioning {
89    // An expression that only references the row index (e.g., `row_idx == 5`).
90    RowIdx(ExprRef),
91    // An expression that does not reference the row index.
92    Child(ExprRef),
93    // Contains both the RowIdx and Child expressions, (e.g., `row_idx < child.some_field`).
94    Partitioned(Arc<PartitionedExpr<Partition>>),
95}
96
97#[derive(Clone, PartialEq, Eq, Hash)]
98enum Partition {
99    RowIdx,
100    Child,
101}
102
103impl Partition {
104    pub fn name(&self) -> &str {
105        match self {
106            Partition::RowIdx => "row_idx",
107            Partition::Child => "child",
108        }
109    }
110}
111
112impl From<Partition> for FieldName {
113    fn from(value: Partition) -> Self {
114        FieldName::from(value.name())
115    }
116}
117
118impl Display for Partition {
119    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
120        write!(f, "{}", self.name())
121    }
122}
123
124impl LayoutReader for RowIdxLayoutReader {
125    fn name(&self) -> &Arc<str> {
126        &self.name
127    }
128
129    fn dtype(&self) -> &DType {
130        self.child.dtype()
131    }
132
133    fn row_count(&self) -> Precision<u64> {
134        self.child.row_count()
135    }
136
137    fn register_splits(
138        &self,
139        field_mask: &[FieldMask],
140        row_offset: u64,
141        splits: &mut BTreeSet<u64>,
142    ) -> VortexResult<()> {
143        self.child.register_splits(field_mask, row_offset, splits)
144    }
145
146    fn pruning_evaluation(
147        &self,
148        row_range: &Range<u64>,
149        expr: &ExprRef,
150        mask: Mask,
151    ) -> VortexResult<MaskFuture> {
152        Ok(match &self.partition_expr(expr) {
153            Partitioning::RowIdx(expr) => {
154                row_idx_mask_future(self.row_offset, row_range, expr, MaskFuture::ready(mask))
155            }
156            Partitioning::Child(expr) => self.child.pruning_evaluation(row_range, expr, mask)?,
157            Partitioning::Partitioned(..) => MaskFuture::ready(mask),
158        })
159    }
160
161    fn filter_evaluation(
162        &self,
163        row_range: &Range<u64>,
164        expr: &ExprRef,
165        mask: MaskFuture,
166    ) -> VortexResult<MaskFuture> {
167        match &self.partition_expr(expr) {
168            // Since this is run during pruning, we skip re-evaluating the row index expression
169            // during the filter evaluation.
170            Partitioning::RowIdx(_) => Ok(mask),
171            Partitioning::Child(expr) => self.child.filter_evaluation(row_range, expr, mask),
172            Partitioning::Partitioned(p) => p.clone().into_mask_future(
173                mask,
174                |annotation, expr, mask| match annotation {
175                    Partition::RowIdx => {
176                        Ok(row_idx_mask_future(self.row_offset, row_range, expr, mask))
177                    }
178                    Partition::Child => self.child.filter_evaluation(row_range, expr, mask),
179                },
180                |annotation, expr, mask| match annotation {
181                    Partition::RowIdx => {
182                        Ok(row_idx_array_future(self.row_offset, row_range, expr, mask))
183                    }
184                    Partition::Child => self.child.projection_evaluation(row_range, expr, mask),
185                },
186            ),
187        }
188    }
189
190    fn projection_evaluation(
191        &self,
192        row_range: &Range<u64>,
193        expr: &ExprRef,
194        mask: MaskFuture,
195    ) -> VortexResult<BoxFuture<'static, VortexResult<ArrayRef>>> {
196        match &self.partition_expr(expr) {
197            Partitioning::RowIdx(expr) => {
198                Ok(row_idx_array_future(self.row_offset, row_range, expr, mask))
199            }
200            Partitioning::Child(expr) => self.child.projection_evaluation(row_range, expr, mask),
201            Partitioning::Partitioned(p) => {
202                p.clone()
203                    .into_array_future(mask, |annotation, expr, mask| match annotation {
204                        Partition::RowIdx => {
205                            Ok(row_idx_array_future(self.row_offset, row_range, expr, mask))
206                        }
207                        Partition::Child => self.child.projection_evaluation(row_range, expr, mask),
208                    })
209            }
210        }
211    }
212}
213
214// Returns a SequenceArray representing the row indices for the given row range,
215fn idx_array(row_offset: u64, row_range: &Range<u64>) -> SequenceArray {
216    SequenceArray::new(
217        PValue::U64(row_offset + row_range.start),
218        PValue::U64(1),
219        PType::U64,
220        NonNullable,
221        usize::try_from(row_range.end - row_range.start)
222            .vortex_expect("Row range length must fit in usize"),
223    )
224    .vortex_expect("Failed to create row index array")
225}
226
227fn row_idx_mask_future(
228    row_offset: u64,
229    row_range: &Range<u64>,
230    expr: &ExprRef,
231    mask: MaskFuture,
232) -> MaskFuture {
233    let row_range = row_range.clone();
234    let expr = expr.clone();
235    MaskFuture::new(mask.len(), async move {
236        let array = idx_array(row_offset, &row_range).into_array();
237        let result_mask = expr
238            .evaluate(&Scope::new(array))?
239            .try_to_mask_fill_null_false()?;
240        Ok(result_mask.bitand(&mask.await?))
241    })
242}
243
244fn row_idx_array_future(
245    row_offset: u64,
246    row_range: &Range<u64>,
247    expr: &ExprRef,
248    mask: MaskFuture,
249) -> ArrayFuture {
250    let row_range = row_range.clone();
251    let expr = expr.clone();
252    async move {
253        let array = idx_array(row_offset, &row_range).into_array();
254        let array = filter(&array, &mask.await?)?;
255        expr.evaluate(&Scope::new(array))
256    }
257    .boxed()
258}
259
260#[cfg(test)]
261mod tests {
262    use std::sync::Arc;
263
264    use arrow_buffer::BooleanBuffer;
265    use itertools::Itertools;
266    use vortex_array::{ArrayContext, IntoArray as _, MaskFuture, ToCanonical};
267    use vortex_buffer::buffer;
268    use vortex_expr::{eq, gt, lit, or, root};
269    use vortex_io::runtime::single::block_on;
270
271    use crate::layouts::flat::writer::FlatLayoutStrategy;
272    use crate::layouts::row_idx::{RowIdxLayoutReader, row_idx};
273    use crate::segments::TestSegments;
274    use crate::sequence::{SequenceId, SequentialArrayStreamExt};
275    use crate::{LayoutReader, LayoutStrategy};
276
277    #[test]
278    fn flat_expr_no_row_id() {
279        block_on(|handle| async {
280            let ctx = ArrayContext::empty();
281            let segments = Arc::new(TestSegments::default());
282            let (ptr, eof) = SequenceId::root().split();
283            let array = buffer![1..=5].into_array();
284            let layout = FlatLayoutStrategy::default()
285                .write_stream(
286                    ctx,
287                    segments.clone(),
288                    array.to_array_stream().sequenced(ptr),
289                    eof,
290                    handle,
291                )
292                .await
293                .unwrap();
294
295            let expr = eq(root(), lit(3i32));
296            let result =
297                RowIdxLayoutReader::new(0, layout.new_reader("".into(), segments).unwrap())
298                    .projection_evaluation(
299                        &(0..layout.row_count()),
300                        &expr,
301                        MaskFuture::new_true(layout.row_count().try_into().unwrap()),
302                    )
303                    .unwrap()
304                    .await
305                    .unwrap()
306                    .to_bool();
307
308            assert_eq!(
309                &BooleanBuffer::from_iter([false, false, true, false, false]),
310                result.boolean_buffer()
311            );
312        })
313    }
314
315    #[test]
316    fn flat_expr_row_id() {
317        block_on(|handle| async {
318            let ctx = ArrayContext::empty();
319            let segments = Arc::new(TestSegments::default());
320            let (ptr, eof) = SequenceId::root().split();
321            let array = buffer![1..=5].into_array();
322            let layout = FlatLayoutStrategy::default()
323                .write_stream(
324                    ctx,
325                    segments.clone(),
326                    array.to_array_stream().sequenced(ptr),
327                    eof,
328                    handle,
329                )
330                .await
331                .unwrap();
332
333            let expr = gt(row_idx(), lit(3u64));
334            let result =
335                RowIdxLayoutReader::new(0, layout.new_reader("".into(), segments).unwrap())
336                    .projection_evaluation(
337                        &(0..layout.row_count()),
338                        &expr,
339                        MaskFuture::new_true(layout.row_count().try_into().unwrap()),
340                    )
341                    .unwrap()
342                    .await
343                    .unwrap()
344                    .to_bool();
345
346            assert_eq!(
347                &BooleanBuffer::from_iter([false, false, false, false, true]),
348                result.boolean_buffer()
349            );
350        })
351    }
352
353    #[test]
354    fn flat_expr_or() {
355        block_on(|handle| async {
356            let ctx = ArrayContext::empty();
357            let segments = Arc::new(TestSegments::default());
358            let (ptr, eof) = SequenceId::root().split();
359            let array = buffer![1..=5].into_array();
360            let layout = FlatLayoutStrategy::default()
361                .write_stream(
362                    ctx,
363                    segments.clone(),
364                    array.to_array_stream().sequenced(ptr),
365                    eof,
366                    handle,
367                )
368                .await
369                .unwrap();
370
371            let expr = or(
372                eq(root(), lit(3i32)),
373                or(gt(row_idx(), lit(3u64)), eq(root(), lit(1i32))),
374            );
375
376            let result =
377                RowIdxLayoutReader::new(0, layout.new_reader("".into(), segments).unwrap())
378                    .projection_evaluation(
379                        &(0..layout.row_count()),
380                        &expr,
381                        MaskFuture::new_true(layout.row_count().try_into().unwrap()),
382                    )
383                    .unwrap()
384                    .await
385                    .unwrap()
386                    .to_bool();
387
388            assert_eq!(
389                vec![true, false, true, false, true],
390                result.boolean_buffer().iter().collect_vec()
391            );
392        })
393    }
394}