Skip to main content

vortex_layout/layouts/row_idx/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod expr;
5
6use std::collections::BTreeSet;
7use std::fmt::Display;
8use std::fmt::Formatter;
9use std::ops::BitAnd;
10use std::ops::Range;
11use std::sync::Arc;
12use std::sync::OnceLock;
13
14use Nullability::NonNullable;
15pub use expr::*;
16use futures::FutureExt;
17use futures::future::BoxFuture;
18use vortex_array::ArrayRef;
19use vortex_array::Canonical;
20use vortex_array::IntoArray;
21use vortex_array::MaskFuture;
22use vortex_array::VortexSessionExecute;
23use vortex_array::dtype::DType;
24use vortex_array::dtype::FieldMask;
25use vortex_array::dtype::FieldName;
26use vortex_array::dtype::Nullability;
27use vortex_array::dtype::PType;
28use vortex_array::expr::ExactExpr;
29use vortex_array::expr::Expression;
30use vortex_array::expr::is_root;
31use vortex_array::expr::root;
32use vortex_array::expr::transform::PartitionedExpr;
33use vortex_array::expr::transform::partition;
34use vortex_array::expr::transform::replace;
35use vortex_array::scalar::PValue;
36use vortex_error::VortexExpect;
37use vortex_error::VortexResult;
38use vortex_mask::Mask;
39use vortex_sequence::Sequence;
40use vortex_sequence::SequenceArray;
41use vortex_session::VortexSession;
42use vortex_utils::aliases::dash_map::DashMap;
43
44use crate::ArrayFuture;
45use crate::LayoutReader;
46use crate::layouts::partitioned::PartitionedExprEval;
47
48pub struct RowIdxLayoutReader {
49    name: Arc<str>,
50    row_offset: u64,
51    child: Arc<dyn LayoutReader>,
52    partition_cache: DashMap<ExactExpr, Arc<OnceLock<Partitioning>>>,
53    session: VortexSession,
54}
55
56impl RowIdxLayoutReader {
57    pub fn new(row_offset: u64, child: Arc<dyn LayoutReader>, session: VortexSession) -> Self {
58        Self {
59            name: Arc::clone(child.name()),
60            row_offset,
61            child,
62            partition_cache: DashMap::with_hasher(Default::default()),
63            session,
64        }
65    }
66
67    fn partition_expr(&self, expr: &Expression) -> Partitioning {
68        let key = ExactExpr(expr.clone());
69
70        // Check cache first with read-only lock.
71        if let Some(entry) = self.partition_cache.get(&key)
72            && let Some(partitioning) = entry.value().get()
73        {
74            return partitioning.clone();
75        }
76
77        let cell = self
78            .partition_cache
79            .entry(key)
80            .or_insert_with(|| Arc::new(OnceLock::new()))
81            .clone();
82
83        cell.get_or_init(|| self.compute_partitioning(expr)).clone()
84    }
85
86    fn compute_partitioning(&self, expr: &Expression) -> Partitioning {
87        // Partition the expression into row idx and child expressions.
88        let mut partitioned = partition(expr.clone(), self.dtype(), |expr| {
89            if expr.is::<RowIdx>() {
90                vec![Partition::RowIdx]
91            } else if is_root(expr) {
92                vec![Partition::Child]
93            } else {
94                vec![]
95            }
96        })
97        .vortex_expect("We should not fail to partition expression over struct fields");
98
99        // If there's only a single partition, we can directly return the expression.
100        if partitioned.partitions.len() == 1 {
101            return match &partitioned.partition_annotations[0] {
102                Partition::RowIdx => {
103                    Partitioning::RowIdx(replace(expr.clone(), &row_idx(), root()))
104                }
105                Partition::Child => Partitioning::Child(expr.clone()),
106            };
107        }
108
109        // Replace the row_idx expression with the root expression in the row_idx partition.
110        partitioned.partitions = partitioned
111            .partitions
112            .into_iter()
113            .map(|p| replace(p, &row_idx(), root()))
114            .collect();
115
116        Partitioning::Partitioned(Arc::new(partitioned))
117    }
118}
119
120#[derive(Clone)]
121enum Partitioning {
122    // An expression that only references the row index (e.g., `row_idx == 5`).
123    RowIdx(Expression),
124    // An expression that does not reference the row index.
125    Child(Expression),
126    // Contains both the RowIdx and Child expressions, (e.g., `row_idx < child.some_field`).
127    Partitioned(Arc<PartitionedExpr<Partition>>),
128}
129
130#[derive(Clone, PartialEq, Eq, Hash)]
131enum Partition {
132    RowIdx,
133    Child,
134}
135
136impl Partition {
137    pub fn name(&self) -> &str {
138        match self {
139            Partition::RowIdx => "row_idx",
140            Partition::Child => "child",
141        }
142    }
143}
144
145impl From<Partition> for FieldName {
146    fn from(value: Partition) -> Self {
147        FieldName::from(value.name())
148    }
149}
150
151impl Display for Partition {
152    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
153        write!(f, "{}", self.name())
154    }
155}
156
157impl LayoutReader for RowIdxLayoutReader {
158    fn name(&self) -> &Arc<str> {
159        &self.name
160    }
161
162    fn dtype(&self) -> &DType {
163        self.child.dtype()
164    }
165
166    fn row_count(&self) -> u64 {
167        self.child.row_count()
168    }
169
170    fn register_splits(
171        &self,
172        field_mask: &[FieldMask],
173        row_range: &Range<u64>,
174        splits: &mut BTreeSet<u64>,
175    ) -> VortexResult<()> {
176        self.child.register_splits(field_mask, row_range, splits)
177    }
178
179    fn pruning_evaluation(
180        &self,
181        row_range: &Range<u64>,
182        expr: &Expression,
183        mask: Mask,
184    ) -> VortexResult<MaskFuture> {
185        Ok(match &self.partition_expr(expr) {
186            Partitioning::RowIdx(expr) => row_idx_mask_future(
187                self.row_offset,
188                row_range,
189                expr,
190                MaskFuture::ready(mask),
191                self.session.clone(),
192            ),
193            Partitioning::Child(expr) => self.child.pruning_evaluation(row_range, expr, mask)?,
194            Partitioning::Partitioned(..) => MaskFuture::ready(mask),
195        })
196    }
197
198    fn filter_evaluation(
199        &self,
200        row_range: &Range<u64>,
201        expr: &Expression,
202        mask: MaskFuture,
203    ) -> VortexResult<MaskFuture> {
204        match &self.partition_expr(expr) {
205            // Since this is run during pruning, we skip re-evaluating the row index expression
206            // during the filter evaluation.
207            Partitioning::RowIdx(_) => Ok(mask),
208            Partitioning::Child(expr) => self.child.filter_evaluation(row_range, expr, mask),
209            Partitioning::Partitioned(p) => Arc::clone(p).into_mask_future(
210                mask,
211                |annotation, expr, mask| match annotation {
212                    Partition::RowIdx => Ok(row_idx_mask_future(
213                        self.row_offset,
214                        row_range,
215                        expr,
216                        mask,
217                        self.session.clone(),
218                    )),
219                    Partition::Child => self.child.filter_evaluation(row_range, expr, mask),
220                },
221                |annotation, expr, mask| match annotation {
222                    Partition::RowIdx => Ok(row_idx_array_future(
223                        self.row_offset,
224                        row_range,
225                        expr,
226                        mask,
227                        self.session.clone(),
228                    )),
229                    Partition::Child => self.child.projection_evaluation(row_range, expr, mask),
230                },
231                self.session.clone(),
232            ),
233        }
234    }
235
236    fn projection_evaluation(
237        &self,
238        row_range: &Range<u64>,
239        expr: &Expression,
240        mask: MaskFuture,
241    ) -> VortexResult<BoxFuture<'static, VortexResult<ArrayRef>>> {
242        match &self.partition_expr(expr) {
243            Partitioning::RowIdx(expr) => Ok(row_idx_array_future(
244                self.row_offset,
245                row_range,
246                expr,
247                mask,
248                self.session.clone(),
249            )),
250            Partitioning::Child(expr) => self.child.projection_evaluation(row_range, expr, mask),
251            Partitioning::Partitioned(p) => {
252                Arc::clone(p).into_array_future(mask, |annotation, expr, mask| match annotation {
253                    Partition::RowIdx => Ok(row_idx_array_future(
254                        self.row_offset,
255                        row_range,
256                        expr,
257                        mask,
258                        self.session.clone(),
259                    )),
260                    Partition::Child => self.child.projection_evaluation(row_range, expr, mask),
261                })
262            }
263        }
264    }
265
266    fn as_any(&self) -> &dyn std::any::Any {
267        self
268    }
269}
270
271// Returns a SequenceArray representing the row indices for the given row range,
272fn idx_array(row_offset: u64, row_range: &Range<u64>) -> SequenceArray {
273    Sequence::try_new(
274        PValue::U64(row_offset + row_range.start),
275        PValue::U64(1),
276        PType::U64,
277        NonNullable,
278        usize::try_from(row_range.end - row_range.start)
279            .vortex_expect("Row range length must fit in usize"),
280    )
281    .vortex_expect("Failed to create row index array")
282}
283
284fn row_idx_mask_future(
285    row_offset: u64,
286    row_range: &Range<u64>,
287    expr: &Expression,
288    mask: MaskFuture,
289    session: VortexSession,
290) -> MaskFuture {
291    let row_range = row_range.clone();
292    let expr = expr.clone();
293    MaskFuture::new(mask.len(), async move {
294        let array = idx_array(row_offset, &row_range).into_array();
295
296        let mut ctx = session.create_execution_ctx();
297        let result_mask = array.apply(&expr)?.execute::<Mask>(&mut ctx)?;
298
299        Ok(result_mask.bitand(&mask.await?))
300    })
301}
302
303fn row_idx_array_future(
304    row_offset: u64,
305    row_range: &Range<u64>,
306    expr: &Expression,
307    mask: MaskFuture,
308    session: VortexSession,
309) -> ArrayFuture {
310    let row_range = row_range.clone();
311    let expr = expr.clone();
312    async move {
313        let array = idx_array(row_offset, &row_range).into_array();
314        let filtered = array.filter(mask.await?)?;
315        let mut ctx = session.create_execution_ctx();
316        let array = filtered.execute::<Canonical>(&mut ctx)?.into_array();
317        array.apply(&expr)
318    }
319    .boxed()
320}
321
322#[cfg(test)]
323mod tests {
324    use std::sync::Arc;
325
326    use vortex_array::ArrayContext;
327    use vortex_array::IntoArray as _;
328    use vortex_array::MaskFuture;
329    use vortex_array::arrays::BoolArray;
330    use vortex_array::assert_arrays_eq;
331    use vortex_array::expr::eq;
332    use vortex_array::expr::gt;
333    use vortex_array::expr::lit;
334    use vortex_array::expr::or;
335    use vortex_array::expr::root;
336    use vortex_buffer::buffer;
337    use vortex_io::runtime::single::block_on;
338    use vortex_io::session::RuntimeSessionExt;
339
340    use crate::LayoutReader;
341    use crate::LayoutStrategy;
342    use crate::layouts::flat::writer::FlatLayoutStrategy;
343    use crate::layouts::row_idx::RowIdxLayoutReader;
344    use crate::layouts::row_idx::row_idx;
345    use crate::segments::TestSegments;
346    use crate::sequence::SequenceId;
347    use crate::sequence::SequentialArrayStreamExt;
348    use crate::test::SESSION;
349
350    #[test]
351    fn flat_expr_no_row_id() {
352        block_on(|handle| async {
353            let session = SESSION.clone().with_handle(handle);
354            let ctx = ArrayContext::empty();
355            let segments = Arc::new(TestSegments::default());
356            let (ptr, eof) = SequenceId::root().split();
357            let array = buffer![1..=5].into_array();
358            let layout = FlatLayoutStrategy::default()
359                .write_stream(
360                    ctx,
361                    Arc::<TestSegments>::clone(&segments),
362                    array.to_array_stream().sequenced(ptr),
363                    eof,
364                    &session,
365                )
366                .await
367                .unwrap();
368
369            let expr = eq(root(), lit(3i32));
370            let result = RowIdxLayoutReader::new(
371                0,
372                layout.new_reader("".into(), segments, &SESSION).unwrap(),
373                SESSION.clone(),
374            )
375            .projection_evaluation(
376                &(0..layout.row_count()),
377                &expr,
378                MaskFuture::new_true(layout.row_count().try_into().unwrap()),
379            )
380            .unwrap()
381            .await
382            .unwrap();
383
384            assert_arrays_eq!(
385                result,
386                BoolArray::from_iter([false, false, true, false, false])
387            );
388        })
389    }
390
391    #[test]
392    fn flat_expr_row_id() {
393        block_on(|handle| async {
394            let session = SESSION.clone().with_handle(handle);
395            let ctx = ArrayContext::empty();
396            let segments = Arc::new(TestSegments::default());
397            let (ptr, eof) = SequenceId::root().split();
398            let array = buffer![1..=5].into_array();
399            let layout = FlatLayoutStrategy::default()
400                .write_stream(
401                    ctx,
402                    Arc::<TestSegments>::clone(&segments),
403                    array.to_array_stream().sequenced(ptr),
404                    eof,
405                    &session,
406                )
407                .await
408                .unwrap();
409
410            let expr = gt(row_idx(), lit(3u64));
411            let result = RowIdxLayoutReader::new(
412                0,
413                layout.new_reader("".into(), segments, &SESSION).unwrap(),
414                SESSION.clone(),
415            )
416            .projection_evaluation(
417                &(0..layout.row_count()),
418                &expr,
419                MaskFuture::new_true(layout.row_count().try_into().unwrap()),
420            )
421            .unwrap()
422            .await
423            .unwrap();
424
425            assert_arrays_eq!(
426                result,
427                BoolArray::from_iter([false, false, false, false, true])
428            );
429        })
430    }
431
432    #[test]
433    fn flat_expr_or() {
434        block_on(|handle| async {
435            let session = SESSION.clone().with_handle(handle);
436            let ctx = ArrayContext::empty();
437            let segments = Arc::new(TestSegments::default());
438            let (ptr, eof) = SequenceId::root().split();
439            let array = buffer![1..=5].into_array();
440            let layout = FlatLayoutStrategy::default()
441                .write_stream(
442                    ctx,
443                    Arc::<TestSegments>::clone(&segments),
444                    array.to_array_stream().sequenced(ptr),
445                    eof,
446                    &session,
447                )
448                .await
449                .unwrap();
450
451            let expr = or(
452                eq(root(), lit(3i32)),
453                or(gt(row_idx(), lit(3u64)), eq(root(), lit(1i32))),
454            );
455
456            let result = RowIdxLayoutReader::new(
457                0,
458                layout.new_reader("".into(), segments, &SESSION).unwrap(),
459                SESSION.clone(),
460            )
461            .projection_evaluation(
462                &(0..layout.row_count()),
463                &expr,
464                MaskFuture::new_true(layout.row_count().try_into().unwrap()),
465            )
466            .unwrap()
467            .await
468            .unwrap();
469
470            assert_arrays_eq!(
471                result,
472                BoolArray::from_iter([true, false, true, false, true])
473            );
474        })
475    }
476}