vortex_array/expr/exprs/
pack.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::fmt::Formatter;
6use std::hash::Hash;
7use std::sync::Arc;
8
9use itertools::Itertools as _;
10use prost::Message;
11use vortex_dtype::DType;
12use vortex_dtype::FieldName;
13use vortex_dtype::FieldNames;
14use vortex_dtype::Nullability;
15use vortex_dtype::StructFields;
16use vortex_error::VortexExpect;
17use vortex_error::VortexResult;
18use vortex_mask::Mask;
19use vortex_proto::expr as pb;
20use vortex_vector::Datum;
21use vortex_vector::ScalarOps;
22use vortex_vector::VectorMutOps;
23use vortex_vector::VectorOps;
24use vortex_vector::struct_::StructVector;
25
26use crate::ArrayRef;
27use crate::IntoArray;
28use crate::arrays::StructArray;
29use crate::expr::Arity;
30use crate::expr::ChildName;
31use crate::expr::ExecutionArgs;
32use crate::expr::ExprId;
33use crate::expr::Expression;
34use crate::expr::VTable;
35use crate::expr::VTableExt;
36use crate::validity::Validity;
37
38/// Pack zero or more expressions into a structure with named fields.
39pub struct Pack;
40
41#[derive(Debug, Clone, PartialEq, Eq, Hash)]
42pub struct PackOptions {
43    pub names: FieldNames,
44    pub nullability: Nullability,
45}
46
47impl Display for PackOptions {
48    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
49        write!(
50            f,
51            "names: [{}], nullability: {}",
52            self.names.iter().join(", "),
53            self.nullability
54        )
55    }
56}
57
58impl VTable for Pack {
59    type Options = PackOptions;
60
61    fn id(&self) -> ExprId {
62        ExprId::new_ref("vortex.pack")
63    }
64
65    fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
66        Ok(Some(
67            pb::PackOpts {
68                paths: instance.names.iter().map(|n| n.to_string()).collect(),
69                nullable: instance.nullability.into(),
70            }
71            .encode_to_vec(),
72        ))
73    }
74
75    fn deserialize(&self, metadata: &[u8]) -> VortexResult<Self::Options> {
76        let opts = pb::PackOpts::decode(metadata)?;
77        let names: FieldNames = opts
78            .paths
79            .iter()
80            .map(|name| FieldName::from(name.as_str()))
81            .collect();
82        Ok(PackOptions {
83            names,
84            nullability: opts.nullable.into(),
85        })
86    }
87
88    fn arity(&self, options: &Self::Options) -> Arity {
89        Arity::Exact(options.names.len())
90    }
91
92    fn child_name(&self, instance: &Self::Options, child_idx: usize) -> ChildName {
93        match instance.names.get(child_idx) {
94            Some(name) => ChildName::from(name.inner().clone()),
95            None => unreachable!(
96                "Invalid child index {} for Pack expression with {} fields",
97                child_idx,
98                instance.names.len()
99            ),
100        }
101    }
102
103    fn fmt_sql(
104        &self,
105        options: &Self::Options,
106        expr: &Expression,
107        f: &mut Formatter<'_>,
108    ) -> std::fmt::Result {
109        write!(f, "pack(")?;
110        for (i, (name, child)) in options.names.iter().zip(expr.children().iter()).enumerate() {
111            write!(f, "{}: ", name)?;
112            child.fmt_sql(f)?;
113            if i + 1 < options.names.len() {
114                write!(f, ", ")?;
115            }
116        }
117        write!(f, "){}", options.nullability)
118    }
119
120    fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
121        Ok(DType::Struct(
122            StructFields::new(options.names.clone(), arg_dtypes.to_vec()),
123            options.nullability,
124        ))
125    }
126
127    fn evaluate(
128        &self,
129        options: &Self::Options,
130        expr: &Expression,
131        scope: &ArrayRef,
132    ) -> VortexResult<ArrayRef> {
133        let len = scope.len();
134        let value_arrays = expr
135            .children()
136            .iter()
137            .zip_eq(options.names.iter())
138            .map(|(child_expr, name)| {
139                child_expr
140                    .evaluate(scope)
141                    .map_err(|e| e.with_context(format!("Can't evaluate '{name}'")))
142            })
143            .process_results(|it| it.collect::<Vec<_>>())?;
144        let validity = match options.nullability {
145            Nullability::NonNullable => Validity::NonNullable,
146            Nullability::Nullable => Validity::AllValid,
147        };
148        Ok(StructArray::try_new(options.names.clone(), value_arrays, len, validity)?.into_array())
149    }
150
151    fn execute(&self, _options: &Self::Options, args: ExecutionArgs) -> VortexResult<Datum> {
152        // If any datum is a vector, we must convert them all to vectors.
153        if args.datums.iter().any(|d| matches!(d, Datum::Vector(_))) {
154            let fields: Box<[_]> = args
155                .datums
156                .into_iter()
157                .map(|v| v.unwrap_into_vector(args.row_count))
158                .collect();
159            return Ok(Datum::Vector(
160                StructVector::try_new(Arc::new(fields), Mask::new_true(args.row_count))?.into(),
161            ));
162        }
163
164        // Otherwise, we can produce a scalar datum by constructing a length-1 struct vector.
165        let fields: Box<[_]> = args
166            .datums
167            .into_iter()
168            .map(|d| {
169                d.into_scalar()
170                    .vortex_expect("all scalars")
171                    .repeat(1)
172                    .freeze()
173            })
174            .collect();
175        let vector = StructVector::new(Arc::new(fields), Mask::new_true(1));
176        Ok(Datum::Scalar(vector.scalar_at(0).into()))
177    }
178
179    // This applies a nullability
180    fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
181        true
182    }
183
184    fn is_fallible(&self, _instance: &Self::Options) -> bool {
185        false
186    }
187}
188
189/// Creates an expression that packs values into a struct with named fields.
190///
191/// ```rust
192/// # use vortex_dtype::Nullability;
193/// # use vortex_array::expr::{pack, col, lit};
194/// let expr = pack([("id", col("user_id")), ("constant", lit(42))], Nullability::NonNullable);
195/// ```
196pub fn pack(
197    elements: impl IntoIterator<Item = (impl Into<FieldName>, Expression)>,
198    nullability: Nullability,
199) -> Expression {
200    let (names, values): (Vec<_>, Vec<_>) = elements
201        .into_iter()
202        .map(|(name, value)| (name.into(), value))
203        .unzip();
204    Pack.new_expr(
205        PackOptions {
206            names: names.into(),
207            nullability,
208        },
209        values,
210    )
211}
212
213#[cfg(test)]
214mod tests {
215    use vortex_buffer::buffer;
216    use vortex_dtype::Nullability;
217    use vortex_error::VortexResult;
218    use vortex_error::vortex_bail;
219
220    use super::Pack;
221    use super::PackOptions;
222    use super::pack;
223    use crate::Array;
224    use crate::ArrayRef;
225    use crate::IntoArray;
226    use crate::ToCanonical;
227    use crate::arrays::PrimitiveArray;
228    use crate::arrays::StructArray;
229    use crate::expr::VTableExt;
230    use crate::expr::exprs::get_item::col;
231    use crate::validity::Validity;
232    use crate::vtable::ValidityHelper;
233
234    fn test_array() -> ArrayRef {
235        StructArray::from_fields(&[
236            ("a", buffer![0, 1, 2].into_array()),
237            ("b", buffer![4, 5, 6].into_array()),
238        ])
239        .unwrap()
240        .into_array()
241    }
242
243    fn primitive_field(array: &dyn Array, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
244        let mut field_path = field_path.iter();
245
246        let Some(field) = field_path.next() else {
247            vortex_bail!("empty field path");
248        };
249
250        let mut array = array.to_struct().field_by_name(field)?.clone();
251        for field in field_path {
252            array = array.to_struct().field_by_name(field)?.clone();
253        }
254        Ok(array.to_primitive())
255    }
256
257    #[test]
258    pub fn test_empty_pack() {
259        let expr = Pack.new_expr(
260            PackOptions {
261                names: Default::default(),
262                nullability: Default::default(),
263            },
264            [],
265        );
266
267        let test_array = test_array();
268        let actual_array = expr.evaluate(&test_array.clone()).unwrap();
269        assert_eq!(actual_array.len(), test_array.len());
270        assert_eq!(actual_array.to_struct().struct_fields().nfields(), 0);
271    }
272
273    #[test]
274    pub fn test_simple_pack() {
275        let expr = Pack.new_expr(
276            PackOptions {
277                names: ["one", "two", "three"].into(),
278                nullability: Nullability::NonNullable,
279            },
280            [col("a"), col("b"), col("a")],
281        );
282
283        let actual_array = expr.evaluate(&test_array()).unwrap().to_struct();
284
285        assert_eq!(actual_array.names(), ["one", "two", "three"]);
286        assert_eq!(actual_array.validity(), &Validity::NonNullable);
287
288        assert_eq!(
289            primitive_field(actual_array.as_ref(), &["one"])
290                .unwrap()
291                .as_slice::<i32>(),
292            [0, 1, 2]
293        );
294        assert_eq!(
295            primitive_field(actual_array.as_ref(), &["two"])
296                .unwrap()
297                .as_slice::<i32>(),
298            [4, 5, 6]
299        );
300        assert_eq!(
301            primitive_field(actual_array.as_ref(), &["three"])
302                .unwrap()
303                .as_slice::<i32>(),
304            [0, 1, 2]
305        );
306    }
307
308    #[test]
309    pub fn test_nested_pack() {
310        let expr = Pack.new_expr(
311            PackOptions {
312                names: ["one", "two", "three"].into(),
313                nullability: Nullability::NonNullable,
314            },
315            [
316                col("a"),
317                Pack.new_expr(
318                    PackOptions {
319                        names: ["two_one", "two_two"].into(),
320                        nullability: Nullability::NonNullable,
321                    },
322                    [col("b"), col("b")],
323                ),
324                col("a"),
325            ],
326        );
327
328        let actual_array = expr.evaluate(&test_array()).unwrap().to_struct();
329
330        assert_eq!(actual_array.names(), ["one", "two", "three"]);
331
332        assert_eq!(
333            primitive_field(actual_array.as_ref(), &["one"])
334                .unwrap()
335                .as_slice::<i32>(),
336            [0, 1, 2]
337        );
338        assert_eq!(
339            primitive_field(actual_array.as_ref(), &["two", "two_one"])
340                .unwrap()
341                .as_slice::<i32>(),
342            [4, 5, 6]
343        );
344        assert_eq!(
345            primitive_field(actual_array.as_ref(), &["two", "two_two"])
346                .unwrap()
347                .as_slice::<i32>(),
348            [4, 5, 6]
349        );
350        assert_eq!(
351            primitive_field(actual_array.as_ref(), &["three"])
352                .unwrap()
353                .as_slice::<i32>(),
354            [0, 1, 2]
355        );
356    }
357
358    #[test]
359    pub fn test_pack_nullable() {
360        let expr = Pack.new_expr(
361            PackOptions {
362                names: ["one", "two", "three"].into(),
363                nullability: Nullability::Nullable,
364            },
365            [col("a"), col("b"), col("a")],
366        );
367
368        let actual_array = expr.evaluate(&test_array()).unwrap().to_struct();
369
370        assert_eq!(actual_array.names(), ["one", "two", "three"]);
371        assert_eq!(actual_array.validity(), &Validity::AllValid);
372    }
373
374    #[test]
375    pub fn test_display() {
376        let expr = pack(
377            [("id", col("user_id")), ("name", col("username"))],
378            Nullability::NonNullable,
379        );
380        assert_eq!(expr.to_string(), "pack(id: $.user_id, name: $.username)");
381
382        let expr2 = Pack.new_expr(
383            PackOptions {
384                names: ["x", "y"].into(),
385                nullability: Nullability::Nullable,
386            },
387            [col("a"), col("b")],
388        );
389        assert_eq!(expr2.to_string(), "pack(x: $.a, y: $.b)?");
390    }
391}