1use std::fmt::Display;
5
6use itertools::Itertools;
7use vortex_array::{ArrayRef, DeserializeMetadata, IntoArray, ProstMetadata, ToCanonical};
8use vortex_dtype::{DType, FieldNames};
9use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
10use vortex_proto::expr::select_opts::Opts;
11use vortex_proto::expr::{FieldNames as ProtoFieldNames, SelectOpts};
12
13use crate::display::{DisplayAs, DisplayFormat};
14use crate::field::DisplayFieldNames;
15use crate::{AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, VTable, vtable};
16
17#[derive(Debug, Clone, PartialEq, Eq, Hash)]
18pub enum SelectField {
19 Include(FieldNames),
20 Exclude(FieldNames),
21}
22
23vtable!(Select);
24
25#[derive(Debug, Clone, Hash, Eq)]
26#[allow(clippy::derived_hash_with_manual_eq)]
27pub struct SelectExpr {
28 fields: SelectField,
29 child: ExprRef,
30}
31
32impl PartialEq for SelectExpr {
33 fn eq(&self, other: &Self) -> bool {
34 self.fields == other.fields && self.child.eq(&other.child)
35 }
36}
37
38pub struct SelectExprEncoding;
39
40impl VTable for SelectVTable {
41 type Expr = SelectExpr;
42 type Encoding = SelectExprEncoding;
43 type Metadata = ProstMetadata<SelectOpts>;
44
45 fn id(_encoding: &Self::Encoding) -> ExprId {
46 ExprId::new_ref("select")
47 }
48
49 fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
50 ExprEncodingRef::new_ref(SelectExprEncoding.as_ref())
51 }
52
53 fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
54 let names = expr
55 .fields()
56 .fields()
57 .iter()
58 .map(|f| f.to_string())
59 .collect_vec();
60
61 let opts = if expr.fields().is_include() {
62 Opts::Include(ProtoFieldNames { names })
63 } else {
64 Opts::Exclude(ProtoFieldNames { names })
65 };
66
67 Some(ProstMetadata(SelectOpts { opts: Some(opts) }))
68 }
69
70 fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
71 vec![&expr.child]
72 }
73
74 fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
75 Ok(SelectExpr {
76 fields: expr.fields.clone(),
77 child: children[0].clone(),
78 })
79 }
80
81 fn build(
82 _encoding: &Self::Encoding,
83 metadata: &<Self::Metadata as DeserializeMetadata>::Output,
84 mut children: Vec<ExprRef>,
85 ) -> VortexResult<Self::Expr> {
86 if children.len() != 1 {
87 vortex_bail!("Select expression must have exactly one child");
88 }
89
90 let fields = match metadata.opts.as_ref() {
91 Some(opts) => match opts {
92 Opts::Include(field_names) => SelectField::Include(FieldNames::from_iter(
93 field_names.names.iter().map(|s| s.as_str()),
94 )),
95 Opts::Exclude(field_names) => SelectField::Exclude(FieldNames::from_iter(
96 field_names.names.iter().map(|s| s.as_str()),
97 )),
98 },
99 None => {
100 vortex_bail!("Select expressions must be provided with fields to select or exclude")
101 }
102 };
103
104 let child = children
105 .drain(..)
106 .next()
107 .vortex_expect("number of children validated to be one");
108
109 Ok(SelectExpr { fields, child })
110 }
111
112 fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
113 let batch = expr.child.unchecked_evaluate(scope)?.to_struct();
114 Ok(match &expr.fields {
115 SelectField::Include(f) => batch.project(f.as_ref()),
116 SelectField::Exclude(names) => {
117 let included_names = batch
118 .names()
119 .iter()
120 .filter(|&f| !names.as_ref().contains(f))
121 .cloned()
122 .collect::<Vec<_>>();
123 batch.project(included_names.as_slice())
124 }
125 }?
126 .into_array())
127 }
128
129 fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
130 let child_dtype = expr.child.return_dtype(scope)?;
131 let child_struct_dtype = child_dtype
132 .as_struct_fields_opt()
133 .ok_or_else(|| vortex_err!("Select child not a struct dtype"))?;
134
135 let projected = match &expr.fields {
136 SelectField::Include(fields) => child_struct_dtype.project(fields.as_ref())?,
137 SelectField::Exclude(fields) => child_struct_dtype
138 .names()
139 .iter()
140 .cloned()
141 .zip_eq(child_struct_dtype.fields())
142 .filter(|(name, _)| !fields.as_ref().contains(name))
143 .collect(),
144 };
145
146 Ok(DType::Struct(projected, child_dtype.nullability()))
147 }
148}
149
150pub fn select(fields: impl Into<FieldNames>, child: ExprRef) -> ExprRef {
158 SelectExpr::include_expr(fields.into(), child)
159}
160
161pub fn select_exclude(fields: impl Into<FieldNames>, child: ExprRef) -> ExprRef {
170 SelectExpr::exclude_expr(fields.into(), child)
171}
172
173impl SelectExpr {
174 pub fn new(fields: SelectField, child: ExprRef) -> Self {
175 Self { fields, child }
176 }
177
178 pub fn new_expr(fields: SelectField, child: ExprRef) -> ExprRef {
179 Self::new(fields, child).into_expr()
180 }
181
182 pub fn include_expr(columns: FieldNames, child: ExprRef) -> ExprRef {
183 Self::new(SelectField::Include(columns), child).into_expr()
184 }
185
186 pub fn exclude_expr(columns: FieldNames, child: ExprRef) -> ExprRef {
187 Self::new(SelectField::Exclude(columns), child).into_expr()
188 }
189
190 pub fn fields(&self) -> &SelectField {
191 &self.fields
192 }
193
194 pub fn child(&self) -> &ExprRef {
195 &self.child
196 }
197
198 pub fn as_include(&self, field_names: &FieldNames) -> VortexResult<ExprRef> {
214 Ok(Self::new(
215 SelectField::Include(self.fields.as_include_names(field_names)?),
216 self.child.clone(),
217 )
218 .into_expr())
219 }
220}
221
222impl SelectField {
223 pub fn include(columns: FieldNames) -> Self {
224 assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
225 Self::Include(columns)
226 }
227
228 pub fn exclude(columns: FieldNames) -> Self {
229 assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
230 Self::Exclude(columns)
231 }
232
233 pub fn is_include(&self) -> bool {
234 matches!(self, Self::Include(_))
235 }
236
237 pub fn is_exclude(&self) -> bool {
238 matches!(self, Self::Exclude(_))
239 }
240
241 pub fn fields(&self) -> &FieldNames {
242 let (SelectField::Include(fields) | SelectField::Exclude(fields)) = self;
243
244 fields
245 }
246
247 pub fn as_include_names(&self, field_names: &FieldNames) -> VortexResult<FieldNames> {
248 if self
249 .fields()
250 .iter()
251 .any(|f| !field_names.iter().contains(f))
252 {
253 vortex_bail!(
254 "Field {:?} in select not in field names {:?}",
255 self,
256 field_names
257 );
258 }
259 match self {
260 SelectField::Include(fields) => Ok(fields.clone()),
261 SelectField::Exclude(exc_fields) => Ok(field_names
262 .iter()
263 .filter(|f| !exc_fields.iter().contains(f))
264 .cloned()
265 .collect()),
266 }
267 }
268}
269
270impl Display for SelectField {
271 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
272 match self {
273 SelectField::Include(fields) => write!(f, "{{{}}}", DisplayFieldNames(fields)),
274 SelectField::Exclude(fields) => write!(f, "~{{{}}}", DisplayFieldNames(fields)),
275 }
276 }
277}
278
279impl DisplayAs for SelectExpr {
280 fn fmt_as(&self, df: DisplayFormat, f: &mut std::fmt::Formatter) -> std::fmt::Result {
281 match df {
282 DisplayFormat::Compact => {
283 write!(f, "{}{}", self.child, self.fields)
284 }
285 DisplayFormat::Tree => {
286 let field_type = if self.fields.is_include() {
287 "include"
288 } else {
289 "exclude"
290 };
291
292 write!(f, "Select({}): {}", field_type, self.fields().fields())
293 }
294 }
295 }
296
297 fn child_names(&self) -> Option<Vec<String>> {
298 None
300 }
301}
302
303impl AnalysisExpr for SelectExpr {}
304
305#[cfg(test)]
306mod tests {
307
308 use vortex_array::arrays::StructArray;
309 use vortex_array::{IntoArray, ToCanonical};
310 use vortex_buffer::buffer;
311 use vortex_dtype::{DType, FieldName, FieldNames, Nullability};
312
313 use crate::{Scope, SelectExpr, SelectField, root, select, select_exclude, test_harness};
314
315 fn test_array() -> StructArray {
316 StructArray::from_fields(&[
317 ("a", buffer![0, 1, 2].into_array()),
318 ("b", buffer![4, 5, 6].into_array()),
319 ])
320 .unwrap()
321 }
322
323 #[test]
324 pub fn include_columns() {
325 let st = test_array();
326 let select = select(vec![FieldName::from("a")], root());
327 let selected = select
328 .evaluate(&Scope::new(st.to_array()))
329 .unwrap()
330 .to_struct();
331 let selected_names = selected.names().clone();
332 assert_eq!(selected_names.as_ref(), &["a".into()]);
333 }
334
335 #[test]
336 pub fn exclude_columns() {
337 let st = test_array();
338 let select = select_exclude(vec![FieldName::from("a")], root());
339 let selected = select
340 .evaluate(&Scope::new(st.to_array()))
341 .unwrap()
342 .to_struct();
343 let selected_names = selected.names().clone();
344 assert_eq!(selected_names.as_ref(), &["b".into()]);
345 }
346
347 #[test]
348 fn dtype() {
349 let dtype = test_harness::struct_dtype();
350
351 let select_expr = select(vec![FieldName::from("a")], root());
352 let expected_dtype = DType::Struct(
353 dtype
354 .as_struct_fields_opt()
355 .unwrap()
356 .project(&["a".into()])
357 .unwrap(),
358 Nullability::NonNullable,
359 );
360 assert_eq!(select_expr.return_dtype(&dtype).unwrap(), expected_dtype);
361
362 let select_expr_exclude = select_exclude(
363 vec![
364 FieldName::from("col1"),
365 FieldName::from("col2"),
366 FieldName::from("bool1"),
367 FieldName::from("bool2"),
368 ],
369 root(),
370 );
371 assert_eq!(
372 select_expr_exclude.return_dtype(&dtype).unwrap(),
373 expected_dtype
374 );
375
376 let select_expr_exclude = select_exclude(
377 vec![FieldName::from("col1"), FieldName::from("col2")],
378 root(),
379 );
380 assert_eq!(
381 select_expr_exclude.return_dtype(&dtype).unwrap(),
382 DType::Struct(
383 dtype
384 .as_struct_fields_opt()
385 .unwrap()
386 .project(&["a".into(), "bool1".into(), "bool2".into()])
387 .unwrap(),
388 Nullability::NonNullable
389 )
390 );
391 }
392
393 #[test]
394 fn test_as_include_names() {
395 let field_names = FieldNames::from(["a", "b", "c"]);
396 let include = SelectExpr::new(SelectField::Include(["a"].into()), root());
397 let exclude = SelectExpr::new(SelectField::Exclude(["b", "c"].into()), root());
398 assert_eq!(
399 &include.as_include(&field_names).unwrap(),
400 &exclude.as_include(&field_names).unwrap()
401 );
402 }
403}