Skip to main content

vortex_array/scalar_fn/fns/
pack.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::fmt::Formatter;
6use std::hash::Hash;
7use std::sync::Arc;
8
9use itertools::Itertools as _;
10use prost::Message;
11use vortex_error::VortexResult;
12use vortex_proto::expr as pb;
13use vortex_session::VortexSession;
14use vortex_session::registry::CachedId;
15
16use crate::ArrayRef;
17use crate::ExecutionCtx;
18use crate::IntoArray;
19use crate::arrays::StructArray;
20use crate::dtype::DType;
21use crate::dtype::FieldName;
22use crate::dtype::FieldNames;
23use crate::dtype::Nullability;
24use crate::dtype::StructFields;
25use crate::expr::Expression;
26use crate::expr::lit;
27use crate::scalar_fn::Arity;
28use crate::scalar_fn::ChildName;
29use crate::scalar_fn::ExecutionArgs;
30use crate::scalar_fn::ScalarFnId;
31use crate::scalar_fn::ScalarFnVTable;
32use crate::validity::Validity;
33
34/// Pack zero or more expressions into a structure with named fields.
35#[derive(Clone)]
36pub struct Pack;
37
38#[derive(Debug, Clone, PartialEq, Eq, Hash)]
39pub struct PackOptions {
40    pub names: FieldNames,
41    pub nullability: Nullability,
42}
43
44impl Display for PackOptions {
45    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
46        write!(
47            f,
48            "names: [{}], nullability: {:#}",
49            self.names.iter().join(", "),
50            self.nullability
51        )
52    }
53}
54
55impl ScalarFnVTable for Pack {
56    type Options = PackOptions;
57
58    fn id(&self) -> ScalarFnId {
59        static ID: CachedId = CachedId::new("vortex.pack");
60        *ID
61    }
62
63    fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
64        Ok(Some(
65            pb::PackOpts {
66                paths: instance.names.iter().map(|n| n.to_string()).collect(),
67                nullable: instance.nullability.into(),
68            }
69            .encode_to_vec(),
70        ))
71    }
72
73    fn deserialize(
74        &self,
75        _metadata: &[u8],
76        _session: &VortexSession,
77    ) -> VortexResult<Self::Options> {
78        let opts = pb::PackOpts::decode(_metadata)?;
79        let names: FieldNames = opts
80            .paths
81            .iter()
82            .map(|name| FieldName::from(name.as_str()))
83            .collect();
84        Ok(PackOptions {
85            names,
86            nullability: opts.nullable.into(),
87        })
88    }
89
90    fn arity(&self, options: &Self::Options) -> Arity {
91        Arity::Exact(options.names.len())
92    }
93
94    fn child_name(&self, instance: &Self::Options, child_idx: usize) -> ChildName {
95        match instance.names.get(child_idx) {
96            Some(name) => ChildName::from(Arc::clone(name.inner())),
97            None => unreachable!(
98                "Invalid child index {} for Pack expression with {} fields",
99                child_idx,
100                instance.names.len()
101            ),
102        }
103    }
104
105    fn fmt_sql(
106        &self,
107        options: &Self::Options,
108        expr: &Expression,
109        f: &mut Formatter<'_>,
110    ) -> std::fmt::Result {
111        write!(f, "pack(")?;
112        for (i, (name, child)) in options.names.iter().zip(expr.children().iter()).enumerate() {
113            write!(f, "{}: ", name)?;
114            child.fmt_sql(f)?;
115            if i + 1 < options.names.len() {
116                write!(f, ", ")?;
117            }
118        }
119        write!(f, "){}", options.nullability)
120    }
121
122    fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
123        Ok(DType::Struct(
124            StructFields::new(options.names.clone(), arg_dtypes.to_vec()),
125            options.nullability,
126        ))
127    }
128
129    fn validity(
130        &self,
131        _options: &Self::Options,
132        _expression: &Expression,
133    ) -> VortexResult<Option<Expression>> {
134        Ok(Some(lit(true)))
135    }
136
137    fn execute(
138        &self,
139        options: &Self::Options,
140        args: &dyn ExecutionArgs,
141        ctx: &mut ExecutionCtx,
142    ) -> VortexResult<ArrayRef> {
143        let len = args.row_count();
144        let value_arrays: Vec<ArrayRef> = (0..args.num_inputs())
145            .map(|i| args.get(i))
146            .collect::<VortexResult<_>>()?;
147        let validity: Validity = options.nullability.into();
148        StructArray::try_new(options.names.clone(), value_arrays, len, validity)?
149            .into_array()
150            .execute(ctx)
151    }
152
153    // This applies a nullability
154    fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
155        true
156    }
157
158    fn is_fallible(&self, _instance: &Self::Options) -> bool {
159        false
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use vortex_buffer::buffer;
166    use vortex_error::VortexResult;
167    use vortex_error::vortex_bail;
168
169    use super::Pack;
170    use super::PackOptions;
171    use crate::ArrayRef;
172    use crate::IntoArray;
173    #[expect(deprecated)]
174    use crate::ToCanonical as _;
175    use crate::VortexSessionExecute;
176    use crate::array_session;
177    use crate::arrays::PrimitiveArray;
178    use crate::arrays::struct_::StructArrayExt;
179    use crate::assert_arrays_eq;
180    use crate::dtype::Nullability;
181    use crate::expr::col;
182    use crate::expr::pack;
183    use crate::scalar_fn::ScalarFnVTableExt;
184    use crate::scalar_fn::fns::pack::StructArray;
185    use crate::validity::Validity;
186
187    fn test_array() -> ArrayRef {
188        StructArray::from_fields(&[
189            ("a", buffer![0, 1, 2].into_array()),
190            ("b", buffer![4, 5, 6].into_array()),
191        ])
192        .unwrap()
193        .into_array()
194    }
195
196    fn primitive_field(array: &ArrayRef, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
197        let mut field_path = field_path.iter();
198
199        let Some(field) = field_path.next() else {
200            vortex_bail!("empty field path");
201        };
202
203        #[expect(deprecated)]
204        let mut array = array.to_struct().unmasked_field_by_name(field)?.clone();
205        for field in field_path {
206            #[expect(deprecated)]
207            let next = array.to_struct().unmasked_field_by_name(field)?.clone();
208            array = next;
209        }
210        #[expect(deprecated)]
211        let result = array.to_primitive();
212        Ok(result)
213    }
214
215    #[test]
216    pub fn test_empty_pack() {
217        let expr = Pack.new_expr(
218            PackOptions {
219                names: Default::default(),
220                nullability: Default::default(),
221            },
222            [],
223        );
224
225        let test_array = test_array();
226        let actual_array = test_array.clone().apply(&expr).unwrap();
227        assert_eq!(actual_array.len(), test_array.len());
228        #[expect(deprecated)]
229        let nfields = actual_array.to_struct().struct_fields().nfields();
230        assert_eq!(nfields, 0);
231    }
232
233    #[test]
234    pub fn test_simple_pack() {
235        let mut ctx = array_session().create_execution_ctx();
236        let expr = Pack.new_expr(
237            PackOptions {
238                names: ["one", "two", "three"].into(),
239                nullability: Nullability::NonNullable,
240            },
241            [col("a"), col("b"), col("a")],
242        );
243
244        #[expect(deprecated)]
245        let actual_array = test_array().apply(&expr).unwrap().to_struct();
246
247        assert_eq!(actual_array.names(), ["one", "two", "three"]);
248        assert!(matches!(actual_array.validity(), Ok(Validity::NonNullable)));
249
250        assert_arrays_eq!(
251            primitive_field(&actual_array.clone().into_array(), &["one"]).unwrap(),
252            PrimitiveArray::from_iter([0i32, 1, 2]),
253            &mut ctx
254        );
255        assert_arrays_eq!(
256            primitive_field(&actual_array.clone().into_array(), &["two"]).unwrap(),
257            PrimitiveArray::from_iter([4i32, 5, 6]),
258            &mut ctx
259        );
260        assert_arrays_eq!(
261            primitive_field(&actual_array.into_array(), &["three"]).unwrap(),
262            PrimitiveArray::from_iter([0i32, 1, 2]),
263            &mut ctx
264        );
265    }
266
267    #[test]
268    pub fn test_nested_pack() {
269        let mut ctx = array_session().create_execution_ctx();
270        let expr = Pack.new_expr(
271            PackOptions {
272                names: ["one", "two", "three"].into(),
273                nullability: Nullability::NonNullable,
274            },
275            [
276                col("a"),
277                Pack.new_expr(
278                    PackOptions {
279                        names: ["two_one", "two_two"].into(),
280                        nullability: Nullability::NonNullable,
281                    },
282                    [col("b"), col("b")],
283                ),
284                col("a"),
285            ],
286        );
287
288        #[expect(deprecated)]
289        let actual_array = test_array().apply(&expr).unwrap().to_struct();
290
291        assert_eq!(actual_array.names(), ["one", "two", "three"]);
292
293        assert_arrays_eq!(
294            primitive_field(&actual_array.clone().into_array(), &["one"]).unwrap(),
295            PrimitiveArray::from_iter([0i32, 1, 2]),
296            &mut ctx
297        );
298        assert_arrays_eq!(
299            primitive_field(&actual_array.clone().into_array(), &["two", "two_one"]).unwrap(),
300            PrimitiveArray::from_iter([4i32, 5, 6]),
301            &mut ctx
302        );
303        assert_arrays_eq!(
304            primitive_field(&actual_array.clone().into_array(), &["two", "two_two"]).unwrap(),
305            PrimitiveArray::from_iter([4i32, 5, 6]),
306            &mut ctx
307        );
308        assert_arrays_eq!(
309            primitive_field(&actual_array.into_array(), &["three"]).unwrap(),
310            PrimitiveArray::from_iter([0i32, 1, 2]),
311            &mut ctx
312        );
313    }
314
315    #[test]
316    pub fn test_pack_nullable() {
317        let expr = Pack.new_expr(
318            PackOptions {
319                names: ["one", "two", "three"].into(),
320                nullability: Nullability::Nullable,
321            },
322            [col("a"), col("b"), col("a")],
323        );
324
325        #[expect(deprecated)]
326        let actual_array = test_array().apply(&expr).unwrap().to_struct();
327
328        assert_eq!(actual_array.names(), ["one", "two", "three"]);
329        assert!(matches!(actual_array.validity(), Ok(Validity::AllValid)));
330    }
331
332    #[test]
333    pub fn test_display() {
334        let expr = pack(
335            [("id", col("user_id")), ("name", col("username"))],
336            Nullability::NonNullable,
337        );
338        assert_eq!(expr.to_string(), "pack(id: $.user_id, name: $.username)");
339
340        let expr2 = Pack.new_expr(
341            PackOptions {
342                names: ["x", "y"].into(),
343                nullability: Nullability::Nullable,
344            },
345            [col("a"), col("b")],
346        );
347        assert_eq!(expr2.to_string(), "pack(x: $.a, y: $.b)?");
348    }
349}