vortex_layout/layouts/row_idx/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod expr;
5
6use std::collections::BTreeSet;
7use std::fmt::{Display, Formatter};
8use std::ops::{BitAnd, Range};
9use std::sync::Arc;
10
11use Nullability::NonNullable;
12use async_trait::async_trait;
13use dashmap::DashMap;
14pub use expr::*;
15use vortex_array::compute::filter;
16use vortex_array::stats::Precision;
17use vortex_array::{ArrayRef, IntoArray};
18use vortex_dtype::{DType, FieldMask, Nullability, PType};
19use vortex_error::{VortexExpect, VortexResult};
20use vortex_expr::transform::partition::{PartitionedExpr, partition};
21use vortex_expr::transform::replace::replace;
22use vortex_expr::{ExactExpr, ExprRef, Scope, is_root, root};
23use vortex_mask::Mask;
24use vortex_scalar::PValue;
25use vortex_sequence::SequenceArray;
26
27use crate::layouts::partitioned::{PartitionedArrayEvaluation, PartitionedMaskEvaluation};
28use crate::{
29    ArrayEvaluation, LayoutReader, MaskEvaluation, NoOpMaskEvaluation, NoOpPruningEvaluation,
30    PruningEvaluation,
31};
32
33pub struct RowIdxLayoutReader {
34    name: Arc<str>,
35    row_offset: u64,
36    child: Arc<dyn LayoutReader>,
37
38    partition_cache: DashMap<ExactExpr, Partitioning>,
39}
40
41impl RowIdxLayoutReader {
42    pub fn new(row_offset: u64, child: Arc<dyn LayoutReader>) -> Self {
43        Self {
44            name: child.name().clone(),
45            row_offset,
46            child,
47            partition_cache: DashMap::new(),
48        }
49    }
50
51    fn partition_expr(&self, expr: &ExprRef) -> Partitioning {
52        self.partition_cache
53            .entry(ExactExpr(expr.clone()))
54            .or_insert_with(|| {
55                // Partition the expression into row idx and child expressions.
56                let mut partitioned = partition(expr.clone(), self.dtype(), |expr| {
57                    if expr.is::<RowIdxVTable>() {
58                        vec![Partition::RowIdx]
59                    } else if is_root(expr) {
60                        vec![Partition::Child]
61                    } else {
62                        vec![]
63                    }
64                })
65                .vortex_expect("We should not fail to partition expression over struct fields");
66
67                // If there's only a single partition, we can directly return the expression.
68                if partitioned.partitions.len() == 1 {
69                    return match &partitioned.partition_annotations[0] {
70                        Partition::RowIdx => {
71                            Partitioning::RowIdx(replace(expr.clone(), &row_idx(), root()))
72                        }
73                        Partition::Child => Partitioning::Child(expr.clone()),
74                    };
75                }
76
77                // Replace the row_idx expression with the root expression in the row_idx partition.
78                partitioned.partitions = partitioned
79                    .partitions
80                    .into_iter()
81                    .map(|p| replace(p, &row_idx(), root()))
82                    .collect();
83
84                Partitioning::Partitioned(Arc::new(partitioned))
85            })
86            .clone()
87    }
88}
89
90#[derive(Clone)]
91enum Partitioning {
92    // An expression that only references the row index (e.g., `row_idx == 5`).
93    RowIdx(ExprRef),
94    // An expression that does not reference the row index.
95    Child(ExprRef),
96    // Contains both the RowIdx and Child expressions, (e.g., `row_idx < child.some_field`).
97    Partitioned(Arc<PartitionedExpr<Partition>>),
98}
99
100#[derive(Clone, PartialEq, Eq, Hash)]
101enum Partition {
102    RowIdx,
103    Child,
104}
105
106impl Display for Partition {
107    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
108        match self {
109            Partition::RowIdx => write!(f, "row_idx"),
110            Partition::Child => write!(f, "child"),
111        }
112    }
113}
114
115impl LayoutReader for RowIdxLayoutReader {
116    fn name(&self) -> &Arc<str> {
117        &self.name
118    }
119
120    fn dtype(&self) -> &DType {
121        self.child.dtype()
122    }
123
124    fn row_count(&self) -> Precision<u64> {
125        self.child.row_count()
126    }
127
128    fn register_splits(
129        &self,
130        field_mask: &[FieldMask],
131        row_offset: u64,
132        splits: &mut BTreeSet<u64>,
133    ) -> VortexResult<()> {
134        self.child.register_splits(field_mask, row_offset, splits)
135    }
136
137    fn pruning_evaluation(
138        &self,
139        row_range: &Range<u64>,
140        expr: &ExprRef,
141    ) -> VortexResult<Box<dyn PruningEvaluation>> {
142        match &self.partition_expr(expr) {
143            Partitioning::RowIdx(expr) => Ok(Box::new(RowIdxEvaluation::new(
144                self.row_offset,
145                row_range,
146                expr,
147            ))),
148            Partitioning::Child(expr) => self.child.pruning_evaluation(row_range, expr),
149            Partitioning::Partitioned(..) => Ok(Box::new(NoOpPruningEvaluation)),
150        }
151    }
152
153    fn filter_evaluation(
154        &self,
155        row_range: &Range<u64>,
156        expr: &ExprRef,
157    ) -> VortexResult<Box<dyn MaskEvaluation>> {
158        match &self.partition_expr(expr) {
159            // Since this is run during pruning, we skip re-evaluating the row index expression
160            // during the filter evaluation.
161            Partitioning::RowIdx(_) => Ok(Box::new(NoOpMaskEvaluation)),
162            Partitioning::Child(expr) => self.child.filter_evaluation(row_range, expr),
163            Partitioning::Partitioned(p) => Ok(Box::new(PartitionedMaskEvaluation::try_new(
164                p.clone(),
165                |annotation, expr| match annotation {
166                    Partition::RowIdx => Ok(Box::new(RowIdxEvaluation::new(
167                        self.row_offset,
168                        row_range,
169                        expr,
170                    ))),
171                    Partition::Child => self.child.filter_evaluation(row_range, expr),
172                },
173                |annotation, expr| match annotation {
174                    Partition::RowIdx => Ok(Box::new(RowIdxEvaluation::new(
175                        self.row_offset,
176                        row_range,
177                        expr,
178                    ))),
179                    Partition::Child => self.child.projection_evaluation(row_range, expr),
180                },
181            )?)),
182        }
183    }
184
185    fn projection_evaluation(
186        &self,
187        row_range: &Range<u64>,
188        expr: &ExprRef,
189    ) -> VortexResult<Box<dyn ArrayEvaluation>> {
190        match &self.partition_expr(expr) {
191            Partitioning::RowIdx(expr) => Ok(Box::new(RowIdxEvaluation::new(
192                self.row_offset,
193                row_range,
194                expr,
195            ))),
196            Partitioning::Child(expr) => self.child.projection_evaluation(row_range, expr),
197            Partitioning::Partitioned(p) => Ok(Box::new(PartitionedArrayEvaluation::try_new(
198                p.clone(),
199                |annotation, expr| match annotation {
200                    Partition::RowIdx => Ok(Box::new(RowIdxEvaluation::new(
201                        self.row_offset,
202                        row_range,
203                        expr,
204                    ))),
205                    Partition::Child => self.child.projection_evaluation(row_range, expr),
206                },
207            )?)),
208        }
209    }
210}
211
212/// We need a custom RowIdx evaluation because we need to defer creating the SequenceArray until
213/// we are given the final row_offset. We cannot just create a RowIdxLayout that spans the entire
214/// dataset because arrays can only cover up to usize rows, not u64.
215struct RowIdxEvaluation {
216    array: ArrayRef,
217    expr: ExprRef,
218}
219
220impl RowIdxEvaluation {
221    fn new(row_offset: u64, row_range: &Range<u64>, expr: &ExprRef) -> Self {
222        let array = SequenceArray::new(
223            PValue::U64(row_offset + row_range.start),
224            PValue::U64(1),
225            PType::U64,
226            NonNullable,
227            usize::try_from(row_range.end - row_range.start)
228                .vortex_expect("Row range length must fit in usize"),
229        )
230        .vortex_expect("Failed to create row index array");
231
232        Self {
233            array: array.into_array(),
234            expr: expr.clone(),
235        }
236    }
237}
238
239#[async_trait]
240impl PruningEvaluation for RowIdxEvaluation {
241    async fn invoke(&self, _mask: Mask) -> VortexResult<Mask> {
242        // TODO(ngates): we could optimize this if the mask was already quite sparse.
243        Mask::try_from(
244            self.expr
245                .evaluate(&Scope::new(self.array.clone()))?
246                .as_ref(),
247        )
248    }
249}
250
251#[async_trait]
252impl MaskEvaluation for RowIdxEvaluation {
253    async fn invoke(&self, mask: Mask) -> VortexResult<Mask> {
254        // TODO(ngates): we could optimize this if the mask was already quite sparse.
255        let result = Mask::try_from(
256            self.expr
257                .evaluate(&Scope::new(self.array.clone()))?
258                .as_ref(),
259        )?;
260
261        // Note that mask evaluation requires an intersection with the input mask, whereas
262        // pruning evaluation does not.
263        Ok(result.bitand(&mask))
264    }
265}
266
267#[async_trait]
268impl ArrayEvaluation for RowIdxEvaluation {
269    async fn invoke(&self, mask: Mask) -> VortexResult<ArrayRef> {
270        let array = filter(&self.array, &mask)?;
271        self.expr.evaluate(&Scope::new(array))
272    }
273}
274
275#[cfg(test)]
276mod tests {
277    use std::sync::Arc;
278
279    use arrow_buffer::BooleanBuffer;
280    use futures::executor::block_on;
281    use futures::stream;
282    use itertools::Itertools;
283    use vortex_array::arrays::PrimitiveArray;
284    use vortex_array::{ArrayContext, ToCanonical};
285    use vortex_expr::{eq, gt, lit, or, root};
286    use vortex_mask::Mask;
287
288    use crate::layouts::flat::writer::FlatLayoutStrategy;
289    use crate::layouts::row_idx::{RowIdxLayoutReader, row_idx};
290    use crate::segments::{SegmentSource, SequenceWriter, TestSegments};
291    use crate::sequence::SequenceId;
292    use crate::{LayoutReader, LayoutStrategy, SequentialStreamAdapter, SequentialStreamExt};
293
294    #[test]
295    fn flat_expr_no_row_id() {
296        block_on(async {
297            let ctx = ArrayContext::empty();
298            let segments = TestSegments::default();
299            let sequence_writer = SequenceWriter::new(Box::new(segments.clone()));
300            let array = PrimitiveArray::from_iter(1..=5).to_array();
301            let array_clone = array.clone();
302            let layout = FlatLayoutStrategy::default()
303                .write_stream(
304                    &ctx,
305                    sequence_writer.clone(),
306                    SequentialStreamAdapter::new(
307                        array.dtype().clone(),
308                        stream::once(async { Ok((SequenceId::root().downgrade(), array_clone)) }),
309                    )
310                    .sendable(),
311                )
312                .await
313                .unwrap();
314            let segments: Arc<dyn SegmentSource> = Arc::new(segments);
315
316            let expr = eq(root(), lit(3i32));
317            let result =
318                RowIdxLayoutReader::new(0, layout.new_reader("".into(), segments).unwrap())
319                    .projection_evaluation(&(0..layout.row_count()), &expr)
320                    .unwrap()
321                    .invoke(Mask::new_true(layout.row_count().try_into().unwrap()))
322                    .await
323                    .unwrap()
324                    .to_bool()
325                    .unwrap();
326
327            assert_eq!(
328                &BooleanBuffer::from_iter([false, false, true, false, false]),
329                result.boolean_buffer()
330            );
331        })
332    }
333
334    #[test]
335    fn flat_expr_row_id() {
336        block_on(async {
337            let ctx = ArrayContext::empty();
338            let segments = TestSegments::default();
339            let sequence_writer = SequenceWriter::new(Box::new(segments.clone()));
340            let array = PrimitiveArray::from_iter(1..=5).to_array();
341            let array_clone = array.clone();
342            let layout = FlatLayoutStrategy::default()
343                .write_stream(
344                    &ctx,
345                    sequence_writer.clone(),
346                    SequentialStreamAdapter::new(
347                        array.dtype().clone(),
348                        stream::once(async { Ok((SequenceId::root().downgrade(), array_clone)) }),
349                    )
350                    .sendable(),
351                )
352                .await
353                .unwrap();
354            let segments: Arc<dyn SegmentSource> = Arc::new(segments);
355
356            let expr = gt(row_idx(), lit(3u64));
357            let result =
358                RowIdxLayoutReader::new(0, layout.new_reader("".into(), segments).unwrap())
359                    .projection_evaluation(&(0..layout.row_count()), &expr)
360                    .unwrap()
361                    .invoke(Mask::new_true(layout.row_count().try_into().unwrap()))
362                    .await
363                    .unwrap()
364                    .to_bool()
365                    .unwrap();
366
367            assert_eq!(
368                &BooleanBuffer::from_iter([false, false, false, false, true]),
369                result.boolean_buffer()
370            );
371        })
372    }
373
374    #[test]
375    fn flat_expr_or() {
376        block_on(async {
377            let ctx = ArrayContext::empty();
378            let segments = TestSegments::default();
379            let sequence_writer = SequenceWriter::new(Box::new(segments.clone()));
380            let array = PrimitiveArray::from_iter(1..=5).to_array();
381            let array_clone = array.clone();
382            let layout = FlatLayoutStrategy::default()
383                .write_stream(
384                    &ctx,
385                    sequence_writer.clone(),
386                    SequentialStreamAdapter::new(
387                        array.dtype().clone(),
388                        stream::once(async { Ok((SequenceId::root().downgrade(), array_clone)) }),
389                    )
390                    .sendable(),
391                )
392                .await
393                .unwrap();
394            let segments: Arc<dyn SegmentSource> = Arc::new(segments);
395
396            let expr = or(
397                eq(root(), lit(3i32)),
398                or(gt(row_idx(), lit(3u64)), eq(root(), lit(1i32))),
399            );
400
401            let result =
402                RowIdxLayoutReader::new(0, layout.new_reader("".into(), segments).unwrap())
403                    .projection_evaluation(&(0..layout.row_count()), &expr)
404                    .unwrap()
405                    .invoke(Mask::new_true(layout.row_count().try_into().unwrap()))
406                    .await
407                    .unwrap()
408                    .to_bool()
409                    .unwrap();
410
411            assert_eq!(
412                vec![true, false, true, false, true],
413                result.boolean_buffer().iter().collect_vec()
414            );
415        })
416    }
417}