vortex_array/scalar/scalar_impl.rs
1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Core [`Scalar`] type definition.
5
6use std::cmp::Ordering;
7use std::hash::Hash;
8use std::hash::Hasher;
9
10use vortex_error::VortexResult;
11use vortex_error::vortex_ensure_eq;
12use vortex_error::vortex_panic;
13
14use crate::dtype::DType;
15use crate::dtype::NativeDType;
16use crate::dtype::PType;
17use crate::dtype::StructFields;
18use crate::scalar::Scalar;
19use crate::scalar::ScalarValue;
20
21impl Scalar {
22 // Constructors for null scalars.
23
24 /// Creates a new null [`Scalar`] with the given [`DType`].
25 ///
26 /// # Panics
27 ///
28 /// Panics if the given [`DType`] is non-nullable.
29 pub fn null(dtype: DType) -> Self {
30 assert!(
31 dtype.is_nullable(),
32 "Cannot create null scalar with non-nullable dtype {dtype}"
33 );
34
35 Self { dtype, value: None }
36 }
37
38 // TODO(connor): This method arguably shouldn't exist...
39 /// Creates a new null [`Scalar`] for the given scalar type.
40 ///
41 /// The resulting scalar will have a nullable version of the type's data type.
42 pub fn null_native<T: NativeDType>() -> Self {
43 Self {
44 dtype: T::dtype().as_nullable(),
45 value: None,
46 }
47 }
48
49 // Constructors for potentially null scalars.
50
51 /// Creates a new [`Scalar`] with the given [`DType`] and potentially null [`ScalarValue`].
52 ///
53 /// This is just a helper function for tests.
54 ///
55 /// # Panics
56 ///
57 /// Panics if the given [`DType`] and [`ScalarValue`] are incompatible.
58 #[cfg(test)]
59 pub fn new(dtype: DType, value: Option<ScalarValue>) -> Self {
60 use vortex_error::VortexExpect;
61
62 Self::try_new(dtype, value).vortex_expect("Failed to create Scalar")
63 }
64
65 /// Attempts to create a new [`Scalar`] with the given [`DType`] and potentially null
66 /// [`ScalarValue`].
67 ///
68 /// # Errors
69 ///
70 /// Returns an error if the given [`DType`] and [`ScalarValue`] are incompatible.
71 pub fn try_new(dtype: DType, value: Option<ScalarValue>) -> VortexResult<Self> {
72 Self::validate(&dtype, value.as_ref())?;
73
74 Ok(Self { dtype, value })
75 }
76
77 /// Creates a new [`Scalar`] with the given [`DType`] and potentially null [`ScalarValue`]
78 /// without checking compatibility.
79 ///
80 /// # Safety
81 ///
82 /// The caller must ensure that the given [`DType`] and [`ScalarValue`] are compatible per the
83 /// rules defined in [`Self::validate`].
84 pub unsafe fn new_unchecked(dtype: DType, value: Option<ScalarValue>) -> Self {
85 #[cfg(debug_assertions)]
86 {
87 use vortex_error::VortexExpect;
88
89 Self::validate(&dtype, value.as_ref())
90 .vortex_expect("Scalar::new_unchecked called with incompatible dtype and value");
91 }
92
93 Self { dtype, value }
94 }
95
96 /// Returns a default value for the given [`DType`].
97 ///
98 /// For nullable types, this returns a null scalar. For non-nullable and non-nested types, this
99 /// returns the zero value for the type.
100 ///
101 /// See [`Scalar::zero_value`] for more details about "zero" values.
102 ///
103 /// For non-nullable and nested types that may need null values in their children (as of right
104 /// now, that is _only_ `FixedSizeList` and `Struct`), this function will provide null default
105 /// children.
106 pub fn default_value(dtype: &DType) -> Self {
107 let value = ScalarValue::default_value(dtype);
108
109 // SAFETY: We assume that `default_value` creates a valid `ScalarValue` for the `DType`.
110 unsafe { Self::new_unchecked(dtype.clone(), value) }
111 }
112
113 /// Returns a non-null zero / identity value for the given [`DType`].
114 ///
115 /// # Zero Values
116 ///
117 /// Here is the list of zero values for each [`DType`] (when the [`DType`] is non-nullable):
118 ///
119 /// - `Null`: Does not have a "zero" value
120 /// - `Bool`: `false`
121 /// - `Primitive`: `0`
122 /// - `Decimal`: `0`
123 /// - `Utf8`: `""`
124 /// - `Binary`: An empty buffer
125 /// - `List`: An empty list
126 /// - `FixedSizeList`: A list (with correct size) of zero values, which is determined by the
127 /// element [`DType`]
128 /// - `Struct`: A struct where each field has a zero value, which is determined by the field
129 /// [`DType`]
130 /// - `Extension`: The zero value of the storage [`DType`]
131 pub fn zero_value(dtype: &DType) -> Self {
132 let value = ScalarValue::zero_value(dtype);
133
134 // SAFETY: We assume that `zero_value` creates a valid `ScalarValue` for the `DType`.
135 unsafe { Self::new_unchecked(dtype.clone(), Some(value)) }
136 }
137
138 // Other methods.
139
140 /// Check if two scalars are equal, ignoring nullability of the [`DType`].
141 pub fn eq_ignore_nullability(&self, other: &Self) -> bool {
142 self.dtype.eq_ignore_nullability(&other.dtype) && self.value == other.value
143 }
144
145 /// Returns the parts of the [`Scalar`].
146 pub fn into_parts(self) -> (DType, Option<ScalarValue>) {
147 (self.dtype, self.value)
148 }
149
150 /// Returns the [`DType`] of the [`Scalar`].
151 pub fn dtype(&self) -> &DType {
152 &self.dtype
153 }
154
155 /// Returns an optional [`ScalarValue`] of the [`Scalar`], where `None` means the value is null.
156 pub fn value(&self) -> Option<&ScalarValue> {
157 self.value.as_ref()
158 }
159
160 /// Returns the internal optional [`ScalarValue`], where `None` means the value is null,
161 /// consuming the [`Scalar`].
162 pub fn into_value(self) -> Option<ScalarValue> {
163 self.value
164 }
165
166 /// Returns `true` if the [`Scalar`] has a non-null value.
167 pub fn is_valid(&self) -> bool {
168 self.value.is_some()
169 }
170
171 /// Returns `true` if the [`Scalar`] is null.
172 pub fn is_null(&self) -> bool {
173 self.value.is_none()
174 }
175
176 /// Returns `true` if the [`Scalar`] has a non-null zero value.
177 ///
178 /// Returns `None` if the scalar is null, otherwise returns `Some(true)` if the value is zero
179 /// and `Some(false)` otherwise.
180 pub fn is_zero(&self) -> Option<bool> {
181 let value = self.value()?;
182
183 let is_zero = match self.dtype() {
184 DType::Null => vortex_panic!("non-null value somehow had `DType::Null`"),
185 DType::Bool(_) => !value.as_bool(),
186 DType::Primitive(..) => value.as_primitive().is_zero(),
187 DType::Decimal(..) => value.as_decimal().is_zero(),
188 DType::Utf8(_) => value.as_utf8().is_empty(),
189 DType::Binary(_) => value.as_binary().is_empty(),
190 DType::List(..) => value.as_list().is_empty(),
191 DType::FixedSizeList(_, list_size, _) => value.as_list().len() == *list_size as usize,
192 DType::Struct(struct_fields, _) => value.as_list().len() == struct_fields.nfields(),
193 DType::Extension(_) => self.as_extension().to_storage_scalar().is_zero()?,
194 DType::Variant(_) => self.as_variant().is_zero()?,
195 };
196
197 Some(is_zero)
198 }
199
200 /// Reinterprets the bytes of this scalar as a different primitive type.
201 ///
202 /// # Errors
203 ///
204 /// Panics if the scalar is not a primitive type or if the types have different byte widths.
205 pub fn primitive_reinterpret_cast(&self, ptype: PType) -> VortexResult<Self> {
206 let primitive = self.as_primitive();
207 if primitive.ptype() == ptype {
208 return Ok(self.clone());
209 }
210
211 vortex_ensure_eq!(
212 primitive.ptype().byte_width(),
213 ptype.byte_width(),
214 "can't reinterpret cast between integers of two different widths"
215 );
216
217 Scalar::try_new(
218 DType::Primitive(ptype, self.dtype().nullability()),
219 primitive
220 .pvalue()
221 .map(|p| p.reinterpret_cast(ptype))
222 .map(ScalarValue::Primitive),
223 )
224 }
225
226 /// Returns an **ESTIMATE** of the size of the scalar in bytes, uncompressed.
227 ///
228 /// Note that the protobuf serialization of scalars will likely have a different (but roughly
229 /// similar) length.
230 pub fn approx_nbytes(&self) -> usize {
231 use crate::dtype::NativeDecimalType;
232 use crate::dtype::i256;
233
234 match self.dtype() {
235 DType::Null => 0,
236 DType::Bool(_) => 1,
237 DType::Primitive(ptype, _) => ptype.byte_width(),
238 DType::Decimal(dt, _) => {
239 if dt.precision() <= i128::MAX_PRECISION {
240 size_of::<i128>()
241 } else {
242 size_of::<i256>()
243 }
244 }
245 DType::Utf8(_) => self
246 .value()
247 .map_or_else(|| 0, |value| value.as_utf8().len()),
248 DType::Binary(_) => self
249 .value()
250 .map_or_else(|| 0, |value| value.as_binary().len()),
251 DType::Struct(..) => self
252 .as_struct()
253 .fields_iter()
254 .map(|fields| fields.into_iter().map(|f| f.approx_nbytes()).sum::<usize>())
255 .unwrap_or_default(),
256 DType::List(..) | DType::FixedSizeList(..) => self
257 .as_list()
258 .elements()
259 .map(|fields| fields.into_iter().map(|f| f.approx_nbytes()).sum::<usize>())
260 .unwrap_or_default(),
261 DType::Extension(_) => self.as_extension().to_storage_scalar().approx_nbytes(),
262 DType::Variant(_) => self.as_variant().value().map_or(0, Scalar::approx_nbytes),
263 }
264 }
265}
266
267/// We implement `Hash` manually to be consistent with `PartialEq`. Since we ignore nullability in
268/// equality comparisons, we must also ignore it when hashing to maintain the invariant that equal
269/// values have equal hashes.
270impl Hash for Scalar {
271 fn hash<H: Hasher>(&self, state: &mut H) {
272 self.dtype.as_nonnullable().hash(state);
273 self.value.hash(state);
274 }
275}
276
277/// We implement `PartialEq` manually because we want to ignore nullability when comparing scalars.
278/// Two scalars with the same value but different nullability should be considered equal.
279///
280/// Note that this has **different** behavior than the [`PartialOrd`] implementation since the
281/// [`PartialOrd`] returns `None` if the types are different, whereas this `PartialEq`
282/// implementation simply returns `false`.
283impl PartialEq for Scalar {
284 fn eq(&self, other: &Self) -> bool {
285 self.dtype.eq_ignore_nullability(&other.dtype) && self.value == other.value
286 }
287}
288
289impl PartialOrd for Scalar {
290 /// Compares two scalar values for ordering.
291 ///
292 /// # Returns
293 /// - `Some(Ordering)` if both scalars have the same data type (ignoring nullability)
294 /// - `None` if the scalars have different data types
295 ///
296 /// # Ordering Rules
297 /// When types match, the ordering follows these rules:
298 /// - Null values are considered less than all non-null values
299 /// - Non-null values are compared according to their natural ordering
300 ///
301 /// # Examples
302 ///
303 /// ```
304 /// use std::cmp::Ordering;
305 /// use vortex_array::dtype::DType;
306 /// use vortex_array::dtype::Nullability;
307 /// use vortex_array::dtype::PType;
308 /// use vortex_array::scalar::Scalar;
309 ///
310 /// // Same types compare successfully
311 /// let a = Scalar::primitive(10i32, Nullability::NonNullable);
312 /// let b = Scalar::primitive(20i32, Nullability::NonNullable);
313 /// assert_eq!(a.partial_cmp(&b), Some(Ordering::Less));
314 ///
315 /// // Different types return None
316 /// let int_scalar = Scalar::primitive(10i32, Nullability::NonNullable);
317 /// let str_scalar = Scalar::utf8("hello", Nullability::NonNullable);
318 /// assert_eq!(int_scalar.partial_cmp(&str_scalar), None);
319 ///
320 /// // Nulls are less than non-nulls
321 /// let null = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
322 /// let value = Scalar::primitive(0i32, Nullability::Nullable);
323 /// assert_eq!(null.partial_cmp(&value), Some(Ordering::Less));
324 /// ```
325 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
326 if !self.dtype().eq_ignore_nullability(other.dtype()) {
327 return None;
328 }
329
330 partial_cmp_scalar_values(self.dtype(), self.value(), other.value())
331 }
332}
333
334/// Compare two optional scalar values using `dtype` for nested tuple interpretation.
335fn partial_cmp_scalar_values(
336 dtype: &DType,
337 lhs: Option<&ScalarValue>,
338 rhs: Option<&ScalarValue>,
339) -> Option<Ordering> {
340 match (lhs, rhs) {
341 (None, None) => Some(Ordering::Equal),
342 (None, Some(_)) => Some(Ordering::Less),
343 (Some(_), None) => Some(Ordering::Greater),
344 (Some(lhs), Some(rhs)) => partial_cmp_non_null_scalar_values(dtype, lhs, rhs),
345 }
346}
347
348/// Compare two non-null scalar values, consulting `dtype` only for tuple-backed values.
349fn partial_cmp_non_null_scalar_values(
350 dtype: &DType,
351 lhs: &ScalarValue,
352 rhs: &ScalarValue,
353) -> Option<Ordering> {
354 // `Scalar::validate` guarantees that a scalar's value matches its dtype. Most of the scalar
355 // value variants have only 1 method of comparison, regardless of the dtype.
356 match (lhs, rhs) {
357 (ScalarValue::Bool(lhs), ScalarValue::Bool(rhs)) => lhs.partial_cmp(rhs),
358 (ScalarValue::Primitive(lhs), ScalarValue::Primitive(rhs)) => lhs.partial_cmp(rhs),
359 (ScalarValue::Decimal(lhs), ScalarValue::Decimal(rhs)) => lhs.partial_cmp(rhs),
360 (ScalarValue::Utf8(lhs), ScalarValue::Utf8(rhs)) => lhs.partial_cmp(rhs),
361 (ScalarValue::Binary(lhs), ScalarValue::Binary(rhs)) => lhs.partial_cmp(rhs),
362 // `Tuple` is the exception here. Since it backs lists, fixed-size lists, and structs, we
363 // need the dtype to know whether children share one element dtype or use per-field dtypes.
364 (ScalarValue::Tuple(lhs), ScalarValue::Tuple(rhs)) => {
365 partial_cmp_tuple_values(dtype, lhs, rhs)
366 }
367 // Variant values can have a different dtype in each row, so it doesn't make sense to
368 // compare them.
369 (ScalarValue::Variant(_), ScalarValue::Variant(_)) => None,
370 _ => None,
371 }
372}
373
374/// Compare tuple values according to the list, fixed-size list, or struct dtype layout.
375fn partial_cmp_tuple_values(
376 dtype: &DType,
377 lhs: &[Option<ScalarValue>],
378 rhs: &[Option<ScalarValue>],
379) -> Option<Ordering> {
380 match dtype {
381 DType::List(element_dtype, _) | DType::FixedSizeList(element_dtype, ..) => {
382 partial_cmp_list_values(element_dtype, lhs, rhs)
383 }
384 DType::Struct(fields, _) => partial_cmp_struct_values(fields, lhs, rhs),
385 DType::Extension(ext_dtype) => {
386 partial_cmp_tuple_values(ext_dtype.storage_dtype(), lhs, rhs)
387 }
388 _ => None,
389 }
390}
391
392/// Compare list tuple values using the shared element dtype for each element.
393fn partial_cmp_list_values(
394 element_dtype: &DType,
395 lhs: &[Option<ScalarValue>],
396 rhs: &[Option<ScalarValue>],
397) -> Option<Ordering> {
398 for (lhs, rhs) in lhs.iter().zip(rhs.iter()) {
399 match partial_cmp_scalar_values(element_dtype, lhs.as_ref(), rhs.as_ref())? {
400 Ordering::Equal => continue,
401 ordering => return Some(ordering),
402 }
403 }
404
405 Some(lhs.len().cmp(&rhs.len()))
406}
407
408/// Compare struct tuple values using each field's dtype in field order.
409fn partial_cmp_struct_values(
410 fields: &StructFields,
411 lhs: &[Option<ScalarValue>],
412 rhs: &[Option<ScalarValue>],
413) -> Option<Ordering> {
414 if lhs.len() != fields.nfields() || rhs.len() != fields.nfields() {
415 return None;
416 }
417
418 for ((field_dtype, lhs), rhs) in fields.fields().zip(lhs.iter()).zip(rhs.iter()) {
419 match partial_cmp_scalar_values(&field_dtype, lhs.as_ref(), rhs.as_ref())? {
420 Ordering::Equal => continue,
421 ordering => return Some(ordering),
422 }
423 }
424
425 Some(Ordering::Equal)
426}