Skip to main content

vortex_array/scalar/
scalar_impl.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Core [`Scalar`] type definition.
5
6use std::cmp::Ordering;
7use std::hash::Hash;
8use std::hash::Hasher;
9
10use vortex_dtype::DType;
11use vortex_dtype::NativeDType;
12use vortex_dtype::PType;
13use vortex_error::VortexResult;
14use vortex_error::vortex_ensure;
15use vortex_error::vortex_ensure_eq;
16use vortex_error::vortex_panic;
17
18use crate::scalar::PValue;
19use crate::scalar::Scalar;
20use crate::scalar::ScalarValue;
21
22/// We implement `PartialEq` manually because we want to ignore nullability when comparing scalars.
23/// Two scalars with the same value but different nullability should be considered equal.
24impl PartialEq for Scalar {
25    fn eq(&self, other: &Self) -> bool {
26        self.dtype.eq_ignore_nullability(&other.dtype) && self.value == other.value
27    }
28}
29
30/// We implement `Hash` manually to be consistent with `PartialEq`. Since we ignore nullability
31/// in equality comparisons, we must also ignore it when hashing to maintain the invariant that
32/// equal values have equal hashes.
33impl Hash for Scalar {
34    fn hash<H: Hasher>(&self, state: &mut H) {
35        self.dtype.as_nonnullable().hash(state);
36        self.value.hash(state);
37    }
38}
39
40impl Scalar {
41    // Constructors for null scalars.
42
43    /// Creates a new null [`Scalar`] with the given [`DType`].
44    ///
45    /// # Panics
46    ///
47    /// Panics if the given [`DType`] is non-nullable.
48    pub fn null(dtype: DType) -> Self {
49        assert!(
50            dtype.is_nullable(),
51            "Cannot create null scalar with non-nullable dtype {dtype}"
52        );
53
54        Self { dtype, value: None }
55    }
56
57    // TODO(connor): This method arguably shouldn't exist...
58    /// Creates a new null [`Scalar`] for the given scalar type.
59    ///
60    /// The resulting scalar will have a nullable version of the type's data type.
61    pub fn null_native<T: NativeDType>() -> Self {
62        Self {
63            dtype: T::dtype().as_nullable(),
64            value: None,
65        }
66    }
67
68    // Constructors for potentially null scalars.
69
70    /// Creates a new [`Scalar`] with the given [`DType`] and potentially null [`ScalarValue`].
71    ///
72    /// This is just a helper function for tests.
73    ///
74    /// # Panics
75    ///
76    /// Panics if the given [`DType`] and [`ScalarValue`] are incompatible.
77    #[cfg(test)]
78    pub fn new(dtype: DType, value: Option<ScalarValue>) -> Self {
79        use vortex_error::VortexExpect;
80
81        Self::try_new(dtype, value).vortex_expect("Failed to create Scalar")
82    }
83
84    /// Attempts to create a new [`Scalar`] with the given [`DType`] and potentially null
85    /// [`ScalarValue`].
86    ///
87    /// # Errors
88    ///
89    /// Returns an error if the given [`DType`] and [`ScalarValue`] are incompatible.
90    pub fn try_new(dtype: DType, value: Option<ScalarValue>) -> VortexResult<Self> {
91        vortex_ensure!(
92            Self::is_compatible(&dtype, value.as_ref()),
93            "Incompatible dtype {dtype} with value {}",
94            value.map(|v| format!("{}", v)).unwrap_or_default()
95        );
96
97        Ok(Self { dtype, value })
98    }
99
100    /// Creates a new [`Scalar`] with the given [`DType`] and potentially null [`ScalarValue`]
101    /// without checking compatibility.
102    ///
103    /// # Safety
104    ///
105    /// The caller must ensure that the given [`DType`] and [`ScalarValue`] are compatible per the
106    /// rules defined in [`Self::is_compatible`].
107    pub unsafe fn new_unchecked(dtype: DType, value: Option<ScalarValue>) -> Self {
108        debug_assert!(
109            Self::is_compatible(&dtype, value.as_ref()),
110            "Incompatible dtype {dtype} with value {}",
111            value.map(|v| format!("{}", v)).unwrap_or_default()
112        );
113
114        Self { dtype, value }
115    }
116
117    /// Returns a default value for the given [`DType`].
118    ///
119    /// For nullable types, this returns a null scalar. For non-nullable and non-nested types, this
120    /// returns the zero value for the type.
121    ///
122    /// For non-nullable and nested types that may need null values in their children (as of right
123    /// now, that is _only_ `FixedSizeList` and `Struct`), this function will provide null default
124    /// children.
125    ///
126    /// See [`ScalarValue::zero_value`] for more details about "zero" values.
127    pub fn default_value(dtype: &DType) -> Self {
128        let value = ScalarValue::default_value(dtype);
129        // SAFETY: We assume that `default_value` creates a valid `ScalarValue` for the `DType`.
130        unsafe { Self::new_unchecked(dtype.clone(), value) }
131    }
132
133    /// Returns a non-null zero / identity value for the given [`DType`].
134    ///
135    /// See [`ScalarValue::zero_value`] for more details about "zero" values.
136    pub fn zero_value(dtype: &DType) -> Self {
137        let value = ScalarValue::zero_value(dtype);
138        // SAFETY: We assume that `zero_value` creates a valid `ScalarValue` for the `DType`.
139        unsafe { Self::new_unchecked(dtype.clone(), Some(value)) }
140    }
141
142    // Other methods.
143
144    /// Check if the given [`ScalarValue`] is compatible with the given [`DType`].
145    pub fn is_compatible(dtype: &DType, value: Option<&ScalarValue>) -> bool {
146        let Some(value) = value else {
147            return dtype.is_nullable();
148        };
149        // From here on, we know that the value is not null.
150
151        match dtype {
152            DType::Null => false,
153            DType::Bool(_) => matches!(value, ScalarValue::Bool(_)),
154            DType::Primitive(ptype, _) => {
155                if let ScalarValue::Primitive(pvalue) = value {
156                    // Note that this is a backwards compatibility check for poor design in the
157                    // previous implementation. `f16` `ScalarValue`s used to be serialized as
158                    // `pb::ScalarValue::Uint64Value(v.to_bits() as u64)`, so we need to ensure that
159                    // we can still represent them as such.
160                    let f16_backcompat_still_works =
161                        matches!(ptype, &PType::F16) && matches!(pvalue, PValue::U64(_));
162
163                    f16_backcompat_still_works || pvalue.ptype() == *ptype
164                } else {
165                    false
166                }
167            }
168            DType::Decimal(dec_dtype, _) => {
169                if let ScalarValue::Decimal(dvalue) = value {
170                    dvalue.fits_in_precision(*dec_dtype)
171                } else {
172                    false
173                }
174            }
175            DType::Utf8(_) => matches!(value, ScalarValue::Utf8(_)),
176            DType::Binary(_) => matches!(value, ScalarValue::Binary(_)),
177            DType::List(elem_dtype, _) => {
178                if let ScalarValue::List(elements) = value {
179                    elements
180                        .iter()
181                        .all(|element| Self::is_compatible(elem_dtype.as_ref(), element.as_ref()))
182                } else {
183                    false
184                }
185            }
186            DType::FixedSizeList(elem_dtype, size, _) => {
187                if let ScalarValue::List(elements) = value {
188                    if elements.len() != *size as usize {
189                        return false;
190                    }
191                    elements
192                        .iter()
193                        .all(|element| Self::is_compatible(elem_dtype.as_ref(), element.as_ref()))
194                } else {
195                    false
196                }
197            }
198            DType::Struct(fields, _) => {
199                if let ScalarValue::List(values) = value {
200                    if values.len() != fields.nfields() {
201                        return false;
202                    }
203                    for (field, field_value) in fields.fields().zip(values.iter()) {
204                        if !Self::is_compatible(&field, field_value.as_ref()) {
205                            return false;
206                        }
207                    }
208                    true
209                } else {
210                    false
211                }
212            }
213            DType::Extension(ext_dtype) => {
214                // TODO(connor): Fix this when adding the correct extension scalars!
215                Self::is_compatible(ext_dtype.storage_dtype(), Some(value))
216            }
217        }
218    }
219
220    /// Check if two scalars are equal, ignoring nullability of the [`DType`].
221    pub fn eq_ignore_nullability(&self, other: &Self) -> bool {
222        self.dtype.eq_ignore_nullability(&other.dtype) && self.value == other.value
223    }
224
225    /// Returns the parts of the [`Scalar`].
226    pub fn into_parts(self) -> (DType, Option<ScalarValue>) {
227        (self.dtype, self.value)
228    }
229
230    /// Returns the [`DType`] of the [`Scalar`].
231    pub fn dtype(&self) -> &DType {
232        &self.dtype
233    }
234
235    /// Returns an optional [`ScalarValue`] of the [`Scalar`], where `None` means the value is null.
236    pub fn value(&self) -> Option<&ScalarValue> {
237        self.value.as_ref()
238    }
239
240    /// Returns the internal optional [`ScalarValue`], where `None` means the value is null,
241    /// consuming the [`Scalar`].
242    pub fn into_value(self) -> Option<ScalarValue> {
243        self.value
244    }
245
246    /// Returns `true` if the [`Scalar`] has a non-null value.
247    pub fn is_valid(&self) -> bool {
248        self.value.is_some()
249    }
250
251    /// Returns `true` if the [`Scalar`] is null.
252    pub fn is_null(&self) -> bool {
253        self.value.is_none()
254    }
255
256    /// Returns `true` if the [`Scalar`] has a non-null zero value.
257    ///
258    /// Returns `None` if the scalar is null, otherwise returns `Some(true)` if the value is zero
259    /// and `Some(false)` otherwise.
260    pub fn is_zero(&self) -> Option<bool> {
261        let value = self.value()?;
262
263        let is_zero = match self.dtype() {
264            DType::Null => vortex_panic!("non-null value somehow had `DType::Null`"),
265            DType::Bool(_) => !value.as_bool(),
266            DType::Primitive(..) => value.as_primitive().is_zero(),
267            DType::Decimal(..) => value.as_decimal().is_zero(),
268            DType::Utf8(_) => value.as_utf8().is_empty(),
269            DType::Binary(_) => value.as_binary().is_empty(),
270            DType::List(..) => value.as_list().is_empty(),
271            DType::FixedSizeList(_, list_size, _) => value.as_list().len() == *list_size as usize,
272            DType::Struct(struct_fields, _) => value.as_list().len() == struct_fields.nfields(),
273            DType::Extension(_) => self.as_extension().to_storage_scalar().is_zero()?,
274        };
275
276        Some(is_zero)
277    }
278
279    /// Reinterprets the bytes of this scalar as a different primitive type.
280    ///
281    /// # Errors
282    ///
283    /// Panics if the scalar is not a primitive type or if the types have different byte widths.
284    pub fn primitive_reinterpret_cast(&self, ptype: PType) -> VortexResult<Self> {
285        let primitive = self.as_primitive();
286        if primitive.ptype() == ptype {
287            return Ok(self.clone());
288        }
289
290        vortex_ensure_eq!(
291            primitive.ptype().byte_width(),
292            ptype.byte_width(),
293            "can't reinterpret cast between integers of two different widths"
294        );
295
296        Scalar::try_new(
297            DType::Primitive(ptype, self.dtype().nullability()),
298            primitive
299                .pvalue()
300                .map(|p| p.reinterpret_cast(ptype))
301                .map(ScalarValue::Primitive),
302        )
303    }
304
305    /// Returns an **ESTIMATE** of the size of the scalar in bytes, uncompressed.
306    ///
307    /// Note that the protobuf serialization of scalars will likely have a different (but roughly
308    /// similar) length.
309    pub fn nbytes(&self) -> usize {
310        use vortex_dtype::NativeDecimalType;
311        use vortex_dtype::i256;
312
313        match self.dtype() {
314            DType::Null => 0,
315            DType::Bool(_) => 1,
316            DType::Primitive(ptype, _) => ptype.byte_width(),
317            DType::Decimal(dt, _) => {
318                if dt.precision() <= i128::MAX_PRECISION {
319                    size_of::<i128>()
320                } else {
321                    size_of::<i256>()
322                }
323            }
324            DType::Utf8(_) => self
325                .value()
326                .map_or_else(|| 0, |value| value.as_utf8().len()),
327            DType::Binary(_) => self
328                .value()
329                .map_or_else(|| 0, |value| value.as_binary().len()),
330            DType::Struct(..) => self
331                .as_struct()
332                .fields_iter()
333                .map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::<usize>())
334                .unwrap_or_default(),
335            DType::List(..) | DType::FixedSizeList(..) => self
336                .as_list()
337                .elements()
338                .map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::<usize>())
339                .unwrap_or_default(),
340            DType::Extension(_) => self.as_extension().to_storage_scalar().nbytes(),
341        }
342    }
343}
344
345impl PartialOrd for Scalar {
346    /// Compares two scalar values for ordering.
347    ///
348    /// # Returns
349    /// - `Some(Ordering)` if both scalars have the same data type (ignoring nullability)
350    /// - `None` if the scalars have different data types
351    ///
352    /// # Ordering Rules
353    /// When types match, the ordering follows these rules:
354    /// - Null values are considered less than all non-null values
355    /// - Non-null values are compared according to their natural ordering
356    ///
357    /// # Examples
358    /// ```ignore
359    /// // Same types compare successfully
360    /// let a = Scalar::primitive(10i32, Nullability::NonNullable);
361    /// let b = Scalar::primitive(20i32, Nullability::NonNullable);
362    /// assert_eq!(a.partial_cmp(&b), Some(Ordering::Less));
363    ///
364    /// // Different types return None
365    /// let int_scalar = Scalar::primitive(10i32, Nullability::NonNullable);
366    /// let str_scalar = Scalar::utf8("hello", Nullability::NonNullable);
367    /// assert_eq!(int_scalar.partial_cmp(&str_scalar), None);
368    ///
369    /// // Nulls are less than non-nulls
370    /// let null = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
371    /// let value = Scalar::primitive(0i32, Nullability::Nullable);
372    /// assert_eq!(null.partial_cmp(&value), Some(Ordering::Less));
373    /// ```
374    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
375        if !self.dtype().eq_ignore_nullability(other.dtype()) {
376            return None;
377        }
378        self.value().partial_cmp(&other.value())
379    }
380}