1use std::fmt::Display;
5use std::hash::Hash;
6
7use itertools::Itertools as _;
8use vortex_array::arrays::StructArray;
9use vortex_array::validity::Validity;
10use vortex_array::{ArrayRef, DeserializeMetadata, IntoArray, ProstMetadata};
11use vortex_dtype::{DType, FieldName, FieldNames, Nullability, StructFields};
12use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_err};
13use vortex_proto::expr as pb;
14
15use crate::{AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, VTable, vtable};
16
17vtable!(Pack);
18
19#[allow(clippy::derived_hash_with_manual_eq)]
48#[derive(Debug, Clone, PartialEq, Eq, Hash)]
49pub struct PackExpr {
50 names: FieldNames,
51 values: Vec<ExprRef>,
52 nullability: Nullability,
53}
54
55pub struct PackExprEncoding;
56
57impl VTable for PackVTable {
58 type Expr = PackExpr;
59 type Encoding = PackExprEncoding;
60 type Metadata = ProstMetadata<pb::PackOpts>;
61
62 fn id(_encoding: &Self::Encoding) -> ExprId {
63 ExprId::new_ref("pack")
64 }
65
66 fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
67 ExprEncodingRef::new_ref(PackExprEncoding.as_ref())
68 }
69
70 fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
71 Some(ProstMetadata(pb::PackOpts {
72 paths: expr.names.iter().map(|n| n.to_string()).collect(),
73 nullable: expr.nullability.into(),
74 }))
75 }
76
77 fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
78 expr.values.iter().collect()
79 }
80
81 fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
82 PackExpr::try_new(expr.names.clone(), children, expr.nullability)
83 }
84
85 fn build(
86 _encoding: &Self::Encoding,
87 metadata: &<Self::Metadata as DeserializeMetadata>::Output,
88 children: Vec<ExprRef>,
89 ) -> VortexResult<Self::Expr> {
90 if children.len() != metadata.paths.len() {
91 vortex_bail!(
92 "Pack expression expects {} children, got {}",
93 metadata.paths.len(),
94 children.len()
95 );
96 }
97 let names: FieldNames = metadata
98 .paths
99 .iter()
100 .map(|name| FieldName::from(name.as_str()))
101 .collect();
102 PackExpr::try_new(names, children, metadata.nullable.into())
103 }
104
105 fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
106 let len = scope.len();
107 let value_arrays = expr
108 .values
109 .iter()
110 .map(|value_expr| value_expr.unchecked_evaluate(scope))
111 .process_results(|it| it.collect::<Vec<_>>())?;
112 let validity = match expr.nullability {
113 Nullability::NonNullable => Validity::NonNullable,
114 Nullability::Nullable => Validity::AllValid,
115 };
116 Ok(StructArray::try_new(expr.names.clone(), value_arrays, len, validity)?.into_array())
117 }
118
119 fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
120 let value_dtypes = expr
121 .values
122 .iter()
123 .map(|value_expr| value_expr.return_dtype(scope))
124 .process_results(|it| it.collect())?;
125 Ok(DType::Struct(
126 StructFields::new(expr.names.clone(), value_dtypes),
127 expr.nullability,
128 ))
129 }
130}
131
132impl PackExpr {
133 pub fn try_new(
134 names: FieldNames,
135 values: Vec<ExprRef>,
136 nullability: Nullability,
137 ) -> VortexResult<Self> {
138 if names.len() != values.len() {
139 vortex_bail!("length mismatch {} {}", names.len(), values.len());
140 }
141 Ok(PackExpr {
142 names,
143 values,
144 nullability,
145 })
146 }
147
148 pub fn try_new_expr(
149 names: FieldNames,
150 values: Vec<ExprRef>,
151 nullability: Nullability,
152 ) -> VortexResult<ExprRef> {
153 Self::try_new(names, values, nullability).map(|v| v.into_expr())
154 }
155
156 pub fn names(&self) -> &FieldNames {
157 &self.names
158 }
159
160 pub fn field(&self, field_name: &FieldName) -> VortexResult<ExprRef> {
161 let idx = self
162 .names
163 .iter()
164 .position(|name| name == field_name)
165 .ok_or_else(|| {
166 vortex_err!(
167 "Cannot find field {} in pack fields {:?}",
168 field_name,
169 self.names
170 )
171 })?;
172
173 self.values
174 .get(idx)
175 .cloned()
176 .ok_or_else(|| vortex_err!("field index out of bounds: {}", idx))
177 }
178
179 pub fn nullability(&self) -> Nullability {
180 self.nullability
181 }
182}
183
184pub fn pack(
192 elements: impl IntoIterator<Item = (impl Into<FieldName>, ExprRef)>,
193 nullability: Nullability,
194) -> ExprRef {
195 let (names, values): (Vec<_>, Vec<_>) = elements
196 .into_iter()
197 .map(|(name, value)| (name.into(), value))
198 .unzip();
199 PackExpr::try_new(names.into(), values, nullability)
200 .vortex_expect("pack names and values have the same length")
201 .into_expr()
202}
203
204impl Display for PackExpr {
205 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
206 write!(
207 f,
208 "pack({}){}",
209 self.names
210 .iter()
211 .zip(&self.values)
212 .format_with(", ", |(name, expr), f| f(&format_args!("{name}: {expr}"))),
213 self.nullability
214 )
215 }
216}
217
218impl AnalysisExpr for PackExpr {}
219
220#[cfg(test)]
221mod tests {
222
223 use vortex_array::arrays::{PrimitiveArray, StructArray};
224 use vortex_array::validity::Validity;
225 use vortex_array::vtable::ValidityHelper;
226 use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
227 use vortex_buffer::buffer;
228 use vortex_dtype::{FieldNames, Nullability};
229 use vortex_error::{VortexResult, vortex_bail};
230
231 use crate::{IntoExpr, PackExpr, Scope, col};
232
233 fn test_array() -> ArrayRef {
234 StructArray::from_fields(&[
235 ("a", buffer![0, 1, 2].into_array()),
236 ("b", buffer![4, 5, 6].into_array()),
237 ])
238 .unwrap()
239 .into_array()
240 }
241
242 fn primitive_field(array: &dyn Array, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
243 let mut field_path = field_path.iter();
244
245 let Some(field) = field_path.next() else {
246 vortex_bail!("empty field path");
247 };
248
249 let mut array = array.to_struct()?.field_by_name(field)?.clone();
250 for field in field_path {
251 array = array.to_struct()?.field_by_name(field)?.clone();
252 }
253 Ok(array.to_primitive().unwrap())
254 }
255
256 #[test]
257 pub fn test_empty_pack() {
258 let expr =
259 PackExpr::try_new(FieldNames::default(), Vec::new(), Nullability::NonNullable).unwrap();
260
261 let test_array = test_array();
262 let actual_array = expr.evaluate(&Scope::new(test_array.clone())).unwrap();
263 assert_eq!(actual_array.len(), test_array.len());
264 assert_eq!(
265 actual_array.to_struct().unwrap().struct_fields().nfields(),
266 0
267 );
268 }
269
270 #[test]
271 pub fn test_simple_pack() {
272 let expr = PackExpr::try_new(
273 ["one", "two", "three"].into(),
274 vec![col("a"), col("b"), col("a")],
275 Nullability::NonNullable,
276 )
277 .unwrap();
278
279 let actual_array = expr
280 .evaluate(&Scope::new(test_array()))
281 .unwrap()
282 .to_struct()
283 .unwrap();
284
285 assert_eq!(actual_array.names(), ["one", "two", "three"]);
286 assert_eq!(actual_array.validity(), &Validity::NonNullable);
287
288 assert_eq!(
289 primitive_field(actual_array.as_ref(), &["one"])
290 .unwrap()
291 .as_slice::<i32>(),
292 [0, 1, 2]
293 );
294 assert_eq!(
295 primitive_field(actual_array.as_ref(), &["two"])
296 .unwrap()
297 .as_slice::<i32>(),
298 [4, 5, 6]
299 );
300 assert_eq!(
301 primitive_field(actual_array.as_ref(), &["three"])
302 .unwrap()
303 .as_slice::<i32>(),
304 [0, 1, 2]
305 );
306 }
307
308 #[test]
309 pub fn test_nested_pack() {
310 let expr = PackExpr::try_new(
311 ["one", "two", "three"].into(),
312 vec![
313 col("a"),
314 PackExpr::try_new(
315 ["two_one", "two_two"].into(),
316 vec![col("b"), col("b")],
317 Nullability::NonNullable,
318 )
319 .unwrap()
320 .into_expr(),
321 col("a"),
322 ],
323 Nullability::NonNullable,
324 )
325 .unwrap();
326
327 let actual_array = expr
328 .evaluate(&Scope::new(test_array()))
329 .unwrap()
330 .to_struct()
331 .unwrap();
332
333 assert_eq!(actual_array.names(), ["one", "two", "three"]);
334
335 assert_eq!(
336 primitive_field(actual_array.as_ref(), &["one"])
337 .unwrap()
338 .as_slice::<i32>(),
339 [0, 1, 2]
340 );
341 assert_eq!(
342 primitive_field(actual_array.as_ref(), &["two", "two_one"])
343 .unwrap()
344 .as_slice::<i32>(),
345 [4, 5, 6]
346 );
347 assert_eq!(
348 primitive_field(actual_array.as_ref(), &["two", "two_two"])
349 .unwrap()
350 .as_slice::<i32>(),
351 [4, 5, 6]
352 );
353 assert_eq!(
354 primitive_field(actual_array.as_ref(), &["three"])
355 .unwrap()
356 .as_slice::<i32>(),
357 [0, 1, 2]
358 );
359 }
360
361 #[test]
362 pub fn test_pack_nullable() {
363 let expr = PackExpr::try_new(
364 ["one", "two", "three"].into(),
365 vec![col("a"), col("b"), col("a")],
366 Nullability::Nullable,
367 )
368 .unwrap();
369
370 let actual_array = expr
371 .evaluate(&Scope::new(test_array()))
372 .unwrap()
373 .to_struct()
374 .unwrap();
375
376 assert_eq!(actual_array.names(), ["one", "two", "three"]);
377 assert_eq!(actual_array.validity(), &Validity::AllValid);
378 }
379}