1use std::fmt::Formatter;
5use std::hash::Hash;
6
7use itertools::Itertools as _;
8use prost::Message;
9use vortex_dtype::{DType, FieldName, FieldNames, Nullability, StructFields};
10use vortex_error::{VortexResult, vortex_bail, vortex_err};
11use vortex_proto::expr as pb;
12
13use crate::arrays::StructArray;
14use crate::expr::{ChildName, ExprId, Expression, ExpressionView, VTable, VTableExt};
15use crate::validity::Validity;
16use crate::{ArrayRef, IntoArray};
17
18pub struct Pack;
20
21#[derive(Debug, Clone, PartialEq, Eq, Hash)]
22pub struct PackOptions {
23 pub names: FieldNames,
24 pub nullability: Nullability,
25}
26
27impl VTable for Pack {
28 type Instance = PackOptions;
29
30 fn id(&self) -> ExprId {
31 ExprId::new_ref("vortex.pack")
32 }
33
34 fn serialize(&self, instance: &Self::Instance) -> VortexResult<Option<Vec<u8>>> {
35 Ok(Some(
36 pb::PackOpts {
37 paths: instance.names.iter().map(|n| n.to_string()).collect(),
38 nullable: instance.nullability.into(),
39 }
40 .encode_to_vec(),
41 ))
42 }
43
44 fn deserialize(&self, metadata: &[u8]) -> VortexResult<Option<Self::Instance>> {
45 let opts = pb::PackOpts::decode(metadata)?;
46 let names: FieldNames = opts
47 .paths
48 .iter()
49 .map(|name| FieldName::from(name.as_str()))
50 .collect();
51 Ok(Some(PackOptions {
52 names,
53 nullability: opts.nullable.into(),
54 }))
55 }
56
57 fn validate(&self, expr: &ExpressionView<Self>) -> VortexResult<()> {
58 let instance = expr.data();
59 if expr.children().len() != instance.names.len() {
60 vortex_bail!(
61 "Pack expression expects {} children, got {}",
62 instance.names.len(),
63 expr.children().len()
64 );
65 }
66 Ok(())
67 }
68
69 fn child_name(&self, instance: &Self::Instance, child_idx: usize) -> ChildName {
70 match instance.names.get(child_idx) {
71 Some(name) => ChildName::from(name.inner().clone()),
72 None => unreachable!(
73 "Invalid child index {} for Pack expression with {} fields",
74 child_idx,
75 instance.names.len()
76 ),
77 }
78 }
79
80 fn fmt_sql(&self, expr: &ExpressionView<Self>, f: &mut Formatter<'_>) -> std::fmt::Result {
81 write!(f, "pack(")?;
82 for (i, (name, child)) in expr
83 .data()
84 .names
85 .iter()
86 .zip(expr.children().iter())
87 .enumerate()
88 {
89 write!(f, "{}: ", name)?;
90 child.fmt_sql(f)?;
91 if i + 1 < expr.data().names.len() {
92 write!(f, ", ")?;
93 }
94 }
95 write!(f, "){}", expr.data().nullability)
96 }
97
98 fn return_dtype(&self, expr: &ExpressionView<Self>, scope: &DType) -> VortexResult<DType> {
99 let value_dtypes = expr
100 .children()
101 .iter()
102 .map(|child| child.return_dtype(scope))
103 .collect::<VortexResult<Vec<_>>>()?;
104 Ok(DType::Struct(
105 StructFields::new(expr.data().names.clone(), value_dtypes),
106 expr.data().nullability,
107 ))
108 }
109
110 fn evaluate(&self, expr: &ExpressionView<Self>, scope: &ArrayRef) -> VortexResult<ArrayRef> {
111 let len = scope.len();
112 let value_arrays = expr
113 .children()
114 .iter()
115 .zip_eq(expr.data().names.iter())
116 .map(|(child_expr, name)| {
117 child_expr
118 .evaluate(scope)
119 .map_err(|e| e.with_context(format!("Can't evaluate '{name}'")))
120 })
121 .process_results(|it| it.collect::<Vec<_>>())?;
122 let validity = match expr.data().nullability {
123 Nullability::NonNullable => Validity::NonNullable,
124 Nullability::Nullable => Validity::AllValid,
125 };
126 Ok(
127 StructArray::try_new(expr.data().names.clone(), value_arrays, len, validity)?
128 .into_array(),
129 )
130 }
131}
132
133impl ExpressionView<'_, Pack> {
134 pub fn field(&self, field_name: &FieldName) -> VortexResult<Expression> {
135 let idx = self
136 .data()
137 .names
138 .iter()
139 .position(|name| name == field_name)
140 .ok_or_else(|| {
141 vortex_err!(
142 "Cannot find field {} in pack fields {:?}",
143 field_name,
144 self.data().names
145 )
146 })?;
147
148 Ok(self.child(idx).clone())
149 }
150}
151
152pub fn pack(
160 elements: impl IntoIterator<Item = (impl Into<FieldName>, Expression)>,
161 nullability: Nullability,
162) -> Expression {
163 let (names, values): (Vec<_>, Vec<_>) = elements
164 .into_iter()
165 .map(|(name, value)| (name.into(), value))
166 .unzip();
167 Pack.new_expr(
168 PackOptions {
169 names: names.into(),
170 nullability,
171 },
172 values,
173 )
174}
175
176#[cfg(test)]
177mod tests {
178 use vortex_buffer::buffer;
179 use vortex_dtype::Nullability;
180 use vortex_error::{VortexResult, vortex_bail};
181
182 use super::{Pack, PackOptions, pack};
183 use crate::arrays::{PrimitiveArray, StructArray};
184 use crate::expr::VTableExt;
185 use crate::expr::exprs::get_item::col;
186 use crate::validity::Validity;
187 use crate::vtable::ValidityHelper;
188 use crate::{Array, ArrayRef, IntoArray, ToCanonical};
189
190 fn test_array() -> ArrayRef {
191 StructArray::from_fields(&[
192 ("a", buffer![0, 1, 2].into_array()),
193 ("b", buffer![4, 5, 6].into_array()),
194 ])
195 .unwrap()
196 .into_array()
197 }
198
199 fn primitive_field(array: &dyn Array, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
200 let mut field_path = field_path.iter();
201
202 let Some(field) = field_path.next() else {
203 vortex_bail!("empty field path");
204 };
205
206 let mut array = array.to_struct().field_by_name(field)?.clone();
207 for field in field_path {
208 array = array.to_struct().field_by_name(field)?.clone();
209 }
210 Ok(array.to_primitive())
211 }
212
213 #[test]
214 pub fn test_empty_pack() {
215 let expr = Pack.new_expr(
216 PackOptions {
217 names: Default::default(),
218 nullability: Default::default(),
219 },
220 [],
221 );
222
223 let test_array = test_array();
224 let actual_array = expr.evaluate(&test_array.clone()).unwrap();
225 assert_eq!(actual_array.len(), test_array.len());
226 assert_eq!(actual_array.to_struct().struct_fields().nfields(), 0);
227 }
228
229 #[test]
230 pub fn test_simple_pack() {
231 let expr = Pack.new_expr(
232 PackOptions {
233 names: ["one", "two", "three"].into(),
234 nullability: Nullability::NonNullable,
235 },
236 [col("a"), col("b"), col("a")],
237 );
238
239 let actual_array = expr.evaluate(&test_array()).unwrap().to_struct();
240
241 assert_eq!(actual_array.names(), ["one", "two", "three"]);
242 assert_eq!(actual_array.validity(), &Validity::NonNullable);
243
244 assert_eq!(
245 primitive_field(actual_array.as_ref(), &["one"])
246 .unwrap()
247 .as_slice::<i32>(),
248 [0, 1, 2]
249 );
250 assert_eq!(
251 primitive_field(actual_array.as_ref(), &["two"])
252 .unwrap()
253 .as_slice::<i32>(),
254 [4, 5, 6]
255 );
256 assert_eq!(
257 primitive_field(actual_array.as_ref(), &["three"])
258 .unwrap()
259 .as_slice::<i32>(),
260 [0, 1, 2]
261 );
262 }
263
264 #[test]
265 pub fn test_nested_pack() {
266 let expr = Pack.new_expr(
267 PackOptions {
268 names: ["one", "two", "three"].into(),
269 nullability: Nullability::NonNullable,
270 },
271 [
272 col("a"),
273 Pack.new_expr(
274 PackOptions {
275 names: ["two_one", "two_two"].into(),
276 nullability: Nullability::NonNullable,
277 },
278 [col("b"), col("b")],
279 ),
280 col("a"),
281 ],
282 );
283
284 let actual_array = expr.evaluate(&test_array()).unwrap().to_struct();
285
286 assert_eq!(actual_array.names(), ["one", "two", "three"]);
287
288 assert_eq!(
289 primitive_field(actual_array.as_ref(), &["one"])
290 .unwrap()
291 .as_slice::<i32>(),
292 [0, 1, 2]
293 );
294 assert_eq!(
295 primitive_field(actual_array.as_ref(), &["two", "two_one"])
296 .unwrap()
297 .as_slice::<i32>(),
298 [4, 5, 6]
299 );
300 assert_eq!(
301 primitive_field(actual_array.as_ref(), &["two", "two_two"])
302 .unwrap()
303 .as_slice::<i32>(),
304 [4, 5, 6]
305 );
306 assert_eq!(
307 primitive_field(actual_array.as_ref(), &["three"])
308 .unwrap()
309 .as_slice::<i32>(),
310 [0, 1, 2]
311 );
312 }
313
314 #[test]
315 pub fn test_pack_nullable() {
316 let expr = Pack.new_expr(
317 PackOptions {
318 names: ["one", "two", "three"].into(),
319 nullability: Nullability::Nullable,
320 },
321 [col("a"), col("b"), col("a")],
322 );
323
324 let actual_array = expr.evaluate(&test_array()).unwrap().to_struct();
325
326 assert_eq!(actual_array.names(), ["one", "two", "three"]);
327 assert_eq!(actual_array.validity(), &Validity::AllValid);
328 }
329
330 #[test]
331 pub fn test_display() {
332 let expr = pack(
333 [("id", col("user_id")), ("name", col("username"))],
334 Nullability::NonNullable,
335 );
336 assert_eq!(expr.to_string(), "pack(id: $.user_id, name: $.username)");
337
338 let expr2 = Pack.new_expr(
339 PackOptions {
340 names: ["x", "y"].into(),
341 nullability: Nullability::Nullable,
342 },
343 [col("a"), col("b")],
344 );
345 assert_eq!(expr2.to_string(), "pack(x: $.a, y: $.b)?");
346 }
347}