1use std::any::Any;
2use std::fmt::Display;
3use std::hash::Hash;
4use std::sync::Arc;
5
6use itertools::Itertools as _;
7use vortex_array::arrays::StructArray;
8use vortex_array::validity::Validity;
9use vortex_array::{ArrayRef, IntoArray};
10use vortex_dtype::{DType, FieldName, FieldNames, Nullability, StructFields};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_err};
12
13use crate::{AnalysisExpr, ExprRef, Scope, ScopeDType, VortexExpr};
14
15#[derive(Debug, Clone, PartialEq, Eq, Hash)]
44pub struct Pack {
45 names: FieldNames,
46 values: Vec<ExprRef>,
47 nullability: Nullability,
48}
49
50impl Pack {
51 pub fn try_new_expr(
52 names: FieldNames,
53 values: Vec<ExprRef>,
54 nullability: Nullability,
55 ) -> VortexResult<ExprRef> {
56 if names.len() != values.len() {
57 vortex_bail!("length mismatch {} {}", names.len(), values.len());
58 }
59 Ok(Arc::new(Pack {
60 names,
61 values,
62 nullability,
63 }))
64 }
65
66 pub fn names(&self) -> &FieldNames {
67 &self.names
68 }
69
70 pub fn field(&self, field_name: &FieldName) -> VortexResult<ExprRef> {
71 let idx = self
72 .names
73 .iter()
74 .position(|name| name == field_name)
75 .ok_or_else(|| {
76 vortex_err!(
77 "Cannot find field {} in pack fields {:?}",
78 field_name,
79 self.names
80 )
81 })?;
82
83 self.values
84 .get(idx)
85 .cloned()
86 .ok_or_else(|| vortex_err!("field index out of bounds: {}", idx))
87 }
88
89 pub fn nullability(&self) -> Nullability {
90 self.nullability
91 }
92}
93
94pub fn pack(
95 elements: impl IntoIterator<Item = (impl Into<FieldName>, ExprRef)>,
96 nullability: Nullability,
97) -> ExprRef {
98 let (names, values): (Vec<_>, Vec<_>) = elements
99 .into_iter()
100 .map(|(name, value)| (name.into(), value))
101 .unzip();
102 Pack::try_new_expr(names.into(), values, nullability)
103 .vortex_expect("pack names and values have the same length")
104}
105
106impl Display for Pack {
107 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108 write!(
109 f,
110 "pack({{{}}}){}",
111 self.names
112 .iter()
113 .zip(&self.values)
114 .format_with(", ", |(name, expr), f| f(&format_args!("{name}: {expr}"))),
115 self.nullability
116 )
117 }
118}
119
120#[cfg(feature = "proto")]
121pub(crate) mod proto {
122 use vortex_error::{VortexResult, vortex_bail};
123 use vortex_proto::expr::kind;
124 use vortex_proto::expr::kind::Kind;
125
126 use crate::{ExprDeserialize, ExprRef, ExprSerializable, Id, Pack};
127
128 pub struct PackSerde;
129
130 impl Id for PackSerde {
131 fn id(&self) -> &'static str {
132 "pack"
133 }
134 }
135
136 impl ExprDeserialize for PackSerde {
137 fn deserialize(&self, kind: &Kind, children: Vec<ExprRef>) -> VortexResult<ExprRef> {
138 let Kind::Pack(op) = kind else {
139 vortex_bail!("wrong kind {:?}, wanted pack", kind)
140 };
141
142 Pack::try_new_expr(
143 op.paths.iter().cloned().map(|p| p.into()).collect(),
144 children,
145 op.nullable.into(),
146 )
147 }
148 }
149
150 impl ExprSerializable for Pack {
151 fn id(&self) -> &'static str {
152 PackSerde.id()
153 }
154
155 fn serialize_kind(&self) -> VortexResult<Kind> {
156 Ok(Kind::Pack(kind::Pack {
157 paths: self.names.iter().map(|n| n.to_string()).collect(),
158 nullable: self.nullability.into(),
159 }))
160 }
161 }
162}
163
164impl AnalysisExpr for Pack {}
165
166impl VortexExpr for Pack {
167 fn as_any(&self) -> &dyn Any {
168 self
169 }
170
171 fn unchecked_evaluate(&self, scope: &Scope) -> VortexResult<ArrayRef> {
172 let len = scope.len();
173 let value_arrays = self
174 .values
175 .iter()
176 .map(|value_expr| value_expr.unchecked_evaluate(scope))
177 .process_results(|it| it.collect::<Vec<_>>())?;
178 let validity = match self.nullability {
179 Nullability::NonNullable => Validity::NonNullable,
180 Nullability::Nullable => Validity::AllValid,
181 };
182 Ok(StructArray::try_new(self.names.clone(), value_arrays, len, validity)?.into_array())
183 }
184
185 fn children(&self) -> Vec<&ExprRef> {
186 self.values.iter().collect()
187 }
188
189 fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef {
190 assert_eq!(children.len(), self.values.len());
191 Self::try_new_expr(self.names.clone(), children, self.nullability)
192 .vortex_expect("children are known to have the same length as names")
193 }
194
195 fn return_dtype(&self, scope: &ScopeDType) -> VortexResult<DType> {
196 let value_dtypes = self
197 .values
198 .iter()
199 .map(|value_expr| value_expr.return_dtype(scope))
200 .process_results(|it| it.collect())?;
201 Ok(DType::Struct(
202 Arc::new(StructFields::new(self.names.clone(), value_dtypes)),
203 self.nullability,
204 ))
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use std::sync::Arc;
211
212 use vortex_array::arrays::{PrimitiveArray, StructArray};
213 use vortex_array::validity::Validity;
214 use vortex_array::vtable::ValidityHelper;
215 use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
216 use vortex_buffer::buffer;
217 use vortex_dtype::{FieldNames, Nullability};
218 use vortex_error::{VortexResult, vortex_bail};
219
220 use crate::{Pack, Scope, col};
221
222 fn test_array() -> ArrayRef {
223 StructArray::from_fields(&[
224 ("a", buffer![0, 1, 2].into_array()),
225 ("b", buffer![4, 5, 6].into_array()),
226 ])
227 .unwrap()
228 .into_array()
229 }
230
231 fn primitive_field(array: &dyn Array, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
232 let mut field_path = field_path.iter();
233
234 let Some(field) = field_path.next() else {
235 vortex_bail!("empty field path");
236 };
237
238 let mut array = array.to_struct()?.field_by_name(field)?.clone();
239 for field in field_path {
240 array = array.to_struct()?.field_by_name(field)?.clone();
241 }
242 Ok(array.to_primitive().unwrap())
243 }
244
245 #[test]
246 pub fn test_empty_pack() {
247 let expr = Pack::try_new_expr(Arc::new([]), Vec::new(), Nullability::NonNullable).unwrap();
248
249 let test_array = test_array();
250 let actual_array = expr.evaluate(&Scope::new(test_array.clone())).unwrap();
251 assert_eq!(actual_array.len(), test_array.len());
252 assert_eq!(
253 actual_array.to_struct().unwrap().struct_fields().nfields(),
254 0
255 );
256 }
257
258 #[test]
259 pub fn test_simple_pack() {
260 let expr = Pack::try_new_expr(
261 ["one".into(), "two".into(), "three".into()].into(),
262 vec![col("a"), col("b"), col("a")],
263 Nullability::NonNullable,
264 )
265 .unwrap();
266
267 let actual_array = expr
268 .evaluate(&Scope::new(test_array()))
269 .unwrap()
270 .to_struct()
271 .unwrap();
272 let expected_names: FieldNames = ["one".into(), "two".into(), "three".into()].into();
273 assert_eq!(actual_array.names(), &expected_names);
274 assert_eq!(actual_array.validity(), &Validity::NonNullable);
275
276 assert_eq!(
277 primitive_field(actual_array.as_ref(), &["one"])
278 .unwrap()
279 .as_slice::<i32>(),
280 [0, 1, 2]
281 );
282 assert_eq!(
283 primitive_field(actual_array.as_ref(), &["two"])
284 .unwrap()
285 .as_slice::<i32>(),
286 [4, 5, 6]
287 );
288 assert_eq!(
289 primitive_field(actual_array.as_ref(), &["three"])
290 .unwrap()
291 .as_slice::<i32>(),
292 [0, 1, 2]
293 );
294 }
295
296 #[test]
297 pub fn test_nested_pack() {
298 let expr = Pack::try_new_expr(
299 ["one".into(), "two".into(), "three".into()].into(),
300 vec![
301 col("a"),
302 Pack::try_new_expr(
303 ["two_one".into(), "two_two".into()].into(),
304 vec![col("b"), col("b")],
305 Nullability::NonNullable,
306 )
307 .unwrap(),
308 col("a"),
309 ],
310 Nullability::NonNullable,
311 )
312 .unwrap();
313
314 let actual_array = expr
315 .evaluate(&Scope::new(test_array()))
316 .unwrap()
317 .to_struct()
318 .unwrap();
319 let expected_names: FieldNames = ["one".into(), "two".into(), "three".into()].into();
320 assert_eq!(actual_array.names(), &expected_names);
321
322 assert_eq!(
323 primitive_field(actual_array.as_ref(), &["one"])
324 .unwrap()
325 .as_slice::<i32>(),
326 [0, 1, 2]
327 );
328 assert_eq!(
329 primitive_field(actual_array.as_ref(), &["two", "two_one"])
330 .unwrap()
331 .as_slice::<i32>(),
332 [4, 5, 6]
333 );
334 assert_eq!(
335 primitive_field(actual_array.as_ref(), &["two", "two_two"])
336 .unwrap()
337 .as_slice::<i32>(),
338 [4, 5, 6]
339 );
340 assert_eq!(
341 primitive_field(actual_array.as_ref(), &["three"])
342 .unwrap()
343 .as_slice::<i32>(),
344 [0, 1, 2]
345 );
346 }
347
348 #[test]
349 pub fn test_pack_nullable() {
350 let expr = Pack::try_new_expr(
351 ["one".into(), "two".into(), "three".into()].into(),
352 vec![col("a"), col("b"), col("a")],
353 Nullability::Nullable,
354 )
355 .unwrap();
356
357 let actual_array = expr
358 .evaluate(&Scope::new(test_array()))
359 .unwrap()
360 .to_struct()
361 .unwrap();
362 let expected_names: FieldNames = ["one".into(), "two".into(), "three".into()].into();
363 assert_eq!(actual_array.names(), &expected_names);
364 assert_eq!(actual_array.validity(), &Validity::AllValid);
365 }
366}