1use std::hash::Hash;
5
6use itertools::Itertools as _;
7use vortex_array::arrays::StructArray;
8use vortex_array::validity::Validity;
9use vortex_array::{ArrayRef, DeserializeMetadata, IntoArray, ProstMetadata};
10use vortex_dtype::{DType, FieldName, FieldNames, Nullability, StructFields};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_err};
12use vortex_proto::expr as pb;
13
14use crate::display::{DisplayAs, DisplayFormat};
15use crate::{AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, VTable, vtable};
16
17vtable!(Pack);
18
19#[allow(clippy::derived_hash_with_manual_eq)]
47#[derive(Debug, Clone, PartialEq, Eq, Hash)]
48pub struct PackExpr {
49 names: FieldNames,
50 values: Vec<ExprRef>,
51 nullability: Nullability,
52}
53
54pub struct PackExprEncoding;
55
56impl VTable for PackVTable {
57 type Expr = PackExpr;
58 type Encoding = PackExprEncoding;
59 type Metadata = ProstMetadata<pb::PackOpts>;
60
61 fn id(_encoding: &Self::Encoding) -> ExprId {
62 ExprId::new_ref("pack")
63 }
64
65 fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
66 ExprEncodingRef::new_ref(PackExprEncoding.as_ref())
67 }
68
69 fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
70 Some(ProstMetadata(pb::PackOpts {
71 paths: expr.names.iter().map(|n| n.to_string()).collect(),
72 nullable: expr.nullability.into(),
73 }))
74 }
75
76 fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
77 expr.values.iter().collect()
78 }
79
80 fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
81 PackExpr::try_new(expr.names.clone(), children, expr.nullability)
82 }
83
84 fn build(
85 _encoding: &Self::Encoding,
86 metadata: &<Self::Metadata as DeserializeMetadata>::Output,
87 children: Vec<ExprRef>,
88 ) -> VortexResult<Self::Expr> {
89 if children.len() != metadata.paths.len() {
90 vortex_bail!(
91 "Pack expression expects {} children, got {}",
92 metadata.paths.len(),
93 children.len()
94 );
95 }
96 let names: FieldNames = metadata
97 .paths
98 .iter()
99 .map(|name| FieldName::from(name.as_str()))
100 .collect();
101 PackExpr::try_new(names, children, metadata.nullable.into())
102 }
103
104 fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
105 let len = scope.len();
106 let value_arrays = expr
107 .values
108 .iter()
109 .zip_eq(expr.names.iter())
110 .map(|(value_expr, name)| {
111 value_expr
112 .unchecked_evaluate(scope)
113 .map_err(|e| e.with_context(format!("Can't evaluate '{name}'")))
114 })
115 .process_results(|it| it.collect::<Vec<_>>())?;
116 let validity = match expr.nullability {
117 Nullability::NonNullable => Validity::NonNullable,
118 Nullability::Nullable => Validity::AllValid,
119 };
120 Ok(StructArray::try_new(expr.names.clone(), value_arrays, len, validity)?.into_array())
121 }
122
123 fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
124 let value_dtypes = expr
125 .values
126 .iter()
127 .map(|value_expr| value_expr.return_dtype(scope))
128 .process_results(|it| it.collect())?;
129 Ok(DType::Struct(
130 StructFields::new(expr.names.clone(), value_dtypes),
131 expr.nullability,
132 ))
133 }
134}
135
136impl PackExpr {
137 pub fn try_new(
138 names: FieldNames,
139 values: Vec<ExprRef>,
140 nullability: Nullability,
141 ) -> VortexResult<Self> {
142 if names.len() != values.len() {
143 vortex_bail!("length mismatch {} {}", names.len(), values.len());
144 }
145 Ok(PackExpr {
146 names,
147 values,
148 nullability,
149 })
150 }
151
152 pub fn try_new_expr(
153 names: FieldNames,
154 values: Vec<ExprRef>,
155 nullability: Nullability,
156 ) -> VortexResult<ExprRef> {
157 Self::try_new(names, values, nullability).map(|v| v.into_expr())
158 }
159
160 pub fn names(&self) -> &FieldNames {
161 &self.names
162 }
163
164 pub fn field(&self, field_name: &FieldName) -> VortexResult<ExprRef> {
165 let idx = self
166 .names
167 .iter()
168 .position(|name| name == field_name)
169 .ok_or_else(|| {
170 vortex_err!(
171 "Cannot find field {} in pack fields {:?}",
172 field_name,
173 self.names
174 )
175 })?;
176
177 self.values
178 .get(idx)
179 .cloned()
180 .ok_or_else(|| vortex_err!("field index out of bounds: {}", idx))
181 }
182
183 pub fn nullability(&self) -> Nullability {
184 self.nullability
185 }
186}
187
188pub fn pack(
196 elements: impl IntoIterator<Item = (impl Into<FieldName>, ExprRef)>,
197 nullability: Nullability,
198) -> ExprRef {
199 let (names, values): (Vec<_>, Vec<_>) = elements
200 .into_iter()
201 .map(|(name, value)| (name.into(), value))
202 .unzip();
203 PackExpr::try_new(names.into(), values, nullability)
204 .vortex_expect("pack names and values have the same length")
205 .into_expr()
206}
207
208impl DisplayAs for PackExpr {
209 fn fmt_as(&self, df: DisplayFormat, f: &mut std::fmt::Formatter) -> std::fmt::Result {
210 match df {
211 DisplayFormat::Compact => {
212 write!(
213 f,
214 "pack({}){}",
215 self.names
216 .iter()
217 .zip(&self.values)
218 .format_with(", ", |(name, expr), f| f(&format_args!("{name}: {expr}"))),
219 self.nullability
220 )
221 }
222 DisplayFormat::Tree => {
223 write!(f, "Pack")
224 }
225 }
226 }
227
228 fn child_names(&self) -> Option<Vec<String>> {
229 Some(self.names.iter().map(|n| n.to_string()).collect())
230 }
231}
232
233impl AnalysisExpr for PackExpr {}
234
235#[cfg(test)]
236mod tests {
237
238 use vortex_array::arrays::{PrimitiveArray, StructArray};
239 use vortex_array::validity::Validity;
240 use vortex_array::vtable::ValidityHelper;
241 use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
242 use vortex_buffer::buffer;
243 use vortex_dtype::{FieldNames, Nullability};
244 use vortex_error::{VortexResult, vortex_bail};
245
246 use crate::{IntoExpr, PackExpr, Scope, col, pack};
247
248 fn test_array() -> ArrayRef {
249 StructArray::from_fields(&[
250 ("a", buffer![0, 1, 2].into_array()),
251 ("b", buffer![4, 5, 6].into_array()),
252 ])
253 .unwrap()
254 .into_array()
255 }
256
257 fn primitive_field(array: &dyn Array, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
258 let mut field_path = field_path.iter();
259
260 let Some(field) = field_path.next() else {
261 vortex_bail!("empty field path");
262 };
263
264 let mut array = array.to_struct().field_by_name(field)?.clone();
265 for field in field_path {
266 array = array.to_struct().field_by_name(field)?.clone();
267 }
268 Ok(array.to_primitive())
269 }
270
271 #[test]
272 pub fn test_empty_pack() {
273 let expr =
274 PackExpr::try_new(FieldNames::default(), Vec::new(), Nullability::NonNullable).unwrap();
275
276 let test_array = test_array();
277 let actual_array = expr.evaluate(&Scope::new(test_array.clone())).unwrap();
278 assert_eq!(actual_array.len(), test_array.len());
279 assert_eq!(actual_array.to_struct().struct_fields().nfields(), 0);
280 }
281
282 #[test]
283 pub fn test_simple_pack() {
284 let expr = PackExpr::try_new(
285 ["one", "two", "three"].into(),
286 vec![col("a"), col("b"), col("a")],
287 Nullability::NonNullable,
288 )
289 .unwrap();
290
291 let actual_array = expr
292 .evaluate(&Scope::new(test_array()))
293 .unwrap()
294 .to_struct();
295
296 assert_eq!(actual_array.names(), ["one", "two", "three"]);
297 assert_eq!(actual_array.validity(), &Validity::NonNullable);
298
299 assert_eq!(
300 primitive_field(actual_array.as_ref(), &["one"])
301 .unwrap()
302 .as_slice::<i32>(),
303 [0, 1, 2]
304 );
305 assert_eq!(
306 primitive_field(actual_array.as_ref(), &["two"])
307 .unwrap()
308 .as_slice::<i32>(),
309 [4, 5, 6]
310 );
311 assert_eq!(
312 primitive_field(actual_array.as_ref(), &["three"])
313 .unwrap()
314 .as_slice::<i32>(),
315 [0, 1, 2]
316 );
317 }
318
319 #[test]
320 pub fn test_nested_pack() {
321 let expr = PackExpr::try_new(
322 ["one", "two", "three"].into(),
323 vec![
324 col("a"),
325 PackExpr::try_new(
326 ["two_one", "two_two"].into(),
327 vec![col("b"), col("b")],
328 Nullability::NonNullable,
329 )
330 .unwrap()
331 .into_expr(),
332 col("a"),
333 ],
334 Nullability::NonNullable,
335 )
336 .unwrap();
337
338 let actual_array = expr
339 .evaluate(&Scope::new(test_array()))
340 .unwrap()
341 .to_struct();
342
343 assert_eq!(actual_array.names(), ["one", "two", "three"]);
344
345 assert_eq!(
346 primitive_field(actual_array.as_ref(), &["one"])
347 .unwrap()
348 .as_slice::<i32>(),
349 [0, 1, 2]
350 );
351 assert_eq!(
352 primitive_field(actual_array.as_ref(), &["two", "two_one"])
353 .unwrap()
354 .as_slice::<i32>(),
355 [4, 5, 6]
356 );
357 assert_eq!(
358 primitive_field(actual_array.as_ref(), &["two", "two_two"])
359 .unwrap()
360 .as_slice::<i32>(),
361 [4, 5, 6]
362 );
363 assert_eq!(
364 primitive_field(actual_array.as_ref(), &["three"])
365 .unwrap()
366 .as_slice::<i32>(),
367 [0, 1, 2]
368 );
369 }
370
371 #[test]
372 pub fn test_pack_nullable() {
373 let expr = PackExpr::try_new(
374 ["one", "two", "three"].into(),
375 vec![col("a"), col("b"), col("a")],
376 Nullability::Nullable,
377 )
378 .unwrap();
379
380 let actual_array = expr
381 .evaluate(&Scope::new(test_array()))
382 .unwrap()
383 .to_struct();
384
385 assert_eq!(actual_array.names(), ["one", "two", "three"]);
386 assert_eq!(actual_array.validity(), &Validity::AllValid);
387 }
388
389 #[test]
390 pub fn test_display() {
391 let expr = pack(
392 [("id", col("user_id")), ("name", col("username"))],
393 Nullability::NonNullable,
394 );
395 assert_eq!(expr.to_string(), "pack(id: $.user_id, name: $.username)");
396
397 let expr2 = PackExpr::try_new(
398 ["x", "y"].into(),
399 vec![col("a"), col("b")],
400 Nullability::Nullable,
401 )
402 .unwrap();
403 assert_eq!(expr2.to_string(), "pack(x: $.a, y: $.b)?");
404 }
405}