Skip to main content

vortex_array/scalar_fn/fns/
select.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::fmt::Formatter;
6
7use itertools::Itertools;
8use prost::Message;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_bail;
12use vortex_error::vortex_err;
13use vortex_proto::expr::FieldNames as ProtoFieldNames;
14use vortex_proto::expr::SelectOpts;
15use vortex_proto::expr::select_opts::Opts;
16use vortex_session::VortexSession;
17
18use crate::ArrayRef;
19use crate::IntoArray;
20use crate::arrays::StructArray;
21use crate::dtype::DType;
22use crate::dtype::FieldName;
23use crate::dtype::FieldNames;
24use crate::expr::expression::Expression;
25use crate::expr::field::DisplayFieldNames;
26use crate::expr::get_item;
27use crate::expr::pack;
28use crate::scalar_fn::Arity;
29use crate::scalar_fn::ChildName;
30use crate::scalar_fn::ExecutionArgs;
31use crate::scalar_fn::ScalarFnId;
32use crate::scalar_fn::ScalarFnVTable;
33use crate::scalar_fn::SimplifyCtx;
34use crate::scalar_fn::fns::pack::Pack;
35
36#[derive(Debug, Clone, PartialEq, Eq, Hash)]
37pub enum FieldSelection {
38    Include(FieldNames),
39    Exclude(FieldNames),
40}
41
42#[derive(Clone)]
43pub struct Select;
44
45impl ScalarFnVTable for Select {
46    type Options = FieldSelection;
47
48    fn id(&self) -> ScalarFnId {
49        ScalarFnId::new_ref("vortex.select")
50    }
51
52    fn serialize(&self, instance: &FieldSelection) -> VortexResult<Option<Vec<u8>>> {
53        let opts = match instance {
54            FieldSelection::Include(fields) => Opts::Include(ProtoFieldNames {
55                names: fields.iter().map(|f| f.to_string()).collect(),
56            }),
57            FieldSelection::Exclude(fields) => Opts::Exclude(ProtoFieldNames {
58                names: fields.iter().map(|f| f.to_string()).collect(),
59            }),
60        };
61
62        let select_opts = SelectOpts { opts: Some(opts) };
63        Ok(Some(select_opts.encode_to_vec()))
64    }
65
66    fn deserialize(
67        &self,
68        _metadata: &[u8],
69        _session: &VortexSession,
70    ) -> VortexResult<FieldSelection> {
71        let prost_metadata = SelectOpts::decode(_metadata)?;
72
73        let select_opts = prost_metadata
74            .opts
75            .ok_or_else(|| vortex_err!("SelectOpts missing opts field"))?;
76
77        let field_selection = match select_opts {
78            Opts::Include(field_names) => FieldSelection::Include(FieldNames::from_iter(
79                field_names.names.iter().map(|s| s.as_str()),
80            )),
81            Opts::Exclude(field_names) => FieldSelection::Exclude(FieldNames::from_iter(
82                field_names.names.iter().map(|s| s.as_str()),
83            )),
84        };
85
86        Ok(field_selection)
87    }
88
89    fn arity(&self, _options: &FieldSelection) -> Arity {
90        Arity::Exact(1)
91    }
92
93    fn child_name(&self, _instance: &FieldSelection, child_idx: usize) -> ChildName {
94        match child_idx {
95            0 => ChildName::new_ref("child"),
96            _ => unreachable!(),
97        }
98    }
99
100    fn fmt_sql(
101        &self,
102        selection: &FieldSelection,
103        expr: &Expression,
104        f: &mut Formatter<'_>,
105    ) -> std::fmt::Result {
106        expr.child(0).fmt_sql(f)?;
107        match selection {
108            FieldSelection::Include(fields) => {
109                write!(f, "{{{}}}", DisplayFieldNames(fields))
110            }
111            FieldSelection::Exclude(fields) => {
112                write!(f, "{{~ {}}}", DisplayFieldNames(fields))
113            }
114        }
115    }
116
117    fn return_dtype(
118        &self,
119        selection: &FieldSelection,
120        arg_dtypes: &[DType],
121    ) -> VortexResult<DType> {
122        let child_dtype = &arg_dtypes[0];
123        let child_struct_dtype = child_dtype
124            .as_struct_fields_opt()
125            .ok_or_else(|| vortex_err!("Select child not a struct dtype"))?;
126
127        let projected = match selection {
128            FieldSelection::Include(fields) => child_struct_dtype.project(fields.as_ref())?,
129            FieldSelection::Exclude(fields) => child_struct_dtype
130                .names()
131                .iter()
132                .cloned()
133                .zip_eq(child_struct_dtype.fields())
134                .filter(|(name, _)| !fields.as_ref().contains(name))
135                .collect(),
136        };
137
138        Ok(DType::Struct(projected, child_dtype.nullability()))
139    }
140
141    fn execute(
142        &self,
143        selection: &FieldSelection,
144        mut args: ExecutionArgs,
145    ) -> VortexResult<ArrayRef> {
146        let child = args
147            .inputs
148            .pop()
149            .vortex_expect("Missing input child")
150            .execute::<StructArray>(args.ctx)?;
151
152        let result = match selection {
153            FieldSelection::Include(f) => child.project(f.as_ref()),
154            FieldSelection::Exclude(names) => {
155                let included_names = child
156                    .names()
157                    .iter()
158                    .filter(|&f| !names.as_ref().contains(f))
159                    .cloned()
160                    .collect::<Vec<_>>();
161                child.project(included_names.as_slice())
162            }
163        }?;
164
165        result.into_array().execute(args.ctx)
166    }
167
168    fn simplify(
169        &self,
170        selection: &FieldSelection,
171        expr: &Expression,
172        ctx: &dyn SimplifyCtx,
173    ) -> VortexResult<Option<Expression>> {
174        let child_struct = expr.child(0);
175        let struct_dtype = ctx.return_dtype(child_struct)?;
176        let struct_nullability = struct_dtype.nullability();
177
178        let struct_fields = struct_dtype.as_struct_fields_opt().ok_or_else(|| {
179            vortex_err!(
180                "Select child must return a struct dtype, however it was a {}",
181                struct_dtype
182            )
183        })?;
184
185        // "Mask" out the unwanted fields of the child struct `DType`.
186        let included_fields = selection.normalize_to_included_fields(struct_fields.names())?;
187        let all_included_fields_are_nullable = included_fields.iter().all(|name| {
188            struct_fields
189                .field(name)
190                .vortex_expect(
191                    "`normalize_to_included_fields` checks that the included fields already exist \
192                     in `struct_fields`",
193                )
194                .is_nullable()
195        });
196
197        // If no fields are included, we can trivially simplify to a pack expression.
198        // NOTE(ngates): we do this knowing that our layout expression partitioning logic has
199        //  special-casing for pack, but not for select. We will fix this up when we revisit the
200        //  layout APIs.
201        if included_fields.is_empty() {
202            let empty: Vec<(FieldName, Expression)> = vec![];
203            return Ok(Some(pack(empty, struct_nullability)));
204        }
205
206        // We cannot always convert a `select` into a `pack(get_item(f1), get_item(f2), ...)`.
207        // This is because `get_item` does a validity intersection of the struct validity with its
208        // fields, which is not the same as just "masking" out the unwanted fields (a selection).
209        //
210        // We can, however, make this simplification when the child of the `select` is already a
211        // `pack` and we know that `get_item` will do no validity intersections.
212        let child_is_pack = child_struct.is::<Pack>();
213
214        // `get_item` only performs validity intersection when the struct is nullable but the field
215        // is not. This would change the semantics of a `select`, so we can only simplify when this
216        // won't happen.
217        let would_intersect_validity =
218            struct_nullability.is_nullable() && !all_included_fields_are_nullable;
219
220        if child_is_pack && !would_intersect_validity {
221            let pack_expr = pack(
222                included_fields
223                    .into_iter()
224                    .map(|name| (name.clone(), get_item(name, child_struct.clone()))),
225                struct_nullability,
226            );
227
228            return Ok(Some(pack_expr));
229        }
230
231        Ok(None)
232    }
233
234    fn is_null_sensitive(&self, _instance: &FieldSelection) -> bool {
235        true
236    }
237
238    fn is_fallible(&self, _instance: &FieldSelection) -> bool {
239        // If this type-checks its infallible.
240        false
241    }
242}
243
244impl FieldSelection {
245    pub fn include(columns: FieldNames) -> Self {
246        assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
247        Self::Include(columns)
248    }
249
250    pub fn exclude(columns: FieldNames) -> Self {
251        assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
252        Self::Exclude(columns)
253    }
254
255    pub fn is_include(&self) -> bool {
256        matches!(self, Self::Include(_))
257    }
258
259    pub fn is_exclude(&self) -> bool {
260        matches!(self, Self::Exclude(_))
261    }
262
263    pub fn field_names(&self) -> &FieldNames {
264        let (FieldSelection::Include(fields) | FieldSelection::Exclude(fields)) = self;
265
266        fields
267    }
268
269    pub fn normalize_to_included_fields(
270        &self,
271        available_fields: &FieldNames,
272    ) -> VortexResult<FieldNames> {
273        // Check that all of the field names exist in the available fields.
274        if self
275            .field_names()
276            .iter()
277            .any(|f| !available_fields.iter().contains(f))
278        {
279            vortex_bail!(
280                "Select fields {:?} must be a subset of child fields {:?}",
281                self,
282                available_fields
283            );
284        }
285
286        match self {
287            FieldSelection::Include(fields) => Ok(fields.clone()),
288            FieldSelection::Exclude(exc_fields) => Ok(available_fields
289                .iter()
290                .filter(|f| !exc_fields.iter().contains(f))
291                .cloned()
292                .collect()),
293        }
294    }
295}
296
297impl Display for FieldSelection {
298    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
299        match self {
300            FieldSelection::Include(fields) => write!(f, "{{{}}}", DisplayFieldNames(fields)),
301            FieldSelection::Exclude(fields) => write!(f, "~{{{}}}", DisplayFieldNames(fields)),
302        }
303    }
304}
305
306#[cfg(test)]
307mod tests {
308    use vortex_buffer::buffer;
309
310    use crate::IntoArray;
311    use crate::ToCanonical;
312    use crate::arrays::StructArray;
313    use crate::dtype::DType;
314    use crate::dtype::FieldName;
315    use crate::dtype::FieldNames;
316    use crate::dtype::Nullability;
317    use crate::dtype::Nullability::Nullable;
318    use crate::dtype::PType::I32;
319    use crate::dtype::StructFields;
320    use crate::expr::root;
321    use crate::expr::select;
322    use crate::expr::select_exclude;
323    use crate::expr::test_harness;
324    use crate::scalar_fn::fns::select::Select;
325
326    fn test_array() -> StructArray {
327        StructArray::from_fields(&[
328            ("a", buffer![0, 1, 2].into_array()),
329            ("b", buffer![4, 5, 6].into_array()),
330        ])
331        .unwrap()
332    }
333
334    #[test]
335    pub fn include_columns() {
336        let st = test_array();
337        let select = select(vec![FieldName::from("a")], root());
338        let selected = st.to_array().apply(&select).unwrap().to_struct();
339        let selected_names = selected.names().clone();
340        assert_eq!(selected_names.as_ref(), &["a"]);
341    }
342
343    #[test]
344    pub fn exclude_columns() {
345        let st = test_array();
346        let select = select_exclude(vec![FieldName::from("a")], root());
347        let selected = st.to_array().apply(&select).unwrap().to_struct();
348        let selected_names = selected.names().clone();
349        assert_eq!(selected_names.as_ref(), &["b"]);
350    }
351
352    #[test]
353    fn dtype() {
354        let dtype = test_harness::struct_dtype();
355
356        let select_expr = select(vec![FieldName::from("a")], root());
357        let expected_dtype = DType::Struct(
358            dtype
359                .as_struct_fields_opt()
360                .unwrap()
361                .project(&["a".into()])
362                .unwrap(),
363            Nullability::NonNullable,
364        );
365        assert_eq!(select_expr.return_dtype(&dtype).unwrap(), expected_dtype);
366
367        let select_expr_exclude = select_exclude(
368            vec![
369                FieldName::from("col1"),
370                FieldName::from("col2"),
371                FieldName::from("bool1"),
372                FieldName::from("bool2"),
373            ],
374            root(),
375        );
376        assert_eq!(
377            select_expr_exclude.return_dtype(&dtype).unwrap(),
378            expected_dtype
379        );
380
381        let select_expr_exclude = select_exclude(
382            vec![FieldName::from("col1"), FieldName::from("col2")],
383            root(),
384        );
385        assert_eq!(
386            select_expr_exclude.return_dtype(&dtype).unwrap(),
387            DType::Struct(
388                dtype
389                    .as_struct_fields_opt()
390                    .unwrap()
391                    .project(&["a".into(), "bool1".into(), "bool2".into()])
392                    .unwrap(),
393                Nullability::NonNullable
394            )
395        );
396    }
397
398    #[test]
399    fn test_as_include_names() {
400        let field_names = FieldNames::from(["a", "b", "c"]);
401        let include = select(["a"], root());
402        let exclude = select_exclude(["b", "c"], root());
403        assert_eq!(
404            &include
405                .as_::<Select>()
406                .normalize_to_included_fields(&field_names)
407                .unwrap(),
408            &exclude
409                .as_::<Select>()
410                .normalize_to_included_fields(&field_names)
411                .unwrap()
412        );
413    }
414
415    #[test]
416    fn test_remove_select_rule() {
417        let dtype = DType::Struct(
418            StructFields::new(["a", "b"].into(), vec![I32.into(), I32.into()]),
419            Nullable,
420        );
421        let e = select(["a", "b"], root());
422
423        let result = e.optimize_recursive(&dtype).unwrap();
424
425        assert!(result.return_dtype(&dtype).unwrap().is_nullable());
426    }
427
428    #[test]
429    fn test_remove_select_rule_exclude_fields() {
430        use crate::expr::select_exclude;
431
432        let dtype = DType::Struct(
433            StructFields::new(
434                ["a", "b", "c"].into(),
435                vec![I32.into(), I32.into(), I32.into()],
436            ),
437            Nullable,
438        );
439        let e = select_exclude(["c"], root());
440
441        let result = e.optimize_recursive(&dtype).unwrap();
442
443        // Should exclude "c" and include "a" and "b"
444        let result_dtype = result.return_dtype(&dtype).unwrap();
445        assert!(result_dtype.is_nullable());
446        let fields = result_dtype.as_struct_fields_opt().unwrap();
447        assert_eq!(fields.names().as_ref(), &["a", "b"]);
448    }
449}