1use std::fmt::Display;
5
6use itertools::Itertools;
7use vortex_array::{ArrayRef, DeserializeMetadata, EmptyMetadata, IntoArray, ToCanonical};
8use vortex_dtype::{DType, FieldNames};
9use vortex_error::{VortexResult, vortex_bail, vortex_err};
10
11use crate::field::DisplayFieldNames;
12use crate::{AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, VTable, vtable};
13
14#[derive(Debug, Clone, PartialEq, Eq, Hash)]
15pub enum SelectField {
16 Include(FieldNames),
17 Exclude(FieldNames),
18}
19
20vtable!(Select);
21
22#[derive(Debug, Clone, Hash)]
23#[allow(clippy::derived_hash_with_manual_eq)]
24pub struct SelectExpr {
25 fields: SelectField,
26 child: ExprRef,
27}
28
29impl PartialEq for SelectExpr {
30 fn eq(&self, other: &Self) -> bool {
31 self.fields == other.fields && self.child.eq(&other.child)
32 }
33}
34
35pub struct SelectExprEncoding;
36
37impl VTable for SelectVTable {
38 type Expr = SelectExpr;
39 type Encoding = SelectExprEncoding;
40 type Metadata = EmptyMetadata;
41
42 fn id(_encoding: &Self::Encoding) -> ExprId {
43 ExprId::new_ref("select")
44 }
45
46 fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
47 ExprEncodingRef::new_ref(SelectExprEncoding.as_ref())
48 }
49
50 fn metadata(_expr: &Self::Expr) -> Option<Self::Metadata> {
51 None
53 }
54
55 fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
56 vec![&expr.child]
57 }
58
59 fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
60 Ok(SelectExpr {
61 fields: expr.fields.clone(),
62 child: children[0].clone(),
63 })
64 }
65
66 fn build(
67 _encoding: &Self::Encoding,
68 _metadata: &<Self::Metadata as DeserializeMetadata>::Output,
69 _children: Vec<ExprRef>,
70 ) -> VortexResult<Self::Expr> {
71 vortex_bail!("Select does not support deserialization")
72 }
73
74 fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
75 let batch = expr.child.unchecked_evaluate(scope)?.to_struct()?;
76 Ok(match &expr.fields {
77 SelectField::Include(f) => batch.project(f.as_ref()),
78 SelectField::Exclude(names) => {
79 let included_names = batch
80 .names()
81 .iter()
82 .filter(|&f| !names.as_ref().contains(f))
83 .cloned()
84 .collect::<Vec<_>>();
85 batch.project(included_names.as_slice())
86 }
87 }?
88 .into_array())
89 }
90
91 fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
92 let child_dtype = expr.child.return_dtype(scope)?;
93 let child_struct_dtype = child_dtype
94 .as_struct()
95 .ok_or_else(|| vortex_err!("Select child not a struct dtype"))?;
96
97 let projected = match &expr.fields {
98 SelectField::Include(fields) => child_struct_dtype.project(fields.as_ref())?,
99 SelectField::Exclude(fields) => child_struct_dtype
100 .names()
101 .iter()
102 .cloned()
103 .zip_eq(child_struct_dtype.fields())
104 .filter(|(name, _)| !fields.as_ref().contains(name))
105 .collect(),
106 };
107
108 Ok(DType::Struct(projected, child_dtype.nullability()))
109 }
110}
111
112pub fn select(fields: impl Into<FieldNames>, child: ExprRef) -> ExprRef {
113 SelectExpr::include_expr(fields.into(), child)
114}
115
116pub fn select_exclude(fields: impl Into<FieldNames>, child: ExprRef) -> ExprRef {
117 SelectExpr::exclude_expr(fields.into(), child)
118}
119
120impl SelectExpr {
121 pub fn new(fields: SelectField, child: ExprRef) -> Self {
122 Self { fields, child }
123 }
124
125 pub fn new_expr(fields: SelectField, child: ExprRef) -> ExprRef {
126 Self::new(fields, child).into_expr()
127 }
128
129 pub fn include_expr(columns: FieldNames, child: ExprRef) -> ExprRef {
130 Self::new(SelectField::Include(columns), child).into_expr()
131 }
132
133 pub fn exclude_expr(columns: FieldNames, child: ExprRef) -> ExprRef {
134 Self::new(SelectField::Exclude(columns), child).into_expr()
135 }
136
137 pub fn fields(&self) -> &SelectField {
138 &self.fields
139 }
140
141 pub fn child(&self) -> &ExprRef {
142 &self.child
143 }
144
145 pub fn as_include(&self, field_names: &FieldNames) -> VortexResult<ExprRef> {
146 Ok(Self::new(
147 SelectField::Include(self.fields.as_include_names(field_names)?),
148 self.child.clone(),
149 )
150 .into_expr())
151 }
152}
153
154impl SelectField {
155 pub fn include(columns: FieldNames) -> Self {
156 assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
157 Self::Include(columns)
158 }
159
160 pub fn exclude(columns: FieldNames) -> Self {
161 assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
162 Self::Exclude(columns)
163 }
164
165 pub fn is_include(&self) -> bool {
166 matches!(self, Self::Include(_))
167 }
168
169 pub fn is_exclude(&self) -> bool {
170 matches!(self, Self::Exclude(_))
171 }
172
173 pub fn fields(&self) -> &FieldNames {
174 match self {
175 SelectField::Include(fields) => fields,
176 SelectField::Exclude(fields) => fields,
177 }
178 }
179
180 pub fn as_include_names(&self, field_names: &FieldNames) -> VortexResult<FieldNames> {
181 if self
182 .fields()
183 .iter()
184 .any(|f| !field_names.iter().contains(f))
185 {
186 vortex_bail!(
187 "Field {:?} in select not in field names {:?}",
188 self,
189 field_names
190 );
191 }
192 match self {
193 SelectField::Include(fields) => Ok(fields.clone()),
194 SelectField::Exclude(exc_fields) => Ok(field_names
195 .iter()
196 .filter(|f| exc_fields.iter().contains(f))
197 .cloned()
198 .collect()),
199 }
200 }
201}
202
203impl Display for SelectField {
204 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205 match self {
206 SelectField::Include(fields) => write!(f, "{{{}}}", DisplayFieldNames(fields)),
207 SelectField::Exclude(fields) => write!(f, "~{{{}}}", DisplayFieldNames(fields)),
208 }
209 }
210}
211
212impl Display for SelectExpr {
213 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214 write!(f, "{}{}", self.child, self.fields)
215 }
216}
217
218impl AnalysisExpr for SelectExpr {}
219
220#[cfg(test)]
221mod tests {
222
223 use vortex_array::arrays::StructArray;
224 use vortex_array::{IntoArray, ToCanonical};
225 use vortex_buffer::buffer;
226 use vortex_dtype::{DType, FieldName, Nullability};
227
228 use crate::{Scope, root, select, select_exclude, test_harness};
229
230 fn test_array() -> StructArray {
231 StructArray::from_fields(&[
232 ("a", buffer![0, 1, 2].into_array()),
233 ("b", buffer![4, 5, 6].into_array()),
234 ])
235 .unwrap()
236 }
237
238 #[test]
239 pub fn include_columns() {
240 let st = test_array();
241 let select = select(vec![FieldName::from("a")], root());
242 let selected = select
243 .evaluate(&Scope::new(st.to_array()))
244 .unwrap()
245 .to_struct()
246 .unwrap();
247 let selected_names = selected.names().clone();
248 assert_eq!(selected_names.as_ref(), &["a".into()]);
249 }
250
251 #[test]
252 pub fn exclude_columns() {
253 let st = test_array();
254 let select = select_exclude(vec![FieldName::from("a")], root());
255 let selected = select
256 .evaluate(&Scope::new(st.to_array()))
257 .unwrap()
258 .to_struct()
259 .unwrap();
260 let selected_names = selected.names().clone();
261 assert_eq!(selected_names.as_ref(), &["b".into()]);
262 }
263
264 #[test]
265 fn dtype() {
266 let dtype = test_harness::struct_dtype();
267
268 let select_expr = select(vec![FieldName::from("a")], root());
269 let expected_dtype = DType::Struct(
270 dtype.as_struct().unwrap().project(&["a".into()]).unwrap(),
271 Nullability::NonNullable,
272 );
273 assert_eq!(select_expr.return_dtype(&dtype).unwrap(), expected_dtype);
274
275 let select_expr_exclude = select_exclude(
276 vec![
277 FieldName::from("col1"),
278 FieldName::from("col2"),
279 FieldName::from("bool1"),
280 FieldName::from("bool2"),
281 ],
282 root(),
283 );
284 assert_eq!(
285 select_expr_exclude.return_dtype(&dtype).unwrap(),
286 expected_dtype
287 );
288
289 let select_expr_exclude = select_exclude(
290 vec![FieldName::from("col1"), FieldName::from("col2")],
291 root(),
292 );
293 assert_eq!(
294 select_expr_exclude.return_dtype(&dtype).unwrap(),
295 DType::Struct(
296 dtype
297 .as_struct()
298 .unwrap()
299 .project(&["a".into(), "bool1".into(), "bool2".into()])
300 .unwrap(),
301 Nullability::NonNullable
302 )
303 );
304 }
305}