Skip to main content

vortex_array/expr/exprs/
get_item.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Formatter;
5
6use prost::Message;
7use vortex_dtype::DType;
8use vortex_dtype::FieldName;
9use vortex_dtype::FieldPath;
10use vortex_dtype::Nullability;
11use vortex_error::VortexExpect;
12use vortex_error::VortexResult;
13use vortex_error::vortex_err;
14use vortex_proto::expr as pb;
15use vortex_session::VortexSession;
16
17use crate::ArrayRef;
18use crate::arrays::StructArray;
19use crate::builtins::ArrayBuiltins;
20use crate::builtins::ExprBuiltins;
21use crate::expr::Arity;
22use crate::expr::ChildName;
23use crate::expr::EmptyOptions;
24use crate::expr::ExecutionArgs;
25use crate::expr::ExprId;
26use crate::expr::Expression;
27use crate::expr::Literal;
28use crate::expr::Mask;
29use crate::expr::Pack;
30use crate::expr::ReduceCtx;
31use crate::expr::ReduceNode;
32use crate::expr::ReduceNodeRef;
33use crate::expr::StatsCatalog;
34use crate::expr::VTable;
35use crate::expr::VTableExt;
36use crate::expr::exprs::root::root;
37use crate::expr::lit;
38use crate::expr::stats::Stat;
39
40pub struct GetItem;
41
42impl VTable for GetItem {
43    type Options = FieldName;
44
45    fn id(&self) -> ExprId {
46        ExprId::from("vortex.get_item")
47    }
48
49    fn serialize(&self, instance: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
50        Ok(Some(
51            pb::GetItemOpts {
52                path: instance.to_string(),
53            }
54            .encode_to_vec(),
55        ))
56    }
57
58    fn deserialize(
59        &self,
60        _metadata: &[u8],
61        _session: &VortexSession,
62    ) -> VortexResult<Self::Options> {
63        let opts = pb::GetItemOpts::decode(_metadata)?;
64        Ok(FieldName::from(opts.path))
65    }
66
67    fn arity(&self, _field_name: &FieldName) -> Arity {
68        Arity::Exact(1)
69    }
70
71    fn child_name(&self, _instance: &Self::Options, child_idx: usize) -> ChildName {
72        match child_idx {
73            0 => ChildName::from("input"),
74            _ => unreachable!("Invalid child index {} for GetItem expression", child_idx),
75        }
76    }
77
78    fn fmt_sql(
79        &self,
80        field_name: &FieldName,
81        expr: &Expression,
82        f: &mut Formatter<'_>,
83    ) -> std::fmt::Result {
84        expr.children()[0].fmt_sql(f)?;
85        write!(f, ".{}", field_name)
86    }
87
88    fn return_dtype(&self, field_name: &FieldName, arg_dtypes: &[DType]) -> VortexResult<DType> {
89        let struct_dtype = &arg_dtypes[0];
90        let field_dtype = struct_dtype
91            .as_struct_fields_opt()
92            .and_then(|st| st.field(field_name))
93            .ok_or_else(|| {
94                vortex_err!("Couldn't find the {} field in the input scope", field_name)
95            })?;
96
97        // Match here to avoid cloning the dtype if nullability doesn't need to change
98        if matches!(
99            (struct_dtype.nullability(), field_dtype.nullability()),
100            (Nullability::Nullable, Nullability::NonNullable)
101        ) {
102            return Ok(field_dtype.with_nullability(Nullability::Nullable));
103        }
104
105        Ok(field_dtype)
106    }
107
108    fn execute(&self, field_name: &FieldName, mut args: ExecutionArgs) -> VortexResult<ArrayRef> {
109        let input = args
110            .inputs
111            .pop()
112            .vortex_expect("missing input for GetItem expression")
113            .execute::<StructArray>(args.ctx)?;
114        let field = input.unmasked_field_by_name(field_name).cloned()?;
115
116        match input.dtype().nullability() {
117            Nullability::NonNullable => Ok(field),
118            Nullability::Nullable => field.mask(input.validity()?.to_array(input.len())),
119        }?
120        .execute(args.ctx)
121    }
122
123    fn reduce(
124        &self,
125        field_name: &FieldName,
126        node: &dyn ReduceNode,
127        ctx: &dyn ReduceCtx,
128    ) -> VortexResult<Option<ReduceNodeRef>> {
129        let child = node.child(0);
130        if let Some(child_fn) = child.scalar_fn()
131            && let Some(pack) = child_fn.as_opt::<Pack>()
132            && let Some(idx) = pack.names.find(field_name)
133        {
134            let mut field = child.child(idx);
135
136            // Possibly mask the field if the pack is nullable
137            if pack.nullability.is_nullable() {
138                field = ctx.new_node(
139                    Mask.bind(EmptyOptions),
140                    &[field, ctx.new_node(Literal.bind(true.into()), &[])?],
141                )?;
142            }
143
144            return Ok(Some(field));
145        }
146
147        Ok(None)
148    }
149
150    fn simplify_untyped(
151        &self,
152        field_name: &FieldName,
153        expr: &Expression,
154    ) -> VortexResult<Option<Expression>> {
155        let child = expr.child(0);
156
157        // If the child is a Pack expression, we can directly return the corresponding child.
158        if let Some(pack) = child.as_opt::<Pack>() {
159            let idx = pack
160                .names
161                .iter()
162                .position(|name| name == field_name)
163                .ok_or_else(|| {
164                    vortex_err!(
165                        "Cannot find field {} in pack fields {:?}",
166                        field_name,
167                        pack.names
168                    )
169                })?;
170
171            let mut field = child.child(idx).clone();
172
173            // It's useful to simplify this node without type info, but we need to make sure
174            // the nullability is correct. We cannot cast since we don't have the dtype info here,
175            // so instead we insert a Mask expression that we know converts a child's dtype to
176            // nullable.
177            if pack.nullability.is_nullable() {
178                // Mask with an all-true array to ensure the field DType is nullable.
179                field = field.mask(lit(true))?;
180            }
181
182            return Ok(Some(field));
183        }
184
185        Ok(None)
186    }
187
188    fn stat_expression(
189        &self,
190        field_name: &FieldName,
191        _expr: &Expression,
192        stat: Stat,
193        catalog: &dyn StatsCatalog,
194    ) -> Option<Expression> {
195        // TODO(ngates): I think we can do better here and support stats over nested fields.
196        //  It would be nice if delegating to our child would return a struct of statistics
197        //  matching the nested DType such that we can write:
198        //    `get_item(expr.child(0).stat_expression(...), expr.data().field_name())`
199
200        // TODO(ngates): this is a bug whereby we may return stats for a nested field of the same
201        //  name as a field in the root struct. This should be resolved with upcoming change to
202        //  falsify expressions, but for now I'm preserving the existing buggy behavior.
203        catalog.stats_ref(&FieldPath::from_name(field_name.clone()), stat)
204    }
205
206    // This will apply struct nullability field. We could add a dtype??
207    fn is_null_sensitive(&self, _field_name: &FieldName) -> bool {
208        true
209    }
210
211    fn is_fallible(&self, _field_name: &FieldName) -> bool {
212        // If this type-checks its infallible.
213        false
214    }
215}
216
217/// Creates an expression that accesses a field from the root array.
218///
219/// Equivalent to `get_item(field, root())` - extracts a named field from the input array.
220///
221/// ```rust
222/// # use vortex_array::expr::col;
223/// let expr = col("name");
224/// ```
225pub fn col(field: impl Into<FieldName>) -> Expression {
226    GetItem.new_expr(field.into(), vec![root()])
227}
228
229/// Creates an expression that extracts a named field from a struct expression.
230///
231/// Accesses the specified field from the result of the child expression.
232///
233/// ```rust
234/// # use vortex_array::expr::{get_item, root};
235/// let expr = get_item("user_id", root());
236/// ```
237pub fn get_item(field: impl Into<FieldName>, child: Expression) -> Expression {
238    GetItem.new_expr(field.into(), vec![child])
239}
240
241#[cfg(test)]
242mod tests {
243    use vortex_buffer::buffer;
244    use vortex_dtype::DType;
245    use vortex_dtype::FieldNames;
246    use vortex_dtype::Nullability;
247    use vortex_dtype::Nullability::NonNullable;
248    use vortex_dtype::PType;
249    use vortex_dtype::StructFields;
250
251    use crate::Array;
252    use crate::IntoArray;
253    use crate::arrays::StructArray;
254    use crate::expr::exprs::binary::checked_add;
255    use crate::expr::exprs::get_item::get_item;
256    use crate::expr::exprs::literal::lit;
257    use crate::expr::exprs::pack::pack;
258    use crate::expr::exprs::root::root;
259    use crate::validity::Validity;
260
261    fn test_array() -> StructArray {
262        StructArray::from_fields(&[
263            ("a", buffer![0i32, 1, 2].into_array()),
264            ("b", buffer![4i64, 5, 6].into_array()),
265        ])
266        .unwrap()
267    }
268
269    #[test]
270    fn get_item_by_name() {
271        let st = test_array();
272        let get_item = get_item("a", root());
273        let item = st.to_array().apply(&get_item).unwrap();
274        assert_eq!(item.dtype(), &DType::from(PType::I32))
275    }
276
277    #[test]
278    fn get_item_by_name_none() {
279        let st = test_array();
280        let get_item = get_item("c", root());
281        assert!(st.to_array().apply(&get_item).is_err());
282    }
283
284    #[test]
285    #[ignore = "apply() has a bug with null propagation from struct validity to non-nullable child fields"]
286    fn get_nullable_field() {
287        let st = StructArray::try_new(
288            FieldNames::from(["a"]),
289            vec![buffer![1i32].into_array()],
290            1,
291            Validity::AllInvalid,
292        )
293        .unwrap()
294        .to_array();
295
296        let get_item_expr = get_item("a", root());
297        let item = st.apply(&get_item_expr).unwrap();
298        // The dtype should be nullable since it inherits struct validity
299        assert_eq!(
300            item.dtype(),
301            &DType::Primitive(PType::I32, Nullability::Nullable)
302        );
303    }
304
305    #[test]
306    fn test_pack_get_item_rule() {
307        // Create: pack(a: lit(1), b: lit(2)).get_item("b")
308        let pack_expr = pack([("a", lit(1)), ("b", lit(2))], NonNullable);
309        let get_item_expr = get_item("b", pack_expr);
310
311        let result = get_item_expr
312            .optimize_recursive(&DType::Struct(StructFields::empty(), NonNullable))
313            .unwrap();
314
315        assert_eq!(result, lit(2));
316    }
317
318    #[test]
319    fn test_multi_level_pack_get_item_simplify() {
320        let inner_pack = pack([("a", lit(1)), ("b", lit(2))], NonNullable);
321        let get_a = get_item("a", inner_pack);
322
323        let outer_pack = pack([("x", get_a), ("y", lit(3)), ("z", lit(4))], NonNullable);
324        let get_z = get_item("z", outer_pack);
325
326        let dtype = DType::Primitive(PType::I32, NonNullable);
327
328        let result = get_z.optimize_recursive(&dtype).unwrap();
329        assert_eq!(result, lit(4));
330    }
331
332    #[test]
333    fn test_deeply_nested_pack_get_item() {
334        let innermost = pack([("a", lit(42))], NonNullable);
335        let get_a = get_item("a", innermost);
336
337        let level2 = pack([("b", get_a)], NonNullable);
338        let get_b = get_item("b", level2);
339
340        let level3 = pack([("c", get_b)], NonNullable);
341        let get_c = get_item("c", level3);
342
343        let outermost = pack([("final", get_c)], NonNullable);
344        let get_final = get_item("final", outermost);
345
346        let dtype = DType::Primitive(PType::I32, NonNullable);
347
348        let result = get_final.optimize_recursive(&dtype).unwrap();
349        assert_eq!(result, lit(42));
350    }
351
352    #[test]
353    fn test_partial_pack_get_item_simplify() {
354        let inner_pack = pack([("x", lit(1)), ("y", lit(2))], NonNullable);
355        let get_x = get_item("x", inner_pack);
356        let add_expr = checked_add(get_x, lit(10));
357
358        let outer_pack = pack([("result", add_expr)], NonNullable);
359        let get_result = get_item("result", outer_pack);
360
361        let dtype = DType::Primitive(PType::I32, NonNullable);
362
363        let result = get_result.optimize_recursive(&dtype).unwrap();
364        let expected = checked_add(lit(1), lit(10));
365        assert_eq!(&result, &expected);
366    }
367
368    #[test]
369    fn get_item_filter_list_field() {
370        use vortex_mask::Mask;
371
372        use crate::arrays::BoolArray;
373        use crate::arrays::FilterArray;
374        use crate::arrays::ListArray;
375
376        let list = ListArray::try_new(
377            buffer![0f32, 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.].into_array(),
378            buffer![2u64, 4, 6, 8, 10, 12].into_array(),
379            Validity::Array(BoolArray::from_iter([true, true, false, true, true]).into_array()),
380        )
381        .unwrap();
382
383        let filtered = FilterArray::try_new(
384            list.into_array(),
385            Mask::from_iter([true, true, false, false, false]),
386        )
387        .unwrap();
388
389        let st = StructArray::try_new(
390            FieldNames::from(["data"]),
391            vec![filtered.into_array()],
392            2,
393            Validity::AllValid,
394        )
395        .unwrap();
396
397        st.to_array().apply(&get_item("data", root())).unwrap();
398    }
399}