vortex_expr/exprs/
get_item.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::{Debug, Formatter};
5use std::hash::Hash;
6
7use vortex_array::stats::Stat;
8use vortex_array::{ArrayRef, DeserializeMetadata, ProstMetadata, ToCanonical};
9use vortex_dtype::{DType, FieldName, FieldPath};
10use vortex_error::{VortexResult, vortex_bail, vortex_err};
11use vortex_proto::expr as pb;
12
13use crate::display::{DisplayAs, DisplayFormat};
14use crate::{
15    AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, StatsCatalog, VTable, root,
16    vtable,
17};
18
19vtable!(GetItem);
20
21#[allow(clippy::derived_hash_with_manual_eq)]
22#[derive(Debug, Clone, Hash, Eq)]
23pub struct GetItemExpr {
24    field: FieldName,
25    child: ExprRef,
26}
27
28impl PartialEq for GetItemExpr {
29    fn eq(&self, other: &Self) -> bool {
30        self.field == other.field && self.child.eq(&other.child)
31    }
32}
33
34pub struct GetItemExprEncoding;
35
36impl VTable for GetItemVTable {
37    type Expr = GetItemExpr;
38    type Encoding = GetItemExprEncoding;
39    type Metadata = ProstMetadata<pb::GetItemOpts>;
40
41    fn id(_encoding: &Self::Encoding) -> ExprId {
42        ExprId::new_ref("get_item")
43    }
44
45    fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
46        ExprEncodingRef::new_ref(GetItemExprEncoding.as_ref())
47    }
48
49    fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
50        Some(ProstMetadata(pb::GetItemOpts {
51            path: expr.field.to_string(),
52        }))
53    }
54
55    fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
56        vec![&expr.child]
57    }
58
59    fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
60        Ok(GetItemExpr {
61            field: expr.field.clone(),
62            child: children[0].clone(),
63        })
64    }
65
66    fn build(
67        _encoding: &Self::Encoding,
68        metadata: &<Self::Metadata as DeserializeMetadata>::Output,
69        children: Vec<ExprRef>,
70    ) -> VortexResult<Self::Expr> {
71        if children.len() != 1 {
72            vortex_bail!(
73                "GetItem expression must have exactly 1 child, got {}",
74                children.len()
75            );
76        }
77
78        let field = FieldName::from(metadata.path.clone());
79        Ok(GetItemExpr {
80            field,
81            child: children[0].clone(),
82        })
83    }
84
85    fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
86        expr.child
87            .unchecked_evaluate(scope)?
88            .to_struct()?
89            .field_by_name(expr.field())
90            .cloned()
91    }
92
93    fn return_dtype(expr: &Self::Expr, scope: &DType) -> VortexResult<DType> {
94        let input = expr.child.return_dtype(scope)?;
95        input
96            .as_struct_fields_opt()
97            .and_then(|st| st.field(expr.field()))
98            .ok_or_else(|| {
99                vortex_err!(
100                    "Couldn't find the {} field in the input scope",
101                    expr.field()
102                )
103            })
104    }
105}
106
107impl GetItemExpr {
108    pub fn new(field: impl Into<FieldName>, child: ExprRef) -> Self {
109        Self {
110            field: field.into(),
111            child,
112        }
113    }
114
115    pub fn new_expr(field: impl Into<FieldName>, child: ExprRef) -> ExprRef {
116        Self::new(field, child).into_expr()
117    }
118
119    pub fn field(&self) -> &FieldName {
120        &self.field
121    }
122
123    pub fn child(&self) -> &ExprRef {
124        &self.child
125    }
126
127    pub fn is(expr: &ExprRef) -> bool {
128        expr.is::<GetItemVTable>()
129    }
130}
131
132/// Creates an expression that accesses a field from the root array.
133///
134/// Equivalent to `get_item(field, root())` - extracts a named field from the input array.
135///
136/// ```rust
137/// # use vortex_expr::col;
138/// let expr = col("name");
139/// ```
140pub fn col(field: impl Into<FieldName>) -> ExprRef {
141    GetItemExpr::new(field, root()).into_expr()
142}
143
144/// Creates an expression that extracts a named field from a struct expression.
145///
146/// Accesses the specified field from the result of the child expression.
147///
148/// ```rust
149/// # use vortex_expr::{get_item, root};
150/// let expr = get_item("user_id", root());
151/// ```
152pub fn get_item(field: impl Into<FieldName>, child: ExprRef) -> ExprRef {
153    GetItemExpr::new(field, child).into_expr()
154}
155
156impl DisplayAs for GetItemExpr {
157    fn fmt_as(&self, df: DisplayFormat, f: &mut Formatter) -> std::fmt::Result {
158        match df {
159            DisplayFormat::Compact => {
160                write!(f, "{}.{}", self.child, &self.field)
161            }
162            DisplayFormat::Tree => {
163                write!(f, "GetItem({})", self.field)
164            }
165        }
166    }
167}
168impl AnalysisExpr for GetItemExpr {
169    fn max(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
170        catalog.stats_ref(&self.field_path()?, Stat::Max)
171    }
172
173    fn min(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
174        catalog.stats_ref(&self.field_path()?, Stat::Min)
175    }
176
177    fn nan_count(&self, catalog: &mut dyn StatsCatalog) -> Option<ExprRef> {
178        catalog.stats_ref(&self.field_path()?, Stat::NaNCount)
179    }
180
181    fn field_path(&self) -> Option<FieldPath> {
182        self.child()
183            .field_path()
184            .map(|fp| fp.push(self.field.clone()))
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use vortex_array::IntoArray;
191    use vortex_array::arrays::StructArray;
192    use vortex_buffer::buffer;
193    use vortex_dtype::DType;
194    use vortex_dtype::PType::I32;
195
196    use crate::get_item::get_item;
197    use crate::{Scope, root};
198
199    fn test_array() -> StructArray {
200        StructArray::from_fields(&[
201            ("a", buffer![0i32, 1, 2].into_array()),
202            ("b", buffer![4i64, 5, 6].into_array()),
203        ])
204        .unwrap()
205    }
206
207    #[test]
208    pub fn get_item_by_name() {
209        let st = test_array();
210        let get_item = get_item("a", root());
211        let item = get_item.evaluate(&Scope::new(st.to_array())).unwrap();
212        assert_eq!(item.dtype(), &DType::from(I32))
213    }
214
215    #[test]
216    pub fn get_item_by_name_none() {
217        let st = test_array();
218        let get_item = get_item("c", root());
219        assert!(get_item.evaluate(&Scope::new(st.to_array())).is_err());
220    }
221}