Skip to main content

vortex_array/scalar_fn/fns/
pack.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::fmt::Formatter;
6use std::hash::Hash;
7
8use itertools::Itertools as _;
9use prost::Message;
10use vortex_error::VortexResult;
11use vortex_proto::expr as pb;
12use vortex_session::VortexSession;
13
14use crate::ArrayRef;
15use crate::IntoArray;
16use crate::arrays::StructArray;
17use crate::dtype::DType;
18use crate::dtype::FieldName;
19use crate::dtype::FieldNames;
20use crate::dtype::Nullability;
21use crate::dtype::StructFields;
22use crate::expr::Expression;
23use crate::expr::lit;
24use crate::scalar_fn::Arity;
25use crate::scalar_fn::ChildName;
26use crate::scalar_fn::ExecutionArgs;
27use crate::scalar_fn::ScalarFnId;
28use crate::scalar_fn::ScalarFnVTable;
29use crate::validity::Validity;
30
31/// Pack zero or more expressions into a structure with named fields.
32#[derive(Clone)]
33pub struct Pack;
34
35#[derive(Debug, Clone, PartialEq, Eq, Hash)]
36pub struct PackOptions {
37    pub names: FieldNames,
38    pub nullability: Nullability,
39}
40
41impl Display for PackOptions {
42    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
43        write!(
44            f,
45            "names: [{}], nullability: {:#}",
46            self.names.iter().join(", "),
47            self.nullability
48        )
49    }
50}
51
52impl ScalarFnVTable for Pack {
53    type Options = PackOptions;
54
55    fn id(&self) -> ScalarFnId {
56        ScalarFnId::new_ref("vortex.pack")
57    }
58
59    fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
60        Ok(Some(
61            pb::PackOpts {
62                paths: instance.names.iter().map(|n| n.to_string()).collect(),
63                nullable: instance.nullability.into(),
64            }
65            .encode_to_vec(),
66        ))
67    }
68
69    fn deserialize(
70        &self,
71        _metadata: &[u8],
72        _session: &VortexSession,
73    ) -> VortexResult<Self::Options> {
74        let opts = pb::PackOpts::decode(_metadata)?;
75        let names: FieldNames = opts
76            .paths
77            .iter()
78            .map(|name| FieldName::from(name.as_str()))
79            .collect();
80        Ok(PackOptions {
81            names,
82            nullability: opts.nullable.into(),
83        })
84    }
85
86    fn arity(&self, options: &Self::Options) -> Arity {
87        Arity::Exact(options.names.len())
88    }
89
90    fn child_name(&self, instance: &Self::Options, child_idx: usize) -> ChildName {
91        match instance.names.get(child_idx) {
92            Some(name) => ChildName::from(name.inner().clone()),
93            None => unreachable!(
94                "Invalid child index {} for Pack expression with {} fields",
95                child_idx,
96                instance.names.len()
97            ),
98        }
99    }
100
101    fn fmt_sql(
102        &self,
103        options: &Self::Options,
104        expr: &Expression,
105        f: &mut Formatter<'_>,
106    ) -> std::fmt::Result {
107        write!(f, "pack(")?;
108        for (i, (name, child)) in options.names.iter().zip(expr.children().iter()).enumerate() {
109            write!(f, "{}: ", name)?;
110            child.fmt_sql(f)?;
111            if i + 1 < options.names.len() {
112                write!(f, ", ")?;
113            }
114        }
115        write!(f, "){}", options.nullability)
116    }
117
118    fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
119        Ok(DType::Struct(
120            StructFields::new(options.names.clone(), arg_dtypes.to_vec()),
121            options.nullability,
122        ))
123    }
124
125    fn validity(
126        &self,
127        _options: &Self::Options,
128        _expression: &Expression,
129    ) -> VortexResult<Option<Expression>> {
130        Ok(Some(lit(true)))
131    }
132
133    fn execute(&self, options: &Self::Options, args: ExecutionArgs) -> VortexResult<ArrayRef> {
134        let len = args.row_count;
135        let value_arrays = args.inputs;
136        let validity: Validity = options.nullability.into();
137        StructArray::try_new(options.names.clone(), value_arrays, len, validity)?
138            .into_array()
139            .execute(args.ctx)
140    }
141
142    // This applies a nullability
143    fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
144        true
145    }
146
147    fn is_fallible(&self, _instance: &Self::Options) -> bool {
148        false
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use vortex_buffer::buffer;
155    use vortex_error::VortexResult;
156    use vortex_error::vortex_bail;
157
158    use super::Pack;
159    use super::PackOptions;
160    use crate::Array;
161    use crate::ArrayRef;
162    use crate::IntoArray;
163    use crate::ToCanonical;
164    use crate::arrays::PrimitiveArray;
165    use crate::arrays::StructArray;
166    use crate::assert_arrays_eq;
167    use crate::dtype::Nullability;
168    use crate::expr::col;
169    use crate::expr::pack;
170    use crate::scalar_fn::ScalarFnVTableExt;
171    use crate::validity::Validity;
172    use crate::vtable::ValidityHelper;
173
174    fn test_array() -> ArrayRef {
175        StructArray::from_fields(&[
176            ("a", buffer![0, 1, 2].into_array()),
177            ("b", buffer![4, 5, 6].into_array()),
178        ])
179        .unwrap()
180        .into_array()
181    }
182
183    fn primitive_field(array: &dyn Array, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
184        let mut field_path = field_path.iter();
185
186        let Some(field) = field_path.next() else {
187            vortex_bail!("empty field path");
188        };
189
190        let mut array = array.to_struct().unmasked_field_by_name(field)?.clone();
191        for field in field_path {
192            array = array.to_struct().unmasked_field_by_name(field)?.clone();
193        }
194        Ok(array.to_primitive())
195    }
196
197    #[test]
198    pub fn test_empty_pack() {
199        let expr = Pack.new_expr(
200            PackOptions {
201                names: Default::default(),
202                nullability: Default::default(),
203            },
204            [],
205        );
206
207        let test_array = test_array();
208        let actual_array = test_array.clone().apply(&expr).unwrap();
209        assert_eq!(actual_array.len(), test_array.len());
210        assert_eq!(actual_array.to_struct().struct_fields().nfields(), 0);
211    }
212
213    #[test]
214    pub fn test_simple_pack() {
215        let expr = Pack.new_expr(
216            PackOptions {
217                names: ["one", "two", "three"].into(),
218                nullability: Nullability::NonNullable,
219            },
220            [col("a"), col("b"), col("a")],
221        );
222
223        let actual_array = test_array().apply(&expr).unwrap().to_struct();
224
225        assert_eq!(actual_array.names(), ["one", "two", "three"]);
226        assert_eq!(actual_array.validity(), &Validity::NonNullable);
227
228        assert_arrays_eq!(
229            primitive_field(actual_array.as_ref(), &["one"]).unwrap(),
230            PrimitiveArray::from_iter([0i32, 1, 2])
231        );
232        assert_arrays_eq!(
233            primitive_field(actual_array.as_ref(), &["two"]).unwrap(),
234            PrimitiveArray::from_iter([4i32, 5, 6])
235        );
236        assert_arrays_eq!(
237            primitive_field(actual_array.as_ref(), &["three"]).unwrap(),
238            PrimitiveArray::from_iter([0i32, 1, 2])
239        );
240    }
241
242    #[test]
243    pub fn test_nested_pack() {
244        let expr = Pack.new_expr(
245            PackOptions {
246                names: ["one", "two", "three"].into(),
247                nullability: Nullability::NonNullable,
248            },
249            [
250                col("a"),
251                Pack.new_expr(
252                    PackOptions {
253                        names: ["two_one", "two_two"].into(),
254                        nullability: Nullability::NonNullable,
255                    },
256                    [col("b"), col("b")],
257                ),
258                col("a"),
259            ],
260        );
261
262        let actual_array = test_array().apply(&expr).unwrap().to_struct();
263
264        assert_eq!(actual_array.names(), ["one", "two", "three"]);
265
266        assert_arrays_eq!(
267            primitive_field(actual_array.as_ref(), &["one"]).unwrap(),
268            PrimitiveArray::from_iter([0i32, 1, 2])
269        );
270        assert_arrays_eq!(
271            primitive_field(actual_array.as_ref(), &["two", "two_one"]).unwrap(),
272            PrimitiveArray::from_iter([4i32, 5, 6])
273        );
274        assert_arrays_eq!(
275            primitive_field(actual_array.as_ref(), &["two", "two_two"]).unwrap(),
276            PrimitiveArray::from_iter([4i32, 5, 6])
277        );
278        assert_arrays_eq!(
279            primitive_field(actual_array.as_ref(), &["three"]).unwrap(),
280            PrimitiveArray::from_iter([0i32, 1, 2])
281        );
282    }
283
284    #[test]
285    pub fn test_pack_nullable() {
286        let expr = Pack.new_expr(
287            PackOptions {
288                names: ["one", "two", "three"].into(),
289                nullability: Nullability::Nullable,
290            },
291            [col("a"), col("b"), col("a")],
292        );
293
294        let actual_array = test_array().apply(&expr).unwrap().to_struct();
295
296        assert_eq!(actual_array.names(), ["one", "two", "three"]);
297        assert_eq!(actual_array.validity(), &Validity::AllValid);
298    }
299
300    #[test]
301    pub fn test_display() {
302        let expr = pack(
303            [("id", col("user_id")), ("name", col("username"))],
304            Nullability::NonNullable,
305        );
306        assert_eq!(expr.to_string(), "pack(id: $.user_id, name: $.username)");
307
308        let expr2 = Pack.new_expr(
309            PackOptions {
310                names: ["x", "y"].into(),
311                nullability: Nullability::Nullable,
312            },
313            [col("a"), col("b")],
314        );
315        assert_eq!(expr2.to_string(), "pack(x: $.a, y: $.b)?");
316    }
317}