vortex_expr/exprs/
select.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5
6use itertools::Itertools;
7use vortex_array::{ArrayRef, DeserializeMetadata, IntoArray, ProstMetadata, ToCanonical};
8use vortex_dtype::{DType, FieldNames};
9use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
10use vortex_proto::expr::select_opts::Opts;
11use vortex_proto::expr::{FieldNames as ProtoFieldNames, SelectOpts};
12
13use crate::display::{DisplayAs, DisplayFormat};
14use crate::field::DisplayFieldNames;
15use crate::{AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, VTable, vtable};
16
17#[derive(Debug, Clone, PartialEq, Eq, Hash)]
18pub enum SelectField {
19    Include(FieldNames),
20    Exclude(FieldNames),
21}
22
23vtable!(Select);
24
25#[derive(Debug, Clone, Hash, Eq)]
26#[allow(clippy::derived_hash_with_manual_eq)]
27pub struct SelectExpr {
28    fields: SelectField,
29    child: ExprRef,
30}
31
32impl PartialEq for SelectExpr {
33    fn eq(&self, other: &Self) -> bool {
34        self.fields == other.fields && self.child.eq(&other.child)
35    }
36}
37
38pub struct SelectExprEncoding;
39
40impl VTable for SelectVTable {
41    type Expr = SelectExpr;
42    type Encoding = SelectExprEncoding;
43    type Metadata = ProstMetadata<SelectOpts>;
44
45    fn id(_encoding: &Self::Encoding) -> ExprId {
46        ExprId::new_ref("select")
47    }
48
49    fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
50        ExprEncodingRef::new_ref(SelectExprEncoding.as_ref())
51    }
52
53    fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
54        let names = expr
55            .fields()
56            .fields()
57            .iter()
58            .map(|f| f.to_string())
59            .collect_vec();
60
61        let opts = if expr.fields().is_include() {
62            Opts::Include(ProtoFieldNames { names })
63        } else {
64            Opts::Exclude(ProtoFieldNames { names })
65        };
66
67        Some(ProstMetadata(SelectOpts { opts: Some(opts) }))
68    }
69
70    fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
71        vec![&expr.child]
72    }
73
74    fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
75        Ok(SelectExpr {
76            fields: expr.fields.clone(),
77            child: children[0].clone(),
78        })
79    }
80
81    fn build(
82        _encoding: &Self::Encoding,
83        metadata: &<Self::Metadata as DeserializeMetadata>::Output,
84        mut children: Vec<ExprRef>,
85    ) -> VortexResult<Self::Expr> {
86        if children.len() != 1 {
87            vortex_bail!("Select expression must have exactly one child");
88        }
89
90        let fields = match metadata.opts.as_ref() {
91            Some(opts) => match opts {
92                Opts::Include(field_names) => SelectField::Include(FieldNames::from_iter(
93                    field_names.names.iter().map(|s| s.as_str()),
94                )),
95                Opts::Exclude(field_names) => SelectField::Exclude(FieldNames::from_iter(
96                    field_names.names.iter().map(|s| s.as_str()),
97                )),
98            },
99            None => {
100                vortex_bail!("Select expressions must be provided with fields to select or exclude")
101            }
102        };
103
104        let child = children
105            .drain(..)
106            .next()
107            .vortex_expect("number of children validated to be one");
108
109        Ok(SelectExpr { fields, child })
110    }
111
112    fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
113        let batch = expr.child.unchecked_evaluate(scope)?.to_struct();
114        Ok(match &expr.fields {
115            SelectField::Include(f) => batch.project(f.as_ref()),
116            SelectField::Exclude(names) => {
117                let included_names = batch
118                    .names()
119                    .iter()
120                    .filter(|&f| !names.as_ref().contains(f))
121                    .cloned()
122                    .collect::<Vec<_>>();
123                batch.project(included_names.as_slice())
124            }
125        }?
126        .into_array())
127    }
128
129    fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
130        let child_dtype = expr.child.return_dtype(scope)?;
131        let child_struct_dtype = child_dtype
132            .as_struct_fields_opt()
133            .ok_or_else(|| vortex_err!("Select child not a struct dtype"))?;
134
135        let projected = match &expr.fields {
136            SelectField::Include(fields) => child_struct_dtype.project(fields.as_ref())?,
137            SelectField::Exclude(fields) => child_struct_dtype
138                .names()
139                .iter()
140                .cloned()
141                .zip_eq(child_struct_dtype.fields())
142                .filter(|(name, _)| !fields.as_ref().contains(name))
143                .collect(),
144        };
145
146        Ok(DType::Struct(projected, child_dtype.nullability()))
147    }
148}
149
150/// Creates an expression that selects (includes) specific fields from an array.
151///
152/// Projects only the specified fields from the child expression, which must be of DType struct.
153/// ```rust
154/// # use vortex_expr::{select, root};
155/// let expr = select(["name", "age"], root());
156/// ```
157pub fn select(fields: impl Into<FieldNames>, child: ExprRef) -> ExprRef {
158    SelectExpr::include_expr(fields.into(), child)
159}
160
161/// Creates an expression that excludes specific fields from an array.
162///
163/// Projects all fields except the specified ones from the input struct expression.
164///
165/// ```rust
166/// # use vortex_expr::{select_exclude, root};
167/// let expr = select_exclude(["internal_id", "metadata"], root());
168/// ```
169pub fn select_exclude(fields: impl Into<FieldNames>, child: ExprRef) -> ExprRef {
170    SelectExpr::exclude_expr(fields.into(), child)
171}
172
173impl SelectExpr {
174    pub fn new(fields: SelectField, child: ExprRef) -> Self {
175        Self { fields, child }
176    }
177
178    pub fn new_expr(fields: SelectField, child: ExprRef) -> ExprRef {
179        Self::new(fields, child).into_expr()
180    }
181
182    pub fn include_expr(columns: FieldNames, child: ExprRef) -> ExprRef {
183        Self::new(SelectField::Include(columns), child).into_expr()
184    }
185
186    pub fn exclude_expr(columns: FieldNames, child: ExprRef) -> ExprRef {
187        Self::new(SelectField::Exclude(columns), child).into_expr()
188    }
189
190    pub fn fields(&self) -> &SelectField {
191        &self.fields
192    }
193
194    pub fn child(&self) -> &ExprRef {
195        &self.child
196    }
197
198    /// Turn the select expression into an `include`, relative to a provided array of field names.
199    ///
200    /// For example:
201    /// ```rust
202    /// # use vortex_expr::root;
203    /// # use vortex_expr::{SelectExpr, SelectField};
204    /// # use vortex_dtype::FieldNames;
205    /// let field_names = FieldNames::from(["a", "b", "c"]);
206    /// let include = SelectExpr::new(SelectField::Include(["a"].into()), root());
207    /// let exclude = SelectExpr::new(SelectField::Exclude(["b", "c"].into()), root());
208    /// assert_eq!(
209    ///     &include.as_include(&field_names).unwrap(),
210    ///     &exclude.as_include(&field_names).unwrap()
211    /// );
212    /// ```
213    pub fn as_include(&self, field_names: &FieldNames) -> VortexResult<ExprRef> {
214        Ok(Self::new(
215            SelectField::Include(self.fields.as_include_names(field_names)?),
216            self.child.clone(),
217        )
218        .into_expr())
219    }
220}
221
222impl SelectField {
223    pub fn include(columns: FieldNames) -> Self {
224        assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
225        Self::Include(columns)
226    }
227
228    pub fn exclude(columns: FieldNames) -> Self {
229        assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
230        Self::Exclude(columns)
231    }
232
233    pub fn is_include(&self) -> bool {
234        matches!(self, Self::Include(_))
235    }
236
237    pub fn is_exclude(&self) -> bool {
238        matches!(self, Self::Exclude(_))
239    }
240
241    pub fn fields(&self) -> &FieldNames {
242        let (SelectField::Include(fields) | SelectField::Exclude(fields)) = self;
243
244        fields
245    }
246
247    pub fn as_include_names(&self, field_names: &FieldNames) -> VortexResult<FieldNames> {
248        if self
249            .fields()
250            .iter()
251            .any(|f| !field_names.iter().contains(f))
252        {
253            vortex_bail!(
254                "Field {:?} in select not in field names {:?}",
255                self,
256                field_names
257            );
258        }
259        match self {
260            SelectField::Include(fields) => Ok(fields.clone()),
261            SelectField::Exclude(exc_fields) => Ok(field_names
262                .iter()
263                .filter(|f| !exc_fields.iter().contains(f))
264                .cloned()
265                .collect()),
266        }
267    }
268}
269
270impl Display for SelectField {
271    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
272        match self {
273            SelectField::Include(fields) => write!(f, "{{{}}}", DisplayFieldNames(fields)),
274            SelectField::Exclude(fields) => write!(f, "~{{{}}}", DisplayFieldNames(fields)),
275        }
276    }
277}
278
279impl DisplayAs for SelectExpr {
280    fn fmt_as(&self, df: DisplayFormat, f: &mut std::fmt::Formatter) -> std::fmt::Result {
281        match df {
282            DisplayFormat::Compact => {
283                write!(f, "{}{}", self.child, self.fields)
284            }
285            DisplayFormat::Tree => {
286                let field_type = if self.fields.is_include() {
287                    "include"
288                } else {
289                    "exclude"
290                };
291
292                write!(f, "Select({}): {}", field_type, self.fields().fields())
293            }
294        }
295    }
296
297    fn child_names(&self) -> Option<Vec<String>> {
298        // Single child - no need to name it, the tree structure makes it obvious
299        None
300    }
301}
302
303impl AnalysisExpr for SelectExpr {}
304
305#[cfg(test)]
306mod tests {
307
308    use vortex_array::arrays::StructArray;
309    use vortex_array::{IntoArray, ToCanonical};
310    use vortex_buffer::buffer;
311    use vortex_dtype::{DType, FieldName, FieldNames, Nullability};
312
313    use crate::{Scope, SelectExpr, SelectField, root, select, select_exclude, test_harness};
314
315    fn test_array() -> StructArray {
316        StructArray::from_fields(&[
317            ("a", buffer![0, 1, 2].into_array()),
318            ("b", buffer![4, 5, 6].into_array()),
319        ])
320        .unwrap()
321    }
322
323    #[test]
324    pub fn include_columns() {
325        let st = test_array();
326        let select = select(vec![FieldName::from("a")], root());
327        let selected = select
328            .evaluate(&Scope::new(st.to_array()))
329            .unwrap()
330            .to_struct();
331        let selected_names = selected.names().clone();
332        assert_eq!(selected_names.as_ref(), &["a".into()]);
333    }
334
335    #[test]
336    pub fn exclude_columns() {
337        let st = test_array();
338        let select = select_exclude(vec![FieldName::from("a")], root());
339        let selected = select
340            .evaluate(&Scope::new(st.to_array()))
341            .unwrap()
342            .to_struct();
343        let selected_names = selected.names().clone();
344        assert_eq!(selected_names.as_ref(), &["b".into()]);
345    }
346
347    #[test]
348    fn dtype() {
349        let dtype = test_harness::struct_dtype();
350
351        let select_expr = select(vec![FieldName::from("a")], root());
352        let expected_dtype = DType::Struct(
353            dtype
354                .as_struct_fields_opt()
355                .unwrap()
356                .project(&["a".into()])
357                .unwrap(),
358            Nullability::NonNullable,
359        );
360        assert_eq!(select_expr.return_dtype(&dtype).unwrap(), expected_dtype);
361
362        let select_expr_exclude = select_exclude(
363            vec![
364                FieldName::from("col1"),
365                FieldName::from("col2"),
366                FieldName::from("bool1"),
367                FieldName::from("bool2"),
368            ],
369            root(),
370        );
371        assert_eq!(
372            select_expr_exclude.return_dtype(&dtype).unwrap(),
373            expected_dtype
374        );
375
376        let select_expr_exclude = select_exclude(
377            vec![FieldName::from("col1"), FieldName::from("col2")],
378            root(),
379        );
380        assert_eq!(
381            select_expr_exclude.return_dtype(&dtype).unwrap(),
382            DType::Struct(
383                dtype
384                    .as_struct_fields_opt()
385                    .unwrap()
386                    .project(&["a".into(), "bool1".into(), "bool2".into()])
387                    .unwrap(),
388                Nullability::NonNullable
389            )
390        );
391    }
392
393    #[test]
394    fn test_as_include_names() {
395        let field_names = FieldNames::from(["a", "b", "c"]);
396        let include = SelectExpr::new(SelectField::Include(["a"].into()), root());
397        let exclude = SelectExpr::new(SelectField::Exclude(["b", "c"].into()), root());
398        assert_eq!(
399            &include.as_include(&field_names).unwrap(),
400            &exclude.as_include(&field_names).unwrap()
401        );
402    }
403}