Skip to main content

vortex_dtype/
dtype.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::sync::Arc;
9
10use DType::*;
11use itertools::Itertools;
12use vortex_error::VortexExpect;
13use vortex_error::vortex_panic;
14
15use crate::FieldDType;
16use crate::FieldName;
17use crate::PType;
18use crate::StructFields;
19use crate::decimal::DecimalDType;
20use crate::decimal::DecimalType;
21use crate::extension::ExtDTypeRef;
22use crate::nullability::Nullability;
23
24/// The logical types of elements in Vortex arrays.
25///
26/// `DType` represents the different logical data types that can be represented in a Vortex array.
27///
28/// This is different from physical types, which represent the actual layout of data (compressed or
29/// uncompressed). The set of physical types/formats (or data layout) is surjective into the set of
30/// logical types (or in other words, all physical types map to a single logical type).
31///
32/// Note that a `DType` represents the logical type of the elements in the `Array`s, **not** the
33/// logical type of the `Array` itself.
34///
35/// For example, an array with [`DType::Primitive`]([`I32`], [`NonNullable`]) could be physically
36/// encoded as any of the following:
37///
38/// - A flat array of `i32` values.
39/// - A run-length encoded sequence.
40/// - Dictionary encoded values with bitpacked codes.
41///
42/// All of these physical encodings preserve the same logical [`I32`] type, even if the physical
43/// data is different.
44///
45/// [`I32`]: PType::I32
46/// [`NonNullable`]: Nullability::NonNullable
47#[derive(Debug, Clone, PartialEq, Eq, Hash)]
48pub enum DType {
49    /// A logical null type.
50    ///
51    /// `Null` only has a single value, `null`.
52    Null,
53
54    /// A logical boolean type.
55    ///
56    /// `Bool` can be `true` or `false` if non-nullable. It can be `true`, `false`, or `null` if
57    /// nullable.
58    Bool(Nullability),
59
60    /// A logical fixed-width numeric type.
61    ///
62    /// This can be unsigned, signed, or floating point. See [`PType`] for more information.
63    Primitive(PType, Nullability),
64
65    /// Logical real numbers with fixed precision and scale.
66    ///
67    /// See [`DecimalDType`] for more information.
68    Decimal(DecimalDType, Nullability),
69
70    /// Logical UTF-8 strings.
71    Utf8(Nullability),
72
73    /// Logical binary data.
74    Binary(Nullability),
75
76    /// A logical variable-length list type.
77    ///
78    /// This is parameterized by a single `DType` that represents the element type of the inner
79    /// lists.
80    List(Arc<DType>, Nullability),
81
82    /// A logical fixed-size list type.
83    ///
84    /// This is parameterized by a `DType` that represents the element type of the inner lists, as
85    /// well as a `u32` size that determines the fixed length of each `FixedSizeList` scalar.
86    FixedSizeList(Arc<DType>, u32, Nullability),
87
88    /// A logical struct type.
89    ///
90    /// A `Struct` type is composed of an ordered list of fields, each with a corresponding name and
91    /// `DType`. See [`StructFields`] for more information.
92    Struct(StructFields, Nullability),
93
94    /// A user-defined extension type.
95    ///
96    /// See [`ExtDTypeRef`] for more information.
97    Extension(ExtDTypeRef),
98}
99
100/// This trait is implemented by native Rust types that can be converted
101/// to and from Vortex scalar values.
102/// e.g. `&str` -> `DType::Utf8`
103///      `bool` -> `DType::Bool`
104///
105/// The dtype is the one closet matching the domain of the rust type
106/// e.g. `Option<T>` -> Nullable DType.
107pub trait NativeDType {
108    /// Returns the Vortex data type for this scalar type.
109    fn dtype() -> DType;
110}
111
112/// Assert that the size of DType is 16 bytes.
113#[cfg(not(target_arch = "wasm32"))]
114const _: [(); size_of::<DType>()] = [(); 24]; // FIXME(ngates): should we keep this at 16?
115
116/// Assert that the size of DType is 12 bytes on wasm32.
117#[cfg(target_arch = "wasm32")]
118const _: [(); size_of::<DType>()] = [(); 12];
119
120impl DType {
121    /// The default `DType` for bytes.
122    pub const BYTES: Self = Primitive(PType::U8, Nullability::NonNullable);
123
124    /// Get the nullability of the `DType`.
125    #[inline]
126    pub fn nullability(&self) -> Nullability {
127        self.is_nullable().into()
128    }
129
130    /// Check if the `DType` is [`Nullability::Nullable`].
131    #[inline]
132    pub fn is_nullable(&self) -> bool {
133        match self {
134            Null => true,
135            Extension(ext_dtype) => ext_dtype.storage_dtype().is_nullable(),
136            Bool(null)
137            | Primitive(_, null)
138            | Decimal(_, null)
139            | Utf8(null)
140            | Binary(null)
141            | Struct(_, null)
142            | List(_, null)
143            | FixedSizeList(_, _, null) => matches!(null, Nullability::Nullable),
144        }
145    }
146
147    /// Get a new `DType` with [`Nullability::NonNullable`] (but otherwise the same as `self`)
148    pub fn as_nonnullable(&self) -> Self {
149        self.with_nullability(Nullability::NonNullable)
150    }
151
152    /// Get a new `DType` with [`Nullability::Nullable`] (but otherwise the same as `self`)
153    pub fn as_nullable(&self) -> Self {
154        self.with_nullability(Nullability::Nullable)
155    }
156
157    /// Get a new DType with the given nullability (but otherwise the same as `self`)
158    pub fn with_nullability(&self, nullability: Nullability) -> Self {
159        match self {
160            Null => Null,
161            Bool(_) => Bool(nullability),
162            Primitive(pdt, _) => Primitive(*pdt, nullability),
163            Decimal(ddt, _) => Decimal(*ddt, nullability),
164            Utf8(_) => Utf8(nullability),
165            Binary(_) => Binary(nullability),
166            Struct(sf, _) => Struct(sf.clone(), nullability),
167            List(edt, _) => List(edt.clone(), nullability),
168            FixedSizeList(edt, size, _) => FixedSizeList(edt.clone(), *size, nullability),
169            Extension(ext) => Extension(ext.with_nullability(nullability)),
170        }
171    }
172
173    /// Union the nullability of this `DType` with the other nullability, returning a new `DType`.
174    pub fn union_nullability(&self, other: Nullability) -> Self {
175        let nullability = self.nullability() | other;
176        self.with_nullability(nullability)
177    }
178
179    /// Check if `self` and `other` are equal, ignoring nullability.
180    pub fn eq_ignore_nullability(&self, other: &Self) -> bool {
181        match (self, other) {
182            (Null, Null) => true,
183            (Bool(_), Bool(_)) => true,
184            (Primitive(lhs_ptype, _), Primitive(rhs_ptype, _)) => lhs_ptype == rhs_ptype,
185            (Decimal(lhs, _), Decimal(rhs, _)) => lhs == rhs,
186            (Utf8(_), Utf8(_)) => true,
187            (Binary(_), Binary(_)) => true,
188            (List(lhs_dtype, _), List(rhs_dtype, _)) => lhs_dtype.eq_ignore_nullability(rhs_dtype),
189            (FixedSizeList(lhs_dtype, lhs_size, _), FixedSizeList(rhs_dtype, rhs_size, _)) => {
190                lhs_size == rhs_size && lhs_dtype.eq_ignore_nullability(rhs_dtype)
191            }
192            (Struct(lhs_dtype, _), Struct(rhs_dtype, _)) => {
193                (lhs_dtype.names() == rhs_dtype.names())
194                    && (lhs_dtype
195                        .fields()
196                        .zip_eq(rhs_dtype.fields())
197                        .all(|(l, r)| l.eq_ignore_nullability(&r)))
198            }
199            (Extension(lhs_extdtype), Extension(rhs_extdtype)) => {
200                lhs_extdtype.eq_ignore_nullability(rhs_extdtype)
201            }
202            _ => false,
203        }
204    }
205
206    /// Returns `true` if `self` is a subset type of `other, otherwise `false`.
207    ///
208    /// If `self` is nullable, this means that the other `DType` must also be nullable (since a
209    /// nullable type represents more values than a non-nullable type) and equal.
210    ///
211    /// If `self` is non-nullable, then the other `DType` must be equal ignoring nullabillity.
212    ///
213    /// We implement this functionality as a complement to `is_superset_of`.
214    pub fn eq_with_nullability_subset(&self, other: &Self) -> bool {
215        if self.is_nullable() {
216            self == other
217        } else {
218            self.eq_ignore_nullability(other)
219        }
220    }
221
222    /// Returns `true` if `self` is a superset type of `other, otherwise `false`.
223    ///
224    /// If `self` is non-nullable, this means that the other `DType` must also be non-nullable
225    /// (since a non-nullable type represents less values than a nullable type) and equal.
226    ///
227    /// If `self` is nullable, then the other `DType` must be equal ignoring nullabillity.
228    ///
229    /// This function is useful (in the `vortex-array` crate) for determining if an `Array` can
230    /// extend a given `ArrayBuilder`: it can only extend it if the `DType` of the builder is a
231    /// superset of the `Array`.
232    pub fn eq_with_nullability_superset(&self, other: &Self) -> bool {
233        if self.is_nullable() {
234            self.eq_ignore_nullability(other)
235        } else {
236            self == other
237        }
238    }
239
240    /// Check if `self` is a boolean
241    pub fn is_boolean(&self) -> bool {
242        matches!(self, Bool(_))
243    }
244
245    /// Check if `self` is a primitive type
246    pub fn is_primitive(&self) -> bool {
247        matches!(self, Primitive(_, _))
248    }
249
250    /// Returns this [`DType`]'s [`PType`] if it is a primitive type, otherwise panics.
251    pub fn as_ptype(&self) -> PType {
252        if let Primitive(ptype, _) = self {
253            *ptype
254        } else {
255            vortex_panic!("DType {self} is not a primitive type")
256        }
257    }
258
259    /// Check if `self` is an unsigned integer
260    pub fn is_unsigned_int(&self) -> bool {
261        if let Primitive(ptype, _) = self {
262            return ptype.is_unsigned_int();
263        }
264        false
265    }
266
267    /// Check if `self` is a signed integer
268    pub fn is_signed_int(&self) -> bool {
269        if let Primitive(ptype, _) = self {
270            return ptype.is_signed_int();
271        }
272        false
273    }
274
275    /// Check if `self` is an integer (signed or unsigned)
276    pub fn is_int(&self) -> bool {
277        if let Primitive(ptype, _) = self {
278            return ptype.is_int();
279        }
280        false
281    }
282
283    /// Check if `self` is a floating point number
284    pub fn is_float(&self) -> bool {
285        if let Primitive(ptype, _) = self {
286            return ptype.is_float();
287        }
288        false
289    }
290
291    /// Check if `self` is a [`DType::Decimal`].
292    pub fn is_decimal(&self) -> bool {
293        matches!(self, Decimal(..))
294    }
295
296    /// Check if `self` is a [`DType::Utf8`]
297    pub fn is_utf8(&self) -> bool {
298        matches!(self, Utf8(_))
299    }
300
301    /// Check if `self` is a [`DType::Binary`]
302    pub fn is_binary(&self) -> bool {
303        matches!(self, Binary(_))
304    }
305
306    /// Check if `self` is a [`DType::List`].
307    pub fn is_list(&self) -> bool {
308        matches!(self, List(_, _))
309    }
310
311    /// Check if `self` is a [`DType::FixedSizeList`],
312    pub fn is_fixed_size_list(&self) -> bool {
313        matches!(self, FixedSizeList(..))
314    }
315
316    /// Check if `self` is a [`DType::Struct`]
317    pub fn is_struct(&self) -> bool {
318        matches!(self, Struct(_, _))
319    }
320
321    /// Check if `self` is a [`DType::Extension`] type
322    pub fn is_extension(&self) -> bool {
323        matches!(self, Extension(_))
324    }
325
326    /// Check if `self` is a nested type, i.e. list, fixed size list, struct, or extension of a
327    /// recursive type.
328    pub fn is_nested(&self) -> bool {
329        match self {
330            List(..) | FixedSizeList(..) | Struct(..) => true,
331            Extension(ext) => ext.storage_dtype().is_nested(),
332            _ => false,
333        }
334    }
335
336    /// Returns the number of bytes occupied by a single scalar of this fixed-width type.
337    ///
338    /// For non-fixed-width types, return None.
339    ///
340    /// [`Bool`] is defined as 1 even though a Vortex array may pack Booleans to one bit per element.
341    pub fn element_size(&self) -> Option<usize> {
342        match self {
343            Null => Some(0),
344            Bool(_) => Some(1),
345            Primitive(ptype, _) => Some(ptype.byte_width()),
346            Decimal(decimal, _) => {
347                Some(DecimalType::smallest_decimal_value_type(decimal).byte_width())
348            }
349            Utf8(_) | Binary(_) | List(..) => None,
350            FixedSizeList(elem_dtype, list_size, _) => {
351                elem_dtype.element_size().map(|s| s * *list_size as usize)
352            }
353            Struct(fields, ..) => {
354                let mut sum = 0_usize;
355                for f in fields.fields() {
356                    let element_size = f.element_size()?;
357                    sum = sum
358                        .checked_add(element_size)
359                        .vortex_expect("sum of field sizes is bigger than usize");
360                }
361                Some(sum)
362            }
363            Extension(ext) => ext.storage_dtype().element_size(),
364        }
365    }
366
367    /// Check returns the inner decimal type if the dtype is a [`DType::Decimal`].
368    pub fn as_decimal_opt(&self) -> Option<&DecimalDType> {
369        if let Decimal(decimal, _) = self {
370            Some(decimal)
371        } else {
372            None
373        }
374    }
375
376    /// Owned version of [Self::as_decimal_opt].
377    pub fn into_decimal_opt(self) -> Option<DecimalDType> {
378        if let Decimal(decimal, _) = self {
379            Some(decimal)
380        } else {
381            None
382        }
383    }
384
385    /// Get the inner element dtype if `self` is a [`DType::List`], otherwise returns `None`.
386    ///
387    /// Note that this does _not_ return `Some` if `self` is a [`DType::FixedSizeList`].
388    pub fn as_list_element_opt(&self) -> Option<&Arc<DType>> {
389        if let List(edt, _) = self {
390            Some(edt)
391        } else {
392            None
393        }
394    }
395
396    /// Owned version of [Self::as_list_element_opt].
397    pub fn into_list_element_opt(self) -> Option<Arc<DType>> {
398        if let List(edt, _) = self {
399            Some(edt)
400        } else {
401            None
402        }
403    }
404
405    /// Get the inner element dtype if `self` is a [`DType::FixedSizeList`], otherwise returns
406    /// `None`.
407    ///
408    /// Note that this does _not_ return `Some` if `self` is a [`DType::List`].
409    pub fn as_fixed_size_list_element_opt(&self) -> Option<&Arc<DType>> {
410        if let FixedSizeList(edt, ..) = self {
411            Some(edt)
412        } else {
413            None
414        }
415    }
416
417    /// Owned version of [Self::as_fixed_size_list_element_opt].
418    pub fn into_fixed_size_list_element_opt(self) -> Option<Arc<DType>> {
419        if let FixedSizeList(edt, ..) = self {
420            Some(edt)
421        } else {
422            None
423        }
424    }
425
426    /// Get the inner element dtype if `self` is **either** a [`DType::List`] or a
427    /// [`DType::FixedSizeList`], otherwise returns `None`
428    pub fn as_any_size_list_element_opt(&self) -> Option<&Arc<DType>> {
429        if let FixedSizeList(edt, ..) = self {
430            Some(edt)
431        } else if let List(edt, ..) = self {
432            Some(edt)
433        } else {
434            None
435        }
436    }
437
438    /// Owned version of [Self::as_any_size_list_element_opt].
439    pub fn into_any_size_list_element_opt(self) -> Option<Arc<DType>> {
440        if let FixedSizeList(edt, ..) = self {
441            Some(edt)
442        } else if let List(edt, ..) = self {
443            Some(edt)
444        } else {
445            None
446        }
447    }
448
449    /// Returns the [`StructFields`] from a struct [`DType`].
450    ///
451    /// # Panics
452    ///
453    /// If the [`DType`] is not a struct.
454    pub fn as_struct_fields(&self) -> &StructFields {
455        if let Struct(f, _) = self {
456            return f;
457        }
458        vortex_panic!("DType is not a Struct")
459    }
460
461    /// Owned version of [Self::as_struct_fields].
462    pub fn into_struct_fields(self) -> StructFields {
463        if let Struct(f, _) = self {
464            return f;
465        }
466        vortex_panic!("DType is not a Struct")
467    }
468
469    /// Get the `StructDType` if `self` is a `StructDType`, otherwise `None`
470    pub fn as_struct_fields_opt(&self) -> Option<&StructFields> {
471        if let Struct(f, _) = self {
472            Some(f)
473        } else {
474            None
475        }
476    }
477
478    /// Owned version of [Self::as_struct_fields_opt].
479    pub fn into_struct_fields_opt(self) -> Option<StructFields> {
480        if let Struct(f, _) = self {
481            Some(f)
482        } else {
483            None
484        }
485    }
486
487    /// Downcast a `DType` to an `ExtDType`
488    pub fn as_extension(&self) -> &ExtDTypeRef {
489        let Extension(ext) = self else {
490            vortex_panic!("DType is not an Extension")
491        };
492        ext
493    }
494
495    /// Get the `ExtDTypeRef` if `self` is an `Extension` type, otherwise `None`
496    pub fn as_extension_opt(&self) -> Option<&ExtDTypeRef> {
497        if let Extension(ext) = self {
498            Some(ext)
499        } else {
500            None
501        }
502    }
503
504    /// Convenience method for creating a [`DType::List`].
505    pub fn list(dtype: impl Into<DType>, nullability: Nullability) -> Self {
506        List(Arc::new(dtype.into()), nullability)
507    }
508
509    /// Convenience method for creating a [`DType::Struct`].
510    pub fn struct_<I: IntoIterator<Item = (impl Into<FieldName>, impl Into<FieldDType>)>>(
511        iter: I,
512        nullability: Nullability,
513    ) -> Self {
514        Struct(StructFields::from_iter(iter), nullability)
515    }
516}
517
518impl Display for DType {
519    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
520        match self {
521            Null => write!(f, "null"),
522            Bool(null) => write!(f, "bool{null}"),
523            Primitive(pdt, null) => write!(f, "{pdt}{null}"),
524            Decimal(ddt, null) => write!(f, "{ddt}{null}"),
525            Utf8(null) => write!(f, "utf8{null}"),
526            Binary(null) => write!(f, "binary{null}"),
527            Struct(sf, null) => write!(
528                f,
529                "{{{}}}{null}",
530                sf.names()
531                    .iter()
532                    .zip(sf.fields())
533                    .map(|(field_null, dt)| format!("{field_null}={dt}"))
534                    .join(", "),
535            ),
536            List(edt, null) => write!(f, "list({edt}){null}"),
537            FixedSizeList(edt, size, null) => write!(f, "fixed_size_list({edt})[{size}]{null}"),
538            Extension(ext) => write!(f, "{}", ext),
539        }
540    }
541}
542
543#[cfg(test)]
544mod tests {
545    use std::sync::Arc;
546
547    use crate::DType;
548    use crate::Nullability::NonNullable;
549    use crate::Nullability::Nullable;
550    use crate::PType;
551    use crate::datetime::Date;
552    use crate::datetime::Time;
553    use crate::datetime::TimeUnit;
554    use crate::datetime::Timestamp;
555    use crate::decimal::DecimalDType;
556
557    #[test]
558    fn test_ext_dtype_eq_ignore_nullability() {
559        let d1 = DType::Extension(Time::new(TimeUnit::Seconds, Nullable).erased());
560        let d2 = DType::Extension(Time::new(TimeUnit::Seconds, NonNullable).erased());
561        assert!(d1.eq_ignore_nullability(&d2));
562
563        let t1 = DType::Extension(
564            Timestamp::new_with_tz(TimeUnit::Seconds, Some("UTC".into()), Nullable).erased(),
565        );
566        let t2 = DType::Extension(
567            Timestamp::new_with_tz(TimeUnit::Seconds, Some("ET".into()), Nullable).erased(),
568        );
569        assert!(!t1.eq_ignore_nullability(&t2));
570    }
571
572    #[test]
573    fn element_size_null() {
574        assert_eq!(DType::Null.element_size(), Some(0));
575    }
576
577    #[test]
578    fn element_size_bool() {
579        assert_eq!(DType::Bool(NonNullable).element_size(), Some(1));
580    }
581
582    #[test]
583    fn element_size_primitives() {
584        assert_eq!(
585            DType::Primitive(PType::U8, NonNullable).element_size(),
586            Some(1)
587        );
588        assert_eq!(
589            DType::Primitive(PType::I32, NonNullable).element_size(),
590            Some(4)
591        );
592        assert_eq!(
593            DType::Primitive(PType::F64, NonNullable).element_size(),
594            Some(8)
595        );
596    }
597
598    #[test]
599    fn element_size_decimal() {
600        let decimal = DecimalDType::new(10, 2);
601        // precision 10 -> DecimalType::I64 -> 8 bytes
602        assert_eq!(DType::Decimal(decimal, NonNullable).element_size(), Some(8));
603    }
604
605    #[test]
606    fn element_size_fixed_size_list() {
607        let elem = Arc::new(DType::Primitive(PType::F64, NonNullable));
608        assert_eq!(
609            DType::FixedSizeList(elem.clone(), 1000, NonNullable).element_size(),
610            Some(8000)
611        );
612
613        assert_eq!(
614            DType::FixedSizeList(
615                Arc::new(DType::FixedSizeList(elem, 20, NonNullable)),
616                1000,
617                NonNullable
618            )
619            .element_size(),
620            Some(160_000)
621        );
622    }
623
624    #[test]
625    fn element_size_nested_fixed_size_list() {
626        let inner = Arc::new(DType::FixedSizeList(
627            Arc::new(DType::Primitive(PType::F64, NonNullable)),
628            10,
629            NonNullable,
630        ));
631        assert_eq!(
632            DType::FixedSizeList(inner, 100, NonNullable).element_size(),
633            Some(8000)
634        );
635    }
636
637    #[test]
638    fn element_size_extension() {
639        assert_eq!(
640            DType::Extension(Date::new(TimeUnit::Days, NonNullable).erased()).element_size(),
641            Some(4)
642        );
643    }
644
645    #[test]
646    fn element_size_variable_width() {
647        assert_eq!(DType::Utf8(NonNullable).element_size(), None);
648        assert_eq!(DType::Binary(NonNullable).element_size(), None);
649        assert_eq!(
650            DType::List(
651                Arc::new(DType::Primitive(PType::I32, NonNullable)),
652                NonNullable
653            )
654            .element_size(),
655            None
656        );
657    }
658}