vortex_array/expr/exprs/
pack.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Formatter;
5use std::hash::Hash;
6
7use itertools::Itertools as _;
8use prost::Message;
9use vortex_dtype::{DType, FieldName, FieldNames, Nullability, StructFields};
10use vortex_error::{VortexResult, vortex_bail, vortex_err};
11use vortex_proto::expr as pb;
12
13use crate::arrays::StructArray;
14use crate::expr::{ChildName, ExprId, Expression, ExpressionView, VTable, VTableExt};
15use crate::validity::Validity;
16use crate::{ArrayRef, IntoArray};
17
18/// Pack zero or more expressions into a structure with named fields.
19pub struct Pack;
20
21#[derive(Debug, Clone, PartialEq, Eq, Hash)]
22pub struct PackOptions {
23    pub names: FieldNames,
24    pub nullability: Nullability,
25}
26
27impl VTable for Pack {
28    type Instance = PackOptions;
29
30    fn id(&self) -> ExprId {
31        ExprId::new_ref("vortex.pack")
32    }
33
34    fn serialize(&self, instance: &Self::Instance) -> VortexResult<Option<Vec<u8>>> {
35        Ok(Some(
36            pb::PackOpts {
37                paths: instance.names.iter().map(|n| n.to_string()).collect(),
38                nullable: instance.nullability.into(),
39            }
40            .encode_to_vec(),
41        ))
42    }
43
44    fn deserialize(&self, metadata: &[u8]) -> VortexResult<Option<Self::Instance>> {
45        let opts = pb::PackOpts::decode(metadata)?;
46        let names: FieldNames = opts
47            .paths
48            .iter()
49            .map(|name| FieldName::from(name.as_str()))
50            .collect();
51        Ok(Some(PackOptions {
52            names,
53            nullability: opts.nullable.into(),
54        }))
55    }
56
57    fn validate(&self, expr: &ExpressionView<Self>) -> VortexResult<()> {
58        let instance = expr.data();
59        if expr.children().len() != instance.names.len() {
60            vortex_bail!(
61                "Pack expression expects {} children, got {}",
62                instance.names.len(),
63                expr.children().len()
64            );
65        }
66        Ok(())
67    }
68
69    fn child_name(&self, instance: &Self::Instance, child_idx: usize) -> ChildName {
70        match instance.names.get(child_idx) {
71            Some(name) => ChildName::from(name.inner().clone()),
72            None => unreachable!(
73                "Invalid child index {} for Pack expression with {} fields",
74                child_idx,
75                instance.names.len()
76            ),
77        }
78    }
79
80    fn fmt_sql(&self, expr: &ExpressionView<Self>, f: &mut Formatter<'_>) -> std::fmt::Result {
81        write!(f, "pack(")?;
82        for (i, (name, child)) in expr
83            .data()
84            .names
85            .iter()
86            .zip(expr.children().iter())
87            .enumerate()
88        {
89            write!(f, "{}: ", name)?;
90            child.fmt_sql(f)?;
91            if i + 1 < expr.data().names.len() {
92                write!(f, ", ")?;
93            }
94        }
95        write!(f, "){}", expr.data().nullability)
96    }
97
98    fn return_dtype(&self, expr: &ExpressionView<Self>, scope: &DType) -> VortexResult<DType> {
99        let value_dtypes = expr
100            .children()
101            .iter()
102            .map(|child| child.return_dtype(scope))
103            .collect::<VortexResult<Vec<_>>>()?;
104        Ok(DType::Struct(
105            StructFields::new(expr.data().names.clone(), value_dtypes),
106            expr.data().nullability,
107        ))
108    }
109
110    fn evaluate(&self, expr: &ExpressionView<Self>, scope: &ArrayRef) -> VortexResult<ArrayRef> {
111        let len = scope.len();
112        let value_arrays = expr
113            .children()
114            .iter()
115            .zip_eq(expr.data().names.iter())
116            .map(|(child_expr, name)| {
117                child_expr
118                    .evaluate(scope)
119                    .map_err(|e| e.with_context(format!("Can't evaluate '{name}'")))
120            })
121            .process_results(|it| it.collect::<Vec<_>>())?;
122        let validity = match expr.data().nullability {
123            Nullability::NonNullable => Validity::NonNullable,
124            Nullability::Nullable => Validity::AllValid,
125        };
126        Ok(
127            StructArray::try_new(expr.data().names.clone(), value_arrays, len, validity)?
128                .into_array(),
129        )
130    }
131}
132
133impl ExpressionView<'_, Pack> {
134    pub fn field(&self, field_name: &FieldName) -> VortexResult<Expression> {
135        let idx = self
136            .data()
137            .names
138            .iter()
139            .position(|name| name == field_name)
140            .ok_or_else(|| {
141                vortex_err!(
142                    "Cannot find field {} in pack fields {:?}",
143                    field_name,
144                    self.data().names
145                )
146            })?;
147
148        Ok(self.child(idx).clone())
149    }
150}
151
152/// Creates an expression that packs values into a struct with named fields.
153///
154/// ```rust
155/// # use vortex_dtype::Nullability;
156/// # use vortex_array::expr::{pack, col, lit};
157/// let expr = pack([("id", col("user_id")), ("constant", lit(42))], Nullability::NonNullable);
158/// ```
159pub fn pack(
160    elements: impl IntoIterator<Item = (impl Into<FieldName>, Expression)>,
161    nullability: Nullability,
162) -> Expression {
163    let (names, values): (Vec<_>, Vec<_>) = elements
164        .into_iter()
165        .map(|(name, value)| (name.into(), value))
166        .unzip();
167    Pack.new_expr(
168        PackOptions {
169            names: names.into(),
170            nullability,
171        },
172        values,
173    )
174}
175
176#[cfg(test)]
177mod tests {
178    use vortex_buffer::buffer;
179    use vortex_dtype::Nullability;
180    use vortex_error::{VortexResult, vortex_bail};
181
182    use super::{Pack, PackOptions, pack};
183    use crate::arrays::{PrimitiveArray, StructArray};
184    use crate::expr::VTableExt;
185    use crate::expr::exprs::get_item::col;
186    use crate::validity::Validity;
187    use crate::vtable::ValidityHelper;
188    use crate::{Array, ArrayRef, IntoArray, ToCanonical};
189
190    fn test_array() -> ArrayRef {
191        StructArray::from_fields(&[
192            ("a", buffer![0, 1, 2].into_array()),
193            ("b", buffer![4, 5, 6].into_array()),
194        ])
195        .unwrap()
196        .into_array()
197    }
198
199    fn primitive_field(array: &dyn Array, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
200        let mut field_path = field_path.iter();
201
202        let Some(field) = field_path.next() else {
203            vortex_bail!("empty field path");
204        };
205
206        let mut array = array.to_struct().field_by_name(field)?.clone();
207        for field in field_path {
208            array = array.to_struct().field_by_name(field)?.clone();
209        }
210        Ok(array.to_primitive())
211    }
212
213    #[test]
214    pub fn test_empty_pack() {
215        let expr = Pack.new_expr(
216            PackOptions {
217                names: Default::default(),
218                nullability: Default::default(),
219            },
220            [],
221        );
222
223        let test_array = test_array();
224        let actual_array = expr.evaluate(&test_array.clone()).unwrap();
225        assert_eq!(actual_array.len(), test_array.len());
226        assert_eq!(actual_array.to_struct().struct_fields().nfields(), 0);
227    }
228
229    #[test]
230    pub fn test_simple_pack() {
231        let expr = Pack.new_expr(
232            PackOptions {
233                names: ["one", "two", "three"].into(),
234                nullability: Nullability::NonNullable,
235            },
236            [col("a"), col("b"), col("a")],
237        );
238
239        let actual_array = expr.evaluate(&test_array()).unwrap().to_struct();
240
241        assert_eq!(actual_array.names(), ["one", "two", "three"]);
242        assert_eq!(actual_array.validity(), &Validity::NonNullable);
243
244        assert_eq!(
245            primitive_field(actual_array.as_ref(), &["one"])
246                .unwrap()
247                .as_slice::<i32>(),
248            [0, 1, 2]
249        );
250        assert_eq!(
251            primitive_field(actual_array.as_ref(), &["two"])
252                .unwrap()
253                .as_slice::<i32>(),
254            [4, 5, 6]
255        );
256        assert_eq!(
257            primitive_field(actual_array.as_ref(), &["three"])
258                .unwrap()
259                .as_slice::<i32>(),
260            [0, 1, 2]
261        );
262    }
263
264    #[test]
265    pub fn test_nested_pack() {
266        let expr = Pack.new_expr(
267            PackOptions {
268                names: ["one", "two", "three"].into(),
269                nullability: Nullability::NonNullable,
270            },
271            [
272                col("a"),
273                Pack.new_expr(
274                    PackOptions {
275                        names: ["two_one", "two_two"].into(),
276                        nullability: Nullability::NonNullable,
277                    },
278                    [col("b"), col("b")],
279                ),
280                col("a"),
281            ],
282        );
283
284        let actual_array = expr.evaluate(&test_array()).unwrap().to_struct();
285
286        assert_eq!(actual_array.names(), ["one", "two", "three"]);
287
288        assert_eq!(
289            primitive_field(actual_array.as_ref(), &["one"])
290                .unwrap()
291                .as_slice::<i32>(),
292            [0, 1, 2]
293        );
294        assert_eq!(
295            primitive_field(actual_array.as_ref(), &["two", "two_one"])
296                .unwrap()
297                .as_slice::<i32>(),
298            [4, 5, 6]
299        );
300        assert_eq!(
301            primitive_field(actual_array.as_ref(), &["two", "two_two"])
302                .unwrap()
303                .as_slice::<i32>(),
304            [4, 5, 6]
305        );
306        assert_eq!(
307            primitive_field(actual_array.as_ref(), &["three"])
308                .unwrap()
309                .as_slice::<i32>(),
310            [0, 1, 2]
311        );
312    }
313
314    #[test]
315    pub fn test_pack_nullable() {
316        let expr = Pack.new_expr(
317            PackOptions {
318                names: ["one", "two", "three"].into(),
319                nullability: Nullability::Nullable,
320            },
321            [col("a"), col("b"), col("a")],
322        );
323
324        let actual_array = expr.evaluate(&test_array()).unwrap().to_struct();
325
326        assert_eq!(actual_array.names(), ["one", "two", "three"]);
327        assert_eq!(actual_array.validity(), &Validity::AllValid);
328    }
329
330    #[test]
331    pub fn test_display() {
332        let expr = pack(
333            [("id", col("user_id")), ("name", col("username"))],
334            Nullability::NonNullable,
335        );
336        assert_eq!(expr.to_string(), "pack(id: $.user_id, name: $.username)");
337
338        let expr2 = Pack.new_expr(
339            PackOptions {
340                names: ["x", "y"].into(),
341                nullability: Nullability::Nullable,
342            },
343            [col("a"), col("b")],
344        );
345        assert_eq!(expr2.to_string(), "pack(x: $.a, y: $.b)?");
346    }
347}