vortex_expr/exprs/
cast.rs1use std::fmt::Display;
5
6use vortex_array::compute::cast as compute_cast;
7use vortex_array::{ArrayRef, DeserializeMetadata, ProstMetadata};
8use vortex_dtype::DType;
9use vortex_error::{VortexResult, vortex_bail, vortex_err};
10use vortex_proto::expr as pb;
11
12use crate::{AnalysisExpr, ExprEncodingRef, ExprId, ExprRef, IntoExpr, Scope, VTable, vtable};
13
14vtable!(Cast);
15
16#[allow(clippy::derived_hash_with_manual_eq)]
17#[derive(Debug, Clone, Hash)]
18pub struct CastExpr {
19 target: DType,
20 child: ExprRef,
21}
22
23impl PartialEq for CastExpr {
24 fn eq(&self, other: &Self) -> bool {
25 self.target == other.target && self.child.eq(&other.child)
26 }
27}
28
29pub struct CastExprEncoding;
30
31impl VTable for CastVTable {
32 type Expr = CastExpr;
33 type Encoding = CastExprEncoding;
34 type Metadata = ProstMetadata<pb::CastOpts>;
35
36 fn id(_encoding: &Self::Encoding) -> ExprId {
37 ExprId::new_ref("cast")
38 }
39
40 fn encoding(_expr: &Self::Expr) -> ExprEncodingRef {
41 ExprEncodingRef::new_ref(CastExprEncoding.as_ref())
42 }
43
44 fn metadata(expr: &Self::Expr) -> Option<Self::Metadata> {
45 Some(ProstMetadata(pb::CastOpts {
46 target: Some((&expr.target).into()),
47 }))
48 }
49
50 fn children(expr: &Self::Expr) -> Vec<&ExprRef> {
51 vec![&expr.child]
52 }
53
54 fn with_children(expr: &Self::Expr, children: Vec<ExprRef>) -> VortexResult<Self::Expr> {
55 Ok(CastExpr {
56 target: expr.target.clone(),
57 child: children[0].clone(),
58 })
59 }
60
61 fn build(
62 _encoding: &Self::Encoding,
63 metadata: &<Self::Metadata as DeserializeMetadata>::Output,
64 children: Vec<ExprRef>,
65 ) -> VortexResult<Self::Expr> {
66 if children.len() != 1 {
67 vortex_bail!(
68 "Cast expression must have exactly 1 child, got {}",
69 children.len()
70 );
71 }
72 let target: DType = metadata
73 .target
74 .as_ref()
75 .ok_or_else(|| vortex_err!("missing target dtype in CastOpts"))?
76 .try_into()?;
77 Ok(CastExpr {
78 target,
79 child: children[0].clone(),
80 })
81 }
82
83 fn evaluate(expr: &Self::Expr, scope: &Scope) -> VortexResult<ArrayRef> {
84 let array = expr.child.evaluate(scope)?;
85 compute_cast(&array, &expr.target)
86 }
87
88 fn return_dtype(expr: &Self::Expr, _scope: &DType) -> VortexResult<DType> {
89 Ok(expr.target.clone())
90 }
91}
92
93impl CastExpr {
94 pub fn new(child: ExprRef, target: DType) -> Self {
95 Self { target, child }
96 }
97
98 pub fn new_expr(child: ExprRef, target: DType) -> ExprRef {
99 Self::new(child, target).into_expr()
100 }
101}
102
103impl Display for CastExpr {
104 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105 write!(f, "cast({}, {})", self.child, self.target)
106 }
107}
108
109impl AnalysisExpr for CastExpr {}
110
111pub fn cast(child: ExprRef, target: DType) -> ExprRef {
112 CastExpr::new(child, target).into_expr()
113}
114
115#[cfg(test)]
116mod tests {
117 use vortex_array::IntoArray;
118 use vortex_array::arrays::StructArray;
119 use vortex_buffer::buffer;
120 use vortex_dtype::{DType, Nullability, PType};
121
122 use crate::{ExprRef, Scope, cast, get_item, root, test_harness};
123
124 #[test]
125 fn dtype() {
126 let dtype = test_harness::struct_dtype();
127 assert_eq!(
128 cast(root(), DType::Bool(Nullability::NonNullable))
129 .return_dtype(&dtype)
130 .unwrap(),
131 DType::Bool(Nullability::NonNullable)
132 );
133 }
134
135 #[test]
136 fn replace_children() {
137 let expr = cast(root(), DType::Bool(Nullability::Nullable));
138 let _ = expr.with_children(vec![root()]);
139 }
140
141 #[test]
142 fn evaluate() {
143 let test_array = StructArray::from_fields(&[
144 ("a", buffer![0i32, 1, 2].into_array()),
145 ("b", buffer![4i64, 5, 6].into_array()),
146 ])
147 .unwrap()
148 .into_array();
149
150 let expr: ExprRef = cast(
151 get_item("a", root()),
152 DType::Primitive(PType::I64, Nullability::NonNullable),
153 );
154 let result = expr.evaluate(&Scope::new(test_array)).unwrap();
155
156 assert_eq!(
157 result.dtype(),
158 &DType::Primitive(PType::I64, Nullability::NonNullable)
159 );
160 }
161}