vortex_array/expr/exprs/
select.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::{Display, Formatter};
5
6use itertools::Itertools;
7use prost::Message;
8use vortex_dtype::{DType, FieldNames};
9use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
10use vortex_proto::expr::select_opts::Opts;
11use vortex_proto::expr::{FieldNames as ProtoFieldNames, SelectOpts};
12
13use crate::expr::expression::Expression;
14use crate::expr::field::DisplayFieldNames;
15use crate::expr::{ChildName, ExprId, ExpressionView, VTable, VTableExt};
16use crate::{ArrayRef, IntoArray, ToCanonical};
17
18#[derive(Debug, Clone, PartialEq, Eq, Hash)]
19pub enum FieldSelection {
20    Include(FieldNames),
21    Exclude(FieldNames),
22}
23
24pub struct Select;
25
26impl VTable for Select {
27    type Instance = FieldSelection;
28
29    fn id(&self) -> ExprId {
30        ExprId::new_ref("vortex.select")
31    }
32
33    fn serialize(&self, instance: &Self::Instance) -> VortexResult<Option<Vec<u8>>> {
34        let opts = match instance {
35            FieldSelection::Include(fields) => Opts::Include(ProtoFieldNames {
36                names: fields.iter().map(|f| f.to_string()).collect(),
37            }),
38            FieldSelection::Exclude(fields) => Opts::Exclude(ProtoFieldNames {
39                names: fields.iter().map(|f| f.to_string()).collect(),
40            }),
41        };
42
43        let select_opts = SelectOpts { opts: Some(opts) };
44        Ok(Some(select_opts.encode_to_vec()))
45    }
46
47    fn deserialize(&self, metadata: &[u8]) -> VortexResult<Option<Self::Instance>> {
48        let prost_metadata = SelectOpts::decode(metadata)?;
49
50        let select_opts = prost_metadata
51            .opts
52            .ok_or_else(|| vortex_err!("SelectOpts missing opts field"))?;
53
54        let field_selection = match select_opts {
55            Opts::Include(field_names) => FieldSelection::Include(FieldNames::from_iter(
56                field_names.names.iter().map(|s| s.as_str()),
57            )),
58            Opts::Exclude(field_names) => FieldSelection::Exclude(FieldNames::from_iter(
59                field_names.names.iter().map(|s| s.as_str()),
60            )),
61        };
62
63        Ok(Some(field_selection))
64    }
65
66    fn validate(&self, expr: &ExpressionView<Self>) -> VortexResult<()> {
67        if expr.children().len() != 1 {
68            vortex_bail!(
69                "Select expression requires exactly 1 child, got {}",
70                expr.children().len()
71            );
72        }
73        Ok(())
74    }
75
76    fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName {
77        match child_idx {
78            0 => ChildName::new_ref("child"),
79            _ => unreachable!(),
80        }
81    }
82
83    fn fmt_sql(&self, expr: &ExpressionView<Self>, f: &mut Formatter<'_>) -> std::fmt::Result {
84        expr.child().fmt_sql(f)?;
85        match expr.data() {
86            FieldSelection::Include(fields) => {
87                write!(f, "{{{}}}", DisplayFieldNames(fields))
88            }
89            FieldSelection::Exclude(fields) => {
90                write!(f, "{{~ {}}}", DisplayFieldNames(fields))
91            }
92        }
93    }
94
95    fn fmt_data(&self, instance: &Self::Instance, f: &mut Formatter<'_>) -> std::fmt::Result {
96        let names = match instance {
97            FieldSelection::Include(names) => {
98                write!(f, "include=")?;
99                names
100            }
101            FieldSelection::Exclude(names) => {
102                write!(f, "exclude=")?;
103                names
104            }
105        };
106        write!(f, "{{{}}}", DisplayFieldNames(names))
107    }
108
109    fn return_dtype(&self, expr: &ExpressionView<Self>, scope: &DType) -> VortexResult<DType> {
110        let child_dtype = expr.child().return_dtype(scope)?;
111        let child_struct_dtype = child_dtype
112            .as_struct_fields_opt()
113            .ok_or_else(|| vortex_err!("Select child not a struct dtype"))?;
114
115        let projected = match expr.data() {
116            FieldSelection::Include(fields) => child_struct_dtype.project(fields.as_ref())?,
117            FieldSelection::Exclude(fields) => child_struct_dtype
118                .names()
119                .iter()
120                .cloned()
121                .zip_eq(child_struct_dtype.fields())
122                .filter(|(name, _)| !fields.as_ref().contains(name))
123                .collect(),
124        };
125
126        Ok(DType::Struct(projected, child_dtype.nullability()))
127    }
128
129    fn evaluate(&self, expr: &ExpressionView<Self>, scope: &ArrayRef) -> VortexResult<ArrayRef> {
130        let batch = expr.child().evaluate(scope)?.to_struct();
131        Ok(match expr.data() {
132            FieldSelection::Include(f) => batch.project(f.as_ref()),
133            FieldSelection::Exclude(names) => {
134                let included_names = batch
135                    .names()
136                    .iter()
137                    .filter(|&f| !names.as_ref().contains(f))
138                    .cloned()
139                    .collect::<Vec<_>>();
140                batch.project(included_names.as_slice())
141            }
142        }?
143        .into_array())
144    }
145}
146
147/// Creates an expression that selects (includes) specific fields from an array.
148///
149/// Projects only the specified fields from the child expression, which must be of DType struct.
150/// ```rust
151/// # use vortex_array::expr::{select, root};
152/// let expr = select(["name", "age"], root());
153/// ```
154pub fn select(field_names: impl Into<FieldNames>, child: Expression) -> Expression {
155    Select
156        .try_new_expr(FieldSelection::Include(field_names.into()), [child])
157        .vortex_expect("Failed to create Select expression")
158}
159
160/// Creates an expression that excludes specific fields from an array.
161///
162/// Projects all fields except the specified ones from the input struct expression.
163///
164/// ```rust
165/// # use vortex_array::expr::{select_exclude, root};
166/// let expr = select_exclude(["internal_id", "metadata"], root());
167/// ```
168pub fn select_exclude(fields: impl Into<FieldNames>, child: Expression) -> Expression {
169    Select
170        .try_new_expr(FieldSelection::Exclude(fields.into()), [child])
171        .vortex_expect("Failed to create Select expression")
172}
173
174impl ExpressionView<'_, Select> {
175    pub fn child(&self) -> &Expression {
176        &self.children()[0]
177    }
178
179    /// Turn the select expression into an `include`, relative to a provided array of field names.
180    ///
181    /// For example:
182    /// ```rust
183    /// # use vortex_array::expr::{root, Select};
184    /// # use vortex_array::expr::{FieldSelection, select, select_exclude};
185    /// # use vortex_dtype::FieldNames;
186    /// let field_names = FieldNames::from(["a", "b", "c"]);
187    /// let include = select(["a"], root());
188    /// let exclude = select_exclude(["b", "c"], root());
189    /// assert_eq!(
190    ///     &include.as_::<Select>().as_include(&field_names).unwrap(),
191    ///     &exclude.as_::<Select>().as_include(&field_names).unwrap(),
192    /// );
193    /// ```
194    pub fn as_include(&self, field_names: &FieldNames) -> VortexResult<Expression> {
195        Select.try_new_expr(
196            FieldSelection::Include(self.data().as_include_names(field_names)?),
197            [self.child().clone()],
198        )
199    }
200}
201
202impl FieldSelection {
203    pub fn include(columns: FieldNames) -> Self {
204        assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
205        Self::Include(columns)
206    }
207
208    pub fn exclude(columns: FieldNames) -> Self {
209        assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
210        Self::Exclude(columns)
211    }
212
213    pub fn is_include(&self) -> bool {
214        matches!(self, Self::Include(_))
215    }
216
217    pub fn is_exclude(&self) -> bool {
218        matches!(self, Self::Exclude(_))
219    }
220
221    pub fn field_names(&self) -> &FieldNames {
222        let (FieldSelection::Include(fields) | FieldSelection::Exclude(fields)) = self;
223
224        fields
225    }
226
227    pub fn as_include_names(&self, field_names: &FieldNames) -> VortexResult<FieldNames> {
228        if self
229            .field_names()
230            .iter()
231            .any(|f| !field_names.iter().contains(f))
232        {
233            vortex_bail!(
234                "Field {:?} in select not in field names {:?}",
235                self,
236                field_names
237            );
238        }
239        match self {
240            FieldSelection::Include(fields) => Ok(fields.clone()),
241            FieldSelection::Exclude(exc_fields) => Ok(field_names
242                .iter()
243                .filter(|f| !exc_fields.iter().contains(f))
244                .cloned()
245                .collect()),
246        }
247    }
248}
249
250impl Display for FieldSelection {
251    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
252        match self {
253            FieldSelection::Include(fields) => write!(f, "{{{}}}", DisplayFieldNames(fields)),
254            FieldSelection::Exclude(fields) => write!(f, "~{{{}}}", DisplayFieldNames(fields)),
255        }
256    }
257}
258
259#[cfg(test)]
260mod tests {
261    use vortex_buffer::buffer;
262    use vortex_dtype::{DType, FieldName, FieldNames, Nullability};
263
264    use super::{select, select_exclude};
265    use crate::arrays::StructArray;
266    use crate::expr::exprs::root::root;
267    use crate::expr::exprs::select::Select;
268    use crate::expr::test_harness;
269    use crate::{IntoArray, ToCanonical};
270
271    fn test_array() -> StructArray {
272        StructArray::from_fields(&[
273            ("a", buffer![0, 1, 2].into_array()),
274            ("b", buffer![4, 5, 6].into_array()),
275        ])
276        .unwrap()
277    }
278
279    #[test]
280    pub fn include_columns() {
281        let st = test_array();
282        let select = select(vec![FieldName::from("a")], root());
283        let selected = select.evaluate(&st.to_array()).unwrap().to_struct();
284        let selected_names = selected.names().clone();
285        assert_eq!(selected_names.as_ref(), &["a"]);
286    }
287
288    #[test]
289    pub fn exclude_columns() {
290        let st = test_array();
291        let select = select_exclude(vec![FieldName::from("a")], root());
292        let selected = select.evaluate(&st.to_array()).unwrap().to_struct();
293        let selected_names = selected.names().clone();
294        assert_eq!(selected_names.as_ref(), &["b"]);
295    }
296
297    #[test]
298    fn dtype() {
299        let dtype = test_harness::struct_dtype();
300
301        let select_expr = select(vec![FieldName::from("a")], root());
302        let expected_dtype = DType::Struct(
303            dtype
304                .as_struct_fields_opt()
305                .unwrap()
306                .project(&["a".into()])
307                .unwrap(),
308            Nullability::NonNullable,
309        );
310        assert_eq!(select_expr.return_dtype(&dtype).unwrap(), expected_dtype);
311
312        let select_expr_exclude = select_exclude(
313            vec![
314                FieldName::from("col1"),
315                FieldName::from("col2"),
316                FieldName::from("bool1"),
317                FieldName::from("bool2"),
318            ],
319            root(),
320        );
321        assert_eq!(
322            select_expr_exclude.return_dtype(&dtype).unwrap(),
323            expected_dtype
324        );
325
326        let select_expr_exclude = select_exclude(
327            vec![FieldName::from("col1"), FieldName::from("col2")],
328            root(),
329        );
330        assert_eq!(
331            select_expr_exclude.return_dtype(&dtype).unwrap(),
332            DType::Struct(
333                dtype
334                    .as_struct_fields_opt()
335                    .unwrap()
336                    .project(&["a".into(), "bool1".into(), "bool2".into()])
337                    .unwrap(),
338                Nullability::NonNullable
339            )
340        );
341    }
342
343    #[test]
344    fn test_as_include_names() {
345        let field_names = FieldNames::from(["a", "b", "c"]);
346        let include = select(["a"], root());
347        let exclude = select_exclude(["b", "c"], root());
348        assert_eq!(
349            &include
350                .as_::<Select>()
351                .data()
352                .as_include_names(&field_names)
353                .unwrap(),
354            &exclude
355                .as_::<Select>()
356                .data()
357                .as_include_names(&field_names)
358                .unwrap()
359        );
360    }
361}