vortex_expr/exprs/
cast.rs1use vortex_array::compute::cast as compute_cast;
5use vortex_array::{ArrayRef, DeserializeMetadata, ProstMetadata};
6use vortex_dtype::DType;
7use vortex_error::{VortexResult, vortex_bail, vortex_err};
8use vortex_proto::expr as pb;
9
10use crate::display::{DisplayAs, DisplayFormat};
11use crate::{AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, VTable, vtable};
12
13vtable!(Cast);
14
15#[allow(clippy::derived_hash_with_manual_eq)]
16#[derive(Debug, Clone, Hash, Eq)]
17pub struct CastExpr {
18 target: DType,
19 child: ExprRef,
20}
21
22impl PartialEq for CastExpr {
23 fn eq(&self, other: &Self) -> bool {
24 self.target == other.target && self.child.eq(&other.child)
25 }
26}
27
28pub struct CastExprEncoding;
29
30impl VTable for CastVTable {
31 type Expr = CastExpr;
32 type Encoding = CastExprEncoding;
33 type Metadata = ProstMetadata<pb::CastOpts>;
34
35 fn id(_encoding: &Self::Encoding) -> ExprId {
36 ExprId::new_ref("cast")
37 }
38
39 fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
40 ExprEncodingRef::new_ref(CastExprEncoding.as_ref())
41 }
42
43 fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
44 Some(ProstMetadata(pb::CastOpts {
45 target: Some((&expr.target).into()),
46 }))
47 }
48
49 fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
50 vec![&expr.child]
51 }
52
53 fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
54 Ok(CastExpr {
55 target: expr.target.clone(),
56 child: children[0].clone(),
57 })
58 }
59
60 fn build(
61 _encoding: &Self::Encoding,
62 metadata: &<Self::Metadata as DeserializeMetadata>::Output,
63 children: Vec<ExprRef>,
64 ) -> VortexResult<Self::Expr> {
65 if children.len() != 1 {
66 vortex_bail!(
67 "Cast expression must have exactly 1 child, got {}",
68 children.len()
69 );
70 }
71 let target: DType = metadata
72 .target
73 .as_ref()
74 .ok_or_else(|| vortex_err!("missing target dtype in CastOpts"))?
75 .try_into()?;
76 Ok(CastExpr {
77 target,
78 child: children[0].clone(),
79 })
80 }
81
82 fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
83 let array = expr.child.evaluate(scope)?;
84 compute_cast(&array, &expr.target)
85 }
86
87 fn return_dtype(expr: &Self::Expr, _scope: &DType) -> VortexResult<DType> {
88 Ok(expr.target.clone())
89 }
90}
91
92impl CastExpr {
93 pub fn new(child: ExprRef, target: DType) -> Self {
94 Self { target, child }
95 }
96
97 pub fn new_expr(child: ExprRef, target: DType) -> ExprRef {
98 Self::new(child, target).into_expr()
99 }
100}
101
102impl DisplayAs for CastExpr {
103 fn fmt_as(&self, df: DisplayFormat, f: &mut std::fmt::Formatter) -> std::fmt::Result {
104 match df {
105 DisplayFormat::Compact => {
106 write!(f, "cast({}, {})", self.child, self.target)
107 }
108 DisplayFormat::Tree => {
109 write!(f, "Cast(target: {})", self.target)
110 }
111 }
112 }
113}
114
115impl AnalysisExpr for CastExpr {}
116
117pub fn cast(child: ExprRef, target: DType) -> ExprRef {
127 CastExpr::new(child, target).into_expr()
128}
129
130#[cfg(test)]
131mod tests {
132 use vortex_array::IntoArray;
133 use vortex_array::arrays::StructArray;
134 use vortex_buffer::buffer;
135 use vortex_dtype::{DType, Nullability, PType};
136
137 use crate::{ExprRef, Scope, cast, get_item, root, test_harness};
138
139 #[test]
140 fn dtype() {
141 let dtype = test_harness::struct_dtype();
142 assert_eq!(
143 cast(root(), DType::Bool(Nullability::NonNullable))
144 .return_dtype(&dtype)
145 .unwrap(),
146 DType::Bool(Nullability::NonNullable)
147 );
148 }
149
150 #[test]
151 fn replace_children() {
152 let expr = cast(root(), DType::Bool(Nullability::Nullable));
153 let _ = expr.with_children(vec![root()]);
154 }
155
156 #[test]
157 fn evaluate() {
158 let test_array = StructArray::from_fields(&[
159 ("a", buffer![0i32, 1, 2].into_array()),
160 ("b", buffer![4i64, 5, 6].into_array()),
161 ])
162 .unwrap()
163 .into_array();
164
165 let expr: ExprRef = cast(
166 get_item("a", root()),
167 DType::Primitive(PType::I64, Nullability::NonNullable),
168 );
169 let result = expr.evaluate(&Scope::new(test_array)).unwrap();
170
171 assert_eq!(
172 result.dtype(),
173 &DType::Primitive(PType::I64, Nullability::NonNullable)
174 );
175 }
176
177 #[test]
178 fn test_display() {
179 let expr = cast(
180 get_item("value", root()),
181 DType::Primitive(PType::I64, Nullability::NonNullable),
182 );
183 assert_eq!(expr.to_string(), "cast($.value, i64)");
184
185 let expr2 = cast(root(), DType::Bool(Nullability::Nullable));
186 assert_eq!(expr2.to_string(), "cast($, bool?)");
187 }
188}