vortex_expr/exprs/
pack.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::hash::Hash;
5
6use itertools::Itertools as _;
7use vortex_array::arrays::StructArray;
8use vortex_array::validity::Validity;
9use vortex_array::{ArrayRef, DeserializeMetadata, IntoArray, ProstMetadata};
10use vortex_dtype::{DType, FieldName, FieldNames, Nullability, StructFields};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_err};
12use vortex_proto::expr as pb;
13
14use crate::display::{DisplayAs, DisplayFormat};
15use crate::{AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, VTable, vtable};
16
17vtable!(Pack);
18
19/// Pack zero or more expressions into a structure with named fields.
20///
21/// # Examples
22///
23/// ```
24/// use vortex_array::{IntoArray, ToCanonical};
25/// use vortex_buffer::buffer;
26/// use vortex_expr::{root, PackExpr, Scope, VortexExpr};
27/// use vortex_scalar::Scalar;
28/// use vortex_dtype::Nullability;
29///
30/// let example = PackExpr::try_new(
31///     ["x", "x copy", "second x copy"].into(),
32///     vec![root(), root(), root()],
33///     Nullability::NonNullable,
34/// ).unwrap();
35/// let packed = example.evaluate(&Scope::new(buffer![100, 110, 200].into_array())).unwrap();
36/// let x_copy = packed
37///     .to_struct()
38///     .field_by_name("x copy")
39///     .unwrap()
40///     .clone();
41/// assert_eq!(x_copy.scalar_at(0), Scalar::from(100));
42/// assert_eq!(x_copy.scalar_at(1), Scalar::from(110));
43/// assert_eq!(x_copy.scalar_at(2), Scalar::from(200));
44/// ```
45///
46#[allow(clippy::derived_hash_with_manual_eq)]
47#[derive(Debug, Clone, PartialEq, Eq, Hash)]
48pub struct PackExpr {
49    names: FieldNames,
50    values: Vec<ExprRef>,
51    nullability: Nullability,
52}
53
54pub struct PackExprEncoding;
55
56impl VTable for PackVTable {
57    type Expr = PackExpr;
58    type Encoding = PackExprEncoding;
59    type Metadata = ProstMetadata<pb::PackOpts>;
60
61    fn id(_encoding: &Self::Encoding) -> ExprId {
62        ExprId::new_ref("pack")
63    }
64
65    fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
66        ExprEncodingRef::new_ref(PackExprEncoding.as_ref())
67    }
68
69    fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
70        Some(ProstMetadata(pb::PackOpts {
71            paths: expr.names.iter().map(|n| n.to_string()).collect(),
72            nullable: expr.nullability.into(),
73        }))
74    }
75
76    fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
77        expr.values.iter().collect()
78    }
79
80    fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
81        PackExpr::try_new(expr.names.clone(), children, expr.nullability)
82    }
83
84    fn build(
85        _encoding: &Self::Encoding,
86        metadata: &<Self::Metadata as DeserializeMetadata>::Output,
87        children: Vec<ExprRef>,
88    ) -> VortexResult<Self::Expr> {
89        if children.len() != metadata.paths.len() {
90            vortex_bail!(
91                "Pack expression expects {} children, got {}",
92                metadata.paths.len(),
93                children.len()
94            );
95        }
96        let names: FieldNames = metadata
97            .paths
98            .iter()
99            .map(|name| FieldName::from(name.as_str()))
100            .collect();
101        PackExpr::try_new(names, children, metadata.nullable.into())
102    }
103
104    fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
105        let len = scope.len();
106        let value_arrays = expr
107            .values
108            .iter()
109            .zip_eq(expr.names.iter())
110            .map(|(value_expr, name)| {
111                value_expr
112                    .unchecked_evaluate(scope)
113                    .map_err(|e| e.with_context(format!("Can't evaluate '{name}'")))
114            })
115            .process_results(|it| it.collect::<Vec<_>>())?;
116        let validity = match expr.nullability {
117            Nullability::NonNullable => Validity::NonNullable,
118            Nullability::Nullable => Validity::AllValid,
119        };
120        Ok(StructArray::try_new(expr.names.clone(), value_arrays, len, validity)?.into_array())
121    }
122
123    fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
124        let value_dtypes = expr
125            .values
126            .iter()
127            .map(|value_expr| value_expr.return_dtype(scope))
128            .process_results(|it| it.collect())?;
129        Ok(DType::Struct(
130            StructFields::new(expr.names.clone(), value_dtypes),
131            expr.nullability,
132        ))
133    }
134}
135
136impl PackExpr {
137    pub fn try_new(
138        names: FieldNames,
139        values: Vec<ExprRef>,
140        nullability: Nullability,
141    ) -> VortexResult<Self> {
142        if names.len() != values.len() {
143            vortex_bail!("length mismatch {} {}", names.len(), values.len());
144        }
145        Ok(PackExpr {
146            names,
147            values,
148            nullability,
149        })
150    }
151
152    pub fn try_new_expr(
153        names: FieldNames,
154        values: Vec<ExprRef>,
155        nullability: Nullability,
156    ) -> VortexResult<ExprRef> {
157        Self::try_new(names, values, nullability).map(|v| v.into_expr())
158    }
159
160    pub fn names(&self) -> &FieldNames {
161        &self.names
162    }
163
164    pub fn field(&self, field_name: &FieldName) -> VortexResult<ExprRef> {
165        let idx = self
166            .names
167            .iter()
168            .position(|name| name == field_name)
169            .ok_or_else(|| {
170                vortex_err!(
171                    "Cannot find field {} in pack fields {:?}",
172                    field_name,
173                    self.names
174                )
175            })?;
176
177        self.values
178            .get(idx)
179            .cloned()
180            .ok_or_else(|| vortex_err!("field index out of bounds: {}", idx))
181    }
182
183    pub fn nullability(&self) -> Nullability {
184        self.nullability
185    }
186}
187
188/// Creates an expression that packs values into a struct with named fields.
189///
190/// ```rust
191/// # use vortex_dtype::Nullability;
192/// # use vortex_expr::{pack, col, lit};
193/// let expr = pack([("id", col("user_id")), ("constant", lit(42))], Nullability::NonNullable);
194/// ```
195pub fn pack(
196    elements: impl IntoIterator<Item = (impl Into<FieldName>, ExprRef)>,
197    nullability: Nullability,
198) -> ExprRef {
199    let (names, values): (Vec<_>, Vec<_>) = elements
200        .into_iter()
201        .map(|(name, value)| (name.into(), value))
202        .unzip();
203    PackExpr::try_new(names.into(), values, nullability)
204        .vortex_expect("pack names and values have the same length")
205        .into_expr()
206}
207
208impl DisplayAs for PackExpr {
209    fn fmt_as(&self, df: DisplayFormat, f: &mut std::fmt::Formatter) -> std::fmt::Result {
210        match df {
211            DisplayFormat::Compact => {
212                write!(
213                    f,
214                    "pack({}){}",
215                    self.names
216                        .iter()
217                        .zip(&self.values)
218                        .format_with(", ", |(name, expr), f| f(&format_args!("{name}: {expr}"))),
219                    self.nullability
220                )
221            }
222            DisplayFormat::Tree => {
223                write!(f, "Pack")
224            }
225        }
226    }
227
228    fn child_names(&self) -> Option<Vec<String>> {
229        Some(self.names.iter().map(|n| n.to_string()).collect())
230    }
231}
232
233impl AnalysisExpr for PackExpr {}
234
235#[cfg(test)]
236mod tests {
237
238    use vortex_array::arrays::{PrimitiveArray, StructArray};
239    use vortex_array::validity::Validity;
240    use vortex_array::vtable::ValidityHelper;
241    use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
242    use vortex_buffer::buffer;
243    use vortex_dtype::{FieldNames, Nullability};
244    use vortex_error::{VortexResult, vortex_bail};
245
246    use crate::{IntoExpr, PackExpr, Scope, col, pack};
247
248    fn test_array() -> ArrayRef {
249        StructArray::from_fields(&[
250            ("a", buffer![0, 1, 2].into_array()),
251            ("b", buffer![4, 5, 6].into_array()),
252        ])
253        .unwrap()
254        .into_array()
255    }
256
257    fn primitive_field(array: &dyn Array, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
258        let mut field_path = field_path.iter();
259
260        let Some(field) = field_path.next() else {
261            vortex_bail!("empty field path");
262        };
263
264        let mut array = array.to_struct().field_by_name(field)?.clone();
265        for field in field_path {
266            array = array.to_struct().field_by_name(field)?.clone();
267        }
268        Ok(array.to_primitive())
269    }
270
271    #[test]
272    pub fn test_empty_pack() {
273        let expr =
274            PackExpr::try_new(FieldNames::default(), Vec::new(), Nullability::NonNullable).unwrap();
275
276        let test_array = test_array();
277        let actual_array = expr.evaluate(&Scope::new(test_array.clone())).unwrap();
278        assert_eq!(actual_array.len(), test_array.len());
279        assert_eq!(actual_array.to_struct().struct_fields().nfields(), 0);
280    }
281
282    #[test]
283    pub fn test_simple_pack() {
284        let expr = PackExpr::try_new(
285            ["one", "two", "three"].into(),
286            vec![col("a"), col("b"), col("a")],
287            Nullability::NonNullable,
288        )
289        .unwrap();
290
291        let actual_array = expr
292            .evaluate(&Scope::new(test_array()))
293            .unwrap()
294            .to_struct();
295
296        assert_eq!(actual_array.names(), ["one", "two", "three"]);
297        assert_eq!(actual_array.validity(), &Validity::NonNullable);
298
299        assert_eq!(
300            primitive_field(actual_array.as_ref(), &["one"])
301                .unwrap()
302                .as_slice::<i32>(),
303            [0, 1, 2]
304        );
305        assert_eq!(
306            primitive_field(actual_array.as_ref(), &["two"])
307                .unwrap()
308                .as_slice::<i32>(),
309            [4, 5, 6]
310        );
311        assert_eq!(
312            primitive_field(actual_array.as_ref(), &["three"])
313                .unwrap()
314                .as_slice::<i32>(),
315            [0, 1, 2]
316        );
317    }
318
319    #[test]
320    pub fn test_nested_pack() {
321        let expr = PackExpr::try_new(
322            ["one", "two", "three"].into(),
323            vec![
324                col("a"),
325                PackExpr::try_new(
326                    ["two_one", "two_two"].into(),
327                    vec![col("b"), col("b")],
328                    Nullability::NonNullable,
329                )
330                .unwrap()
331                .into_expr(),
332                col("a"),
333            ],
334            Nullability::NonNullable,
335        )
336        .unwrap();
337
338        let actual_array = expr
339            .evaluate(&Scope::new(test_array()))
340            .unwrap()
341            .to_struct();
342
343        assert_eq!(actual_array.names(), ["one", "two", "three"]);
344
345        assert_eq!(
346            primitive_field(actual_array.as_ref(), &["one"])
347                .unwrap()
348                .as_slice::<i32>(),
349            [0, 1, 2]
350        );
351        assert_eq!(
352            primitive_field(actual_array.as_ref(), &["two", "two_one"])
353                .unwrap()
354                .as_slice::<i32>(),
355            [4, 5, 6]
356        );
357        assert_eq!(
358            primitive_field(actual_array.as_ref(), &["two", "two_two"])
359                .unwrap()
360                .as_slice::<i32>(),
361            [4, 5, 6]
362        );
363        assert_eq!(
364            primitive_field(actual_array.as_ref(), &["three"])
365                .unwrap()
366                .as_slice::<i32>(),
367            [0, 1, 2]
368        );
369    }
370
371    #[test]
372    pub fn test_pack_nullable() {
373        let expr = PackExpr::try_new(
374            ["one", "two", "three"].into(),
375            vec![col("a"), col("b"), col("a")],
376            Nullability::Nullable,
377        )
378        .unwrap();
379
380        let actual_array = expr
381            .evaluate(&Scope::new(test_array()))
382            .unwrap()
383            .to_struct();
384
385        assert_eq!(actual_array.names(), ["one", "two", "three"]);
386        assert_eq!(actual_array.validity(), &Validity::AllValid);
387    }
388
389    #[test]
390    pub fn test_display() {
391        let expr = pack(
392            [("id", col("user_id")), ("name", col("username"))],
393            Nullability::NonNullable,
394        );
395        assert_eq!(expr.to_string(), "pack(id: $.user_id, name: $.username)");
396
397        let expr2 = PackExpr::try_new(
398            ["x", "y"].into(),
399            vec![col("a"), col("b")],
400            Nullability::Nullable,
401        )
402        .unwrap();
403        assert_eq!(expr2.to_string(), "pack(x: $.a, y: $.b)?");
404    }
405}