vortex_dtype/
dtype.rs

1use std::fmt::{Debug, Display, Formatter};
2use std::hash::Hash;
3use std::sync::Arc;
4
5use DType::*;
6use itertools::Itertools;
7use static_assertions::const_assert_eq;
8
9use crate::nullability::Nullability;
10use crate::{ExtDType, PType, StructDType};
11
12/// A name for a field in a struct
13pub type FieldName = Arc<str>;
14/// An ordered list of field names in a struct
15pub type FieldNames = Arc<[FieldName]>;
16
17/// The logical types of elements in Vortex arrays.
18///
19/// Vortex arrays preserve a single logical type, while the encodings allow for multiple
20/// physical ways to encode that type.
21#[derive(Debug, Clone, PartialEq, Eq, Hash)]
22#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
23pub enum DType {
24    /// The logical null type (only has a single value, `null`)
25    Null,
26    /// The logical boolean type (`true` or `false` if non-nullable; `true`, `false`, or `null` if nullable)
27    Bool(Nullability),
28    /// Primitive, fixed-width numeric types (e.g., `u8`, `i8`, `u16`, `i16`, `u32`, `i32`, `u64`, `i64`, `f32`, `f64`)
29    Primitive(PType, Nullability),
30    /// UTF-8 strings
31    Utf8(Nullability),
32    /// Binary data
33    Binary(Nullability),
34    /// A struct is composed of an ordered list of fields, each with a corresponding name and DType
35    Struct(Arc<StructDType>, Nullability),
36    /// A variable-length list type, parameterized by a single element DType
37    List(Arc<DType>, Nullability),
38    /// User-defined extension types
39    Extension(Arc<ExtDType>),
40}
41
42#[cfg(not(target_arch = "wasm32"))]
43const_assert_eq!(size_of::<DType>(), 16);
44
45#[cfg(target_arch = "wasm32")]
46const_assert_eq!(size_of::<DType>(), 8);
47
48impl DType {
49    /// The default DType for bytes
50    pub const BYTES: Self = Primitive(PType::U8, Nullability::NonNullable);
51
52    /// Get the nullability of the DType
53    pub fn nullability(&self) -> Nullability {
54        self.is_nullable().into()
55    }
56
57    /// Check if the DType is nullable
58    pub fn is_nullable(&self) -> bool {
59        use crate::nullability::Nullability::*;
60
61        match self {
62            Null => true,
63            Bool(n) => matches!(n, Nullable),
64            Primitive(_, n) => matches!(n, Nullable),
65            Utf8(n) => matches!(n, Nullable),
66            Binary(n) => matches!(n, Nullable),
67            Struct(_, n) => matches!(n, Nullable),
68            List(_, n) => matches!(n, Nullable),
69            Extension(ext_dtype) => ext_dtype.storage_dtype().is_nullable(),
70        }
71    }
72
73    /// Get a new DType with `Nullability::NonNullable` (but otherwise the same as `self`)
74    pub fn as_nonnullable(&self) -> Self {
75        self.with_nullability(Nullability::NonNullable)
76    }
77
78    /// Get a new DType with `Nullability::Nullable` (but otherwise the same as `self`)
79    pub fn as_nullable(&self) -> Self {
80        self.with_nullability(Nullability::Nullable)
81    }
82
83    /// Get a new DType with the given nullability (but otherwise the same as `self`)
84    pub fn with_nullability(&self, nullability: Nullability) -> Self {
85        match self {
86            Null => Null,
87            Bool(_) => Bool(nullability),
88            Primitive(p, _) => Primitive(*p, nullability),
89            Utf8(_) => Utf8(nullability),
90            Binary(_) => Binary(nullability),
91            Struct(st, _) => Struct(st.clone(), nullability),
92            List(c, _) => List(c.clone(), nullability),
93            Extension(ext) => Extension(Arc::new(ext.with_nullability(nullability))),
94        }
95    }
96
97    /// Check if `self` and `other` are equal, ignoring nullability
98    pub fn eq_ignore_nullability(&self, other: &Self) -> bool {
99        match (self, other) {
100            (Null, Null) => true,
101            (Null, _) => false,
102            (Bool(_), Bool(_)) => true,
103            (Bool(_), _) => false,
104            (Primitive(lhs_ptype, _), Primitive(rhs_ptype, _)) => lhs_ptype == rhs_ptype,
105            (Primitive(..), _) => false,
106            (Utf8(_), Utf8(_)) => true,
107            (Utf8(_), _) => false,
108            (Binary(_), Binary(_)) => true,
109            (Binary(_), _) => false,
110            (List(lhs_dtype, _), List(rhs_dtype, _)) => lhs_dtype.eq_ignore_nullability(rhs_dtype),
111            (List(..), _) => false,
112            (Struct(lhs_dtype, _), Struct(rhs_dtype, _)) => {
113                (lhs_dtype.names() == rhs_dtype.names())
114                    && (lhs_dtype
115                        .fields()
116                        .zip_eq(rhs_dtype.fields())
117                        .all(|(l, r)| l.eq_ignore_nullability(&r)))
118            }
119            (Struct(..), _) => false,
120            (Extension(lhs_extdtype), Extension(rhs_extdtype)) => {
121                lhs_extdtype.as_ref().eq_ignore_nullability(rhs_extdtype)
122            }
123            (Extension(_), _) => false,
124        }
125    }
126
127    /// Check if `self` is a `StructDType`
128    pub fn is_struct(&self) -> bool {
129        matches!(self, Struct(_, _))
130    }
131
132    /// Check if `self` is an unsigned integer
133    pub fn is_unsigned_int(&self) -> bool {
134        PType::try_from(self).is_ok_and(PType::is_unsigned_int)
135    }
136
137    /// Check if `self` is a signed integer
138    pub fn is_signed_int(&self) -> bool {
139        PType::try_from(self).is_ok_and(PType::is_signed_int)
140    }
141
142    /// Check if `self` is an integer (signed or unsigned)
143    pub fn is_int(&self) -> bool {
144        PType::try_from(self).is_ok_and(PType::is_int)
145    }
146
147    /// Check if `self` is a floating point number
148    pub fn is_float(&self) -> bool {
149        PType::try_from(self).is_ok_and(PType::is_float)
150    }
151
152    /// Check if `self` is a boolean
153    pub fn is_boolean(&self) -> bool {
154        matches!(self, Bool(_))
155    }
156
157    /// Check if `self` is an extension type
158    pub fn is_extension(&self) -> bool {
159        matches!(self, Extension(_))
160    }
161
162    /// Get the `StructDType` if `self` is a `StructDType`, otherwise `None`
163    pub fn as_struct(&self) -> Option<&Arc<StructDType>> {
164        match self {
165            Struct(s, _) => Some(s),
166            _ => None,
167        }
168    }
169
170    /// Get the inner dtype if `self` is a `ListDType`, otherwise `None`
171    pub fn as_list_element(&self) -> Option<&DType> {
172        match self {
173            List(s, _) => Some(s.as_ref()),
174            _ => None,
175        }
176    }
177}
178
179impl Display for DType {
180    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
181        match self {
182            Null => write!(f, "null"),
183            Bool(n) => write!(f, "bool{}", n),
184            Primitive(pt, n) => write!(f, "{}{}", pt, n),
185            Utf8(n) => write!(f, "utf8{}", n),
186            Binary(n) => write!(f, "binary{}", n),
187            Struct(sdt, n) => write!(
188                f,
189                "{{{}}}{}",
190                sdt.names()
191                    .iter()
192                    .zip(sdt.fields())
193                    .map(|(n, dt)| format!("{}={}", n, dt))
194                    .join(", "),
195                n
196            ),
197            List(edt, n) => write!(f, "list({}){}", edt, n),
198            Extension(ext) => write!(
199                f,
200                "ext({}, {}{}){}",
201                ext.id(),
202                ext.storage_dtype()
203                    .with_nullability(Nullability::NonNullable),
204                ext.metadata()
205                    .map(|m| format!(", {:?}", m))
206                    .unwrap_or_else(|| "".to_string()),
207                ext.storage_dtype().nullability(),
208            ),
209        }
210    }
211}