vortex_array/scalar_fn/fns/
select.rs1use std::fmt::Display;
5use std::fmt::Formatter;
6
7use itertools::Itertools;
8use prost::Message;
9use vortex_error::VortexExpect;
10use vortex_error::VortexResult;
11use vortex_error::vortex_bail;
12use vortex_error::vortex_err;
13use vortex_proto::expr::FieldNames as ProtoFieldNames;
14use vortex_proto::expr::SelectOpts;
15use vortex_proto::expr::select_opts::Opts;
16use vortex_session::VortexSession;
17use vortex_session::registry::CachedId;
18
19use crate::ArrayRef;
20use crate::ExecutionCtx;
21use crate::IntoArray;
22use crate::arrays::StructArray;
23use crate::arrays::struct_::StructArrayExt;
24use crate::dtype::DType;
25use crate::dtype::FieldName;
26use crate::dtype::FieldNames;
27use crate::expr::expression::Expression;
28use crate::expr::field::DisplayFieldNames;
29use crate::expr::get_item;
30use crate::expr::pack;
31use crate::scalar_fn::Arity;
32use crate::scalar_fn::ChildName;
33use crate::scalar_fn::ExecutionArgs;
34use crate::scalar_fn::ScalarFnId;
35use crate::scalar_fn::ScalarFnVTable;
36use crate::scalar_fn::SimplifyCtx;
37use crate::scalar_fn::fns::pack::Pack;
38
39#[derive(Debug, Clone, PartialEq, Eq, Hash)]
40pub enum FieldSelection {
41 Include(FieldNames),
42 Exclude(FieldNames),
43}
44
45#[derive(Clone)]
46pub struct Select;
47
48impl ScalarFnVTable for Select {
49 type Options = FieldSelection;
50
51 fn id(&self) -> ScalarFnId {
52 static ID: CachedId = CachedId::new("vortex.select");
53 *ID
54 }
55
56 fn serialize(&self, instance: &FieldSelection) -> VortexResult<Option<Vec<u8>>> {
57 let opts = match instance {
58 FieldSelection::Include(fields) => Opts::Include(ProtoFieldNames {
59 names: fields.iter().map(|f| f.to_string()).collect(),
60 }),
61 FieldSelection::Exclude(fields) => Opts::Exclude(ProtoFieldNames {
62 names: fields.iter().map(|f| f.to_string()).collect(),
63 }),
64 };
65
66 let select_opts = SelectOpts { opts: Some(opts) };
67 Ok(Some(select_opts.encode_to_vec()))
68 }
69
70 fn deserialize(
71 &self,
72 _metadata: &[u8],
73 _session: &VortexSession,
74 ) -> VortexResult<FieldSelection> {
75 let prost_metadata = SelectOpts::decode(_metadata)?;
76
77 let select_opts = prost_metadata
78 .opts
79 .ok_or_else(|| vortex_err!("SelectOpts missing opts field"))?;
80
81 let field_selection = match select_opts {
82 Opts::Include(field_names) => FieldSelection::Include(FieldNames::from_iter(
83 field_names.names.iter().map(|s| s.as_str()),
84 )),
85 Opts::Exclude(field_names) => FieldSelection::Exclude(FieldNames::from_iter(
86 field_names.names.iter().map(|s| s.as_str()),
87 )),
88 };
89
90 Ok(field_selection)
91 }
92
93 fn arity(&self, _options: &FieldSelection) -> Arity {
94 Arity::Exact(1)
95 }
96
97 fn child_name(&self, _instance: &FieldSelection, child_idx: usize) -> ChildName {
98 match child_idx {
99 0 => ChildName::from("child"),
100 _ => unreachable!(),
101 }
102 }
103
104 fn fmt_sql(
105 &self,
106 selection: &FieldSelection,
107 expr: &Expression,
108 f: &mut Formatter<'_>,
109 ) -> std::fmt::Result {
110 expr.child(0).fmt_sql(f)?;
111 match selection {
112 FieldSelection::Include(fields) => {
113 write!(f, "{{{}}}", DisplayFieldNames(fields))
114 }
115 FieldSelection::Exclude(fields) => {
116 write!(f, "{{~ {}}}", DisplayFieldNames(fields))
117 }
118 }
119 }
120
121 fn return_dtype(
122 &self,
123 selection: &FieldSelection,
124 arg_dtypes: &[DType],
125 ) -> VortexResult<DType> {
126 let child_dtype = &arg_dtypes[0];
127 let child_struct_dtype = child_dtype
128 .as_struct_fields_opt()
129 .ok_or_else(|| vortex_err!("Select child not a struct dtype"))?;
130
131 let projected = match selection {
132 FieldSelection::Include(fields) => child_struct_dtype.project(fields.as_ref())?,
133 FieldSelection::Exclude(fields) => child_struct_dtype
134 .names()
135 .iter()
136 .cloned()
137 .zip_eq(child_struct_dtype.fields())
138 .filter(|(name, _)| !fields.as_ref().contains(name))
139 .collect(),
140 };
141
142 Ok(DType::Struct(projected, child_dtype.nullability()))
143 }
144
145 fn execute(
146 &self,
147 selection: &FieldSelection,
148 args: &dyn ExecutionArgs,
149 ctx: &mut ExecutionCtx,
150 ) -> VortexResult<ArrayRef> {
151 let child = args.get(0)?.execute::<StructArray>(ctx)?;
152
153 let result = match selection {
154 FieldSelection::Include(f) => child.project(f.as_ref()),
155 FieldSelection::Exclude(names) => {
156 let included_names = child
157 .names()
158 .iter()
159 .filter(|&f| !names.as_ref().contains(f))
160 .cloned()
161 .collect::<Vec<_>>();
162 child.project(included_names.as_slice())
163 }
164 }?;
165
166 result.into_array().execute(ctx)
167 }
168
169 fn simplify(
170 &self,
171 selection: &FieldSelection,
172 expr: &Expression,
173 ctx: &dyn SimplifyCtx,
174 ) -> VortexResult<Option<Expression>> {
175 let child_struct = expr.child(0);
176 let struct_dtype = ctx.return_dtype(child_struct)?;
177 let struct_nullability = struct_dtype.nullability();
178
179 let struct_fields = struct_dtype.as_struct_fields_opt().ok_or_else(|| {
180 vortex_err!(
181 "Select child must return a struct dtype, however it was a {}",
182 struct_dtype
183 )
184 })?;
185
186 let included_fields = selection.normalize_to_included_fields(struct_fields.names())?;
188 let all_included_fields_are_nullable = included_fields.iter().all(|name| {
189 struct_fields
190 .field(name)
191 .vortex_expect(
192 "`normalize_to_included_fields` checks that the included fields already exist \
193 in `struct_fields`",
194 )
195 .is_nullable()
196 });
197
198 if included_fields.is_empty() {
203 let empty: Vec<(FieldName, Expression)> = vec![];
204 return Ok(Some(pack(empty, struct_nullability)));
205 }
206
207 let child_is_pack = child_struct.is::<Pack>();
214
215 let would_intersect_validity =
219 struct_nullability.is_nullable() && !all_included_fields_are_nullable;
220
221 if child_is_pack && !would_intersect_validity {
222 let pack_expr = pack(
223 included_fields
224 .into_iter()
225 .map(|name| (name.clone(), get_item(name, child_struct.clone()))),
226 struct_nullability,
227 );
228
229 return Ok(Some(pack_expr));
230 }
231
232 Ok(None)
233 }
234
235 fn is_null_sensitive(&self, _instance: &FieldSelection) -> bool {
236 true
237 }
238
239 fn is_fallible(&self, _instance: &FieldSelection) -> bool {
240 false
242 }
243}
244
245impl FieldSelection {
246 pub fn include(columns: FieldNames) -> Self {
247 assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
248 Self::Include(columns)
249 }
250
251 pub fn exclude(columns: FieldNames) -> Self {
252 assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
253 Self::Exclude(columns)
254 }
255
256 pub fn is_include(&self) -> bool {
257 matches!(self, Self::Include(_))
258 }
259
260 pub fn is_exclude(&self) -> bool {
261 matches!(self, Self::Exclude(_))
262 }
263
264 pub fn field_names(&self) -> &FieldNames {
265 let (FieldSelection::Include(fields) | FieldSelection::Exclude(fields)) = self;
266
267 fields
268 }
269
270 pub fn normalize_to_included_fields(
271 &self,
272 available_fields: &FieldNames,
273 ) -> VortexResult<FieldNames> {
274 if self
276 .field_names()
277 .iter()
278 .any(|f| !available_fields.iter().contains(f))
279 {
280 vortex_bail!(
281 "Select fields {:?} must be a subset of child fields {:?}",
282 self,
283 available_fields
284 );
285 }
286
287 match self {
288 FieldSelection::Include(fields) => Ok(fields.clone()),
289 FieldSelection::Exclude(exc_fields) => Ok(available_fields
290 .iter()
291 .filter(|f| !exc_fields.iter().contains(f))
292 .cloned()
293 .collect()),
294 }
295 }
296}
297
298impl Display for FieldSelection {
299 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
300 match self {
301 FieldSelection::Include(fields) => write!(f, "{{{}}}", DisplayFieldNames(fields)),
302 FieldSelection::Exclude(fields) => write!(f, "~{{{}}}", DisplayFieldNames(fields)),
303 }
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use vortex_buffer::buffer;
310
311 use crate::IntoArray;
312 #[expect(deprecated)]
313 use crate::ToCanonical as _;
314 use crate::arrays::struct_::StructArrayExt;
315 use crate::dtype::DType;
316 use crate::dtype::FieldName;
317 use crate::dtype::FieldNames;
318 use crate::dtype::Nullability;
319 use crate::dtype::Nullability::Nullable;
320 use crate::dtype::PType::I32;
321 use crate::dtype::StructFields;
322 use crate::expr::root;
323 use crate::expr::select;
324 use crate::expr::select_exclude;
325 use crate::expr::test_harness;
326 use crate::scalar_fn::fns::select::Select;
327 use crate::scalar_fn::fns::select::StructArray;
328
329 fn test_array() -> StructArray {
330 StructArray::from_fields(&[
331 ("a", buffer![0, 1, 2].into_array()),
332 ("b", buffer![4, 5, 6].into_array()),
333 ])
334 .unwrap()
335 }
336
337 #[test]
338 pub fn include_columns() {
339 let st = test_array();
340 let select = select(vec![FieldName::from("a")], root());
341 #[expect(deprecated)]
342 let selected = st.into_array().apply(&select).unwrap().to_struct();
343 let selected_names = selected.names().clone();
344 assert_eq!(selected_names.as_ref(), &["a"]);
345 }
346
347 #[test]
348 pub fn exclude_columns() {
349 let st = test_array();
350 let select = select_exclude(vec![FieldName::from("a")], root());
351 #[expect(deprecated)]
352 let selected = st.into_array().apply(&select).unwrap().to_struct();
353 let selected_names = selected.names().clone();
354 assert_eq!(selected_names.as_ref(), &["b"]);
355 }
356
357 #[test]
358 fn dtype() {
359 let dtype = test_harness::struct_dtype();
360
361 let select_expr = select(vec![FieldName::from("a")], root());
362 let expected_dtype = DType::Struct(
363 dtype
364 .as_struct_fields_opt()
365 .unwrap()
366 .project(&["a".into()])
367 .unwrap(),
368 Nullability::NonNullable,
369 );
370 assert_eq!(select_expr.return_dtype(&dtype).unwrap(), expected_dtype);
371
372 let select_expr_exclude = select_exclude(
373 vec![
374 FieldName::from("col1"),
375 FieldName::from("col2"),
376 FieldName::from("bool1"),
377 FieldName::from("bool2"),
378 ],
379 root(),
380 );
381 assert_eq!(
382 select_expr_exclude.return_dtype(&dtype).unwrap(),
383 expected_dtype
384 );
385
386 let select_expr_exclude = select_exclude(
387 vec![FieldName::from("col1"), FieldName::from("col2")],
388 root(),
389 );
390 assert_eq!(
391 select_expr_exclude.return_dtype(&dtype).unwrap(),
392 DType::Struct(
393 dtype
394 .as_struct_fields_opt()
395 .unwrap()
396 .project(&["a".into(), "bool1".into(), "bool2".into()])
397 .unwrap(),
398 Nullability::NonNullable
399 )
400 );
401 }
402
403 #[test]
404 fn test_as_include_names() {
405 let field_names = FieldNames::from(["a", "b", "c"]);
406 let include = select(["a"], root());
407 let exclude = select_exclude(["b", "c"], root());
408 assert_eq!(
409 &include
410 .as_::<Select>()
411 .normalize_to_included_fields(&field_names)
412 .unwrap(),
413 &exclude
414 .as_::<Select>()
415 .normalize_to_included_fields(&field_names)
416 .unwrap()
417 );
418 }
419
420 #[test]
421 fn test_remove_select_rule() {
422 let dtype = DType::Struct(
423 StructFields::new(["a", "b"].into(), vec![I32.into(), I32.into()]),
424 Nullable,
425 );
426 let e = select(["a", "b"], root());
427
428 let result = e.optimize_recursive(&dtype).unwrap();
429
430 assert!(result.return_dtype(&dtype).unwrap().is_nullable());
431 }
432
433 #[test]
434 fn test_remove_select_rule_exclude_fields() {
435 use crate::expr::select_exclude;
436
437 let dtype = DType::Struct(
438 StructFields::new(
439 ["a", "b", "c"].into(),
440 vec![I32.into(), I32.into(), I32.into()],
441 ),
442 Nullable,
443 );
444 let e = select_exclude(["c"], root());
445
446 let result = e.optimize_recursive(&dtype).unwrap();
447
448 let result_dtype = result.return_dtype(&dtype).unwrap();
450 assert!(result_dtype.is_nullable());
451 let fields = result_dtype.as_struct_fields_opt().unwrap();
452 assert_eq!(fields.names().as_ref(), &["a", "b"]);
453 }
454}