Skip to main content

vortex_array/scalar/
scalar_impl.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Core [`Scalar`] type definition.
5
6use std::cmp::Ordering;
7use std::hash::Hash;
8use std::hash::Hasher;
9
10use vortex_error::VortexResult;
11use vortex_error::vortex_ensure_eq;
12use vortex_error::vortex_panic;
13
14use crate::dtype::DType;
15use crate::dtype::NativeDType;
16use crate::dtype::PType;
17use crate::scalar::Scalar;
18use crate::scalar::ScalarValue;
19
20impl Scalar {
21    // Constructors for null scalars.
22
23    /// Creates a new null [`Scalar`] with the given [`DType`].
24    ///
25    /// # Panics
26    ///
27    /// Panics if the given [`DType`] is non-nullable.
28    pub fn null(dtype: DType) -> Self {
29        assert!(
30            dtype.is_nullable(),
31            "Cannot create null scalar with non-nullable dtype {dtype}"
32        );
33
34        Self { dtype, value: None }
35    }
36
37    // TODO(connor): This method arguably shouldn't exist...
38    /// Creates a new null [`Scalar`] for the given scalar type.
39    ///
40    /// The resulting scalar will have a nullable version of the type's data type.
41    pub fn null_native<T: NativeDType>() -> Self {
42        Self {
43            dtype: T::dtype().as_nullable(),
44            value: None,
45        }
46    }
47
48    // Constructors for potentially null scalars.
49
50    /// Creates a new [`Scalar`] with the given [`DType`] and potentially null [`ScalarValue`].
51    ///
52    /// This is just a helper function for tests.
53    ///
54    /// # Panics
55    ///
56    /// Panics if the given [`DType`] and [`ScalarValue`] are incompatible.
57    #[cfg(test)]
58    pub fn new(dtype: DType, value: Option<ScalarValue>) -> Self {
59        use vortex_error::VortexExpect;
60
61        Self::try_new(dtype, value).vortex_expect("Failed to create Scalar")
62    }
63
64    /// Attempts to create a new [`Scalar`] with the given [`DType`] and potentially null
65    /// [`ScalarValue`].
66    ///
67    /// # Errors
68    ///
69    /// Returns an error if the given [`DType`] and [`ScalarValue`] are incompatible.
70    pub fn try_new(dtype: DType, value: Option<ScalarValue>) -> VortexResult<Self> {
71        Self::validate(&dtype, value.as_ref())?;
72
73        Ok(Self { dtype, value })
74    }
75
76    /// Creates a new [`Scalar`] with the given [`DType`] and potentially null [`ScalarValue`]
77    /// without checking compatibility.
78    ///
79    /// # Safety
80    ///
81    /// The caller must ensure that the given [`DType`] and [`ScalarValue`] are compatible per the
82    /// rules defined in [`Self::validate`].
83    pub unsafe fn new_unchecked(dtype: DType, value: Option<ScalarValue>) -> Self {
84        #[cfg(debug_assertions)]
85        {
86            use vortex_error::VortexExpect;
87
88            Self::validate(&dtype, value.as_ref())
89                .vortex_expect("Scalar::new_unchecked called with incompatible dtype and value");
90        }
91
92        Self { dtype, value }
93    }
94
95    /// Returns a default value for the given [`DType`].
96    ///
97    /// For nullable types, this returns a null scalar. For non-nullable and non-nested types, this
98    /// returns the zero value for the type.
99    ///
100    /// See [`Scalar::zero_value`] for more details about "zero" values.
101    ///
102    /// For non-nullable and nested types that may need null values in their children (as of right
103    /// now, that is _only_ `FixedSizeList` and `Struct`), this function will provide null default
104    /// children.
105    pub fn default_value(dtype: &DType) -> Self {
106        let value = ScalarValue::default_value(dtype);
107
108        // SAFETY: We assume that `default_value` creates a valid `ScalarValue` for the `DType`.
109        unsafe { Self::new_unchecked(dtype.clone(), value) }
110    }
111
112    /// Returns a non-null zero / identity value for the given [`DType`].
113    ///
114    /// # Zero Values
115    ///
116    /// Here is the list of zero values for each [`DType`] (when the [`DType`] is non-nullable):
117    ///
118    /// - `Null`: Does not have a "zero" value
119    /// - `Bool`: `false`
120    /// - `Primitive`: `0`
121    /// - `Decimal`: `0`
122    /// - `Utf8`: `""`
123    /// - `Binary`: An empty buffer
124    /// - `List`: An empty list
125    /// - `FixedSizeList`: A list (with correct size) of zero values, which is determined by the
126    ///   element [`DType`]
127    /// - `Struct`: A struct where each field has a zero value, which is determined by the field
128    ///   [`DType`]
129    /// - `Extension`: The zero value of the storage [`DType`]
130    pub fn zero_value(dtype: &DType) -> Self {
131        let value = ScalarValue::zero_value(dtype);
132
133        // SAFETY: We assume that `zero_value` creates a valid `ScalarValue` for the `DType`.
134        unsafe { Self::new_unchecked(dtype.clone(), Some(value)) }
135    }
136
137    // Other methods.
138
139    /// Check if two scalars are equal, ignoring nullability of the [`DType`].
140    pub fn eq_ignore_nullability(&self, other: &Self) -> bool {
141        self.dtype.eq_ignore_nullability(&other.dtype) && self.value == other.value
142    }
143
144    /// Returns the parts of the [`Scalar`].
145    pub fn into_parts(self) -> (DType, Option<ScalarValue>) {
146        (self.dtype, self.value)
147    }
148
149    /// Returns the [`DType`] of the [`Scalar`].
150    pub fn dtype(&self) -> &DType {
151        &self.dtype
152    }
153
154    /// Returns an optional [`ScalarValue`] of the [`Scalar`], where `None` means the value is null.
155    pub fn value(&self) -> Option<&ScalarValue> {
156        self.value.as_ref()
157    }
158
159    /// Returns the internal optional [`ScalarValue`], where `None` means the value is null,
160    /// consuming the [`Scalar`].
161    pub fn into_value(self) -> Option<ScalarValue> {
162        self.value
163    }
164
165    /// Returns `true` if the [`Scalar`] has a non-null value.
166    pub fn is_valid(&self) -> bool {
167        self.value.is_some()
168    }
169
170    /// Returns `true` if the [`Scalar`] is null.
171    pub fn is_null(&self) -> bool {
172        self.value.is_none()
173    }
174
175    /// Returns `true` if the [`Scalar`] has a non-null zero value.
176    ///
177    /// Returns `None` if the scalar is null, otherwise returns `Some(true)` if the value is zero
178    /// and `Some(false)` otherwise.
179    pub fn is_zero(&self) -> Option<bool> {
180        let value = self.value()?;
181
182        let is_zero = match self.dtype() {
183            DType::Null => vortex_panic!("non-null value somehow had `DType::Null`"),
184            DType::Bool(_) => !value.as_bool(),
185            DType::Primitive(..) => value.as_primitive().is_zero(),
186            DType::Decimal(..) => value.as_decimal().is_zero(),
187            DType::Utf8(_) => value.as_utf8().is_empty(),
188            DType::Binary(_) => value.as_binary().is_empty(),
189            DType::List(..) => value.as_list().is_empty(),
190            DType::FixedSizeList(_, list_size, _) => value.as_list().len() == *list_size as usize,
191            DType::Struct(struct_fields, _) => value.as_list().len() == struct_fields.nfields(),
192            DType::Extension(_) => self.as_extension().to_storage_scalar().is_zero()?,
193            DType::Variant(_) => self.as_variant().is_zero()?,
194        };
195
196        Some(is_zero)
197    }
198
199    /// Reinterprets the bytes of this scalar as a different primitive type.
200    ///
201    /// # Errors
202    ///
203    /// Panics if the scalar is not a primitive type or if the types have different byte widths.
204    pub fn primitive_reinterpret_cast(&self, ptype: PType) -> VortexResult<Self> {
205        let primitive = self.as_primitive();
206        if primitive.ptype() == ptype {
207            return Ok(self.clone());
208        }
209
210        vortex_ensure_eq!(
211            primitive.ptype().byte_width(),
212            ptype.byte_width(),
213            "can't reinterpret cast between integers of two different widths"
214        );
215
216        Scalar::try_new(
217            DType::Primitive(ptype, self.dtype().nullability()),
218            primitive
219                .pvalue()
220                .map(|p| p.reinterpret_cast(ptype))
221                .map(ScalarValue::Primitive),
222        )
223    }
224
225    /// Returns an **ESTIMATE** of the size of the scalar in bytes, uncompressed.
226    ///
227    /// Note that the protobuf serialization of scalars will likely have a different (but roughly
228    /// similar) length.
229    pub fn approx_nbytes(&self) -> usize {
230        use crate::dtype::NativeDecimalType;
231        use crate::dtype::i256;
232
233        match self.dtype() {
234            DType::Null => 0,
235            DType::Bool(_) => 1,
236            DType::Primitive(ptype, _) => ptype.byte_width(),
237            DType::Decimal(dt, _) => {
238                if dt.precision() <= i128::MAX_PRECISION {
239                    size_of::<i128>()
240                } else {
241                    size_of::<i256>()
242                }
243            }
244            DType::Utf8(_) => self
245                .value()
246                .map_or_else(|| 0, |value| value.as_utf8().len()),
247            DType::Binary(_) => self
248                .value()
249                .map_or_else(|| 0, |value| value.as_binary().len()),
250            DType::Struct(..) => self
251                .as_struct()
252                .fields_iter()
253                .map(|fields| fields.into_iter().map(|f| f.approx_nbytes()).sum::<usize>())
254                .unwrap_or_default(),
255            DType::List(..) | DType::FixedSizeList(..) => self
256                .as_list()
257                .elements()
258                .map(|fields| fields.into_iter().map(|f| f.approx_nbytes()).sum::<usize>())
259                .unwrap_or_default(),
260            DType::Extension(_) => self.as_extension().to_storage_scalar().approx_nbytes(),
261            DType::Variant(_) => self.as_variant().value().map_or(0, Scalar::approx_nbytes),
262        }
263    }
264}
265
266/// We implement `PartialEq` manually because we want to ignore nullability when comparing scalars.
267/// Two scalars with the same value but different nullability should be considered equal.
268///
269/// Note that this has **different** behavior than the [`PartialOrd`] implementation since the
270/// [`PartialOrd`] returns `None` if the types are different, whereas this `PartialEq`
271/// implementation simply returns `false`.
272impl PartialEq for Scalar {
273    fn eq(&self, other: &Self) -> bool {
274        self.dtype.eq_ignore_nullability(&other.dtype) && self.value == other.value
275    }
276}
277
278impl PartialOrd for Scalar {
279    /// Compares two scalar values for ordering.
280    ///
281    /// # Returns
282    /// - `Some(Ordering)` if both scalars have the same data type (ignoring nullability)
283    /// - `None` if the scalars have different data types
284    ///
285    /// # Ordering Rules
286    /// When types match, the ordering follows these rules:
287    /// - Null values are considered less than all non-null values
288    /// - Non-null values are compared according to their natural ordering
289    ///
290    /// # Examples
291    /// ```ignore
292    /// // Same types compare successfully
293    /// let a = Scalar::primitive(10i32, Nullability::NonNullable);
294    /// let b = Scalar::primitive(20i32, Nullability::NonNullable);
295    /// assert_eq!(a.partial_cmp(&b), Some(Ordering::Less));
296    ///
297    /// // Different types return None
298    /// let int_scalar = Scalar::primitive(10i32, Nullability::NonNullable);
299    /// let str_scalar = Scalar::utf8("hello", Nullability::NonNullable);
300    /// assert_eq!(int_scalar.partial_cmp(&str_scalar), None);
301    ///
302    /// // Nulls are less than non-nulls
303    /// let null = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
304    /// let value = Scalar::primitive(0i32, Nullability::Nullable);
305    /// assert_eq!(null.partial_cmp(&value), Some(Ordering::Less));
306    /// ```
307    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
308        if !self.dtype().eq_ignore_nullability(other.dtype()) {
309            return None;
310        }
311        self.value().partial_cmp(&other.value())
312    }
313}
314
315/// We implement `Hash` manually to be consistent with `PartialEq`. Since we ignore nullability
316/// in equality comparisons, we must also ignore it when hashing to maintain the invariant that
317/// equal values have equal hashes.
318impl Hash for Scalar {
319    fn hash<H: Hasher>(&self, state: &mut H) {
320        self.dtype.as_nonnullable().hash(state);
321        self.value.hash(state);
322    }
323}