vortex_array/scalar/scalar_impl.rs
1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Core [`Scalar`] type definition.
5
6use std::cmp::Ordering;
7use std::hash::Hash;
8use std::hash::Hasher;
9
10use vortex_dtype::DType;
11use vortex_dtype::NativeDType;
12use vortex_dtype::PType;
13use vortex_error::VortexResult;
14use vortex_error::vortex_ensure;
15use vortex_error::vortex_ensure_eq;
16use vortex_error::vortex_panic;
17
18use crate::scalar::PValue;
19use crate::scalar::Scalar;
20use crate::scalar::ScalarValue;
21
22/// We implement `PartialEq` manually because we want to ignore nullability when comparing scalars.
23/// Two scalars with the same value but different nullability should be considered equal.
24impl PartialEq for Scalar {
25 fn eq(&self, other: &Self) -> bool {
26 self.dtype.eq_ignore_nullability(&other.dtype) && self.value == other.value
27 }
28}
29
30/// We implement `Hash` manually to be consistent with `PartialEq`. Since we ignore nullability
31/// in equality comparisons, we must also ignore it when hashing to maintain the invariant that
32/// equal values have equal hashes.
33impl Hash for Scalar {
34 fn hash<H: Hasher>(&self, state: &mut H) {
35 self.dtype.as_nonnullable().hash(state);
36 self.value.hash(state);
37 }
38}
39
40impl Scalar {
41 // Constructors for null scalars.
42
43 /// Creates a new null [`Scalar`] with the given [`DType`].
44 ///
45 /// # Panics
46 ///
47 /// Panics if the given [`DType`] is non-nullable.
48 pub fn null(dtype: DType) -> Self {
49 assert!(
50 dtype.is_nullable(),
51 "Cannot create null scalar with non-nullable dtype {dtype}"
52 );
53
54 Self { dtype, value: None }
55 }
56
57 // TODO(connor): This method arguably shouldn't exist...
58 /// Creates a new null [`Scalar`] for the given scalar type.
59 ///
60 /// The resulting scalar will have a nullable version of the type's data type.
61 pub fn null_native<T: NativeDType>() -> Self {
62 Self {
63 dtype: T::dtype().as_nullable(),
64 value: None,
65 }
66 }
67
68 // Constructors for potentially null scalars.
69
70 /// Creates a new [`Scalar`] with the given [`DType`] and potentially null [`ScalarValue`].
71 ///
72 /// This is just a helper function for tests.
73 ///
74 /// # Panics
75 ///
76 /// Panics if the given [`DType`] and [`ScalarValue`] are incompatible.
77 #[cfg(test)]
78 pub fn new(dtype: DType, value: Option<ScalarValue>) -> Self {
79 use vortex_error::VortexExpect;
80
81 Self::try_new(dtype, value).vortex_expect("Failed to create Scalar")
82 }
83
84 /// Attempts to create a new [`Scalar`] with the given [`DType`] and potentially null
85 /// [`ScalarValue`].
86 ///
87 /// # Errors
88 ///
89 /// Returns an error if the given [`DType`] and [`ScalarValue`] are incompatible.
90 pub fn try_new(dtype: DType, value: Option<ScalarValue>) -> VortexResult<Self> {
91 vortex_ensure!(
92 Self::is_compatible(&dtype, value.as_ref()),
93 "Incompatible dtype {dtype} with value {}",
94 value.map(|v| format!("{}", v)).unwrap_or_default()
95 );
96
97 Ok(Self { dtype, value })
98 }
99
100 /// Creates a new [`Scalar`] with the given [`DType`] and potentially null [`ScalarValue`]
101 /// without checking compatibility.
102 ///
103 /// # Safety
104 ///
105 /// The caller must ensure that the given [`DType`] and [`ScalarValue`] are compatible per the
106 /// rules defined in [`Self::is_compatible`].
107 pub unsafe fn new_unchecked(dtype: DType, value: Option<ScalarValue>) -> Self {
108 debug_assert!(
109 Self::is_compatible(&dtype, value.as_ref()),
110 "Incompatible dtype {dtype} with value {}",
111 value.map(|v| format!("{}", v)).unwrap_or_default()
112 );
113
114 Self { dtype, value }
115 }
116
117 /// Returns a default value for the given [`DType`].
118 ///
119 /// For nullable types, this returns a null scalar. For non-nullable and non-nested types, this
120 /// returns the zero value for the type.
121 ///
122 /// For non-nullable and nested types that may need null values in their children (as of right
123 /// now, that is _only_ `FixedSizeList` and `Struct`), this function will provide null default
124 /// children.
125 ///
126 /// See [`ScalarValue::zero_value`] for more details about "zero" values.
127 pub fn default_value(dtype: &DType) -> Self {
128 let value = ScalarValue::default_value(dtype);
129 // SAFETY: We assume that `default_value` creates a valid `ScalarValue` for the `DType`.
130 unsafe { Self::new_unchecked(dtype.clone(), value) }
131 }
132
133 /// Returns a non-null zero / identity value for the given [`DType`].
134 ///
135 /// See [`ScalarValue::zero_value`] for more details about "zero" values.
136 pub fn zero_value(dtype: &DType) -> Self {
137 let value = ScalarValue::zero_value(dtype);
138 // SAFETY: We assume that `zero_value` creates a valid `ScalarValue` for the `DType`.
139 unsafe { Self::new_unchecked(dtype.clone(), Some(value)) }
140 }
141
142 // Other methods.
143
144 /// Check if the given [`ScalarValue`] is compatible with the given [`DType`].
145 pub fn is_compatible(dtype: &DType, value: Option<&ScalarValue>) -> bool {
146 let Some(value) = value else {
147 return dtype.is_nullable();
148 };
149 // From here on, we know that the value is not null.
150
151 match dtype {
152 DType::Null => false,
153 DType::Bool(_) => matches!(value, ScalarValue::Bool(_)),
154 DType::Primitive(ptype, _) => {
155 if let ScalarValue::Primitive(pvalue) = value {
156 // Note that this is a backwards compatibility check for poor design in the
157 // previous implementation. `f16` `ScalarValue`s used to be serialized as
158 // `pb::ScalarValue::Uint64Value(v.to_bits() as u64)`, so we need to ensure that
159 // we can still represent them as such.
160 let f16_backcompat_still_works =
161 matches!(ptype, &PType::F16) && matches!(pvalue, PValue::U64(_));
162
163 f16_backcompat_still_works || pvalue.ptype() == *ptype
164 } else {
165 false
166 }
167 }
168 DType::Decimal(dec_dtype, _) => {
169 if let ScalarValue::Decimal(dvalue) = value {
170 dvalue.fits_in_precision(*dec_dtype)
171 } else {
172 false
173 }
174 }
175 DType::Utf8(_) => matches!(value, ScalarValue::Utf8(_)),
176 DType::Binary(_) => matches!(value, ScalarValue::Binary(_)),
177 DType::List(elem_dtype, _) => {
178 if let ScalarValue::List(elements) = value {
179 elements
180 .iter()
181 .all(|element| Self::is_compatible(elem_dtype.as_ref(), element.as_ref()))
182 } else {
183 false
184 }
185 }
186 DType::FixedSizeList(elem_dtype, size, _) => {
187 if let ScalarValue::List(elements) = value {
188 if elements.len() != *size as usize {
189 return false;
190 }
191 elements
192 .iter()
193 .all(|element| Self::is_compatible(elem_dtype.as_ref(), element.as_ref()))
194 } else {
195 false
196 }
197 }
198 DType::Struct(fields, _) => {
199 if let ScalarValue::List(values) = value {
200 if values.len() != fields.nfields() {
201 return false;
202 }
203 for (field, field_value) in fields.fields().zip(values.iter()) {
204 if !Self::is_compatible(&field, field_value.as_ref()) {
205 return false;
206 }
207 }
208 true
209 } else {
210 false
211 }
212 }
213 DType::Extension(ext_dtype) => {
214 // TODO(connor): Fix this when adding the correct extension scalars!
215 Self::is_compatible(ext_dtype.storage_dtype(), Some(value))
216 }
217 }
218 }
219
220 /// Check if two scalars are equal, ignoring nullability of the [`DType`].
221 pub fn eq_ignore_nullability(&self, other: &Self) -> bool {
222 self.dtype.eq_ignore_nullability(&other.dtype) && self.value == other.value
223 }
224
225 /// Returns the parts of the [`Scalar`].
226 pub fn into_parts(self) -> (DType, Option<ScalarValue>) {
227 (self.dtype, self.value)
228 }
229
230 /// Returns the [`DType`] of the [`Scalar`].
231 pub fn dtype(&self) -> &DType {
232 &self.dtype
233 }
234
235 /// Returns an optional [`ScalarValue`] of the [`Scalar`], where `None` means the value is null.
236 pub fn value(&self) -> Option<&ScalarValue> {
237 self.value.as_ref()
238 }
239
240 /// Returns the internal optional [`ScalarValue`], where `None` means the value is null,
241 /// consuming the [`Scalar`].
242 pub fn into_value(self) -> Option<ScalarValue> {
243 self.value
244 }
245
246 /// Returns `true` if the [`Scalar`] has a non-null value.
247 pub fn is_valid(&self) -> bool {
248 self.value.is_some()
249 }
250
251 /// Returns `true` if the [`Scalar`] is null.
252 pub fn is_null(&self) -> bool {
253 self.value.is_none()
254 }
255
256 /// Returns `true` if the [`Scalar`] has a non-null zero value.
257 ///
258 /// Returns `None` if the scalar is null, otherwise returns `Some(true)` if the value is zero
259 /// and `Some(false)` otherwise.
260 pub fn is_zero(&self) -> Option<bool> {
261 let value = self.value()?;
262
263 let is_zero = match self.dtype() {
264 DType::Null => vortex_panic!("non-null value somehow had `DType::Null`"),
265 DType::Bool(_) => !value.as_bool(),
266 DType::Primitive(..) => value.as_primitive().is_zero(),
267 DType::Decimal(..) => value.as_decimal().is_zero(),
268 DType::Utf8(_) => value.as_utf8().is_empty(),
269 DType::Binary(_) => value.as_binary().is_empty(),
270 DType::List(..) => value.as_list().is_empty(),
271 DType::FixedSizeList(_, list_size, _) => value.as_list().len() == *list_size as usize,
272 DType::Struct(struct_fields, _) => value.as_list().len() == struct_fields.nfields(),
273 DType::Extension(_) => self.as_extension().to_storage_scalar().is_zero()?,
274 };
275
276 Some(is_zero)
277 }
278
279 /// Reinterprets the bytes of this scalar as a different primitive type.
280 ///
281 /// # Errors
282 ///
283 /// Panics if the scalar is not a primitive type or if the types have different byte widths.
284 pub fn primitive_reinterpret_cast(&self, ptype: PType) -> VortexResult<Self> {
285 let primitive = self.as_primitive();
286 if primitive.ptype() == ptype {
287 return Ok(self.clone());
288 }
289
290 vortex_ensure_eq!(
291 primitive.ptype().byte_width(),
292 ptype.byte_width(),
293 "can't reinterpret cast between integers of two different widths"
294 );
295
296 Scalar::try_new(
297 DType::Primitive(ptype, self.dtype().nullability()),
298 primitive
299 .pvalue()
300 .map(|p| p.reinterpret_cast(ptype))
301 .map(ScalarValue::Primitive),
302 )
303 }
304
305 /// Returns an **ESTIMATE** of the size of the scalar in bytes, uncompressed.
306 ///
307 /// Note that the protobuf serialization of scalars will likely have a different (but roughly
308 /// similar) length.
309 pub fn nbytes(&self) -> usize {
310 use vortex_dtype::NativeDecimalType;
311 use vortex_dtype::i256;
312
313 match self.dtype() {
314 DType::Null => 0,
315 DType::Bool(_) => 1,
316 DType::Primitive(ptype, _) => ptype.byte_width(),
317 DType::Decimal(dt, _) => {
318 if dt.precision() <= i128::MAX_PRECISION {
319 size_of::<i128>()
320 } else {
321 size_of::<i256>()
322 }
323 }
324 DType::Utf8(_) => self
325 .value()
326 .map_or_else(|| 0, |value| value.as_utf8().len()),
327 DType::Binary(_) => self
328 .value()
329 .map_or_else(|| 0, |value| value.as_binary().len()),
330 DType::Struct(..) => self
331 .as_struct()
332 .fields_iter()
333 .map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::<usize>())
334 .unwrap_or_default(),
335 DType::List(..) | DType::FixedSizeList(..) => self
336 .as_list()
337 .elements()
338 .map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::<usize>())
339 .unwrap_or_default(),
340 DType::Extension(_) => self.as_extension().to_storage_scalar().nbytes(),
341 }
342 }
343}
344
345impl PartialOrd for Scalar {
346 /// Compares two scalar values for ordering.
347 ///
348 /// # Returns
349 /// - `Some(Ordering)` if both scalars have the same data type (ignoring nullability)
350 /// - `None` if the scalars have different data types
351 ///
352 /// # Ordering Rules
353 /// When types match, the ordering follows these rules:
354 /// - Null values are considered less than all non-null values
355 /// - Non-null values are compared according to their natural ordering
356 ///
357 /// # Examples
358 /// ```ignore
359 /// // Same types compare successfully
360 /// let a = Scalar::primitive(10i32, Nullability::NonNullable);
361 /// let b = Scalar::primitive(20i32, Nullability::NonNullable);
362 /// assert_eq!(a.partial_cmp(&b), Some(Ordering::Less));
363 ///
364 /// // Different types return None
365 /// let int_scalar = Scalar::primitive(10i32, Nullability::NonNullable);
366 /// let str_scalar = Scalar::utf8("hello", Nullability::NonNullable);
367 /// assert_eq!(int_scalar.partial_cmp(&str_scalar), None);
368 ///
369 /// // Nulls are less than non-nulls
370 /// let null = Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable));
371 /// let value = Scalar::primitive(0i32, Nullability::Nullable);
372 /// assert_eq!(null.partial_cmp(&value), Some(Ordering::Less));
373 /// ```
374 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
375 if !self.dtype().eq_ignore_nullability(other.dtype()) {
376 return None;
377 }
378 self.value().partial_cmp(&other.value())
379 }
380}