vortex_array/scalar_fn/fns/
select.rs1use std::fmt::Display;
5use std::fmt::Formatter;
6
7use itertools::Itertools;
8use prost::Message;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_bail;
12use vortex_error::vortex_err;
13use vortex_proto::expr::FieldNames as ProtoFieldNames;
14use vortex_proto::expr::SelectOpts;
15use vortex_proto::expr::select_opts::Opts;
16use vortex_session::VortexSession;
17
18use crate::ArrayRef;
19use crate::ExecutionCtx;
20use crate::IntoArray;
21use crate::arrays::StructArray;
22use crate::arrays::struct_::StructArrayExt;
23use crate::dtype::DType;
24use crate::dtype::FieldName;
25use crate::dtype::FieldNames;
26use crate::expr::expression::Expression;
27use crate::expr::field::DisplayFieldNames;
28use crate::expr::get_item;
29use crate::expr::pack;
30use crate::scalar_fn::Arity;
31use crate::scalar_fn::ChildName;
32use crate::scalar_fn::ExecutionArgs;
33use crate::scalar_fn::ScalarFnId;
34use crate::scalar_fn::ScalarFnVTable;
35use crate::scalar_fn::SimplifyCtx;
36use crate::scalar_fn::fns::pack::Pack;
37
38#[derive(Debug, Clone, PartialEq, Eq, Hash)]
39pub enum FieldSelection {
40 Include(FieldNames),
41 Exclude(FieldNames),
42}
43
44#[derive(Clone)]
45pub struct Select;
46
47impl ScalarFnVTable for Select {
48 type Options = FieldSelection;
49
50 fn id(&self) -> ScalarFnId {
51 ScalarFnId::from("vortex.select")
52 }
53
54 fn serialize(&self, instance: &FieldSelection) -> VortexResult<Option<Vec<u8>>> {
55 let opts = match instance {
56 FieldSelection::Include(fields) => Opts::Include(ProtoFieldNames {
57 names: fields.iter().map(|f| f.to_string()).collect(),
58 }),
59 FieldSelection::Exclude(fields) => Opts::Exclude(ProtoFieldNames {
60 names: fields.iter().map(|f| f.to_string()).collect(),
61 }),
62 };
63
64 let select_opts = SelectOpts { opts: Some(opts) };
65 Ok(Some(select_opts.encode_to_vec()))
66 }
67
68 fn deserialize(
69 &self,
70 _metadata: &[u8],
71 _session: &VortexSession,
72 ) -> VortexResult<FieldSelection> {
73 let prost_metadata = SelectOpts::decode(_metadata)?;
74
75 let select_opts = prost_metadata
76 .opts
77 .ok_or_else(|| vortex_err!("SelectOpts missing opts field"))?;
78
79 let field_selection = match select_opts {
80 Opts::Include(field_names) => FieldSelection::Include(FieldNames::from_iter(
81 field_names.names.iter().map(|s| s.as_str()),
82 )),
83 Opts::Exclude(field_names) => FieldSelection::Exclude(FieldNames::from_iter(
84 field_names.names.iter().map(|s| s.as_str()),
85 )),
86 };
87
88 Ok(field_selection)
89 }
90
91 fn arity(&self, _options: &FieldSelection) -> Arity {
92 Arity::Exact(1)
93 }
94
95 fn child_name(&self, _instance: &FieldSelection, child_idx: usize) -> ChildName {
96 match child_idx {
97 0 => ChildName::from("child"),
98 _ => unreachable!(),
99 }
100 }
101
102 fn fmt_sql(
103 &self,
104 selection: &FieldSelection,
105 expr: &Expression,
106 f: &mut Formatter<'_>,
107 ) -> std::fmt::Result {
108 expr.child(0).fmt_sql(f)?;
109 match selection {
110 FieldSelection::Include(fields) => {
111 write!(f, "{{{}}}", DisplayFieldNames(fields))
112 }
113 FieldSelection::Exclude(fields) => {
114 write!(f, "{{~ {}}}", DisplayFieldNames(fields))
115 }
116 }
117 }
118
119 fn return_dtype(
120 &self,
121 selection: &FieldSelection,
122 arg_dtypes: &[DType],
123 ) -> VortexResult<DType> {
124 let child_dtype = &arg_dtypes[0];
125 let child_struct_dtype = child_dtype
126 .as_struct_fields_opt()
127 .ok_or_else(|| vortex_err!("Select child not a struct dtype"))?;
128
129 let projected = match selection {
130 FieldSelection::Include(fields) => child_struct_dtype.project(fields.as_ref())?,
131 FieldSelection::Exclude(fields) => child_struct_dtype
132 .names()
133 .iter()
134 .cloned()
135 .zip_eq(child_struct_dtype.fields())
136 .filter(|(name, _)| !fields.as_ref().contains(name))
137 .collect(),
138 };
139
140 Ok(DType::Struct(projected, child_dtype.nullability()))
141 }
142
143 fn execute(
144 &self,
145 selection: &FieldSelection,
146 args: &dyn ExecutionArgs,
147 ctx: &mut ExecutionCtx,
148 ) -> VortexResult<ArrayRef> {
149 let child = args.get(0)?.execute::<StructArray>(ctx)?;
150
151 let result = match selection {
152 FieldSelection::Include(f) => child.project(f.as_ref()),
153 FieldSelection::Exclude(names) => {
154 let included_names = child
155 .names()
156 .iter()
157 .filter(|&f| !names.as_ref().contains(f))
158 .cloned()
159 .collect::<Vec<_>>();
160 child.project(included_names.as_slice())
161 }
162 }?;
163
164 result.into_array().execute(ctx)
165 }
166
167 fn simplify(
168 &self,
169 selection: &FieldSelection,
170 expr: &Expression,
171 ctx: &dyn SimplifyCtx,
172 ) -> VortexResult<Option<Expression>> {
173 let child_struct = expr.child(0);
174 let struct_dtype = ctx.return_dtype(child_struct)?;
175 let struct_nullability = struct_dtype.nullability();
176
177 let struct_fields = struct_dtype.as_struct_fields_opt().ok_or_else(|| {
178 vortex_err!(
179 "Select child must return a struct dtype, however it was a {}",
180 struct_dtype
181 )
182 })?;
183
184 let included_fields = selection.normalize_to_included_fields(struct_fields.names())?;
186 let all_included_fields_are_nullable = included_fields.iter().all(|name| {
187 struct_fields
188 .field(name)
189 .vortex_expect(
190 "`normalize_to_included_fields` checks that the included fields already exist \
191 in `struct_fields`",
192 )
193 .is_nullable()
194 });
195
196 if included_fields.is_empty() {
201 let empty: Vec<(FieldName, Expression)> = vec![];
202 return Ok(Some(pack(empty, struct_nullability)));
203 }
204
205 let child_is_pack = child_struct.is::<Pack>();
212
213 let would_intersect_validity =
217 struct_nullability.is_nullable() && !all_included_fields_are_nullable;
218
219 if child_is_pack && !would_intersect_validity {
220 let pack_expr = pack(
221 included_fields
222 .into_iter()
223 .map(|name| (name.clone(), get_item(name, child_struct.clone()))),
224 struct_nullability,
225 );
226
227 return Ok(Some(pack_expr));
228 }
229
230 Ok(None)
231 }
232
233 fn is_null_sensitive(&self, _instance: &FieldSelection) -> bool {
234 true
235 }
236
237 fn is_fallible(&self, _instance: &FieldSelection) -> bool {
238 false
240 }
241}
242
243impl FieldSelection {
244 pub fn include(columns: FieldNames) -> Self {
245 assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
246 Self::Include(columns)
247 }
248
249 pub fn exclude(columns: FieldNames) -> Self {
250 assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
251 Self::Exclude(columns)
252 }
253
254 pub fn is_include(&self) -> bool {
255 matches!(self, Self::Include(_))
256 }
257
258 pub fn is_exclude(&self) -> bool {
259 matches!(self, Self::Exclude(_))
260 }
261
262 pub fn field_names(&self) -> &FieldNames {
263 let (FieldSelection::Include(fields) | FieldSelection::Exclude(fields)) = self;
264
265 fields
266 }
267
268 pub fn normalize_to_included_fields(
269 &self,
270 available_fields: &FieldNames,
271 ) -> VortexResult<FieldNames> {
272 if self
274 .field_names()
275 .iter()
276 .any(|f| !available_fields.iter().contains(f))
277 {
278 vortex_bail!(
279 "Select fields {:?} must be a subset of child fields {:?}",
280 self,
281 available_fields
282 );
283 }
284
285 match self {
286 FieldSelection::Include(fields) => Ok(fields.clone()),
287 FieldSelection::Exclude(exc_fields) => Ok(available_fields
288 .iter()
289 .filter(|f| !exc_fields.iter().contains(f))
290 .cloned()
291 .collect()),
292 }
293 }
294}
295
296impl Display for FieldSelection {
297 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
298 match self {
299 FieldSelection::Include(fields) => write!(f, "{{{}}}", DisplayFieldNames(fields)),
300 FieldSelection::Exclude(fields) => write!(f, "~{{{}}}", DisplayFieldNames(fields)),
301 }
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use vortex_buffer::buffer;
308
309 use crate::IntoArray;
310 #[expect(deprecated)]
311 use crate::ToCanonical as _;
312 use crate::arrays::struct_::StructArrayExt;
313 use crate::dtype::DType;
314 use crate::dtype::FieldName;
315 use crate::dtype::FieldNames;
316 use crate::dtype::Nullability;
317 use crate::dtype::Nullability::Nullable;
318 use crate::dtype::PType::I32;
319 use crate::dtype::StructFields;
320 use crate::expr::root;
321 use crate::expr::select;
322 use crate::expr::select_exclude;
323 use crate::expr::test_harness;
324 use crate::scalar_fn::fns::select::Select;
325 use crate::scalar_fn::fns::select::StructArray;
326
327 fn test_array() -> StructArray {
328 StructArray::from_fields(&[
329 ("a", buffer![0, 1, 2].into_array()),
330 ("b", buffer![4, 5, 6].into_array()),
331 ])
332 .unwrap()
333 }
334
335 #[test]
336 pub fn include_columns() {
337 let st = test_array();
338 let select = select(vec![FieldName::from("a")], root());
339 #[expect(deprecated)]
340 let selected = st.into_array().apply(&select).unwrap().to_struct();
341 let selected_names = selected.names().clone();
342 assert_eq!(selected_names.as_ref(), &["a"]);
343 }
344
345 #[test]
346 pub fn exclude_columns() {
347 let st = test_array();
348 let select = select_exclude(vec![FieldName::from("a")], root());
349 #[expect(deprecated)]
350 let selected = st.into_array().apply(&select).unwrap().to_struct();
351 let selected_names = selected.names().clone();
352 assert_eq!(selected_names.as_ref(), &["b"]);
353 }
354
355 #[test]
356 fn dtype() {
357 let dtype = test_harness::struct_dtype();
358
359 let select_expr = select(vec![FieldName::from("a")], root());
360 let expected_dtype = DType::Struct(
361 dtype
362 .as_struct_fields_opt()
363 .unwrap()
364 .project(&["a".into()])
365 .unwrap(),
366 Nullability::NonNullable,
367 );
368 assert_eq!(select_expr.return_dtype(&dtype).unwrap(), expected_dtype);
369
370 let select_expr_exclude = select_exclude(
371 vec![
372 FieldName::from("col1"),
373 FieldName::from("col2"),
374 FieldName::from("bool1"),
375 FieldName::from("bool2"),
376 ],
377 root(),
378 );
379 assert_eq!(
380 select_expr_exclude.return_dtype(&dtype).unwrap(),
381 expected_dtype
382 );
383
384 let select_expr_exclude = select_exclude(
385 vec![FieldName::from("col1"), FieldName::from("col2")],
386 root(),
387 );
388 assert_eq!(
389 select_expr_exclude.return_dtype(&dtype).unwrap(),
390 DType::Struct(
391 dtype
392 .as_struct_fields_opt()
393 .unwrap()
394 .project(&["a".into(), "bool1".into(), "bool2".into()])
395 .unwrap(),
396 Nullability::NonNullable
397 )
398 );
399 }
400
401 #[test]
402 fn test_as_include_names() {
403 let field_names = FieldNames::from(["a", "b", "c"]);
404 let include = select(["a"], root());
405 let exclude = select_exclude(["b", "c"], root());
406 assert_eq!(
407 &include
408 .as_::<Select>()
409 .normalize_to_included_fields(&field_names)
410 .unwrap(),
411 &exclude
412 .as_::<Select>()
413 .normalize_to_included_fields(&field_names)
414 .unwrap()
415 );
416 }
417
418 #[test]
419 fn test_remove_select_rule() {
420 let dtype = DType::Struct(
421 StructFields::new(["a", "b"].into(), vec![I32.into(), I32.into()]),
422 Nullable,
423 );
424 let e = select(["a", "b"], root());
425
426 let result = e.optimize_recursive(&dtype).unwrap();
427
428 assert!(result.return_dtype(&dtype).unwrap().is_nullable());
429 }
430
431 #[test]
432 fn test_remove_select_rule_exclude_fields() {
433 use crate::expr::select_exclude;
434
435 let dtype = DType::Struct(
436 StructFields::new(
437 ["a", "b", "c"].into(),
438 vec![I32.into(), I32.into(), I32.into()],
439 ),
440 Nullable,
441 );
442 let e = select_exclude(["c"], root());
443
444 let result = e.optimize_recursive(&dtype).unwrap();
445
446 let result_dtype = result.return_dtype(&dtype).unwrap();
448 assert!(result_dtype.is_nullable());
449 let fields = result_dtype.as_struct_fields_opt().unwrap();
450 assert_eq!(fields.names().as_ref(), &["a", "b"]);
451 }
452}