1use std::fmt::Display;
5use std::fmt::Formatter;
6use std::sync::Arc;
7
8use itertools::Itertools;
9use prost::Message;
10use vortex_dtype::DType;
11use vortex_dtype::FieldNames;
12use vortex_error::VortexExpect;
13use vortex_error::VortexResult;
14use vortex_error::vortex_bail;
15use vortex_error::vortex_err;
16use vortex_proto::expr::FieldNames as ProtoFieldNames;
17use vortex_proto::expr::SelectOpts;
18use vortex_proto::expr::select_opts::Opts;
19use vortex_vector::Datum;
20use vortex_vector::StructDatum;
21use vortex_vector::VectorOps;
22use vortex_vector::struct_::StructVector;
23
24use crate::ArrayRef;
25use crate::IntoArray;
26use crate::ToCanonical;
27use crate::expr::Arity;
28use crate::expr::ChildName;
29use crate::expr::ExecutionArgs;
30use crate::expr::ExprId;
31use crate::expr::SimplifyCtx;
32use crate::expr::VTable;
33use crate::expr::VTableExt;
34use crate::expr::expression::Expression;
35use crate::expr::field::DisplayFieldNames;
36use crate::expr::get_item;
37use crate::expr::pack;
38
39#[derive(Debug, Clone, PartialEq, Eq, Hash)]
40pub enum FieldSelection {
41 Include(FieldNames),
42 Exclude(FieldNames),
43}
44
45pub struct Select;
46
47impl VTable for Select {
48 type Options = FieldSelection;
49
50 fn id(&self) -> ExprId {
51 ExprId::new_ref("vortex.select")
52 }
53
54 fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
55 let opts = match instance {
56 FieldSelection::Include(fields) => Opts::Include(ProtoFieldNames {
57 names: fields.iter().map(|f| f.to_string()).collect(),
58 }),
59 FieldSelection::Exclude(fields) => Opts::Exclude(ProtoFieldNames {
60 names: fields.iter().map(|f| f.to_string()).collect(),
61 }),
62 };
63
64 let select_opts = SelectOpts { opts: Some(opts) };
65 Ok(Some(select_opts.encode_to_vec()))
66 }
67
68 fn deserialize(&self, metadata: &[u8]) -> VortexResult<Self::Options> {
69 let prost_metadata = SelectOpts::decode(metadata)?;
70
71 let select_opts = prost_metadata
72 .opts
73 .ok_or_else(|| vortex_err!("SelectOpts missing opts field"))?;
74
75 let field_selection = match select_opts {
76 Opts::Include(field_names) => FieldSelection::Include(FieldNames::from_iter(
77 field_names.names.iter().map(|s| s.as_str()),
78 )),
79 Opts::Exclude(field_names) => FieldSelection::Exclude(FieldNames::from_iter(
80 field_names.names.iter().map(|s| s.as_str()),
81 )),
82 };
83
84 Ok(field_selection)
85 }
86
87 fn arity(&self, _options: &Self::Options) -> Arity {
88 Arity::Exact(1)
89 }
90
91 fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
92 match child_idx {
93 0 => ChildName::new_ref("child"),
94 _ => unreachable!(),
95 }
96 }
97
98 fn fmt_sql(
99 &self,
100 selection: &FieldSelection,
101 expr: &Expression,
102 f: &mut Formatter<'_>,
103 ) -> std::fmt::Result {
104 expr.child(0).fmt_sql(f)?;
105 match selection {
106 FieldSelection::Include(fields) => {
107 write!(f, "{{{}}}", DisplayFieldNames(fields))
108 }
109 FieldSelection::Exclude(fields) => {
110 write!(f, "{{~ {}}}", DisplayFieldNames(fields))
111 }
112 }
113 }
114
115 fn return_dtype(
116 &self,
117 selection: &FieldSelection,
118 arg_dtypes: &[DType],
119 ) -> VortexResult<DType> {
120 let child_dtype = &arg_dtypes[0];
121 let child_struct_dtype = child_dtype
122 .as_struct_fields_opt()
123 .ok_or_else(|| vortex_err!("Select child not a struct dtype"))?;
124
125 let projected = match selection {
126 FieldSelection::Include(fields) => child_struct_dtype.project(fields.as_ref())?,
127 FieldSelection::Exclude(fields) => child_struct_dtype
128 .names()
129 .iter()
130 .cloned()
131 .zip_eq(child_struct_dtype.fields())
132 .filter(|(name, _)| !fields.as_ref().contains(name))
133 .collect(),
134 };
135
136 Ok(DType::Struct(projected, child_dtype.nullability()))
137 }
138
139 fn evaluate(
140 &self,
141 selection: &FieldSelection,
142 expr: &Expression,
143 scope: &ArrayRef,
144 ) -> VortexResult<ArrayRef> {
145 let batch = expr.child(0).evaluate(scope)?.to_struct();
146 Ok(match selection {
147 FieldSelection::Include(f) => batch.project(f.as_ref()),
148 FieldSelection::Exclude(names) => {
149 let included_names = batch
150 .names()
151 .iter()
152 .filter(|&f| !names.as_ref().contains(f))
153 .cloned()
154 .collect::<Vec<_>>();
155 batch.project(included_names.as_slice())
156 }
157 }?
158 .into_array())
159 }
160
161 fn execute(&self, selection: &FieldSelection, mut args: ExecutionArgs) -> VortexResult<Datum> {
162 let child_fields = args
163 .dtypes
164 .pop()
165 .vortex_expect("Missing input dtype")
166 .into_struct_fields();
167
168 let field_indices: Vec<usize> = match selection {
169 FieldSelection::Include(f) => f
170 .iter()
171 .map(|name| {
172 child_fields
173 .find(name)
174 .ok_or_else(|| vortex_err!("Field {} not found in struct dtype", name))
175 })
176 .try_collect(),
177 FieldSelection::Exclude(names) => child_fields
178 .names()
179 .iter()
180 .filter(|&f| !names.as_ref().contains(f))
181 .map(|name| {
182 child_fields
183 .find(name)
184 .ok_or_else(|| vortex_err!("Field {} not found in struct dtype", name))
185 })
186 .try_collect(),
187 }?;
188
189 let child = args
190 .datums
191 .pop()
192 .vortex_expect("Missing input child")
193 .into_struct();
194
195 Ok(match child {
196 StructDatum::Scalar(s) => StructDatum::Scalar(
197 select_from_struct_vector(s.value(), &field_indices)?.scalar_at(0),
198 ),
199 StructDatum::Vector(v) => {
200 StructDatum::Vector(select_from_struct_vector(&v, &field_indices)?)
201 }
202 }
203 .into())
204 }
205
206 fn simplify(
207 &self,
208 options: &Self::Options,
209 expr: &Expression,
210 ctx: &dyn SimplifyCtx,
211 ) -> VortexResult<Option<Expression>> {
212 let child = expr.child(0);
213 let child_dtype = ctx.return_dtype(child)?;
214 let child_nullability = child_dtype.nullability();
215
216 let child_dtype = child_dtype.as_struct_fields_opt().ok_or_else(|| {
217 vortex_err!(
218 "Select child must return a struct dtype, however it was a {}",
219 child_dtype
220 )
221 })?;
222
223 let expr = pack(
224 options
225 .as_include_names(child_dtype.names())
226 .map_err(|e| {
227 e.with_context(format!(
228 "Select fields {:?} must be a subset of child fields {:?}",
229 options,
230 child_dtype.names()
231 ))
232 })?
233 .iter()
234 .map(|name| (name.clone(), get_item(name.clone(), child.clone()))),
235 child_nullability,
236 );
237
238 Ok(Some(expr))
239 }
240
241 fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
242 true
243 }
244
245 fn is_fallible(&self, _instance: &Self::Options) -> bool {
246 false
248 }
249}
250
251fn select_from_struct_vector(
252 vec: &StructVector,
253 field_indices: &[usize],
254) -> VortexResult<StructVector> {
255 let new_fields = field_indices
256 .iter()
257 .map(|&idx| vec.fields()[idx].clone())
258 .collect();
259 Ok(unsafe { StructVector::new_unchecked(Arc::new(new_fields), vec.validity().clone()) })
260}
261
262pub fn select(field_names: impl Into<FieldNames>, child: Expression) -> Expression {
270 Select
271 .try_new_expr(FieldSelection::Include(field_names.into()), [child])
272 .vortex_expect("Failed to create Select expression")
273}
274
275pub fn select_exclude(fields: impl Into<FieldNames>, child: Expression) -> Expression {
284 Select
285 .try_new_expr(FieldSelection::Exclude(fields.into()), [child])
286 .vortex_expect("Failed to create Select expression")
287}
288
289impl FieldSelection {
290 pub fn include(columns: FieldNames) -> Self {
291 assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
292 Self::Include(columns)
293 }
294
295 pub fn exclude(columns: FieldNames) -> Self {
296 assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
297 Self::Exclude(columns)
298 }
299
300 pub fn is_include(&self) -> bool {
301 matches!(self, Self::Include(_))
302 }
303
304 pub fn is_exclude(&self) -> bool {
305 matches!(self, Self::Exclude(_))
306 }
307
308 pub fn field_names(&self) -> &FieldNames {
309 let (FieldSelection::Include(fields) | FieldSelection::Exclude(fields)) = self;
310
311 fields
312 }
313
314 pub fn as_include_names(&self, field_names: &FieldNames) -> VortexResult<FieldNames> {
315 if self
316 .field_names()
317 .iter()
318 .any(|f| !field_names.iter().contains(f))
319 {
320 vortex_bail!(
321 "Field {:?} in select not in field names {:?}",
322 self,
323 field_names
324 );
325 }
326 match self {
327 FieldSelection::Include(fields) => Ok(fields.clone()),
328 FieldSelection::Exclude(exc_fields) => Ok(field_names
329 .iter()
330 .filter(|f| !exc_fields.iter().contains(f))
331 .cloned()
332 .collect()),
333 }
334 }
335}
336
337impl Display for FieldSelection {
338 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
339 match self {
340 FieldSelection::Include(fields) => write!(f, "{{{}}}", DisplayFieldNames(fields)),
341 FieldSelection::Exclude(fields) => write!(f, "~{{{}}}", DisplayFieldNames(fields)),
342 }
343 }
344}
345
346#[cfg(test)]
347mod tests {
348 use vortex_buffer::buffer;
349 use vortex_dtype::DType;
350 use vortex_dtype::FieldName;
351 use vortex_dtype::FieldNames;
352 use vortex_dtype::Nullability;
353 use vortex_dtype::Nullability::Nullable;
354 use vortex_dtype::PType::I32;
355 use vortex_dtype::StructFields;
356
357 use super::select;
358 use super::select_exclude;
359 use crate::IntoArray;
360 use crate::ToCanonical;
361 use crate::arrays::StructArray;
362 use crate::expr::exprs::pack::Pack;
363 use crate::expr::exprs::root::root;
364 use crate::expr::exprs::select::Select;
365 use crate::expr::test_harness;
366
367 fn test_array() -> StructArray {
368 StructArray::from_fields(&[
369 ("a", buffer![0, 1, 2].into_array()),
370 ("b", buffer![4, 5, 6].into_array()),
371 ])
372 .unwrap()
373 }
374
375 #[test]
376 pub fn include_columns() {
377 let st = test_array();
378 let select = select(vec![FieldName::from("a")], root());
379 let selected = select.evaluate(&st.to_array()).unwrap().to_struct();
380 let selected_names = selected.names().clone();
381 assert_eq!(selected_names.as_ref(), &["a"]);
382 }
383
384 #[test]
385 pub fn exclude_columns() {
386 let st = test_array();
387 let select = select_exclude(vec![FieldName::from("a")], root());
388 let selected = select.evaluate(&st.to_array()).unwrap().to_struct();
389 let selected_names = selected.names().clone();
390 assert_eq!(selected_names.as_ref(), &["b"]);
391 }
392
393 #[test]
394 fn dtype() {
395 let dtype = test_harness::struct_dtype();
396
397 let select_expr = select(vec![FieldName::from("a")], root());
398 let expected_dtype = DType::Struct(
399 dtype
400 .as_struct_fields_opt()
401 .unwrap()
402 .project(&["a".into()])
403 .unwrap(),
404 Nullability::NonNullable,
405 );
406 assert_eq!(select_expr.return_dtype(&dtype).unwrap(), expected_dtype);
407
408 let select_expr_exclude = select_exclude(
409 vec![
410 FieldName::from("col1"),
411 FieldName::from("col2"),
412 FieldName::from("bool1"),
413 FieldName::from("bool2"),
414 ],
415 root(),
416 );
417 assert_eq!(
418 select_expr_exclude.return_dtype(&dtype).unwrap(),
419 expected_dtype
420 );
421
422 let select_expr_exclude = select_exclude(
423 vec![FieldName::from("col1"), FieldName::from("col2")],
424 root(),
425 );
426 assert_eq!(
427 select_expr_exclude.return_dtype(&dtype).unwrap(),
428 DType::Struct(
429 dtype
430 .as_struct_fields_opt()
431 .unwrap()
432 .project(&["a".into(), "bool1".into(), "bool2".into()])
433 .unwrap(),
434 Nullability::NonNullable
435 )
436 );
437 }
438
439 #[test]
440 fn test_as_include_names() {
441 let field_names = FieldNames::from(["a", "b", "c"]);
442 let include = select(["a"], root());
443 let exclude = select_exclude(["b", "c"], root());
444 assert_eq!(
445 &include
446 .as_::<Select>()
447 .as_include_names(&field_names)
448 .unwrap(),
449 &exclude
450 .as_::<Select>()
451 .as_include_names(&field_names)
452 .unwrap()
453 );
454 }
455
456 #[test]
457 fn test_remove_select_rule() {
458 let dtype = DType::Struct(
459 StructFields::new(["a", "b"].into(), vec![I32.into(), I32.into()]),
460 Nullable,
461 );
462 let e = select(["a", "b"], root());
463
464 let result = e.optimize_recursive(&dtype).unwrap();
465
466 assert!(result.is::<Pack>());
467 assert!(result.return_dtype(&dtype).unwrap().is_nullable());
468 }
469
470 #[test]
471 fn test_remove_select_rule_exclude_fields() {
472 use crate::expr::exprs::select::select_exclude;
473
474 let dtype = DType::Struct(
475 StructFields::new(
476 ["a", "b", "c"].into(),
477 vec![I32.into(), I32.into(), I32.into()],
478 ),
479 Nullable,
480 );
481 let e = select_exclude(["c"], root());
482
483 let result = e.optimize_recursive(&dtype).unwrap();
484
485 assert!(result.is::<Pack>());
486
487 let result_dtype = result.return_dtype(&dtype).unwrap();
489 assert!(result_dtype.is_nullable());
490 let fields = result_dtype.as_struct_fields_opt().unwrap();
491 assert_eq!(fields.names().as_ref(), &["a", "b"]);
492 }
493}