vortex_expr/
pack.rs

1use std::any::Any;
2use std::fmt::Display;
3use std::hash::Hash;
4use std::sync::Arc;
5
6use itertools::Itertools as _;
7use vortex_array::arrays::StructArray;
8use vortex_array::validity::Validity;
9use vortex_array::{ArrayRef, IntoArray};
10use vortex_dtype::{DType, FieldName, FieldNames, Nullability, StructFields};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_err};
12
13use crate::{AnalysisExpr, ExprRef, Scope, ScopeDType, VortexExpr};
14
15/// Pack zero or more expressions into a structure with named fields.
16///
17/// # Examples
18///
19/// ```
20/// use vortex_array::{IntoArray, ToCanonical};
21/// use vortex_buffer::buffer;
22/// use vortex_expr::{root, Pack, Scope, VortexExpr};
23/// use vortex_scalar::Scalar;
24/// use vortex_dtype::Nullability;
25///
26/// let example = Pack::try_new_expr(
27///     ["x".into(), "x copy".into(), "second x copy".into()].into(),
28///     vec![root(), root(), root()],
29///     Nullability::NonNullable,
30/// ).unwrap();
31/// let packed = example.evaluate(&Scope::new(buffer![100, 110, 200].into_array())).unwrap();
32/// let x_copy = packed
33///     .to_struct()
34///     .unwrap()
35///     .field_by_name("x copy")
36///     .unwrap()
37///     .clone();
38/// assert_eq!(x_copy.scalar_at(0).unwrap(), Scalar::from(100));
39/// assert_eq!(x_copy.scalar_at(1).unwrap(), Scalar::from(110));
40/// assert_eq!(x_copy.scalar_at(2).unwrap(), Scalar::from(200));
41/// ```
42///
43#[derive(Debug, Clone, PartialEq, Eq, Hash)]
44pub struct Pack {
45    names: FieldNames,
46    values: Vec<ExprRef>,
47    nullability: Nullability,
48}
49
50impl Pack {
51    pub fn try_new_expr(
52        names: FieldNames,
53        values: Vec<ExprRef>,
54        nullability: Nullability,
55    ) -> VortexResult<Arc<Self>> {
56        if names.len() != values.len() {
57            vortex_bail!("length mismatch {} {}", names.len(), values.len());
58        }
59        Ok(Arc::new(Pack {
60            names,
61            values,
62            nullability,
63        }))
64    }
65
66    pub fn names(&self) -> &FieldNames {
67        &self.names
68    }
69
70    pub fn field(&self, field_name: &FieldName) -> VortexResult<ExprRef> {
71        let idx = self
72            .names
73            .iter()
74            .position(|name| name == field_name)
75            .ok_or_else(|| {
76                vortex_err!(
77                    "Cannot find field {} in pack fields {:?}",
78                    field_name,
79                    self.names
80                )
81            })?;
82
83        self.values
84            .get(idx)
85            .cloned()
86            .ok_or_else(|| vortex_err!("field index out of bounds: {}", idx))
87    }
88}
89
90pub fn pack(
91    elements: impl IntoIterator<Item = (impl Into<FieldName>, ExprRef)>,
92    nullability: Nullability,
93) -> ExprRef {
94    let (names, values): (Vec<_>, Vec<_>) = elements
95        .into_iter()
96        .map(|(name, value)| (name.into(), value))
97        .unzip();
98    Pack::try_new_expr(names.into(), values, nullability)
99        .vortex_expect("pack names and values have the same length")
100}
101
102impl Display for Pack {
103    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104        f.write_str("{")?;
105        self.names
106            .iter()
107            .zip(&self.values)
108            .format_with(", ", |(name, expr), f| f(&format_args!("{name}: {expr}")))
109            .fmt(f)?;
110        f.write_str("}")
111    }
112}
113
114#[cfg(feature = "proto")]
115pub(crate) mod proto {
116    use vortex_error::{VortexResult, vortex_bail};
117    use vortex_proto::expr::kind::Kind;
118
119    use crate::{ExprDeserialize, ExprRef, ExprSerializable, Id, Pack};
120
121    pub struct PackSerde;
122
123    impl Id for PackSerde {
124        fn id(&self) -> &'static str {
125            "pack"
126        }
127    }
128
129    impl ExprDeserialize for PackSerde {
130        fn deserialize(&self, _kind: &Kind, _children: Vec<ExprRef>) -> VortexResult<ExprRef> {
131            todo!()
132        }
133    }
134
135    impl ExprSerializable for Pack {
136        fn id(&self) -> &'static str {
137            PackSerde.id()
138        }
139
140        fn serialize_kind(&self) -> VortexResult<Kind> {
141            vortex_bail!(NotImplemented: "", self.id())
142        }
143    }
144}
145
146impl AnalysisExpr for Pack {}
147
148impl VortexExpr for Pack {
149    fn as_any(&self) -> &dyn Any {
150        self
151    }
152
153    fn unchecked_evaluate(&self, scope: &Scope) -> VortexResult<ArrayRef> {
154        let len = scope.len();
155        let value_arrays = self
156            .values
157            .iter()
158            .map(|value_expr| value_expr.unchecked_evaluate(scope))
159            .process_results(|it| it.collect::<Vec<_>>())?;
160        let validity = match self.nullability {
161            Nullability::NonNullable => Validity::NonNullable,
162            Nullability::Nullable => Validity::AllValid,
163        };
164        Ok(StructArray::try_new(self.names.clone(), value_arrays, len, validity)?.into_array())
165    }
166
167    fn children(&self) -> Vec<&ExprRef> {
168        self.values.iter().collect()
169    }
170
171    fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef {
172        assert_eq!(children.len(), self.values.len());
173        Self::try_new_expr(self.names.clone(), children, self.nullability)
174            .vortex_expect("children are known to have the same length as names")
175    }
176
177    fn return_dtype(&self, scope: &ScopeDType) -> VortexResult<DType> {
178        let value_dtypes = self
179            .values
180            .iter()
181            .map(|value_expr| value_expr.return_dtype(scope))
182            .process_results(|it| it.collect())?;
183        Ok(DType::Struct(
184            Arc::new(StructFields::new(self.names.clone(), value_dtypes)),
185            self.nullability,
186        ))
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use std::sync::Arc;
193
194    use vortex_array::arrays::{PrimitiveArray, StructArray};
195    use vortex_array::validity::Validity;
196    use vortex_array::vtable::ValidityHelper;
197    use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
198    use vortex_buffer::buffer;
199    use vortex_dtype::{FieldNames, Nullability};
200    use vortex_error::{VortexResult, vortex_bail};
201
202    use crate::{Pack, Scope, VortexExpr, col};
203
204    fn test_array() -> ArrayRef {
205        StructArray::from_fields(&[
206            ("a", buffer![0, 1, 2].into_array()),
207            ("b", buffer![4, 5, 6].into_array()),
208        ])
209        .unwrap()
210        .into_array()
211    }
212
213    fn primitive_field(array: &dyn Array, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
214        let mut field_path = field_path.iter();
215
216        let Some(field) = field_path.next() else {
217            vortex_bail!("empty field path");
218        };
219
220        let mut array = array.to_struct()?.field_by_name(field)?.clone();
221        for field in field_path {
222            array = array.to_struct()?.field_by_name(field)?.clone();
223        }
224        Ok(array.to_primitive().unwrap())
225    }
226
227    #[test]
228    pub fn test_empty_pack() {
229        let expr = Pack::try_new_expr(Arc::new([]), Vec::new(), Nullability::NonNullable).unwrap();
230
231        let test_array = test_array();
232        let actual_array = expr.evaluate(&Scope::new(test_array.clone())).unwrap();
233        assert_eq!(actual_array.len(), test_array.len());
234        assert_eq!(
235            actual_array.to_struct().unwrap().struct_fields().nfields(),
236            0
237        );
238    }
239
240    #[test]
241    pub fn test_simple_pack() {
242        let expr = Pack::try_new_expr(
243            ["one".into(), "two".into(), "three".into()].into(),
244            vec![col("a"), col("b"), col("a")],
245            Nullability::NonNullable,
246        )
247        .unwrap();
248
249        let actual_array = expr
250            .evaluate(&Scope::new(test_array()))
251            .unwrap()
252            .to_struct()
253            .unwrap();
254        let expected_names: FieldNames = ["one".into(), "two".into(), "three".into()].into();
255        assert_eq!(actual_array.names(), &expected_names);
256        assert_eq!(actual_array.validity(), &Validity::NonNullable);
257
258        assert_eq!(
259            primitive_field(actual_array.as_ref(), &["one"])
260                .unwrap()
261                .as_slice::<i32>(),
262            [0, 1, 2]
263        );
264        assert_eq!(
265            primitive_field(actual_array.as_ref(), &["two"])
266                .unwrap()
267                .as_slice::<i32>(),
268            [4, 5, 6]
269        );
270        assert_eq!(
271            primitive_field(actual_array.as_ref(), &["three"])
272                .unwrap()
273                .as_slice::<i32>(),
274            [0, 1, 2]
275        );
276    }
277
278    #[test]
279    pub fn test_nested_pack() {
280        let expr = Pack::try_new_expr(
281            ["one".into(), "two".into(), "three".into()].into(),
282            vec![
283                col("a"),
284                Pack::try_new_expr(
285                    ["two_one".into(), "two_two".into()].into(),
286                    vec![col("b"), col("b")],
287                    Nullability::NonNullable,
288                )
289                .unwrap(),
290                col("a"),
291            ],
292            Nullability::NonNullable,
293        )
294        .unwrap();
295
296        let actual_array = expr
297            .evaluate(&Scope::new(test_array()))
298            .unwrap()
299            .to_struct()
300            .unwrap();
301        let expected_names: FieldNames = ["one".into(), "two".into(), "three".into()].into();
302        assert_eq!(actual_array.names(), &expected_names);
303
304        assert_eq!(
305            primitive_field(actual_array.as_ref(), &["one"])
306                .unwrap()
307                .as_slice::<i32>(),
308            [0, 1, 2]
309        );
310        assert_eq!(
311            primitive_field(actual_array.as_ref(), &["two", "two_one"])
312                .unwrap()
313                .as_slice::<i32>(),
314            [4, 5, 6]
315        );
316        assert_eq!(
317            primitive_field(actual_array.as_ref(), &["two", "two_two"])
318                .unwrap()
319                .as_slice::<i32>(),
320            [4, 5, 6]
321        );
322        assert_eq!(
323            primitive_field(actual_array.as_ref(), &["three"])
324                .unwrap()
325                .as_slice::<i32>(),
326            [0, 1, 2]
327        );
328    }
329
330    #[test]
331    pub fn test_pack_nullable() {
332        let expr = Pack::try_new_expr(
333            ["one".into(), "two".into(), "three".into()].into(),
334            vec![col("a"), col("b"), col("a")],
335            Nullability::Nullable,
336        )
337        .unwrap();
338
339        let actual_array = expr
340            .evaluate(&Scope::new(test_array()))
341            .unwrap()
342            .to_struct()
343            .unwrap();
344        let expected_names: FieldNames = ["one".into(), "two".into(), "three".into()].into();
345        assert_eq!(actual_array.names(), &expected_names);
346        assert_eq!(actual_array.validity(), &Validity::AllValid);
347    }
348}