Skip to main content

vortex_array/scalar/
scalar_impl.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Core [`Scalar`] type definition.
5
6use std::cmp::Ordering;
7use std::hash::Hash;
8use std::hash::Hasher;
9
10use vortex_error::VortexResult;
11use vortex_error::vortex_ensure_eq;
12use vortex_error::vortex_panic;
13
14use crate::dtype::DType;
15use crate::dtype::NativeDType;
16use crate::dtype::PType;
17use crate::dtype::StructFields;
18use crate::scalar::Scalar;
19use crate::scalar::ScalarValue;
20
21impl Scalar {
22    // Constructors for null scalars.
23
24    /// Creates a new null [`Scalar`] with the given [`DType`].
25    ///
26    /// # Panics
27    ///
28    /// Panics if the given [`DType`] is non-nullable.
29    pub fn null(dtype: DType) -> Self {
30        assert!(
31            dtype.is_nullable(),
32            "Cannot create null scalar with non-nullable dtype {dtype}"
33        );
34
35        Self { dtype, value: None }
36    }
37
38    // TODO(connor): This method arguably shouldn't exist...
39    /// Creates a new null [`Scalar`] for the given scalar type.
40    ///
41    /// The resulting scalar will have a nullable version of the type's data type.
42    pub fn null_native<T: NativeDType>() -> Self {
43        Self {
44            dtype: T::dtype().as_nullable(),
45            value: None,
46        }
47    }
48
49    // Constructors for potentially null scalars.
50
51    /// Creates a new [`Scalar`] with the given [`DType`] and potentially null [`ScalarValue`].
52    ///
53    /// This is just a helper function for tests.
54    ///
55    /// # Panics
56    ///
57    /// Panics if the given [`DType`] and [`ScalarValue`] are incompatible.
58    #[cfg(test)]
59    pub fn new(dtype: DType, value: Option<ScalarValue>) -> Self {
60        use vortex_error::VortexExpect;
61
62        Self::try_new(dtype, value).vortex_expect("Failed to create Scalar")
63    }
64
65    /// Attempts to create a new [`Scalar`] with the given [`DType`] and potentially null
66    /// [`ScalarValue`].
67    ///
68    /// # Errors
69    ///
70    /// Returns an error if the given [`DType`] and [`ScalarValue`] are incompatible.
71    pub fn try_new(dtype: DType, value: Option<ScalarValue>) -> VortexResult<Self> {
72        Self::validate(&dtype, value.as_ref())?;
73
74        Ok(Self { dtype, value })
75    }
76
77    /// Creates a new [`Scalar`] with the given [`DType`] and potentially null [`ScalarValue`]
78    /// without checking compatibility.
79    ///
80    /// # Safety
81    ///
82    /// The caller must ensure that the given [`DType`] and [`ScalarValue`] are compatible per the
83    /// rules defined in [`Self::validate`].
84    pub unsafe fn new_unchecked(dtype: DType, value: Option<ScalarValue>) -> Self {
85        #[cfg(debug_assertions)]
86        {
87            use vortex_error::VortexExpect;
88
89            Self::validate(&dtype, value.as_ref())
90                .vortex_expect("Scalar::new_unchecked called with incompatible dtype and value");
91        }
92
93        Self { dtype, value }
94    }
95
96    /// Returns a default value for the given [`DType`].
97    ///
98    /// For nullable types, this returns a null scalar. For non-nullable and non-nested types, this
99    /// returns the zero value for the type.
100    ///
101    /// See [`Scalar::zero_value`] for more details about "zero" values.
102    ///
103    /// For non-nullable and nested types that may need null values in their children (as of right
104    /// now, that is _only_ `FixedSizeList` and `Struct`), this function will provide null default
105    /// children.
106    pub fn default_value(dtype: &DType) -> Self {
107        let value = ScalarValue::default_value(dtype);
108
109        // SAFETY: We assume that `default_value` creates a valid `ScalarValue` for the `DType`.
110        unsafe { Self::new_unchecked(dtype.clone(), value) }
111    }
112
113    /// Returns a non-null zero / identity value for the given [`DType`].
114    ///
115    /// # Zero Values
116    ///
117    /// Here is the list of zero values for each [`DType`] (when the [`DType`] is non-nullable):
118    ///
119    /// - `Null`: Does not have a "zero" value
120    /// - `Bool`: `false`
121    /// - `Primitive`: `0`
122    /// - `Decimal`: `0`
123    /// - `Utf8`: `""`
124    /// - `Binary`: An empty buffer
125    /// - `List`: An empty list
126    /// - `FixedSizeList`: A list (with correct size) of zero values, which is determined by the
127    ///   element [`DType`]
128    /// - `Struct`: A struct where each field has a zero value, which is determined by the field
129    ///   [`DType`]
130    /// - `Extension`: The zero value of the storage [`DType`]
131    pub fn zero_value(dtype: &DType) -> Self {
132        let value = ScalarValue::zero_value(dtype);
133
134        // SAFETY: We assume that `zero_value` creates a valid `ScalarValue` for the `DType`.
135        unsafe { Self::new_unchecked(dtype.clone(), Some(value)) }
136    }
137
138    // Other methods.
139
140    /// Check if two scalars are equal, ignoring nullability of the [`DType`].
141    pub fn eq_ignore_nullability(&self, other: &Self) -> bool {
142        self.dtype.eq_ignore_nullability(&other.dtype) && self.value == other.value
143    }
144
145    /// Returns the parts of the [`Scalar`].
146    pub fn into_parts(self) -> (DType, Option<ScalarValue>) {
147        (self.dtype, self.value)
148    }
149
150    /// Returns the [`DType`] of the [`Scalar`].
151    pub fn dtype(&self) -> &DType {
152        &self.dtype
153    }
154
155    /// Returns an optional [`ScalarValue`] of the [`Scalar`], where `None` means the value is null.
156    pub fn value(&self) -> Option<&ScalarValue> {
157        self.value.as_ref()
158    }
159
160    /// Returns the internal optional [`ScalarValue`], where `None` means the value is null,
161    /// consuming the [`Scalar`].
162    pub fn into_value(self) -> Option<ScalarValue> {
163        self.value
164    }
165
166    /// Returns `true` if the [`Scalar`] has a non-null value.
167    pub fn is_valid(&self) -> bool {
168        self.value.is_some()
169    }
170
171    /// Returns `true` if the [`Scalar`] is null.
172    pub fn is_null(&self) -> bool {
173        self.value.is_none()
174    }
175
176    /// Returns `true` if the [`Scalar`] has a non-null zero value.
177    ///
178    /// Returns `None` if the scalar is null, otherwise returns `Some(true)` if the value is zero
179    /// and `Some(false)` otherwise.
180    pub fn is_zero(&self) -> Option<bool> {
181        let value = self.value()?;
182
183        let is_zero = match self.dtype() {
184            DType::Null => vortex_panic!("non-null value somehow had `DType::Null`"),
185            DType::Bool(_) => !value.as_bool(),
186            DType::Primitive(..) => value.as_primitive().is_zero(),
187            DType::Decimal(..) => value.as_decimal().is_zero(),
188            DType::Utf8(_) => value.as_utf8().is_empty(),
189            DType::Binary(_) => value.as_binary().is_empty(),
190            DType::List(..) => value.as_list().is_empty(),
191            // TODO(connor): This seems wrong...
192            DType::FixedSizeList(_, list_size, _) => value.as_list().len() == *list_size as usize,
193            // TODO(connor): This also seems wrong...
194            DType::Struct(struct_fields, _) => value.as_list().len() == struct_fields.nfields(),
195            DType::Union(..) => todo!("TODO(connor)[Union]: unimplemented"),
196            DType::Variant(_) => self.as_variant().is_zero()?,
197            DType::Extension(_) => self.as_extension().to_storage_scalar().is_zero()?,
198        };
199
200        Some(is_zero)
201    }
202
203    /// Reinterprets the bytes of this scalar as a different primitive type.
204    ///
205    /// # Errors
206    ///
207    /// Panics if the scalar is not a primitive type or if the types have different byte widths.
208    pub fn primitive_reinterpret_cast(&self, ptype: PType) -> VortexResult<Self> {
209        let primitive = self.as_primitive();
210        if primitive.ptype() == ptype {
211            return Ok(self.clone());
212        }
213
214        vortex_ensure_eq!(
215            primitive.ptype().byte_width(),
216            ptype.byte_width(),
217            "can't reinterpret cast between integers of two different widths"
218        );
219
220        Scalar::try_new(
221            DType::Primitive(ptype, self.dtype().nullability()),
222            primitive
223                .pvalue()
224                .map(|p| p.reinterpret_cast(ptype))
225                .map(ScalarValue::Primitive),
226        )
227    }
228
229    /// Returns an **ESTIMATE** of the size of the scalar in bytes, uncompressed.
230    ///
231    /// Note that the protobuf serialization of scalars will likely have a different (but roughly
232    /// similar) length.
233    pub fn approx_nbytes(&self) -> usize {
234        use crate::dtype::NativeDecimalType;
235        use crate::dtype::i256;
236
237        match self.dtype() {
238            DType::Null => 0,
239            DType::Bool(_) => 1,
240            DType::Primitive(ptype, _) => ptype.byte_width(),
241            DType::Decimal(dt, _) => {
242                if dt.precision() <= i128::MAX_PRECISION {
243                    size_of::<i128>()
244                } else {
245                    size_of::<i256>()
246                }
247            }
248            DType::Utf8(_) => self
249                .value()
250                .map_or_else(|| 0, |value| value.as_utf8().len()),
251            DType::Binary(_) => self
252                .value()
253                .map_or_else(|| 0, |value| value.as_binary().len()),
254            DType::List(..) | DType::FixedSizeList(..) => self
255                .as_list()
256                .elements()
257                .map(|fields| fields.into_iter().map(|f| f.approx_nbytes()).sum::<usize>())
258                .unwrap_or_default(),
259            DType::Struct(..) => self
260                .as_struct()
261                .fields_iter()
262                .map(|fields| fields.into_iter().map(|f| f.approx_nbytes()).sum::<usize>())
263                .unwrap_or_default(),
264            DType::Union(..) => todo!("TODO(connor)[Union]: unimplemented"),
265            DType::Variant(_) => self.as_variant().value().map_or(0, Scalar::approx_nbytes),
266            DType::Extension(_) => self.as_extension().to_storage_scalar().approx_nbytes(),
267        }
268    }
269}
270
271/// We implement `Hash` manually to be consistent with `PartialEq`. Since we ignore nullability in
272/// equality comparisons, we must also ignore it when hashing to maintain the invariant that equal
273/// values have equal hashes.
274impl Hash for Scalar {
275    fn hash<H: Hasher>(&self, state: &mut H) {
276        self.dtype.as_nonnullable().hash(state);
277        self.value.hash(state);
278    }
279}
280
281/// We implement `PartialEq` manually because we want to ignore nullability when comparing scalars.
282/// Two scalars with the same value but different nullability should be considered equal.
283///
284/// Note that this has **different** behavior than the [`PartialOrd`] implementation since the
285/// [`PartialOrd`] returns `None` if the types are different, whereas this `PartialEq`
286/// implementation simply returns `false`.
287impl PartialEq for Scalar {
288    fn eq(&self, other: &Self) -> bool {
289        self.dtype.eq_ignore_nullability(&other.dtype) && self.value == other.value
290    }
291}
292
293impl PartialOrd for Scalar {
294    /// Compares two scalar values for ordering.
295    ///
296    /// # Returns
297    /// - `Some(Ordering)` if both scalars have the same data type (ignoring nullability)
298    /// - `None` if the scalars have different data types
299    ///
300    /// # Ordering Rules
301    /// When types match, the ordering follows these rules:
302    /// - Null values are considered less than all non-null values
303    /// - Non-null values are compared according to their natural ordering
304    ///
305    /// # Examples
306    ///
307    /// ```
308    /// use std::cmp::Ordering;
309    /// use vortex_array::dtype::DType;
310    /// use vortex_array::dtype::Nullability;
311    /// use vortex_array::dtype::PType;
312    /// use vortex_array::scalar::Scalar;
313    ///
314    /// // Same types compare successfully
315    /// let a = Scalar::primitive(10i32, Nullability::NonNullable);
316    /// let b = Scalar::primitive(20i32, Nullability::NonNullable);
317    /// assert_eq!(a.partial_cmp(&b), Some(Ordering::Less));
318    ///
319    /// // Different types return None
320    /// let int_scalar = Scalar::primitive(10i32, Nullability::NonNullable);
321    /// let str_scalar = Scalar::utf8("hello", Nullability::NonNullable);
322    /// assert_eq!(int_scalar.partial_cmp(&str_scalar), None);
323    ///
324    /// // Nulls are less than non-nulls
325    /// let null = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
326    /// let value = Scalar::primitive(0i32, Nullability::Nullable);
327    /// assert_eq!(null.partial_cmp(&value), Some(Ordering::Less));
328    /// ```
329    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
330        if !self.dtype().eq_ignore_nullability(other.dtype()) {
331            return None;
332        }
333
334        partial_cmp_scalar_values(self.dtype(), self.value(), other.value())
335    }
336}
337
338/// Compare two optional scalar values using `dtype` for nested tuple interpretation.
339fn partial_cmp_scalar_values(
340    dtype: &DType,
341    lhs: Option<&ScalarValue>,
342    rhs: Option<&ScalarValue>,
343) -> Option<Ordering> {
344    match (lhs, rhs) {
345        (None, None) => Some(Ordering::Equal),
346        (None, Some(_)) => Some(Ordering::Less),
347        (Some(_), None) => Some(Ordering::Greater),
348        (Some(lhs), Some(rhs)) => partial_cmp_non_null_scalar_values(dtype, lhs, rhs),
349    }
350}
351
352/// Compare two non-null scalar values, consulting `dtype` only for tuple-backed values.
353fn partial_cmp_non_null_scalar_values(
354    dtype: &DType,
355    lhs: &ScalarValue,
356    rhs: &ScalarValue,
357) -> Option<Ordering> {
358    // `Scalar::validate` guarantees that a scalar's value matches its dtype. Most of the scalar
359    // value variants have only 1 method of comparison, regardless of the dtype.
360    match (lhs, rhs) {
361        (ScalarValue::Bool(lhs), ScalarValue::Bool(rhs)) => lhs.partial_cmp(rhs),
362        (ScalarValue::Primitive(lhs), ScalarValue::Primitive(rhs)) => lhs.partial_cmp(rhs),
363        (ScalarValue::Decimal(lhs), ScalarValue::Decimal(rhs)) => lhs.partial_cmp(rhs),
364        (ScalarValue::Utf8(lhs), ScalarValue::Utf8(rhs)) => lhs.partial_cmp(rhs),
365        (ScalarValue::Binary(lhs), ScalarValue::Binary(rhs)) => lhs.partial_cmp(rhs),
366        // `Tuple` is the exception here. Since it backs lists, fixed-size lists, and structs, we
367        // need the dtype to know whether children share one element dtype or use per-field dtypes.
368        (ScalarValue::Tuple(lhs), ScalarValue::Tuple(rhs)) => {
369            partial_cmp_tuple_values(dtype, lhs, rhs)
370        }
371        // Variant values can have a different dtype in each row, so it doesn't make sense to
372        // compare them.
373        (ScalarValue::Variant(_), ScalarValue::Variant(_)) => None,
374        _ => None,
375    }
376}
377
378/// Compare tuple values according to the list, fixed-size list, or struct dtype layout.
379fn partial_cmp_tuple_values(
380    dtype: &DType,
381    lhs: &[Option<ScalarValue>],
382    rhs: &[Option<ScalarValue>],
383) -> Option<Ordering> {
384    match dtype {
385        DType::List(element_dtype, _) | DType::FixedSizeList(element_dtype, ..) => {
386            partial_cmp_list_values(element_dtype, lhs, rhs)
387        }
388        DType::Struct(fields, _) => partial_cmp_struct_values(fields, lhs, rhs),
389        DType::Extension(ext_dtype) => {
390            partial_cmp_tuple_values(ext_dtype.storage_dtype(), lhs, rhs)
391        }
392        _ => None,
393    }
394}
395
396/// Compare list tuple values using the shared element dtype for each element.
397fn partial_cmp_list_values(
398    element_dtype: &DType,
399    lhs: &[Option<ScalarValue>],
400    rhs: &[Option<ScalarValue>],
401) -> Option<Ordering> {
402    for (lhs, rhs) in lhs.iter().zip(rhs.iter()) {
403        match partial_cmp_scalar_values(element_dtype, lhs.as_ref(), rhs.as_ref())? {
404            Ordering::Equal => continue,
405            ordering => return Some(ordering),
406        }
407    }
408
409    Some(lhs.len().cmp(&rhs.len()))
410}
411
412/// Compare struct tuple values using each field's dtype in field order.
413fn partial_cmp_struct_values(
414    fields: &StructFields,
415    lhs: &[Option<ScalarValue>],
416    rhs: &[Option<ScalarValue>],
417) -> Option<Ordering> {
418    if lhs.len() != fields.nfields() || rhs.len() != fields.nfields() {
419        return None;
420    }
421
422    for ((field_dtype, lhs), rhs) in fields.fields().zip(lhs.iter()).zip(rhs.iter()) {
423        match partial_cmp_scalar_values(&field_dtype, lhs.as_ref(), rhs.as_ref())? {
424            Ordering::Equal => continue,
425            ordering => return Some(ordering),
426        }
427    }
428
429    Some(Ordering::Equal)
430}