vortex_expr/
pack.rs

1use std::any::Any;
2use std::fmt::Display;
3use std::hash::Hash;
4use std::sync::Arc;
5
6use itertools::Itertools as _;
7use vortex_array::arrays::StructArray;
8use vortex_array::validity::Validity;
9use vortex_array::{Array, ArrayRef, IntoArray};
10use vortex_dtype::{DType, FieldName, FieldNames, Nullability, StructDType};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_err};
12
13use crate::{ExprRef, VortexExpr};
14
15/// Pack zero or more expressions into a structure with named fields.
16///
17/// # Examples
18///
19/// ```
20/// use vortex_array::{IntoArray, ToCanonical};
21/// use vortex_buffer::buffer;
22/// use vortex_expr::{Pack, Identity, VortexExpr};
23/// use vortex_scalar::Scalar;
24/// use vortex_dtype::Nullability;
25///
26/// let example = Pack::try_new_expr(
27///     ["x".into(), "x copy".into(), "second x copy".into()].into(),
28///     vec![Identity::new_expr(), Identity::new_expr(), Identity::new_expr()],
29///     Nullability::NonNullable,
30/// ).unwrap();
31/// let packed = example.evaluate(&buffer![100, 110, 200].into_array()).unwrap();
32/// let x_copy = packed
33///     .to_struct()
34///     .unwrap()
35///     .field_by_name("x copy")
36///     .unwrap()
37///     .clone();
38/// assert_eq!(x_copy.scalar_at(0).unwrap(), Scalar::from(100));
39/// assert_eq!(x_copy.scalar_at(1).unwrap(), Scalar::from(110));
40/// assert_eq!(x_copy.scalar_at(2).unwrap(), Scalar::from(200));
41/// ```
42///
43#[derive(Debug, Clone, PartialEq, Eq, Hash)]
44pub struct Pack {
45    names: FieldNames,
46    values: Vec<ExprRef>,
47    nullability: Nullability,
48}
49
50impl Pack {
51    pub fn try_new_expr(
52        names: FieldNames,
53        values: Vec<ExprRef>,
54        nullability: Nullability,
55    ) -> VortexResult<Arc<Self>> {
56        if names.len() != values.len() {
57            vortex_bail!("length mismatch {} {}", names.len(), values.len());
58        }
59        Ok(Arc::new(Pack {
60            names,
61            values,
62            nullability,
63        }))
64    }
65
66    pub fn names(&self) -> &FieldNames {
67        &self.names
68    }
69
70    pub fn field(&self, field_name: &FieldName) -> VortexResult<ExprRef> {
71        let idx = self
72            .names
73            .iter()
74            .position(|name| name == field_name)
75            .ok_or_else(|| {
76                vortex_err!(
77                    "Cannot find field {} in pack fields {:?}",
78                    field_name,
79                    self.names
80                )
81            })?;
82
83        self.values
84            .get(idx)
85            .cloned()
86            .ok_or_else(|| vortex_err!("field index out of bounds: {}", idx))
87    }
88}
89
90pub fn pack(
91    elements: impl IntoIterator<Item = (impl Into<FieldName>, ExprRef)>,
92    nullability: Nullability,
93) -> ExprRef {
94    let (names, values): (Vec<_>, Vec<_>) = elements
95        .into_iter()
96        .map(|(name, value)| (name.into(), value))
97        .unzip();
98    Pack::try_new_expr(names.into(), values, nullability)
99        .vortex_expect("pack names and values have the same length")
100}
101
102impl Display for Pack {
103    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104        f.write_str("{")?;
105        self.names
106            .iter()
107            .zip(&self.values)
108            .format_with(", ", |(name, expr), f| f(&format_args!("{name}: {expr}")))
109            .fmt(f)?;
110        f.write_str("}")
111    }
112}
113
114#[cfg(feature = "proto")]
115pub(crate) mod proto {
116    use vortex_error::{VortexResult, vortex_bail};
117    use vortex_proto::expr::kind::Kind;
118
119    use crate::{ExprDeserialize, ExprRef, ExprSerializable, Id, Pack};
120
121    pub struct PackSerde;
122
123    impl Id for PackSerde {
124        fn id(&self) -> &'static str {
125            "pack"
126        }
127    }
128
129    impl ExprDeserialize for PackSerde {
130        fn deserialize(&self, _kind: &Kind, _children: Vec<ExprRef>) -> VortexResult<ExprRef> {
131            todo!()
132        }
133    }
134
135    impl ExprSerializable for Pack {
136        fn id(&self) -> &'static str {
137            PackSerde.id()
138        }
139
140        fn serialize_kind(&self) -> VortexResult<Kind> {
141            vortex_bail!(NotImplemented: "", self.id())
142        }
143    }
144}
145
146impl VortexExpr for Pack {
147    fn as_any(&self) -> &dyn Any {
148        self
149    }
150
151    fn unchecked_evaluate(&self, batch: &dyn Array) -> VortexResult<ArrayRef> {
152        let len = batch.len();
153        let value_arrays = self
154            .values
155            .iter()
156            .map(|value_expr| value_expr.evaluate(batch))
157            .process_results(|it| it.collect::<Vec<_>>())?;
158        let validity = match self.nullability {
159            Nullability::NonNullable => Validity::NonNullable,
160            Nullability::Nullable => Validity::AllValid,
161        };
162        Ok(StructArray::try_new(self.names.clone(), value_arrays, len, validity)?.into_array())
163    }
164
165    fn children(&self) -> Vec<&ExprRef> {
166        self.values.iter().collect()
167    }
168
169    fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef {
170        assert_eq!(children.len(), self.values.len());
171        Self::try_new_expr(self.names.clone(), children, self.nullability)
172            .vortex_expect("children are known to have the same length as names")
173    }
174
175    fn return_dtype(&self, scope_dtype: &DType) -> VortexResult<DType> {
176        let value_dtypes = self
177            .values
178            .iter()
179            .map(|value_expr| value_expr.return_dtype(scope_dtype))
180            .process_results(|it| it.collect())?;
181        Ok(DType::Struct(
182            Arc::new(StructDType::new(self.names.clone(), value_dtypes)),
183            self.nullability,
184        ))
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use std::sync::Arc;
191
192    use vortex_array::arrays::{PrimitiveArray, StructArray};
193    use vortex_array::validity::Validity;
194    use vortex_array::vtable::ValidityHelper;
195    use vortex_array::{Array, IntoArray, ToCanonical};
196    use vortex_buffer::buffer;
197    use vortex_dtype::{FieldNames, Nullability};
198    use vortex_error::{VortexResult, vortex_bail};
199
200    use crate::{Pack, VortexExpr, col};
201
202    fn test_array() -> StructArray {
203        StructArray::from_fields(&[
204            ("a", buffer![0, 1, 2].into_array()),
205            ("b", buffer![4, 5, 6].into_array()),
206        ])
207        .unwrap()
208    }
209
210    fn primitive_field(array: &dyn Array, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
211        let mut field_path = field_path.iter();
212
213        let Some(field) = field_path.next() else {
214            vortex_bail!("empty field path");
215        };
216
217        let mut array = array.to_struct()?.field_by_name(field)?.clone();
218        for field in field_path {
219            array = array.to_struct()?.field_by_name(field)?.clone();
220        }
221        Ok(array.to_primitive().unwrap())
222    }
223
224    #[test]
225    pub fn test_empty_pack() {
226        let expr = Pack::try_new_expr(Arc::new([]), Vec::new(), Nullability::NonNullable).unwrap();
227
228        let test_array = test_array().into_array();
229        let actual_array = expr.evaluate(&test_array).unwrap();
230        assert_eq!(actual_array.len(), test_array.len());
231        assert_eq!(
232            actual_array.to_struct().unwrap().struct_dtype().nfields(),
233            0
234        );
235    }
236
237    #[test]
238    pub fn test_simple_pack() {
239        let expr = Pack::try_new_expr(
240            ["one".into(), "two".into(), "three".into()].into(),
241            vec![col("a"), col("b"), col("a")],
242            Nullability::NonNullable,
243        )
244        .unwrap();
245
246        let actual_array = expr
247            .evaluate(test_array().as_ref())
248            .unwrap()
249            .to_struct()
250            .unwrap();
251        let expected_names: FieldNames = ["one".into(), "two".into(), "three".into()].into();
252        assert_eq!(actual_array.names(), &expected_names);
253        assert_eq!(actual_array.validity(), &Validity::NonNullable);
254
255        assert_eq!(
256            primitive_field(actual_array.as_ref(), &["one"])
257                .unwrap()
258                .as_slice::<i32>(),
259            [0, 1, 2]
260        );
261        assert_eq!(
262            primitive_field(actual_array.as_ref(), &["two"])
263                .unwrap()
264                .as_slice::<i32>(),
265            [4, 5, 6]
266        );
267        assert_eq!(
268            primitive_field(actual_array.as_ref(), &["three"])
269                .unwrap()
270                .as_slice::<i32>(),
271            [0, 1, 2]
272        );
273    }
274
275    #[test]
276    pub fn test_nested_pack() {
277        let expr = Pack::try_new_expr(
278            ["one".into(), "two".into(), "three".into()].into(),
279            vec![
280                col("a"),
281                Pack::try_new_expr(
282                    ["two_one".into(), "two_two".into()].into(),
283                    vec![col("b"), col("b")],
284                    Nullability::NonNullable,
285                )
286                .unwrap(),
287                col("a"),
288            ],
289            Nullability::NonNullable,
290        )
291        .unwrap();
292
293        let actual_array = expr
294            .evaluate(test_array().as_ref())
295            .unwrap()
296            .to_struct()
297            .unwrap();
298        let expected_names: FieldNames = ["one".into(), "two".into(), "three".into()].into();
299        assert_eq!(actual_array.names(), &expected_names);
300
301        assert_eq!(
302            primitive_field(actual_array.as_ref(), &["one"])
303                .unwrap()
304                .as_slice::<i32>(),
305            [0, 1, 2]
306        );
307        assert_eq!(
308            primitive_field(actual_array.as_ref(), &["two", "two_one"])
309                .unwrap()
310                .as_slice::<i32>(),
311            [4, 5, 6]
312        );
313        assert_eq!(
314            primitive_field(actual_array.as_ref(), &["two", "two_two"])
315                .unwrap()
316                .as_slice::<i32>(),
317            [4, 5, 6]
318        );
319        assert_eq!(
320            primitive_field(actual_array.as_ref(), &["three"])
321                .unwrap()
322                .as_slice::<i32>(),
323            [0, 1, 2]
324        );
325    }
326
327    #[test]
328    pub fn test_pack_nullable() {
329        let expr = Pack::try_new_expr(
330            ["one".into(), "two".into(), "three".into()].into(),
331            vec![col("a"), col("b"), col("a")],
332            Nullability::Nullable,
333        )
334        .unwrap();
335
336        let actual_array = expr
337            .evaluate(test_array().as_ref())
338            .unwrap()
339            .to_struct()
340            .unwrap();
341        let expected_names: FieldNames = ["one".into(), "two".into(), "three".into()].into();
342        assert_eq!(actual_array.names(), &expected_names);
343        assert_eq!(actual_array.validity(), &Validity::AllValid);
344    }
345}