1use std::fmt::Formatter;
5use std::hash::Hash;
6
7use itertools::Itertools as _;
8use prost::Message;
9use vortex_dtype::DType;
10use vortex_dtype::FieldName;
11use vortex_dtype::FieldNames;
12use vortex_dtype::Nullability;
13use vortex_dtype::StructFields;
14use vortex_error::VortexResult;
15use vortex_error::vortex_bail;
16use vortex_error::vortex_err;
17use vortex_proto::expr as pb;
18
19use crate::ArrayRef;
20use crate::IntoArray;
21use crate::arrays::StructArray;
22use crate::expr::ChildName;
23use crate::expr::ExprId;
24use crate::expr::Expression;
25use crate::expr::ExpressionView;
26use crate::expr::VTable;
27use crate::expr::VTableExt;
28use crate::validity::Validity;
29
30pub struct Pack;
32
33#[derive(Debug, Clone, PartialEq, Eq, Hash)]
34pub struct PackOptions {
35 pub names: FieldNames,
36 pub nullability: Nullability,
37}
38
39impl VTable for Pack {
40 type Instance = PackOptions;
41
42 fn id(&self) -> ExprId {
43 ExprId::new_ref("vortex.pack")
44 }
45
46 fn serialize(&self, instance: &Self::Instance) -> VortexResult<Option<Vec<u8>>> {
47 Ok(Some(
48 pb::PackOpts {
49 paths: instance.names.iter().map(|n| n.to_string()).collect(),
50 nullable: instance.nullability.into(),
51 }
52 .encode_to_vec(),
53 ))
54 }
55
56 fn deserialize(&self, metadata: &[u8]) -> VortexResult<Option<Self::Instance>> {
57 let opts = pb::PackOpts::decode(metadata)?;
58 let names: FieldNames = opts
59 .paths
60 .iter()
61 .map(|name| FieldName::from(name.as_str()))
62 .collect();
63 Ok(Some(PackOptions {
64 names,
65 nullability: opts.nullable.into(),
66 }))
67 }
68
69 fn validate(&self, expr: &ExpressionView<Self>) -> VortexResult<()> {
70 let instance = expr.data();
71 if expr.children().len() != instance.names.len() {
72 vortex_bail!(
73 "Pack expression expects {} children, got {}",
74 instance.names.len(),
75 expr.children().len()
76 );
77 }
78 Ok(())
79 }
80
81 fn child_name(&self, instance: &Self::Instance, child_idx: usize) -> ChildName {
82 match instance.names.get(child_idx) {
83 Some(name) => ChildName::from(name.inner().clone()),
84 None => unreachable!(
85 "Invalid child index {} for Pack expression with {} fields",
86 child_idx,
87 instance.names.len()
88 ),
89 }
90 }
91
92 fn fmt_sql(&self, expr: &ExpressionView<Self>, f: &mut Formatter<'_>) -> std::fmt::Result {
93 write!(f, "pack(")?;
94 for (i, (name, child)) in expr
95 .data()
96 .names
97 .iter()
98 .zip(expr.children().iter())
99 .enumerate()
100 {
101 write!(f, "{}: ", name)?;
102 child.fmt_sql(f)?;
103 if i + 1 < expr.data().names.len() {
104 write!(f, ", ")?;
105 }
106 }
107 write!(f, "){}", expr.data().nullability)
108 }
109
110 fn return_dtype(&self, expr: &ExpressionView<Self>, scope: &DType) -> VortexResult<DType> {
111 let value_dtypes = expr
112 .children()
113 .iter()
114 .map(|child| child.return_dtype(scope))
115 .collect::<VortexResult<Vec<_>>>()?;
116 Ok(DType::Struct(
117 StructFields::new(expr.data().names.clone(), value_dtypes),
118 expr.data().nullability,
119 ))
120 }
121
122 fn evaluate(&self, expr: &ExpressionView<Self>, scope: &ArrayRef) -> VortexResult<ArrayRef> {
123 let len = scope.len();
124 let value_arrays = expr
125 .children()
126 .iter()
127 .zip_eq(expr.data().names.iter())
128 .map(|(child_expr, name)| {
129 child_expr
130 .evaluate(scope)
131 .map_err(|e| e.with_context(format!("Can't evaluate '{name}'")))
132 })
133 .process_results(|it| it.collect::<Vec<_>>())?;
134 let validity = match expr.data().nullability {
135 Nullability::NonNullable => Validity::NonNullable,
136 Nullability::Nullable => Validity::AllValid,
137 };
138 Ok(
139 StructArray::try_new(expr.data().names.clone(), value_arrays, len, validity)?
140 .into_array(),
141 )
142 }
143
144 fn is_null_sensitive(&self, _instance: &Self::Instance) -> bool {
146 true
147 }
148
149 fn is_fallible(&self, _instance: &Self::Instance) -> bool {
150 false
151 }
152}
153
154impl ExpressionView<'_, Pack> {
155 pub fn field(&self, field_name: &FieldName) -> VortexResult<Expression> {
156 let idx = self
157 .data()
158 .names
159 .iter()
160 .position(|name| name == field_name)
161 .ok_or_else(|| {
162 vortex_err!(
163 "Cannot find field {} in pack fields {:?}",
164 field_name,
165 self.data().names
166 )
167 })?;
168
169 Ok(self.child(idx).clone())
170 }
171}
172
173pub fn pack(
181 elements: impl IntoIterator<Item = (impl Into<FieldName>, Expression)>,
182 nullability: Nullability,
183) -> Expression {
184 let (names, values): (Vec<_>, Vec<_>) = elements
185 .into_iter()
186 .map(|(name, value)| (name.into(), value))
187 .unzip();
188 Pack.new_expr(
189 PackOptions {
190 names: names.into(),
191 nullability,
192 },
193 values,
194 )
195}
196
197#[cfg(test)]
198mod tests {
199 use vortex_buffer::buffer;
200 use vortex_dtype::Nullability;
201 use vortex_error::VortexResult;
202 use vortex_error::vortex_bail;
203
204 use super::Pack;
205 use super::PackOptions;
206 use super::pack;
207 use crate::Array;
208 use crate::ArrayRef;
209 use crate::IntoArray;
210 use crate::ToCanonical;
211 use crate::arrays::PrimitiveArray;
212 use crate::arrays::StructArray;
213 use crate::expr::VTableExt;
214 use crate::expr::exprs::get_item::col;
215 use crate::validity::Validity;
216 use crate::vtable::ValidityHelper;
217
218 fn test_array() -> ArrayRef {
219 StructArray::from_fields(&[
220 ("a", buffer![0, 1, 2].into_array()),
221 ("b", buffer![4, 5, 6].into_array()),
222 ])
223 .unwrap()
224 .into_array()
225 }
226
227 fn primitive_field(array: &dyn Array, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
228 let mut field_path = field_path.iter();
229
230 let Some(field) = field_path.next() else {
231 vortex_bail!("empty field path");
232 };
233
234 let mut array = array.to_struct().field_by_name(field)?.clone();
235 for field in field_path {
236 array = array.to_struct().field_by_name(field)?.clone();
237 }
238 Ok(array.to_primitive())
239 }
240
241 #[test]
242 pub fn test_empty_pack() {
243 let expr = Pack.new_expr(
244 PackOptions {
245 names: Default::default(),
246 nullability: Default::default(),
247 },
248 [],
249 );
250
251 let test_array = test_array();
252 let actual_array = expr.evaluate(&test_array.clone()).unwrap();
253 assert_eq!(actual_array.len(), test_array.len());
254 assert_eq!(actual_array.to_struct().struct_fields().nfields(), 0);
255 }
256
257 #[test]
258 pub fn test_simple_pack() {
259 let expr = Pack.new_expr(
260 PackOptions {
261 names: ["one", "two", "three"].into(),
262 nullability: Nullability::NonNullable,
263 },
264 [col("a"), col("b"), col("a")],
265 );
266
267 let actual_array = expr.evaluate(&test_array()).unwrap().to_struct();
268
269 assert_eq!(actual_array.names(), ["one", "two", "three"]);
270 assert_eq!(actual_array.validity(), &Validity::NonNullable);
271
272 assert_eq!(
273 primitive_field(actual_array.as_ref(), &["one"])
274 .unwrap()
275 .as_slice::<i32>(),
276 [0, 1, 2]
277 );
278 assert_eq!(
279 primitive_field(actual_array.as_ref(), &["two"])
280 .unwrap()
281 .as_slice::<i32>(),
282 [4, 5, 6]
283 );
284 assert_eq!(
285 primitive_field(actual_array.as_ref(), &["three"])
286 .unwrap()
287 .as_slice::<i32>(),
288 [0, 1, 2]
289 );
290 }
291
292 #[test]
293 pub fn test_nested_pack() {
294 let expr = Pack.new_expr(
295 PackOptions {
296 names: ["one", "two", "three"].into(),
297 nullability: Nullability::NonNullable,
298 },
299 [
300 col("a"),
301 Pack.new_expr(
302 PackOptions {
303 names: ["two_one", "two_two"].into(),
304 nullability: Nullability::NonNullable,
305 },
306 [col("b"), col("b")],
307 ),
308 col("a"),
309 ],
310 );
311
312 let actual_array = expr.evaluate(&test_array()).unwrap().to_struct();
313
314 assert_eq!(actual_array.names(), ["one", "two", "three"]);
315
316 assert_eq!(
317 primitive_field(actual_array.as_ref(), &["one"])
318 .unwrap()
319 .as_slice::<i32>(),
320 [0, 1, 2]
321 );
322 assert_eq!(
323 primitive_field(actual_array.as_ref(), &["two", "two_one"])
324 .unwrap()
325 .as_slice::<i32>(),
326 [4, 5, 6]
327 );
328 assert_eq!(
329 primitive_field(actual_array.as_ref(), &["two", "two_two"])
330 .unwrap()
331 .as_slice::<i32>(),
332 [4, 5, 6]
333 );
334 assert_eq!(
335 primitive_field(actual_array.as_ref(), &["three"])
336 .unwrap()
337 .as_slice::<i32>(),
338 [0, 1, 2]
339 );
340 }
341
342 #[test]
343 pub fn test_pack_nullable() {
344 let expr = Pack.new_expr(
345 PackOptions {
346 names: ["one", "two", "three"].into(),
347 nullability: Nullability::Nullable,
348 },
349 [col("a"), col("b"), col("a")],
350 );
351
352 let actual_array = expr.evaluate(&test_array()).unwrap().to_struct();
353
354 assert_eq!(actual_array.names(), ["one", "two", "three"]);
355 assert_eq!(actual_array.validity(), &Validity::AllValid);
356 }
357
358 #[test]
359 pub fn test_display() {
360 let expr = pack(
361 [("id", col("user_id")), ("name", col("username"))],
362 Nullability::NonNullable,
363 );
364 assert_eq!(expr.to_string(), "pack(id: $.user_id, name: $.username)");
365
366 let expr2 = Pack.new_expr(
367 PackOptions {
368 names: ["x", "y"].into(),
369 nullability: Nullability::Nullable,
370 },
371 [col("a"), col("b")],
372 );
373 assert_eq!(expr2.to_string(), "pack(x: $.a, y: $.b)?");
374 }
375}