Skip to main content

vortex_array/scalar_fn/fns/
pack.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::fmt::Formatter;
6use std::hash::Hash;
7use std::sync::Arc;
8
9use itertools::Itertools as _;
10use prost::Message;
11use vortex_error::VortexResult;
12use vortex_proto::expr as pb;
13use vortex_session::VortexSession;
14
15use crate::ArrayRef;
16use crate::ExecutionCtx;
17use crate::IntoArray;
18use crate::arrays::StructArray;
19use crate::dtype::DType;
20use crate::dtype::FieldName;
21use crate::dtype::FieldNames;
22use crate::dtype::Nullability;
23use crate::dtype::StructFields;
24use crate::expr::Expression;
25use crate::expr::lit;
26use crate::scalar_fn::Arity;
27use crate::scalar_fn::ChildName;
28use crate::scalar_fn::ExecutionArgs;
29use crate::scalar_fn::ScalarFnId;
30use crate::scalar_fn::ScalarFnVTable;
31use crate::validity::Validity;
32
33/// Pack zero or more expressions into a structure with named fields.
34#[derive(Clone)]
35pub struct Pack;
36
37#[derive(Debug, Clone, PartialEq, Eq, Hash)]
38pub struct PackOptions {
39    pub names: FieldNames,
40    pub nullability: Nullability,
41}
42
43impl Display for PackOptions {
44    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
45        write!(
46            f,
47            "names: [{}], nullability: {:#}",
48            self.names.iter().join(", "),
49            self.nullability
50        )
51    }
52}
53
54impl ScalarFnVTable for Pack {
55    type Options = PackOptions;
56
57    fn id(&self) -> ScalarFnId {
58        ScalarFnId::from("vortex.pack")
59    }
60
61    fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
62        Ok(Some(
63            pb::PackOpts {
64                paths: instance.names.iter().map(|n| n.to_string()).collect(),
65                nullable: instance.nullability.into(),
66            }
67            .encode_to_vec(),
68        ))
69    }
70
71    fn deserialize(
72        &self,
73        _metadata: &[u8],
74        _session: &VortexSession,
75    ) -> VortexResult<Self::Options> {
76        let opts = pb::PackOpts::decode(_metadata)?;
77        let names: FieldNames = opts
78            .paths
79            .iter()
80            .map(|name| FieldName::from(name.as_str()))
81            .collect();
82        Ok(PackOptions {
83            names,
84            nullability: opts.nullable.into(),
85        })
86    }
87
88    fn arity(&self, options: &Self::Options) -> Arity {
89        Arity::Exact(options.names.len())
90    }
91
92    fn child_name(&self, instance: &Self::Options, child_idx: usize) -> ChildName {
93        match instance.names.get(child_idx) {
94            Some(name) => ChildName::from(Arc::clone(name.inner())),
95            None => unreachable!(
96                "Invalid child index {} for Pack expression with {} fields",
97                child_idx,
98                instance.names.len()
99            ),
100        }
101    }
102
103    fn fmt_sql(
104        &self,
105        options: &Self::Options,
106        expr: &Expression,
107        f: &mut Formatter<'_>,
108    ) -> std::fmt::Result {
109        write!(f, "pack(")?;
110        for (i, (name, child)) in options.names.iter().zip(expr.children().iter()).enumerate() {
111            write!(f, "{}: ", name)?;
112            child.fmt_sql(f)?;
113            if i + 1 < options.names.len() {
114                write!(f, ", ")?;
115            }
116        }
117        write!(f, "){}", options.nullability)
118    }
119
120    fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
121        Ok(DType::Struct(
122            StructFields::new(options.names.clone(), arg_dtypes.to_vec()),
123            options.nullability,
124        ))
125    }
126
127    fn validity(
128        &self,
129        _options: &Self::Options,
130        _expression: &Expression,
131    ) -> VortexResult<Option<Expression>> {
132        Ok(Some(lit(true)))
133    }
134
135    fn execute(
136        &self,
137        options: &Self::Options,
138        args: &dyn ExecutionArgs,
139        ctx: &mut ExecutionCtx,
140    ) -> VortexResult<ArrayRef> {
141        let len = args.row_count();
142        let value_arrays: Vec<ArrayRef> = (0..args.num_inputs())
143            .map(|i| args.get(i))
144            .collect::<VortexResult<_>>()?;
145        let validity: Validity = options.nullability.into();
146        StructArray::try_new(options.names.clone(), value_arrays, len, validity)?
147            .into_array()
148            .execute(ctx)
149    }
150
151    // This applies a nullability
152    fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
153        true
154    }
155
156    fn is_fallible(&self, _instance: &Self::Options) -> bool {
157        false
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use vortex_buffer::buffer;
164    use vortex_error::VortexResult;
165    use vortex_error::vortex_bail;
166
167    use super::Pack;
168    use super::PackOptions;
169    use crate::ArrayRef;
170    use crate::IntoArray;
171    #[expect(deprecated)]
172    use crate::ToCanonical as _;
173    use crate::arrays::PrimitiveArray;
174    use crate::arrays::struct_::StructArrayExt;
175    use crate::assert_arrays_eq;
176    use crate::dtype::Nullability;
177    use crate::expr::col;
178    use crate::expr::pack;
179    use crate::scalar_fn::ScalarFnVTableExt;
180    use crate::scalar_fn::fns::pack::StructArray;
181    use crate::validity::Validity;
182
183    fn test_array() -> ArrayRef {
184        StructArray::from_fields(&[
185            ("a", buffer![0, 1, 2].into_array()),
186            ("b", buffer![4, 5, 6].into_array()),
187        ])
188        .unwrap()
189        .into_array()
190    }
191
192    fn primitive_field(array: &ArrayRef, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
193        let mut field_path = field_path.iter();
194
195        let Some(field) = field_path.next() else {
196            vortex_bail!("empty field path");
197        };
198
199        #[expect(deprecated)]
200        let mut array = array.to_struct().unmasked_field_by_name(field)?.clone();
201        for field in field_path {
202            #[expect(deprecated)]
203            let next = array.to_struct().unmasked_field_by_name(field)?.clone();
204            array = next;
205        }
206        #[expect(deprecated)]
207        let result = array.to_primitive();
208        Ok(result)
209    }
210
211    #[test]
212    pub fn test_empty_pack() {
213        let expr = Pack.new_expr(
214            PackOptions {
215                names: Default::default(),
216                nullability: Default::default(),
217            },
218            [],
219        );
220
221        let test_array = test_array();
222        let actual_array = test_array.clone().apply(&expr).unwrap();
223        assert_eq!(actual_array.len(), test_array.len());
224        #[expect(deprecated)]
225        let nfields = actual_array.to_struct().struct_fields().nfields();
226        assert_eq!(nfields, 0);
227    }
228
229    #[test]
230    pub fn test_simple_pack() {
231        let expr = Pack.new_expr(
232            PackOptions {
233                names: ["one", "two", "three"].into(),
234                nullability: Nullability::NonNullable,
235            },
236            [col("a"), col("b"), col("a")],
237        );
238
239        #[expect(deprecated)]
240        let actual_array = test_array().apply(&expr).unwrap().to_struct();
241
242        assert_eq!(actual_array.names(), ["one", "two", "three"]);
243        assert!(matches!(actual_array.validity(), Ok(Validity::NonNullable)));
244
245        assert_arrays_eq!(
246            primitive_field(&actual_array.clone().into_array(), &["one"]).unwrap(),
247            PrimitiveArray::from_iter([0i32, 1, 2])
248        );
249        assert_arrays_eq!(
250            primitive_field(&actual_array.clone().into_array(), &["two"]).unwrap(),
251            PrimitiveArray::from_iter([4i32, 5, 6])
252        );
253        assert_arrays_eq!(
254            primitive_field(&actual_array.into_array(), &["three"]).unwrap(),
255            PrimitiveArray::from_iter([0i32, 1, 2])
256        );
257    }
258
259    #[test]
260    pub fn test_nested_pack() {
261        let expr = Pack.new_expr(
262            PackOptions {
263                names: ["one", "two", "three"].into(),
264                nullability: Nullability::NonNullable,
265            },
266            [
267                col("a"),
268                Pack.new_expr(
269                    PackOptions {
270                        names: ["two_one", "two_two"].into(),
271                        nullability: Nullability::NonNullable,
272                    },
273                    [col("b"), col("b")],
274                ),
275                col("a"),
276            ],
277        );
278
279        #[expect(deprecated)]
280        let actual_array = test_array().apply(&expr).unwrap().to_struct();
281
282        assert_eq!(actual_array.names(), ["one", "two", "three"]);
283
284        assert_arrays_eq!(
285            primitive_field(&actual_array.clone().into_array(), &["one"]).unwrap(),
286            PrimitiveArray::from_iter([0i32, 1, 2])
287        );
288        assert_arrays_eq!(
289            primitive_field(&actual_array.clone().into_array(), &["two", "two_one"]).unwrap(),
290            PrimitiveArray::from_iter([4i32, 5, 6])
291        );
292        assert_arrays_eq!(
293            primitive_field(&actual_array.clone().into_array(), &["two", "two_two"]).unwrap(),
294            PrimitiveArray::from_iter([4i32, 5, 6])
295        );
296        assert_arrays_eq!(
297            primitive_field(&actual_array.into_array(), &["three"]).unwrap(),
298            PrimitiveArray::from_iter([0i32, 1, 2])
299        );
300    }
301
302    #[test]
303    pub fn test_pack_nullable() {
304        let expr = Pack.new_expr(
305            PackOptions {
306                names: ["one", "two", "three"].into(),
307                nullability: Nullability::Nullable,
308            },
309            [col("a"), col("b"), col("a")],
310        );
311
312        #[expect(deprecated)]
313        let actual_array = test_array().apply(&expr).unwrap().to_struct();
314
315        assert_eq!(actual_array.names(), ["one", "two", "three"]);
316        assert!(matches!(actual_array.validity(), Ok(Validity::AllValid)));
317    }
318
319    #[test]
320    pub fn test_display() {
321        let expr = pack(
322            [("id", col("user_id")), ("name", col("username"))],
323            Nullability::NonNullable,
324        );
325        assert_eq!(expr.to_string(), "pack(id: $.user_id, name: $.username)");
326
327        let expr2 = Pack.new_expr(
328            PackOptions {
329                names: ["x", "y"].into(),
330                nullability: Nullability::Nullable,
331            },
332            [col("a"), col("b")],
333        );
334        assert_eq!(expr2.to_string(), "pack(x: $.a, y: $.b)?");
335    }
336}