vortex_array/expr/exprs/
select.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::fmt::Formatter;
6use std::sync::Arc;
7
8use itertools::Itertools;
9use prost::Message;
10use vortex_dtype::DType;
11use vortex_dtype::FieldNames;
12use vortex_error::VortexExpect;
13use vortex_error::VortexResult;
14use vortex_error::vortex_bail;
15use vortex_error::vortex_err;
16use vortex_proto::expr::FieldNames as ProtoFieldNames;
17use vortex_proto::expr::SelectOpts;
18use vortex_proto::expr::select_opts::Opts;
19use vortex_vector::Datum;
20use vortex_vector::StructDatum;
21use vortex_vector::VectorOps;
22use vortex_vector::struct_::StructVector;
23
24use crate::ArrayRef;
25use crate::IntoArray;
26use crate::ToCanonical;
27use crate::expr::Arity;
28use crate::expr::ChildName;
29use crate::expr::ExecutionArgs;
30use crate::expr::ExprId;
31use crate::expr::SimplifyCtx;
32use crate::expr::VTable;
33use crate::expr::VTableExt;
34use crate::expr::expression::Expression;
35use crate::expr::field::DisplayFieldNames;
36use crate::expr::get_item;
37use crate::expr::pack;
38
39#[derive(Debug, Clone, PartialEq, Eq, Hash)]
40pub enum FieldSelection {
41    Include(FieldNames),
42    Exclude(FieldNames),
43}
44
45pub struct Select;
46
47impl VTable for Select {
48    type Options = FieldSelection;
49
50    fn id(&self) -> ExprId {
51        ExprId::new_ref("vortex.select")
52    }
53
54    fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
55        let opts = match instance {
56            FieldSelection::Include(fields) => Opts::Include(ProtoFieldNames {
57                names: fields.iter().map(|f| f.to_string()).collect(),
58            }),
59            FieldSelection::Exclude(fields) => Opts::Exclude(ProtoFieldNames {
60                names: fields.iter().map(|f| f.to_string()).collect(),
61            }),
62        };
63
64        let select_opts = SelectOpts { opts: Some(opts) };
65        Ok(Some(select_opts.encode_to_vec()))
66    }
67
68    fn deserialize(&self, metadata: &[u8]) -> VortexResult<Self::Options> {
69        let prost_metadata = SelectOpts::decode(metadata)?;
70
71        let select_opts = prost_metadata
72            .opts
73            .ok_or_else(|| vortex_err!("SelectOpts missing opts field"))?;
74
75        let field_selection = match select_opts {
76            Opts::Include(field_names) => FieldSelection::Include(FieldNames::from_iter(
77                field_names.names.iter().map(|s| s.as_str()),
78            )),
79            Opts::Exclude(field_names) => FieldSelection::Exclude(FieldNames::from_iter(
80                field_names.names.iter().map(|s| s.as_str()),
81            )),
82        };
83
84        Ok(field_selection)
85    }
86
87    fn arity(&self, _options: &Self::Options) -> Arity {
88        Arity::Exact(1)
89    }
90
91    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
92        match child_idx {
93            0 => ChildName::new_ref("child"),
94            _ => unreachable!(),
95        }
96    }
97
98    fn fmt_sql(
99        &self,
100        selection: &FieldSelection,
101        expr: &Expression,
102        f: &mut Formatter<'_>,
103    ) -> std::fmt::Result {
104        expr.child(0).fmt_sql(f)?;
105        match selection {
106            FieldSelection::Include(fields) => {
107                write!(f, "{{{}}}", DisplayFieldNames(fields))
108            }
109            FieldSelection::Exclude(fields) => {
110                write!(f, "{{~ {}}}", DisplayFieldNames(fields))
111            }
112        }
113    }
114
115    fn return_dtype(
116        &self,
117        selection: &FieldSelection,
118        arg_dtypes: &[DType],
119    ) -> VortexResult<DType> {
120        let child_dtype = &arg_dtypes[0];
121        let child_struct_dtype = child_dtype
122            .as_struct_fields_opt()
123            .ok_or_else(|| vortex_err!("Select child not a struct dtype"))?;
124
125        let projected = match selection {
126            FieldSelection::Include(fields) => child_struct_dtype.project(fields.as_ref())?,
127            FieldSelection::Exclude(fields) => child_struct_dtype
128                .names()
129                .iter()
130                .cloned()
131                .zip_eq(child_struct_dtype.fields())
132                .filter(|(name, _)| !fields.as_ref().contains(name))
133                .collect(),
134        };
135
136        Ok(DType::Struct(projected, child_dtype.nullability()))
137    }
138
139    fn evaluate(
140        &self,
141        selection: &FieldSelection,
142        expr: &Expression,
143        scope: &ArrayRef,
144    ) -> VortexResult<ArrayRef> {
145        let batch = expr.child(0).evaluate(scope)?.to_struct();
146        Ok(match selection {
147            FieldSelection::Include(f) => batch.project(f.as_ref()),
148            FieldSelection::Exclude(names) => {
149                let included_names = batch
150                    .names()
151                    .iter()
152                    .filter(|&f| !names.as_ref().contains(f))
153                    .cloned()
154                    .collect::<Vec<_>>();
155                batch.project(included_names.as_slice())
156            }
157        }?
158        .into_array())
159    }
160
161    fn execute(&self, selection: &FieldSelection, mut args: ExecutionArgs) -> VortexResult<Datum> {
162        let child_fields = args
163            .dtypes
164            .pop()
165            .vortex_expect("Missing input dtype")
166            .into_struct_fields();
167
168        let field_indices: Vec<usize> = match selection {
169            FieldSelection::Include(f) => f
170                .iter()
171                .map(|name| {
172                    child_fields
173                        .find(name)
174                        .ok_or_else(|| vortex_err!("Field {} not found in struct dtype", name))
175                })
176                .try_collect(),
177            FieldSelection::Exclude(names) => child_fields
178                .names()
179                .iter()
180                .filter(|&f| !names.as_ref().contains(f))
181                .map(|name| {
182                    child_fields
183                        .find(name)
184                        .ok_or_else(|| vortex_err!("Field {} not found in struct dtype", name))
185                })
186                .try_collect(),
187        }?;
188
189        let child = args
190            .datums
191            .pop()
192            .vortex_expect("Missing input child")
193            .into_struct();
194
195        Ok(match child {
196            StructDatum::Scalar(s) => StructDatum::Scalar(
197                select_from_struct_vector(s.value(), &field_indices)?.scalar_at(0),
198            ),
199            StructDatum::Vector(v) => {
200                StructDatum::Vector(select_from_struct_vector(&v, &field_indices)?)
201            }
202        }
203        .into())
204    }
205
206    fn simplify(
207        &self,
208        options: &Self::Options,
209        expr: &Expression,
210        ctx: &dyn SimplifyCtx,
211    ) -> VortexResult<Option<Expression>> {
212        let child = expr.child(0);
213        let child_dtype = ctx.return_dtype(child)?;
214        let child_nullability = child_dtype.nullability();
215
216        let child_dtype = child_dtype.as_struct_fields_opt().ok_or_else(|| {
217            vortex_err!(
218                "Select child must return a struct dtype, however it was a {}",
219                child_dtype
220            )
221        })?;
222
223        let expr = pack(
224            options
225                .as_include_names(child_dtype.names())
226                .map_err(|e| {
227                    e.with_context(format!(
228                        "Select fields {:?} must be a subset of child fields {:?}",
229                        options,
230                        child_dtype.names()
231                    ))
232                })?
233                .iter()
234                .map(|name| (name.clone(), get_item(name.clone(), child.clone()))),
235            child_nullability,
236        );
237
238        Ok(Some(expr))
239    }
240
241    fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
242        true
243    }
244
245    fn is_fallible(&self, _instance: &Self::Options) -> bool {
246        // If this type-checks its infallible.
247        false
248    }
249}
250
251fn select_from_struct_vector(
252    vec: &StructVector,
253    field_indices: &[usize],
254) -> VortexResult<StructVector> {
255    let new_fields = field_indices
256        .iter()
257        .map(|&idx| vec.fields()[idx].clone())
258        .collect();
259    Ok(unsafe { StructVector::new_unchecked(Arc::new(new_fields), vec.validity().clone()) })
260}
261
262/// Creates an expression that selects (includes) specific fields from an array.
263///
264/// Projects only the specified fields from the child expression, which must be of DType struct.
265/// ```rust
266/// # use vortex_array::expr::{select, root};
267/// let expr = select(["name", "age"], root());
268/// ```
269pub fn select(field_names: impl Into<FieldNames>, child: Expression) -> Expression {
270    Select
271        .try_new_expr(FieldSelection::Include(field_names.into()), [child])
272        .vortex_expect("Failed to create Select expression")
273}
274
275/// Creates an expression that excludes specific fields from an array.
276///
277/// Projects all fields except the specified ones from the input struct expression.
278///
279/// ```rust
280/// # use vortex_array::expr::{select_exclude, root};
281/// let expr = select_exclude(["internal_id", "metadata"], root());
282/// ```
283pub fn select_exclude(fields: impl Into<FieldNames>, child: Expression) -> Expression {
284    Select
285        .try_new_expr(FieldSelection::Exclude(fields.into()), [child])
286        .vortex_expect("Failed to create Select expression")
287}
288
289impl FieldSelection {
290    pub fn include(columns: FieldNames) -> Self {
291        assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
292        Self::Include(columns)
293    }
294
295    pub fn exclude(columns: FieldNames) -> Self {
296        assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
297        Self::Exclude(columns)
298    }
299
300    pub fn is_include(&self) -> bool {
301        matches!(self, Self::Include(_))
302    }
303
304    pub fn is_exclude(&self) -> bool {
305        matches!(self, Self::Exclude(_))
306    }
307
308    pub fn field_names(&self) -> &FieldNames {
309        let (FieldSelection::Include(fields) | FieldSelection::Exclude(fields)) = self;
310
311        fields
312    }
313
314    pub fn as_include_names(&self, field_names: &FieldNames) -> VortexResult<FieldNames> {
315        if self
316            .field_names()
317            .iter()
318            .any(|f| !field_names.iter().contains(f))
319        {
320            vortex_bail!(
321                "Field {:?} in select not in field names {:?}",
322                self,
323                field_names
324            );
325        }
326        match self {
327            FieldSelection::Include(fields) => Ok(fields.clone()),
328            FieldSelection::Exclude(exc_fields) => Ok(field_names
329                .iter()
330                .filter(|f| !exc_fields.iter().contains(f))
331                .cloned()
332                .collect()),
333        }
334    }
335}
336
337impl Display for FieldSelection {
338    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
339        match self {
340            FieldSelection::Include(fields) => write!(f, "{{{}}}", DisplayFieldNames(fields)),
341            FieldSelection::Exclude(fields) => write!(f, "~{{{}}}", DisplayFieldNames(fields)),
342        }
343    }
344}
345
346#[cfg(test)]
347mod tests {
348    use vortex_buffer::buffer;
349    use vortex_dtype::DType;
350    use vortex_dtype::FieldName;
351    use vortex_dtype::FieldNames;
352    use vortex_dtype::Nullability;
353    use vortex_dtype::Nullability::Nullable;
354    use vortex_dtype::PType::I32;
355    use vortex_dtype::StructFields;
356
357    use super::select;
358    use super::select_exclude;
359    use crate::IntoArray;
360    use crate::ToCanonical;
361    use crate::arrays::StructArray;
362    use crate::expr::exprs::pack::Pack;
363    use crate::expr::exprs::root::root;
364    use crate::expr::exprs::select::Select;
365    use crate::expr::test_harness;
366
367    fn test_array() -> StructArray {
368        StructArray::from_fields(&[
369            ("a", buffer![0, 1, 2].into_array()),
370            ("b", buffer![4, 5, 6].into_array()),
371        ])
372        .unwrap()
373    }
374
375    #[test]
376    pub fn include_columns() {
377        let st = test_array();
378        let select = select(vec![FieldName::from("a")], root());
379        let selected = select.evaluate(&st.to_array()).unwrap().to_struct();
380        let selected_names = selected.names().clone();
381        assert_eq!(selected_names.as_ref(), &["a"]);
382    }
383
384    #[test]
385    pub fn exclude_columns() {
386        let st = test_array();
387        let select = select_exclude(vec![FieldName::from("a")], root());
388        let selected = select.evaluate(&st.to_array()).unwrap().to_struct();
389        let selected_names = selected.names().clone();
390        assert_eq!(selected_names.as_ref(), &["b"]);
391    }
392
393    #[test]
394    fn dtype() {
395        let dtype = test_harness::struct_dtype();
396
397        let select_expr = select(vec![FieldName::from("a")], root());
398        let expected_dtype = DType::Struct(
399            dtype
400                .as_struct_fields_opt()
401                .unwrap()
402                .project(&["a".into()])
403                .unwrap(),
404            Nullability::NonNullable,
405        );
406        assert_eq!(select_expr.return_dtype(&dtype).unwrap(), expected_dtype);
407
408        let select_expr_exclude = select_exclude(
409            vec![
410                FieldName::from("col1"),
411                FieldName::from("col2"),
412                FieldName::from("bool1"),
413                FieldName::from("bool2"),
414            ],
415            root(),
416        );
417        assert_eq!(
418            select_expr_exclude.return_dtype(&dtype).unwrap(),
419            expected_dtype
420        );
421
422        let select_expr_exclude = select_exclude(
423            vec![FieldName::from("col1"), FieldName::from("col2")],
424            root(),
425        );
426        assert_eq!(
427            select_expr_exclude.return_dtype(&dtype).unwrap(),
428            DType::Struct(
429                dtype
430                    .as_struct_fields_opt()
431                    .unwrap()
432                    .project(&["a".into(), "bool1".into(), "bool2".into()])
433                    .unwrap(),
434                Nullability::NonNullable
435            )
436        );
437    }
438
439    #[test]
440    fn test_as_include_names() {
441        let field_names = FieldNames::from(["a", "b", "c"]);
442        let include = select(["a"], root());
443        let exclude = select_exclude(["b", "c"], root());
444        assert_eq!(
445            &include
446                .as_::<Select>()
447                .as_include_names(&field_names)
448                .unwrap(),
449            &exclude
450                .as_::<Select>()
451                .as_include_names(&field_names)
452                .unwrap()
453        );
454    }
455
456    #[test]
457    fn test_remove_select_rule() {
458        let dtype = DType::Struct(
459            StructFields::new(["a", "b"].into(), vec![I32.into(), I32.into()]),
460            Nullable,
461        );
462        let e = select(["a", "b"], root());
463
464        let result = e.optimize_recursive(&dtype).unwrap();
465
466        assert!(result.is::<Pack>());
467        assert!(result.return_dtype(&dtype).unwrap().is_nullable());
468    }
469
470    #[test]
471    fn test_remove_select_rule_exclude_fields() {
472        use crate::expr::exprs::select::select_exclude;
473
474        let dtype = DType::Struct(
475            StructFields::new(
476                ["a", "b", "c"].into(),
477                vec![I32.into(), I32.into(), I32.into()],
478            ),
479            Nullable,
480        );
481        let e = select_exclude(["c"], root());
482
483        let result = e.optimize_recursive(&dtype).unwrap();
484
485        assert!(result.is::<Pack>());
486
487        // Should exclude "c" and include "a" and "b"
488        let result_dtype = result.return_dtype(&dtype).unwrap();
489        assert!(result_dtype.is_nullable());
490        let fields = result_dtype.as_struct_fields_opt().unwrap();
491        assert_eq!(fields.names().as_ref(), &["a", "b"]);
492    }
493}