vortex_array/scalar_fn/fns/
pack.rs1use std::fmt::Display;
5use std::fmt::Formatter;
6use std::hash::Hash;
7
8use itertools::Itertools as _;
9use prost::Message;
10use vortex_error::VortexResult;
11use vortex_proto::expr as pb;
12use vortex_session::VortexSession;
13
14use crate::ArrayRef;
15use crate::IntoArray;
16use crate::arrays::StructArray;
17use crate::dtype::DType;
18use crate::dtype::FieldName;
19use crate::dtype::FieldNames;
20use crate::dtype::Nullability;
21use crate::dtype::StructFields;
22use crate::expr::Expression;
23use crate::expr::lit;
24use crate::scalar_fn::Arity;
25use crate::scalar_fn::ChildName;
26use crate::scalar_fn::ExecutionArgs;
27use crate::scalar_fn::ScalarFnId;
28use crate::scalar_fn::ScalarFnVTable;
29use crate::validity::Validity;
30
31#[derive(Clone)]
33pub struct Pack;
34
35#[derive(Debug, Clone, PartialEq, Eq, Hash)]
36pub struct PackOptions {
37 pub names: FieldNames,
38 pub nullability: Nullability,
39}
40
41impl Display for PackOptions {
42 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
43 write!(
44 f,
45 "names: [{}], nullability: {:#}",
46 self.names.iter().join(", "),
47 self.nullability
48 )
49 }
50}
51
52impl ScalarFnVTable for Pack {
53 type Options = PackOptions;
54
55 fn id(&self) -> ScalarFnId {
56 ScalarFnId::new_ref("vortex.pack")
57 }
58
59 fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
60 Ok(Some(
61 pb::PackOpts {
62 paths: instance.names.iter().map(|n| n.to_string()).collect(),
63 nullable: instance.nullability.into(),
64 }
65 .encode_to_vec(),
66 ))
67 }
68
69 fn deserialize(
70 &self,
71 _metadata: &[u8],
72 _session: &VortexSession,
73 ) -> VortexResult<Self::Options> {
74 let opts = pb::PackOpts::decode(_metadata)?;
75 let names: FieldNames = opts
76 .paths
77 .iter()
78 .map(|name| FieldName::from(name.as_str()))
79 .collect();
80 Ok(PackOptions {
81 names,
82 nullability: opts.nullable.into(),
83 })
84 }
85
86 fn arity(&self, options: &Self::Options) -> Arity {
87 Arity::Exact(options.names.len())
88 }
89
90 fn child_name(&self, instance: &Self::Options, child_idx: usize) -> ChildName {
91 match instance.names.get(child_idx) {
92 Some(name) => ChildName::from(name.inner().clone()),
93 None => unreachable!(
94 "Invalid child index {} for Pack expression with {} fields",
95 child_idx,
96 instance.names.len()
97 ),
98 }
99 }
100
101 fn fmt_sql(
102 &self,
103 options: &Self::Options,
104 expr: &Expression,
105 f: &mut Formatter<'_>,
106 ) -> std::fmt::Result {
107 write!(f, "pack(")?;
108 for (i, (name, child)) in options.names.iter().zip(expr.children().iter()).enumerate() {
109 write!(f, "{}: ", name)?;
110 child.fmt_sql(f)?;
111 if i + 1 < options.names.len() {
112 write!(f, ", ")?;
113 }
114 }
115 write!(f, "){}", options.nullability)
116 }
117
118 fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
119 Ok(DType::Struct(
120 StructFields::new(options.names.clone(), arg_dtypes.to_vec()),
121 options.nullability,
122 ))
123 }
124
125 fn validity(
126 &self,
127 _options: &Self::Options,
128 _expression: &Expression,
129 ) -> VortexResult<Option<Expression>> {
130 Ok(Some(lit(true)))
131 }
132
133 fn execute(&self, options: &Self::Options, args: ExecutionArgs) -> VortexResult<ArrayRef> {
134 let len = args.row_count;
135 let value_arrays = args.inputs;
136 let validity: Validity = options.nullability.into();
137 StructArray::try_new(options.names.clone(), value_arrays, len, validity)?
138 .into_array()
139 .execute(args.ctx)
140 }
141
142 fn is_null_sensitive(&self, _instance: &Self::Options) -> bool {
144 true
145 }
146
147 fn is_fallible(&self, _instance: &Self::Options) -> bool {
148 false
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use vortex_buffer::buffer;
155 use vortex_error::VortexResult;
156 use vortex_error::vortex_bail;
157
158 use super::Pack;
159 use super::PackOptions;
160 use crate::Array;
161 use crate::ArrayRef;
162 use crate::IntoArray;
163 use crate::ToCanonical;
164 use crate::arrays::PrimitiveArray;
165 use crate::arrays::StructArray;
166 use crate::assert_arrays_eq;
167 use crate::dtype::Nullability;
168 use crate::expr::col;
169 use crate::expr::pack;
170 use crate::scalar_fn::ScalarFnVTableExt;
171 use crate::validity::Validity;
172 use crate::vtable::ValidityHelper;
173
174 fn test_array() -> ArrayRef {
175 StructArray::from_fields(&[
176 ("a", buffer![0, 1, 2].into_array()),
177 ("b", buffer![4, 5, 6].into_array()),
178 ])
179 .unwrap()
180 .into_array()
181 }
182
183 fn primitive_field(array: &dyn Array, field_path: &[&str]) -> VortexResult<PrimitiveArray> {
184 let mut field_path = field_path.iter();
185
186 let Some(field) = field_path.next() else {
187 vortex_bail!("empty field path");
188 };
189
190 let mut array = array.to_struct().unmasked_field_by_name(field)?.clone();
191 for field in field_path {
192 array = array.to_struct().unmasked_field_by_name(field)?.clone();
193 }
194 Ok(array.to_primitive())
195 }
196
197 #[test]
198 pub fn test_empty_pack() {
199 let expr = Pack.new_expr(
200 PackOptions {
201 names: Default::default(),
202 nullability: Default::default(),
203 },
204 [],
205 );
206
207 let test_array = test_array();
208 let actual_array = test_array.clone().apply(&expr).unwrap();
209 assert_eq!(actual_array.len(), test_array.len());
210 assert_eq!(actual_array.to_struct().struct_fields().nfields(), 0);
211 }
212
213 #[test]
214 pub fn test_simple_pack() {
215 let expr = Pack.new_expr(
216 PackOptions {
217 names: ["one", "two", "three"].into(),
218 nullability: Nullability::NonNullable,
219 },
220 [col("a"), col("b"), col("a")],
221 );
222
223 let actual_array = test_array().apply(&expr).unwrap().to_struct();
224
225 assert_eq!(actual_array.names(), ["one", "two", "three"]);
226 assert_eq!(actual_array.validity(), &Validity::NonNullable);
227
228 assert_arrays_eq!(
229 primitive_field(actual_array.as_ref(), &["one"]).unwrap(),
230 PrimitiveArray::from_iter([0i32, 1, 2])
231 );
232 assert_arrays_eq!(
233 primitive_field(actual_array.as_ref(), &["two"]).unwrap(),
234 PrimitiveArray::from_iter([4i32, 5, 6])
235 );
236 assert_arrays_eq!(
237 primitive_field(actual_array.as_ref(), &["three"]).unwrap(),
238 PrimitiveArray::from_iter([0i32, 1, 2])
239 );
240 }
241
242 #[test]
243 pub fn test_nested_pack() {
244 let expr = Pack.new_expr(
245 PackOptions {
246 names: ["one", "two", "three"].into(),
247 nullability: Nullability::NonNullable,
248 },
249 [
250 col("a"),
251 Pack.new_expr(
252 PackOptions {
253 names: ["two_one", "two_two"].into(),
254 nullability: Nullability::NonNullable,
255 },
256 [col("b"), col("b")],
257 ),
258 col("a"),
259 ],
260 );
261
262 let actual_array = test_array().apply(&expr).unwrap().to_struct();
263
264 assert_eq!(actual_array.names(), ["one", "two", "three"]);
265
266 assert_arrays_eq!(
267 primitive_field(actual_array.as_ref(), &["one"]).unwrap(),
268 PrimitiveArray::from_iter([0i32, 1, 2])
269 );
270 assert_arrays_eq!(
271 primitive_field(actual_array.as_ref(), &["two", "two_one"]).unwrap(),
272 PrimitiveArray::from_iter([4i32, 5, 6])
273 );
274 assert_arrays_eq!(
275 primitive_field(actual_array.as_ref(), &["two", "two_two"]).unwrap(),
276 PrimitiveArray::from_iter([4i32, 5, 6])
277 );
278 assert_arrays_eq!(
279 primitive_field(actual_array.as_ref(), &["three"]).unwrap(),
280 PrimitiveArray::from_iter([0i32, 1, 2])
281 );
282 }
283
284 #[test]
285 pub fn test_pack_nullable() {
286 let expr = Pack.new_expr(
287 PackOptions {
288 names: ["one", "two", "three"].into(),
289 nullability: Nullability::Nullable,
290 },
291 [col("a"), col("b"), col("a")],
292 );
293
294 let actual_array = test_array().apply(&expr).unwrap().to_struct();
295
296 assert_eq!(actual_array.names(), ["one", "two", "three"]);
297 assert_eq!(actual_array.validity(), &Validity::AllValid);
298 }
299
300 #[test]
301 pub fn test_display() {
302 let expr = pack(
303 [("id", col("user_id")), ("name", col("username"))],
304 Nullability::NonNullable,
305 );
306 assert_eq!(expr.to_string(), "pack(id: $.user_id, name: $.username)");
307
308 let expr2 = Pack.new_expr(
309 PackOptions {
310 names: ["x", "y"].into(),
311 nullability: Nullability::Nullable,
312 },
313 [col("a"), col("b")],
314 );
315 assert_eq!(expr2.to_string(), "pack(x: $.a, y: $.b)?");
316 }
317}