vortex_array/expr/exprs/select/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4pub mod transform;
5
6use std::fmt::Display;
7use std::fmt::Formatter;
8use std::sync::Arc;
9
10use itertools::Itertools;
11use prost::Message;
12use vortex_dtype::DType;
13use vortex_dtype::FieldNames;
14use vortex_error::VortexExpect;
15use vortex_error::VortexResult;
16use vortex_error::vortex_bail;
17use vortex_error::vortex_err;
18use vortex_proto::expr::FieldNames as ProtoFieldNames;
19use vortex_proto::expr::SelectOpts;
20use vortex_proto::expr::select_opts::Opts;
21use vortex_vector::Vector;
22use vortex_vector::struct_::StructVector;
23
24use crate::ArrayRef;
25use crate::IntoArray;
26use crate::ToCanonical;
27use crate::expr::ChildName;
28use crate::expr::ExecutionArgs;
29use crate::expr::ExprId;
30use crate::expr::ExpressionView;
31use crate::expr::VTable;
32use crate::expr::VTableExt;
33use crate::expr::expression::Expression;
34use crate::expr::field::DisplayFieldNames;
35
36#[derive(Debug, Clone, PartialEq, Eq, Hash)]
37pub enum FieldSelection {
38    Include(FieldNames),
39    Exclude(FieldNames),
40}
41
42pub struct Select;
43
44impl VTable for Select {
45    type Instance = FieldSelection;
46
47    fn id(&self) -> ExprId {
48        ExprId::new_ref("vortex.select")
49    }
50
51    fn serialize(&self, instance: &Self::Instance) -> VortexResult<Option<Vec<u8>>> {
52        let opts = match instance {
53            FieldSelection::Include(fields) => Opts::Include(ProtoFieldNames {
54                names: fields.iter().map(|f| f.to_string()).collect(),
55            }),
56            FieldSelection::Exclude(fields) => Opts::Exclude(ProtoFieldNames {
57                names: fields.iter().map(|f| f.to_string()).collect(),
58            }),
59        };
60
61        let select_opts = SelectOpts { opts: Some(opts) };
62        Ok(Some(select_opts.encode_to_vec()))
63    }
64
65    fn deserialize(&self, metadata: &[u8]) -> VortexResult<Option<Self::Instance>> {
66        let prost_metadata = SelectOpts::decode(metadata)?;
67
68        let select_opts = prost_metadata
69            .opts
70            .ok_or_else(|| vortex_err!("SelectOpts missing opts field"))?;
71
72        let field_selection = match select_opts {
73            Opts::Include(field_names) => FieldSelection::Include(FieldNames::from_iter(
74                field_names.names.iter().map(|s| s.as_str()),
75            )),
76            Opts::Exclude(field_names) => FieldSelection::Exclude(FieldNames::from_iter(
77                field_names.names.iter().map(|s| s.as_str()),
78            )),
79        };
80
81        Ok(Some(field_selection))
82    }
83
84    fn validate(&self, expr: &ExpressionView<Self>) -> VortexResult<()> {
85        if expr.children().len() != 1 {
86            vortex_bail!(
87                "Select expression requires exactly 1 child, got {}",
88                expr.children().len()
89            );
90        }
91        Ok(())
92    }
93
94    fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName {
95        match child_idx {
96            0 => ChildName::new_ref("child"),
97            _ => unreachable!(),
98        }
99    }
100
101    fn fmt_sql(&self, expr: &ExpressionView<Self>, f: &mut Formatter<'_>) -> std::fmt::Result {
102        expr.child().fmt_sql(f)?;
103        match expr.data() {
104            FieldSelection::Include(fields) => {
105                write!(f, "{{{}}}", DisplayFieldNames(fields))
106            }
107            FieldSelection::Exclude(fields) => {
108                write!(f, "{{~ {}}}", DisplayFieldNames(fields))
109            }
110        }
111    }
112
113    fn fmt_data(&self, instance: &Self::Instance, f: &mut Formatter<'_>) -> std::fmt::Result {
114        let names = match instance {
115            FieldSelection::Include(names) => {
116                write!(f, "include=")?;
117                names
118            }
119            FieldSelection::Exclude(names) => {
120                write!(f, "exclude=")?;
121                names
122            }
123        };
124        write!(f, "{{{}}}", DisplayFieldNames(names))
125    }
126
127    fn return_dtype(&self, expr: &ExpressionView<Self>, scope: &DType) -> VortexResult<DType> {
128        let child_dtype = expr.child().return_dtype(scope)?;
129        let child_struct_dtype = child_dtype
130            .as_struct_fields_opt()
131            .ok_or_else(|| vortex_err!("Select child not a struct dtype"))?;
132
133        let projected = match expr.data() {
134            FieldSelection::Include(fields) => child_struct_dtype.project(fields.as_ref())?,
135            FieldSelection::Exclude(fields) => child_struct_dtype
136                .names()
137                .iter()
138                .cloned()
139                .zip_eq(child_struct_dtype.fields())
140                .filter(|(name, _)| !fields.as_ref().contains(name))
141                .collect(),
142        };
143
144        Ok(DType::Struct(projected, child_dtype.nullability()))
145    }
146
147    fn evaluate(&self, expr: &ExpressionView<Self>, scope: &ArrayRef) -> VortexResult<ArrayRef> {
148        let batch = expr.child().evaluate(scope)?.to_struct();
149        Ok(match expr.data() {
150            FieldSelection::Include(f) => batch.project(f.as_ref()),
151            FieldSelection::Exclude(names) => {
152                let included_names = batch
153                    .names()
154                    .iter()
155                    .filter(|&f| !names.as_ref().contains(f))
156                    .cloned()
157                    .collect::<Vec<_>>();
158                batch.project(included_names.as_slice())
159            }
160        }?
161        .into_array())
162    }
163
164    fn execute(&self, selection: &FieldSelection, mut args: ExecutionArgs) -> VortexResult<Vector> {
165        let child = args
166            .vectors
167            .pop()
168            .vortex_expect("Missing input child")
169            .into_struct();
170        let child_fields = args
171            .dtypes
172            .pop()
173            .vortex_expect("Missing input dtype")
174            .into_struct_fields();
175
176        let field_indices: Vec<usize> = match selection {
177            FieldSelection::Include(f) => f
178                .iter()
179                .map(|name| {
180                    child_fields
181                        .find(name)
182                        .ok_or_else(|| vortex_err!("Field {} not found in struct dtype", name))
183                })
184                .try_collect(),
185            FieldSelection::Exclude(names) => child_fields
186                .names()
187                .iter()
188                .filter(|&f| !names.as_ref().contains(f))
189                .map(|name| {
190                    child_fields
191                        .find(name)
192                        .ok_or_else(|| vortex_err!("Field {} not found in struct dtype", name))
193                })
194                .try_collect(),
195        }?;
196
197        let (fields, mask) = child.into_parts();
198        let new_fields = field_indices
199            .iter()
200            .map(|&idx| fields[idx].clone())
201            .collect();
202        Ok(unsafe { StructVector::new_unchecked(Arc::new(new_fields), mask) }.into())
203    }
204
205    fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool {
206        true
207    }
208
209    fn is_fallible(&self, _instance: &Self::Instance) -> bool {
210        // If this type-checks its infallible.
211        false
212    }
213}
214
215/// Creates an expression that selects (includes) specific fields from an array.
216///
217/// Projects only the specified fields from the child expression, which must be of DType struct.
218/// ```rust
219/// # use vortex_array::expr::{select, root};
220/// let expr = select(["name", "age"], root());
221/// ```
222pub fn select(field_names: impl Into<FieldNames>, child: Expression) -> Expression {
223    Select
224        .try_new_expr(FieldSelection::Include(field_names.into()), [child])
225        .vortex_expect("Failed to create Select expression")
226}
227
228/// Creates an expression that excludes specific fields from an array.
229///
230/// Projects all fields except the specified ones from the input struct expression.
231///
232/// ```rust
233/// # use vortex_array::expr::{select_exclude, root};
234/// let expr = select_exclude(["internal_id", "metadata"], root());
235/// ```
236pub fn select_exclude(fields: impl Into<FieldNames>, child: Expression) -> Expression {
237    Select
238        .try_new_expr(FieldSelection::Exclude(fields.into()), [child])
239        .vortex_expect("Failed to create Select expression")
240}
241
242impl ExpressionView<'_, Select> {
243    pub fn child(&self) -> &Expression {
244        &self.children()[0]
245    }
246
247    /// Turn the select expression into an `include`, relative to a provided array of field names.
248    ///
249    /// For example:
250    /// ```rust
251    /// # use vortex_array::expr::{root, Select};
252    /// # use vortex_array::expr::{FieldSelection, select, select_exclude};
253    /// # use vortex_dtype::FieldNames;
254    /// let field_names = FieldNames::from(["a", "b", "c"]);
255    /// let include = select(["a"], root());
256    /// let exclude = select_exclude(["b", "c"], root());
257    /// assert_eq!(
258    ///     &include.as_::<Select>().as_include(&field_names).unwrap(),
259    ///     &exclude.as_::<Select>().as_include(&field_names).unwrap(),
260    /// );
261    /// ```
262    pub fn as_include(&self, field_names: &FieldNames) -> VortexResult<Expression> {
263        Select.try_new_expr(
264            FieldSelection::Include(self.data().as_include_names(field_names)?),
265            [self.child().clone()],
266        )
267    }
268}
269
270impl FieldSelection {
271    pub fn include(columns: FieldNames) -> Self {
272        assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
273        Self::Include(columns)
274    }
275
276    pub fn exclude(columns: FieldNames) -> Self {
277        assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
278        Self::Exclude(columns)
279    }
280
281    pub fn is_include(&self) -> bool {
282        matches!(self, Self::Include(_))
283    }
284
285    pub fn is_exclude(&self) -> bool {
286        matches!(self, Self::Exclude(_))
287    }
288
289    pub fn field_names(&self) -> &FieldNames {
290        let (FieldSelection::Include(fields) | FieldSelection::Exclude(fields)) = self;
291
292        fields
293    }
294
295    pub fn as_include_names(&self, field_names: &FieldNames) -> VortexResult<FieldNames> {
296        if self
297            .field_names()
298            .iter()
299            .any(|f| !field_names.iter().contains(f))
300        {
301            vortex_bail!(
302                "Field {:?} in select not in field names {:?}",
303                self,
304                field_names
305            );
306        }
307        match self {
308            FieldSelection::Include(fields) => Ok(fields.clone()),
309            FieldSelection::Exclude(exc_fields) => Ok(field_names
310                .iter()
311                .filter(|f| !exc_fields.iter().contains(f))
312                .cloned()
313                .collect()),
314        }
315    }
316}
317
318impl Display for FieldSelection {
319    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
320        match self {
321            FieldSelection::Include(fields) => write!(f, "{{{}}}", DisplayFieldNames(fields)),
322            FieldSelection::Exclude(fields) => write!(f, "~{{{}}}", DisplayFieldNames(fields)),
323        }
324    }
325}
326
327#[cfg(test)]
328mod tests {
329    use vortex_buffer::buffer;
330    use vortex_dtype::DType;
331    use vortex_dtype::FieldName;
332    use vortex_dtype::FieldNames;
333    use vortex_dtype::Nullability;
334
335    use super::select;
336    use super::select_exclude;
337    use crate::IntoArray;
338    use crate::ToCanonical;
339    use crate::arrays::StructArray;
340    use crate::expr::exprs::root::root;
341    use crate::expr::exprs::select::Select;
342    use crate::expr::test_harness;
343
344    fn test_array() -> StructArray {
345        StructArray::from_fields(&[
346            ("a", buffer![0, 1, 2].into_array()),
347            ("b", buffer![4, 5, 6].into_array()),
348        ])
349        .unwrap()
350    }
351
352    #[test]
353    pub fn include_columns() {
354        let st = test_array();
355        let select = select(vec![FieldName::from("a")], root());
356        let selected = select.evaluate(&st.to_array()).unwrap().to_struct();
357        let selected_names = selected.names().clone();
358        assert_eq!(selected_names.as_ref(), &["a"]);
359    }
360
361    #[test]
362    pub fn exclude_columns() {
363        let st = test_array();
364        let select = select_exclude(vec![FieldName::from("a")], root());
365        let selected = select.evaluate(&st.to_array()).unwrap().to_struct();
366        let selected_names = selected.names().clone();
367        assert_eq!(selected_names.as_ref(), &["b"]);
368    }
369
370    #[test]
371    fn dtype() {
372        let dtype = test_harness::struct_dtype();
373
374        let select_expr = select(vec![FieldName::from("a")], root());
375        let expected_dtype = DType::Struct(
376            dtype
377                .as_struct_fields_opt()
378                .unwrap()
379                .project(&["a".into()])
380                .unwrap(),
381            Nullability::NonNullable,
382        );
383        assert_eq!(select_expr.return_dtype(&dtype).unwrap(), expected_dtype);
384
385        let select_expr_exclude = select_exclude(
386            vec![
387                FieldName::from("col1"),
388                FieldName::from("col2"),
389                FieldName::from("bool1"),
390                FieldName::from("bool2"),
391            ],
392            root(),
393        );
394        assert_eq!(
395            select_expr_exclude.return_dtype(&dtype).unwrap(),
396            expected_dtype
397        );
398
399        let select_expr_exclude = select_exclude(
400            vec![FieldName::from("col1"), FieldName::from("col2")],
401            root(),
402        );
403        assert_eq!(
404            select_expr_exclude.return_dtype(&dtype).unwrap(),
405            DType::Struct(
406                dtype
407                    .as_struct_fields_opt()
408                    .unwrap()
409                    .project(&["a".into(), "bool1".into(), "bool2".into()])
410                    .unwrap(),
411                Nullability::NonNullable
412            )
413        );
414    }
415
416    #[test]
417    fn test_as_include_names() {
418        let field_names = FieldNames::from(["a", "b", "c"]);
419        let include = select(["a"], root());
420        let exclude = select_exclude(["b", "c"], root());
421        assert_eq!(
422            &include
423                .as_::<Select>()
424                .data()
425                .as_include_names(&field_names)
426                .unwrap(),
427            &exclude
428                .as_::<Select>()
429                .data()
430                .as_include_names(&field_names)
431                .unwrap()
432        );
433    }
434}