1use std::any::Any;
2use std::fmt::Display;
3use std::hash::Hash;
4use std::sync::Arc;
5
6use itertools::Itertools as _;
7use vortex_array::arrays::StructArray;
8use vortex_array::validity::Validity;
9use vortex_array::{Array, ArrayRef, IntoArray};
10use vortex_dtype::{DType, FieldName, FieldNames, Nullability, StructDType};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_err};
12
13use crate::{ExprRef, VortexExpr};
14
15#[derive(Debug, Clone, PartialEq, Eq, Hash)]
44pub struct Pack {
45 names: FieldNames,
46 values: Vec<ExprRef>,
47 nullability: Nullability,
48}
49
50impl Pack {
51 pub fn try_new_expr(
52 names: FieldNames,
53 values: Vec<ExprRef>,
54 nullability: Nullability,
55 ) -> VortexResult<Arc<Self>> {
56 if names.len() != values.len() {
57 vortex_bail!("length mismatch {} {}", names.len(), values.len());
58 }
59 Ok(Arc::new(Pack {
60 names,
61 values,
62 nullability,
63 }))
64 }
65
66 pub fn names(&self) -> &FieldNames {
67 &self.names
68 }
69
70 pub fn field(&self, field_name: &FieldName) -> VortexResult<ExprRef> {
71 let idx = self
72 .names
73 .iter()
74 .position(|name| name == field_name)
75 .ok_or_else(|| {
76 vortex_err!(
77 "Cannot find field {} in pack fields {:?}",
78 field_name,
79 self.names
80 )
81 })?;
82
83 self.values
84 .get(idx)
85 .cloned()
86 .ok_or_else(|| vortex_err!("field index out of bounds: {}", idx))
87 }
88}
89
90pub fn pack(
91 elements: impl IntoIterator<Item = (impl Into<FieldName>, ExprRef)>,
92 nullability: Nullability,
93) -> ExprRef {
94 let (names, values): (Vec<_>, Vec<_>) = elements
95 .into_iter()
96 .map(|(name, value)| (name.into(), value))
97 .unzip();
98 Pack::try_new_expr(names.into(), values, nullability)
99 .vortex_expect("pack names and values have the same length")
100}
101
102impl Display for Pack {
103 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104 f.write_str("{")?;
105 self.names
106 .iter()
107 .zip(&self.values)
108 .format_with(", ", |(name, expr), f| f(&format_args!("{name}: {expr}")))
109 .fmt(f)?;
110 f.write_str("}")
111 }
112}
113
114#[cfg(feature = "proto")]
115pub(crate) mod proto {
116 use vortex_error::{VortexResult, vortex_bail};
117 use vortex_proto::expr::kind::Kind;
118
119 use crate::{ExprDeserialize, ExprRef, ExprSerializable, Id, Pack};
120
121 pub struct PackSerde;
122
123 impl Id for PackSerde {
124 fn id(&self) -> &'static str {
125 "pack"
126 }
127 }
128
129 impl ExprDeserialize for PackSerde {
130 fn deserialize(&self, _kind: &Kind, _children: Vec<ExprRef>) -> VortexResult<ExprRef> {
131 todo!()
132 }
133 }
134
135 impl ExprSerializable for Pack {
136 fn id(&self) -> &'static str {
137 PackSerde.id()
138 }
139
140 fn serialize_kind(&self) -> VortexResult<Kind> {
141 vortex_bail!(NotImplemented: "", self.id())
142 }
143 }
144}
145
146impl VortexExpr for Pack {
147 fn as_any(&self) -> &dyn Any {
148 self
149 }
150
151 fn unchecked_evaluate(&self, batch: &dyn Array) -> VortexResult<ArrayRef> {
152 let len = batch.len();
153 let value_arrays = self
154 .values
155 .iter()
156 .map(|value_expr| value_expr.evaluate(batch))
157 .process_results(|it| it.collect::<Vec<_>>())?;
158 let validity = match self.nullability {
159 Nullability::NonNullable => Validity::NonNullable,
160 Nullability::Nullable => Validity::AllValid,
161 };
162 Ok(StructArray::try_new(self.names.clone(), value_arrays, len, validity)?.into_array())
163 }
164
165 fn children(&self) -> Vec<&ExprRef> {
166 self.values.iter().collect()
167 }
168
169 fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef {
170 assert_eq!(children.len(), self.values.len());
171 Self::try_new_expr(self.names.clone(), children, self.nullability)
172 .vortex_expect("children are known to have the same length as names")
173 }
174
175 fn return_dtype(&self, scope_dtype: &DType) -> VortexResult<DType> {
176 let value_dtypes = self
177 .values
178 .iter()
179 .map(|value_expr| value_expr.return_dtype(scope_dtype))
180 .process_results(|it| it.collect())?;
181 Ok(DType::Struct(
182 Arc::new(StructDType::new(self.names.clone(), value_dtypes)),
183 self.nullability,
184 ))
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use std::sync::Arc;
191
192 use vortex_array::arrays::{PrimitiveArray, StructArray};
193 use vortex_array::validity::Validity;
194 use vortex_array::vtable::ValidityHelper;
195 use vortex_array::{Array, IntoArray, ToCanonical};
196 use vortex_buffer::buffer;
197 use vortex_dtype::{FieldNames, Nullability};
198 use vortex_error::{VortexResult, vortex_bail};
199
200 use crate::{Pack, VortexExpr, col};
201
202 fn test_array() -> StructArray {
203 StructArray::from_fields(&[
204 ("a", buffer![0, 1, 2].into_array()),
205 ("b", buffer![4, 5, 6].into_array()),
206 ])
207 .unwrap()
208 }
209
210 fn primitive_field(array: &dyn Array, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
211 let mut field_path = field_path.iter();
212
213 let Some(field) = field_path.next() else {
214 vortex_bail!("empty field path");
215 };
216
217 let mut array = array.to_struct()?.field_by_name(field)?.clone();
218 for field in field_path {
219 array = array.to_struct()?.field_by_name(field)?.clone();
220 }
221 Ok(array.to_primitive().unwrap())
222 }
223
224 #[test]
225 pub fn test_empty_pack() {
226 let expr = Pack::try_new_expr(Arc::new([]), Vec::new(), Nullability::NonNullable).unwrap();
227
228 let test_array = test_array().into_array();
229 let actual_array = expr.evaluate(&test_array).unwrap();
230 assert_eq!(actual_array.len(), test_array.len());
231 assert_eq!(
232 actual_array.to_struct().unwrap().struct_dtype().nfields(),
233 0
234 );
235 }
236
237 #[test]
238 pub fn test_simple_pack() {
239 let expr = Pack::try_new_expr(
240 ["one".into(), "two".into(), "three".into()].into(),
241 vec![col("a"), col("b"), col("a")],
242 Nullability::NonNullable,
243 )
244 .unwrap();
245
246 let actual_array = expr
247 .evaluate(test_array().as_ref())
248 .unwrap()
249 .to_struct()
250 .unwrap();
251 let expected_names: FieldNames = ["one".into(), "two".into(), "three".into()].into();
252 assert_eq!(actual_array.names(), &expected_names);
253 assert_eq!(actual_array.validity(), &Validity::NonNullable);
254
255 assert_eq!(
256 primitive_field(actual_array.as_ref(), &["one"])
257 .unwrap()
258 .as_slice::<i32>(),
259 [0, 1, 2]
260 );
261 assert_eq!(
262 primitive_field(actual_array.as_ref(), &["two"])
263 .unwrap()
264 .as_slice::<i32>(),
265 [4, 5, 6]
266 );
267 assert_eq!(
268 primitive_field(actual_array.as_ref(), &["three"])
269 .unwrap()
270 .as_slice::<i32>(),
271 [0, 1, 2]
272 );
273 }
274
275 #[test]
276 pub fn test_nested_pack() {
277 let expr = Pack::try_new_expr(
278 ["one".into(), "two".into(), "three".into()].into(),
279 vec![
280 col("a"),
281 Pack::try_new_expr(
282 ["two_one".into(), "two_two".into()].into(),
283 vec![col("b"), col("b")],
284 Nullability::NonNullable,
285 )
286 .unwrap(),
287 col("a"),
288 ],
289 Nullability::NonNullable,
290 )
291 .unwrap();
292
293 let actual_array = expr
294 .evaluate(test_array().as_ref())
295 .unwrap()
296 .to_struct()
297 .unwrap();
298 let expected_names: FieldNames = ["one".into(), "two".into(), "three".into()].into();
299 assert_eq!(actual_array.names(), &expected_names);
300
301 assert_eq!(
302 primitive_field(actual_array.as_ref(), &["one"])
303 .unwrap()
304 .as_slice::<i32>(),
305 [0, 1, 2]
306 );
307 assert_eq!(
308 primitive_field(actual_array.as_ref(), &["two", "two_one"])
309 .unwrap()
310 .as_slice::<i32>(),
311 [4, 5, 6]
312 );
313 assert_eq!(
314 primitive_field(actual_array.as_ref(), &["two", "two_two"])
315 .unwrap()
316 .as_slice::<i32>(),
317 [4, 5, 6]
318 );
319 assert_eq!(
320 primitive_field(actual_array.as_ref(), &["three"])
321 .unwrap()
322 .as_slice::<i32>(),
323 [0, 1, 2]
324 );
325 }
326
327 #[test]
328 pub fn test_pack_nullable() {
329 let expr = Pack::try_new_expr(
330 ["one".into(), "two".into(), "three".into()].into(),
331 vec![col("a"), col("b"), col("a")],
332 Nullability::Nullable,
333 )
334 .unwrap();
335
336 let actual_array = expr
337 .evaluate(test_array().as_ref())
338 .unwrap()
339 .to_struct()
340 .unwrap();
341 let expected_names: FieldNames = ["one".into(), "two".into(), "three".into()].into();
342 assert_eq!(actual_array.names(), &expected_names);
343 assert_eq!(actual_array.validity(), &Validity::AllValid);
344 }
345}