vortex_expr/
select.rs

1use std::any::Any;
2use std::fmt::Display;
3use std::sync::Arc;
4
5use itertools::Itertools;
6use vortex_array::{Array, ArrayRef, ArrayVariants};
7use vortex_dtype::{DType, FieldNames};
8use vortex_error::{VortexResult, vortex_bail, vortex_err};
9
10use crate::field::DisplayFieldNames;
11use crate::{ExprRef, VortexExpr};
12
13#[derive(Debug, Clone, PartialEq, Eq, Hash)]
14pub enum SelectField {
15    Include(FieldNames),
16    Exclude(FieldNames),
17}
18
19#[derive(Debug, Clone, Eq, Hash)]
20#[allow(clippy::derived_hash_with_manual_eq)]
21pub struct Select {
22    fields: SelectField,
23    child: ExprRef,
24}
25
26pub fn select(fields: impl Into<FieldNames>, child: ExprRef) -> ExprRef {
27    Select::include_expr(fields.into(), child)
28}
29
30pub fn select_exclude(fields: impl Into<FieldNames>, child: ExprRef) -> ExprRef {
31    Select::exclude_expr(fields.into(), child)
32}
33
34impl Select {
35    pub fn new_expr(fields: SelectField, child: ExprRef) -> ExprRef {
36        Arc::new(Self { fields, child })
37    }
38
39    pub fn include_expr(columns: FieldNames, child: ExprRef) -> ExprRef {
40        Self::new_expr(SelectField::Include(columns), child)
41    }
42
43    pub fn exclude_expr(columns: FieldNames, child: ExprRef) -> ExprRef {
44        Self::new_expr(SelectField::Exclude(columns), child)
45    }
46
47    pub fn fields(&self) -> &SelectField {
48        &self.fields
49    }
50
51    pub fn child(&self) -> &ExprRef {
52        &self.child
53    }
54
55    pub fn as_include(&self, field_names: &FieldNames) -> VortexResult<ExprRef> {
56        Ok(Self::new_expr(
57            SelectField::Include(self.fields.as_include_names(field_names)?),
58            self.child.clone(),
59        ))
60    }
61}
62
63impl SelectField {
64    pub fn include(columns: FieldNames) -> Self {
65        assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
66        Self::Include(columns)
67    }
68
69    pub fn exclude(columns: FieldNames) -> Self {
70        assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
71        Self::Exclude(columns)
72    }
73
74    pub fn is_include(&self) -> bool {
75        matches!(self, Self::Include(_))
76    }
77
78    pub fn is_exclude(&self) -> bool {
79        matches!(self, Self::Exclude(_))
80    }
81
82    pub fn fields(&self) -> &FieldNames {
83        match self {
84            SelectField::Include(fields) => fields,
85            SelectField::Exclude(fields) => fields,
86        }
87    }
88
89    pub fn as_include_names(&self, field_names: &FieldNames) -> VortexResult<FieldNames> {
90        if self.fields().iter().any(|f| !field_names.contains(f)) {
91            vortex_bail!(
92                "Field {:?} in select not in field names {:?}",
93                self,
94                field_names
95            );
96        }
97        match self {
98            SelectField::Include(fields) => Ok(fields.clone()),
99            SelectField::Exclude(exc_fields) => Ok(field_names
100                .iter()
101                .filter(|f| exc_fields.contains(f))
102                .cloned()
103                .collect()),
104        }
105    }
106}
107
108impl Display for SelectField {
109    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110        match self {
111            SelectField::Include(fields) => write!(f, "{{{}}}", DisplayFieldNames(fields)),
112            SelectField::Exclude(fields) => write!(f, "~{{{}}}", DisplayFieldNames(fields)),
113        }
114    }
115}
116
117impl Display for Select {
118    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119        write!(f, "{}{}", self.child, self.fields)
120    }
121}
122
123impl VortexExpr for Select {
124    fn as_any(&self) -> &dyn Any {
125        self
126    }
127
128    fn unchecked_evaluate(&self, batch: &dyn Array) -> VortexResult<ArrayRef> {
129        let batch = self.child.evaluate(batch)?;
130        let st = batch
131            .as_struct_typed()
132            .ok_or_else(|| vortex_err!("Not a struct array"))?;
133        match &self.fields {
134            SelectField::Include(f) => st.project(f),
135            SelectField::Exclude(names) => {
136                let included_names = st
137                    .names()
138                    .iter()
139                    .filter(|&f| !names.contains(f))
140                    .cloned()
141                    .collect::<Vec<_>>();
142                st.project(included_names.as_slice())
143            }
144        }
145    }
146
147    fn children(&self) -> Vec<&ExprRef> {
148        vec![&self.child]
149    }
150
151    fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef {
152        assert_eq!(children.len(), 1);
153        Self::new_expr(self.fields.clone(), children[0].clone())
154    }
155
156    fn return_dtype(&self, scope_dtype: &DType) -> VortexResult<DType> {
157        let child_dtype = self.child.return_dtype(scope_dtype)?;
158        let child_struct_dtype = child_dtype
159            .as_struct()
160            .ok_or_else(|| vortex_err!("Select child not a struct dtype"))?;
161
162        let projected = match &self.fields {
163            SelectField::Include(fields) => child_struct_dtype.project(fields)?,
164            SelectField::Exclude(fields) => child_struct_dtype
165                .names()
166                .iter()
167                .cloned()
168                .zip_eq(child_struct_dtype.fields())
169                .filter(|(name, _)| !fields.contains(name))
170                .collect(),
171        };
172
173        Ok(DType::Struct(
174            Arc::new(projected),
175            child_dtype.nullability(),
176        ))
177    }
178}
179
180impl PartialEq for Select {
181    fn eq(&self, other: &Select) -> bool {
182        self.fields == other.fields && self.child.eq(&other.child)
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use std::sync::Arc;
189
190    use vortex_array::arrays::StructArray;
191    use vortex_array::{ArrayVariants, IntoArray};
192    use vortex_buffer::buffer;
193    use vortex_dtype::{DType, FieldName, Nullability};
194
195    use crate::{ident, select, select_exclude, test_harness};
196
197    fn test_array() -> StructArray {
198        StructArray::from_fields(&[
199            ("a", buffer![0, 1, 2].into_array()),
200            ("b", buffer![4, 5, 6].into_array()),
201        ])
202        .unwrap()
203    }
204
205    #[test]
206    pub fn include_columns() {
207        let st = test_array();
208        let select = select(vec![FieldName::from("a")], ident());
209        let selected = select.evaluate(&st).unwrap();
210        let selected_names = selected.as_struct_typed().unwrap().names().clone();
211        assert_eq!(selected_names.as_ref(), &["a".into()]);
212    }
213
214    #[test]
215    pub fn exclude_columns() {
216        let st = test_array();
217        let select = select_exclude(vec![FieldName::from("a")], ident());
218        let selected = select.evaluate(&st).unwrap();
219        let selected_names = selected.as_struct_typed().unwrap().names().clone();
220        assert_eq!(selected_names.as_ref(), &["b".into()]);
221    }
222
223    #[test]
224    fn dtype() {
225        let dtype = test_harness::struct_dtype();
226
227        let select_expr = select(vec![FieldName::from("a")], ident());
228        let expected_dtype = DType::Struct(
229            Arc::new(dtype.as_struct().unwrap().project(&["a".into()]).unwrap()),
230            Nullability::NonNullable,
231        );
232        assert_eq!(select_expr.return_dtype(&dtype).unwrap(), expected_dtype);
233
234        let select_expr_exclude = select_exclude(
235            vec![
236                FieldName::from("col1"),
237                FieldName::from("col2"),
238                FieldName::from("bool1"),
239                FieldName::from("bool2"),
240            ],
241            ident(),
242        );
243        assert_eq!(
244            select_expr_exclude.return_dtype(&dtype).unwrap(),
245            expected_dtype
246        );
247
248        let select_expr_exclude = select_exclude(
249            vec![FieldName::from("col1"), FieldName::from("col2")],
250            ident(),
251        );
252        assert_eq!(
253            select_expr_exclude.return_dtype(&dtype).unwrap(),
254            DType::Struct(
255                Arc::new(
256                    dtype
257                        .as_struct()
258                        .unwrap()
259                        .project(&["a".into(), "bool1".into(), "bool2".into()])
260                        .unwrap()
261                ),
262                Nullability::NonNullable
263            )
264        );
265    }
266}