vortex_expr/exprs/
select.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5
6use itertools::Itertools;
7use vortex_array::{ArrayRef, DeserializeMetadata, IntoArray, ProstMetadata, ToCanonical};
8use vortex_dtype::{DType, FieldNames};
9use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
10use vortex_proto::expr::select_opts::Opts;
11use vortex_proto::expr::{FieldNames as ProtoFieldNames, SelectOpts};
12
13use crate::field::DisplayFieldNames;
14use crate::{AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, VTable, vtable};
15
16#[derive(Debug, Clone, PartialEq, Eq, Hash)]
17pub enum SelectField {
18    Include(FieldNames),
19    Exclude(FieldNames),
20}
21
22vtable!(Select);
23
24#[derive(Debug, Clone, Hash, Eq)]
25#[allow(clippy::derived_hash_with_manual_eq)]
26pub struct SelectExpr {
27    fields: SelectField,
28    child: ExprRef,
29}
30
31impl PartialEq for SelectExpr {
32    fn eq(&self, other: &Self) -> bool {
33        self.fields == other.fields && self.child.eq(&other.child)
34    }
35}
36
37pub struct SelectExprEncoding;
38
39impl VTable for SelectVTable {
40    type Expr = SelectExpr;
41    type Encoding = SelectExprEncoding;
42    type Metadata = ProstMetadata<SelectOpts>;
43
44    fn id(_encoding: &Self::Encoding) -> ExprId {
45        ExprId::new_ref("select")
46    }
47
48    fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
49        ExprEncodingRef::new_ref(SelectExprEncoding.as_ref())
50    }
51
52    fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
53        let names = expr
54            .fields()
55            .fields()
56            .iter()
57            .map(|f| f.to_string())
58            .collect_vec();
59
60        let opts = if expr.fields().is_include() {
61            Opts::Include(ProtoFieldNames { names })
62        } else {
63            Opts::Exclude(ProtoFieldNames { names })
64        };
65
66        Some(ProstMetadata(SelectOpts { opts: Some(opts) }))
67    }
68
69    fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
70        vec![&expr.child]
71    }
72
73    fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
74        Ok(SelectExpr {
75            fields: expr.fields.clone(),
76            child: children[0].clone(),
77        })
78    }
79
80    fn build(
81        _encoding: &Self::Encoding,
82        metadata: &<Self::Metadata as DeserializeMetadata>::Output,
83        mut children: Vec<ExprRef>,
84    ) -> VortexResult<Self::Expr> {
85        if children.len() != 1 {
86            vortex_bail!("Select expression must have exactly one child");
87        }
88
89        let fields = match metadata.opts.as_ref() {
90            Some(opts) => match opts {
91                Opts::Include(field_names) => SelectField::Include(FieldNames::from_iter(
92                    field_names.names.iter().map(|s| s.as_str()),
93                )),
94                Opts::Exclude(field_names) => SelectField::Exclude(FieldNames::from_iter(
95                    field_names.names.iter().map(|s| s.as_str()),
96                )),
97            },
98            None => {
99                vortex_bail!("Select expressions must be provided with fields to select or exclude")
100            }
101        };
102
103        let child = children
104            .drain(..)
105            .next()
106            .vortex_expect("number of children validated to be one");
107
108        Ok(SelectExpr { fields, child })
109    }
110
111    fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
112        let batch = expr.child.unchecked_evaluate(scope)?.to_struct()?;
113        Ok(match &expr.fields {
114            SelectField::Include(f) => batch.project(f.as_ref()),
115            SelectField::Exclude(names) => {
116                let included_names = batch
117                    .names()
118                    .iter()
119                    .filter(|&f| !names.as_ref().contains(f))
120                    .cloned()
121                    .collect::<Vec<_>>();
122                batch.project(included_names.as_slice())
123            }
124        }?
125        .into_array())
126    }
127
128    fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
129        let child_dtype = expr.child.return_dtype(scope)?;
130        let child_struct_dtype = child_dtype
131            .as_struct()
132            .ok_or_else(|| vortex_err!("Select child not a struct dtype"))?;
133
134        let projected = match &expr.fields {
135            SelectField::Include(fields) => child_struct_dtype.project(fields.as_ref())?,
136            SelectField::Exclude(fields) => child_struct_dtype
137                .names()
138                .iter()
139                .cloned()
140                .zip_eq(child_struct_dtype.fields())
141                .filter(|(name, _)| !fields.as_ref().contains(name))
142                .collect(),
143        };
144
145        Ok(DType::Struct(projected, child_dtype.nullability()))
146    }
147}
148
149/// Creates an expression that selects (includes) specific fields from an array.
150///
151/// Projects only the specified fields from the child expression, which must be of DType struct.
152/// ```rust
153/// # use vortex_expr::{select, root};
154/// let expr = select(["name", "age"], root());
155/// ```
156pub fn select(fields: impl Into<FieldNames>, child: ExprRef) -> ExprRef {
157    SelectExpr::include_expr(fields.into(), child)
158}
159
160/// Creates an expression that excludes specific fields from an array.
161///
162/// Projects all fields except the specified ones from the input struct expression.
163///
164/// ```rust
165/// # use vortex_expr::{select_exclude, root};
166/// let expr = select_exclude(["internal_id", "metadata"], root());
167/// ```
168pub fn select_exclude(fields: impl Into<FieldNames>, child: ExprRef) -> ExprRef {
169    SelectExpr::exclude_expr(fields.into(), child)
170}
171
172impl SelectExpr {
173    pub fn new(fields: SelectField, child: ExprRef) -> Self {
174        Self { fields, child }
175    }
176
177    pub fn new_expr(fields: SelectField, child: ExprRef) -> ExprRef {
178        Self::new(fields, child).into_expr()
179    }
180
181    pub fn include_expr(columns: FieldNames, child: ExprRef) -> ExprRef {
182        Self::new(SelectField::Include(columns), child).into_expr()
183    }
184
185    pub fn exclude_expr(columns: FieldNames, child: ExprRef) -> ExprRef {
186        Self::new(SelectField::Exclude(columns), child).into_expr()
187    }
188
189    pub fn fields(&self) -> &SelectField {
190        &self.fields
191    }
192
193    pub fn child(&self) -> &ExprRef {
194        &self.child
195    }
196
197    /// Turn the select expression into an `include`, relative to a provided array of field names.
198    ///
199    /// For example:
200    /// ```rust
201    /// # use vortex_expr::root;
202    /// # use vortex_expr::{SelectExpr, SelectField};
203    /// # use vortex_dtype::FieldNames;
204    /// let field_names = FieldNames::from(["a", "b", "c"]);
205    /// let include = SelectExpr::new(SelectField::Include(["a"].into()), root());
206    /// let exclude = SelectExpr::new(SelectField::Exclude(["b", "c"].into()), root());
207    /// assert_eq!(
208    ///     &include.as_include(&field_names).unwrap(),
209    ///     &exclude.as_include(&field_names).unwrap()
210    /// );
211    /// ```
212    pub fn as_include(&self, field_names: &FieldNames) -> VortexResult<ExprRef> {
213        Ok(Self::new(
214            SelectField::Include(self.fields.as_include_names(field_names)?),
215            self.child.clone(),
216        )
217        .into_expr())
218    }
219}
220
221impl SelectField {
222    pub fn include(columns: FieldNames) -> Self {
223        assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
224        Self::Include(columns)
225    }
226
227    pub fn exclude(columns: FieldNames) -> Self {
228        assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
229        Self::Exclude(columns)
230    }
231
232    pub fn is_include(&self) -> bool {
233        matches!(self, Self::Include(_))
234    }
235
236    pub fn is_exclude(&self) -> bool {
237        matches!(self, Self::Exclude(_))
238    }
239
240    pub fn fields(&self) -> &FieldNames {
241        let (SelectField::Include(fields) | SelectField::Exclude(fields)) = self;
242
243        fields
244    }
245
246    pub fn as_include_names(&self, field_names: &FieldNames) -> VortexResult<FieldNames> {
247        if self
248            .fields()
249            .iter()
250            .any(|f| !field_names.iter().contains(f))
251        {
252            vortex_bail!(
253                "Field {:?} in select not in field names {:?}",
254                self,
255                field_names
256            );
257        }
258        match self {
259            SelectField::Include(fields) => Ok(fields.clone()),
260            SelectField::Exclude(exc_fields) => Ok(field_names
261                .iter()
262                .filter(|f| !exc_fields.iter().contains(f))
263                .cloned()
264                .collect()),
265        }
266    }
267}
268
269impl Display for SelectField {
270    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
271        match self {
272            SelectField::Include(fields) => write!(f, "{{{}}}", DisplayFieldNames(fields)),
273            SelectField::Exclude(fields) => write!(f, "~{{{}}}", DisplayFieldNames(fields)),
274        }
275    }
276}
277
278impl Display for SelectExpr {
279    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
280        write!(f, "{}{}", self.child, self.fields)
281    }
282}
283
284impl AnalysisExpr for SelectExpr {}
285
286#[cfg(test)]
287mod tests {
288
289    use vortex_array::arrays::StructArray;
290    use vortex_array::{IntoArray, ToCanonical};
291    use vortex_buffer::buffer;
292    use vortex_dtype::{DType, FieldName, FieldNames, Nullability};
293
294    use crate::{Scope, SelectExpr, SelectField, root, select, select_exclude, test_harness};
295
296    fn test_array() -> StructArray {
297        StructArray::from_fields(&[
298            ("a", buffer![0, 1, 2].into_array()),
299            ("b", buffer![4, 5, 6].into_array()),
300        ])
301        .unwrap()
302    }
303
304    #[test]
305    pub fn include_columns() {
306        let st = test_array();
307        let select = select(vec![FieldName::from("a")], root());
308        let selected = select
309            .evaluate(&Scope::new(st.to_array()))
310            .unwrap()
311            .to_struct()
312            .unwrap();
313        let selected_names = selected.names().clone();
314        assert_eq!(selected_names.as_ref(), &["a".into()]);
315    }
316
317    #[test]
318    pub fn exclude_columns() {
319        let st = test_array();
320        let select = select_exclude(vec![FieldName::from("a")], root());
321        let selected = select
322            .evaluate(&Scope::new(st.to_array()))
323            .unwrap()
324            .to_struct()
325            .unwrap();
326        let selected_names = selected.names().clone();
327        assert_eq!(selected_names.as_ref(), &["b".into()]);
328    }
329
330    #[test]
331    fn dtype() {
332        let dtype = test_harness::struct_dtype();
333
334        let select_expr = select(vec![FieldName::from("a")], root());
335        let expected_dtype = DType::Struct(
336            dtype.as_struct().unwrap().project(&["a".into()]).unwrap(),
337            Nullability::NonNullable,
338        );
339        assert_eq!(select_expr.return_dtype(&dtype).unwrap(), expected_dtype);
340
341        let select_expr_exclude = select_exclude(
342            vec![
343                FieldName::from("col1"),
344                FieldName::from("col2"),
345                FieldName::from("bool1"),
346                FieldName::from("bool2"),
347            ],
348            root(),
349        );
350        assert_eq!(
351            select_expr_exclude.return_dtype(&dtype).unwrap(),
352            expected_dtype
353        );
354
355        let select_expr_exclude = select_exclude(
356            vec![FieldName::from("col1"), FieldName::from("col2")],
357            root(),
358        );
359        assert_eq!(
360            select_expr_exclude.return_dtype(&dtype).unwrap(),
361            DType::Struct(
362                dtype
363                    .as_struct()
364                    .unwrap()
365                    .project(&["a".into(), "bool1".into(), "bool2".into()])
366                    .unwrap(),
367                Nullability::NonNullable
368            )
369        );
370    }
371
372    #[test]
373    fn test_as_include_names() {
374        let field_names = FieldNames::from(["a", "b", "c"]);
375        let include = SelectExpr::new(SelectField::Include(["a"].into()), root());
376        let exclude = SelectExpr::new(SelectField::Exclude(["b", "c"].into()), root());
377        assert_eq!(
378            &include.as_include(&field_names).unwrap(),
379            &exclude.as_include(&field_names).unwrap()
380        );
381    }
382}