1use std::any::Any;
2use std::fmt::Display;
3use std::hash::Hash;
4use std::sync::Arc;
5
6use itertools::Itertools as _;
7use vortex_array::arrays::StructArray;
8use vortex_array::validity::Validity;
9use vortex_array::{ArrayRef, IntoArray};
10use vortex_dtype::{DType, FieldName, FieldNames, Nullability, StructFields};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_err};
12
13use crate::{AnalysisExpr, ExprRef, Scope, ScopeDType, VortexExpr};
14
15#[derive(Debug, Clone, PartialEq, Eq, Hash)]
44pub struct Pack {
45 names: FieldNames,
46 values: Vec<ExprRef>,
47 nullability: Nullability,
48}
49
50impl Pack {
51 pub fn try_new_expr(
52 names: FieldNames,
53 values: Vec<ExprRef>,
54 nullability: Nullability,
55 ) -> VortexResult<Arc<Self>> {
56 if names.len() != values.len() {
57 vortex_bail!("length mismatch {} {}", names.len(), values.len());
58 }
59 Ok(Arc::new(Pack {
60 names,
61 values,
62 nullability,
63 }))
64 }
65
66 pub fn names(&self) -> &FieldNames {
67 &self.names
68 }
69
70 pub fn field(&self, field_name: &FieldName) -> VortexResult<ExprRef> {
71 let idx = self
72 .names
73 .iter()
74 .position(|name| name == field_name)
75 .ok_or_else(|| {
76 vortex_err!(
77 "Cannot find field {} in pack fields {:?}",
78 field_name,
79 self.names
80 )
81 })?;
82
83 self.values
84 .get(idx)
85 .cloned()
86 .ok_or_else(|| vortex_err!("field index out of bounds: {}", idx))
87 }
88}
89
90pub fn pack(
91 elements: impl IntoIterator<Item = (impl Into<FieldName>, ExprRef)>,
92 nullability: Nullability,
93) -> ExprRef {
94 let (names, values): (Vec<_>, Vec<_>) = elements
95 .into_iter()
96 .map(|(name, value)| (name.into(), value))
97 .unzip();
98 Pack::try_new_expr(names.into(), values, nullability)
99 .vortex_expect("pack names and values have the same length")
100}
101
102impl Display for Pack {
103 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104 f.write_str("{")?;
105 self.names
106 .iter()
107 .zip(&self.values)
108 .format_with(", ", |(name, expr), f| f(&format_args!("{name}: {expr}")))
109 .fmt(f)?;
110 f.write_str("}")
111 }
112}
113
114#[cfg(feature = "proto")]
115pub(crate) mod proto {
116 use vortex_error::{VortexResult, vortex_bail};
117 use vortex_proto::expr::kind::Kind;
118
119 use crate::{ExprDeserialize, ExprRef, ExprSerializable, Id, Pack};
120
121 pub struct PackSerde;
122
123 impl Id for PackSerde {
124 fn id(&self) -> &'static str {
125 "pack"
126 }
127 }
128
129 impl ExprDeserialize for PackSerde {
130 fn deserialize(&self, _kind: &Kind, _children: Vec<ExprRef>) -> VortexResult<ExprRef> {
131 todo!()
132 }
133 }
134
135 impl ExprSerializable for Pack {
136 fn id(&self) -> &'static str {
137 PackSerde.id()
138 }
139
140 fn serialize_kind(&self) -> VortexResult<Kind> {
141 vortex_bail!(NotImplemented: "", self.id())
142 }
143 }
144}
145
146impl AnalysisExpr for Pack {}
147
148impl VortexExpr for Pack {
149 fn as_any(&self) -> &dyn Any {
150 self
151 }
152
153 fn unchecked_evaluate(&self, scope: &Scope) -> VortexResult<ArrayRef> {
154 let len = scope.len();
155 let value_arrays = self
156 .values
157 .iter()
158 .map(|value_expr| value_expr.unchecked_evaluate(scope))
159 .process_results(|it| it.collect::<Vec<_>>())?;
160 let validity = match self.nullability {
161 Nullability::NonNullable => Validity::NonNullable,
162 Nullability::Nullable => Validity::AllValid,
163 };
164 Ok(StructArray::try_new(self.names.clone(), value_arrays, len, validity)?.into_array())
165 }
166
167 fn children(&self) -> Vec<&ExprRef> {
168 self.values.iter().collect()
169 }
170
171 fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef {
172 assert_eq!(children.len(), self.values.len());
173 Self::try_new_expr(self.names.clone(), children, self.nullability)
174 .vortex_expect("children are known to have the same length as names")
175 }
176
177 fn return_dtype(&self, scope: &ScopeDType) -> VortexResult<DType> {
178 let value_dtypes = self
179 .values
180 .iter()
181 .map(|value_expr| value_expr.return_dtype(scope))
182 .process_results(|it| it.collect())?;
183 Ok(DType::Struct(
184 Arc::new(StructFields::new(self.names.clone(), value_dtypes)),
185 self.nullability,
186 ))
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use std::sync::Arc;
193
194 use vortex_array::arrays::{PrimitiveArray, StructArray};
195 use vortex_array::validity::Validity;
196 use vortex_array::vtable::ValidityHelper;
197 use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
198 use vortex_buffer::buffer;
199 use vortex_dtype::{FieldNames, Nullability};
200 use vortex_error::{VortexResult, vortex_bail};
201
202 use crate::{Pack, Scope, VortexExpr, col};
203
204 fn test_array() -> ArrayRef {
205 StructArray::from_fields(&[
206 ("a", buffer![0, 1, 2].into_array()),
207 ("b", buffer![4, 5, 6].into_array()),
208 ])
209 .unwrap()
210 .into_array()
211 }
212
213 fn primitive_field(array: &dyn Array, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
214 let mut field_path = field_path.iter();
215
216 let Some(field) = field_path.next() else {
217 vortex_bail!("empty field path");
218 };
219
220 let mut array = array.to_struct()?.field_by_name(field)?.clone();
221 for field in field_path {
222 array = array.to_struct()?.field_by_name(field)?.clone();
223 }
224 Ok(array.to_primitive().unwrap())
225 }
226
227 #[test]
228 pub fn test_empty_pack() {
229 let expr = Pack::try_new_expr(Arc::new([]), Vec::new(), Nullability::NonNullable).unwrap();
230
231 let test_array = test_array();
232 let actual_array = expr.evaluate(&Scope::new(test_array.clone())).unwrap();
233 assert_eq!(actual_array.len(), test_array.len());
234 assert_eq!(
235 actual_array.to_struct().unwrap().struct_fields().nfields(),
236 0
237 );
238 }
239
240 #[test]
241 pub fn test_simple_pack() {
242 let expr = Pack::try_new_expr(
243 ["one".into(), "two".into(), "three".into()].into(),
244 vec![col("a"), col("b"), col("a")],
245 Nullability::NonNullable,
246 )
247 .unwrap();
248
249 let actual_array = expr
250 .evaluate(&Scope::new(test_array()))
251 .unwrap()
252 .to_struct()
253 .unwrap();
254 let expected_names: FieldNames = ["one".into(), "two".into(), "three".into()].into();
255 assert_eq!(actual_array.names(), &expected_names);
256 assert_eq!(actual_array.validity(), &Validity::NonNullable);
257
258 assert_eq!(
259 primitive_field(actual_array.as_ref(), &["one"])
260 .unwrap()
261 .as_slice::<i32>(),
262 [0, 1, 2]
263 );
264 assert_eq!(
265 primitive_field(actual_array.as_ref(), &["two"])
266 .unwrap()
267 .as_slice::<i32>(),
268 [4, 5, 6]
269 );
270 assert_eq!(
271 primitive_field(actual_array.as_ref(), &["three"])
272 .unwrap()
273 .as_slice::<i32>(),
274 [0, 1, 2]
275 );
276 }
277
278 #[test]
279 pub fn test_nested_pack() {
280 let expr = Pack::try_new_expr(
281 ["one".into(), "two".into(), "three".into()].into(),
282 vec![
283 col("a"),
284 Pack::try_new_expr(
285 ["two_one".into(), "two_two".into()].into(),
286 vec![col("b"), col("b")],
287 Nullability::NonNullable,
288 )
289 .unwrap(),
290 col("a"),
291 ],
292 Nullability::NonNullable,
293 )
294 .unwrap();
295
296 let actual_array = expr
297 .evaluate(&Scope::new(test_array()))
298 .unwrap()
299 .to_struct()
300 .unwrap();
301 let expected_names: FieldNames = ["one".into(), "two".into(), "three".into()].into();
302 assert_eq!(actual_array.names(), &expected_names);
303
304 assert_eq!(
305 primitive_field(actual_array.as_ref(), &["one"])
306 .unwrap()
307 .as_slice::<i32>(),
308 [0, 1, 2]
309 );
310 assert_eq!(
311 primitive_field(actual_array.as_ref(), &["two", "two_one"])
312 .unwrap()
313 .as_slice::<i32>(),
314 [4, 5, 6]
315 );
316 assert_eq!(
317 primitive_field(actual_array.as_ref(), &["two", "two_two"])
318 .unwrap()
319 .as_slice::<i32>(),
320 [4, 5, 6]
321 );
322 assert_eq!(
323 primitive_field(actual_array.as_ref(), &["three"])
324 .unwrap()
325 .as_slice::<i32>(),
326 [0, 1, 2]
327 );
328 }
329
330 #[test]
331 pub fn test_pack_nullable() {
332 let expr = Pack::try_new_expr(
333 ["one".into(), "two".into(), "three".into()].into(),
334 vec![col("a"), col("b"), col("a")],
335 Nullability::Nullable,
336 )
337 .unwrap();
338
339 let actual_array = expr
340 .evaluate(&Scope::new(test_array()))
341 .unwrap()
342 .to_struct()
343 .unwrap();
344 let expected_names: FieldNames = ["one".into(), "two".into(), "three".into()].into();
345 assert_eq!(actual_array.names(), &expected_names);
346 assert_eq!(actual_array.validity(), &Validity::AllValid);
347 }
348}