vortex_scalar/
bool.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::cmp::Ordering;
5use std::fmt::Display;
6use std::fmt::Formatter;
7
8use vortex_dtype::DType;
9use vortex_dtype::Nullability;
10use vortex_dtype::Nullability::NonNullable;
11use vortex_error::VortexError;
12use vortex_error::VortexExpect as _;
13use vortex_error::VortexResult;
14use vortex_error::vortex_bail;
15use vortex_error::vortex_err;
16
17use crate::InnerScalarValue;
18use crate::Scalar;
19use crate::ScalarValue;
20
21/// A scalar value representing a boolean.
22///
23/// This type provides a view into a boolean scalar value, which can be either
24/// true, false, or null.
25#[derive(Debug, Clone, Hash, Eq)]
26pub struct BoolScalar<'a> {
27    dtype: &'a DType,
28    value: Option<bool>,
29}
30
31impl Display for BoolScalar<'_> {
32    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
33        match self.value {
34            None => write!(f, "null"),
35            Some(v) => write!(f, "{v}"),
36        }
37    }
38}
39
40impl PartialEq for BoolScalar<'_> {
41    fn eq(&self, other: &Self) -> bool {
42        self.dtype.eq_ignore_nullability(other.dtype) && self.value == other.value
43    }
44}
45
46impl PartialOrd for BoolScalar<'_> {
47    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
48        Some(self.cmp(other))
49    }
50}
51
52impl Ord for BoolScalar<'_> {
53    fn cmp(&self, other: &Self) -> Ordering {
54        self.value.cmp(&other.value)
55    }
56}
57
58impl<'a> BoolScalar<'a> {
59    /// Returns the data type of this boolean scalar.
60    #[inline]
61    pub fn dtype(&self) -> &'a DType {
62        self.dtype
63    }
64
65    /// Returns the boolean value, or None if null.
66    pub fn value(&self) -> Option<bool> {
67        self.value
68    }
69
70    pub(crate) fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
71        if !matches!(dtype, DType::Bool(..)) {
72            vortex_bail!(
73                "Cannot cast bool to {dtype}: boolean scalars can only be cast to boolean types with different nullability"
74            )
75        }
76        Ok(Scalar::bool(
77            self.value.vortex_expect("nullness handled in Scalar::cast"),
78            dtype.nullability(),
79        ))
80    }
81
82    /// Returns a new boolean scalar with the inverted value.
83    ///
84    /// Null values remain null.
85    pub fn invert(self) -> BoolScalar<'a> {
86        BoolScalar {
87            dtype: self.dtype,
88            value: self.value.map(|v| !v),
89        }
90    }
91
92    /// Converts this boolean scalar into a general scalar.
93    pub fn into_scalar(self) -> Scalar {
94        Scalar::new(
95            self.dtype.clone(),
96            self.value
97                .map(|x| ScalarValue(InnerScalarValue::Bool(x)))
98                .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null)),
99        )
100    }
101}
102
103impl Scalar {
104    /// Creates a new boolean scalar with the given value and nullability.
105    pub fn bool(value: bool, nullability: Nullability) -> Self {
106        Self::new(
107            DType::Bool(nullability),
108            ScalarValue(InnerScalarValue::Bool(value)),
109        )
110    }
111}
112
113impl<'a> TryFrom<&'a Scalar> for BoolScalar<'a> {
114    type Error = VortexError;
115
116    fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
117        if !matches!(value.dtype(), DType::Bool(_)) {
118            vortex_bail!("Expected bool scalar, found {}", value.dtype())
119        }
120        Ok(Self {
121            dtype: value.dtype(),
122            value: value.value().as_bool()?,
123        })
124    }
125}
126
127impl TryFrom<&Scalar> for bool {
128    type Error = VortexError;
129
130    fn try_from(value: &Scalar) -> VortexResult<Self> {
131        <Option<bool>>::try_from(value)?
132            .ok_or_else(|| vortex_err!("Can't extract present value from null scalar"))
133    }
134}
135
136impl TryFrom<&Scalar> for Option<bool> {
137    type Error = VortexError;
138
139    fn try_from(value: &Scalar) -> VortexResult<Self> {
140        Ok(BoolScalar::try_from(value)?.value())
141    }
142}
143
144impl TryFrom<Scalar> for bool {
145    type Error = VortexError;
146
147    fn try_from(value: Scalar) -> VortexResult<Self> {
148        Self::try_from(&value)
149    }
150}
151
152impl TryFrom<Scalar> for Option<bool> {
153    type Error = VortexError;
154
155    fn try_from(value: Scalar) -> VortexResult<Self> {
156        Self::try_from(&value)
157    }
158}
159
160impl From<bool> for Scalar {
161    fn from(value: bool) -> Self {
162        Self::new(DType::Bool(NonNullable), value.into())
163    }
164}
165
166impl From<bool> for ScalarValue {
167    fn from(value: bool) -> Self {
168        ScalarValue(InnerScalarValue::Bool(value))
169    }
170}
171
172#[cfg(test)]
173mod test {
174    use vortex_dtype::Nullability::*;
175
176    use super::*;
177
178    #[test]
179    fn into_from() {
180        let scalar: Scalar = false.into();
181        assert!(!bool::try_from(&scalar).unwrap());
182    }
183
184    #[test]
185    fn equality() {
186        assert_eq!(&Scalar::bool(true, Nullable), &Scalar::bool(true, Nullable));
187        // Equality ignores nullability
188        assert_eq!(
189            &Scalar::bool(true, Nullable),
190            &Scalar::bool(true, NonNullable)
191        );
192    }
193
194    #[test]
195    fn test_bool_scalar_ordering() {
196        let false_scalar = Scalar::bool(false, NonNullable);
197        let true_scalar = Scalar::bool(true, NonNullable);
198        let null_scalar = Scalar::null(DType::Bool(Nullable));
199
200        let false_bool = BoolScalar::try_from(&false_scalar).unwrap();
201        let true_bool = BoolScalar::try_from(&true_scalar).unwrap();
202        let null_bool = BoolScalar::try_from(&null_scalar).unwrap();
203
204        // false < true
205        assert!(false_bool < true_bool);
206        assert!(true_bool > false_bool);
207
208        // None < Some(false) < Some(true)
209        assert!(null_bool < false_bool);
210        assert!(null_bool < true_bool);
211        assert!(false_bool > null_bool);
212        assert!(true_bool > null_bool);
213    }
214
215    #[test]
216    fn test_bool_invert() {
217        let true_scalar = Scalar::bool(true, NonNullable);
218        let false_scalar = Scalar::bool(false, NonNullable);
219        let null_scalar = Scalar::null(DType::Bool(Nullable));
220
221        let true_bool = BoolScalar::try_from(&true_scalar).unwrap();
222        let false_bool = BoolScalar::try_from(&false_scalar).unwrap();
223        let null_bool = BoolScalar::try_from(&null_scalar).unwrap();
224
225        // Invert true -> false
226        let inverted_true = true_bool.invert();
227        assert_eq!(inverted_true.value(), Some(false));
228
229        // Invert false -> true
230        let inverted_false = false_bool.invert();
231        assert_eq!(inverted_false.value(), Some(true));
232
233        // Invert null -> null
234        let inverted_null = null_bool.invert();
235        assert_eq!(inverted_null.value(), None);
236    }
237
238    #[test]
239    fn test_bool_into_scalar() {
240        let bool_scalar = BoolScalar {
241            dtype: &DType::Bool(NonNullable),
242            value: Some(true),
243        };
244
245        let scalar = bool_scalar.into_scalar();
246        assert_eq!(scalar.dtype(), &DType::Bool(NonNullable));
247        assert!(bool::try_from(&scalar).unwrap());
248
249        // Test null case
250        let null_bool_scalar = BoolScalar {
251            dtype: &DType::Bool(Nullable),
252            value: None,
253        };
254
255        let null_scalar = null_bool_scalar.into_scalar();
256        assert!(null_scalar.is_null());
257    }
258
259    #[test]
260    fn test_bool_cast_to_bool() {
261        let bool_scalar = Scalar::bool(true, NonNullable);
262        let bool = BoolScalar::try_from(&bool_scalar).unwrap();
263
264        // Cast to nullable bool
265        let result = bool.cast(&DType::Bool(Nullable)).unwrap();
266        assert_eq!(result.dtype(), &DType::Bool(Nullable));
267        assert!(bool::try_from(&result).unwrap());
268
269        // Cast to non-nullable bool
270        let result = bool.cast(&DType::Bool(NonNullable)).unwrap();
271        assert_eq!(result.dtype(), &DType::Bool(NonNullable));
272        assert!(bool::try_from(&result).unwrap());
273    }
274
275    #[test]
276    fn test_bool_cast_to_non_bool_fails() {
277        use vortex_dtype::PType;
278
279        let bool_scalar = Scalar::bool(true, NonNullable);
280        let bool = BoolScalar::try_from(&bool_scalar).unwrap();
281
282        let result = bool.cast(&DType::Primitive(PType::I32, NonNullable));
283        assert!(result.is_err());
284    }
285
286    #[test]
287    fn test_try_from_non_bool_scalar() {
288        let int_scalar = Scalar::primitive(42i32, NonNullable);
289        let result = BoolScalar::try_from(&int_scalar);
290        assert!(result.is_err());
291    }
292
293    #[test]
294    fn test_try_from_null_scalar() {
295        let null_scalar = Scalar::null(DType::Bool(Nullable));
296
297        // Try to extract bool from null - should fail
298        let result: Result<bool, _> = (&null_scalar).try_into();
299        assert!(result.is_err());
300
301        // Extract Option<bool> from null - should succeed with None
302        let result: Result<Option<bool>, _> = (&null_scalar).try_into();
303        assert!(result.is_ok());
304        assert_eq!(result.unwrap(), None);
305    }
306
307    #[test]
308    fn test_try_from_owned_scalar() {
309        // Test owned Scalar -> bool
310        let scalar = Scalar::bool(true, NonNullable);
311        let result: Result<bool, _> = scalar.try_into();
312        assert!(result.is_ok());
313        assert!(result.unwrap());
314
315        // Test owned Scalar -> Option<bool>
316        let scalar = Scalar::bool(false, Nullable);
317        let result: Result<Option<bool>, _> = scalar.try_into();
318        assert!(result.is_ok());
319        assert_eq!(result.unwrap(), Some(false));
320
321        // Test owned null Scalar -> Option<bool>
322        let null_scalar = Scalar::null(DType::Bool(Nullable));
323        let result: Result<Option<bool>, _> = null_scalar.try_into();
324        assert!(result.is_ok());
325        assert_eq!(result.unwrap(), None);
326    }
327
328    #[test]
329    fn test_scalar_value_from_bool() {
330        let value: ScalarValue = true.into();
331        let scalar = Scalar::new(DType::Bool(NonNullable), value);
332        assert!(bool::try_from(&scalar).unwrap());
333
334        let value: ScalarValue = false.into();
335        let scalar = Scalar::new(DType::Bool(NonNullable), value);
336        assert!(!bool::try_from(&scalar).unwrap());
337    }
338
339    #[test]
340    fn test_bool_partial_eq_different_values() {
341        let true_scalar = Scalar::bool(true, NonNullable);
342        let false_scalar = Scalar::bool(false, NonNullable);
343
344        let true_bool = BoolScalar::try_from(&true_scalar).unwrap();
345        let false_bool = BoolScalar::try_from(&false_scalar).unwrap();
346
347        assert_ne!(true_bool, false_bool);
348    }
349
350    #[test]
351    fn test_bool_partial_eq_null() {
352        let null_scalar1 = Scalar::null(DType::Bool(Nullable));
353        let null_scalar2 = Scalar::null(DType::Bool(Nullable));
354        let non_null_scalar = Scalar::bool(true, Nullable);
355
356        let null_bool1 = BoolScalar::try_from(&null_scalar1).unwrap();
357        let null_bool2 = BoolScalar::try_from(&null_scalar2).unwrap();
358        let non_null_bool = BoolScalar::try_from(&non_null_scalar).unwrap();
359
360        // Two nulls are equal
361        assert_eq!(null_bool1, null_bool2);
362
363        // Null != non-null
364        assert_ne!(null_bool1, non_null_bool);
365    }
366
367    #[test]
368    fn test_bool_value_accessor() {
369        let true_scalar = Scalar::bool(true, NonNullable);
370        let false_scalar = Scalar::bool(false, NonNullable);
371        let null_scalar = Scalar::null(DType::Bool(Nullable));
372
373        let true_bool = BoolScalar::try_from(&true_scalar).unwrap();
374        let false_bool = BoolScalar::try_from(&false_scalar).unwrap();
375        let null_bool = BoolScalar::try_from(&null_scalar).unwrap();
376
377        assert_eq!(true_bool.value(), Some(true));
378        assert_eq!(false_bool.value(), Some(false));
379        assert_eq!(null_bool.value(), None);
380    }
381
382    #[test]
383    fn test_bool_dtype_accessor() {
384        let nullable_scalar = Scalar::bool(true, Nullable);
385        let non_nullable_scalar = Scalar::bool(false, NonNullable);
386
387        let nullable_bool = BoolScalar::try_from(&nullable_scalar).unwrap();
388        let non_nullable_bool = BoolScalar::try_from(&non_nullable_scalar).unwrap();
389
390        assert_eq!(nullable_bool.dtype(), &DType::Bool(Nullable));
391        assert_eq!(non_nullable_bool.dtype(), &DType::Bool(NonNullable));
392    }
393
394    #[test]
395    fn test_bool_partial_cmp() {
396        let false_scalar = Scalar::bool(false, NonNullable);
397        let true_scalar = Scalar::bool(true, NonNullable);
398
399        let false_bool = BoolScalar::try_from(&false_scalar).unwrap();
400        let true_bool = BoolScalar::try_from(&true_scalar).unwrap();
401
402        assert_eq!(false_bool.partial_cmp(&false_bool), Some(Ordering::Equal));
403        assert_eq!(false_bool.partial_cmp(&true_bool), Some(Ordering::Less));
404        assert_eq!(true_bool.partial_cmp(&false_bool), Some(Ordering::Greater));
405    }
406}