vortex_expr/exprs/
pack.rs

1use std::any::Any;
2use std::fmt::Display;
3use std::hash::Hash;
4use std::sync::Arc;
5
6use itertools::Itertools as _;
7use vortex_array::arrays::StructArray;
8use vortex_array::validity::Validity;
9use vortex_array::{ArrayRef, IntoArray};
10use vortex_dtype::{DType, FieldName, FieldNames, Nullability, StructFields};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_err};
12
13use crate::{AnalysisExpr, ExprRef, Scope, ScopeDType, VortexExpr};
14
15/// Pack zero or more expressions into a structure with named fields.
16///
17/// # Examples
18///
19/// ```
20/// use vortex_array::{IntoArray, ToCanonical};
21/// use vortex_buffer::buffer;
22/// use vortex_expr::{root, Pack, Scope, VortexExpr};
23/// use vortex_scalar::Scalar;
24/// use vortex_dtype::Nullability;
25///
26/// let example = Pack::try_new_expr(
27///     ["x", "x copy", "second x copy"].into(),
28///     vec![root(), root(), root()],
29///     Nullability::NonNullable,
30/// ).unwrap();
31/// let packed = example.evaluate(&Scope::new(buffer![100, 110, 200].into_array())).unwrap();
32/// let x_copy = packed
33///     .to_struct()
34///     .unwrap()
35///     .field_by_name("x copy")
36///     .unwrap()
37///     .clone();
38/// assert_eq!(x_copy.scalar_at(0).unwrap(), Scalar::from(100));
39/// assert_eq!(x_copy.scalar_at(1).unwrap(), Scalar::from(110));
40/// assert_eq!(x_copy.scalar_at(2).unwrap(), Scalar::from(200));
41/// ```
42///
43#[derive(Debug, Clone, PartialEq, Eq, Hash)]
44pub struct Pack {
45    names: FieldNames,
46    values: Vec<ExprRef>,
47    nullability: Nullability,
48}
49
50impl Pack {
51    pub fn try_new_expr(
52        names: FieldNames,
53        values: Vec<ExprRef>,
54        nullability: Nullability,
55    ) -> VortexResult<ExprRef> {
56        if names.len() != values.len() {
57            vortex_bail!("length mismatch {} {}", names.len(), values.len());
58        }
59        Ok(Arc::new(Pack {
60            names,
61            values,
62            nullability,
63        }))
64    }
65
66    pub fn names(&self) -> &FieldNames {
67        &self.names
68    }
69
70    pub fn field(&self, field_name: &FieldName) -> VortexResult<ExprRef> {
71        let idx = self
72            .names
73            .iter()
74            .position(|name| name == field_name)
75            .ok_or_else(|| {
76                vortex_err!(
77                    "Cannot find field {} in pack fields {:?}",
78                    field_name,
79                    self.names
80                )
81            })?;
82
83        self.values
84            .get(idx)
85            .cloned()
86            .ok_or_else(|| vortex_err!("field index out of bounds: {}", idx))
87    }
88
89    pub fn nullability(&self) -> Nullability {
90        self.nullability
91    }
92}
93
94pub fn pack(
95    elements: impl IntoIterator<Item = (impl Into<FieldName>, ExprRef)>,
96    nullability: Nullability,
97) -> ExprRef {
98    let (names, values): (Vec<_>, Vec<_>) = elements
99        .into_iter()
100        .map(|(name, value)| (name.into(), value))
101        .unzip();
102    Pack::try_new_expr(names.into(), values, nullability)
103        .vortex_expect("pack names and values have the same length")
104}
105
106impl Display for Pack {
107    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108        write!(
109            f,
110            "pack({{{}}}){}",
111            self.names
112                .iter()
113                .zip(&self.values)
114                .format_with(", ", |(name, expr), f| f(&format_args!("{name}: {expr}"))),
115            self.nullability
116        )
117    }
118}
119
120#[cfg(feature = "proto")]
121pub(crate) mod proto {
122    use vortex_error::{VortexResult, vortex_bail};
123    use vortex_proto::expr::kind;
124    use vortex_proto::expr::kind::Kind;
125
126    use crate::{ExprDeserialize, ExprRef, ExprSerializable, Id, Pack};
127
128    pub struct PackSerde;
129
130    impl Id for PackSerde {
131        fn id(&self) -> &'static str {
132            "pack"
133        }
134    }
135
136    impl ExprDeserialize for PackSerde {
137        fn deserialize(&self, kind: &Kind, children: Vec<ExprRef>) -> VortexResult<ExprRef> {
138            let Kind::Pack(op) = kind else {
139                vortex_bail!("wrong kind {:?}, wanted pack", kind)
140            };
141
142            Pack::try_new_expr(
143                op.paths.iter().map(|p| p.as_str()).collect(),
144                children,
145                op.nullable.into(),
146            )
147        }
148    }
149
150    impl ExprSerializable for Pack {
151        fn id(&self) -> &'static str {
152            PackSerde.id()
153        }
154
155        fn serialize_kind(&self) -> VortexResult<Kind> {
156            Ok(Kind::Pack(kind::Pack {
157                paths: self.names.iter().map(|n| n.to_string()).collect(),
158                nullable: self.nullability.into(),
159            }))
160        }
161    }
162}
163
164impl AnalysisExpr for Pack {}
165
166impl VortexExpr for Pack {
167    fn as_any(&self) -> &dyn Any {
168        self
169    }
170
171    fn unchecked_evaluate(&self, scope: &Scope) -> VortexResult<ArrayRef> {
172        let len = scope.len();
173        let value_arrays = self
174            .values
175            .iter()
176            .map(|value_expr| value_expr.unchecked_evaluate(scope))
177            .process_results(|it| it.collect::<Vec<_>>())?;
178        let validity = match self.nullability {
179            Nullability::NonNullable => Validity::NonNullable,
180            Nullability::Nullable => Validity::AllValid,
181        };
182        Ok(StructArray::try_new(self.names.clone(), value_arrays, len, validity)?.into_array())
183    }
184
185    fn children(&self) -> Vec<&ExprRef> {
186        self.values.iter().collect()
187    }
188
189    fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef {
190        assert_eq!(children.len(), self.values.len());
191        Self::try_new_expr(self.names.clone(), children, self.nullability)
192            .vortex_expect("children are known to have the same length as names")
193    }
194
195    fn return_dtype(&self, scope: &ScopeDType) -> VortexResult<DType> {
196        let value_dtypes = self
197            .values
198            .iter()
199            .map(|value_expr| value_expr.return_dtype(scope))
200            .process_results(|it| it.collect())?;
201        Ok(DType::Struct(
202            StructFields::new(self.names.clone(), value_dtypes),
203            self.nullability,
204        ))
205    }
206}
207
208#[cfg(test)]
209mod tests {
210
211    use vortex_array::arrays::{PrimitiveArray, StructArray};
212    use vortex_array::validity::Validity;
213    use vortex_array::vtable::ValidityHelper;
214    use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
215    use vortex_buffer::buffer;
216    use vortex_dtype::{FieldNames, Nullability};
217    use vortex_error::{VortexResult, vortex_bail};
218
219    use crate::{Pack, Scope, col};
220
221    fn test_array() -> ArrayRef {
222        StructArray::from_fields(&[
223            ("a", buffer![0, 1, 2].into_array()),
224            ("b", buffer![4, 5, 6].into_array()),
225        ])
226        .unwrap()
227        .into_array()
228    }
229
230    fn primitive_field(array: &dyn Array, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
231        let mut field_path = field_path.iter();
232
233        let Some(field) = field_path.next() else {
234            vortex_bail!("empty field path");
235        };
236
237        let mut array = array.to_struct()?.field_by_name(field)?.clone();
238        for field in field_path {
239            array = array.to_struct()?.field_by_name(field)?.clone();
240        }
241        Ok(array.to_primitive().unwrap())
242    }
243
244    #[test]
245    pub fn test_empty_pack() {
246        let expr = Pack::try_new_expr(FieldNames::default(), Vec::new(), Nullability::NonNullable)
247            .unwrap();
248
249        let test_array = test_array();
250        let actual_array = expr.evaluate(&Scope::new(test_array.clone())).unwrap();
251        assert_eq!(actual_array.len(), test_array.len());
252        assert_eq!(
253            actual_array.to_struct().unwrap().struct_fields().nfields(),
254            0
255        );
256    }
257
258    #[test]
259    pub fn test_simple_pack() {
260        let expr = Pack::try_new_expr(
261            ["one", "two", "three"].into(),
262            vec![col("a"), col("b"), col("a")],
263            Nullability::NonNullable,
264        )
265        .unwrap();
266
267        let actual_array = expr
268            .evaluate(&Scope::new(test_array()))
269            .unwrap()
270            .to_struct()
271            .unwrap();
272        let expected_names: FieldNames = ["one", "two", "three"].into();
273        assert_eq!(actual_array.names(), &expected_names);
274        assert_eq!(actual_array.validity(), &Validity::NonNullable);
275
276        assert_eq!(
277            primitive_field(actual_array.as_ref(), &["one"])
278                .unwrap()
279                .as_slice::<i32>(),
280            [0, 1, 2]
281        );
282        assert_eq!(
283            primitive_field(actual_array.as_ref(), &["two"])
284                .unwrap()
285                .as_slice::<i32>(),
286            [4, 5, 6]
287        );
288        assert_eq!(
289            primitive_field(actual_array.as_ref(), &["three"])
290                .unwrap()
291                .as_slice::<i32>(),
292            [0, 1, 2]
293        );
294    }
295
296    #[test]
297    pub fn test_nested_pack() {
298        let expr = Pack::try_new_expr(
299            ["one", "two", "three"].into(),
300            vec![
301                col("a"),
302                Pack::try_new_expr(
303                    ["two_one", "two_two"].into(),
304                    vec![col("b"), col("b")],
305                    Nullability::NonNullable,
306                )
307                .unwrap(),
308                col("a"),
309            ],
310            Nullability::NonNullable,
311        )
312        .unwrap();
313
314        let actual_array = expr
315            .evaluate(&Scope::new(test_array()))
316            .unwrap()
317            .to_struct()
318            .unwrap();
319        let expected_names = FieldNames::from(["one", "two", "three"]);
320        assert_eq!(actual_array.names(), &expected_names);
321
322        assert_eq!(
323            primitive_field(actual_array.as_ref(), &["one"])
324                .unwrap()
325                .as_slice::<i32>(),
326            [0, 1, 2]
327        );
328        assert_eq!(
329            primitive_field(actual_array.as_ref(), &["two", "two_one"])
330                .unwrap()
331                .as_slice::<i32>(),
332            [4, 5, 6]
333        );
334        assert_eq!(
335            primitive_field(actual_array.as_ref(), &["two", "two_two"])
336                .unwrap()
337                .as_slice::<i32>(),
338            [4, 5, 6]
339        );
340        assert_eq!(
341            primitive_field(actual_array.as_ref(), &["three"])
342                .unwrap()
343                .as_slice::<i32>(),
344            [0, 1, 2]
345        );
346    }
347
348    #[test]
349    pub fn test_pack_nullable() {
350        let expr = Pack::try_new_expr(
351            ["one", "two", "three"].into(),
352            vec![col("a"), col("b"), col("a")],
353            Nullability::Nullable,
354        )
355        .unwrap();
356
357        let actual_array = expr
358            .evaluate(&Scope::new(test_array()))
359            .unwrap()
360            .to_struct()
361            .unwrap();
362        let expected_names: FieldNames = ["one", "two", "three"].into();
363        assert_eq!(actual_array.names(), &expected_names);
364        assert_eq!(actual_array.validity(), &Validity::AllValid);
365    }
366}