vortex_array/scalar_fn/fns/cast/
mod.rs1mod kernel;
5
6use std::fmt::Formatter;
7
8pub use kernel::*;
9use prost::Message;
10use vortex_error::VortexResult;
11use vortex_error::vortex_bail;
12use vortex_error::vortex_err;
13use vortex_proto::expr as pb;
14use vortex_session::VortexSession;
15use vortex_session::registry::CachedId;
16
17use crate::AnyColumnar;
18use crate::ArrayRef;
19use crate::ArrayView;
20use crate::CanonicalView;
21use crate::ColumnarView;
22use crate::ExecutionCtx;
23use crate::arrays::Bool;
24use crate::arrays::Constant;
25use crate::arrays::Decimal;
26use crate::arrays::Extension;
27use crate::arrays::FixedSizeList;
28use crate::arrays::ListView;
29use crate::arrays::Null;
30use crate::arrays::Primitive;
31use crate::arrays::VarBinView;
32use crate::arrays::struct_::compute::cast::struct_cast;
33use crate::builtins::ArrayBuiltins;
34use crate::dtype::DType;
35use crate::expr::StatsCatalog;
36use crate::expr::cast;
37use crate::expr::expression::Expression;
38use crate::expr::lit;
39use crate::expr::stats::Stat;
40use crate::scalar_fn::Arity;
41use crate::scalar_fn::ChildName;
42use crate::scalar_fn::ExecutionArgs;
43use crate::scalar_fn::ReduceCtx;
44use crate::scalar_fn::ReduceNode;
45use crate::scalar_fn::ReduceNodeRef;
46use crate::scalar_fn::ScalarFnId;
47use crate::scalar_fn::ScalarFnVTable;
48
49#[derive(Clone)]
51pub struct Cast;
52
53impl ScalarFnVTable for Cast {
54 type Options = DType;
55
56 fn id(&self) -> ScalarFnId {
57 static ID: CachedId = CachedId::new("vortex.cast");
58 *ID
59 }
60
61 fn serialize(&self, dtype: &DType) -> VortexResult<Option<Vec<u8>>> {
62 Ok(Some(
63 pb::CastOpts {
64 target: Some(dtype.try_into()?),
65 }
66 .encode_to_vec(),
67 ))
68 }
69
70 fn deserialize(
71 &self,
72 _metadata: &[u8],
73 session: &VortexSession,
74 ) -> VortexResult<Self::Options> {
75 let proto = pb::CastOpts::decode(_metadata)?.target;
76 DType::from_proto(
77 proto
78 .as_ref()
79 .ok_or_else(|| vortex_err!("Missing target dtype in Cast expression"))?,
80 session,
81 )
82 }
83
84 fn arity(&self, _options: &DType) -> Arity {
85 Arity::Exact(1)
86 }
87
88 fn child_name(&self, _instance: &DType, child_idx: usize) -> ChildName {
89 match child_idx {
90 0 => ChildName::from("input"),
91 _ => unreachable!("Invalid child index {} for Cast expression", child_idx),
92 }
93 }
94
95 fn fmt_sql(&self, dtype: &DType, expr: &Expression, f: &mut Formatter<'_>) -> std::fmt::Result {
96 write!(f, "cast(")?;
97 expr.children()[0].fmt_sql(f)?;
98 write!(f, " as {}", dtype)?;
99 write!(f, ")")
100 }
101
102 fn return_dtype(&self, dtype: &DType, _arg_dtypes: &[DType]) -> VortexResult<DType> {
103 Ok(dtype.clone())
104 }
105
106 fn execute(
107 &self,
108 target_dtype: &DType,
109 args: &dyn ExecutionArgs,
110 ctx: &mut ExecutionCtx,
111 ) -> VortexResult<ArrayRef> {
112 let input = args.get(0)?;
113
114 let Some(columnar) = input.as_opt::<AnyColumnar>() else {
115 return input.execute::<ArrayRef>(ctx)?.cast(target_dtype.clone());
116 };
117
118 match columnar {
119 ColumnarView::Canonical(canonical) => {
120 match cast_canonical(canonical, target_dtype, ctx)? {
121 Some(result) => Ok(result),
122 None => vortex_bail!(
123 "No CastKernel to cast canonical array {} from {} to {}",
124 canonical.to_array_ref().encoding_id(),
125 canonical.to_array_ref().dtype(),
126 target_dtype,
127 ),
128 }
129 }
130 ColumnarView::Constant(constant) => match cast_constant(constant, target_dtype)? {
131 Some(result) => Ok(result),
132 None => vortex_bail!(
133 "No CastReduce to cast constant array from {} to {}",
134 constant.dtype(),
135 target_dtype,
136 ),
137 },
138 }
139 }
140
141 fn reduce(
142 &self,
143 target_dtype: &DType,
144 node: &dyn ReduceNode,
145 _ctx: &dyn ReduceCtx,
146 ) -> VortexResult<Option<ReduceNodeRef>> {
147 let child = node.child(0);
149 if &child.node_dtype()? == target_dtype {
150 return Ok(Some(child));
151 }
152 Ok(None)
153 }
154
155 fn stat_expression(
156 &self,
157 dtype: &DType,
158 expr: &Expression,
159 stat: Stat,
160 catalog: &dyn StatsCatalog,
161 ) -> Option<Expression> {
162 match stat {
163 Stat::IsConstant
164 | Stat::IsSorted
165 | Stat::IsStrictSorted
166 | Stat::NaNCount
167 | Stat::Sum
168 | Stat::UncompressedSizeInBytes => expr.child(0).stat_expression(stat, catalog),
169 Stat::Max | Stat::Min => {
170 expr.child(0)
172 .stat_expression(stat, catalog)
173 .map(|x| cast(x, dtype.clone()))
174 }
175 Stat::NullCount => {
176 None
184 }
185 }
186 }
187
188 fn validity(&self, dtype: &DType, expression: &Expression) -> VortexResult<Option<Expression>> {
189 Ok(Some(if dtype.is_nullable() {
190 expression.child(0).validity()?
191 } else {
192 lit(true)
193 }))
194 }
195
196 fn is_null_sensitive(&self, _instance: &DType) -> bool {
198 true
199 }
200}
201
202fn cast_canonical(
211 canonical: CanonicalView<'_>,
212 dtype: &DType,
213 ctx: &mut ExecutionCtx,
214) -> VortexResult<Option<ArrayRef>> {
215 match canonical {
216 CanonicalView::Null(a) => <Null as CastReduce>::cast(a, dtype),
217 CanonicalView::Bool(a) => <Bool as CastKernel>::cast(a, dtype, ctx),
218 CanonicalView::Primitive(a) => <Primitive as CastKernel>::cast(a, dtype, ctx),
219 CanonicalView::Decimal(a) => <Decimal as CastKernel>::cast(a, dtype, ctx),
220 CanonicalView::VarBinView(a) => <VarBinView as CastKernel>::cast(a, dtype, ctx),
221 CanonicalView::List(a) => <ListView as CastKernel>::cast(a, dtype, ctx),
222 CanonicalView::FixedSizeList(a) => <FixedSizeList as CastKernel>::cast(a, dtype, ctx),
223 CanonicalView::Struct(a) => struct_cast(a, dtype, ctx),
224 CanonicalView::Extension(a) => <Extension as CastReduce>::cast(a, dtype),
225 CanonicalView::Variant(_) => {
226 vortex_bail!("Variant arrays don't support casting")
227 }
228 }
229}
230
231fn cast_constant(array: ArrayView<Constant>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
233 <Constant as CastReduce>::cast(array, dtype)
234}
235
236#[cfg(test)]
237mod tests {
238 use vortex_buffer::buffer;
239 use vortex_error::VortexExpect as _;
240
241 use crate::IntoArray;
242 use crate::arrays::StructArray;
243 use crate::dtype::DType;
244 use crate::dtype::Nullability;
245 use crate::dtype::PType;
246 use crate::expr::Expression;
247 use crate::expr::cast;
248 use crate::expr::get_item;
249 use crate::expr::root;
250 use crate::expr::test_harness;
251
252 #[test]
253 fn dtype() {
254 let dtype = test_harness::struct_dtype();
255 assert_eq!(
256 cast(root(), DType::Bool(Nullability::NonNullable))
257 .return_dtype(&dtype)
258 .unwrap(),
259 DType::Bool(Nullability::NonNullable)
260 );
261 }
262
263 #[test]
264 fn replace_children() {
265 let expr = cast(root(), DType::Bool(Nullability::Nullable));
266 expr.with_children(vec![root()])
267 .vortex_expect("operation should succeed in test");
268 }
269
270 #[test]
271 fn evaluate() {
272 let test_array = StructArray::from_fields(&[
273 ("a", buffer![0i32, 1, 2].into_array()),
274 ("b", buffer![4i64, 5, 6].into_array()),
275 ])
276 .unwrap()
277 .into_array();
278
279 let expr: Expression = cast(
280 get_item("a", root()),
281 DType::Primitive(PType::I64, Nullability::NonNullable),
282 );
283 let result = test_array.apply(&expr).unwrap();
284
285 assert_eq!(
286 result.dtype(),
287 &DType::Primitive(PType::I64, Nullability::NonNullable)
288 );
289 }
290
291 #[test]
292 fn test_display() {
293 let expr = cast(
294 get_item("value", root()),
295 DType::Primitive(PType::I64, Nullability::NonNullable),
296 );
297 assert_eq!(expr.to_string(), "cast($.value as i64)");
298
299 let expr2 = cast(root(), DType::Bool(Nullability::Nullable));
300 assert_eq!(expr2.to_string(), "cast($ as bool?)");
301 }
302}