Skip to main content

vortex_array/scalar_fn/fns/
get_item.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Formatter;
5
6use prost::Message;
7use vortex_error::VortexResult;
8use vortex_error::vortex_err;
9use vortex_proto::expr as pb;
10use vortex_session::VortexSession;
11use vortex_session::registry::CachedId;
12
13use crate::ArrayRef;
14use crate::ExecutionCtx;
15use crate::arrays::StructArray;
16use crate::arrays::struct_::StructArrayExt;
17use crate::builtins::ArrayBuiltins;
18use crate::builtins::ExprBuiltins;
19use crate::dtype::DType;
20use crate::dtype::FieldName;
21use crate::dtype::FieldPath;
22use crate::dtype::Nullability;
23use crate::expr::Expression;
24use crate::expr::StatsCatalog;
25use crate::expr::lit;
26use crate::expr::stats::Stat;
27use crate::scalar_fn::Arity;
28use crate::scalar_fn::ChildName;
29use crate::scalar_fn::EmptyOptions;
30use crate::scalar_fn::ExecutionArgs;
31use crate::scalar_fn::ReduceCtx;
32use crate::scalar_fn::ReduceNode;
33use crate::scalar_fn::ReduceNodeRef;
34use crate::scalar_fn::ScalarFnId;
35use crate::scalar_fn::ScalarFnVTable;
36use crate::scalar_fn::ScalarFnVTableExt;
37use crate::scalar_fn::fns::literal::Literal;
38use crate::scalar_fn::fns::mask::Mask;
39use crate::scalar_fn::fns::pack::Pack;
40
41#[derive(Clone)]
42pub struct GetItem;
43
44impl ScalarFnVTable for GetItem {
45    type Options = FieldName;
46
47    fn id(&self) -> ScalarFnId {
48        static ID: CachedId = CachedId::new("vortex.get_item");
49        *ID
50    }
51
52    fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
53        Ok(Some(
54            pb::GetItemOpts {
55                path: instance.to_string(),
56            }
57            .encode_to_vec(),
58        ))
59    }
60
61    fn deserialize(
62        &self,
63        _metadata: &[u8],
64        _session: &VortexSession,
65    ) -> VortexResult<Self::Options> {
66        let opts = pb::GetItemOpts::decode(_metadata)?;
67        Ok(FieldName::from(opts.path))
68    }
69
70    fn arity(&self, _field_name: &FieldName) -> Arity {
71        Arity::Exact(1)
72    }
73
74    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
75        match child_idx {
76            0 => ChildName::from("input"),
77            _ => unreachable!("Invalid child index {} for GetItem expression", child_idx),
78        }
79    }
80
81    fn fmt_sql(
82        &self,
83        field_name: &FieldName,
84        expr: &Expression,
85        f: &mut Formatter<'_>,
86    ) -> std::fmt::Result {
87        expr.children()[0].fmt_sql(f)?;
88        write!(f, ".{}", field_name)
89    }
90
91    fn return_dtype(&self, field_name: &FieldName, arg_dtypes: &[DType]) -> VortexResult<DType> {
92        let struct_dtype = &arg_dtypes[0];
93        let field_dtype = struct_dtype
94            .as_struct_fields_opt()
95            .and_then(|st| st.field(field_name))
96            .ok_or_else(|| {
97                vortex_err!("Couldn't find the {} field in the input scope", field_name)
98            })?;
99
100        // Match here to avoid cloning the dtype if nullability doesn't need to change
101        if matches!(
102            (struct_dtype.nullability(), field_dtype.nullability()),
103            (Nullability::Nullable, Nullability::NonNullable)
104        ) {
105            return Ok(field_dtype.with_nullability(Nullability::Nullable));
106        }
107
108        Ok(field_dtype)
109    }
110
111    fn execute(
112        &self,
113        field_name: &FieldName,
114        args: &dyn ExecutionArgs,
115        ctx: &mut ExecutionCtx,
116    ) -> VortexResult<ArrayRef> {
117        let input = args.get(0)?.execute::<StructArray>(ctx)?;
118        let field = input.unmasked_field_by_name(field_name).cloned()?;
119
120        match input.dtype().nullability() {
121            Nullability::NonNullable => Ok(field),
122            Nullability::Nullable => field.mask(input.validity()?.to_array(input.len())),
123        }
124    }
125
126    fn reduce(
127        &self,
128        field_name: &FieldName,
129        node: &dyn ReduceNode,
130        ctx: &dyn ReduceCtx,
131    ) -> VortexResult<Option<ReduceNodeRef>> {
132        let child = node.child(0);
133        if let Some(child_fn) = child.scalar_fn()
134            && let Some(pack) = child_fn.as_opt::<Pack>()
135            && let Some(idx) = pack.names.find(field_name)
136        {
137            let mut field = child.child(idx);
138
139            // Possibly mask the field if the pack is nullable
140            if pack.nullability.is_nullable() {
141                field = ctx.new_node(
142                    Mask.bind(EmptyOptions),
143                    &[field, ctx.new_node(Literal.bind(true.into()), &[])?],
144                )?;
145            }
146
147            return Ok(Some(field));
148        }
149
150        Ok(None)
151    }
152
153    fn simplify_untyped(
154        &self,
155        field_name: &FieldName,
156        expr: &Expression,
157    ) -> VortexResult<Option<Expression>> {
158        let child = expr.child(0);
159
160        // If the child is a Pack expression, we can directly return the corresponding child.
161        if let Some(pack) = child.as_opt::<Pack>() {
162            let idx = pack
163                .names
164                .iter()
165                .position(|name| name == field_name)
166                .ok_or_else(|| {
167                    vortex_err!(
168                        "Cannot find field {} in pack fields {:?}",
169                        field_name,
170                        pack.names
171                    )
172                })?;
173
174            let mut field = child.child(idx).clone();
175
176            // It's useful to simplify this node without type info, but we need to make sure
177            // the nullability is correct. We cannot cast since we don't have the dtype info here,
178            // so instead we insert a Mask expression that we know converts a child's dtype to
179            // nullable.
180            if pack.nullability.is_nullable() {
181                // Mask with an all-true array to ensure the field DType is nullable.
182                field = field.mask(lit(true))?;
183            }
184
185            return Ok(Some(field));
186        }
187
188        Ok(None)
189    }
190
191    fn stat_expression(
192        &self,
193        field_name: &FieldName,
194        _expr: &Expression,
195        stat: Stat,
196        catalog: &dyn StatsCatalog,
197    ) -> Option<Expression> {
198        // TODO(ngates): I think we can do better here and support stats over nested fields.
199        //  It would be nice if delegating to our child would return a struct of statistics
200        //  matching the nested DType such that we can write:
201        //    `get_item(expr.child(0).stat_expression(...), expr.data().field_name())`
202
203        // TODO(ngates): this is a bug whereby we may return stats for a nested field of the same
204        //  name as a field in the root struct. This should be resolved with upcoming change to
205        //  falsify expressions, but for now I'm preserving the existing buggy behavior.
206        catalog.stats_ref(&FieldPath::from_name(field_name.clone()), stat)
207    }
208
209    // This will apply struct nullability field. We could add a dtype??
210    fn is_null_sensitive(&self, _field_name: &FieldName) -> bool {
211        true
212    }
213
214    fn is_fallible(&self, _field_name: &FieldName) -> bool {
215        // If this type-checks its infallible.
216        false
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use vortex_buffer::buffer;
223
224    use crate::IntoArray;
225    use crate::dtype::DType;
226    use crate::dtype::FieldNames;
227    use crate::dtype::Nullability;
228    use crate::dtype::Nullability::NonNullable;
229    use crate::dtype::PType;
230    use crate::dtype::StructFields;
231    use crate::expr::checked_add;
232    use crate::expr::get_item;
233    use crate::expr::lit;
234    use crate::expr::pack;
235    use crate::expr::root;
236    use crate::scalar_fn::fns::get_item::StructArray;
237    use crate::validity::Validity;
238
239    fn test_array() -> StructArray {
240        StructArray::from_fields(&[
241            ("a", buffer![0i32, 1, 2].into_array()),
242            ("b", buffer![4i64, 5, 6].into_array()),
243        ])
244        .unwrap()
245    }
246
247    #[test]
248    fn get_item_by_name() {
249        let st = test_array();
250        let get_item = get_item("a", root());
251        let item = st.into_array().apply(&get_item).unwrap();
252        assert_eq!(item.dtype(), &DType::from(PType::I32))
253    }
254
255    #[test]
256    fn get_item_by_name_none() {
257        let st = test_array();
258        let get_item = get_item("c", root());
259        assert!(st.into_array().apply(&get_item).is_err());
260    }
261
262    #[test]
263    fn get_nullable_field() {
264        let st = StructArray::try_new(
265            FieldNames::from(["a"]),
266            vec![buffer![1i32].into_array()],
267            1,
268            Validity::AllInvalid,
269        )
270        .unwrap()
271        .into_array();
272
273        let get_item_expr = get_item("a", root());
274        let item = st.apply(&get_item_expr).unwrap();
275        // The dtype should be nullable since it inherits struct validity
276        assert_eq!(
277            item.dtype(),
278            &DType::Primitive(PType::I32, Nullability::Nullable)
279        );
280    }
281
282    #[test]
283    fn test_pack_get_item_rule() {
284        // Create: pack(a: lit(1), b: lit(2)).get_item("b")
285        let pack_expr = pack([("a", lit(1)), ("b", lit(2))], NonNullable);
286        let get_item_expr = get_item("b", pack_expr);
287
288        let result = get_item_expr
289            .optimize_recursive(&DType::Struct(StructFields::empty(), NonNullable))
290            .unwrap();
291
292        assert_eq!(result, lit(2));
293    }
294
295    #[test]
296    fn test_multi_level_pack_get_item_simplify() {
297        let inner_pack = pack([("a", lit(1)), ("b", lit(2))], NonNullable);
298        let get_a = get_item("a", inner_pack);
299
300        let outer_pack = pack([("x", get_a), ("y", lit(3)), ("z", lit(4))], NonNullable);
301        let get_z = get_item("z", outer_pack);
302
303        let dtype = DType::Primitive(PType::I32, NonNullable);
304
305        let result = get_z.optimize_recursive(&dtype).unwrap();
306        assert_eq!(result, lit(4));
307    }
308
309    #[test]
310    fn test_deeply_nested_pack_get_item() {
311        let innermost = pack([("a", lit(42))], NonNullable);
312        let get_a = get_item("a", innermost);
313
314        let level2 = pack([("b", get_a)], NonNullable);
315        let get_b = get_item("b", level2);
316
317        let level3 = pack([("c", get_b)], NonNullable);
318        let get_c = get_item("c", level3);
319
320        let outermost = pack([("final", get_c)], NonNullable);
321        let get_final = get_item("final", outermost);
322
323        let dtype = DType::Primitive(PType::I32, NonNullable);
324
325        let result = get_final.optimize_recursive(&dtype).unwrap();
326        assert_eq!(result, lit(42));
327    }
328
329    #[test]
330    fn test_partial_pack_get_item_simplify() {
331        let inner_pack = pack([("x", lit(1)), ("y", lit(2))], NonNullable);
332        let get_x = get_item("x", inner_pack);
333        let add_expr = checked_add(get_x, lit(10));
334
335        let outer_pack = pack([("result", add_expr)], NonNullable);
336        let get_result = get_item("result", outer_pack);
337
338        let dtype = DType::Primitive(PType::I32, NonNullable);
339
340        let result = get_result.optimize_recursive(&dtype).unwrap();
341        let expected = checked_add(lit(1), lit(10));
342        assert_eq!(&result, &expected);
343    }
344
345    #[test]
346    fn get_item_filter_list_field() {
347        use vortex_mask::Mask;
348
349        use crate::arrays::BoolArray;
350        use crate::arrays::FilterArray;
351        use crate::arrays::ListArray;
352
353        let list = ListArray::try_new(
354            buffer![0f32, 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.].into_array(),
355            buffer![2u64, 4, 6, 8, 10, 12].into_array(),
356            Validity::Array(BoolArray::from_iter([true, true, false, true, true]).into_array()),
357        )
358        .unwrap();
359
360        let filtered = FilterArray::try_new(
361            list.into_array(),
362            Mask::from_iter([true, true, false, false, false]),
363        )
364        .unwrap();
365
366        let st = StructArray::try_new(
367            FieldNames::from(["data"]),
368            vec![filtered.into_array()],
369            2,
370            Validity::AllValid,
371        )
372        .unwrap();
373
374        st.into_array().apply(&get_item("data", root())).unwrap();
375    }
376}