vortex_dtype/
dtype.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::{Debug, Display, Formatter};
5use std::hash::Hash;
6use std::sync::Arc;
7
8use DType::*;
9use itertools::Itertools;
10use static_assertions::const_assert_eq;
11use vortex_error::vortex_panic;
12
13use crate::decimal::DecimalDType;
14use crate::nullability::Nullability;
15use crate::{ExtDType, FieldDType, FieldName, PType, StructFields};
16
17/// The logical types of elements in Vortex arrays.
18///
19/// `DType` represents the different logical data types that can be represented in a Vortex array.
20///
21/// This is different from physical types, which represent the actual layout of data (compressed or
22/// uncompressed). The set of physical types/formats (or data layout) is surjective into the set of
23/// logical types (or in other words, all physical types map to a single logical type).
24///
25/// Note that a `DType` represents the logical type of the elements in the `Array`s, **not** the
26/// logical type of the `Array` itself.
27///
28/// For example, an array with [`DType::Primitive`]([`I32`], [`NonNullable`]) could be physically
29/// encoded as any of the following:
30///
31/// - A flat array of `i32` values.
32/// - A run-length encoded sequence.
33/// - Dictionary encoded values with bitpacked codes.
34///
35/// All of these physical encodings preserve the same logical [`I32`] type, even if the physical
36/// data is different.
37///
38/// [`I32`]: PType::I32
39/// [`NonNullable`]: Nullability::NonNullable
40#[derive(Debug, Clone, PartialEq, Eq, Hash)]
41#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
42pub enum DType {
43    /// A logical null type.
44    ///
45    /// `Null` only has a single value, `null`.
46    Null,
47
48    /// A logical boolean type.
49    ///
50    /// `Bool` can be `true` or `false` if non-nullable. It can be `true`, `false`, or `null` if
51    /// nullable.
52    Bool(Nullability),
53
54    /// A logical fixed-width numeric type.
55    ///
56    /// This can be unsigned, signed, or floating point. See [`PType`] for more information.
57    Primitive(PType, Nullability),
58
59    /// Logical real numbers with fixed precision and scale.
60    ///
61    /// See [`DecimalDType`] for more information.
62    Decimal(DecimalDType, Nullability),
63
64    /// Logical UTF-8 strings.
65    Utf8(Nullability),
66
67    /// Logical binary data.
68    Binary(Nullability),
69
70    /// A logical variable-length list type.
71    ///
72    /// This is parameterized by a single `DType` that represents the element type of the inner
73    /// lists.
74    List(Arc<DType>, Nullability),
75
76    /// A logical struct type.
77    ///
78    /// A `Struct` type is composed of an ordered list of fields, each with a corresponding name and
79    /// `DType`. See [`StructFields`] for more information.
80    Struct(StructFields, Nullability),
81
82    /// A user-defined extension type.
83    ///
84    /// See [`ExtDType`] for more information.
85    Extension(Arc<ExtDType>),
86}
87
88#[cfg(not(target_arch = "wasm32"))]
89const_assert_eq!(size_of::<DType>(), 16);
90
91#[cfg(target_arch = "wasm32")]
92const_assert_eq!(size_of::<DType>(), 8);
93
94impl DType {
95    /// The default `DType` for bytes.
96    pub const BYTES: Self = Primitive(PType::U8, Nullability::NonNullable);
97
98    /// Get the nullability of the `DType`.
99    pub fn nullability(&self) -> Nullability {
100        self.is_nullable().into()
101    }
102
103    /// Check if the `DType` is [`Nullability::Nullable`].
104    pub fn is_nullable(&self) -> bool {
105        match self {
106            Null => true,
107            Extension(ext_dtype) => ext_dtype.storage_dtype().is_nullable(),
108            Bool(null)
109            | Primitive(_, null)
110            | Decimal(_, null)
111            | Utf8(null)
112            | Binary(null)
113            | Struct(_, null)
114            | List(_, null) => matches!(null, Nullability::Nullable),
115        }
116    }
117
118    /// Get a new `DType` with [`Nullability::NonNullable`] (but otherwise the same as `self`)
119    pub fn as_nonnullable(&self) -> Self {
120        self.with_nullability(Nullability::NonNullable)
121    }
122
123    /// Get a new `DType` with [`Nullability::Nullable`] (but otherwise the same as `self`)
124    pub fn as_nullable(&self) -> Self {
125        self.with_nullability(Nullability::Nullable)
126    }
127
128    /// Get a new DType with the given nullability (but otherwise the same as `self`)
129    pub fn with_nullability(&self, nullability: Nullability) -> Self {
130        match self {
131            Null => Null,
132            Bool(_) => Bool(nullability),
133            Primitive(pdt, _) => Primitive(*pdt, nullability),
134            Decimal(ddt, _) => Decimal(*ddt, nullability),
135            Utf8(_) => Utf8(nullability),
136            Binary(_) => Binary(nullability),
137            Struct(sf, _) => Struct(sf.clone(), nullability),
138            List(edt, _) => List(edt.clone(), nullability),
139            Extension(ext) => Extension(Arc::new(ext.with_nullability(nullability))),
140        }
141    }
142
143    /// Union the nullability of this `DType` with the other nullability, returning a new `DType`.
144    pub fn union_nullability(&self, other: Nullability) -> Self {
145        let nullability = self.nullability() | other;
146        self.with_nullability(nullability)
147    }
148
149    /// Check if `self` and `other` are equal, ignoring nullability.
150    pub fn eq_ignore_nullability(&self, other: &Self) -> bool {
151        match (self, other) {
152            (Null, Null) => true,
153            (Bool(_), Bool(_)) => true,
154            (Primitive(lhs_ptype, _), Primitive(rhs_ptype, _)) => lhs_ptype == rhs_ptype,
155            (Decimal(lhs, _), Decimal(rhs, _)) => lhs == rhs,
156            (Utf8(_), Utf8(_)) => true,
157            (Binary(_), Binary(_)) => true,
158            (List(lhs_dtype, _), List(rhs_dtype, _)) => lhs_dtype.eq_ignore_nullability(rhs_dtype),
159            (Struct(lhs_dtype, _), Struct(rhs_dtype, _)) => {
160                (lhs_dtype.names() == rhs_dtype.names())
161                    && (lhs_dtype
162                        .fields()
163                        .zip_eq(rhs_dtype.fields())
164                        .all(|(l, r)| l.eq_ignore_nullability(&r)))
165            }
166            (Extension(lhs_extdtype), Extension(rhs_extdtype)) => {
167                lhs_extdtype.as_ref().eq_ignore_nullability(rhs_extdtype)
168            }
169            _ => false,
170        }
171    }
172
173    /// Check if `self` is a `StructDType`
174    pub fn is_struct(&self) -> bool {
175        matches!(self, Struct(_, _))
176    }
177
178    /// Check if `self` is a `ListDType`
179    pub fn is_list(&self) -> bool {
180        matches!(self, List(_, _))
181    }
182
183    /// Check if `self` is a primitive type
184    pub fn is_primitive(&self) -> bool {
185        matches!(self, Primitive(_, _))
186    }
187
188    /// Returns this [`DType`]'s [`PType`] if it is a primitive type, otherwise panics.
189    pub fn as_ptype(&self) -> PType {
190        if let Primitive(ptype, _) = self {
191            *ptype
192        } else {
193            vortex_panic!("DType is not a primitive type")
194        }
195    }
196
197    /// Check if `self` is an unsigned integer
198    pub fn is_unsigned_int(&self) -> bool {
199        if let Primitive(ptype, _) = self {
200            return ptype.is_unsigned_int();
201        }
202        false
203    }
204
205    /// Check if `self` is a signed integer
206    pub fn is_signed_int(&self) -> bool {
207        if let Primitive(ptype, _) = self {
208            return ptype.is_signed_int();
209        }
210        false
211    }
212
213    /// Check if `self` is an integer (signed or unsigned)
214    pub fn is_int(&self) -> bool {
215        if let Primitive(ptype, _) = self {
216            return ptype.is_int();
217        }
218        false
219    }
220
221    /// Check if `self` is a floating point number
222    pub fn is_float(&self) -> bool {
223        if let Primitive(ptype, _) = self {
224            return ptype.is_float();
225        }
226        false
227    }
228
229    /// Check if `self` is a boolean
230    pub fn is_boolean(&self) -> bool {
231        matches!(self, Bool(_))
232    }
233
234    /// Check if `self` is a binary
235    pub fn is_binary(&self) -> bool {
236        matches!(self, Binary(_))
237    }
238
239    /// Check if `self` is a utf8
240    pub fn is_utf8(&self) -> bool {
241        matches!(self, Utf8(_))
242    }
243
244    /// Check if `self` is an extension type
245    pub fn is_extension(&self) -> bool {
246        matches!(self, Extension(_))
247    }
248
249    /// Check if `self` is a decimal type
250    pub fn is_decimal(&self) -> bool {
251        matches!(self, Decimal(..))
252    }
253
254    /// Check returns the inner decimal type if the dtype is a decimal
255    pub fn as_decimal_opt(&self) -> Option<&DecimalDType> {
256        if let Decimal(decimal, _) = self {
257            Some(decimal)
258        } else {
259            None
260        }
261    }
262
263    /// Get the `StructDType` if `self` is a `StructDType`, otherwise `None`
264    pub fn as_struct_opt(&self) -> Option<&StructFields> {
265        if let Struct(f, _) = self {
266            Some(f)
267        } else {
268            None
269        }
270    }
271
272    /// Get the inner dtype if `self` is a `ListDType`, otherwise `None`
273    pub fn as_list_element_opt(&self) -> Option<&Arc<DType>> {
274        if let List(s, _) = self { Some(s) } else { None }
275    }
276
277    /// Convenience method for creating a [`DType::Struct`].
278    pub fn struct_<I: IntoIterator<Item = (impl Into<FieldName>, impl Into<FieldDType>)>>(
279        iter: I,
280        nullability: Nullability,
281    ) -> Self {
282        Struct(StructFields::from_iter(iter), nullability)
283    }
284
285    /// Convenience method for creating a list dtype
286    pub fn list(dtype: impl Into<DType>, nullability: Nullability) -> Self {
287        List(Arc::new(dtype.into()), nullability)
288    }
289}
290
291impl Display for DType {
292    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
293        match self {
294            Null => write!(f, "null"),
295            Bool(null) => write!(f, "bool{null}"),
296            Primitive(pdt, null) => write!(f, "{pdt}{null}"),
297            Decimal(ddt, null) => write!(f, "{ddt}{null}"),
298            Utf8(null) => write!(f, "utf8{null}"),
299            Binary(null) => write!(f, "binary{null}"),
300            Struct(sf, null) => write!(
301                f,
302                "{{{}}}{null}",
303                sf.names()
304                    .iter()
305                    .zip(sf.fields())
306                    .map(|(field_null, dt)| format!("{field_null}={dt}"))
307                    .join(", "),
308            ),
309            List(edt, null) => write!(f, "list({edt}){null}"),
310            Extension(ext) => write!(
311                f,
312                "ext({}, {}{}){}",
313                ext.id(),
314                ext.storage_dtype()
315                    .with_nullability(Nullability::NonNullable),
316                ext.metadata()
317                    .map(|m| format!(", {m:?}"))
318                    .unwrap_or_else(|| "".to_string()),
319                ext.storage_dtype().nullability(),
320            ),
321        }
322    }
323}