vortex_dtype/
dtype.rs

1use std::fmt::{Debug, Display, Formatter};
2use std::hash::Hash;
3use std::sync::Arc;
4
5use DType::*;
6use itertools::Itertools;
7use static_assertions::const_assert_eq;
8use vortex_error::vortex_panic;
9
10use crate::decimal::DecimalDType;
11use crate::nullability::Nullability;
12use crate::{ExtDType, PType, StructFields};
13
14/// A name for a field in a struct
15pub type FieldName = Arc<str>;
16/// An ordered list of field names in a struct
17pub type FieldNames = Arc<[FieldName]>;
18
19/// The logical types of elements in Vortex arrays.
20///
21/// Vortex arrays preserve a single logical type, while the encodings allow for multiple
22/// physical ways to encode that type.
23#[derive(Debug, Clone, PartialEq, Eq, Hash)]
24#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
25pub enum DType {
26    /// The logical null type (only has a single value, `null`)
27    Null,
28    /// The logical boolean type (`true` or `false` if non-nullable; `true`, `false`, or `null` if nullable)
29    Bool(Nullability),
30    /// Primitive, fixed-width numeric types (e.g., `u8`, `i8`, `u16`, `i16`, `u32`, `i32`, `u64`, `i64`, `f32`, `f64`)
31    Primitive(PType, Nullability),
32    /// Real numbers with fixed exact precision and scale.
33    Decimal(DecimalDType, Nullability),
34    /// UTF-8 strings
35    Utf8(Nullability),
36    /// Binary data
37    Binary(Nullability),
38    /// A struct is composed of an ordered list of fields, each with a corresponding name and DType
39    Struct(Arc<StructFields>, Nullability),
40    /// A variable-length list type, parameterized by a single element DType
41    List(Arc<DType>, Nullability),
42    /// User-defined extension types
43    Extension(Arc<ExtDType>),
44}
45
46#[cfg(not(target_arch = "wasm32"))]
47const_assert_eq!(size_of::<DType>(), 16);
48
49#[cfg(target_arch = "wasm32")]
50const_assert_eq!(size_of::<DType>(), 8);
51
52impl DType {
53    /// The default DType for bytes
54    pub const BYTES: Self = Primitive(PType::U8, Nullability::NonNullable);
55
56    /// Get the nullability of the DType
57    pub fn nullability(&self) -> Nullability {
58        self.is_nullable().into()
59    }
60
61    /// Check if the DType is nullable
62    pub fn is_nullable(&self) -> bool {
63        use crate::nullability::Nullability::*;
64
65        match self {
66            Null => true,
67            Extension(ext_dtype) => ext_dtype.storage_dtype().is_nullable(),
68            Bool(n)
69            | Primitive(_, n)
70            | Decimal(_, n)
71            | Utf8(n)
72            | Binary(n)
73            | Struct(_, n)
74            | List(_, n) => matches!(n, Nullable),
75        }
76    }
77
78    /// Get a new DType with `Nullability::NonNullable` (but otherwise the same as `self`)
79    pub fn as_nonnullable(&self) -> Self {
80        self.with_nullability(Nullability::NonNullable)
81    }
82
83    /// Get a new DType with `Nullability::Nullable` (but otherwise the same as `self`)
84    pub fn as_nullable(&self) -> Self {
85        self.with_nullability(Nullability::Nullable)
86    }
87
88    /// Get a new DType with the given nullability (but otherwise the same as `self`)
89    pub fn with_nullability(&self, nullability: Nullability) -> Self {
90        match self {
91            Null => Null,
92            Bool(_) => Bool(nullability),
93            Primitive(p, _) => Primitive(*p, nullability),
94            Decimal(d, _) => Decimal(*d, nullability),
95            Utf8(_) => Utf8(nullability),
96            Binary(_) => Binary(nullability),
97            Struct(st, _) => Struct(st.clone(), nullability),
98            List(c, _) => List(c.clone(), nullability),
99            Extension(ext) => Extension(Arc::new(ext.with_nullability(nullability))),
100        }
101    }
102
103    /// Union the nullability of this dtype with the other nullability, returning a new dtype.
104    pub fn union_nullability(&self, other: Nullability) -> Self {
105        let nullability = self.nullability() | other;
106        self.with_nullability(nullability)
107    }
108
109    /// Check if `self` and `other` are equal, ignoring nullability
110    pub fn eq_ignore_nullability(&self, other: &Self) -> bool {
111        match (self, other) {
112            (Null, Null) => true,
113            (Bool(_), Bool(_)) => true,
114            (Primitive(lhs_ptype, _), Primitive(rhs_ptype, _)) => lhs_ptype == rhs_ptype,
115            (Decimal(lhs, _), Decimal(rhs, _)) => lhs == rhs,
116            (Utf8(_), Utf8(_)) => true,
117            (Binary(_), Binary(_)) => true,
118            (List(lhs_dtype, _), List(rhs_dtype, _)) => lhs_dtype.eq_ignore_nullability(rhs_dtype),
119            (Struct(lhs_dtype, _), Struct(rhs_dtype, _)) => {
120                (lhs_dtype.names() == rhs_dtype.names())
121                    && (lhs_dtype
122                        .fields()
123                        .zip_eq(rhs_dtype.fields())
124                        .all(|(l, r)| l.eq_ignore_nullability(&r)))
125            }
126            (Extension(lhs_extdtype), Extension(rhs_extdtype)) => {
127                lhs_extdtype.as_ref().eq_ignore_nullability(rhs_extdtype)
128            }
129            _ => false,
130        }
131    }
132
133    /// Check if `self` is a `StructDType`
134    pub fn is_struct(&self) -> bool {
135        matches!(self, Struct(_, _))
136    }
137
138    /// Check if `self` is a primitive tpye
139    pub fn is_primitive(&self) -> bool {
140        matches!(self, Primitive(_, _))
141    }
142
143    /// Returns this DType's `PType` if it is a primitive type, otherwise panics.
144    pub fn as_ptype(&self) -> PType {
145        match self {
146            Primitive(ptype, _) => *ptype,
147            _ => vortex_panic!("DType is not a primitive type"),
148        }
149    }
150
151    /// Check if `self` is an unsigned integer
152    pub fn is_unsigned_int(&self) -> bool {
153        if let Primitive(ptype, _) = self {
154            return ptype.is_unsigned_int();
155        }
156        false
157    }
158
159    /// Check if `self` is a signed integer
160    pub fn is_signed_int(&self) -> bool {
161        if let Primitive(ptype, _) = self {
162            return ptype.is_signed_int();
163        }
164        false
165    }
166
167    /// Check if `self` is an integer (signed or unsigned)
168    pub fn is_int(&self) -> bool {
169        if let Primitive(ptype, _) = self {
170            return ptype.is_int();
171        }
172        false
173    }
174
175    /// Check if `self` is a floating point number
176    pub fn is_float(&self) -> bool {
177        if let Primitive(ptype, _) = self {
178            return ptype.is_float();
179        }
180        false
181    }
182
183    /// Check if `self` is a boolean
184    pub fn is_boolean(&self) -> bool {
185        matches!(self, Bool(_))
186    }
187
188    /// Check if `self` is a binary
189    pub fn is_binary(&self) -> bool {
190        matches!(self, Binary(_))
191    }
192
193    /// Check if `self` is a utf8
194    pub fn is_utf8(&self) -> bool {
195        matches!(self, Utf8(_))
196    }
197
198    /// Check if `self` is an extension type
199    pub fn is_extension(&self) -> bool {
200        matches!(self, Extension(_))
201    }
202
203    /// Check if `self` is a decimal type
204    pub fn is_decimal(&self) -> bool {
205        matches!(self, Decimal(..))
206    }
207
208    /// Check returns the inner decimal type if the dtype is a decimal
209    pub fn as_decimal(&self) -> Option<&DecimalDType> {
210        match self {
211            Decimal(decimal, _) => Some(decimal),
212            _ => None,
213        }
214    }
215
216    /// Get the `StructDType` if `self` is a `StructDType`, otherwise `None`
217    pub fn as_struct(&self) -> Option<&Arc<StructFields>> {
218        match self {
219            Struct(s, _) => Some(s),
220            _ => None,
221        }
222    }
223
224    /// Get the inner dtype if `self` is a `ListDType`, otherwise `None`
225    pub fn as_list_element(&self) -> Option<&Arc<DType>> {
226        match self {
227            List(s, _) => Some(s),
228            _ => None,
229        }
230    }
231}
232
233impl Display for DType {
234    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
235        match self {
236            Null => write!(f, "null"),
237            Bool(n) => write!(f, "bool{n}"),
238            Primitive(pt, n) => write!(f, "{pt}{n}"),
239            Decimal(dt, n) => write!(f, "{dt}{n}"),
240            Utf8(n) => write!(f, "utf8{n}"),
241            Binary(n) => write!(f, "binary{n}"),
242            Struct(sdt, n) => write!(
243                f,
244                "{{{}}}{}",
245                sdt.names()
246                    .iter()
247                    .zip(sdt.fields())
248                    .map(|(n, dt)| format!("{n}={dt}"))
249                    .join(", "),
250                n
251            ),
252            List(edt, n) => write!(f, "list({edt}){n}"),
253            Extension(ext) => write!(
254                f,
255                "ext({}, {}{}){}",
256                ext.id(),
257                ext.storage_dtype()
258                    .with_nullability(Nullability::NonNullable),
259                ext.metadata()
260                    .map(|m| format!(", {m:?}"))
261                    .unwrap_or_else(|| "".to_string()),
262                ext.storage_dtype().nullability(),
263            ),
264        }
265    }
266}