vortex_layout/layouts/row_idx/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod expr;
5
6use std::collections::BTreeSet;
7use std::fmt::{Display, Formatter};
8use std::ops::{BitAnd, Range};
9use std::sync::Arc;
10
11use Nullability::NonNullable;
12pub use expr::*;
13use futures::FutureExt;
14use futures::future::BoxFuture;
15use vortex_array::compute::filter;
16use vortex_array::expr::transform::{PartitionedExpr, partition, replace};
17use vortex_array::expr::{ExactExpr, Expression, is_root, root};
18use vortex_array::{ArrayRef, IntoArray, MaskFuture};
19use vortex_dtype::{DType, FieldMask, FieldName, Nullability, PType};
20use vortex_error::{VortexExpect, VortexResult};
21use vortex_mask::Mask;
22use vortex_scalar::PValue;
23use vortex_sequence::SequenceArray;
24use vortex_utils::aliases::dash_map::DashMap;
25
26use crate::layouts::partitioned::PartitionedExprEval;
27use crate::{ArrayFuture, LayoutReader};
28
29pub struct RowIdxLayoutReader {
30    name: Arc<str>,
31    row_offset: u64,
32    child: Arc<dyn LayoutReader>,
33
34    partition_cache: DashMap<ExactExpr, Partitioning>,
35}
36
37impl RowIdxLayoutReader {
38    pub fn new(row_offset: u64, child: Arc<dyn LayoutReader>) -> Self {
39        Self {
40            name: child.name().clone(),
41            row_offset,
42            child,
43            partition_cache: DashMap::with_hasher(Default::default()),
44        }
45    }
46
47    fn partition_expr(&self, expr: &Expression) -> Partitioning {
48        self.partition_cache
49            .entry(ExactExpr(expr.clone()))
50            .or_insert_with(|| {
51                // Partition the expression into row idx and child expressions.
52                let mut partitioned = partition(expr.clone(), self.dtype(), |expr| {
53                    if expr.is::<RowIdx>() {
54                        vec![Partition::RowIdx]
55                    } else if is_root(expr) {
56                        vec![Partition::Child]
57                    } else {
58                        vec![]
59                    }
60                })
61                .vortex_expect("We should not fail to partition expression over struct fields");
62
63                // If there's only a single partition, we can directly return the expression.
64                if partitioned.partitions.len() == 1 {
65                    return match &partitioned.partition_annotations[0] {
66                        Partition::RowIdx => {
67                            Partitioning::RowIdx(replace(expr.clone(), &row_idx(), root()))
68                        }
69                        Partition::Child => Partitioning::Child(expr.clone()),
70                    };
71                }
72
73                // Replace the row_idx expression with the root expression in the row_idx partition.
74                partitioned.partitions = partitioned
75                    .partitions
76                    .into_iter()
77                    .map(|p| replace(p, &row_idx(), root()))
78                    .collect();
79
80                Partitioning::Partitioned(Arc::new(partitioned))
81            })
82            .clone()
83    }
84}
85
86#[derive(Clone)]
87enum Partitioning {
88    // An expression that only references the row index (e.g., `row_idx == 5`).
89    RowIdx(Expression),
90    // An expression that does not reference the row index.
91    Child(Expression),
92    // Contains both the RowIdx and Child expressions, (e.g., `row_idx < child.some_field`).
93    Partitioned(Arc<PartitionedExpr<Partition>>),
94}
95
96#[derive(Clone, PartialEq, Eq, Hash)]
97enum Partition {
98    RowIdx,
99    Child,
100}
101
102impl Partition {
103    pub fn name(&self) -> &str {
104        match self {
105            Partition::RowIdx => "row_idx",
106            Partition::Child => "child",
107        }
108    }
109}
110
111impl From<Partition> for FieldName {
112    fn from(value: Partition) -> Self {
113        FieldName::from(value.name())
114    }
115}
116
117impl Display for Partition {
118    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
119        write!(f, "{}", self.name())
120    }
121}
122
123impl LayoutReader for RowIdxLayoutReader {
124    fn name(&self) -> &Arc<str> {
125        &self.name
126    }
127
128    fn dtype(&self) -> &DType {
129        self.child.dtype()
130    }
131
132    fn row_count(&self) -> u64 {
133        self.child.row_count()
134    }
135
136    fn register_splits(
137        &self,
138        field_mask: &[FieldMask],
139        row_range: &Range<u64>,
140        splits: &mut BTreeSet<u64>,
141    ) -> VortexResult<()> {
142        self.child.register_splits(field_mask, row_range, splits)
143    }
144
145    fn pruning_evaluation(
146        &self,
147        row_range: &Range<u64>,
148        expr: &Expression,
149        mask: Mask,
150    ) -> VortexResult<MaskFuture> {
151        Ok(match &self.partition_expr(expr) {
152            Partitioning::RowIdx(expr) => {
153                row_idx_mask_future(self.row_offset, row_range, expr, MaskFuture::ready(mask))
154            }
155            Partitioning::Child(expr) => self.child.pruning_evaluation(row_range, expr, mask)?,
156            Partitioning::Partitioned(..) => MaskFuture::ready(mask),
157        })
158    }
159
160    fn filter_evaluation(
161        &self,
162        row_range: &Range<u64>,
163        expr: &Expression,
164        mask: MaskFuture,
165    ) -> VortexResult<MaskFuture> {
166        match &self.partition_expr(expr) {
167            // Since this is run during pruning, we skip re-evaluating the row index expression
168            // during the filter evaluation.
169            Partitioning::RowIdx(_) => Ok(mask),
170            Partitioning::Child(expr) => self.child.filter_evaluation(row_range, expr, mask),
171            Partitioning::Partitioned(p) => p.clone().into_mask_future(
172                mask,
173                |annotation, expr, mask| match annotation {
174                    Partition::RowIdx => {
175                        Ok(row_idx_mask_future(self.row_offset, row_range, expr, mask))
176                    }
177                    Partition::Child => self.child.filter_evaluation(row_range, expr, mask),
178                },
179                |annotation, expr, mask| match annotation {
180                    Partition::RowIdx => {
181                        Ok(row_idx_array_future(self.row_offset, row_range, expr, mask))
182                    }
183                    Partition::Child => self.child.projection_evaluation(row_range, expr, mask),
184                },
185            ),
186        }
187    }
188
189    fn projection_evaluation(
190        &self,
191        row_range: &Range<u64>,
192        expr: &Expression,
193        mask: MaskFuture,
194    ) -> VortexResult<BoxFuture<'static, VortexResult<ArrayRef>>> {
195        match &self.partition_expr(expr) {
196            Partitioning::RowIdx(expr) => {
197                Ok(row_idx_array_future(self.row_offset, row_range, expr, mask))
198            }
199            Partitioning::Child(expr) => self.child.projection_evaluation(row_range, expr, mask),
200            Partitioning::Partitioned(p) => {
201                p.clone()
202                    .into_array_future(mask, |annotation, expr, mask| match annotation {
203                        Partition::RowIdx => {
204                            Ok(row_idx_array_future(self.row_offset, row_range, expr, mask))
205                        }
206                        Partition::Child => self.child.projection_evaluation(row_range, expr, mask),
207                    })
208            }
209        }
210    }
211}
212
213// Returns a SequenceArray representing the row indices for the given row range,
214fn idx_array(row_offset: u64, row_range: &Range<u64>) -> SequenceArray {
215    SequenceArray::new(
216        PValue::U64(row_offset + row_range.start),
217        PValue::U64(1),
218        PType::U64,
219        NonNullable,
220        usize::try_from(row_range.end - row_range.start)
221            .vortex_expect("Row range length must fit in usize"),
222    )
223    .vortex_expect("Failed to create row index array")
224}
225
226fn row_idx_mask_future(
227    row_offset: u64,
228    row_range: &Range<u64>,
229    expr: &Expression,
230    mask: MaskFuture,
231) -> MaskFuture {
232    let row_range = row_range.clone();
233    let expr = expr.clone();
234    MaskFuture::new(mask.len(), async move {
235        let array = idx_array(row_offset, &row_range).into_array();
236        let result_mask = expr.evaluate(&array)?.try_to_mask_fill_null_false()?;
237        Ok(result_mask.bitand(&mask.await?))
238    })
239}
240
241fn row_idx_array_future(
242    row_offset: u64,
243    row_range: &Range<u64>,
244    expr: &Expression,
245    mask: MaskFuture,
246) -> ArrayFuture {
247    let row_range = row_range.clone();
248    let expr = expr.clone();
249    async move {
250        let array = idx_array(row_offset, &row_range).into_array();
251        let array = filter(&array, &mask.await?)?;
252        expr.evaluate(&array)
253    }
254    .boxed()
255}
256
257#[cfg(test)]
258mod tests {
259    use std::sync::Arc;
260
261    use itertools::Itertools;
262    use vortex_array::expr::{eq, gt, lit, or, root};
263    use vortex_array::{ArrayContext, IntoArray as _, MaskFuture, ToCanonical};
264    use vortex_buffer::{BitBuffer, buffer};
265    use vortex_io::runtime::single::block_on;
266
267    use crate::layouts::flat::writer::FlatLayoutStrategy;
268    use crate::layouts::row_idx::{RowIdxLayoutReader, row_idx};
269    use crate::segments::TestSegments;
270    use crate::sequence::{SequenceId, SequentialArrayStreamExt};
271    use crate::{LayoutReader, LayoutStrategy};
272
273    #[test]
274    fn flat_expr_no_row_id() {
275        block_on(|handle| async {
276            let ctx = ArrayContext::empty();
277            let segments = Arc::new(TestSegments::default());
278            let (ptr, eof) = SequenceId::root().split();
279            let array = buffer![1..=5].into_array();
280            let layout = FlatLayoutStrategy::default()
281                .write_stream(
282                    ctx,
283                    segments.clone(),
284                    array.to_array_stream().sequenced(ptr),
285                    eof,
286                    handle,
287                )
288                .await
289                .unwrap();
290
291            let expr = eq(root(), lit(3i32));
292            let result =
293                RowIdxLayoutReader::new(0, layout.new_reader("".into(), segments).unwrap())
294                    .projection_evaluation(
295                        &(0..layout.row_count()),
296                        &expr,
297                        MaskFuture::new_true(layout.row_count().try_into().unwrap()),
298                    )
299                    .unwrap()
300                    .await
301                    .unwrap()
302                    .to_bool();
303
304            assert_eq!(
305                &BitBuffer::from_iter([false, false, true, false, false]),
306                result.bit_buffer()
307            );
308        })
309    }
310
311    #[test]
312    fn flat_expr_row_id() {
313        block_on(|handle| async {
314            let ctx = ArrayContext::empty();
315            let segments = Arc::new(TestSegments::default());
316            let (ptr, eof) = SequenceId::root().split();
317            let array = buffer![1..=5].into_array();
318            let layout = FlatLayoutStrategy::default()
319                .write_stream(
320                    ctx,
321                    segments.clone(),
322                    array.to_array_stream().sequenced(ptr),
323                    eof,
324                    handle,
325                )
326                .await
327                .unwrap();
328
329            let expr = gt(row_idx(), lit(3u64));
330            let result =
331                RowIdxLayoutReader::new(0, layout.new_reader("".into(), segments).unwrap())
332                    .projection_evaluation(
333                        &(0..layout.row_count()),
334                        &expr,
335                        MaskFuture::new_true(layout.row_count().try_into().unwrap()),
336                    )
337                    .unwrap()
338                    .await
339                    .unwrap()
340                    .to_bool();
341
342            assert_eq!(
343                &BitBuffer::from_iter([false, false, false, false, true]),
344                result.bit_buffer()
345            );
346        })
347    }
348
349    #[test]
350    fn flat_expr_or() {
351        block_on(|handle| async {
352            let ctx = ArrayContext::empty();
353            let segments = Arc::new(TestSegments::default());
354            let (ptr, eof) = SequenceId::root().split();
355            let array = buffer![1..=5].into_array();
356            let layout = FlatLayoutStrategy::default()
357                .write_stream(
358                    ctx,
359                    segments.clone(),
360                    array.to_array_stream().sequenced(ptr),
361                    eof,
362                    handle,
363                )
364                .await
365                .unwrap();
366
367            let expr = or(
368                eq(root(), lit(3i32)),
369                or(gt(row_idx(), lit(3u64)), eq(root(), lit(1i32))),
370            );
371
372            let result =
373                RowIdxLayoutReader::new(0, layout.new_reader("".into(), segments).unwrap())
374                    .projection_evaluation(
375                        &(0..layout.row_count()),
376                        &expr,
377                        MaskFuture::new_true(layout.row_count().try_into().unwrap()),
378                    )
379                    .unwrap()
380                    .await
381                    .unwrap()
382                    .to_bool();
383
384            assert_eq!(
385                vec![true, false, true, false, true],
386                result.bit_buffer().iter().collect_vec()
387            );
388        })
389    }
390}