vortex_compute/cast/
bool.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use num_traits::One;
5use num_traits::Zero;
6use vortex_buffer::Buffer;
7use vortex_dtype::DType;
8use vortex_dtype::match_each_native_ptype;
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11use vortex_vector::Scalar;
12use vortex_vector::ScalarOps;
13use vortex_vector::Vector;
14use vortex_vector::VectorOps;
15use vortex_vector::bool::BoolScalar;
16use vortex_vector::bool::BoolVector;
17use vortex_vector::primitive::PScalar;
18use vortex_vector::primitive::PVector;
19
20use crate::cast::Cast;
21use crate::cast::try_cast_scalar_common;
22use crate::cast::try_cast_vector_common;
23
24impl Cast for BoolVector {
25    type Output = Vector;
26
27    /// Casts to Bool (identity) or Primitive (as 0/1).
28    fn cast(&self, target_dtype: &DType) -> VortexResult<Vector> {
29        if let Some(result) = try_cast_vector_common(self, target_dtype)? {
30            return Ok(result);
31        }
32
33        match target_dtype {
34            DType::Bool(n) if n.is_nullable() || self.validity().all_true() => {
35                Ok(self.clone().into())
36            }
37            DType::Primitive(ptype, n) if n.is_nullable() || self.validity().all_true() => {
38                match_each_native_ptype!(ptype, |T| {
39                    Ok(PVector::<T>::new(
40                        Buffer::<T>::from_trusted_len_iter(
41                            self.bits()
42                                .iter()
43                                .map(|b| if b { T::one() } else { T::zero() }),
44                        ),
45                        self.validity().clone(),
46                    )
47                    .into())
48                })
49            }
50            _ => {
51                vortex_bail!("Cannot cast BoolVector to {}", target_dtype);
52            }
53        }
54    }
55}
56
57impl Cast for BoolScalar {
58    type Output = Scalar;
59
60    /// Casts to Bool (identity) or Primitive (as 0/1).
61    fn cast(&self, target_dtype: &DType) -> VortexResult<Scalar> {
62        if let Some(result) = try_cast_scalar_common(self, target_dtype)? {
63            return Ok(result);
64        }
65        match target_dtype {
66            DType::Bool(n) if n.is_nullable() || self.is_valid() => Ok(self.clone().into()),
67            DType::Primitive(ptype, n) if n.is_nullable() || self.is_valid() => {
68                match_each_native_ptype!(ptype, |T| {
69                    let value = self.value().map(|b| if b { T::one() } else { T::zero() });
70                    Ok(PScalar::<T>::new(value).into())
71                })
72            }
73            _ => {
74                vortex_bail!("Cannot cast BoolScalar to {}", target_dtype);
75            }
76        }
77    }
78}