1#![deny(missing_docs)]
11
12use std::cmp::Ordering;
13use std::hash::Hash;
14use std::sync::Arc;
15
16pub use scalar_type::ScalarType;
17use vortex_buffer::{Buffer, BufferString, ByteBuffer};
18use vortex_dtype::half::f16;
19use vortex_dtype::{DECIMAL128_MAX_PRECISION, DType, Nullability};
20#[cfg(feature = "arbitrary")]
21pub mod arbitrary;
22mod arrow;
23mod bigint;
24mod binary;
25mod bool;
26mod decimal;
27mod display;
28mod extension;
29mod list;
30mod null;
31mod primitive;
32mod proto;
33mod pvalue;
34mod scalar_type;
35mod scalar_value;
36mod struct_;
37#[cfg(test)]
38mod tests;
39mod utf8;
40
41pub use bigint::*;
42pub use binary::*;
43pub use bool::*;
44pub use decimal::*;
45pub use extension::*;
46pub use list::*;
47pub use primitive::*;
48pub use pvalue::*;
49pub use scalar_value::*;
50pub use struct_::*;
51pub use utf8::*;
52use vortex_error::{VortexExpect, VortexResult, vortex_bail};
53
54#[derive(Debug, Clone)]
63pub struct Scalar {
64 dtype: DType,
65 value: ScalarValue,
66}
67
68impl Scalar {
69 pub fn new(dtype: DType, value: ScalarValue) -> Self {
71 Self { dtype, value }
72 }
73
74 #[inline]
76 pub fn dtype(&self) -> &DType {
77 &self.dtype
78 }
79
80 #[inline]
82 pub fn value(&self) -> &ScalarValue {
83 &self.value
84 }
85
86 #[inline]
88 pub fn into_parts(self) -> (DType, ScalarValue) {
89 (self.dtype, self.value)
90 }
91
92 #[inline]
94 pub fn into_value(self) -> ScalarValue {
95 self.value
96 }
97
98 pub fn is_valid(&self) -> bool {
100 !self.value.is_null()
101 }
102
103 pub fn is_null(&self) -> bool {
105 self.value.is_null()
106 }
107
108 pub fn null(dtype: DType) -> Self {
114 assert!(
115 dtype.is_nullable(),
116 "Creating null scalar for non-nullable DType {dtype}"
117 );
118 Self {
119 dtype,
120 value: ScalarValue(InnerScalarValue::Null),
121 }
122 }
123
124 pub fn null_typed<T: ScalarType>() -> Self {
128 Self {
129 dtype: T::dtype().as_nullable(),
130 value: ScalarValue(InnerScalarValue::Null),
131 }
132 }
133
134 pub fn cast(&self, target: &DType) -> VortexResult<Self> {
139 if let DType::Extension(ext_dtype) = target {
140 let storage_scalar = self.cast_to_non_extension(ext_dtype.storage_dtype())?;
141 Ok(Scalar::extension(ext_dtype.clone(), storage_scalar))
142 } else {
143 self.cast_to_non_extension(target)
144 }
145 }
146
147 fn cast_to_non_extension(&self, target: &DType) -> VortexResult<Self> {
148 assert!(!matches!(target, DType::Extension(..)));
149 if self.is_null() {
150 if target.is_nullable() {
151 return Ok(Scalar::new(target.clone(), self.value.clone()));
152 } else {
153 vortex_bail!(
154 "Cannot cast null to {}: target type is non-nullable",
155 target
156 )
157 }
158 }
159
160 if self.dtype().eq_ignore_nullability(target) {
161 return Ok(Scalar::new(target.clone(), self.value.clone()));
162 }
163
164 match &self.dtype {
165 DType::Null => unreachable!(), DType::Bool(_) => self.as_bool().cast(target),
167 DType::Primitive(..) => self.as_primitive().cast(target),
168 DType::Decimal(..) => self.as_decimal().cast(target),
169 DType::Utf8(_) => self.as_utf8().cast(target),
170 DType::Binary(_) => self.as_binary().cast(target),
171 DType::Struct(..) => self.as_struct().cast(target),
172 DType::List(..) => self.as_list().cast(target),
173 DType::Extension(..) => self.as_extension().cast(target),
174 }
175 }
176
177 pub fn into_nullable(self) -> Self {
179 Self {
180 dtype: self.dtype.as_nullable(),
181 value: self.value,
182 }
183 }
184
185 pub fn nbytes(&self) -> usize {
187 match self.dtype() {
188 DType::Null => 0,
189 DType::Bool(_) => 1,
190 DType::Primitive(ptype, _) => ptype.byte_width(),
191 DType::Decimal(dt, _) => {
192 if dt.precision() <= DECIMAL128_MAX_PRECISION {
193 size_of::<i128>()
194 } else {
195 size_of::<i256>()
196 }
197 }
198 DType::Binary(_) | DType::Utf8(_) => self
199 .value()
200 .as_buffer()
201 .ok()
202 .flatten()
203 .map_or(0, |s| s.len()),
204 DType::Struct(_dtype, _) => self
205 .as_struct()
206 .fields()
207 .map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::<usize>())
208 .unwrap_or_default(),
209 DType::List(_dtype, _) => self
210 .as_list()
211 .elements()
212 .map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::<usize>())
213 .unwrap_or_default(),
214 DType::Extension(_ext_dtype) => self.as_extension().storage().nbytes(),
215 }
216 }
217
218 pub fn default_value(dtype: DType) -> Self {
223 if dtype.is_nullable() {
224 return Self::null(dtype);
225 }
226
227 match dtype {
228 DType::Null => Self::null(dtype),
229 DType::Bool(nullability) => Self::bool(false, nullability),
230 DType::Primitive(pt, nullability) => {
231 Self::primitive_value(PValue::zero(pt), pt, nullability)
232 }
233 DType::Decimal(dt, nullability) => {
234 Self::decimal(DecimalValue::from(0), dt, nullability)
235 }
236 DType::Utf8(nullability) => Self::utf8("", nullability),
237 DType::Binary(nullability) => Self::binary(Buffer::empty(), nullability),
238 DType::Struct(sf, nullability) => {
239 let fields: Vec<_> = sf.fields().map(Scalar::default_value).collect();
240 Self::struct_(DType::Struct(sf, nullability), fields)
241 }
242 DType::List(dt, nullability) => Self::list(dt, vec![], nullability),
243 DType::Extension(dt) => {
244 let scalar = Self::default_value(dt.storage_dtype().clone());
245 Self::extension(dt, scalar)
246 }
247 }
248 }
249}
250
251impl Scalar {
252 pub fn as_bool(&self) -> BoolScalar<'_> {
258 BoolScalar::try_from(self).vortex_expect("Failed to convert scalar to bool")
259 }
260
261 pub fn as_bool_opt(&self) -> Option<BoolScalar<'_>> {
263 matches!(self.dtype, DType::Bool(..)).then(|| self.as_bool())
264 }
265
266 pub fn as_primitive(&self) -> PrimitiveScalar<'_> {
272 PrimitiveScalar::try_from(self).vortex_expect("Failed to convert scalar to primitive")
273 }
274
275 pub fn as_primitive_opt(&self) -> Option<PrimitiveScalar<'_>> {
277 matches!(self.dtype, DType::Primitive(..)).then(|| self.as_primitive())
278 }
279
280 pub fn as_decimal(&self) -> DecimalScalar<'_> {
286 DecimalScalar::try_from(self).vortex_expect("Failed to convert scalar to decimal")
287 }
288
289 pub fn as_decimal_opt(&self) -> Option<DecimalScalar<'_>> {
291 matches!(self.dtype, DType::Decimal(..)).then(|| self.as_decimal())
292 }
293
294 pub fn as_utf8(&self) -> Utf8Scalar<'_> {
300 Utf8Scalar::try_from(self).vortex_expect("Failed to convert scalar to utf8")
301 }
302
303 pub fn as_utf8_opt(&self) -> Option<Utf8Scalar<'_>> {
305 matches!(self.dtype, DType::Utf8(..)).then(|| self.as_utf8())
306 }
307
308 pub fn as_binary(&self) -> BinaryScalar<'_> {
314 BinaryScalar::try_from(self).vortex_expect("Failed to convert scalar to binary")
315 }
316
317 pub fn as_binary_opt(&self) -> Option<BinaryScalar<'_>> {
319 matches!(self.dtype, DType::Binary(..)).then(|| self.as_binary())
320 }
321
322 pub fn as_struct(&self) -> StructScalar<'_> {
328 StructScalar::try_from(self).vortex_expect("Failed to convert scalar to struct")
329 }
330
331 pub fn as_struct_opt(&self) -> Option<StructScalar<'_>> {
333 matches!(self.dtype, DType::Struct(..)).then(|| self.as_struct())
334 }
335
336 pub fn as_list(&self) -> ListScalar<'_> {
342 ListScalar::try_from(self).vortex_expect("Failed to convert scalar to list")
343 }
344
345 pub fn as_list_opt(&self) -> Option<ListScalar<'_>> {
347 matches!(self.dtype, DType::List(..)).then(|| self.as_list())
348 }
349
350 pub fn as_extension(&self) -> ExtScalar<'_> {
356 ExtScalar::try_from(self).vortex_expect("Failed to convert scalar to extension")
357 }
358
359 pub fn as_extension_opt(&self) -> Option<ExtScalar<'_>> {
361 matches!(self.dtype, DType::Extension(..)).then(|| self.as_extension())
362 }
363}
364
365impl PartialEq for Scalar {
366 fn eq(&self, other: &Self) -> bool {
367 if !self.dtype.eq_ignore_nullability(&other.dtype) {
368 return false;
369 }
370
371 match self.dtype() {
372 DType::Null => true,
373 DType::Bool(_) => self.as_bool() == other.as_bool(),
374 DType::Primitive(..) => self.as_primitive() == other.as_primitive(),
375 DType::Decimal(..) => self.as_decimal() == other.as_decimal(),
376 DType::Utf8(_) => self.as_utf8() == other.as_utf8(),
377 DType::Binary(_) => self.as_binary() == other.as_binary(),
378 DType::Struct(..) => self.as_struct() == other.as_struct(),
379 DType::List(..) => self.as_list() == other.as_list(),
380 DType::Extension(_) => self.as_extension() == other.as_extension(),
381 }
382 }
383}
384
385impl Eq for Scalar {}
386
387impl PartialOrd for Scalar {
388 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
417 if !self.dtype().eq_ignore_nullability(other.dtype()) {
418 return None;
419 }
420 match self.dtype() {
421 DType::Null => Some(Ordering::Equal),
422 DType::Bool(_) => self.as_bool().partial_cmp(&other.as_bool()),
423 DType::Primitive(..) => self.as_primitive().partial_cmp(&other.as_primitive()),
424 DType::Decimal(..) => self.as_decimal().partial_cmp(&other.as_decimal()),
425 DType::Utf8(_) => self.as_utf8().partial_cmp(&other.as_utf8()),
426 DType::Binary(_) => self.as_binary().partial_cmp(&other.as_binary()),
427 DType::Struct(..) => self.as_struct().partial_cmp(&other.as_struct()),
428 DType::List(..) => self.as_list().partial_cmp(&other.as_list()),
429 DType::Extension(_) => self.as_extension().partial_cmp(&other.as_extension()),
430 }
431 }
432}
433
434impl Hash for Scalar {
435 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
436 match self.dtype() {
437 DType::Null => self.dtype().hash(state), DType::Bool(_) => self.as_bool().hash(state),
439 DType::Primitive(..) => self.as_primitive().hash(state),
440 DType::Decimal(..) => self.as_decimal().hash(state),
441 DType::Utf8(_) => self.as_utf8().hash(state),
442 DType::Binary(_) => self.as_binary().hash(state),
443 DType::Struct(..) => self.as_struct().hash(state),
444 DType::List(..) => self.as_list().hash(state),
445 DType::Extension(_) => self.as_extension().hash(state),
446 }
447 }
448}
449
450impl AsRef<Self> for Scalar {
451 fn as_ref(&self) -> &Self {
452 self
453 }
454}
455
456impl<T> From<Option<T>> for Scalar
457where
458 T: ScalarType,
459 Scalar: From<T>,
460{
461 fn from(value: Option<T>) -> Self {
462 value
463 .map(Scalar::from)
464 .map(|x| x.into_nullable())
465 .unwrap_or_else(|| Scalar {
466 dtype: T::dtype().as_nullable(),
467 value: ScalarValue(InnerScalarValue::Null),
468 })
469 }
470}
471
472impl From<PrimitiveScalar<'_>> for Scalar {
473 fn from(pscalar: PrimitiveScalar<'_>) -> Self {
474 let dtype = pscalar.dtype().clone();
475 let value = pscalar
476 .pvalue()
477 .map(|pvalue| ScalarValue(InnerScalarValue::Primitive(pvalue)))
478 .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null));
479 Self::new(dtype, value)
480 }
481}
482
483impl From<DecimalScalar<'_>> for Scalar {
484 fn from(decimal_scalar: DecimalScalar<'_>) -> Self {
485 let dtype = decimal_scalar.dtype().clone();
486 let value = decimal_scalar
487 .decimal_value()
488 .map(|value| ScalarValue(InnerScalarValue::Decimal(value)))
489 .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null));
490 Self::new(dtype, value)
491 }
492}
493
494macro_rules! from_vec_for_scalar {
495 ($T:ty) => {
496 impl From<Vec<$T>> for Scalar {
497 fn from(value: Vec<$T>) -> Self {
498 Scalar {
499 dtype: DType::List(Arc::from(<$T>::dtype()), Nullability::NonNullable),
500 value: ScalarValue(InnerScalarValue::List(
501 value
502 .into_iter()
503 .map(Scalar::from)
504 .map(|s| s.into_value())
505 .collect::<Arc<[_]>>(),
506 )),
507 }
508 }
509 }
510 };
511}
512
513from_vec_for_scalar!(u16);
515from_vec_for_scalar!(u32);
516from_vec_for_scalar!(u64);
517from_vec_for_scalar!(usize); from_vec_for_scalar!(i8);
519from_vec_for_scalar!(i16);
520from_vec_for_scalar!(i32);
521from_vec_for_scalar!(i64);
522from_vec_for_scalar!(f16);
523from_vec_for_scalar!(f32);
524from_vec_for_scalar!(f64);
525from_vec_for_scalar!(String);
526from_vec_for_scalar!(BufferString);
527from_vec_for_scalar!(ByteBuffer);