Skip to main content

vortex_array/scalar_fn/fns/
pack.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::fmt::Formatter;
6use std::hash::Hash;
7
8use itertools::Itertools as _;
9use prost::Message;
10use vortex_error::VortexResult;
11use vortex_proto::expr as pb;
12use vortex_session::VortexSession;
13
14use crate::ArrayRef;
15use crate::ExecutionCtx;
16use crate::IntoArray;
17use crate::arrays::StructArray;
18use crate::dtype::DType;
19use crate::dtype::FieldName;
20use crate::dtype::FieldNames;
21use crate::dtype::Nullability;
22use crate::dtype::StructFields;
23use crate::expr::Expression;
24use crate::expr::lit;
25use crate::scalar_fn::Arity;
26use crate::scalar_fn::ChildName;
27use crate::scalar_fn::ExecutionArgs;
28use crate::scalar_fn::ScalarFnId;
29use crate::scalar_fn::ScalarFnVTable;
30use crate::validity::Validity;
31
32/// Pack zero or more expressions into a structure with named fields.
33#[derive(Clone)]
34pub struct Pack;
35
36#[derive(Debug, Clone, PartialEq, Eq, Hash)]
37pub struct PackOptions {
38    pub names: FieldNames,
39    pub nullability: Nullability,
40}
41
42impl Display for PackOptions {
43    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
44        write!(
45            f,
46            "names: [{}], nullability: {:#}",
47            self.names.iter().join(", "),
48            self.nullability
49        )
50    }
51}
52
53impl ScalarFnVTable for Pack {
54    type Options = PackOptions;
55
56    fn id(&self) -> ScalarFnId {
57        ScalarFnId::new_ref("vortex.pack")
58    }
59
60    fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
61        Ok(Some(
62            pb::PackOpts {
63                paths: instance.names.iter().map(|n| n.to_string()).collect(),
64                nullable: instance.nullability.into(),
65            }
66            .encode_to_vec(),
67        ))
68    }
69
70    fn deserialize(
71        &self,
72        _metadata: &[u8],
73        _session: &VortexSession,
74    ) -> VortexResult<Self::Options> {
75        let opts = pb::PackOpts::decode(_metadata)?;
76        let names: FieldNames = opts
77            .paths
78            .iter()
79            .map(|name| FieldName::from(name.as_str()))
80            .collect();
81        Ok(PackOptions {
82            names,
83            nullability: opts.nullable.into(),
84        })
85    }
86
87    fn arity(&self, options: &Self::Options) -> Arity {
88        Arity::Exact(options.names.len())
89    }
90
91    fn child_name(&self, instance: &Self::Options, child_idx: usize) -> ChildName {
92        match instance.names.get(child_idx) {
93            Some(name) => ChildName::from(name.inner().clone()),
94            None => unreachable!(
95                "Invalid child index {} for Pack expression with {} fields",
96                child_idx,
97                instance.names.len()
98            ),
99        }
100    }
101
102    fn fmt_sql(
103        &self,
104        options: &Self::Options,
105        expr: &Expression,
106        f: &mut Formatter<'_>,
107    ) -> std::fmt::Result {
108        write!(f, "pack(")?;
109        for (i, (name, child)) in options.names.iter().zip(expr.children().iter()).enumerate() {
110            write!(f, "{}: ", name)?;
111            child.fmt_sql(f)?;
112            if i + 1 < options.names.len() {
113                write!(f, ", ")?;
114            }
115        }
116        write!(f, "){}", options.nullability)
117    }
118
119    fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
120        Ok(DType::Struct(
121            StructFields::new(options.names.clone(), arg_dtypes.to_vec()),
122            options.nullability,
123        ))
124    }
125
126    fn validity(
127        &self,
128        _options: &Self::Options,
129        _expression: &Expression,
130    ) -> VortexResult<Option<Expression>> {
131        Ok(Some(lit(true)))
132    }
133
134    fn execute(
135        &self,
136        options: &Self::Options,
137        args: &dyn ExecutionArgs,
138        ctx: &mut ExecutionCtx,
139    ) -> VortexResult<ArrayRef> {
140        let len = args.row_count();
141        let value_arrays: Vec<ArrayRef> = (0..args.num_inputs())
142            .map(|i| args.get(i))
143            .collect::<VortexResult<_>>()?;
144        let validity: Validity = options.nullability.into();
145        StructArray::try_new(options.names.clone(), value_arrays, len, validity)?
146            .into_array()
147            .execute(ctx)
148    }
149
150    // This applies a nullability
151    fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
152        true
153    }
154
155    fn is_fallible(&self, _instance: &Self::Options) -> bool {
156        false
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use vortex_buffer::buffer;
163    use vortex_error::VortexResult;
164    use vortex_error::vortex_bail;
165
166    use super::Pack;
167    use super::PackOptions;
168    use crate::Array;
169    use crate::ArrayRef;
170    use crate::IntoArray;
171    use crate::ToCanonical;
172    use crate::arrays::PrimitiveArray;
173    use crate::arrays::StructArray;
174    use crate::assert_arrays_eq;
175    use crate::dtype::Nullability;
176    use crate::expr::col;
177    use crate::expr::pack;
178    use crate::scalar_fn::ScalarFnVTableExt;
179    use crate::validity::Validity;
180    use crate::vtable::ValidityHelper;
181
182    fn test_array() -> ArrayRef {
183        StructArray::from_fields(&[
184            ("a", buffer![0, 1, 2].into_array()),
185            ("b", buffer![4, 5, 6].into_array()),
186        ])
187        .unwrap()
188        .into_array()
189    }
190
191    fn primitive_field(array: &ArrayRef, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
192        let mut field_path = field_path.iter();
193
194        let Some(field) = field_path.next() else {
195            vortex_bail!("empty field path");
196        };
197
198        let mut array = array.to_struct().unmasked_field_by_name(field)?.clone();
199        for field in field_path {
200            array = array.to_struct().unmasked_field_by_name(field)?.clone();
201        }
202        Ok(array.to_primitive())
203    }
204
205    #[test]
206    pub fn test_empty_pack() {
207        let expr = Pack.new_expr(
208            PackOptions {
209                names: Default::default(),
210                nullability: Default::default(),
211            },
212            [],
213        );
214
215        let test_array = test_array();
216        let actual_array = test_array.clone().apply(&expr).unwrap();
217        assert_eq!(actual_array.len(), test_array.len());
218        assert_eq!(actual_array.to_struct().struct_fields().nfields(), 0);
219    }
220
221    #[test]
222    pub fn test_simple_pack() {
223        let expr = Pack.new_expr(
224            PackOptions {
225                names: ["one", "two", "three"].into(),
226                nullability: Nullability::NonNullable,
227            },
228            [col("a"), col("b"), col("a")],
229        );
230
231        let actual_array = test_array().apply(&expr).unwrap().to_struct();
232
233        assert_eq!(actual_array.names(), ["one", "two", "three"]);
234        assert_eq!(actual_array.validity(), &Validity::NonNullable);
235
236        assert_arrays_eq!(
237            primitive_field(&actual_array.to_array(), &["one"]).unwrap(),
238            PrimitiveArray::from_iter([0i32, 1, 2])
239        );
240        assert_arrays_eq!(
241            primitive_field(&actual_array.to_array(), &["two"]).unwrap(),
242            PrimitiveArray::from_iter([4i32, 5, 6])
243        );
244        assert_arrays_eq!(
245            primitive_field(&actual_array.to_array(), &["three"]).unwrap(),
246            PrimitiveArray::from_iter([0i32, 1, 2])
247        );
248    }
249
250    #[test]
251    pub fn test_nested_pack() {
252        let expr = Pack.new_expr(
253            PackOptions {
254                names: ["one", "two", "three"].into(),
255                nullability: Nullability::NonNullable,
256            },
257            [
258                col("a"),
259                Pack.new_expr(
260                    PackOptions {
261                        names: ["two_one", "two_two"].into(),
262                        nullability: Nullability::NonNullable,
263                    },
264                    [col("b"), col("b")],
265                ),
266                col("a"),
267            ],
268        );
269
270        let actual_array = test_array().apply(&expr).unwrap().to_struct();
271
272        assert_eq!(actual_array.names(), ["one", "two", "three"]);
273
274        assert_arrays_eq!(
275            primitive_field(&actual_array.to_array(), &["one"]).unwrap(),
276            PrimitiveArray::from_iter([0i32, 1, 2])
277        );
278        assert_arrays_eq!(
279            primitive_field(&actual_array.to_array(), &["two", "two_one"]).unwrap(),
280            PrimitiveArray::from_iter([4i32, 5, 6])
281        );
282        assert_arrays_eq!(
283            primitive_field(&actual_array.to_array(), &["two", "two_two"]).unwrap(),
284            PrimitiveArray::from_iter([4i32, 5, 6])
285        );
286        assert_arrays_eq!(
287            primitive_field(&actual_array.to_array(), &["three"]).unwrap(),
288            PrimitiveArray::from_iter([0i32, 1, 2])
289        );
290    }
291
292    #[test]
293    pub fn test_pack_nullable() {
294        let expr = Pack.new_expr(
295            PackOptions {
296                names: ["one", "two", "three"].into(),
297                nullability: Nullability::Nullable,
298            },
299            [col("a"), col("b"), col("a")],
300        );
301
302        let actual_array = test_array().apply(&expr).unwrap().to_struct();
303
304        assert_eq!(actual_array.names(), ["one", "two", "three"]);
305        assert_eq!(actual_array.validity(), &Validity::AllValid);
306    }
307
308    #[test]
309    pub fn test_display() {
310        let expr = pack(
311            [("id", col("user_id")), ("name", col("username"))],
312            Nullability::NonNullable,
313        );
314        assert_eq!(expr.to_string(), "pack(id: $.user_id, name: $.username)");
315
316        let expr2 = Pack.new_expr(
317            PackOptions {
318                names: ["x", "y"].into(),
319                nullability: Nullability::Nullable,
320            },
321            [col("a"), col("b")],
322        );
323        assert_eq!(expr2.to_string(), "pack(x: $.a, y: $.b)?");
324    }
325}