Skip to main content

vortex_array/scalar/
scalar_impl.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Core [`Scalar`] type definition.
5
6use std::cmp::Ordering;
7use std::hash::Hash;
8use std::hash::Hasher;
9
10use vortex_error::VortexResult;
11use vortex_error::vortex_ensure_eq;
12use vortex_error::vortex_panic;
13
14use crate::dtype::DType;
15use crate::dtype::NativeDType;
16use crate::dtype::PType;
17use crate::dtype::StructFields;
18use crate::scalar::Scalar;
19use crate::scalar::ScalarValue;
20
21impl Scalar {
22    // Constructors for null scalars.
23
24    /// Creates a new null [`Scalar`] with the given [`DType`].
25    ///
26    /// # Panics
27    ///
28    /// Panics if the given [`DType`] is non-nullable.
29    pub fn null(dtype: DType) -> Self {
30        assert!(
31            dtype.is_nullable(),
32            "Cannot create null scalar with non-nullable dtype {dtype}"
33        );
34
35        Self { dtype, value: None }
36    }
37
38    // TODO(connor): This method arguably shouldn't exist...
39    /// Creates a new null [`Scalar`] for the given scalar type.
40    ///
41    /// The resulting scalar will have a nullable version of the type's data type.
42    pub fn null_native<T: NativeDType>() -> Self {
43        Self {
44            dtype: T::dtype().as_nullable(),
45            value: None,
46        }
47    }
48
49    // Constructors for potentially null scalars.
50
51    /// Creates a new [`Scalar`] with the given [`DType`] and potentially null [`ScalarValue`].
52    ///
53    /// This is just a helper function for tests.
54    ///
55    /// # Panics
56    ///
57    /// Panics if the given [`DType`] and [`ScalarValue`] are incompatible.
58    #[cfg(test)]
59    pub fn new(dtype: DType, value: Option<ScalarValue>) -> Self {
60        use vortex_error::VortexExpect;
61
62        Self::try_new(dtype, value).vortex_expect("Failed to create Scalar")
63    }
64
65    /// Attempts to create a new [`Scalar`] with the given [`DType`] and potentially null
66    /// [`ScalarValue`].
67    ///
68    /// # Errors
69    ///
70    /// Returns an error if the given [`DType`] and [`ScalarValue`] are incompatible.
71    pub fn try_new(dtype: DType, value: Option<ScalarValue>) -> VortexResult<Self> {
72        Self::validate(&dtype, value.as_ref())?;
73
74        Ok(Self { dtype, value })
75    }
76
77    /// Creates a new [`Scalar`] with the given [`DType`] and potentially null [`ScalarValue`]
78    /// without checking compatibility.
79    ///
80    /// # Safety
81    ///
82    /// The caller must ensure that the given [`DType`] and [`ScalarValue`] are compatible per the
83    /// rules defined in [`Self::validate`].
84    pub unsafe fn new_unchecked(dtype: DType, value: Option<ScalarValue>) -> Self {
85        #[cfg(debug_assertions)]
86        {
87            use vortex_error::VortexExpect;
88
89            Self::validate(&dtype, value.as_ref())
90                .vortex_expect("Scalar::new_unchecked called with incompatible dtype and value");
91        }
92
93        Self { dtype, value }
94    }
95
96    /// Returns a default value for the given [`DType`].
97    ///
98    /// For nullable types, this returns a null scalar. For non-nullable and non-nested types, this
99    /// returns the zero value for the type.
100    ///
101    /// See [`Scalar::zero_value`] for more details about "zero" values.
102    ///
103    /// For non-nullable and nested types that may need null values in their children (as of right
104    /// now, that is _only_ `FixedSizeList` and `Struct`), this function will provide null default
105    /// children.
106    pub fn default_value(dtype: &DType) -> Self {
107        let value = ScalarValue::default_value(dtype);
108
109        // SAFETY: We assume that `default_value` creates a valid `ScalarValue` for the `DType`.
110        unsafe { Self::new_unchecked(dtype.clone(), value) }
111    }
112
113    /// Returns a non-null zero / identity value for the given [`DType`].
114    ///
115    /// # Zero Values
116    ///
117    /// Here is the list of zero values for each [`DType`] (when the [`DType`] is non-nullable):
118    ///
119    /// - `Null`: Does not have a "zero" value
120    /// - `Bool`: `false`
121    /// - `Primitive`: `0`
122    /// - `Decimal`: `0`
123    /// - `Utf8`: `""`
124    /// - `Binary`: An empty buffer
125    /// - `List`: An empty list
126    /// - `FixedSizeList`: A list (with correct size) of zero values, which is determined by the
127    ///   element [`DType`]
128    /// - `Struct`: A struct where each field has a zero value, which is determined by the field
129    ///   [`DType`]
130    /// - `Extension`: The zero value of the storage [`DType`]
131    pub fn zero_value(dtype: &DType) -> Self {
132        let value = ScalarValue::zero_value(dtype);
133
134        // SAFETY: We assume that `zero_value` creates a valid `ScalarValue` for the `DType`.
135        unsafe { Self::new_unchecked(dtype.clone(), Some(value)) }
136    }
137
138    // Other methods.
139
140    /// Check if two scalars are equal, ignoring nullability of the [`DType`].
141    pub fn eq_ignore_nullability(&self, other: &Self) -> bool {
142        self.dtype.eq_ignore_nullability(&other.dtype) && self.value == other.value
143    }
144
145    /// Returns the parts of the [`Scalar`].
146    pub fn into_parts(self) -> (DType, Option<ScalarValue>) {
147        (self.dtype, self.value)
148    }
149
150    /// Returns the [`DType`] of the [`Scalar`].
151    pub fn dtype(&self) -> &DType {
152        &self.dtype
153    }
154
155    /// Returns an optional [`ScalarValue`] of the [`Scalar`], where `None` means the value is null.
156    pub fn value(&self) -> Option<&ScalarValue> {
157        self.value.as_ref()
158    }
159
160    /// Returns the internal optional [`ScalarValue`], where `None` means the value is null,
161    /// consuming the [`Scalar`].
162    pub fn into_value(self) -> Option<ScalarValue> {
163        self.value
164    }
165
166    /// Returns `true` if the [`Scalar`] has a non-null value.
167    pub fn is_valid(&self) -> bool {
168        self.value.is_some()
169    }
170
171    /// Returns `true` if the [`Scalar`] is null.
172    pub fn is_null(&self) -> bool {
173        self.value.is_none()
174    }
175
176    /// Returns `true` if the [`Scalar`] has a non-null zero value.
177    ///
178    /// Returns `None` if the scalar is null, otherwise returns `Some(true)` if the value is zero
179    /// and `Some(false)` otherwise.
180    pub fn is_zero(&self) -> Option<bool> {
181        let value = self.value()?;
182
183        let is_zero = match self.dtype() {
184            DType::Null => vortex_panic!("non-null value somehow had `DType::Null`"),
185            DType::Bool(_) => !value.as_bool(),
186            DType::Primitive(..) => value.as_primitive().is_zero(),
187            DType::Decimal(..) => value.as_decimal().is_zero(),
188            DType::Utf8(_) => value.as_utf8().is_empty(),
189            DType::Binary(_) => value.as_binary().is_empty(),
190            DType::List(..) => value.as_list().is_empty(),
191            // A fixed-size list is zero only if it has the expected number of elements and every
192            // element is itself a non-null zero value.
193            DType::FixedSizeList(_, list_size, _) => {
194                let list = self.as_list();
195                list.len() == *list_size as usize
196                    && (0..list.len())
197                        .all(|i| list.element(i).is_some_and(|e| e.is_zero() == Some(true)))
198            }
199            // A struct is zero only if every one of its fields is itself a non-null zero value.
200            DType::Struct(..) => self
201                .as_struct()
202                .fields_iter()
203                .is_some_and(|mut fields| fields.all(|f| f.is_zero() == Some(true))),
204            DType::Union(..) => todo!("TODO(connor)[Union]: unimplemented"),
205            DType::Variant(_) => self.as_variant().is_zero()?,
206            DType::Extension(_) => self.as_extension().to_storage_scalar().is_zero()?,
207        };
208
209        Some(is_zero)
210    }
211
212    /// Reinterprets the bytes of this scalar as a different primitive type.
213    ///
214    /// # Errors
215    ///
216    /// Panics if the scalar is not a primitive type or if the types have different byte widths.
217    pub fn primitive_reinterpret_cast(&self, ptype: PType) -> VortexResult<Self> {
218        let primitive = self.as_primitive();
219        if primitive.ptype() == ptype {
220            return Ok(self.clone());
221        }
222
223        vortex_ensure_eq!(
224            primitive.ptype().byte_width(),
225            ptype.byte_width(),
226            "can't reinterpret cast between integers of two different widths"
227        );
228
229        Scalar::try_new(
230            DType::Primitive(ptype, self.dtype().nullability()),
231            primitive
232                .pvalue()
233                .map(|p| p.reinterpret_cast(ptype))
234                .map(ScalarValue::Primitive),
235        )
236    }
237
238    /// Returns an **ESTIMATE** of the size of the scalar in bytes, uncompressed.
239    ///
240    /// Note that the protobuf serialization of scalars will likely have a different (but roughly
241    /// similar) length.
242    pub fn approx_nbytes(&self) -> usize {
243        use crate::dtype::NativeDecimalType;
244        use crate::dtype::i256;
245
246        match self.dtype() {
247            DType::Null => 0,
248            DType::Bool(_) => 1,
249            DType::Primitive(ptype, _) => ptype.byte_width(),
250            DType::Decimal(dt, _) => {
251                if dt.precision() <= i128::MAX_PRECISION {
252                    size_of::<i128>()
253                } else {
254                    size_of::<i256>()
255                }
256            }
257            DType::Utf8(_) => self
258                .value()
259                .map_or_else(|| 0, |value| value.as_utf8().len()),
260            DType::Binary(_) => self
261                .value()
262                .map_or_else(|| 0, |value| value.as_binary().len()),
263            DType::List(..) | DType::FixedSizeList(..) => self
264                .as_list()
265                .elements()
266                .map(|fields| fields.into_iter().map(|f| f.approx_nbytes()).sum::<usize>())
267                .unwrap_or_default(),
268            DType::Struct(..) => self
269                .as_struct()
270                .fields_iter()
271                .map(|fields| fields.into_iter().map(|f| f.approx_nbytes()).sum::<usize>())
272                .unwrap_or_default(),
273            DType::Union(..) => todo!("TODO(connor)[Union]: unimplemented"),
274            DType::Variant(_) => self.as_variant().value().map_or(0, Scalar::approx_nbytes),
275            DType::Extension(_) => self.as_extension().to_storage_scalar().approx_nbytes(),
276        }
277    }
278}
279
280/// We implement `Hash` manually to be consistent with `PartialEq`. Since we ignore nullability in
281/// equality comparisons, we must also ignore it when hashing to maintain the invariant that equal
282/// values have equal hashes.
283impl Hash for Scalar {
284    fn hash<H: Hasher>(&self, state: &mut H) {
285        self.dtype.as_nonnullable().hash(state);
286        self.value.hash(state);
287    }
288}
289
290/// We implement `PartialEq` manually because we want to ignore nullability when comparing scalars.
291/// Two scalars with the same value but different nullability should be considered equal.
292///
293/// Note that this has **different** behavior than the [`PartialOrd`] implementation since the
294/// [`PartialOrd`] returns `None` if the types are different, whereas this `PartialEq`
295/// implementation simply returns `false`.
296impl PartialEq for Scalar {
297    fn eq(&self, other: &Self) -> bool {
298        self.dtype.eq_ignore_nullability(&other.dtype) && self.value == other.value
299    }
300}
301
302impl PartialOrd for Scalar {
303    /// Compares two scalar values for ordering.
304    ///
305    /// # Returns
306    /// - `Some(Ordering)` if both scalars have the same data type (ignoring nullability)
307    /// - `None` if the scalars have different data types
308    ///
309    /// # Ordering Rules
310    /// When types match, the ordering follows these rules:
311    /// - Null values are considered less than all non-null values
312    /// - Non-null values are compared according to their natural ordering
313    ///
314    /// # Examples
315    ///
316    /// ```
317    /// use std::cmp::Ordering;
318    /// use vortex_array::dtype::DType;
319    /// use vortex_array::dtype::Nullability;
320    /// use vortex_array::dtype::PType;
321    /// use vortex_array::scalar::Scalar;
322    ///
323    /// // Same types compare successfully
324    /// let a = Scalar::primitive(10i32, Nullability::NonNullable);
325    /// let b = Scalar::primitive(20i32, Nullability::NonNullable);
326    /// assert_eq!(a.partial_cmp(&b), Some(Ordering::Less));
327    ///
328    /// // Different types return None
329    /// let int_scalar = Scalar::primitive(10i32, Nullability::NonNullable);
330    /// let str_scalar = Scalar::utf8("hello", Nullability::NonNullable);
331    /// assert_eq!(int_scalar.partial_cmp(&str_scalar), None);
332    ///
333    /// // Nulls are less than non-nulls
334    /// let null = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
335    /// let value = Scalar::primitive(0i32, Nullability::Nullable);
336    /// assert_eq!(null.partial_cmp(&value), Some(Ordering::Less));
337    /// ```
338    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
339        if !self.dtype().eq_ignore_nullability(other.dtype()) {
340            return None;
341        }
342
343        partial_cmp_scalar_values(self.dtype(), self.value(), other.value())
344    }
345}
346
347/// Compare two optional scalar values using `dtype` for nested tuple interpretation.
348fn partial_cmp_scalar_values(
349    dtype: &DType,
350    lhs: Option<&ScalarValue>,
351    rhs: Option<&ScalarValue>,
352) -> Option<Ordering> {
353    match (lhs, rhs) {
354        (None, None) => Some(Ordering::Equal),
355        (None, Some(_)) => Some(Ordering::Less),
356        (Some(_), None) => Some(Ordering::Greater),
357        (Some(lhs), Some(rhs)) => partial_cmp_non_null_scalar_values(dtype, lhs, rhs),
358    }
359}
360
361/// Compare two non-null scalar values, consulting `dtype` only for tuple-backed values.
362fn partial_cmp_non_null_scalar_values(
363    dtype: &DType,
364    lhs: &ScalarValue,
365    rhs: &ScalarValue,
366) -> Option<Ordering> {
367    // `Scalar::validate` guarantees that a scalar's value matches its dtype. Most of the scalar
368    // value variants have only 1 method of comparison, regardless of the dtype.
369    match (lhs, rhs) {
370        (ScalarValue::Bool(lhs), ScalarValue::Bool(rhs)) => lhs.partial_cmp(rhs),
371        (ScalarValue::Primitive(lhs), ScalarValue::Primitive(rhs)) => lhs.partial_cmp(rhs),
372        (ScalarValue::Decimal(lhs), ScalarValue::Decimal(rhs)) => lhs.partial_cmp(rhs),
373        (ScalarValue::Utf8(lhs), ScalarValue::Utf8(rhs)) => lhs.partial_cmp(rhs),
374        (ScalarValue::Binary(lhs), ScalarValue::Binary(rhs)) => lhs.partial_cmp(rhs),
375        // `Tuple` is the exception here. Since it backs lists, fixed-size lists, and structs, we
376        // need the dtype to know whether children share one element dtype or use per-field dtypes.
377        (ScalarValue::Tuple(lhs), ScalarValue::Tuple(rhs)) => {
378            partial_cmp_tuple_values(dtype, lhs, rhs)
379        }
380        // Variant values can have a different dtype in each row, so it doesn't make sense to
381        // compare them.
382        (ScalarValue::Variant(_), ScalarValue::Variant(_)) => None,
383        _ => None,
384    }
385}
386
387/// Compare tuple values according to the list, fixed-size list, or struct dtype layout.
388fn partial_cmp_tuple_values(
389    dtype: &DType,
390    lhs: &[Option<ScalarValue>],
391    rhs: &[Option<ScalarValue>],
392) -> Option<Ordering> {
393    match dtype {
394        DType::List(element_dtype, _) | DType::FixedSizeList(element_dtype, ..) => {
395            partial_cmp_list_values(element_dtype, lhs, rhs)
396        }
397        DType::Struct(fields, _) => partial_cmp_struct_values(fields, lhs, rhs),
398        DType::Extension(ext_dtype) => {
399            partial_cmp_tuple_values(ext_dtype.storage_dtype(), lhs, rhs)
400        }
401        _ => None,
402    }
403}
404
405/// Compare list tuple values using the shared element dtype for each element.
406fn partial_cmp_list_values(
407    element_dtype: &DType,
408    lhs: &[Option<ScalarValue>],
409    rhs: &[Option<ScalarValue>],
410) -> Option<Ordering> {
411    for (lhs, rhs) in lhs.iter().zip(rhs.iter()) {
412        match partial_cmp_scalar_values(element_dtype, lhs.as_ref(), rhs.as_ref())? {
413            Ordering::Equal => continue,
414            ordering => return Some(ordering),
415        }
416    }
417
418    Some(lhs.len().cmp(&rhs.len()))
419}
420
421/// Compare struct tuple values using each field's dtype in field order.
422fn partial_cmp_struct_values(
423    fields: &StructFields,
424    lhs: &[Option<ScalarValue>],
425    rhs: &[Option<ScalarValue>],
426) -> Option<Ordering> {
427    if lhs.len() != fields.nfields() || rhs.len() != fields.nfields() {
428        return None;
429    }
430
431    for ((field_dtype, lhs), rhs) in fields.fields().zip(lhs.iter()).zip(rhs.iter()) {
432        match partial_cmp_scalar_values(&field_dtype, lhs.as_ref(), rhs.as_ref())? {
433            Ordering::Equal => continue,
434            ordering => return Some(ordering),
435        }
436    }
437
438    Some(Ordering::Equal)
439}
440
441#[cfg(test)]
442mod tests {
443    use std::sync::Arc;
444
445    use rstest::rstest;
446
447    use crate::dtype::DType;
448    use crate::dtype::Nullability;
449    use crate::dtype::PType;
450    use crate::dtype::StructFields;
451    use crate::scalar::Scalar;
452
453    fn i32_scalar(value: i32) -> Scalar {
454        Scalar::primitive::<i32>(value, Nullability::NonNullable)
455    }
456
457    fn nullable_i32(value: Option<i32>) -> Scalar {
458        match value {
459            Some(value) => Scalar::primitive::<i32>(value, Nullability::Nullable),
460            None => Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
461        }
462    }
463
464    fn ab_struct_dtype(nullability: Nullability) -> DType {
465        DType::Struct(
466            StructFields::new(
467                ["a", "b"].into(),
468                vec![
469                    DType::Primitive(PType::I32, Nullability::NonNullable),
470                    DType::Utf8(Nullability::NonNullable),
471                ],
472            ),
473            nullability,
474        )
475    }
476
477    #[rstest]
478    // A fixed-size list of all-zero elements is itself zero.
479    #[case(vec![0, 0], Some(true))]
480    #[case(vec![0], Some(true))]
481    // A single non-zero element makes the whole list non-zero. On `develop` these incorrectly
482    // returned `Some(true)` because only the element count was checked.
483    #[case(vec![0, 5], Some(false))]
484    #[case(vec![5, 0], Some(false))]
485    #[case(vec![1, 2], Some(false))]
486    fn fixed_size_list_is_zero(#[case] values: Vec<i32>, #[case] expected: Option<bool>) {
487        let element_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
488        let children: Vec<Scalar> = values.into_iter().map(i32_scalar).collect();
489        let scalar = Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable);
490        assert_eq!(scalar.is_zero(), expected);
491    }
492
493    #[test]
494    fn null_fixed_size_list_is_zero_is_none() {
495        let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
496        let scalar = Scalar::null(DType::FixedSizeList(
497            element_dtype,
498            2,
499            Nullability::Nullable,
500        ));
501        assert_eq!(scalar.is_zero(), None);
502    }
503
504    #[test]
505    fn fixed_size_list_with_null_element_is_not_zero() {
506        // A non-null fixed-size list containing a null element is not a zero value. On `develop`
507        // this incorrectly returned `Some(true)`.
508        let element_dtype = DType::Primitive(PType::I32, Nullability::Nullable);
509        let children = vec![nullable_i32(Some(0)), nullable_i32(None)];
510        let scalar = Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable);
511        assert_eq!(scalar.is_zero(), Some(false));
512    }
513
514    #[test]
515    fn struct_with_all_zero_fields_is_zero() {
516        let scalar = Scalar::struct_(
517            ab_struct_dtype(Nullability::NonNullable),
518            vec![i32_scalar(0), Scalar::utf8("", Nullability::NonNullable)],
519        );
520        assert_eq!(scalar.is_zero(), Some(true));
521    }
522
523    #[rstest]
524    // A non-zero primitive field, a non-empty string field, or both, make the struct non-zero. On
525    // `develop` all of these incorrectly returned `Some(true)`.
526    #[case(5, "")]
527    #[case(0, "x")]
528    #[case(7, "y")]
529    fn struct_with_non_zero_field_is_not_zero(#[case] a: i32, #[case] b: &str) {
530        let scalar = Scalar::struct_(
531            ab_struct_dtype(Nullability::NonNullable),
532            vec![i32_scalar(a), Scalar::utf8(b, Nullability::NonNullable)],
533        );
534        assert_eq!(scalar.is_zero(), Some(false));
535    }
536
537    #[test]
538    fn null_struct_is_zero_is_none() {
539        let scalar = Scalar::null(ab_struct_dtype(Nullability::Nullable));
540        assert_eq!(scalar.is_zero(), None);
541    }
542
543    #[test]
544    fn struct_with_null_field_is_not_zero() {
545        // A non-null struct with a null field is not a zero value. On `develop` this incorrectly
546        // returned `Some(true)`.
547        let dtype = DType::Struct(
548            StructFields::new(
549                ["a", "b"].into(),
550                vec![
551                    DType::Primitive(PType::I32, Nullability::Nullable),
552                    DType::Primitive(PType::I32, Nullability::Nullable),
553                ],
554            ),
555            Nullability::NonNullable,
556        );
557        let scalar = Scalar::struct_(dtype, vec![nullable_i32(Some(0)), nullable_i32(None)]);
558        assert_eq!(scalar.is_zero(), Some(false));
559    }
560
561    #[test]
562    fn nested_struct_of_fixed_size_list_recurses() {
563        // Zero-checking must recurse through both structs and fixed-size lists. On `develop` the
564        // non-zero case incorrectly returned `Some(true)`.
565        let element_dtype = Arc::new(DType::Primitive(PType::I32, Nullability::NonNullable));
566        let fsl_dtype =
567            DType::FixedSizeList(Arc::clone(&element_dtype), 2, Nullability::NonNullable);
568        let struct_dtype = DType::Struct(
569            StructFields::new(["fsl"].into(), vec![fsl_dtype]),
570            Nullability::NonNullable,
571        );
572
573        let all_zero = Scalar::struct_(
574            struct_dtype.clone(),
575            vec![Scalar::fixed_size_list(
576                Arc::clone(&element_dtype),
577                vec![i32_scalar(0), i32_scalar(0)],
578                Nullability::NonNullable,
579            )],
580        );
581        assert_eq!(all_zero.is_zero(), Some(true));
582
583        let with_non_zero = Scalar::struct_(
584            struct_dtype,
585            vec![Scalar::fixed_size_list(
586                element_dtype,
587                vec![i32_scalar(0), i32_scalar(9)],
588                Nullability::NonNullable,
589            )],
590        );
591        assert_eq!(with_non_zero.is_zero(), Some(false));
592    }
593}