1use std::fmt::Formatter;
5use std::ops::Not;
6
7use prost::Message;
8use vortex_dtype::DType;
9use vortex_dtype::FieldName;
10use vortex_dtype::FieldPath;
11use vortex_dtype::Nullability;
12use vortex_error::VortexExpect;
13use vortex_error::VortexResult;
14use vortex_error::vortex_err;
15use vortex_proto::expr as pb;
16use vortex_vector::Datum;
17use vortex_vector::ScalarOps;
18use vortex_vector::VectorOps;
19
20use crate::ArrayRef;
21use crate::ToCanonical;
22use crate::builtins::ExprBuiltins;
23use crate::compute::mask;
24use crate::expr::Arity;
25use crate::expr::ChildName;
26use crate::expr::EmptyOptions;
27use crate::expr::ExecutionArgs;
28use crate::expr::ExprId;
29use crate::expr::Expression;
30use crate::expr::Literal;
31use crate::expr::Mask;
32use crate::expr::Pack;
33use crate::expr::ReduceCtx;
34use crate::expr::ReduceNode;
35use crate::expr::ReduceNodeRef;
36use crate::expr::StatsCatalog;
37use crate::expr::VTable;
38use crate::expr::VTableExt;
39use crate::expr::exprs::root::root;
40use crate::expr::lit;
41use crate::expr::stats::Stat;
42
43pub struct GetItem;
44
45impl VTable for GetItem {
46 type Options = FieldName;
47
48 fn id(&self) -> ExprId {
49 ExprId::from("vortex.get_item")
50 }
51
52 fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
53 Ok(Some(
54 pb::GetItemOpts {
55 path: instance.to_string(),
56 }
57 .encode_to_vec(),
58 ))
59 }
60
61 fn deserialize(&self, metadata: &[u8]) -> VortexResult<Self::Options> {
62 let opts = pb::GetItemOpts::decode(metadata)?;
63 Ok(FieldName::from(opts.path))
64 }
65
66 fn arity(&self, _field_name: &FieldName) -> Arity {
67 Arity::Exact(1)
68 }
69
70 fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
71 match child_idx {
72 0 => ChildName::from("input"),
73 _ => unreachable!("Invalid child index {} for GetItem expression", child_idx),
74 }
75 }
76
77 fn fmt_sql(
78 &self,
79 field_name: &FieldName,
80 expr: &Expression,
81 f: &mut Formatter<'_>,
82 ) -> std::fmt::Result {
83 expr.children()[0].fmt_sql(f)?;
84 write!(f, ".{}", field_name)
85 }
86
87 fn return_dtype(&self, field_name: &FieldName, arg_dtypes: &[DType]) -> VortexResult<DType> {
88 let struct_dtype = &arg_dtypes[0];
89 let field_dtype = struct_dtype
90 .as_struct_fields_opt()
91 .and_then(|st| st.field(field_name))
92 .ok_or_else(|| {
93 vortex_err!("Couldn't find the {} field in the input scope", field_name)
94 })?;
95
96 if matches!(
98 (struct_dtype.nullability(), field_dtype.nullability()),
99 (Nullability::Nullable, Nullability::NonNullable)
100 ) {
101 return Ok(field_dtype.with_nullability(Nullability::Nullable));
102 }
103
104 Ok(field_dtype)
105 }
106
107 fn evaluate(
108 &self,
109 field_name: &FieldName,
110 expr: &Expression,
111 scope: &ArrayRef,
112 ) -> VortexResult<ArrayRef> {
113 let input = expr.children()[0].evaluate(scope)?.to_struct();
114 let field = input.field_by_name(field_name).cloned()?;
115
116 match input.dtype().nullability() {
117 Nullability::NonNullable => Ok(field),
118 Nullability::Nullable => mask(&field, &input.validity_mask().not()),
119 }
120 }
121
122 fn execute(&self, field_name: &FieldName, mut args: ExecutionArgs) -> VortexResult<Datum> {
123 let struct_dtype = args.dtypes[0]
124 .as_struct_fields_opt()
125 .ok_or_else(|| vortex_err!("Expected struct dtype for child of GetItem expression"))?;
126 let field_idx = struct_dtype
127 .find(field_name)
128 .ok_or_else(|| vortex_err!("Field {} not found in struct dtype", field_name))?;
129
130 match args.datums.pop().vortex_expect("missing input") {
131 Datum::Scalar(s) => {
132 let mut field = s.as_struct().field(field_idx);
133 field.mask_validity(s.is_valid());
134 Ok(Datum::Scalar(field))
135 }
136 Datum::Vector(v) => {
137 let mut field = v.as_struct().fields()[field_idx].clone();
138 field.mask_validity(v.validity());
139 Ok(Datum::Vector(field))
140 }
141 }
142 }
143
144 fn reduce(
145 &self,
146 field_name: &FieldName,
147 node: &dyn ReduceNode,
148 ctx: &dyn ReduceCtx,
149 ) -> VortexResult<Option<ReduceNodeRef>> {
150 let child = node.child(0);
151 if let Some(child_fn) = child.scalar_fn()
152 && let Some(pack) = child_fn.as_opt::<Pack>()
153 && let Some(idx) = pack.names.find(field_name)
154 {
155 let mut field = child.child(idx);
156
157 if pack.nullability.is_nullable() {
159 field = ctx.new_node(
160 Mask.bind(EmptyOptions),
161 &[field, ctx.new_node(Literal.bind(true.into()), &[])?],
162 )?;
163 }
164
165 return Ok(Some(field));
166 }
167
168 Ok(None)
169 }
170
171 fn simplify_untyped(
172 &self,
173 field_name: &FieldName,
174 expr: &Expression,
175 ) -> VortexResult<Option<Expression>> {
176 let child = expr.child(0);
177
178 if let Some(pack) = child.as_opt::<Pack>() {
180 let idx = pack
181 .names
182 .iter()
183 .position(|name| name == field_name)
184 .ok_or_else(|| {
185 vortex_err!(
186 "Cannot find field {} in pack fields {:?}",
187 field_name,
188 pack.names
189 )
190 })?;
191
192 let mut field = child.child(idx).clone();
193
194 if pack.nullability.is_nullable() {
199 field = field.mask(lit(true))?;
201 }
202
203 return Ok(Some(field));
204 }
205
206 Ok(None)
207 }
208
209 fn stat_expression(
210 &self,
211 field_name: &FieldName,
212 _expr: &Expression,
213 stat: Stat,
214 catalog: &dyn StatsCatalog,
215 ) -> Option<Expression> {
216 catalog.stats_ref(&FieldPath::from_name(field_name.clone()), stat)
225 }
226
227 fn is_null_sensitive(&self, _field_name: &FieldName) -> bool {
229 true
230 }
231
232 fn is_fallible(&self, _field_name: &FieldName) -> bool {
233 false
235 }
236}
237
238pub fn col(field: impl Into<FieldName>) -> Expression {
247 GetItem.new_expr(field.into(), vec![root()])
248}
249
250pub fn get_item(field: impl Into<FieldName>, child: Expression) -> Expression {
259 GetItem.new_expr(field.into(), vec![child])
260}
261
262#[cfg(test)]
263mod tests {
264 use vortex_buffer::buffer;
265 use vortex_dtype::DType;
266 use vortex_dtype::FieldNames;
267 use vortex_dtype::Nullability;
268 use vortex_dtype::Nullability::NonNullable;
269 use vortex_dtype::PType;
270 use vortex_dtype::StructFields;
271 use vortex_scalar::Scalar;
272
273 use crate::Array;
274 use crate::IntoArray;
275 use crate::arrays::StructArray;
276 use crate::expr::exprs::binary::checked_add;
277 use crate::expr::exprs::get_item::get_item;
278 use crate::expr::exprs::literal::lit;
279 use crate::expr::exprs::pack::pack;
280 use crate::expr::exprs::root::root;
281 use crate::validity::Validity;
282
283 fn test_array() -> StructArray {
284 StructArray::from_fields(&[
285 ("a", buffer![0i32, 1, 2].into_array()),
286 ("b", buffer![4i64, 5, 6].into_array()),
287 ])
288 .unwrap()
289 }
290
291 #[test]
292 fn get_item_by_name() {
293 let st = test_array();
294 let get_item = get_item("a", root());
295 let item = get_item.evaluate(&st.to_array()).unwrap();
296 assert_eq!(item.dtype(), &DType::from(PType::I32))
297 }
298
299 #[test]
300 fn get_item_by_name_none() {
301 let st = test_array();
302 let get_item = get_item("c", root());
303 assert!(get_item.evaluate(&st.to_array()).is_err());
304 }
305
306 #[test]
307 fn get_nullable_field() {
308 let st = StructArray::try_new(
309 FieldNames::from(["a"]),
310 vec![buffer![1i32].into_array()],
311 1,
312 Validity::AllInvalid,
313 )
314 .unwrap()
315 .to_array();
316
317 let get_item = get_item("a", root());
318 let item = get_item.evaluate(&st).unwrap();
319 assert_eq!(
320 item.scalar_at(0),
321 Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable))
322 );
323 }
324
325 #[test]
326 fn test_pack_get_item_rule() {
327 let pack_expr = pack([("a", lit(1)), ("b", lit(2))], NonNullable);
329 let get_item_expr = get_item("b", pack_expr);
330
331 let result = get_item_expr
332 .optimize_recursive(&DType::Struct(StructFields::empty(), NonNullable))
333 .unwrap();
334
335 assert_eq!(result, lit(2));
336 }
337
338 #[test]
339 fn test_multi_level_pack_get_item_simplify() {
340 let inner_pack = pack([("a", lit(1)), ("b", lit(2))], NonNullable);
341 let get_a = get_item("a", inner_pack);
342
343 let outer_pack = pack([("x", get_a), ("y", lit(3)), ("z", lit(4))], NonNullable);
344 let get_z = get_item("z", outer_pack);
345
346 let dtype = DType::Primitive(PType::I32, NonNullable);
347
348 let result = get_z.optimize_recursive(&dtype).unwrap();
349 assert_eq!(result, lit(4));
350 }
351
352 #[test]
353 fn test_deeply_nested_pack_get_item() {
354 let innermost = pack([("a", lit(42))], NonNullable);
355 let get_a = get_item("a", innermost);
356
357 let level2 = pack([("b", get_a)], NonNullable);
358 let get_b = get_item("b", level2);
359
360 let level3 = pack([("c", get_b)], NonNullable);
361 let get_c = get_item("c", level3);
362
363 let outermost = pack([("final", get_c)], NonNullable);
364 let get_final = get_item("final", outermost);
365
366 let dtype = DType::Primitive(PType::I32, NonNullable);
367
368 let result = get_final.optimize_recursive(&dtype).unwrap();
369 assert_eq!(result, lit(42));
370 }
371
372 #[test]
373 fn test_partial_pack_get_item_simplify() {
374 let inner_pack = pack([("x", lit(1)), ("y", lit(2))], NonNullable);
375 let get_x = get_item("x", inner_pack);
376 let add_expr = checked_add(get_x, lit(10));
377
378 let outer_pack = pack([("result", add_expr)], NonNullable);
379 let get_result = get_item("result", outer_pack);
380
381 let dtype = DType::Primitive(PType::I32, NonNullable);
382
383 let result = get_result.optimize_recursive(&dtype).unwrap();
384 let expected = checked_add(lit(1), lit(10));
385 assert_eq!(&result, &expected);
386 }
387}