vortex_array/expr/exprs/select/
mod.rs1pub mod transform;
5
6use std::fmt::Display;
7use std::fmt::Formatter;
8use std::sync::Arc;
9
10use itertools::Itertools;
11use prost::Message;
12use vortex_dtype::DType;
13use vortex_dtype::FieldNames;
14use vortex_error::VortexExpect;
15use vortex_error::VortexResult;
16use vortex_error::vortex_bail;
17use vortex_error::vortex_err;
18use vortex_proto::expr::FieldNames as ProtoFieldNames;
19use vortex_proto::expr::SelectOpts;
20use vortex_proto::expr::select_opts::Opts;
21use vortex_vector::Vector;
22use vortex_vector::struct_::StructVector;
23
24use crate::ArrayRef;
25use crate::IntoArray;
26use crate::ToCanonical;
27use crate::expr::ChildName;
28use crate::expr::ExecutionArgs;
29use crate::expr::ExprId;
30use crate::expr::ExpressionView;
31use crate::expr::VTable;
32use crate::expr::VTableExt;
33use crate::expr::expression::Expression;
34use crate::expr::field::DisplayFieldNames;
35
36#[derive(Debug, Clone, PartialEq, Eq, Hash)]
37pub enum FieldSelection {
38 Include(FieldNames),
39 Exclude(FieldNames),
40}
41
42pub struct Select;
43
44impl VTable for Select {
45 type Instance = FieldSelection;
46
47 fn id(&self) -> ExprId {
48 ExprId::new_ref("vortex.select")
49 }
50
51 fn serialize(&self, instance: &Self::Instance) -> VortexResult<Option<Vec<u8>>> {
52 let opts = match instance {
53 FieldSelection::Include(fields) => Opts::Include(ProtoFieldNames {
54 names: fields.iter().map(|f| f.to_string()).collect(),
55 }),
56 FieldSelection::Exclude(fields) => Opts::Exclude(ProtoFieldNames {
57 names: fields.iter().map(|f| f.to_string()).collect(),
58 }),
59 };
60
61 let select_opts = SelectOpts { opts: Some(opts) };
62 Ok(Some(select_opts.encode_to_vec()))
63 }
64
65 fn deserialize(&self, metadata: &[u8]) -> VortexResult<Option<Self::Instance>> {
66 let prost_metadata = SelectOpts::decode(metadata)?;
67
68 let select_opts = prost_metadata
69 .opts
70 .ok_or_else(|| vortex_err!("SelectOpts missing opts field"))?;
71
72 let field_selection = match select_opts {
73 Opts::Include(field_names) => FieldSelection::Include(FieldNames::from_iter(
74 field_names.names.iter().map(|s| s.as_str()),
75 )),
76 Opts::Exclude(field_names) => FieldSelection::Exclude(FieldNames::from_iter(
77 field_names.names.iter().map(|s| s.as_str()),
78 )),
79 };
80
81 Ok(Some(field_selection))
82 }
83
84 fn validate(&self, expr: &ExpressionView<Self>) -> VortexResult<()> {
85 if expr.children().len() != 1 {
86 vortex_bail!(
87 "Select expression requires exactly 1 child, got {}",
88 expr.children().len()
89 );
90 }
91 Ok(())
92 }
93
94 fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName {
95 match child_idx {
96 0 => ChildName::new_ref("child"),
97 _ => unreachable!(),
98 }
99 }
100
101 fn fmt_sql(&self, expr: &ExpressionView<Self>, f: &mut Formatter<'_>) -> std::fmt::Result {
102 expr.child().fmt_sql(f)?;
103 match expr.data() {
104 FieldSelection::Include(fields) => {
105 write!(f, "{{{}}}", DisplayFieldNames(fields))
106 }
107 FieldSelection::Exclude(fields) => {
108 write!(f, "{{~ {}}}", DisplayFieldNames(fields))
109 }
110 }
111 }
112
113 fn fmt_data(&self, instance: &Self::Instance, f: &mut Formatter<'_>) -> std::fmt::Result {
114 let names = match instance {
115 FieldSelection::Include(names) => {
116 write!(f, "include=")?;
117 names
118 }
119 FieldSelection::Exclude(names) => {
120 write!(f, "exclude=")?;
121 names
122 }
123 };
124 write!(f, "{{{}}}", DisplayFieldNames(names))
125 }
126
127 fn return_dtype(&self, expr: &ExpressionView<Self>, scope: &DType) -> VortexResult<DType> {
128 let child_dtype = expr.child().return_dtype(scope)?;
129 let child_struct_dtype = child_dtype
130 .as_struct_fields_opt()
131 .ok_or_else(|| vortex_err!("Select child not a struct dtype"))?;
132
133 let projected = match expr.data() {
134 FieldSelection::Include(fields) => child_struct_dtype.project(fields.as_ref())?,
135 FieldSelection::Exclude(fields) => child_struct_dtype
136 .names()
137 .iter()
138 .cloned()
139 .zip_eq(child_struct_dtype.fields())
140 .filter(|(name, _)| !fields.as_ref().contains(name))
141 .collect(),
142 };
143
144 Ok(DType::Struct(projected, child_dtype.nullability()))
145 }
146
147 fn evaluate(&self, expr: &ExpressionView<Self>, scope: &ArrayRef) -> VortexResult<ArrayRef> {
148 let batch = expr.child().evaluate(scope)?.to_struct();
149 Ok(match expr.data() {
150 FieldSelection::Include(f) => batch.project(f.as_ref()),
151 FieldSelection::Exclude(names) => {
152 let included_names = batch
153 .names()
154 .iter()
155 .filter(|&f| !names.as_ref().contains(f))
156 .cloned()
157 .collect::<Vec<_>>();
158 batch.project(included_names.as_slice())
159 }
160 }?
161 .into_array())
162 }
163
164 fn execute(&self, selection: &FieldSelection, mut args: ExecutionArgs) -> VortexResult<Vector> {
165 let child = args
166 .vectors
167 .pop()
168 .vortex_expect("Missing input child")
169 .into_struct();
170 let child_fields = args
171 .dtypes
172 .pop()
173 .vortex_expect("Missing input dtype")
174 .into_struct_fields();
175
176 let field_indices: Vec<usize> = match selection {
177 FieldSelection::Include(f) => f
178 .iter()
179 .map(|name| {
180 child_fields
181 .find(name)
182 .ok_or_else(|| vortex_err!("Field {} not found in struct dtype", name))
183 })
184 .try_collect(),
185 FieldSelection::Exclude(names) => child_fields
186 .names()
187 .iter()
188 .filter(|&f| !names.as_ref().contains(f))
189 .map(|name| {
190 child_fields
191 .find(name)
192 .ok_or_else(|| vortex_err!("Field {} not found in struct dtype", name))
193 })
194 .try_collect(),
195 }?;
196
197 let (fields, mask) = child.into_parts();
198 let new_fields = field_indices
199 .iter()
200 .map(|&idx| fields[idx].clone())
201 .collect();
202 Ok(unsafe { StructVector::new_unchecked(Arc::new(new_fields), mask) }.into())
203 }
204
205 fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool {
206 true
207 }
208
209 fn is_fallible(&self, _instance: &Self::Instance) -> bool {
210 false
212 }
213}
214
215pub fn select(field_names: impl Into<FieldNames>, child: Expression) -> Expression {
223 Select
224 .try_new_expr(FieldSelection::Include(field_names.into()), [child])
225 .vortex_expect("Failed to create Select expression")
226}
227
228pub fn select_exclude(fields: impl Into<FieldNames>, child: Expression) -> Expression {
237 Select
238 .try_new_expr(FieldSelection::Exclude(fields.into()), [child])
239 .vortex_expect("Failed to create Select expression")
240}
241
242impl ExpressionView<'_, Select> {
243 pub fn child(&self) -> &Expression {
244 &self.children()[0]
245 }
246
247 pub fn as_include(&self, field_names: &FieldNames) -> VortexResult<Expression> {
263 Select.try_new_expr(
264 FieldSelection::Include(self.data().as_include_names(field_names)?),
265 [self.child().clone()],
266 )
267 }
268}
269
270impl FieldSelection {
271 pub fn include(columns: FieldNames) -> Self {
272 assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
273 Self::Include(columns)
274 }
275
276 pub fn exclude(columns: FieldNames) -> Self {
277 assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
278 Self::Exclude(columns)
279 }
280
281 pub fn is_include(&self) -> bool {
282 matches!(self, Self::Include(_))
283 }
284
285 pub fn is_exclude(&self) -> bool {
286 matches!(self, Self::Exclude(_))
287 }
288
289 pub fn field_names(&self) -> &FieldNames {
290 let (FieldSelection::Include(fields) | FieldSelection::Exclude(fields)) = self;
291
292 fields
293 }
294
295 pub fn as_include_names(&self, field_names: &FieldNames) -> VortexResult<FieldNames> {
296 if self
297 .field_names()
298 .iter()
299 .any(|f| !field_names.iter().contains(f))
300 {
301 vortex_bail!(
302 "Field {:?} in select not in field names {:?}",
303 self,
304 field_names
305 );
306 }
307 match self {
308 FieldSelection::Include(fields) => Ok(fields.clone()),
309 FieldSelection::Exclude(exc_fields) => Ok(field_names
310 .iter()
311 .filter(|f| !exc_fields.iter().contains(f))
312 .cloned()
313 .collect()),
314 }
315 }
316}
317
318impl Display for FieldSelection {
319 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
320 match self {
321 FieldSelection::Include(fields) => write!(f, "{{{}}}", DisplayFieldNames(fields)),
322 FieldSelection::Exclude(fields) => write!(f, "~{{{}}}", DisplayFieldNames(fields)),
323 }
324 }
325}
326
327#[cfg(test)]
328mod tests {
329 use vortex_buffer::buffer;
330 use vortex_dtype::DType;
331 use vortex_dtype::FieldName;
332 use vortex_dtype::FieldNames;
333 use vortex_dtype::Nullability;
334
335 use super::select;
336 use super::select_exclude;
337 use crate::IntoArray;
338 use crate::ToCanonical;
339 use crate::arrays::StructArray;
340 use crate::expr::exprs::root::root;
341 use crate::expr::exprs::select::Select;
342 use crate::expr::test_harness;
343
344 fn test_array() -> StructArray {
345 StructArray::from_fields(&[
346 ("a", buffer![0, 1, 2].into_array()),
347 ("b", buffer![4, 5, 6].into_array()),
348 ])
349 .unwrap()
350 }
351
352 #[test]
353 pub fn include_columns() {
354 let st = test_array();
355 let select = select(vec![FieldName::from("a")], root());
356 let selected = select.evaluate(&st.to_array()).unwrap().to_struct();
357 let selected_names = selected.names().clone();
358 assert_eq!(selected_names.as_ref(), &["a"]);
359 }
360
361 #[test]
362 pub fn exclude_columns() {
363 let st = test_array();
364 let select = select_exclude(vec![FieldName::from("a")], root());
365 let selected = select.evaluate(&st.to_array()).unwrap().to_struct();
366 let selected_names = selected.names().clone();
367 assert_eq!(selected_names.as_ref(), &["b"]);
368 }
369
370 #[test]
371 fn dtype() {
372 let dtype = test_harness::struct_dtype();
373
374 let select_expr = select(vec![FieldName::from("a")], root());
375 let expected_dtype = DType::Struct(
376 dtype
377 .as_struct_fields_opt()
378 .unwrap()
379 .project(&["a".into()])
380 .unwrap(),
381 Nullability::NonNullable,
382 );
383 assert_eq!(select_expr.return_dtype(&dtype).unwrap(), expected_dtype);
384
385 let select_expr_exclude = select_exclude(
386 vec![
387 FieldName::from("col1"),
388 FieldName::from("col2"),
389 FieldName::from("bool1"),
390 FieldName::from("bool2"),
391 ],
392 root(),
393 );
394 assert_eq!(
395 select_expr_exclude.return_dtype(&dtype).unwrap(),
396 expected_dtype
397 );
398
399 let select_expr_exclude = select_exclude(
400 vec![FieldName::from("col1"), FieldName::from("col2")],
401 root(),
402 );
403 assert_eq!(
404 select_expr_exclude.return_dtype(&dtype).unwrap(),
405 DType::Struct(
406 dtype
407 .as_struct_fields_opt()
408 .unwrap()
409 .project(&["a".into(), "bool1".into(), "bool2".into()])
410 .unwrap(),
411 Nullability::NonNullable
412 )
413 );
414 }
415
416 #[test]
417 fn test_as_include_names() {
418 let field_names = FieldNames::from(["a", "b", "c"]);
419 let include = select(["a"], root());
420 let exclude = select_exclude(["b", "c"], root());
421 assert_eq!(
422 &include
423 .as_::<Select>()
424 .data()
425 .as_include_names(&field_names)
426 .unwrap(),
427 &exclude
428 .as_::<Select>()
429 .data()
430 .as_include_names(&field_names)
431 .unwrap()
432 );
433 }
434}