Skip to main content

vortex_array/scalar_fn/fns/
select.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::fmt::Formatter;
6
7use itertools::Itertools;
8use prost::Message;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_bail;
12use vortex_error::vortex_err;
13use vortex_proto::expr::FieldNames as ProtoFieldNames;
14use vortex_proto::expr::SelectOpts;
15use vortex_proto::expr::select_opts::Opts;
16use vortex_session::VortexSession;
17
18use crate::ArrayRef;
19use crate::ExecutionCtx;
20use crate::IntoArray;
21use crate::arrays::StructArray;
22use crate::dtype::DType;
23use crate::dtype::FieldName;
24use crate::dtype::FieldNames;
25use crate::expr::expression::Expression;
26use crate::expr::field::DisplayFieldNames;
27use crate::expr::get_item;
28use crate::expr::pack;
29use crate::scalar_fn::Arity;
30use crate::scalar_fn::ChildName;
31use crate::scalar_fn::ExecutionArgs;
32use crate::scalar_fn::ScalarFnId;
33use crate::scalar_fn::ScalarFnVTable;
34use crate::scalar_fn::SimplifyCtx;
35use crate::scalar_fn::fns::pack::Pack;
36
37#[derive(Debug, Clone, PartialEq, Eq, Hash)]
38pub enum FieldSelection {
39    Include(FieldNames),
40    Exclude(FieldNames),
41}
42
43#[derive(Clone)]
44pub struct Select;
45
46impl ScalarFnVTable for Select {
47    type Options = FieldSelection;
48
49    fn id(&self) -> ScalarFnId {
50        ScalarFnId::new_ref("vortex.select")
51    }
52
53    fn serialize(&self, instance: &FieldSelection) -> VortexResult<Option<Vec<u8>>> {
54        let opts = match instance {
55            FieldSelection::Include(fields) => Opts::Include(ProtoFieldNames {
56                names: fields.iter().map(|f| f.to_string()).collect(),
57            }),
58            FieldSelection::Exclude(fields) => Opts::Exclude(ProtoFieldNames {
59                names: fields.iter().map(|f| f.to_string()).collect(),
60            }),
61        };
62
63        let select_opts = SelectOpts { opts: Some(opts) };
64        Ok(Some(select_opts.encode_to_vec()))
65    }
66
67    fn deserialize(
68        &self,
69        _metadata: &[u8],
70        _session: &VortexSession,
71    ) -> VortexResult<FieldSelection> {
72        let prost_metadata = SelectOpts::decode(_metadata)?;
73
74        let select_opts = prost_metadata
75            .opts
76            .ok_or_else(|| vortex_err!("SelectOpts missing opts field"))?;
77
78        let field_selection = match select_opts {
79            Opts::Include(field_names) => FieldSelection::Include(FieldNames::from_iter(
80                field_names.names.iter().map(|s| s.as_str()),
81            )),
82            Opts::Exclude(field_names) => FieldSelection::Exclude(FieldNames::from_iter(
83                field_names.names.iter().map(|s| s.as_str()),
84            )),
85        };
86
87        Ok(field_selection)
88    }
89
90    fn arity(&self, _options: &FieldSelection) -> Arity {
91        Arity::Exact(1)
92    }
93
94    fn child_name(&self, _instance: &FieldSelection, child_idx: usize) -> ChildName {
95        match child_idx {
96            0 => ChildName::new_ref("child"),
97            _ => unreachable!(),
98        }
99    }
100
101    fn fmt_sql(
102        &self,
103        selection: &FieldSelection,
104        expr: &Expression,
105        f: &mut Formatter<'_>,
106    ) -> std::fmt::Result {
107        expr.child(0).fmt_sql(f)?;
108        match selection {
109            FieldSelection::Include(fields) => {
110                write!(f, "{{{}}}", DisplayFieldNames(fields))
111            }
112            FieldSelection::Exclude(fields) => {
113                write!(f, "{{~ {}}}", DisplayFieldNames(fields))
114            }
115        }
116    }
117
118    fn return_dtype(
119        &self,
120        selection: &FieldSelection,
121        arg_dtypes: &[DType],
122    ) -> VortexResult<DType> {
123        let child_dtype = &arg_dtypes[0];
124        let child_struct_dtype = child_dtype
125            .as_struct_fields_opt()
126            .ok_or_else(|| vortex_err!("Select child not a struct dtype"))?;
127
128        let projected = match selection {
129            FieldSelection::Include(fields) => child_struct_dtype.project(fields.as_ref())?,
130            FieldSelection::Exclude(fields) => child_struct_dtype
131                .names()
132                .iter()
133                .cloned()
134                .zip_eq(child_struct_dtype.fields())
135                .filter(|(name, _)| !fields.as_ref().contains(name))
136                .collect(),
137        };
138
139        Ok(DType::Struct(projected, child_dtype.nullability()))
140    }
141
142    fn execute(
143        &self,
144        selection: &FieldSelection,
145        args: &dyn ExecutionArgs,
146        ctx: &mut ExecutionCtx,
147    ) -> VortexResult<ArrayRef> {
148        let child = args.get(0)?.execute::<StructArray>(ctx)?;
149
150        let result = match selection {
151            FieldSelection::Include(f) => child.project(f.as_ref()),
152            FieldSelection::Exclude(names) => {
153                let included_names = child
154                    .names()
155                    .iter()
156                    .filter(|&f| !names.as_ref().contains(f))
157                    .cloned()
158                    .collect::<Vec<_>>();
159                child.project(included_names.as_slice())
160            }
161        }?;
162
163        result.into_array().execute(ctx)
164    }
165
166    fn simplify(
167        &self,
168        selection: &FieldSelection,
169        expr: &Expression,
170        ctx: &dyn SimplifyCtx,
171    ) -> VortexResult<Option<Expression>> {
172        let child_struct = expr.child(0);
173        let struct_dtype = ctx.return_dtype(child_struct)?;
174        let struct_nullability = struct_dtype.nullability();
175
176        let struct_fields = struct_dtype.as_struct_fields_opt().ok_or_else(|| {
177            vortex_err!(
178                "Select child must return a struct dtype, however it was a {}",
179                struct_dtype
180            )
181        })?;
182
183        // "Mask" out the unwanted fields of the child struct `DType`.
184        let included_fields = selection.normalize_to_included_fields(struct_fields.names())?;
185        let all_included_fields_are_nullable = included_fields.iter().all(|name| {
186            struct_fields
187                .field(name)
188                .vortex_expect(
189                    "`normalize_to_included_fields` checks that the included fields already exist \
190                     in `struct_fields`",
191                )
192                .is_nullable()
193        });
194
195        // If no fields are included, we can trivially simplify to a pack expression.
196        // NOTE(ngates): we do this knowing that our layout expression partitioning logic has
197        //  special-casing for pack, but not for select. We will fix this up when we revisit the
198        //  layout APIs.
199        if included_fields.is_empty() {
200            let empty: Vec<(FieldName, Expression)> = vec![];
201            return Ok(Some(pack(empty, struct_nullability)));
202        }
203
204        // We cannot always convert a `select` into a `pack(get_item(f1), get_item(f2), ...)`.
205        // This is because `get_item` does a validity intersection of the struct validity with its
206        // fields, which is not the same as just "masking" out the unwanted fields (a selection).
207        //
208        // We can, however, make this simplification when the child of the `select` is already a
209        // `pack` and we know that `get_item` will do no validity intersections.
210        let child_is_pack = child_struct.is::<Pack>();
211
212        // `get_item` only performs validity intersection when the struct is nullable but the field
213        // is not. This would change the semantics of a `select`, so we can only simplify when this
214        // won't happen.
215        let would_intersect_validity =
216            struct_nullability.is_nullable() && !all_included_fields_are_nullable;
217
218        if child_is_pack && !would_intersect_validity {
219            let pack_expr = pack(
220                included_fields
221                    .into_iter()
222                    .map(|name| (name.clone(), get_item(name, child_struct.clone()))),
223                struct_nullability,
224            );
225
226            return Ok(Some(pack_expr));
227        }
228
229        Ok(None)
230    }
231
232    fn is_null_sensitive(&self, _instance: &FieldSelection) -> bool {
233        true
234    }
235
236    fn is_fallible(&self, _instance: &FieldSelection) -> bool {
237        // If this type-checks its infallible.
238        false
239    }
240}
241
242impl FieldSelection {
243    pub fn include(columns: FieldNames) -> Self {
244        assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
245        Self::Include(columns)
246    }
247
248    pub fn exclude(columns: FieldNames) -> Self {
249        assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
250        Self::Exclude(columns)
251    }
252
253    pub fn is_include(&self) -> bool {
254        matches!(self, Self::Include(_))
255    }
256
257    pub fn is_exclude(&self) -> bool {
258        matches!(self, Self::Exclude(_))
259    }
260
261    pub fn field_names(&self) -> &FieldNames {
262        let (FieldSelection::Include(fields) | FieldSelection::Exclude(fields)) = self;
263
264        fields
265    }
266
267    pub fn normalize_to_included_fields(
268        &self,
269        available_fields: &FieldNames,
270    ) -> VortexResult<FieldNames> {
271        // Check that all of the field names exist in the available fields.
272        if self
273            .field_names()
274            .iter()
275            .any(|f| !available_fields.iter().contains(f))
276        {
277            vortex_bail!(
278                "Select fields {:?} must be a subset of child fields {:?}",
279                self,
280                available_fields
281            );
282        }
283
284        match self {
285            FieldSelection::Include(fields) => Ok(fields.clone()),
286            FieldSelection::Exclude(exc_fields) => Ok(available_fields
287                .iter()
288                .filter(|f| !exc_fields.iter().contains(f))
289                .cloned()
290                .collect()),
291        }
292    }
293}
294
295impl Display for FieldSelection {
296    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
297        match self {
298            FieldSelection::Include(fields) => write!(f, "{{{}}}", DisplayFieldNames(fields)),
299            FieldSelection::Exclude(fields) => write!(f, "~{{{}}}", DisplayFieldNames(fields)),
300        }
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use vortex_buffer::buffer;
307
308    use crate::IntoArray;
309    use crate::ToCanonical;
310    use crate::arrays::StructArray;
311    use crate::dtype::DType;
312    use crate::dtype::FieldName;
313    use crate::dtype::FieldNames;
314    use crate::dtype::Nullability;
315    use crate::dtype::Nullability::Nullable;
316    use crate::dtype::PType::I32;
317    use crate::dtype::StructFields;
318    use crate::expr::root;
319    use crate::expr::select;
320    use crate::expr::select_exclude;
321    use crate::expr::test_harness;
322    use crate::scalar_fn::fns::select::Select;
323
324    fn test_array() -> StructArray {
325        StructArray::from_fields(&[
326            ("a", buffer![0, 1, 2].into_array()),
327            ("b", buffer![4, 5, 6].into_array()),
328        ])
329        .unwrap()
330    }
331
332    #[test]
333    pub fn include_columns() {
334        let st = test_array();
335        let select = select(vec![FieldName::from("a")], root());
336        let selected = st.to_array().apply(&select).unwrap().to_struct();
337        let selected_names = selected.names().clone();
338        assert_eq!(selected_names.as_ref(), &["a"]);
339    }
340
341    #[test]
342    pub fn exclude_columns() {
343        let st = test_array();
344        let select = select_exclude(vec![FieldName::from("a")], root());
345        let selected = st.to_array().apply(&select).unwrap().to_struct();
346        let selected_names = selected.names().clone();
347        assert_eq!(selected_names.as_ref(), &["b"]);
348    }
349
350    #[test]
351    fn dtype() {
352        let dtype = test_harness::struct_dtype();
353
354        let select_expr = select(vec![FieldName::from("a")], root());
355        let expected_dtype = DType::Struct(
356            dtype
357                .as_struct_fields_opt()
358                .unwrap()
359                .project(&["a".into()])
360                .unwrap(),
361            Nullability::NonNullable,
362        );
363        assert_eq!(select_expr.return_dtype(&dtype).unwrap(), expected_dtype);
364
365        let select_expr_exclude = select_exclude(
366            vec![
367                FieldName::from("col1"),
368                FieldName::from("col2"),
369                FieldName::from("bool1"),
370                FieldName::from("bool2"),
371            ],
372            root(),
373        );
374        assert_eq!(
375            select_expr_exclude.return_dtype(&dtype).unwrap(),
376            expected_dtype
377        );
378
379        let select_expr_exclude = select_exclude(
380            vec![FieldName::from("col1"), FieldName::from("col2")],
381            root(),
382        );
383        assert_eq!(
384            select_expr_exclude.return_dtype(&dtype).unwrap(),
385            DType::Struct(
386                dtype
387                    .as_struct_fields_opt()
388                    .unwrap()
389                    .project(&["a".into(), "bool1".into(), "bool2".into()])
390                    .unwrap(),
391                Nullability::NonNullable
392            )
393        );
394    }
395
396    #[test]
397    fn test_as_include_names() {
398        let field_names = FieldNames::from(["a", "b", "c"]);
399        let include = select(["a"], root());
400        let exclude = select_exclude(["b", "c"], root());
401        assert_eq!(
402            &include
403                .as_::<Select>()
404                .normalize_to_included_fields(&field_names)
405                .unwrap(),
406            &exclude
407                .as_::<Select>()
408                .normalize_to_included_fields(&field_names)
409                .unwrap()
410        );
411    }
412
413    #[test]
414    fn test_remove_select_rule() {
415        let dtype = DType::Struct(
416            StructFields::new(["a", "b"].into(), vec![I32.into(), I32.into()]),
417            Nullable,
418        );
419        let e = select(["a", "b"], root());
420
421        let result = e.optimize_recursive(&dtype).unwrap();
422
423        assert!(result.return_dtype(&dtype).unwrap().is_nullable());
424    }
425
426    #[test]
427    fn test_remove_select_rule_exclude_fields() {
428        use crate::expr::select_exclude;
429
430        let dtype = DType::Struct(
431            StructFields::new(
432                ["a", "b", "c"].into(),
433                vec![I32.into(), I32.into(), I32.into()],
434            ),
435            Nullable,
436        );
437        let e = select_exclude(["c"], root());
438
439        let result = e.optimize_recursive(&dtype).unwrap();
440
441        // Should exclude "c" and include "a" and "b"
442        let result_dtype = result.return_dtype(&dtype).unwrap();
443        assert!(result_dtype.is_nullable());
444        let fields = result_dtype.as_struct_fields_opt().unwrap();
445        assert_eq!(fields.names().as_ref(), &["a", "b"]);
446    }
447}