vortex_array/operator/
slice.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::hash::Hash;
6use std::ops::Range;
7use std::sync::Arc;
8
9use async_trait::async_trait;
10use itertools::Itertools;
11use vortex_dtype::DType;
12use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail};
13
14use crate::operator::{
15    BatchBindCtx, BatchExecution, BatchExecutionRef, BatchOperator, LengthBounds, Operator,
16    OperatorEq, OperatorHash, OperatorId, OperatorRef,
17};
18use crate::{Array, Canonical, IntoArray};
19
20#[derive(Debug, Clone)]
21pub struct SliceOperator {
22    child: OperatorRef,
23    range: Range<usize>,
24}
25
26impl SliceOperator {
27    pub fn try_new(child: OperatorRef, range: Range<usize>) -> VortexResult<Self> {
28        if range.start > range.end {
29            vortex_bail!(
30                "invalid slice range: start > end ({} > {})",
31                range.start,
32                range.end
33            );
34        }
35        if range.end > child.bounds().max {
36            vortex_bail!(
37                "slice range end out of bounds: {} > {}",
38                range.end,
39                child.bounds().max
40            );
41        }
42        Ok(SliceOperator { child, range })
43    }
44
45    pub fn range(&self) -> &Range<usize> {
46        &self.range
47    }
48}
49
50impl OperatorHash for SliceOperator {
51    fn operator_hash<H: std::hash::Hasher>(&self, state: &mut H) {
52        self.child.operator_hash(state);
53        self.range.hash(state);
54    }
55}
56
57impl OperatorEq for SliceOperator {
58    fn operator_eq(&self, other: &Self) -> bool {
59        self.range == other.range && self.child.operator_eq(&other.child)
60    }
61}
62
63impl Operator for SliceOperator {
64    fn id(&self) -> OperatorId {
65        OperatorId::from("vortex.slice")
66    }
67
68    fn as_any(&self) -> &dyn Any {
69        self
70    }
71
72    fn dtype(&self) -> &DType {
73        self.child.dtype()
74    }
75
76    fn bounds(&self) -> LengthBounds {
77        (self.range.end - self.range.start).into()
78    }
79
80    fn children(&self) -> &[OperatorRef] {
81        std::slice::from_ref(&self.child)
82    }
83
84    fn with_children(self: Arc<Self>, children: Vec<OperatorRef>) -> VortexResult<OperatorRef> {
85        Ok(Arc::new(SliceOperator::try_new(
86            children.into_iter().next().vortex_expect("missing child"),
87            self.range.clone(),
88        )?))
89    }
90
91    fn reduce_children(&self) -> VortexResult<Option<OperatorRef>> {
92        // We push down the slice operator to any child that is aligned to the parent.
93        let children = (0..self.nchildren())
94            .map(|i| {
95                let child = self.child.children()[i].clone();
96
97                if self.child.is_selection_target(i).unwrap_or_default() {
98                    // Push-down the filter to this child.
99                    Ok::<_, VortexError>(Arc::new(SliceOperator::try_new(
100                        child,
101                        self.range.clone(),
102                    )?) as OperatorRef)
103                } else {
104                    Ok(child)
105                }
106            })
107            .try_collect()?;
108
109        Ok(Some(self.child.clone().with_children(children)?))
110    }
111
112    fn as_batch(&self) -> Option<&dyn BatchOperator> {
113        Some(self)
114    }
115}
116
117impl BatchOperator for SliceOperator {
118    fn bind(&self, ctx: &mut dyn BatchBindCtx) -> VortexResult<BatchExecutionRef> {
119        let child_exec = ctx.child(0)?;
120        Ok(Box::new(SliceExecution {
121            child: child_exec,
122            range: self.range.clone(),
123        }))
124    }
125}
126
127struct SliceExecution {
128    child: BatchExecutionRef,
129    range: Range<usize>,
130}
131
132#[async_trait]
133impl BatchExecution for SliceExecution {
134    async fn execute(self: Box<Self>) -> VortexResult<Canonical> {
135        let child = self.child.execute().await?;
136        Ok(child.into_array().slice(self.range).to_canonical())
137    }
138}