Skip to main content

vortex_array/scalar_fn/fns/
select.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::fmt::Formatter;
6
7use itertools::Itertools;
8use prost::Message;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_bail;
12use vortex_error::vortex_err;
13use vortex_proto::expr::FieldNames as ProtoFieldNames;
14use vortex_proto::expr::SelectOpts;
15use vortex_proto::expr::select_opts::Opts;
16use vortex_session::VortexSession;
17
18use crate::ArrayRef;
19use crate::ExecutionCtx;
20use crate::IntoArray;
21use crate::arrays::StructArray;
22use crate::arrays::struct_::StructArrayExt;
23use crate::dtype::DType;
24use crate::dtype::FieldName;
25use crate::dtype::FieldNames;
26use crate::expr::expression::Expression;
27use crate::expr::field::DisplayFieldNames;
28use crate::expr::get_item;
29use crate::expr::pack;
30use crate::scalar_fn::Arity;
31use crate::scalar_fn::ChildName;
32use crate::scalar_fn::ExecutionArgs;
33use crate::scalar_fn::ScalarFnId;
34use crate::scalar_fn::ScalarFnVTable;
35use crate::scalar_fn::SimplifyCtx;
36use crate::scalar_fn::fns::pack::Pack;
37
38#[derive(Debug, Clone, PartialEq, Eq, Hash)]
39pub enum FieldSelection {
40    Include(FieldNames),
41    Exclude(FieldNames),
42}
43
44#[derive(Clone)]
45pub struct Select;
46
47impl ScalarFnVTable for Select {
48    type Options = FieldSelection;
49
50    fn id(&self) -> ScalarFnId {
51        ScalarFnId::from("vortex.select")
52    }
53
54    fn serialize(&self, instance: &FieldSelection) -> VortexResult<Option<Vec<u8>>> {
55        let opts = match instance {
56            FieldSelection::Include(fields) => Opts::Include(ProtoFieldNames {
57                names: fields.iter().map(|f| f.to_string()).collect(),
58            }),
59            FieldSelection::Exclude(fields) => Opts::Exclude(ProtoFieldNames {
60                names: fields.iter().map(|f| f.to_string()).collect(),
61            }),
62        };
63
64        let select_opts = SelectOpts { opts: Some(opts) };
65        Ok(Some(select_opts.encode_to_vec()))
66    }
67
68    fn deserialize(
69        &self,
70        _metadata: &[u8],
71        _session: &VortexSession,
72    ) -> VortexResult<FieldSelection> {
73        let prost_metadata = SelectOpts::decode(_metadata)?;
74
75        let select_opts = prost_metadata
76            .opts
77            .ok_or_else(|| vortex_err!("SelectOpts missing opts field"))?;
78
79        let field_selection = match select_opts {
80            Opts::Include(field_names) => FieldSelection::Include(FieldNames::from_iter(
81                field_names.names.iter().map(|s| s.as_str()),
82            )),
83            Opts::Exclude(field_names) => FieldSelection::Exclude(FieldNames::from_iter(
84                field_names.names.iter().map(|s| s.as_str()),
85            )),
86        };
87
88        Ok(field_selection)
89    }
90
91    fn arity(&self, _options: &FieldSelection) -> Arity {
92        Arity::Exact(1)
93    }
94
95    fn child_name(&self, _instance: &FieldSelection, child_idx: usize) -> ChildName {
96        match child_idx {
97            0 => ChildName::from("child"),
98            _ => unreachable!(),
99        }
100    }
101
102    fn fmt_sql(
103        &self,
104        selection: &FieldSelection,
105        expr: &Expression,
106        f: &mut Formatter<'_>,
107    ) -> std::fmt::Result {
108        expr.child(0).fmt_sql(f)?;
109        match selection {
110            FieldSelection::Include(fields) => {
111                write!(f, "{{{}}}", DisplayFieldNames(fields))
112            }
113            FieldSelection::Exclude(fields) => {
114                write!(f, "{{~ {}}}", DisplayFieldNames(fields))
115            }
116        }
117    }
118
119    fn return_dtype(
120        &self,
121        selection: &FieldSelection,
122        arg_dtypes: &[DType],
123    ) -> VortexResult<DType> {
124        let child_dtype = &arg_dtypes[0];
125        let child_struct_dtype = child_dtype
126            .as_struct_fields_opt()
127            .ok_or_else(|| vortex_err!("Select child not a struct dtype"))?;
128
129        let projected = match selection {
130            FieldSelection::Include(fields) => child_struct_dtype.project(fields.as_ref())?,
131            FieldSelection::Exclude(fields) => child_struct_dtype
132                .names()
133                .iter()
134                .cloned()
135                .zip_eq(child_struct_dtype.fields())
136                .filter(|(name, _)| !fields.as_ref().contains(name))
137                .collect(),
138        };
139
140        Ok(DType::Struct(projected, child_dtype.nullability()))
141    }
142
143    fn execute(
144        &self,
145        selection: &FieldSelection,
146        args: &dyn ExecutionArgs,
147        ctx: &mut ExecutionCtx,
148    ) -> VortexResult<ArrayRef> {
149        let child = args.get(0)?.execute::<StructArray>(ctx)?;
150
151        let result = match selection {
152            FieldSelection::Include(f) => child.project(f.as_ref()),
153            FieldSelection::Exclude(names) => {
154                let included_names = child
155                    .names()
156                    .iter()
157                    .filter(|&f| !names.as_ref().contains(f))
158                    .cloned()
159                    .collect::<Vec<_>>();
160                child.project(included_names.as_slice())
161            }
162        }?;
163
164        result.into_array().execute(ctx)
165    }
166
167    fn simplify(
168        &self,
169        selection: &FieldSelection,
170        expr: &Expression,
171        ctx: &dyn SimplifyCtx,
172    ) -> VortexResult<Option<Expression>> {
173        let child_struct = expr.child(0);
174        let struct_dtype = ctx.return_dtype(child_struct)?;
175        let struct_nullability = struct_dtype.nullability();
176
177        let struct_fields = struct_dtype.as_struct_fields_opt().ok_or_else(|| {
178            vortex_err!(
179                "Select child must return a struct dtype, however it was a {}",
180                struct_dtype
181            )
182        })?;
183
184        // "Mask" out the unwanted fields of the child struct `DType`.
185        let included_fields = selection.normalize_to_included_fields(struct_fields.names())?;
186        let all_included_fields_are_nullable = included_fields.iter().all(|name| {
187            struct_fields
188                .field(name)
189                .vortex_expect(
190                    "`normalize_to_included_fields` checks that the included fields already exist \
191                     in `struct_fields`",
192                )
193                .is_nullable()
194        });
195
196        // If no fields are included, we can trivially simplify to a pack expression.
197        // NOTE(ngates): we do this knowing that our layout expression partitioning logic has
198        //  special-casing for pack, but not for select. We will fix this up when we revisit the
199        //  layout APIs.
200        if included_fields.is_empty() {
201            let empty: Vec<(FieldName, Expression)> = vec![];
202            return Ok(Some(pack(empty, struct_nullability)));
203        }
204
205        // We cannot always convert a `select` into a `pack(get_item(f1), get_item(f2), ...)`.
206        // This is because `get_item` does a validity intersection of the struct validity with its
207        // fields, which is not the same as just "masking" out the unwanted fields (a selection).
208        //
209        // We can, however, make this simplification when the child of the `select` is already a
210        // `pack` and we know that `get_item` will do no validity intersections.
211        let child_is_pack = child_struct.is::<Pack>();
212
213        // `get_item` only performs validity intersection when the struct is nullable but the field
214        // is not. This would change the semantics of a `select`, so we can only simplify when this
215        // won't happen.
216        let would_intersect_validity =
217            struct_nullability.is_nullable() && !all_included_fields_are_nullable;
218
219        if child_is_pack && !would_intersect_validity {
220            let pack_expr = pack(
221                included_fields
222                    .into_iter()
223                    .map(|name| (name.clone(), get_item(name, child_struct.clone()))),
224                struct_nullability,
225            );
226
227            return Ok(Some(pack_expr));
228        }
229
230        Ok(None)
231    }
232
233    fn is_null_sensitive(&self, _instance: &FieldSelection) -> bool {
234        true
235    }
236
237    fn is_fallible(&self, _instance: &FieldSelection) -> bool {
238        // If this type-checks its infallible.
239        false
240    }
241}
242
243impl FieldSelection {
244    pub fn include(columns: FieldNames) -> Self {
245        assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
246        Self::Include(columns)
247    }
248
249    pub fn exclude(columns: FieldNames) -> Self {
250        assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
251        Self::Exclude(columns)
252    }
253
254    pub fn is_include(&self) -> bool {
255        matches!(self, Self::Include(_))
256    }
257
258    pub fn is_exclude(&self) -> bool {
259        matches!(self, Self::Exclude(_))
260    }
261
262    pub fn field_names(&self) -> &FieldNames {
263        let (FieldSelection::Include(fields) | FieldSelection::Exclude(fields)) = self;
264
265        fields
266    }
267
268    pub fn normalize_to_included_fields(
269        &self,
270        available_fields: &FieldNames,
271    ) -> VortexResult<FieldNames> {
272        // Check that all of the field names exist in the available fields.
273        if self
274            .field_names()
275            .iter()
276            .any(|f| !available_fields.iter().contains(f))
277        {
278            vortex_bail!(
279                "Select fields {:?} must be a subset of child fields {:?}",
280                self,
281                available_fields
282            );
283        }
284
285        match self {
286            FieldSelection::Include(fields) => Ok(fields.clone()),
287            FieldSelection::Exclude(exc_fields) => Ok(available_fields
288                .iter()
289                .filter(|f| !exc_fields.iter().contains(f))
290                .cloned()
291                .collect()),
292        }
293    }
294}
295
296impl Display for FieldSelection {
297    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
298        match self {
299            FieldSelection::Include(fields) => write!(f, "{{{}}}", DisplayFieldNames(fields)),
300            FieldSelection::Exclude(fields) => write!(f, "~{{{}}}", DisplayFieldNames(fields)),
301        }
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use vortex_buffer::buffer;
308
309    use crate::IntoArray;
310    #[expect(deprecated)]
311    use crate::ToCanonical as _;
312    use crate::arrays::struct_::StructArrayExt;
313    use crate::dtype::DType;
314    use crate::dtype::FieldName;
315    use crate::dtype::FieldNames;
316    use crate::dtype::Nullability;
317    use crate::dtype::Nullability::Nullable;
318    use crate::dtype::PType::I32;
319    use crate::dtype::StructFields;
320    use crate::expr::root;
321    use crate::expr::select;
322    use crate::expr::select_exclude;
323    use crate::expr::test_harness;
324    use crate::scalar_fn::fns::select::Select;
325    use crate::scalar_fn::fns::select::StructArray;
326
327    fn test_array() -> StructArray {
328        StructArray::from_fields(&[
329            ("a", buffer![0, 1, 2].into_array()),
330            ("b", buffer![4, 5, 6].into_array()),
331        ])
332        .unwrap()
333    }
334
335    #[test]
336    pub fn include_columns() {
337        let st = test_array();
338        let select = select(vec![FieldName::from("a")], root());
339        #[expect(deprecated)]
340        let selected = st.into_array().apply(&select).unwrap().to_struct();
341        let selected_names = selected.names().clone();
342        assert_eq!(selected_names.as_ref(), &["a"]);
343    }
344
345    #[test]
346    pub fn exclude_columns() {
347        let st = test_array();
348        let select = select_exclude(vec![FieldName::from("a")], root());
349        #[expect(deprecated)]
350        let selected = st.into_array().apply(&select).unwrap().to_struct();
351        let selected_names = selected.names().clone();
352        assert_eq!(selected_names.as_ref(), &["b"]);
353    }
354
355    #[test]
356    fn dtype() {
357        let dtype = test_harness::struct_dtype();
358
359        let select_expr = select(vec![FieldName::from("a")], root());
360        let expected_dtype = DType::Struct(
361            dtype
362                .as_struct_fields_opt()
363                .unwrap()
364                .project(&["a".into()])
365                .unwrap(),
366            Nullability::NonNullable,
367        );
368        assert_eq!(select_expr.return_dtype(&dtype).unwrap(), expected_dtype);
369
370        let select_expr_exclude = select_exclude(
371            vec![
372                FieldName::from("col1"),
373                FieldName::from("col2"),
374                FieldName::from("bool1"),
375                FieldName::from("bool2"),
376            ],
377            root(),
378        );
379        assert_eq!(
380            select_expr_exclude.return_dtype(&dtype).unwrap(),
381            expected_dtype
382        );
383
384        let select_expr_exclude = select_exclude(
385            vec![FieldName::from("col1"), FieldName::from("col2")],
386            root(),
387        );
388        assert_eq!(
389            select_expr_exclude.return_dtype(&dtype).unwrap(),
390            DType::Struct(
391                dtype
392                    .as_struct_fields_opt()
393                    .unwrap()
394                    .project(&["a".into(), "bool1".into(), "bool2".into()])
395                    .unwrap(),
396                Nullability::NonNullable
397            )
398        );
399    }
400
401    #[test]
402    fn test_as_include_names() {
403        let field_names = FieldNames::from(["a", "b", "c"]);
404        let include = select(["a"], root());
405        let exclude = select_exclude(["b", "c"], root());
406        assert_eq!(
407            &include
408                .as_::<Select>()
409                .normalize_to_included_fields(&field_names)
410                .unwrap(),
411            &exclude
412                .as_::<Select>()
413                .normalize_to_included_fields(&field_names)
414                .unwrap()
415        );
416    }
417
418    #[test]
419    fn test_remove_select_rule() {
420        let dtype = DType::Struct(
421            StructFields::new(["a", "b"].into(), vec![I32.into(), I32.into()]),
422            Nullable,
423        );
424        let e = select(["a", "b"], root());
425
426        let result = e.optimize_recursive(&dtype).unwrap();
427
428        assert!(result.return_dtype(&dtype).unwrap().is_nullable());
429    }
430
431    #[test]
432    fn test_remove_select_rule_exclude_fields() {
433        use crate::expr::select_exclude;
434
435        let dtype = DType::Struct(
436            StructFields::new(
437                ["a", "b", "c"].into(),
438                vec![I32.into(), I32.into(), I32.into()],
439            ),
440            Nullable,
441        );
442        let e = select_exclude(["c"], root());
443
444        let result = e.optimize_recursive(&dtype).unwrap();
445
446        // Should exclude "c" and include "a" and "b"
447        let result_dtype = result.return_dtype(&dtype).unwrap();
448        assert!(result_dtype.is_nullable());
449        let fields = result_dtype.as_struct_fields_opt().unwrap();
450        assert_eq!(fields.names().as_ref(), &["a", "b"]);
451    }
452}