1use std::any::Any;
2use std::fmt::Display;
3use std::sync::Arc;
4
5use itertools::Itertools;
6use vortex_array::{ArrayRef, IntoArray, ToCanonical};
7use vortex_dtype::{DType, FieldNames};
8use vortex_error::{VortexResult, vortex_bail, vortex_err};
9
10use crate::field::DisplayFieldNames;
11use crate::{AnalysisExpr, ExprRef, Scope, ScopeDType, VortexExpr};
12
13#[derive(Debug, Clone, PartialEq, Eq, Hash)]
14pub enum SelectField {
15 Include(FieldNames),
16 Exclude(FieldNames),
17}
18
19#[derive(Debug, Clone, Eq, Hash)]
20#[allow(clippy::derived_hash_with_manual_eq)]
21pub struct Select {
22 fields: SelectField,
23 child: ExprRef,
24}
25
26pub fn select(fields: impl Into<FieldNames>, child: ExprRef) -> ExprRef {
27 Select::include_expr(fields.into(), child)
28}
29
30pub fn select_exclude(fields: impl Into<FieldNames>, child: ExprRef) -> ExprRef {
31 Select::exclude_expr(fields.into(), child)
32}
33
34impl Select {
35 pub fn new_expr(fields: SelectField, child: ExprRef) -> ExprRef {
36 Arc::new(Self { fields, child })
37 }
38
39 pub fn include_expr(columns: FieldNames, child: ExprRef) -> ExprRef {
40 Self::new_expr(SelectField::Include(columns), child)
41 }
42
43 pub fn exclude_expr(columns: FieldNames, child: ExprRef) -> ExprRef {
44 Self::new_expr(SelectField::Exclude(columns), child)
45 }
46
47 pub fn fields(&self) -> &SelectField {
48 &self.fields
49 }
50
51 pub fn child(&self) -> &ExprRef {
52 &self.child
53 }
54
55 pub fn as_include(&self, field_names: &FieldNames) -> VortexResult<ExprRef> {
56 Ok(Self::new_expr(
57 SelectField::Include(self.fields.as_include_names(field_names)?),
58 self.child.clone(),
59 ))
60 }
61}
62
63impl SelectField {
64 pub fn include(columns: FieldNames) -> Self {
65 assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
66 Self::Include(columns)
67 }
68
69 pub fn exclude(columns: FieldNames) -> Self {
70 assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
71 Self::Exclude(columns)
72 }
73
74 pub fn is_include(&self) -> bool {
75 matches!(self, Self::Include(_))
76 }
77
78 pub fn is_exclude(&self) -> bool {
79 matches!(self, Self::Exclude(_))
80 }
81
82 pub fn fields(&self) -> &FieldNames {
83 match self {
84 SelectField::Include(fields) => fields,
85 SelectField::Exclude(fields) => fields,
86 }
87 }
88
89 pub fn as_include_names(&self, field_names: &FieldNames) -> VortexResult<FieldNames> {
90 if self
91 .fields()
92 .iter()
93 .any(|f| !field_names.iter().contains(f))
94 {
95 vortex_bail!(
96 "Field {:?} in select not in field names {:?}",
97 self,
98 field_names
99 );
100 }
101 match self {
102 SelectField::Include(fields) => Ok(fields.clone()),
103 SelectField::Exclude(exc_fields) => Ok(field_names
104 .iter()
105 .filter(|f| exc_fields.iter().contains(f))
106 .cloned()
107 .collect()),
108 }
109 }
110}
111
112impl Display for SelectField {
113 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 match self {
115 SelectField::Include(fields) => write!(f, "{{{}}}", DisplayFieldNames(fields)),
116 SelectField::Exclude(fields) => write!(f, "~{{{}}}", DisplayFieldNames(fields)),
117 }
118 }
119}
120
121impl Display for Select {
122 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123 write!(f, "{}{}", self.child, self.fields)
124 }
125}
126
127#[cfg(feature = "proto")]
128pub(crate) mod proto {
129 use vortex_error::{VortexResult, vortex_bail};
130 use vortex_proto::expr::kind::Kind;
131
132 use crate::{ExprDeserialize, ExprRef, ExprSerializable, Id, Select};
133
134 pub struct SelectSerde;
135
136 impl Id for SelectSerde {
137 fn id(&self) -> &'static str {
138 "select"
139 }
140 }
141
142 impl ExprDeserialize for SelectSerde {
143 fn deserialize(&self, _kind: &Kind, _children: Vec<ExprRef>) -> VortexResult<ExprRef> {
144 vortex_bail!(NotImplemented: "", self.id())
145 }
146 }
147
148 impl ExprSerializable for Select {
149 fn id(&self) -> &'static str {
150 SelectSerde.id()
151 }
152
153 fn serialize_kind(&self) -> VortexResult<Kind> {
154 vortex_bail!(NotImplemented: "", self.id())
155 }
156 }
157}
158
159impl AnalysisExpr for Select {}
160
161impl VortexExpr for Select {
162 fn as_any(&self) -> &dyn Any {
163 self
164 }
165
166 fn unchecked_evaluate(&self, scope: &Scope) -> VortexResult<ArrayRef> {
167 let batch = self.child.unchecked_evaluate(scope)?.to_struct()?;
168 Ok(match &self.fields {
169 SelectField::Include(f) => batch.project(f.as_ref()),
170 SelectField::Exclude(names) => {
171 let included_names = batch
172 .names()
173 .iter()
174 .filter(|&f| !names.iter().contains(f))
175 .cloned()
176 .collect::<Vec<_>>();
177 batch.project(included_names.as_slice())
178 }
179 }?
180 .into_array())
181 }
182
183 fn children(&self) -> Vec<&ExprRef> {
184 vec![&self.child]
185 }
186
187 fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef {
188 assert_eq!(children.len(), 1);
189 Self::new_expr(self.fields.clone(), children[0].clone())
190 }
191
192 fn return_dtype(&self, scope: &ScopeDType) -> VortexResult<DType> {
193 let child_dtype = self.child.return_dtype(scope)?;
194 let child_struct_dtype = child_dtype
195 .as_struct()
196 .ok_or_else(|| vortex_err!("Select child not a struct dtype"))?;
197
198 let projected = match &self.fields {
199 SelectField::Include(fields) => child_struct_dtype.project(fields.as_ref())?,
200 SelectField::Exclude(fields) => child_struct_dtype
201 .names()
202 .iter()
203 .cloned()
204 .zip_eq(child_struct_dtype.fields())
205 .filter(|(name, _)| !fields.iter().contains(name))
206 .collect(),
207 };
208
209 Ok(DType::Struct(projected, child_dtype.nullability()))
210 }
211}
212
213impl PartialEq for Select {
214 fn eq(&self, other: &Select) -> bool {
215 self.fields == other.fields && self.child.eq(&other.child)
216 }
217}
218
219#[cfg(test)]
220mod tests {
221
222 use vortex_array::arrays::StructArray;
223 use vortex_array::{IntoArray, ToCanonical};
224 use vortex_buffer::buffer;
225 use vortex_dtype::{DType, FieldName, Nullability};
226
227 use crate::{Scope, ScopeDType, root, select, select_exclude, test_harness};
228
229 fn test_array() -> StructArray {
230 StructArray::from_fields(&[
231 ("a", buffer![0, 1, 2].into_array()),
232 ("b", buffer![4, 5, 6].into_array()),
233 ])
234 .unwrap()
235 }
236
237 #[test]
238 pub fn include_columns() {
239 let st = test_array();
240 let select = select(vec![FieldName::from("a")], root());
241 let selected = select
242 .evaluate(&Scope::new(st.to_array()))
243 .unwrap()
244 .to_struct()
245 .unwrap();
246 let selected_names = selected.names().clone();
247 assert_eq!(selected_names.as_ref(), &["a".into()]);
248 }
249
250 #[test]
251 pub fn exclude_columns() {
252 let st = test_array();
253 let select = select_exclude(vec![FieldName::from("a")], root());
254 let selected = select
255 .evaluate(&Scope::new(st.to_array()))
256 .unwrap()
257 .to_struct()
258 .unwrap();
259 let selected_names = selected.names().clone();
260 assert_eq!(selected_names.as_ref(), &["b".into()]);
261 }
262
263 #[test]
264 fn dtype() {
265 let dtype = test_harness::struct_dtype();
266
267 let select_expr = select(vec![FieldName::from("a")], root());
268 let expected_dtype = DType::Struct(
269 dtype.as_struct().unwrap().project(&["a".into()]).unwrap(),
270 Nullability::NonNullable,
271 );
272 assert_eq!(
273 select_expr
274 .return_dtype(&ScopeDType::new(dtype.clone()))
275 .unwrap(),
276 expected_dtype
277 );
278
279 let select_expr_exclude = select_exclude(
280 vec![
281 FieldName::from("col1"),
282 FieldName::from("col2"),
283 FieldName::from("bool1"),
284 FieldName::from("bool2"),
285 ],
286 root(),
287 );
288 assert_eq!(
289 select_expr_exclude
290 .return_dtype(&ScopeDType::new(dtype.clone()))
291 .unwrap(),
292 expected_dtype
293 );
294
295 let select_expr_exclude = select_exclude(
296 vec![FieldName::from("col1"), FieldName::from("col2")],
297 root(),
298 );
299 assert_eq!(
300 select_expr_exclude
301 .return_dtype(&ScopeDType::new(dtype.clone()))
302 .unwrap(),
303 DType::Struct(
304 dtype
305 .as_struct()
306 .unwrap()
307 .project(&["a".into(), "bool1".into(), "bool2".into()])
308 .unwrap(),
309 Nullability::NonNullable
310 )
311 );
312 }
313}