Skip to main content

vortex_array/expr/exprs/
select.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::fmt::Formatter;
6
7use itertools::Itertools;
8use prost::Message;
9use vortex_dtype::DType;
10use vortex_dtype::FieldName;
11use vortex_dtype::FieldNames;
12use vortex_error::VortexExpect;
13use vortex_error::VortexResult;
14use vortex_error::vortex_bail;
15use vortex_error::vortex_err;
16use vortex_proto::expr::FieldNames as ProtoFieldNames;
17use vortex_proto::expr::SelectOpts;
18use vortex_proto::expr::select_opts::Opts;
19use vortex_session::VortexSession;
20
21use crate::ArrayRef;
22use crate::IntoArray;
23use crate::arrays::StructArray;
24use crate::expr;
25use crate::expr::Arity;
26use crate::expr::ChildName;
27use crate::expr::ExecutionArgs;
28use crate::expr::ExprId;
29use crate::expr::Pack;
30use crate::expr::SimplifyCtx;
31use crate::expr::VTable;
32use crate::expr::VTableExt;
33use crate::expr::expression::Expression;
34use crate::expr::field::DisplayFieldNames;
35
36#[derive(Debug, Clone, PartialEq, Eq, Hash)]
37pub enum FieldSelection {
38    Include(FieldNames),
39    Exclude(FieldNames),
40}
41
42pub struct Select;
43
44impl VTable for Select {
45    type Options = FieldSelection;
46
47    fn id(&self) -> ExprId {
48        ExprId::new_ref("vortex.select")
49    }
50
51    fn serialize(&self, instance: &FieldSelection) -> VortexResult<Option<Vec<u8>>> {
52        let opts = match instance {
53            FieldSelection::Include(fields) => Opts::Include(ProtoFieldNames {
54                names: fields.iter().map(|f| f.to_string()).collect(),
55            }),
56            FieldSelection::Exclude(fields) => Opts::Exclude(ProtoFieldNames {
57                names: fields.iter().map(|f| f.to_string()).collect(),
58            }),
59        };
60
61        let select_opts = SelectOpts { opts: Some(opts) };
62        Ok(Some(select_opts.encode_to_vec()))
63    }
64
65    fn deserialize(
66        &self,
67        _metadata: &[u8],
68        _session: &VortexSession,
69    ) -> VortexResult<FieldSelection> {
70        let prost_metadata = SelectOpts::decode(_metadata)?;
71
72        let select_opts = prost_metadata
73            .opts
74            .ok_or_else(|| vortex_err!("SelectOpts missing opts field"))?;
75
76        let field_selection = match select_opts {
77            Opts::Include(field_names) => FieldSelection::Include(FieldNames::from_iter(
78                field_names.names.iter().map(|s| s.as_str()),
79            )),
80            Opts::Exclude(field_names) => FieldSelection::Exclude(FieldNames::from_iter(
81                field_names.names.iter().map(|s| s.as_str()),
82            )),
83        };
84
85        Ok(field_selection)
86    }
87
88    fn arity(&self, _options: &FieldSelection) -> Arity {
89        Arity::Exact(1)
90    }
91
92    fn child_name(&self, _instance: &FieldSelection, child_idx: usize) -> ChildName {
93        match child_idx {
94            0 => ChildName::new_ref("child"),
95            _ => unreachable!(),
96        }
97    }
98
99    fn fmt_sql(
100        &self,
101        selection: &FieldSelection,
102        expr: &Expression,
103        f: &mut Formatter<'_>,
104    ) -> std::fmt::Result {
105        expr.child(0).fmt_sql(f)?;
106        match selection {
107            FieldSelection::Include(fields) => {
108                write!(f, "{{{}}}", DisplayFieldNames(fields))
109            }
110            FieldSelection::Exclude(fields) => {
111                write!(f, "{{~ {}}}", DisplayFieldNames(fields))
112            }
113        }
114    }
115
116    fn return_dtype(
117        &self,
118        selection: &FieldSelection,
119        arg_dtypes: &[DType],
120    ) -> VortexResult<DType> {
121        let child_dtype = &arg_dtypes[0];
122        let child_struct_dtype = child_dtype
123            .as_struct_fields_opt()
124            .ok_or_else(|| vortex_err!("Select child not a struct dtype"))?;
125
126        let projected = match selection {
127            FieldSelection::Include(fields) => child_struct_dtype.project(fields.as_ref())?,
128            FieldSelection::Exclude(fields) => child_struct_dtype
129                .names()
130                .iter()
131                .cloned()
132                .zip_eq(child_struct_dtype.fields())
133                .filter(|(name, _)| !fields.as_ref().contains(name))
134                .collect(),
135        };
136
137        Ok(DType::Struct(projected, child_dtype.nullability()))
138    }
139
140    fn execute(
141        &self,
142        selection: &FieldSelection,
143        mut args: ExecutionArgs,
144    ) -> VortexResult<ArrayRef> {
145        let child = args
146            .inputs
147            .pop()
148            .vortex_expect("Missing input child")
149            .execute::<StructArray>(args.ctx)?;
150
151        let result = match selection {
152            FieldSelection::Include(f) => child.project(f.as_ref()),
153            FieldSelection::Exclude(names) => {
154                let included_names = child
155                    .names()
156                    .iter()
157                    .filter(|&f| !names.as_ref().contains(f))
158                    .cloned()
159                    .collect::<Vec<_>>();
160                child.project(included_names.as_slice())
161            }
162        }?;
163
164        result.into_array().execute(args.ctx)
165    }
166
167    fn simplify(
168        &self,
169        selection: &FieldSelection,
170        expr: &Expression,
171        ctx: &dyn SimplifyCtx,
172    ) -> VortexResult<Option<Expression>> {
173        let child_struct = expr.child(0);
174        let struct_dtype = ctx.return_dtype(child_struct)?;
175        let struct_nullability = struct_dtype.nullability();
176
177        let struct_fields = struct_dtype.as_struct_fields_opt().ok_or_else(|| {
178            vortex_err!(
179                "Select child must return a struct dtype, however it was a {}",
180                struct_dtype
181            )
182        })?;
183
184        // "Mask" out the unwanted fields of the child struct `DType`.
185        let included_fields = selection.normalize_to_included_fields(struct_fields.names())?;
186        let all_included_fields_are_nullable = included_fields.iter().all(|name| {
187            struct_fields
188                .field(name)
189                .vortex_expect(
190                    "`normalize_to_included_fields` checks that the included fields already exist \
191                     in `struct_fields`",
192                )
193                .is_nullable()
194        });
195
196        // If no fields are included, we can trivially simplify to a pack expression.
197        // NOTE(ngates): we do this knowing that our layout expression partitioning logic has
198        //  special-casing for pack, but not for select. We will fix this up when we revisit the
199        //  layout APIs.
200        if included_fields.is_empty() {
201            let empty: Vec<(FieldName, Expression)> = vec![];
202            return Ok(Some(expr::pack(empty, struct_nullability)));
203        }
204
205        // We cannot always convert a `select` into a `pack(get_item(f1), get_item(f2), ...)`.
206        // This is because `get_item` does a validity intersection of the struct validity with its
207        // fields, which is not the same as just "masking" out the unwanted fields (a selection).
208        //
209        // We can, however, make this simplification when the child of the `select` is already a
210        // `pack` and we know that `get_item` will do no validity intersections.
211        let child_is_pack = child_struct.is::<Pack>();
212
213        // `get_item` only performs validity intersection when the struct is nullable but the field
214        // is not. This would change the semantics of a `select`, so we can only simplify when this
215        // won't happen.
216        let would_intersect_validity =
217            struct_nullability.is_nullable() && !all_included_fields_are_nullable;
218
219        if child_is_pack && !would_intersect_validity {
220            let pack_expr = expr::pack(
221                included_fields
222                    .into_iter()
223                    .map(|name| (name.clone(), expr::get_item(name, child_struct.clone()))),
224                struct_nullability,
225            );
226
227            return Ok(Some(pack_expr));
228        }
229
230        Ok(None)
231    }
232
233    fn is_null_sensitive(&self, _instance: &FieldSelection) -> bool {
234        true
235    }
236
237    fn is_fallible(&self, _instance: &FieldSelection) -> bool {
238        // If this type-checks its infallible.
239        false
240    }
241}
242
243/// Creates an expression that selects (includes) specific fields from an array.
244///
245/// Projects only the specified fields from the child expression, which must be of DType struct.
246/// ```rust
247/// # use vortex_array::expr::{select, root};
248/// let expr = select(["name", "age"], root());
249/// ```
250pub fn select(field_names: impl Into<FieldNames>, child: Expression) -> Expression {
251    Select
252        .try_new_expr(FieldSelection::Include(field_names.into()), [child])
253        .vortex_expect("Failed to create Select expression")
254}
255
256/// Creates an expression that excludes specific fields from an array.
257///
258/// Projects all fields except the specified ones from the input struct expression.
259///
260/// ```rust
261/// # use vortex_array::expr::{select_exclude, root};
262/// let expr = select_exclude(["internal_id", "metadata"], root());
263/// ```
264pub fn select_exclude(fields: impl Into<FieldNames>, child: Expression) -> Expression {
265    Select
266        .try_new_expr(FieldSelection::Exclude(fields.into()), [child])
267        .vortex_expect("Failed to create Select expression")
268}
269
270impl FieldSelection {
271    pub fn include(columns: FieldNames) -> Self {
272        assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
273        Self::Include(columns)
274    }
275
276    pub fn exclude(columns: FieldNames) -> Self {
277        assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
278        Self::Exclude(columns)
279    }
280
281    pub fn is_include(&self) -> bool {
282        matches!(self, Self::Include(_))
283    }
284
285    pub fn is_exclude(&self) -> bool {
286        matches!(self, Self::Exclude(_))
287    }
288
289    pub fn field_names(&self) -> &FieldNames {
290        let (FieldSelection::Include(fields) | FieldSelection::Exclude(fields)) = self;
291
292        fields
293    }
294
295    pub fn normalize_to_included_fields(
296        &self,
297        available_fields: &FieldNames,
298    ) -> VortexResult<FieldNames> {
299        // Check that all of the field names exist in the available fields.
300        if self
301            .field_names()
302            .iter()
303            .any(|f| !available_fields.iter().contains(f))
304        {
305            vortex_bail!(
306                "Select fields {:?} must be a subset of child fields {:?}",
307                self,
308                available_fields
309            );
310        }
311
312        match self {
313            FieldSelection::Include(fields) => Ok(fields.clone()),
314            FieldSelection::Exclude(exc_fields) => Ok(available_fields
315                .iter()
316                .filter(|f| !exc_fields.iter().contains(f))
317                .cloned()
318                .collect()),
319        }
320    }
321}
322
323impl Display for FieldSelection {
324    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
325        match self {
326            FieldSelection::Include(fields) => write!(f, "{{{}}}", DisplayFieldNames(fields)),
327            FieldSelection::Exclude(fields) => write!(f, "~{{{}}}", DisplayFieldNames(fields)),
328        }
329    }
330}
331
332#[cfg(test)]
333mod tests {
334    use vortex_buffer::buffer;
335    use vortex_dtype::DType;
336    use vortex_dtype::FieldName;
337    use vortex_dtype::FieldNames;
338    use vortex_dtype::Nullability;
339    use vortex_dtype::Nullability::Nullable;
340    use vortex_dtype::PType::I32;
341    use vortex_dtype::StructFields;
342
343    use super::select;
344    use super::select_exclude;
345    use crate::IntoArray;
346    use crate::ToCanonical;
347    use crate::arrays::StructArray;
348    use crate::expr::exprs::root::root;
349    use crate::expr::exprs::select::Select;
350    use crate::expr::test_harness;
351
352    fn test_array() -> StructArray {
353        StructArray::from_fields(&[
354            ("a", buffer![0, 1, 2].into_array()),
355            ("b", buffer![4, 5, 6].into_array()),
356        ])
357        .unwrap()
358    }
359
360    #[test]
361    pub fn include_columns() {
362        let st = test_array();
363        let select = select(vec![FieldName::from("a")], root());
364        let selected = st.to_array().apply(&select).unwrap().to_struct();
365        let selected_names = selected.names().clone();
366        assert_eq!(selected_names.as_ref(), &["a"]);
367    }
368
369    #[test]
370    pub fn exclude_columns() {
371        let st = test_array();
372        let select = select_exclude(vec![FieldName::from("a")], root());
373        let selected = st.to_array().apply(&select).unwrap().to_struct();
374        let selected_names = selected.names().clone();
375        assert_eq!(selected_names.as_ref(), &["b"]);
376    }
377
378    #[test]
379    fn dtype() {
380        let dtype = test_harness::struct_dtype();
381
382        let select_expr = select(vec![FieldName::from("a")], root());
383        let expected_dtype = DType::Struct(
384            dtype
385                .as_struct_fields_opt()
386                .unwrap()
387                .project(&["a".into()])
388                .unwrap(),
389            Nullability::NonNullable,
390        );
391        assert_eq!(select_expr.return_dtype(&dtype).unwrap(), expected_dtype);
392
393        let select_expr_exclude = select_exclude(
394            vec![
395                FieldName::from("col1"),
396                FieldName::from("col2"),
397                FieldName::from("bool1"),
398                FieldName::from("bool2"),
399            ],
400            root(),
401        );
402        assert_eq!(
403            select_expr_exclude.return_dtype(&dtype).unwrap(),
404            expected_dtype
405        );
406
407        let select_expr_exclude = select_exclude(
408            vec![FieldName::from("col1"), FieldName::from("col2")],
409            root(),
410        );
411        assert_eq!(
412            select_expr_exclude.return_dtype(&dtype).unwrap(),
413            DType::Struct(
414                dtype
415                    .as_struct_fields_opt()
416                    .unwrap()
417                    .project(&["a".into(), "bool1".into(), "bool2".into()])
418                    .unwrap(),
419                Nullability::NonNullable
420            )
421        );
422    }
423
424    #[test]
425    fn test_as_include_names() {
426        let field_names = FieldNames::from(["a", "b", "c"]);
427        let include = select(["a"], root());
428        let exclude = select_exclude(["b", "c"], root());
429        assert_eq!(
430            &include
431                .as_::<Select>()
432                .normalize_to_included_fields(&field_names)
433                .unwrap(),
434            &exclude
435                .as_::<Select>()
436                .normalize_to_included_fields(&field_names)
437                .unwrap()
438        );
439    }
440
441    #[test]
442    fn test_remove_select_rule() {
443        let dtype = DType::Struct(
444            StructFields::new(["a", "b"].into(), vec![I32.into(), I32.into()]),
445            Nullable,
446        );
447        let e = select(["a", "b"], root());
448
449        let result = e.optimize_recursive(&dtype).unwrap();
450
451        assert!(result.return_dtype(&dtype).unwrap().is_nullable());
452    }
453
454    #[test]
455    fn test_remove_select_rule_exclude_fields() {
456        use crate::expr::exprs::select::select_exclude;
457
458        let dtype = DType::Struct(
459            StructFields::new(
460                ["a", "b", "c"].into(),
461                vec![I32.into(), I32.into(), I32.into()],
462            ),
463            Nullable,
464        );
465        let e = select_exclude(["c"], root());
466
467        let result = e.optimize_recursive(&dtype).unwrap();
468
469        // Should exclude "c" and include "a" and "b"
470        let result_dtype = result.return_dtype(&dtype).unwrap();
471        assert!(result_dtype.is_nullable());
472        let fields = result_dtype.as_struct_fields_opt().unwrap();
473        assert_eq!(fields.names().as_ref(), &["a", "b"]);
474    }
475}