Skip to main content

vortex_array/scalar_fn/fns/
select.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::fmt::Formatter;
6
7use itertools::Itertools;
8use prost::Message;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_bail;
12use vortex_error::vortex_err;
13use vortex_proto::expr::FieldNames as ProtoFieldNames;
14use vortex_proto::expr::SelectOpts;
15use vortex_proto::expr::select_opts::Opts;
16use vortex_session::VortexSession;
17use vortex_session::registry::CachedId;
18
19use crate::ArrayRef;
20use crate::ExecutionCtx;
21use crate::IntoArray;
22use crate::arrays::StructArray;
23use crate::arrays::struct_::StructArrayExt;
24use crate::dtype::DType;
25use crate::dtype::FieldName;
26use crate::dtype::FieldNames;
27use crate::expr::expression::Expression;
28use crate::expr::field::DisplayFieldNames;
29use crate::expr::get_item;
30use crate::expr::pack;
31use crate::scalar_fn::Arity;
32use crate::scalar_fn::ChildName;
33use crate::scalar_fn::ExecutionArgs;
34use crate::scalar_fn::ScalarFnId;
35use crate::scalar_fn::ScalarFnVTable;
36use crate::scalar_fn::SimplifyCtx;
37use crate::scalar_fn::fns::pack::Pack;
38
39#[derive(Debug, Clone, PartialEq, Eq, Hash)]
40pub enum FieldSelection {
41    Include(FieldNames),
42    Exclude(FieldNames),
43}
44
45#[derive(Clone)]
46pub struct Select;
47
48impl ScalarFnVTable for Select {
49    type Options = FieldSelection;
50
51    fn id(&self) -> ScalarFnId {
52        static ID: CachedId = CachedId::new("vortex.select");
53        *ID
54    }
55
56    fn serialize(&self, instance: &FieldSelection) -> VortexResult<Option<Vec<u8>>> {
57        let opts = match instance {
58            FieldSelection::Include(fields) => Opts::Include(ProtoFieldNames {
59                names: fields.iter().map(|f| f.to_string()).collect(),
60            }),
61            FieldSelection::Exclude(fields) => Opts::Exclude(ProtoFieldNames {
62                names: fields.iter().map(|f| f.to_string()).collect(),
63            }),
64        };
65
66        let select_opts = SelectOpts { opts: Some(opts) };
67        Ok(Some(select_opts.encode_to_vec()))
68    }
69
70    fn deserialize(
71        &self,
72        _metadata: &[u8],
73        _session: &VortexSession,
74    ) -> VortexResult<FieldSelection> {
75        let prost_metadata = SelectOpts::decode(_metadata)?;
76
77        let select_opts = prost_metadata
78            .opts
79            .ok_or_else(|| vortex_err!("SelectOpts missing opts field"))?;
80
81        let field_selection = match select_opts {
82            Opts::Include(field_names) => FieldSelection::Include(FieldNames::from_iter(
83                field_names.names.iter().map(|s| s.as_str()),
84            )),
85            Opts::Exclude(field_names) => FieldSelection::Exclude(FieldNames::from_iter(
86                field_names.names.iter().map(|s| s.as_str()),
87            )),
88        };
89
90        Ok(field_selection)
91    }
92
93    fn arity(&self, _options: &FieldSelection) -> Arity {
94        Arity::Exact(1)
95    }
96
97    fn child_name(&self, _instance: &FieldSelection, child_idx: usize) -> ChildName {
98        match child_idx {
99            0 => ChildName::from("child"),
100            _ => unreachable!(),
101        }
102    }
103
104    fn fmt_sql(
105        &self,
106        selection: &FieldSelection,
107        expr: &Expression,
108        f: &mut Formatter<'_>,
109    ) -> std::fmt::Result {
110        expr.child(0).fmt_sql(f)?;
111        match selection {
112            FieldSelection::Include(fields) => {
113                write!(f, "{{{}}}", DisplayFieldNames(fields))
114            }
115            FieldSelection::Exclude(fields) => {
116                write!(f, "{{~ {}}}", DisplayFieldNames(fields))
117            }
118        }
119    }
120
121    fn return_dtype(
122        &self,
123        selection: &FieldSelection,
124        arg_dtypes: &[DType],
125    ) -> VortexResult<DType> {
126        let child_dtype = &arg_dtypes[0];
127        let child_struct_dtype = child_dtype
128            .as_struct_fields_opt()
129            .ok_or_else(|| vortex_err!("Select child not a struct dtype"))?;
130
131        let projected = match selection {
132            FieldSelection::Include(fields) => child_struct_dtype.project(fields.as_ref())?,
133            FieldSelection::Exclude(fields) => child_struct_dtype
134                .names()
135                .iter()
136                .cloned()
137                .zip_eq(child_struct_dtype.fields())
138                .filter(|(name, _)| !fields.as_ref().contains(name))
139                .collect(),
140        };
141
142        Ok(DType::Struct(projected, child_dtype.nullability()))
143    }
144
145    fn execute(
146        &self,
147        selection: &FieldSelection,
148        args: &dyn ExecutionArgs,
149        ctx: &mut ExecutionCtx,
150    ) -> VortexResult<ArrayRef> {
151        let child = args.get(0)?.execute::<StructArray>(ctx)?;
152
153        let result = match selection {
154            FieldSelection::Include(f) => child.project(f.as_ref()),
155            FieldSelection::Exclude(names) => {
156                let included_names = child
157                    .names()
158                    .iter()
159                    .filter(|&f| !names.as_ref().contains(f))
160                    .cloned()
161                    .collect::<Vec<_>>();
162                child.project(included_names.as_slice())
163            }
164        }?;
165
166        result.into_array().execute(ctx)
167    }
168
169    fn simplify(
170        &self,
171        selection: &FieldSelection,
172        expr: &Expression,
173        ctx: &dyn SimplifyCtx,
174    ) -> VortexResult<Option<Expression>> {
175        let child_struct = expr.child(0);
176        let struct_dtype = ctx.return_dtype(child_struct)?;
177        let struct_nullability = struct_dtype.nullability();
178
179        let struct_fields = struct_dtype.as_struct_fields_opt().ok_or_else(|| {
180            vortex_err!(
181                "Select child must return a struct dtype, however it was a {}",
182                struct_dtype
183            )
184        })?;
185
186        // "Mask" out the unwanted fields of the child struct `DType`.
187        let included_fields = selection.normalize_to_included_fields(struct_fields.names())?;
188        let all_included_fields_are_nullable = included_fields.iter().all(|name| {
189            struct_fields
190                .field(name)
191                .vortex_expect(
192                    "`normalize_to_included_fields` checks that the included fields already exist \
193                     in `struct_fields`",
194                )
195                .is_nullable()
196        });
197
198        // If no fields are included, we can trivially simplify to a pack expression.
199        // NOTE(ngates): we do this knowing that our layout expression partitioning logic has
200        //  special-casing for pack, but not for select. We will fix this up when we revisit the
201        //  layout APIs.
202        if included_fields.is_empty() {
203            let empty: Vec<(FieldName, Expression)> = vec![];
204            return Ok(Some(pack(empty, struct_nullability)));
205        }
206
207        // We cannot always convert a `select` into a `pack(get_item(f1), get_item(f2), ...)`.
208        // This is because `get_item` does a validity intersection of the struct validity with its
209        // fields, which is not the same as just "masking" out the unwanted fields (a selection).
210        //
211        // We can, however, make this simplification when the child of the `select` is already a
212        // `pack` and we know that `get_item` will do no validity intersections.
213        let child_is_pack = child_struct.is::<Pack>();
214
215        // `get_item` only performs validity intersection when the struct is nullable but the field
216        // is not. This would change the semantics of a `select`, so we can only simplify when this
217        // won't happen.
218        let would_intersect_validity =
219            struct_nullability.is_nullable() && !all_included_fields_are_nullable;
220
221        if child_is_pack && !would_intersect_validity {
222            let pack_expr = pack(
223                included_fields
224                    .into_iter()
225                    .map(|name| (name.clone(), get_item(name, child_struct.clone()))),
226                struct_nullability,
227            );
228
229            return Ok(Some(pack_expr));
230        }
231
232        Ok(None)
233    }
234
235    fn is_null_sensitive(&self, _instance: &FieldSelection) -> bool {
236        true
237    }
238
239    fn is_fallible(&self, _instance: &FieldSelection) -> bool {
240        // If this type-checks its infallible.
241        false
242    }
243}
244
245impl FieldSelection {
246    pub fn include(columns: FieldNames) -> Self {
247        assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
248        Self::Include(columns)
249    }
250
251    pub fn exclude(columns: FieldNames) -> Self {
252        assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
253        Self::Exclude(columns)
254    }
255
256    pub fn is_include(&self) -> bool {
257        matches!(self, Self::Include(_))
258    }
259
260    pub fn is_exclude(&self) -> bool {
261        matches!(self, Self::Exclude(_))
262    }
263
264    pub fn field_names(&self) -> &FieldNames {
265        let (FieldSelection::Include(fields) | FieldSelection::Exclude(fields)) = self;
266
267        fields
268    }
269
270    pub fn normalize_to_included_fields(
271        &self,
272        available_fields: &FieldNames,
273    ) -> VortexResult<FieldNames> {
274        // Check that all of the field names exist in the available fields.
275        if self
276            .field_names()
277            .iter()
278            .any(|f| !available_fields.iter().contains(f))
279        {
280            vortex_bail!(
281                "Select fields {:?} must be a subset of child fields {:?}",
282                self,
283                available_fields
284            );
285        }
286
287        match self {
288            FieldSelection::Include(fields) => Ok(fields.clone()),
289            FieldSelection::Exclude(exc_fields) => Ok(available_fields
290                .iter()
291                .filter(|f| !exc_fields.iter().contains(f))
292                .cloned()
293                .collect()),
294        }
295    }
296}
297
298impl Display for FieldSelection {
299    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
300        match self {
301            FieldSelection::Include(fields) => write!(f, "{{{}}}", DisplayFieldNames(fields)),
302            FieldSelection::Exclude(fields) => write!(f, "~{{{}}}", DisplayFieldNames(fields)),
303        }
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use vortex_buffer::buffer;
310
311    use crate::IntoArray;
312    #[expect(deprecated)]
313    use crate::ToCanonical as _;
314    use crate::arrays::struct_::StructArrayExt;
315    use crate::dtype::DType;
316    use crate::dtype::FieldName;
317    use crate::dtype::FieldNames;
318    use crate::dtype::Nullability;
319    use crate::dtype::Nullability::Nullable;
320    use crate::dtype::PType::I32;
321    use crate::dtype::StructFields;
322    use crate::expr::root;
323    use crate::expr::select;
324    use crate::expr::select_exclude;
325    use crate::expr::test_harness;
326    use crate::scalar_fn::fns::select::Select;
327    use crate::scalar_fn::fns::select::StructArray;
328
329    fn test_array() -> StructArray {
330        StructArray::from_fields(&[
331            ("a", buffer![0, 1, 2].into_array()),
332            ("b", buffer![4, 5, 6].into_array()),
333        ])
334        .unwrap()
335    }
336
337    #[test]
338    pub fn include_columns() {
339        let st = test_array();
340        let select = select(vec![FieldName::from("a")], root());
341        #[expect(deprecated)]
342        let selected = st.into_array().apply(&select).unwrap().to_struct();
343        let selected_names = selected.names().clone();
344        assert_eq!(selected_names.as_ref(), &["a"]);
345    }
346
347    #[test]
348    pub fn exclude_columns() {
349        let st = test_array();
350        let select = select_exclude(vec![FieldName::from("a")], root());
351        #[expect(deprecated)]
352        let selected = st.into_array().apply(&select).unwrap().to_struct();
353        let selected_names = selected.names().clone();
354        assert_eq!(selected_names.as_ref(), &["b"]);
355    }
356
357    #[test]
358    fn dtype() {
359        let dtype = test_harness::struct_dtype();
360
361        let select_expr = select(vec![FieldName::from("a")], root());
362        let expected_dtype = DType::Struct(
363            dtype
364                .as_struct_fields_opt()
365                .unwrap()
366                .project(&["a".into()])
367                .unwrap(),
368            Nullability::NonNullable,
369        );
370        assert_eq!(select_expr.return_dtype(&dtype).unwrap(), expected_dtype);
371
372        let select_expr_exclude = select_exclude(
373            vec![
374                FieldName::from("col1"),
375                FieldName::from("col2"),
376                FieldName::from("bool1"),
377                FieldName::from("bool2"),
378            ],
379            root(),
380        );
381        assert_eq!(
382            select_expr_exclude.return_dtype(&dtype).unwrap(),
383            expected_dtype
384        );
385
386        let select_expr_exclude = select_exclude(
387            vec![FieldName::from("col1"), FieldName::from("col2")],
388            root(),
389        );
390        assert_eq!(
391            select_expr_exclude.return_dtype(&dtype).unwrap(),
392            DType::Struct(
393                dtype
394                    .as_struct_fields_opt()
395                    .unwrap()
396                    .project(&["a".into(), "bool1".into(), "bool2".into()])
397                    .unwrap(),
398                Nullability::NonNullable
399            )
400        );
401    }
402
403    #[test]
404    fn test_as_include_names() {
405        let field_names = FieldNames::from(["a", "b", "c"]);
406        let include = select(["a"], root());
407        let exclude = select_exclude(["b", "c"], root());
408        assert_eq!(
409            &include
410                .as_::<Select>()
411                .normalize_to_included_fields(&field_names)
412                .unwrap(),
413            &exclude
414                .as_::<Select>()
415                .normalize_to_included_fields(&field_names)
416                .unwrap()
417        );
418    }
419
420    #[test]
421    fn test_remove_select_rule() {
422        let dtype = DType::Struct(
423            StructFields::new(["a", "b"].into(), vec![I32.into(), I32.into()]),
424            Nullable,
425        );
426        let e = select(["a", "b"], root());
427
428        let result = e.optimize_recursive(&dtype).unwrap();
429
430        assert!(result.return_dtype(&dtype).unwrap().is_nullable());
431    }
432
433    #[test]
434    fn test_remove_select_rule_exclude_fields() {
435        use crate::expr::select_exclude;
436
437        let dtype = DType::Struct(
438            StructFields::new(
439                ["a", "b", "c"].into(),
440                vec![I32.into(), I32.into(), I32.into()],
441            ),
442            Nullable,
443        );
444        let e = select_exclude(["c"], root());
445
446        let result = e.optimize_recursive(&dtype).unwrap();
447
448        // Should exclude "c" and include "a" and "b"
449        let result_dtype = result.return_dtype(&dtype).unwrap();
450        assert!(result_dtype.is_nullable());
451        let fields = result_dtype.as_struct_fields_opt().unwrap();
452        assert_eq!(fields.names().as_ref(), &["a", "b"]);
453    }
454}