1use std::fmt::Display;
5
6use itertools::Itertools;
7use vortex_array::{ArrayRef, DeserializeMetadata, IntoArray, ProstMetadata, ToCanonical};
8use vortex_dtype::{DType, FieldNames};
9use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
10use vortex_proto::expr::select_opts::Opts;
11use vortex_proto::expr::{FieldNames as ProtoFieldNames, SelectOpts};
12
13use crate::display::{DisplayAs, DisplayFormat};
14use crate::field::DisplayFieldNames;
15use crate::{AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, VTable, vtable};
16
17#[derive(Debug, Clone, PartialEq, Eq, Hash)]
18pub enum FieldSelection {
19 Include(FieldNames),
20 Exclude(FieldNames),
21}
22
23vtable!(Select);
24
25#[derive(Debug, Clone, Hash, Eq)]
26#[allow(clippy::derived_hash_with_manual_eq)]
27pub struct SelectExpr {
28 selection: FieldSelection,
29 child: ExprRef,
30}
31
32impl PartialEq for SelectExpr {
33 fn eq(&self, other: &Self) -> bool {
34 self.selection == other.selection && self.child.eq(&other.child)
35 }
36}
37
38pub struct SelectExprEncoding;
39
40impl VTable for SelectVTable {
41 type Expr = SelectExpr;
42 type Encoding = SelectExprEncoding;
43 type Metadata = ProstMetadata<SelectOpts>;
44
45 fn id(_encoding: &Self::Encoding) -> ExprId {
46 ExprId::new_ref("select")
47 }
48
49 fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
50 ExprEncodingRef::new_ref(SelectExprEncoding.as_ref())
51 }
52
53 fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
54 let names = expr
55 .selection()
56 .field_names()
57 .iter()
58 .map(|f| f.to_string())
59 .collect_vec();
60
61 let opts = if expr.selection().is_include() {
62 Opts::Include(ProtoFieldNames { names })
63 } else {
64 Opts::Exclude(ProtoFieldNames { names })
65 };
66
67 Some(ProstMetadata(SelectOpts { opts: Some(opts) }))
68 }
69
70 fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
71 vec![&expr.child]
72 }
73
74 fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
75 Ok(SelectExpr {
76 selection: expr.selection.clone(),
77 child: children[0].clone(),
78 })
79 }
80
81 fn build(
82 _encoding: &Self::Encoding,
83 metadata: &<Self::Metadata as DeserializeMetadata>::Output,
84 mut children: Vec<ExprRef>,
85 ) -> VortexResult<Self::Expr> {
86 if children.len() != 1 {
87 vortex_bail!("Select expression must have exactly one child");
88 }
89
90 let fields = match metadata.opts.as_ref() {
91 Some(opts) => match opts {
92 Opts::Include(field_names) => FieldSelection::Include(FieldNames::from_iter(
93 field_names.names.iter().map(|s| s.as_str()),
94 )),
95 Opts::Exclude(field_names) => FieldSelection::Exclude(FieldNames::from_iter(
96 field_names.names.iter().map(|s| s.as_str()),
97 )),
98 },
99 None => {
100 vortex_bail!("Select expressions must be provided with fields to select or exclude")
101 }
102 };
103
104 let child = children
105 .drain(..)
106 .next()
107 .vortex_expect("number of children validated to be one");
108
109 Ok(SelectExpr {
110 selection: fields,
111 child,
112 })
113 }
114
115 fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
116 let batch = expr.child.unchecked_evaluate(scope)?.to_struct();
117 Ok(match &expr.selection {
118 FieldSelection::Include(f) => batch.project(f.as_ref()),
119 FieldSelection::Exclude(names) => {
120 let included_names = batch
121 .names()
122 .iter()
123 .filter(|&f| !names.as_ref().contains(f))
124 .cloned()
125 .collect::<Vec<_>>();
126 batch.project(included_names.as_slice())
127 }
128 }?
129 .into_array())
130 }
131
132 fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
133 let child_dtype = expr.child.return_dtype(scope)?;
134 let child_struct_dtype = child_dtype
135 .as_struct_fields_opt()
136 .ok_or_else(|| vortex_err!("Select child not a struct dtype"))?;
137
138 let projected = match &expr.selection {
139 FieldSelection::Include(fields) => child_struct_dtype.project(fields.as_ref())?,
140 FieldSelection::Exclude(fields) => child_struct_dtype
141 .names()
142 .iter()
143 .cloned()
144 .zip_eq(child_struct_dtype.fields())
145 .filter(|(name, _)| !fields.as_ref().contains(name))
146 .collect(),
147 };
148
149 Ok(DType::Struct(projected, child_dtype.nullability()))
150 }
151}
152
153pub fn select(field_names: impl Into<FieldNames>, child: ExprRef) -> ExprRef {
161 SelectExpr::include_expr(field_names.into(), child)
162}
163
164pub fn select_exclude(fields: impl Into<FieldNames>, child: ExprRef) -> ExprRef {
173 SelectExpr::exclude_expr(fields.into(), child)
174}
175
176impl SelectExpr {
177 pub fn new(fields: FieldSelection, child: ExprRef) -> Self {
178 Self {
179 selection: fields,
180 child,
181 }
182 }
183
184 pub fn new_expr(fields: FieldSelection, child: ExprRef) -> ExprRef {
185 Self::new(fields, child).into_expr()
186 }
187
188 pub fn include_expr(columns: FieldNames, child: ExprRef) -> ExprRef {
189 Self::new(FieldSelection::Include(columns), child).into_expr()
190 }
191
192 pub fn exclude_expr(columns: FieldNames, child: ExprRef) -> ExprRef {
193 Self::new(FieldSelection::Exclude(columns), child).into_expr()
194 }
195
196 pub fn selection(&self) -> &FieldSelection {
197 &self.selection
198 }
199
200 pub fn child(&self) -> &ExprRef {
201 &self.child
202 }
203
204 pub fn as_include(&self, field_names: &FieldNames) -> VortexResult<ExprRef> {
220 Ok(Self::new(
221 FieldSelection::Include(self.selection.as_include_names(field_names)?),
222 self.child.clone(),
223 )
224 .into_expr())
225 }
226}
227
228impl FieldSelection {
229 pub fn include(columns: FieldNames) -> Self {
230 assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
231 Self::Include(columns)
232 }
233
234 pub fn exclude(columns: FieldNames) -> Self {
235 assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
236 Self::Exclude(columns)
237 }
238
239 pub fn is_include(&self) -> bool {
240 matches!(self, Self::Include(_))
241 }
242
243 pub fn is_exclude(&self) -> bool {
244 matches!(self, Self::Exclude(_))
245 }
246
247 pub fn field_names(&self) -> &FieldNames {
248 let (FieldSelection::Include(fields) | FieldSelection::Exclude(fields)) = self;
249
250 fields
251 }
252
253 pub fn as_include_names(&self, field_names: &FieldNames) -> VortexResult<FieldNames> {
254 if self
255 .field_names()
256 .iter()
257 .any(|f| !field_names.iter().contains(f))
258 {
259 vortex_bail!(
260 "Field {:?} in select not in field names {:?}",
261 self,
262 field_names
263 );
264 }
265 match self {
266 FieldSelection::Include(fields) => Ok(fields.clone()),
267 FieldSelection::Exclude(exc_fields) => Ok(field_names
268 .iter()
269 .filter(|f| !exc_fields.iter().contains(f))
270 .cloned()
271 .collect()),
272 }
273 }
274}
275
276impl Display for FieldSelection {
277 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
278 match self {
279 FieldSelection::Include(fields) => write!(f, "{{{}}}", DisplayFieldNames(fields)),
280 FieldSelection::Exclude(fields) => write!(f, "~{{{}}}", DisplayFieldNames(fields)),
281 }
282 }
283}
284
285impl DisplayAs for SelectExpr {
286 fn fmt_as(&self, df: DisplayFormat, f: &mut std::fmt::Formatter) -> std::fmt::Result {
287 match df {
288 DisplayFormat::Compact => {
289 write!(f, "{}{}", self.child, self.selection)
290 }
291 DisplayFormat::Tree => {
292 let field_type = if self.selection.is_include() {
293 "include"
294 } else {
295 "exclude"
296 };
297
298 write!(
299 f,
300 "Select({}): {}",
301 field_type,
302 self.selection().field_names()
303 )
304 }
305 }
306 }
307
308 fn child_names(&self) -> Option<Vec<String>> {
309 None
311 }
312}
313
314impl AnalysisExpr for SelectExpr {}
315
316#[cfg(test)]
317mod tests {
318
319 use vortex_array::arrays::StructArray;
320 use vortex_array::{IntoArray, ToCanonical};
321 use vortex_buffer::buffer;
322 use vortex_dtype::{DType, FieldName, FieldNames, Nullability};
323
324 use crate::{FieldSelection, Scope, SelectExpr, root, select, select_exclude, test_harness};
325
326 fn test_array() -> StructArray {
327 StructArray::from_fields(&[
328 ("a", buffer![0, 1, 2].into_array()),
329 ("b", buffer![4, 5, 6].into_array()),
330 ])
331 .unwrap()
332 }
333
334 #[test]
335 pub fn include_columns() {
336 let st = test_array();
337 let select = select(vec![FieldName::from("a")], root());
338 let selected = select
339 .evaluate(&Scope::new(st.to_array()))
340 .unwrap()
341 .to_struct();
342 let selected_names = selected.names().clone();
343 assert_eq!(selected_names.as_ref(), &["a"]);
344 }
345
346 #[test]
347 pub fn exclude_columns() {
348 let st = test_array();
349 let select = select_exclude(vec![FieldName::from("a")], root());
350 let selected = select
351 .evaluate(&Scope::new(st.to_array()))
352 .unwrap()
353 .to_struct();
354 let selected_names = selected.names().clone();
355 assert_eq!(selected_names.as_ref(), &["b"]);
356 }
357
358 #[test]
359 fn dtype() {
360 let dtype = test_harness::struct_dtype();
361
362 let select_expr = select(vec![FieldName::from("a")], root());
363 let expected_dtype = DType::Struct(
364 dtype
365 .as_struct_fields_opt()
366 .unwrap()
367 .project(&["a".into()])
368 .unwrap(),
369 Nullability::NonNullable,
370 );
371 assert_eq!(select_expr.return_dtype(&dtype).unwrap(), expected_dtype);
372
373 let select_expr_exclude = select_exclude(
374 vec![
375 FieldName::from("col1"),
376 FieldName::from("col2"),
377 FieldName::from("bool1"),
378 FieldName::from("bool2"),
379 ],
380 root(),
381 );
382 assert_eq!(
383 select_expr_exclude.return_dtype(&dtype).unwrap(),
384 expected_dtype
385 );
386
387 let select_expr_exclude = select_exclude(
388 vec![FieldName::from("col1"), FieldName::from("col2")],
389 root(),
390 );
391 assert_eq!(
392 select_expr_exclude.return_dtype(&dtype).unwrap(),
393 DType::Struct(
394 dtype
395 .as_struct_fields_opt()
396 .unwrap()
397 .project(&["a".into(), "bool1".into(), "bool2".into()])
398 .unwrap(),
399 Nullability::NonNullable
400 )
401 );
402 }
403
404 #[test]
405 fn test_as_include_names() {
406 let field_names = FieldNames::from(["a", "b", "c"]);
407 let include = SelectExpr::new(FieldSelection::Include(["a"].into()), root());
408 let exclude = SelectExpr::new(FieldSelection::Exclude(["b", "c"].into()), root());
409 assert_eq!(
410 &include.as_include(&field_names).unwrap(),
411 &exclude.as_include(&field_names).unwrap()
412 );
413 }
414}