vortex_array/scalar_fn/fns/
select.rs1use std::fmt::Display;
5use std::fmt::Formatter;
6
7use itertools::Itertools;
8use prost::Message;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_bail;
12use vortex_error::vortex_err;
13use vortex_proto::expr::FieldNames as ProtoFieldNames;
14use vortex_proto::expr::SelectOpts;
15use vortex_proto::expr::select_opts::Opts;
16use vortex_session::VortexSession;
17
18use crate::ArrayRef;
19use crate::IntoArray;
20use crate::arrays::StructArray;
21use crate::dtype::DType;
22use crate::dtype::FieldName;
23use crate::dtype::FieldNames;
24use crate::expr::expression::Expression;
25use crate::expr::field::DisplayFieldNames;
26use crate::expr::get_item;
27use crate::expr::pack;
28use crate::scalar_fn::Arity;
29use crate::scalar_fn::ChildName;
30use crate::scalar_fn::ExecutionArgs;
31use crate::scalar_fn::ScalarFnId;
32use crate::scalar_fn::ScalarFnVTable;
33use crate::scalar_fn::SimplifyCtx;
34use crate::scalar_fn::fns::pack::Pack;
35
36#[derive(Debug, Clone, PartialEq, Eq, Hash)]
37pub enum FieldSelection {
38 Include(FieldNames),
39 Exclude(FieldNames),
40}
41
42#[derive(Clone)]
43pub struct Select;
44
45impl ScalarFnVTable for Select {
46 type Options = FieldSelection;
47
48 fn id(&self) -> ScalarFnId {
49 ScalarFnId::new_ref("vortex.select")
50 }
51
52 fn serialize(&self, instance: &FieldSelection) -> VortexResult<Option<Vec<u8>>> {
53 let opts = match instance {
54 FieldSelection::Include(fields) => Opts::Include(ProtoFieldNames {
55 names: fields.iter().map(|f| f.to_string()).collect(),
56 }),
57 FieldSelection::Exclude(fields) => Opts::Exclude(ProtoFieldNames {
58 names: fields.iter().map(|f| f.to_string()).collect(),
59 }),
60 };
61
62 let select_opts = SelectOpts { opts: Some(opts) };
63 Ok(Some(select_opts.encode_to_vec()))
64 }
65
66 fn deserialize(
67 &self,
68 _metadata: &[u8],
69 _session: &VortexSession,
70 ) -> VortexResult<FieldSelection> {
71 let prost_metadata = SelectOpts::decode(_metadata)?;
72
73 let select_opts = prost_metadata
74 .opts
75 .ok_or_else(|| vortex_err!("SelectOpts missing opts field"))?;
76
77 let field_selection = match select_opts {
78 Opts::Include(field_names) => FieldSelection::Include(FieldNames::from_iter(
79 field_names.names.iter().map(|s| s.as_str()),
80 )),
81 Opts::Exclude(field_names) => FieldSelection::Exclude(FieldNames::from_iter(
82 field_names.names.iter().map(|s| s.as_str()),
83 )),
84 };
85
86 Ok(field_selection)
87 }
88
89 fn arity(&self, _options: &FieldSelection) -> Arity {
90 Arity::Exact(1)
91 }
92
93 fn child_name(&self, _instance: &FieldSelection, child_idx: usize) -> ChildName {
94 match child_idx {
95 0 => ChildName::new_ref("child"),
96 _ => unreachable!(),
97 }
98 }
99
100 fn fmt_sql(
101 &self,
102 selection: &FieldSelection,
103 expr: &Expression,
104 f: &mut Formatter<'_>,
105 ) -> std::fmt::Result {
106 expr.child(0).fmt_sql(f)?;
107 match selection {
108 FieldSelection::Include(fields) => {
109 write!(f, "{{{}}}", DisplayFieldNames(fields))
110 }
111 FieldSelection::Exclude(fields) => {
112 write!(f, "{{~ {}}}", DisplayFieldNames(fields))
113 }
114 }
115 }
116
117 fn return_dtype(
118 &self,
119 selection: &FieldSelection,
120 arg_dtypes: &[DType],
121 ) -> VortexResult<DType> {
122 let child_dtype = &arg_dtypes[0];
123 let child_struct_dtype = child_dtype
124 .as_struct_fields_opt()
125 .ok_or_else(|| vortex_err!("Select child not a struct dtype"))?;
126
127 let projected = match selection {
128 FieldSelection::Include(fields) => child_struct_dtype.project(fields.as_ref())?,
129 FieldSelection::Exclude(fields) => child_struct_dtype
130 .names()
131 .iter()
132 .cloned()
133 .zip_eq(child_struct_dtype.fields())
134 .filter(|(name, _)| !fields.as_ref().contains(name))
135 .collect(),
136 };
137
138 Ok(DType::Struct(projected, child_dtype.nullability()))
139 }
140
141 fn execute(
142 &self,
143 selection: &FieldSelection,
144 mut args: ExecutionArgs,
145 ) -> VortexResult<ArrayRef> {
146 let child = args
147 .inputs
148 .pop()
149 .vortex_expect("Missing input child")
150 .execute::<StructArray>(args.ctx)?;
151
152 let result = match selection {
153 FieldSelection::Include(f) => child.project(f.as_ref()),
154 FieldSelection::Exclude(names) => {
155 let included_names = child
156 .names()
157 .iter()
158 .filter(|&f| !names.as_ref().contains(f))
159 .cloned()
160 .collect::<Vec<_>>();
161 child.project(included_names.as_slice())
162 }
163 }?;
164
165 result.into_array().execute(args.ctx)
166 }
167
168 fn simplify(
169 &self,
170 selection: &FieldSelection,
171 expr: &Expression,
172 ctx: &dyn SimplifyCtx,
173 ) -> VortexResult<Option<Expression>> {
174 let child_struct = expr.child(0);
175 let struct_dtype = ctx.return_dtype(child_struct)?;
176 let struct_nullability = struct_dtype.nullability();
177
178 let struct_fields = struct_dtype.as_struct_fields_opt().ok_or_else(|| {
179 vortex_err!(
180 "Select child must return a struct dtype, however it was a {}",
181 struct_dtype
182 )
183 })?;
184
185 let included_fields = selection.normalize_to_included_fields(struct_fields.names())?;
187 let all_included_fields_are_nullable = included_fields.iter().all(|name| {
188 struct_fields
189 .field(name)
190 .vortex_expect(
191 "`normalize_to_included_fields` checks that the included fields already exist \
192 in `struct_fields`",
193 )
194 .is_nullable()
195 });
196
197 if included_fields.is_empty() {
202 let empty: Vec<(FieldName, Expression)> = vec![];
203 return Ok(Some(pack(empty, struct_nullability)));
204 }
205
206 let child_is_pack = child_struct.is::<Pack>();
213
214 let would_intersect_validity =
218 struct_nullability.is_nullable() && !all_included_fields_are_nullable;
219
220 if child_is_pack && !would_intersect_validity {
221 let pack_expr = pack(
222 included_fields
223 .into_iter()
224 .map(|name| (name.clone(), get_item(name, child_struct.clone()))),
225 struct_nullability,
226 );
227
228 return Ok(Some(pack_expr));
229 }
230
231 Ok(None)
232 }
233
234 fn is_null_sensitive(&self, _instance: &FieldSelection) -> bool {
235 true
236 }
237
238 fn is_fallible(&self, _instance: &FieldSelection) -> bool {
239 false
241 }
242}
243
244impl FieldSelection {
245 pub fn include(columns: FieldNames) -> Self {
246 assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
247 Self::Include(columns)
248 }
249
250 pub fn exclude(columns: FieldNames) -> Self {
251 assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
252 Self::Exclude(columns)
253 }
254
255 pub fn is_include(&self) -> bool {
256 matches!(self, Self::Include(_))
257 }
258
259 pub fn is_exclude(&self) -> bool {
260 matches!(self, Self::Exclude(_))
261 }
262
263 pub fn field_names(&self) -> &FieldNames {
264 let (FieldSelection::Include(fields) | FieldSelection::Exclude(fields)) = self;
265
266 fields
267 }
268
269 pub fn normalize_to_included_fields(
270 &self,
271 available_fields: &FieldNames,
272 ) -> VortexResult<FieldNames> {
273 if self
275 .field_names()
276 .iter()
277 .any(|f| !available_fields.iter().contains(f))
278 {
279 vortex_bail!(
280 "Select fields {:?} must be a subset of child fields {:?}",
281 self,
282 available_fields
283 );
284 }
285
286 match self {
287 FieldSelection::Include(fields) => Ok(fields.clone()),
288 FieldSelection::Exclude(exc_fields) => Ok(available_fields
289 .iter()
290 .filter(|f| !exc_fields.iter().contains(f))
291 .cloned()
292 .collect()),
293 }
294 }
295}
296
297impl Display for FieldSelection {
298 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
299 match self {
300 FieldSelection::Include(fields) => write!(f, "{{{}}}", DisplayFieldNames(fields)),
301 FieldSelection::Exclude(fields) => write!(f, "~{{{}}}", DisplayFieldNames(fields)),
302 }
303 }
304}
305
306#[cfg(test)]
307mod tests {
308 use vortex_buffer::buffer;
309
310 use crate::IntoArray;
311 use crate::ToCanonical;
312 use crate::arrays::StructArray;
313 use crate::dtype::DType;
314 use crate::dtype::FieldName;
315 use crate::dtype::FieldNames;
316 use crate::dtype::Nullability;
317 use crate::dtype::Nullability::Nullable;
318 use crate::dtype::PType::I32;
319 use crate::dtype::StructFields;
320 use crate::expr::root;
321 use crate::expr::select;
322 use crate::expr::select_exclude;
323 use crate::expr::test_harness;
324 use crate::scalar_fn::fns::select::Select;
325
326 fn test_array() -> StructArray {
327 StructArray::from_fields(&[
328 ("a", buffer![0, 1, 2].into_array()),
329 ("b", buffer![4, 5, 6].into_array()),
330 ])
331 .unwrap()
332 }
333
334 #[test]
335 pub fn include_columns() {
336 let st = test_array();
337 let select = select(vec![FieldName::from("a")], root());
338 let selected = st.to_array().apply(&select).unwrap().to_struct();
339 let selected_names = selected.names().clone();
340 assert_eq!(selected_names.as_ref(), &["a"]);
341 }
342
343 #[test]
344 pub fn exclude_columns() {
345 let st = test_array();
346 let select = select_exclude(vec![FieldName::from("a")], root());
347 let selected = st.to_array().apply(&select).unwrap().to_struct();
348 let selected_names = selected.names().clone();
349 assert_eq!(selected_names.as_ref(), &["b"]);
350 }
351
352 #[test]
353 fn dtype() {
354 let dtype = test_harness::struct_dtype();
355
356 let select_expr = select(vec![FieldName::from("a")], root());
357 let expected_dtype = DType::Struct(
358 dtype
359 .as_struct_fields_opt()
360 .unwrap()
361 .project(&["a".into()])
362 .unwrap(),
363 Nullability::NonNullable,
364 );
365 assert_eq!(select_expr.return_dtype(&dtype).unwrap(), expected_dtype);
366
367 let select_expr_exclude = select_exclude(
368 vec![
369 FieldName::from("col1"),
370 FieldName::from("col2"),
371 FieldName::from("bool1"),
372 FieldName::from("bool2"),
373 ],
374 root(),
375 );
376 assert_eq!(
377 select_expr_exclude.return_dtype(&dtype).unwrap(),
378 expected_dtype
379 );
380
381 let select_expr_exclude = select_exclude(
382 vec![FieldName::from("col1"), FieldName::from("col2")],
383 root(),
384 );
385 assert_eq!(
386 select_expr_exclude.return_dtype(&dtype).unwrap(),
387 DType::Struct(
388 dtype
389 .as_struct_fields_opt()
390 .unwrap()
391 .project(&["a".into(), "bool1".into(), "bool2".into()])
392 .unwrap(),
393 Nullability::NonNullable
394 )
395 );
396 }
397
398 #[test]
399 fn test_as_include_names() {
400 let field_names = FieldNames::from(["a", "b", "c"]);
401 let include = select(["a"], root());
402 let exclude = select_exclude(["b", "c"], root());
403 assert_eq!(
404 &include
405 .as_::<Select>()
406 .normalize_to_included_fields(&field_names)
407 .unwrap(),
408 &exclude
409 .as_::<Select>()
410 .normalize_to_included_fields(&field_names)
411 .unwrap()
412 );
413 }
414
415 #[test]
416 fn test_remove_select_rule() {
417 let dtype = DType::Struct(
418 StructFields::new(["a", "b"].into(), vec![I32.into(), I32.into()]),
419 Nullable,
420 );
421 let e = select(["a", "b"], root());
422
423 let result = e.optimize_recursive(&dtype).unwrap();
424
425 assert!(result.return_dtype(&dtype).unwrap().is_nullable());
426 }
427
428 #[test]
429 fn test_remove_select_rule_exclude_fields() {
430 use crate::expr::select_exclude;
431
432 let dtype = DType::Struct(
433 StructFields::new(
434 ["a", "b", "c"].into(),
435 vec![I32.into(), I32.into(), I32.into()],
436 ),
437 Nullable,
438 );
439 let e = select_exclude(["c"], root());
440
441 let result = e.optimize_recursive(&dtype).unwrap();
442
443 let result_dtype = result.return_dtype(&dtype).unwrap();
445 assert!(result_dtype.is_nullable());
446 let fields = result_dtype.as_struct_fields_opt().unwrap();
447 assert_eq!(fields.names().as_ref(), &["a", "b"]);
448 }
449}