vortex_expr/exprs/
select.rs

1use std::any::Any;
2use std::fmt::Display;
3use std::sync::Arc;
4
5use itertools::Itertools;
6use vortex_array::{ArrayRef, IntoArray, ToCanonical};
7use vortex_dtype::{DType, FieldNames};
8use vortex_error::{VortexResult, vortex_bail, vortex_err};
9
10use crate::field::DisplayFieldNames;
11use crate::{AnalysisExpr, ExprRef, Scope, ScopeDType, VortexExpr};
12
13#[derive(Debug, Clone, PartialEq, Eq, Hash)]
14pub enum SelectField {
15    Include(FieldNames),
16    Exclude(FieldNames),
17}
18
19#[derive(Debug, Clone, Eq, Hash)]
20#[allow(clippy::derived_hash_with_manual_eq)]
21pub struct Select {
22    fields: SelectField,
23    child: ExprRef,
24}
25
26pub fn select(fields: impl Into<FieldNames>, child: ExprRef) -> ExprRef {
27    Select::include_expr(fields.into(), child)
28}
29
30pub fn select_exclude(fields: impl Into<FieldNames>, child: ExprRef) -> ExprRef {
31    Select::exclude_expr(fields.into(), child)
32}
33
34impl Select {
35    pub fn new_expr(fields: SelectField, child: ExprRef) -> ExprRef {
36        Arc::new(Self { fields, child })
37    }
38
39    pub fn include_expr(columns: FieldNames, child: ExprRef) -> ExprRef {
40        Self::new_expr(SelectField::Include(columns), child)
41    }
42
43    pub fn exclude_expr(columns: FieldNames, child: ExprRef) -> ExprRef {
44        Self::new_expr(SelectField::Exclude(columns), child)
45    }
46
47    pub fn fields(&self) -> &SelectField {
48        &self.fields
49    }
50
51    pub fn child(&self) -> &ExprRef {
52        &self.child
53    }
54
55    pub fn as_include(&self, field_names: &FieldNames) -> VortexResult<ExprRef> {
56        Ok(Self::new_expr(
57            SelectField::Include(self.fields.as_include_names(field_names)?),
58            self.child.clone(),
59        ))
60    }
61}
62
63impl SelectField {
64    pub fn include(columns: FieldNames) -> Self {
65        assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
66        Self::Include(columns)
67    }
68
69    pub fn exclude(columns: FieldNames) -> Self {
70        assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
71        Self::Exclude(columns)
72    }
73
74    pub fn is_include(&self) -> bool {
75        matches!(self, Self::Include(_))
76    }
77
78    pub fn is_exclude(&self) -> bool {
79        matches!(self, Self::Exclude(_))
80    }
81
82    pub fn fields(&self) -> &FieldNames {
83        match self {
84            SelectField::Include(fields) => fields,
85            SelectField::Exclude(fields) => fields,
86        }
87    }
88
89    pub fn as_include_names(&self, field_names: &FieldNames) -> VortexResult<FieldNames> {
90        if self
91            .fields()
92            .iter()
93            .any(|f| !field_names.iter().contains(f))
94        {
95            vortex_bail!(
96                "Field {:?} in select not in field names {:?}",
97                self,
98                field_names
99            );
100        }
101        match self {
102            SelectField::Include(fields) => Ok(fields.clone()),
103            SelectField::Exclude(exc_fields) => Ok(field_names
104                .iter()
105                .filter(|f| exc_fields.iter().contains(f))
106                .cloned()
107                .collect()),
108        }
109    }
110}
111
112impl Display for SelectField {
113    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114        match self {
115            SelectField::Include(fields) => write!(f, "{{{}}}", DisplayFieldNames(fields)),
116            SelectField::Exclude(fields) => write!(f, "~{{{}}}", DisplayFieldNames(fields)),
117        }
118    }
119}
120
121impl Display for Select {
122    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123        write!(f, "{}{}", self.child, self.fields)
124    }
125}
126
127#[cfg(feature = "proto")]
128pub(crate) mod proto {
129    use vortex_error::{VortexResult, vortex_bail};
130    use vortex_proto::expr::kind::Kind;
131
132    use crate::{ExprDeserialize, ExprRef, ExprSerializable, Id, Select};
133
134    pub struct SelectSerde;
135
136    impl Id for SelectSerde {
137        fn id(&self) -> &'static str {
138            "select"
139        }
140    }
141
142    impl ExprDeserialize for SelectSerde {
143        fn deserialize(&self, _kind: &Kind, _children: Vec<ExprRef>) -> VortexResult<ExprRef> {
144            vortex_bail!(NotImplemented: "", self.id())
145        }
146    }
147
148    impl ExprSerializable for Select {
149        fn id(&self) -> &'static str {
150            SelectSerde.id()
151        }
152
153        fn serialize_kind(&self) -> VortexResult<Kind> {
154            vortex_bail!(NotImplemented: "", self.id())
155        }
156    }
157}
158
159impl AnalysisExpr for Select {}
160
161impl VortexExpr for Select {
162    fn as_any(&self) -> &dyn Any {
163        self
164    }
165
166    fn unchecked_evaluate(&self, scope: &Scope) -> VortexResult<ArrayRef> {
167        let batch = self.child.unchecked_evaluate(scope)?.to_struct()?;
168        Ok(match &self.fields {
169            SelectField::Include(f) => batch.project(f.as_ref()),
170            SelectField::Exclude(names) => {
171                let included_names = batch
172                    .names()
173                    .iter()
174                    .filter(|&f| !names.iter().contains(f))
175                    .cloned()
176                    .collect::<Vec<_>>();
177                batch.project(included_names.as_slice())
178            }
179        }?
180        .into_array())
181    }
182
183    fn children(&self) -> Vec<&ExprRef> {
184        vec![&self.child]
185    }
186
187    fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef {
188        assert_eq!(children.len(), 1);
189        Self::new_expr(self.fields.clone(), children[0].clone())
190    }
191
192    fn return_dtype(&self, scope: &ScopeDType) -> VortexResult<DType> {
193        let child_dtype = self.child.return_dtype(scope)?;
194        let child_struct_dtype = child_dtype
195            .as_struct()
196            .ok_or_else(|| vortex_err!("Select child not a struct dtype"))?;
197
198        let projected = match &self.fields {
199            SelectField::Include(fields) => child_struct_dtype.project(fields.as_ref())?,
200            SelectField::Exclude(fields) => child_struct_dtype
201                .names()
202                .iter()
203                .cloned()
204                .zip_eq(child_struct_dtype.fields())
205                .filter(|(name, _)| !fields.iter().contains(name))
206                .collect(),
207        };
208
209        Ok(DType::Struct(projected, child_dtype.nullability()))
210    }
211}
212
213impl PartialEq for Select {
214    fn eq(&self, other: &Select) -> bool {
215        self.fields == other.fields && self.child.eq(&other.child)
216    }
217}
218
219#[cfg(test)]
220mod tests {
221
222    use vortex_array::arrays::StructArray;
223    use vortex_array::{IntoArray, ToCanonical};
224    use vortex_buffer::buffer;
225    use vortex_dtype::{DType, FieldName, Nullability};
226
227    use crate::{Scope, ScopeDType, root, select, select_exclude, test_harness};
228
229    fn test_array() -> StructArray {
230        StructArray::from_fields(&[
231            ("a", buffer![0, 1, 2].into_array()),
232            ("b", buffer![4, 5, 6].into_array()),
233        ])
234        .unwrap()
235    }
236
237    #[test]
238    pub fn include_columns() {
239        let st = test_array();
240        let select = select(vec![FieldName::from("a")], root());
241        let selected = select
242            .evaluate(&Scope::new(st.to_array()))
243            .unwrap()
244            .to_struct()
245            .unwrap();
246        let selected_names = selected.names().clone();
247        assert_eq!(selected_names.as_ref(), &["a".into()]);
248    }
249
250    #[test]
251    pub fn exclude_columns() {
252        let st = test_array();
253        let select = select_exclude(vec![FieldName::from("a")], root());
254        let selected = select
255            .evaluate(&Scope::new(st.to_array()))
256            .unwrap()
257            .to_struct()
258            .unwrap();
259        let selected_names = selected.names().clone();
260        assert_eq!(selected_names.as_ref(), &["b".into()]);
261    }
262
263    #[test]
264    fn dtype() {
265        let dtype = test_harness::struct_dtype();
266
267        let select_expr = select(vec![FieldName::from("a")], root());
268        let expected_dtype = DType::Struct(
269            dtype.as_struct().unwrap().project(&["a".into()]).unwrap(),
270            Nullability::NonNullable,
271        );
272        assert_eq!(
273            select_expr
274                .return_dtype(&ScopeDType::new(dtype.clone()))
275                .unwrap(),
276            expected_dtype
277        );
278
279        let select_expr_exclude = select_exclude(
280            vec![
281                FieldName::from("col1"),
282                FieldName::from("col2"),
283                FieldName::from("bool1"),
284                FieldName::from("bool2"),
285            ],
286            root(),
287        );
288        assert_eq!(
289            select_expr_exclude
290                .return_dtype(&ScopeDType::new(dtype.clone()))
291                .unwrap(),
292            expected_dtype
293        );
294
295        let select_expr_exclude = select_exclude(
296            vec![FieldName::from("col1"), FieldName::from("col2")],
297            root(),
298        );
299        assert_eq!(
300            select_expr_exclude
301                .return_dtype(&ScopeDType::new(dtype.clone()))
302                .unwrap(),
303            DType::Struct(
304                dtype
305                    .as_struct()
306                    .unwrap()
307                    .project(&["a".into(), "bool1".into(), "bool2".into()])
308                    .unwrap(),
309                Nullability::NonNullable
310            )
311        );
312    }
313}