vortex_expr/exprs/
pack.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::hash::Hash;
6
7use itertools::Itertools as _;
8use vortex_array::arrays::StructArray;
9use vortex_array::validity::Validity;
10use vortex_array::{ArrayRef, DeserializeMetadata, IntoArray, ProstMetadata};
11use vortex_dtype::{DType, FieldName, FieldNames, Nullability, StructFields};
12use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_err};
13use vortex_proto::expr as pb;
14
15use crate::{AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, VTable, vtable};
16
17vtable!(Pack);
18
19/// Pack zero or more expressions into a structure with named fields.
20///
21/// # Examples
22///
23/// ```
24/// use vortex_array::{IntoArray, ToCanonical};
25/// use vortex_buffer::buffer;
26/// use vortex_expr::{root, PackExpr, Scope, VortexExpr};
27/// use vortex_scalar::Scalar;
28/// use vortex_dtype::Nullability;
29///
30/// let example = PackExpr::try_new(
31///     ["x", "x copy", "second x copy"].into(),
32///     vec![root(), root(), root()],
33///     Nullability::NonNullable,
34/// ).unwrap();
35/// let packed = example.evaluate(&Scope::new(buffer![100, 110, 200].into_array())).unwrap();
36/// let x_copy = packed
37///     .to_struct()
38///     .unwrap()
39///     .field_by_name("x copy")
40///     .unwrap()
41///     .clone();
42/// assert_eq!(x_copy.scalar_at(0).unwrap(), Scalar::from(100));
43/// assert_eq!(x_copy.scalar_at(1).unwrap(), Scalar::from(110));
44/// assert_eq!(x_copy.scalar_at(2).unwrap(), Scalar::from(200));
45/// ```
46///
47#[allow(clippy::derived_hash_with_manual_eq)]
48#[derive(Debug, Clone, PartialEq, Eq, Hash)]
49pub struct PackExpr {
50    names: FieldNames,
51    values: Vec<ExprRef>,
52    nullability: Nullability,
53}
54
55pub struct PackExprEncoding;
56
57impl VTable for PackVTable {
58    type Expr = PackExpr;
59    type Encoding = PackExprEncoding;
60    type Metadata = ProstMetadata<pb::PackOpts>;
61
62    fn id(_encoding: &Self::Encoding) -> ExprId {
63        ExprId::new_ref("pack")
64    }
65
66    fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
67        ExprEncodingRef::new_ref(PackExprEncoding.as_ref())
68    }
69
70    fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
71        Some(ProstMetadata(pb::PackOpts {
72            paths: expr.names.iter().map(|n| n.to_string()).collect(),
73            nullable: expr.nullability.into(),
74        }))
75    }
76
77    fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
78        expr.values.iter().collect()
79    }
80
81    fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
82        PackExpr::try_new(expr.names.clone(), children, expr.nullability)
83    }
84
85    fn build(
86        _encoding: &Self::Encoding,
87        metadata: &<Self::Metadata as DeserializeMetadata>::Output,
88        children: Vec<ExprRef>,
89    ) -> VortexResult<Self::Expr> {
90        if children.len() != metadata.paths.len() {
91            vortex_bail!(
92                "Pack expression expects {} children, got {}",
93                metadata.paths.len(),
94                children.len()
95            );
96        }
97        let names: FieldNames = metadata
98            .paths
99            .iter()
100            .map(|name| FieldName::from(name.as_str()))
101            .collect();
102        PackExpr::try_new(names, children, metadata.nullable.into())
103    }
104
105    fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
106        let len = scope.len();
107        let value_arrays = expr
108            .values
109            .iter()
110            .map(|value_expr| value_expr.unchecked_evaluate(scope))
111            .process_results(|it| it.collect::<Vec<_>>())?;
112        let validity = match expr.nullability {
113            Nullability::NonNullable => Validity::NonNullable,
114            Nullability::Nullable => Validity::AllValid,
115        };
116        Ok(StructArray::try_new(expr.names.clone(), value_arrays, len, validity)?.into_array())
117    }
118
119    fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
120        let value_dtypes = expr
121            .values
122            .iter()
123            .map(|value_expr| value_expr.return_dtype(scope))
124            .process_results(|it| it.collect())?;
125        Ok(DType::Struct(
126            StructFields::new(expr.names.clone(), value_dtypes),
127            expr.nullability,
128        ))
129    }
130}
131
132impl PackExpr {
133    pub fn try_new(
134        names: FieldNames,
135        values: Vec<ExprRef>,
136        nullability: Nullability,
137    ) -> VortexResult<Self> {
138        if names.len() != values.len() {
139            vortex_bail!("length mismatch {} {}", names.len(), values.len());
140        }
141        Ok(PackExpr {
142            names,
143            values,
144            nullability,
145        })
146    }
147
148    pub fn try_new_expr(
149        names: FieldNames,
150        values: Vec<ExprRef>,
151        nullability: Nullability,
152    ) -> VortexResult<ExprRef> {
153        Self::try_new(names, values, nullability).map(|v| v.into_expr())
154    }
155
156    pub fn names(&self) -> &FieldNames {
157        &self.names
158    }
159
160    pub fn field(&self, field_name: &FieldName) -> VortexResult<ExprRef> {
161        let idx = self
162            .names
163            .iter()
164            .position(|name| name == field_name)
165            .ok_or_else(|| {
166                vortex_err!(
167                    "Cannot find field {} in pack fields {:?}",
168                    field_name,
169                    self.names
170                )
171            })?;
172
173        self.values
174            .get(idx)
175            .cloned()
176            .ok_or_else(|| vortex_err!("field index out of bounds: {}", idx))
177    }
178
179    pub fn nullability(&self) -> Nullability {
180        self.nullability
181    }
182}
183
184/// Creates an expression that packs values into a struct with named fields.
185///
186/// ```rust
187/// # use vortex_dtype::Nullability;
188/// # use vortex_expr::{pack, col, lit};
189/// let expr = pack([("id", col("user_id")), ("constant", lit(42))], Nullability::NonNullable);
190/// ```
191pub fn pack(
192    elements: impl IntoIterator<Item = (impl Into<FieldName>, ExprRef)>,
193    nullability: Nullability,
194) -> ExprRef {
195    let (names, values): (Vec<_>, Vec<_>) = elements
196        .into_iter()
197        .map(|(name, value)| (name.into(), value))
198        .unzip();
199    PackExpr::try_new(names.into(), values, nullability)
200        .vortex_expect("pack names and values have the same length")
201        .into_expr()
202}
203
204impl Display for PackExpr {
205    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
206        write!(
207            f,
208            "pack({}){}",
209            self.names
210                .iter()
211                .zip(&self.values)
212                .format_with(", ", |(name, expr), f| f(&format_args!("{name}: {expr}"))),
213            self.nullability
214        )
215    }
216}
217
218impl AnalysisExpr for PackExpr {}
219
220#[cfg(test)]
221mod tests {
222
223    use vortex_array::arrays::{PrimitiveArray, StructArray};
224    use vortex_array::validity::Validity;
225    use vortex_array::vtable::ValidityHelper;
226    use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
227    use vortex_buffer::buffer;
228    use vortex_dtype::{FieldNames, Nullability};
229    use vortex_error::{VortexResult, vortex_bail};
230
231    use crate::{IntoExpr, PackExpr, Scope, col};
232
233    fn test_array() -> ArrayRef {
234        StructArray::from_fields(&[
235            ("a", buffer![0, 1, 2].into_array()),
236            ("b", buffer![4, 5, 6].into_array()),
237        ])
238        .unwrap()
239        .into_array()
240    }
241
242    fn primitive_field(array: &dyn Array, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
243        let mut field_path = field_path.iter();
244
245        let Some(field) = field_path.next() else {
246            vortex_bail!("empty field path");
247        };
248
249        let mut array = array.to_struct()?.field_by_name(field)?.clone();
250        for field in field_path {
251            array = array.to_struct()?.field_by_name(field)?.clone();
252        }
253        Ok(array.to_primitive().unwrap())
254    }
255
256    #[test]
257    pub fn test_empty_pack() {
258        let expr =
259            PackExpr::try_new(FieldNames::default(), Vec::new(), Nullability::NonNullable).unwrap();
260
261        let test_array = test_array();
262        let actual_array = expr.evaluate(&Scope::new(test_array.clone())).unwrap();
263        assert_eq!(actual_array.len(), test_array.len());
264        assert_eq!(
265            actual_array.to_struct().unwrap().struct_fields().nfields(),
266            0
267        );
268    }
269
270    #[test]
271    pub fn test_simple_pack() {
272        let expr = PackExpr::try_new(
273            ["one", "two", "three"].into(),
274            vec![col("a"), col("b"), col("a")],
275            Nullability::NonNullable,
276        )
277        .unwrap();
278
279        let actual_array = expr
280            .evaluate(&Scope::new(test_array()))
281            .unwrap()
282            .to_struct()
283            .unwrap();
284
285        assert_eq!(actual_array.names(), ["one", "two", "three"]);
286        assert_eq!(actual_array.validity(), &Validity::NonNullable);
287
288        assert_eq!(
289            primitive_field(actual_array.as_ref(), &["one"])
290                .unwrap()
291                .as_slice::<i32>(),
292            [0, 1, 2]
293        );
294        assert_eq!(
295            primitive_field(actual_array.as_ref(), &["two"])
296                .unwrap()
297                .as_slice::<i32>(),
298            [4, 5, 6]
299        );
300        assert_eq!(
301            primitive_field(actual_array.as_ref(), &["three"])
302                .unwrap()
303                .as_slice::<i32>(),
304            [0, 1, 2]
305        );
306    }
307
308    #[test]
309    pub fn test_nested_pack() {
310        let expr = PackExpr::try_new(
311            ["one", "two", "three"].into(),
312            vec![
313                col("a"),
314                PackExpr::try_new(
315                    ["two_one", "two_two"].into(),
316                    vec![col("b"), col("b")],
317                    Nullability::NonNullable,
318                )
319                .unwrap()
320                .into_expr(),
321                col("a"),
322            ],
323            Nullability::NonNullable,
324        )
325        .unwrap();
326
327        let actual_array = expr
328            .evaluate(&Scope::new(test_array()))
329            .unwrap()
330            .to_struct()
331            .unwrap();
332
333        assert_eq!(actual_array.names(), ["one", "two", "three"]);
334
335        assert_eq!(
336            primitive_field(actual_array.as_ref(), &["one"])
337                .unwrap()
338                .as_slice::<i32>(),
339            [0, 1, 2]
340        );
341        assert_eq!(
342            primitive_field(actual_array.as_ref(), &["two", "two_one"])
343                .unwrap()
344                .as_slice::<i32>(),
345            [4, 5, 6]
346        );
347        assert_eq!(
348            primitive_field(actual_array.as_ref(), &["two", "two_two"])
349                .unwrap()
350                .as_slice::<i32>(),
351            [4, 5, 6]
352        );
353        assert_eq!(
354            primitive_field(actual_array.as_ref(), &["three"])
355                .unwrap()
356                .as_slice::<i32>(),
357            [0, 1, 2]
358        );
359    }
360
361    #[test]
362    pub fn test_pack_nullable() {
363        let expr = PackExpr::try_new(
364            ["one", "two", "three"].into(),
365            vec![col("a"), col("b"), col("a")],
366            Nullability::Nullable,
367        )
368        .unwrap();
369
370        let actual_array = expr
371            .evaluate(&Scope::new(test_array()))
372            .unwrap()
373            .to_struct()
374            .unwrap();
375
376        assert_eq!(actual_array.names(), ["one", "two", "three"]);
377        assert_eq!(actual_array.validity(), &Validity::AllValid);
378    }
379}