vortex_expr/
pack.rs

1use std::any::Any;
2use std::fmt::Display;
3use std::hash::Hash;
4use std::sync::Arc;
5
6use itertools::Itertools as _;
7use vortex_array::arrays::StructArray;
8use vortex_array::validity::Validity;
9use vortex_array::{Array, ArrayRef};
10use vortex_dtype::{DType, FieldName, FieldNames, Nullability, StructDType};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_err};
12
13use crate::{ExprRef, VortexExpr};
14
15/// Pack zero or more expressions into a structure with named fields.
16///
17/// # Examples
18///
19/// ```
20/// use vortex_array::IntoArray;
21/// use vortex_array::compute::scalar_at;
22/// use vortex_buffer::buffer;
23/// use vortex_expr::{Pack, Identity, VortexExpr};
24/// use vortex_scalar::Scalar;
25///
26/// let example = Pack::try_new_expr(
27///     ["x".into(), "x copy".into(), "second x copy".into()].into(),
28///     vec![Identity::new_expr(), Identity::new_expr(), Identity::new_expr()],
29/// ).unwrap();
30/// let packed = example.evaluate(&buffer![100, 110, 200].into_array()).unwrap();
31/// let x_copy = packed
32///     .as_struct_typed()
33///     .unwrap()
34///     .maybe_null_field_by_name("x copy")
35///     .unwrap();
36/// assert_eq!(scalar_at(&x_copy, 0).unwrap(), Scalar::from(100));
37/// assert_eq!(scalar_at(&x_copy, 1).unwrap(), Scalar::from(110));
38/// assert_eq!(scalar_at(&x_copy, 2).unwrap(), Scalar::from(200));
39/// ```
40///
41#[derive(Debug, Clone, PartialEq, Eq, Hash)]
42pub struct Pack {
43    names: FieldNames,
44    values: Vec<ExprRef>,
45}
46
47impl Pack {
48    pub fn try_new_expr(names: FieldNames, values: Vec<ExprRef>) -> VortexResult<Arc<Self>> {
49        if names.len() != values.len() {
50            vortex_bail!("length mismatch {} {}", names.len(), values.len());
51        }
52        Ok(Arc::new(Pack { names, values }))
53    }
54
55    pub fn names(&self) -> &FieldNames {
56        &self.names
57    }
58
59    pub fn field(&self, field_name: &FieldName) -> VortexResult<ExprRef> {
60        let idx = self
61            .names
62            .iter()
63            .position(|name| name == field_name)
64            .ok_or_else(|| {
65                vortex_err!(
66                    "Cannot find field {} in pack fields {:?}",
67                    field_name,
68                    self.names
69                )
70            })?;
71
72        self.values
73            .get(idx)
74            .cloned()
75            .ok_or_else(|| vortex_err!("field index out of bounds: {}", idx))
76    }
77}
78
79pub fn pack(elements: impl IntoIterator<Item = (impl Into<FieldName>, ExprRef)>) -> ExprRef {
80    let (names, values): (Vec<_>, Vec<_>) = elements
81        .into_iter()
82        .map(|(name, value)| (name.into(), value))
83        .unzip();
84    Pack::try_new_expr(names.into(), values)
85        .vortex_expect("pack names and values have the same length")
86}
87
88impl Display for Pack {
89    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90        f.write_str("{")?;
91        self.names
92            .iter()
93            .zip(&self.values)
94            .format_with(", ", |(name, expr), f| f(&format_args!("{name}: {expr}")))
95            .fmt(f)?;
96        f.write_str("}")
97    }
98}
99
100#[cfg(feature = "proto")]
101pub(crate) mod proto {
102    use vortex_error::{VortexResult, vortex_bail};
103    use vortex_proto::expr::kind::Kind;
104
105    use crate::{ExprDeserialize, ExprRef, ExprSerializable, Id, Pack};
106
107    pub struct PackSerde;
108
109    impl Id for PackSerde {
110        fn id(&self) -> &'static str {
111            "pack"
112        }
113    }
114
115    impl ExprDeserialize for PackSerde {
116        fn deserialize(&self, _kind: &Kind, _children: Vec<ExprRef>) -> VortexResult<ExprRef> {
117            todo!()
118        }
119    }
120
121    impl ExprSerializable for Pack {
122        fn id(&self) -> &'static str {
123            PackSerde.id()
124        }
125
126        fn serialize_kind(&self) -> VortexResult<Kind> {
127            vortex_bail!(NotImplemented: "", self.id())
128        }
129    }
130}
131
132impl VortexExpr for Pack {
133    fn as_any(&self) -> &dyn Any {
134        self
135    }
136
137    fn unchecked_evaluate(&self, batch: &dyn Array) -> VortexResult<ArrayRef> {
138        let len = batch.len();
139        let value_arrays = self
140            .values
141            .iter()
142            .map(|value_expr| value_expr.evaluate(batch))
143            .process_results(|it| it.collect::<Vec<_>>())?;
144        Ok(
145            StructArray::try_new(self.names.clone(), value_arrays, len, Validity::NonNullable)?
146                .into_array(),
147        )
148    }
149
150    fn children(&self) -> Vec<&ExprRef> {
151        self.values.iter().collect()
152    }
153
154    fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef {
155        assert_eq!(children.len(), self.values.len());
156        Self::try_new_expr(self.names.clone(), children)
157            .vortex_expect("children are known to have the same length as names")
158    }
159
160    fn return_dtype(&self, scope_dtype: &DType) -> VortexResult<DType> {
161        let value_dtypes = self
162            .values
163            .iter()
164            .map(|value_expr| value_expr.return_dtype(scope_dtype))
165            .process_results(|it| it.collect())?;
166        Ok(DType::Struct(
167            Arc::new(StructDType::new(self.names.clone(), value_dtypes)),
168            Nullability::NonNullable,
169        ))
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use std::sync::Arc;
176
177    use vortex_array::arrays::{PrimitiveArray, StructArray};
178    use vortex_array::{Array, IntoArray, ToCanonical};
179    use vortex_buffer::buffer;
180    use vortex_dtype::FieldNames;
181    use vortex_error::{VortexResult, vortex_bail, vortex_err};
182
183    use crate::{Pack, VortexExpr, col};
184
185    fn test_array() -> StructArray {
186        StructArray::from_fields(&[
187            ("a", buffer![0, 1, 2].into_array()),
188            ("b", buffer![4, 5, 6].into_array()),
189        ])
190        .unwrap()
191    }
192
193    fn primitive_field(array: &dyn Array, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
194        let mut field_path = field_path.iter();
195
196        let Some(field) = field_path.next() else {
197            vortex_bail!("empty field path");
198        };
199
200        let mut array = array
201            .as_struct_typed()
202            .ok_or_else(|| vortex_err!("expected a struct"))?
203            .maybe_null_field_by_name(field)?;
204
205        for field in field_path {
206            array = array
207                .as_struct_typed()
208                .ok_or_else(|| vortex_err!("expected a struct"))?
209                .maybe_null_field_by_name(field)?;
210        }
211        Ok(array.to_primitive().unwrap())
212    }
213
214    #[test]
215    pub fn test_empty_pack() {
216        let expr = Pack::try_new_expr(Arc::new([]), Vec::new()).unwrap();
217
218        let test_array = test_array().into_array();
219        let actual_array = expr.evaluate(&test_array).unwrap();
220        assert_eq!(actual_array.len(), test_array.len());
221        assert!(actual_array.as_struct_typed().unwrap().nfields() == 0);
222    }
223
224    #[test]
225    pub fn test_simple_pack() {
226        let expr = Pack::try_new_expr(
227            ["one".into(), "two".into(), "three".into()].into(),
228            vec![col("a"), col("b"), col("a")],
229        )
230        .unwrap();
231
232        let actual_array = expr.evaluate(&test_array()).unwrap();
233        let expected_names: FieldNames = ["one".into(), "two".into(), "three".into()].into();
234        assert_eq!(
235            actual_array.as_struct_typed().unwrap().names(),
236            &expected_names
237        );
238
239        assert_eq!(
240            primitive_field(&actual_array, &["one"])
241                .unwrap()
242                .as_slice::<i32>(),
243            [0, 1, 2]
244        );
245        assert_eq!(
246            primitive_field(&actual_array, &["two"])
247                .unwrap()
248                .as_slice::<i32>(),
249            [4, 5, 6]
250        );
251        assert_eq!(
252            primitive_field(&actual_array, &["three"])
253                .unwrap()
254                .as_slice::<i32>(),
255            [0, 1, 2]
256        );
257    }
258
259    #[test]
260    pub fn test_nested_pack() {
261        let expr = Pack::try_new_expr(
262            ["one".into(), "two".into(), "three".into()].into(),
263            vec![
264                col("a"),
265                Pack::try_new_expr(
266                    ["two_one".into(), "two_two".into()].into(),
267                    vec![col("b"), col("b")],
268                )
269                .unwrap(),
270                col("a"),
271            ],
272        )
273        .unwrap();
274
275        let actual_array = expr.evaluate(&test_array()).unwrap();
276        let expected_names: FieldNames = ["one".into(), "two".into(), "three".into()].into();
277        assert_eq!(
278            actual_array.as_struct_typed().unwrap().names(),
279            &expected_names
280        );
281
282        assert_eq!(
283            primitive_field(&actual_array, &["one"])
284                .unwrap()
285                .as_slice::<i32>(),
286            [0, 1, 2]
287        );
288        assert_eq!(
289            primitive_field(&actual_array, &["two", "two_one"])
290                .unwrap()
291                .as_slice::<i32>(),
292            [4, 5, 6]
293        );
294        assert_eq!(
295            primitive_field(&actual_array, &["two", "two_two"])
296                .unwrap()
297                .as_slice::<i32>(),
298            [4, 5, 6]
299        );
300        assert_eq!(
301            primitive_field(&actual_array, &["three"])
302                .unwrap()
303                .as_slice::<i32>(),
304            [0, 1, 2]
305        );
306    }
307}