1use std::hash::Hash;
5
6use itertools::Itertools as _;
7use vortex_array::arrays::StructArray;
8use vortex_array::validity::Validity;
9use vortex_array::{ArrayRef, DeserializeMetadata, IntoArray, ProstMetadata};
10use vortex_dtype::{DType, FieldName, FieldNames, Nullability, StructFields};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_err};
12use vortex_proto::expr as pb;
13
14use crate::display::{DisplayAs, DisplayFormat};
15use crate::{AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, VTable, vtable};
16
17vtable!(Pack);
18
19#[allow(clippy::derived_hash_with_manual_eq)]
47#[derive(Debug, Clone, PartialEq, Eq, Hash)]
48pub struct PackExpr {
49 names: FieldNames,
50 values: Vec<ExprRef>,
51 nullability: Nullability,
52}
53
54pub struct PackExprEncoding;
55
56impl VTable for PackVTable {
57 type Expr = PackExpr;
58 type Encoding = PackExprEncoding;
59 type Metadata = ProstMetadata<pb::PackOpts>;
60
61 fn id(_encoding: &Self::Encoding) -> ExprId {
62 ExprId::new_ref("pack")
63 }
64
65 fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
66 ExprEncodingRef::new_ref(PackExprEncoding.as_ref())
67 }
68
69 fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
70 Some(ProstMetadata(pb::PackOpts {
71 paths: expr.names.iter().map(|n| n.to_string()).collect(),
72 nullable: expr.nullability.into(),
73 }))
74 }
75
76 fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
77 expr.values.iter().collect()
78 }
79
80 fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
81 PackExpr::try_new(expr.names.clone(), children, expr.nullability)
82 }
83
84 fn build(
85 _encoding: &Self::Encoding,
86 metadata: &<Self::Metadata as DeserializeMetadata>::Output,
87 children: Vec<ExprRef>,
88 ) -> VortexResult<Self::Expr> {
89 if children.len() != metadata.paths.len() {
90 vortex_bail!(
91 "Pack expression expects {} children, got {}",
92 metadata.paths.len(),
93 children.len()
94 );
95 }
96 let names: FieldNames = metadata
97 .paths
98 .iter()
99 .map(|name| FieldName::from(name.as_str()))
100 .collect();
101 PackExpr::try_new(names, children, metadata.nullable.into())
102 }
103
104 fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
105 let len = scope.len();
106 let value_arrays = expr
107 .values
108 .iter()
109 .zip_eq(expr.names.iter())
110 .map(|(value_expr, name)| {
111 value_expr
112 .unchecked_evaluate(scope)
113 .map_err(|e| e.with_context(format!("Can't evaluate '{name}'")))
114 })
115 .process_results(|it| it.collect::<Vec<_>>())?;
116 let validity = match expr.nullability {
117 Nullability::NonNullable => Validity::NonNullable,
118 Nullability::Nullable => Validity::AllValid,
119 };
120 Ok(StructArray::try_new(expr.names.clone(), value_arrays, len, validity)?.into_array())
121 }
122
123 fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
124 let value_dtypes = expr
125 .values
126 .iter()
127 .map(|value_expr| value_expr.return_dtype(scope))
128 .collect::<VortexResult<Vec<_>>>()?;
129 Ok(DType::Struct(
130 StructFields::new(expr.names.clone(), value_dtypes),
131 expr.nullability,
132 ))
133 }
134}
135
136impl PackExpr {
137 pub fn try_new(
138 names: FieldNames,
139 values: Vec<ExprRef>,
140 nullability: Nullability,
141 ) -> VortexResult<Self> {
142 if names.len() != values.len() {
143 vortex_bail!("length mismatch {} {}", names.len(), values.len());
144 }
145 Ok(PackExpr {
146 names,
147 values,
148 nullability,
149 })
150 }
151
152 pub fn try_new_expr(
153 names: FieldNames,
154 values: Vec<ExprRef>,
155 nullability: Nullability,
156 ) -> VortexResult<ExprRef> {
157 Self::try_new(names, values, nullability).map(|v| v.into_expr())
158 }
159
160 pub fn names(&self) -> &FieldNames {
161 &self.names
162 }
163
164 pub fn field(&self, field_name: &FieldName) -> VortexResult<ExprRef> {
165 let idx = self
166 .names
167 .iter()
168 .position(|name| name == field_name)
169 .ok_or_else(|| {
170 vortex_err!(
171 "Cannot find field {} in pack fields {:?}",
172 field_name,
173 self.names
174 )
175 })?;
176
177 self.values
178 .get(idx)
179 .cloned()
180 .ok_or_else(|| vortex_err!("field index out of bounds: {}", idx))
181 }
182
183 pub fn nullability(&self) -> Nullability {
184 self.nullability
185 }
186}
187
188pub fn pack(
196 elements: impl IntoIterator<Item = (impl Into<FieldName>, ExprRef)>,
197 nullability: Nullability,
198) -> ExprRef {
199 let (names, values): (Vec<_>, Vec<_>) = elements
200 .into_iter()
201 .map(|(name, value)| (name.into(), value))
202 .unzip();
203 PackExpr::try_new(names.into(), values, nullability)
204 .vortex_expect("pack names and values have the same length")
205 .into_expr()
206}
207
208impl DisplayAs for PackExpr {
209 fn fmt_as(&self, df: DisplayFormat, f: &mut std::fmt::Formatter) -> std::fmt::Result {
210 match df {
211 DisplayFormat::Compact => {
212 write!(
213 f,
214 "pack({}){}",
215 self.names
216 .iter()
217 .zip(&self.values)
218 .format_with(", ", |(name, expr), f| f(&format_args!("{name}: {expr}"))),
219 self.nullability
220 )
221 }
222 DisplayFormat::Tree => {
223 write!(f, "Pack")
224 }
225 }
226 }
227
228 fn child_names(&self) -> Option<Vec<String>> {
229 Some(self.names.iter().map(|n| n.to_string()).collect())
230 }
231}
232
233impl AnalysisExpr for PackExpr {}
234
235#[cfg(test)]
236mod tests {
237 use vortex_array::arrays::{PrimitiveArray, StructArray};
238 use vortex_array::validity::Validity;
239 use vortex_array::vtable::ValidityHelper;
240 use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
241 use vortex_buffer::buffer;
242 use vortex_dtype::{FieldNames, Nullability};
243 use vortex_error::{VortexResult, vortex_bail};
244
245 use crate::{IntoExpr, PackExpr, Scope, col, pack};
246
247 fn test_array() -> ArrayRef {
248 StructArray::from_fields(&[
249 ("a", buffer![0, 1, 2].into_array()),
250 ("b", buffer![4, 5, 6].into_array()),
251 ])
252 .unwrap()
253 .into_array()
254 }
255
256 fn primitive_field(array: &dyn Array, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
257 let mut field_path = field_path.iter();
258
259 let Some(field) = field_path.next() else {
260 vortex_bail!("empty field path");
261 };
262
263 let mut array = array.to_struct().field_by_name(field)?.clone();
264 for field in field_path {
265 array = array.to_struct().field_by_name(field)?.clone();
266 }
267 Ok(array.to_primitive())
268 }
269
270 #[test]
271 pub fn test_empty_pack() {
272 let expr =
273 PackExpr::try_new(FieldNames::default(), Vec::new(), Nullability::NonNullable).unwrap();
274
275 let test_array = test_array();
276 let actual_array = expr.evaluate(&Scope::new(test_array.clone())).unwrap();
277 assert_eq!(actual_array.len(), test_array.len());
278 assert_eq!(actual_array.to_struct().struct_fields().nfields(), 0);
279 }
280
281 #[test]
282 pub fn test_simple_pack() {
283 let expr = PackExpr::try_new(
284 ["one", "two", "three"].into(),
285 vec![col("a"), col("b"), col("a")],
286 Nullability::NonNullable,
287 )
288 .unwrap();
289
290 let actual_array = expr
291 .evaluate(&Scope::new(test_array()))
292 .unwrap()
293 .to_struct();
294
295 assert_eq!(actual_array.names(), ["one", "two", "three"]);
296 assert_eq!(actual_array.validity(), &Validity::NonNullable);
297
298 assert_eq!(
299 primitive_field(actual_array.as_ref(), &["one"])
300 .unwrap()
301 .as_slice::<i32>(),
302 [0, 1, 2]
303 );
304 assert_eq!(
305 primitive_field(actual_array.as_ref(), &["two"])
306 .unwrap()
307 .as_slice::<i32>(),
308 [4, 5, 6]
309 );
310 assert_eq!(
311 primitive_field(actual_array.as_ref(), &["three"])
312 .unwrap()
313 .as_slice::<i32>(),
314 [0, 1, 2]
315 );
316 }
317
318 #[test]
319 pub fn test_nested_pack() {
320 let expr = PackExpr::try_new(
321 ["one", "two", "three"].into(),
322 vec![
323 col("a"),
324 PackExpr::try_new(
325 ["two_one", "two_two"].into(),
326 vec![col("b"), col("b")],
327 Nullability::NonNullable,
328 )
329 .unwrap()
330 .into_expr(),
331 col("a"),
332 ],
333 Nullability::NonNullable,
334 )
335 .unwrap();
336
337 let actual_array = expr
338 .evaluate(&Scope::new(test_array()))
339 .unwrap()
340 .to_struct();
341
342 assert_eq!(actual_array.names(), ["one", "two", "three"]);
343
344 assert_eq!(
345 primitive_field(actual_array.as_ref(), &["one"])
346 .unwrap()
347 .as_slice::<i32>(),
348 [0, 1, 2]
349 );
350 assert_eq!(
351 primitive_field(actual_array.as_ref(), &["two", "two_one"])
352 .unwrap()
353 .as_slice::<i32>(),
354 [4, 5, 6]
355 );
356 assert_eq!(
357 primitive_field(actual_array.as_ref(), &["two", "two_two"])
358 .unwrap()
359 .as_slice::<i32>(),
360 [4, 5, 6]
361 );
362 assert_eq!(
363 primitive_field(actual_array.as_ref(), &["three"])
364 .unwrap()
365 .as_slice::<i32>(),
366 [0, 1, 2]
367 );
368 }
369
370 #[test]
371 pub fn test_pack_nullable() {
372 let expr = PackExpr::try_new(
373 ["one", "two", "three"].into(),
374 vec![col("a"), col("b"), col("a")],
375 Nullability::Nullable,
376 )
377 .unwrap();
378
379 let actual_array = expr
380 .evaluate(&Scope::new(test_array()))
381 .unwrap()
382 .to_struct();
383
384 assert_eq!(actual_array.names(), ["one", "two", "three"]);
385 assert_eq!(actual_array.validity(), &Validity::AllValid);
386 }
387
388 #[test]
389 pub fn test_display() {
390 let expr = pack(
391 [("id", col("user_id")), ("name", col("username"))],
392 Nullability::NonNullable,
393 );
394 assert_eq!(expr.to_string(), "pack(id: $.user_id, name: $.username)");
395
396 let expr2 = PackExpr::try_new(
397 ["x", "y"].into(),
398 vec![col("a"), col("b")],
399 Nullability::Nullable,
400 )
401 .unwrap();
402 assert_eq!(expr2.to_string(), "pack(x: $.a, y: $.b)?");
403 }
404}