vortex_array/expr/exprs/
pack.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Formatter;
5use std::hash::Hash;
6
7use itertools::Itertools as _;
8use prost::Message;
9use vortex_dtype::DType;
10use vortex_dtype::FieldName;
11use vortex_dtype::FieldNames;
12use vortex_dtype::Nullability;
13use vortex_dtype::StructFields;
14use vortex_error::VortexResult;
15use vortex_error::vortex_bail;
16use vortex_error::vortex_err;
17use vortex_proto::expr as pb;
18
19use crate::ArrayRef;
20use crate::IntoArray;
21use crate::arrays::StructArray;
22use crate::expr::ChildName;
23use crate::expr::ExprId;
24use crate::expr::Expression;
25use crate::expr::ExpressionView;
26use crate::expr::VTable;
27use crate::expr::VTableExt;
28use crate::validity::Validity;
29
30/// Pack zero or more expressions into a structure with named fields.
31pub struct Pack;
32
33#[derive(Debug, Clone, PartialEq, Eq, Hash)]
34pub struct PackOptions {
35    pub names: FieldNames,
36    pub nullability: Nullability,
37}
38
39impl VTable for Pack {
40    type Instance = PackOptions;
41
42    fn id(&self) -> ExprId {
43        ExprId::new_ref("vortex.pack")
44    }
45
46    fn serialize(&self, instance: &Self::Instance) -> VortexResult<Option<Vec<u8>>> {
47        Ok(Some(
48            pb::PackOpts {
49                paths: instance.names.iter().map(|n| n.to_string()).collect(),
50                nullable: instance.nullability.into(),
51            }
52            .encode_to_vec(),
53        ))
54    }
55
56    fn deserialize(&self, metadata: &[u8]) -> VortexResult<Option<Self::Instance>> {
57        let opts = pb::PackOpts::decode(metadata)?;
58        let names: FieldNames = opts
59            .paths
60            .iter()
61            .map(|name| FieldName::from(name.as_str()))
62            .collect();
63        Ok(Some(PackOptions {
64            names,
65            nullability: opts.nullable.into(),
66        }))
67    }
68
69    fn validate(&self, expr: &ExpressionView<Self>) -> VortexResult<()> {
70        let instance = expr.data();
71        if expr.children().len() != instance.names.len() {
72            vortex_bail!(
73                "Pack expression expects {} children, got {}",
74                instance.names.len(),
75                expr.children().len()
76            );
77        }
78        Ok(())
79    }
80
81    fn child_name(&self, instance: &Self::Instance, child_idx: usize) -> ChildName {
82        match instance.names.get(child_idx) {
83            Some(name) => ChildName::from(name.inner().clone()),
84            None => unreachable!(
85                "Invalid child index {} for Pack expression with {} fields",
86                child_idx,
87                instance.names.len()
88            ),
89        }
90    }
91
92    fn fmt_sql(&self, expr: &ExpressionView<Self>, f: &mut Formatter<'_>) -> std::fmt::Result {
93        write!(f, "pack(")?;
94        for (i, (name, child)) in expr
95            .data()
96            .names
97            .iter()
98            .zip(expr.children().iter())
99            .enumerate()
100        {
101            write!(f, "{}: ", name)?;
102            child.fmt_sql(f)?;
103            if i + 1 < expr.data().names.len() {
104                write!(f, ", ")?;
105            }
106        }
107        write!(f, "){}", expr.data().nullability)
108    }
109
110    fn return_dtype(&self, expr: &ExpressionView<Self>, scope: &DType) -> VortexResult<DType> {
111        let value_dtypes = expr
112            .children()
113            .iter()
114            .map(|child| child.return_dtype(scope))
115            .collect::<VortexResult<Vec<_>>>()?;
116        Ok(DType::Struct(
117            StructFields::new(expr.data().names.clone(), value_dtypes),
118            expr.data().nullability,
119        ))
120    }
121
122    fn evaluate(&self, expr: &ExpressionView<Self>, scope: &ArrayRef) -> VortexResult<ArrayRef> {
123        let len = scope.len();
124        let value_arrays = expr
125            .children()
126            .iter()
127            .zip_eq(expr.data().names.iter())
128            .map(|(child_expr, name)| {
129                child_expr
130                    .evaluate(scope)
131                    .map_err(|e| e.with_context(format!("Can't evaluate '{name}'")))
132            })
133            .process_results(|it| it.collect::<Vec<_>>())?;
134        let validity = match expr.data().nullability {
135            Nullability::NonNullable => Validity::NonNullable,
136            Nullability::Nullable => Validity::AllValid,
137        };
138        Ok(
139            StructArray::try_new(expr.data().names.clone(), value_arrays, len, validity)?
140                .into_array(),
141        )
142    }
143
144    // This applies a nullability
145    fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool {
146        true
147    }
148
149    fn is_fallible(&self, _instance: &Self::Instance) -> bool {
150        false
151    }
152}
153
154impl ExpressionView<'_, Pack> {
155    pub fn field(&self, field_name: &FieldName) -> VortexResult<Expression> {
156        let idx = self
157            .data()
158            .names
159            .iter()
160            .position(|name| name == field_name)
161            .ok_or_else(|| {
162                vortex_err!(
163                    "Cannot find field {} in pack fields {:?}",
164                    field_name,
165                    self.data().names
166                )
167            })?;
168
169        Ok(self.child(idx).clone())
170    }
171}
172
173/// Creates an expression that packs values into a struct with named fields.
174///
175/// ```rust
176/// # use vortex_dtype::Nullability;
177/// # use vortex_array::expr::{pack, col, lit};
178/// let expr = pack([("id", col("user_id")), ("constant", lit(42))], Nullability::NonNullable);
179/// ```
180pub fn pack(
181    elements: impl IntoIterator<Item = (impl Into<FieldName>, Expression)>,
182    nullability: Nullability,
183) -> Expression {
184    let (names, values): (Vec<_>, Vec<_>) = elements
185        .into_iter()
186        .map(|(name, value)| (name.into(), value))
187        .unzip();
188    Pack.new_expr(
189        PackOptions {
190            names: names.into(),
191            nullability,
192        },
193        values,
194    )
195}
196
197#[cfg(test)]
198mod tests {
199    use vortex_buffer::buffer;
200    use vortex_dtype::Nullability;
201    use vortex_error::VortexResult;
202    use vortex_error::vortex_bail;
203
204    use super::Pack;
205    use super::PackOptions;
206    use super::pack;
207    use crate::Array;
208    use crate::ArrayRef;
209    use crate::IntoArray;
210    use crate::ToCanonical;
211    use crate::arrays::PrimitiveArray;
212    use crate::arrays::StructArray;
213    use crate::expr::VTableExt;
214    use crate::expr::exprs::get_item::col;
215    use crate::validity::Validity;
216    use crate::vtable::ValidityHelper;
217
218    fn test_array() -> ArrayRef {
219        StructArray::from_fields(&[
220            ("a", buffer![0, 1, 2].into_array()),
221            ("b", buffer![4, 5, 6].into_array()),
222        ])
223        .unwrap()
224        .into_array()
225    }
226
227    fn primitive_field(array: &dyn Array, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
228        let mut field_path = field_path.iter();
229
230        let Some(field) = field_path.next() else {
231            vortex_bail!("empty field path");
232        };
233
234        let mut array = array.to_struct().field_by_name(field)?.clone();
235        for field in field_path {
236            array = array.to_struct().field_by_name(field)?.clone();
237        }
238        Ok(array.to_primitive())
239    }
240
241    #[test]
242    pub fn test_empty_pack() {
243        let expr = Pack.new_expr(
244            PackOptions {
245                names: Default::default(),
246                nullability: Default::default(),
247            },
248            [],
249        );
250
251        let test_array = test_array();
252        let actual_array = expr.evaluate(&test_array.clone()).unwrap();
253        assert_eq!(actual_array.len(), test_array.len());
254        assert_eq!(actual_array.to_struct().struct_fields().nfields(), 0);
255    }
256
257    #[test]
258    pub fn test_simple_pack() {
259        let expr = Pack.new_expr(
260            PackOptions {
261                names: ["one", "two", "three"].into(),
262                nullability: Nullability::NonNullable,
263            },
264            [col("a"), col("b"), col("a")],
265        );
266
267        let actual_array = expr.evaluate(&test_array()).unwrap().to_struct();
268
269        assert_eq!(actual_array.names(), ["one", "two", "three"]);
270        assert_eq!(actual_array.validity(), &Validity::NonNullable);
271
272        assert_eq!(
273            primitive_field(actual_array.as_ref(), &["one"])
274                .unwrap()
275                .as_slice::<i32>(),
276            [0, 1, 2]
277        );
278        assert_eq!(
279            primitive_field(actual_array.as_ref(), &["two"])
280                .unwrap()
281                .as_slice::<i32>(),
282            [4, 5, 6]
283        );
284        assert_eq!(
285            primitive_field(actual_array.as_ref(), &["three"])
286                .unwrap()
287                .as_slice::<i32>(),
288            [0, 1, 2]
289        );
290    }
291
292    #[test]
293    pub fn test_nested_pack() {
294        let expr = Pack.new_expr(
295            PackOptions {
296                names: ["one", "two", "three"].into(),
297                nullability: Nullability::NonNullable,
298            },
299            [
300                col("a"),
301                Pack.new_expr(
302                    PackOptions {
303                        names: ["two_one", "two_two"].into(),
304                        nullability: Nullability::NonNullable,
305                    },
306                    [col("b"), col("b")],
307                ),
308                col("a"),
309            ],
310        );
311
312        let actual_array = expr.evaluate(&test_array()).unwrap().to_struct();
313
314        assert_eq!(actual_array.names(), ["one", "two", "three"]);
315
316        assert_eq!(
317            primitive_field(actual_array.as_ref(), &["one"])
318                .unwrap()
319                .as_slice::<i32>(),
320            [0, 1, 2]
321        );
322        assert_eq!(
323            primitive_field(actual_array.as_ref(), &["two", "two_one"])
324                .unwrap()
325                .as_slice::<i32>(),
326            [4, 5, 6]
327        );
328        assert_eq!(
329            primitive_field(actual_array.as_ref(), &["two", "two_two"])
330                .unwrap()
331                .as_slice::<i32>(),
332            [4, 5, 6]
333        );
334        assert_eq!(
335            primitive_field(actual_array.as_ref(), &["three"])
336                .unwrap()
337                .as_slice::<i32>(),
338            [0, 1, 2]
339        );
340    }
341
342    #[test]
343    pub fn test_pack_nullable() {
344        let expr = Pack.new_expr(
345            PackOptions {
346                names: ["one", "two", "three"].into(),
347                nullability: Nullability::Nullable,
348            },
349            [col("a"), col("b"), col("a")],
350        );
351
352        let actual_array = expr.evaluate(&test_array()).unwrap().to_struct();
353
354        assert_eq!(actual_array.names(), ["one", "two", "three"]);
355        assert_eq!(actual_array.validity(), &Validity::AllValid);
356    }
357
358    #[test]
359    pub fn test_display() {
360        let expr = pack(
361            [("id", col("user_id")), ("name", col("username"))],
362            Nullability::NonNullable,
363        );
364        assert_eq!(expr.to_string(), "pack(id: $.user_id, name: $.username)");
365
366        let expr2 = Pack.new_expr(
367            PackOptions {
368                names: ["x", "y"].into(),
369                nullability: Nullability::Nullable,
370            },
371            [col("a"), col("b")],
372        );
373        assert_eq!(expr2.to_string(), "pack(x: $.a, y: $.b)?");
374    }
375}