1use std::fmt::Display;
5use std::fmt::Formatter;
6
7use itertools::Itertools;
8use prost::Message;
9use vortex_dtype::DType;
10use vortex_dtype::FieldName;
11use vortex_dtype::FieldNames;
12use vortex_error::VortexExpect;
13use vortex_error::VortexResult;
14use vortex_error::vortex_bail;
15use vortex_error::vortex_err;
16use vortex_proto::expr::FieldNames as ProtoFieldNames;
17use vortex_proto::expr::SelectOpts;
18use vortex_proto::expr::select_opts::Opts;
19use vortex_session::VortexSession;
20
21use crate::ArrayRef;
22use crate::IntoArray;
23use crate::arrays::StructArray;
24use crate::expr;
25use crate::expr::Arity;
26use crate::expr::ChildName;
27use crate::expr::ExecutionArgs;
28use crate::expr::ExprId;
29use crate::expr::Pack;
30use crate::expr::SimplifyCtx;
31use crate::expr::VTable;
32use crate::expr::VTableExt;
33use crate::expr::expression::Expression;
34use crate::expr::field::DisplayFieldNames;
35
36#[derive(Debug, Clone, PartialEq, Eq, Hash)]
37pub enum FieldSelection {
38 Include(FieldNames),
39 Exclude(FieldNames),
40}
41
42pub struct Select;
43
44impl VTable for Select {
45 type Options = FieldSelection;
46
47 fn id(&self) -> ExprId {
48 ExprId::new_ref("vortex.select")
49 }
50
51 fn serialize(&self, instance: &FieldSelection) -> VortexResult<Option<Vec<u8>>> {
52 let opts = match instance {
53 FieldSelection::Include(fields) => Opts::Include(ProtoFieldNames {
54 names: fields.iter().map(|f| f.to_string()).collect(),
55 }),
56 FieldSelection::Exclude(fields) => Opts::Exclude(ProtoFieldNames {
57 names: fields.iter().map(|f| f.to_string()).collect(),
58 }),
59 };
60
61 let select_opts = SelectOpts { opts: Some(opts) };
62 Ok(Some(select_opts.encode_to_vec()))
63 }
64
65 fn deserialize(
66 &self,
67 _metadata: &[u8],
68 _session: &VortexSession,
69 ) -> VortexResult<FieldSelection> {
70 let prost_metadata = SelectOpts::decode(_metadata)?;
71
72 let select_opts = prost_metadata
73 .opts
74 .ok_or_else(|| vortex_err!("SelectOpts missing opts field"))?;
75
76 let field_selection = match select_opts {
77 Opts::Include(field_names) => FieldSelection::Include(FieldNames::from_iter(
78 field_names.names.iter().map(|s| s.as_str()),
79 )),
80 Opts::Exclude(field_names) => FieldSelection::Exclude(FieldNames::from_iter(
81 field_names.names.iter().map(|s| s.as_str()),
82 )),
83 };
84
85 Ok(field_selection)
86 }
87
88 fn arity(&self, _options: &FieldSelection) -> Arity {
89 Arity::Exact(1)
90 }
91
92 fn child_name(&self, _instance: &FieldSelection, child_idx: usize) -> ChildName {
93 match child_idx {
94 0 => ChildName::new_ref("child"),
95 _ => unreachable!(),
96 }
97 }
98
99 fn fmt_sql(
100 &self,
101 selection: &FieldSelection,
102 expr: &Expression,
103 f: &mut Formatter<'_>,
104 ) -> std::fmt::Result {
105 expr.child(0).fmt_sql(f)?;
106 match selection {
107 FieldSelection::Include(fields) => {
108 write!(f, "{{{}}}", DisplayFieldNames(fields))
109 }
110 FieldSelection::Exclude(fields) => {
111 write!(f, "{{~ {}}}", DisplayFieldNames(fields))
112 }
113 }
114 }
115
116 fn return_dtype(
117 &self,
118 selection: &FieldSelection,
119 arg_dtypes: &[DType],
120 ) -> VortexResult<DType> {
121 let child_dtype = &arg_dtypes[0];
122 let child_struct_dtype = child_dtype
123 .as_struct_fields_opt()
124 .ok_or_else(|| vortex_err!("Select child not a struct dtype"))?;
125
126 let projected = match selection {
127 FieldSelection::Include(fields) => child_struct_dtype.project(fields.as_ref())?,
128 FieldSelection::Exclude(fields) => child_struct_dtype
129 .names()
130 .iter()
131 .cloned()
132 .zip_eq(child_struct_dtype.fields())
133 .filter(|(name, _)| !fields.as_ref().contains(name))
134 .collect(),
135 };
136
137 Ok(DType::Struct(projected, child_dtype.nullability()))
138 }
139
140 fn execute(
141 &self,
142 selection: &FieldSelection,
143 mut args: ExecutionArgs,
144 ) -> VortexResult<ArrayRef> {
145 let child = args
146 .inputs
147 .pop()
148 .vortex_expect("Missing input child")
149 .execute::<StructArray>(args.ctx)?;
150
151 let result = match selection {
152 FieldSelection::Include(f) => child.project(f.as_ref()),
153 FieldSelection::Exclude(names) => {
154 let included_names = child
155 .names()
156 .iter()
157 .filter(|&f| !names.as_ref().contains(f))
158 .cloned()
159 .collect::<Vec<_>>();
160 child.project(included_names.as_slice())
161 }
162 }?;
163
164 result.into_array().execute(args.ctx)
165 }
166
167 fn simplify(
168 &self,
169 selection: &FieldSelection,
170 expr: &Expression,
171 ctx: &dyn SimplifyCtx,
172 ) -> VortexResult<Option<Expression>> {
173 let child_struct = expr.child(0);
174 let struct_dtype = ctx.return_dtype(child_struct)?;
175 let struct_nullability = struct_dtype.nullability();
176
177 let struct_fields = struct_dtype.as_struct_fields_opt().ok_or_else(|| {
178 vortex_err!(
179 "Select child must return a struct dtype, however it was a {}",
180 struct_dtype
181 )
182 })?;
183
184 let included_fields = selection.normalize_to_included_fields(struct_fields.names())?;
186 let all_included_fields_are_nullable = included_fields.iter().all(|name| {
187 struct_fields
188 .field(name)
189 .vortex_expect(
190 "`normalize_to_included_fields` checks that the included fields already exist \
191 in `struct_fields`",
192 )
193 .is_nullable()
194 });
195
196 if included_fields.is_empty() {
201 let empty: Vec<(FieldName, Expression)> = vec![];
202 return Ok(Some(expr::pack(empty, struct_nullability)));
203 }
204
205 let child_is_pack = child_struct.is::<Pack>();
212
213 let would_intersect_validity =
217 struct_nullability.is_nullable() && !all_included_fields_are_nullable;
218
219 if child_is_pack && !would_intersect_validity {
220 let pack_expr = expr::pack(
221 included_fields
222 .into_iter()
223 .map(|name| (name.clone(), expr::get_item(name, child_struct.clone()))),
224 struct_nullability,
225 );
226
227 return Ok(Some(pack_expr));
228 }
229
230 Ok(None)
231 }
232
233 fn is_null_sensitive(&self, _instance: &FieldSelection) -> bool {
234 true
235 }
236
237 fn is_fallible(&self, _instance: &FieldSelection) -> bool {
238 false
240 }
241}
242
243pub fn select(field_names: impl Into<FieldNames>, child: Expression) -> Expression {
251 Select
252 .try_new_expr(FieldSelection::Include(field_names.into()), [child])
253 .vortex_expect("Failed to create Select expression")
254}
255
256pub fn select_exclude(fields: impl Into<FieldNames>, child: Expression) -> Expression {
265 Select
266 .try_new_expr(FieldSelection::Exclude(fields.into()), [child])
267 .vortex_expect("Failed to create Select expression")
268}
269
270impl FieldSelection {
271 pub fn include(columns: FieldNames) -> Self {
272 assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
273 Self::Include(columns)
274 }
275
276 pub fn exclude(columns: FieldNames) -> Self {
277 assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
278 Self::Exclude(columns)
279 }
280
281 pub fn is_include(&self) -> bool {
282 matches!(self, Self::Include(_))
283 }
284
285 pub fn is_exclude(&self) -> bool {
286 matches!(self, Self::Exclude(_))
287 }
288
289 pub fn field_names(&self) -> &FieldNames {
290 let (FieldSelection::Include(fields) | FieldSelection::Exclude(fields)) = self;
291
292 fields
293 }
294
295 pub fn normalize_to_included_fields(
296 &self,
297 available_fields: &FieldNames,
298 ) -> VortexResult<FieldNames> {
299 if self
301 .field_names()
302 .iter()
303 .any(|f| !available_fields.iter().contains(f))
304 {
305 vortex_bail!(
306 "Select fields {:?} must be a subset of child fields {:?}",
307 self,
308 available_fields
309 );
310 }
311
312 match self {
313 FieldSelection::Include(fields) => Ok(fields.clone()),
314 FieldSelection::Exclude(exc_fields) => Ok(available_fields
315 .iter()
316 .filter(|f| !exc_fields.iter().contains(f))
317 .cloned()
318 .collect()),
319 }
320 }
321}
322
323impl Display for FieldSelection {
324 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
325 match self {
326 FieldSelection::Include(fields) => write!(f, "{{{}}}", DisplayFieldNames(fields)),
327 FieldSelection::Exclude(fields) => write!(f, "~{{{}}}", DisplayFieldNames(fields)),
328 }
329 }
330}
331
332#[cfg(test)]
333mod tests {
334 use vortex_buffer::buffer;
335 use vortex_dtype::DType;
336 use vortex_dtype::FieldName;
337 use vortex_dtype::FieldNames;
338 use vortex_dtype::Nullability;
339 use vortex_dtype::Nullability::Nullable;
340 use vortex_dtype::PType::I32;
341 use vortex_dtype::StructFields;
342
343 use super::select;
344 use super::select_exclude;
345 use crate::IntoArray;
346 use crate::ToCanonical;
347 use crate::arrays::StructArray;
348 use crate::expr::exprs::root::root;
349 use crate::expr::exprs::select::Select;
350 use crate::expr::test_harness;
351
352 fn test_array() -> StructArray {
353 StructArray::from_fields(&[
354 ("a", buffer![0, 1, 2].into_array()),
355 ("b", buffer![4, 5, 6].into_array()),
356 ])
357 .unwrap()
358 }
359
360 #[test]
361 pub fn include_columns() {
362 let st = test_array();
363 let select = select(vec![FieldName::from("a")], root());
364 let selected = st.to_array().apply(&select).unwrap().to_struct();
365 let selected_names = selected.names().clone();
366 assert_eq!(selected_names.as_ref(), &["a"]);
367 }
368
369 #[test]
370 pub fn exclude_columns() {
371 let st = test_array();
372 let select = select_exclude(vec![FieldName::from("a")], root());
373 let selected = st.to_array().apply(&select).unwrap().to_struct();
374 let selected_names = selected.names().clone();
375 assert_eq!(selected_names.as_ref(), &["b"]);
376 }
377
378 #[test]
379 fn dtype() {
380 let dtype = test_harness::struct_dtype();
381
382 let select_expr = select(vec![FieldName::from("a")], root());
383 let expected_dtype = DType::Struct(
384 dtype
385 .as_struct_fields_opt()
386 .unwrap()
387 .project(&["a".into()])
388 .unwrap(),
389 Nullability::NonNullable,
390 );
391 assert_eq!(select_expr.return_dtype(&dtype).unwrap(), expected_dtype);
392
393 let select_expr_exclude = select_exclude(
394 vec![
395 FieldName::from("col1"),
396 FieldName::from("col2"),
397 FieldName::from("bool1"),
398 FieldName::from("bool2"),
399 ],
400 root(),
401 );
402 assert_eq!(
403 select_expr_exclude.return_dtype(&dtype).unwrap(),
404 expected_dtype
405 );
406
407 let select_expr_exclude = select_exclude(
408 vec![FieldName::from("col1"), FieldName::from("col2")],
409 root(),
410 );
411 assert_eq!(
412 select_expr_exclude.return_dtype(&dtype).unwrap(),
413 DType::Struct(
414 dtype
415 .as_struct_fields_opt()
416 .unwrap()
417 .project(&["a".into(), "bool1".into(), "bool2".into()])
418 .unwrap(),
419 Nullability::NonNullable
420 )
421 );
422 }
423
424 #[test]
425 fn test_as_include_names() {
426 let field_names = FieldNames::from(["a", "b", "c"]);
427 let include = select(["a"], root());
428 let exclude = select_exclude(["b", "c"], root());
429 assert_eq!(
430 &include
431 .as_::<Select>()
432 .normalize_to_included_fields(&field_names)
433 .unwrap(),
434 &exclude
435 .as_::<Select>()
436 .normalize_to_included_fields(&field_names)
437 .unwrap()
438 );
439 }
440
441 #[test]
442 fn test_remove_select_rule() {
443 let dtype = DType::Struct(
444 StructFields::new(["a", "b"].into(), vec![I32.into(), I32.into()]),
445 Nullable,
446 );
447 let e = select(["a", "b"], root());
448
449 let result = e.optimize_recursive(&dtype).unwrap();
450
451 assert!(result.return_dtype(&dtype).unwrap().is_nullable());
452 }
453
454 #[test]
455 fn test_remove_select_rule_exclude_fields() {
456 use crate::expr::exprs::select::select_exclude;
457
458 let dtype = DType::Struct(
459 StructFields::new(
460 ["a", "b", "c"].into(),
461 vec![I32.into(), I32.into(), I32.into()],
462 ),
463 Nullable,
464 );
465 let e = select_exclude(["c"], root());
466
467 let result = e.optimize_recursive(&dtype).unwrap();
468
469 let result_dtype = result.return_dtype(&dtype).unwrap();
471 assert!(result_dtype.is_nullable());
472 let fields = result_dtype.as_struct_fields_opt().unwrap();
473 assert_eq!(fields.names().as_ref(), &["a", "b"]);
474 }
475}