vortex_dtype/
dtype.rs

1use std::fmt::{Debug, Display, Formatter};
2use std::hash::Hash;
3use std::sync::Arc;
4
5use DType::*;
6use itertools::Itertools;
7use static_assertions::const_assert_eq;
8
9use crate::decimal::DecimalDType;
10use crate::nullability::Nullability;
11use crate::{ExtDType, PType, StructDType};
12
13/// A name for a field in a struct
14pub type FieldName = Arc<str>;
15/// An ordered list of field names in a struct
16pub type FieldNames = Arc<[FieldName]>;
17
18/// The logical types of elements in Vortex arrays.
19///
20/// Vortex arrays preserve a single logical type, while the encodings allow for multiple
21/// physical ways to encode that type.
22#[derive(Debug, Clone, PartialEq, Eq, Hash)]
23#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
24pub enum DType {
25    /// The logical null type (only has a single value, `null`)
26    Null,
27    /// The logical boolean type (`true` or `false` if non-nullable; `true`, `false`, or `null` if nullable)
28    Bool(Nullability),
29    /// Primitive, fixed-width numeric types (e.g., `u8`, `i8`, `u16`, `i16`, `u32`, `i32`, `u64`, `i64`, `f32`, `f64`)
30    Primitive(PType, Nullability),
31    /// Real numbers with fixed exact precision and scale.
32    Decimal(DecimalDType, Nullability),
33    /// UTF-8 strings
34    Utf8(Nullability),
35    /// Binary data
36    Binary(Nullability),
37    /// A struct is composed of an ordered list of fields, each with a corresponding name and DType
38    Struct(Arc<StructDType>, Nullability),
39    /// A variable-length list type, parameterized by a single element DType
40    List(Arc<DType>, Nullability),
41    /// User-defined extension types
42    Extension(Arc<ExtDType>),
43}
44
45#[cfg(not(target_arch = "wasm32"))]
46const_assert_eq!(size_of::<DType>(), 16);
47
48#[cfg(target_arch = "wasm32")]
49const_assert_eq!(size_of::<DType>(), 8);
50
51impl DType {
52    /// The default DType for bytes
53    pub const BYTES: Self = Primitive(PType::U8, Nullability::NonNullable);
54
55    /// Get the nullability of the DType
56    pub fn nullability(&self) -> Nullability {
57        self.is_nullable().into()
58    }
59
60    /// Check if the DType is nullable
61    pub fn is_nullable(&self) -> bool {
62        use crate::nullability::Nullability::*;
63
64        match self {
65            Null => true,
66            Extension(ext_dtype) => ext_dtype.storage_dtype().is_nullable(),
67            Bool(n)
68            | Primitive(_, n)
69            | Decimal(_, n)
70            | Utf8(n)
71            | Binary(n)
72            | Struct(_, n)
73            | List(_, n) => matches!(n, Nullable),
74        }
75    }
76
77    /// Get a new DType with `Nullability::NonNullable` (but otherwise the same as `self`)
78    pub fn as_nonnullable(&self) -> Self {
79        self.with_nullability(Nullability::NonNullable)
80    }
81
82    /// Get a new DType with `Nullability::Nullable` (but otherwise the same as `self`)
83    pub fn as_nullable(&self) -> Self {
84        self.with_nullability(Nullability::Nullable)
85    }
86
87    /// Get a new DType with the given nullability (but otherwise the same as `self`)
88    pub fn with_nullability(&self, nullability: Nullability) -> Self {
89        match self {
90            Null => Null,
91            Bool(_) => Bool(nullability),
92            Primitive(p, _) => Primitive(*p, nullability),
93            Decimal(d, _) => Decimal(*d, nullability),
94            Utf8(_) => Utf8(nullability),
95            Binary(_) => Binary(nullability),
96            Struct(st, _) => Struct(st.clone(), nullability),
97            List(c, _) => List(c.clone(), nullability),
98            Extension(ext) => Extension(Arc::new(ext.with_nullability(nullability))),
99        }
100    }
101
102    /// Check if `self` and `other` are equal, ignoring nullability
103    pub fn eq_ignore_nullability(&self, other: &Self) -> bool {
104        match (self, other) {
105            (Null, Null) => true,
106            (Bool(_), Bool(_)) => true,
107            (Primitive(lhs_ptype, _), Primitive(rhs_ptype, _)) => lhs_ptype == rhs_ptype,
108            (Decimal(lhs, _), Decimal(rhs, _)) => lhs == rhs,
109            (Utf8(_), Utf8(_)) => true,
110            (Binary(_), Binary(_)) => true,
111            (List(lhs_dtype, _), List(rhs_dtype, _)) => lhs_dtype.eq_ignore_nullability(rhs_dtype),
112            (Struct(lhs_dtype, _), Struct(rhs_dtype, _)) => {
113                (lhs_dtype.names() == rhs_dtype.names())
114                    && (lhs_dtype
115                        .fields()
116                        .zip_eq(rhs_dtype.fields())
117                        .all(|(l, r)| l.eq_ignore_nullability(&r)))
118            }
119            (Extension(lhs_extdtype), Extension(rhs_extdtype)) => {
120                lhs_extdtype.as_ref().eq_ignore_nullability(rhs_extdtype)
121            }
122            _ => false,
123        }
124    }
125
126    /// Check if `self` is a `StructDType`
127    pub fn is_struct(&self) -> bool {
128        matches!(self, Struct(_, _))
129    }
130
131    /// Check if `self` is an unsigned integer
132    pub fn is_unsigned_int(&self) -> bool {
133        PType::try_from(self).is_ok_and(PType::is_unsigned_int)
134    }
135
136    /// Check if `self` is a signed integer
137    pub fn is_signed_int(&self) -> bool {
138        PType::try_from(self).is_ok_and(PType::is_signed_int)
139    }
140
141    /// Check if `self` is an integer (signed or unsigned)
142    pub fn is_int(&self) -> bool {
143        PType::try_from(self).is_ok_and(PType::is_int)
144    }
145
146    /// Check if `self` is a floating point number
147    pub fn is_float(&self) -> bool {
148        PType::try_from(self).is_ok_and(PType::is_float)
149    }
150
151    /// Check if `self` is a boolean
152    pub fn is_boolean(&self) -> bool {
153        matches!(self, Bool(_))
154    }
155
156    /// Check if `self` is an extension type
157    pub fn is_extension(&self) -> bool {
158        matches!(self, Extension(_))
159    }
160
161    /// Get the `StructDType` if `self` is a `StructDType`, otherwise `None`
162    pub fn as_struct(&self) -> Option<&Arc<StructDType>> {
163        match self {
164            Struct(s, _) => Some(s),
165            _ => None,
166        }
167    }
168
169    /// Get the inner dtype if `self` is a `ListDType`, otherwise `None`
170    pub fn as_list_element(&self) -> Option<&DType> {
171        match self {
172            List(s, _) => Some(s.as_ref()),
173            _ => None,
174        }
175    }
176}
177
178impl Display for DType {
179    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
180        match self {
181            Null => write!(f, "null"),
182            Bool(n) => write!(f, "bool{}", n),
183            Primitive(pt, n) => write!(f, "{}{}", pt, n),
184            Decimal(dt, n) => write!(f, "decimal({},{}){}", dt.precision(), dt.scale(), n),
185            Utf8(n) => write!(f, "utf8{}", n),
186            Binary(n) => write!(f, "binary{}", n),
187            Struct(sdt, n) => write!(
188                f,
189                "{{{}}}{}",
190                sdt.names()
191                    .iter()
192                    .zip(sdt.fields())
193                    .map(|(n, dt)| format!("{}={}", n, dt))
194                    .join(", "),
195                n
196            ),
197            List(edt, n) => write!(f, "list({}){}", edt, n),
198            Extension(ext) => write!(
199                f,
200                "ext({}, {}{}){}",
201                ext.id(),
202                ext.storage_dtype()
203                    .with_nullability(Nullability::NonNullable),
204                ext.metadata()
205                    .map(|m| format!(", {:?}", m))
206                    .unwrap_or_else(|| "".to_string()),
207                ext.storage_dtype().nullability(),
208            ),
209        }
210    }
211}