vortex_array/scalar_fn/fns/
select.rs1use std::fmt::Display;
5use std::fmt::Formatter;
6
7use itertools::Itertools;
8use prost::Message;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_bail;
12use vortex_error::vortex_err;
13use vortex_proto::expr::FieldNames as ProtoFieldNames;
14use vortex_proto::expr::SelectOpts;
15use vortex_proto::expr::select_opts::Opts;
16use vortex_session::VortexSession;
17
18use crate::ArrayRef;
19use crate::ExecutionCtx;
20use crate::IntoArray;
21use crate::arrays::StructArray;
22use crate::dtype::DType;
23use crate::dtype::FieldName;
24use crate::dtype::FieldNames;
25use crate::expr::expression::Expression;
26use crate::expr::field::DisplayFieldNames;
27use crate::expr::get_item;
28use crate::expr::pack;
29use crate::scalar_fn::Arity;
30use crate::scalar_fn::ChildName;
31use crate::scalar_fn::ExecutionArgs;
32use crate::scalar_fn::ScalarFnId;
33use crate::scalar_fn::ScalarFnVTable;
34use crate::scalar_fn::SimplifyCtx;
35use crate::scalar_fn::fns::pack::Pack;
36
37#[derive(Debug, Clone, PartialEq, Eq, Hash)]
38pub enum FieldSelection {
39 Include(FieldNames),
40 Exclude(FieldNames),
41}
42
43#[derive(Clone)]
44pub struct Select;
45
46impl ScalarFnVTable for Select {
47 type Options = FieldSelection;
48
49 fn id(&self) -> ScalarFnId {
50 ScalarFnId::new_ref("vortex.select")
51 }
52
53 fn serialize(&self, instance: &FieldSelection) -> VortexResult<Option<Vec<u8>>> {
54 let opts = match instance {
55 FieldSelection::Include(fields) => Opts::Include(ProtoFieldNames {
56 names: fields.iter().map(|f| f.to_string()).collect(),
57 }),
58 FieldSelection::Exclude(fields) => Opts::Exclude(ProtoFieldNames {
59 names: fields.iter().map(|f| f.to_string()).collect(),
60 }),
61 };
62
63 let select_opts = SelectOpts { opts: Some(opts) };
64 Ok(Some(select_opts.encode_to_vec()))
65 }
66
67 fn deserialize(
68 &self,
69 _metadata: &[u8],
70 _session: &VortexSession,
71 ) -> VortexResult<FieldSelection> {
72 let prost_metadata = SelectOpts::decode(_metadata)?;
73
74 let select_opts = prost_metadata
75 .opts
76 .ok_or_else(|| vortex_err!("SelectOpts missing opts field"))?;
77
78 let field_selection = match select_opts {
79 Opts::Include(field_names) => FieldSelection::Include(FieldNames::from_iter(
80 field_names.names.iter().map(|s| s.as_str()),
81 )),
82 Opts::Exclude(field_names) => FieldSelection::Exclude(FieldNames::from_iter(
83 field_names.names.iter().map(|s| s.as_str()),
84 )),
85 };
86
87 Ok(field_selection)
88 }
89
90 fn arity(&self, _options: &FieldSelection) -> Arity {
91 Arity::Exact(1)
92 }
93
94 fn child_name(&self, _instance: &FieldSelection, child_idx: usize) -> ChildName {
95 match child_idx {
96 0 => ChildName::new_ref("child"),
97 _ => unreachable!(),
98 }
99 }
100
101 fn fmt_sql(
102 &self,
103 selection: &FieldSelection,
104 expr: &Expression,
105 f: &mut Formatter<'_>,
106 ) -> std::fmt::Result {
107 expr.child(0).fmt_sql(f)?;
108 match selection {
109 FieldSelection::Include(fields) => {
110 write!(f, "{{{}}}", DisplayFieldNames(fields))
111 }
112 FieldSelection::Exclude(fields) => {
113 write!(f, "{{~ {}}}", DisplayFieldNames(fields))
114 }
115 }
116 }
117
118 fn return_dtype(
119 &self,
120 selection: &FieldSelection,
121 arg_dtypes: &[DType],
122 ) -> VortexResult<DType> {
123 let child_dtype = &arg_dtypes[0];
124 let child_struct_dtype = child_dtype
125 .as_struct_fields_opt()
126 .ok_or_else(|| vortex_err!("Select child not a struct dtype"))?;
127
128 let projected = match selection {
129 FieldSelection::Include(fields) => child_struct_dtype.project(fields.as_ref())?,
130 FieldSelection::Exclude(fields) => child_struct_dtype
131 .names()
132 .iter()
133 .cloned()
134 .zip_eq(child_struct_dtype.fields())
135 .filter(|(name, _)| !fields.as_ref().contains(name))
136 .collect(),
137 };
138
139 Ok(DType::Struct(projected, child_dtype.nullability()))
140 }
141
142 fn execute(
143 &self,
144 selection: &FieldSelection,
145 args: &dyn ExecutionArgs,
146 ctx: &mut ExecutionCtx,
147 ) -> VortexResult<ArrayRef> {
148 let child = args.get(0)?.execute::<StructArray>(ctx)?;
149
150 let result = match selection {
151 FieldSelection::Include(f) => child.project(f.as_ref()),
152 FieldSelection::Exclude(names) => {
153 let included_names = child
154 .names()
155 .iter()
156 .filter(|&f| !names.as_ref().contains(f))
157 .cloned()
158 .collect::<Vec<_>>();
159 child.project(included_names.as_slice())
160 }
161 }?;
162
163 result.into_array().execute(ctx)
164 }
165
166 fn simplify(
167 &self,
168 selection: &FieldSelection,
169 expr: &Expression,
170 ctx: &dyn SimplifyCtx,
171 ) -> VortexResult<Option<Expression>> {
172 let child_struct = expr.child(0);
173 let struct_dtype = ctx.return_dtype(child_struct)?;
174 let struct_nullability = struct_dtype.nullability();
175
176 let struct_fields = struct_dtype.as_struct_fields_opt().ok_or_else(|| {
177 vortex_err!(
178 "Select child must return a struct dtype, however it was a {}",
179 struct_dtype
180 )
181 })?;
182
183 let included_fields = selection.normalize_to_included_fields(struct_fields.names())?;
185 let all_included_fields_are_nullable = included_fields.iter().all(|name| {
186 struct_fields
187 .field(name)
188 .vortex_expect(
189 "`normalize_to_included_fields` checks that the included fields already exist \
190 in `struct_fields`",
191 )
192 .is_nullable()
193 });
194
195 if included_fields.is_empty() {
200 let empty: Vec<(FieldName, Expression)> = vec![];
201 return Ok(Some(pack(empty, struct_nullability)));
202 }
203
204 let child_is_pack = child_struct.is::<Pack>();
211
212 let would_intersect_validity =
216 struct_nullability.is_nullable() && !all_included_fields_are_nullable;
217
218 if child_is_pack && !would_intersect_validity {
219 let pack_expr = pack(
220 included_fields
221 .into_iter()
222 .map(|name| (name.clone(), get_item(name, child_struct.clone()))),
223 struct_nullability,
224 );
225
226 return Ok(Some(pack_expr));
227 }
228
229 Ok(None)
230 }
231
232 fn is_null_sensitive(&self, _instance: &FieldSelection) -> bool {
233 true
234 }
235
236 fn is_fallible(&self, _instance: &FieldSelection) -> bool {
237 false
239 }
240}
241
242impl FieldSelection {
243 pub fn include(columns: FieldNames) -> Self {
244 assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
245 Self::Include(columns)
246 }
247
248 pub fn exclude(columns: FieldNames) -> Self {
249 assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
250 Self::Exclude(columns)
251 }
252
253 pub fn is_include(&self) -> bool {
254 matches!(self, Self::Include(_))
255 }
256
257 pub fn is_exclude(&self) -> bool {
258 matches!(self, Self::Exclude(_))
259 }
260
261 pub fn field_names(&self) -> &FieldNames {
262 let (FieldSelection::Include(fields) | FieldSelection::Exclude(fields)) = self;
263
264 fields
265 }
266
267 pub fn normalize_to_included_fields(
268 &self,
269 available_fields: &FieldNames,
270 ) -> VortexResult<FieldNames> {
271 if self
273 .field_names()
274 .iter()
275 .any(|f| !available_fields.iter().contains(f))
276 {
277 vortex_bail!(
278 "Select fields {:?} must be a subset of child fields {:?}",
279 self,
280 available_fields
281 );
282 }
283
284 match self {
285 FieldSelection::Include(fields) => Ok(fields.clone()),
286 FieldSelection::Exclude(exc_fields) => Ok(available_fields
287 .iter()
288 .filter(|f| !exc_fields.iter().contains(f))
289 .cloned()
290 .collect()),
291 }
292 }
293}
294
295impl Display for FieldSelection {
296 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
297 match self {
298 FieldSelection::Include(fields) => write!(f, "{{{}}}", DisplayFieldNames(fields)),
299 FieldSelection::Exclude(fields) => write!(f, "~{{{}}}", DisplayFieldNames(fields)),
300 }
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use vortex_buffer::buffer;
307
308 use crate::IntoArray;
309 use crate::ToCanonical;
310 use crate::arrays::StructArray;
311 use crate::dtype::DType;
312 use crate::dtype::FieldName;
313 use crate::dtype::FieldNames;
314 use crate::dtype::Nullability;
315 use crate::dtype::Nullability::Nullable;
316 use crate::dtype::PType::I32;
317 use crate::dtype::StructFields;
318 use crate::expr::root;
319 use crate::expr::select;
320 use crate::expr::select_exclude;
321 use crate::expr::test_harness;
322 use crate::scalar_fn::fns::select::Select;
323
324 fn test_array() -> StructArray {
325 StructArray::from_fields(&[
326 ("a", buffer![0, 1, 2].into_array()),
327 ("b", buffer![4, 5, 6].into_array()),
328 ])
329 .unwrap()
330 }
331
332 #[test]
333 pub fn include_columns() {
334 let st = test_array();
335 let select = select(vec![FieldName::from("a")], root());
336 let selected = st.to_array().apply(&select).unwrap().to_struct();
337 let selected_names = selected.names().clone();
338 assert_eq!(selected_names.as_ref(), &["a"]);
339 }
340
341 #[test]
342 pub fn exclude_columns() {
343 let st = test_array();
344 let select = select_exclude(vec![FieldName::from("a")], root());
345 let selected = st.to_array().apply(&select).unwrap().to_struct();
346 let selected_names = selected.names().clone();
347 assert_eq!(selected_names.as_ref(), &["b"]);
348 }
349
350 #[test]
351 fn dtype() {
352 let dtype = test_harness::struct_dtype();
353
354 let select_expr = select(vec![FieldName::from("a")], root());
355 let expected_dtype = DType::Struct(
356 dtype
357 .as_struct_fields_opt()
358 .unwrap()
359 .project(&["a".into()])
360 .unwrap(),
361 Nullability::NonNullable,
362 );
363 assert_eq!(select_expr.return_dtype(&dtype).unwrap(), expected_dtype);
364
365 let select_expr_exclude = select_exclude(
366 vec![
367 FieldName::from("col1"),
368 FieldName::from("col2"),
369 FieldName::from("bool1"),
370 FieldName::from("bool2"),
371 ],
372 root(),
373 );
374 assert_eq!(
375 select_expr_exclude.return_dtype(&dtype).unwrap(),
376 expected_dtype
377 );
378
379 let select_expr_exclude = select_exclude(
380 vec![FieldName::from("col1"), FieldName::from("col2")],
381 root(),
382 );
383 assert_eq!(
384 select_expr_exclude.return_dtype(&dtype).unwrap(),
385 DType::Struct(
386 dtype
387 .as_struct_fields_opt()
388 .unwrap()
389 .project(&["a".into(), "bool1".into(), "bool2".into()])
390 .unwrap(),
391 Nullability::NonNullable
392 )
393 );
394 }
395
396 #[test]
397 fn test_as_include_names() {
398 let field_names = FieldNames::from(["a", "b", "c"]);
399 let include = select(["a"], root());
400 let exclude = select_exclude(["b", "c"], root());
401 assert_eq!(
402 &include
403 .as_::<Select>()
404 .normalize_to_included_fields(&field_names)
405 .unwrap(),
406 &exclude
407 .as_::<Select>()
408 .normalize_to_included_fields(&field_names)
409 .unwrap()
410 );
411 }
412
413 #[test]
414 fn test_remove_select_rule() {
415 let dtype = DType::Struct(
416 StructFields::new(["a", "b"].into(), vec![I32.into(), I32.into()]),
417 Nullable,
418 );
419 let e = select(["a", "b"], root());
420
421 let result = e.optimize_recursive(&dtype).unwrap();
422
423 assert!(result.return_dtype(&dtype).unwrap().is_nullable());
424 }
425
426 #[test]
427 fn test_remove_select_rule_exclude_fields() {
428 use crate::expr::select_exclude;
429
430 let dtype = DType::Struct(
431 StructFields::new(
432 ["a", "b", "c"].into(),
433 vec![I32.into(), I32.into(), I32.into()],
434 ),
435 Nullable,
436 );
437 let e = select_exclude(["c"], root());
438
439 let result = e.optimize_recursive(&dtype).unwrap();
440
441 let result_dtype = result.return_dtype(&dtype).unwrap();
443 assert!(result_dtype.is_nullable());
444 let fields = result_dtype.as_struct_fields_opt().unwrap();
445 assert_eq!(fields.names().as_ref(), &["a", "b"]);
446 }
447}