vortex_expr/exprs/
cast.rs1use vortex_array::compute::cast as compute_cast;
5use vortex_array::{ArrayRef, DeserializeMetadata, ProstMetadata};
6use vortex_dtype::DType;
7use vortex_error::{VortexResult, vortex_bail, vortex_err};
8use vortex_proto::expr as pb;
9
10use crate::display::{DisplayAs, DisplayFormat};
11use crate::{AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, VTable, vtable};
12
13vtable!(Cast);
14
15#[allow(clippy::derived_hash_with_manual_eq)]
16#[derive(Debug, Clone, Hash, Eq)]
17pub struct CastExpr {
18 target: DType,
19 child: ExprRef,
20}
21
22impl PartialEq for CastExpr {
23 fn eq(&self, other: &Self) -> bool {
24 self.target == other.target && self.child.eq(&other.child)
25 }
26}
27
28pub struct CastExprEncoding;
29
30impl VTable for CastVTable {
31 type Expr = CastExpr;
32 type Encoding = CastExprEncoding;
33 type Metadata = ProstMetadata<pb::CastOpts>;
34
35 fn id(_encoding: &Self::Encoding) -> ExprId {
36 ExprId::new_ref("cast")
37 }
38
39 fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
40 ExprEncodingRef::new_ref(CastExprEncoding.as_ref())
41 }
42
43 fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
44 Some(ProstMetadata(pb::CastOpts {
45 target: Some((&expr.target).into()),
46 }))
47 }
48
49 fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
50 vec![&expr.child]
51 }
52
53 fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
54 Ok(CastExpr {
55 target: expr.target.clone(),
56 child: children[0].clone(),
57 })
58 }
59
60 fn build(
61 _encoding: &Self::Encoding,
62 metadata: &<Self::Metadata as DeserializeMetadata>::Output,
63 children: Vec<ExprRef>,
64 ) -> VortexResult<Self::Expr> {
65 if children.len() != 1 {
66 vortex_bail!(
67 "Cast expression must have exactly 1 child, got {}",
68 children.len()
69 );
70 }
71 let target: DType = metadata
72 .target
73 .as_ref()
74 .ok_or_else(|| vortex_err!("missing target dtype in CastOpts"))?
75 .try_into()?;
76 Ok(CastExpr {
77 target,
78 child: children[0].clone(),
79 })
80 }
81
82 fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
83 let array = expr.child.evaluate(scope)?;
84 compute_cast(&array, &expr.target).map_err(|e| {
85 e.with_context(format!(
86 "Failed to cast array of dtype {} to {}",
87 array.dtype(),
88 expr.target
89 ))
90 })
91 }
92
93 fn return_dtype(expr: &Self::Expr, _scope: &DType) -> VortexResult<DType> {
94 Ok(expr.target.clone())
95 }
96}
97
98impl CastExpr {
99 pub fn new(child: ExprRef, target: DType) -> Self {
100 Self { target, child }
101 }
102
103 pub fn new_expr(child: ExprRef, target: DType) -> ExprRef {
104 Self::new(child, target).into_expr()
105 }
106}
107
108impl DisplayAs for CastExpr {
109 fn fmt_as(&self, df: DisplayFormat, f: &mut std::fmt::Formatter) -> std::fmt::Result {
110 match df {
111 DisplayFormat::Compact => {
112 write!(f, "cast({}, {})", self.child, self.target)
113 }
114 DisplayFormat::Tree => {
115 write!(f, "Cast(target: {})", self.target)
116 }
117 }
118 }
119}
120
121impl AnalysisExpr for CastExpr {}
122
123pub fn cast(child: ExprRef, target: DType) -> ExprRef {
133 CastExpr::new(child, target).into_expr()
134}
135
136#[cfg(test)]
137mod tests {
138 use vortex_array::IntoArray;
139 use vortex_array::arrays::StructArray;
140 use vortex_buffer::buffer;
141 use vortex_dtype::{DType, Nullability, PType};
142
143 use crate::{ExprRef, Scope, cast, get_item, root, test_harness};
144
145 #[test]
146 fn dtype() {
147 let dtype = test_harness::struct_dtype();
148 assert_eq!(
149 cast(root(), DType::Bool(Nullability::NonNullable))
150 .return_dtype(&dtype)
151 .unwrap(),
152 DType::Bool(Nullability::NonNullable)
153 );
154 }
155
156 #[test]
157 fn replace_children() {
158 let expr = cast(root(), DType::Bool(Nullability::Nullable));
159 let _ = expr.with_children(vec![root()]);
160 }
161
162 #[test]
163 fn evaluate() {
164 let test_array = StructArray::from_fields(&[
165 ("a", buffer![0i32, 1, 2].into_array()),
166 ("b", buffer![4i64, 5, 6].into_array()),
167 ])
168 .unwrap()
169 .into_array();
170
171 let expr: ExprRef = cast(
172 get_item("a", root()),
173 DType::Primitive(PType::I64, Nullability::NonNullable),
174 );
175 let result = expr.evaluate(&Scope::new(test_array)).unwrap();
176
177 assert_eq!(
178 result.dtype(),
179 &DType::Primitive(PType::I64, Nullability::NonNullable)
180 );
181 }
182
183 #[test]
184 fn test_display() {
185 let expr = cast(
186 get_item("value", root()),
187 DType::Primitive(PType::I64, Nullability::NonNullable),
188 );
189 assert_eq!(expr.to_string(), "cast($.value, i64)");
190
191 let expr2 = cast(root(), DType::Bool(Nullability::Nullable));
192 assert_eq!(expr2.to_string(), "cast($, bool?)");
193 }
194}