vortex_expr/exprs/
select.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5
6use itertools::Itertools;
7use vortex_array::{ArrayRef, DeserializeMetadata, IntoArray, ProstMetadata, ToCanonical};
8use vortex_dtype::{DType, FieldNames};
9use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
10use vortex_proto::expr::select_opts::Opts;
11use vortex_proto::expr::{FieldNames as ProtoFieldNames, SelectOpts};
12
13use crate::field::DisplayFieldNames;
14use crate::{AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, VTable, vtable};
15
16#[derive(Debug, Clone, PartialEq, Eq, Hash)]
17pub enum SelectField {
18    Include(FieldNames),
19    Exclude(FieldNames),
20}
21
22vtable!(Select);
23
24#[derive(Debug, Clone, Hash, Eq)]
25#[allow(clippy::derived_hash_with_manual_eq)]
26pub struct SelectExpr {
27    fields: SelectField,
28    child: ExprRef,
29}
30
31impl PartialEq for SelectExpr {
32    fn eq(&self, other: &Self) -> bool {
33        self.fields == other.fields && self.child.eq(&other.child)
34    }
35}
36
37pub struct SelectExprEncoding;
38
39impl VTable for SelectVTable {
40    type Expr = SelectExpr;
41    type Encoding = SelectExprEncoding;
42    type Metadata = ProstMetadata<SelectOpts>;
43
44    fn id(_encoding: &Self::Encoding) -> ExprId {
45        ExprId::new_ref("select")
46    }
47
48    fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
49        ExprEncodingRef::new_ref(SelectExprEncoding.as_ref())
50    }
51
52    fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
53        let names = expr
54            .fields()
55            .fields()
56            .iter()
57            .map(|f| f.to_string())
58            .collect_vec();
59
60        let opts = if expr.fields().is_include() {
61            Opts::Include(ProtoFieldNames { names })
62        } else {
63            Opts::Exclude(ProtoFieldNames { names })
64        };
65
66        Some(ProstMetadata(SelectOpts { opts: Some(opts) }))
67    }
68
69    fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
70        vec![&expr.child]
71    }
72
73    fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
74        Ok(SelectExpr {
75            fields: expr.fields.clone(),
76            child: children[0].clone(),
77        })
78    }
79
80    fn build(
81        _encoding: &Self::Encoding,
82        metadata: &<Self::Metadata as DeserializeMetadata>::Output,
83        mut children: Vec<ExprRef>,
84    ) -> VortexResult<Self::Expr> {
85        if children.len() != 1 {
86            vortex_bail!("Select expression must have exactly one child");
87        }
88
89        let fields = match metadata.opts.as_ref() {
90            Some(opts) => match opts {
91                Opts::Include(field_names) => SelectField::Include(FieldNames::from_iter(
92                    field_names.names.iter().map(|s| s.as_str()),
93                )),
94                Opts::Exclude(field_names) => SelectField::Exclude(FieldNames::from_iter(
95                    field_names.names.iter().map(|s| s.as_str()),
96                )),
97            },
98            None => {
99                vortex_bail!("Select expressions must be provided with fields to select or exclude")
100            }
101        };
102
103        let child = children
104            .drain(..)
105            .next()
106            .vortex_expect("number of children validated to be one");
107
108        Ok(SelectExpr { fields, child })
109    }
110
111    fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
112        let batch = expr.child.unchecked_evaluate(scope)?.to_struct()?;
113        Ok(match &expr.fields {
114            SelectField::Include(f) => batch.project(f.as_ref()),
115            SelectField::Exclude(names) => {
116                let included_names = batch
117                    .names()
118                    .iter()
119                    .filter(|&f| !names.as_ref().contains(f))
120                    .cloned()
121                    .collect::<Vec<_>>();
122                batch.project(included_names.as_slice())
123            }
124        }?
125        .into_array())
126    }
127
128    fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
129        let child_dtype = expr.child.return_dtype(scope)?;
130        let child_struct_dtype = child_dtype
131            .as_struct()
132            .ok_or_else(|| vortex_err!("Select child not a struct dtype"))?;
133
134        let projected = match &expr.fields {
135            SelectField::Include(fields) => child_struct_dtype.project(fields.as_ref())?,
136            SelectField::Exclude(fields) => child_struct_dtype
137                .names()
138                .iter()
139                .cloned()
140                .zip_eq(child_struct_dtype.fields())
141                .filter(|(name, _)| !fields.as_ref().contains(name))
142                .collect(),
143        };
144
145        Ok(DType::Struct(projected, child_dtype.nullability()))
146    }
147}
148
149pub fn select(fields: impl Into<FieldNames>, child: ExprRef) -> ExprRef {
150    SelectExpr::include_expr(fields.into(), child)
151}
152
153pub fn select_exclude(fields: impl Into<FieldNames>, child: ExprRef) -> ExprRef {
154    SelectExpr::exclude_expr(fields.into(), child)
155}
156
157impl SelectExpr {
158    pub fn new(fields: SelectField, child: ExprRef) -> Self {
159        Self { fields, child }
160    }
161
162    pub fn new_expr(fields: SelectField, child: ExprRef) -> ExprRef {
163        Self::new(fields, child).into_expr()
164    }
165
166    pub fn include_expr(columns: FieldNames, child: ExprRef) -> ExprRef {
167        Self::new(SelectField::Include(columns), child).into_expr()
168    }
169
170    pub fn exclude_expr(columns: FieldNames, child: ExprRef) -> ExprRef {
171        Self::new(SelectField::Exclude(columns), child).into_expr()
172    }
173
174    pub fn fields(&self) -> &SelectField {
175        &self.fields
176    }
177
178    pub fn child(&self) -> &ExprRef {
179        &self.child
180    }
181
182    pub fn as_include(&self, field_names: &FieldNames) -> VortexResult<ExprRef> {
183        Ok(Self::new(
184            SelectField::Include(self.fields.as_include_names(field_names)?),
185            self.child.clone(),
186        )
187        .into_expr())
188    }
189}
190
191impl SelectField {
192    pub fn include(columns: FieldNames) -> Self {
193        assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
194        Self::Include(columns)
195    }
196
197    pub fn exclude(columns: FieldNames) -> Self {
198        assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
199        Self::Exclude(columns)
200    }
201
202    pub fn is_include(&self) -> bool {
203        matches!(self, Self::Include(_))
204    }
205
206    pub fn is_exclude(&self) -> bool {
207        matches!(self, Self::Exclude(_))
208    }
209
210    pub fn fields(&self) -> &FieldNames {
211        let (SelectField::Include(fields) | SelectField::Exclude(fields)) = self;
212
213        fields
214    }
215
216    pub fn as_include_names(&self, field_names: &FieldNames) -> VortexResult<FieldNames> {
217        if self
218            .fields()
219            .iter()
220            .any(|f| !field_names.iter().contains(f))
221        {
222            vortex_bail!(
223                "Field {:?} in select not in field names {:?}",
224                self,
225                field_names
226            );
227        }
228        match self {
229            SelectField::Include(fields) => Ok(fields.clone()),
230            SelectField::Exclude(exc_fields) => Ok(field_names
231                .iter()
232                .filter(|f| exc_fields.iter().contains(f))
233                .cloned()
234                .collect()),
235        }
236    }
237}
238
239impl Display for SelectField {
240    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
241        match self {
242            SelectField::Include(fields) => write!(f, "{{{}}}", DisplayFieldNames(fields)),
243            SelectField::Exclude(fields) => write!(f, "~{{{}}}", DisplayFieldNames(fields)),
244        }
245    }
246}
247
248impl Display for SelectExpr {
249    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
250        write!(f, "{}{}", self.child, self.fields)
251    }
252}
253
254impl AnalysisExpr for SelectExpr {}
255
256#[cfg(test)]
257mod tests {
258
259    use vortex_array::arrays::StructArray;
260    use vortex_array::{IntoArray, ToCanonical};
261    use vortex_buffer::buffer;
262    use vortex_dtype::{DType, FieldName, Nullability};
263
264    use crate::{Scope, root, select, select_exclude, test_harness};
265
266    fn test_array() -> StructArray {
267        StructArray::from_fields(&[
268            ("a", buffer![0, 1, 2].into_array()),
269            ("b", buffer![4, 5, 6].into_array()),
270        ])
271        .unwrap()
272    }
273
274    #[test]
275    pub fn include_columns() {
276        let st = test_array();
277        let select = select(vec![FieldName::from("a")], root());
278        let selected = select
279            .evaluate(&Scope::new(st.to_array()))
280            .unwrap()
281            .to_struct()
282            .unwrap();
283        let selected_names = selected.names().clone();
284        assert_eq!(selected_names.as_ref(), &["a".into()]);
285    }
286
287    #[test]
288    pub fn exclude_columns() {
289        let st = test_array();
290        let select = select_exclude(vec![FieldName::from("a")], root());
291        let selected = select
292            .evaluate(&Scope::new(st.to_array()))
293            .unwrap()
294            .to_struct()
295            .unwrap();
296        let selected_names = selected.names().clone();
297        assert_eq!(selected_names.as_ref(), &["b".into()]);
298    }
299
300    #[test]
301    fn dtype() {
302        let dtype = test_harness::struct_dtype();
303
304        let select_expr = select(vec![FieldName::from("a")], root());
305        let expected_dtype = DType::Struct(
306            dtype.as_struct().unwrap().project(&["a".into()]).unwrap(),
307            Nullability::NonNullable,
308        );
309        assert_eq!(select_expr.return_dtype(&dtype).unwrap(), expected_dtype);
310
311        let select_expr_exclude = select_exclude(
312            vec![
313                FieldName::from("col1"),
314                FieldName::from("col2"),
315                FieldName::from("bool1"),
316                FieldName::from("bool2"),
317            ],
318            root(),
319        );
320        assert_eq!(
321            select_expr_exclude.return_dtype(&dtype).unwrap(),
322            expected_dtype
323        );
324
325        let select_expr_exclude = select_exclude(
326            vec![FieldName::from("col1"), FieldName::from("col2")],
327            root(),
328        );
329        assert_eq!(
330            select_expr_exclude.return_dtype(&dtype).unwrap(),
331            DType::Struct(
332                dtype
333                    .as_struct()
334                    .unwrap()
335                    .project(&["a".into(), "bool1".into(), "bool2".into()])
336                    .unwrap(),
337                Nullability::NonNullable
338            )
339        );
340    }
341}