1use std::fmt::Display;
5use std::hash::Hash;
6
7use itertools::Itertools as _;
8use vortex_array::arrays::StructArray;
9use vortex_array::validity::Validity;
10use vortex_array::{ArrayRef, DeserializeMetadata, IntoArray, ProstMetadata};
11use vortex_dtype::{DType, FieldName, FieldNames, Nullability, StructFields};
12use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_err};
13use vortex_proto::expr as pb;
14
15use crate::{AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, VTable, vtable};
16
17vtable!(Pack);
18
19#[allow(clippy::derived_hash_with_manual_eq)]
48#[derive(Debug, Clone, PartialEq, Eq, Hash)]
49pub struct PackExpr {
50 names: FieldNames,
51 values: Vec<ExprRef>,
52 nullability: Nullability,
53}
54
55pub struct PackExprEncoding;
56
57impl VTable for PackVTable {
58 type Expr = PackExpr;
59 type Encoding = PackExprEncoding;
60 type Metadata = ProstMetadata<pb::PackOpts>;
61
62 fn id(_encoding: &Self::Encoding) -> ExprId {
63 ExprId::new_ref("pack")
64 }
65
66 fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
67 ExprEncodingRef::new_ref(PackExprEncoding.as_ref())
68 }
69
70 fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
71 Some(ProstMetadata(pb::PackOpts {
72 paths: expr.names.iter().map(|n| n.to_string()).collect(),
73 nullable: expr.nullability.into(),
74 }))
75 }
76
77 fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
78 expr.values.iter().collect()
79 }
80
81 fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
82 PackExpr::try_new(expr.names.clone(), children, expr.nullability)
83 }
84
85 fn build(
86 _encoding: &Self::Encoding,
87 metadata: &<Self::Metadata as DeserializeMetadata>::Output,
88 children: Vec<ExprRef>,
89 ) -> VortexResult<Self::Expr> {
90 if children.len() != metadata.paths.len() {
91 vortex_bail!(
92 "Pack expression expects {} children, got {}",
93 metadata.paths.len(),
94 children.len()
95 );
96 }
97 let names: FieldNames = metadata
98 .paths
99 .iter()
100 .map(|name| FieldName::from(name.as_str()))
101 .collect();
102 PackExpr::try_new(names, children, metadata.nullable.into())
103 }
104
105 fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
106 let len = scope.len();
107 let value_arrays = expr
108 .values
109 .iter()
110 .map(|value_expr| value_expr.unchecked_evaluate(scope))
111 .process_results(|it| it.collect::<Vec<_>>())?;
112 let validity = match expr.nullability {
113 Nullability::NonNullable => Validity::NonNullable,
114 Nullability::Nullable => Validity::AllValid,
115 };
116 Ok(StructArray::try_new(expr.names.clone(), value_arrays, len, validity)?.into_array())
117 }
118
119 fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
120 let value_dtypes = expr
121 .values
122 .iter()
123 .map(|value_expr| value_expr.return_dtype(scope))
124 .process_results(|it| it.collect())?;
125 Ok(DType::Struct(
126 StructFields::new(expr.names.clone(), value_dtypes),
127 expr.nullability,
128 ))
129 }
130}
131
132impl PackExpr {
133 pub fn try_new(
134 names: FieldNames,
135 values: Vec<ExprRef>,
136 nullability: Nullability,
137 ) -> VortexResult<Self> {
138 if names.len() != values.len() {
139 vortex_bail!("length mismatch {} {}", names.len(), values.len());
140 }
141 Ok(PackExpr {
142 names,
143 values,
144 nullability,
145 })
146 }
147
148 pub fn try_new_expr(
149 names: FieldNames,
150 values: Vec<ExprRef>,
151 nullability: Nullability,
152 ) -> VortexResult<ExprRef> {
153 Self::try_new(names, values, nullability).map(|v| v.into_expr())
154 }
155
156 pub fn names(&self) -> &FieldNames {
157 &self.names
158 }
159
160 pub fn field(&self, field_name: &FieldName) -> VortexResult<ExprRef> {
161 let idx = self
162 .names
163 .iter()
164 .position(|name| name == field_name)
165 .ok_or_else(|| {
166 vortex_err!(
167 "Cannot find field {} in pack fields {:?}",
168 field_name,
169 self.names
170 )
171 })?;
172
173 self.values
174 .get(idx)
175 .cloned()
176 .ok_or_else(|| vortex_err!("field index out of bounds: {}", idx))
177 }
178
179 pub fn nullability(&self) -> Nullability {
180 self.nullability
181 }
182}
183
184pub fn pack(
185 elements: impl IntoIterator<Item = (impl Into<FieldName>, ExprRef)>,
186 nullability: Nullability,
187) -> ExprRef {
188 let (names, values): (Vec<_>, Vec<_>) = elements
189 .into_iter()
190 .map(|(name, value)| (name.into(), value))
191 .unzip();
192 PackExpr::try_new(names.into(), values, nullability)
193 .vortex_expect("pack names and values have the same length")
194 .into_expr()
195}
196
197impl Display for PackExpr {
198 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
199 write!(
200 f,
201 "pack({}){}",
202 self.names
203 .iter()
204 .zip(&self.values)
205 .format_with(", ", |(name, expr), f| f(&format_args!("{name}: {expr}"))),
206 self.nullability
207 )
208 }
209}
210
211impl AnalysisExpr for PackExpr {}
212
213#[cfg(test)]
214mod tests {
215
216 use vortex_array::arrays::{PrimitiveArray, StructArray};
217 use vortex_array::validity::Validity;
218 use vortex_array::vtable::ValidityHelper;
219 use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
220 use vortex_buffer::buffer;
221 use vortex_dtype::{FieldNames, Nullability};
222 use vortex_error::{VortexResult, vortex_bail};
223
224 use crate::{IntoExpr, PackExpr, Scope, col};
225
226 fn test_array() -> ArrayRef {
227 StructArray::from_fields(&[
228 ("a", buffer![0, 1, 2].into_array()),
229 ("b", buffer![4, 5, 6].into_array()),
230 ])
231 .unwrap()
232 .into_array()
233 }
234
235 fn primitive_field(array: &dyn Array, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
236 let mut field_path = field_path.iter();
237
238 let Some(field) = field_path.next() else {
239 vortex_bail!("empty field path");
240 };
241
242 let mut array = array.to_struct()?.field_by_name(field)?.clone();
243 for field in field_path {
244 array = array.to_struct()?.field_by_name(field)?.clone();
245 }
246 Ok(array.to_primitive().unwrap())
247 }
248
249 #[test]
250 pub fn test_empty_pack() {
251 let expr =
252 PackExpr::try_new(FieldNames::default(), Vec::new(), Nullability::NonNullable).unwrap();
253
254 let test_array = test_array();
255 let actual_array = expr.evaluate(&Scope::new(test_array.clone())).unwrap();
256 assert_eq!(actual_array.len(), test_array.len());
257 assert_eq!(
258 actual_array.to_struct().unwrap().struct_fields().nfields(),
259 0
260 );
261 }
262
263 #[test]
264 pub fn test_simple_pack() {
265 let expr = PackExpr::try_new(
266 ["one", "two", "three"].into(),
267 vec![col("a"), col("b"), col("a")],
268 Nullability::NonNullable,
269 )
270 .unwrap();
271
272 let actual_array = expr
273 .evaluate(&Scope::new(test_array()))
274 .unwrap()
275 .to_struct()
276 .unwrap();
277 let expected_names: FieldNames = ["one", "two", "three"].into();
278 assert_eq!(actual_array.names(), &expected_names);
279 assert_eq!(actual_array.validity(), &Validity::NonNullable);
280
281 assert_eq!(
282 primitive_field(actual_array.as_ref(), &["one"])
283 .unwrap()
284 .as_slice::<i32>(),
285 [0, 1, 2]
286 );
287 assert_eq!(
288 primitive_field(actual_array.as_ref(), &["two"])
289 .unwrap()
290 .as_slice::<i32>(),
291 [4, 5, 6]
292 );
293 assert_eq!(
294 primitive_field(actual_array.as_ref(), &["three"])
295 .unwrap()
296 .as_slice::<i32>(),
297 [0, 1, 2]
298 );
299 }
300
301 #[test]
302 pub fn test_nested_pack() {
303 let expr = PackExpr::try_new(
304 ["one", "two", "three"].into(),
305 vec![
306 col("a"),
307 PackExpr::try_new(
308 ["two_one", "two_two"].into(),
309 vec![col("b"), col("b")],
310 Nullability::NonNullable,
311 )
312 .unwrap()
313 .into_expr(),
314 col("a"),
315 ],
316 Nullability::NonNullable,
317 )
318 .unwrap();
319
320 let actual_array = expr
321 .evaluate(&Scope::new(test_array()))
322 .unwrap()
323 .to_struct()
324 .unwrap();
325 let expected_names = FieldNames::from(["one", "two", "three"]);
326 assert_eq!(actual_array.names(), &expected_names);
327
328 assert_eq!(
329 primitive_field(actual_array.as_ref(), &["one"])
330 .unwrap()
331 .as_slice::<i32>(),
332 [0, 1, 2]
333 );
334 assert_eq!(
335 primitive_field(actual_array.as_ref(), &["two", "two_one"])
336 .unwrap()
337 .as_slice::<i32>(),
338 [4, 5, 6]
339 );
340 assert_eq!(
341 primitive_field(actual_array.as_ref(), &["two", "two_two"])
342 .unwrap()
343 .as_slice::<i32>(),
344 [4, 5, 6]
345 );
346 assert_eq!(
347 primitive_field(actual_array.as_ref(), &["three"])
348 .unwrap()
349 .as_slice::<i32>(),
350 [0, 1, 2]
351 );
352 }
353
354 #[test]
355 pub fn test_pack_nullable() {
356 let expr = PackExpr::try_new(
357 ["one", "two", "three"].into(),
358 vec![col("a"), col("b"), col("a")],
359 Nullability::Nullable,
360 )
361 .unwrap();
362
363 let actual_array = expr
364 .evaluate(&Scope::new(test_array()))
365 .unwrap()
366 .to_struct()
367 .unwrap();
368 let expected_names: FieldNames = ["one", "two", "three"].into();
369 assert_eq!(actual_array.names(), &expected_names);
370 assert_eq!(actual_array.validity(), &Validity::AllValid);
371 }
372}