vortex_expr/exprs/
select.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5
6use itertools::Itertools;
7use vortex_array::{ArrayRef, DeserializeMetadata, EmptyMetadata, IntoArray, ToCanonical};
8use vortex_dtype::{DType, FieldNames};
9use vortex_error::{VortexResult, vortex_bail, vortex_err};
10
11use crate::field::DisplayFieldNames;
12use crate::{AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, VTable, vtable};
13
14#[derive(Debug, Clone, PartialEq, Eq, Hash)]
15pub enum SelectField {
16    Include(FieldNames),
17    Exclude(FieldNames),
18}
19
20vtable!(Select);
21
22#[derive(Debug, Clone, Hash)]
23#[allow(clippy::derived_hash_with_manual_eq)]
24pub struct SelectExpr {
25    fields: SelectField,
26    child: ExprRef,
27}
28
29impl PartialEq for SelectExpr {
30    fn eq(&self, other: &Self) -> bool {
31        self.fields == other.fields && self.child.eq(&other.child)
32    }
33}
34
35pub struct SelectExprEncoding;
36
37impl VTable for SelectVTable {
38    type Expr = SelectExpr;
39    type Encoding = SelectExprEncoding;
40    type Metadata = EmptyMetadata;
41
42    fn id(_encoding: &Self::Encoding) -> ExprId {
43        ExprId::new_ref("select")
44    }
45
46    fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
47        ExprEncodingRef::new_ref(SelectExprEncoding.as_ref())
48    }
49
50    fn metadata(_expr: &Self::Expr) -> Option<Self::Metadata> {
51        // Select does not support serialization
52        None
53    }
54
55    fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
56        vec![&expr.child]
57    }
58
59    fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
60        Ok(SelectExpr {
61            fields: expr.fields.clone(),
62            child: children[0].clone(),
63        })
64    }
65
66    fn build(
67        _encoding: &Self::Encoding,
68        _metadata: &<Self::Metadata as DeserializeMetadata>::Output,
69        _children: Vec<ExprRef>,
70    ) -> VortexResult<Self::Expr> {
71        vortex_bail!("Select does not support deserialization")
72    }
73
74    fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
75        let batch = expr.child.unchecked_evaluate(scope)?.to_struct()?;
76        Ok(match &expr.fields {
77            SelectField::Include(f) => batch.project(f.as_ref()),
78            SelectField::Exclude(names) => {
79                let included_names = batch
80                    .names()
81                    .iter()
82                    .filter(|&f| !names.as_ref().contains(f))
83                    .cloned()
84                    .collect::<Vec<_>>();
85                batch.project(included_names.as_slice())
86            }
87        }?
88        .into_array())
89    }
90
91    fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
92        let child_dtype = expr.child.return_dtype(scope)?;
93        let child_struct_dtype = child_dtype
94            .as_struct()
95            .ok_or_else(|| vortex_err!("Select child not a struct dtype"))?;
96
97        let projected = match &expr.fields {
98            SelectField::Include(fields) => child_struct_dtype.project(fields.as_ref())?,
99            SelectField::Exclude(fields) => child_struct_dtype
100                .names()
101                .iter()
102                .cloned()
103                .zip_eq(child_struct_dtype.fields())
104                .filter(|(name, _)| !fields.as_ref().contains(name))
105                .collect(),
106        };
107
108        Ok(DType::Struct(projected, child_dtype.nullability()))
109    }
110}
111
112pub fn select(fields: impl Into<FieldNames>, child: ExprRef) -> ExprRef {
113    SelectExpr::include_expr(fields.into(), child)
114}
115
116pub fn select_exclude(fields: impl Into<FieldNames>, child: ExprRef) -> ExprRef {
117    SelectExpr::exclude_expr(fields.into(), child)
118}
119
120impl SelectExpr {
121    pub fn new(fields: SelectField, child: ExprRef) -> Self {
122        Self { fields, child }
123    }
124
125    pub fn new_expr(fields: SelectField, child: ExprRef) -> ExprRef {
126        Self::new(fields, child).into_expr()
127    }
128
129    pub fn include_expr(columns: FieldNames, child: ExprRef) -> ExprRef {
130        Self::new(SelectField::Include(columns), child).into_expr()
131    }
132
133    pub fn exclude_expr(columns: FieldNames, child: ExprRef) -> ExprRef {
134        Self::new(SelectField::Exclude(columns), child).into_expr()
135    }
136
137    pub fn fields(&self) -> &SelectField {
138        &self.fields
139    }
140
141    pub fn child(&self) -> &ExprRef {
142        &self.child
143    }
144
145    pub fn as_include(&self, field_names: &FieldNames) -> VortexResult<ExprRef> {
146        Ok(Self::new(
147            SelectField::Include(self.fields.as_include_names(field_names)?),
148            self.child.clone(),
149        )
150        .into_expr())
151    }
152}
153
154impl SelectField {
155    pub fn include(columns: FieldNames) -> Self {
156        assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
157        Self::Include(columns)
158    }
159
160    pub fn exclude(columns: FieldNames) -> Self {
161        assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
162        Self::Exclude(columns)
163    }
164
165    pub fn is_include(&self) -> bool {
166        matches!(self, Self::Include(_))
167    }
168
169    pub fn is_exclude(&self) -> bool {
170        matches!(self, Self::Exclude(_))
171    }
172
173    pub fn fields(&self) -> &FieldNames {
174        match self {
175            SelectField::Include(fields) => fields,
176            SelectField::Exclude(fields) => fields,
177        }
178    }
179
180    pub fn as_include_names(&self, field_names: &FieldNames) -> VortexResult<FieldNames> {
181        if self
182            .fields()
183            .iter()
184            .any(|f| !field_names.iter().contains(f))
185        {
186            vortex_bail!(
187                "Field {:?} in select not in field names {:?}",
188                self,
189                field_names
190            );
191        }
192        match self {
193            SelectField::Include(fields) => Ok(fields.clone()),
194            SelectField::Exclude(exc_fields) => Ok(field_names
195                .iter()
196                .filter(|f| exc_fields.iter().contains(f))
197                .cloned()
198                .collect()),
199        }
200    }
201}
202
203impl Display for SelectField {
204    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205        match self {
206            SelectField::Include(fields) => write!(f, "{{{}}}", DisplayFieldNames(fields)),
207            SelectField::Exclude(fields) => write!(f, "~{{{}}}", DisplayFieldNames(fields)),
208        }
209    }
210}
211
212impl Display for SelectExpr {
213    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214        write!(f, "{}{}", self.child, self.fields)
215    }
216}
217
218impl AnalysisExpr for SelectExpr {}
219
220#[cfg(test)]
221mod tests {
222
223    use vortex_array::arrays::StructArray;
224    use vortex_array::{IntoArray, ToCanonical};
225    use vortex_buffer::buffer;
226    use vortex_dtype::{DType, FieldName, Nullability};
227
228    use crate::{Scope, root, select, select_exclude, test_harness};
229
230    fn test_array() -> StructArray {
231        StructArray::from_fields(&[
232            ("a", buffer![0, 1, 2].into_array()),
233            ("b", buffer![4, 5, 6].into_array()),
234        ])
235        .unwrap()
236    }
237
238    #[test]
239    pub fn include_columns() {
240        let st = test_array();
241        let select = select(vec![FieldName::from("a")], root());
242        let selected = select
243            .evaluate(&Scope::new(st.to_array()))
244            .unwrap()
245            .to_struct()
246            .unwrap();
247        let selected_names = selected.names().clone();
248        assert_eq!(selected_names.as_ref(), &["a".into()]);
249    }
250
251    #[test]
252    pub fn exclude_columns() {
253        let st = test_array();
254        let select = select_exclude(vec![FieldName::from("a")], root());
255        let selected = select
256            .evaluate(&Scope::new(st.to_array()))
257            .unwrap()
258            .to_struct()
259            .unwrap();
260        let selected_names = selected.names().clone();
261        assert_eq!(selected_names.as_ref(), &["b".into()]);
262    }
263
264    #[test]
265    fn dtype() {
266        let dtype = test_harness::struct_dtype();
267
268        let select_expr = select(vec![FieldName::from("a")], root());
269        let expected_dtype = DType::Struct(
270            dtype.as_struct().unwrap().project(&["a".into()]).unwrap(),
271            Nullability::NonNullable,
272        );
273        assert_eq!(select_expr.return_dtype(&dtype).unwrap(), expected_dtype);
274
275        let select_expr_exclude = select_exclude(
276            vec![
277                FieldName::from("col1"),
278                FieldName::from("col2"),
279                FieldName::from("bool1"),
280                FieldName::from("bool2"),
281            ],
282            root(),
283        );
284        assert_eq!(
285            select_expr_exclude.return_dtype(&dtype).unwrap(),
286            expected_dtype
287        );
288
289        let select_expr_exclude = select_exclude(
290            vec![FieldName::from("col1"), FieldName::from("col2")],
291            root(),
292        );
293        assert_eq!(
294            select_expr_exclude.return_dtype(&dtype).unwrap(),
295            DType::Struct(
296                dtype
297                    .as_struct()
298                    .unwrap()
299                    .project(&["a".into(), "bool1".into(), "bool2".into()])
300                    .unwrap(),
301                Nullability::NonNullable
302            )
303        );
304    }
305}