vortex_layout/layouts/row_idx/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod expr;
5
6use std::collections::BTreeSet;
7use std::fmt::Display;
8use std::fmt::Formatter;
9use std::ops::BitAnd;
10use std::ops::Range;
11use std::sync::Arc;
12
13use Nullability::NonNullable;
14pub use expr::*;
15use futures::FutureExt;
16use futures::future::BoxFuture;
17use vortex_array::ArrayRef;
18use vortex_array::IntoArray;
19use vortex_array::MaskFuture;
20use vortex_array::compute::filter;
21use vortex_array::expr::ExactExpr;
22use vortex_array::expr::Expression;
23use vortex_array::expr::is_root;
24use vortex_array::expr::root;
25use vortex_array::expr::session::ExprSessionExt;
26use vortex_array::expr::transform::ExprOptimizer;
27use vortex_array::expr::transform::PartitionedExpr;
28use vortex_array::expr::transform::partition;
29use vortex_array::expr::transform::replace;
30use vortex_dtype::DType;
31use vortex_dtype::FieldMask;
32use vortex_dtype::FieldName;
33use vortex_dtype::Nullability;
34use vortex_dtype::PType;
35use vortex_error::VortexExpect;
36use vortex_error::VortexResult;
37use vortex_mask::Mask;
38use vortex_scalar::PValue;
39use vortex_sequence::SequenceArray;
40use vortex_session::VortexSession;
41use vortex_utils::aliases::dash_map::DashMap;
42
43use crate::ArrayFuture;
44use crate::LayoutReader;
45use crate::layouts::partitioned::PartitionedExprEval;
46
47pub struct RowIdxLayoutReader {
48    name: Arc<str>,
49    row_offset: u64,
50    child: Arc<dyn LayoutReader>,
51    partition_cache: DashMap<ExactExpr, Partitioning>,
52    expr_optimizer: ExprOptimizer,
53}
54
55impl RowIdxLayoutReader {
56    pub fn new(row_offset: u64, child: Arc<dyn LayoutReader>, session: &VortexSession) -> Self {
57        let expr_optimizer = ExprOptimizer::new(&session.expressions());
58        Self {
59            name: child.name().clone(),
60            row_offset,
61            child,
62            partition_cache: DashMap::with_hasher(Default::default()),
63            expr_optimizer,
64        }
65    }
66
67    fn partition_expr(&self, expr: &Expression) -> Partitioning {
68        self.partition_cache
69            .entry(ExactExpr(expr.clone()))
70            .or_insert_with(|| {
71                // Partition the expression into row idx and child expressions.
72                let mut partitioned = partition(
73                    expr.clone(),
74                    self.dtype(),
75                    |expr| {
76                        if expr.is::<RowIdx>() {
77                            vec![Partition::RowIdx]
78                        } else if is_root(expr) {
79                            vec![Partition::Child]
80                        } else {
81                            vec![]
82                        }
83                    },
84                    &self.expr_optimizer,
85                )
86                .vortex_expect("We should not fail to partition expression over struct fields");
87
88                // If there's only a single partition, we can directly return the expression.
89                if partitioned.partitions.len() == 1 {
90                    return match &partitioned.partition_annotations[0] {
91                        Partition::RowIdx => {
92                            Partitioning::RowIdx(replace(expr.clone(), &row_idx(), root()))
93                        }
94                        Partition::Child => Partitioning::Child(expr.clone()),
95                    };
96                }
97
98                // Replace the row_idx expression with the root expression in the row_idx partition.
99                partitioned.partitions = partitioned
100                    .partitions
101                    .into_iter()
102                    .map(|p| replace(p, &row_idx(), root()))
103                    .collect();
104
105                Partitioning::Partitioned(Arc::new(partitioned))
106            })
107            .clone()
108    }
109}
110
111#[derive(Clone)]
112enum Partitioning {
113    // An expression that only references the row index (e.g., `row_idx == 5`).
114    RowIdx(Expression),
115    // An expression that does not reference the row index.
116    Child(Expression),
117    // Contains both the RowIdx and Child expressions, (e.g., `row_idx < child.some_field`).
118    Partitioned(Arc<PartitionedExpr<Partition>>),
119}
120
121#[derive(Clone, PartialEq, Eq, Hash)]
122enum Partition {
123    RowIdx,
124    Child,
125}
126
127impl Partition {
128    pub fn name(&self) -> &str {
129        match self {
130            Partition::RowIdx => "row_idx",
131            Partition::Child => "child",
132        }
133    }
134}
135
136impl From<Partition> for FieldName {
137    fn from(value: Partition) -> Self {
138        FieldName::from(value.name())
139    }
140}
141
142impl Display for Partition {
143    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
144        write!(f, "{}", self.name())
145    }
146}
147
148impl LayoutReader for RowIdxLayoutReader {
149    fn name(&self) -> &Arc<str> {
150        &self.name
151    }
152
153    fn dtype(&self) -> &DType {
154        self.child.dtype()
155    }
156
157    fn row_count(&self) -> u64 {
158        self.child.row_count()
159    }
160
161    fn register_splits(
162        &self,
163        field_mask: &[FieldMask],
164        row_range: &Range<u64>,
165        splits: &mut BTreeSet<u64>,
166    ) -> VortexResult<()> {
167        self.child.register_splits(field_mask, row_range, splits)
168    }
169
170    fn pruning_evaluation(
171        &self,
172        row_range: &Range<u64>,
173        expr: &Expression,
174        mask: Mask,
175    ) -> VortexResult<MaskFuture> {
176        Ok(match &self.partition_expr(expr) {
177            Partitioning::RowIdx(expr) => {
178                row_idx_mask_future(self.row_offset, row_range, expr, MaskFuture::ready(mask))
179            }
180            Partitioning::Child(expr) => self.child.pruning_evaluation(row_range, expr, mask)?,
181            Partitioning::Partitioned(..) => MaskFuture::ready(mask),
182        })
183    }
184
185    fn filter_evaluation(
186        &self,
187        row_range: &Range<u64>,
188        expr: &Expression,
189        mask: MaskFuture,
190    ) -> VortexResult<MaskFuture> {
191        match &self.partition_expr(expr) {
192            // Since this is run during pruning, we skip re-evaluating the row index expression
193            // during the filter evaluation.
194            Partitioning::RowIdx(_) => Ok(mask),
195            Partitioning::Child(expr) => self.child.filter_evaluation(row_range, expr, mask),
196            Partitioning::Partitioned(p) => p.clone().into_mask_future(
197                mask,
198                |annotation, expr, mask| match annotation {
199                    Partition::RowIdx => {
200                        Ok(row_idx_mask_future(self.row_offset, row_range, expr, mask))
201                    }
202                    Partition::Child => self.child.filter_evaluation(row_range, expr, mask),
203                },
204                |annotation, expr, mask| match annotation {
205                    Partition::RowIdx => {
206                        Ok(row_idx_array_future(self.row_offset, row_range, expr, mask))
207                    }
208                    Partition::Child => self.child.projection_evaluation(row_range, expr, mask),
209                },
210            ),
211        }
212    }
213
214    fn projection_evaluation(
215        &self,
216        row_range: &Range<u64>,
217        expr: &Expression,
218        mask: MaskFuture,
219    ) -> VortexResult<BoxFuture<'static, VortexResult<ArrayRef>>> {
220        match &self.partition_expr(expr) {
221            Partitioning::RowIdx(expr) => {
222                Ok(row_idx_array_future(self.row_offset, row_range, expr, mask))
223            }
224            Partitioning::Child(expr) => self.child.projection_evaluation(row_range, expr, mask),
225            Partitioning::Partitioned(p) => {
226                p.clone()
227                    .into_array_future(mask, |annotation, expr, mask| match annotation {
228                        Partition::RowIdx => {
229                            Ok(row_idx_array_future(self.row_offset, row_range, expr, mask))
230                        }
231                        Partition::Child => self.child.projection_evaluation(row_range, expr, mask),
232                    })
233            }
234        }
235    }
236}
237
238// Returns a SequenceArray representing the row indices for the given row range,
239fn idx_array(row_offset: u64, row_range: &Range<u64>) -> SequenceArray {
240    SequenceArray::new(
241        PValue::U64(row_offset + row_range.start),
242        PValue::U64(1),
243        PType::U64,
244        NonNullable,
245        usize::try_from(row_range.end - row_range.start)
246            .vortex_expect("Row range length must fit in usize"),
247    )
248    .vortex_expect("Failed to create row index array")
249}
250
251fn row_idx_mask_future(
252    row_offset: u64,
253    row_range: &Range<u64>,
254    expr: &Expression,
255    mask: MaskFuture,
256) -> MaskFuture {
257    let row_range = row_range.clone();
258    let expr = expr.clone();
259    MaskFuture::new(mask.len(), async move {
260        let array = idx_array(row_offset, &row_range).into_array();
261        let result_mask = expr.evaluate(&array)?.try_to_mask_fill_null_false()?;
262        Ok(result_mask.bitand(&mask.await?))
263    })
264}
265
266fn row_idx_array_future(
267    row_offset: u64,
268    row_range: &Range<u64>,
269    expr: &Expression,
270    mask: MaskFuture,
271) -> ArrayFuture {
272    let row_range = row_range.clone();
273    let expr = expr.clone();
274    async move {
275        let array = idx_array(row_offset, &row_range).into_array();
276        let array = filter(&array, &mask.await?)?;
277        expr.evaluate(&array)
278    }
279    .boxed()
280}
281
282#[cfg(test)]
283mod tests {
284    use std::sync::Arc;
285
286    use itertools::Itertools;
287    use vortex_array::ArrayContext;
288    use vortex_array::IntoArray as _;
289    use vortex_array::MaskFuture;
290    use vortex_array::ToCanonical;
291    use vortex_array::expr::eq;
292    use vortex_array::expr::gt;
293    use vortex_array::expr::lit;
294    use vortex_array::expr::or;
295    use vortex_array::expr::root;
296    use vortex_buffer::BitBuffer;
297    use vortex_buffer::buffer;
298    use vortex_io::runtime::single::block_on;
299
300    use crate::LayoutReader;
301    use crate::LayoutStrategy;
302    use crate::layouts::flat::writer::FlatLayoutStrategy;
303    use crate::layouts::row_idx::RowIdxLayoutReader;
304    use crate::layouts::row_idx::row_idx;
305    use crate::segments::TestSegments;
306    use crate::sequence::SequenceId;
307    use crate::sequence::SequentialArrayStreamExt;
308    use crate::test::SESSION;
309
310    #[test]
311    fn flat_expr_no_row_id() {
312        block_on(|handle| async {
313            let ctx = ArrayContext::empty();
314            let segments = Arc::new(TestSegments::default());
315            let (ptr, eof) = SequenceId::root().split();
316            let array = buffer![1..=5].into_array();
317            let layout = FlatLayoutStrategy::default()
318                .write_stream(
319                    ctx,
320                    segments.clone(),
321                    array.to_array_stream().sequenced(ptr),
322                    eof,
323                    handle,
324                )
325                .await
326                .unwrap();
327
328            let expr = eq(root(), lit(3i32));
329            let result = RowIdxLayoutReader::new(
330                0,
331                layout.new_reader("".into(), segments, &SESSION).unwrap(),
332                &SESSION,
333            )
334            .projection_evaluation(
335                &(0..layout.row_count()),
336                &expr,
337                MaskFuture::new_true(layout.row_count().try_into().unwrap()),
338            )
339            .unwrap()
340            .await
341            .unwrap()
342            .to_bool();
343
344            assert_eq!(
345                &BitBuffer::from_iter([false, false, true, false, false]),
346                result.bit_buffer()
347            );
348        })
349    }
350
351    #[test]
352    fn flat_expr_row_id() {
353        block_on(|handle| async {
354            let ctx = ArrayContext::empty();
355            let segments = Arc::new(TestSegments::default());
356            let (ptr, eof) = SequenceId::root().split();
357            let array = buffer![1..=5].into_array();
358            let layout = FlatLayoutStrategy::default()
359                .write_stream(
360                    ctx,
361                    segments.clone(),
362                    array.to_array_stream().sequenced(ptr),
363                    eof,
364                    handle,
365                )
366                .await
367                .unwrap();
368
369            let expr = gt(row_idx(), lit(3u64));
370            let result = RowIdxLayoutReader::new(
371                0,
372                layout.new_reader("".into(), segments, &SESSION).unwrap(),
373                &SESSION,
374            )
375            .projection_evaluation(
376                &(0..layout.row_count()),
377                &expr,
378                MaskFuture::new_true(layout.row_count().try_into().unwrap()),
379            )
380            .unwrap()
381            .await
382            .unwrap()
383            .to_bool();
384
385            assert_eq!(
386                &BitBuffer::from_iter([false, false, false, false, true]),
387                result.bit_buffer()
388            );
389        })
390    }
391
392    #[test]
393    fn flat_expr_or() {
394        block_on(|handle| async {
395            let ctx = ArrayContext::empty();
396            let segments = Arc::new(TestSegments::default());
397            let (ptr, eof) = SequenceId::root().split();
398            let array = buffer![1..=5].into_array();
399            let layout = FlatLayoutStrategy::default()
400                .write_stream(
401                    ctx,
402                    segments.clone(),
403                    array.to_array_stream().sequenced(ptr),
404                    eof,
405                    handle,
406                )
407                .await
408                .unwrap();
409
410            let expr = or(
411                eq(root(), lit(3i32)),
412                or(gt(row_idx(), lit(3u64)), eq(root(), lit(1i32))),
413            );
414
415            let result = RowIdxLayoutReader::new(
416                0,
417                layout.new_reader("".into(), segments, &SESSION).unwrap(),
418                &SESSION,
419            )
420            .projection_evaluation(
421                &(0..layout.row_count()),
422                &expr,
423                MaskFuture::new_true(layout.row_count().try_into().unwrap()),
424            )
425            .unwrap()
426            .await
427            .unwrap()
428            .to_bool();
429
430            assert_eq!(
431                vec![true, false, true, false, true],
432                result.bit_buffer().iter().collect_vec()
433            );
434        })
435    }
436}