1use std::fmt::Display;
5
6use itertools::Itertools;
7use vortex_array::{ArrayRef, DeserializeMetadata, IntoArray, ProstMetadata, ToCanonical};
8use vortex_dtype::{DType, FieldNames};
9use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
10use vortex_proto::expr::select_opts::Opts;
11use vortex_proto::expr::{FieldNames as ProtoFieldNames, SelectOpts};
12
13use crate::field::DisplayFieldNames;
14use crate::{AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, VTable, vtable};
15
16#[derive(Debug, Clone, PartialEq, Eq, Hash)]
17pub enum SelectField {
18 Include(FieldNames),
19 Exclude(FieldNames),
20}
21
22vtable!(Select);
23
24#[derive(Debug, Clone, Hash, Eq)]
25#[allow(clippy::derived_hash_with_manual_eq)]
26pub struct SelectExpr {
27 fields: SelectField,
28 child: ExprRef,
29}
30
31impl PartialEq for SelectExpr {
32 fn eq(&self, other: &Self) -> bool {
33 self.fields == other.fields && self.child.eq(&other.child)
34 }
35}
36
37pub struct SelectExprEncoding;
38
39impl VTable for SelectVTable {
40 type Expr = SelectExpr;
41 type Encoding = SelectExprEncoding;
42 type Metadata = ProstMetadata<SelectOpts>;
43
44 fn id(_encoding: &Self::Encoding) -> ExprId {
45 ExprId::new_ref("select")
46 }
47
48 fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
49 ExprEncodingRef::new_ref(SelectExprEncoding.as_ref())
50 }
51
52 fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
53 let names = expr
54 .fields()
55 .fields()
56 .iter()
57 .map(|f| f.to_string())
58 .collect_vec();
59
60 let opts = if expr.fields().is_include() {
61 Opts::Include(ProtoFieldNames { names })
62 } else {
63 Opts::Exclude(ProtoFieldNames { names })
64 };
65
66 Some(ProstMetadata(SelectOpts { opts: Some(opts) }))
67 }
68
69 fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
70 vec![&expr.child]
71 }
72
73 fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
74 Ok(SelectExpr {
75 fields: expr.fields.clone(),
76 child: children[0].clone(),
77 })
78 }
79
80 fn build(
81 _encoding: &Self::Encoding,
82 metadata: &<Self::Metadata as DeserializeMetadata>::Output,
83 mut children: Vec<ExprRef>,
84 ) -> VortexResult<Self::Expr> {
85 if children.len() != 1 {
86 vortex_bail!("Select expression must have exactly one child");
87 }
88
89 let fields = match metadata.opts.as_ref() {
90 Some(opts) => match opts {
91 Opts::Include(field_names) => SelectField::Include(FieldNames::from_iter(
92 field_names.names.iter().map(|s| s.as_str()),
93 )),
94 Opts::Exclude(field_names) => SelectField::Exclude(FieldNames::from_iter(
95 field_names.names.iter().map(|s| s.as_str()),
96 )),
97 },
98 None => {
99 vortex_bail!("Select expressions must be provided with fields to select or exclude")
100 }
101 };
102
103 let child = children
104 .drain(..)
105 .next()
106 .vortex_expect("number of children validated to be one");
107
108 Ok(SelectExpr { fields, child })
109 }
110
111 fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
112 let batch = expr.child.unchecked_evaluate(scope)?.to_struct()?;
113 Ok(match &expr.fields {
114 SelectField::Include(f) => batch.project(f.as_ref()),
115 SelectField::Exclude(names) => {
116 let included_names = batch
117 .names()
118 .iter()
119 .filter(|&f| !names.as_ref().contains(f))
120 .cloned()
121 .collect::<Vec<_>>();
122 batch.project(included_names.as_slice())
123 }
124 }?
125 .into_array())
126 }
127
128 fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
129 let child_dtype = expr.child.return_dtype(scope)?;
130 let child_struct_dtype = child_dtype
131 .as_struct()
132 .ok_or_else(|| vortex_err!("Select child not a struct dtype"))?;
133
134 let projected = match &expr.fields {
135 SelectField::Include(fields) => child_struct_dtype.project(fields.as_ref())?,
136 SelectField::Exclude(fields) => child_struct_dtype
137 .names()
138 .iter()
139 .cloned()
140 .zip_eq(child_struct_dtype.fields())
141 .filter(|(name, _)| !fields.as_ref().contains(name))
142 .collect(),
143 };
144
145 Ok(DType::Struct(projected, child_dtype.nullability()))
146 }
147}
148
149pub fn select(fields: impl Into<FieldNames>, child: ExprRef) -> ExprRef {
157 SelectExpr::include_expr(fields.into(), child)
158}
159
160pub fn select_exclude(fields: impl Into<FieldNames>, child: ExprRef) -> ExprRef {
169 SelectExpr::exclude_expr(fields.into(), child)
170}
171
172impl SelectExpr {
173 pub fn new(fields: SelectField, child: ExprRef) -> Self {
174 Self { fields, child }
175 }
176
177 pub fn new_expr(fields: SelectField, child: ExprRef) -> ExprRef {
178 Self::new(fields, child).into_expr()
179 }
180
181 pub fn include_expr(columns: FieldNames, child: ExprRef) -> ExprRef {
182 Self::new(SelectField::Include(columns), child).into_expr()
183 }
184
185 pub fn exclude_expr(columns: FieldNames, child: ExprRef) -> ExprRef {
186 Self::new(SelectField::Exclude(columns), child).into_expr()
187 }
188
189 pub fn fields(&self) -> &SelectField {
190 &self.fields
191 }
192
193 pub fn child(&self) -> &ExprRef {
194 &self.child
195 }
196
197 pub fn as_include(&self, field_names: &FieldNames) -> VortexResult<ExprRef> {
213 Ok(Self::new(
214 SelectField::Include(self.fields.as_include_names(field_names)?),
215 self.child.clone(),
216 )
217 .into_expr())
218 }
219}
220
221impl SelectField {
222 pub fn include(columns: FieldNames) -> Self {
223 assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
224 Self::Include(columns)
225 }
226
227 pub fn exclude(columns: FieldNames) -> Self {
228 assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
229 Self::Exclude(columns)
230 }
231
232 pub fn is_include(&self) -> bool {
233 matches!(self, Self::Include(_))
234 }
235
236 pub fn is_exclude(&self) -> bool {
237 matches!(self, Self::Exclude(_))
238 }
239
240 pub fn fields(&self) -> &FieldNames {
241 let (SelectField::Include(fields) | SelectField::Exclude(fields)) = self;
242
243 fields
244 }
245
246 pub fn as_include_names(&self, field_names: &FieldNames) -> VortexResult<FieldNames> {
247 if self
248 .fields()
249 .iter()
250 .any(|f| !field_names.iter().contains(f))
251 {
252 vortex_bail!(
253 "Field {:?} in select not in field names {:?}",
254 self,
255 field_names
256 );
257 }
258 match self {
259 SelectField::Include(fields) => Ok(fields.clone()),
260 SelectField::Exclude(exc_fields) => Ok(field_names
261 .iter()
262 .filter(|f| !exc_fields.iter().contains(f))
263 .cloned()
264 .collect()),
265 }
266 }
267}
268
269impl Display for SelectField {
270 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
271 match self {
272 SelectField::Include(fields) => write!(f, "{{{}}}", DisplayFieldNames(fields)),
273 SelectField::Exclude(fields) => write!(f, "~{{{}}}", DisplayFieldNames(fields)),
274 }
275 }
276}
277
278impl Display for SelectExpr {
279 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
280 write!(f, "{}{}", self.child, self.fields)
281 }
282}
283
284impl AnalysisExpr for SelectExpr {}
285
286#[cfg(test)]
287mod tests {
288
289 use vortex_array::arrays::StructArray;
290 use vortex_array::{IntoArray, ToCanonical};
291 use vortex_buffer::buffer;
292 use vortex_dtype::{DType, FieldName, FieldNames, Nullability};
293
294 use crate::{Scope, SelectExpr, SelectField, root, select, select_exclude, test_harness};
295
296 fn test_array() -> StructArray {
297 StructArray::from_fields(&[
298 ("a", buffer![0, 1, 2].into_array()),
299 ("b", buffer![4, 5, 6].into_array()),
300 ])
301 .unwrap()
302 }
303
304 #[test]
305 pub fn include_columns() {
306 let st = test_array();
307 let select = select(vec![FieldName::from("a")], root());
308 let selected = select
309 .evaluate(&Scope::new(st.to_array()))
310 .unwrap()
311 .to_struct()
312 .unwrap();
313 let selected_names = selected.names().clone();
314 assert_eq!(selected_names.as_ref(), &["a".into()]);
315 }
316
317 #[test]
318 pub fn exclude_columns() {
319 let st = test_array();
320 let select = select_exclude(vec![FieldName::from("a")], root());
321 let selected = select
322 .evaluate(&Scope::new(st.to_array()))
323 .unwrap()
324 .to_struct()
325 .unwrap();
326 let selected_names = selected.names().clone();
327 assert_eq!(selected_names.as_ref(), &["b".into()]);
328 }
329
330 #[test]
331 fn dtype() {
332 let dtype = test_harness::struct_dtype();
333
334 let select_expr = select(vec![FieldName::from("a")], root());
335 let expected_dtype = DType::Struct(
336 dtype.as_struct().unwrap().project(&["a".into()]).unwrap(),
337 Nullability::NonNullable,
338 );
339 assert_eq!(select_expr.return_dtype(&dtype).unwrap(), expected_dtype);
340
341 let select_expr_exclude = select_exclude(
342 vec![
343 FieldName::from("col1"),
344 FieldName::from("col2"),
345 FieldName::from("bool1"),
346 FieldName::from("bool2"),
347 ],
348 root(),
349 );
350 assert_eq!(
351 select_expr_exclude.return_dtype(&dtype).unwrap(),
352 expected_dtype
353 );
354
355 let select_expr_exclude = select_exclude(
356 vec![FieldName::from("col1"), FieldName::from("col2")],
357 root(),
358 );
359 assert_eq!(
360 select_expr_exclude.return_dtype(&dtype).unwrap(),
361 DType::Struct(
362 dtype
363 .as_struct()
364 .unwrap()
365 .project(&["a".into(), "bool1".into(), "bool2".into()])
366 .unwrap(),
367 Nullability::NonNullable
368 )
369 );
370 }
371
372 #[test]
373 fn test_as_include_names() {
374 let field_names = FieldNames::from(["a", "b", "c"]);
375 let include = SelectExpr::new(SelectField::Include(["a"].into()), root());
376 let exclude = SelectExpr::new(SelectField::Exclude(["b", "c"].into()), root());
377 assert_eq!(
378 &include.as_include(&field_names).unwrap(),
379 &exclude.as_include(&field_names).unwrap()
380 );
381 }
382}