vortex_scalar/
bool.rs

1use std::cmp::Ordering;
2
3use vortex_dtype::Nullability::NonNullable;
4use vortex_dtype::{DType, Nullability};
5use vortex_error::{VortexError, VortexExpect as _, VortexResult, vortex_bail, vortex_err};
6
7use crate::{InnerScalarValue, Scalar, ScalarValue};
8
9#[derive(Debug, Hash)]
10pub struct BoolScalar<'a> {
11    dtype: &'a DType,
12    value: Option<bool>,
13}
14
15impl PartialEq for BoolScalar<'_> {
16    fn eq(&self, other: &Self) -> bool {
17        self.dtype.eq_ignore_nullability(other.dtype) && self.value == other.value
18    }
19}
20
21impl Eq for BoolScalar<'_> {}
22
23impl PartialOrd for BoolScalar<'_> {
24    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
25        Some(self.value.cmp(&other.value))
26    }
27}
28
29impl Ord for BoolScalar<'_> {
30    fn cmp(&self, other: &Self) -> Ordering {
31        self.value.cmp(&other.value)
32    }
33}
34
35impl<'a> BoolScalar<'a> {
36    #[inline]
37    pub fn dtype(&self) -> &'a DType {
38        self.dtype
39    }
40
41    pub fn value(&self) -> Option<bool> {
42        self.value
43    }
44
45    pub(crate) fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
46        if !matches!(dtype, DType::Bool(..)) {
47            vortex_bail!("Can't cast bool to {}", dtype)
48        }
49        Ok(Scalar::bool(
50            self.value.vortex_expect("nullness handled in Scalar::cast"),
51            dtype.nullability(),
52        ))
53    }
54
55    pub fn invert(self) -> BoolScalar<'a> {
56        BoolScalar {
57            dtype: self.dtype,
58            value: self.value.map(|v| !v),
59        }
60    }
61
62    pub fn into_scalar(self) -> Scalar {
63        Scalar {
64            dtype: self.dtype.clone(),
65            value: self
66                .value
67                .map(|x| ScalarValue(InnerScalarValue::Bool(x)))
68                .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null)),
69        }
70    }
71}
72
73impl Scalar {
74    pub fn bool(value: bool, nullability: Nullability) -> Self {
75        Self {
76            dtype: DType::Bool(nullability),
77            value: ScalarValue(InnerScalarValue::Bool(value)),
78        }
79    }
80}
81
82impl<'a> TryFrom<&'a Scalar> for BoolScalar<'a> {
83    type Error = VortexError;
84
85    fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
86        if !matches!(value.dtype(), DType::Bool(_)) {
87            vortex_bail!("Expected bool scalar, found {}", value.dtype())
88        }
89        Ok(Self {
90            dtype: value.dtype(),
91            value: value.value.as_bool()?,
92        })
93    }
94}
95
96impl TryFrom<&Scalar> for bool {
97    type Error = VortexError;
98
99    fn try_from(value: &Scalar) -> VortexResult<Self> {
100        <Option<bool>>::try_from(value)?
101            .ok_or_else(|| vortex_err!("Can't extract present value from null scalar"))
102    }
103}
104
105impl TryFrom<&Scalar> for Option<bool> {
106    type Error = VortexError;
107
108    fn try_from(value: &Scalar) -> VortexResult<Self> {
109        Ok(BoolScalar::try_from(value)?.value())
110    }
111}
112
113impl TryFrom<Scalar> for bool {
114    type Error = VortexError;
115
116    fn try_from(value: Scalar) -> VortexResult<Self> {
117        Self::try_from(&value)
118    }
119}
120
121impl TryFrom<Scalar> for Option<bool> {
122    type Error = VortexError;
123
124    fn try_from(value: Scalar) -> VortexResult<Self> {
125        Self::try_from(&value)
126    }
127}
128
129impl From<bool> for Scalar {
130    fn from(value: bool) -> Self {
131        Self {
132            dtype: DType::Bool(NonNullable),
133            value: value.into(),
134        }
135    }
136}
137
138impl From<bool> for ScalarValue {
139    fn from(value: bool) -> Self {
140        ScalarValue(InnerScalarValue::Bool(value))
141    }
142}
143
144#[cfg(test)]
145mod test {
146    use vortex_dtype::Nullability::*;
147
148    use super::*;
149
150    #[test]
151    fn into_from() {
152        let scalar: Scalar = false.into();
153        assert!(!bool::try_from(&scalar).unwrap());
154    }
155
156    #[test]
157    fn equality() {
158        assert_eq!(&Scalar::bool(true, Nullable), &Scalar::bool(true, Nullable));
159        // Equality ignores nullability
160        assert_eq!(
161            &Scalar::bool(true, Nullable),
162            &Scalar::bool(true, NonNullable)
163        );
164    }
165}