1use std::any::Any;
2use std::fmt::Display;
3use std::sync::Arc;
4
5use itertools::Itertools;
6use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
7use vortex_dtype::{DType, FieldNames};
8use vortex_error::{VortexResult, vortex_bail, vortex_err};
9
10use crate::field::DisplayFieldNames;
11use crate::{ExprRef, VortexExpr};
12
13#[derive(Debug, Clone, PartialEq, Eq, Hash)]
14pub enum SelectField {
15 Include(FieldNames),
16 Exclude(FieldNames),
17}
18
19#[derive(Debug, Clone, Eq, Hash)]
20#[allow(clippy::derived_hash_with_manual_eq)]
21pub struct Select {
22 fields: SelectField,
23 child: ExprRef,
24}
25
26pub fn select(fields: impl Into<FieldNames>, child: ExprRef) -> ExprRef {
27 Select::include_expr(fields.into(), child)
28}
29
30pub fn select_exclude(fields: impl Into<FieldNames>, child: ExprRef) -> ExprRef {
31 Select::exclude_expr(fields.into(), child)
32}
33
34impl Select {
35 pub fn new_expr(fields: SelectField, child: ExprRef) -> ExprRef {
36 Arc::new(Self { fields, child })
37 }
38
39 pub fn include_expr(columns: FieldNames, child: ExprRef) -> ExprRef {
40 Self::new_expr(SelectField::Include(columns), child)
41 }
42
43 pub fn exclude_expr(columns: FieldNames, child: ExprRef) -> ExprRef {
44 Self::new_expr(SelectField::Exclude(columns), child)
45 }
46
47 pub fn fields(&self) -> &SelectField {
48 &self.fields
49 }
50
51 pub fn child(&self) -> &ExprRef {
52 &self.child
53 }
54
55 pub fn as_include(&self, field_names: &FieldNames) -> VortexResult<ExprRef> {
56 Ok(Self::new_expr(
57 SelectField::Include(self.fields.as_include_names(field_names)?),
58 self.child.clone(),
59 ))
60 }
61}
62
63impl SelectField {
64 pub fn include(columns: FieldNames) -> Self {
65 assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
66 Self::Include(columns)
67 }
68
69 pub fn exclude(columns: FieldNames) -> Self {
70 assert_eq!(columns.iter().unique().collect_vec().len(), columns.len());
71 Self::Exclude(columns)
72 }
73
74 pub fn is_include(&self) -> bool {
75 matches!(self, Self::Include(_))
76 }
77
78 pub fn is_exclude(&self) -> bool {
79 matches!(self, Self::Exclude(_))
80 }
81
82 pub fn fields(&self) -> &FieldNames {
83 match self {
84 SelectField::Include(fields) => fields,
85 SelectField::Exclude(fields) => fields,
86 }
87 }
88
89 pub fn as_include_names(&self, field_names: &FieldNames) -> VortexResult<FieldNames> {
90 if self.fields().iter().any(|f| !field_names.contains(f)) {
91 vortex_bail!(
92 "Field {:?} in select not in field names {:?}",
93 self,
94 field_names
95 );
96 }
97 match self {
98 SelectField::Include(fields) => Ok(fields.clone()),
99 SelectField::Exclude(exc_fields) => Ok(field_names
100 .iter()
101 .filter(|f| exc_fields.contains(f))
102 .cloned()
103 .collect()),
104 }
105 }
106}
107
108impl Display for SelectField {
109 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
110 match self {
111 SelectField::Include(fields) => write!(f, "{{{}}}", DisplayFieldNames(fields)),
112 SelectField::Exclude(fields) => write!(f, "~{{{}}}", DisplayFieldNames(fields)),
113 }
114 }
115}
116
117impl Display for Select {
118 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119 write!(f, "{}{}", self.child, self.fields)
120 }
121}
122
123#[cfg(feature = "proto")]
124pub(crate) mod proto {
125 use vortex_error::{VortexResult, vortex_bail};
126 use vortex_proto::expr::kind::Kind;
127
128 use crate::{ExprDeserialize, ExprRef, ExprSerializable, Id, Select};
129
130 pub struct SelectSerde;
131
132 impl Id for SelectSerde {
133 fn id(&self) -> &'static str {
134 "select"
135 }
136 }
137
138 impl ExprDeserialize for SelectSerde {
139 fn deserialize(&self, _kind: &Kind, _children: Vec<ExprRef>) -> VortexResult<ExprRef> {
140 vortex_bail!(NotImplemented: "", self.id())
141 }
142 }
143
144 impl ExprSerializable for Select {
145 fn id(&self) -> &'static str {
146 SelectSerde.id()
147 }
148
149 fn serialize_kind(&self) -> VortexResult<Kind> {
150 vortex_bail!(NotImplemented: "", self.id())
151 }
152 }
153}
154
155impl VortexExpr for Select {
156 fn as_any(&self) -> &dyn Any {
157 self
158 }
159
160 fn unchecked_evaluate(&self, batch: &dyn Array) -> VortexResult<ArrayRef> {
161 let batch = self.child.evaluate(batch)?.to_struct()?;
162 Ok(match &self.fields {
163 SelectField::Include(f) => batch.project(f),
164 SelectField::Exclude(names) => {
165 let included_names = batch
166 .names()
167 .iter()
168 .filter(|&f| !names.contains(f))
169 .cloned()
170 .collect::<Vec<_>>();
171 batch.project(included_names.as_slice())
172 }
173 }?
174 .into_array())
175 }
176
177 fn children(&self) -> Vec<&ExprRef> {
178 vec![&self.child]
179 }
180
181 fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef {
182 assert_eq!(children.len(), 1);
183 Self::new_expr(self.fields.clone(), children[0].clone())
184 }
185
186 fn return_dtype(&self, scope_dtype: &DType) -> VortexResult<DType> {
187 let child_dtype = self.child.return_dtype(scope_dtype)?;
188 let child_struct_dtype = child_dtype
189 .as_struct()
190 .ok_or_else(|| vortex_err!("Select child not a struct dtype"))?;
191
192 let projected = match &self.fields {
193 SelectField::Include(fields) => child_struct_dtype.project(fields)?,
194 SelectField::Exclude(fields) => child_struct_dtype
195 .names()
196 .iter()
197 .cloned()
198 .zip_eq(child_struct_dtype.fields())
199 .filter(|(name, _)| !fields.contains(name))
200 .collect(),
201 };
202
203 Ok(DType::Struct(
204 Arc::new(projected),
205 child_dtype.nullability(),
206 ))
207 }
208}
209
210impl PartialEq for Select {
211 fn eq(&self, other: &Select) -> bool {
212 self.fields == other.fields && self.child.eq(&other.child)
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use std::sync::Arc;
219
220 use vortex_array::arrays::StructArray;
221 use vortex_array::{IntoArray, ToCanonical};
222 use vortex_buffer::buffer;
223 use vortex_dtype::{DType, FieldName, Nullability};
224
225 use crate::{ident, select, select_exclude, test_harness};
226
227 fn test_array() -> StructArray {
228 StructArray::from_fields(&[
229 ("a", buffer![0, 1, 2].into_array()),
230 ("b", buffer![4, 5, 6].into_array()),
231 ])
232 .unwrap()
233 }
234
235 #[test]
236 pub fn include_columns() {
237 let st = test_array();
238 let select = select(vec![FieldName::from("a")], ident());
239 let selected = select.evaluate(st.as_ref()).unwrap().to_struct().unwrap();
240 let selected_names = selected.names().clone();
241 assert_eq!(selected_names.as_ref(), &["a".into()]);
242 }
243
244 #[test]
245 pub fn exclude_columns() {
246 let st = test_array();
247 let select = select_exclude(vec![FieldName::from("a")], ident());
248 let selected = select.evaluate(st.as_ref()).unwrap().to_struct().unwrap();
249 let selected_names = selected.names().clone();
250 assert_eq!(selected_names.as_ref(), &["b".into()]);
251 }
252
253 #[test]
254 fn dtype() {
255 let dtype = test_harness::struct_dtype();
256
257 let select_expr = select(vec![FieldName::from("a")], ident());
258 let expected_dtype = DType::Struct(
259 Arc::new(dtype.as_struct().unwrap().project(&["a".into()]).unwrap()),
260 Nullability::NonNullable,
261 );
262 assert_eq!(select_expr.return_dtype(&dtype).unwrap(), expected_dtype);
263
264 let select_expr_exclude = select_exclude(
265 vec![
266 FieldName::from("col1"),
267 FieldName::from("col2"),
268 FieldName::from("bool1"),
269 FieldName::from("bool2"),
270 ],
271 ident(),
272 );
273 assert_eq!(
274 select_expr_exclude.return_dtype(&dtype).unwrap(),
275 expected_dtype
276 );
277
278 let select_expr_exclude = select_exclude(
279 vec![FieldName::from("col1"), FieldName::from("col2")],
280 ident(),
281 );
282 assert_eq!(
283 select_expr_exclude.return_dtype(&dtype).unwrap(),
284 DType::Struct(
285 Arc::new(
286 dtype
287 .as_struct()
288 .unwrap()
289 .project(&["a".into(), "bool1".into(), "bool2".into()])
290 .unwrap()
291 ),
292 Nullability::NonNullable
293 )
294 );
295 }
296}