1use std::fmt::Display;
5
6use itertools::Itertools;
7use vortex_array::{ArrayRef, DeserializeMetadata, IntoArray, ProstMetadata, ToCanonical};
8use vortex_dtype::{DType, FieldNames};
9use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
10use vortex_proto::expr::select_opts::Opts;
11use vortex_proto::expr::{FieldNames as ProtoFieldNames, SelectOpts};
12
13use crate::field::DisplayFieldNames;
14use crate::{AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, VTable, vtable};
15
16#[derive(Debug, Clone, PartialEq, Eq, Hash)]
17pub enum SelectField {
18 Include(FieldNames),
19 Exclude(FieldNames),
20}
21
22vtable!(Select);
23
24#[derive(Debug, Clone, Hash, Eq)]
25#[allow(clippy::derived_hash_with_manual_eq)]
26pub struct SelectExpr {
27 fields: SelectField,
28 child: ExprRef,
29}
30
31impl PartialEq for SelectExpr {
32 fn eq(&self, other: &Self) -> bool {
33 self.fields == other.fields && self.child.eq(&other.child)
34 }
35}
36
37pub struct SelectExprEncoding;
38
39impl VTable for SelectVTable {
40 type Expr = SelectExpr;
41 type Encoding = SelectExprEncoding;
42 type Metadata = ProstMetadata<SelectOpts>;
43
44 fn id(_encoding: &Self::Encoding) -> ExprId {
45 ExprId::new_ref("select")
46 }
47
48 fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
49 ExprEncodingRef::new_ref(SelectExprEncoding.as_ref())
50 }
51
52 fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
53 let names = expr
54 .fields()
55 .fields()
56 .iter()
57 .map(|f| f.to_string())
58 .collect_vec();
59
60 let opts = if expr.fields().is_include() {
61 Opts::Include(ProtoFieldNames { names })
62 } else {
63 Opts::Exclude(ProtoFieldNames { names })
64 };
65
66 Some(ProstMetadata(SelectOpts { opts: Some(opts) }))
67 }
68
69 fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
70 vec![&expr.child]
71 }
72
73 fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
74 Ok(SelectExpr {
75 fields: expr.fields.clone(),
76 child: children[0].clone(),
77 })
78 }
79
80 fn build(
81 _encoding: &Self::Encoding,
82 metadata: &<Self::Metadata as DeserializeMetadata>::Output,
83 mut children: Vec<ExprRef>,
84 ) -> VortexResult<Self::Expr> {
85 if children.len() != 1 {
86 vortex_bail!("Select expression must have exactly one child");
87 }
88
89 let fields = match metadata.opts.as_ref() {
90 Some(opts) => match opts {
91 Opts::Include(field_names) => SelectField::Include(FieldNames::from_iter(
92 field_names.names.iter().map(|s| s.as_str()),
93 )),
94 Opts::Exclude(field_names) => SelectField::Exclude(FieldNames::from_iter(
95 field_names.names.iter().map(|s| s.as_str()),
96 )),
97 },
98 None => {
99 vortex_bail!("Select expressions must be provided with fields to select or exclude")
100 }
101 };
102
103 let child = children
104 .drain(..)
105 .next()
106 .vortex_expect("number of children validated to be one");
107
108 Ok(SelectExpr { fields, child })
109 }
110
111 fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
112 let batch = expr.child.unchecked_evaluate(scope)?.to_struct()?;
113 Ok(match &expr.fields {
114 SelectField::Include(f) => batch.project(f.as_ref()),
115 SelectField::Exclude(names) => {
116 let included_names = batch
117 .names()
118 .iter()
119 .filter(|&f| !names.as_ref().contains(f))
120 .cloned()
121 .collect::<Vec<_>>();
122 batch.project(included_names.as_slice())
123 }
124 }?
125 .into_array())
126 }
127
128 fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
129 let child_dtype = expr.child.return_dtype(scope)?;
130 let child_struct_dtype = child_dtype
131 .as_struct()
132 .ok_or_else(|| vortex_err!("Select child not a struct dtype"))?;
133
134 let projected = match &expr.fields {
135 SelectField::Include(fields) => child_struct_dtype.project(fields.as_ref())?,
136 SelectField::Exclude(fields) => child_struct_dtype
137 .names()
138 .iter()
139 .cloned()
140 .zip_eq(child_struct_dtype.fields())
141 .filter(|(name, _)| !fields.as_ref().contains(name))
142 .collect(),
143 };
144
145 Ok(DType::Struct(projected, child_dtype.nullability()))
146 }
147}
148
149pub fn select(fields: impl Into<FieldNames>, child: ExprRef) -> ExprRef {
150 SelectExpr::include_expr(fields.into(), child)
151}
152
153pub fn select_exclude(fields: impl Into<FieldNames>, child: ExprRef) -> ExprRef {
154 SelectExpr::exclude_expr(fields.into(), child)
155}
156
157impl SelectExpr {
158 pub fn new(fields: SelectField, child: ExprRef) -> Self {
159 Self { fields, child }
160 }
161
162 pub fn new_expr(fields: SelectField, child: ExprRef) -> ExprRef {
163 Self::new(fields, child).into_expr()
164 }
165
166 pub fn include_expr(columns: FieldNames, child: ExprRef) -> ExprRef {
167 Self::new(SelectField::Include(columns), child).into_expr()
168 }
169
170 pub fn exclude_expr(columns: FieldNames, child: ExprRef) -> ExprRef {
171 Self::new(SelectField::Exclude(columns), child).into_expr()
172 }
173
174 pub fn fields(&self) -> &SelectField {
175 &self.fields
176 }
177
178 pub fn child(&self) -> &ExprRef {
179 &self.child
180 }
181
182 pub fn as_include(&self, field_names: &FieldNames) -> VortexResult<ExprRef> {
183 Ok(Self::new(
184 SelectField::Include(self.fields.as_include_names(field_names)?),
185 self.child.clone(),
186 )
187 .into_expr())
188 }
189}
190
191impl SelectField {
192 pub fn include(columns: FieldNames) -> Self {
193 assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
194 Self::Include(columns)
195 }
196
197 pub fn exclude(columns: FieldNames) -> Self {
198 assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
199 Self::Exclude(columns)
200 }
201
202 pub fn is_include(&self) -> bool {
203 matches!(self, Self::Include(_))
204 }
205
206 pub fn is_exclude(&self) -> bool {
207 matches!(self, Self::Exclude(_))
208 }
209
210 pub fn fields(&self) -> &FieldNames {
211 let (SelectField::Include(fields) | SelectField::Exclude(fields)) = self;
212
213 fields
214 }
215
216 pub fn as_include_names(&self, field_names: &FieldNames) -> VortexResult<FieldNames> {
217 if self
218 .fields()
219 .iter()
220 .any(|f| !field_names.iter().contains(f))
221 {
222 vortex_bail!(
223 "Field {:?} in select not in field names {:?}",
224 self,
225 field_names
226 );
227 }
228 match self {
229 SelectField::Include(fields) => Ok(fields.clone()),
230 SelectField::Exclude(exc_fields) => Ok(field_names
231 .iter()
232 .filter(|f| exc_fields.iter().contains(f))
233 .cloned()
234 .collect()),
235 }
236 }
237}
238
239impl Display for SelectField {
240 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
241 match self {
242 SelectField::Include(fields) => write!(f, "{{{}}}", DisplayFieldNames(fields)),
243 SelectField::Exclude(fields) => write!(f, "~{{{}}}", DisplayFieldNames(fields)),
244 }
245 }
246}
247
248impl Display for SelectExpr {
249 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
250 write!(f, "{}{}", self.child, self.fields)
251 }
252}
253
254impl AnalysisExpr for SelectExpr {}
255
256#[cfg(test)]
257mod tests {
258
259 use vortex_array::arrays::StructArray;
260 use vortex_array::{IntoArray, ToCanonical};
261 use vortex_buffer::buffer;
262 use vortex_dtype::{DType, FieldName, Nullability};
263
264 use crate::{Scope, root, select, select_exclude, test_harness};
265
266 fn test_array() -> StructArray {
267 StructArray::from_fields(&[
268 ("a", buffer![0, 1, 2].into_array()),
269 ("b", buffer![4, 5, 6].into_array()),
270 ])
271 .unwrap()
272 }
273
274 #[test]
275 pub fn include_columns() {
276 let st = test_array();
277 let select = select(vec![FieldName::from("a")], root());
278 let selected = select
279 .evaluate(&Scope::new(st.to_array()))
280 .unwrap()
281 .to_struct()
282 .unwrap();
283 let selected_names = selected.names().clone();
284 assert_eq!(selected_names.as_ref(), &["a".into()]);
285 }
286
287 #[test]
288 pub fn exclude_columns() {
289 let st = test_array();
290 let select = select_exclude(vec![FieldName::from("a")], root());
291 let selected = select
292 .evaluate(&Scope::new(st.to_array()))
293 .unwrap()
294 .to_struct()
295 .unwrap();
296 let selected_names = selected.names().clone();
297 assert_eq!(selected_names.as_ref(), &["b".into()]);
298 }
299
300 #[test]
301 fn dtype() {
302 let dtype = test_harness::struct_dtype();
303
304 let select_expr = select(vec![FieldName::from("a")], root());
305 let expected_dtype = DType::Struct(
306 dtype.as_struct().unwrap().project(&["a".into()]).unwrap(),
307 Nullability::NonNullable,
308 );
309 assert_eq!(select_expr.return_dtype(&dtype).unwrap(), expected_dtype);
310
311 let select_expr_exclude = select_exclude(
312 vec![
313 FieldName::from("col1"),
314 FieldName::from("col2"),
315 FieldName::from("bool1"),
316 FieldName::from("bool2"),
317 ],
318 root(),
319 );
320 assert_eq!(
321 select_expr_exclude.return_dtype(&dtype).unwrap(),
322 expected_dtype
323 );
324
325 let select_expr_exclude = select_exclude(
326 vec![FieldName::from("col1"), FieldName::from("col2")],
327 root(),
328 );
329 assert_eq!(
330 select_expr_exclude.return_dtype(&dtype).unwrap(),
331 DType::Struct(
332 dtype
333 .as_struct()
334 .unwrap()
335 .project(&["a".into(), "bool1".into(), "bool2".into()])
336 .unwrap(),
337 Nullability::NonNullable
338 )
339 );
340 }
341}