vortex_expr/exprs/
pack.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::hash::Hash;
5
6use itertools::Itertools as _;
7use vortex_array::arrays::StructArray;
8use vortex_array::validity::Validity;
9use vortex_array::{ArrayRef, DeserializeMetadata, IntoArray, ProstMetadata};
10use vortex_dtype::{DType, FieldName, FieldNames, Nullability, StructFields};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_err};
12use vortex_proto::expr as pb;
13
14use crate::display::{DisplayAs, DisplayFormat};
15use crate::{AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, VTable, vtable};
16
17vtable!(Pack);
18
19/// Pack zero or more expressions into a structure with named fields.
20///
21/// # Examples
22///
23/// ```
24/// use vortex_array::{IntoArray, ToCanonical};
25/// use vortex_buffer::buffer;
26/// use vortex_expr::{root, PackExpr, Scope, VortexExpr};
27/// use vortex_scalar::Scalar;
28/// use vortex_dtype::Nullability;
29///
30/// let example = PackExpr::try_new(
31///     ["x", "x copy", "second x copy"].into(),
32///     vec![root(), root(), root()],
33///     Nullability::NonNullable,
34/// ).unwrap();
35/// let packed = example.evaluate(&Scope::new(buffer![100, 110, 200].into_array())).unwrap();
36/// let x_copy = packed
37///     .to_struct()
38///     .unwrap()
39///     .field_by_name("x copy")
40///     .unwrap()
41///     .clone();
42/// assert_eq!(x_copy.scalar_at(0), Scalar::from(100));
43/// assert_eq!(x_copy.scalar_at(1), Scalar::from(110));
44/// assert_eq!(x_copy.scalar_at(2), Scalar::from(200));
45/// ```
46///
47#[allow(clippy::derived_hash_with_manual_eq)]
48#[derive(Debug, Clone, PartialEq, Eq, Hash)]
49pub struct PackExpr {
50    names: FieldNames,
51    values: Vec<ExprRef>,
52    nullability: Nullability,
53}
54
55pub struct PackExprEncoding;
56
57impl VTable for PackVTable {
58    type Expr = PackExpr;
59    type Encoding = PackExprEncoding;
60    type Metadata = ProstMetadata<pb::PackOpts>;
61
62    fn id(_encoding: &Self::Encoding) -> ExprId {
63        ExprId::new_ref("pack")
64    }
65
66    fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
67        ExprEncodingRef::new_ref(PackExprEncoding.as_ref())
68    }
69
70    fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
71        Some(ProstMetadata(pb::PackOpts {
72            paths: expr.names.iter().map(|n| n.to_string()).collect(),
73            nullable: expr.nullability.into(),
74        }))
75    }
76
77    fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
78        expr.values.iter().collect()
79    }
80
81    fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
82        PackExpr::try_new(expr.names.clone(), children, expr.nullability)
83    }
84
85    fn build(
86        _encoding: &Self::Encoding,
87        metadata: &<Self::Metadata as DeserializeMetadata>::Output,
88        children: Vec<ExprRef>,
89    ) -> VortexResult<Self::Expr> {
90        if children.len() != metadata.paths.len() {
91            vortex_bail!(
92                "Pack expression expects {} children, got {}",
93                metadata.paths.len(),
94                children.len()
95            );
96        }
97        let names: FieldNames = metadata
98            .paths
99            .iter()
100            .map(|name| FieldName::from(name.as_str()))
101            .collect();
102        PackExpr::try_new(names, children, metadata.nullable.into())
103    }
104
105    fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
106        let len = scope.len();
107        let value_arrays = expr
108            .values
109            .iter()
110            .zip_eq(expr.names.iter())
111            .map(|(value_expr, name)| {
112                value_expr
113                    .unchecked_evaluate(scope)
114                    .map_err(|e| e.with_context(format!("Can't evaluate '{name}'")))
115            })
116            .process_results(|it| it.collect::<Vec<_>>())?;
117        let validity = match expr.nullability {
118            Nullability::NonNullable => Validity::NonNullable,
119            Nullability::Nullable => Validity::AllValid,
120        };
121        Ok(StructArray::try_new(expr.names.clone(), value_arrays, len, validity)?.into_array())
122    }
123
124    fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
125        let value_dtypes = expr
126            .values
127            .iter()
128            .map(|value_expr| value_expr.return_dtype(scope))
129            .process_results(|it| it.collect())?;
130        Ok(DType::Struct(
131            StructFields::new(expr.names.clone(), value_dtypes),
132            expr.nullability,
133        ))
134    }
135}
136
137impl PackExpr {
138    pub fn try_new(
139        names: FieldNames,
140        values: Vec<ExprRef>,
141        nullability: Nullability,
142    ) -> VortexResult<Self> {
143        if names.len() != values.len() {
144            vortex_bail!("length mismatch {} {}", names.len(), values.len());
145        }
146        Ok(PackExpr {
147            names,
148            values,
149            nullability,
150        })
151    }
152
153    pub fn try_new_expr(
154        names: FieldNames,
155        values: Vec<ExprRef>,
156        nullability: Nullability,
157    ) -> VortexResult<ExprRef> {
158        Self::try_new(names, values, nullability).map(|v| v.into_expr())
159    }
160
161    pub fn names(&self) -> &FieldNames {
162        &self.names
163    }
164
165    pub fn field(&self, field_name: &FieldName) -> VortexResult<ExprRef> {
166        let idx = self
167            .names
168            .iter()
169            .position(|name| name == field_name)
170            .ok_or_else(|| {
171                vortex_err!(
172                    "Cannot find field {} in pack fields {:?}",
173                    field_name,
174                    self.names
175                )
176            })?;
177
178        self.values
179            .get(idx)
180            .cloned()
181            .ok_or_else(|| vortex_err!("field index out of bounds: {}", idx))
182    }
183
184    pub fn nullability(&self) -> Nullability {
185        self.nullability
186    }
187}
188
189/// Creates an expression that packs values into a struct with named fields.
190///
191/// ```rust
192/// # use vortex_dtype::Nullability;
193/// # use vortex_expr::{pack, col, lit};
194/// let expr = pack([("id", col("user_id")), ("constant", lit(42))], Nullability::NonNullable);
195/// ```
196pub fn pack(
197    elements: impl IntoIterator<Item = (impl Into<FieldName>, ExprRef)>,
198    nullability: Nullability,
199) -> ExprRef {
200    let (names, values): (Vec<_>, Vec<_>) = elements
201        .into_iter()
202        .map(|(name, value)| (name.into(), value))
203        .unzip();
204    PackExpr::try_new(names.into(), values, nullability)
205        .vortex_expect("pack names and values have the same length")
206        .into_expr()
207}
208
209impl DisplayAs for PackExpr {
210    fn fmt_as(&self, df: DisplayFormat, f: &mut std::fmt::Formatter) -> std::fmt::Result {
211        match df {
212            DisplayFormat::Compact => {
213                write!(
214                    f,
215                    "pack({}){}",
216                    self.names
217                        .iter()
218                        .zip(&self.values)
219                        .format_with(", ", |(name, expr), f| f(&format_args!("{name}: {expr}"))),
220                    self.nullability
221                )
222            }
223            DisplayFormat::Tree => {
224                write!(f, "Pack")
225            }
226        }
227    }
228
229    fn child_names(&self) -> Option<Vec<String>> {
230        Some(self.names.iter().map(|n| n.to_string()).collect())
231    }
232}
233
234impl AnalysisExpr for PackExpr {}
235
236#[cfg(test)]
237mod tests {
238
239    use vortex_array::arrays::{PrimitiveArray, StructArray};
240    use vortex_array::validity::Validity;
241    use vortex_array::vtable::ValidityHelper;
242    use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
243    use vortex_buffer::buffer;
244    use vortex_dtype::{FieldNames, Nullability};
245    use vortex_error::{VortexResult, vortex_bail};
246
247    use crate::{IntoExpr, PackExpr, Scope, col, pack};
248
249    fn test_array() -> ArrayRef {
250        StructArray::from_fields(&[
251            ("a", buffer![0, 1, 2].into_array()),
252            ("b", buffer![4, 5, 6].into_array()),
253        ])
254        .unwrap()
255        .into_array()
256    }
257
258    fn primitive_field(array: &dyn Array, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
259        let mut field_path = field_path.iter();
260
261        let Some(field) = field_path.next() else {
262            vortex_bail!("empty field path");
263        };
264
265        let mut array = array.to_struct()?.field_by_name(field)?.clone();
266        for field in field_path {
267            array = array.to_struct()?.field_by_name(field)?.clone();
268        }
269        Ok(array.to_primitive().unwrap())
270    }
271
272    #[test]
273    pub fn test_empty_pack() {
274        let expr =
275            PackExpr::try_new(FieldNames::default(), Vec::new(), Nullability::NonNullable).unwrap();
276
277        let test_array = test_array();
278        let actual_array = expr.evaluate(&Scope::new(test_array.clone())).unwrap();
279        assert_eq!(actual_array.len(), test_array.len());
280        assert_eq!(
281            actual_array.to_struct().unwrap().struct_fields().nfields(),
282            0
283        );
284    }
285
286    #[test]
287    pub fn test_simple_pack() {
288        let expr = PackExpr::try_new(
289            ["one", "two", "three"].into(),
290            vec![col("a"), col("b"), col("a")],
291            Nullability::NonNullable,
292        )
293        .unwrap();
294
295        let actual_array = expr
296            .evaluate(&Scope::new(test_array()))
297            .unwrap()
298            .to_struct()
299            .unwrap();
300
301        assert_eq!(actual_array.names(), ["one", "two", "three"]);
302        assert_eq!(actual_array.validity(), &Validity::NonNullable);
303
304        assert_eq!(
305            primitive_field(actual_array.as_ref(), &["one"])
306                .unwrap()
307                .as_slice::<i32>(),
308            [0, 1, 2]
309        );
310        assert_eq!(
311            primitive_field(actual_array.as_ref(), &["two"])
312                .unwrap()
313                .as_slice::<i32>(),
314            [4, 5, 6]
315        );
316        assert_eq!(
317            primitive_field(actual_array.as_ref(), &["three"])
318                .unwrap()
319                .as_slice::<i32>(),
320            [0, 1, 2]
321        );
322    }
323
324    #[test]
325    pub fn test_nested_pack() {
326        let expr = PackExpr::try_new(
327            ["one", "two", "three"].into(),
328            vec![
329                col("a"),
330                PackExpr::try_new(
331                    ["two_one", "two_two"].into(),
332                    vec![col("b"), col("b")],
333                    Nullability::NonNullable,
334                )
335                .unwrap()
336                .into_expr(),
337                col("a"),
338            ],
339            Nullability::NonNullable,
340        )
341        .unwrap();
342
343        let actual_array = expr
344            .evaluate(&Scope::new(test_array()))
345            .unwrap()
346            .to_struct()
347            .unwrap();
348
349        assert_eq!(actual_array.names(), ["one", "two", "three"]);
350
351        assert_eq!(
352            primitive_field(actual_array.as_ref(), &["one"])
353                .unwrap()
354                .as_slice::<i32>(),
355            [0, 1, 2]
356        );
357        assert_eq!(
358            primitive_field(actual_array.as_ref(), &["two", "two_one"])
359                .unwrap()
360                .as_slice::<i32>(),
361            [4, 5, 6]
362        );
363        assert_eq!(
364            primitive_field(actual_array.as_ref(), &["two", "two_two"])
365                .unwrap()
366                .as_slice::<i32>(),
367            [4, 5, 6]
368        );
369        assert_eq!(
370            primitive_field(actual_array.as_ref(), &["three"])
371                .unwrap()
372                .as_slice::<i32>(),
373            [0, 1, 2]
374        );
375    }
376
377    #[test]
378    pub fn test_pack_nullable() {
379        let expr = PackExpr::try_new(
380            ["one", "two", "three"].into(),
381            vec![col("a"), col("b"), col("a")],
382            Nullability::Nullable,
383        )
384        .unwrap();
385
386        let actual_array = expr
387            .evaluate(&Scope::new(test_array()))
388            .unwrap()
389            .to_struct()
390            .unwrap();
391
392        assert_eq!(actual_array.names(), ["one", "two", "three"]);
393        assert_eq!(actual_array.validity(), &Validity::AllValid);
394    }
395
396    #[test]
397    pub fn test_display() {
398        let expr = pack(
399            [("id", col("user_id")), ("name", col("username"))],
400            Nullability::NonNullable,
401        );
402        assert_eq!(expr.to_string(), "pack(id: $.user_id, name: $.username)");
403
404        let expr2 = PackExpr::try_new(
405            ["x", "y"].into(),
406            vec![col("a"), col("b")],
407            Nullability::Nullable,
408        )
409        .unwrap();
410        assert_eq!(expr2.to_string(), "pack(x: $.a, y: $.b)?");
411    }
412}