Skip to main content

vortex_array/scalar_fn/fns/
get_item.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Formatter;
5
6use prost::Message;
7use vortex_error::VortexResult;
8use vortex_error::vortex_err;
9use vortex_proto::expr as pb;
10use vortex_session::VortexSession;
11
12use crate::ArrayRef;
13use crate::ExecutionCtx;
14use crate::arrays::StructArray;
15use crate::arrays::struct_::StructArrayExt;
16use crate::builtins::ArrayBuiltins;
17use crate::builtins::ExprBuiltins;
18use crate::dtype::DType;
19use crate::dtype::FieldName;
20use crate::dtype::FieldPath;
21use crate::dtype::Nullability;
22use crate::expr::Expression;
23use crate::expr::StatsCatalog;
24use crate::expr::lit;
25use crate::expr::stats::Stat;
26use crate::scalar_fn::Arity;
27use crate::scalar_fn::ChildName;
28use crate::scalar_fn::EmptyOptions;
29use crate::scalar_fn::ExecutionArgs;
30use crate::scalar_fn::ReduceCtx;
31use crate::scalar_fn::ReduceNode;
32use crate::scalar_fn::ReduceNodeRef;
33use crate::scalar_fn::ScalarFnId;
34use crate::scalar_fn::ScalarFnVTable;
35use crate::scalar_fn::ScalarFnVTableExt;
36use crate::scalar_fn::fns::literal::Literal;
37use crate::scalar_fn::fns::mask::Mask;
38use crate::scalar_fn::fns::pack::Pack;
39
40#[derive(Clone)]
41pub struct GetItem;
42
43impl ScalarFnVTable for GetItem {
44    type Options = FieldName;
45
46    fn id(&self) -> ScalarFnId {
47        ScalarFnId::from("vortex.get_item")
48    }
49
50    fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
51        Ok(Some(
52            pb::GetItemOpts {
53                path: instance.to_string(),
54            }
55            .encode_to_vec(),
56        ))
57    }
58
59    fn deserialize(
60        &self,
61        _metadata: &[u8],
62        _session: &VortexSession,
63    ) -> VortexResult<Self::Options> {
64        let opts = pb::GetItemOpts::decode(_metadata)?;
65        Ok(FieldName::from(opts.path))
66    }
67
68    fn arity(&self, _field_name: &FieldName) -> Arity {
69        Arity::Exact(1)
70    }
71
72    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
73        match child_idx {
74            0 => ChildName::from("input"),
75            _ => unreachable!("Invalid child index {} for GetItem expression", child_idx),
76        }
77    }
78
79    fn fmt_sql(
80        &self,
81        field_name: &FieldName,
82        expr: &Expression,
83        f: &mut Formatter<'_>,
84    ) -> std::fmt::Result {
85        expr.children()[0].fmt_sql(f)?;
86        write!(f, ".{}", field_name)
87    }
88
89    fn return_dtype(&self, field_name: &FieldName, arg_dtypes: &[DType]) -> VortexResult<DType> {
90        let struct_dtype = &arg_dtypes[0];
91        let field_dtype = struct_dtype
92            .as_struct_fields_opt()
93            .and_then(|st| st.field(field_name))
94            .ok_or_else(|| {
95                vortex_err!("Couldn't find the {} field in the input scope", field_name)
96            })?;
97
98        // Match here to avoid cloning the dtype if nullability doesn't need to change
99        if matches!(
100            (struct_dtype.nullability(), field_dtype.nullability()),
101            (Nullability::Nullable, Nullability::NonNullable)
102        ) {
103            return Ok(field_dtype.with_nullability(Nullability::Nullable));
104        }
105
106        Ok(field_dtype)
107    }
108
109    fn execute(
110        &self,
111        field_name: &FieldName,
112        args: &dyn ExecutionArgs,
113        ctx: &mut ExecutionCtx,
114    ) -> VortexResult<ArrayRef> {
115        let input = args.get(0)?.execute::<StructArray>(ctx)?;
116        let field = input.unmasked_field_by_name(field_name).cloned()?;
117
118        match input.dtype().nullability() {
119            Nullability::NonNullable => Ok(field),
120            Nullability::Nullable => field.mask(input.validity()?.to_array(input.len())),
121        }
122    }
123
124    fn reduce(
125        &self,
126        field_name: &FieldName,
127        node: &dyn ReduceNode,
128        ctx: &dyn ReduceCtx,
129    ) -> VortexResult<Option<ReduceNodeRef>> {
130        let child = node.child(0);
131        if let Some(child_fn) = child.scalar_fn()
132            && let Some(pack) = child_fn.as_opt::<Pack>()
133            && let Some(idx) = pack.names.find(field_name)
134        {
135            let mut field = child.child(idx);
136
137            // Possibly mask the field if the pack is nullable
138            if pack.nullability.is_nullable() {
139                field = ctx.new_node(
140                    Mask.bind(EmptyOptions),
141                    &[field, ctx.new_node(Literal.bind(true.into()), &[])?],
142                )?;
143            }
144
145            return Ok(Some(field));
146        }
147
148        Ok(None)
149    }
150
151    fn simplify_untyped(
152        &self,
153        field_name: &FieldName,
154        expr: &Expression,
155    ) -> VortexResult<Option<Expression>> {
156        let child = expr.child(0);
157
158        // If the child is a Pack expression, we can directly return the corresponding child.
159        if let Some(pack) = child.as_opt::<Pack>() {
160            let idx = pack
161                .names
162                .iter()
163                .position(|name| name == field_name)
164                .ok_or_else(|| {
165                    vortex_err!(
166                        "Cannot find field {} in pack fields {:?}",
167                        field_name,
168                        pack.names
169                    )
170                })?;
171
172            let mut field = child.child(idx).clone();
173
174            // It's useful to simplify this node without type info, but we need to make sure
175            // the nullability is correct. We cannot cast since we don't have the dtype info here,
176            // so instead we insert a Mask expression that we know converts a child's dtype to
177            // nullable.
178            if pack.nullability.is_nullable() {
179                // Mask with an all-true array to ensure the field DType is nullable.
180                field = field.mask(lit(true))?;
181            }
182
183            return Ok(Some(field));
184        }
185
186        Ok(None)
187    }
188
189    fn stat_expression(
190        &self,
191        field_name: &FieldName,
192        _expr: &Expression,
193        stat: Stat,
194        catalog: &dyn StatsCatalog,
195    ) -> Option<Expression> {
196        // TODO(ngates): I think we can do better here and support stats over nested fields.
197        //  It would be nice if delegating to our child would return a struct of statistics
198        //  matching the nested DType such that we can write:
199        //    `get_item(expr.child(0).stat_expression(...), expr.data().field_name())`
200
201        // TODO(ngates): this is a bug whereby we may return stats for a nested field of the same
202        //  name as a field in the root struct. This should be resolved with upcoming change to
203        //  falsify expressions, but for now I'm preserving the existing buggy behavior.
204        catalog.stats_ref(&FieldPath::from_name(field_name.clone()), stat)
205    }
206
207    // This will apply struct nullability field. We could add a dtype??
208    fn is_null_sensitive(&self, _field_name: &FieldName) -> bool {
209        true
210    }
211
212    fn is_fallible(&self, _field_name: &FieldName) -> bool {
213        // If this type-checks its infallible.
214        false
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use vortex_buffer::buffer;
221
222    use crate::IntoArray;
223    use crate::dtype::DType;
224    use crate::dtype::FieldNames;
225    use crate::dtype::Nullability;
226    use crate::dtype::Nullability::NonNullable;
227    use crate::dtype::PType;
228    use crate::dtype::StructFields;
229    use crate::expr::checked_add;
230    use crate::expr::get_item;
231    use crate::expr::lit;
232    use crate::expr::pack;
233    use crate::expr::root;
234    use crate::scalar_fn::fns::get_item::StructArray;
235    use crate::validity::Validity;
236
237    fn test_array() -> StructArray {
238        StructArray::from_fields(&[
239            ("a", buffer![0i32, 1, 2].into_array()),
240            ("b", buffer![4i64, 5, 6].into_array()),
241        ])
242        .unwrap()
243    }
244
245    #[test]
246    fn get_item_by_name() {
247        let st = test_array();
248        let get_item = get_item("a", root());
249        let item = st.into_array().apply(&get_item).unwrap();
250        assert_eq!(item.dtype(), &DType::from(PType::I32))
251    }
252
253    #[test]
254    fn get_item_by_name_none() {
255        let st = test_array();
256        let get_item = get_item("c", root());
257        assert!(st.into_array().apply(&get_item).is_err());
258    }
259
260    #[test]
261    #[ignore = "apply() has a bug with null propagation from struct validity to non-nullable child fields"]
262    fn get_nullable_field() {
263        let st = StructArray::try_new(
264            FieldNames::from(["a"]),
265            vec![buffer![1i32].into_array()],
266            1,
267            Validity::AllInvalid,
268        )
269        .unwrap()
270        .into_array();
271
272        let get_item_expr = get_item("a", root());
273        let item = st.apply(&get_item_expr).unwrap();
274        // The dtype should be nullable since it inherits struct validity
275        assert_eq!(
276            item.dtype(),
277            &DType::Primitive(PType::I32, Nullability::Nullable)
278        );
279    }
280
281    #[test]
282    fn test_pack_get_item_rule() {
283        // Create: pack(a: lit(1), b: lit(2)).get_item("b")
284        let pack_expr = pack([("a", lit(1)), ("b", lit(2))], NonNullable);
285        let get_item_expr = get_item("b", pack_expr);
286
287        let result = get_item_expr
288            .optimize_recursive(&DType::Struct(StructFields::empty(), NonNullable))
289            .unwrap();
290
291        assert_eq!(result, lit(2));
292    }
293
294    #[test]
295    fn test_multi_level_pack_get_item_simplify() {
296        let inner_pack = pack([("a", lit(1)), ("b", lit(2))], NonNullable);
297        let get_a = get_item("a", inner_pack);
298
299        let outer_pack = pack([("x", get_a), ("y", lit(3)), ("z", lit(4))], NonNullable);
300        let get_z = get_item("z", outer_pack);
301
302        let dtype = DType::Primitive(PType::I32, NonNullable);
303
304        let result = get_z.optimize_recursive(&dtype).unwrap();
305        assert_eq!(result, lit(4));
306    }
307
308    #[test]
309    fn test_deeply_nested_pack_get_item() {
310        let innermost = pack([("a", lit(42))], NonNullable);
311        let get_a = get_item("a", innermost);
312
313        let level2 = pack([("b", get_a)], NonNullable);
314        let get_b = get_item("b", level2);
315
316        let level3 = pack([("c", get_b)], NonNullable);
317        let get_c = get_item("c", level3);
318
319        let outermost = pack([("final", get_c)], NonNullable);
320        let get_final = get_item("final", outermost);
321
322        let dtype = DType::Primitive(PType::I32, NonNullable);
323
324        let result = get_final.optimize_recursive(&dtype).unwrap();
325        assert_eq!(result, lit(42));
326    }
327
328    #[test]
329    fn test_partial_pack_get_item_simplify() {
330        let inner_pack = pack([("x", lit(1)), ("y", lit(2))], NonNullable);
331        let get_x = get_item("x", inner_pack);
332        let add_expr = checked_add(get_x, lit(10));
333
334        let outer_pack = pack([("result", add_expr)], NonNullable);
335        let get_result = get_item("result", outer_pack);
336
337        let dtype = DType::Primitive(PType::I32, NonNullable);
338
339        let result = get_result.optimize_recursive(&dtype).unwrap();
340        let expected = checked_add(lit(1), lit(10));
341        assert_eq!(&result, &expected);
342    }
343
344    #[test]
345    fn get_item_filter_list_field() {
346        use vortex_mask::Mask;
347
348        use crate::arrays::BoolArray;
349        use crate::arrays::FilterArray;
350        use crate::arrays::ListArray;
351
352        let list = ListArray::try_new(
353            buffer![0f32, 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.].into_array(),
354            buffer![2u64, 4, 6, 8, 10, 12].into_array(),
355            Validity::Array(BoolArray::from_iter([true, true, false, true, true]).into_array()),
356        )
357        .unwrap();
358
359        let filtered = FilterArray::try_new(
360            list.into_array(),
361            Mask::from_iter([true, true, false, false, false]),
362        )
363        .unwrap();
364
365        let st = StructArray::try_new(
366            FieldNames::from(["data"]),
367            vec![filtered.into_array()],
368            2,
369            Validity::AllValid,
370        )
371        .unwrap();
372
373        st.into_array().apply(&get_item("data", root())).unwrap();
374    }
375}