vortex_array/scalar_fn/fns/
get_item.rs1use std::fmt::Formatter;
5
6use prost::Message;
7use vortex_error::VortexResult;
8use vortex_error::vortex_err;
9use vortex_proto::expr as pb;
10use vortex_session::VortexSession;
11
12use crate::ArrayRef;
13use crate::ExecutionCtx;
14use crate::arrays::StructArray;
15use crate::arrays::struct_::StructArrayExt;
16use crate::builtins::ArrayBuiltins;
17use crate::builtins::ExprBuiltins;
18use crate::dtype::DType;
19use crate::dtype::FieldName;
20use crate::dtype::FieldPath;
21use crate::dtype::Nullability;
22use crate::expr::Expression;
23use crate::expr::StatsCatalog;
24use crate::expr::lit;
25use crate::expr::stats::Stat;
26use crate::scalar_fn::Arity;
27use crate::scalar_fn::ChildName;
28use crate::scalar_fn::EmptyOptions;
29use crate::scalar_fn::ExecutionArgs;
30use crate::scalar_fn::ReduceCtx;
31use crate::scalar_fn::ReduceNode;
32use crate::scalar_fn::ReduceNodeRef;
33use crate::scalar_fn::ScalarFnId;
34use crate::scalar_fn::ScalarFnVTable;
35use crate::scalar_fn::ScalarFnVTableExt;
36use crate::scalar_fn::fns::literal::Literal;
37use crate::scalar_fn::fns::mask::Mask;
38use crate::scalar_fn::fns::pack::Pack;
39
40#[derive(Clone)]
41pub struct GetItem;
42
43impl ScalarFnVTable for GetItem {
44 type Options = FieldName;
45
46 fn id(&self) -> ScalarFnId {
47 ScalarFnId::from("vortex.get_item")
48 }
49
50 fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
51 Ok(Some(
52 pb::GetItemOpts {
53 path: instance.to_string(),
54 }
55 .encode_to_vec(),
56 ))
57 }
58
59 fn deserialize(
60 &self,
61 _metadata: &[u8],
62 _session: &VortexSession,
63 ) -> VortexResult<Self::Options> {
64 let opts = pb::GetItemOpts::decode(_metadata)?;
65 Ok(FieldName::from(opts.path))
66 }
67
68 fn arity(&self, _field_name: &FieldName) -> Arity {
69 Arity::Exact(1)
70 }
71
72 fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
73 match child_idx {
74 0 => ChildName::from("input"),
75 _ => unreachable!("Invalid child index {} for GetItem expression", child_idx),
76 }
77 }
78
79 fn fmt_sql(
80 &self,
81 field_name: &FieldName,
82 expr: &Expression,
83 f: &mut Formatter<'_>,
84 ) -> std::fmt::Result {
85 expr.children()[0].fmt_sql(f)?;
86 write!(f, ".{}", field_name)
87 }
88
89 fn return_dtype(&self, field_name: &FieldName, arg_dtypes: &[DType]) -> VortexResult<DType> {
90 let struct_dtype = &arg_dtypes[0];
91 let field_dtype = struct_dtype
92 .as_struct_fields_opt()
93 .and_then(|st| st.field(field_name))
94 .ok_or_else(|| {
95 vortex_err!("Couldn't find the {} field in the input scope", field_name)
96 })?;
97
98 if matches!(
100 (struct_dtype.nullability(), field_dtype.nullability()),
101 (Nullability::Nullable, Nullability::NonNullable)
102 ) {
103 return Ok(field_dtype.with_nullability(Nullability::Nullable));
104 }
105
106 Ok(field_dtype)
107 }
108
109 fn execute(
110 &self,
111 field_name: &FieldName,
112 args: &dyn ExecutionArgs,
113 ctx: &mut ExecutionCtx,
114 ) -> VortexResult<ArrayRef> {
115 let input = args.get(0)?.execute::<StructArray>(ctx)?;
116 let field = input.unmasked_field_by_name(field_name).cloned()?;
117
118 match input.dtype().nullability() {
119 Nullability::NonNullable => Ok(field),
120 Nullability::Nullable => field.mask(input.validity()?.to_array(input.len())),
121 }
122 }
123
124 fn reduce(
125 &self,
126 field_name: &FieldName,
127 node: &dyn ReduceNode,
128 ctx: &dyn ReduceCtx,
129 ) -> VortexResult<Option<ReduceNodeRef>> {
130 let child = node.child(0);
131 if let Some(child_fn) = child.scalar_fn()
132 && let Some(pack) = child_fn.as_opt::<Pack>()
133 && let Some(idx) = pack.names.find(field_name)
134 {
135 let mut field = child.child(idx);
136
137 if pack.nullability.is_nullable() {
139 field = ctx.new_node(
140 Mask.bind(EmptyOptions),
141 &[field, ctx.new_node(Literal.bind(true.into()), &[])?],
142 )?;
143 }
144
145 return Ok(Some(field));
146 }
147
148 Ok(None)
149 }
150
151 fn simplify_untyped(
152 &self,
153 field_name: &FieldName,
154 expr: &Expression,
155 ) -> VortexResult<Option<Expression>> {
156 let child = expr.child(0);
157
158 if let Some(pack) = child.as_opt::<Pack>() {
160 let idx = pack
161 .names
162 .iter()
163 .position(|name| name == field_name)
164 .ok_or_else(|| {
165 vortex_err!(
166 "Cannot find field {} in pack fields {:?}",
167 field_name,
168 pack.names
169 )
170 })?;
171
172 let mut field = child.child(idx).clone();
173
174 if pack.nullability.is_nullable() {
179 field = field.mask(lit(true))?;
181 }
182
183 return Ok(Some(field));
184 }
185
186 Ok(None)
187 }
188
189 fn stat_expression(
190 &self,
191 field_name: &FieldName,
192 _expr: &Expression,
193 stat: Stat,
194 catalog: &dyn StatsCatalog,
195 ) -> Option<Expression> {
196 catalog.stats_ref(&FieldPath::from_name(field_name.clone()), stat)
205 }
206
207 fn is_null_sensitive(&self, _field_name: &FieldName) -> bool {
209 true
210 }
211
212 fn is_fallible(&self, _field_name: &FieldName) -> bool {
213 false
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use vortex_buffer::buffer;
221
222 use crate::IntoArray;
223 use crate::dtype::DType;
224 use crate::dtype::FieldNames;
225 use crate::dtype::Nullability;
226 use crate::dtype::Nullability::NonNullable;
227 use crate::dtype::PType;
228 use crate::dtype::StructFields;
229 use crate::expr::checked_add;
230 use crate::expr::get_item;
231 use crate::expr::lit;
232 use crate::expr::pack;
233 use crate::expr::root;
234 use crate::scalar_fn::fns::get_item::StructArray;
235 use crate::validity::Validity;
236
237 fn test_array() -> StructArray {
238 StructArray::from_fields(&[
239 ("a", buffer![0i32, 1, 2].into_array()),
240 ("b", buffer![4i64, 5, 6].into_array()),
241 ])
242 .unwrap()
243 }
244
245 #[test]
246 fn get_item_by_name() {
247 let st = test_array();
248 let get_item = get_item("a", root());
249 let item = st.into_array().apply(&get_item).unwrap();
250 assert_eq!(item.dtype(), &DType::from(PType::I32))
251 }
252
253 #[test]
254 fn get_item_by_name_none() {
255 let st = test_array();
256 let get_item = get_item("c", root());
257 assert!(st.into_array().apply(&get_item).is_err());
258 }
259
260 #[test]
261 #[ignore = "apply() has a bug with null propagation from struct validity to non-nullable child fields"]
262 fn get_nullable_field() {
263 let st = StructArray::try_new(
264 FieldNames::from(["a"]),
265 vec![buffer![1i32].into_array()],
266 1,
267 Validity::AllInvalid,
268 )
269 .unwrap()
270 .into_array();
271
272 let get_item_expr = get_item("a", root());
273 let item = st.apply(&get_item_expr).unwrap();
274 assert_eq!(
276 item.dtype(),
277 &DType::Primitive(PType::I32, Nullability::Nullable)
278 );
279 }
280
281 #[test]
282 fn test_pack_get_item_rule() {
283 let pack_expr = pack([("a", lit(1)), ("b", lit(2))], NonNullable);
285 let get_item_expr = get_item("b", pack_expr);
286
287 let result = get_item_expr
288 .optimize_recursive(&DType::Struct(StructFields::empty(), NonNullable))
289 .unwrap();
290
291 assert_eq!(result, lit(2));
292 }
293
294 #[test]
295 fn test_multi_level_pack_get_item_simplify() {
296 let inner_pack = pack([("a", lit(1)), ("b", lit(2))], NonNullable);
297 let get_a = get_item("a", inner_pack);
298
299 let outer_pack = pack([("x", get_a), ("y", lit(3)), ("z", lit(4))], NonNullable);
300 let get_z = get_item("z", outer_pack);
301
302 let dtype = DType::Primitive(PType::I32, NonNullable);
303
304 let result = get_z.optimize_recursive(&dtype).unwrap();
305 assert_eq!(result, lit(4));
306 }
307
308 #[test]
309 fn test_deeply_nested_pack_get_item() {
310 let innermost = pack([("a", lit(42))], NonNullable);
311 let get_a = get_item("a", innermost);
312
313 let level2 = pack([("b", get_a)], NonNullable);
314 let get_b = get_item("b", level2);
315
316 let level3 = pack([("c", get_b)], NonNullable);
317 let get_c = get_item("c", level3);
318
319 let outermost = pack([("final", get_c)], NonNullable);
320 let get_final = get_item("final", outermost);
321
322 let dtype = DType::Primitive(PType::I32, NonNullable);
323
324 let result = get_final.optimize_recursive(&dtype).unwrap();
325 assert_eq!(result, lit(42));
326 }
327
328 #[test]
329 fn test_partial_pack_get_item_simplify() {
330 let inner_pack = pack([("x", lit(1)), ("y", lit(2))], NonNullable);
331 let get_x = get_item("x", inner_pack);
332 let add_expr = checked_add(get_x, lit(10));
333
334 let outer_pack = pack([("result", add_expr)], NonNullable);
335 let get_result = get_item("result", outer_pack);
336
337 let dtype = DType::Primitive(PType::I32, NonNullable);
338
339 let result = get_result.optimize_recursive(&dtype).unwrap();
340 let expected = checked_add(lit(1), lit(10));
341 assert_eq!(&result, &expected);
342 }
343
344 #[test]
345 fn get_item_filter_list_field() {
346 use vortex_mask::Mask;
347
348 use crate::arrays::BoolArray;
349 use crate::arrays::FilterArray;
350 use crate::arrays::ListArray;
351
352 let list = ListArray::try_new(
353 buffer![0f32, 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.].into_array(),
354 buffer![2u64, 4, 6, 8, 10, 12].into_array(),
355 Validity::Array(BoolArray::from_iter([true, true, false, true, true]).into_array()),
356 )
357 .unwrap();
358
359 let filtered = FilterArray::try_new(
360 list.into_array(),
361 Mask::from_iter([true, true, false, false, false]),
362 )
363 .unwrap();
364
365 let st = StructArray::try_new(
366 FieldNames::from(["data"]),
367 vec![filtered.into_array()],
368 2,
369 Validity::AllValid,
370 )
371 .unwrap();
372
373 st.into_array().apply(&get_item("data", root())).unwrap();
374 }
375}