vortex_array/scalar_fn/fns/
get_item.rs1use std::fmt::Formatter;
5
6use prost::Message;
7use vortex_error::VortexResult;
8use vortex_error::vortex_err;
9use vortex_proto::expr as pb;
10use vortex_session::VortexSession;
11use vortex_session::registry::CachedId;
12
13use crate::ArrayRef;
14use crate::ExecutionCtx;
15use crate::arrays::StructArray;
16use crate::arrays::struct_::StructArrayExt;
17use crate::builtins::ArrayBuiltins;
18use crate::builtins::ExprBuiltins;
19use crate::dtype::DType;
20use crate::dtype::FieldName;
21use crate::dtype::FieldPath;
22use crate::dtype::Nullability;
23use crate::expr::Expression;
24use crate::expr::StatsCatalog;
25use crate::expr::lit;
26use crate::expr::stats::Stat;
27use crate::scalar_fn::Arity;
28use crate::scalar_fn::ChildName;
29use crate::scalar_fn::EmptyOptions;
30use crate::scalar_fn::ExecutionArgs;
31use crate::scalar_fn::ReduceCtx;
32use crate::scalar_fn::ReduceNode;
33use crate::scalar_fn::ReduceNodeRef;
34use crate::scalar_fn::ScalarFnId;
35use crate::scalar_fn::ScalarFnVTable;
36use crate::scalar_fn::ScalarFnVTableExt;
37use crate::scalar_fn::fns::literal::Literal;
38use crate::scalar_fn::fns::mask::Mask;
39use crate::scalar_fn::fns::pack::Pack;
40
41#[derive(Clone)]
42pub struct GetItem;
43
44impl ScalarFnVTable for GetItem {
45 type Options = FieldName;
46
47 fn id(&self) -> ScalarFnId {
48 static ID: CachedId = CachedId::new("vortex.get_item");
49 *ID
50 }
51
52 fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
53 Ok(Some(
54 pb::GetItemOpts {
55 path: instance.to_string(),
56 }
57 .encode_to_vec(),
58 ))
59 }
60
61 fn deserialize(
62 &self,
63 _metadata: &[u8],
64 _session: &VortexSession,
65 ) -> VortexResult<Self::Options> {
66 let opts = pb::GetItemOpts::decode(_metadata)?;
67 Ok(FieldName::from(opts.path))
68 }
69
70 fn arity(&self, _field_name: &FieldName) -> Arity {
71 Arity::Exact(1)
72 }
73
74 fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
75 match child_idx {
76 0 => ChildName::from("input"),
77 _ => unreachable!("Invalid child index {} for GetItem expression", child_idx),
78 }
79 }
80
81 fn fmt_sql(
82 &self,
83 field_name: &FieldName,
84 expr: &Expression,
85 f: &mut Formatter<'_>,
86 ) -> std::fmt::Result {
87 expr.children()[0].fmt_sql(f)?;
88 write!(f, ".{}", field_name)
89 }
90
91 fn return_dtype(&self, field_name: &FieldName, arg_dtypes: &[DType]) -> VortexResult<DType> {
92 let struct_dtype = &arg_dtypes[0];
93 let field_dtype = struct_dtype
94 .as_struct_fields_opt()
95 .and_then(|st| st.field(field_name))
96 .ok_or_else(|| {
97 vortex_err!("Couldn't find the {} field in the input scope", field_name)
98 })?;
99
100 if matches!(
102 (struct_dtype.nullability(), field_dtype.nullability()),
103 (Nullability::Nullable, Nullability::NonNullable)
104 ) {
105 return Ok(field_dtype.with_nullability(Nullability::Nullable));
106 }
107
108 Ok(field_dtype)
109 }
110
111 fn execute(
112 &self,
113 field_name: &FieldName,
114 args: &dyn ExecutionArgs,
115 ctx: &mut ExecutionCtx,
116 ) -> VortexResult<ArrayRef> {
117 let input = args.get(0)?.execute::<StructArray>(ctx)?;
118 let field = input.unmasked_field_by_name(field_name).cloned()?;
119
120 match input.dtype().nullability() {
121 Nullability::NonNullable => Ok(field),
122 Nullability::Nullable => field.mask(input.validity()?.to_array(input.len())),
123 }
124 }
125
126 fn reduce(
127 &self,
128 field_name: &FieldName,
129 node: &dyn ReduceNode,
130 ctx: &dyn ReduceCtx,
131 ) -> VortexResult<Option<ReduceNodeRef>> {
132 let child = node.child(0);
133 if let Some(child_fn) = child.scalar_fn()
134 && let Some(pack) = child_fn.as_opt::<Pack>()
135 && let Some(idx) = pack.names.find(field_name)
136 {
137 let mut field = child.child(idx);
138
139 if pack.nullability.is_nullable() {
141 field = ctx.new_node(
142 Mask.bind(EmptyOptions),
143 &[field, ctx.new_node(Literal.bind(true.into()), &[])?],
144 )?;
145 }
146
147 return Ok(Some(field));
148 }
149
150 Ok(None)
151 }
152
153 fn simplify_untyped(
154 &self,
155 field_name: &FieldName,
156 expr: &Expression,
157 ) -> VortexResult<Option<Expression>> {
158 let child = expr.child(0);
159
160 if let Some(pack) = child.as_opt::<Pack>() {
162 let idx = pack
163 .names
164 .iter()
165 .position(|name| name == field_name)
166 .ok_or_else(|| {
167 vortex_err!(
168 "Cannot find field {} in pack fields {:?}",
169 field_name,
170 pack.names
171 )
172 })?;
173
174 let mut field = child.child(idx).clone();
175
176 if pack.nullability.is_nullable() {
181 field = field.mask(lit(true))?;
183 }
184
185 return Ok(Some(field));
186 }
187
188 Ok(None)
189 }
190
191 fn stat_expression(
192 &self,
193 field_name: &FieldName,
194 _expr: &Expression,
195 stat: Stat,
196 catalog: &dyn StatsCatalog,
197 ) -> Option<Expression> {
198 catalog.stats_ref(&FieldPath::from_name(field_name.clone()), stat)
207 }
208
209 fn is_null_sensitive(&self, _field_name: &FieldName) -> bool {
211 true
212 }
213
214 fn is_fallible(&self, _field_name: &FieldName) -> bool {
215 false
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use vortex_buffer::buffer;
223
224 use crate::IntoArray;
225 use crate::dtype::DType;
226 use crate::dtype::FieldNames;
227 use crate::dtype::Nullability;
228 use crate::dtype::Nullability::NonNullable;
229 use crate::dtype::PType;
230 use crate::dtype::StructFields;
231 use crate::expr::checked_add;
232 use crate::expr::get_item;
233 use crate::expr::lit;
234 use crate::expr::pack;
235 use crate::expr::root;
236 use crate::scalar_fn::fns::get_item::StructArray;
237 use crate::validity::Validity;
238
239 fn test_array() -> StructArray {
240 StructArray::from_fields(&[
241 ("a", buffer![0i32, 1, 2].into_array()),
242 ("b", buffer![4i64, 5, 6].into_array()),
243 ])
244 .unwrap()
245 }
246
247 #[test]
248 fn get_item_by_name() {
249 let st = test_array();
250 let get_item = get_item("a", root());
251 let item = st.into_array().apply(&get_item).unwrap();
252 assert_eq!(item.dtype(), &DType::from(PType::I32))
253 }
254
255 #[test]
256 fn get_item_by_name_none() {
257 let st = test_array();
258 let get_item = get_item("c", root());
259 assert!(st.into_array().apply(&get_item).is_err());
260 }
261
262 #[test]
263 fn get_nullable_field() {
264 let st = StructArray::try_new(
265 FieldNames::from(["a"]),
266 vec![buffer![1i32].into_array()],
267 1,
268 Validity::AllInvalid,
269 )
270 .unwrap()
271 .into_array();
272
273 let get_item_expr = get_item("a", root());
274 let item = st.apply(&get_item_expr).unwrap();
275 assert_eq!(
277 item.dtype(),
278 &DType::Primitive(PType::I32, Nullability::Nullable)
279 );
280 }
281
282 #[test]
283 fn test_pack_get_item_rule() {
284 let pack_expr = pack([("a", lit(1)), ("b", lit(2))], NonNullable);
286 let get_item_expr = get_item("b", pack_expr);
287
288 let result = get_item_expr
289 .optimize_recursive(&DType::Struct(StructFields::empty(), NonNullable))
290 .unwrap();
291
292 assert_eq!(result, lit(2));
293 }
294
295 #[test]
296 fn test_multi_level_pack_get_item_simplify() {
297 let inner_pack = pack([("a", lit(1)), ("b", lit(2))], NonNullable);
298 let get_a = get_item("a", inner_pack);
299
300 let outer_pack = pack([("x", get_a), ("y", lit(3)), ("z", lit(4))], NonNullable);
301 let get_z = get_item("z", outer_pack);
302
303 let dtype = DType::Primitive(PType::I32, NonNullable);
304
305 let result = get_z.optimize_recursive(&dtype).unwrap();
306 assert_eq!(result, lit(4));
307 }
308
309 #[test]
310 fn test_deeply_nested_pack_get_item() {
311 let innermost = pack([("a", lit(42))], NonNullable);
312 let get_a = get_item("a", innermost);
313
314 let level2 = pack([("b", get_a)], NonNullable);
315 let get_b = get_item("b", level2);
316
317 let level3 = pack([("c", get_b)], NonNullable);
318 let get_c = get_item("c", level3);
319
320 let outermost = pack([("final", get_c)], NonNullable);
321 let get_final = get_item("final", outermost);
322
323 let dtype = DType::Primitive(PType::I32, NonNullable);
324
325 let result = get_final.optimize_recursive(&dtype).unwrap();
326 assert_eq!(result, lit(42));
327 }
328
329 #[test]
330 fn test_partial_pack_get_item_simplify() {
331 let inner_pack = pack([("x", lit(1)), ("y", lit(2))], NonNullable);
332 let get_x = get_item("x", inner_pack);
333 let add_expr = checked_add(get_x, lit(10));
334
335 let outer_pack = pack([("result", add_expr)], NonNullable);
336 let get_result = get_item("result", outer_pack);
337
338 let dtype = DType::Primitive(PType::I32, NonNullable);
339
340 let result = get_result.optimize_recursive(&dtype).unwrap();
341 let expected = checked_add(lit(1), lit(10));
342 assert_eq!(&result, &expected);
343 }
344
345 #[test]
346 fn get_item_filter_list_field() {
347 use vortex_mask::Mask;
348
349 use crate::arrays::BoolArray;
350 use crate::arrays::FilterArray;
351 use crate::arrays::ListArray;
352
353 let list = ListArray::try_new(
354 buffer![0f32, 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.].into_array(),
355 buffer![2u64, 4, 6, 8, 10, 12].into_array(),
356 Validity::Array(BoolArray::from_iter([true, true, false, true, true]).into_array()),
357 )
358 .unwrap();
359
360 let filtered = FilterArray::try_new(
361 list.into_array(),
362 Mask::from_iter([true, true, false, false, false]),
363 )
364 .unwrap();
365
366 let st = StructArray::try_new(
367 FieldNames::from(["data"]),
368 vec![filtered.into_array()],
369 2,
370 Validity::AllValid,
371 )
372 .unwrap();
373
374 st.into_array().apply(&get_item("data", root())).unwrap();
375 }
376}