1use std::any::Any;
2use std::fmt::Display;
3use std::hash::Hash;
4use std::sync::Arc;
5
6use itertools::Itertools as _;
7use vortex_array::arrays::StructArray;
8use vortex_array::validity::Validity;
9use vortex_array::{Array, ArrayRef};
10use vortex_dtype::{DType, FieldName, FieldNames, Nullability, StructDType};
11use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_err};
12
13use crate::{ExprRef, VortexExpr};
14
15#[derive(Debug, Clone, PartialEq, Eq, Hash)]
42pub struct Pack {
43 names: FieldNames,
44 values: Vec<ExprRef>,
45}
46
47impl Pack {
48 pub fn try_new_expr(names: FieldNames, values: Vec<ExprRef>) -> VortexResult<Arc<Self>> {
49 if names.len() != values.len() {
50 vortex_bail!("length mismatch {} {}", names.len(), values.len());
51 }
52 Ok(Arc::new(Pack { names, values }))
53 }
54
55 pub fn names(&self) -> &FieldNames {
56 &self.names
57 }
58
59 pub fn field(&self, field_name: &FieldName) -> VortexResult<ExprRef> {
60 let idx = self
61 .names
62 .iter()
63 .position(|name| name == field_name)
64 .ok_or_else(|| {
65 vortex_err!(
66 "Cannot find field {} in pack fields {:?}",
67 field_name,
68 self.names
69 )
70 })?;
71
72 self.values
73 .get(idx)
74 .cloned()
75 .ok_or_else(|| vortex_err!("field index out of bounds: {}", idx))
76 }
77}
78
79pub fn pack(elements: impl IntoIterator<Item = (impl Into<FieldName>, ExprRef)>) -> ExprRef {
80 let (names, values): (Vec<_>, Vec<_>) = elements
81 .into_iter()
82 .map(|(name, value)| (name.into(), value))
83 .unzip();
84 Pack::try_new_expr(names.into(), values)
85 .vortex_expect("pack names and values have the same length")
86}
87
88impl Display for Pack {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 f.write_str("{")?;
91 self.names
92 .iter()
93 .zip(&self.values)
94 .format_with(", ", |(name, expr), f| f(&format_args!("{name}: {expr}")))
95 .fmt(f)?;
96 f.write_str("}")
97 }
98}
99
100#[cfg(feature = "proto")]
101pub(crate) mod proto {
102 use vortex_error::{VortexResult, vortex_bail};
103 use vortex_proto::expr::kind::Kind;
104
105 use crate::{ExprDeserialize, ExprRef, ExprSerializable, Id, Pack};
106
107 pub struct PackSerde;
108
109 impl Id for PackSerde {
110 fn id(&self) -> &'static str {
111 "pack"
112 }
113 }
114
115 impl ExprDeserialize for PackSerde {
116 fn deserialize(&self, _kind: &Kind, _children: Vec<ExprRef>) -> VortexResult<ExprRef> {
117 todo!()
118 }
119 }
120
121 impl ExprSerializable for Pack {
122 fn id(&self) -> &'static str {
123 PackSerde.id()
124 }
125
126 fn serialize_kind(&self) -> VortexResult<Kind> {
127 vortex_bail!(NotImplemented: "", self.id())
128 }
129 }
130}
131
132impl VortexExpr for Pack {
133 fn as_any(&self) -> &dyn Any {
134 self
135 }
136
137 fn unchecked_evaluate(&self, batch: &dyn Array) -> VortexResult<ArrayRef> {
138 let len = batch.len();
139 let value_arrays = self
140 .values
141 .iter()
142 .map(|value_expr| value_expr.evaluate(batch))
143 .process_results(|it| it.collect::<Vec<_>>())?;
144 Ok(
145 StructArray::try_new(self.names.clone(), value_arrays, len, Validity::NonNullable)?
146 .into_array(),
147 )
148 }
149
150 fn children(&self) -> Vec<&ExprRef> {
151 self.values.iter().collect()
152 }
153
154 fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef {
155 assert_eq!(children.len(), self.values.len());
156 Self::try_new_expr(self.names.clone(), children)
157 .vortex_expect("children are known to have the same length as names")
158 }
159
160 fn return_dtype(&self, scope_dtype: &DType) -> VortexResult<DType> {
161 let value_dtypes = self
162 .values
163 .iter()
164 .map(|value_expr| value_expr.return_dtype(scope_dtype))
165 .process_results(|it| it.collect())?;
166 Ok(DType::Struct(
167 Arc::new(StructDType::new(self.names.clone(), value_dtypes)),
168 Nullability::NonNullable,
169 ))
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use std::sync::Arc;
176
177 use vortex_array::arrays::{PrimitiveArray, StructArray};
178 use vortex_array::{Array, IntoArray, ToCanonical};
179 use vortex_buffer::buffer;
180 use vortex_dtype::FieldNames;
181 use vortex_error::{VortexResult, vortex_bail, vortex_err};
182
183 use crate::{Pack, VortexExpr, col};
184
185 fn test_array() -> StructArray {
186 StructArray::from_fields(&[
187 ("a", buffer![0, 1, 2].into_array()),
188 ("b", buffer![4, 5, 6].into_array()),
189 ])
190 .unwrap()
191 }
192
193 fn primitive_field(array: &dyn Array, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
194 let mut field_path = field_path.iter();
195
196 let Some(field) = field_path.next() else {
197 vortex_bail!("empty field path");
198 };
199
200 let mut array = array
201 .as_struct_typed()
202 .ok_or_else(|| vortex_err!("expected a struct"))?
203 .maybe_null_field_by_name(field)?;
204
205 for field in field_path {
206 array = array
207 .as_struct_typed()
208 .ok_or_else(|| vortex_err!("expected a struct"))?
209 .maybe_null_field_by_name(field)?;
210 }
211 Ok(array.to_primitive().unwrap())
212 }
213
214 #[test]
215 pub fn test_empty_pack() {
216 let expr = Pack::try_new_expr(Arc::new([]), Vec::new()).unwrap();
217
218 let test_array = test_array().into_array();
219 let actual_array = expr.evaluate(&test_array).unwrap();
220 assert_eq!(actual_array.len(), test_array.len());
221 assert!(actual_array.as_struct_typed().unwrap().nfields() == 0);
222 }
223
224 #[test]
225 pub fn test_simple_pack() {
226 let expr = Pack::try_new_expr(
227 ["one".into(), "two".into(), "three".into()].into(),
228 vec![col("a"), col("b"), col("a")],
229 )
230 .unwrap();
231
232 let actual_array = expr.evaluate(&test_array()).unwrap();
233 let expected_names: FieldNames = ["one".into(), "two".into(), "three".into()].into();
234 assert_eq!(
235 actual_array.as_struct_typed().unwrap().names(),
236 &expected_names
237 );
238
239 assert_eq!(
240 primitive_field(&actual_array, &["one"])
241 .unwrap()
242 .as_slice::<i32>(),
243 [0, 1, 2]
244 );
245 assert_eq!(
246 primitive_field(&actual_array, &["two"])
247 .unwrap()
248 .as_slice::<i32>(),
249 [4, 5, 6]
250 );
251 assert_eq!(
252 primitive_field(&actual_array, &["three"])
253 .unwrap()
254 .as_slice::<i32>(),
255 [0, 1, 2]
256 );
257 }
258
259 #[test]
260 pub fn test_nested_pack() {
261 let expr = Pack::try_new_expr(
262 ["one".into(), "two".into(), "three".into()].into(),
263 vec![
264 col("a"),
265 Pack::try_new_expr(
266 ["two_one".into(), "two_two".into()].into(),
267 vec![col("b"), col("b")],
268 )
269 .unwrap(),
270 col("a"),
271 ],
272 )
273 .unwrap();
274
275 let actual_array = expr.evaluate(&test_array()).unwrap();
276 let expected_names: FieldNames = ["one".into(), "two".into(), "three".into()].into();
277 assert_eq!(
278 actual_array.as_struct_typed().unwrap().names(),
279 &expected_names
280 );
281
282 assert_eq!(
283 primitive_field(&actual_array, &["one"])
284 .unwrap()
285 .as_slice::<i32>(),
286 [0, 1, 2]
287 );
288 assert_eq!(
289 primitive_field(&actual_array, &["two", "two_one"])
290 .unwrap()
291 .as_slice::<i32>(),
292 [4, 5, 6]
293 );
294 assert_eq!(
295 primitive_field(&actual_array, &["two", "two_two"])
296 .unwrap()
297 .as_slice::<i32>(),
298 [4, 5, 6]
299 );
300 assert_eq!(
301 primitive_field(&actual_array, &["three"])
302 .unwrap()
303 .as_slice::<i32>(),
304 [0, 1, 2]
305 );
306 }
307}