1use std::fmt::{Display, Formatter};
5
6use itertools::Itertools;
7use prost::Message;
8use vortex_dtype::{DType, FieldNames};
9use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
10use vortex_proto::expr::select_opts::Opts;
11use vortex_proto::expr::{FieldNames as ProtoFieldNames, SelectOpts};
12
13use crate::expr::expression::Expression;
14use crate::expr::field::DisplayFieldNames;
15use crate::expr::{ChildName, ExprId, ExpressionView, VTable, VTableExt};
16use crate::{ArrayRef, IntoArray, ToCanonical};
17
18#[derive(Debug, Clone, PartialEq, Eq, Hash)]
19pub enum FieldSelection {
20 Include(FieldNames),
21 Exclude(FieldNames),
22}
23
24pub struct Select;
25
26impl VTable for Select {
27 type Instance = FieldSelection;
28
29 fn id(&self) -> ExprId {
30 ExprId::new_ref("vortex.select")
31 }
32
33 fn serialize(&self, instance: &Self::Instance) -> VortexResult<Option<Vec<u8>>> {
34 let opts = match instance {
35 FieldSelection::Include(fields) => Opts::Include(ProtoFieldNames {
36 names: fields.iter().map(|f| f.to_string()).collect(),
37 }),
38 FieldSelection::Exclude(fields) => Opts::Exclude(ProtoFieldNames {
39 names: fields.iter().map(|f| f.to_string()).collect(),
40 }),
41 };
42
43 let select_opts = SelectOpts { opts: Some(opts) };
44 Ok(Some(select_opts.encode_to_vec()))
45 }
46
47 fn deserialize(&self, metadata: &[u8]) -> VortexResult<Option<Self::Instance>> {
48 let prost_metadata = SelectOpts::decode(metadata)?;
49
50 let select_opts = prost_metadata
51 .opts
52 .ok_or_else(|| vortex_err!("SelectOpts missing opts field"))?;
53
54 let field_selection = match select_opts {
55 Opts::Include(field_names) => FieldSelection::Include(FieldNames::from_iter(
56 field_names.names.iter().map(|s| s.as_str()),
57 )),
58 Opts::Exclude(field_names) => FieldSelection::Exclude(FieldNames::from_iter(
59 field_names.names.iter().map(|s| s.as_str()),
60 )),
61 };
62
63 Ok(Some(field_selection))
64 }
65
66 fn validate(&self, expr: &ExpressionView<Self>) -> VortexResult<()> {
67 if expr.children().len() != 1 {
68 vortex_bail!(
69 "Select expression requires exactly 1 child, got {}",
70 expr.children().len()
71 );
72 }
73 Ok(())
74 }
75
76 fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName {
77 match child_idx {
78 0 => ChildName::new_ref("child"),
79 _ => unreachable!(),
80 }
81 }
82
83 fn fmt_sql(&self, expr: &ExpressionView<Self>, f: &mut Formatter<'_>) -> std::fmt::Result {
84 expr.child().fmt_sql(f)?;
85 match expr.data() {
86 FieldSelection::Include(fields) => {
87 write!(f, "{{{}}}", DisplayFieldNames(fields))
88 }
89 FieldSelection::Exclude(fields) => {
90 write!(f, "{{~ {}}}", DisplayFieldNames(fields))
91 }
92 }
93 }
94
95 fn fmt_data(&self, instance: &Self::Instance, f: &mut Formatter<'_>) -> std::fmt::Result {
96 let names = match instance {
97 FieldSelection::Include(names) => {
98 write!(f, "include=")?;
99 names
100 }
101 FieldSelection::Exclude(names) => {
102 write!(f, "exclude=")?;
103 names
104 }
105 };
106 write!(f, "{{{}}}", DisplayFieldNames(names))
107 }
108
109 fn return_dtype(&self, expr: &ExpressionView<Self>, scope: &DType) -> VortexResult<DType> {
110 let child_dtype = expr.child().return_dtype(scope)?;
111 let child_struct_dtype = child_dtype
112 .as_struct_fields_opt()
113 .ok_or_else(|| vortex_err!("Select child not a struct dtype"))?;
114
115 let projected = match expr.data() {
116 FieldSelection::Include(fields) => child_struct_dtype.project(fields.as_ref())?,
117 FieldSelection::Exclude(fields) => child_struct_dtype
118 .names()
119 .iter()
120 .cloned()
121 .zip_eq(child_struct_dtype.fields())
122 .filter(|(name, _)| !fields.as_ref().contains(name))
123 .collect(),
124 };
125
126 Ok(DType::Struct(projected, child_dtype.nullability()))
127 }
128
129 fn evaluate(&self, expr: &ExpressionView<Self>, scope: &ArrayRef) -> VortexResult<ArrayRef> {
130 let batch = expr.child().evaluate(scope)?.to_struct();
131 Ok(match expr.data() {
132 FieldSelection::Include(f) => batch.project(f.as_ref()),
133 FieldSelection::Exclude(names) => {
134 let included_names = batch
135 .names()
136 .iter()
137 .filter(|&f| !names.as_ref().contains(f))
138 .cloned()
139 .collect::<Vec<_>>();
140 batch.project(included_names.as_slice())
141 }
142 }?
143 .into_array())
144 }
145}
146
147pub fn select(field_names: impl Into<FieldNames>, child: Expression) -> Expression {
155 Select
156 .try_new_expr(FieldSelection::Include(field_names.into()), [child])
157 .vortex_expect("Failed to create Select expression")
158}
159
160pub fn select_exclude(fields: impl Into<FieldNames>, child: Expression) -> Expression {
169 Select
170 .try_new_expr(FieldSelection::Exclude(fields.into()), [child])
171 .vortex_expect("Failed to create Select expression")
172}
173
174impl ExpressionView<'_, Select> {
175 pub fn child(&self) -> &Expression {
176 &self.children()[0]
177 }
178
179 pub fn as_include(&self, field_names: &FieldNames) -> VortexResult<Expression> {
195 Select.try_new_expr(
196 FieldSelection::Include(self.data().as_include_names(field_names)?),
197 [self.child().clone()],
198 )
199 }
200}
201
202impl FieldSelection {
203 pub fn include(columns: FieldNames) -> Self {
204 assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
205 Self::Include(columns)
206 }
207
208 pub fn exclude(columns: FieldNames) -> Self {
209 assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
210 Self::Exclude(columns)
211 }
212
213 pub fn is_include(&self) -> bool {
214 matches!(self, Self::Include(_))
215 }
216
217 pub fn is_exclude(&self) -> bool {
218 matches!(self, Self::Exclude(_))
219 }
220
221 pub fn field_names(&self) -> &FieldNames {
222 let (FieldSelection::Include(fields) | FieldSelection::Exclude(fields)) = self;
223
224 fields
225 }
226
227 pub fn as_include_names(&self, field_names: &FieldNames) -> VortexResult<FieldNames> {
228 if self
229 .field_names()
230 .iter()
231 .any(|f| !field_names.iter().contains(f))
232 {
233 vortex_bail!(
234 "Field {:?} in select not in field names {:?}",
235 self,
236 field_names
237 );
238 }
239 match self {
240 FieldSelection::Include(fields) => Ok(fields.clone()),
241 FieldSelection::Exclude(exc_fields) => Ok(field_names
242 .iter()
243 .filter(|f| !exc_fields.iter().contains(f))
244 .cloned()
245 .collect()),
246 }
247 }
248}
249
250impl Display for FieldSelection {
251 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
252 match self {
253 FieldSelection::Include(fields) => write!(f, "{{{}}}", DisplayFieldNames(fields)),
254 FieldSelection::Exclude(fields) => write!(f, "~{{{}}}", DisplayFieldNames(fields)),
255 }
256 }
257}
258
259#[cfg(test)]
260mod tests {
261 use vortex_buffer::buffer;
262 use vortex_dtype::{DType, FieldName, FieldNames, Nullability};
263
264 use super::{select, select_exclude};
265 use crate::arrays::StructArray;
266 use crate::expr::exprs::root::root;
267 use crate::expr::exprs::select::Select;
268 use crate::expr::test_harness;
269 use crate::{IntoArray, ToCanonical};
270
271 fn test_array() -> StructArray {
272 StructArray::from_fields(&[
273 ("a", buffer![0, 1, 2].into_array()),
274 ("b", buffer![4, 5, 6].into_array()),
275 ])
276 .unwrap()
277 }
278
279 #[test]
280 pub fn include_columns() {
281 let st = test_array();
282 let select = select(vec![FieldName::from("a")], root());
283 let selected = select.evaluate(&st.to_array()).unwrap().to_struct();
284 let selected_names = selected.names().clone();
285 assert_eq!(selected_names.as_ref(), &["a"]);
286 }
287
288 #[test]
289 pub fn exclude_columns() {
290 let st = test_array();
291 let select = select_exclude(vec![FieldName::from("a")], root());
292 let selected = select.evaluate(&st.to_array()).unwrap().to_struct();
293 let selected_names = selected.names().clone();
294 assert_eq!(selected_names.as_ref(), &["b"]);
295 }
296
297 #[test]
298 fn dtype() {
299 let dtype = test_harness::struct_dtype();
300
301 let select_expr = select(vec![FieldName::from("a")], root());
302 let expected_dtype = DType::Struct(
303 dtype
304 .as_struct_fields_opt()
305 .unwrap()
306 .project(&["a".into()])
307 .unwrap(),
308 Nullability::NonNullable,
309 );
310 assert_eq!(select_expr.return_dtype(&dtype).unwrap(), expected_dtype);
311
312 let select_expr_exclude = select_exclude(
313 vec![
314 FieldName::from("col1"),
315 FieldName::from("col2"),
316 FieldName::from("bool1"),
317 FieldName::from("bool2"),
318 ],
319 root(),
320 );
321 assert_eq!(
322 select_expr_exclude.return_dtype(&dtype).unwrap(),
323 expected_dtype
324 );
325
326 let select_expr_exclude = select_exclude(
327 vec![FieldName::from("col1"), FieldName::from("col2")],
328 root(),
329 );
330 assert_eq!(
331 select_expr_exclude.return_dtype(&dtype).unwrap(),
332 DType::Struct(
333 dtype
334 .as_struct_fields_opt()
335 .unwrap()
336 .project(&["a".into(), "bool1".into(), "bool2".into()])
337 .unwrap(),
338 Nullability::NonNullable
339 )
340 );
341 }
342
343 #[test]
344 fn test_as_include_names() {
345 let field_names = FieldNames::from(["a", "b", "c"]);
346 let include = select(["a"], root());
347 let exclude = select_exclude(["b", "c"], root());
348 assert_eq!(
349 &include
350 .as_::<Select>()
351 .data()
352 .as_include_names(&field_names)
353 .unwrap(),
354 &exclude
355 .as_::<Select>()
356 .data()
357 .as_include_names(&field_names)
358 .unwrap()
359 );
360 }
361}