vortex_layout/layouts/row_idx/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod expr;
5
6use std::collections::BTreeSet;
7use std::fmt::{Display, Formatter};
8use std::ops::{BitAnd, Range};
9use std::sync::Arc;
10
11use Nullability::NonNullable;
12use async_trait::async_trait;
13use dashmap::DashMap;
14pub use expr::*;
15use vortex_array::compute::filter;
16use vortex_array::stats::Precision;
17use vortex_array::{ArrayRef, IntoArray};
18use vortex_dtype::{DType, FieldMask, Nullability, PType};
19use vortex_error::{VortexExpect, VortexResult};
20use vortex_expr::transform::{PartitionedExpr, partition, replace};
21use vortex_expr::{ExactExpr, ExprRef, Scope, is_root, root};
22use vortex_mask::Mask;
23use vortex_scalar::PValue;
24use vortex_sequence::SequenceArray;
25
26use crate::layouts::partitioned::{PartitionedArrayEvaluation, PartitionedMaskEvaluation};
27use crate::{
28    ArrayEvaluation, LayoutReader, MaskEvaluation, NoOpMaskEvaluation, NoOpPruningEvaluation,
29    PruningEvaluation,
30};
31
32pub struct RowIdxLayoutReader {
33    name: Arc<str>,
34    row_offset: u64,
35    child: Arc<dyn LayoutReader>,
36
37    partition_cache: DashMap<ExactExpr, Partitioning>,
38}
39
40impl RowIdxLayoutReader {
41    pub fn new(row_offset: u64, child: Arc<dyn LayoutReader>) -> Self {
42        Self {
43            name: child.name().clone(),
44            row_offset,
45            child,
46            partition_cache: DashMap::new(),
47        }
48    }
49
50    fn partition_expr(&self, expr: &ExprRef) -> Partitioning {
51        self.partition_cache
52            .entry(ExactExpr(expr.clone()))
53            .or_insert_with(|| {
54                // Partition the expression into row idx and child expressions.
55                let mut partitioned = partition(expr.clone(), self.dtype(), |expr| {
56                    if expr.is::<RowIdxVTable>() {
57                        vec![Partition::RowIdx]
58                    } else if is_root(expr) {
59                        vec![Partition::Child]
60                    } else {
61                        vec![]
62                    }
63                })
64                .vortex_expect("We should not fail to partition expression over struct fields");
65
66                // If there's only a single partition, we can directly return the expression.
67                if partitioned.partitions.len() == 1 {
68                    return match &partitioned.partition_annotations[0] {
69                        Partition::RowIdx => {
70                            Partitioning::RowIdx(replace(expr.clone(), &row_idx(), root()))
71                        }
72                        Partition::Child => Partitioning::Child(expr.clone()),
73                    };
74                }
75
76                // Replace the row_idx expression with the root expression in the row_idx partition.
77                partitioned.partitions = partitioned
78                    .partitions
79                    .into_iter()
80                    .map(|p| replace(p, &row_idx(), root()))
81                    .collect();
82
83                Partitioning::Partitioned(Arc::new(partitioned))
84            })
85            .clone()
86    }
87}
88
89#[derive(Clone)]
90enum Partitioning {
91    // An expression that only references the row index (e.g., `row_idx == 5`).
92    RowIdx(ExprRef),
93    // An expression that does not reference the row index.
94    Child(ExprRef),
95    // Contains both the RowIdx and Child expressions, (e.g., `row_idx < child.some_field`).
96    Partitioned(Arc<PartitionedExpr<Partition>>),
97}
98
99#[derive(Clone, PartialEq, Eq, Hash)]
100enum Partition {
101    RowIdx,
102    Child,
103}
104
105impl Display for Partition {
106    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
107        match self {
108            Partition::RowIdx => write!(f, "row_idx"),
109            Partition::Child => write!(f, "child"),
110        }
111    }
112}
113
114impl LayoutReader for RowIdxLayoutReader {
115    fn name(&self) -> &Arc<str> {
116        &self.name
117    }
118
119    fn dtype(&self) -> &DType {
120        self.child.dtype()
121    }
122
123    fn row_count(&self) -> Precision<u64> {
124        self.child.row_count()
125    }
126
127    fn register_splits(
128        &self,
129        field_mask: &[FieldMask],
130        row_offset: u64,
131        splits: &mut BTreeSet<u64>,
132    ) -> VortexResult<()> {
133        self.child.register_splits(field_mask, row_offset, splits)
134    }
135
136    fn pruning_evaluation(
137        &self,
138        row_range: &Range<u64>,
139        expr: &ExprRef,
140    ) -> VortexResult<Box<dyn PruningEvaluation>> {
141        match &self.partition_expr(expr) {
142            Partitioning::RowIdx(expr) => Ok(Box::new(RowIdxEvaluation::new(
143                self.row_offset,
144                row_range,
145                expr,
146            ))),
147            Partitioning::Child(expr) => self.child.pruning_evaluation(row_range, expr),
148            Partitioning::Partitioned(..) => Ok(Box::new(NoOpPruningEvaluation)),
149        }
150    }
151
152    fn filter_evaluation(
153        &self,
154        row_range: &Range<u64>,
155        expr: &ExprRef,
156    ) -> VortexResult<Box<dyn MaskEvaluation>> {
157        match &self.partition_expr(expr) {
158            // Since this is run during pruning, we skip re-evaluating the row index expression
159            // during the filter evaluation.
160            Partitioning::RowIdx(_) => Ok(Box::new(NoOpMaskEvaluation)),
161            Partitioning::Child(expr) => self.child.filter_evaluation(row_range, expr),
162            Partitioning::Partitioned(p) => Ok(Box::new(PartitionedMaskEvaluation::try_new(
163                p.clone(),
164                |annotation, expr| match annotation {
165                    Partition::RowIdx => Ok(Box::new(RowIdxEvaluation::new(
166                        self.row_offset,
167                        row_range,
168                        expr,
169                    ))),
170                    Partition::Child => self.child.filter_evaluation(row_range, expr),
171                },
172                |annotation, expr| match annotation {
173                    Partition::RowIdx => Ok(Box::new(RowIdxEvaluation::new(
174                        self.row_offset,
175                        row_range,
176                        expr,
177                    ))),
178                    Partition::Child => self.child.projection_evaluation(row_range, expr),
179                },
180            )?)),
181        }
182    }
183
184    fn projection_evaluation(
185        &self,
186        row_range: &Range<u64>,
187        expr: &ExprRef,
188    ) -> VortexResult<Box<dyn ArrayEvaluation>> {
189        match &self.partition_expr(expr) {
190            Partitioning::RowIdx(expr) => Ok(Box::new(RowIdxEvaluation::new(
191                self.row_offset,
192                row_range,
193                expr,
194            ))),
195            Partitioning::Child(expr) => self.child.projection_evaluation(row_range, expr),
196            Partitioning::Partitioned(p) => Ok(Box::new(PartitionedArrayEvaluation::try_new(
197                p.clone(),
198                |annotation, expr| match annotation {
199                    Partition::RowIdx => Ok(Box::new(RowIdxEvaluation::new(
200                        self.row_offset,
201                        row_range,
202                        expr,
203                    ))),
204                    Partition::Child => self.child.projection_evaluation(row_range, expr),
205                },
206            )?)),
207        }
208    }
209}
210
211/// We need a custom RowIdx evaluation because we need to defer creating the SequenceArray until
212/// we are given the final row_offset. We cannot just create a RowIdxLayout that spans the entire
213/// dataset because arrays can only cover up to usize rows, not u64.
214struct RowIdxEvaluation {
215    array: ArrayRef,
216    expr: ExprRef,
217}
218
219impl RowIdxEvaluation {
220    fn new(row_offset: u64, row_range: &Range<u64>, expr: &ExprRef) -> Self {
221        let array = SequenceArray::new(
222            PValue::U64(row_offset + row_range.start),
223            PValue::U64(1),
224            PType::U64,
225            NonNullable,
226            usize::try_from(row_range.end - row_range.start)
227                .vortex_expect("Row range length must fit in usize"),
228        )
229        .vortex_expect("Failed to create row index array");
230
231        Self {
232            array: array.into_array(),
233            expr: expr.clone(),
234        }
235    }
236}
237
238#[async_trait]
239impl PruningEvaluation for RowIdxEvaluation {
240    async fn invoke(&self, _mask: Mask) -> VortexResult<Mask> {
241        // TODO(ngates): we could optimize this if the mask was already quite sparse.
242        Mask::try_from(
243            self.expr
244                .evaluate(&Scope::new(self.array.clone()))?
245                .as_ref(),
246        )
247    }
248}
249
250#[async_trait]
251impl MaskEvaluation for RowIdxEvaluation {
252    async fn invoke(&self, mask: Mask) -> VortexResult<Mask> {
253        // TODO(ngates): we could optimize this if the mask was already quite sparse.
254        let result = Mask::try_from(
255            self.expr
256                .evaluate(&Scope::new(self.array.clone()))?
257                .as_ref(),
258        )?;
259
260        // Note that mask evaluation requires an intersection with the input mask, whereas
261        // pruning evaluation does not.
262        Ok(result.bitand(&mask))
263    }
264}
265
266#[async_trait]
267impl ArrayEvaluation for RowIdxEvaluation {
268    async fn invoke(&self, mask: Mask) -> VortexResult<ArrayRef> {
269        let array = filter(&self.array, &mask)?;
270        self.expr.evaluate(&Scope::new(array))
271    }
272}
273
274#[cfg(test)]
275mod tests {
276    use std::sync::Arc;
277
278    use arrow_buffer::BooleanBuffer;
279    use futures::executor::block_on;
280    use futures::stream;
281    use itertools::Itertools;
282    use vortex_array::arrays::PrimitiveArray;
283    use vortex_array::{ArrayContext, ToCanonical};
284    use vortex_expr::{eq, gt, lit, or, root};
285    use vortex_mask::Mask;
286
287    use crate::layouts::flat::writer::FlatLayoutStrategy;
288    use crate::layouts::row_idx::{RowIdxLayoutReader, row_idx};
289    use crate::segments::{SegmentSource, SequenceWriter, TestSegments};
290    use crate::sequence::SequenceId;
291    use crate::{LayoutReader, LayoutStrategy, SequentialStreamAdapter, SequentialStreamExt};
292
293    #[test]
294    fn flat_expr_no_row_id() {
295        block_on(async {
296            let ctx = ArrayContext::empty();
297            let segments = TestSegments::default();
298            let sequence_writer = SequenceWriter::new(Box::new(segments.clone()));
299            let array = PrimitiveArray::from_iter(1..=5).to_array();
300            let array_clone = array.clone();
301            let layout = FlatLayoutStrategy::default()
302                .write_stream(
303                    &ctx,
304                    sequence_writer.clone(),
305                    SequentialStreamAdapter::new(
306                        array.dtype().clone(),
307                        stream::once(async { Ok((SequenceId::root().downgrade(), array_clone)) }),
308                    )
309                    .sendable(),
310                )
311                .await
312                .unwrap();
313            let segments: Arc<dyn SegmentSource> = Arc::new(segments);
314
315            let expr = eq(root(), lit(3i32));
316            let result =
317                RowIdxLayoutReader::new(0, layout.new_reader("".into(), segments).unwrap())
318                    .projection_evaluation(&(0..layout.row_count()), &expr)
319                    .unwrap()
320                    .invoke(Mask::new_true(layout.row_count().try_into().unwrap()))
321                    .await
322                    .unwrap()
323                    .to_bool()
324                    .unwrap();
325
326            assert_eq!(
327                &BooleanBuffer::from_iter([false, false, true, false, false]),
328                result.boolean_buffer()
329            );
330        })
331    }
332
333    #[test]
334    fn flat_expr_row_id() {
335        block_on(async {
336            let ctx = ArrayContext::empty();
337            let segments = TestSegments::default();
338            let sequence_writer = SequenceWriter::new(Box::new(segments.clone()));
339            let array = PrimitiveArray::from_iter(1..=5).to_array();
340            let array_clone = array.clone();
341            let layout = FlatLayoutStrategy::default()
342                .write_stream(
343                    &ctx,
344                    sequence_writer.clone(),
345                    SequentialStreamAdapter::new(
346                        array.dtype().clone(),
347                        stream::once(async { Ok((SequenceId::root().downgrade(), array_clone)) }),
348                    )
349                    .sendable(),
350                )
351                .await
352                .unwrap();
353            let segments: Arc<dyn SegmentSource> = Arc::new(segments);
354
355            let expr = gt(row_idx(), lit(3u64));
356            let result =
357                RowIdxLayoutReader::new(0, layout.new_reader("".into(), segments).unwrap())
358                    .projection_evaluation(&(0..layout.row_count()), &expr)
359                    .unwrap()
360                    .invoke(Mask::new_true(layout.row_count().try_into().unwrap()))
361                    .await
362                    .unwrap()
363                    .to_bool()
364                    .unwrap();
365
366            assert_eq!(
367                &BooleanBuffer::from_iter([false, false, false, false, true]),
368                result.boolean_buffer()
369            );
370        })
371    }
372
373    #[test]
374    fn flat_expr_or() {
375        block_on(async {
376            let ctx = ArrayContext::empty();
377            let segments = TestSegments::default();
378            let sequence_writer = SequenceWriter::new(Box::new(segments.clone()));
379            let array = PrimitiveArray::from_iter(1..=5).to_array();
380            let array_clone = array.clone();
381            let layout = FlatLayoutStrategy::default()
382                .write_stream(
383                    &ctx,
384                    sequence_writer.clone(),
385                    SequentialStreamAdapter::new(
386                        array.dtype().clone(),
387                        stream::once(async { Ok((SequenceId::root().downgrade(), array_clone)) }),
388                    )
389                    .sendable(),
390                )
391                .await
392                .unwrap();
393            let segments: Arc<dyn SegmentSource> = Arc::new(segments);
394
395            let expr = or(
396                eq(root(), lit(3i32)),
397                or(gt(row_idx(), lit(3u64)), eq(root(), lit(1i32))),
398            );
399
400            let result =
401                RowIdxLayoutReader::new(0, layout.new_reader("".into(), segments).unwrap())
402                    .projection_evaluation(&(0..layout.row_count()), &expr)
403                    .unwrap()
404                    .invoke(Mask::new_true(layout.row_count().try_into().unwrap()))
405                    .await
406                    .unwrap()
407                    .to_bool()
408                    .unwrap();
409
410            assert_eq!(
411                vec![true, false, true, false, true],
412                result.boolean_buffer().iter().collect_vec()
413            );
414        })
415    }
416}