vortex_expr/exprs/
select.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5
6use itertools::Itertools;
7use vortex_array::{ArrayRef, DeserializeMetadata, IntoArray, ProstMetadata, ToCanonical};
8use vortex_dtype::{DType, FieldNames};
9use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
10use vortex_proto::expr::select_opts::Opts;
11use vortex_proto::expr::{FieldNames as ProtoFieldNames, SelectOpts};
12
13use crate::display::{DisplayAs, DisplayFormat};
14use crate::field::DisplayFieldNames;
15use crate::{AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, VTable, vtable};
16
17#[derive(Debug, Clone, PartialEq, Eq, Hash)]
18pub enum FieldSelection {
19    Include(FieldNames),
20    Exclude(FieldNames),
21}
22
23vtable!(Select);
24
25#[derive(Debug, Clone, Hash, Eq)]
26#[allow(clippy::derived_hash_with_manual_eq)]
27pub struct SelectExpr {
28    selection: FieldSelection,
29    child: ExprRef,
30}
31
32impl PartialEq for SelectExpr {
33    fn eq(&self, other: &Self) -> bool {
34        self.selection == other.selection && self.child.eq(&other.child)
35    }
36}
37
38pub struct SelectExprEncoding;
39
40impl VTable for SelectVTable {
41    type Expr = SelectExpr;
42    type Encoding = SelectExprEncoding;
43    type Metadata = ProstMetadata<SelectOpts>;
44
45    fn id(_encoding: &Self::Encoding) -> ExprId {
46        ExprId::new_ref("select")
47    }
48
49    fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
50        ExprEncodingRef::new_ref(SelectExprEncoding.as_ref())
51    }
52
53    fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
54        let names = expr
55            .selection()
56            .field_names()
57            .iter()
58            .map(|f| f.to_string())
59            .collect_vec();
60
61        let opts = if expr.selection().is_include() {
62            Opts::Include(ProtoFieldNames { names })
63        } else {
64            Opts::Exclude(ProtoFieldNames { names })
65        };
66
67        Some(ProstMetadata(SelectOpts { opts: Some(opts) }))
68    }
69
70    fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
71        vec![&expr.child]
72    }
73
74    fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
75        Ok(SelectExpr {
76            selection: expr.selection.clone(),
77            child: children[0].clone(),
78        })
79    }
80
81    fn build(
82        _encoding: &Self::Encoding,
83        metadata: &<Self::Metadata as DeserializeMetadata>::Output,
84        mut children: Vec<ExprRef>,
85    ) -> VortexResult<Self::Expr> {
86        if children.len() != 1 {
87            vortex_bail!("Select expression must have exactly one child");
88        }
89
90        let fields = match metadata.opts.as_ref() {
91            Some(opts) => match opts {
92                Opts::Include(field_names) => FieldSelection::Include(FieldNames::from_iter(
93                    field_names.names.iter().map(|s| s.as_str()),
94                )),
95                Opts::Exclude(field_names) => FieldSelection::Exclude(FieldNames::from_iter(
96                    field_names.names.iter().map(|s| s.as_str()),
97                )),
98            },
99            None => {
100                vortex_bail!("Select expressions must be provided with fields to select or exclude")
101            }
102        };
103
104        let child = children
105            .drain(..)
106            .next()
107            .vortex_expect("number of children validated to be one");
108
109        Ok(SelectExpr {
110            selection: fields,
111            child,
112        })
113    }
114
115    fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
116        let batch = expr.child.unchecked_evaluate(scope)?.to_struct();
117        Ok(match &expr.selection {
118            FieldSelection::Include(f) => batch.project(f.as_ref()),
119            FieldSelection::Exclude(names) => {
120                let included_names = batch
121                    .names()
122                    .iter()
123                    .filter(|&f| !names.as_ref().contains(f))
124                    .cloned()
125                    .collect::<Vec<_>>();
126                batch.project(included_names.as_slice())
127            }
128        }?
129        .into_array())
130    }
131
132    fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
133        let child_dtype = expr.child.return_dtype(scope)?;
134        let child_struct_dtype = child_dtype
135            .as_struct_fields_opt()
136            .ok_or_else(|| vortex_err!("Select child not a struct dtype"))?;
137
138        let projected = match &expr.selection {
139            FieldSelection::Include(fields) => child_struct_dtype.project(fields.as_ref())?,
140            FieldSelection::Exclude(fields) => child_struct_dtype
141                .names()
142                .iter()
143                .cloned()
144                .zip_eq(child_struct_dtype.fields())
145                .filter(|(name, _)| !fields.as_ref().contains(name))
146                .collect(),
147        };
148
149        Ok(DType::Struct(projected, child_dtype.nullability()))
150    }
151}
152
153/// Creates an expression that selects (includes) specific fields from an array.
154///
155/// Projects only the specified fields from the child expression, which must be of DType struct.
156/// ```rust
157/// # use vortex_expr::{select, root};
158/// let expr = select(["name", "age"], root());
159/// ```
160pub fn select(field_names: impl Into<FieldNames>, child: ExprRef) -> ExprRef {
161    SelectExpr::include_expr(field_names.into(), child)
162}
163
164/// Creates an expression that excludes specific fields from an array.
165///
166/// Projects all fields except the specified ones from the input struct expression.
167///
168/// ```rust
169/// # use vortex_expr::{select_exclude, root};
170/// let expr = select_exclude(["internal_id", "metadata"], root());
171/// ```
172pub fn select_exclude(fields: impl Into<FieldNames>, child: ExprRef) -> ExprRef {
173    SelectExpr::exclude_expr(fields.into(), child)
174}
175
176impl SelectExpr {
177    pub fn new(fields: FieldSelection, child: ExprRef) -> Self {
178        Self {
179            selection: fields,
180            child,
181        }
182    }
183
184    pub fn new_expr(fields: FieldSelection, child: ExprRef) -> ExprRef {
185        Self::new(fields, child).into_expr()
186    }
187
188    pub fn include_expr(columns: FieldNames, child: ExprRef) -> ExprRef {
189        Self::new(FieldSelection::Include(columns), child).into_expr()
190    }
191
192    pub fn exclude_expr(columns: FieldNames, child: ExprRef) -> ExprRef {
193        Self::new(FieldSelection::Exclude(columns), child).into_expr()
194    }
195
196    pub fn selection(&self) -> &FieldSelection {
197        &self.selection
198    }
199
200    pub fn child(&self) -> &ExprRef {
201        &self.child
202    }
203
204    /// Turn the select expression into an `include`, relative to a provided array of field names.
205    ///
206    /// For example:
207    /// ```rust
208    /// # use vortex_expr::root;
209    /// # use vortex_expr::{FieldSelection, SelectExpr};
210    /// # use vortex_dtype::FieldNames;
211    /// let field_names = FieldNames::from(["a", "b", "c"]);
212    /// let include = SelectExpr::new(FieldSelection::Include(["a"].into()), root());
213    /// let exclude = SelectExpr::new(FieldSelection::Exclude(["b", "c"].into()), root());
214    /// assert_eq!(
215    ///     &include.as_include(&field_names).unwrap(),
216    ///     &exclude.as_include(&field_names).unwrap()
217    /// );
218    /// ```
219    pub fn as_include(&self, field_names: &FieldNames) -> VortexResult<ExprRef> {
220        Ok(Self::new(
221            FieldSelection::Include(self.selection.as_include_names(field_names)?),
222            self.child.clone(),
223        )
224        .into_expr())
225    }
226}
227
228impl FieldSelection {
229    pub fn include(columns: FieldNames) -> Self {
230        assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
231        Self::Include(columns)
232    }
233
234    pub fn exclude(columns: FieldNames) -> Self {
235        assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
236        Self::Exclude(columns)
237    }
238
239    pub fn is_include(&self) -> bool {
240        matches!(self, Self::Include(_))
241    }
242
243    pub fn is_exclude(&self) -> bool {
244        matches!(self, Self::Exclude(_))
245    }
246
247    pub fn field_names(&self) -> &FieldNames {
248        let (FieldSelection::Include(fields) | FieldSelection::Exclude(fields)) = self;
249
250        fields
251    }
252
253    pub fn as_include_names(&self, field_names: &FieldNames) -> VortexResult<FieldNames> {
254        if self
255            .field_names()
256            .iter()
257            .any(|f| !field_names.iter().contains(f))
258        {
259            vortex_bail!(
260                "Field {:?} in select not in field names {:?}",
261                self,
262                field_names
263            );
264        }
265        match self {
266            FieldSelection::Include(fields) => Ok(fields.clone()),
267            FieldSelection::Exclude(exc_fields) => Ok(field_names
268                .iter()
269                .filter(|f| !exc_fields.iter().contains(f))
270                .cloned()
271                .collect()),
272        }
273    }
274}
275
276impl Display for FieldSelection {
277    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
278        match self {
279            FieldSelection::Include(fields) => write!(f, "{{{}}}", DisplayFieldNames(fields)),
280            FieldSelection::Exclude(fields) => write!(f, "~{{{}}}", DisplayFieldNames(fields)),
281        }
282    }
283}
284
285impl DisplayAs for SelectExpr {
286    fn fmt_as(&self, df: DisplayFormat, f: &mut std::fmt::Formatter) -> std::fmt::Result {
287        match df {
288            DisplayFormat::Compact => {
289                write!(f, "{}{}", self.child, self.selection)
290            }
291            DisplayFormat::Tree => {
292                let field_type = if self.selection.is_include() {
293                    "include"
294                } else {
295                    "exclude"
296                };
297
298                write!(
299                    f,
300                    "Select({}): {}",
301                    field_type,
302                    self.selection().field_names()
303                )
304            }
305        }
306    }
307
308    fn child_names(&self) -> Option<Vec<String>> {
309        // Single child - no need to name it, the tree structure makes it obvious
310        None
311    }
312}
313
314impl AnalysisExpr for SelectExpr {}
315
316#[cfg(test)]
317mod tests {
318
319    use vortex_array::arrays::StructArray;
320    use vortex_array::{IntoArray, ToCanonical};
321    use vortex_buffer::buffer;
322    use vortex_dtype::{DType, FieldName, FieldNames, Nullability};
323
324    use crate::{FieldSelection, Scope, SelectExpr, root, select, select_exclude, test_harness};
325
326    fn test_array() -> StructArray {
327        StructArray::from_fields(&[
328            ("a", buffer![0, 1, 2].into_array()),
329            ("b", buffer![4, 5, 6].into_array()),
330        ])
331        .unwrap()
332    }
333
334    #[test]
335    pub fn include_columns() {
336        let st = test_array();
337        let select = select(vec![FieldName::from("a")], root());
338        let selected = select
339            .evaluate(&Scope::new(st.to_array()))
340            .unwrap()
341            .to_struct();
342        let selected_names = selected.names().clone();
343        assert_eq!(selected_names.as_ref(), &["a"]);
344    }
345
346    #[test]
347    pub fn exclude_columns() {
348        let st = test_array();
349        let select = select_exclude(vec![FieldName::from("a")], root());
350        let selected = select
351            .evaluate(&Scope::new(st.to_array()))
352            .unwrap()
353            .to_struct();
354        let selected_names = selected.names().clone();
355        assert_eq!(selected_names.as_ref(), &["b"]);
356    }
357
358    #[test]
359    fn dtype() {
360        let dtype = test_harness::struct_dtype();
361
362        let select_expr = select(vec![FieldName::from("a")], root());
363        let expected_dtype = DType::Struct(
364            dtype
365                .as_struct_fields_opt()
366                .unwrap()
367                .project(&["a".into()])
368                .unwrap(),
369            Nullability::NonNullable,
370        );
371        assert_eq!(select_expr.return_dtype(&dtype).unwrap(), expected_dtype);
372
373        let select_expr_exclude = select_exclude(
374            vec![
375                FieldName::from("col1"),
376                FieldName::from("col2"),
377                FieldName::from("bool1"),
378                FieldName::from("bool2"),
379            ],
380            root(),
381        );
382        assert_eq!(
383            select_expr_exclude.return_dtype(&dtype).unwrap(),
384            expected_dtype
385        );
386
387        let select_expr_exclude = select_exclude(
388            vec![FieldName::from("col1"), FieldName::from("col2")],
389            root(),
390        );
391        assert_eq!(
392            select_expr_exclude.return_dtype(&dtype).unwrap(),
393            DType::Struct(
394                dtype
395                    .as_struct_fields_opt()
396                    .unwrap()
397                    .project(&["a".into(), "bool1".into(), "bool2".into()])
398                    .unwrap(),
399                Nullability::NonNullable
400            )
401        );
402    }
403
404    #[test]
405    fn test_as_include_names() {
406        let field_names = FieldNames::from(["a", "b", "c"]);
407        let include = SelectExpr::new(FieldSelection::Include(["a"].into()), root());
408        let exclude = SelectExpr::new(FieldSelection::Exclude(["b", "c"].into()), root());
409        assert_eq!(
410            &include.as_include(&field_names).unwrap(),
411            &exclude.as_include(&field_names).unwrap()
412        );
413    }
414}