1use std::any::Any;
2use std::fmt::{Debug, Display, Formatter};
3use std::hash::Hash;
4use std::sync::Arc;
5
6use vortex_array::{Array, ArrayRef, ArrayVariants};
7use vortex_dtype::{DType, FieldName};
8use vortex_error::{VortexResult, vortex_err};
9
10use crate::{ExprRef, VortexExpr, ident};
11
12#[derive(Debug, Clone, Eq, Hash)]
13#[allow(clippy::derived_hash_with_manual_eq)]
14pub struct GetItem {
15 field: FieldName,
16 child: ExprRef,
17}
18
19impl GetItem {
20 pub fn new_expr(field: impl Into<FieldName>, child: ExprRef) -> ExprRef {
21 Arc::new(Self {
22 field: field.into(),
23 child,
24 })
25 }
26
27 pub fn field(&self) -> &FieldName {
28 &self.field
29 }
30
31 pub fn child(&self) -> &ExprRef {
32 &self.child
33 }
34
35 pub fn is(expr: &ExprRef) -> bool {
36 expr.as_any().is::<Self>()
37 }
38}
39
40pub fn col(field: impl Into<FieldName>) -> ExprRef {
41 GetItem::new_expr(field, ident())
42}
43
44pub fn get_item(field: impl Into<FieldName>, child: ExprRef) -> ExprRef {
45 GetItem::new_expr(field, child)
46}
47
48pub fn get_item_scope(field: impl Into<FieldName>) -> ExprRef {
49 GetItem::new_expr(field, ident())
50}
51
52impl Display for GetItem {
53 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
54 write!(f, "{}.{}", self.child, &self.field)
55 }
56}
57
58#[cfg(feature = "proto")]
59pub(crate) mod proto {
60 use vortex_error::{VortexResult, vortex_bail};
61 use vortex_proto::expr::kind;
62 use vortex_proto::expr::kind::Kind;
63
64 use crate::{ExprDeserialize, ExprRef, ExprSerializable, GetItem, Id};
65
66 pub(crate) struct GetItemSerde;
67
68 impl Id for GetItemSerde {
69 fn id(&self) -> &'static str {
70 "get_item"
71 }
72 }
73
74 impl ExprDeserialize for GetItemSerde {
75 fn deserialize(&self, kind: &Kind, children: Vec<ExprRef>) -> VortexResult<ExprRef> {
76 let Kind::GetItem(kind::GetItem { path }) = kind else {
77 vortex_bail!("wrong kind {:?}, want get_item", kind)
78 };
79
80 Ok(GetItem::new_expr(path.to_string(), children[0].clone()))
81 }
82 }
83
84 impl ExprSerializable for GetItem {
85 fn id(&self) -> &'static str {
86 GetItemSerde.id()
87 }
88
89 fn serialize_kind(&self) -> VortexResult<Kind> {
90 Ok(Kind::GetItem(kind::GetItem {
91 path: self.field.to_string(),
92 }))
93 }
94 }
95}
96
97impl VortexExpr for GetItem {
98 fn as_any(&self) -> &dyn Any {
99 self
100 }
101
102 fn unchecked_evaluate(&self, batch: &dyn Array) -> VortexResult<ArrayRef> {
103 let child = self.child.evaluate(batch)?;
104 child
105 .as_struct_typed()
106 .ok_or_else(|| vortex_err!("GetItem: child array into struct"))?
107 .maybe_null_field_by_name(self.field())
109 }
110
111 fn children(&self) -> Vec<&ExprRef> {
112 vec![self.child()]
113 }
114
115 fn replacing_children(self: Arc<Self>, children: Vec<ExprRef>) -> ExprRef {
116 assert_eq!(children.len(), 1);
117 Self::new_expr(self.field().clone(), children[0].clone())
118 }
119
120 fn return_dtype(&self, scope_dtype: &DType) -> VortexResult<DType> {
121 let input = self.child.return_dtype(scope_dtype)?;
122 input
123 .as_struct()
124 .ok_or_else(|| vortex_err!("GetItem: child dtype is not a struct"))?
125 .field(self.field())
126 }
127}
128
129impl PartialEq for GetItem {
130 fn eq(&self, other: &GetItem) -> bool {
131 self.field == other.field && self.child.eq(&other.child)
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use vortex_array::IntoArray;
138 use vortex_array::arrays::StructArray;
139 use vortex_buffer::buffer;
140 use vortex_dtype::DType;
141 use vortex_dtype::PType::I32;
142
143 use crate::get_item::get_item;
144 use crate::ident;
145
146 fn test_array() -> StructArray {
147 StructArray::from_fields(&[
148 ("a", buffer![0i32, 1, 2].into_array()),
149 ("b", buffer![4i64, 5, 6].into_array()),
150 ])
151 .unwrap()
152 }
153
154 #[test]
155 pub fn get_item_by_name() {
156 let st = test_array();
157 let get_item = get_item("a", ident());
158 let item = get_item.evaluate(&st).unwrap();
159 assert_eq!(item.dtype(), &DType::from(I32))
160 }
161
162 #[test]
163 pub fn get_item_by_name_none() {
164 let st = test_array();
165 let get_item = get_item("c", ident());
166 assert!(get_item.evaluate(&st).is_err());
167 }
168}