1use std::any::Any;
2use std::fmt::Display;
3use std::sync::Arc;
4
5use vortex_array::ArrayRef;
6use vortex_array::compute::cast as compute_cast;
7use vortex_dtype::DType;
8use vortex_error::{VortexExpect, VortexResult};
9
10use crate::{AnalysisExpr, ExprRef, Scope, ScopeDType, VortexExpr};
11
12#[derive(Debug, Eq, Hash)]
13#[allow(clippy::derived_hash_with_manual_eq)]
14pub struct Cast {
15 target: DType,
16 child: ExprRef,
17}
18
19impl Cast {
20 pub fn new_expr(child: ExprRef, target: DType) -> ExprRef {
21 Arc::new(Self { target, child })
22 }
23}
24
25impl PartialEq for Cast {
26 fn eq(&self, other: &Self) -> bool {
27 self.target.eq(&other.target) && self.child.eq(&other.child)
28 }
29}
30
31impl Display for Cast {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 write!(f, "cast({}, {})", self.child, self.target)
34 }
35}
36
37#[cfg(feature = "proto")]
38pub(crate) mod proto {
39 use vortex_dtype::DType;
40 use vortex_error::{VortexResult, vortex_bail, vortex_err};
41 use vortex_proto::expr::kind;
42 use vortex_proto::expr::kind::Kind;
43
44 use crate::cast::Cast;
45 use crate::{ExprDeserialize, ExprRef, ExprSerializable, Id};
46
47 pub(crate) struct CastSerde;
48
49 impl Id for CastSerde {
50 fn id(&self) -> &'static str {
51 "cast"
52 }
53 }
54
55 impl ExprDeserialize for CastSerde {
56 fn deserialize(&self, kind: &Kind, children: Vec<ExprRef>) -> VortexResult<ExprRef> {
57 let Kind::Cast(kind::Cast { target }) = kind else {
58 vortex_bail!("wrong kind {:?}, want cast", kind)
59 };
60 let target: DType = target
61 .as_ref()
62 .ok_or_else(|| vortex_err!("empty target dtype"))?
63 .try_into()?;
64
65 Ok(Cast::new_expr(children[0].clone(), target))
66 }
67 }
68
69 impl ExprSerializable for Cast {
70 fn id(&self) -> &'static str {
71 CastSerde.id()
72 }
73
74 fn serialize_kind(&self) -> VortexResult<Kind> {
75 Ok(Kind::Cast(kind::Cast {
76 target: Some((&self.target).into()),
77 }))
78 }
79 }
80}
81
82impl AnalysisExpr for Cast {}
83
84impl VortexExpr for Cast {
85 fn as_any(&self) -> &dyn Any {
86 self
87 }
88
89 fn unchecked_evaluate(&self, scope: &Scope) -> VortexResult<ArrayRef> {
90 let array = self.child.evaluate(scope)?;
91 compute_cast(&array, &self.target)
92 }
93
94 fn children(&self) -> Vec<&ExprRef> {
95 vec![&self.child]
96 }
97
98 fn replacing_children(self: Arc<Self>, mut children: Vec<ExprRef>) -> ExprRef {
99 Self::new_expr(
100 children
101 .pop()
102 .vortex_expect("Cast::replacing_children should have one child"),
103 self.target.clone(),
104 )
105 }
106
107 fn return_dtype(&self, _scope_dtype: &ScopeDType) -> VortexResult<DType> {
108 Ok(self.target.clone())
109 }
110}
111
112pub fn cast(child: ExprRef, target: DType) -> ExprRef {
113 Cast::new_expr(child, target)
114}
115
116#[cfg(test)]
117mod tests {
118 use vortex_array::IntoArray;
119 use vortex_array::arrays::StructArray;
120 use vortex_buffer::buffer;
121 use vortex_dtype::{DType, Nullability, PType};
122
123 use crate::{ExprRef, Scope, ScopeDType, cast, get_item, root, test_harness};
124
125 #[test]
126 fn dtype() {
127 let dtype = test_harness::struct_dtype();
128 assert_eq!(
129 cast(root(), DType::Bool(Nullability::NonNullable))
130 .return_dtype(&ScopeDType::new(dtype))
131 .unwrap(),
132 DType::Bool(Nullability::NonNullable)
133 );
134 }
135
136 #[test]
137 fn replace_children() {
138 let expr = cast(root(), DType::Bool(Nullability::Nullable));
139 let _ = expr.replacing_children(vec![root()]);
140 }
141
142 #[test]
143 fn evaluate() {
144 let test_array = StructArray::from_fields(&[
145 ("a", buffer![0i32, 1, 2].into_array()),
146 ("b", buffer![4i64, 5, 6].into_array()),
147 ])
148 .unwrap()
149 .into_array();
150
151 let expr: ExprRef = cast(
152 get_item("a", root()),
153 DType::Primitive(PType::I64, Nullability::NonNullable),
154 );
155 let result = expr.evaluate(&Scope::new(test_array)).unwrap();
156
157 assert_eq!(
158 result.dtype(),
159 &DType::Primitive(PType::I64, Nullability::NonNullable)
160 );
161 }
162}