1use std::cmp::Ordering;
5use std::fmt::Display;
6use std::fmt::Formatter;
7
8use vortex_dtype::DType;
9use vortex_dtype::Nullability;
10use vortex_dtype::Nullability::NonNullable;
11use vortex_error::VortexError;
12use vortex_error::VortexExpect as _;
13use vortex_error::VortexResult;
14use vortex_error::vortex_bail;
15use vortex_error::vortex_err;
16
17use crate::InnerScalarValue;
18use crate::Scalar;
19use crate::ScalarValue;
20
21#[derive(Debug, Clone, Hash, Eq)]
26pub struct BoolScalar<'a> {
27 dtype: &'a DType,
28 value: Option<bool>,
29}
30
31impl Display for BoolScalar<'_> {
32 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
33 match self.value {
34 None => write!(f, "null"),
35 Some(v) => write!(f, "{v}"),
36 }
37 }
38}
39
40impl PartialEq for BoolScalar<'_> {
41 fn eq(&self, other: &Self) -> bool {
42 self.dtype.eq_ignore_nullability(other.dtype) && self.value == other.value
43 }
44}
45
46impl PartialOrd for BoolScalar<'_> {
47 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
48 Some(self.cmp(other))
49 }
50}
51
52impl Ord for BoolScalar<'_> {
53 fn cmp(&self, other: &Self) -> Ordering {
54 self.value.cmp(&other.value)
55 }
56}
57
58impl<'a> BoolScalar<'a> {
59 #[inline]
61 pub fn dtype(&self) -> &'a DType {
62 self.dtype
63 }
64
65 pub fn value(&self) -> Option<bool> {
67 self.value
68 }
69
70 pub(crate) fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
71 if !matches!(dtype, DType::Bool(..)) {
72 vortex_bail!(
73 "Cannot cast bool to {dtype}: boolean scalars can only be cast to boolean types with different nullability"
74 )
75 }
76 Ok(Scalar::bool(
77 self.value.vortex_expect("nullness handled in Scalar::cast"),
78 dtype.nullability(),
79 ))
80 }
81
82 pub fn invert(self) -> BoolScalar<'a> {
86 BoolScalar {
87 dtype: self.dtype,
88 value: self.value.map(|v| !v),
89 }
90 }
91
92 pub fn into_scalar(self) -> Scalar {
94 Scalar::new(
95 self.dtype.clone(),
96 self.value
97 .map(|x| ScalarValue(InnerScalarValue::Bool(x)))
98 .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null)),
99 )
100 }
101}
102
103impl Scalar {
104 pub fn bool(value: bool, nullability: Nullability) -> Self {
106 Self::new(
107 DType::Bool(nullability),
108 ScalarValue(InnerScalarValue::Bool(value)),
109 )
110 }
111}
112
113impl<'a> TryFrom<&'a Scalar> for BoolScalar<'a> {
114 type Error = VortexError;
115
116 fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
117 if !matches!(value.dtype(), DType::Bool(_)) {
118 vortex_bail!("Expected bool scalar, found {}", value.dtype())
119 }
120 Ok(Self {
121 dtype: value.dtype(),
122 value: value.value().as_bool()?,
123 })
124 }
125}
126
127impl TryFrom<&Scalar> for bool {
128 type Error = VortexError;
129
130 fn try_from(value: &Scalar) -> VortexResult<Self> {
131 <Option<bool>>::try_from(value)?
132 .ok_or_else(|| vortex_err!("Can't extract present value from null scalar"))
133 }
134}
135
136impl TryFrom<&Scalar> for Option<bool> {
137 type Error = VortexError;
138
139 fn try_from(value: &Scalar) -> VortexResult<Self> {
140 Ok(BoolScalar::try_from(value)?.value())
141 }
142}
143
144impl TryFrom<Scalar> for bool {
145 type Error = VortexError;
146
147 fn try_from(value: Scalar) -> VortexResult<Self> {
148 Self::try_from(&value)
149 }
150}
151
152impl TryFrom<Scalar> for Option<bool> {
153 type Error = VortexError;
154
155 fn try_from(value: Scalar) -> VortexResult<Self> {
156 Self::try_from(&value)
157 }
158}
159
160impl From<bool> for Scalar {
161 fn from(value: bool) -> Self {
162 Self::new(DType::Bool(NonNullable), value.into())
163 }
164}
165
166impl From<bool> for ScalarValue {
167 fn from(value: bool) -> Self {
168 ScalarValue(InnerScalarValue::Bool(value))
169 }
170}
171
172#[cfg(test)]
173mod test {
174 use vortex_dtype::Nullability::*;
175
176 use super::*;
177
178 #[test]
179 fn into_from() {
180 let scalar: Scalar = false.into();
181 assert!(!bool::try_from(&scalar).unwrap());
182 }
183
184 #[test]
185 fn equality() {
186 assert_eq!(&Scalar::bool(true, Nullable), &Scalar::bool(true, Nullable));
187 assert_eq!(
189 &Scalar::bool(true, Nullable),
190 &Scalar::bool(true, NonNullable)
191 );
192 }
193
194 #[test]
195 fn test_bool_scalar_ordering() {
196 let false_scalar = Scalar::bool(false, NonNullable);
197 let true_scalar = Scalar::bool(true, NonNullable);
198 let null_scalar = Scalar::null(DType::Bool(Nullable));
199
200 let false_bool = BoolScalar::try_from(&false_scalar).unwrap();
201 let true_bool = BoolScalar::try_from(&true_scalar).unwrap();
202 let null_bool = BoolScalar::try_from(&null_scalar).unwrap();
203
204 assert!(false_bool < true_bool);
206 assert!(true_bool > false_bool);
207
208 assert!(null_bool < false_bool);
210 assert!(null_bool < true_bool);
211 assert!(false_bool > null_bool);
212 assert!(true_bool > null_bool);
213 }
214
215 #[test]
216 fn test_bool_invert() {
217 let true_scalar = Scalar::bool(true, NonNullable);
218 let false_scalar = Scalar::bool(false, NonNullable);
219 let null_scalar = Scalar::null(DType::Bool(Nullable));
220
221 let true_bool = BoolScalar::try_from(&true_scalar).unwrap();
222 let false_bool = BoolScalar::try_from(&false_scalar).unwrap();
223 let null_bool = BoolScalar::try_from(&null_scalar).unwrap();
224
225 let inverted_true = true_bool.invert();
227 assert_eq!(inverted_true.value(), Some(false));
228
229 let inverted_false = false_bool.invert();
231 assert_eq!(inverted_false.value(), Some(true));
232
233 let inverted_null = null_bool.invert();
235 assert_eq!(inverted_null.value(), None);
236 }
237
238 #[test]
239 fn test_bool_into_scalar() {
240 let bool_scalar = BoolScalar {
241 dtype: &DType::Bool(NonNullable),
242 value: Some(true),
243 };
244
245 let scalar = bool_scalar.into_scalar();
246 assert_eq!(scalar.dtype(), &DType::Bool(NonNullable));
247 assert!(bool::try_from(&scalar).unwrap());
248
249 let null_bool_scalar = BoolScalar {
251 dtype: &DType::Bool(Nullable),
252 value: None,
253 };
254
255 let null_scalar = null_bool_scalar.into_scalar();
256 assert!(null_scalar.is_null());
257 }
258
259 #[test]
260 fn test_bool_cast_to_bool() {
261 let bool_scalar = Scalar::bool(true, NonNullable);
262 let bool = BoolScalar::try_from(&bool_scalar).unwrap();
263
264 let result = bool.cast(&DType::Bool(Nullable)).unwrap();
266 assert_eq!(result.dtype(), &DType::Bool(Nullable));
267 assert!(bool::try_from(&result).unwrap());
268
269 let result = bool.cast(&DType::Bool(NonNullable)).unwrap();
271 assert_eq!(result.dtype(), &DType::Bool(NonNullable));
272 assert!(bool::try_from(&result).unwrap());
273 }
274
275 #[test]
276 fn test_bool_cast_to_non_bool_fails() {
277 use vortex_dtype::PType;
278
279 let bool_scalar = Scalar::bool(true, NonNullable);
280 let bool = BoolScalar::try_from(&bool_scalar).unwrap();
281
282 let result = bool.cast(&DType::Primitive(PType::I32, NonNullable));
283 assert!(result.is_err());
284 }
285
286 #[test]
287 fn test_try_from_non_bool_scalar() {
288 let int_scalar = Scalar::primitive(42i32, NonNullable);
289 let result = BoolScalar::try_from(&int_scalar);
290 assert!(result.is_err());
291 }
292
293 #[test]
294 fn test_try_from_null_scalar() {
295 let null_scalar = Scalar::null(DType::Bool(Nullable));
296
297 let result: Result<bool, _> = (&null_scalar).try_into();
299 assert!(result.is_err());
300
301 let result: Result<Option<bool>, _> = (&null_scalar).try_into();
303 assert!(result.is_ok());
304 assert_eq!(result.unwrap(), None);
305 }
306
307 #[test]
308 fn test_try_from_owned_scalar() {
309 let scalar = Scalar::bool(true, NonNullable);
311 let result: Result<bool, _> = scalar.try_into();
312 assert!(result.is_ok());
313 assert!(result.unwrap());
314
315 let scalar = Scalar::bool(false, Nullable);
317 let result: Result<Option<bool>, _> = scalar.try_into();
318 assert!(result.is_ok());
319 assert_eq!(result.unwrap(), Some(false));
320
321 let null_scalar = Scalar::null(DType::Bool(Nullable));
323 let result: Result<Option<bool>, _> = null_scalar.try_into();
324 assert!(result.is_ok());
325 assert_eq!(result.unwrap(), None);
326 }
327
328 #[test]
329 fn test_scalar_value_from_bool() {
330 let value: ScalarValue = true.into();
331 let scalar = Scalar::new(DType::Bool(NonNullable), value);
332 assert!(bool::try_from(&scalar).unwrap());
333
334 let value: ScalarValue = false.into();
335 let scalar = Scalar::new(DType::Bool(NonNullable), value);
336 assert!(!bool::try_from(&scalar).unwrap());
337 }
338
339 #[test]
340 fn test_bool_partial_eq_different_values() {
341 let true_scalar = Scalar::bool(true, NonNullable);
342 let false_scalar = Scalar::bool(false, NonNullable);
343
344 let true_bool = BoolScalar::try_from(&true_scalar).unwrap();
345 let false_bool = BoolScalar::try_from(&false_scalar).unwrap();
346
347 assert_ne!(true_bool, false_bool);
348 }
349
350 #[test]
351 fn test_bool_partial_eq_null() {
352 let null_scalar1 = Scalar::null(DType::Bool(Nullable));
353 let null_scalar2 = Scalar::null(DType::Bool(Nullable));
354 let non_null_scalar = Scalar::bool(true, Nullable);
355
356 let null_bool1 = BoolScalar::try_from(&null_scalar1).unwrap();
357 let null_bool2 = BoolScalar::try_from(&null_scalar2).unwrap();
358 let non_null_bool = BoolScalar::try_from(&non_null_scalar).unwrap();
359
360 assert_eq!(null_bool1, null_bool2);
362
363 assert_ne!(null_bool1, non_null_bool);
365 }
366
367 #[test]
368 fn test_bool_value_accessor() {
369 let true_scalar = Scalar::bool(true, NonNullable);
370 let false_scalar = Scalar::bool(false, NonNullable);
371 let null_scalar = Scalar::null(DType::Bool(Nullable));
372
373 let true_bool = BoolScalar::try_from(&true_scalar).unwrap();
374 let false_bool = BoolScalar::try_from(&false_scalar).unwrap();
375 let null_bool = BoolScalar::try_from(&null_scalar).unwrap();
376
377 assert_eq!(true_bool.value(), Some(true));
378 assert_eq!(false_bool.value(), Some(false));
379 assert_eq!(null_bool.value(), None);
380 }
381
382 #[test]
383 fn test_bool_dtype_accessor() {
384 let nullable_scalar = Scalar::bool(true, Nullable);
385 let non_nullable_scalar = Scalar::bool(false, NonNullable);
386
387 let nullable_bool = BoolScalar::try_from(&nullable_scalar).unwrap();
388 let non_nullable_bool = BoolScalar::try_from(&non_nullable_scalar).unwrap();
389
390 assert_eq!(nullable_bool.dtype(), &DType::Bool(Nullable));
391 assert_eq!(non_nullable_bool.dtype(), &DType::Bool(NonNullable));
392 }
393
394 #[test]
395 fn test_bool_partial_cmp() {
396 let false_scalar = Scalar::bool(false, NonNullable);
397 let true_scalar = Scalar::bool(true, NonNullable);
398
399 let false_bool = BoolScalar::try_from(&false_scalar).unwrap();
400 let true_bool = BoolScalar::try_from(&true_scalar).unwrap();
401
402 assert_eq!(false_bool.partial_cmp(&false_bool), Some(Ordering::Equal));
403 assert_eq!(false_bool.partial_cmp(&true_bool), Some(Ordering::Less));
404 assert_eq!(true_bool.partial_cmp(&false_bool), Some(Ordering::Greater));
405 }
406}