vortex_scalar/
bool.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::cmp::Ordering;
5use std::fmt::{Display, Formatter};
6
7use vortex_dtype::Nullability::NonNullable;
8use vortex_dtype::{DType, Nullability};
9use vortex_error::{VortexError, VortexExpect as _, VortexResult, vortex_bail, vortex_err};
10
11use crate::{InnerScalarValue, Scalar, ScalarValue};
12
13/// A scalar value representing a boolean.
14///
15/// This type provides a view into a boolean scalar value, which can be either
16/// true, false, or null.
17#[derive(Debug, Clone, Hash, Eq)]
18pub struct BoolScalar<'a> {
19    dtype: &'a DType,
20    value: Option<bool>,
21}
22
23impl Display for BoolScalar<'_> {
24    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
25        match self.value {
26            None => write!(f, "null"),
27            Some(v) => write!(f, "{v}"),
28        }
29    }
30}
31
32impl PartialEq for BoolScalar<'_> {
33    fn eq(&self, other: &Self) -> bool {
34        self.dtype.eq_ignore_nullability(other.dtype) && self.value == other.value
35    }
36}
37
38impl PartialOrd for BoolScalar<'_> {
39    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
40        Some(self.cmp(other))
41    }
42}
43
44impl Ord for BoolScalar<'_> {
45    fn cmp(&self, other: &Self) -> Ordering {
46        self.value.cmp(&other.value)
47    }
48}
49
50impl<'a> BoolScalar<'a> {
51    /// Returns the data type of this boolean scalar.
52    #[inline]
53    pub fn dtype(&self) -> &'a DType {
54        self.dtype
55    }
56
57    /// Returns the boolean value, or None if null.
58    pub fn value(&self) -> Option<bool> {
59        self.value
60    }
61
62    pub(crate) fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
63        if !matches!(dtype, DType::Bool(..)) {
64            vortex_bail!(
65                "Cannot cast bool to {dtype}: boolean scalars can only be cast to boolean types with different nullability"
66            )
67        }
68        Ok(Scalar::bool(
69            self.value.vortex_expect("nullness handled in Scalar::cast"),
70            dtype.nullability(),
71        ))
72    }
73
74    /// Returns a new boolean scalar with the inverted value.
75    ///
76    /// Null values remain null.
77    pub fn invert(self) -> BoolScalar<'a> {
78        BoolScalar {
79            dtype: self.dtype,
80            value: self.value.map(|v| !v),
81        }
82    }
83
84    /// Converts this boolean scalar into a general scalar.
85    pub fn into_scalar(self) -> Scalar {
86        Scalar::new(
87            self.dtype.clone(),
88            self.value
89                .map(|x| ScalarValue(InnerScalarValue::Bool(x)))
90                .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null)),
91        )
92    }
93}
94
95impl Scalar {
96    /// Creates a new boolean scalar with the given value and nullability.
97    pub fn bool(value: bool, nullability: Nullability) -> Self {
98        Self::new(
99            DType::Bool(nullability),
100            ScalarValue(InnerScalarValue::Bool(value)),
101        )
102    }
103}
104
105impl<'a> TryFrom<&'a Scalar> for BoolScalar<'a> {
106    type Error = VortexError;
107
108    fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
109        if !matches!(value.dtype(), DType::Bool(_)) {
110            vortex_bail!("Expected bool scalar, found {}", value.dtype())
111        }
112        Ok(Self {
113            dtype: value.dtype(),
114            value: value.value().as_bool()?,
115        })
116    }
117}
118
119impl TryFrom<&Scalar> for bool {
120    type Error = VortexError;
121
122    fn try_from(value: &Scalar) -> VortexResult<Self> {
123        <Option<bool>>::try_from(value)?
124            .ok_or_else(|| vortex_err!("Can't extract present value from null scalar"))
125    }
126}
127
128impl TryFrom<&Scalar> for Option<bool> {
129    type Error = VortexError;
130
131    fn try_from(value: &Scalar) -> VortexResult<Self> {
132        Ok(BoolScalar::try_from(value)?.value())
133    }
134}
135
136impl TryFrom<Scalar> for bool {
137    type Error = VortexError;
138
139    fn try_from(value: Scalar) -> VortexResult<Self> {
140        Self::try_from(&value)
141    }
142}
143
144impl TryFrom<Scalar> for Option<bool> {
145    type Error = VortexError;
146
147    fn try_from(value: Scalar) -> VortexResult<Self> {
148        Self::try_from(&value)
149    }
150}
151
152impl From<bool> for Scalar {
153    fn from(value: bool) -> Self {
154        Self::new(DType::Bool(NonNullable), value.into())
155    }
156}
157
158impl From<bool> for ScalarValue {
159    fn from(value: bool) -> Self {
160        ScalarValue(InnerScalarValue::Bool(value))
161    }
162}
163
164#[cfg(test)]
165mod test {
166    use vortex_dtype::Nullability::*;
167
168    use super::*;
169
170    #[test]
171    fn into_from() {
172        let scalar: Scalar = false.into();
173        assert!(!bool::try_from(&scalar).unwrap());
174    }
175
176    #[test]
177    fn equality() {
178        assert_eq!(&Scalar::bool(true, Nullable), &Scalar::bool(true, Nullable));
179        // Equality ignores nullability
180        assert_eq!(
181            &Scalar::bool(true, Nullable),
182            &Scalar::bool(true, NonNullable)
183        );
184    }
185
186    #[test]
187    fn test_bool_scalar_ordering() {
188        let false_scalar = Scalar::bool(false, NonNullable);
189        let true_scalar = Scalar::bool(true, NonNullable);
190        let null_scalar = Scalar::null(DType::Bool(Nullable));
191
192        let false_bool = BoolScalar::try_from(&false_scalar).unwrap();
193        let true_bool = BoolScalar::try_from(&true_scalar).unwrap();
194        let null_bool = BoolScalar::try_from(&null_scalar).unwrap();
195
196        // false < true
197        assert!(false_bool < true_bool);
198        assert!(true_bool > false_bool);
199
200        // None < Some(false) < Some(true)
201        assert!(null_bool < false_bool);
202        assert!(null_bool < true_bool);
203        assert!(false_bool > null_bool);
204        assert!(true_bool > null_bool);
205    }
206
207    #[test]
208    fn test_bool_invert() {
209        let true_scalar = Scalar::bool(true, NonNullable);
210        let false_scalar = Scalar::bool(false, NonNullable);
211        let null_scalar = Scalar::null(DType::Bool(Nullable));
212
213        let true_bool = BoolScalar::try_from(&true_scalar).unwrap();
214        let false_bool = BoolScalar::try_from(&false_scalar).unwrap();
215        let null_bool = BoolScalar::try_from(&null_scalar).unwrap();
216
217        // Invert true -> false
218        let inverted_true = true_bool.invert();
219        assert_eq!(inverted_true.value(), Some(false));
220
221        // Invert false -> true
222        let inverted_false = false_bool.invert();
223        assert_eq!(inverted_false.value(), Some(true));
224
225        // Invert null -> null
226        let inverted_null = null_bool.invert();
227        assert_eq!(inverted_null.value(), None);
228    }
229
230    #[test]
231    fn test_bool_into_scalar() {
232        let bool_scalar = BoolScalar {
233            dtype: &DType::Bool(NonNullable),
234            value: Some(true),
235        };
236
237        let scalar = bool_scalar.into_scalar();
238        assert_eq!(scalar.dtype(), &DType::Bool(NonNullable));
239        assert!(bool::try_from(&scalar).unwrap());
240
241        // Test null case
242        let null_bool_scalar = BoolScalar {
243            dtype: &DType::Bool(Nullable),
244            value: None,
245        };
246
247        let null_scalar = null_bool_scalar.into_scalar();
248        assert!(null_scalar.is_null());
249    }
250
251    #[test]
252    fn test_bool_cast_to_bool() {
253        let bool_scalar = Scalar::bool(true, NonNullable);
254        let bool = BoolScalar::try_from(&bool_scalar).unwrap();
255
256        // Cast to nullable bool
257        let result = bool.cast(&DType::Bool(Nullable)).unwrap();
258        assert_eq!(result.dtype(), &DType::Bool(Nullable));
259        assert!(bool::try_from(&result).unwrap());
260
261        // Cast to non-nullable bool
262        let result = bool.cast(&DType::Bool(NonNullable)).unwrap();
263        assert_eq!(result.dtype(), &DType::Bool(NonNullable));
264        assert!(bool::try_from(&result).unwrap());
265    }
266
267    #[test]
268    fn test_bool_cast_to_non_bool_fails() {
269        use vortex_dtype::PType;
270
271        let bool_scalar = Scalar::bool(true, NonNullable);
272        let bool = BoolScalar::try_from(&bool_scalar).unwrap();
273
274        let result = bool.cast(&DType::Primitive(PType::I32, NonNullable));
275        assert!(result.is_err());
276    }
277
278    #[test]
279    fn test_try_from_non_bool_scalar() {
280        let int_scalar = Scalar::primitive(42i32, NonNullable);
281        let result = BoolScalar::try_from(&int_scalar);
282        assert!(result.is_err());
283    }
284
285    #[test]
286    fn test_try_from_null_scalar() {
287        let null_scalar = Scalar::null(DType::Bool(Nullable));
288
289        // Try to extract bool from null - should fail
290        let result: Result<bool, _> = (&null_scalar).try_into();
291        assert!(result.is_err());
292
293        // Extract Option<bool> from null - should succeed with None
294        let result: Result<Option<bool>, _> = (&null_scalar).try_into();
295        assert!(result.is_ok());
296        assert_eq!(result.unwrap(), None);
297    }
298
299    #[test]
300    fn test_try_from_owned_scalar() {
301        // Test owned Scalar -> bool
302        let scalar = Scalar::bool(true, NonNullable);
303        let result: Result<bool, _> = scalar.try_into();
304        assert!(result.is_ok());
305        assert!(result.unwrap());
306
307        // Test owned Scalar -> Option<bool>
308        let scalar = Scalar::bool(false, Nullable);
309        let result: Result<Option<bool>, _> = scalar.try_into();
310        assert!(result.is_ok());
311        assert_eq!(result.unwrap(), Some(false));
312
313        // Test owned null Scalar -> Option<bool>
314        let null_scalar = Scalar::null(DType::Bool(Nullable));
315        let result: Result<Option<bool>, _> = null_scalar.try_into();
316        assert!(result.is_ok());
317        assert_eq!(result.unwrap(), None);
318    }
319
320    #[test]
321    fn test_scalar_value_from_bool() {
322        let value: ScalarValue = true.into();
323        let scalar = Scalar::new(DType::Bool(NonNullable), value);
324        assert!(bool::try_from(&scalar).unwrap());
325
326        let value: ScalarValue = false.into();
327        let scalar = Scalar::new(DType::Bool(NonNullable), value);
328        assert!(!bool::try_from(&scalar).unwrap());
329    }
330
331    #[test]
332    fn test_bool_partial_eq_different_values() {
333        let true_scalar = Scalar::bool(true, NonNullable);
334        let false_scalar = Scalar::bool(false, NonNullable);
335
336        let true_bool = BoolScalar::try_from(&true_scalar).unwrap();
337        let false_bool = BoolScalar::try_from(&false_scalar).unwrap();
338
339        assert_ne!(true_bool, false_bool);
340    }
341
342    #[test]
343    fn test_bool_partial_eq_null() {
344        let null_scalar1 = Scalar::null(DType::Bool(Nullable));
345        let null_scalar2 = Scalar::null(DType::Bool(Nullable));
346        let non_null_scalar = Scalar::bool(true, Nullable);
347
348        let null_bool1 = BoolScalar::try_from(&null_scalar1).unwrap();
349        let null_bool2 = BoolScalar::try_from(&null_scalar2).unwrap();
350        let non_null_bool = BoolScalar::try_from(&non_null_scalar).unwrap();
351
352        // Two nulls are equal
353        assert_eq!(null_bool1, null_bool2);
354
355        // Null != non-null
356        assert_ne!(null_bool1, non_null_bool);
357    }
358
359    #[test]
360    fn test_bool_value_accessor() {
361        let true_scalar = Scalar::bool(true, NonNullable);
362        let false_scalar = Scalar::bool(false, NonNullable);
363        let null_scalar = Scalar::null(DType::Bool(Nullable));
364
365        let true_bool = BoolScalar::try_from(&true_scalar).unwrap();
366        let false_bool = BoolScalar::try_from(&false_scalar).unwrap();
367        let null_bool = BoolScalar::try_from(&null_scalar).unwrap();
368
369        assert_eq!(true_bool.value(), Some(true));
370        assert_eq!(false_bool.value(), Some(false));
371        assert_eq!(null_bool.value(), None);
372    }
373
374    #[test]
375    fn test_bool_dtype_accessor() {
376        let nullable_scalar = Scalar::bool(true, Nullable);
377        let non_nullable_scalar = Scalar::bool(false, NonNullable);
378
379        let nullable_bool = BoolScalar::try_from(&nullable_scalar).unwrap();
380        let non_nullable_bool = BoolScalar::try_from(&non_nullable_scalar).unwrap();
381
382        assert_eq!(nullable_bool.dtype(), &DType::Bool(Nullable));
383        assert_eq!(non_nullable_bool.dtype(), &DType::Bool(NonNullable));
384    }
385
386    #[test]
387    fn test_bool_partial_cmp() {
388        let false_scalar = Scalar::bool(false, NonNullable);
389        let true_scalar = Scalar::bool(true, NonNullable);
390
391        let false_bool = BoolScalar::try_from(&false_scalar).unwrap();
392        let true_bool = BoolScalar::try_from(&true_scalar).unwrap();
393
394        assert_eq!(false_bool.partial_cmp(&false_bool), Some(Ordering::Equal));
395        assert_eq!(false_bool.partial_cmp(&true_bool), Some(Ordering::Less));
396        assert_eq!(true_bool.partial_cmp(&false_bool), Some(Ordering::Greater));
397    }
398}