1use std::fmt::Formatter;
5
6use prost::Message;
7use vortex_dtype::DType;
8use vortex_dtype::FieldName;
9use vortex_dtype::FieldPath;
10use vortex_dtype::Nullability;
11use vortex_error::VortexExpect;
12use vortex_error::VortexResult;
13use vortex_error::vortex_err;
14use vortex_proto::expr as pb;
15use vortex_session::VortexSession;
16
17use crate::ArrayRef;
18use crate::arrays::StructArray;
19use crate::builtins::ArrayBuiltins;
20use crate::builtins::ExprBuiltins;
21use crate::expr::Arity;
22use crate::expr::ChildName;
23use crate::expr::EmptyOptions;
24use crate::expr::ExecutionArgs;
25use crate::expr::ExprId;
26use crate::expr::Expression;
27use crate::expr::Literal;
28use crate::expr::Mask;
29use crate::expr::Pack;
30use crate::expr::ReduceCtx;
31use crate::expr::ReduceNode;
32use crate::expr::ReduceNodeRef;
33use crate::expr::StatsCatalog;
34use crate::expr::VTable;
35use crate::expr::VTableExt;
36use crate::expr::exprs::root::root;
37use crate::expr::lit;
38use crate::expr::stats::Stat;
39
40pub struct GetItem;
41
42impl VTable for GetItem {
43 type Options = FieldName;
44
45 fn id(&self) -> ExprId {
46 ExprId::from("vortex.get_item")
47 }
48
49 fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
50 Ok(Some(
51 pb::GetItemOpts {
52 path: instance.to_string(),
53 }
54 .encode_to_vec(),
55 ))
56 }
57
58 fn deserialize(
59 &self,
60 _metadata: &[u8],
61 _session: &VortexSession,
62 ) -> VortexResult<Self::Options> {
63 let opts = pb::GetItemOpts::decode(_metadata)?;
64 Ok(FieldName::from(opts.path))
65 }
66
67 fn arity(&self, _field_name: &FieldName) -> Arity {
68 Arity::Exact(1)
69 }
70
71 fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
72 match child_idx {
73 0 => ChildName::from("input"),
74 _ => unreachable!("Invalid child index {} for GetItem expression", child_idx),
75 }
76 }
77
78 fn fmt_sql(
79 &self,
80 field_name: &FieldName,
81 expr: &Expression,
82 f: &mut Formatter<'_>,
83 ) -> std::fmt::Result {
84 expr.children()[0].fmt_sql(f)?;
85 write!(f, ".{}", field_name)
86 }
87
88 fn return_dtype(&self, field_name: &FieldName, arg_dtypes: &[DType]) -> VortexResult<DType> {
89 let struct_dtype = &arg_dtypes[0];
90 let field_dtype = struct_dtype
91 .as_struct_fields_opt()
92 .and_then(|st| st.field(field_name))
93 .ok_or_else(|| {
94 vortex_err!("Couldn't find the {} field in the input scope", field_name)
95 })?;
96
97 if matches!(
99 (struct_dtype.nullability(), field_dtype.nullability()),
100 (Nullability::Nullable, Nullability::NonNullable)
101 ) {
102 return Ok(field_dtype.with_nullability(Nullability::Nullable));
103 }
104
105 Ok(field_dtype)
106 }
107
108 fn execute(&self, field_name: &FieldName, mut args: ExecutionArgs) -> VortexResult<ArrayRef> {
109 let input = args
110 .inputs
111 .pop()
112 .vortex_expect("missing input for GetItem expression")
113 .execute::<StructArray>(args.ctx)?;
114 let field = input.unmasked_field_by_name(field_name).cloned()?;
115
116 match input.dtype().nullability() {
117 Nullability::NonNullable => Ok(field),
118 Nullability::Nullable => field.mask(input.validity()?.to_array(input.len())),
119 }?
120 .execute(args.ctx)
121 }
122
123 fn reduce(
124 &self,
125 field_name: &FieldName,
126 node: &dyn ReduceNode,
127 ctx: &dyn ReduceCtx,
128 ) -> VortexResult<Option<ReduceNodeRef>> {
129 let child = node.child(0);
130 if let Some(child_fn) = child.scalar_fn()
131 && let Some(pack) = child_fn.as_opt::<Pack>()
132 && let Some(idx) = pack.names.find(field_name)
133 {
134 let mut field = child.child(idx);
135
136 if pack.nullability.is_nullable() {
138 field = ctx.new_node(
139 Mask.bind(EmptyOptions),
140 &[field, ctx.new_node(Literal.bind(true.into()), &[])?],
141 )?;
142 }
143
144 return Ok(Some(field));
145 }
146
147 Ok(None)
148 }
149
150 fn simplify_untyped(
151 &self,
152 field_name: &FieldName,
153 expr: &Expression,
154 ) -> VortexResult<Option<Expression>> {
155 let child = expr.child(0);
156
157 if let Some(pack) = child.as_opt::<Pack>() {
159 let idx = pack
160 .names
161 .iter()
162 .position(|name| name == field_name)
163 .ok_or_else(|| {
164 vortex_err!(
165 "Cannot find field {} in pack fields {:?}",
166 field_name,
167 pack.names
168 )
169 })?;
170
171 let mut field = child.child(idx).clone();
172
173 if pack.nullability.is_nullable() {
178 field = field.mask(lit(true))?;
180 }
181
182 return Ok(Some(field));
183 }
184
185 Ok(None)
186 }
187
188 fn stat_expression(
189 &self,
190 field_name: &FieldName,
191 _expr: &Expression,
192 stat: Stat,
193 catalog: &dyn StatsCatalog,
194 ) -> Option<Expression> {
195 catalog.stats_ref(&FieldPath::from_name(field_name.clone()), stat)
204 }
205
206 fn is_null_sensitive(&self, _field_name: &FieldName) -> bool {
208 true
209 }
210
211 fn is_fallible(&self, _field_name: &FieldName) -> bool {
212 false
214 }
215}
216
217pub fn col(field: impl Into<FieldName>) -> Expression {
226 GetItem.new_expr(field.into(), vec![root()])
227}
228
229pub fn get_item(field: impl Into<FieldName>, child: Expression) -> Expression {
238 GetItem.new_expr(field.into(), vec![child])
239}
240
241#[cfg(test)]
242mod tests {
243 use vortex_buffer::buffer;
244 use vortex_dtype::DType;
245 use vortex_dtype::FieldNames;
246 use vortex_dtype::Nullability;
247 use vortex_dtype::Nullability::NonNullable;
248 use vortex_dtype::PType;
249 use vortex_dtype::StructFields;
250
251 use crate::Array;
252 use crate::IntoArray;
253 use crate::arrays::StructArray;
254 use crate::expr::exprs::binary::checked_add;
255 use crate::expr::exprs::get_item::get_item;
256 use crate::expr::exprs::literal::lit;
257 use crate::expr::exprs::pack::pack;
258 use crate::expr::exprs::root::root;
259 use crate::validity::Validity;
260
261 fn test_array() -> StructArray {
262 StructArray::from_fields(&[
263 ("a", buffer![0i32, 1, 2].into_array()),
264 ("b", buffer![4i64, 5, 6].into_array()),
265 ])
266 .unwrap()
267 }
268
269 #[test]
270 fn get_item_by_name() {
271 let st = test_array();
272 let get_item = get_item("a", root());
273 let item = st.to_array().apply(&get_item).unwrap();
274 assert_eq!(item.dtype(), &DType::from(PType::I32))
275 }
276
277 #[test]
278 fn get_item_by_name_none() {
279 let st = test_array();
280 let get_item = get_item("c", root());
281 assert!(st.to_array().apply(&get_item).is_err());
282 }
283
284 #[test]
285 #[ignore = "apply() has a bug with null propagation from struct validity to non-nullable child fields"]
286 fn get_nullable_field() {
287 let st = StructArray::try_new(
288 FieldNames::from(["a"]),
289 vec![buffer![1i32].into_array()],
290 1,
291 Validity::AllInvalid,
292 )
293 .unwrap()
294 .to_array();
295
296 let get_item_expr = get_item("a", root());
297 let item = st.apply(&get_item_expr).unwrap();
298 assert_eq!(
300 item.dtype(),
301 &DType::Primitive(PType::I32, Nullability::Nullable)
302 );
303 }
304
305 #[test]
306 fn test_pack_get_item_rule() {
307 let pack_expr = pack([("a", lit(1)), ("b", lit(2))], NonNullable);
309 let get_item_expr = get_item("b", pack_expr);
310
311 let result = get_item_expr
312 .optimize_recursive(&DType::Struct(StructFields::empty(), NonNullable))
313 .unwrap();
314
315 assert_eq!(result, lit(2));
316 }
317
318 #[test]
319 fn test_multi_level_pack_get_item_simplify() {
320 let inner_pack = pack([("a", lit(1)), ("b", lit(2))], NonNullable);
321 let get_a = get_item("a", inner_pack);
322
323 let outer_pack = pack([("x", get_a), ("y", lit(3)), ("z", lit(4))], NonNullable);
324 let get_z = get_item("z", outer_pack);
325
326 let dtype = DType::Primitive(PType::I32, NonNullable);
327
328 let result = get_z.optimize_recursive(&dtype).unwrap();
329 assert_eq!(result, lit(4));
330 }
331
332 #[test]
333 fn test_deeply_nested_pack_get_item() {
334 let innermost = pack([("a", lit(42))], NonNullable);
335 let get_a = get_item("a", innermost);
336
337 let level2 = pack([("b", get_a)], NonNullable);
338 let get_b = get_item("b", level2);
339
340 let level3 = pack([("c", get_b)], NonNullable);
341 let get_c = get_item("c", level3);
342
343 let outermost = pack([("final", get_c)], NonNullable);
344 let get_final = get_item("final", outermost);
345
346 let dtype = DType::Primitive(PType::I32, NonNullable);
347
348 let result = get_final.optimize_recursive(&dtype).unwrap();
349 assert_eq!(result, lit(42));
350 }
351
352 #[test]
353 fn test_partial_pack_get_item_simplify() {
354 let inner_pack = pack([("x", lit(1)), ("y", lit(2))], NonNullable);
355 let get_x = get_item("x", inner_pack);
356 let add_expr = checked_add(get_x, lit(10));
357
358 let outer_pack = pack([("result", add_expr)], NonNullable);
359 let get_result = get_item("result", outer_pack);
360
361 let dtype = DType::Primitive(PType::I32, NonNullable);
362
363 let result = get_result.optimize_recursive(&dtype).unwrap();
364 let expected = checked_add(lit(1), lit(10));
365 assert_eq!(&result, &expected);
366 }
367
368 #[test]
369 fn get_item_filter_list_field() {
370 use vortex_mask::Mask;
371
372 use crate::arrays::BoolArray;
373 use crate::arrays::FilterArray;
374 use crate::arrays::ListArray;
375
376 let list = ListArray::try_new(
377 buffer![0f32, 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.].into_array(),
378 buffer![2u64, 4, 6, 8, 10, 12].into_array(),
379 Validity::Array(BoolArray::from_iter([true, true, false, true, true]).into_array()),
380 )
381 .unwrap();
382
383 let filtered = FilterArray::try_new(
384 list.into_array(),
385 Mask::from_iter([true, true, false, false, false]),
386 )
387 .unwrap();
388
389 let st = StructArray::try_new(
390 FieldNames::from(["data"]),
391 vec![filtered.into_array()],
392 2,
393 Validity::AllValid,
394 )
395 .unwrap();
396
397 st.to_array().apply(&get_item("data", root())).unwrap();
398 }
399}