vortex_array/expr/exprs/
get_item.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Formatter;
5use std::ops::Not;
6
7use prost::Message;
8use vortex_dtype::{DType, FieldName, FieldPath, Nullability};
9use vortex_error::{VortexResult, vortex_bail, vortex_err};
10use vortex_proto::expr as pb;
11
12use crate::compute::mask;
13use crate::expr::exprs::root::root;
14use crate::expr::{ChildName, ExprId, Expression, ExpressionView, StatsCatalog, VTable, VTableExt};
15use crate::stats::Stat;
16use crate::{ArrayRef, ToCanonical};
17
18pub struct GetItem;
19
20impl VTable for GetItem {
21    type Instance = FieldName;
22
23    fn id(&self) -> ExprId {
24        ExprId::from("vortex.get_item")
25    }
26
27    fn serialize(&self, instance: &Self::Instance) -> VortexResult<Option<Vec<u8>>> {
28        Ok(Some(
29            pb::GetItemOpts {
30                path: instance.to_string(),
31            }
32            .encode_to_vec(),
33        ))
34    }
35
36    fn deserialize(&self, metadata: &[u8]) -> VortexResult<Option<Self::Instance>> {
37        let opts = pb::GetItemOpts::decode(metadata)?;
38        Ok(Some(FieldName::from(opts.path)))
39    }
40
41    fn validate(&self, expr: &ExpressionView<Self>) -> VortexResult<()> {
42        if expr.children().len() != 1 {
43            vortex_bail!(
44                "GetItem expression requires exactly 1 child, got {}",
45                expr.children().len()
46            );
47        }
48        Ok(())
49    }
50
51    fn child_name(&self, _instance: &Self::Instance, child_idx: usize) -> ChildName {
52        match child_idx {
53            0 => ChildName::from("input"),
54            _ => unreachable!("Invalid child index {} for GetItem expression", child_idx),
55        }
56    }
57
58    fn fmt_sql(&self, expr: &ExpressionView<Self>, f: &mut Formatter<'_>) -> std::fmt::Result {
59        expr.children()[0].fmt_sql(f)?;
60        write!(f, ".{}", expr.data())
61    }
62
63    fn fmt_data(&self, instance: &Self::Instance, f: &mut Formatter<'_>) -> std::fmt::Result {
64        write!(f, "\"{}\"", instance.inner().as_ref())
65    }
66
67    fn return_dtype(&self, expr: &ExpressionView<Self>, scope: &DType) -> VortexResult<DType> {
68        let struct_dtype = expr.children()[0].return_dtype(scope)?;
69        let field_dtype = struct_dtype
70            .as_struct_fields_opt()
71            .and_then(|st| st.field(expr.data()))
72            .ok_or_else(|| {
73                vortex_err!("Couldn't find the {} field in the input scope", expr.data())
74            })?;
75
76        // Match here to avoid cloning the dtype if nullability doesn't need to change
77        if matches!(
78            (struct_dtype.nullability(), field_dtype.nullability()),
79            (Nullability::Nullable, Nullability::NonNullable)
80        ) {
81            return Ok(field_dtype.with_nullability(Nullability::Nullable));
82        }
83
84        Ok(field_dtype)
85    }
86
87    fn evaluate(&self, expr: &ExpressionView<Self>, scope: &ArrayRef) -> VortexResult<ArrayRef> {
88        let input = expr.children()[0].evaluate(scope)?.to_struct();
89        let field = input.field_by_name(expr.data()).cloned()?;
90
91        match input.dtype().nullability() {
92            Nullability::NonNullable => Ok(field),
93            Nullability::Nullable => mask(&field, &input.validity_mask().not()),
94        }
95    }
96
97    fn stat_max(
98        &self,
99        expr: &ExpressionView<Self>,
100        catalog: &mut dyn StatsCatalog,
101    ) -> Option<Expression> {
102        catalog.stats_ref(&FieldPath::from_name(expr.data().clone()), Stat::Max)
103    }
104
105    fn stat_min(
106        &self,
107        expr: &ExpressionView<Self>,
108        catalog: &mut dyn StatsCatalog,
109    ) -> Option<Expression> {
110        catalog.stats_ref(&FieldPath::from_name(expr.data().clone()), Stat::Min)
111    }
112
113    fn stat_nan_count(
114        &self,
115        expr: &ExpressionView<Self>,
116        catalog: &mut dyn StatsCatalog,
117    ) -> Option<Expression> {
118        catalog.stats_ref(&FieldPath::from_name(expr.data().clone()), Stat::NaNCount)
119    }
120
121    fn stat_field_path(&self, expr: &ExpressionView<Self>) -> Option<FieldPath> {
122        expr.children()[0]
123            .stat_field_path()
124            .map(|fp| fp.push(expr.data().clone()))
125    }
126}
127
128/// Creates an expression that accesses a field from the root array.
129///
130/// Equivalent to `get_item(field, root())` - extracts a named field from the input array.
131///
132/// ```rust
133/// # use vortex_array::expr::col;
134/// let expr = col("name");
135/// ```
136pub fn col(field: impl Into<FieldName>) -> Expression {
137    GetItem.new_expr(field.into(), vec![root()])
138}
139
140/// Creates an expression that extracts a named field from a struct expression.
141///
142/// Accesses the specified field from the result of the child expression.
143///
144/// ```rust
145/// # use vortex_array::expr::{get_item, root};
146/// let expr = get_item("user_id", root());
147/// ```
148pub fn get_item(field: impl Into<FieldName>, child: Expression) -> Expression {
149    GetItem.new_expr(field.into(), vec![child])
150}
151
152#[cfg(test)]
153mod tests {
154    use vortex_buffer::buffer;
155    use vortex_dtype::PType::I32;
156    use vortex_dtype::{DType, FieldNames, Nullability};
157    use vortex_scalar::Scalar;
158
159    use super::get_item;
160    use crate::arrays::StructArray;
161    use crate::expr::exprs::root::root;
162    use crate::validity::Validity;
163    use crate::{Array, IntoArray};
164
165    fn test_array() -> StructArray {
166        StructArray::from_fields(&[
167            ("a", buffer![0i32, 1, 2].into_array()),
168            ("b", buffer![4i64, 5, 6].into_array()),
169        ])
170        .unwrap()
171    }
172
173    #[test]
174    fn get_item_by_name() {
175        let st = test_array();
176        let get_item = get_item("a", root());
177        let item = get_item.evaluate(&st.to_array()).unwrap();
178        assert_eq!(item.dtype(), &DType::from(I32))
179    }
180
181    #[test]
182    fn get_item_by_name_none() {
183        let st = test_array();
184        let get_item = get_item("c", root());
185        assert!(get_item.evaluate(&st.to_array()).is_err());
186    }
187
188    #[test]
189    fn get_nullable_field() {
190        let st = StructArray::try_new(
191            FieldNames::from(["a"]),
192            vec![buffer![1i32].into_array()],
193            1,
194            Validity::AllInvalid,
195        )
196        .unwrap()
197        .to_array();
198
199        let get_item = get_item("a", root());
200        let item = get_item.evaluate(&st).unwrap();
201        assert_eq!(
202            item.scalar_at(0),
203            Scalar::null(DType::Primitive(I32, Nullability::Nullable))
204        );
205    }
206}